"""Tests for TTSService""" import os from unittest.mock import MagicMock, call, patch, AsyncMock import numpy as np import torch import pytest from onnxruntime import InferenceSession from aiofiles import threadpool from api.src.core.config import settings from api.src.services.tts_model import TTSModel from api.src.services.tts_service import TTSService from api.src.services.tts_cpu import TTSCPUModel from api.src.services.tts_gpu import TTSGPUModel @pytest.fixture def tts_service(): """Create a TTSService instance for testing""" return TTSService() @pytest.fixture def sample_audio(): """Generate a simple sine wave for testing""" sample_rate = 24000 duration = 0.1 # 100ms t = np.linspace(0, duration, int(sample_rate * duration)) frequency = 440 # A4 note return np.sin(2 * np.pi * frequency * t).astype(np.float32) def test_audio_to_bytes(tts_service, sample_audio): """Test converting audio tensor to bytes""" audio_bytes = tts_service._audio_to_bytes(sample_audio) assert isinstance(audio_bytes, bytes) assert len(audio_bytes) > 0 @pytest.mark.asyncio async def test_list_voices(tts_service): """Test listing available voices""" # Mock os.listdir to return test files with patch('os.listdir', return_value=["voice1.pt", "voice2.pt", "not_a_voice.txt"]): # Register mock with threadpool async_listdir = AsyncMock(return_value=["voice1.pt", "voice2.pt", "not_a_voice.txt"]) threadpool.async_wrap = MagicMock(return_value=async_listdir) voices = await tts_service.list_voices() assert len(voices) == 2 assert "voice1" in voices assert "voice2" in voices assert "not_a_voice" not in voices @pytest.mark.asyncio async def test_list_voices_error(tts_service): """Test error handling in list_voices""" # Mock os.listdir to raise an exception with patch('os.listdir', side_effect=Exception("Failed to list directory")): # Register mock with threadpool async_listdir = AsyncMock(side_effect=Exception("Failed to list directory")) threadpool.async_wrap = MagicMock(return_value=async_listdir) voices = await tts_service.list_voices() assert voices == [] def mock_model_setup(cuda_available=False): """Helper function to mock model setup""" # Reset model state TTSModel._instance = None TTSModel._device = None TTSModel._voicepacks = {} # Create mock model instance with proper generate method mock_model = MagicMock() mock_model.generate.return_value = np.zeros(24000, dtype=np.float32) TTSModel._instance = mock_model # Set device based on CUDA availability TTSModel._device = "cuda" if cuda_available else "cpu" return 3 # Return voice count (including af.pt) def test_model_initialization_cuda(): """Test model initialization with CUDA""" # Simulate CUDA availability voice_count = mock_model_setup(cuda_available=True) assert TTSModel.get_device() == "cuda" assert voice_count == 3 # voice1.pt, voice2.pt, af.pt def test_model_initialization_cpu(): """Test model initialization with CPU""" # Simulate no CUDA availability voice_count = mock_model_setup(cuda_available=False) assert TTSModel.get_device() == "cpu" assert voice_count == 3 # voice1.pt, voice2.pt, af.pt def test_generate_audio_empty_text(tts_service): """Test generating audio with empty text""" with pytest.raises(ValueError, match="Text is empty after preprocessing"): tts_service._generate_audio("", "af", 1.0) @patch("api.src.services.tts_model.TTSModel.get_instance") @patch("api.src.services.tts_model.TTSModel.get_device") @patch("os.path.exists") @patch("kokoro.normalize_text") @patch("kokoro.phonemize") @patch("kokoro.tokenize") @patch("kokoro.generate") @patch("torch.load") def test_generate_audio_phonemize_error( mock_torch_load, mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_exists, mock_get_device, mock_instance, tts_service, ): """Test handling phonemization error""" mock_normalize.return_value = "Test text" mock_phonemize.side_effect = Exception("Phonemization failed") mock_instance.return_value = (mock_generate, "cpu") # Use the same mock for consistency mock_get_device.return_value = "cpu" mock_exists.return_value = True mock_torch_load.return_value = torch.zeros((10, 24000)) mock_generate.return_value = (None, None) with pytest.raises(ValueError, match="No chunks were processed successfully"): tts_service._generate_audio("Test text", "af", 1.0) @patch("api.src.services.tts_model.TTSModel.get_instance") @patch("api.src.services.tts_model.TTSModel.get_device") @patch("os.path.exists") @patch("kokoro.normalize_text") @patch("kokoro.phonemize") @patch("kokoro.tokenize") @patch("kokoro.generate") @patch("torch.load") def test_generate_audio_error( mock_torch_load, mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_exists, mock_get_device, mock_instance, tts_service, ): """Test handling generation error""" mock_normalize.return_value = "Test text" mock_phonemize.return_value = "Test text" mock_tokenize.return_value = [1, 2] # Return integers instead of strings mock_generate.side_effect = Exception("Generation failed") mock_instance.return_value = (mock_generate, "cpu") # Use the same mock for consistency mock_get_device.return_value = "cpu" mock_exists.return_value = True mock_torch_load.return_value = torch.zeros((10, 24000)) with pytest.raises(ValueError, match="No chunks were processed successfully"): tts_service._generate_audio("Test text", "af", 1.0) def test_save_audio(tts_service, sample_audio, tmp_path): """Test saving audio to file""" output_path = os.path.join(tmp_path, "test_output.wav") tts_service._save_audio(sample_audio, output_path) assert os.path.exists(output_path) assert os.path.getsize(output_path) > 0 @pytest.mark.asyncio async def test_combine_voices(tts_service): """Test combining multiple voices""" # Setup mocks for torch operations with patch('torch.load', return_value=torch.tensor([1.0, 2.0])), \ patch('torch.stack', return_value=torch.tensor([[1.0, 2.0], [3.0, 4.0]])), \ patch('torch.mean', return_value=torch.tensor([2.0, 3.0])), \ patch('torch.save'), \ patch('os.path.exists', return_value=True): # Test combining two voices result = await tts_service.combine_voices(["voice1", "voice2"]) assert result == "voice1_voice2" @pytest.mark.asyncio async def test_combine_voices_invalid_input(tts_service): """Test combining voices with invalid input""" # Test with empty list with pytest.raises(ValueError, match="At least 2 voices are required"): await tts_service.combine_voices([]) # Test with single voice with pytest.raises(ValueError, match="At least 2 voices are required"): await tts_service.combine_voices(["voice1"]) @patch("api.src.services.tts_service.TTSService._get_voice_path") @patch("api.src.services.tts_model.TTSModel.get_instance") def test_voicepack_loading_error(mock_get_instance, mock_get_voice_path): """Test voicepack loading error handling""" mock_get_voice_path.return_value = None mock_instance = MagicMock() mock_instance.generate.return_value = np.zeros(24000, dtype=np.float32) mock_get_instance.return_value = (mock_instance, "cpu") TTSModel._voicepacks = {} # Reset voicepacks service = TTSService() with pytest.raises(ValueError, match="Voice not found: nonexistent_voice"): service._generate_audio("test", "nonexistent_voice", 1.0)