mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Added basic pytest on the fastapi side
This commit is contained in:
parent
60a19bde43
commit
175daea325
8 changed files with 244 additions and 1 deletions
1
api/__init__.py
Normal file
1
api/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
# Make api directory a Python package
|
|
@ -30,7 +30,7 @@ class TTSRequest(BaseModel):
|
|||
text: str
|
||||
voice: str = "af" # Default voice
|
||||
local: bool = False # Whether to save file locally or return bytes
|
||||
speed: float = 1.0
|
||||
speed: float = Field(default=1.0, gt=0.0, description="Speed multiplier (must be positive)")
|
||||
stitch_long_output: bool = True # Whether to stitch together long outputs
|
||||
|
||||
|
||||
|
|
1
api/tests/__init__.py
Normal file
1
api/tests/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
# Make tests directory a Python package
|
58
api/tests/conftest.py
Normal file
58
api/tests/conftest.py
Normal file
|
@ -0,0 +1,58 @@
|
|||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
import sys
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
# Mock torch and other ML modules before they're imported
|
||||
sys.modules['torch'] = Mock()
|
||||
sys.modules['transformers'] = Mock()
|
||||
sys.modules['phonemizer'] = Mock()
|
||||
sys.modules['models'] = Mock()
|
||||
sys.modules['models.build_model'] = Mock()
|
||||
sys.modules['kokoro'] = Mock()
|
||||
sys.modules['kokoro.generate'] = Mock()
|
||||
sys.modules['kokoro.phonemize'] = Mock()
|
||||
sys.modules['kokoro.tokenize'] = Mock()
|
||||
|
||||
from api.src.database.database import Base, get_db
|
||||
from api.src.main import app
|
||||
|
||||
# Use SQLite for testing
|
||||
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
|
||||
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
|
||||
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def db():
|
||||
"""Create a fresh database for each test"""
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db = TestingSessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client(db):
|
||||
"""Create a test client with database dependency override"""
|
||||
def override_get_db():
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
app.dependency_overrides[get_db] = override_get_db
|
||||
yield app.dependency_overrides
|
||||
app.dependency_overrides = {}
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_tts_model():
|
||||
"""Mock TTSModel to avoid loading real models during tests"""
|
||||
with patch("api.src.services.tts.TTSModel") as mock:
|
||||
model_instance = Mock()
|
||||
model_instance.get_instance.return_value = model_instance
|
||||
model_instance.get_voicepack.return_value = None
|
||||
mock.get_instance.return_value = model_instance
|
||||
yield model_instance
|
160
api/tests/test_endpoints.py
Normal file
160
api/tests/test_endpoints.py
Normal file
|
@ -0,0 +1,160 @@
|
|||
from fastapi.testclient import TestClient
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from ..src.main import app
|
||||
from ..src.services.tts import TTSService
|
||||
|
||||
# Create test client
|
||||
client = TestClient(app)
|
||||
|
||||
# Mock TTSService methods
|
||||
@pytest.fixture
|
||||
def mock_tts_service():
|
||||
with patch("api.src.routers.tts.TTSService") as mock_service:
|
||||
# Setup mock returns
|
||||
service_instance = Mock()
|
||||
service_instance.list_voices.return_value = ["af", "en"]
|
||||
service_instance.create_tts_request.return_value = 1
|
||||
service_instance.get_request_status.return_value = Mock(
|
||||
id=1,
|
||||
status="completed",
|
||||
output_file="test.wav",
|
||||
processing_time=1.0
|
||||
)
|
||||
mock_service.return_value = service_instance
|
||||
yield service_instance
|
||||
|
||||
def test_health_check():
|
||||
"""Test the health check endpoint"""
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "healthy"}
|
||||
|
||||
def test_list_voices(mock_tts_service):
|
||||
"""Test listing available voices"""
|
||||
response = client.get("/tts/voices")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"voices": ["af", "en"],
|
||||
"default": "af"
|
||||
}
|
||||
|
||||
def test_create_tts_request(mock_tts_service):
|
||||
"""Test creating a TTS request"""
|
||||
test_request = {
|
||||
"text": "Hello world",
|
||||
"voice": "af",
|
||||
"speed": 1.0,
|
||||
"stitch_long_output": True
|
||||
}
|
||||
response = client.post("/tts", json=test_request)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"request_id": 1,
|
||||
"status": "pending",
|
||||
"output_file": None,
|
||||
"processing_time": None
|
||||
}
|
||||
|
||||
def test_create_tts_invalid_voice(mock_tts_service):
|
||||
"""Test creating a TTS request with invalid voice"""
|
||||
test_request = {
|
||||
"text": "Hello world",
|
||||
"voice": "invalid_voice",
|
||||
"speed": 1.0,
|
||||
"stitch_long_output": True
|
||||
}
|
||||
response = client.post("/tts", json=test_request)
|
||||
assert response.status_code == 400
|
||||
assert "Voice 'invalid_voice' not found" in response.json()["detail"]
|
||||
|
||||
def test_get_tts_status(mock_tts_service):
|
||||
"""Test getting TTS request status"""
|
||||
response = client.get("/tts/1")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"request_id": 1,
|
||||
"status": "completed",
|
||||
"output_file": "test.wav",
|
||||
"processing_time": 1.0
|
||||
}
|
||||
|
||||
def test_get_tts_status_not_found(mock_tts_service):
|
||||
"""Test getting status of non-existent request"""
|
||||
mock_tts_service.get_request_status.return_value = None
|
||||
response = client.get("/tts/999")
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "Request not found"
|
||||
|
||||
@patch("builtins.open", create=True)
|
||||
@patch("os.path.exists", return_value=True)
|
||||
def test_get_audio_file(mock_exists, mock_open, mock_tts_service):
|
||||
"""Test downloading audio file"""
|
||||
# Set up mock request status with output file
|
||||
mock_request = Mock(
|
||||
id=1,
|
||||
status="completed", # Must match the status check in router
|
||||
output_file="test.wav",
|
||||
processing_time=1.0
|
||||
)
|
||||
mock_tts_service.get_request_status.return_value = mock_request
|
||||
|
||||
# Mock file read
|
||||
mock_open.return_value.__enter__.return_value.read.return_value = b"audio data"
|
||||
|
||||
response = client.get("/tts/file/1")
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/wav"
|
||||
assert response.headers["content-disposition"] == "attachment; filename=speech_1.wav"
|
||||
assert response.content == b"audio data"
|
||||
|
||||
def test_get_audio_file_not_found(mock_tts_service):
|
||||
"""Test downloading non-existent audio file"""
|
||||
mock_tts_service.get_request_status.return_value = None
|
||||
response = client.get("/tts/file/999")
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == "Request not found"
|
||||
|
||||
def test_create_tts_invalid_speed(mock_tts_service):
|
||||
"""Test creating a TTS request with invalid speed"""
|
||||
test_request = {
|
||||
"text": "Hello world",
|
||||
"voice": "af",
|
||||
"speed": -1.0, # Invalid speed
|
||||
"stitch_long_output": True
|
||||
}
|
||||
response = client.post("/tts", json=test_request)
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
def test_get_audio_file_not_completed(mock_tts_service):
|
||||
"""Test getting audio file for request that's still processing"""
|
||||
mock_request = Mock(
|
||||
id=1,
|
||||
status="processing", # Not completed yet
|
||||
output_file=None,
|
||||
processing_time=None
|
||||
)
|
||||
mock_tts_service.get_request_status.return_value = mock_request
|
||||
|
||||
response = client.get("/tts/file/1")
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "Audio generation not complete"
|
||||
|
||||
def test_get_tts_status_processing(mock_tts_service):
|
||||
"""Test getting status of a processing request"""
|
||||
mock_request = Mock(
|
||||
id=1,
|
||||
status="processing",
|
||||
output_file=None,
|
||||
processing_time=None
|
||||
)
|
||||
mock_tts_service.get_request_status.return_value = mock_request
|
||||
|
||||
response = client.get("/tts/1")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"request_id": 1,
|
||||
"status": "processing",
|
||||
"output_file": None,
|
||||
"processing_time": None
|
||||
}
|
5
pytest.ini
Normal file
5
pytest.ini
Normal file
|
@ -0,0 +1,5 @@
|
|||
[pytest]
|
||||
testpaths = api/tests
|
||||
python_files = test_*.py
|
||||
addopts = -v --tb=short
|
||||
pythonpath = .
|
12
requirements-test.txt
Normal file
12
requirements-test.txt
Normal file
|
@ -0,0 +1,12 @@
|
|||
# Core dependencies for testing
|
||||
fastapi==0.115.6
|
||||
uvicorn==0.34.0
|
||||
pydantic==2.10.4
|
||||
pydantic-settings==2.7.0
|
||||
python-dotenv==1.0.1
|
||||
sqlalchemy==2.0.27
|
||||
|
||||
# Testing
|
||||
pytest==8.0.0
|
||||
httpx==0.26.0
|
||||
pytest-asyncio==0.23.5
|
|
@ -1,3 +1,4 @@
|
|||
# Primarily for reference, as Dockerfile refer
|
||||
# Core dependencies
|
||||
fastapi==0.115.6
|
||||
uvicorn==0.34.0
|
||||
|
@ -19,3 +20,8 @@ regex==2024.11.6
|
|||
tqdm==4.67.1
|
||||
requests==2.32.3
|
||||
munch==4.0.0
|
||||
|
||||
# Testing
|
||||
pytest==8.0.0
|
||||
httpx==0.26.0
|
||||
pytest-asyncio==0.23.5
|
||||
|
|
Loading…
Add table
Reference in a new issue