2024-12-30 13:25:30 -07:00
|
|
|
from fastapi.testclient import TestClient
|
|
|
|
import pytest
|
2024-12-31 01:57:00 -07:00
|
|
|
from unittest.mock import Mock
|
2024-12-30 13:25:30 -07:00
|
|
|
from ..src.main import app
|
|
|
|
|
|
|
|
# Create test client
|
|
|
|
client = TestClient(app)
|
|
|
|
|
2024-12-31 01:57:00 -07:00
|
|
|
|
2024-12-31 01:52:16 -07:00
|
|
|
# Mock services
|
2024-12-30 13:25:30 -07:00
|
|
|
@pytest.fixture
|
2024-12-31 01:52:16 -07:00
|
|
|
def mock_tts_service(monkeypatch):
|
|
|
|
mock_service = Mock()
|
|
|
|
mock_service._generate_audio.return_value = (bytes([0, 1, 2, 3]), 1.0)
|
2024-12-31 01:57:00 -07:00
|
|
|
mock_service.list_voices.return_value = [
|
|
|
|
"af",
|
|
|
|
"bm_lewis",
|
|
|
|
"bf_isabella",
|
|
|
|
"bf_emma",
|
|
|
|
"af_sarah",
|
|
|
|
"af_bella",
|
|
|
|
"am_adam",
|
|
|
|
"am_michael",
|
|
|
|
"bm_george",
|
|
|
|
]
|
|
|
|
monkeypatch.setattr(
|
|
|
|
"api.src.routers.openai_compatible.TTSService",
|
|
|
|
lambda *args, **kwargs: mock_service,
|
|
|
|
)
|
2024-12-31 01:52:16 -07:00
|
|
|
return mock_service
|
|
|
|
|
2024-12-31 01:57:00 -07:00
|
|
|
|
2024-12-31 01:52:16 -07:00
|
|
|
@pytest.fixture
|
|
|
|
def mock_audio_service(monkeypatch):
|
|
|
|
def mock_convert(*args):
|
|
|
|
return b"converted mock audio data"
|
2024-12-31 01:57:00 -07:00
|
|
|
|
|
|
|
monkeypatch.setattr(
|
|
|
|
"api.src.routers.openai_compatible.AudioService.convert_audio", mock_convert
|
|
|
|
)
|
|
|
|
|
2024-12-30 13:25:30 -07:00
|
|
|
|
|
|
|
def test_health_check():
|
|
|
|
"""Test the health check endpoint"""
|
|
|
|
response = client.get("/health")
|
|
|
|
assert response.status_code == 200
|
|
|
|
assert response.json() == {"status": "healthy"}
|
|
|
|
|
2024-12-31 01:57:00 -07:00
|
|
|
|
2024-12-31 01:52:16 -07:00
|
|
|
def test_openai_speech_endpoint(mock_tts_service, mock_audio_service):
|
|
|
|
"""Test the OpenAI-compatible speech endpoint"""
|
2024-12-30 13:25:30 -07:00
|
|
|
test_request = {
|
2024-12-31 01:52:16 -07:00
|
|
|
"model": "tts-1",
|
|
|
|
"input": "Hello world",
|
|
|
|
"voice": "bm_lewis",
|
|
|
|
"response_format": "wav",
|
2024-12-31 01:57:00 -07:00
|
|
|
"speed": 1.0,
|
2024-12-30 13:25:30 -07:00
|
|
|
}
|
2024-12-31 01:52:16 -07:00
|
|
|
response = client.post("/v1/audio/speech", json=test_request)
|
2024-12-30 13:25:30 -07:00
|
|
|
assert response.status_code == 200
|
2024-12-31 01:52:16 -07:00
|
|
|
assert response.headers["content-type"] == "audio/wav"
|
|
|
|
assert response.headers["content-disposition"] == "attachment; filename=speech.wav"
|
|
|
|
mock_tts_service._generate_audio.assert_called_once_with(
|
2024-12-31 01:57:00 -07:00
|
|
|
text="Hello world", voice="bm_lewis", speed=1.0, stitch_long_output=True
|
2024-12-31 01:52:16 -07:00
|
|
|
)
|
|
|
|
assert response.content == b"converted mock audio data"
|
2024-12-30 13:25:30 -07:00
|
|
|
|
2024-12-31 01:57:00 -07:00
|
|
|
|
2024-12-31 01:52:16 -07:00
|
|
|
def test_openai_speech_invalid_voice(mock_tts_service):
|
|
|
|
"""Test the OpenAI-compatible speech endpoint with invalid voice"""
|
2024-12-30 13:25:30 -07:00
|
|
|
test_request = {
|
2024-12-31 01:52:16 -07:00
|
|
|
"model": "tts-1",
|
|
|
|
"input": "Hello world",
|
2024-12-30 13:25:30 -07:00
|
|
|
"voice": "invalid_voice",
|
2024-12-31 01:52:16 -07:00
|
|
|
"response_format": "wav",
|
2024-12-31 01:57:00 -07:00
|
|
|
"speed": 1.0,
|
2024-12-30 13:25:30 -07:00
|
|
|
}
|
2024-12-31 01:52:16 -07:00
|
|
|
response = client.post("/v1/audio/speech", json=test_request)
|
|
|
|
assert response.status_code == 422 # Validation error
|
2024-12-30 13:25:30 -07:00
|
|
|
|
2024-12-31 01:57:00 -07:00
|
|
|
|
2024-12-31 01:52:16 -07:00
|
|
|
def test_openai_speech_invalid_speed(mock_tts_service):
|
|
|
|
"""Test the OpenAI-compatible speech endpoint with invalid speed"""
|
2024-12-30 13:25:30 -07:00
|
|
|
test_request = {
|
2024-12-31 01:52:16 -07:00
|
|
|
"model": "tts-1",
|
|
|
|
"input": "Hello world",
|
2024-12-30 13:25:30 -07:00
|
|
|
"voice": "af",
|
2024-12-31 01:52:16 -07:00
|
|
|
"response_format": "wav",
|
2024-12-31 01:57:00 -07:00
|
|
|
"speed": -1.0, # Invalid speed
|
2024-12-30 13:25:30 -07:00
|
|
|
}
|
2024-12-31 01:52:16 -07:00
|
|
|
response = client.post("/v1/audio/speech", json=test_request)
|
2024-12-30 13:25:30 -07:00
|
|
|
assert response.status_code == 422 # Validation error
|
|
|
|
|
2024-12-31 01:57:00 -07:00
|
|
|
|
2024-12-31 01:52:16 -07:00
|
|
|
def test_openai_speech_generation_error(mock_tts_service):
|
|
|
|
"""Test error handling in speech generation"""
|
|
|
|
mock_tts_service._generate_audio.side_effect = Exception("Generation failed")
|
|
|
|
test_request = {
|
|
|
|
"model": "tts-1",
|
|
|
|
"input": "Hello world",
|
|
|
|
"voice": "af",
|
|
|
|
"response_format": "wav",
|
2024-12-31 01:57:00 -07:00
|
|
|
"speed": 1.0,
|
2024-12-30 13:25:30 -07:00
|
|
|
}
|
2024-12-31 01:52:16 -07:00
|
|
|
response = client.post("/v1/audio/speech", json=test_request)
|
|
|
|
assert response.status_code == 500
|
|
|
|
assert "Generation failed" in response.json()["detail"]
|