- Added support for combining voices via any endpoint

- Updated the `process_voices` function to handle both string and list formats for voice input.
This commit is contained in:
remsky 2025-01-07 03:50:08 -07:00
parent bb1f9b54ba
commit 130b084cce
10 changed files with 259 additions and 104 deletions

@ -1 +1 @@
Subproject commit c97b7bbc3e60f447383c79b2f94fee861ff156ac Subproject commit 3095858c40fc22e28c46429da9340dfda1f8cf28

View file

@ -3,17 +3,17 @@
</p> </p>
# Kokoro TTS API # Kokoro TTS API
[![Tests](https://img.shields.io/badge/tests-98%20passed-darkgreen)]() [![Tests](https://img.shields.io/badge/tests-105%20passed-darkgreen)]()
[![Coverage](https://img.shields.io/badge/coverage-73%25-darkgreen)]() [![Coverage](https://img.shields.io/badge/coverage-74%25-darkgreen)]()
[![Tested at Model Commit](https://img.shields.io/badge/last--tested--model--commit-a67f113-blue)](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667) [![Try on Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Try%20on-Spaces-blue)](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero) [![Tested at Model Commit](https://img.shields.io/badge/last--tested--model--commit-a67f113-blue)](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667) [![Try on Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Try%20on-Spaces-blue)](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero)
Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model
- OpenAI-compatible Speech endpoint, with voice combination functionality - OpenAI-compatible Speech endpoint, with inline voice combination functionality
- NVIDIA GPU accelerated inference (or CPU) option - NVIDIA GPU accelerated inference (or CPU) option
- very fast generation time - very fast generation time
- ~ 35x real time speed via 4060Ti, ~300ms latency - ~ 35x real time speed via 4060Ti, ~300ms latency
- ~ 6x real time spead via M3 Pro CPU, ~1000ms latency - ~ 6x real time spead via M3 Pro CPU, ~1000ms latency
- streaming support w/ variable chunking control latency & artifacts - streaming support w/ variable chunking to control latency & artifacts
- simple audio generation web ui utility - simple audio generation web ui utility
@ -39,7 +39,7 @@ The service can be accessed through either the API endpoints or the Gradio web i
response = client.audio.speech.create( response = client.audio.speech.create(
model="kokoro", model="kokoro",
voice="af_bella", voice="af_sky+af_bella", #single or multiple voicepack combo
input="Hello world!", input="Hello world!",
response_format="mp3" response_format="mp3"
) )
@ -61,7 +61,7 @@ from openai import OpenAI
client = OpenAI(base_url="http://localhost:8880", api_key="not-needed") client = OpenAI(base_url="http://localhost:8880", api_key="not-needed")
response = client.audio.speech.create( response = client.audio.speech.create(
model="kokoro", # Not used but required for compatibility, also accepts library defaults model="kokoro", # Not used but required for compatibility, also accepts library defaults
voice="af_bella", voice="af_bella+af_sky",
input="Hello world!", input="Hello world!",
response_format="mp3" response_format="mp3"
) )
@ -105,6 +105,7 @@ python examples/test_all_voices.py # Test all available voices
- Averages model weights of any existing voicepacks - Averages model weights of any existing voicepacks
- Saves generated voicepacks for future use - Saves generated voicepacks for future use
- (new) Available through any endpoint, simply concatenate desired packs with "+"
Combine voices and generate audio: Combine voices and generate audio:
```python ```python
@ -119,12 +120,12 @@ response = requests.post(
) )
combined_voice = response.json()["voice"] combined_voice = response.json()["voice"]
# Generate audio with combined voice # Generate audio with combined voice (or, simply pass multiple directly with `+` )
response = requests.post( response = requests.post(
"http://localhost:8880/v1/audio/speech", "http://localhost:8880/v1/audio/speech",
json={ json={
"input": "Hello world!", "input": "Hello world!",
"voice": combined_voice, "voice": combined_voice, # or skip the above step with f"{voices[0]}+{voices[1]}"
"response_format": "mp3" "response_format": "mp3"
} }
) )

View file

@ -1,4 +1,4 @@
from typing import List from typing import List, Union
from loguru import logger from loguru import logger
from fastapi import Depends, Response, APIRouter, HTTPException from fastapi import Depends, Response, APIRouter, HTTPException
@ -20,18 +20,43 @@ def get_tts_service() -> TTSService:
return TTSService() # Initialize TTSService with default settings return TTSService() # Initialize TTSService with default settings
async def process_voices(voice_input: Union[str, List[str]], tts_service: TTSService) -> str:
"""Process voice input into a combined voice, handling both string and list formats"""
# Convert input to list of voices
if isinstance(voice_input, str):
voices = [v.strip() for v in voice_input.split("+") if v.strip()]
else:
voices = voice_input
if not voices:
raise ValueError("No voices provided")
# Check if all voices exist
available_voices = await tts_service.list_voices()
for voice in voices:
if voice not in available_voices:
raise ValueError(f"Voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}")
# If single voice, return it directly
if len(voices) == 1:
return voices[0]
# Otherwise combine voices
return await tts_service.combine_voices(voices=voices)
async def stream_audio_chunks(tts_service: TTSService, request: OpenAISpeechRequest) -> AsyncGenerator[bytes, None]: async def stream_audio_chunks(tts_service: TTSService, request: OpenAISpeechRequest) -> AsyncGenerator[bytes, None]:
"""Stream audio chunks as they're generated""" """Stream audio chunks as they're generated"""
voice_to_use = await process_voices(request.voice, tts_service)
async for chunk in tts_service.generate_audio_stream( async for chunk in tts_service.generate_audio_stream(
text=request.input, text=request.input,
voice=request.voice, voice=voice_to_use,
speed=request.speed, speed=request.speed,
output_format=request.response_format output_format=request.response_format
): ):
yield chunk yield chunk
@router.post("/audio/speech") @router.post("/audio/speech")
async def create_speech( async def create_speech(
request: OpenAISpeechRequest, request: OpenAISpeechRequest,
@ -40,12 +65,8 @@ async def create_speech(
): ):
"""OpenAI-compatible endpoint for text-to-speech""" """OpenAI-compatible endpoint for text-to-speech"""
try: try:
# Validate voice exists # Process voice combination and validate
available_voices = tts_service.list_voices() voice_to_use = await process_voices(request.voice, tts_service)
if request.voice not in available_voices:
raise ValueError(
f"Voice '{request.voice}' not found. Available voices: {', '.join(sorted(available_voices))}"
)
# Set content type based on format # Set content type based on format
content_type = { content_type = {
@ -73,7 +94,7 @@ async def create_speech(
# Generate complete audio # Generate complete audio
audio, _ = tts_service._generate_audio( audio, _ = tts_service._generate_audio(
text=request.input, text=request.input,
voice=request.voice, voice=voice_to_use,
speed=request.speed, speed=request.speed,
stitch_long_output=True, stitch_long_output=True,
) )
@ -111,7 +132,7 @@ async def create_speech(
async def list_voices(tts_service: TTSService = Depends(get_tts_service)): async def list_voices(tts_service: TTSService = Depends(get_tts_service)):
"""List all available voices for text-to-speech""" """List all available voices for text-to-speech"""
try: try:
voices = tts_service.list_voices() voices = await tts_service.list_voices()
return {"voices": voices} return {"voices": voices}
except Exception as e: except Exception as e:
logger.error(f"Error listing voices: {str(e)}") logger.error(f"Error listing voices: {str(e)}")
@ -120,12 +141,13 @@ async def list_voices(tts_service: TTSService = Depends(get_tts_service)):
@router.post("/audio/voices/combine") @router.post("/audio/voices/combine")
async def combine_voices( async def combine_voices(
request: List[str], tts_service: TTSService = Depends(get_tts_service) request: Union[str, List[str]], tts_service: TTSService = Depends(get_tts_service)
): ):
"""Combine multiple voices into a new voice. """Combine multiple voices into a new voice.
Args: Args:
request: List of voice names to combine request: Either a string with voices separated by + (e.g. "voice1+voice2")
or a list of voice names to combine
Returns: Returns:
Dict with combined voice name and list of all available voices Dict with combined voice name and list of all available voices
@ -136,8 +158,8 @@ async def combine_voices(
- 500: Server error (file system issues, combination failed) - 500: Server error (file system issues, combination failed)
""" """
try: try:
combined_voice = tts_service.combine_voices(voices=request) combined_voice = await process_voices(request, tts_service)
voices = tts_service.list_voices() voices = await tts_service.list_voices()
return {"voices": voices, "voice": combined_voice} return {"voices": voices, "voice": combined_voice}
except ValueError as e: except ValueError as e:
@ -146,14 +168,8 @@ async def combine_voices(
status_code=400, detail={"error": "Invalid request", "message": str(e)} status_code=400, detail={"error": "Invalid request", "message": str(e)}
) )
except RuntimeError as e: except Exception as e:
logger.error(f"Server error during voice combination: {str(e)}") logger.error(f"Server error during voice combination: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=500, detail={"error": "Server error", "message": str(e)} status_code=500, detail={"error": "Server error", "message": "Server error"}
)
except Exception as e:
logger.error(f"Unexpected error during voice combination: {str(e)}")
raise HTTPException(
status_code=500, detail={"error": "Unexpected error", "message": str(e)}
) )

View file

@ -1,9 +1,11 @@
import aiofiles
import io import io
import os import os
import re import re
import time import time
from typing import List, Tuple, Optional from typing import List, Tuple, Optional
from functools import lru_cache from functools import lru_cache
from aiofiles import threadpool
import numpy as np import numpy as np
import torch import torch
@ -211,7 +213,7 @@ class TTSService:
wavfile.write(buffer, 24000, audio) wavfile.write(buffer, 24000, audio)
return buffer.getvalue() return buffer.getvalue()
def combine_voices(self, voices: List[str]) -> str: async def combine_voices(self, voices: List[str]) -> str:
"""Combine multiple voices into a new voice""" """Combine multiple voices into a new voice"""
if len(voices) < 2: if len(voices) < 2:
raise ValueError("At least 2 voices are required for combination") raise ValueError("At least 2 voices are required for combination")
@ -252,11 +254,13 @@ class TTSService:
raise RuntimeError(f"Error combining voices: {str(e)}") raise RuntimeError(f"Error combining voices: {str(e)}")
raise raise
def list_voices(self) -> List[str]: async def list_voices(self) -> List[str]:
"""List all available voices""" """List all available voices"""
voices = [] voices = []
try: try:
for file in os.listdir(TTSModel.VOICES_DIR): # Use os.listdir in a thread pool
files = await threadpool.async_wrap(os.listdir)(TTSModel.VOICES_DIR)
for file in files:
if file.endswith(".pt"): if file.endswith(".pt"):
voices.append(file[:-3]) # Remove .pt extension voices.append(file[:-3]) # Remove .pt extension
except Exception as e: except Exception as e:

View file

@ -1,9 +1,17 @@
from enum import Enum from enum import Enum
from typing import Literal from typing import Literal, Union, List
from pydantic import Field, BaseModel from pydantic import Field, BaseModel
class VoiceCombineRequest(BaseModel):
"""Request schema for voice combination endpoint that accepts either a string with + or a list"""
voices: Union[str, List[str]] = Field(
...,
description="Either a string with voices separated by + (e.g. 'voice1+voice2') or a list of voice names to combine"
)
class TTSStatus(str, Enum): class TTSStatus(str, Enum):
PENDING = "pending" PENDING = "pending"
PROCESSING = "processing" PROCESSING = "processing"

View file

@ -29,7 +29,9 @@ def mock_tts_service(monkeypatch):
for chunk in [b"chunk1", b"chunk2"]: for chunk in [b"chunk1", b"chunk2"]:
yield chunk yield chunk
mock_service.generate_audio_stream = mock_stream mock_service.generate_audio_stream = mock_stream
mock_service.list_voices.return_value = [
# Create async mocks
mock_service.list_voices = AsyncMock(return_value=[
"af", "af",
"bm_lewis", "bm_lewis",
"bf_isabella", "bf_isabella",
@ -39,7 +41,8 @@ def mock_tts_service(monkeypatch):
"am_adam", "am_adam",
"am_michael", "am_michael",
"bm_george", "bm_george",
] ])
mock_service.combine_voices = AsyncMock()
monkeypatch.setattr( monkeypatch.setattr(
"api.src.routers.openai_compatible.TTSService", "api.src.routers.openai_compatible.TTSService",
lambda *args, **kwargs: mock_service, lambda *args, **kwargs: mock_service,
@ -64,7 +67,8 @@ def test_health_check():
assert response.json() == {"status": "healthy"} assert response.json() == {"status": "healthy"}
def test_openai_speech_endpoint(mock_tts_service, mock_audio_service): @pytest.mark.asyncio
async def test_openai_speech_endpoint(mock_tts_service, mock_audio_service, async_client):
"""Test the OpenAI-compatible speech endpoint""" """Test the OpenAI-compatible speech endpoint"""
test_request = { test_request = {
"model": "kokoro", "model": "kokoro",
@ -74,7 +78,7 @@ def test_openai_speech_endpoint(mock_tts_service, mock_audio_service):
"speed": 1.0, "speed": 1.0,
"stream": False # Explicitly disable streaming "stream": False # Explicitly disable streaming
} }
response = client.post("/v1/audio/speech", json=test_request) response = await async_client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 200 assert response.status_code == 200
assert response.headers["content-type"] == "audio/wav" assert response.headers["content-type"] == "audio/wav"
assert response.headers["content-disposition"] == "attachment; filename=speech.wav" assert response.headers["content-disposition"] == "attachment; filename=speech.wav"
@ -84,7 +88,8 @@ def test_openai_speech_endpoint(mock_tts_service, mock_audio_service):
assert response.content == b"converted mock audio data" assert response.content == b"converted mock audio data"
def test_openai_speech_invalid_voice(mock_tts_service): @pytest.mark.asyncio
async def test_openai_speech_invalid_voice(mock_tts_service, async_client):
"""Test the OpenAI-compatible speech endpoint with invalid voice""" """Test the OpenAI-compatible speech endpoint with invalid voice"""
test_request = { test_request = {
"model": "kokoro", "model": "kokoro",
@ -94,12 +99,13 @@ def test_openai_speech_invalid_voice(mock_tts_service):
"speed": 1.0, "speed": 1.0,
"stream": False # Explicitly disable streaming "stream": False # Explicitly disable streaming
} }
response = client.post("/v1/audio/speech", json=test_request) response = await async_client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 400 # Bad request assert response.status_code == 400 # Bad request
assert "not found" in response.json()["detail"]["message"] assert "not found" in response.json()["detail"]["message"]
def test_openai_speech_invalid_speed(mock_tts_service): @pytest.mark.asyncio
async def test_openai_speech_invalid_speed(mock_tts_service, async_client):
"""Test the OpenAI-compatible speech endpoint with invalid speed""" """Test the OpenAI-compatible speech endpoint with invalid speed"""
test_request = { test_request = {
"model": "kokoro", "model": "kokoro",
@ -109,11 +115,12 @@ def test_openai_speech_invalid_speed(mock_tts_service):
"speed": -1.0, # Invalid speed "speed": -1.0, # Invalid speed
"stream": False # Explicitly disable streaming "stream": False # Explicitly disable streaming
} }
response = client.post("/v1/audio/speech", json=test_request) response = await async_client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 422 # Validation error assert response.status_code == 422 # Validation error
def test_openai_speech_generation_error(mock_tts_service): @pytest.mark.asyncio
async def test_openai_speech_generation_error(mock_tts_service, async_client):
"""Test error handling in speech generation""" """Test error handling in speech generation"""
mock_tts_service._generate_audio.side_effect = Exception("Generation failed") mock_tts_service._generate_audio.side_effect = Exception("Generation failed")
test_request = { test_request = {
@ -124,54 +131,173 @@ def test_openai_speech_generation_error(mock_tts_service):
"speed": 1.0, "speed": 1.0,
"stream": False # Explicitly disable streaming "stream": False # Explicitly disable streaming
} }
response = client.post("/v1/audio/speech", json=test_request) response = await async_client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 500 assert response.status_code == 500
assert "Generation failed" in response.json()["detail"]["message"] assert "Generation failed" in response.json()["detail"]["message"]
def test_combine_voices_success(mock_tts_service): @pytest.mark.asyncio
"""Test successful voice combination""" async def test_combine_voices_list_success(mock_tts_service, async_client):
"""Test successful voice combination using list format"""
test_voices = ["af_bella", "af_sarah"] test_voices = ["af_bella", "af_sarah"]
mock_tts_service.combine_voices.return_value = "af_bella_af_sarah" mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
response = client.post("/v1/audio/voices/combine", json=test_voices) response = await async_client.post("/v1/audio/voices/combine", json=test_voices)
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["voice"] == "af_bella_af_sarah" assert response.json()["voice"] == "af_bella_af_sarah"
mock_tts_service.combine_voices.assert_called_once_with(voices=test_voices) mock_tts_service.combine_voices.assert_called_once_with(voices=test_voices)
def test_combine_voices_single_voice(mock_tts_service): @pytest.mark.asyncio
"""Test combining single voice returns default voice""" async def test_combine_voices_string_success(mock_tts_service, async_client):
"""Test successful voice combination using string format with +"""
test_voices = "af_bella+af_sarah"
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
response = await async_client.post("/v1/audio/voices/combine", json=test_voices)
assert response.status_code == 200
assert response.json()["voice"] == "af_bella_af_sarah"
mock_tts_service.combine_voices.assert_called_once_with(voices=["af_bella", "af_sarah"])
@pytest.mark.asyncio
async def test_combine_voices_single_voice(mock_tts_service, async_client):
"""Test combining single voice returns same voice"""
test_voices = ["af_bella"] test_voices = ["af_bella"]
mock_tts_service.combine_voices.return_value = "af" response = await async_client.post("/v1/audio/voices/combine", json=test_voices)
response = client.post("/v1/audio/voices/combine", json=test_voices)
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["voice"] == "af" assert response.json()["voice"] == "af_bella"
def test_combine_voices_empty_list(mock_tts_service): @pytest.mark.asyncio
"""Test combining empty voice list returns default voice""" async def test_combine_voices_empty_list(mock_tts_service, async_client):
"""Test combining empty voice list returns error"""
test_voices = [] test_voices = []
mock_tts_service.combine_voices.return_value = "af" response = await async_client.post("/v1/audio/voices/combine", json=test_voices)
assert response.status_code == 400
response = client.post("/v1/audio/voices/combine", json=test_voices) assert "No voices provided" in response.json()["detail"]["message"]
assert response.status_code == 200
assert response.json()["voice"] == "af"
def test_combine_voices_error(mock_tts_service): @pytest.mark.asyncio
async def test_combine_voices_error(mock_tts_service, async_client):
"""Test error handling in voice combination""" """Test error handling in voice combination"""
test_voices = ["af_bella", "af_sarah"] test_voices = ["af_bella", "af_sarah"]
mock_tts_service.combine_voices.side_effect = Exception("Combination failed") mock_tts_service.combine_voices = AsyncMock(side_effect=Exception("Combination failed"))
response = client.post("/v1/audio/voices/combine", json=test_voices)
response = await async_client.post("/v1/audio/voices/combine", json=test_voices)
assert response.status_code == 500 assert response.status_code == 500
assert "Combination failed" in response.json()["detail"]["message"] assert "Server error" in response.json()["detail"]["message"]
@pytest.mark.asyncio
async def test_speech_with_combined_voice(mock_tts_service, mock_audio_service, async_client):
"""Test speech generation with combined voice using + syntax"""
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
test_request = {
"model": "kokoro",
"input": "Hello world",
"voice": "af_bella+af_sarah",
"response_format": "wav",
"speed": 1.0,
"stream": False
}
response = await async_client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/wav"
mock_tts_service._generate_audio.assert_called_once_with(
text="Hello world",
voice="af_bella_af_sarah",
speed=1.0,
stitch_long_output=True
)
@pytest.mark.asyncio
async def test_speech_with_whitespace_in_voice(mock_tts_service, mock_audio_service, async_client):
"""Test speech generation with whitespace in voice combination"""
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
test_request = {
"model": "kokoro",
"input": "Hello world",
"voice": " af_bella + af_sarah ",
"response_format": "wav",
"speed": 1.0,
"stream": False
}
response = await async_client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/wav"
mock_tts_service.combine_voices.assert_called_once_with(voices=["af_bella", "af_sarah"])
@pytest.mark.asyncio
async def test_speech_with_empty_voice_combination(mock_tts_service, async_client):
"""Test speech generation with empty voice combination"""
test_request = {
"model": "kokoro",
"input": "Hello world",
"voice": "+",
"response_format": "wav",
"speed": 1.0,
"stream": False
}
response = await async_client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 400
assert "No voices provided" in response.json()["detail"]["message"]
@pytest.mark.asyncio
async def test_speech_with_invalid_combined_voice(mock_tts_service, async_client):
"""Test speech generation with invalid voice combination"""
test_request = {
"model": "kokoro",
"input": "Hello world",
"voice": "invalid+combination",
"response_format": "wav",
"speed": 1.0,
"stream": False
}
response = await async_client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 400
assert "not found" in response.json()["detail"]["message"]
@pytest.mark.asyncio
async def test_speech_streaming_with_combined_voice(mock_tts_service, async_client):
"""Test streaming speech with combined voice using + syntax"""
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
test_request = {
"model": "kokoro",
"input": "Hello world",
"voice": "af_bella+af_sarah",
"response_format": "mp3",
"stream": True
}
# Create streaming mock
async def mock_stream(*args, **kwargs):
for chunk in [b"mp3header", b"mp3data"]:
yield chunk
mock_tts_service.generate_audio_stream = mock_stream
# Add streaming header
headers = {"x-raw-response": "stream"}
response = await async_client.post("/v1/audio/speech", json=test_request, headers=headers)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/mpeg"
assert response.headers["content-disposition"] == "attachment; filename=speech.mp3"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -197,9 +323,6 @@ async def test_openai_speech_pcm_streaming(mock_tts_service, async_client):
assert response.status_code == 200 assert response.status_code == 200
assert response.headers["content-type"] == "audio/pcm" assert response.headers["content-type"] == "audio/pcm"
# Just verify status and content type
assert response.status_code == 200
assert response.headers["content-type"] == "audio/pcm"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -226,10 +349,6 @@ async def test_openai_speech_streaming_mp3(mock_tts_service, async_client):
assert response.status_code == 200 assert response.status_code == 200
assert response.headers["content-type"] == "audio/mpeg" assert response.headers["content-type"] == "audio/mpeg"
assert response.headers["content-disposition"] == "attachment; filename=speech.mp3" assert response.headers["content-disposition"] == "attachment; filename=speech.mp3"
# Just verify status and content type
assert response.status_code == 200
assert response.headers["content-type"] == "audio/mpeg"
assert response.headers["content-disposition"] == "attachment; filename=speech.mp3"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -255,6 +374,3 @@ async def test_openai_speech_streaming_generator(mock_tts_service, async_client)
assert response.status_code == 200 assert response.status_code == 200
assert response.headers["content-type"] == "audio/pcm" assert response.headers["content-type"] == "audio/pcm"
# Just verify status and content type
assert response.status_code == 200
assert response.headers["content-type"] == "audio/pcm"

View file

@ -1,12 +1,13 @@
"""Tests for TTSService""" """Tests for TTSService"""
import os import os
from unittest.mock import MagicMock, call, patch from unittest.mock import MagicMock, call, patch, AsyncMock
import numpy as np import numpy as np
import torch import torch
import pytest import pytest
from onnxruntime import InferenceSession from onnxruntime import InferenceSession
from aiofiles import threadpool
from api.src.core.config import settings from api.src.core.config import settings
from api.src.services.tts_model import TTSModel from api.src.services.tts_model import TTSModel
@ -38,27 +39,33 @@ def test_audio_to_bytes(tts_service, sample_audio):
assert len(audio_bytes) > 0 assert len(audio_bytes) > 0
@patch("os.listdir") @pytest.mark.asyncio
@patch("os.path.join") async def test_list_voices(tts_service):
def test_list_voices(mock_join, mock_listdir, tts_service):
"""Test listing available voices""" """Test listing available voices"""
mock_listdir.return_value = ["voice1.pt", "voice2.pt", "not_a_voice.txt"] # Mock os.listdir to return test files
mock_join.return_value = "/fake/path" with patch('os.listdir', return_value=["voice1.pt", "voice2.pt", "not_a_voice.txt"]):
# Register mock with threadpool
voices = tts_service.list_voices() async_listdir = AsyncMock(return_value=["voice1.pt", "voice2.pt", "not_a_voice.txt"])
assert len(voices) == 2 threadpool.async_wrap = MagicMock(return_value=async_listdir)
assert "voice1" in voices
assert "voice2" in voices voices = await tts_service.list_voices()
assert "not_a_voice" not in voices assert len(voices) == 2
assert "voice1" in voices
assert "voice2" in voices
assert "not_a_voice" not in voices
@patch("os.listdir") @pytest.mark.asyncio
def test_list_voices_error(mock_listdir, tts_service): async def test_list_voices_error(tts_service):
"""Test error handling in list_voices""" """Test error handling in list_voices"""
mock_listdir.side_effect = Exception("Failed to list directory") # Mock os.listdir to raise an exception
with patch('os.listdir', side_effect=Exception("Failed to list directory")):
voices = tts_service.list_voices() # Register mock with threadpool
assert voices == [] async_listdir = AsyncMock(side_effect=Exception("Failed to list directory"))
threadpool.async_wrap = MagicMock(return_value=async_listdir)
voices = await tts_service.list_voices()
assert voices == []
def mock_model_setup(cuda_available=False): def mock_model_setup(cuda_available=False):
@ -176,7 +183,8 @@ def test_save_audio(tts_service, sample_audio, tmp_path):
assert os.path.getsize(output_path) > 0 assert os.path.getsize(output_path) > 0
def test_combine_voices(tts_service): @pytest.mark.asyncio
async def test_combine_voices(tts_service):
"""Test combining multiple voices""" """Test combining multiple voices"""
# Setup mocks for torch operations # Setup mocks for torch operations
with patch('torch.load', return_value=torch.tensor([1.0, 2.0])), \ with patch('torch.load', return_value=torch.tensor([1.0, 2.0])), \
@ -186,20 +194,21 @@ def test_combine_voices(tts_service):
patch('os.path.exists', return_value=True): patch('os.path.exists', return_value=True):
# Test combining two voices # Test combining two voices
result = tts_service.combine_voices(["voice1", "voice2"]) result = await tts_service.combine_voices(["voice1", "voice2"])
assert result == "voice1_voice2" assert result == "voice1_voice2"
def test_combine_voices_invalid_input(tts_service): @pytest.mark.asyncio
async def test_combine_voices_invalid_input(tts_service):
"""Test combining voices with invalid input""" """Test combining voices with invalid input"""
# Test with empty list # Test with empty list
with pytest.raises(ValueError, match="At least 2 voices are required"): with pytest.raises(ValueError, match="At least 2 voices are required"):
tts_service.combine_voices([]) await tts_service.combine_voices([])
# Test with single voice # Test with single voice
with pytest.raises(ValueError, match="At least 2 voices are required"): with pytest.raises(ValueError, match="At least 2 voices are required"):
tts_service.combine_voices(["voice1"]) await tts_service.combine_voices(["voice1"])
@patch("api.src.services.tts_service.TTSService._get_voice_path") @patch("api.src.services.tts_service.TTSService._get_voice_path")

View file

@ -34,7 +34,7 @@ def stream_to_speakers() -> None:
with openai.audio.speech.with_streaming_response.create( with openai.audio.speech.with_streaming_response.create(
model="kokoro", model="kokoro",
voice="af", voice="af_sky+af_bella+bm_george",
response_format="pcm", # similar to WAV, but without a header chunk at the start. response_format="pcm", # similar to WAV, but without a header chunk at the start.
input="""My dear sir, that is just where you are wrong. That is just where the whole world has gone wrong. We are always getting away from the present moment. Our mental existences, which are immaterial and have no dimensions, are passing along the Time-Dimension with a uniform velocity from the cradle to the grave. Just as we should travel down if we began our existence fifty miles above the earths surface""", input="""My dear sir, that is just where you are wrong. That is just where the whole world has gone wrong. We are always getting away from the present moment. Our mental existences, which are immaterial and have no dimensions, are passing along the Time-Dimension with a uniform velocity from the cradle to the grave. Just as we should travel down if we began our existence fifty miles above the earths surface""",
) as response: ) as response:

BIN
examples/speech.mp3 Normal file

Binary file not shown.

View file

@ -20,6 +20,7 @@ phonemizer==3.3.0
regex==2024.11.6 regex==2024.11.6
# Utilities # Utilities
aiofiles==24.1.0
tqdm==4.67.1 tqdm==4.67.1
requests==2.32.3 requests==2.32.3
munch==4.0.0 munch==4.0.0