2025-01-01 17:38:22 -07:00
|
|
|
"""Tests for FastAPI application"""
|
2025-01-01 21:50:41 -07:00
|
|
|
|
|
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
|
2024-12-31 02:55:51 -07:00
|
|
|
import pytest
|
|
|
|
from fastapi.testclient import TestClient
|
2025-01-01 21:50:41 -07:00
|
|
|
|
2025-01-01 17:38:22 -07:00
|
|
|
from api.src.main import app, lifespan
|
2024-12-31 02:55:51 -07:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
2025-01-01 17:38:22 -07:00
|
|
|
def test_client():
|
2024-12-31 02:55:51 -07:00
|
|
|
"""Create a test client"""
|
|
|
|
return TestClient(app)
|
|
|
|
|
|
|
|
|
2025-01-01 17:38:22 -07:00
|
|
|
def test_health_check(test_client):
|
2024-12-31 02:55:51 -07:00
|
|
|
"""Test health check endpoint"""
|
2025-01-01 17:38:22 -07:00
|
|
|
response = test_client.get("/health")
|
2024-12-31 02:55:51 -07:00
|
|
|
assert response.status_code == 200
|
|
|
|
assert response.json() == {"status": "healthy"}
|
|
|
|
|
|
|
|
|
2025-01-01 17:38:22 -07:00
|
|
|
@pytest.mark.asyncio
|
2025-01-01 21:50:41 -07:00
|
|
|
@patch("api.src.main.TTSModel")
|
|
|
|
@patch("api.src.main.logger")
|
2025-01-01 17:38:22 -07:00
|
|
|
async def test_lifespan_successful_warmup(mock_logger, mock_tts_model):
|
|
|
|
"""Test successful model warmup in lifespan"""
|
|
|
|
# Mock the model initialization with model info and voicepack count
|
|
|
|
mock_model = MagicMock()
|
|
|
|
# Mock file system for voice counting
|
|
|
|
mock_tts_model.VOICES_DIR = "/mock/voices"
|
2025-01-01 21:50:41 -07:00
|
|
|
with patch("os.listdir", return_value=["voice1.pt", "voice2.pt", "voice3.pt"]):
|
2025-01-01 17:38:22 -07:00
|
|
|
mock_tts_model.initialize.return_value = (mock_model, 3) # 3 voice files
|
|
|
|
mock_tts_model._device = "cuda" # Set device class variable
|
2025-01-01 21:50:41 -07:00
|
|
|
|
2025-01-01 17:38:22 -07:00
|
|
|
# Create an async generator from the lifespan context manager
|
|
|
|
async_gen = lifespan(MagicMock())
|
|
|
|
# Start the context manager
|
|
|
|
await async_gen.__aenter__()
|
2025-01-01 21:50:41 -07:00
|
|
|
|
2025-01-01 17:38:22 -07:00
|
|
|
# Verify the expected logging sequence
|
|
|
|
mock_logger.info.assert_any_call("Loading TTS model and voice packs...")
|
|
|
|
mock_logger.info.assert_any_call("Model loaded and warmed up on cuda")
|
|
|
|
mock_logger.info.assert_any_call("3 voice packs loaded successfully")
|
2025-01-01 21:50:41 -07:00
|
|
|
|
2025-01-01 17:38:22 -07:00
|
|
|
# Verify model initialization was called
|
|
|
|
mock_tts_model.initialize.assert_called_once()
|
2025-01-01 21:50:41 -07:00
|
|
|
|
2025-01-01 17:38:22 -07:00
|
|
|
# Clean up
|
|
|
|
await async_gen.__aexit__(None, None, None)
|
2024-12-31 02:55:51 -07:00
|
|
|
|
|
|
|
|
2025-01-01 17:38:22 -07:00
|
|
|
@pytest.mark.asyncio
|
2025-01-01 21:50:41 -07:00
|
|
|
@patch("api.src.main.TTSModel")
|
|
|
|
@patch("api.src.main.logger")
|
2025-01-01 17:38:22 -07:00
|
|
|
async def test_lifespan_failed_warmup(mock_logger, mock_tts_model):
|
|
|
|
"""Test failed model warmup in lifespan"""
|
|
|
|
# Mock the model initialization to fail
|
|
|
|
mock_tts_model.initialize.side_effect = Exception("Failed to initialize model")
|
2025-01-01 21:50:41 -07:00
|
|
|
|
2025-01-01 17:38:22 -07:00
|
|
|
# Create an async generator from the lifespan context manager
|
|
|
|
async_gen = lifespan(MagicMock())
|
2025-01-01 21:50:41 -07:00
|
|
|
|
2025-01-01 17:38:22 -07:00
|
|
|
# Verify the exception is raised
|
|
|
|
with pytest.raises(Exception, match="Failed to initialize model"):
|
|
|
|
await async_gen.__aenter__()
|
2025-01-01 21:50:41 -07:00
|
|
|
|
2025-01-01 17:38:22 -07:00
|
|
|
# Verify the expected logging sequence
|
|
|
|
mock_logger.info.assert_called_with("Loading TTS model and voice packs...")
|
2025-01-01 21:50:41 -07:00
|
|
|
|
2025-01-01 17:38:22 -07:00
|
|
|
# Clean up
|
|
|
|
await async_gen.__aexit__(None, None, None)
|
2024-12-31 02:55:51 -07:00
|
|
|
|
|
|
|
|
2025-01-01 17:38:22 -07:00
|
|
|
@pytest.mark.asyncio
|
2025-01-01 21:50:41 -07:00
|
|
|
@patch("api.src.main.TTSModel")
|
2025-01-01 17:38:22 -07:00
|
|
|
async def test_lifespan_cuda_warmup(mock_tts_model):
|
|
|
|
"""Test model warmup specifically on CUDA"""
|
|
|
|
# Mock the model initialization with CUDA and voicepacks
|
|
|
|
mock_model = MagicMock()
|
|
|
|
# Mock file system for voice counting
|
|
|
|
mock_tts_model.VOICES_DIR = "/mock/voices"
|
2025-01-01 21:50:41 -07:00
|
|
|
with patch("os.listdir", return_value=["voice1.pt", "voice2.pt"]):
|
2025-01-01 17:38:22 -07:00
|
|
|
mock_tts_model.initialize.return_value = (mock_model, 2) # 2 voice files
|
|
|
|
mock_tts_model._device = "cuda" # Set device class variable
|
2025-01-01 21:50:41 -07:00
|
|
|
|
2025-01-01 17:38:22 -07:00
|
|
|
# Create an async generator from the lifespan context manager
|
|
|
|
async_gen = lifespan(MagicMock())
|
|
|
|
await async_gen.__aenter__()
|
2025-01-01 21:50:41 -07:00
|
|
|
|
2025-01-01 17:38:22 -07:00
|
|
|
# Verify model was initialized
|
|
|
|
mock_tts_model.initialize.assert_called_once()
|
2025-01-01 21:50:41 -07:00
|
|
|
|
2025-01-01 17:38:22 -07:00
|
|
|
# Clean up
|
|
|
|
await async_gen.__aexit__(None, None, None)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
2025-01-01 21:50:41 -07:00
|
|
|
@patch("api.src.main.TTSModel")
|
2025-01-01 17:38:22 -07:00
|
|
|
async def test_lifespan_cpu_fallback(mock_tts_model):
|
|
|
|
"""Test model warmup falling back to CPU"""
|
|
|
|
# Mock the model initialization with CPU and voicepacks
|
|
|
|
mock_model = MagicMock()
|
|
|
|
# Mock file system for voice counting
|
|
|
|
mock_tts_model.VOICES_DIR = "/mock/voices"
|
2025-01-01 21:50:41 -07:00
|
|
|
with patch(
|
|
|
|
"os.listdir", return_value=["voice1.pt", "voice2.pt", "voice3.pt", "voice4.pt"]
|
|
|
|
):
|
2025-01-01 17:38:22 -07:00
|
|
|
mock_tts_model.initialize.return_value = (mock_model, 4) # 4 voice files
|
|
|
|
mock_tts_model._device = "cpu" # Set device class variable
|
2025-01-01 21:50:41 -07:00
|
|
|
|
2025-01-01 17:38:22 -07:00
|
|
|
# Create an async generator from the lifespan context manager
|
|
|
|
async_gen = lifespan(MagicMock())
|
|
|
|
await async_gen.__aenter__()
|
2025-01-01 21:50:41 -07:00
|
|
|
|
2025-01-01 17:38:22 -07:00
|
|
|
# Verify model was initialized
|
|
|
|
mock_tts_model.initialize.assert_called_once()
|
2025-01-01 21:50:41 -07:00
|
|
|
|
2025-01-01 17:38:22 -07:00
|
|
|
# Clean up
|
|
|
|
await async_gen.__aexit__(None, None, None)
|