diff --git a/api/tests/__init__.py b/api/tests/__init__.py deleted file mode 100644 index b9911d8..0000000 --- a/api/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Make tests directory a Python package diff --git a/api/tests/conftest.py b/api/tests/conftest.py deleted file mode 100644 index 2e3bba8..0000000 --- a/api/tests/conftest.py +++ /dev/null @@ -1,71 +0,0 @@ -import os -from pathlib import Path -from unittest.mock import AsyncMock, MagicMock, patch - -import numpy as np -import pytest -import pytest_asyncio -import torch - -from api.src.inference.model_manager import ModelManager -from api.src.inference.voice_manager import VoiceManager -from api.src.services.tts_service import TTSService -from api.src.structures.model_schemas import VoiceConfig - - -@pytest.fixture -def mock_voice_tensor(): - """Load a real voice tensor for testing.""" - voice_path = os.path.join( - os.path.dirname(os.path.dirname(__file__)), "src/voices/af_bella.pt" - ) - return torch.load(voice_path, map_location="cpu", weights_only=False) - - -@pytest.fixture -def mock_audio_output(): - """Load pre-generated test audio for consistent testing.""" - test_audio_path = os.path.join( - os.path.dirname(__file__), "test_data/test_audio.npy" - ) - return np.load(test_audio_path) # Return as numpy array instead of bytes - - -@pytest_asyncio.fixture -async def mock_model_manager(mock_audio_output): - """Mock model manager for testing.""" - manager = AsyncMock(spec=ModelManager) - manager.get_backend = MagicMock() - - async def mock_generate(*args, **kwargs): - # Simulate successful audio generation - return np.random.rand(24000).astype(np.float32) # 1 second of random audio data - - manager.generate = AsyncMock(side_effect=mock_generate) - return manager - - -@pytest_asyncio.fixture -async def mock_voice_manager(mock_voice_tensor): - """Mock voice manager for testing.""" - manager = AsyncMock(spec=VoiceManager) - manager.get_voice_path = MagicMock(return_value="/mock/path/voice.pt") - manager.load_voice = AsyncMock(return_value=mock_voice_tensor) - manager.list_voices = AsyncMock(return_value=["voice1", "voice2"]) - manager.combine_voices = AsyncMock(return_value="voice1_voice2") - return manager - - -@pytest_asyncio.fixture -async def tts_service(mock_model_manager, mock_voice_manager): - """Get mocked TTS service instance.""" - service = TTSService() - service.model_manager = mock_model_manager - service._voice_manager = mock_voice_manager - return service - - -@pytest.fixture -def test_voice(): - """Return a test voice name.""" - return "voice1" diff --git a/api/tests/test_audio_service.py b/api/tests/test_audio_service.py deleted file mode 100644 index 5ba5392..0000000 --- a/api/tests/test_audio_service.py +++ /dev/null @@ -1,256 +0,0 @@ -"""Tests for AudioService""" - -from unittest.mock import patch - -import numpy as np -import pytest - -from api.src.inference.base import AudioChunk -from api.src.services.audio import AudioNormalizer, AudioService -from api.src.services.streaming_audio_writer import StreamingAudioWriter - - -@pytest.fixture(autouse=True) -def mock_settings(): - """Mock settings for all tests""" - with patch("api.src.services.audio.settings") as mock_settings: - mock_settings.gap_trim_ms = 250 - yield mock_settings - - -@pytest.fixture -def sample_audio(): - """Generate a simple sine wave for testing""" - sample_rate = 24000 - duration = 0.1 # 100ms - t = np.linspace(0, duration, int(sample_rate * duration)) - frequency = 440 # A4 note - return np.sin(2 * np.pi * frequency * t).astype(np.float32), sample_rate - - -@pytest.mark.asyncio -async def test_convert_to_wav(sample_audio): - """Test converting to WAV format""" - audio_data, sample_rate = sample_audio - - writer = StreamingAudioWriter("wav", sample_rate=24000) - # Write and finalize in one step for WAV - audio_chunk = await AudioService.convert_audio( - AudioChunk(audio_data), "wav", writer, is_last_chunk=False - ) - - writer.close() - - assert isinstance(audio_chunk.output, bytes) - assert isinstance(audio_chunk, AudioChunk) - assert len(audio_chunk.output) > 0 - # Check WAV header - assert audio_chunk.output.startswith(b"RIFF") - assert b"WAVE" in audio_chunk.output[:12] - - -@pytest.mark.asyncio -async def test_convert_to_mp3(sample_audio): - """Test converting to MP3 format""" - audio_data, sample_rate = sample_audio - - writer = StreamingAudioWriter("mp3", sample_rate=24000) - - audio_chunk = await AudioService.convert_audio( - AudioChunk(audio_data), "mp3", writer - ) - - writer.close() - - assert isinstance(audio_chunk.output, bytes) - assert isinstance(audio_chunk, AudioChunk) - assert len(audio_chunk.output) > 0 - # Check MP3 header (ID3 or MPEG frame sync) - assert audio_chunk.output.startswith(b"ID3") or audio_chunk.output.startswith( - b"\xff\xfb" - ) - - -@pytest.mark.asyncio -async def test_convert_to_opus(sample_audio): - """Test converting to Opus format""" - - audio_data, sample_rate = sample_audio - - writer = StreamingAudioWriter("opus", sample_rate=24000) - - audio_chunk = await AudioService.convert_audio( - AudioChunk(audio_data), "opus", writer - ) - - writer.close() - - assert isinstance(audio_chunk.output, bytes) - assert isinstance(audio_chunk, AudioChunk) - assert len(audio_chunk.output) > 0 - # Check OGG header - assert audio_chunk.output.startswith(b"OggS") - - -@pytest.mark.asyncio -async def test_convert_to_flac(sample_audio): - """Test converting to FLAC format""" - audio_data, sample_rate = sample_audio - - writer = StreamingAudioWriter("flac", sample_rate=24000) - - audio_chunk = await AudioService.convert_audio( - AudioChunk(audio_data), "flac", writer - ) - - writer.close() - - assert isinstance(audio_chunk.output, bytes) - assert isinstance(audio_chunk, AudioChunk) - assert len(audio_chunk.output) > 0 - # Check FLAC header - assert audio_chunk.output.startswith(b"fLaC") - - -@pytest.mark.asyncio -async def test_convert_to_aac(sample_audio): - """Test converting to M4A format""" - audio_data, sample_rate = sample_audio - - writer = StreamingAudioWriter("aac", sample_rate=24000) - - audio_chunk = await AudioService.convert_audio( - AudioChunk(audio_data), "aac", writer - ) - - writer.close() - - assert isinstance(audio_chunk.output, bytes) - assert isinstance(audio_chunk, AudioChunk) - assert len(audio_chunk.output) > 0 - # Check ADTS header (AAC) - assert audio_chunk.output.startswith(b"\xff\xf0") or audio_chunk.output.startswith( - b"\xff\xf1" - ) - - -@pytest.mark.asyncio -async def test_convert_to_pcm(sample_audio): - """Test converting to PCM format""" - audio_data, sample_rate = sample_audio - - writer = StreamingAudioWriter("pcm", sample_rate=24000) - - audio_chunk = await AudioService.convert_audio( - AudioChunk(audio_data), "pcm", writer - ) - - writer.close() - - assert isinstance(audio_chunk.output, bytes) - assert isinstance(audio_chunk, AudioChunk) - assert len(audio_chunk.output) > 0 - # PCM is raw bytes, so no header to check - - -@pytest.mark.asyncio -async def test_convert_to_invalid_format_raises_error(sample_audio): - """Test that converting to an invalid format raises an error""" - # audio_data, sample_rate = sample_audio - with pytest.raises(ValueError, match="Unsupported format: invalid"): - writer = StreamingAudioWriter("invalid", sample_rate=24000) - - -@pytest.mark.asyncio -async def test_normalization_wav(sample_audio): - """Test that WAV output is properly normalized to int16 range""" - audio_data, sample_rate = sample_audio - - writer = StreamingAudioWriter("wav", sample_rate=24000) - - # Create audio data outside int16 range - large_audio = audio_data * 1e5 - # Write and finalize in one step for WAV - audio_chunk = await AudioService.convert_audio( - AudioChunk(large_audio), "wav", writer - ) - - writer.close() - - assert isinstance(audio_chunk.output, bytes) - assert isinstance(audio_chunk, AudioChunk) - assert len(audio_chunk.output) > 0 - - -@pytest.mark.asyncio -async def test_normalization_pcm(sample_audio): - """Test that PCM output is properly normalized to int16 range""" - audio_data, sample_rate = sample_audio - - writer = StreamingAudioWriter("pcm", sample_rate=24000) - - # Create audio data outside int16 range - large_audio = audio_data * 1e5 - audio_chunk = await AudioService.convert_audio( - AudioChunk(large_audio), "pcm", writer - ) - assert isinstance(audio_chunk.output, bytes) - assert isinstance(audio_chunk, AudioChunk) - assert len(audio_chunk.output) > 0 - - -@pytest.mark.asyncio -async def test_invalid_audio_data(): - """Test handling of invalid audio data""" - invalid_audio = np.array([]) # Empty array - sample_rate = 24000 - - writer = StreamingAudioWriter("wav", sample_rate=24000) - - with pytest.raises(ValueError): - await AudioService.convert_audio(invalid_audio, sample_rate, "wav", writer) - - -@pytest.mark.asyncio -async def test_different_sample_rates(sample_audio): - """Test converting audio with different sample rates""" - audio_data, _ = sample_audio - sample_rates = [8000, 16000, 44100, 48000] - - for rate in sample_rates: - writer = StreamingAudioWriter("wav", sample_rate=rate) - - audio_chunk = await AudioService.convert_audio( - AudioChunk(audio_data), "wav", writer - ) - - writer.close() - - assert isinstance(audio_chunk.output, bytes) - assert isinstance(audio_chunk, AudioChunk) - assert len(audio_chunk.output) > 0 - - -@pytest.mark.asyncio -async def test_buffer_position_after_conversion(sample_audio): - """Test that buffer position is reset after writing""" - audio_data, sample_rate = sample_audio - - writer = StreamingAudioWriter("wav", sample_rate=24000) - - # Write and finalize in one step for first conversion - audio_chunk1 = await AudioService.convert_audio( - AudioChunk(audio_data), "wav", writer, is_last_chunk=True - ) - assert isinstance(audio_chunk1.output, bytes) - assert isinstance(audio_chunk1, AudioChunk) - # Convert again to ensure buffer was properly reset - - writer = StreamingAudioWriter("wav", sample_rate=24000) - - audio_chunk2 = await AudioService.convert_audio( - AudioChunk(audio_data), "wav", writer, is_last_chunk=True - ) - assert isinstance(audio_chunk2.output, bytes) - assert isinstance(audio_chunk2, AudioChunk) - assert len(audio_chunk1.output) == len(audio_chunk2.output) diff --git a/api/tests/test_data/generate_test_data.py b/api/tests/test_data/generate_test_data.py deleted file mode 100644 index 3f6b7cf..0000000 --- a/api/tests/test_data/generate_test_data.py +++ /dev/null @@ -1,23 +0,0 @@ -import os - -import numpy as np - - -def generate_test_audio(): - """Generate test audio data - 1 second of 440Hz tone""" - # Create 1 second of silence at 24kHz - audio = np.zeros(24000, dtype=np.float32) - - # Add a simple sine wave to make it non-zero - t = np.linspace(0, 1, 24000) - audio += 0.5 * np.sin(2 * np.pi * 440 * t) # 440 Hz tone at half amplitude - - # Create test_data directory if it doesn't exist - os.makedirs("api/tests/test_data", exist_ok=True) - - # Save the test audio - np.save("api/tests/test_data/test_audio.npy", audio) - - -if __name__ == "__main__": - generate_test_audio() diff --git a/api/tests/test_data/test_audio.npy b/api/tests/test_data/test_audio.npy deleted file mode 100644 index 2e06aa9..0000000 Binary files a/api/tests/test_data/test_audio.npy and /dev/null differ diff --git a/api/tests/test_development.py b/api/tests/test_development.py deleted file mode 100644 index a03b3ba..0000000 --- a/api/tests/test_development.py +++ /dev/null @@ -1,34 +0,0 @@ -import base64 -import json -from unittest.mock import MagicMock, patch - -import pytest -import requests - - -def test_generate_captioned_speech(): - """Test the generate_captioned_speech function with mocked responses""" - # Mock the API responses - mock_audio_response = MagicMock() - mock_audio_response.status_code = 200 - - mock_timestamps_response = MagicMock() - mock_timestamps_response.status_code = 200 - mock_timestamps_response.content = json.dumps( - { - "audio": base64.b64encode(b"mock audio data").decode("utf-8"), - "timestamps": [{"word": "test", "start_time": 0.0, "end_time": 1.0}], - } - ) - - # Patch the HTTP requests - with patch("requests.post", return_value=mock_timestamps_response): - # Import here to avoid module-level import issues - from examples.captioned_speech_example import generate_captioned_speech - - # Test the function - audio, timestamps = generate_captioned_speech("test text") - - # Verify we got both audio and timestamps - assert audio == b"mock audio data" - assert timestamps == [{"word": "test", "start_time": 0.0, "end_time": 1.0}] diff --git a/api/tests/test_kokoro_v1.py b/api/tests/test_kokoro_v1.py deleted file mode 100644 index 29d83c5..0000000 --- a/api/tests/test_kokoro_v1.py +++ /dev/null @@ -1,165 +0,0 @@ -from unittest.mock import ANY, MagicMock, patch - -import numpy as np -import pytest -import torch - -from api.src.inference.kokoro_v1 import KokoroV1 - - -@pytest.fixture -def kokoro_backend(): - """Create a KokoroV1 instance for testing.""" - return KokoroV1() - - -def test_initial_state(kokoro_backend): - """Test initial state of KokoroV1.""" - assert not kokoro_backend.is_loaded - assert kokoro_backend._model is None - assert kokoro_backend._pipelines == {} # Now using dict of pipelines - # Device should be set based on settings - assert kokoro_backend.device in ["cuda", "cpu"] - - -@patch("torch.cuda.is_available", return_value=True) -@patch("torch.cuda.memory_allocated", return_value=5e9) -def test_memory_management(mock_memory, mock_cuda, kokoro_backend): - """Test GPU memory management functions.""" - # Patch backend so it thinks we have cuda - with patch.object(kokoro_backend, "_device", "cuda"): - # Test memory check - with patch("api.src.inference.kokoro_v1.model_config") as mock_config: - mock_config.pytorch_gpu.memory_threshold = 4 - assert kokoro_backend._check_memory() == True - - mock_config.pytorch_gpu.memory_threshold = 6 - assert kokoro_backend._check_memory() == False - - -@patch("torch.cuda.empty_cache") -@patch("torch.cuda.synchronize") -def test_clear_memory(mock_sync, mock_clear, kokoro_backend): - """Test memory clearing.""" - with patch.object(kokoro_backend, "_device", "cuda"): - kokoro_backend._clear_memory() - mock_clear.assert_called_once() - mock_sync.assert_called_once() - - -@pytest.mark.asyncio -async def test_load_model_validation(kokoro_backend): - """Test model loading validation.""" - with pytest.raises(RuntimeError, match="Failed to load Kokoro model"): - await kokoro_backend.load_model("nonexistent_model.pth") - - -def test_unload_with_pipelines(kokoro_backend): - """Test model unloading with multiple pipelines.""" - # Mock loaded state with multiple pipelines - kokoro_backend._model = MagicMock() - pipeline_a = MagicMock() - pipeline_e = MagicMock() - kokoro_backend._pipelines = {"a": pipeline_a, "e": pipeline_e} - assert kokoro_backend.is_loaded - - # Test unload - kokoro_backend.unload() - assert not kokoro_backend.is_loaded - assert kokoro_backend._model is None - assert kokoro_backend._pipelines == {} # All pipelines should be cleared - - -@pytest.mark.asyncio -async def test_generate_validation(kokoro_backend): - """Test generation validation.""" - with pytest.raises(RuntimeError, match="Model not loaded"): - async for _ in kokoro_backend.generate("test", "voice"): - pass - - -@pytest.mark.asyncio -async def test_generate_from_tokens_validation(kokoro_backend): - """Test token generation validation.""" - with pytest.raises(RuntimeError, match="Model not loaded"): - async for _ in kokoro_backend.generate_from_tokens("test tokens", "voice"): - pass - - -def test_get_pipeline_creates_new(kokoro_backend): - """Test that _get_pipeline creates new pipeline for new language code.""" - # Mock loaded state - kokoro_backend._model = MagicMock() - - # Mock KPipeline - mock_pipeline = MagicMock() - with patch( - "api.src.inference.kokoro_v1.KPipeline", return_value=mock_pipeline - ) as mock_kpipeline: - # Get pipeline for Spanish - pipeline_e = kokoro_backend._get_pipeline("e") - - # Should create new pipeline with correct params - mock_kpipeline.assert_called_once_with( - lang_code="e", model=kokoro_backend._model, device=kokoro_backend._device - ) - assert pipeline_e == mock_pipeline - assert kokoro_backend._pipelines["e"] == mock_pipeline - - -def test_get_pipeline_reuses_existing(kokoro_backend): - """Test that _get_pipeline reuses existing pipeline for same language code.""" - # Mock loaded state - kokoro_backend._model = MagicMock() - - # Mock KPipeline - mock_pipeline = MagicMock() - with patch( - "api.src.inference.kokoro_v1.KPipeline", return_value=mock_pipeline - ) as mock_kpipeline: - # Get pipeline twice for same language - pipeline1 = kokoro_backend._get_pipeline("e") - pipeline2 = kokoro_backend._get_pipeline("e") - - # Should only create pipeline once - mock_kpipeline.assert_called_once() - assert pipeline1 == pipeline2 - assert kokoro_backend._pipelines["e"] == mock_pipeline - - -@pytest.mark.asyncio -async def test_generate_uses_correct_pipeline(kokoro_backend): - """Test that generate uses correct pipeline for language code.""" - # Mock loaded state - kokoro_backend._model = MagicMock() - - # Mock voice path handling - with ( - patch("api.src.core.paths.load_voice_tensor") as mock_load_voice, - patch("api.src.core.paths.save_voice_tensor"), - patch("tempfile.gettempdir") as mock_tempdir, - ): - mock_load_voice.return_value = torch.ones(1) - mock_tempdir.return_value = "/tmp" - - # Mock KPipeline - mock_pipeline = MagicMock() - mock_pipeline.return_value = iter([]) # Empty generator for testing - with patch("api.src.inference.kokoro_v1.KPipeline", return_value=mock_pipeline): - # Generate with Spanish voice and explicit lang_code - async for _ in kokoro_backend.generate("test", "ef_voice", lang_code="e"): - pass - - # Should create pipeline with Spanish lang_code - assert "e" in kokoro_backend._pipelines - # Use ANY to match the temp file path since it's dynamic - mock_pipeline.assert_called_with( - "test", - voice=ANY, # Don't check exact path since it's dynamic - speed=1.0, - model=kokoro_backend._model, - ) - # Verify the voice path is a temp file path - call_args = mock_pipeline.call_args - assert isinstance(call_args[1]["voice"], str) - assert call_args[1]["voice"].startswith("/tmp/temp_voice_") diff --git a/api/tests/test_normalizer.py b/api/tests/test_normalizer.py deleted file mode 100644 index 3db0801..0000000 --- a/api/tests/test_normalizer.py +++ /dev/null @@ -1,317 +0,0 @@ -"""Tests for text normalization service""" - -import pytest - -from api.src.services.text_processing.normalizer import normalize_text -from api.src.structures.schemas import NormalizationOptions - - -def test_url_protocols(): - """Test URL protocol handling""" - assert ( - normalize_text( - "Check out https://example.com", - normalization_options=NormalizationOptions(), - ) - == "Check out https example dot com" - ) - assert ( - normalize_text( - "Visit http://site.com", normalization_options=NormalizationOptions() - ) - == "Visit http site dot com" - ) - assert ( - normalize_text( - "Go to https://test.org/path", normalization_options=NormalizationOptions() - ) - == "Go to https test dot org slash path" - ) - - -def test_url_www(): - """Test www prefix handling""" - assert ( - normalize_text( - "Go to www.example.com", normalization_options=NormalizationOptions() - ) - == "Go to www example dot com" - ) - assert ( - normalize_text( - "Visit www.test.org/docs", normalization_options=NormalizationOptions() - ) - == "Visit www test dot org slash docs" - ) - assert ( - normalize_text( - "Check www.site.com?q=test", normalization_options=NormalizationOptions() - ) - == "Check www site dot com question-mark q equals test" - ) - - -def test_url_localhost(): - """Test localhost URL handling""" - assert ( - normalize_text( - "Running on localhost:7860", normalization_options=NormalizationOptions() - ) - == "Running on localhost colon seventy-eight sixty" - ) - assert ( - normalize_text( - "Server at localhost:8080/api", normalization_options=NormalizationOptions() - ) - == "Server at localhost colon eighty eighty slash api" - ) - assert ( - normalize_text( - "Test localhost:3000/test?v=1", normalization_options=NormalizationOptions() - ) - == "Test localhost colon three thousand slash test question-mark v equals one" - ) - - -def test_url_ip_addresses(): - """Test IP address URL handling""" - assert ( - normalize_text( - "Access 0.0.0.0:9090/test", normalization_options=NormalizationOptions() - ) - == "Access zero dot zero dot zero dot zero colon ninety ninety slash test" - ) - assert ( - normalize_text( - "API at 192.168.1.1:8000", normalization_options=NormalizationOptions() - ) - == "API at one hundred and ninety-two dot one hundred and sixty-eight dot one dot one colon eight thousand" - ) - assert ( - normalize_text("Server 127.0.0.1", normalization_options=NormalizationOptions()) - == "Server one hundred and twenty-seven dot zero dot zero dot one" - ) - - -def test_url_raw_domains(): - """Test raw domain handling""" - assert ( - normalize_text( - "Visit google.com/search", normalization_options=NormalizationOptions() - ) - == "Visit google dot com slash search" - ) - assert ( - normalize_text( - "Go to example.com/path?q=test", - normalization_options=NormalizationOptions(), - ) - == "Go to example dot com slash path question-mark q equals test" - ) - assert ( - normalize_text( - "Check docs.test.com", normalization_options=NormalizationOptions() - ) - == "Check docs dot test dot com" - ) - - -def test_url_email_addresses(): - """Test email address handling""" - assert ( - normalize_text( - "Email me at user@example.com", normalization_options=NormalizationOptions() - ) - == "Email me at user at example dot com" - ) - assert ( - normalize_text( - "Contact admin@test.org", normalization_options=NormalizationOptions() - ) - == "Contact admin at test dot org" - ) - assert ( - normalize_text( - "Send to test.user@site.com", normalization_options=NormalizationOptions() - ) - == "Send to test dot user at site dot com" - ) - - -def test_money(): - """Test that money text is normalized correctly""" - assert ( - normalize_text( - "He lost $5.3 thousand.", normalization_options=NormalizationOptions() - ) - == "He lost five point three thousand dollars." - ) - - assert ( - normalize_text( - "He went gambling and lost about $25.05k.", - normalization_options=NormalizationOptions(), - ) - == "He went gambling and lost about twenty-five point zero five thousand dollars." - ) - - assert ( - normalize_text( - "To put it weirdly -$6.9 million", - normalization_options=NormalizationOptions(), - ) - == "To put it weirdly minus six point nine million dollars" - ) - - assert ( - normalize_text("It costs $50.3.", normalization_options=NormalizationOptions()) - == "It costs fifty dollars and thirty cents." - ) - - assert ( - normalize_text( - "The plant cost $200,000.8.", normalization_options=NormalizationOptions() - ) - == "The plant cost two hundred thousand dollars and eighty cents." - ) - - assert ( - normalize_text( - "€30.2 is in euros", normalization_options=NormalizationOptions() - ) - == "thirty euros and twenty cents is in euros" - ) - - -def test_time(): - """Test time normalization""" - - assert ( - normalize_text( - "Your flight leaves at 10:35 pm", - normalization_options=NormalizationOptions(), - ) - == "Your flight leaves at ten thirty-five pm" - ) - - assert ( - normalize_text( - "He departed for london around 5:03 am.", - normalization_options=NormalizationOptions(), - ) - == "He departed for london around five oh three am." - ) - - assert ( - normalize_text( - "Only the 13:42 and 15:12 slots are available.", - normalization_options=NormalizationOptions(), - ) - == "Only the thirteen forty-two and fifteen twelve slots are available." - ) - - assert ( - normalize_text( - "It is currently 1:00 pm", normalization_options=NormalizationOptions() - ) - == "It is currently one pm" - ) - - assert ( - normalize_text( - "It is currently 3:00", normalization_options=NormalizationOptions() - ) - == "It is currently three o'clock" - ) - - assert ( - normalize_text( - "12:00 am is midnight", normalization_options=NormalizationOptions() - ) - == "twelve am is midnight" - ) - - -def test_number(): - """Test number normalization""" - - assert ( - normalize_text( - "I bought 1035 cans of soda", normalization_options=NormalizationOptions() - ) - == "I bought one thousand and thirty-five cans of soda" - ) - - assert ( - normalize_text( - "The bus has a maximum capacity of 62 people", - normalization_options=NormalizationOptions(), - ) - == "The bus has a maximum capacity of sixty-two people" - ) - - assert ( - normalize_text( - "There are 1300 products left in stock", - normalization_options=NormalizationOptions(), - ) - == "There are one thousand, three hundred products left in stock" - ) - - assert ( - normalize_text( - "The population is 7,890,000 people.", - normalization_options=NormalizationOptions(), - ) - == "The population is seven million, eight hundred and ninety thousand people." - ) - - assert ( - normalize_text( - "He looked around but only found 1.6k of the 10k bricks", - normalization_options=NormalizationOptions(), - ) - == "He looked around but only found one point six thousand of the ten thousand bricks" - ) - - assert ( - normalize_text( - "The book has 342 pages.", normalization_options=NormalizationOptions() - ) - == "The book has three hundred and forty-two pages." - ) - - assert ( - normalize_text( - "He made -50 sales today.", normalization_options=NormalizationOptions() - ) - == "He made minus fifty sales today." - ) - - assert ( - normalize_text( - "56.789 to the power of 1.35 million", - normalization_options=NormalizationOptions(), - ) - == "fifty-six point seven eight nine to the power of one point three five million" - ) - - -def test_non_url_text(): - """Test that non-URL text is unaffected""" - assert ( - normalize_text( - "This is not.a.url text", normalization_options=NormalizationOptions() - ) - == "This is not-a-url text" - ) - assert ( - normalize_text( - "Hello, how are you today?", normalization_options=NormalizationOptions() - ) - == "Hello, how are you today?" - ) - assert ( - normalize_text("It costs $50.", normalization_options=NormalizationOptions()) - == "It costs fifty dollars." - ) diff --git a/api/tests/test_openai_endpoints.py b/api/tests/test_openai_endpoints.py deleted file mode 100644 index d5c7efc..0000000 --- a/api/tests/test_openai_endpoints.py +++ /dev/null @@ -1,499 +0,0 @@ -import asyncio -import json -import os -from typing import AsyncGenerator, Tuple -from unittest.mock import AsyncMock, MagicMock, patch - -import numpy as np -import pytest -from fastapi.testclient import TestClient - -from api.src.core.config import settings -from api.src.inference.base import AudioChunk -from api.src.main import app -from api.src.routers.openai_compatible import ( - get_tts_service, - load_openai_mappings, - stream_audio_chunks, -) -from api.src.services.streaming_audio_writer import StreamingAudioWriter -from api.src.services.tts_service import TTSService -from api.src.structures.schemas import OpenAISpeechRequest - -client = TestClient(app) - - -@pytest.fixture -def test_voice(): - """Fixture providing a test voice name.""" - return "test_voice" - - -@pytest.fixture -def mock_openai_mappings(): - """Mock OpenAI mappings for testing.""" - with patch( - "api.src.routers.openai_compatible._openai_mappings", - { - "models": {"tts-1": "kokoro-v1_0", "tts-1-hd": "kokoro-v1_0"}, - "voices": {"alloy": "am_adam", "nova": "bf_isabella"}, - }, - ): - yield - - -@pytest.fixture -def mock_json_file(tmp_path): - """Create a temporary mock JSON file.""" - content = { - "models": {"test-model": "test-kokoro"}, - "voices": {"test-voice": "test-internal"}, - } - json_file = tmp_path / "test_mappings.json" - json_file.write_text(json.dumps(content)) - return json_file - - -def test_load_openai_mappings(mock_json_file): - """Test loading OpenAI mappings from JSON file""" - with patch("os.path.join", return_value=str(mock_json_file)): - mappings = load_openai_mappings() - assert "models" in mappings - assert "voices" in mappings - assert mappings["models"]["test-model"] == "test-kokoro" - assert mappings["voices"]["test-voice"] == "test-internal" - - -def test_load_openai_mappings_file_not_found(): - """Test handling of missing mappings file""" - with patch("os.path.join", return_value="/nonexistent/path"): - mappings = load_openai_mappings() - assert mappings == {"models": {}, "voices": {}} - - -def test_list_models(mock_openai_mappings): - """Test listing available models endpoint""" - response = client.get("/v1/models") - assert response.status_code == 200 - data = response.json() - assert data["object"] == "list" - assert isinstance(data["data"], list) - assert len(data["data"]) == 3 # tts-1, tts-1-hd, and kokoro - - # Verify all expected models are present - model_ids = [model["id"] for model in data["data"]] - assert "tts-1" in model_ids - assert "tts-1-hd" in model_ids - assert "kokoro" in model_ids - - # Verify model format - for model in data["data"]: - assert model["object"] == "model" - assert "created" in model - assert model["owned_by"] == "kokoro" - - -def test_retrieve_model(mock_openai_mappings): - """Test retrieving a specific model endpoint""" - # Test successful model retrieval - response = client.get("/v1/models/tts-1") - assert response.status_code == 200 - data = response.json() - assert data["id"] == "tts-1" - assert data["object"] == "model" - assert data["owned_by"] == "kokoro" - assert "created" in data - - # Test non-existent model - response = client.get("/v1/models/nonexistent-model") - assert response.status_code == 404 - error = response.json() - assert error["detail"]["error"] == "model_not_found" - assert "not found" in error["detail"]["message"] - assert error["detail"]["type"] == "invalid_request_error" - - -@pytest.mark.asyncio -async def test_get_tts_service_initialization(): - """Test TTSService initialization""" - with patch("api.src.routers.openai_compatible._tts_service", None): - with patch("api.src.routers.openai_compatible._init_lock", None): - with patch("api.src.services.tts_service.TTSService.create") as mock_create: - mock_service = AsyncMock() - mock_create.return_value = mock_service - - # Test concurrent access - async def get_service(): - return await get_tts_service() - - # Create multiple concurrent requests - tasks = [get_service() for _ in range(5)] - results = await asyncio.gather(*tasks) - - # Verify service was created only once - mock_create.assert_called_once() - assert all(r == mock_service for r in results) - - -@pytest.mark.asyncio -async def test_stream_audio_chunks_client_disconnect(): - """Test handling of client disconnect during streaming""" - mock_request = MagicMock() - mock_request.is_disconnected = AsyncMock(return_value=True) - - mock_service = AsyncMock() - - async def mock_stream(*args, **kwargs): - for i in range(5): - yield AudioChunk(np.ndarray([], np.int16), output=b"chunk") - - mock_service.generate_audio_stream = mock_stream - mock_service.list_voices.return_value = ["test_voice"] - - request = OpenAISpeechRequest( - model="kokoro", - input="Test text", - voice="test_voice", - response_format="mp3", - stream=True, - speed=1.0, - ) - - writer = StreamingAudioWriter("mp3", 24000) - - chunks = [] - async for chunk in stream_audio_chunks(mock_service, request, mock_request, writer): - chunks.append(chunk) - - writer.close() - - assert len(chunks) == 0 # Should stop immediately due to disconnect - - -def test_openai_voice_mapping(mock_tts_service, mock_openai_mappings): - """Test OpenAI voice name mapping""" - mock_tts_service.list_voices.return_value = ["am_adam", "bf_isabella"] - - response = client.post( - "/v1/audio/speech", - json={ - "model": "tts-1", - "input": "Hello world", - "voice": "alloy", # OpenAI voice name - "response_format": "mp3", - "stream": False, - }, - ) - assert response.status_code == 200 - mock_tts_service.generate_audio.assert_called_once() - assert mock_tts_service.generate_audio.call_args[1]["voice"] == "am_adam" - - -def test_openai_voice_mapping_streaming( - mock_tts_service, mock_openai_mappings, mock_audio_bytes -): - """Test OpenAI voice mapping in streaming mode""" - mock_tts_service.list_voices.return_value = ["am_adam", "bf_isabella"] - - response = client.post( - "/v1/audio/speech", - json={ - "model": "tts-1-hd", - "input": "Hello world", - "voice": "nova", # OpenAI voice name - "response_format": "mp3", - "stream": True, - }, - ) - assert response.status_code == 200 - content = b"" - for chunk in response.iter_bytes(): - content += chunk - assert content == mock_audio_bytes - - -def test_invalid_openai_model(mock_tts_service, mock_openai_mappings): - """Test error handling for invalid OpenAI model""" - response = client.post( - "/v1/audio/speech", - json={ - "model": "invalid-model", - "input": "Hello world", - "voice": "alloy", - "response_format": "mp3", - "stream": False, - }, - ) - assert response.status_code == 400 - error_response = response.json() - assert error_response["detail"]["error"] == "invalid_model" - assert "Unsupported model" in error_response["detail"]["message"] - - -@pytest.fixture -def mock_audio_bytes(): - """Mock audio bytes for testing.""" - return b"mock audio data" - - -@pytest.fixture -def mock_tts_service(mock_audio_bytes): - """Mock TTS service for testing.""" - with patch("api.src.routers.openai_compatible.get_tts_service") as mock_get: - service = AsyncMock(spec=TTSService) - service.generate_audio.return_value = AudioChunk(np.zeros(1000, np.int16)) - - async def mock_stream(*args, **kwargs) -> AsyncGenerator[AudioChunk, None]: - yield AudioChunk(np.ndarray([], np.int16), output=mock_audio_bytes) - - service.generate_audio_stream = mock_stream - service.list_voices.return_value = ["test_voice", "voice1", "voice2"] - service.combine_voices.return_value = "voice1_voice2" - - mock_get.return_value = service - mock_get.side_effect = None - yield service - - -@patch("api.src.services.audio.AudioService.convert_audio") -def test_openai_speech_endpoint( - mock_convert, mock_tts_service, test_voice, mock_audio_bytes -): - """Test the OpenAI-compatible speech endpoint with basic MP3 generation""" - # Configure mocks - mock_tts_service.generate_audio.return_value = AudioChunk(np.zeros(1000, np.int16)) - mock_convert.return_value = AudioChunk( - np.zeros(1000, np.int16), output=mock_audio_bytes - ) - - response = client.post( - "/v1/audio/speech", - json={ - "model": "kokoro", - "input": "Hello world", - "voice": test_voice, - "response_format": "mp3", - "stream": False, - }, - ) - assert response.status_code == 200 - assert response.headers["content-type"] == "audio/mpeg" - assert len(response.content) > 0 - assert response.content == mock_audio_bytes + mock_audio_bytes - - mock_tts_service.generate_audio.assert_called_once() - assert mock_convert.call_count == 2 - - -def test_openai_speech_streaming(mock_tts_service, test_voice, mock_audio_bytes): - """Test the OpenAI-compatible speech endpoint with streaming""" - response = client.post( - "/v1/audio/speech", - json={ - "model": "kokoro", - "input": "Hello world", - "voice": test_voice, - "response_format": "mp3", - "stream": True, - }, - ) - assert response.status_code == 200 - assert response.headers["content-type"] == "audio/mpeg" - assert "Transfer-Encoding" in response.headers - assert response.headers["Transfer-Encoding"] == "chunked" - - content = b"" - for chunk in response.iter_bytes(): - content += chunk - assert content == mock_audio_bytes - - -def test_openai_speech_pcm_streaming(mock_tts_service, test_voice, mock_audio_bytes): - """Test PCM streaming format""" - response = client.post( - "/v1/audio/speech", - json={ - "model": "kokoro", - "input": "Hello world", - "voice": test_voice, - "response_format": "pcm", - "stream": True, - }, - ) - assert response.status_code == 200 - assert response.headers["content-type"] == "audio/pcm" - - content = b"" - for chunk in response.iter_bytes(): - content += chunk - assert content == mock_audio_bytes - - -def test_openai_speech_invalid_voice(mock_tts_service): - """Test error handling for invalid voice""" - mock_tts_service.generate_audio.side_effect = ValueError( - "Voice 'invalid_voice' not found" - ) - - response = client.post( - "/v1/audio/speech", - json={ - "model": "kokoro", - "input": "Hello world", - "voice": "invalid_voice", - "response_format": "mp3", - "stream": False, - }, - ) - assert response.status_code == 400 - error_response = response.json() - assert error_response["detail"]["error"] == "validation_error" - assert "Voice 'invalid_voice' not found" in error_response["detail"]["message"] - assert error_response["detail"]["type"] == "invalid_request_error" - - -def test_openai_speech_empty_text(mock_tts_service, test_voice): - """Test error handling for empty text""" - - async def mock_error_stream(*args, **kwargs): - raise ValueError("Text is empty after preprocessing") - - mock_tts_service.generate_audio = mock_error_stream - mock_tts_service.list_voices.return_value = ["test_voice"] - - response = client.post( - "/v1/audio/speech", - json={ - "model": "kokoro", - "input": "", - "voice": test_voice, - "response_format": "mp3", - "stream": False, - }, - ) - assert response.status_code == 400 - error_response = response.json() - assert error_response["detail"]["error"] == "validation_error" - assert "Text is empty after preprocessing" in error_response["detail"]["message"] - assert error_response["detail"]["type"] == "invalid_request_error" - - -def test_openai_speech_invalid_format(mock_tts_service, test_voice): - """Test error handling for invalid format""" - response = client.post( - "/v1/audio/speech", - json={ - "model": "kokoro", - "input": "Hello world", - "voice": test_voice, - "response_format": "invalid_format", - "stream": False, - }, - ) - assert response.status_code == 422 # Validation error from Pydantic - - -def test_list_voices(mock_tts_service): - """Test listing available voices""" - # Override the mock for this specific test - mock_tts_service.list_voices.return_value = ["voice1", "voice2"] - - response = client.get("/v1/audio/voices") - assert response.status_code == 200 - data = response.json() - assert "voices" in data - assert len(data["voices"]) == 2 - assert "voice1" in data["voices"] - assert "voice2" in data["voices"] - - -@patch("api.src.routers.openai_compatible.settings") -def test_combine_voices(mock_settings, mock_tts_service): - """Test combining voices endpoint""" - # Enable local voice saving for this test - mock_settings.allow_local_voice_saving = True - - response = client.post("/v1/audio/voices/combine", json="voice1+voice2") - assert response.status_code == 200 - assert response.headers["content-type"] == "application/octet-stream" - assert "voice1+voice2.pt" in response.headers["content-disposition"] - - -def test_server_error(mock_tts_service, test_voice): - """Test handling of server errors""" - - async def mock_error_stream(*args, **kwargs): - raise RuntimeError("Internal server error") - - mock_tts_service.generate_audio = mock_error_stream - mock_tts_service.list_voices.return_value = ["test_voice"] - - response = client.post( - "/v1/audio/speech", - json={ - "model": "kokoro", - "input": "Hello world", - "voice": test_voice, - "response_format": "mp3", - "stream": False, - }, - ) - assert response.status_code == 500 - error_response = response.json() - assert error_response["detail"]["error"] == "processing_error" - assert error_response["detail"]["type"] == "server_error" - - -def test_streaming_error(mock_tts_service, test_voice): - """Test handling streaming errors""" - # Mock process_voices to raise the error - mock_tts_service.list_voices.side_effect = RuntimeError("Streaming failed") - - response = client.post( - "/v1/audio/speech", - json={ - "model": "kokoro", - "input": "Hello world", - "voice": test_voice, - "response_format": "mp3", - "stream": True, - }, - ) - - assert response.status_code == 500 - error_data = response.json() - assert error_data["detail"]["error"] == "processing_error" - assert error_data["detail"]["type"] == "server_error" - assert "Streaming failed" in error_data["detail"]["message"] - - -@pytest.mark.asyncio -async def test_streaming_initialization_error(): - """Test handling of streaming initialization errors""" - mock_service = AsyncMock() - - async def mock_error_stream(*args, **kwargs): - if False: # This makes it a proper generator - yield b"" - raise RuntimeError("Failed to initialize stream") - - mock_service.generate_audio_stream = mock_error_stream - mock_service.list_voices.return_value = ["test_voice"] - - request = OpenAISpeechRequest( - model="kokoro", - input="Test text", - voice="test_voice", - response_format="mp3", - stream=True, - speed=1.0, - ) - - writer = StreamingAudioWriter("mp3", 24000) - - with pytest.raises(RuntimeError) as exc: - async for _ in stream_audio_chunks(mock_service, request, MagicMock(), writer): - pass - - writer.close() - assert "Failed to initialize stream" in str(exc.value) diff --git a/api/tests/test_paths.py b/api/tests/test_paths.py deleted file mode 100644 index 715934e..0000000 --- a/api/tests/test_paths.py +++ /dev/null @@ -1,138 +0,0 @@ -import os -from unittest.mock import patch - -import pytest - -from api.src.core.paths import ( - _find_file, - _scan_directories, - get_content_type, - get_temp_dir_size, - get_temp_file_path, - list_temp_files, -) - - -@pytest.mark.asyncio -async def test_find_file_exists(): - """Test finding existing file.""" - with patch("aiofiles.os.path.exists") as mock_exists: - mock_exists.return_value = True - path = await _find_file("test.txt", ["/test/path"]) - assert path == "/test/path/test.txt" - - -@pytest.mark.asyncio -async def test_find_file_not_exists(): - """Test finding non-existent file.""" - with patch("aiofiles.os.path.exists") as mock_exists: - mock_exists.return_value = False - with pytest.raises(FileNotFoundError, match="File not found"): - await _find_file("test.txt", ["/test/path"]) - - -@pytest.mark.asyncio -async def test_find_file_with_filter(): - """Test finding file with filter function.""" - with patch("aiofiles.os.path.exists") as mock_exists: - mock_exists.return_value = True - filter_fn = lambda p: p.endswith(".txt") - path = await _find_file("test.txt", ["/test/path"], filter_fn) - assert path == "/test/path/test.txt" - - -@pytest.mark.asyncio -async def test_scan_directories(): - """Test scanning directories.""" - mock_entry = type("MockEntry", (), {"name": "test.txt"})() - - with ( - patch("aiofiles.os.path.exists") as mock_exists, - patch("aiofiles.os.scandir") as mock_scandir, - ): - mock_exists.return_value = True - mock_scandir.return_value = [mock_entry] - - files = await _scan_directories(["/test/path"]) - assert "test.txt" in files - - -@pytest.mark.asyncio -async def test_get_content_type(): - """Test content type detection.""" - test_cases = [ - ("test.html", "text/html"), - ("test.js", "application/javascript"), - ("test.css", "text/css"), - ("test.png", "image/png"), - ("test.unknown", "application/octet-stream"), - ] - - for filename, expected in test_cases: - content_type = await get_content_type(filename) - assert content_type == expected - - -@pytest.mark.asyncio -async def test_get_temp_file_path(): - """Test temp file path generation.""" - with ( - patch("aiofiles.os.path.exists") as mock_exists, - patch("aiofiles.os.makedirs") as mock_makedirs, - ): - mock_exists.return_value = False - - path = await get_temp_file_path("test.wav") - assert "test.wav" in path - mock_makedirs.assert_called_once() - - -@pytest.mark.asyncio -async def test_list_temp_files(): - """Test listing temp files.""" - - class MockEntry: - def __init__(self, name): - self.name = name - - def is_file(self): - return True - - mock_entry = MockEntry("test.wav") - - with ( - patch("aiofiles.os.path.exists") as mock_exists, - patch("aiofiles.os.scandir") as mock_scandir, - ): - mock_exists.return_value = True - mock_scandir.return_value = [mock_entry] - - files = await list_temp_files() - assert "test.wav" in files - - -@pytest.mark.asyncio -async def test_get_temp_dir_size(): - """Test getting temp directory size.""" - - class MockEntry: - def __init__(self, path): - self.path = path - - def is_file(self): - return True - - mock_entry = MockEntry("/tmp/test.wav") - mock_stat = type("MockStat", (), {"st_size": 1024})() - - with ( - patch("aiofiles.os.path.exists") as mock_exists, - patch("aiofiles.os.scandir") as mock_scandir, - patch("aiofiles.os.stat") as mock_stat_fn, - ): - mock_exists.return_value = True - mock_scandir.return_value = [mock_entry] - mock_stat_fn.return_value = mock_stat - - size = await get_temp_dir_size() - assert size == 1024 diff --git a/api/tests/test_text_processor.py b/api/tests/test_text_processor.py deleted file mode 100644 index 6ff8282..0000000 --- a/api/tests/test_text_processor.py +++ /dev/null @@ -1,167 +0,0 @@ -import pytest - -from api.src.services.text_processing.text_processor import ( - get_sentence_info, - process_text_chunk, - smart_split, -) - - -def test_process_text_chunk_basic(): - """Test basic text chunk processing.""" - text = "Hello world" - tokens = process_text_chunk(text) - assert isinstance(tokens, list) - assert len(tokens) > 0 - - -def test_process_text_chunk_empty(): - """Test processing empty text.""" - text = "" - tokens = process_text_chunk(text) - assert isinstance(tokens, list) - assert len(tokens) == 0 - - -def test_process_text_chunk_phonemes(): - """Test processing with skip_phonemize.""" - phonemes = "h @ l @U" # Example phoneme sequence - tokens = process_text_chunk(phonemes, skip_phonemize=True) - assert isinstance(tokens, list) - assert len(tokens) > 0 - - -def test_get_sentence_info(): - """Test sentence splitting and info extraction.""" - text = "This is sentence one. This is sentence two! What about three?" - results = get_sentence_info(text, {}) - - assert len(results) == 3 - for sentence, tokens, count in results: - assert isinstance(sentence, str) - assert isinstance(tokens, list) - assert isinstance(count, int) - assert count == len(tokens) - assert count > 0 - - -def test_get_sentence_info_phenomoes(): - """Test sentence splitting and info extraction.""" - text = ( - "This is sentence one. This is two! What about three?" - ) - results = get_sentence_info(text, {"": r"sˈɛntᵊns"}) - - assert len(results) == 3 - assert "sˈɛntᵊns" in results[1][0] - for sentence, tokens, count in results: - assert isinstance(sentence, str) - assert isinstance(tokens, list) - assert isinstance(count, int) - assert count == len(tokens) - assert count > 0 - - -@pytest.mark.asyncio -async def test_smart_split_short_text(): - """Test smart splitting with text under max tokens.""" - text = "This is a short test sentence." - chunks = [] - async for chunk_text, chunk_tokens in smart_split(text): - chunks.append((chunk_text, chunk_tokens)) - - assert len(chunks) == 1 - assert isinstance(chunks[0][0], str) - assert isinstance(chunks[0][1], list) - - -@pytest.mark.asyncio -async def test_smart_split_long_text(): - """Test smart splitting with longer text.""" - # Create text that should split into multiple chunks - text = ". ".join(["This is test sentence number " + str(i) for i in range(20)]) - - chunks = [] - async for chunk_text, chunk_tokens in smart_split(text): - chunks.append((chunk_text, chunk_tokens)) - - assert len(chunks) > 1 - for chunk_text, chunk_tokens in chunks: - assert isinstance(chunk_text, str) - assert isinstance(chunk_tokens, list) - assert len(chunk_tokens) > 0 - - -@pytest.mark.asyncio -async def test_smart_split_with_punctuation(): - """Test smart splitting handles punctuation correctly.""" - text = "First sentence! Second sentence? Third sentence; Fourth sentence: Fifth sentence." - - chunks = [] - async for chunk_text, chunk_tokens in smart_split(text): - chunks.append(chunk_text) - - # Verify punctuation is preserved - assert all(any(p in chunk for p in "!?;:.") for chunk in chunks) - -def test_process_text_chunk_chinese_phonemes(): - """Test processing with Chinese pinyin phonemes.""" - pinyin = "nǐ hǎo lì" # Example pinyin sequence with tones - tokens = process_text_chunk(pinyin, skip_phonemize=True, language="z") - assert isinstance(tokens, list) - assert len(tokens) > 0 - - -def test_get_sentence_info_chinese(): - """Test Chinese sentence splitting and info extraction.""" - text = "这是一个句子。这是第二个句子!第三个问题?" - results = get_sentence_info(text, {}, lang_code="z") - - assert len(results) == 3 - for sentence, tokens, count in results: - assert isinstance(sentence, str) - assert isinstance(tokens, list) - assert isinstance(count, int) - assert count == len(tokens) - assert count > 0 - -@pytest.mark.asyncio -async def test_smart_split_chinese_short(): - """Test Chinese smart splitting with short text.""" - text = "这是一句话。" - chunks = [] - async for chunk_text, chunk_tokens in smart_split(text, lang_code="z"): - chunks.append((chunk_text, chunk_tokens)) - - assert len(chunks) == 1 - assert isinstance(chunks[0][0], str) - assert isinstance(chunks[0][1], list) - - -@pytest.mark.asyncio -async def test_smart_split_chinese_long(): - """Test Chinese smart splitting with longer text.""" - text = "。".join([f"测试句子 {i}" for i in range(20)]) - - chunks = [] - async for chunk_text, chunk_tokens in smart_split(text, lang_code="z"): - chunks.append((chunk_text, chunk_tokens)) - - assert len(chunks) > 1 - for chunk_text, chunk_tokens in chunks: - assert isinstance(chunk_text, str) - assert isinstance(chunk_tokens, list) - assert len(chunk_tokens) > 0 - - -@pytest.mark.asyncio -async def test_smart_split_chinese_punctuation(): - """Test Chinese smart splitting with punctuation preservation.""" - text = "第一句!第二问?第三句;第四句:第五句。" - - chunks = [] - async for chunk_text, _ in smart_split(text, lang_code="z"): - chunks.append(chunk_text) - - # Verify Chinese punctuation is preserved - assert all(any(p in chunk for p in "!?;:。") for chunk in chunks) \ No newline at end of file diff --git a/api/tests/test_tts_service.py b/api/tests/test_tts_service.py deleted file mode 100644 index ae8447a..0000000 --- a/api/tests/test_tts_service.py +++ /dev/null @@ -1,126 +0,0 @@ -from unittest.mock import AsyncMock, MagicMock, patch - -import numpy as np -import pytest -import torch - -from api.src.services.tts_service import TTSService - - -@pytest.fixture -def mock_managers(): - """Mock model and voice managers.""" - - async def _mock_managers(): - model_manager = AsyncMock() - model_manager.get_backend.return_value = MagicMock() - - voice_manager = AsyncMock() - voice_manager.get_voice_path.return_value = "/path/to/voice.pt" - voice_manager.list_voices.return_value = ["voice1", "voice2"] - - with ( - patch("api.src.services.tts_service.get_model_manager") as mock_get_model, - patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice, - ): - mock_get_model.return_value = model_manager - mock_get_voice.return_value = voice_manager - return model_manager, voice_manager - - return _mock_managers() - - -@pytest.fixture -def tts_service(mock_managers): - """Create TTSService instance with mocked dependencies.""" - - async def _create_service(): - return await TTSService.create("test_output") - - return _create_service() - - -@pytest.mark.asyncio -async def test_service_creation(): - """Test service creation and initialization.""" - model_manager = AsyncMock() - voice_manager = AsyncMock() - - with ( - patch("api.src.services.tts_service.get_model_manager") as mock_get_model, - patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice, - ): - mock_get_model.return_value = model_manager - mock_get_voice.return_value = voice_manager - - service = await TTSService.create("test_output") - assert service.output_dir == "test_output" - assert service.model_manager is model_manager - assert service._voice_manager is voice_manager - - -@pytest.mark.asyncio -async def test_get_voice_path_single(): - """Test getting path for single voice.""" - model_manager = AsyncMock() - voice_manager = AsyncMock() - voice_manager.get_voice_path.return_value = "/path/to/voice1.pt" - - with ( - patch("api.src.services.tts_service.get_model_manager") as mock_get_model, - patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice, - ): - mock_get_model.return_value = model_manager - mock_get_voice.return_value = voice_manager - - service = await TTSService.create("test_output") - name, path = await service._get_voices_path("voice1") - assert name == "voice1" - assert path == "/path/to/voice1.pt" - voice_manager.get_voice_path.assert_called_once_with("voice1") - - -@pytest.mark.asyncio -async def test_get_voice_path_combined(): - """Test getting path for combined voices.""" - model_manager = AsyncMock() - voice_manager = AsyncMock() - voice_manager.get_voice_path.return_value = "/path/to/voice.pt" - - with ( - patch("api.src.services.tts_service.get_model_manager") as mock_get_model, - patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice, - patch("torch.load") as mock_load, - patch("torch.save") as mock_save, - patch("tempfile.gettempdir") as mock_temp, - ): - mock_get_model.return_value = model_manager - mock_get_voice.return_value = voice_manager - mock_temp.return_value = "/tmp" - mock_load.return_value = torch.ones(10) - - service = await TTSService.create("test_output") - name, path = await service._get_voices_path("voice1+voice2") - assert name == "voice1+voice2" - assert path.endswith("voice1+voice2.pt") - mock_save.assert_called_once() - - -@pytest.mark.asyncio -async def test_list_voices(): - """Test listing available voices.""" - model_manager = AsyncMock() - voice_manager = AsyncMock() - voice_manager.list_voices.return_value = ["voice1", "voice2"] - - with ( - patch("api.src.services.tts_service.get_model_manager") as mock_get_model, - patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice, - ): - mock_get_model.return_value = model_manager - mock_get_voice.return_value = voice_manager - - service = await TTSService.create("test_output") - voices = await service.list_voices() - assert voices == ["voice1", "voice2"] - voice_manager.list_voices.assert_called_once() diff --git a/docker/gpu/Dockerfile b/docker/gpu/Dockerfile index 44c1ba7..1f9d58e 100644 --- a/docker/gpu/Dockerfile +++ b/docker/gpu/Dockerfile @@ -1,67 +1,59 @@ -FROM --platform=$BUILDPLATFORM nvidia/cuda:12.8.0-cudnn-runtime-ubuntu24.04 +FROM --platform=$BUILDPLATFORM nvidia/cuda:12.9.0-base-ubuntu24.04 # Set non-interactive frontend ENV DEBIAN_FRONTEND=noninteractive # Install Python and other dependencies -RUN apt-get update && apt-get install -y \ - python3.10 \ - python3-venv \ - espeak-ng \ - espeak-ng-data \ - git \ - libsndfile1 \ - curl \ - ffmpeg \ - g++ \ - && apt-get clean && rm -rf /var/lib/apt/lists/* \ - && mkdir -p /usr/share/espeak-ng-data \ - && ln -s /usr/lib/*/espeak-ng-data/* /usr/share/espeak-ng-data/ - -# Install UV using the installer script -RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \ +RUN apt-get update -y && \ + apt-get install -y python3 python3-venv libsndfile1 curl ffmpeg g++ && \ + apt-get clean && rm -rf /var/lib/apt/lists/* && \ + curl -LsSf https://astral.sh/uv/install.sh | sh && \ mv /root/.local/bin/uv /usr/local/bin/ && \ - mv /root/.local/bin/uvx /usr/local/bin/ - -# Create non-root user and set up directories and permissions -RUN useradd -m -u 1001 appuser && \ + mv /root/.local/bin/uvx /usr/local/bin/ && \ + useradd -m -u 1001 appuser && \ mkdir -p /app/api/src/models/v1_0 && \ chown -R appuser:appuser /app USER appuser WORKDIR /app -# Copy dependency files -COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml - ENV PHONEMIZER_ESPEAK_PATH=/usr/bin \ PHONEMIZER_ESPEAK_DATA=/usr/share/espeak-ng-data \ - ESPEAK_DATA_PATH=/usr/share/espeak-ng-data + PYTHONUNBUFFERED=1 \ + PYTHONPATH=/app:/app/api \ + PATH="/app/.venv/bin:$PATH" \ + UV_LINK_MODE=copy \ + USE_GPU=true \ + DEVICE="gpu" -# Install dependencies with GPU extras (using cache mounts) -RUN --mount=type=cache,target=/root/.cache/uv \ - uv venv --python 3.10 && \ - uv sync --extra gpu - -# Copy project files including models +COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml COPY --chown=appuser:appuser api ./api COPY --chown=appuser:appuser web ./web COPY --chown=appuser:appuser docker/scripts/ ./ RUN chmod +x ./entrypoint.sh + +RUN --mount=type=cache,target=/root/.cache/uv \ + uv venv --python 3.10 && \ + uv sync --extra gpu && \ + uv cache clean && \ + python download_model.py --output api/src/models/v1_0 - -# Set all environment variables in one go -ENV PYTHONUNBUFFERED=1 \ - PYTHONPATH=/app:/app/api \ - PATH="/app/.venv/bin:$PATH" \ - UV_LINK_MODE=copy \ - USE_GPU=true - -ENV DOWNLOAD_MODEL=true -# Download model if enabled -RUN if [ "$DOWNLOAD_MODEL" = "true" ]; then \ - python download_model.py --output api/src/models/v1_0; \ - fi - -ENV DEVICE="gpu" # Run FastAPI server through entrypoint.sh CMD ["./entrypoint.sh"] + + + +# If you want to test the docker image locally, run this from the project root: +# docker build -f docker\gpu\Dockerfile -t kokoro . +# Run it with +# docker run -p 8880:8880 --name kokoro kokoro --gpus all +# +# You can log into the container with +# docker exec -it kokoro /bin/bash +# +# Other commands: +# 1. Stop and remove container +# docker stop kokoro +# docker container remove kokoro +# 2. List and remove images +# docker images +# docker image remove kokoro \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 5d082f7..e983575 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,10 +31,9 @@ dependencies = [ "matplotlib>=3.10.0", "mutagen>=1.47.0", "psutil>=6.1.1", - "espeakng-loader==0.2.4", "kokoro==0.9.2", "misaki[en,ja,ko,zh]==0.9.3", - "spacy==3.8.5", + "spacy==3.8.7", "en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl", "inflect>=7.5.0", "phonemizer-fork>=3.3.2",