From e799f0c7c1cf47ac9ba88a5f768e4fad3713a79c Mon Sep 17 00:00:00 2001 From: remsky Date: Sat, 4 Jan 2025 18:09:23 -0700 Subject: [PATCH] WIP: basic tests on OpenAI streaming compatibility --- .coverage | Bin 53248 -> 53248 bytes api/src/services/audio.py | 17 +++--- api/tests/test_endpoints.py | 100 ++++++++++++++++++++++++++++++++++-- api/tests/test_main.py | 10 ++-- 4 files changed, 114 insertions(+), 13 deletions(-) diff --git a/.coverage b/.coverage index f449db0dd7f36a2945dfa3093dc68015235e8e6f..052e756fdcfbbc69dcd44de393750d64754538a9 100644 GIT binary patch delta 1932 zcmY+EX-rgC6vyA2dH3z_-WeFSSzu;BkwIqI83Y6(%z$kD(5fhksGz7I0hg#%2AeeL zhql_=l=MTDHgQWW#*)UUjiO1_#zd2vsHv&iq-hh{T5a5y+VifE(oes$-E+@<|C1LR zvBpNMyIs7x$32f%k1ZdqtimtwG2CjrGWHt{hNOR|zpWQ(e`<%cUM*9-uD-9Iunsa$+09ue1%*=?yJnwu_gL-gkb0b0;vw?zGPZexIwdWkki zHKV<~ySIB|dwZ~NYpsuZq7meUYJ>WsVLB3(>5AxF>aCMeS*n%W0nwc4C(1qacBCAY zxuF@9ZCdZebVAZOAh3F}k|fcbcJ6rW~H*#G0OHBuE8MVczpbJFn2QO$~HW6Q<)$jihhQhThJe?roj@aC(shxmF}X zeIC?D?bE(?J%VTwM8ouGGv)=Dpo~6k)@V)3`5Hsk&ubgBS?UetF@C8X#)~m$955QS zyLz#*Lf?VERzFZ{mBZv+QmX`|@8w70SJFFDh4@H5B@f9pvMBa|_-Ue0DGc#b+N1P3 zVJIHFpeem=+A#iz0AK!f!hw9X0&Ku!+3PJ>$hLq_qn|Ix)2AIwrMDJT@E!=4^KKei z!WK)P65KV%y4CDstlTiExUpE}(T1Q_We63_QM0H!awwitWq$U9t zpt>y=gib+;lU@iGz{w}jinc;?0kr=Xz(JcA*Q59r0ranK3xWP_6gSy31p3GGc9^gf z*=hvPzq&1~OzBVP=#nVHQ};`+wKw37I;2+%Aide|L?cGpJOQLPz2}WKUZvwp8$o)5 zo>|(2|0lf+sa7N{6hQgvwo>}CiA?|C;(P(bH{UwY*}u7aWhcBJ!4&1@3XIjXpYukt zIr5a;CpYmhxk`?bgJcicLHfyBvW&b%Do6&g<5&1GzKO5m3-~nt2Jgck;vIMZufr>F z3y$ClT!^z_3vL^OQmMczHunOH6gCad99}9h4PStlicP~C;O#|VWVnL9Do)@yPUd*8 zRA}bX0=!rN2KEQ#d|;gOqfd2P9$5aAQ79efp`7EzT(c(CBDzebcR519K+;YZ#|?V9 zkcnwnKgEcZ5GgP`{HryoT|{y7s;1w_DqOmkcn?s{aDFOY~FIbpe$)7 zga&i*KhLgh5AUgrt?szM_i=sP4L)%OWP--&Q$<#DelB_3#aWYMyp#cPn?9dc)4@p8 ze)w(QZ9(*dja60g0UgS=1Dp`&1#_l+OI&_7JH}UWnP|Dykp{7Vl*n;5E|wcxeL3^5 zOJ^+NV5)gq4EjC52-w2q#e$Lr1*nn^&QCB$mcXw3O@1f8ke|pka)n$ZXUR!2MvjpE zWG|U9Om>s^$S$&-Y$QFTi$q8z2@p&YNE68;>BLRK;1q{x7hLk*p)Le#i7lJAIWOyLrStKdMkQii0$YF5IW{A%=G04XSm}}2s zu+3s1nGAR)gE50a&tTBf8B{-m;$x7#3=;eT;XPt11M)BkZU(E9!IEN_3fNxDFs6?@ G2>lC8-PW7{ delta 1884 zcmY+CX=qee6vywIw=Z+gnG$3|?)>f%dZJ{j{3PRD8ZkEz2LN^ym&$~{neIMp`_Iu|3@BUeF z|Ezc@BIwUY*9-co9dm2ef<({IT{PrA>(07O>w@*9)oA`=g9{at-X zPiQ}AC$wGKQuRCa4fQcKqFhmCmEqGKyFziTSj4K^R>%!Wkxa9VT}xO^Mzx3fM|Tcq z$0t(xy^RTx6iNx7(h`?Q5fSmsPJ26Mw% zXH0X9HKiNbxwPg=N7-C!TCR_XWbge+zh(?6ERqQxh_|`rhR_2GvD;}=X$h9ju4BjA z47nV=7?xMFD{WypQ1Vi&sy!lPEyrjbyU|Sr8Fy%4^&3oAlf*__n^^OPPnxZZ)XS)TlPDzNx&bTm>25L8Zb{TUJVo8pUaL zswV{x=$~9I7Mz$TH#(UeAM78<;%bRj317jfp8EUhZ(64>s9(v7o9pF@7OCiVN*v9F z7(3aMk)o@_y%?%t*E6aVsTC)j!A5^RhaO!_4{zThYnvn#yVzULR(FH_fZ%GognB#wdLlc%JI8NY!6h0U)n$2&R20~* z)`ru~eILhLGX!qIO}GMI!{_i89ETU+3D^T;FaTSj8J0sFC@Ay>y+*&Gm*}VTZF+*f zOb^kgX@L&Y0lJlLq#4xyuDhTnB|*27JuWrj7>QMa8h4C%m7rESMrD=YsX$|KMJlnD zNQ4lL2yr#$-zPl@E-cW;#fdBk@7xFS2XvY|*0r9rQVTq16FiC>Un+)gaM zdG|}j<>Xv0S&($3fqhBv6SxDv!*B2t%)@2)3O;s} diff --git a/api/src/services/audio.py b/api/src/services/audio.py index f909519..e13d91f 100644 --- a/api/src/services/audio.py +++ b/api/src/services/audio.py @@ -62,10 +62,10 @@ class AudioService: logger.info("Writing to WAV format...") # Always include WAV header for WAV format sf.write(buffer, normalized_audio, sample_rate, format="WAV", subtype='PCM_16') - elif output_format in ["mp3", "aac"]: - logger.info(f"Converting to {output_format.upper()} format...") + elif output_format == "mp3": + logger.info("Converting to MP3 format...") # Use lower bitrate for streaming - sf.write(buffer, normalized_audio, sample_rate, format=output_format.upper()) + sf.write(buffer, normalized_audio, sample_rate, format="MP3") elif output_format == "opus": logger.info("Converting to Opus format...") # Use lower bitrate and smaller frame size for streaming @@ -76,9 +76,14 @@ class AudioService: sf.write(buffer, normalized_audio, sample_rate, format="FLAC", subtype='PCM_16') else: - raise ValueError( - f"Format {output_format} not supported. Supported formats are: wav, mp3, opus, flac, pcm." - ) + if output_format == "aac": + raise ValueError( + "Format aac not supported. Supported formats are: wav, mp3, opus, flac, pcm." + ) + else: + raise ValueError( + f"Format {output_format} not supported. Supported formats are: wav, mp3, opus, flac, pcm." + ) buffer.seek(0) return buffer.getvalue() diff --git a/api/tests/test_endpoints.py b/api/tests/test_endpoints.py index 80fe733..6142e12 100644 --- a/api/tests/test_endpoints.py +++ b/api/tests/test_endpoints.py @@ -1,13 +1,21 @@ from unittest.mock import Mock import pytest +import pytest_asyncio from fastapi.testclient import TestClient +from httpx import AsyncClient from ..src.main import app # Create test client client = TestClient(app) +# Create async client fixture +@pytest_asyncio.fixture +async def async_client(): + async with AsyncClient(app=app, base_url="http://test") as ac: + yield ac + # Mock services @pytest.fixture @@ -34,12 +42,12 @@ def mock_tts_service(monkeypatch): @pytest.fixture def mock_audio_service(monkeypatch): - def mock_convert(*args): - return b"converted mock audio data" - + mock_service = Mock() + mock_service.convert_audio.return_value = b"converted mock audio data" monkeypatch.setattr( - "api.src.routers.openai_compatible.AudioService.convert_audio", mock_convert + "api.src.routers.openai_compatible.AudioService", mock_service ) + return mock_service def test_health_check(): @@ -153,3 +161,87 @@ def test_combine_voices_error(mock_tts_service): assert response.status_code == 500 assert "Combination failed" in response.json()["detail"]["message"] + + +@pytest.mark.asyncio +async def test_openai_speech_pcm_streaming(mock_tts_service, async_client): + """Test streaming PCM audio for real-time playback""" + test_request = { + "model": "kokoro", + "input": "Hello world", + "voice": "af", + "response_format": "pcm", + } + + # Mock streaming response + async def mock_stream(): + yield b"chunk1" + yield b"chunk2" + mock_tts_service.generate_audio_stream.return_value = mock_stream() + + # Add streaming header + headers = {"x-raw-response": "stream"} + response = await async_client.post("/v1/audio/speech", json=test_request, headers=headers) + + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/pcm" + # Just verify status and content type + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/pcm" + + +@pytest.mark.asyncio +async def test_openai_speech_streaming_mp3(mock_tts_service, async_client): + """Test streaming MP3 audio to file""" + test_request = { + "model": "kokoro", + "input": "Hello world", + "voice": "af", + "response_format": "mp3", + } + + # Mock streaming response + async def mock_stream(): + yield b"mp3header" + yield b"mp3data" + mock_tts_service.generate_audio_stream.return_value = mock_stream() + + # Add streaming header + headers = {"x-raw-response": "stream"} + response = await async_client.post("/v1/audio/speech", json=test_request, headers=headers) + + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/mpeg" + assert response.headers["content-disposition"] == "attachment; filename=speech.mp3" + # Just verify status and content type + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/mpeg" + assert response.headers["content-disposition"] == "attachment; filename=speech.mp3" + + +@pytest.mark.asyncio +async def test_openai_speech_streaming_generator(mock_tts_service, async_client): + """Test streaming with async generator""" + test_request = { + "model": "kokoro", + "input": "Hello world", + "voice": "af", + "response_format": "pcm", + } + + # Mock streaming response + async def mock_stream(): + yield b"chunk1" + yield b"chunk2" + + mock_tts_service.generate_audio_stream.return_value = mock_stream() + + # Add streaming header + headers = {"x-raw-response": "stream"} + response = await async_client.post("/v1/audio/speech", json=test_request, headers=headers) + + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/pcm" + # Just verify status and content type + assert response.status_code == 200 + assert response.headers["content-type"] == "audio/pcm" diff --git a/api/tests/test_main.py b/api/tests/test_main.py index c6a972e..51026c5 100644 --- a/api/tests/test_main.py +++ b/api/tests/test_main.py @@ -1,6 +1,6 @@ """Tests for FastAPI application""" -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, call import pytest from fastapi.testclient import TestClient @@ -39,8 +39,12 @@ async def test_lifespan_successful_warmup(mock_logger, mock_tts_model): # Verify the expected logging sequence mock_logger.info.assert_any_call("Loading TTS model and voice packs...") - mock_logger.info.assert_any_call("Model loaded and warmed up on cuda") - mock_logger.info.assert_any_call("3 voice packs loaded successfully") + + # Check for the startup message containing the required info + startup_calls = [call[0][0] for call in mock_logger.info.call_args_list] + startup_msg = next(msg for msg in startup_calls if "Model loaded and warmed up on" in msg) + assert "Model loaded and warmed up on cuda" in startup_msg + assert "3 voice packs loaded successfully" in startup_msg # Verify model setup was called mock_tts_model.setup.assert_called_once()