import pytest from unittest.mock import AsyncMock, patch, MagicMock from fastapi.testclient import TestClient import numpy as np import asyncio from typing import AsyncGenerator import os import json from api.src.main import app from api.src.services.tts_service import TTSService from api.src.core.config import settings from api.src.routers.openai_compatible import ( load_openai_mappings, get_tts_service, stream_audio_chunks ) from api.src.structures.schemas import OpenAISpeechRequest client = TestClient(app) @pytest.fixture def test_voice(): """Fixture providing a test voice name.""" return "test_voice" @pytest.fixture def mock_openai_mappings(): """Mock OpenAI mappings for testing.""" with patch("api.src.routers.openai_compatible._openai_mappings", { "models": { "tts-1": "kokoro-v0_19", "tts-1-hd": "kokoro-v0_19" }, "voices": { "alloy": "am_adam", "nova": "bf_isabella" } }): yield @pytest.fixture def mock_json_file(tmp_path): """Create a temporary mock JSON file.""" content = { "models": {"test-model": "test-kokoro"}, "voices": {"test-voice": "test-internal"} } json_file = tmp_path / "test_mappings.json" json_file.write_text(json.dumps(content)) return json_file def test_load_openai_mappings(mock_json_file): """Test loading OpenAI mappings from JSON file""" with patch("os.path.join", return_value=str(mock_json_file)): mappings = load_openai_mappings() assert "models" in mappings assert "voices" in mappings assert mappings["models"]["test-model"] == "test-kokoro" assert mappings["voices"]["test-voice"] == "test-internal" def test_load_openai_mappings_file_not_found(): """Test handling of missing mappings file""" with patch("os.path.join", return_value="/nonexistent/path"): mappings = load_openai_mappings() assert mappings == {"models": {}, "voices": {}} @pytest.mark.asyncio async def test_get_tts_service_initialization(): """Test TTSService initialization""" with patch("api.src.routers.openai_compatible._tts_service", None): with patch("api.src.routers.openai_compatible._init_lock", None): with patch("api.src.services.tts_service.TTSService.create") as mock_create: mock_service = AsyncMock() mock_create.return_value = mock_service # Test concurrent access async def get_service(): return await get_tts_service() # Create multiple concurrent requests tasks = [get_service() for _ in range(5)] results = await asyncio.gather(*tasks) # Verify service was created only once mock_create.assert_called_once() assert all(r == mock_service for r in results) @pytest.mark.asyncio async def test_stream_audio_chunks_client_disconnect(): """Test handling of client disconnect during streaming""" mock_request = MagicMock() mock_request.is_disconnected = AsyncMock(return_value=True) mock_service = AsyncMock() async def mock_stream(*args, **kwargs): for i in range(5): yield b"chunk" mock_service.generate_audio_stream = mock_stream mock_service.list_voices.return_value = ["test_voice"] request = OpenAISpeechRequest( model="kokoro", input="Test text", voice="test_voice", response_format="mp3", stream=True, speed=1.0 ) chunks = [] async for chunk in stream_audio_chunks(mock_service, request, mock_request): chunks.append(chunk) assert len(chunks) == 0 # Should stop immediately due to disconnect def test_openai_voice_mapping(mock_tts_service, mock_openai_mappings): """Test OpenAI voice name mapping""" mock_tts_service.list_voices.return_value = ["am_adam", "bf_isabella"] response = client.post( "/v1/audio/speech", json={ "model": "tts-1", "input": "Hello world", "voice": "alloy", # OpenAI voice name "response_format": "mp3", "stream": False } ) assert response.status_code == 200 mock_tts_service.generate_audio.assert_called_once() assert mock_tts_service.generate_audio.call_args[1]["voice"] == "am_adam" def test_openai_voice_mapping_streaming(mock_tts_service, mock_openai_mappings, mock_audio_bytes): """Test OpenAI voice mapping in streaming mode""" mock_tts_service.list_voices.return_value = ["am_adam", "bf_isabella"] response = client.post( "/v1/audio/speech", json={ "model": "tts-1-hd", "input": "Hello world", "voice": "nova", # OpenAI voice name "response_format": "mp3", "stream": True } ) assert response.status_code == 200 content = b"" for chunk in response.iter_bytes(): content += chunk assert content == mock_audio_bytes def test_invalid_openai_model(mock_tts_service, mock_openai_mappings): """Test error handling for invalid OpenAI model""" response = client.post( "/v1/audio/speech", json={ "model": "invalid-model", "input": "Hello world", "voice": "alloy", "response_format": "mp3", "stream": False } ) assert response.status_code == 400 error_response = response.json() assert error_response["detail"]["error"] == "invalid_model" assert "Unsupported model" in error_response["detail"]["message"] @pytest.fixture def mock_audio_bytes(): """Mock audio bytes for testing.""" return b"mock audio data" @pytest.fixture def mock_tts_service(mock_audio_bytes): """Mock TTS service for testing.""" with patch("api.src.routers.openai_compatible.get_tts_service") as mock_get: service = AsyncMock(spec=TTSService) service.generate_audio.return_value = (np.zeros(1000), 0.1) async def mock_stream(*args, **kwargs) -> AsyncGenerator[bytes, None]: yield mock_audio_bytes service.generate_audio_stream = mock_stream service.list_voices.return_value = ["test_voice", "voice1", "voice2"] service.combine_voices.return_value = "voice1_voice2" mock_get.return_value = service mock_get.side_effect = None yield service def test_openai_speech_endpoint(mock_tts_service, test_voice): """Test the OpenAI-compatible speech endpoint with basic MP3 generation""" response = client.post( "/v1/audio/speech", json={ "model": "kokoro", "input": "Hello world", "voice": test_voice, "response_format": "mp3", "stream": False } ) assert response.status_code == 200 assert response.headers["content-type"] == "audio/mpeg" assert len(response.content) > 0 def test_openai_speech_streaming(mock_tts_service, test_voice, mock_audio_bytes): """Test the OpenAI-compatible speech endpoint with streaming""" response = client.post( "/v1/audio/speech", json={ "model": "kokoro", "input": "Hello world", "voice": test_voice, "response_format": "mp3", "stream": True } ) assert response.status_code == 200 assert response.headers["content-type"] == "audio/mpeg" assert "Transfer-Encoding" in response.headers assert response.headers["Transfer-Encoding"] == "chunked" content = b"" for chunk in response.iter_bytes(): content += chunk assert content == mock_audio_bytes def test_openai_speech_pcm_streaming(mock_tts_service, test_voice, mock_audio_bytes): """Test PCM streaming format""" response = client.post( "/v1/audio/speech", json={ "model": "kokoro", "input": "Hello world", "voice": test_voice, "response_format": "pcm", "stream": True } ) assert response.status_code == 200 assert response.headers["content-type"] == "audio/pcm" content = b"" for chunk in response.iter_bytes(): content += chunk assert content == mock_audio_bytes def test_openai_speech_invalid_voice(mock_tts_service): """Test error handling for invalid voice""" mock_tts_service.generate_audio.side_effect = ValueError("Voice 'invalid_voice' not found") response = client.post( "/v1/audio/speech", json={ "model": "kokoro", "input": "Hello world", "voice": "invalid_voice", "response_format": "mp3", "stream": False } ) assert response.status_code == 400 error_response = response.json() assert error_response["detail"]["error"] == "validation_error" assert "Voice 'invalid_voice' not found" in error_response["detail"]["message"] assert error_response["detail"]["type"] == "invalid_request_error" def test_openai_speech_empty_text(mock_tts_service, test_voice): """Test error handling for empty text""" async def mock_error_stream(*args, **kwargs): raise ValueError("Text is empty after preprocessing") mock_tts_service.generate_audio = mock_error_stream mock_tts_service.list_voices.return_value = ["test_voice"] response = client.post( "/v1/audio/speech", json={ "model": "kokoro", "input": "", "voice": test_voice, "response_format": "mp3", "stream": False } ) assert response.status_code == 400 error_response = response.json() assert error_response["detail"]["error"] == "validation_error" assert "Text is empty after preprocessing" in error_response["detail"]["message"] assert error_response["detail"]["type"] == "invalid_request_error" def test_openai_speech_invalid_format(mock_tts_service, test_voice): """Test error handling for invalid format""" response = client.post( "/v1/audio/speech", json={ "model": "kokoro", "input": "Hello world", "voice": test_voice, "response_format": "invalid_format", "stream": False } ) assert response.status_code == 422 # Validation error from Pydantic def test_list_voices(mock_tts_service): """Test listing available voices""" # Override the mock for this specific test mock_tts_service.list_voices.return_value = ["voice1", "voice2"] response = client.get("/v1/audio/voices") assert response.status_code == 200 data = response.json() assert "voices" in data assert len(data["voices"]) == 2 assert "voice1" in data["voices"] assert "voice2" in data["voices"] def test_combine_voices(mock_tts_service): """Test combining voices endpoint""" response = client.post( "/v1/audio/voices/combine", json="voice1+voice2" ) assert response.status_code == 200 data = response.json() assert "voice" in data assert data["voice"] == "voice1_voice2" def test_server_error(mock_tts_service, test_voice): """Test handling of server errors""" async def mock_error_stream(*args, **kwargs): raise RuntimeError("Internal server error") mock_tts_service.generate_audio = mock_error_stream mock_tts_service.list_voices.return_value = ["test_voice"] response = client.post( "/v1/audio/speech", json={ "model": "kokoro", "input": "Hello world", "voice": test_voice, "response_format": "mp3", "stream": False } ) assert response.status_code == 500 error_response = response.json() assert error_response["detail"]["error"] == "processing_error" assert error_response["detail"]["type"] == "server_error" def test_streaming_error(mock_tts_service, test_voice): """Test handling streaming errors""" async def mock_error_stream(*args, **kwargs) -> AsyncGenerator[bytes, None]: if False: # This makes it a proper generator yield b"" raise RuntimeError("Streaming failed") mock_tts_service.generate_audio_stream = mock_error_stream response = client.post( "/v1/audio/speech", json={ "model": "kokoro", "input": "Hello world", "voice": test_voice, "response_format": "mp3", "stream": True } ) assert response.status_code == 500 error_response = response.json() assert error_response["detail"]["error"] == "processing_error" assert error_response["detail"]["type"] == "server_error" @pytest.mark.asyncio async def test_streaming_initialization_error(): """Test handling of streaming initialization errors""" mock_service = AsyncMock() async def mock_error_stream(*args, **kwargs): if False: # This makes it a proper generator yield b"" raise RuntimeError("Failed to initialize stream") mock_service.generate_audio_stream = mock_error_stream mock_service.list_voices.return_value = ["test_voice"] request = OpenAISpeechRequest( model="kokoro", input="Test text", voice="test_voice", response_format="mp3", stream=True, speed=1.0 ) with pytest.raises(RuntimeError) as exc: async for _ in stream_audio_chunks(mock_service, request, MagicMock()): pass assert "Failed to initialize stream" in str(exc.value)