mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
- Added support for combining voices via any endpoint
- Updated the `process_voices` function to handle both string and list formats for voice input.
This commit is contained in:
parent
bb1f9b54ba
commit
130b084cce
10 changed files with 259 additions and 104 deletions
|
@ -1 +1 @@
|
||||||
Subproject commit c97b7bbc3e60f447383c79b2f94fee861ff156ac
|
Subproject commit 3095858c40fc22e28c46429da9340dfda1f8cf28
|
17
README.md
17
README.md
|
@ -3,17 +3,17 @@
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
# Kokoro TTS API
|
# Kokoro TTS API
|
||||||
[]()
|
[]()
|
||||||
[]()
|
[]()
|
||||||
[](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667) [](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero)
|
[](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667) [](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero)
|
||||||
|
|
||||||
Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model
|
Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model
|
||||||
- OpenAI-compatible Speech endpoint, with voice combination functionality
|
- OpenAI-compatible Speech endpoint, with inline voice combination functionality
|
||||||
- NVIDIA GPU accelerated inference (or CPU) option
|
- NVIDIA GPU accelerated inference (or CPU) option
|
||||||
- very fast generation time
|
- very fast generation time
|
||||||
- ~ 35x real time speed via 4060Ti, ~300ms latency
|
- ~ 35x real time speed via 4060Ti, ~300ms latency
|
||||||
- ~ 6x real time spead via M3 Pro CPU, ~1000ms latency
|
- ~ 6x real time spead via M3 Pro CPU, ~1000ms latency
|
||||||
- streaming support w/ variable chunking control latency & artifacts
|
- streaming support w/ variable chunking to control latency & artifacts
|
||||||
- simple audio generation web ui utility
|
- simple audio generation web ui utility
|
||||||
|
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ The service can be accessed through either the API endpoints or the Gradio web i
|
||||||
|
|
||||||
response = client.audio.speech.create(
|
response = client.audio.speech.create(
|
||||||
model="kokoro",
|
model="kokoro",
|
||||||
voice="af_bella",
|
voice="af_sky+af_bella", #single or multiple voicepack combo
|
||||||
input="Hello world!",
|
input="Hello world!",
|
||||||
response_format="mp3"
|
response_format="mp3"
|
||||||
)
|
)
|
||||||
|
@ -61,7 +61,7 @@ from openai import OpenAI
|
||||||
client = OpenAI(base_url="http://localhost:8880", api_key="not-needed")
|
client = OpenAI(base_url="http://localhost:8880", api_key="not-needed")
|
||||||
response = client.audio.speech.create(
|
response = client.audio.speech.create(
|
||||||
model="kokoro", # Not used but required for compatibility, also accepts library defaults
|
model="kokoro", # Not used but required for compatibility, also accepts library defaults
|
||||||
voice="af_bella",
|
voice="af_bella+af_sky",
|
||||||
input="Hello world!",
|
input="Hello world!",
|
||||||
response_format="mp3"
|
response_format="mp3"
|
||||||
)
|
)
|
||||||
|
@ -105,6 +105,7 @@ python examples/test_all_voices.py # Test all available voices
|
||||||
|
|
||||||
- Averages model weights of any existing voicepacks
|
- Averages model weights of any existing voicepacks
|
||||||
- Saves generated voicepacks for future use
|
- Saves generated voicepacks for future use
|
||||||
|
- (new) Available through any endpoint, simply concatenate desired packs with "+"
|
||||||
|
|
||||||
Combine voices and generate audio:
|
Combine voices and generate audio:
|
||||||
```python
|
```python
|
||||||
|
@ -119,12 +120,12 @@ response = requests.post(
|
||||||
)
|
)
|
||||||
combined_voice = response.json()["voice"]
|
combined_voice = response.json()["voice"]
|
||||||
|
|
||||||
# Generate audio with combined voice
|
# Generate audio with combined voice (or, simply pass multiple directly with `+` )
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
"http://localhost:8880/v1/audio/speech",
|
"http://localhost:8880/v1/audio/speech",
|
||||||
json={
|
json={
|
||||||
"input": "Hello world!",
|
"input": "Hello world!",
|
||||||
"voice": combined_voice,
|
"voice": combined_voice, # or skip the above step with f"{voices[0]}+{voices[1]}"
|
||||||
"response_format": "mp3"
|
"response_format": "mp3"
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import List
|
from typing import List, Union
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from fastapi import Depends, Response, APIRouter, HTTPException
|
from fastapi import Depends, Response, APIRouter, HTTPException
|
||||||
|
@ -20,18 +20,43 @@ def get_tts_service() -> TTSService:
|
||||||
return TTSService() # Initialize TTSService with default settings
|
return TTSService() # Initialize TTSService with default settings
|
||||||
|
|
||||||
|
|
||||||
|
async def process_voices(voice_input: Union[str, List[str]], tts_service: TTSService) -> str:
|
||||||
|
"""Process voice input into a combined voice, handling both string and list formats"""
|
||||||
|
# Convert input to list of voices
|
||||||
|
if isinstance(voice_input, str):
|
||||||
|
voices = [v.strip() for v in voice_input.split("+") if v.strip()]
|
||||||
|
else:
|
||||||
|
voices = voice_input
|
||||||
|
|
||||||
|
if not voices:
|
||||||
|
raise ValueError("No voices provided")
|
||||||
|
|
||||||
|
# Check if all voices exist
|
||||||
|
available_voices = await tts_service.list_voices()
|
||||||
|
for voice in voices:
|
||||||
|
if voice not in available_voices:
|
||||||
|
raise ValueError(f"Voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}")
|
||||||
|
|
||||||
|
# If single voice, return it directly
|
||||||
|
if len(voices) == 1:
|
||||||
|
return voices[0]
|
||||||
|
|
||||||
|
# Otherwise combine voices
|
||||||
|
return await tts_service.combine_voices(voices=voices)
|
||||||
|
|
||||||
|
|
||||||
async def stream_audio_chunks(tts_service: TTSService, request: OpenAISpeechRequest) -> AsyncGenerator[bytes, None]:
|
async def stream_audio_chunks(tts_service: TTSService, request: OpenAISpeechRequest) -> AsyncGenerator[bytes, None]:
|
||||||
"""Stream audio chunks as they're generated"""
|
"""Stream audio chunks as they're generated"""
|
||||||
|
voice_to_use = await process_voices(request.voice, tts_service)
|
||||||
async for chunk in tts_service.generate_audio_stream(
|
async for chunk in tts_service.generate_audio_stream(
|
||||||
text=request.input,
|
text=request.input,
|
||||||
voice=request.voice,
|
voice=voice_to_use,
|
||||||
speed=request.speed,
|
speed=request.speed,
|
||||||
output_format=request.response_format
|
output_format=request.response_format
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/audio/speech")
|
@router.post("/audio/speech")
|
||||||
async def create_speech(
|
async def create_speech(
|
||||||
request: OpenAISpeechRequest,
|
request: OpenAISpeechRequest,
|
||||||
|
@ -40,12 +65,8 @@ async def create_speech(
|
||||||
):
|
):
|
||||||
"""OpenAI-compatible endpoint for text-to-speech"""
|
"""OpenAI-compatible endpoint for text-to-speech"""
|
||||||
try:
|
try:
|
||||||
# Validate voice exists
|
# Process voice combination and validate
|
||||||
available_voices = tts_service.list_voices()
|
voice_to_use = await process_voices(request.voice, tts_service)
|
||||||
if request.voice not in available_voices:
|
|
||||||
raise ValueError(
|
|
||||||
f"Voice '{request.voice}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set content type based on format
|
# Set content type based on format
|
||||||
content_type = {
|
content_type = {
|
||||||
|
@ -73,7 +94,7 @@ async def create_speech(
|
||||||
# Generate complete audio
|
# Generate complete audio
|
||||||
audio, _ = tts_service._generate_audio(
|
audio, _ = tts_service._generate_audio(
|
||||||
text=request.input,
|
text=request.input,
|
||||||
voice=request.voice,
|
voice=voice_to_use,
|
||||||
speed=request.speed,
|
speed=request.speed,
|
||||||
stitch_long_output=True,
|
stitch_long_output=True,
|
||||||
)
|
)
|
||||||
|
@ -111,7 +132,7 @@ async def create_speech(
|
||||||
async def list_voices(tts_service: TTSService = Depends(get_tts_service)):
|
async def list_voices(tts_service: TTSService = Depends(get_tts_service)):
|
||||||
"""List all available voices for text-to-speech"""
|
"""List all available voices for text-to-speech"""
|
||||||
try:
|
try:
|
||||||
voices = tts_service.list_voices()
|
voices = await tts_service.list_voices()
|
||||||
return {"voices": voices}
|
return {"voices": voices}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error listing voices: {str(e)}")
|
logger.error(f"Error listing voices: {str(e)}")
|
||||||
|
@ -120,12 +141,13 @@ async def list_voices(tts_service: TTSService = Depends(get_tts_service)):
|
||||||
|
|
||||||
@router.post("/audio/voices/combine")
|
@router.post("/audio/voices/combine")
|
||||||
async def combine_voices(
|
async def combine_voices(
|
||||||
request: List[str], tts_service: TTSService = Depends(get_tts_service)
|
request: Union[str, List[str]], tts_service: TTSService = Depends(get_tts_service)
|
||||||
):
|
):
|
||||||
"""Combine multiple voices into a new voice.
|
"""Combine multiple voices into a new voice.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: List of voice names to combine
|
request: Either a string with voices separated by + (e.g. "voice1+voice2")
|
||||||
|
or a list of voice names to combine
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with combined voice name and list of all available voices
|
Dict with combined voice name and list of all available voices
|
||||||
|
@ -136,8 +158,8 @@ async def combine_voices(
|
||||||
- 500: Server error (file system issues, combination failed)
|
- 500: Server error (file system issues, combination failed)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
combined_voice = tts_service.combine_voices(voices=request)
|
combined_voice = await process_voices(request, tts_service)
|
||||||
voices = tts_service.list_voices()
|
voices = await tts_service.list_voices()
|
||||||
return {"voices": voices, "voice": combined_voice}
|
return {"voices": voices, "voice": combined_voice}
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
@ -146,14 +168,8 @@ async def combine_voices(
|
||||||
status_code=400, detail={"error": "Invalid request", "message": str(e)}
|
status_code=400, detail={"error": "Invalid request", "message": str(e)}
|
||||||
)
|
)
|
||||||
|
|
||||||
except RuntimeError as e:
|
except Exception as e:
|
||||||
logger.error(f"Server error during voice combination: {str(e)}")
|
logger.error(f"Server error during voice combination: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail={"error": "Server error", "message": str(e)}
|
status_code=500, detail={"error": "Server error", "message": "Server error"}
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Unexpected error during voice combination: {str(e)}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=500, detail={"error": "Unexpected error", "message": str(e)}
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
|
import aiofiles
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from typing import List, Tuple, Optional
|
from typing import List, Tuple, Optional
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
from aiofiles import threadpool
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -211,7 +213,7 @@ class TTSService:
|
||||||
wavfile.write(buffer, 24000, audio)
|
wavfile.write(buffer, 24000, audio)
|
||||||
return buffer.getvalue()
|
return buffer.getvalue()
|
||||||
|
|
||||||
def combine_voices(self, voices: List[str]) -> str:
|
async def combine_voices(self, voices: List[str]) -> str:
|
||||||
"""Combine multiple voices into a new voice"""
|
"""Combine multiple voices into a new voice"""
|
||||||
if len(voices) < 2:
|
if len(voices) < 2:
|
||||||
raise ValueError("At least 2 voices are required for combination")
|
raise ValueError("At least 2 voices are required for combination")
|
||||||
|
@ -252,11 +254,13 @@ class TTSService:
|
||||||
raise RuntimeError(f"Error combining voices: {str(e)}")
|
raise RuntimeError(f"Error combining voices: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def list_voices(self) -> List[str]:
|
async def list_voices(self) -> List[str]:
|
||||||
"""List all available voices"""
|
"""List all available voices"""
|
||||||
voices = []
|
voices = []
|
||||||
try:
|
try:
|
||||||
for file in os.listdir(TTSModel.VOICES_DIR):
|
# Use os.listdir in a thread pool
|
||||||
|
files = await threadpool.async_wrap(os.listdir)(TTSModel.VOICES_DIR)
|
||||||
|
for file in files:
|
||||||
if file.endswith(".pt"):
|
if file.endswith(".pt"):
|
||||||
voices.append(file[:-3]) # Remove .pt extension
|
voices.append(file[:-3]) # Remove .pt extension
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -1,9 +1,17 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Literal
|
from typing import Literal, Union, List
|
||||||
|
|
||||||
from pydantic import Field, BaseModel
|
from pydantic import Field, BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class VoiceCombineRequest(BaseModel):
|
||||||
|
"""Request schema for voice combination endpoint that accepts either a string with + or a list"""
|
||||||
|
voices: Union[str, List[str]] = Field(
|
||||||
|
...,
|
||||||
|
description="Either a string with voices separated by + (e.g. 'voice1+voice2') or a list of voice names to combine"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TTSStatus(str, Enum):
|
class TTSStatus(str, Enum):
|
||||||
PENDING = "pending"
|
PENDING = "pending"
|
||||||
PROCESSING = "processing"
|
PROCESSING = "processing"
|
||||||
|
|
|
@ -29,7 +29,9 @@ def mock_tts_service(monkeypatch):
|
||||||
for chunk in [b"chunk1", b"chunk2"]:
|
for chunk in [b"chunk1", b"chunk2"]:
|
||||||
yield chunk
|
yield chunk
|
||||||
mock_service.generate_audio_stream = mock_stream
|
mock_service.generate_audio_stream = mock_stream
|
||||||
mock_service.list_voices.return_value = [
|
|
||||||
|
# Create async mocks
|
||||||
|
mock_service.list_voices = AsyncMock(return_value=[
|
||||||
"af",
|
"af",
|
||||||
"bm_lewis",
|
"bm_lewis",
|
||||||
"bf_isabella",
|
"bf_isabella",
|
||||||
|
@ -39,7 +41,8 @@ def mock_tts_service(monkeypatch):
|
||||||
"am_adam",
|
"am_adam",
|
||||||
"am_michael",
|
"am_michael",
|
||||||
"bm_george",
|
"bm_george",
|
||||||
]
|
])
|
||||||
|
mock_service.combine_voices = AsyncMock()
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"api.src.routers.openai_compatible.TTSService",
|
"api.src.routers.openai_compatible.TTSService",
|
||||||
lambda *args, **kwargs: mock_service,
|
lambda *args, **kwargs: mock_service,
|
||||||
|
@ -64,7 +67,8 @@ def test_health_check():
|
||||||
assert response.json() == {"status": "healthy"}
|
assert response.json() == {"status": "healthy"}
|
||||||
|
|
||||||
|
|
||||||
def test_openai_speech_endpoint(mock_tts_service, mock_audio_service):
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_speech_endpoint(mock_tts_service, mock_audio_service, async_client):
|
||||||
"""Test the OpenAI-compatible speech endpoint"""
|
"""Test the OpenAI-compatible speech endpoint"""
|
||||||
test_request = {
|
test_request = {
|
||||||
"model": "kokoro",
|
"model": "kokoro",
|
||||||
|
@ -74,7 +78,7 @@ def test_openai_speech_endpoint(mock_tts_service, mock_audio_service):
|
||||||
"speed": 1.0,
|
"speed": 1.0,
|
||||||
"stream": False # Explicitly disable streaming
|
"stream": False # Explicitly disable streaming
|
||||||
}
|
}
|
||||||
response = client.post("/v1/audio/speech", json=test_request)
|
response = await async_client.post("/v1/audio/speech", json=test_request)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.headers["content-type"] == "audio/wav"
|
assert response.headers["content-type"] == "audio/wav"
|
||||||
assert response.headers["content-disposition"] == "attachment; filename=speech.wav"
|
assert response.headers["content-disposition"] == "attachment; filename=speech.wav"
|
||||||
|
@ -84,7 +88,8 @@ def test_openai_speech_endpoint(mock_tts_service, mock_audio_service):
|
||||||
assert response.content == b"converted mock audio data"
|
assert response.content == b"converted mock audio data"
|
||||||
|
|
||||||
|
|
||||||
def test_openai_speech_invalid_voice(mock_tts_service):
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_speech_invalid_voice(mock_tts_service, async_client):
|
||||||
"""Test the OpenAI-compatible speech endpoint with invalid voice"""
|
"""Test the OpenAI-compatible speech endpoint with invalid voice"""
|
||||||
test_request = {
|
test_request = {
|
||||||
"model": "kokoro",
|
"model": "kokoro",
|
||||||
|
@ -94,12 +99,13 @@ def test_openai_speech_invalid_voice(mock_tts_service):
|
||||||
"speed": 1.0,
|
"speed": 1.0,
|
||||||
"stream": False # Explicitly disable streaming
|
"stream": False # Explicitly disable streaming
|
||||||
}
|
}
|
||||||
response = client.post("/v1/audio/speech", json=test_request)
|
response = await async_client.post("/v1/audio/speech", json=test_request)
|
||||||
assert response.status_code == 400 # Bad request
|
assert response.status_code == 400 # Bad request
|
||||||
assert "not found" in response.json()["detail"]["message"]
|
assert "not found" in response.json()["detail"]["message"]
|
||||||
|
|
||||||
|
|
||||||
def test_openai_speech_invalid_speed(mock_tts_service):
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_speech_invalid_speed(mock_tts_service, async_client):
|
||||||
"""Test the OpenAI-compatible speech endpoint with invalid speed"""
|
"""Test the OpenAI-compatible speech endpoint with invalid speed"""
|
||||||
test_request = {
|
test_request = {
|
||||||
"model": "kokoro",
|
"model": "kokoro",
|
||||||
|
@ -109,11 +115,12 @@ def test_openai_speech_invalid_speed(mock_tts_service):
|
||||||
"speed": -1.0, # Invalid speed
|
"speed": -1.0, # Invalid speed
|
||||||
"stream": False # Explicitly disable streaming
|
"stream": False # Explicitly disable streaming
|
||||||
}
|
}
|
||||||
response = client.post("/v1/audio/speech", json=test_request)
|
response = await async_client.post("/v1/audio/speech", json=test_request)
|
||||||
assert response.status_code == 422 # Validation error
|
assert response.status_code == 422 # Validation error
|
||||||
|
|
||||||
|
|
||||||
def test_openai_speech_generation_error(mock_tts_service):
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_speech_generation_error(mock_tts_service, async_client):
|
||||||
"""Test error handling in speech generation"""
|
"""Test error handling in speech generation"""
|
||||||
mock_tts_service._generate_audio.side_effect = Exception("Generation failed")
|
mock_tts_service._generate_audio.side_effect = Exception("Generation failed")
|
||||||
test_request = {
|
test_request = {
|
||||||
|
@ -124,54 +131,173 @@ def test_openai_speech_generation_error(mock_tts_service):
|
||||||
"speed": 1.0,
|
"speed": 1.0,
|
||||||
"stream": False # Explicitly disable streaming
|
"stream": False # Explicitly disable streaming
|
||||||
}
|
}
|
||||||
response = client.post("/v1/audio/speech", json=test_request)
|
response = await async_client.post("/v1/audio/speech", json=test_request)
|
||||||
assert response.status_code == 500
|
assert response.status_code == 500
|
||||||
assert "Generation failed" in response.json()["detail"]["message"]
|
assert "Generation failed" in response.json()["detail"]["message"]
|
||||||
|
|
||||||
|
|
||||||
def test_combine_voices_success(mock_tts_service):
|
@pytest.mark.asyncio
|
||||||
"""Test successful voice combination"""
|
async def test_combine_voices_list_success(mock_tts_service, async_client):
|
||||||
|
"""Test successful voice combination using list format"""
|
||||||
test_voices = ["af_bella", "af_sarah"]
|
test_voices = ["af_bella", "af_sarah"]
|
||||||
mock_tts_service.combine_voices.return_value = "af_bella_af_sarah"
|
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
|
||||||
|
|
||||||
response = client.post("/v1/audio/voices/combine", json=test_voices)
|
response = await async_client.post("/v1/audio/voices/combine", json=test_voices)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["voice"] == "af_bella_af_sarah"
|
assert response.json()["voice"] == "af_bella_af_sarah"
|
||||||
mock_tts_service.combine_voices.assert_called_once_with(voices=test_voices)
|
mock_tts_service.combine_voices.assert_called_once_with(voices=test_voices)
|
||||||
|
|
||||||
|
|
||||||
def test_combine_voices_single_voice(mock_tts_service):
|
@pytest.mark.asyncio
|
||||||
"""Test combining single voice returns default voice"""
|
async def test_combine_voices_string_success(mock_tts_service, async_client):
|
||||||
|
"""Test successful voice combination using string format with +"""
|
||||||
|
test_voices = "af_bella+af_sarah"
|
||||||
|
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
|
||||||
|
|
||||||
|
response = await async_client.post("/v1/audio/voices/combine", json=test_voices)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["voice"] == "af_bella_af_sarah"
|
||||||
|
mock_tts_service.combine_voices.assert_called_once_with(voices=["af_bella", "af_sarah"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_combine_voices_single_voice(mock_tts_service, async_client):
|
||||||
|
"""Test combining single voice returns same voice"""
|
||||||
test_voices = ["af_bella"]
|
test_voices = ["af_bella"]
|
||||||
mock_tts_service.combine_voices.return_value = "af"
|
response = await async_client.post("/v1/audio/voices/combine", json=test_voices)
|
||||||
|
|
||||||
response = client.post("/v1/audio/voices/combine", json=test_voices)
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["voice"] == "af"
|
assert response.json()["voice"] == "af_bella"
|
||||||
|
|
||||||
|
|
||||||
def test_combine_voices_empty_list(mock_tts_service):
|
@pytest.mark.asyncio
|
||||||
"""Test combining empty voice list returns default voice"""
|
async def test_combine_voices_empty_list(mock_tts_service, async_client):
|
||||||
|
"""Test combining empty voice list returns error"""
|
||||||
test_voices = []
|
test_voices = []
|
||||||
mock_tts_service.combine_voices.return_value = "af"
|
response = await async_client.post("/v1/audio/voices/combine", json=test_voices)
|
||||||
|
assert response.status_code == 400
|
||||||
response = client.post("/v1/audio/voices/combine", json=test_voices)
|
assert "No voices provided" in response.json()["detail"]["message"]
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["voice"] == "af"
|
|
||||||
|
|
||||||
|
|
||||||
def test_combine_voices_error(mock_tts_service):
|
@pytest.mark.asyncio
|
||||||
|
async def test_combine_voices_error(mock_tts_service, async_client):
|
||||||
"""Test error handling in voice combination"""
|
"""Test error handling in voice combination"""
|
||||||
test_voices = ["af_bella", "af_sarah"]
|
test_voices = ["af_bella", "af_sarah"]
|
||||||
mock_tts_service.combine_voices.side_effect = Exception("Combination failed")
|
mock_tts_service.combine_voices = AsyncMock(side_effect=Exception("Combination failed"))
|
||||||
|
|
||||||
response = client.post("/v1/audio/voices/combine", json=test_voices)
|
|
||||||
|
|
||||||
|
response = await async_client.post("/v1/audio/voices/combine", json=test_voices)
|
||||||
assert response.status_code == 500
|
assert response.status_code == 500
|
||||||
assert "Combination failed" in response.json()["detail"]["message"]
|
assert "Server error" in response.json()["detail"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_speech_with_combined_voice(mock_tts_service, mock_audio_service, async_client):
|
||||||
|
"""Test speech generation with combined voice using + syntax"""
|
||||||
|
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
|
||||||
|
|
||||||
|
test_request = {
|
||||||
|
"model": "kokoro",
|
||||||
|
"input": "Hello world",
|
||||||
|
"voice": "af_bella+af_sarah",
|
||||||
|
"response_format": "wav",
|
||||||
|
"speed": 1.0,
|
||||||
|
"stream": False
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await async_client.post("/v1/audio/speech", json=test_request)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers["content-type"] == "audio/wav"
|
||||||
|
mock_tts_service._generate_audio.assert_called_once_with(
|
||||||
|
text="Hello world",
|
||||||
|
voice="af_bella_af_sarah",
|
||||||
|
speed=1.0,
|
||||||
|
stitch_long_output=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_speech_with_whitespace_in_voice(mock_tts_service, mock_audio_service, async_client):
|
||||||
|
"""Test speech generation with whitespace in voice combination"""
|
||||||
|
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
|
||||||
|
|
||||||
|
test_request = {
|
||||||
|
"model": "kokoro",
|
||||||
|
"input": "Hello world",
|
||||||
|
"voice": " af_bella + af_sarah ",
|
||||||
|
"response_format": "wav",
|
||||||
|
"speed": 1.0,
|
||||||
|
"stream": False
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await async_client.post("/v1/audio/speech", json=test_request)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers["content-type"] == "audio/wav"
|
||||||
|
mock_tts_service.combine_voices.assert_called_once_with(voices=["af_bella", "af_sarah"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_speech_with_empty_voice_combination(mock_tts_service, async_client):
|
||||||
|
"""Test speech generation with empty voice combination"""
|
||||||
|
test_request = {
|
||||||
|
"model": "kokoro",
|
||||||
|
"input": "Hello world",
|
||||||
|
"voice": "+",
|
||||||
|
"response_format": "wav",
|
||||||
|
"speed": 1.0,
|
||||||
|
"stream": False
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await async_client.post("/v1/audio/speech", json=test_request)
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "No voices provided" in response.json()["detail"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_speech_with_invalid_combined_voice(mock_tts_service, async_client):
|
||||||
|
"""Test speech generation with invalid voice combination"""
|
||||||
|
test_request = {
|
||||||
|
"model": "kokoro",
|
||||||
|
"input": "Hello world",
|
||||||
|
"voice": "invalid+combination",
|
||||||
|
"response_format": "wav",
|
||||||
|
"speed": 1.0,
|
||||||
|
"stream": False
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await async_client.post("/v1/audio/speech", json=test_request)
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "not found" in response.json()["detail"]["message"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_speech_streaming_with_combined_voice(mock_tts_service, async_client):
|
||||||
|
"""Test streaming speech with combined voice using + syntax"""
|
||||||
|
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
|
||||||
|
|
||||||
|
test_request = {
|
||||||
|
"model": "kokoro",
|
||||||
|
"input": "Hello world",
|
||||||
|
"voice": "af_bella+af_sarah",
|
||||||
|
"response_format": "mp3",
|
||||||
|
"stream": True
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create streaming mock
|
||||||
|
async def mock_stream(*args, **kwargs):
|
||||||
|
for chunk in [b"mp3header", b"mp3data"]:
|
||||||
|
yield chunk
|
||||||
|
mock_tts_service.generate_audio_stream = mock_stream
|
||||||
|
|
||||||
|
# Add streaming header
|
||||||
|
headers = {"x-raw-response": "stream"}
|
||||||
|
response = await async_client.post("/v1/audio/speech", json=test_request, headers=headers)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers["content-type"] == "audio/mpeg"
|
||||||
|
assert response.headers["content-disposition"] == "attachment; filename=speech.mp3"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -197,9 +323,6 @@ async def test_openai_speech_pcm_streaming(mock_tts_service, async_client):
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.headers["content-type"] == "audio/pcm"
|
assert response.headers["content-type"] == "audio/pcm"
|
||||||
# Just verify status and content type
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.headers["content-type"] == "audio/pcm"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -226,10 +349,6 @@ async def test_openai_speech_streaming_mp3(mock_tts_service, async_client):
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.headers["content-type"] == "audio/mpeg"
|
assert response.headers["content-type"] == "audio/mpeg"
|
||||||
assert response.headers["content-disposition"] == "attachment; filename=speech.mp3"
|
assert response.headers["content-disposition"] == "attachment; filename=speech.mp3"
|
||||||
# Just verify status and content type
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.headers["content-type"] == "audio/mpeg"
|
|
||||||
assert response.headers["content-disposition"] == "attachment; filename=speech.mp3"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -255,6 +374,3 @@ async def test_openai_speech_streaming_generator(mock_tts_service, async_client)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.headers["content-type"] == "audio/pcm"
|
assert response.headers["content-type"] == "audio/pcm"
|
||||||
# Just verify status and content type
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.headers["content-type"] == "audio/pcm"
|
|
||||||
|
|
|
@ -1,12 +1,13 @@
|
||||||
"""Tests for TTSService"""
|
"""Tests for TTSService"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from unittest.mock import MagicMock, call, patch
|
from unittest.mock import MagicMock, call, patch, AsyncMock
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import pytest
|
import pytest
|
||||||
from onnxruntime import InferenceSession
|
from onnxruntime import InferenceSession
|
||||||
|
from aiofiles import threadpool
|
||||||
|
|
||||||
from api.src.core.config import settings
|
from api.src.core.config import settings
|
||||||
from api.src.services.tts_model import TTSModel
|
from api.src.services.tts_model import TTSModel
|
||||||
|
@ -38,27 +39,33 @@ def test_audio_to_bytes(tts_service, sample_audio):
|
||||||
assert len(audio_bytes) > 0
|
assert len(audio_bytes) > 0
|
||||||
|
|
||||||
|
|
||||||
@patch("os.listdir")
|
@pytest.mark.asyncio
|
||||||
@patch("os.path.join")
|
async def test_list_voices(tts_service):
|
||||||
def test_list_voices(mock_join, mock_listdir, tts_service):
|
|
||||||
"""Test listing available voices"""
|
"""Test listing available voices"""
|
||||||
mock_listdir.return_value = ["voice1.pt", "voice2.pt", "not_a_voice.txt"]
|
# Mock os.listdir to return test files
|
||||||
mock_join.return_value = "/fake/path"
|
with patch('os.listdir', return_value=["voice1.pt", "voice2.pt", "not_a_voice.txt"]):
|
||||||
|
# Register mock with threadpool
|
||||||
voices = tts_service.list_voices()
|
async_listdir = AsyncMock(return_value=["voice1.pt", "voice2.pt", "not_a_voice.txt"])
|
||||||
assert len(voices) == 2
|
threadpool.async_wrap = MagicMock(return_value=async_listdir)
|
||||||
assert "voice1" in voices
|
|
||||||
assert "voice2" in voices
|
voices = await tts_service.list_voices()
|
||||||
assert "not_a_voice" not in voices
|
assert len(voices) == 2
|
||||||
|
assert "voice1" in voices
|
||||||
|
assert "voice2" in voices
|
||||||
|
assert "not_a_voice" not in voices
|
||||||
|
|
||||||
|
|
||||||
@patch("os.listdir")
|
@pytest.mark.asyncio
|
||||||
def test_list_voices_error(mock_listdir, tts_service):
|
async def test_list_voices_error(tts_service):
|
||||||
"""Test error handling in list_voices"""
|
"""Test error handling in list_voices"""
|
||||||
mock_listdir.side_effect = Exception("Failed to list directory")
|
# Mock os.listdir to raise an exception
|
||||||
|
with patch('os.listdir', side_effect=Exception("Failed to list directory")):
|
||||||
voices = tts_service.list_voices()
|
# Register mock with threadpool
|
||||||
assert voices == []
|
async_listdir = AsyncMock(side_effect=Exception("Failed to list directory"))
|
||||||
|
threadpool.async_wrap = MagicMock(return_value=async_listdir)
|
||||||
|
|
||||||
|
voices = await tts_service.list_voices()
|
||||||
|
assert voices == []
|
||||||
|
|
||||||
|
|
||||||
def mock_model_setup(cuda_available=False):
|
def mock_model_setup(cuda_available=False):
|
||||||
|
@ -176,7 +183,8 @@ def test_save_audio(tts_service, sample_audio, tmp_path):
|
||||||
assert os.path.getsize(output_path) > 0
|
assert os.path.getsize(output_path) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_combine_voices(tts_service):
|
@pytest.mark.asyncio
|
||||||
|
async def test_combine_voices(tts_service):
|
||||||
"""Test combining multiple voices"""
|
"""Test combining multiple voices"""
|
||||||
# Setup mocks for torch operations
|
# Setup mocks for torch operations
|
||||||
with patch('torch.load', return_value=torch.tensor([1.0, 2.0])), \
|
with patch('torch.load', return_value=torch.tensor([1.0, 2.0])), \
|
||||||
|
@ -186,20 +194,21 @@ def test_combine_voices(tts_service):
|
||||||
patch('os.path.exists', return_value=True):
|
patch('os.path.exists', return_value=True):
|
||||||
|
|
||||||
# Test combining two voices
|
# Test combining two voices
|
||||||
result = tts_service.combine_voices(["voice1", "voice2"])
|
result = await tts_service.combine_voices(["voice1", "voice2"])
|
||||||
|
|
||||||
assert result == "voice1_voice2"
|
assert result == "voice1_voice2"
|
||||||
|
|
||||||
|
|
||||||
def test_combine_voices_invalid_input(tts_service):
|
@pytest.mark.asyncio
|
||||||
|
async def test_combine_voices_invalid_input(tts_service):
|
||||||
"""Test combining voices with invalid input"""
|
"""Test combining voices with invalid input"""
|
||||||
# Test with empty list
|
# Test with empty list
|
||||||
with pytest.raises(ValueError, match="At least 2 voices are required"):
|
with pytest.raises(ValueError, match="At least 2 voices are required"):
|
||||||
tts_service.combine_voices([])
|
await tts_service.combine_voices([])
|
||||||
|
|
||||||
# Test with single voice
|
# Test with single voice
|
||||||
with pytest.raises(ValueError, match="At least 2 voices are required"):
|
with pytest.raises(ValueError, match="At least 2 voices are required"):
|
||||||
tts_service.combine_voices(["voice1"])
|
await tts_service.combine_voices(["voice1"])
|
||||||
|
|
||||||
|
|
||||||
@patch("api.src.services.tts_service.TTSService._get_voice_path")
|
@patch("api.src.services.tts_service.TTSService._get_voice_path")
|
||||||
|
|
|
@ -34,7 +34,7 @@ def stream_to_speakers() -> None:
|
||||||
|
|
||||||
with openai.audio.speech.with_streaming_response.create(
|
with openai.audio.speech.with_streaming_response.create(
|
||||||
model="kokoro",
|
model="kokoro",
|
||||||
voice="af",
|
voice="af_sky+af_bella+bm_george",
|
||||||
response_format="pcm", # similar to WAV, but without a header chunk at the start.
|
response_format="pcm", # similar to WAV, but without a header chunk at the start.
|
||||||
input="""My dear sir, that is just where you are wrong. That is just where the whole world has gone wrong. We are always getting away from the present moment. Our mental existences, which are immaterial and have no dimensions, are passing along the Time-Dimension with a uniform velocity from the cradle to the grave. Just as we should travel down if we began our existence fifty miles above the earth’s surface""",
|
input="""My dear sir, that is just where you are wrong. That is just where the whole world has gone wrong. We are always getting away from the present moment. Our mental existences, which are immaterial and have no dimensions, are passing along the Time-Dimension with a uniform velocity from the cradle to the grave. Just as we should travel down if we began our existence fifty miles above the earth’s surface""",
|
||||||
) as response:
|
) as response:
|
||||||
|
|
BIN
examples/speech.mp3
Normal file
BIN
examples/speech.mp3
Normal file
Binary file not shown.
|
@ -20,6 +20,7 @@ phonemizer==3.3.0
|
||||||
regex==2024.11.6
|
regex==2024.11.6
|
||||||
|
|
||||||
# Utilities
|
# Utilities
|
||||||
|
aiofiles==24.1.0
|
||||||
tqdm==4.67.1
|
tqdm==4.67.1
|
||||||
requests==2.32.3
|
requests==2.32.3
|
||||||
munch==4.0.0
|
munch==4.0.0
|
||||||
|
|
Loading…
Add table
Reference in a new issue