diff --git a/.coverage b/.coverage index 7a62bb5..b5254a6 100644 Binary files a/.coverage and b/.coverage differ diff --git a/.gitignore b/.gitignore index ec379db..1d9db35 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ output/ *.db *.pyc *.pth +*.pt Kokoro-82M/* __pycache__/ diff --git a/README.md b/README.md index 45c1597..3976730 100644 --- a/README.md +++ b/README.md @@ -4,8 +4,8 @@ # Kokoro TTS API [![Model Commit](https://img.shields.io/badge/model--commit-a67f113-blue)](https://huggingface.co/hexgrad/Kokoro-82M/tree/8228a351f87c8a6076502c1e3b7e72e821ebec9a) -[![Tests](https://img.shields.io/badge/tests-33%20passed-darkgreen)]() -[![Coverage](https://img.shields.io/badge/coverage-97%25-darkgreen)]() +[![Tests](https://img.shields.io/badge/tests-36%20passed-darkgreen)]() +[![Coverage](https://img.shields.io/badge/coverage-91%25-darkgreen)]() FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model, providing an OpenAI-compatible endpoint with: - NVIDIA GPU accelerated inference (or CPU) option @@ -30,33 +30,40 @@ docker compose up --build # For CPU-only deployment (~10x slower, but doesn't require an NVIDIA GPU): docker compose -f docker-compose.cpu.yml up --build ``` - - - -Test all voices (from another terminal): -```bash -python examples/test_all_voices.py -``` +Quick tests (run from another terminal): Test OpenAI compatibility: ```bash +# Test OpenAI Compatibility python examples/test_openai_tts.py +# Test all available voices +python examples/test_all_voices.py ``` ## OpenAI-Compatible API -List available voices: +```python +# Using OpenAI's Python library +from openai import OpenAI +client = OpenAI(base_url="http://localhost:8880", api_key="not-needed") +response = client.audio.speech.create( + model="kokoro", # Not used but required for compatibility, also accepts library defaults + voice="af_bella", + input="Hello world!", + response_format="mp3" +) + +response.stream_to_file("output.mp3") +``` +Or Via Requests: ```python import requests +# Get list of all available voices response = requests.get("http://localhost:8880/audio/voices") voices = response.json()["voices"] -``` - -Generate speech: -```python -import requests +# Generate audio response = requests.post( "http://localhost:8880/audio/speech", json={ @@ -73,20 +80,28 @@ with open("output.mp3", "wb") as f: f.write(response.content) ``` -Using OpenAI's Python library: +## Voice Combination + +Combine voices and generate audio: ```python -from openai import OpenAI +import requests -client = OpenAI(base_url="http://localhost:8880", api_key="not-needed") - -response = client.audio.speech.create( - model="kokoro", # Not used but required for compatibility, also accepts library defaults - voice="af_bella", - input="Hello world!", - response_format="mp3" +# Create combined voice (saved locally on server) +response = requests.post( + "http://localhost:8880/v1/audio/voices/combine", + json=["af_bella", "af_sarah"] ) +combined_voice = response.json()["voice"] -response.stream_to_file("output.mp3") +# Generate audio with combined voice +response = requests.post( + "http://localhost:8880/v1/audio/speech", + json={ + "input": "Hello world!", + "voice": combined_voice, + "response_format": "mp3" + } +) ``` ## Performance Benchmarks @@ -115,6 +130,13 @@ Key Performance Metrics: - Multiple audio formats: mp3, wav, opus, flac, (aac & pcm not implemented) - Natural Boundary Detection: - Automatically splits and stitches at sentence boundaries to reduce artifacts and maintain performacne +- Voice Combination: + - Averages model weights of any existing voicepacks + - Saves generated voicepacks for future use + +

+ Voice Analysis Comparison +

*Note: CPU Inference is currently a very basic implementation, and not heavily tested* @@ -133,11 +155,3 @@ This project is licensed under the Apache License 2.0 - see below for details: - The inference code adapted from StyleTTS2 is MIT licensed The full Apache 2.0 license text can be found at: https://www.apache.org/licenses/LICENSE-2.0 - -## Sample - -
- - https://user-images.githubusercontent.com/338912d2-90f3-41fb-bca0-5db7b4e02287.mp4 - -
diff --git a/api/src/routers/openai_compatible.py b/api/src/routers/openai_compatible.py index 9cb7370..983fdc1 100644 --- a/api/src/routers/openai_compatible.py +++ b/api/src/routers/openai_compatible.py @@ -26,6 +26,13 @@ async def create_speech( ): """OpenAI-compatible endpoint for text-to-speech""" try: + # Validate voice exists + available_voices = tts_service.list_voices() + if request.voice not in available_voices: + raise ValueError( + f"Voice '{request.voice}' not found. Available voices: {', '.join(sorted(available_voices))}" + ) + # Generate audio directly using TTSService's method audio, _ = tts_service._generate_audio( text=request.input, @@ -45,9 +52,18 @@ async def create_speech( }, ) + except ValueError as e: + logger.error(f"Invalid request: {str(e)}") + raise HTTPException( + status_code=400, + detail={"error": "Invalid request", "message": str(e)} + ) except Exception as e: logger.error(f"Error generating speech: {str(e)}") - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException( + status_code=500, + detail={"error": "Server error", "message": str(e)} + ) @router.get("/audio/voices") @@ -63,10 +79,41 @@ async def list_voices(tts_service: TTSService = Depends(get_tts_service)): @router.post("/audio/voices/combine") async def combine_voices(request: List[str], tts_service: TTSService = Depends(get_tts_service)): + """Combine multiple voices into a new voice. + + Args: + request: List of voice names to combine + + Returns: + Dict with combined voice name and list of all available voices + + Raises: + HTTPException: + - 400: Invalid request (wrong number of voices, voice not found) + - 500: Server error (file system issues, combination failed) + """ try: - t = tts_service.combine_voices(voices=request) + combined_voice = tts_service.combine_voices(voices=request) voices = tts_service.list_voices() - return {"voices": voices, "voice": t} + return {"voices": voices, "voice": combined_voice} + + except ValueError as e: + logger.error(f"Invalid voice combination request: {str(e)}") + raise HTTPException( + status_code=400, + detail={"error": "Invalid request", "message": str(e)} + ) + + except RuntimeError as e: + logger.error(f"Server error during voice combination: {str(e)}") + raise HTTPException( + status_code=500, + detail={"error": "Server error", "message": str(e)} + ) + except Exception as e: - logger.error(f"Error listing voices: {str(e)}") - raise HTTPException(status_code=500, detail=str(e)) + logger.error(f"Unexpected error during voice combination: {str(e)}") + raise HTTPException( + status_code=500, + detail={"error": "Unexpected error", "message": str(e)} + ) diff --git a/api/src/services/tts.py b/api/src/services/tts.py index f3a24f0..686ef5d 100644 --- a/api/src/services/tts.py +++ b/api/src/services/tts.py @@ -3,7 +3,7 @@ import os import re import threading import time -from typing import List, Tuple +from typing import List, Tuple, Optional import numpy as np import scipy.io.wavfile as wavfile @@ -23,6 +23,9 @@ class TTSModel: _instance = None _lock = threading.Lock() _voicepacks = {} + + # Directory for all voices (copied base voices, and any created combined voices) + VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices") @classmethod def get_instance(cls): @@ -36,21 +39,21 @@ class TTSModel: # Note: RNN memory optimization is handled internally by the model cls._instance = (model, device) return cls._instance - + @classmethod def get_voicepack(cls, voice_name: str) -> torch.Tensor: + """Get a voice pack from the voices directory.""" model, device = cls.get_instance() if voice_name not in cls._voicepacks: try: - voice_path = os.path.join( - settings.model_dir, settings.voices_dir, f"{voice_name}.pt" - ) - voicepack = torch.load( - voice_path, map_location=device, weights_only=True - ) + voice_path = os.path.join(cls.VOICES_DIR, f"{voice_name}.pt") + if not os.path.exists(voice_path): + raise FileNotFoundError(f"Voice file not found: {voice_name}") + + voicepack = torch.load(voice_path, map_location=device, weights_only=True) cls._voicepacks[voice_name] = voicepack except Exception as e: - print(f"Error loading voice {voice_name}: {str(e)}") + logger.error(f"Error loading voice {voice_name}: {str(e)}") if voice_name != "af": return cls.get_voicepack("af") raise @@ -60,13 +63,45 @@ class TTSModel: class TTSService: def __init__(self, output_dir: str = None, start_worker: bool = False): self.output_dir = output_dir + self._ensure_voices() if start_worker: self.start_worker() + + def _ensure_voices(self): + """Copy base voices to local voices directory during initialization""" + os.makedirs(TTSModel.VOICES_DIR, exist_ok=True) + + base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir) + if os.path.exists(base_voices_dir): + for file in os.listdir(base_voices_dir): + if file.endswith(".pt"): + voice_name = file[:-3] + voice_path = os.path.join(TTSModel.VOICES_DIR, file) + if not os.path.exists(voice_path): + try: + base_path = os.path.join(base_voices_dir, file) + logger.info(f"Copying base voice {voice_name} to voices directory") + voicepack = torch.load(base_path, map_location=TTSModel.get_instance()[1], weights_only=True) + torch.save(voicepack, voice_path) + except Exception as e: + logger.error(f"Error copying voice {voice_name}: {str(e)}") def _split_text(self, text: str) -> List[str]: """Split text into sentences""" return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()] + def _get_voice_path(self, voice_name: str) -> Optional[str]: + """Get the path to a voice file. + + Args: + voice_name: Name of the voice to find + + Returns: + Path to the voice file if found, None otherwise + """ + voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice_name}.pt") + return voice_path if os.path.exists(voice_path) else None + def _generate_audio( self, text: str, voice: str, speed: float, stitch_long_output: bool = True ) -> Tuple[torch.Tensor, float]: @@ -79,9 +114,15 @@ class TTSService: if not text: raise ValueError("Text is empty after preprocessing") - # Get model instance and voicepack + # Get model instance model, device = TTSModel.get_instance() - voicepack = TTSModel.get_voicepack(voice) + + # Load voice + voice_path = self._get_voice_path(voice) + if not voice_path: + raise ValueError(f"Voice not found: {voice}") + + voicepack = torch.load(voice_path, map_location=device, weights_only=True) # Generate audio with or without stitching if stitch_long_output: @@ -143,34 +184,63 @@ class TTSService: return buffer.getvalue() def combine_voices(self, voices: List[str]) -> str: + """Combine multiple voices into a new voice. + + Args: + voices: List of voice names to combine + + Returns: + Name of the combined voice + + Raises: + ValueError: If less than 2 voices provided or voice loading fails + RuntimeError: If voice combination or saving fails + """ if len(voices) < 2: - return "af" + raise ValueError("At least 2 voices are required for combination") + + # Load voices t_voices: List[torch.Tensor] = [] v_name: List[str] = [] + + for voice in voices: + voice_path = self._get_voice_path(voice) + if not voice_path: + raise ValueError(f"Voice not found: {voice}") + + try: + voicepack = torch.load(voice_path, map_location=TTSModel.get_instance()[1], weights_only=True) + t_voices.append(voicepack) + v_name.append(voice) + except Exception as e: + raise ValueError(f"Failed to load voice {voice}: {str(e)}") + + # Combine voices try: - for file in os.listdir("voices"): - voice_name = file[:-3] # Remove .pt extension - for n in voices: - if n == voice_name: - v_name.append(voice_name) - t_voices.append(torch.load(f"voices/{file}", weights_only=True)) + f: str = "_".join(v_name) + v = torch.mean(torch.stack(t_voices), dim=0) + combined_path = os.path.join(TTSModel.VOICES_DIR, f"{f}.pt") + + # Save combined voice + try: + torch.save(v, combined_path) + except Exception as e: + raise RuntimeError(f"Failed to save combined voice to {combined_path}: {str(e)}") + + return f + except Exception as e: - print(f"Error combining voices: {str(e)}") - return "af" - f: str = "_".join(v_name) - v = torch.mean(torch.stack(t_voices), dim=0) - torch.save(v, f"voices/{f}.pt") - return f + if not isinstance(e, (ValueError, RuntimeError)): + raise RuntimeError(f"Error combining voices: {str(e)}") + raise def list_voices(self) -> List[str]: """List all available voices""" voices = [] try: - voices_path = os.path.join(settings.model_dir, settings.voices_dir) - for file in os.listdir(voices_path): + for file in os.listdir(TTSModel.VOICES_DIR): if file.endswith(".pt"): - voice_name = file[:-3] # Remove .pt extension - voices.append(voice_name) + voices.append(file[:-3]) # Remove .pt extension except Exception as e: - print(f"Error listing voices: {str(e)}") - return voices + logger.error(f"Error listing voices: {str(e)}") + return sorted(voices) diff --git a/api/src/structures/schemas.py b/api/src/structures/schemas.py index 2031e97..8ef36e4 100644 --- a/api/src/structures/schemas.py +++ b/api/src/structures/schemas.py @@ -16,18 +16,10 @@ class TTSStatus(str, Enum): class OpenAISpeechRequest(BaseModel): model: Literal["tts-1", "tts-1-hd", "kokoro"] = "kokoro" input: str = Field(..., description="The text to generate audio for") - voice: Literal[ - "am_adam", - "am_michael", - "bm_lewis", - "af", - "bm_george", - "bf_isabella", - "bf_emma", - "af_sarah", - "af_bella", - "af_nicole", - ] = Field(default="af", description="The voice to use for generation") + voice: str = Field( + default="af", + description="The voice to use for generation. Can be a base voice or a combined voice name." + ) response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field( default="mp3", description="The format to return audio in. Supported formats: mp3, opus, flac, wav. AAC and PCM are not currently supported.", diff --git a/api/tests/test_endpoints.py b/api/tests/test_endpoints.py index 97789f5..7c11008 100644 --- a/api/tests/test_endpoints.py +++ b/api/tests/test_endpoints.py @@ -78,7 +78,8 @@ def test_openai_speech_invalid_voice(mock_tts_service): "speed": 1.0, } response = client.post("/v1/audio/speech", json=test_request) - assert response.status_code == 422 # Validation error + assert response.status_code == 400 # Bad request + assert "not found" in response.json()["detail"]["message"] def test_openai_speech_invalid_speed(mock_tts_service): @@ -106,4 +107,49 @@ def test_openai_speech_generation_error(mock_tts_service): } response = client.post("/v1/audio/speech", json=test_request) assert response.status_code == 500 - assert "Generation failed" in response.json()["detail"] + assert "Generation failed" in response.json()["detail"]["message"] + + +def test_combine_voices_success(mock_tts_service): + """Test successful voice combination""" + test_voices = ["af_bella", "af_sarah"] + mock_tts_service.combine_voices.return_value = "af_bella_af_sarah" + + response = client.post("/v1/audio/voices/combine", json=test_voices) + + assert response.status_code == 200 + assert response.json()["voice"] == "af_bella_af_sarah" + mock_tts_service.combine_voices.assert_called_once_with(voices=test_voices) + + +def test_combine_voices_single_voice(mock_tts_service): + """Test combining single voice returns default voice""" + test_voices = ["af_bella"] + mock_tts_service.combine_voices.return_value = "af" + + response = client.post("/v1/audio/voices/combine", json=test_voices) + + assert response.status_code == 200 + assert response.json()["voice"] == "af" + + +def test_combine_voices_empty_list(mock_tts_service): + """Test combining empty voice list returns default voice""" + test_voices = [] + mock_tts_service.combine_voices.return_value = "af" + + response = client.post("/v1/audio/voices/combine", json=test_voices) + + assert response.status_code == 200 + assert response.json()["voice"] == "af" + + +def test_combine_voices_error(mock_tts_service): + """Test error handling in voice combination""" + test_voices = ["af_bella", "af_sarah"] + mock_tts_service.combine_voices.side_effect = Exception("Combination failed") + + response = client.post("/v1/audio/voices/combine", json=test_voices) + + assert response.status_code == 500 + assert "Combination failed" in response.json()["detail"]["message"] diff --git a/api/tests/test_tts_service.py b/api/tests/test_tts_service.py index 3f35a2b..533c514 100644 --- a/api/tests/test_tts_service.py +++ b/api/tests/test_tts_service.py @@ -79,36 +79,42 @@ def test_generate_audio_empty_text(mock_generate, mock_tokenize, mock_phonemize, @patch('api.src.services.tts.TTSModel.get_instance') -@patch('api.src.services.tts.TTSModel.get_voicepack') +@patch('os.path.exists') @patch('api.src.services.tts.normalize_text') @patch('api.src.services.tts.phonemize') @patch('api.src.services.tts.tokenize') @patch('api.src.services.tts.generate') -def test_generate_audio_no_chunks(mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_voicepack, mock_instance, tts_service): +@patch('torch.load') +def test_generate_audio_no_chunks(mock_torch_load, mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_exists, mock_instance, tts_service): """Test generating audio with no successful chunks""" mock_normalize.return_value = "Test text" mock_phonemize.return_value = "Test text" mock_tokenize.return_value = ["test", "text"] mock_generate.return_value = (None, None) mock_instance.return_value = (MagicMock(), "cpu") + mock_exists.return_value = True + mock_torch_load.return_value = MagicMock() with pytest.raises(ValueError, match="No audio chunks were generated successfully"): tts_service._generate_audio("Test text", "af", 1.0) @patch('api.src.services.tts.TTSModel.get_instance') -@patch('api.src.services.tts.TTSModel.get_voicepack') +@patch('os.path.exists') @patch('api.src.services.tts.normalize_text') @patch('api.src.services.tts.phonemize') @patch('api.src.services.tts.tokenize') @patch('api.src.services.tts.generate') -def test_generate_audio_success(mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_voicepack, mock_instance, tts_service, sample_audio): +@patch('torch.load') +def test_generate_audio_success(mock_torch_load, mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_exists, mock_instance, tts_service, sample_audio): """Test successful audio generation""" mock_normalize.return_value = "Test text" mock_phonemize.return_value = "Test text" mock_tokenize.return_value = ["test", "text"] mock_generate.return_value = (sample_audio, None) mock_instance.return_value = (MagicMock(), "cpu") + mock_exists.return_value = True + mock_torch_load.return_value = MagicMock() audio, processing_time = tts_service._generate_audio("Test text", "af", 1.0) assert isinstance(audio, np.ndarray) @@ -148,34 +154,19 @@ def test_model_initialization_cpu(mock_build_model, mock_cuda_available): mock_build_model.assert_called_once() +@patch('os.path.exists') @patch('api.src.services.tts.torch.load') @patch('os.path.join') -def test_voicepack_loading_error(mock_join, mock_torch_load): +def test_voicepack_loading_error(mock_join, mock_torch_load, mock_exists): """Test voicepack loading error handling""" mock_join.side_effect = lambda *args: '/'.join(args) - mock_torch_load.side_effect = [Exception("Failed to load voice"), MagicMock()] + mock_exists.side_effect = lambda x: False # All voice files don't exist TTSModel._instance = (MagicMock(), "cpu") # Mock instance TTSModel._voicepacks = {} # Reset voicepacks - # Should fall back to 'af' voice - voicepack = TTSModel.get_voicepack("nonexistent_voice") - assert mock_torch_load.call_count == 2 # Tried original voice then fallback - assert isinstance(voicepack, MagicMock) # Successfully got fallback voice - - -@patch('api.src.services.tts.torch.load') -@patch('os.path.join') -def test_voicepack_loading_error_af(mock_join, mock_torch_load): - """Test voicepack loading error for 'af' voice""" - mock_join.side_effect = lambda *args: '/'.join(args) - mock_torch_load.side_effect = Exception("Failed to load voice") - - TTSModel._instance = (MagicMock(), "cpu") # Mock instance - TTSModel._voicepacks = {} # Reset voicepacks - - with pytest.raises(Exception): - TTSModel.get_voicepack("af") + with pytest.raises(FileNotFoundError, match="Voice file not found: af"): + TTSModel.get_voicepack("nonexistent_voice") def test_save_audio(tts_service, sample_audio, tmp_path): @@ -188,14 +179,17 @@ def test_save_audio(tts_service, sample_audio, tmp_path): @patch('api.src.services.tts.TTSModel.get_instance') -@patch('api.src.services.tts.TTSModel.get_voicepack') +@patch('os.path.exists') @patch('api.src.services.tts.normalize_text') @patch('api.src.services.tts.generate') -def test_generate_audio_without_stitching(mock_generate, mock_normalize, mock_voicepack, mock_instance, tts_service, sample_audio): +@patch('torch.load') +def test_generate_audio_without_stitching(mock_torch_load, mock_generate, mock_normalize, mock_exists, mock_instance, tts_service, sample_audio): """Test generating audio without text stitching""" mock_normalize.return_value = "Test text" mock_generate.return_value = (sample_audio, None) mock_instance.return_value = (MagicMock(), "cpu") + mock_exists.return_value = True + mock_torch_load.return_value = MagicMock() audio, processing_time = tts_service._generate_audio("Test text", "af", 1.0, stitch_long_output=False) assert isinstance(audio, np.ndarray) @@ -214,16 +208,19 @@ def test_list_voices_error(mock_listdir, tts_service): @patch('api.src.services.tts.TTSModel.get_instance') -@patch('api.src.services.tts.TTSModel.get_voicepack') +@patch('os.path.exists') @patch('api.src.services.tts.normalize_text') @patch('api.src.services.tts.phonemize') @patch('api.src.services.tts.tokenize') @patch('api.src.services.tts.generate') -def test_generate_audio_phonemize_error(mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_voicepack, mock_instance, tts_service): +@patch('torch.load') +def test_generate_audio_phonemize_error(mock_torch_load, mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_exists, mock_instance, tts_service): """Test handling phonemization error""" mock_normalize.return_value = "Test text" mock_phonemize.side_effect = Exception("Phonemization failed") mock_instance.return_value = (MagicMock(), "cpu") + mock_exists.return_value = True + mock_torch_load.return_value = MagicMock() mock_generate.return_value = (None, None) with pytest.raises(ValueError, match="No audio chunks were generated successfully"): @@ -231,14 +228,17 @@ def test_generate_audio_phonemize_error(mock_generate, mock_tokenize, mock_phone @patch('api.src.services.tts.TTSModel.get_instance') -@patch('api.src.services.tts.TTSModel.get_voicepack') +@patch('os.path.exists') @patch('api.src.services.tts.normalize_text') @patch('api.src.services.tts.generate') -def test_generate_audio_error(mock_generate, mock_normalize, mock_voicepack, mock_instance, tts_service): +@patch('torch.load') +def test_generate_audio_error(mock_torch_load, mock_generate, mock_normalize, mock_exists, mock_instance, tts_service): """Test handling generation error""" mock_normalize.return_value = "Test text" mock_generate.side_effect = Exception("Generation failed") mock_instance.return_value = (MagicMock(), "cpu") + mock_exists.return_value = True + mock_torch_load.return_value = MagicMock() with pytest.raises(ValueError, match="No audio chunks were generated successfully"): tts_service._generate_audio("Test text", "af", 1.0) diff --git a/examples/benchmarks/analysis_comparison.png b/examples/benchmarks/analysis_comparison.png new file mode 100644 index 0000000..87a6d13 Binary files /dev/null and b/examples/benchmarks/analysis_comparison.png differ diff --git a/examples/test_analyze_combined_voices.py b/examples/test_analyze_combined_voices.py new file mode 100644 index 0000000..f48be90 --- /dev/null +++ b/examples/test_analyze_combined_voices.py @@ -0,0 +1,330 @@ +#!/usr/bin/env python3 +import argparse +import os +from typing import List, Optional, Dict, Tuple + +import requests +import numpy as np +from scipy.io import wavfile +import matplotlib.pyplot as plt + + +def submit_combine_voices(voices: List[str], base_url: str = "http://localhost:8880") -> Optional[str]: + """Combine multiple voices into a new voice. + + Args: + voices: List of voice names to combine (e.g. ["af_bella", "af_sarah"]) + base_url: API base URL + + Returns: + Name of the combined voice (e.g. "af_bella_af_sarah") or None if error + """ + try: + response = requests.post(f"{base_url}/v1/audio/voices/combine", json=voices) + print(f"Response status: {response.status_code}") + print(f"Raw response: {response.text}") + + # Accept both 200 and 201 as success + if response.status_code not in [200, 201]: + try: + error = response.json()["detail"]["message"] + print(f"Error combining voices: {error}") + except: + print(f"Error combining voices: {response.text}") + return None + + try: + data = response.json() + if "voices" in data: + print(f"Available voices: {', '.join(sorted(data['voices']))}") + return data["voice"] + except Exception as e: + print(f"Error parsing response: {e}") + return None + except Exception as e: + print(f"Error: {e}") + return None + + +def generate_speech(text: str, voice: str, base_url: str = "http://localhost:8880", output_file: str = "output.mp3") -> bool: + """Generate speech using specified voice. + + Args: + text: Text to convert to speech + voice: Voice name to use + base_url: API base URL + output_file: Path to save audio file + + Returns: + True if successful, False otherwise + """ + try: + response = requests.post( + f"{base_url}/v1/audio/speech", + json={ + "input": text, + "voice": voice, + "speed": 1.0, + "response_format": "wav" # Use WAV for analysis + } + ) + + if response.status_code != 200: + error = response.json().get("detail", {}).get("message", response.text) + print(f"Error generating speech: {error}") + return False + + # Save the audio + os.makedirs(os.path.dirname(output_file) if os.path.dirname(output_file) else ".", exist_ok=True) + with open(output_file, "wb") as f: + f.write(response.content) + print(f"Saved audio to {output_file}") + return True + + except Exception as e: + print(f"Error: {e}") + return False + + +def analyze_audio(filepath: str) -> Tuple[np.ndarray, int, dict]: + """Analyze audio file and return samples, sample rate, and audio characteristics. + + Args: + filepath: Path to audio file + + Returns: + Tuple of (samples, sample_rate, characteristics) + """ + sample_rate, samples = wavfile.read(filepath) + + # Convert to mono if stereo + if len(samples.shape) > 1: + samples = np.mean(samples, axis=1) + + # Calculate basic stats + max_amp = np.max(np.abs(samples)) + rms = np.sqrt(np.mean(samples**2)) + duration = len(samples) / sample_rate + + # Zero crossing rate (helps identify voice characteristics) + zero_crossings = np.sum(np.abs(np.diff(np.signbit(samples)))) / len(samples) + + # Simple frequency analysis + if len(samples) > 0: + # Use FFT to get frequency components + fft_result = np.fft.fft(samples) + freqs = np.fft.fftfreq(len(samples), 1/sample_rate) + + # Get positive frequencies only + pos_mask = freqs > 0 + freqs = freqs[pos_mask] + magnitudes = np.abs(fft_result)[pos_mask] + + # Find dominant frequencies (top 3) + top_indices = np.argsort(magnitudes)[-3:] + dominant_freqs = freqs[top_indices] + + # Calculate spectral centroid (brightness of sound) + spectral_centroid = np.sum(freqs * magnitudes) / np.sum(magnitudes) + else: + dominant_freqs = [] + spectral_centroid = 0 + + characteristics = { + "max_amplitude": max_amp, + "rms": rms, + "duration": duration, + "zero_crossing_rate": zero_crossings, + "dominant_frequencies": dominant_freqs, + "spectral_centroid": spectral_centroid + } + + return samples, sample_rate, characteristics + + +def setup_plot(fig, ax, title): + """Configure plot styling""" + # Improve grid + ax.grid(True, linestyle="--", alpha=0.3, color="#ffffff") + + # Set title and labels with better fonts + ax.set_title(title, pad=20, fontsize=16, fontweight="bold", color="#ffffff") + ax.set_xlabel(ax.get_xlabel(), fontsize=14, fontweight="medium", color="#ffffff") + ax.set_ylabel(ax.get_ylabel(), fontsize=14, fontweight="medium", color="#ffffff") + + # Improve tick labels + ax.tick_params(labelsize=12, colors="#ffffff") + + # Style spines + for spine in ax.spines.values(): + spine.set_color("#ffffff") + spine.set_alpha(0.3) + spine.set_linewidth(0.5) + + # Set background colors + ax.set_facecolor("#1a1a2e") + fig.patch.set_facecolor("#1a1a2e") + + return fig, ax + +def plot_analysis(audio_files: Dict[str, str], output_dir: str): + """Plot comprehensive voice analysis including waveforms and metrics comparison. + + Args: + audio_files: Dictionary of label -> filepath + output_dir: Directory to save plot files + """ + # Set dark style + plt.style.use('dark_background') + + # Create figure with subplots + fig = plt.figure(figsize=(15, 15)) + fig.patch.set_facecolor("#1a1a2e") + num_files = len(audio_files) + + # Create subplot grid with proper spacing + gs = plt.GridSpec(num_files + 1, 2, height_ratios=[1.5]*num_files + [1], + hspace=0.4, wspace=0.3) + + # Analyze all files first + all_chars = {} + for i, (label, filepath) in enumerate(audio_files.items()): + samples, sample_rate, chars = analyze_audio(filepath) + all_chars[label] = chars + + # Plot waveform spanning both columns + ax = plt.subplot(gs[i, :]) + time = np.arange(len(samples)) / sample_rate + plt.plot(time, samples / chars['max_amplitude'], linewidth=0.5, color="#ff2a6d") + ax.set_xlabel("Time (seconds)") + ax.set_ylabel("Normalized Amplitude") + ax.set_ylim(-1.1, 1.1) + setup_plot(fig, ax, f"Waveform: {label}") + + # Colors for voices + colors = ["#ff2a6d", "#05d9e8", "#d1f7ff"] + + # Create two subplots for metrics with similar scales + # Left subplot: Brightness and Volume + ax1 = plt.subplot(gs[num_files, 0]) + metrics1 = [ + ('Brightness', [chars['spectral_centroid']/1000 for chars in all_chars.values()], 'kHz'), + ('Volume', [chars['rms']*100 for chars in all_chars.values()], 'RMS×100') + ] + + # Right subplot: Voice Pitch and Texture + ax2 = plt.subplot(gs[num_files, 1]) + metrics2 = [ + ('Voice Pitch', [min(chars['dominant_frequencies']) for chars in all_chars.values()], 'Hz'), + ('Texture', [chars['zero_crossing_rate']*1000 for chars in all_chars.values()], 'ZCR×1000') + ] + + def plot_grouped_bars(ax, metrics, show_legend=True): + n_groups = len(metrics) + n_voices = len(audio_files) + bar_width = 0.25 + + indices = np.arange(n_groups) + + # Get max value for y-axis scaling + max_val = max(max(m[1]) for m in metrics) + + for i, (voice, color) in enumerate(zip(audio_files.keys(), colors)): + values = [m[1][i] for m in metrics] + offset = (i - n_voices/2 + 0.5) * bar_width + bars = ax.bar(indices + offset, values, bar_width, + label=voice, color=color, alpha=0.8) + + # Add value labels on top of bars + for bar in bars: + height = bar.get_height() + ax.text(bar.get_x() + bar.get_width()/2., height, + f'{height:.1f}', + ha='center', va='bottom', color='white', + fontsize=10) + + ax.set_xticks(indices) + ax.set_xticklabels([f"{m[0]}\n({m[2]})" for m in metrics]) + + # Set y-axis limits with some padding + ax.set_ylim(0, max_val * 1.2) + + if show_legend: + ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', + facecolor="#1a1a2e", edgecolor="#ffffff") + + # Plot both subplots + plot_grouped_bars(ax1, metrics1, show_legend=True) + plot_grouped_bars(ax2, metrics2, show_legend=False) + + # Style both subplots + setup_plot(fig, ax1, 'Brightness and Volume') + setup_plot(fig, ax2, 'Voice Pitch and Texture') + + # Add y-axis labels + ax1.set_ylabel('Value') + ax2.set_ylabel('Value') + + # Adjust the figure size to accommodate the legend + fig.set_size_inches(15, 15) + + # Add padding around the entire figure + plt.subplots_adjust(right=0.85, top=0.95, bottom=0.05, left=0.1) + plt.savefig(os.path.join(output_dir, "analysis_comparison.png"), dpi=300) + print(f"Saved analysis comparison to {output_dir}/analysis_comparison.png") + + # Print detailed comparative analysis + print("\nDetailed Voice Analysis:") + for label, chars in all_chars.items(): + print(f"\n{label}:") + print(f" Max Amplitude: {chars['max_amplitude']:.2f}") + print(f" RMS (loudness): {chars['rms']:.2f}") + print(f" Duration: {chars['duration']:.2f}s") + print(f" Zero Crossing Rate: {chars['zero_crossing_rate']:.3f}") + print(f" Spectral Centroid: {chars['spectral_centroid']:.0f}Hz") + print(f" Dominant Frequencies: {', '.join(f'{f:.0f}Hz' for f in chars['dominant_frequencies'])}") + + +def main(): + parser = argparse.ArgumentParser(description="Kokoro Voice Analysis Demo") + parser.add_argument("--voices", nargs="+", type=str, help="Voices to combine") + parser.add_argument("--text", type=str, default="Hello! This is a test of combined voices.", help="Text to speak") + parser.add_argument("--url", default="http://localhost:8880", help="API base URL") + parser.add_argument("--output-dir", default="examples/output", help="Output directory for audio files") + args = parser.parse_args() + + if not args.voices: + print("No voices provided, using default test voices") + args.voices = ["af_bella", "af_nicole"] + + # Create output directory + os.makedirs(args.output_dir, exist_ok=True) + + # Dictionary to store audio files for analysis + audio_files = {} + + # Generate speech with individual voices + print("Generating speech with individual voices...") + for voice in args.voices: + output_file = os.path.join(args.output_dir, f"analysis_{voice}.wav") + if generate_speech(args.text, voice, args.url, output_file): + audio_files[voice] = output_file + + # Generate speech with combined voice + print(f"\nCombining voices: {', '.join(args.voices)}") + combined_voice = submit_combine_voices(args.voices, args.url) + + if combined_voice: + print(f"Successfully created combined voice: {combined_voice}") + output_file = os.path.join(args.output_dir, f"analysis_combined_{combined_voice}.wav") + if generate_speech(args.text, combined_voice, args.url, output_file): + audio_files["combined"] = output_file + + # Generate comparison plots + plot_analysis(audio_files, args.output_dir) + else: + print("Failed to combine voices") + + +if __name__ == "__main__": + main() diff --git a/examples/test_combine_voices.py b/examples/test_combine_voices.py deleted file mode 100644 index 993d8b5..0000000 --- a/examples/test_combine_voices.py +++ /dev/null @@ -1,33 +0,0 @@ -#!/usr/bin/env python3 -import argparse -from typing import List, Optional - -import requests - - -def submit_combine_voices(voices: List[str], base_url: str = "http://localhost:8880") -> Optional[List[str]]: - try: - response = requests.post(f"{base_url}/v1/audio/voices/combine", json=voices) - if response.status_code != 200: - print(f"Error submitting request: {response.text}") - return None - return response.json()["voices"] - except requests.exceptions.RequestException as e: - print(f"Error: {e}") - return None - - -def main(): - parser = argparse.ArgumentParser(description="Kokoro TTS CLI") - parser.add_argument("--voices", nargs="+", type=str, help="Voices to combine") - parser.add_argument("--url", default="http://localhost:8880", help="API base URL") - args = parser.parse_args() - - success = submit_combine_voices(args.voices, args.url) - if success: - for voice in success: - print(f" {voice}") - - -if __name__ == "__main__": - main()