mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
fixed a bunch of tests
This commit is contained in:
parent
1a03ac7464
commit
1a6e7abac3
2 changed files with 15 additions and 14 deletions
|
@ -1,24 +1,24 @@
|
|||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import requests
|
||||
import base64
|
||||
import json
|
||||
|
||||
def test_generate_captioned_speech():
|
||||
"""Test the generate_captioned_speech function with mocked responses"""
|
||||
# Mock the API responses
|
||||
mock_audio_response = MagicMock()
|
||||
mock_audio_response.status_code = 200
|
||||
mock_audio_response.content = b"mock audio data"
|
||||
mock_audio_response.headers = {"X-Timestamps-Path": "test.json"}
|
||||
|
||||
mock_timestamps_response = MagicMock()
|
||||
mock_timestamps_response.status_code = 200
|
||||
mock_timestamps_response.json.return_value = [
|
||||
{"word": "test", "start_time": 0.0, "end_time": 1.0}
|
||||
]
|
||||
mock_timestamps_response.content = json.dumps({
|
||||
"audio":base64.b64encode(b"mock audio data").decode("utf-8"),
|
||||
"timestamps":[{"word": "test", "start_time": 0.0, "end_time": 1.0}]
|
||||
})
|
||||
|
||||
# Patch both HTTP requests
|
||||
with patch('requests.post', return_value=mock_audio_response), \
|
||||
patch('requests.get', return_value=mock_timestamps_response):
|
||||
# Patch the HTTP requests
|
||||
with patch('requests.post', return_value=mock_timestamps_response):
|
||||
|
||||
# Import here to avoid module-level import issues
|
||||
from examples.captioned_speech_example import generate_captioned_speech
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import AsyncGenerator
|
||||
from typing import AsyncGenerator, Tuple
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from api.src.inference.base import AudioChunk
|
||||
import numpy as np
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
@ -144,7 +145,7 @@ async def test_stream_audio_chunks_client_disconnect():
|
|||
|
||||
async def mock_stream(*args, **kwargs):
|
||||
for i in range(5):
|
||||
yield b"chunk"
|
||||
yield (b"chunk",AudioChunk(np.ndarray([],np.int16)))
|
||||
|
||||
mock_service.generate_audio_stream = mock_stream
|
||||
mock_service.list_voices.return_value = ["test_voice"]
|
||||
|
@ -236,10 +237,10 @@ def mock_tts_service(mock_audio_bytes):
|
|||
"""Mock TTS service for testing."""
|
||||
with patch("api.src.routers.openai_compatible.get_tts_service") as mock_get:
|
||||
service = AsyncMock(spec=TTSService)
|
||||
service.generate_audio.return_value = (np.zeros(1000), 0.1)
|
||||
service.generate_audio.return_value = (np.zeros(1000), AudioChunk(np.zeros(1000,np.int16)))
|
||||
|
||||
async def mock_stream(*args, **kwargs) -> AsyncGenerator[bytes, None]:
|
||||
yield mock_audio_bytes
|
||||
async def mock_stream(*args, **kwargs) -> AsyncGenerator[Tuple[bytes,AudioChunk], None]:
|
||||
yield (mock_audio_bytes, AudioChunk(np.ndarray([],np.int16)))
|
||||
|
||||
service.generate_audio_stream = mock_stream
|
||||
service.list_voices.return_value = ["test_voice", "voice1", "voice2"]
|
||||
|
@ -256,7 +257,7 @@ def test_openai_speech_endpoint(
|
|||
):
|
||||
"""Test the OpenAI-compatible speech endpoint with basic MP3 generation"""
|
||||
# Configure mocks
|
||||
mock_tts_service.generate_audio.return_value = (np.zeros(1000), 0.1)
|
||||
mock_tts_service.generate_audio.return_value = (np.zeros(1000), AudioChunk(np.zeros(1000,np.int16)))
|
||||
mock_convert.return_value = mock_audio_bytes
|
||||
|
||||
response = client.post(
|
||||
|
|
Loading…
Add table
Reference in a new issue