diff --git a/.coverage b/.coverage new file mode 100644 index 0000000..7a62bb5 Binary files /dev/null and b/.coverage differ diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..f422eda --- /dev/null +++ b/.coveragerc @@ -0,0 +1,12 @@ +[run] +source = api +omit = Kokoro-82M/* + +[report] +exclude_lines = + pragma: no cover + def __repr__ + raise NotImplementedError + if __name__ == .__main__.: + pass + raise ImportError diff --git a/.ruff.toml b/.ruff.toml new file mode 100644 index 0000000..833539b --- /dev/null +++ b/.ruff.toml @@ -0,0 +1,11 @@ +line-length = 88 + +[lint] +select = ["I"] + +[lint.isort] +combine-as-imports = true +force-wrap-aliases = true +length-sort = true +split-on-trailing-comma = true +section-order = ["future", "standard-library", "third-party", "first-party", "local-folder"] diff --git a/api/src/main.py b/api/src/main.py index 362602c..9115ade 100644 --- a/api/src/main.py +++ b/api/src/main.py @@ -2,15 +2,16 @@ FastAPI OpenAI Compatible API """ -import uvicorn from contextlib import asynccontextmanager + +import uvicorn +from loguru import logger from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from loguru import logger from .core.config import settings -from .routers.openai_compatible import router as openai_router from .services.tts import TTSModel, TTSService +from .routers.openai_compatible import router as openai_router @asynccontextmanager diff --git a/api/src/routers/openai_compatible.py b/api/src/routers/openai_compatible.py index df20d66..b42c794 100644 --- a/api/src/routers/openai_compatible.py +++ b/api/src/routers/openai_compatible.py @@ -1,10 +1,9 @@ -from fastapi import APIRouter, HTTPException, Response, Depends -import logging -from ..structures.schemas import OpenAISpeechRequest +from loguru import logger +from fastapi import Depends, Response, APIRouter, HTTPException + from ..services.tts import TTSService from ..services.audio import AudioService - -logger = logging.getLogger(__name__) +from ..structures.schemas import OpenAISpeechRequest router = APIRouter( tags=["OpenAI Compatible TTS"], diff --git a/api/src/services/__init__.py b/api/src/services/__init__.py index ee384e9..46f2e93 100644 --- a/api/src/services/__init__.py +++ b/api/src/services/__init__.py @@ -1,3 +1,3 @@ -from .tts import TTSService, TTSModel +from .tts import TTSModel, TTSService __all__ = ["TTSService", "TTSModel"] diff --git a/api/src/services/audio.py b/api/src/services/audio.py index d3408ff..0aa852d 100644 --- a/api/src/services/audio.py +++ b/api/src/services/audio.py @@ -1,12 +1,11 @@ """Audio conversion service""" from io import BytesIO -import numpy as np -import scipy.io.wavfile as wavfile -import soundfile as sf -import logging -logger = logging.getLogger(__name__) +import numpy as np +import soundfile as sf +import scipy.io.wavfile as wavfile +from loguru import logger class AudioService: diff --git a/api/src/services/tts.py b/api/src/services/tts.py index 35d9bcb..76b83cc 100644 --- a/api/src/services/tts.py +++ b/api/src/services/tts.py @@ -1,19 +1,20 @@ -import os -import threading -import time import io +import os +import re +import time +import threading from typing import List, Tuple + import numpy as np import torch -import scipy.io.wavfile as wavfile -from models import build_model -from kokoro import generate, phonemize, tokenize, normalize_text -from ..core.config import settings -import re -import logging import tiktoken +import scipy.io.wavfile as wavfile +from kokoro import generate, tokenize, phonemize, normalize_text +from loguru import logger +from models import build_model + +from ..core.config import settings -logger = logging.getLogger(__name__) enc = tiktoken.get_encoding("cl100k_base") diff --git a/api/src/structures/schemas.py b/api/src/structures/schemas.py index bb00fc7..5a36e4a 100644 --- a/api/src/structures/schemas.py +++ b/api/src/structures/schemas.py @@ -1,6 +1,7 @@ -from pydantic import BaseModel, Field -from typing import Literal from enum import Enum +from typing import Literal + +from pydantic import Field, BaseModel class TTSStatus(str, Enum): @@ -13,7 +14,7 @@ class TTSStatus(str, Enum): # OpenAI-compatible schemas class OpenAISpeechRequest(BaseModel): - model: Literal["tts-1", "tts-1-hd"] = "tts-1" + model: Literal["tts-1", "tts-1-hd", "kokoro"] = "kokoro" input: str = Field(..., description="The text to generate audio for") voice: Literal[ "am_adam", diff --git a/api/tests/conftest.py b/api/tests/conftest.py index ecb8229..6648c15 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -1,6 +1,7 @@ -import pytest -from unittest.mock import Mock, patch import sys +from unittest.mock import Mock, patch + +import pytest # Mock torch and other ML modules before they're imported sys.modules["torch"] = Mock() diff --git a/api/tests/test_audio_service.py b/api/tests/test_audio_service.py new file mode 100644 index 0000000..0e1d1bc --- /dev/null +++ b/api/tests/test_audio_service.py @@ -0,0 +1,67 @@ +"""Tests for AudioService""" +import numpy as np +import pytest +from api.src.services.audio import AudioService + + +@pytest.fixture +def sample_audio(): + """Generate a simple sine wave for testing""" + sample_rate = 24000 + duration = 0.1 # 100ms + t = np.linspace(0, duration, int(sample_rate * duration)) + frequency = 440 # A4 note + return np.sin(2 * np.pi * frequency * t).astype(np.float32), sample_rate + + +def test_convert_to_wav(sample_audio): + """Test converting to WAV format""" + audio_data, sample_rate = sample_audio + result = AudioService.convert_audio(audio_data, sample_rate, "wav") + assert isinstance(result, bytes) + assert len(result) > 0 + + +def test_convert_to_mp3(sample_audio): + """Test converting to MP3 format""" + audio_data, sample_rate = sample_audio + result = AudioService.convert_audio(audio_data, sample_rate, "mp3") + assert isinstance(result, bytes) + assert len(result) > 0 + + +def test_convert_to_opus(sample_audio): + """Test converting to Opus format""" + audio_data, sample_rate = sample_audio + result = AudioService.convert_audio(audio_data, sample_rate, "opus") + assert isinstance(result, bytes) + assert len(result) > 0 + + +def test_convert_to_flac(sample_audio): + """Test converting to FLAC format""" + audio_data, sample_rate = sample_audio + result = AudioService.convert_audio(audio_data, sample_rate, "flac") + assert isinstance(result, bytes) + assert len(result) > 0 + + +def test_convert_to_aac_raises_error(sample_audio): + """Test that converting to AAC raises an error""" + audio_data, sample_rate = sample_audio + with pytest.raises(ValueError, match="AAC format is not currently supported"): + AudioService.convert_audio(audio_data, sample_rate, "aac") + + +def test_convert_to_pcm_raises_error(sample_audio): + """Test that converting to PCM raises an error""" + audio_data, sample_rate = sample_audio + with pytest.raises(ValueError, match="PCM format is not currently supported"): + AudioService.convert_audio(audio_data, sample_rate, "pcm") + + +def test_convert_to_invalid_format_raises_error(sample_audio): + """Test that converting to an invalid format raises an error""" + audio_data, sample_rate = sample_audio + with pytest.raises(ValueError, match="Format invalid not supported"): + AudioService.convert_audio(audio_data, sample_rate, "invalid") diff --git a/api/tests/test_endpoints.py b/api/tests/test_endpoints.py index c2223f0..97789f5 100644 --- a/api/tests/test_endpoints.py +++ b/api/tests/test_endpoints.py @@ -1,6 +1,8 @@ -from fastapi.testclient import TestClient -import pytest from unittest.mock import Mock + +import pytest +from fastapi.testclient import TestClient + from ..src.main import app # Create test client @@ -50,7 +52,7 @@ def test_health_check(): def test_openai_speech_endpoint(mock_tts_service, mock_audio_service): """Test the OpenAI-compatible speech endpoint""" test_request = { - "model": "tts-1", + "model": "kokoro", "input": "Hello world", "voice": "bm_lewis", "response_format": "wav", @@ -69,7 +71,7 @@ def test_openai_speech_endpoint(mock_tts_service, mock_audio_service): def test_openai_speech_invalid_voice(mock_tts_service): """Test the OpenAI-compatible speech endpoint with invalid voice""" test_request = { - "model": "tts-1", + "model": "kokoro", "input": "Hello world", "voice": "invalid_voice", "response_format": "wav", @@ -82,7 +84,7 @@ def test_openai_speech_invalid_voice(mock_tts_service): def test_openai_speech_invalid_speed(mock_tts_service): """Test the OpenAI-compatible speech endpoint with invalid speed""" test_request = { - "model": "tts-1", + "model": "kokoro", "input": "Hello world", "voice": "af", "response_format": "wav", @@ -96,7 +98,7 @@ def test_openai_speech_generation_error(mock_tts_service): """Test error handling in speech generation""" mock_tts_service._generate_audio.side_effect = Exception("Generation failed") test_request = { - "model": "tts-1", + "model": "kokoro", "input": "Hello world", "voice": "af", "response_format": "wav", diff --git a/api/tests/test_main.py b/api/tests/test_main.py new file mode 100644 index 0000000..9493d27 --- /dev/null +++ b/api/tests/test_main.py @@ -0,0 +1,45 @@ +"""Tests for main FastAPI application""" +import pytest +from unittest.mock import patch, MagicMock +from fastapi.testclient import TestClient + +from api.src.main import app + + +@pytest.fixture +def client(): + """Create a test client""" + return TestClient(app) + + +def test_health_check(client): + """Test health check endpoint""" + response = client.get("/health") + assert response.status_code == 200 + assert response.json() == {"status": "healthy"} + + +def test_test_endpoint(client): + """Test the test endpoint""" + response = client.get("/v1/test") + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + +def test_cors_headers(client): + """Test CORS headers are present""" + response = client.get( + "/health", + headers={"Origin": "http://testserver"}, + ) + assert response.status_code == 200 + assert response.headers["access-control-allow-origin"] == "*" + + +def test_openapi_schema(client): + """Test OpenAPI schema is accessible""" + response = client.get("/openapi.json") + assert response.status_code == 200 + schema = response.json() + assert schema["info"]["title"] == app.title + assert schema["info"]["version"] == app.version diff --git a/api/tests/test_tts_service.py b/api/tests/test_tts_service.py new file mode 100644 index 0000000..3f35a2b --- /dev/null +++ b/api/tests/test_tts_service.py @@ -0,0 +1,244 @@ +"""Tests for TTSService""" +import os +import numpy as np +import pytest +from unittest.mock import patch, MagicMock, call +from api.src.services.tts import TTSService, TTSModel + + +@pytest.fixture +def tts_service(): + """Create a TTSService instance for testing""" + return TTSService(start_worker=False) + + +@pytest.fixture +def sample_audio(): + """Generate a simple sine wave for testing""" + sample_rate = 24000 + duration = 0.1 # 100ms + t = np.linspace(0, duration, int(sample_rate * duration)) + frequency = 440 # A4 note + return np.sin(2 * np.pi * frequency * t).astype(np.float32) + + +def test_split_text(tts_service): + """Test text splitting into sentences""" + text = "First sentence. Second sentence! Third sentence?" + sentences = tts_service._split_text(text) + assert len(sentences) == 3 + assert sentences[0] == "First sentence." + assert sentences[1] == "Second sentence!" + assert sentences[2] == "Third sentence?" + + +def test_split_text_empty(tts_service): + """Test splitting empty text""" + assert tts_service._split_text("") == [] + + +def test_split_text_single_sentence(tts_service): + """Test splitting single sentence""" + text = "Just one sentence." + assert tts_service._split_text(text) == ["Just one sentence."] + + +def test_audio_to_bytes(tts_service, sample_audio): + """Test converting audio tensor to bytes""" + audio_bytes = tts_service._audio_to_bytes(sample_audio) + assert isinstance(audio_bytes, bytes) + assert len(audio_bytes) > 0 + + +@patch('os.listdir') +@patch('os.path.join') +def test_list_voices(mock_join, mock_listdir, tts_service): + """Test listing available voices""" + mock_listdir.return_value = ['voice1.pt', 'voice2.pt', 'not_a_voice.txt'] + mock_join.return_value = '/fake/path' + + voices = tts_service.list_voices() + assert len(voices) == 2 + assert 'voice1' in voices + assert 'voice2' in voices + assert 'not_a_voice' not in voices + + +@patch('api.src.services.tts.TTSModel.get_instance') +@patch('api.src.services.tts.TTSModel.get_voicepack') +@patch('api.src.services.tts.normalize_text') +@patch('api.src.services.tts.phonemize') +@patch('api.src.services.tts.tokenize') +@patch('api.src.services.tts.generate') +def test_generate_audio_empty_text(mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_voicepack, mock_instance, tts_service): + """Test generating audio with empty text""" + mock_normalize.return_value = "" + + with pytest.raises(ValueError, match="Text is empty after preprocessing"): + tts_service._generate_audio("", "af", 1.0) + + +@patch('api.src.services.tts.TTSModel.get_instance') +@patch('api.src.services.tts.TTSModel.get_voicepack') +@patch('api.src.services.tts.normalize_text') +@patch('api.src.services.tts.phonemize') +@patch('api.src.services.tts.tokenize') +@patch('api.src.services.tts.generate') +def test_generate_audio_no_chunks(mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_voicepack, mock_instance, tts_service): + """Test generating audio with no successful chunks""" + mock_normalize.return_value = "Test text" + mock_phonemize.return_value = "Test text" + mock_tokenize.return_value = ["test", "text"] + mock_generate.return_value = (None, None) + mock_instance.return_value = (MagicMock(), "cpu") + + with pytest.raises(ValueError, match="No audio chunks were generated successfully"): + tts_service._generate_audio("Test text", "af", 1.0) + + +@patch('api.src.services.tts.TTSModel.get_instance') +@patch('api.src.services.tts.TTSModel.get_voicepack') +@patch('api.src.services.tts.normalize_text') +@patch('api.src.services.tts.phonemize') +@patch('api.src.services.tts.tokenize') +@patch('api.src.services.tts.generate') +def test_generate_audio_success(mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_voicepack, mock_instance, tts_service, sample_audio): + """Test successful audio generation""" + mock_normalize.return_value = "Test text" + mock_phonemize.return_value = "Test text" + mock_tokenize.return_value = ["test", "text"] + mock_generate.return_value = (sample_audio, None) + mock_instance.return_value = (MagicMock(), "cpu") + + audio, processing_time = tts_service._generate_audio("Test text", "af", 1.0) + assert isinstance(audio, np.ndarray) + assert isinstance(processing_time, float) + assert len(audio) > 0 + + +@patch('api.src.services.tts.torch.cuda.is_available') +@patch('api.src.services.tts.build_model') +def test_model_initialization_cuda(mock_build_model, mock_cuda_available): + """Test model initialization with CUDA""" + mock_cuda_available.return_value = True + mock_model = MagicMock() + mock_build_model.return_value = mock_model + + TTSModel._instance = None # Reset singleton + model, device = TTSModel.get_instance() + + assert device == "cuda" + assert model == mock_model + mock_build_model.assert_called_once() + + +@patch('api.src.services.tts.torch.cuda.is_available') +@patch('api.src.services.tts.build_model') +def test_model_initialization_cpu(mock_build_model, mock_cuda_available): + """Test model initialization with CPU""" + mock_cuda_available.return_value = False + mock_model = MagicMock() + mock_build_model.return_value = mock_model + + TTSModel._instance = None # Reset singleton + model, device = TTSModel.get_instance() + + assert device == "cpu" + assert model == mock_model + mock_build_model.assert_called_once() + + +@patch('api.src.services.tts.torch.load') +@patch('os.path.join') +def test_voicepack_loading_error(mock_join, mock_torch_load): + """Test voicepack loading error handling""" + mock_join.side_effect = lambda *args: '/'.join(args) + mock_torch_load.side_effect = [Exception("Failed to load voice"), MagicMock()] + + TTSModel._instance = (MagicMock(), "cpu") # Mock instance + TTSModel._voicepacks = {} # Reset voicepacks + + # Should fall back to 'af' voice + voicepack = TTSModel.get_voicepack("nonexistent_voice") + assert mock_torch_load.call_count == 2 # Tried original voice then fallback + assert isinstance(voicepack, MagicMock) # Successfully got fallback voice + + +@patch('api.src.services.tts.torch.load') +@patch('os.path.join') +def test_voicepack_loading_error_af(mock_join, mock_torch_load): + """Test voicepack loading error for 'af' voice""" + mock_join.side_effect = lambda *args: '/'.join(args) + mock_torch_load.side_effect = Exception("Failed to load voice") + + TTSModel._instance = (MagicMock(), "cpu") # Mock instance + TTSModel._voicepacks = {} # Reset voicepacks + + with pytest.raises(Exception): + TTSModel.get_voicepack("af") + + +def test_save_audio(tts_service, sample_audio, tmp_path): + """Test saving audio to file""" + output_path = os.path.join(tmp_path, "test_output", "audio.wav") + tts_service._save_audio(sample_audio, output_path) + + assert os.path.exists(output_path) + assert os.path.getsize(output_path) > 0 + + +@patch('api.src.services.tts.TTSModel.get_instance') +@patch('api.src.services.tts.TTSModel.get_voicepack') +@patch('api.src.services.tts.normalize_text') +@patch('api.src.services.tts.generate') +def test_generate_audio_without_stitching(mock_generate, mock_normalize, mock_voicepack, mock_instance, tts_service, sample_audio): + """Test generating audio without text stitching""" + mock_normalize.return_value = "Test text" + mock_generate.return_value = (sample_audio, None) + mock_instance.return_value = (MagicMock(), "cpu") + + audio, processing_time = tts_service._generate_audio("Test text", "af", 1.0, stitch_long_output=False) + assert isinstance(audio, np.ndarray) + assert isinstance(processing_time, float) + assert len(audio) > 0 + mock_generate.assert_called_once() + + +@patch('os.listdir') +def test_list_voices_error(mock_listdir, tts_service): + """Test error handling in list_voices""" + mock_listdir.side_effect = Exception("Failed to list directory") + + voices = tts_service.list_voices() + assert voices == [] + + +@patch('api.src.services.tts.TTSModel.get_instance') +@patch('api.src.services.tts.TTSModel.get_voicepack') +@patch('api.src.services.tts.normalize_text') +@patch('api.src.services.tts.phonemize') +@patch('api.src.services.tts.tokenize') +@patch('api.src.services.tts.generate') +def test_generate_audio_phonemize_error(mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_voicepack, mock_instance, tts_service): + """Test handling phonemization error""" + mock_normalize.return_value = "Test text" + mock_phonemize.side_effect = Exception("Phonemization failed") + mock_instance.return_value = (MagicMock(), "cpu") + mock_generate.return_value = (None, None) + + with pytest.raises(ValueError, match="No audio chunks were generated successfully"): + tts_service._generate_audio("Test text", "af", 1.0) + + +@patch('api.src.services.tts.TTSModel.get_instance') +@patch('api.src.services.tts.TTSModel.get_voicepack') +@patch('api.src.services.tts.normalize_text') +@patch('api.src.services.tts.generate') +def test_generate_audio_error(mock_generate, mock_normalize, mock_voicepack, mock_instance, tts_service): + """Test handling generation error""" + mock_normalize.return_value = "Test text" + mock_generate.side_effect = Exception("Generation failed") + mock_instance.return_value = (MagicMock(), "cpu") + + with pytest.raises(ValueError, match="No audio chunks were generated successfully"): + tts_service._generate_audio("Test text", "af", 1.0) diff --git a/examples/benchmarks/benchmark_tts.py b/examples/benchmarks/benchmark_tts.py index 61b51f1..2e657ce 100644 --- a/examples/benchmarks/benchmark_tts.py +++ b/examples/benchmarks/benchmark_tts.py @@ -1,16 +1,17 @@ import os -import time import json -import scipy.io.wavfile as wavfile -import requests -import pandas as pd -import seaborn as sns -import matplotlib.pyplot as plt -import tiktoken -import psutil +import time import subprocess from datetime import datetime +import pandas as pd +import psutil +import seaborn as sns +import requests +import tiktoken +import scipy.io.wavfile as wavfile +import matplotlib.pyplot as plt + enc = tiktoken.get_encoding("cl100k_base") diff --git a/examples/test_all_voices.py b/examples/test_all_voices.py index 3f1c88a..c0645a4 100644 --- a/examples/test_all_voices.py +++ b/examples/test_all_voices.py @@ -1,4 +1,5 @@ from pathlib import Path + import openai import requests @@ -18,6 +19,7 @@ output_dir = Path(__file__).parent / "output" output_dir.mkdir(exist_ok=True) + def test_voice(voice: str): speech_file = output_dir / f"speech_{voice}.wav" print(f"\nTesting voice: {voice}") @@ -25,7 +27,7 @@ def test_voice(voice: str): try: response = client.audio.speech.create( - model="tts-1", voice=voice, input=SAMPLE_TEXT, response_format="wav" + model="kokoro", voice=voice, input=SAMPLE_TEXT, response_format="wav" ) print("Got response, saving to file...") diff --git a/examples/test_openai_tts.py b/examples/test_openai_tts.py index fd9d7d6..932aa11 100644 --- a/examples/test_openai_tts.py +++ b/examples/test_openai_tts.py @@ -1,4 +1,5 @@ from pathlib import Path + import openai # Configure OpenAI client to use our local endpoint diff --git a/pytest.ini b/pytest.ini index e7ea054..3bcd461 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,5 +1,5 @@ [pytest] testpaths = api/tests python_files = test_*.py -addopts = -v --tb=short +addopts = -v --tb=short --cov=api --cov-report=term-missing --cov-config=.coveragerc pythonpath = .