diff --git a/.coverage b/.coverage index f449db0..052e756 100644 Binary files a/.coverage and b/.coverage differ diff --git a/api/src/services/audio.py b/api/src/services/audio.py index f909519..e13d91f 100644 --- a/api/src/services/audio.py +++ b/api/src/services/audio.py @@ -62,10 +62,10 @@ class AudioService: logger.info("Writing to WAV format...") # Always include WAV header for WAV format sf.write(buffer, normalized_audio, sample_rate, format="WAV", subtype='PCM_16') - elif output_format in ["mp3", "aac"]: - logger.info(f"Converting to {output_format.upper()} format...") + elif output_format == "mp3": + logger.info("Converting to MP3 format...") # Use lower bitrate for streaming - sf.write(buffer, normalized_audio, sample_rate, format=output_format.upper()) + sf.write(buffer, normalized_audio, sample_rate, format="MP3") elif output_format == "opus": logger.info("Converting to Opus format...") # Use lower bitrate and smaller frame size for streaming @@ -76,9 +76,14 @@ class AudioService: sf.write(buffer, normalized_audio, sample_rate, format="FLAC", subtype='PCM_16') else: - raise ValueError( - f"Format {output_format} not supported. Supported formats are: wav, mp3, opus, flac, pcm." - ) + if output_format == "aac": + raise ValueError( + "Format aac not supported. Supported formats are: wav, mp3, opus, flac, pcm." + ) + else: + raise ValueError( + f"Format {output_format} not supported. Supported formats are: wav, mp3, opus, flac, pcm." + ) buffer.seek(0) return buffer.getvalue() diff --git a/api/tests/test_endpoints.py b/api/tests/test_endpoints.py index 80fe733..6142e12 100644 --- a/api/tests/test_endpoints.py +++ b/api/tests/test_endpoints.py @@ -1,13 +1,21 @@ from unittest.mock import Mock import pytest +import pytest_asyncio from fastapi.testclient import TestClient +from httpx import AsyncClient from ..src.main import app # Create test client client = TestClient(app) +# Create async client fixture +@pytest_asyncio.fixture +async def async_client(): + async with AsyncClient(app=app, base_url="http://test") as ac: + yield ac + # Mock services @pytest.fixture @@ -34,12 +42,12 @@ def mock_tts_service(monkeypatch): @pytest.fixture def mock_audio_service(monkeypatch): - def mock_convert(*args): - return b"converted mock audio data" - + mock_service = Mock() + mock_service.convert_audio.return_value = b"converted mock audio data" monkeypatch.setattr( - "api.src.routers.openai_compatible.AudioService.convert_audio", mock_convert + "api.src.routers.openai_compatible.AudioService", mock_service ) + return mock_service def test_health_check(): @@ -153,3 +161,87 @@ def test_combine_voices_error(mock_tts_service): assert response.status_code == 500 assert "Combination failed" in response.json()["detail"]["message"] + + +@pytest.mark.asyncio +async def test_openai_speech_pcm_streaming(mock_tts_service, async_client): + """Test streaming PCM audio for real-time playback""" + test_request = { + "model": "kokoro", + "input": "Hello world", + "voice": "af", + "response_format": "pcm", + } + + # Mock streaming response + async def mock_stream(): + yield b"chunk1" + yield b"chunk2" + mock_tts_service.generate_audio_stream.return_value = mock_stream() + + # Add streaming header + headers = {"x-raw-response": "stream"} + response = await async_client.post("/v1/audio/speech", json=test_request, headers=headers) + + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/pcm" + # Just verify status and content type + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/pcm" + + +@pytest.mark.asyncio +async def test_openai_speech_streaming_mp3(mock_tts_service, async_client): + """Test streaming MP3 audio to file""" + test_request = { + "model": "kokoro", + "input": "Hello world", + "voice": "af", + "response_format": "mp3", + } + + # Mock streaming response + async def mock_stream(): + yield b"mp3header" + yield b"mp3data" + mock_tts_service.generate_audio_stream.return_value = mock_stream() + + # Add streaming header + headers = {"x-raw-response": "stream"} + response = await async_client.post("/v1/audio/speech", json=test_request, headers=headers) + + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/mpeg" + assert response.headers["content-disposition"] == "attachment; filename=speech.mp3" + # Just verify status and content type + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/mpeg" + assert response.headers["content-disposition"] == "attachment; filename=speech.mp3" + + +@pytest.mark.asyncio +async def test_openai_speech_streaming_generator(mock_tts_service, async_client): + """Test streaming with async generator""" + test_request = { + "model": "kokoro", + "input": "Hello world", + "voice": "af", + "response_format": "pcm", + } + + # Mock streaming response + async def mock_stream(): + yield b"chunk1" + yield b"chunk2" + + mock_tts_service.generate_audio_stream.return_value = mock_stream() + + # Add streaming header + headers = {"x-raw-response": "stream"} + response = await async_client.post("/v1/audio/speech", json=test_request, headers=headers) + + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/pcm" + # Just verify status and content type + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/pcm" diff --git a/api/tests/test_main.py b/api/tests/test_main.py index c6a972e..51026c5 100644 --- a/api/tests/test_main.py +++ b/api/tests/test_main.py @@ -1,6 +1,6 @@ """Tests for FastAPI application""" -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, call import pytest from fastapi.testclient import TestClient @@ -39,8 +39,12 @@ async def test_lifespan_successful_warmup(mock_logger, mock_tts_model): # Verify the expected logging sequence mock_logger.info.assert_any_call("Loading TTS model and voice packs...") - mock_logger.info.assert_any_call("Model loaded and warmed up on cuda") - mock_logger.info.assert_any_call("3 voice packs loaded successfully") + + # Check for the startup message containing the required info + startup_calls = [call[0][0] for call in mock_logger.info.call_args_list] + startup_msg = next(msg for msg in startup_calls if "Model loaded and warmed up on" in msg) + assert "Model loaded and warmed up on cuda" in startup_msg + assert "3 voice packs loaded successfully" in startup_msg # Verify model setup was called mock_tts_model.setup.assert_called_once()