Ruff Check + Format

This commit is contained in:
remsky 2025-01-01 21:50:41 -07:00
parent e749b3bc88
commit f051984805
27 changed files with 638 additions and 504 deletions

BIN
.coverage

Binary file not shown.

View file

@ -1,10 +1,10 @@
from typing import List from typing import List
from fastapi import APIRouter, Depends, HTTPException, Response
from loguru import logger from loguru import logger
from fastapi import Depends, Response, APIRouter, HTTPException
from ..services.audio import AudioService
from ..services.tts import TTSService from ..services.tts import TTSService
from ..services.audio import AudioService
from ..structures.schemas import OpenAISpeechRequest from ..structures.schemas import OpenAISpeechRequest
router = APIRouter( router = APIRouter(
@ -55,14 +55,12 @@ async def create_speech(
except ValueError as e: except ValueError as e:
logger.error(f"Invalid request: {str(e)}") logger.error(f"Invalid request: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=400, status_code=400, detail={"error": "Invalid request", "message": str(e)}
detail={"error": "Invalid request", "message": str(e)}
) )
except Exception as e: except Exception as e:
logger.error(f"Error generating speech: {str(e)}") logger.error(f"Error generating speech: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail={"error": "Server error", "message": str(e)}
detail={"error": "Server error", "message": str(e)}
) )
@ -78,7 +76,9 @@ async def list_voices(tts_service: TTSService = Depends(get_tts_service)):
@router.post("/audio/voices/combine") @router.post("/audio/voices/combine")
async def combine_voices(request: List[str], tts_service: TTSService = Depends(get_tts_service)): async def combine_voices(
request: List[str], tts_service: TTSService = Depends(get_tts_service)
):
"""Combine multiple voices into a new voice. """Combine multiple voices into a new voice.
Args: Args:
@ -100,20 +100,17 @@ async def combine_voices(request: List[str], tts_service: TTSService = Depends(g
except ValueError as e: except ValueError as e:
logger.error(f"Invalid voice combination request: {str(e)}") logger.error(f"Invalid voice combination request: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=400, status_code=400, detail={"error": "Invalid request", "message": str(e)}
detail={"error": "Invalid request", "message": str(e)}
) )
except RuntimeError as e: except RuntimeError as e:
logger.error(f"Server error during voice combination: {str(e)}") logger.error(f"Server error during voice combination: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail={"error": "Server error", "message": str(e)}
detail={"error": "Server error", "message": str(e)}
) )
except Exception as e: except Exception as e:
logger.error(f"Unexpected error during voice combination: {str(e)}") logger.error(f"Unexpected error during voice combination: {str(e)}")
raise HTTPException( raise HTTPException(
status_code=500, status_code=500, detail={"error": "Unexpected error", "message": str(e)}
detail={"error": "Unexpected error", "message": str(e)}
) )

View file

@ -1,17 +1,16 @@
import io import io
import os import os
import re import re
import threading
import time import time
import threading
from typing import List, Tuple, Optional from typing import List, Tuple, Optional
import numpy as np import numpy as np
import scipy.io.wavfile as wavfile
import tiktoken
import torch import torch
import tiktoken
import scipy.io.wavfile as wavfile
from kokoro import generate, tokenize, phonemize, normalize_text
from loguru import logger from loguru import logger
from kokoro import generate, normalize_text, phonemize, tokenize
from models import build_model from models import build_model
from ..core.config import settings from ..core.config import settings
@ -51,25 +50,37 @@ class TTSModel:
voice_path = os.path.join(cls.VOICES_DIR, file) voice_path = os.path.join(cls.VOICES_DIR, file)
if not os.path.exists(voice_path): if not os.path.exists(voice_path):
try: try:
logger.info(f"Copying base voice {voice_name} to voices directory") logger.info(
f"Copying base voice {voice_name} to voices directory"
)
base_path = os.path.join(base_voices_dir, file) base_path = os.path.join(base_voices_dir, file)
voicepack = torch.load(base_path, map_location=cls._device, weights_only=True) voicepack = torch.load(
base_path,
map_location=cls._device,
weights_only=True,
)
torch.save(voicepack, voice_path) torch.save(voicepack, voice_path)
except Exception as e: except Exception as e:
logger.error(f"Error copying voice {voice_name}: {str(e)}") logger.error(
f"Error copying voice {voice_name}: {str(e)}"
)
# Warm up with default voice # Warm up with default voice
try: try:
dummy_text = "Hello" dummy_text = "Hello"
voice_path = os.path.join(cls.VOICES_DIR, "af.pt") voice_path = os.path.join(cls.VOICES_DIR, "af.pt")
dummy_voicepack = torch.load(voice_path, map_location=cls._device, weights_only=True) dummy_voicepack = torch.load(
generate(model, dummy_text, dummy_voicepack, lang='a', speed=1.0) voice_path, map_location=cls._device, weights_only=True
)
generate(model, dummy_text, dummy_voicepack, lang="a", speed=1.0)
logger.info("Model warm-up complete") logger.info("Model warm-up complete")
except Exception as e: except Exception as e:
logger.warning(f"Model warm-up failed: {e}") logger.warning(f"Model warm-up failed: {e}")
# Count voices in directory for validation # Count voices in directory for validation
voice_count = len([f for f in os.listdir(cls.VOICES_DIR) if f.endswith('.pt')]) voice_count = len(
[f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]
)
return cls._instance, voice_count return cls._instance, voice_count
@classmethod @classmethod
@ -99,9 +110,15 @@ class TTSService:
voice_path = os.path.join(TTSModel.VOICES_DIR, file) voice_path = os.path.join(TTSModel.VOICES_DIR, file)
if not os.path.exists(voice_path): if not os.path.exists(voice_path):
try: try:
logger.info(f"Copying base voice {voice_name} to voices directory") logger.info(
f"Copying base voice {voice_name} to voices directory"
)
base_path = os.path.join(base_voices_dir, file) base_path = os.path.join(base_voices_dir, file)
voicepack = torch.load(base_path, map_location=TTSModel._device, weights_only=True) voicepack = torch.load(
base_path,
map_location=TTSModel._device,
weights_only=True,
)
torch.save(voicepack, voice_path) torch.save(voicepack, voice_path)
except Exception as e: except Exception as e:
logger.error(f"Error copying voice {voice_name}: {str(e)}") logger.error(f"Error copying voice {voice_name}: {str(e)}")
@ -141,7 +158,9 @@ class TTSService:
# Load model and voice # Load model and voice
model = TTSModel._instance model = TTSModel._instance
voicepack = torch.load(voice_path, map_location=TTSModel._device, weights_only=True) voicepack = torch.load(
voice_path, map_location=TTSModel._device, weights_only=True
)
# Generate audio with or without stitching # Generate audio with or without stitching
if stitch_long_output: if stitch_long_output:
@ -152,11 +171,11 @@ class TTSService:
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):
try: try:
# Validate phonemization first # Validate phonemization first
ps = phonemize(chunk, voice[0]) # ps = phonemize(chunk, voice[0])
tokens = tokenize(ps) # tokens = tokenize(ps)
logger.debug( # logger.debug(
f"Processing chunk {i + 1}/{len(chunks)}: {len(tokens)} tokens" # f"Processing chunk {i + 1}/{len(chunks)}: {len(tokens)} tokens"
) # )
# Only proceed if phonemization succeeded # Only proceed if phonemization succeeded
chunk_audio, _ = generate( chunk_audio, _ = generate(
@ -226,7 +245,9 @@ class TTSService:
for voice in voices: for voice in voices:
try: try:
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt") voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt")
voicepack = torch.load(voice_path, map_location=TTSModel._device, weights_only=True) voicepack = torch.load(
voice_path, map_location=TTSModel._device, weights_only=True
)
t_voices.append(voicepack) t_voices.append(voicepack)
v_name.append(voice) v_name.append(voice)
except Exception as e: except Exception as e:
@ -242,7 +263,9 @@ class TTSService:
try: try:
torch.save(v, combined_path) torch.save(v, combined_path)
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to save combined voice to {combined_path}: {str(e)}") raise RuntimeError(
f"Failed to save combined voice to {combined_path}: {str(e)}"
)
return f return f

View file

@ -18,7 +18,7 @@ class OpenAISpeechRequest(BaseModel):
input: str = Field(..., description="The text to generate audio for") input: str = Field(..., description="The text to generate audio for")
voice: str = Field( voice: str = Field(
default="af", default="af",
description="The voice to use for generation. Can be a base voice or a combined voice name." description="The voice to use for generation. Can be a base voice or a combined voice name.",
) )
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field( response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field(
default="mp3", default="mp3",

View file

@ -1,16 +1,18 @@
import os import os
import shutil
import sys import sys
import shutil
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest import pytest
def cleanup_mock_dirs(): def cleanup_mock_dirs():
"""Clean up any MagicMock directories created during tests""" """Clean up any MagicMock directories created during tests"""
mock_dir = "MagicMock" mock_dir = "MagicMock"
if os.path.exists(mock_dir): if os.path.exists(mock_dir):
shutil.rmtree(mock_dir) shutil.rmtree(mock_dir)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def cleanup(): def cleanup():
"""Automatically clean up before and after each test""" """Automatically clean up before and after each test"""
@ -18,6 +20,7 @@ def cleanup():
yield yield
cleanup_mock_dirs() cleanup_mock_dirs()
# Mock torch and other ML modules before they're imported # Mock torch and other ML modules before they're imported
sys.modules["torch"] = Mock() sys.modules["torch"] = Mock()
sys.modules["transformers"] = Mock() sys.modules["transformers"] = Mock()

View file

@ -1,6 +1,8 @@
"""Tests for AudioService""" """Tests for AudioService"""
import numpy as np import numpy as np
import pytest import pytest
from api.src.services.audio import AudioService from api.src.services.audio import AudioService

View file

@ -1,7 +1,10 @@
"""Tests for FastAPI application""" """Tests for FastAPI application"""
from unittest.mock import MagicMock, patch
import pytest import pytest
from unittest.mock import patch, MagicMock
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from api.src.main import app, lifespan from api.src.main import app, lifespan
@ -19,15 +22,15 @@ def test_health_check(test_client):
@pytest.mark.asyncio @pytest.mark.asyncio
@patch('api.src.main.TTSModel') @patch("api.src.main.TTSModel")
@patch('api.src.main.logger') @patch("api.src.main.logger")
async def test_lifespan_successful_warmup(mock_logger, mock_tts_model): async def test_lifespan_successful_warmup(mock_logger, mock_tts_model):
"""Test successful model warmup in lifespan""" """Test successful model warmup in lifespan"""
# Mock the model initialization with model info and voicepack count # Mock the model initialization with model info and voicepack count
mock_model = MagicMock() mock_model = MagicMock()
# Mock file system for voice counting # Mock file system for voice counting
mock_tts_model.VOICES_DIR = "/mock/voices" mock_tts_model.VOICES_DIR = "/mock/voices"
with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt', 'voice3.pt']): with patch("os.listdir", return_value=["voice1.pt", "voice2.pt", "voice3.pt"]):
mock_tts_model.initialize.return_value = (mock_model, 3) # 3 voice files mock_tts_model.initialize.return_value = (mock_model, 3) # 3 voice files
mock_tts_model._device = "cuda" # Set device class variable mock_tts_model._device = "cuda" # Set device class variable
@ -49,8 +52,8 @@ async def test_lifespan_successful_warmup(mock_logger, mock_tts_model):
@pytest.mark.asyncio @pytest.mark.asyncio
@patch('api.src.main.TTSModel') @patch("api.src.main.TTSModel")
@patch('api.src.main.logger') @patch("api.src.main.logger")
async def test_lifespan_failed_warmup(mock_logger, mock_tts_model): async def test_lifespan_failed_warmup(mock_logger, mock_tts_model):
"""Test failed model warmup in lifespan""" """Test failed model warmup in lifespan"""
# Mock the model initialization to fail # Mock the model initialization to fail
@ -71,14 +74,14 @@ async def test_lifespan_failed_warmup(mock_logger, mock_tts_model):
@pytest.mark.asyncio @pytest.mark.asyncio
@patch('api.src.main.TTSModel') @patch("api.src.main.TTSModel")
async def test_lifespan_cuda_warmup(mock_tts_model): async def test_lifespan_cuda_warmup(mock_tts_model):
"""Test model warmup specifically on CUDA""" """Test model warmup specifically on CUDA"""
# Mock the model initialization with CUDA and voicepacks # Mock the model initialization with CUDA and voicepacks
mock_model = MagicMock() mock_model = MagicMock()
# Mock file system for voice counting # Mock file system for voice counting
mock_tts_model.VOICES_DIR = "/mock/voices" mock_tts_model.VOICES_DIR = "/mock/voices"
with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt']): with patch("os.listdir", return_value=["voice1.pt", "voice2.pt"]):
mock_tts_model.initialize.return_value = (mock_model, 2) # 2 voice files mock_tts_model.initialize.return_value = (mock_model, 2) # 2 voice files
mock_tts_model._device = "cuda" # Set device class variable mock_tts_model._device = "cuda" # Set device class variable
@ -94,14 +97,16 @@ async def test_lifespan_cuda_warmup(mock_tts_model):
@pytest.mark.asyncio @pytest.mark.asyncio
@patch('api.src.main.TTSModel') @patch("api.src.main.TTSModel")
async def test_lifespan_cpu_fallback(mock_tts_model): async def test_lifespan_cpu_fallback(mock_tts_model):
"""Test model warmup falling back to CPU""" """Test model warmup falling back to CPU"""
# Mock the model initialization with CPU and voicepacks # Mock the model initialization with CPU and voicepacks
mock_model = MagicMock() mock_model = MagicMock()
# Mock file system for voice counting # Mock file system for voice counting
mock_tts_model.VOICES_DIR = "/mock/voices" mock_tts_model.VOICES_DIR = "/mock/voices"
with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt', 'voice3.pt', 'voice4.pt']): with patch(
"os.listdir", return_value=["voice1.pt", "voice2.pt", "voice3.pt", "voice4.pt"]
):
mock_tts_model.initialize.return_value = (mock_model, 4) # 4 voice files mock_tts_model.initialize.return_value = (mock_model, 4) # 4 voice files
mock_tts_model._device = "cpu" # Set device class variable mock_tts_model._device = "cpu" # Set device class variable

View file

@ -1,9 +1,12 @@
"""Tests for TTSService""" """Tests for TTSService"""
import os import os
from unittest.mock import MagicMock, call, patch
import numpy as np import numpy as np
import pytest import pytest
from unittest.mock import patch, MagicMock, call
from api.src.services.tts import TTSService, TTSModel from api.src.services.tts import TTSModel, TTSService
@pytest.fixture @pytest.fixture
@ -50,27 +53,35 @@ def test_audio_to_bytes(tts_service, sample_audio):
assert len(audio_bytes) > 0 assert len(audio_bytes) > 0
@patch('os.listdir') @patch("os.listdir")
@patch('os.path.join') @patch("os.path.join")
def test_list_voices(mock_join, mock_listdir, tts_service): def test_list_voices(mock_join, mock_listdir, tts_service):
"""Test listing available voices""" """Test listing available voices"""
mock_listdir.return_value = ['voice1.pt', 'voice2.pt', 'not_a_voice.txt'] mock_listdir.return_value = ["voice1.pt", "voice2.pt", "not_a_voice.txt"]
mock_join.return_value = '/fake/path' mock_join.return_value = "/fake/path"
voices = tts_service.list_voices() voices = tts_service.list_voices()
assert len(voices) == 2 assert len(voices) == 2
assert 'voice1' in voices assert "voice1" in voices
assert 'voice2' in voices assert "voice2" in voices
assert 'not_a_voice' not in voices assert "not_a_voice" not in voices
@patch('api.src.services.tts.TTSModel.get_instance') @patch("api.src.services.tts.TTSModel.get_instance")
@patch('api.src.services.tts.TTSModel.get_voicepack') @patch("api.src.services.tts.TTSModel.get_voicepack")
@patch('api.src.services.tts.normalize_text') @patch("api.src.services.tts.normalize_text")
@patch('api.src.services.tts.phonemize') @patch("api.src.services.tts.phonemize")
@patch('api.src.services.tts.tokenize') @patch("api.src.services.tts.tokenize")
@patch('api.src.services.tts.generate') @patch("api.src.services.tts.generate")
def test_generate_audio_empty_text(mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_voicepack, mock_instance, tts_service): def test_generate_audio_empty_text(
mock_generate,
mock_tokenize,
mock_phonemize,
mock_normalize,
mock_voicepack,
mock_instance,
tts_service,
):
"""Test generating audio with empty text""" """Test generating audio with empty text"""
mock_normalize.return_value = "" mock_normalize.return_value = ""
@ -78,14 +89,23 @@ def test_generate_audio_empty_text(mock_generate, mock_tokenize, mock_phonemize,
tts_service._generate_audio("", "af", 1.0) tts_service._generate_audio("", "af", 1.0)
@patch('api.src.services.tts.TTSModel.get_instance') @patch("api.src.services.tts.TTSModel.get_instance")
@patch('os.path.exists') @patch("os.path.exists")
@patch('api.src.services.tts.normalize_text') @patch("api.src.services.tts.normalize_text")
@patch('api.src.services.tts.phonemize') @patch("api.src.services.tts.phonemize")
@patch('api.src.services.tts.tokenize') @patch("api.src.services.tts.tokenize")
@patch('api.src.services.tts.generate') @patch("api.src.services.tts.generate")
@patch('torch.load') @patch("torch.load")
def test_generate_audio_no_chunks(mock_torch_load, mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_exists, mock_instance, tts_service): def test_generate_audio_no_chunks(
mock_torch_load,
mock_generate,
mock_tokenize,
mock_phonemize,
mock_normalize,
mock_exists,
mock_instance,
tts_service,
):
"""Test generating audio with no successful chunks""" """Test generating audio with no successful chunks"""
mock_normalize.return_value = "Test text" mock_normalize.return_value = "Test text"
mock_phonemize.return_value = "Test text" mock_phonemize.return_value = "Test text"
@ -99,14 +119,24 @@ def test_generate_audio_no_chunks(mock_torch_load, mock_generate, mock_tokenize,
tts_service._generate_audio("Test text", "af", 1.0) tts_service._generate_audio("Test text", "af", 1.0)
@patch('api.src.services.tts.TTSModel.get_instance') @patch("api.src.services.tts.TTSModel.get_instance")
@patch('os.path.exists') @patch("os.path.exists")
@patch('api.src.services.tts.normalize_text') @patch("api.src.services.tts.normalize_text")
@patch('api.src.services.tts.phonemize') @patch("api.src.services.tts.phonemize")
@patch('api.src.services.tts.tokenize') @patch("api.src.services.tts.tokenize")
@patch('api.src.services.tts.generate') @patch("api.src.services.tts.generate")
@patch('torch.load') @patch("torch.load")
def test_generate_audio_success(mock_torch_load, mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_exists, mock_instance, tts_service, sample_audio): def test_generate_audio_success(
mock_torch_load,
mock_generate,
mock_tokenize,
mock_phonemize,
mock_normalize,
mock_exists,
mock_instance,
tts_service,
sample_audio,
):
"""Test successful audio generation""" """Test successful audio generation"""
mock_normalize.return_value = "Test text" mock_normalize.return_value = "Test text"
mock_phonemize.return_value = "Test text" mock_phonemize.return_value = "Test text"
@ -122,8 +152,8 @@ def test_generate_audio_success(mock_torch_load, mock_generate, mock_tokenize, m
assert len(audio) > 0 assert len(audio) > 0
@patch('api.src.services.tts.torch.cuda.is_available') @patch("api.src.services.tts.torch.cuda.is_available")
@patch('api.src.services.tts.build_model') @patch("api.src.services.tts.build_model")
def test_model_initialization_cuda(mock_build_model, mock_cuda_available): def test_model_initialization_cuda(mock_build_model, mock_cuda_available):
"""Test model initialization with CUDA""" """Test model initialization with CUDA"""
mock_cuda_available.return_value = True mock_cuda_available.return_value = True
@ -138,8 +168,8 @@ def test_model_initialization_cuda(mock_build_model, mock_cuda_available):
mock_build_model.assert_called_once() mock_build_model.assert_called_once()
@patch('api.src.services.tts.torch.cuda.is_available') @patch("api.src.services.tts.torch.cuda.is_available")
@patch('api.src.services.tts.build_model') @patch("api.src.services.tts.build_model")
def test_model_initialization_cpu(mock_build_model, mock_cuda_available): def test_model_initialization_cpu(mock_build_model, mock_cuda_available):
"""Test model initialization with CPU""" """Test model initialization with CPU"""
mock_cuda_available.return_value = False mock_cuda_available.return_value = False
@ -154,8 +184,8 @@ def test_model_initialization_cpu(mock_build_model, mock_cuda_available):
mock_build_model.assert_called_once() mock_build_model.assert_called_once()
@patch('api.src.services.tts.TTSService._get_voice_path') @patch("api.src.services.tts.TTSService._get_voice_path")
@patch('api.src.services.tts.TTSModel.get_instance') @patch("api.src.services.tts.TTSModel.get_instance")
def test_voicepack_loading_error(mock_get_instance, mock_get_voice_path): def test_voicepack_loading_error(mock_get_instance, mock_get_voice_path):
"""Test voicepack loading error handling""" """Test voicepack loading error handling"""
mock_get_voice_path.return_value = None mock_get_voice_path.return_value = None
@ -168,7 +198,7 @@ def test_voicepack_loading_error(mock_get_instance, mock_get_voice_path):
service._generate_audio("test", "nonexistent_voice", 1.0) service._generate_audio("test", "nonexistent_voice", 1.0)
@patch('api.src.services.tts.TTSModel') @patch("api.src.services.tts.TTSModel")
def test_save_audio(mock_tts_model, tts_service, sample_audio, tmp_path): def test_save_audio(mock_tts_model, tts_service, sample_audio, tmp_path):
"""Test saving audio to file""" """Test saving audio to file"""
output_dir = os.path.join(tmp_path, "test_output") output_dir = os.path.join(tmp_path, "test_output")
@ -181,12 +211,20 @@ def test_save_audio(mock_tts_model, tts_service, sample_audio, tmp_path):
assert os.path.getsize(output_path) > 0 assert os.path.getsize(output_path) > 0
@patch('api.src.services.tts.TTSModel.get_instance') @patch("api.src.services.tts.TTSModel.get_instance")
@patch('os.path.exists') @patch("os.path.exists")
@patch('api.src.services.tts.normalize_text') @patch("api.src.services.tts.normalize_text")
@patch('api.src.services.tts.generate') @patch("api.src.services.tts.generate")
@patch('torch.load') @patch("torch.load")
def test_generate_audio_without_stitching(mock_torch_load, mock_generate, mock_normalize, mock_exists, mock_instance, tts_service, sample_audio): def test_generate_audio_without_stitching(
mock_torch_load,
mock_generate,
mock_normalize,
mock_exists,
mock_instance,
tts_service,
sample_audio,
):
"""Test generating audio without text stitching""" """Test generating audio without text stitching"""
mock_normalize.return_value = "Test text" mock_normalize.return_value = "Test text"
mock_generate.return_value = (sample_audio, None) mock_generate.return_value = (sample_audio, None)
@ -194,14 +232,16 @@ def test_generate_audio_without_stitching(mock_torch_load, mock_generate, mock_n
mock_exists.return_value = True mock_exists.return_value = True
mock_torch_load.return_value = MagicMock() mock_torch_load.return_value = MagicMock()
audio, processing_time = tts_service._generate_audio("Test text", "af", 1.0, stitch_long_output=False) audio, processing_time = tts_service._generate_audio(
"Test text", "af", 1.0, stitch_long_output=False
)
assert isinstance(audio, np.ndarray) assert isinstance(audio, np.ndarray)
assert isinstance(processing_time, float) assert isinstance(processing_time, float)
assert len(audio) > 0 assert len(audio) > 0
mock_generate.assert_called_once() mock_generate.assert_called_once()
@patch('os.listdir') @patch("os.listdir")
def test_list_voices_error(mock_listdir, tts_service): def test_list_voices_error(mock_listdir, tts_service):
"""Test error handling in list_voices""" """Test error handling in list_voices"""
mock_listdir.side_effect = Exception("Failed to list directory") mock_listdir.side_effect = Exception("Failed to list directory")
@ -210,14 +250,23 @@ def test_list_voices_error(mock_listdir, tts_service):
assert voices == [] assert voices == []
@patch('api.src.services.tts.TTSModel.get_instance') @patch("api.src.services.tts.TTSModel.get_instance")
@patch('os.path.exists') @patch("os.path.exists")
@patch('api.src.services.tts.normalize_text') @patch("api.src.services.tts.normalize_text")
@patch('api.src.services.tts.phonemize') @patch("api.src.services.tts.phonemize")
@patch('api.src.services.tts.tokenize') @patch("api.src.services.tts.tokenize")
@patch('api.src.services.tts.generate') @patch("api.src.services.tts.generate")
@patch('torch.load') @patch("torch.load")
def test_generate_audio_phonemize_error(mock_torch_load, mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_exists, mock_instance, tts_service): def test_generate_audio_phonemize_error(
mock_torch_load,
mock_generate,
mock_tokenize,
mock_phonemize,
mock_normalize,
mock_exists,
mock_instance,
tts_service,
):
"""Test handling phonemization error""" """Test handling phonemization error"""
mock_normalize.return_value = "Test text" mock_normalize.return_value = "Test text"
mock_phonemize.side_effect = Exception("Phonemization failed") mock_phonemize.side_effect = Exception("Phonemization failed")
@ -230,12 +279,19 @@ def test_generate_audio_phonemize_error(mock_torch_load, mock_generate, mock_tok
tts_service._generate_audio("Test text", "af", 1.0) tts_service._generate_audio("Test text", "af", 1.0)
@patch('api.src.services.tts.TTSModel.get_instance') @patch("api.src.services.tts.TTSModel.get_instance")
@patch('os.path.exists') @patch("os.path.exists")
@patch('api.src.services.tts.normalize_text') @patch("api.src.services.tts.normalize_text")
@patch('api.src.services.tts.generate') @patch("api.src.services.tts.generate")
@patch('torch.load') @patch("torch.load")
def test_generate_audio_error(mock_torch_load, mock_generate, mock_normalize, mock_exists, mock_instance, tts_service): def test_generate_audio_error(
mock_torch_load,
mock_generate,
mock_normalize,
mock_exists,
mock_instance,
tts_service,
):
"""Test handling generation error""" """Test handling generation error"""
mock_normalize.return_value = "Test text" mock_normalize.return_value = "Test text"
mock_generate.side_effect = Exception("Generation failed") mock_generate.side_effect = Exception("Generation failed")

View file

@ -19,7 +19,6 @@ output_dir = Path(__file__).parent / "output"
output_dir.mkdir(exist_ok=True) output_dir.mkdir(exist_ok=True)
def test_voice(voice: str): def test_voice(voice: str):
speech_file = output_dir / f"speech_{voice}.mp3" speech_file = output_dir / f"speech_{voice}.mp3"
print(f"\nTesting voice: {voice}") print(f"\nTesting voice: {voice}")

View file

@ -1,15 +1,17 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse
import os import os
from typing import List, Optional, Dict, Tuple import argparse
from typing import Dict, List, Tuple, Optional
import requests
import numpy as np import numpy as np
from scipy.io import wavfile import requests
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from scipy.io import wavfile
def submit_combine_voices(voices: List[str], base_url: str = "http://localhost:8880") -> Optional[str]: def submit_combine_voices(
voices: List[str], base_url: str = "http://localhost:8880"
) -> Optional[str]:
"""Combine multiple voices into a new voice. """Combine multiple voices into a new voice.
Args: Args:
@ -46,7 +48,12 @@ def submit_combine_voices(voices: List[str], base_url: str = "http://localhost:8
return None return None
def generate_speech(text: str, voice: str, base_url: str = "http://localhost:8880", output_file: str = "output.mp3") -> bool: def generate_speech(
text: str,
voice: str,
base_url: str = "http://localhost:8880",
output_file: str = "output.mp3",
) -> bool:
"""Generate speech using specified voice. """Generate speech using specified voice.
Args: Args:
@ -65,8 +72,8 @@ def generate_speech(text: str, voice: str, base_url: str = "http://localhost:888
"input": text, "input": text,
"voice": voice, "voice": voice,
"speed": 1.0, "speed": 1.0,
"response_format": "wav" # Use WAV for analysis "response_format": "wav", # Use WAV for analysis
} },
) )
if response.status_code != 200: if response.status_code != 200:
@ -75,7 +82,10 @@ def generate_speech(text: str, voice: str, base_url: str = "http://localhost:888
return False return False
# Save the audio # Save the audio
os.makedirs(os.path.dirname(output_file) if os.path.dirname(output_file) else ".", exist_ok=True) os.makedirs(
os.path.dirname(output_file) if os.path.dirname(output_file) else ".",
exist_ok=True,
)
with open(output_file, "wb") as f: with open(output_file, "wb") as f:
f.write(response.content) f.write(response.content)
print(f"Saved audio to {output_file}") print(f"Saved audio to {output_file}")
@ -136,7 +146,7 @@ def analyze_audio(filepath: str) -> Tuple[np.ndarray, int, dict]:
"duration": duration, "duration": duration,
"zero_crossing_rate": zero_crossings, "zero_crossing_rate": zero_crossings,
"dominant_frequencies": dominant_freqs, "dominant_frequencies": dominant_freqs,
"spectral_centroid": spectral_centroid "spectral_centroid": spectral_centroid,
} }
return samples, sample_rate, characteristics return samples, sample_rate, characteristics
@ -167,6 +177,7 @@ def setup_plot(fig, ax, title):
return fig, ax return fig, ax
def plot_analysis(audio_files: Dict[str, str], output_dir: str): def plot_analysis(audio_files: Dict[str, str], output_dir: str):
"""Plot comprehensive voice analysis including waveforms and metrics comparison. """Plot comprehensive voice analysis including waveforms and metrics comparison.
@ -175,7 +186,7 @@ def plot_analysis(audio_files: Dict[str, str], output_dir: str):
output_dir: Directory to save plot files output_dir: Directory to save plot files
""" """
# Set dark style # Set dark style
plt.style.use('dark_background') plt.style.use("dark_background")
# Create figure with subplots # Create figure with subplots
fig = plt.figure(figsize=(15, 15)) fig = plt.figure(figsize=(15, 15))
@ -183,8 +194,9 @@ def plot_analysis(audio_files: Dict[str, str], output_dir: str):
num_files = len(audio_files) num_files = len(audio_files)
# Create subplot grid with proper spacing # Create subplot grid with proper spacing
gs = plt.GridSpec(num_files + 1, 2, height_ratios=[1.5]*num_files + [1], gs = plt.GridSpec(
hspace=0.4, wspace=0.3) num_files + 1, 2, height_ratios=[1.5] * num_files + [1], hspace=0.4, wspace=0.3
)
# Analyze all files first # Analyze all files first
all_chars = {} all_chars = {}
@ -195,7 +207,7 @@ def plot_analysis(audio_files: Dict[str, str], output_dir: str):
# Plot waveform spanning both columns # Plot waveform spanning both columns
ax = plt.subplot(gs[i, :]) ax = plt.subplot(gs[i, :])
time = np.arange(len(samples)) / sample_rate time = np.arange(len(samples)) / sample_rate
plt.plot(time, samples / chars['max_amplitude'], linewidth=0.5, color="#ff2a6d") plt.plot(time, samples / chars["max_amplitude"], linewidth=0.5, color="#ff2a6d")
ax.set_xlabel("Time (seconds)") ax.set_xlabel("Time (seconds)")
ax.set_ylabel("Normalized Amplitude") ax.set_ylabel("Normalized Amplitude")
ax.set_ylim(-1.1, 1.1) ax.set_ylim(-1.1, 1.1)
@ -208,15 +220,27 @@ def plot_analysis(audio_files: Dict[str, str], output_dir: str):
# Left subplot: Brightness and Volume # Left subplot: Brightness and Volume
ax1 = plt.subplot(gs[num_files, 0]) ax1 = plt.subplot(gs[num_files, 0])
metrics1 = [ metrics1 = [
('Brightness', [chars['spectral_centroid']/1000 for chars in all_chars.values()], 'kHz'), (
('Volume', [chars['rms']*100 for chars in all_chars.values()], 'RMS×100') "Brightness",
[chars["spectral_centroid"] / 1000 for chars in all_chars.values()],
"kHz",
),
("Volume", [chars["rms"] * 100 for chars in all_chars.values()], "RMS×100"),
] ]
# Right subplot: Voice Pitch and Texture # Right subplot: Voice Pitch and Texture
ax2 = plt.subplot(gs[num_files, 1]) ax2 = plt.subplot(gs[num_files, 1])
metrics2 = [ metrics2 = [
('Voice Pitch', [min(chars['dominant_frequencies']) for chars in all_chars.values()], 'Hz'), (
('Texture', [chars['zero_crossing_rate']*1000 for chars in all_chars.values()], 'ZCR×1000') "Voice Pitch",
[min(chars["dominant_frequencies"]) for chars in all_chars.values()],
"Hz",
),
(
"Texture",
[chars["zero_crossing_rate"] * 1000 for chars in all_chars.values()],
"ZCR×1000",
),
] ]
def plot_grouped_bars(ax, metrics, show_legend=True): def plot_grouped_bars(ax, metrics, show_legend=True):
@ -232,16 +256,22 @@ def plot_analysis(audio_files: Dict[str, str], output_dir: str):
for i, (voice, color) in enumerate(zip(audio_files.keys(), colors)): for i, (voice, color) in enumerate(zip(audio_files.keys(), colors)):
values = [m[1][i] for m in metrics] values = [m[1][i] for m in metrics]
offset = (i - n_voices / 2 + 0.5) * bar_width offset = (i - n_voices / 2 + 0.5) * bar_width
bars = ax.bar(indices + offset, values, bar_width, bars = ax.bar(
label=voice, color=color, alpha=0.8) indices + offset, values, bar_width, label=voice, color=color, alpha=0.8
)
# Add value labels on top of bars # Add value labels on top of bars
for bar in bars: for bar in bars:
height = bar.get_height() height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height, ax.text(
f'{height:.1f}', bar.get_x() + bar.get_width() / 2.0,
ha='center', va='bottom', color='white', height,
fontsize=10) f"{height:.1f}",
ha="center",
va="bottom",
color="white",
fontsize=10,
)
ax.set_xticks(indices) ax.set_xticks(indices)
ax.set_xticklabels([f"{m[0]}\n({m[2]})" for m in metrics]) ax.set_xticklabels([f"{m[0]}\n({m[2]})" for m in metrics])
@ -250,20 +280,24 @@ def plot_analysis(audio_files: Dict[str, str], output_dir: str):
ax.set_ylim(0, max_val * 1.2) ax.set_ylim(0, max_val * 1.2)
if show_legend: if show_legend:
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', ax.legend(
facecolor="#1a1a2e", edgecolor="#ffffff") bbox_to_anchor=(1.05, 1),
loc="upper left",
facecolor="#1a1a2e",
edgecolor="#ffffff",
)
# Plot both subplots # Plot both subplots
plot_grouped_bars(ax1, metrics1, show_legend=True) plot_grouped_bars(ax1, metrics1, show_legend=True)
plot_grouped_bars(ax2, metrics2, show_legend=False) plot_grouped_bars(ax2, metrics2, show_legend=False)
# Style both subplots # Style both subplots
setup_plot(fig, ax1, 'Brightness and Volume') setup_plot(fig, ax1, "Brightness and Volume")
setup_plot(fig, ax2, 'Voice Pitch and Texture') setup_plot(fig, ax2, "Voice Pitch and Texture")
# Add y-axis labels # Add y-axis labels
ax1.set_ylabel('Value') ax1.set_ylabel("Value")
ax2.set_ylabel('Value') ax2.set_ylabel("Value")
# Adjust the figure size to accommodate the legend # Adjust the figure size to accommodate the legend
fig.set_size_inches(15, 15) fig.set_size_inches(15, 15)
@ -282,15 +316,26 @@ def plot_analysis(audio_files: Dict[str, str], output_dir: str):
print(f" Duration: {chars['duration']:.2f}s") print(f" Duration: {chars['duration']:.2f}s")
print(f" Zero Crossing Rate: {chars['zero_crossing_rate']:.3f}") print(f" Zero Crossing Rate: {chars['zero_crossing_rate']:.3f}")
print(f" Spectral Centroid: {chars['spectral_centroid']:.0f}Hz") print(f" Spectral Centroid: {chars['spectral_centroid']:.0f}Hz")
print(f" Dominant Frequencies: {', '.join(f'{f:.0f}Hz' for f in chars['dominant_frequencies'])}") print(
f" Dominant Frequencies: {', '.join(f'{f:.0f}Hz' for f in chars['dominant_frequencies'])}"
)
def main(): def main():
parser = argparse.ArgumentParser(description="Kokoro Voice Analysis Demo") parser = argparse.ArgumentParser(description="Kokoro Voice Analysis Demo")
parser.add_argument("--voices", nargs="+", type=str, help="Voices to combine") parser.add_argument("--voices", nargs="+", type=str, help="Voices to combine")
parser.add_argument("--text", type=str, default="Hello! This is a test of combined voices.", help="Text to speak") parser.add_argument(
"--text",
type=str,
default="Hello! This is a test of combined voices.",
help="Text to speak",
)
parser.add_argument("--url", default="http://localhost:8880", help="API base URL") parser.add_argument("--url", default="http://localhost:8880", help="API base URL")
parser.add_argument("--output-dir", default="examples/output", help="Output directory for audio files") parser.add_argument(
"--output-dir",
default="examples/output",
help="Output directory for audio files",
)
args = parser.parse_args() args = parser.parse_args()
if not args.voices: if not args.voices:
@ -316,7 +361,9 @@ def main():
if combined_voice: if combined_voice:
print(f"Successfully created combined voice: {combined_voice}") print(f"Successfully created combined voice: {combined_voice}")
output_file = os.path.join(args.output_dir, f"analysis_combined_{combined_voice}.wav") output_file = os.path.join(
args.output_dir, f"analysis_combined_{combined_voice}.wav"
)
if generate_speech(args.text, combined_voice, args.url, output_file): if generate_speech(args.text, combined_voice, args.url, output_file):
audio_files["combined"] = output_file audio_files["combined"] = output_file

View file

@ -10,3 +10,5 @@ sqlalchemy==2.0.27
pytest==8.0.0 pytest==8.0.0
httpx==0.26.0 httpx==0.26.0
pytest-asyncio==0.23.5 pytest-asyncio==0.23.5
pytest-cov==6.0.0
gradio==4.19.2

View file

@ -2,8 +2,4 @@ from lib.interface import create_interface
if __name__ == "__main__": if __name__ == "__main__":
demo = create_interface() demo = create_interface()
demo.launch( demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
server_name="0.0.0.0",
server_port=7860,
show_error=True
)

View file

@ -1,16 +1,19 @@
import requests
from typing import Tuple, List, Optional
import os import os
import datetime import datetime
from typing import List, Tuple, Optional
import requests
from .config import API_URL, OUTPUTS_DIR from .config import API_URL, OUTPUTS_DIR
def check_api_status() -> Tuple[bool, List[str]]: def check_api_status() -> Tuple[bool, List[str]]:
"""Check TTS service status and get available voices.""" """Check TTS service status and get available voices."""
try: try:
# Use a longer timeout during startup # Use a longer timeout during startup
response = requests.get( response = requests.get(
f"{API_URL}/v1/audio/voices", f"{API_URL}/v1/audio/voices",
timeout=30 # Increased timeout for initial startup period timeout=30, # Increased timeout for initial startup period
) )
response.raise_for_status() response.raise_for_status()
voices = response.json().get("voices", []) voices = response.json().get("voices", [])
@ -31,7 +34,10 @@ def check_api_status() -> Tuple[bool, List[str]]:
print(f"Unexpected error checking API status: {str(e)}") print(f"Unexpected error checking API status: {str(e)}")
return False, [] return False, []
def text_to_speech(text: str, voice_id: str, format: str, speed: float) -> Optional[str]:
def text_to_speech(
text: str, voice_id: str, format: str, speed: float
) -> Optional[str]:
"""Generate speech from text using TTS API.""" """Generate speech from text using TTS API."""
if not text.strip(): if not text.strip():
return None return None
@ -49,10 +55,10 @@ def text_to_speech(text: str, voice_id: str, format: str, speed: float) -> Optio
"input": text, "input": text,
"voice": voice_id, "voice": voice_id,
"response_format": format, "response_format": format,
"speed": float(speed) "speed": float(speed),
}, },
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
timeout=300 # Longer timeout for speech generation timeout=300, # Longer timeout for speech generation
) )
response.raise_for_status() response.raise_for_status()
@ -70,6 +76,7 @@ def text_to_speech(text: str, voice_id: str, format: str, speed: float) -> Optio
print(f"Unexpected error generating speech: {str(e)}") print(f"Unexpected error generating speech: {str(e)}")
return None return None
def get_status_html(is_available: bool) -> str: def get_status_html(is_available: bool) -> str:
"""Generate HTML for status indicator.""" """Generate HTML for status indicator."""
color = "green" if is_available else "red" color = "green" if is_available else "red"

View file

@ -1,7 +1,10 @@
import gradio as gr
from typing import Tuple from typing import Tuple
import gradio as gr
from .. import files from .. import files
def create_input_column() -> Tuple[gr.Column, dict]: def create_input_column() -> Tuple[gr.Column, dict]:
"""Create the input column with text input and file handling.""" """Create the input column with text input and file handling."""
with gr.Column(scale=1) as col: with gr.Column(scale=1) as col:
@ -11,15 +14,9 @@ def create_input_column() -> Tuple[gr.Column, dict]:
# Direct Input Tab # Direct Input Tab
with gr.TabItem("Direct Input"): with gr.TabItem("Direct Input"):
text_input = gr.Textbox( text_input = gr.Textbox(
label="Text to speak", label="Text to speak", placeholder="Enter text here...", lines=4
placeholder="Enter text here...",
lines=4
)
text_submit = gr.Button(
"Generate Speech",
variant="primary",
size="lg"
) )
text_submit = gr.Button("Generate Speech", variant="primary", size="lg")
# File Input Tab # File Input Tab
with gr.TabItem("From File"): with gr.TabItem("From File"):
@ -27,31 +24,24 @@ def create_input_column() -> Tuple[gr.Column, dict]:
input_files_list = gr.Dropdown( input_files_list = gr.Dropdown(
label="Select Existing File", label="Select Existing File",
choices=files.list_input_files(), choices=files.list_input_files(),
value=None value=None,
) )
# Simple file upload # Simple file upload
file_upload = gr.File( file_upload = gr.File(
label="Upload Text File (.txt)", label="Upload Text File (.txt)", file_types=[".txt"]
file_types=[".txt"]
) )
file_preview = gr.Textbox( file_preview = gr.Textbox(
label="File Content Preview", label="File Content Preview", interactive=False, lines=4
interactive=False,
lines=4
) )
with gr.Row(): with gr.Row():
file_submit = gr.Button( file_submit = gr.Button(
"Generate Speech", "Generate Speech", variant="primary", size="lg"
variant="primary",
size="lg"
) )
clear_files = gr.Button( clear_files = gr.Button(
"Clear Files", "Clear Files", variant="secondary", size="lg"
variant="secondary",
size="lg"
) )
components = { components = {
@ -62,7 +52,7 @@ def create_input_column() -> Tuple[gr.Column, dict]:
"file_preview": file_preview, "file_preview": file_preview,
"text_submit": text_submit, "text_submit": text_submit,
"file_submit": file_submit, "file_submit": file_submit,
"clear_files": clear_files "clear_files": clear_files,
} }
return col, components return col, components

View file

@ -1,7 +1,10 @@
import gradio as gr
from typing import Tuple, Optional from typing import Tuple, Optional
import gradio as gr
from .. import api, config from .. import api, config
def create_model_column(voice_ids: Optional[list] = None) -> Tuple[gr.Column, dict]: def create_model_column(voice_ids: Optional[list] = None) -> Tuple[gr.Column, dict]:
"""Create the model settings column.""" """Create the model settings column."""
if voice_ids is None: if voice_ids is None:
@ -12,34 +15,27 @@ def create_model_column(voice_ids: Optional[list] = None) -> Tuple[gr.Column, di
# Status button starts in waiting state # Status button starts in waiting state
status_btn = gr.Button( status_btn = gr.Button(
"⌛ TTS Service: Waiting for Service...", "⌛ TTS Service: Waiting for Service...", variant="secondary"
variant="secondary"
) )
voice_input = gr.Dropdown( voice_input = gr.Dropdown(
choices=voice_ids, choices=voice_ids,
label="Voice", label="Voice",
value=voice_ids[0] if voice_ids else None, value=voice_ids[0] if voice_ids else None,
interactive=True interactive=True,
) )
format_input = gr.Dropdown( format_input = gr.Dropdown(
choices=config.AUDIO_FORMATS, choices=config.AUDIO_FORMATS, label="Audio Format", value="mp3"
label="Audio Format",
value="mp3"
) )
speed_input = gr.Slider( speed_input = gr.Slider(
minimum=0.5, minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speed"
maximum=2.0,
value=1.0,
step=0.1,
label="Speed"
) )
components = { components = {
"status_btn": status_btn, "status_btn": status_btn,
"voice": voice_input, "voice": voice_input,
"format": format_input, "format": format_input,
"speed": speed_input "speed": speed_input,
} }
return col, components return col, components

View file

@ -1,40 +1,42 @@
import gradio as gr
from typing import Tuple from typing import Tuple
import gradio as gr
from .. import files from .. import files
def create_output_column() -> Tuple[gr.Column, dict]: def create_output_column() -> Tuple[gr.Column, dict]:
"""Create the output column with audio player and file list.""" """Create the output column with audio player and file list."""
with gr.Column(scale=1) as col: with gr.Column(scale=1) as col:
gr.Markdown("### Latest Output") gr.Markdown("### Latest Output")
audio_output = gr.Audio( audio_output = gr.Audio(label="Generated Speech", type="filepath")
label="Generated Speech",
type="filepath"
)
gr.Markdown("### Generated Files") gr.Markdown("### Generated Files")
output_files = gr.Dropdown( output_files = gr.Dropdown(
label="Previous Outputs", label="Previous Outputs",
choices=files.list_output_files(), choices=files.list_output_files(),
value=None, value=None,
allow_custom_value=False allow_custom_value=False,
) )
play_btn = gr.Button("▶️ Play Selected", size="sm") play_btn = gr.Button("▶️ Play Selected", size="sm")
selected_audio = gr.Audio( selected_audio = gr.Audio(
label="Selected Output", label="Selected Output", type="filepath", visible=False
type="filepath",
visible=False
) )
clear_outputs = gr.Button("⚠️ Delete All Previously Generated Output Audio 🗑️", size="sm", variant="secondary") clear_outputs = gr.Button(
"⚠️ Delete All Previously Generated Output Audio 🗑️",
size="sm",
variant="secondary",
)
components = { components = {
"audio_output": audio_output, "audio_output": audio_output,
"output_files": output_files, "output_files": output_files,
"play_btn": play_btn, "play_btn": play_btn,
"selected_audio": selected_audio, "selected_audio": selected_audio,
"clear_outputs": clear_outputs "clear_outputs": clear_outputs,
} }
return col, components return col, components

View file

@ -1,17 +1,23 @@
import os import os
from typing import List, Optional, Tuple
import datetime import datetime
from typing import List, Tuple, Optional
from .config import INPUTS_DIR, OUTPUTS_DIR, AUDIO_FORMATS from .config import INPUTS_DIR, OUTPUTS_DIR, AUDIO_FORMATS
def list_input_files() -> List[str]: def list_input_files() -> List[str]:
"""List all input text files.""" """List all input text files."""
return [f for f in os.listdir(INPUTS_DIR) if f.endswith('.txt')] return [f for f in os.listdir(INPUTS_DIR) if f.endswith(".txt")]
def list_output_files() -> List[str]: def list_output_files() -> List[str]:
"""List all output audio files.""" """List all output audio files."""
return [os.path.join(OUTPUTS_DIR, f) return [
os.path.join(OUTPUTS_DIR, f)
for f in os.listdir(OUTPUTS_DIR) for f in os.listdir(OUTPUTS_DIR)
if any(f.endswith(ext) for ext in AUDIO_FORMATS)] if any(f.endswith(ext) for ext in AUDIO_FORMATS)
]
def read_text_file(filename: str) -> str: def read_text_file(filename: str) -> str:
"""Read content of a text file.""" """Read content of a text file."""
@ -19,11 +25,12 @@ def read_text_file(filename: str) -> str:
return "" return ""
try: try:
file_path = os.path.join(INPUTS_DIR, filename) file_path = os.path.join(INPUTS_DIR, filename)
with open(file_path, 'r', encoding='utf-8') as f: with open(file_path, "r", encoding="utf-8") as f:
return f.read() return f.read()
except: except:
return "" return ""
def save_text(text: str, filename: Optional[str] = None) -> Optional[str]: def save_text(text: str, filename: Optional[str] = None) -> Optional[str]:
"""Save text to a file. Returns the filename if successful.""" """Save text to a file. Returns the filename if successful."""
if not text.strip(): if not text.strip():
@ -41,7 +48,7 @@ def save_text(text: str, filename: Optional[str] = None) -> Optional[str]:
else: else:
# Handle duplicate filenames by adding _1, _2, etc. # Handle duplicate filenames by adding _1, _2, etc.
base = os.path.splitext(filename)[0] base = os.path.splitext(filename)[0]
ext = os.path.splitext(filename)[1] or '.txt' ext = os.path.splitext(filename)[1] or ".txt"
counter = 1 counter = 1
while os.path.exists(os.path.join(INPUTS_DIR, filename)): while os.path.exists(os.path.join(INPUTS_DIR, filename)):
filename = f"{base}_{counter}{ext}" filename = f"{base}_{counter}{ext}"
@ -56,11 +63,12 @@ def save_text(text: str, filename: Optional[str] = None) -> Optional[str]:
print(f"Error saving file: {e}") print(f"Error saving file: {e}")
return None return None
def delete_all_input_files() -> bool: def delete_all_input_files() -> bool:
"""Delete all files from the inputs directory. Returns True if successful.""" """Delete all files from the inputs directory. Returns True if successful."""
try: try:
for filename in os.listdir(INPUTS_DIR): for filename in os.listdir(INPUTS_DIR):
if filename.endswith('.txt'): if filename.endswith(".txt"):
file_path = os.path.join(INPUTS_DIR, filename) file_path = os.path.join(INPUTS_DIR, filename)
os.remove(file_path) os.remove(file_path)
return True return True
@ -68,6 +76,7 @@ def delete_all_input_files() -> bool:
print(f"Error deleting input files: {e}") print(f"Error deleting input files: {e}")
return False return False
def delete_all_output_files() -> bool: def delete_all_output_files() -> bool:
"""Delete all audio files from the outputs directory. Returns True if successful.""" """Delete all audio files from the outputs directory. Returns True if successful."""
try: try:
@ -80,6 +89,7 @@ def delete_all_output_files() -> bool:
print(f"Error deleting output files: {e}") print(f"Error deleting output files: {e}")
return False return False
def process_uploaded_file(file_path: str) -> bool: def process_uploaded_file(file_path: str) -> bool:
"""Save uploaded file to inputs directory. Returns True if successful.""" """Save uploaded file to inputs directory. Returns True if successful."""
if not file_path: if not file_path:
@ -87,7 +97,7 @@ def process_uploaded_file(file_path: str) -> bool:
try: try:
filename = os.path.basename(file_path) filename = os.path.basename(file_path)
if not filename.endswith('.txt'): if not filename.endswith(".txt"):
return False return False
# Create target path in inputs directory # Create target path in inputs directory
@ -103,6 +113,7 @@ def process_uploaded_file(file_path: str) -> bool:
# Copy file to inputs directory # Copy file to inputs directory
import shutil import shutil
shutil.copy2(file_path, target_path) shutil.copy2(file_path, target_path)
return True return True

View file

@ -1,8 +1,11 @@
import gradio as gr
import os import os
import shutil import shutil
import gradio as gr
from . import api, files from . import api, files
def setup_event_handlers(components: dict): def setup_event_handlers(components: dict):
"""Set up all event handlers for the UI components.""" """Set up all event handlers for the UI components."""
@ -19,17 +22,17 @@ def setup_event_handlers(components: dict):
gr.update( gr.update(
value=f"🔄 TTS Service: {status}", value=f"🔄 TTS Service: {status}",
interactive=True, interactive=True,
variant="secondary" variant="secondary",
), ),
gr.update(choices=voices, value=default_voice) gr.update(choices=voices, value=default_voice),
] ]
return [ return [
gr.update( gr.update(
value=f"⌛ TTS Service: {status}", value=f"⌛ TTS Service: {status}",
interactive=True, interactive=True,
variant="secondary" variant="secondary",
), ),
gr.update(choices=[], value=None) gr.update(choices=[], value=None),
] ]
except Exception as e: except Exception as e:
print(f"Error in refresh status: {str(e)}") print(f"Error in refresh status: {str(e)}")
@ -37,9 +40,9 @@ def setup_event_handlers(components: dict):
gr.update( gr.update(
value="❌ TTS Service: Connection Error", value="❌ TTS Service: Connection Error",
interactive=True, interactive=True,
variant="secondary" variant="secondary",
), ),
gr.update(choices=[], value=None) gr.update(choices=[], value=None),
] ]
def handle_file_select(filename): def handle_file_select(filename):
@ -82,30 +85,23 @@ def setup_event_handlers(components: dict):
is_available, _ = api.check_api_status() is_available, _ = api.check_api_status()
if not is_available: if not is_available:
gr.Warning("TTS Service is currently unavailable") gr.Warning("TTS Service is currently unavailable")
return [ return [None, gr.update(choices=files.list_output_files())]
None,
gr.update(choices=files.list_output_files())
]
if not text or not text.strip(): if not text or not text.strip():
gr.Warning("Please enter text in the input box") gr.Warning("Please enter text in the input box")
return [ return [None, gr.update(choices=files.list_output_files())]
None,
gr.update(choices=files.list_output_files())
]
files.save_text(text) files.save_text(text)
result = api.text_to_speech(text, voice, format, speed) result = api.text_to_speech(text, voice, format, speed)
if result is None: if result is None:
gr.Warning("Failed to generate speech. Please try again.") gr.Warning("Failed to generate speech. Please try again.")
return [ return [None, gr.update(choices=files.list_output_files())]
None,
gr.update(choices=files.list_output_files())
]
return [ return [
result, result,
gr.update(choices=files.list_output_files(), value=os.path.basename(result)) gr.update(
choices=files.list_output_files(), value=os.path.basename(result)
),
] ]
def generate_from_file(selected_file, voice, format, speed): def generate_from_file(selected_file, voice, format, speed):
@ -113,30 +109,23 @@ def setup_event_handlers(components: dict):
is_available, _ = api.check_api_status() is_available, _ = api.check_api_status()
if not is_available: if not is_available:
gr.Warning("TTS Service is currently unavailable") gr.Warning("TTS Service is currently unavailable")
return [ return [None, gr.update(choices=files.list_output_files())]
None,
gr.update(choices=files.list_output_files())
]
if not selected_file: if not selected_file:
gr.Warning("Please select a file") gr.Warning("Please select a file")
return [ return [None, gr.update(choices=files.list_output_files())]
None,
gr.update(choices=files.list_output_files())
]
text = files.read_text_file(selected_file) text = files.read_text_file(selected_file)
result = api.text_to_speech(text, voice, format, speed) result = api.text_to_speech(text, voice, format, speed)
if result is None: if result is None:
gr.Warning("Failed to generate speech. Please try again.") gr.Warning("Failed to generate speech. Please try again.")
return [ return [None, gr.update(choices=files.list_output_files())]
None,
gr.update(choices=files.list_output_files())
]
return [ return [
result, result,
gr.update(choices=files.list_output_files(), value=os.path.basename(result)) gr.update(
choices=files.list_output_files(), value=os.path.basename(result)
),
] ]
def play_selected(file_path): def play_selected(file_path):
@ -155,7 +144,7 @@ def setup_event_handlers(components: dict):
gr.update(choices=files.list_output_files()), # output_files gr.update(choices=files.list_output_files()), # output_files
gr.update(value=voice), # voice gr.update(value=voice), # voice
gr.update(value=format), # format gr.update(value=format), # format
gr.update(value=speed) # speed gr.update(value=speed), # speed
] ]
def clear_outputs(): def clear_outputs():
@ -164,34 +153,31 @@ def setup_event_handlers(components: dict):
return [ return [
None, # audio_output None, # audio_output
gr.update(choices=[], value=None), # output_files gr.update(choices=[], value=None), # output_files
gr.update(visible=False) # selected_audio gr.update(visible=False), # selected_audio
] ]
# Connect event handlers # Connect event handlers
components["model"]["status_btn"].click( components["model"]["status_btn"].click(
fn=refresh_status, fn=refresh_status,
outputs=[ outputs=[components["model"]["status_btn"], components["model"]["voice"]],
components["model"]["status_btn"],
components["model"]["voice"]
]
) )
components["input"]["file_select"].change( components["input"]["file_select"].change(
fn=handle_file_select, fn=handle_file_select,
inputs=[components["input"]["file_select"]], inputs=[components["input"]["file_select"]],
outputs=[components["input"]["file_preview"]] outputs=[components["input"]["file_preview"]],
) )
components["input"]["file_upload"].upload( components["input"]["file_upload"].upload(
fn=handle_file_upload, fn=handle_file_upload,
inputs=[components["input"]["file_upload"]], inputs=[components["input"]["file_upload"]],
outputs=[components["input"]["file_select"]] outputs=[components["input"]["file_select"]],
) )
components["output"]["play_btn"].click( components["output"]["play_btn"].click(
fn=play_selected, fn=play_selected,
inputs=[components["output"]["output_files"]], inputs=[components["output"]["output_files"]],
outputs=[components["output"]["selected_audio"]] outputs=[components["output"]["selected_audio"]],
) )
# Connect clear files button # Connect clear files button
@ -200,7 +186,7 @@ def setup_event_handlers(components: dict):
inputs=[ inputs=[
components["model"]["voice"], components["model"]["voice"],
components["model"]["format"], components["model"]["format"],
components["model"]["speed"] components["model"]["speed"],
], ],
outputs=[ outputs=[
components["input"]["file_select"], components["input"]["file_select"],
@ -210,8 +196,8 @@ def setup_event_handlers(components: dict):
components["output"]["output_files"], components["output"]["output_files"],
components["model"]["voice"], components["model"]["voice"],
components["model"]["format"], components["model"]["format"],
components["model"]["speed"] components["model"]["speed"],
] ],
) )
# Connect submit buttons for each tab # Connect submit buttons for each tab
@ -221,12 +207,12 @@ def setup_event_handlers(components: dict):
components["input"]["text_input"], components["input"]["text_input"],
components["model"]["voice"], components["model"]["voice"],
components["model"]["format"], components["model"]["format"],
components["model"]["speed"] components["model"]["speed"],
], ],
outputs=[ outputs=[
components["output"]["audio_output"], components["output"]["audio_output"],
components["output"]["output_files"] components["output"]["output_files"],
] ],
) )
# Connect clear outputs button # Connect clear outputs button
@ -235,8 +221,8 @@ def setup_event_handlers(components: dict):
outputs=[ outputs=[
components["output"]["audio_output"], components["output"]["audio_output"],
components["output"]["output_files"], components["output"]["output_files"],
components["output"]["selected_audio"] components["output"]["selected_audio"],
] ],
) )
components["input"]["file_submit"].click( components["input"]["file_submit"].click(
@ -245,10 +231,10 @@ def setup_event_handlers(components: dict):
components["input"]["file_select"], components["input"]["file_select"],
components["model"]["voice"], components["model"]["voice"],
components["model"]["format"], components["model"]["format"],
components["model"]["speed"] components["model"]["speed"],
], ],
outputs=[ outputs=[
components["output"]["audio_output"], components["output"]["audio_output"],
components["output"]["output_files"] components["output"]["output_files"],
] ],
) )

View file

@ -1,34 +1,38 @@
import gradio as gr import gradio as gr
from . import api from . import api
from .components import create_input_column, create_model_column, create_output_column
from .handlers import setup_event_handlers from .handlers import setup_event_handlers
from .components import create_input_column, create_model_column, create_output_column
def create_interface(): def create_interface():
"""Create the main Gradio interface.""" """Create the main Gradio interface."""
# Skip initial status check - let the timer handle it # Skip initial status check - let the timer handle it
is_available, available_voices = False, [] is_available, available_voices = False, []
with gr.Blocks( with gr.Blocks(title="Kokoro TTS Demo", theme=gr.themes.Monochrome()) as demo:
title="Kokoro TTS Demo", gr.HTML(
theme=gr.themes.Monochrome() value='<div style="display: flex; gap: 0;">'
) as demo:
gr.HTML(value='<div style="display: flex; gap: 0;">'
'<a href="https://huggingface.co/hexgrad/Kokoro-82M" target="_blank" style="color: #2196F3; text-decoration: none; margin: 2px; border: 1px solid #2196F3; padding: 4px 8px; height: 24px; box-sizing: border-box; display: inline-flex; align-items: center;">Kokoro-82M HF Repo</a>' '<a href="https://huggingface.co/hexgrad/Kokoro-82M" target="_blank" style="color: #2196F3; text-decoration: none; margin: 2px; border: 1px solid #2196F3; padding: 4px 8px; height: 24px; box-sizing: border-box; display: inline-flex; align-items: center;">Kokoro-82M HF Repo</a>'
'<a href="https://github.com/remsky/Kokoro-FastAPI" target="_blank" style="color: #2196F3; text-decoration: none; margin: 2px; border: 1px solid #2196F3; padding: 4px 8px; height: 24px; box-sizing: border-box; display: inline-flex; align-items: center;">Kokoro-FastAPI Repo</a>' '<a href="https://github.com/remsky/Kokoro-FastAPI" target="_blank" style="color: #2196F3; text-decoration: none; margin: 2px; border: 1px solid #2196F3; padding: 4px 8px; height: 24px; box-sizing: border-box; display: inline-flex; align-items: center;">Kokoro-FastAPI Repo</a>'
'</div>', show_label=False) "</div>",
show_label=False,
)
# Main interface # Main interface
with gr.Row(): with gr.Row():
# Create columns # Create columns
input_col, input_components = create_input_column() input_col, input_components = create_input_column()
model_col, model_components = create_model_column(available_voices) # Pass initial voices model_col, model_components = create_model_column(
available_voices
) # Pass initial voices
output_col, output_components = create_output_column() output_col, output_components = create_output_column()
# Collect all components # Collect all components
components = { components = {
"input": input_components, "input": input_components,
"model": model_components, "model": model_components,
"output": output_components "output": output_components,
} }
# Set up event handlers # Set up event handlers
@ -43,16 +47,18 @@ def create_interface():
if is_available and voices: if is_available and voices:
# Service is available, update UI and stop timer # Service is available, update UI and stop timer
current_voice = components["model"]["voice"].value current_voice = components["model"]["voice"].value
default_voice = current_voice if current_voice in voices else voices[0] default_voice = (
current_voice if current_voice in voices else voices[0]
)
# Return values in same order as outputs list # Return values in same order as outputs list
return [ return [
gr.update( gr.update(
value=f"🔄 TTS Service: {status}", value=f"🔄 TTS Service: {status}",
interactive=True, interactive=True,
variant="secondary" variant="secondary",
), ),
gr.update(choices=voices, value=default_voice), gr.update(choices=voices, value=default_voice),
gr.update(active=False) # Stop timer gr.update(active=False), # Stop timer
] ]
# Service not available yet, keep checking # Service not available yet, keep checking
@ -60,10 +66,10 @@ def create_interface():
gr.update( gr.update(
value=f"⌛ TTS Service: {status}", value=f"⌛ TTS Service: {status}",
interactive=True, interactive=True,
variant="secondary" variant="secondary",
), ),
gr.update(choices=[], value=None), gr.update(choices=[], value=None),
gr.update(active=True) gr.update(active=True),
] ]
except Exception as e: except Exception as e:
print(f"Error in status update: {str(e)}") print(f"Error in status update: {str(e)}")
@ -72,10 +78,10 @@ def create_interface():
gr.update( gr.update(
value="❌ TTS Service: Connection Error", value="❌ TTS Service: Connection Error",
interactive=True, interactive=True,
variant="secondary" variant="secondary",
), ),
gr.update(choices=[], value=None), gr.update(choices=[], value=None),
gr.update(active=True) gr.update(active=True),
] ]
timer = gr.Timer(value=5) # Check every 5 seconds timer = gr.Timer(value=5) # Check every 5 seconds
@ -84,8 +90,8 @@ def create_interface():
outputs=[ outputs=[
components["model"]["status_btn"], components["model"]["status_btn"],
components["model"]["voice"], components["model"]["voice"],
timer timer,
] ],
) )
return demo return demo

View file

@ -1,5 +1,5 @@
import pytest
import gradio as gr import gradio as gr
import pytest
@pytest.fixture @pytest.fixture

View file

@ -1,6 +1,8 @@
from unittest.mock import patch, mock_open
import pytest import pytest
import requests import requests
from unittest.mock import patch, mock_open
from ui.lib import api from ui.lib import api
@ -57,10 +59,9 @@ def test_check_api_status_connection_error():
def test_text_to_speech_success(mock_response, tmp_path): def test_text_to_speech_success(mock_response, tmp_path):
"""Test successful speech generation""" """Test successful speech generation"""
with patch("requests.post", return_value=mock_response({})), \ with patch("requests.post", return_value=mock_response({})), patch(
patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)), \ "ui.lib.api.OUTPUTS_DIR", str(tmp_path)
patch("builtins.open", mock_open()) as mock_file: ), patch("builtins.open", mock_open()) as mock_file:
result = api.text_to_speech("test text", "voice1", "mp3", 1.0) result = api.text_to_speech("test text", "voice1", "mp3", 1.0)
assert result is not None assert result is not None
@ -105,10 +106,9 @@ def test_get_status_html_unavailable():
def test_text_to_speech_api_params(mock_response, tmp_path): def test_text_to_speech_api_params(mock_response, tmp_path):
"""Test correct API parameters are sent""" """Test correct API parameters are sent"""
with patch("requests.post") as mock_post, \ with patch("requests.post") as mock_post, patch(
patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)), \ "ui.lib.api.OUTPUTS_DIR", str(tmp_path)
patch("builtins.open", mock_open()): ), patch("builtins.open", mock_open()):
mock_post.return_value = mock_response({}) mock_post.return_value = mock_response({})
api.text_to_speech("test text", "voice1", "mp3", 1.5) api.text_to_speech("test text", "voice1", "mp3", 1.5)
@ -121,7 +121,7 @@ def test_text_to_speech_api_params(mock_response, tmp_path):
"input": "test text", "input": "test text",
"voice": "voice1", "voice": "voice1",
"response_format": "mp3", "response_format": "mp3",
"speed": 1.5 "speed": 1.5,
} }
# Check headers and timeout # Check headers and timeout

View file

@ -1,8 +1,9 @@
import pytest
import gradio as gr import gradio as gr
import pytest
from ui.lib.config import AUDIO_FORMATS
from ui.lib.components.model import create_model_column from ui.lib.components.model import create_model_column
from ui.lib.components.output import create_output_column from ui.lib.components.output import create_output_column
from ui.lib.config import AUDIO_FORMATS
def test_create_model_column_structure(): def test_create_model_column_structure():
@ -15,12 +16,7 @@ def test_create_model_column_structure():
assert isinstance(components, dict) assert isinstance(components, dict)
# Test expected components presence # Test expected components presence
expected_components = { expected_components = {"status_btn", "voice", "format", "speed"}
"status_btn",
"voice",
"format",
"speed"
}
assert set(components.keys()) == expected_components assert set(components.keys()) == expected_components
# Test component types # Test component types
@ -78,7 +74,7 @@ def test_create_output_column_structure():
"output_files", "output_files",
"play_btn", "play_btn",
"selected_audio", "selected_audio",
"clear_outputs" "clear_outputs",
} }
assert set(components.keys()) == expected_components assert set(components.keys()) == expected_components

View file

@ -1,6 +1,8 @@
import os import os
import pytest
from unittest.mock import patch from unittest.mock import patch
import pytest
from ui.lib import files from ui.lib import files
from ui.lib.config import AUDIO_FORMATS from ui.lib.config import AUDIO_FORMATS

View file

@ -1,5 +1,6 @@
import pytest
import gradio as gr import gradio as gr
import pytest
from ui.lib.components.input import create_input_column from ui.lib.components.input import create_input_column

View file

@ -1,12 +1,15 @@
import pytest from unittest.mock import MagicMock, PropertyMock, patch
import gradio as gr import gradio as gr
from unittest.mock import patch, MagicMock, PropertyMock import pytest
from ui.lib.interface import create_interface from ui.lib.interface import create_interface
@pytest.fixture @pytest.fixture
def mock_timer(): def mock_timer():
"""Create a mock timer with events property""" """Create a mock timer with events property"""
class MockEvent: class MockEvent:
def __init__(self, fn): def __init__(self, fn):
self.fn = fn self.fn = fn
@ -44,8 +47,7 @@ def test_interface_html_links():
# Find HTML component # Find HTML component
html_components = [ html_components = [
comp for comp in demo.blocks.values() comp for comp in demo.blocks.values() if isinstance(comp, gr.HTML)
if isinstance(comp, gr.HTML)
] ]
assert len(html_components) > 0 assert len(html_components) > 0
html = html_components[0] html = html_components[0]
@ -60,8 +62,9 @@ def test_interface_html_links():
def test_update_status_available(mock_timer): def test_update_status_available(mock_timer):
"""Test status update when service is available""" """Test status update when service is available"""
voices = ["voice1", "voice2"] voices = ["voice1", "voice2"]
with patch("ui.lib.api.check_api_status", return_value=(True, voices)), \ with patch("ui.lib.api.check_api_status", return_value=(True, voices)), patch(
patch("gradio.Timer", return_value=mock_timer): "gradio.Timer", return_value=mock_timer
):
demo = create_interface() demo = create_interface()
# Get the update function # Get the update function
@ -78,8 +81,9 @@ def test_update_status_available(mock_timer):
def test_update_status_unavailable(mock_timer): def test_update_status_unavailable(mock_timer):
"""Test status update when service is unavailable""" """Test status update when service is unavailable"""
with patch("ui.lib.api.check_api_status", return_value=(False, [])), \ with patch("ui.lib.api.check_api_status", return_value=(False, [])), patch(
patch("gradio.Timer", return_value=mock_timer): "gradio.Timer", return_value=mock_timer
):
demo = create_interface() demo = create_interface()
update_fn = mock_timer.events[0].fn update_fn = mock_timer.events[0].fn
@ -93,8 +97,9 @@ def test_update_status_unavailable(mock_timer):
def test_update_status_error(mock_timer): def test_update_status_error(mock_timer):
"""Test status update when an error occurs""" """Test status update when an error occurs"""
with patch("ui.lib.api.check_api_status", side_effect=Exception("Test error")), \ with patch(
patch("gradio.Timer", return_value=mock_timer): "ui.lib.api.check_api_status", side_effect=Exception("Test error")
), patch("gradio.Timer", return_value=mock_timer):
demo = create_interface() demo = create_interface()
update_fn = mock_timer.events[0].fn update_fn = mock_timer.events[0].fn
@ -108,8 +113,9 @@ def test_update_status_error(mock_timer):
def test_timer_configuration(mock_timer): def test_timer_configuration(mock_timer):
"""Test timer configuration""" """Test timer configuration"""
with patch("ui.lib.api.check_api_status", return_value=(False, [])), \ with patch("ui.lib.api.check_api_status", return_value=(False, [])), patch(
patch("gradio.Timer", return_value=mock_timer): "gradio.Timer", return_value=mock_timer
):
demo = create_interface() demo = create_interface()
assert mock_timer.value == 5 # Check interval is 5 seconds assert mock_timer.value == 5 # Check interval is 5 seconds
@ -123,8 +129,9 @@ def test_interface_components_presence():
# Check for main component sections # Check for main component sections
components = { components = {
comp.label for comp in demo.blocks.values() comp.label
if hasattr(comp, 'label') and comp.label for comp in demo.blocks.values()
if hasattr(comp, "label") and comp.label
} }
required_components = { required_components = {
@ -133,7 +140,7 @@ def test_interface_components_presence():
"Audio Format", "Audio Format",
"Speed", "Speed",
"Generated Speech", "Generated Speech",
"Previous Outputs" "Previous Outputs",
} }
assert required_components.issubset(components) assert required_components.issubset(components)