diff --git a/api/depr_tests/test_endpoints.py b/api/depr_tests/test_endpoints.py deleted file mode 100644 index b5af29e..0000000 --- a/api/depr_tests/test_endpoints.py +++ /dev/null @@ -1,206 +0,0 @@ -"""Tests for API endpoints""" -import pytest -import torch -from fastapi.testclient import TestClient - -from ..src.main import app - -# Create test client for non-async tests -client = TestClient(app) - - -def test_health_check(): - """Test the health check endpoint""" - response = client.get("/health") - assert response.status_code == 200 - assert response.json() == {"status": "healthy"} - - -@pytest.mark.asyncio -async def test_openai_speech_endpoint(async_client, mock_tts_service): - """Test the OpenAI-compatible speech endpoint""" - # Setup mocks - mock_tts_service._voice_manager.list_voices.return_value = ["bm_lewis"] - mock_tts_service.generate_audio.return_value = (torch.zeros(48000).numpy(), 1.0) - mock_tts_service._voice_manager.load_voice.return_value = torch.zeros(192) - - # Mock voice validation - mock_tts_service._voice_manager.get_voice_path.return_value = "/mock/voices/bm_lewis.pt" - - test_request = { - "model": "kokoro", - "input": "Hello world", - "voice": "bm_lewis", - "response_format": "wav", - "speed": 1.0, - "stream": False, - } - response = await async_client.post("/v1/audio/speech", json=test_request) - assert response.status_code == 200 - assert response.headers["content-type"] == "audio/wav" - assert response.headers["content-disposition"] == "attachment; filename=speech.wav" - mock_tts_service.generate_audio.assert_called_once() - - -@pytest.mark.asyncio -async def test_openai_speech_invalid_voice(async_client, mock_tts_service): - """Test the OpenAI-compatible speech endpoint with invalid voice""" - # Setup mocks - mock_tts_service._voice_manager.list_voices.return_value = ["af", "bm_lewis"] - mock_tts_service._voice_manager.get_voice_path.return_value = None - - test_request = { - "model": "kokoro", - "input": "Hello world", - "voice": "invalid_voice", - "response_format": "wav", - "speed": 1.0, - "stream": False, - } - response = await async_client.post("/v1/audio/speech", json=test_request) - assert response.status_code == 400 - assert "not found" in response.json()["detail"]["message"] - - -@pytest.mark.asyncio -async def test_openai_speech_generation_error(async_client, mock_tts_service): - """Test error handling in speech generation""" - # Setup mocks - mock_tts_service._voice_manager.list_voices.return_value = ["af"] - mock_tts_service.generate_audio.side_effect = RuntimeError("Generation failed") - mock_tts_service._voice_manager.load_voice.return_value = torch.zeros(192) - mock_tts_service._voice_manager.get_voice_path.return_value = "/mock/voices/af.pt" - - test_request = { - "model": "kokoro", - "input": "Hello world", - "voice": "af", - "response_format": "wav", - "speed": 1.0, - "stream": False, - } - response = await async_client.post("/v1/audio/speech", json=test_request) - assert response.status_code == 500 - assert "Generation failed" in response.json()["detail"]["message"] - - -@pytest.mark.asyncio -async def test_combine_voices_list_success(async_client, mock_tts_service): - """Test successful voice combination using list format""" - # Setup mocks - mock_tts_service._voice_manager.list_voices.return_value = ["af_bella", "af_sarah"] - mock_tts_service._voice_manager.combine_voices.return_value = "af_bella_af_sarah" - mock_tts_service._voice_manager.load_voice.return_value = torch.zeros(192) - mock_tts_service._voice_manager.get_voice_path.return_value = "/mock/voices/af_bella.pt" - - test_voices = ["af_bella", "af_sarah"] - response = await async_client.post("/v1/audio/voices/combine", json=test_voices) - - assert response.status_code == 200 - assert response.json()["voice"] == "af_bella_af_sarah" - mock_tts_service._voice_manager.combine_voices.assert_called_once() - - -@pytest.mark.asyncio -async def test_combine_voices_empty_list(async_client, mock_tts_service): - """Test combining empty voice list returns error""" - test_voices = [] - response = await async_client.post("/v1/audio/voices/combine", json=test_voices) - assert response.status_code == 400 - assert "No voices provided" in response.json()["detail"]["message"] - - -@pytest.mark.asyncio -async def test_speech_streaming_with_combined_voice(async_client, mock_tts_service): - """Test streaming speech with combined voice using + syntax""" - # Setup mocks - mock_tts_service._voice_manager.list_voices.return_value = ["af_bella", "af_sarah"] - mock_tts_service._voice_manager.combine_voices.return_value = "af_bella_af_sarah" - mock_tts_service._voice_manager.load_voice.return_value = torch.zeros(192) - mock_tts_service._voice_manager.get_voice_path.return_value = "/mock/voices/af_bella.pt" - - async def mock_stream(): - yield b"chunk1" - yield b"chunk2" - - mock_tts_service.generate_audio_stream.return_value = mock_stream() - - test_request = { - "model": "kokoro", - "input": "Hello world", - "voice": "af_bella+af_sarah", - "response_format": "mp3", - "stream": True, - } - - headers = {"x-raw-response": "stream"} - response = await async_client.post( - "/v1/audio/speech", json=test_request, headers=headers - ) - - assert response.status_code == 200 - assert response.headers["content-type"] == "audio/mpeg" - assert response.headers["content-disposition"] == "attachment; filename=speech.mp3" - - -@pytest.mark.asyncio -async def test_openai_speech_pcm_streaming(async_client, mock_tts_service): - """Test streaming PCM audio for real-time playback""" - # Setup mocks - mock_tts_service._voice_manager.list_voices.return_value = ["af"] - mock_tts_service._voice_manager.load_voice.return_value = torch.zeros(192) - mock_tts_service._voice_manager.get_voice_path.return_value = "/mock/voices/af.pt" - - async def mock_stream(): - yield b"chunk1" - yield b"chunk2" - - mock_tts_service.generate_audio_stream.return_value = mock_stream() - - test_request = { - "model": "kokoro", - "input": "Hello world", - "voice": "af", - "response_format": "pcm", - "stream": True, - } - - headers = {"x-raw-response": "stream"} - response = await async_client.post( - "/v1/audio/speech", json=test_request, headers=headers - ) - - assert response.status_code == 200 - assert response.headers["content-type"] == "audio/pcm" - - -@pytest.mark.asyncio -async def test_openai_speech_streaming_mp3(async_client, mock_tts_service): - """Test streaming MP3 audio to file""" - # Setup mocks - mock_tts_service._voice_manager.list_voices.return_value = ["af"] - mock_tts_service._voice_manager.load_voice.return_value = torch.zeros(192) - mock_tts_service._voice_manager.get_voice_path.return_value = "/mock/voices/af.pt" - - async def mock_stream(): - yield b"chunk1" - yield b"chunk2" - - mock_tts_service.generate_audio_stream.return_value = mock_stream() - - test_request = { - "model": "kokoro", - "input": "Hello world", - "voice": "af", - "response_format": "mp3", - "stream": True, - } - - headers = {"x-raw-response": "stream"} - response = await async_client.post( - "/v1/audio/speech", json=test_request, headers=headers - ) - - assert response.status_code == 200 - assert response.headers["content-type"] == "audio/mpeg" - assert response.headers["content-disposition"] == "attachment; filename=speech.mp3" diff --git a/api/depr_tests/test_main.py b/api/depr_tests/test_main.py deleted file mode 100644 index dd5ac12..0000000 --- a/api/depr_tests/test_main.py +++ /dev/null @@ -1,104 +0,0 @@ -"""Tests for FastAPI application""" -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -import torch -from fastapi.testclient import TestClient - -from api.src.main import app, lifespan - - -@pytest.fixture -def test_client(): - """Create a test client""" - return TestClient(app) - - -def test_health_check(test_client): - """Test health check endpoint""" - response = test_client.get("/health") - assert response.status_code == 200 - assert response.json() == {"status": "healthy"} - - -@pytest.mark.asyncio -async def test_lifespan_successful_warmup(): - """Test successful model warmup in lifespan""" - with patch("api.src.inference.model_manager.get_manager") as mock_get_model, \ - patch("api.src.inference.voice_manager.get_manager") as mock_get_voice, \ - patch("api.src.main.logger") as mock_logger, \ - patch("os.path.exists") as mock_exists, \ - patch("torch.cuda.is_available") as mock_cuda: - - # Setup mocks - mock_model = AsyncMock() - mock_voice = AsyncMock() - mock_get_model.return_value = mock_model - mock_get_voice.return_value = mock_voice - mock_exists.return_value = True - mock_cuda.return_value = False - - # Setup model manager - mock_backend = MagicMock() - mock_backend.device = "cpu" - mock_model.get_backend.return_value = mock_backend - mock_model.load_model = AsyncMock() - - # Setup voice manager - mock_voice_tensor = torch.zeros(192) - mock_voice.load_voice = AsyncMock(return_value=mock_voice_tensor) - mock_voice.list_voices = AsyncMock(return_value=["af", "af_bella", "af_sarah"]) - - # Create an async generator from the lifespan context manager - async_gen = lifespan(MagicMock()) - - # Start the context manager - await async_gen.__aenter__() - - # Verify managers were initialized - mock_get_model.assert_called_once() - mock_get_voice.assert_called_once() - mock_model.load_model.assert_called_once() - - # Clean up - await async_gen.__aexit__(None, None, None) - - -@pytest.mark.asyncio -async def test_lifespan_failed_warmup(): - """Test failed model warmup in lifespan""" - with patch("api.src.inference.model_manager.get_manager") as mock_get_model: - # Mock the model manager to fail - mock_get_model.side_effect = RuntimeError("Failed to initialize model") - - # Create an async generator from the lifespan context manager - async_gen = lifespan(MagicMock()) - - # Verify the exception is raised - with pytest.raises(RuntimeError, match="Failed to initialize model"): - await async_gen.__aenter__() - - # Clean up - await async_gen.__aexit__(None, None, None) - - -@pytest.mark.asyncio -async def test_lifespan_voice_manager_failure(): - """Test failure when voice manager fails to initialize""" - with patch("api.src.inference.model_manager.get_manager") as mock_get_model, \ - patch("api.src.inference.voice_manager.get_manager") as mock_get_voice: - - # Setup model manager success but voice manager failure - mock_model = AsyncMock() - mock_get_model.return_value = mock_model - mock_get_voice.side_effect = RuntimeError("Failed to initialize voice manager") - - # Create an async generator from the lifespan context manager - async_gen = lifespan(MagicMock()) - - # Verify the exception is raised - with pytest.raises(RuntimeError, match="Failed to initialize voice manager"): - await async_gen.__aenter__() - - # Clean up - await async_gen.__aexit__(None, None, None) diff --git a/api/depr_tests/test_managers.py b/api/depr_tests/test_managers.py deleted file mode 100644 index 64bb8c6..0000000 --- a/api/depr_tests/test_managers.py +++ /dev/null @@ -1,190 +0,0 @@ -"""Tests for model and voice managers""" -import os -import numpy as np -import pytest -import torch -from unittest.mock import AsyncMock, MagicMock, Mock, patch - -from api.src.inference.model_manager import get_manager as get_model_manager -from api.src.inference.voice_manager import get_manager as get_voice_manager - -# Get project root path -PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -MOCK_VOICES_DIR = os.path.join(PROJECT_ROOT, "api", "src", "voices") -MOCK_MODEL_DIR = os.path.join(PROJECT_ROOT, "api", "src", "models") - - -@pytest.mark.asyncio -async def test_model_manager_initialization(): - """Test model manager initialization""" - with patch("api.src.inference.model_manager.settings") as mock_settings, \ - patch("api.src.core.paths.get_model_path") as mock_get_path: - - mock_settings.model_dir = MOCK_MODEL_DIR - mock_settings.onnx_model_path = "model.onnx" - mock_get_path.return_value = os.path.join(MOCK_MODEL_DIR, "model.onnx") - - manager = await get_model_manager() - assert manager is not None - backend = manager.get_backend() - assert backend is not None - - -@pytest.mark.asyncio -async def test_model_manager_generate(): - """Test model generation""" - with patch("api.src.inference.model_manager.settings") as mock_settings, \ - patch("api.src.core.paths.get_model_path") as mock_get_path, \ - patch("torch.load") as mock_torch_load: - - mock_settings.model_dir = MOCK_MODEL_DIR - mock_settings.onnx_model_path = "model.onnx" - mock_settings.use_onnx = True - mock_settings.use_gpu = False - mock_get_path.return_value = os.path.join(MOCK_MODEL_DIR, "model.onnx") - - # Mock torch load to return a tensor - mock_torch_load.return_value = torch.zeros(192) - - manager = await get_model_manager() - - # Set up mock backend - mock_backend = AsyncMock() - mock_backend.is_loaded = True - mock_backend.device = "cpu" - - # Create audio tensor and ensure it's properly mocked - audio_data = torch.zeros(48000, dtype=torch.float32) - async def mock_generate(*args, **kwargs): - return audio_data - mock_backend.generate.side_effect = mock_generate - - # Set up manager with mock backend - manager._backends['onnx_cpu'] = mock_backend - manager._current_backend = 'onnx_cpu' - - # Generate audio - tokens = [1, 2, 3] - voice_tensor = torch.zeros(192) - audio = await manager.generate(tokens, voice_tensor, speed=1.0) - - assert isinstance(audio, torch.Tensor), "Generated audio should be torch tensor" - assert audio.dtype == torch.float32, "Audio should be 32-bit float" - assert audio.shape == (48000,), "Audio should have 48000 samples" - assert mock_backend.generate.call_count == 1 - - -@pytest.mark.asyncio -async def test_voice_manager_initialization(): - """Test voice manager initialization""" - with patch("api.src.inference.voice_manager.settings") as mock_settings, \ - patch("os.path.exists") as mock_exists: - - mock_settings.voices_dir = MOCK_VOICES_DIR - mock_exists.return_value = True - - manager = await get_voice_manager() - assert manager is not None - - -@pytest.mark.asyncio -async def test_voice_manager_list_voices(): - """Test listing available voices""" - with patch("api.src.inference.voice_manager.settings") as mock_settings, \ - patch("os.listdir") as mock_listdir, \ - patch("os.makedirs") as mock_makedirs, \ - patch("os.path.exists") as mock_exists: - - mock_settings.voices_dir = MOCK_VOICES_DIR - mock_listdir.return_value = ["af_bella.pt", "af_sarah.pt", "bm_lewis.pt"] - mock_exists.return_value = True - - manager = await get_voice_manager() - voices = await manager.list_voices() - - assert isinstance(voices, list) - assert len(voices) == 3, f"Expected 3 voices but got {len(voices)}" - assert sorted(voices) == ["af_bella", "af_sarah", "bm_lewis"] - mock_listdir.assert_called_once() - - -@pytest.mark.asyncio -async def test_voice_manager_load_voice(): - """Test loading a voice""" - with patch("api.src.inference.voice_manager.settings") as mock_settings, \ - patch("torch.load") as mock_torch_load, \ - patch("os.path.exists") as mock_exists: - - mock_settings.voices_dir = MOCK_VOICES_DIR - mock_exists.return_value = True - - # Create a mock tensor - mock_tensor = torch.zeros(192) - mock_torch_load.return_value = mock_tensor - - manager = await get_voice_manager() - voice_tensor = await manager.load_voice("af_bella", device="cpu") - - assert isinstance(voice_tensor, torch.Tensor) - assert voice_tensor.shape == (192,) - mock_torch_load.assert_called_once() - - -@pytest.mark.asyncio -async def test_voice_manager_combine_voices(): - """Test combining voices""" - with patch("api.src.inference.voice_manager.settings") as mock_settings, \ - patch("torch.load") as mock_load, \ - patch("torch.save") as mock_save, \ - patch("os.makedirs") as mock_makedirs, \ - patch("os.path.exists") as mock_exists: - - mock_settings.voices_dir = MOCK_VOICES_DIR - mock_exists.return_value = True - - # Create mock tensors - mock_tensor1 = torch.ones(192) - mock_tensor2 = torch.ones(192) * 2 - mock_load.side_effect = [mock_tensor1, mock_tensor2] - - manager = await get_voice_manager() - combined_name = await manager.combine_voices(["af_bella", "af_sarah"]) - - assert combined_name == "af_bella_af_sarah" - assert mock_load.call_count == 2 - mock_save.assert_called_once() - - # Verify the combined tensor was saved - saved_tensor = mock_save.call_args[0][0] - assert isinstance(saved_tensor, torch.Tensor) - assert saved_tensor.shape == (192,) - # Should be average of the two tensors - assert torch.allclose(saved_tensor, torch.ones(192) * 1.5) - - -@pytest.mark.asyncio -async def test_voice_manager_invalid_voice(): - """Test loading invalid voice""" - with patch("api.src.inference.voice_manager.settings") as mock_settings, \ - patch("os.path.exists") as mock_exists: - - mock_settings.voices_dir = MOCK_VOICES_DIR - mock_exists.return_value = False - - manager = await get_voice_manager() - with pytest.raises(RuntimeError, match="Voice not found"): - await manager.load_voice("invalid_voice", device="cpu") - - -@pytest.mark.asyncio -async def test_voice_manager_combine_invalid_voices(): - """Test combining with invalid voices""" - with patch("api.src.inference.voice_manager.settings") as mock_settings, \ - patch("os.path.exists") as mock_exists: - - mock_settings.voices_dir = MOCK_VOICES_DIR - mock_exists.return_value = False - - manager = await get_voice_manager() - with pytest.raises(RuntimeError, match="Voice not found"): - await manager.combine_voices(["invalid_voice1", "invalid_voice2"]) \ No newline at end of file diff --git a/api/depr_tests/test_text_processing.py b/api/depr_tests/test_text_processing.py deleted file mode 100644 index 7e63491..0000000 --- a/api/depr_tests/test_text_processing.py +++ /dev/null @@ -1,139 +0,0 @@ -"""Tests for text processing endpoints""" -import os -import pytest -import torch -from fastapi.testclient import TestClient - -from ..src.main import app - -# Get project root path -PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -MOCK_VOICES_DIR = os.path.join(PROJECT_ROOT, "api", "src", "voices") - -client = TestClient(app) - - -@pytest.mark.asyncio -async def test_generate_from_phonemes(async_client, mock_tts_service): - """Test generating audio from phonemes""" - # Setup mocks - mock_tts_service._voice_manager.list_voices.return_value = ["af_bella"] - mock_tts_service.generate_audio.return_value = (torch.zeros(48000).numpy(), 1.0) - mock_tts_service._voice_manager.load_voice.return_value = torch.zeros(192) - mock_tts_service._voice_manager.get_voice_path.return_value = "/mock/voices/af_bella.pt" - - test_request = { - "model": "kokoro", - "input": "h @ l oU w r= l d", - "voice": "af_bella", - "response_format": "wav", - "speed": 1.0, - "stream": False, - } - - response = await async_client.post("/v1/audio/speech", json=test_request) - assert response.status_code == 200 - assert response.headers["content-type"] == "audio/wav" - assert response.headers["content-disposition"] == "attachment; filename=speech.wav" - mock_tts_service.generate_audio.assert_called_once() - - -@pytest.mark.asyncio -async def test_generate_from_phonemes_invalid_voice(async_client, mock_tts_service): - """Test generating audio from phonemes with invalid voice""" - # Setup mocks - mock_tts_service._voice_manager.list_voices.return_value = ["af_bella"] - mock_tts_service._voice_manager.get_voice_path.return_value = None - - test_request = { - "model": "kokoro", - "input": "h @ l oU w r= l d", - "voice": "invalid_voice", - "response_format": "wav", - "speed": 1.0, - "stream": False, - } - - response = await async_client.post("/v1/audio/speech", json=test_request) - assert response.status_code == 400 - assert "Voice not found" in response.json()["detail"]["message"] - - -@pytest.mark.asyncio -async def test_generate_from_phonemes_chunked(async_client, mock_tts_service): - """Test generating chunked audio from phonemes""" - # Setup mocks - mock_tts_service._voice_manager.list_voices.return_value = ["af_bella"] - mock_tts_service._voice_manager.load_voice.return_value = torch.zeros(192) - mock_tts_service._voice_manager.get_voice_path.return_value = "/mock/voices/af_bella.pt" - - async def mock_stream(): - yield b"chunk1" - yield b"chunk2" - - mock_tts_service.generate_audio_stream.return_value = mock_stream() - - test_request = { - "model": "kokoro", - "input": "h @ l oU w r= l d", - "voice": "af_bella", - "response_format": "mp3", - "stream": True, - } - - headers = {"x-raw-response": "stream"} - response = await async_client.post( - "/v1/audio/speech", json=test_request, headers=headers - ) - - assert response.status_code == 200 - assert response.headers["content-type"] == "audio/mpeg" - assert response.headers["content-disposition"] == "attachment; filename=speech.mp3" - - -@pytest.mark.asyncio -async def test_invalid_phonemes(async_client, mock_tts_service): - """Test handling invalid phonemes""" - # Setup mocks - mock_tts_service._voice_manager.list_voices.return_value = ["af_bella"] - mock_tts_service._voice_manager.load_voice.return_value = torch.zeros(192) - mock_tts_service._voice_manager.get_voice_path.return_value = "/mock/voices/af_bella.pt" - - test_request = { - "model": "kokoro", - "input": "", # Empty input - "voice": "af_bella", - "response_format": "wav", - "speed": 1.0, - "stream": False, - } - - response = await async_client.post("/v1/audio/speech", json=test_request) - assert response.status_code == 400 - assert "Text is empty" in response.json()["detail"]["message"] - - -@pytest.mark.asyncio -async def test_phonemes_with_combined_voice(async_client, mock_tts_service): - """Test generating audio from phonemes with combined voice""" - # Setup mocks - mock_tts_service._voice_manager.list_voices.return_value = ["af_bella", "af_sarah"] - mock_tts_service._voice_manager.combine_voices.return_value = "af_bella_af_sarah" - mock_tts_service._voice_manager.load_voice.return_value = torch.zeros(192) - mock_tts_service._voice_manager.get_voice_path.return_value = "/mock/voices/af_bella_af_sarah.pt" - mock_tts_service.generate_audio.return_value = (torch.zeros(48000).numpy(), 1.0) - - test_request = { - "model": "kokoro", - "input": "h @ l oU w r= l d", - "voice": "af_bella+af_sarah", - "response_format": "wav", - "speed": 1.0, - "stream": False, - } - - response = await async_client.post("/v1/audio/speech", json=test_request) - assert response.status_code == 200 - assert response.headers["content-type"] == "audio/wav" - mock_tts_service._voice_manager.combine_voices.assert_called_once() - mock_tts_service.generate_audio.assert_called_once() diff --git a/api/depr_tests/test_tts_service.py b/api/depr_tests/test_tts_service.py deleted file mode 100644 index ac33c5f..0000000 --- a/api/depr_tests/test_tts_service.py +++ /dev/null @@ -1,118 +0,0 @@ -"""Tests for TTSService""" -import os -import numpy as np -import pytest -import torch -from unittest.mock import AsyncMock, MagicMock, Mock, patch - -from api.src.services.tts_service import TTSService - -# Get project root path -PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -MOCK_VOICES_DIR = os.path.join(PROJECT_ROOT, "api", "src", "voices") -MOCK_MODEL_DIR = os.path.join(PROJECT_ROOT, "api", "src", "models") - - -@pytest.mark.asyncio -async def test_service_initialization(mock_model_manager, mock_voice_manager): - """Test TTSService initialization""" - # Create service using factory method - with patch("api.src.services.tts_service.get_model_manager", return_value=mock_model_manager), \ - patch("api.src.services.tts_service.get_voice_manager", return_value=mock_voice_manager): - service = await TTSService.create() - assert service is not None - assert service.model_manager == mock_model_manager - assert service._voice_manager == mock_voice_manager - - -@pytest.mark.asyncio -async def test_generate_audio_basic(mock_tts_service): - """Test basic audio generation""" - text = "Hello world" - voice = "af" - audio, duration = await mock_tts_service.generate_audio(text, voice) - assert isinstance(audio, np.ndarray) - assert duration > 0 - - -@pytest.mark.asyncio -async def test_generate_audio_empty_text(mock_tts_service): - """Test handling empty text input""" - with pytest.raises(ValueError, match="Text is empty after preprocessing"): - await mock_tts_service.generate_audio("", "af") - - -@pytest.mark.asyncio -async def test_generate_audio_stream(mock_tts_service): - """Test streaming audio generation""" - text = "Hello world" - voice = "af" - - # Setup mock stream - async def mock_stream(): - yield b"chunk1" - yield b"chunk2" - mock_tts_service.generate_audio_stream.return_value = mock_stream() - - # Test streaming - stream = mock_tts_service.generate_audio_stream(text, voice) - chunks = [] - async for chunk in await stream: - chunks.append(chunk) - - assert len(chunks) > 0 - assert all(isinstance(chunk, bytes) for chunk in chunks) - - -@pytest.mark.asyncio -async def test_list_voices(mock_tts_service): - """Test listing available voices""" - with patch("api.src.inference.voice_manager.settings") as mock_settings: - mock_settings.voices_dir = MOCK_VOICES_DIR - voices = await mock_tts_service.list_voices() - assert isinstance(voices, list) - assert len(voices) == 4 # ["af", "af_bella", "af_sarah", "bm_lewis"] - assert all(isinstance(voice, str) for voice in voices) - - -@pytest.mark.asyncio -async def test_combine_voices(mock_tts_service): - """Test combining voices""" - with patch("api.src.inference.voice_manager.settings") as mock_settings: - mock_settings.voices_dir = MOCK_VOICES_DIR - voices = ["af_bella", "af_sarah"] - result = await mock_tts_service.combine_voices(voices) - assert isinstance(result, str) - assert result == "af_bella_af_sarah" - - -@pytest.mark.asyncio -async def test_audio_to_bytes(mock_tts_service): - """Test converting audio to bytes""" - audio = np.zeros(48000, dtype=np.float32) - audio_bytes = mock_tts_service._audio_to_bytes(audio) - assert isinstance(audio_bytes, bytes) - assert len(audio_bytes) > 0 - - -@pytest.mark.asyncio -async def test_voice_loading(mock_tts_service): - """Test voice loading""" - with patch("api.src.inference.voice_manager.settings") as mock_settings, \ - patch("os.path.exists", return_value=True), \ - patch("torch.load", return_value=torch.zeros(192)): - mock_settings.voices_dir = MOCK_VOICES_DIR - voice = await mock_tts_service._voice_manager.load_voice("af", device="cpu") - assert isinstance(voice, torch.Tensor) - assert voice.shape == (192,) - - -@pytest.mark.asyncio -async def test_model_generation(mock_tts_service): - """Test model generation""" - tokens = [1, 2, 3] - voice_tensor = torch.zeros(192) - audio = await mock_tts_service.model_manager.generate(tokens, voice_tensor) - assert isinstance(audio, torch.Tensor) - assert audio.shape == (48000,) - assert audio.dtype == torch.float32 diff --git a/api/src/core/config.py b/api/src/core/config.py index 4c8fbbd..8588fc7 100644 --- a/api/src/core/config.py +++ b/api/src/core/config.py @@ -15,7 +15,7 @@ class Settings(BaseSettings): default_voice: str = "af" use_gpu: bool = False # Whether to use GPU acceleration if available use_onnx: bool = True # Whether to use ONNX runtime - allow_local_voice_saving: bool = True # Whether to allow saving combined voices locally + allow_local_voice_saving: bool = False # Whether to allow saving combined voices locally # Container absolute paths model_dir: str = "/app/api/src/models" # Absolute path in container diff --git a/api/src/inference/voice_manager.py b/api/src/inference/voice_manager.py index 55644fd..de9765c 100644 --- a/api/src/inference/voice_manager.py +++ b/api/src/inference/voice_manager.py @@ -49,11 +49,29 @@ class VoiceManager: Raises: RuntimeError: If voice loading fails """ + # Check if it's a combined voice request + if "+" in voice_name: + voices = [v.strip() for v in voice_name.split("+") if v.strip()] + if len(voices) < 2: + raise RuntimeError(f"Invalid combined voice name: {voice_name}") + + # Load and combine voices + voice_tensors = [] + for voice in voices: + try: + voice_tensor = await self.load_voice(voice, device) + voice_tensors.append(voice_tensor) + except Exception as e: + raise RuntimeError(f"Failed to load base voice {voice}: {e}") + + return torch.mean(torch.stack(voice_tensors), dim=0) + + # Handle single voice voice_path = self.get_voice_path(voice_name) if not voice_path: raise RuntimeError(f"Voice not found: {voice_name}") - # Check cache first + # Check cache cache_key = f"{voice_path}_{device}" if self._config.use_cache and cache_key in self._voice_cache: return self._voice_cache[cache_key] @@ -98,48 +116,39 @@ class VoiceManager: if len(voices) < 2: raise ValueError("At least 2 voices are required for combination") - # Load voices - voice_tensors: List[torch.Tensor] = [] - for voice in voices: - try: - voice_tensor = await self.load_voice(voice, device) - voice_tensors.append(voice_tensor) - except Exception as e: - raise RuntimeError(f"Failed to load voice {voice}: {e}") + # Create combined name using + as separator + combined_name = "+".join(voices) - try: - # Combine voices - combined_name = "_".join(voices) - combined_tensor = torch.mean(torch.stack(voice_tensors), dim=0) - - # Get api directory path - api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) - voices_dir = os.path.join(api_dir, settings.voices_dir) - os.makedirs(voices_dir, exist_ok=True) - - # Only save to disk if local voice saving is allowed - if settings.allow_local_voice_saving: + # If saving is enabled, try to save the combination + if settings.allow_local_voice_saving: + try: + # Load and combine voices + combined_tensor = await self.load_voice(combined_name, device) + + # Save to disk + api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + voices_dir = os.path.join(api_dir, settings.voices_dir) + os.makedirs(voices_dir, exist_ok=True) + combined_path = os.path.join(voices_dir, f"{combined_name}.pt") try: torch.save(combined_tensor, combined_path) - # Cache the new combined voice with disk path + # Cache with path-based key self._voice_cache[f"{combined_path}_{device}"] = combined_tensor except Exception as e: raise RuntimeError(f"Failed to save combined voice: {e}") - else: - # Just cache the combined voice in memory without saving to disk - self._voice_cache[f"{combined_name}_{device}"] = combined_tensor - return combined_name + except Exception as e: + logger.warning(f"Failed to save combined voice: {e}") + # Continue without saving - will be combined on-the-fly when needed - except Exception as e: - raise RuntimeError(f"Failed to combine voices: {e}") + return combined_name async def list_voices(self) -> List[str]: """List available voices. Returns: - List of voice names, including both disk-saved and in-memory combined voices + List of voice names """ voices = set() # Use set to avoid duplicates try: @@ -151,14 +160,6 @@ class VoiceManager: for entry in os.listdir(voices_dir): if entry.endswith(".pt"): voices.add(entry[:-3]) - - # Add in-memory combined voices from cache - for cache_key in self._voice_cache: - # Extract voice name from cache key (format: "name_device" or "path_device") - voice_name = cache_key.split("_")[0] - if "/" in voice_name: # It's a path - voice_name = os.path.basename(voice_name)[:-3] # Remove .pt extension - voices.add(voice_name) except Exception as e: logger.error(f"Error listing voices: {e}") diff --git a/api/src/routers/openai_compatible.py b/api/src/routers/openai_compatible.py index 9a34297..69573b5 100644 --- a/api/src/routers/openai_compatible.py +++ b/api/src/routers/openai_compatible.py @@ -8,6 +8,7 @@ from ..services.audio import AudioService from ..services.tts_service import TTSService from ..structures.schemas import OpenAISpeechRequest + router = APIRouter( tags=["OpenAI Compatible TTS"], responses={404: {"description": "Not found"}}, @@ -17,6 +18,7 @@ router = APIRouter( _tts_service = None _init_lock = None + async def get_tts_service() -> TTSService: """Get global TTSService instance""" global _tts_service, _init_lock @@ -50,19 +52,24 @@ async def process_voices( if not voices: raise ValueError("No voices provided") - # Check if all voices exist + # If single voice, validate and return it + if len(voices) == 1: + available_voices = await tts_service.list_voices() + if voices[0] not in available_voices: + raise ValueError( + f"Voice '{voices[0]}' not found. Available voices: {', '.join(sorted(available_voices))}" + ) + return voices[0] + + # For multiple voices, validate base voices exist available_voices = await tts_service.list_voices() for voice in voices: if voice not in available_voices: raise ValueError( - f"Voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}" + f"Base voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}" ) - # If single voice, return it directly - if len(voices) == 1: - return voices[0] - - # Otherwise combine voices + # Combine voices return await tts_service.combine_voices(voices=voices) diff --git a/api/tests/test_voice_manager.py b/api/tests/test_voice_manager.py index 2c129e8..0205486 100644 --- a/api/tests/test_voice_manager.py +++ b/api/tests/test_voice_manager.py @@ -1,82 +1,149 @@ import pytest -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, patch, MagicMock import torch from pathlib import Path -@pytest.mark.asyncio -async def test_list_available_voices(mock_voice_manager): - """Test listing available voices""" - voices = await mock_voice_manager.list_voices() - assert len(voices) == 2 - assert "voice1" in voices - assert "voice2" in voices +from ..src.inference.voice_manager import VoiceManager +from ..src.structures.model_schemas import VoiceConfig + + +@pytest.fixture +def mock_voice_tensor(): + return torch.randn(10, 10) # Dummy tensor + + +@pytest.fixture +def voice_manager(): + return VoiceManager(VoiceConfig()) + @pytest.mark.asyncio -async def test_get_voice_path(mock_voice_manager): - """Test getting path for a specific voice""" - voice_path = mock_voice_manager.get_voice_path("voice1") - assert voice_path == "/mock/path/voice.pt" +async def test_load_voice(voice_manager, mock_voice_tensor): + """Test loading a single voice""" + with patch("api.src.core.paths.load_voice_tensor", new_callable=AsyncMock) as mock_load: + mock_load.return_value = mock_voice_tensor + with patch("os.path.exists", return_value=True): + voice = await voice_manager.load_voice("af_bella", "cpu") + assert torch.equal(voice, mock_voice_tensor) - # Test invalid voice - mock_voice_manager.get_voice_path.return_value = None - assert mock_voice_manager.get_voice_path("invalid_voice") is None @pytest.mark.asyncio -async def test_load_voice(mock_voice_manager, mock_voice_tensor): - """Test loading a voice tensor""" - voice_tensor = await mock_voice_manager.load_voice("voice1") - assert torch.equal(voice_tensor, mock_voice_tensor) - mock_voice_manager.load_voice.assert_called_once_with("voice1") - -@pytest.mark.asyncio -async def test_load_voice_not_found(mock_voice_manager): +async def test_load_voice_not_found(voice_manager): """Test loading non-existent voice""" - mock_voice_manager.get_voice_path.return_value = None - mock_voice_manager.load_voice.side_effect = ValueError("Voice not found: invalid_voice") - - with pytest.raises(ValueError, match="Voice not found: invalid_voice"): - await mock_voice_manager.load_voice("invalid_voice") + with patch("os.path.exists", return_value=False): + with pytest.raises(RuntimeError, match="Voice not found: invalid_voice"): + await voice_manager.load_voice("invalid_voice", "cpu") + @pytest.mark.asyncio -async def test_combine_voices(mock_voice_manager): - """Test combining two voices""" - voices = ["voice1", "voice2"] - weights = [0.5, 0.5] - - combined_id = await mock_voice_manager.combine_voices(voices, weights) - assert combined_id == "voice1_voice2" - mock_voice_manager.combine_voices.assert_called_once_with(voices, weights) +async def test_combine_voices_with_saving(voice_manager, mock_voice_tensor): + """Test combining voices with local saving enabled""" + with patch("api.src.core.paths.load_voice_tensor", new_callable=AsyncMock) as mock_load, \ + patch("torch.save") as mock_save, \ + patch("os.makedirs"), \ + patch("os.path.exists", return_value=True): + + # Setup mocks + mock_load.return_value = mock_voice_tensor + + # Mock settings + with patch("api.src.core.config.settings") as mock_settings: + mock_settings.allow_local_voice_saving = True + mock_settings.voices_dir = "/mock/voices" + + # Combine voices + combined = await voice_manager.combine_voices(["af_bella", "af_sarah"], "cpu") + assert combined == "af_bella+af_sarah" # Note: using + separator + + # Verify voice was saved + mock_save.assert_called_once() + @pytest.mark.asyncio -async def test_combine_voices_invalid_weights(mock_voice_manager): - """Test combining voices with invalid weights""" - voices = ["voice1", "voice2"] - weights = [0.3, 0.3] # Doesn't sum to 1 - - mock_voice_manager.combine_voices.side_effect = ValueError("Weights must sum to 1") - with pytest.raises(ValueError, match="Weights must sum to 1"): - await mock_voice_manager.combine_voices(voices, weights) +async def test_combine_voices_without_saving(voice_manager, mock_voice_tensor): + """Test combining voices without local saving""" + with patch("api.src.core.paths.load_voice_tensor", new_callable=AsyncMock) as mock_load, \ + patch("torch.save") as mock_save, \ + patch("os.makedirs"), \ + patch("os.path.exists", return_value=True): + + # Setup mocks + mock_load.return_value = mock_voice_tensor + + # Mock settings + with patch("api.src.core.config.settings") as mock_settings: + mock_settings.allow_local_voice_saving = False + mock_settings.voices_dir = "/mock/voices" + + # Combine voices + combined = await voice_manager.combine_voices(["af_bella", "af_sarah"], "cpu") + assert combined == "af_bella+af_sarah" # Note: using + separator + + # Verify voice was not saved + mock_save.assert_not_called() + @pytest.mark.asyncio -async def test_combine_voices_single_voice(mock_voice_manager): +async def test_combine_voices_single_voice(voice_manager): """Test combining with single voice""" - voices = ["voice1"] - weights = [1.0] - - mock_voice_manager.combine_voices.side_effect = ValueError("At least 2 voices are required") with pytest.raises(ValueError, match="At least 2 voices are required"): - await mock_voice_manager.combine_voices(voices, weights) + await voice_manager.combine_voices(["af_bella"], "cpu") + @pytest.mark.asyncio -async def test_cache_management(mock_voice_manager, mock_voice_tensor): +async def test_list_voices(voice_manager): + """Test listing available voices""" + with patch("os.listdir", return_value=["af_bella.pt", "af_sarah.pt", "af_bella+af_sarah.pt"]), \ + patch("os.makedirs"): + voices = await voice_manager.list_voices() + assert len(voices) == 3 + assert "af_bella" in voices + assert "af_sarah" in voices + assert "af_bella+af_sarah" in voices + + +@pytest.mark.asyncio +async def test_load_combined_voice(voice_manager, mock_voice_tensor): + """Test loading a combined voice""" + with patch("api.src.core.paths.load_voice_tensor", new_callable=AsyncMock) as mock_load: + mock_load.return_value = mock_voice_tensor + with patch("os.path.exists", return_value=True): + voice = await voice_manager.load_voice("af_bella+af_sarah", "cpu") + assert torch.equal(voice, mock_voice_tensor) + + +def test_cache_management(voice_manager, mock_voice_tensor): """Test voice cache management""" - # Mock cache info - mock_voice_manager.cache_info = {"size": 1, "max_size": 10} + # Set small cache size + voice_manager._config.cache_size = 2 - # Load voice to test caching - await mock_voice_manager.load_voice("voice1") + # Add items to cache + voice_manager._voice_cache = { + "voice1_cpu": torch.randn(5, 5), + "voice2_cpu": torch.randn(5, 5), + } - # Check cache info - cache_info = mock_voice_manager.cache_info - assert cache_info["size"] == 1 - assert cache_info["max_size"] == 10 \ No newline at end of file + # Try adding another item + voice_manager._manage_cache() + + # Check cache size maintained + assert len(voice_manager._voice_cache) <= 2 + + +@pytest.mark.asyncio +async def test_voice_loading_with_cache(voice_manager, mock_voice_tensor): + """Test voice loading with cache enabled""" + with patch("api.src.core.paths.load_voice_tensor", new_callable=AsyncMock) as mock_load, \ + patch("os.path.exists", return_value=True): + + mock_load.return_value = mock_voice_tensor + + # First load should hit disk + voice1 = await voice_manager.load_voice("af_bella", "cpu") + assert mock_load.call_count == 1 + + # Second load should hit cache + voice2 = await voice_manager.load_voice("af_bella", "cpu") + assert mock_load.call_count == 1 # Still 1 + + assert torch.equal(voice1, voice2) \ No newline at end of file diff --git a/web/app.js b/web/app.js index bc44e9a..5e122a6 100644 --- a/web/app.js +++ b/web/app.js @@ -2,21 +2,88 @@ class KokoroPlayer { constructor() { this.elements = { textInput: document.getElementById('text-input'), - voiceSelect: document.getElementById('voice-select'), - streamToggle: document.getElementById('stream-toggle'), + voiceSearch: document.getElementById('voice-search'), + voiceDropdown: document.getElementById('voice-dropdown'), + voiceOptions: document.getElementById('voice-options'), + selectedVoices: document.getElementById('selected-voices'), autoplayToggle: document.getElementById('autoplay-toggle'), + formatSelect: document.getElementById('format-select'), generateBtn: document.getElementById('generate-btn'), - audioPlayer: document.getElementById('audio-player'), + cancelBtn: document.getElementById('cancel-btn'), + playPauseBtn: document.getElementById('play-pause-btn'), + waveContainer: document.getElementById('wave-container'), + timeDisplay: document.getElementById('time-display'), + downloadBtn: document.getElementById('download-btn'), status: document.getElementById('status') }; this.isGenerating = false; + this.availableVoices = []; + this.selectedVoiceSet = new Set(); + this.currentController = null; + this.audioChunks = []; + this.sound = null; + this.wave = null; this.init(); } async init() { await this.loadVoices(); + this.setupWave(); this.setupEventListeners(); + this.setupAudioControls(); + } + + setupWave() { + this.wave = new SiriWave({ + container: this.elements.waveContainer, + width: this.elements.waveContainer.clientWidth, + height: 50, + style: 'ios', + color: '#6366f1', + speed: 0.02, + amplitude: 0.7, + frequency: 4 + }); + } + + formatTime(secs) { + const minutes = Math.floor(secs / 60); + const seconds = Math.floor(secs % 60); + return `${minutes}:${seconds.toString().padStart(2, '0')}`; + } + + updateTimeDisplay() { + if (!this.sound) return; + const seek = this.sound.seek() || 0; + const duration = this.sound.duration() || 0; + this.elements.timeDisplay.textContent = `${this.formatTime(seek)} / ${this.formatTime(duration)}`; + + // Update seek slider + const seekSlider = document.getElementById('seek-slider'); + seekSlider.value = (seek / duration) * 100 || 0; + + if (this.sound.playing()) { + requestAnimationFrame(() => this.updateTimeDisplay()); + } + } + + setupAudioControls() { + const seekSlider = document.getElementById('seek-slider'); + const volumeSlider = document.getElementById('volume-slider'); + + seekSlider.addEventListener('input', (e) => { + if (!this.sound) return; + const duration = this.sound.duration(); + const seekTime = (duration * e.target.value) / 100; + this.sound.seek(seekTime); + }); + + volumeSlider.addEventListener('input', (e) => { + if (!this.sound) return; + const volume = e.target.value / 100; + this.sound.volume(volume); + }); } async loadVoices() { @@ -32,27 +99,132 @@ class KokoroPlayer { throw new Error('No voices available'); } - this.elements.voiceSelect.innerHTML = data.voices - .map(voice => ``) - .join(''); + this.availableVoices = data.voices; + this.renderVoiceOptions(this.availableVoices); - // Select first voice by default - if (data.voices.length > 0) { - this.elements.voiceSelect.value = data.voices[0]; + if (this.selectedVoiceSet.size === 0) { + const firstVoice = this.availableVoices.find(voice => voice && voice.trim()); + if (firstVoice) { + this.addSelectedVoice(firstVoice); + } } this.showStatus('Voices loaded successfully', 'success'); } catch (error) { this.showStatus('Failed to load voices: ' + error.message, 'error'); - // Disable generate button if no voices this.elements.generateBtn.disabled = true; } } + renderVoiceOptions(voices) { + this.elements.voiceOptions.innerHTML = voices + .map(voice => ` + + `) + .join(''); + this.updateSelectedVoicesDisplay(); + } + + updateSelectedVoicesDisplay() { + this.elements.selectedVoices.innerHTML = Array.from(this.selectedVoiceSet) + .map(voice => ` + + ${voice} + × + + `) + .join(''); + + if (this.selectedVoiceSet.size > 0) { + this.elements.voiceSearch.placeholder = 'Search voices...'; + } else { + this.elements.voiceSearch.placeholder = 'Search and select voices...'; + } + } + + addSelectedVoice(voice) { + this.selectedVoiceSet.add(voice); + this.updateSelectedVoicesDisplay(); + } + + removeSelectedVoice(voice) { + this.selectedVoiceSet.delete(voice); + this.updateSelectedVoicesDisplay(); + const checkbox = this.elements.voiceOptions.querySelector(`input[value="${voice}"]`); + if (checkbox) checkbox.checked = false; + } + + filterVoices(searchTerm) { + const filtered = this.availableVoices.filter(voice => + voice.toLowerCase().includes(searchTerm.toLowerCase()) + ); + this.renderVoiceOptions(filtered); + } + setupEventListeners() { + window.addEventListener('beforeunload', () => { + if (this.currentController) { + this.currentController.abort(); + } + if (this.sound) { + this.sound.unload(); + } + }); + + this.elements.voiceSearch.addEventListener('input', (e) => { + this.filterVoices(e.target.value); + }); + + this.elements.voiceOptions.addEventListener('change', (e) => { + if (e.target.type === 'checkbox') { + if (e.target.checked) { + this.addSelectedVoice(e.target.value); + } else { + this.removeSelectedVoice(e.target.value); + } + } + }); + + this.elements.selectedVoices.addEventListener('click', (e) => { + if (e.target.classList.contains('remove-voice')) { + const voice = e.target.dataset.voice; + this.removeSelectedVoice(voice); + } + }); + this.elements.generateBtn.addEventListener('click', () => this.generateSpeech()); - this.elements.audioPlayer.addEventListener('ended', () => { - this.elements.generateBtn.disabled = false; + this.elements.cancelBtn.addEventListener('click', () => this.cancelGeneration()); + this.elements.playPauseBtn.addEventListener('click', () => this.togglePlayPause()); + this.elements.downloadBtn.addEventListener('click', () => this.downloadAudio()); + + document.addEventListener('click', (e) => { + if (!this.elements.voiceSearch.contains(e.target) && + !this.elements.voiceDropdown.contains(e.target)) { + this.elements.voiceDropdown.style.display = 'none'; + } + }); + + this.elements.voiceSearch.addEventListener('focus', () => { + this.elements.voiceDropdown.style.display = 'block'; + if (!this.elements.voiceSearch.value) { + this.elements.voiceSearch.placeholder = 'Search voices...'; + } + }); + + this.elements.voiceSearch.addEventListener('blur', () => { + if (!this.elements.voiceSearch.value && this.selectedVoiceSet.size === 0) { + this.elements.voiceSearch.placeholder = 'Search and select voices...'; + } + }); + + window.addEventListener('resize', () => { + if (this.wave) { + this.wave.width = this.elements.waveContainer.clientWidth; + } }); } @@ -67,7 +239,8 @@ class KokoroPlayer { setLoading(loading) { this.isGenerating = loading; this.elements.generateBtn.disabled = loading; - this.elements.generateBtn.className = loading ? 'primary loading' : 'primary'; + this.elements.generateBtn.className = loading ? 'loading' : ''; + this.elements.cancelBtn.style.display = loading ? 'block' : 'none'; } validateInput() { @@ -77,8 +250,7 @@ class KokoroPlayer { return false; } - const voice = this.elements.voiceSelect.value; - if (!voice) { + if (this.selectedVoiceSet.size === 0) { this.showStatus('Please select a voice', 'error'); return false; } @@ -86,89 +258,68 @@ class KokoroPlayer { return true; } - async generateSpeech() { - if (this.isGenerating || !this.validateInput()) return; - - const text = this.elements.textInput.value.trim(); - const voice = this.elements.voiceSelect.value; - const stream = this.elements.streamToggle.checked; - - this.setLoading(true); - - try { - if (stream) { - await this.handleStreamingAudio(text, voice); - } else { - await this.handleNonStreamingAudio(text, voice); + cancelGeneration() { + if (this.currentController) { + this.currentController.abort(); + this.currentController = null; + if (this.sound) { + this.sound.unload(); + this.sound = null; } - } catch (error) { - this.showStatus('Error generating speech: ' + error.message, 'error'); - } finally { + this.wave.stop(); + this.showStatus('Generation cancelled', 'info'); this.setLoading(false); } } - async handleStreamingAudio(text, voice) { - this.showStatus('Initializing audio stream...', 'info'); + togglePlayPause() { + if (!this.sound) return; - const response = await fetch('/v1/audio/speech', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - input: text, - voice: voice, - response_format: 'mp3', - stream: true - }) - }); - - if (!response.ok) { - const error = await response.json(); - throw new Error(error.detail?.message || 'Failed to generate speech'); + if (this.sound.playing()) { + this.sound.pause(); + this.wave.stop(); + this.elements.playPauseBtn.textContent = 'Play'; + } else { + this.sound.play(); + this.wave.start(); + this.elements.playPauseBtn.textContent = 'Pause'; + this.updateTimeDisplay(); } - - const mediaSource = new MediaSource(); - this.elements.audioPlayer.src = URL.createObjectURL(mediaSource); - - return new Promise((resolve, reject) => { - mediaSource.addEventListener('sourceopen', async () => { - try { - const sourceBuffer = mediaSource.addSourceBuffer('audio/mpeg'); - const reader = response.body.getReader(); - let totalChunks = 0; - - while (true) { - const {done, value} = await reader.read(); - if (done) break; - - // Wait for the buffer to be ready - if (sourceBuffer.updating) { - await new Promise(resolve => { - sourceBuffer.addEventListener('updateend', resolve, {once: true}); - }); - } - - sourceBuffer.appendBuffer(value); - totalChunks++; - this.showStatus(`Received chunk ${totalChunks}...`, 'info'); - } -mediaSource.endOfStream(); -if (this.elements.autoplayToggle.checked) { - await this.elements.audioPlayer.play(); -} -this.showStatus('Audio stream ready', 'success'); - this.showStatus('Audio stream ready', 'success'); - resolve(); - } catch (error) { - mediaSource.endOfStream(); - this.showStatus('Error during streaming: ' + error.message, 'error'); - reject(error); - } - }); - }); } - async handleNonStreamingAudio(text, voice) { + async generateSpeech() { + if (this.isGenerating || !this.validateInput()) return; + + if (this.sound) { + this.sound.unload(); + this.sound = null; + } + this.wave.stop(); + + this.elements.downloadBtn.style.display = 'none'; + this.audioChunks = []; + + const text = this.elements.textInput.value.trim(); + const voice = Array.from(this.selectedVoiceSet).join('+'); + + this.setLoading(true); + this.currentController = new AbortController(); + + try { + await this.handleAudio(text, voice); + } catch (error) { + if (error.name === 'AbortError') { + this.showStatus('Generation cancelled', 'info'); + } else { + this.showStatus('Error generating speech: ' + error.message, 'error'); + } + } finally { + this.currentController = null; + this.setLoading(false); + } + } + + async handleAudio(text, voice) { this.showStatus('Generating audio...', 'info'); const response = await fetch('/v1/audio/speech', { @@ -178,8 +329,9 @@ this.showStatus('Audio stream ready', 'success'); input: text, voice: voice, response_format: 'mp3', - stream: false - }) + stream: true + }), + signal: this.currentController.signal }); if (!response.ok) { @@ -187,17 +339,97 @@ this.showStatus('Audio stream ready', 'success'); throw new Error(error.detail?.message || 'Failed to generate speech'); } - const blob = await response.blob(); - const url = URL.createObjectURL(blob); - this.elements.audioPlayer.src = url; - if (this.elements.autoplayToggle.checked) { - await this.elements.audioPlayer.play(); + const chunks = []; + const reader = response.body.getReader(); + let totalChunks = 0; + + try { + while (true) { + const {value, done} = await reader.read(); + + if (done) { + this.showStatus('Processing complete', 'success'); + break; + } + + chunks.push(value); + this.audioChunks.push(value.slice(0)); + totalChunks++; + + if (totalChunks % 5 === 0) { + this.showStatus(`Received ${totalChunks} chunks...`, 'info'); + } + } + + const blob = new Blob(chunks, { type: 'audio/mpeg' }); + const url = URL.createObjectURL(blob); + + if (this.sound) { + this.sound.unload(); + } + + this.sound = new Howl({ + src: [url], + format: ['mp3'], + html5: true, + onplay: () => { + this.elements.playPauseBtn.textContent = 'Pause'; + this.wave.start(); + this.updateTimeDisplay(); + }, + onpause: () => { + this.elements.playPauseBtn.textContent = 'Play'; + this.wave.stop(); + }, + onend: () => { + this.elements.playPauseBtn.textContent = 'Play'; + this.wave.stop(); + this.elements.generateBtn.disabled = false; + }, + onload: () => { + URL.revokeObjectURL(url); + this.showStatus('Audio ready', 'success'); + this.enableDownload(); + if (this.elements.autoplayToggle.checked) { + this.sound.play(); + } + }, + onloaderror: () => { + URL.revokeObjectURL(url); + this.showStatus('Error loading audio', 'error'); + } + }); + + } catch (error) { + if (error.name === 'AbortError') { + throw error; + } + console.error('Streaming error:', error); + this.showStatus('Error during streaming', 'error'); + throw error; } - this.showStatus('Audio ready', 'success'); + } + + enableDownload() { + this.elements.downloadBtn.style.display = 'flex'; + } + + downloadAudio() { + if (this.audioChunks.length === 0) return; + + const format = this.elements.formatSelect.value; + const blob = new Blob(this.audioChunks, { type: `audio/${format}` }); + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = `generated-speech.${format}`; + document.body.appendChild(a); + a.click(); + document.body.removeChild(a); + URL.revokeObjectURL(url); } } -// Initialize the player when the page loads document.addEventListener('DOMContentLoaded', () => { new KokoroPlayer(); }); \ No newline at end of file diff --git a/web/favicon.svg b/web/favicon.svg new file mode 100644 index 0000000..ae7545d --- /dev/null +++ b/web/favicon.svg @@ -0,0 +1,47 @@ + \ No newline at end of file diff --git a/web/index.html b/web/index.html index e22bc52..e9c5616 100644 --- a/web/index.html +++ b/web/index.html @@ -3,13 +3,35 @@
-Kokoro-FastAPI TTS System