mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Ruff Check + Format
This commit is contained in:
parent
e749b3bc88
commit
f051984805
27 changed files with 638 additions and 504 deletions
BIN
.coverage
BIN
.coverage
Binary file not shown.
|
@ -1,10 +1,10 @@
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Response
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from fastapi import Depends, Response, APIRouter, HTTPException
|
||||||
|
|
||||||
from ..services.audio import AudioService
|
|
||||||
from ..services.tts import TTSService
|
from ..services.tts import TTSService
|
||||||
|
from ..services.audio import AudioService
|
||||||
from ..structures.schemas import OpenAISpeechRequest
|
from ..structures.schemas import OpenAISpeechRequest
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
|
@ -32,7 +32,7 @@ async def create_speech(
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Voice '{request.voice}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
f"Voice '{request.voice}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate audio directly using TTSService's method
|
# Generate audio directly using TTSService's method
|
||||||
audio, _ = tts_service._generate_audio(
|
audio, _ = tts_service._generate_audio(
|
||||||
text=request.input,
|
text=request.input,
|
||||||
|
@ -55,14 +55,12 @@ async def create_speech(
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f"Invalid request: {str(e)}")
|
logger.error(f"Invalid request: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400, detail={"error": "Invalid request", "message": str(e)}
|
||||||
detail={"error": "Invalid request", "message": str(e)}
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error generating speech: {str(e)}")
|
logger.error(f"Error generating speech: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail={"error": "Server error", "message": str(e)}
|
||||||
detail={"error": "Server error", "message": str(e)}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -78,17 +76,19 @@ async def list_voices(tts_service: TTSService = Depends(get_tts_service)):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/audio/voices/combine")
|
@router.post("/audio/voices/combine")
|
||||||
async def combine_voices(request: List[str], tts_service: TTSService = Depends(get_tts_service)):
|
async def combine_voices(
|
||||||
|
request: List[str], tts_service: TTSService = Depends(get_tts_service)
|
||||||
|
):
|
||||||
"""Combine multiple voices into a new voice.
|
"""Combine multiple voices into a new voice.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: List of voice names to combine
|
request: List of voice names to combine
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with combined voice name and list of all available voices
|
Dict with combined voice name and list of all available voices
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
HTTPException:
|
HTTPException:
|
||||||
- 400: Invalid request (wrong number of voices, voice not found)
|
- 400: Invalid request (wrong number of voices, voice not found)
|
||||||
- 500: Server error (file system issues, combination failed)
|
- 500: Server error (file system issues, combination failed)
|
||||||
"""
|
"""
|
||||||
|
@ -96,24 +96,21 @@ async def combine_voices(request: List[str], tts_service: TTSService = Depends(g
|
||||||
combined_voice = tts_service.combine_voices(voices=request)
|
combined_voice = tts_service.combine_voices(voices=request)
|
||||||
voices = tts_service.list_voices()
|
voices = tts_service.list_voices()
|
||||||
return {"voices": voices, "voice": combined_voice}
|
return {"voices": voices, "voice": combined_voice}
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f"Invalid voice combination request: {str(e)}")
|
logger.error(f"Invalid voice combination request: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400, detail={"error": "Invalid request", "message": str(e)}
|
||||||
detail={"error": "Invalid request", "message": str(e)}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logger.error(f"Server error during voice combination: {str(e)}")
|
logger.error(f"Server error during voice combination: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail={"error": "Server error", "message": str(e)}
|
||||||
detail={"error": "Server error", "message": str(e)}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unexpected error during voice combination: {str(e)}")
|
logger.error(f"Unexpected error during voice combination: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500, detail={"error": "Unexpected error", "message": str(e)}
|
||||||
detail={"error": "Unexpected error", "message": str(e)}
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,17 +1,16 @@
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
|
import threading
|
||||||
from typing import List, Tuple, Optional
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.io.wavfile as wavfile
|
|
||||||
import tiktoken
|
|
||||||
import torch
|
import torch
|
||||||
|
import tiktoken
|
||||||
|
import scipy.io.wavfile as wavfile
|
||||||
|
from kokoro import generate, tokenize, phonemize, normalize_text
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from kokoro import generate, normalize_text, phonemize, tokenize
|
|
||||||
from models import build_model
|
from models import build_model
|
||||||
|
|
||||||
from ..core.config import settings
|
from ..core.config import settings
|
||||||
|
@ -23,7 +22,7 @@ class TTSModel:
|
||||||
_instance = None
|
_instance = None
|
||||||
_device = None
|
_device = None
|
||||||
_lock = threading.Lock()
|
_lock = threading.Lock()
|
||||||
|
|
||||||
# Directory for all voices (copied base voices, and any created combined voices)
|
# Directory for all voices (copied base voices, and any created combined voices)
|
||||||
VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices")
|
VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices")
|
||||||
|
|
||||||
|
@ -38,10 +37,10 @@ class TTSModel:
|
||||||
model_path = os.path.join(settings.model_dir, settings.model_path)
|
model_path = os.path.join(settings.model_dir, settings.model_path)
|
||||||
model = build_model(model_path, cls._device)
|
model = build_model(model_path, cls._device)
|
||||||
cls._instance = model
|
cls._instance = model
|
||||||
|
|
||||||
# Ensure voices directory exists
|
# Ensure voices directory exists
|
||||||
os.makedirs(cls.VOICES_DIR, exist_ok=True)
|
os.makedirs(cls.VOICES_DIR, exist_ok=True)
|
||||||
|
|
||||||
# Copy base voices to local directory
|
# Copy base voices to local directory
|
||||||
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
|
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
|
||||||
if os.path.exists(base_voices_dir):
|
if os.path.exists(base_voices_dir):
|
||||||
|
@ -51,25 +50,37 @@ class TTSModel:
|
||||||
voice_path = os.path.join(cls.VOICES_DIR, file)
|
voice_path = os.path.join(cls.VOICES_DIR, file)
|
||||||
if not os.path.exists(voice_path):
|
if not os.path.exists(voice_path):
|
||||||
try:
|
try:
|
||||||
logger.info(f"Copying base voice {voice_name} to voices directory")
|
logger.info(
|
||||||
|
f"Copying base voice {voice_name} to voices directory"
|
||||||
|
)
|
||||||
base_path = os.path.join(base_voices_dir, file)
|
base_path = os.path.join(base_voices_dir, file)
|
||||||
voicepack = torch.load(base_path, map_location=cls._device, weights_only=True)
|
voicepack = torch.load(
|
||||||
|
base_path,
|
||||||
|
map_location=cls._device,
|
||||||
|
weights_only=True,
|
||||||
|
)
|
||||||
torch.save(voicepack, voice_path)
|
torch.save(voicepack, voice_path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error copying voice {voice_name}: {str(e)}")
|
logger.error(
|
||||||
|
f"Error copying voice {voice_name}: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
# Warm up with default voice
|
# Warm up with default voice
|
||||||
try:
|
try:
|
||||||
dummy_text = "Hello"
|
dummy_text = "Hello"
|
||||||
voice_path = os.path.join(cls.VOICES_DIR, "af.pt")
|
voice_path = os.path.join(cls.VOICES_DIR, "af.pt")
|
||||||
dummy_voicepack = torch.load(voice_path, map_location=cls._device, weights_only=True)
|
dummy_voicepack = torch.load(
|
||||||
generate(model, dummy_text, dummy_voicepack, lang='a', speed=1.0)
|
voice_path, map_location=cls._device, weights_only=True
|
||||||
|
)
|
||||||
|
generate(model, dummy_text, dummy_voicepack, lang="a", speed=1.0)
|
||||||
logger.info("Model warm-up complete")
|
logger.info("Model warm-up complete")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Model warm-up failed: {e}")
|
logger.warning(f"Model warm-up failed: {e}")
|
||||||
|
|
||||||
# Count voices in directory for validation
|
# Count voices in directory for validation
|
||||||
voice_count = len([f for f in os.listdir(cls.VOICES_DIR) if f.endswith('.pt')])
|
voice_count = len(
|
||||||
|
[f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]
|
||||||
|
)
|
||||||
return cls._instance, voice_count
|
return cls._instance, voice_count
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -86,11 +97,11 @@ class TTSService:
|
||||||
self._ensure_voices()
|
self._ensure_voices()
|
||||||
if start_worker:
|
if start_worker:
|
||||||
self.start_worker()
|
self.start_worker()
|
||||||
|
|
||||||
def _ensure_voices(self):
|
def _ensure_voices(self):
|
||||||
"""Copy base voices to local voices directory during initialization"""
|
"""Copy base voices to local voices directory during initialization"""
|
||||||
os.makedirs(TTSModel.VOICES_DIR, exist_ok=True)
|
os.makedirs(TTSModel.VOICES_DIR, exist_ok=True)
|
||||||
|
|
||||||
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
|
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
|
||||||
if os.path.exists(base_voices_dir):
|
if os.path.exists(base_voices_dir):
|
||||||
for file in os.listdir(base_voices_dir):
|
for file in os.listdir(base_voices_dir):
|
||||||
|
@ -99,9 +110,15 @@ class TTSService:
|
||||||
voice_path = os.path.join(TTSModel.VOICES_DIR, file)
|
voice_path = os.path.join(TTSModel.VOICES_DIR, file)
|
||||||
if not os.path.exists(voice_path):
|
if not os.path.exists(voice_path):
|
||||||
try:
|
try:
|
||||||
logger.info(f"Copying base voice {voice_name} to voices directory")
|
logger.info(
|
||||||
|
f"Copying base voice {voice_name} to voices directory"
|
||||||
|
)
|
||||||
base_path = os.path.join(base_voices_dir, file)
|
base_path = os.path.join(base_voices_dir, file)
|
||||||
voicepack = torch.load(base_path, map_location=TTSModel._device, weights_only=True)
|
voicepack = torch.load(
|
||||||
|
base_path,
|
||||||
|
map_location=TTSModel._device,
|
||||||
|
weights_only=True,
|
||||||
|
)
|
||||||
torch.save(voicepack, voice_path)
|
torch.save(voicepack, voice_path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error copying voice {voice_name}: {str(e)}")
|
logger.error(f"Error copying voice {voice_name}: {str(e)}")
|
||||||
|
@ -112,10 +129,10 @@ class TTSService:
|
||||||
|
|
||||||
def _get_voice_path(self, voice_name: str) -> Optional[str]:
|
def _get_voice_path(self, voice_name: str) -> Optional[str]:
|
||||||
"""Get the path to a voice file.
|
"""Get the path to a voice file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
voice_name: Name of the voice to find
|
voice_name: Name of the voice to find
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Path to the voice file if found, None otherwise
|
Path to the voice file if found, None otherwise
|
||||||
"""
|
"""
|
||||||
|
@ -141,7 +158,9 @@ class TTSService:
|
||||||
|
|
||||||
# Load model and voice
|
# Load model and voice
|
||||||
model = TTSModel._instance
|
model = TTSModel._instance
|
||||||
voicepack = torch.load(voice_path, map_location=TTSModel._device, weights_only=True)
|
voicepack = torch.load(
|
||||||
|
voice_path, map_location=TTSModel._device, weights_only=True
|
||||||
|
)
|
||||||
|
|
||||||
# Generate audio with or without stitching
|
# Generate audio with or without stitching
|
||||||
if stitch_long_output:
|
if stitch_long_output:
|
||||||
|
@ -152,11 +171,11 @@ class TTSService:
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
try:
|
try:
|
||||||
# Validate phonemization first
|
# Validate phonemization first
|
||||||
ps = phonemize(chunk, voice[0])
|
# ps = phonemize(chunk, voice[0])
|
||||||
tokens = tokenize(ps)
|
# tokens = tokenize(ps)
|
||||||
logger.debug(
|
# logger.debug(
|
||||||
f"Processing chunk {i + 1}/{len(chunks)}: {len(tokens)} tokens"
|
# f"Processing chunk {i + 1}/{len(chunks)}: {len(tokens)} tokens"
|
||||||
)
|
# )
|
||||||
|
|
||||||
# Only proceed if phonemization succeeded
|
# Only proceed if phonemization succeeded
|
||||||
chunk_audio, _ = generate(
|
chunk_audio, _ = generate(
|
||||||
|
@ -205,47 +224,51 @@ class TTSService:
|
||||||
|
|
||||||
def combine_voices(self, voices: List[str]) -> str:
|
def combine_voices(self, voices: List[str]) -> str:
|
||||||
"""Combine multiple voices into a new voice.
|
"""Combine multiple voices into a new voice.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
voices: List of voice names to combine
|
voices: List of voice names to combine
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Name of the combined voice
|
Name of the combined voice
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If less than 2 voices provided or voice loading fails
|
ValueError: If less than 2 voices provided or voice loading fails
|
||||||
RuntimeError: If voice combination or saving fails
|
RuntimeError: If voice combination or saving fails
|
||||||
"""
|
"""
|
||||||
if len(voices) < 2:
|
if len(voices) < 2:
|
||||||
raise ValueError("At least 2 voices are required for combination")
|
raise ValueError("At least 2 voices are required for combination")
|
||||||
|
|
||||||
# Load voices
|
# Load voices
|
||||||
t_voices: List[torch.Tensor] = []
|
t_voices: List[torch.Tensor] = []
|
||||||
v_name: List[str] = []
|
v_name: List[str] = []
|
||||||
|
|
||||||
for voice in voices:
|
for voice in voices:
|
||||||
try:
|
try:
|
||||||
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt")
|
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt")
|
||||||
voicepack = torch.load(voice_path, map_location=TTSModel._device, weights_only=True)
|
voicepack = torch.load(
|
||||||
|
voice_path, map_location=TTSModel._device, weights_only=True
|
||||||
|
)
|
||||||
t_voices.append(voicepack)
|
t_voices.append(voicepack)
|
||||||
v_name.append(voice)
|
v_name.append(voice)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Failed to load voice {voice}: {str(e)}")
|
raise ValueError(f"Failed to load voice {voice}: {str(e)}")
|
||||||
|
|
||||||
# Combine voices
|
# Combine voices
|
||||||
try:
|
try:
|
||||||
f: str = "_".join(v_name)
|
f: str = "_".join(v_name)
|
||||||
v = torch.mean(torch.stack(t_voices), dim=0)
|
v = torch.mean(torch.stack(t_voices), dim=0)
|
||||||
combined_path = os.path.join(TTSModel.VOICES_DIR, f"{f}.pt")
|
combined_path = os.path.join(TTSModel.VOICES_DIR, f"{f}.pt")
|
||||||
|
|
||||||
# Save combined voice
|
# Save combined voice
|
||||||
try:
|
try:
|
||||||
torch.save(v, combined_path)
|
torch.save(v, combined_path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to save combined voice to {combined_path}: {str(e)}")
|
raise RuntimeError(
|
||||||
|
f"Failed to save combined voice to {combined_path}: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
return f
|
return f
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if not isinstance(e, (ValueError, RuntimeError)):
|
if not isinstance(e, (ValueError, RuntimeError)):
|
||||||
raise RuntimeError(f"Error combining voices: {str(e)}")
|
raise RuntimeError(f"Error combining voices: {str(e)}")
|
||||||
|
|
|
@ -17,8 +17,8 @@ class OpenAISpeechRequest(BaseModel):
|
||||||
model: Literal["tts-1", "tts-1-hd", "kokoro"] = "kokoro"
|
model: Literal["tts-1", "tts-1-hd", "kokoro"] = "kokoro"
|
||||||
input: str = Field(..., description="The text to generate audio for")
|
input: str = Field(..., description="The text to generate audio for")
|
||||||
voice: str = Field(
|
voice: str = Field(
|
||||||
default="af",
|
default="af",
|
||||||
description="The voice to use for generation. Can be a base voice or a combined voice name."
|
description="The voice to use for generation. Can be a base voice or a combined voice name.",
|
||||||
)
|
)
|
||||||
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field(
|
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field(
|
||||||
default="mp3",
|
default="mp3",
|
||||||
|
|
|
@ -1,16 +1,18 @@
|
||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
import sys
|
import sys
|
||||||
|
import shutil
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
def cleanup_mock_dirs():
|
def cleanup_mock_dirs():
|
||||||
"""Clean up any MagicMock directories created during tests"""
|
"""Clean up any MagicMock directories created during tests"""
|
||||||
mock_dir = "MagicMock"
|
mock_dir = "MagicMock"
|
||||||
if os.path.exists(mock_dir):
|
if os.path.exists(mock_dir):
|
||||||
shutil.rmtree(mock_dir)
|
shutil.rmtree(mock_dir)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def cleanup():
|
def cleanup():
|
||||||
"""Automatically clean up before and after each test"""
|
"""Automatically clean up before and after each test"""
|
||||||
|
@ -18,6 +20,7 @@ def cleanup():
|
||||||
yield
|
yield
|
||||||
cleanup_mock_dirs()
|
cleanup_mock_dirs()
|
||||||
|
|
||||||
|
|
||||||
# Mock torch and other ML modules before they're imported
|
# Mock torch and other ML modules before they're imported
|
||||||
sys.modules["torch"] = Mock()
|
sys.modules["torch"] = Mock()
|
||||||
sys.modules["transformers"] = Mock()
|
sys.modules["transformers"] = Mock()
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
"""Tests for AudioService"""
|
"""Tests for AudioService"""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from api.src.services.audio import AudioService
|
from api.src.services.audio import AudioService
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -114,9 +114,9 @@ def test_combine_voices_success(mock_tts_service):
|
||||||
"""Test successful voice combination"""
|
"""Test successful voice combination"""
|
||||||
test_voices = ["af_bella", "af_sarah"]
|
test_voices = ["af_bella", "af_sarah"]
|
||||||
mock_tts_service.combine_voices.return_value = "af_bella_af_sarah"
|
mock_tts_service.combine_voices.return_value = "af_bella_af_sarah"
|
||||||
|
|
||||||
response = client.post("/v1/audio/voices/combine", json=test_voices)
|
response = client.post("/v1/audio/voices/combine", json=test_voices)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["voice"] == "af_bella_af_sarah"
|
assert response.json()["voice"] == "af_bella_af_sarah"
|
||||||
mock_tts_service.combine_voices.assert_called_once_with(voices=test_voices)
|
mock_tts_service.combine_voices.assert_called_once_with(voices=test_voices)
|
||||||
|
@ -126,9 +126,9 @@ def test_combine_voices_single_voice(mock_tts_service):
|
||||||
"""Test combining single voice returns default voice"""
|
"""Test combining single voice returns default voice"""
|
||||||
test_voices = ["af_bella"]
|
test_voices = ["af_bella"]
|
||||||
mock_tts_service.combine_voices.return_value = "af"
|
mock_tts_service.combine_voices.return_value = "af"
|
||||||
|
|
||||||
response = client.post("/v1/audio/voices/combine", json=test_voices)
|
response = client.post("/v1/audio/voices/combine", json=test_voices)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["voice"] == "af"
|
assert response.json()["voice"] == "af"
|
||||||
|
|
||||||
|
@ -137,9 +137,9 @@ def test_combine_voices_empty_list(mock_tts_service):
|
||||||
"""Test combining empty voice list returns default voice"""
|
"""Test combining empty voice list returns default voice"""
|
||||||
test_voices = []
|
test_voices = []
|
||||||
mock_tts_service.combine_voices.return_value = "af"
|
mock_tts_service.combine_voices.return_value = "af"
|
||||||
|
|
||||||
response = client.post("/v1/audio/voices/combine", json=test_voices)
|
response = client.post("/v1/audio/voices/combine", json=test_voices)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["voice"] == "af"
|
assert response.json()["voice"] == "af"
|
||||||
|
|
||||||
|
@ -148,8 +148,8 @@ def test_combine_voices_error(mock_tts_service):
|
||||||
"""Test error handling in voice combination"""
|
"""Test error handling in voice combination"""
|
||||||
test_voices = ["af_bella", "af_sarah"]
|
test_voices = ["af_bella", "af_sarah"]
|
||||||
mock_tts_service.combine_voices.side_effect = Exception("Combination failed")
|
mock_tts_service.combine_voices.side_effect = Exception("Combination failed")
|
||||||
|
|
||||||
response = client.post("/v1/audio/voices/combine", json=test_voices)
|
response = client.post("/v1/audio/voices/combine", json=test_voices)
|
||||||
|
|
||||||
assert response.status_code == 500
|
assert response.status_code == 500
|
||||||
assert "Combination failed" in response.json()["detail"]["message"]
|
assert "Combination failed" in response.json()["detail"]["message"]
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
"""Tests for FastAPI application"""
|
"""Tests for FastAPI application"""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from api.src.main import app, lifespan
|
from api.src.main import app, lifespan
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,98 +22,100 @@ def test_health_check(test_client):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch('api.src.main.TTSModel')
|
@patch("api.src.main.TTSModel")
|
||||||
@patch('api.src.main.logger')
|
@patch("api.src.main.logger")
|
||||||
async def test_lifespan_successful_warmup(mock_logger, mock_tts_model):
|
async def test_lifespan_successful_warmup(mock_logger, mock_tts_model):
|
||||||
"""Test successful model warmup in lifespan"""
|
"""Test successful model warmup in lifespan"""
|
||||||
# Mock the model initialization with model info and voicepack count
|
# Mock the model initialization with model info and voicepack count
|
||||||
mock_model = MagicMock()
|
mock_model = MagicMock()
|
||||||
# Mock file system for voice counting
|
# Mock file system for voice counting
|
||||||
mock_tts_model.VOICES_DIR = "/mock/voices"
|
mock_tts_model.VOICES_DIR = "/mock/voices"
|
||||||
with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt', 'voice3.pt']):
|
with patch("os.listdir", return_value=["voice1.pt", "voice2.pt", "voice3.pt"]):
|
||||||
mock_tts_model.initialize.return_value = (mock_model, 3) # 3 voice files
|
mock_tts_model.initialize.return_value = (mock_model, 3) # 3 voice files
|
||||||
mock_tts_model._device = "cuda" # Set device class variable
|
mock_tts_model._device = "cuda" # Set device class variable
|
||||||
|
|
||||||
# Create an async generator from the lifespan context manager
|
# Create an async generator from the lifespan context manager
|
||||||
async_gen = lifespan(MagicMock())
|
async_gen = lifespan(MagicMock())
|
||||||
# Start the context manager
|
# Start the context manager
|
||||||
await async_gen.__aenter__()
|
await async_gen.__aenter__()
|
||||||
|
|
||||||
# Verify the expected logging sequence
|
# Verify the expected logging sequence
|
||||||
mock_logger.info.assert_any_call("Loading TTS model and voice packs...")
|
mock_logger.info.assert_any_call("Loading TTS model and voice packs...")
|
||||||
mock_logger.info.assert_any_call("Model loaded and warmed up on cuda")
|
mock_logger.info.assert_any_call("Model loaded and warmed up on cuda")
|
||||||
mock_logger.info.assert_any_call("3 voice packs loaded successfully")
|
mock_logger.info.assert_any_call("3 voice packs loaded successfully")
|
||||||
|
|
||||||
# Verify model initialization was called
|
# Verify model initialization was called
|
||||||
mock_tts_model.initialize.assert_called_once()
|
mock_tts_model.initialize.assert_called_once()
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
await async_gen.__aexit__(None, None, None)
|
await async_gen.__aexit__(None, None, None)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch('api.src.main.TTSModel')
|
@patch("api.src.main.TTSModel")
|
||||||
@patch('api.src.main.logger')
|
@patch("api.src.main.logger")
|
||||||
async def test_lifespan_failed_warmup(mock_logger, mock_tts_model):
|
async def test_lifespan_failed_warmup(mock_logger, mock_tts_model):
|
||||||
"""Test failed model warmup in lifespan"""
|
"""Test failed model warmup in lifespan"""
|
||||||
# Mock the model initialization to fail
|
# Mock the model initialization to fail
|
||||||
mock_tts_model.initialize.side_effect = Exception("Failed to initialize model")
|
mock_tts_model.initialize.side_effect = Exception("Failed to initialize model")
|
||||||
|
|
||||||
# Create an async generator from the lifespan context manager
|
# Create an async generator from the lifespan context manager
|
||||||
async_gen = lifespan(MagicMock())
|
async_gen = lifespan(MagicMock())
|
||||||
|
|
||||||
# Verify the exception is raised
|
# Verify the exception is raised
|
||||||
with pytest.raises(Exception, match="Failed to initialize model"):
|
with pytest.raises(Exception, match="Failed to initialize model"):
|
||||||
await async_gen.__aenter__()
|
await async_gen.__aenter__()
|
||||||
|
|
||||||
# Verify the expected logging sequence
|
# Verify the expected logging sequence
|
||||||
mock_logger.info.assert_called_with("Loading TTS model and voice packs...")
|
mock_logger.info.assert_called_with("Loading TTS model and voice packs...")
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
await async_gen.__aexit__(None, None, None)
|
await async_gen.__aexit__(None, None, None)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch('api.src.main.TTSModel')
|
@patch("api.src.main.TTSModel")
|
||||||
async def test_lifespan_cuda_warmup(mock_tts_model):
|
async def test_lifespan_cuda_warmup(mock_tts_model):
|
||||||
"""Test model warmup specifically on CUDA"""
|
"""Test model warmup specifically on CUDA"""
|
||||||
# Mock the model initialization with CUDA and voicepacks
|
# Mock the model initialization with CUDA and voicepacks
|
||||||
mock_model = MagicMock()
|
mock_model = MagicMock()
|
||||||
# Mock file system for voice counting
|
# Mock file system for voice counting
|
||||||
mock_tts_model.VOICES_DIR = "/mock/voices"
|
mock_tts_model.VOICES_DIR = "/mock/voices"
|
||||||
with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt']):
|
with patch("os.listdir", return_value=["voice1.pt", "voice2.pt"]):
|
||||||
mock_tts_model.initialize.return_value = (mock_model, 2) # 2 voice files
|
mock_tts_model.initialize.return_value = (mock_model, 2) # 2 voice files
|
||||||
mock_tts_model._device = "cuda" # Set device class variable
|
mock_tts_model._device = "cuda" # Set device class variable
|
||||||
|
|
||||||
# Create an async generator from the lifespan context manager
|
# Create an async generator from the lifespan context manager
|
||||||
async_gen = lifespan(MagicMock())
|
async_gen = lifespan(MagicMock())
|
||||||
await async_gen.__aenter__()
|
await async_gen.__aenter__()
|
||||||
|
|
||||||
# Verify model was initialized
|
# Verify model was initialized
|
||||||
mock_tts_model.initialize.assert_called_once()
|
mock_tts_model.initialize.assert_called_once()
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
await async_gen.__aexit__(None, None, None)
|
await async_gen.__aexit__(None, None, None)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@patch('api.src.main.TTSModel')
|
@patch("api.src.main.TTSModel")
|
||||||
async def test_lifespan_cpu_fallback(mock_tts_model):
|
async def test_lifespan_cpu_fallback(mock_tts_model):
|
||||||
"""Test model warmup falling back to CPU"""
|
"""Test model warmup falling back to CPU"""
|
||||||
# Mock the model initialization with CPU and voicepacks
|
# Mock the model initialization with CPU and voicepacks
|
||||||
mock_model = MagicMock()
|
mock_model = MagicMock()
|
||||||
# Mock file system for voice counting
|
# Mock file system for voice counting
|
||||||
mock_tts_model.VOICES_DIR = "/mock/voices"
|
mock_tts_model.VOICES_DIR = "/mock/voices"
|
||||||
with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt', 'voice3.pt', 'voice4.pt']):
|
with patch(
|
||||||
|
"os.listdir", return_value=["voice1.pt", "voice2.pt", "voice3.pt", "voice4.pt"]
|
||||||
|
):
|
||||||
mock_tts_model.initialize.return_value = (mock_model, 4) # 4 voice files
|
mock_tts_model.initialize.return_value = (mock_model, 4) # 4 voice files
|
||||||
mock_tts_model._device = "cpu" # Set device class variable
|
mock_tts_model._device = "cpu" # Set device class variable
|
||||||
|
|
||||||
# Create an async generator from the lifespan context manager
|
# Create an async generator from the lifespan context manager
|
||||||
async_gen = lifespan(MagicMock())
|
async_gen = lifespan(MagicMock())
|
||||||
await async_gen.__aenter__()
|
await async_gen.__aenter__()
|
||||||
|
|
||||||
# Verify model was initialized
|
# Verify model was initialized
|
||||||
mock_tts_model.initialize.assert_called_once()
|
mock_tts_model.initialize.assert_called_once()
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
await async_gen.__aexit__(None, None, None)
|
await async_gen.__aexit__(None, None, None)
|
||||||
|
|
|
@ -1,9 +1,12 @@
|
||||||
"""Tests for TTSService"""
|
"""Tests for TTSService"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from unittest.mock import MagicMock, call, patch
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import patch, MagicMock, call
|
|
||||||
from api.src.services.tts import TTSService, TTSModel
|
from api.src.services.tts import TTSModel, TTSService
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -50,42 +53,59 @@ def test_audio_to_bytes(tts_service, sample_audio):
|
||||||
assert len(audio_bytes) > 0
|
assert len(audio_bytes) > 0
|
||||||
|
|
||||||
|
|
||||||
@patch('os.listdir')
|
@patch("os.listdir")
|
||||||
@patch('os.path.join')
|
@patch("os.path.join")
|
||||||
def test_list_voices(mock_join, mock_listdir, tts_service):
|
def test_list_voices(mock_join, mock_listdir, tts_service):
|
||||||
"""Test listing available voices"""
|
"""Test listing available voices"""
|
||||||
mock_listdir.return_value = ['voice1.pt', 'voice2.pt', 'not_a_voice.txt']
|
mock_listdir.return_value = ["voice1.pt", "voice2.pt", "not_a_voice.txt"]
|
||||||
mock_join.return_value = '/fake/path'
|
mock_join.return_value = "/fake/path"
|
||||||
|
|
||||||
voices = tts_service.list_voices()
|
voices = tts_service.list_voices()
|
||||||
assert len(voices) == 2
|
assert len(voices) == 2
|
||||||
assert 'voice1' in voices
|
assert "voice1" in voices
|
||||||
assert 'voice2' in voices
|
assert "voice2" in voices
|
||||||
assert 'not_a_voice' not in voices
|
assert "not_a_voice" not in voices
|
||||||
|
|
||||||
|
|
||||||
@patch('api.src.services.tts.TTSModel.get_instance')
|
@patch("api.src.services.tts.TTSModel.get_instance")
|
||||||
@patch('api.src.services.tts.TTSModel.get_voicepack')
|
@patch("api.src.services.tts.TTSModel.get_voicepack")
|
||||||
@patch('api.src.services.tts.normalize_text')
|
@patch("api.src.services.tts.normalize_text")
|
||||||
@patch('api.src.services.tts.phonemize')
|
@patch("api.src.services.tts.phonemize")
|
||||||
@patch('api.src.services.tts.tokenize')
|
@patch("api.src.services.tts.tokenize")
|
||||||
@patch('api.src.services.tts.generate')
|
@patch("api.src.services.tts.generate")
|
||||||
def test_generate_audio_empty_text(mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_voicepack, mock_instance, tts_service):
|
def test_generate_audio_empty_text(
|
||||||
|
mock_generate,
|
||||||
|
mock_tokenize,
|
||||||
|
mock_phonemize,
|
||||||
|
mock_normalize,
|
||||||
|
mock_voicepack,
|
||||||
|
mock_instance,
|
||||||
|
tts_service,
|
||||||
|
):
|
||||||
"""Test generating audio with empty text"""
|
"""Test generating audio with empty text"""
|
||||||
mock_normalize.return_value = ""
|
mock_normalize.return_value = ""
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Text is empty after preprocessing"):
|
with pytest.raises(ValueError, match="Text is empty after preprocessing"):
|
||||||
tts_service._generate_audio("", "af", 1.0)
|
tts_service._generate_audio("", "af", 1.0)
|
||||||
|
|
||||||
|
|
||||||
@patch('api.src.services.tts.TTSModel.get_instance')
|
@patch("api.src.services.tts.TTSModel.get_instance")
|
||||||
@patch('os.path.exists')
|
@patch("os.path.exists")
|
||||||
@patch('api.src.services.tts.normalize_text')
|
@patch("api.src.services.tts.normalize_text")
|
||||||
@patch('api.src.services.tts.phonemize')
|
@patch("api.src.services.tts.phonemize")
|
||||||
@patch('api.src.services.tts.tokenize')
|
@patch("api.src.services.tts.tokenize")
|
||||||
@patch('api.src.services.tts.generate')
|
@patch("api.src.services.tts.generate")
|
||||||
@patch('torch.load')
|
@patch("torch.load")
|
||||||
def test_generate_audio_no_chunks(mock_torch_load, mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_exists, mock_instance, tts_service):
|
def test_generate_audio_no_chunks(
|
||||||
|
mock_torch_load,
|
||||||
|
mock_generate,
|
||||||
|
mock_tokenize,
|
||||||
|
mock_phonemize,
|
||||||
|
mock_normalize,
|
||||||
|
mock_exists,
|
||||||
|
mock_instance,
|
||||||
|
tts_service,
|
||||||
|
):
|
||||||
"""Test generating audio with no successful chunks"""
|
"""Test generating audio with no successful chunks"""
|
||||||
mock_normalize.return_value = "Test text"
|
mock_normalize.return_value = "Test text"
|
||||||
mock_phonemize.return_value = "Test text"
|
mock_phonemize.return_value = "Test text"
|
||||||
|
@ -94,19 +114,29 @@ def test_generate_audio_no_chunks(mock_torch_load, mock_generate, mock_tokenize,
|
||||||
mock_instance.return_value = (MagicMock(), "cpu")
|
mock_instance.return_value = (MagicMock(), "cpu")
|
||||||
mock_exists.return_value = True
|
mock_exists.return_value = True
|
||||||
mock_torch_load.return_value = MagicMock()
|
mock_torch_load.return_value = MagicMock()
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
|
with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
|
||||||
tts_service._generate_audio("Test text", "af", 1.0)
|
tts_service._generate_audio("Test text", "af", 1.0)
|
||||||
|
|
||||||
|
|
||||||
@patch('api.src.services.tts.TTSModel.get_instance')
|
@patch("api.src.services.tts.TTSModel.get_instance")
|
||||||
@patch('os.path.exists')
|
@patch("os.path.exists")
|
||||||
@patch('api.src.services.tts.normalize_text')
|
@patch("api.src.services.tts.normalize_text")
|
||||||
@patch('api.src.services.tts.phonemize')
|
@patch("api.src.services.tts.phonemize")
|
||||||
@patch('api.src.services.tts.tokenize')
|
@patch("api.src.services.tts.tokenize")
|
||||||
@patch('api.src.services.tts.generate')
|
@patch("api.src.services.tts.generate")
|
||||||
@patch('torch.load')
|
@patch("torch.load")
|
||||||
def test_generate_audio_success(mock_torch_load, mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_exists, mock_instance, tts_service, sample_audio):
|
def test_generate_audio_success(
|
||||||
|
mock_torch_load,
|
||||||
|
mock_generate,
|
||||||
|
mock_tokenize,
|
||||||
|
mock_phonemize,
|
||||||
|
mock_normalize,
|
||||||
|
mock_exists,
|
||||||
|
mock_instance,
|
||||||
|
tts_service,
|
||||||
|
sample_audio,
|
||||||
|
):
|
||||||
"""Test successful audio generation"""
|
"""Test successful audio generation"""
|
||||||
mock_normalize.return_value = "Test text"
|
mock_normalize.return_value = "Test text"
|
||||||
mock_phonemize.return_value = "Test text"
|
mock_phonemize.return_value = "Test text"
|
||||||
|
@ -115,15 +145,15 @@ def test_generate_audio_success(mock_torch_load, mock_generate, mock_tokenize, m
|
||||||
mock_instance.return_value = (MagicMock(), "cpu")
|
mock_instance.return_value = (MagicMock(), "cpu")
|
||||||
mock_exists.return_value = True
|
mock_exists.return_value = True
|
||||||
mock_torch_load.return_value = MagicMock()
|
mock_torch_load.return_value = MagicMock()
|
||||||
|
|
||||||
audio, processing_time = tts_service._generate_audio("Test text", "af", 1.0)
|
audio, processing_time = tts_service._generate_audio("Test text", "af", 1.0)
|
||||||
assert isinstance(audio, np.ndarray)
|
assert isinstance(audio, np.ndarray)
|
||||||
assert isinstance(processing_time, float)
|
assert isinstance(processing_time, float)
|
||||||
assert len(audio) > 0
|
assert len(audio) > 0
|
||||||
|
|
||||||
|
|
||||||
@patch('api.src.services.tts.torch.cuda.is_available')
|
@patch("api.src.services.tts.torch.cuda.is_available")
|
||||||
@patch('api.src.services.tts.build_model')
|
@patch("api.src.services.tts.build_model")
|
||||||
def test_model_initialization_cuda(mock_build_model, mock_cuda_available):
|
def test_model_initialization_cuda(mock_build_model, mock_cuda_available):
|
||||||
"""Test model initialization with CUDA"""
|
"""Test model initialization with CUDA"""
|
||||||
mock_cuda_available.return_value = True
|
mock_cuda_available.return_value = True
|
||||||
|
@ -132,14 +162,14 @@ def test_model_initialization_cuda(mock_build_model, mock_cuda_available):
|
||||||
|
|
||||||
TTSModel._instance = None # Reset singleton
|
TTSModel._instance = None # Reset singleton
|
||||||
model, voice_count = TTSModel.initialize()
|
model, voice_count = TTSModel.initialize()
|
||||||
|
|
||||||
assert TTSModel._device == "cuda" # Check the class variable instead
|
assert TTSModel._device == "cuda" # Check the class variable instead
|
||||||
assert model == mock_model
|
assert model == mock_model
|
||||||
mock_build_model.assert_called_once()
|
mock_build_model.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@patch('api.src.services.tts.torch.cuda.is_available')
|
@patch("api.src.services.tts.torch.cuda.is_available")
|
||||||
@patch('api.src.services.tts.build_model')
|
@patch("api.src.services.tts.build_model")
|
||||||
def test_model_initialization_cpu(mock_build_model, mock_cuda_available):
|
def test_model_initialization_cpu(mock_build_model, mock_cuda_available):
|
||||||
"""Test model initialization with CPU"""
|
"""Test model initialization with CPU"""
|
||||||
mock_cuda_available.return_value = False
|
mock_cuda_available.return_value = False
|
||||||
|
@ -148,76 +178,95 @@ def test_model_initialization_cpu(mock_build_model, mock_cuda_available):
|
||||||
|
|
||||||
TTSModel._instance = None # Reset singleton
|
TTSModel._instance = None # Reset singleton
|
||||||
model, voice_count = TTSModel.initialize()
|
model, voice_count = TTSModel.initialize()
|
||||||
|
|
||||||
assert TTSModel._device == "cpu" # Check the class variable instead
|
assert TTSModel._device == "cpu" # Check the class variable instead
|
||||||
assert model == mock_model
|
assert model == mock_model
|
||||||
mock_build_model.assert_called_once()
|
mock_build_model.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@patch('api.src.services.tts.TTSService._get_voice_path')
|
@patch("api.src.services.tts.TTSService._get_voice_path")
|
||||||
@patch('api.src.services.tts.TTSModel.get_instance')
|
@patch("api.src.services.tts.TTSModel.get_instance")
|
||||||
def test_voicepack_loading_error(mock_get_instance, mock_get_voice_path):
|
def test_voicepack_loading_error(mock_get_instance, mock_get_voice_path):
|
||||||
"""Test voicepack loading error handling"""
|
"""Test voicepack loading error handling"""
|
||||||
mock_get_voice_path.return_value = None
|
mock_get_voice_path.return_value = None
|
||||||
mock_get_instance.return_value = (MagicMock(), "cpu")
|
mock_get_instance.return_value = (MagicMock(), "cpu")
|
||||||
|
|
||||||
TTSModel._voicepacks = {} # Reset voicepacks
|
TTSModel._voicepacks = {} # Reset voicepacks
|
||||||
|
|
||||||
service = TTSService(start_worker=False)
|
service = TTSService(start_worker=False)
|
||||||
with pytest.raises(ValueError, match="Voice not found: nonexistent_voice"):
|
with pytest.raises(ValueError, match="Voice not found: nonexistent_voice"):
|
||||||
service._generate_audio("test", "nonexistent_voice", 1.0)
|
service._generate_audio("test", "nonexistent_voice", 1.0)
|
||||||
|
|
||||||
|
|
||||||
@patch('api.src.services.tts.TTSModel')
|
@patch("api.src.services.tts.TTSModel")
|
||||||
def test_save_audio(mock_tts_model, tts_service, sample_audio, tmp_path):
|
def test_save_audio(mock_tts_model, tts_service, sample_audio, tmp_path):
|
||||||
"""Test saving audio to file"""
|
"""Test saving audio to file"""
|
||||||
output_dir = os.path.join(tmp_path, "test_output")
|
output_dir = os.path.join(tmp_path, "test_output")
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
output_path = os.path.join(output_dir, "audio.wav")
|
output_path = os.path.join(output_dir, "audio.wav")
|
||||||
|
|
||||||
tts_service._save_audio(sample_audio, output_path)
|
tts_service._save_audio(sample_audio, output_path)
|
||||||
|
|
||||||
assert os.path.exists(output_path)
|
assert os.path.exists(output_path)
|
||||||
assert os.path.getsize(output_path) > 0
|
assert os.path.getsize(output_path) > 0
|
||||||
|
|
||||||
|
|
||||||
@patch('api.src.services.tts.TTSModel.get_instance')
|
@patch("api.src.services.tts.TTSModel.get_instance")
|
||||||
@patch('os.path.exists')
|
@patch("os.path.exists")
|
||||||
@patch('api.src.services.tts.normalize_text')
|
@patch("api.src.services.tts.normalize_text")
|
||||||
@patch('api.src.services.tts.generate')
|
@patch("api.src.services.tts.generate")
|
||||||
@patch('torch.load')
|
@patch("torch.load")
|
||||||
def test_generate_audio_without_stitching(mock_torch_load, mock_generate, mock_normalize, mock_exists, mock_instance, tts_service, sample_audio):
|
def test_generate_audio_without_stitching(
|
||||||
|
mock_torch_load,
|
||||||
|
mock_generate,
|
||||||
|
mock_normalize,
|
||||||
|
mock_exists,
|
||||||
|
mock_instance,
|
||||||
|
tts_service,
|
||||||
|
sample_audio,
|
||||||
|
):
|
||||||
"""Test generating audio without text stitching"""
|
"""Test generating audio without text stitching"""
|
||||||
mock_normalize.return_value = "Test text"
|
mock_normalize.return_value = "Test text"
|
||||||
mock_generate.return_value = (sample_audio, None)
|
mock_generate.return_value = (sample_audio, None)
|
||||||
mock_instance.return_value = (MagicMock(), "cpu")
|
mock_instance.return_value = (MagicMock(), "cpu")
|
||||||
mock_exists.return_value = True
|
mock_exists.return_value = True
|
||||||
mock_torch_load.return_value = MagicMock()
|
mock_torch_load.return_value = MagicMock()
|
||||||
|
|
||||||
audio, processing_time = tts_service._generate_audio("Test text", "af", 1.0, stitch_long_output=False)
|
audio, processing_time = tts_service._generate_audio(
|
||||||
|
"Test text", "af", 1.0, stitch_long_output=False
|
||||||
|
)
|
||||||
assert isinstance(audio, np.ndarray)
|
assert isinstance(audio, np.ndarray)
|
||||||
assert isinstance(processing_time, float)
|
assert isinstance(processing_time, float)
|
||||||
assert len(audio) > 0
|
assert len(audio) > 0
|
||||||
mock_generate.assert_called_once()
|
mock_generate.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@patch('os.listdir')
|
@patch("os.listdir")
|
||||||
def test_list_voices_error(mock_listdir, tts_service):
|
def test_list_voices_error(mock_listdir, tts_service):
|
||||||
"""Test error handling in list_voices"""
|
"""Test error handling in list_voices"""
|
||||||
mock_listdir.side_effect = Exception("Failed to list directory")
|
mock_listdir.side_effect = Exception("Failed to list directory")
|
||||||
|
|
||||||
voices = tts_service.list_voices()
|
voices = tts_service.list_voices()
|
||||||
assert voices == []
|
assert voices == []
|
||||||
|
|
||||||
|
|
||||||
@patch('api.src.services.tts.TTSModel.get_instance')
|
@patch("api.src.services.tts.TTSModel.get_instance")
|
||||||
@patch('os.path.exists')
|
@patch("os.path.exists")
|
||||||
@patch('api.src.services.tts.normalize_text')
|
@patch("api.src.services.tts.normalize_text")
|
||||||
@patch('api.src.services.tts.phonemize')
|
@patch("api.src.services.tts.phonemize")
|
||||||
@patch('api.src.services.tts.tokenize')
|
@patch("api.src.services.tts.tokenize")
|
||||||
@patch('api.src.services.tts.generate')
|
@patch("api.src.services.tts.generate")
|
||||||
@patch('torch.load')
|
@patch("torch.load")
|
||||||
def test_generate_audio_phonemize_error(mock_torch_load, mock_generate, mock_tokenize, mock_phonemize, mock_normalize, mock_exists, mock_instance, tts_service):
|
def test_generate_audio_phonemize_error(
|
||||||
|
mock_torch_load,
|
||||||
|
mock_generate,
|
||||||
|
mock_tokenize,
|
||||||
|
mock_phonemize,
|
||||||
|
mock_normalize,
|
||||||
|
mock_exists,
|
||||||
|
mock_instance,
|
||||||
|
tts_service,
|
||||||
|
):
|
||||||
"""Test handling phonemization error"""
|
"""Test handling phonemization error"""
|
||||||
mock_normalize.return_value = "Test text"
|
mock_normalize.return_value = "Test text"
|
||||||
mock_phonemize.side_effect = Exception("Phonemization failed")
|
mock_phonemize.side_effect = Exception("Phonemization failed")
|
||||||
|
@ -225,23 +274,30 @@ def test_generate_audio_phonemize_error(mock_torch_load, mock_generate, mock_tok
|
||||||
mock_exists.return_value = True
|
mock_exists.return_value = True
|
||||||
mock_torch_load.return_value = MagicMock()
|
mock_torch_load.return_value = MagicMock()
|
||||||
mock_generate.return_value = (None, None)
|
mock_generate.return_value = (None, None)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
|
with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
|
||||||
tts_service._generate_audio("Test text", "af", 1.0)
|
tts_service._generate_audio("Test text", "af", 1.0)
|
||||||
|
|
||||||
|
|
||||||
@patch('api.src.services.tts.TTSModel.get_instance')
|
@patch("api.src.services.tts.TTSModel.get_instance")
|
||||||
@patch('os.path.exists')
|
@patch("os.path.exists")
|
||||||
@patch('api.src.services.tts.normalize_text')
|
@patch("api.src.services.tts.normalize_text")
|
||||||
@patch('api.src.services.tts.generate')
|
@patch("api.src.services.tts.generate")
|
||||||
@patch('torch.load')
|
@patch("torch.load")
|
||||||
def test_generate_audio_error(mock_torch_load, mock_generate, mock_normalize, mock_exists, mock_instance, tts_service):
|
def test_generate_audio_error(
|
||||||
|
mock_torch_load,
|
||||||
|
mock_generate,
|
||||||
|
mock_normalize,
|
||||||
|
mock_exists,
|
||||||
|
mock_instance,
|
||||||
|
tts_service,
|
||||||
|
):
|
||||||
"""Test handling generation error"""
|
"""Test handling generation error"""
|
||||||
mock_normalize.return_value = "Test text"
|
mock_normalize.return_value = "Test text"
|
||||||
mock_generate.side_effect = Exception("Generation failed")
|
mock_generate.side_effect = Exception("Generation failed")
|
||||||
mock_instance.return_value = (MagicMock(), "cpu")
|
mock_instance.return_value = (MagicMock(), "cpu")
|
||||||
mock_exists.return_value = True
|
mock_exists.return_value = True
|
||||||
mock_torch_load.return_value = MagicMock()
|
mock_torch_load.return_value = MagicMock()
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
|
with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
|
||||||
tts_service._generate_audio("Test text", "af", 1.0)
|
tts_service._generate_audio("Test text", "af", 1.0)
|
||||||
|
|
|
@ -19,7 +19,6 @@ output_dir = Path(__file__).parent / "output"
|
||||||
output_dir.mkdir(exist_ok=True)
|
output_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_voice(voice: str):
|
def test_voice(voice: str):
|
||||||
speech_file = output_dir / f"speech_{voice}.mp3"
|
speech_file = output_dir / f"speech_{voice}.mp3"
|
||||||
print(f"\nTesting voice: {voice}")
|
print(f"\nTesting voice: {voice}")
|
||||||
|
|
|
@ -1,21 +1,23 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import argparse
|
|
||||||
import os
|
import os
|
||||||
from typing import List, Optional, Dict, Tuple
|
import argparse
|
||||||
|
from typing import Dict, List, Tuple, Optional
|
||||||
|
|
||||||
import requests
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.io import wavfile
|
import requests
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
from scipy.io import wavfile
|
||||||
|
|
||||||
|
|
||||||
def submit_combine_voices(voices: List[str], base_url: str = "http://localhost:8880") -> Optional[str]:
|
def submit_combine_voices(
|
||||||
|
voices: List[str], base_url: str = "http://localhost:8880"
|
||||||
|
) -> Optional[str]:
|
||||||
"""Combine multiple voices into a new voice.
|
"""Combine multiple voices into a new voice.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
voices: List of voice names to combine (e.g. ["af_bella", "af_sarah"])
|
voices: List of voice names to combine (e.g. ["af_bella", "af_sarah"])
|
||||||
base_url: API base URL
|
base_url: API base URL
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Name of the combined voice (e.g. "af_bella_af_sarah") or None if error
|
Name of the combined voice (e.g. "af_bella_af_sarah") or None if error
|
||||||
"""
|
"""
|
||||||
|
@ -23,7 +25,7 @@ def submit_combine_voices(voices: List[str], base_url: str = "http://localhost:8
|
||||||
response = requests.post(f"{base_url}/v1/audio/voices/combine", json=voices)
|
response = requests.post(f"{base_url}/v1/audio/voices/combine", json=voices)
|
||||||
print(f"Response status: {response.status_code}")
|
print(f"Response status: {response.status_code}")
|
||||||
print(f"Raw response: {response.text}")
|
print(f"Raw response: {response.text}")
|
||||||
|
|
||||||
# Accept both 200 and 201 as success
|
# Accept both 200 and 201 as success
|
||||||
if response.status_code not in [200, 201]:
|
if response.status_code not in [200, 201]:
|
||||||
try:
|
try:
|
||||||
|
@ -32,7 +34,7 @@ def submit_combine_voices(voices: List[str], base_url: str = "http://localhost:8
|
||||||
except:
|
except:
|
||||||
print(f"Error combining voices: {response.text}")
|
print(f"Error combining voices: {response.text}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = response.json()
|
data = response.json()
|
||||||
if "voices" in data:
|
if "voices" in data:
|
||||||
|
@ -46,15 +48,20 @@ def submit_combine_voices(voices: List[str], base_url: str = "http://localhost:8
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def generate_speech(text: str, voice: str, base_url: str = "http://localhost:8880", output_file: str = "output.mp3") -> bool:
|
def generate_speech(
|
||||||
|
text: str,
|
||||||
|
voice: str,
|
||||||
|
base_url: str = "http://localhost:8880",
|
||||||
|
output_file: str = "output.mp3",
|
||||||
|
) -> bool:
|
||||||
"""Generate speech using specified voice.
|
"""Generate speech using specified voice.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: Text to convert to speech
|
text: Text to convert to speech
|
||||||
voice: Voice name to use
|
voice: Voice name to use
|
||||||
base_url: API base URL
|
base_url: API base URL
|
||||||
output_file: Path to save audio file
|
output_file: Path to save audio file
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if successful, False otherwise
|
True if successful, False otherwise
|
||||||
"""
|
"""
|
||||||
|
@ -65,22 +72,25 @@ def generate_speech(text: str, voice: str, base_url: str = "http://localhost:888
|
||||||
"input": text,
|
"input": text,
|
||||||
"voice": voice,
|
"voice": voice,
|
||||||
"speed": 1.0,
|
"speed": 1.0,
|
||||||
"response_format": "wav" # Use WAV for analysis
|
"response_format": "wav", # Use WAV for analysis
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
error = response.json().get("detail", {}).get("message", response.text)
|
error = response.json().get("detail", {}).get("message", response.text)
|
||||||
print(f"Error generating speech: {error}")
|
print(f"Error generating speech: {error}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Save the audio
|
# Save the audio
|
||||||
os.makedirs(os.path.dirname(output_file) if os.path.dirname(output_file) else ".", exist_ok=True)
|
os.makedirs(
|
||||||
|
os.path.dirname(output_file) if os.path.dirname(output_file) else ".",
|
||||||
|
exist_ok=True,
|
||||||
|
)
|
||||||
with open(output_file, "wb") as f:
|
with open(output_file, "wb") as f:
|
||||||
f.write(response.content)
|
f.write(response.content)
|
||||||
print(f"Saved audio to {output_file}")
|
print(f"Saved audio to {output_file}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error: {e}")
|
print(f"Error: {e}")
|
||||||
return False
|
return False
|
||||||
|
@ -88,57 +98,57 @@ def generate_speech(text: str, voice: str, base_url: str = "http://localhost:888
|
||||||
|
|
||||||
def analyze_audio(filepath: str) -> Tuple[np.ndarray, int, dict]:
|
def analyze_audio(filepath: str) -> Tuple[np.ndarray, int, dict]:
|
||||||
"""Analyze audio file and return samples, sample rate, and audio characteristics.
|
"""Analyze audio file and return samples, sample rate, and audio characteristics.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filepath: Path to audio file
|
filepath: Path to audio file
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (samples, sample_rate, characteristics)
|
Tuple of (samples, sample_rate, characteristics)
|
||||||
"""
|
"""
|
||||||
sample_rate, samples = wavfile.read(filepath)
|
sample_rate, samples = wavfile.read(filepath)
|
||||||
|
|
||||||
# Convert to mono if stereo
|
# Convert to mono if stereo
|
||||||
if len(samples.shape) > 1:
|
if len(samples.shape) > 1:
|
||||||
samples = np.mean(samples, axis=1)
|
samples = np.mean(samples, axis=1)
|
||||||
|
|
||||||
# Calculate basic stats
|
# Calculate basic stats
|
||||||
max_amp = np.max(np.abs(samples))
|
max_amp = np.max(np.abs(samples))
|
||||||
rms = np.sqrt(np.mean(samples**2))
|
rms = np.sqrt(np.mean(samples**2))
|
||||||
duration = len(samples) / sample_rate
|
duration = len(samples) / sample_rate
|
||||||
|
|
||||||
# Zero crossing rate (helps identify voice characteristics)
|
# Zero crossing rate (helps identify voice characteristics)
|
||||||
zero_crossings = np.sum(np.abs(np.diff(np.signbit(samples)))) / len(samples)
|
zero_crossings = np.sum(np.abs(np.diff(np.signbit(samples)))) / len(samples)
|
||||||
|
|
||||||
# Simple frequency analysis
|
# Simple frequency analysis
|
||||||
if len(samples) > 0:
|
if len(samples) > 0:
|
||||||
# Use FFT to get frequency components
|
# Use FFT to get frequency components
|
||||||
fft_result = np.fft.fft(samples)
|
fft_result = np.fft.fft(samples)
|
||||||
freqs = np.fft.fftfreq(len(samples), 1/sample_rate)
|
freqs = np.fft.fftfreq(len(samples), 1 / sample_rate)
|
||||||
|
|
||||||
# Get positive frequencies only
|
# Get positive frequencies only
|
||||||
pos_mask = freqs > 0
|
pos_mask = freqs > 0
|
||||||
freqs = freqs[pos_mask]
|
freqs = freqs[pos_mask]
|
||||||
magnitudes = np.abs(fft_result)[pos_mask]
|
magnitudes = np.abs(fft_result)[pos_mask]
|
||||||
|
|
||||||
# Find dominant frequencies (top 3)
|
# Find dominant frequencies (top 3)
|
||||||
top_indices = np.argsort(magnitudes)[-3:]
|
top_indices = np.argsort(magnitudes)[-3:]
|
||||||
dominant_freqs = freqs[top_indices]
|
dominant_freqs = freqs[top_indices]
|
||||||
|
|
||||||
# Calculate spectral centroid (brightness of sound)
|
# Calculate spectral centroid (brightness of sound)
|
||||||
spectral_centroid = np.sum(freqs * magnitudes) / np.sum(magnitudes)
|
spectral_centroid = np.sum(freqs * magnitudes) / np.sum(magnitudes)
|
||||||
else:
|
else:
|
||||||
dominant_freqs = []
|
dominant_freqs = []
|
||||||
spectral_centroid = 0
|
spectral_centroid = 0
|
||||||
|
|
||||||
characteristics = {
|
characteristics = {
|
||||||
"max_amplitude": max_amp,
|
"max_amplitude": max_amp,
|
||||||
"rms": rms,
|
"rms": rms,
|
||||||
"duration": duration,
|
"duration": duration,
|
||||||
"zero_crossing_rate": zero_crossings,
|
"zero_crossing_rate": zero_crossings,
|
||||||
"dominant_frequencies": dominant_freqs,
|
"dominant_frequencies": dominant_freqs,
|
||||||
"spectral_centroid": spectral_centroid
|
"spectral_centroid": spectral_centroid,
|
||||||
}
|
}
|
||||||
|
|
||||||
return samples, sample_rate, characteristics
|
return samples, sample_rate, characteristics
|
||||||
|
|
||||||
|
|
||||||
|
@ -167,112 +177,136 @@ def setup_plot(fig, ax, title):
|
||||||
|
|
||||||
return fig, ax
|
return fig, ax
|
||||||
|
|
||||||
|
|
||||||
def plot_analysis(audio_files: Dict[str, str], output_dir: str):
|
def plot_analysis(audio_files: Dict[str, str], output_dir: str):
|
||||||
"""Plot comprehensive voice analysis including waveforms and metrics comparison.
|
"""Plot comprehensive voice analysis including waveforms and metrics comparison.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
audio_files: Dictionary of label -> filepath
|
audio_files: Dictionary of label -> filepath
|
||||||
output_dir: Directory to save plot files
|
output_dir: Directory to save plot files
|
||||||
"""
|
"""
|
||||||
# Set dark style
|
# Set dark style
|
||||||
plt.style.use('dark_background')
|
plt.style.use("dark_background")
|
||||||
|
|
||||||
# Create figure with subplots
|
# Create figure with subplots
|
||||||
fig = plt.figure(figsize=(15, 15))
|
fig = plt.figure(figsize=(15, 15))
|
||||||
fig.patch.set_facecolor("#1a1a2e")
|
fig.patch.set_facecolor("#1a1a2e")
|
||||||
num_files = len(audio_files)
|
num_files = len(audio_files)
|
||||||
|
|
||||||
# Create subplot grid with proper spacing
|
# Create subplot grid with proper spacing
|
||||||
gs = plt.GridSpec(num_files + 1, 2, height_ratios=[1.5]*num_files + [1],
|
gs = plt.GridSpec(
|
||||||
hspace=0.4, wspace=0.3)
|
num_files + 1, 2, height_ratios=[1.5] * num_files + [1], hspace=0.4, wspace=0.3
|
||||||
|
)
|
||||||
|
|
||||||
# Analyze all files first
|
# Analyze all files first
|
||||||
all_chars = {}
|
all_chars = {}
|
||||||
for i, (label, filepath) in enumerate(audio_files.items()):
|
for i, (label, filepath) in enumerate(audio_files.items()):
|
||||||
samples, sample_rate, chars = analyze_audio(filepath)
|
samples, sample_rate, chars = analyze_audio(filepath)
|
||||||
all_chars[label] = chars
|
all_chars[label] = chars
|
||||||
|
|
||||||
# Plot waveform spanning both columns
|
# Plot waveform spanning both columns
|
||||||
ax = plt.subplot(gs[i, :])
|
ax = plt.subplot(gs[i, :])
|
||||||
time = np.arange(len(samples)) / sample_rate
|
time = np.arange(len(samples)) / sample_rate
|
||||||
plt.plot(time, samples / chars['max_amplitude'], linewidth=0.5, color="#ff2a6d")
|
plt.plot(time, samples / chars["max_amplitude"], linewidth=0.5, color="#ff2a6d")
|
||||||
ax.set_xlabel("Time (seconds)")
|
ax.set_xlabel("Time (seconds)")
|
||||||
ax.set_ylabel("Normalized Amplitude")
|
ax.set_ylabel("Normalized Amplitude")
|
||||||
ax.set_ylim(-1.1, 1.1)
|
ax.set_ylim(-1.1, 1.1)
|
||||||
setup_plot(fig, ax, f"Waveform: {label}")
|
setup_plot(fig, ax, f"Waveform: {label}")
|
||||||
|
|
||||||
# Colors for voices
|
# Colors for voices
|
||||||
colors = ["#ff2a6d", "#05d9e8", "#d1f7ff"]
|
colors = ["#ff2a6d", "#05d9e8", "#d1f7ff"]
|
||||||
|
|
||||||
# Create two subplots for metrics with similar scales
|
# Create two subplots for metrics with similar scales
|
||||||
# Left subplot: Brightness and Volume
|
# Left subplot: Brightness and Volume
|
||||||
ax1 = plt.subplot(gs[num_files, 0])
|
ax1 = plt.subplot(gs[num_files, 0])
|
||||||
metrics1 = [
|
metrics1 = [
|
||||||
('Brightness', [chars['spectral_centroid']/1000 for chars in all_chars.values()], 'kHz'),
|
(
|
||||||
('Volume', [chars['rms']*100 for chars in all_chars.values()], 'RMS×100')
|
"Brightness",
|
||||||
|
[chars["spectral_centroid"] / 1000 for chars in all_chars.values()],
|
||||||
|
"kHz",
|
||||||
|
),
|
||||||
|
("Volume", [chars["rms"] * 100 for chars in all_chars.values()], "RMS×100"),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Right subplot: Voice Pitch and Texture
|
# Right subplot: Voice Pitch and Texture
|
||||||
ax2 = plt.subplot(gs[num_files, 1])
|
ax2 = plt.subplot(gs[num_files, 1])
|
||||||
metrics2 = [
|
metrics2 = [
|
||||||
('Voice Pitch', [min(chars['dominant_frequencies']) for chars in all_chars.values()], 'Hz'),
|
(
|
||||||
('Texture', [chars['zero_crossing_rate']*1000 for chars in all_chars.values()], 'ZCR×1000')
|
"Voice Pitch",
|
||||||
|
[min(chars["dominant_frequencies"]) for chars in all_chars.values()],
|
||||||
|
"Hz",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"Texture",
|
||||||
|
[chars["zero_crossing_rate"] * 1000 for chars in all_chars.values()],
|
||||||
|
"ZCR×1000",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
def plot_grouped_bars(ax, metrics, show_legend=True):
|
def plot_grouped_bars(ax, metrics, show_legend=True):
|
||||||
n_groups = len(metrics)
|
n_groups = len(metrics)
|
||||||
n_voices = len(audio_files)
|
n_voices = len(audio_files)
|
||||||
bar_width = 0.25
|
bar_width = 0.25
|
||||||
|
|
||||||
indices = np.arange(n_groups)
|
indices = np.arange(n_groups)
|
||||||
|
|
||||||
# Get max value for y-axis scaling
|
# Get max value for y-axis scaling
|
||||||
max_val = max(max(m[1]) for m in metrics)
|
max_val = max(max(m[1]) for m in metrics)
|
||||||
|
|
||||||
for i, (voice, color) in enumerate(zip(audio_files.keys(), colors)):
|
for i, (voice, color) in enumerate(zip(audio_files.keys(), colors)):
|
||||||
values = [m[1][i] for m in metrics]
|
values = [m[1][i] for m in metrics]
|
||||||
offset = (i - n_voices/2 + 0.5) * bar_width
|
offset = (i - n_voices / 2 + 0.5) * bar_width
|
||||||
bars = ax.bar(indices + offset, values, bar_width,
|
bars = ax.bar(
|
||||||
label=voice, color=color, alpha=0.8)
|
indices + offset, values, bar_width, label=voice, color=color, alpha=0.8
|
||||||
|
)
|
||||||
|
|
||||||
# Add value labels on top of bars
|
# Add value labels on top of bars
|
||||||
for bar in bars:
|
for bar in bars:
|
||||||
height = bar.get_height()
|
height = bar.get_height()
|
||||||
ax.text(bar.get_x() + bar.get_width()/2., height,
|
ax.text(
|
||||||
f'{height:.1f}',
|
bar.get_x() + bar.get_width() / 2.0,
|
||||||
ha='center', va='bottom', color='white',
|
height,
|
||||||
fontsize=10)
|
f"{height:.1f}",
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
color="white",
|
||||||
|
fontsize=10,
|
||||||
|
)
|
||||||
|
|
||||||
ax.set_xticks(indices)
|
ax.set_xticks(indices)
|
||||||
ax.set_xticklabels([f"{m[0]}\n({m[2]})" for m in metrics])
|
ax.set_xticklabels([f"{m[0]}\n({m[2]})" for m in metrics])
|
||||||
|
|
||||||
# Set y-axis limits with some padding
|
# Set y-axis limits with some padding
|
||||||
ax.set_ylim(0, max_val * 1.2)
|
ax.set_ylim(0, max_val * 1.2)
|
||||||
|
|
||||||
if show_legend:
|
if show_legend:
|
||||||
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left',
|
ax.legend(
|
||||||
facecolor="#1a1a2e", edgecolor="#ffffff")
|
bbox_to_anchor=(1.05, 1),
|
||||||
|
loc="upper left",
|
||||||
|
facecolor="#1a1a2e",
|
||||||
|
edgecolor="#ffffff",
|
||||||
|
)
|
||||||
|
|
||||||
# Plot both subplots
|
# Plot both subplots
|
||||||
plot_grouped_bars(ax1, metrics1, show_legend=True)
|
plot_grouped_bars(ax1, metrics1, show_legend=True)
|
||||||
plot_grouped_bars(ax2, metrics2, show_legend=False)
|
plot_grouped_bars(ax2, metrics2, show_legend=False)
|
||||||
|
|
||||||
# Style both subplots
|
# Style both subplots
|
||||||
setup_plot(fig, ax1, 'Brightness and Volume')
|
setup_plot(fig, ax1, "Brightness and Volume")
|
||||||
setup_plot(fig, ax2, 'Voice Pitch and Texture')
|
setup_plot(fig, ax2, "Voice Pitch and Texture")
|
||||||
|
|
||||||
# Add y-axis labels
|
# Add y-axis labels
|
||||||
ax1.set_ylabel('Value')
|
ax1.set_ylabel("Value")
|
||||||
ax2.set_ylabel('Value')
|
ax2.set_ylabel("Value")
|
||||||
|
|
||||||
# Adjust the figure size to accommodate the legend
|
# Adjust the figure size to accommodate the legend
|
||||||
fig.set_size_inches(15, 15)
|
fig.set_size_inches(15, 15)
|
||||||
|
|
||||||
# Add padding around the entire figure
|
# Add padding around the entire figure
|
||||||
plt.subplots_adjust(right=0.85, top=0.95, bottom=0.05, left=0.1)
|
plt.subplots_adjust(right=0.85, top=0.95, bottom=0.05, left=0.1)
|
||||||
plt.savefig(os.path.join(output_dir, "analysis_comparison.png"), dpi=300)
|
plt.savefig(os.path.join(output_dir, "analysis_comparison.png"), dpi=300)
|
||||||
print(f"Saved analysis comparison to {output_dir}/analysis_comparison.png")
|
print(f"Saved analysis comparison to {output_dir}/analysis_comparison.png")
|
||||||
|
|
||||||
# Print detailed comparative analysis
|
# Print detailed comparative analysis
|
||||||
print("\nDetailed Voice Analysis:")
|
print("\nDetailed Voice Analysis:")
|
||||||
for label, chars in all_chars.items():
|
for label, chars in all_chars.items():
|
||||||
|
@ -282,44 +316,57 @@ def plot_analysis(audio_files: Dict[str, str], output_dir: str):
|
||||||
print(f" Duration: {chars['duration']:.2f}s")
|
print(f" Duration: {chars['duration']:.2f}s")
|
||||||
print(f" Zero Crossing Rate: {chars['zero_crossing_rate']:.3f}")
|
print(f" Zero Crossing Rate: {chars['zero_crossing_rate']:.3f}")
|
||||||
print(f" Spectral Centroid: {chars['spectral_centroid']:.0f}Hz")
|
print(f" Spectral Centroid: {chars['spectral_centroid']:.0f}Hz")
|
||||||
print(f" Dominant Frequencies: {', '.join(f'{f:.0f}Hz' for f in chars['dominant_frequencies'])}")
|
print(
|
||||||
|
f" Dominant Frequencies: {', '.join(f'{f:.0f}Hz' for f in chars['dominant_frequencies'])}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="Kokoro Voice Analysis Demo")
|
parser = argparse.ArgumentParser(description="Kokoro Voice Analysis Demo")
|
||||||
parser.add_argument("--voices", nargs="+", type=str, help="Voices to combine")
|
parser.add_argument("--voices", nargs="+", type=str, help="Voices to combine")
|
||||||
parser.add_argument("--text", type=str, default="Hello! This is a test of combined voices.", help="Text to speak")
|
parser.add_argument(
|
||||||
|
"--text",
|
||||||
|
type=str,
|
||||||
|
default="Hello! This is a test of combined voices.",
|
||||||
|
help="Text to speak",
|
||||||
|
)
|
||||||
parser.add_argument("--url", default="http://localhost:8880", help="API base URL")
|
parser.add_argument("--url", default="http://localhost:8880", help="API base URL")
|
||||||
parser.add_argument("--output-dir", default="examples/output", help="Output directory for audio files")
|
parser.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
default="examples/output",
|
||||||
|
help="Output directory for audio files",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if not args.voices:
|
if not args.voices:
|
||||||
print("No voices provided, using default test voices")
|
print("No voices provided, using default test voices")
|
||||||
args.voices = ["af_bella", "af_nicole"]
|
args.voices = ["af_bella", "af_nicole"]
|
||||||
|
|
||||||
# Create output directory
|
# Create output directory
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
# Dictionary to store audio files for analysis
|
# Dictionary to store audio files for analysis
|
||||||
audio_files = {}
|
audio_files = {}
|
||||||
|
|
||||||
# Generate speech with individual voices
|
# Generate speech with individual voices
|
||||||
print("Generating speech with individual voices...")
|
print("Generating speech with individual voices...")
|
||||||
for voice in args.voices:
|
for voice in args.voices:
|
||||||
output_file = os.path.join(args.output_dir, f"analysis_{voice}.wav")
|
output_file = os.path.join(args.output_dir, f"analysis_{voice}.wav")
|
||||||
if generate_speech(args.text, voice, args.url, output_file):
|
if generate_speech(args.text, voice, args.url, output_file):
|
||||||
audio_files[voice] = output_file
|
audio_files[voice] = output_file
|
||||||
|
|
||||||
# Generate speech with combined voice
|
# Generate speech with combined voice
|
||||||
print(f"\nCombining voices: {', '.join(args.voices)}")
|
print(f"\nCombining voices: {', '.join(args.voices)}")
|
||||||
combined_voice = submit_combine_voices(args.voices, args.url)
|
combined_voice = submit_combine_voices(args.voices, args.url)
|
||||||
|
|
||||||
if combined_voice:
|
if combined_voice:
|
||||||
print(f"Successfully created combined voice: {combined_voice}")
|
print(f"Successfully created combined voice: {combined_voice}")
|
||||||
output_file = os.path.join(args.output_dir, f"analysis_combined_{combined_voice}.wav")
|
output_file = os.path.join(
|
||||||
|
args.output_dir, f"analysis_combined_{combined_voice}.wav"
|
||||||
|
)
|
||||||
if generate_speech(args.text, combined_voice, args.url, output_file):
|
if generate_speech(args.text, combined_voice, args.url, output_file):
|
||||||
audio_files["combined"] = output_file
|
audio_files["combined"] = output_file
|
||||||
|
|
||||||
# Generate comparison plots
|
# Generate comparison plots
|
||||||
plot_analysis(audio_files, args.output_dir)
|
plot_analysis(audio_files, args.output_dir)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -60,7 +60,7 @@ def test_speed(speed: float):
|
||||||
|
|
||||||
# Test different formats
|
# Test different formats
|
||||||
for format in ["wav", "mp3", "opus", "aac", "flac", "pcm"]:
|
for format in ["wav", "mp3", "opus", "aac", "flac", "pcm"]:
|
||||||
test_format(format) # aac and pcm should fail as they are not supported
|
test_format(format) # aac and pcm should fail as they are not supported
|
||||||
|
|
||||||
# Test different speeds
|
# Test different speeds
|
||||||
for speed in [0.25, 1.0, 2.0, 4.0]: # 5.0 should fail as it's out of range
|
for speed in [0.25, 1.0, 2.0, 4.0]: # 5.0 should fail as it's out of range
|
||||||
|
|
|
@ -10,3 +10,5 @@ sqlalchemy==2.0.27
|
||||||
pytest==8.0.0
|
pytest==8.0.0
|
||||||
httpx==0.26.0
|
httpx==0.26.0
|
||||||
pytest-asyncio==0.23.5
|
pytest-asyncio==0.23.5
|
||||||
|
pytest-cov==6.0.0
|
||||||
|
gradio==4.19.2
|
||||||
|
|
|
@ -2,8 +2,4 @@ from lib.interface import create_interface
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
demo = create_interface()
|
demo = create_interface()
|
||||||
demo.launch(
|
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
|
||||||
server_name="0.0.0.0",
|
|
||||||
server_port=7860,
|
|
||||||
show_error=True
|
|
||||||
)
|
|
||||||
|
|
|
@ -1,16 +1,19 @@
|
||||||
import requests
|
|
||||||
from typing import Tuple, List, Optional
|
|
||||||
import os
|
import os
|
||||||
import datetime
|
import datetime
|
||||||
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
from .config import API_URL, OUTPUTS_DIR
|
from .config import API_URL, OUTPUTS_DIR
|
||||||
|
|
||||||
|
|
||||||
def check_api_status() -> Tuple[bool, List[str]]:
|
def check_api_status() -> Tuple[bool, List[str]]:
|
||||||
"""Check TTS service status and get available voices."""
|
"""Check TTS service status and get available voices."""
|
||||||
try:
|
try:
|
||||||
# Use a longer timeout during startup
|
# Use a longer timeout during startup
|
||||||
response = requests.get(
|
response = requests.get(
|
||||||
f"{API_URL}/v1/audio/voices",
|
f"{API_URL}/v1/audio/voices",
|
||||||
timeout=30 # Increased timeout for initial startup period
|
timeout=30, # Increased timeout for initial startup period
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
voices = response.json().get("voices", [])
|
voices = response.json().get("voices", [])
|
||||||
|
@ -31,16 +34,19 @@ def check_api_status() -> Tuple[bool, List[str]]:
|
||||||
print(f"Unexpected error checking API status: {str(e)}")
|
print(f"Unexpected error checking API status: {str(e)}")
|
||||||
return False, []
|
return False, []
|
||||||
|
|
||||||
def text_to_speech(text: str, voice_id: str, format: str, speed: float) -> Optional[str]:
|
|
||||||
|
def text_to_speech(
|
||||||
|
text: str, voice_id: str, format: str, speed: float
|
||||||
|
) -> Optional[str]:
|
||||||
"""Generate speech from text using TTS API."""
|
"""Generate speech from text using TTS API."""
|
||||||
if not text.strip():
|
if not text.strip():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Create output filename
|
# Create output filename
|
||||||
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||||
output_filename = f"output_{timestamp}_voice-{voice_id}_speed-{speed}.{format}"
|
output_filename = f"output_{timestamp}_voice-{voice_id}_speed-{speed}.{format}"
|
||||||
output_path = os.path.join(OUTPUTS_DIR, output_filename)
|
output_path = os.path.join(OUTPUTS_DIR, output_filename)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{API_URL}/v1/audio/speech",
|
f"{API_URL}/v1/audio/speech",
|
||||||
|
@ -49,17 +55,17 @@ def text_to_speech(text: str, voice_id: str, format: str, speed: float) -> Optio
|
||||||
"input": text,
|
"input": text,
|
||||||
"voice": voice_id,
|
"voice": voice_id,
|
||||||
"response_format": format,
|
"response_format": format,
|
||||||
"speed": float(speed)
|
"speed": float(speed),
|
||||||
},
|
},
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
timeout=300 # Longer timeout for speech generation
|
timeout=300, # Longer timeout for speech generation
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
with open(output_path, "wb") as f:
|
with open(output_path, "wb") as f:
|
||||||
f.write(response.content)
|
f.write(response.content)
|
||||||
return output_path
|
return output_path
|
||||||
|
|
||||||
except requests.exceptions.Timeout:
|
except requests.exceptions.Timeout:
|
||||||
print("Speech generation request timed out")
|
print("Speech generation request timed out")
|
||||||
return None
|
return None
|
||||||
|
@ -70,6 +76,7 @@ def text_to_speech(text: str, voice_id: str, format: str, speed: float) -> Optio
|
||||||
print(f"Unexpected error generating speech: {str(e)}")
|
print(f"Unexpected error generating speech: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_status_html(is_available: bool) -> str:
|
def get_status_html(is_available: bool) -> str:
|
||||||
"""Generate HTML for status indicator."""
|
"""Generate HTML for status indicator."""
|
||||||
color = "green" if is_available else "red"
|
color = "green" if is_available else "red"
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
import gradio as gr
|
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
from .. import files
|
from .. import files
|
||||||
|
|
||||||
|
|
||||||
def create_input_column() -> Tuple[gr.Column, dict]:
|
def create_input_column() -> Tuple[gr.Column, dict]:
|
||||||
"""Create the input column with text input and file handling."""
|
"""Create the input column with text input and file handling."""
|
||||||
with gr.Column(scale=1) as col:
|
with gr.Column(scale=1) as col:
|
||||||
|
@ -11,49 +14,36 @@ def create_input_column() -> Tuple[gr.Column, dict]:
|
||||||
# Direct Input Tab
|
# Direct Input Tab
|
||||||
with gr.TabItem("Direct Input"):
|
with gr.TabItem("Direct Input"):
|
||||||
text_input = gr.Textbox(
|
text_input = gr.Textbox(
|
||||||
label="Text to speak",
|
label="Text to speak", placeholder="Enter text here...", lines=4
|
||||||
placeholder="Enter text here...",
|
|
||||||
lines=4
|
|
||||||
)
|
)
|
||||||
text_submit = gr.Button(
|
text_submit = gr.Button("Generate Speech", variant="primary", size="lg")
|
||||||
"Generate Speech",
|
|
||||||
variant="primary",
|
|
||||||
size="lg"
|
|
||||||
)
|
|
||||||
|
|
||||||
# File Input Tab
|
# File Input Tab
|
||||||
with gr.TabItem("From File"):
|
with gr.TabItem("From File"):
|
||||||
# Existing files dropdown
|
# Existing files dropdown
|
||||||
input_files_list = gr.Dropdown(
|
input_files_list = gr.Dropdown(
|
||||||
label="Select Existing File",
|
label="Select Existing File",
|
||||||
choices=files.list_input_files(),
|
choices=files.list_input_files(),
|
||||||
value=None
|
value=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Simple file upload
|
# Simple file upload
|
||||||
file_upload = gr.File(
|
file_upload = gr.File(
|
||||||
label="Upload Text File (.txt)",
|
label="Upload Text File (.txt)", file_types=[".txt"]
|
||||||
file_types=[".txt"]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
file_preview = gr.Textbox(
|
file_preview = gr.Textbox(
|
||||||
label="File Content Preview",
|
label="File Content Preview", interactive=False, lines=4
|
||||||
interactive=False,
|
|
||||||
lines=4
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
file_submit = gr.Button(
|
file_submit = gr.Button(
|
||||||
"Generate Speech",
|
"Generate Speech", variant="primary", size="lg"
|
||||||
variant="primary",
|
|
||||||
size="lg"
|
|
||||||
)
|
)
|
||||||
clear_files = gr.Button(
|
clear_files = gr.Button(
|
||||||
"Clear Files",
|
"Clear Files", variant="secondary", size="lg"
|
||||||
variant="secondary",
|
|
||||||
size="lg"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
components = {
|
components = {
|
||||||
"tabs": tabs,
|
"tabs": tabs,
|
||||||
"text_input": text_input,
|
"text_input": text_input,
|
||||||
|
@ -62,7 +52,7 @@ def create_input_column() -> Tuple[gr.Column, dict]:
|
||||||
"file_preview": file_preview,
|
"file_preview": file_preview,
|
||||||
"text_submit": text_submit,
|
"text_submit": text_submit,
|
||||||
"file_submit": file_submit,
|
"file_submit": file_submit,
|
||||||
"clear_files": clear_files
|
"clear_files": clear_files,
|
||||||
}
|
}
|
||||||
|
|
||||||
return col, components
|
return col, components
|
||||||
|
|
|
@ -1,45 +1,41 @@
|
||||||
import gradio as gr
|
|
||||||
from typing import Tuple, Optional
|
from typing import Tuple, Optional
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
from .. import api, config
|
from .. import api, config
|
||||||
|
|
||||||
|
|
||||||
def create_model_column(voice_ids: Optional[list] = None) -> Tuple[gr.Column, dict]:
|
def create_model_column(voice_ids: Optional[list] = None) -> Tuple[gr.Column, dict]:
|
||||||
"""Create the model settings column."""
|
"""Create the model settings column."""
|
||||||
if voice_ids is None:
|
if voice_ids is None:
|
||||||
voice_ids = []
|
voice_ids = []
|
||||||
|
|
||||||
with gr.Column(scale=1) as col:
|
with gr.Column(scale=1) as col:
|
||||||
gr.Markdown("### Model Settings")
|
gr.Markdown("### Model Settings")
|
||||||
|
|
||||||
# Status button starts in waiting state
|
# Status button starts in waiting state
|
||||||
status_btn = gr.Button(
|
status_btn = gr.Button(
|
||||||
"⌛ TTS Service: Waiting for Service...",
|
"⌛ TTS Service: Waiting for Service...", variant="secondary"
|
||||||
variant="secondary"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
voice_input = gr.Dropdown(
|
voice_input = gr.Dropdown(
|
||||||
choices=voice_ids,
|
choices=voice_ids,
|
||||||
label="Voice",
|
label="Voice",
|
||||||
value=voice_ids[0] if voice_ids else None,
|
value=voice_ids[0] if voice_ids else None,
|
||||||
interactive=True
|
interactive=True,
|
||||||
)
|
)
|
||||||
format_input = gr.Dropdown(
|
format_input = gr.Dropdown(
|
||||||
choices=config.AUDIO_FORMATS,
|
choices=config.AUDIO_FORMATS, label="Audio Format", value="mp3"
|
||||||
label="Audio Format",
|
|
||||||
value="mp3"
|
|
||||||
)
|
)
|
||||||
speed_input = gr.Slider(
|
speed_input = gr.Slider(
|
||||||
minimum=0.5,
|
minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Speed"
|
||||||
maximum=2.0,
|
|
||||||
value=1.0,
|
|
||||||
step=0.1,
|
|
||||||
label="Speed"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
components = {
|
components = {
|
||||||
"status_btn": status_btn,
|
"status_btn": status_btn,
|
||||||
"voice": voice_input,
|
"voice": voice_input,
|
||||||
"format": format_input,
|
"format": format_input,
|
||||||
"speed": speed_input
|
"speed": speed_input,
|
||||||
}
|
}
|
||||||
|
|
||||||
return col, components
|
return col, components
|
||||||
|
|
|
@ -1,40 +1,42 @@
|
||||||
import gradio as gr
|
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
from .. import files
|
from .. import files
|
||||||
|
|
||||||
|
|
||||||
def create_output_column() -> Tuple[gr.Column, dict]:
|
def create_output_column() -> Tuple[gr.Column, dict]:
|
||||||
"""Create the output column with audio player and file list."""
|
"""Create the output column with audio player and file list."""
|
||||||
with gr.Column(scale=1) as col:
|
with gr.Column(scale=1) as col:
|
||||||
gr.Markdown("### Latest Output")
|
gr.Markdown("### Latest Output")
|
||||||
audio_output = gr.Audio(
|
audio_output = gr.Audio(label="Generated Speech", type="filepath")
|
||||||
label="Generated Speech",
|
|
||||||
type="filepath"
|
|
||||||
)
|
|
||||||
|
|
||||||
gr.Markdown("### Generated Files")
|
gr.Markdown("### Generated Files")
|
||||||
output_files = gr.Dropdown(
|
output_files = gr.Dropdown(
|
||||||
label="Previous Outputs",
|
label="Previous Outputs",
|
||||||
choices=files.list_output_files(),
|
choices=files.list_output_files(),
|
||||||
value=None,
|
value=None,
|
||||||
allow_custom_value=False
|
allow_custom_value=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
play_btn = gr.Button("▶️ Play Selected", size="sm")
|
play_btn = gr.Button("▶️ Play Selected", size="sm")
|
||||||
|
|
||||||
selected_audio = gr.Audio(
|
selected_audio = gr.Audio(
|
||||||
label="Selected Output",
|
label="Selected Output", type="filepath", visible=False
|
||||||
type="filepath",
|
|
||||||
visible=False
|
|
||||||
)
|
)
|
||||||
|
|
||||||
clear_outputs = gr.Button("⚠️ Delete All Previously Generated Output Audio 🗑️", size="sm", variant="secondary")
|
clear_outputs = gr.Button(
|
||||||
|
"⚠️ Delete All Previously Generated Output Audio 🗑️",
|
||||||
|
size="sm",
|
||||||
|
variant="secondary",
|
||||||
|
)
|
||||||
|
|
||||||
components = {
|
components = {
|
||||||
"audio_output": audio_output,
|
"audio_output": audio_output,
|
||||||
"output_files": output_files,
|
"output_files": output_files,
|
||||||
"play_btn": play_btn,
|
"play_btn": play_btn,
|
||||||
"selected_audio": selected_audio,
|
"selected_audio": selected_audio,
|
||||||
"clear_outputs": clear_outputs
|
"clear_outputs": clear_outputs,
|
||||||
}
|
}
|
||||||
|
|
||||||
return col, components
|
return col, components
|
||||||
|
|
|
@ -1,17 +1,23 @@
|
||||||
import os
|
import os
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
import datetime
|
import datetime
|
||||||
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
from .config import INPUTS_DIR, OUTPUTS_DIR, AUDIO_FORMATS
|
from .config import INPUTS_DIR, OUTPUTS_DIR, AUDIO_FORMATS
|
||||||
|
|
||||||
|
|
||||||
def list_input_files() -> List[str]:
|
def list_input_files() -> List[str]:
|
||||||
"""List all input text files."""
|
"""List all input text files."""
|
||||||
return [f for f in os.listdir(INPUTS_DIR) if f.endswith('.txt')]
|
return [f for f in os.listdir(INPUTS_DIR) if f.endswith(".txt")]
|
||||||
|
|
||||||
|
|
||||||
def list_output_files() -> List[str]:
|
def list_output_files() -> List[str]:
|
||||||
"""List all output audio files."""
|
"""List all output audio files."""
|
||||||
return [os.path.join(OUTPUTS_DIR, f)
|
return [
|
||||||
for f in os.listdir(OUTPUTS_DIR)
|
os.path.join(OUTPUTS_DIR, f)
|
||||||
if any(f.endswith(ext) for ext in AUDIO_FORMATS)]
|
for f in os.listdir(OUTPUTS_DIR)
|
||||||
|
if any(f.endswith(ext) for ext in AUDIO_FORMATS)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def read_text_file(filename: str) -> str:
|
def read_text_file(filename: str) -> str:
|
||||||
"""Read content of a text file."""
|
"""Read content of a text file."""
|
||||||
|
@ -19,16 +25,17 @@ def read_text_file(filename: str) -> str:
|
||||||
return ""
|
return ""
|
||||||
try:
|
try:
|
||||||
file_path = os.path.join(INPUTS_DIR, filename)
|
file_path = os.path.join(INPUTS_DIR, filename)
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
return f.read()
|
return f.read()
|
||||||
except:
|
except:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def save_text(text: str, filename: Optional[str] = None) -> Optional[str]:
|
def save_text(text: str, filename: Optional[str] = None) -> Optional[str]:
|
||||||
"""Save text to a file. Returns the filename if successful."""
|
"""Save text to a file. Returns the filename if successful."""
|
||||||
if not text.strip():
|
if not text.strip():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if filename is None:
|
if filename is None:
|
||||||
# Use input_1.txt, input_2.txt, etc.
|
# Use input_1.txt, input_2.txt, etc.
|
||||||
base = "input"
|
base = "input"
|
||||||
|
@ -41,12 +48,12 @@ def save_text(text: str, filename: Optional[str] = None) -> Optional[str]:
|
||||||
else:
|
else:
|
||||||
# Handle duplicate filenames by adding _1, _2, etc.
|
# Handle duplicate filenames by adding _1, _2, etc.
|
||||||
base = os.path.splitext(filename)[0]
|
base = os.path.splitext(filename)[0]
|
||||||
ext = os.path.splitext(filename)[1] or '.txt'
|
ext = os.path.splitext(filename)[1] or ".txt"
|
||||||
counter = 1
|
counter = 1
|
||||||
while os.path.exists(os.path.join(INPUTS_DIR, filename)):
|
while os.path.exists(os.path.join(INPUTS_DIR, filename)):
|
||||||
filename = f"{base}_{counter}{ext}"
|
filename = f"{base}_{counter}{ext}"
|
||||||
counter += 1
|
counter += 1
|
||||||
|
|
||||||
filepath = os.path.join(INPUTS_DIR, filename)
|
filepath = os.path.join(INPUTS_DIR, filename)
|
||||||
try:
|
try:
|
||||||
with open(filepath, "w", encoding="utf-8") as f:
|
with open(filepath, "w", encoding="utf-8") as f:
|
||||||
|
@ -56,11 +63,12 @@ def save_text(text: str, filename: Optional[str] = None) -> Optional[str]:
|
||||||
print(f"Error saving file: {e}")
|
print(f"Error saving file: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def delete_all_input_files() -> bool:
|
def delete_all_input_files() -> bool:
|
||||||
"""Delete all files from the inputs directory. Returns True if successful."""
|
"""Delete all files from the inputs directory. Returns True if successful."""
|
||||||
try:
|
try:
|
||||||
for filename in os.listdir(INPUTS_DIR):
|
for filename in os.listdir(INPUTS_DIR):
|
||||||
if filename.endswith('.txt'):
|
if filename.endswith(".txt"):
|
||||||
file_path = os.path.join(INPUTS_DIR, filename)
|
file_path = os.path.join(INPUTS_DIR, filename)
|
||||||
os.remove(file_path)
|
os.remove(file_path)
|
||||||
return True
|
return True
|
||||||
|
@ -68,6 +76,7 @@ def delete_all_input_files() -> bool:
|
||||||
print(f"Error deleting input files: {e}")
|
print(f"Error deleting input files: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def delete_all_output_files() -> bool:
|
def delete_all_output_files() -> bool:
|
||||||
"""Delete all audio files from the outputs directory. Returns True if successful."""
|
"""Delete all audio files from the outputs directory. Returns True if successful."""
|
||||||
try:
|
try:
|
||||||
|
@ -80,19 +89,20 @@ def delete_all_output_files() -> bool:
|
||||||
print(f"Error deleting output files: {e}")
|
print(f"Error deleting output files: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def process_uploaded_file(file_path: str) -> bool:
|
def process_uploaded_file(file_path: str) -> bool:
|
||||||
"""Save uploaded file to inputs directory. Returns True if successful."""
|
"""Save uploaded file to inputs directory. Returns True if successful."""
|
||||||
if not file_path:
|
if not file_path:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
filename = os.path.basename(file_path)
|
filename = os.path.basename(file_path)
|
||||||
if not filename.endswith('.txt'):
|
if not filename.endswith(".txt"):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Create target path in inputs directory
|
# Create target path in inputs directory
|
||||||
target_path = os.path.join(INPUTS_DIR, filename)
|
target_path = os.path.join(INPUTS_DIR, filename)
|
||||||
|
|
||||||
# If file exists, add number suffix
|
# If file exists, add number suffix
|
||||||
base, ext = os.path.splitext(filename)
|
base, ext = os.path.splitext(filename)
|
||||||
counter = 1
|
counter = 1
|
||||||
|
@ -100,12 +110,13 @@ def process_uploaded_file(file_path: str) -> bool:
|
||||||
new_name = f"{base}_{counter}{ext}"
|
new_name = f"{base}_{counter}{ext}"
|
||||||
target_path = os.path.join(INPUTS_DIR, new_name)
|
target_path = os.path.join(INPUTS_DIR, new_name)
|
||||||
counter += 1
|
counter += 1
|
||||||
|
|
||||||
# Copy file to inputs directory
|
# Copy file to inputs directory
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
shutil.copy2(file_path, target_path)
|
shutil.copy2(file_path, target_path)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error saving uploaded file: {e}")
|
print(f"Error saving uploaded file: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -1,16 +1,19 @@
|
||||||
import gradio as gr
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
from . import api, files
|
from . import api, files
|
||||||
|
|
||||||
|
|
||||||
def setup_event_handlers(components: dict):
|
def setup_event_handlers(components: dict):
|
||||||
"""Set up all event handlers for the UI components."""
|
"""Set up all event handlers for the UI components."""
|
||||||
|
|
||||||
def refresh_status():
|
def refresh_status():
|
||||||
try:
|
try:
|
||||||
is_available, voices = api.check_api_status()
|
is_available, voices = api.check_api_status()
|
||||||
status = "Available" if is_available else "Waiting for Service..."
|
status = "Available" if is_available else "Waiting for Service..."
|
||||||
|
|
||||||
if is_available and voices:
|
if is_available and voices:
|
||||||
# Preserve current voice selection if it exists and is still valid
|
# Preserve current voice selection if it exists and is still valid
|
||||||
current_voice = components["model"]["voice"].value
|
current_voice = components["model"]["voice"].value
|
||||||
|
@ -19,17 +22,17 @@ def setup_event_handlers(components: dict):
|
||||||
gr.update(
|
gr.update(
|
||||||
value=f"🔄 TTS Service: {status}",
|
value=f"🔄 TTS Service: {status}",
|
||||||
interactive=True,
|
interactive=True,
|
||||||
variant="secondary"
|
variant="secondary",
|
||||||
),
|
),
|
||||||
gr.update(choices=voices, value=default_voice)
|
gr.update(choices=voices, value=default_voice),
|
||||||
]
|
]
|
||||||
return [
|
return [
|
||||||
gr.update(
|
gr.update(
|
||||||
value=f"⌛ TTS Service: {status}",
|
value=f"⌛ TTS Service: {status}",
|
||||||
interactive=True,
|
interactive=True,
|
||||||
variant="secondary"
|
variant="secondary",
|
||||||
),
|
),
|
||||||
gr.update(choices=[], value=None)
|
gr.update(choices=[], value=None),
|
||||||
]
|
]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error in refresh status: {str(e)}")
|
print(f"Error in refresh status: {str(e)}")
|
||||||
|
@ -37,11 +40,11 @@ def setup_event_handlers(components: dict):
|
||||||
gr.update(
|
gr.update(
|
||||||
value="❌ TTS Service: Connection Error",
|
value="❌ TTS Service: Connection Error",
|
||||||
interactive=True,
|
interactive=True,
|
||||||
variant="secondary"
|
variant="secondary",
|
||||||
),
|
),
|
||||||
gr.update(choices=[], value=None)
|
gr.update(choices=[], value=None),
|
||||||
]
|
]
|
||||||
|
|
||||||
def handle_file_select(filename):
|
def handle_file_select(filename):
|
||||||
if filename:
|
if filename:
|
||||||
try:
|
try:
|
||||||
|
@ -52,16 +55,16 @@ def setup_event_handlers(components: dict):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error reading file: {e}")
|
print(f"Error reading file: {e}")
|
||||||
return gr.update(value="")
|
return gr.update(value="")
|
||||||
|
|
||||||
def handle_file_upload(file):
|
def handle_file_upload(file):
|
||||||
if file is None:
|
if file is None:
|
||||||
return gr.update(choices=files.list_input_files())
|
return gr.update(choices=files.list_input_files())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Copy file to inputs directory
|
# Copy file to inputs directory
|
||||||
filename = os.path.basename(file.name)
|
filename = os.path.basename(file.name)
|
||||||
target_path = os.path.join(files.INPUTS_DIR, filename)
|
target_path = os.path.join(files.INPUTS_DIR, filename)
|
||||||
|
|
||||||
# Handle duplicate filenames
|
# Handle duplicate filenames
|
||||||
base, ext = os.path.splitext(filename)
|
base, ext = os.path.splitext(filename)
|
||||||
counter = 1
|
counter = 1
|
||||||
|
@ -69,43 +72,36 @@ def setup_event_handlers(components: dict):
|
||||||
new_name = f"{base}_{counter}{ext}"
|
new_name = f"{base}_{counter}{ext}"
|
||||||
target_path = os.path.join(files.INPUTS_DIR, new_name)
|
target_path = os.path.join(files.INPUTS_DIR, new_name)
|
||||||
counter += 1
|
counter += 1
|
||||||
|
|
||||||
shutil.copy2(file.name, target_path)
|
shutil.copy2(file.name, target_path)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error uploading file: {e}")
|
print(f"Error uploading file: {e}")
|
||||||
|
|
||||||
return gr.update(choices=files.list_input_files())
|
return gr.update(choices=files.list_input_files())
|
||||||
|
|
||||||
def generate_from_text(text, voice, format, speed):
|
def generate_from_text(text, voice, format, speed):
|
||||||
"""Generate speech from direct text input"""
|
"""Generate speech from direct text input"""
|
||||||
is_available, _ = api.check_api_status()
|
is_available, _ = api.check_api_status()
|
||||||
if not is_available:
|
if not is_available:
|
||||||
gr.Warning("TTS Service is currently unavailable")
|
gr.Warning("TTS Service is currently unavailable")
|
||||||
return [
|
return [None, gr.update(choices=files.list_output_files())]
|
||||||
None,
|
|
||||||
gr.update(choices=files.list_output_files())
|
|
||||||
]
|
|
||||||
|
|
||||||
if not text or not text.strip():
|
if not text or not text.strip():
|
||||||
gr.Warning("Please enter text in the input box")
|
gr.Warning("Please enter text in the input box")
|
||||||
return [
|
return [None, gr.update(choices=files.list_output_files())]
|
||||||
None,
|
|
||||||
gr.update(choices=files.list_output_files())
|
|
||||||
]
|
|
||||||
|
|
||||||
files.save_text(text)
|
files.save_text(text)
|
||||||
result = api.text_to_speech(text, voice, format, speed)
|
result = api.text_to_speech(text, voice, format, speed)
|
||||||
if result is None:
|
if result is None:
|
||||||
gr.Warning("Failed to generate speech. Please try again.")
|
gr.Warning("Failed to generate speech. Please try again.")
|
||||||
return [
|
return [None, gr.update(choices=files.list_output_files())]
|
||||||
None,
|
|
||||||
gr.update(choices=files.list_output_files())
|
|
||||||
]
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
result,
|
result,
|
||||||
gr.update(choices=files.list_output_files(), value=os.path.basename(result))
|
gr.update(
|
||||||
|
choices=files.list_output_files(), value=os.path.basename(result)
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
def generate_from_file(selected_file, voice, format, speed):
|
def generate_from_file(selected_file, voice, format, speed):
|
||||||
|
@ -113,37 +109,30 @@ def setup_event_handlers(components: dict):
|
||||||
is_available, _ = api.check_api_status()
|
is_available, _ = api.check_api_status()
|
||||||
if not is_available:
|
if not is_available:
|
||||||
gr.Warning("TTS Service is currently unavailable")
|
gr.Warning("TTS Service is currently unavailable")
|
||||||
return [
|
return [None, gr.update(choices=files.list_output_files())]
|
||||||
None,
|
|
||||||
gr.update(choices=files.list_output_files())
|
|
||||||
]
|
|
||||||
|
|
||||||
if not selected_file:
|
if not selected_file:
|
||||||
gr.Warning("Please select a file")
|
gr.Warning("Please select a file")
|
||||||
return [
|
return [None, gr.update(choices=files.list_output_files())]
|
||||||
None,
|
|
||||||
gr.update(choices=files.list_output_files())
|
|
||||||
]
|
|
||||||
|
|
||||||
text = files.read_text_file(selected_file)
|
text = files.read_text_file(selected_file)
|
||||||
result = api.text_to_speech(text, voice, format, speed)
|
result = api.text_to_speech(text, voice, format, speed)
|
||||||
if result is None:
|
if result is None:
|
||||||
gr.Warning("Failed to generate speech. Please try again.")
|
gr.Warning("Failed to generate speech. Please try again.")
|
||||||
return [
|
return [None, gr.update(choices=files.list_output_files())]
|
||||||
None,
|
|
||||||
gr.update(choices=files.list_output_files())
|
|
||||||
]
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
result,
|
result,
|
||||||
gr.update(choices=files.list_output_files(), value=os.path.basename(result))
|
gr.update(
|
||||||
|
choices=files.list_output_files(), value=os.path.basename(result)
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
def play_selected(file_path):
|
def play_selected(file_path):
|
||||||
if file_path and os.path.exists(file_path):
|
if file_path and os.path.exists(file_path):
|
||||||
return gr.update(value=file_path, visible=True)
|
return gr.update(value=file_path, visible=True)
|
||||||
return gr.update(visible=False)
|
return gr.update(visible=False)
|
||||||
|
|
||||||
def clear_files(voice, format, speed):
|
def clear_files(voice, format, speed):
|
||||||
"""Delete all input files and clear UI components while preserving model settings"""
|
"""Delete all input files and clear UI components while preserving model settings"""
|
||||||
files.delete_all_input_files()
|
files.delete_all_input_files()
|
||||||
|
@ -155,7 +144,7 @@ def setup_event_handlers(components: dict):
|
||||||
gr.update(choices=files.list_output_files()), # output_files
|
gr.update(choices=files.list_output_files()), # output_files
|
||||||
gr.update(value=voice), # voice
|
gr.update(value=voice), # voice
|
||||||
gr.update(value=format), # format
|
gr.update(value=format), # format
|
||||||
gr.update(value=speed) # speed
|
gr.update(value=speed), # speed
|
||||||
]
|
]
|
||||||
|
|
||||||
def clear_outputs():
|
def clear_outputs():
|
||||||
|
@ -164,43 +153,40 @@ def setup_event_handlers(components: dict):
|
||||||
return [
|
return [
|
||||||
None, # audio_output
|
None, # audio_output
|
||||||
gr.update(choices=[], value=None), # output_files
|
gr.update(choices=[], value=None), # output_files
|
||||||
gr.update(visible=False) # selected_audio
|
gr.update(visible=False), # selected_audio
|
||||||
]
|
]
|
||||||
|
|
||||||
# Connect event handlers
|
# Connect event handlers
|
||||||
components["model"]["status_btn"].click(
|
components["model"]["status_btn"].click(
|
||||||
fn=refresh_status,
|
fn=refresh_status,
|
||||||
outputs=[
|
outputs=[components["model"]["status_btn"], components["model"]["voice"]],
|
||||||
components["model"]["status_btn"],
|
|
||||||
components["model"]["voice"]
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
components["input"]["file_select"].change(
|
components["input"]["file_select"].change(
|
||||||
fn=handle_file_select,
|
fn=handle_file_select,
|
||||||
inputs=[components["input"]["file_select"]],
|
inputs=[components["input"]["file_select"]],
|
||||||
outputs=[components["input"]["file_preview"]]
|
outputs=[components["input"]["file_preview"]],
|
||||||
)
|
)
|
||||||
|
|
||||||
components["input"]["file_upload"].upload(
|
components["input"]["file_upload"].upload(
|
||||||
fn=handle_file_upload,
|
fn=handle_file_upload,
|
||||||
inputs=[components["input"]["file_upload"]],
|
inputs=[components["input"]["file_upload"]],
|
||||||
outputs=[components["input"]["file_select"]]
|
outputs=[components["input"]["file_select"]],
|
||||||
)
|
)
|
||||||
|
|
||||||
components["output"]["play_btn"].click(
|
components["output"]["play_btn"].click(
|
||||||
fn=play_selected,
|
fn=play_selected,
|
||||||
inputs=[components["output"]["output_files"]],
|
inputs=[components["output"]["output_files"]],
|
||||||
outputs=[components["output"]["selected_audio"]]
|
outputs=[components["output"]["selected_audio"]],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Connect clear files button
|
# Connect clear files button
|
||||||
components["input"]["clear_files"].click(
|
components["input"]["clear_files"].click(
|
||||||
fn=clear_files,
|
fn=clear_files,
|
||||||
inputs=[
|
inputs=[
|
||||||
components["model"]["voice"],
|
components["model"]["voice"],
|
||||||
components["model"]["format"],
|
components["model"]["format"],
|
||||||
components["model"]["speed"]
|
components["model"]["speed"],
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
components["input"]["file_select"],
|
components["input"]["file_select"],
|
||||||
|
@ -210,10 +196,10 @@ def setup_event_handlers(components: dict):
|
||||||
components["output"]["output_files"],
|
components["output"]["output_files"],
|
||||||
components["model"]["voice"],
|
components["model"]["voice"],
|
||||||
components["model"]["format"],
|
components["model"]["format"],
|
||||||
components["model"]["speed"]
|
components["model"]["speed"],
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Connect submit buttons for each tab
|
# Connect submit buttons for each tab
|
||||||
components["input"]["text_submit"].click(
|
components["input"]["text_submit"].click(
|
||||||
fn=generate_from_text,
|
fn=generate_from_text,
|
||||||
|
@ -221,22 +207,22 @@ def setup_event_handlers(components: dict):
|
||||||
components["input"]["text_input"],
|
components["input"]["text_input"],
|
||||||
components["model"]["voice"],
|
components["model"]["voice"],
|
||||||
components["model"]["format"],
|
components["model"]["format"],
|
||||||
components["model"]["speed"]
|
components["model"]["speed"],
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
components["output"]["audio_output"],
|
components["output"]["audio_output"],
|
||||||
components["output"]["output_files"]
|
components["output"]["output_files"],
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Connect clear outputs button
|
# Connect clear outputs button
|
||||||
components["output"]["clear_outputs"].click(
|
components["output"]["clear_outputs"].click(
|
||||||
fn=clear_outputs,
|
fn=clear_outputs,
|
||||||
outputs=[
|
outputs=[
|
||||||
components["output"]["audio_output"],
|
components["output"]["audio_output"],
|
||||||
components["output"]["output_files"],
|
components["output"]["output_files"],
|
||||||
components["output"]["selected_audio"]
|
components["output"]["selected_audio"],
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
components["input"]["file_submit"].click(
|
components["input"]["file_submit"].click(
|
||||||
|
@ -245,10 +231,10 @@ def setup_event_handlers(components: dict):
|
||||||
components["input"]["file_select"],
|
components["input"]["file_select"],
|
||||||
components["model"]["voice"],
|
components["model"]["voice"],
|
||||||
components["model"]["format"],
|
components["model"]["format"],
|
||||||
components["model"]["speed"]
|
components["model"]["speed"],
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
components["output"]["audio_output"],
|
components["output"]["audio_output"],
|
||||||
components["output"]["output_files"]
|
components["output"]["output_files"],
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,69 +1,75 @@
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from . import api
|
from . import api
|
||||||
from .components import create_input_column, create_model_column, create_output_column
|
|
||||||
from .handlers import setup_event_handlers
|
from .handlers import setup_event_handlers
|
||||||
|
from .components import create_input_column, create_model_column, create_output_column
|
||||||
|
|
||||||
|
|
||||||
def create_interface():
|
def create_interface():
|
||||||
"""Create the main Gradio interface."""
|
"""Create the main Gradio interface."""
|
||||||
# Skip initial status check - let the timer handle it
|
# Skip initial status check - let the timer handle it
|
||||||
is_available, available_voices = False, []
|
is_available, available_voices = False, []
|
||||||
|
|
||||||
with gr.Blocks(
|
with gr.Blocks(title="Kokoro TTS Demo", theme=gr.themes.Monochrome()) as demo:
|
||||||
title="Kokoro TTS Demo",
|
gr.HTML(
|
||||||
theme=gr.themes.Monochrome()
|
value='<div style="display: flex; gap: 0;">'
|
||||||
) as demo:
|
'<a href="https://huggingface.co/hexgrad/Kokoro-82M" target="_blank" style="color: #2196F3; text-decoration: none; margin: 2px; border: 1px solid #2196F3; padding: 4px 8px; height: 24px; box-sizing: border-box; display: inline-flex; align-items: center;">Kokoro-82M HF Repo</a>'
|
||||||
gr.HTML(value='<div style="display: flex; gap: 0;">'
|
'<a href="https://github.com/remsky/Kokoro-FastAPI" target="_blank" style="color: #2196F3; text-decoration: none; margin: 2px; border: 1px solid #2196F3; padding: 4px 8px; height: 24px; box-sizing: border-box; display: inline-flex; align-items: center;">Kokoro-FastAPI Repo</a>'
|
||||||
'<a href="https://huggingface.co/hexgrad/Kokoro-82M" target="_blank" style="color: #2196F3; text-decoration: none; margin: 2px; border: 1px solid #2196F3; padding: 4px 8px; height: 24px; box-sizing: border-box; display: inline-flex; align-items: center;">Kokoro-82M HF Repo</a>'
|
"</div>",
|
||||||
'<a href="https://github.com/remsky/Kokoro-FastAPI" target="_blank" style="color: #2196F3; text-decoration: none; margin: 2px; border: 1px solid #2196F3; padding: 4px 8px; height: 24px; box-sizing: border-box; display: inline-flex; align-items: center;">Kokoro-FastAPI Repo</a>'
|
show_label=False,
|
||||||
'</div>', show_label=False)
|
)
|
||||||
|
|
||||||
# Main interface
|
# Main interface
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
# Create columns
|
# Create columns
|
||||||
input_col, input_components = create_input_column()
|
input_col, input_components = create_input_column()
|
||||||
model_col, model_components = create_model_column(available_voices) # Pass initial voices
|
model_col, model_components = create_model_column(
|
||||||
|
available_voices
|
||||||
|
) # Pass initial voices
|
||||||
output_col, output_components = create_output_column()
|
output_col, output_components = create_output_column()
|
||||||
|
|
||||||
# Collect all components
|
# Collect all components
|
||||||
components = {
|
components = {
|
||||||
"input": input_components,
|
"input": input_components,
|
||||||
"model": model_components,
|
"model": model_components,
|
||||||
"output": output_components
|
"output": output_components,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Set up event handlers
|
# Set up event handlers
|
||||||
setup_event_handlers(components)
|
setup_event_handlers(components)
|
||||||
|
|
||||||
# Add periodic status check with Timer
|
# Add periodic status check with Timer
|
||||||
def update_status():
|
def update_status():
|
||||||
try:
|
try:
|
||||||
is_available, voices = api.check_api_status()
|
is_available, voices = api.check_api_status()
|
||||||
status = "Available" if is_available else "Waiting for Service..."
|
status = "Available" if is_available else "Waiting for Service..."
|
||||||
|
|
||||||
if is_available and voices:
|
if is_available and voices:
|
||||||
# Service is available, update UI and stop timer
|
# Service is available, update UI and stop timer
|
||||||
current_voice = components["model"]["voice"].value
|
current_voice = components["model"]["voice"].value
|
||||||
default_voice = current_voice if current_voice in voices else voices[0]
|
default_voice = (
|
||||||
|
current_voice if current_voice in voices else voices[0]
|
||||||
|
)
|
||||||
# Return values in same order as outputs list
|
# Return values in same order as outputs list
|
||||||
return [
|
return [
|
||||||
gr.update(
|
gr.update(
|
||||||
value=f"🔄 TTS Service: {status}",
|
value=f"🔄 TTS Service: {status}",
|
||||||
interactive=True,
|
interactive=True,
|
||||||
variant="secondary"
|
variant="secondary",
|
||||||
),
|
),
|
||||||
gr.update(choices=voices, value=default_voice),
|
gr.update(choices=voices, value=default_voice),
|
||||||
gr.update(active=False) # Stop timer
|
gr.update(active=False), # Stop timer
|
||||||
]
|
]
|
||||||
|
|
||||||
# Service not available yet, keep checking
|
# Service not available yet, keep checking
|
||||||
return [
|
return [
|
||||||
gr.update(
|
gr.update(
|
||||||
value=f"⌛ TTS Service: {status}",
|
value=f"⌛ TTS Service: {status}",
|
||||||
interactive=True,
|
interactive=True,
|
||||||
variant="secondary"
|
variant="secondary",
|
||||||
),
|
),
|
||||||
gr.update(choices=[], value=None),
|
gr.update(choices=[], value=None),
|
||||||
gr.update(active=True)
|
gr.update(active=True),
|
||||||
]
|
]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error in status update: {str(e)}")
|
print(f"Error in status update: {str(e)}")
|
||||||
|
@ -72,20 +78,20 @@ def create_interface():
|
||||||
gr.update(
|
gr.update(
|
||||||
value="❌ TTS Service: Connection Error",
|
value="❌ TTS Service: Connection Error",
|
||||||
interactive=True,
|
interactive=True,
|
||||||
variant="secondary"
|
variant="secondary",
|
||||||
),
|
),
|
||||||
gr.update(choices=[], value=None),
|
gr.update(choices=[], value=None),
|
||||||
gr.update(active=True)
|
gr.update(active=True),
|
||||||
]
|
]
|
||||||
|
|
||||||
timer = gr.Timer(value=5) # Check every 5 seconds
|
timer = gr.Timer(value=5) # Check every 5 seconds
|
||||||
timer.tick(
|
timer.tick(
|
||||||
fn=update_status,
|
fn=update_status,
|
||||||
outputs=[
|
outputs=[
|
||||||
components["model"]["status_btn"],
|
components["model"]["status_btn"],
|
||||||
components["model"]["voice"],
|
components["model"]["voice"],
|
||||||
timer
|
timer,
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
return demo
|
return demo
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import pytest
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
|
from unittest.mock import patch, mock_open
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
from unittest.mock import patch, mock_open
|
|
||||||
from ui.lib import api
|
from ui.lib import api
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,12 +59,11 @@ def test_check_api_status_connection_error():
|
||||||
|
|
||||||
def test_text_to_speech_success(mock_response, tmp_path):
|
def test_text_to_speech_success(mock_response, tmp_path):
|
||||||
"""Test successful speech generation"""
|
"""Test successful speech generation"""
|
||||||
with patch("requests.post", return_value=mock_response({})), \
|
with patch("requests.post", return_value=mock_response({})), patch(
|
||||||
patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)), \
|
"ui.lib.api.OUTPUTS_DIR", str(tmp_path)
|
||||||
patch("builtins.open", mock_open()) as mock_file:
|
), patch("builtins.open", mock_open()) as mock_file:
|
||||||
|
|
||||||
result = api.text_to_speech("test text", "voice1", "mp3", 1.0)
|
result = api.text_to_speech("test text", "voice1", "mp3", 1.0)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert "output_" in result
|
assert "output_" in result
|
||||||
assert result.endswith(".mp3")
|
assert result.endswith(".mp3")
|
||||||
|
@ -105,25 +106,24 @@ def test_get_status_html_unavailable():
|
||||||
|
|
||||||
def test_text_to_speech_api_params(mock_response, tmp_path):
|
def test_text_to_speech_api_params(mock_response, tmp_path):
|
||||||
"""Test correct API parameters are sent"""
|
"""Test correct API parameters are sent"""
|
||||||
with patch("requests.post") as mock_post, \
|
with patch("requests.post") as mock_post, patch(
|
||||||
patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)), \
|
"ui.lib.api.OUTPUTS_DIR", str(tmp_path)
|
||||||
patch("builtins.open", mock_open()):
|
), patch("builtins.open", mock_open()):
|
||||||
|
|
||||||
mock_post.return_value = mock_response({})
|
mock_post.return_value = mock_response({})
|
||||||
api.text_to_speech("test text", "voice1", "mp3", 1.5)
|
api.text_to_speech("test text", "voice1", "mp3", 1.5)
|
||||||
|
|
||||||
mock_post.assert_called_once()
|
mock_post.assert_called_once()
|
||||||
args, kwargs = mock_post.call_args
|
args, kwargs = mock_post.call_args
|
||||||
|
|
||||||
# Check request body
|
# Check request body
|
||||||
assert kwargs["json"] == {
|
assert kwargs["json"] == {
|
||||||
"model": "kokoro",
|
"model": "kokoro",
|
||||||
"input": "test text",
|
"input": "test text",
|
||||||
"voice": "voice1",
|
"voice": "voice1",
|
||||||
"response_format": "mp3",
|
"response_format": "mp3",
|
||||||
"speed": 1.5
|
"speed": 1.5,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Check headers and timeout
|
# Check headers and timeout
|
||||||
assert kwargs["headers"] == {"Content-Type": "application/json"}
|
assert kwargs["headers"] == {"Content-Type": "application/json"}
|
||||||
assert kwargs["timeout"] == 300
|
assert kwargs["timeout"] == 300
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
import pytest
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ui.lib.config import AUDIO_FORMATS
|
||||||
from ui.lib.components.model import create_model_column
|
from ui.lib.components.model import create_model_column
|
||||||
from ui.lib.components.output import create_output_column
|
from ui.lib.components.output import create_output_column
|
||||||
from ui.lib.config import AUDIO_FORMATS
|
|
||||||
|
|
||||||
|
|
||||||
def test_create_model_column_structure():
|
def test_create_model_column_structure():
|
||||||
|
@ -15,12 +16,7 @@ def test_create_model_column_structure():
|
||||||
assert isinstance(components, dict)
|
assert isinstance(components, dict)
|
||||||
|
|
||||||
# Test expected components presence
|
# Test expected components presence
|
||||||
expected_components = {
|
expected_components = {"status_btn", "voice", "format", "speed"}
|
||||||
"status_btn",
|
|
||||||
"voice",
|
|
||||||
"format",
|
|
||||||
"speed"
|
|
||||||
}
|
|
||||||
assert set(components.keys()) == expected_components
|
assert set(components.keys()) == expected_components
|
||||||
|
|
||||||
# Test component types
|
# Test component types
|
||||||
|
@ -78,7 +74,7 @@ def test_create_output_column_structure():
|
||||||
"output_files",
|
"output_files",
|
||||||
"play_btn",
|
"play_btn",
|
||||||
"selected_audio",
|
"selected_audio",
|
||||||
"clear_outputs"
|
"clear_outputs",
|
||||||
}
|
}
|
||||||
assert set(components.keys()) == expected_components
|
assert set(components.keys()) == expected_components
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
import os
|
import os
|
||||||
import pytest
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from ui.lib import files
|
from ui.lib import files
|
||||||
from ui.lib.config import AUDIO_FORMATS
|
from ui.lib.config import AUDIO_FORMATS
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import pytest
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
import pytest
|
||||||
|
|
||||||
from ui.lib.components.input import create_input_column
|
from ui.lib.components.input import create_input_column
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,12 +1,15 @@
|
||||||
import pytest
|
from unittest.mock import MagicMock, PropertyMock, patch
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from unittest.mock import patch, MagicMock, PropertyMock
|
import pytest
|
||||||
|
|
||||||
from ui.lib.interface import create_interface
|
from ui.lib.interface import create_interface
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_timer():
|
def mock_timer():
|
||||||
"""Create a mock timer with events property"""
|
"""Create a mock timer with events property"""
|
||||||
|
|
||||||
class MockEvent:
|
class MockEvent:
|
||||||
def __init__(self, fn):
|
def __init__(self, fn):
|
||||||
self.fn = fn
|
self.fn = fn
|
||||||
|
@ -30,7 +33,7 @@ def test_create_interface_structure():
|
||||||
"""Test the basic structure of the created interface"""
|
"""Test the basic structure of the created interface"""
|
||||||
with patch("ui.lib.api.check_api_status", return_value=(False, [])):
|
with patch("ui.lib.api.check_api_status", return_value=(False, [])):
|
||||||
demo = create_interface()
|
demo = create_interface()
|
||||||
|
|
||||||
# Test interface type and theme
|
# Test interface type and theme
|
||||||
assert isinstance(demo, gr.Blocks)
|
assert isinstance(demo, gr.Blocks)
|
||||||
assert demo.title == "Kokoro TTS Demo"
|
assert demo.title == "Kokoro TTS Demo"
|
||||||
|
@ -41,15 +44,14 @@ def test_interface_html_links():
|
||||||
"""Test that HTML links are properly configured"""
|
"""Test that HTML links are properly configured"""
|
||||||
with patch("ui.lib.api.check_api_status", return_value=(False, [])):
|
with patch("ui.lib.api.check_api_status", return_value=(False, [])):
|
||||||
demo = create_interface()
|
demo = create_interface()
|
||||||
|
|
||||||
# Find HTML component
|
# Find HTML component
|
||||||
html_components = [
|
html_components = [
|
||||||
comp for comp in demo.blocks.values()
|
comp for comp in demo.blocks.values() if isinstance(comp, gr.HTML)
|
||||||
if isinstance(comp, gr.HTML)
|
|
||||||
]
|
]
|
||||||
assert len(html_components) > 0
|
assert len(html_components) > 0
|
||||||
html = html_components[0]
|
html = html_components[0]
|
||||||
|
|
||||||
# Check for required links
|
# Check for required links
|
||||||
assert 'href="https://huggingface.co/hexgrad/Kokoro-82M"' in html.value
|
assert 'href="https://huggingface.co/hexgrad/Kokoro-82M"' in html.value
|
||||||
assert 'href="https://github.com/remsky/Kokoro-FastAPI"' in html.value
|
assert 'href="https://github.com/remsky/Kokoro-FastAPI"' in html.value
|
||||||
|
@ -60,16 +62,17 @@ def test_interface_html_links():
|
||||||
def test_update_status_available(mock_timer):
|
def test_update_status_available(mock_timer):
|
||||||
"""Test status update when service is available"""
|
"""Test status update when service is available"""
|
||||||
voices = ["voice1", "voice2"]
|
voices = ["voice1", "voice2"]
|
||||||
with patch("ui.lib.api.check_api_status", return_value=(True, voices)), \
|
with patch("ui.lib.api.check_api_status", return_value=(True, voices)), patch(
|
||||||
patch("gradio.Timer", return_value=mock_timer):
|
"gradio.Timer", return_value=mock_timer
|
||||||
|
):
|
||||||
demo = create_interface()
|
demo = create_interface()
|
||||||
|
|
||||||
# Get the update function
|
# Get the update function
|
||||||
update_fn = mock_timer.events[0].fn
|
update_fn = mock_timer.events[0].fn
|
||||||
|
|
||||||
# Test update with available service
|
# Test update with available service
|
||||||
updates = update_fn()
|
updates = update_fn()
|
||||||
|
|
||||||
assert "Available" in updates[0]["value"]
|
assert "Available" in updates[0]["value"]
|
||||||
assert updates[1]["choices"] == voices
|
assert updates[1]["choices"] == voices
|
||||||
assert updates[1]["value"] == voices[0]
|
assert updates[1]["value"] == voices[0]
|
||||||
|
@ -78,13 +81,14 @@ def test_update_status_available(mock_timer):
|
||||||
|
|
||||||
def test_update_status_unavailable(mock_timer):
|
def test_update_status_unavailable(mock_timer):
|
||||||
"""Test status update when service is unavailable"""
|
"""Test status update when service is unavailable"""
|
||||||
with patch("ui.lib.api.check_api_status", return_value=(False, [])), \
|
with patch("ui.lib.api.check_api_status", return_value=(False, [])), patch(
|
||||||
patch("gradio.Timer", return_value=mock_timer):
|
"gradio.Timer", return_value=mock_timer
|
||||||
|
):
|
||||||
demo = create_interface()
|
demo = create_interface()
|
||||||
update_fn = mock_timer.events[0].fn
|
update_fn = mock_timer.events[0].fn
|
||||||
|
|
||||||
updates = update_fn()
|
updates = update_fn()
|
||||||
|
|
||||||
assert "Waiting for Service" in updates[0]["value"]
|
assert "Waiting for Service" in updates[0]["value"]
|
||||||
assert updates[1]["choices"] == []
|
assert updates[1]["choices"] == []
|
||||||
assert updates[1]["value"] is None
|
assert updates[1]["value"] is None
|
||||||
|
@ -93,13 +97,14 @@ def test_update_status_unavailable(mock_timer):
|
||||||
|
|
||||||
def test_update_status_error(mock_timer):
|
def test_update_status_error(mock_timer):
|
||||||
"""Test status update when an error occurs"""
|
"""Test status update when an error occurs"""
|
||||||
with patch("ui.lib.api.check_api_status", side_effect=Exception("Test error")), \
|
with patch(
|
||||||
patch("gradio.Timer", return_value=mock_timer):
|
"ui.lib.api.check_api_status", side_effect=Exception("Test error")
|
||||||
|
), patch("gradio.Timer", return_value=mock_timer):
|
||||||
demo = create_interface()
|
demo = create_interface()
|
||||||
update_fn = mock_timer.events[0].fn
|
update_fn = mock_timer.events[0].fn
|
||||||
|
|
||||||
updates = update_fn()
|
updates = update_fn()
|
||||||
|
|
||||||
assert "Connection Error" in updates[0]["value"]
|
assert "Connection Error" in updates[0]["value"]
|
||||||
assert updates[1]["choices"] == []
|
assert updates[1]["choices"] == []
|
||||||
assert updates[1]["value"] is None
|
assert updates[1]["value"] is None
|
||||||
|
@ -108,10 +113,11 @@ def test_update_status_error(mock_timer):
|
||||||
|
|
||||||
def test_timer_configuration(mock_timer):
|
def test_timer_configuration(mock_timer):
|
||||||
"""Test timer configuration"""
|
"""Test timer configuration"""
|
||||||
with patch("ui.lib.api.check_api_status", return_value=(False, [])), \
|
with patch("ui.lib.api.check_api_status", return_value=(False, [])), patch(
|
||||||
patch("gradio.Timer", return_value=mock_timer):
|
"gradio.Timer", return_value=mock_timer
|
||||||
|
):
|
||||||
demo = create_interface()
|
demo = create_interface()
|
||||||
|
|
||||||
assert mock_timer.value == 5 # Check interval is 5 seconds
|
assert mock_timer.value == 5 # Check interval is 5 seconds
|
||||||
assert len(mock_timer.events) == 1 # Should have one event handler
|
assert len(mock_timer.events) == 1 # Should have one event handler
|
||||||
|
|
||||||
|
@ -120,20 +126,21 @@ def test_interface_components_presence():
|
||||||
"""Test that all required components are present"""
|
"""Test that all required components are present"""
|
||||||
with patch("ui.lib.api.check_api_status", return_value=(False, [])):
|
with patch("ui.lib.api.check_api_status", return_value=(False, [])):
|
||||||
demo = create_interface()
|
demo = create_interface()
|
||||||
|
|
||||||
# Check for main component sections
|
# Check for main component sections
|
||||||
components = {
|
components = {
|
||||||
comp.label for comp in demo.blocks.values()
|
comp.label
|
||||||
if hasattr(comp, 'label') and comp.label
|
for comp in demo.blocks.values()
|
||||||
|
if hasattr(comp, "label") and comp.label
|
||||||
}
|
}
|
||||||
|
|
||||||
required_components = {
|
required_components = {
|
||||||
"Text to speak",
|
"Text to speak",
|
||||||
"Voice",
|
"Voice",
|
||||||
"Audio Format",
|
"Audio Format",
|
||||||
"Speed",
|
"Speed",
|
||||||
"Generated Speech",
|
"Generated Speech",
|
||||||
"Previous Outputs"
|
"Previous Outputs",
|
||||||
}
|
}
|
||||||
|
|
||||||
assert required_components.issubset(components)
|
assert required_components.issubset(components)
|
||||||
|
|
Loading…
Add table
Reference in a new issue