WIP: v1_0_0 migration

This commit is contained in:
remsky 2025-01-28 13:52:57 -07:00
parent 1345b6c81a
commit 9867fc398f
16 changed files with 422 additions and 269 deletions

View file

@ -56,7 +56,9 @@ The service can be accessed through either the API endpoints or the Gradio web i
cd docker/gpu # OR
# cd docker/cpu # Run this or the above
docker compose up --build
# if you are missing any models, run the .py or .sh scrips in the respective folders
# if you are missing any models, run:
# python ../scripts/download_model.py --type pth # for GPU
# python ../scripts/download_model.py --type onnx # for CPU
```
Once started:

View file

@ -336,3 +336,4 @@ async def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
_manager_instance = ModelManager(config)
await _manager_instance.initialize()
return _manager_instance

View file

@ -5,6 +5,7 @@ FastAPI OpenAI Compatible API
import os
import sys
from contextlib import asynccontextmanager
from pathlib import Path
import torch
import uvicorn
@ -57,9 +58,23 @@ async def lifespan(app: FastAPI):
# Initialize model with warmup and get status
device, model, voicepack_count = await model_manager.initialize_with_warmup(voice_manager)
except FileNotFoundError:
logger.error("""
Model files not found! You need to either:
1. Download models using the scripts:
GPU: python docker/scripts/download_model.py --type pth
CPU: python docker/scripts/download_model.py --type onnx
2. Set environment variables in docker-compose:
GPU: DOWNLOAD_PTH=true
CPU: DOWNLOAD_ONNX=true
""")
raise
except Exception as e:
logger.error(f"Failed to initialize model: {e}")
raise
boundary = "" * 2*12
startup_msg = f"""

View file

@ -1,3 +1,5 @@
"""OpenAI-compatible router for text-to-speech"""
import json
import os
from typing import AsyncGenerator, Dict, List, Union
@ -217,9 +219,9 @@ async def create_speech(
stitch_long_output=True
)
# Convert to requested format
# Convert to requested format - removed stream parameter
content = await AudioService.convert_audio(
audio, 24000, request.response_format, is_first_chunk=True, stream=False
audio, 24000, request.response_format, is_first_chunk=True
)
return Response(

View file

@ -16,7 +16,6 @@ class AudioNormalizer:
"""Handles audio normalization state for a single stream"""
def __init__(self):
self.int16_max = np.iinfo(np.int16).max
self.chunk_trim_ms = settings.gap_trim_ms
self.sample_rate = 24000 # Sample rate of the audio
self.samples_to_trim = int(self.chunk_trim_ms * self.sample_rate / 1000)
@ -30,20 +29,23 @@ class AudioNormalizer:
Returns:
Normalized and trimmed audio data
"""
# Convert to float32 for processing
audio_float = audio_data.astype(np.float32)
if len(audio_data) == 0:
raise ValueError("Empty audio data")
# Trim start and end if enough samples
if len(audio_float) > (2 * self.samples_to_trim):
audio_float = audio_float[self.samples_to_trim:-self.samples_to_trim]
if len(audio_data) > (2 * self.samples_to_trim):
audio_data = audio_data[self.samples_to_trim:-self.samples_to_trim]
# Scale to int16 range
return (audio_float * 32767).astype(np.int16)
# Scale directly to int16 range with clipping
return np.clip(audio_data * 32767, -32768, 32767).astype(np.int16)
class AudioService:
"""Service for audio format conversions with streaming support"""
# Supported formats
SUPPORTED_FORMATS = {"wav", "mp3", "opus", "flac", "aac", "pcm", "ogg"}
# Default audio format settings balanced for speed and compression
DEFAULT_SETTINGS = {
"mp3": {
@ -86,6 +88,10 @@ class AudioService:
Bytes of the converted audio chunk
"""
try:
# Validate format
if output_format not in AudioService.SUPPORTED_FORMATS:
raise ValueError(f"Format {output_format} not supported")
# Always normalize audio to ensure proper amplitude scaling
if normalizer is None:
normalizer = AudioNormalizer()

View file

@ -17,26 +17,41 @@ class StreamingAudioWriter:
self.sample_rate = sample_rate
self.channels = channels
self.bytes_written = 0
self.buffer = BytesIO()
# Format-specific setup
if self.format == "wav":
self._write_wav_header()
elif self.format == "ogg":
elif self.format in ["ogg", "opus"]:
# For OGG/Opus, write to memory buffer
self.writer = sf.SoundFile(
file=BytesIO(),
file=self.buffer,
mode='w',
samplerate=sample_rate,
channels=channels,
format='OGG',
subtype='VORBIS'
subtype='VORBIS' if self.format == "ogg" else "OPUS"
)
elif self.format == "mp3":
# For MP3, we'll use pydub's incremental writer
self.buffer = BytesIO()
elif self.format == "flac":
# For FLAC, write to memory buffer
self.writer = sf.SoundFile(
file=self.buffer,
mode='w',
samplerate=sample_rate,
channels=channels,
format='FLAC'
)
elif self.format in ["mp3", "aac"]:
# For MP3/AAC, we'll use pydub's incremental writer
self.segments = [] # Store segments until we have enough data
self.total_duration = 0 # Track total duration in milliseconds
# Initialize an empty AudioSegment as our encoder
self.encoder = AudioSegment.silent(duration=0, frame_rate=self.sample_rate)
elif self.format == "pcm":
# PCM doesn't need initialization, we'll write raw bytes
pass
else:
raise ValueError(f"Unsupported format: {format}")
def _write_wav_header(self) -> bytes:
"""Write WAV header with correct streaming format"""
@ -63,42 +78,48 @@ class StreamingAudioWriter:
audio_data: Audio data to write, or None if finalizing
finalize: Whether this is the final write to close the stream
"""
buffer = BytesIO()
output_buffer = BytesIO()
if finalize:
if self.format == "wav":
# Write final WAV header with correct sizes
buffer.write(b'RIFF')
buffer.write(struct.pack('<L', self.bytes_written + 36))
buffer.write(b'WAVE')
buffer.write(b'fmt ')
buffer.write(struct.pack('<L', 16))
buffer.write(struct.pack('<H', 1))
buffer.write(struct.pack('<H', self.channels))
buffer.write(struct.pack('<L', self.sample_rate))
buffer.write(struct.pack('<L', self.sample_rate * self.channels * 2))
buffer.write(struct.pack('<H', self.channels * 2))
buffer.write(struct.pack('<H', 16))
buffer.write(b'data')
buffer.write(struct.pack('<L', self.bytes_written))
elif self.format == "ogg":
output_buffer.write(b'RIFF')
output_buffer.write(struct.pack('<L', self.bytes_written + 36))
output_buffer.write(b'WAVE')
output_buffer.write(b'fmt ')
output_buffer.write(struct.pack('<L', 16))
output_buffer.write(struct.pack('<H', 1))
output_buffer.write(struct.pack('<H', self.channels))
output_buffer.write(struct.pack('<L', self.sample_rate))
output_buffer.write(struct.pack('<L', self.sample_rate * self.channels * 2))
output_buffer.write(struct.pack('<H', self.channels * 2))
output_buffer.write(struct.pack('<H', 16))
output_buffer.write(b'data')
output_buffer.write(struct.pack('<L', self.bytes_written))
elif self.format in ["ogg", "opus", "flac"]:
self.writer.close()
elif self.format == "mp3":
return self.buffer.getvalue()
elif self.format in ["mp3", "aac"]:
# Final export of any remaining audio
if hasattr(self, 'encoder') and len(self.encoder) > 0:
# Export with duration metadata
format_args = {
"mp3": {"format": "mp3", "codec": "libmp3lame"},
"aac": {"format": "adts", "codec": "aac"}
}[self.format]
self.encoder.export(
buffer,
format="mp3",
output_buffer,
**format_args,
bitrate="192k",
parameters=[
"-q:a", "2",
"-write_xing", "1", # Force XING/LAME header
"-write_xing", "1" if self.format == "mp3" else "0", # XING header for MP3 only
"-metadata", f"duration={self.total_duration/1000}" # Duration in seconds
]
)
self.encoder = None
return buffer.getvalue()
return output_buffer.getvalue()
if audio_data is None or len(audio_data) == 0:
return b''
@ -106,22 +127,22 @@ class StreamingAudioWriter:
if self.format == "wav":
# For WAV, write raw PCM after the first chunk
if self.bytes_written == 0:
buffer.write(self._write_wav_header())
buffer.write(audio_data.tobytes())
output_buffer.write(self._write_wav_header())
output_buffer.write(audio_data.tobytes())
self.bytes_written += len(audio_data.tobytes())
elif self.format == "ogg":
# OGG/Vorbis handles streaming naturally
elif self.format in ["ogg", "opus", "flac"]:
# Write to soundfile buffer
self.writer.write(audio_data)
self.writer.flush()
buffer = self.writer.file
buffer.seek(0, 2) # Seek to end
chunk = buffer.getvalue()
buffer.seek(0)
buffer.truncate()
return chunk
# Get current buffer contents
data = self.buffer.getvalue()
# Clear buffer for next chunk
self.buffer.seek(0)
self.buffer.truncate()
return data
elif self.format == "mp3":
elif self.format in ["mp3", "aac"]:
# Convert chunk to AudioSegment and encode
segment = AudioSegment(
audio_data.tobytes(),
@ -137,21 +158,30 @@ class StreamingAudioWriter:
self.encoder = self.encoder + segment
# Export current state to buffer
self.encoder.export(buffer, format="mp3", bitrate="192k", parameters=[
format_args = {
"mp3": {"format": "mp3", "codec": "libmp3lame"},
"aac": {"format": "adts", "codec": "aac"}
}[self.format]
self.encoder.export(output_buffer, **format_args, bitrate="192k", parameters=[
"-q:a", "2",
"-write_xing", "1", # Force XING/LAME header
"-write_xing", "1" if self.format == "mp3" else "0", # XING header for MP3 only
"-metadata", f"duration={self.total_duration/1000}" # Duration in seconds
])
# Get the encoded data
encoded_data = buffer.getvalue()
encoded_data = output_buffer.getvalue()
# Reset encoder to prevent memory growth
self.encoder = AudioSegment.silent(duration=0, frame_rate=self.sample_rate)
return encoded_data
return buffer.getvalue()
elif self.format == "pcm":
# For PCM, just write raw bytes
return audio_data.tobytes()
return output_buffer.getvalue()
def close(self) -> Optional[bytes]:
"""Finish the audio file and return any remaining data"""
@ -161,16 +191,31 @@ class StreamingAudioWriter:
buffer.write(b'RIFF')
buffer.write(struct.pack('<L', self.bytes_written + 36)) # File size
buffer.write(b'WAVE')
# ... rest of header ...
buffer.write(struct.pack('<L', self.bytes_written)) # Data size
buffer.write(b'fmt ')
buffer.write(struct.pack('<L', 16))
buffer.write(struct.pack('<H', 1))
buffer.write(struct.pack('<H', self.channels))
buffer.write(struct.pack('<L', self.sample_rate))
buffer.write(struct.pack('<L', self.sample_rate * self.channels * 2))
buffer.write(struct.pack('<H', self.channels * 2))
buffer.write(struct.pack('<H', 16))
buffer.write(b'data')
buffer.write(struct.pack('<L', self.bytes_written))
return buffer.getvalue()
elif self.format == "ogg":
elif self.format in ["ogg", "opus", "flac"]:
self.writer.close()
return None
return self.buffer.getvalue()
elif self.format == "mp3":
# Flush any remaining MP3 frames
elif self.format in ["mp3", "aac"]:
# Flush any remaining audio
buffer = BytesIO()
self.encoder.export(buffer, format="mp3")
if hasattr(self, 'encoder') and len(self.encoder) > 0:
format_args = {
"mp3": {"format": "mp3", "codec": "libmp3lame"},
"aac": {"format": "adts", "codec": "aac"}
}[self.format]
self.encoder.export(buffer, **format_args)
return buffer.getvalue()
return None

View file

@ -26,106 +26,128 @@ def sample_audio():
return np.sin(2 * np.pi * frequency * t).astype(np.float32), sample_rate
def test_convert_to_wav(sample_audio):
@pytest.mark.asyncio
async def test_convert_to_wav(sample_audio):
"""Test converting to WAV format"""
audio_data, sample_rate = sample_audio
result = AudioService.convert_audio(audio_data, sample_rate, "wav")
result = await AudioService.convert_audio(audio_data, sample_rate, "wav")
assert isinstance(result, bytes)
assert len(result) > 0
# Check WAV header
assert result.startswith(b'RIFF')
assert b'WAVE' in result[:12]
def test_convert_to_mp3(sample_audio):
@pytest.mark.asyncio
async def test_convert_to_mp3(sample_audio):
"""Test converting to MP3 format"""
audio_data, sample_rate = sample_audio
result = AudioService.convert_audio(audio_data, sample_rate, "mp3")
result = await AudioService.convert_audio(audio_data, sample_rate, "mp3")
assert isinstance(result, bytes)
assert len(result) > 0
# Check MP3 header (ID3 or MPEG frame sync)
assert result.startswith(b'ID3') or result.startswith(b'\xff\xfb')
def test_convert_to_opus(sample_audio):
@pytest.mark.asyncio
async def test_convert_to_opus(sample_audio):
"""Test converting to Opus format"""
audio_data, sample_rate = sample_audio
result = AudioService.convert_audio(audio_data, sample_rate, "opus")
result = await AudioService.convert_audio(audio_data, sample_rate, "opus")
assert isinstance(result, bytes)
assert len(result) > 0
# Check OGG header
assert result.startswith(b'OggS')
def test_convert_to_flac(sample_audio):
@pytest.mark.asyncio
async def test_convert_to_flac(sample_audio):
"""Test converting to FLAC format"""
audio_data, sample_rate = sample_audio
result = AudioService.convert_audio(audio_data, sample_rate, "flac")
result = await AudioService.convert_audio(audio_data, sample_rate, "flac")
assert isinstance(result, bytes)
assert len(result) > 0
# Check FLAC header
assert result.startswith(b'fLaC')
def test_convert_to_aac(sample_audio):
@pytest.mark.asyncio
async def test_convert_to_aac(sample_audio):
"""Test converting to AAC format"""
audio_data, sample_rate = sample_audio
result = AudioService.convert_audio(audio_data, sample_rate, "aac")
result = await AudioService.convert_audio(audio_data, sample_rate, "aac")
assert isinstance(result, bytes)
assert len(result) > 0
# AAC files typically start with an ADTS header
assert result.startswith(b'\xff\xf1') or result.startswith(b'\xff\xf9')
# Check ADTS header (AAC)
assert result.startswith(b'\xff\xf0') or result.startswith(b'\xff\xf1')
def test_convert_to_pcm(sample_audio):
@pytest.mark.asyncio
async def test_convert_to_pcm(sample_audio):
"""Test converting to PCM format"""
audio_data, sample_rate = sample_audio
result = AudioService.convert_audio(audio_data, sample_rate, "pcm")
result = await AudioService.convert_audio(audio_data, sample_rate, "pcm")
assert isinstance(result, bytes)
assert len(result) > 0
# PCM is raw bytes, so no header to check
def test_convert_to_invalid_format_raises_error(sample_audio):
@pytest.mark.asyncio
async def test_convert_to_invalid_format_raises_error(sample_audio):
"""Test that converting to an invalid format raises an error"""
audio_data, sample_rate = sample_audio
with pytest.raises(ValueError, match="Format invalid not supported"):
AudioService.convert_audio(audio_data, sample_rate, "invalid")
await AudioService.convert_audio(audio_data, sample_rate, "invalid")
def test_normalization_wav(sample_audio):
@pytest.mark.asyncio
async def test_normalization_wav(sample_audio):
"""Test that WAV output is properly normalized to int16 range"""
audio_data, sample_rate = sample_audio
# Create audio data outside int16 range
large_audio = audio_data * 1e5
result = AudioService.convert_audio(large_audio, sample_rate, "wav")
result = await AudioService.convert_audio(large_audio, sample_rate, "wav")
assert isinstance(result, bytes)
assert len(result) > 0
def test_normalization_pcm(sample_audio):
@pytest.mark.asyncio
async def test_normalization_pcm(sample_audio):
"""Test that PCM output is properly normalized to int16 range"""
audio_data, sample_rate = sample_audio
# Create audio data outside int16 range
large_audio = audio_data * 1e5
result = AudioService.convert_audio(large_audio, sample_rate, "pcm")
result = await AudioService.convert_audio(large_audio, sample_rate, "pcm")
assert isinstance(result, bytes)
assert len(result) > 0
def test_invalid_audio_data():
@pytest.mark.asyncio
async def test_invalid_audio_data():
"""Test handling of invalid audio data"""
invalid_audio = np.array([]) # Empty array
sample_rate = 24000
with pytest.raises(ValueError):
AudioService.convert_audio(invalid_audio, sample_rate, "wav")
await AudioService.convert_audio(invalid_audio, sample_rate, "wav")
def test_different_sample_rates(sample_audio):
@pytest.mark.asyncio
async def test_different_sample_rates(sample_audio):
"""Test converting audio with different sample rates"""
audio_data, _ = sample_audio
sample_rates = [8000, 16000, 44100, 48000]
for rate in sample_rates:
result = AudioService.convert_audio(audio_data, rate, "wav")
result = await AudioService.convert_audio(audio_data, rate, "wav")
assert isinstance(result, bytes)
assert len(result) > 0
def test_buffer_position_after_conversion(sample_audio):
@pytest.mark.asyncio
async def test_buffer_position_after_conversion(sample_audio):
"""Test that buffer position is reset after writing"""
audio_data, sample_rate = sample_audio
result = AudioService.convert_audio(audio_data, sample_rate, "wav")
result = await AudioService.convert_audio(audio_data, sample_rate, "wav")
# Convert again to ensure buffer was properly reset
result2 = AudioService.convert_audio(audio_data, sample_rate, "wav")
result2 = await AudioService.convert_audio(audio_data, sample_rate, "wav")
assert len(result) == len(result2)

View file

@ -59,7 +59,7 @@ async def test_empty_text(tts_service, test_voice):
voice=test_voice,
speed=1.0
)
assert "Text is empty after preprocessing" in str(exc_info.value)
assert "No audio chunks were generated successfully" in str(exc_info.value)
@pytest.mark.asyncio
async def test_invalid_voice(tts_service):
@ -126,15 +126,17 @@ async def test_combine_voices(tts_service):
@pytest.mark.asyncio
async def test_chunked_text_processing(tts_service, test_voice, mock_audio_output):
"""Test processing chunked text"""
long_text = "First sentence. Second sentence. Third sentence."
# Create text that will force chunking by exceeding max tokens
long_text = "This is a test sentence." * 100 # Should be way over 500 tokens
# Don't mock smart_split - let it actually split the text
audio, processing_time = await tts_service.generate_audio(
text=long_text,
voice=test_voice,
speed=1.0,
stitch_long_output=True
speed=1.0
)
# Should be called multiple times due to chunking
assert tts_service.model_manager.generate.call_count > 1
assert isinstance(audio, np.ndarray)
assert processing_time > 0

View file

@ -34,6 +34,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
# Copy project files including models
COPY --chown=appuser:appuser api ./api
COPY --chown=appuser:appuser web ./web
COPY --chown=appuser:appuser docker/scripts/download_model.* ./
# Install project
RUN --mount=type=cache,target=/root/.cache/uv \
@ -44,8 +45,19 @@ ENV PYTHONUNBUFFERED=1
ENV PYTHONPATH=/app
ENV PATH="/app/.venv/bin:$PATH"
ENV UV_LINK_MODE=copy
ENV USE_GPU=false
ENV USE_ONNX=false
ENV USE_ONNX=true
ENV DOWNLOAD_ONNX=true
ENV DOWNLOAD_PTH=false
# Download models based on environment variables
RUN if [ "$DOWNLOAD_ONNX" = "true" ]; then \
python download_model.py --type onnx; \
fi && \
if [ "$DOWNLOAD_PTH" = "true" ]; then \
python download_model.py --type pth; \
fi
# Run FastAPI server
CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]

View file

@ -1,53 +0,0 @@
#!/usr/bin/env python3
import os
import sys
import requests
from pathlib import Path
from typing import List
def download_file(url: str, output_dir: Path) -> None:
"""Download a file from URL to the specified directory."""
filename = os.path.basename(url)
output_path = output_dir / filename
print(f"Downloading {filename}...")
response = requests.get(url, stream=True)
response.raise_for_status()
with open(output_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
def find_project_root() -> Path:
"""Find project root by looking for api directory."""
max_steps = 5
current = Path(__file__).resolve()
for _ in range(max_steps):
if (current / 'api').is_dir():
return current
current = current.parent
raise RuntimeError("Could not find project root (no api directory found)")
def main(custom_models: List[str] = None):
# Always use top-level models directory relative to project root
project_root = find_project_root()
models_dir = project_root / 'api' / 'src' / 'models'
models_dir.mkdir(exist_ok=True, parents=True)
# Default ONNX model if no arguments provided
default_models = [
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.onnx",
# "https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19_fp16.onnx"
]
# Use provided models or default
models_to_download = custom_models if custom_models else default_models
for model_url in models_to_download:
try:
download_file(model_url, models_dir)
except Exception as e:
print(f"Error downloading {model_url}: {e}")
if __name__ == "__main__":
main(sys.argv[1:] if len(sys.argv) > 1 else None)

View file

@ -1,32 +0,0 @@
#!/bin/bash
# Ensure models directory exists
mkdir -p api/src/models
# Function to download a file
download_file() {
local url="$1"
local filename=$(basename "$url")
echo "Downloading $filename..."
curl -L "$url" -o "api/src/models/$filename"
}
# Default ONNX model if no arguments provided
DEFAULT_MODELS=(
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.onnx"
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19_fp16.onnx"
)
# Use provided models or default
if [ $# -gt 0 ]; then
MODELS=("$@")
else
MODELS=("${DEFAULT_MODELS[@]}")
fi
# Download all models
for model in "${MODELS[@]}"; do
download_file "$model"
done
echo "ONNX model download complete!"

View file

@ -38,6 +38,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
# Copy project files including models
COPY --chown=appuser:appuser api ./api
COPY --chown=appuser:appuser web ./web
COPY --chown=appuser:appuser docker/scripts/download_model.* ./
# Install project with GPU extras
RUN --mount=type=cache,target=/root/.cache/uv \
@ -48,7 +49,19 @@ ENV PYTHONUNBUFFERED=1
ENV PYTHONPATH=/app
ENV PATH="/app/.venv/bin:$PATH"
ENV UV_LINK_MODE=copy
ENV USE_GPU=true
ENV USE_ONNX=false
ENV DOWNLOAD_PTH=true
ENV DOWNLOAD_ONNX=false
# Download models based on environment variables
RUN if [ "$DOWNLOAD_PTH" = "true" ]; then \
python download_model.py --type pth; \
fi && \
if [ "$DOWNLOAD_ONNX" = "true" ]; then \
python download_model.py --type onnx; \
fi
# Run FastAPI server
CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]

View file

@ -1,57 +0,0 @@
#!/usr/bin/env python3
import os
import sys
import requests
from pathlib import Path
from typing import List
def download_file(url: str, output_dir: Path) -> None:
"""Download a file from URL to the specified directory."""
filename = os.path.basename(url)
if not filename.endswith('.pth'):
print(f"Warning: {filename} is not a .pth file")
return
output_path = output_dir / filename
print(f"Downloading {filename}...")
response = requests.get(url, stream=True)
response.raise_for_status()
with open(output_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
def find_project_root() -> Path:
"""Find project root by looking for api directory."""
max_steps = 5
current = Path(__file__).resolve()
for _ in range(max_steps):
if (current / 'api').is_dir():
return current
current = current.parent
raise RuntimeError("Could not find project root (no api directory found)")
def main(custom_models: List[str] = None):
# Find project root and ensure models directory exists
project_root = find_project_root()
models_dir = project_root / 'api' / 'src' / 'models'
print(f"Downloading models to {models_dir}")
models_dir.mkdir(exist_ok=True)
# Default PTH model if no arguments provided
default_models = [
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth"
]
# Use provided models or default
models_to_download = custom_models if custom_models else default_models
for model_url in models_to_download:
try:
download_file(model_url, models_dir)
except Exception as e:
print(f"Error downloading {model_url}: {e}")
if __name__ == "__main__":
main(sys.argv[1:] if len(sys.argv) > 1 else None)

View file

@ -1,31 +0,0 @@
#!/bin/bash
# Ensure models directory exists
mkdir -p api/src/models
# Function to download a file
download_file() {
local url="$1"
local filename=$(basename "$url")
echo "Downloading $filename..."
curl -L "$url" -o "api/src/models/$filename"
}
# Default PTH model if no arguments provided
DEFAULT_MODELS=(
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth"
)
# Use provided models or default
if [ $# -gt 0 ]; then
MODELS=("$@")
else
MODELS=("${DEFAULT_MODELS[@]}")
fi
# Download all models
for model in "${MODELS[@]}"; do
download_file "$model"
done
echo "PyTorch model download complete!"

View file

@ -0,0 +1,97 @@
#!/usr/bin/env python3
import os
import sys
import argparse
import requests
from pathlib import Path
from typing import List
def download_file(url: str, output_dir: Path, model_type: str) -> bool:
"""Download a file from URL to the specified directory.
Returns:
bool: True if download succeeded, False otherwise
"""
filename = os.path.basename(url)
if not filename.endswith(f'.{model_type}'):
print(f"Warning: {filename} is not a .{model_type} file", file=sys.stderr)
return False
output_path = output_dir / filename
print(f"Downloading {filename}...")
try:
response = requests.get(url, stream=True)
response.raise_for_status()
with open(output_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"Successfully downloaded {filename}")
return True
except Exception as e:
print(f"Error downloading {filename}: {e}", file=sys.stderr)
return False
def find_project_root() -> Path:
"""Find project root by looking for api directory."""
max_steps = 5
current = Path(__file__).resolve()
for _ in range(max_steps):
if (current / 'api').is_dir():
return current
current = current.parent
raise RuntimeError("Could not find project root (no api directory found)")
def main() -> int:
"""Download models to the project.
Returns:
int: Exit code (0 for success, 1 for failure)
"""
parser = argparse.ArgumentParser(description='Download model files')
parser.add_argument('--type', choices=['pth', 'onnx'], required=True,
help='Model type to download (pth or onnx)')
parser.add_argument('urls', nargs='*', help='Optional model URLs to download')
args = parser.parse_args()
try:
# Find project root and ensure models directory exists
project_root = find_project_root()
models_dir = project_root / 'api' / 'src' / 'models'
print(f"Downloading models to {models_dir}")
models_dir.mkdir(exist_ok=True)
# Default models if no arguments provided
default_models = {
'pth': [
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth"
],
'onnx': [
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.onnx",
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19_fp16.onnx"
]
}
# Use provided models or default
models_to_download = args.urls if args.urls else default_models[args.type]
# Download all models
success = True
for model_url in models_to_download:
if not download_file(model_url, models_dir, args.type):
success = False
if success:
print(f"{args.type.upper()} model download complete!")
return 0
else:
print("Some downloads failed", file=sys.stderr)
return 1
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
return 1
if __name__ == "__main__":
sys.exit(main())

View file

@ -0,0 +1,109 @@
#!/bin/bash
# Find project root by looking for api directory
find_project_root() {
local current_dir="$PWD"
local max_steps=5
local steps=0
while [ $steps -lt $max_steps ]; do
if [ -d "$current_dir/api" ]; then
echo "$current_dir"
return 0
fi
current_dir="$(dirname "$current_dir")"
((steps++))
done
echo "Error: Could not find project root (no api directory found)" >&2
exit 1
}
# Function to download a file
download_file() {
local url="$1"
local output_dir="$2"
local model_type="$3"
local filename=$(basename "$url")
# Validate file extension
if [[ ! "$filename" =~ \.$model_type$ ]]; then
echo "Warning: $filename is not a .$model_type file" >&2
return 1
}
echo "Downloading $filename..."
if curl -L "$url" -o "$output_dir/$filename"; then
echo "Successfully downloaded $filename"
return 0
else
echo "Error downloading $filename" >&2
return 1
fi
}
# Parse arguments
MODEL_TYPE=""
while [[ $# -gt 0 ]]; do
case $1 in
--type)
MODEL_TYPE="$2"
shift 2
;;
*)
# If no flag specified, treat remaining args as model URLs
break
;;
esac
done
# Validate model type
if [ "$MODEL_TYPE" != "pth" ] && [ "$MODEL_TYPE" != "onnx" ]; then
echo "Error: Must specify model type with --type (pth or onnx)" >&2
exit 1
fi
# Find project root and ensure models directory exists
PROJECT_ROOT=$(find_project_root)
if [ $? -ne 0 ]; then
exit 1
fi
MODELS_DIR="$PROJECT_ROOT/api/src/models"
echo "Downloading models to $MODELS_DIR"
mkdir -p "$MODELS_DIR"
# Default models if no arguments provided
if [ "$MODEL_TYPE" = "pth" ]; then
DEFAULT_MODELS=(
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth"
)
else
DEFAULT_MODELS=(
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.onnx"
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19_fp16.onnx"
)
fi
# Use provided models or default
if [ $# -gt 0 ]; then
MODELS=("$@")
else
MODELS=("${DEFAULT_MODELS[@]}")
fi
# Download all models
success=true
for model in "${MODELS[@]}"; do
if ! download_file "$model" "$MODELS_DIR" "$MODEL_TYPE"; then
success=false
fi
done
if [ "$success" = true ]; then
echo "${MODEL_TYPE^^} model download complete!"
exit 0
else
echo "Some downloads failed" >&2
exit 1
fi