Ruff Check + Format

This commit is contained in:
remsky 2025-01-01 21:50:41 -07:00
parent e749b3bc88
commit f051984805
27 changed files with 638 additions and 504 deletions

BIN
.coverage

Binary file not shown.

View file

@ -1,10 +1,10 @@
from typing import List
from fastapi import APIRouter, Depends, HTTPException, Response
from loguru import logger
from fastapi import Depends, Response, APIRouter, HTTPException
from ..services.audio import AudioService
from ..services.tts import TTSService
from ..services.audio import AudioService
from ..structures.schemas import OpenAISpeechRequest
router = APIRouter(
@ -32,7 +32,7 @@ async def create_speech(
raise ValueError(
f"Voice '{request.voice}' not found. Available voices: {', '.join(sorted(available_voices))}"
)
# Generate audio directly using TTSService's method
audio, _ = tts_service._generate_audio(
text=request.input,
@ -55,14 +55,12 @@ async def create_speech(
except ValueError as e:
logger.error(f"Invalid request: {str(e)}")
raise HTTPException(
status_code=400,
detail={"error": "Invalid request", "message": str(e)}
status_code=400, detail={"error": "Invalid request", "message": str(e)}
)
except Exception as e:
logger.error(f"Error generating speech: {str(e)}")
raise HTTPException(
status_code=500,
detail={"error": "Server error", "message": str(e)}
status_code=500, detail={"error": "Server error", "message": str(e)}
)
@ -78,17 +76,19 @@ async def list_voices(tts_service: TTSService = Depends(get_tts_service)):
@router.post("/audio/voices/combine")
async def combine_voices(request: List[str], tts_service: TTSService = Depends(get_tts_service)):
async def combine_voices(
request: List[str], tts_service: TTSService = Depends(get_tts_service)
):
"""Combine multiple voices into a new voice.
Args:
request: List of voice names to combine
Returns:
Dict with combined voice name and list of all available voices
Raises:
HTTPException:
HTTPException:
- 400: Invalid request (wrong number of voices, voice not found)
- 500: Server error (file system issues, combination failed)
"""
@ -96,24 +96,21 @@ async def combine_voices(request: List[str], tts_service: TTSService = Depends(g
combined_voice = tts_service.combine_voices(voices=request)
voices = tts_service.list_voices()
return {"voices": voices, "voice": combined_voice}
except ValueError as e:
logger.error(f"Invalid voice combination request: {str(e)}")
raise HTTPException(
status_code=400,
detail={"error": "Invalid request", "message": str(e)}
status_code=400, detail={"error": "Invalid request", "message": str(e)}
)
except RuntimeError as e:
logger.error(f"Server error during voice combination: {str(e)}")
raise HTTPException(
status_code=500,
detail={"error": "Server error", "message": str(e)}
status_code=500, detail={"error": "Server error", "message": str(e)}
)
except Exception as e:
logger.error(f"Unexpected error during voice combination: {str(e)}")
raise HTTPException(
status_code=500,
detail={"error": "Unexpected error", "message": str(e)}
status_code=500, detail={"error": "Unexpected error", "message": str(e)}
)

View file

@ -1,17 +1,16 @@
import io
import os
import re
import threading
import time
import threading
from typing import List, Tuple, Optional
import numpy as np
import scipy.io.wavfile as wavfile
import tiktoken
import torch
import tiktoken
import scipy.io.wavfile as wavfile
from kokoro import generate, tokenize, phonemize, normalize_text
from loguru import logger
from kokoro import generate, normalize_text, phonemize, tokenize
from models import build_model
from ..core.config import settings
@ -23,7 +22,7 @@ class TTSModel:
_instance = None
_device = None
_lock = threading.Lock()
# Directory for all voices (copied base voices, and any created combined voices)
VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices")
@ -38,10 +37,10 @@ class TTSModel:
model_path = os.path.join(settings.model_dir, settings.model_path)
model = build_model(model_path, cls._device)
cls._instance = model
# Ensure voices directory exists
os.makedirs(cls.VOICES_DIR, exist_ok=True)
# Copy base voices to local directory
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
if os.path.exists(base_voices_dir):
@ -51,25 +50,37 @@ class TTSModel:
voice_path = os.path.join(cls.VOICES_DIR, file)
if not os.path.exists(voice_path):
try:
logger.info(f"Copying base voice {voice_name} to voices directory")
logger.info(
f"Copying base voice {voice_name} to voices directory"
)
base_path = os.path.join(base_voices_dir, file)
voicepack = torch.load(base_path, map_location=cls._device, weights_only=True)
voicepack = torch.load(
base_path,
map_location=cls._device,
weights_only=True,
)
torch.save(voicepack, voice_path)
except Exception as e:
logger.error(f"Error copying voice {voice_name}: {str(e)}")
logger.error(
f"Error copying voice {voice_name}: {str(e)}"
)
# Warm up with default voice
try:
dummy_text = "Hello"
voice_path = os.path.join(cls.VOICES_DIR, "af.pt")
dummy_voicepack = torch.load(voice_path, map_location=cls._device, weights_only=True)
generate(model, dummy_text, dummy_voicepack, lang='a', speed=1.0)
dummy_voicepack = torch.load(
voice_path, map_location=cls._device, weights_only=True
)
generate(model, dummy_text, dummy_voicepack, lang="a", speed=1.0)
logger.info("Model warm-up complete")
except Exception as e:
logger.warning(f"Model warm-up failed: {e}")
# Count voices in directory for validation
voice_count = len([f for f in os.listdir(cls.VOICES_DIR) if f.endswith('.pt')])
voice_count = len(
[f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]
)
return cls._instance, voice_count
@classmethod
@ -86,11 +97,11 @@ class TTSService:
self._ensure_voices()
if start_worker:
self.start_worker()
def _ensure_voices(self):
"""Copy base voices to local voices directory during initialization"""
os.makedirs(TTSModel.VOICES_DIR, exist_ok=True)
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
if os.path.exists(base_voices_dir):
for file in os.listdir(base_voices_dir):
@ -99,9 +110,15 @@ class TTSService:
voice_path = os.path.join(TTSModel.VOICES_DIR, file)
if not os.path.exists(voice_path):
try:
logger.info(f"Copying base voice {voice_name} to voices directory")
logger.info(
f"Copying base voice {voice_name} to voices directory"
)
base_path = os.path.join(base_voices_dir, file)
voicepack = torch.load(base_path, map_location=TTSModel._device, weights_only=True)
voicepack = torch.load(
base_path,
map_location=TTSModel._device,
weights_only=True,
)
torch.save(voicepack, voice_path)
except Exception as e:
logger.error(f"Error copying voice {voice_name}: {str(e)}")
@ -112,10 +129,10 @@ class TTSService:
def _get_voice_path(self, voice_name: str) -> Optional[str]:
"""Get the path to a voice file.
Args:
voice_name: Name of the voice to find
Returns:
Path to the voice file if found, None otherwise
"""
@ -141,7 +158,9 @@ class TTSService:
# Load model and voice
model = TTSModel._instance
voicepack = torch.load(voice_path, map_location=TTSModel._device, weights_only=True)
voicepack = torch.load(
voice_path, map_location=TTSModel._device, weights_only=True
)
# Generate audio with or without stitching
if stitch_long_output:
@ -152,11 +171,11 @@ class TTSService:
for i, chunk in enumerate(chunks):
try:
# Validate phonemization first
ps = phonemize(chunk, voice[0])
tokens = tokenize(ps)
logger.debug(
f"Processing chunk {i + 1}/{len(chunks)}: {len(tokens)} tokens"
)
# ps = phonemize(chunk, voice[0])
# tokens = tokenize(ps)
# logger.debug(
# f"Processing chunk {i + 1}/{len(chunks)}: {len(tokens)} tokens"
# )
# Only proceed if phonemization succeeded
chunk_audio, _ = generate(
@ -205,47 +224,51 @@ class TTSService:
def combine_voices(self, voices: List[str]) -> str:
"""Combine multiple voices into a new voice.
Args:
voices: List of voice names to combine
Returns:
Name of the combined voice
Raises:
ValueError: If less than 2 voices provided or voice loading fails
RuntimeError: If voice combination or saving fails
"""
if len(voices) < 2:
raise ValueError("At least 2 voices are required for combination")
# Load voices
t_voices: List[torch.Tensor] = []
v_name: List[str] = []
for voice in voices:
try:
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt")
voicepack = torch.load(voice_path, map_location=TTSModel._device, weights_only=True)
voicepack = torch.load(
voice_path, map_location=TTSModel._device, weights_only=True
)
t_voices.append(voicepack)
v_name.append(voice)
except Exception as e:
raise ValueError(f"Failed to load voice {voice}: {str(e)}")
# Combine voices
try:
f: str = "_".join(v_name)
v = torch.mean(torch.stack(t_voices), dim=0)
combined_path = os.path.join(TTSModel.VOICES_DIR, f"{f}.pt")
# Save combined voice
try:
torch.save(v, combined_path)
except Exception as e:
raise RuntimeError(f"Failed to save combined voice to {combined_path}: {str(e)}")
raise RuntimeError(
f"Failed to save combined voice to {combined_path}: {str(e)}"
)
return f
except Exception as e:
if not isinstance(e, (ValueError, RuntimeError)):
raise RuntimeError(f"Error combining voices: {str(e)}")

View file

@ -17,8 +17,8 @@ class OpenAISpeechRequest(BaseModel):
model: Literal["tts-1", "tts-1-hd", "kokoro"] = "kokoro"
input: str = Field(..., description="The text to generate audio for")
voice: str = Field(
default="af",
description="The voice to use for generation. Can be a base voice or a combined voice name."
default="af",
description="The voice to use for generation. Can be a base voice or a combined voice name.",
)
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field(
default="mp3",

View file

@ -1,16 +1,18 @@
import os
import shutil
import sys
import shutil
from unittest.mock import Mock, patch
import pytest
def cleanup_mock_dirs():
"""Clean up any MagicMock directories created during tests"""
mock_dir = "MagicMock"
if os.path.exists(mock_dir):
shutil.rmtree(mock_dir)
@pytest.fixture(autouse=True)
def cleanup():
"""Automatically clean up before and after each test"""
@ -18,6 +20,7 @@ def cleanup():
yield
cleanup_mock_dirs()
# Mock torch and other ML modules before they're imported
sys.modules["torch"] = Mock()
sys.modules["transformers"] = Mock()

View file

@ -1,6 +1,8 @@
"""Tests for AudioService"""
import numpy as np
import pytest
from api.src.services.audio import AudioService

View file

@ -114,9 +114,9 @@ def test_combine_voices_success(mock_tts_service):
"""Test successful voice combination"""
test_voices = ["af_bella", "af_sarah"]
mock_tts_service.combine_voices.return_value = "af_bella_af_sarah"
response = client.post("/v1/audio/voices/combine", json=test_voices)
assert response.status_code == 200
assert response.json()["voice"] == "af_bella_af_sarah"
mock_tts_service.combine_voices.assert_called_once_with(voices=test_voices)
@ -126,9 +126,9 @@ def test_combine_voices_single_voice(mock_tts_service):
"""Test combining single voice returns default voice"""
test_voices = ["af_bella"]
mock_tts_service.combine_voices.return_value = "af"
response = client.post("/v1/audio/voices/combine", json=test_voices)
assert response.status_code == 200
assert response.json()["voice"] == "af"
@ -137,9 +137,9 @@ def test_combine_voices_empty_list(mock_tts_service):
"""Test combining empty voice list returns default voice"""
test_voices = []
mock_tts_service.combine_voices.return_value = "af"
response = client.post("/v1/audio/voices/combine", json=test_voices)
assert response.status_code == 200
assert response.json()["voice"] == "af"
@ -148,8 +148,8 @@ def test_combine_voices_error(mock_tts_service):
"""Test error handling in voice combination"""
test_voices = ["af_bella", "af_sarah"]
mock_tts_service.combine_voices.side_effect = Exception("Combination failed")
response = client.post("/v1/audio/voices/combine", json=test_voices)
assert response.status_code == 500
assert "Combination failed" in response.json()["detail"]["message"]

View file

@ -1,7 +1,10 @@
"""Tests for FastAPI application"""
from unittest.mock import MagicMock, patch
import pytest
from unittest.mock import patch, MagicMock
from fastapi.testclient import TestClient
from api.src.main import app, lifespan
@ -19,98 +22,100 @@ def test_health_check(test_client):
@pytest.mark.asyncio
@patch('api.src.main.TTSModel')
@patch('api.src.main.logger')
@patch("api.src.main.TTSModel")
@patch("api.src.main.logger")
async def test_lifespan_successful_warmup(mock_logger, mock_tts_model):
"""Test successful model warmup in lifespan"""
# Mock the model initialization with model info and voicepack count
mock_model = MagicMock()
# Mock file system for voice counting
mock_tts_model.VOICES_DIR = "/mock/voices"
with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt', 'voice3.pt']):
with patch("os.listdir", return_value=["voice1.pt", "voice2.pt", "voice3.pt"]):
mock_tts_model.initialize.return_value = (mock_model, 3) # 3 voice files
mock_tts_model._device = "cuda" # Set device class variable
# Create an async generator from the lifespan context manager
async_gen = lifespan(MagicMock())
# Start the context manager
await async_gen.__aenter__()
# Verify the expected logging sequence
mock_logger.info.assert_any_call("Loading TTS model and voice packs...")
mock_logger.info.assert_any_call("Model loaded and warmed up on cuda")
mock_logger.info.assert_any_call("3 voice packs loaded successfully")
# Verify model initialization was called
mock_tts_model.initialize.assert_called_once()
# Clean up
await async_gen.__aexit__(None, None, None)
@pytest.mark.asyncio
@patch('api.src.main.TTSModel')
@patch('api.src.main.logger')
@patch("api.src.main.TTSModel")
@patch("api.src.main.logger")
async def test_lifespan_failed_warmup(mock_logger, mock_tts_model):
"""Test failed model warmup in lifespan"""
# Mock the model initialization to fail
mock_tts_model.initialize.side_effect = Exception("Failed to initialize model")
# Create an async generator from the lifespan context manager
async_gen = lifespan(MagicMock())
# Verify the exception is raised
with pytest.raises(Exception, match="Failed to initialize model"):
await async_gen.__aenter__()
# Verify the expected logging sequence
mock_logger.info.assert_called_with("Loading TTS model and voice packs...")
# Clean up
await async_gen.__aexit__(None, None, None)
@pytest.mark.asyncio
@patch('api.src.main.TTSModel')
@patch("api.src.main.TTSModel")
async def test_lifespan_cuda_warmup(mock_tts_model):
"""Test model warmup specifically on CUDA"""
# Mock the model initialization with CUDA and voicepacks
mock_model = MagicMock()
# Mock file system for voice counting
mock_tts_model.VOICES_DIR = "/mock/voices"
with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt']):
with patch("os.listdir", return_value=["voice1.pt", "voice2.pt"]):
mock_tts_model.initialize.return_value = (mock_model, 2) # 2 voice files
mock_tts_model._device = "cuda" # Set device class variable
# Create an async generator from the lifespan context manager
async_gen = lifespan(MagicMock())
await async_gen.__aenter__()
# Verify model was initialized
mock_tts_model.initialize.assert_called_once()
# Clean up
await async_gen.__aexit__(None, None, None)
@pytest.mark.asyncio
@patch('api.src.main.TTSModel')
@patch("api.src.main.TTSModel")
async def test_lifespan_cpu_fallback(mock_tts_model):
"""Test model warmup falling back to CPU"""
# Mock the model initialization with CPU and voicepacks
mock_model = MagicMock()
# Mock file system for voice counting
mock_tts_model.VOICES_DIR = "/mock/voices"
with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt', 'voice3.pt', 'voice4.pt']):
with patch(
"os.listdir", return_value=["voice1.pt", "voice2.pt", "voice3.pt", "voice4.pt"]
):
mock_tts_model.initialize.return_value = (mock_model, 4) # 4 voice files
mock_tts_model._device = "cpu" # Set device class variable
# Create an async generator from the lifespan context manager
async_gen = lifespan(MagicMock())
await async_gen.__aenter__()
# Verify model was initialized
mock_tts_model.initialize.assert_called_once()
# Clean up
await async_gen.__aexit__(None, None, None)

View file

@ -1,9 +1,12 @@
"""Tests for TTSService"""
import os
from unittest.mock import MagicMock, call, patch
import numpy as np
import pytest
from unittest.mock import patch, MagicMock, call
from api.src.services.tts import TTSService, TTSModel
from api.src.services.tts import TTSModel, TTSService
@pytest.fixture
@ -50,42 +53,59 @@ def test_audio_to_bytes(tts_service, sample_audio):
assert len(audio_bytes) > 0
@patch('os.listdir')
@patch('os.path.join')
@patch("os.listdir")
@patch("os.path.join")
def test_list_voices(mock_join, mock_listdir, tts_service):
"""Test listing available voices"""
mock_listdir.return_value = ['voice1.pt', 'voice2.pt', 'not_a_voice.txt']
mock_join.return_value = '/fake/path'
mock_listdir.return_value = ["voice1.pt", "voice2.pt", "not_a_voice.txt"]
mock_join.return_value = "/fake/path"
voices = tts_service.list_voices()
assert len(voices) == 2
assert 'voice1' in voices
assert 'voice2' in voices
assert 'not_a_voice' not in voices
assert "voice1" in voices
assert "voice2" in voices
assert "not_a_voice" not in voices
@patch('api.src.services.tts.TTSModel.get_instance')
@patch('api.src.services.tts.TTSModel.get_voicepack')
@patch('api.src.services.tts.normalize_text')
@patch('api.src.services.tts.phonemize')
@patch('api.src.services.tts.tokenize')
@patch('api.src.services.tts.generate')
def test_generate_audio_empty_text(mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_voicepack, mock_instance, tts_service):
@patch("api.src.services.tts.TTSModel.get_instance")
@patch("api.src.services.tts.TTSModel.get_voicepack")
@patch("api.src.services.tts.normalize_text")
@patch("api.src.services.tts.phonemize")
@patch("api.src.services.tts.tokenize")
@patch("api.src.services.tts.generate")
def test_generate_audio_empty_text(
mock_generate,
mock_tokenize,
mock_phonemize,
mock_normalize,
mock_voicepack,
mock_instance,
tts_service,
):
"""Test generating audio with empty text"""
mock_normalize.return_value = ""
with pytest.raises(ValueError, match="Text is empty after preprocessing"):
tts_service._generate_audio("", "af", 1.0)
@patch('api.src.services.tts.TTSModel.get_instance')
@patch('os.path.exists')
@patch('api.src.services.tts.normalize_text')
@patch('api.src.services.tts.phonemize')
@patch('api.src.services.tts.tokenize')
@patch('api.src.services.tts.generate')
@patch('torch.load')
def test_generate_audio_no_chunks(mock_torch_load, mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_exists, mock_instance, tts_service):
@patch("api.src.services.tts.TTSModel.get_instance")
@patch("os.path.exists")
@patch("api.src.services.tts.normalize_text")
@patch("api.src.services.tts.phonemize")
@patch("api.src.services.tts.tokenize")
@patch("api.src.services.tts.generate")
@patch("torch.load")
def test_generate_audio_no_chunks(
mock_torch_load,
mock_generate,
mock_tokenize,
mock_phonemize,
mock_normalize,
mock_exists,
mock_instance,
tts_service,
):
"""Test generating audio with no successful chunks"""
mock_normalize.return_value = "Test text"
mock_phonemize.return_value = "Test text"
@ -94,19 +114,29 @@ def test_generate_audio_no_chunks(mock_torch_load, mock_generate, mock_tokenize,
mock_instance.return_value = (MagicMock(), "cpu")
mock_exists.return_value = True
mock_torch_load.return_value = MagicMock()
with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
tts_service._generate_audio("Test text", "af", 1.0)
@patch('api.src.services.tts.TTSModel.get_instance')
@patch('os.path.exists')
@patch('api.src.services.tts.normalize_text')
@patch('api.src.services.tts.phonemize')
@patch('api.src.services.tts.tokenize')
@patch('api.src.services.tts.generate')
@patch('torch.load')
def test_generate_audio_success(mock_torch_load, mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_exists, mock_instance, tts_service, sample_audio):
@patch("api.src.services.tts.TTSModel.get_instance")
@patch("os.path.exists")
@patch("api.src.services.tts.normalize_text")
@patch("api.src.services.tts.phonemize")
@patch("api.src.services.tts.tokenize")
@patch("api.src.services.tts.generate")
@patch("torch.load")
def test_generate_audio_success(
mock_torch_load,
mock_generate,
mock_tokenize,
mock_phonemize,
mock_normalize,
mock_exists,
mock_instance,
tts_service,
sample_audio,
):
"""Test successful audio generation"""
mock_normalize.return_value = "Test text"
mock_phonemize.return_value = "Test text"
@ -115,15 +145,15 @@ def test_generate_audio_success(mock_torch_load, mock_generate, mock_tokenize, m
mock_instance.return_value = (MagicMock(), "cpu")
mock_exists.return_value = True
mock_torch_load.return_value = MagicMock()
audio, processing_time = tts_service._generate_audio("Test text", "af", 1.0)
assert isinstance(audio, np.ndarray)
assert isinstance(processing_time, float)
assert len(audio) > 0
@patch('api.src.services.tts.torch.cuda.is_available')
@patch('api.src.services.tts.build_model')
@patch("api.src.services.tts.torch.cuda.is_available")
@patch("api.src.services.tts.build_model")
def test_model_initialization_cuda(mock_build_model, mock_cuda_available):
"""Test model initialization with CUDA"""
mock_cuda_available.return_value = True
@ -132,14 +162,14 @@ def test_model_initialization_cuda(mock_build_model, mock_cuda_available):
TTSModel._instance = None # Reset singleton
model, voice_count = TTSModel.initialize()
assert TTSModel._device == "cuda" # Check the class variable instead
assert model == mock_model
mock_build_model.assert_called_once()
@patch('api.src.services.tts.torch.cuda.is_available')
@patch('api.src.services.tts.build_model')
@patch("api.src.services.tts.torch.cuda.is_available")
@patch("api.src.services.tts.build_model")
def test_model_initialization_cpu(mock_build_model, mock_cuda_available):
"""Test model initialization with CPU"""
mock_cuda_available.return_value = False
@ -148,76 +178,95 @@ def test_model_initialization_cpu(mock_build_model, mock_cuda_available):
TTSModel._instance = None # Reset singleton
model, voice_count = TTSModel.initialize()
assert TTSModel._device == "cpu" # Check the class variable instead
assert model == mock_model
mock_build_model.assert_called_once()
@patch('api.src.services.tts.TTSService._get_voice_path')
@patch('api.src.services.tts.TTSModel.get_instance')
@patch("api.src.services.tts.TTSService._get_voice_path")
@patch("api.src.services.tts.TTSModel.get_instance")
def test_voicepack_loading_error(mock_get_instance, mock_get_voice_path):
"""Test voicepack loading error handling"""
mock_get_voice_path.return_value = None
mock_get_instance.return_value = (MagicMock(), "cpu")
TTSModel._voicepacks = {} # Reset voicepacks
service = TTSService(start_worker=False)
with pytest.raises(ValueError, match="Voice not found: nonexistent_voice"):
service._generate_audio("test", "nonexistent_voice", 1.0)
@patch('api.src.services.tts.TTSModel')
@patch("api.src.services.tts.TTSModel")
def test_save_audio(mock_tts_model, tts_service, sample_audio, tmp_path):
"""Test saving audio to file"""
output_dir = os.path.join(tmp_path, "test_output")
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, "audio.wav")
tts_service._save_audio(sample_audio, output_path)
assert os.path.exists(output_path)
assert os.path.getsize(output_path) > 0
@patch('api.src.services.tts.TTSModel.get_instance')
@patch('os.path.exists')
@patch('api.src.services.tts.normalize_text')
@patch('api.src.services.tts.generate')
@patch('torch.load')
def test_generate_audio_without_stitching(mock_torch_load, mock_generate, mock_normalize, mock_exists, mock_instance, tts_service, sample_audio):
@patch("api.src.services.tts.TTSModel.get_instance")
@patch("os.path.exists")
@patch("api.src.services.tts.normalize_text")
@patch("api.src.services.tts.generate")
@patch("torch.load")
def test_generate_audio_without_stitching(
mock_torch_load,
mock_generate,
mock_normalize,
mock_exists,
mock_instance,
tts_service,
sample_audio,
):
"""Test generating audio without text stitching"""
mock_normalize.return_value = "Test text"
mock_generate.return_value = (sample_audio, None)
mock_instance.return_value = (MagicMock(), "cpu")
mock_exists.return_value = True
mock_torch_load.return_value = MagicMock()
audio, processing_time = tts_service._generate_audio("Test text", "af", 1.0, stitch_long_output=False)
audio, processing_time = tts_service._generate_audio(
"Test text", "af", 1.0, stitch_long_output=False
)
assert isinstance(audio, np.ndarray)
assert isinstance(processing_time, float)
assert len(audio) > 0
mock_generate.assert_called_once()
@patch('os.listdir')
@patch("os.listdir")
def test_list_voices_error(mock_listdir, tts_service):
"""Test error handling in list_voices"""
mock_listdir.side_effect = Exception("Failed to list directory")
voices = tts_service.list_voices()
assert voices == []
@patch('api.src.services.tts.TTSModel.get_instance')
@patch('os.path.exists')
@patch('api.src.services.tts.normalize_text')
@patch('api.src.services.tts.phonemize')
@patch('api.src.services.tts.tokenize')
@patch('api.src.services.tts.generate')
@patch('torch.load')
def test_generate_audio_phonemize_error(mock_torch_load, mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_exists, mock_instance, tts_service):
@patch("api.src.services.tts.TTSModel.get_instance")
@patch("os.path.exists")
@patch("api.src.services.tts.normalize_text")
@patch("api.src.services.tts.phonemize")
@patch("api.src.services.tts.tokenize")
@patch("api.src.services.tts.generate")
@patch("torch.load")
def test_generate_audio_phonemize_error(
mock_torch_load,
mock_generate,
mock_tokenize,
mock_phonemize,
mock_normalize,
mock_exists,
mock_instance,
tts_service,
):
"""Test handling phonemization error"""
mock_normalize.return_value = "Test text"
mock_phonemize.side_effect = Exception("Phonemization failed")
@ -225,23 +274,30 @@ def test_generate_audio_phonemize_error(mock_torch_load, mock_generate, mock_tok
mock_exists.return_value = True
mock_torch_load.return_value = MagicMock()
mock_generate.return_value = (None, None)
with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
tts_service._generate_audio("Test text", "af", 1.0)
@patch('api.src.services.tts.TTSModel.get_instance')
@patch('os.path.exists')
@patch('api.src.services.tts.normalize_text')
@patch('api.src.services.tts.generate')
@patch('torch.load')
def test_generate_audio_error(mock_torch_load, mock_generate, mock_normalize, mock_exists, mock_instance, tts_service):
@patch("api.src.services.tts.TTSModel.get_instance")
@patch("os.path.exists")
@patch("api.src.services.tts.normalize_text")
@patch("api.src.services.tts.generate")
@patch("torch.load")
def test_generate_audio_error(
mock_torch_load,
mock_generate,
mock_normalize,
mock_exists,
mock_instance,
tts_service,
):
"""Test handling generation error"""
mock_normalize.return_value = "Test text"
mock_generate.side_effect = Exception("Generation failed")
mock_instance.return_value = (MagicMock(), "cpu")
mock_exists.return_value = True
mock_torch_load.return_value = MagicMock()
with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
tts_service._generate_audio("Test text", "af", 1.0)

View file

@ -19,7 +19,6 @@ output_dir = Path(__file__).parent / "output"
output_dir.mkdir(exist_ok=True)
def test_voice(voice: str):
speech_file = output_dir / f"speech_{voice}.mp3"
print(f"\nTesting voice: {voice}")

View file

@ -1,21 +1,23 @@
#!/usr/bin/env python3
import argparse
import os
from typing import List, Optional, Dict, Tuple
import argparse
from typing import Dict, List, Tuple, Optional
import requests
import numpy as np
from scipy.io import wavfile
import requests
import matplotlib.pyplot as plt
from scipy.io import wavfile
def submit_combine_voices(voices: List[str], base_url: str = "http://localhost:8880") -> Optional[str]:
def submit_combine_voices(
voices: List[str], base_url: str = "http://localhost:8880"
) -> Optional[str]:
"""Combine multiple voices into a new voice.
Args:
voices: List of voice names to combine (e.g. ["af_bella", "af_sarah"])
base_url: API base URL
Returns:
Name of the combined voice (e.g. "af_bella_af_sarah") or None if error
"""
@ -23,7 +25,7 @@ def submit_combine_voices(voices: List[str], base_url: str = "http://localhost:8
response = requests.post(f"{base_url}/v1/audio/voices/combine", json=voices)
print(f"Response status: {response.status_code}")
print(f"Raw response: {response.text}")
# Accept both 200 and 201 as success
if response.status_code not in [200, 201]:
try:
@ -32,7 +34,7 @@ def submit_combine_voices(voices: List[str], base_url: str = "http://localhost:8
except:
print(f"Error combining voices: {response.text}")
return None
try:
data = response.json()
if "voices" in data:
@ -46,15 +48,20 @@ def submit_combine_voices(voices: List[str], base_url: str = "http://localhost:8
return None
def generate_speech(text: str, voice: str, base_url: str = "http://localhost:8880", output_file: str = "output.mp3") -> bool:
def generate_speech(
text: str,
voice: str,
base_url: str = "http://localhost:8880",
output_file: str = "output.mp3",
) -> bool:
"""Generate speech using specified voice.
Args:
text: Text to convert to speech
voice: Voice name to use
base_url: API base URL
output_file: Path to save audio file
Returns:
True if successful, False otherwise
"""
@ -65,22 +72,25 @@ def generate_speech(text: str, voice: str, base_url: str = "http://localhost:888
"input": text,
"voice": voice,
"speed": 1.0,
"response_format": "wav" # Use WAV for analysis
}
"response_format": "wav", # Use WAV for analysis
},
)
if response.status_code != 200:
error = response.json().get("detail", {}).get("message", response.text)
print(f"Error generating speech: {error}")
return False
# Save the audio
os.makedirs(os.path.dirname(output_file) if os.path.dirname(output_file) else ".", exist_ok=True)
os.makedirs(
os.path.dirname(output_file) if os.path.dirname(output_file) else ".",
exist_ok=True,
)
with open(output_file, "wb") as f:
f.write(response.content)
print(f"Saved audio to {output_file}")
return True
except Exception as e:
print(f"Error: {e}")
return False
@ -88,57 +98,57 @@ def generate_speech(text: str, voice: str, base_url: str = "http://localhost:888
def analyze_audio(filepath: str) -> Tuple[np.ndarray, int, dict]:
"""Analyze audio file and return samples, sample rate, and audio characteristics.
Args:
filepath: Path to audio file
Returns:
Tuple of (samples, sample_rate, characteristics)
"""
sample_rate, samples = wavfile.read(filepath)
# Convert to mono if stereo
if len(samples.shape) > 1:
samples = np.mean(samples, axis=1)
# Calculate basic stats
max_amp = np.max(np.abs(samples))
rms = np.sqrt(np.mean(samples**2))
duration = len(samples) / sample_rate
# Zero crossing rate (helps identify voice characteristics)
zero_crossings = np.sum(np.abs(np.diff(np.signbit(samples)))) / len(samples)
# Simple frequency analysis
if len(samples) > 0:
# Use FFT to get frequency components
fft_result = np.fft.fft(samples)
freqs = np.fft.fftfreq(len(samples), 1/sample_rate)
freqs = np.fft.fftfreq(len(samples), 1 / sample_rate)
# Get positive frequencies only
pos_mask = freqs > 0
freqs = freqs[pos_mask]
magnitudes = np.abs(fft_result)[pos_mask]
# Find dominant frequencies (top 3)
top_indices = np.argsort(magnitudes)[-3:]
dominant_freqs = freqs[top_indices]
# Calculate spectral centroid (brightness of sound)
spectral_centroid = np.sum(freqs * magnitudes) / np.sum(magnitudes)
else:
dominant_freqs = []
spectral_centroid = 0
characteristics = {
"max_amplitude": max_amp,
"rms": rms,
"duration": duration,
"zero_crossing_rate": zero_crossings,
"dominant_frequencies": dominant_freqs,
"spectral_centroid": spectral_centroid
"spectral_centroid": spectral_centroid,
}
return samples, sample_rate, characteristics
@ -167,112 +177,136 @@ def setup_plot(fig, ax, title):
return fig, ax
def plot_analysis(audio_files: Dict[str, str], output_dir: str):
"""Plot comprehensive voice analysis including waveforms and metrics comparison.
Args:
audio_files: Dictionary of label -> filepath
output_dir: Directory to save plot files
"""
# Set dark style
plt.style.use('dark_background')
plt.style.use("dark_background")
# Create figure with subplots
fig = plt.figure(figsize=(15, 15))
fig.patch.set_facecolor("#1a1a2e")
num_files = len(audio_files)
# Create subplot grid with proper spacing
gs = plt.GridSpec(num_files + 1, 2, height_ratios=[1.5]*num_files + [1],
hspace=0.4, wspace=0.3)
gs = plt.GridSpec(
num_files + 1, 2, height_ratios=[1.5] * num_files + [1], hspace=0.4, wspace=0.3
)
# Analyze all files first
all_chars = {}
for i, (label, filepath) in enumerate(audio_files.items()):
samples, sample_rate, chars = analyze_audio(filepath)
all_chars[label] = chars
# Plot waveform spanning both columns
ax = plt.subplot(gs[i, :])
time = np.arange(len(samples)) / sample_rate
plt.plot(time, samples / chars['max_amplitude'], linewidth=0.5, color="#ff2a6d")
plt.plot(time, samples / chars["max_amplitude"], linewidth=0.5, color="#ff2a6d")
ax.set_xlabel("Time (seconds)")
ax.set_ylabel("Normalized Amplitude")
ax.set_ylim(-1.1, 1.1)
setup_plot(fig, ax, f"Waveform: {label}")
# Colors for voices
colors = ["#ff2a6d", "#05d9e8", "#d1f7ff"]
# Create two subplots for metrics with similar scales
# Left subplot: Brightness and Volume
ax1 = plt.subplot(gs[num_files, 0])
metrics1 = [
('Brightness', [chars['spectral_centroid']/1000 for chars in all_chars.values()], 'kHz'),
('Volume', [chars['rms']*100 for chars in all_chars.values()], 'RMS×100')
(
"Brightness",
[chars["spectral_centroid"] / 1000 for chars in all_chars.values()],
"kHz",
),
("Volume", [chars["rms"] * 100 for chars in all_chars.values()], "RMS×100"),
]
# Right subplot: Voice Pitch and Texture
ax2 = plt.subplot(gs[num_files, 1])
metrics2 = [
('Voice Pitch', [min(chars['dominant_frequencies']) for chars in all_chars.values()], 'Hz'),
('Texture', [chars['zero_crossing_rate']*1000 for chars in all_chars.values()], 'ZCR×1000')
(
"Voice Pitch",
[min(chars["dominant_frequencies"]) for chars in all_chars.values()],
"Hz",
),
(
"Texture",
[chars["zero_crossing_rate"] * 1000 for chars in all_chars.values()],
"ZCR×1000",
),
]
def plot_grouped_bars(ax, metrics, show_legend=True):
n_groups = len(metrics)
n_voices = len(audio_files)
bar_width = 0.25
indices = np.arange(n_groups)
# Get max value for y-axis scaling
max_val = max(max(m[1]) for m in metrics)
for i, (voice, color) in enumerate(zip(audio_files.keys(), colors)):
values = [m[1][i] for m in metrics]
offset = (i - n_voices/2 + 0.5) * bar_width
bars = ax.bar(indices + offset, values, bar_width,
label=voice, color=color, alpha=0.8)
offset = (i - n_voices / 2 + 0.5) * bar_width
bars = ax.bar(
indices + offset, values, bar_width, label=voice, color=color, alpha=0.8
)
# Add value labels on top of bars
for bar in bars:
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height,
f'{height:.1f}',
ha='center', va='bottom', color='white',
fontsize=10)
ax.text(
bar.get_x() + bar.get_width() / 2.0,
height,
f"{height:.1f}",
ha="center",
va="bottom",
color="white",
fontsize=10,
)
ax.set_xticks(indices)
ax.set_xticklabels([f"{m[0]}\n({m[2]})" for m in metrics])
# Set y-axis limits with some padding
ax.set_ylim(0, max_val * 1.2)
if show_legend:
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left',
facecolor="#1a1a2e", edgecolor="#ffffff")
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
facecolor="#1a1a2e",
edgecolor="#ffffff",
)
# Plot both subplots
plot_grouped_bars(ax1, metrics1, show_legend=True)
plot_grouped_bars(ax2, metrics2, show_legend=False)
# Style both subplots
setup_plot(fig, ax1, 'Brightness and Volume')
setup_plot(fig, ax2, 'Voice Pitch and Texture')
setup_plot(fig, ax1, "Brightness and Volume")
setup_plot(fig, ax2, "Voice Pitch and Texture")
# Add y-axis labels
ax1.set_ylabel('Value')
ax2.set_ylabel('Value')
ax1.set_ylabel("Value")
ax2.set_ylabel("Value")
# Adjust the figure size to accommodate the legend
fig.set_size_inches(15, 15)
# Add padding around the entire figure
plt.subplots_adjust(right=0.85, top=0.95, bottom=0.05, left=0.1)
plt.savefig(os.path.join(output_dir, "analysis_comparison.png"), dpi=300)
print(f"Saved analysis comparison to {output_dir}/analysis_comparison.png")
# Print detailed comparative analysis
print("\nDetailed Voice Analysis:")
for label, chars in all_chars.items():
@ -282,44 +316,57 @@ def plot_analysis(audio_files: Dict[str, str], output_dir: str):
print(f" Duration: {chars['duration']:.2f}s")
print(f" Zero Crossing Rate: {chars['zero_crossing_rate']:.3f}")
print(f" Spectral Centroid: {chars['spectral_centroid']:.0f}Hz")
print(f" Dominant Frequencies: {', '.join(f'{f:.0f}Hz' for f in chars['dominant_frequencies'])}")
print(
f" Dominant Frequencies: {', '.join(f'{f:.0f}Hz' for f in chars['dominant_frequencies'])}"
)
def main():
parser = argparse.ArgumentParser(description="Kokoro Voice Analysis Demo")
parser.add_argument("--voices", nargs="+", type=str, help="Voices to combine")
parser.add_argument("--text", type=str, default="Hello! This is a test of combined voices.", help="Text to speak")
parser.add_argument(
"--text",
type=str,
default="Hello! This is a test of combined voices.",
help="Text to speak",
)
parser.add_argument("--url", default="http://localhost:8880", help="API base URL")
parser.add_argument("--output-dir", default="examples/output", help="Output directory for audio files")
parser.add_argument(
"--output-dir",
default="examples/output",
help="Output directory for audio files",
)
args = parser.parse_args()
if not args.voices:
print("No voices provided, using default test voices")
args.voices = ["af_bella", "af_nicole"]
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Dictionary to store audio files for analysis
audio_files = {}
# Generate speech with individual voices
print("Generating speech with individual voices...")
for voice in args.voices:
output_file = os.path.join(args.output_dir, f"analysis_{voice}.wav")
if generate_speech(args.text, voice, args.url, output_file):
audio_files[voice] = output_file
# Generate speech with combined voice
print(f"\nCombining voices: {', '.join(args.voices)}")
combined_voice = submit_combine_voices(args.voices, args.url)
if combined_voice:
print(f"Successfully created combined voice: {combined_voice}")
output_file = os.path.join(args.output_dir, f"analysis_combined_{combined_voice}.wav")
output_file = os.path.join(
args.output_dir, f"analysis_combined_{combined_voice}.wav"
)
if generate_speech(args.text, combined_voice, args.url, output_file):
audio_files["combined"] = output_file
# Generate comparison plots
plot_analysis(audio_files, args.output_dir)
else:

View file

@ -60,7 +60,7 @@ def test_speed(speed: float):
# Test different formats
for format in ["wav", "mp3", "opus", "aac", "flac", "pcm"]:
test_format(format) # aac and pcm should fail as they are not supported
test_format(format) # aac and pcm should fail as they are not supported
# Test different speeds
for speed in [0.25, 1.0, 2.0, 4.0]: # 5.0 should fail as it's out of range

View file

@ -10,3 +10,5 @@ sqlalchemy==2.0.27
pytest==8.0.0
httpx==0.26.0
pytest-asyncio==0.23.5
pytest-cov==6.0.0
gradio==4.19.2

View file

@ -2,8 +2,4 @@ from lib.interface import create_interface
if __name__ == "__main__":
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True
)
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)

View file

@ -1,16 +1,19 @@
import requests
from typing import Tuple, List, Optional
import os
import datetime
from typing import List, Tuple, Optional
import requests
from .config import API_URL, OUTPUTS_DIR
def check_api_status() -> Tuple[bool, List[str]]:
"""Check TTS service status and get available voices."""
try:
# Use a longer timeout during startup
response = requests.get(
f"{API_URL}/v1/audio/voices",
timeout=30 # Increased timeout for initial startup period
timeout=30, # Increased timeout for initial startup period
)
response.raise_for_status()
voices = response.json().get("voices", [])
@ -31,16 +34,19 @@ def check_api_status() -> Tuple[bool, List[str]]:
print(f"Unexpected error checking API status: {str(e)}")
return False, []
def text_to_speech(text: str, voice_id: str, format: str, speed: float) -> Optional[str]:
def text_to_speech(
text: str, voice_id: str, format: str, speed: float
) -> Optional[str]:
"""Generate speech from text using TTS API."""
if not text.strip():
return None
# Create output filename
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
output_filename = f"output_{timestamp}_voice-{voice_id}_speed-{speed}.{format}"
output_path = os.path.join(OUTPUTS_DIR, output_filename)
try:
response = requests.post(
f"{API_URL}/v1/audio/speech",
@ -49,17 +55,17 @@ def text_to_speech(text: str, voice_id: str, format: str, speed: float) -> Optio
"input": text,
"voice": voice_id,
"response_format": format,
"speed": float(speed)
"speed": float(speed),
},
headers={"Content-Type": "application/json"},
timeout=300 # Longer timeout for speech generation
timeout=300, # Longer timeout for speech generation
)
response.raise_for_status()
with open(output_path, "wb") as f:
f.write(response.content)
return output_path
except requests.exceptions.Timeout:
print("Speech generation request timed out")
return None
@ -70,6 +76,7 @@ def text_to_speech(text: str, voice_id: str, format: str, speed: float) -> Optio
print(f"Unexpected error generating speech: {str(e)}")
return None
def get_status_html(is_available: bool) -> str:
"""Generate HTML for status indicator."""
color = "green" if is_available else "red"

View file

@ -1,7 +1,10 @@
import gradio as gr
from typing import Tuple
import gradio as gr
from .. import files
def create_input_column() -> Tuple[gr.Column, dict]:
"""Create the input column with text input and file handling."""
with gr.Column(scale=1) as col:
@ -11,49 +14,36 @@ def create_input_column() -> Tuple[gr.Column, dict]:
# Direct Input Tab
with gr.TabItem("Direct Input"):
text_input = gr.Textbox(
label="Text to speak",
placeholder="Enter text here...",
lines=4
label="Text to speak", placeholder="Enter text here...", lines=4
)
text_submit = gr.Button(
"Generate Speech",
variant="primary",
size="lg"
)
text_submit = gr.Button("Generate Speech", variant="primary", size="lg")
# File Input Tab
with gr.TabItem("From File"):
# Existing files dropdown
input_files_list = gr.Dropdown(
label="Select Existing File",
choices=files.list_input_files(),
value=None
value=None,
)
# Simple file upload
file_upload = gr.File(
label="Upload Text File (.txt)",
file_types=[".txt"]
label="Upload Text File (.txt)", file_types=[".txt"]
)
file_preview = gr.Textbox(
label="File Content Preview",
interactive=False,
lines=4
label="File Content Preview", interactive=False, lines=4
)
with gr.Row():
file_submit = gr.Button(
"Generate Speech",
variant="primary",
size="lg"
"Generate Speech", variant="primary", size="lg"
)
clear_files = gr.Button(
"Clear Files",
variant="secondary",
size="lg"
"Clear Files", variant="secondary", size="lg"
)
components = {
"tabs": tabs,
"text_input": text_input,
@ -62,7 +52,7 @@ def create_input_column() -> Tuple[gr.Column, dict]:
"file_preview": file_preview,
"text_submit": text_submit,
"file_submit": file_submit,
"clear_files": clear_files
"clear_files": clear_files,
}
return col, components

View file

@ -1,45 +1,41 @@
import gradio as gr
from typing import Tuple, Optional
import gradio as gr
from .. import api, config
def create_model_column(voice_ids: Optional[list] = None) -> Tuple[gr.Column, dict]:
"""Create the model settings column."""
if voice_ids is None:
voice_ids = []
with gr.Column(scale=1) as col:
gr.Markdown("### Model Settings")
# Status button starts in waiting state
status_btn = gr.Button(
"⌛ TTS Service: Waiting for Service...",
variant="secondary"
"⌛ TTS Service: Waiting for Service...", variant="secondary"
)
voice_input = gr.Dropdown(
choices=voice_ids,
label="Voice",
value=voice_ids[0] if voice_ids else None,
interactive=True
interactive=True,
)
format_input = gr.Dropdown(
choices=config.AUDIO_FORMATS,
label="Audio Format",
value="mp3"
choices=config.AUDIO_FORMATS, label="Audio Format", value="mp3"
)
speed_input = gr.Slider(
minimum=0.5,
maximum=2.0,
value=1.0,
step=0.1,
label="Speed"
minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speed"
)
components = {
"status_btn": status_btn,
"voice": voice_input,
"format": format_input,
"speed": speed_input
"speed": speed_input,
}
return col, components

View file

@ -1,40 +1,42 @@
import gradio as gr
from typing import Tuple
import gradio as gr
from .. import files
def create_output_column() -> Tuple[gr.Column, dict]:
"""Create the output column with audio player and file list."""
with gr.Column(scale=1) as col:
gr.Markdown("### Latest Output")
audio_output = gr.Audio(
label="Generated Speech",
type="filepath"
)
audio_output = gr.Audio(label="Generated Speech", type="filepath")
gr.Markdown("### Generated Files")
output_files = gr.Dropdown(
label="Previous Outputs",
choices=files.list_output_files(),
value=None,
allow_custom_value=False
allow_custom_value=False,
)
play_btn = gr.Button("▶️ Play Selected", size="sm")
selected_audio = gr.Audio(
label="Selected Output",
type="filepath",
visible=False
label="Selected Output", type="filepath", visible=False
)
clear_outputs = gr.Button("⚠️ Delete All Previously Generated Output Audio 🗑️", size="sm", variant="secondary")
clear_outputs = gr.Button(
"⚠️ Delete All Previously Generated Output Audio 🗑️",
size="sm",
variant="secondary",
)
components = {
"audio_output": audio_output,
"output_files": output_files,
"play_btn": play_btn,
"selected_audio": selected_audio,
"clear_outputs": clear_outputs
"clear_outputs": clear_outputs,
}
return col, components

View file

@ -1,17 +1,23 @@
import os
from typing import List, Optional, Tuple
import datetime
from typing import List, Tuple, Optional
from .config import INPUTS_DIR, OUTPUTS_DIR, AUDIO_FORMATS
def list_input_files() -> List[str]:
"""List all input text files."""
return [f for f in os.listdir(INPUTS_DIR) if f.endswith('.txt')]
return [f for f in os.listdir(INPUTS_DIR) if f.endswith(".txt")]
def list_output_files() -> List[str]:
"""List all output audio files."""
return [os.path.join(OUTPUTS_DIR, f)
for f in os.listdir(OUTPUTS_DIR)
if any(f.endswith(ext) for ext in AUDIO_FORMATS)]
return [
os.path.join(OUTPUTS_DIR, f)
for f in os.listdir(OUTPUTS_DIR)
if any(f.endswith(ext) for ext in AUDIO_FORMATS)
]
def read_text_file(filename: str) -> str:
"""Read content of a text file."""
@ -19,16 +25,17 @@ def read_text_file(filename: str) -> str:
return ""
try:
file_path = os.path.join(INPUTS_DIR, filename)
with open(file_path, 'r', encoding='utf-8') as f:
with open(file_path, "r", encoding="utf-8") as f:
return f.read()
except:
return ""
def save_text(text: str, filename: Optional[str] = None) -> Optional[str]:
"""Save text to a file. Returns the filename if successful."""
if not text.strip():
return None
if filename is None:
# Use input_1.txt, input_2.txt, etc.
base = "input"
@ -41,12 +48,12 @@ def save_text(text: str, filename: Optional[str] = None) -> Optional[str]:
else:
# Handle duplicate filenames by adding _1, _2, etc.
base = os.path.splitext(filename)[0]
ext = os.path.splitext(filename)[1] or '.txt'
ext = os.path.splitext(filename)[1] or ".txt"
counter = 1
while os.path.exists(os.path.join(INPUTS_DIR, filename)):
filename = f"{base}_{counter}{ext}"
counter += 1
filepath = os.path.join(INPUTS_DIR, filename)
try:
with open(filepath, "w", encoding="utf-8") as f:
@ -56,11 +63,12 @@ def save_text(text: str, filename: Optional[str] = None) -> Optional[str]:
print(f"Error saving file: {e}")
return None
def delete_all_input_files() -> bool:
"""Delete all files from the inputs directory. Returns True if successful."""
try:
for filename in os.listdir(INPUTS_DIR):
if filename.endswith('.txt'):
if filename.endswith(".txt"):
file_path = os.path.join(INPUTS_DIR, filename)
os.remove(file_path)
return True
@ -68,6 +76,7 @@ def delete_all_input_files() -> bool:
print(f"Error deleting input files: {e}")
return False
def delete_all_output_files() -> bool:
"""Delete all audio files from the outputs directory. Returns True if successful."""
try:
@ -80,19 +89,20 @@ def delete_all_output_files() -> bool:
print(f"Error deleting output files: {e}")
return False
def process_uploaded_file(file_path: str) -> bool:
"""Save uploaded file to inputs directory. Returns True if successful."""
if not file_path:
return False
try:
filename = os.path.basename(file_path)
if not filename.endswith('.txt'):
if not filename.endswith(".txt"):
return False
# Create target path in inputs directory
target_path = os.path.join(INPUTS_DIR, filename)
# If file exists, add number suffix
base, ext = os.path.splitext(filename)
counter = 1
@ -100,12 +110,13 @@ def process_uploaded_file(file_path: str) -> bool:
new_name = f"{base}_{counter}{ext}"
target_path = os.path.join(INPUTS_DIR, new_name)
counter += 1
# Copy file to inputs directory
import shutil
shutil.copy2(file_path, target_path)
return True
except Exception as e:
print(f"Error saving uploaded file: {e}")
return False

View file

@ -1,16 +1,19 @@
import gradio as gr
import os
import shutil
import gradio as gr
from . import api, files
def setup_event_handlers(components: dict):
"""Set up all event handlers for the UI components."""
def refresh_status():
try:
is_available, voices = api.check_api_status()
status = "Available" if is_available else "Waiting for Service..."
if is_available and voices:
# Preserve current voice selection if it exists and is still valid
current_voice = components["model"]["voice"].value
@ -19,17 +22,17 @@ def setup_event_handlers(components: dict):
gr.update(
value=f"🔄 TTS Service: {status}",
interactive=True,
variant="secondary"
variant="secondary",
),
gr.update(choices=voices, value=default_voice)
gr.update(choices=voices, value=default_voice),
]
return [
gr.update(
value=f"⌛ TTS Service: {status}",
interactive=True,
variant="secondary"
variant="secondary",
),
gr.update(choices=[], value=None)
gr.update(choices=[], value=None),
]
except Exception as e:
print(f"Error in refresh status: {str(e)}")
@ -37,11 +40,11 @@ def setup_event_handlers(components: dict):
gr.update(
value="❌ TTS Service: Connection Error",
interactive=True,
variant="secondary"
variant="secondary",
),
gr.update(choices=[], value=None)
gr.update(choices=[], value=None),
]
def handle_file_select(filename):
if filename:
try:
@ -52,16 +55,16 @@ def setup_event_handlers(components: dict):
except Exception as e:
print(f"Error reading file: {e}")
return gr.update(value="")
def handle_file_upload(file):
if file is None:
return gr.update(choices=files.list_input_files())
try:
# Copy file to inputs directory
filename = os.path.basename(file.name)
target_path = os.path.join(files.INPUTS_DIR, filename)
# Handle duplicate filenames
base, ext = os.path.splitext(filename)
counter = 1
@ -69,43 +72,36 @@ def setup_event_handlers(components: dict):
new_name = f"{base}_{counter}{ext}"
target_path = os.path.join(files.INPUTS_DIR, new_name)
counter += 1
shutil.copy2(file.name, target_path)
except Exception as e:
print(f"Error uploading file: {e}")
return gr.update(choices=files.list_input_files())
def generate_from_text(text, voice, format, speed):
"""Generate speech from direct text input"""
is_available, _ = api.check_api_status()
if not is_available:
gr.Warning("TTS Service is currently unavailable")
return [
None,
gr.update(choices=files.list_output_files())
]
return [None, gr.update(choices=files.list_output_files())]
if not text or not text.strip():
gr.Warning("Please enter text in the input box")
return [
None,
gr.update(choices=files.list_output_files())
]
return [None, gr.update(choices=files.list_output_files())]
files.save_text(text)
result = api.text_to_speech(text, voice, format, speed)
if result is None:
gr.Warning("Failed to generate speech. Please try again.")
return [
None,
gr.update(choices=files.list_output_files())
]
return [None, gr.update(choices=files.list_output_files())]
return [
result,
gr.update(choices=files.list_output_files(), value=os.path.basename(result))
gr.update(
choices=files.list_output_files(), value=os.path.basename(result)
),
]
def generate_from_file(selected_file, voice, format, speed):
@ -113,37 +109,30 @@ def setup_event_handlers(components: dict):
is_available, _ = api.check_api_status()
if not is_available:
gr.Warning("TTS Service is currently unavailable")
return [
None,
gr.update(choices=files.list_output_files())
]
return [None, gr.update(choices=files.list_output_files())]
if not selected_file:
gr.Warning("Please select a file")
return [
None,
gr.update(choices=files.list_output_files())
]
return [None, gr.update(choices=files.list_output_files())]
text = files.read_text_file(selected_file)
result = api.text_to_speech(text, voice, format, speed)
if result is None:
gr.Warning("Failed to generate speech. Please try again.")
return [
None,
gr.update(choices=files.list_output_files())
]
return [None, gr.update(choices=files.list_output_files())]
return [
result,
gr.update(choices=files.list_output_files(), value=os.path.basename(result))
gr.update(
choices=files.list_output_files(), value=os.path.basename(result)
),
]
def play_selected(file_path):
if file_path and os.path.exists(file_path):
return gr.update(value=file_path, visible=True)
return gr.update(visible=False)
def clear_files(voice, format, speed):
"""Delete all input files and clear UI components while preserving model settings"""
files.delete_all_input_files()
@ -155,7 +144,7 @@ def setup_event_handlers(components: dict):
gr.update(choices=files.list_output_files()), # output_files
gr.update(value=voice), # voice
gr.update(value=format), # format
gr.update(value=speed) # speed
gr.update(value=speed), # speed
]
def clear_outputs():
@ -164,43 +153,40 @@ def setup_event_handlers(components: dict):
return [
None, # audio_output
gr.update(choices=[], value=None), # output_files
gr.update(visible=False) # selected_audio
gr.update(visible=False), # selected_audio
]
# Connect event handlers
components["model"]["status_btn"].click(
fn=refresh_status,
outputs=[
components["model"]["status_btn"],
components["model"]["voice"]
]
outputs=[components["model"]["status_btn"], components["model"]["voice"]],
)
components["input"]["file_select"].change(
fn=handle_file_select,
inputs=[components["input"]["file_select"]],
outputs=[components["input"]["file_preview"]]
outputs=[components["input"]["file_preview"]],
)
components["input"]["file_upload"].upload(
fn=handle_file_upload,
inputs=[components["input"]["file_upload"]],
outputs=[components["input"]["file_select"]]
outputs=[components["input"]["file_select"]],
)
components["output"]["play_btn"].click(
fn=play_selected,
inputs=[components["output"]["output_files"]],
outputs=[components["output"]["selected_audio"]]
outputs=[components["output"]["selected_audio"]],
)
# Connect clear files button
components["input"]["clear_files"].click(
fn=clear_files,
inputs=[
components["model"]["voice"],
components["model"]["format"],
components["model"]["speed"]
components["model"]["speed"],
],
outputs=[
components["input"]["file_select"],
@ -210,10 +196,10 @@ def setup_event_handlers(components: dict):
components["output"]["output_files"],
components["model"]["voice"],
components["model"]["format"],
components["model"]["speed"]
]
components["model"]["speed"],
],
)
# Connect submit buttons for each tab
components["input"]["text_submit"].click(
fn=generate_from_text,
@ -221,22 +207,22 @@ def setup_event_handlers(components: dict):
components["input"]["text_input"],
components["model"]["voice"],
components["model"]["format"],
components["model"]["speed"]
components["model"]["speed"],
],
outputs=[
components["output"]["audio_output"],
components["output"]["output_files"]
]
components["output"]["output_files"],
],
)
# Connect clear outputs button
components["output"]["clear_outputs"].click(
fn=clear_outputs,
outputs=[
components["output"]["audio_output"],
components["output"]["output_files"],
components["output"]["selected_audio"]
]
components["output"]["selected_audio"],
],
)
components["input"]["file_submit"].click(
@ -245,10 +231,10 @@ def setup_event_handlers(components: dict):
components["input"]["file_select"],
components["model"]["voice"],
components["model"]["format"],
components["model"]["speed"]
components["model"]["speed"],
],
outputs=[
components["output"]["audio_output"],
components["output"]["output_files"]
]
components["output"]["output_files"],
],
)

View file

@ -1,69 +1,75 @@
import gradio as gr
from . import api
from .components import create_input_column, create_model_column, create_output_column
from .handlers import setup_event_handlers
from .components import create_input_column, create_model_column, create_output_column
def create_interface():
"""Create the main Gradio interface."""
# Skip initial status check - let the timer handle it
is_available, available_voices = False, []
with gr.Blocks(
title="Kokoro TTS Demo",
theme=gr.themes.Monochrome()
) as demo:
gr.HTML(value='<div style="display: flex; gap: 0;">'
'<a href="https://huggingface.co/hexgrad/Kokoro-82M" target="_blank" style="color: #2196F3; text-decoration: none; margin: 2px; border: 1px solid #2196F3; padding: 4px 8px; height: 24px; box-sizing: border-box; display: inline-flex; align-items: center;">Kokoro-82M HF Repo</a>'
'<a href="https://github.com/remsky/Kokoro-FastAPI" target="_blank" style="color: #2196F3; text-decoration: none; margin: 2px; border: 1px solid #2196F3; padding: 4px 8px; height: 24px; box-sizing: border-box; display: inline-flex; align-items: center;">Kokoro-FastAPI Repo</a>'
'</div>', show_label=False)
with gr.Blocks(title="Kokoro TTS Demo", theme=gr.themes.Monochrome()) as demo:
gr.HTML(
value='<div style="display: flex; gap: 0;">'
'<a href="https://huggingface.co/hexgrad/Kokoro-82M" target="_blank" style="color: #2196F3; text-decoration: none; margin: 2px; border: 1px solid #2196F3; padding: 4px 8px; height: 24px; box-sizing: border-box; display: inline-flex; align-items: center;">Kokoro-82M HF Repo</a>'
'<a href="https://github.com/remsky/Kokoro-FastAPI" target="_blank" style="color: #2196F3; text-decoration: none; margin: 2px; border: 1px solid #2196F3; padding: 4px 8px; height: 24px; box-sizing: border-box; display: inline-flex; align-items: center;">Kokoro-FastAPI Repo</a>'
"</div>",
show_label=False,
)
# Main interface
with gr.Row():
# Create columns
input_col, input_components = create_input_column()
model_col, model_components = create_model_column(available_voices) # Pass initial voices
model_col, model_components = create_model_column(
available_voices
) # Pass initial voices
output_col, output_components = create_output_column()
# Collect all components
components = {
"input": input_components,
"model": model_components,
"output": output_components
"output": output_components,
}
# Set up event handlers
setup_event_handlers(components)
# Add periodic status check with Timer
def update_status():
try:
is_available, voices = api.check_api_status()
status = "Available" if is_available else "Waiting for Service..."
if is_available and voices:
# Service is available, update UI and stop timer
current_voice = components["model"]["voice"].value
default_voice = current_voice if current_voice in voices else voices[0]
default_voice = (
current_voice if current_voice in voices else voices[0]
)
# Return values in same order as outputs list
return [
gr.update(
value=f"🔄 TTS Service: {status}",
interactive=True,
variant="secondary"
variant="secondary",
),
gr.update(choices=voices, value=default_voice),
gr.update(active=False) # Stop timer
gr.update(active=False), # Stop timer
]
# Service not available yet, keep checking
return [
gr.update(
value=f"⌛ TTS Service: {status}",
interactive=True,
variant="secondary"
variant="secondary",
),
gr.update(choices=[], value=None),
gr.update(active=True)
gr.update(active=True),
]
except Exception as e:
print(f"Error in status update: {str(e)}")
@ -72,20 +78,20 @@ def create_interface():
gr.update(
value="❌ TTS Service: Connection Error",
interactive=True,
variant="secondary"
variant="secondary",
),
gr.update(choices=[], value=None),
gr.update(active=True)
gr.update(active=True),
]
timer = gr.Timer(value=5) # Check every 5 seconds
timer.tick(
fn=update_status,
outputs=[
components["model"]["status_btn"],
components["model"]["voice"],
timer
]
timer,
],
)
return demo

View file

@ -1,5 +1,5 @@
import pytest
import gradio as gr
import pytest
@pytest.fixture

View file

@ -1,6 +1,8 @@
from unittest.mock import patch, mock_open
import pytest
import requests
from unittest.mock import patch, mock_open
from ui.lib import api
@ -57,12 +59,11 @@ def test_check_api_status_connection_error():
def test_text_to_speech_success(mock_response, tmp_path):
"""Test successful speech generation"""
with patch("requests.post", return_value=mock_response({})), \
patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)), \
patch("builtins.open", mock_open()) as mock_file:
with patch("requests.post", return_value=mock_response({})), patch(
"ui.lib.api.OUTPUTS_DIR", str(tmp_path)
), patch("builtins.open", mock_open()) as mock_file:
result = api.text_to_speech("test text", "voice1", "mp3", 1.0)
assert result is not None
assert "output_" in result
assert result.endswith(".mp3")
@ -105,25 +106,24 @@ def test_get_status_html_unavailable():
def test_text_to_speech_api_params(mock_response, tmp_path):
"""Test correct API parameters are sent"""
with patch("requests.post") as mock_post, \
patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)), \
patch("builtins.open", mock_open()):
with patch("requests.post") as mock_post, patch(
"ui.lib.api.OUTPUTS_DIR", str(tmp_path)
), patch("builtins.open", mock_open()):
mock_post.return_value = mock_response({})
api.text_to_speech("test text", "voice1", "mp3", 1.5)
mock_post.assert_called_once()
args, kwargs = mock_post.call_args
# Check request body
assert kwargs["json"] == {
"model": "kokoro",
"input": "test text",
"voice": "voice1",
"response_format": "mp3",
"speed": 1.5
"speed": 1.5,
}
# Check headers and timeout
assert kwargs["headers"] == {"Content-Type": "application/json"}
assert kwargs["timeout"] == 300

View file

@ -1,8 +1,9 @@
import pytest
import gradio as gr
import pytest
from ui.lib.config import AUDIO_FORMATS
from ui.lib.components.model import create_model_column
from ui.lib.components.output import create_output_column
from ui.lib.config import AUDIO_FORMATS
def test_create_model_column_structure():
@ -15,12 +16,7 @@ def test_create_model_column_structure():
assert isinstance(components, dict)
# Test expected components presence
expected_components = {
"status_btn",
"voice",
"format",
"speed"
}
expected_components = {"status_btn", "voice", "format", "speed"}
assert set(components.keys()) == expected_components
# Test component types
@ -78,7 +74,7 @@ def test_create_output_column_structure():
"output_files",
"play_btn",
"selected_audio",
"clear_outputs"
"clear_outputs",
}
assert set(components.keys()) == expected_components

View file

@ -1,6 +1,8 @@
import os
import pytest
from unittest.mock import patch
import pytest
from ui.lib import files
from ui.lib.config import AUDIO_FORMATS

View file

@ -1,5 +1,6 @@
import pytest
import gradio as gr
import pytest
from ui.lib.components.input import create_input_column

View file

@ -1,12 +1,15 @@
import pytest
from unittest.mock import MagicMock, PropertyMock, patch
import gradio as gr
from unittest.mock import patch, MagicMock, PropertyMock
import pytest
from ui.lib.interface import create_interface
@pytest.fixture
def mock_timer():
"""Create a mock timer with events property"""
class MockEvent:
def __init__(self, fn):
self.fn = fn
@ -30,7 +33,7 @@ def test_create_interface_structure():
"""Test the basic structure of the created interface"""
with patch("ui.lib.api.check_api_status", return_value=(False, [])):
demo = create_interface()
# Test interface type and theme
assert isinstance(demo, gr.Blocks)
assert demo.title == "Kokoro TTS Demo"
@ -41,15 +44,14 @@ def test_interface_html_links():
"""Test that HTML links are properly configured"""
with patch("ui.lib.api.check_api_status", return_value=(False, [])):
demo = create_interface()
# Find HTML component
html_components = [
comp for comp in demo.blocks.values()
if isinstance(comp, gr.HTML)
comp for comp in demo.blocks.values() if isinstance(comp, gr.HTML)
]
assert len(html_components) > 0
html = html_components[0]
# Check for required links
assert 'href="https://huggingface.co/hexgrad/Kokoro-82M"' in html.value
assert 'href="https://github.com/remsky/Kokoro-FastAPI"' in html.value
@ -60,16 +62,17 @@ def test_interface_html_links():
def test_update_status_available(mock_timer):
"""Test status update when service is available"""
voices = ["voice1", "voice2"]
with patch("ui.lib.api.check_api_status", return_value=(True, voices)), \
patch("gradio.Timer", return_value=mock_timer):
with patch("ui.lib.api.check_api_status", return_value=(True, voices)), patch(
"gradio.Timer", return_value=mock_timer
):
demo = create_interface()
# Get the update function
update_fn = mock_timer.events[0].fn
# Test update with available service
updates = update_fn()
assert "Available" in updates[0]["value"]
assert updates[1]["choices"] == voices
assert updates[1]["value"] == voices[0]
@ -78,13 +81,14 @@ def test_update_status_available(mock_timer):
def test_update_status_unavailable(mock_timer):
"""Test status update when service is unavailable"""
with patch("ui.lib.api.check_api_status", return_value=(False, [])), \
patch("gradio.Timer", return_value=mock_timer):
with patch("ui.lib.api.check_api_status", return_value=(False, [])), patch(
"gradio.Timer", return_value=mock_timer
):
demo = create_interface()
update_fn = mock_timer.events[0].fn
updates = update_fn()
assert "Waiting for Service" in updates[0]["value"]
assert updates[1]["choices"] == []
assert updates[1]["value"] is None
@ -93,13 +97,14 @@ def test_update_status_unavailable(mock_timer):
def test_update_status_error(mock_timer):
"""Test status update when an error occurs"""
with patch("ui.lib.api.check_api_status", side_effect=Exception("Test error")), \
patch("gradio.Timer", return_value=mock_timer):
with patch(
"ui.lib.api.check_api_status", side_effect=Exception("Test error")
), patch("gradio.Timer", return_value=mock_timer):
demo = create_interface()
update_fn = mock_timer.events[0].fn
updates = update_fn()
assert "Connection Error" in updates[0]["value"]
assert updates[1]["choices"] == []
assert updates[1]["value"] is None
@ -108,10 +113,11 @@ def test_update_status_error(mock_timer):
def test_timer_configuration(mock_timer):
"""Test timer configuration"""
with patch("ui.lib.api.check_api_status", return_value=(False, [])), \
patch("gradio.Timer", return_value=mock_timer):
with patch("ui.lib.api.check_api_status", return_value=(False, [])), patch(
"gradio.Timer", return_value=mock_timer
):
demo = create_interface()
assert mock_timer.value == 5 # Check interval is 5 seconds
assert len(mock_timer.events) == 1 # Should have one event handler
@ -120,20 +126,21 @@ def test_interface_components_presence():
"""Test that all required components are present"""
with patch("ui.lib.api.check_api_status", return_value=(False, [])):
demo = create_interface()
# Check for main component sections
components = {
comp.label for comp in demo.blocks.values()
if hasattr(comp, 'label') and comp.label
comp.label
for comp in demo.blocks.values()
if hasattr(comp, "label") and comp.label
}
required_components = {
"Text to speak",
"Voice",
"Audio Format",
"Speed",
"Generated Speech",
"Previous Outputs"
"Previous Outputs",
}
assert required_components.issubset(components)