mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
fixed a bunch of tests
This commit is contained in:
parent
1a03ac7464
commit
1a6e7abac3
2 changed files with 15 additions and 14 deletions
|
@ -1,24 +1,24 @@
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import patch, MagicMock
|
from unittest.mock import patch, MagicMock
|
||||||
import requests
|
import requests
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
|
||||||
def test_generate_captioned_speech():
|
def test_generate_captioned_speech():
|
||||||
"""Test the generate_captioned_speech function with mocked responses"""
|
"""Test the generate_captioned_speech function with mocked responses"""
|
||||||
# Mock the API responses
|
# Mock the API responses
|
||||||
mock_audio_response = MagicMock()
|
mock_audio_response = MagicMock()
|
||||||
mock_audio_response.status_code = 200
|
mock_audio_response.status_code = 200
|
||||||
mock_audio_response.content = b"mock audio data"
|
|
||||||
mock_audio_response.headers = {"X-Timestamps-Path": "test.json"}
|
|
||||||
|
|
||||||
mock_timestamps_response = MagicMock()
|
mock_timestamps_response = MagicMock()
|
||||||
mock_timestamps_response.status_code = 200
|
mock_timestamps_response.status_code = 200
|
||||||
mock_timestamps_response.json.return_value = [
|
mock_timestamps_response.content = json.dumps({
|
||||||
{"word": "test", "start_time": 0.0, "end_time": 1.0}
|
"audio":base64.b64encode(b"mock audio data").decode("utf-8"),
|
||||||
]
|
"timestamps":[{"word": "test", "start_time": 0.0, "end_time": 1.0}]
|
||||||
|
})
|
||||||
|
|
||||||
# Patch both HTTP requests
|
# Patch the HTTP requests
|
||||||
with patch('requests.post', return_value=mock_audio_response), \
|
with patch('requests.post', return_value=mock_timestamps_response):
|
||||||
patch('requests.get', return_value=mock_timestamps_response):
|
|
||||||
|
|
||||||
# Import here to avoid module-level import issues
|
# Import here to avoid module-level import issues
|
||||||
from examples.captioned_speech_example import generate_captioned_speech
|
from examples.captioned_speech_example import generate_captioned_speech
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator, Tuple
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from api.src.inference.base import AudioChunk
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
@ -144,7 +145,7 @@ async def test_stream_audio_chunks_client_disconnect():
|
||||||
|
|
||||||
async def mock_stream(*args, **kwargs):
|
async def mock_stream(*args, **kwargs):
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
yield b"chunk"
|
yield (b"chunk",AudioChunk(np.ndarray([],np.int16)))
|
||||||
|
|
||||||
mock_service.generate_audio_stream = mock_stream
|
mock_service.generate_audio_stream = mock_stream
|
||||||
mock_service.list_voices.return_value = ["test_voice"]
|
mock_service.list_voices.return_value = ["test_voice"]
|
||||||
|
@ -236,10 +237,10 @@ def mock_tts_service(mock_audio_bytes):
|
||||||
"""Mock TTS service for testing."""
|
"""Mock TTS service for testing."""
|
||||||
with patch("api.src.routers.openai_compatible.get_tts_service") as mock_get:
|
with patch("api.src.routers.openai_compatible.get_tts_service") as mock_get:
|
||||||
service = AsyncMock(spec=TTSService)
|
service = AsyncMock(spec=TTSService)
|
||||||
service.generate_audio.return_value = (np.zeros(1000), 0.1)
|
service.generate_audio.return_value = (np.zeros(1000), AudioChunk(np.zeros(1000,np.int16)))
|
||||||
|
|
||||||
async def mock_stream(*args, **kwargs) -> AsyncGenerator[bytes, None]:
|
async def mock_stream(*args, **kwargs) -> AsyncGenerator[Tuple[bytes,AudioChunk], None]:
|
||||||
yield mock_audio_bytes
|
yield (mock_audio_bytes, AudioChunk(np.ndarray([],np.int16)))
|
||||||
|
|
||||||
service.generate_audio_stream = mock_stream
|
service.generate_audio_stream = mock_stream
|
||||||
service.list_voices.return_value = ["test_voice", "voice1", "voice2"]
|
service.list_voices.return_value = ["test_voice", "voice1", "voice2"]
|
||||||
|
@ -256,7 +257,7 @@ def test_openai_speech_endpoint(
|
||||||
):
|
):
|
||||||
"""Test the OpenAI-compatible speech endpoint with basic MP3 generation"""
|
"""Test the OpenAI-compatible speech endpoint with basic MP3 generation"""
|
||||||
# Configure mocks
|
# Configure mocks
|
||||||
mock_tts_service.generate_audio.return_value = (np.zeros(1000), 0.1)
|
mock_tts_service.generate_audio.return_value = (np.zeros(1000), AudioChunk(np.zeros(1000,np.int16)))
|
||||||
mock_convert.return_value = mock_audio_bytes
|
mock_convert.return_value = mock_audio_bytes
|
||||||
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
|
|
Loading…
Add table
Reference in a new issue