mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Fixes for running the unit tests on windows.
This commit is contained in:
parent
29066f7c9f
commit
eac7ab4449
3 changed files with 13 additions and 12 deletions
|
@ -1,6 +1,6 @@
|
||||||
|
import os
|
||||||
from unittest.mock import ANY, MagicMock, patch
|
from unittest.mock import ANY, MagicMock, patch
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ def test_clear_memory(mock_sync, mock_clear, kokoro_backend):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_load_model_validation(kokoro_backend):
|
async def test_load_model_validation(kokoro_backend):
|
||||||
"""Test model loading validation."""
|
"""Test model loading validation."""
|
||||||
with pytest.raises(RuntimeError, match="Failed to load Kokoro model"):
|
with pytest.raises(FileNotFoundError):
|
||||||
await kokoro_backend.load_model("nonexistent_model.pth")
|
await kokoro_backend.load_model("nonexistent_model.pth")
|
||||||
|
|
||||||
|
|
||||||
|
@ -140,7 +140,7 @@ async def test_generate_uses_correct_pipeline(kokoro_backend):
|
||||||
patch("tempfile.gettempdir") as mock_tempdir,
|
patch("tempfile.gettempdir") as mock_tempdir,
|
||||||
):
|
):
|
||||||
mock_load_voice.return_value = torch.ones(1)
|
mock_load_voice.return_value = torch.ones(1)
|
||||||
mock_tempdir.return_value = "/tmp"
|
mock_tempdir.return_value = f"{os.sep}tmp"
|
||||||
|
|
||||||
# Mock KPipeline
|
# Mock KPipeline
|
||||||
mock_pipeline = MagicMock()
|
mock_pipeline = MagicMock()
|
||||||
|
@ -162,4 +162,4 @@ async def test_generate_uses_correct_pipeline(kokoro_backend):
|
||||||
# Verify the voice path is a temp file path
|
# Verify the voice path is a temp file path
|
||||||
call_args = mock_pipeline.call_args
|
call_args = mock_pipeline.call_args
|
||||||
assert isinstance(call_args[1]["voice"], str)
|
assert isinstance(call_args[1]["voice"], str)
|
||||||
assert call_args[1]["voice"].startswith("/tmp/temp_voice_")
|
assert call_args[1]["voice"].startswith(f"{os.sep}tmp{os.sep}temp_voice_")
|
||||||
|
|
|
@ -18,8 +18,8 @@ async def test_find_file_exists():
|
||||||
"""Test finding existing file."""
|
"""Test finding existing file."""
|
||||||
with patch("aiofiles.os.path.exists") as mock_exists:
|
with patch("aiofiles.os.path.exists") as mock_exists:
|
||||||
mock_exists.return_value = True
|
mock_exists.return_value = True
|
||||||
path = await _find_file("test.txt", ["/test/path"])
|
path = await _find_file("test.txt", [f"{os.sep}test{os.sep}path"])
|
||||||
assert path == "/test/path/test.txt"
|
assert path == f"{os.sep}test{os.sep}path{os.sep}test.txt"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -37,8 +37,8 @@ async def test_find_file_with_filter():
|
||||||
with patch("aiofiles.os.path.exists") as mock_exists:
|
with patch("aiofiles.os.path.exists") as mock_exists:
|
||||||
mock_exists.return_value = True
|
mock_exists.return_value = True
|
||||||
filter_fn = lambda p: p.endswith(".txt")
|
filter_fn = lambda p: p.endswith(".txt")
|
||||||
path = await _find_file("test.txt", ["/test/path"], filter_fn)
|
path = await _find_file("test.txt", [f"{os.sep}test{os.sep}path"], filter_fn)
|
||||||
assert path == "/test/path/test.txt"
|
assert path == f"{os.sep}test{os.sep}path{os.sep}test.txt"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
|
import os
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import os
|
|
||||||
|
|
||||||
from api.src.services.tts_service import TTSService
|
from api.src.services.tts_service import TTSService
|
||||||
|
|
||||||
|
@ -86,6 +85,7 @@ async def test_get_voice_path_combined():
|
||||||
"""Test getting path for combined voices."""
|
"""Test getting path for combined voices."""
|
||||||
model_manager = AsyncMock()
|
model_manager = AsyncMock()
|
||||||
voice_manager = AsyncMock()
|
voice_manager = AsyncMock()
|
||||||
|
|
||||||
voice_manager.get_voice_path.return_value = "/path/to/voice.pt"
|
voice_manager.get_voice_path.return_value = "/path/to/voice.pt"
|
||||||
|
|
||||||
with (
|
with (
|
||||||
|
@ -97,14 +97,15 @@ async def test_get_voice_path_combined():
|
||||||
):
|
):
|
||||||
mock_get_model.return_value = model_manager
|
mock_get_model.return_value = model_manager
|
||||||
mock_get_voice.return_value = voice_manager
|
mock_get_voice.return_value = voice_manager
|
||||||
mock_temp.return_value = "/tmp"
|
mock_temp.return_value = f"{os.sep}tmp"
|
||||||
mock_load.return_value = torch.ones(10)
|
mock_load.return_value = torch.ones(10)
|
||||||
|
|
||||||
service = await TTSService.create("test_output")
|
service = await TTSService.create("test_output")
|
||||||
name, path = await service._get_voices_path("voice1+voice2")
|
name, path = await service._get_voices_path("voice1+voice2")
|
||||||
assert name == "voice1+voice2"
|
assert name == "voice1+voice2"
|
||||||
|
print(f"{path=}")
|
||||||
# Verify the path points to a temporary file with expected format
|
# Verify the path points to a temporary file with expected format
|
||||||
assert path.startswith("/tmp/")
|
assert path.startswith(f"{os.sep}tmp{os.sep}")
|
||||||
assert "voice1+voice2" in path
|
assert "voice1+voice2" in path
|
||||||
assert path.endswith(".pt")
|
assert path.endswith(".pt")
|
||||||
mock_save.assert_called_once()
|
mock_save.assert_called_once()
|
||||||
|
|
Loading…
Add table
Reference in a new issue