Ruff Check + Format

This commit is contained in:
remsky 2025-01-01 21:50:41 -07:00
parent e749b3bc88
commit f051984805
27 changed files with 638 additions and 504 deletions

BIN
.coverage

Binary file not shown.

View file

@ -1,10 +1,10 @@
from typing import List
from fastapi import APIRouter, Depends, HTTPException, Response
from loguru import logger
from fastapi import Depends, Response, APIRouter, HTTPException
from ..services.audio import AudioService
from ..services.tts import TTSService
from ..services.audio import AudioService
from ..structures.schemas import OpenAISpeechRequest
router = APIRouter(
@ -55,14 +55,12 @@ async def create_speech(
except ValueError as e:
logger.error(f"Invalid request: {str(e)}")
raise HTTPException(
status_code=400,
detail={"error": "Invalid request", "message": str(e)}
status_code=400, detail={"error": "Invalid request", "message": str(e)}
)
except Exception as e:
logger.error(f"Error generating speech: {str(e)}")
raise HTTPException(
status_code=500,
detail={"error": "Server error", "message": str(e)}
status_code=500, detail={"error": "Server error", "message": str(e)}
)
@ -78,7 +76,9 @@ async def list_voices(tts_service: TTSService = Depends(get_tts_service)):
@router.post("/audio/voices/combine")
async def combine_voices(request: List[str], tts_service: TTSService = Depends(get_tts_service)):
async def combine_voices(
request: List[str], tts_service: TTSService = Depends(get_tts_service)
):
"""Combine multiple voices into a new voice.
Args:
@ -100,20 +100,17 @@ async def combine_voices(request: List[str], tts_service: TTSService = Depends(g
except ValueError as e:
logger.error(f"Invalid voice combination request: {str(e)}")
raise HTTPException(
status_code=400,
detail={"error": "Invalid request", "message": str(e)}
status_code=400, detail={"error": "Invalid request", "message": str(e)}
)
except RuntimeError as e:
logger.error(f"Server error during voice combination: {str(e)}")
raise HTTPException(
status_code=500,
detail={"error": "Server error", "message": str(e)}
status_code=500, detail={"error": "Server error", "message": str(e)}
)
except Exception as e:
logger.error(f"Unexpected error during voice combination: {str(e)}")
raise HTTPException(
status_code=500,
detail={"error": "Unexpected error", "message": str(e)}
status_code=500, detail={"error": "Unexpected error", "message": str(e)}
)

View file

@ -1,17 +1,16 @@
import io
import os
import re
import threading
import time
import threading
from typing import List, Tuple, Optional
import numpy as np
import scipy.io.wavfile as wavfile
import tiktoken
import torch
import tiktoken
import scipy.io.wavfile as wavfile
from kokoro import generate, tokenize, phonemize, normalize_text
from loguru import logger
from kokoro import generate, normalize_text, phonemize, tokenize
from models import build_model
from ..core.config import settings
@ -51,25 +50,37 @@ class TTSModel:
voice_path = os.path.join(cls.VOICES_DIR, file)
if not os.path.exists(voice_path):
try:
logger.info(f"Copying base voice {voice_name} to voices directory")
logger.info(
f"Copying base voice {voice_name} to voices directory"
)
base_path = os.path.join(base_voices_dir, file)
voicepack = torch.load(base_path, map_location=cls._device, weights_only=True)
voicepack = torch.load(
base_path,
map_location=cls._device,
weights_only=True,
)
torch.save(voicepack, voice_path)
except Exception as e:
logger.error(f"Error copying voice {voice_name}: {str(e)}")
logger.error(
f"Error copying voice {voice_name}: {str(e)}"
)
# Warm up with default voice
try:
dummy_text = "Hello"
voice_path = os.path.join(cls.VOICES_DIR, "af.pt")
dummy_voicepack = torch.load(voice_path, map_location=cls._device, weights_only=True)
generate(model, dummy_text, dummy_voicepack, lang='a', speed=1.0)
dummy_voicepack = torch.load(
voice_path, map_location=cls._device, weights_only=True
)
generate(model, dummy_text, dummy_voicepack, lang="a", speed=1.0)
logger.info("Model warm-up complete")
except Exception as e:
logger.warning(f"Model warm-up failed: {e}")
# Count voices in directory for validation
voice_count = len([f for f in os.listdir(cls.VOICES_DIR) if f.endswith('.pt')])
voice_count = len(
[f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]
)
return cls._instance, voice_count
@classmethod
@ -99,9 +110,15 @@ class TTSService:
voice_path = os.path.join(TTSModel.VOICES_DIR, file)
if not os.path.exists(voice_path):
try:
logger.info(f"Copying base voice {voice_name} to voices directory")
logger.info(
f"Copying base voice {voice_name} to voices directory"
)
base_path = os.path.join(base_voices_dir, file)
voicepack = torch.load(base_path, map_location=TTSModel._device, weights_only=True)
voicepack = torch.load(
base_path,
map_location=TTSModel._device,
weights_only=True,
)
torch.save(voicepack, voice_path)
except Exception as e:
logger.error(f"Error copying voice {voice_name}: {str(e)}")
@ -141,7 +158,9 @@ class TTSService:
# Load model and voice
model = TTSModel._instance
voicepack = torch.load(voice_path, map_location=TTSModel._device, weights_only=True)
voicepack = torch.load(
voice_path, map_location=TTSModel._device, weights_only=True
)
# Generate audio with or without stitching
if stitch_long_output:
@ -152,11 +171,11 @@ class TTSService:
for i, chunk in enumerate(chunks):
try:
# Validate phonemization first
ps = phonemize(chunk, voice[0])
tokens = tokenize(ps)
logger.debug(
f"Processing chunk {i + 1}/{len(chunks)}: {len(tokens)} tokens"
)
# ps = phonemize(chunk, voice[0])
# tokens = tokenize(ps)
# logger.debug(
# f"Processing chunk {i + 1}/{len(chunks)}: {len(tokens)} tokens"
# )
# Only proceed if phonemization succeeded
chunk_audio, _ = generate(
@ -226,7 +245,9 @@ class TTSService:
for voice in voices:
try:
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt")
voicepack = torch.load(voice_path, map_location=TTSModel._device, weights_only=True)
voicepack = torch.load(
voice_path, map_location=TTSModel._device, weights_only=True
)
t_voices.append(voicepack)
v_name.append(voice)
except Exception as e:
@ -242,7 +263,9 @@ class TTSService:
try:
torch.save(v, combined_path)
except Exception as e:
raise RuntimeError(f"Failed to save combined voice to {combined_path}: {str(e)}")
raise RuntimeError(
f"Failed to save combined voice to {combined_path}: {str(e)}"
)
return f

View file

@ -18,7 +18,7 @@ class OpenAISpeechRequest(BaseModel):
input: str = Field(..., description="The text to generate audio for")
voice: str = Field(
default="af",
description="The voice to use for generation. Can be a base voice or a combined voice name."
description="The voice to use for generation. Can be a base voice or a combined voice name.",
)
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field(
default="mp3",

View file

@ -1,16 +1,18 @@
import os
import shutil
import sys
import shutil
from unittest.mock import Mock, patch
import pytest
def cleanup_mock_dirs():
"""Clean up any MagicMock directories created during tests"""
mock_dir = "MagicMock"
if os.path.exists(mock_dir):
shutil.rmtree(mock_dir)
@pytest.fixture(autouse=True)
def cleanup():
"""Automatically clean up before and after each test"""
@ -18,6 +20,7 @@ def cleanup():
yield
cleanup_mock_dirs()
# Mock torch and other ML modules before they're imported
sys.modules["torch"] = Mock()
sys.modules["transformers"] = Mock()

View file

@ -1,6 +1,8 @@
"""Tests for AudioService"""
import numpy as np
import pytest
from api.src.services.audio import AudioService

View file

@ -1,7 +1,10 @@
"""Tests for FastAPI application"""
from unittest.mock import MagicMock, patch
import pytest
from unittest.mock import patch, MagicMock
from fastapi.testclient import TestClient
from api.src.main import app, lifespan
@ -19,15 +22,15 @@ def test_health_check(test_client):
@pytest.mark.asyncio
@patch('api.src.main.TTSModel')
@patch('api.src.main.logger')
@patch("api.src.main.TTSModel")
@patch("api.src.main.logger")
async def test_lifespan_successful_warmup(mock_logger, mock_tts_model):
"""Test successful model warmup in lifespan"""
# Mock the model initialization with model info and voicepack count
mock_model = MagicMock()
# Mock file system for voice counting
mock_tts_model.VOICES_DIR = "/mock/voices"
with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt', 'voice3.pt']):
with patch("os.listdir", return_value=["voice1.pt", "voice2.pt", "voice3.pt"]):
mock_tts_model.initialize.return_value = (mock_model, 3) # 3 voice files
mock_tts_model._device = "cuda" # Set device class variable
@ -49,8 +52,8 @@ async def test_lifespan_successful_warmup(mock_logger, mock_tts_model):
@pytest.mark.asyncio
@patch('api.src.main.TTSModel')
@patch('api.src.main.logger')
@patch("api.src.main.TTSModel")
@patch("api.src.main.logger")
async def test_lifespan_failed_warmup(mock_logger, mock_tts_model):
"""Test failed model warmup in lifespan"""
# Mock the model initialization to fail
@ -71,14 +74,14 @@ async def test_lifespan_failed_warmup(mock_logger, mock_tts_model):
@pytest.mark.asyncio
@patch('api.src.main.TTSModel')
@patch("api.src.main.TTSModel")
async def test_lifespan_cuda_warmup(mock_tts_model):
"""Test model warmup specifically on CUDA"""
# Mock the model initialization with CUDA and voicepacks
mock_model = MagicMock()
# Mock file system for voice counting
mock_tts_model.VOICES_DIR = "/mock/voices"
with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt']):
with patch("os.listdir", return_value=["voice1.pt", "voice2.pt"]):
mock_tts_model.initialize.return_value = (mock_model, 2) # 2 voice files
mock_tts_model._device = "cuda" # Set device class variable
@ -94,14 +97,16 @@ async def test_lifespan_cuda_warmup(mock_tts_model):
@pytest.mark.asyncio
@patch('api.src.main.TTSModel')
@patch("api.src.main.TTSModel")
async def test_lifespan_cpu_fallback(mock_tts_model):
"""Test model warmup falling back to CPU"""
# Mock the model initialization with CPU and voicepacks
mock_model = MagicMock()
# Mock file system for voice counting
mock_tts_model.VOICES_DIR = "/mock/voices"
with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt', 'voice3.pt', 'voice4.pt']):
with patch(
"os.listdir", return_value=["voice1.pt", "voice2.pt", "voice3.pt", "voice4.pt"]
):
mock_tts_model.initialize.return_value = (mock_model, 4) # 4 voice files
mock_tts_model._device = "cpu" # Set device class variable

View file

@ -1,9 +1,12 @@
"""Tests for TTSService"""
import os
from unittest.mock import MagicMock, call, patch
import numpy as np
import pytest
from unittest.mock import patch, MagicMock, call
from api.src.services.tts import TTSService, TTSModel
from api.src.services.tts import TTSModel, TTSService
@pytest.fixture
@ -50,27 +53,35 @@ def test_audio_to_bytes(tts_service, sample_audio):
assert len(audio_bytes) > 0
@patch('os.listdir')
@patch('os.path.join')
@patch("os.listdir")
@patch("os.path.join")
def test_list_voices(mock_join, mock_listdir, tts_service):
"""Test listing available voices"""
mock_listdir.return_value = ['voice1.pt', 'voice2.pt', 'not_a_voice.txt']
mock_join.return_value = '/fake/path'
mock_listdir.return_value = ["voice1.pt", "voice2.pt", "not_a_voice.txt"]
mock_join.return_value = "/fake/path"
voices = tts_service.list_voices()
assert len(voices) == 2
assert 'voice1' in voices
assert 'voice2' in voices
assert 'not_a_voice' not in voices
assert "voice1" in voices
assert "voice2" in voices
assert "not_a_voice" not in voices
@patch('api.src.services.tts.TTSModel.get_instance')
@patch('api.src.services.tts.TTSModel.get_voicepack')
@patch('api.src.services.tts.normalize_text')
@patch('api.src.services.tts.phonemize')
@patch('api.src.services.tts.tokenize')
@patch('api.src.services.tts.generate')
def test_generate_audio_empty_text(mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_voicepack, mock_instance, tts_service):
@patch("api.src.services.tts.TTSModel.get_instance")
@patch("api.src.services.tts.TTSModel.get_voicepack")
@patch("api.src.services.tts.normalize_text")
@patch("api.src.services.tts.phonemize")
@patch("api.src.services.tts.tokenize")
@patch("api.src.services.tts.generate")
def test_generate_audio_empty_text(
mock_generate,
mock_tokenize,
mock_phonemize,
mock_normalize,
mock_voicepack,
mock_instance,
tts_service,
):
"""Test generating audio with empty text"""
mock_normalize.return_value = ""
@ -78,14 +89,23 @@ def test_generate_audio_empty_text(mock_generate, mock_tokenize, mock_phonemize,
tts_service._generate_audio("", "af", 1.0)
@patch('api.src.services.tts.TTSModel.get_instance')
@patch('os.path.exists')
@patch('api.src.services.tts.normalize_text')
@patch('api.src.services.tts.phonemize')
@patch('api.src.services.tts.tokenize')
@patch('api.src.services.tts.generate')
@patch('torch.load')
def test_generate_audio_no_chunks(mock_torch_load, mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_exists, mock_instance, tts_service):
@patch("api.src.services.tts.TTSModel.get_instance")
@patch("os.path.exists")
@patch("api.src.services.tts.normalize_text")
@patch("api.src.services.tts.phonemize")
@patch("api.src.services.tts.tokenize")
@patch("api.src.services.tts.generate")
@patch("torch.load")
def test_generate_audio_no_chunks(
mock_torch_load,
mock_generate,
mock_tokenize,
mock_phonemize,
mock_normalize,
mock_exists,
mock_instance,
tts_service,
):
"""Test generating audio with no successful chunks"""
mock_normalize.return_value = "Test text"
mock_phonemize.return_value = "Test text"
@ -99,14 +119,24 @@ def test_generate_audio_no_chunks(mock_torch_load, mock_generate, mock_tokenize,
tts_service._generate_audio("Test text", "af", 1.0)
@patch('api.src.services.tts.TTSModel.get_instance')
@patch('os.path.exists')
@patch('api.src.services.tts.normalize_text')
@patch('api.src.services.tts.phonemize')
@patch('api.src.services.tts.tokenize')
@patch('api.src.services.tts.generate')
@patch('torch.load')
def test_generate_audio_success(mock_torch_load, mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_exists, mock_instance, tts_service, sample_audio):
@patch("api.src.services.tts.TTSModel.get_instance")
@patch("os.path.exists")
@patch("api.src.services.tts.normalize_text")
@patch("api.src.services.tts.phonemize")
@patch("api.src.services.tts.tokenize")
@patch("api.src.services.tts.generate")
@patch("torch.load")
def test_generate_audio_success(
mock_torch_load,
mock_generate,
mock_tokenize,
mock_phonemize,
mock_normalize,
mock_exists,
mock_instance,
tts_service,
sample_audio,
):
"""Test successful audio generation"""
mock_normalize.return_value = "Test text"
mock_phonemize.return_value = "Test text"
@ -122,8 +152,8 @@ def test_generate_audio_success(mock_torch_load, mock_generate, mock_tokenize, m
assert len(audio) > 0
@patch('api.src.services.tts.torch.cuda.is_available')
@patch('api.src.services.tts.build_model')
@patch("api.src.services.tts.torch.cuda.is_available")
@patch("api.src.services.tts.build_model")
def test_model_initialization_cuda(mock_build_model, mock_cuda_available):
"""Test model initialization with CUDA"""
mock_cuda_available.return_value = True
@ -138,8 +168,8 @@ def test_model_initialization_cuda(mock_build_model, mock_cuda_available):
mock_build_model.assert_called_once()
@patch('api.src.services.tts.torch.cuda.is_available')
@patch('api.src.services.tts.build_model')
@patch("api.src.services.tts.torch.cuda.is_available")
@patch("api.src.services.tts.build_model")
def test_model_initialization_cpu(mock_build_model, mock_cuda_available):
"""Test model initialization with CPU"""
mock_cuda_available.return_value = False
@ -154,8 +184,8 @@ def test_model_initialization_cpu(mock_build_model, mock_cuda_available):
mock_build_model.assert_called_once()
@patch('api.src.services.tts.TTSService._get_voice_path')
@patch('api.src.services.tts.TTSModel.get_instance')
@patch("api.src.services.tts.TTSService._get_voice_path")
@patch("api.src.services.tts.TTSModel.get_instance")
def test_voicepack_loading_error(mock_get_instance, mock_get_voice_path):
"""Test voicepack loading error handling"""
mock_get_voice_path.return_value = None
@ -168,7 +198,7 @@ def test_voicepack_loading_error(mock_get_instance, mock_get_voice_path):
service._generate_audio("test", "nonexistent_voice", 1.0)
@patch('api.src.services.tts.TTSModel')
@patch("api.src.services.tts.TTSModel")
def test_save_audio(mock_tts_model, tts_service, sample_audio, tmp_path):
"""Test saving audio to file"""
output_dir = os.path.join(tmp_path, "test_output")
@ -181,12 +211,20 @@ def test_save_audio(mock_tts_model, tts_service, sample_audio, tmp_path):
assert os.path.getsize(output_path) > 0
@patch('api.src.services.tts.TTSModel.get_instance')
@patch('os.path.exists')
@patch('api.src.services.tts.normalize_text')
@patch('api.src.services.tts.generate')
@patch('torch.load')
def test_generate_audio_without_stitching(mock_torch_load, mock_generate, mock_normalize, mock_exists, mock_instance, tts_service, sample_audio):
@patch("api.src.services.tts.TTSModel.get_instance")
@patch("os.path.exists")
@patch("api.src.services.tts.normalize_text")
@patch("api.src.services.tts.generate")
@patch("torch.load")
def test_generate_audio_without_stitching(
mock_torch_load,
mock_generate,
mock_normalize,
mock_exists,
mock_instance,
tts_service,
sample_audio,
):
"""Test generating audio without text stitching"""
mock_normalize.return_value = "Test text"
mock_generate.return_value = (sample_audio, None)
@ -194,14 +232,16 @@ def test_generate_audio_without_stitching(mock_torch_load, mock_generate, mock_n
mock_exists.return_value = True
mock_torch_load.return_value = MagicMock()
audio, processing_time = tts_service._generate_audio("Test text", "af", 1.0, stitch_long_output=False)
audio, processing_time = tts_service._generate_audio(
"Test text", "af", 1.0, stitch_long_output=False
)
assert isinstance(audio, np.ndarray)
assert isinstance(processing_time, float)
assert len(audio) > 0
mock_generate.assert_called_once()
@patch('os.listdir')
@patch("os.listdir")
def test_list_voices_error(mock_listdir, tts_service):
"""Test error handling in list_voices"""
mock_listdir.side_effect = Exception("Failed to list directory")
@ -210,14 +250,23 @@ def test_list_voices_error(mock_listdir, tts_service):
assert voices == []
@patch('api.src.services.tts.TTSModel.get_instance')
@patch('os.path.exists')
@patch('api.src.services.tts.normalize_text')
@patch('api.src.services.tts.phonemize')
@patch('api.src.services.tts.tokenize')
@patch('api.src.services.tts.generate')
@patch('torch.load')
def test_generate_audio_phonemize_error(mock_torch_load, mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_exists, mock_instance, tts_service):
@patch("api.src.services.tts.TTSModel.get_instance")
@patch("os.path.exists")
@patch("api.src.services.tts.normalize_text")
@patch("api.src.services.tts.phonemize")
@patch("api.src.services.tts.tokenize")
@patch("api.src.services.tts.generate")
@patch("torch.load")
def test_generate_audio_phonemize_error(
mock_torch_load,
mock_generate,
mock_tokenize,
mock_phonemize,
mock_normalize,
mock_exists,
mock_instance,
tts_service,
):
"""Test handling phonemization error"""
mock_normalize.return_value = "Test text"
mock_phonemize.side_effect = Exception("Phonemization failed")
@ -230,12 +279,19 @@ def test_generate_audio_phonemize_error(mock_torch_load, mock_generate, mock_tok
tts_service._generate_audio("Test text", "af", 1.0)
@patch('api.src.services.tts.TTSModel.get_instance')
@patch('os.path.exists')
@patch('api.src.services.tts.normalize_text')
@patch('api.src.services.tts.generate')
@patch('torch.load')
def test_generate_audio_error(mock_torch_load, mock_generate, mock_normalize, mock_exists, mock_instance, tts_service):
@patch("api.src.services.tts.TTSModel.get_instance")
@patch("os.path.exists")
@patch("api.src.services.tts.normalize_text")
@patch("api.src.services.tts.generate")
@patch("torch.load")
def test_generate_audio_error(
mock_torch_load,
mock_generate,
mock_normalize,
mock_exists,
mock_instance,
tts_service,
):
"""Test handling generation error"""
mock_normalize.return_value = "Test text"
mock_generate.side_effect = Exception("Generation failed")

View file

@ -19,7 +19,6 @@ output_dir = Path(__file__).parent / "output"
output_dir.mkdir(exist_ok=True)
def test_voice(voice: str):
speech_file = output_dir / f"speech_{voice}.mp3"
print(f"\nTesting voice: {voice}")

View file

@ -1,15 +1,17 @@
#!/usr/bin/env python3
import argparse
import os
from typing import List, Optional, Dict, Tuple
import argparse
from typing import Dict, List, Tuple, Optional
import requests
import numpy as np
from scipy.io import wavfile
import requests
import matplotlib.pyplot as plt
from scipy.io import wavfile
def submit_combine_voices(voices: List[str], base_url: str = "http://localhost:8880") -> Optional[str]:
def submit_combine_voices(
voices: List[str], base_url: str = "http://localhost:8880"
) -> Optional[str]:
"""Combine multiple voices into a new voice.
Args:
@ -46,7 +48,12 @@ def submit_combine_voices(voices: List[str], base_url: str = "http://localhost:8
return None
def generate_speech(text: str, voice: str, base_url: str = "http://localhost:8880", output_file: str = "output.mp3") -> bool:
def generate_speech(
text: str,
voice: str,
base_url: str = "http://localhost:8880",
output_file: str = "output.mp3",
) -> bool:
"""Generate speech using specified voice.
Args:
@ -65,8 +72,8 @@ def generate_speech(text: str, voice: str, base_url: str = "http://localhost:888
"input": text,
"voice": voice,
"speed": 1.0,
"response_format": "wav" # Use WAV for analysis
}
"response_format": "wav", # Use WAV for analysis
},
)
if response.status_code != 200:
@ -75,7 +82,10 @@ def generate_speech(text: str, voice: str, base_url: str = "http://localhost:888
return False
# Save the audio
os.makedirs(os.path.dirname(output_file) if os.path.dirname(output_file) else ".", exist_ok=True)
os.makedirs(
os.path.dirname(output_file) if os.path.dirname(output_file) else ".",
exist_ok=True,
)
with open(output_file, "wb") as f:
f.write(response.content)
print(f"Saved audio to {output_file}")
@ -136,7 +146,7 @@ def analyze_audio(filepath: str) -> Tuple[np.ndarray, int, dict]:
"duration": duration,
"zero_crossing_rate": zero_crossings,
"dominant_frequencies": dominant_freqs,
"spectral_centroid": spectral_centroid
"spectral_centroid": spectral_centroid,
}
return samples, sample_rate, characteristics
@ -167,6 +177,7 @@ def setup_plot(fig, ax, title):
return fig, ax
def plot_analysis(audio_files: Dict[str, str], output_dir: str):
"""Plot comprehensive voice analysis including waveforms and metrics comparison.
@ -175,7 +186,7 @@ def plot_analysis(audio_files: Dict[str, str], output_dir: str):
output_dir: Directory to save plot files
"""
# Set dark style
plt.style.use('dark_background')
plt.style.use("dark_background")
# Create figure with subplots
fig = plt.figure(figsize=(15, 15))
@ -183,8 +194,9 @@ def plot_analysis(audio_files: Dict[str, str], output_dir: str):
num_files = len(audio_files)
# Create subplot grid with proper spacing
gs = plt.GridSpec(num_files + 1, 2, height_ratios=[1.5]*num_files + [1],
hspace=0.4, wspace=0.3)
gs = plt.GridSpec(
num_files + 1, 2, height_ratios=[1.5] * num_files + [1], hspace=0.4, wspace=0.3
)
# Analyze all files first
all_chars = {}
@ -195,7 +207,7 @@ def plot_analysis(audio_files: Dict[str, str], output_dir: str):
# Plot waveform spanning both columns
ax = plt.subplot(gs[i, :])
time = np.arange(len(samples)) / sample_rate
plt.plot(time, samples / chars['max_amplitude'], linewidth=0.5, color="#ff2a6d")
plt.plot(time, samples / chars["max_amplitude"], linewidth=0.5, color="#ff2a6d")
ax.set_xlabel("Time (seconds)")
ax.set_ylabel("Normalized Amplitude")
ax.set_ylim(-1.1, 1.1)
@ -208,15 +220,27 @@ def plot_analysis(audio_files: Dict[str, str], output_dir: str):
# Left subplot: Brightness and Volume
ax1 = plt.subplot(gs[num_files, 0])
metrics1 = [
('Brightness', [chars['spectral_centroid']/1000 for chars in all_chars.values()], 'kHz'),
('Volume', [chars['rms']*100 for chars in all_chars.values()], 'RMS×100')
(
"Brightness",
[chars["spectral_centroid"] / 1000 for chars in all_chars.values()],
"kHz",
),
("Volume", [chars["rms"] * 100 for chars in all_chars.values()], "RMS×100"),
]
# Right subplot: Voice Pitch and Texture
ax2 = plt.subplot(gs[num_files, 1])
metrics2 = [
('Voice Pitch', [min(chars['dominant_frequencies']) for chars in all_chars.values()], 'Hz'),
('Texture', [chars['zero_crossing_rate']*1000 for chars in all_chars.values()], 'ZCR×1000')
(
"Voice Pitch",
[min(chars["dominant_frequencies"]) for chars in all_chars.values()],
"Hz",
),
(
"Texture",
[chars["zero_crossing_rate"] * 1000 for chars in all_chars.values()],
"ZCR×1000",
),
]
def plot_grouped_bars(ax, metrics, show_legend=True):
@ -232,16 +256,22 @@ def plot_analysis(audio_files: Dict[str, str], output_dir: str):
for i, (voice, color) in enumerate(zip(audio_files.keys(), colors)):
values = [m[1][i] for m in metrics]
offset = (i - n_voices / 2 + 0.5) * bar_width
bars = ax.bar(indices + offset, values, bar_width,
label=voice, color=color, alpha=0.8)
bars = ax.bar(
indices + offset, values, bar_width, label=voice, color=color, alpha=0.8
)
# Add value labels on top of bars
for bar in bars:
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height,
f'{height:.1f}',
ha='center', va='bottom', color='white',
fontsize=10)
ax.text(
bar.get_x() + bar.get_width() / 2.0,
height,
f"{height:.1f}",
ha="center",
va="bottom",
color="white",
fontsize=10,
)
ax.set_xticks(indices)
ax.set_xticklabels([f"{m[0]}\n({m[2]})" for m in metrics])
@ -250,20 +280,24 @@ def plot_analysis(audio_files: Dict[str, str], output_dir: str):
ax.set_ylim(0, max_val * 1.2)
if show_legend:
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left',
facecolor="#1a1a2e", edgecolor="#ffffff")
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
facecolor="#1a1a2e",
edgecolor="#ffffff",
)
# Plot both subplots
plot_grouped_bars(ax1, metrics1, show_legend=True)
plot_grouped_bars(ax2, metrics2, show_legend=False)
# Style both subplots
setup_plot(fig, ax1, 'Brightness and Volume')
setup_plot(fig, ax2, 'Voice Pitch and Texture')
setup_plot(fig, ax1, "Brightness and Volume")
setup_plot(fig, ax2, "Voice Pitch and Texture")
# Add y-axis labels
ax1.set_ylabel('Value')
ax2.set_ylabel('Value')
ax1.set_ylabel("Value")
ax2.set_ylabel("Value")
# Adjust the figure size to accommodate the legend
fig.set_size_inches(15, 15)
@ -282,15 +316,26 @@ def plot_analysis(audio_files: Dict[str, str], output_dir: str):
print(f" Duration: {chars['duration']:.2f}s")
print(f" Zero Crossing Rate: {chars['zero_crossing_rate']:.3f}")
print(f" Spectral Centroid: {chars['spectral_centroid']:.0f}Hz")
print(f" Dominant Frequencies: {', '.join(f'{f:.0f}Hz' for f in chars['dominant_frequencies'])}")
print(
f" Dominant Frequencies: {', '.join(f'{f:.0f}Hz' for f in chars['dominant_frequencies'])}"
)
def main():
parser = argparse.ArgumentParser(description="Kokoro Voice Analysis Demo")
parser.add_argument("--voices", nargs="+", type=str, help="Voices to combine")
parser.add_argument("--text", type=str, default="Hello! This is a test of combined voices.", help="Text to speak")
parser.add_argument(
"--text",
type=str,
default="Hello! This is a test of combined voices.",
help="Text to speak",
)
parser.add_argument("--url", default="http://localhost:8880", help="API base URL")
parser.add_argument("--output-dir", default="examples/output", help="Output directory for audio files")
parser.add_argument(
"--output-dir",
default="examples/output",
help="Output directory for audio files",
)
args = parser.parse_args()
if not args.voices:
@ -316,7 +361,9 @@ def main():
if combined_voice:
print(f"Successfully created combined voice: {combined_voice}")
output_file = os.path.join(args.output_dir, f"analysis_combined_{combined_voice}.wav")
output_file = os.path.join(
args.output_dir, f"analysis_combined_{combined_voice}.wav"
)
if generate_speech(args.text, combined_voice, args.url, output_file):
audio_files["combined"] = output_file

View file

@ -10,3 +10,5 @@ sqlalchemy==2.0.27
pytest==8.0.0
httpx==0.26.0
pytest-asyncio==0.23.5
pytest-cov==6.0.0
gradio==4.19.2

View file

@ -2,8 +2,4 @@ from lib.interface import create_interface
if __name__ == "__main__":
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True
)
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)

View file

@ -1,16 +1,19 @@
import requests
from typing import Tuple, List, Optional
import os
import datetime
from typing import List, Tuple, Optional
import requests
from .config import API_URL, OUTPUTS_DIR
def check_api_status() -> Tuple[bool, List[str]]:
"""Check TTS service status and get available voices."""
try:
# Use a longer timeout during startup
response = requests.get(
f"{API_URL}/v1/audio/voices",
timeout=30 # Increased timeout for initial startup period
timeout=30, # Increased timeout for initial startup period
)
response.raise_for_status()
voices = response.json().get("voices", [])
@ -31,7 +34,10 @@ def check_api_status() -> Tuple[bool, List[str]]:
print(f"Unexpected error checking API status: {str(e)}")
return False, []
def text_to_speech(text: str, voice_id: str, format: str, speed: float) -> Optional[str]:
def text_to_speech(
text: str, voice_id: str, format: str, speed: float
) -> Optional[str]:
"""Generate speech from text using TTS API."""
if not text.strip():
return None
@ -49,10 +55,10 @@ def text_to_speech(text: str, voice_id: str, format: str, speed: float) -> Optio
"input": text,
"voice": voice_id,
"response_format": format,
"speed": float(speed)
"speed": float(speed),
},
headers={"Content-Type": "application/json"},
timeout=300 # Longer timeout for speech generation
timeout=300, # Longer timeout for speech generation
)
response.raise_for_status()
@ -70,6 +76,7 @@ def text_to_speech(text: str, voice_id: str, format: str, speed: float) -> Optio
print(f"Unexpected error generating speech: {str(e)}")
return None
def get_status_html(is_available: bool) -> str:
"""Generate HTML for status indicator."""
color = "green" if is_available else "red"

View file

@ -1,7 +1,10 @@
import gradio as gr
from typing import Tuple
import gradio as gr
from .. import files
def create_input_column() -> Tuple[gr.Column, dict]:
"""Create the input column with text input and file handling."""
with gr.Column(scale=1) as col:
@ -11,15 +14,9 @@ def create_input_column() -> Tuple[gr.Column, dict]:
# Direct Input Tab
with gr.TabItem("Direct Input"):
text_input = gr.Textbox(
label="Text to speak",
placeholder="Enter text here...",
lines=4
)
text_submit = gr.Button(
"Generate Speech",
variant="primary",
size="lg"
label="Text to speak", placeholder="Enter text here...", lines=4
)
text_submit = gr.Button("Generate Speech", variant="primary", size="lg")
# File Input Tab
with gr.TabItem("From File"):
@ -27,31 +24,24 @@ def create_input_column() -> Tuple[gr.Column, dict]:
input_files_list = gr.Dropdown(
label="Select Existing File",
choices=files.list_input_files(),
value=None
value=None,
)
# Simple file upload
file_upload = gr.File(
label="Upload Text File (.txt)",
file_types=[".txt"]
label="Upload Text File (.txt)", file_types=[".txt"]
)
file_preview = gr.Textbox(
label="File Content Preview",
interactive=False,
lines=4
label="File Content Preview", interactive=False, lines=4
)
with gr.Row():
file_submit = gr.Button(
"Generate Speech",
variant="primary",
size="lg"
"Generate Speech", variant="primary", size="lg"
)
clear_files = gr.Button(
"Clear Files",
variant="secondary",
size="lg"
"Clear Files", variant="secondary", size="lg"
)
components = {
@ -62,7 +52,7 @@ def create_input_column() -> Tuple[gr.Column, dict]:
"file_preview": file_preview,
"text_submit": text_submit,
"file_submit": file_submit,
"clear_files": clear_files
"clear_files": clear_files,
}
return col, components

View file

@ -1,7 +1,10 @@
import gradio as gr
from typing import Tuple, Optional
import gradio as gr
from .. import api, config
def create_model_column(voice_ids: Optional[list] = None) -> Tuple[gr.Column, dict]:
"""Create the model settings column."""
if voice_ids is None:
@ -12,34 +15,27 @@ def create_model_column(voice_ids: Optional[list] = None) -> Tuple[gr.Column, di
# Status button starts in waiting state
status_btn = gr.Button(
"⌛ TTS Service: Waiting for Service...",
variant="secondary"
"⌛ TTS Service: Waiting for Service...", variant="secondary"
)
voice_input = gr.Dropdown(
choices=voice_ids,
label="Voice",
value=voice_ids[0] if voice_ids else None,
interactive=True
interactive=True,
)
format_input = gr.Dropdown(
choices=config.AUDIO_FORMATS,
label="Audio Format",
value="mp3"
choices=config.AUDIO_FORMATS, label="Audio Format", value="mp3"
)
speed_input = gr.Slider(
minimum=0.5,
maximum=2.0,
value=1.0,
step=0.1,
label="Speed"
minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speed"
)
components = {
"status_btn": status_btn,
"voice": voice_input,
"format": format_input,
"speed": speed_input
"speed": speed_input,
}
return col, components

View file

@ -1,40 +1,42 @@
import gradio as gr
from typing import Tuple
import gradio as gr
from .. import files
def create_output_column() -> Tuple[gr.Column, dict]:
"""Create the output column with audio player and file list."""
with gr.Column(scale=1) as col:
gr.Markdown("### Latest Output")
audio_output = gr.Audio(
label="Generated Speech",
type="filepath"
)
audio_output = gr.Audio(label="Generated Speech", type="filepath")
gr.Markdown("### Generated Files")
output_files = gr.Dropdown(
label="Previous Outputs",
choices=files.list_output_files(),
value=None,
allow_custom_value=False
allow_custom_value=False,
)
play_btn = gr.Button("▶️ Play Selected", size="sm")
selected_audio = gr.Audio(
label="Selected Output",
type="filepath",
visible=False
label="Selected Output", type="filepath", visible=False
)
clear_outputs = gr.Button("⚠️ Delete All Previously Generated Output Audio 🗑️", size="sm", variant="secondary")
clear_outputs = gr.Button(
"⚠️ Delete All Previously Generated Output Audio 🗑️",
size="sm",
variant="secondary",
)
components = {
"audio_output": audio_output,
"output_files": output_files,
"play_btn": play_btn,
"selected_audio": selected_audio,
"clear_outputs": clear_outputs
"clear_outputs": clear_outputs,
}
return col, components

View file

@ -1,17 +1,23 @@
import os
from typing import List, Optional, Tuple
import datetime
from typing import List, Tuple, Optional
from .config import INPUTS_DIR, OUTPUTS_DIR, AUDIO_FORMATS
def list_input_files() -> List[str]:
"""List all input text files."""
return [f for f in os.listdir(INPUTS_DIR) if f.endswith('.txt')]
return [f for f in os.listdir(INPUTS_DIR) if f.endswith(".txt")]
def list_output_files() -> List[str]:
"""List all output audio files."""
return [os.path.join(OUTPUTS_DIR, f)
return [
os.path.join(OUTPUTS_DIR, f)
for f in os.listdir(OUTPUTS_DIR)
if any(f.endswith(ext) for ext in AUDIO_FORMATS)]
if any(f.endswith(ext) for ext in AUDIO_FORMATS)
]
def read_text_file(filename: str) -> str:
"""Read content of a text file."""
@ -19,11 +25,12 @@ def read_text_file(filename: str) -> str:
return ""
try:
file_path = os.path.join(INPUTS_DIR, filename)
with open(file_path, 'r', encoding='utf-8') as f:
with open(file_path, "r", encoding="utf-8") as f:
return f.read()
except:
return ""
def save_text(text: str, filename: Optional[str] = None) -> Optional[str]:
"""Save text to a file. Returns the filename if successful."""
if not text.strip():
@ -41,7 +48,7 @@ def save_text(text: str, filename: Optional[str] = None) -> Optional[str]:
else:
# Handle duplicate filenames by adding _1, _2, etc.
base = os.path.splitext(filename)[0]
ext = os.path.splitext(filename)[1] or '.txt'
ext = os.path.splitext(filename)[1] or ".txt"
counter = 1
while os.path.exists(os.path.join(INPUTS_DIR, filename)):
filename = f"{base}_{counter}{ext}"
@ -56,11 +63,12 @@ def save_text(text: str, filename: Optional[str] = None) -> Optional[str]:
print(f"Error saving file: {e}")
return None
def delete_all_input_files() -> bool:
"""Delete all files from the inputs directory. Returns True if successful."""
try:
for filename in os.listdir(INPUTS_DIR):
if filename.endswith('.txt'):
if filename.endswith(".txt"):
file_path = os.path.join(INPUTS_DIR, filename)
os.remove(file_path)
return True
@ -68,6 +76,7 @@ def delete_all_input_files() -> bool:
print(f"Error deleting input files: {e}")
return False
def delete_all_output_files() -> bool:
"""Delete all audio files from the outputs directory. Returns True if successful."""
try:
@ -80,6 +89,7 @@ def delete_all_output_files() -> bool:
print(f"Error deleting output files: {e}")
return False
def process_uploaded_file(file_path: str) -> bool:
"""Save uploaded file to inputs directory. Returns True if successful."""
if not file_path:
@ -87,7 +97,7 @@ def process_uploaded_file(file_path: str) -> bool:
try:
filename = os.path.basename(file_path)
if not filename.endswith('.txt'):
if not filename.endswith(".txt"):
return False
# Create target path in inputs directory
@ -103,6 +113,7 @@ def process_uploaded_file(file_path: str) -> bool:
# Copy file to inputs directory
import shutil
shutil.copy2(file_path, target_path)
return True

View file

@ -1,8 +1,11 @@
import gradio as gr
import os
import shutil
import gradio as gr
from . import api, files
def setup_event_handlers(components: dict):
"""Set up all event handlers for the UI components."""
@ -19,17 +22,17 @@ def setup_event_handlers(components: dict):
gr.update(
value=f"🔄 TTS Service: {status}",
interactive=True,
variant="secondary"
variant="secondary",
),
gr.update(choices=voices, value=default_voice)
gr.update(choices=voices, value=default_voice),
]
return [
gr.update(
value=f"⌛ TTS Service: {status}",
interactive=True,
variant="secondary"
variant="secondary",
),
gr.update(choices=[], value=None)
gr.update(choices=[], value=None),
]
except Exception as e:
print(f"Error in refresh status: {str(e)}")
@ -37,9 +40,9 @@ def setup_event_handlers(components: dict):
gr.update(
value="❌ TTS Service: Connection Error",
interactive=True,
variant="secondary"
variant="secondary",
),
gr.update(choices=[], value=None)
gr.update(choices=[], value=None),
]
def handle_file_select(filename):
@ -82,30 +85,23 @@ def setup_event_handlers(components: dict):
is_available, _ = api.check_api_status()
if not is_available:
gr.Warning("TTS Service is currently unavailable")
return [
None,
gr.update(choices=files.list_output_files())
]
return [None, gr.update(choices=files.list_output_files())]
if not text or not text.strip():
gr.Warning("Please enter text in the input box")
return [
None,
gr.update(choices=files.list_output_files())
]
return [None, gr.update(choices=files.list_output_files())]
files.save_text(text)
result = api.text_to_speech(text, voice, format, speed)
if result is None:
gr.Warning("Failed to generate speech. Please try again.")
return [
None,
gr.update(choices=files.list_output_files())
]
return [None, gr.update(choices=files.list_output_files())]
return [
result,
gr.update(choices=files.list_output_files(), value=os.path.basename(result))
gr.update(
choices=files.list_output_files(), value=os.path.basename(result)
),
]
def generate_from_file(selected_file, voice, format, speed):
@ -113,30 +109,23 @@ def setup_event_handlers(components: dict):
is_available, _ = api.check_api_status()
if not is_available:
gr.Warning("TTS Service is currently unavailable")
return [
None,
gr.update(choices=files.list_output_files())
]
return [None, gr.update(choices=files.list_output_files())]
if not selected_file:
gr.Warning("Please select a file")
return [
None,
gr.update(choices=files.list_output_files())
]
return [None, gr.update(choices=files.list_output_files())]
text = files.read_text_file(selected_file)
result = api.text_to_speech(text, voice, format, speed)
if result is None:
gr.Warning("Failed to generate speech. Please try again.")
return [
None,
gr.update(choices=files.list_output_files())
]
return [None, gr.update(choices=files.list_output_files())]
return [
result,
gr.update(choices=files.list_output_files(), value=os.path.basename(result))
gr.update(
choices=files.list_output_files(), value=os.path.basename(result)
),
]
def play_selected(file_path):
@ -155,7 +144,7 @@ def setup_event_handlers(components: dict):
gr.update(choices=files.list_output_files()), # output_files
gr.update(value=voice), # voice
gr.update(value=format), # format
gr.update(value=speed) # speed
gr.update(value=speed), # speed
]
def clear_outputs():
@ -164,34 +153,31 @@ def setup_event_handlers(components: dict):
return [
None, # audio_output
gr.update(choices=[], value=None), # output_files
gr.update(visible=False) # selected_audio
gr.update(visible=False), # selected_audio
]
# Connect event handlers
components["model"]["status_btn"].click(
fn=refresh_status,
outputs=[
components["model"]["status_btn"],
components["model"]["voice"]
]
outputs=[components["model"]["status_btn"], components["model"]["voice"]],
)
components["input"]["file_select"].change(
fn=handle_file_select,
inputs=[components["input"]["file_select"]],
outputs=[components["input"]["file_preview"]]
outputs=[components["input"]["file_preview"]],
)
components["input"]["file_upload"].upload(
fn=handle_file_upload,
inputs=[components["input"]["file_upload"]],
outputs=[components["input"]["file_select"]]
outputs=[components["input"]["file_select"]],
)
components["output"]["play_btn"].click(
fn=play_selected,
inputs=[components["output"]["output_files"]],
outputs=[components["output"]["selected_audio"]]
outputs=[components["output"]["selected_audio"]],
)
# Connect clear files button
@ -200,7 +186,7 @@ def setup_event_handlers(components: dict):
inputs=[
components["model"]["voice"],
components["model"]["format"],
components["model"]["speed"]
components["model"]["speed"],
],
outputs=[
components["input"]["file_select"],
@ -210,8 +196,8 @@ def setup_event_handlers(components: dict):
components["output"]["output_files"],
components["model"]["voice"],
components["model"]["format"],
components["model"]["speed"]
]
components["model"]["speed"],
],
)
# Connect submit buttons for each tab
@ -221,12 +207,12 @@ def setup_event_handlers(components: dict):
components["input"]["text_input"],
components["model"]["voice"],
components["model"]["format"],
components["model"]["speed"]
components["model"]["speed"],
],
outputs=[
components["output"]["audio_output"],
components["output"]["output_files"]
]
components["output"]["output_files"],
],
)
# Connect clear outputs button
@ -235,8 +221,8 @@ def setup_event_handlers(components: dict):
outputs=[
components["output"]["audio_output"],
components["output"]["output_files"],
components["output"]["selected_audio"]
]
components["output"]["selected_audio"],
],
)
components["input"]["file_submit"].click(
@ -245,10 +231,10 @@ def setup_event_handlers(components: dict):
components["input"]["file_select"],
components["model"]["voice"],
components["model"]["format"],
components["model"]["speed"]
components["model"]["speed"],
],
outputs=[
components["output"]["audio_output"],
components["output"]["output_files"]
]
components["output"]["output_files"],
],
)

View file

@ -1,34 +1,38 @@
import gradio as gr
from . import api
from .components import create_input_column, create_model_column, create_output_column
from .handlers import setup_event_handlers
from .components import create_input_column, create_model_column, create_output_column
def create_interface():
"""Create the main Gradio interface."""
# Skip initial status check - let the timer handle it
is_available, available_voices = False, []
with gr.Blocks(
title="Kokoro TTS Demo",
theme=gr.themes.Monochrome()
) as demo:
gr.HTML(value='<div style="display: flex; gap: 0;">'
with gr.Blocks(title="Kokoro TTS Demo", theme=gr.themes.Monochrome()) as demo:
gr.HTML(
value='<div style="display: flex; gap: 0;">'
'<a href="https://huggingface.co/hexgrad/Kokoro-82M" target="_blank" style="color: #2196F3; text-decoration: none; margin: 2px; border: 1px solid #2196F3; padding: 4px 8px; height: 24px; box-sizing: border-box; display: inline-flex; align-items: center;">Kokoro-82M HF Repo</a>'
'<a href="https://github.com/remsky/Kokoro-FastAPI" target="_blank" style="color: #2196F3; text-decoration: none; margin: 2px; border: 1px solid #2196F3; padding: 4px 8px; height: 24px; box-sizing: border-box; display: inline-flex; align-items: center;">Kokoro-FastAPI Repo</a>'
'</div>', show_label=False)
"</div>",
show_label=False,
)
# Main interface
with gr.Row():
# Create columns
input_col, input_components = create_input_column()
model_col, model_components = create_model_column(available_voices) # Pass initial voices
model_col, model_components = create_model_column(
available_voices
) # Pass initial voices
output_col, output_components = create_output_column()
# Collect all components
components = {
"input": input_components,
"model": model_components,
"output": output_components
"output": output_components,
}
# Set up event handlers
@ -43,16 +47,18 @@ def create_interface():
if is_available and voices:
# Service is available, update UI and stop timer
current_voice = components["model"]["voice"].value
default_voice = current_voice if current_voice in voices else voices[0]
default_voice = (
current_voice if current_voice in voices else voices[0]
)
# Return values in same order as outputs list
return [
gr.update(
value=f"🔄 TTS Service: {status}",
interactive=True,
variant="secondary"
variant="secondary",
),
gr.update(choices=voices, value=default_voice),
gr.update(active=False) # Stop timer
gr.update(active=False), # Stop timer
]
# Service not available yet, keep checking
@ -60,10 +66,10 @@ def create_interface():
gr.update(
value=f"⌛ TTS Service: {status}",
interactive=True,
variant="secondary"
variant="secondary",
),
gr.update(choices=[], value=None),
gr.update(active=True)
gr.update(active=True),
]
except Exception as e:
print(f"Error in status update: {str(e)}")
@ -72,10 +78,10 @@ def create_interface():
gr.update(
value="❌ TTS Service: Connection Error",
interactive=True,
variant="secondary"
variant="secondary",
),
gr.update(choices=[], value=None),
gr.update(active=True)
gr.update(active=True),
]
timer = gr.Timer(value=5) # Check every 5 seconds
@ -84,8 +90,8 @@ def create_interface():
outputs=[
components["model"]["status_btn"],
components["model"]["voice"],
timer
]
timer,
],
)
return demo

View file

@ -1,5 +1,5 @@
import pytest
import gradio as gr
import pytest
@pytest.fixture

View file

@ -1,6 +1,8 @@
from unittest.mock import patch, mock_open
import pytest
import requests
from unittest.mock import patch, mock_open
from ui.lib import api
@ -57,10 +59,9 @@ def test_check_api_status_connection_error():
def test_text_to_speech_success(mock_response, tmp_path):
"""Test successful speech generation"""
with patch("requests.post", return_value=mock_response({})), \
patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)), \
patch("builtins.open", mock_open()) as mock_file:
with patch("requests.post", return_value=mock_response({})), patch(
"ui.lib.api.OUTPUTS_DIR", str(tmp_path)
), patch("builtins.open", mock_open()) as mock_file:
result = api.text_to_speech("test text", "voice1", "mp3", 1.0)
assert result is not None
@ -105,10 +106,9 @@ def test_get_status_html_unavailable():
def test_text_to_speech_api_params(mock_response, tmp_path):
"""Test correct API parameters are sent"""
with patch("requests.post") as mock_post, \
patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)), \
patch("builtins.open", mock_open()):
with patch("requests.post") as mock_post, patch(
"ui.lib.api.OUTPUTS_DIR", str(tmp_path)
), patch("builtins.open", mock_open()):
mock_post.return_value = mock_response({})
api.text_to_speech("test text", "voice1", "mp3", 1.5)
@ -121,7 +121,7 @@ def test_text_to_speech_api_params(mock_response, tmp_path):
"input": "test text",
"voice": "voice1",
"response_format": "mp3",
"speed": 1.5
"speed": 1.5,
}
# Check headers and timeout

View file

@ -1,8 +1,9 @@
import pytest
import gradio as gr
import pytest
from ui.lib.config import AUDIO_FORMATS
from ui.lib.components.model import create_model_column
from ui.lib.components.output import create_output_column
from ui.lib.config import AUDIO_FORMATS
def test_create_model_column_structure():
@ -15,12 +16,7 @@ def test_create_model_column_structure():
assert isinstance(components, dict)
# Test expected components presence
expected_components = {
"status_btn",
"voice",
"format",
"speed"
}
expected_components = {"status_btn", "voice", "format", "speed"}
assert set(components.keys()) == expected_components
# Test component types
@ -78,7 +74,7 @@ def test_create_output_column_structure():
"output_files",
"play_btn",
"selected_audio",
"clear_outputs"
"clear_outputs",
}
assert set(components.keys()) == expected_components

View file

@ -1,6 +1,8 @@
import os
import pytest
from unittest.mock import patch
import pytest
from ui.lib import files
from ui.lib.config import AUDIO_FORMATS

View file

@ -1,5 +1,6 @@
import pytest
import gradio as gr
import pytest
from ui.lib.components.input import create_input_column

View file

@ -1,12 +1,15 @@
import pytest
from unittest.mock import MagicMock, PropertyMock, patch
import gradio as gr
from unittest.mock import patch, MagicMock, PropertyMock
import pytest
from ui.lib.interface import create_interface
@pytest.fixture
def mock_timer():
"""Create a mock timer with events property"""
class MockEvent:
def __init__(self, fn):
self.fn = fn
@ -44,8 +47,7 @@ def test_interface_html_links():
# Find HTML component
html_components = [
comp for comp in demo.blocks.values()
if isinstance(comp, gr.HTML)
comp for comp in demo.blocks.values() if isinstance(comp, gr.HTML)
]
assert len(html_components) > 0
html = html_components[0]
@ -60,8 +62,9 @@ def test_interface_html_links():
def test_update_status_available(mock_timer):
"""Test status update when service is available"""
voices = ["voice1", "voice2"]
with patch("ui.lib.api.check_api_status", return_value=(True, voices)), \
patch("gradio.Timer", return_value=mock_timer):
with patch("ui.lib.api.check_api_status", return_value=(True, voices)), patch(
"gradio.Timer", return_value=mock_timer
):
demo = create_interface()
# Get the update function
@ -78,8 +81,9 @@ def test_update_status_available(mock_timer):
def test_update_status_unavailable(mock_timer):
"""Test status update when service is unavailable"""
with patch("ui.lib.api.check_api_status", return_value=(False, [])), \
patch("gradio.Timer", return_value=mock_timer):
with patch("ui.lib.api.check_api_status", return_value=(False, [])), patch(
"gradio.Timer", return_value=mock_timer
):
demo = create_interface()
update_fn = mock_timer.events[0].fn
@ -93,8 +97,9 @@ def test_update_status_unavailable(mock_timer):
def test_update_status_error(mock_timer):
"""Test status update when an error occurs"""
with patch("ui.lib.api.check_api_status", side_effect=Exception("Test error")), \
patch("gradio.Timer", return_value=mock_timer):
with patch(
"ui.lib.api.check_api_status", side_effect=Exception("Test error")
), patch("gradio.Timer", return_value=mock_timer):
demo = create_interface()
update_fn = mock_timer.events[0].fn
@ -108,8 +113,9 @@ def test_update_status_error(mock_timer):
def test_timer_configuration(mock_timer):
"""Test timer configuration"""
with patch("ui.lib.api.check_api_status", return_value=(False, [])), \
patch("gradio.Timer", return_value=mock_timer):
with patch("ui.lib.api.check_api_status", return_value=(False, [])), patch(
"gradio.Timer", return_value=mock_timer
):
demo = create_interface()
assert mock_timer.value == 5 # Check interval is 5 seconds
@ -123,8 +129,9 @@ def test_interface_components_presence():
# Check for main component sections
components = {
comp.label for comp in demo.blocks.values()
if hasattr(comp, 'label') and comp.label
comp.label
for comp in demo.blocks.values()
if hasattr(comp, "label") and comp.label
}
required_components = {
@ -133,7 +140,7 @@ def test_interface_components_presence():
"Audio Format",
"Speed",
"Generated Speech",
"Previous Outputs"
"Previous Outputs",
}
assert required_components.issubset(components)