fixed a bunch of tests

This commit is contained in:
Fireblade 2025-02-15 09:40:01 -05:00
parent 1a03ac7464
commit 1a6e7abac3
2 changed files with 15 additions and 14 deletions

View file

@ -1,24 +1,24 @@
import pytest import pytest
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock
import requests import requests
import base64
import json
def test_generate_captioned_speech(): def test_generate_captioned_speech():
"""Test the generate_captioned_speech function with mocked responses""" """Test the generate_captioned_speech function with mocked responses"""
# Mock the API responses # Mock the API responses
mock_audio_response = MagicMock() mock_audio_response = MagicMock()
mock_audio_response.status_code = 200 mock_audio_response.status_code = 200
mock_audio_response.content = b"mock audio data"
mock_audio_response.headers = {"X-Timestamps-Path": "test.json"}
mock_timestamps_response = MagicMock() mock_timestamps_response = MagicMock()
mock_timestamps_response.status_code = 200 mock_timestamps_response.status_code = 200
mock_timestamps_response.json.return_value = [ mock_timestamps_response.content = json.dumps({
{"word": "test", "start_time": 0.0, "end_time": 1.0} "audio":base64.b64encode(b"mock audio data").decode("utf-8"),
] "timestamps":[{"word": "test", "start_time": 0.0, "end_time": 1.0}]
})
# Patch both HTTP requests # Patch the HTTP requests
with patch('requests.post', return_value=mock_audio_response), \ with patch('requests.post', return_value=mock_timestamps_response):
patch('requests.get', return_value=mock_timestamps_response):
# Import here to avoid module-level import issues # Import here to avoid module-level import issues
from examples.captioned_speech_example import generate_captioned_speech from examples.captioned_speech_example import generate_captioned_speech

View file

@ -1,9 +1,10 @@
import asyncio import asyncio
import json import json
import os import os
from typing import AsyncGenerator from typing import AsyncGenerator, Tuple
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
from api.src.inference.base import AudioChunk
import numpy as np import numpy as np
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
@ -144,7 +145,7 @@ async def test_stream_audio_chunks_client_disconnect():
async def mock_stream(*args, **kwargs): async def mock_stream(*args, **kwargs):
for i in range(5): for i in range(5):
yield b"chunk" yield (b"chunk",AudioChunk(np.ndarray([],np.int16)))
mock_service.generate_audio_stream = mock_stream mock_service.generate_audio_stream = mock_stream
mock_service.list_voices.return_value = ["test_voice"] mock_service.list_voices.return_value = ["test_voice"]
@ -236,10 +237,10 @@ def mock_tts_service(mock_audio_bytes):
"""Mock TTS service for testing.""" """Mock TTS service for testing."""
with patch("api.src.routers.openai_compatible.get_tts_service") as mock_get: with patch("api.src.routers.openai_compatible.get_tts_service") as mock_get:
service = AsyncMock(spec=TTSService) service = AsyncMock(spec=TTSService)
service.generate_audio.return_value = (np.zeros(1000), 0.1) service.generate_audio.return_value = (np.zeros(1000), AudioChunk(np.zeros(1000,np.int16)))
async def mock_stream(*args, **kwargs) -> AsyncGenerator[bytes, None]: async def mock_stream(*args, **kwargs) -> AsyncGenerator[Tuple[bytes,AudioChunk], None]:
yield mock_audio_bytes yield (mock_audio_bytes, AudioChunk(np.ndarray([],np.int16)))
service.generate_audio_stream = mock_stream service.generate_audio_stream = mock_stream
service.list_voices.return_value = ["test_voice", "voice1", "voice2"] service.list_voices.return_value = ["test_voice", "voice1", "voice2"]
@ -256,7 +257,7 @@ def test_openai_speech_endpoint(
): ):
"""Test the OpenAI-compatible speech endpoint with basic MP3 generation""" """Test the OpenAI-compatible speech endpoint with basic MP3 generation"""
# Configure mocks # Configure mocks
mock_tts_service.generate_audio.return_value = (np.zeros(1000), 0.1) mock_tts_service.generate_audio.return_value = (np.zeros(1000), AudioChunk(np.zeros(1000,np.int16)))
mock_convert.return_value = mock_audio_bytes mock_convert.return_value = mock_audio_bytes
response = client.post( response = client.post(