mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Dockerfile optimizations:
Moved to nvidia cuda base image. Merged all apt-get install commands into one. Removed espeak-ng. Added and uv cache clean command at the end. Removed g++ at the end. Concatenated all ENV commands. Removed api tests folder.
This commit is contained in:
parent
543cbecc1a
commit
a3d23e9dad
14 changed files with 39 additions and 1845 deletions
|
@ -1 +0,0 @@
|
|||
# Make tests directory a Python package
|
|
@ -1,71 +0,0 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import torch
|
||||
|
||||
from api.src.inference.model_manager import ModelManager
|
||||
from api.src.inference.voice_manager import VoiceManager
|
||||
from api.src.services.tts_service import TTSService
|
||||
from api.src.structures.model_schemas import VoiceConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_voice_tensor():
|
||||
"""Load a real voice tensor for testing."""
|
||||
voice_path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)), "src/voices/af_bella.pt"
|
||||
)
|
||||
return torch.load(voice_path, map_location="cpu", weights_only=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_audio_output():
|
||||
"""Load pre-generated test audio for consistent testing."""
|
||||
test_audio_path = os.path.join(
|
||||
os.path.dirname(__file__), "test_data/test_audio.npy"
|
||||
)
|
||||
return np.load(test_audio_path) # Return as numpy array instead of bytes
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def mock_model_manager(mock_audio_output):
|
||||
"""Mock model manager for testing."""
|
||||
manager = AsyncMock(spec=ModelManager)
|
||||
manager.get_backend = MagicMock()
|
||||
|
||||
async def mock_generate(*args, **kwargs):
|
||||
# Simulate successful audio generation
|
||||
return np.random.rand(24000).astype(np.float32) # 1 second of random audio data
|
||||
|
||||
manager.generate = AsyncMock(side_effect=mock_generate)
|
||||
return manager
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def mock_voice_manager(mock_voice_tensor):
|
||||
"""Mock voice manager for testing."""
|
||||
manager = AsyncMock(spec=VoiceManager)
|
||||
manager.get_voice_path = MagicMock(return_value="/mock/path/voice.pt")
|
||||
manager.load_voice = AsyncMock(return_value=mock_voice_tensor)
|
||||
manager.list_voices = AsyncMock(return_value=["voice1", "voice2"])
|
||||
manager.combine_voices = AsyncMock(return_value="voice1_voice2")
|
||||
return manager
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def tts_service(mock_model_manager, mock_voice_manager):
|
||||
"""Get mocked TTS service instance."""
|
||||
service = TTSService()
|
||||
service.model_manager = mock_model_manager
|
||||
service._voice_manager = mock_voice_manager
|
||||
return service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_voice():
|
||||
"""Return a test voice name."""
|
||||
return "voice1"
|
|
@ -1,256 +0,0 @@
|
|||
"""Tests for AudioService"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from api.src.inference.base import AudioChunk
|
||||
from api.src.services.audio import AudioNormalizer, AudioService
|
||||
from api.src.services.streaming_audio_writer import StreamingAudioWriter
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_settings():
|
||||
"""Mock settings for all tests"""
|
||||
with patch("api.src.services.audio.settings") as mock_settings:
|
||||
mock_settings.gap_trim_ms = 250
|
||||
yield mock_settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_audio():
|
||||
"""Generate a simple sine wave for testing"""
|
||||
sample_rate = 24000
|
||||
duration = 0.1 # 100ms
|
||||
t = np.linspace(0, duration, int(sample_rate * duration))
|
||||
frequency = 440 # A4 note
|
||||
return np.sin(2 * np.pi * frequency * t).astype(np.float32), sample_rate
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_to_wav(sample_audio):
|
||||
"""Test converting to WAV format"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
|
||||
writer = StreamingAudioWriter("wav", sample_rate=24000)
|
||||
# Write and finalize in one step for WAV
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), "wav", writer, is_last_chunk=False
|
||||
)
|
||||
|
||||
writer.close()
|
||||
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
assert len(audio_chunk.output) > 0
|
||||
# Check WAV header
|
||||
assert audio_chunk.output.startswith(b"RIFF")
|
||||
assert b"WAVE" in audio_chunk.output[:12]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_to_mp3(sample_audio):
|
||||
"""Test converting to MP3 format"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
|
||||
writer = StreamingAudioWriter("mp3", sample_rate=24000)
|
||||
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), "mp3", writer
|
||||
)
|
||||
|
||||
writer.close()
|
||||
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
assert len(audio_chunk.output) > 0
|
||||
# Check MP3 header (ID3 or MPEG frame sync)
|
||||
assert audio_chunk.output.startswith(b"ID3") or audio_chunk.output.startswith(
|
||||
b"\xff\xfb"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_to_opus(sample_audio):
|
||||
"""Test converting to Opus format"""
|
||||
|
||||
audio_data, sample_rate = sample_audio
|
||||
|
||||
writer = StreamingAudioWriter("opus", sample_rate=24000)
|
||||
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), "opus", writer
|
||||
)
|
||||
|
||||
writer.close()
|
||||
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
assert len(audio_chunk.output) > 0
|
||||
# Check OGG header
|
||||
assert audio_chunk.output.startswith(b"OggS")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_to_flac(sample_audio):
|
||||
"""Test converting to FLAC format"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
|
||||
writer = StreamingAudioWriter("flac", sample_rate=24000)
|
||||
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), "flac", writer
|
||||
)
|
||||
|
||||
writer.close()
|
||||
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
assert len(audio_chunk.output) > 0
|
||||
# Check FLAC header
|
||||
assert audio_chunk.output.startswith(b"fLaC")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_to_aac(sample_audio):
|
||||
"""Test converting to M4A format"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
|
||||
writer = StreamingAudioWriter("aac", sample_rate=24000)
|
||||
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), "aac", writer
|
||||
)
|
||||
|
||||
writer.close()
|
||||
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
assert len(audio_chunk.output) > 0
|
||||
# Check ADTS header (AAC)
|
||||
assert audio_chunk.output.startswith(b"\xff\xf0") or audio_chunk.output.startswith(
|
||||
b"\xff\xf1"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_to_pcm(sample_audio):
|
||||
"""Test converting to PCM format"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
|
||||
writer = StreamingAudioWriter("pcm", sample_rate=24000)
|
||||
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), "pcm", writer
|
||||
)
|
||||
|
||||
writer.close()
|
||||
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
assert len(audio_chunk.output) > 0
|
||||
# PCM is raw bytes, so no header to check
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_to_invalid_format_raises_error(sample_audio):
|
||||
"""Test that converting to an invalid format raises an error"""
|
||||
# audio_data, sample_rate = sample_audio
|
||||
with pytest.raises(ValueError, match="Unsupported format: invalid"):
|
||||
writer = StreamingAudioWriter("invalid", sample_rate=24000)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalization_wav(sample_audio):
|
||||
"""Test that WAV output is properly normalized to int16 range"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
|
||||
writer = StreamingAudioWriter("wav", sample_rate=24000)
|
||||
|
||||
# Create audio data outside int16 range
|
||||
large_audio = audio_data * 1e5
|
||||
# Write and finalize in one step for WAV
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(large_audio), "wav", writer
|
||||
)
|
||||
|
||||
writer.close()
|
||||
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
assert len(audio_chunk.output) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalization_pcm(sample_audio):
|
||||
"""Test that PCM output is properly normalized to int16 range"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
|
||||
writer = StreamingAudioWriter("pcm", sample_rate=24000)
|
||||
|
||||
# Create audio data outside int16 range
|
||||
large_audio = audio_data * 1e5
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(large_audio), "pcm", writer
|
||||
)
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
assert len(audio_chunk.output) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_audio_data():
|
||||
"""Test handling of invalid audio data"""
|
||||
invalid_audio = np.array([]) # Empty array
|
||||
sample_rate = 24000
|
||||
|
||||
writer = StreamingAudioWriter("wav", sample_rate=24000)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await AudioService.convert_audio(invalid_audio, sample_rate, "wav", writer)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_sample_rates(sample_audio):
|
||||
"""Test converting audio with different sample rates"""
|
||||
audio_data, _ = sample_audio
|
||||
sample_rates = [8000, 16000, 44100, 48000]
|
||||
|
||||
for rate in sample_rates:
|
||||
writer = StreamingAudioWriter("wav", sample_rate=rate)
|
||||
|
||||
audio_chunk = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), "wav", writer
|
||||
)
|
||||
|
||||
writer.close()
|
||||
|
||||
assert isinstance(audio_chunk.output, bytes)
|
||||
assert isinstance(audio_chunk, AudioChunk)
|
||||
assert len(audio_chunk.output) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_buffer_position_after_conversion(sample_audio):
|
||||
"""Test that buffer position is reset after writing"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
|
||||
writer = StreamingAudioWriter("wav", sample_rate=24000)
|
||||
|
||||
# Write and finalize in one step for first conversion
|
||||
audio_chunk1 = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), "wav", writer, is_last_chunk=True
|
||||
)
|
||||
assert isinstance(audio_chunk1.output, bytes)
|
||||
assert isinstance(audio_chunk1, AudioChunk)
|
||||
# Convert again to ensure buffer was properly reset
|
||||
|
||||
writer = StreamingAudioWriter("wav", sample_rate=24000)
|
||||
|
||||
audio_chunk2 = await AudioService.convert_audio(
|
||||
AudioChunk(audio_data), "wav", writer, is_last_chunk=True
|
||||
)
|
||||
assert isinstance(audio_chunk2.output, bytes)
|
||||
assert isinstance(audio_chunk2, AudioChunk)
|
||||
assert len(audio_chunk1.output) == len(audio_chunk2.output)
|
|
@ -1,23 +0,0 @@
|
|||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def generate_test_audio():
|
||||
"""Generate test audio data - 1 second of 440Hz tone"""
|
||||
# Create 1 second of silence at 24kHz
|
||||
audio = np.zeros(24000, dtype=np.float32)
|
||||
|
||||
# Add a simple sine wave to make it non-zero
|
||||
t = np.linspace(0, 1, 24000)
|
||||
audio += 0.5 * np.sin(2 * np.pi * 440 * t) # 440 Hz tone at half amplitude
|
||||
|
||||
# Create test_data directory if it doesn't exist
|
||||
os.makedirs("api/tests/test_data", exist_ok=True)
|
||||
|
||||
# Save the test audio
|
||||
np.save("api/tests/test_data/test_audio.npy", audio)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_test_audio()
|
Binary file not shown.
|
@ -1,34 +0,0 @@
|
|||
import base64
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
def test_generate_captioned_speech():
|
||||
"""Test the generate_captioned_speech function with mocked responses"""
|
||||
# Mock the API responses
|
||||
mock_audio_response = MagicMock()
|
||||
mock_audio_response.status_code = 200
|
||||
|
||||
mock_timestamps_response = MagicMock()
|
||||
mock_timestamps_response.status_code = 200
|
||||
mock_timestamps_response.content = json.dumps(
|
||||
{
|
||||
"audio": base64.b64encode(b"mock audio data").decode("utf-8"),
|
||||
"timestamps": [{"word": "test", "start_time": 0.0, "end_time": 1.0}],
|
||||
}
|
||||
)
|
||||
|
||||
# Patch the HTTP requests
|
||||
with patch("requests.post", return_value=mock_timestamps_response):
|
||||
# Import here to avoid module-level import issues
|
||||
from examples.captioned_speech_example import generate_captioned_speech
|
||||
|
||||
# Test the function
|
||||
audio, timestamps = generate_captioned_speech("test text")
|
||||
|
||||
# Verify we got both audio and timestamps
|
||||
assert audio == b"mock audio data"
|
||||
assert timestamps == [{"word": "test", "start_time": 0.0, "end_time": 1.0}]
|
|
@ -1,165 +0,0 @@
|
|||
from unittest.mock import ANY, MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from api.src.inference.kokoro_v1 import KokoroV1
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def kokoro_backend():
|
||||
"""Create a KokoroV1 instance for testing."""
|
||||
return KokoroV1()
|
||||
|
||||
|
||||
def test_initial_state(kokoro_backend):
|
||||
"""Test initial state of KokoroV1."""
|
||||
assert not kokoro_backend.is_loaded
|
||||
assert kokoro_backend._model is None
|
||||
assert kokoro_backend._pipelines == {} # Now using dict of pipelines
|
||||
# Device should be set based on settings
|
||||
assert kokoro_backend.device in ["cuda", "cpu"]
|
||||
|
||||
|
||||
@patch("torch.cuda.is_available", return_value=True)
|
||||
@patch("torch.cuda.memory_allocated", return_value=5e9)
|
||||
def test_memory_management(mock_memory, mock_cuda, kokoro_backend):
|
||||
"""Test GPU memory management functions."""
|
||||
# Patch backend so it thinks we have cuda
|
||||
with patch.object(kokoro_backend, "_device", "cuda"):
|
||||
# Test memory check
|
||||
with patch("api.src.inference.kokoro_v1.model_config") as mock_config:
|
||||
mock_config.pytorch_gpu.memory_threshold = 4
|
||||
assert kokoro_backend._check_memory() == True
|
||||
|
||||
mock_config.pytorch_gpu.memory_threshold = 6
|
||||
assert kokoro_backend._check_memory() == False
|
||||
|
||||
|
||||
@patch("torch.cuda.empty_cache")
|
||||
@patch("torch.cuda.synchronize")
|
||||
def test_clear_memory(mock_sync, mock_clear, kokoro_backend):
|
||||
"""Test memory clearing."""
|
||||
with patch.object(kokoro_backend, "_device", "cuda"):
|
||||
kokoro_backend._clear_memory()
|
||||
mock_clear.assert_called_once()
|
||||
mock_sync.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_model_validation(kokoro_backend):
|
||||
"""Test model loading validation."""
|
||||
with pytest.raises(RuntimeError, match="Failed to load Kokoro model"):
|
||||
await kokoro_backend.load_model("nonexistent_model.pth")
|
||||
|
||||
|
||||
def test_unload_with_pipelines(kokoro_backend):
|
||||
"""Test model unloading with multiple pipelines."""
|
||||
# Mock loaded state with multiple pipelines
|
||||
kokoro_backend._model = MagicMock()
|
||||
pipeline_a = MagicMock()
|
||||
pipeline_e = MagicMock()
|
||||
kokoro_backend._pipelines = {"a": pipeline_a, "e": pipeline_e}
|
||||
assert kokoro_backend.is_loaded
|
||||
|
||||
# Test unload
|
||||
kokoro_backend.unload()
|
||||
assert not kokoro_backend.is_loaded
|
||||
assert kokoro_backend._model is None
|
||||
assert kokoro_backend._pipelines == {} # All pipelines should be cleared
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_validation(kokoro_backend):
|
||||
"""Test generation validation."""
|
||||
with pytest.raises(RuntimeError, match="Model not loaded"):
|
||||
async for _ in kokoro_backend.generate("test", "voice"):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_from_tokens_validation(kokoro_backend):
|
||||
"""Test token generation validation."""
|
||||
with pytest.raises(RuntimeError, match="Model not loaded"):
|
||||
async for _ in kokoro_backend.generate_from_tokens("test tokens", "voice"):
|
||||
pass
|
||||
|
||||
|
||||
def test_get_pipeline_creates_new(kokoro_backend):
|
||||
"""Test that _get_pipeline creates new pipeline for new language code."""
|
||||
# Mock loaded state
|
||||
kokoro_backend._model = MagicMock()
|
||||
|
||||
# Mock KPipeline
|
||||
mock_pipeline = MagicMock()
|
||||
with patch(
|
||||
"api.src.inference.kokoro_v1.KPipeline", return_value=mock_pipeline
|
||||
) as mock_kpipeline:
|
||||
# Get pipeline for Spanish
|
||||
pipeline_e = kokoro_backend._get_pipeline("e")
|
||||
|
||||
# Should create new pipeline with correct params
|
||||
mock_kpipeline.assert_called_once_with(
|
||||
lang_code="e", model=kokoro_backend._model, device=kokoro_backend._device
|
||||
)
|
||||
assert pipeline_e == mock_pipeline
|
||||
assert kokoro_backend._pipelines["e"] == mock_pipeline
|
||||
|
||||
|
||||
def test_get_pipeline_reuses_existing(kokoro_backend):
|
||||
"""Test that _get_pipeline reuses existing pipeline for same language code."""
|
||||
# Mock loaded state
|
||||
kokoro_backend._model = MagicMock()
|
||||
|
||||
# Mock KPipeline
|
||||
mock_pipeline = MagicMock()
|
||||
with patch(
|
||||
"api.src.inference.kokoro_v1.KPipeline", return_value=mock_pipeline
|
||||
) as mock_kpipeline:
|
||||
# Get pipeline twice for same language
|
||||
pipeline1 = kokoro_backend._get_pipeline("e")
|
||||
pipeline2 = kokoro_backend._get_pipeline("e")
|
||||
|
||||
# Should only create pipeline once
|
||||
mock_kpipeline.assert_called_once()
|
||||
assert pipeline1 == pipeline2
|
||||
assert kokoro_backend._pipelines["e"] == mock_pipeline
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_uses_correct_pipeline(kokoro_backend):
|
||||
"""Test that generate uses correct pipeline for language code."""
|
||||
# Mock loaded state
|
||||
kokoro_backend._model = MagicMock()
|
||||
|
||||
# Mock voice path handling
|
||||
with (
|
||||
patch("api.src.core.paths.load_voice_tensor") as mock_load_voice,
|
||||
patch("api.src.core.paths.save_voice_tensor"),
|
||||
patch("tempfile.gettempdir") as mock_tempdir,
|
||||
):
|
||||
mock_load_voice.return_value = torch.ones(1)
|
||||
mock_tempdir.return_value = "/tmp"
|
||||
|
||||
# Mock KPipeline
|
||||
mock_pipeline = MagicMock()
|
||||
mock_pipeline.return_value = iter([]) # Empty generator for testing
|
||||
with patch("api.src.inference.kokoro_v1.KPipeline", return_value=mock_pipeline):
|
||||
# Generate with Spanish voice and explicit lang_code
|
||||
async for _ in kokoro_backend.generate("test", "ef_voice", lang_code="e"):
|
||||
pass
|
||||
|
||||
# Should create pipeline with Spanish lang_code
|
||||
assert "e" in kokoro_backend._pipelines
|
||||
# Use ANY to match the temp file path since it's dynamic
|
||||
mock_pipeline.assert_called_with(
|
||||
"test",
|
||||
voice=ANY, # Don't check exact path since it's dynamic
|
||||
speed=1.0,
|
||||
model=kokoro_backend._model,
|
||||
)
|
||||
# Verify the voice path is a temp file path
|
||||
call_args = mock_pipeline.call_args
|
||||
assert isinstance(call_args[1]["voice"], str)
|
||||
assert call_args[1]["voice"].startswith("/tmp/temp_voice_")
|
|
@ -1,317 +0,0 @@
|
|||
"""Tests for text normalization service"""
|
||||
|
||||
import pytest
|
||||
|
||||
from api.src.services.text_processing.normalizer import normalize_text
|
||||
from api.src.structures.schemas import NormalizationOptions
|
||||
|
||||
|
||||
def test_url_protocols():
|
||||
"""Test URL protocol handling"""
|
||||
assert (
|
||||
normalize_text(
|
||||
"Check out https://example.com",
|
||||
normalization_options=NormalizationOptions(),
|
||||
)
|
||||
== "Check out https example dot com"
|
||||
)
|
||||
assert (
|
||||
normalize_text(
|
||||
"Visit http://site.com", normalization_options=NormalizationOptions()
|
||||
)
|
||||
== "Visit http site dot com"
|
||||
)
|
||||
assert (
|
||||
normalize_text(
|
||||
"Go to https://test.org/path", normalization_options=NormalizationOptions()
|
||||
)
|
||||
== "Go to https test dot org slash path"
|
||||
)
|
||||
|
||||
|
||||
def test_url_www():
|
||||
"""Test www prefix handling"""
|
||||
assert (
|
||||
normalize_text(
|
||||
"Go to www.example.com", normalization_options=NormalizationOptions()
|
||||
)
|
||||
== "Go to www example dot com"
|
||||
)
|
||||
assert (
|
||||
normalize_text(
|
||||
"Visit www.test.org/docs", normalization_options=NormalizationOptions()
|
||||
)
|
||||
== "Visit www test dot org slash docs"
|
||||
)
|
||||
assert (
|
||||
normalize_text(
|
||||
"Check www.site.com?q=test", normalization_options=NormalizationOptions()
|
||||
)
|
||||
== "Check www site dot com question-mark q equals test"
|
||||
)
|
||||
|
||||
|
||||
def test_url_localhost():
|
||||
"""Test localhost URL handling"""
|
||||
assert (
|
||||
normalize_text(
|
||||
"Running on localhost:7860", normalization_options=NormalizationOptions()
|
||||
)
|
||||
== "Running on localhost colon seventy-eight sixty"
|
||||
)
|
||||
assert (
|
||||
normalize_text(
|
||||
"Server at localhost:8080/api", normalization_options=NormalizationOptions()
|
||||
)
|
||||
== "Server at localhost colon eighty eighty slash api"
|
||||
)
|
||||
assert (
|
||||
normalize_text(
|
||||
"Test localhost:3000/test?v=1", normalization_options=NormalizationOptions()
|
||||
)
|
||||
== "Test localhost colon three thousand slash test question-mark v equals one"
|
||||
)
|
||||
|
||||
|
||||
def test_url_ip_addresses():
|
||||
"""Test IP address URL handling"""
|
||||
assert (
|
||||
normalize_text(
|
||||
"Access 0.0.0.0:9090/test", normalization_options=NormalizationOptions()
|
||||
)
|
||||
== "Access zero dot zero dot zero dot zero colon ninety ninety slash test"
|
||||
)
|
||||
assert (
|
||||
normalize_text(
|
||||
"API at 192.168.1.1:8000", normalization_options=NormalizationOptions()
|
||||
)
|
||||
== "API at one hundred and ninety-two dot one hundred and sixty-eight dot one dot one colon eight thousand"
|
||||
)
|
||||
assert (
|
||||
normalize_text("Server 127.0.0.1", normalization_options=NormalizationOptions())
|
||||
== "Server one hundred and twenty-seven dot zero dot zero dot one"
|
||||
)
|
||||
|
||||
|
||||
def test_url_raw_domains():
|
||||
"""Test raw domain handling"""
|
||||
assert (
|
||||
normalize_text(
|
||||
"Visit google.com/search", normalization_options=NormalizationOptions()
|
||||
)
|
||||
== "Visit google dot com slash search"
|
||||
)
|
||||
assert (
|
||||
normalize_text(
|
||||
"Go to example.com/path?q=test",
|
||||
normalization_options=NormalizationOptions(),
|
||||
)
|
||||
== "Go to example dot com slash path question-mark q equals test"
|
||||
)
|
||||
assert (
|
||||
normalize_text(
|
||||
"Check docs.test.com", normalization_options=NormalizationOptions()
|
||||
)
|
||||
== "Check docs dot test dot com"
|
||||
)
|
||||
|
||||
|
||||
def test_url_email_addresses():
|
||||
"""Test email address handling"""
|
||||
assert (
|
||||
normalize_text(
|
||||
"Email me at user@example.com", normalization_options=NormalizationOptions()
|
||||
)
|
||||
== "Email me at user at example dot com"
|
||||
)
|
||||
assert (
|
||||
normalize_text(
|
||||
"Contact admin@test.org", normalization_options=NormalizationOptions()
|
||||
)
|
||||
== "Contact admin at test dot org"
|
||||
)
|
||||
assert (
|
||||
normalize_text(
|
||||
"Send to test.user@site.com", normalization_options=NormalizationOptions()
|
||||
)
|
||||
== "Send to test dot user at site dot com"
|
||||
)
|
||||
|
||||
|
||||
def test_money():
|
||||
"""Test that money text is normalized correctly"""
|
||||
assert (
|
||||
normalize_text(
|
||||
"He lost $5.3 thousand.", normalization_options=NormalizationOptions()
|
||||
)
|
||||
== "He lost five point three thousand dollars."
|
||||
)
|
||||
|
||||
assert (
|
||||
normalize_text(
|
||||
"He went gambling and lost about $25.05k.",
|
||||
normalization_options=NormalizationOptions(),
|
||||
)
|
||||
== "He went gambling and lost about twenty-five point zero five thousand dollars."
|
||||
)
|
||||
|
||||
assert (
|
||||
normalize_text(
|
||||
"To put it weirdly -$6.9 million",
|
||||
normalization_options=NormalizationOptions(),
|
||||
)
|
||||
== "To put it weirdly minus six point nine million dollars"
|
||||
)
|
||||
|
||||
assert (
|
||||
normalize_text("It costs $50.3.", normalization_options=NormalizationOptions())
|
||||
== "It costs fifty dollars and thirty cents."
|
||||
)
|
||||
|
||||
assert (
|
||||
normalize_text(
|
||||
"The plant cost $200,000.8.", normalization_options=NormalizationOptions()
|
||||
)
|
||||
== "The plant cost two hundred thousand dollars and eighty cents."
|
||||
)
|
||||
|
||||
assert (
|
||||
normalize_text(
|
||||
"€30.2 is in euros", normalization_options=NormalizationOptions()
|
||||
)
|
||||
== "thirty euros and twenty cents is in euros"
|
||||
)
|
||||
|
||||
|
||||
def test_time():
|
||||
"""Test time normalization"""
|
||||
|
||||
assert (
|
||||
normalize_text(
|
||||
"Your flight leaves at 10:35 pm",
|
||||
normalization_options=NormalizationOptions(),
|
||||
)
|
||||
== "Your flight leaves at ten thirty-five pm"
|
||||
)
|
||||
|
||||
assert (
|
||||
normalize_text(
|
||||
"He departed for london around 5:03 am.",
|
||||
normalization_options=NormalizationOptions(),
|
||||
)
|
||||
== "He departed for london around five oh three am."
|
||||
)
|
||||
|
||||
assert (
|
||||
normalize_text(
|
||||
"Only the 13:42 and 15:12 slots are available.",
|
||||
normalization_options=NormalizationOptions(),
|
||||
)
|
||||
== "Only the thirteen forty-two and fifteen twelve slots are available."
|
||||
)
|
||||
|
||||
assert (
|
||||
normalize_text(
|
||||
"It is currently 1:00 pm", normalization_options=NormalizationOptions()
|
||||
)
|
||||
== "It is currently one pm"
|
||||
)
|
||||
|
||||
assert (
|
||||
normalize_text(
|
||||
"It is currently 3:00", normalization_options=NormalizationOptions()
|
||||
)
|
||||
== "It is currently three o'clock"
|
||||
)
|
||||
|
||||
assert (
|
||||
normalize_text(
|
||||
"12:00 am is midnight", normalization_options=NormalizationOptions()
|
||||
)
|
||||
== "twelve am is midnight"
|
||||
)
|
||||
|
||||
|
||||
def test_number():
|
||||
"""Test number normalization"""
|
||||
|
||||
assert (
|
||||
normalize_text(
|
||||
"I bought 1035 cans of soda", normalization_options=NormalizationOptions()
|
||||
)
|
||||
== "I bought one thousand and thirty-five cans of soda"
|
||||
)
|
||||
|
||||
assert (
|
||||
normalize_text(
|
||||
"The bus has a maximum capacity of 62 people",
|
||||
normalization_options=NormalizationOptions(),
|
||||
)
|
||||
== "The bus has a maximum capacity of sixty-two people"
|
||||
)
|
||||
|
||||
assert (
|
||||
normalize_text(
|
||||
"There are 1300 products left in stock",
|
||||
normalization_options=NormalizationOptions(),
|
||||
)
|
||||
== "There are one thousand, three hundred products left in stock"
|
||||
)
|
||||
|
||||
assert (
|
||||
normalize_text(
|
||||
"The population is 7,890,000 people.",
|
||||
normalization_options=NormalizationOptions(),
|
||||
)
|
||||
== "The population is seven million, eight hundred and ninety thousand people."
|
||||
)
|
||||
|
||||
assert (
|
||||
normalize_text(
|
||||
"He looked around but only found 1.6k of the 10k bricks",
|
||||
normalization_options=NormalizationOptions(),
|
||||
)
|
||||
== "He looked around but only found one point six thousand of the ten thousand bricks"
|
||||
)
|
||||
|
||||
assert (
|
||||
normalize_text(
|
||||
"The book has 342 pages.", normalization_options=NormalizationOptions()
|
||||
)
|
||||
== "The book has three hundred and forty-two pages."
|
||||
)
|
||||
|
||||
assert (
|
||||
normalize_text(
|
||||
"He made -50 sales today.", normalization_options=NormalizationOptions()
|
||||
)
|
||||
== "He made minus fifty sales today."
|
||||
)
|
||||
|
||||
assert (
|
||||
normalize_text(
|
||||
"56.789 to the power of 1.35 million",
|
||||
normalization_options=NormalizationOptions(),
|
||||
)
|
||||
== "fifty-six point seven eight nine to the power of one point three five million"
|
||||
)
|
||||
|
||||
|
||||
def test_non_url_text():
|
||||
"""Test that non-URL text is unaffected"""
|
||||
assert (
|
||||
normalize_text(
|
||||
"This is not.a.url text", normalization_options=NormalizationOptions()
|
||||
)
|
||||
== "This is not-a-url text"
|
||||
)
|
||||
assert (
|
||||
normalize_text(
|
||||
"Hello, how are you today?", normalization_options=NormalizationOptions()
|
||||
)
|
||||
== "Hello, how are you today?"
|
||||
)
|
||||
assert (
|
||||
normalize_text("It costs $50.", normalization_options=NormalizationOptions())
|
||||
== "It costs fifty dollars."
|
||||
)
|
|
@ -1,499 +0,0 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import AsyncGenerator, Tuple
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from api.src.core.config import settings
|
||||
from api.src.inference.base import AudioChunk
|
||||
from api.src.main import app
|
||||
from api.src.routers.openai_compatible import (
|
||||
get_tts_service,
|
||||
load_openai_mappings,
|
||||
stream_audio_chunks,
|
||||
)
|
||||
from api.src.services.streaming_audio_writer import StreamingAudioWriter
|
||||
from api.src.services.tts_service import TTSService
|
||||
from api.src.structures.schemas import OpenAISpeechRequest
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_voice():
|
||||
"""Fixture providing a test voice name."""
|
||||
return "test_voice"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_mappings():
|
||||
"""Mock OpenAI mappings for testing."""
|
||||
with patch(
|
||||
"api.src.routers.openai_compatible._openai_mappings",
|
||||
{
|
||||
"models": {"tts-1": "kokoro-v1_0", "tts-1-hd": "kokoro-v1_0"},
|
||||
"voices": {"alloy": "am_adam", "nova": "bf_isabella"},
|
||||
},
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_json_file(tmp_path):
|
||||
"""Create a temporary mock JSON file."""
|
||||
content = {
|
||||
"models": {"test-model": "test-kokoro"},
|
||||
"voices": {"test-voice": "test-internal"},
|
||||
}
|
||||
json_file = tmp_path / "test_mappings.json"
|
||||
json_file.write_text(json.dumps(content))
|
||||
return json_file
|
||||
|
||||
|
||||
def test_load_openai_mappings(mock_json_file):
|
||||
"""Test loading OpenAI mappings from JSON file"""
|
||||
with patch("os.path.join", return_value=str(mock_json_file)):
|
||||
mappings = load_openai_mappings()
|
||||
assert "models" in mappings
|
||||
assert "voices" in mappings
|
||||
assert mappings["models"]["test-model"] == "test-kokoro"
|
||||
assert mappings["voices"]["test-voice"] == "test-internal"
|
||||
|
||||
|
||||
def test_load_openai_mappings_file_not_found():
|
||||
"""Test handling of missing mappings file"""
|
||||
with patch("os.path.join", return_value="/nonexistent/path"):
|
||||
mappings = load_openai_mappings()
|
||||
assert mappings == {"models": {}, "voices": {}}
|
||||
|
||||
|
||||
def test_list_models(mock_openai_mappings):
|
||||
"""Test listing available models endpoint"""
|
||||
response = client.get("/v1/models")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["object"] == "list"
|
||||
assert isinstance(data["data"], list)
|
||||
assert len(data["data"]) == 3 # tts-1, tts-1-hd, and kokoro
|
||||
|
||||
# Verify all expected models are present
|
||||
model_ids = [model["id"] for model in data["data"]]
|
||||
assert "tts-1" in model_ids
|
||||
assert "tts-1-hd" in model_ids
|
||||
assert "kokoro" in model_ids
|
||||
|
||||
# Verify model format
|
||||
for model in data["data"]:
|
||||
assert model["object"] == "model"
|
||||
assert "created" in model
|
||||
assert model["owned_by"] == "kokoro"
|
||||
|
||||
|
||||
def test_retrieve_model(mock_openai_mappings):
|
||||
"""Test retrieving a specific model endpoint"""
|
||||
# Test successful model retrieval
|
||||
response = client.get("/v1/models/tts-1")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == "tts-1"
|
||||
assert data["object"] == "model"
|
||||
assert data["owned_by"] == "kokoro"
|
||||
assert "created" in data
|
||||
|
||||
# Test non-existent model
|
||||
response = client.get("/v1/models/nonexistent-model")
|
||||
assert response.status_code == 404
|
||||
error = response.json()
|
||||
assert error["detail"]["error"] == "model_not_found"
|
||||
assert "not found" in error["detail"]["message"]
|
||||
assert error["detail"]["type"] == "invalid_request_error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tts_service_initialization():
|
||||
"""Test TTSService initialization"""
|
||||
with patch("api.src.routers.openai_compatible._tts_service", None):
|
||||
with patch("api.src.routers.openai_compatible._init_lock", None):
|
||||
with patch("api.src.services.tts_service.TTSService.create") as mock_create:
|
||||
mock_service = AsyncMock()
|
||||
mock_create.return_value = mock_service
|
||||
|
||||
# Test concurrent access
|
||||
async def get_service():
|
||||
return await get_tts_service()
|
||||
|
||||
# Create multiple concurrent requests
|
||||
tasks = [get_service() for _ in range(5)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Verify service was created only once
|
||||
mock_create.assert_called_once()
|
||||
assert all(r == mock_service for r in results)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_audio_chunks_client_disconnect():
|
||||
"""Test handling of client disconnect during streaming"""
|
||||
mock_request = MagicMock()
|
||||
mock_request.is_disconnected = AsyncMock(return_value=True)
|
||||
|
||||
mock_service = AsyncMock()
|
||||
|
||||
async def mock_stream(*args, **kwargs):
|
||||
for i in range(5):
|
||||
yield AudioChunk(np.ndarray([], np.int16), output=b"chunk")
|
||||
|
||||
mock_service.generate_audio_stream = mock_stream
|
||||
mock_service.list_voices.return_value = ["test_voice"]
|
||||
|
||||
request = OpenAISpeechRequest(
|
||||
model="kokoro",
|
||||
input="Test text",
|
||||
voice="test_voice",
|
||||
response_format="mp3",
|
||||
stream=True,
|
||||
speed=1.0,
|
||||
)
|
||||
|
||||
writer = StreamingAudioWriter("mp3", 24000)
|
||||
|
||||
chunks = []
|
||||
async for chunk in stream_audio_chunks(mock_service, request, mock_request, writer):
|
||||
chunks.append(chunk)
|
||||
|
||||
writer.close()
|
||||
|
||||
assert len(chunks) == 0 # Should stop immediately due to disconnect
|
||||
|
||||
|
||||
def test_openai_voice_mapping(mock_tts_service, mock_openai_mappings):
|
||||
"""Test OpenAI voice name mapping"""
|
||||
mock_tts_service.list_voices.return_value = ["am_adam", "bf_isabella"]
|
||||
|
||||
response = client.post(
|
||||
"/v1/audio/speech",
|
||||
json={
|
||||
"model": "tts-1",
|
||||
"input": "Hello world",
|
||||
"voice": "alloy", # OpenAI voice name
|
||||
"response_format": "mp3",
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mock_tts_service.generate_audio.assert_called_once()
|
||||
assert mock_tts_service.generate_audio.call_args[1]["voice"] == "am_adam"
|
||||
|
||||
|
||||
def test_openai_voice_mapping_streaming(
|
||||
mock_tts_service, mock_openai_mappings, mock_audio_bytes
|
||||
):
|
||||
"""Test OpenAI voice mapping in streaming mode"""
|
||||
mock_tts_service.list_voices.return_value = ["am_adam", "bf_isabella"]
|
||||
|
||||
response = client.post(
|
||||
"/v1/audio/speech",
|
||||
json={
|
||||
"model": "tts-1-hd",
|
||||
"input": "Hello world",
|
||||
"voice": "nova", # OpenAI voice name
|
||||
"response_format": "mp3",
|
||||
"stream": True,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
content = b""
|
||||
for chunk in response.iter_bytes():
|
||||
content += chunk
|
||||
assert content == mock_audio_bytes
|
||||
|
||||
|
||||
def test_invalid_openai_model(mock_tts_service, mock_openai_mappings):
|
||||
"""Test error handling for invalid OpenAI model"""
|
||||
response = client.post(
|
||||
"/v1/audio/speech",
|
||||
json={
|
||||
"model": "invalid-model",
|
||||
"input": "Hello world",
|
||||
"voice": "alloy",
|
||||
"response_format": "mp3",
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
error_response = response.json()
|
||||
assert error_response["detail"]["error"] == "invalid_model"
|
||||
assert "Unsupported model" in error_response["detail"]["message"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_audio_bytes():
|
||||
"""Mock audio bytes for testing."""
|
||||
return b"mock audio data"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tts_service(mock_audio_bytes):
|
||||
"""Mock TTS service for testing."""
|
||||
with patch("api.src.routers.openai_compatible.get_tts_service") as mock_get:
|
||||
service = AsyncMock(spec=TTSService)
|
||||
service.generate_audio.return_value = AudioChunk(np.zeros(1000, np.int16))
|
||||
|
||||
async def mock_stream(*args, **kwargs) -> AsyncGenerator[AudioChunk, None]:
|
||||
yield AudioChunk(np.ndarray([], np.int16), output=mock_audio_bytes)
|
||||
|
||||
service.generate_audio_stream = mock_stream
|
||||
service.list_voices.return_value = ["test_voice", "voice1", "voice2"]
|
||||
service.combine_voices.return_value = "voice1_voice2"
|
||||
|
||||
mock_get.return_value = service
|
||||
mock_get.side_effect = None
|
||||
yield service
|
||||
|
||||
|
||||
@patch("api.src.services.audio.AudioService.convert_audio")
|
||||
def test_openai_speech_endpoint(
|
||||
mock_convert, mock_tts_service, test_voice, mock_audio_bytes
|
||||
):
|
||||
"""Test the OpenAI-compatible speech endpoint with basic MP3 generation"""
|
||||
# Configure mocks
|
||||
mock_tts_service.generate_audio.return_value = AudioChunk(np.zeros(1000, np.int16))
|
||||
mock_convert.return_value = AudioChunk(
|
||||
np.zeros(1000, np.int16), output=mock_audio_bytes
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/v1/audio/speech",
|
||||
json={
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": test_voice,
|
||||
"response_format": "mp3",
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/mpeg"
|
||||
assert len(response.content) > 0
|
||||
assert response.content == mock_audio_bytes + mock_audio_bytes
|
||||
|
||||
mock_tts_service.generate_audio.assert_called_once()
|
||||
assert mock_convert.call_count == 2
|
||||
|
||||
|
||||
def test_openai_speech_streaming(mock_tts_service, test_voice, mock_audio_bytes):
|
||||
"""Test the OpenAI-compatible speech endpoint with streaming"""
|
||||
response = client.post(
|
||||
"/v1/audio/speech",
|
||||
json={
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": test_voice,
|
||||
"response_format": "mp3",
|
||||
"stream": True,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/mpeg"
|
||||
assert "Transfer-Encoding" in response.headers
|
||||
assert response.headers["Transfer-Encoding"] == "chunked"
|
||||
|
||||
content = b""
|
||||
for chunk in response.iter_bytes():
|
||||
content += chunk
|
||||
assert content == mock_audio_bytes
|
||||
|
||||
|
||||
def test_openai_speech_pcm_streaming(mock_tts_service, test_voice, mock_audio_bytes):
|
||||
"""Test PCM streaming format"""
|
||||
response = client.post(
|
||||
"/v1/audio/speech",
|
||||
json={
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": test_voice,
|
||||
"response_format": "pcm",
|
||||
"stream": True,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/pcm"
|
||||
|
||||
content = b""
|
||||
for chunk in response.iter_bytes():
|
||||
content += chunk
|
||||
assert content == mock_audio_bytes
|
||||
|
||||
|
||||
def test_openai_speech_invalid_voice(mock_tts_service):
|
||||
"""Test error handling for invalid voice"""
|
||||
mock_tts_service.generate_audio.side_effect = ValueError(
|
||||
"Voice 'invalid_voice' not found"
|
||||
)
|
||||
|
||||
response = client.post(
|
||||
"/v1/audio/speech",
|
||||
json={
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": "invalid_voice",
|
||||
"response_format": "mp3",
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
error_response = response.json()
|
||||
assert error_response["detail"]["error"] == "validation_error"
|
||||
assert "Voice 'invalid_voice' not found" in error_response["detail"]["message"]
|
||||
assert error_response["detail"]["type"] == "invalid_request_error"
|
||||
|
||||
|
||||
def test_openai_speech_empty_text(mock_tts_service, test_voice):
|
||||
"""Test error handling for empty text"""
|
||||
|
||||
async def mock_error_stream(*args, **kwargs):
|
||||
raise ValueError("Text is empty after preprocessing")
|
||||
|
||||
mock_tts_service.generate_audio = mock_error_stream
|
||||
mock_tts_service.list_voices.return_value = ["test_voice"]
|
||||
|
||||
response = client.post(
|
||||
"/v1/audio/speech",
|
||||
json={
|
||||
"model": "kokoro",
|
||||
"input": "",
|
||||
"voice": test_voice,
|
||||
"response_format": "mp3",
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
error_response = response.json()
|
||||
assert error_response["detail"]["error"] == "validation_error"
|
||||
assert "Text is empty after preprocessing" in error_response["detail"]["message"]
|
||||
assert error_response["detail"]["type"] == "invalid_request_error"
|
||||
|
||||
|
||||
def test_openai_speech_invalid_format(mock_tts_service, test_voice):
|
||||
"""Test error handling for invalid format"""
|
||||
response = client.post(
|
||||
"/v1/audio/speech",
|
||||
json={
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": test_voice,
|
||||
"response_format": "invalid_format",
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 422 # Validation error from Pydantic
|
||||
|
||||
|
||||
def test_list_voices(mock_tts_service):
|
||||
"""Test listing available voices"""
|
||||
# Override the mock for this specific test
|
||||
mock_tts_service.list_voices.return_value = ["voice1", "voice2"]
|
||||
|
||||
response = client.get("/v1/audio/voices")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "voices" in data
|
||||
assert len(data["voices"]) == 2
|
||||
assert "voice1" in data["voices"]
|
||||
assert "voice2" in data["voices"]
|
||||
|
||||
|
||||
@patch("api.src.routers.openai_compatible.settings")
|
||||
def test_combine_voices(mock_settings, mock_tts_service):
|
||||
"""Test combining voices endpoint"""
|
||||
# Enable local voice saving for this test
|
||||
mock_settings.allow_local_voice_saving = True
|
||||
|
||||
response = client.post("/v1/audio/voices/combine", json="voice1+voice2")
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "application/octet-stream"
|
||||
assert "voice1+voice2.pt" in response.headers["content-disposition"]
|
||||
|
||||
|
||||
def test_server_error(mock_tts_service, test_voice):
|
||||
"""Test handling of server errors"""
|
||||
|
||||
async def mock_error_stream(*args, **kwargs):
|
||||
raise RuntimeError("Internal server error")
|
||||
|
||||
mock_tts_service.generate_audio = mock_error_stream
|
||||
mock_tts_service.list_voices.return_value = ["test_voice"]
|
||||
|
||||
response = client.post(
|
||||
"/v1/audio/speech",
|
||||
json={
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": test_voice,
|
||||
"response_format": "mp3",
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 500
|
||||
error_response = response.json()
|
||||
assert error_response["detail"]["error"] == "processing_error"
|
||||
assert error_response["detail"]["type"] == "server_error"
|
||||
|
||||
|
||||
def test_streaming_error(mock_tts_service, test_voice):
|
||||
"""Test handling streaming errors"""
|
||||
# Mock process_voices to raise the error
|
||||
mock_tts_service.list_voices.side_effect = RuntimeError("Streaming failed")
|
||||
|
||||
response = client.post(
|
||||
"/v1/audio/speech",
|
||||
json={
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": test_voice,
|
||||
"response_format": "mp3",
|
||||
"stream": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
error_data = response.json()
|
||||
assert error_data["detail"]["error"] == "processing_error"
|
||||
assert error_data["detail"]["type"] == "server_error"
|
||||
assert "Streaming failed" in error_data["detail"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_initialization_error():
|
||||
"""Test handling of streaming initialization errors"""
|
||||
mock_service = AsyncMock()
|
||||
|
||||
async def mock_error_stream(*args, **kwargs):
|
||||
if False: # This makes it a proper generator
|
||||
yield b""
|
||||
raise RuntimeError("Failed to initialize stream")
|
||||
|
||||
mock_service.generate_audio_stream = mock_error_stream
|
||||
mock_service.list_voices.return_value = ["test_voice"]
|
||||
|
||||
request = OpenAISpeechRequest(
|
||||
model="kokoro",
|
||||
input="Test text",
|
||||
voice="test_voice",
|
||||
response_format="mp3",
|
||||
stream=True,
|
||||
speed=1.0,
|
||||
)
|
||||
|
||||
writer = StreamingAudioWriter("mp3", 24000)
|
||||
|
||||
with pytest.raises(RuntimeError) as exc:
|
||||
async for _ in stream_audio_chunks(mock_service, request, MagicMock(), writer):
|
||||
pass
|
||||
|
||||
writer.close()
|
||||
assert "Failed to initialize stream" in str(exc.value)
|
|
@ -1,138 +0,0 @@
|
|||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.src.core.paths import (
|
||||
_find_file,
|
||||
_scan_directories,
|
||||
get_content_type,
|
||||
get_temp_dir_size,
|
||||
get_temp_file_path,
|
||||
list_temp_files,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_file_exists():
|
||||
"""Test finding existing file."""
|
||||
with patch("aiofiles.os.path.exists") as mock_exists:
|
||||
mock_exists.return_value = True
|
||||
path = await _find_file("test.txt", ["/test/path"])
|
||||
assert path == "/test/path/test.txt"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_file_not_exists():
|
||||
"""Test finding non-existent file."""
|
||||
with patch("aiofiles.os.path.exists") as mock_exists:
|
||||
mock_exists.return_value = False
|
||||
with pytest.raises(FileNotFoundError, match="File not found"):
|
||||
await _find_file("test.txt", ["/test/path"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_file_with_filter():
|
||||
"""Test finding file with filter function."""
|
||||
with patch("aiofiles.os.path.exists") as mock_exists:
|
||||
mock_exists.return_value = True
|
||||
filter_fn = lambda p: p.endswith(".txt")
|
||||
path = await _find_file("test.txt", ["/test/path"], filter_fn)
|
||||
assert path == "/test/path/test.txt"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_directories():
|
||||
"""Test scanning directories."""
|
||||
mock_entry = type("MockEntry", (), {"name": "test.txt"})()
|
||||
|
||||
with (
|
||||
patch("aiofiles.os.path.exists") as mock_exists,
|
||||
patch("aiofiles.os.scandir") as mock_scandir,
|
||||
):
|
||||
mock_exists.return_value = True
|
||||
mock_scandir.return_value = [mock_entry]
|
||||
|
||||
files = await _scan_directories(["/test/path"])
|
||||
assert "test.txt" in files
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_content_type():
|
||||
"""Test content type detection."""
|
||||
test_cases = [
|
||||
("test.html", "text/html"),
|
||||
("test.js", "application/javascript"),
|
||||
("test.css", "text/css"),
|
||||
("test.png", "image/png"),
|
||||
("test.unknown", "application/octet-stream"),
|
||||
]
|
||||
|
||||
for filename, expected in test_cases:
|
||||
content_type = await get_content_type(filename)
|
||||
assert content_type == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_temp_file_path():
|
||||
"""Test temp file path generation."""
|
||||
with (
|
||||
patch("aiofiles.os.path.exists") as mock_exists,
|
||||
patch("aiofiles.os.makedirs") as mock_makedirs,
|
||||
):
|
||||
mock_exists.return_value = False
|
||||
|
||||
path = await get_temp_file_path("test.wav")
|
||||
assert "test.wav" in path
|
||||
mock_makedirs.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_temp_files():
|
||||
"""Test listing temp files."""
|
||||
|
||||
class MockEntry:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def is_file(self):
|
||||
return True
|
||||
|
||||
mock_entry = MockEntry("test.wav")
|
||||
|
||||
with (
|
||||
patch("aiofiles.os.path.exists") as mock_exists,
|
||||
patch("aiofiles.os.scandir") as mock_scandir,
|
||||
):
|
||||
mock_exists.return_value = True
|
||||
mock_scandir.return_value = [mock_entry]
|
||||
|
||||
files = await list_temp_files()
|
||||
assert "test.wav" in files
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_temp_dir_size():
|
||||
"""Test getting temp directory size."""
|
||||
|
||||
class MockEntry:
|
||||
def __init__(self, path):
|
||||
self.path = path
|
||||
|
||||
def is_file(self):
|
||||
return True
|
||||
|
||||
mock_entry = MockEntry("/tmp/test.wav")
|
||||
mock_stat = type("MockStat", (), {"st_size": 1024})()
|
||||
|
||||
with (
|
||||
patch("aiofiles.os.path.exists") as mock_exists,
|
||||
patch("aiofiles.os.scandir") as mock_scandir,
|
||||
patch("aiofiles.os.stat") as mock_stat_fn,
|
||||
):
|
||||
mock_exists.return_value = True
|
||||
mock_scandir.return_value = [mock_entry]
|
||||
mock_stat_fn.return_value = mock_stat
|
||||
|
||||
size = await get_temp_dir_size()
|
||||
assert size == 1024
|
|
@ -1,167 +0,0 @@
|
|||
import pytest
|
||||
|
||||
from api.src.services.text_processing.text_processor import (
|
||||
get_sentence_info,
|
||||
process_text_chunk,
|
||||
smart_split,
|
||||
)
|
||||
|
||||
|
||||
def test_process_text_chunk_basic():
|
||||
"""Test basic text chunk processing."""
|
||||
text = "Hello world"
|
||||
tokens = process_text_chunk(text)
|
||||
assert isinstance(tokens, list)
|
||||
assert len(tokens) > 0
|
||||
|
||||
|
||||
def test_process_text_chunk_empty():
|
||||
"""Test processing empty text."""
|
||||
text = ""
|
||||
tokens = process_text_chunk(text)
|
||||
assert isinstance(tokens, list)
|
||||
assert len(tokens) == 0
|
||||
|
||||
|
||||
def test_process_text_chunk_phonemes():
|
||||
"""Test processing with skip_phonemize."""
|
||||
phonemes = "h @ l @U" # Example phoneme sequence
|
||||
tokens = process_text_chunk(phonemes, skip_phonemize=True)
|
||||
assert isinstance(tokens, list)
|
||||
assert len(tokens) > 0
|
||||
|
||||
|
||||
def test_get_sentence_info():
|
||||
"""Test sentence splitting and info extraction."""
|
||||
text = "This is sentence one. This is sentence two! What about three?"
|
||||
results = get_sentence_info(text, {})
|
||||
|
||||
assert len(results) == 3
|
||||
for sentence, tokens, count in results:
|
||||
assert isinstance(sentence, str)
|
||||
assert isinstance(tokens, list)
|
||||
assert isinstance(count, int)
|
||||
assert count == len(tokens)
|
||||
assert count > 0
|
||||
|
||||
|
||||
def test_get_sentence_info_phenomoes():
|
||||
"""Test sentence splitting and info extraction."""
|
||||
text = (
|
||||
"This is sentence one. This is </|custom_phonemes_0|/> two! What about three?"
|
||||
)
|
||||
results = get_sentence_info(text, {"</|custom_phonemes_0|/>": r"sˈɛntᵊns"})
|
||||
|
||||
assert len(results) == 3
|
||||
assert "sˈɛntᵊns" in results[1][0]
|
||||
for sentence, tokens, count in results:
|
||||
assert isinstance(sentence, str)
|
||||
assert isinstance(tokens, list)
|
||||
assert isinstance(count, int)
|
||||
assert count == len(tokens)
|
||||
assert count > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_split_short_text():
|
||||
"""Test smart splitting with text under max tokens."""
|
||||
text = "This is a short test sentence."
|
||||
chunks = []
|
||||
async for chunk_text, chunk_tokens in smart_split(text):
|
||||
chunks.append((chunk_text, chunk_tokens))
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert isinstance(chunks[0][0], str)
|
||||
assert isinstance(chunks[0][1], list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_split_long_text():
|
||||
"""Test smart splitting with longer text."""
|
||||
# Create text that should split into multiple chunks
|
||||
text = ". ".join(["This is test sentence number " + str(i) for i in range(20)])
|
||||
|
||||
chunks = []
|
||||
async for chunk_text, chunk_tokens in smart_split(text):
|
||||
chunks.append((chunk_text, chunk_tokens))
|
||||
|
||||
assert len(chunks) > 1
|
||||
for chunk_text, chunk_tokens in chunks:
|
||||
assert isinstance(chunk_text, str)
|
||||
assert isinstance(chunk_tokens, list)
|
||||
assert len(chunk_tokens) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_split_with_punctuation():
|
||||
"""Test smart splitting handles punctuation correctly."""
|
||||
text = "First sentence! Second sentence? Third sentence; Fourth sentence: Fifth sentence."
|
||||
|
||||
chunks = []
|
||||
async for chunk_text, chunk_tokens in smart_split(text):
|
||||
chunks.append(chunk_text)
|
||||
|
||||
# Verify punctuation is preserved
|
||||
assert all(any(p in chunk for p in "!?;:.") for chunk in chunks)
|
||||
|
||||
def test_process_text_chunk_chinese_phonemes():
|
||||
"""Test processing with Chinese pinyin phonemes."""
|
||||
pinyin = "nǐ hǎo lì" # Example pinyin sequence with tones
|
||||
tokens = process_text_chunk(pinyin, skip_phonemize=True, language="z")
|
||||
assert isinstance(tokens, list)
|
||||
assert len(tokens) > 0
|
||||
|
||||
|
||||
def test_get_sentence_info_chinese():
|
||||
"""Test Chinese sentence splitting and info extraction."""
|
||||
text = "这是一个句子。这是第二个句子!第三个问题?"
|
||||
results = get_sentence_info(text, {}, lang_code="z")
|
||||
|
||||
assert len(results) == 3
|
||||
for sentence, tokens, count in results:
|
||||
assert isinstance(sentence, str)
|
||||
assert isinstance(tokens, list)
|
||||
assert isinstance(count, int)
|
||||
assert count == len(tokens)
|
||||
assert count > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_split_chinese_short():
|
||||
"""Test Chinese smart splitting with short text."""
|
||||
text = "这是一句话。"
|
||||
chunks = []
|
||||
async for chunk_text, chunk_tokens in smart_split(text, lang_code="z"):
|
||||
chunks.append((chunk_text, chunk_tokens))
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert isinstance(chunks[0][0], str)
|
||||
assert isinstance(chunks[0][1], list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_split_chinese_long():
|
||||
"""Test Chinese smart splitting with longer text."""
|
||||
text = "。".join([f"测试句子 {i}" for i in range(20)])
|
||||
|
||||
chunks = []
|
||||
async for chunk_text, chunk_tokens in smart_split(text, lang_code="z"):
|
||||
chunks.append((chunk_text, chunk_tokens))
|
||||
|
||||
assert len(chunks) > 1
|
||||
for chunk_text, chunk_tokens in chunks:
|
||||
assert isinstance(chunk_text, str)
|
||||
assert isinstance(chunk_tokens, list)
|
||||
assert len(chunk_tokens) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_split_chinese_punctuation():
|
||||
"""Test Chinese smart splitting with punctuation preservation."""
|
||||
text = "第一句!第二问?第三句;第四句:第五句。"
|
||||
|
||||
chunks = []
|
||||
async for chunk_text, _ in smart_split(text, lang_code="z"):
|
||||
chunks.append(chunk_text)
|
||||
|
||||
# Verify Chinese punctuation is preserved
|
||||
assert all(any(p in chunk for p in "!?;:。") for chunk in chunks)
|
|
@ -1,126 +0,0 @@
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from api.src.services.tts_service import TTSService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_managers():
|
||||
"""Mock model and voice managers."""
|
||||
|
||||
async def _mock_managers():
|
||||
model_manager = AsyncMock()
|
||||
model_manager.get_backend.return_value = MagicMock()
|
||||
|
||||
voice_manager = AsyncMock()
|
||||
voice_manager.get_voice_path.return_value = "/path/to/voice.pt"
|
||||
voice_manager.list_voices.return_value = ["voice1", "voice2"]
|
||||
|
||||
with (
|
||||
patch("api.src.services.tts_service.get_model_manager") as mock_get_model,
|
||||
patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice,
|
||||
):
|
||||
mock_get_model.return_value = model_manager
|
||||
mock_get_voice.return_value = voice_manager
|
||||
return model_manager, voice_manager
|
||||
|
||||
return _mock_managers()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tts_service(mock_managers):
|
||||
"""Create TTSService instance with mocked dependencies."""
|
||||
|
||||
async def _create_service():
|
||||
return await TTSService.create("test_output")
|
||||
|
||||
return _create_service()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_creation():
|
||||
"""Test service creation and initialization."""
|
||||
model_manager = AsyncMock()
|
||||
voice_manager = AsyncMock()
|
||||
|
||||
with (
|
||||
patch("api.src.services.tts_service.get_model_manager") as mock_get_model,
|
||||
patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice,
|
||||
):
|
||||
mock_get_model.return_value = model_manager
|
||||
mock_get_voice.return_value = voice_manager
|
||||
|
||||
service = await TTSService.create("test_output")
|
||||
assert service.output_dir == "test_output"
|
||||
assert service.model_manager is model_manager
|
||||
assert service._voice_manager is voice_manager
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_voice_path_single():
|
||||
"""Test getting path for single voice."""
|
||||
model_manager = AsyncMock()
|
||||
voice_manager = AsyncMock()
|
||||
voice_manager.get_voice_path.return_value = "/path/to/voice1.pt"
|
||||
|
||||
with (
|
||||
patch("api.src.services.tts_service.get_model_manager") as mock_get_model,
|
||||
patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice,
|
||||
):
|
||||
mock_get_model.return_value = model_manager
|
||||
mock_get_voice.return_value = voice_manager
|
||||
|
||||
service = await TTSService.create("test_output")
|
||||
name, path = await service._get_voices_path("voice1")
|
||||
assert name == "voice1"
|
||||
assert path == "/path/to/voice1.pt"
|
||||
voice_manager.get_voice_path.assert_called_once_with("voice1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_voice_path_combined():
|
||||
"""Test getting path for combined voices."""
|
||||
model_manager = AsyncMock()
|
||||
voice_manager = AsyncMock()
|
||||
voice_manager.get_voice_path.return_value = "/path/to/voice.pt"
|
||||
|
||||
with (
|
||||
patch("api.src.services.tts_service.get_model_manager") as mock_get_model,
|
||||
patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice,
|
||||
patch("torch.load") as mock_load,
|
||||
patch("torch.save") as mock_save,
|
||||
patch("tempfile.gettempdir") as mock_temp,
|
||||
):
|
||||
mock_get_model.return_value = model_manager
|
||||
mock_get_voice.return_value = voice_manager
|
||||
mock_temp.return_value = "/tmp"
|
||||
mock_load.return_value = torch.ones(10)
|
||||
|
||||
service = await TTSService.create("test_output")
|
||||
name, path = await service._get_voices_path("voice1+voice2")
|
||||
assert name == "voice1+voice2"
|
||||
assert path.endswith("voice1+voice2.pt")
|
||||
mock_save.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_voices():
|
||||
"""Test listing available voices."""
|
||||
model_manager = AsyncMock()
|
||||
voice_manager = AsyncMock()
|
||||
voice_manager.list_voices.return_value = ["voice1", "voice2"]
|
||||
|
||||
with (
|
||||
patch("api.src.services.tts_service.get_model_manager") as mock_get_model,
|
||||
patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice,
|
||||
):
|
||||
mock_get_model.return_value = model_manager
|
||||
mock_get_voice.return_value = voice_manager
|
||||
|
||||
service = await TTSService.create("test_output")
|
||||
voices = await service.list_voices()
|
||||
assert voices == ["voice1", "voice2"]
|
||||
voice_manager.list_voices.assert_called_once()
|
|
@ -1,67 +1,59 @@
|
|||
FROM --platform=$BUILDPLATFORM nvidia/cuda:12.8.0-cudnn-runtime-ubuntu24.04
|
||||
FROM --platform=$BUILDPLATFORM nvidia/cuda:12.9.0-base-ubuntu24.04
|
||||
# Set non-interactive frontend
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install Python and other dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
python3.10 \
|
||||
python3-venv \
|
||||
espeak-ng \
|
||||
espeak-ng-data \
|
||||
git \
|
||||
libsndfile1 \
|
||||
curl \
|
||||
ffmpeg \
|
||||
g++ \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/* \
|
||||
&& mkdir -p /usr/share/espeak-ng-data \
|
||||
&& ln -s /usr/lib/*/espeak-ng-data/* /usr/share/espeak-ng-data/
|
||||
|
||||
# Install UV using the installer script
|
||||
RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
||||
RUN apt-get update -y && \
|
||||
apt-get install -y python3 python3-venv libsndfile1 curl ffmpeg g++ && \
|
||||
apt-get clean && rm -rf /var/lib/apt/lists/* && \
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
||||
mv /root/.local/bin/uv /usr/local/bin/ && \
|
||||
mv /root/.local/bin/uvx /usr/local/bin/
|
||||
|
||||
# Create non-root user and set up directories and permissions
|
||||
RUN useradd -m -u 1001 appuser && \
|
||||
mv /root/.local/bin/uvx /usr/local/bin/ && \
|
||||
useradd -m -u 1001 appuser && \
|
||||
mkdir -p /app/api/src/models/v1_0 && \
|
||||
chown -R appuser:appuser /app
|
||||
|
||||
USER appuser
|
||||
WORKDIR /app
|
||||
|
||||
# Copy dependency files
|
||||
COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml
|
||||
|
||||
ENV PHONEMIZER_ESPEAK_PATH=/usr/bin \
|
||||
PHONEMIZER_ESPEAK_DATA=/usr/share/espeak-ng-data \
|
||||
ESPEAK_DATA_PATH=/usr/share/espeak-ng-data
|
||||
PYTHONUNBUFFERED=1 \
|
||||
PYTHONPATH=/app:/app/api \
|
||||
PATH="/app/.venv/bin:$PATH" \
|
||||
UV_LINK_MODE=copy \
|
||||
USE_GPU=true \
|
||||
DEVICE="gpu"
|
||||
|
||||
# Install dependencies with GPU extras (using cache mounts)
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv venv --python 3.10 && \
|
||||
uv sync --extra gpu
|
||||
|
||||
# Copy project files including models
|
||||
COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml
|
||||
COPY --chown=appuser:appuser api ./api
|
||||
COPY --chown=appuser:appuser web ./web
|
||||
COPY --chown=appuser:appuser docker/scripts/ ./
|
||||
RUN chmod +x ./entrypoint.sh
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv venv --python 3.10 && \
|
||||
uv sync --extra gpu && \
|
||||
uv cache clean && \
|
||||
python download_model.py --output api/src/models/v1_0
|
||||
|
||||
|
||||
# Set all environment variables in one go
|
||||
ENV PYTHONUNBUFFERED=1 \
|
||||
PYTHONPATH=/app:/app/api \
|
||||
PATH="/app/.venv/bin:$PATH" \
|
||||
UV_LINK_MODE=copy \
|
||||
USE_GPU=true
|
||||
|
||||
ENV DOWNLOAD_MODEL=true
|
||||
# Download model if enabled
|
||||
RUN if [ "$DOWNLOAD_MODEL" = "true" ]; then \
|
||||
python download_model.py --output api/src/models/v1_0; \
|
||||
fi
|
||||
|
||||
ENV DEVICE="gpu"
|
||||
# Run FastAPI server through entrypoint.sh
|
||||
CMD ["./entrypoint.sh"]
|
||||
|
||||
|
||||
|
||||
# If you want to test the docker image locally, run this from the project root:
|
||||
# docker build -f docker\gpu\Dockerfile -t kokoro .
|
||||
# Run it with
|
||||
# docker run -p 8880:8880 --name kokoro kokoro --gpus all
|
||||
#
|
||||
# You can log into the container with
|
||||
# docker exec -it kokoro /bin/bash
|
||||
#
|
||||
# Other commands:
|
||||
# 1. Stop and remove container
|
||||
# docker stop kokoro
|
||||
# docker container remove kokoro
|
||||
# 2. List and remove images
|
||||
# docker images
|
||||
# docker image remove kokoro
|
|
@ -31,10 +31,9 @@ dependencies = [
|
|||
"matplotlib>=3.10.0",
|
||||
"mutagen>=1.47.0",
|
||||
"psutil>=6.1.1",
|
||||
"espeakng-loader==0.2.4",
|
||||
"kokoro==0.9.2",
|
||||
"misaki[en,ja,ko,zh]==0.9.3",
|
||||
"spacy==3.8.5",
|
||||
"spacy==3.8.7",
|
||||
"en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl",
|
||||
"inflect>=7.5.0",
|
||||
"phonemizer-fork>=3.3.2",
|
||||
|
|
Loading…
Add table
Reference in a new issue