"""Tests for FastAPI application""" from unittest.mock import MagicMock, patch import pytest from fastapi.testclient import TestClient from api.src.main import app, lifespan @pytest.fixture def test_client(): """Create a test client""" return TestClient(app) def test_health_check(test_client): """Test health check endpoint""" response = test_client.get("/health") assert response.status_code == 200 assert response.json() == {"status": "healthy"} @pytest.mark.asyncio @patch("api.src.main.TTSModel") @patch("api.src.main.logger") async def test_lifespan_successful_warmup(mock_logger, mock_tts_model): """Test successful model warmup in lifespan""" # Mock the model initialization with model info and voicepack count mock_model = MagicMock() # Mock file system for voice counting mock_tts_model.VOICES_DIR = "/mock/voices" with patch("os.listdir", return_value=["voice1.pt", "voice2.pt", "voice3.pt"]): mock_tts_model.initialize.return_value = (mock_model, 3) # 3 voice files mock_tts_model._device = "cuda" # Set device class variable # Create an async generator from the lifespan context manager async_gen = lifespan(MagicMock()) # Start the context manager await async_gen.__aenter__() # Verify the expected logging sequence mock_logger.info.assert_any_call("Loading TTS model and voice packs...") mock_logger.info.assert_any_call("Model loaded and warmed up on cuda") mock_logger.info.assert_any_call("3 voice packs loaded successfully") # Verify model initialization was called mock_tts_model.initialize.assert_called_once() # Clean up await async_gen.__aexit__(None, None, None) @pytest.mark.asyncio @patch("api.src.main.TTSModel") @patch("api.src.main.logger") async def test_lifespan_failed_warmup(mock_logger, mock_tts_model): """Test failed model warmup in lifespan""" # Mock the model initialization to fail mock_tts_model.initialize.side_effect = Exception("Failed to initialize model") # Create an async generator from the lifespan context manager async_gen = lifespan(MagicMock()) # Verify the exception is raised with pytest.raises(Exception, match="Failed to initialize model"): await async_gen.__aenter__() # Verify the expected logging sequence mock_logger.info.assert_called_with("Loading TTS model and voice packs...") # Clean up await async_gen.__aexit__(None, None, None) @pytest.mark.asyncio @patch("api.src.main.TTSModel") async def test_lifespan_cuda_warmup(mock_tts_model): """Test model warmup specifically on CUDA""" # Mock the model initialization with CUDA and voicepacks mock_model = MagicMock() # Mock file system for voice counting mock_tts_model.VOICES_DIR = "/mock/voices" with patch("os.listdir", return_value=["voice1.pt", "voice2.pt"]): mock_tts_model.initialize.return_value = (mock_model, 2) # 2 voice files mock_tts_model._device = "cuda" # Set device class variable # Create an async generator from the lifespan context manager async_gen = lifespan(MagicMock()) await async_gen.__aenter__() # Verify model was initialized mock_tts_model.initialize.assert_called_once() # Clean up await async_gen.__aexit__(None, None, None) @pytest.mark.asyncio @patch("api.src.main.TTSModel") async def test_lifespan_cpu_fallback(mock_tts_model): """Test model warmup falling back to CPU""" # Mock the model initialization with CPU and voicepacks mock_model = MagicMock() # Mock file system for voice counting mock_tts_model.VOICES_DIR = "/mock/voices" with patch( "os.listdir", return_value=["voice1.pt", "voice2.pt", "voice3.pt", "voice4.pt"] ): mock_tts_model.initialize.return_value = (mock_model, 4) # 4 voice files mock_tts_model._device = "cpu" # Set device class variable # Create an async generator from the lifespan context manager async_gen = lifespan(MagicMock()) await async_gen.__aenter__() # Verify model was initialized mock_tts_model.initialize.assert_called_once() # Clean up await async_gen.__aexit__(None, None, None)