mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
- Added support for combining voices via any endpoint
- Updated the `process_voices` function to handle both string and list formats for voice input.
This commit is contained in:
parent
bb1f9b54ba
commit
130b084cce
10 changed files with 259 additions and 104 deletions
|
@ -1 +1 @@
|
|||
Subproject commit c97b7bbc3e60f447383c79b2f94fee861ff156ac
|
||||
Subproject commit 3095858c40fc22e28c46429da9340dfda1f8cf28
|
17
README.md
17
README.md
|
@ -3,17 +3,17 @@
|
|||
</p>
|
||||
|
||||
# Kokoro TTS API
|
||||
[]()
|
||||
[]()
|
||||
[]()
|
||||
[]()
|
||||
[](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667) [](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero)
|
||||
|
||||
Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model
|
||||
- OpenAI-compatible Speech endpoint, with voice combination functionality
|
||||
- OpenAI-compatible Speech endpoint, with inline voice combination functionality
|
||||
- NVIDIA GPU accelerated inference (or CPU) option
|
||||
- very fast generation time
|
||||
- ~ 35x real time speed via 4060Ti, ~300ms latency
|
||||
- ~ 6x real time spead via M3 Pro CPU, ~1000ms latency
|
||||
- streaming support w/ variable chunking control latency & artifacts
|
||||
- streaming support w/ variable chunking to control latency & artifacts
|
||||
- simple audio generation web ui utility
|
||||
|
||||
|
||||
|
@ -39,7 +39,7 @@ The service can be accessed through either the API endpoints or the Gradio web i
|
|||
|
||||
response = client.audio.speech.create(
|
||||
model="kokoro",
|
||||
voice="af_bella",
|
||||
voice="af_sky+af_bella", #single or multiple voicepack combo
|
||||
input="Hello world!",
|
||||
response_format="mp3"
|
||||
)
|
||||
|
@ -61,7 +61,7 @@ from openai import OpenAI
|
|||
client = OpenAI(base_url="http://localhost:8880", api_key="not-needed")
|
||||
response = client.audio.speech.create(
|
||||
model="kokoro", # Not used but required for compatibility, also accepts library defaults
|
||||
voice="af_bella",
|
||||
voice="af_bella+af_sky",
|
||||
input="Hello world!",
|
||||
response_format="mp3"
|
||||
)
|
||||
|
@ -105,6 +105,7 @@ python examples/test_all_voices.py # Test all available voices
|
|||
|
||||
- Averages model weights of any existing voicepacks
|
||||
- Saves generated voicepacks for future use
|
||||
- (new) Available through any endpoint, simply concatenate desired packs with "+"
|
||||
|
||||
Combine voices and generate audio:
|
||||
```python
|
||||
|
@ -119,12 +120,12 @@ response = requests.post(
|
|||
)
|
||||
combined_voice = response.json()["voice"]
|
||||
|
||||
# Generate audio with combined voice
|
||||
# Generate audio with combined voice (or, simply pass multiple directly with `+` )
|
||||
response = requests.post(
|
||||
"http://localhost:8880/v1/audio/speech",
|
||||
json={
|
||||
"input": "Hello world!",
|
||||
"voice": combined_voice,
|
||||
"voice": combined_voice, # or skip the above step with f"{voices[0]}+{voices[1]}"
|
||||
"response_format": "mp3"
|
||||
}
|
||||
)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List
|
||||
from typing import List, Union
|
||||
|
||||
from loguru import logger
|
||||
from fastapi import Depends, Response, APIRouter, HTTPException
|
||||
|
@ -20,18 +20,43 @@ def get_tts_service() -> TTSService:
|
|||
return TTSService() # Initialize TTSService with default settings
|
||||
|
||||
|
||||
async def process_voices(voice_input: Union[str, List[str]], tts_service: TTSService) -> str:
|
||||
"""Process voice input into a combined voice, handling both string and list formats"""
|
||||
# Convert input to list of voices
|
||||
if isinstance(voice_input, str):
|
||||
voices = [v.strip() for v in voice_input.split("+") if v.strip()]
|
||||
else:
|
||||
voices = voice_input
|
||||
|
||||
if not voices:
|
||||
raise ValueError("No voices provided")
|
||||
|
||||
# Check if all voices exist
|
||||
available_voices = await tts_service.list_voices()
|
||||
for voice in voices:
|
||||
if voice not in available_voices:
|
||||
raise ValueError(f"Voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}")
|
||||
|
||||
# If single voice, return it directly
|
||||
if len(voices) == 1:
|
||||
return voices[0]
|
||||
|
||||
# Otherwise combine voices
|
||||
return await tts_service.combine_voices(voices=voices)
|
||||
|
||||
|
||||
async def stream_audio_chunks(tts_service: TTSService, request: OpenAISpeechRequest) -> AsyncGenerator[bytes, None]:
|
||||
"""Stream audio chunks as they're generated"""
|
||||
voice_to_use = await process_voices(request.voice, tts_service)
|
||||
async for chunk in tts_service.generate_audio_stream(
|
||||
text=request.input,
|
||||
voice=request.voice,
|
||||
voice=voice_to_use,
|
||||
speed=request.speed,
|
||||
output_format=request.response_format
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
||||
|
||||
@router.post("/audio/speech")
|
||||
async def create_speech(
|
||||
request: OpenAISpeechRequest,
|
||||
|
@ -40,12 +65,8 @@ async def create_speech(
|
|||
):
|
||||
"""OpenAI-compatible endpoint for text-to-speech"""
|
||||
try:
|
||||
# Validate voice exists
|
||||
available_voices = tts_service.list_voices()
|
||||
if request.voice not in available_voices:
|
||||
raise ValueError(
|
||||
f"Voice '{request.voice}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
||||
)
|
||||
# Process voice combination and validate
|
||||
voice_to_use = await process_voices(request.voice, tts_service)
|
||||
|
||||
# Set content type based on format
|
||||
content_type = {
|
||||
|
@ -73,7 +94,7 @@ async def create_speech(
|
|||
# Generate complete audio
|
||||
audio, _ = tts_service._generate_audio(
|
||||
text=request.input,
|
||||
voice=request.voice,
|
||||
voice=voice_to_use,
|
||||
speed=request.speed,
|
||||
stitch_long_output=True,
|
||||
)
|
||||
|
@ -111,7 +132,7 @@ async def create_speech(
|
|||
async def list_voices(tts_service: TTSService = Depends(get_tts_service)):
|
||||
"""List all available voices for text-to-speech"""
|
||||
try:
|
||||
voices = tts_service.list_voices()
|
||||
voices = await tts_service.list_voices()
|
||||
return {"voices": voices}
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing voices: {str(e)}")
|
||||
|
@ -120,12 +141,13 @@ async def list_voices(tts_service: TTSService = Depends(get_tts_service)):
|
|||
|
||||
@router.post("/audio/voices/combine")
|
||||
async def combine_voices(
|
||||
request: List[str], tts_service: TTSService = Depends(get_tts_service)
|
||||
request: Union[str, List[str]], tts_service: TTSService = Depends(get_tts_service)
|
||||
):
|
||||
"""Combine multiple voices into a new voice.
|
||||
|
||||
Args:
|
||||
request: List of voice names to combine
|
||||
request: Either a string with voices separated by + (e.g. "voice1+voice2")
|
||||
or a list of voice names to combine
|
||||
|
||||
Returns:
|
||||
Dict with combined voice name and list of all available voices
|
||||
|
@ -136,8 +158,8 @@ async def combine_voices(
|
|||
- 500: Server error (file system issues, combination failed)
|
||||
"""
|
||||
try:
|
||||
combined_voice = tts_service.combine_voices(voices=request)
|
||||
voices = tts_service.list_voices()
|
||||
combined_voice = await process_voices(request, tts_service)
|
||||
voices = await tts_service.list_voices()
|
||||
return {"voices": voices, "voice": combined_voice}
|
||||
|
||||
except ValueError as e:
|
||||
|
@ -146,14 +168,8 @@ async def combine_voices(
|
|||
status_code=400, detail={"error": "Invalid request", "message": str(e)}
|
||||
)
|
||||
|
||||
except RuntimeError as e:
|
||||
except Exception as e:
|
||||
logger.error(f"Server error during voice combination: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail={"error": "Server error", "message": str(e)}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during voice combination: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail={"error": "Unexpected error", "message": str(e)}
|
||||
status_code=500, detail={"error": "Server error", "message": "Server error"}
|
||||
)
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import aiofiles
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import List, Tuple, Optional
|
||||
from functools import lru_cache
|
||||
from aiofiles import threadpool
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -211,7 +213,7 @@ class TTSService:
|
|||
wavfile.write(buffer, 24000, audio)
|
||||
return buffer.getvalue()
|
||||
|
||||
def combine_voices(self, voices: List[str]) -> str:
|
||||
async def combine_voices(self, voices: List[str]) -> str:
|
||||
"""Combine multiple voices into a new voice"""
|
||||
if len(voices) < 2:
|
||||
raise ValueError("At least 2 voices are required for combination")
|
||||
|
@ -252,11 +254,13 @@ class TTSService:
|
|||
raise RuntimeError(f"Error combining voices: {str(e)}")
|
||||
raise
|
||||
|
||||
def list_voices(self) -> List[str]:
|
||||
async def list_voices(self) -> List[str]:
|
||||
"""List all available voices"""
|
||||
voices = []
|
||||
try:
|
||||
for file in os.listdir(TTSModel.VOICES_DIR):
|
||||
# Use os.listdir in a thread pool
|
||||
files = await threadpool.async_wrap(os.listdir)(TTSModel.VOICES_DIR)
|
||||
for file in files:
|
||||
if file.endswith(".pt"):
|
||||
voices.append(file[:-3]) # Remove .pt extension
|
||||
except Exception as e:
|
||||
|
|
|
@ -1,9 +1,17 @@
|
|||
from enum import Enum
|
||||
from typing import Literal
|
||||
from typing import Literal, Union, List
|
||||
|
||||
from pydantic import Field, BaseModel
|
||||
|
||||
|
||||
class VoiceCombineRequest(BaseModel):
|
||||
"""Request schema for voice combination endpoint that accepts either a string with + or a list"""
|
||||
voices: Union[str, List[str]] = Field(
|
||||
...,
|
||||
description="Either a string with voices separated by + (e.g. 'voice1+voice2') or a list of voice names to combine"
|
||||
)
|
||||
|
||||
|
||||
class TTSStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
|
|
|
@ -29,7 +29,9 @@ def mock_tts_service(monkeypatch):
|
|||
for chunk in [b"chunk1", b"chunk2"]:
|
||||
yield chunk
|
||||
mock_service.generate_audio_stream = mock_stream
|
||||
mock_service.list_voices.return_value = [
|
||||
|
||||
# Create async mocks
|
||||
mock_service.list_voices = AsyncMock(return_value=[
|
||||
"af",
|
||||
"bm_lewis",
|
||||
"bf_isabella",
|
||||
|
@ -39,7 +41,8 @@ def mock_tts_service(monkeypatch):
|
|||
"am_adam",
|
||||
"am_michael",
|
||||
"bm_george",
|
||||
]
|
||||
])
|
||||
mock_service.combine_voices = AsyncMock()
|
||||
monkeypatch.setattr(
|
||||
"api.src.routers.openai_compatible.TTSService",
|
||||
lambda *args, **kwargs: mock_service,
|
||||
|
@ -64,7 +67,8 @@ def test_health_check():
|
|||
assert response.json() == {"status": "healthy"}
|
||||
|
||||
|
||||
def test_openai_speech_endpoint(mock_tts_service, mock_audio_service):
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_speech_endpoint(mock_tts_service, mock_audio_service, async_client):
|
||||
"""Test the OpenAI-compatible speech endpoint"""
|
||||
test_request = {
|
||||
"model": "kokoro",
|
||||
|
@ -74,7 +78,7 @@ def test_openai_speech_endpoint(mock_tts_service, mock_audio_service):
|
|||
"speed": 1.0,
|
||||
"stream": False # Explicitly disable streaming
|
||||
}
|
||||
response = client.post("/v1/audio/speech", json=test_request)
|
||||
response = await async_client.post("/v1/audio/speech", json=test_request)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/wav"
|
||||
assert response.headers["content-disposition"] == "attachment; filename=speech.wav"
|
||||
|
@ -84,7 +88,8 @@ def test_openai_speech_endpoint(mock_tts_service, mock_audio_service):
|
|||
assert response.content == b"converted mock audio data"
|
||||
|
||||
|
||||
def test_openai_speech_invalid_voice(mock_tts_service):
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_speech_invalid_voice(mock_tts_service, async_client):
|
||||
"""Test the OpenAI-compatible speech endpoint with invalid voice"""
|
||||
test_request = {
|
||||
"model": "kokoro",
|
||||
|
@ -94,12 +99,13 @@ def test_openai_speech_invalid_voice(mock_tts_service):
|
|||
"speed": 1.0,
|
||||
"stream": False # Explicitly disable streaming
|
||||
}
|
||||
response = client.post("/v1/audio/speech", json=test_request)
|
||||
response = await async_client.post("/v1/audio/speech", json=test_request)
|
||||
assert response.status_code == 400 # Bad request
|
||||
assert "not found" in response.json()["detail"]["message"]
|
||||
|
||||
|
||||
def test_openai_speech_invalid_speed(mock_tts_service):
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_speech_invalid_speed(mock_tts_service, async_client):
|
||||
"""Test the OpenAI-compatible speech endpoint with invalid speed"""
|
||||
test_request = {
|
||||
"model": "kokoro",
|
||||
|
@ -109,11 +115,12 @@ def test_openai_speech_invalid_speed(mock_tts_service):
|
|||
"speed": -1.0, # Invalid speed
|
||||
"stream": False # Explicitly disable streaming
|
||||
}
|
||||
response = client.post("/v1/audio/speech", json=test_request)
|
||||
response = await async_client.post("/v1/audio/speech", json=test_request)
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
|
||||
def test_openai_speech_generation_error(mock_tts_service):
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_speech_generation_error(mock_tts_service, async_client):
|
||||
"""Test error handling in speech generation"""
|
||||
mock_tts_service._generate_audio.side_effect = Exception("Generation failed")
|
||||
test_request = {
|
||||
|
@ -124,54 +131,173 @@ def test_openai_speech_generation_error(mock_tts_service):
|
|||
"speed": 1.0,
|
||||
"stream": False # Explicitly disable streaming
|
||||
}
|
||||
response = client.post("/v1/audio/speech", json=test_request)
|
||||
response = await async_client.post("/v1/audio/speech", json=test_request)
|
||||
assert response.status_code == 500
|
||||
assert "Generation failed" in response.json()["detail"]["message"]
|
||||
|
||||
|
||||
def test_combine_voices_success(mock_tts_service):
|
||||
"""Test successful voice combination"""
|
||||
@pytest.mark.asyncio
|
||||
async def test_combine_voices_list_success(mock_tts_service, async_client):
|
||||
"""Test successful voice combination using list format"""
|
||||
test_voices = ["af_bella", "af_sarah"]
|
||||
mock_tts_service.combine_voices.return_value = "af_bella_af_sarah"
|
||||
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
|
||||
|
||||
response = client.post("/v1/audio/voices/combine", json=test_voices)
|
||||
response = await async_client.post("/v1/audio/voices/combine", json=test_voices)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["voice"] == "af_bella_af_sarah"
|
||||
mock_tts_service.combine_voices.assert_called_once_with(voices=test_voices)
|
||||
|
||||
|
||||
def test_combine_voices_single_voice(mock_tts_service):
|
||||
"""Test combining single voice returns default voice"""
|
||||
@pytest.mark.asyncio
|
||||
async def test_combine_voices_string_success(mock_tts_service, async_client):
|
||||
"""Test successful voice combination using string format with +"""
|
||||
test_voices = "af_bella+af_sarah"
|
||||
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
|
||||
|
||||
response = await async_client.post("/v1/audio/voices/combine", json=test_voices)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["voice"] == "af_bella_af_sarah"
|
||||
mock_tts_service.combine_voices.assert_called_once_with(voices=["af_bella", "af_sarah"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_combine_voices_single_voice(mock_tts_service, async_client):
|
||||
"""Test combining single voice returns same voice"""
|
||||
test_voices = ["af_bella"]
|
||||
mock_tts_service.combine_voices.return_value = "af"
|
||||
|
||||
response = client.post("/v1/audio/voices/combine", json=test_voices)
|
||||
|
||||
response = await async_client.post("/v1/audio/voices/combine", json=test_voices)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["voice"] == "af"
|
||||
assert response.json()["voice"] == "af_bella"
|
||||
|
||||
|
||||
def test_combine_voices_empty_list(mock_tts_service):
|
||||
"""Test combining empty voice list returns default voice"""
|
||||
@pytest.mark.asyncio
|
||||
async def test_combine_voices_empty_list(mock_tts_service, async_client):
|
||||
"""Test combining empty voice list returns error"""
|
||||
test_voices = []
|
||||
mock_tts_service.combine_voices.return_value = "af"
|
||||
|
||||
response = client.post("/v1/audio/voices/combine", json=test_voices)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["voice"] == "af"
|
||||
response = await async_client.post("/v1/audio/voices/combine", json=test_voices)
|
||||
assert response.status_code == 400
|
||||
assert "No voices provided" in response.json()["detail"]["message"]
|
||||
|
||||
|
||||
def test_combine_voices_error(mock_tts_service):
|
||||
@pytest.mark.asyncio
|
||||
async def test_combine_voices_error(mock_tts_service, async_client):
|
||||
"""Test error handling in voice combination"""
|
||||
test_voices = ["af_bella", "af_sarah"]
|
||||
mock_tts_service.combine_voices.side_effect = Exception("Combination failed")
|
||||
|
||||
response = client.post("/v1/audio/voices/combine", json=test_voices)
|
||||
mock_tts_service.combine_voices = AsyncMock(side_effect=Exception("Combination failed"))
|
||||
|
||||
response = await async_client.post("/v1/audio/voices/combine", json=test_voices)
|
||||
assert response.status_code == 500
|
||||
assert "Combination failed" in response.json()["detail"]["message"]
|
||||
assert "Server error" in response.json()["detail"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_speech_with_combined_voice(mock_tts_service, mock_audio_service, async_client):
|
||||
"""Test speech generation with combined voice using + syntax"""
|
||||
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
|
||||
|
||||
test_request = {
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": "af_bella+af_sarah",
|
||||
"response_format": "wav",
|
||||
"speed": 1.0,
|
||||
"stream": False
|
||||
}
|
||||
|
||||
response = await async_client.post("/v1/audio/speech", json=test_request)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/wav"
|
||||
mock_tts_service._generate_audio.assert_called_once_with(
|
||||
text="Hello world",
|
||||
voice="af_bella_af_sarah",
|
||||
speed=1.0,
|
||||
stitch_long_output=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_speech_with_whitespace_in_voice(mock_tts_service, mock_audio_service, async_client):
|
||||
"""Test speech generation with whitespace in voice combination"""
|
||||
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
|
||||
|
||||
test_request = {
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": " af_bella + af_sarah ",
|
||||
"response_format": "wav",
|
||||
"speed": 1.0,
|
||||
"stream": False
|
||||
}
|
||||
|
||||
response = await async_client.post("/v1/audio/speech", json=test_request)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/wav"
|
||||
mock_tts_service.combine_voices.assert_called_once_with(voices=["af_bella", "af_sarah"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_speech_with_empty_voice_combination(mock_tts_service, async_client):
|
||||
"""Test speech generation with empty voice combination"""
|
||||
test_request = {
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": "+",
|
||||
"response_format": "wav",
|
||||
"speed": 1.0,
|
||||
"stream": False
|
||||
}
|
||||
|
||||
response = await async_client.post("/v1/audio/speech", json=test_request)
|
||||
assert response.status_code == 400
|
||||
assert "No voices provided" in response.json()["detail"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_speech_with_invalid_combined_voice(mock_tts_service, async_client):
|
||||
"""Test speech generation with invalid voice combination"""
|
||||
test_request = {
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": "invalid+combination",
|
||||
"response_format": "wav",
|
||||
"speed": 1.0,
|
||||
"stream": False
|
||||
}
|
||||
|
||||
response = await async_client.post("/v1/audio/speech", json=test_request)
|
||||
assert response.status_code == 400
|
||||
assert "not found" in response.json()["detail"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_speech_streaming_with_combined_voice(mock_tts_service, async_client):
|
||||
"""Test streaming speech with combined voice using + syntax"""
|
||||
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
|
||||
|
||||
test_request = {
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": "af_bella+af_sarah",
|
||||
"response_format": "mp3",
|
||||
"stream": True
|
||||
}
|
||||
|
||||
# Create streaming mock
|
||||
async def mock_stream(*args, **kwargs):
|
||||
for chunk in [b"mp3header", b"mp3data"]:
|
||||
yield chunk
|
||||
mock_tts_service.generate_audio_stream = mock_stream
|
||||
|
||||
# Add streaming header
|
||||
headers = {"x-raw-response": "stream"}
|
||||
response = await async_client.post("/v1/audio/speech", json=test_request, headers=headers)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/mpeg"
|
||||
assert response.headers["content-disposition"] == "attachment; filename=speech.mp3"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -197,9 +323,6 @@ async def test_openai_speech_pcm_streaming(mock_tts_service, async_client):
|
|||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/pcm"
|
||||
# Just verify status and content type
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/pcm"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -226,10 +349,6 @@ async def test_openai_speech_streaming_mp3(mock_tts_service, async_client):
|
|||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/mpeg"
|
||||
assert response.headers["content-disposition"] == "attachment; filename=speech.mp3"
|
||||
# Just verify status and content type
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/mpeg"
|
||||
assert response.headers["content-disposition"] == "attachment; filename=speech.mp3"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -255,6 +374,3 @@ async def test_openai_speech_streaming_generator(mock_tts_service, async_client)
|
|||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/pcm"
|
||||
# Just verify status and content type
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/pcm"
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
"""Tests for TTSService"""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
from unittest.mock import MagicMock, call, patch, AsyncMock
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import pytest
|
||||
from onnxruntime import InferenceSession
|
||||
from aiofiles import threadpool
|
||||
|
||||
from api.src.core.config import settings
|
||||
from api.src.services.tts_model import TTSModel
|
||||
|
@ -38,27 +39,33 @@ def test_audio_to_bytes(tts_service, sample_audio):
|
|||
assert len(audio_bytes) > 0
|
||||
|
||||
|
||||
@patch("os.listdir")
|
||||
@patch("os.path.join")
|
||||
def test_list_voices(mock_join, mock_listdir, tts_service):
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_voices(tts_service):
|
||||
"""Test listing available voices"""
|
||||
mock_listdir.return_value = ["voice1.pt", "voice2.pt", "not_a_voice.txt"]
|
||||
mock_join.return_value = "/fake/path"
|
||||
|
||||
voices = tts_service.list_voices()
|
||||
assert len(voices) == 2
|
||||
assert "voice1" in voices
|
||||
assert "voice2" in voices
|
||||
assert "not_a_voice" not in voices
|
||||
# Mock os.listdir to return test files
|
||||
with patch('os.listdir', return_value=["voice1.pt", "voice2.pt", "not_a_voice.txt"]):
|
||||
# Register mock with threadpool
|
||||
async_listdir = AsyncMock(return_value=["voice1.pt", "voice2.pt", "not_a_voice.txt"])
|
||||
threadpool.async_wrap = MagicMock(return_value=async_listdir)
|
||||
|
||||
voices = await tts_service.list_voices()
|
||||
assert len(voices) == 2
|
||||
assert "voice1" in voices
|
||||
assert "voice2" in voices
|
||||
assert "not_a_voice" not in voices
|
||||
|
||||
|
||||
@patch("os.listdir")
|
||||
def test_list_voices_error(mock_listdir, tts_service):
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_voices_error(tts_service):
|
||||
"""Test error handling in list_voices"""
|
||||
mock_listdir.side_effect = Exception("Failed to list directory")
|
||||
|
||||
voices = tts_service.list_voices()
|
||||
assert voices == []
|
||||
# Mock os.listdir to raise an exception
|
||||
with patch('os.listdir', side_effect=Exception("Failed to list directory")):
|
||||
# Register mock with threadpool
|
||||
async_listdir = AsyncMock(side_effect=Exception("Failed to list directory"))
|
||||
threadpool.async_wrap = MagicMock(return_value=async_listdir)
|
||||
|
||||
voices = await tts_service.list_voices()
|
||||
assert voices == []
|
||||
|
||||
|
||||
def mock_model_setup(cuda_available=False):
|
||||
|
@ -176,7 +183,8 @@ def test_save_audio(tts_service, sample_audio, tmp_path):
|
|||
assert os.path.getsize(output_path) > 0
|
||||
|
||||
|
||||
def test_combine_voices(tts_service):
|
||||
@pytest.mark.asyncio
|
||||
async def test_combine_voices(tts_service):
|
||||
"""Test combining multiple voices"""
|
||||
# Setup mocks for torch operations
|
||||
with patch('torch.load', return_value=torch.tensor([1.0, 2.0])), \
|
||||
|
@ -186,20 +194,21 @@ def test_combine_voices(tts_service):
|
|||
patch('os.path.exists', return_value=True):
|
||||
|
||||
# Test combining two voices
|
||||
result = tts_service.combine_voices(["voice1", "voice2"])
|
||||
result = await tts_service.combine_voices(["voice1", "voice2"])
|
||||
|
||||
assert result == "voice1_voice2"
|
||||
|
||||
|
||||
def test_combine_voices_invalid_input(tts_service):
|
||||
@pytest.mark.asyncio
|
||||
async def test_combine_voices_invalid_input(tts_service):
|
||||
"""Test combining voices with invalid input"""
|
||||
# Test with empty list
|
||||
with pytest.raises(ValueError, match="At least 2 voices are required"):
|
||||
tts_service.combine_voices([])
|
||||
await tts_service.combine_voices([])
|
||||
|
||||
# Test with single voice
|
||||
with pytest.raises(ValueError, match="At least 2 voices are required"):
|
||||
tts_service.combine_voices(["voice1"])
|
||||
await tts_service.combine_voices(["voice1"])
|
||||
|
||||
|
||||
@patch("api.src.services.tts_service.TTSService._get_voice_path")
|
||||
|
|
|
@ -34,7 +34,7 @@ def stream_to_speakers() -> None:
|
|||
|
||||
with openai.audio.speech.with_streaming_response.create(
|
||||
model="kokoro",
|
||||
voice="af",
|
||||
voice="af_sky+af_bella+bm_george",
|
||||
response_format="pcm", # similar to WAV, but without a header chunk at the start.
|
||||
input="""My dear sir, that is just where you are wrong. That is just where the whole world has gone wrong. We are always getting away from the present moment. Our mental existences, which are immaterial and have no dimensions, are passing along the Time-Dimension with a uniform velocity from the cradle to the grave. Just as we should travel down if we began our existence fifty miles above the earth’s surface""",
|
||||
) as response:
|
||||
|
|
BIN
examples/speech.mp3
Normal file
BIN
examples/speech.mp3
Normal file
Binary file not shown.
|
@ -20,6 +20,7 @@ phonemizer==3.3.0
|
|||
regex==2024.11.6
|
||||
|
||||
# Utilities
|
||||
aiofiles==24.1.0
|
||||
tqdm==4.67.1
|
||||
requests==2.32.3
|
||||
munch==4.0.0
|
||||
|
|
Loading…
Add table
Reference in a new issue