mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Refactor Docker configurations and update test mocks for development routers
This commit is contained in:
parent
e8c1284032
commit
926ea8cecf
11 changed files with 63 additions and 49 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -6,7 +6,6 @@ ui/data/*
|
||||||
*.db
|
*.db
|
||||||
*.pyc
|
*.pyc
|
||||||
*.pth
|
*.pth
|
||||||
*.pt
|
|
||||||
|
|
||||||
Kokoro-82M/*
|
Kokoro-82M/*
|
||||||
__pycache__/
|
__pycache__/
|
||||||
|
|
|
@ -93,6 +93,7 @@ async def create_speech(
|
||||||
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
||||||
"X-Accel-Buffering": "no", # Disable proxy buffering
|
"X-Accel-Buffering": "no", # Disable proxy buffering
|
||||||
"Cache-Control": "no-cache", # Prevent caching
|
"Cache-Control": "no-cache", # Prevent caching
|
||||||
|
"Transfer-Encoding": "chunked", # Enable chunked transfer encoding
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -104,7 +104,7 @@ class AudioService:
|
||||||
# Raw 16-bit PCM samples, no header
|
# Raw 16-bit PCM samples, no header
|
||||||
buffer.write(normalized_audio.tobytes())
|
buffer.write(normalized_audio.tobytes())
|
||||||
elif output_format == "wav":
|
elif output_format == "wav":
|
||||||
# Always use soundfile for WAV to ensure proper headers and normalization
|
# WAV format with headers
|
||||||
sf.write(
|
sf.write(
|
||||||
buffer,
|
buffer,
|
||||||
normalized_audio,
|
normalized_audio,
|
||||||
|
@ -113,14 +113,14 @@ class AudioService:
|
||||||
subtype="PCM_16",
|
subtype="PCM_16",
|
||||||
)
|
)
|
||||||
elif output_format == "mp3":
|
elif output_format == "mp3":
|
||||||
# Use format settings or defaults
|
# MP3 format with proper framing
|
||||||
settings = format_settings.get("mp3", {}) if format_settings else {}
|
settings = format_settings.get("mp3", {}) if format_settings else {}
|
||||||
settings = {**AudioService.DEFAULT_SETTINGS["mp3"], **settings}
|
settings = {**AudioService.DEFAULT_SETTINGS["mp3"], **settings}
|
||||||
sf.write(
|
sf.write(
|
||||||
buffer, normalized_audio, sample_rate, format="MP3", **settings
|
buffer, normalized_audio, sample_rate, format="MP3", **settings
|
||||||
)
|
)
|
||||||
|
|
||||||
elif output_format == "opus":
|
elif output_format == "opus":
|
||||||
|
# Opus format in OGG container
|
||||||
settings = format_settings.get("opus", {}) if format_settings else {}
|
settings = format_settings.get("opus", {}) if format_settings else {}
|
||||||
settings = {**AudioService.DEFAULT_SETTINGS["opus"], **settings}
|
settings = {**AudioService.DEFAULT_SETTINGS["opus"], **settings}
|
||||||
sf.write(
|
sf.write(
|
||||||
|
@ -131,8 +131,8 @@ class AudioService:
|
||||||
subtype="OPUS",
|
subtype="OPUS",
|
||||||
**settings,
|
**settings,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif output_format == "flac":
|
elif output_format == "flac":
|
||||||
|
# FLAC format with proper framing
|
||||||
if is_first_chunk:
|
if is_first_chunk:
|
||||||
logger.info("Starting FLAC stream...")
|
logger.info("Starting FLAC stream...")
|
||||||
settings = format_settings.get("flac", {}) if format_settings else {}
|
settings = format_settings.get("flac", {}) if format_settings else {}
|
||||||
|
@ -145,15 +145,14 @@ class AudioService:
|
||||||
subtype="PCM_16",
|
subtype="PCM_16",
|
||||||
**settings,
|
**settings,
|
||||||
)
|
)
|
||||||
|
elif output_format == "aac":
|
||||||
|
raise ValueError(
|
||||||
|
"Format aac not currently supported. Supported formats are: wav, mp3, opus, flac, pcm."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if output_format == "aac":
|
raise ValueError(
|
||||||
raise ValueError(
|
f"Format {output_format} not supported. Supported formats are: wav, mp3, opus, flac, pcm, aac."
|
||||||
"Format aac not supported. Supported formats are: wav, mp3, opus, flac, pcm."
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Format {output_format} not supported. Supported formats are: wav, mp3, opus, flac, pcm."
|
|
||||||
)
|
|
||||||
|
|
||||||
buffer.seek(0)
|
buffer.seek(0)
|
||||||
return buffer.getvalue()
|
return buffer.getvalue()
|
||||||
|
|
|
@ -177,7 +177,7 @@ class TTSService:
|
||||||
)
|
)
|
||||||
|
|
||||||
if chunk_audio is not None:
|
if chunk_audio is not None:
|
||||||
# Convert chunk with proper header handling
|
# Convert chunk with proper streaming header handling
|
||||||
chunk_bytes = AudioService.convert_audio(
|
chunk_bytes = AudioService.convert_audio(
|
||||||
chunk_audio,
|
chunk_audio,
|
||||||
24000,
|
24000,
|
||||||
|
@ -185,6 +185,7 @@ class TTSService:
|
||||||
is_first_chunk=is_first,
|
is_first_chunk=is_first,
|
||||||
normalizer=stream_normalizer,
|
normalizer=stream_normalizer,
|
||||||
is_last_chunk=(next_chunk is None), # Last if no next chunk
|
is_last_chunk=(next_chunk is None), # Last if no next chunk
|
||||||
|
stream=True # Ensure proper streaming format handling
|
||||||
)
|
)
|
||||||
|
|
||||||
yield chunk_bytes
|
yield chunk_bytes
|
||||||
|
|
|
@ -181,7 +181,7 @@ def mock_tts_service(monkeypatch):
|
||||||
# Mock TTSModel.generate_from_tokens since we call it directly
|
# Mock TTSModel.generate_from_tokens since we call it directly
|
||||||
mock_generate = Mock(return_value=np.zeros(48000))
|
mock_generate = Mock(return_value=np.zeros(48000))
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"api.src.routers.text_processing.TTSModel.generate_from_tokens", mock_generate
|
"api.src.routers.development.TTSModel.generate_from_tokens", mock_generate
|
||||||
)
|
)
|
||||||
|
|
||||||
return mock_service
|
return mock_service
|
||||||
|
@ -192,5 +192,5 @@ def mock_audio_service(monkeypatch):
|
||||||
"""Mock AudioService"""
|
"""Mock AudioService"""
|
||||||
mock_service = Mock()
|
mock_service = Mock()
|
||||||
mock_service.convert_audio.return_value = b"mock audio data"
|
mock_service.convert_audio.return_value = b"mock audio data"
|
||||||
monkeypatch.setattr("api.src.routers.text_processing.AudioService", mock_service)
|
monkeypatch.setattr("api.src.routers.development.AudioService", mock_service)
|
||||||
return mock_service
|
return mock_service
|
||||||
|
|
|
@ -63,7 +63,7 @@ def test_convert_to_aac_raises_error(sample_audio):
|
||||||
audio_data, sample_rate = sample_audio
|
audio_data, sample_rate = sample_audio
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError,
|
ValueError,
|
||||||
match="Format aac not supported. Supported formats are: wav, mp3, opus, flac, pcm.",
|
match="Failed to convert audio to aac: Format aac not currently supported. Supported formats are: wav, mp3, opus, flac, pcm.",
|
||||||
):
|
):
|
||||||
AudioService.convert_audio(audio_data, sample_rate, "aac")
|
AudioService.convert_audio(audio_data, sample_rate, "aac")
|
||||||
|
|
||||||
|
|
|
@ -20,8 +20,8 @@ async def async_client():
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_phonemize_endpoint(async_client):
|
async def test_phonemize_endpoint(async_client):
|
||||||
"""Test phoneme generation endpoint"""
|
"""Test phoneme generation endpoint"""
|
||||||
with patch("api.src.routers.text_processing.phonemize") as mock_phonemize, patch(
|
with patch("api.src.routers.development.phonemize") as mock_phonemize, patch(
|
||||||
"api.src.routers.text_processing.tokenize"
|
"api.src.routers.development.tokenize"
|
||||||
) as mock_tokenize:
|
) as mock_tokenize:
|
||||||
# Setup mocks
|
# Setup mocks
|
||||||
mock_phonemize.return_value = "həlˈoʊ"
|
mock_phonemize.return_value = "həlˈoʊ"
|
||||||
|
@ -56,7 +56,7 @@ async def test_generate_from_phonemes(
|
||||||
):
|
):
|
||||||
"""Test audio generation from phonemes"""
|
"""Test audio generation from phonemes"""
|
||||||
with patch(
|
with patch(
|
||||||
"api.src.routers.text_processing.TTSService", return_value=mock_tts_service
|
"api.src.routers.development.TTSService", return_value=mock_tts_service
|
||||||
):
|
):
|
||||||
response = await async_client.post(
|
response = await async_client.post(
|
||||||
"/text/generate_from_phonemes",
|
"/text/generate_from_phonemes",
|
||||||
|
@ -76,7 +76,7 @@ async def test_generate_from_phonemes_invalid_voice(async_client, mock_tts_servi
|
||||||
"""Test audio generation with invalid voice"""
|
"""Test audio generation with invalid voice"""
|
||||||
mock_tts_service._get_voice_path.return_value = None
|
mock_tts_service._get_voice_path.return_value = None
|
||||||
with patch(
|
with patch(
|
||||||
"api.src.routers.text_processing.TTSService", return_value=mock_tts_service
|
"api.src.routers.development.TTSService", return_value=mock_tts_service
|
||||||
):
|
):
|
||||||
response = await async_client.post(
|
response = await async_client.post(
|
||||||
"/text/generate_from_phonemes",
|
"/text/generate_from_phonemes",
|
||||||
|
@ -111,7 +111,7 @@ async def test_generate_from_phonemes_invalid_speed(async_client, monkeypatch):
|
||||||
async def test_generate_from_phonemes_empty_phonemes(async_client, mock_tts_service):
|
async def test_generate_from_phonemes_empty_phonemes(async_client, mock_tts_service):
|
||||||
"""Test audio generation with empty phonemes"""
|
"""Test audio generation with empty phonemes"""
|
||||||
with patch(
|
with patch(
|
||||||
"api.src.routers.text_processing.TTSService", return_value=mock_tts_service
|
"api.src.routers.development.TTSService", return_value=mock_tts_service
|
||||||
):
|
):
|
||||||
response = await async_client.post(
|
response = await async_client.post(
|
||||||
"/text/generate_from_phonemes",
|
"/text/generate_from_phonemes",
|
||||||
|
|
|
@ -26,11 +26,11 @@ services:
|
||||||
start_period: 1s
|
start_period: 1s
|
||||||
|
|
||||||
kokoro-tts:
|
kokoro-tts:
|
||||||
# image: ghcr.io/remsky/kokoro-fastapi:latest-cpu
|
image: ghcr.io/remsky/kokoro-fastapi:latest-cpu
|
||||||
# Uncomment below to build from source instead of using the released image
|
# Uncomment below (and comment out above) to build from source instead of using the released image
|
||||||
build:
|
# build:
|
||||||
context: .
|
# context: .
|
||||||
dockerfile: Dockerfile.cpu
|
# dockerfile: Dockerfile.cpu
|
||||||
volumes:
|
volumes:
|
||||||
- ./api/src:/app/api/src
|
- ./api/src:/app/api/src
|
||||||
- ./Kokoro-82M:/app/Kokoro-82M
|
- ./Kokoro-82M:/app/Kokoro-82M
|
||||||
|
@ -52,8 +52,8 @@ services:
|
||||||
|
|
||||||
# Gradio UI service [Comment out everything below if you don't need it]
|
# Gradio UI service [Comment out everything below if you don't need it]
|
||||||
gradio-ui:
|
gradio-ui:
|
||||||
# image: ghcr.io/remsky/kokoro-fastapi:latest-ui
|
image: ghcr.io/remsky/kokoro-fastapi:latest-ui
|
||||||
# Uncomment below to build from source instead of using the released image
|
# Uncomment below (and comment out above) to build from source instead of using the released image
|
||||||
build:
|
build:
|
||||||
context: ./ui
|
context: ./ui
|
||||||
ports:
|
ports:
|
||||||
|
@ -63,3 +63,4 @@ services:
|
||||||
- ./ui/app.py:/app/app.py # Mount app.py for hot reload
|
- ./ui/app.py:/app/app.py # Mount app.py for hot reload
|
||||||
environment:
|
environment:
|
||||||
- GRADIO_WATCH=True # Enable hot reloading
|
- GRADIO_WATCH=True # Enable hot reloading
|
||||||
|
- PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
|
||||||
|
|
|
@ -32,10 +32,10 @@ services:
|
||||||
start_period: 1s
|
start_period: 1s
|
||||||
|
|
||||||
kokoro-tts:
|
kokoro-tts:
|
||||||
# image: ghcr.io/remsky/kokoro-fastapi:latest
|
image: ghcr.io/remsky/kokoro-fastapi:latest
|
||||||
# Uncomment below to build from source instead of using the released image
|
# Uncomment below (and comment out above) to build from source instead of using the released image
|
||||||
build:
|
# build:
|
||||||
context: .
|
# context: .
|
||||||
volumes:
|
volumes:
|
||||||
- ./api/src:/app/api/src
|
- ./api/src:/app/api/src
|
||||||
- ./Kokoro-82M:/app/Kokoro-82M
|
- ./Kokoro-82M:/app/Kokoro-82M
|
||||||
|
@ -50,16 +50,22 @@ services:
|
||||||
- driver: nvidia
|
- driver: nvidia
|
||||||
count: 1
|
count: 1
|
||||||
capabilities: [gpu]
|
capabilities: [gpu]
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "curl", "-f", "http://localhost:8880/v1/audio/voices"]
|
||||||
|
interval: 10s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 30
|
||||||
|
start_period: 30s
|
||||||
depends_on:
|
depends_on:
|
||||||
model-fetcher:
|
model-fetcher:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
|
|
||||||
# Gradio UI service [Comment out everything below if you don't need it]
|
# Gradio UI service [Comment out everything below if you don't need it]
|
||||||
gradio-ui:
|
gradio-ui:
|
||||||
# image: ghcr.io/remsky/kokoro-fastapi:latest-ui
|
image: ghcr.io/remsky/kokoro-fastapi:latest-ui
|
||||||
# Uncomment below to build from source instead of using the released image
|
# Uncomment below (and comment out above) to build from source instead of using the released image
|
||||||
build:
|
# build:
|
||||||
context: ./ui
|
# context: ./ui
|
||||||
ports:
|
ports:
|
||||||
- "7860:7860"
|
- "7860:7860"
|
||||||
volumes:
|
volumes:
|
||||||
|
@ -67,3 +73,7 @@ services:
|
||||||
- ./ui/app.py:/app/app.py # Mount app.py for hot reload
|
- ./ui/app.py:/app/app.py # Mount app.py for hot reload
|
||||||
environment:
|
environment:
|
||||||
- GRADIO_WATCH=True # Enable hot reloading
|
- GRADIO_WATCH=True # Enable hot reloading
|
||||||
|
- PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
|
||||||
|
depends_on:
|
||||||
|
kokoro-tts:
|
||||||
|
condition: service_healthy
|
||||||
|
|
|
@ -1,23 +1,25 @@
|
||||||
#!/usr/bin/env rye run python
|
#!/usr/bin/env rye run python
|
||||||
# %%
|
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
# gets OPENAI_API_KEY from your environment variables
|
# gets OPENAI_API_KEY from your environment variables
|
||||||
openai = OpenAI(base_url="http://localhost:8000/v1", api_key="not-needed-for-local")
|
openai = OpenAI(base_url="http://localhost:8880/v1", api_key="not-needed-for-local")
|
||||||
|
|
||||||
speech_file_path = Path(__file__).parent / "speech.mp3"
|
speech_file_path = Path(__file__).parent / "speech.mp3"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
stream_to_speakers()
|
stream_to_speakers()
|
||||||
|
|
||||||
|
# Create text-to-speech audio file
|
||||||
|
with openai.audio.speech.with_streaming_response.create(
|
||||||
|
model="kokoro",
|
||||||
|
voice="af_bella",
|
||||||
|
input="the quick brown fox jumped over the lazy dogs",
|
||||||
|
) as response:
|
||||||
|
response.stream_to_file(speech_file_path)
|
||||||
|
|
||||||
|
|
||||||
def stream_to_speakers() -> None:
|
def stream_to_speakers() -> None:
|
||||||
|
@ -31,9 +33,12 @@ def stream_to_speakers() -> None:
|
||||||
|
|
||||||
with openai.audio.speech.with_streaming_response.create(
|
with openai.audio.speech.with_streaming_response.create(
|
||||||
model="kokoro",
|
model="kokoro",
|
||||||
voice=VOICE,
|
voice="af_bella",
|
||||||
response_format="mp3", # similar to WAV, but without a header chunk at the start.
|
response_format="pcm", # similar to WAV, but without a header chunk at the start.
|
||||||
input="""My dear sir, that is just where you are wrong. That is just where the whole world has gone wrong. We are always getting away from the present moment. Our mental existences, which are immaterial and have no dimensions, are passing along the Time-Dimension""",
|
input="""I see skies of blue and clouds of white
|
||||||
|
The bright blessed days, the dark sacred nights
|
||||||
|
And I think to myself
|
||||||
|
What a wonderful world""",
|
||||||
) as response:
|
) as response:
|
||||||
print(f"Time to first byte: {int((time.time() - start_time) * 1000)}ms")
|
print(f"Time to first byte: {int((time.time() - start_time) * 1000)}ms")
|
||||||
for chunk in response.iter_bytes(chunk_size=1024):
|
for chunk in response.iter_bytes(chunk_size=1024):
|
||||||
|
@ -44,5 +49,3 @@ def stream_to_speakers() -> None:
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
# %%
|
|
||||||
|
|
|
@ -74,7 +74,7 @@ def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"):
|
||||||
all_audio_data.extend(chunk)
|
all_audio_data.extend(chunk)
|
||||||
|
|
||||||
# Log progress every 10 chunks
|
# Log progress every 10 chunks
|
||||||
if chunk_count % 10 == 0:
|
if chunk_count % 100 == 0:
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
print(
|
print(
|
||||||
f"Progress: {chunk_count} chunks, {total_bytes/1024:.1f}KB received, {elapsed:.1f}s elapsed"
|
f"Progress: {chunk_count} chunks, {total_bytes/1024:.1f}KB received, {elapsed:.1f}s elapsed"
|
||||||
|
|
Loading…
Add table
Reference in a new issue