- modified voice loading to copy on init

- adjustments to the combine voices functionality
- error handling and analysis
This commit is contained in:
remsky 2024-12-31 18:55:26 -07:00
parent 510b01cc90
commit 05e1e30c47
11 changed files with 612 additions and 145 deletions

BIN
.coverage

Binary file not shown.

1
.gitignore vendored
View file

@ -5,6 +5,7 @@ output/
*.db *.db
*.pyc *.pyc
*.pth *.pth
*.pt
Kokoro-82M/* Kokoro-82M/*
__pycache__/ __pycache__/

View file

@ -4,8 +4,8 @@
# Kokoro TTS API # Kokoro TTS API
[![Model Commit](https://img.shields.io/badge/model--commit-a67f113-blue)](https://huggingface.co/hexgrad/Kokoro-82M/tree/8228a351f87c8a6076502c1e3b7e72e821ebec9a) [![Model Commit](https://img.shields.io/badge/model--commit-a67f113-blue)](https://huggingface.co/hexgrad/Kokoro-82M/tree/8228a351f87c8a6076502c1e3b7e72e821ebec9a)
[![Tests](https://img.shields.io/badge/tests-33%20passed-darkgreen)]() [![Tests](https://img.shields.io/badge/tests-36%20passed-darkgreen)]()
[![Coverage](https://img.shields.io/badge/coverage-97%25-darkgreen)]() [![Coverage](https://img.shields.io/badge/coverage-91%25-darkgreen)]()
FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model, providing an OpenAI-compatible endpoint with: FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model, providing an OpenAI-compatible endpoint with:
- NVIDIA GPU accelerated inference (or CPU) option - NVIDIA GPU accelerated inference (or CPU) option
@ -30,33 +30,40 @@ docker compose up --build
# For CPU-only deployment (~10x slower, but doesn't require an NVIDIA GPU): # For CPU-only deployment (~10x slower, but doesn't require an NVIDIA GPU):
docker compose -f docker-compose.cpu.yml up --build docker compose -f docker-compose.cpu.yml up --build
``` ```
Quick tests (run from another terminal):
Test all voices (from another terminal):
```bash
python examples/test_all_voices.py
```
Test OpenAI compatibility: Test OpenAI compatibility:
```bash ```bash
# Test OpenAI Compatibility
python examples/test_openai_tts.py python examples/test_openai_tts.py
# Test all available voices
python examples/test_all_voices.py
``` ```
## OpenAI-Compatible API ## OpenAI-Compatible API
List available voices: ```python
# Using OpenAI's Python library
from openai import OpenAI
client = OpenAI(base_url="http://localhost:8880", api_key="not-needed")
response = client.audio.speech.create(
model="kokoro", # Not used but required for compatibility, also accepts library defaults
voice="af_bella",
input="Hello world!",
response_format="mp3"
)
response.stream_to_file("output.mp3")
```
Or Via Requests:
```python ```python
import requests import requests
# Get list of all available voices
response = requests.get("http://localhost:8880/audio/voices") response = requests.get("http://localhost:8880/audio/voices")
voices = response.json()["voices"] voices = response.json()["voices"]
```
Generate speech:
```python
import requests
# Generate audio
response = requests.post( response = requests.post(
"http://localhost:8880/audio/speech", "http://localhost:8880/audio/speech",
json={ json={
@ -73,20 +80,28 @@ with open("output.mp3", "wb") as f:
f.write(response.content) f.write(response.content)
``` ```
Using OpenAI's Python library: ## Voice Combination
Combine voices and generate audio:
```python ```python
from openai import OpenAI import requests
client = OpenAI(base_url="http://localhost:8880", api_key="not-needed") # Create combined voice (saved locally on server)
response = requests.post(
response = client.audio.speech.create( "http://localhost:8880/v1/audio/voices/combine",
model="kokoro", # Not used but required for compatibility, also accepts library defaults json=["af_bella", "af_sarah"]
voice="af_bella",
input="Hello world!",
response_format="mp3"
) )
combined_voice = response.json()["voice"]
response.stream_to_file("output.mp3") # Generate audio with combined voice
response = requests.post(
"http://localhost:8880/v1/audio/speech",
json={
"input": "Hello world!",
"voice": combined_voice,
"response_format": "mp3"
}
)
``` ```
## Performance Benchmarks ## Performance Benchmarks
@ -115,6 +130,13 @@ Key Performance Metrics:
- Multiple audio formats: mp3, wav, opus, flac, (aac & pcm not implemented) - Multiple audio formats: mp3, wav, opus, flac, (aac & pcm not implemented)
- Natural Boundary Detection: - Natural Boundary Detection:
- Automatically splits and stitches at sentence boundaries to reduce artifacts and maintain performacne - Automatically splits and stitches at sentence boundaries to reduce artifacts and maintain performacne
- Voice Combination:
- Averages model weights of any existing voicepacks
- Saves generated voicepacks for future use
<p align="center">
<img src="examples/benchmarks/analysis_comparison.png" width="60%" alt="Voice Analysis Comparison" style="border: 2px solid #333; padding: 10px;">
</p>
*Note: CPU Inference is currently a very basic implementation, and not heavily tested* *Note: CPU Inference is currently a very basic implementation, and not heavily tested*
@ -133,11 +155,3 @@ This project is licensed under the Apache License 2.0 - see below for details:
- The inference code adapted from StyleTTS2 is MIT licensed - The inference code adapted from StyleTTS2 is MIT licensed
The full Apache 2.0 license text can be found at: https://www.apache.org/licenses/LICENSE-2.0 The full Apache 2.0 license text can be found at: https://www.apache.org/licenses/LICENSE-2.0
## Sample
<div align="center";">
https://user-images.githubusercontent.com/338912d2-90f3-41fb-bca0-5db7b4e02287.mp4
</div>

View file

@ -26,6 +26,13 @@ async def create_speech(
): ):
"""OpenAI-compatible endpoint for text-to-speech""" """OpenAI-compatible endpoint for text-to-speech"""
try: try:
# Validate voice exists
available_voices = tts_service.list_voices()
if request.voice not in available_voices:
raise ValueError(
f"Voice '{request.voice}' not found. Available voices: {', '.join(sorted(available_voices))}"
)
# Generate audio directly using TTSService's method # Generate audio directly using TTSService's method
audio, _ = tts_service._generate_audio( audio, _ = tts_service._generate_audio(
text=request.input, text=request.input,
@ -45,9 +52,18 @@ async def create_speech(
}, },
) )
except ValueError as e:
logger.error(f"Invalid request: {str(e)}")
raise HTTPException(
status_code=400,
detail={"error": "Invalid request", "message": str(e)}
)
except Exception as e: except Exception as e:
logger.error(f"Error generating speech: {str(e)}") logger.error(f"Error generating speech: {str(e)}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(
status_code=500,
detail={"error": "Server error", "message": str(e)}
)
@router.get("/audio/voices") @router.get("/audio/voices")
@ -63,10 +79,41 @@ async def list_voices(tts_service: TTSService = Depends(get_tts_service)):
@router.post("/audio/voices/combine") @router.post("/audio/voices/combine")
async def combine_voices(request: List[str], tts_service: TTSService = Depends(get_tts_service)): async def combine_voices(request: List[str], tts_service: TTSService = Depends(get_tts_service)):
"""Combine multiple voices into a new voice.
Args:
request: List of voice names to combine
Returns:
Dict with combined voice name and list of all available voices
Raises:
HTTPException:
- 400: Invalid request (wrong number of voices, voice not found)
- 500: Server error (file system issues, combination failed)
"""
try: try:
t = tts_service.combine_voices(voices=request) combined_voice = tts_service.combine_voices(voices=request)
voices = tts_service.list_voices() voices = tts_service.list_voices()
return {"voices": voices, "voice": t} return {"voices": voices, "voice": combined_voice}
except ValueError as e:
logger.error(f"Invalid voice combination request: {str(e)}")
raise HTTPException(
status_code=400,
detail={"error": "Invalid request", "message": str(e)}
)
except RuntimeError as e:
logger.error(f"Server error during voice combination: {str(e)}")
raise HTTPException(
status_code=500,
detail={"error": "Server error", "message": str(e)}
)
except Exception as e: except Exception as e:
logger.error(f"Error listing voices: {str(e)}") logger.error(f"Unexpected error during voice combination: {str(e)}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(
status_code=500,
detail={"error": "Unexpected error", "message": str(e)}
)

View file

@ -3,7 +3,7 @@ import os
import re import re
import threading import threading
import time import time
from typing import List, Tuple from typing import List, Tuple, Optional
import numpy as np import numpy as np
import scipy.io.wavfile as wavfile import scipy.io.wavfile as wavfile
@ -24,6 +24,9 @@ class TTSModel:
_lock = threading.Lock() _lock = threading.Lock()
_voicepacks = {} _voicepacks = {}
# Directory for all voices (copied base voices, and any created combined voices)
VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices")
@classmethod @classmethod
def get_instance(cls): def get_instance(cls):
if cls._instance is None: if cls._instance is None:
@ -39,18 +42,18 @@ class TTSModel:
@classmethod @classmethod
def get_voicepack(cls, voice_name: str) -> torch.Tensor: def get_voicepack(cls, voice_name: str) -> torch.Tensor:
"""Get a voice pack from the voices directory."""
model, device = cls.get_instance() model, device = cls.get_instance()
if voice_name not in cls._voicepacks: if voice_name not in cls._voicepacks:
try: try:
voice_path = os.path.join( voice_path = os.path.join(cls.VOICES_DIR, f"{voice_name}.pt")
settings.model_dir, settings.voices_dir, f"{voice_name}.pt" if not os.path.exists(voice_path):
) raise FileNotFoundError(f"Voice file not found: {voice_name}")
voicepack = torch.load(
voice_path, map_location=device, weights_only=True voicepack = torch.load(voice_path, map_location=device, weights_only=True)
)
cls._voicepacks[voice_name] = voicepack cls._voicepacks[voice_name] = voicepack
except Exception as e: except Exception as e:
print(f"Error loading voice {voice_name}: {str(e)}") logger.error(f"Error loading voice {voice_name}: {str(e)}")
if voice_name != "af": if voice_name != "af":
return cls.get_voicepack("af") return cls.get_voicepack("af")
raise raise
@ -60,13 +63,45 @@ class TTSModel:
class TTSService: class TTSService:
def __init__(self, output_dir: str = None, start_worker: bool = False): def __init__(self, output_dir: str = None, start_worker: bool = False):
self.output_dir = output_dir self.output_dir = output_dir
self._ensure_voices()
if start_worker: if start_worker:
self.start_worker() self.start_worker()
def _ensure_voices(self):
"""Copy base voices to local voices directory during initialization"""
os.makedirs(TTSModel.VOICES_DIR, exist_ok=True)
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
if os.path.exists(base_voices_dir):
for file in os.listdir(base_voices_dir):
if file.endswith(".pt"):
voice_name = file[:-3]
voice_path = os.path.join(TTSModel.VOICES_DIR, file)
if not os.path.exists(voice_path):
try:
base_path = os.path.join(base_voices_dir, file)
logger.info(f"Copying base voice {voice_name} to voices directory")
voicepack = torch.load(base_path, map_location=TTSModel.get_instance()[1], weights_only=True)
torch.save(voicepack, voice_path)
except Exception as e:
logger.error(f"Error copying voice {voice_name}: {str(e)}")
def _split_text(self, text: str) -> List[str]: def _split_text(self, text: str) -> List[str]:
"""Split text into sentences""" """Split text into sentences"""
return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()] return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()]
def _get_voice_path(self, voice_name: str) -> Optional[str]:
"""Get the path to a voice file.
Args:
voice_name: Name of the voice to find
Returns:
Path to the voice file if found, None otherwise
"""
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice_name}.pt")
return voice_path if os.path.exists(voice_path) else None
def _generate_audio( def _generate_audio(
self, text: str, voice: str, speed: float, stitch_long_output: bool = True self, text: str, voice: str, speed: float, stitch_long_output: bool = True
) -> Tuple[torch.Tensor, float]: ) -> Tuple[torch.Tensor, float]:
@ -79,9 +114,15 @@ class TTSService:
if not text: if not text:
raise ValueError("Text is empty after preprocessing") raise ValueError("Text is empty after preprocessing")
# Get model instance and voicepack # Get model instance
model, device = TTSModel.get_instance() model, device = TTSModel.get_instance()
voicepack = TTSModel.get_voicepack(voice)
# Load voice
voice_path = self._get_voice_path(voice)
if not voice_path:
raise ValueError(f"Voice not found: {voice}")
voicepack = torch.load(voice_path, map_location=device, weights_only=True)
# Generate audio with or without stitching # Generate audio with or without stitching
if stitch_long_output: if stitch_long_output:
@ -143,34 +184,63 @@ class TTSService:
return buffer.getvalue() return buffer.getvalue()
def combine_voices(self, voices: List[str]) -> str: def combine_voices(self, voices: List[str]) -> str:
"""Combine multiple voices into a new voice.
Args:
voices: List of voice names to combine
Returns:
Name of the combined voice
Raises:
ValueError: If less than 2 voices provided or voice loading fails
RuntimeError: If voice combination or saving fails
"""
if len(voices) < 2: if len(voices) < 2:
return "af" raise ValueError("At least 2 voices are required for combination")
# Load voices
t_voices: List[torch.Tensor] = [] t_voices: List[torch.Tensor] = []
v_name: List[str] = [] v_name: List[str] = []
for voice in voices:
voice_path = self._get_voice_path(voice)
if not voice_path:
raise ValueError(f"Voice not found: {voice}")
try: try:
for file in os.listdir("voices"): voicepack = torch.load(voice_path, map_location=TTSModel.get_instance()[1], weights_only=True)
voice_name = file[:-3] # Remove .pt extension t_voices.append(voicepack)
for n in voices: v_name.append(voice)
if n == voice_name:
v_name.append(voice_name)
t_voices.append(torch.load(f"voices/{file}", weights_only=True))
except Exception as e: except Exception as e:
print(f"Error combining voices: {str(e)}") raise ValueError(f"Failed to load voice {voice}: {str(e)}")
return "af"
# Combine voices
try:
f: str = "_".join(v_name) f: str = "_".join(v_name)
v = torch.mean(torch.stack(t_voices), dim=0) v = torch.mean(torch.stack(t_voices), dim=0)
torch.save(v, f"voices/{f}.pt") combined_path = os.path.join(TTSModel.VOICES_DIR, f"{f}.pt")
# Save combined voice
try:
torch.save(v, combined_path)
except Exception as e:
raise RuntimeError(f"Failed to save combined voice to {combined_path}: {str(e)}")
return f return f
except Exception as e:
if not isinstance(e, (ValueError, RuntimeError)):
raise RuntimeError(f"Error combining voices: {str(e)}")
raise
def list_voices(self) -> List[str]: def list_voices(self) -> List[str]:
"""List all available voices""" """List all available voices"""
voices = [] voices = []
try: try:
voices_path = os.path.join(settings.model_dir, settings.voices_dir) for file in os.listdir(TTSModel.VOICES_DIR):
for file in os.listdir(voices_path):
if file.endswith(".pt"): if file.endswith(".pt"):
voice_name = file[:-3] # Remove .pt extension voices.append(file[:-3]) # Remove .pt extension
voices.append(voice_name)
except Exception as e: except Exception as e:
print(f"Error listing voices: {str(e)}") logger.error(f"Error listing voices: {str(e)}")
return voices return sorted(voices)

View file

@ -16,18 +16,10 @@ class TTSStatus(str, Enum):
class OpenAISpeechRequest(BaseModel): class OpenAISpeechRequest(BaseModel):
model: Literal["tts-1", "tts-1-hd", "kokoro"] = "kokoro" model: Literal["tts-1", "tts-1-hd", "kokoro"] = "kokoro"
input: str = Field(..., description="The text to generate audio for") input: str = Field(..., description="The text to generate audio for")
voice: Literal[ voice: str = Field(
"am_adam", default="af",
"am_michael", description="The voice to use for generation. Can be a base voice or a combined voice name."
"bm_lewis", )
"af",
"bm_george",
"bf_isabella",
"bf_emma",
"af_sarah",
"af_bella",
"af_nicole",
] = Field(default="af", description="The voice to use for generation")
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field( response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field(
default="mp3", default="mp3",
description="The format to return audio in. Supported formats: mp3, opus, flac, wav. AAC and PCM are not currently supported.", description="The format to return audio in. Supported formats: mp3, opus, flac, wav. AAC and PCM are not currently supported.",

View file

@ -78,7 +78,8 @@ def test_openai_speech_invalid_voice(mock_tts_service):
"speed": 1.0, "speed": 1.0,
} }
response = client.post("/v1/audio/speech", json=test_request) response = client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 422 # Validation error assert response.status_code == 400 # Bad request
assert "not found" in response.json()["detail"]["message"]
def test_openai_speech_invalid_speed(mock_tts_service): def test_openai_speech_invalid_speed(mock_tts_service):
@ -106,4 +107,49 @@ def test_openai_speech_generation_error(mock_tts_service):
} }
response = client.post("/v1/audio/speech", json=test_request) response = client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 500 assert response.status_code == 500
assert "Generation failed" in response.json()["detail"] assert "Generation failed" in response.json()["detail"]["message"]
def test_combine_voices_success(mock_tts_service):
"""Test successful voice combination"""
test_voices = ["af_bella", "af_sarah"]
mock_tts_service.combine_voices.return_value = "af_bella_af_sarah"
response = client.post("/v1/audio/voices/combine", json=test_voices)
assert response.status_code == 200
assert response.json()["voice"] == "af_bella_af_sarah"
mock_tts_service.combine_voices.assert_called_once_with(voices=test_voices)
def test_combine_voices_single_voice(mock_tts_service):
"""Test combining single voice returns default voice"""
test_voices = ["af_bella"]
mock_tts_service.combine_voices.return_value = "af"
response = client.post("/v1/audio/voices/combine", json=test_voices)
assert response.status_code == 200
assert response.json()["voice"] == "af"
def test_combine_voices_empty_list(mock_tts_service):
"""Test combining empty voice list returns default voice"""
test_voices = []
mock_tts_service.combine_voices.return_value = "af"
response = client.post("/v1/audio/voices/combine", json=test_voices)
assert response.status_code == 200
assert response.json()["voice"] == "af"
def test_combine_voices_error(mock_tts_service):
"""Test error handling in voice combination"""
test_voices = ["af_bella", "af_sarah"]
mock_tts_service.combine_voices.side_effect = Exception("Combination failed")
response = client.post("/v1/audio/voices/combine", json=test_voices)
assert response.status_code == 500
assert "Combination failed" in response.json()["detail"]["message"]

View file

@ -79,36 +79,42 @@ def test_generate_audio_empty_text(mock_generate, mock_tokenize, mock_phonemize,
@patch('api.src.services.tts.TTSModel.get_instance') @patch('api.src.services.tts.TTSModel.get_instance')
@patch('api.src.services.tts.TTSModel.get_voicepack') @patch('os.path.exists')
@patch('api.src.services.tts.normalize_text') @patch('api.src.services.tts.normalize_text')
@patch('api.src.services.tts.phonemize') @patch('api.src.services.tts.phonemize')
@patch('api.src.services.tts.tokenize') @patch('api.src.services.tts.tokenize')
@patch('api.src.services.tts.generate') @patch('api.src.services.tts.generate')
def test_generate_audio_no_chunks(mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_voicepack, mock_instance, tts_service): @patch('torch.load')
def test_generate_audio_no_chunks(mock_torch_load, mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_exists, mock_instance, tts_service):
"""Test generating audio with no successful chunks""" """Test generating audio with no successful chunks"""
mock_normalize.return_value = "Test text" mock_normalize.return_value = "Test text"
mock_phonemize.return_value = "Test text" mock_phonemize.return_value = "Test text"
mock_tokenize.return_value = ["test", "text"] mock_tokenize.return_value = ["test", "text"]
mock_generate.return_value = (None, None) mock_generate.return_value = (None, None)
mock_instance.return_value = (MagicMock(), "cpu") mock_instance.return_value = (MagicMock(), "cpu")
mock_exists.return_value = True
mock_torch_load.return_value = MagicMock()
with pytest.raises(ValueError, match="No audio chunks were generated successfully"): with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
tts_service._generate_audio("Test text", "af", 1.0) tts_service._generate_audio("Test text", "af", 1.0)
@patch('api.src.services.tts.TTSModel.get_instance') @patch('api.src.services.tts.TTSModel.get_instance')
@patch('api.src.services.tts.TTSModel.get_voicepack') @patch('os.path.exists')
@patch('api.src.services.tts.normalize_text') @patch('api.src.services.tts.normalize_text')
@patch('api.src.services.tts.phonemize') @patch('api.src.services.tts.phonemize')
@patch('api.src.services.tts.tokenize') @patch('api.src.services.tts.tokenize')
@patch('api.src.services.tts.generate') @patch('api.src.services.tts.generate')
def test_generate_audio_success(mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_voicepack, mock_instance, tts_service, sample_audio): @patch('torch.load')
def test_generate_audio_success(mock_torch_load, mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_exists, mock_instance, tts_service, sample_audio):
"""Test successful audio generation""" """Test successful audio generation"""
mock_normalize.return_value = "Test text" mock_normalize.return_value = "Test text"
mock_phonemize.return_value = "Test text" mock_phonemize.return_value = "Test text"
mock_tokenize.return_value = ["test", "text"] mock_tokenize.return_value = ["test", "text"]
mock_generate.return_value = (sample_audio, None) mock_generate.return_value = (sample_audio, None)
mock_instance.return_value = (MagicMock(), "cpu") mock_instance.return_value = (MagicMock(), "cpu")
mock_exists.return_value = True
mock_torch_load.return_value = MagicMock()
audio, processing_time = tts_service._generate_audio("Test text", "af", 1.0) audio, processing_time = tts_service._generate_audio("Test text", "af", 1.0)
assert isinstance(audio, np.ndarray) assert isinstance(audio, np.ndarray)
@ -148,34 +154,19 @@ def test_model_initialization_cpu(mock_build_model, mock_cuda_available):
mock_build_model.assert_called_once() mock_build_model.assert_called_once()
@patch('os.path.exists')
@patch('api.src.services.tts.torch.load') @patch('api.src.services.tts.torch.load')
@patch('os.path.join') @patch('os.path.join')
def test_voicepack_loading_error(mock_join, mock_torch_load): def test_voicepack_loading_error(mock_join, mock_torch_load, mock_exists):
"""Test voicepack loading error handling""" """Test voicepack loading error handling"""
mock_join.side_effect = lambda *args: '/'.join(args) mock_join.side_effect = lambda *args: '/'.join(args)
mock_torch_load.side_effect = [Exception("Failed to load voice"), MagicMock()] mock_exists.side_effect = lambda x: False # All voice files don't exist
TTSModel._instance = (MagicMock(), "cpu") # Mock instance TTSModel._instance = (MagicMock(), "cpu") # Mock instance
TTSModel._voicepacks = {} # Reset voicepacks TTSModel._voicepacks = {} # Reset voicepacks
# Should fall back to 'af' voice with pytest.raises(FileNotFoundError, match="Voice file not found: af"):
voicepack = TTSModel.get_voicepack("nonexistent_voice") TTSModel.get_voicepack("nonexistent_voice")
assert mock_torch_load.call_count == 2 # Tried original voice then fallback
assert isinstance(voicepack, MagicMock) # Successfully got fallback voice
@patch('api.src.services.tts.torch.load')
@patch('os.path.join')
def test_voicepack_loading_error_af(mock_join, mock_torch_load):
"""Test voicepack loading error for 'af' voice"""
mock_join.side_effect = lambda *args: '/'.join(args)
mock_torch_load.side_effect = Exception("Failed to load voice")
TTSModel._instance = (MagicMock(), "cpu") # Mock instance
TTSModel._voicepacks = {} # Reset voicepacks
with pytest.raises(Exception):
TTSModel.get_voicepack("af")
def test_save_audio(tts_service, sample_audio, tmp_path): def test_save_audio(tts_service, sample_audio, tmp_path):
@ -188,14 +179,17 @@ def test_save_audio(tts_service, sample_audio, tmp_path):
@patch('api.src.services.tts.TTSModel.get_instance') @patch('api.src.services.tts.TTSModel.get_instance')
@patch('api.src.services.tts.TTSModel.get_voicepack') @patch('os.path.exists')
@patch('api.src.services.tts.normalize_text') @patch('api.src.services.tts.normalize_text')
@patch('api.src.services.tts.generate') @patch('api.src.services.tts.generate')
def test_generate_audio_without_stitching(mock_generate, mock_normalize, mock_voicepack, mock_instance, tts_service, sample_audio): @patch('torch.load')
def test_generate_audio_without_stitching(mock_torch_load, mock_generate, mock_normalize, mock_exists, mock_instance, tts_service, sample_audio):
"""Test generating audio without text stitching""" """Test generating audio without text stitching"""
mock_normalize.return_value = "Test text" mock_normalize.return_value = "Test text"
mock_generate.return_value = (sample_audio, None) mock_generate.return_value = (sample_audio, None)
mock_instance.return_value = (MagicMock(), "cpu") mock_instance.return_value = (MagicMock(), "cpu")
mock_exists.return_value = True
mock_torch_load.return_value = MagicMock()
audio, processing_time = tts_service._generate_audio("Test text", "af", 1.0, stitch_long_output=False) audio, processing_time = tts_service._generate_audio("Test text", "af", 1.0, stitch_long_output=False)
assert isinstance(audio, np.ndarray) assert isinstance(audio, np.ndarray)
@ -214,16 +208,19 @@ def test_list_voices_error(mock_listdir, tts_service):
@patch('api.src.services.tts.TTSModel.get_instance') @patch('api.src.services.tts.TTSModel.get_instance')
@patch('api.src.services.tts.TTSModel.get_voicepack') @patch('os.path.exists')
@patch('api.src.services.tts.normalize_text') @patch('api.src.services.tts.normalize_text')
@patch('api.src.services.tts.phonemize') @patch('api.src.services.tts.phonemize')
@patch('api.src.services.tts.tokenize') @patch('api.src.services.tts.tokenize')
@patch('api.src.services.tts.generate') @patch('api.src.services.tts.generate')
def test_generate_audio_phonemize_error(mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_voicepack, mock_instance, tts_service): @patch('torch.load')
def test_generate_audio_phonemize_error(mock_torch_load, mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_exists, mock_instance, tts_service):
"""Test handling phonemization error""" """Test handling phonemization error"""
mock_normalize.return_value = "Test text" mock_normalize.return_value = "Test text"
mock_phonemize.side_effect = Exception("Phonemization failed") mock_phonemize.side_effect = Exception("Phonemization failed")
mock_instance.return_value = (MagicMock(), "cpu") mock_instance.return_value = (MagicMock(), "cpu")
mock_exists.return_value = True
mock_torch_load.return_value = MagicMock()
mock_generate.return_value = (None, None) mock_generate.return_value = (None, None)
with pytest.raises(ValueError, match="No audio chunks were generated successfully"): with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
@ -231,14 +228,17 @@ def test_generate_audio_phonemize_error(mock_generate, mock_tokenize, mock_phone
@patch('api.src.services.tts.TTSModel.get_instance') @patch('api.src.services.tts.TTSModel.get_instance')
@patch('api.src.services.tts.TTSModel.get_voicepack') @patch('os.path.exists')
@patch('api.src.services.tts.normalize_text') @patch('api.src.services.tts.normalize_text')
@patch('api.src.services.tts.generate') @patch('api.src.services.tts.generate')
def test_generate_audio_error(mock_generate, mock_normalize, mock_voicepack, mock_instance, tts_service): @patch('torch.load')
def test_generate_audio_error(mock_torch_load, mock_generate, mock_normalize, mock_exists, mock_instance, tts_service):
"""Test handling generation error""" """Test handling generation error"""
mock_normalize.return_value = "Test text" mock_normalize.return_value = "Test text"
mock_generate.side_effect = Exception("Generation failed") mock_generate.side_effect = Exception("Generation failed")
mock_instance.return_value = (MagicMock(), "cpu") mock_instance.return_value = (MagicMock(), "cpu")
mock_exists.return_value = True
mock_torch_load.return_value = MagicMock()
with pytest.raises(ValueError, match="No audio chunks were generated successfully"): with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
tts_service._generate_audio("Test text", "af", 1.0) tts_service._generate_audio("Test text", "af", 1.0)

Binary file not shown.

After

Width:  |  Height:  |  Size: 754 KiB

View file

@ -0,0 +1,330 @@
#!/usr/bin/env python3
import argparse
import os
from typing import List, Optional, Dict, Tuple
import requests
import numpy as np
from scipy.io import wavfile
import matplotlib.pyplot as plt
def submit_combine_voices(voices: List[str], base_url: str = "http://localhost:8880") -> Optional[str]:
"""Combine multiple voices into a new voice.
Args:
voices: List of voice names to combine (e.g. ["af_bella", "af_sarah"])
base_url: API base URL
Returns:
Name of the combined voice (e.g. "af_bella_af_sarah") or None if error
"""
try:
response = requests.post(f"{base_url}/v1/audio/voices/combine", json=voices)
print(f"Response status: {response.status_code}")
print(f"Raw response: {response.text}")
# Accept both 200 and 201 as success
if response.status_code not in [200, 201]:
try:
error = response.json()["detail"]["message"]
print(f"Error combining voices: {error}")
except:
print(f"Error combining voices: {response.text}")
return None
try:
data = response.json()
if "voices" in data:
print(f"Available voices: {', '.join(sorted(data['voices']))}")
return data["voice"]
except Exception as e:
print(f"Error parsing response: {e}")
return None
except Exception as e:
print(f"Error: {e}")
return None
def generate_speech(text: str, voice: str, base_url: str = "http://localhost:8880", output_file: str = "output.mp3") -> bool:
"""Generate speech using specified voice.
Args:
text: Text to convert to speech
voice: Voice name to use
base_url: API base URL
output_file: Path to save audio file
Returns:
True if successful, False otherwise
"""
try:
response = requests.post(
f"{base_url}/v1/audio/speech",
json={
"input": text,
"voice": voice,
"speed": 1.0,
"response_format": "wav" # Use WAV for analysis
}
)
if response.status_code != 200:
error = response.json().get("detail", {}).get("message", response.text)
print(f"Error generating speech: {error}")
return False
# Save the audio
os.makedirs(os.path.dirname(output_file) if os.path.dirname(output_file) else ".", exist_ok=True)
with open(output_file, "wb") as f:
f.write(response.content)
print(f"Saved audio to {output_file}")
return True
except Exception as e:
print(f"Error: {e}")
return False
def analyze_audio(filepath: str) -> Tuple[np.ndarray, int, dict]:
"""Analyze audio file and return samples, sample rate, and audio characteristics.
Args:
filepath: Path to audio file
Returns:
Tuple of (samples, sample_rate, characteristics)
"""
sample_rate, samples = wavfile.read(filepath)
# Convert to mono if stereo
if len(samples.shape) > 1:
samples = np.mean(samples, axis=1)
# Calculate basic stats
max_amp = np.max(np.abs(samples))
rms = np.sqrt(np.mean(samples**2))
duration = len(samples) / sample_rate
# Zero crossing rate (helps identify voice characteristics)
zero_crossings = np.sum(np.abs(np.diff(np.signbit(samples)))) / len(samples)
# Simple frequency analysis
if len(samples) > 0:
# Use FFT to get frequency components
fft_result = np.fft.fft(samples)
freqs = np.fft.fftfreq(len(samples), 1/sample_rate)
# Get positive frequencies only
pos_mask = freqs > 0
freqs = freqs[pos_mask]
magnitudes = np.abs(fft_result)[pos_mask]
# Find dominant frequencies (top 3)
top_indices = np.argsort(magnitudes)[-3:]
dominant_freqs = freqs[top_indices]
# Calculate spectral centroid (brightness of sound)
spectral_centroid = np.sum(freqs * magnitudes) / np.sum(magnitudes)
else:
dominant_freqs = []
spectral_centroid = 0
characteristics = {
"max_amplitude": max_amp,
"rms": rms,
"duration": duration,
"zero_crossing_rate": zero_crossings,
"dominant_frequencies": dominant_freqs,
"spectral_centroid": spectral_centroid
}
return samples, sample_rate, characteristics
def setup_plot(fig, ax, title):
"""Configure plot styling"""
# Improve grid
ax.grid(True, linestyle="--", alpha=0.3, color="#ffffff")
# Set title and labels with better fonts
ax.set_title(title, pad=20, fontsize=16, fontweight="bold", color="#ffffff")
ax.set_xlabel(ax.get_xlabel(), fontsize=14, fontweight="medium", color="#ffffff")
ax.set_ylabel(ax.get_ylabel(), fontsize=14, fontweight="medium", color="#ffffff")
# Improve tick labels
ax.tick_params(labelsize=12, colors="#ffffff")
# Style spines
for spine in ax.spines.values():
spine.set_color("#ffffff")
spine.set_alpha(0.3)
spine.set_linewidth(0.5)
# Set background colors
ax.set_facecolor("#1a1a2e")
fig.patch.set_facecolor("#1a1a2e")
return fig, ax
def plot_analysis(audio_files: Dict[str, str], output_dir: str):
"""Plot comprehensive voice analysis including waveforms and metrics comparison.
Args:
audio_files: Dictionary of label -> filepath
output_dir: Directory to save plot files
"""
# Set dark style
plt.style.use('dark_background')
# Create figure with subplots
fig = plt.figure(figsize=(15, 15))
fig.patch.set_facecolor("#1a1a2e")
num_files = len(audio_files)
# Create subplot grid with proper spacing
gs = plt.GridSpec(num_files + 1, 2, height_ratios=[1.5]*num_files + [1],
hspace=0.4, wspace=0.3)
# Analyze all files first
all_chars = {}
for i, (label, filepath) in enumerate(audio_files.items()):
samples, sample_rate, chars = analyze_audio(filepath)
all_chars[label] = chars
# Plot waveform spanning both columns
ax = plt.subplot(gs[i, :])
time = np.arange(len(samples)) / sample_rate
plt.plot(time, samples / chars['max_amplitude'], linewidth=0.5, color="#ff2a6d")
ax.set_xlabel("Time (seconds)")
ax.set_ylabel("Normalized Amplitude")
ax.set_ylim(-1.1, 1.1)
setup_plot(fig, ax, f"Waveform: {label}")
# Colors for voices
colors = ["#ff2a6d", "#05d9e8", "#d1f7ff"]
# Create two subplots for metrics with similar scales
# Left subplot: Brightness and Volume
ax1 = plt.subplot(gs[num_files, 0])
metrics1 = [
('Brightness', [chars['spectral_centroid']/1000 for chars in all_chars.values()], 'kHz'),
('Volume', [chars['rms']*100 for chars in all_chars.values()], 'RMS×100')
]
# Right subplot: Voice Pitch and Texture
ax2 = plt.subplot(gs[num_files, 1])
metrics2 = [
('Voice Pitch', [min(chars['dominant_frequencies']) for chars in all_chars.values()], 'Hz'),
('Texture', [chars['zero_crossing_rate']*1000 for chars in all_chars.values()], 'ZCR×1000')
]
def plot_grouped_bars(ax, metrics, show_legend=True):
n_groups = len(metrics)
n_voices = len(audio_files)
bar_width = 0.25
indices = np.arange(n_groups)
# Get max value for y-axis scaling
max_val = max(max(m[1]) for m in metrics)
for i, (voice, color) in enumerate(zip(audio_files.keys(), colors)):
values = [m[1][i] for m in metrics]
offset = (i - n_voices/2 + 0.5) * bar_width
bars = ax.bar(indices + offset, values, bar_width,
label=voice, color=color, alpha=0.8)
# Add value labels on top of bars
for bar in bars:
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height,
f'{height:.1f}',
ha='center', va='bottom', color='white',
fontsize=10)
ax.set_xticks(indices)
ax.set_xticklabels([f"{m[0]}\n({m[2]})" for m in metrics])
# Set y-axis limits with some padding
ax.set_ylim(0, max_val * 1.2)
if show_legend:
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left',
facecolor="#1a1a2e", edgecolor="#ffffff")
# Plot both subplots
plot_grouped_bars(ax1, metrics1, show_legend=True)
plot_grouped_bars(ax2, metrics2, show_legend=False)
# Style both subplots
setup_plot(fig, ax1, 'Brightness and Volume')
setup_plot(fig, ax2, 'Voice Pitch and Texture')
# Add y-axis labels
ax1.set_ylabel('Value')
ax2.set_ylabel('Value')
# Adjust the figure size to accommodate the legend
fig.set_size_inches(15, 15)
# Add padding around the entire figure
plt.subplots_adjust(right=0.85, top=0.95, bottom=0.05, left=0.1)
plt.savefig(os.path.join(output_dir, "analysis_comparison.png"), dpi=300)
print(f"Saved analysis comparison to {output_dir}/analysis_comparison.png")
# Print detailed comparative analysis
print("\nDetailed Voice Analysis:")
for label, chars in all_chars.items():
print(f"\n{label}:")
print(f" Max Amplitude: {chars['max_amplitude']:.2f}")
print(f" RMS (loudness): {chars['rms']:.2f}")
print(f" Duration: {chars['duration']:.2f}s")
print(f" Zero Crossing Rate: {chars['zero_crossing_rate']:.3f}")
print(f" Spectral Centroid: {chars['spectral_centroid']:.0f}Hz")
print(f" Dominant Frequencies: {', '.join(f'{f:.0f}Hz' for f in chars['dominant_frequencies'])}")
def main():
parser = argparse.ArgumentParser(description="Kokoro Voice Analysis Demo")
parser.add_argument("--voices", nargs="+", type=str, help="Voices to combine")
parser.add_argument("--text", type=str, default="Hello! This is a test of combined voices.", help="Text to speak")
parser.add_argument("--url", default="http://localhost:8880", help="API base URL")
parser.add_argument("--output-dir", default="examples/output", help="Output directory for audio files")
args = parser.parse_args()
if not args.voices:
print("No voices provided, using default test voices")
args.voices = ["af_bella", "af_nicole"]
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Dictionary to store audio files for analysis
audio_files = {}
# Generate speech with individual voices
print("Generating speech with individual voices...")
for voice in args.voices:
output_file = os.path.join(args.output_dir, f"analysis_{voice}.wav")
if generate_speech(args.text, voice, args.url, output_file):
audio_files[voice] = output_file
# Generate speech with combined voice
print(f"\nCombining voices: {', '.join(args.voices)}")
combined_voice = submit_combine_voices(args.voices, args.url)
if combined_voice:
print(f"Successfully created combined voice: {combined_voice}")
output_file = os.path.join(args.output_dir, f"analysis_combined_{combined_voice}.wav")
if generate_speech(args.text, combined_voice, args.url, output_file):
audio_files["combined"] = output_file
# Generate comparison plots
plot_analysis(audio_files, args.output_dir)
else:
print("Failed to combine voices")
if __name__ == "__main__":
main()

View file

@ -1,33 +0,0 @@
#!/usr/bin/env python3
import argparse
from typing import List, Optional
import requests
def submit_combine_voices(voices: List[str], base_url: str = "http://localhost:8880") -> Optional[List[str]]:
try:
response = requests.post(f"{base_url}/v1/audio/voices/combine", json=voices)
if response.status_code != 200:
print(f"Error submitting request: {response.text}")
return None
return response.json()["voices"]
except requests.exceptions.RequestException as e:
print(f"Error: {e}")
return None
def main():
parser = argparse.ArgumentParser(description="Kokoro TTS CLI")
parser.add_argument("--voices", nargs="+", type=str, help="Voices to combine")
parser.add_argument("--url", default="http://localhost:8880", help="API base URL")
args = parser.parse_args()
success = submit_combine_voices(args.voices, args.url)
if success:
for voice in success:
print(f" {voice}")
if __name__ == "__main__":
main()