mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
- modified voice loading to copy on init
- adjustments to the combine voices functionality - error handling and analysis
This commit is contained in:
parent
510b01cc90
commit
05e1e30c47
11 changed files with 612 additions and 145 deletions
BIN
.coverage
BIN
.coverage
Binary file not shown.
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -5,6 +5,7 @@ output/
|
|||
*.db
|
||||
*.pyc
|
||||
*.pth
|
||||
*.pt
|
||||
|
||||
Kokoro-82M/*
|
||||
__pycache__/
|
||||
|
|
80
README.md
80
README.md
|
@ -4,8 +4,8 @@
|
|||
|
||||
# Kokoro TTS API
|
||||
[](https://huggingface.co/hexgrad/Kokoro-82M/tree/8228a351f87c8a6076502c1e3b7e72e821ebec9a)
|
||||
[]()
|
||||
[]()
|
||||
[]()
|
||||
[]()
|
||||
|
||||
FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model, providing an OpenAI-compatible endpoint with:
|
||||
- NVIDIA GPU accelerated inference (or CPU) option
|
||||
|
@ -30,33 +30,40 @@ docker compose up --build
|
|||
# For CPU-only deployment (~10x slower, but doesn't require an NVIDIA GPU):
|
||||
docker compose -f docker-compose.cpu.yml up --build
|
||||
```
|
||||
|
||||
|
||||
|
||||
Test all voices (from another terminal):
|
||||
```bash
|
||||
python examples/test_all_voices.py
|
||||
```
|
||||
Quick tests (run from another terminal):
|
||||
|
||||
Test OpenAI compatibility:
|
||||
```bash
|
||||
# Test OpenAI Compatibility
|
||||
python examples/test_openai_tts.py
|
||||
# Test all available voices
|
||||
python examples/test_all_voices.py
|
||||
```
|
||||
|
||||
## OpenAI-Compatible API
|
||||
|
||||
List available voices:
|
||||
```python
|
||||
# Using OpenAI's Python library
|
||||
from openai import OpenAI
|
||||
client = OpenAI(base_url="http://localhost:8880", api_key="not-needed")
|
||||
response = client.audio.speech.create(
|
||||
model="kokoro", # Not used but required for compatibility, also accepts library defaults
|
||||
voice="af_bella",
|
||||
input="Hello world!",
|
||||
response_format="mp3"
|
||||
)
|
||||
|
||||
response.stream_to_file("output.mp3")
|
||||
```
|
||||
Or Via Requests:
|
||||
```python
|
||||
import requests
|
||||
|
||||
# Get list of all available voices
|
||||
response = requests.get("http://localhost:8880/audio/voices")
|
||||
voices = response.json()["voices"]
|
||||
```
|
||||
|
||||
Generate speech:
|
||||
```python
|
||||
import requests
|
||||
|
||||
# Generate audio
|
||||
response = requests.post(
|
||||
"http://localhost:8880/audio/speech",
|
||||
json={
|
||||
|
@ -73,20 +80,28 @@ with open("output.mp3", "wb") as f:
|
|||
f.write(response.content)
|
||||
```
|
||||
|
||||
Using OpenAI's Python library:
|
||||
## Voice Combination
|
||||
|
||||
Combine voices and generate audio:
|
||||
```python
|
||||
from openai import OpenAI
|
||||
import requests
|
||||
|
||||
client = OpenAI(base_url="http://localhost:8880", api_key="not-needed")
|
||||
|
||||
response = client.audio.speech.create(
|
||||
model="kokoro", # Not used but required for compatibility, also accepts library defaults
|
||||
voice="af_bella",
|
||||
input="Hello world!",
|
||||
response_format="mp3"
|
||||
# Create combined voice (saved locally on server)
|
||||
response = requests.post(
|
||||
"http://localhost:8880/v1/audio/voices/combine",
|
||||
json=["af_bella", "af_sarah"]
|
||||
)
|
||||
combined_voice = response.json()["voice"]
|
||||
|
||||
response.stream_to_file("output.mp3")
|
||||
# Generate audio with combined voice
|
||||
response = requests.post(
|
||||
"http://localhost:8880/v1/audio/speech",
|
||||
json={
|
||||
"input": "Hello world!",
|
||||
"voice": combined_voice,
|
||||
"response_format": "mp3"
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Performance Benchmarks
|
||||
|
@ -115,6 +130,13 @@ Key Performance Metrics:
|
|||
- Multiple audio formats: mp3, wav, opus, flac, (aac & pcm not implemented)
|
||||
- Natural Boundary Detection:
|
||||
- Automatically splits and stitches at sentence boundaries to reduce artifacts and maintain performacne
|
||||
- Voice Combination:
|
||||
- Averages model weights of any existing voicepacks
|
||||
- Saves generated voicepacks for future use
|
||||
|
||||
<p align="center">
|
||||
<img src="examples/benchmarks/analysis_comparison.png" width="60%" alt="Voice Analysis Comparison" style="border: 2px solid #333; padding: 10px;">
|
||||
</p>
|
||||
|
||||
*Note: CPU Inference is currently a very basic implementation, and not heavily tested*
|
||||
|
||||
|
@ -133,11 +155,3 @@ This project is licensed under the Apache License 2.0 - see below for details:
|
|||
- The inference code adapted from StyleTTS2 is MIT licensed
|
||||
|
||||
The full Apache 2.0 license text can be found at: https://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
## Sample
|
||||
|
||||
<div align="center";">
|
||||
|
||||
https://user-images.githubusercontent.com/338912d2-90f3-41fb-bca0-5db7b4e02287.mp4
|
||||
|
||||
</div>
|
||||
|
|
|
@ -26,6 +26,13 @@ async def create_speech(
|
|||
):
|
||||
"""OpenAI-compatible endpoint for text-to-speech"""
|
||||
try:
|
||||
# Validate voice exists
|
||||
available_voices = tts_service.list_voices()
|
||||
if request.voice not in available_voices:
|
||||
raise ValueError(
|
||||
f"Voice '{request.voice}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
||||
)
|
||||
|
||||
# Generate audio directly using TTSService's method
|
||||
audio, _ = tts_service._generate_audio(
|
||||
text=request.input,
|
||||
|
@ -45,9 +52,18 @@ async def create_speech(
|
|||
},
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid request: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "Invalid request", "message": str(e)}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating speech: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": "Server error", "message": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/audio/voices")
|
||||
|
@ -63,10 +79,41 @@ async def list_voices(tts_service: TTSService = Depends(get_tts_service)):
|
|||
|
||||
@router.post("/audio/voices/combine")
|
||||
async def combine_voices(request: List[str], tts_service: TTSService = Depends(get_tts_service)):
|
||||
"""Combine multiple voices into a new voice.
|
||||
|
||||
Args:
|
||||
request: List of voice names to combine
|
||||
|
||||
Returns:
|
||||
Dict with combined voice name and list of all available voices
|
||||
|
||||
Raises:
|
||||
HTTPException:
|
||||
- 400: Invalid request (wrong number of voices, voice not found)
|
||||
- 500: Server error (file system issues, combination failed)
|
||||
"""
|
||||
try:
|
||||
t = tts_service.combine_voices(voices=request)
|
||||
combined_voice = tts_service.combine_voices(voices=request)
|
||||
voices = tts_service.list_voices()
|
||||
return {"voices": voices, "voice": t}
|
||||
return {"voices": voices, "voice": combined_voice}
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid voice combination request: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "Invalid request", "message": str(e)}
|
||||
)
|
||||
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Server error during voice combination: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": "Server error", "message": str(e)}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing voices: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
logger.error(f"Unexpected error during voice combination: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": "Unexpected error", "message": str(e)}
|
||||
)
|
||||
|
|
|
@ -3,7 +3,7 @@ import os
|
|||
import re
|
||||
import threading
|
||||
import time
|
||||
from typing import List, Tuple
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import scipy.io.wavfile as wavfile
|
||||
|
@ -24,6 +24,9 @@ class TTSModel:
|
|||
_lock = threading.Lock()
|
||||
_voicepacks = {}
|
||||
|
||||
# Directory for all voices (copied base voices, and any created combined voices)
|
||||
VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices")
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
if cls._instance is None:
|
||||
|
@ -39,18 +42,18 @@ class TTSModel:
|
|||
|
||||
@classmethod
|
||||
def get_voicepack(cls, voice_name: str) -> torch.Tensor:
|
||||
"""Get a voice pack from the voices directory."""
|
||||
model, device = cls.get_instance()
|
||||
if voice_name not in cls._voicepacks:
|
||||
try:
|
||||
voice_path = os.path.join(
|
||||
settings.model_dir, settings.voices_dir, f"{voice_name}.pt"
|
||||
)
|
||||
voicepack = torch.load(
|
||||
voice_path, map_location=device, weights_only=True
|
||||
)
|
||||
voice_path = os.path.join(cls.VOICES_DIR, f"{voice_name}.pt")
|
||||
if not os.path.exists(voice_path):
|
||||
raise FileNotFoundError(f"Voice file not found: {voice_name}")
|
||||
|
||||
voicepack = torch.load(voice_path, map_location=device, weights_only=True)
|
||||
cls._voicepacks[voice_name] = voicepack
|
||||
except Exception as e:
|
||||
print(f"Error loading voice {voice_name}: {str(e)}")
|
||||
logger.error(f"Error loading voice {voice_name}: {str(e)}")
|
||||
if voice_name != "af":
|
||||
return cls.get_voicepack("af")
|
||||
raise
|
||||
|
@ -60,13 +63,45 @@ class TTSModel:
|
|||
class TTSService:
|
||||
def __init__(self, output_dir: str = None, start_worker: bool = False):
|
||||
self.output_dir = output_dir
|
||||
self._ensure_voices()
|
||||
if start_worker:
|
||||
self.start_worker()
|
||||
|
||||
def _ensure_voices(self):
|
||||
"""Copy base voices to local voices directory during initialization"""
|
||||
os.makedirs(TTSModel.VOICES_DIR, exist_ok=True)
|
||||
|
||||
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
|
||||
if os.path.exists(base_voices_dir):
|
||||
for file in os.listdir(base_voices_dir):
|
||||
if file.endswith(".pt"):
|
||||
voice_name = file[:-3]
|
||||
voice_path = os.path.join(TTSModel.VOICES_DIR, file)
|
||||
if not os.path.exists(voice_path):
|
||||
try:
|
||||
base_path = os.path.join(base_voices_dir, file)
|
||||
logger.info(f"Copying base voice {voice_name} to voices directory")
|
||||
voicepack = torch.load(base_path, map_location=TTSModel.get_instance()[1], weights_only=True)
|
||||
torch.save(voicepack, voice_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Error copying voice {voice_name}: {str(e)}")
|
||||
|
||||
def _split_text(self, text: str) -> List[str]:
|
||||
"""Split text into sentences"""
|
||||
return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()]
|
||||
|
||||
def _get_voice_path(self, voice_name: str) -> Optional[str]:
|
||||
"""Get the path to a voice file.
|
||||
|
||||
Args:
|
||||
voice_name: Name of the voice to find
|
||||
|
||||
Returns:
|
||||
Path to the voice file if found, None otherwise
|
||||
"""
|
||||
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice_name}.pt")
|
||||
return voice_path if os.path.exists(voice_path) else None
|
||||
|
||||
def _generate_audio(
|
||||
self, text: str, voice: str, speed: float, stitch_long_output: bool = True
|
||||
) -> Tuple[torch.Tensor, float]:
|
||||
|
@ -79,9 +114,15 @@ class TTSService:
|
|||
if not text:
|
||||
raise ValueError("Text is empty after preprocessing")
|
||||
|
||||
# Get model instance and voicepack
|
||||
# Get model instance
|
||||
model, device = TTSModel.get_instance()
|
||||
voicepack = TTSModel.get_voicepack(voice)
|
||||
|
||||
# Load voice
|
||||
voice_path = self._get_voice_path(voice)
|
||||
if not voice_path:
|
||||
raise ValueError(f"Voice not found: {voice}")
|
||||
|
||||
voicepack = torch.load(voice_path, map_location=device, weights_only=True)
|
||||
|
||||
# Generate audio with or without stitching
|
||||
if stitch_long_output:
|
||||
|
@ -143,34 +184,63 @@ class TTSService:
|
|||
return buffer.getvalue()
|
||||
|
||||
def combine_voices(self, voices: List[str]) -> str:
|
||||
"""Combine multiple voices into a new voice.
|
||||
|
||||
Args:
|
||||
voices: List of voice names to combine
|
||||
|
||||
Returns:
|
||||
Name of the combined voice
|
||||
|
||||
Raises:
|
||||
ValueError: If less than 2 voices provided or voice loading fails
|
||||
RuntimeError: If voice combination or saving fails
|
||||
"""
|
||||
if len(voices) < 2:
|
||||
return "af"
|
||||
raise ValueError("At least 2 voices are required for combination")
|
||||
|
||||
# Load voices
|
||||
t_voices: List[torch.Tensor] = []
|
||||
v_name: List[str] = []
|
||||
|
||||
for voice in voices:
|
||||
voice_path = self._get_voice_path(voice)
|
||||
if not voice_path:
|
||||
raise ValueError(f"Voice not found: {voice}")
|
||||
|
||||
try:
|
||||
voicepack = torch.load(voice_path, map_location=TTSModel.get_instance()[1], weights_only=True)
|
||||
t_voices.append(voicepack)
|
||||
v_name.append(voice)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load voice {voice}: {str(e)}")
|
||||
|
||||
# Combine voices
|
||||
try:
|
||||
for file in os.listdir("voices"):
|
||||
voice_name = file[:-3] # Remove .pt extension
|
||||
for n in voices:
|
||||
if n == voice_name:
|
||||
v_name.append(voice_name)
|
||||
t_voices.append(torch.load(f"voices/{file}", weights_only=True))
|
||||
f: str = "_".join(v_name)
|
||||
v = torch.mean(torch.stack(t_voices), dim=0)
|
||||
combined_path = os.path.join(TTSModel.VOICES_DIR, f"{f}.pt")
|
||||
|
||||
# Save combined voice
|
||||
try:
|
||||
torch.save(v, combined_path)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to save combined voice to {combined_path}: {str(e)}")
|
||||
|
||||
return f
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error combining voices: {str(e)}")
|
||||
return "af"
|
||||
f: str = "_".join(v_name)
|
||||
v = torch.mean(torch.stack(t_voices), dim=0)
|
||||
torch.save(v, f"voices/{f}.pt")
|
||||
return f
|
||||
if not isinstance(e, (ValueError, RuntimeError)):
|
||||
raise RuntimeError(f"Error combining voices: {str(e)}")
|
||||
raise
|
||||
|
||||
def list_voices(self) -> List[str]:
|
||||
"""List all available voices"""
|
||||
voices = []
|
||||
try:
|
||||
voices_path = os.path.join(settings.model_dir, settings.voices_dir)
|
||||
for file in os.listdir(voices_path):
|
||||
for file in os.listdir(TTSModel.VOICES_DIR):
|
||||
if file.endswith(".pt"):
|
||||
voice_name = file[:-3] # Remove .pt extension
|
||||
voices.append(voice_name)
|
||||
voices.append(file[:-3]) # Remove .pt extension
|
||||
except Exception as e:
|
||||
print(f"Error listing voices: {str(e)}")
|
||||
return voices
|
||||
logger.error(f"Error listing voices: {str(e)}")
|
||||
return sorted(voices)
|
||||
|
|
|
@ -16,18 +16,10 @@ class TTSStatus(str, Enum):
|
|||
class OpenAISpeechRequest(BaseModel):
|
||||
model: Literal["tts-1", "tts-1-hd", "kokoro"] = "kokoro"
|
||||
input: str = Field(..., description="The text to generate audio for")
|
||||
voice: Literal[
|
||||
"am_adam",
|
||||
"am_michael",
|
||||
"bm_lewis",
|
||||
"af",
|
||||
"bm_george",
|
||||
"bf_isabella",
|
||||
"bf_emma",
|
||||
"af_sarah",
|
||||
"af_bella",
|
||||
"af_nicole",
|
||||
] = Field(default="af", description="The voice to use for generation")
|
||||
voice: str = Field(
|
||||
default="af",
|
||||
description="The voice to use for generation. Can be a base voice or a combined voice name."
|
||||
)
|
||||
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field(
|
||||
default="mp3",
|
||||
description="The format to return audio in. Supported formats: mp3, opus, flac, wav. AAC and PCM are not currently supported.",
|
||||
|
|
|
@ -78,7 +78,8 @@ def test_openai_speech_invalid_voice(mock_tts_service):
|
|||
"speed": 1.0,
|
||||
}
|
||||
response = client.post("/v1/audio/speech", json=test_request)
|
||||
assert response.status_code == 422 # Validation error
|
||||
assert response.status_code == 400 # Bad request
|
||||
assert "not found" in response.json()["detail"]["message"]
|
||||
|
||||
|
||||
def test_openai_speech_invalid_speed(mock_tts_service):
|
||||
|
@ -106,4 +107,49 @@ def test_openai_speech_generation_error(mock_tts_service):
|
|||
}
|
||||
response = client.post("/v1/audio/speech", json=test_request)
|
||||
assert response.status_code == 500
|
||||
assert "Generation failed" in response.json()["detail"]
|
||||
assert "Generation failed" in response.json()["detail"]["message"]
|
||||
|
||||
|
||||
def test_combine_voices_success(mock_tts_service):
|
||||
"""Test successful voice combination"""
|
||||
test_voices = ["af_bella", "af_sarah"]
|
||||
mock_tts_service.combine_voices.return_value = "af_bella_af_sarah"
|
||||
|
||||
response = client.post("/v1/audio/voices/combine", json=test_voices)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["voice"] == "af_bella_af_sarah"
|
||||
mock_tts_service.combine_voices.assert_called_once_with(voices=test_voices)
|
||||
|
||||
|
||||
def test_combine_voices_single_voice(mock_tts_service):
|
||||
"""Test combining single voice returns default voice"""
|
||||
test_voices = ["af_bella"]
|
||||
mock_tts_service.combine_voices.return_value = "af"
|
||||
|
||||
response = client.post("/v1/audio/voices/combine", json=test_voices)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["voice"] == "af"
|
||||
|
||||
|
||||
def test_combine_voices_empty_list(mock_tts_service):
|
||||
"""Test combining empty voice list returns default voice"""
|
||||
test_voices = []
|
||||
mock_tts_service.combine_voices.return_value = "af"
|
||||
|
||||
response = client.post("/v1/audio/voices/combine", json=test_voices)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["voice"] == "af"
|
||||
|
||||
|
||||
def test_combine_voices_error(mock_tts_service):
|
||||
"""Test error handling in voice combination"""
|
||||
test_voices = ["af_bella", "af_sarah"]
|
||||
mock_tts_service.combine_voices.side_effect = Exception("Combination failed")
|
||||
|
||||
response = client.post("/v1/audio/voices/combine", json=test_voices)
|
||||
|
||||
assert response.status_code == 500
|
||||
assert "Combination failed" in response.json()["detail"]["message"]
|
||||
|
|
|
@ -79,36 +79,42 @@ def test_generate_audio_empty_text(mock_generate, mock_tokenize, mock_phonemize,
|
|||
|
||||
|
||||
@patch('api.src.services.tts.TTSModel.get_instance')
|
||||
@patch('api.src.services.tts.TTSModel.get_voicepack')
|
||||
@patch('os.path.exists')
|
||||
@patch('api.src.services.tts.normalize_text')
|
||||
@patch('api.src.services.tts.phonemize')
|
||||
@patch('api.src.services.tts.tokenize')
|
||||
@patch('api.src.services.tts.generate')
|
||||
def test_generate_audio_no_chunks(mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_voicepack, mock_instance, tts_service):
|
||||
@patch('torch.load')
|
||||
def test_generate_audio_no_chunks(mock_torch_load, mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_exists, mock_instance, tts_service):
|
||||
"""Test generating audio with no successful chunks"""
|
||||
mock_normalize.return_value = "Test text"
|
||||
mock_phonemize.return_value = "Test text"
|
||||
mock_tokenize.return_value = ["test", "text"]
|
||||
mock_generate.return_value = (None, None)
|
||||
mock_instance.return_value = (MagicMock(), "cpu")
|
||||
mock_exists.return_value = True
|
||||
mock_torch_load.return_value = MagicMock()
|
||||
|
||||
with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
|
||||
tts_service._generate_audio("Test text", "af", 1.0)
|
||||
|
||||
|
||||
@patch('api.src.services.tts.TTSModel.get_instance')
|
||||
@patch('api.src.services.tts.TTSModel.get_voicepack')
|
||||
@patch('os.path.exists')
|
||||
@patch('api.src.services.tts.normalize_text')
|
||||
@patch('api.src.services.tts.phonemize')
|
||||
@patch('api.src.services.tts.tokenize')
|
||||
@patch('api.src.services.tts.generate')
|
||||
def test_generate_audio_success(mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_voicepack, mock_instance, tts_service, sample_audio):
|
||||
@patch('torch.load')
|
||||
def test_generate_audio_success(mock_torch_load, mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_exists, mock_instance, tts_service, sample_audio):
|
||||
"""Test successful audio generation"""
|
||||
mock_normalize.return_value = "Test text"
|
||||
mock_phonemize.return_value = "Test text"
|
||||
mock_tokenize.return_value = ["test", "text"]
|
||||
mock_generate.return_value = (sample_audio, None)
|
||||
mock_instance.return_value = (MagicMock(), "cpu")
|
||||
mock_exists.return_value = True
|
||||
mock_torch_load.return_value = MagicMock()
|
||||
|
||||
audio, processing_time = tts_service._generate_audio("Test text", "af", 1.0)
|
||||
assert isinstance(audio, np.ndarray)
|
||||
|
@ -148,34 +154,19 @@ def test_model_initialization_cpu(mock_build_model, mock_cuda_available):
|
|||
mock_build_model.assert_called_once()
|
||||
|
||||
|
||||
@patch('os.path.exists')
|
||||
@patch('api.src.services.tts.torch.load')
|
||||
@patch('os.path.join')
|
||||
def test_voicepack_loading_error(mock_join, mock_torch_load):
|
||||
def test_voicepack_loading_error(mock_join, mock_torch_load, mock_exists):
|
||||
"""Test voicepack loading error handling"""
|
||||
mock_join.side_effect = lambda *args: '/'.join(args)
|
||||
mock_torch_load.side_effect = [Exception("Failed to load voice"), MagicMock()]
|
||||
mock_exists.side_effect = lambda x: False # All voice files don't exist
|
||||
|
||||
TTSModel._instance = (MagicMock(), "cpu") # Mock instance
|
||||
TTSModel._voicepacks = {} # Reset voicepacks
|
||||
|
||||
# Should fall back to 'af' voice
|
||||
voicepack = TTSModel.get_voicepack("nonexistent_voice")
|
||||
assert mock_torch_load.call_count == 2 # Tried original voice then fallback
|
||||
assert isinstance(voicepack, MagicMock) # Successfully got fallback voice
|
||||
|
||||
|
||||
@patch('api.src.services.tts.torch.load')
|
||||
@patch('os.path.join')
|
||||
def test_voicepack_loading_error_af(mock_join, mock_torch_load):
|
||||
"""Test voicepack loading error for 'af' voice"""
|
||||
mock_join.side_effect = lambda *args: '/'.join(args)
|
||||
mock_torch_load.side_effect = Exception("Failed to load voice")
|
||||
|
||||
TTSModel._instance = (MagicMock(), "cpu") # Mock instance
|
||||
TTSModel._voicepacks = {} # Reset voicepacks
|
||||
|
||||
with pytest.raises(Exception):
|
||||
TTSModel.get_voicepack("af")
|
||||
with pytest.raises(FileNotFoundError, match="Voice file not found: af"):
|
||||
TTSModel.get_voicepack("nonexistent_voice")
|
||||
|
||||
|
||||
def test_save_audio(tts_service, sample_audio, tmp_path):
|
||||
|
@ -188,14 +179,17 @@ def test_save_audio(tts_service, sample_audio, tmp_path):
|
|||
|
||||
|
||||
@patch('api.src.services.tts.TTSModel.get_instance')
|
||||
@patch('api.src.services.tts.TTSModel.get_voicepack')
|
||||
@patch('os.path.exists')
|
||||
@patch('api.src.services.tts.normalize_text')
|
||||
@patch('api.src.services.tts.generate')
|
||||
def test_generate_audio_without_stitching(mock_generate, mock_normalize, mock_voicepack, mock_instance, tts_service, sample_audio):
|
||||
@patch('torch.load')
|
||||
def test_generate_audio_without_stitching(mock_torch_load, mock_generate, mock_normalize, mock_exists, mock_instance, tts_service, sample_audio):
|
||||
"""Test generating audio without text stitching"""
|
||||
mock_normalize.return_value = "Test text"
|
||||
mock_generate.return_value = (sample_audio, None)
|
||||
mock_instance.return_value = (MagicMock(), "cpu")
|
||||
mock_exists.return_value = True
|
||||
mock_torch_load.return_value = MagicMock()
|
||||
|
||||
audio, processing_time = tts_service._generate_audio("Test text", "af", 1.0, stitch_long_output=False)
|
||||
assert isinstance(audio, np.ndarray)
|
||||
|
@ -214,16 +208,19 @@ def test_list_voices_error(mock_listdir, tts_service):
|
|||
|
||||
|
||||
@patch('api.src.services.tts.TTSModel.get_instance')
|
||||
@patch('api.src.services.tts.TTSModel.get_voicepack')
|
||||
@patch('os.path.exists')
|
||||
@patch('api.src.services.tts.normalize_text')
|
||||
@patch('api.src.services.tts.phonemize')
|
||||
@patch('api.src.services.tts.tokenize')
|
||||
@patch('api.src.services.tts.generate')
|
||||
def test_generate_audio_phonemize_error(mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_voicepack, mock_instance, tts_service):
|
||||
@patch('torch.load')
|
||||
def test_generate_audio_phonemize_error(mock_torch_load, mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_exists, mock_instance, tts_service):
|
||||
"""Test handling phonemization error"""
|
||||
mock_normalize.return_value = "Test text"
|
||||
mock_phonemize.side_effect = Exception("Phonemization failed")
|
||||
mock_instance.return_value = (MagicMock(), "cpu")
|
||||
mock_exists.return_value = True
|
||||
mock_torch_load.return_value = MagicMock()
|
||||
mock_generate.return_value = (None, None)
|
||||
|
||||
with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
|
||||
|
@ -231,14 +228,17 @@ def test_generate_audio_phonemize_error(mock_generate, mock_tokenize, mock_phone
|
|||
|
||||
|
||||
@patch('api.src.services.tts.TTSModel.get_instance')
|
||||
@patch('api.src.services.tts.TTSModel.get_voicepack')
|
||||
@patch('os.path.exists')
|
||||
@patch('api.src.services.tts.normalize_text')
|
||||
@patch('api.src.services.tts.generate')
|
||||
def test_generate_audio_error(mock_generate, mock_normalize, mock_voicepack, mock_instance, tts_service):
|
||||
@patch('torch.load')
|
||||
def test_generate_audio_error(mock_torch_load, mock_generate, mock_normalize, mock_exists, mock_instance, tts_service):
|
||||
"""Test handling generation error"""
|
||||
mock_normalize.return_value = "Test text"
|
||||
mock_generate.side_effect = Exception("Generation failed")
|
||||
mock_instance.return_value = (MagicMock(), "cpu")
|
||||
mock_exists.return_value = True
|
||||
mock_torch_load.return_value = MagicMock()
|
||||
|
||||
with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
|
||||
tts_service._generate_audio("Test text", "af", 1.0)
|
||||
|
|
BIN
examples/benchmarks/analysis_comparison.png
Normal file
BIN
examples/benchmarks/analysis_comparison.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 754 KiB |
330
examples/test_analyze_combined_voices.py
Normal file
330
examples/test_analyze_combined_voices.py
Normal file
|
@ -0,0 +1,330 @@
|
|||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import os
|
||||
from typing import List, Optional, Dict, Tuple
|
||||
|
||||
import requests
|
||||
import numpy as np
|
||||
from scipy.io import wavfile
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def submit_combine_voices(voices: List[str], base_url: str = "http://localhost:8880") -> Optional[str]:
|
||||
"""Combine multiple voices into a new voice.
|
||||
|
||||
Args:
|
||||
voices: List of voice names to combine (e.g. ["af_bella", "af_sarah"])
|
||||
base_url: API base URL
|
||||
|
||||
Returns:
|
||||
Name of the combined voice (e.g. "af_bella_af_sarah") or None if error
|
||||
"""
|
||||
try:
|
||||
response = requests.post(f"{base_url}/v1/audio/voices/combine", json=voices)
|
||||
print(f"Response status: {response.status_code}")
|
||||
print(f"Raw response: {response.text}")
|
||||
|
||||
# Accept both 200 and 201 as success
|
||||
if response.status_code not in [200, 201]:
|
||||
try:
|
||||
error = response.json()["detail"]["message"]
|
||||
print(f"Error combining voices: {error}")
|
||||
except:
|
||||
print(f"Error combining voices: {response.text}")
|
||||
return None
|
||||
|
||||
try:
|
||||
data = response.json()
|
||||
if "voices" in data:
|
||||
print(f"Available voices: {', '.join(sorted(data['voices']))}")
|
||||
return data["voice"]
|
||||
except Exception as e:
|
||||
print(f"Error parsing response: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def generate_speech(text: str, voice: str, base_url: str = "http://localhost:8880", output_file: str = "output.mp3") -> bool:
|
||||
"""Generate speech using specified voice.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
voice: Voice name to use
|
||||
base_url: API base URL
|
||||
output_file: Path to save audio file
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{base_url}/v1/audio/speech",
|
||||
json={
|
||||
"input": text,
|
||||
"voice": voice,
|
||||
"speed": 1.0,
|
||||
"response_format": "wav" # Use WAV for analysis
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error = response.json().get("detail", {}).get("message", response.text)
|
||||
print(f"Error generating speech: {error}")
|
||||
return False
|
||||
|
||||
# Save the audio
|
||||
os.makedirs(os.path.dirname(output_file) if os.path.dirname(output_file) else ".", exist_ok=True)
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(response.content)
|
||||
print(f"Saved audio to {output_file}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def analyze_audio(filepath: str) -> Tuple[np.ndarray, int, dict]:
|
||||
"""Analyze audio file and return samples, sample rate, and audio characteristics.
|
||||
|
||||
Args:
|
||||
filepath: Path to audio file
|
||||
|
||||
Returns:
|
||||
Tuple of (samples, sample_rate, characteristics)
|
||||
"""
|
||||
sample_rate, samples = wavfile.read(filepath)
|
||||
|
||||
# Convert to mono if stereo
|
||||
if len(samples.shape) > 1:
|
||||
samples = np.mean(samples, axis=1)
|
||||
|
||||
# Calculate basic stats
|
||||
max_amp = np.max(np.abs(samples))
|
||||
rms = np.sqrt(np.mean(samples**2))
|
||||
duration = len(samples) / sample_rate
|
||||
|
||||
# Zero crossing rate (helps identify voice characteristics)
|
||||
zero_crossings = np.sum(np.abs(np.diff(np.signbit(samples)))) / len(samples)
|
||||
|
||||
# Simple frequency analysis
|
||||
if len(samples) > 0:
|
||||
# Use FFT to get frequency components
|
||||
fft_result = np.fft.fft(samples)
|
||||
freqs = np.fft.fftfreq(len(samples), 1/sample_rate)
|
||||
|
||||
# Get positive frequencies only
|
||||
pos_mask = freqs > 0
|
||||
freqs = freqs[pos_mask]
|
||||
magnitudes = np.abs(fft_result)[pos_mask]
|
||||
|
||||
# Find dominant frequencies (top 3)
|
||||
top_indices = np.argsort(magnitudes)[-3:]
|
||||
dominant_freqs = freqs[top_indices]
|
||||
|
||||
# Calculate spectral centroid (brightness of sound)
|
||||
spectral_centroid = np.sum(freqs * magnitudes) / np.sum(magnitudes)
|
||||
else:
|
||||
dominant_freqs = []
|
||||
spectral_centroid = 0
|
||||
|
||||
characteristics = {
|
||||
"max_amplitude": max_amp,
|
||||
"rms": rms,
|
||||
"duration": duration,
|
||||
"zero_crossing_rate": zero_crossings,
|
||||
"dominant_frequencies": dominant_freqs,
|
||||
"spectral_centroid": spectral_centroid
|
||||
}
|
||||
|
||||
return samples, sample_rate, characteristics
|
||||
|
||||
|
||||
def setup_plot(fig, ax, title):
|
||||
"""Configure plot styling"""
|
||||
# Improve grid
|
||||
ax.grid(True, linestyle="--", alpha=0.3, color="#ffffff")
|
||||
|
||||
# Set title and labels with better fonts
|
||||
ax.set_title(title, pad=20, fontsize=16, fontweight="bold", color="#ffffff")
|
||||
ax.set_xlabel(ax.get_xlabel(), fontsize=14, fontweight="medium", color="#ffffff")
|
||||
ax.set_ylabel(ax.get_ylabel(), fontsize=14, fontweight="medium", color="#ffffff")
|
||||
|
||||
# Improve tick labels
|
||||
ax.tick_params(labelsize=12, colors="#ffffff")
|
||||
|
||||
# Style spines
|
||||
for spine in ax.spines.values():
|
||||
spine.set_color("#ffffff")
|
||||
spine.set_alpha(0.3)
|
||||
spine.set_linewidth(0.5)
|
||||
|
||||
# Set background colors
|
||||
ax.set_facecolor("#1a1a2e")
|
||||
fig.patch.set_facecolor("#1a1a2e")
|
||||
|
||||
return fig, ax
|
||||
|
||||
def plot_analysis(audio_files: Dict[str, str], output_dir: str):
|
||||
"""Plot comprehensive voice analysis including waveforms and metrics comparison.
|
||||
|
||||
Args:
|
||||
audio_files: Dictionary of label -> filepath
|
||||
output_dir: Directory to save plot files
|
||||
"""
|
||||
# Set dark style
|
||||
plt.style.use('dark_background')
|
||||
|
||||
# Create figure with subplots
|
||||
fig = plt.figure(figsize=(15, 15))
|
||||
fig.patch.set_facecolor("#1a1a2e")
|
||||
num_files = len(audio_files)
|
||||
|
||||
# Create subplot grid with proper spacing
|
||||
gs = plt.GridSpec(num_files + 1, 2, height_ratios=[1.5]*num_files + [1],
|
||||
hspace=0.4, wspace=0.3)
|
||||
|
||||
# Analyze all files first
|
||||
all_chars = {}
|
||||
for i, (label, filepath) in enumerate(audio_files.items()):
|
||||
samples, sample_rate, chars = analyze_audio(filepath)
|
||||
all_chars[label] = chars
|
||||
|
||||
# Plot waveform spanning both columns
|
||||
ax = plt.subplot(gs[i, :])
|
||||
time = np.arange(len(samples)) / sample_rate
|
||||
plt.plot(time, samples / chars['max_amplitude'], linewidth=0.5, color="#ff2a6d")
|
||||
ax.set_xlabel("Time (seconds)")
|
||||
ax.set_ylabel("Normalized Amplitude")
|
||||
ax.set_ylim(-1.1, 1.1)
|
||||
setup_plot(fig, ax, f"Waveform: {label}")
|
||||
|
||||
# Colors for voices
|
||||
colors = ["#ff2a6d", "#05d9e8", "#d1f7ff"]
|
||||
|
||||
# Create two subplots for metrics with similar scales
|
||||
# Left subplot: Brightness and Volume
|
||||
ax1 = plt.subplot(gs[num_files, 0])
|
||||
metrics1 = [
|
||||
('Brightness', [chars['spectral_centroid']/1000 for chars in all_chars.values()], 'kHz'),
|
||||
('Volume', [chars['rms']*100 for chars in all_chars.values()], 'RMS×100')
|
||||
]
|
||||
|
||||
# Right subplot: Voice Pitch and Texture
|
||||
ax2 = plt.subplot(gs[num_files, 1])
|
||||
metrics2 = [
|
||||
('Voice Pitch', [min(chars['dominant_frequencies']) for chars in all_chars.values()], 'Hz'),
|
||||
('Texture', [chars['zero_crossing_rate']*1000 for chars in all_chars.values()], 'ZCR×1000')
|
||||
]
|
||||
|
||||
def plot_grouped_bars(ax, metrics, show_legend=True):
|
||||
n_groups = len(metrics)
|
||||
n_voices = len(audio_files)
|
||||
bar_width = 0.25
|
||||
|
||||
indices = np.arange(n_groups)
|
||||
|
||||
# Get max value for y-axis scaling
|
||||
max_val = max(max(m[1]) for m in metrics)
|
||||
|
||||
for i, (voice, color) in enumerate(zip(audio_files.keys(), colors)):
|
||||
values = [m[1][i] for m in metrics]
|
||||
offset = (i - n_voices/2 + 0.5) * bar_width
|
||||
bars = ax.bar(indices + offset, values, bar_width,
|
||||
label=voice, color=color, alpha=0.8)
|
||||
|
||||
# Add value labels on top of bars
|
||||
for bar in bars:
|
||||
height = bar.get_height()
|
||||
ax.text(bar.get_x() + bar.get_width()/2., height,
|
||||
f'{height:.1f}',
|
||||
ha='center', va='bottom', color='white',
|
||||
fontsize=10)
|
||||
|
||||
ax.set_xticks(indices)
|
||||
ax.set_xticklabels([f"{m[0]}\n({m[2]})" for m in metrics])
|
||||
|
||||
# Set y-axis limits with some padding
|
||||
ax.set_ylim(0, max_val * 1.2)
|
||||
|
||||
if show_legend:
|
||||
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left',
|
||||
facecolor="#1a1a2e", edgecolor="#ffffff")
|
||||
|
||||
# Plot both subplots
|
||||
plot_grouped_bars(ax1, metrics1, show_legend=True)
|
||||
plot_grouped_bars(ax2, metrics2, show_legend=False)
|
||||
|
||||
# Style both subplots
|
||||
setup_plot(fig, ax1, 'Brightness and Volume')
|
||||
setup_plot(fig, ax2, 'Voice Pitch and Texture')
|
||||
|
||||
# Add y-axis labels
|
||||
ax1.set_ylabel('Value')
|
||||
ax2.set_ylabel('Value')
|
||||
|
||||
# Adjust the figure size to accommodate the legend
|
||||
fig.set_size_inches(15, 15)
|
||||
|
||||
# Add padding around the entire figure
|
||||
plt.subplots_adjust(right=0.85, top=0.95, bottom=0.05, left=0.1)
|
||||
plt.savefig(os.path.join(output_dir, "analysis_comparison.png"), dpi=300)
|
||||
print(f"Saved analysis comparison to {output_dir}/analysis_comparison.png")
|
||||
|
||||
# Print detailed comparative analysis
|
||||
print("\nDetailed Voice Analysis:")
|
||||
for label, chars in all_chars.items():
|
||||
print(f"\n{label}:")
|
||||
print(f" Max Amplitude: {chars['max_amplitude']:.2f}")
|
||||
print(f" RMS (loudness): {chars['rms']:.2f}")
|
||||
print(f" Duration: {chars['duration']:.2f}s")
|
||||
print(f" Zero Crossing Rate: {chars['zero_crossing_rate']:.3f}")
|
||||
print(f" Spectral Centroid: {chars['spectral_centroid']:.0f}Hz")
|
||||
print(f" Dominant Frequencies: {', '.join(f'{f:.0f}Hz' for f in chars['dominant_frequencies'])}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Kokoro Voice Analysis Demo")
|
||||
parser.add_argument("--voices", nargs="+", type=str, help="Voices to combine")
|
||||
parser.add_argument("--text", type=str, default="Hello! This is a test of combined voices.", help="Text to speak")
|
||||
parser.add_argument("--url", default="http://localhost:8880", help="API base URL")
|
||||
parser.add_argument("--output-dir", default="examples/output", help="Output directory for audio files")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.voices:
|
||||
print("No voices provided, using default test voices")
|
||||
args.voices = ["af_bella", "af_nicole"]
|
||||
|
||||
# Create output directory
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Dictionary to store audio files for analysis
|
||||
audio_files = {}
|
||||
|
||||
# Generate speech with individual voices
|
||||
print("Generating speech with individual voices...")
|
||||
for voice in args.voices:
|
||||
output_file = os.path.join(args.output_dir, f"analysis_{voice}.wav")
|
||||
if generate_speech(args.text, voice, args.url, output_file):
|
||||
audio_files[voice] = output_file
|
||||
|
||||
# Generate speech with combined voice
|
||||
print(f"\nCombining voices: {', '.join(args.voices)}")
|
||||
combined_voice = submit_combine_voices(args.voices, args.url)
|
||||
|
||||
if combined_voice:
|
||||
print(f"Successfully created combined voice: {combined_voice}")
|
||||
output_file = os.path.join(args.output_dir, f"analysis_combined_{combined_voice}.wav")
|
||||
if generate_speech(args.text, combined_voice, args.url, output_file):
|
||||
audio_files["combined"] = output_file
|
||||
|
||||
# Generate comparison plots
|
||||
plot_analysis(audio_files, args.output_dir)
|
||||
else:
|
||||
print("Failed to combine voices")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,33 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
from typing import List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def submit_combine_voices(voices: List[str], base_url: str = "http://localhost:8880") -> Optional[List[str]]:
|
||||
try:
|
||||
response = requests.post(f"{base_url}/v1/audio/voices/combine", json=voices)
|
||||
if response.status_code != 200:
|
||||
print(f"Error submitting request: {response.text}")
|
||||
return None
|
||||
return response.json()["voices"]
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Error: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Kokoro TTS CLI")
|
||||
parser.add_argument("--voices", nargs="+", type=str, help="Voices to combine")
|
||||
parser.add_argument("--url", default="http://localhost:8880", help="API base URL")
|
||||
args = parser.parse_args()
|
||||
|
||||
success = submit_combine_voices(args.voices, args.url)
|
||||
if success:
|
||||
for voice in success:
|
||||
print(f" {voice}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Add table
Reference in a new issue