2025-01-22 21:11:47 -07:00
|
|
|
import asyncio
|
2025-01-23 04:11:31 -07:00
|
|
|
import json
|
2025-02-09 18:32:17 -07:00
|
|
|
import os
|
2025-02-15 09:40:01 -05:00
|
|
|
from typing import AsyncGenerator, Tuple
|
2025-02-09 18:32:17 -07:00
|
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import pytest
|
|
|
|
from fastapi.testclient import TestClient
|
2025-01-22 21:11:47 -07:00
|
|
|
|
2025-01-23 04:11:31 -07:00
|
|
|
from api.src.core.config import settings
|
2025-04-04 16:50:46 -06:00
|
|
|
from api.src.inference.base import AudioChunk
|
2025-02-09 18:32:17 -07:00
|
|
|
from api.src.main import app
|
2025-01-23 04:11:31 -07:00
|
|
|
from api.src.routers.openai_compatible import (
|
|
|
|
get_tts_service,
|
2025-02-09 18:32:17 -07:00
|
|
|
load_openai_mappings,
|
|
|
|
stream_audio_chunks,
|
2025-01-23 04:11:31 -07:00
|
|
|
)
|
2025-04-04 16:50:46 -06:00
|
|
|
from api.src.services.streaming_audio_writer import StreamingAudioWriter
|
2025-02-09 18:32:17 -07:00
|
|
|
from api.src.services.tts_service import TTSService
|
2025-01-23 04:11:31 -07:00
|
|
|
from api.src.structures.schemas import OpenAISpeechRequest
|
2025-01-22 21:11:47 -07:00
|
|
|
|
|
|
|
client = TestClient(app)
|
|
|
|
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-23 04:11:31 -07:00
|
|
|
@pytest.fixture
|
|
|
|
def test_voice():
|
|
|
|
"""Fixture providing a test voice name."""
|
|
|
|
return "test_voice"
|
|
|
|
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-23 04:11:31 -07:00
|
|
|
@pytest.fixture
|
|
|
|
def mock_openai_mappings():
|
|
|
|
"""Mock OpenAI mappings for testing."""
|
2025-02-09 18:32:17 -07:00
|
|
|
with patch(
|
|
|
|
"api.src.routers.openai_compatible._openai_mappings",
|
|
|
|
{
|
|
|
|
"models": {"tts-1": "kokoro-v1_0", "tts-1-hd": "kokoro-v1_0"},
|
|
|
|
"voices": {"alloy": "am_adam", "nova": "bf_isabella"},
|
2025-01-23 04:11:31 -07:00
|
|
|
},
|
2025-02-09 18:32:17 -07:00
|
|
|
):
|
2025-01-23 04:11:31 -07:00
|
|
|
yield
|
|
|
|
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-23 04:11:31 -07:00
|
|
|
@pytest.fixture
|
|
|
|
def mock_json_file(tmp_path):
|
|
|
|
"""Create a temporary mock JSON file."""
|
|
|
|
content = {
|
|
|
|
"models": {"test-model": "test-kokoro"},
|
2025-02-09 18:32:17 -07:00
|
|
|
"voices": {"test-voice": "test-internal"},
|
2025-01-23 04:11:31 -07:00
|
|
|
}
|
|
|
|
json_file = tmp_path / "test_mappings.json"
|
|
|
|
json_file.write_text(json.dumps(content))
|
|
|
|
return json_file
|
|
|
|
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-23 04:11:31 -07:00
|
|
|
def test_load_openai_mappings(mock_json_file):
|
|
|
|
"""Test loading OpenAI mappings from JSON file"""
|
|
|
|
with patch("os.path.join", return_value=str(mock_json_file)):
|
|
|
|
mappings = load_openai_mappings()
|
|
|
|
assert "models" in mappings
|
|
|
|
assert "voices" in mappings
|
|
|
|
assert mappings["models"]["test-model"] == "test-kokoro"
|
|
|
|
assert mappings["voices"]["test-voice"] == "test-internal"
|
|
|
|
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-23 04:11:31 -07:00
|
|
|
def test_load_openai_mappings_file_not_found():
|
|
|
|
"""Test handling of missing mappings file"""
|
|
|
|
with patch("os.path.join", return_value="/nonexistent/path"):
|
|
|
|
mappings = load_openai_mappings()
|
|
|
|
assert mappings == {"models": {}, "voices": {}}
|
|
|
|
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-02-09 20:55:21 -07:00
|
|
|
def test_list_models(mock_openai_mappings):
|
|
|
|
"""Test listing available models endpoint"""
|
|
|
|
response = client.get("/v1/models")
|
|
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
|
|
assert data["object"] == "list"
|
|
|
|
assert isinstance(data["data"], list)
|
|
|
|
assert len(data["data"]) == 3 # tts-1, tts-1-hd, and kokoro
|
|
|
|
|
|
|
|
# Verify all expected models are present
|
|
|
|
model_ids = [model["id"] for model in data["data"]]
|
|
|
|
assert "tts-1" in model_ids
|
|
|
|
assert "tts-1-hd" in model_ids
|
|
|
|
assert "kokoro" in model_ids
|
|
|
|
|
|
|
|
# Verify model format
|
|
|
|
for model in data["data"]:
|
|
|
|
assert model["object"] == "model"
|
|
|
|
assert "created" in model
|
|
|
|
assert model["owned_by"] == "kokoro"
|
|
|
|
|
|
|
|
|
|
|
|
def test_retrieve_model(mock_openai_mappings):
|
|
|
|
"""Test retrieving a specific model endpoint"""
|
|
|
|
# Test successful model retrieval
|
|
|
|
response = client.get("/v1/models/tts-1")
|
|
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
|
|
assert data["id"] == "tts-1"
|
|
|
|
assert data["object"] == "model"
|
|
|
|
assert data["owned_by"] == "kokoro"
|
|
|
|
assert "created" in data
|
|
|
|
|
|
|
|
# Test non-existent model
|
|
|
|
response = client.get("/v1/models/nonexistent-model")
|
|
|
|
assert response.status_code == 404
|
|
|
|
error = response.json()
|
|
|
|
assert error["detail"]["error"] == "model_not_found"
|
|
|
|
assert "not found" in error["detail"]["message"]
|
|
|
|
assert error["detail"]["type"] == "invalid_request_error"
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-01-23 04:11:31 -07:00
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_get_tts_service_initialization():
|
|
|
|
"""Test TTSService initialization"""
|
|
|
|
with patch("api.src.routers.openai_compatible._tts_service", None):
|
|
|
|
with patch("api.src.routers.openai_compatible._init_lock", None):
|
|
|
|
with patch("api.src.services.tts_service.TTSService.create") as mock_create:
|
|
|
|
mock_service = AsyncMock()
|
|
|
|
mock_create.return_value = mock_service
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-23 04:11:31 -07:00
|
|
|
# Test concurrent access
|
|
|
|
async def get_service():
|
|
|
|
return await get_tts_service()
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-23 04:11:31 -07:00
|
|
|
# Create multiple concurrent requests
|
|
|
|
tasks = [get_service() for _ in range(5)]
|
|
|
|
results = await asyncio.gather(*tasks)
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-23 04:11:31 -07:00
|
|
|
# Verify service was created only once
|
|
|
|
mock_create.assert_called_once()
|
|
|
|
assert all(r == mock_service for r in results)
|
|
|
|
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-23 04:11:31 -07:00
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_stream_audio_chunks_client_disconnect():
|
|
|
|
"""Test handling of client disconnect during streaming"""
|
|
|
|
mock_request = MagicMock()
|
|
|
|
mock_request.is_disconnected = AsyncMock(return_value=True)
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-23 04:11:31 -07:00
|
|
|
mock_service = AsyncMock()
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-23 04:11:31 -07:00
|
|
|
async def mock_stream(*args, **kwargs):
|
|
|
|
for i in range(5):
|
2025-02-16 15:37:01 -05:00
|
|
|
yield AudioChunk(np.ndarray([],np.int16),output=b"chunk")
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-23 04:11:31 -07:00
|
|
|
mock_service.generate_audio_stream = mock_stream
|
|
|
|
mock_service.list_voices.return_value = ["test_voice"]
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-23 04:11:31 -07:00
|
|
|
request = OpenAISpeechRequest(
|
|
|
|
model="kokoro",
|
|
|
|
input="Test text",
|
|
|
|
voice="test_voice",
|
|
|
|
response_format="mp3",
|
|
|
|
stream=True,
|
2025-02-09 18:32:17 -07:00
|
|
|
speed=1.0,
|
2025-01-23 04:11:31 -07:00
|
|
|
)
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-03-20 19:15:07 +00:00
|
|
|
writer = StreamingAudioWriter("mp3", 24000)
|
|
|
|
|
2025-01-23 04:11:31 -07:00
|
|
|
chunks = []
|
2025-03-20 19:15:07 +00:00
|
|
|
async for chunk in stream_audio_chunks(mock_service, request, mock_request, writer):
|
2025-01-23 04:11:31 -07:00
|
|
|
chunks.append(chunk)
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-03-20 19:15:07 +00:00
|
|
|
writer.close()
|
|
|
|
|
2025-01-23 04:11:31 -07:00
|
|
|
assert len(chunks) == 0 # Should stop immediately due to disconnect
|
|
|
|
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-23 04:11:31 -07:00
|
|
|
def test_openai_voice_mapping(mock_tts_service, mock_openai_mappings):
|
|
|
|
"""Test OpenAI voice name mapping"""
|
|
|
|
mock_tts_service.list_voices.return_value = ["am_adam", "bf_isabella"]
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-23 04:11:31 -07:00
|
|
|
response = client.post(
|
|
|
|
"/v1/audio/speech",
|
|
|
|
json={
|
|
|
|
"model": "tts-1",
|
|
|
|
"input": "Hello world",
|
|
|
|
"voice": "alloy", # OpenAI voice name
|
|
|
|
"response_format": "mp3",
|
2025-02-09 18:32:17 -07:00
|
|
|
"stream": False,
|
|
|
|
},
|
2025-01-23 04:11:31 -07:00
|
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
|
|
mock_tts_service.generate_audio.assert_called_once()
|
|
|
|
assert mock_tts_service.generate_audio.call_args[1]["voice"] == "am_adam"
|
|
|
|
|
2025-02-09 18:32:17 -07:00
|
|
|
|
|
|
|
def test_openai_voice_mapping_streaming(
|
|
|
|
mock_tts_service, mock_openai_mappings, mock_audio_bytes
|
|
|
|
):
|
2025-01-23 04:11:31 -07:00
|
|
|
"""Test OpenAI voice mapping in streaming mode"""
|
|
|
|
mock_tts_service.list_voices.return_value = ["am_adam", "bf_isabella"]
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-23 04:11:31 -07:00
|
|
|
response = client.post(
|
|
|
|
"/v1/audio/speech",
|
|
|
|
json={
|
|
|
|
"model": "tts-1-hd",
|
|
|
|
"input": "Hello world",
|
|
|
|
"voice": "nova", # OpenAI voice name
|
|
|
|
"response_format": "mp3",
|
2025-02-09 18:32:17 -07:00
|
|
|
"stream": True,
|
|
|
|
},
|
2025-01-23 04:11:31 -07:00
|
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
|
|
content = b""
|
|
|
|
for chunk in response.iter_bytes():
|
|
|
|
content += chunk
|
|
|
|
assert content == mock_audio_bytes
|
|
|
|
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-23 04:11:31 -07:00
|
|
|
def test_invalid_openai_model(mock_tts_service, mock_openai_mappings):
|
|
|
|
"""Test error handling for invalid OpenAI model"""
|
|
|
|
response = client.post(
|
|
|
|
"/v1/audio/speech",
|
|
|
|
json={
|
|
|
|
"model": "invalid-model",
|
|
|
|
"input": "Hello world",
|
|
|
|
"voice": "alloy",
|
|
|
|
"response_format": "mp3",
|
2025-02-09 18:32:17 -07:00
|
|
|
"stream": False,
|
|
|
|
},
|
2025-01-23 04:11:31 -07:00
|
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
|
|
error_response = response.json()
|
|
|
|
assert error_response["detail"]["error"] == "invalid_model"
|
|
|
|
assert "Unsupported model" in error_response["detail"]["message"]
|
|
|
|
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-22 21:11:47 -07:00
|
|
|
@pytest.fixture
|
|
|
|
def mock_audio_bytes():
|
|
|
|
"""Mock audio bytes for testing."""
|
|
|
|
return b"mock audio data"
|
|
|
|
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-22 21:11:47 -07:00
|
|
|
@pytest.fixture
|
|
|
|
def mock_tts_service(mock_audio_bytes):
|
|
|
|
"""Mock TTS service for testing."""
|
|
|
|
with patch("api.src.routers.openai_compatible.get_tts_service") as mock_get:
|
|
|
|
service = AsyncMock(spec=TTSService)
|
2025-02-16 15:37:01 -05:00
|
|
|
service.generate_audio.return_value = AudioChunk(np.zeros(1000,np.int16))
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-02-16 15:37:01 -05:00
|
|
|
async def mock_stream(*args, **kwargs) -> AsyncGenerator[AudioChunk, None]:
|
|
|
|
yield AudioChunk(np.ndarray([],np.int16),output=mock_audio_bytes)
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-22 21:11:47 -07:00
|
|
|
service.generate_audio_stream = mock_stream
|
2025-01-23 04:11:31 -07:00
|
|
|
service.list_voices.return_value = ["test_voice", "voice1", "voice2"]
|
2025-01-22 21:11:47 -07:00
|
|
|
service.combine_voices.return_value = "voice1_voice2"
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-22 21:11:47 -07:00
|
|
|
mock_get.return_value = service
|
|
|
|
mock_get.side_effect = None
|
|
|
|
yield service
|
|
|
|
|
2025-02-09 18:32:17 -07:00
|
|
|
|
|
|
|
@patch("api.src.services.audio.AudioService.convert_audio")
|
|
|
|
def test_openai_speech_endpoint(
|
|
|
|
mock_convert, mock_tts_service, test_voice, mock_audio_bytes
|
|
|
|
):
|
2025-01-22 21:11:47 -07:00
|
|
|
"""Test the OpenAI-compatible speech endpoint with basic MP3 generation"""
|
2025-01-30 04:44:04 -07:00
|
|
|
# Configure mocks
|
2025-02-16 15:37:01 -05:00
|
|
|
mock_tts_service.generate_audio.return_value = AudioChunk(np.zeros(1000,np.int16))
|
|
|
|
mock_convert.return_value = AudioChunk(np.zeros(1000,np.int16),output=mock_audio_bytes)
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-22 21:11:47 -07:00
|
|
|
response = client.post(
|
|
|
|
"/v1/audio/speech",
|
|
|
|
json={
|
|
|
|
"model": "kokoro",
|
|
|
|
"input": "Hello world",
|
|
|
|
"voice": test_voice,
|
|
|
|
"response_format": "mp3",
|
2025-02-09 18:32:17 -07:00
|
|
|
"stream": False,
|
|
|
|
},
|
2025-01-22 21:11:47 -07:00
|
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
|
|
assert response.headers["content-type"] == "audio/mpeg"
|
|
|
|
assert len(response.content) > 0
|
2025-02-15 09:49:15 -05:00
|
|
|
assert response.content == mock_audio_bytes + mock_audio_bytes
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-30 04:44:04 -07:00
|
|
|
mock_tts_service.generate_audio.assert_called_once()
|
2025-02-15 09:49:15 -05:00
|
|
|
assert mock_convert.call_count == 2
|
2025-01-22 21:11:47 -07:00
|
|
|
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-22 21:11:47 -07:00
|
|
|
def test_openai_speech_streaming(mock_tts_service, test_voice, mock_audio_bytes):
|
|
|
|
"""Test the OpenAI-compatible speech endpoint with streaming"""
|
|
|
|
response = client.post(
|
|
|
|
"/v1/audio/speech",
|
|
|
|
json={
|
|
|
|
"model": "kokoro",
|
|
|
|
"input": "Hello world",
|
|
|
|
"voice": test_voice,
|
|
|
|
"response_format": "mp3",
|
2025-02-09 18:32:17 -07:00
|
|
|
"stream": True,
|
|
|
|
},
|
2025-01-22 21:11:47 -07:00
|
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
|
|
assert response.headers["content-type"] == "audio/mpeg"
|
|
|
|
assert "Transfer-Encoding" in response.headers
|
|
|
|
assert response.headers["Transfer-Encoding"] == "chunked"
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-22 21:11:47 -07:00
|
|
|
content = b""
|
|
|
|
for chunk in response.iter_bytes():
|
|
|
|
content += chunk
|
|
|
|
assert content == mock_audio_bytes
|
|
|
|
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-22 21:11:47 -07:00
|
|
|
def test_openai_speech_pcm_streaming(mock_tts_service, test_voice, mock_audio_bytes):
|
|
|
|
"""Test PCM streaming format"""
|
|
|
|
response = client.post(
|
|
|
|
"/v1/audio/speech",
|
|
|
|
json={
|
|
|
|
"model": "kokoro",
|
|
|
|
"input": "Hello world",
|
|
|
|
"voice": test_voice,
|
|
|
|
"response_format": "pcm",
|
2025-02-09 18:32:17 -07:00
|
|
|
"stream": True,
|
|
|
|
},
|
2025-01-22 21:11:47 -07:00
|
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
|
|
assert response.headers["content-type"] == "audio/pcm"
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-22 21:11:47 -07:00
|
|
|
content = b""
|
|
|
|
for chunk in response.iter_bytes():
|
|
|
|
content += chunk
|
|
|
|
assert content == mock_audio_bytes
|
|
|
|
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-22 21:11:47 -07:00
|
|
|
def test_openai_speech_invalid_voice(mock_tts_service):
|
|
|
|
"""Test error handling for invalid voice"""
|
2025-02-09 18:32:17 -07:00
|
|
|
mock_tts_service.generate_audio.side_effect = ValueError(
|
|
|
|
"Voice 'invalid_voice' not found"
|
|
|
|
)
|
|
|
|
|
2025-01-22 21:11:47 -07:00
|
|
|
response = client.post(
|
|
|
|
"/v1/audio/speech",
|
|
|
|
json={
|
|
|
|
"model": "kokoro",
|
|
|
|
"input": "Hello world",
|
|
|
|
"voice": "invalid_voice",
|
|
|
|
"response_format": "mp3",
|
2025-02-09 18:32:17 -07:00
|
|
|
"stream": False,
|
|
|
|
},
|
2025-01-22 21:11:47 -07:00
|
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
|
|
error_response = response.json()
|
|
|
|
assert error_response["detail"]["error"] == "validation_error"
|
|
|
|
assert "Voice 'invalid_voice' not found" in error_response["detail"]["message"]
|
|
|
|
assert error_response["detail"]["type"] == "invalid_request_error"
|
|
|
|
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-22 21:11:47 -07:00
|
|
|
def test_openai_speech_empty_text(mock_tts_service, test_voice):
|
|
|
|
"""Test error handling for empty text"""
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-23 04:11:31 -07:00
|
|
|
async def mock_error_stream(*args, **kwargs):
|
|
|
|
raise ValueError("Text is empty after preprocessing")
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-23 04:11:31 -07:00
|
|
|
mock_tts_service.generate_audio = mock_error_stream
|
|
|
|
mock_tts_service.list_voices.return_value = ["test_voice"]
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-22 21:11:47 -07:00
|
|
|
response = client.post(
|
|
|
|
"/v1/audio/speech",
|
|
|
|
json={
|
|
|
|
"model": "kokoro",
|
|
|
|
"input": "",
|
|
|
|
"voice": test_voice,
|
|
|
|
"response_format": "mp3",
|
2025-02-09 18:32:17 -07:00
|
|
|
"stream": False,
|
|
|
|
},
|
2025-01-22 21:11:47 -07:00
|
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
|
|
error_response = response.json()
|
|
|
|
assert error_response["detail"]["error"] == "validation_error"
|
|
|
|
assert "Text is empty after preprocessing" in error_response["detail"]["message"]
|
|
|
|
assert error_response["detail"]["type"] == "invalid_request_error"
|
|
|
|
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-22 21:11:47 -07:00
|
|
|
def test_openai_speech_invalid_format(mock_tts_service, test_voice):
|
|
|
|
"""Test error handling for invalid format"""
|
|
|
|
response = client.post(
|
|
|
|
"/v1/audio/speech",
|
|
|
|
json={
|
|
|
|
"model": "kokoro",
|
|
|
|
"input": "Hello world",
|
|
|
|
"voice": test_voice,
|
|
|
|
"response_format": "invalid_format",
|
2025-02-09 18:32:17 -07:00
|
|
|
"stream": False,
|
|
|
|
},
|
2025-01-22 21:11:47 -07:00
|
|
|
)
|
|
|
|
assert response.status_code == 422 # Validation error from Pydantic
|
|
|
|
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-22 21:11:47 -07:00
|
|
|
def test_list_voices(mock_tts_service):
|
|
|
|
"""Test listing available voices"""
|
2025-01-23 04:11:31 -07:00
|
|
|
# Override the mock for this specific test
|
|
|
|
mock_tts_service.list_voices.return_value = ["voice1", "voice2"]
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-22 21:11:47 -07:00
|
|
|
response = client.get("/v1/audio/voices")
|
|
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
|
|
assert "voices" in data
|
|
|
|
assert len(data["voices"]) == 2
|
|
|
|
assert "voice1" in data["voices"]
|
|
|
|
assert "voice2" in data["voices"]
|
|
|
|
|
2025-02-09 18:32:17 -07:00
|
|
|
|
|
|
|
@patch("api.src.routers.openai_compatible.settings")
|
2025-02-06 23:43:26 -07:00
|
|
|
def test_combine_voices(mock_settings, mock_tts_service):
|
2025-01-22 21:11:47 -07:00
|
|
|
"""Test combining voices endpoint"""
|
2025-02-06 23:43:26 -07:00
|
|
|
# Enable local voice saving for this test
|
|
|
|
mock_settings.allow_local_voice_saving = True
|
2025-02-09 18:32:17 -07:00
|
|
|
|
|
|
|
response = client.post("/v1/audio/voices/combine", json="voice1+voice2")
|
2025-01-22 21:11:47 -07:00
|
|
|
assert response.status_code == 200
|
2025-02-06 23:43:26 -07:00
|
|
|
assert response.headers["content-type"] == "application/octet-stream"
|
|
|
|
assert "voice1+voice2.pt" in response.headers["content-disposition"]
|
2025-01-22 21:11:47 -07:00
|
|
|
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-22 21:11:47 -07:00
|
|
|
def test_server_error(mock_tts_service, test_voice):
|
|
|
|
"""Test handling of server errors"""
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-23 04:11:31 -07:00
|
|
|
async def mock_error_stream(*args, **kwargs):
|
|
|
|
raise RuntimeError("Internal server error")
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-23 04:11:31 -07:00
|
|
|
mock_tts_service.generate_audio = mock_error_stream
|
|
|
|
mock_tts_service.list_voices.return_value = ["test_voice"]
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-22 21:11:47 -07:00
|
|
|
response = client.post(
|
|
|
|
"/v1/audio/speech",
|
|
|
|
json={
|
|
|
|
"model": "kokoro",
|
|
|
|
"input": "Hello world",
|
|
|
|
"voice": test_voice,
|
|
|
|
"response_format": "mp3",
|
2025-02-09 18:32:17 -07:00
|
|
|
"stream": False,
|
|
|
|
},
|
2025-01-22 21:11:47 -07:00
|
|
|
)
|
|
|
|
assert response.status_code == 500
|
|
|
|
error_response = response.json()
|
|
|
|
assert error_response["detail"]["error"] == "processing_error"
|
|
|
|
assert error_response["detail"]["type"] == "server_error"
|
|
|
|
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-22 21:11:47 -07:00
|
|
|
def test_streaming_error(mock_tts_service, test_voice):
|
|
|
|
"""Test handling streaming errors"""
|
2025-01-30 04:44:04 -07:00
|
|
|
# Mock process_voices to raise the error
|
|
|
|
mock_tts_service.list_voices.side_effect = RuntimeError("Streaming failed")
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-22 21:11:47 -07:00
|
|
|
response = client.post(
|
|
|
|
"/v1/audio/speech",
|
|
|
|
json={
|
|
|
|
"model": "kokoro",
|
|
|
|
"input": "Hello world",
|
|
|
|
"voice": test_voice,
|
|
|
|
"response_format": "mp3",
|
2025-02-09 18:32:17 -07:00
|
|
|
"stream": True,
|
|
|
|
},
|
2025-01-22 21:11:47 -07:00
|
|
|
)
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-22 21:11:47 -07:00
|
|
|
assert response.status_code == 500
|
2025-01-30 04:44:04 -07:00
|
|
|
error_data = response.json()
|
|
|
|
assert error_data["detail"]["error"] == "processing_error"
|
|
|
|
assert error_data["detail"]["type"] == "server_error"
|
|
|
|
assert "Streaming failed" in error_data["detail"]["message"]
|
2025-01-23 04:11:31 -07:00
|
|
|
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-23 04:11:31 -07:00
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_streaming_initialization_error():
|
|
|
|
"""Test handling of streaming initialization errors"""
|
|
|
|
mock_service = AsyncMock()
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-23 04:11:31 -07:00
|
|
|
async def mock_error_stream(*args, **kwargs):
|
|
|
|
if False: # This makes it a proper generator
|
|
|
|
yield b""
|
|
|
|
raise RuntimeError("Failed to initialize stream")
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-23 04:11:31 -07:00
|
|
|
mock_service.generate_audio_stream = mock_error_stream
|
|
|
|
mock_service.list_voices.return_value = ["test_voice"]
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-01-23 04:11:31 -07:00
|
|
|
request = OpenAISpeechRequest(
|
|
|
|
model="kokoro",
|
|
|
|
input="Test text",
|
|
|
|
voice="test_voice",
|
|
|
|
response_format="mp3",
|
|
|
|
stream=True,
|
2025-02-09 18:32:17 -07:00
|
|
|
speed=1.0,
|
2025-01-23 04:11:31 -07:00
|
|
|
)
|
2025-02-09 18:32:17 -07:00
|
|
|
|
2025-03-20 19:15:07 +00:00
|
|
|
writer = StreamingAudioWriter("mp3", 24000)
|
|
|
|
|
2025-01-23 04:11:31 -07:00
|
|
|
with pytest.raises(RuntimeError) as exc:
|
2025-03-20 19:15:07 +00:00
|
|
|
async for _ in stream_audio_chunks(mock_service, request, MagicMock(), writer):
|
2025-01-23 04:11:31 -07:00
|
|
|
pass
|
2025-03-20 19:15:07 +00:00
|
|
|
|
|
|
|
writer.close()
|
2025-02-09 18:32:17 -07:00
|
|
|
assert "Failed to initialize stream" in str(exc.value)
|