Dockerfile optimizations:

Moved to nvidia cuda base image.
Merged all apt-get install commands into one.
Removed espeak-ng.
Added and uv cache clean command at the end.
Removed g++ at the end.
Concatenated all ENV commands.
Removed api tests folder.
This commit is contained in:
faltiska 2025-06-07 11:02:10 +03:00
parent 543cbecc1a
commit a3d23e9dad
14 changed files with 39 additions and 1845 deletions

View file

@ -1 +0,0 @@
# Make tests directory a Python package

View file

@ -1,71 +0,0 @@
import os
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
import pytest_asyncio
import torch
from api.src.inference.model_manager import ModelManager
from api.src.inference.voice_manager import VoiceManager
from api.src.services.tts_service import TTSService
from api.src.structures.model_schemas import VoiceConfig
@pytest.fixture
def mock_voice_tensor():
"""Load a real voice tensor for testing."""
voice_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "src/voices/af_bella.pt"
)
return torch.load(voice_path, map_location="cpu", weights_only=False)
@pytest.fixture
def mock_audio_output():
"""Load pre-generated test audio for consistent testing."""
test_audio_path = os.path.join(
os.path.dirname(__file__), "test_data/test_audio.npy"
)
return np.load(test_audio_path) # Return as numpy array instead of bytes
@pytest_asyncio.fixture
async def mock_model_manager(mock_audio_output):
"""Mock model manager for testing."""
manager = AsyncMock(spec=ModelManager)
manager.get_backend = MagicMock()
async def mock_generate(*args, **kwargs):
# Simulate successful audio generation
return np.random.rand(24000).astype(np.float32) # 1 second of random audio data
manager.generate = AsyncMock(side_effect=mock_generate)
return manager
@pytest_asyncio.fixture
async def mock_voice_manager(mock_voice_tensor):
"""Mock voice manager for testing."""
manager = AsyncMock(spec=VoiceManager)
manager.get_voice_path = MagicMock(return_value="/mock/path/voice.pt")
manager.load_voice = AsyncMock(return_value=mock_voice_tensor)
manager.list_voices = AsyncMock(return_value=["voice1", "voice2"])
manager.combine_voices = AsyncMock(return_value="voice1_voice2")
return manager
@pytest_asyncio.fixture
async def tts_service(mock_model_manager, mock_voice_manager):
"""Get mocked TTS service instance."""
service = TTSService()
service.model_manager = mock_model_manager
service._voice_manager = mock_voice_manager
return service
@pytest.fixture
def test_voice():
"""Return a test voice name."""
return "voice1"

View file

@ -1,256 +0,0 @@
"""Tests for AudioService"""
from unittest.mock import patch
import numpy as np
import pytest
from api.src.inference.base import AudioChunk
from api.src.services.audio import AudioNormalizer, AudioService
from api.src.services.streaming_audio_writer import StreamingAudioWriter
@pytest.fixture(autouse=True)
def mock_settings():
"""Mock settings for all tests"""
with patch("api.src.services.audio.settings") as mock_settings:
mock_settings.gap_trim_ms = 250
yield mock_settings
@pytest.fixture
def sample_audio():
"""Generate a simple sine wave for testing"""
sample_rate = 24000
duration = 0.1 # 100ms
t = np.linspace(0, duration, int(sample_rate * duration))
frequency = 440 # A4 note
return np.sin(2 * np.pi * frequency * t).astype(np.float32), sample_rate
@pytest.mark.asyncio
async def test_convert_to_wav(sample_audio):
"""Test converting to WAV format"""
audio_data, sample_rate = sample_audio
writer = StreamingAudioWriter("wav", sample_rate=24000)
# Write and finalize in one step for WAV
audio_chunk = await AudioService.convert_audio(
AudioChunk(audio_data), "wav", writer, is_last_chunk=False
)
writer.close()
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
assert len(audio_chunk.output) > 0
# Check WAV header
assert audio_chunk.output.startswith(b"RIFF")
assert b"WAVE" in audio_chunk.output[:12]
@pytest.mark.asyncio
async def test_convert_to_mp3(sample_audio):
"""Test converting to MP3 format"""
audio_data, sample_rate = sample_audio
writer = StreamingAudioWriter("mp3", sample_rate=24000)
audio_chunk = await AudioService.convert_audio(
AudioChunk(audio_data), "mp3", writer
)
writer.close()
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
assert len(audio_chunk.output) > 0
# Check MP3 header (ID3 or MPEG frame sync)
assert audio_chunk.output.startswith(b"ID3") or audio_chunk.output.startswith(
b"\xff\xfb"
)
@pytest.mark.asyncio
async def test_convert_to_opus(sample_audio):
"""Test converting to Opus format"""
audio_data, sample_rate = sample_audio
writer = StreamingAudioWriter("opus", sample_rate=24000)
audio_chunk = await AudioService.convert_audio(
AudioChunk(audio_data), "opus", writer
)
writer.close()
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
assert len(audio_chunk.output) > 0
# Check OGG header
assert audio_chunk.output.startswith(b"OggS")
@pytest.mark.asyncio
async def test_convert_to_flac(sample_audio):
"""Test converting to FLAC format"""
audio_data, sample_rate = sample_audio
writer = StreamingAudioWriter("flac", sample_rate=24000)
audio_chunk = await AudioService.convert_audio(
AudioChunk(audio_data), "flac", writer
)
writer.close()
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
assert len(audio_chunk.output) > 0
# Check FLAC header
assert audio_chunk.output.startswith(b"fLaC")
@pytest.mark.asyncio
async def test_convert_to_aac(sample_audio):
"""Test converting to M4A format"""
audio_data, sample_rate = sample_audio
writer = StreamingAudioWriter("aac", sample_rate=24000)
audio_chunk = await AudioService.convert_audio(
AudioChunk(audio_data), "aac", writer
)
writer.close()
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
assert len(audio_chunk.output) > 0
# Check ADTS header (AAC)
assert audio_chunk.output.startswith(b"\xff\xf0") or audio_chunk.output.startswith(
b"\xff\xf1"
)
@pytest.mark.asyncio
async def test_convert_to_pcm(sample_audio):
"""Test converting to PCM format"""
audio_data, sample_rate = sample_audio
writer = StreamingAudioWriter("pcm", sample_rate=24000)
audio_chunk = await AudioService.convert_audio(
AudioChunk(audio_data), "pcm", writer
)
writer.close()
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
assert len(audio_chunk.output) > 0
# PCM is raw bytes, so no header to check
@pytest.mark.asyncio
async def test_convert_to_invalid_format_raises_error(sample_audio):
"""Test that converting to an invalid format raises an error"""
# audio_data, sample_rate = sample_audio
with pytest.raises(ValueError, match="Unsupported format: invalid"):
writer = StreamingAudioWriter("invalid", sample_rate=24000)
@pytest.mark.asyncio
async def test_normalization_wav(sample_audio):
"""Test that WAV output is properly normalized to int16 range"""
audio_data, sample_rate = sample_audio
writer = StreamingAudioWriter("wav", sample_rate=24000)
# Create audio data outside int16 range
large_audio = audio_data * 1e5
# Write and finalize in one step for WAV
audio_chunk = await AudioService.convert_audio(
AudioChunk(large_audio), "wav", writer
)
writer.close()
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
assert len(audio_chunk.output) > 0
@pytest.mark.asyncio
async def test_normalization_pcm(sample_audio):
"""Test that PCM output is properly normalized to int16 range"""
audio_data, sample_rate = sample_audio
writer = StreamingAudioWriter("pcm", sample_rate=24000)
# Create audio data outside int16 range
large_audio = audio_data * 1e5
audio_chunk = await AudioService.convert_audio(
AudioChunk(large_audio), "pcm", writer
)
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
assert len(audio_chunk.output) > 0
@pytest.mark.asyncio
async def test_invalid_audio_data():
"""Test handling of invalid audio data"""
invalid_audio = np.array([]) # Empty array
sample_rate = 24000
writer = StreamingAudioWriter("wav", sample_rate=24000)
with pytest.raises(ValueError):
await AudioService.convert_audio(invalid_audio, sample_rate, "wav", writer)
@pytest.mark.asyncio
async def test_different_sample_rates(sample_audio):
"""Test converting audio with different sample rates"""
audio_data, _ = sample_audio
sample_rates = [8000, 16000, 44100, 48000]
for rate in sample_rates:
writer = StreamingAudioWriter("wav", sample_rate=rate)
audio_chunk = await AudioService.convert_audio(
AudioChunk(audio_data), "wav", writer
)
writer.close()
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
assert len(audio_chunk.output) > 0
@pytest.mark.asyncio
async def test_buffer_position_after_conversion(sample_audio):
"""Test that buffer position is reset after writing"""
audio_data, sample_rate = sample_audio
writer = StreamingAudioWriter("wav", sample_rate=24000)
# Write and finalize in one step for first conversion
audio_chunk1 = await AudioService.convert_audio(
AudioChunk(audio_data), "wav", writer, is_last_chunk=True
)
assert isinstance(audio_chunk1.output, bytes)
assert isinstance(audio_chunk1, AudioChunk)
# Convert again to ensure buffer was properly reset
writer = StreamingAudioWriter("wav", sample_rate=24000)
audio_chunk2 = await AudioService.convert_audio(
AudioChunk(audio_data), "wav", writer, is_last_chunk=True
)
assert isinstance(audio_chunk2.output, bytes)
assert isinstance(audio_chunk2, AudioChunk)
assert len(audio_chunk1.output) == len(audio_chunk2.output)

View file

@ -1,23 +0,0 @@
import os
import numpy as np
def generate_test_audio():
"""Generate test audio data - 1 second of 440Hz tone"""
# Create 1 second of silence at 24kHz
audio = np.zeros(24000, dtype=np.float32)
# Add a simple sine wave to make it non-zero
t = np.linspace(0, 1, 24000)
audio += 0.5 * np.sin(2 * np.pi * 440 * t) # 440 Hz tone at half amplitude
# Create test_data directory if it doesn't exist
os.makedirs("api/tests/test_data", exist_ok=True)
# Save the test audio
np.save("api/tests/test_data/test_audio.npy", audio)
if __name__ == "__main__":
generate_test_audio()

Binary file not shown.

View file

@ -1,34 +0,0 @@
import base64
import json
from unittest.mock import MagicMock, patch
import pytest
import requests
def test_generate_captioned_speech():
"""Test the generate_captioned_speech function with mocked responses"""
# Mock the API responses
mock_audio_response = MagicMock()
mock_audio_response.status_code = 200
mock_timestamps_response = MagicMock()
mock_timestamps_response.status_code = 200
mock_timestamps_response.content = json.dumps(
{
"audio": base64.b64encode(b"mock audio data").decode("utf-8"),
"timestamps": [{"word": "test", "start_time": 0.0, "end_time": 1.0}],
}
)
# Patch the HTTP requests
with patch("requests.post", return_value=mock_timestamps_response):
# Import here to avoid module-level import issues
from examples.captioned_speech_example import generate_captioned_speech
# Test the function
audio, timestamps = generate_captioned_speech("test text")
# Verify we got both audio and timestamps
assert audio == b"mock audio data"
assert timestamps == [{"word": "test", "start_time": 0.0, "end_time": 1.0}]

View file

@ -1,165 +0,0 @@
from unittest.mock import ANY, MagicMock, patch
import numpy as np
import pytest
import torch
from api.src.inference.kokoro_v1 import KokoroV1
@pytest.fixture
def kokoro_backend():
"""Create a KokoroV1 instance for testing."""
return KokoroV1()
def test_initial_state(kokoro_backend):
"""Test initial state of KokoroV1."""
assert not kokoro_backend.is_loaded
assert kokoro_backend._model is None
assert kokoro_backend._pipelines == {} # Now using dict of pipelines
# Device should be set based on settings
assert kokoro_backend.device in ["cuda", "cpu"]
@patch("torch.cuda.is_available", return_value=True)
@patch("torch.cuda.memory_allocated", return_value=5e9)
def test_memory_management(mock_memory, mock_cuda, kokoro_backend):
"""Test GPU memory management functions."""
# Patch backend so it thinks we have cuda
with patch.object(kokoro_backend, "_device", "cuda"):
# Test memory check
with patch("api.src.inference.kokoro_v1.model_config") as mock_config:
mock_config.pytorch_gpu.memory_threshold = 4
assert kokoro_backend._check_memory() == True
mock_config.pytorch_gpu.memory_threshold = 6
assert kokoro_backend._check_memory() == False
@patch("torch.cuda.empty_cache")
@patch("torch.cuda.synchronize")
def test_clear_memory(mock_sync, mock_clear, kokoro_backend):
"""Test memory clearing."""
with patch.object(kokoro_backend, "_device", "cuda"):
kokoro_backend._clear_memory()
mock_clear.assert_called_once()
mock_sync.assert_called_once()
@pytest.mark.asyncio
async def test_load_model_validation(kokoro_backend):
"""Test model loading validation."""
with pytest.raises(RuntimeError, match="Failed to load Kokoro model"):
await kokoro_backend.load_model("nonexistent_model.pth")
def test_unload_with_pipelines(kokoro_backend):
"""Test model unloading with multiple pipelines."""
# Mock loaded state with multiple pipelines
kokoro_backend._model = MagicMock()
pipeline_a = MagicMock()
pipeline_e = MagicMock()
kokoro_backend._pipelines = {"a": pipeline_a, "e": pipeline_e}
assert kokoro_backend.is_loaded
# Test unload
kokoro_backend.unload()
assert not kokoro_backend.is_loaded
assert kokoro_backend._model is None
assert kokoro_backend._pipelines == {} # All pipelines should be cleared
@pytest.mark.asyncio
async def test_generate_validation(kokoro_backend):
"""Test generation validation."""
with pytest.raises(RuntimeError, match="Model not loaded"):
async for _ in kokoro_backend.generate("test", "voice"):
pass
@pytest.mark.asyncio
async def test_generate_from_tokens_validation(kokoro_backend):
"""Test token generation validation."""
with pytest.raises(RuntimeError, match="Model not loaded"):
async for _ in kokoro_backend.generate_from_tokens("test tokens", "voice"):
pass
def test_get_pipeline_creates_new(kokoro_backend):
"""Test that _get_pipeline creates new pipeline for new language code."""
# Mock loaded state
kokoro_backend._model = MagicMock()
# Mock KPipeline
mock_pipeline = MagicMock()
with patch(
"api.src.inference.kokoro_v1.KPipeline", return_value=mock_pipeline
) as mock_kpipeline:
# Get pipeline for Spanish
pipeline_e = kokoro_backend._get_pipeline("e")
# Should create new pipeline with correct params
mock_kpipeline.assert_called_once_with(
lang_code="e", model=kokoro_backend._model, device=kokoro_backend._device
)
assert pipeline_e == mock_pipeline
assert kokoro_backend._pipelines["e"] == mock_pipeline
def test_get_pipeline_reuses_existing(kokoro_backend):
"""Test that _get_pipeline reuses existing pipeline for same language code."""
# Mock loaded state
kokoro_backend._model = MagicMock()
# Mock KPipeline
mock_pipeline = MagicMock()
with patch(
"api.src.inference.kokoro_v1.KPipeline", return_value=mock_pipeline
) as mock_kpipeline:
# Get pipeline twice for same language
pipeline1 = kokoro_backend._get_pipeline("e")
pipeline2 = kokoro_backend._get_pipeline("e")
# Should only create pipeline once
mock_kpipeline.assert_called_once()
assert pipeline1 == pipeline2
assert kokoro_backend._pipelines["e"] == mock_pipeline
@pytest.mark.asyncio
async def test_generate_uses_correct_pipeline(kokoro_backend):
"""Test that generate uses correct pipeline for language code."""
# Mock loaded state
kokoro_backend._model = MagicMock()
# Mock voice path handling
with (
patch("api.src.core.paths.load_voice_tensor") as mock_load_voice,
patch("api.src.core.paths.save_voice_tensor"),
patch("tempfile.gettempdir") as mock_tempdir,
):
mock_load_voice.return_value = torch.ones(1)
mock_tempdir.return_value = "/tmp"
# Mock KPipeline
mock_pipeline = MagicMock()
mock_pipeline.return_value = iter([]) # Empty generator for testing
with patch("api.src.inference.kokoro_v1.KPipeline", return_value=mock_pipeline):
# Generate with Spanish voice and explicit lang_code
async for _ in kokoro_backend.generate("test", "ef_voice", lang_code="e"):
pass
# Should create pipeline with Spanish lang_code
assert "e" in kokoro_backend._pipelines
# Use ANY to match the temp file path since it's dynamic
mock_pipeline.assert_called_with(
"test",
voice=ANY, # Don't check exact path since it's dynamic
speed=1.0,
model=kokoro_backend._model,
)
# Verify the voice path is a temp file path
call_args = mock_pipeline.call_args
assert isinstance(call_args[1]["voice"], str)
assert call_args[1]["voice"].startswith("/tmp/temp_voice_")

View file

@ -1,317 +0,0 @@
"""Tests for text normalization service"""
import pytest
from api.src.services.text_processing.normalizer import normalize_text
from api.src.structures.schemas import NormalizationOptions
def test_url_protocols():
"""Test URL protocol handling"""
assert (
normalize_text(
"Check out https://example.com",
normalization_options=NormalizationOptions(),
)
== "Check out https example dot com"
)
assert (
normalize_text(
"Visit http://site.com", normalization_options=NormalizationOptions()
)
== "Visit http site dot com"
)
assert (
normalize_text(
"Go to https://test.org/path", normalization_options=NormalizationOptions()
)
== "Go to https test dot org slash path"
)
def test_url_www():
"""Test www prefix handling"""
assert (
normalize_text(
"Go to www.example.com", normalization_options=NormalizationOptions()
)
== "Go to www example dot com"
)
assert (
normalize_text(
"Visit www.test.org/docs", normalization_options=NormalizationOptions()
)
== "Visit www test dot org slash docs"
)
assert (
normalize_text(
"Check www.site.com?q=test", normalization_options=NormalizationOptions()
)
== "Check www site dot com question-mark q equals test"
)
def test_url_localhost():
"""Test localhost URL handling"""
assert (
normalize_text(
"Running on localhost:7860", normalization_options=NormalizationOptions()
)
== "Running on localhost colon seventy-eight sixty"
)
assert (
normalize_text(
"Server at localhost:8080/api", normalization_options=NormalizationOptions()
)
== "Server at localhost colon eighty eighty slash api"
)
assert (
normalize_text(
"Test localhost:3000/test?v=1", normalization_options=NormalizationOptions()
)
== "Test localhost colon three thousand slash test question-mark v equals one"
)
def test_url_ip_addresses():
"""Test IP address URL handling"""
assert (
normalize_text(
"Access 0.0.0.0:9090/test", normalization_options=NormalizationOptions()
)
== "Access zero dot zero dot zero dot zero colon ninety ninety slash test"
)
assert (
normalize_text(
"API at 192.168.1.1:8000", normalization_options=NormalizationOptions()
)
== "API at one hundred and ninety-two dot one hundred and sixty-eight dot one dot one colon eight thousand"
)
assert (
normalize_text("Server 127.0.0.1", normalization_options=NormalizationOptions())
== "Server one hundred and twenty-seven dot zero dot zero dot one"
)
def test_url_raw_domains():
"""Test raw domain handling"""
assert (
normalize_text(
"Visit google.com/search", normalization_options=NormalizationOptions()
)
== "Visit google dot com slash search"
)
assert (
normalize_text(
"Go to example.com/path?q=test",
normalization_options=NormalizationOptions(),
)
== "Go to example dot com slash path question-mark q equals test"
)
assert (
normalize_text(
"Check docs.test.com", normalization_options=NormalizationOptions()
)
== "Check docs dot test dot com"
)
def test_url_email_addresses():
"""Test email address handling"""
assert (
normalize_text(
"Email me at user@example.com", normalization_options=NormalizationOptions()
)
== "Email me at user at example dot com"
)
assert (
normalize_text(
"Contact admin@test.org", normalization_options=NormalizationOptions()
)
== "Contact admin at test dot org"
)
assert (
normalize_text(
"Send to test.user@site.com", normalization_options=NormalizationOptions()
)
== "Send to test dot user at site dot com"
)
def test_money():
"""Test that money text is normalized correctly"""
assert (
normalize_text(
"He lost $5.3 thousand.", normalization_options=NormalizationOptions()
)
== "He lost five point three thousand dollars."
)
assert (
normalize_text(
"He went gambling and lost about $25.05k.",
normalization_options=NormalizationOptions(),
)
== "He went gambling and lost about twenty-five point zero five thousand dollars."
)
assert (
normalize_text(
"To put it weirdly -$6.9 million",
normalization_options=NormalizationOptions(),
)
== "To put it weirdly minus six point nine million dollars"
)
assert (
normalize_text("It costs $50.3.", normalization_options=NormalizationOptions())
== "It costs fifty dollars and thirty cents."
)
assert (
normalize_text(
"The plant cost $200,000.8.", normalization_options=NormalizationOptions()
)
== "The plant cost two hundred thousand dollars and eighty cents."
)
assert (
normalize_text(
"€30.2 is in euros", normalization_options=NormalizationOptions()
)
== "thirty euros and twenty cents is in euros"
)
def test_time():
"""Test time normalization"""
assert (
normalize_text(
"Your flight leaves at 10:35 pm",
normalization_options=NormalizationOptions(),
)
== "Your flight leaves at ten thirty-five pm"
)
assert (
normalize_text(
"He departed for london around 5:03 am.",
normalization_options=NormalizationOptions(),
)
== "He departed for london around five oh three am."
)
assert (
normalize_text(
"Only the 13:42 and 15:12 slots are available.",
normalization_options=NormalizationOptions(),
)
== "Only the thirteen forty-two and fifteen twelve slots are available."
)
assert (
normalize_text(
"It is currently 1:00 pm", normalization_options=NormalizationOptions()
)
== "It is currently one pm"
)
assert (
normalize_text(
"It is currently 3:00", normalization_options=NormalizationOptions()
)
== "It is currently three o'clock"
)
assert (
normalize_text(
"12:00 am is midnight", normalization_options=NormalizationOptions()
)
== "twelve am is midnight"
)
def test_number():
"""Test number normalization"""
assert (
normalize_text(
"I bought 1035 cans of soda", normalization_options=NormalizationOptions()
)
== "I bought one thousand and thirty-five cans of soda"
)
assert (
normalize_text(
"The bus has a maximum capacity of 62 people",
normalization_options=NormalizationOptions(),
)
== "The bus has a maximum capacity of sixty-two people"
)
assert (
normalize_text(
"There are 1300 products left in stock",
normalization_options=NormalizationOptions(),
)
== "There are one thousand, three hundred products left in stock"
)
assert (
normalize_text(
"The population is 7,890,000 people.",
normalization_options=NormalizationOptions(),
)
== "The population is seven million, eight hundred and ninety thousand people."
)
assert (
normalize_text(
"He looked around but only found 1.6k of the 10k bricks",
normalization_options=NormalizationOptions(),
)
== "He looked around but only found one point six thousand of the ten thousand bricks"
)
assert (
normalize_text(
"The book has 342 pages.", normalization_options=NormalizationOptions()
)
== "The book has three hundred and forty-two pages."
)
assert (
normalize_text(
"He made -50 sales today.", normalization_options=NormalizationOptions()
)
== "He made minus fifty sales today."
)
assert (
normalize_text(
"56.789 to the power of 1.35 million",
normalization_options=NormalizationOptions(),
)
== "fifty-six point seven eight nine to the power of one point three five million"
)
def test_non_url_text():
"""Test that non-URL text is unaffected"""
assert (
normalize_text(
"This is not.a.url text", normalization_options=NormalizationOptions()
)
== "This is not-a-url text"
)
assert (
normalize_text(
"Hello, how are you today?", normalization_options=NormalizationOptions()
)
== "Hello, how are you today?"
)
assert (
normalize_text("It costs $50.", normalization_options=NormalizationOptions())
== "It costs fifty dollars."
)

View file

@ -1,499 +0,0 @@
import asyncio
import json
import os
from typing import AsyncGenerator, Tuple
from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
from fastapi.testclient import TestClient
from api.src.core.config import settings
from api.src.inference.base import AudioChunk
from api.src.main import app
from api.src.routers.openai_compatible import (
get_tts_service,
load_openai_mappings,
stream_audio_chunks,
)
from api.src.services.streaming_audio_writer import StreamingAudioWriter
from api.src.services.tts_service import TTSService
from api.src.structures.schemas import OpenAISpeechRequest
client = TestClient(app)
@pytest.fixture
def test_voice():
"""Fixture providing a test voice name."""
return "test_voice"
@pytest.fixture
def mock_openai_mappings():
"""Mock OpenAI mappings for testing."""
with patch(
"api.src.routers.openai_compatible._openai_mappings",
{
"models": {"tts-1": "kokoro-v1_0", "tts-1-hd": "kokoro-v1_0"},
"voices": {"alloy": "am_adam", "nova": "bf_isabella"},
},
):
yield
@pytest.fixture
def mock_json_file(tmp_path):
"""Create a temporary mock JSON file."""
content = {
"models": {"test-model": "test-kokoro"},
"voices": {"test-voice": "test-internal"},
}
json_file = tmp_path / "test_mappings.json"
json_file.write_text(json.dumps(content))
return json_file
def test_load_openai_mappings(mock_json_file):
"""Test loading OpenAI mappings from JSON file"""
with patch("os.path.join", return_value=str(mock_json_file)):
mappings = load_openai_mappings()
assert "models" in mappings
assert "voices" in mappings
assert mappings["models"]["test-model"] == "test-kokoro"
assert mappings["voices"]["test-voice"] == "test-internal"
def test_load_openai_mappings_file_not_found():
"""Test handling of missing mappings file"""
with patch("os.path.join", return_value="/nonexistent/path"):
mappings = load_openai_mappings()
assert mappings == {"models": {}, "voices": {}}
def test_list_models(mock_openai_mappings):
"""Test listing available models endpoint"""
response = client.get("/v1/models")
assert response.status_code == 200
data = response.json()
assert data["object"] == "list"
assert isinstance(data["data"], list)
assert len(data["data"]) == 3 # tts-1, tts-1-hd, and kokoro
# Verify all expected models are present
model_ids = [model["id"] for model in data["data"]]
assert "tts-1" in model_ids
assert "tts-1-hd" in model_ids
assert "kokoro" in model_ids
# Verify model format
for model in data["data"]:
assert model["object"] == "model"
assert "created" in model
assert model["owned_by"] == "kokoro"
def test_retrieve_model(mock_openai_mappings):
"""Test retrieving a specific model endpoint"""
# Test successful model retrieval
response = client.get("/v1/models/tts-1")
assert response.status_code == 200
data = response.json()
assert data["id"] == "tts-1"
assert data["object"] == "model"
assert data["owned_by"] == "kokoro"
assert "created" in data
# Test non-existent model
response = client.get("/v1/models/nonexistent-model")
assert response.status_code == 404
error = response.json()
assert error["detail"]["error"] == "model_not_found"
assert "not found" in error["detail"]["message"]
assert error["detail"]["type"] == "invalid_request_error"
@pytest.mark.asyncio
async def test_get_tts_service_initialization():
"""Test TTSService initialization"""
with patch("api.src.routers.openai_compatible._tts_service", None):
with patch("api.src.routers.openai_compatible._init_lock", None):
with patch("api.src.services.tts_service.TTSService.create") as mock_create:
mock_service = AsyncMock()
mock_create.return_value = mock_service
# Test concurrent access
async def get_service():
return await get_tts_service()
# Create multiple concurrent requests
tasks = [get_service() for _ in range(5)]
results = await asyncio.gather(*tasks)
# Verify service was created only once
mock_create.assert_called_once()
assert all(r == mock_service for r in results)
@pytest.mark.asyncio
async def test_stream_audio_chunks_client_disconnect():
"""Test handling of client disconnect during streaming"""
mock_request = MagicMock()
mock_request.is_disconnected = AsyncMock(return_value=True)
mock_service = AsyncMock()
async def mock_stream(*args, **kwargs):
for i in range(5):
yield AudioChunk(np.ndarray([], np.int16), output=b"chunk")
mock_service.generate_audio_stream = mock_stream
mock_service.list_voices.return_value = ["test_voice"]
request = OpenAISpeechRequest(
model="kokoro",
input="Test text",
voice="test_voice",
response_format="mp3",
stream=True,
speed=1.0,
)
writer = StreamingAudioWriter("mp3", 24000)
chunks = []
async for chunk in stream_audio_chunks(mock_service, request, mock_request, writer):
chunks.append(chunk)
writer.close()
assert len(chunks) == 0 # Should stop immediately due to disconnect
def test_openai_voice_mapping(mock_tts_service, mock_openai_mappings):
"""Test OpenAI voice name mapping"""
mock_tts_service.list_voices.return_value = ["am_adam", "bf_isabella"]
response = client.post(
"/v1/audio/speech",
json={
"model": "tts-1",
"input": "Hello world",
"voice": "alloy", # OpenAI voice name
"response_format": "mp3",
"stream": False,
},
)
assert response.status_code == 200
mock_tts_service.generate_audio.assert_called_once()
assert mock_tts_service.generate_audio.call_args[1]["voice"] == "am_adam"
def test_openai_voice_mapping_streaming(
mock_tts_service, mock_openai_mappings, mock_audio_bytes
):
"""Test OpenAI voice mapping in streaming mode"""
mock_tts_service.list_voices.return_value = ["am_adam", "bf_isabella"]
response = client.post(
"/v1/audio/speech",
json={
"model": "tts-1-hd",
"input": "Hello world",
"voice": "nova", # OpenAI voice name
"response_format": "mp3",
"stream": True,
},
)
assert response.status_code == 200
content = b""
for chunk in response.iter_bytes():
content += chunk
assert content == mock_audio_bytes
def test_invalid_openai_model(mock_tts_service, mock_openai_mappings):
"""Test error handling for invalid OpenAI model"""
response = client.post(
"/v1/audio/speech",
json={
"model": "invalid-model",
"input": "Hello world",
"voice": "alloy",
"response_format": "mp3",
"stream": False,
},
)
assert response.status_code == 400
error_response = response.json()
assert error_response["detail"]["error"] == "invalid_model"
assert "Unsupported model" in error_response["detail"]["message"]
@pytest.fixture
def mock_audio_bytes():
"""Mock audio bytes for testing."""
return b"mock audio data"
@pytest.fixture
def mock_tts_service(mock_audio_bytes):
"""Mock TTS service for testing."""
with patch("api.src.routers.openai_compatible.get_tts_service") as mock_get:
service = AsyncMock(spec=TTSService)
service.generate_audio.return_value = AudioChunk(np.zeros(1000, np.int16))
async def mock_stream(*args, **kwargs) -> AsyncGenerator[AudioChunk, None]:
yield AudioChunk(np.ndarray([], np.int16), output=mock_audio_bytes)
service.generate_audio_stream = mock_stream
service.list_voices.return_value = ["test_voice", "voice1", "voice2"]
service.combine_voices.return_value = "voice1_voice2"
mock_get.return_value = service
mock_get.side_effect = None
yield service
@patch("api.src.services.audio.AudioService.convert_audio")
def test_openai_speech_endpoint(
mock_convert, mock_tts_service, test_voice, mock_audio_bytes
):
"""Test the OpenAI-compatible speech endpoint with basic MP3 generation"""
# Configure mocks
mock_tts_service.generate_audio.return_value = AudioChunk(np.zeros(1000, np.int16))
mock_convert.return_value = AudioChunk(
np.zeros(1000, np.int16), output=mock_audio_bytes
)
response = client.post(
"/v1/audio/speech",
json={
"model": "kokoro",
"input": "Hello world",
"voice": test_voice,
"response_format": "mp3",
"stream": False,
},
)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/mpeg"
assert len(response.content) > 0
assert response.content == mock_audio_bytes + mock_audio_bytes
mock_tts_service.generate_audio.assert_called_once()
assert mock_convert.call_count == 2
def test_openai_speech_streaming(mock_tts_service, test_voice, mock_audio_bytes):
"""Test the OpenAI-compatible speech endpoint with streaming"""
response = client.post(
"/v1/audio/speech",
json={
"model": "kokoro",
"input": "Hello world",
"voice": test_voice,
"response_format": "mp3",
"stream": True,
},
)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/mpeg"
assert "Transfer-Encoding" in response.headers
assert response.headers["Transfer-Encoding"] == "chunked"
content = b""
for chunk in response.iter_bytes():
content += chunk
assert content == mock_audio_bytes
def test_openai_speech_pcm_streaming(mock_tts_service, test_voice, mock_audio_bytes):
"""Test PCM streaming format"""
response = client.post(
"/v1/audio/speech",
json={
"model": "kokoro",
"input": "Hello world",
"voice": test_voice,
"response_format": "pcm",
"stream": True,
},
)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/pcm"
content = b""
for chunk in response.iter_bytes():
content += chunk
assert content == mock_audio_bytes
def test_openai_speech_invalid_voice(mock_tts_service):
"""Test error handling for invalid voice"""
mock_tts_service.generate_audio.side_effect = ValueError(
"Voice 'invalid_voice' not found"
)
response = client.post(
"/v1/audio/speech",
json={
"model": "kokoro",
"input": "Hello world",
"voice": "invalid_voice",
"response_format": "mp3",
"stream": False,
},
)
assert response.status_code == 400
error_response = response.json()
assert error_response["detail"]["error"] == "validation_error"
assert "Voice 'invalid_voice' not found" in error_response["detail"]["message"]
assert error_response["detail"]["type"] == "invalid_request_error"
def test_openai_speech_empty_text(mock_tts_service, test_voice):
"""Test error handling for empty text"""
async def mock_error_stream(*args, **kwargs):
raise ValueError("Text is empty after preprocessing")
mock_tts_service.generate_audio = mock_error_stream
mock_tts_service.list_voices.return_value = ["test_voice"]
response = client.post(
"/v1/audio/speech",
json={
"model": "kokoro",
"input": "",
"voice": test_voice,
"response_format": "mp3",
"stream": False,
},
)
assert response.status_code == 400
error_response = response.json()
assert error_response["detail"]["error"] == "validation_error"
assert "Text is empty after preprocessing" in error_response["detail"]["message"]
assert error_response["detail"]["type"] == "invalid_request_error"
def test_openai_speech_invalid_format(mock_tts_service, test_voice):
"""Test error handling for invalid format"""
response = client.post(
"/v1/audio/speech",
json={
"model": "kokoro",
"input": "Hello world",
"voice": test_voice,
"response_format": "invalid_format",
"stream": False,
},
)
assert response.status_code == 422 # Validation error from Pydantic
def test_list_voices(mock_tts_service):
"""Test listing available voices"""
# Override the mock for this specific test
mock_tts_service.list_voices.return_value = ["voice1", "voice2"]
response = client.get("/v1/audio/voices")
assert response.status_code == 200
data = response.json()
assert "voices" in data
assert len(data["voices"]) == 2
assert "voice1" in data["voices"]
assert "voice2" in data["voices"]
@patch("api.src.routers.openai_compatible.settings")
def test_combine_voices(mock_settings, mock_tts_service):
"""Test combining voices endpoint"""
# Enable local voice saving for this test
mock_settings.allow_local_voice_saving = True
response = client.post("/v1/audio/voices/combine", json="voice1+voice2")
assert response.status_code == 200
assert response.headers["content-type"] == "application/octet-stream"
assert "voice1+voice2.pt" in response.headers["content-disposition"]
def test_server_error(mock_tts_service, test_voice):
"""Test handling of server errors"""
async def mock_error_stream(*args, **kwargs):
raise RuntimeError("Internal server error")
mock_tts_service.generate_audio = mock_error_stream
mock_tts_service.list_voices.return_value = ["test_voice"]
response = client.post(
"/v1/audio/speech",
json={
"model": "kokoro",
"input": "Hello world",
"voice": test_voice,
"response_format": "mp3",
"stream": False,
},
)
assert response.status_code == 500
error_response = response.json()
assert error_response["detail"]["error"] == "processing_error"
assert error_response["detail"]["type"] == "server_error"
def test_streaming_error(mock_tts_service, test_voice):
"""Test handling streaming errors"""
# Mock process_voices to raise the error
mock_tts_service.list_voices.side_effect = RuntimeError("Streaming failed")
response = client.post(
"/v1/audio/speech",
json={
"model": "kokoro",
"input": "Hello world",
"voice": test_voice,
"response_format": "mp3",
"stream": True,
},
)
assert response.status_code == 500
error_data = response.json()
assert error_data["detail"]["error"] == "processing_error"
assert error_data["detail"]["type"] == "server_error"
assert "Streaming failed" in error_data["detail"]["message"]
@pytest.mark.asyncio
async def test_streaming_initialization_error():
"""Test handling of streaming initialization errors"""
mock_service = AsyncMock()
async def mock_error_stream(*args, **kwargs):
if False: # This makes it a proper generator
yield b""
raise RuntimeError("Failed to initialize stream")
mock_service.generate_audio_stream = mock_error_stream
mock_service.list_voices.return_value = ["test_voice"]
request = OpenAISpeechRequest(
model="kokoro",
input="Test text",
voice="test_voice",
response_format="mp3",
stream=True,
speed=1.0,
)
writer = StreamingAudioWriter("mp3", 24000)
with pytest.raises(RuntimeError) as exc:
async for _ in stream_audio_chunks(mock_service, request, MagicMock(), writer):
pass
writer.close()
assert "Failed to initialize stream" in str(exc.value)

View file

@ -1,138 +0,0 @@
import os
from unittest.mock import patch
import pytest
from api.src.core.paths import (
_find_file,
_scan_directories,
get_content_type,
get_temp_dir_size,
get_temp_file_path,
list_temp_files,
)
@pytest.mark.asyncio
async def test_find_file_exists():
"""Test finding existing file."""
with patch("aiofiles.os.path.exists") as mock_exists:
mock_exists.return_value = True
path = await _find_file("test.txt", ["/test/path"])
assert path == "/test/path/test.txt"
@pytest.mark.asyncio
async def test_find_file_not_exists():
"""Test finding non-existent file."""
with patch("aiofiles.os.path.exists") as mock_exists:
mock_exists.return_value = False
with pytest.raises(FileNotFoundError, match="File not found"):
await _find_file("test.txt", ["/test/path"])
@pytest.mark.asyncio
async def test_find_file_with_filter():
"""Test finding file with filter function."""
with patch("aiofiles.os.path.exists") as mock_exists:
mock_exists.return_value = True
filter_fn = lambda p: p.endswith(".txt")
path = await _find_file("test.txt", ["/test/path"], filter_fn)
assert path == "/test/path/test.txt"
@pytest.mark.asyncio
async def test_scan_directories():
"""Test scanning directories."""
mock_entry = type("MockEntry", (), {"name": "test.txt"})()
with (
patch("aiofiles.os.path.exists") as mock_exists,
patch("aiofiles.os.scandir") as mock_scandir,
):
mock_exists.return_value = True
mock_scandir.return_value = [mock_entry]
files = await _scan_directories(["/test/path"])
assert "test.txt" in files
@pytest.mark.asyncio
async def test_get_content_type():
"""Test content type detection."""
test_cases = [
("test.html", "text/html"),
("test.js", "application/javascript"),
("test.css", "text/css"),
("test.png", "image/png"),
("test.unknown", "application/octet-stream"),
]
for filename, expected in test_cases:
content_type = await get_content_type(filename)
assert content_type == expected
@pytest.mark.asyncio
async def test_get_temp_file_path():
"""Test temp file path generation."""
with (
patch("aiofiles.os.path.exists") as mock_exists,
patch("aiofiles.os.makedirs") as mock_makedirs,
):
mock_exists.return_value = False
path = await get_temp_file_path("test.wav")
assert "test.wav" in path
mock_makedirs.assert_called_once()
@pytest.mark.asyncio
async def test_list_temp_files():
"""Test listing temp files."""
class MockEntry:
def __init__(self, name):
self.name = name
def is_file(self):
return True
mock_entry = MockEntry("test.wav")
with (
patch("aiofiles.os.path.exists") as mock_exists,
patch("aiofiles.os.scandir") as mock_scandir,
):
mock_exists.return_value = True
mock_scandir.return_value = [mock_entry]
files = await list_temp_files()
assert "test.wav" in files
@pytest.mark.asyncio
async def test_get_temp_dir_size():
"""Test getting temp directory size."""
class MockEntry:
def __init__(self, path):
self.path = path
def is_file(self):
return True
mock_entry = MockEntry("/tmp/test.wav")
mock_stat = type("MockStat", (), {"st_size": 1024})()
with (
patch("aiofiles.os.path.exists") as mock_exists,
patch("aiofiles.os.scandir") as mock_scandir,
patch("aiofiles.os.stat") as mock_stat_fn,
):
mock_exists.return_value = True
mock_scandir.return_value = [mock_entry]
mock_stat_fn.return_value = mock_stat
size = await get_temp_dir_size()
assert size == 1024

View file

@ -1,167 +0,0 @@
import pytest
from api.src.services.text_processing.text_processor import (
get_sentence_info,
process_text_chunk,
smart_split,
)
def test_process_text_chunk_basic():
"""Test basic text chunk processing."""
text = "Hello world"
tokens = process_text_chunk(text)
assert isinstance(tokens, list)
assert len(tokens) > 0
def test_process_text_chunk_empty():
"""Test processing empty text."""
text = ""
tokens = process_text_chunk(text)
assert isinstance(tokens, list)
assert len(tokens) == 0
def test_process_text_chunk_phonemes():
"""Test processing with skip_phonemize."""
phonemes = "h @ l @U" # Example phoneme sequence
tokens = process_text_chunk(phonemes, skip_phonemize=True)
assert isinstance(tokens, list)
assert len(tokens) > 0
def test_get_sentence_info():
"""Test sentence splitting and info extraction."""
text = "This is sentence one. This is sentence two! What about three?"
results = get_sentence_info(text, {})
assert len(results) == 3
for sentence, tokens, count in results:
assert isinstance(sentence, str)
assert isinstance(tokens, list)
assert isinstance(count, int)
assert count == len(tokens)
assert count > 0
def test_get_sentence_info_phenomoes():
"""Test sentence splitting and info extraction."""
text = (
"This is sentence one. This is </|custom_phonemes_0|/> two! What about three?"
)
results = get_sentence_info(text, {"</|custom_phonemes_0|/>": r"sˈɛntᵊns"})
assert len(results) == 3
assert "sˈɛntᵊns" in results[1][0]
for sentence, tokens, count in results:
assert isinstance(sentence, str)
assert isinstance(tokens, list)
assert isinstance(count, int)
assert count == len(tokens)
assert count > 0
@pytest.mark.asyncio
async def test_smart_split_short_text():
"""Test smart splitting with text under max tokens."""
text = "This is a short test sentence."
chunks = []
async for chunk_text, chunk_tokens in smart_split(text):
chunks.append((chunk_text, chunk_tokens))
assert len(chunks) == 1
assert isinstance(chunks[0][0], str)
assert isinstance(chunks[0][1], list)
@pytest.mark.asyncio
async def test_smart_split_long_text():
"""Test smart splitting with longer text."""
# Create text that should split into multiple chunks
text = ". ".join(["This is test sentence number " + str(i) for i in range(20)])
chunks = []
async for chunk_text, chunk_tokens in smart_split(text):
chunks.append((chunk_text, chunk_tokens))
assert len(chunks) > 1
for chunk_text, chunk_tokens in chunks:
assert isinstance(chunk_text, str)
assert isinstance(chunk_tokens, list)
assert len(chunk_tokens) > 0
@pytest.mark.asyncio
async def test_smart_split_with_punctuation():
"""Test smart splitting handles punctuation correctly."""
text = "First sentence! Second sentence? Third sentence; Fourth sentence: Fifth sentence."
chunks = []
async for chunk_text, chunk_tokens in smart_split(text):
chunks.append(chunk_text)
# Verify punctuation is preserved
assert all(any(p in chunk for p in "!?;:.") for chunk in chunks)
def test_process_text_chunk_chinese_phonemes():
"""Test processing with Chinese pinyin phonemes."""
pinyin = "nǐ hǎo lì" # Example pinyin sequence with tones
tokens = process_text_chunk(pinyin, skip_phonemize=True, language="z")
assert isinstance(tokens, list)
assert len(tokens) > 0
def test_get_sentence_info_chinese():
"""Test Chinese sentence splitting and info extraction."""
text = "这是一个句子。这是第二个句子!第三个问题?"
results = get_sentence_info(text, {}, lang_code="z")
assert len(results) == 3
for sentence, tokens, count in results:
assert isinstance(sentence, str)
assert isinstance(tokens, list)
assert isinstance(count, int)
assert count == len(tokens)
assert count > 0
@pytest.mark.asyncio
async def test_smart_split_chinese_short():
"""Test Chinese smart splitting with short text."""
text = "这是一句话。"
chunks = []
async for chunk_text, chunk_tokens in smart_split(text, lang_code="z"):
chunks.append((chunk_text, chunk_tokens))
assert len(chunks) == 1
assert isinstance(chunks[0][0], str)
assert isinstance(chunks[0][1], list)
@pytest.mark.asyncio
async def test_smart_split_chinese_long():
"""Test Chinese smart splitting with longer text."""
text = "".join([f"测试句子 {i}" for i in range(20)])
chunks = []
async for chunk_text, chunk_tokens in smart_split(text, lang_code="z"):
chunks.append((chunk_text, chunk_tokens))
assert len(chunks) > 1
for chunk_text, chunk_tokens in chunks:
assert isinstance(chunk_text, str)
assert isinstance(chunk_tokens, list)
assert len(chunk_tokens) > 0
@pytest.mark.asyncio
async def test_smart_split_chinese_punctuation():
"""Test Chinese smart splitting with punctuation preservation."""
text = "第一句!第二问?第三句;第四句:第五句。"
chunks = []
async for chunk_text, _ in smart_split(text, lang_code="z"):
chunks.append(chunk_text)
# Verify Chinese punctuation is preserved
assert all(any(p in chunk for p in "!?;:。") for chunk in chunks)

View file

@ -1,126 +0,0 @@
from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
import torch
from api.src.services.tts_service import TTSService
@pytest.fixture
def mock_managers():
"""Mock model and voice managers."""
async def _mock_managers():
model_manager = AsyncMock()
model_manager.get_backend.return_value = MagicMock()
voice_manager = AsyncMock()
voice_manager.get_voice_path.return_value = "/path/to/voice.pt"
voice_manager.list_voices.return_value = ["voice1", "voice2"]
with (
patch("api.src.services.tts_service.get_model_manager") as mock_get_model,
patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice,
):
mock_get_model.return_value = model_manager
mock_get_voice.return_value = voice_manager
return model_manager, voice_manager
return _mock_managers()
@pytest.fixture
def tts_service(mock_managers):
"""Create TTSService instance with mocked dependencies."""
async def _create_service():
return await TTSService.create("test_output")
return _create_service()
@pytest.mark.asyncio
async def test_service_creation():
"""Test service creation and initialization."""
model_manager = AsyncMock()
voice_manager = AsyncMock()
with (
patch("api.src.services.tts_service.get_model_manager") as mock_get_model,
patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice,
):
mock_get_model.return_value = model_manager
mock_get_voice.return_value = voice_manager
service = await TTSService.create("test_output")
assert service.output_dir == "test_output"
assert service.model_manager is model_manager
assert service._voice_manager is voice_manager
@pytest.mark.asyncio
async def test_get_voice_path_single():
"""Test getting path for single voice."""
model_manager = AsyncMock()
voice_manager = AsyncMock()
voice_manager.get_voice_path.return_value = "/path/to/voice1.pt"
with (
patch("api.src.services.tts_service.get_model_manager") as mock_get_model,
patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice,
):
mock_get_model.return_value = model_manager
mock_get_voice.return_value = voice_manager
service = await TTSService.create("test_output")
name, path = await service._get_voices_path("voice1")
assert name == "voice1"
assert path == "/path/to/voice1.pt"
voice_manager.get_voice_path.assert_called_once_with("voice1")
@pytest.mark.asyncio
async def test_get_voice_path_combined():
"""Test getting path for combined voices."""
model_manager = AsyncMock()
voice_manager = AsyncMock()
voice_manager.get_voice_path.return_value = "/path/to/voice.pt"
with (
patch("api.src.services.tts_service.get_model_manager") as mock_get_model,
patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice,
patch("torch.load") as mock_load,
patch("torch.save") as mock_save,
patch("tempfile.gettempdir") as mock_temp,
):
mock_get_model.return_value = model_manager
mock_get_voice.return_value = voice_manager
mock_temp.return_value = "/tmp"
mock_load.return_value = torch.ones(10)
service = await TTSService.create("test_output")
name, path = await service._get_voices_path("voice1+voice2")
assert name == "voice1+voice2"
assert path.endswith("voice1+voice2.pt")
mock_save.assert_called_once()
@pytest.mark.asyncio
async def test_list_voices():
"""Test listing available voices."""
model_manager = AsyncMock()
voice_manager = AsyncMock()
voice_manager.list_voices.return_value = ["voice1", "voice2"]
with (
patch("api.src.services.tts_service.get_model_manager") as mock_get_model,
patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice,
):
mock_get_model.return_value = model_manager
mock_get_voice.return_value = voice_manager
service = await TTSService.create("test_output")
voices = await service.list_voices()
assert voices == ["voice1", "voice2"]
voice_manager.list_voices.assert_called_once()

View file

@ -1,67 +1,59 @@
FROM --platform=$BUILDPLATFORM nvidia/cuda:12.8.0-cudnn-runtime-ubuntu24.04
FROM --platform=$BUILDPLATFORM nvidia/cuda:12.9.0-base-ubuntu24.04
# Set non-interactive frontend
ENV DEBIAN_FRONTEND=noninteractive
# Install Python and other dependencies
RUN apt-get update && apt-get install -y \
python3.10 \
python3-venv \
espeak-ng \
espeak-ng-data \
git \
libsndfile1 \
curl \
ffmpeg \
g++ \
&& apt-get clean && rm -rf /var/lib/apt/lists/* \
&& mkdir -p /usr/share/espeak-ng-data \
&& ln -s /usr/lib/*/espeak-ng-data/* /usr/share/espeak-ng-data/
# Install UV using the installer script
RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
RUN apt-get update -y && \
apt-get install -y python3 python3-venv libsndfile1 curl ffmpeg g++ && \
apt-get clean && rm -rf /var/lib/apt/lists/* && \
curl -LsSf https://astral.sh/uv/install.sh | sh && \
mv /root/.local/bin/uv /usr/local/bin/ && \
mv /root/.local/bin/uvx /usr/local/bin/
# Create non-root user and set up directories and permissions
RUN useradd -m -u 1001 appuser && \
mv /root/.local/bin/uvx /usr/local/bin/ && \
useradd -m -u 1001 appuser && \
mkdir -p /app/api/src/models/v1_0 && \
chown -R appuser:appuser /app
USER appuser
WORKDIR /app
# Copy dependency files
COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml
ENV PHONEMIZER_ESPEAK_PATH=/usr/bin \
PHONEMIZER_ESPEAK_DATA=/usr/share/espeak-ng-data \
ESPEAK_DATA_PATH=/usr/share/espeak-ng-data
PYTHONUNBUFFERED=1 \
PYTHONPATH=/app:/app/api \
PATH="/app/.venv/bin:$PATH" \
UV_LINK_MODE=copy \
USE_GPU=true \
DEVICE="gpu"
# Install dependencies with GPU extras (using cache mounts)
RUN --mount=type=cache,target=/root/.cache/uv \
uv venv --python 3.10 && \
uv sync --extra gpu
# Copy project files including models
COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml
COPY --chown=appuser:appuser api ./api
COPY --chown=appuser:appuser web ./web
COPY --chown=appuser:appuser docker/scripts/ ./
RUN chmod +x ./entrypoint.sh
RUN --mount=type=cache,target=/root/.cache/uv \
uv venv --python 3.10 && \
uv sync --extra gpu && \
uv cache clean && \
python download_model.py --output api/src/models/v1_0
# Set all environment variables in one go
ENV PYTHONUNBUFFERED=1 \
PYTHONPATH=/app:/app/api \
PATH="/app/.venv/bin:$PATH" \
UV_LINK_MODE=copy \
USE_GPU=true
ENV DOWNLOAD_MODEL=true
# Download model if enabled
RUN if [ "$DOWNLOAD_MODEL" = "true" ]; then \
python download_model.py --output api/src/models/v1_0; \
fi
ENV DEVICE="gpu"
# Run FastAPI server through entrypoint.sh
CMD ["./entrypoint.sh"]
# If you want to test the docker image locally, run this from the project root:
# docker build -f docker\gpu\Dockerfile -t kokoro .
# Run it with
# docker run -p 8880:8880 --name kokoro kokoro --gpus all
#
# You can log into the container with
# docker exec -it kokoro /bin/bash
#
# Other commands:
# 1. Stop and remove container
# docker stop kokoro
# docker container remove kokoro
# 2. List and remove images
# docker images
# docker image remove kokoro

View file

@ -31,10 +31,9 @@ dependencies = [
"matplotlib>=3.10.0",
"mutagen>=1.47.0",
"psutil>=6.1.1",
"espeakng-loader==0.2.4",
"kokoro==0.9.2",
"misaki[en,ja,ko,zh]==0.9.3",
"spacy==3.8.5",
"spacy==3.8.7",
"en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl",
"inflect>=7.5.0",
"phonemizer-fork>=3.3.2",