mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
WIP: v1_0_0 migration
This commit is contained in:
parent
1345b6c81a
commit
9867fc398f
16 changed files with 422 additions and 269 deletions
|
@ -53,10 +53,12 @@ The service can be accessed through either the API endpoints or the Gradio web i
|
||||||
git clone https://github.com/remsky/Kokoro-FastAPI.git
|
git clone https://github.com/remsky/Kokoro-FastAPI.git
|
||||||
cd Kokoro-FastAPI
|
cd Kokoro-FastAPI
|
||||||
|
|
||||||
cd docker/gpu # OR
|
cd docker/gpu # OR
|
||||||
# cd docker/cpu # Run this or the above
|
# cd docker/cpu # Run this or the above
|
||||||
docker compose up --build
|
docker compose up --build
|
||||||
# if you are missing any models, run the .py or .sh scrips in the respective folders
|
# if you are missing any models, run:
|
||||||
|
# python ../scripts/download_model.py --type pth # for GPU
|
||||||
|
# python ../scripts/download_model.py --type onnx # for CPU
|
||||||
```
|
```
|
||||||
|
|
||||||
Once started:
|
Once started:
|
||||||
|
|
|
@ -335,4 +335,5 @@ async def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
|
||||||
if _manager_instance is None:
|
if _manager_instance is None:
|
||||||
_manager_instance = ModelManager(config)
|
_manager_instance = ModelManager(config)
|
||||||
await _manager_instance.initialize()
|
await _manager_instance.initialize()
|
||||||
return _manager_instance
|
return _manager_instance
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@ FastAPI OpenAI Compatible API
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
@ -57,16 +58,30 @@ async def lifespan(app: FastAPI):
|
||||||
|
|
||||||
# Initialize model with warmup and get status
|
# Initialize model with warmup and get status
|
||||||
device, model, voicepack_count = await model_manager.initialize_with_warmup(voice_manager)
|
device, model, voicepack_count = await model_manager.initialize_with_warmup(voice_manager)
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.error("""
|
||||||
|
Model files not found! You need to either:
|
||||||
|
|
||||||
|
1. Download models using the scripts:
|
||||||
|
GPU: python docker/scripts/download_model.py --type pth
|
||||||
|
CPU: python docker/scripts/download_model.py --type onnx
|
||||||
|
|
||||||
|
2. Set environment variables in docker-compose:
|
||||||
|
GPU: DOWNLOAD_PTH=true
|
||||||
|
CPU: DOWNLOAD_ONNX=true
|
||||||
|
""")
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to initialize model: {e}")
|
logger.error(f"Failed to initialize model: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
boundary = "░" * 2*12
|
boundary = "░" * 2*12
|
||||||
startup_msg = f"""
|
startup_msg = f"""
|
||||||
|
|
||||||
{boundary}
|
{boundary}
|
||||||
|
|
||||||
╔═╗┌─┐┌─┐┌┬┐
|
╔═╗┌─┐┌─┐┌┬┐
|
||||||
╠╣ ├─┤└─┐ │
|
╠╣ ├─┤└─┐ │
|
||||||
╚ ┴ ┴└─┘ ┴
|
╚ ┴ ┴└─┘ ┴
|
||||||
╦╔═┌─┐┬┌─┌─┐
|
╦╔═┌─┐┬┌─┌─┐
|
||||||
╠╩╗│ │├┴┐│ │
|
╠╩╗│ │├┴┐│ │
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
"""OpenAI-compatible router for text-to-speech"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import AsyncGenerator, Dict, List, Union
|
from typing import AsyncGenerator, Dict, List, Union
|
||||||
|
@ -217,9 +219,9 @@ async def create_speech(
|
||||||
stitch_long_output=True
|
stitch_long_output=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert to requested format
|
# Convert to requested format - removed stream parameter
|
||||||
content = await AudioService.convert_audio(
|
content = await AudioService.convert_audio(
|
||||||
audio, 24000, request.response_format, is_first_chunk=True, stream=False
|
audio, 24000, request.response_format, is_first_chunk=True
|
||||||
)
|
)
|
||||||
|
|
||||||
return Response(
|
return Response(
|
||||||
|
|
|
@ -16,7 +16,6 @@ class AudioNormalizer:
|
||||||
"""Handles audio normalization state for a single stream"""
|
"""Handles audio normalization state for a single stream"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.int16_max = np.iinfo(np.int16).max
|
|
||||||
self.chunk_trim_ms = settings.gap_trim_ms
|
self.chunk_trim_ms = settings.gap_trim_ms
|
||||||
self.sample_rate = 24000 # Sample rate of the audio
|
self.sample_rate = 24000 # Sample rate of the audio
|
||||||
self.samples_to_trim = int(self.chunk_trim_ms * self.sample_rate / 1000)
|
self.samples_to_trim = int(self.chunk_trim_ms * self.sample_rate / 1000)
|
||||||
|
@ -30,20 +29,23 @@ class AudioNormalizer:
|
||||||
Returns:
|
Returns:
|
||||||
Normalized and trimmed audio data
|
Normalized and trimmed audio data
|
||||||
"""
|
"""
|
||||||
# Convert to float32 for processing
|
if len(audio_data) == 0:
|
||||||
audio_float = audio_data.astype(np.float32)
|
raise ValueError("Empty audio data")
|
||||||
|
|
||||||
# Trim start and end if enough samples
|
# Trim start and end if enough samples
|
||||||
if len(audio_float) > (2 * self.samples_to_trim):
|
if len(audio_data) > (2 * self.samples_to_trim):
|
||||||
audio_float = audio_float[self.samples_to_trim:-self.samples_to_trim]
|
audio_data = audio_data[self.samples_to_trim:-self.samples_to_trim]
|
||||||
|
|
||||||
# Scale to int16 range
|
# Scale directly to int16 range with clipping
|
||||||
return (audio_float * 32767).astype(np.int16)
|
return np.clip(audio_data * 32767, -32768, 32767).astype(np.int16)
|
||||||
|
|
||||||
|
|
||||||
class AudioService:
|
class AudioService:
|
||||||
"""Service for audio format conversions with streaming support"""
|
"""Service for audio format conversions with streaming support"""
|
||||||
|
|
||||||
|
# Supported formats
|
||||||
|
SUPPORTED_FORMATS = {"wav", "mp3", "opus", "flac", "aac", "pcm", "ogg"}
|
||||||
|
|
||||||
# Default audio format settings balanced for speed and compression
|
# Default audio format settings balanced for speed and compression
|
||||||
DEFAULT_SETTINGS = {
|
DEFAULT_SETTINGS = {
|
||||||
"mp3": {
|
"mp3": {
|
||||||
|
@ -86,6 +88,10 @@ class AudioService:
|
||||||
Bytes of the converted audio chunk
|
Bytes of the converted audio chunk
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Validate format
|
||||||
|
if output_format not in AudioService.SUPPORTED_FORMATS:
|
||||||
|
raise ValueError(f"Format {output_format} not supported")
|
||||||
|
|
||||||
# Always normalize audio to ensure proper amplitude scaling
|
# Always normalize audio to ensure proper amplitude scaling
|
||||||
if normalizer is None:
|
if normalizer is None:
|
||||||
normalizer = AudioNormalizer()
|
normalizer = AudioNormalizer()
|
||||||
|
|
|
@ -17,26 +17,41 @@ class StreamingAudioWriter:
|
||||||
self.sample_rate = sample_rate
|
self.sample_rate = sample_rate
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.bytes_written = 0
|
self.bytes_written = 0
|
||||||
|
self.buffer = BytesIO()
|
||||||
|
|
||||||
# Format-specific setup
|
# Format-specific setup
|
||||||
if self.format == "wav":
|
if self.format == "wav":
|
||||||
self._write_wav_header()
|
self._write_wav_header()
|
||||||
elif self.format == "ogg":
|
elif self.format in ["ogg", "opus"]:
|
||||||
|
# For OGG/Opus, write to memory buffer
|
||||||
self.writer = sf.SoundFile(
|
self.writer = sf.SoundFile(
|
||||||
file=BytesIO(),
|
file=self.buffer,
|
||||||
mode='w',
|
mode='w',
|
||||||
samplerate=sample_rate,
|
samplerate=sample_rate,
|
||||||
channels=channels,
|
channels=channels,
|
||||||
format='OGG',
|
format='OGG',
|
||||||
subtype='VORBIS'
|
subtype='VORBIS' if self.format == "ogg" else "OPUS"
|
||||||
)
|
)
|
||||||
elif self.format == "mp3":
|
elif self.format == "flac":
|
||||||
# For MP3, we'll use pydub's incremental writer
|
# For FLAC, write to memory buffer
|
||||||
self.buffer = BytesIO()
|
self.writer = sf.SoundFile(
|
||||||
|
file=self.buffer,
|
||||||
|
mode='w',
|
||||||
|
samplerate=sample_rate,
|
||||||
|
channels=channels,
|
||||||
|
format='FLAC'
|
||||||
|
)
|
||||||
|
elif self.format in ["mp3", "aac"]:
|
||||||
|
# For MP3/AAC, we'll use pydub's incremental writer
|
||||||
self.segments = [] # Store segments until we have enough data
|
self.segments = [] # Store segments until we have enough data
|
||||||
self.total_duration = 0 # Track total duration in milliseconds
|
self.total_duration = 0 # Track total duration in milliseconds
|
||||||
# Initialize an empty AudioSegment as our encoder
|
# Initialize an empty AudioSegment as our encoder
|
||||||
self.encoder = AudioSegment.silent(duration=0, frame_rate=self.sample_rate)
|
self.encoder = AudioSegment.silent(duration=0, frame_rate=self.sample_rate)
|
||||||
|
elif self.format == "pcm":
|
||||||
|
# PCM doesn't need initialization, we'll write raw bytes
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported format: {format}")
|
||||||
|
|
||||||
def _write_wav_header(self) -> bytes:
|
def _write_wav_header(self) -> bytes:
|
||||||
"""Write WAV header with correct streaming format"""
|
"""Write WAV header with correct streaming format"""
|
||||||
|
@ -63,42 +78,48 @@ class StreamingAudioWriter:
|
||||||
audio_data: Audio data to write, or None if finalizing
|
audio_data: Audio data to write, or None if finalizing
|
||||||
finalize: Whether this is the final write to close the stream
|
finalize: Whether this is the final write to close the stream
|
||||||
"""
|
"""
|
||||||
buffer = BytesIO()
|
output_buffer = BytesIO()
|
||||||
|
|
||||||
if finalize:
|
if finalize:
|
||||||
if self.format == "wav":
|
if self.format == "wav":
|
||||||
# Write final WAV header with correct sizes
|
# Write final WAV header with correct sizes
|
||||||
buffer.write(b'RIFF')
|
output_buffer.write(b'RIFF')
|
||||||
buffer.write(struct.pack('<L', self.bytes_written + 36))
|
output_buffer.write(struct.pack('<L', self.bytes_written + 36))
|
||||||
buffer.write(b'WAVE')
|
output_buffer.write(b'WAVE')
|
||||||
buffer.write(b'fmt ')
|
output_buffer.write(b'fmt ')
|
||||||
buffer.write(struct.pack('<L', 16))
|
output_buffer.write(struct.pack('<L', 16))
|
||||||
buffer.write(struct.pack('<H', 1))
|
output_buffer.write(struct.pack('<H', 1))
|
||||||
buffer.write(struct.pack('<H', self.channels))
|
output_buffer.write(struct.pack('<H', self.channels))
|
||||||
buffer.write(struct.pack('<L', self.sample_rate))
|
output_buffer.write(struct.pack('<L', self.sample_rate))
|
||||||
buffer.write(struct.pack('<L', self.sample_rate * self.channels * 2))
|
output_buffer.write(struct.pack('<L', self.sample_rate * self.channels * 2))
|
||||||
buffer.write(struct.pack('<H', self.channels * 2))
|
output_buffer.write(struct.pack('<H', self.channels * 2))
|
||||||
buffer.write(struct.pack('<H', 16))
|
output_buffer.write(struct.pack('<H', 16))
|
||||||
buffer.write(b'data')
|
output_buffer.write(b'data')
|
||||||
buffer.write(struct.pack('<L', self.bytes_written))
|
output_buffer.write(struct.pack('<L', self.bytes_written))
|
||||||
elif self.format == "ogg":
|
elif self.format in ["ogg", "opus", "flac"]:
|
||||||
self.writer.close()
|
self.writer.close()
|
||||||
elif self.format == "mp3":
|
return self.buffer.getvalue()
|
||||||
|
elif self.format in ["mp3", "aac"]:
|
||||||
# Final export of any remaining audio
|
# Final export of any remaining audio
|
||||||
if hasattr(self, 'encoder') and len(self.encoder) > 0:
|
if hasattr(self, 'encoder') and len(self.encoder) > 0:
|
||||||
# Export with duration metadata
|
# Export with duration metadata
|
||||||
|
format_args = {
|
||||||
|
"mp3": {"format": "mp3", "codec": "libmp3lame"},
|
||||||
|
"aac": {"format": "adts", "codec": "aac"}
|
||||||
|
}[self.format]
|
||||||
|
|
||||||
self.encoder.export(
|
self.encoder.export(
|
||||||
buffer,
|
output_buffer,
|
||||||
format="mp3",
|
**format_args,
|
||||||
bitrate="192k",
|
bitrate="192k",
|
||||||
parameters=[
|
parameters=[
|
||||||
"-q:a", "2",
|
"-q:a", "2",
|
||||||
"-write_xing", "1", # Force XING/LAME header
|
"-write_xing", "1" if self.format == "mp3" else "0", # XING header for MP3 only
|
||||||
"-metadata", f"duration={self.total_duration/1000}" # Duration in seconds
|
"-metadata", f"duration={self.total_duration/1000}" # Duration in seconds
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.encoder = None
|
self.encoder = None
|
||||||
return buffer.getvalue()
|
return output_buffer.getvalue()
|
||||||
|
|
||||||
if audio_data is None or len(audio_data) == 0:
|
if audio_data is None or len(audio_data) == 0:
|
||||||
return b''
|
return b''
|
||||||
|
@ -106,22 +127,22 @@ class StreamingAudioWriter:
|
||||||
if self.format == "wav":
|
if self.format == "wav":
|
||||||
# For WAV, write raw PCM after the first chunk
|
# For WAV, write raw PCM after the first chunk
|
||||||
if self.bytes_written == 0:
|
if self.bytes_written == 0:
|
||||||
buffer.write(self._write_wav_header())
|
output_buffer.write(self._write_wav_header())
|
||||||
buffer.write(audio_data.tobytes())
|
output_buffer.write(audio_data.tobytes())
|
||||||
self.bytes_written += len(audio_data.tobytes())
|
self.bytes_written += len(audio_data.tobytes())
|
||||||
|
|
||||||
elif self.format == "ogg":
|
elif self.format in ["ogg", "opus", "flac"]:
|
||||||
# OGG/Vorbis handles streaming naturally
|
# Write to soundfile buffer
|
||||||
self.writer.write(audio_data)
|
self.writer.write(audio_data)
|
||||||
self.writer.flush()
|
self.writer.flush()
|
||||||
buffer = self.writer.file
|
# Get current buffer contents
|
||||||
buffer.seek(0, 2) # Seek to end
|
data = self.buffer.getvalue()
|
||||||
chunk = buffer.getvalue()
|
# Clear buffer for next chunk
|
||||||
buffer.seek(0)
|
self.buffer.seek(0)
|
||||||
buffer.truncate()
|
self.buffer.truncate()
|
||||||
return chunk
|
return data
|
||||||
|
|
||||||
elif self.format == "mp3":
|
elif self.format in ["mp3", "aac"]:
|
||||||
# Convert chunk to AudioSegment and encode
|
# Convert chunk to AudioSegment and encode
|
||||||
segment = AudioSegment(
|
segment = AudioSegment(
|
||||||
audio_data.tobytes(),
|
audio_data.tobytes(),
|
||||||
|
@ -137,21 +158,30 @@ class StreamingAudioWriter:
|
||||||
self.encoder = self.encoder + segment
|
self.encoder = self.encoder + segment
|
||||||
|
|
||||||
# Export current state to buffer
|
# Export current state to buffer
|
||||||
self.encoder.export(buffer, format="mp3", bitrate="192k", parameters=[
|
format_args = {
|
||||||
|
"mp3": {"format": "mp3", "codec": "libmp3lame"},
|
||||||
|
"aac": {"format": "adts", "codec": "aac"}
|
||||||
|
}[self.format]
|
||||||
|
|
||||||
|
self.encoder.export(output_buffer, **format_args, bitrate="192k", parameters=[
|
||||||
"-q:a", "2",
|
"-q:a", "2",
|
||||||
"-write_xing", "1", # Force XING/LAME header
|
"-write_xing", "1" if self.format == "mp3" else "0", # XING header for MP3 only
|
||||||
"-metadata", f"duration={self.total_duration/1000}" # Duration in seconds
|
"-metadata", f"duration={self.total_duration/1000}" # Duration in seconds
|
||||||
])
|
])
|
||||||
|
|
||||||
# Get the encoded data
|
# Get the encoded data
|
||||||
encoded_data = buffer.getvalue()
|
encoded_data = output_buffer.getvalue()
|
||||||
|
|
||||||
# Reset encoder to prevent memory growth
|
# Reset encoder to prevent memory growth
|
||||||
self.encoder = AudioSegment.silent(duration=0, frame_rate=self.sample_rate)
|
self.encoder = AudioSegment.silent(duration=0, frame_rate=self.sample_rate)
|
||||||
|
|
||||||
return encoded_data
|
return encoded_data
|
||||||
|
|
||||||
return buffer.getvalue()
|
elif self.format == "pcm":
|
||||||
|
# For PCM, just write raw bytes
|
||||||
|
return audio_data.tobytes()
|
||||||
|
|
||||||
|
return output_buffer.getvalue()
|
||||||
|
|
||||||
def close(self) -> Optional[bytes]:
|
def close(self) -> Optional[bytes]:
|
||||||
"""Finish the audio file and return any remaining data"""
|
"""Finish the audio file and return any remaining data"""
|
||||||
|
@ -161,16 +191,31 @@ class StreamingAudioWriter:
|
||||||
buffer.write(b'RIFF')
|
buffer.write(b'RIFF')
|
||||||
buffer.write(struct.pack('<L', self.bytes_written + 36)) # File size
|
buffer.write(struct.pack('<L', self.bytes_written + 36)) # File size
|
||||||
buffer.write(b'WAVE')
|
buffer.write(b'WAVE')
|
||||||
# ... rest of header ...
|
buffer.write(b'fmt ')
|
||||||
buffer.write(struct.pack('<L', self.bytes_written)) # Data size
|
buffer.write(struct.pack('<L', 16))
|
||||||
|
buffer.write(struct.pack('<H', 1))
|
||||||
|
buffer.write(struct.pack('<H', self.channels))
|
||||||
|
buffer.write(struct.pack('<L', self.sample_rate))
|
||||||
|
buffer.write(struct.pack('<L', self.sample_rate * self.channels * 2))
|
||||||
|
buffer.write(struct.pack('<H', self.channels * 2))
|
||||||
|
buffer.write(struct.pack('<H', 16))
|
||||||
|
buffer.write(b'data')
|
||||||
|
buffer.write(struct.pack('<L', self.bytes_written))
|
||||||
return buffer.getvalue()
|
return buffer.getvalue()
|
||||||
|
|
||||||
elif self.format == "ogg":
|
elif self.format in ["ogg", "opus", "flac"]:
|
||||||
self.writer.close()
|
self.writer.close()
|
||||||
return None
|
return self.buffer.getvalue()
|
||||||
|
|
||||||
elif self.format == "mp3":
|
elif self.format in ["mp3", "aac"]:
|
||||||
# Flush any remaining MP3 frames
|
# Flush any remaining audio
|
||||||
buffer = BytesIO()
|
buffer = BytesIO()
|
||||||
self.encoder.export(buffer, format="mp3")
|
if hasattr(self, 'encoder') and len(self.encoder) > 0:
|
||||||
return buffer.getvalue()
|
format_args = {
|
||||||
|
"mp3": {"format": "mp3", "codec": "libmp3lame"},
|
||||||
|
"aac": {"format": "adts", "codec": "aac"}
|
||||||
|
}[self.format]
|
||||||
|
self.encoder.export(buffer, **format_args)
|
||||||
|
return buffer.getvalue()
|
||||||
|
|
||||||
|
return None
|
|
@ -26,106 +26,128 @@ def sample_audio():
|
||||||
return np.sin(2 * np.pi * frequency * t).astype(np.float32), sample_rate
|
return np.sin(2 * np.pi * frequency * t).astype(np.float32), sample_rate
|
||||||
|
|
||||||
|
|
||||||
def test_convert_to_wav(sample_audio):
|
@pytest.mark.asyncio
|
||||||
|
async def test_convert_to_wav(sample_audio):
|
||||||
"""Test converting to WAV format"""
|
"""Test converting to WAV format"""
|
||||||
audio_data, sample_rate = sample_audio
|
audio_data, sample_rate = sample_audio
|
||||||
result = AudioService.convert_audio(audio_data, sample_rate, "wav")
|
result = await AudioService.convert_audio(audio_data, sample_rate, "wav")
|
||||||
assert isinstance(result, bytes)
|
assert isinstance(result, bytes)
|
||||||
assert len(result) > 0
|
assert len(result) > 0
|
||||||
|
# Check WAV header
|
||||||
|
assert result.startswith(b'RIFF')
|
||||||
|
assert b'WAVE' in result[:12]
|
||||||
|
|
||||||
|
|
||||||
def test_convert_to_mp3(sample_audio):
|
@pytest.mark.asyncio
|
||||||
|
async def test_convert_to_mp3(sample_audio):
|
||||||
"""Test converting to MP3 format"""
|
"""Test converting to MP3 format"""
|
||||||
audio_data, sample_rate = sample_audio
|
audio_data, sample_rate = sample_audio
|
||||||
result = AudioService.convert_audio(audio_data, sample_rate, "mp3")
|
result = await AudioService.convert_audio(audio_data, sample_rate, "mp3")
|
||||||
assert isinstance(result, bytes)
|
assert isinstance(result, bytes)
|
||||||
assert len(result) > 0
|
assert len(result) > 0
|
||||||
|
# Check MP3 header (ID3 or MPEG frame sync)
|
||||||
|
assert result.startswith(b'ID3') or result.startswith(b'\xff\xfb')
|
||||||
|
|
||||||
|
|
||||||
def test_convert_to_opus(sample_audio):
|
@pytest.mark.asyncio
|
||||||
|
async def test_convert_to_opus(sample_audio):
|
||||||
"""Test converting to Opus format"""
|
"""Test converting to Opus format"""
|
||||||
audio_data, sample_rate = sample_audio
|
audio_data, sample_rate = sample_audio
|
||||||
result = AudioService.convert_audio(audio_data, sample_rate, "opus")
|
result = await AudioService.convert_audio(audio_data, sample_rate, "opus")
|
||||||
assert isinstance(result, bytes)
|
assert isinstance(result, bytes)
|
||||||
assert len(result) > 0
|
assert len(result) > 0
|
||||||
|
# Check OGG header
|
||||||
|
assert result.startswith(b'OggS')
|
||||||
|
|
||||||
|
|
||||||
def test_convert_to_flac(sample_audio):
|
@pytest.mark.asyncio
|
||||||
|
async def test_convert_to_flac(sample_audio):
|
||||||
"""Test converting to FLAC format"""
|
"""Test converting to FLAC format"""
|
||||||
audio_data, sample_rate = sample_audio
|
audio_data, sample_rate = sample_audio
|
||||||
result = AudioService.convert_audio(audio_data, sample_rate, "flac")
|
result = await AudioService.convert_audio(audio_data, sample_rate, "flac")
|
||||||
assert isinstance(result, bytes)
|
assert isinstance(result, bytes)
|
||||||
assert len(result) > 0
|
assert len(result) > 0
|
||||||
|
# Check FLAC header
|
||||||
|
assert result.startswith(b'fLaC')
|
||||||
|
|
||||||
|
|
||||||
def test_convert_to_aac(sample_audio):
|
@pytest.mark.asyncio
|
||||||
|
async def test_convert_to_aac(sample_audio):
|
||||||
"""Test converting to AAC format"""
|
"""Test converting to AAC format"""
|
||||||
audio_data, sample_rate = sample_audio
|
audio_data, sample_rate = sample_audio
|
||||||
result = AudioService.convert_audio(audio_data, sample_rate, "aac")
|
result = await AudioService.convert_audio(audio_data, sample_rate, "aac")
|
||||||
assert isinstance(result, bytes)
|
assert isinstance(result, bytes)
|
||||||
assert len(result) > 0
|
assert len(result) > 0
|
||||||
# AAC files typically start with an ADTS header
|
# Check ADTS header (AAC)
|
||||||
assert result.startswith(b'\xff\xf1') or result.startswith(b'\xff\xf9')
|
assert result.startswith(b'\xff\xf0') or result.startswith(b'\xff\xf1')
|
||||||
|
|
||||||
|
|
||||||
def test_convert_to_pcm(sample_audio):
|
@pytest.mark.asyncio
|
||||||
|
async def test_convert_to_pcm(sample_audio):
|
||||||
"""Test converting to PCM format"""
|
"""Test converting to PCM format"""
|
||||||
audio_data, sample_rate = sample_audio
|
audio_data, sample_rate = sample_audio
|
||||||
result = AudioService.convert_audio(audio_data, sample_rate, "pcm")
|
result = await AudioService.convert_audio(audio_data, sample_rate, "pcm")
|
||||||
assert isinstance(result, bytes)
|
assert isinstance(result, bytes)
|
||||||
assert len(result) > 0
|
assert len(result) > 0
|
||||||
|
# PCM is raw bytes, so no header to check
|
||||||
|
|
||||||
|
|
||||||
def test_convert_to_invalid_format_raises_error(sample_audio):
|
@pytest.mark.asyncio
|
||||||
|
async def test_convert_to_invalid_format_raises_error(sample_audio):
|
||||||
"""Test that converting to an invalid format raises an error"""
|
"""Test that converting to an invalid format raises an error"""
|
||||||
audio_data, sample_rate = sample_audio
|
audio_data, sample_rate = sample_audio
|
||||||
with pytest.raises(ValueError, match="Format invalid not supported"):
|
with pytest.raises(ValueError, match="Format invalid not supported"):
|
||||||
AudioService.convert_audio(audio_data, sample_rate, "invalid")
|
await AudioService.convert_audio(audio_data, sample_rate, "invalid")
|
||||||
|
|
||||||
|
|
||||||
def test_normalization_wav(sample_audio):
|
@pytest.mark.asyncio
|
||||||
|
async def test_normalization_wav(sample_audio):
|
||||||
"""Test that WAV output is properly normalized to int16 range"""
|
"""Test that WAV output is properly normalized to int16 range"""
|
||||||
audio_data, sample_rate = sample_audio
|
audio_data, sample_rate = sample_audio
|
||||||
# Create audio data outside int16 range
|
# Create audio data outside int16 range
|
||||||
large_audio = audio_data * 1e5
|
large_audio = audio_data * 1e5
|
||||||
result = AudioService.convert_audio(large_audio, sample_rate, "wav")
|
result = await AudioService.convert_audio(large_audio, sample_rate, "wav")
|
||||||
assert isinstance(result, bytes)
|
assert isinstance(result, bytes)
|
||||||
assert len(result) > 0
|
assert len(result) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_normalization_pcm(sample_audio):
|
@pytest.mark.asyncio
|
||||||
|
async def test_normalization_pcm(sample_audio):
|
||||||
"""Test that PCM output is properly normalized to int16 range"""
|
"""Test that PCM output is properly normalized to int16 range"""
|
||||||
audio_data, sample_rate = sample_audio
|
audio_data, sample_rate = sample_audio
|
||||||
# Create audio data outside int16 range
|
# Create audio data outside int16 range
|
||||||
large_audio = audio_data * 1e5
|
large_audio = audio_data * 1e5
|
||||||
result = AudioService.convert_audio(large_audio, sample_rate, "pcm")
|
result = await AudioService.convert_audio(large_audio, sample_rate, "pcm")
|
||||||
assert isinstance(result, bytes)
|
assert isinstance(result, bytes)
|
||||||
assert len(result) > 0
|
assert len(result) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_audio_data():
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_audio_data():
|
||||||
"""Test handling of invalid audio data"""
|
"""Test handling of invalid audio data"""
|
||||||
invalid_audio = np.array([]) # Empty array
|
invalid_audio = np.array([]) # Empty array
|
||||||
sample_rate = 24000
|
sample_rate = 24000
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
AudioService.convert_audio(invalid_audio, sample_rate, "wav")
|
await AudioService.convert_audio(invalid_audio, sample_rate, "wav")
|
||||||
|
|
||||||
|
|
||||||
def test_different_sample_rates(sample_audio):
|
@pytest.mark.asyncio
|
||||||
|
async def test_different_sample_rates(sample_audio):
|
||||||
"""Test converting audio with different sample rates"""
|
"""Test converting audio with different sample rates"""
|
||||||
audio_data, _ = sample_audio
|
audio_data, _ = sample_audio
|
||||||
sample_rates = [8000, 16000, 44100, 48000]
|
sample_rates = [8000, 16000, 44100, 48000]
|
||||||
|
|
||||||
for rate in sample_rates:
|
for rate in sample_rates:
|
||||||
result = AudioService.convert_audio(audio_data, rate, "wav")
|
result = await AudioService.convert_audio(audio_data, rate, "wav")
|
||||||
assert isinstance(result, bytes)
|
assert isinstance(result, bytes)
|
||||||
assert len(result) > 0
|
assert len(result) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_buffer_position_after_conversion(sample_audio):
|
@pytest.mark.asyncio
|
||||||
|
async def test_buffer_position_after_conversion(sample_audio):
|
||||||
"""Test that buffer position is reset after writing"""
|
"""Test that buffer position is reset after writing"""
|
||||||
audio_data, sample_rate = sample_audio
|
audio_data, sample_rate = sample_audio
|
||||||
result = AudioService.convert_audio(audio_data, sample_rate, "wav")
|
result = await AudioService.convert_audio(audio_data, sample_rate, "wav")
|
||||||
# Convert again to ensure buffer was properly reset
|
# Convert again to ensure buffer was properly reset
|
||||||
result2 = AudioService.convert_audio(audio_data, sample_rate, "wav")
|
result2 = await AudioService.convert_audio(audio_data, sample_rate, "wav")
|
||||||
assert len(result) == len(result2)
|
assert len(result) == len(result2)
|
||||||
|
|
|
@ -59,7 +59,7 @@ async def test_empty_text(tts_service, test_voice):
|
||||||
voice=test_voice,
|
voice=test_voice,
|
||||||
speed=1.0
|
speed=1.0
|
||||||
)
|
)
|
||||||
assert "Text is empty after preprocessing" in str(exc_info.value)
|
assert "No audio chunks were generated successfully" in str(exc_info.value)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_invalid_voice(tts_service):
|
async def test_invalid_voice(tts_service):
|
||||||
|
@ -126,15 +126,17 @@ async def test_combine_voices(tts_service):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chunked_text_processing(tts_service, test_voice, mock_audio_output):
|
async def test_chunked_text_processing(tts_service, test_voice, mock_audio_output):
|
||||||
"""Test processing chunked text"""
|
"""Test processing chunked text"""
|
||||||
long_text = "First sentence. Second sentence. Third sentence."
|
# Create text that will force chunking by exceeding max tokens
|
||||||
|
long_text = "This is a test sentence." * 100 # Should be way over 500 tokens
|
||||||
|
|
||||||
|
# Don't mock smart_split - let it actually split the text
|
||||||
audio, processing_time = await tts_service.generate_audio(
|
audio, processing_time = await tts_service.generate_audio(
|
||||||
text=long_text,
|
text=long_text,
|
||||||
voice=test_voice,
|
voice=test_voice,
|
||||||
speed=1.0,
|
speed=1.0
|
||||||
stitch_long_output=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Should be called multiple times due to chunking
|
||||||
assert tts_service.model_manager.generate.call_count > 1
|
assert tts_service.model_manager.generate.call_count > 1
|
||||||
assert isinstance(audio, np.ndarray)
|
assert isinstance(audio, np.ndarray)
|
||||||
assert processing_time > 0
|
assert processing_time > 0
|
|
@ -34,6 +34,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
# Copy project files including models
|
# Copy project files including models
|
||||||
COPY --chown=appuser:appuser api ./api
|
COPY --chown=appuser:appuser api ./api
|
||||||
COPY --chown=appuser:appuser web ./web
|
COPY --chown=appuser:appuser web ./web
|
||||||
|
COPY --chown=appuser:appuser docker/scripts/download_model.* ./
|
||||||
|
|
||||||
# Install project
|
# Install project
|
||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
|
@ -44,8 +45,19 @@ ENV PYTHONUNBUFFERED=1
|
||||||
ENV PYTHONPATH=/app
|
ENV PYTHONPATH=/app
|
||||||
ENV PATH="/app/.venv/bin:$PATH"
|
ENV PATH="/app/.venv/bin:$PATH"
|
||||||
ENV UV_LINK_MODE=copy
|
ENV UV_LINK_MODE=copy
|
||||||
|
|
||||||
ENV USE_GPU=false
|
ENV USE_GPU=false
|
||||||
ENV USE_ONNX=false
|
ENV USE_ONNX=true
|
||||||
|
ENV DOWNLOAD_ONNX=true
|
||||||
|
ENV DOWNLOAD_PTH=false
|
||||||
|
|
||||||
|
# Download models based on environment variables
|
||||||
|
RUN if [ "$DOWNLOAD_ONNX" = "true" ]; then \
|
||||||
|
python download_model.py --type onnx; \
|
||||||
|
fi && \
|
||||||
|
if [ "$DOWNLOAD_PTH" = "true" ]; then \
|
||||||
|
python download_model.py --type pth; \
|
||||||
|
fi
|
||||||
|
|
||||||
# Run FastAPI server
|
# Run FastAPI server
|
||||||
CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]
|
CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]
|
||||||
|
|
|
@ -1,53 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import requests
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
def download_file(url: str, output_dir: Path) -> None:
|
|
||||||
"""Download a file from URL to the specified directory."""
|
|
||||||
filename = os.path.basename(url)
|
|
||||||
output_path = output_dir / filename
|
|
||||||
|
|
||||||
print(f"Downloading {filename}...")
|
|
||||||
response = requests.get(url, stream=True)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
with open(output_path, 'wb') as f:
|
|
||||||
for chunk in response.iter_content(chunk_size=8192):
|
|
||||||
f.write(chunk)
|
|
||||||
|
|
||||||
def find_project_root() -> Path:
|
|
||||||
"""Find project root by looking for api directory."""
|
|
||||||
max_steps = 5
|
|
||||||
current = Path(__file__).resolve()
|
|
||||||
for _ in range(max_steps):
|
|
||||||
if (current / 'api').is_dir():
|
|
||||||
return current
|
|
||||||
current = current.parent
|
|
||||||
raise RuntimeError("Could not find project root (no api directory found)")
|
|
||||||
|
|
||||||
def main(custom_models: List[str] = None):
|
|
||||||
# Always use top-level models directory relative to project root
|
|
||||||
project_root = find_project_root()
|
|
||||||
models_dir = project_root / 'api' / 'src' / 'models'
|
|
||||||
models_dir.mkdir(exist_ok=True, parents=True)
|
|
||||||
|
|
||||||
# Default ONNX model if no arguments provided
|
|
||||||
default_models = [
|
|
||||||
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.onnx",
|
|
||||||
# "https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19_fp16.onnx"
|
|
||||||
]
|
|
||||||
|
|
||||||
# Use provided models or default
|
|
||||||
models_to_download = custom_models if custom_models else default_models
|
|
||||||
|
|
||||||
for model_url in models_to_download:
|
|
||||||
try:
|
|
||||||
download_file(model_url, models_dir)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error downloading {model_url}: {e}")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main(sys.argv[1:] if len(sys.argv) > 1 else None)
|
|
|
@ -1,32 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# Ensure models directory exists
|
|
||||||
mkdir -p api/src/models
|
|
||||||
|
|
||||||
# Function to download a file
|
|
||||||
download_file() {
|
|
||||||
local url="$1"
|
|
||||||
local filename=$(basename "$url")
|
|
||||||
echo "Downloading $filename..."
|
|
||||||
curl -L "$url" -o "api/src/models/$filename"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Default ONNX model if no arguments provided
|
|
||||||
DEFAULT_MODELS=(
|
|
||||||
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.onnx"
|
|
||||||
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19_fp16.onnx"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use provided models or default
|
|
||||||
if [ $# -gt 0 ]; then
|
|
||||||
MODELS=("$@")
|
|
||||||
else
|
|
||||||
MODELS=("${DEFAULT_MODELS[@]}")
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Download all models
|
|
||||||
for model in "${MODELS[@]}"; do
|
|
||||||
download_file "$model"
|
|
||||||
done
|
|
||||||
|
|
||||||
echo "ONNX model download complete!"
|
|
|
@ -38,6 +38,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
# Copy project files including models
|
# Copy project files including models
|
||||||
COPY --chown=appuser:appuser api ./api
|
COPY --chown=appuser:appuser api ./api
|
||||||
COPY --chown=appuser:appuser web ./web
|
COPY --chown=appuser:appuser web ./web
|
||||||
|
COPY --chown=appuser:appuser docker/scripts/download_model.* ./
|
||||||
|
|
||||||
# Install project with GPU extras
|
# Install project with GPU extras
|
||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
|
@ -48,7 +49,19 @@ ENV PYTHONUNBUFFERED=1
|
||||||
ENV PYTHONPATH=/app
|
ENV PYTHONPATH=/app
|
||||||
ENV PATH="/app/.venv/bin:$PATH"
|
ENV PATH="/app/.venv/bin:$PATH"
|
||||||
ENV UV_LINK_MODE=copy
|
ENV UV_LINK_MODE=copy
|
||||||
|
|
||||||
ENV USE_GPU=true
|
ENV USE_GPU=true
|
||||||
|
ENV USE_ONNX=false
|
||||||
|
ENV DOWNLOAD_PTH=true
|
||||||
|
ENV DOWNLOAD_ONNX=false
|
||||||
|
|
||||||
|
# Download models based on environment variables
|
||||||
|
RUN if [ "$DOWNLOAD_PTH" = "true" ]; then \
|
||||||
|
python download_model.py --type pth; \
|
||||||
|
fi && \
|
||||||
|
if [ "$DOWNLOAD_ONNX" = "true" ]; then \
|
||||||
|
python download_model.py --type onnx; \
|
||||||
|
fi
|
||||||
|
|
||||||
# Run FastAPI server
|
# Run FastAPI server
|
||||||
CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]
|
CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]
|
||||||
|
|
|
@ -1,57 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import requests
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
def download_file(url: str, output_dir: Path) -> None:
|
|
||||||
"""Download a file from URL to the specified directory."""
|
|
||||||
filename = os.path.basename(url)
|
|
||||||
if not filename.endswith('.pth'):
|
|
||||||
print(f"Warning: {filename} is not a .pth file")
|
|
||||||
return
|
|
||||||
|
|
||||||
output_path = output_dir / filename
|
|
||||||
|
|
||||||
print(f"Downloading {filename}...")
|
|
||||||
response = requests.get(url, stream=True)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
with open(output_path, 'wb') as f:
|
|
||||||
for chunk in response.iter_content(chunk_size=8192):
|
|
||||||
f.write(chunk)
|
|
||||||
|
|
||||||
def find_project_root() -> Path:
|
|
||||||
"""Find project root by looking for api directory."""
|
|
||||||
max_steps = 5
|
|
||||||
current = Path(__file__).resolve()
|
|
||||||
for _ in range(max_steps):
|
|
||||||
if (current / 'api').is_dir():
|
|
||||||
return current
|
|
||||||
current = current.parent
|
|
||||||
raise RuntimeError("Could not find project root (no api directory found)")
|
|
||||||
|
|
||||||
def main(custom_models: List[str] = None):
|
|
||||||
# Find project root and ensure models directory exists
|
|
||||||
project_root = find_project_root()
|
|
||||||
models_dir = project_root / 'api' / 'src' / 'models'
|
|
||||||
print(f"Downloading models to {models_dir}")
|
|
||||||
models_dir.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
# Default PTH model if no arguments provided
|
|
||||||
default_models = [
|
|
||||||
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth"
|
|
||||||
]
|
|
||||||
|
|
||||||
# Use provided models or default
|
|
||||||
models_to_download = custom_models if custom_models else default_models
|
|
||||||
|
|
||||||
for model_url in models_to_download:
|
|
||||||
try:
|
|
||||||
download_file(model_url, models_dir)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error downloading {model_url}: {e}")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main(sys.argv[1:] if len(sys.argv) > 1 else None)
|
|
|
@ -1,31 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# Ensure models directory exists
|
|
||||||
mkdir -p api/src/models
|
|
||||||
|
|
||||||
# Function to download a file
|
|
||||||
download_file() {
|
|
||||||
local url="$1"
|
|
||||||
local filename=$(basename "$url")
|
|
||||||
echo "Downloading $filename..."
|
|
||||||
curl -L "$url" -o "api/src/models/$filename"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Default PTH model if no arguments provided
|
|
||||||
DEFAULT_MODELS=(
|
|
||||||
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use provided models or default
|
|
||||||
if [ $# -gt 0 ]; then
|
|
||||||
MODELS=("$@")
|
|
||||||
else
|
|
||||||
MODELS=("${DEFAULT_MODELS[@]}")
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Download all models
|
|
||||||
for model in "${MODELS[@]}"; do
|
|
||||||
download_file "$model"
|
|
||||||
done
|
|
||||||
|
|
||||||
echo "PyTorch model download complete!"
|
|
97
docker/scripts/download_model.py
Normal file
97
docker/scripts/download_model.py
Normal file
|
@ -0,0 +1,97 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
import requests
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
def download_file(url: str, output_dir: Path, model_type: str) -> bool:
|
||||||
|
"""Download a file from URL to the specified directory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if download succeeded, False otherwise
|
||||||
|
"""
|
||||||
|
filename = os.path.basename(url)
|
||||||
|
if not filename.endswith(f'.{model_type}'):
|
||||||
|
print(f"Warning: {filename} is not a .{model_type} file", file=sys.stderr)
|
||||||
|
return False
|
||||||
|
|
||||||
|
output_path = output_dir / filename
|
||||||
|
|
||||||
|
print(f"Downloading {filename}...")
|
||||||
|
try:
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
with open(output_path, 'wb') as f:
|
||||||
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
|
f.write(chunk)
|
||||||
|
print(f"Successfully downloaded {filename}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error downloading {filename}: {e}", file=sys.stderr)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def find_project_root() -> Path:
|
||||||
|
"""Find project root by looking for api directory."""
|
||||||
|
max_steps = 5
|
||||||
|
current = Path(__file__).resolve()
|
||||||
|
for _ in range(max_steps):
|
||||||
|
if (current / 'api').is_dir():
|
||||||
|
return current
|
||||||
|
current = current.parent
|
||||||
|
raise RuntimeError("Could not find project root (no api directory found)")
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
"""Download models to the project.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Exit code (0 for success, 1 for failure)
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser(description='Download model files')
|
||||||
|
parser.add_argument('--type', choices=['pth', 'onnx'], required=True,
|
||||||
|
help='Model type to download (pth or onnx)')
|
||||||
|
parser.add_argument('urls', nargs='*', help='Optional model URLs to download')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Find project root and ensure models directory exists
|
||||||
|
project_root = find_project_root()
|
||||||
|
models_dir = project_root / 'api' / 'src' / 'models'
|
||||||
|
print(f"Downloading models to {models_dir}")
|
||||||
|
models_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# Default models if no arguments provided
|
||||||
|
default_models = {
|
||||||
|
'pth': [
|
||||||
|
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth"
|
||||||
|
],
|
||||||
|
'onnx': [
|
||||||
|
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.onnx",
|
||||||
|
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19_fp16.onnx"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Use provided models or default
|
||||||
|
models_to_download = args.urls if args.urls else default_models[args.type]
|
||||||
|
|
||||||
|
# Download all models
|
||||||
|
success = True
|
||||||
|
for model_url in models_to_download:
|
||||||
|
if not download_file(model_url, models_dir, args.type):
|
||||||
|
success = False
|
||||||
|
|
||||||
|
if success:
|
||||||
|
print(f"{args.type.upper()} model download complete!")
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
print("Some downloads failed", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
109
docker/scripts/download_model.sh
Normal file
109
docker/scripts/download_model.sh
Normal file
|
@ -0,0 +1,109 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Find project root by looking for api directory
|
||||||
|
find_project_root() {
|
||||||
|
local current_dir="$PWD"
|
||||||
|
local max_steps=5
|
||||||
|
local steps=0
|
||||||
|
|
||||||
|
while [ $steps -lt $max_steps ]; do
|
||||||
|
if [ -d "$current_dir/api" ]; then
|
||||||
|
echo "$current_dir"
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
current_dir="$(dirname "$current_dir")"
|
||||||
|
((steps++))
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "Error: Could not find project root (no api directory found)" >&2
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# Function to download a file
|
||||||
|
download_file() {
|
||||||
|
local url="$1"
|
||||||
|
local output_dir="$2"
|
||||||
|
local model_type="$3"
|
||||||
|
local filename=$(basename "$url")
|
||||||
|
|
||||||
|
# Validate file extension
|
||||||
|
if [[ ! "$filename" =~ \.$model_type$ ]]; then
|
||||||
|
echo "Warning: $filename is not a .$model_type file" >&2
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
echo "Downloading $filename..."
|
||||||
|
if curl -L "$url" -o "$output_dir/$filename"; then
|
||||||
|
echo "Successfully downloaded $filename"
|
||||||
|
return 0
|
||||||
|
else
|
||||||
|
echo "Error downloading $filename" >&2
|
||||||
|
return 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Parse arguments
|
||||||
|
MODEL_TYPE=""
|
||||||
|
while [[ $# -gt 0 ]]; do
|
||||||
|
case $1 in
|
||||||
|
--type)
|
||||||
|
MODEL_TYPE="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
# If no flag specified, treat remaining args as model URLs
|
||||||
|
break
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
# Validate model type
|
||||||
|
if [ "$MODEL_TYPE" != "pth" ] && [ "$MODEL_TYPE" != "onnx" ]; then
|
||||||
|
echo "Error: Must specify model type with --type (pth or onnx)" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Find project root and ensure models directory exists
|
||||||
|
PROJECT_ROOT=$(find_project_root)
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
MODELS_DIR="$PROJECT_ROOT/api/src/models"
|
||||||
|
echo "Downloading models to $MODELS_DIR"
|
||||||
|
mkdir -p "$MODELS_DIR"
|
||||||
|
|
||||||
|
# Default models if no arguments provided
|
||||||
|
if [ "$MODEL_TYPE" = "pth" ]; then
|
||||||
|
DEFAULT_MODELS=(
|
||||||
|
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth"
|
||||||
|
)
|
||||||
|
else
|
||||||
|
DEFAULT_MODELS=(
|
||||||
|
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.onnx"
|
||||||
|
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19_fp16.onnx"
|
||||||
|
)
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Use provided models or default
|
||||||
|
if [ $# -gt 0 ]; then
|
||||||
|
MODELS=("$@")
|
||||||
|
else
|
||||||
|
MODELS=("${DEFAULT_MODELS[@]}")
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Download all models
|
||||||
|
success=true
|
||||||
|
for model in "${MODELS[@]}"; do
|
||||||
|
if ! download_file "$model" "$MODELS_DIR" "$MODEL_TYPE"; then
|
||||||
|
success=false
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
if [ "$success" = true ]; then
|
||||||
|
echo "${MODEL_TYPE^^} model download complete!"
|
||||||
|
exit 0
|
||||||
|
else
|
||||||
|
echo "Some downloads failed" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
Loading…
Add table
Reference in a new issue