mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
-Add debug endpoint for system stats
-Adjust headers, generate from phonemes, etc
This commit is contained in:
parent
2e318051f8
commit
f61f79981d
19 changed files with 686 additions and 659 deletions
28
.gitignore
vendored
28
.gitignore
vendored
|
@ -39,19 +39,9 @@ ENV/
|
||||||
*.pth
|
*.pth
|
||||||
*.tar*
|
*.tar*
|
||||||
|
|
||||||
# Audio files
|
|
||||||
examples/*.wav
|
|
||||||
examples/*.pcm
|
|
||||||
examples/*.mp3
|
|
||||||
examples/*.flac
|
|
||||||
examples/*.acc
|
|
||||||
examples/*.ogg
|
|
||||||
examples/speech.mp3
|
|
||||||
examples/phoneme_examples/output/example_1.wav
|
|
||||||
examples/phoneme_examples/output/example_2.wav
|
|
||||||
examples/phoneme_examples/output/example_3.wav
|
|
||||||
|
|
||||||
# Other project files
|
# Other project files
|
||||||
|
.env
|
||||||
Kokoro-82M/
|
Kokoro-82M/
|
||||||
ui/data/
|
ui/data/
|
||||||
EXTERNAL_UV_DOCUMENTATION*
|
EXTERNAL_UV_DOCUMENTATION*
|
||||||
|
@ -61,10 +51,20 @@ api/temp_files/
|
||||||
# Docker
|
# Docker
|
||||||
Dockerfile*
|
Dockerfile*
|
||||||
docker-compose*
|
docker-compose*
|
||||||
examples/assorted_checks/River_of_Teet_-_Sarah_Gailey.epub
|
|
||||||
examples/ebook_test/chapter_to_audio.py
|
examples/ebook_test/chapter_to_audio.py
|
||||||
examples/ebook_test/chapters_to_audio.py
|
examples/ebook_test/chapters_to_audio.py
|
||||||
examples/ebook_test/parse_epub.py
|
examples/ebook_test/parse_epub.py
|
||||||
examples/ebook_test/River_of_Teet_-_Sarah_Gailey.epub
|
|
||||||
examples/ebook_test/River_of_Teet_-_Sarah_Gailey.txt
|
|
||||||
api/src/voices/af_jadzia.pt
|
api/src/voices/af_jadzia.pt
|
||||||
|
examples/assorted_checks/test_combinations/output/*
|
||||||
|
examples/assorted_checks/test_openai/output/*
|
||||||
|
|
||||||
|
|
||||||
|
# Audio files
|
||||||
|
examples/*.wav
|
||||||
|
examples/*.pcm
|
||||||
|
examples/*.mp3
|
||||||
|
examples/*.flac
|
||||||
|
examples/*.acc
|
||||||
|
examples/*.ogg
|
||||||
|
examples/speech.mp3
|
||||||
|
examples/phoneme_examples/output/*.wav
|
|
@ -38,7 +38,6 @@ class Settings(BaseSettings):
|
||||||
max_temp_dir_age_hours: int = 1 # Remove temp files older than 1 hour
|
max_temp_dir_age_hours: int = 1 # Remove temp files older than 1 hour
|
||||||
max_temp_dir_count: int = 3 # Maximum number of temp files to keep
|
max_temp_dir_count: int = 3 # Maximum number of temp files to keep
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_file = ".env"
|
env_file = ".env"
|
||||||
|
|
||||||
|
|
|
@ -128,7 +128,8 @@ class BaseSessionPool:
|
||||||
# Check if we can create new session
|
# Check if we can create new session
|
||||||
if len(self._sessions) >= self._max_size:
|
if len(self._sessions) >= self._max_size:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Maximum number of sessions reached ({self._max_size})"
|
f"Maximum number of sessions reached ({self._max_size}). "
|
||||||
|
"Try again later or reduce concurrent requests."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create new session
|
# Create new session
|
||||||
|
|
|
@ -132,3 +132,57 @@ async def get_system_info():
|
||||||
"network": network_info,
|
"network": network_info,
|
||||||
"gpu": gpu_info
|
"gpu": gpu_info
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@router.get("/debug/session_pools")
|
||||||
|
async def get_session_pool_info():
|
||||||
|
"""Get information about ONNX session pools."""
|
||||||
|
from ..inference.model_manager import get_manager
|
||||||
|
|
||||||
|
manager = await get_manager()
|
||||||
|
pools = manager._session_pools
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
pool_info = {}
|
||||||
|
|
||||||
|
# Get CPU pool info
|
||||||
|
if 'onnx_cpu' in pools:
|
||||||
|
cpu_pool = pools['onnx_cpu']
|
||||||
|
pool_info['cpu'] = {
|
||||||
|
"active_sessions": len(cpu_pool._sessions),
|
||||||
|
"max_sessions": cpu_pool._max_size,
|
||||||
|
"sessions": [{
|
||||||
|
"model": path,
|
||||||
|
"age_seconds": current_time - info.last_used
|
||||||
|
} for path, info in cpu_pool._sessions.items()]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Get GPU pool info
|
||||||
|
if 'onnx_gpu' in pools:
|
||||||
|
gpu_pool = pools['onnx_gpu']
|
||||||
|
pool_info['gpu'] = {
|
||||||
|
"active_sessions": len(gpu_pool._sessions),
|
||||||
|
"max_streams": gpu_pool._max_size,
|
||||||
|
"available_streams": len(gpu_pool._available_streams),
|
||||||
|
"sessions": [{
|
||||||
|
"model": path,
|
||||||
|
"age_seconds": current_time - info.last_used,
|
||||||
|
"stream_id": info.stream_id
|
||||||
|
} for path, info in gpu_pool._sessions.items()]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add GPU memory info if available
|
||||||
|
if GPU_AVAILABLE:
|
||||||
|
try:
|
||||||
|
gpus = GPUtil.getGPUs()
|
||||||
|
if gpus:
|
||||||
|
gpu = gpus[0] # Assume first GPU
|
||||||
|
pool_info['gpu']['memory'] = {
|
||||||
|
"total_mb": gpu.memoryTotal,
|
||||||
|
"used_mb": gpu.memoryUsed,
|
||||||
|
"free_mb": gpu.memoryFree,
|
||||||
|
"percent_used": (gpu.memoryUsed / gpu.memoryTotal) * 100
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return pool_info
|
|
@ -3,10 +3,13 @@ from typing import List
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from ..services.audio import AudioService
|
from ..services.audio import AudioService, AudioNormalizer
|
||||||
from ..services.text_processing import phonemize, tokenize
|
from ..services.streaming_audio_writer import StreamingAudioWriter
|
||||||
|
from ..services.text_processing import phonemize, smart_split
|
||||||
|
from ..services.text_processing.vocabulary import tokenize
|
||||||
from ..services.tts_service import TTSService
|
from ..services.tts_service import TTSService
|
||||||
from ..structures.text_schemas import (
|
from ..structures.text_schemas import (
|
||||||
GenerateFromPhonemesRequest,
|
GenerateFromPhonemesRequest,
|
||||||
|
@ -21,8 +24,6 @@ async def get_tts_service() -> TTSService:
|
||||||
"""Dependency to get TTSService instance"""
|
"""Dependency to get TTSService instance"""
|
||||||
return await TTSService.create() # Create service with properly initialized managers
|
return await TTSService.create() # Create service with properly initialized managers
|
||||||
|
|
||||||
|
|
||||||
@router.post("/text/phonemize", response_model=PhonemeResponse, tags=["deprecated"])
|
|
||||||
@router.post("/dev/phonemize", response_model=PhonemeResponse)
|
@router.post("/dev/phonemize", response_model=PhonemeResponse)
|
||||||
async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse:
|
async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse:
|
||||||
"""Convert text to phonemes and tokens
|
"""Convert text to phonemes and tokens
|
||||||
|
@ -56,108 +57,95 @@ async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail={"error": "Server error", "message": str(e)}
|
status_code=500, detail={"error": "Server error", "message": str(e)}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/text/generate_from_phonemes", tags=["deprecated"])
|
|
||||||
@router.post("/dev/generate_from_phonemes")
|
@router.post("/dev/generate_from_phonemes")
|
||||||
async def generate_from_phonemes(
|
async def generate_from_phonemes(
|
||||||
request: GenerateFromPhonemesRequest,
|
request: GenerateFromPhonemesRequest,
|
||||||
|
client_request: Request,
|
||||||
tts_service: TTSService = Depends(get_tts_service),
|
tts_service: TTSService = Depends(get_tts_service),
|
||||||
) -> Response:
|
) -> StreamingResponse:
|
||||||
"""Generate audio directly from phonemes
|
"""Generate audio directly from phonemes with proper streaming"""
|
||||||
|
try:
|
||||||
Args:
|
# Basic validation
|
||||||
request: Request containing phonemes and generation parameters
|
if not isinstance(request.phonemes, str):
|
||||||
tts_service: Injected TTSService instance
|
raise ValueError("Phonemes must be a string")
|
||||||
|
|
||||||
Returns:
|
|
||||||
WAV audio bytes
|
|
||||||
"""
|
|
||||||
# Validate phonemes first
|
|
||||||
if not request.phonemes:
|
if not request.phonemes:
|
||||||
raise HTTPException(
|
raise ValueError("Phonemes cannot be empty")
|
||||||
status_code=400,
|
|
||||||
detail={"error": "Invalid request", "message": "Phonemes cannot be empty"},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# Create streaming audio writer and normalizer
|
||||||
|
writer = StreamingAudioWriter(format="wav", sample_rate=24000, channels=1)
|
||||||
|
normalizer = AudioNormalizer()
|
||||||
|
|
||||||
|
async def generate_chunks():
|
||||||
try:
|
try:
|
||||||
# Validate voice exists
|
has_data = False
|
||||||
available_voices = await tts_service.list_voices()
|
# Process phonemes in chunks
|
||||||
if request.voice not in available_voices:
|
async for chunk_text, _ in smart_split(request.phonemes):
|
||||||
raise HTTPException(
|
# Check if client is still connected
|
||||||
status_code=400,
|
is_disconnected = client_request.is_disconnected
|
||||||
detail={
|
if callable(is_disconnected):
|
||||||
"error": "Invalid request",
|
is_disconnected = await is_disconnected()
|
||||||
"message": f"Voice not found: {request.voice}",
|
if is_disconnected:
|
||||||
},
|
logger.info("Client disconnected, stopping audio generation")
|
||||||
)
|
break
|
||||||
|
|
||||||
# Handle both single string and list of chunks
|
chunk_audio, _ = await tts_service.generate_from_phonemes(
|
||||||
phoneme_chunks = [request.phonemes] if isinstance(request.phonemes, str) else request.phonemes
|
phonemes=chunk_text,
|
||||||
audio_chunks = []
|
voice=request.voice,
|
||||||
|
speed=1.0
|
||||||
# Load voice tensor first since we'll need it for all chunks
|
|
||||||
voice_tensor = await tts_service._voice_manager.load_voice(
|
|
||||||
request.voice,
|
|
||||||
device=tts_service.model_manager.get_backend().device
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Process each chunk
|
|
||||||
for chunk in phoneme_chunks:
|
|
||||||
# Convert chunk to tokens
|
|
||||||
tokens = tokenize(chunk)
|
|
||||||
tokens = [0] + tokens + [0] # Add start/end tokens
|
|
||||||
|
|
||||||
# Validate chunk length
|
|
||||||
if len(tokens) > 510: # 510 to leave room for start/end tokens
|
|
||||||
raise ValueError(
|
|
||||||
f"Chunk too long ({len(tokens)} tokens). Each chunk must be under 510 tokens."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate audio for chunk
|
|
||||||
chunk_audio = await tts_service.model_manager.generate(
|
|
||||||
tokens,
|
|
||||||
voice_tensor,
|
|
||||||
speed=request.speed
|
|
||||||
)
|
)
|
||||||
if chunk_audio is not None:
|
if chunk_audio is not None:
|
||||||
audio_chunks.append(chunk_audio)
|
has_data = True
|
||||||
|
# Normalize audio before writing
|
||||||
|
normalized_audio = await normalizer.normalize(chunk_audio)
|
||||||
|
# Write chunk and yield bytes
|
||||||
|
chunk_bytes = writer.write_chunk(normalized_audio)
|
||||||
|
if chunk_bytes:
|
||||||
|
yield chunk_bytes
|
||||||
|
|
||||||
# Combine chunks if needed
|
if not has_data:
|
||||||
if len(audio_chunks) > 1:
|
raise ValueError("Failed to generate any audio data")
|
||||||
audio = np.concatenate(audio_chunks)
|
|
||||||
elif len(audio_chunks) == 1:
|
|
||||||
audio = audio_chunks[0]
|
|
||||||
else:
|
|
||||||
raise ValueError("No audio chunks were generated")
|
|
||||||
|
|
||||||
finally:
|
# Finalize and yield remaining bytes if we still have a connection
|
||||||
# Clean up voice tensor
|
if not (callable(is_disconnected) and await is_disconnected()):
|
||||||
del voice_tensor
|
final_bytes = writer.write_chunk(finalize=True)
|
||||||
torch.cuda.empty_cache()
|
if final_bytes:
|
||||||
|
yield final_bytes
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in audio chunk generation: {str(e)}")
|
||||||
|
# Clean up writer on error
|
||||||
|
writer.write_chunk(finalize=True)
|
||||||
|
# Re-raise the original exception
|
||||||
|
raise
|
||||||
|
|
||||||
# Convert to WAV bytes
|
return StreamingResponse(
|
||||||
wav_bytes = await AudioService.convert_audio(
|
generate_chunks(),
|
||||||
audio, 24000, "wav", is_first_chunk=True, is_last_chunk=True, stream=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
return Response(
|
|
||||||
content=wav_bytes,
|
|
||||||
media_type="audio/wav",
|
media_type="audio/wav",
|
||||||
headers={
|
headers={
|
||||||
"Content-Disposition": "attachment; filename=speech.wav",
|
"Content-Disposition": "attachment; filename=speech.wav",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
"Cache-Control": "no-cache",
|
"Cache-Control": "no-cache",
|
||||||
},
|
"Transfer-Encoding": "chunked"
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f"Invalid request: {str(e)}")
|
logger.error(f"Error generating audio: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail={"error": "Invalid request", "message": str(e)}
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": "validation_error",
|
||||||
|
"message": str(e),
|
||||||
|
"type": "invalid_request_error"
|
||||||
|
}
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error generating audio: {str(e)}")
|
logger.error(f"Error generating audio: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail={"error": "Server error", "message": str(e)}
|
status_code=500,
|
||||||
|
detail={
|
||||||
|
"error": "processing_error",
|
||||||
|
"message": str(e),
|
||||||
|
"type": "server_error"
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
|
@ -156,12 +156,8 @@ async def create_speech(
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_name = get_model_name(request.model)
|
# model_name = get_model_name(request.model)
|
||||||
|
|
||||||
# Get global service instance
|
|
||||||
tts_service = await get_tts_service()
|
tts_service = await get_tts_service()
|
||||||
|
|
||||||
# Process voice combination and validate
|
|
||||||
voice_to_use = await process_voices(request.voice, tts_service)
|
voice_to_use = await process_voices(request.voice, tts_service)
|
||||||
|
|
||||||
# Set content type based on format
|
# Set content type based on format
|
||||||
|
@ -238,13 +234,14 @@ async def create_speech(
|
||||||
audio, _ = await tts_service.generate_audio(
|
audio, _ = await tts_service.generate_audio(
|
||||||
text=request.input,
|
text=request.input,
|
||||||
voice=voice_to_use,
|
voice=voice_to_use,
|
||||||
speed=request.speed,
|
speed=request.speed
|
||||||
stitch_long_output=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert to requested format - removed stream parameter
|
# Convert to requested format with proper finalization
|
||||||
content = await AudioService.convert_audio(
|
content = await AudioService.convert_audio(
|
||||||
audio, 24000, request.response_format, is_first_chunk=True
|
audio, 24000, request.response_format,
|
||||||
|
is_first_chunk=True,
|
||||||
|
is_last_chunk=True
|
||||||
)
|
)
|
||||||
|
|
||||||
return Response(
|
return Response(
|
||||||
|
|
|
@ -105,13 +105,16 @@ class AudioService:
|
||||||
)
|
)
|
||||||
writer = AudioService._writers[writer_key]
|
writer = AudioService._writers[writer_key]
|
||||||
|
|
||||||
# Write chunk or finalize
|
# Write audio data first
|
||||||
if is_last_chunk:
|
if len(normalized_audio) > 0:
|
||||||
chunk_data = writer.write_chunk(finalize=True)
|
|
||||||
del AudioService._writers[writer_key]
|
|
||||||
else:
|
|
||||||
chunk_data = writer.write_chunk(normalized_audio)
|
chunk_data = writer.write_chunk(normalized_audio)
|
||||||
|
|
||||||
|
# Then finalize if this is the last chunk
|
||||||
|
if is_last_chunk:
|
||||||
|
final_data = writer.write_chunk(finalize=True)
|
||||||
|
del AudioService._writers[writer_key]
|
||||||
|
return final_data if final_data else b''
|
||||||
|
|
||||||
return chunk_data if chunk_data else b''
|
return chunk_data if chunk_data else b''
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import struct
|
import struct
|
||||||
from typing import Generator, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
@ -21,7 +21,7 @@ class StreamingAudioWriter:
|
||||||
|
|
||||||
# Format-specific setup
|
# Format-specific setup
|
||||||
if self.format == "wav":
|
if self.format == "wav":
|
||||||
self._write_wav_header()
|
self._write_wav_header_initial()
|
||||||
elif self.format in ["ogg", "opus"]:
|
elif self.format in ["ogg", "opus"]:
|
||||||
# For OGG/Opus, write to memory buffer
|
# For OGG/Opus, write to memory buffer
|
||||||
self.writer = sf.SoundFile(
|
self.writer = sf.SoundFile(
|
||||||
|
@ -53,23 +53,21 @@ class StreamingAudioWriter:
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported format: {format}")
|
raise ValueError(f"Unsupported format: {format}")
|
||||||
|
|
||||||
def _write_wav_header(self) -> bytes:
|
def _write_wav_header_initial(self) -> None:
|
||||||
"""Write WAV header with correct streaming format"""
|
"""Write initial WAV header with placeholders"""
|
||||||
header = BytesIO()
|
self.buffer.write(b'RIFF')
|
||||||
header.write(b'RIFF')
|
self.buffer.write(struct.pack('<L', 0)) # Placeholder for file size
|
||||||
header.write(struct.pack('<L', 0)) # Placeholder for file size
|
self.buffer.write(b'WAVE')
|
||||||
header.write(b'WAVE')
|
self.buffer.write(b'fmt ')
|
||||||
header.write(b'fmt ')
|
self.buffer.write(struct.pack('<L', 16)) # fmt chunk size
|
||||||
header.write(struct.pack('<L', 16)) # fmt chunk size
|
self.buffer.write(struct.pack('<H', 1)) # PCM format
|
||||||
header.write(struct.pack('<H', 1)) # PCM format
|
self.buffer.write(struct.pack('<H', self.channels))
|
||||||
header.write(struct.pack('<H', self.channels))
|
self.buffer.write(struct.pack('<L', self.sample_rate))
|
||||||
header.write(struct.pack('<L', self.sample_rate))
|
self.buffer.write(struct.pack('<L', self.sample_rate * self.channels * 2)) # Byte rate
|
||||||
header.write(struct.pack('<L', self.sample_rate * self.channels * 2)) # Byte rate
|
self.buffer.write(struct.pack('<H', self.channels * 2)) # Block align
|
||||||
header.write(struct.pack('<H', self.channels * 2)) # Block align
|
self.buffer.write(struct.pack('<H', 16)) # Bits per sample
|
||||||
header.write(struct.pack('<H', 16)) # Bits per sample
|
self.buffer.write(b'data')
|
||||||
header.write(b'data')
|
self.buffer.write(struct.pack('<L', 0)) # Placeholder for data size
|
||||||
header.write(struct.pack('<L', 0)) # Placeholder for data size
|
|
||||||
return header.getvalue()
|
|
||||||
|
|
||||||
def write_chunk(self, audio_data: Optional[np.ndarray] = None, finalize: bool = False) -> bytes:
|
def write_chunk(self, audio_data: Optional[np.ndarray] = None, finalize: bool = False) -> bytes:
|
||||||
"""Write a chunk of audio data and return bytes in the target format.
|
"""Write a chunk of audio data and return bytes in the target format.
|
||||||
|
@ -82,47 +80,45 @@ class StreamingAudioWriter:
|
||||||
|
|
||||||
if finalize:
|
if finalize:
|
||||||
if self.format == "wav":
|
if self.format == "wav":
|
||||||
# Write final WAV header with correct sizes
|
# Calculate actual file and data sizes
|
||||||
output_buffer.write(b'RIFF')
|
file_size = self.bytes_written + 36 # RIFF header bytes
|
||||||
output_buffer.write(struct.pack('<L', self.bytes_written + 36))
|
data_size = self.bytes_written
|
||||||
output_buffer.write(b'WAVE')
|
|
||||||
output_buffer.write(b'fmt ')
|
# Seek to the beginning to overwrite the placeholders
|
||||||
output_buffer.write(struct.pack('<L', 16))
|
self.buffer.seek(4)
|
||||||
output_buffer.write(struct.pack('<H', 1))
|
self.buffer.write(struct.pack('<L', file_size))
|
||||||
output_buffer.write(struct.pack('<H', self.channels))
|
self.buffer.seek(40)
|
||||||
output_buffer.write(struct.pack('<L', self.sample_rate))
|
self.buffer.write(struct.pack('<L', data_size))
|
||||||
output_buffer.write(struct.pack('<L', self.sample_rate * self.channels * 2))
|
|
||||||
output_buffer.write(struct.pack('<H', self.channels * 2))
|
self.buffer.seek(0)
|
||||||
output_buffer.write(struct.pack('<H', 16))
|
return self.buffer.read()
|
||||||
output_buffer.write(b'data')
|
|
||||||
output_buffer.write(struct.pack('<L', self.bytes_written))
|
|
||||||
elif self.format in ["ogg", "opus", "flac"]:
|
elif self.format in ["ogg", "opus", "flac"]:
|
||||||
self.writer.close()
|
self.writer.close()
|
||||||
return self.buffer.getvalue()
|
return self.buffer.getvalue()
|
||||||
elif self.format in ["mp3", "aac"]:
|
elif self.format in ["mp3", "aac"]:
|
||||||
# Final export of any remaining audio
|
|
||||||
if hasattr(self, 'encoder') and len(self.encoder) > 0:
|
if hasattr(self, 'encoder') and len(self.encoder) > 0:
|
||||||
# Export with duration metadata
|
|
||||||
format_args = {
|
format_args = {
|
||||||
"mp3": {"format": "mp3", "codec": "libmp3lame"},
|
"mp3": {"format": "mp3", "codec": "libmp3lame"},
|
||||||
"aac": {"format": "adts", "codec": "aac"}
|
"aac": {"format": "adts", "codec": "aac"}
|
||||||
}[self.format]
|
}[self.format]
|
||||||
|
|
||||||
# On finalization, include proper headers and duration metadata
|
parameters = []
|
||||||
parameters = [
|
|
||||||
"-q:a", "2",
|
|
||||||
"-write_xing", "1" if self.format == "mp3" else "0", # XING header for MP3 only
|
|
||||||
"-metadata", f"duration={self.total_duration/1000}", # Duration in seconds
|
|
||||||
"-write_id3v1", "1" if self.format == "mp3" else "0", # ID3v1 tag for MP3
|
|
||||||
"-write_id3v2", "1" if self.format == "mp3" else "0" # ID3v2 tag for MP3
|
|
||||||
]
|
|
||||||
|
|
||||||
if self.format == "mp3":
|
if self.format == "mp3":
|
||||||
# For MP3, ensure proper VBR headers
|
|
||||||
parameters.extend([
|
parameters.extend([
|
||||||
|
"-q:a", "2",
|
||||||
|
"-write_xing", "1", # XING header for MP3
|
||||||
|
"-id3v1", "1",
|
||||||
|
"-id3v2", "1",
|
||||||
"-write_vbr", "1",
|
"-write_vbr", "1",
|
||||||
"-vbr_quality", "2"
|
"-vbr_quality", "2"
|
||||||
])
|
])
|
||||||
|
elif self.format == "aac":
|
||||||
|
parameters.extend([
|
||||||
|
"-q:a", "2",
|
||||||
|
"-write_xing", "0",
|
||||||
|
"-write_id3v1", "0",
|
||||||
|
"-write_id3v2", "0"
|
||||||
|
])
|
||||||
|
|
||||||
self.encoder.export(
|
self.encoder.export(
|
||||||
output_buffer,
|
output_buffer,
|
||||||
|
@ -131,28 +127,23 @@ class StreamingAudioWriter:
|
||||||
parameters=parameters
|
parameters=parameters
|
||||||
)
|
)
|
||||||
self.encoder = None
|
self.encoder = None
|
||||||
|
|
||||||
return output_buffer.getvalue()
|
return output_buffer.getvalue()
|
||||||
|
|
||||||
if audio_data is None or len(audio_data) == 0:
|
if audio_data is None or len(audio_data) == 0:
|
||||||
return b''
|
return b''
|
||||||
|
|
||||||
if self.format == "wav":
|
if self.format == "wav":
|
||||||
# For WAV, write raw PCM after the first chunk
|
# Write raw PCM data
|
||||||
if self.bytes_written == 0:
|
self.buffer.write(audio_data.tobytes())
|
||||||
output_buffer.write(self._write_wav_header())
|
|
||||||
output_buffer.write(audio_data.tobytes())
|
|
||||||
self.bytes_written += len(audio_data.tobytes())
|
self.bytes_written += len(audio_data.tobytes())
|
||||||
|
return b''
|
||||||
|
|
||||||
elif self.format in ["ogg", "opus", "flac"]:
|
elif self.format in ["ogg", "opus", "flac"]:
|
||||||
# Write to soundfile buffer
|
# Write to soundfile buffer
|
||||||
self.writer.write(audio_data)
|
self.writer.write(audio_data)
|
||||||
self.writer.flush()
|
self.writer.flush()
|
||||||
# Get current buffer contents
|
return self.buffer.getvalue()
|
||||||
data = self.buffer.getvalue()
|
|
||||||
# Clear buffer for next chunk
|
|
||||||
self.buffer.seek(0)
|
|
||||||
self.buffer.truncate()
|
|
||||||
return data
|
|
||||||
|
|
||||||
elif self.format in ["mp3", "aac"]:
|
elif self.format in ["mp3", "aac"]:
|
||||||
# Convert chunk to AudioSegment and encode
|
# Convert chunk to AudioSegment and encode
|
||||||
|
@ -167,9 +158,9 @@ class StreamingAudioWriter:
|
||||||
self.total_duration += len(segment)
|
self.total_duration += len(segment)
|
||||||
|
|
||||||
# Add segment to encoder
|
# Add segment to encoder
|
||||||
self.encoder = self.encoder + segment
|
self.encoder += segment
|
||||||
|
|
||||||
# Export current state to buffer
|
# Export current state to buffer without final metadata
|
||||||
format_args = {
|
format_args = {
|
||||||
"mp3": {"format": "mp3", "codec": "libmp3lame"},
|
"mp3": {"format": "mp3", "codec": "libmp3lame"},
|
||||||
"aac": {"format": "adts", "codec": "aac"}
|
"aac": {"format": "adts", "codec": "aac"}
|
||||||
|
@ -190,44 +181,27 @@ class StreamingAudioWriter:
|
||||||
return encoded_data
|
return encoded_data
|
||||||
|
|
||||||
elif self.format == "pcm":
|
elif self.format == "pcm":
|
||||||
# For PCM, just write raw bytes
|
# Write raw bytes
|
||||||
return audio_data.tobytes()
|
return audio_data.tobytes()
|
||||||
|
|
||||||
return output_buffer.getvalue()
|
return b''
|
||||||
|
|
||||||
def close(self) -> Optional[bytes]:
|
def close(self) -> Optional[bytes]:
|
||||||
"""Finish the audio file and return any remaining data"""
|
"""Finish the audio file and return any remaining data"""
|
||||||
if self.format == "wav":
|
if self.format == "wav":
|
||||||
# Update WAV header with final file size
|
# Re-finalize WAV file by updating headers
|
||||||
buffer = BytesIO()
|
self.buffer.seek(0)
|
||||||
buffer.write(b'RIFF')
|
file_content = self.write_chunk(finalize=True)
|
||||||
buffer.write(struct.pack('<L', self.bytes_written + 36)) # File size
|
return file_content
|
||||||
buffer.write(b'WAVE')
|
|
||||||
buffer.write(b'fmt ')
|
|
||||||
buffer.write(struct.pack('<L', 16))
|
|
||||||
buffer.write(struct.pack('<H', 1))
|
|
||||||
buffer.write(struct.pack('<H', self.channels))
|
|
||||||
buffer.write(struct.pack('<L', self.sample_rate))
|
|
||||||
buffer.write(struct.pack('<L', self.sample_rate * self.channels * 2))
|
|
||||||
buffer.write(struct.pack('<H', self.channels * 2))
|
|
||||||
buffer.write(struct.pack('<H', 16))
|
|
||||||
buffer.write(b'data')
|
|
||||||
buffer.write(struct.pack('<L', self.bytes_written))
|
|
||||||
return buffer.getvalue()
|
|
||||||
|
|
||||||
elif self.format in ["ogg", "opus", "flac"]:
|
elif self.format in ["ogg", "opus", "flac"]:
|
||||||
|
# Finalize other formats
|
||||||
self.writer.close()
|
self.writer.close()
|
||||||
return self.buffer.getvalue()
|
return self.buffer.getvalue()
|
||||||
|
|
||||||
elif self.format in ["mp3", "aac"]:
|
elif self.format in ["mp3", "aac"]:
|
||||||
# Flush any remaining audio
|
# Finalize MP3/AAC
|
||||||
buffer = BytesIO()
|
final_data = self.write_chunk(finalize=True)
|
||||||
if hasattr(self, 'encoder') and len(self.encoder) > 0:
|
return final_data
|
||||||
format_args = {
|
|
||||||
"mp3": {"format": "mp3", "codec": "libmp3lame"},
|
|
||||||
"aac": {"format": "adts", "codec": "aac"}
|
|
||||||
}[self.format]
|
|
||||||
self.encoder.export(buffer, **format_args)
|
|
||||||
return buffer.getvalue()
|
|
||||||
|
|
||||||
return None
|
return None
|
|
@ -13,31 +13,37 @@ TARGET_MIN = 200
|
||||||
TARGET_MAX = 350
|
TARGET_MAX = 350
|
||||||
ABSOLUTE_MAX = 500
|
ABSOLUTE_MAX = 500
|
||||||
|
|
||||||
def process_text_chunk(text: str, language: str = "a") -> List[int]:
|
def process_text_chunk(text: str, language: str = "a", skip_phonemize: bool = False) -> List[int]:
|
||||||
"""Process a chunk of text through normalization, phonemization, and tokenization.
|
"""Process a chunk of text through normalization, phonemization, and tokenization.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: Text chunk to process
|
text: Text chunk to process
|
||||||
language: Language code for phonemization
|
language: Language code for phonemization
|
||||||
|
skip_phonemize: If True, treat input as phonemes and skip normalization/phonemization
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of token IDs
|
List of token IDs
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# Normalize
|
if skip_phonemize:
|
||||||
|
# Input is already phonemes, just tokenize
|
||||||
|
t0 = time.time()
|
||||||
|
tokens = tokenize(text)
|
||||||
|
t1 = time.time()
|
||||||
|
logger.debug(f"Tokenization took {(t1-t0)*1000:.2f}ms for {len(text)} chars")
|
||||||
|
else:
|
||||||
|
# Normal text processing pipeline
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
normalized = normalize_text(text)
|
normalized = normalize_text(text)
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
logger.debug(f"Normalization took {(t1-t0)*1000:.2f}ms for {len(text)} chars")
|
logger.debug(f"Normalization took {(t1-t0)*1000:.2f}ms for {len(text)} chars")
|
||||||
|
|
||||||
# Phonemize
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
phonemes = phonemize(normalized, language, normalize=False) # Already normalized
|
phonemes = phonemize(normalized, language, normalize=False) # Already normalized
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
logger.debug(f"Phonemization took {(t1-t0)*1000:.2f}ms for {len(normalized)} chars")
|
logger.debug(f"Phonemization took {(t1-t0)*1000:.2f}ms for {len(normalized)} chars")
|
||||||
|
|
||||||
# Convert to token IDs
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
tokens = tokenize(phonemes)
|
tokens = tokenize(phonemes)
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
|
|
|
@ -13,6 +13,7 @@ from ..inference.model_manager import get_manager as get_model_manager
|
||||||
from ..inference.voice_manager import get_manager as get_voice_manager
|
from ..inference.voice_manager import get_manager as get_voice_manager
|
||||||
from .audio import AudioNormalizer, AudioService
|
from .audio import AudioNormalizer, AudioService
|
||||||
from .text_processing.text_processor import process_text_chunk, smart_split
|
from .text_processing.text_processor import process_text_chunk, smart_split
|
||||||
|
from .text_processing import tokenize
|
||||||
|
|
||||||
class TTSService:
|
class TTSService:
|
||||||
"""Text-to-speech service."""
|
"""Text-to-speech service."""
|
||||||
|
@ -169,6 +170,55 @@ class TTSService:
|
||||||
del voice_tensor
|
del voice_tensor
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
async def generate_from_phonemes(
|
||||||
|
self, phonemes: str, voice: str, speed: float = 1.0
|
||||||
|
) -> Tuple[np.ndarray, float]:
|
||||||
|
"""Generate audio from phonemes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
phonemes: Phoneme string to synthesize
|
||||||
|
voice: Voice ID to use
|
||||||
|
speed: Speed multiplier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (audio array, processing time)
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
voice_tensor = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get backend and load voice
|
||||||
|
backend = self.model_manager.get_backend()
|
||||||
|
voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device)
|
||||||
|
|
||||||
|
# Convert phonemes to tokens
|
||||||
|
tokens = tokenize(phonemes)
|
||||||
|
if len(tokens) > 500: # Model context limit
|
||||||
|
raise ValueError(f"Phoneme sequence too long ({len(tokens)} tokens, max 500)")
|
||||||
|
|
||||||
|
tokens = [0] + tokens + [0] # Add start/end tokens
|
||||||
|
|
||||||
|
# Generate audio
|
||||||
|
audio = await self.model_manager.generate(
|
||||||
|
tokens,
|
||||||
|
voice_tensor,
|
||||||
|
speed=speed
|
||||||
|
)
|
||||||
|
|
||||||
|
if audio is None:
|
||||||
|
raise ValueError("Failed to generate audio")
|
||||||
|
|
||||||
|
processing_time = time.time() - start_time
|
||||||
|
return audio, processing_time
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in phoneme audio generation: {str(e)}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
if voice_tensor is not None:
|
||||||
|
del voice_tensor
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
async def generate_audio(
|
async def generate_audio(
|
||||||
self, text: str, voice: str, speed: float = 1.0
|
self, text: str, voice: str, speed: float = 1.0
|
||||||
) -> Tuple[np.ndarray, float]:
|
) -> Tuple[np.ndarray, float]:
|
||||||
|
|
|
@ -33,15 +33,6 @@ class StitchOptions(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class GenerateFromPhonemesRequest(BaseModel):
|
class GenerateFromPhonemesRequest(BaseModel):
|
||||||
phonemes: Union[str, List[str]] = Field(
|
"""Simple request for phoneme-to-speech generation"""
|
||||||
...,
|
phonemes: str = Field(..., description="Phoneme string to synthesize")
|
||||||
description="Single phoneme string or list of phoneme chunks to stitch together"
|
|
||||||
)
|
|
||||||
voice: str = Field(..., description="Voice ID to use for generation")
|
voice: str = Field(..., description="Voice ID to use for generation")
|
||||||
speed: float = Field(
|
|
||||||
default=1.0, ge=0.1, le=5.0, description="Speed factor for generation"
|
|
||||||
)
|
|
||||||
options: Optional[StitchOptions] = Field(
|
|
||||||
default=None,
|
|
||||||
description="Optional settings for audio generation and stitching"
|
|
||||||
)
|
|
||||||
|
|
|
@ -30,7 +30,14 @@ def sample_audio():
|
||||||
async def test_convert_to_wav(sample_audio):
|
async def test_convert_to_wav(sample_audio):
|
||||||
"""Test converting to WAV format"""
|
"""Test converting to WAV format"""
|
||||||
audio_data, sample_rate = sample_audio
|
audio_data, sample_rate = sample_audio
|
||||||
result = await AudioService.convert_audio(audio_data, sample_rate, "wav")
|
# Write and finalize in one step for WAV
|
||||||
|
result = await AudioService.convert_audio(
|
||||||
|
audio_data,
|
||||||
|
sample_rate,
|
||||||
|
"wav",
|
||||||
|
is_first_chunk=True,
|
||||||
|
is_last_chunk=True
|
||||||
|
)
|
||||||
assert isinstance(result, bytes)
|
assert isinstance(result, bytes)
|
||||||
assert len(result) > 0
|
assert len(result) > 0
|
||||||
# Check WAV header
|
# Check WAV header
|
||||||
|
@ -106,7 +113,14 @@ async def test_normalization_wav(sample_audio):
|
||||||
audio_data, sample_rate = sample_audio
|
audio_data, sample_rate = sample_audio
|
||||||
# Create audio data outside int16 range
|
# Create audio data outside int16 range
|
||||||
large_audio = audio_data * 1e5
|
large_audio = audio_data * 1e5
|
||||||
result = await AudioService.convert_audio(large_audio, sample_rate, "wav")
|
# Write and finalize in one step for WAV
|
||||||
|
result = await AudioService.convert_audio(
|
||||||
|
large_audio,
|
||||||
|
sample_rate,
|
||||||
|
"wav",
|
||||||
|
is_first_chunk=True,
|
||||||
|
is_last_chunk=True
|
||||||
|
)
|
||||||
assert isinstance(result, bytes)
|
assert isinstance(result, bytes)
|
||||||
assert len(result) > 0
|
assert len(result) > 0
|
||||||
|
|
||||||
|
@ -138,7 +152,13 @@ async def test_different_sample_rates(sample_audio):
|
||||||
sample_rates = [8000, 16000, 44100, 48000]
|
sample_rates = [8000, 16000, 44100, 48000]
|
||||||
|
|
||||||
for rate in sample_rates:
|
for rate in sample_rates:
|
||||||
result = await AudioService.convert_audio(audio_data, rate, "wav")
|
result = await AudioService.convert_audio(
|
||||||
|
audio_data,
|
||||||
|
rate,
|
||||||
|
"wav",
|
||||||
|
is_first_chunk=True,
|
||||||
|
is_last_chunk=True
|
||||||
|
)
|
||||||
assert isinstance(result, bytes)
|
assert isinstance(result, bytes)
|
||||||
assert len(result) > 0
|
assert len(result) > 0
|
||||||
|
|
||||||
|
@ -147,7 +167,20 @@ async def test_different_sample_rates(sample_audio):
|
||||||
async def test_buffer_position_after_conversion(sample_audio):
|
async def test_buffer_position_after_conversion(sample_audio):
|
||||||
"""Test that buffer position is reset after writing"""
|
"""Test that buffer position is reset after writing"""
|
||||||
audio_data, sample_rate = sample_audio
|
audio_data, sample_rate = sample_audio
|
||||||
result = await AudioService.convert_audio(audio_data, sample_rate, "wav")
|
# Write and finalize in one step for first conversion
|
||||||
|
result = await AudioService.convert_audio(
|
||||||
|
audio_data,
|
||||||
|
sample_rate,
|
||||||
|
"wav",
|
||||||
|
is_first_chunk=True,
|
||||||
|
is_last_chunk=True
|
||||||
|
)
|
||||||
# Convert again to ensure buffer was properly reset
|
# Convert again to ensure buffer was properly reset
|
||||||
result2 = await AudioService.convert_audio(audio_data, sample_rate, "wav")
|
result2 = await AudioService.convert_audio(
|
||||||
|
audio_data,
|
||||||
|
sample_rate,
|
||||||
|
"wav",
|
||||||
|
is_first_chunk=True,
|
||||||
|
is_last_chunk=True
|
||||||
|
)
|
||||||
assert len(result) == len(result2)
|
assert len(result) == len(result2)
|
||||||
|
|
|
@ -192,8 +192,13 @@ def mock_tts_service(mock_audio_bytes):
|
||||||
mock_get.side_effect = None
|
mock_get.side_effect = None
|
||||||
yield service
|
yield service
|
||||||
|
|
||||||
def test_openai_speech_endpoint(mock_tts_service, test_voice):
|
@patch('api.src.services.audio.AudioService.convert_audio')
|
||||||
|
def test_openai_speech_endpoint(mock_convert, mock_tts_service, test_voice, mock_audio_bytes):
|
||||||
"""Test the OpenAI-compatible speech endpoint with basic MP3 generation"""
|
"""Test the OpenAI-compatible speech endpoint with basic MP3 generation"""
|
||||||
|
# Configure mocks
|
||||||
|
mock_tts_service.generate_audio.return_value = (np.zeros(1000), 0.1)
|
||||||
|
mock_convert.return_value = mock_audio_bytes
|
||||||
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/audio/speech",
|
"/v1/audio/speech",
|
||||||
json={
|
json={
|
||||||
|
@ -207,6 +212,10 @@ def test_openai_speech_endpoint(mock_tts_service, test_voice):
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.headers["content-type"] == "audio/mpeg"
|
assert response.headers["content-type"] == "audio/mpeg"
|
||||||
assert len(response.content) > 0
|
assert len(response.content) > 0
|
||||||
|
assert response.content == mock_audio_bytes
|
||||||
|
|
||||||
|
mock_tts_service.generate_audio.assert_called_once()
|
||||||
|
mock_convert.assert_called_once()
|
||||||
|
|
||||||
def test_openai_speech_streaming(mock_tts_service, test_voice, mock_audio_bytes):
|
def test_openai_speech_streaming(mock_tts_service, test_voice, mock_audio_bytes):
|
||||||
"""Test the OpenAI-compatible speech endpoint with streaming"""
|
"""Test the OpenAI-compatible speech endpoint with streaming"""
|
||||||
|
@ -357,12 +366,8 @@ def test_server_error(mock_tts_service, test_voice):
|
||||||
|
|
||||||
def test_streaming_error(mock_tts_service, test_voice):
|
def test_streaming_error(mock_tts_service, test_voice):
|
||||||
"""Test handling streaming errors"""
|
"""Test handling streaming errors"""
|
||||||
async def mock_error_stream(*args, **kwargs) -> AsyncGenerator[bytes, None]:
|
# Mock process_voices to raise the error
|
||||||
if False: # This makes it a proper generator
|
mock_tts_service.list_voices.side_effect = RuntimeError("Streaming failed")
|
||||||
yield b""
|
|
||||||
raise RuntimeError("Streaming failed")
|
|
||||||
|
|
||||||
mock_tts_service.generate_audio_stream = mock_error_stream
|
|
||||||
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/audio/speech",
|
"/v1/audio/speech",
|
||||||
|
@ -374,10 +379,12 @@ def test_streaming_error(mock_tts_service, test_voice):
|
||||||
"stream": True
|
"stream": True
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.status_code == 500
|
assert response.status_code == 500
|
||||||
error_response = response.json()
|
error_data = response.json()
|
||||||
assert error_response["detail"]["error"] == "processing_error"
|
assert error_data["detail"]["error"] == "processing_error"
|
||||||
assert error_response["detail"]["type"] == "server_error"
|
assert error_data["detail"]["type"] == "server_error"
|
||||||
|
assert "Streaming failed" in error_data["detail"]["message"]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_streaming_initialization_error():
|
async def test_streaming_initialization_error():
|
||||||
|
|
|
@ -9,3 +9,9 @@ Accept: application/json
|
||||||
### Get System Information
|
### Get System Information
|
||||||
GET http://localhost:8880/debug/system
|
GET http://localhost:8880/debug/system
|
||||||
Accept: application/json
|
Accept: application/json
|
||||||
|
|
||||||
|
### Get Session Pool Status
|
||||||
|
# Shows active ONNX sessions, CUDA stream usage, and session ages
|
||||||
|
# Useful for debugging resource exhaustion issues
|
||||||
|
GET http://localhost:8880/debug/session_pools
|
||||||
|
Accept: application/json
|
167
docs/architecture/streaming_audio_writer_analysis.md
Normal file
167
docs/architecture/streaming_audio_writer_analysis.md
Normal file
|
@ -0,0 +1,167 @@
|
||||||
|
# Streaming Audio Writer Analysis
|
||||||
|
|
||||||
|
This auto-document provides an in-depth technical analysis of the `StreamingAudioWriter` class, detailing the streaming and non-streaming paths, supported formats, header management, and challenges faced in the implementation.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The `StreamingAudioWriter` class is designed to handle streaming audio format conversions efficiently. It supports various audio formats and provides methods to write audio data in chunks, finalize the stream, and manage audio headers meticulously to ensure compatibility and integrity of the resulting audio files.
|
||||||
|
|
||||||
|
## Supported Formats
|
||||||
|
|
||||||
|
The class supports the following audio formats:
|
||||||
|
|
||||||
|
- **WAV**
|
||||||
|
- **OGG**
|
||||||
|
- **Opus**
|
||||||
|
- **FLAC**
|
||||||
|
- **MP3**
|
||||||
|
- **AAC**
|
||||||
|
- **PCM**
|
||||||
|
|
||||||
|
## Initialization
|
||||||
|
|
||||||
|
Upon initialization, the class sets up format-specific configurations to prepare for audio data processing:
|
||||||
|
|
||||||
|
- **WAV**:
|
||||||
|
- Writes an initial WAV header with placeholders for file size (`RIFF` chunk) and data size (`data` chunk).
|
||||||
|
- Utilizes the `_write_wav_header` method to generate the header.
|
||||||
|
|
||||||
|
- **OGG/Opus/FLAC**:
|
||||||
|
- Uses `soundfile.SoundFile` to write audio data to a memory buffer (`BytesIO`).
|
||||||
|
- Configures the writer with appropriate format and subtype based on the specified format.
|
||||||
|
|
||||||
|
- **MP3/AAC**:
|
||||||
|
- Utilizes `pydub.AudioSegment` for incremental writing.
|
||||||
|
- Initializes an empty `AudioSegment` as the encoder to accumulate audio data.
|
||||||
|
|
||||||
|
- **PCM**:
|
||||||
|
- Prepares to write raw PCM bytes without additional headers.
|
||||||
|
|
||||||
|
Initialization ensures that each format is correctly configured to handle the specific requirements of streaming and finalizing audio data.
|
||||||
|
|
||||||
|
## Streaming Path
|
||||||
|
|
||||||
|
### Writing Chunks
|
||||||
|
|
||||||
|
The `write_chunk` method handles the incoming audio data, processing it according to the specified format:
|
||||||
|
|
||||||
|
- **WAV**:
|
||||||
|
- **First Chunk**: Writes the initial WAV header to the buffer.
|
||||||
|
- **Subsequent Chunks**: Writes raw PCM data directly after the header.
|
||||||
|
- Updates `bytes_written` to track the total size of audio data written.
|
||||||
|
|
||||||
|
- **OGG/Opus/FLAC**:
|
||||||
|
- Writes audio data to the `soundfile` buffer.
|
||||||
|
- Flushes the writer to ensure data integrity.
|
||||||
|
- Retrieves the current buffer contents and truncates the buffer for the next chunk.
|
||||||
|
|
||||||
|
- **MP3/AAC**:
|
||||||
|
- Converts incoming audio data (`np.ndarray`) to a `pydub.AudioSegment`.
|
||||||
|
- Accumulates segments in the encoder.
|
||||||
|
- Exports the current state to the output buffer without writing duration metadata or XING headers for chunks.
|
||||||
|
- Resets the encoder to prevent memory growth after exporting.
|
||||||
|
|
||||||
|
- **PCM**:
|
||||||
|
- Directly writes raw bytes from the audio data to the output buffer.
|
||||||
|
|
||||||
|
### Finalizing
|
||||||
|
|
||||||
|
Finalizing the audio stream involves ensuring that all audio data is correctly written and that headers are updated to reflect the accurate file and data sizes:
|
||||||
|
|
||||||
|
- **WAV**:
|
||||||
|
- Rewrites the `RIFF` and `data` chunks in the header with the actual file size (`bytes_written + 36`) and data size (`bytes_written`).
|
||||||
|
- Creates a new buffer with the complete WAV file by copying audio data from the original buffer starting at byte 44 (end of the initial header).
|
||||||
|
|
||||||
|
- **OGG/Opus/FLAC**:
|
||||||
|
- Closes the `soundfile` writer to flush all remaining data to the buffer.
|
||||||
|
- Returns the final buffer content, ensuring that all necessary headers and data are correctly written.
|
||||||
|
|
||||||
|
- **MP3/AAC**:
|
||||||
|
- Exports any remaining audio data with proper headers and metadata, including duration and VBR quality for MP3.
|
||||||
|
- Writes ID3v1 and ID3v2 tags for MP3 formats.
|
||||||
|
- Performs final exports to ensure that all audio data is properly encoded and formatted.
|
||||||
|
|
||||||
|
- **PCM**:
|
||||||
|
- No finalization is needed as PCM involves raw data without headers.
|
||||||
|
|
||||||
|
## Non-Streaming Path
|
||||||
|
|
||||||
|
The `StreamingAudioWriter` class is inherently designed for streaming audio data. However, it's essential to understand how it behaves when handling complete files versus streaming data:
|
||||||
|
|
||||||
|
### Full File Writing
|
||||||
|
|
||||||
|
- **Process**:
|
||||||
|
- Accumulate all audio data in memory or buffer.
|
||||||
|
- Write the complete file with accurate headers and data sizes upon finalization.
|
||||||
|
|
||||||
|
- **Advantages**:
|
||||||
|
- Simplifies header management since the total data size is known before writing.
|
||||||
|
- Reduces complexity in data handling and processing.
|
||||||
|
|
||||||
|
- **Disadvantages**:
|
||||||
|
- High memory consumption for large audio files.
|
||||||
|
- Delay in availability of audio data until the entire file is processed.
|
||||||
|
|
||||||
|
### Stream-to-File Writing
|
||||||
|
|
||||||
|
- **Process**:
|
||||||
|
- Incrementally write audio data in chunks.
|
||||||
|
- Update headers and finalize the file dynamically as data flows.
|
||||||
|
|
||||||
|
- **Advantages**:
|
||||||
|
- Lower memory usage as data is processed in smaller chunks.
|
||||||
|
- Immediate availability of audio data, suitable for real-time streaming applications.
|
||||||
|
|
||||||
|
- **Disadvantages**:
|
||||||
|
- Complex header management to accommodate dynamic data sizes.
|
||||||
|
- Increased likelihood of header synchronization issues, leading to potential file corruption.
|
||||||
|
|
||||||
|
**Challenges**:
|
||||||
|
- Balancing memory usage with processing speed.
|
||||||
|
- Ensuring consistent and accurate header updates during streaming operations.
|
||||||
|
|
||||||
|
## Header Management
|
||||||
|
|
||||||
|
### WAV Headers
|
||||||
|
|
||||||
|
WAV files utilize `RIFF` headers to describe file structure:
|
||||||
|
|
||||||
|
- **Initial Header**:
|
||||||
|
- Contains placeholders for file size and data size (`struct.pack('<L', 0)`).
|
||||||
|
|
||||||
|
- **Final Header**:
|
||||||
|
- Calculates and writes the actual file size (`bytes_written + 36`) and data size (`bytes_written`).
|
||||||
|
- Ensures that audio players can correctly interpret the file by having accurate header information.
|
||||||
|
|
||||||
|
**Technical Details**:
|
||||||
|
- The `_write_wav_header` method initializes the WAV header with placeholders.
|
||||||
|
- Upon finalization, the `write_chunk` method creates a new buffer, writes the correct sizes, and appends the audio data from the original buffer starting at byte 44 (end of the initial header).
|
||||||
|
|
||||||
|
**Challenges**:
|
||||||
|
- Maintaining synchronization between audio data size and header placeholders.
|
||||||
|
- Ensuring that the header is correctly rewritten upon finalization to prevent file corruption.
|
||||||
|
|
||||||
|
### MP3/AAC Headers
|
||||||
|
|
||||||
|
MP3 and AAC formats require proper metadata and headers to ensure compatibility:
|
||||||
|
|
||||||
|
- **XING Headers (MP3)**:
|
||||||
|
- Essential for Variable Bit Rate (VBR) audio files.
|
||||||
|
- Control the quality and indexing of the MP3 file.
|
||||||
|
|
||||||
|
- **ID3 Tags (MP3)**:
|
||||||
|
- Provide metadata such as artist, title, and album information.
|
||||||
|
|
||||||
|
- **ADTS Headers (AAC)**:
|
||||||
|
- Describe the AAC frame headers necessary for decoding.
|
||||||
|
|
||||||
|
**Technical Details**:
|
||||||
|
- During finalization, the `write_chunk` method for MP3/AAC formats includes:
|
||||||
|
- Duration metadata (`-metadata duration`).
|
||||||
|
- VBR headers for MP3 (`-write_vbr`, `-vbr_quality`).
|
||||||
|
- ID3 tags for MP3 (`-write_id3v1`, `-write_id3v2`).
|
||||||
|
- Ensures that all remaining audio data is correctly encoded and formatted with the necessary headers.
|
||||||
|
|
||||||
|
**Challenges**:
|
||||||
|
- Ensuring that metadata is accurately written during the finalization process.
|
||||||
|
- Managing VBR headers to maintain audio quality and file integrity.
|
|
@ -1,394 +1,138 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import os
|
import os
|
||||||
import argparse
|
import time
|
||||||
from typing import Dict, List, Tuple, Optional
|
import wave
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from scipy.io import wavfile
|
from openai import OpenAI
|
||||||
|
|
||||||
|
# Create output directory
|
||||||
|
output_dir = Path(__file__).parent / "output"
|
||||||
|
output_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
def submit_combine_voices(
|
# Initialize OpenAI client
|
||||||
voices: List[str], base_url: str = "http://localhost:8880"
|
client = OpenAI(base_url="http://localhost:8880/v1", api_key="not-needed")
|
||||||
) -> Optional[str]:
|
|
||||||
"""Combine multiple voices into a new voice.
|
|
||||||
|
|
||||||
Args:
|
# Test text that showcases voice characteristics
|
||||||
voices: List of voice names to combine (e.g. ["af_bella", "af_sarah"])
|
text = """The quick brown fox jumps over the lazy dog.
|
||||||
base_url: API base URL
|
How vexingly quick daft zebras jump!
|
||||||
|
The five boxing wizards jump quickly."""
|
||||||
|
|
||||||
Returns:
|
def generate_and_save_audio(voice: str, output_path: str):
|
||||||
Name of the combined voice (e.g. "af_bella_af_sarah") or None if error
|
"""Generate audio using specified voice and save to WAV file."""
|
||||||
"""
|
print(f"\nGenerating audio for voice: {voice}")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Generate audio using streaming response
|
||||||
|
with client.audio.speech.with_streaming_response.create(
|
||||||
|
model="kokoro",
|
||||||
|
voice=voice,
|
||||||
|
response_format="wav",
|
||||||
|
input=text,
|
||||||
|
) as response:
|
||||||
|
# Save the audio stream to file
|
||||||
|
with open(output_path, "wb") as f:
|
||||||
|
for chunk in response.iter_bytes():
|
||||||
|
f.write(chunk)
|
||||||
|
|
||||||
|
duration = time.time() - start_time
|
||||||
|
print(f"Generated in {duration:.2f}s")
|
||||||
|
print(f"Saved to {output_path}")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
def analyze_audio(filepath: str):
|
||||||
|
"""Analyze audio file and return key characteristics."""
|
||||||
|
print(f"\nAnalyzing {filepath}")
|
||||||
try:
|
try:
|
||||||
response = requests.post(f"{base_url}/v1/audio/voices/combine", json=voices)
|
print(f"\nTrying to read {filepath}")
|
||||||
print(f"Response status: {response.status_code}")
|
with wave.open(filepath, 'rb') as wf:
|
||||||
print(f"Raw response: {response.text}")
|
sample_rate = wf.getframerate()
|
||||||
|
samples = np.frombuffer(wf.readframes(wf.getnframes()), dtype=np.int16)
|
||||||
# Accept both 200 and 201 as success
|
print(f"Successfully read file:")
|
||||||
if response.status_code not in [200, 201]:
|
print(f"Sample rate: {sample_rate}")
|
||||||
try:
|
print(f"Samples shape: {samples.shape}")
|
||||||
error = response.json()["detail"]["message"]
|
print(f"Samples dtype: {samples.dtype}")
|
||||||
print(f"Error combining voices: {error}")
|
print(f"First few samples: {samples[:10]}")
|
||||||
except:
|
|
||||||
print(f"Error combining voices: {response.text}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
data = response.json()
|
|
||||||
if "voices" in data:
|
|
||||||
print(f"Available voices: {', '.join(sorted(data['voices']))}")
|
|
||||||
return data["voice"]
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error parsing response: {e}")
|
print(f"Error reading file: {str(e)}")
|
||||||
return None
|
raise
|
||||||
except Exception as e:
|
|
||||||
print(f"Error: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
# Convert to float64 for calculations
|
||||||
def generate_speech(
|
samples = samples.astype(np.float64) / 32768.0 # Normalize 16-bit audio
|
||||||
text: str,
|
|
||||||
voice: str,
|
|
||||||
base_url: str = "http://localhost:8880",
|
|
||||||
output_file: str = "output.mp3",
|
|
||||||
) -> bool:
|
|
||||||
"""Generate speech using specified voice.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Text to convert to speech
|
|
||||||
voice: Voice name to use
|
|
||||||
base_url: API base URL
|
|
||||||
output_file: Path to save audio file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if successful, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
response = requests.post(
|
|
||||||
f"{base_url}/v1/audio/speech",
|
|
||||||
json={
|
|
||||||
"input": text,
|
|
||||||
"voice": voice,
|
|
||||||
"speed": 1.0,
|
|
||||||
"response_format": "wav", # Use WAV for analysis
|
|
||||||
"stream": False,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
error = response.json().get("detail", {}).get("message", response.text)
|
|
||||||
print(f"Error generating speech: {error}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Save the audio
|
|
||||||
os.makedirs(
|
|
||||||
os.path.dirname(output_file) if os.path.dirname(output_file) else ".",
|
|
||||||
exist_ok=True,
|
|
||||||
)
|
|
||||||
with open(output_file, "wb") as f:
|
|
||||||
f.write(response.content)
|
|
||||||
print(f"Saved audio to {output_file}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def analyze_audio(filepath: str) -> Tuple[np.ndarray, int, dict]:
|
|
||||||
"""Analyze audio file and return samples, sample rate, and audio characteristics.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filepath: Path to audio file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (samples, sample_rate, characteristics)
|
|
||||||
"""
|
|
||||||
sample_rate, samples = wavfile.read(filepath)
|
|
||||||
|
|
||||||
# Convert to mono if stereo
|
# Convert to mono if stereo
|
||||||
if len(samples.shape) > 1:
|
if len(samples.shape) > 1:
|
||||||
samples = np.mean(samples, axis=1)
|
samples = np.mean(samples, axis=1)
|
||||||
|
|
||||||
# Calculate basic stats
|
# Calculate basic stats
|
||||||
|
duration = len(samples) / sample_rate
|
||||||
max_amp = np.max(np.abs(samples))
|
max_amp = np.max(np.abs(samples))
|
||||||
rms = np.sqrt(np.mean(samples**2))
|
rms = np.sqrt(np.mean(samples**2))
|
||||||
duration = len(samples) / sample_rate
|
|
||||||
|
|
||||||
# Zero crossing rate (helps identify voice characteristics)
|
# Calculate frequency characteristics
|
||||||
zero_crossings = np.sum(np.abs(np.diff(np.signbit(samples)))) / len(samples)
|
# Compute FFT
|
||||||
|
N = len(samples)
|
||||||
|
yf = np.fft.fft(samples)
|
||||||
|
xf = np.fft.fftfreq(N, 1 / sample_rate)[:N//2]
|
||||||
|
magnitude = 2.0/N * np.abs(yf[0:N//2])
|
||||||
|
# Calculate spectral centroid
|
||||||
|
spectral_centroid = np.sum(xf * magnitude) / np.sum(magnitude)
|
||||||
|
# Determine dominant frequencies
|
||||||
|
dominant_freqs = xf[magnitude.argsort()[-5:]][::-1].tolist()
|
||||||
|
|
||||||
# Simple frequency analysis
|
return {
|
||||||
if len(samples) > 0:
|
'samples': samples,
|
||||||
# Use FFT to get frequency components
|
'sample_rate': sample_rate,
|
||||||
fft_result = np.fft.fft(samples)
|
'duration': duration,
|
||||||
freqs = np.fft.fftfreq(len(samples), 1 / sample_rate)
|
'max_amplitude': max_amp,
|
||||||
|
'rms': rms,
|
||||||
# Get positive frequencies only
|
'spectral_centroid': spectral_centroid,
|
||||||
pos_mask = freqs > 0
|
'dominant_frequencies': dominant_freqs
|
||||||
freqs = freqs[pos_mask]
|
|
||||||
magnitudes = np.abs(fft_result)[pos_mask]
|
|
||||||
|
|
||||||
# Find dominant frequencies (top 3)
|
|
||||||
top_indices = np.argsort(magnitudes)[-3:]
|
|
||||||
dominant_freqs = freqs[top_indices]
|
|
||||||
|
|
||||||
# Calculate spectral centroid (brightness of sound)
|
|
||||||
spectral_centroid = np.sum(freqs * magnitudes) / np.sum(magnitudes)
|
|
||||||
else:
|
|
||||||
dominant_freqs = []
|
|
||||||
spectral_centroid = 0
|
|
||||||
|
|
||||||
characteristics = {
|
|
||||||
"max_amplitude": max_amp,
|
|
||||||
"rms": rms,
|
|
||||||
"duration": duration,
|
|
||||||
"zero_crossing_rate": zero_crossings,
|
|
||||||
"dominant_frequencies": dominant_freqs,
|
|
||||||
"spectral_centroid": spectral_centroid,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return samples, sample_rate, characteristics
|
def plot_comparison(analyses, output_path):
|
||||||
|
"""Create comparison plot of the audio analyses."""
|
||||||
|
plt.style.use('dark_background')
|
||||||
|
fig = plt.figure(figsize=(15, 10))
|
||||||
|
fig.patch.set_facecolor('#1a1a2e')
|
||||||
|
|
||||||
|
# Plot waveforms
|
||||||
|
for i, (name, data) in enumerate(analyses.items()):
|
||||||
|
ax = plt.subplot(3, 1, i+1)
|
||||||
|
samples = data['samples']
|
||||||
|
time = np.arange(len(samples)) / data['sample_rate']
|
||||||
|
plt.plot(time, samples / data['max_amplitude'], linewidth=0.5, color='#ff2a6d')
|
||||||
|
plt.title(f"Waveform: {name}", color='white', pad=20)
|
||||||
|
plt.xlabel("Time (seconds)", color='white')
|
||||||
|
plt.ylabel("Normalized Amplitude", color='white')
|
||||||
|
plt.grid(True, alpha=0.3)
|
||||||
|
ax.set_facecolor('#1a1a2e')
|
||||||
|
plt.ylim(-1.1, 1.1)
|
||||||
|
|
||||||
def setup_plot(fig, ax, title):
|
plt.tight_layout()
|
||||||
"""Configure plot styling"""
|
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
||||||
# Improve grid
|
print(f"\nSaved comparison plot to {output_path}")
|
||||||
ax.grid(True, linestyle="--", alpha=0.3, color="#ffffff")
|
|
||||||
|
|
||||||
# Set title and labels with better fonts
|
|
||||||
ax.set_title(title, pad=20, fontsize=16, fontweight="bold", color="#ffffff")
|
|
||||||
ax.set_xlabel(ax.get_xlabel(), fontsize=14, fontweight="medium", color="#ffffff")
|
|
||||||
ax.set_ylabel(ax.get_ylabel(), fontsize=14, fontweight="medium", color="#ffffff")
|
|
||||||
|
|
||||||
# Improve tick labels
|
|
||||||
ax.tick_params(labelsize=12, colors="#ffffff")
|
|
||||||
|
|
||||||
# Style spines
|
|
||||||
for spine in ax.spines.values():
|
|
||||||
spine.set_color("#ffffff")
|
|
||||||
spine.set_alpha(0.3)
|
|
||||||
spine.set_linewidth(0.5)
|
|
||||||
|
|
||||||
# Set background colors
|
|
||||||
ax.set_facecolor("#1a1a2e")
|
|
||||||
fig.patch.set_facecolor("#1a1a2e")
|
|
||||||
|
|
||||||
return fig, ax
|
|
||||||
|
|
||||||
|
|
||||||
def plot_analysis(audio_files: Dict[str, str], output_dir: str):
|
|
||||||
"""Plot comprehensive voice analysis including waveforms and metrics comparison.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
audio_files: Dictionary of label -> filepath
|
|
||||||
output_dir: Directory to save plot files
|
|
||||||
"""
|
|
||||||
# Set dark style
|
|
||||||
plt.style.use("dark_background")
|
|
||||||
|
|
||||||
# Create figure with subplots
|
|
||||||
fig = plt.figure(figsize=(15, 15))
|
|
||||||
fig.patch.set_facecolor("#1a1a2e")
|
|
||||||
num_files = len(audio_files)
|
|
||||||
|
|
||||||
# Create subplot grid with proper spacing for waveforms and metrics
|
|
||||||
total_rows = num_files + 2 # Add one more row for metrics
|
|
||||||
gs = plt.GridSpec(
|
|
||||||
total_rows, 2, height_ratios=[1.5] * num_files + [1, 1], hspace=0.4, wspace=0.3
|
|
||||||
)
|
|
||||||
|
|
||||||
# Analyze all files first
|
|
||||||
all_chars = {}
|
|
||||||
for i, (label, filepath) in enumerate(audio_files.items()):
|
|
||||||
samples, sample_rate, chars = analyze_audio(filepath)
|
|
||||||
all_chars[label] = chars
|
|
||||||
|
|
||||||
# Plot waveform spanning both columns
|
|
||||||
ax = plt.subplot(gs[i, :])
|
|
||||||
time = np.arange(len(samples)) / sample_rate
|
|
||||||
plt.plot(time, samples / chars["max_amplitude"], linewidth=0.5, color="#ff2a6d")
|
|
||||||
ax.set_xlabel("Time (seconds)")
|
|
||||||
ax.set_ylabel("Normalized Amplitude")
|
|
||||||
ax.set_ylim(-1.1, 1.1)
|
|
||||||
setup_plot(fig, ax, f"Waveform: {label}")
|
|
||||||
|
|
||||||
# Colors for voices
|
|
||||||
colors = ["#ff2a6d", "#05d9e8", "#d1f7ff"]
|
|
||||||
|
|
||||||
# Create metrics for each subplot
|
|
||||||
metrics = [
|
|
||||||
(
|
|
||||||
plt.subplot(gs[num_files, 0]),
|
|
||||||
[
|
|
||||||
(
|
|
||||||
"Volume",
|
|
||||||
[chars["rms"] * 100 for chars in all_chars.values()],
|
|
||||||
"RMS×100",
|
|
||||||
)
|
|
||||||
],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
plt.subplot(gs[num_files, 1]),
|
|
||||||
[
|
|
||||||
(
|
|
||||||
"Brightness",
|
|
||||||
[chars["spectral_centroid"] / 1000 for chars in all_chars.values()],
|
|
||||||
"kHz",
|
|
||||||
)
|
|
||||||
],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
plt.subplot(gs[num_files + 1, 0]),
|
|
||||||
[
|
|
||||||
(
|
|
||||||
"Voice Pitch",
|
|
||||||
[
|
|
||||||
min(chars["dominant_frequencies"])
|
|
||||||
for chars in all_chars.values()
|
|
||||||
],
|
|
||||||
"Hz",
|
|
||||||
)
|
|
||||||
],
|
|
||||||
),
|
|
||||||
(
|
|
||||||
plt.subplot(gs[num_files + 1, 1]),
|
|
||||||
[
|
|
||||||
(
|
|
||||||
"Texture",
|
|
||||||
[
|
|
||||||
chars["zero_crossing_rate"] * 1000
|
|
||||||
for chars in all_chars.values()
|
|
||||||
],
|
|
||||||
"ZCR×1000",
|
|
||||||
)
|
|
||||||
],
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Plot each metric
|
|
||||||
for i, (ax, metric_data) in enumerate(metrics):
|
|
||||||
n_voices = len(audio_files)
|
|
||||||
bar_width = 0.25
|
|
||||||
indices = np.array([0])
|
|
||||||
|
|
||||||
values = metric_data[0][1]
|
|
||||||
max_val = max(values)
|
|
||||||
|
|
||||||
for j, (voice, color) in enumerate(zip(audio_files.keys(), colors)):
|
|
||||||
offset = (j - n_voices / 2 + 0.5) * bar_width
|
|
||||||
bars = ax.bar(
|
|
||||||
indices + offset,
|
|
||||||
[values[j]],
|
|
||||||
bar_width,
|
|
||||||
label=voice,
|
|
||||||
color=color,
|
|
||||||
alpha=0.8,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add value labels on top of bars
|
|
||||||
for bar in bars:
|
|
||||||
height = bar.get_height()
|
|
||||||
ax.text(
|
|
||||||
bar.get_x() + bar.get_width() / 2.0,
|
|
||||||
height,
|
|
||||||
f"{height:.1f}",
|
|
||||||
ha="center",
|
|
||||||
va="bottom",
|
|
||||||
color="white",
|
|
||||||
fontsize=10,
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.set_xticks(indices)
|
|
||||||
ax.set_xticklabels([f"{metric_data[0][0]}\n({metric_data[0][2]})"])
|
|
||||||
ax.set_ylim(0, max_val * 1.2)
|
|
||||||
ax.set_ylabel("Value")
|
|
||||||
|
|
||||||
# Only show legend on first metric plot
|
|
||||||
if i == 0:
|
|
||||||
ax.legend(
|
|
||||||
bbox_to_anchor=(1.05, 1),
|
|
||||||
loc="upper left",
|
|
||||||
facecolor="#1a1a2e",
|
|
||||||
edgecolor="#ffffff",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Style the subplot
|
|
||||||
setup_plot(fig, ax, metric_data[0][0])
|
|
||||||
|
|
||||||
# Adjust the figure size and padding
|
|
||||||
fig.set_size_inches(15, 20)
|
|
||||||
plt.subplots_adjust(right=0.85, top=0.95, bottom=0.05, left=0.1)
|
|
||||||
plt.savefig(os.path.join(output_dir, "analysis_comparison.png"), dpi=300)
|
|
||||||
print(f"Saved analysis comparison to {output_dir}/analysis_comparison.png")
|
|
||||||
|
|
||||||
# Print detailed comparative analysis
|
|
||||||
print("\nDetailed Voice Analysis:")
|
|
||||||
for label, chars in all_chars.items():
|
|
||||||
print(f"\n{label}:")
|
|
||||||
print(f" Max Amplitude: {chars['max_amplitude']:.2f}")
|
|
||||||
print(f" RMS (loudness): {chars['rms']:.2f}")
|
|
||||||
print(f" Duration: {chars['duration']:.2f}s")
|
|
||||||
print(f" Zero Crossing Rate: {chars['zero_crossing_rate']:.3f}")
|
|
||||||
print(f" Spectral Centroid: {chars['spectral_centroid']:.0f}Hz")
|
|
||||||
print(
|
|
||||||
f" Dominant Frequencies: {', '.join(f'{f:.0f}Hz' for f in chars['dominant_frequencies'])}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="Kokoro Voice Analysis Demo")
|
# Generate audio for each voice
|
||||||
parser.add_argument("--voices", nargs="+", type=str, help="Voices to combine")
|
voices = {
|
||||||
parser.add_argument(
|
'af_bella': output_dir / 'af_bella.wav',
|
||||||
"--text",
|
'af_irulan': output_dir / 'af_irulan.wav',
|
||||||
type=str,
|
'af_bella+af_irulan': output_dir / 'af_bella+af_irulan.wav'
|
||||||
default="Hello! This is a test of combined voices.",
|
}
|
||||||
help="Text to speak",
|
|
||||||
)
|
|
||||||
parser.add_argument("--url", default="http://localhost:8880", help="API base URL")
|
|
||||||
parser.add_argument(
|
|
||||||
"--output-dir",
|
|
||||||
default="examples/assorted_checks/test_combinations/output",
|
|
||||||
help="Output directory for audio files",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if not args.voices:
|
for voice, path in voices.items():
|
||||||
print("No voices provided, using default test voices")
|
generate_and_save_audio(voice, str(path))
|
||||||
args.voices = ["af_bella", "af_nicole"]
|
|
||||||
|
|
||||||
# Create output directory
|
# Analyze each audio file
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
analyses = {}
|
||||||
|
for name, path in voices.items():
|
||||||
# Dictionary to store audio files for analysis
|
analyses[name] = analyze_audio(str(path))
|
||||||
audio_files = {}
|
|
||||||
|
|
||||||
# Generate speech with individual voices
|
|
||||||
print("Generating speech with individual voices...")
|
|
||||||
for voice in args.voices:
|
|
||||||
output_file = os.path.join(args.output_dir, f"analysis_{voice}.wav")
|
|
||||||
if generate_speech(args.text, voice, args.url, output_file):
|
|
||||||
audio_files[voice] = output_file
|
|
||||||
|
|
||||||
# Generate speech with combined voice
|
|
||||||
print(f"\nCombining voices: {', '.join(args.voices)}")
|
|
||||||
combined_voice = submit_combine_voices(args.voices, args.url)
|
|
||||||
|
|
||||||
if combined_voice:
|
|
||||||
print(f"Successfully created combined voice: {combined_voice}")
|
|
||||||
output_file = os.path.join(
|
|
||||||
args.output_dir, f"analysis_combined_{combined_voice}.wav"
|
|
||||||
)
|
|
||||||
if generate_speech(args.text, combined_voice, args.url, output_file):
|
|
||||||
audio_files["combined"] = output_file
|
|
||||||
|
|
||||||
# Generate comparison plots
|
|
||||||
plot_analysis(audio_files, args.output_dir)
|
|
||||||
else:
|
|
||||||
print("Failed to combine voices")
|
|
||||||
|
|
||||||
|
# Create comparison plot
|
||||||
|
plot_comparison(analyses, output_dir / 'voice_comparison.png')
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -31,7 +31,7 @@ def stream_to_speakers() -> None:
|
||||||
|
|
||||||
with openai.audio.speech.with_streaming_response.create(
|
with openai.audio.speech.with_streaming_response.create(
|
||||||
model="kokoro",
|
model="kokoro",
|
||||||
voice="af_bella",
|
voice="af_bella+af_irulan",
|
||||||
response_format="pcm", # similar to WAV, but without a header chunk at the start.
|
response_format="pcm", # similar to WAV, but without a header chunk at the start.
|
||||||
input="""I see skies of blue and clouds of white
|
input="""I see skies of blue and clouds of white
|
||||||
The bright blessed days, the dark sacred nights
|
The bright blessed days, the dark sacred nights
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import json
|
import json
|
||||||
from typing import Tuple, Optional
|
from typing import Tuple, Optional, Union, List
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
@ -22,7 +22,7 @@ def get_phonemes(text: str, language: str = "a") -> Tuple[str, list[int]]:
|
||||||
payload = {"text": text, "language": language}
|
payload = {"text": text, "language": language}
|
||||||
|
|
||||||
# Make POST request to the phonemize endpoint
|
# Make POST request to the phonemize endpoint
|
||||||
response = requests.post("http://localhost:8880/text/phonemize", json=payload)
|
response = requests.post("http://localhost:8880/dev/phonemize", json=payload)
|
||||||
|
|
||||||
# Raise exception for error status codes
|
# Raise exception for error status codes
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
@ -32,44 +32,29 @@ def get_phonemes(text: str, language: str = "a") -> Tuple[str, list[int]]:
|
||||||
return result["phonemes"], result["tokens"]
|
return result["phonemes"], result["tokens"]
|
||||||
|
|
||||||
|
|
||||||
def generate_audio_from_phonemes(
|
def generate_audio_from_phonemes(phonemes: str, voice: str = "af_bella") -> Optional[bytes]:
|
||||||
phonemes: str, voice: str = "af_bella", speed: float = 1.0
|
"""Generate audio from phonemes."""
|
||||||
) -> Optional[bytes]:
|
|
||||||
"""Generate audio from phonemes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
phonemes: Phoneme string to synthesize
|
|
||||||
voice: Voice ID to use (defaults to af_bella)
|
|
||||||
speed: Speed factor (defaults to 1.0)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
WAV audio bytes if successful, None if failed
|
|
||||||
"""
|
|
||||||
# Create the request payload
|
|
||||||
payload = {
|
|
||||||
"phonemes": phonemes,
|
|
||||||
"voice": voice,
|
|
||||||
"speed": speed,
|
|
||||||
"stitch_long_content": True # Default to false to get the error message
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Make POST request to generate audio
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
"http://localhost:8880/text/generate_from_phonemes", json=payload
|
"http://localhost:8880/dev/generate_from_phonemes",
|
||||||
|
json={"phonemes": phonemes, "voice": voice},
|
||||||
|
headers={"Accept": "audio/wav"}
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
|
||||||
return response.content
|
print(f"Response status: {response.status_code}")
|
||||||
except requests.HTTPError as e:
|
print(f"Response headers: {dict(response.headers)}")
|
||||||
# Get the error details from the response
|
print(f"Response content type: {response.headers.get('Content-Type')}")
|
||||||
try:
|
print(f"Response length: {len(response.content)} bytes")
|
||||||
error_details = response.json()
|
|
||||||
error_msg = error_details.get('detail', {}).get('message', str(e))
|
if response.status_code != 200:
|
||||||
print(f"Server Error: {error_msg}")
|
print(f"Error response: {response.text}")
|
||||||
except:
|
|
||||||
print(f"Error: {e}")
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if not response.content:
|
||||||
|
print("Error: Empty response content")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return response.content
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Example texts to convert
|
# Example texts to convert
|
||||||
|
@ -103,11 +88,15 @@ def main():
|
||||||
|
|
||||||
# Generate audio from phonemes
|
# Generate audio from phonemes
|
||||||
print("Generating audio...")
|
print("Generating audio...")
|
||||||
if len(phonemes) > 500: # split into arrays of 500 phonemes
|
|
||||||
phonemes = [phonemes[i:i+500] for i in range(0, len(phonemes), 500)]
|
|
||||||
|
|
||||||
audio_bytes = generate_audio_from_phonemes(phonemes)
|
audio_bytes = generate_audio_from_phonemes(phonemes)
|
||||||
|
|
||||||
|
if not audio_bytes:
|
||||||
|
print("Error: No audio data generated")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Log response size
|
||||||
|
print(f"Generated {len(audio_bytes)} bytes of audio data")
|
||||||
|
|
||||||
if audio_bytes:
|
if audio_bytes:
|
||||||
# Save audio file
|
# Save audio file
|
||||||
output_path = output_dir / f"example_{i+1}.wav"
|
output_path = output_dir / f"example_{i+1}.wav"
|
||||||
|
|
18
uv.lock
generated
18
uv.lock
generated
|
@ -1016,6 +1016,7 @@ dependencies = [
|
||||||
{ name = "openai" },
|
{ name = "openai" },
|
||||||
{ name = "phonemizer" },
|
{ name = "phonemizer" },
|
||||||
{ name = "psutil" },
|
{ name = "psutil" },
|
||||||
|
{ name = "pyaudio" },
|
||||||
{ name = "pydantic" },
|
{ name = "pydantic" },
|
||||||
{ name = "pydantic-settings" },
|
{ name = "pydantic-settings" },
|
||||||
{ name = "pydub" },
|
{ name = "pydub" },
|
||||||
|
@ -1070,6 +1071,7 @@ requires-dist = [
|
||||||
{ name = "openai", marker = "extra == 'test'", specifier = ">=1.59.6" },
|
{ name = "openai", marker = "extra == 'test'", specifier = ">=1.59.6" },
|
||||||
{ name = "phonemizer", specifier = "==3.3.0" },
|
{ name = "phonemizer", specifier = "==3.3.0" },
|
||||||
{ name = "psutil", specifier = ">=6.1.1" },
|
{ name = "psutil", specifier = ">=6.1.1" },
|
||||||
|
{ name = "pyaudio", specifier = ">=0.2.14" },
|
||||||
{ name = "pydantic", specifier = "==2.10.4" },
|
{ name = "pydantic", specifier = "==2.10.4" },
|
||||||
{ name = "pydantic-settings", specifier = "==2.7.0" },
|
{ name = "pydantic-settings", specifier = "==2.7.0" },
|
||||||
{ name = "pydub", specifier = ">=0.25.1" },
|
{ name = "pydub", specifier = ">=0.25.1" },
|
||||||
|
@ -2128,6 +2130,22 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/7b/d7/7831438e6c3ebbfa6e01a927127a6cb42ad3ab844247f3c5b96bea25d73d/psutil-6.1.1-cp37-abi3-win_amd64.whl", hash = "sha256:f35cfccb065fff93529d2afb4a2e89e363fe63ca1e4a5da22b603a85833c2649", size = 254444 },
|
{ url = "https://files.pythonhosted.org/packages/7b/d7/7831438e6c3ebbfa6e01a927127a6cb42ad3ab844247f3c5b96bea25d73d/psutil-6.1.1-cp37-abi3-win_amd64.whl", hash = "sha256:f35cfccb065fff93529d2afb4a2e89e363fe63ca1e4a5da22b603a85833c2649", size = 254444 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pyaudio"
|
||||||
|
version = "0.2.14"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/26/1d/8878c7752febb0f6716a7e1a52cb92ac98871c5aa522cba181878091607c/PyAudio-0.2.14.tar.gz", hash = "sha256:78dfff3879b4994d1f4fc6485646a57755c6ee3c19647a491f790a0895bd2f87", size = 47066 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/90/90/1553487277e6aa25c0b7c2c38709cdd2b49e11c66c0b25c6e8b7b6638c72/PyAudio-0.2.14-cp310-cp310-win32.whl", hash = "sha256:126065b5e82a1c03ba16e7c0404d8f54e17368836e7d2d92427358ad44fefe61", size = 144624 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/27/bc/719d140ee63cf4b0725016531d36743a797ffdbab85e8536922902c9349a/PyAudio-0.2.14-cp310-cp310-win_amd64.whl", hash = "sha256:2a166fc88d435a2779810dd2678354adc33499e9d4d7f937f28b20cc55893e83", size = 164069 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/7b/f0/b0eab89eafa70a86b7b566a4df2f94c7880a2d483aa8de1c77d335335b5b/PyAudio-0.2.14-cp311-cp311-win32.whl", hash = "sha256:506b32a595f8693811682ab4b127602d404df7dfc453b499c91a80d0f7bad289", size = 144624 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/82/d8/f043c854aad450a76e476b0cf9cda1956419e1dacf1062eb9df3c0055abe/PyAudio-0.2.14-cp311-cp311-win_amd64.whl", hash = "sha256:bbeb01d36a2f472ae5ee5e1451cacc42112986abe622f735bb870a5db77cf903", size = 164070 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/8d/45/8d2b76e8f6db783f9326c1305f3f816d4a12c8eda5edc6a2e1d03c097c3b/PyAudio-0.2.14-cp312-cp312-win32.whl", hash = "sha256:5fce4bcdd2e0e8c063d835dbe2860dac46437506af509353c7f8114d4bacbd5b", size = 144750 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b0/6a/d25812e5f79f06285767ec607b39149d02aa3b31d50c2269768f48768930/PyAudio-0.2.14-cp312-cp312-win_amd64.whl", hash = "sha256:12f2f1ba04e06ff95d80700a78967897a489c05e093e3bffa05a84ed9c0a7fa3", size = 164126 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/3a/77/66cd37111a87c1589b63524f3d3c848011d21ca97828422c7fde7665ff0d/PyAudio-0.2.14-cp313-cp313-win32.whl", hash = "sha256:95328285b4dab57ea8c52a4a996cb52be6d629353315be5bfda403d15932a497", size = 150982 },
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/a5/8b/7f9a061c1cc2b230f9ac02a6003fcd14c85ce1828013aecbaf45aa988d20/PyAudio-0.2.14-cp313-cp313-win_amd64.whl", hash = "sha256:692d8c1446f52ed2662120bcd9ddcb5aa2b71f38bda31e58b19fb4672fffba69", size = 173655 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pycparser"
|
name = "pycparser"
|
||||||
version = "2.22"
|
version = "2.22"
|
||||||
|
|
Loading…
Add table
Reference in a new issue