mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
190 lines
No EOL
7 KiB
Python
190 lines
No EOL
7 KiB
Python
"""Tests for model and voice managers"""
|
|
import os
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
|
|
|
from api.src.inference.model_manager import get_manager as get_model_manager
|
|
from api.src.inference.voice_manager import get_manager as get_voice_manager
|
|
|
|
# Get project root path
|
|
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
MOCK_VOICES_DIR = os.path.join(PROJECT_ROOT, "api", "src", "voices")
|
|
MOCK_MODEL_DIR = os.path.join(PROJECT_ROOT, "api", "src", "models")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_model_manager_initialization():
|
|
"""Test model manager initialization"""
|
|
with patch("api.src.inference.model_manager.settings") as mock_settings, \
|
|
patch("api.src.core.paths.get_model_path") as mock_get_path:
|
|
|
|
mock_settings.model_dir = MOCK_MODEL_DIR
|
|
mock_settings.onnx_model_path = "model.onnx"
|
|
mock_get_path.return_value = os.path.join(MOCK_MODEL_DIR, "model.onnx")
|
|
|
|
manager = await get_model_manager()
|
|
assert manager is not None
|
|
backend = manager.get_backend()
|
|
assert backend is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_model_manager_generate():
|
|
"""Test model generation"""
|
|
with patch("api.src.inference.model_manager.settings") as mock_settings, \
|
|
patch("api.src.core.paths.get_model_path") as mock_get_path, \
|
|
patch("torch.load") as mock_torch_load:
|
|
|
|
mock_settings.model_dir = MOCK_MODEL_DIR
|
|
mock_settings.onnx_model_path = "model.onnx"
|
|
mock_settings.use_onnx = True
|
|
mock_settings.use_gpu = False
|
|
mock_get_path.return_value = os.path.join(MOCK_MODEL_DIR, "model.onnx")
|
|
|
|
# Mock torch load to return a tensor
|
|
mock_torch_load.return_value = torch.zeros(192)
|
|
|
|
manager = await get_model_manager()
|
|
|
|
# Set up mock backend
|
|
mock_backend = AsyncMock()
|
|
mock_backend.is_loaded = True
|
|
mock_backend.device = "cpu"
|
|
|
|
# Create audio tensor and ensure it's properly mocked
|
|
audio_data = torch.zeros(48000, dtype=torch.float32)
|
|
async def mock_generate(*args, **kwargs):
|
|
return audio_data
|
|
mock_backend.generate.side_effect = mock_generate
|
|
|
|
# Set up manager with mock backend
|
|
manager._backends['onnx_cpu'] = mock_backend
|
|
manager._current_backend = 'onnx_cpu'
|
|
|
|
# Generate audio
|
|
tokens = [1, 2, 3]
|
|
voice_tensor = torch.zeros(192)
|
|
audio = await manager.generate(tokens, voice_tensor, speed=1.0)
|
|
|
|
assert isinstance(audio, torch.Tensor), "Generated audio should be torch tensor"
|
|
assert audio.dtype == torch.float32, "Audio should be 32-bit float"
|
|
assert audio.shape == (48000,), "Audio should have 48000 samples"
|
|
assert mock_backend.generate.call_count == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_voice_manager_initialization():
|
|
"""Test voice manager initialization"""
|
|
with patch("api.src.inference.voice_manager.settings") as mock_settings, \
|
|
patch("os.path.exists") as mock_exists:
|
|
|
|
mock_settings.voices_dir = MOCK_VOICES_DIR
|
|
mock_exists.return_value = True
|
|
|
|
manager = await get_voice_manager()
|
|
assert manager is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_voice_manager_list_voices():
|
|
"""Test listing available voices"""
|
|
with patch("api.src.inference.voice_manager.settings") as mock_settings, \
|
|
patch("os.listdir") as mock_listdir, \
|
|
patch("os.makedirs") as mock_makedirs, \
|
|
patch("os.path.exists") as mock_exists:
|
|
|
|
mock_settings.voices_dir = MOCK_VOICES_DIR
|
|
mock_listdir.return_value = ["af_bella.pt", "af_sarah.pt", "bm_lewis.pt"]
|
|
mock_exists.return_value = True
|
|
|
|
manager = await get_voice_manager()
|
|
voices = await manager.list_voices()
|
|
|
|
assert isinstance(voices, list)
|
|
assert len(voices) == 3, f"Expected 3 voices but got {len(voices)}"
|
|
assert sorted(voices) == ["af_bella", "af_sarah", "bm_lewis"]
|
|
mock_listdir.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_voice_manager_load_voice():
|
|
"""Test loading a voice"""
|
|
with patch("api.src.inference.voice_manager.settings") as mock_settings, \
|
|
patch("torch.load") as mock_torch_load, \
|
|
patch("os.path.exists") as mock_exists:
|
|
|
|
mock_settings.voices_dir = MOCK_VOICES_DIR
|
|
mock_exists.return_value = True
|
|
|
|
# Create a mock tensor
|
|
mock_tensor = torch.zeros(192)
|
|
mock_torch_load.return_value = mock_tensor
|
|
|
|
manager = await get_voice_manager()
|
|
voice_tensor = await manager.load_voice("af_bella", device="cpu")
|
|
|
|
assert isinstance(voice_tensor, torch.Tensor)
|
|
assert voice_tensor.shape == (192,)
|
|
mock_torch_load.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_voice_manager_combine_voices():
|
|
"""Test combining voices"""
|
|
with patch("api.src.inference.voice_manager.settings") as mock_settings, \
|
|
patch("torch.load") as mock_load, \
|
|
patch("torch.save") as mock_save, \
|
|
patch("os.makedirs") as mock_makedirs, \
|
|
patch("os.path.exists") as mock_exists:
|
|
|
|
mock_settings.voices_dir = MOCK_VOICES_DIR
|
|
mock_exists.return_value = True
|
|
|
|
# Create mock tensors
|
|
mock_tensor1 = torch.ones(192)
|
|
mock_tensor2 = torch.ones(192) * 2
|
|
mock_load.side_effect = [mock_tensor1, mock_tensor2]
|
|
|
|
manager = await get_voice_manager()
|
|
combined_name = await manager.combine_voices(["af_bella", "af_sarah"])
|
|
|
|
assert combined_name == "af_bella_af_sarah"
|
|
assert mock_load.call_count == 2
|
|
mock_save.assert_called_once()
|
|
|
|
# Verify the combined tensor was saved
|
|
saved_tensor = mock_save.call_args[0][0]
|
|
assert isinstance(saved_tensor, torch.Tensor)
|
|
assert saved_tensor.shape == (192,)
|
|
# Should be average of the two tensors
|
|
assert torch.allclose(saved_tensor, torch.ones(192) * 1.5)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_voice_manager_invalid_voice():
|
|
"""Test loading invalid voice"""
|
|
with patch("api.src.inference.voice_manager.settings") as mock_settings, \
|
|
patch("os.path.exists") as mock_exists:
|
|
|
|
mock_settings.voices_dir = MOCK_VOICES_DIR
|
|
mock_exists.return_value = False
|
|
|
|
manager = await get_voice_manager()
|
|
with pytest.raises(RuntimeError, match="Voice not found"):
|
|
await manager.load_voice("invalid_voice", device="cpu")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_voice_manager_combine_invalid_voices():
|
|
"""Test combining with invalid voices"""
|
|
with patch("api.src.inference.voice_manager.settings") as mock_settings, \
|
|
patch("os.path.exists") as mock_exists:
|
|
|
|
mock_settings.voices_dir = MOCK_VOICES_DIR
|
|
mock_exists.return_value = False
|
|
|
|
manager = await get_voice_manager()
|
|
with pytest.raises(RuntimeError, match="Voice not found"):
|
|
await manager.combine_voices(["invalid_voice1", "invalid_voice2"]) |