WIP: basic tests on OpenAI streaming compatibility

This commit is contained in:
remsky 2025-01-04 18:09:23 -07:00
parent 0e9f77fc79
commit e799f0c7c1
4 changed files with 114 additions and 13 deletions

BIN
.coverage

Binary file not shown.

View file

@ -62,10 +62,10 @@ class AudioService:
logger.info("Writing to WAV format...")
# Always include WAV header for WAV format
sf.write(buffer, normalized_audio, sample_rate, format="WAV", subtype='PCM_16')
elif output_format in ["mp3", "aac"]:
logger.info(f"Converting to {output_format.upper()} format...")
elif output_format == "mp3":
logger.info("Converting to MP3 format...")
# Use lower bitrate for streaming
sf.write(buffer, normalized_audio, sample_rate, format=output_format.upper())
sf.write(buffer, normalized_audio, sample_rate, format="MP3")
elif output_format == "opus":
logger.info("Converting to Opus format...")
# Use lower bitrate and smaller frame size for streaming
@ -75,6 +75,11 @@ class AudioService:
# Use smaller block size for streaming
sf.write(buffer, normalized_audio, sample_rate, format="FLAC",
subtype='PCM_16')
else:
if output_format == "aac":
raise ValueError(
"Format aac not supported. Supported formats are: wav, mp3, opus, flac, pcm."
)
else:
raise ValueError(
f"Format {output_format} not supported. Supported formats are: wav, mp3, opus, flac, pcm."

View file

@ -1,13 +1,21 @@
from unittest.mock import Mock
import pytest
import pytest_asyncio
from fastapi.testclient import TestClient
from httpx import AsyncClient
from ..src.main import app
# Create test client
client = TestClient(app)
# Create async client fixture
@pytest_asyncio.fixture
async def async_client():
async with AsyncClient(app=app, base_url="http://test") as ac:
yield ac
# Mock services
@pytest.fixture
@ -34,12 +42,12 @@ def mock_tts_service(monkeypatch):
@pytest.fixture
def mock_audio_service(monkeypatch):
def mock_convert(*args):
return b"converted mock audio data"
mock_service = Mock()
mock_service.convert_audio.return_value = b"converted mock audio data"
monkeypatch.setattr(
"api.src.routers.openai_compatible.AudioService.convert_audio", mock_convert
"api.src.routers.openai_compatible.AudioService", mock_service
)
return mock_service
def test_health_check():
@ -153,3 +161,87 @@ def test_combine_voices_error(mock_tts_service):
assert response.status_code == 500
assert "Combination failed" in response.json()["detail"]["message"]
@pytest.mark.asyncio
async def test_openai_speech_pcm_streaming(mock_tts_service, async_client):
"""Test streaming PCM audio for real-time playback"""
test_request = {
"model": "kokoro",
"input": "Hello world",
"voice": "af",
"response_format": "pcm",
}
# Mock streaming response
async def mock_stream():
yield b"chunk1"
yield b"chunk2"
mock_tts_service.generate_audio_stream.return_value = mock_stream()
# Add streaming header
headers = {"x-raw-response": "stream"}
response = await async_client.post("/v1/audio/speech", json=test_request, headers=headers)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/pcm"
# Just verify status and content type
assert response.status_code == 200
assert response.headers["content-type"] == "audio/pcm"
@pytest.mark.asyncio
async def test_openai_speech_streaming_mp3(mock_tts_service, async_client):
"""Test streaming MP3 audio to file"""
test_request = {
"model": "kokoro",
"input": "Hello world",
"voice": "af",
"response_format": "mp3",
}
# Mock streaming response
async def mock_stream():
yield b"mp3header"
yield b"mp3data"
mock_tts_service.generate_audio_stream.return_value = mock_stream()
# Add streaming header
headers = {"x-raw-response": "stream"}
response = await async_client.post("/v1/audio/speech", json=test_request, headers=headers)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/mpeg"
assert response.headers["content-disposition"] == "attachment; filename=speech.mp3"
# Just verify status and content type
assert response.status_code == 200
assert response.headers["content-type"] == "audio/mpeg"
assert response.headers["content-disposition"] == "attachment; filename=speech.mp3"
@pytest.mark.asyncio
async def test_openai_speech_streaming_generator(mock_tts_service, async_client):
"""Test streaming with async generator"""
test_request = {
"model": "kokoro",
"input": "Hello world",
"voice": "af",
"response_format": "pcm",
}
# Mock streaming response
async def mock_stream():
yield b"chunk1"
yield b"chunk2"
mock_tts_service.generate_audio_stream.return_value = mock_stream()
# Add streaming header
headers = {"x-raw-response": "stream"}
response = await async_client.post("/v1/audio/speech", json=test_request, headers=headers)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/pcm"
# Just verify status and content type
assert response.status_code == 200
assert response.headers["content-type"] == "audio/pcm"

View file

@ -1,6 +1,6 @@
"""Tests for FastAPI application"""
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, patch, call
import pytest
from fastapi.testclient import TestClient
@ -39,8 +39,12 @@ async def test_lifespan_successful_warmup(mock_logger, mock_tts_model):
# Verify the expected logging sequence
mock_logger.info.assert_any_call("Loading TTS model and voice packs...")
mock_logger.info.assert_any_call("Model loaded and warmed up on cuda")
mock_logger.info.assert_any_call("3 voice packs loaded successfully")
# Check for the startup message containing the required info
startup_calls = [call[0][0] for call in mock_logger.info.call_args_list]
startup_msg = next(msg for msg in startup_calls if "Model loaded and warmed up on" in msg)
assert "Model loaded and warmed up on cuda" in startup_msg
assert "3 voice packs loaded successfully" in startup_msg
# Verify model setup was called
mock_tts_model.setup.assert_called_once()