mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00

-alignment with streaming standards -audio processing config settings -more comprehensive model warmup -minor model improvements -enhancing testing, benchmarking -cool ascii logo
218 lines
7.1 KiB
Python
218 lines
7.1 KiB
Python
"""Tests for TTSService"""
|
|
|
|
import os
|
|
from unittest.mock import MagicMock, call, patch
|
|
|
|
import numpy as np
|
|
import torch
|
|
import pytest
|
|
from onnxruntime import InferenceSession
|
|
|
|
from api.src.core.config import settings
|
|
from api.src.services.tts_model import TTSModel
|
|
from api.src.services.tts_service import TTSService
|
|
from api.src.services.tts_cpu import TTSCPUModel
|
|
from api.src.services.tts_gpu import TTSGPUModel
|
|
|
|
|
|
@pytest.fixture
|
|
def tts_service():
|
|
"""Create a TTSService instance for testing"""
|
|
return TTSService()
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_audio():
|
|
"""Generate a simple sine wave for testing"""
|
|
sample_rate = 24000
|
|
duration = 0.1 # 100ms
|
|
t = np.linspace(0, duration, int(sample_rate * duration))
|
|
frequency = 440 # A4 note
|
|
return np.sin(2 * np.pi * frequency * t).astype(np.float32)
|
|
|
|
|
|
def test_audio_to_bytes(tts_service, sample_audio):
|
|
"""Test converting audio tensor to bytes"""
|
|
audio_bytes = tts_service._audio_to_bytes(sample_audio)
|
|
assert isinstance(audio_bytes, bytes)
|
|
assert len(audio_bytes) > 0
|
|
|
|
|
|
@patch("os.listdir")
|
|
@patch("os.path.join")
|
|
def test_list_voices(mock_join, mock_listdir, tts_service):
|
|
"""Test listing available voices"""
|
|
mock_listdir.return_value = ["voice1.pt", "voice2.pt", "not_a_voice.txt"]
|
|
mock_join.return_value = "/fake/path"
|
|
|
|
voices = tts_service.list_voices()
|
|
assert len(voices) == 2
|
|
assert "voice1" in voices
|
|
assert "voice2" in voices
|
|
assert "not_a_voice" not in voices
|
|
|
|
|
|
@patch("os.listdir")
|
|
def test_list_voices_error(mock_listdir, tts_service):
|
|
"""Test error handling in list_voices"""
|
|
mock_listdir.side_effect = Exception("Failed to list directory")
|
|
|
|
voices = tts_service.list_voices()
|
|
assert voices == []
|
|
|
|
|
|
def mock_model_setup(cuda_available=False):
|
|
"""Helper function to mock model setup"""
|
|
# Reset model state
|
|
TTSModel._instance = None
|
|
TTSModel._device = None
|
|
TTSModel._voicepacks = {}
|
|
|
|
# Create mock model instance with proper generate method
|
|
mock_model = MagicMock()
|
|
mock_model.generate.return_value = np.zeros(24000, dtype=np.float32)
|
|
TTSModel._instance = mock_model
|
|
|
|
# Set device based on CUDA availability
|
|
TTSModel._device = "cuda" if cuda_available else "cpu"
|
|
|
|
return 3 # Return voice count (including af.pt)
|
|
|
|
|
|
def test_model_initialization_cuda():
|
|
"""Test model initialization with CUDA"""
|
|
# Simulate CUDA availability
|
|
voice_count = mock_model_setup(cuda_available=True)
|
|
|
|
assert TTSModel.get_device() == "cuda"
|
|
assert voice_count == 3 # voice1.pt, voice2.pt, af.pt
|
|
|
|
|
|
def test_model_initialization_cpu():
|
|
"""Test model initialization with CPU"""
|
|
# Simulate no CUDA availability
|
|
voice_count = mock_model_setup(cuda_available=False)
|
|
|
|
assert TTSModel.get_device() == "cpu"
|
|
assert voice_count == 3 # voice1.pt, voice2.pt, af.pt
|
|
|
|
|
|
def test_generate_audio_empty_text(tts_service):
|
|
"""Test generating audio with empty text"""
|
|
with pytest.raises(ValueError, match="Text is empty after preprocessing"):
|
|
tts_service._generate_audio("", "af", 1.0)
|
|
|
|
|
|
@patch("api.src.services.tts_model.TTSModel.get_instance")
|
|
@patch("api.src.services.tts_model.TTSModel.get_device")
|
|
@patch("os.path.exists")
|
|
@patch("kokoro.normalize_text")
|
|
@patch("kokoro.phonemize")
|
|
@patch("kokoro.tokenize")
|
|
@patch("kokoro.generate")
|
|
@patch("torch.load")
|
|
def test_generate_audio_phonemize_error(
|
|
mock_torch_load,
|
|
mock_generate,
|
|
mock_tokenize,
|
|
mock_phonemize,
|
|
mock_normalize,
|
|
mock_exists,
|
|
mock_get_device,
|
|
mock_instance,
|
|
tts_service,
|
|
):
|
|
"""Test handling phonemization error"""
|
|
mock_normalize.return_value = "Test text"
|
|
mock_phonemize.side_effect = Exception("Phonemization failed")
|
|
mock_instance.return_value = (mock_generate, "cpu") # Use the same mock for consistency
|
|
mock_get_device.return_value = "cpu"
|
|
mock_exists.return_value = True
|
|
mock_torch_load.return_value = torch.zeros((10, 24000))
|
|
mock_generate.return_value = (None, None)
|
|
|
|
with pytest.raises(ValueError, match="No chunks were processed successfully"):
|
|
tts_service._generate_audio("Test text", "af", 1.0)
|
|
|
|
|
|
@patch("api.src.services.tts_model.TTSModel.get_instance")
|
|
@patch("api.src.services.tts_model.TTSModel.get_device")
|
|
@patch("os.path.exists")
|
|
@patch("kokoro.normalize_text")
|
|
@patch("kokoro.phonemize")
|
|
@patch("kokoro.tokenize")
|
|
@patch("kokoro.generate")
|
|
@patch("torch.load")
|
|
def test_generate_audio_error(
|
|
mock_torch_load,
|
|
mock_generate,
|
|
mock_tokenize,
|
|
mock_phonemize,
|
|
mock_normalize,
|
|
mock_exists,
|
|
mock_get_device,
|
|
mock_instance,
|
|
tts_service,
|
|
):
|
|
"""Test handling generation error"""
|
|
mock_normalize.return_value = "Test text"
|
|
mock_phonemize.return_value = "Test text"
|
|
mock_tokenize.return_value = [1, 2] # Return integers instead of strings
|
|
mock_generate.side_effect = Exception("Generation failed")
|
|
mock_instance.return_value = (mock_generate, "cpu") # Use the same mock for consistency
|
|
mock_get_device.return_value = "cpu"
|
|
mock_exists.return_value = True
|
|
mock_torch_load.return_value = torch.zeros((10, 24000))
|
|
|
|
with pytest.raises(ValueError, match="No chunks were processed successfully"):
|
|
tts_service._generate_audio("Test text", "af", 1.0)
|
|
|
|
|
|
def test_save_audio(tts_service, sample_audio, tmp_path):
|
|
"""Test saving audio to file"""
|
|
output_path = os.path.join(tmp_path, "test_output.wav")
|
|
tts_service._save_audio(sample_audio, output_path)
|
|
assert os.path.exists(output_path)
|
|
assert os.path.getsize(output_path) > 0
|
|
|
|
|
|
def test_combine_voices(tts_service):
|
|
"""Test combining multiple voices"""
|
|
# Setup mocks for torch operations
|
|
with patch('torch.load', return_value=torch.tensor([1.0, 2.0])), \
|
|
patch('torch.stack', return_value=torch.tensor([[1.0, 2.0], [3.0, 4.0]])), \
|
|
patch('torch.mean', return_value=torch.tensor([2.0, 3.0])), \
|
|
patch('torch.save'), \
|
|
patch('os.path.exists', return_value=True):
|
|
|
|
# Test combining two voices
|
|
result = tts_service.combine_voices(["voice1", "voice2"])
|
|
|
|
assert result == "voice1_voice2"
|
|
|
|
|
|
def test_combine_voices_invalid_input(tts_service):
|
|
"""Test combining voices with invalid input"""
|
|
# Test with empty list
|
|
with pytest.raises(ValueError, match="At least 2 voices are required"):
|
|
tts_service.combine_voices([])
|
|
|
|
# Test with single voice
|
|
with pytest.raises(ValueError, match="At least 2 voices are required"):
|
|
tts_service.combine_voices(["voice1"])
|
|
|
|
|
|
@patch("api.src.services.tts_service.TTSService._get_voice_path")
|
|
@patch("api.src.services.tts_model.TTSModel.get_instance")
|
|
def test_voicepack_loading_error(mock_get_instance, mock_get_voice_path):
|
|
"""Test voicepack loading error handling"""
|
|
mock_get_voice_path.return_value = None
|
|
mock_instance = MagicMock()
|
|
mock_instance.generate.return_value = np.zeros(24000, dtype=np.float32)
|
|
mock_get_instance.return_value = (mock_instance, "cpu")
|
|
|
|
TTSModel._voicepacks = {} # Reset voicepacks
|
|
|
|
service = TTSService()
|
|
with pytest.raises(ValueError, match="Voice not found: nonexistent_voice"):
|
|
service._generate_audio("test", "nonexistent_voice", 1.0)
|