From 4123ab0891ded15dce57a0f4805ba58e7ebf5153 Mon Sep 17 00:00:00 2001 From: remsky Date: Tue, 31 Dec 2024 02:55:51 -0700 Subject: [PATCH] Refactor TTS API and enhance testing setup with coverage and logging improvements --- .coverage | Bin 0 -> 53248 bytes .coveragerc | 12 ++ .ruff.toml | 11 ++ api/src/main.py | 7 +- api/src/routers/openai_compatible.py | 9 +- api/src/services/__init__.py | 2 +- api/src/services/audio.py | 9 +- api/src/services/tts.py | 21 +-- api/src/structures/schemas.py | 7 +- api/tests/conftest.py | 5 +- api/tests/test_audio_service.py | 67 ++++++++ api/tests/test_endpoints.py | 14 +- api/tests/test_main.py | 45 +++++ api/tests/test_tts_service.py | 244 +++++++++++++++++++++++++++ examples/benchmarks/benchmark_tts.py | 17 +- examples/test_all_voices.py | 4 +- examples/test_openai_tts.py | 1 + pytest.ini | 2 +- 18 files changed, 432 insertions(+), 45 deletions(-) create mode 100644 .coverage create mode 100644 .coveragerc create mode 100644 .ruff.toml create mode 100644 api/tests/test_audio_service.py create mode 100644 api/tests/test_main.py create mode 100644 api/tests/test_tts_service.py diff --git a/.coverage b/.coverage new file mode 100644 index 0000000000000000000000000000000000000000..7a62bb5cc5afc92e2668b64a4493cccc3a7e371e GIT binary patch literal 53248 zcmeI4UyR&F9mnmpcfH=VcXkuP<_Oift?1>*-L(f2q(y3SNiKidinI{=07tgIUEjsc zt-Z0mcgX`ywuM9@ph`ScLP8ayNPVb;0Er?YBqD)8@PH`1fB+Fyf(8UgNPrrGfZvSm zTi+!Y#fLLR{^m2knc4BK{oyAbb0f!E@Pn=$S?`tVk}OMmElZLlon9ur z(xB0Xn(k04XSO%m)TOEAiza(gGD@G7*m3hHn>4;)-dBCi=<1)YzNp!Hn@%8r00@8p z2>gE&7(8KU&5;rL*-u3F{E`y|cH0T!*xdi2*_mUr*0Gs=kIh=~F>7bVqA)pW%~*lI zV)dNBT5y*f%k?^L+m2js(TYwuY3_dLbofL^$7rhK8K*>fzSVL&R4Q^7Nun3HT{}2q zopjDjhz!ZQMb7CcK0qRlyXf&8*3P8(xD_}HPT+WLCycGw>2}5|%WSrGemOYeQO3@d`Z^mrUL(LuXj{1|^g1N#O-D&ht~O!-D*+ zsj#rw>v!h~`P2Pwk2emLh~GH!H!8VN*~m{Sd&_NegSH&jcTIDq$a#y#rTU(Gs_(2e z{)9CwIU%ay`4bke1bM)_s-@hI&W61_ubVeLWepCjRk48g}&G3T_xmoX#2}fU@tmT$=P^V zc&i7s?$wqmgsakW}HhB>D2tF?HS8+`x_>o~3 zST_;ey7@gjI$>RJJ3mZ@j}=SJ4{Y1WaBf0tn#ZN*K2{JWUl!>E8LUMJwtnAf#H72o^p+g?9{y< z_ruVoX_Ze~;X~W2bY3%2nlfO?qh)_Xmj5=jVU38shDw>?-BUY+E z)gvZcwQG0c?#>$}qquUyxlS&gjSkq66S-Z-;#q_-!m(ED&}s(`ous*Rd`!8S8}Qa@ zP4k$qlFR&~++bAJn!9$%gMQK(<8h`H(xl$CQ?rxZ!Wt)8cXH@2M`^&NPdiyZ;p{Rv z_-_0Td~l0fa6YDx=1M;6RL1!NUuM@N`bPi(5C8!X009sH0T2KI5C8!X009uV^9U$% zLDu;CzrZd_?9UV-fB*=900@8p2!H?xfB*=900@8p2;7zgjDm6(OMl_mPdWQ|v(PYVB(^ul8>9iusZ`XWnD{(zs~MRNt(=N?8y<00ck) z1V8`;KmY_l00i<7IMh(3!Pfl`&3z=KTbt*erpJYJ=MFgG$;j``J?fwI1Ap?M9Y!;c zADOdz?p)-AQ7Gb86ooA+wCuJWy32WQorH>;66)HnM+%2nK2yS=Jr_nnza7!z0K&QW z_TrF)4%bL%+pU|jPN>y#JvVB#Na=$nDczkCGCX z$2O;zF>7~Kkyi4lB#dV(`3%T3x8rqs^neh0rbh3~ew}2h`7B6WC7iSS9oHw3y=4-~ z;W50K=qWS~Mc#tDNDBKjQqc36Arvw!XG)}@<~Jm~uoVo#w6U zWEw{>&Gg1S3P~w>r8rnFN_E;(*o;7?GP~tmKAI`XzyEJkJ&FCGw#)pzd9U#+<1@9d zn_sT^#=h!RcAm}F{$YN+`c2A$00JNY0w4eaAOHd&00JP8i@;%ZuvN~xP~eM%b?g7e zj5;W8PAFNM)^nP&ZrMIu|Fa|NU|SAT)(K^nBDJIH;O?ACH(URk2h{=Fq+Z4><3VL5 zr%J-On$v*H5;MF0uN+VZ)tnZrE}GNzzn;@$u3!I`_p1Xvry1+l|Jpuvpykw9txCN9 zFC9_`jm=7Anzi_dI%ww5$u!RT^}o7T9Voe_GV6b3Qv#XF6m}~sIWEm)O5*u{xdw1# z1_2NN0T2KI5C8!X009sH0T2Lz+l_!iKN>Lk_y02cr$qkD%(qxzNhre+25My zWNG{B34TkJOGTceSTR*f->JSR7kC(y724gRyzp009sH0T2KI5C8!X009sH0T2KI5V(U0@bCYz{=b787*zuS5C8!X009sH z0T2KI5C8!X00ANp-~ay#|NZ~p=?MUTq-OxU&VEhz0Q`*onEjA_k9~)Ii|zyXDtiUE zKmY_l00ck)1V8`;KmY_l00ck)1pWsElm 0 + + +def test_convert_to_mp3(sample_audio): + """Test converting to MP3 format""" + audio_data, sample_rate = sample_audio + result = AudioService.convert_audio(audio_data, sample_rate, "mp3") + assert isinstance(result, bytes) + assert len(result) > 0 + + +def test_convert_to_opus(sample_audio): + """Test converting to Opus format""" + audio_data, sample_rate = sample_audio + result = AudioService.convert_audio(audio_data, sample_rate, "opus") + assert isinstance(result, bytes) + assert len(result) > 0 + + +def test_convert_to_flac(sample_audio): + """Test converting to FLAC format""" + audio_data, sample_rate = sample_audio + result = AudioService.convert_audio(audio_data, sample_rate, "flac") + assert isinstance(result, bytes) + assert len(result) > 0 + + +def test_convert_to_aac_raises_error(sample_audio): + """Test that converting to AAC raises an error""" + audio_data, sample_rate = sample_audio + with pytest.raises(ValueError, match="AAC format is not currently supported"): + AudioService.convert_audio(audio_data, sample_rate, "aac") + + +def test_convert_to_pcm_raises_error(sample_audio): + """Test that converting to PCM raises an error""" + audio_data, sample_rate = sample_audio + with pytest.raises(ValueError, match="PCM format is not currently supported"): + AudioService.convert_audio(audio_data, sample_rate, "pcm") + + +def test_convert_to_invalid_format_raises_error(sample_audio): + """Test that converting to an invalid format raises an error""" + audio_data, sample_rate = sample_audio + with pytest.raises(ValueError, match="Format invalid not supported"): + AudioService.convert_audio(audio_data, sample_rate, "invalid") diff --git a/api/tests/test_endpoints.py b/api/tests/test_endpoints.py index c2223f0..97789f5 100644 --- a/api/tests/test_endpoints.py +++ b/api/tests/test_endpoints.py @@ -1,6 +1,8 @@ -from fastapi.testclient import TestClient -import pytest from unittest.mock import Mock + +import pytest +from fastapi.testclient import TestClient + from ..src.main import app # Create test client @@ -50,7 +52,7 @@ def test_health_check(): def test_openai_speech_endpoint(mock_tts_service, mock_audio_service): """Test the OpenAI-compatible speech endpoint""" test_request = { - "model": "tts-1", + "model": "kokoro", "input": "Hello world", "voice": "bm_lewis", "response_format": "wav", @@ -69,7 +71,7 @@ def test_openai_speech_endpoint(mock_tts_service, mock_audio_service): def test_openai_speech_invalid_voice(mock_tts_service): """Test the OpenAI-compatible speech endpoint with invalid voice""" test_request = { - "model": "tts-1", + "model": "kokoro", "input": "Hello world", "voice": "invalid_voice", "response_format": "wav", @@ -82,7 +84,7 @@ def test_openai_speech_invalid_voice(mock_tts_service): def test_openai_speech_invalid_speed(mock_tts_service): """Test the OpenAI-compatible speech endpoint with invalid speed""" test_request = { - "model": "tts-1", + "model": "kokoro", "input": "Hello world", "voice": "af", "response_format": "wav", @@ -96,7 +98,7 @@ def test_openai_speech_generation_error(mock_tts_service): """Test error handling in speech generation""" mock_tts_service._generate_audio.side_effect = Exception("Generation failed") test_request = { - "model": "tts-1", + "model": "kokoro", "input": "Hello world", "voice": "af", "response_format": "wav", diff --git a/api/tests/test_main.py b/api/tests/test_main.py new file mode 100644 index 0000000..9493d27 --- /dev/null +++ b/api/tests/test_main.py @@ -0,0 +1,45 @@ +"""Tests for main FastAPI application""" +import pytest +from unittest.mock import patch, MagicMock +from fastapi.testclient import TestClient + +from api.src.main import app + + +@pytest.fixture +def client(): + """Create a test client""" + return TestClient(app) + + +def test_health_check(client): + """Test health check endpoint""" + response = client.get("/health") + assert response.status_code == 200 + assert response.json() == {"status": "healthy"} + + +def test_test_endpoint(client): + """Test the test endpoint""" + response = client.get("/v1/test") + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + +def test_cors_headers(client): + """Test CORS headers are present""" + response = client.get( + "/health", + headers={"Origin": "http://testserver"}, + ) + assert response.status_code == 200 + assert response.headers["access-control-allow-origin"] == "*" + + +def test_openapi_schema(client): + """Test OpenAPI schema is accessible""" + response = client.get("/openapi.json") + assert response.status_code == 200 + schema = response.json() + assert schema["info"]["title"] == app.title + assert schema["info"]["version"] == app.version diff --git a/api/tests/test_tts_service.py b/api/tests/test_tts_service.py new file mode 100644 index 0000000..3f35a2b --- /dev/null +++ b/api/tests/test_tts_service.py @@ -0,0 +1,244 @@ +"""Tests for TTSService""" +import os +import numpy as np +import pytest +from unittest.mock import patch, MagicMock, call +from api.src.services.tts import TTSService, TTSModel + + +@pytest.fixture +def tts_service(): + """Create a TTSService instance for testing""" + return TTSService(start_worker=False) + + +@pytest.fixture +def sample_audio(): + """Generate a simple sine wave for testing""" + sample_rate = 24000 + duration = 0.1 # 100ms + t = np.linspace(0, duration, int(sample_rate * duration)) + frequency = 440 # A4 note + return np.sin(2 * np.pi * frequency * t).astype(np.float32) + + +def test_split_text(tts_service): + """Test text splitting into sentences""" + text = "First sentence. Second sentence! Third sentence?" + sentences = tts_service._split_text(text) + assert len(sentences) == 3 + assert sentences[0] == "First sentence." + assert sentences[1] == "Second sentence!" + assert sentences[2] == "Third sentence?" + + +def test_split_text_empty(tts_service): + """Test splitting empty text""" + assert tts_service._split_text("") == [] + + +def test_split_text_single_sentence(tts_service): + """Test splitting single sentence""" + text = "Just one sentence." + assert tts_service._split_text(text) == ["Just one sentence."] + + +def test_audio_to_bytes(tts_service, sample_audio): + """Test converting audio tensor to bytes""" + audio_bytes = tts_service._audio_to_bytes(sample_audio) + assert isinstance(audio_bytes, bytes) + assert len(audio_bytes) > 0 + + +@patch('os.listdir') +@patch('os.path.join') +def test_list_voices(mock_join, mock_listdir, tts_service): + """Test listing available voices""" + mock_listdir.return_value = ['voice1.pt', 'voice2.pt', 'not_a_voice.txt'] + mock_join.return_value = '/fake/path' + + voices = tts_service.list_voices() + assert len(voices) == 2 + assert 'voice1' in voices + assert 'voice2' in voices + assert 'not_a_voice' not in voices + + +@patch('api.src.services.tts.TTSModel.get_instance') +@patch('api.src.services.tts.TTSModel.get_voicepack') +@patch('api.src.services.tts.normalize_text') +@patch('api.src.services.tts.phonemize') +@patch('api.src.services.tts.tokenize') +@patch('api.src.services.tts.generate') +def test_generate_audio_empty_text(mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_voicepack, mock_instance, tts_service): + """Test generating audio with empty text""" + mock_normalize.return_value = "" + + with pytest.raises(ValueError, match="Text is empty after preprocessing"): + tts_service._generate_audio("", "af", 1.0) + + +@patch('api.src.services.tts.TTSModel.get_instance') +@patch('api.src.services.tts.TTSModel.get_voicepack') +@patch('api.src.services.tts.normalize_text') +@patch('api.src.services.tts.phonemize') +@patch('api.src.services.tts.tokenize') +@patch('api.src.services.tts.generate') +def test_generate_audio_no_chunks(mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_voicepack, mock_instance, tts_service): + """Test generating audio with no successful chunks""" + mock_normalize.return_value = "Test text" + mock_phonemize.return_value = "Test text" + mock_tokenize.return_value = ["test", "text"] + mock_generate.return_value = (None, None) + mock_instance.return_value = (MagicMock(), "cpu") + + with pytest.raises(ValueError, match="No audio chunks were generated successfully"): + tts_service._generate_audio("Test text", "af", 1.0) + + +@patch('api.src.services.tts.TTSModel.get_instance') +@patch('api.src.services.tts.TTSModel.get_voicepack') +@patch('api.src.services.tts.normalize_text') +@patch('api.src.services.tts.phonemize') +@patch('api.src.services.tts.tokenize') +@patch('api.src.services.tts.generate') +def test_generate_audio_success(mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_voicepack, mock_instance, tts_service, sample_audio): + """Test successful audio generation""" + mock_normalize.return_value = "Test text" + mock_phonemize.return_value = "Test text" + mock_tokenize.return_value = ["test", "text"] + mock_generate.return_value = (sample_audio, None) + mock_instance.return_value = (MagicMock(), "cpu") + + audio, processing_time = tts_service._generate_audio("Test text", "af", 1.0) + assert isinstance(audio, np.ndarray) + assert isinstance(processing_time, float) + assert len(audio) > 0 + + +@patch('api.src.services.tts.torch.cuda.is_available') +@patch('api.src.services.tts.build_model') +def test_model_initialization_cuda(mock_build_model, mock_cuda_available): + """Test model initialization with CUDA""" + mock_cuda_available.return_value = True + mock_model = MagicMock() + mock_build_model.return_value = mock_model + + TTSModel._instance = None # Reset singleton + model, device = TTSModel.get_instance() + + assert device == "cuda" + assert model == mock_model + mock_build_model.assert_called_once() + + +@patch('api.src.services.tts.torch.cuda.is_available') +@patch('api.src.services.tts.build_model') +def test_model_initialization_cpu(mock_build_model, mock_cuda_available): + """Test model initialization with CPU""" + mock_cuda_available.return_value = False + mock_model = MagicMock() + mock_build_model.return_value = mock_model + + TTSModel._instance = None # Reset singleton + model, device = TTSModel.get_instance() + + assert device == "cpu" + assert model == mock_model + mock_build_model.assert_called_once() + + +@patch('api.src.services.tts.torch.load') +@patch('os.path.join') +def test_voicepack_loading_error(mock_join, mock_torch_load): + """Test voicepack loading error handling""" + mock_join.side_effect = lambda *args: '/'.join(args) + mock_torch_load.side_effect = [Exception("Failed to load voice"), MagicMock()] + + TTSModel._instance = (MagicMock(), "cpu") # Mock instance + TTSModel._voicepacks = {} # Reset voicepacks + + # Should fall back to 'af' voice + voicepack = TTSModel.get_voicepack("nonexistent_voice") + assert mock_torch_load.call_count == 2 # Tried original voice then fallback + assert isinstance(voicepack, MagicMock) # Successfully got fallback voice + + +@patch('api.src.services.tts.torch.load') +@patch('os.path.join') +def test_voicepack_loading_error_af(mock_join, mock_torch_load): + """Test voicepack loading error for 'af' voice""" + mock_join.side_effect = lambda *args: '/'.join(args) + mock_torch_load.side_effect = Exception("Failed to load voice") + + TTSModel._instance = (MagicMock(), "cpu") # Mock instance + TTSModel._voicepacks = {} # Reset voicepacks + + with pytest.raises(Exception): + TTSModel.get_voicepack("af") + + +def test_save_audio(tts_service, sample_audio, tmp_path): + """Test saving audio to file""" + output_path = os.path.join(tmp_path, "test_output", "audio.wav") + tts_service._save_audio(sample_audio, output_path) + + assert os.path.exists(output_path) + assert os.path.getsize(output_path) > 0 + + +@patch('api.src.services.tts.TTSModel.get_instance') +@patch('api.src.services.tts.TTSModel.get_voicepack') +@patch('api.src.services.tts.normalize_text') +@patch('api.src.services.tts.generate') +def test_generate_audio_without_stitching(mock_generate, mock_normalize, mock_voicepack, mock_instance, tts_service, sample_audio): + """Test generating audio without text stitching""" + mock_normalize.return_value = "Test text" + mock_generate.return_value = (sample_audio, None) + mock_instance.return_value = (MagicMock(), "cpu") + + audio, processing_time = tts_service._generate_audio("Test text", "af", 1.0, stitch_long_output=False) + assert isinstance(audio, np.ndarray) + assert isinstance(processing_time, float) + assert len(audio) > 0 + mock_generate.assert_called_once() + + +@patch('os.listdir') +def test_list_voices_error(mock_listdir, tts_service): + """Test error handling in list_voices""" + mock_listdir.side_effect = Exception("Failed to list directory") + + voices = tts_service.list_voices() + assert voices == [] + + +@patch('api.src.services.tts.TTSModel.get_instance') +@patch('api.src.services.tts.TTSModel.get_voicepack') +@patch('api.src.services.tts.normalize_text') +@patch('api.src.services.tts.phonemize') +@patch('api.src.services.tts.tokenize') +@patch('api.src.services.tts.generate') +def test_generate_audio_phonemize_error(mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_voicepack, mock_instance, tts_service): + """Test handling phonemization error""" + mock_normalize.return_value = "Test text" + mock_phonemize.side_effect = Exception("Phonemization failed") + mock_instance.return_value = (MagicMock(), "cpu") + mock_generate.return_value = (None, None) + + with pytest.raises(ValueError, match="No audio chunks were generated successfully"): + tts_service._generate_audio("Test text", "af", 1.0) + + +@patch('api.src.services.tts.TTSModel.get_instance') +@patch('api.src.services.tts.TTSModel.get_voicepack') +@patch('api.src.services.tts.normalize_text') +@patch('api.src.services.tts.generate') +def test_generate_audio_error(mock_generate, mock_normalize, mock_voicepack, mock_instance, tts_service): + """Test handling generation error""" + mock_normalize.return_value = "Test text" + mock_generate.side_effect = Exception("Generation failed") + mock_instance.return_value = (MagicMock(), "cpu") + + with pytest.raises(ValueError, match="No audio chunks were generated successfully"): + tts_service._generate_audio("Test text", "af", 1.0) diff --git a/examples/benchmarks/benchmark_tts.py b/examples/benchmarks/benchmark_tts.py index 61b51f1..2e657ce 100644 --- a/examples/benchmarks/benchmark_tts.py +++ b/examples/benchmarks/benchmark_tts.py @@ -1,16 +1,17 @@ import os -import time import json -import scipy.io.wavfile as wavfile -import requests -import pandas as pd -import seaborn as sns -import matplotlib.pyplot as plt -import tiktoken -import psutil +import time import subprocess from datetime import datetime +import pandas as pd +import psutil +import seaborn as sns +import requests +import tiktoken +import scipy.io.wavfile as wavfile +import matplotlib.pyplot as plt + enc = tiktoken.get_encoding("cl100k_base") diff --git a/examples/test_all_voices.py b/examples/test_all_voices.py index 3f1c88a..c0645a4 100644 --- a/examples/test_all_voices.py +++ b/examples/test_all_voices.py @@ -1,4 +1,5 @@ from pathlib import Path + import openai import requests @@ -18,6 +19,7 @@ output_dir = Path(__file__).parent / "output" output_dir.mkdir(exist_ok=True) + def test_voice(voice: str): speech_file = output_dir / f"speech_{voice}.wav" print(f"\nTesting voice: {voice}") @@ -25,7 +27,7 @@ def test_voice(voice: str): try: response = client.audio.speech.create( - model="tts-1", voice=voice, input=SAMPLE_TEXT, response_format="wav" + model="kokoro", voice=voice, input=SAMPLE_TEXT, response_format="wav" ) print("Got response, saving to file...") diff --git a/examples/test_openai_tts.py b/examples/test_openai_tts.py index fd9d7d6..932aa11 100644 --- a/examples/test_openai_tts.py +++ b/examples/test_openai_tts.py @@ -1,4 +1,5 @@ from pathlib import Path + import openai # Configure OpenAI client to use our local endpoint diff --git a/pytest.ini b/pytest.ini index e7ea054..3bcd461 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,5 +1,5 @@ [pytest] testpaths = api/tests python_files = test_*.py -addopts = -v --tb=short +addopts = -v --tb=short --cov=api --cov-report=term-missing --cov-config=.coveragerc pythonpath = .