2025-01-09 07:20:14 -07:00
|
|
|
|
"""Tests for text processing endpoints"""
|
2025-01-09 18:41:44 -07:00
|
|
|
|
|
2025-01-09 07:20:14 -07:00
|
|
|
|
from unittest.mock import Mock, patch
|
2025-01-09 18:41:44 -07:00
|
|
|
|
|
|
|
|
|
import numpy as np
|
2025-01-09 07:20:14 -07:00
|
|
|
|
import pytest
|
|
|
|
|
import pytest_asyncio
|
|
|
|
|
from httpx import AsyncClient
|
|
|
|
|
|
|
|
|
|
from .conftest import MockTTSModel
|
2025-01-09 18:41:44 -07:00
|
|
|
|
from ..src.main import app
|
|
|
|
|
|
2025-01-09 07:20:14 -07:00
|
|
|
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
|
|
|
async def async_client():
|
|
|
|
|
async with AsyncClient(app=app, base_url="http://test") as ac:
|
|
|
|
|
yield ac
|
|
|
|
|
|
2025-01-09 18:41:44 -07:00
|
|
|
|
|
2025-01-09 07:20:14 -07:00
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
|
async def test_phonemize_endpoint(async_client):
|
|
|
|
|
"""Test phoneme generation endpoint"""
|
2025-01-10 22:03:16 -07:00
|
|
|
|
with patch("api.src.routers.development.phonemize") as mock_phonemize, patch(
|
|
|
|
|
"api.src.routers.development.tokenize"
|
2025-01-09 18:41:44 -07:00
|
|
|
|
) as mock_tokenize:
|
2025-01-09 07:20:14 -07:00
|
|
|
|
# Setup mocks
|
|
|
|
|
mock_phonemize.return_value = "həlˈoʊ"
|
|
|
|
|
mock_tokenize.return_value = [1, 2, 3]
|
2025-01-09 18:41:44 -07:00
|
|
|
|
|
2025-01-09 07:20:14 -07:00
|
|
|
|
# Test request
|
2025-01-09 18:41:44 -07:00
|
|
|
|
response = await async_client.post(
|
|
|
|
|
"/text/phonemize", json={"text": "hello", "language": "a"}
|
|
|
|
|
)
|
|
|
|
|
|
2025-01-09 07:20:14 -07:00
|
|
|
|
# Verify response
|
|
|
|
|
assert response.status_code == 200
|
|
|
|
|
result = response.json()
|
|
|
|
|
assert result["phonemes"] == "həlˈoʊ"
|
|
|
|
|
assert result["tokens"] == [0, 1, 2, 3, 0] # Should add start/end tokens
|
|
|
|
|
|
2025-01-09 18:41:44 -07:00
|
|
|
|
|
2025-01-09 07:20:14 -07:00
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
|
async def test_phonemize_empty_text(async_client):
|
|
|
|
|
"""Test phoneme generation with empty text"""
|
2025-01-09 18:41:44 -07:00
|
|
|
|
response = await async_client.post(
|
|
|
|
|
"/text/phonemize", json={"text": "", "language": "a"}
|
|
|
|
|
)
|
|
|
|
|
|
2025-01-09 07:20:14 -07:00
|
|
|
|
assert response.status_code == 500
|
|
|
|
|
assert "error" in response.json()["detail"]
|
|
|
|
|
|
2025-01-09 18:41:44 -07:00
|
|
|
|
|
2025-01-09 07:20:14 -07:00
|
|
|
|
@pytest.mark.asyncio
|
2025-01-09 18:41:44 -07:00
|
|
|
|
async def test_generate_from_phonemes(
|
|
|
|
|
async_client, mock_tts_service, mock_audio_service
|
|
|
|
|
):
|
2025-01-09 07:20:14 -07:00
|
|
|
|
"""Test audio generation from phonemes"""
|
2025-01-09 18:41:44 -07:00
|
|
|
|
with patch(
|
2025-01-10 22:03:16 -07:00
|
|
|
|
"api.src.routers.development.TTSService", return_value=mock_tts_service
|
2025-01-09 18:41:44 -07:00
|
|
|
|
):
|
|
|
|
|
response = await async_client.post(
|
|
|
|
|
"/text/generate_from_phonemes",
|
|
|
|
|
json={"phonemes": "həlˈoʊ", "voice": "af_bella", "speed": 1.0},
|
|
|
|
|
)
|
|
|
|
|
|
2025-01-09 07:20:14 -07:00
|
|
|
|
assert response.status_code == 200
|
|
|
|
|
assert response.headers["content-type"] == "audio/wav"
|
2025-01-09 18:41:44 -07:00
|
|
|
|
assert (
|
|
|
|
|
response.headers["content-disposition"] == "attachment; filename=speech.wav"
|
|
|
|
|
)
|
2025-01-09 07:20:14 -07:00
|
|
|
|
assert response.content == b"mock audio data"
|
|
|
|
|
|
2025-01-09 18:41:44 -07:00
|
|
|
|
|
2025-01-09 07:20:14 -07:00
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
|
async def test_generate_from_phonemes_invalid_voice(async_client, mock_tts_service):
|
|
|
|
|
"""Test audio generation with invalid voice"""
|
|
|
|
|
mock_tts_service._get_voice_path.return_value = None
|
2025-01-09 18:41:44 -07:00
|
|
|
|
with patch(
|
2025-01-10 22:03:16 -07:00
|
|
|
|
"api.src.routers.development.TTSService", return_value=mock_tts_service
|
2025-01-09 18:41:44 -07:00
|
|
|
|
):
|
|
|
|
|
response = await async_client.post(
|
|
|
|
|
"/text/generate_from_phonemes",
|
|
|
|
|
json={"phonemes": "həlˈoʊ", "voice": "invalid_voice", "speed": 1.0},
|
|
|
|
|
)
|
|
|
|
|
|
2025-01-09 07:20:14 -07:00
|
|
|
|
assert response.status_code == 400
|
|
|
|
|
assert "Voice not found" in response.json()["detail"]["message"]
|
|
|
|
|
|
2025-01-09 18:41:44 -07:00
|
|
|
|
|
2025-01-09 07:20:14 -07:00
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
|
async def test_generate_from_phonemes_invalid_speed(async_client, monkeypatch):
|
|
|
|
|
"""Test audio generation with invalid speed"""
|
|
|
|
|
# Mock TTSModel initialization
|
|
|
|
|
mock_model = Mock()
|
|
|
|
|
mock_model.generate_from_tokens = Mock(return_value=np.zeros(48000))
|
|
|
|
|
monkeypatch.setattr("api.src.services.tts_model.TTSModel._instance", mock_model)
|
2025-01-09 18:41:44 -07:00
|
|
|
|
monkeypatch.setattr(
|
|
|
|
|
"api.src.services.tts_model.TTSModel.get_instance",
|
|
|
|
|
Mock(return_value=mock_model),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
response = await async_client.post(
|
|
|
|
|
"/text/generate_from_phonemes",
|
|
|
|
|
json={"phonemes": "həlˈoʊ", "voice": "af_bella", "speed": -1.0},
|
|
|
|
|
)
|
|
|
|
|
|
2025-01-09 07:20:14 -07:00
|
|
|
|
assert response.status_code == 422 # Validation error
|
|
|
|
|
|
2025-01-09 18:41:44 -07:00
|
|
|
|
|
2025-01-09 07:20:14 -07:00
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
|
async def test_generate_from_phonemes_empty_phonemes(async_client, mock_tts_service):
|
|
|
|
|
"""Test audio generation with empty phonemes"""
|
2025-01-09 18:41:44 -07:00
|
|
|
|
with patch(
|
2025-01-10 22:03:16 -07:00
|
|
|
|
"api.src.routers.development.TTSService", return_value=mock_tts_service
|
2025-01-09 18:41:44 -07:00
|
|
|
|
):
|
|
|
|
|
response = await async_client.post(
|
|
|
|
|
"/text/generate_from_phonemes",
|
|
|
|
|
json={"phonemes": "", "voice": "af_bella", "speed": 1.0},
|
|
|
|
|
)
|
|
|
|
|
|
2025-01-09 07:20:14 -07:00
|
|
|
|
assert response.status_code == 400
|
|
|
|
|
assert "Invalid request" in response.json()["detail"]["error"]
|