diff --git a/README.md b/README.md index a8557af..44d9600 100644 --- a/README.md +++ b/README.md @@ -53,10 +53,12 @@ The service can be accessed through either the API endpoints or the Gradio web i git clone https://github.com/remsky/Kokoro-FastAPI.git cd Kokoro-FastAPI - cd docker/gpu # OR + cd docker/gpu # OR # cd docker/cpu # Run this or the above - docker compose up --build - # if you are missing any models, run the .py or .sh scrips in the respective folders + docker compose up --build + # if you are missing any models, run: + # python ../scripts/download_model.py --type pth # for GPU + # python ../scripts/download_model.py --type onnx # for CPU ``` Once started: diff --git a/api/src/inference/model_manager.py b/api/src/inference/model_manager.py index 5920843..5fe53c9 100644 --- a/api/src/inference/model_manager.py +++ b/api/src/inference/model_manager.py @@ -335,4 +335,5 @@ async def get_manager(config: Optional[ModelConfig] = None) -> ModelManager: if _manager_instance is None: _manager_instance = ModelManager(config) await _manager_instance.initialize() - return _manager_instance \ No newline at end of file + return _manager_instance + diff --git a/api/src/main.py b/api/src/main.py index d4f00e8..759bbd7 100644 --- a/api/src/main.py +++ b/api/src/main.py @@ -5,6 +5,7 @@ FastAPI OpenAI Compatible API import os import sys from contextlib import asynccontextmanager +from pathlib import Path import torch import uvicorn @@ -57,16 +58,30 @@ async def lifespan(app: FastAPI): # Initialize model with warmup and get status device, model, voicepack_count = await model_manager.initialize_with_warmup(voice_manager) + except FileNotFoundError: + logger.error(""" +Model files not found! You need to either: + +1. Download models using the scripts: + GPU: python docker/scripts/download_model.py --type pth + CPU: python docker/scripts/download_model.py --type onnx + +2. Set environment variables in docker-compose: + GPU: DOWNLOAD_PTH=true + CPU: DOWNLOAD_ONNX=true +""") + raise except Exception as e: logger.error(f"Failed to initialize model: {e}") raise + boundary = "░" * 2*12 startup_msg = f""" {boundary} ╔═╗┌─┐┌─┐┌┬┐ - ╠╣ ├─┤└─┐ │ + ╠╣ ├─┤└─┐ │ ╚ ┴ ┴└─┘ ┴ ╦╔═┌─┐┬┌─┌─┐ ╠╩╗│ │├┴┐│ │ diff --git a/api/src/routers/openai_compatible.py b/api/src/routers/openai_compatible.py index 5908a56..f4b8684 100644 --- a/api/src/routers/openai_compatible.py +++ b/api/src/routers/openai_compatible.py @@ -1,3 +1,5 @@ +"""OpenAI-compatible router for text-to-speech""" + import json import os from typing import AsyncGenerator, Dict, List, Union @@ -217,9 +219,9 @@ async def create_speech( stitch_long_output=True ) - # Convert to requested format + # Convert to requested format - removed stream parameter content = await AudioService.convert_audio( - audio, 24000, request.response_format, is_first_chunk=True, stream=False + audio, 24000, request.response_format, is_first_chunk=True ) return Response( diff --git a/api/src/services/audio.py b/api/src/services/audio.py index bfe419b..2633594 100644 --- a/api/src/services/audio.py +++ b/api/src/services/audio.py @@ -16,7 +16,6 @@ class AudioNormalizer: """Handles audio normalization state for a single stream""" def __init__(self): - self.int16_max = np.iinfo(np.int16).max self.chunk_trim_ms = settings.gap_trim_ms self.sample_rate = 24000 # Sample rate of the audio self.samples_to_trim = int(self.chunk_trim_ms * self.sample_rate / 1000) @@ -30,20 +29,23 @@ class AudioNormalizer: Returns: Normalized and trimmed audio data """ - # Convert to float32 for processing - audio_float = audio_data.astype(np.float32) - + if len(audio_data) == 0: + raise ValueError("Empty audio data") + # Trim start and end if enough samples - if len(audio_float) > (2 * self.samples_to_trim): - audio_float = audio_float[self.samples_to_trim:-self.samples_to_trim] + if len(audio_data) > (2 * self.samples_to_trim): + audio_data = audio_data[self.samples_to_trim:-self.samples_to_trim] - # Scale to int16 range - return (audio_float * 32767).astype(np.int16) + # Scale directly to int16 range with clipping + return np.clip(audio_data * 32767, -32768, 32767).astype(np.int16) class AudioService: """Service for audio format conversions with streaming support""" + # Supported formats + SUPPORTED_FORMATS = {"wav", "mp3", "opus", "flac", "aac", "pcm", "ogg"} + # Default audio format settings balanced for speed and compression DEFAULT_SETTINGS = { "mp3": { @@ -86,6 +88,10 @@ class AudioService: Bytes of the converted audio chunk """ try: + # Validate format + if output_format not in AudioService.SUPPORTED_FORMATS: + raise ValueError(f"Format {output_format} not supported") + # Always normalize audio to ensure proper amplitude scaling if normalizer is None: normalizer = AudioNormalizer() diff --git a/api/src/services/streaming_audio_writer.py b/api/src/services/streaming_audio_writer.py index 23d63c0..1084c91 100644 --- a/api/src/services/streaming_audio_writer.py +++ b/api/src/services/streaming_audio_writer.py @@ -17,26 +17,41 @@ class StreamingAudioWriter: self.sample_rate = sample_rate self.channels = channels self.bytes_written = 0 + self.buffer = BytesIO() # Format-specific setup if self.format == "wav": self._write_wav_header() - elif self.format == "ogg": + elif self.format in ["ogg", "opus"]: + # For OGG/Opus, write to memory buffer self.writer = sf.SoundFile( - file=BytesIO(), + file=self.buffer, mode='w', samplerate=sample_rate, channels=channels, format='OGG', - subtype='VORBIS' + subtype='VORBIS' if self.format == "ogg" else "OPUS" ) - elif self.format == "mp3": - # For MP3, we'll use pydub's incremental writer - self.buffer = BytesIO() + elif self.format == "flac": + # For FLAC, write to memory buffer + self.writer = sf.SoundFile( + file=self.buffer, + mode='w', + samplerate=sample_rate, + channels=channels, + format='FLAC' + ) + elif self.format in ["mp3", "aac"]: + # For MP3/AAC, we'll use pydub's incremental writer self.segments = [] # Store segments until we have enough data self.total_duration = 0 # Track total duration in milliseconds # Initialize an empty AudioSegment as our encoder self.encoder = AudioSegment.silent(duration=0, frame_rate=self.sample_rate) + elif self.format == "pcm": + # PCM doesn't need initialization, we'll write raw bytes + pass + else: + raise ValueError(f"Unsupported format: {format}") def _write_wav_header(self) -> bytes: """Write WAV header with correct streaming format""" @@ -63,42 +78,48 @@ class StreamingAudioWriter: audio_data: Audio data to write, or None if finalizing finalize: Whether this is the final write to close the stream """ - buffer = BytesIO() + output_buffer = BytesIO() if finalize: if self.format == "wav": # Write final WAV header with correct sizes - buffer.write(b'RIFF') - buffer.write(struct.pack(' 0: # Export with duration metadata + format_args = { + "mp3": {"format": "mp3", "codec": "libmp3lame"}, + "aac": {"format": "adts", "codec": "aac"} + }[self.format] + self.encoder.export( - buffer, - format="mp3", + output_buffer, + **format_args, bitrate="192k", parameters=[ "-q:a", "2", - "-write_xing", "1", # Force XING/LAME header + "-write_xing", "1" if self.format == "mp3" else "0", # XING header for MP3 only "-metadata", f"duration={self.total_duration/1000}" # Duration in seconds ] ) self.encoder = None - return buffer.getvalue() + return output_buffer.getvalue() if audio_data is None or len(audio_data) == 0: return b'' @@ -106,22 +127,22 @@ class StreamingAudioWriter: if self.format == "wav": # For WAV, write raw PCM after the first chunk if self.bytes_written == 0: - buffer.write(self._write_wav_header()) - buffer.write(audio_data.tobytes()) + output_buffer.write(self._write_wav_header()) + output_buffer.write(audio_data.tobytes()) self.bytes_written += len(audio_data.tobytes()) - elif self.format == "ogg": - # OGG/Vorbis handles streaming naturally + elif self.format in ["ogg", "opus", "flac"]: + # Write to soundfile buffer self.writer.write(audio_data) self.writer.flush() - buffer = self.writer.file - buffer.seek(0, 2) # Seek to end - chunk = buffer.getvalue() - buffer.seek(0) - buffer.truncate() - return chunk + # Get current buffer contents + data = self.buffer.getvalue() + # Clear buffer for next chunk + self.buffer.seek(0) + self.buffer.truncate() + return data - elif self.format == "mp3": + elif self.format in ["mp3", "aac"]: # Convert chunk to AudioSegment and encode segment = AudioSegment( audio_data.tobytes(), @@ -137,21 +158,30 @@ class StreamingAudioWriter: self.encoder = self.encoder + segment # Export current state to buffer - self.encoder.export(buffer, format="mp3", bitrate="192k", parameters=[ + format_args = { + "mp3": {"format": "mp3", "codec": "libmp3lame"}, + "aac": {"format": "adts", "codec": "aac"} + }[self.format] + + self.encoder.export(output_buffer, **format_args, bitrate="192k", parameters=[ "-q:a", "2", - "-write_xing", "1", # Force XING/LAME header + "-write_xing", "1" if self.format == "mp3" else "0", # XING header for MP3 only "-metadata", f"duration={self.total_duration/1000}" # Duration in seconds ]) # Get the encoded data - encoded_data = buffer.getvalue() + encoded_data = output_buffer.getvalue() # Reset encoder to prevent memory growth self.encoder = AudioSegment.silent(duration=0, frame_rate=self.sample_rate) return encoded_data - return buffer.getvalue() + elif self.format == "pcm": + # For PCM, just write raw bytes + return audio_data.tobytes() + + return output_buffer.getvalue() def close(self) -> Optional[bytes]: """Finish the audio file and return any remaining data""" @@ -161,16 +191,31 @@ class StreamingAudioWriter: buffer.write(b'RIFF') buffer.write(struct.pack(' 0: + format_args = { + "mp3": {"format": "mp3", "codec": "libmp3lame"}, + "aac": {"format": "adts", "codec": "aac"} + }[self.format] + self.encoder.export(buffer, **format_args) + return buffer.getvalue() + + return None \ No newline at end of file diff --git a/api/tests/test_audio_service.py b/api/tests/test_audio_service.py index 8131c9f..b201c92 100644 --- a/api/tests/test_audio_service.py +++ b/api/tests/test_audio_service.py @@ -26,106 +26,128 @@ def sample_audio(): return np.sin(2 * np.pi * frequency * t).astype(np.float32), sample_rate -def test_convert_to_wav(sample_audio): +@pytest.mark.asyncio +async def test_convert_to_wav(sample_audio): """Test converting to WAV format""" audio_data, sample_rate = sample_audio - result = AudioService.convert_audio(audio_data, sample_rate, "wav") + result = await AudioService.convert_audio(audio_data, sample_rate, "wav") assert isinstance(result, bytes) assert len(result) > 0 + # Check WAV header + assert result.startswith(b'RIFF') + assert b'WAVE' in result[:12] -def test_convert_to_mp3(sample_audio): +@pytest.mark.asyncio +async def test_convert_to_mp3(sample_audio): """Test converting to MP3 format""" audio_data, sample_rate = sample_audio - result = AudioService.convert_audio(audio_data, sample_rate, "mp3") + result = await AudioService.convert_audio(audio_data, sample_rate, "mp3") assert isinstance(result, bytes) assert len(result) > 0 + # Check MP3 header (ID3 or MPEG frame sync) + assert result.startswith(b'ID3') or result.startswith(b'\xff\xfb') -def test_convert_to_opus(sample_audio): +@pytest.mark.asyncio +async def test_convert_to_opus(sample_audio): """Test converting to Opus format""" audio_data, sample_rate = sample_audio - result = AudioService.convert_audio(audio_data, sample_rate, "opus") + result = await AudioService.convert_audio(audio_data, sample_rate, "opus") assert isinstance(result, bytes) assert len(result) > 0 + # Check OGG header + assert result.startswith(b'OggS') -def test_convert_to_flac(sample_audio): +@pytest.mark.asyncio +async def test_convert_to_flac(sample_audio): """Test converting to FLAC format""" audio_data, sample_rate = sample_audio - result = AudioService.convert_audio(audio_data, sample_rate, "flac") + result = await AudioService.convert_audio(audio_data, sample_rate, "flac") assert isinstance(result, bytes) assert len(result) > 0 + # Check FLAC header + assert result.startswith(b'fLaC') -def test_convert_to_aac(sample_audio): +@pytest.mark.asyncio +async def test_convert_to_aac(sample_audio): """Test converting to AAC format""" audio_data, sample_rate = sample_audio - result = AudioService.convert_audio(audio_data, sample_rate, "aac") + result = await AudioService.convert_audio(audio_data, sample_rate, "aac") assert isinstance(result, bytes) assert len(result) > 0 - # AAC files typically start with an ADTS header - assert result.startswith(b'\xff\xf1') or result.startswith(b'\xff\xf9') + # Check ADTS header (AAC) + assert result.startswith(b'\xff\xf0') or result.startswith(b'\xff\xf1') -def test_convert_to_pcm(sample_audio): +@pytest.mark.asyncio +async def test_convert_to_pcm(sample_audio): """Test converting to PCM format""" audio_data, sample_rate = sample_audio - result = AudioService.convert_audio(audio_data, sample_rate, "pcm") + result = await AudioService.convert_audio(audio_data, sample_rate, "pcm") assert isinstance(result, bytes) assert len(result) > 0 + # PCM is raw bytes, so no header to check -def test_convert_to_invalid_format_raises_error(sample_audio): +@pytest.mark.asyncio +async def test_convert_to_invalid_format_raises_error(sample_audio): """Test that converting to an invalid format raises an error""" audio_data, sample_rate = sample_audio with pytest.raises(ValueError, match="Format invalid not supported"): - AudioService.convert_audio(audio_data, sample_rate, "invalid") + await AudioService.convert_audio(audio_data, sample_rate, "invalid") -def test_normalization_wav(sample_audio): +@pytest.mark.asyncio +async def test_normalization_wav(sample_audio): """Test that WAV output is properly normalized to int16 range""" audio_data, sample_rate = sample_audio # Create audio data outside int16 range large_audio = audio_data * 1e5 - result = AudioService.convert_audio(large_audio, sample_rate, "wav") + result = await AudioService.convert_audio(large_audio, sample_rate, "wav") assert isinstance(result, bytes) assert len(result) > 0 -def test_normalization_pcm(sample_audio): +@pytest.mark.asyncio +async def test_normalization_pcm(sample_audio): """Test that PCM output is properly normalized to int16 range""" audio_data, sample_rate = sample_audio # Create audio data outside int16 range large_audio = audio_data * 1e5 - result = AudioService.convert_audio(large_audio, sample_rate, "pcm") + result = await AudioService.convert_audio(large_audio, sample_rate, "pcm") assert isinstance(result, bytes) assert len(result) > 0 -def test_invalid_audio_data(): +@pytest.mark.asyncio +async def test_invalid_audio_data(): """Test handling of invalid audio data""" invalid_audio = np.array([]) # Empty array sample_rate = 24000 with pytest.raises(ValueError): - AudioService.convert_audio(invalid_audio, sample_rate, "wav") + await AudioService.convert_audio(invalid_audio, sample_rate, "wav") -def test_different_sample_rates(sample_audio): +@pytest.mark.asyncio +async def test_different_sample_rates(sample_audio): """Test converting audio with different sample rates""" audio_data, _ = sample_audio sample_rates = [8000, 16000, 44100, 48000] for rate in sample_rates: - result = AudioService.convert_audio(audio_data, rate, "wav") + result = await AudioService.convert_audio(audio_data, rate, "wav") assert isinstance(result, bytes) assert len(result) > 0 -def test_buffer_position_after_conversion(sample_audio): +@pytest.mark.asyncio +async def test_buffer_position_after_conversion(sample_audio): """Test that buffer position is reset after writing""" audio_data, sample_rate = sample_audio - result = AudioService.convert_audio(audio_data, sample_rate, "wav") + result = await AudioService.convert_audio(audio_data, sample_rate, "wav") # Convert again to ensure buffer was properly reset - result2 = AudioService.convert_audio(audio_data, sample_rate, "wav") + result2 = await AudioService.convert_audio(audio_data, sample_rate, "wav") assert len(result) == len(result2) diff --git a/api/tests/test_tts_service_new.py b/api/tests/test_tts_service_new.py index f9a1cdd..4b2887f 100644 --- a/api/tests/test_tts_service_new.py +++ b/api/tests/test_tts_service_new.py @@ -59,7 +59,7 @@ async def test_empty_text(tts_service, test_voice): voice=test_voice, speed=1.0 ) - assert "Text is empty after preprocessing" in str(exc_info.value) + assert "No audio chunks were generated successfully" in str(exc_info.value) @pytest.mark.asyncio async def test_invalid_voice(tts_service): @@ -126,15 +126,17 @@ async def test_combine_voices(tts_service): @pytest.mark.asyncio async def test_chunked_text_processing(tts_service, test_voice, mock_audio_output): """Test processing chunked text""" - long_text = "First sentence. Second sentence. Third sentence." + # Create text that will force chunking by exceeding max tokens + long_text = "This is a test sentence." * 100 # Should be way over 500 tokens + # Don't mock smart_split - let it actually split the text audio, processing_time = await tts_service.generate_audio( text=long_text, voice=test_voice, - speed=1.0, - stitch_long_output=True + speed=1.0 ) + # Should be called multiple times due to chunking assert tts_service.model_manager.generate.call_count > 1 assert isinstance(audio, np.ndarray) assert processing_time > 0 \ No newline at end of file diff --git a/docker/cpu/Dockerfile b/docker/cpu/Dockerfile index ccd42cd..8363ed3 100644 --- a/docker/cpu/Dockerfile +++ b/docker/cpu/Dockerfile @@ -34,6 +34,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ # Copy project files including models COPY --chown=appuser:appuser api ./api COPY --chown=appuser:appuser web ./web +COPY --chown=appuser:appuser docker/scripts/download_model.* ./ # Install project RUN --mount=type=cache,target=/root/.cache/uv \ @@ -44,8 +45,19 @@ ENV PYTHONUNBUFFERED=1 ENV PYTHONPATH=/app ENV PATH="/app/.venv/bin:$PATH" ENV UV_LINK_MODE=copy + ENV USE_GPU=false -ENV USE_ONNX=false +ENV USE_ONNX=true +ENV DOWNLOAD_ONNX=true +ENV DOWNLOAD_PTH=false + +# Download models based on environment variables +RUN if [ "$DOWNLOAD_ONNX" = "true" ]; then \ + python download_model.py --type onnx; \ + fi && \ + if [ "$DOWNLOAD_PTH" = "true" ]; then \ + python download_model.py --type pth; \ + fi # Run FastAPI server CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"] diff --git a/docker/cpu/download_onnx.py b/docker/cpu/download_onnx.py deleted file mode 100755 index a97daf9..0000000 --- a/docker/cpu/download_onnx.py +++ /dev/null @@ -1,53 +0,0 @@ -#!/usr/bin/env python3 -import os -import sys -import requests -from pathlib import Path -from typing import List - -def download_file(url: str, output_dir: Path) -> None: - """Download a file from URL to the specified directory.""" - filename = os.path.basename(url) - output_path = output_dir / filename - - print(f"Downloading {filename}...") - response = requests.get(url, stream=True) - response.raise_for_status() - - with open(output_path, 'wb') as f: - for chunk in response.iter_content(chunk_size=8192): - f.write(chunk) - -def find_project_root() -> Path: - """Find project root by looking for api directory.""" - max_steps = 5 - current = Path(__file__).resolve() - for _ in range(max_steps): - if (current / 'api').is_dir(): - return current - current = current.parent - raise RuntimeError("Could not find project root (no api directory found)") - -def main(custom_models: List[str] = None): - # Always use top-level models directory relative to project root - project_root = find_project_root() - models_dir = project_root / 'api' / 'src' / 'models' - models_dir.mkdir(exist_ok=True, parents=True) - - # Default ONNX model if no arguments provided - default_models = [ - "https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.onnx", - # "https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19_fp16.onnx" - ] - - # Use provided models or default - models_to_download = custom_models if custom_models else default_models - - for model_url in models_to_download: - try: - download_file(model_url, models_dir) - except Exception as e: - print(f"Error downloading {model_url}: {e}") - -if __name__ == "__main__": - main(sys.argv[1:] if len(sys.argv) > 1 else None) \ No newline at end of file diff --git a/docker/cpu/download_onnx.sh b/docker/cpu/download_onnx.sh deleted file mode 100755 index c0a250b..0000000 --- a/docker/cpu/download_onnx.sh +++ /dev/null @@ -1,32 +0,0 @@ -#!/bin/bash - -# Ensure models directory exists -mkdir -p api/src/models - -# Function to download a file -download_file() { - local url="$1" - local filename=$(basename "$url") - echo "Downloading $filename..." - curl -L "$url" -o "api/src/models/$filename" -} - -# Default ONNX model if no arguments provided -DEFAULT_MODELS=( - "https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.onnx" - "https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19_fp16.onnx" -) - -# Use provided models or default -if [ $# -gt 0 ]; then - MODELS=("$@") -else - MODELS=("${DEFAULT_MODELS[@]}") -fi - -# Download all models -for model in "${MODELS[@]}"; do - download_file "$model" -done - -echo "ONNX model download complete!" \ No newline at end of file diff --git a/docker/gpu/Dockerfile b/docker/gpu/Dockerfile index 93613aa..3134f4e 100644 --- a/docker/gpu/Dockerfile +++ b/docker/gpu/Dockerfile @@ -38,6 +38,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ # Copy project files including models COPY --chown=appuser:appuser api ./api COPY --chown=appuser:appuser web ./web +COPY --chown=appuser:appuser docker/scripts/download_model.* ./ # Install project with GPU extras RUN --mount=type=cache,target=/root/.cache/uv \ @@ -48,7 +49,19 @@ ENV PYTHONUNBUFFERED=1 ENV PYTHONPATH=/app ENV PATH="/app/.venv/bin:$PATH" ENV UV_LINK_MODE=copy + ENV USE_GPU=true +ENV USE_ONNX=false +ENV DOWNLOAD_PTH=true +ENV DOWNLOAD_ONNX=false + +# Download models based on environment variables +RUN if [ "$DOWNLOAD_PTH" = "true" ]; then \ + python download_model.py --type pth; \ + fi && \ + if [ "$DOWNLOAD_ONNX" = "true" ]; then \ + python download_model.py --type onnx; \ + fi # Run FastAPI server CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"] diff --git a/docker/gpu/download_pth.py b/docker/gpu/download_pth.py deleted file mode 100755 index f58f29d..0000000 --- a/docker/gpu/download_pth.py +++ /dev/null @@ -1,57 +0,0 @@ -#!/usr/bin/env python3 -import os -import sys -import requests -from pathlib import Path -from typing import List - -def download_file(url: str, output_dir: Path) -> None: - """Download a file from URL to the specified directory.""" - filename = os.path.basename(url) - if not filename.endswith('.pth'): - print(f"Warning: {filename} is not a .pth file") - return - - output_path = output_dir / filename - - print(f"Downloading {filename}...") - response = requests.get(url, stream=True) - response.raise_for_status() - - with open(output_path, 'wb') as f: - for chunk in response.iter_content(chunk_size=8192): - f.write(chunk) - -def find_project_root() -> Path: - """Find project root by looking for api directory.""" - max_steps = 5 - current = Path(__file__).resolve() - for _ in range(max_steps): - if (current / 'api').is_dir(): - return current - current = current.parent - raise RuntimeError("Could not find project root (no api directory found)") - -def main(custom_models: List[str] = None): - # Find project root and ensure models directory exists - project_root = find_project_root() - models_dir = project_root / 'api' / 'src' / 'models' - print(f"Downloading models to {models_dir}") - models_dir.mkdir(exist_ok=True) - - # Default PTH model if no arguments provided - default_models = [ - "https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth" - ] - - # Use provided models or default - models_to_download = custom_models if custom_models else default_models - - for model_url in models_to_download: - try: - download_file(model_url, models_dir) - except Exception as e: - print(f"Error downloading {model_url}: {e}") - -if __name__ == "__main__": - main(sys.argv[1:] if len(sys.argv) > 1 else None) \ No newline at end of file diff --git a/docker/gpu/download_pth.sh b/docker/gpu/download_pth.sh deleted file mode 100755 index c8bda83..0000000 --- a/docker/gpu/download_pth.sh +++ /dev/null @@ -1,31 +0,0 @@ -#!/bin/bash - -# Ensure models directory exists -mkdir -p api/src/models - -# Function to download a file -download_file() { - local url="$1" - local filename=$(basename "$url") - echo "Downloading $filename..." - curl -L "$url" -o "api/src/models/$filename" -} - -# Default PTH model if no arguments provided -DEFAULT_MODELS=( - "https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth" -) - -# Use provided models or default -if [ $# -gt 0 ]; then - MODELS=("$@") -else - MODELS=("${DEFAULT_MODELS[@]}") -fi - -# Download all models -for model in "${MODELS[@]}"; do - download_file "$model" -done - -echo "PyTorch model download complete!" \ No newline at end of file diff --git a/docker/scripts/download_model.py b/docker/scripts/download_model.py new file mode 100644 index 0000000..bc808df --- /dev/null +++ b/docker/scripts/download_model.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +import os +import sys +import argparse +import requests +from pathlib import Path +from typing import List + +def download_file(url: str, output_dir: Path, model_type: str) -> bool: + """Download a file from URL to the specified directory. + + Returns: + bool: True if download succeeded, False otherwise + """ + filename = os.path.basename(url) + if not filename.endswith(f'.{model_type}'): + print(f"Warning: {filename} is not a .{model_type} file", file=sys.stderr) + return False + + output_path = output_dir / filename + + print(f"Downloading {filename}...") + try: + response = requests.get(url, stream=True) + response.raise_for_status() + + with open(output_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + print(f"Successfully downloaded {filename}") + return True + except Exception as e: + print(f"Error downloading {filename}: {e}", file=sys.stderr) + return False + +def find_project_root() -> Path: + """Find project root by looking for api directory.""" + max_steps = 5 + current = Path(__file__).resolve() + for _ in range(max_steps): + if (current / 'api').is_dir(): + return current + current = current.parent + raise RuntimeError("Could not find project root (no api directory found)") + +def main() -> int: + """Download models to the project. + + Returns: + int: Exit code (0 for success, 1 for failure) + """ + parser = argparse.ArgumentParser(description='Download model files') + parser.add_argument('--type', choices=['pth', 'onnx'], required=True, + help='Model type to download (pth or onnx)') + parser.add_argument('urls', nargs='*', help='Optional model URLs to download') + args = parser.parse_args() + + try: + # Find project root and ensure models directory exists + project_root = find_project_root() + models_dir = project_root / 'api' / 'src' / 'models' + print(f"Downloading models to {models_dir}") + models_dir.mkdir(exist_ok=True) + + # Default models if no arguments provided + default_models = { + 'pth': [ + "https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth" + ], + 'onnx': [ + "https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.onnx", + "https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19_fp16.onnx" + ] + } + + # Use provided models or default + models_to_download = args.urls if args.urls else default_models[args.type] + + # Download all models + success = True + for model_url in models_to_download: + if not download_file(model_url, models_dir, args.type): + success = False + + if success: + print(f"{args.type.upper()} model download complete!") + return 0 + else: + print("Some downloads failed", file=sys.stderr) + return 1 + + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + return 1 + +if __name__ == "__main__": + sys.exit(main()) \ No newline at end of file diff --git a/docker/scripts/download_model.sh b/docker/scripts/download_model.sh new file mode 100644 index 0000000..926a831 --- /dev/null +++ b/docker/scripts/download_model.sh @@ -0,0 +1,109 @@ +#!/bin/bash + +# Find project root by looking for api directory +find_project_root() { + local current_dir="$PWD" + local max_steps=5 + local steps=0 + + while [ $steps -lt $max_steps ]; do + if [ -d "$current_dir/api" ]; then + echo "$current_dir" + return 0 + fi + current_dir="$(dirname "$current_dir")" + ((steps++)) + done + + echo "Error: Could not find project root (no api directory found)" >&2 + exit 1 +} + +# Function to download a file +download_file() { + local url="$1" + local output_dir="$2" + local model_type="$3" + local filename=$(basename "$url") + + # Validate file extension + if [[ ! "$filename" =~ \.$model_type$ ]]; then + echo "Warning: $filename is not a .$model_type file" >&2 + return 1 + } + + echo "Downloading $filename..." + if curl -L "$url" -o "$output_dir/$filename"; then + echo "Successfully downloaded $filename" + return 0 + else + echo "Error downloading $filename" >&2 + return 1 + fi +} + +# Parse arguments +MODEL_TYPE="" +while [[ $# -gt 0 ]]; do + case $1 in + --type) + MODEL_TYPE="$2" + shift 2 + ;; + *) + # If no flag specified, treat remaining args as model URLs + break + ;; + esac +done + +# Validate model type +if [ "$MODEL_TYPE" != "pth" ] && [ "$MODEL_TYPE" != "onnx" ]; then + echo "Error: Must specify model type with --type (pth or onnx)" >&2 + exit 1 +fi + +# Find project root and ensure models directory exists +PROJECT_ROOT=$(find_project_root) +if [ $? -ne 0 ]; then + exit 1 +fi + +MODELS_DIR="$PROJECT_ROOT/api/src/models" +echo "Downloading models to $MODELS_DIR" +mkdir -p "$MODELS_DIR" + +# Default models if no arguments provided +if [ "$MODEL_TYPE" = "pth" ]; then + DEFAULT_MODELS=( + "https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth" + ) +else + DEFAULT_MODELS=( + "https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.onnx" + "https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19_fp16.onnx" + ) +fi + +# Use provided models or default +if [ $# -gt 0 ]; then + MODELS=("$@") +else + MODELS=("${DEFAULT_MODELS[@]}") +fi + +# Download all models +success=true +for model in "${MODELS[@]}"; do + if ! download_file "$model" "$MODELS_DIR" "$MODEL_TYPE"; then + success=false + fi +done + +if [ "$success" = true ]; then + echo "${MODEL_TYPE^^} model download complete!" + exit 0 +else + echo "Some downloads failed" >&2 + exit 1 +fi \ No newline at end of file