WIP: basic tests on OpenAI streaming compatibility

This commit is contained in:
remsky 2025-01-04 18:09:23 -07:00
parent 0e9f77fc79
commit e799f0c7c1
4 changed files with 114 additions and 13 deletions

BIN
.coverage

Binary file not shown.

View file

@ -62,10 +62,10 @@ class AudioService:
logger.info("Writing to WAV format...") logger.info("Writing to WAV format...")
# Always include WAV header for WAV format # Always include WAV header for WAV format
sf.write(buffer, normalized_audio, sample_rate, format="WAV", subtype='PCM_16') sf.write(buffer, normalized_audio, sample_rate, format="WAV", subtype='PCM_16')
elif output_format in ["mp3", "aac"]: elif output_format == "mp3":
logger.info(f"Converting to {output_format.upper()} format...") logger.info("Converting to MP3 format...")
# Use lower bitrate for streaming # Use lower bitrate for streaming
sf.write(buffer, normalized_audio, sample_rate, format=output_format.upper()) sf.write(buffer, normalized_audio, sample_rate, format="MP3")
elif output_format == "opus": elif output_format == "opus":
logger.info("Converting to Opus format...") logger.info("Converting to Opus format...")
# Use lower bitrate and smaller frame size for streaming # Use lower bitrate and smaller frame size for streaming
@ -75,6 +75,11 @@ class AudioService:
# Use smaller block size for streaming # Use smaller block size for streaming
sf.write(buffer, normalized_audio, sample_rate, format="FLAC", sf.write(buffer, normalized_audio, sample_rate, format="FLAC",
subtype='PCM_16') subtype='PCM_16')
else:
if output_format == "aac":
raise ValueError(
"Format aac not supported. Supported formats are: wav, mp3, opus, flac, pcm."
)
else: else:
raise ValueError( raise ValueError(
f"Format {output_format} not supported. Supported formats are: wav, mp3, opus, flac, pcm." f"Format {output_format} not supported. Supported formats are: wav, mp3, opus, flac, pcm."

View file

@ -1,13 +1,21 @@
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
import pytest_asyncio
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from httpx import AsyncClient
from ..src.main import app from ..src.main import app
# Create test client # Create test client
client = TestClient(app) client = TestClient(app)
# Create async client fixture
@pytest_asyncio.fixture
async def async_client():
async with AsyncClient(app=app, base_url="http://test") as ac:
yield ac
# Mock services # Mock services
@pytest.fixture @pytest.fixture
@ -34,12 +42,12 @@ def mock_tts_service(monkeypatch):
@pytest.fixture @pytest.fixture
def mock_audio_service(monkeypatch): def mock_audio_service(monkeypatch):
def mock_convert(*args): mock_service = Mock()
return b"converted mock audio data" mock_service.convert_audio.return_value = b"converted mock audio data"
monkeypatch.setattr( monkeypatch.setattr(
"api.src.routers.openai_compatible.AudioService.convert_audio", mock_convert "api.src.routers.openai_compatible.AudioService", mock_service
) )
return mock_service
def test_health_check(): def test_health_check():
@ -153,3 +161,87 @@ def test_combine_voices_error(mock_tts_service):
assert response.status_code == 500 assert response.status_code == 500
assert "Combination failed" in response.json()["detail"]["message"] assert "Combination failed" in response.json()["detail"]["message"]
@pytest.mark.asyncio
async def test_openai_speech_pcm_streaming(mock_tts_service, async_client):
"""Test streaming PCM audio for real-time playback"""
test_request = {
"model": "kokoro",
"input": "Hello world",
"voice": "af",
"response_format": "pcm",
}
# Mock streaming response
async def mock_stream():
yield b"chunk1"
yield b"chunk2"
mock_tts_service.generate_audio_stream.return_value = mock_stream()
# Add streaming header
headers = {"x-raw-response": "stream"}
response = await async_client.post("/v1/audio/speech", json=test_request, headers=headers)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/pcm"
# Just verify status and content type
assert response.status_code == 200
assert response.headers["content-type"] == "audio/pcm"
@pytest.mark.asyncio
async def test_openai_speech_streaming_mp3(mock_tts_service, async_client):
"""Test streaming MP3 audio to file"""
test_request = {
"model": "kokoro",
"input": "Hello world",
"voice": "af",
"response_format": "mp3",
}
# Mock streaming response
async def mock_stream():
yield b"mp3header"
yield b"mp3data"
mock_tts_service.generate_audio_stream.return_value = mock_stream()
# Add streaming header
headers = {"x-raw-response": "stream"}
response = await async_client.post("/v1/audio/speech", json=test_request, headers=headers)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/mpeg"
assert response.headers["content-disposition"] == "attachment; filename=speech.mp3"
# Just verify status and content type
assert response.status_code == 200
assert response.headers["content-type"] == "audio/mpeg"
assert response.headers["content-disposition"] == "attachment; filename=speech.mp3"
@pytest.mark.asyncio
async def test_openai_speech_streaming_generator(mock_tts_service, async_client):
"""Test streaming with async generator"""
test_request = {
"model": "kokoro",
"input": "Hello world",
"voice": "af",
"response_format": "pcm",
}
# Mock streaming response
async def mock_stream():
yield b"chunk1"
yield b"chunk2"
mock_tts_service.generate_audio_stream.return_value = mock_stream()
# Add streaming header
headers = {"x-raw-response": "stream"}
response = await async_client.post("/v1/audio/speech", json=test_request, headers=headers)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/pcm"
# Just verify status and content type
assert response.status_code == 200
assert response.headers["content-type"] == "audio/pcm"

View file

@ -1,6 +1,6 @@
"""Tests for FastAPI application""" """Tests for FastAPI application"""
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch, call
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
@ -39,8 +39,12 @@ async def test_lifespan_successful_warmup(mock_logger, mock_tts_model):
# Verify the expected logging sequence # Verify the expected logging sequence
mock_logger.info.assert_any_call("Loading TTS model and voice packs...") mock_logger.info.assert_any_call("Loading TTS model and voice packs...")
mock_logger.info.assert_any_call("Model loaded and warmed up on cuda")
mock_logger.info.assert_any_call("3 voice packs loaded successfully") # Check for the startup message containing the required info
startup_calls = [call[0][0] for call in mock_logger.info.call_args_list]
startup_msg = next(msg for msg in startup_calls if "Model loaded and warmed up on" in msg)
assert "Model loaded and warmed up on cuda" in startup_msg
assert "3 voice packs loaded successfully" in startup_msg
# Verify model setup was called # Verify model setup was called
mock_tts_model.setup.assert_called_once() mock_tts_model.setup.assert_called_once()