Kokoro-FastAPI/api/tests/test_tts_service.py
remsky 720c1fb97d -update soundfile version
-alignment with streaming standards
-audio processing config settings
-more comprehensive model warmup
-minor model improvements
-enhancing testing, benchmarking
-cool ascii logo
2025-01-06 03:32:41 -07:00

218 lines
7.1 KiB
Python

"""Tests for TTSService"""
import os
from unittest.mock import MagicMock, call, patch
import numpy as np
import torch
import pytest
from onnxruntime import InferenceSession
from api.src.core.config import settings
from api.src.services.tts_model import TTSModel
from api.src.services.tts_service import TTSService
from api.src.services.tts_cpu import TTSCPUModel
from api.src.services.tts_gpu import TTSGPUModel
@pytest.fixture
def tts_service():
"""Create a TTSService instance for testing"""
return TTSService()
@pytest.fixture
def sample_audio():
"""Generate a simple sine wave for testing"""
sample_rate = 24000
duration = 0.1 # 100ms
t = np.linspace(0, duration, int(sample_rate * duration))
frequency = 440 # A4 note
return np.sin(2 * np.pi * frequency * t).astype(np.float32)
def test_audio_to_bytes(tts_service, sample_audio):
"""Test converting audio tensor to bytes"""
audio_bytes = tts_service._audio_to_bytes(sample_audio)
assert isinstance(audio_bytes, bytes)
assert len(audio_bytes) > 0
@patch("os.listdir")
@patch("os.path.join")
def test_list_voices(mock_join, mock_listdir, tts_service):
"""Test listing available voices"""
mock_listdir.return_value = ["voice1.pt", "voice2.pt", "not_a_voice.txt"]
mock_join.return_value = "/fake/path"
voices = tts_service.list_voices()
assert len(voices) == 2
assert "voice1" in voices
assert "voice2" in voices
assert "not_a_voice" not in voices
@patch("os.listdir")
def test_list_voices_error(mock_listdir, tts_service):
"""Test error handling in list_voices"""
mock_listdir.side_effect = Exception("Failed to list directory")
voices = tts_service.list_voices()
assert voices == []
def mock_model_setup(cuda_available=False):
"""Helper function to mock model setup"""
# Reset model state
TTSModel._instance = None
TTSModel._device = None
TTSModel._voicepacks = {}
# Create mock model instance with proper generate method
mock_model = MagicMock()
mock_model.generate.return_value = np.zeros(24000, dtype=np.float32)
TTSModel._instance = mock_model
# Set device based on CUDA availability
TTSModel._device = "cuda" if cuda_available else "cpu"
return 3 # Return voice count (including af.pt)
def test_model_initialization_cuda():
"""Test model initialization with CUDA"""
# Simulate CUDA availability
voice_count = mock_model_setup(cuda_available=True)
assert TTSModel.get_device() == "cuda"
assert voice_count == 3 # voice1.pt, voice2.pt, af.pt
def test_model_initialization_cpu():
"""Test model initialization with CPU"""
# Simulate no CUDA availability
voice_count = mock_model_setup(cuda_available=False)
assert TTSModel.get_device() == "cpu"
assert voice_count == 3 # voice1.pt, voice2.pt, af.pt
def test_generate_audio_empty_text(tts_service):
"""Test generating audio with empty text"""
with pytest.raises(ValueError, match="Text is empty after preprocessing"):
tts_service._generate_audio("", "af", 1.0)
@patch("api.src.services.tts_model.TTSModel.get_instance")
@patch("api.src.services.tts_model.TTSModel.get_device")
@patch("os.path.exists")
@patch("kokoro.normalize_text")
@patch("kokoro.phonemize")
@patch("kokoro.tokenize")
@patch("kokoro.generate")
@patch("torch.load")
def test_generate_audio_phonemize_error(
mock_torch_load,
mock_generate,
mock_tokenize,
mock_phonemize,
mock_normalize,
mock_exists,
mock_get_device,
mock_instance,
tts_service,
):
"""Test handling phonemization error"""
mock_normalize.return_value = "Test text"
mock_phonemize.side_effect = Exception("Phonemization failed")
mock_instance.return_value = (mock_generate, "cpu") # Use the same mock for consistency
mock_get_device.return_value = "cpu"
mock_exists.return_value = True
mock_torch_load.return_value = torch.zeros((10, 24000))
mock_generate.return_value = (None, None)
with pytest.raises(ValueError, match="No chunks were processed successfully"):
tts_service._generate_audio("Test text", "af", 1.0)
@patch("api.src.services.tts_model.TTSModel.get_instance")
@patch("api.src.services.tts_model.TTSModel.get_device")
@patch("os.path.exists")
@patch("kokoro.normalize_text")
@patch("kokoro.phonemize")
@patch("kokoro.tokenize")
@patch("kokoro.generate")
@patch("torch.load")
def test_generate_audio_error(
mock_torch_load,
mock_generate,
mock_tokenize,
mock_phonemize,
mock_normalize,
mock_exists,
mock_get_device,
mock_instance,
tts_service,
):
"""Test handling generation error"""
mock_normalize.return_value = "Test text"
mock_phonemize.return_value = "Test text"
mock_tokenize.return_value = [1, 2] # Return integers instead of strings
mock_generate.side_effect = Exception("Generation failed")
mock_instance.return_value = (mock_generate, "cpu") # Use the same mock for consistency
mock_get_device.return_value = "cpu"
mock_exists.return_value = True
mock_torch_load.return_value = torch.zeros((10, 24000))
with pytest.raises(ValueError, match="No chunks were processed successfully"):
tts_service._generate_audio("Test text", "af", 1.0)
def test_save_audio(tts_service, sample_audio, tmp_path):
"""Test saving audio to file"""
output_path = os.path.join(tmp_path, "test_output.wav")
tts_service._save_audio(sample_audio, output_path)
assert os.path.exists(output_path)
assert os.path.getsize(output_path) > 0
def test_combine_voices(tts_service):
"""Test combining multiple voices"""
# Setup mocks for torch operations
with patch('torch.load', return_value=torch.tensor([1.0, 2.0])), \
patch('torch.stack', return_value=torch.tensor([[1.0, 2.0], [3.0, 4.0]])), \
patch('torch.mean', return_value=torch.tensor([2.0, 3.0])), \
patch('torch.save'), \
patch('os.path.exists', return_value=True):
# Test combining two voices
result = tts_service.combine_voices(["voice1", "voice2"])
assert result == "voice1_voice2"
def test_combine_voices_invalid_input(tts_service):
"""Test combining voices with invalid input"""
# Test with empty list
with pytest.raises(ValueError, match="At least 2 voices are required"):
tts_service.combine_voices([])
# Test with single voice
with pytest.raises(ValueError, match="At least 2 voices are required"):
tts_service.combine_voices(["voice1"])
@patch("api.src.services.tts_service.TTSService._get_voice_path")
@patch("api.src.services.tts_model.TTSModel.get_instance")
def test_voicepack_loading_error(mock_get_instance, mock_get_voice_path):
"""Test voicepack loading error handling"""
mock_get_voice_path.return_value = None
mock_instance = MagicMock()
mock_instance.generate.return_value = np.zeros(24000, dtype=np.float32)
mock_get_instance.return_value = (mock_instance, "cpu")
TTSModel._voicepacks = {} # Reset voicepacks
service = TTSService()
with pytest.raises(ValueError, match="Voice not found: nonexistent_voice"):
service._generate_audio("test", "nonexistent_voice", 1.0)