fixed a bunch of tests

This commit is contained in:
Fireblade 2025-02-15 09:40:01 -05:00
parent 1a03ac7464
commit 1a6e7abac3
2 changed files with 15 additions and 14 deletions

View file

@ -1,24 +1,24 @@
import pytest
from unittest.mock import patch, MagicMock
import requests
import base64
import json
def test_generate_captioned_speech():
"""Test the generate_captioned_speech function with mocked responses"""
# Mock the API responses
mock_audio_response = MagicMock()
mock_audio_response.status_code = 200
mock_audio_response.content = b"mock audio data"
mock_audio_response.headers = {"X-Timestamps-Path": "test.json"}
mock_timestamps_response = MagicMock()
mock_timestamps_response.status_code = 200
mock_timestamps_response.json.return_value = [
{"word": "test", "start_time": 0.0, "end_time": 1.0}
]
mock_timestamps_response.content = json.dumps({
"audio":base64.b64encode(b"mock audio data").decode("utf-8"),
"timestamps":[{"word": "test", "start_time": 0.0, "end_time": 1.0}]
})
# Patch both HTTP requests
with patch('requests.post', return_value=mock_audio_response), \
patch('requests.get', return_value=mock_timestamps_response):
# Patch the HTTP requests
with patch('requests.post', return_value=mock_timestamps_response):
# Import here to avoid module-level import issues
from examples.captioned_speech_example import generate_captioned_speech

View file

@ -1,9 +1,10 @@
import asyncio
import json
import os
from typing import AsyncGenerator
from typing import AsyncGenerator, Tuple
from unittest.mock import AsyncMock, MagicMock, patch
from api.src.inference.base import AudioChunk
import numpy as np
import pytest
from fastapi.testclient import TestClient
@ -144,7 +145,7 @@ async def test_stream_audio_chunks_client_disconnect():
async def mock_stream(*args, **kwargs):
for i in range(5):
yield b"chunk"
yield (b"chunk",AudioChunk(np.ndarray([],np.int16)))
mock_service.generate_audio_stream = mock_stream
mock_service.list_voices.return_value = ["test_voice"]
@ -236,10 +237,10 @@ def mock_tts_service(mock_audio_bytes):
"""Mock TTS service for testing."""
with patch("api.src.routers.openai_compatible.get_tts_service") as mock_get:
service = AsyncMock(spec=TTSService)
service.generate_audio.return_value = (np.zeros(1000), 0.1)
service.generate_audio.return_value = (np.zeros(1000), AudioChunk(np.zeros(1000,np.int16)))
async def mock_stream(*args, **kwargs) -> AsyncGenerator[bytes, None]:
yield mock_audio_bytes
async def mock_stream(*args, **kwargs) -> AsyncGenerator[Tuple[bytes,AudioChunk], None]:
yield (mock_audio_bytes, AudioChunk(np.ndarray([],np.int16)))
service.generate_audio_stream = mock_stream
service.list_voices.return_value = ["test_voice", "voice1", "voice2"]
@ -256,7 +257,7 @@ def test_openai_speech_endpoint(
):
"""Test the OpenAI-compatible speech endpoint with basic MP3 generation"""
# Configure mocks
mock_tts_service.generate_audio.return_value = (np.zeros(1000), 0.1)
mock_tts_service.generate_audio.return_value = (np.zeros(1000), AudioChunk(np.zeros(1000,np.int16)))
mock_convert.return_value = mock_audio_bytes
response = client.post(