mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
WIP: v1_0_0 migration
This commit is contained in:
parent
1345b6c81a
commit
9867fc398f
16 changed files with 422 additions and 269 deletions
|
@ -56,7 +56,9 @@ The service can be accessed through either the API endpoints or the Gradio web i
|
|||
cd docker/gpu # OR
|
||||
# cd docker/cpu # Run this or the above
|
||||
docker compose up --build
|
||||
# if you are missing any models, run the .py or .sh scrips in the respective folders
|
||||
# if you are missing any models, run:
|
||||
# python ../scripts/download_model.py --type pth # for GPU
|
||||
# python ../scripts/download_model.py --type onnx # for CPU
|
||||
```
|
||||
|
||||
Once started:
|
||||
|
|
|
@ -336,3 +336,4 @@ async def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
|
|||
_manager_instance = ModelManager(config)
|
||||
await _manager_instance.initialize()
|
||||
return _manager_instance
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ FastAPI OpenAI Compatible API
|
|||
import os
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import uvicorn
|
||||
|
@ -57,9 +58,23 @@ async def lifespan(app: FastAPI):
|
|||
|
||||
# Initialize model with warmup and get status
|
||||
device, model, voicepack_count = await model_manager.initialize_with_warmup(voice_manager)
|
||||
except FileNotFoundError:
|
||||
logger.error("""
|
||||
Model files not found! You need to either:
|
||||
|
||||
1. Download models using the scripts:
|
||||
GPU: python docker/scripts/download_model.py --type pth
|
||||
CPU: python docker/scripts/download_model.py --type onnx
|
||||
|
||||
2. Set environment variables in docker-compose:
|
||||
GPU: DOWNLOAD_PTH=true
|
||||
CPU: DOWNLOAD_ONNX=true
|
||||
""")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize model: {e}")
|
||||
raise
|
||||
|
||||
boundary = "░" * 2*12
|
||||
startup_msg = f"""
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
"""OpenAI-compatible router for text-to-speech"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import AsyncGenerator, Dict, List, Union
|
||||
|
@ -217,9 +219,9 @@ async def create_speech(
|
|||
stitch_long_output=True
|
||||
)
|
||||
|
||||
# Convert to requested format
|
||||
# Convert to requested format - removed stream parameter
|
||||
content = await AudioService.convert_audio(
|
||||
audio, 24000, request.response_format, is_first_chunk=True, stream=False
|
||||
audio, 24000, request.response_format, is_first_chunk=True
|
||||
)
|
||||
|
||||
return Response(
|
||||
|
|
|
@ -16,7 +16,6 @@ class AudioNormalizer:
|
|||
"""Handles audio normalization state for a single stream"""
|
||||
|
||||
def __init__(self):
|
||||
self.int16_max = np.iinfo(np.int16).max
|
||||
self.chunk_trim_ms = settings.gap_trim_ms
|
||||
self.sample_rate = 24000 # Sample rate of the audio
|
||||
self.samples_to_trim = int(self.chunk_trim_ms * self.sample_rate / 1000)
|
||||
|
@ -30,20 +29,23 @@ class AudioNormalizer:
|
|||
Returns:
|
||||
Normalized and trimmed audio data
|
||||
"""
|
||||
# Convert to float32 for processing
|
||||
audio_float = audio_data.astype(np.float32)
|
||||
if len(audio_data) == 0:
|
||||
raise ValueError("Empty audio data")
|
||||
|
||||
# Trim start and end if enough samples
|
||||
if len(audio_float) > (2 * self.samples_to_trim):
|
||||
audio_float = audio_float[self.samples_to_trim:-self.samples_to_trim]
|
||||
if len(audio_data) > (2 * self.samples_to_trim):
|
||||
audio_data = audio_data[self.samples_to_trim:-self.samples_to_trim]
|
||||
|
||||
# Scale to int16 range
|
||||
return (audio_float * 32767).astype(np.int16)
|
||||
# Scale directly to int16 range with clipping
|
||||
return np.clip(audio_data * 32767, -32768, 32767).astype(np.int16)
|
||||
|
||||
|
||||
class AudioService:
|
||||
"""Service for audio format conversions with streaming support"""
|
||||
|
||||
# Supported formats
|
||||
SUPPORTED_FORMATS = {"wav", "mp3", "opus", "flac", "aac", "pcm", "ogg"}
|
||||
|
||||
# Default audio format settings balanced for speed and compression
|
||||
DEFAULT_SETTINGS = {
|
||||
"mp3": {
|
||||
|
@ -86,6 +88,10 @@ class AudioService:
|
|||
Bytes of the converted audio chunk
|
||||
"""
|
||||
try:
|
||||
# Validate format
|
||||
if output_format not in AudioService.SUPPORTED_FORMATS:
|
||||
raise ValueError(f"Format {output_format} not supported")
|
||||
|
||||
# Always normalize audio to ensure proper amplitude scaling
|
||||
if normalizer is None:
|
||||
normalizer = AudioNormalizer()
|
||||
|
|
|
@ -17,26 +17,41 @@ class StreamingAudioWriter:
|
|||
self.sample_rate = sample_rate
|
||||
self.channels = channels
|
||||
self.bytes_written = 0
|
||||
self.buffer = BytesIO()
|
||||
|
||||
# Format-specific setup
|
||||
if self.format == "wav":
|
||||
self._write_wav_header()
|
||||
elif self.format == "ogg":
|
||||
elif self.format in ["ogg", "opus"]:
|
||||
# For OGG/Opus, write to memory buffer
|
||||
self.writer = sf.SoundFile(
|
||||
file=BytesIO(),
|
||||
file=self.buffer,
|
||||
mode='w',
|
||||
samplerate=sample_rate,
|
||||
channels=channels,
|
||||
format='OGG',
|
||||
subtype='VORBIS'
|
||||
subtype='VORBIS' if self.format == "ogg" else "OPUS"
|
||||
)
|
||||
elif self.format == "mp3":
|
||||
# For MP3, we'll use pydub's incremental writer
|
||||
self.buffer = BytesIO()
|
||||
elif self.format == "flac":
|
||||
# For FLAC, write to memory buffer
|
||||
self.writer = sf.SoundFile(
|
||||
file=self.buffer,
|
||||
mode='w',
|
||||
samplerate=sample_rate,
|
||||
channels=channels,
|
||||
format='FLAC'
|
||||
)
|
||||
elif self.format in ["mp3", "aac"]:
|
||||
# For MP3/AAC, we'll use pydub's incremental writer
|
||||
self.segments = [] # Store segments until we have enough data
|
||||
self.total_duration = 0 # Track total duration in milliseconds
|
||||
# Initialize an empty AudioSegment as our encoder
|
||||
self.encoder = AudioSegment.silent(duration=0, frame_rate=self.sample_rate)
|
||||
elif self.format == "pcm":
|
||||
# PCM doesn't need initialization, we'll write raw bytes
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: {format}")
|
||||
|
||||
def _write_wav_header(self) -> bytes:
|
||||
"""Write WAV header with correct streaming format"""
|
||||
|
@ -63,42 +78,48 @@ class StreamingAudioWriter:
|
|||
audio_data: Audio data to write, or None if finalizing
|
||||
finalize: Whether this is the final write to close the stream
|
||||
"""
|
||||
buffer = BytesIO()
|
||||
output_buffer = BytesIO()
|
||||
|
||||
if finalize:
|
||||
if self.format == "wav":
|
||||
# Write final WAV header with correct sizes
|
||||
buffer.write(b'RIFF')
|
||||
buffer.write(struct.pack('<L', self.bytes_written + 36))
|
||||
buffer.write(b'WAVE')
|
||||
buffer.write(b'fmt ')
|
||||
buffer.write(struct.pack('<L', 16))
|
||||
buffer.write(struct.pack('<H', 1))
|
||||
buffer.write(struct.pack('<H', self.channels))
|
||||
buffer.write(struct.pack('<L', self.sample_rate))
|
||||
buffer.write(struct.pack('<L', self.sample_rate * self.channels * 2))
|
||||
buffer.write(struct.pack('<H', self.channels * 2))
|
||||
buffer.write(struct.pack('<H', 16))
|
||||
buffer.write(b'data')
|
||||
buffer.write(struct.pack('<L', self.bytes_written))
|
||||
elif self.format == "ogg":
|
||||
output_buffer.write(b'RIFF')
|
||||
output_buffer.write(struct.pack('<L', self.bytes_written + 36))
|
||||
output_buffer.write(b'WAVE')
|
||||
output_buffer.write(b'fmt ')
|
||||
output_buffer.write(struct.pack('<L', 16))
|
||||
output_buffer.write(struct.pack('<H', 1))
|
||||
output_buffer.write(struct.pack('<H', self.channels))
|
||||
output_buffer.write(struct.pack('<L', self.sample_rate))
|
||||
output_buffer.write(struct.pack('<L', self.sample_rate * self.channels * 2))
|
||||
output_buffer.write(struct.pack('<H', self.channels * 2))
|
||||
output_buffer.write(struct.pack('<H', 16))
|
||||
output_buffer.write(b'data')
|
||||
output_buffer.write(struct.pack('<L', self.bytes_written))
|
||||
elif self.format in ["ogg", "opus", "flac"]:
|
||||
self.writer.close()
|
||||
elif self.format == "mp3":
|
||||
return self.buffer.getvalue()
|
||||
elif self.format in ["mp3", "aac"]:
|
||||
# Final export of any remaining audio
|
||||
if hasattr(self, 'encoder') and len(self.encoder) > 0:
|
||||
# Export with duration metadata
|
||||
format_args = {
|
||||
"mp3": {"format": "mp3", "codec": "libmp3lame"},
|
||||
"aac": {"format": "adts", "codec": "aac"}
|
||||
}[self.format]
|
||||
|
||||
self.encoder.export(
|
||||
buffer,
|
||||
format="mp3",
|
||||
output_buffer,
|
||||
**format_args,
|
||||
bitrate="192k",
|
||||
parameters=[
|
||||
"-q:a", "2",
|
||||
"-write_xing", "1", # Force XING/LAME header
|
||||
"-write_xing", "1" if self.format == "mp3" else "0", # XING header for MP3 only
|
||||
"-metadata", f"duration={self.total_duration/1000}" # Duration in seconds
|
||||
]
|
||||
)
|
||||
self.encoder = None
|
||||
return buffer.getvalue()
|
||||
return output_buffer.getvalue()
|
||||
|
||||
if audio_data is None or len(audio_data) == 0:
|
||||
return b''
|
||||
|
@ -106,22 +127,22 @@ class StreamingAudioWriter:
|
|||
if self.format == "wav":
|
||||
# For WAV, write raw PCM after the first chunk
|
||||
if self.bytes_written == 0:
|
||||
buffer.write(self._write_wav_header())
|
||||
buffer.write(audio_data.tobytes())
|
||||
output_buffer.write(self._write_wav_header())
|
||||
output_buffer.write(audio_data.tobytes())
|
||||
self.bytes_written += len(audio_data.tobytes())
|
||||
|
||||
elif self.format == "ogg":
|
||||
# OGG/Vorbis handles streaming naturally
|
||||
elif self.format in ["ogg", "opus", "flac"]:
|
||||
# Write to soundfile buffer
|
||||
self.writer.write(audio_data)
|
||||
self.writer.flush()
|
||||
buffer = self.writer.file
|
||||
buffer.seek(0, 2) # Seek to end
|
||||
chunk = buffer.getvalue()
|
||||
buffer.seek(0)
|
||||
buffer.truncate()
|
||||
return chunk
|
||||
# Get current buffer contents
|
||||
data = self.buffer.getvalue()
|
||||
# Clear buffer for next chunk
|
||||
self.buffer.seek(0)
|
||||
self.buffer.truncate()
|
||||
return data
|
||||
|
||||
elif self.format == "mp3":
|
||||
elif self.format in ["mp3", "aac"]:
|
||||
# Convert chunk to AudioSegment and encode
|
||||
segment = AudioSegment(
|
||||
audio_data.tobytes(),
|
||||
|
@ -137,21 +158,30 @@ class StreamingAudioWriter:
|
|||
self.encoder = self.encoder + segment
|
||||
|
||||
# Export current state to buffer
|
||||
self.encoder.export(buffer, format="mp3", bitrate="192k", parameters=[
|
||||
format_args = {
|
||||
"mp3": {"format": "mp3", "codec": "libmp3lame"},
|
||||
"aac": {"format": "adts", "codec": "aac"}
|
||||
}[self.format]
|
||||
|
||||
self.encoder.export(output_buffer, **format_args, bitrate="192k", parameters=[
|
||||
"-q:a", "2",
|
||||
"-write_xing", "1", # Force XING/LAME header
|
||||
"-write_xing", "1" if self.format == "mp3" else "0", # XING header for MP3 only
|
||||
"-metadata", f"duration={self.total_duration/1000}" # Duration in seconds
|
||||
])
|
||||
|
||||
# Get the encoded data
|
||||
encoded_data = buffer.getvalue()
|
||||
encoded_data = output_buffer.getvalue()
|
||||
|
||||
# Reset encoder to prevent memory growth
|
||||
self.encoder = AudioSegment.silent(duration=0, frame_rate=self.sample_rate)
|
||||
|
||||
return encoded_data
|
||||
|
||||
return buffer.getvalue()
|
||||
elif self.format == "pcm":
|
||||
# For PCM, just write raw bytes
|
||||
return audio_data.tobytes()
|
||||
|
||||
return output_buffer.getvalue()
|
||||
|
||||
def close(self) -> Optional[bytes]:
|
||||
"""Finish the audio file and return any remaining data"""
|
||||
|
@ -161,16 +191,31 @@ class StreamingAudioWriter:
|
|||
buffer.write(b'RIFF')
|
||||
buffer.write(struct.pack('<L', self.bytes_written + 36)) # File size
|
||||
buffer.write(b'WAVE')
|
||||
# ... rest of header ...
|
||||
buffer.write(struct.pack('<L', self.bytes_written)) # Data size
|
||||
buffer.write(b'fmt ')
|
||||
buffer.write(struct.pack('<L', 16))
|
||||
buffer.write(struct.pack('<H', 1))
|
||||
buffer.write(struct.pack('<H', self.channels))
|
||||
buffer.write(struct.pack('<L', self.sample_rate))
|
||||
buffer.write(struct.pack('<L', self.sample_rate * self.channels * 2))
|
||||
buffer.write(struct.pack('<H', self.channels * 2))
|
||||
buffer.write(struct.pack('<H', 16))
|
||||
buffer.write(b'data')
|
||||
buffer.write(struct.pack('<L', self.bytes_written))
|
||||
return buffer.getvalue()
|
||||
|
||||
elif self.format == "ogg":
|
||||
elif self.format in ["ogg", "opus", "flac"]:
|
||||
self.writer.close()
|
||||
return None
|
||||
return self.buffer.getvalue()
|
||||
|
||||
elif self.format == "mp3":
|
||||
# Flush any remaining MP3 frames
|
||||
elif self.format in ["mp3", "aac"]:
|
||||
# Flush any remaining audio
|
||||
buffer = BytesIO()
|
||||
self.encoder.export(buffer, format="mp3")
|
||||
if hasattr(self, 'encoder') and len(self.encoder) > 0:
|
||||
format_args = {
|
||||
"mp3": {"format": "mp3", "codec": "libmp3lame"},
|
||||
"aac": {"format": "adts", "codec": "aac"}
|
||||
}[self.format]
|
||||
self.encoder.export(buffer, **format_args)
|
||||
return buffer.getvalue()
|
||||
|
||||
return None
|
|
@ -26,106 +26,128 @@ def sample_audio():
|
|||
return np.sin(2 * np.pi * frequency * t).astype(np.float32), sample_rate
|
||||
|
||||
|
||||
def test_convert_to_wav(sample_audio):
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_to_wav(sample_audio):
|
||||
"""Test converting to WAV format"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
result = AudioService.convert_audio(audio_data, sample_rate, "wav")
|
||||
result = await AudioService.convert_audio(audio_data, sample_rate, "wav")
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) > 0
|
||||
# Check WAV header
|
||||
assert result.startswith(b'RIFF')
|
||||
assert b'WAVE' in result[:12]
|
||||
|
||||
|
||||
def test_convert_to_mp3(sample_audio):
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_to_mp3(sample_audio):
|
||||
"""Test converting to MP3 format"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
result = AudioService.convert_audio(audio_data, sample_rate, "mp3")
|
||||
result = await AudioService.convert_audio(audio_data, sample_rate, "mp3")
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) > 0
|
||||
# Check MP3 header (ID3 or MPEG frame sync)
|
||||
assert result.startswith(b'ID3') or result.startswith(b'\xff\xfb')
|
||||
|
||||
|
||||
def test_convert_to_opus(sample_audio):
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_to_opus(sample_audio):
|
||||
"""Test converting to Opus format"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
result = AudioService.convert_audio(audio_data, sample_rate, "opus")
|
||||
result = await AudioService.convert_audio(audio_data, sample_rate, "opus")
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) > 0
|
||||
# Check OGG header
|
||||
assert result.startswith(b'OggS')
|
||||
|
||||
|
||||
def test_convert_to_flac(sample_audio):
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_to_flac(sample_audio):
|
||||
"""Test converting to FLAC format"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
result = AudioService.convert_audio(audio_data, sample_rate, "flac")
|
||||
result = await AudioService.convert_audio(audio_data, sample_rate, "flac")
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) > 0
|
||||
# Check FLAC header
|
||||
assert result.startswith(b'fLaC')
|
||||
|
||||
|
||||
def test_convert_to_aac(sample_audio):
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_to_aac(sample_audio):
|
||||
"""Test converting to AAC format"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
result = AudioService.convert_audio(audio_data, sample_rate, "aac")
|
||||
result = await AudioService.convert_audio(audio_data, sample_rate, "aac")
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) > 0
|
||||
# AAC files typically start with an ADTS header
|
||||
assert result.startswith(b'\xff\xf1') or result.startswith(b'\xff\xf9')
|
||||
# Check ADTS header (AAC)
|
||||
assert result.startswith(b'\xff\xf0') or result.startswith(b'\xff\xf1')
|
||||
|
||||
|
||||
def test_convert_to_pcm(sample_audio):
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_to_pcm(sample_audio):
|
||||
"""Test converting to PCM format"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
result = AudioService.convert_audio(audio_data, sample_rate, "pcm")
|
||||
result = await AudioService.convert_audio(audio_data, sample_rate, "pcm")
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) > 0
|
||||
# PCM is raw bytes, so no header to check
|
||||
|
||||
|
||||
def test_convert_to_invalid_format_raises_error(sample_audio):
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_to_invalid_format_raises_error(sample_audio):
|
||||
"""Test that converting to an invalid format raises an error"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
with pytest.raises(ValueError, match="Format invalid not supported"):
|
||||
AudioService.convert_audio(audio_data, sample_rate, "invalid")
|
||||
await AudioService.convert_audio(audio_data, sample_rate, "invalid")
|
||||
|
||||
|
||||
def test_normalization_wav(sample_audio):
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalization_wav(sample_audio):
|
||||
"""Test that WAV output is properly normalized to int16 range"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
# Create audio data outside int16 range
|
||||
large_audio = audio_data * 1e5
|
||||
result = AudioService.convert_audio(large_audio, sample_rate, "wav")
|
||||
result = await AudioService.convert_audio(large_audio, sample_rate, "wav")
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
def test_normalization_pcm(sample_audio):
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalization_pcm(sample_audio):
|
||||
"""Test that PCM output is properly normalized to int16 range"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
# Create audio data outside int16 range
|
||||
large_audio = audio_data * 1e5
|
||||
result = AudioService.convert_audio(large_audio, sample_rate, "pcm")
|
||||
result = await AudioService.convert_audio(large_audio, sample_rate, "pcm")
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
def test_invalid_audio_data():
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_audio_data():
|
||||
"""Test handling of invalid audio data"""
|
||||
invalid_audio = np.array([]) # Empty array
|
||||
sample_rate = 24000
|
||||
with pytest.raises(ValueError):
|
||||
AudioService.convert_audio(invalid_audio, sample_rate, "wav")
|
||||
await AudioService.convert_audio(invalid_audio, sample_rate, "wav")
|
||||
|
||||
|
||||
def test_different_sample_rates(sample_audio):
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_sample_rates(sample_audio):
|
||||
"""Test converting audio with different sample rates"""
|
||||
audio_data, _ = sample_audio
|
||||
sample_rates = [8000, 16000, 44100, 48000]
|
||||
|
||||
for rate in sample_rates:
|
||||
result = AudioService.convert_audio(audio_data, rate, "wav")
|
||||
result = await AudioService.convert_audio(audio_data, rate, "wav")
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
def test_buffer_position_after_conversion(sample_audio):
|
||||
@pytest.mark.asyncio
|
||||
async def test_buffer_position_after_conversion(sample_audio):
|
||||
"""Test that buffer position is reset after writing"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
result = AudioService.convert_audio(audio_data, sample_rate, "wav")
|
||||
result = await AudioService.convert_audio(audio_data, sample_rate, "wav")
|
||||
# Convert again to ensure buffer was properly reset
|
||||
result2 = AudioService.convert_audio(audio_data, sample_rate, "wav")
|
||||
result2 = await AudioService.convert_audio(audio_data, sample_rate, "wav")
|
||||
assert len(result) == len(result2)
|
||||
|
|
|
@ -59,7 +59,7 @@ async def test_empty_text(tts_service, test_voice):
|
|||
voice=test_voice,
|
||||
speed=1.0
|
||||
)
|
||||
assert "Text is empty after preprocessing" in str(exc_info.value)
|
||||
assert "No audio chunks were generated successfully" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_voice(tts_service):
|
||||
|
@ -126,15 +126,17 @@ async def test_combine_voices(tts_service):
|
|||
@pytest.mark.asyncio
|
||||
async def test_chunked_text_processing(tts_service, test_voice, mock_audio_output):
|
||||
"""Test processing chunked text"""
|
||||
long_text = "First sentence. Second sentence. Third sentence."
|
||||
# Create text that will force chunking by exceeding max tokens
|
||||
long_text = "This is a test sentence." * 100 # Should be way over 500 tokens
|
||||
|
||||
# Don't mock smart_split - let it actually split the text
|
||||
audio, processing_time = await tts_service.generate_audio(
|
||||
text=long_text,
|
||||
voice=test_voice,
|
||||
speed=1.0,
|
||||
stitch_long_output=True
|
||||
speed=1.0
|
||||
)
|
||||
|
||||
# Should be called multiple times due to chunking
|
||||
assert tts_service.model_manager.generate.call_count > 1
|
||||
assert isinstance(audio, np.ndarray)
|
||||
assert processing_time > 0
|
|
@ -34,6 +34,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
|||
# Copy project files including models
|
||||
COPY --chown=appuser:appuser api ./api
|
||||
COPY --chown=appuser:appuser web ./web
|
||||
COPY --chown=appuser:appuser docker/scripts/download_model.* ./
|
||||
|
||||
# Install project
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
|
@ -44,8 +45,19 @@ ENV PYTHONUNBUFFERED=1
|
|||
ENV PYTHONPATH=/app
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
ENV USE_GPU=false
|
||||
ENV USE_ONNX=false
|
||||
ENV USE_ONNX=true
|
||||
ENV DOWNLOAD_ONNX=true
|
||||
ENV DOWNLOAD_PTH=false
|
||||
|
||||
# Download models based on environment variables
|
||||
RUN if [ "$DOWNLOAD_ONNX" = "true" ]; then \
|
||||
python download_model.py --type onnx; \
|
||||
fi && \
|
||||
if [ "$DOWNLOAD_PTH" = "true" ]; then \
|
||||
python download_model.py --type pth; \
|
||||
fi
|
||||
|
||||
# Run FastAPI server
|
||||
CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]
|
||||
|
|
|
@ -1,53 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
import os
|
||||
import sys
|
||||
import requests
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
def download_file(url: str, output_dir: Path) -> None:
|
||||
"""Download a file from URL to the specified directory."""
|
||||
filename = os.path.basename(url)
|
||||
output_path = output_dir / filename
|
||||
|
||||
print(f"Downloading {filename}...")
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(output_path, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
def find_project_root() -> Path:
|
||||
"""Find project root by looking for api directory."""
|
||||
max_steps = 5
|
||||
current = Path(__file__).resolve()
|
||||
for _ in range(max_steps):
|
||||
if (current / 'api').is_dir():
|
||||
return current
|
||||
current = current.parent
|
||||
raise RuntimeError("Could not find project root (no api directory found)")
|
||||
|
||||
def main(custom_models: List[str] = None):
|
||||
# Always use top-level models directory relative to project root
|
||||
project_root = find_project_root()
|
||||
models_dir = project_root / 'api' / 'src' / 'models'
|
||||
models_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# Default ONNX model if no arguments provided
|
||||
default_models = [
|
||||
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.onnx",
|
||||
# "https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19_fp16.onnx"
|
||||
]
|
||||
|
||||
# Use provided models or default
|
||||
models_to_download = custom_models if custom_models else default_models
|
||||
|
||||
for model_url in models_to_download:
|
||||
try:
|
||||
download_file(model_url, models_dir)
|
||||
except Exception as e:
|
||||
print(f"Error downloading {model_url}: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(sys.argv[1:] if len(sys.argv) > 1 else None)
|
|
@ -1,32 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Ensure models directory exists
|
||||
mkdir -p api/src/models
|
||||
|
||||
# Function to download a file
|
||||
download_file() {
|
||||
local url="$1"
|
||||
local filename=$(basename "$url")
|
||||
echo "Downloading $filename..."
|
||||
curl -L "$url" -o "api/src/models/$filename"
|
||||
}
|
||||
|
||||
# Default ONNX model if no arguments provided
|
||||
DEFAULT_MODELS=(
|
||||
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.onnx"
|
||||
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19_fp16.onnx"
|
||||
)
|
||||
|
||||
# Use provided models or default
|
||||
if [ $# -gt 0 ]; then
|
||||
MODELS=("$@")
|
||||
else
|
||||
MODELS=("${DEFAULT_MODELS[@]}")
|
||||
fi
|
||||
|
||||
# Download all models
|
||||
for model in "${MODELS[@]}"; do
|
||||
download_file "$model"
|
||||
done
|
||||
|
||||
echo "ONNX model download complete!"
|
|
@ -38,6 +38,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
|||
# Copy project files including models
|
||||
COPY --chown=appuser:appuser api ./api
|
||||
COPY --chown=appuser:appuser web ./web
|
||||
COPY --chown=appuser:appuser docker/scripts/download_model.* ./
|
||||
|
||||
# Install project with GPU extras
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
|
@ -48,7 +49,19 @@ ENV PYTHONUNBUFFERED=1
|
|||
ENV PYTHONPATH=/app
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
ENV USE_GPU=true
|
||||
ENV USE_ONNX=false
|
||||
ENV DOWNLOAD_PTH=true
|
||||
ENV DOWNLOAD_ONNX=false
|
||||
|
||||
# Download models based on environment variables
|
||||
RUN if [ "$DOWNLOAD_PTH" = "true" ]; then \
|
||||
python download_model.py --type pth; \
|
||||
fi && \
|
||||
if [ "$DOWNLOAD_ONNX" = "true" ]; then \
|
||||
python download_model.py --type onnx; \
|
||||
fi
|
||||
|
||||
# Run FastAPI server
|
||||
CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]
|
||||
|
|
|
@ -1,57 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
import os
|
||||
import sys
|
||||
import requests
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
def download_file(url: str, output_dir: Path) -> None:
|
||||
"""Download a file from URL to the specified directory."""
|
||||
filename = os.path.basename(url)
|
||||
if not filename.endswith('.pth'):
|
||||
print(f"Warning: {filename} is not a .pth file")
|
||||
return
|
||||
|
||||
output_path = output_dir / filename
|
||||
|
||||
print(f"Downloading {filename}...")
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(output_path, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
def find_project_root() -> Path:
|
||||
"""Find project root by looking for api directory."""
|
||||
max_steps = 5
|
||||
current = Path(__file__).resolve()
|
||||
for _ in range(max_steps):
|
||||
if (current / 'api').is_dir():
|
||||
return current
|
||||
current = current.parent
|
||||
raise RuntimeError("Could not find project root (no api directory found)")
|
||||
|
||||
def main(custom_models: List[str] = None):
|
||||
# Find project root and ensure models directory exists
|
||||
project_root = find_project_root()
|
||||
models_dir = project_root / 'api' / 'src' / 'models'
|
||||
print(f"Downloading models to {models_dir}")
|
||||
models_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Default PTH model if no arguments provided
|
||||
default_models = [
|
||||
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth"
|
||||
]
|
||||
|
||||
# Use provided models or default
|
||||
models_to_download = custom_models if custom_models else default_models
|
||||
|
||||
for model_url in models_to_download:
|
||||
try:
|
||||
download_file(model_url, models_dir)
|
||||
except Exception as e:
|
||||
print(f"Error downloading {model_url}: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(sys.argv[1:] if len(sys.argv) > 1 else None)
|
|
@ -1,31 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Ensure models directory exists
|
||||
mkdir -p api/src/models
|
||||
|
||||
# Function to download a file
|
||||
download_file() {
|
||||
local url="$1"
|
||||
local filename=$(basename "$url")
|
||||
echo "Downloading $filename..."
|
||||
curl -L "$url" -o "api/src/models/$filename"
|
||||
}
|
||||
|
||||
# Default PTH model if no arguments provided
|
||||
DEFAULT_MODELS=(
|
||||
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth"
|
||||
)
|
||||
|
||||
# Use provided models or default
|
||||
if [ $# -gt 0 ]; then
|
||||
MODELS=("$@")
|
||||
else
|
||||
MODELS=("${DEFAULT_MODELS[@]}")
|
||||
fi
|
||||
|
||||
# Download all models
|
||||
for model in "${MODELS[@]}"; do
|
||||
download_file "$model"
|
||||
done
|
||||
|
||||
echo "PyTorch model download complete!"
|
97
docker/scripts/download_model.py
Normal file
97
docker/scripts/download_model.py
Normal file
|
@ -0,0 +1,97 @@
|
|||
#!/usr/bin/env python3
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import requests
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
def download_file(url: str, output_dir: Path, model_type: str) -> bool:
|
||||
"""Download a file from URL to the specified directory.
|
||||
|
||||
Returns:
|
||||
bool: True if download succeeded, False otherwise
|
||||
"""
|
||||
filename = os.path.basename(url)
|
||||
if not filename.endswith(f'.{model_type}'):
|
||||
print(f"Warning: {filename} is not a .{model_type} file", file=sys.stderr)
|
||||
return False
|
||||
|
||||
output_path = output_dir / filename
|
||||
|
||||
print(f"Downloading {filename}...")
|
||||
try:
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(output_path, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
print(f"Successfully downloaded {filename}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error downloading {filename}: {e}", file=sys.stderr)
|
||||
return False
|
||||
|
||||
def find_project_root() -> Path:
|
||||
"""Find project root by looking for api directory."""
|
||||
max_steps = 5
|
||||
current = Path(__file__).resolve()
|
||||
for _ in range(max_steps):
|
||||
if (current / 'api').is_dir():
|
||||
return current
|
||||
current = current.parent
|
||||
raise RuntimeError("Could not find project root (no api directory found)")
|
||||
|
||||
def main() -> int:
|
||||
"""Download models to the project.
|
||||
|
||||
Returns:
|
||||
int: Exit code (0 for success, 1 for failure)
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='Download model files')
|
||||
parser.add_argument('--type', choices=['pth', 'onnx'], required=True,
|
||||
help='Model type to download (pth or onnx)')
|
||||
parser.add_argument('urls', nargs='*', help='Optional model URLs to download')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# Find project root and ensure models directory exists
|
||||
project_root = find_project_root()
|
||||
models_dir = project_root / 'api' / 'src' / 'models'
|
||||
print(f"Downloading models to {models_dir}")
|
||||
models_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Default models if no arguments provided
|
||||
default_models = {
|
||||
'pth': [
|
||||
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth"
|
||||
],
|
||||
'onnx': [
|
||||
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.onnx",
|
||||
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19_fp16.onnx"
|
||||
]
|
||||
}
|
||||
|
||||
# Use provided models or default
|
||||
models_to_download = args.urls if args.urls else default_models[args.type]
|
||||
|
||||
# Download all models
|
||||
success = True
|
||||
for model_url in models_to_download:
|
||||
if not download_file(model_url, models_dir, args.type):
|
||||
success = False
|
||||
|
||||
if success:
|
||||
print(f"{args.type.upper()} model download complete!")
|
||||
return 0
|
||||
else:
|
||||
print("Some downloads failed", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
109
docker/scripts/download_model.sh
Normal file
109
docker/scripts/download_model.sh
Normal file
|
@ -0,0 +1,109 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Find project root by looking for api directory
|
||||
find_project_root() {
|
||||
local current_dir="$PWD"
|
||||
local max_steps=5
|
||||
local steps=0
|
||||
|
||||
while [ $steps -lt $max_steps ]; do
|
||||
if [ -d "$current_dir/api" ]; then
|
||||
echo "$current_dir"
|
||||
return 0
|
||||
fi
|
||||
current_dir="$(dirname "$current_dir")"
|
||||
((steps++))
|
||||
done
|
||||
|
||||
echo "Error: Could not find project root (no api directory found)" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Function to download a file
|
||||
download_file() {
|
||||
local url="$1"
|
||||
local output_dir="$2"
|
||||
local model_type="$3"
|
||||
local filename=$(basename "$url")
|
||||
|
||||
# Validate file extension
|
||||
if [[ ! "$filename" =~ \.$model_type$ ]]; then
|
||||
echo "Warning: $filename is not a .$model_type file" >&2
|
||||
return 1
|
||||
}
|
||||
|
||||
echo "Downloading $filename..."
|
||||
if curl -L "$url" -o "$output_dir/$filename"; then
|
||||
echo "Successfully downloaded $filename"
|
||||
return 0
|
||||
else
|
||||
echo "Error downloading $filename" >&2
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Parse arguments
|
||||
MODEL_TYPE=""
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--type)
|
||||
MODEL_TYPE="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
# If no flag specified, treat remaining args as model URLs
|
||||
break
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Validate model type
|
||||
if [ "$MODEL_TYPE" != "pth" ] && [ "$MODEL_TYPE" != "onnx" ]; then
|
||||
echo "Error: Must specify model type with --type (pth or onnx)" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Find project root and ensure models directory exists
|
||||
PROJECT_ROOT=$(find_project_root)
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
MODELS_DIR="$PROJECT_ROOT/api/src/models"
|
||||
echo "Downloading models to $MODELS_DIR"
|
||||
mkdir -p "$MODELS_DIR"
|
||||
|
||||
# Default models if no arguments provided
|
||||
if [ "$MODEL_TYPE" = "pth" ]; then
|
||||
DEFAULT_MODELS=(
|
||||
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth"
|
||||
)
|
||||
else
|
||||
DEFAULT_MODELS=(
|
||||
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.onnx"
|
||||
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19_fp16.onnx"
|
||||
)
|
||||
fi
|
||||
|
||||
# Use provided models or default
|
||||
if [ $# -gt 0 ]; then
|
||||
MODELS=("$@")
|
||||
else
|
||||
MODELS=("${DEFAULT_MODELS[@]}")
|
||||
fi
|
||||
|
||||
# Download all models
|
||||
success=true
|
||||
for model in "${MODELS[@]}"; do
|
||||
if ! download_file "$model" "$MODELS_DIR" "$MODEL_TYPE"; then
|
||||
success=false
|
||||
fi
|
||||
done
|
||||
|
||||
if [ "$success" = true ]; then
|
||||
echo "${MODEL_TYPE^^} model download complete!"
|
||||
exit 0
|
||||
else
|
||||
echo "Some downloads failed" >&2
|
||||
exit 1
|
||||
fi
|
Loading…
Add table
Reference in a new issue