Added basic pytest on the fastapi side

This commit is contained in:
remsky 2024-12-30 13:25:30 -07:00
parent 60a19bde43
commit 175daea325
8 changed files with 244 additions and 1 deletions

1
api/__init__.py Normal file
View file

@ -0,0 +1 @@
# Make api directory a Python package

View file

@ -30,7 +30,7 @@ class TTSRequest(BaseModel):
text: str
voice: str = "af" # Default voice
local: bool = False # Whether to save file locally or return bytes
speed: float = 1.0
speed: float = Field(default=1.0, gt=0.0, description="Speed multiplier (must be positive)")
stitch_long_output: bool = True # Whether to stitch together long outputs

1
api/tests/__init__.py Normal file
View file

@ -0,0 +1 @@
# Make tests directory a Python package

58
api/tests/conftest.py Normal file
View file

@ -0,0 +1,58 @@
import pytest
from unittest.mock import Mock, patch
import sys
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
# Mock torch and other ML modules before they're imported
sys.modules['torch'] = Mock()
sys.modules['transformers'] = Mock()
sys.modules['phonemizer'] = Mock()
sys.modules['models'] = Mock()
sys.modules['models.build_model'] = Mock()
sys.modules['kokoro'] = Mock()
sys.modules['kokoro.generate'] = Mock()
sys.modules['kokoro.phonemize'] = Mock()
sys.modules['kokoro.tokenize'] = Mock()
from api.src.database.database import Base, get_db
from api.src.main import app
# Use SQLite for testing
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
@pytest.fixture(scope="function")
def db():
"""Create a fresh database for each test"""
Base.metadata.create_all(bind=engine)
db = TestingSessionLocal()
try:
yield db
finally:
db.close()
Base.metadata.drop_all(bind=engine)
@pytest.fixture(scope="function")
def client(db):
"""Create a test client with database dependency override"""
def override_get_db():
try:
yield db
finally:
db.close()
app.dependency_overrides[get_db] = override_get_db
yield app.dependency_overrides
app.dependency_overrides = {}
@pytest.fixture(autouse=True)
def mock_tts_model():
"""Mock TTSModel to avoid loading real models during tests"""
with patch("api.src.services.tts.TTSModel") as mock:
model_instance = Mock()
model_instance.get_instance.return_value = model_instance
model_instance.get_voicepack.return_value = None
mock.get_instance.return_value = model_instance
yield model_instance

160
api/tests/test_endpoints.py Normal file
View file

@ -0,0 +1,160 @@
from fastapi.testclient import TestClient
import pytest
from unittest.mock import Mock, patch
from ..src.main import app
from ..src.services.tts import TTSService
# Create test client
client = TestClient(app)
# Mock TTSService methods
@pytest.fixture
def mock_tts_service():
with patch("api.src.routers.tts.TTSService") as mock_service:
# Setup mock returns
service_instance = Mock()
service_instance.list_voices.return_value = ["af", "en"]
service_instance.create_tts_request.return_value = 1
service_instance.get_request_status.return_value = Mock(
id=1,
status="completed",
output_file="test.wav",
processing_time=1.0
)
mock_service.return_value = service_instance
yield service_instance
def test_health_check():
"""Test the health check endpoint"""
response = client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "healthy"}
def test_list_voices(mock_tts_service):
"""Test listing available voices"""
response = client.get("/tts/voices")
assert response.status_code == 200
assert response.json() == {
"voices": ["af", "en"],
"default": "af"
}
def test_create_tts_request(mock_tts_service):
"""Test creating a TTS request"""
test_request = {
"text": "Hello world",
"voice": "af",
"speed": 1.0,
"stitch_long_output": True
}
response = client.post("/tts", json=test_request)
assert response.status_code == 200
assert response.json() == {
"request_id": 1,
"status": "pending",
"output_file": None,
"processing_time": None
}
def test_create_tts_invalid_voice(mock_tts_service):
"""Test creating a TTS request with invalid voice"""
test_request = {
"text": "Hello world",
"voice": "invalid_voice",
"speed": 1.0,
"stitch_long_output": True
}
response = client.post("/tts", json=test_request)
assert response.status_code == 400
assert "Voice 'invalid_voice' not found" in response.json()["detail"]
def test_get_tts_status(mock_tts_service):
"""Test getting TTS request status"""
response = client.get("/tts/1")
assert response.status_code == 200
assert response.json() == {
"request_id": 1,
"status": "completed",
"output_file": "test.wav",
"processing_time": 1.0
}
def test_get_tts_status_not_found(mock_tts_service):
"""Test getting status of non-existent request"""
mock_tts_service.get_request_status.return_value = None
response = client.get("/tts/999")
assert response.status_code == 404
assert response.json()["detail"] == "Request not found"
@patch("builtins.open", create=True)
@patch("os.path.exists", return_value=True)
def test_get_audio_file(mock_exists, mock_open, mock_tts_service):
"""Test downloading audio file"""
# Set up mock request status with output file
mock_request = Mock(
id=1,
status="completed", # Must match the status check in router
output_file="test.wav",
processing_time=1.0
)
mock_tts_service.get_request_status.return_value = mock_request
# Mock file read
mock_open.return_value.__enter__.return_value.read.return_value = b"audio data"
response = client.get("/tts/file/1")
assert response.status_code == 200
assert response.headers["content-type"] == "audio/wav"
assert response.headers["content-disposition"] == "attachment; filename=speech_1.wav"
assert response.content == b"audio data"
def test_get_audio_file_not_found(mock_tts_service):
"""Test downloading non-existent audio file"""
mock_tts_service.get_request_status.return_value = None
response = client.get("/tts/file/999")
assert response.status_code == 404
assert response.json()["detail"] == "Request not found"
def test_create_tts_invalid_speed(mock_tts_service):
"""Test creating a TTS request with invalid speed"""
test_request = {
"text": "Hello world",
"voice": "af",
"speed": -1.0, # Invalid speed
"stitch_long_output": True
}
response = client.post("/tts", json=test_request)
assert response.status_code == 422 # Validation error
def test_get_audio_file_not_completed(mock_tts_service):
"""Test getting audio file for request that's still processing"""
mock_request = Mock(
id=1,
status="processing", # Not completed yet
output_file=None,
processing_time=None
)
mock_tts_service.get_request_status.return_value = mock_request
response = client.get("/tts/file/1")
assert response.status_code == 400
assert response.json()["detail"] == "Audio generation not complete"
def test_get_tts_status_processing(mock_tts_service):
"""Test getting status of a processing request"""
mock_request = Mock(
id=1,
status="processing",
output_file=None,
processing_time=None
)
mock_tts_service.get_request_status.return_value = mock_request
response = client.get("/tts/1")
assert response.status_code == 200
assert response.json() == {
"request_id": 1,
"status": "processing",
"output_file": None,
"processing_time": None
}

5
pytest.ini Normal file
View file

@ -0,0 +1,5 @@
[pytest]
testpaths = api/tests
python_files = test_*.py
addopts = -v --tb=short
pythonpath = .

12
requirements-test.txt Normal file
View file

@ -0,0 +1,12 @@
# Core dependencies for testing
fastapi==0.115.6
uvicorn==0.34.0
pydantic==2.10.4
pydantic-settings==2.7.0
python-dotenv==1.0.1
sqlalchemy==2.0.27
# Testing
pytest==8.0.0
httpx==0.26.0
pytest-asyncio==0.23.5

View file

@ -1,3 +1,4 @@
# Primarily for reference, as Dockerfile refer
# Core dependencies
fastapi==0.115.6
uvicorn==0.34.0
@ -19,3 +20,8 @@ regex==2024.11.6
tqdm==4.67.1
requests==2.32.3
munch==4.0.0
# Testing
pytest==8.0.0
httpx==0.26.0
pytest-asyncio==0.23.5