- Added support for combining voices via any endpoint

- Updated the `process_voices` function to handle both string and list formats for voice input.
This commit is contained in:
remsky 2025-01-07 03:50:08 -07:00
parent bb1f9b54ba
commit 130b084cce
10 changed files with 259 additions and 104 deletions

@ -1 +1 @@
Subproject commit c97b7bbc3e60f447383c79b2f94fee861ff156ac
Subproject commit 3095858c40fc22e28c46429da9340dfda1f8cf28

View file

@ -3,17 +3,17 @@
</p>
# Kokoro TTS API
[![Tests](https://img.shields.io/badge/tests-98%20passed-darkgreen)]()
[![Coverage](https://img.shields.io/badge/coverage-73%25-darkgreen)]()
[![Tests](https://img.shields.io/badge/tests-105%20passed-darkgreen)]()
[![Coverage](https://img.shields.io/badge/coverage-74%25-darkgreen)]()
[![Tested at Model Commit](https://img.shields.io/badge/last--tested--model--commit-a67f113-blue)](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667) [![Try on Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Try%20on-Spaces-blue)](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero)
Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model
- OpenAI-compatible Speech endpoint, with voice combination functionality
- OpenAI-compatible Speech endpoint, with inline voice combination functionality
- NVIDIA GPU accelerated inference (or CPU) option
- very fast generation time
- ~ 35x real time speed via 4060Ti, ~300ms latency
- ~ 6x real time spead via M3 Pro CPU, ~1000ms latency
- streaming support w/ variable chunking control latency & artifacts
- streaming support w/ variable chunking to control latency & artifacts
- simple audio generation web ui utility
@ -39,7 +39,7 @@ The service can be accessed through either the API endpoints or the Gradio web i
response = client.audio.speech.create(
model="kokoro",
voice="af_bella",
voice="af_sky+af_bella", #single or multiple voicepack combo
input="Hello world!",
response_format="mp3"
)
@ -61,7 +61,7 @@ from openai import OpenAI
client = OpenAI(base_url="http://localhost:8880", api_key="not-needed")
response = client.audio.speech.create(
model="kokoro", # Not used but required for compatibility, also accepts library defaults
voice="af_bella",
voice="af_bella+af_sky",
input="Hello world!",
response_format="mp3"
)
@ -105,6 +105,7 @@ python examples/test_all_voices.py # Test all available voices
- Averages model weights of any existing voicepacks
- Saves generated voicepacks for future use
- (new) Available through any endpoint, simply concatenate desired packs with "+"
Combine voices and generate audio:
```python
@ -119,12 +120,12 @@ response = requests.post(
)
combined_voice = response.json()["voice"]
# Generate audio with combined voice
# Generate audio with combined voice (or, simply pass multiple directly with `+` )
response = requests.post(
"http://localhost:8880/v1/audio/speech",
json={
"input": "Hello world!",
"voice": combined_voice,
"voice": combined_voice, # or skip the above step with f"{voices[0]}+{voices[1]}"
"response_format": "mp3"
}
)

View file

@ -1,4 +1,4 @@
from typing import List
from typing import List, Union
from loguru import logger
from fastapi import Depends, Response, APIRouter, HTTPException
@ -20,18 +20,43 @@ def get_tts_service() -> TTSService:
return TTSService() # Initialize TTSService with default settings
async def process_voices(voice_input: Union[str, List[str]], tts_service: TTSService) -> str:
"""Process voice input into a combined voice, handling both string and list formats"""
# Convert input to list of voices
if isinstance(voice_input, str):
voices = [v.strip() for v in voice_input.split("+") if v.strip()]
else:
voices = voice_input
if not voices:
raise ValueError("No voices provided")
# Check if all voices exist
available_voices = await tts_service.list_voices()
for voice in voices:
if voice not in available_voices:
raise ValueError(f"Voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}")
# If single voice, return it directly
if len(voices) == 1:
return voices[0]
# Otherwise combine voices
return await tts_service.combine_voices(voices=voices)
async def stream_audio_chunks(tts_service: TTSService, request: OpenAISpeechRequest) -> AsyncGenerator[bytes, None]:
"""Stream audio chunks as they're generated"""
voice_to_use = await process_voices(request.voice, tts_service)
async for chunk in tts_service.generate_audio_stream(
text=request.input,
voice=request.voice,
voice=voice_to_use,
speed=request.speed,
output_format=request.response_format
):
yield chunk
@router.post("/audio/speech")
async def create_speech(
request: OpenAISpeechRequest,
@ -40,12 +65,8 @@ async def create_speech(
):
"""OpenAI-compatible endpoint for text-to-speech"""
try:
# Validate voice exists
available_voices = tts_service.list_voices()
if request.voice not in available_voices:
raise ValueError(
f"Voice '{request.voice}' not found. Available voices: {', '.join(sorted(available_voices))}"
)
# Process voice combination and validate
voice_to_use = await process_voices(request.voice, tts_service)
# Set content type based on format
content_type = {
@ -73,7 +94,7 @@ async def create_speech(
# Generate complete audio
audio, _ = tts_service._generate_audio(
text=request.input,
voice=request.voice,
voice=voice_to_use,
speed=request.speed,
stitch_long_output=True,
)
@ -111,7 +132,7 @@ async def create_speech(
async def list_voices(tts_service: TTSService = Depends(get_tts_service)):
"""List all available voices for text-to-speech"""
try:
voices = tts_service.list_voices()
voices = await tts_service.list_voices()
return {"voices": voices}
except Exception as e:
logger.error(f"Error listing voices: {str(e)}")
@ -120,12 +141,13 @@ async def list_voices(tts_service: TTSService = Depends(get_tts_service)):
@router.post("/audio/voices/combine")
async def combine_voices(
request: List[str], tts_service: TTSService = Depends(get_tts_service)
request: Union[str, List[str]], tts_service: TTSService = Depends(get_tts_service)
):
"""Combine multiple voices into a new voice.
Args:
request: List of voice names to combine
request: Either a string with voices separated by + (e.g. "voice1+voice2")
or a list of voice names to combine
Returns:
Dict with combined voice name and list of all available voices
@ -136,8 +158,8 @@ async def combine_voices(
- 500: Server error (file system issues, combination failed)
"""
try:
combined_voice = tts_service.combine_voices(voices=request)
voices = tts_service.list_voices()
combined_voice = await process_voices(request, tts_service)
voices = await tts_service.list_voices()
return {"voices": voices, "voice": combined_voice}
except ValueError as e:
@ -146,14 +168,8 @@ async def combine_voices(
status_code=400, detail={"error": "Invalid request", "message": str(e)}
)
except RuntimeError as e:
except Exception as e:
logger.error(f"Server error during voice combination: {str(e)}")
raise HTTPException(
status_code=500, detail={"error": "Server error", "message": str(e)}
)
except Exception as e:
logger.error(f"Unexpected error during voice combination: {str(e)}")
raise HTTPException(
status_code=500, detail={"error": "Unexpected error", "message": str(e)}
status_code=500, detail={"error": "Server error", "message": "Server error"}
)

View file

@ -1,9 +1,11 @@
import aiofiles
import io
import os
import re
import time
from typing import List, Tuple, Optional
from functools import lru_cache
from aiofiles import threadpool
import numpy as np
import torch
@ -211,7 +213,7 @@ class TTSService:
wavfile.write(buffer, 24000, audio)
return buffer.getvalue()
def combine_voices(self, voices: List[str]) -> str:
async def combine_voices(self, voices: List[str]) -> str:
"""Combine multiple voices into a new voice"""
if len(voices) < 2:
raise ValueError("At least 2 voices are required for combination")
@ -252,11 +254,13 @@ class TTSService:
raise RuntimeError(f"Error combining voices: {str(e)}")
raise
def list_voices(self) -> List[str]:
async def list_voices(self) -> List[str]:
"""List all available voices"""
voices = []
try:
for file in os.listdir(TTSModel.VOICES_DIR):
# Use os.listdir in a thread pool
files = await threadpool.async_wrap(os.listdir)(TTSModel.VOICES_DIR)
for file in files:
if file.endswith(".pt"):
voices.append(file[:-3]) # Remove .pt extension
except Exception as e:

View file

@ -1,9 +1,17 @@
from enum import Enum
from typing import Literal
from typing import Literal, Union, List
from pydantic import Field, BaseModel
class VoiceCombineRequest(BaseModel):
"""Request schema for voice combination endpoint that accepts either a string with + or a list"""
voices: Union[str, List[str]] = Field(
...,
description="Either a string with voices separated by + (e.g. 'voice1+voice2') or a list of voice names to combine"
)
class TTSStatus(str, Enum):
PENDING = "pending"
PROCESSING = "processing"

View file

@ -29,7 +29,9 @@ def mock_tts_service(monkeypatch):
for chunk in [b"chunk1", b"chunk2"]:
yield chunk
mock_service.generate_audio_stream = mock_stream
mock_service.list_voices.return_value = [
# Create async mocks
mock_service.list_voices = AsyncMock(return_value=[
"af",
"bm_lewis",
"bf_isabella",
@ -39,7 +41,8 @@ def mock_tts_service(monkeypatch):
"am_adam",
"am_michael",
"bm_george",
]
])
mock_service.combine_voices = AsyncMock()
monkeypatch.setattr(
"api.src.routers.openai_compatible.TTSService",
lambda *args, **kwargs: mock_service,
@ -64,7 +67,8 @@ def test_health_check():
assert response.json() == {"status": "healthy"}
def test_openai_speech_endpoint(mock_tts_service, mock_audio_service):
@pytest.mark.asyncio
async def test_openai_speech_endpoint(mock_tts_service, mock_audio_service, async_client):
"""Test the OpenAI-compatible speech endpoint"""
test_request = {
"model": "kokoro",
@ -74,7 +78,7 @@ def test_openai_speech_endpoint(mock_tts_service, mock_audio_service):
"speed": 1.0,
"stream": False # Explicitly disable streaming
}
response = client.post("/v1/audio/speech", json=test_request)
response = await async_client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/wav"
assert response.headers["content-disposition"] == "attachment; filename=speech.wav"
@ -84,7 +88,8 @@ def test_openai_speech_endpoint(mock_tts_service, mock_audio_service):
assert response.content == b"converted mock audio data"
def test_openai_speech_invalid_voice(mock_tts_service):
@pytest.mark.asyncio
async def test_openai_speech_invalid_voice(mock_tts_service, async_client):
"""Test the OpenAI-compatible speech endpoint with invalid voice"""
test_request = {
"model": "kokoro",
@ -94,12 +99,13 @@ def test_openai_speech_invalid_voice(mock_tts_service):
"speed": 1.0,
"stream": False # Explicitly disable streaming
}
response = client.post("/v1/audio/speech", json=test_request)
response = await async_client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 400 # Bad request
assert "not found" in response.json()["detail"]["message"]
def test_openai_speech_invalid_speed(mock_tts_service):
@pytest.mark.asyncio
async def test_openai_speech_invalid_speed(mock_tts_service, async_client):
"""Test the OpenAI-compatible speech endpoint with invalid speed"""
test_request = {
"model": "kokoro",
@ -109,11 +115,12 @@ def test_openai_speech_invalid_speed(mock_tts_service):
"speed": -1.0, # Invalid speed
"stream": False # Explicitly disable streaming
}
response = client.post("/v1/audio/speech", json=test_request)
response = await async_client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 422 # Validation error
def test_openai_speech_generation_error(mock_tts_service):
@pytest.mark.asyncio
async def test_openai_speech_generation_error(mock_tts_service, async_client):
"""Test error handling in speech generation"""
mock_tts_service._generate_audio.side_effect = Exception("Generation failed")
test_request = {
@ -124,54 +131,173 @@ def test_openai_speech_generation_error(mock_tts_service):
"speed": 1.0,
"stream": False # Explicitly disable streaming
}
response = client.post("/v1/audio/speech", json=test_request)
response = await async_client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 500
assert "Generation failed" in response.json()["detail"]["message"]
def test_combine_voices_success(mock_tts_service):
"""Test successful voice combination"""
@pytest.mark.asyncio
async def test_combine_voices_list_success(mock_tts_service, async_client):
"""Test successful voice combination using list format"""
test_voices = ["af_bella", "af_sarah"]
mock_tts_service.combine_voices.return_value = "af_bella_af_sarah"
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
response = client.post("/v1/audio/voices/combine", json=test_voices)
response = await async_client.post("/v1/audio/voices/combine", json=test_voices)
assert response.status_code == 200
assert response.json()["voice"] == "af_bella_af_sarah"
mock_tts_service.combine_voices.assert_called_once_with(voices=test_voices)
def test_combine_voices_single_voice(mock_tts_service):
"""Test combining single voice returns default voice"""
@pytest.mark.asyncio
async def test_combine_voices_string_success(mock_tts_service, async_client):
"""Test successful voice combination using string format with +"""
test_voices = "af_bella+af_sarah"
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
response = await async_client.post("/v1/audio/voices/combine", json=test_voices)
assert response.status_code == 200
assert response.json()["voice"] == "af_bella_af_sarah"
mock_tts_service.combine_voices.assert_called_once_with(voices=["af_bella", "af_sarah"])
@pytest.mark.asyncio
async def test_combine_voices_single_voice(mock_tts_service, async_client):
"""Test combining single voice returns same voice"""
test_voices = ["af_bella"]
mock_tts_service.combine_voices.return_value = "af"
response = client.post("/v1/audio/voices/combine", json=test_voices)
response = await async_client.post("/v1/audio/voices/combine", json=test_voices)
assert response.status_code == 200
assert response.json()["voice"] == "af"
assert response.json()["voice"] == "af_bella"
def test_combine_voices_empty_list(mock_tts_service):
"""Test combining empty voice list returns default voice"""
@pytest.mark.asyncio
async def test_combine_voices_empty_list(mock_tts_service, async_client):
"""Test combining empty voice list returns error"""
test_voices = []
mock_tts_service.combine_voices.return_value = "af"
response = client.post("/v1/audio/voices/combine", json=test_voices)
assert response.status_code == 200
assert response.json()["voice"] == "af"
response = await async_client.post("/v1/audio/voices/combine", json=test_voices)
assert response.status_code == 400
assert "No voices provided" in response.json()["detail"]["message"]
def test_combine_voices_error(mock_tts_service):
@pytest.mark.asyncio
async def test_combine_voices_error(mock_tts_service, async_client):
"""Test error handling in voice combination"""
test_voices = ["af_bella", "af_sarah"]
mock_tts_service.combine_voices.side_effect = Exception("Combination failed")
response = client.post("/v1/audio/voices/combine", json=test_voices)
mock_tts_service.combine_voices = AsyncMock(side_effect=Exception("Combination failed"))
response = await async_client.post("/v1/audio/voices/combine", json=test_voices)
assert response.status_code == 500
assert "Combination failed" in response.json()["detail"]["message"]
assert "Server error" in response.json()["detail"]["message"]
@pytest.mark.asyncio
async def test_speech_with_combined_voice(mock_tts_service, mock_audio_service, async_client):
"""Test speech generation with combined voice using + syntax"""
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
test_request = {
"model": "kokoro",
"input": "Hello world",
"voice": "af_bella+af_sarah",
"response_format": "wav",
"speed": 1.0,
"stream": False
}
response = await async_client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/wav"
mock_tts_service._generate_audio.assert_called_once_with(
text="Hello world",
voice="af_bella_af_sarah",
speed=1.0,
stitch_long_output=True
)
@pytest.mark.asyncio
async def test_speech_with_whitespace_in_voice(mock_tts_service, mock_audio_service, async_client):
"""Test speech generation with whitespace in voice combination"""
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
test_request = {
"model": "kokoro",
"input": "Hello world",
"voice": " af_bella + af_sarah ",
"response_format": "wav",
"speed": 1.0,
"stream": False
}
response = await async_client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/wav"
mock_tts_service.combine_voices.assert_called_once_with(voices=["af_bella", "af_sarah"])
@pytest.mark.asyncio
async def test_speech_with_empty_voice_combination(mock_tts_service, async_client):
"""Test speech generation with empty voice combination"""
test_request = {
"model": "kokoro",
"input": "Hello world",
"voice": "+",
"response_format": "wav",
"speed": 1.0,
"stream": False
}
response = await async_client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 400
assert "No voices provided" in response.json()["detail"]["message"]
@pytest.mark.asyncio
async def test_speech_with_invalid_combined_voice(mock_tts_service, async_client):
"""Test speech generation with invalid voice combination"""
test_request = {
"model": "kokoro",
"input": "Hello world",
"voice": "invalid+combination",
"response_format": "wav",
"speed": 1.0,
"stream": False
}
response = await async_client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 400
assert "not found" in response.json()["detail"]["message"]
@pytest.mark.asyncio
async def test_speech_streaming_with_combined_voice(mock_tts_service, async_client):
"""Test streaming speech with combined voice using + syntax"""
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
test_request = {
"model": "kokoro",
"input": "Hello world",
"voice": "af_bella+af_sarah",
"response_format": "mp3",
"stream": True
}
# Create streaming mock
async def mock_stream(*args, **kwargs):
for chunk in [b"mp3header", b"mp3data"]:
yield chunk
mock_tts_service.generate_audio_stream = mock_stream
# Add streaming header
headers = {"x-raw-response": "stream"}
response = await async_client.post("/v1/audio/speech", json=test_request, headers=headers)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/mpeg"
assert response.headers["content-disposition"] == "attachment; filename=speech.mp3"
@pytest.mark.asyncio
@ -197,9 +323,6 @@ async def test_openai_speech_pcm_streaming(mock_tts_service, async_client):
assert response.status_code == 200
assert response.headers["content-type"] == "audio/pcm"
# Just verify status and content type
assert response.status_code == 200
assert response.headers["content-type"] == "audio/pcm"
@pytest.mark.asyncio
@ -226,10 +349,6 @@ async def test_openai_speech_streaming_mp3(mock_tts_service, async_client):
assert response.status_code == 200
assert response.headers["content-type"] == "audio/mpeg"
assert response.headers["content-disposition"] == "attachment; filename=speech.mp3"
# Just verify status and content type
assert response.status_code == 200
assert response.headers["content-type"] == "audio/mpeg"
assert response.headers["content-disposition"] == "attachment; filename=speech.mp3"
@pytest.mark.asyncio
@ -255,6 +374,3 @@ async def test_openai_speech_streaming_generator(mock_tts_service, async_client)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/pcm"
# Just verify status and content type
assert response.status_code == 200
assert response.headers["content-type"] == "audio/pcm"

View file

@ -1,12 +1,13 @@
"""Tests for TTSService"""
import os
from unittest.mock import MagicMock, call, patch
from unittest.mock import MagicMock, call, patch, AsyncMock
import numpy as np
import torch
import pytest
from onnxruntime import InferenceSession
from aiofiles import threadpool
from api.src.core.config import settings
from api.src.services.tts_model import TTSModel
@ -38,27 +39,33 @@ def test_audio_to_bytes(tts_service, sample_audio):
assert len(audio_bytes) > 0
@patch("os.listdir")
@patch("os.path.join")
def test_list_voices(mock_join, mock_listdir, tts_service):
@pytest.mark.asyncio
async def test_list_voices(tts_service):
"""Test listing available voices"""
mock_listdir.return_value = ["voice1.pt", "voice2.pt", "not_a_voice.txt"]
mock_join.return_value = "/fake/path"
voices = tts_service.list_voices()
assert len(voices) == 2
assert "voice1" in voices
assert "voice2" in voices
assert "not_a_voice" not in voices
# Mock os.listdir to return test files
with patch('os.listdir', return_value=["voice1.pt", "voice2.pt", "not_a_voice.txt"]):
# Register mock with threadpool
async_listdir = AsyncMock(return_value=["voice1.pt", "voice2.pt", "not_a_voice.txt"])
threadpool.async_wrap = MagicMock(return_value=async_listdir)
voices = await tts_service.list_voices()
assert len(voices) == 2
assert "voice1" in voices
assert "voice2" in voices
assert "not_a_voice" not in voices
@patch("os.listdir")
def test_list_voices_error(mock_listdir, tts_service):
@pytest.mark.asyncio
async def test_list_voices_error(tts_service):
"""Test error handling in list_voices"""
mock_listdir.side_effect = Exception("Failed to list directory")
voices = tts_service.list_voices()
assert voices == []
# Mock os.listdir to raise an exception
with patch('os.listdir', side_effect=Exception("Failed to list directory")):
# Register mock with threadpool
async_listdir = AsyncMock(side_effect=Exception("Failed to list directory"))
threadpool.async_wrap = MagicMock(return_value=async_listdir)
voices = await tts_service.list_voices()
assert voices == []
def mock_model_setup(cuda_available=False):
@ -176,7 +183,8 @@ def test_save_audio(tts_service, sample_audio, tmp_path):
assert os.path.getsize(output_path) > 0
def test_combine_voices(tts_service):
@pytest.mark.asyncio
async def test_combine_voices(tts_service):
"""Test combining multiple voices"""
# Setup mocks for torch operations
with patch('torch.load', return_value=torch.tensor([1.0, 2.0])), \
@ -186,20 +194,21 @@ def test_combine_voices(tts_service):
patch('os.path.exists', return_value=True):
# Test combining two voices
result = tts_service.combine_voices(["voice1", "voice2"])
result = await tts_service.combine_voices(["voice1", "voice2"])
assert result == "voice1_voice2"
def test_combine_voices_invalid_input(tts_service):
@pytest.mark.asyncio
async def test_combine_voices_invalid_input(tts_service):
"""Test combining voices with invalid input"""
# Test with empty list
with pytest.raises(ValueError, match="At least 2 voices are required"):
tts_service.combine_voices([])
await tts_service.combine_voices([])
# Test with single voice
with pytest.raises(ValueError, match="At least 2 voices are required"):
tts_service.combine_voices(["voice1"])
await tts_service.combine_voices(["voice1"])
@patch("api.src.services.tts_service.TTSService._get_voice_path")

View file

@ -34,7 +34,7 @@ def stream_to_speakers() -> None:
with openai.audio.speech.with_streaming_response.create(
model="kokoro",
voice="af",
voice="af_sky+af_bella+bm_george",
response_format="pcm", # similar to WAV, but without a header chunk at the start.
input="""My dear sir, that is just where you are wrong. That is just where the whole world has gone wrong. We are always getting away from the present moment. Our mental existences, which are immaterial and have no dimensions, are passing along the Time-Dimension with a uniform velocity from the cradle to the grave. Just as we should travel down if we began our existence fifty miles above the earths surface""",
) as response:

BIN
examples/speech.mp3 Normal file

Binary file not shown.

View file

@ -20,6 +20,7 @@ phonemizer==3.3.0
regex==2024.11.6
# Utilities
aiofiles==24.1.0
tqdm==4.67.1
requests==2.32.3
munch==4.0.0