diff --git a/.coverage b/.coverage index b5254a6..2133e27 100644 Binary files a/.coverage and b/.coverage differ diff --git a/api/src/routers/openai_compatible.py b/api/src/routers/openai_compatible.py index 983fdc1..4183d39 100644 --- a/api/src/routers/openai_compatible.py +++ b/api/src/routers/openai_compatible.py @@ -1,10 +1,10 @@ from typing import List -from fastapi import APIRouter, Depends, HTTPException, Response from loguru import logger +from fastapi import Depends, Response, APIRouter, HTTPException -from ..services.audio import AudioService from ..services.tts import TTSService +from ..services.audio import AudioService from ..structures.schemas import OpenAISpeechRequest router = APIRouter( @@ -32,7 +32,7 @@ async def create_speech( raise ValueError( f"Voice '{request.voice}' not found. Available voices: {', '.join(sorted(available_voices))}" ) - + # Generate audio directly using TTSService's method audio, _ = tts_service._generate_audio( text=request.input, @@ -55,14 +55,12 @@ async def create_speech( except ValueError as e: logger.error(f"Invalid request: {str(e)}") raise HTTPException( - status_code=400, - detail={"error": "Invalid request", "message": str(e)} + status_code=400, detail={"error": "Invalid request", "message": str(e)} ) except Exception as e: logger.error(f"Error generating speech: {str(e)}") raise HTTPException( - status_code=500, - detail={"error": "Server error", "message": str(e)} + status_code=500, detail={"error": "Server error", "message": str(e)} ) @@ -78,17 +76,19 @@ async def list_voices(tts_service: TTSService = Depends(get_tts_service)): @router.post("/audio/voices/combine") -async def combine_voices(request: List[str], tts_service: TTSService = Depends(get_tts_service)): +async def combine_voices( + request: List[str], tts_service: TTSService = Depends(get_tts_service) +): """Combine multiple voices into a new voice. - + Args: request: List of voice names to combine - + Returns: Dict with combined voice name and list of all available voices - + Raises: - HTTPException: + HTTPException: - 400: Invalid request (wrong number of voices, voice not found) - 500: Server error (file system issues, combination failed) """ @@ -96,24 +96,21 @@ async def combine_voices(request: List[str], tts_service: TTSService = Depends(g combined_voice = tts_service.combine_voices(voices=request) voices = tts_service.list_voices() return {"voices": voices, "voice": combined_voice} - + except ValueError as e: logger.error(f"Invalid voice combination request: {str(e)}") raise HTTPException( - status_code=400, - detail={"error": "Invalid request", "message": str(e)} + status_code=400, detail={"error": "Invalid request", "message": str(e)} ) - + except RuntimeError as e: logger.error(f"Server error during voice combination: {str(e)}") raise HTTPException( - status_code=500, - detail={"error": "Server error", "message": str(e)} + status_code=500, detail={"error": "Server error", "message": str(e)} ) - + except Exception as e: logger.error(f"Unexpected error during voice combination: {str(e)}") raise HTTPException( - status_code=500, - detail={"error": "Unexpected error", "message": str(e)} + status_code=500, detail={"error": "Unexpected error", "message": str(e)} ) diff --git a/api/src/services/tts.py b/api/src/services/tts.py index de76836..c1abd9f 100644 --- a/api/src/services/tts.py +++ b/api/src/services/tts.py @@ -1,17 +1,16 @@ import io import os import re -import threading import time +import threading from typing import List, Tuple, Optional import numpy as np -import scipy.io.wavfile as wavfile -import tiktoken import torch +import tiktoken +import scipy.io.wavfile as wavfile +from kokoro import generate, tokenize, phonemize, normalize_text from loguru import logger - -from kokoro import generate, normalize_text, phonemize, tokenize from models import build_model from ..core.config import settings @@ -23,7 +22,7 @@ class TTSModel: _instance = None _device = None _lock = threading.Lock() - + # Directory for all voices (copied base voices, and any created combined voices) VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices") @@ -38,10 +37,10 @@ class TTSModel: model_path = os.path.join(settings.model_dir, settings.model_path) model = build_model(model_path, cls._device) cls._instance = model - + # Ensure voices directory exists os.makedirs(cls.VOICES_DIR, exist_ok=True) - + # Copy base voices to local directory base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir) if os.path.exists(base_voices_dir): @@ -51,25 +50,37 @@ class TTSModel: voice_path = os.path.join(cls.VOICES_DIR, file) if not os.path.exists(voice_path): try: - logger.info(f"Copying base voice {voice_name} to voices directory") + logger.info( + f"Copying base voice {voice_name} to voices directory" + ) base_path = os.path.join(base_voices_dir, file) - voicepack = torch.load(base_path, map_location=cls._device, weights_only=True) + voicepack = torch.load( + base_path, + map_location=cls._device, + weights_only=True, + ) torch.save(voicepack, voice_path) except Exception as e: - logger.error(f"Error copying voice {voice_name}: {str(e)}") - + logger.error( + f"Error copying voice {voice_name}: {str(e)}" + ) + # Warm up with default voice try: dummy_text = "Hello" voice_path = os.path.join(cls.VOICES_DIR, "af.pt") - dummy_voicepack = torch.load(voice_path, map_location=cls._device, weights_only=True) - generate(model, dummy_text, dummy_voicepack, lang='a', speed=1.0) + dummy_voicepack = torch.load( + voice_path, map_location=cls._device, weights_only=True + ) + generate(model, dummy_text, dummy_voicepack, lang="a", speed=1.0) logger.info("Model warm-up complete") except Exception as e: logger.warning(f"Model warm-up failed: {e}") - + # Count voices in directory for validation - voice_count = len([f for f in os.listdir(cls.VOICES_DIR) if f.endswith('.pt')]) + voice_count = len( + [f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")] + ) return cls._instance, voice_count @classmethod @@ -86,11 +97,11 @@ class TTSService: self._ensure_voices() if start_worker: self.start_worker() - + def _ensure_voices(self): """Copy base voices to local voices directory during initialization""" os.makedirs(TTSModel.VOICES_DIR, exist_ok=True) - + base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir) if os.path.exists(base_voices_dir): for file in os.listdir(base_voices_dir): @@ -99,9 +110,15 @@ class TTSService: voice_path = os.path.join(TTSModel.VOICES_DIR, file) if not os.path.exists(voice_path): try: - logger.info(f"Copying base voice {voice_name} to voices directory") + logger.info( + f"Copying base voice {voice_name} to voices directory" + ) base_path = os.path.join(base_voices_dir, file) - voicepack = torch.load(base_path, map_location=TTSModel._device, weights_only=True) + voicepack = torch.load( + base_path, + map_location=TTSModel._device, + weights_only=True, + ) torch.save(voicepack, voice_path) except Exception as e: logger.error(f"Error copying voice {voice_name}: {str(e)}") @@ -112,10 +129,10 @@ class TTSService: def _get_voice_path(self, voice_name: str) -> Optional[str]: """Get the path to a voice file. - + Args: voice_name: Name of the voice to find - + Returns: Path to the voice file if found, None otherwise """ @@ -141,7 +158,9 @@ class TTSService: # Load model and voice model = TTSModel._instance - voicepack = torch.load(voice_path, map_location=TTSModel._device, weights_only=True) + voicepack = torch.load( + voice_path, map_location=TTSModel._device, weights_only=True + ) # Generate audio with or without stitching if stitch_long_output: @@ -152,11 +171,11 @@ class TTSService: for i, chunk in enumerate(chunks): try: # Validate phonemization first - ps = phonemize(chunk, voice[0]) - tokens = tokenize(ps) - logger.debug( - f"Processing chunk {i + 1}/{len(chunks)}: {len(tokens)} tokens" - ) + # ps = phonemize(chunk, voice[0]) + # tokens = tokenize(ps) + # logger.debug( + # f"Processing chunk {i + 1}/{len(chunks)}: {len(tokens)} tokens" + # ) # Only proceed if phonemization succeeded chunk_audio, _ = generate( @@ -205,47 +224,51 @@ class TTSService: def combine_voices(self, voices: List[str]) -> str: """Combine multiple voices into a new voice. - + Args: voices: List of voice names to combine - + Returns: Name of the combined voice - + Raises: ValueError: If less than 2 voices provided or voice loading fails RuntimeError: If voice combination or saving fails """ if len(voices) < 2: raise ValueError("At least 2 voices are required for combination") - + # Load voices t_voices: List[torch.Tensor] = [] v_name: List[str] = [] - + for voice in voices: try: voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt") - voicepack = torch.load(voice_path, map_location=TTSModel._device, weights_only=True) + voicepack = torch.load( + voice_path, map_location=TTSModel._device, weights_only=True + ) t_voices.append(voicepack) v_name.append(voice) except Exception as e: raise ValueError(f"Failed to load voice {voice}: {str(e)}") - + # Combine voices try: f: str = "_".join(v_name) v = torch.mean(torch.stack(t_voices), dim=0) combined_path = os.path.join(TTSModel.VOICES_DIR, f"{f}.pt") - + # Save combined voice try: torch.save(v, combined_path) except Exception as e: - raise RuntimeError(f"Failed to save combined voice to {combined_path}: {str(e)}") - + raise RuntimeError( + f"Failed to save combined voice to {combined_path}: {str(e)}" + ) + return f - + except Exception as e: if not isinstance(e, (ValueError, RuntimeError)): raise RuntimeError(f"Error combining voices: {str(e)}") diff --git a/api/src/structures/schemas.py b/api/src/structures/schemas.py index 8ef36e4..bc778bb 100644 --- a/api/src/structures/schemas.py +++ b/api/src/structures/schemas.py @@ -17,8 +17,8 @@ class OpenAISpeechRequest(BaseModel): model: Literal["tts-1", "tts-1-hd", "kokoro"] = "kokoro" input: str = Field(..., description="The text to generate audio for") voice: str = Field( - default="af", - description="The voice to use for generation. Can be a base voice or a combined voice name." + default="af", + description="The voice to use for generation. Can be a base voice or a combined voice name.", ) response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field( default="mp3", diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 5972003..c41172f 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -1,16 +1,18 @@ import os -import shutil import sys +import shutil from unittest.mock import Mock, patch import pytest + def cleanup_mock_dirs(): """Clean up any MagicMock directories created during tests""" mock_dir = "MagicMock" if os.path.exists(mock_dir): shutil.rmtree(mock_dir) + @pytest.fixture(autouse=True) def cleanup(): """Automatically clean up before and after each test""" @@ -18,6 +20,7 @@ def cleanup(): yield cleanup_mock_dirs() + # Mock torch and other ML modules before they're imported sys.modules["torch"] = Mock() sys.modules["transformers"] = Mock() diff --git a/api/tests/test_audio_service.py b/api/tests/test_audio_service.py index 0e1d1bc..ac0780e 100644 --- a/api/tests/test_audio_service.py +++ b/api/tests/test_audio_service.py @@ -1,6 +1,8 @@ """Tests for AudioService""" + import numpy as np import pytest + from api.src.services.audio import AudioService diff --git a/api/tests/test_endpoints.py b/api/tests/test_endpoints.py index 7c11008..80fe733 100644 --- a/api/tests/test_endpoints.py +++ b/api/tests/test_endpoints.py @@ -114,9 +114,9 @@ def test_combine_voices_success(mock_tts_service): """Test successful voice combination""" test_voices = ["af_bella", "af_sarah"] mock_tts_service.combine_voices.return_value = "af_bella_af_sarah" - + response = client.post("/v1/audio/voices/combine", json=test_voices) - + assert response.status_code == 200 assert response.json()["voice"] == "af_bella_af_sarah" mock_tts_service.combine_voices.assert_called_once_with(voices=test_voices) @@ -126,9 +126,9 @@ def test_combine_voices_single_voice(mock_tts_service): """Test combining single voice returns default voice""" test_voices = ["af_bella"] mock_tts_service.combine_voices.return_value = "af" - + response = client.post("/v1/audio/voices/combine", json=test_voices) - + assert response.status_code == 200 assert response.json()["voice"] == "af" @@ -137,9 +137,9 @@ def test_combine_voices_empty_list(mock_tts_service): """Test combining empty voice list returns default voice""" test_voices = [] mock_tts_service.combine_voices.return_value = "af" - + response = client.post("/v1/audio/voices/combine", json=test_voices) - + assert response.status_code == 200 assert response.json()["voice"] == "af" @@ -148,8 +148,8 @@ def test_combine_voices_error(mock_tts_service): """Test error handling in voice combination""" test_voices = ["af_bella", "af_sarah"] mock_tts_service.combine_voices.side_effect = Exception("Combination failed") - + response = client.post("/v1/audio/voices/combine", json=test_voices) - + assert response.status_code == 500 assert "Combination failed" in response.json()["detail"]["message"] diff --git a/api/tests/test_main.py b/api/tests/test_main.py index 4eedc64..5b23749 100644 --- a/api/tests/test_main.py +++ b/api/tests/test_main.py @@ -1,7 +1,10 @@ """Tests for FastAPI application""" + +from unittest.mock import MagicMock, patch + import pytest -from unittest.mock import patch, MagicMock from fastapi.testclient import TestClient + from api.src.main import app, lifespan @@ -19,98 +22,100 @@ def test_health_check(test_client): @pytest.mark.asyncio -@patch('api.src.main.TTSModel') -@patch('api.src.main.logger') +@patch("api.src.main.TTSModel") +@patch("api.src.main.logger") async def test_lifespan_successful_warmup(mock_logger, mock_tts_model): """Test successful model warmup in lifespan""" # Mock the model initialization with model info and voicepack count mock_model = MagicMock() # Mock file system for voice counting mock_tts_model.VOICES_DIR = "/mock/voices" - with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt', 'voice3.pt']): + with patch("os.listdir", return_value=["voice1.pt", "voice2.pt", "voice3.pt"]): mock_tts_model.initialize.return_value = (mock_model, 3) # 3 voice files mock_tts_model._device = "cuda" # Set device class variable - + # Create an async generator from the lifespan context manager async_gen = lifespan(MagicMock()) # Start the context manager await async_gen.__aenter__() - + # Verify the expected logging sequence mock_logger.info.assert_any_call("Loading TTS model and voice packs...") mock_logger.info.assert_any_call("Model loaded and warmed up on cuda") mock_logger.info.assert_any_call("3 voice packs loaded successfully") - + # Verify model initialization was called mock_tts_model.initialize.assert_called_once() - + # Clean up await async_gen.__aexit__(None, None, None) @pytest.mark.asyncio -@patch('api.src.main.TTSModel') -@patch('api.src.main.logger') +@patch("api.src.main.TTSModel") +@patch("api.src.main.logger") async def test_lifespan_failed_warmup(mock_logger, mock_tts_model): """Test failed model warmup in lifespan""" # Mock the model initialization to fail mock_tts_model.initialize.side_effect = Exception("Failed to initialize model") - + # Create an async generator from the lifespan context manager async_gen = lifespan(MagicMock()) - + # Verify the exception is raised with pytest.raises(Exception, match="Failed to initialize model"): await async_gen.__aenter__() - + # Verify the expected logging sequence mock_logger.info.assert_called_with("Loading TTS model and voice packs...") - + # Clean up await async_gen.__aexit__(None, None, None) @pytest.mark.asyncio -@patch('api.src.main.TTSModel') +@patch("api.src.main.TTSModel") async def test_lifespan_cuda_warmup(mock_tts_model): """Test model warmup specifically on CUDA""" # Mock the model initialization with CUDA and voicepacks mock_model = MagicMock() # Mock file system for voice counting mock_tts_model.VOICES_DIR = "/mock/voices" - with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt']): + with patch("os.listdir", return_value=["voice1.pt", "voice2.pt"]): mock_tts_model.initialize.return_value = (mock_model, 2) # 2 voice files mock_tts_model._device = "cuda" # Set device class variable - + # Create an async generator from the lifespan context manager async_gen = lifespan(MagicMock()) await async_gen.__aenter__() - + # Verify model was initialized mock_tts_model.initialize.assert_called_once() - + # Clean up await async_gen.__aexit__(None, None, None) @pytest.mark.asyncio -@patch('api.src.main.TTSModel') +@patch("api.src.main.TTSModel") async def test_lifespan_cpu_fallback(mock_tts_model): """Test model warmup falling back to CPU""" # Mock the model initialization with CPU and voicepacks mock_model = MagicMock() # Mock file system for voice counting mock_tts_model.VOICES_DIR = "/mock/voices" - with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt', 'voice3.pt', 'voice4.pt']): + with patch( + "os.listdir", return_value=["voice1.pt", "voice2.pt", "voice3.pt", "voice4.pt"] + ): mock_tts_model.initialize.return_value = (mock_model, 4) # 4 voice files mock_tts_model._device = "cpu" # Set device class variable - + # Create an async generator from the lifespan context manager async_gen = lifespan(MagicMock()) await async_gen.__aenter__() - + # Verify model was initialized mock_tts_model.initialize.assert_called_once() - + # Clean up await async_gen.__aexit__(None, None, None) diff --git a/api/tests/test_tts_service.py b/api/tests/test_tts_service.py index a0273ad..8616c5f 100644 --- a/api/tests/test_tts_service.py +++ b/api/tests/test_tts_service.py @@ -1,9 +1,12 @@ """Tests for TTSService""" + import os +from unittest.mock import MagicMock, call, patch + import numpy as np import pytest -from unittest.mock import patch, MagicMock, call -from api.src.services.tts import TTSService, TTSModel + +from api.src.services.tts import TTSModel, TTSService @pytest.fixture @@ -50,42 +53,59 @@ def test_audio_to_bytes(tts_service, sample_audio): assert len(audio_bytes) > 0 -@patch('os.listdir') -@patch('os.path.join') +@patch("os.listdir") +@patch("os.path.join") def test_list_voices(mock_join, mock_listdir, tts_service): """Test listing available voices""" - mock_listdir.return_value = ['voice1.pt', 'voice2.pt', 'not_a_voice.txt'] - mock_join.return_value = '/fake/path' - + mock_listdir.return_value = ["voice1.pt", "voice2.pt", "not_a_voice.txt"] + mock_join.return_value = "/fake/path" + voices = tts_service.list_voices() assert len(voices) == 2 - assert 'voice1' in voices - assert 'voice2' in voices - assert 'not_a_voice' not in voices + assert "voice1" in voices + assert "voice2" in voices + assert "not_a_voice" not in voices -@patch('api.src.services.tts.TTSModel.get_instance') -@patch('api.src.services.tts.TTSModel.get_voicepack') -@patch('api.src.services.tts.normalize_text') -@patch('api.src.services.tts.phonemize') -@patch('api.src.services.tts.tokenize') -@patch('api.src.services.tts.generate') -def test_generate_audio_empty_text(mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_voicepack, mock_instance, tts_service): +@patch("api.src.services.tts.TTSModel.get_instance") +@patch("api.src.services.tts.TTSModel.get_voicepack") +@patch("api.src.services.tts.normalize_text") +@patch("api.src.services.tts.phonemize") +@patch("api.src.services.tts.tokenize") +@patch("api.src.services.tts.generate") +def test_generate_audio_empty_text( + mock_generate, + mock_tokenize, + mock_phonemize, + mock_normalize, + mock_voicepack, + mock_instance, + tts_service, +): """Test generating audio with empty text""" mock_normalize.return_value = "" - + with pytest.raises(ValueError, match="Text is empty after preprocessing"): tts_service._generate_audio("", "af", 1.0) -@patch('api.src.services.tts.TTSModel.get_instance') -@patch('os.path.exists') -@patch('api.src.services.tts.normalize_text') -@patch('api.src.services.tts.phonemize') -@patch('api.src.services.tts.tokenize') -@patch('api.src.services.tts.generate') -@patch('torch.load') -def test_generate_audio_no_chunks(mock_torch_load, mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_exists, mock_instance, tts_service): +@patch("api.src.services.tts.TTSModel.get_instance") +@patch("os.path.exists") +@patch("api.src.services.tts.normalize_text") +@patch("api.src.services.tts.phonemize") +@patch("api.src.services.tts.tokenize") +@patch("api.src.services.tts.generate") +@patch("torch.load") +def test_generate_audio_no_chunks( + mock_torch_load, + mock_generate, + mock_tokenize, + mock_phonemize, + mock_normalize, + mock_exists, + mock_instance, + tts_service, +): """Test generating audio with no successful chunks""" mock_normalize.return_value = "Test text" mock_phonemize.return_value = "Test text" @@ -94,19 +114,29 @@ def test_generate_audio_no_chunks(mock_torch_load, mock_generate, mock_tokenize, mock_instance.return_value = (MagicMock(), "cpu") mock_exists.return_value = True mock_torch_load.return_value = MagicMock() - + with pytest.raises(ValueError, match="No audio chunks were generated successfully"): tts_service._generate_audio("Test text", "af", 1.0) -@patch('api.src.services.tts.TTSModel.get_instance') -@patch('os.path.exists') -@patch('api.src.services.tts.normalize_text') -@patch('api.src.services.tts.phonemize') -@patch('api.src.services.tts.tokenize') -@patch('api.src.services.tts.generate') -@patch('torch.load') -def test_generate_audio_success(mock_torch_load, mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_exists, mock_instance, tts_service, sample_audio): +@patch("api.src.services.tts.TTSModel.get_instance") +@patch("os.path.exists") +@patch("api.src.services.tts.normalize_text") +@patch("api.src.services.tts.phonemize") +@patch("api.src.services.tts.tokenize") +@patch("api.src.services.tts.generate") +@patch("torch.load") +def test_generate_audio_success( + mock_torch_load, + mock_generate, + mock_tokenize, + mock_phonemize, + mock_normalize, + mock_exists, + mock_instance, + tts_service, + sample_audio, +): """Test successful audio generation""" mock_normalize.return_value = "Test text" mock_phonemize.return_value = "Test text" @@ -115,15 +145,15 @@ def test_generate_audio_success(mock_torch_load, mock_generate, mock_tokenize, m mock_instance.return_value = (MagicMock(), "cpu") mock_exists.return_value = True mock_torch_load.return_value = MagicMock() - + audio, processing_time = tts_service._generate_audio("Test text", "af", 1.0) assert isinstance(audio, np.ndarray) assert isinstance(processing_time, float) assert len(audio) > 0 -@patch('api.src.services.tts.torch.cuda.is_available') -@patch('api.src.services.tts.build_model') +@patch("api.src.services.tts.torch.cuda.is_available") +@patch("api.src.services.tts.build_model") def test_model_initialization_cuda(mock_build_model, mock_cuda_available): """Test model initialization with CUDA""" mock_cuda_available.return_value = True @@ -132,14 +162,14 @@ def test_model_initialization_cuda(mock_build_model, mock_cuda_available): TTSModel._instance = None # Reset singleton model, voice_count = TTSModel.initialize() - + assert TTSModel._device == "cuda" # Check the class variable instead assert model == mock_model mock_build_model.assert_called_once() -@patch('api.src.services.tts.torch.cuda.is_available') -@patch('api.src.services.tts.build_model') +@patch("api.src.services.tts.torch.cuda.is_available") +@patch("api.src.services.tts.build_model") def test_model_initialization_cpu(mock_build_model, mock_cuda_available): """Test model initialization with CPU""" mock_cuda_available.return_value = False @@ -148,76 +178,95 @@ def test_model_initialization_cpu(mock_build_model, mock_cuda_available): TTSModel._instance = None # Reset singleton model, voice_count = TTSModel.initialize() - + assert TTSModel._device == "cpu" # Check the class variable instead assert model == mock_model mock_build_model.assert_called_once() -@patch('api.src.services.tts.TTSService._get_voice_path') -@patch('api.src.services.tts.TTSModel.get_instance') +@patch("api.src.services.tts.TTSService._get_voice_path") +@patch("api.src.services.tts.TTSModel.get_instance") def test_voicepack_loading_error(mock_get_instance, mock_get_voice_path): """Test voicepack loading error handling""" mock_get_voice_path.return_value = None mock_get_instance.return_value = (MagicMock(), "cpu") - + TTSModel._voicepacks = {} # Reset voicepacks - + service = TTSService(start_worker=False) with pytest.raises(ValueError, match="Voice not found: nonexistent_voice"): service._generate_audio("test", "nonexistent_voice", 1.0) -@patch('api.src.services.tts.TTSModel') +@patch("api.src.services.tts.TTSModel") def test_save_audio(mock_tts_model, tts_service, sample_audio, tmp_path): """Test saving audio to file""" output_dir = os.path.join(tmp_path, "test_output") os.makedirs(output_dir, exist_ok=True) output_path = os.path.join(output_dir, "audio.wav") - + tts_service._save_audio(sample_audio, output_path) - + assert os.path.exists(output_path) assert os.path.getsize(output_path) > 0 -@patch('api.src.services.tts.TTSModel.get_instance') -@patch('os.path.exists') -@patch('api.src.services.tts.normalize_text') -@patch('api.src.services.tts.generate') -@patch('torch.load') -def test_generate_audio_without_stitching(mock_torch_load, mock_generate, mock_normalize, mock_exists, mock_instance, tts_service, sample_audio): +@patch("api.src.services.tts.TTSModel.get_instance") +@patch("os.path.exists") +@patch("api.src.services.tts.normalize_text") +@patch("api.src.services.tts.generate") +@patch("torch.load") +def test_generate_audio_without_stitching( + mock_torch_load, + mock_generate, + mock_normalize, + mock_exists, + mock_instance, + tts_service, + sample_audio, +): """Test generating audio without text stitching""" mock_normalize.return_value = "Test text" mock_generate.return_value = (sample_audio, None) mock_instance.return_value = (MagicMock(), "cpu") mock_exists.return_value = True mock_torch_load.return_value = MagicMock() - - audio, processing_time = tts_service._generate_audio("Test text", "af", 1.0, stitch_long_output=False) + + audio, processing_time = tts_service._generate_audio( + "Test text", "af", 1.0, stitch_long_output=False + ) assert isinstance(audio, np.ndarray) assert isinstance(processing_time, float) assert len(audio) > 0 mock_generate.assert_called_once() -@patch('os.listdir') +@patch("os.listdir") def test_list_voices_error(mock_listdir, tts_service): """Test error handling in list_voices""" mock_listdir.side_effect = Exception("Failed to list directory") - + voices = tts_service.list_voices() assert voices == [] -@patch('api.src.services.tts.TTSModel.get_instance') -@patch('os.path.exists') -@patch('api.src.services.tts.normalize_text') -@patch('api.src.services.tts.phonemize') -@patch('api.src.services.tts.tokenize') -@patch('api.src.services.tts.generate') -@patch('torch.load') -def test_generate_audio_phonemize_error(mock_torch_load, mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_exists, mock_instance, tts_service): +@patch("api.src.services.tts.TTSModel.get_instance") +@patch("os.path.exists") +@patch("api.src.services.tts.normalize_text") +@patch("api.src.services.tts.phonemize") +@patch("api.src.services.tts.tokenize") +@patch("api.src.services.tts.generate") +@patch("torch.load") +def test_generate_audio_phonemize_error( + mock_torch_load, + mock_generate, + mock_tokenize, + mock_phonemize, + mock_normalize, + mock_exists, + mock_instance, + tts_service, +): """Test handling phonemization error""" mock_normalize.return_value = "Test text" mock_phonemize.side_effect = Exception("Phonemization failed") @@ -225,23 +274,30 @@ def test_generate_audio_phonemize_error(mock_torch_load, mock_generate, mock_tok mock_exists.return_value = True mock_torch_load.return_value = MagicMock() mock_generate.return_value = (None, None) - + with pytest.raises(ValueError, match="No audio chunks were generated successfully"): tts_service._generate_audio("Test text", "af", 1.0) -@patch('api.src.services.tts.TTSModel.get_instance') -@patch('os.path.exists') -@patch('api.src.services.tts.normalize_text') -@patch('api.src.services.tts.generate') -@patch('torch.load') -def test_generate_audio_error(mock_torch_load, mock_generate, mock_normalize, mock_exists, mock_instance, tts_service): +@patch("api.src.services.tts.TTSModel.get_instance") +@patch("os.path.exists") +@patch("api.src.services.tts.normalize_text") +@patch("api.src.services.tts.generate") +@patch("torch.load") +def test_generate_audio_error( + mock_torch_load, + mock_generate, + mock_normalize, + mock_exists, + mock_instance, + tts_service, +): """Test handling generation error""" mock_normalize.return_value = "Test text" mock_generate.side_effect = Exception("Generation failed") mock_instance.return_value = (MagicMock(), "cpu") mock_exists.return_value = True mock_torch_load.return_value = MagicMock() - + with pytest.raises(ValueError, match="No audio chunks were generated successfully"): tts_service._generate_audio("Test text", "af", 1.0) diff --git a/examples/test_all_voices.py b/examples/test_all_voices.py index 5f9cf47..a143e83 100644 --- a/examples/test_all_voices.py +++ b/examples/test_all_voices.py @@ -19,7 +19,6 @@ output_dir = Path(__file__).parent / "output" output_dir.mkdir(exist_ok=True) - def test_voice(voice: str): speech_file = output_dir / f"speech_{voice}.mp3" print(f"\nTesting voice: {voice}") diff --git a/examples/test_analyze_combined_voices.py b/examples/test_analyze_combined_voices.py index f48be90..8db7865 100644 --- a/examples/test_analyze_combined_voices.py +++ b/examples/test_analyze_combined_voices.py @@ -1,21 +1,23 @@ #!/usr/bin/env python3 -import argparse import os -from typing import List, Optional, Dict, Tuple +import argparse +from typing import Dict, List, Tuple, Optional -import requests import numpy as np -from scipy.io import wavfile +import requests import matplotlib.pyplot as plt +from scipy.io import wavfile -def submit_combine_voices(voices: List[str], base_url: str = "http://localhost:8880") -> Optional[str]: +def submit_combine_voices( + voices: List[str], base_url: str = "http://localhost:8880" +) -> Optional[str]: """Combine multiple voices into a new voice. - + Args: voices: List of voice names to combine (e.g. ["af_bella", "af_sarah"]) base_url: API base URL - + Returns: Name of the combined voice (e.g. "af_bella_af_sarah") or None if error """ @@ -23,7 +25,7 @@ def submit_combine_voices(voices: List[str], base_url: str = "http://localhost:8 response = requests.post(f"{base_url}/v1/audio/voices/combine", json=voices) print(f"Response status: {response.status_code}") print(f"Raw response: {response.text}") - + # Accept both 200 and 201 as success if response.status_code not in [200, 201]: try: @@ -32,7 +34,7 @@ def submit_combine_voices(voices: List[str], base_url: str = "http://localhost:8 except: print(f"Error combining voices: {response.text}") return None - + try: data = response.json() if "voices" in data: @@ -46,15 +48,20 @@ def submit_combine_voices(voices: List[str], base_url: str = "http://localhost:8 return None -def generate_speech(text: str, voice: str, base_url: str = "http://localhost:8880", output_file: str = "output.mp3") -> bool: +def generate_speech( + text: str, + voice: str, + base_url: str = "http://localhost:8880", + output_file: str = "output.mp3", +) -> bool: """Generate speech using specified voice. - + Args: text: Text to convert to speech voice: Voice name to use base_url: API base URL output_file: Path to save audio file - + Returns: True if successful, False otherwise """ @@ -65,22 +72,25 @@ def generate_speech(text: str, voice: str, base_url: str = "http://localhost:888 "input": text, "voice": voice, "speed": 1.0, - "response_format": "wav" # Use WAV for analysis - } + "response_format": "wav", # Use WAV for analysis + }, ) - + if response.status_code != 200: error = response.json().get("detail", {}).get("message", response.text) print(f"Error generating speech: {error}") return False - + # Save the audio - os.makedirs(os.path.dirname(output_file) if os.path.dirname(output_file) else ".", exist_ok=True) + os.makedirs( + os.path.dirname(output_file) if os.path.dirname(output_file) else ".", + exist_ok=True, + ) with open(output_file, "wb") as f: f.write(response.content) print(f"Saved audio to {output_file}") return True - + except Exception as e: print(f"Error: {e}") return False @@ -88,57 +98,57 @@ def generate_speech(text: str, voice: str, base_url: str = "http://localhost:888 def analyze_audio(filepath: str) -> Tuple[np.ndarray, int, dict]: """Analyze audio file and return samples, sample rate, and audio characteristics. - + Args: filepath: Path to audio file - + Returns: Tuple of (samples, sample_rate, characteristics) """ sample_rate, samples = wavfile.read(filepath) - + # Convert to mono if stereo if len(samples.shape) > 1: samples = np.mean(samples, axis=1) - + # Calculate basic stats max_amp = np.max(np.abs(samples)) rms = np.sqrt(np.mean(samples**2)) duration = len(samples) / sample_rate - + # Zero crossing rate (helps identify voice characteristics) zero_crossings = np.sum(np.abs(np.diff(np.signbit(samples)))) / len(samples) - + # Simple frequency analysis if len(samples) > 0: # Use FFT to get frequency components fft_result = np.fft.fft(samples) - freqs = np.fft.fftfreq(len(samples), 1/sample_rate) - + freqs = np.fft.fftfreq(len(samples), 1 / sample_rate) + # Get positive frequencies only pos_mask = freqs > 0 freqs = freqs[pos_mask] magnitudes = np.abs(fft_result)[pos_mask] - + # Find dominant frequencies (top 3) top_indices = np.argsort(magnitudes)[-3:] dominant_freqs = freqs[top_indices] - + # Calculate spectral centroid (brightness of sound) spectral_centroid = np.sum(freqs * magnitudes) / np.sum(magnitudes) else: dominant_freqs = [] spectral_centroid = 0 - + characteristics = { "max_amplitude": max_amp, "rms": rms, "duration": duration, "zero_crossing_rate": zero_crossings, "dominant_frequencies": dominant_freqs, - "spectral_centroid": spectral_centroid + "spectral_centroid": spectral_centroid, } - + return samples, sample_rate, characteristics @@ -167,112 +177,136 @@ def setup_plot(fig, ax, title): return fig, ax + def plot_analysis(audio_files: Dict[str, str], output_dir: str): """Plot comprehensive voice analysis including waveforms and metrics comparison. - + Args: audio_files: Dictionary of label -> filepath output_dir: Directory to save plot files """ # Set dark style - plt.style.use('dark_background') - + plt.style.use("dark_background") + # Create figure with subplots fig = plt.figure(figsize=(15, 15)) fig.patch.set_facecolor("#1a1a2e") num_files = len(audio_files) - + # Create subplot grid with proper spacing - gs = plt.GridSpec(num_files + 1, 2, height_ratios=[1.5]*num_files + [1], - hspace=0.4, wspace=0.3) - + gs = plt.GridSpec( + num_files + 1, 2, height_ratios=[1.5] * num_files + [1], hspace=0.4, wspace=0.3 + ) + # Analyze all files first all_chars = {} for i, (label, filepath) in enumerate(audio_files.items()): samples, sample_rate, chars = analyze_audio(filepath) all_chars[label] = chars - + # Plot waveform spanning both columns ax = plt.subplot(gs[i, :]) time = np.arange(len(samples)) / sample_rate - plt.plot(time, samples / chars['max_amplitude'], linewidth=0.5, color="#ff2a6d") + plt.plot(time, samples / chars["max_amplitude"], linewidth=0.5, color="#ff2a6d") ax.set_xlabel("Time (seconds)") ax.set_ylabel("Normalized Amplitude") ax.set_ylim(-1.1, 1.1) setup_plot(fig, ax, f"Waveform: {label}") - + # Colors for voices colors = ["#ff2a6d", "#05d9e8", "#d1f7ff"] - + # Create two subplots for metrics with similar scales # Left subplot: Brightness and Volume ax1 = plt.subplot(gs[num_files, 0]) metrics1 = [ - ('Brightness', [chars['spectral_centroid']/1000 for chars in all_chars.values()], 'kHz'), - ('Volume', [chars['rms']*100 for chars in all_chars.values()], 'RMS×100') + ( + "Brightness", + [chars["spectral_centroid"] / 1000 for chars in all_chars.values()], + "kHz", + ), + ("Volume", [chars["rms"] * 100 for chars in all_chars.values()], "RMS×100"), ] - + # Right subplot: Voice Pitch and Texture ax2 = plt.subplot(gs[num_files, 1]) metrics2 = [ - ('Voice Pitch', [min(chars['dominant_frequencies']) for chars in all_chars.values()], 'Hz'), - ('Texture', [chars['zero_crossing_rate']*1000 for chars in all_chars.values()], 'ZCR×1000') + ( + "Voice Pitch", + [min(chars["dominant_frequencies"]) for chars in all_chars.values()], + "Hz", + ), + ( + "Texture", + [chars["zero_crossing_rate"] * 1000 for chars in all_chars.values()], + "ZCR×1000", + ), ] - + def plot_grouped_bars(ax, metrics, show_legend=True): n_groups = len(metrics) n_voices = len(audio_files) bar_width = 0.25 - + indices = np.arange(n_groups) - + # Get max value for y-axis scaling max_val = max(max(m[1]) for m in metrics) - + for i, (voice, color) in enumerate(zip(audio_files.keys(), colors)): values = [m[1][i] for m in metrics] - offset = (i - n_voices/2 + 0.5) * bar_width - bars = ax.bar(indices + offset, values, bar_width, - label=voice, color=color, alpha=0.8) - + offset = (i - n_voices / 2 + 0.5) * bar_width + bars = ax.bar( + indices + offset, values, bar_width, label=voice, color=color, alpha=0.8 + ) + # Add value labels on top of bars for bar in bars: height = bar.get_height() - ax.text(bar.get_x() + bar.get_width()/2., height, - f'{height:.1f}', - ha='center', va='bottom', color='white', - fontsize=10) - + ax.text( + bar.get_x() + bar.get_width() / 2.0, + height, + f"{height:.1f}", + ha="center", + va="bottom", + color="white", + fontsize=10, + ) + ax.set_xticks(indices) ax.set_xticklabels([f"{m[0]}\n({m[2]})" for m in metrics]) - + # Set y-axis limits with some padding ax.set_ylim(0, max_val * 1.2) - + if show_legend: - ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', - facecolor="#1a1a2e", edgecolor="#ffffff") - + ax.legend( + bbox_to_anchor=(1.05, 1), + loc="upper left", + facecolor="#1a1a2e", + edgecolor="#ffffff", + ) + # Plot both subplots plot_grouped_bars(ax1, metrics1, show_legend=True) plot_grouped_bars(ax2, metrics2, show_legend=False) - + # Style both subplots - setup_plot(fig, ax1, 'Brightness and Volume') - setup_plot(fig, ax2, 'Voice Pitch and Texture') - + setup_plot(fig, ax1, "Brightness and Volume") + setup_plot(fig, ax2, "Voice Pitch and Texture") + # Add y-axis labels - ax1.set_ylabel('Value') - ax2.set_ylabel('Value') - + ax1.set_ylabel("Value") + ax2.set_ylabel("Value") + # Adjust the figure size to accommodate the legend fig.set_size_inches(15, 15) - + # Add padding around the entire figure plt.subplots_adjust(right=0.85, top=0.95, bottom=0.05, left=0.1) plt.savefig(os.path.join(output_dir, "analysis_comparison.png"), dpi=300) print(f"Saved analysis comparison to {output_dir}/analysis_comparison.png") - + # Print detailed comparative analysis print("\nDetailed Voice Analysis:") for label, chars in all_chars.items(): @@ -282,44 +316,57 @@ def plot_analysis(audio_files: Dict[str, str], output_dir: str): print(f" Duration: {chars['duration']:.2f}s") print(f" Zero Crossing Rate: {chars['zero_crossing_rate']:.3f}") print(f" Spectral Centroid: {chars['spectral_centroid']:.0f}Hz") - print(f" Dominant Frequencies: {', '.join(f'{f:.0f}Hz' for f in chars['dominant_frequencies'])}") + print( + f" Dominant Frequencies: {', '.join(f'{f:.0f}Hz' for f in chars['dominant_frequencies'])}" + ) def main(): parser = argparse.ArgumentParser(description="Kokoro Voice Analysis Demo") parser.add_argument("--voices", nargs="+", type=str, help="Voices to combine") - parser.add_argument("--text", type=str, default="Hello! This is a test of combined voices.", help="Text to speak") + parser.add_argument( + "--text", + type=str, + default="Hello! This is a test of combined voices.", + help="Text to speak", + ) parser.add_argument("--url", default="http://localhost:8880", help="API base URL") - parser.add_argument("--output-dir", default="examples/output", help="Output directory for audio files") + parser.add_argument( + "--output-dir", + default="examples/output", + help="Output directory for audio files", + ) args = parser.parse_args() if not args.voices: print("No voices provided, using default test voices") args.voices = ["af_bella", "af_nicole"] - + # Create output directory os.makedirs(args.output_dir, exist_ok=True) - + # Dictionary to store audio files for analysis audio_files = {} - + # Generate speech with individual voices print("Generating speech with individual voices...") for voice in args.voices: output_file = os.path.join(args.output_dir, f"analysis_{voice}.wav") if generate_speech(args.text, voice, args.url, output_file): audio_files[voice] = output_file - + # Generate speech with combined voice print(f"\nCombining voices: {', '.join(args.voices)}") combined_voice = submit_combine_voices(args.voices, args.url) - + if combined_voice: print(f"Successfully created combined voice: {combined_voice}") - output_file = os.path.join(args.output_dir, f"analysis_combined_{combined_voice}.wav") + output_file = os.path.join( + args.output_dir, f"analysis_combined_{combined_voice}.wav" + ) if generate_speech(args.text, combined_voice, args.url, output_file): audio_files["combined"] = output_file - + # Generate comparison plots plot_analysis(audio_files, args.output_dir) else: diff --git a/examples/test_openai_tts.py b/examples/test_openai_tts.py index 7cc8104..80e3602 100644 --- a/examples/test_openai_tts.py +++ b/examples/test_openai_tts.py @@ -60,7 +60,7 @@ def test_speed(speed: float): # Test different formats for format in ["wav", "mp3", "opus", "aac", "flac", "pcm"]: - test_format(format) # aac and pcm should fail as they are not supported + test_format(format) # aac and pcm should fail as they are not supported # Test different speeds for speed in [0.25, 1.0, 2.0, 4.0]: # 5.0 should fail as it's out of range diff --git a/requirements-test.txt b/requirements-test.txt index 53135bd..26a7791 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -10,3 +10,5 @@ sqlalchemy==2.0.27 pytest==8.0.0 httpx==0.26.0 pytest-asyncio==0.23.5 +pytest-cov==6.0.0 +gradio==4.19.2 diff --git a/ui/app.py b/ui/app.py index a3d9939..96aae35 100644 --- a/ui/app.py +++ b/ui/app.py @@ -2,8 +2,4 @@ from lib.interface import create_interface if __name__ == "__main__": demo = create_interface() - demo.launch( - server_name="0.0.0.0", - server_port=7860, - show_error=True - ) + demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True) diff --git a/ui/lib/api.py b/ui/lib/api.py index a9c6a19..1528656 100644 --- a/ui/lib/api.py +++ b/ui/lib/api.py @@ -1,16 +1,19 @@ -import requests -from typing import Tuple, List, Optional import os import datetime +from typing import List, Tuple, Optional + +import requests + from .config import API_URL, OUTPUTS_DIR + def check_api_status() -> Tuple[bool, List[str]]: """Check TTS service status and get available voices.""" try: # Use a longer timeout during startup response = requests.get( f"{API_URL}/v1/audio/voices", - timeout=30 # Increased timeout for initial startup period + timeout=30, # Increased timeout for initial startup period ) response.raise_for_status() voices = response.json().get("voices", []) @@ -31,16 +34,19 @@ def check_api_status() -> Tuple[bool, List[str]]: print(f"Unexpected error checking API status: {str(e)}") return False, [] -def text_to_speech(text: str, voice_id: str, format: str, speed: float) -> Optional[str]: + +def text_to_speech( + text: str, voice_id: str, format: str, speed: float +) -> Optional[str]: """Generate speech from text using TTS API.""" if not text.strip(): return None - + # Create output filename timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") output_filename = f"output_{timestamp}_voice-{voice_id}_speed-{speed}.{format}" output_path = os.path.join(OUTPUTS_DIR, output_filename) - + try: response = requests.post( f"{API_URL}/v1/audio/speech", @@ -49,17 +55,17 @@ def text_to_speech(text: str, voice_id: str, format: str, speed: float) -> Optio "input": text, "voice": voice_id, "response_format": format, - "speed": float(speed) + "speed": float(speed), }, headers={"Content-Type": "application/json"}, - timeout=300 # Longer timeout for speech generation + timeout=300, # Longer timeout for speech generation ) response.raise_for_status() - + with open(output_path, "wb") as f: f.write(response.content) return output_path - + except requests.exceptions.Timeout: print("Speech generation request timed out") return None @@ -70,6 +76,7 @@ def text_to_speech(text: str, voice_id: str, format: str, speed: float) -> Optio print(f"Unexpected error generating speech: {str(e)}") return None + def get_status_html(is_available: bool) -> str: """Generate HTML for status indicator.""" color = "green" if is_available else "red" diff --git a/ui/lib/components/input.py b/ui/lib/components/input.py index 2644060..793a89e 100644 --- a/ui/lib/components/input.py +++ b/ui/lib/components/input.py @@ -1,7 +1,10 @@ -import gradio as gr from typing import Tuple + +import gradio as gr + from .. import files + def create_input_column() -> Tuple[gr.Column, dict]: """Create the input column with text input and file handling.""" with gr.Column(scale=1) as col: @@ -11,49 +14,36 @@ def create_input_column() -> Tuple[gr.Column, dict]: # Direct Input Tab with gr.TabItem("Direct Input"): text_input = gr.Textbox( - label="Text to speak", - placeholder="Enter text here...", - lines=4 + label="Text to speak", placeholder="Enter text here...", lines=4 ) - text_submit = gr.Button( - "Generate Speech", - variant="primary", - size="lg" - ) - + text_submit = gr.Button("Generate Speech", variant="primary", size="lg") + # File Input Tab with gr.TabItem("From File"): # Existing files dropdown input_files_list = gr.Dropdown( label="Select Existing File", choices=files.list_input_files(), - value=None + value=None, ) - + # Simple file upload file_upload = gr.File( - label="Upload Text File (.txt)", - file_types=[".txt"] + label="Upload Text File (.txt)", file_types=[".txt"] ) - + file_preview = gr.Textbox( - label="File Content Preview", - interactive=False, - lines=4 + label="File Content Preview", interactive=False, lines=4 ) - + with gr.Row(): file_submit = gr.Button( - "Generate Speech", - variant="primary", - size="lg" + "Generate Speech", variant="primary", size="lg" ) clear_files = gr.Button( - "Clear Files", - variant="secondary", - size="lg" + "Clear Files", variant="secondary", size="lg" ) - + components = { "tabs": tabs, "text_input": text_input, @@ -62,7 +52,7 @@ def create_input_column() -> Tuple[gr.Column, dict]: "file_preview": file_preview, "text_submit": text_submit, "file_submit": file_submit, - "clear_files": clear_files + "clear_files": clear_files, } - + return col, components diff --git a/ui/lib/components/model.py b/ui/lib/components/model.py index 3b7ae96..444d0f8 100644 --- a/ui/lib/components/model.py +++ b/ui/lib/components/model.py @@ -1,45 +1,41 @@ -import gradio as gr from typing import Tuple, Optional + +import gradio as gr + from .. import api, config + def create_model_column(voice_ids: Optional[list] = None) -> Tuple[gr.Column, dict]: """Create the model settings column.""" if voice_ids is None: voice_ids = [] - + with gr.Column(scale=1) as col: gr.Markdown("### Model Settings") - + # Status button starts in waiting state status_btn = gr.Button( - "⌛ TTS Service: Waiting for Service...", - variant="secondary" + "⌛ TTS Service: Waiting for Service...", variant="secondary" ) - + voice_input = gr.Dropdown( choices=voice_ids, label="Voice", value=voice_ids[0] if voice_ids else None, - interactive=True + interactive=True, ) format_input = gr.Dropdown( - choices=config.AUDIO_FORMATS, - label="Audio Format", - value="mp3" + choices=config.AUDIO_FORMATS, label="Audio Format", value="mp3" ) speed_input = gr.Slider( - minimum=0.5, - maximum=2.0, - value=1.0, - step=0.1, - label="Speed" + minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speed" ) - + components = { "status_btn": status_btn, "voice": voice_input, "format": format_input, - "speed": speed_input + "speed": speed_input, } - + return col, components diff --git a/ui/lib/components/output.py b/ui/lib/components/output.py index 8ef4640..e25601d 100644 --- a/ui/lib/components/output.py +++ b/ui/lib/components/output.py @@ -1,40 +1,42 @@ -import gradio as gr from typing import Tuple + +import gradio as gr + from .. import files + def create_output_column() -> Tuple[gr.Column, dict]: """Create the output column with audio player and file list.""" with gr.Column(scale=1) as col: gr.Markdown("### Latest Output") - audio_output = gr.Audio( - label="Generated Speech", - type="filepath" - ) - + audio_output = gr.Audio(label="Generated Speech", type="filepath") + gr.Markdown("### Generated Files") output_files = gr.Dropdown( label="Previous Outputs", choices=files.list_output_files(), value=None, - allow_custom_value=False + allow_custom_value=False, ) - + play_btn = gr.Button("▶️ Play Selected", size="sm") - + selected_audio = gr.Audio( - label="Selected Output", - type="filepath", - visible=False + label="Selected Output", type="filepath", visible=False ) - - clear_outputs = gr.Button("⚠️ Delete All Previously Generated Output Audio 🗑️", size="sm", variant="secondary") - + + clear_outputs = gr.Button( + "⚠️ Delete All Previously Generated Output Audio 🗑️", + size="sm", + variant="secondary", + ) + components = { "audio_output": audio_output, "output_files": output_files, "play_btn": play_btn, "selected_audio": selected_audio, - "clear_outputs": clear_outputs + "clear_outputs": clear_outputs, } - + return col, components diff --git a/ui/lib/files.py b/ui/lib/files.py index 98867f3..867f4f4 100644 --- a/ui/lib/files.py +++ b/ui/lib/files.py @@ -1,17 +1,23 @@ import os -from typing import List, Optional, Tuple import datetime +from typing import List, Tuple, Optional + from .config import INPUTS_DIR, OUTPUTS_DIR, AUDIO_FORMATS + def list_input_files() -> List[str]: """List all input text files.""" - return [f for f in os.listdir(INPUTS_DIR) if f.endswith('.txt')] + return [f for f in os.listdir(INPUTS_DIR) if f.endswith(".txt")] + def list_output_files() -> List[str]: """List all output audio files.""" - return [os.path.join(OUTPUTS_DIR, f) - for f in os.listdir(OUTPUTS_DIR) - if any(f.endswith(ext) for ext in AUDIO_FORMATS)] + return [ + os.path.join(OUTPUTS_DIR, f) + for f in os.listdir(OUTPUTS_DIR) + if any(f.endswith(ext) for ext in AUDIO_FORMATS) + ] + def read_text_file(filename: str) -> str: """Read content of a text file.""" @@ -19,16 +25,17 @@ def read_text_file(filename: str) -> str: return "" try: file_path = os.path.join(INPUTS_DIR, filename) - with open(file_path, 'r', encoding='utf-8') as f: + with open(file_path, "r", encoding="utf-8") as f: return f.read() except: return "" + def save_text(text: str, filename: Optional[str] = None) -> Optional[str]: """Save text to a file. Returns the filename if successful.""" if not text.strip(): return None - + if filename is None: # Use input_1.txt, input_2.txt, etc. base = "input" @@ -41,12 +48,12 @@ def save_text(text: str, filename: Optional[str] = None) -> Optional[str]: else: # Handle duplicate filenames by adding _1, _2, etc. base = os.path.splitext(filename)[0] - ext = os.path.splitext(filename)[1] or '.txt' + ext = os.path.splitext(filename)[1] or ".txt" counter = 1 while os.path.exists(os.path.join(INPUTS_DIR, filename)): filename = f"{base}_{counter}{ext}" counter += 1 - + filepath = os.path.join(INPUTS_DIR, filename) try: with open(filepath, "w", encoding="utf-8") as f: @@ -56,11 +63,12 @@ def save_text(text: str, filename: Optional[str] = None) -> Optional[str]: print(f"Error saving file: {e}") return None + def delete_all_input_files() -> bool: """Delete all files from the inputs directory. Returns True if successful.""" try: for filename in os.listdir(INPUTS_DIR): - if filename.endswith('.txt'): + if filename.endswith(".txt"): file_path = os.path.join(INPUTS_DIR, filename) os.remove(file_path) return True @@ -68,6 +76,7 @@ def delete_all_input_files() -> bool: print(f"Error deleting input files: {e}") return False + def delete_all_output_files() -> bool: """Delete all audio files from the outputs directory. Returns True if successful.""" try: @@ -80,19 +89,20 @@ def delete_all_output_files() -> bool: print(f"Error deleting output files: {e}") return False + def process_uploaded_file(file_path: str) -> bool: """Save uploaded file to inputs directory. Returns True if successful.""" if not file_path: return False - + try: filename = os.path.basename(file_path) - if not filename.endswith('.txt'): + if not filename.endswith(".txt"): return False - + # Create target path in inputs directory target_path = os.path.join(INPUTS_DIR, filename) - + # If file exists, add number suffix base, ext = os.path.splitext(filename) counter = 1 @@ -100,12 +110,13 @@ def process_uploaded_file(file_path: str) -> bool: new_name = f"{base}_{counter}{ext}" target_path = os.path.join(INPUTS_DIR, new_name) counter += 1 - + # Copy file to inputs directory import shutil + shutil.copy2(file_path, target_path) return True - + except Exception as e: print(f"Error saving uploaded file: {e}") return False diff --git a/ui/lib/handlers.py b/ui/lib/handlers.py index 94c9574..eba6cda 100644 --- a/ui/lib/handlers.py +++ b/ui/lib/handlers.py @@ -1,16 +1,19 @@ -import gradio as gr import os import shutil + +import gradio as gr + from . import api, files + def setup_event_handlers(components: dict): """Set up all event handlers for the UI components.""" - + def refresh_status(): try: is_available, voices = api.check_api_status() status = "Available" if is_available else "Waiting for Service..." - + if is_available and voices: # Preserve current voice selection if it exists and is still valid current_voice = components["model"]["voice"].value @@ -19,17 +22,17 @@ def setup_event_handlers(components: dict): gr.update( value=f"🔄 TTS Service: {status}", interactive=True, - variant="secondary" + variant="secondary", ), - gr.update(choices=voices, value=default_voice) + gr.update(choices=voices, value=default_voice), ] return [ gr.update( value=f"⌛ TTS Service: {status}", interactive=True, - variant="secondary" + variant="secondary", ), - gr.update(choices=[], value=None) + gr.update(choices=[], value=None), ] except Exception as e: print(f"Error in refresh status: {str(e)}") @@ -37,11 +40,11 @@ def setup_event_handlers(components: dict): gr.update( value="❌ TTS Service: Connection Error", interactive=True, - variant="secondary" + variant="secondary", ), - gr.update(choices=[], value=None) + gr.update(choices=[], value=None), ] - + def handle_file_select(filename): if filename: try: @@ -52,16 +55,16 @@ def setup_event_handlers(components: dict): except Exception as e: print(f"Error reading file: {e}") return gr.update(value="") - + def handle_file_upload(file): if file is None: return gr.update(choices=files.list_input_files()) - + try: # Copy file to inputs directory filename = os.path.basename(file.name) target_path = os.path.join(files.INPUTS_DIR, filename) - + # Handle duplicate filenames base, ext = os.path.splitext(filename) counter = 1 @@ -69,43 +72,36 @@ def setup_event_handlers(components: dict): new_name = f"{base}_{counter}{ext}" target_path = os.path.join(files.INPUTS_DIR, new_name) counter += 1 - + shutil.copy2(file.name, target_path) - + except Exception as e: print(f"Error uploading file: {e}") - + return gr.update(choices=files.list_input_files()) - + def generate_from_text(text, voice, format, speed): """Generate speech from direct text input""" is_available, _ = api.check_api_status() if not is_available: gr.Warning("TTS Service is currently unavailable") - return [ - None, - gr.update(choices=files.list_output_files()) - ] + return [None, gr.update(choices=files.list_output_files())] if not text or not text.strip(): gr.Warning("Please enter text in the input box") - return [ - None, - gr.update(choices=files.list_output_files()) - ] + return [None, gr.update(choices=files.list_output_files())] files.save_text(text) result = api.text_to_speech(text, voice, format, speed) if result is None: gr.Warning("Failed to generate speech. Please try again.") - return [ - None, - gr.update(choices=files.list_output_files()) - ] - + return [None, gr.update(choices=files.list_output_files())] + return [ result, - gr.update(choices=files.list_output_files(), value=os.path.basename(result)) + gr.update( + choices=files.list_output_files(), value=os.path.basename(result) + ), ] def generate_from_file(selected_file, voice, format, speed): @@ -113,37 +109,30 @@ def setup_event_handlers(components: dict): is_available, _ = api.check_api_status() if not is_available: gr.Warning("TTS Service is currently unavailable") - return [ - None, - gr.update(choices=files.list_output_files()) - ] + return [None, gr.update(choices=files.list_output_files())] if not selected_file: gr.Warning("Please select a file") - return [ - None, - gr.update(choices=files.list_output_files()) - ] + return [None, gr.update(choices=files.list_output_files())] text = files.read_text_file(selected_file) result = api.text_to_speech(text, voice, format, speed) if result is None: gr.Warning("Failed to generate speech. Please try again.") - return [ - None, - gr.update(choices=files.list_output_files()) - ] - + return [None, gr.update(choices=files.list_output_files())] + return [ result, - gr.update(choices=files.list_output_files(), value=os.path.basename(result)) + gr.update( + choices=files.list_output_files(), value=os.path.basename(result) + ), ] def play_selected(file_path): if file_path and os.path.exists(file_path): return gr.update(value=file_path, visible=True) return gr.update(visible=False) - + def clear_files(voice, format, speed): """Delete all input files and clear UI components while preserving model settings""" files.delete_all_input_files() @@ -155,7 +144,7 @@ def setup_event_handlers(components: dict): gr.update(choices=files.list_output_files()), # output_files gr.update(value=voice), # voice gr.update(value=format), # format - gr.update(value=speed) # speed + gr.update(value=speed), # speed ] def clear_outputs(): @@ -164,43 +153,40 @@ def setup_event_handlers(components: dict): return [ None, # audio_output gr.update(choices=[], value=None), # output_files - gr.update(visible=False) # selected_audio + gr.update(visible=False), # selected_audio ] # Connect event handlers components["model"]["status_btn"].click( fn=refresh_status, - outputs=[ - components["model"]["status_btn"], - components["model"]["voice"] - ] + outputs=[components["model"]["status_btn"], components["model"]["voice"]], ) - + components["input"]["file_select"].change( fn=handle_file_select, inputs=[components["input"]["file_select"]], - outputs=[components["input"]["file_preview"]] + outputs=[components["input"]["file_preview"]], ) - + components["input"]["file_upload"].upload( fn=handle_file_upload, inputs=[components["input"]["file_upload"]], - outputs=[components["input"]["file_select"]] + outputs=[components["input"]["file_select"]], ) - + components["output"]["play_btn"].click( fn=play_selected, inputs=[components["output"]["output_files"]], - outputs=[components["output"]["selected_audio"]] + outputs=[components["output"]["selected_audio"]], ) - + # Connect clear files button components["input"]["clear_files"].click( fn=clear_files, inputs=[ components["model"]["voice"], components["model"]["format"], - components["model"]["speed"] + components["model"]["speed"], ], outputs=[ components["input"]["file_select"], @@ -210,10 +196,10 @@ def setup_event_handlers(components: dict): components["output"]["output_files"], components["model"]["voice"], components["model"]["format"], - components["model"]["speed"] - ] + components["model"]["speed"], + ], ) - + # Connect submit buttons for each tab components["input"]["text_submit"].click( fn=generate_from_text, @@ -221,22 +207,22 @@ def setup_event_handlers(components: dict): components["input"]["text_input"], components["model"]["voice"], components["model"]["format"], - components["model"]["speed"] + components["model"]["speed"], ], outputs=[ components["output"]["audio_output"], - components["output"]["output_files"] - ] + components["output"]["output_files"], + ], ) - + # Connect clear outputs button components["output"]["clear_outputs"].click( fn=clear_outputs, outputs=[ components["output"]["audio_output"], components["output"]["output_files"], - components["output"]["selected_audio"] - ] + components["output"]["selected_audio"], + ], ) components["input"]["file_submit"].click( @@ -245,10 +231,10 @@ def setup_event_handlers(components: dict): components["input"]["file_select"], components["model"]["voice"], components["model"]["format"], - components["model"]["speed"] + components["model"]["speed"], ], outputs=[ components["output"]["audio_output"], - components["output"]["output_files"] - ] + components["output"]["output_files"], + ], ) diff --git a/ui/lib/interface.py b/ui/lib/interface.py index 5361217..a23ed7c 100644 --- a/ui/lib/interface.py +++ b/ui/lib/interface.py @@ -1,69 +1,75 @@ import gradio as gr + from . import api -from .components import create_input_column, create_model_column, create_output_column from .handlers import setup_event_handlers +from .components import create_input_column, create_model_column, create_output_column + def create_interface(): """Create the main Gradio interface.""" # Skip initial status check - let the timer handle it is_available, available_voices = False, [] - with gr.Blocks( - title="Kokoro TTS Demo", - theme=gr.themes.Monochrome() -) as demo: - gr.HTML(value='
' - 'Kokoro-82M HF Repo' - 'Kokoro-FastAPI Repo' - '
', show_label=False) - + with gr.Blocks(title="Kokoro TTS Demo", theme=gr.themes.Monochrome()) as demo: + gr.HTML( + value='
' + 'Kokoro-82M HF Repo' + 'Kokoro-FastAPI Repo' + "
", + show_label=False, + ) + # Main interface with gr.Row(): # Create columns input_col, input_components = create_input_column() - model_col, model_components = create_model_column(available_voices) # Pass initial voices + model_col, model_components = create_model_column( + available_voices + ) # Pass initial voices output_col, output_components = create_output_column() - + # Collect all components components = { "input": input_components, "model": model_components, - "output": output_components + "output": output_components, } - + # Set up event handlers setup_event_handlers(components) - + # Add periodic status check with Timer def update_status(): try: is_available, voices = api.check_api_status() status = "Available" if is_available else "Waiting for Service..." - + if is_available and voices: # Service is available, update UI and stop timer current_voice = components["model"]["voice"].value - default_voice = current_voice if current_voice in voices else voices[0] + default_voice = ( + current_voice if current_voice in voices else voices[0] + ) # Return values in same order as outputs list return [ gr.update( value=f"🔄 TTS Service: {status}", interactive=True, - variant="secondary" + variant="secondary", ), gr.update(choices=voices, value=default_voice), - gr.update(active=False) # Stop timer + gr.update(active=False), # Stop timer ] - + # Service not available yet, keep checking return [ gr.update( value=f"⌛ TTS Service: {status}", interactive=True, - variant="secondary" + variant="secondary", ), gr.update(choices=[], value=None), - gr.update(active=True) + gr.update(active=True), ] except Exception as e: print(f"Error in status update: {str(e)}") @@ -72,20 +78,20 @@ def create_interface(): gr.update( value="❌ TTS Service: Connection Error", interactive=True, - variant="secondary" + variant="secondary", ), gr.update(choices=[], value=None), - gr.update(active=True) + gr.update(active=True), ] - + timer = gr.Timer(value=5) # Check every 5 seconds timer.tick( fn=update_status, outputs=[ components["model"]["status_btn"], components["model"]["voice"], - timer - ] + timer, + ], ) return demo diff --git a/ui/tests/conftest.py b/ui/tests/conftest.py index 05ae58d..e9bc035 100644 --- a/ui/tests/conftest.py +++ b/ui/tests/conftest.py @@ -1,5 +1,5 @@ -import pytest import gradio as gr +import pytest @pytest.fixture diff --git a/ui/tests/test_api.py b/ui/tests/test_api.py index c9b37db..fe5dbe7 100644 --- a/ui/tests/test_api.py +++ b/ui/tests/test_api.py @@ -1,6 +1,8 @@ +from unittest.mock import patch, mock_open + import pytest import requests -from unittest.mock import patch, mock_open + from ui.lib import api @@ -57,12 +59,11 @@ def test_check_api_status_connection_error(): def test_text_to_speech_success(mock_response, tmp_path): """Test successful speech generation""" - with patch("requests.post", return_value=mock_response({})), \ - patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)), \ - patch("builtins.open", mock_open()) as mock_file: - + with patch("requests.post", return_value=mock_response({})), patch( + "ui.lib.api.OUTPUTS_DIR", str(tmp_path) + ), patch("builtins.open", mock_open()) as mock_file: result = api.text_to_speech("test text", "voice1", "mp3", 1.0) - + assert result is not None assert "output_" in result assert result.endswith(".mp3") @@ -105,25 +106,24 @@ def test_get_status_html_unavailable(): def test_text_to_speech_api_params(mock_response, tmp_path): """Test correct API parameters are sent""" - with patch("requests.post") as mock_post, \ - patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)), \ - patch("builtins.open", mock_open()): - + with patch("requests.post") as mock_post, patch( + "ui.lib.api.OUTPUTS_DIR", str(tmp_path) + ), patch("builtins.open", mock_open()): mock_post.return_value = mock_response({}) api.text_to_speech("test text", "voice1", "mp3", 1.5) - + mock_post.assert_called_once() args, kwargs = mock_post.call_args - + # Check request body assert kwargs["json"] == { "model": "kokoro", "input": "test text", "voice": "voice1", "response_format": "mp3", - "speed": 1.5 + "speed": 1.5, } - + # Check headers and timeout assert kwargs["headers"] == {"Content-Type": "application/json"} assert kwargs["timeout"] == 300 diff --git a/ui/tests/test_components.py b/ui/tests/test_components.py index 9ddb1ad..b125cb7 100644 --- a/ui/tests/test_components.py +++ b/ui/tests/test_components.py @@ -1,8 +1,9 @@ -import pytest import gradio as gr +import pytest + +from ui.lib.config import AUDIO_FORMATS from ui.lib.components.model import create_model_column from ui.lib.components.output import create_output_column -from ui.lib.config import AUDIO_FORMATS def test_create_model_column_structure(): @@ -15,12 +16,7 @@ def test_create_model_column_structure(): assert isinstance(components, dict) # Test expected components presence - expected_components = { - "status_btn", - "voice", - "format", - "speed" - } + expected_components = {"status_btn", "voice", "format", "speed"} assert set(components.keys()) == expected_components # Test component types @@ -78,7 +74,7 @@ def test_create_output_column_structure(): "output_files", "play_btn", "selected_audio", - "clear_outputs" + "clear_outputs", } assert set(components.keys()) == expected_components diff --git a/ui/tests/test_files.py b/ui/tests/test_files.py index aaa0fe8..2e7e038 100644 --- a/ui/tests/test_files.py +++ b/ui/tests/test_files.py @@ -1,6 +1,8 @@ import os -import pytest from unittest.mock import patch + +import pytest + from ui.lib import files from ui.lib.config import AUDIO_FORMATS diff --git a/ui/tests/test_input.py b/ui/tests/test_input.py index 807a483..2919fd0 100644 --- a/ui/tests/test_input.py +++ b/ui/tests/test_input.py @@ -1,5 +1,6 @@ -import pytest import gradio as gr +import pytest + from ui.lib.components.input import create_input_column diff --git a/ui/tests/test_interface.py b/ui/tests/test_interface.py index 550591f..cff4825 100644 --- a/ui/tests/test_interface.py +++ b/ui/tests/test_interface.py @@ -1,12 +1,15 @@ -import pytest +from unittest.mock import MagicMock, PropertyMock, patch + import gradio as gr -from unittest.mock import patch, MagicMock, PropertyMock +import pytest + from ui.lib.interface import create_interface @pytest.fixture def mock_timer(): """Create a mock timer with events property""" + class MockEvent: def __init__(self, fn): self.fn = fn @@ -30,7 +33,7 @@ def test_create_interface_structure(): """Test the basic structure of the created interface""" with patch("ui.lib.api.check_api_status", return_value=(False, [])): demo = create_interface() - + # Test interface type and theme assert isinstance(demo, gr.Blocks) assert demo.title == "Kokoro TTS Demo" @@ -41,15 +44,14 @@ def test_interface_html_links(): """Test that HTML links are properly configured""" with patch("ui.lib.api.check_api_status", return_value=(False, [])): demo = create_interface() - + # Find HTML component html_components = [ - comp for comp in demo.blocks.values() - if isinstance(comp, gr.HTML) + comp for comp in demo.blocks.values() if isinstance(comp, gr.HTML) ] assert len(html_components) > 0 html = html_components[0] - + # Check for required links assert 'href="https://huggingface.co/hexgrad/Kokoro-82M"' in html.value assert 'href="https://github.com/remsky/Kokoro-FastAPI"' in html.value @@ -60,16 +62,17 @@ def test_interface_html_links(): def test_update_status_available(mock_timer): """Test status update when service is available""" voices = ["voice1", "voice2"] - with patch("ui.lib.api.check_api_status", return_value=(True, voices)), \ - patch("gradio.Timer", return_value=mock_timer): + with patch("ui.lib.api.check_api_status", return_value=(True, voices)), patch( + "gradio.Timer", return_value=mock_timer + ): demo = create_interface() - + # Get the update function update_fn = mock_timer.events[0].fn - + # Test update with available service updates = update_fn() - + assert "Available" in updates[0]["value"] assert updates[1]["choices"] == voices assert updates[1]["value"] == voices[0] @@ -78,13 +81,14 @@ def test_update_status_available(mock_timer): def test_update_status_unavailable(mock_timer): """Test status update when service is unavailable""" - with patch("ui.lib.api.check_api_status", return_value=(False, [])), \ - patch("gradio.Timer", return_value=mock_timer): + with patch("ui.lib.api.check_api_status", return_value=(False, [])), patch( + "gradio.Timer", return_value=mock_timer + ): demo = create_interface() update_fn = mock_timer.events[0].fn - + updates = update_fn() - + assert "Waiting for Service" in updates[0]["value"] assert updates[1]["choices"] == [] assert updates[1]["value"] is None @@ -93,13 +97,14 @@ def test_update_status_unavailable(mock_timer): def test_update_status_error(mock_timer): """Test status update when an error occurs""" - with patch("ui.lib.api.check_api_status", side_effect=Exception("Test error")), \ - patch("gradio.Timer", return_value=mock_timer): + with patch( + "ui.lib.api.check_api_status", side_effect=Exception("Test error") + ), patch("gradio.Timer", return_value=mock_timer): demo = create_interface() update_fn = mock_timer.events[0].fn - + updates = update_fn() - + assert "Connection Error" in updates[0]["value"] assert updates[1]["choices"] == [] assert updates[1]["value"] is None @@ -108,10 +113,11 @@ def test_update_status_error(mock_timer): def test_timer_configuration(mock_timer): """Test timer configuration""" - with patch("ui.lib.api.check_api_status", return_value=(False, [])), \ - patch("gradio.Timer", return_value=mock_timer): + with patch("ui.lib.api.check_api_status", return_value=(False, [])), patch( + "gradio.Timer", return_value=mock_timer + ): demo = create_interface() - + assert mock_timer.value == 5 # Check interval is 5 seconds assert len(mock_timer.events) == 1 # Should have one event handler @@ -120,20 +126,21 @@ def test_interface_components_presence(): """Test that all required components are present""" with patch("ui.lib.api.check_api_status", return_value=(False, [])): demo = create_interface() - + # Check for main component sections components = { - comp.label for comp in demo.blocks.values() - if hasattr(comp, 'label') and comp.label + comp.label + for comp in demo.blocks.values() + if hasattr(comp, "label") and comp.label } - + required_components = { "Text to speak", "Voice", "Audio Format", "Speed", "Generated Speech", - "Previous Outputs" + "Previous Outputs", } - + assert required_components.issubset(components)