mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Ruff Check + Format
This commit is contained in:
parent
e749b3bc88
commit
f051984805
27 changed files with 638 additions and 504 deletions
BIN
.coverage
BIN
.coverage
Binary file not shown.
|
@ -1,10 +1,10 @@
|
|||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response
|
||||
from loguru import logger
|
||||
from fastapi import Depends, Response, APIRouter, HTTPException
|
||||
|
||||
from ..services.audio import AudioService
|
||||
from ..services.tts import TTSService
|
||||
from ..services.audio import AudioService
|
||||
from ..structures.schemas import OpenAISpeechRequest
|
||||
|
||||
router = APIRouter(
|
||||
|
@ -55,14 +55,12 @@ async def create_speech(
|
|||
except ValueError as e:
|
||||
logger.error(f"Invalid request: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "Invalid request", "message": str(e)}
|
||||
status_code=400, detail={"error": "Invalid request", "message": str(e)}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating speech: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": "Server error", "message": str(e)}
|
||||
status_code=500, detail={"error": "Server error", "message": str(e)}
|
||||
)
|
||||
|
||||
|
||||
|
@ -78,7 +76,9 @@ async def list_voices(tts_service: TTSService = Depends(get_tts_service)):
|
|||
|
||||
|
||||
@router.post("/audio/voices/combine")
|
||||
async def combine_voices(request: List[str], tts_service: TTSService = Depends(get_tts_service)):
|
||||
async def combine_voices(
|
||||
request: List[str], tts_service: TTSService = Depends(get_tts_service)
|
||||
):
|
||||
"""Combine multiple voices into a new voice.
|
||||
|
||||
Args:
|
||||
|
@ -100,20 +100,17 @@ async def combine_voices(request: List[str], tts_service: TTSService = Depends(g
|
|||
except ValueError as e:
|
||||
logger.error(f"Invalid voice combination request: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "Invalid request", "message": str(e)}
|
||||
status_code=400, detail={"error": "Invalid request", "message": str(e)}
|
||||
)
|
||||
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Server error during voice combination: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": "Server error", "message": str(e)}
|
||||
status_code=500, detail={"error": "Server error", "message": str(e)}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during voice combination: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": "Unexpected error", "message": str(e)}
|
||||
status_code=500, detail={"error": "Unexpected error", "message": str(e)}
|
||||
)
|
||||
|
|
|
@ -1,17 +1,16 @@
|
|||
import io
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
import threading
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import scipy.io.wavfile as wavfile
|
||||
import tiktoken
|
||||
import torch
|
||||
import tiktoken
|
||||
import scipy.io.wavfile as wavfile
|
||||
from kokoro import generate, tokenize, phonemize, normalize_text
|
||||
from loguru import logger
|
||||
|
||||
from kokoro import generate, normalize_text, phonemize, tokenize
|
||||
from models import build_model
|
||||
|
||||
from ..core.config import settings
|
||||
|
@ -51,25 +50,37 @@ class TTSModel:
|
|||
voice_path = os.path.join(cls.VOICES_DIR, file)
|
||||
if not os.path.exists(voice_path):
|
||||
try:
|
||||
logger.info(f"Copying base voice {voice_name} to voices directory")
|
||||
logger.info(
|
||||
f"Copying base voice {voice_name} to voices directory"
|
||||
)
|
||||
base_path = os.path.join(base_voices_dir, file)
|
||||
voicepack = torch.load(base_path, map_location=cls._device, weights_only=True)
|
||||
voicepack = torch.load(
|
||||
base_path,
|
||||
map_location=cls._device,
|
||||
weights_only=True,
|
||||
)
|
||||
torch.save(voicepack, voice_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Error copying voice {voice_name}: {str(e)}")
|
||||
logger.error(
|
||||
f"Error copying voice {voice_name}: {str(e)}"
|
||||
)
|
||||
|
||||
# Warm up with default voice
|
||||
try:
|
||||
dummy_text = "Hello"
|
||||
voice_path = os.path.join(cls.VOICES_DIR, "af.pt")
|
||||
dummy_voicepack = torch.load(voice_path, map_location=cls._device, weights_only=True)
|
||||
generate(model, dummy_text, dummy_voicepack, lang='a', speed=1.0)
|
||||
dummy_voicepack = torch.load(
|
||||
voice_path, map_location=cls._device, weights_only=True
|
||||
)
|
||||
generate(model, dummy_text, dummy_voicepack, lang="a", speed=1.0)
|
||||
logger.info("Model warm-up complete")
|
||||
except Exception as e:
|
||||
logger.warning(f"Model warm-up failed: {e}")
|
||||
|
||||
# Count voices in directory for validation
|
||||
voice_count = len([f for f in os.listdir(cls.VOICES_DIR) if f.endswith('.pt')])
|
||||
voice_count = len(
|
||||
[f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]
|
||||
)
|
||||
return cls._instance, voice_count
|
||||
|
||||
@classmethod
|
||||
|
@ -99,9 +110,15 @@ class TTSService:
|
|||
voice_path = os.path.join(TTSModel.VOICES_DIR, file)
|
||||
if not os.path.exists(voice_path):
|
||||
try:
|
||||
logger.info(f"Copying base voice {voice_name} to voices directory")
|
||||
logger.info(
|
||||
f"Copying base voice {voice_name} to voices directory"
|
||||
)
|
||||
base_path = os.path.join(base_voices_dir, file)
|
||||
voicepack = torch.load(base_path, map_location=TTSModel._device, weights_only=True)
|
||||
voicepack = torch.load(
|
||||
base_path,
|
||||
map_location=TTSModel._device,
|
||||
weights_only=True,
|
||||
)
|
||||
torch.save(voicepack, voice_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Error copying voice {voice_name}: {str(e)}")
|
||||
|
@ -141,7 +158,9 @@ class TTSService:
|
|||
|
||||
# Load model and voice
|
||||
model = TTSModel._instance
|
||||
voicepack = torch.load(voice_path, map_location=TTSModel._device, weights_only=True)
|
||||
voicepack = torch.load(
|
||||
voice_path, map_location=TTSModel._device, weights_only=True
|
||||
)
|
||||
|
||||
# Generate audio with or without stitching
|
||||
if stitch_long_output:
|
||||
|
@ -152,11 +171,11 @@ class TTSService:
|
|||
for i, chunk in enumerate(chunks):
|
||||
try:
|
||||
# Validate phonemization first
|
||||
ps = phonemize(chunk, voice[0])
|
||||
tokens = tokenize(ps)
|
||||
logger.debug(
|
||||
f"Processing chunk {i + 1}/{len(chunks)}: {len(tokens)} tokens"
|
||||
)
|
||||
# ps = phonemize(chunk, voice[0])
|
||||
# tokens = tokenize(ps)
|
||||
# logger.debug(
|
||||
# f"Processing chunk {i + 1}/{len(chunks)}: {len(tokens)} tokens"
|
||||
# )
|
||||
|
||||
# Only proceed if phonemization succeeded
|
||||
chunk_audio, _ = generate(
|
||||
|
@ -226,7 +245,9 @@ class TTSService:
|
|||
for voice in voices:
|
||||
try:
|
||||
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt")
|
||||
voicepack = torch.load(voice_path, map_location=TTSModel._device, weights_only=True)
|
||||
voicepack = torch.load(
|
||||
voice_path, map_location=TTSModel._device, weights_only=True
|
||||
)
|
||||
t_voices.append(voicepack)
|
||||
v_name.append(voice)
|
||||
except Exception as e:
|
||||
|
@ -242,7 +263,9 @@ class TTSService:
|
|||
try:
|
||||
torch.save(v, combined_path)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to save combined voice to {combined_path}: {str(e)}")
|
||||
raise RuntimeError(
|
||||
f"Failed to save combined voice to {combined_path}: {str(e)}"
|
||||
)
|
||||
|
||||
return f
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ class OpenAISpeechRequest(BaseModel):
|
|||
input: str = Field(..., description="The text to generate audio for")
|
||||
voice: str = Field(
|
||||
default="af",
|
||||
description="The voice to use for generation. Can be a base voice or a combined voice name."
|
||||
description="The voice to use for generation. Can be a base voice or a combined voice name.",
|
||||
)
|
||||
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field(
|
||||
default="mp3",
|
||||
|
|
|
@ -1,16 +1,18 @@
|
|||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import shutil
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def cleanup_mock_dirs():
|
||||
"""Clean up any MagicMock directories created during tests"""
|
||||
mock_dir = "MagicMock"
|
||||
if os.path.exists(mock_dir):
|
||||
shutil.rmtree(mock_dir)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup():
|
||||
"""Automatically clean up before and after each test"""
|
||||
|
@ -18,6 +20,7 @@ def cleanup():
|
|||
yield
|
||||
cleanup_mock_dirs()
|
||||
|
||||
|
||||
# Mock torch and other ML modules before they're imported
|
||||
sys.modules["torch"] = Mock()
|
||||
sys.modules["transformers"] = Mock()
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
"""Tests for AudioService"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from api.src.services.audio import AudioService
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
"""Tests for FastAPI application"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from api.src.main import app, lifespan
|
||||
|
||||
|
||||
|
@ -19,15 +22,15 @@ def test_health_check(test_client):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('api.src.main.TTSModel')
|
||||
@patch('api.src.main.logger')
|
||||
@patch("api.src.main.TTSModel")
|
||||
@patch("api.src.main.logger")
|
||||
async def test_lifespan_successful_warmup(mock_logger, mock_tts_model):
|
||||
"""Test successful model warmup in lifespan"""
|
||||
# Mock the model initialization with model info and voicepack count
|
||||
mock_model = MagicMock()
|
||||
# Mock file system for voice counting
|
||||
mock_tts_model.VOICES_DIR = "/mock/voices"
|
||||
with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt', 'voice3.pt']):
|
||||
with patch("os.listdir", return_value=["voice1.pt", "voice2.pt", "voice3.pt"]):
|
||||
mock_tts_model.initialize.return_value = (mock_model, 3) # 3 voice files
|
||||
mock_tts_model._device = "cuda" # Set device class variable
|
||||
|
||||
|
@ -49,8 +52,8 @@ async def test_lifespan_successful_warmup(mock_logger, mock_tts_model):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('api.src.main.TTSModel')
|
||||
@patch('api.src.main.logger')
|
||||
@patch("api.src.main.TTSModel")
|
||||
@patch("api.src.main.logger")
|
||||
async def test_lifespan_failed_warmup(mock_logger, mock_tts_model):
|
||||
"""Test failed model warmup in lifespan"""
|
||||
# Mock the model initialization to fail
|
||||
|
@ -71,14 +74,14 @@ async def test_lifespan_failed_warmup(mock_logger, mock_tts_model):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('api.src.main.TTSModel')
|
||||
@patch("api.src.main.TTSModel")
|
||||
async def test_lifespan_cuda_warmup(mock_tts_model):
|
||||
"""Test model warmup specifically on CUDA"""
|
||||
# Mock the model initialization with CUDA and voicepacks
|
||||
mock_model = MagicMock()
|
||||
# Mock file system for voice counting
|
||||
mock_tts_model.VOICES_DIR = "/mock/voices"
|
||||
with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt']):
|
||||
with patch("os.listdir", return_value=["voice1.pt", "voice2.pt"]):
|
||||
mock_tts_model.initialize.return_value = (mock_model, 2) # 2 voice files
|
||||
mock_tts_model._device = "cuda" # Set device class variable
|
||||
|
||||
|
@ -94,14 +97,16 @@ async def test_lifespan_cuda_warmup(mock_tts_model):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('api.src.main.TTSModel')
|
||||
@patch("api.src.main.TTSModel")
|
||||
async def test_lifespan_cpu_fallback(mock_tts_model):
|
||||
"""Test model warmup falling back to CPU"""
|
||||
# Mock the model initialization with CPU and voicepacks
|
||||
mock_model = MagicMock()
|
||||
# Mock file system for voice counting
|
||||
mock_tts_model.VOICES_DIR = "/mock/voices"
|
||||
with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt', 'voice3.pt', 'voice4.pt']):
|
||||
with patch(
|
||||
"os.listdir", return_value=["voice1.pt", "voice2.pt", "voice3.pt", "voice4.pt"]
|
||||
):
|
||||
mock_tts_model.initialize.return_value = (mock_model, 4) # 4 voice files
|
||||
mock_tts_model._device = "cpu" # Set device class variable
|
||||
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
"""Tests for TTSService"""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock, call
|
||||
from api.src.services.tts import TTSService, TTSModel
|
||||
|
||||
from api.src.services.tts import TTSModel, TTSService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -50,27 +53,35 @@ def test_audio_to_bytes(tts_service, sample_audio):
|
|||
assert len(audio_bytes) > 0
|
||||
|
||||
|
||||
@patch('os.listdir')
|
||||
@patch('os.path.join')
|
||||
@patch("os.listdir")
|
||||
@patch("os.path.join")
|
||||
def test_list_voices(mock_join, mock_listdir, tts_service):
|
||||
"""Test listing available voices"""
|
||||
mock_listdir.return_value = ['voice1.pt', 'voice2.pt', 'not_a_voice.txt']
|
||||
mock_join.return_value = '/fake/path'
|
||||
mock_listdir.return_value = ["voice1.pt", "voice2.pt", "not_a_voice.txt"]
|
||||
mock_join.return_value = "/fake/path"
|
||||
|
||||
voices = tts_service.list_voices()
|
||||
assert len(voices) == 2
|
||||
assert 'voice1' in voices
|
||||
assert 'voice2' in voices
|
||||
assert 'not_a_voice' not in voices
|
||||
assert "voice1" in voices
|
||||
assert "voice2" in voices
|
||||
assert "not_a_voice" not in voices
|
||||
|
||||
|
||||
@patch('api.src.services.tts.TTSModel.get_instance')
|
||||
@patch('api.src.services.tts.TTSModel.get_voicepack')
|
||||
@patch('api.src.services.tts.normalize_text')
|
||||
@patch('api.src.services.tts.phonemize')
|
||||
@patch('api.src.services.tts.tokenize')
|
||||
@patch('api.src.services.tts.generate')
|
||||
def test_generate_audio_empty_text(mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_voicepack, mock_instance, tts_service):
|
||||
@patch("api.src.services.tts.TTSModel.get_instance")
|
||||
@patch("api.src.services.tts.TTSModel.get_voicepack")
|
||||
@patch("api.src.services.tts.normalize_text")
|
||||
@patch("api.src.services.tts.phonemize")
|
||||
@patch("api.src.services.tts.tokenize")
|
||||
@patch("api.src.services.tts.generate")
|
||||
def test_generate_audio_empty_text(
|
||||
mock_generate,
|
||||
mock_tokenize,
|
||||
mock_phonemize,
|
||||
mock_normalize,
|
||||
mock_voicepack,
|
||||
mock_instance,
|
||||
tts_service,
|
||||
):
|
||||
"""Test generating audio with empty text"""
|
||||
mock_normalize.return_value = ""
|
||||
|
||||
|
@ -78,14 +89,23 @@ def test_generate_audio_empty_text(mock_generate, mock_tokenize, mock_phonemize,
|
|||
tts_service._generate_audio("", "af", 1.0)
|
||||
|
||||
|
||||
@patch('api.src.services.tts.TTSModel.get_instance')
|
||||
@patch('os.path.exists')
|
||||
@patch('api.src.services.tts.normalize_text')
|
||||
@patch('api.src.services.tts.phonemize')
|
||||
@patch('api.src.services.tts.tokenize')
|
||||
@patch('api.src.services.tts.generate')
|
||||
@patch('torch.load')
|
||||
def test_generate_audio_no_chunks(mock_torch_load, mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_exists, mock_instance, tts_service):
|
||||
@patch("api.src.services.tts.TTSModel.get_instance")
|
||||
@patch("os.path.exists")
|
||||
@patch("api.src.services.tts.normalize_text")
|
||||
@patch("api.src.services.tts.phonemize")
|
||||
@patch("api.src.services.tts.tokenize")
|
||||
@patch("api.src.services.tts.generate")
|
||||
@patch("torch.load")
|
||||
def test_generate_audio_no_chunks(
|
||||
mock_torch_load,
|
||||
mock_generate,
|
||||
mock_tokenize,
|
||||
mock_phonemize,
|
||||
mock_normalize,
|
||||
mock_exists,
|
||||
mock_instance,
|
||||
tts_service,
|
||||
):
|
||||
"""Test generating audio with no successful chunks"""
|
||||
mock_normalize.return_value = "Test text"
|
||||
mock_phonemize.return_value = "Test text"
|
||||
|
@ -99,14 +119,24 @@ def test_generate_audio_no_chunks(mock_torch_load, mock_generate, mock_tokenize,
|
|||
tts_service._generate_audio("Test text", "af", 1.0)
|
||||
|
||||
|
||||
@patch('api.src.services.tts.TTSModel.get_instance')
|
||||
@patch('os.path.exists')
|
||||
@patch('api.src.services.tts.normalize_text')
|
||||
@patch('api.src.services.tts.phonemize')
|
||||
@patch('api.src.services.tts.tokenize')
|
||||
@patch('api.src.services.tts.generate')
|
||||
@patch('torch.load')
|
||||
def test_generate_audio_success(mock_torch_load, mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_exists, mock_instance, tts_service, sample_audio):
|
||||
@patch("api.src.services.tts.TTSModel.get_instance")
|
||||
@patch("os.path.exists")
|
||||
@patch("api.src.services.tts.normalize_text")
|
||||
@patch("api.src.services.tts.phonemize")
|
||||
@patch("api.src.services.tts.tokenize")
|
||||
@patch("api.src.services.tts.generate")
|
||||
@patch("torch.load")
|
||||
def test_generate_audio_success(
|
||||
mock_torch_load,
|
||||
mock_generate,
|
||||
mock_tokenize,
|
||||
mock_phonemize,
|
||||
mock_normalize,
|
||||
mock_exists,
|
||||
mock_instance,
|
||||
tts_service,
|
||||
sample_audio,
|
||||
):
|
||||
"""Test successful audio generation"""
|
||||
mock_normalize.return_value = "Test text"
|
||||
mock_phonemize.return_value = "Test text"
|
||||
|
@ -122,8 +152,8 @@ def test_generate_audio_success(mock_torch_load, mock_generate, mock_tokenize, m
|
|||
assert len(audio) > 0
|
||||
|
||||
|
||||
@patch('api.src.services.tts.torch.cuda.is_available')
|
||||
@patch('api.src.services.tts.build_model')
|
||||
@patch("api.src.services.tts.torch.cuda.is_available")
|
||||
@patch("api.src.services.tts.build_model")
|
||||
def test_model_initialization_cuda(mock_build_model, mock_cuda_available):
|
||||
"""Test model initialization with CUDA"""
|
||||
mock_cuda_available.return_value = True
|
||||
|
@ -138,8 +168,8 @@ def test_model_initialization_cuda(mock_build_model, mock_cuda_available):
|
|||
mock_build_model.assert_called_once()
|
||||
|
||||
|
||||
@patch('api.src.services.tts.torch.cuda.is_available')
|
||||
@patch('api.src.services.tts.build_model')
|
||||
@patch("api.src.services.tts.torch.cuda.is_available")
|
||||
@patch("api.src.services.tts.build_model")
|
||||
def test_model_initialization_cpu(mock_build_model, mock_cuda_available):
|
||||
"""Test model initialization with CPU"""
|
||||
mock_cuda_available.return_value = False
|
||||
|
@ -154,8 +184,8 @@ def test_model_initialization_cpu(mock_build_model, mock_cuda_available):
|
|||
mock_build_model.assert_called_once()
|
||||
|
||||
|
||||
@patch('api.src.services.tts.TTSService._get_voice_path')
|
||||
@patch('api.src.services.tts.TTSModel.get_instance')
|
||||
@patch("api.src.services.tts.TTSService._get_voice_path")
|
||||
@patch("api.src.services.tts.TTSModel.get_instance")
|
||||
def test_voicepack_loading_error(mock_get_instance, mock_get_voice_path):
|
||||
"""Test voicepack loading error handling"""
|
||||
mock_get_voice_path.return_value = None
|
||||
|
@ -168,7 +198,7 @@ def test_voicepack_loading_error(mock_get_instance, mock_get_voice_path):
|
|||
service._generate_audio("test", "nonexistent_voice", 1.0)
|
||||
|
||||
|
||||
@patch('api.src.services.tts.TTSModel')
|
||||
@patch("api.src.services.tts.TTSModel")
|
||||
def test_save_audio(mock_tts_model, tts_service, sample_audio, tmp_path):
|
||||
"""Test saving audio to file"""
|
||||
output_dir = os.path.join(tmp_path, "test_output")
|
||||
|
@ -181,12 +211,20 @@ def test_save_audio(mock_tts_model, tts_service, sample_audio, tmp_path):
|
|||
assert os.path.getsize(output_path) > 0
|
||||
|
||||
|
||||
@patch('api.src.services.tts.TTSModel.get_instance')
|
||||
@patch('os.path.exists')
|
||||
@patch('api.src.services.tts.normalize_text')
|
||||
@patch('api.src.services.tts.generate')
|
||||
@patch('torch.load')
|
||||
def test_generate_audio_without_stitching(mock_torch_load, mock_generate, mock_normalize, mock_exists, mock_instance, tts_service, sample_audio):
|
||||
@patch("api.src.services.tts.TTSModel.get_instance")
|
||||
@patch("os.path.exists")
|
||||
@patch("api.src.services.tts.normalize_text")
|
||||
@patch("api.src.services.tts.generate")
|
||||
@patch("torch.load")
|
||||
def test_generate_audio_without_stitching(
|
||||
mock_torch_load,
|
||||
mock_generate,
|
||||
mock_normalize,
|
||||
mock_exists,
|
||||
mock_instance,
|
||||
tts_service,
|
||||
sample_audio,
|
||||
):
|
||||
"""Test generating audio without text stitching"""
|
||||
mock_normalize.return_value = "Test text"
|
||||
mock_generate.return_value = (sample_audio, None)
|
||||
|
@ -194,14 +232,16 @@ def test_generate_audio_without_stitching(mock_torch_load, mock_generate, mock_n
|
|||
mock_exists.return_value = True
|
||||
mock_torch_load.return_value = MagicMock()
|
||||
|
||||
audio, processing_time = tts_service._generate_audio("Test text", "af", 1.0, stitch_long_output=False)
|
||||
audio, processing_time = tts_service._generate_audio(
|
||||
"Test text", "af", 1.0, stitch_long_output=False
|
||||
)
|
||||
assert isinstance(audio, np.ndarray)
|
||||
assert isinstance(processing_time, float)
|
||||
assert len(audio) > 0
|
||||
mock_generate.assert_called_once()
|
||||
|
||||
|
||||
@patch('os.listdir')
|
||||
@patch("os.listdir")
|
||||
def test_list_voices_error(mock_listdir, tts_service):
|
||||
"""Test error handling in list_voices"""
|
||||
mock_listdir.side_effect = Exception("Failed to list directory")
|
||||
|
@ -210,14 +250,23 @@ def test_list_voices_error(mock_listdir, tts_service):
|
|||
assert voices == []
|
||||
|
||||
|
||||
@patch('api.src.services.tts.TTSModel.get_instance')
|
||||
@patch('os.path.exists')
|
||||
@patch('api.src.services.tts.normalize_text')
|
||||
@patch('api.src.services.tts.phonemize')
|
||||
@patch('api.src.services.tts.tokenize')
|
||||
@patch('api.src.services.tts.generate')
|
||||
@patch('torch.load')
|
||||
def test_generate_audio_phonemize_error(mock_torch_load, mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_exists, mock_instance, tts_service):
|
||||
@patch("api.src.services.tts.TTSModel.get_instance")
|
||||
@patch("os.path.exists")
|
||||
@patch("api.src.services.tts.normalize_text")
|
||||
@patch("api.src.services.tts.phonemize")
|
||||
@patch("api.src.services.tts.tokenize")
|
||||
@patch("api.src.services.tts.generate")
|
||||
@patch("torch.load")
|
||||
def test_generate_audio_phonemize_error(
|
||||
mock_torch_load,
|
||||
mock_generate,
|
||||
mock_tokenize,
|
||||
mock_phonemize,
|
||||
mock_normalize,
|
||||
mock_exists,
|
||||
mock_instance,
|
||||
tts_service,
|
||||
):
|
||||
"""Test handling phonemization error"""
|
||||
mock_normalize.return_value = "Test text"
|
||||
mock_phonemize.side_effect = Exception("Phonemization failed")
|
||||
|
@ -230,12 +279,19 @@ def test_generate_audio_phonemize_error(mock_torch_load, mock_generate, mock_tok
|
|||
tts_service._generate_audio("Test text", "af", 1.0)
|
||||
|
||||
|
||||
@patch('api.src.services.tts.TTSModel.get_instance')
|
||||
@patch('os.path.exists')
|
||||
@patch('api.src.services.tts.normalize_text')
|
||||
@patch('api.src.services.tts.generate')
|
||||
@patch('torch.load')
|
||||
def test_generate_audio_error(mock_torch_load, mock_generate, mock_normalize, mock_exists, mock_instance, tts_service):
|
||||
@patch("api.src.services.tts.TTSModel.get_instance")
|
||||
@patch("os.path.exists")
|
||||
@patch("api.src.services.tts.normalize_text")
|
||||
@patch("api.src.services.tts.generate")
|
||||
@patch("torch.load")
|
||||
def test_generate_audio_error(
|
||||
mock_torch_load,
|
||||
mock_generate,
|
||||
mock_normalize,
|
||||
mock_exists,
|
||||
mock_instance,
|
||||
tts_service,
|
||||
):
|
||||
"""Test handling generation error"""
|
||||
mock_normalize.return_value = "Test text"
|
||||
mock_generate.side_effect = Exception("Generation failed")
|
||||
|
|
|
@ -19,7 +19,6 @@ output_dir = Path(__file__).parent / "output"
|
|||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
|
||||
def test_voice(voice: str):
|
||||
speech_file = output_dir / f"speech_{voice}.mp3"
|
||||
print(f"\nTesting voice: {voice}")
|
||||
|
|
|
@ -1,15 +1,17 @@
|
|||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import os
|
||||
from typing import List, Optional, Dict, Tuple
|
||||
import argparse
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
|
||||
import requests
|
||||
import numpy as np
|
||||
from scipy.io import wavfile
|
||||
import requests
|
||||
import matplotlib.pyplot as plt
|
||||
from scipy.io import wavfile
|
||||
|
||||
|
||||
def submit_combine_voices(voices: List[str], base_url: str = "http://localhost:8880") -> Optional[str]:
|
||||
def submit_combine_voices(
|
||||
voices: List[str], base_url: str = "http://localhost:8880"
|
||||
) -> Optional[str]:
|
||||
"""Combine multiple voices into a new voice.
|
||||
|
||||
Args:
|
||||
|
@ -46,7 +48,12 @@ def submit_combine_voices(voices: List[str], base_url: str = "http://localhost:8
|
|||
return None
|
||||
|
||||
|
||||
def generate_speech(text: str, voice: str, base_url: str = "http://localhost:8880", output_file: str = "output.mp3") -> bool:
|
||||
def generate_speech(
|
||||
text: str,
|
||||
voice: str,
|
||||
base_url: str = "http://localhost:8880",
|
||||
output_file: str = "output.mp3",
|
||||
) -> bool:
|
||||
"""Generate speech using specified voice.
|
||||
|
||||
Args:
|
||||
|
@ -65,8 +72,8 @@ def generate_speech(text: str, voice: str, base_url: str = "http://localhost:888
|
|||
"input": text,
|
||||
"voice": voice,
|
||||
"speed": 1.0,
|
||||
"response_format": "wav" # Use WAV for analysis
|
||||
}
|
||||
"response_format": "wav", # Use WAV for analysis
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
|
@ -75,7 +82,10 @@ def generate_speech(text: str, voice: str, base_url: str = "http://localhost:888
|
|||
return False
|
||||
|
||||
# Save the audio
|
||||
os.makedirs(os.path.dirname(output_file) if os.path.dirname(output_file) else ".", exist_ok=True)
|
||||
os.makedirs(
|
||||
os.path.dirname(output_file) if os.path.dirname(output_file) else ".",
|
||||
exist_ok=True,
|
||||
)
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(response.content)
|
||||
print(f"Saved audio to {output_file}")
|
||||
|
@ -113,7 +123,7 @@ def analyze_audio(filepath: str) -> Tuple[np.ndarray, int, dict]:
|
|||
if len(samples) > 0:
|
||||
# Use FFT to get frequency components
|
||||
fft_result = np.fft.fft(samples)
|
||||
freqs = np.fft.fftfreq(len(samples), 1/sample_rate)
|
||||
freqs = np.fft.fftfreq(len(samples), 1 / sample_rate)
|
||||
|
||||
# Get positive frequencies only
|
||||
pos_mask = freqs > 0
|
||||
|
@ -136,7 +146,7 @@ def analyze_audio(filepath: str) -> Tuple[np.ndarray, int, dict]:
|
|||
"duration": duration,
|
||||
"zero_crossing_rate": zero_crossings,
|
||||
"dominant_frequencies": dominant_freqs,
|
||||
"spectral_centroid": spectral_centroid
|
||||
"spectral_centroid": spectral_centroid,
|
||||
}
|
||||
|
||||
return samples, sample_rate, characteristics
|
||||
|
@ -167,6 +177,7 @@ def setup_plot(fig, ax, title):
|
|||
|
||||
return fig, ax
|
||||
|
||||
|
||||
def plot_analysis(audio_files: Dict[str, str], output_dir: str):
|
||||
"""Plot comprehensive voice analysis including waveforms and metrics comparison.
|
||||
|
||||
|
@ -175,7 +186,7 @@ def plot_analysis(audio_files: Dict[str, str], output_dir: str):
|
|||
output_dir: Directory to save plot files
|
||||
"""
|
||||
# Set dark style
|
||||
plt.style.use('dark_background')
|
||||
plt.style.use("dark_background")
|
||||
|
||||
# Create figure with subplots
|
||||
fig = plt.figure(figsize=(15, 15))
|
||||
|
@ -183,8 +194,9 @@ def plot_analysis(audio_files: Dict[str, str], output_dir: str):
|
|||
num_files = len(audio_files)
|
||||
|
||||
# Create subplot grid with proper spacing
|
||||
gs = plt.GridSpec(num_files + 1, 2, height_ratios=[1.5]*num_files + [1],
|
||||
hspace=0.4, wspace=0.3)
|
||||
gs = plt.GridSpec(
|
||||
num_files + 1, 2, height_ratios=[1.5] * num_files + [1], hspace=0.4, wspace=0.3
|
||||
)
|
||||
|
||||
# Analyze all files first
|
||||
all_chars = {}
|
||||
|
@ -195,7 +207,7 @@ def plot_analysis(audio_files: Dict[str, str], output_dir: str):
|
|||
# Plot waveform spanning both columns
|
||||
ax = plt.subplot(gs[i, :])
|
||||
time = np.arange(len(samples)) / sample_rate
|
||||
plt.plot(time, samples / chars['max_amplitude'], linewidth=0.5, color="#ff2a6d")
|
||||
plt.plot(time, samples / chars["max_amplitude"], linewidth=0.5, color="#ff2a6d")
|
||||
ax.set_xlabel("Time (seconds)")
|
||||
ax.set_ylabel("Normalized Amplitude")
|
||||
ax.set_ylim(-1.1, 1.1)
|
||||
|
@ -208,15 +220,27 @@ def plot_analysis(audio_files: Dict[str, str], output_dir: str):
|
|||
# Left subplot: Brightness and Volume
|
||||
ax1 = plt.subplot(gs[num_files, 0])
|
||||
metrics1 = [
|
||||
('Brightness', [chars['spectral_centroid']/1000 for chars in all_chars.values()], 'kHz'),
|
||||
('Volume', [chars['rms']*100 for chars in all_chars.values()], 'RMS×100')
|
||||
(
|
||||
"Brightness",
|
||||
[chars["spectral_centroid"] / 1000 for chars in all_chars.values()],
|
||||
"kHz",
|
||||
),
|
||||
("Volume", [chars["rms"] * 100 for chars in all_chars.values()], "RMS×100"),
|
||||
]
|
||||
|
||||
# Right subplot: Voice Pitch and Texture
|
||||
ax2 = plt.subplot(gs[num_files, 1])
|
||||
metrics2 = [
|
||||
('Voice Pitch', [min(chars['dominant_frequencies']) for chars in all_chars.values()], 'Hz'),
|
||||
('Texture', [chars['zero_crossing_rate']*1000 for chars in all_chars.values()], 'ZCR×1000')
|
||||
(
|
||||
"Voice Pitch",
|
||||
[min(chars["dominant_frequencies"]) for chars in all_chars.values()],
|
||||
"Hz",
|
||||
),
|
||||
(
|
||||
"Texture",
|
||||
[chars["zero_crossing_rate"] * 1000 for chars in all_chars.values()],
|
||||
"ZCR×1000",
|
||||
),
|
||||
]
|
||||
|
||||
def plot_grouped_bars(ax, metrics, show_legend=True):
|
||||
|
@ -231,17 +255,23 @@ def plot_analysis(audio_files: Dict[str, str], output_dir: str):
|
|||
|
||||
for i, (voice, color) in enumerate(zip(audio_files.keys(), colors)):
|
||||
values = [m[1][i] for m in metrics]
|
||||
offset = (i - n_voices/2 + 0.5) * bar_width
|
||||
bars = ax.bar(indices + offset, values, bar_width,
|
||||
label=voice, color=color, alpha=0.8)
|
||||
offset = (i - n_voices / 2 + 0.5) * bar_width
|
||||
bars = ax.bar(
|
||||
indices + offset, values, bar_width, label=voice, color=color, alpha=0.8
|
||||
)
|
||||
|
||||
# Add value labels on top of bars
|
||||
for bar in bars:
|
||||
height = bar.get_height()
|
||||
ax.text(bar.get_x() + bar.get_width()/2., height,
|
||||
f'{height:.1f}',
|
||||
ha='center', va='bottom', color='white',
|
||||
fontsize=10)
|
||||
ax.text(
|
||||
bar.get_x() + bar.get_width() / 2.0,
|
||||
height,
|
||||
f"{height:.1f}",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
color="white",
|
||||
fontsize=10,
|
||||
)
|
||||
|
||||
ax.set_xticks(indices)
|
||||
ax.set_xticklabels([f"{m[0]}\n({m[2]})" for m in metrics])
|
||||
|
@ -250,20 +280,24 @@ def plot_analysis(audio_files: Dict[str, str], output_dir: str):
|
|||
ax.set_ylim(0, max_val * 1.2)
|
||||
|
||||
if show_legend:
|
||||
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left',
|
||||
facecolor="#1a1a2e", edgecolor="#ffffff")
|
||||
ax.legend(
|
||||
bbox_to_anchor=(1.05, 1),
|
||||
loc="upper left",
|
||||
facecolor="#1a1a2e",
|
||||
edgecolor="#ffffff",
|
||||
)
|
||||
|
||||
# Plot both subplots
|
||||
plot_grouped_bars(ax1, metrics1, show_legend=True)
|
||||
plot_grouped_bars(ax2, metrics2, show_legend=False)
|
||||
|
||||
# Style both subplots
|
||||
setup_plot(fig, ax1, 'Brightness and Volume')
|
||||
setup_plot(fig, ax2, 'Voice Pitch and Texture')
|
||||
setup_plot(fig, ax1, "Brightness and Volume")
|
||||
setup_plot(fig, ax2, "Voice Pitch and Texture")
|
||||
|
||||
# Add y-axis labels
|
||||
ax1.set_ylabel('Value')
|
||||
ax2.set_ylabel('Value')
|
||||
ax1.set_ylabel("Value")
|
||||
ax2.set_ylabel("Value")
|
||||
|
||||
# Adjust the figure size to accommodate the legend
|
||||
fig.set_size_inches(15, 15)
|
||||
|
@ -282,15 +316,26 @@ def plot_analysis(audio_files: Dict[str, str], output_dir: str):
|
|||
print(f" Duration: {chars['duration']:.2f}s")
|
||||
print(f" Zero Crossing Rate: {chars['zero_crossing_rate']:.3f}")
|
||||
print(f" Spectral Centroid: {chars['spectral_centroid']:.0f}Hz")
|
||||
print(f" Dominant Frequencies: {', '.join(f'{f:.0f}Hz' for f in chars['dominant_frequencies'])}")
|
||||
print(
|
||||
f" Dominant Frequencies: {', '.join(f'{f:.0f}Hz' for f in chars['dominant_frequencies'])}"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Kokoro Voice Analysis Demo")
|
||||
parser.add_argument("--voices", nargs="+", type=str, help="Voices to combine")
|
||||
parser.add_argument("--text", type=str, default="Hello! This is a test of combined voices.", help="Text to speak")
|
||||
parser.add_argument(
|
||||
"--text",
|
||||
type=str,
|
||||
default="Hello! This is a test of combined voices.",
|
||||
help="Text to speak",
|
||||
)
|
||||
parser.add_argument("--url", default="http://localhost:8880", help="API base URL")
|
||||
parser.add_argument("--output-dir", default="examples/output", help="Output directory for audio files")
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
default="examples/output",
|
||||
help="Output directory for audio files",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.voices:
|
||||
|
@ -316,7 +361,9 @@ def main():
|
|||
|
||||
if combined_voice:
|
||||
print(f"Successfully created combined voice: {combined_voice}")
|
||||
output_file = os.path.join(args.output_dir, f"analysis_combined_{combined_voice}.wav")
|
||||
output_file = os.path.join(
|
||||
args.output_dir, f"analysis_combined_{combined_voice}.wav"
|
||||
)
|
||||
if generate_speech(args.text, combined_voice, args.url, output_file):
|
||||
audio_files["combined"] = output_file
|
||||
|
||||
|
|
|
@ -60,7 +60,7 @@ def test_speed(speed: float):
|
|||
|
||||
# Test different formats
|
||||
for format in ["wav", "mp3", "opus", "aac", "flac", "pcm"]:
|
||||
test_format(format) # aac and pcm should fail as they are not supported
|
||||
test_format(format) # aac and pcm should fail as they are not supported
|
||||
|
||||
# Test different speeds
|
||||
for speed in [0.25, 1.0, 2.0, 4.0]: # 5.0 should fail as it's out of range
|
||||
|
|
|
@ -10,3 +10,5 @@ sqlalchemy==2.0.27
|
|||
pytest==8.0.0
|
||||
httpx==0.26.0
|
||||
pytest-asyncio==0.23.5
|
||||
pytest-cov==6.0.0
|
||||
gradio==4.19.2
|
||||
|
|
|
@ -2,8 +2,4 @@ from lib.interface import create_interface
|
|||
|
||||
if __name__ == "__main__":
|
||||
demo = create_interface()
|
||||
demo.launch(
|
||||
server_name="0.0.0.0",
|
||||
server_port=7860,
|
||||
show_error=True
|
||||
)
|
||||
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
|
||||
|
|
|
@ -1,16 +1,19 @@
|
|||
import requests
|
||||
from typing import Tuple, List, Optional
|
||||
import os
|
||||
import datetime
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from .config import API_URL, OUTPUTS_DIR
|
||||
|
||||
|
||||
def check_api_status() -> Tuple[bool, List[str]]:
|
||||
"""Check TTS service status and get available voices."""
|
||||
try:
|
||||
# Use a longer timeout during startup
|
||||
response = requests.get(
|
||||
f"{API_URL}/v1/audio/voices",
|
||||
timeout=30 # Increased timeout for initial startup period
|
||||
timeout=30, # Increased timeout for initial startup period
|
||||
)
|
||||
response.raise_for_status()
|
||||
voices = response.json().get("voices", [])
|
||||
|
@ -31,7 +34,10 @@ def check_api_status() -> Tuple[bool, List[str]]:
|
|||
print(f"Unexpected error checking API status: {str(e)}")
|
||||
return False, []
|
||||
|
||||
def text_to_speech(text: str, voice_id: str, format: str, speed: float) -> Optional[str]:
|
||||
|
||||
def text_to_speech(
|
||||
text: str, voice_id: str, format: str, speed: float
|
||||
) -> Optional[str]:
|
||||
"""Generate speech from text using TTS API."""
|
||||
if not text.strip():
|
||||
return None
|
||||
|
@ -49,10 +55,10 @@ def text_to_speech(text: str, voice_id: str, format: str, speed: float) -> Optio
|
|||
"input": text,
|
||||
"voice": voice_id,
|
||||
"response_format": format,
|
||||
"speed": float(speed)
|
||||
"speed": float(speed),
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=300 # Longer timeout for speech generation
|
||||
timeout=300, # Longer timeout for speech generation
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
|
@ -70,6 +76,7 @@ def text_to_speech(text: str, voice_id: str, format: str, speed: float) -> Optio
|
|||
print(f"Unexpected error generating speech: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def get_status_html(is_available: bool) -> str:
|
||||
"""Generate HTML for status indicator."""
|
||||
color = "green" if is_available else "red"
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
import gradio as gr
|
||||
from typing import Tuple
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from .. import files
|
||||
|
||||
|
||||
def create_input_column() -> Tuple[gr.Column, dict]:
|
||||
"""Create the input column with text input and file handling."""
|
||||
with gr.Column(scale=1) as col:
|
||||
|
@ -11,15 +14,9 @@ def create_input_column() -> Tuple[gr.Column, dict]:
|
|||
# Direct Input Tab
|
||||
with gr.TabItem("Direct Input"):
|
||||
text_input = gr.Textbox(
|
||||
label="Text to speak",
|
||||
placeholder="Enter text here...",
|
||||
lines=4
|
||||
)
|
||||
text_submit = gr.Button(
|
||||
"Generate Speech",
|
||||
variant="primary",
|
||||
size="lg"
|
||||
label="Text to speak", placeholder="Enter text here...", lines=4
|
||||
)
|
||||
text_submit = gr.Button("Generate Speech", variant="primary", size="lg")
|
||||
|
||||
# File Input Tab
|
||||
with gr.TabItem("From File"):
|
||||
|
@ -27,31 +24,24 @@ def create_input_column() -> Tuple[gr.Column, dict]:
|
|||
input_files_list = gr.Dropdown(
|
||||
label="Select Existing File",
|
||||
choices=files.list_input_files(),
|
||||
value=None
|
||||
value=None,
|
||||
)
|
||||
|
||||
# Simple file upload
|
||||
file_upload = gr.File(
|
||||
label="Upload Text File (.txt)",
|
||||
file_types=[".txt"]
|
||||
label="Upload Text File (.txt)", file_types=[".txt"]
|
||||
)
|
||||
|
||||
file_preview = gr.Textbox(
|
||||
label="File Content Preview",
|
||||
interactive=False,
|
||||
lines=4
|
||||
label="File Content Preview", interactive=False, lines=4
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
file_submit = gr.Button(
|
||||
"Generate Speech",
|
||||
variant="primary",
|
||||
size="lg"
|
||||
"Generate Speech", variant="primary", size="lg"
|
||||
)
|
||||
clear_files = gr.Button(
|
||||
"Clear Files",
|
||||
variant="secondary",
|
||||
size="lg"
|
||||
"Clear Files", variant="secondary", size="lg"
|
||||
)
|
||||
|
||||
components = {
|
||||
|
@ -62,7 +52,7 @@ def create_input_column() -> Tuple[gr.Column, dict]:
|
|||
"file_preview": file_preview,
|
||||
"text_submit": text_submit,
|
||||
"file_submit": file_submit,
|
||||
"clear_files": clear_files
|
||||
"clear_files": clear_files,
|
||||
}
|
||||
|
||||
return col, components
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
import gradio as gr
|
||||
from typing import Tuple, Optional
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from .. import api, config
|
||||
|
||||
|
||||
def create_model_column(voice_ids: Optional[list] = None) -> Tuple[gr.Column, dict]:
|
||||
"""Create the model settings column."""
|
||||
if voice_ids is None:
|
||||
|
@ -12,34 +15,27 @@ def create_model_column(voice_ids: Optional[list] = None) -> Tuple[gr.Column, di
|
|||
|
||||
# Status button starts in waiting state
|
||||
status_btn = gr.Button(
|
||||
"⌛ TTS Service: Waiting for Service...",
|
||||
variant="secondary"
|
||||
"⌛ TTS Service: Waiting for Service...", variant="secondary"
|
||||
)
|
||||
|
||||
voice_input = gr.Dropdown(
|
||||
choices=voice_ids,
|
||||
label="Voice",
|
||||
value=voice_ids[0] if voice_ids else None,
|
||||
interactive=True
|
||||
interactive=True,
|
||||
)
|
||||
format_input = gr.Dropdown(
|
||||
choices=config.AUDIO_FORMATS,
|
||||
label="Audio Format",
|
||||
value="mp3"
|
||||
choices=config.AUDIO_FORMATS, label="Audio Format", value="mp3"
|
||||
)
|
||||
speed_input = gr.Slider(
|
||||
minimum=0.5,
|
||||
maximum=2.0,
|
||||
value=1.0,
|
||||
step=0.1,
|
||||
label="Speed"
|
||||
minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speed"
|
||||
)
|
||||
|
||||
components = {
|
||||
"status_btn": status_btn,
|
||||
"voice": voice_input,
|
||||
"format": format_input,
|
||||
"speed": speed_input
|
||||
"speed": speed_input,
|
||||
}
|
||||
|
||||
return col, components
|
||||
|
|
|
@ -1,40 +1,42 @@
|
|||
import gradio as gr
|
||||
from typing import Tuple
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from .. import files
|
||||
|
||||
|
||||
def create_output_column() -> Tuple[gr.Column, dict]:
|
||||
"""Create the output column with audio player and file list."""
|
||||
with gr.Column(scale=1) as col:
|
||||
gr.Markdown("### Latest Output")
|
||||
audio_output = gr.Audio(
|
||||
label="Generated Speech",
|
||||
type="filepath"
|
||||
)
|
||||
audio_output = gr.Audio(label="Generated Speech", type="filepath")
|
||||
|
||||
gr.Markdown("### Generated Files")
|
||||
output_files = gr.Dropdown(
|
||||
label="Previous Outputs",
|
||||
choices=files.list_output_files(),
|
||||
value=None,
|
||||
allow_custom_value=False
|
||||
allow_custom_value=False,
|
||||
)
|
||||
|
||||
play_btn = gr.Button("▶️ Play Selected", size="sm")
|
||||
|
||||
selected_audio = gr.Audio(
|
||||
label="Selected Output",
|
||||
type="filepath",
|
||||
visible=False
|
||||
label="Selected Output", type="filepath", visible=False
|
||||
)
|
||||
|
||||
clear_outputs = gr.Button("⚠️ Delete All Previously Generated Output Audio 🗑️", size="sm", variant="secondary")
|
||||
clear_outputs = gr.Button(
|
||||
"⚠️ Delete All Previously Generated Output Audio 🗑️",
|
||||
size="sm",
|
||||
variant="secondary",
|
||||
)
|
||||
|
||||
components = {
|
||||
"audio_output": audio_output,
|
||||
"output_files": output_files,
|
||||
"play_btn": play_btn,
|
||||
"selected_audio": selected_audio,
|
||||
"clear_outputs": clear_outputs
|
||||
"clear_outputs": clear_outputs,
|
||||
}
|
||||
|
||||
return col, components
|
||||
|
|
|
@ -1,17 +1,23 @@
|
|||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
import datetime
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from .config import INPUTS_DIR, OUTPUTS_DIR, AUDIO_FORMATS
|
||||
|
||||
|
||||
def list_input_files() -> List[str]:
|
||||
"""List all input text files."""
|
||||
return [f for f in os.listdir(INPUTS_DIR) if f.endswith('.txt')]
|
||||
return [f for f in os.listdir(INPUTS_DIR) if f.endswith(".txt")]
|
||||
|
||||
|
||||
def list_output_files() -> List[str]:
|
||||
"""List all output audio files."""
|
||||
return [os.path.join(OUTPUTS_DIR, f)
|
||||
for f in os.listdir(OUTPUTS_DIR)
|
||||
if any(f.endswith(ext) for ext in AUDIO_FORMATS)]
|
||||
return [
|
||||
os.path.join(OUTPUTS_DIR, f)
|
||||
for f in os.listdir(OUTPUTS_DIR)
|
||||
if any(f.endswith(ext) for ext in AUDIO_FORMATS)
|
||||
]
|
||||
|
||||
|
||||
def read_text_file(filename: str) -> str:
|
||||
"""Read content of a text file."""
|
||||
|
@ -19,11 +25,12 @@ def read_text_file(filename: str) -> str:
|
|||
return ""
|
||||
try:
|
||||
file_path = os.path.join(INPUTS_DIR, filename)
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
except:
|
||||
return ""
|
||||
|
||||
|
||||
def save_text(text: str, filename: Optional[str] = None) -> Optional[str]:
|
||||
"""Save text to a file. Returns the filename if successful."""
|
||||
if not text.strip():
|
||||
|
@ -41,7 +48,7 @@ def save_text(text: str, filename: Optional[str] = None) -> Optional[str]:
|
|||
else:
|
||||
# Handle duplicate filenames by adding _1, _2, etc.
|
||||
base = os.path.splitext(filename)[0]
|
||||
ext = os.path.splitext(filename)[1] or '.txt'
|
||||
ext = os.path.splitext(filename)[1] or ".txt"
|
||||
counter = 1
|
||||
while os.path.exists(os.path.join(INPUTS_DIR, filename)):
|
||||
filename = f"{base}_{counter}{ext}"
|
||||
|
@ -56,11 +63,12 @@ def save_text(text: str, filename: Optional[str] = None) -> Optional[str]:
|
|||
print(f"Error saving file: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def delete_all_input_files() -> bool:
|
||||
"""Delete all files from the inputs directory. Returns True if successful."""
|
||||
try:
|
||||
for filename in os.listdir(INPUTS_DIR):
|
||||
if filename.endswith('.txt'):
|
||||
if filename.endswith(".txt"):
|
||||
file_path = os.path.join(INPUTS_DIR, filename)
|
||||
os.remove(file_path)
|
||||
return True
|
||||
|
@ -68,6 +76,7 @@ def delete_all_input_files() -> bool:
|
|||
print(f"Error deleting input files: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def delete_all_output_files() -> bool:
|
||||
"""Delete all audio files from the outputs directory. Returns True if successful."""
|
||||
try:
|
||||
|
@ -80,6 +89,7 @@ def delete_all_output_files() -> bool:
|
|||
print(f"Error deleting output files: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def process_uploaded_file(file_path: str) -> bool:
|
||||
"""Save uploaded file to inputs directory. Returns True if successful."""
|
||||
if not file_path:
|
||||
|
@ -87,7 +97,7 @@ def process_uploaded_file(file_path: str) -> bool:
|
|||
|
||||
try:
|
||||
filename = os.path.basename(file_path)
|
||||
if not filename.endswith('.txt'):
|
||||
if not filename.endswith(".txt"):
|
||||
return False
|
||||
|
||||
# Create target path in inputs directory
|
||||
|
@ -103,6 +113,7 @@ def process_uploaded_file(file_path: str) -> bool:
|
|||
|
||||
# Copy file to inputs directory
|
||||
import shutil
|
||||
|
||||
shutil.copy2(file_path, target_path)
|
||||
return True
|
||||
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
import gradio as gr
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from . import api, files
|
||||
|
||||
|
||||
def setup_event_handlers(components: dict):
|
||||
"""Set up all event handlers for the UI components."""
|
||||
|
||||
|
@ -19,17 +22,17 @@ def setup_event_handlers(components: dict):
|
|||
gr.update(
|
||||
value=f"🔄 TTS Service: {status}",
|
||||
interactive=True,
|
||||
variant="secondary"
|
||||
variant="secondary",
|
||||
),
|
||||
gr.update(choices=voices, value=default_voice)
|
||||
gr.update(choices=voices, value=default_voice),
|
||||
]
|
||||
return [
|
||||
gr.update(
|
||||
value=f"⌛ TTS Service: {status}",
|
||||
interactive=True,
|
||||
variant="secondary"
|
||||
variant="secondary",
|
||||
),
|
||||
gr.update(choices=[], value=None)
|
||||
gr.update(choices=[], value=None),
|
||||
]
|
||||
except Exception as e:
|
||||
print(f"Error in refresh status: {str(e)}")
|
||||
|
@ -37,9 +40,9 @@ def setup_event_handlers(components: dict):
|
|||
gr.update(
|
||||
value="❌ TTS Service: Connection Error",
|
||||
interactive=True,
|
||||
variant="secondary"
|
||||
variant="secondary",
|
||||
),
|
||||
gr.update(choices=[], value=None)
|
||||
gr.update(choices=[], value=None),
|
||||
]
|
||||
|
||||
def handle_file_select(filename):
|
||||
|
@ -82,30 +85,23 @@ def setup_event_handlers(components: dict):
|
|||
is_available, _ = api.check_api_status()
|
||||
if not is_available:
|
||||
gr.Warning("TTS Service is currently unavailable")
|
||||
return [
|
||||
None,
|
||||
gr.update(choices=files.list_output_files())
|
||||
]
|
||||
return [None, gr.update(choices=files.list_output_files())]
|
||||
|
||||
if not text or not text.strip():
|
||||
gr.Warning("Please enter text in the input box")
|
||||
return [
|
||||
None,
|
||||
gr.update(choices=files.list_output_files())
|
||||
]
|
||||
return [None, gr.update(choices=files.list_output_files())]
|
||||
|
||||
files.save_text(text)
|
||||
result = api.text_to_speech(text, voice, format, speed)
|
||||
if result is None:
|
||||
gr.Warning("Failed to generate speech. Please try again.")
|
||||
return [
|
||||
None,
|
||||
gr.update(choices=files.list_output_files())
|
||||
]
|
||||
return [None, gr.update(choices=files.list_output_files())]
|
||||
|
||||
return [
|
||||
result,
|
||||
gr.update(choices=files.list_output_files(), value=os.path.basename(result))
|
||||
gr.update(
|
||||
choices=files.list_output_files(), value=os.path.basename(result)
|
||||
),
|
||||
]
|
||||
|
||||
def generate_from_file(selected_file, voice, format, speed):
|
||||
|
@ -113,30 +109,23 @@ def setup_event_handlers(components: dict):
|
|||
is_available, _ = api.check_api_status()
|
||||
if not is_available:
|
||||
gr.Warning("TTS Service is currently unavailable")
|
||||
return [
|
||||
None,
|
||||
gr.update(choices=files.list_output_files())
|
||||
]
|
||||
return [None, gr.update(choices=files.list_output_files())]
|
||||
|
||||
if not selected_file:
|
||||
gr.Warning("Please select a file")
|
||||
return [
|
||||
None,
|
||||
gr.update(choices=files.list_output_files())
|
||||
]
|
||||
return [None, gr.update(choices=files.list_output_files())]
|
||||
|
||||
text = files.read_text_file(selected_file)
|
||||
result = api.text_to_speech(text, voice, format, speed)
|
||||
if result is None:
|
||||
gr.Warning("Failed to generate speech. Please try again.")
|
||||
return [
|
||||
None,
|
||||
gr.update(choices=files.list_output_files())
|
||||
]
|
||||
return [None, gr.update(choices=files.list_output_files())]
|
||||
|
||||
return [
|
||||
result,
|
||||
gr.update(choices=files.list_output_files(), value=os.path.basename(result))
|
||||
gr.update(
|
||||
choices=files.list_output_files(), value=os.path.basename(result)
|
||||
),
|
||||
]
|
||||
|
||||
def play_selected(file_path):
|
||||
|
@ -155,7 +144,7 @@ def setup_event_handlers(components: dict):
|
|||
gr.update(choices=files.list_output_files()), # output_files
|
||||
gr.update(value=voice), # voice
|
||||
gr.update(value=format), # format
|
||||
gr.update(value=speed) # speed
|
||||
gr.update(value=speed), # speed
|
||||
]
|
||||
|
||||
def clear_outputs():
|
||||
|
@ -164,34 +153,31 @@ def setup_event_handlers(components: dict):
|
|||
return [
|
||||
None, # audio_output
|
||||
gr.update(choices=[], value=None), # output_files
|
||||
gr.update(visible=False) # selected_audio
|
||||
gr.update(visible=False), # selected_audio
|
||||
]
|
||||
|
||||
# Connect event handlers
|
||||
components["model"]["status_btn"].click(
|
||||
fn=refresh_status,
|
||||
outputs=[
|
||||
components["model"]["status_btn"],
|
||||
components["model"]["voice"]
|
||||
]
|
||||
outputs=[components["model"]["status_btn"], components["model"]["voice"]],
|
||||
)
|
||||
|
||||
components["input"]["file_select"].change(
|
||||
fn=handle_file_select,
|
||||
inputs=[components["input"]["file_select"]],
|
||||
outputs=[components["input"]["file_preview"]]
|
||||
outputs=[components["input"]["file_preview"]],
|
||||
)
|
||||
|
||||
components["input"]["file_upload"].upload(
|
||||
fn=handle_file_upload,
|
||||
inputs=[components["input"]["file_upload"]],
|
||||
outputs=[components["input"]["file_select"]]
|
||||
outputs=[components["input"]["file_select"]],
|
||||
)
|
||||
|
||||
components["output"]["play_btn"].click(
|
||||
fn=play_selected,
|
||||
inputs=[components["output"]["output_files"]],
|
||||
outputs=[components["output"]["selected_audio"]]
|
||||
outputs=[components["output"]["selected_audio"]],
|
||||
)
|
||||
|
||||
# Connect clear files button
|
||||
|
@ -200,7 +186,7 @@ def setup_event_handlers(components: dict):
|
|||
inputs=[
|
||||
components["model"]["voice"],
|
||||
components["model"]["format"],
|
||||
components["model"]["speed"]
|
||||
components["model"]["speed"],
|
||||
],
|
||||
outputs=[
|
||||
components["input"]["file_select"],
|
||||
|
@ -210,8 +196,8 @@ def setup_event_handlers(components: dict):
|
|||
components["output"]["output_files"],
|
||||
components["model"]["voice"],
|
||||
components["model"]["format"],
|
||||
components["model"]["speed"]
|
||||
]
|
||||
components["model"]["speed"],
|
||||
],
|
||||
)
|
||||
|
||||
# Connect submit buttons for each tab
|
||||
|
@ -221,12 +207,12 @@ def setup_event_handlers(components: dict):
|
|||
components["input"]["text_input"],
|
||||
components["model"]["voice"],
|
||||
components["model"]["format"],
|
||||
components["model"]["speed"]
|
||||
components["model"]["speed"],
|
||||
],
|
||||
outputs=[
|
||||
components["output"]["audio_output"],
|
||||
components["output"]["output_files"]
|
||||
]
|
||||
components["output"]["output_files"],
|
||||
],
|
||||
)
|
||||
|
||||
# Connect clear outputs button
|
||||
|
@ -235,8 +221,8 @@ def setup_event_handlers(components: dict):
|
|||
outputs=[
|
||||
components["output"]["audio_output"],
|
||||
components["output"]["output_files"],
|
||||
components["output"]["selected_audio"]
|
||||
]
|
||||
components["output"]["selected_audio"],
|
||||
],
|
||||
)
|
||||
|
||||
components["input"]["file_submit"].click(
|
||||
|
@ -245,10 +231,10 @@ def setup_event_handlers(components: dict):
|
|||
components["input"]["file_select"],
|
||||
components["model"]["voice"],
|
||||
components["model"]["format"],
|
||||
components["model"]["speed"]
|
||||
components["model"]["speed"],
|
||||
],
|
||||
outputs=[
|
||||
components["output"]["audio_output"],
|
||||
components["output"]["output_files"]
|
||||
]
|
||||
components["output"]["output_files"],
|
||||
],
|
||||
)
|
||||
|
|
|
@ -1,34 +1,38 @@
|
|||
import gradio as gr
|
||||
|
||||
from . import api
|
||||
from .components import create_input_column, create_model_column, create_output_column
|
||||
from .handlers import setup_event_handlers
|
||||
from .components import create_input_column, create_model_column, create_output_column
|
||||
|
||||
|
||||
def create_interface():
|
||||
"""Create the main Gradio interface."""
|
||||
# Skip initial status check - let the timer handle it
|
||||
is_available, available_voices = False, []
|
||||
|
||||
with gr.Blocks(
|
||||
title="Kokoro TTS Demo",
|
||||
theme=gr.themes.Monochrome()
|
||||
) as demo:
|
||||
gr.HTML(value='<div style="display: flex; gap: 0;">'
|
||||
'<a href="https://huggingface.co/hexgrad/Kokoro-82M" target="_blank" style="color: #2196F3; text-decoration: none; margin: 2px; border: 1px solid #2196F3; padding: 4px 8px; height: 24px; box-sizing: border-box; display: inline-flex; align-items: center;">Kokoro-82M HF Repo</a>'
|
||||
'<a href="https://github.com/remsky/Kokoro-FastAPI" target="_blank" style="color: #2196F3; text-decoration: none; margin: 2px; border: 1px solid #2196F3; padding: 4px 8px; height: 24px; box-sizing: border-box; display: inline-flex; align-items: center;">Kokoro-FastAPI Repo</a>'
|
||||
'</div>', show_label=False)
|
||||
with gr.Blocks(title="Kokoro TTS Demo", theme=gr.themes.Monochrome()) as demo:
|
||||
gr.HTML(
|
||||
value='<div style="display: flex; gap: 0;">'
|
||||
'<a href="https://huggingface.co/hexgrad/Kokoro-82M" target="_blank" style="color: #2196F3; text-decoration: none; margin: 2px; border: 1px solid #2196F3; padding: 4px 8px; height: 24px; box-sizing: border-box; display: inline-flex; align-items: center;">Kokoro-82M HF Repo</a>'
|
||||
'<a href="https://github.com/remsky/Kokoro-FastAPI" target="_blank" style="color: #2196F3; text-decoration: none; margin: 2px; border: 1px solid #2196F3; padding: 4px 8px; height: 24px; box-sizing: border-box; display: inline-flex; align-items: center;">Kokoro-FastAPI Repo</a>'
|
||||
"</div>",
|
||||
show_label=False,
|
||||
)
|
||||
|
||||
# Main interface
|
||||
with gr.Row():
|
||||
# Create columns
|
||||
input_col, input_components = create_input_column()
|
||||
model_col, model_components = create_model_column(available_voices) # Pass initial voices
|
||||
model_col, model_components = create_model_column(
|
||||
available_voices
|
||||
) # Pass initial voices
|
||||
output_col, output_components = create_output_column()
|
||||
|
||||
# Collect all components
|
||||
components = {
|
||||
"input": input_components,
|
||||
"model": model_components,
|
||||
"output": output_components
|
||||
"output": output_components,
|
||||
}
|
||||
|
||||
# Set up event handlers
|
||||
|
@ -43,16 +47,18 @@ def create_interface():
|
|||
if is_available and voices:
|
||||
# Service is available, update UI and stop timer
|
||||
current_voice = components["model"]["voice"].value
|
||||
default_voice = current_voice if current_voice in voices else voices[0]
|
||||
default_voice = (
|
||||
current_voice if current_voice in voices else voices[0]
|
||||
)
|
||||
# Return values in same order as outputs list
|
||||
return [
|
||||
gr.update(
|
||||
value=f"🔄 TTS Service: {status}",
|
||||
interactive=True,
|
||||
variant="secondary"
|
||||
variant="secondary",
|
||||
),
|
||||
gr.update(choices=voices, value=default_voice),
|
||||
gr.update(active=False) # Stop timer
|
||||
gr.update(active=False), # Stop timer
|
||||
]
|
||||
|
||||
# Service not available yet, keep checking
|
||||
|
@ -60,10 +66,10 @@ def create_interface():
|
|||
gr.update(
|
||||
value=f"⌛ TTS Service: {status}",
|
||||
interactive=True,
|
||||
variant="secondary"
|
||||
variant="secondary",
|
||||
),
|
||||
gr.update(choices=[], value=None),
|
||||
gr.update(active=True)
|
||||
gr.update(active=True),
|
||||
]
|
||||
except Exception as e:
|
||||
print(f"Error in status update: {str(e)}")
|
||||
|
@ -72,10 +78,10 @@ def create_interface():
|
|||
gr.update(
|
||||
value="❌ TTS Service: Connection Error",
|
||||
interactive=True,
|
||||
variant="secondary"
|
||||
variant="secondary",
|
||||
),
|
||||
gr.update(choices=[], value=None),
|
||||
gr.update(active=True)
|
||||
gr.update(active=True),
|
||||
]
|
||||
|
||||
timer = gr.Timer(value=5) # Check every 5 seconds
|
||||
|
@ -84,8 +90,8 @@ def create_interface():
|
|||
outputs=[
|
||||
components["model"]["status_btn"],
|
||||
components["model"]["voice"],
|
||||
timer
|
||||
]
|
||||
timer,
|
||||
],
|
||||
)
|
||||
|
||||
return demo
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import pytest
|
||||
import gradio as gr
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
from unittest.mock import patch, mock_open
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from unittest.mock import patch, mock_open
|
||||
|
||||
from ui.lib import api
|
||||
|
||||
|
||||
|
@ -57,10 +59,9 @@ def test_check_api_status_connection_error():
|
|||
|
||||
def test_text_to_speech_success(mock_response, tmp_path):
|
||||
"""Test successful speech generation"""
|
||||
with patch("requests.post", return_value=mock_response({})), \
|
||||
patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)), \
|
||||
patch("builtins.open", mock_open()) as mock_file:
|
||||
|
||||
with patch("requests.post", return_value=mock_response({})), patch(
|
||||
"ui.lib.api.OUTPUTS_DIR", str(tmp_path)
|
||||
), patch("builtins.open", mock_open()) as mock_file:
|
||||
result = api.text_to_speech("test text", "voice1", "mp3", 1.0)
|
||||
|
||||
assert result is not None
|
||||
|
@ -105,10 +106,9 @@ def test_get_status_html_unavailable():
|
|||
|
||||
def test_text_to_speech_api_params(mock_response, tmp_path):
|
||||
"""Test correct API parameters are sent"""
|
||||
with patch("requests.post") as mock_post, \
|
||||
patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)), \
|
||||
patch("builtins.open", mock_open()):
|
||||
|
||||
with patch("requests.post") as mock_post, patch(
|
||||
"ui.lib.api.OUTPUTS_DIR", str(tmp_path)
|
||||
), patch("builtins.open", mock_open()):
|
||||
mock_post.return_value = mock_response({})
|
||||
api.text_to_speech("test text", "voice1", "mp3", 1.5)
|
||||
|
||||
|
@ -121,7 +121,7 @@ def test_text_to_speech_api_params(mock_response, tmp_path):
|
|||
"input": "test text",
|
||||
"voice": "voice1",
|
||||
"response_format": "mp3",
|
||||
"speed": 1.5
|
||||
"speed": 1.5,
|
||||
}
|
||||
|
||||
# Check headers and timeout
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import pytest
|
||||
import gradio as gr
|
||||
import pytest
|
||||
|
||||
from ui.lib.config import AUDIO_FORMATS
|
||||
from ui.lib.components.model import create_model_column
|
||||
from ui.lib.components.output import create_output_column
|
||||
from ui.lib.config import AUDIO_FORMATS
|
||||
|
||||
|
||||
def test_create_model_column_structure():
|
||||
|
@ -15,12 +16,7 @@ def test_create_model_column_structure():
|
|||
assert isinstance(components, dict)
|
||||
|
||||
# Test expected components presence
|
||||
expected_components = {
|
||||
"status_btn",
|
||||
"voice",
|
||||
"format",
|
||||
"speed"
|
||||
}
|
||||
expected_components = {"status_btn", "voice", "format", "speed"}
|
||||
assert set(components.keys()) == expected_components
|
||||
|
||||
# Test component types
|
||||
|
@ -78,7 +74,7 @@ def test_create_output_column_structure():
|
|||
"output_files",
|
||||
"play_btn",
|
||||
"selected_audio",
|
||||
"clear_outputs"
|
||||
"clear_outputs",
|
||||
}
|
||||
assert set(components.keys()) == expected_components
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from ui.lib import files
|
||||
from ui.lib.config import AUDIO_FORMATS
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import pytest
|
||||
import gradio as gr
|
||||
import pytest
|
||||
|
||||
from ui.lib.components.input import create_input_column
|
||||
|
||||
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
import pytest
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import gradio as gr
|
||||
from unittest.mock import patch, MagicMock, PropertyMock
|
||||
import pytest
|
||||
|
||||
from ui.lib.interface import create_interface
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_timer():
|
||||
"""Create a mock timer with events property"""
|
||||
|
||||
class MockEvent:
|
||||
def __init__(self, fn):
|
||||
self.fn = fn
|
||||
|
@ -44,8 +47,7 @@ def test_interface_html_links():
|
|||
|
||||
# Find HTML component
|
||||
html_components = [
|
||||
comp for comp in demo.blocks.values()
|
||||
if isinstance(comp, gr.HTML)
|
||||
comp for comp in demo.blocks.values() if isinstance(comp, gr.HTML)
|
||||
]
|
||||
assert len(html_components) > 0
|
||||
html = html_components[0]
|
||||
|
@ -60,8 +62,9 @@ def test_interface_html_links():
|
|||
def test_update_status_available(mock_timer):
|
||||
"""Test status update when service is available"""
|
||||
voices = ["voice1", "voice2"]
|
||||
with patch("ui.lib.api.check_api_status", return_value=(True, voices)), \
|
||||
patch("gradio.Timer", return_value=mock_timer):
|
||||
with patch("ui.lib.api.check_api_status", return_value=(True, voices)), patch(
|
||||
"gradio.Timer", return_value=mock_timer
|
||||
):
|
||||
demo = create_interface()
|
||||
|
||||
# Get the update function
|
||||
|
@ -78,8 +81,9 @@ def test_update_status_available(mock_timer):
|
|||
|
||||
def test_update_status_unavailable(mock_timer):
|
||||
"""Test status update when service is unavailable"""
|
||||
with patch("ui.lib.api.check_api_status", return_value=(False, [])), \
|
||||
patch("gradio.Timer", return_value=mock_timer):
|
||||
with patch("ui.lib.api.check_api_status", return_value=(False, [])), patch(
|
||||
"gradio.Timer", return_value=mock_timer
|
||||
):
|
||||
demo = create_interface()
|
||||
update_fn = mock_timer.events[0].fn
|
||||
|
||||
|
@ -93,8 +97,9 @@ def test_update_status_unavailable(mock_timer):
|
|||
|
||||
def test_update_status_error(mock_timer):
|
||||
"""Test status update when an error occurs"""
|
||||
with patch("ui.lib.api.check_api_status", side_effect=Exception("Test error")), \
|
||||
patch("gradio.Timer", return_value=mock_timer):
|
||||
with patch(
|
||||
"ui.lib.api.check_api_status", side_effect=Exception("Test error")
|
||||
), patch("gradio.Timer", return_value=mock_timer):
|
||||
demo = create_interface()
|
||||
update_fn = mock_timer.events[0].fn
|
||||
|
||||
|
@ -108,8 +113,9 @@ def test_update_status_error(mock_timer):
|
|||
|
||||
def test_timer_configuration(mock_timer):
|
||||
"""Test timer configuration"""
|
||||
with patch("ui.lib.api.check_api_status", return_value=(False, [])), \
|
||||
patch("gradio.Timer", return_value=mock_timer):
|
||||
with patch("ui.lib.api.check_api_status", return_value=(False, [])), patch(
|
||||
"gradio.Timer", return_value=mock_timer
|
||||
):
|
||||
demo = create_interface()
|
||||
|
||||
assert mock_timer.value == 5 # Check interval is 5 seconds
|
||||
|
@ -123,8 +129,9 @@ def test_interface_components_presence():
|
|||
|
||||
# Check for main component sections
|
||||
components = {
|
||||
comp.label for comp in demo.blocks.values()
|
||||
if hasattr(comp, 'label') and comp.label
|
||||
comp.label
|
||||
for comp in demo.blocks.values()
|
||||
if hasattr(comp, "label") and comp.label
|
||||
}
|
||||
|
||||
required_components = {
|
||||
|
@ -133,7 +140,7 @@ def test_interface_components_presence():
|
|||
"Audio Format",
|
||||
"Speed",
|
||||
"Generated Speech",
|
||||
"Previous Outputs"
|
||||
"Previous Outputs",
|
||||
}
|
||||
|
||||
assert required_components.issubset(components)
|
||||
|
|
Loading…
Add table
Reference in a new issue