Ruff Check + Format

This commit is contained in:
remsky 2025-01-01 21:50:41 -07:00
parent e749b3bc88
commit f051984805
27 changed files with 638 additions and 504 deletions

BIN
.coverage

Binary file not shown.

View file

@ -1,10 +1,10 @@
from typing import List from typing import List
from fastapi import APIRouter, Depends, HTTPException, Response
from loguru import logger from loguru import logger
from fastapi import Depends, Response, APIRouter, HTTPException
from ..services.audio import AudioService
from ..services.tts import TTSService from ..services.tts import TTSService
from ..services.audio import AudioService
from ..structures.schemas import OpenAISpeechRequest from ..structures.schemas import OpenAISpeechRequest
router = APIRouter( router = APIRouter(
@ -32,7 +32,7 @@ async def create_speech(
raise ValueError( raise ValueError(
f"Voice '{request.voice}' not found. Available voices: {', '.join(sorted(available_voices))}" f"Voice '{request.voice}' not found. Available voices: {', '.join(sorted(available_voices))}"
) )
# Generate audio directly using TTSService's method # Generate audio directly using TTSService's method
audio, _ = tts_service._generate_audio( audio, _ = tts_service._generate_audio(
text=request.input, text=request.input,
@ -55,14 +55,12 @@ async def create_speech(
except ValueError as e: except ValueError as e:
logger.error(f"Invalid request: {str(e)}") logger.error(f"Invalid request: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=400, status_code=400, detail={"error": "Invalid request", "message": str(e)}
detail={"error": "Invalid request", "message": str(e)}
) )
except Exception as e: except Exception as e:
logger.error(f"Error generating speech: {str(e)}") logger.error(f"Error generating speech: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail={"error": "Server error", "message": str(e)}
detail={"error": "Server error", "message": str(e)}
) )
@ -78,17 +76,19 @@ async def list_voices(tts_service: TTSService = Depends(get_tts_service)):
@router.post("/audio/voices/combine") @router.post("/audio/voices/combine")
async def combine_voices(request: List[str], tts_service: TTSService = Depends(get_tts_service)): async def combine_voices(
request: List[str], tts_service: TTSService = Depends(get_tts_service)
):
"""Combine multiple voices into a new voice. """Combine multiple voices into a new voice.
Args: Args:
request: List of voice names to combine request: List of voice names to combine
Returns: Returns:
Dict with combined voice name and list of all available voices Dict with combined voice name and list of all available voices
Raises: Raises:
HTTPException: HTTPException:
- 400: Invalid request (wrong number of voices, voice not found) - 400: Invalid request (wrong number of voices, voice not found)
- 500: Server error (file system issues, combination failed) - 500: Server error (file system issues, combination failed)
""" """
@ -96,24 +96,21 @@ async def combine_voices(request: List[str], tts_service: TTSService = Depends(g
combined_voice = tts_service.combine_voices(voices=request) combined_voice = tts_service.combine_voices(voices=request)
voices = tts_service.list_voices() voices = tts_service.list_voices()
return {"voices": voices, "voice": combined_voice} return {"voices": voices, "voice": combined_voice}
except ValueError as e: except ValueError as e:
logger.error(f"Invalid voice combination request: {str(e)}") logger.error(f"Invalid voice combination request: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=400, status_code=400, detail={"error": "Invalid request", "message": str(e)}
detail={"error": "Invalid request", "message": str(e)}
) )
except RuntimeError as e: except RuntimeError as e:
logger.error(f"Server error during voice combination: {str(e)}") logger.error(f"Server error during voice combination: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail={"error": "Server error", "message": str(e)}
detail={"error": "Server error", "message": str(e)}
) )
except Exception as e: except Exception as e:
logger.error(f"Unexpected error during voice combination: {str(e)}") logger.error(f"Unexpected error during voice combination: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail={"error": "Unexpected error", "message": str(e)}
detail={"error": "Unexpected error", "message": str(e)}
) )

View file

@ -1,17 +1,16 @@
import io import io
import os import os
import re import re
import threading
import time import time
import threading
from typing import List, Tuple, Optional from typing import List, Tuple, Optional
import numpy as np import numpy as np
import scipy.io.wavfile as wavfile
import tiktoken
import torch import torch
import tiktoken
import scipy.io.wavfile as wavfile
from kokoro import generate, tokenize, phonemize, normalize_text
from loguru import logger from loguru import logger
from kokoro import generate, normalize_text, phonemize, tokenize
from models import build_model from models import build_model
from ..core.config import settings from ..core.config import settings
@ -23,7 +22,7 @@ class TTSModel:
_instance = None _instance = None
_device = None _device = None
_lock = threading.Lock() _lock = threading.Lock()
# Directory for all voices (copied base voices, and any created combined voices) # Directory for all voices (copied base voices, and any created combined voices)
VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices") VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices")
@ -38,10 +37,10 @@ class TTSModel:
model_path = os.path.join(settings.model_dir, settings.model_path) model_path = os.path.join(settings.model_dir, settings.model_path)
model = build_model(model_path, cls._device) model = build_model(model_path, cls._device)
cls._instance = model cls._instance = model
# Ensure voices directory exists # Ensure voices directory exists
os.makedirs(cls.VOICES_DIR, exist_ok=True) os.makedirs(cls.VOICES_DIR, exist_ok=True)
# Copy base voices to local directory # Copy base voices to local directory
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir) base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
if os.path.exists(base_voices_dir): if os.path.exists(base_voices_dir):
@ -51,25 +50,37 @@ class TTSModel:
voice_path = os.path.join(cls.VOICES_DIR, file) voice_path = os.path.join(cls.VOICES_DIR, file)
if not os.path.exists(voice_path): if not os.path.exists(voice_path):
try: try:
logger.info(f"Copying base voice {voice_name} to voices directory") logger.info(
f"Copying base voice {voice_name} to voices directory"
)
base_path = os.path.join(base_voices_dir, file) base_path = os.path.join(base_voices_dir, file)
voicepack = torch.load(base_path, map_location=cls._device, weights_only=True) voicepack = torch.load(
base_path,
map_location=cls._device,
weights_only=True,
)
torch.save(voicepack, voice_path) torch.save(voicepack, voice_path)
except Exception as e: except Exception as e:
logger.error(f"Error copying voice {voice_name}: {str(e)}") logger.error(
f"Error copying voice {voice_name}: {str(e)}"
)
# Warm up with default voice # Warm up with default voice
try: try:
dummy_text = "Hello" dummy_text = "Hello"
voice_path = os.path.join(cls.VOICES_DIR, "af.pt") voice_path = os.path.join(cls.VOICES_DIR, "af.pt")
dummy_voicepack = torch.load(voice_path, map_location=cls._device, weights_only=True) dummy_voicepack = torch.load(
generate(model, dummy_text, dummy_voicepack, lang='a', speed=1.0) voice_path, map_location=cls._device, weights_only=True
)
generate(model, dummy_text, dummy_voicepack, lang="a", speed=1.0)
logger.info("Model warm-up complete") logger.info("Model warm-up complete")
except Exception as e: except Exception as e:
logger.warning(f"Model warm-up failed: {e}") logger.warning(f"Model warm-up failed: {e}")
# Count voices in directory for validation # Count voices in directory for validation
voice_count = len([f for f in os.listdir(cls.VOICES_DIR) if f.endswith('.pt')]) voice_count = len(
[f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]
)
return cls._instance, voice_count return cls._instance, voice_count
@classmethod @classmethod
@ -86,11 +97,11 @@ class TTSService:
self._ensure_voices() self._ensure_voices()
if start_worker: if start_worker:
self.start_worker() self.start_worker()
def _ensure_voices(self): def _ensure_voices(self):
"""Copy base voices to local voices directory during initialization""" """Copy base voices to local voices directory during initialization"""
os.makedirs(TTSModel.VOICES_DIR, exist_ok=True) os.makedirs(TTSModel.VOICES_DIR, exist_ok=True)
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir) base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
if os.path.exists(base_voices_dir): if os.path.exists(base_voices_dir):
for file in os.listdir(base_voices_dir): for file in os.listdir(base_voices_dir):
@ -99,9 +110,15 @@ class TTSService:
voice_path = os.path.join(TTSModel.VOICES_DIR, file) voice_path = os.path.join(TTSModel.VOICES_DIR, file)
if not os.path.exists(voice_path): if not os.path.exists(voice_path):
try: try:
logger.info(f"Copying base voice {voice_name} to voices directory") logger.info(
f"Copying base voice {voice_name} to voices directory"
)
base_path = os.path.join(base_voices_dir, file) base_path = os.path.join(base_voices_dir, file)
voicepack = torch.load(base_path, map_location=TTSModel._device, weights_only=True) voicepack = torch.load(
base_path,
map_location=TTSModel._device,
weights_only=True,
)
torch.save(voicepack, voice_path) torch.save(voicepack, voice_path)
except Exception as e: except Exception as e:
logger.error(f"Error copying voice {voice_name}: {str(e)}") logger.error(f"Error copying voice {voice_name}: {str(e)}")
@ -112,10 +129,10 @@ class TTSService:
def _get_voice_path(self, voice_name: str) -> Optional[str]: def _get_voice_path(self, voice_name: str) -> Optional[str]:
"""Get the path to a voice file. """Get the path to a voice file.
Args: Args:
voice_name: Name of the voice to find voice_name: Name of the voice to find
Returns: Returns:
Path to the voice file if found, None otherwise Path to the voice file if found, None otherwise
""" """
@ -141,7 +158,9 @@ class TTSService:
# Load model and voice # Load model and voice
model = TTSModel._instance model = TTSModel._instance
voicepack = torch.load(voice_path, map_location=TTSModel._device, weights_only=True) voicepack = torch.load(
voice_path, map_location=TTSModel._device, weights_only=True
)
# Generate audio with or without stitching # Generate audio with or without stitching
if stitch_long_output: if stitch_long_output:
@ -152,11 +171,11 @@ class TTSService:
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):
try: try:
# Validate phonemization first # Validate phonemization first
ps = phonemize(chunk, voice[0]) # ps = phonemize(chunk, voice[0])
tokens = tokenize(ps) # tokens = tokenize(ps)
logger.debug( # logger.debug(
f"Processing chunk {i + 1}/{len(chunks)}: {len(tokens)} tokens" # f"Processing chunk {i + 1}/{len(chunks)}: {len(tokens)} tokens"
) # )
# Only proceed if phonemization succeeded # Only proceed if phonemization succeeded
chunk_audio, _ = generate( chunk_audio, _ = generate(
@ -205,47 +224,51 @@ class TTSService:
def combine_voices(self, voices: List[str]) -> str: def combine_voices(self, voices: List[str]) -> str:
"""Combine multiple voices into a new voice. """Combine multiple voices into a new voice.
Args: Args:
voices: List of voice names to combine voices: List of voice names to combine
Returns: Returns:
Name of the combined voice Name of the combined voice
Raises: Raises:
ValueError: If less than 2 voices provided or voice loading fails ValueError: If less than 2 voices provided or voice loading fails
RuntimeError: If voice combination or saving fails RuntimeError: If voice combination or saving fails
""" """
if len(voices) < 2: if len(voices) < 2:
raise ValueError("At least 2 voices are required for combination") raise ValueError("At least 2 voices are required for combination")
# Load voices # Load voices
t_voices: List[torch.Tensor] = [] t_voices: List[torch.Tensor] = []
v_name: List[str] = [] v_name: List[str] = []
for voice in voices: for voice in voices:
try: try:
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt") voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt")
voicepack = torch.load(voice_path, map_location=TTSModel._device, weights_only=True) voicepack = torch.load(
voice_path, map_location=TTSModel._device, weights_only=True
)
t_voices.append(voicepack) t_voices.append(voicepack)
v_name.append(voice) v_name.append(voice)
except Exception as e: except Exception as e:
raise ValueError(f"Failed to load voice {voice}: {str(e)}") raise ValueError(f"Failed to load voice {voice}: {str(e)}")
# Combine voices # Combine voices
try: try:
f: str = "_".join(v_name) f: str = "_".join(v_name)
v = torch.mean(torch.stack(t_voices), dim=0) v = torch.mean(torch.stack(t_voices), dim=0)
combined_path = os.path.join(TTSModel.VOICES_DIR, f"{f}.pt") combined_path = os.path.join(TTSModel.VOICES_DIR, f"{f}.pt")
# Save combined voice # Save combined voice
try: try:
torch.save(v, combined_path) torch.save(v, combined_path)
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to save combined voice to {combined_path}: {str(e)}") raise RuntimeError(
f"Failed to save combined voice to {combined_path}: {str(e)}"
)
return f return f
except Exception as e: except Exception as e:
if not isinstance(e, (ValueError, RuntimeError)): if not isinstance(e, (ValueError, RuntimeError)):
raise RuntimeError(f"Error combining voices: {str(e)}") raise RuntimeError(f"Error combining voices: {str(e)}")

View file

@ -17,8 +17,8 @@ class OpenAISpeechRequest(BaseModel):
model: Literal["tts-1", "tts-1-hd", "kokoro"] = "kokoro" model: Literal["tts-1", "tts-1-hd", "kokoro"] = "kokoro"
input: str = Field(..., description="The text to generate audio for") input: str = Field(..., description="The text to generate audio for")
voice: str = Field( voice: str = Field(
default="af", default="af",
description="The voice to use for generation. Can be a base voice or a combined voice name." description="The voice to use for generation. Can be a base voice or a combined voice name.",
) )
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field( response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field(
default="mp3", default="mp3",

View file

@ -1,16 +1,18 @@
import os import os
import shutil
import sys import sys
import shutil
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest import pytest
def cleanup_mock_dirs(): def cleanup_mock_dirs():
"""Clean up any MagicMock directories created during tests""" """Clean up any MagicMock directories created during tests"""
mock_dir = "MagicMock" mock_dir = "MagicMock"
if os.path.exists(mock_dir): if os.path.exists(mock_dir):
shutil.rmtree(mock_dir) shutil.rmtree(mock_dir)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def cleanup(): def cleanup():
"""Automatically clean up before and after each test""" """Automatically clean up before and after each test"""
@ -18,6 +20,7 @@ def cleanup():
yield yield
cleanup_mock_dirs() cleanup_mock_dirs()
# Mock torch and other ML modules before they're imported # Mock torch and other ML modules before they're imported
sys.modules["torch"] = Mock() sys.modules["torch"] = Mock()
sys.modules["transformers"] = Mock() sys.modules["transformers"] = Mock()

View file

@ -1,6 +1,8 @@
"""Tests for AudioService""" """Tests for AudioService"""
import numpy as np import numpy as np
import pytest import pytest
from api.src.services.audio import AudioService from api.src.services.audio import AudioService

View file

@ -114,9 +114,9 @@ def test_combine_voices_success(mock_tts_service):
"""Test successful voice combination""" """Test successful voice combination"""
test_voices = ["af_bella", "af_sarah"] test_voices = ["af_bella", "af_sarah"]
mock_tts_service.combine_voices.return_value = "af_bella_af_sarah" mock_tts_service.combine_voices.return_value = "af_bella_af_sarah"
response = client.post("/v1/audio/voices/combine", json=test_voices) response = client.post("/v1/audio/voices/combine", json=test_voices)
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["voice"] == "af_bella_af_sarah" assert response.json()["voice"] == "af_bella_af_sarah"
mock_tts_service.combine_voices.assert_called_once_with(voices=test_voices) mock_tts_service.combine_voices.assert_called_once_with(voices=test_voices)
@ -126,9 +126,9 @@ def test_combine_voices_single_voice(mock_tts_service):
"""Test combining single voice returns default voice""" """Test combining single voice returns default voice"""
test_voices = ["af_bella"] test_voices = ["af_bella"]
mock_tts_service.combine_voices.return_value = "af" mock_tts_service.combine_voices.return_value = "af"
response = client.post("/v1/audio/voices/combine", json=test_voices) response = client.post("/v1/audio/voices/combine", json=test_voices)
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["voice"] == "af" assert response.json()["voice"] == "af"
@ -137,9 +137,9 @@ def test_combine_voices_empty_list(mock_tts_service):
"""Test combining empty voice list returns default voice""" """Test combining empty voice list returns default voice"""
test_voices = [] test_voices = []
mock_tts_service.combine_voices.return_value = "af" mock_tts_service.combine_voices.return_value = "af"
response = client.post("/v1/audio/voices/combine", json=test_voices) response = client.post("/v1/audio/voices/combine", json=test_voices)
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["voice"] == "af" assert response.json()["voice"] == "af"
@ -148,8 +148,8 @@ def test_combine_voices_error(mock_tts_service):
"""Test error handling in voice combination""" """Test error handling in voice combination"""
test_voices = ["af_bella", "af_sarah"] test_voices = ["af_bella", "af_sarah"]
mock_tts_service.combine_voices.side_effect = Exception("Combination failed") mock_tts_service.combine_voices.side_effect = Exception("Combination failed")
response = client.post("/v1/audio/voices/combine", json=test_voices) response = client.post("/v1/audio/voices/combine", json=test_voices)
assert response.status_code == 500 assert response.status_code == 500
assert "Combination failed" in response.json()["detail"]["message"] assert "Combination failed" in response.json()["detail"]["message"]

View file

@ -1,7 +1,10 @@
"""Tests for FastAPI application""" """Tests for FastAPI application"""
from unittest.mock import MagicMock, patch
import pytest import pytest
from unittest.mock import patch, MagicMock
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from api.src.main import app, lifespan from api.src.main import app, lifespan
@ -19,98 +22,100 @@ def test_health_check(test_client):
@pytest.mark.asyncio @pytest.mark.asyncio
@patch('api.src.main.TTSModel') @patch("api.src.main.TTSModel")
@patch('api.src.main.logger') @patch("api.src.main.logger")
async def test_lifespan_successful_warmup(mock_logger, mock_tts_model): async def test_lifespan_successful_warmup(mock_logger, mock_tts_model):
"""Test successful model warmup in lifespan""" """Test successful model warmup in lifespan"""
# Mock the model initialization with model info and voicepack count # Mock the model initialization with model info and voicepack count
mock_model = MagicMock() mock_model = MagicMock()
# Mock file system for voice counting # Mock file system for voice counting
mock_tts_model.VOICES_DIR = "/mock/voices" mock_tts_model.VOICES_DIR = "/mock/voices"
with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt', 'voice3.pt']): with patch("os.listdir", return_value=["voice1.pt", "voice2.pt", "voice3.pt"]):
mock_tts_model.initialize.return_value = (mock_model, 3) # 3 voice files mock_tts_model.initialize.return_value = (mock_model, 3) # 3 voice files
mock_tts_model._device = "cuda" # Set device class variable mock_tts_model._device = "cuda" # Set device class variable
# Create an async generator from the lifespan context manager # Create an async generator from the lifespan context manager
async_gen = lifespan(MagicMock()) async_gen = lifespan(MagicMock())
# Start the context manager # Start the context manager
await async_gen.__aenter__() await async_gen.__aenter__()
# Verify the expected logging sequence # Verify the expected logging sequence
mock_logger.info.assert_any_call("Loading TTS model and voice packs...") mock_logger.info.assert_any_call("Loading TTS model and voice packs...")
mock_logger.info.assert_any_call("Model loaded and warmed up on cuda") mock_logger.info.assert_any_call("Model loaded and warmed up on cuda")
mock_logger.info.assert_any_call("3 voice packs loaded successfully") mock_logger.info.assert_any_call("3 voice packs loaded successfully")
# Verify model initialization was called # Verify model initialization was called
mock_tts_model.initialize.assert_called_once() mock_tts_model.initialize.assert_called_once()
# Clean up # Clean up
await async_gen.__aexit__(None, None, None) await async_gen.__aexit__(None, None, None)
@pytest.mark.asyncio @pytest.mark.asyncio
@patch('api.src.main.TTSModel') @patch("api.src.main.TTSModel")
@patch('api.src.main.logger') @patch("api.src.main.logger")
async def test_lifespan_failed_warmup(mock_logger, mock_tts_model): async def test_lifespan_failed_warmup(mock_logger, mock_tts_model):
"""Test failed model warmup in lifespan""" """Test failed model warmup in lifespan"""
# Mock the model initialization to fail # Mock the model initialization to fail
mock_tts_model.initialize.side_effect = Exception("Failed to initialize model") mock_tts_model.initialize.side_effect = Exception("Failed to initialize model")
# Create an async generator from the lifespan context manager # Create an async generator from the lifespan context manager
async_gen = lifespan(MagicMock()) async_gen = lifespan(MagicMock())
# Verify the exception is raised # Verify the exception is raised
with pytest.raises(Exception, match="Failed to initialize model"): with pytest.raises(Exception, match="Failed to initialize model"):
await async_gen.__aenter__() await async_gen.__aenter__()
# Verify the expected logging sequence # Verify the expected logging sequence
mock_logger.info.assert_called_with("Loading TTS model and voice packs...") mock_logger.info.assert_called_with("Loading TTS model and voice packs...")
# Clean up # Clean up
await async_gen.__aexit__(None, None, None) await async_gen.__aexit__(None, None, None)
@pytest.mark.asyncio @pytest.mark.asyncio
@patch('api.src.main.TTSModel') @patch("api.src.main.TTSModel")
async def test_lifespan_cuda_warmup(mock_tts_model): async def test_lifespan_cuda_warmup(mock_tts_model):
"""Test model warmup specifically on CUDA""" """Test model warmup specifically on CUDA"""
# Mock the model initialization with CUDA and voicepacks # Mock the model initialization with CUDA and voicepacks
mock_model = MagicMock() mock_model = MagicMock()
# Mock file system for voice counting # Mock file system for voice counting
mock_tts_model.VOICES_DIR = "/mock/voices" mock_tts_model.VOICES_DIR = "/mock/voices"
with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt']): with patch("os.listdir", return_value=["voice1.pt", "voice2.pt"]):
mock_tts_model.initialize.return_value = (mock_model, 2) # 2 voice files mock_tts_model.initialize.return_value = (mock_model, 2) # 2 voice files
mock_tts_model._device = "cuda" # Set device class variable mock_tts_model._device = "cuda" # Set device class variable
# Create an async generator from the lifespan context manager # Create an async generator from the lifespan context manager
async_gen = lifespan(MagicMock()) async_gen = lifespan(MagicMock())
await async_gen.__aenter__() await async_gen.__aenter__()
# Verify model was initialized # Verify model was initialized
mock_tts_model.initialize.assert_called_once() mock_tts_model.initialize.assert_called_once()
# Clean up # Clean up
await async_gen.__aexit__(None, None, None) await async_gen.__aexit__(None, None, None)
@pytest.mark.asyncio @pytest.mark.asyncio
@patch('api.src.main.TTSModel') @patch("api.src.main.TTSModel")
async def test_lifespan_cpu_fallback(mock_tts_model): async def test_lifespan_cpu_fallback(mock_tts_model):
"""Test model warmup falling back to CPU""" """Test model warmup falling back to CPU"""
# Mock the model initialization with CPU and voicepacks # Mock the model initialization with CPU and voicepacks
mock_model = MagicMock() mock_model = MagicMock()
# Mock file system for voice counting # Mock file system for voice counting
mock_tts_model.VOICES_DIR = "/mock/voices" mock_tts_model.VOICES_DIR = "/mock/voices"
with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt', 'voice3.pt', 'voice4.pt']): with patch(
"os.listdir", return_value=["voice1.pt", "voice2.pt", "voice3.pt", "voice4.pt"]
):
mock_tts_model.initialize.return_value = (mock_model, 4) # 4 voice files mock_tts_model.initialize.return_value = (mock_model, 4) # 4 voice files
mock_tts_model._device = "cpu" # Set device class variable mock_tts_model._device = "cpu" # Set device class variable
# Create an async generator from the lifespan context manager # Create an async generator from the lifespan context manager
async_gen = lifespan(MagicMock()) async_gen = lifespan(MagicMock())
await async_gen.__aenter__() await async_gen.__aenter__()
# Verify model was initialized # Verify model was initialized
mock_tts_model.initialize.assert_called_once() mock_tts_model.initialize.assert_called_once()
# Clean up # Clean up
await async_gen.__aexit__(None, None, None) await async_gen.__aexit__(None, None, None)

View file

@ -1,9 +1,12 @@
"""Tests for TTSService""" """Tests for TTSService"""
import os import os
from unittest.mock import MagicMock, call, patch
import numpy as np import numpy as np
import pytest import pytest
from unittest.mock import patch, MagicMock, call
from api.src.services.tts import TTSService, TTSModel from api.src.services.tts import TTSModel, TTSService
@pytest.fixture @pytest.fixture
@ -50,42 +53,59 @@ def test_audio_to_bytes(tts_service, sample_audio):
assert len(audio_bytes) > 0 assert len(audio_bytes) > 0
@patch('os.listdir') @patch("os.listdir")
@patch('os.path.join') @patch("os.path.join")
def test_list_voices(mock_join, mock_listdir, tts_service): def test_list_voices(mock_join, mock_listdir, tts_service):
"""Test listing available voices""" """Test listing available voices"""
mock_listdir.return_value = ['voice1.pt', 'voice2.pt', 'not_a_voice.txt'] mock_listdir.return_value = ["voice1.pt", "voice2.pt", "not_a_voice.txt"]
mock_join.return_value = '/fake/path' mock_join.return_value = "/fake/path"
voices = tts_service.list_voices() voices = tts_service.list_voices()
assert len(voices) == 2 assert len(voices) == 2
assert 'voice1' in voices assert "voice1" in voices
assert 'voice2' in voices assert "voice2" in voices
assert 'not_a_voice' not in voices assert "not_a_voice" not in voices
@patch('api.src.services.tts.TTSModel.get_instance') @patch("api.src.services.tts.TTSModel.get_instance")
@patch('api.src.services.tts.TTSModel.get_voicepack') @patch("api.src.services.tts.TTSModel.get_voicepack")
@patch('api.src.services.tts.normalize_text') @patch("api.src.services.tts.normalize_text")
@patch('api.src.services.tts.phonemize') @patch("api.src.services.tts.phonemize")
@patch('api.src.services.tts.tokenize') @patch("api.src.services.tts.tokenize")
@patch('api.src.services.tts.generate') @patch("api.src.services.tts.generate")
def test_generate_audio_empty_text(mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_voicepack, mock_instance, tts_service): def test_generate_audio_empty_text(
mock_generate,
mock_tokenize,
mock_phonemize,
mock_normalize,
mock_voicepack,
mock_instance,
tts_service,
):
"""Test generating audio with empty text""" """Test generating audio with empty text"""
mock_normalize.return_value = "" mock_normalize.return_value = ""
with pytest.raises(ValueError, match="Text is empty after preprocessing"): with pytest.raises(ValueError, match="Text is empty after preprocessing"):
tts_service._generate_audio("", "af", 1.0) tts_service._generate_audio("", "af", 1.0)
@patch('api.src.services.tts.TTSModel.get_instance') @patch("api.src.services.tts.TTSModel.get_instance")
@patch('os.path.exists') @patch("os.path.exists")
@patch('api.src.services.tts.normalize_text') @patch("api.src.services.tts.normalize_text")
@patch('api.src.services.tts.phonemize') @patch("api.src.services.tts.phonemize")
@patch('api.src.services.tts.tokenize') @patch("api.src.services.tts.tokenize")
@patch('api.src.services.tts.generate') @patch("api.src.services.tts.generate")
@patch('torch.load') @patch("torch.load")
def test_generate_audio_no_chunks(mock_torch_load, mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_exists, mock_instance, tts_service): def test_generate_audio_no_chunks(
mock_torch_load,
mock_generate,
mock_tokenize,
mock_phonemize,
mock_normalize,
mock_exists,
mock_instance,
tts_service,
):
"""Test generating audio with no successful chunks""" """Test generating audio with no successful chunks"""
mock_normalize.return_value = "Test text" mock_normalize.return_value = "Test text"
mock_phonemize.return_value = "Test text" mock_phonemize.return_value = "Test text"
@ -94,19 +114,29 @@ def test_generate_audio_no_chunks(mock_torch_load, mock_generate, mock_tokenize,
mock_instance.return_value = (MagicMock(), "cpu") mock_instance.return_value = (MagicMock(), "cpu")
mock_exists.return_value = True mock_exists.return_value = True
mock_torch_load.return_value = MagicMock() mock_torch_load.return_value = MagicMock()
with pytest.raises(ValueError, match="No audio chunks were generated successfully"): with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
tts_service._generate_audio("Test text", "af", 1.0) tts_service._generate_audio("Test text", "af", 1.0)
@patch('api.src.services.tts.TTSModel.get_instance') @patch("api.src.services.tts.TTSModel.get_instance")
@patch('os.path.exists') @patch("os.path.exists")
@patch('api.src.services.tts.normalize_text') @patch("api.src.services.tts.normalize_text")
@patch('api.src.services.tts.phonemize') @patch("api.src.services.tts.phonemize")
@patch('api.src.services.tts.tokenize') @patch("api.src.services.tts.tokenize")
@patch('api.src.services.tts.generate') @patch("api.src.services.tts.generate")
@patch('torch.load') @patch("torch.load")
def test_generate_audio_success(mock_torch_load, mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_exists, mock_instance, tts_service, sample_audio): def test_generate_audio_success(
mock_torch_load,
mock_generate,
mock_tokenize,
mock_phonemize,
mock_normalize,
mock_exists,
mock_instance,
tts_service,
sample_audio,
):
"""Test successful audio generation""" """Test successful audio generation"""
mock_normalize.return_value = "Test text" mock_normalize.return_value = "Test text"
mock_phonemize.return_value = "Test text" mock_phonemize.return_value = "Test text"
@ -115,15 +145,15 @@ def test_generate_audio_success(mock_torch_load, mock_generate, mock_tokenize, m
mock_instance.return_value = (MagicMock(), "cpu") mock_instance.return_value = (MagicMock(), "cpu")
mock_exists.return_value = True mock_exists.return_value = True
mock_torch_load.return_value = MagicMock() mock_torch_load.return_value = MagicMock()
audio, processing_time = tts_service._generate_audio("Test text", "af", 1.0) audio, processing_time = tts_service._generate_audio("Test text", "af", 1.0)
assert isinstance(audio, np.ndarray) assert isinstance(audio, np.ndarray)
assert isinstance(processing_time, float) assert isinstance(processing_time, float)
assert len(audio) > 0 assert len(audio) > 0
@patch('api.src.services.tts.torch.cuda.is_available') @patch("api.src.services.tts.torch.cuda.is_available")
@patch('api.src.services.tts.build_model') @patch("api.src.services.tts.build_model")
def test_model_initialization_cuda(mock_build_model, mock_cuda_available): def test_model_initialization_cuda(mock_build_model, mock_cuda_available):
"""Test model initialization with CUDA""" """Test model initialization with CUDA"""
mock_cuda_available.return_value = True mock_cuda_available.return_value = True
@ -132,14 +162,14 @@ def test_model_initialization_cuda(mock_build_model, mock_cuda_available):
TTSModel._instance = None # Reset singleton TTSModel._instance = None # Reset singleton
model, voice_count = TTSModel.initialize() model, voice_count = TTSModel.initialize()
assert TTSModel._device == "cuda" # Check the class variable instead assert TTSModel._device == "cuda" # Check the class variable instead
assert model == mock_model assert model == mock_model
mock_build_model.assert_called_once() mock_build_model.assert_called_once()
@patch('api.src.services.tts.torch.cuda.is_available') @patch("api.src.services.tts.torch.cuda.is_available")
@patch('api.src.services.tts.build_model') @patch("api.src.services.tts.build_model")
def test_model_initialization_cpu(mock_build_model, mock_cuda_available): def test_model_initialization_cpu(mock_build_model, mock_cuda_available):
"""Test model initialization with CPU""" """Test model initialization with CPU"""
mock_cuda_available.return_value = False mock_cuda_available.return_value = False
@ -148,76 +178,95 @@ def test_model_initialization_cpu(mock_build_model, mock_cuda_available):
TTSModel._instance = None # Reset singleton TTSModel._instance = None # Reset singleton
model, voice_count = TTSModel.initialize() model, voice_count = TTSModel.initialize()
assert TTSModel._device == "cpu" # Check the class variable instead assert TTSModel._device == "cpu" # Check the class variable instead
assert model == mock_model assert model == mock_model
mock_build_model.assert_called_once() mock_build_model.assert_called_once()
@patch('api.src.services.tts.TTSService._get_voice_path') @patch("api.src.services.tts.TTSService._get_voice_path")
@patch('api.src.services.tts.TTSModel.get_instance') @patch("api.src.services.tts.TTSModel.get_instance")
def test_voicepack_loading_error(mock_get_instance, mock_get_voice_path): def test_voicepack_loading_error(mock_get_instance, mock_get_voice_path):
"""Test voicepack loading error handling""" """Test voicepack loading error handling"""
mock_get_voice_path.return_value = None mock_get_voice_path.return_value = None
mock_get_instance.return_value = (MagicMock(), "cpu") mock_get_instance.return_value = (MagicMock(), "cpu")
TTSModel._voicepacks = {} # Reset voicepacks TTSModel._voicepacks = {} # Reset voicepacks
service = TTSService(start_worker=False) service = TTSService(start_worker=False)
with pytest.raises(ValueError, match="Voice not found: nonexistent_voice"): with pytest.raises(ValueError, match="Voice not found: nonexistent_voice"):
service._generate_audio("test", "nonexistent_voice", 1.0) service._generate_audio("test", "nonexistent_voice", 1.0)
@patch('api.src.services.tts.TTSModel') @patch("api.src.services.tts.TTSModel")
def test_save_audio(mock_tts_model, tts_service, sample_audio, tmp_path): def test_save_audio(mock_tts_model, tts_service, sample_audio, tmp_path):
"""Test saving audio to file""" """Test saving audio to file"""
output_dir = os.path.join(tmp_path, "test_output") output_dir = os.path.join(tmp_path, "test_output")
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, "audio.wav") output_path = os.path.join(output_dir, "audio.wav")
tts_service._save_audio(sample_audio, output_path) tts_service._save_audio(sample_audio, output_path)
assert os.path.exists(output_path) assert os.path.exists(output_path)
assert os.path.getsize(output_path) > 0 assert os.path.getsize(output_path) > 0
@patch('api.src.services.tts.TTSModel.get_instance') @patch("api.src.services.tts.TTSModel.get_instance")
@patch('os.path.exists') @patch("os.path.exists")
@patch('api.src.services.tts.normalize_text') @patch("api.src.services.tts.normalize_text")
@patch('api.src.services.tts.generate') @patch("api.src.services.tts.generate")
@patch('torch.load') @patch("torch.load")
def test_generate_audio_without_stitching(mock_torch_load, mock_generate, mock_normalize, mock_exists, mock_instance, tts_service, sample_audio): def test_generate_audio_without_stitching(
mock_torch_load,
mock_generate,
mock_normalize,
mock_exists,
mock_instance,
tts_service,
sample_audio,
):
"""Test generating audio without text stitching""" """Test generating audio without text stitching"""
mock_normalize.return_value = "Test text" mock_normalize.return_value = "Test text"
mock_generate.return_value = (sample_audio, None) mock_generate.return_value = (sample_audio, None)
mock_instance.return_value = (MagicMock(), "cpu") mock_instance.return_value = (MagicMock(), "cpu")
mock_exists.return_value = True mock_exists.return_value = True
mock_torch_load.return_value = MagicMock() mock_torch_load.return_value = MagicMock()
audio, processing_time = tts_service._generate_audio("Test text", "af", 1.0, stitch_long_output=False) audio, processing_time = tts_service._generate_audio(
"Test text", "af", 1.0, stitch_long_output=False
)
assert isinstance(audio, np.ndarray) assert isinstance(audio, np.ndarray)
assert isinstance(processing_time, float) assert isinstance(processing_time, float)
assert len(audio) > 0 assert len(audio) > 0
mock_generate.assert_called_once() mock_generate.assert_called_once()
@patch('os.listdir') @patch("os.listdir")
def test_list_voices_error(mock_listdir, tts_service): def test_list_voices_error(mock_listdir, tts_service):
"""Test error handling in list_voices""" """Test error handling in list_voices"""
mock_listdir.side_effect = Exception("Failed to list directory") mock_listdir.side_effect = Exception("Failed to list directory")
voices = tts_service.list_voices() voices = tts_service.list_voices()
assert voices == [] assert voices == []
@patch('api.src.services.tts.TTSModel.get_instance') @patch("api.src.services.tts.TTSModel.get_instance")
@patch('os.path.exists') @patch("os.path.exists")
@patch('api.src.services.tts.normalize_text') @patch("api.src.services.tts.normalize_text")
@patch('api.src.services.tts.phonemize') @patch("api.src.services.tts.phonemize")
@patch('api.src.services.tts.tokenize') @patch("api.src.services.tts.tokenize")
@patch('api.src.services.tts.generate') @patch("api.src.services.tts.generate")
@patch('torch.load') @patch("torch.load")
def test_generate_audio_phonemize_error(mock_torch_load, mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_exists, mock_instance, tts_service): def test_generate_audio_phonemize_error(
mock_torch_load,
mock_generate,
mock_tokenize,
mock_phonemize,
mock_normalize,
mock_exists,
mock_instance,
tts_service,
):
"""Test handling phonemization error""" """Test handling phonemization error"""
mock_normalize.return_value = "Test text" mock_normalize.return_value = "Test text"
mock_phonemize.side_effect = Exception("Phonemization failed") mock_phonemize.side_effect = Exception("Phonemization failed")
@ -225,23 +274,30 @@ def test_generate_audio_phonemize_error(mock_torch_load, mock_generate, mock_tok
mock_exists.return_value = True mock_exists.return_value = True
mock_torch_load.return_value = MagicMock() mock_torch_load.return_value = MagicMock()
mock_generate.return_value = (None, None) mock_generate.return_value = (None, None)
with pytest.raises(ValueError, match="No audio chunks were generated successfully"): with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
tts_service._generate_audio("Test text", "af", 1.0) tts_service._generate_audio("Test text", "af", 1.0)
@patch('api.src.services.tts.TTSModel.get_instance') @patch("api.src.services.tts.TTSModel.get_instance")
@patch('os.path.exists') @patch("os.path.exists")
@patch('api.src.services.tts.normalize_text') @patch("api.src.services.tts.normalize_text")
@patch('api.src.services.tts.generate') @patch("api.src.services.tts.generate")
@patch('torch.load') @patch("torch.load")
def test_generate_audio_error(mock_torch_load, mock_generate, mock_normalize, mock_exists, mock_instance, tts_service): def test_generate_audio_error(
mock_torch_load,
mock_generate,
mock_normalize,
mock_exists,
mock_instance,
tts_service,
):
"""Test handling generation error""" """Test handling generation error"""
mock_normalize.return_value = "Test text" mock_normalize.return_value = "Test text"
mock_generate.side_effect = Exception("Generation failed") mock_generate.side_effect = Exception("Generation failed")
mock_instance.return_value = (MagicMock(), "cpu") mock_instance.return_value = (MagicMock(), "cpu")
mock_exists.return_value = True mock_exists.return_value = True
mock_torch_load.return_value = MagicMock() mock_torch_load.return_value = MagicMock()
with pytest.raises(ValueError, match="No audio chunks were generated successfully"): with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
tts_service._generate_audio("Test text", "af", 1.0) tts_service._generate_audio("Test text", "af", 1.0)

View file

@ -19,7 +19,6 @@ output_dir = Path(__file__).parent / "output"
output_dir.mkdir(exist_ok=True) output_dir.mkdir(exist_ok=True)
def test_voice(voice: str): def test_voice(voice: str):
speech_file = output_dir / f"speech_{voice}.mp3" speech_file = output_dir / f"speech_{voice}.mp3"
print(f"\nTesting voice: {voice}") print(f"\nTesting voice: {voice}")

View file

@ -1,21 +1,23 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse
import os import os
from typing import List, Optional, Dict, Tuple import argparse
from typing import Dict, List, Tuple, Optional
import requests
import numpy as np import numpy as np
from scipy.io import wavfile import requests
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from scipy.io import wavfile
def submit_combine_voices(voices: List[str], base_url: str = "http://localhost:8880") -> Optional[str]: def submit_combine_voices(
voices: List[str], base_url: str = "http://localhost:8880"
) -> Optional[str]:
"""Combine multiple voices into a new voice. """Combine multiple voices into a new voice.
Args: Args:
voices: List of voice names to combine (e.g. ["af_bella", "af_sarah"]) voices: List of voice names to combine (e.g. ["af_bella", "af_sarah"])
base_url: API base URL base_url: API base URL
Returns: Returns:
Name of the combined voice (e.g. "af_bella_af_sarah") or None if error Name of the combined voice (e.g. "af_bella_af_sarah") or None if error
""" """
@ -23,7 +25,7 @@ def submit_combine_voices(voices: List[str], base_url: str = "http://localhost:8
response = requests.post(f"{base_url}/v1/audio/voices/combine", json=voices) response = requests.post(f"{base_url}/v1/audio/voices/combine", json=voices)
print(f"Response status: {response.status_code}") print(f"Response status: {response.status_code}")
print(f"Raw response: {response.text}") print(f"Raw response: {response.text}")
# Accept both 200 and 201 as success # Accept both 200 and 201 as success
if response.status_code not in [200, 201]: if response.status_code not in [200, 201]:
try: try:
@ -32,7 +34,7 @@ def submit_combine_voices(voices: List[str], base_url: str = "http://localhost:8
except: except:
print(f"Error combining voices: {response.text}") print(f"Error combining voices: {response.text}")
return None return None
try: try:
data = response.json() data = response.json()
if "voices" in data: if "voices" in data:
@ -46,15 +48,20 @@ def submit_combine_voices(voices: List[str], base_url: str = "http://localhost:8
return None return None
def generate_speech(text: str, voice: str, base_url: str = "http://localhost:8880", output_file: str = "output.mp3") -> bool: def generate_speech(
text: str,
voice: str,
base_url: str = "http://localhost:8880",
output_file: str = "output.mp3",
) -> bool:
"""Generate speech using specified voice. """Generate speech using specified voice.
Args: Args:
text: Text to convert to speech text: Text to convert to speech
voice: Voice name to use voice: Voice name to use
base_url: API base URL base_url: API base URL
output_file: Path to save audio file output_file: Path to save audio file
Returns: Returns:
True if successful, False otherwise True if successful, False otherwise
""" """
@ -65,22 +72,25 @@ def generate_speech(text: str, voice: str, base_url: str = "http://localhost:888
"input": text, "input": text,
"voice": voice, "voice": voice,
"speed": 1.0, "speed": 1.0,
"response_format": "wav" # Use WAV for analysis "response_format": "wav", # Use WAV for analysis
} },
) )
if response.status_code != 200: if response.status_code != 200:
error = response.json().get("detail", {}).get("message", response.text) error = response.json().get("detail", {}).get("message", response.text)
print(f"Error generating speech: {error}") print(f"Error generating speech: {error}")
return False return False
# Save the audio # Save the audio
os.makedirs(os.path.dirname(output_file) if os.path.dirname(output_file) else ".", exist_ok=True) os.makedirs(
os.path.dirname(output_file) if os.path.dirname(output_file) else ".",
exist_ok=True,
)
with open(output_file, "wb") as f: with open(output_file, "wb") as f:
f.write(response.content) f.write(response.content)
print(f"Saved audio to {output_file}") print(f"Saved audio to {output_file}")
return True return True
except Exception as e: except Exception as e:
print(f"Error: {e}") print(f"Error: {e}")
return False return False
@ -88,57 +98,57 @@ def generate_speech(text: str, voice: str, base_url: str = "http://localhost:888
def analyze_audio(filepath: str) -> Tuple[np.ndarray, int, dict]: def analyze_audio(filepath: str) -> Tuple[np.ndarray, int, dict]:
"""Analyze audio file and return samples, sample rate, and audio characteristics. """Analyze audio file and return samples, sample rate, and audio characteristics.
Args: Args:
filepath: Path to audio file filepath: Path to audio file
Returns: Returns:
Tuple of (samples, sample_rate, characteristics) Tuple of (samples, sample_rate, characteristics)
""" """
sample_rate, samples = wavfile.read(filepath) sample_rate, samples = wavfile.read(filepath)
# Convert to mono if stereo # Convert to mono if stereo
if len(samples.shape) > 1: if len(samples.shape) > 1:
samples = np.mean(samples, axis=1) samples = np.mean(samples, axis=1)
# Calculate basic stats # Calculate basic stats
max_amp = np.max(np.abs(samples)) max_amp = np.max(np.abs(samples))
rms = np.sqrt(np.mean(samples**2)) rms = np.sqrt(np.mean(samples**2))
duration = len(samples) / sample_rate duration = len(samples) / sample_rate
# Zero crossing rate (helps identify voice characteristics) # Zero crossing rate (helps identify voice characteristics)
zero_crossings = np.sum(np.abs(np.diff(np.signbit(samples)))) / len(samples) zero_crossings = np.sum(np.abs(np.diff(np.signbit(samples)))) / len(samples)
# Simple frequency analysis # Simple frequency analysis
if len(samples) > 0: if len(samples) > 0:
# Use FFT to get frequency components # Use FFT to get frequency components
fft_result = np.fft.fft(samples) fft_result = np.fft.fft(samples)
freqs = np.fft.fftfreq(len(samples), 1/sample_rate) freqs = np.fft.fftfreq(len(samples), 1 / sample_rate)
# Get positive frequencies only # Get positive frequencies only
pos_mask = freqs > 0 pos_mask = freqs > 0
freqs = freqs[pos_mask] freqs = freqs[pos_mask]
magnitudes = np.abs(fft_result)[pos_mask] magnitudes = np.abs(fft_result)[pos_mask]
# Find dominant frequencies (top 3) # Find dominant frequencies (top 3)
top_indices = np.argsort(magnitudes)[-3:] top_indices = np.argsort(magnitudes)[-3:]
dominant_freqs = freqs[top_indices] dominant_freqs = freqs[top_indices]
# Calculate spectral centroid (brightness of sound) # Calculate spectral centroid (brightness of sound)
spectral_centroid = np.sum(freqs * magnitudes) / np.sum(magnitudes) spectral_centroid = np.sum(freqs * magnitudes) / np.sum(magnitudes)
else: else:
dominant_freqs = [] dominant_freqs = []
spectral_centroid = 0 spectral_centroid = 0
characteristics = { characteristics = {
"max_amplitude": max_amp, "max_amplitude": max_amp,
"rms": rms, "rms": rms,
"duration": duration, "duration": duration,
"zero_crossing_rate": zero_crossings, "zero_crossing_rate": zero_crossings,
"dominant_frequencies": dominant_freqs, "dominant_frequencies": dominant_freqs,
"spectral_centroid": spectral_centroid "spectral_centroid": spectral_centroid,
} }
return samples, sample_rate, characteristics return samples, sample_rate, characteristics
@ -167,112 +177,136 @@ def setup_plot(fig, ax, title):
return fig, ax return fig, ax
def plot_analysis(audio_files: Dict[str, str], output_dir: str): def plot_analysis(audio_files: Dict[str, str], output_dir: str):
"""Plot comprehensive voice analysis including waveforms and metrics comparison. """Plot comprehensive voice analysis including waveforms and metrics comparison.
Args: Args:
audio_files: Dictionary of label -> filepath audio_files: Dictionary of label -> filepath
output_dir: Directory to save plot files output_dir: Directory to save plot files
""" """
# Set dark style # Set dark style
plt.style.use('dark_background') plt.style.use("dark_background")
# Create figure with subplots # Create figure with subplots
fig = plt.figure(figsize=(15, 15)) fig = plt.figure(figsize=(15, 15))
fig.patch.set_facecolor("#1a1a2e") fig.patch.set_facecolor("#1a1a2e")
num_files = len(audio_files) num_files = len(audio_files)
# Create subplot grid with proper spacing # Create subplot grid with proper spacing
gs = plt.GridSpec(num_files + 1, 2, height_ratios=[1.5]*num_files + [1], gs = plt.GridSpec(
hspace=0.4, wspace=0.3) num_files + 1, 2, height_ratios=[1.5] * num_files + [1], hspace=0.4, wspace=0.3
)
# Analyze all files first # Analyze all files first
all_chars = {} all_chars = {}
for i, (label, filepath) in enumerate(audio_files.items()): for i, (label, filepath) in enumerate(audio_files.items()):
samples, sample_rate, chars = analyze_audio(filepath) samples, sample_rate, chars = analyze_audio(filepath)
all_chars[label] = chars all_chars[label] = chars
# Plot waveform spanning both columns # Plot waveform spanning both columns
ax = plt.subplot(gs[i, :]) ax = plt.subplot(gs[i, :])
time = np.arange(len(samples)) / sample_rate time = np.arange(len(samples)) / sample_rate
plt.plot(time, samples / chars['max_amplitude'], linewidth=0.5, color="#ff2a6d") plt.plot(time, samples / chars["max_amplitude"], linewidth=0.5, color="#ff2a6d")
ax.set_xlabel("Time (seconds)") ax.set_xlabel("Time (seconds)")
ax.set_ylabel("Normalized Amplitude") ax.set_ylabel("Normalized Amplitude")
ax.set_ylim(-1.1, 1.1) ax.set_ylim(-1.1, 1.1)
setup_plot(fig, ax, f"Waveform: {label}") setup_plot(fig, ax, f"Waveform: {label}")
# Colors for voices # Colors for voices
colors = ["#ff2a6d", "#05d9e8", "#d1f7ff"] colors = ["#ff2a6d", "#05d9e8", "#d1f7ff"]
# Create two subplots for metrics with similar scales # Create two subplots for metrics with similar scales
# Left subplot: Brightness and Volume # Left subplot: Brightness and Volume
ax1 = plt.subplot(gs[num_files, 0]) ax1 = plt.subplot(gs[num_files, 0])
metrics1 = [ metrics1 = [
('Brightness', [chars['spectral_centroid']/1000 for chars in all_chars.values()], 'kHz'), (
('Volume', [chars['rms']*100 for chars in all_chars.values()], 'RMS×100') "Brightness",
[chars["spectral_centroid"] / 1000 for chars in all_chars.values()],
"kHz",
),
("Volume", [chars["rms"] * 100 for chars in all_chars.values()], "RMS×100"),
] ]
# Right subplot: Voice Pitch and Texture # Right subplot: Voice Pitch and Texture
ax2 = plt.subplot(gs[num_files, 1]) ax2 = plt.subplot(gs[num_files, 1])
metrics2 = [ metrics2 = [
('Voice Pitch', [min(chars['dominant_frequencies']) for chars in all_chars.values()], 'Hz'), (
('Texture', [chars['zero_crossing_rate']*1000 for chars in all_chars.values()], 'ZCR×1000') "Voice Pitch",
[min(chars["dominant_frequencies"]) for chars in all_chars.values()],
"Hz",
),
(
"Texture",
[chars["zero_crossing_rate"] * 1000 for chars in all_chars.values()],
"ZCR×1000",
),
] ]
def plot_grouped_bars(ax, metrics, show_legend=True): def plot_grouped_bars(ax, metrics, show_legend=True):
n_groups = len(metrics) n_groups = len(metrics)
n_voices = len(audio_files) n_voices = len(audio_files)
bar_width = 0.25 bar_width = 0.25
indices = np.arange(n_groups) indices = np.arange(n_groups)
# Get max value for y-axis scaling # Get max value for y-axis scaling
max_val = max(max(m[1]) for m in metrics) max_val = max(max(m[1]) for m in metrics)
for i, (voice, color) in enumerate(zip(audio_files.keys(), colors)): for i, (voice, color) in enumerate(zip(audio_files.keys(), colors)):
values = [m[1][i] for m in metrics] values = [m[1][i] for m in metrics]
offset = (i - n_voices/2 + 0.5) * bar_width offset = (i - n_voices / 2 + 0.5) * bar_width
bars = ax.bar(indices + offset, values, bar_width, bars = ax.bar(
label=voice, color=color, alpha=0.8) indices + offset, values, bar_width, label=voice, color=color, alpha=0.8
)
# Add value labels on top of bars # Add value labels on top of bars
for bar in bars: for bar in bars:
height = bar.get_height() height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height, ax.text(
f'{height:.1f}', bar.get_x() + bar.get_width() / 2.0,
ha='center', va='bottom', color='white', height,
fontsize=10) f"{height:.1f}",
ha="center",
va="bottom",
color="white",
fontsize=10,
)
ax.set_xticks(indices) ax.set_xticks(indices)
ax.set_xticklabels([f"{m[0]}\n({m[2]})" for m in metrics]) ax.set_xticklabels([f"{m[0]}\n({m[2]})" for m in metrics])
# Set y-axis limits with some padding # Set y-axis limits with some padding
ax.set_ylim(0, max_val * 1.2) ax.set_ylim(0, max_val * 1.2)
if show_legend: if show_legend:
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', ax.legend(
facecolor="#1a1a2e", edgecolor="#ffffff") bbox_to_anchor=(1.05, 1),
loc="upper left",
facecolor="#1a1a2e",
edgecolor="#ffffff",
)
# Plot both subplots # Plot both subplots
plot_grouped_bars(ax1, metrics1, show_legend=True) plot_grouped_bars(ax1, metrics1, show_legend=True)
plot_grouped_bars(ax2, metrics2, show_legend=False) plot_grouped_bars(ax2, metrics2, show_legend=False)
# Style both subplots # Style both subplots
setup_plot(fig, ax1, 'Brightness and Volume') setup_plot(fig, ax1, "Brightness and Volume")
setup_plot(fig, ax2, 'Voice Pitch and Texture') setup_plot(fig, ax2, "Voice Pitch and Texture")
# Add y-axis labels # Add y-axis labels
ax1.set_ylabel('Value') ax1.set_ylabel("Value")
ax2.set_ylabel('Value') ax2.set_ylabel("Value")
# Adjust the figure size to accommodate the legend # Adjust the figure size to accommodate the legend
fig.set_size_inches(15, 15) fig.set_size_inches(15, 15)
# Add padding around the entire figure # Add padding around the entire figure
plt.subplots_adjust(right=0.85, top=0.95, bottom=0.05, left=0.1) plt.subplots_adjust(right=0.85, top=0.95, bottom=0.05, left=0.1)
plt.savefig(os.path.join(output_dir, "analysis_comparison.png"), dpi=300) plt.savefig(os.path.join(output_dir, "analysis_comparison.png"), dpi=300)
print(f"Saved analysis comparison to {output_dir}/analysis_comparison.png") print(f"Saved analysis comparison to {output_dir}/analysis_comparison.png")
# Print detailed comparative analysis # Print detailed comparative analysis
print("\nDetailed Voice Analysis:") print("\nDetailed Voice Analysis:")
for label, chars in all_chars.items(): for label, chars in all_chars.items():
@ -282,44 +316,57 @@ def plot_analysis(audio_files: Dict[str, str], output_dir: str):
print(f" Duration: {chars['duration']:.2f}s") print(f" Duration: {chars['duration']:.2f}s")
print(f" Zero Crossing Rate: {chars['zero_crossing_rate']:.3f}") print(f" Zero Crossing Rate: {chars['zero_crossing_rate']:.3f}")
print(f" Spectral Centroid: {chars['spectral_centroid']:.0f}Hz") print(f" Spectral Centroid: {chars['spectral_centroid']:.0f}Hz")
print(f" Dominant Frequencies: {', '.join(f'{f:.0f}Hz' for f in chars['dominant_frequencies'])}") print(
f" Dominant Frequencies: {', '.join(f'{f:.0f}Hz' for f in chars['dominant_frequencies'])}"
)
def main(): def main():
parser = argparse.ArgumentParser(description="Kokoro Voice Analysis Demo") parser = argparse.ArgumentParser(description="Kokoro Voice Analysis Demo")
parser.add_argument("--voices", nargs="+", type=str, help="Voices to combine") parser.add_argument("--voices", nargs="+", type=str, help="Voices to combine")
parser.add_argument("--text", type=str, default="Hello! This is a test of combined voices.", help="Text to speak") parser.add_argument(
"--text",
type=str,
default="Hello! This is a test of combined voices.",
help="Text to speak",
)
parser.add_argument("--url", default="http://localhost:8880", help="API base URL") parser.add_argument("--url", default="http://localhost:8880", help="API base URL")
parser.add_argument("--output-dir", default="examples/output", help="Output directory for audio files") parser.add_argument(
"--output-dir",
default="examples/output",
help="Output directory for audio files",
)
args = parser.parse_args() args = parser.parse_args()
if not args.voices: if not args.voices:
print("No voices provided, using default test voices") print("No voices provided, using default test voices")
args.voices = ["af_bella", "af_nicole"] args.voices = ["af_bella", "af_nicole"]
# Create output directory # Create output directory
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
# Dictionary to store audio files for analysis # Dictionary to store audio files for analysis
audio_files = {} audio_files = {}
# Generate speech with individual voices # Generate speech with individual voices
print("Generating speech with individual voices...") print("Generating speech with individual voices...")
for voice in args.voices: for voice in args.voices:
output_file = os.path.join(args.output_dir, f"analysis_{voice}.wav") output_file = os.path.join(args.output_dir, f"analysis_{voice}.wav")
if generate_speech(args.text, voice, args.url, output_file): if generate_speech(args.text, voice, args.url, output_file):
audio_files[voice] = output_file audio_files[voice] = output_file
# Generate speech with combined voice # Generate speech with combined voice
print(f"\nCombining voices: {', '.join(args.voices)}") print(f"\nCombining voices: {', '.join(args.voices)}")
combined_voice = submit_combine_voices(args.voices, args.url) combined_voice = submit_combine_voices(args.voices, args.url)
if combined_voice: if combined_voice:
print(f"Successfully created combined voice: {combined_voice}") print(f"Successfully created combined voice: {combined_voice}")
output_file = os.path.join(args.output_dir, f"analysis_combined_{combined_voice}.wav") output_file = os.path.join(
args.output_dir, f"analysis_combined_{combined_voice}.wav"
)
if generate_speech(args.text, combined_voice, args.url, output_file): if generate_speech(args.text, combined_voice, args.url, output_file):
audio_files["combined"] = output_file audio_files["combined"] = output_file
# Generate comparison plots # Generate comparison plots
plot_analysis(audio_files, args.output_dir) plot_analysis(audio_files, args.output_dir)
else: else:

View file

@ -60,7 +60,7 @@ def test_speed(speed: float):
# Test different formats # Test different formats
for format in ["wav", "mp3", "opus", "aac", "flac", "pcm"]: for format in ["wav", "mp3", "opus", "aac", "flac", "pcm"]:
test_format(format) # aac and pcm should fail as they are not supported test_format(format) # aac and pcm should fail as they are not supported
# Test different speeds # Test different speeds
for speed in [0.25, 1.0, 2.0, 4.0]: # 5.0 should fail as it's out of range for speed in [0.25, 1.0, 2.0, 4.0]: # 5.0 should fail as it's out of range

View file

@ -10,3 +10,5 @@ sqlalchemy==2.0.27
pytest==8.0.0 pytest==8.0.0
httpx==0.26.0 httpx==0.26.0
pytest-asyncio==0.23.5 pytest-asyncio==0.23.5
pytest-cov==6.0.0
gradio==4.19.2

View file

@ -2,8 +2,4 @@ from lib.interface import create_interface
if __name__ == "__main__": if __name__ == "__main__":
demo = create_interface() demo = create_interface()
demo.launch( demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
server_name="0.0.0.0",
server_port=7860,
show_error=True
)

View file

@ -1,16 +1,19 @@
import requests
from typing import Tuple, List, Optional
import os import os
import datetime import datetime
from typing import List, Tuple, Optional
import requests
from .config import API_URL, OUTPUTS_DIR from .config import API_URL, OUTPUTS_DIR
def check_api_status() -> Tuple[bool, List[str]]: def check_api_status() -> Tuple[bool, List[str]]:
"""Check TTS service status and get available voices.""" """Check TTS service status and get available voices."""
try: try:
# Use a longer timeout during startup # Use a longer timeout during startup
response = requests.get( response = requests.get(
f"{API_URL}/v1/audio/voices", f"{API_URL}/v1/audio/voices",
timeout=30 # Increased timeout for initial startup period timeout=30, # Increased timeout for initial startup period
) )
response.raise_for_status() response.raise_for_status()
voices = response.json().get("voices", []) voices = response.json().get("voices", [])
@ -31,16 +34,19 @@ def check_api_status() -> Tuple[bool, List[str]]:
print(f"Unexpected error checking API status: {str(e)}") print(f"Unexpected error checking API status: {str(e)}")
return False, [] return False, []
def text_to_speech(text: str, voice_id: str, format: str, speed: float) -> Optional[str]:
def text_to_speech(
text: str, voice_id: str, format: str, speed: float
) -> Optional[str]:
"""Generate speech from text using TTS API.""" """Generate speech from text using TTS API."""
if not text.strip(): if not text.strip():
return None return None
# Create output filename # Create output filename
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
output_filename = f"output_{timestamp}_voice-{voice_id}_speed-{speed}.{format}" output_filename = f"output_{timestamp}_voice-{voice_id}_speed-{speed}.{format}"
output_path = os.path.join(OUTPUTS_DIR, output_filename) output_path = os.path.join(OUTPUTS_DIR, output_filename)
try: try:
response = requests.post( response = requests.post(
f"{API_URL}/v1/audio/speech", f"{API_URL}/v1/audio/speech",
@ -49,17 +55,17 @@ def text_to_speech(text: str, voice_id: str, format: str, speed: float) -> Optio
"input": text, "input": text,
"voice": voice_id, "voice": voice_id,
"response_format": format, "response_format": format,
"speed": float(speed) "speed": float(speed),
}, },
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
timeout=300 # Longer timeout for speech generation timeout=300, # Longer timeout for speech generation
) )
response.raise_for_status() response.raise_for_status()
with open(output_path, "wb") as f: with open(output_path, "wb") as f:
f.write(response.content) f.write(response.content)
return output_path return output_path
except requests.exceptions.Timeout: except requests.exceptions.Timeout:
print("Speech generation request timed out") print("Speech generation request timed out")
return None return None
@ -70,6 +76,7 @@ def text_to_speech(text: str, voice_id: str, format: str, speed: float) -> Optio
print(f"Unexpected error generating speech: {str(e)}") print(f"Unexpected error generating speech: {str(e)}")
return None return None
def get_status_html(is_available: bool) -> str: def get_status_html(is_available: bool) -> str:
"""Generate HTML for status indicator.""" """Generate HTML for status indicator."""
color = "green" if is_available else "red" color = "green" if is_available else "red"

View file

@ -1,7 +1,10 @@
import gradio as gr
from typing import Tuple from typing import Tuple
import gradio as gr
from .. import files from .. import files
def create_input_column() -> Tuple[gr.Column, dict]: def create_input_column() -> Tuple[gr.Column, dict]:
"""Create the input column with text input and file handling.""" """Create the input column with text input and file handling."""
with gr.Column(scale=1) as col: with gr.Column(scale=1) as col:
@ -11,49 +14,36 @@ def create_input_column() -> Tuple[gr.Column, dict]:
# Direct Input Tab # Direct Input Tab
with gr.TabItem("Direct Input"): with gr.TabItem("Direct Input"):
text_input = gr.Textbox( text_input = gr.Textbox(
label="Text to speak", label="Text to speak", placeholder="Enter text here...", lines=4
placeholder="Enter text here...",
lines=4
) )
text_submit = gr.Button( text_submit = gr.Button("Generate Speech", variant="primary", size="lg")
"Generate Speech",
variant="primary",
size="lg"
)
# File Input Tab # File Input Tab
with gr.TabItem("From File"): with gr.TabItem("From File"):
# Existing files dropdown # Existing files dropdown
input_files_list = gr.Dropdown( input_files_list = gr.Dropdown(
label="Select Existing File", label="Select Existing File",
choices=files.list_input_files(), choices=files.list_input_files(),
value=None value=None,
) )
# Simple file upload # Simple file upload
file_upload = gr.File( file_upload = gr.File(
label="Upload Text File (.txt)", label="Upload Text File (.txt)", file_types=[".txt"]
file_types=[".txt"]
) )
file_preview = gr.Textbox( file_preview = gr.Textbox(
label="File Content Preview", label="File Content Preview", interactive=False, lines=4
interactive=False,
lines=4
) )
with gr.Row(): with gr.Row():
file_submit = gr.Button( file_submit = gr.Button(
"Generate Speech", "Generate Speech", variant="primary", size="lg"
variant="primary",
size="lg"
) )
clear_files = gr.Button( clear_files = gr.Button(
"Clear Files", "Clear Files", variant="secondary", size="lg"
variant="secondary",
size="lg"
) )
components = { components = {
"tabs": tabs, "tabs": tabs,
"text_input": text_input, "text_input": text_input,
@ -62,7 +52,7 @@ def create_input_column() -> Tuple[gr.Column, dict]:
"file_preview": file_preview, "file_preview": file_preview,
"text_submit": text_submit, "text_submit": text_submit,
"file_submit": file_submit, "file_submit": file_submit,
"clear_files": clear_files "clear_files": clear_files,
} }
return col, components return col, components

View file

@ -1,45 +1,41 @@
import gradio as gr
from typing import Tuple, Optional from typing import Tuple, Optional
import gradio as gr
from .. import api, config from .. import api, config
def create_model_column(voice_ids: Optional[list] = None) -> Tuple[gr.Column, dict]: def create_model_column(voice_ids: Optional[list] = None) -> Tuple[gr.Column, dict]:
"""Create the model settings column.""" """Create the model settings column."""
if voice_ids is None: if voice_ids is None:
voice_ids = [] voice_ids = []
with gr.Column(scale=1) as col: with gr.Column(scale=1) as col:
gr.Markdown("### Model Settings") gr.Markdown("### Model Settings")
# Status button starts in waiting state # Status button starts in waiting state
status_btn = gr.Button( status_btn = gr.Button(
"⌛ TTS Service: Waiting for Service...", "⌛ TTS Service: Waiting for Service...", variant="secondary"
variant="secondary"
) )
voice_input = gr.Dropdown( voice_input = gr.Dropdown(
choices=voice_ids, choices=voice_ids,
label="Voice", label="Voice",
value=voice_ids[0] if voice_ids else None, value=voice_ids[0] if voice_ids else None,
interactive=True interactive=True,
) )
format_input = gr.Dropdown( format_input = gr.Dropdown(
choices=config.AUDIO_FORMATS, choices=config.AUDIO_FORMATS, label="Audio Format", value="mp3"
label="Audio Format",
value="mp3"
) )
speed_input = gr.Slider( speed_input = gr.Slider(
minimum=0.5, minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speed"
maximum=2.0,
value=1.0,
step=0.1,
label="Speed"
) )
components = { components = {
"status_btn": status_btn, "status_btn": status_btn,
"voice": voice_input, "voice": voice_input,
"format": format_input, "format": format_input,
"speed": speed_input "speed": speed_input,
} }
return col, components return col, components

View file

@ -1,40 +1,42 @@
import gradio as gr
from typing import Tuple from typing import Tuple
import gradio as gr
from .. import files from .. import files
def create_output_column() -> Tuple[gr.Column, dict]: def create_output_column() -> Tuple[gr.Column, dict]:
"""Create the output column with audio player and file list.""" """Create the output column with audio player and file list."""
with gr.Column(scale=1) as col: with gr.Column(scale=1) as col:
gr.Markdown("### Latest Output") gr.Markdown("### Latest Output")
audio_output = gr.Audio( audio_output = gr.Audio(label="Generated Speech", type="filepath")
label="Generated Speech",
type="filepath"
)
gr.Markdown("### Generated Files") gr.Markdown("### Generated Files")
output_files = gr.Dropdown( output_files = gr.Dropdown(
label="Previous Outputs", label="Previous Outputs",
choices=files.list_output_files(), choices=files.list_output_files(),
value=None, value=None,
allow_custom_value=False allow_custom_value=False,
) )
play_btn = gr.Button("▶️ Play Selected", size="sm") play_btn = gr.Button("▶️ Play Selected", size="sm")
selected_audio = gr.Audio( selected_audio = gr.Audio(
label="Selected Output", label="Selected Output", type="filepath", visible=False
type="filepath",
visible=False
) )
clear_outputs = gr.Button("⚠️ Delete All Previously Generated Output Audio 🗑️", size="sm", variant="secondary") clear_outputs = gr.Button(
"⚠️ Delete All Previously Generated Output Audio 🗑️",
size="sm",
variant="secondary",
)
components = { components = {
"audio_output": audio_output, "audio_output": audio_output,
"output_files": output_files, "output_files": output_files,
"play_btn": play_btn, "play_btn": play_btn,
"selected_audio": selected_audio, "selected_audio": selected_audio,
"clear_outputs": clear_outputs "clear_outputs": clear_outputs,
} }
return col, components return col, components

View file

@ -1,17 +1,23 @@
import os import os
from typing import List, Optional, Tuple
import datetime import datetime
from typing import List, Tuple, Optional
from .config import INPUTS_DIR, OUTPUTS_DIR, AUDIO_FORMATS from .config import INPUTS_DIR, OUTPUTS_DIR, AUDIO_FORMATS
def list_input_files() -> List[str]: def list_input_files() -> List[str]:
"""List all input text files.""" """List all input text files."""
return [f for f in os.listdir(INPUTS_DIR) if f.endswith('.txt')] return [f for f in os.listdir(INPUTS_DIR) if f.endswith(".txt")]
def list_output_files() -> List[str]: def list_output_files() -> List[str]:
"""List all output audio files.""" """List all output audio files."""
return [os.path.join(OUTPUTS_DIR, f) return [
for f in os.listdir(OUTPUTS_DIR) os.path.join(OUTPUTS_DIR, f)
if any(f.endswith(ext) for ext in AUDIO_FORMATS)] for f in os.listdir(OUTPUTS_DIR)
if any(f.endswith(ext) for ext in AUDIO_FORMATS)
]
def read_text_file(filename: str) -> str: def read_text_file(filename: str) -> str:
"""Read content of a text file.""" """Read content of a text file."""
@ -19,16 +25,17 @@ def read_text_file(filename: str) -> str:
return "" return ""
try: try:
file_path = os.path.join(INPUTS_DIR, filename) file_path = os.path.join(INPUTS_DIR, filename)
with open(file_path, 'r', encoding='utf-8') as f: with open(file_path, "r", encoding="utf-8") as f:
return f.read() return f.read()
except: except:
return "" return ""
def save_text(text: str, filename: Optional[str] = None) -> Optional[str]: def save_text(text: str, filename: Optional[str] = None) -> Optional[str]:
"""Save text to a file. Returns the filename if successful.""" """Save text to a file. Returns the filename if successful."""
if not text.strip(): if not text.strip():
return None return None
if filename is None: if filename is None:
# Use input_1.txt, input_2.txt, etc. # Use input_1.txt, input_2.txt, etc.
base = "input" base = "input"
@ -41,12 +48,12 @@ def save_text(text: str, filename: Optional[str] = None) -> Optional[str]:
else: else:
# Handle duplicate filenames by adding _1, _2, etc. # Handle duplicate filenames by adding _1, _2, etc.
base = os.path.splitext(filename)[0] base = os.path.splitext(filename)[0]
ext = os.path.splitext(filename)[1] or '.txt' ext = os.path.splitext(filename)[1] or ".txt"
counter = 1 counter = 1
while os.path.exists(os.path.join(INPUTS_DIR, filename)): while os.path.exists(os.path.join(INPUTS_DIR, filename)):
filename = f"{base}_{counter}{ext}" filename = f"{base}_{counter}{ext}"
counter += 1 counter += 1
filepath = os.path.join(INPUTS_DIR, filename) filepath = os.path.join(INPUTS_DIR, filename)
try: try:
with open(filepath, "w", encoding="utf-8") as f: with open(filepath, "w", encoding="utf-8") as f:
@ -56,11 +63,12 @@ def save_text(text: str, filename: Optional[str] = None) -> Optional[str]:
print(f"Error saving file: {e}") print(f"Error saving file: {e}")
return None return None
def delete_all_input_files() -> bool: def delete_all_input_files() -> bool:
"""Delete all files from the inputs directory. Returns True if successful.""" """Delete all files from the inputs directory. Returns True if successful."""
try: try:
for filename in os.listdir(INPUTS_DIR): for filename in os.listdir(INPUTS_DIR):
if filename.endswith('.txt'): if filename.endswith(".txt"):
file_path = os.path.join(INPUTS_DIR, filename) file_path = os.path.join(INPUTS_DIR, filename)
os.remove(file_path) os.remove(file_path)
return True return True
@ -68,6 +76,7 @@ def delete_all_input_files() -> bool:
print(f"Error deleting input files: {e}") print(f"Error deleting input files: {e}")
return False return False
def delete_all_output_files() -> bool: def delete_all_output_files() -> bool:
"""Delete all audio files from the outputs directory. Returns True if successful.""" """Delete all audio files from the outputs directory. Returns True if successful."""
try: try:
@ -80,19 +89,20 @@ def delete_all_output_files() -> bool:
print(f"Error deleting output files: {e}") print(f"Error deleting output files: {e}")
return False return False
def process_uploaded_file(file_path: str) -> bool: def process_uploaded_file(file_path: str) -> bool:
"""Save uploaded file to inputs directory. Returns True if successful.""" """Save uploaded file to inputs directory. Returns True if successful."""
if not file_path: if not file_path:
return False return False
try: try:
filename = os.path.basename(file_path) filename = os.path.basename(file_path)
if not filename.endswith('.txt'): if not filename.endswith(".txt"):
return False return False
# Create target path in inputs directory # Create target path in inputs directory
target_path = os.path.join(INPUTS_DIR, filename) target_path = os.path.join(INPUTS_DIR, filename)
# If file exists, add number suffix # If file exists, add number suffix
base, ext = os.path.splitext(filename) base, ext = os.path.splitext(filename)
counter = 1 counter = 1
@ -100,12 +110,13 @@ def process_uploaded_file(file_path: str) -> bool:
new_name = f"{base}_{counter}{ext}" new_name = f"{base}_{counter}{ext}"
target_path = os.path.join(INPUTS_DIR, new_name) target_path = os.path.join(INPUTS_DIR, new_name)
counter += 1 counter += 1
# Copy file to inputs directory # Copy file to inputs directory
import shutil import shutil
shutil.copy2(file_path, target_path) shutil.copy2(file_path, target_path)
return True return True
except Exception as e: except Exception as e:
print(f"Error saving uploaded file: {e}") print(f"Error saving uploaded file: {e}")
return False return False

View file

@ -1,16 +1,19 @@
import gradio as gr
import os import os
import shutil import shutil
import gradio as gr
from . import api, files from . import api, files
def setup_event_handlers(components: dict): def setup_event_handlers(components: dict):
"""Set up all event handlers for the UI components.""" """Set up all event handlers for the UI components."""
def refresh_status(): def refresh_status():
try: try:
is_available, voices = api.check_api_status() is_available, voices = api.check_api_status()
status = "Available" if is_available else "Waiting for Service..." status = "Available" if is_available else "Waiting for Service..."
if is_available and voices: if is_available and voices:
# Preserve current voice selection if it exists and is still valid # Preserve current voice selection if it exists and is still valid
current_voice = components["model"]["voice"].value current_voice = components["model"]["voice"].value
@ -19,17 +22,17 @@ def setup_event_handlers(components: dict):
gr.update( gr.update(
value=f"🔄 TTS Service: {status}", value=f"🔄 TTS Service: {status}",
interactive=True, interactive=True,
variant="secondary" variant="secondary",
), ),
gr.update(choices=voices, value=default_voice) gr.update(choices=voices, value=default_voice),
] ]
return [ return [
gr.update( gr.update(
value=f"⌛ TTS Service: {status}", value=f"⌛ TTS Service: {status}",
interactive=True, interactive=True,
variant="secondary" variant="secondary",
), ),
gr.update(choices=[], value=None) gr.update(choices=[], value=None),
] ]
except Exception as e: except Exception as e:
print(f"Error in refresh status: {str(e)}") print(f"Error in refresh status: {str(e)}")
@ -37,11 +40,11 @@ def setup_event_handlers(components: dict):
gr.update( gr.update(
value="❌ TTS Service: Connection Error", value="❌ TTS Service: Connection Error",
interactive=True, interactive=True,
variant="secondary" variant="secondary",
), ),
gr.update(choices=[], value=None) gr.update(choices=[], value=None),
] ]
def handle_file_select(filename): def handle_file_select(filename):
if filename: if filename:
try: try:
@ -52,16 +55,16 @@ def setup_event_handlers(components: dict):
except Exception as e: except Exception as e:
print(f"Error reading file: {e}") print(f"Error reading file: {e}")
return gr.update(value="") return gr.update(value="")
def handle_file_upload(file): def handle_file_upload(file):
if file is None: if file is None:
return gr.update(choices=files.list_input_files()) return gr.update(choices=files.list_input_files())
try: try:
# Copy file to inputs directory # Copy file to inputs directory
filename = os.path.basename(file.name) filename = os.path.basename(file.name)
target_path = os.path.join(files.INPUTS_DIR, filename) target_path = os.path.join(files.INPUTS_DIR, filename)
# Handle duplicate filenames # Handle duplicate filenames
base, ext = os.path.splitext(filename) base, ext = os.path.splitext(filename)
counter = 1 counter = 1
@ -69,43 +72,36 @@ def setup_event_handlers(components: dict):
new_name = f"{base}_{counter}{ext}" new_name = f"{base}_{counter}{ext}"
target_path = os.path.join(files.INPUTS_DIR, new_name) target_path = os.path.join(files.INPUTS_DIR, new_name)
counter += 1 counter += 1
shutil.copy2(file.name, target_path) shutil.copy2(file.name, target_path)
except Exception as e: except Exception as e:
print(f"Error uploading file: {e}") print(f"Error uploading file: {e}")
return gr.update(choices=files.list_input_files()) return gr.update(choices=files.list_input_files())
def generate_from_text(text, voice, format, speed): def generate_from_text(text, voice, format, speed):
"""Generate speech from direct text input""" """Generate speech from direct text input"""
is_available, _ = api.check_api_status() is_available, _ = api.check_api_status()
if not is_available: if not is_available:
gr.Warning("TTS Service is currently unavailable") gr.Warning("TTS Service is currently unavailable")
return [ return [None, gr.update(choices=files.list_output_files())]
None,
gr.update(choices=files.list_output_files())
]
if not text or not text.strip(): if not text or not text.strip():
gr.Warning("Please enter text in the input box") gr.Warning("Please enter text in the input box")
return [ return [None, gr.update(choices=files.list_output_files())]
None,
gr.update(choices=files.list_output_files())
]
files.save_text(text) files.save_text(text)
result = api.text_to_speech(text, voice, format, speed) result = api.text_to_speech(text, voice, format, speed)
if result is None: if result is None:
gr.Warning("Failed to generate speech. Please try again.") gr.Warning("Failed to generate speech. Please try again.")
return [ return [None, gr.update(choices=files.list_output_files())]
None,
gr.update(choices=files.list_output_files())
]
return [ return [
result, result,
gr.update(choices=files.list_output_files(), value=os.path.basename(result)) gr.update(
choices=files.list_output_files(), value=os.path.basename(result)
),
] ]
def generate_from_file(selected_file, voice, format, speed): def generate_from_file(selected_file, voice, format, speed):
@ -113,37 +109,30 @@ def setup_event_handlers(components: dict):
is_available, _ = api.check_api_status() is_available, _ = api.check_api_status()
if not is_available: if not is_available:
gr.Warning("TTS Service is currently unavailable") gr.Warning("TTS Service is currently unavailable")
return [ return [None, gr.update(choices=files.list_output_files())]
None,
gr.update(choices=files.list_output_files())
]
if not selected_file: if not selected_file:
gr.Warning("Please select a file") gr.Warning("Please select a file")
return [ return [None, gr.update(choices=files.list_output_files())]
None,
gr.update(choices=files.list_output_files())
]
text = files.read_text_file(selected_file) text = files.read_text_file(selected_file)
result = api.text_to_speech(text, voice, format, speed) result = api.text_to_speech(text, voice, format, speed)
if result is None: if result is None:
gr.Warning("Failed to generate speech. Please try again.") gr.Warning("Failed to generate speech. Please try again.")
return [ return [None, gr.update(choices=files.list_output_files())]
None,
gr.update(choices=files.list_output_files())
]
return [ return [
result, result,
gr.update(choices=files.list_output_files(), value=os.path.basename(result)) gr.update(
choices=files.list_output_files(), value=os.path.basename(result)
),
] ]
def play_selected(file_path): def play_selected(file_path):
if file_path and os.path.exists(file_path): if file_path and os.path.exists(file_path):
return gr.update(value=file_path, visible=True) return gr.update(value=file_path, visible=True)
return gr.update(visible=False) return gr.update(visible=False)
def clear_files(voice, format, speed): def clear_files(voice, format, speed):
"""Delete all input files and clear UI components while preserving model settings""" """Delete all input files and clear UI components while preserving model settings"""
files.delete_all_input_files() files.delete_all_input_files()
@ -155,7 +144,7 @@ def setup_event_handlers(components: dict):
gr.update(choices=files.list_output_files()), # output_files gr.update(choices=files.list_output_files()), # output_files
gr.update(value=voice), # voice gr.update(value=voice), # voice
gr.update(value=format), # format gr.update(value=format), # format
gr.update(value=speed) # speed gr.update(value=speed), # speed
] ]
def clear_outputs(): def clear_outputs():
@ -164,43 +153,40 @@ def setup_event_handlers(components: dict):
return [ return [
None, # audio_output None, # audio_output
gr.update(choices=[], value=None), # output_files gr.update(choices=[], value=None), # output_files
gr.update(visible=False) # selected_audio gr.update(visible=False), # selected_audio
] ]
# Connect event handlers # Connect event handlers
components["model"]["status_btn"].click( components["model"]["status_btn"].click(
fn=refresh_status, fn=refresh_status,
outputs=[ outputs=[components["model"]["status_btn"], components["model"]["voice"]],
components["model"]["status_btn"],
components["model"]["voice"]
]
) )
components["input"]["file_select"].change( components["input"]["file_select"].change(
fn=handle_file_select, fn=handle_file_select,
inputs=[components["input"]["file_select"]], inputs=[components["input"]["file_select"]],
outputs=[components["input"]["file_preview"]] outputs=[components["input"]["file_preview"]],
) )
components["input"]["file_upload"].upload( components["input"]["file_upload"].upload(
fn=handle_file_upload, fn=handle_file_upload,
inputs=[components["input"]["file_upload"]], inputs=[components["input"]["file_upload"]],
outputs=[components["input"]["file_select"]] outputs=[components["input"]["file_select"]],
) )
components["output"]["play_btn"].click( components["output"]["play_btn"].click(
fn=play_selected, fn=play_selected,
inputs=[components["output"]["output_files"]], inputs=[components["output"]["output_files"]],
outputs=[components["output"]["selected_audio"]] outputs=[components["output"]["selected_audio"]],
) )
# Connect clear files button # Connect clear files button
components["input"]["clear_files"].click( components["input"]["clear_files"].click(
fn=clear_files, fn=clear_files,
inputs=[ inputs=[
components["model"]["voice"], components["model"]["voice"],
components["model"]["format"], components["model"]["format"],
components["model"]["speed"] components["model"]["speed"],
], ],
outputs=[ outputs=[
components["input"]["file_select"], components["input"]["file_select"],
@ -210,10 +196,10 @@ def setup_event_handlers(components: dict):
components["output"]["output_files"], components["output"]["output_files"],
components["model"]["voice"], components["model"]["voice"],
components["model"]["format"], components["model"]["format"],
components["model"]["speed"] components["model"]["speed"],
] ],
) )
# Connect submit buttons for each tab # Connect submit buttons for each tab
components["input"]["text_submit"].click( components["input"]["text_submit"].click(
fn=generate_from_text, fn=generate_from_text,
@ -221,22 +207,22 @@ def setup_event_handlers(components: dict):
components["input"]["text_input"], components["input"]["text_input"],
components["model"]["voice"], components["model"]["voice"],
components["model"]["format"], components["model"]["format"],
components["model"]["speed"] components["model"]["speed"],
], ],
outputs=[ outputs=[
components["output"]["audio_output"], components["output"]["audio_output"],
components["output"]["output_files"] components["output"]["output_files"],
] ],
) )
# Connect clear outputs button # Connect clear outputs button
components["output"]["clear_outputs"].click( components["output"]["clear_outputs"].click(
fn=clear_outputs, fn=clear_outputs,
outputs=[ outputs=[
components["output"]["audio_output"], components["output"]["audio_output"],
components["output"]["output_files"], components["output"]["output_files"],
components["output"]["selected_audio"] components["output"]["selected_audio"],
] ],
) )
components["input"]["file_submit"].click( components["input"]["file_submit"].click(
@ -245,10 +231,10 @@ def setup_event_handlers(components: dict):
components["input"]["file_select"], components["input"]["file_select"],
components["model"]["voice"], components["model"]["voice"],
components["model"]["format"], components["model"]["format"],
components["model"]["speed"] components["model"]["speed"],
], ],
outputs=[ outputs=[
components["output"]["audio_output"], components["output"]["audio_output"],
components["output"]["output_files"] components["output"]["output_files"],
] ],
) )

View file

@ -1,69 +1,75 @@
import gradio as gr import gradio as gr
from . import api from . import api
from .components import create_input_column, create_model_column, create_output_column
from .handlers import setup_event_handlers from .handlers import setup_event_handlers
from .components import create_input_column, create_model_column, create_output_column
def create_interface(): def create_interface():
"""Create the main Gradio interface.""" """Create the main Gradio interface."""
# Skip initial status check - let the timer handle it # Skip initial status check - let the timer handle it
is_available, available_voices = False, [] is_available, available_voices = False, []
with gr.Blocks( with gr.Blocks(title="Kokoro TTS Demo", theme=gr.themes.Monochrome()) as demo:
title="Kokoro TTS Demo", gr.HTML(
theme=gr.themes.Monochrome() value='<div style="display: flex; gap: 0;">'
) as demo: '<a href="https://huggingface.co/hexgrad/Kokoro-82M" target="_blank" style="color: #2196F3; text-decoration: none; margin: 2px; border: 1px solid #2196F3; padding: 4px 8px; height: 24px; box-sizing: border-box; display: inline-flex; align-items: center;">Kokoro-82M HF Repo</a>'
gr.HTML(value='<div style="display: flex; gap: 0;">' '<a href="https://github.com/remsky/Kokoro-FastAPI" target="_blank" style="color: #2196F3; text-decoration: none; margin: 2px; border: 1px solid #2196F3; padding: 4px 8px; height: 24px; box-sizing: border-box; display: inline-flex; align-items: center;">Kokoro-FastAPI Repo</a>'
'<a href="https://huggingface.co/hexgrad/Kokoro-82M" target="_blank" style="color: #2196F3; text-decoration: none; margin: 2px; border: 1px solid #2196F3; padding: 4px 8px; height: 24px; box-sizing: border-box; display: inline-flex; align-items: center;">Kokoro-82M HF Repo</a>' "</div>",
'<a href="https://github.com/remsky/Kokoro-FastAPI" target="_blank" style="color: #2196F3; text-decoration: none; margin: 2px; border: 1px solid #2196F3; padding: 4px 8px; height: 24px; box-sizing: border-box; display: inline-flex; align-items: center;">Kokoro-FastAPI Repo</a>' show_label=False,
'</div>', show_label=False) )
# Main interface # Main interface
with gr.Row(): with gr.Row():
# Create columns # Create columns
input_col, input_components = create_input_column() input_col, input_components = create_input_column()
model_col, model_components = create_model_column(available_voices) # Pass initial voices model_col, model_components = create_model_column(
available_voices
) # Pass initial voices
output_col, output_components = create_output_column() output_col, output_components = create_output_column()
# Collect all components # Collect all components
components = { components = {
"input": input_components, "input": input_components,
"model": model_components, "model": model_components,
"output": output_components "output": output_components,
} }
# Set up event handlers # Set up event handlers
setup_event_handlers(components) setup_event_handlers(components)
# Add periodic status check with Timer # Add periodic status check with Timer
def update_status(): def update_status():
try: try:
is_available, voices = api.check_api_status() is_available, voices = api.check_api_status()
status = "Available" if is_available else "Waiting for Service..." status = "Available" if is_available else "Waiting for Service..."
if is_available and voices: if is_available and voices:
# Service is available, update UI and stop timer # Service is available, update UI and stop timer
current_voice = components["model"]["voice"].value current_voice = components["model"]["voice"].value
default_voice = current_voice if current_voice in voices else voices[0] default_voice = (
current_voice if current_voice in voices else voices[0]
)
# Return values in same order as outputs list # Return values in same order as outputs list
return [ return [
gr.update( gr.update(
value=f"🔄 TTS Service: {status}", value=f"🔄 TTS Service: {status}",
interactive=True, interactive=True,
variant="secondary" variant="secondary",
), ),
gr.update(choices=voices, value=default_voice), gr.update(choices=voices, value=default_voice),
gr.update(active=False) # Stop timer gr.update(active=False), # Stop timer
] ]
# Service not available yet, keep checking # Service not available yet, keep checking
return [ return [
gr.update( gr.update(
value=f"⌛ TTS Service: {status}", value=f"⌛ TTS Service: {status}",
interactive=True, interactive=True,
variant="secondary" variant="secondary",
), ),
gr.update(choices=[], value=None), gr.update(choices=[], value=None),
gr.update(active=True) gr.update(active=True),
] ]
except Exception as e: except Exception as e:
print(f"Error in status update: {str(e)}") print(f"Error in status update: {str(e)}")
@ -72,20 +78,20 @@ def create_interface():
gr.update( gr.update(
value="❌ TTS Service: Connection Error", value="❌ TTS Service: Connection Error",
interactive=True, interactive=True,
variant="secondary" variant="secondary",
), ),
gr.update(choices=[], value=None), gr.update(choices=[], value=None),
gr.update(active=True) gr.update(active=True),
] ]
timer = gr.Timer(value=5) # Check every 5 seconds timer = gr.Timer(value=5) # Check every 5 seconds
timer.tick( timer.tick(
fn=update_status, fn=update_status,
outputs=[ outputs=[
components["model"]["status_btn"], components["model"]["status_btn"],
components["model"]["voice"], components["model"]["voice"],
timer timer,
] ],
) )
return demo return demo

View file

@ -1,5 +1,5 @@
import pytest
import gradio as gr import gradio as gr
import pytest
@pytest.fixture @pytest.fixture

View file

@ -1,6 +1,8 @@
from unittest.mock import patch, mock_open
import pytest import pytest
import requests import requests
from unittest.mock import patch, mock_open
from ui.lib import api from ui.lib import api
@ -57,12 +59,11 @@ def test_check_api_status_connection_error():
def test_text_to_speech_success(mock_response, tmp_path): def test_text_to_speech_success(mock_response, tmp_path):
"""Test successful speech generation""" """Test successful speech generation"""
with patch("requests.post", return_value=mock_response({})), \ with patch("requests.post", return_value=mock_response({})), patch(
patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)), \ "ui.lib.api.OUTPUTS_DIR", str(tmp_path)
patch("builtins.open", mock_open()) as mock_file: ), patch("builtins.open", mock_open()) as mock_file:
result = api.text_to_speech("test text", "voice1", "mp3", 1.0) result = api.text_to_speech("test text", "voice1", "mp3", 1.0)
assert result is not None assert result is not None
assert "output_" in result assert "output_" in result
assert result.endswith(".mp3") assert result.endswith(".mp3")
@ -105,25 +106,24 @@ def test_get_status_html_unavailable():
def test_text_to_speech_api_params(mock_response, tmp_path): def test_text_to_speech_api_params(mock_response, tmp_path):
"""Test correct API parameters are sent""" """Test correct API parameters are sent"""
with patch("requests.post") as mock_post, \ with patch("requests.post") as mock_post, patch(
patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)), \ "ui.lib.api.OUTPUTS_DIR", str(tmp_path)
patch("builtins.open", mock_open()): ), patch("builtins.open", mock_open()):
mock_post.return_value = mock_response({}) mock_post.return_value = mock_response({})
api.text_to_speech("test text", "voice1", "mp3", 1.5) api.text_to_speech("test text", "voice1", "mp3", 1.5)
mock_post.assert_called_once() mock_post.assert_called_once()
args, kwargs = mock_post.call_args args, kwargs = mock_post.call_args
# Check request body # Check request body
assert kwargs["json"] == { assert kwargs["json"] == {
"model": "kokoro", "model": "kokoro",
"input": "test text", "input": "test text",
"voice": "voice1", "voice": "voice1",
"response_format": "mp3", "response_format": "mp3",
"speed": 1.5 "speed": 1.5,
} }
# Check headers and timeout # Check headers and timeout
assert kwargs["headers"] == {"Content-Type": "application/json"} assert kwargs["headers"] == {"Content-Type": "application/json"}
assert kwargs["timeout"] == 300 assert kwargs["timeout"] == 300

View file

@ -1,8 +1,9 @@
import pytest
import gradio as gr import gradio as gr
import pytest
from ui.lib.config import AUDIO_FORMATS
from ui.lib.components.model import create_model_column from ui.lib.components.model import create_model_column
from ui.lib.components.output import create_output_column from ui.lib.components.output import create_output_column
from ui.lib.config import AUDIO_FORMATS
def test_create_model_column_structure(): def test_create_model_column_structure():
@ -15,12 +16,7 @@ def test_create_model_column_structure():
assert isinstance(components, dict) assert isinstance(components, dict)
# Test expected components presence # Test expected components presence
expected_components = { expected_components = {"status_btn", "voice", "format", "speed"}
"status_btn",
"voice",
"format",
"speed"
}
assert set(components.keys()) == expected_components assert set(components.keys()) == expected_components
# Test component types # Test component types
@ -78,7 +74,7 @@ def test_create_output_column_structure():
"output_files", "output_files",
"play_btn", "play_btn",
"selected_audio", "selected_audio",
"clear_outputs" "clear_outputs",
} }
assert set(components.keys()) == expected_components assert set(components.keys()) == expected_components

View file

@ -1,6 +1,8 @@
import os import os
import pytest
from unittest.mock import patch from unittest.mock import patch
import pytest
from ui.lib import files from ui.lib import files
from ui.lib.config import AUDIO_FORMATS from ui.lib.config import AUDIO_FORMATS

View file

@ -1,5 +1,6 @@
import pytest
import gradio as gr import gradio as gr
import pytest
from ui.lib.components.input import create_input_column from ui.lib.components.input import create_input_column

View file

@ -1,12 +1,15 @@
import pytest from unittest.mock import MagicMock, PropertyMock, patch
import gradio as gr import gradio as gr
from unittest.mock import patch, MagicMock, PropertyMock import pytest
from ui.lib.interface import create_interface from ui.lib.interface import create_interface
@pytest.fixture @pytest.fixture
def mock_timer(): def mock_timer():
"""Create a mock timer with events property""" """Create a mock timer with events property"""
class MockEvent: class MockEvent:
def __init__(self, fn): def __init__(self, fn):
self.fn = fn self.fn = fn
@ -30,7 +33,7 @@ def test_create_interface_structure():
"""Test the basic structure of the created interface""" """Test the basic structure of the created interface"""
with patch("ui.lib.api.check_api_status", return_value=(False, [])): with patch("ui.lib.api.check_api_status", return_value=(False, [])):
demo = create_interface() demo = create_interface()
# Test interface type and theme # Test interface type and theme
assert isinstance(demo, gr.Blocks) assert isinstance(demo, gr.Blocks)
assert demo.title == "Kokoro TTS Demo" assert demo.title == "Kokoro TTS Demo"
@ -41,15 +44,14 @@ def test_interface_html_links():
"""Test that HTML links are properly configured""" """Test that HTML links are properly configured"""
with patch("ui.lib.api.check_api_status", return_value=(False, [])): with patch("ui.lib.api.check_api_status", return_value=(False, [])):
demo = create_interface() demo = create_interface()
# Find HTML component # Find HTML component
html_components = [ html_components = [
comp for comp in demo.blocks.values() comp for comp in demo.blocks.values() if isinstance(comp, gr.HTML)
if isinstance(comp, gr.HTML)
] ]
assert len(html_components) > 0 assert len(html_components) > 0
html = html_components[0] html = html_components[0]
# Check for required links # Check for required links
assert 'href="https://huggingface.co/hexgrad/Kokoro-82M"' in html.value assert 'href="https://huggingface.co/hexgrad/Kokoro-82M"' in html.value
assert 'href="https://github.com/remsky/Kokoro-FastAPI"' in html.value assert 'href="https://github.com/remsky/Kokoro-FastAPI"' in html.value
@ -60,16 +62,17 @@ def test_interface_html_links():
def test_update_status_available(mock_timer): def test_update_status_available(mock_timer):
"""Test status update when service is available""" """Test status update when service is available"""
voices = ["voice1", "voice2"] voices = ["voice1", "voice2"]
with patch("ui.lib.api.check_api_status", return_value=(True, voices)), \ with patch("ui.lib.api.check_api_status", return_value=(True, voices)), patch(
patch("gradio.Timer", return_value=mock_timer): "gradio.Timer", return_value=mock_timer
):
demo = create_interface() demo = create_interface()
# Get the update function # Get the update function
update_fn = mock_timer.events[0].fn update_fn = mock_timer.events[0].fn
# Test update with available service # Test update with available service
updates = update_fn() updates = update_fn()
assert "Available" in updates[0]["value"] assert "Available" in updates[0]["value"]
assert updates[1]["choices"] == voices assert updates[1]["choices"] == voices
assert updates[1]["value"] == voices[0] assert updates[1]["value"] == voices[0]
@ -78,13 +81,14 @@ def test_update_status_available(mock_timer):
def test_update_status_unavailable(mock_timer): def test_update_status_unavailable(mock_timer):
"""Test status update when service is unavailable""" """Test status update when service is unavailable"""
with patch("ui.lib.api.check_api_status", return_value=(False, [])), \ with patch("ui.lib.api.check_api_status", return_value=(False, [])), patch(
patch("gradio.Timer", return_value=mock_timer): "gradio.Timer", return_value=mock_timer
):
demo = create_interface() demo = create_interface()
update_fn = mock_timer.events[0].fn update_fn = mock_timer.events[0].fn
updates = update_fn() updates = update_fn()
assert "Waiting for Service" in updates[0]["value"] assert "Waiting for Service" in updates[0]["value"]
assert updates[1]["choices"] == [] assert updates[1]["choices"] == []
assert updates[1]["value"] is None assert updates[1]["value"] is None
@ -93,13 +97,14 @@ def test_update_status_unavailable(mock_timer):
def test_update_status_error(mock_timer): def test_update_status_error(mock_timer):
"""Test status update when an error occurs""" """Test status update when an error occurs"""
with patch("ui.lib.api.check_api_status", side_effect=Exception("Test error")), \ with patch(
patch("gradio.Timer", return_value=mock_timer): "ui.lib.api.check_api_status", side_effect=Exception("Test error")
), patch("gradio.Timer", return_value=mock_timer):
demo = create_interface() demo = create_interface()
update_fn = mock_timer.events[0].fn update_fn = mock_timer.events[0].fn
updates = update_fn() updates = update_fn()
assert "Connection Error" in updates[0]["value"] assert "Connection Error" in updates[0]["value"]
assert updates[1]["choices"] == [] assert updates[1]["choices"] == []
assert updates[1]["value"] is None assert updates[1]["value"] is None
@ -108,10 +113,11 @@ def test_update_status_error(mock_timer):
def test_timer_configuration(mock_timer): def test_timer_configuration(mock_timer):
"""Test timer configuration""" """Test timer configuration"""
with patch("ui.lib.api.check_api_status", return_value=(False, [])), \ with patch("ui.lib.api.check_api_status", return_value=(False, [])), patch(
patch("gradio.Timer", return_value=mock_timer): "gradio.Timer", return_value=mock_timer
):
demo = create_interface() demo = create_interface()
assert mock_timer.value == 5 # Check interval is 5 seconds assert mock_timer.value == 5 # Check interval is 5 seconds
assert len(mock_timer.events) == 1 # Should have one event handler assert len(mock_timer.events) == 1 # Should have one event handler
@ -120,20 +126,21 @@ def test_interface_components_presence():
"""Test that all required components are present""" """Test that all required components are present"""
with patch("ui.lib.api.check_api_status", return_value=(False, [])): with patch("ui.lib.api.check_api_status", return_value=(False, [])):
demo = create_interface() demo = create_interface()
# Check for main component sections # Check for main component sections
components = { components = {
comp.label for comp in demo.blocks.values() comp.label
if hasattr(comp, 'label') and comp.label for comp in demo.blocks.values()
if hasattr(comp, "label") and comp.label
} }
required_components = { required_components = {
"Text to speak", "Text to speak",
"Voice", "Voice",
"Audio Format", "Audio Format",
"Speed", "Speed",
"Generated Speech", "Generated Speech",
"Previous Outputs" "Previous Outputs",
} }
assert required_components.issubset(components) assert required_components.issubset(components)