Kokoro-FastAPI/api/tests/test_voice_manager.py

149 lines
No EOL
5.4 KiB
Python

import pytest
from unittest.mock import AsyncMock, patch, MagicMock
import torch
from pathlib import Path
from ..src.inference.voice_manager import VoiceManager
from ..src.structures.model_schemas import VoiceConfig
@pytest.fixture
def mock_voice_tensor():
return torch.randn(10, 10) # Dummy tensor
@pytest.fixture
def voice_manager():
return VoiceManager(VoiceConfig())
@pytest.mark.asyncio
async def test_load_voice(voice_manager, mock_voice_tensor):
"""Test loading a single voice"""
with patch("api.src.core.paths.load_voice_tensor", new_callable=AsyncMock) as mock_load:
mock_load.return_value = mock_voice_tensor
with patch("os.path.exists", return_value=True):
voice = await voice_manager.load_voice("af_bella", "cpu")
assert torch.equal(voice, mock_voice_tensor)
@pytest.mark.asyncio
async def test_load_voice_not_found(voice_manager):
"""Test loading non-existent voice"""
with patch("os.path.exists", return_value=False):
with pytest.raises(RuntimeError, match="Voice not found: invalid_voice"):
await voice_manager.load_voice("invalid_voice", "cpu")
@pytest.mark.asyncio
async def test_combine_voices_with_saving(voice_manager, mock_voice_tensor):
"""Test combining voices with local saving enabled"""
with patch("api.src.core.paths.load_voice_tensor", new_callable=AsyncMock) as mock_load, \
patch("torch.save") as mock_save, \
patch("os.makedirs"), \
patch("os.path.exists", return_value=True):
# Setup mocks
mock_load.return_value = mock_voice_tensor
# Mock settings
with patch("api.src.core.config.settings") as mock_settings:
mock_settings.allow_local_voice_saving = True
mock_settings.voices_dir = "/mock/voices"
# Combine voices
combined = await voice_manager.combine_voices(["af_bella", "af_sarah"], "cpu")
assert combined == "af_bella+af_sarah" # Note: using + separator
# Verify voice was saved
mock_save.assert_called_once()
@pytest.mark.asyncio
async def test_combine_voices_without_saving(voice_manager, mock_voice_tensor):
"""Test combining voices without local saving"""
with patch("api.src.core.paths.load_voice_tensor", new_callable=AsyncMock) as mock_load, \
patch("torch.save") as mock_save, \
patch("os.makedirs"), \
patch("os.path.exists", return_value=True):
# Setup mocks
mock_load.return_value = mock_voice_tensor
# Mock settings
with patch("api.src.core.config.settings") as mock_settings:
mock_settings.allow_local_voice_saving = False
mock_settings.voices_dir = "/mock/voices"
# Combine voices
combined = await voice_manager.combine_voices(["af_bella", "af_sarah"], "cpu")
assert combined == "af_bella+af_sarah" # Note: using + separator
# Verify voice was not saved
mock_save.assert_not_called()
@pytest.mark.asyncio
async def test_combine_voices_single_voice(voice_manager):
"""Test combining with single voice"""
with pytest.raises(ValueError, match="At least 2 voices are required"):
await voice_manager.combine_voices(["af_bella"], "cpu")
@pytest.mark.asyncio
async def test_list_voices(voice_manager):
"""Test listing available voices"""
with patch("os.listdir", return_value=["af_bella.pt", "af_sarah.pt", "af_bella+af_sarah.pt"]), \
patch("os.makedirs"):
voices = await voice_manager.list_voices()
assert len(voices) == 3
assert "af_bella" in voices
assert "af_sarah" in voices
assert "af_bella+af_sarah" in voices
@pytest.mark.asyncio
async def test_load_combined_voice(voice_manager, mock_voice_tensor):
"""Test loading a combined voice"""
with patch("api.src.core.paths.load_voice_tensor", new_callable=AsyncMock) as mock_load:
mock_load.return_value = mock_voice_tensor
with patch("os.path.exists", return_value=True):
voice = await voice_manager.load_voice("af_bella+af_sarah", "cpu")
assert torch.equal(voice, mock_voice_tensor)
def test_cache_management(voice_manager, mock_voice_tensor):
"""Test voice cache management"""
# Set small cache size
voice_manager._config.cache_size = 2
# Add items to cache
voice_manager._voice_cache = {
"voice1_cpu": torch.randn(5, 5),
"voice2_cpu": torch.randn(5, 5),
}
# Try adding another item
voice_manager._manage_cache()
# Check cache size maintained
assert len(voice_manager._voice_cache) <= 2
@pytest.mark.asyncio
async def test_voice_loading_with_cache(voice_manager, mock_voice_tensor):
"""Test voice loading with cache enabled"""
with patch("api.src.core.paths.load_voice_tensor", new_callable=AsyncMock) as mock_load, \
patch("os.path.exists", return_value=True):
mock_load.return_value = mock_voice_tensor
# First load should hit disk
voice1 = await voice_manager.load_voice("af_bella", "cpu")
assert mock_load.call_count == 1
# Second load should hit cache
voice2 = await voice_manager.load_voice("af_bella", "cpu")
assert mock_load.call_count == 1 # Still 1
assert torch.equal(voice1, voice2)