Kokoro-FastAPI/api/tests/test_openai_endpoints.py
2025-04-04 16:58:07 -06:00

499 lines
16 KiB
Python

import asyncio
import json
import os
from typing import AsyncGenerator, Tuple
from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
from fastapi.testclient import TestClient
from api.src.core.config import settings
from api.src.inference.base import AudioChunk
from api.src.main import app
from api.src.routers.openai_compatible import (
get_tts_service,
load_openai_mappings,
stream_audio_chunks,
)
from api.src.services.streaming_audio_writer import StreamingAudioWriter
from api.src.services.tts_service import TTSService
from api.src.structures.schemas import OpenAISpeechRequest
client = TestClient(app)
@pytest.fixture
def test_voice():
"""Fixture providing a test voice name."""
return "test_voice"
@pytest.fixture
def mock_openai_mappings():
"""Mock OpenAI mappings for testing."""
with patch(
"api.src.routers.openai_compatible._openai_mappings",
{
"models": {"tts-1": "kokoro-v1_0", "tts-1-hd": "kokoro-v1_0"},
"voices": {"alloy": "am_adam", "nova": "bf_isabella"},
},
):
yield
@pytest.fixture
def mock_json_file(tmp_path):
"""Create a temporary mock JSON file."""
content = {
"models": {"test-model": "test-kokoro"},
"voices": {"test-voice": "test-internal"},
}
json_file = tmp_path / "test_mappings.json"
json_file.write_text(json.dumps(content))
return json_file
def test_load_openai_mappings(mock_json_file):
"""Test loading OpenAI mappings from JSON file"""
with patch("os.path.join", return_value=str(mock_json_file)):
mappings = load_openai_mappings()
assert "models" in mappings
assert "voices" in mappings
assert mappings["models"]["test-model"] == "test-kokoro"
assert mappings["voices"]["test-voice"] == "test-internal"
def test_load_openai_mappings_file_not_found():
"""Test handling of missing mappings file"""
with patch("os.path.join", return_value="/nonexistent/path"):
mappings = load_openai_mappings()
assert mappings == {"models": {}, "voices": {}}
def test_list_models(mock_openai_mappings):
"""Test listing available models endpoint"""
response = client.get("/v1/models")
assert response.status_code == 200
data = response.json()
assert data["object"] == "list"
assert isinstance(data["data"], list)
assert len(data["data"]) == 3 # tts-1, tts-1-hd, and kokoro
# Verify all expected models are present
model_ids = [model["id"] for model in data["data"]]
assert "tts-1" in model_ids
assert "tts-1-hd" in model_ids
assert "kokoro" in model_ids
# Verify model format
for model in data["data"]:
assert model["object"] == "model"
assert "created" in model
assert model["owned_by"] == "kokoro"
def test_retrieve_model(mock_openai_mappings):
"""Test retrieving a specific model endpoint"""
# Test successful model retrieval
response = client.get("/v1/models/tts-1")
assert response.status_code == 200
data = response.json()
assert data["id"] == "tts-1"
assert data["object"] == "model"
assert data["owned_by"] == "kokoro"
assert "created" in data
# Test non-existent model
response = client.get("/v1/models/nonexistent-model")
assert response.status_code == 404
error = response.json()
assert error["detail"]["error"] == "model_not_found"
assert "not found" in error["detail"]["message"]
assert error["detail"]["type"] == "invalid_request_error"
@pytest.mark.asyncio
async def test_get_tts_service_initialization():
"""Test TTSService initialization"""
with patch("api.src.routers.openai_compatible._tts_service", None):
with patch("api.src.routers.openai_compatible._init_lock", None):
with patch("api.src.services.tts_service.TTSService.create") as mock_create:
mock_service = AsyncMock()
mock_create.return_value = mock_service
# Test concurrent access
async def get_service():
return await get_tts_service()
# Create multiple concurrent requests
tasks = [get_service() for _ in range(5)]
results = await asyncio.gather(*tasks)
# Verify service was created only once
mock_create.assert_called_once()
assert all(r == mock_service for r in results)
@pytest.mark.asyncio
async def test_stream_audio_chunks_client_disconnect():
"""Test handling of client disconnect during streaming"""
mock_request = MagicMock()
mock_request.is_disconnected = AsyncMock(return_value=True)
mock_service = AsyncMock()
async def mock_stream(*args, **kwargs):
for i in range(5):
yield AudioChunk(np.ndarray([], np.int16), output=b"chunk")
mock_service.generate_audio_stream = mock_stream
mock_service.list_voices.return_value = ["test_voice"]
request = OpenAISpeechRequest(
model="kokoro",
input="Test text",
voice="test_voice",
response_format="mp3",
stream=True,
speed=1.0,
)
writer = StreamingAudioWriter("mp3", 24000)
chunks = []
async for chunk in stream_audio_chunks(mock_service, request, mock_request, writer):
chunks.append(chunk)
writer.close()
assert len(chunks) == 0 # Should stop immediately due to disconnect
def test_openai_voice_mapping(mock_tts_service, mock_openai_mappings):
"""Test OpenAI voice name mapping"""
mock_tts_service.list_voices.return_value = ["am_adam", "bf_isabella"]
response = client.post(
"/v1/audio/speech",
json={
"model": "tts-1",
"input": "Hello world",
"voice": "alloy", # OpenAI voice name
"response_format": "mp3",
"stream": False,
},
)
assert response.status_code == 200
mock_tts_service.generate_audio.assert_called_once()
assert mock_tts_service.generate_audio.call_args[1]["voice"] == "am_adam"
def test_openai_voice_mapping_streaming(
mock_tts_service, mock_openai_mappings, mock_audio_bytes
):
"""Test OpenAI voice mapping in streaming mode"""
mock_tts_service.list_voices.return_value = ["am_adam", "bf_isabella"]
response = client.post(
"/v1/audio/speech",
json={
"model": "tts-1-hd",
"input": "Hello world",
"voice": "nova", # OpenAI voice name
"response_format": "mp3",
"stream": True,
},
)
assert response.status_code == 200
content = b""
for chunk in response.iter_bytes():
content += chunk
assert content == mock_audio_bytes
def test_invalid_openai_model(mock_tts_service, mock_openai_mappings):
"""Test error handling for invalid OpenAI model"""
response = client.post(
"/v1/audio/speech",
json={
"model": "invalid-model",
"input": "Hello world",
"voice": "alloy",
"response_format": "mp3",
"stream": False,
},
)
assert response.status_code == 400
error_response = response.json()
assert error_response["detail"]["error"] == "invalid_model"
assert "Unsupported model" in error_response["detail"]["message"]
@pytest.fixture
def mock_audio_bytes():
"""Mock audio bytes for testing."""
return b"mock audio data"
@pytest.fixture
def mock_tts_service(mock_audio_bytes):
"""Mock TTS service for testing."""
with patch("api.src.routers.openai_compatible.get_tts_service") as mock_get:
service = AsyncMock(spec=TTSService)
service.generate_audio.return_value = AudioChunk(np.zeros(1000, np.int16))
async def mock_stream(*args, **kwargs) -> AsyncGenerator[AudioChunk, None]:
yield AudioChunk(np.ndarray([], np.int16), output=mock_audio_bytes)
service.generate_audio_stream = mock_stream
service.list_voices.return_value = ["test_voice", "voice1", "voice2"]
service.combine_voices.return_value = "voice1_voice2"
mock_get.return_value = service
mock_get.side_effect = None
yield service
@patch("api.src.services.audio.AudioService.convert_audio")
def test_openai_speech_endpoint(
mock_convert, mock_tts_service, test_voice, mock_audio_bytes
):
"""Test the OpenAI-compatible speech endpoint with basic MP3 generation"""
# Configure mocks
mock_tts_service.generate_audio.return_value = AudioChunk(np.zeros(1000, np.int16))
mock_convert.return_value = AudioChunk(
np.zeros(1000, np.int16), output=mock_audio_bytes
)
response = client.post(
"/v1/audio/speech",
json={
"model": "kokoro",
"input": "Hello world",
"voice": test_voice,
"response_format": "mp3",
"stream": False,
},
)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/mpeg"
assert len(response.content) > 0
assert response.content == mock_audio_bytes + mock_audio_bytes
mock_tts_service.generate_audio.assert_called_once()
assert mock_convert.call_count == 2
def test_openai_speech_streaming(mock_tts_service, test_voice, mock_audio_bytes):
"""Test the OpenAI-compatible speech endpoint with streaming"""
response = client.post(
"/v1/audio/speech",
json={
"model": "kokoro",
"input": "Hello world",
"voice": test_voice,
"response_format": "mp3",
"stream": True,
},
)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/mpeg"
assert "Transfer-Encoding" in response.headers
assert response.headers["Transfer-Encoding"] == "chunked"
content = b""
for chunk in response.iter_bytes():
content += chunk
assert content == mock_audio_bytes
def test_openai_speech_pcm_streaming(mock_tts_service, test_voice, mock_audio_bytes):
"""Test PCM streaming format"""
response = client.post(
"/v1/audio/speech",
json={
"model": "kokoro",
"input": "Hello world",
"voice": test_voice,
"response_format": "pcm",
"stream": True,
},
)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/pcm"
content = b""
for chunk in response.iter_bytes():
content += chunk
assert content == mock_audio_bytes
def test_openai_speech_invalid_voice(mock_tts_service):
"""Test error handling for invalid voice"""
mock_tts_service.generate_audio.side_effect = ValueError(
"Voice 'invalid_voice' not found"
)
response = client.post(
"/v1/audio/speech",
json={
"model": "kokoro",
"input": "Hello world",
"voice": "invalid_voice",
"response_format": "mp3",
"stream": False,
},
)
assert response.status_code == 400
error_response = response.json()
assert error_response["detail"]["error"] == "validation_error"
assert "Voice 'invalid_voice' not found" in error_response["detail"]["message"]
assert error_response["detail"]["type"] == "invalid_request_error"
def test_openai_speech_empty_text(mock_tts_service, test_voice):
"""Test error handling for empty text"""
async def mock_error_stream(*args, **kwargs):
raise ValueError("Text is empty after preprocessing")
mock_tts_service.generate_audio = mock_error_stream
mock_tts_service.list_voices.return_value = ["test_voice"]
response = client.post(
"/v1/audio/speech",
json={
"model": "kokoro",
"input": "",
"voice": test_voice,
"response_format": "mp3",
"stream": False,
},
)
assert response.status_code == 400
error_response = response.json()
assert error_response["detail"]["error"] == "validation_error"
assert "Text is empty after preprocessing" in error_response["detail"]["message"]
assert error_response["detail"]["type"] == "invalid_request_error"
def test_openai_speech_invalid_format(mock_tts_service, test_voice):
"""Test error handling for invalid format"""
response = client.post(
"/v1/audio/speech",
json={
"model": "kokoro",
"input": "Hello world",
"voice": test_voice,
"response_format": "invalid_format",
"stream": False,
},
)
assert response.status_code == 422 # Validation error from Pydantic
def test_list_voices(mock_tts_service):
"""Test listing available voices"""
# Override the mock for this specific test
mock_tts_service.list_voices.return_value = ["voice1", "voice2"]
response = client.get("/v1/audio/voices")
assert response.status_code == 200
data = response.json()
assert "voices" in data
assert len(data["voices"]) == 2
assert "voice1" in data["voices"]
assert "voice2" in data["voices"]
@patch("api.src.routers.openai_compatible.settings")
def test_combine_voices(mock_settings, mock_tts_service):
"""Test combining voices endpoint"""
# Enable local voice saving for this test
mock_settings.allow_local_voice_saving = True
response = client.post("/v1/audio/voices/combine", json="voice1+voice2")
assert response.status_code == 200
assert response.headers["content-type"] == "application/octet-stream"
assert "voice1+voice2.pt" in response.headers["content-disposition"]
def test_server_error(mock_tts_service, test_voice):
"""Test handling of server errors"""
async def mock_error_stream(*args, **kwargs):
raise RuntimeError("Internal server error")
mock_tts_service.generate_audio = mock_error_stream
mock_tts_service.list_voices.return_value = ["test_voice"]
response = client.post(
"/v1/audio/speech",
json={
"model": "kokoro",
"input": "Hello world",
"voice": test_voice,
"response_format": "mp3",
"stream": False,
},
)
assert response.status_code == 500
error_response = response.json()
assert error_response["detail"]["error"] == "processing_error"
assert error_response["detail"]["type"] == "server_error"
def test_streaming_error(mock_tts_service, test_voice):
"""Test handling streaming errors"""
# Mock process_voices to raise the error
mock_tts_service.list_voices.side_effect = RuntimeError("Streaming failed")
response = client.post(
"/v1/audio/speech",
json={
"model": "kokoro",
"input": "Hello world",
"voice": test_voice,
"response_format": "mp3",
"stream": True,
},
)
assert response.status_code == 500
error_data = response.json()
assert error_data["detail"]["error"] == "processing_error"
assert error_data["detail"]["type"] == "server_error"
assert "Streaming failed" in error_data["detail"]["message"]
@pytest.mark.asyncio
async def test_streaming_initialization_error():
"""Test handling of streaming initialization errors"""
mock_service = AsyncMock()
async def mock_error_stream(*args, **kwargs):
if False: # This makes it a proper generator
yield b""
raise RuntimeError("Failed to initialize stream")
mock_service.generate_audio_stream = mock_error_stream
mock_service.list_voices.return_value = ["test_voice"]
request = OpenAISpeechRequest(
model="kokoro",
input="Test text",
voice="test_voice",
response_format="mp3",
stream=True,
speed=1.0,
)
writer = StreamingAudioWriter("mp3", 24000)
with pytest.raises(RuntimeError) as exc:
async for _ in stream_audio_chunks(mock_service, request, MagicMock(), writer):
pass
writer.close()
assert "Failed to initialize stream" in str(exc.value)