WIP: v1_0_0 migration

This commit is contained in:
remsky 2025-01-28 13:52:57 -07:00
parent 1345b6c81a
commit 9867fc398f
16 changed files with 422 additions and 269 deletions

View file

@ -53,10 +53,12 @@ The service can be accessed through either the API endpoints or the Gradio web i
git clone https://github.com/remsky/Kokoro-FastAPI.git git clone https://github.com/remsky/Kokoro-FastAPI.git
cd Kokoro-FastAPI cd Kokoro-FastAPI
cd docker/gpu # OR cd docker/gpu # OR
# cd docker/cpu # Run this or the above # cd docker/cpu # Run this or the above
docker compose up --build docker compose up --build
# if you are missing any models, run the .py or .sh scrips in the respective folders # if you are missing any models, run:
# python ../scripts/download_model.py --type pth # for GPU
# python ../scripts/download_model.py --type onnx # for CPU
``` ```
Once started: Once started:

View file

@ -335,4 +335,5 @@ async def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
if _manager_instance is None: if _manager_instance is None:
_manager_instance = ModelManager(config) _manager_instance = ModelManager(config)
await _manager_instance.initialize() await _manager_instance.initialize()
return _manager_instance return _manager_instance

View file

@ -5,6 +5,7 @@ FastAPI OpenAI Compatible API
import os import os
import sys import sys
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from pathlib import Path
import torch import torch
import uvicorn import uvicorn
@ -57,16 +58,30 @@ async def lifespan(app: FastAPI):
# Initialize model with warmup and get status # Initialize model with warmup and get status
device, model, voicepack_count = await model_manager.initialize_with_warmup(voice_manager) device, model, voicepack_count = await model_manager.initialize_with_warmup(voice_manager)
except FileNotFoundError:
logger.error("""
Model files not found! You need to either:
1. Download models using the scripts:
GPU: python docker/scripts/download_model.py --type pth
CPU: python docker/scripts/download_model.py --type onnx
2. Set environment variables in docker-compose:
GPU: DOWNLOAD_PTH=true
CPU: DOWNLOAD_ONNX=true
""")
raise
except Exception as e: except Exception as e:
logger.error(f"Failed to initialize model: {e}") logger.error(f"Failed to initialize model: {e}")
raise raise
boundary = "" * 2*12 boundary = "" * 2*12
startup_msg = f""" startup_msg = f"""
{boundary} {boundary}

View file

@ -1,3 +1,5 @@
"""OpenAI-compatible router for text-to-speech"""
import json import json
import os import os
from typing import AsyncGenerator, Dict, List, Union from typing import AsyncGenerator, Dict, List, Union
@ -217,9 +219,9 @@ async def create_speech(
stitch_long_output=True stitch_long_output=True
) )
# Convert to requested format # Convert to requested format - removed stream parameter
content = await AudioService.convert_audio( content = await AudioService.convert_audio(
audio, 24000, request.response_format, is_first_chunk=True, stream=False audio, 24000, request.response_format, is_first_chunk=True
) )
return Response( return Response(

View file

@ -16,7 +16,6 @@ class AudioNormalizer:
"""Handles audio normalization state for a single stream""" """Handles audio normalization state for a single stream"""
def __init__(self): def __init__(self):
self.int16_max = np.iinfo(np.int16).max
self.chunk_trim_ms = settings.gap_trim_ms self.chunk_trim_ms = settings.gap_trim_ms
self.sample_rate = 24000 # Sample rate of the audio self.sample_rate = 24000 # Sample rate of the audio
self.samples_to_trim = int(self.chunk_trim_ms * self.sample_rate / 1000) self.samples_to_trim = int(self.chunk_trim_ms * self.sample_rate / 1000)
@ -30,20 +29,23 @@ class AudioNormalizer:
Returns: Returns:
Normalized and trimmed audio data Normalized and trimmed audio data
""" """
# Convert to float32 for processing if len(audio_data) == 0:
audio_float = audio_data.astype(np.float32) raise ValueError("Empty audio data")
# Trim start and end if enough samples # Trim start and end if enough samples
if len(audio_float) > (2 * self.samples_to_trim): if len(audio_data) > (2 * self.samples_to_trim):
audio_float = audio_float[self.samples_to_trim:-self.samples_to_trim] audio_data = audio_data[self.samples_to_trim:-self.samples_to_trim]
# Scale to int16 range # Scale directly to int16 range with clipping
return (audio_float * 32767).astype(np.int16) return np.clip(audio_data * 32767, -32768, 32767).astype(np.int16)
class AudioService: class AudioService:
"""Service for audio format conversions with streaming support""" """Service for audio format conversions with streaming support"""
# Supported formats
SUPPORTED_FORMATS = {"wav", "mp3", "opus", "flac", "aac", "pcm", "ogg"}
# Default audio format settings balanced for speed and compression # Default audio format settings balanced for speed and compression
DEFAULT_SETTINGS = { DEFAULT_SETTINGS = {
"mp3": { "mp3": {
@ -86,6 +88,10 @@ class AudioService:
Bytes of the converted audio chunk Bytes of the converted audio chunk
""" """
try: try:
# Validate format
if output_format not in AudioService.SUPPORTED_FORMATS:
raise ValueError(f"Format {output_format} not supported")
# Always normalize audio to ensure proper amplitude scaling # Always normalize audio to ensure proper amplitude scaling
if normalizer is None: if normalizer is None:
normalizer = AudioNormalizer() normalizer = AudioNormalizer()

View file

@ -17,26 +17,41 @@ class StreamingAudioWriter:
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.channels = channels self.channels = channels
self.bytes_written = 0 self.bytes_written = 0
self.buffer = BytesIO()
# Format-specific setup # Format-specific setup
if self.format == "wav": if self.format == "wav":
self._write_wav_header() self._write_wav_header()
elif self.format == "ogg": elif self.format in ["ogg", "opus"]:
# For OGG/Opus, write to memory buffer
self.writer = sf.SoundFile( self.writer = sf.SoundFile(
file=BytesIO(), file=self.buffer,
mode='w', mode='w',
samplerate=sample_rate, samplerate=sample_rate,
channels=channels, channels=channels,
format='OGG', format='OGG',
subtype='VORBIS' subtype='VORBIS' if self.format == "ogg" else "OPUS"
) )
elif self.format == "mp3": elif self.format == "flac":
# For MP3, we'll use pydub's incremental writer # For FLAC, write to memory buffer
self.buffer = BytesIO() self.writer = sf.SoundFile(
file=self.buffer,
mode='w',
samplerate=sample_rate,
channels=channels,
format='FLAC'
)
elif self.format in ["mp3", "aac"]:
# For MP3/AAC, we'll use pydub's incremental writer
self.segments = [] # Store segments until we have enough data self.segments = [] # Store segments until we have enough data
self.total_duration = 0 # Track total duration in milliseconds self.total_duration = 0 # Track total duration in milliseconds
# Initialize an empty AudioSegment as our encoder # Initialize an empty AudioSegment as our encoder
self.encoder = AudioSegment.silent(duration=0, frame_rate=self.sample_rate) self.encoder = AudioSegment.silent(duration=0, frame_rate=self.sample_rate)
elif self.format == "pcm":
# PCM doesn't need initialization, we'll write raw bytes
pass
else:
raise ValueError(f"Unsupported format: {format}")
def _write_wav_header(self) -> bytes: def _write_wav_header(self) -> bytes:
"""Write WAV header with correct streaming format""" """Write WAV header with correct streaming format"""
@ -63,42 +78,48 @@ class StreamingAudioWriter:
audio_data: Audio data to write, or None if finalizing audio_data: Audio data to write, or None if finalizing
finalize: Whether this is the final write to close the stream finalize: Whether this is the final write to close the stream
""" """
buffer = BytesIO() output_buffer = BytesIO()
if finalize: if finalize:
if self.format == "wav": if self.format == "wav":
# Write final WAV header with correct sizes # Write final WAV header with correct sizes
buffer.write(b'RIFF') output_buffer.write(b'RIFF')
buffer.write(struct.pack('<L', self.bytes_written + 36)) output_buffer.write(struct.pack('<L', self.bytes_written + 36))
buffer.write(b'WAVE') output_buffer.write(b'WAVE')
buffer.write(b'fmt ') output_buffer.write(b'fmt ')
buffer.write(struct.pack('<L', 16)) output_buffer.write(struct.pack('<L', 16))
buffer.write(struct.pack('<H', 1)) output_buffer.write(struct.pack('<H', 1))
buffer.write(struct.pack('<H', self.channels)) output_buffer.write(struct.pack('<H', self.channels))
buffer.write(struct.pack('<L', self.sample_rate)) output_buffer.write(struct.pack('<L', self.sample_rate))
buffer.write(struct.pack('<L', self.sample_rate * self.channels * 2)) output_buffer.write(struct.pack('<L', self.sample_rate * self.channels * 2))
buffer.write(struct.pack('<H', self.channels * 2)) output_buffer.write(struct.pack('<H', self.channels * 2))
buffer.write(struct.pack('<H', 16)) output_buffer.write(struct.pack('<H', 16))
buffer.write(b'data') output_buffer.write(b'data')
buffer.write(struct.pack('<L', self.bytes_written)) output_buffer.write(struct.pack('<L', self.bytes_written))
elif self.format == "ogg": elif self.format in ["ogg", "opus", "flac"]:
self.writer.close() self.writer.close()
elif self.format == "mp3": return self.buffer.getvalue()
elif self.format in ["mp3", "aac"]:
# Final export of any remaining audio # Final export of any remaining audio
if hasattr(self, 'encoder') and len(self.encoder) > 0: if hasattr(self, 'encoder') and len(self.encoder) > 0:
# Export with duration metadata # Export with duration metadata
format_args = {
"mp3": {"format": "mp3", "codec": "libmp3lame"},
"aac": {"format": "adts", "codec": "aac"}
}[self.format]
self.encoder.export( self.encoder.export(
buffer, output_buffer,
format="mp3", **format_args,
bitrate="192k", bitrate="192k",
parameters=[ parameters=[
"-q:a", "2", "-q:a", "2",
"-write_xing", "1", # Force XING/LAME header "-write_xing", "1" if self.format == "mp3" else "0", # XING header for MP3 only
"-metadata", f"duration={self.total_duration/1000}" # Duration in seconds "-metadata", f"duration={self.total_duration/1000}" # Duration in seconds
] ]
) )
self.encoder = None self.encoder = None
return buffer.getvalue() return output_buffer.getvalue()
if audio_data is None or len(audio_data) == 0: if audio_data is None or len(audio_data) == 0:
return b'' return b''
@ -106,22 +127,22 @@ class StreamingAudioWriter:
if self.format == "wav": if self.format == "wav":
# For WAV, write raw PCM after the first chunk # For WAV, write raw PCM after the first chunk
if self.bytes_written == 0: if self.bytes_written == 0:
buffer.write(self._write_wav_header()) output_buffer.write(self._write_wav_header())
buffer.write(audio_data.tobytes()) output_buffer.write(audio_data.tobytes())
self.bytes_written += len(audio_data.tobytes()) self.bytes_written += len(audio_data.tobytes())
elif self.format == "ogg": elif self.format in ["ogg", "opus", "flac"]:
# OGG/Vorbis handles streaming naturally # Write to soundfile buffer
self.writer.write(audio_data) self.writer.write(audio_data)
self.writer.flush() self.writer.flush()
buffer = self.writer.file # Get current buffer contents
buffer.seek(0, 2) # Seek to end data = self.buffer.getvalue()
chunk = buffer.getvalue() # Clear buffer for next chunk
buffer.seek(0) self.buffer.seek(0)
buffer.truncate() self.buffer.truncate()
return chunk return data
elif self.format == "mp3": elif self.format in ["mp3", "aac"]:
# Convert chunk to AudioSegment and encode # Convert chunk to AudioSegment and encode
segment = AudioSegment( segment = AudioSegment(
audio_data.tobytes(), audio_data.tobytes(),
@ -137,21 +158,30 @@ class StreamingAudioWriter:
self.encoder = self.encoder + segment self.encoder = self.encoder + segment
# Export current state to buffer # Export current state to buffer
self.encoder.export(buffer, format="mp3", bitrate="192k", parameters=[ format_args = {
"mp3": {"format": "mp3", "codec": "libmp3lame"},
"aac": {"format": "adts", "codec": "aac"}
}[self.format]
self.encoder.export(output_buffer, **format_args, bitrate="192k", parameters=[
"-q:a", "2", "-q:a", "2",
"-write_xing", "1", # Force XING/LAME header "-write_xing", "1" if self.format == "mp3" else "0", # XING header for MP3 only
"-metadata", f"duration={self.total_duration/1000}" # Duration in seconds "-metadata", f"duration={self.total_duration/1000}" # Duration in seconds
]) ])
# Get the encoded data # Get the encoded data
encoded_data = buffer.getvalue() encoded_data = output_buffer.getvalue()
# Reset encoder to prevent memory growth # Reset encoder to prevent memory growth
self.encoder = AudioSegment.silent(duration=0, frame_rate=self.sample_rate) self.encoder = AudioSegment.silent(duration=0, frame_rate=self.sample_rate)
return encoded_data return encoded_data
return buffer.getvalue() elif self.format == "pcm":
# For PCM, just write raw bytes
return audio_data.tobytes()
return output_buffer.getvalue()
def close(self) -> Optional[bytes]: def close(self) -> Optional[bytes]:
"""Finish the audio file and return any remaining data""" """Finish the audio file and return any remaining data"""
@ -161,16 +191,31 @@ class StreamingAudioWriter:
buffer.write(b'RIFF') buffer.write(b'RIFF')
buffer.write(struct.pack('<L', self.bytes_written + 36)) # File size buffer.write(struct.pack('<L', self.bytes_written + 36)) # File size
buffer.write(b'WAVE') buffer.write(b'WAVE')
# ... rest of header ... buffer.write(b'fmt ')
buffer.write(struct.pack('<L', self.bytes_written)) # Data size buffer.write(struct.pack('<L', 16))
buffer.write(struct.pack('<H', 1))
buffer.write(struct.pack('<H', self.channels))
buffer.write(struct.pack('<L', self.sample_rate))
buffer.write(struct.pack('<L', self.sample_rate * self.channels * 2))
buffer.write(struct.pack('<H', self.channels * 2))
buffer.write(struct.pack('<H', 16))
buffer.write(b'data')
buffer.write(struct.pack('<L', self.bytes_written))
return buffer.getvalue() return buffer.getvalue()
elif self.format == "ogg": elif self.format in ["ogg", "opus", "flac"]:
self.writer.close() self.writer.close()
return None return self.buffer.getvalue()
elif self.format == "mp3": elif self.format in ["mp3", "aac"]:
# Flush any remaining MP3 frames # Flush any remaining audio
buffer = BytesIO() buffer = BytesIO()
self.encoder.export(buffer, format="mp3") if hasattr(self, 'encoder') and len(self.encoder) > 0:
return buffer.getvalue() format_args = {
"mp3": {"format": "mp3", "codec": "libmp3lame"},
"aac": {"format": "adts", "codec": "aac"}
}[self.format]
self.encoder.export(buffer, **format_args)
return buffer.getvalue()
return None

View file

@ -26,106 +26,128 @@ def sample_audio():
return np.sin(2 * np.pi * frequency * t).astype(np.float32), sample_rate return np.sin(2 * np.pi * frequency * t).astype(np.float32), sample_rate
def test_convert_to_wav(sample_audio): @pytest.mark.asyncio
async def test_convert_to_wav(sample_audio):
"""Test converting to WAV format""" """Test converting to WAV format"""
audio_data, sample_rate = sample_audio audio_data, sample_rate = sample_audio
result = AudioService.convert_audio(audio_data, sample_rate, "wav") result = await AudioService.convert_audio(audio_data, sample_rate, "wav")
assert isinstance(result, bytes) assert isinstance(result, bytes)
assert len(result) > 0 assert len(result) > 0
# Check WAV header
assert result.startswith(b'RIFF')
assert b'WAVE' in result[:12]
def test_convert_to_mp3(sample_audio): @pytest.mark.asyncio
async def test_convert_to_mp3(sample_audio):
"""Test converting to MP3 format""" """Test converting to MP3 format"""
audio_data, sample_rate = sample_audio audio_data, sample_rate = sample_audio
result = AudioService.convert_audio(audio_data, sample_rate, "mp3") result = await AudioService.convert_audio(audio_data, sample_rate, "mp3")
assert isinstance(result, bytes) assert isinstance(result, bytes)
assert len(result) > 0 assert len(result) > 0
# Check MP3 header (ID3 or MPEG frame sync)
assert result.startswith(b'ID3') or result.startswith(b'\xff\xfb')
def test_convert_to_opus(sample_audio): @pytest.mark.asyncio
async def test_convert_to_opus(sample_audio):
"""Test converting to Opus format""" """Test converting to Opus format"""
audio_data, sample_rate = sample_audio audio_data, sample_rate = sample_audio
result = AudioService.convert_audio(audio_data, sample_rate, "opus") result = await AudioService.convert_audio(audio_data, sample_rate, "opus")
assert isinstance(result, bytes) assert isinstance(result, bytes)
assert len(result) > 0 assert len(result) > 0
# Check OGG header
assert result.startswith(b'OggS')
def test_convert_to_flac(sample_audio): @pytest.mark.asyncio
async def test_convert_to_flac(sample_audio):
"""Test converting to FLAC format""" """Test converting to FLAC format"""
audio_data, sample_rate = sample_audio audio_data, sample_rate = sample_audio
result = AudioService.convert_audio(audio_data, sample_rate, "flac") result = await AudioService.convert_audio(audio_data, sample_rate, "flac")
assert isinstance(result, bytes) assert isinstance(result, bytes)
assert len(result) > 0 assert len(result) > 0
# Check FLAC header
assert result.startswith(b'fLaC')
def test_convert_to_aac(sample_audio): @pytest.mark.asyncio
async def test_convert_to_aac(sample_audio):
"""Test converting to AAC format""" """Test converting to AAC format"""
audio_data, sample_rate = sample_audio audio_data, sample_rate = sample_audio
result = AudioService.convert_audio(audio_data, sample_rate, "aac") result = await AudioService.convert_audio(audio_data, sample_rate, "aac")
assert isinstance(result, bytes) assert isinstance(result, bytes)
assert len(result) > 0 assert len(result) > 0
# AAC files typically start with an ADTS header # Check ADTS header (AAC)
assert result.startswith(b'\xff\xf1') or result.startswith(b'\xff\xf9') assert result.startswith(b'\xff\xf0') or result.startswith(b'\xff\xf1')
def test_convert_to_pcm(sample_audio): @pytest.mark.asyncio
async def test_convert_to_pcm(sample_audio):
"""Test converting to PCM format""" """Test converting to PCM format"""
audio_data, sample_rate = sample_audio audio_data, sample_rate = sample_audio
result = AudioService.convert_audio(audio_data, sample_rate, "pcm") result = await AudioService.convert_audio(audio_data, sample_rate, "pcm")
assert isinstance(result, bytes) assert isinstance(result, bytes)
assert len(result) > 0 assert len(result) > 0
# PCM is raw bytes, so no header to check
def test_convert_to_invalid_format_raises_error(sample_audio): @pytest.mark.asyncio
async def test_convert_to_invalid_format_raises_error(sample_audio):
"""Test that converting to an invalid format raises an error""" """Test that converting to an invalid format raises an error"""
audio_data, sample_rate = sample_audio audio_data, sample_rate = sample_audio
with pytest.raises(ValueError, match="Format invalid not supported"): with pytest.raises(ValueError, match="Format invalid not supported"):
AudioService.convert_audio(audio_data, sample_rate, "invalid") await AudioService.convert_audio(audio_data, sample_rate, "invalid")
def test_normalization_wav(sample_audio): @pytest.mark.asyncio
async def test_normalization_wav(sample_audio):
"""Test that WAV output is properly normalized to int16 range""" """Test that WAV output is properly normalized to int16 range"""
audio_data, sample_rate = sample_audio audio_data, sample_rate = sample_audio
# Create audio data outside int16 range # Create audio data outside int16 range
large_audio = audio_data * 1e5 large_audio = audio_data * 1e5
result = AudioService.convert_audio(large_audio, sample_rate, "wav") result = await AudioService.convert_audio(large_audio, sample_rate, "wav")
assert isinstance(result, bytes) assert isinstance(result, bytes)
assert len(result) > 0 assert len(result) > 0
def test_normalization_pcm(sample_audio): @pytest.mark.asyncio
async def test_normalization_pcm(sample_audio):
"""Test that PCM output is properly normalized to int16 range""" """Test that PCM output is properly normalized to int16 range"""
audio_data, sample_rate = sample_audio audio_data, sample_rate = sample_audio
# Create audio data outside int16 range # Create audio data outside int16 range
large_audio = audio_data * 1e5 large_audio = audio_data * 1e5
result = AudioService.convert_audio(large_audio, sample_rate, "pcm") result = await AudioService.convert_audio(large_audio, sample_rate, "pcm")
assert isinstance(result, bytes) assert isinstance(result, bytes)
assert len(result) > 0 assert len(result) > 0
def test_invalid_audio_data(): @pytest.mark.asyncio
async def test_invalid_audio_data():
"""Test handling of invalid audio data""" """Test handling of invalid audio data"""
invalid_audio = np.array([]) # Empty array invalid_audio = np.array([]) # Empty array
sample_rate = 24000 sample_rate = 24000
with pytest.raises(ValueError): with pytest.raises(ValueError):
AudioService.convert_audio(invalid_audio, sample_rate, "wav") await AudioService.convert_audio(invalid_audio, sample_rate, "wav")
def test_different_sample_rates(sample_audio): @pytest.mark.asyncio
async def test_different_sample_rates(sample_audio):
"""Test converting audio with different sample rates""" """Test converting audio with different sample rates"""
audio_data, _ = sample_audio audio_data, _ = sample_audio
sample_rates = [8000, 16000, 44100, 48000] sample_rates = [8000, 16000, 44100, 48000]
for rate in sample_rates: for rate in sample_rates:
result = AudioService.convert_audio(audio_data, rate, "wav") result = await AudioService.convert_audio(audio_data, rate, "wav")
assert isinstance(result, bytes) assert isinstance(result, bytes)
assert len(result) > 0 assert len(result) > 0
def test_buffer_position_after_conversion(sample_audio): @pytest.mark.asyncio
async def test_buffer_position_after_conversion(sample_audio):
"""Test that buffer position is reset after writing""" """Test that buffer position is reset after writing"""
audio_data, sample_rate = sample_audio audio_data, sample_rate = sample_audio
result = AudioService.convert_audio(audio_data, sample_rate, "wav") result = await AudioService.convert_audio(audio_data, sample_rate, "wav")
# Convert again to ensure buffer was properly reset # Convert again to ensure buffer was properly reset
result2 = AudioService.convert_audio(audio_data, sample_rate, "wav") result2 = await AudioService.convert_audio(audio_data, sample_rate, "wav")
assert len(result) == len(result2) assert len(result) == len(result2)

View file

@ -59,7 +59,7 @@ async def test_empty_text(tts_service, test_voice):
voice=test_voice, voice=test_voice,
speed=1.0 speed=1.0
) )
assert "Text is empty after preprocessing" in str(exc_info.value) assert "No audio chunks were generated successfully" in str(exc_info.value)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invalid_voice(tts_service): async def test_invalid_voice(tts_service):
@ -126,15 +126,17 @@ async def test_combine_voices(tts_service):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chunked_text_processing(tts_service, test_voice, mock_audio_output): async def test_chunked_text_processing(tts_service, test_voice, mock_audio_output):
"""Test processing chunked text""" """Test processing chunked text"""
long_text = "First sentence. Second sentence. Third sentence." # Create text that will force chunking by exceeding max tokens
long_text = "This is a test sentence." * 100 # Should be way over 500 tokens
# Don't mock smart_split - let it actually split the text
audio, processing_time = await tts_service.generate_audio( audio, processing_time = await tts_service.generate_audio(
text=long_text, text=long_text,
voice=test_voice, voice=test_voice,
speed=1.0, speed=1.0
stitch_long_output=True
) )
# Should be called multiple times due to chunking
assert tts_service.model_manager.generate.call_count > 1 assert tts_service.model_manager.generate.call_count > 1
assert isinstance(audio, np.ndarray) assert isinstance(audio, np.ndarray)
assert processing_time > 0 assert processing_time > 0

View file

@ -34,6 +34,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
# Copy project files including models # Copy project files including models
COPY --chown=appuser:appuser api ./api COPY --chown=appuser:appuser api ./api
COPY --chown=appuser:appuser web ./web COPY --chown=appuser:appuser web ./web
COPY --chown=appuser:appuser docker/scripts/download_model.* ./
# Install project # Install project
RUN --mount=type=cache,target=/root/.cache/uv \ RUN --mount=type=cache,target=/root/.cache/uv \
@ -44,8 +45,19 @@ ENV PYTHONUNBUFFERED=1
ENV PYTHONPATH=/app ENV PYTHONPATH=/app
ENV PATH="/app/.venv/bin:$PATH" ENV PATH="/app/.venv/bin:$PATH"
ENV UV_LINK_MODE=copy ENV UV_LINK_MODE=copy
ENV USE_GPU=false ENV USE_GPU=false
ENV USE_ONNX=false ENV USE_ONNX=true
ENV DOWNLOAD_ONNX=true
ENV DOWNLOAD_PTH=false
# Download models based on environment variables
RUN if [ "$DOWNLOAD_ONNX" = "true" ]; then \
python download_model.py --type onnx; \
fi && \
if [ "$DOWNLOAD_PTH" = "true" ]; then \
python download_model.py --type pth; \
fi
# Run FastAPI server # Run FastAPI server
CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"] CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]

View file

@ -1,53 +0,0 @@
#!/usr/bin/env python3
import os
import sys
import requests
from pathlib import Path
from typing import List
def download_file(url: str, output_dir: Path) -> None:
"""Download a file from URL to the specified directory."""
filename = os.path.basename(url)
output_path = output_dir / filename
print(f"Downloading {filename}...")
response = requests.get(url, stream=True)
response.raise_for_status()
with open(output_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
def find_project_root() -> Path:
"""Find project root by looking for api directory."""
max_steps = 5
current = Path(__file__).resolve()
for _ in range(max_steps):
if (current / 'api').is_dir():
return current
current = current.parent
raise RuntimeError("Could not find project root (no api directory found)")
def main(custom_models: List[str] = None):
# Always use top-level models directory relative to project root
project_root = find_project_root()
models_dir = project_root / 'api' / 'src' / 'models'
models_dir.mkdir(exist_ok=True, parents=True)
# Default ONNX model if no arguments provided
default_models = [
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.onnx",
# "https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19_fp16.onnx"
]
# Use provided models or default
models_to_download = custom_models if custom_models else default_models
for model_url in models_to_download:
try:
download_file(model_url, models_dir)
except Exception as e:
print(f"Error downloading {model_url}: {e}")
if __name__ == "__main__":
main(sys.argv[1:] if len(sys.argv) > 1 else None)

View file

@ -1,32 +0,0 @@
#!/bin/bash
# Ensure models directory exists
mkdir -p api/src/models
# Function to download a file
download_file() {
local url="$1"
local filename=$(basename "$url")
echo "Downloading $filename..."
curl -L "$url" -o "api/src/models/$filename"
}
# Default ONNX model if no arguments provided
DEFAULT_MODELS=(
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.onnx"
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19_fp16.onnx"
)
# Use provided models or default
if [ $# -gt 0 ]; then
MODELS=("$@")
else
MODELS=("${DEFAULT_MODELS[@]}")
fi
# Download all models
for model in "${MODELS[@]}"; do
download_file "$model"
done
echo "ONNX model download complete!"

View file

@ -38,6 +38,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
# Copy project files including models # Copy project files including models
COPY --chown=appuser:appuser api ./api COPY --chown=appuser:appuser api ./api
COPY --chown=appuser:appuser web ./web COPY --chown=appuser:appuser web ./web
COPY --chown=appuser:appuser docker/scripts/download_model.* ./
# Install project with GPU extras # Install project with GPU extras
RUN --mount=type=cache,target=/root/.cache/uv \ RUN --mount=type=cache,target=/root/.cache/uv \
@ -48,7 +49,19 @@ ENV PYTHONUNBUFFERED=1
ENV PYTHONPATH=/app ENV PYTHONPATH=/app
ENV PATH="/app/.venv/bin:$PATH" ENV PATH="/app/.venv/bin:$PATH"
ENV UV_LINK_MODE=copy ENV UV_LINK_MODE=copy
ENV USE_GPU=true ENV USE_GPU=true
ENV USE_ONNX=false
ENV DOWNLOAD_PTH=true
ENV DOWNLOAD_ONNX=false
# Download models based on environment variables
RUN if [ "$DOWNLOAD_PTH" = "true" ]; then \
python download_model.py --type pth; \
fi && \
if [ "$DOWNLOAD_ONNX" = "true" ]; then \
python download_model.py --type onnx; \
fi
# Run FastAPI server # Run FastAPI server
CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"] CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]

View file

@ -1,57 +0,0 @@
#!/usr/bin/env python3
import os
import sys
import requests
from pathlib import Path
from typing import List
def download_file(url: str, output_dir: Path) -> None:
"""Download a file from URL to the specified directory."""
filename = os.path.basename(url)
if not filename.endswith('.pth'):
print(f"Warning: {filename} is not a .pth file")
return
output_path = output_dir / filename
print(f"Downloading {filename}...")
response = requests.get(url, stream=True)
response.raise_for_status()
with open(output_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
def find_project_root() -> Path:
"""Find project root by looking for api directory."""
max_steps = 5
current = Path(__file__).resolve()
for _ in range(max_steps):
if (current / 'api').is_dir():
return current
current = current.parent
raise RuntimeError("Could not find project root (no api directory found)")
def main(custom_models: List[str] = None):
# Find project root and ensure models directory exists
project_root = find_project_root()
models_dir = project_root / 'api' / 'src' / 'models'
print(f"Downloading models to {models_dir}")
models_dir.mkdir(exist_ok=True)
# Default PTH model if no arguments provided
default_models = [
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth"
]
# Use provided models or default
models_to_download = custom_models if custom_models else default_models
for model_url in models_to_download:
try:
download_file(model_url, models_dir)
except Exception as e:
print(f"Error downloading {model_url}: {e}")
if __name__ == "__main__":
main(sys.argv[1:] if len(sys.argv) > 1 else None)

View file

@ -1,31 +0,0 @@
#!/bin/bash
# Ensure models directory exists
mkdir -p api/src/models
# Function to download a file
download_file() {
local url="$1"
local filename=$(basename "$url")
echo "Downloading $filename..."
curl -L "$url" -o "api/src/models/$filename"
}
# Default PTH model if no arguments provided
DEFAULT_MODELS=(
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth"
)
# Use provided models or default
if [ $# -gt 0 ]; then
MODELS=("$@")
else
MODELS=("${DEFAULT_MODELS[@]}")
fi
# Download all models
for model in "${MODELS[@]}"; do
download_file "$model"
done
echo "PyTorch model download complete!"

View file

@ -0,0 +1,97 @@
#!/usr/bin/env python3
import os
import sys
import argparse
import requests
from pathlib import Path
from typing import List
def download_file(url: str, output_dir: Path, model_type: str) -> bool:
"""Download a file from URL to the specified directory.
Returns:
bool: True if download succeeded, False otherwise
"""
filename = os.path.basename(url)
if not filename.endswith(f'.{model_type}'):
print(f"Warning: {filename} is not a .{model_type} file", file=sys.stderr)
return False
output_path = output_dir / filename
print(f"Downloading {filename}...")
try:
response = requests.get(url, stream=True)
response.raise_for_status()
with open(output_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"Successfully downloaded {filename}")
return True
except Exception as e:
print(f"Error downloading {filename}: {e}", file=sys.stderr)
return False
def find_project_root() -> Path:
"""Find project root by looking for api directory."""
max_steps = 5
current = Path(__file__).resolve()
for _ in range(max_steps):
if (current / 'api').is_dir():
return current
current = current.parent
raise RuntimeError("Could not find project root (no api directory found)")
def main() -> int:
"""Download models to the project.
Returns:
int: Exit code (0 for success, 1 for failure)
"""
parser = argparse.ArgumentParser(description='Download model files')
parser.add_argument('--type', choices=['pth', 'onnx'], required=True,
help='Model type to download (pth or onnx)')
parser.add_argument('urls', nargs='*', help='Optional model URLs to download')
args = parser.parse_args()
try:
# Find project root and ensure models directory exists
project_root = find_project_root()
models_dir = project_root / 'api' / 'src' / 'models'
print(f"Downloading models to {models_dir}")
models_dir.mkdir(exist_ok=True)
# Default models if no arguments provided
default_models = {
'pth': [
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth"
],
'onnx': [
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.onnx",
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19_fp16.onnx"
]
}
# Use provided models or default
models_to_download = args.urls if args.urls else default_models[args.type]
# Download all models
success = True
for model_url in models_to_download:
if not download_file(model_url, models_dir, args.type):
success = False
if success:
print(f"{args.type.upper()} model download complete!")
return 0
else:
print("Some downloads failed", file=sys.stderr)
return 1
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
return 1
if __name__ == "__main__":
sys.exit(main())

View file

@ -0,0 +1,109 @@
#!/bin/bash
# Find project root by looking for api directory
find_project_root() {
local current_dir="$PWD"
local max_steps=5
local steps=0
while [ $steps -lt $max_steps ]; do
if [ -d "$current_dir/api" ]; then
echo "$current_dir"
return 0
fi
current_dir="$(dirname "$current_dir")"
((steps++))
done
echo "Error: Could not find project root (no api directory found)" >&2
exit 1
}
# Function to download a file
download_file() {
local url="$1"
local output_dir="$2"
local model_type="$3"
local filename=$(basename "$url")
# Validate file extension
if [[ ! "$filename" =~ \.$model_type$ ]]; then
echo "Warning: $filename is not a .$model_type file" >&2
return 1
}
echo "Downloading $filename..."
if curl -L "$url" -o "$output_dir/$filename"; then
echo "Successfully downloaded $filename"
return 0
else
echo "Error downloading $filename" >&2
return 1
fi
}
# Parse arguments
MODEL_TYPE=""
while [[ $# -gt 0 ]]; do
case $1 in
--type)
MODEL_TYPE="$2"
shift 2
;;
*)
# If no flag specified, treat remaining args as model URLs
break
;;
esac
done
# Validate model type
if [ "$MODEL_TYPE" != "pth" ] && [ "$MODEL_TYPE" != "onnx" ]; then
echo "Error: Must specify model type with --type (pth or onnx)" >&2
exit 1
fi
# Find project root and ensure models directory exists
PROJECT_ROOT=$(find_project_root)
if [ $? -ne 0 ]; then
exit 1
fi
MODELS_DIR="$PROJECT_ROOT/api/src/models"
echo "Downloading models to $MODELS_DIR"
mkdir -p "$MODELS_DIR"
# Default models if no arguments provided
if [ "$MODEL_TYPE" = "pth" ]; then
DEFAULT_MODELS=(
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth"
)
else
DEFAULT_MODELS=(
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.onnx"
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19_fp16.onnx"
)
fi
# Use provided models or default
if [ $# -gt 0 ]; then
MODELS=("$@")
else
MODELS=("${DEFAULT_MODELS[@]}")
fi
# Download all models
success=true
for model in "${MODELS[@]}"; do
if ! download_file "$model" "$MODELS_DIR" "$MODEL_TYPE"; then
success=false
fi
done
if [ "$success" = true ]; then
echo "${MODEL_TYPE^^} model download complete!"
exit 0
else
echo "Some downloads failed" >&2
exit 1
fi