Kokoro-FastAPI/api/tests/test_tts_service.py
remsky 130b084cce - Added support for combining voices via any endpoint
- Updated the `process_voices` function to handle both string and list formats for voice input.
2025-01-07 03:50:08 -07:00

227 lines
7.7 KiB
Python

"""Tests for TTSService"""
import os
from unittest.mock import MagicMock, call, patch, AsyncMock
import numpy as np
import torch
import pytest
from onnxruntime import InferenceSession
from aiofiles import threadpool
from api.src.core.config import settings
from api.src.services.tts_model import TTSModel
from api.src.services.tts_service import TTSService
from api.src.services.tts_cpu import TTSCPUModel
from api.src.services.tts_gpu import TTSGPUModel
@pytest.fixture
def tts_service():
"""Create a TTSService instance for testing"""
return TTSService()
@pytest.fixture
def sample_audio():
"""Generate a simple sine wave for testing"""
sample_rate = 24000
duration = 0.1 # 100ms
t = np.linspace(0, duration, int(sample_rate * duration))
frequency = 440 # A4 note
return np.sin(2 * np.pi * frequency * t).astype(np.float32)
def test_audio_to_bytes(tts_service, sample_audio):
"""Test converting audio tensor to bytes"""
audio_bytes = tts_service._audio_to_bytes(sample_audio)
assert isinstance(audio_bytes, bytes)
assert len(audio_bytes) > 0
@pytest.mark.asyncio
async def test_list_voices(tts_service):
"""Test listing available voices"""
# Mock os.listdir to return test files
with patch('os.listdir', return_value=["voice1.pt", "voice2.pt", "not_a_voice.txt"]):
# Register mock with threadpool
async_listdir = AsyncMock(return_value=["voice1.pt", "voice2.pt", "not_a_voice.txt"])
threadpool.async_wrap = MagicMock(return_value=async_listdir)
voices = await tts_service.list_voices()
assert len(voices) == 2
assert "voice1" in voices
assert "voice2" in voices
assert "not_a_voice" not in voices
@pytest.mark.asyncio
async def test_list_voices_error(tts_service):
"""Test error handling in list_voices"""
# Mock os.listdir to raise an exception
with patch('os.listdir', side_effect=Exception("Failed to list directory")):
# Register mock with threadpool
async_listdir = AsyncMock(side_effect=Exception("Failed to list directory"))
threadpool.async_wrap = MagicMock(return_value=async_listdir)
voices = await tts_service.list_voices()
assert voices == []
def mock_model_setup(cuda_available=False):
"""Helper function to mock model setup"""
# Reset model state
TTSModel._instance = None
TTSModel._device = None
TTSModel._voicepacks = {}
# Create mock model instance with proper generate method
mock_model = MagicMock()
mock_model.generate.return_value = np.zeros(24000, dtype=np.float32)
TTSModel._instance = mock_model
# Set device based on CUDA availability
TTSModel._device = "cuda" if cuda_available else "cpu"
return 3 # Return voice count (including af.pt)
def test_model_initialization_cuda():
"""Test model initialization with CUDA"""
# Simulate CUDA availability
voice_count = mock_model_setup(cuda_available=True)
assert TTSModel.get_device() == "cuda"
assert voice_count == 3 # voice1.pt, voice2.pt, af.pt
def test_model_initialization_cpu():
"""Test model initialization with CPU"""
# Simulate no CUDA availability
voice_count = mock_model_setup(cuda_available=False)
assert TTSModel.get_device() == "cpu"
assert voice_count == 3 # voice1.pt, voice2.pt, af.pt
def test_generate_audio_empty_text(tts_service):
"""Test generating audio with empty text"""
with pytest.raises(ValueError, match="Text is empty after preprocessing"):
tts_service._generate_audio("", "af", 1.0)
@patch("api.src.services.tts_model.TTSModel.get_instance")
@patch("api.src.services.tts_model.TTSModel.get_device")
@patch("os.path.exists")
@patch("kokoro.normalize_text")
@patch("kokoro.phonemize")
@patch("kokoro.tokenize")
@patch("kokoro.generate")
@patch("torch.load")
def test_generate_audio_phonemize_error(
mock_torch_load,
mock_generate,
mock_tokenize,
mock_phonemize,
mock_normalize,
mock_exists,
mock_get_device,
mock_instance,
tts_service,
):
"""Test handling phonemization error"""
mock_normalize.return_value = "Test text"
mock_phonemize.side_effect = Exception("Phonemization failed")
mock_instance.return_value = (mock_generate, "cpu") # Use the same mock for consistency
mock_get_device.return_value = "cpu"
mock_exists.return_value = True
mock_torch_load.return_value = torch.zeros((10, 24000))
mock_generate.return_value = (None, None)
with pytest.raises(ValueError, match="No chunks were processed successfully"):
tts_service._generate_audio("Test text", "af", 1.0)
@patch("api.src.services.tts_model.TTSModel.get_instance")
@patch("api.src.services.tts_model.TTSModel.get_device")
@patch("os.path.exists")
@patch("kokoro.normalize_text")
@patch("kokoro.phonemize")
@patch("kokoro.tokenize")
@patch("kokoro.generate")
@patch("torch.load")
def test_generate_audio_error(
mock_torch_load,
mock_generate,
mock_tokenize,
mock_phonemize,
mock_normalize,
mock_exists,
mock_get_device,
mock_instance,
tts_service,
):
"""Test handling generation error"""
mock_normalize.return_value = "Test text"
mock_phonemize.return_value = "Test text"
mock_tokenize.return_value = [1, 2] # Return integers instead of strings
mock_generate.side_effect = Exception("Generation failed")
mock_instance.return_value = (mock_generate, "cpu") # Use the same mock for consistency
mock_get_device.return_value = "cpu"
mock_exists.return_value = True
mock_torch_load.return_value = torch.zeros((10, 24000))
with pytest.raises(ValueError, match="No chunks were processed successfully"):
tts_service._generate_audio("Test text", "af", 1.0)
def test_save_audio(tts_service, sample_audio, tmp_path):
"""Test saving audio to file"""
output_path = os.path.join(tmp_path, "test_output.wav")
tts_service._save_audio(sample_audio, output_path)
assert os.path.exists(output_path)
assert os.path.getsize(output_path) > 0
@pytest.mark.asyncio
async def test_combine_voices(tts_service):
"""Test combining multiple voices"""
# Setup mocks for torch operations
with patch('torch.load', return_value=torch.tensor([1.0, 2.0])), \
patch('torch.stack', return_value=torch.tensor([[1.0, 2.0], [3.0, 4.0]])), \
patch('torch.mean', return_value=torch.tensor([2.0, 3.0])), \
patch('torch.save'), \
patch('os.path.exists', return_value=True):
# Test combining two voices
result = await tts_service.combine_voices(["voice1", "voice2"])
assert result == "voice1_voice2"
@pytest.mark.asyncio
async def test_combine_voices_invalid_input(tts_service):
"""Test combining voices with invalid input"""
# Test with empty list
with pytest.raises(ValueError, match="At least 2 voices are required"):
await tts_service.combine_voices([])
# Test with single voice
with pytest.raises(ValueError, match="At least 2 voices are required"):
await tts_service.combine_voices(["voice1"])
@patch("api.src.services.tts_service.TTSService._get_voice_path")
@patch("api.src.services.tts_model.TTSModel.get_instance")
def test_voicepack_loading_error(mock_get_instance, mock_get_voice_path):
"""Test voicepack loading error handling"""
mock_get_voice_path.return_value = None
mock_instance = MagicMock()
mock_instance.generate.return_value = np.zeros(24000, dtype=np.float32)
mock_get_instance.return_value = (mock_instance, "cpu")
TTSModel._voicepacks = {} # Reset voicepacks
service = TTSService()
with pytest.raises(ValueError, match="Voice not found: nonexistent_voice"):
service._generate_audio("test", "nonexistent_voice", 1.0)