Refactor Docker configurations and update test mocks for development routers

This commit is contained in:
remsky 2025-01-10 22:03:16 -07:00
parent e8c1284032
commit 926ea8cecf
11 changed files with 63 additions and 49 deletions

1
.gitignore vendored
View file

@ -6,7 +6,6 @@ ui/data/*
*.db *.db
*.pyc *.pyc
*.pth *.pth
*.pt
Kokoro-82M/* Kokoro-82M/*
__pycache__/ __pycache__/

View file

@ -93,6 +93,7 @@ async def create_speech(
"Content-Disposition": f"attachment; filename=speech.{request.response_format}", "Content-Disposition": f"attachment; filename=speech.{request.response_format}",
"X-Accel-Buffering": "no", # Disable proxy buffering "X-Accel-Buffering": "no", # Disable proxy buffering
"Cache-Control": "no-cache", # Prevent caching "Cache-Control": "no-cache", # Prevent caching
"Transfer-Encoding": "chunked", # Enable chunked transfer encoding
}, },
) )
else: else:

View file

@ -104,7 +104,7 @@ class AudioService:
# Raw 16-bit PCM samples, no header # Raw 16-bit PCM samples, no header
buffer.write(normalized_audio.tobytes()) buffer.write(normalized_audio.tobytes())
elif output_format == "wav": elif output_format == "wav":
# Always use soundfile for WAV to ensure proper headers and normalization # WAV format with headers
sf.write( sf.write(
buffer, buffer,
normalized_audio, normalized_audio,
@ -113,14 +113,14 @@ class AudioService:
subtype="PCM_16", subtype="PCM_16",
) )
elif output_format == "mp3": elif output_format == "mp3":
# Use format settings or defaults # MP3 format with proper framing
settings = format_settings.get("mp3", {}) if format_settings else {} settings = format_settings.get("mp3", {}) if format_settings else {}
settings = {**AudioService.DEFAULT_SETTINGS["mp3"], **settings} settings = {**AudioService.DEFAULT_SETTINGS["mp3"], **settings}
sf.write( sf.write(
buffer, normalized_audio, sample_rate, format="MP3", **settings buffer, normalized_audio, sample_rate, format="MP3", **settings
) )
elif output_format == "opus": elif output_format == "opus":
# Opus format in OGG container
settings = format_settings.get("opus", {}) if format_settings else {} settings = format_settings.get("opus", {}) if format_settings else {}
settings = {**AudioService.DEFAULT_SETTINGS["opus"], **settings} settings = {**AudioService.DEFAULT_SETTINGS["opus"], **settings}
sf.write( sf.write(
@ -131,8 +131,8 @@ class AudioService:
subtype="OPUS", subtype="OPUS",
**settings, **settings,
) )
elif output_format == "flac": elif output_format == "flac":
# FLAC format with proper framing
if is_first_chunk: if is_first_chunk:
logger.info("Starting FLAC stream...") logger.info("Starting FLAC stream...")
settings = format_settings.get("flac", {}) if format_settings else {} settings = format_settings.get("flac", {}) if format_settings else {}
@ -145,15 +145,14 @@ class AudioService:
subtype="PCM_16", subtype="PCM_16",
**settings, **settings,
) )
elif output_format == "aac":
raise ValueError(
"Format aac not currently supported. Supported formats are: wav, mp3, opus, flac, pcm."
)
else: else:
if output_format == "aac": raise ValueError(
raise ValueError( f"Format {output_format} not supported. Supported formats are: wav, mp3, opus, flac, pcm, aac."
"Format aac not supported. Supported formats are: wav, mp3, opus, flac, pcm." )
)
else:
raise ValueError(
f"Format {output_format} not supported. Supported formats are: wav, mp3, opus, flac, pcm."
)
buffer.seek(0) buffer.seek(0)
return buffer.getvalue() return buffer.getvalue()

View file

@ -177,7 +177,7 @@ class TTSService:
) )
if chunk_audio is not None: if chunk_audio is not None:
# Convert chunk with proper header handling # Convert chunk with proper streaming header handling
chunk_bytes = AudioService.convert_audio( chunk_bytes = AudioService.convert_audio(
chunk_audio, chunk_audio,
24000, 24000,
@ -185,6 +185,7 @@ class TTSService:
is_first_chunk=is_first, is_first_chunk=is_first,
normalizer=stream_normalizer, normalizer=stream_normalizer,
is_last_chunk=(next_chunk is None), # Last if no next chunk is_last_chunk=(next_chunk is None), # Last if no next chunk
stream=True # Ensure proper streaming format handling
) )
yield chunk_bytes yield chunk_bytes

View file

@ -181,7 +181,7 @@ def mock_tts_service(monkeypatch):
# Mock TTSModel.generate_from_tokens since we call it directly # Mock TTSModel.generate_from_tokens since we call it directly
mock_generate = Mock(return_value=np.zeros(48000)) mock_generate = Mock(return_value=np.zeros(48000))
monkeypatch.setattr( monkeypatch.setattr(
"api.src.routers.text_processing.TTSModel.generate_from_tokens", mock_generate "api.src.routers.development.TTSModel.generate_from_tokens", mock_generate
) )
return mock_service return mock_service
@ -192,5 +192,5 @@ def mock_audio_service(monkeypatch):
"""Mock AudioService""" """Mock AudioService"""
mock_service = Mock() mock_service = Mock()
mock_service.convert_audio.return_value = b"mock audio data" mock_service.convert_audio.return_value = b"mock audio data"
monkeypatch.setattr("api.src.routers.text_processing.AudioService", mock_service) monkeypatch.setattr("api.src.routers.development.AudioService", mock_service)
return mock_service return mock_service

View file

@ -63,7 +63,7 @@ def test_convert_to_aac_raises_error(sample_audio):
audio_data, sample_rate = sample_audio audio_data, sample_rate = sample_audio
with pytest.raises( with pytest.raises(
ValueError, ValueError,
match="Format aac not supported. Supported formats are: wav, mp3, opus, flac, pcm.", match="Failed to convert audio to aac: Format aac not currently supported. Supported formats are: wav, mp3, opus, flac, pcm.",
): ):
AudioService.convert_audio(audio_data, sample_rate, "aac") AudioService.convert_audio(audio_data, sample_rate, "aac")

View file

@ -20,8 +20,8 @@ async def async_client():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_phonemize_endpoint(async_client): async def test_phonemize_endpoint(async_client):
"""Test phoneme generation endpoint""" """Test phoneme generation endpoint"""
with patch("api.src.routers.text_processing.phonemize") as mock_phonemize, patch( with patch("api.src.routers.development.phonemize") as mock_phonemize, patch(
"api.src.routers.text_processing.tokenize" "api.src.routers.development.tokenize"
) as mock_tokenize: ) as mock_tokenize:
# Setup mocks # Setup mocks
mock_phonemize.return_value = "həlˈ" mock_phonemize.return_value = "həlˈ"
@ -56,7 +56,7 @@ async def test_generate_from_phonemes(
): ):
"""Test audio generation from phonemes""" """Test audio generation from phonemes"""
with patch( with patch(
"api.src.routers.text_processing.TTSService", return_value=mock_tts_service "api.src.routers.development.TTSService", return_value=mock_tts_service
): ):
response = await async_client.post( response = await async_client.post(
"/text/generate_from_phonemes", "/text/generate_from_phonemes",
@ -76,7 +76,7 @@ async def test_generate_from_phonemes_invalid_voice(async_client, mock_tts_servi
"""Test audio generation with invalid voice""" """Test audio generation with invalid voice"""
mock_tts_service._get_voice_path.return_value = None mock_tts_service._get_voice_path.return_value = None
with patch( with patch(
"api.src.routers.text_processing.TTSService", return_value=mock_tts_service "api.src.routers.development.TTSService", return_value=mock_tts_service
): ):
response = await async_client.post( response = await async_client.post(
"/text/generate_from_phonemes", "/text/generate_from_phonemes",
@ -111,7 +111,7 @@ async def test_generate_from_phonemes_invalid_speed(async_client, monkeypatch):
async def test_generate_from_phonemes_empty_phonemes(async_client, mock_tts_service): async def test_generate_from_phonemes_empty_phonemes(async_client, mock_tts_service):
"""Test audio generation with empty phonemes""" """Test audio generation with empty phonemes"""
with patch( with patch(
"api.src.routers.text_processing.TTSService", return_value=mock_tts_service "api.src.routers.development.TTSService", return_value=mock_tts_service
): ):
response = await async_client.post( response = await async_client.post(
"/text/generate_from_phonemes", "/text/generate_from_phonemes",

View file

@ -26,11 +26,11 @@ services:
start_period: 1s start_period: 1s
kokoro-tts: kokoro-tts:
# image: ghcr.io/remsky/kokoro-fastapi:latest-cpu image: ghcr.io/remsky/kokoro-fastapi:latest-cpu
# Uncomment below to build from source instead of using the released image # Uncomment below (and comment out above) to build from source instead of using the released image
build: # build:
context: . # context: .
dockerfile: Dockerfile.cpu # dockerfile: Dockerfile.cpu
volumes: volumes:
- ./api/src:/app/api/src - ./api/src:/app/api/src
- ./Kokoro-82M:/app/Kokoro-82M - ./Kokoro-82M:/app/Kokoro-82M
@ -52,8 +52,8 @@ services:
# Gradio UI service [Comment out everything below if you don't need it] # Gradio UI service [Comment out everything below if you don't need it]
gradio-ui: gradio-ui:
# image: ghcr.io/remsky/kokoro-fastapi:latest-ui image: ghcr.io/remsky/kokoro-fastapi:latest-ui
# Uncomment below to build from source instead of using the released image # Uncomment below (and comment out above) to build from source instead of using the released image
build: build:
context: ./ui context: ./ui
ports: ports:
@ -63,3 +63,4 @@ services:
- ./ui/app.py:/app/app.py # Mount app.py for hot reload - ./ui/app.py:/app/app.py # Mount app.py for hot reload
environment: environment:
- GRADIO_WATCH=True # Enable hot reloading - GRADIO_WATCH=True # Enable hot reloading
- PYTHONUNBUFFERED=1 # Ensure Python output is not buffered

View file

@ -32,10 +32,10 @@ services:
start_period: 1s start_period: 1s
kokoro-tts: kokoro-tts:
# image: ghcr.io/remsky/kokoro-fastapi:latest image: ghcr.io/remsky/kokoro-fastapi:latest
# Uncomment below to build from source instead of using the released image # Uncomment below (and comment out above) to build from source instead of using the released image
build: # build:
context: . # context: .
volumes: volumes:
- ./api/src:/app/api/src - ./api/src:/app/api/src
- ./Kokoro-82M:/app/Kokoro-82M - ./Kokoro-82M:/app/Kokoro-82M
@ -50,16 +50,22 @@ services:
- driver: nvidia - driver: nvidia
count: 1 count: 1
capabilities: [gpu] capabilities: [gpu]
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8880/v1/audio/voices"]
interval: 10s
timeout: 5s
retries: 30
start_period: 30s
depends_on: depends_on:
model-fetcher: model-fetcher:
condition: service_healthy condition: service_healthy
# Gradio UI service [Comment out everything below if you don't need it] # Gradio UI service [Comment out everything below if you don't need it]
gradio-ui: gradio-ui:
# image: ghcr.io/remsky/kokoro-fastapi:latest-ui image: ghcr.io/remsky/kokoro-fastapi:latest-ui
# Uncomment below to build from source instead of using the released image # Uncomment below (and comment out above) to build from source instead of using the released image
build: # build:
context: ./ui # context: ./ui
ports: ports:
- "7860:7860" - "7860:7860"
volumes: volumes:
@ -67,3 +73,7 @@ services:
- ./ui/app.py:/app/app.py # Mount app.py for hot reload - ./ui/app.py:/app/app.py # Mount app.py for hot reload
environment: environment:
- GRADIO_WATCH=True # Enable hot reloading - GRADIO_WATCH=True # Enable hot reloading
- PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
depends_on:
kokoro-tts:
condition: service_healthy

View file

@ -1,23 +1,25 @@
#!/usr/bin/env rye run python #!/usr/bin/env rye run python
# %%
import time import time
from pathlib import Path from pathlib import Path
from openai import OpenAI from openai import OpenAI
# gets OPENAI_API_KEY from your environment variables # gets OPENAI_API_KEY from your environment variables
openai = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed-for-local") openai = OpenAI(base_url="http://localhost:8880/v1", api_key="not-needed-for-local")
speech_file_path = Path(__file__).parent / "speech.mp3" speech_file_path = Path(__file__).parent / "speech.mp3"
def main() -> None: def main() -> None:
stream_to_speakers() stream_to_speakers()
# Create text-to-speech audio file
with openai.audio.speech.with_streaming_response.create(
model="kokoro",
voice="af_bella",
input="the quick brown fox jumped over the lazy dogs",
) as response:
response.stream_to_file(speech_file_path)
def stream_to_speakers() -> None: def stream_to_speakers() -> None:
@ -31,9 +33,12 @@ def stream_to_speakers() -> None:
with openai.audio.speech.with_streaming_response.create( with openai.audio.speech.with_streaming_response.create(
model="kokoro", model="kokoro",
voice=VOICE, voice="af_bella",
response_format="mp3", # similar to WAV, but without a header chunk at the start. response_format="pcm", # similar to WAV, but without a header chunk at the start.
input="""My dear sir, that is just where you are wrong. That is just where the whole world has gone wrong. We are always getting away from the present moment. Our mental existences, which are immaterial and have no dimensions, are passing along the Time-Dimension""", input="""I see skies of blue and clouds of white
The bright blessed days, the dark sacred nights
And I think to myself
What a wonderful world""",
) as response: ) as response:
print(f"Time to first byte: {int((time.time() - start_time) * 1000)}ms") print(f"Time to first byte: {int((time.time() - start_time) * 1000)}ms")
for chunk in response.iter_bytes(chunk_size=1024): for chunk in response.iter_bytes(chunk_size=1024):
@ -44,5 +49,3 @@ def stream_to_speakers() -> None:
if __name__ == "__main__": if __name__ == "__main__":
main() main()
# %%

View file

@ -74,7 +74,7 @@ def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"):
all_audio_data.extend(chunk) all_audio_data.extend(chunk)
# Log progress every 10 chunks # Log progress every 10 chunks
if chunk_count % 10 == 0: if chunk_count % 100 == 0:
elapsed = time.time() - start_time elapsed = time.time() - start_time
print( print(
f"Progress: {chunk_count} chunks, {total_bytes/1024:.1f}KB received, {elapsed:.1f}s elapsed" f"Progress: {chunk_count} chunks, {total_bytes/1024:.1f}KB received, {elapsed:.1f}s elapsed"