Refactor TTS API and enhance testing setup with coverage and logging improvements

This commit is contained in:
remsky 2024-12-31 02:55:51 -07:00
parent c11a6ea6ea
commit 4123ab0891
18 changed files with 432 additions and 45 deletions

BIN
.coverage Normal file

Binary file not shown.

12
.coveragerc Normal file
View file

@ -0,0 +1,12 @@
[run]
source = api
omit = Kokoro-82M/*
[report]
exclude_lines =
pragma: no cover
def __repr__
raise NotImplementedError
if __name__ == .__main__.:
pass
raise ImportError

11
.ruff.toml Normal file
View file

@ -0,0 +1,11 @@
line-length = 88
[lint]
select = ["I"]
[lint.isort]
combine-as-imports = true
force-wrap-aliases = true
length-sort = true
split-on-trailing-comma = true
section-order = ["future", "standard-library", "third-party", "first-party", "local-folder"]

View file

@ -2,15 +2,16 @@
FastAPI OpenAI Compatible API FastAPI OpenAI Compatible API
""" """
import uvicorn
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import uvicorn
from loguru import logger
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from .core.config import settings from .core.config import settings
from .routers.openai_compatible import router as openai_router
from .services.tts import TTSModel, TTSService from .services.tts import TTSModel, TTSService
from .routers.openai_compatible import router as openai_router
@asynccontextmanager @asynccontextmanager

View file

@ -1,10 +1,9 @@
from fastapi import APIRouter, HTTPException, Response, Depends from loguru import logger
import logging from fastapi import Depends, Response, APIRouter, HTTPException
from ..structures.schemas import OpenAISpeechRequest
from ..services.tts import TTSService from ..services.tts import TTSService
from ..services.audio import AudioService from ..services.audio import AudioService
from ..structures.schemas import OpenAISpeechRequest
logger = logging.getLogger(__name__)
router = APIRouter( router = APIRouter(
tags=["OpenAI Compatible TTS"], tags=["OpenAI Compatible TTS"],

View file

@ -1,3 +1,3 @@
from .tts import TTSService, TTSModel from .tts import TTSModel, TTSService
__all__ = ["TTSService", "TTSModel"] __all__ = ["TTSService", "TTSModel"]

View file

@ -1,12 +1,11 @@
"""Audio conversion service""" """Audio conversion service"""
from io import BytesIO from io import BytesIO
import numpy as np
import scipy.io.wavfile as wavfile
import soundfile as sf
import logging
logger = logging.getLogger(__name__) import numpy as np
import soundfile as sf
import scipy.io.wavfile as wavfile
from loguru import logger
class AudioService: class AudioService:

View file

@ -1,19 +1,20 @@
import os
import threading
import time
import io import io
import os
import re
import time
import threading
from typing import List, Tuple from typing import List, Tuple
import numpy as np import numpy as np
import torch import torch
import scipy.io.wavfile as wavfile
from models import build_model
from kokoro import generate, phonemize, tokenize, normalize_text
from ..core.config import settings
import re
import logging
import tiktoken import tiktoken
import scipy.io.wavfile as wavfile
from kokoro import generate, tokenize, phonemize, normalize_text
from loguru import logger
from models import build_model
from ..core.config import settings
logger = logging.getLogger(__name__)
enc = tiktoken.get_encoding("cl100k_base") enc = tiktoken.get_encoding("cl100k_base")

View file

@ -1,6 +1,7 @@
from pydantic import BaseModel, Field
from typing import Literal
from enum import Enum from enum import Enum
from typing import Literal
from pydantic import Field, BaseModel
class TTSStatus(str, Enum): class TTSStatus(str, Enum):
@ -13,7 +14,7 @@ class TTSStatus(str, Enum):
# OpenAI-compatible schemas # OpenAI-compatible schemas
class OpenAISpeechRequest(BaseModel): class OpenAISpeechRequest(BaseModel):
model: Literal["tts-1", "tts-1-hd"] = "tts-1" model: Literal["tts-1", "tts-1-hd", "kokoro"] = "kokoro"
input: str = Field(..., description="The text to generate audio for") input: str = Field(..., description="The text to generate audio for")
voice: Literal[ voice: Literal[
"am_adam", "am_adam",

View file

@ -1,6 +1,7 @@
import pytest
from unittest.mock import Mock, patch
import sys import sys
from unittest.mock import Mock, patch
import pytest
# Mock torch and other ML modules before they're imported # Mock torch and other ML modules before they're imported
sys.modules["torch"] = Mock() sys.modules["torch"] = Mock()

View file

@ -0,0 +1,67 @@
"""Tests for AudioService"""
import numpy as np
import pytest
from api.src.services.audio import AudioService
@pytest.fixture
def sample_audio():
"""Generate a simple sine wave for testing"""
sample_rate = 24000
duration = 0.1 # 100ms
t = np.linspace(0, duration, int(sample_rate * duration))
frequency = 440 # A4 note
return np.sin(2 * np.pi * frequency * t).astype(np.float32), sample_rate
def test_convert_to_wav(sample_audio):
"""Test converting to WAV format"""
audio_data, sample_rate = sample_audio
result = AudioService.convert_audio(audio_data, sample_rate, "wav")
assert isinstance(result, bytes)
assert len(result) > 0
def test_convert_to_mp3(sample_audio):
"""Test converting to MP3 format"""
audio_data, sample_rate = sample_audio
result = AudioService.convert_audio(audio_data, sample_rate, "mp3")
assert isinstance(result, bytes)
assert len(result) > 0
def test_convert_to_opus(sample_audio):
"""Test converting to Opus format"""
audio_data, sample_rate = sample_audio
result = AudioService.convert_audio(audio_data, sample_rate, "opus")
assert isinstance(result, bytes)
assert len(result) > 0
def test_convert_to_flac(sample_audio):
"""Test converting to FLAC format"""
audio_data, sample_rate = sample_audio
result = AudioService.convert_audio(audio_data, sample_rate, "flac")
assert isinstance(result, bytes)
assert len(result) > 0
def test_convert_to_aac_raises_error(sample_audio):
"""Test that converting to AAC raises an error"""
audio_data, sample_rate = sample_audio
with pytest.raises(ValueError, match="AAC format is not currently supported"):
AudioService.convert_audio(audio_data, sample_rate, "aac")
def test_convert_to_pcm_raises_error(sample_audio):
"""Test that converting to PCM raises an error"""
audio_data, sample_rate = sample_audio
with pytest.raises(ValueError, match="PCM format is not currently supported"):
AudioService.convert_audio(audio_data, sample_rate, "pcm")
def test_convert_to_invalid_format_raises_error(sample_audio):
"""Test that converting to an invalid format raises an error"""
audio_data, sample_rate = sample_audio
with pytest.raises(ValueError, match="Format invalid not supported"):
AudioService.convert_audio(audio_data, sample_rate, "invalid")

View file

@ -1,6 +1,8 @@
from fastapi.testclient import TestClient
import pytest
from unittest.mock import Mock from unittest.mock import Mock
import pytest
from fastapi.testclient import TestClient
from ..src.main import app from ..src.main import app
# Create test client # Create test client
@ -50,7 +52,7 @@ def test_health_check():
def test_openai_speech_endpoint(mock_tts_service, mock_audio_service): def test_openai_speech_endpoint(mock_tts_service, mock_audio_service):
"""Test the OpenAI-compatible speech endpoint""" """Test the OpenAI-compatible speech endpoint"""
test_request = { test_request = {
"model": "tts-1", "model": "kokoro",
"input": "Hello world", "input": "Hello world",
"voice": "bm_lewis", "voice": "bm_lewis",
"response_format": "wav", "response_format": "wav",
@ -69,7 +71,7 @@ def test_openai_speech_endpoint(mock_tts_service, mock_audio_service):
def test_openai_speech_invalid_voice(mock_tts_service): def test_openai_speech_invalid_voice(mock_tts_service):
"""Test the OpenAI-compatible speech endpoint with invalid voice""" """Test the OpenAI-compatible speech endpoint with invalid voice"""
test_request = { test_request = {
"model": "tts-1", "model": "kokoro",
"input": "Hello world", "input": "Hello world",
"voice": "invalid_voice", "voice": "invalid_voice",
"response_format": "wav", "response_format": "wav",
@ -82,7 +84,7 @@ def test_openai_speech_invalid_voice(mock_tts_service):
def test_openai_speech_invalid_speed(mock_tts_service): def test_openai_speech_invalid_speed(mock_tts_service):
"""Test the OpenAI-compatible speech endpoint with invalid speed""" """Test the OpenAI-compatible speech endpoint with invalid speed"""
test_request = { test_request = {
"model": "tts-1", "model": "kokoro",
"input": "Hello world", "input": "Hello world",
"voice": "af", "voice": "af",
"response_format": "wav", "response_format": "wav",
@ -96,7 +98,7 @@ def test_openai_speech_generation_error(mock_tts_service):
"""Test error handling in speech generation""" """Test error handling in speech generation"""
mock_tts_service._generate_audio.side_effect = Exception("Generation failed") mock_tts_service._generate_audio.side_effect = Exception("Generation failed")
test_request = { test_request = {
"model": "tts-1", "model": "kokoro",
"input": "Hello world", "input": "Hello world",
"voice": "af", "voice": "af",
"response_format": "wav", "response_format": "wav",

45
api/tests/test_main.py Normal file
View file

@ -0,0 +1,45 @@
"""Tests for main FastAPI application"""
import pytest
from unittest.mock import patch, MagicMock
from fastapi.testclient import TestClient
from api.src.main import app
@pytest.fixture
def client():
"""Create a test client"""
return TestClient(app)
def test_health_check(client):
"""Test health check endpoint"""
response = client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "healthy"}
def test_test_endpoint(client):
"""Test the test endpoint"""
response = client.get("/v1/test")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
def test_cors_headers(client):
"""Test CORS headers are present"""
response = client.get(
"/health",
headers={"Origin": "http://testserver"},
)
assert response.status_code == 200
assert response.headers["access-control-allow-origin"] == "*"
def test_openapi_schema(client):
"""Test OpenAPI schema is accessible"""
response = client.get("/openapi.json")
assert response.status_code == 200
schema = response.json()
assert schema["info"]["title"] == app.title
assert schema["info"]["version"] == app.version

View file

@ -0,0 +1,244 @@
"""Tests for TTSService"""
import os
import numpy as np
import pytest
from unittest.mock import patch, MagicMock, call
from api.src.services.tts import TTSService, TTSModel
@pytest.fixture
def tts_service():
"""Create a TTSService instance for testing"""
return TTSService(start_worker=False)
@pytest.fixture
def sample_audio():
"""Generate a simple sine wave for testing"""
sample_rate = 24000
duration = 0.1 # 100ms
t = np.linspace(0, duration, int(sample_rate * duration))
frequency = 440 # A4 note
return np.sin(2 * np.pi * frequency * t).astype(np.float32)
def test_split_text(tts_service):
"""Test text splitting into sentences"""
text = "First sentence. Second sentence! Third sentence?"
sentences = tts_service._split_text(text)
assert len(sentences) == 3
assert sentences[0] == "First sentence."
assert sentences[1] == "Second sentence!"
assert sentences[2] == "Third sentence?"
def test_split_text_empty(tts_service):
"""Test splitting empty text"""
assert tts_service._split_text("") == []
def test_split_text_single_sentence(tts_service):
"""Test splitting single sentence"""
text = "Just one sentence."
assert tts_service._split_text(text) == ["Just one sentence."]
def test_audio_to_bytes(tts_service, sample_audio):
"""Test converting audio tensor to bytes"""
audio_bytes = tts_service._audio_to_bytes(sample_audio)
assert isinstance(audio_bytes, bytes)
assert len(audio_bytes) > 0
@patch('os.listdir')
@patch('os.path.join')
def test_list_voices(mock_join, mock_listdir, tts_service):
"""Test listing available voices"""
mock_listdir.return_value = ['voice1.pt', 'voice2.pt', 'not_a_voice.txt']
mock_join.return_value = '/fake/path'
voices = tts_service.list_voices()
assert len(voices) == 2
assert 'voice1' in voices
assert 'voice2' in voices
assert 'not_a_voice' not in voices
@patch('api.src.services.tts.TTSModel.get_instance')
@patch('api.src.services.tts.TTSModel.get_voicepack')
@patch('api.src.services.tts.normalize_text')
@patch('api.src.services.tts.phonemize')
@patch('api.src.services.tts.tokenize')
@patch('api.src.services.tts.generate')
def test_generate_audio_empty_text(mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_voicepack, mock_instance, tts_service):
"""Test generating audio with empty text"""
mock_normalize.return_value = ""
with pytest.raises(ValueError, match="Text is empty after preprocessing"):
tts_service._generate_audio("", "af", 1.0)
@patch('api.src.services.tts.TTSModel.get_instance')
@patch('api.src.services.tts.TTSModel.get_voicepack')
@patch('api.src.services.tts.normalize_text')
@patch('api.src.services.tts.phonemize')
@patch('api.src.services.tts.tokenize')
@patch('api.src.services.tts.generate')
def test_generate_audio_no_chunks(mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_voicepack, mock_instance, tts_service):
"""Test generating audio with no successful chunks"""
mock_normalize.return_value = "Test text"
mock_phonemize.return_value = "Test text"
mock_tokenize.return_value = ["test", "text"]
mock_generate.return_value = (None, None)
mock_instance.return_value = (MagicMock(), "cpu")
with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
tts_service._generate_audio("Test text", "af", 1.0)
@patch('api.src.services.tts.TTSModel.get_instance')
@patch('api.src.services.tts.TTSModel.get_voicepack')
@patch('api.src.services.tts.normalize_text')
@patch('api.src.services.tts.phonemize')
@patch('api.src.services.tts.tokenize')
@patch('api.src.services.tts.generate')
def test_generate_audio_success(mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_voicepack, mock_instance, tts_service, sample_audio):
"""Test successful audio generation"""
mock_normalize.return_value = "Test text"
mock_phonemize.return_value = "Test text"
mock_tokenize.return_value = ["test", "text"]
mock_generate.return_value = (sample_audio, None)
mock_instance.return_value = (MagicMock(), "cpu")
audio, processing_time = tts_service._generate_audio("Test text", "af", 1.0)
assert isinstance(audio, np.ndarray)
assert isinstance(processing_time, float)
assert len(audio) > 0
@patch('api.src.services.tts.torch.cuda.is_available')
@patch('api.src.services.tts.build_model')
def test_model_initialization_cuda(mock_build_model, mock_cuda_available):
"""Test model initialization with CUDA"""
mock_cuda_available.return_value = True
mock_model = MagicMock()
mock_build_model.return_value = mock_model
TTSModel._instance = None # Reset singleton
model, device = TTSModel.get_instance()
assert device == "cuda"
assert model == mock_model
mock_build_model.assert_called_once()
@patch('api.src.services.tts.torch.cuda.is_available')
@patch('api.src.services.tts.build_model')
def test_model_initialization_cpu(mock_build_model, mock_cuda_available):
"""Test model initialization with CPU"""
mock_cuda_available.return_value = False
mock_model = MagicMock()
mock_build_model.return_value = mock_model
TTSModel._instance = None # Reset singleton
model, device = TTSModel.get_instance()
assert device == "cpu"
assert model == mock_model
mock_build_model.assert_called_once()
@patch('api.src.services.tts.torch.load')
@patch('os.path.join')
def test_voicepack_loading_error(mock_join, mock_torch_load):
"""Test voicepack loading error handling"""
mock_join.side_effect = lambda *args: '/'.join(args)
mock_torch_load.side_effect = [Exception("Failed to load voice"), MagicMock()]
TTSModel._instance = (MagicMock(), "cpu") # Mock instance
TTSModel._voicepacks = {} # Reset voicepacks
# Should fall back to 'af' voice
voicepack = TTSModel.get_voicepack("nonexistent_voice")
assert mock_torch_load.call_count == 2 # Tried original voice then fallback
assert isinstance(voicepack, MagicMock) # Successfully got fallback voice
@patch('api.src.services.tts.torch.load')
@patch('os.path.join')
def test_voicepack_loading_error_af(mock_join, mock_torch_load):
"""Test voicepack loading error for 'af' voice"""
mock_join.side_effect = lambda *args: '/'.join(args)
mock_torch_load.side_effect = Exception("Failed to load voice")
TTSModel._instance = (MagicMock(), "cpu") # Mock instance
TTSModel._voicepacks = {} # Reset voicepacks
with pytest.raises(Exception):
TTSModel.get_voicepack("af")
def test_save_audio(tts_service, sample_audio, tmp_path):
"""Test saving audio to file"""
output_path = os.path.join(tmp_path, "test_output", "audio.wav")
tts_service._save_audio(sample_audio, output_path)
assert os.path.exists(output_path)
assert os.path.getsize(output_path) > 0
@patch('api.src.services.tts.TTSModel.get_instance')
@patch('api.src.services.tts.TTSModel.get_voicepack')
@patch('api.src.services.tts.normalize_text')
@patch('api.src.services.tts.generate')
def test_generate_audio_without_stitching(mock_generate, mock_normalize, mock_voicepack, mock_instance, tts_service, sample_audio):
"""Test generating audio without text stitching"""
mock_normalize.return_value = "Test text"
mock_generate.return_value = (sample_audio, None)
mock_instance.return_value = (MagicMock(), "cpu")
audio, processing_time = tts_service._generate_audio("Test text", "af", 1.0, stitch_long_output=False)
assert isinstance(audio, np.ndarray)
assert isinstance(processing_time, float)
assert len(audio) > 0
mock_generate.assert_called_once()
@patch('os.listdir')
def test_list_voices_error(mock_listdir, tts_service):
"""Test error handling in list_voices"""
mock_listdir.side_effect = Exception("Failed to list directory")
voices = tts_service.list_voices()
assert voices == []
@patch('api.src.services.tts.TTSModel.get_instance')
@patch('api.src.services.tts.TTSModel.get_voicepack')
@patch('api.src.services.tts.normalize_text')
@patch('api.src.services.tts.phonemize')
@patch('api.src.services.tts.tokenize')
@patch('api.src.services.tts.generate')
def test_generate_audio_phonemize_error(mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_voicepack, mock_instance, tts_service):
"""Test handling phonemization error"""
mock_normalize.return_value = "Test text"
mock_phonemize.side_effect = Exception("Phonemization failed")
mock_instance.return_value = (MagicMock(), "cpu")
mock_generate.return_value = (None, None)
with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
tts_service._generate_audio("Test text", "af", 1.0)
@patch('api.src.services.tts.TTSModel.get_instance')
@patch('api.src.services.tts.TTSModel.get_voicepack')
@patch('api.src.services.tts.normalize_text')
@patch('api.src.services.tts.generate')
def test_generate_audio_error(mock_generate, mock_normalize, mock_voicepack, mock_instance, tts_service):
"""Test handling generation error"""
mock_normalize.return_value = "Test text"
mock_generate.side_effect = Exception("Generation failed")
mock_instance.return_value = (MagicMock(), "cpu")
with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
tts_service._generate_audio("Test text", "af", 1.0)

View file

@ -1,16 +1,17 @@
import os import os
import time
import json import json
import scipy.io.wavfile as wavfile import time
import requests
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import tiktoken
import psutil
import subprocess import subprocess
from datetime import datetime from datetime import datetime
import pandas as pd
import psutil
import seaborn as sns
import requests
import tiktoken
import scipy.io.wavfile as wavfile
import matplotlib.pyplot as plt
enc = tiktoken.get_encoding("cl100k_base") enc = tiktoken.get_encoding("cl100k_base")

View file

@ -1,4 +1,5 @@
from pathlib import Path from pathlib import Path
import openai import openai
import requests import requests
@ -18,6 +19,7 @@ output_dir = Path(__file__).parent / "output"
output_dir.mkdir(exist_ok=True) output_dir.mkdir(exist_ok=True)
def test_voice(voice: str): def test_voice(voice: str):
speech_file = output_dir / f"speech_{voice}.wav" speech_file = output_dir / f"speech_{voice}.wav"
print(f"\nTesting voice: {voice}") print(f"\nTesting voice: {voice}")
@ -25,7 +27,7 @@ def test_voice(voice: str):
try: try:
response = client.audio.speech.create( response = client.audio.speech.create(
model="tts-1", voice=voice, input=SAMPLE_TEXT, response_format="wav" model="kokoro", voice=voice, input=SAMPLE_TEXT, response_format="wav"
) )
print("Got response, saving to file...") print("Got response, saving to file...")

View file

@ -1,4 +1,5 @@
from pathlib import Path from pathlib import Path
import openai import openai
# Configure OpenAI client to use our local endpoint # Configure OpenAI client to use our local endpoint

View file

@ -1,5 +1,5 @@
[pytest] [pytest]
testpaths = api/tests testpaths = api/tests
python_files = test_*.py python_files = test_*.py
addopts = -v --tb=short addopts = -v --tb=short --cov=api --cov-report=term-missing --cov-config=.coveragerc
pythonpath = . pythonpath = .