mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
WIP: basic tests on OpenAI streaming compatibility
This commit is contained in:
parent
0e9f77fc79
commit
e799f0c7c1
4 changed files with 114 additions and 13 deletions
BIN
.coverage
BIN
.coverage
Binary file not shown.
|
@ -62,10 +62,10 @@ class AudioService:
|
|||
logger.info("Writing to WAV format...")
|
||||
# Always include WAV header for WAV format
|
||||
sf.write(buffer, normalized_audio, sample_rate, format="WAV", subtype='PCM_16')
|
||||
elif output_format in ["mp3", "aac"]:
|
||||
logger.info(f"Converting to {output_format.upper()} format...")
|
||||
elif output_format == "mp3":
|
||||
logger.info("Converting to MP3 format...")
|
||||
# Use lower bitrate for streaming
|
||||
sf.write(buffer, normalized_audio, sample_rate, format=output_format.upper())
|
||||
sf.write(buffer, normalized_audio, sample_rate, format="MP3")
|
||||
elif output_format == "opus":
|
||||
logger.info("Converting to Opus format...")
|
||||
# Use lower bitrate and smaller frame size for streaming
|
||||
|
@ -76,9 +76,14 @@ class AudioService:
|
|||
sf.write(buffer, normalized_audio, sample_rate, format="FLAC",
|
||||
subtype='PCM_16')
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Format {output_format} not supported. Supported formats are: wav, mp3, opus, flac, pcm."
|
||||
)
|
||||
if output_format == "aac":
|
||||
raise ValueError(
|
||||
"Format aac not supported. Supported formats are: wav, mp3, opus, flac, pcm."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Format {output_format} not supported. Supported formats are: wav, mp3, opus, flac, pcm."
|
||||
)
|
||||
|
||||
buffer.seek(0)
|
||||
return buffer.getvalue()
|
||||
|
|
|
@ -1,13 +1,21 @@
|
|||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from fastapi.testclient import TestClient
|
||||
from httpx import AsyncClient
|
||||
|
||||
from ..src.main import app
|
||||
|
||||
# Create test client
|
||||
client = TestClient(app)
|
||||
|
||||
# Create async client fixture
|
||||
@pytest_asyncio.fixture
|
||||
async def async_client():
|
||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||
yield ac
|
||||
|
||||
|
||||
# Mock services
|
||||
@pytest.fixture
|
||||
|
@ -34,12 +42,12 @@ def mock_tts_service(monkeypatch):
|
|||
|
||||
@pytest.fixture
|
||||
def mock_audio_service(monkeypatch):
|
||||
def mock_convert(*args):
|
||||
return b"converted mock audio data"
|
||||
|
||||
mock_service = Mock()
|
||||
mock_service.convert_audio.return_value = b"converted mock audio data"
|
||||
monkeypatch.setattr(
|
||||
"api.src.routers.openai_compatible.AudioService.convert_audio", mock_convert
|
||||
"api.src.routers.openai_compatible.AudioService", mock_service
|
||||
)
|
||||
return mock_service
|
||||
|
||||
|
||||
def test_health_check():
|
||||
|
@ -153,3 +161,87 @@ def test_combine_voices_error(mock_tts_service):
|
|||
|
||||
assert response.status_code == 500
|
||||
assert "Combination failed" in response.json()["detail"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_speech_pcm_streaming(mock_tts_service, async_client):
|
||||
"""Test streaming PCM audio for real-time playback"""
|
||||
test_request = {
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": "af",
|
||||
"response_format": "pcm",
|
||||
}
|
||||
|
||||
# Mock streaming response
|
||||
async def mock_stream():
|
||||
yield b"chunk1"
|
||||
yield b"chunk2"
|
||||
mock_tts_service.generate_audio_stream.return_value = mock_stream()
|
||||
|
||||
# Add streaming header
|
||||
headers = {"x-raw-response": "stream"}
|
||||
response = await async_client.post("/v1/audio/speech", json=test_request, headers=headers)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/pcm"
|
||||
# Just verify status and content type
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/pcm"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_speech_streaming_mp3(mock_tts_service, async_client):
|
||||
"""Test streaming MP3 audio to file"""
|
||||
test_request = {
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": "af",
|
||||
"response_format": "mp3",
|
||||
}
|
||||
|
||||
# Mock streaming response
|
||||
async def mock_stream():
|
||||
yield b"mp3header"
|
||||
yield b"mp3data"
|
||||
mock_tts_service.generate_audio_stream.return_value = mock_stream()
|
||||
|
||||
# Add streaming header
|
||||
headers = {"x-raw-response": "stream"}
|
||||
response = await async_client.post("/v1/audio/speech", json=test_request, headers=headers)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/mpeg"
|
||||
assert response.headers["content-disposition"] == "attachment; filename=speech.mp3"
|
||||
# Just verify status and content type
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/mpeg"
|
||||
assert response.headers["content-disposition"] == "attachment; filename=speech.mp3"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_speech_streaming_generator(mock_tts_service, async_client):
|
||||
"""Test streaming with async generator"""
|
||||
test_request = {
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": "af",
|
||||
"response_format": "pcm",
|
||||
}
|
||||
|
||||
# Mock streaming response
|
||||
async def mock_stream():
|
||||
yield b"chunk1"
|
||||
yield b"chunk2"
|
||||
|
||||
mock_tts_service.generate_audio_stream.return_value = mock_stream()
|
||||
|
||||
# Add streaming header
|
||||
headers = {"x-raw-response": "stream"}
|
||||
response = await async_client.post("/v1/audio/speech", json=test_request, headers=headers)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/pcm"
|
||||
# Just verify status and content type
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/pcm"
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
"""Tests for FastAPI application"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
@ -39,8 +39,12 @@ async def test_lifespan_successful_warmup(mock_logger, mock_tts_model):
|
|||
|
||||
# Verify the expected logging sequence
|
||||
mock_logger.info.assert_any_call("Loading TTS model and voice packs...")
|
||||
mock_logger.info.assert_any_call("Model loaded and warmed up on cuda")
|
||||
mock_logger.info.assert_any_call("3 voice packs loaded successfully")
|
||||
|
||||
# Check for the startup message containing the required info
|
||||
startup_calls = [call[0][0] for call in mock_logger.info.call_args_list]
|
||||
startup_msg = next(msg for msg in startup_calls if "Model loaded and warmed up on" in msg)
|
||||
assert "Model loaded and warmed up on cuda" in startup_msg
|
||||
assert "3 voice packs loaded successfully" in startup_msg
|
||||
|
||||
# Verify model setup was called
|
||||
mock_tts_model.setup.assert_called_once()
|
||||
|
|
Loading…
Add table
Reference in a new issue