mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Refactor TTS API and enhance testing setup with coverage and logging improvements
This commit is contained in:
parent
c11a6ea6ea
commit
4123ab0891
18 changed files with 432 additions and 45 deletions
BIN
.coverage
Normal file
BIN
.coverage
Normal file
Binary file not shown.
12
.coveragerc
Normal file
12
.coveragerc
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
[run]
|
||||||
|
source = api
|
||||||
|
omit = Kokoro-82M/*
|
||||||
|
|
||||||
|
[report]
|
||||||
|
exclude_lines =
|
||||||
|
pragma: no cover
|
||||||
|
def __repr__
|
||||||
|
raise NotImplementedError
|
||||||
|
if __name__ == .__main__.:
|
||||||
|
pass
|
||||||
|
raise ImportError
|
11
.ruff.toml
Normal file
11
.ruff.toml
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
line-length = 88
|
||||||
|
|
||||||
|
[lint]
|
||||||
|
select = ["I"]
|
||||||
|
|
||||||
|
[lint.isort]
|
||||||
|
combine-as-imports = true
|
||||||
|
force-wrap-aliases = true
|
||||||
|
length-sort = true
|
||||||
|
split-on-trailing-comma = true
|
||||||
|
section-order = ["future", "standard-library", "third-party", "first-party", "local-folder"]
|
|
@ -2,15 +2,16 @@
|
||||||
FastAPI OpenAI Compatible API
|
FastAPI OpenAI Compatible API
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import uvicorn
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
import uvicorn
|
||||||
|
from loguru import logger
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from .core.config import settings
|
from .core.config import settings
|
||||||
from .routers.openai_compatible import router as openai_router
|
|
||||||
from .services.tts import TTSModel, TTSService
|
from .services.tts import TTSModel, TTSService
|
||||||
|
from .routers.openai_compatible import router as openai_router
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
from fastapi import APIRouter, HTTPException, Response, Depends
|
from loguru import logger
|
||||||
import logging
|
from fastapi import Depends, Response, APIRouter, HTTPException
|
||||||
from ..structures.schemas import OpenAISpeechRequest
|
|
||||||
from ..services.tts import TTSService
|
from ..services.tts import TTSService
|
||||||
from ..services.audio import AudioService
|
from ..services.audio import AudioService
|
||||||
|
from ..structures.schemas import OpenAISpeechRequest
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
tags=["OpenAI Compatible TTS"],
|
tags=["OpenAI Compatible TTS"],
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
from .tts import TTSService, TTSModel
|
from .tts import TTSModel, TTSService
|
||||||
|
|
||||||
__all__ = ["TTSService", "TTSModel"]
|
__all__ = ["TTSService", "TTSModel"]
|
||||||
|
|
|
@ -1,12 +1,11 @@
|
||||||
"""Audio conversion service"""
|
"""Audio conversion service"""
|
||||||
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import numpy as np
|
|
||||||
import scipy.io.wavfile as wavfile
|
|
||||||
import soundfile as sf
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
import scipy.io.wavfile as wavfile
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
class AudioService:
|
class AudioService:
|
||||||
|
|
|
@ -1,19 +1,20 @@
|
||||||
import os
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
import io
|
import io
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import scipy.io.wavfile as wavfile
|
|
||||||
from models import build_model
|
|
||||||
from kokoro import generate, phonemize, tokenize, normalize_text
|
|
||||||
from ..core.config import settings
|
|
||||||
import re
|
|
||||||
import logging
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
import scipy.io.wavfile as wavfile
|
||||||
|
from kokoro import generate, tokenize, phonemize, normalize_text
|
||||||
|
from loguru import logger
|
||||||
|
from models import build_model
|
||||||
|
|
||||||
|
from ..core.config import settings
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
enc = tiktoken.get_encoding("cl100k_base")
|
enc = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from typing import Literal
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import Field, BaseModel
|
||||||
|
|
||||||
|
|
||||||
class TTSStatus(str, Enum):
|
class TTSStatus(str, Enum):
|
||||||
|
@ -13,7 +14,7 @@ class TTSStatus(str, Enum):
|
||||||
|
|
||||||
# OpenAI-compatible schemas
|
# OpenAI-compatible schemas
|
||||||
class OpenAISpeechRequest(BaseModel):
|
class OpenAISpeechRequest(BaseModel):
|
||||||
model: Literal["tts-1", "tts-1-hd"] = "tts-1"
|
model: Literal["tts-1", "tts-1-hd", "kokoro"] = "kokoro"
|
||||||
input: str = Field(..., description="The text to generate audio for")
|
input: str = Field(..., description="The text to generate audio for")
|
||||||
voice: Literal[
|
voice: Literal[
|
||||||
"am_adam",
|
"am_adam",
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import pytest
|
|
||||||
from unittest.mock import Mock, patch
|
|
||||||
import sys
|
import sys
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
# Mock torch and other ML modules before they're imported
|
# Mock torch and other ML modules before they're imported
|
||||||
sys.modules["torch"] = Mock()
|
sys.modules["torch"] = Mock()
|
||||||
|
|
67
api/tests/test_audio_service.py
Normal file
67
api/tests/test_audio_service.py
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
"""Tests for AudioService"""
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from api.src.services.audio import AudioService
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_audio():
|
||||||
|
"""Generate a simple sine wave for testing"""
|
||||||
|
sample_rate = 24000
|
||||||
|
duration = 0.1 # 100ms
|
||||||
|
t = np.linspace(0, duration, int(sample_rate * duration))
|
||||||
|
frequency = 440 # A4 note
|
||||||
|
return np.sin(2 * np.pi * frequency * t).astype(np.float32), sample_rate
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_to_wav(sample_audio):
|
||||||
|
"""Test converting to WAV format"""
|
||||||
|
audio_data, sample_rate = sample_audio
|
||||||
|
result = AudioService.convert_audio(audio_data, sample_rate, "wav")
|
||||||
|
assert isinstance(result, bytes)
|
||||||
|
assert len(result) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_to_mp3(sample_audio):
|
||||||
|
"""Test converting to MP3 format"""
|
||||||
|
audio_data, sample_rate = sample_audio
|
||||||
|
result = AudioService.convert_audio(audio_data, sample_rate, "mp3")
|
||||||
|
assert isinstance(result, bytes)
|
||||||
|
assert len(result) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_to_opus(sample_audio):
|
||||||
|
"""Test converting to Opus format"""
|
||||||
|
audio_data, sample_rate = sample_audio
|
||||||
|
result = AudioService.convert_audio(audio_data, sample_rate, "opus")
|
||||||
|
assert isinstance(result, bytes)
|
||||||
|
assert len(result) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_to_flac(sample_audio):
|
||||||
|
"""Test converting to FLAC format"""
|
||||||
|
audio_data, sample_rate = sample_audio
|
||||||
|
result = AudioService.convert_audio(audio_data, sample_rate, "flac")
|
||||||
|
assert isinstance(result, bytes)
|
||||||
|
assert len(result) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_to_aac_raises_error(sample_audio):
|
||||||
|
"""Test that converting to AAC raises an error"""
|
||||||
|
audio_data, sample_rate = sample_audio
|
||||||
|
with pytest.raises(ValueError, match="AAC format is not currently supported"):
|
||||||
|
AudioService.convert_audio(audio_data, sample_rate, "aac")
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_to_pcm_raises_error(sample_audio):
|
||||||
|
"""Test that converting to PCM raises an error"""
|
||||||
|
audio_data, sample_rate = sample_audio
|
||||||
|
with pytest.raises(ValueError, match="PCM format is not currently supported"):
|
||||||
|
AudioService.convert_audio(audio_data, sample_rate, "pcm")
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_to_invalid_format_raises_error(sample_audio):
|
||||||
|
"""Test that converting to an invalid format raises an error"""
|
||||||
|
audio_data, sample_rate = sample_audio
|
||||||
|
with pytest.raises(ValueError, match="Format invalid not supported"):
|
||||||
|
AudioService.convert_audio(audio_data, sample_rate, "invalid")
|
|
@ -1,6 +1,8 @@
|
||||||
from fastapi.testclient import TestClient
|
|
||||||
import pytest
|
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from ..src.main import app
|
from ..src.main import app
|
||||||
|
|
||||||
# Create test client
|
# Create test client
|
||||||
|
@ -50,7 +52,7 @@ def test_health_check():
|
||||||
def test_openai_speech_endpoint(mock_tts_service, mock_audio_service):
|
def test_openai_speech_endpoint(mock_tts_service, mock_audio_service):
|
||||||
"""Test the OpenAI-compatible speech endpoint"""
|
"""Test the OpenAI-compatible speech endpoint"""
|
||||||
test_request = {
|
test_request = {
|
||||||
"model": "tts-1",
|
"model": "kokoro",
|
||||||
"input": "Hello world",
|
"input": "Hello world",
|
||||||
"voice": "bm_lewis",
|
"voice": "bm_lewis",
|
||||||
"response_format": "wav",
|
"response_format": "wav",
|
||||||
|
@ -69,7 +71,7 @@ def test_openai_speech_endpoint(mock_tts_service, mock_audio_service):
|
||||||
def test_openai_speech_invalid_voice(mock_tts_service):
|
def test_openai_speech_invalid_voice(mock_tts_service):
|
||||||
"""Test the OpenAI-compatible speech endpoint with invalid voice"""
|
"""Test the OpenAI-compatible speech endpoint with invalid voice"""
|
||||||
test_request = {
|
test_request = {
|
||||||
"model": "tts-1",
|
"model": "kokoro",
|
||||||
"input": "Hello world",
|
"input": "Hello world",
|
||||||
"voice": "invalid_voice",
|
"voice": "invalid_voice",
|
||||||
"response_format": "wav",
|
"response_format": "wav",
|
||||||
|
@ -82,7 +84,7 @@ def test_openai_speech_invalid_voice(mock_tts_service):
|
||||||
def test_openai_speech_invalid_speed(mock_tts_service):
|
def test_openai_speech_invalid_speed(mock_tts_service):
|
||||||
"""Test the OpenAI-compatible speech endpoint with invalid speed"""
|
"""Test the OpenAI-compatible speech endpoint with invalid speed"""
|
||||||
test_request = {
|
test_request = {
|
||||||
"model": "tts-1",
|
"model": "kokoro",
|
||||||
"input": "Hello world",
|
"input": "Hello world",
|
||||||
"voice": "af",
|
"voice": "af",
|
||||||
"response_format": "wav",
|
"response_format": "wav",
|
||||||
|
@ -96,7 +98,7 @@ def test_openai_speech_generation_error(mock_tts_service):
|
||||||
"""Test error handling in speech generation"""
|
"""Test error handling in speech generation"""
|
||||||
mock_tts_service._generate_audio.side_effect = Exception("Generation failed")
|
mock_tts_service._generate_audio.side_effect = Exception("Generation failed")
|
||||||
test_request = {
|
test_request = {
|
||||||
"model": "tts-1",
|
"model": "kokoro",
|
||||||
"input": "Hello world",
|
"input": "Hello world",
|
||||||
"voice": "af",
|
"voice": "af",
|
||||||
"response_format": "wav",
|
"response_format": "wav",
|
||||||
|
|
45
api/tests/test_main.py
Normal file
45
api/tests/test_main.py
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
"""Tests for main FastAPI application"""
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from api.src.main import app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client():
|
||||||
|
"""Create a test client"""
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_health_check(client):
|
||||||
|
"""Test health check endpoint"""
|
||||||
|
response = client.get("/health")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"status": "healthy"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_test_endpoint(client):
|
||||||
|
"""Test the test endpoint"""
|
||||||
|
response = client.get("/v1/test")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json() == {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_cors_headers(client):
|
||||||
|
"""Test CORS headers are present"""
|
||||||
|
response = client.get(
|
||||||
|
"/health",
|
||||||
|
headers={"Origin": "http://testserver"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers["access-control-allow-origin"] == "*"
|
||||||
|
|
||||||
|
|
||||||
|
def test_openapi_schema(client):
|
||||||
|
"""Test OpenAPI schema is accessible"""
|
||||||
|
response = client.get("/openapi.json")
|
||||||
|
assert response.status_code == 200
|
||||||
|
schema = response.json()
|
||||||
|
assert schema["info"]["title"] == app.title
|
||||||
|
assert schema["info"]["version"] == app.version
|
244
api/tests/test_tts_service.py
Normal file
244
api/tests/test_tts_service.py
Normal file
|
@ -0,0 +1,244 @@
|
||||||
|
"""Tests for TTSService"""
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock, call
|
||||||
|
from api.src.services.tts import TTSService, TTSModel
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tts_service():
|
||||||
|
"""Create a TTSService instance for testing"""
|
||||||
|
return TTSService(start_worker=False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_audio():
|
||||||
|
"""Generate a simple sine wave for testing"""
|
||||||
|
sample_rate = 24000
|
||||||
|
duration = 0.1 # 100ms
|
||||||
|
t = np.linspace(0, duration, int(sample_rate * duration))
|
||||||
|
frequency = 440 # A4 note
|
||||||
|
return np.sin(2 * np.pi * frequency * t).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_text(tts_service):
|
||||||
|
"""Test text splitting into sentences"""
|
||||||
|
text = "First sentence. Second sentence! Third sentence?"
|
||||||
|
sentences = tts_service._split_text(text)
|
||||||
|
assert len(sentences) == 3
|
||||||
|
assert sentences[0] == "First sentence."
|
||||||
|
assert sentences[1] == "Second sentence!"
|
||||||
|
assert sentences[2] == "Third sentence?"
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_text_empty(tts_service):
|
||||||
|
"""Test splitting empty text"""
|
||||||
|
assert tts_service._split_text("") == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_text_single_sentence(tts_service):
|
||||||
|
"""Test splitting single sentence"""
|
||||||
|
text = "Just one sentence."
|
||||||
|
assert tts_service._split_text(text) == ["Just one sentence."]
|
||||||
|
|
||||||
|
|
||||||
|
def test_audio_to_bytes(tts_service, sample_audio):
|
||||||
|
"""Test converting audio tensor to bytes"""
|
||||||
|
audio_bytes = tts_service._audio_to_bytes(sample_audio)
|
||||||
|
assert isinstance(audio_bytes, bytes)
|
||||||
|
assert len(audio_bytes) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@patch('os.listdir')
|
||||||
|
@patch('os.path.join')
|
||||||
|
def test_list_voices(mock_join, mock_listdir, tts_service):
|
||||||
|
"""Test listing available voices"""
|
||||||
|
mock_listdir.return_value = ['voice1.pt', 'voice2.pt', 'not_a_voice.txt']
|
||||||
|
mock_join.return_value = '/fake/path'
|
||||||
|
|
||||||
|
voices = tts_service.list_voices()
|
||||||
|
assert len(voices) == 2
|
||||||
|
assert 'voice1' in voices
|
||||||
|
assert 'voice2' in voices
|
||||||
|
assert 'not_a_voice' not in voices
|
||||||
|
|
||||||
|
|
||||||
|
@patch('api.src.services.tts.TTSModel.get_instance')
|
||||||
|
@patch('api.src.services.tts.TTSModel.get_voicepack')
|
||||||
|
@patch('api.src.services.tts.normalize_text')
|
||||||
|
@patch('api.src.services.tts.phonemize')
|
||||||
|
@patch('api.src.services.tts.tokenize')
|
||||||
|
@patch('api.src.services.tts.generate')
|
||||||
|
def test_generate_audio_empty_text(mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_voicepack, mock_instance, tts_service):
|
||||||
|
"""Test generating audio with empty text"""
|
||||||
|
mock_normalize.return_value = ""
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Text is empty after preprocessing"):
|
||||||
|
tts_service._generate_audio("", "af", 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
@patch('api.src.services.tts.TTSModel.get_instance')
|
||||||
|
@patch('api.src.services.tts.TTSModel.get_voicepack')
|
||||||
|
@patch('api.src.services.tts.normalize_text')
|
||||||
|
@patch('api.src.services.tts.phonemize')
|
||||||
|
@patch('api.src.services.tts.tokenize')
|
||||||
|
@patch('api.src.services.tts.generate')
|
||||||
|
def test_generate_audio_no_chunks(mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_voicepack, mock_instance, tts_service):
|
||||||
|
"""Test generating audio with no successful chunks"""
|
||||||
|
mock_normalize.return_value = "Test text"
|
||||||
|
mock_phonemize.return_value = "Test text"
|
||||||
|
mock_tokenize.return_value = ["test", "text"]
|
||||||
|
mock_generate.return_value = (None, None)
|
||||||
|
mock_instance.return_value = (MagicMock(), "cpu")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
|
||||||
|
tts_service._generate_audio("Test text", "af", 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
@patch('api.src.services.tts.TTSModel.get_instance')
|
||||||
|
@patch('api.src.services.tts.TTSModel.get_voicepack')
|
||||||
|
@patch('api.src.services.tts.normalize_text')
|
||||||
|
@patch('api.src.services.tts.phonemize')
|
||||||
|
@patch('api.src.services.tts.tokenize')
|
||||||
|
@patch('api.src.services.tts.generate')
|
||||||
|
def test_generate_audio_success(mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_voicepack, mock_instance, tts_service, sample_audio):
|
||||||
|
"""Test successful audio generation"""
|
||||||
|
mock_normalize.return_value = "Test text"
|
||||||
|
mock_phonemize.return_value = "Test text"
|
||||||
|
mock_tokenize.return_value = ["test", "text"]
|
||||||
|
mock_generate.return_value = (sample_audio, None)
|
||||||
|
mock_instance.return_value = (MagicMock(), "cpu")
|
||||||
|
|
||||||
|
audio, processing_time = tts_service._generate_audio("Test text", "af", 1.0)
|
||||||
|
assert isinstance(audio, np.ndarray)
|
||||||
|
assert isinstance(processing_time, float)
|
||||||
|
assert len(audio) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@patch('api.src.services.tts.torch.cuda.is_available')
|
||||||
|
@patch('api.src.services.tts.build_model')
|
||||||
|
def test_model_initialization_cuda(mock_build_model, mock_cuda_available):
|
||||||
|
"""Test model initialization with CUDA"""
|
||||||
|
mock_cuda_available.return_value = True
|
||||||
|
mock_model = MagicMock()
|
||||||
|
mock_build_model.return_value = mock_model
|
||||||
|
|
||||||
|
TTSModel._instance = None # Reset singleton
|
||||||
|
model, device = TTSModel.get_instance()
|
||||||
|
|
||||||
|
assert device == "cuda"
|
||||||
|
assert model == mock_model
|
||||||
|
mock_build_model.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@patch('api.src.services.tts.torch.cuda.is_available')
|
||||||
|
@patch('api.src.services.tts.build_model')
|
||||||
|
def test_model_initialization_cpu(mock_build_model, mock_cuda_available):
|
||||||
|
"""Test model initialization with CPU"""
|
||||||
|
mock_cuda_available.return_value = False
|
||||||
|
mock_model = MagicMock()
|
||||||
|
mock_build_model.return_value = mock_model
|
||||||
|
|
||||||
|
TTSModel._instance = None # Reset singleton
|
||||||
|
model, device = TTSModel.get_instance()
|
||||||
|
|
||||||
|
assert device == "cpu"
|
||||||
|
assert model == mock_model
|
||||||
|
mock_build_model.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@patch('api.src.services.tts.torch.load')
|
||||||
|
@patch('os.path.join')
|
||||||
|
def test_voicepack_loading_error(mock_join, mock_torch_load):
|
||||||
|
"""Test voicepack loading error handling"""
|
||||||
|
mock_join.side_effect = lambda *args: '/'.join(args)
|
||||||
|
mock_torch_load.side_effect = [Exception("Failed to load voice"), MagicMock()]
|
||||||
|
|
||||||
|
TTSModel._instance = (MagicMock(), "cpu") # Mock instance
|
||||||
|
TTSModel._voicepacks = {} # Reset voicepacks
|
||||||
|
|
||||||
|
# Should fall back to 'af' voice
|
||||||
|
voicepack = TTSModel.get_voicepack("nonexistent_voice")
|
||||||
|
assert mock_torch_load.call_count == 2 # Tried original voice then fallback
|
||||||
|
assert isinstance(voicepack, MagicMock) # Successfully got fallback voice
|
||||||
|
|
||||||
|
|
||||||
|
@patch('api.src.services.tts.torch.load')
|
||||||
|
@patch('os.path.join')
|
||||||
|
def test_voicepack_loading_error_af(mock_join, mock_torch_load):
|
||||||
|
"""Test voicepack loading error for 'af' voice"""
|
||||||
|
mock_join.side_effect = lambda *args: '/'.join(args)
|
||||||
|
mock_torch_load.side_effect = Exception("Failed to load voice")
|
||||||
|
|
||||||
|
TTSModel._instance = (MagicMock(), "cpu") # Mock instance
|
||||||
|
TTSModel._voicepacks = {} # Reset voicepacks
|
||||||
|
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
TTSModel.get_voicepack("af")
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_audio(tts_service, sample_audio, tmp_path):
|
||||||
|
"""Test saving audio to file"""
|
||||||
|
output_path = os.path.join(tmp_path, "test_output", "audio.wav")
|
||||||
|
tts_service._save_audio(sample_audio, output_path)
|
||||||
|
|
||||||
|
assert os.path.exists(output_path)
|
||||||
|
assert os.path.getsize(output_path) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@patch('api.src.services.tts.TTSModel.get_instance')
|
||||||
|
@patch('api.src.services.tts.TTSModel.get_voicepack')
|
||||||
|
@patch('api.src.services.tts.normalize_text')
|
||||||
|
@patch('api.src.services.tts.generate')
|
||||||
|
def test_generate_audio_without_stitching(mock_generate, mock_normalize, mock_voicepack, mock_instance, tts_service, sample_audio):
|
||||||
|
"""Test generating audio without text stitching"""
|
||||||
|
mock_normalize.return_value = "Test text"
|
||||||
|
mock_generate.return_value = (sample_audio, None)
|
||||||
|
mock_instance.return_value = (MagicMock(), "cpu")
|
||||||
|
|
||||||
|
audio, processing_time = tts_service._generate_audio("Test text", "af", 1.0, stitch_long_output=False)
|
||||||
|
assert isinstance(audio, np.ndarray)
|
||||||
|
assert isinstance(processing_time, float)
|
||||||
|
assert len(audio) > 0
|
||||||
|
mock_generate.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@patch('os.listdir')
|
||||||
|
def test_list_voices_error(mock_listdir, tts_service):
|
||||||
|
"""Test error handling in list_voices"""
|
||||||
|
mock_listdir.side_effect = Exception("Failed to list directory")
|
||||||
|
|
||||||
|
voices = tts_service.list_voices()
|
||||||
|
assert voices == []
|
||||||
|
|
||||||
|
|
||||||
|
@patch('api.src.services.tts.TTSModel.get_instance')
|
||||||
|
@patch('api.src.services.tts.TTSModel.get_voicepack')
|
||||||
|
@patch('api.src.services.tts.normalize_text')
|
||||||
|
@patch('api.src.services.tts.phonemize')
|
||||||
|
@patch('api.src.services.tts.tokenize')
|
||||||
|
@patch('api.src.services.tts.generate')
|
||||||
|
def test_generate_audio_phonemize_error(mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_voicepack, mock_instance, tts_service):
|
||||||
|
"""Test handling phonemization error"""
|
||||||
|
mock_normalize.return_value = "Test text"
|
||||||
|
mock_phonemize.side_effect = Exception("Phonemization failed")
|
||||||
|
mock_instance.return_value = (MagicMock(), "cpu")
|
||||||
|
mock_generate.return_value = (None, None)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
|
||||||
|
tts_service._generate_audio("Test text", "af", 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
@patch('api.src.services.tts.TTSModel.get_instance')
|
||||||
|
@patch('api.src.services.tts.TTSModel.get_voicepack')
|
||||||
|
@patch('api.src.services.tts.normalize_text')
|
||||||
|
@patch('api.src.services.tts.generate')
|
||||||
|
def test_generate_audio_error(mock_generate, mock_normalize, mock_voicepack, mock_instance, tts_service):
|
||||||
|
"""Test handling generation error"""
|
||||||
|
mock_normalize.return_value = "Test text"
|
||||||
|
mock_generate.side_effect = Exception("Generation failed")
|
||||||
|
mock_instance.return_value = (MagicMock(), "cpu")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
|
||||||
|
tts_service._generate_audio("Test text", "af", 1.0)
|
|
@ -1,16 +1,17 @@
|
||||||
import os
|
import os
|
||||||
import time
|
|
||||||
import json
|
import json
|
||||||
import scipy.io.wavfile as wavfile
|
import time
|
||||||
import requests
|
|
||||||
import pandas as pd
|
|
||||||
import seaborn as sns
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import tiktoken
|
|
||||||
import psutil
|
|
||||||
import subprocess
|
import subprocess
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import psutil
|
||||||
|
import seaborn as sns
|
||||||
|
import requests
|
||||||
|
import tiktoken
|
||||||
|
import scipy.io.wavfile as wavfile
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
enc = tiktoken.get_encoding("cl100k_base")
|
enc = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
@ -18,6 +19,7 @@ output_dir = Path(__file__).parent / "output"
|
||||||
output_dir.mkdir(exist_ok=True)
|
output_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_voice(voice: str):
|
def test_voice(voice: str):
|
||||||
speech_file = output_dir / f"speech_{voice}.wav"
|
speech_file = output_dir / f"speech_{voice}.wav"
|
||||||
print(f"\nTesting voice: {voice}")
|
print(f"\nTesting voice: {voice}")
|
||||||
|
@ -25,7 +27,7 @@ def test_voice(voice: str):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = client.audio.speech.create(
|
response = client.audio.speech.create(
|
||||||
model="tts-1", voice=voice, input=SAMPLE_TEXT, response_format="wav"
|
model="kokoro", voice=voice, input=SAMPLE_TEXT, response_format="wav"
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Got response, saving to file...")
|
print("Got response, saving to file...")
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
# Configure OpenAI client to use our local endpoint
|
# Configure OpenAI client to use our local endpoint
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
[pytest]
|
[pytest]
|
||||||
testpaths = api/tests
|
testpaths = api/tests
|
||||||
python_files = test_*.py
|
python_files = test_*.py
|
||||||
addopts = -v --tb=short
|
addopts = -v --tb=short --cov=api --cov-report=term-missing --cov-config=.coveragerc
|
||||||
pythonpath = .
|
pythonpath = .
|
||||||
|
|
Loading…
Add table
Reference in a new issue