From 175daea3257b86a250869b84143f5b5c61a5d681 Mon Sep 17 00:00:00 2001 From: remsky Date: Mon, 30 Dec 2024 13:25:30 -0700 Subject: [PATCH] Added basic pytest on the fastapi side --- api/__init__.py | 1 + api/src/models/schemas.py | 2 +- api/tests/__init__.py | 1 + api/tests/conftest.py | 58 +++++++++++++ api/tests/test_endpoints.py | 160 ++++++++++++++++++++++++++++++++++++ pytest.ini | 5 ++ requirements-test.txt | 12 +++ requirements.txt | 6 ++ 8 files changed, 244 insertions(+), 1 deletion(-) create mode 100644 api/__init__.py create mode 100644 api/tests/__init__.py create mode 100644 api/tests/conftest.py create mode 100644 api/tests/test_endpoints.py create mode 100644 pytest.ini create mode 100644 requirements-test.txt diff --git a/api/__init__.py b/api/__init__.py new file mode 100644 index 0000000..d4d3984 --- /dev/null +++ b/api/__init__.py @@ -0,0 +1 @@ +# Make api directory a Python package diff --git a/api/src/models/schemas.py b/api/src/models/schemas.py index 3458d14..228388c 100644 --- a/api/src/models/schemas.py +++ b/api/src/models/schemas.py @@ -30,7 +30,7 @@ class TTSRequest(BaseModel): text: str voice: str = "af" # Default voice local: bool = False # Whether to save file locally or return bytes - speed: float = 1.0 + speed: float = Field(default=1.0, gt=0.0, description="Speed multiplier (must be positive)") stitch_long_output: bool = True # Whether to stitch together long outputs diff --git a/api/tests/__init__.py b/api/tests/__init__.py new file mode 100644 index 0000000..b9911d8 --- /dev/null +++ b/api/tests/__init__.py @@ -0,0 +1 @@ +# Make tests directory a Python package diff --git a/api/tests/conftest.py b/api/tests/conftest.py new file mode 100644 index 0000000..015adec --- /dev/null +++ b/api/tests/conftest.py @@ -0,0 +1,58 @@ +import pytest +from unittest.mock import Mock, patch +import sys +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +# Mock torch and other ML modules before they're imported +sys.modules['torch'] = Mock() +sys.modules['transformers'] = Mock() +sys.modules['phonemizer'] = Mock() +sys.modules['models'] = Mock() +sys.modules['models.build_model'] = Mock() +sys.modules['kokoro'] = Mock() +sys.modules['kokoro.generate'] = Mock() +sys.modules['kokoro.phonemize'] = Mock() +sys.modules['kokoro.tokenize'] = Mock() + +from api.src.database.database import Base, get_db +from api.src.main import app + +# Use SQLite for testing +SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db" +engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}) +TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +@pytest.fixture(scope="function") +def db(): + """Create a fresh database for each test""" + Base.metadata.create_all(bind=engine) + db = TestingSessionLocal() + try: + yield db + finally: + db.close() + Base.metadata.drop_all(bind=engine) + +@pytest.fixture(scope="function") +def client(db): + """Create a test client with database dependency override""" + def override_get_db(): + try: + yield db + finally: + db.close() + + app.dependency_overrides[get_db] = override_get_db + yield app.dependency_overrides + app.dependency_overrides = {} + +@pytest.fixture(autouse=True) +def mock_tts_model(): + """Mock TTSModel to avoid loading real models during tests""" + with patch("api.src.services.tts.TTSModel") as mock: + model_instance = Mock() + model_instance.get_instance.return_value = model_instance + model_instance.get_voicepack.return_value = None + mock.get_instance.return_value = model_instance + yield model_instance diff --git a/api/tests/test_endpoints.py b/api/tests/test_endpoints.py new file mode 100644 index 0000000..fd9cf78 --- /dev/null +++ b/api/tests/test_endpoints.py @@ -0,0 +1,160 @@ +from fastapi.testclient import TestClient +import pytest +from unittest.mock import Mock, patch +from ..src.main import app +from ..src.services.tts import TTSService + +# Create test client +client = TestClient(app) + +# Mock TTSService methods +@pytest.fixture +def mock_tts_service(): + with patch("api.src.routers.tts.TTSService") as mock_service: + # Setup mock returns + service_instance = Mock() + service_instance.list_voices.return_value = ["af", "en"] + service_instance.create_tts_request.return_value = 1 + service_instance.get_request_status.return_value = Mock( + id=1, + status="completed", + output_file="test.wav", + processing_time=1.0 + ) + mock_service.return_value = service_instance + yield service_instance + +def test_health_check(): + """Test the health check endpoint""" + response = client.get("/health") + assert response.status_code == 200 + assert response.json() == {"status": "healthy"} + +def test_list_voices(mock_tts_service): + """Test listing available voices""" + response = client.get("/tts/voices") + assert response.status_code == 200 + assert response.json() == { + "voices": ["af", "en"], + "default": "af" + } + +def test_create_tts_request(mock_tts_service): + """Test creating a TTS request""" + test_request = { + "text": "Hello world", + "voice": "af", + "speed": 1.0, + "stitch_long_output": True + } + response = client.post("/tts", json=test_request) + assert response.status_code == 200 + assert response.json() == { + "request_id": 1, + "status": "pending", + "output_file": None, + "processing_time": None + } + +def test_create_tts_invalid_voice(mock_tts_service): + """Test creating a TTS request with invalid voice""" + test_request = { + "text": "Hello world", + "voice": "invalid_voice", + "speed": 1.0, + "stitch_long_output": True + } + response = client.post("/tts", json=test_request) + assert response.status_code == 400 + assert "Voice 'invalid_voice' not found" in response.json()["detail"] + +def test_get_tts_status(mock_tts_service): + """Test getting TTS request status""" + response = client.get("/tts/1") + assert response.status_code == 200 + assert response.json() == { + "request_id": 1, + "status": "completed", + "output_file": "test.wav", + "processing_time": 1.0 + } + +def test_get_tts_status_not_found(mock_tts_service): + """Test getting status of non-existent request""" + mock_tts_service.get_request_status.return_value = None + response = client.get("/tts/999") + assert response.status_code == 404 + assert response.json()["detail"] == "Request not found" + +@patch("builtins.open", create=True) +@patch("os.path.exists", return_value=True) +def test_get_audio_file(mock_exists, mock_open, mock_tts_service): + """Test downloading audio file""" + # Set up mock request status with output file + mock_request = Mock( + id=1, + status="completed", # Must match the status check in router + output_file="test.wav", + processing_time=1.0 + ) + mock_tts_service.get_request_status.return_value = mock_request + + # Mock file read + mock_open.return_value.__enter__.return_value.read.return_value = b"audio data" + + response = client.get("/tts/file/1") + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/wav" + assert response.headers["content-disposition"] == "attachment; filename=speech_1.wav" + assert response.content == b"audio data" + +def test_get_audio_file_not_found(mock_tts_service): + """Test downloading non-existent audio file""" + mock_tts_service.get_request_status.return_value = None + response = client.get("/tts/file/999") + assert response.status_code == 404 + assert response.json()["detail"] == "Request not found" + +def test_create_tts_invalid_speed(mock_tts_service): + """Test creating a TTS request with invalid speed""" + test_request = { + "text": "Hello world", + "voice": "af", + "speed": -1.0, # Invalid speed + "stitch_long_output": True + } + response = client.post("/tts", json=test_request) + assert response.status_code == 422 # Validation error + +def test_get_audio_file_not_completed(mock_tts_service): + """Test getting audio file for request that's still processing""" + mock_request = Mock( + id=1, + status="processing", # Not completed yet + output_file=None, + processing_time=None + ) + mock_tts_service.get_request_status.return_value = mock_request + + response = client.get("/tts/file/1") + assert response.status_code == 400 + assert response.json()["detail"] == "Audio generation not complete" + +def test_get_tts_status_processing(mock_tts_service): + """Test getting status of a processing request""" + mock_request = Mock( + id=1, + status="processing", + output_file=None, + processing_time=None + ) + mock_tts_service.get_request_status.return_value = mock_request + + response = client.get("/tts/1") + assert response.status_code == 200 + assert response.json() == { + "request_id": 1, + "status": "processing", + "output_file": None, + "processing_time": None + } diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..e7ea054 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,5 @@ +[pytest] +testpaths = api/tests +python_files = test_*.py +addopts = -v --tb=short +pythonpath = . diff --git a/requirements-test.txt b/requirements-test.txt new file mode 100644 index 0000000..53135bd --- /dev/null +++ b/requirements-test.txt @@ -0,0 +1,12 @@ +# Core dependencies for testing +fastapi==0.115.6 +uvicorn==0.34.0 +pydantic==2.10.4 +pydantic-settings==2.7.0 +python-dotenv==1.0.1 +sqlalchemy==2.0.27 + +# Testing +pytest==8.0.0 +httpx==0.26.0 +pytest-asyncio==0.23.5 diff --git a/requirements.txt b/requirements.txt index a5a314d..253d591 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +# Primarily for reference, as Dockerfile refer # Core dependencies fastapi==0.115.6 uvicorn==0.34.0 @@ -19,3 +20,8 @@ regex==2024.11.6 tqdm==4.67.1 requests==2.32.3 munch==4.0.0 + +# Testing +pytest==8.0.0 +httpx==0.26.0 +pytest-asyncio==0.23.5