Kokoro-FastAPI/api/tests/conftest.py

73 lines
2.2 KiB
Python
Raw Normal View History

2025-02-09 18:32:17 -07:00
import os
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
2025-02-09 18:32:17 -07:00
import numpy as np
2025-02-09 18:32:17 -07:00
import pytest
import pytest_asyncio
import torch
2025-02-09 18:32:17 -07:00
from api.src.inference.model_manager import ModelManager
2025-02-09 18:32:17 -07:00
from api.src.inference.voice_manager import VoiceManager
from api.src.services.tts_service import TTSService
from api.src.structures.model_schemas import VoiceConfig
2025-02-09 18:32:17 -07:00
@pytest.fixture
def mock_voice_tensor():
"""Load a real voice tensor for testing."""
2025-02-09 18:32:17 -07:00
voice_path = os.path.join(
os.path.dirname(os.path.dirname(__file__)), "src/voices/af_bella.pt"
)
return torch.load(voice_path, map_location="cpu", weights_only=False)
@pytest.fixture
def mock_audio_output():
"""Load pre-generated test audio for consistent testing."""
2025-02-09 18:32:17 -07:00
test_audio_path = os.path.join(
os.path.dirname(__file__), "test_data/test_audio.npy"
)
return np.load(test_audio_path) # Return as numpy array instead of bytes
2025-02-09 18:32:17 -07:00
@pytest_asyncio.fixture
async def mock_model_manager(mock_audio_output):
"""Mock model manager for testing."""
manager = AsyncMock(spec=ModelManager)
manager.get_backend = MagicMock()
2025-02-09 18:32:17 -07:00
async def mock_generate(*args, **kwargs):
# Simulate successful audio generation
return np.random.rand(24000).astype(np.float32) # 1 second of random audio data
2025-02-09 18:32:17 -07:00
manager.generate = AsyncMock(side_effect=mock_generate)
return manager
2025-02-09 18:32:17 -07:00
@pytest_asyncio.fixture
async def mock_voice_manager(mock_voice_tensor):
"""Mock voice manager for testing."""
manager = AsyncMock(spec=VoiceManager)
manager.get_voice_path = MagicMock(return_value="/mock/path/voice.pt")
manager.load_voice = AsyncMock(return_value=mock_voice_tensor)
manager.list_voices = AsyncMock(return_value=["voice1", "voice2"])
manager.combine_voices = AsyncMock(return_value="voice1_voice2")
return manager
2025-02-09 18:32:17 -07:00
@pytest_asyncio.fixture
async def tts_service(mock_model_manager, mock_voice_manager):
"""Get mocked TTS service instance."""
service = TTSService()
service.model_manager = mock_model_manager
service._voice_manager = mock_voice_manager
return service
2025-02-09 18:32:17 -07:00
@pytest.fixture
def test_voice():
"""Return a test voice name."""
return "voice1"