mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Update dependencies, enhance voice management, and add captioned speech support
This commit is contained in:
parent
9198de2d95
commit
6c234a3b67
31 changed files with 979 additions and 169 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -69,3 +69,4 @@ examples/*.ogg
|
|||
examples/speech.mp3
|
||||
examples/phoneme_examples/output/*.wav
|
||||
examples/assorted_checks/benchmarks/output_audio/*
|
||||
uv.lock
|
||||
|
|
|
@ -1,4 +1,9 @@
|
|||
"""Model configuration for Kokoro V1."""
|
||||
"""Model configuration for Kokoro V1.
|
||||
|
||||
This module provides model-specific configuration settings that complement the application-level
|
||||
settings in config.py. While config.py handles general application settings (API, paths, etc.),
|
||||
this module focuses on memory management and model file paths.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
@ -9,51 +14,29 @@ class KokoroV1Config(BaseModel):
|
|||
class Config:
|
||||
frozen = True
|
||||
|
||||
class PyTorchCPUConfig(BaseModel):
|
||||
"""PyTorch CPU backend configuration."""
|
||||
|
||||
class PyTorchConfig(BaseModel):
|
||||
"""PyTorch backend configuration."""
|
||||
memory_threshold: float = Field(0.8, description="Memory threshold for cleanup")
|
||||
retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors")
|
||||
num_threads: int = Field(8, description="Number of threads for parallel operations")
|
||||
pin_memory: bool = Field(True, description="Whether to pin memory for faster CPU-GPU transfer")
|
||||
|
||||
class Config:
|
||||
frozen = True
|
||||
|
||||
|
||||
class PyTorchGPUConfig(BaseModel):
|
||||
"""PyTorch GPU backend configuration."""
|
||||
|
||||
device_id: int = Field(0, description="CUDA device ID")
|
||||
use_triton: bool = Field(True, description="Whether to use Triton for CUDA kernels")
|
||||
memory_threshold: float = Field(0.8, description="Memory threshold for cleanup")
|
||||
retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors")
|
||||
sync_cuda: bool = Field(True, description="Whether to synchronize CUDA operations")
|
||||
cuda_streams: int = Field(2, description="Number of CUDA streams for inference")
|
||||
stream_timeout: int = Field(60, description="Stream timeout in seconds")
|
||||
|
||||
class Config:
|
||||
frozen = True
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
"""Kokoro V1 model configuration."""
|
||||
|
||||
# General settings
|
||||
device_type: str = Field("cpu", description="Device type ('cpu' or 'gpu')")
|
||||
cache_voices: bool = Field(True, description="Whether to cache voice tensors")
|
||||
voice_cache_size: int = Field(2, description="Maximum number of cached voices")
|
||||
|
||||
# Model filename
|
||||
pytorch_kokoro_v1_file: str = Field("v1_0/kokoro-v1_0.pth", description="PyTorch Kokoro V1 model filename")
|
||||
|
||||
# Backend configs
|
||||
pytorch_cpu: PyTorchCPUConfig = Field(default_factory=PyTorchCPUConfig)
|
||||
pytorch_gpu: PyTorchGPUConfig = Field(default_factory=PyTorchGPUConfig)
|
||||
# Backend config
|
||||
pytorch_gpu: PyTorchConfig = Field(default_factory=PyTorchConfig)
|
||||
|
||||
class Config:
|
||||
frozen = True
|
||||
|
||||
|
||||
# Global instance
|
||||
model_config = ModelConfig()
|
|
@ -160,7 +160,7 @@ async def list_voices() -> List[str]:
|
|||
return sorted([name[:-3] for name in voices]) # Remove .pt extension
|
||||
|
||||
|
||||
async def load_voice_tensor(voice_path: str, device: str = "cpu") -> torch.Tensor:
|
||||
async def load_voice_tensor(voice_path: str, device: str = "cpu", weights_only=False) -> torch.Tensor:
|
||||
"""Load voice tensor from file.
|
||||
|
||||
Args:
|
||||
|
@ -179,7 +179,7 @@ async def load_voice_tensor(voice_path: str, device: str = "cpu") -> torch.Tenso
|
|||
return torch.load(
|
||||
io.BytesIO(data),
|
||||
map_location=device,
|
||||
weights_only=True
|
||||
weights_only=weights_only
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load voice tensor from {voice_path}: {e}")
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import aiofiles
|
||||
from loguru import logger
|
||||
|
||||
from ..core import paths
|
||||
|
@ -57,7 +58,7 @@ class VoiceManager:
|
|||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load voice {voice_name}: {e}")
|
||||
|
||||
async def combine_voices(self, voices: List[str], device: Optional[str] = None) -> str:
|
||||
async def combine_voices(self, voices: List[str], device: Optional[str] = None) -> torch.Tensor:
|
||||
"""Combine multiple voices.
|
||||
|
||||
Args:
|
||||
|
@ -65,7 +66,7 @@ class VoiceManager:
|
|||
device: Optional override for target device
|
||||
|
||||
Returns:
|
||||
Name of combined voice
|
||||
Combined voice tensor
|
||||
|
||||
Raises:
|
||||
RuntimeError: If any voice not found
|
||||
|
@ -80,10 +81,7 @@ class VoiceManager:
|
|||
voice_tensors.append(voice)
|
||||
|
||||
combined = torch.mean(torch.stack(voice_tensors), dim=0)
|
||||
combined_name = "+".join(voices)
|
||||
self._voices[combined_name] = combined
|
||||
|
||||
return combined_name
|
||||
return combined
|
||||
|
||||
async def list_voices(self) -> List[str]:
|
||||
"""List available voice names.
|
||||
|
|
|
@ -8,14 +8,19 @@ from loguru import logger
|
|||
|
||||
from ..services.audio import AudioService, AudioNormalizer
|
||||
from ..services.streaming_audio_writer import StreamingAudioWriter
|
||||
from ..services.text_processing import phonemize, smart_split
|
||||
from ..services.text_processing.vocabulary import tokenize
|
||||
from ..services.text_processing import smart_split
|
||||
from kokoro import KPipeline
|
||||
from ..services.tts_service import TTSService
|
||||
from ..structures.text_schemas import (
|
||||
GenerateFromPhonemesRequest,
|
||||
PhonemeRequest,
|
||||
PhonemeResponse,
|
||||
)
|
||||
from ..structures import (
|
||||
CaptionedSpeechRequest,
|
||||
CaptionedSpeechResponse,
|
||||
WordTimestamp
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["text processing"])
|
||||
|
||||
|
@ -26,11 +31,10 @@ async def get_tts_service() -> TTSService:
|
|||
|
||||
@router.post("/dev/phonemize", response_model=PhonemeResponse)
|
||||
async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse:
|
||||
"""Convert text to phonemes and tokens
|
||||
"""Convert text to phonemes using Kokoro's quiet mode.
|
||||
|
||||
Args:
|
||||
request: Request containing text and language
|
||||
tts_service: Injected TTSService instance
|
||||
|
||||
Returns:
|
||||
Phonemes and token IDs
|
||||
|
@ -39,14 +43,17 @@ async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse:
|
|||
if not request.text:
|
||||
raise ValueError("Text cannot be empty")
|
||||
|
||||
# Get phonemes
|
||||
phonemes = phonemize(request.text, request.language)
|
||||
if not phonemes:
|
||||
raise ValueError("Failed to generate phonemes")
|
||||
# Initialize Kokoro pipeline in quiet mode (no model)
|
||||
pipeline = KPipeline(lang_code=request.language, model=False)
|
||||
|
||||
# Get first result from pipeline (we only need one since we're not chunking)
|
||||
for result in pipeline(request.text):
|
||||
# result.graphemes = original text
|
||||
# result.phonemes = phonemized text
|
||||
# result.tokens = token objects (if available)
|
||||
return PhonemeResponse(phonemes=result.phonemes, tokens=[])
|
||||
|
||||
# Get tokens (without adding start/end tokens to match process_text behavior)
|
||||
tokens = tokenize(phonemes)
|
||||
return PhonemeResponse(phonemes=phonemes, tokens=tokens)
|
||||
raise ValueError("Failed to generate phonemes")
|
||||
except ValueError as e:
|
||||
logger.error(f"Error in phoneme generation: {str(e)}")
|
||||
raise HTTPException(
|
||||
|
@ -63,7 +70,7 @@ async def generate_from_phonemes(
|
|||
client_request: Request,
|
||||
tts_service: TTSService = Depends(get_tts_service),
|
||||
) -> StreamingResponse:
|
||||
"""Generate audio directly from phonemes with proper streaming"""
|
||||
"""Generate audio directly from phonemes using Kokoro's phoneme format"""
|
||||
try:
|
||||
# Basic validation
|
||||
if not isinstance(request.phonemes, str):
|
||||
|
@ -77,41 +84,30 @@ async def generate_from_phonemes(
|
|||
|
||||
async def generate_chunks():
|
||||
try:
|
||||
has_data = False
|
||||
# Process phonemes in chunks
|
||||
async for chunk_text, _ in smart_split(request.phonemes):
|
||||
# Check if client is still connected
|
||||
is_disconnected = client_request.is_disconnected
|
||||
if callable(is_disconnected):
|
||||
is_disconnected = await is_disconnected()
|
||||
if is_disconnected:
|
||||
logger.info("Client disconnected, stopping audio generation")
|
||||
break
|
||||
|
||||
chunk_audio, _ = await tts_service.generate_from_phonemes(
|
||||
phonemes=chunk_text,
|
||||
voice=request.voice,
|
||||
speed=1.0
|
||||
)
|
||||
if chunk_audio is not None:
|
||||
has_data = True
|
||||
# Normalize audio before writing
|
||||
normalized_audio = await normalizer.normalize(chunk_audio)
|
||||
# Write chunk and yield bytes
|
||||
chunk_bytes = writer.write_chunk(normalized_audio)
|
||||
if chunk_bytes:
|
||||
yield chunk_bytes
|
||||
|
||||
if not has_data:
|
||||
raise ValueError("Failed to generate any audio data")
|
||||
|
||||
# Finalize and yield remaining bytes if we still have a connection
|
||||
if not (callable(is_disconnected) and await is_disconnected()):
|
||||
# Generate audio from phonemes
|
||||
chunk_audio, _ = await tts_service.generate_from_phonemes(
|
||||
phonemes=request.phonemes, # Pass complete phoneme string
|
||||
voice=request.voice,
|
||||
speed=1.0
|
||||
)
|
||||
|
||||
if chunk_audio is not None:
|
||||
# Normalize audio before writing
|
||||
normalized_audio = await normalizer.normalize(chunk_audio)
|
||||
# Write chunk and yield bytes
|
||||
chunk_bytes = writer.write_chunk(normalized_audio)
|
||||
if chunk_bytes:
|
||||
yield chunk_bytes
|
||||
|
||||
# Finalize and yield remaining bytes
|
||||
final_bytes = writer.write_chunk(finalize=True)
|
||||
if final_bytes:
|
||||
yield final_bytes
|
||||
else:
|
||||
raise ValueError("Failed to generate audio data")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in audio chunk generation: {str(e)}")
|
||||
logger.error(f"Error in audio generation: {str(e)}")
|
||||
# Clean up writer on error
|
||||
writer.write_chunk(finalize=True)
|
||||
# Re-raise the original exception
|
||||
|
@ -128,7 +124,6 @@ async def generate_from_phonemes(
|
|||
}
|
||||
)
|
||||
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Error generating audio: {str(e)}")
|
||||
raise HTTPException(
|
||||
|
@ -149,3 +144,92 @@ async def generate_from_phonemes(
|
|||
"type": "server_error"
|
||||
}
|
||||
)
|
||||
|
||||
@router.post("/dev/captioned_speech")
|
||||
async def create_captioned_speech(
|
||||
request: CaptionedSpeechRequest,
|
||||
tts_service: TTSService = Depends(get_tts_service),
|
||||
) -> StreamingResponse:
|
||||
"""Generate audio with word-level timestamps using Kokoro's output"""
|
||||
try:
|
||||
# Get voice path
|
||||
voice_name, voice_path = await tts_service._get_voice_path(request.voice)
|
||||
|
||||
# Generate audio with timestamps
|
||||
audio, _, word_timestamps = await tts_service.generate_audio(
|
||||
text=request.input,
|
||||
voice=voice_name,
|
||||
speed=request.speed,
|
||||
return_timestamps=True
|
||||
)
|
||||
|
||||
# Create streaming audio writer
|
||||
writer = StreamingAudioWriter(format=request.response_format, sample_rate=24000, channels=1)
|
||||
normalizer = AudioNormalizer()
|
||||
|
||||
async def generate_chunks():
|
||||
try:
|
||||
if audio is not None:
|
||||
# Normalize audio before writing
|
||||
normalized_audio = await normalizer.normalize(audio)
|
||||
# Write chunk and yield bytes
|
||||
chunk_bytes = writer.write_chunk(normalized_audio)
|
||||
if chunk_bytes:
|
||||
yield chunk_bytes
|
||||
|
||||
# Finalize and yield remaining bytes
|
||||
final_bytes = writer.write_chunk(finalize=True)
|
||||
if final_bytes:
|
||||
yield final_bytes
|
||||
else:
|
||||
raise ValueError("Failed to generate audio data")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in audio generation: {str(e)}")
|
||||
# Clean up writer on error
|
||||
writer.write_chunk(finalize=True)
|
||||
# Re-raise the original exception
|
||||
raise
|
||||
|
||||
# Convert timestamps to JSON and add as header
|
||||
import json
|
||||
logger.debug(f"Processing {len(word_timestamps)} word timestamps")
|
||||
timestamps_json = json.dumps([{
|
||||
'word': str(ts['word']), # Ensure string for text
|
||||
'start_time': float(ts['start_time']), # Ensure float for timestamps
|
||||
'end_time': float(ts['end_time'])
|
||||
} for ts in word_timestamps])
|
||||
logger.debug(f"Generated timestamps JSON: {timestamps_json}")
|
||||
|
||||
return StreamingResponse(
|
||||
generate_chunks(),
|
||||
media_type=f"audio/{request.response_format}",
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Cache-Control": "no-cache",
|
||||
"Transfer-Encoding": "chunked",
|
||||
"X-Word-Timestamps": timestamps_json
|
||||
}
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Error in captioned speech generation: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "validation_error",
|
||||
"message": str(e),
|
||||
"type": "invalid_request_error"
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in captioned speech generation: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "processing_error",
|
||||
"message": str(e),
|
||||
"type": "server_error"
|
||||
}
|
||||
)
|
||||
|
|
|
@ -2,15 +2,19 @@
|
|||
|
||||
import json
|
||||
import os
|
||||
import io
|
||||
import tempfile
|
||||
from typing import AsyncGenerator, Dict, List, Union
|
||||
|
||||
import torch
|
||||
import aiofiles
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||
from fastapi.responses import StreamingResponse, FileResponse
|
||||
from loguru import logger
|
||||
|
||||
from ..services.audio import AudioService
|
||||
from ..services.tts_service import TTSService
|
||||
from ..structures.schemas import OpenAISpeechRequest
|
||||
from ..structures import OpenAISpeechRequest
|
||||
from ..core.config import settings
|
||||
|
||||
# Load OpenAI mappings
|
||||
|
@ -72,55 +76,65 @@ def get_model_name(model: str) -> str:
|
|||
async def process_voices(
|
||||
voice_input: Union[str, List[str]], tts_service: TTSService
|
||||
) -> str:
|
||||
"""Process voice input into a combined voice, handling both string and list formats"""
|
||||
"""Process voice input, handling both string and list formats
|
||||
|
||||
Returns:
|
||||
Voice name to use (with weights if specified)
|
||||
"""
|
||||
# Convert input to list of voices
|
||||
if isinstance(voice_input, str):
|
||||
# Check if it's an OpenAI voice name
|
||||
mapped_voice = _openai_mappings["voices"].get(voice_input)
|
||||
if mapped_voice:
|
||||
voice_input = mapped_voice
|
||||
voices = [v.strip() for v in voice_input.split("+") if v.strip()]
|
||||
# Split on + but preserve any parentheses
|
||||
voices = []
|
||||
for part in voice_input.split("+"):
|
||||
part = part.strip()
|
||||
if not part:
|
||||
continue
|
||||
# Extract voice name without weight
|
||||
voice_name = part.split("(")[0].strip()
|
||||
# Check if it's a valid voice
|
||||
available_voices = await tts_service.list_voices()
|
||||
if voice_name not in available_voices:
|
||||
raise ValueError(
|
||||
f"Voice '{voice_name}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
||||
)
|
||||
voices.append(part)
|
||||
else:
|
||||
# For list input, map each voice if it's an OpenAI voice name
|
||||
voices = [_openai_mappings["voices"].get(v, v) for v in voice_input]
|
||||
voices = [v.strip() for v in voices if v.strip()]
|
||||
voices = []
|
||||
for v in voice_input:
|
||||
mapped = _openai_mappings["voices"].get(v, v)
|
||||
voice_name = mapped.split("(")[0].strip()
|
||||
# Check if it's a valid voice
|
||||
available_voices = await tts_service.list_voices()
|
||||
if voice_name not in available_voices:
|
||||
raise ValueError(
|
||||
f"Voice '{voice_name}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
||||
)
|
||||
voices.append(mapped)
|
||||
|
||||
if not voices:
|
||||
raise ValueError("No voices provided")
|
||||
|
||||
# If single voice, validate and return it
|
||||
if len(voices) == 1:
|
||||
available_voices = await tts_service.list_voices()
|
||||
if voices[0] not in available_voices:
|
||||
raise ValueError(
|
||||
f"Voice '{voices[0]}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
||||
)
|
||||
return voices[0]
|
||||
|
||||
# For multiple voices, validate base voices exist
|
||||
available_voices = await tts_service.list_voices()
|
||||
for voice in voices:
|
||||
if voice not in available_voices:
|
||||
raise ValueError(
|
||||
f"Base voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
||||
)
|
||||
|
||||
# Combine voices
|
||||
return await tts_service.combine_voices(voices=voices)
|
||||
# For multiple voices, combine them with +
|
||||
return "+".join(voices)
|
||||
|
||||
|
||||
async def stream_audio_chunks(
|
||||
tts_service: TTSService,
|
||||
tts_service: TTSService,
|
||||
request: OpenAISpeechRequest,
|
||||
client_request: Request
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""Stream audio chunks as they're generated with client disconnect handling"""
|
||||
voice_to_use = await process_voices(request.voice, tts_service)
|
||||
voice_name = await process_voices(request.voice, tts_service)
|
||||
|
||||
try:
|
||||
async for chunk in tts_service.generate_audio_stream(
|
||||
text=request.input,
|
||||
voice=voice_to_use,
|
||||
voice=voice_name,
|
||||
speed=request.speed,
|
||||
output_format=request.response_format,
|
||||
):
|
||||
|
@ -159,7 +173,7 @@ async def create_speech(
|
|||
try:
|
||||
# model_name = get_model_name(request.model)
|
||||
tts_service = await get_tts_service()
|
||||
voice_to_use = await process_voices(request.voice, tts_service)
|
||||
voice_name = await process_voices(request.voice, tts_service)
|
||||
|
||||
# Set content type based on format
|
||||
content_type = {
|
||||
|
@ -237,7 +251,7 @@ async def create_speech(
|
|||
# Generate complete audio using public interface
|
||||
audio, _ = await tts_service.generate_audio(
|
||||
text=request.input,
|
||||
voice=voice_to_use,
|
||||
voice=voice_name,
|
||||
speed=request.speed
|
||||
)
|
||||
|
||||
|
@ -350,14 +364,14 @@ async def list_voices():
|
|||
|
||||
@router.post("/audio/voices/combine")
|
||||
async def combine_voices(request: Union[str, List[str]]):
|
||||
"""Combine multiple voices into a new voice.
|
||||
"""Combine multiple voices into a new voice and return the .pt file.
|
||||
|
||||
Args:
|
||||
request: Either a string with voices separated by + (e.g. "voice1+voice2")
|
||||
or a list of voice names to combine
|
||||
|
||||
Returns:
|
||||
Dict with combined voice name and list of all available voices
|
||||
FileResponse with the combined voice .pt file
|
||||
|
||||
Raises:
|
||||
HTTPException:
|
||||
|
@ -365,10 +379,51 @@ async def combine_voices(request: Union[str, List[str]]):
|
|||
- 500: Server error (file system issues, combination failed)
|
||||
"""
|
||||
try:
|
||||
# Convert input to list of voices
|
||||
if isinstance(request, str):
|
||||
# Check if it's an OpenAI voice name
|
||||
mapped_voice = _openai_mappings["voices"].get(request)
|
||||
if mapped_voice:
|
||||
request = mapped_voice
|
||||
voices = [v.strip() for v in request.split("+") if v.strip()]
|
||||
else:
|
||||
# For list input, map each voice if it's an OpenAI voice name
|
||||
voices = [_openai_mappings["voices"].get(v, v) for v in request]
|
||||
voices = [v.strip() for v in voices if v.strip()]
|
||||
|
||||
if not voices:
|
||||
raise ValueError("No voices provided")
|
||||
|
||||
# For multiple voices, validate base voices exist
|
||||
tts_service = await get_tts_service()
|
||||
combined_voice = await process_voices(request, tts_service)
|
||||
voices = await tts_service.list_voices()
|
||||
return {"voices": voices, "voice": combined_voice}
|
||||
available_voices = await tts_service.list_voices()
|
||||
for voice in voices:
|
||||
if voice not in available_voices:
|
||||
raise ValueError(
|
||||
f"Base voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
||||
)
|
||||
|
||||
# Combine voices
|
||||
combined_tensor = await tts_service.combine_voices(voices=voices)
|
||||
combined_name = "+".join(voices)
|
||||
|
||||
# Save to temp file
|
||||
temp_dir = tempfile.gettempdir()
|
||||
voice_path = os.path.join(temp_dir, f"{combined_name}.pt")
|
||||
buffer = io.BytesIO()
|
||||
torch.save(combined_tensor, buffer)
|
||||
async with aiofiles.open(voice_path, 'wb') as f:
|
||||
await f.write(buffer.getvalue())
|
||||
|
||||
return FileResponse(
|
||||
voice_path,
|
||||
media_type="application/octet-stream",
|
||||
filename=f"{combined_name}.pt",
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename={combined_name}.pt",
|
||||
"Cache-Control": "no-cache"
|
||||
}
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning(f"Invalid voice combination request: {str(e)}")
|
||||
|
|
|
@ -17,6 +17,7 @@ from .audio import AudioNormalizer, AudioService
|
|||
from .text_processing.text_processor import process_text_chunk, smart_split
|
||||
from .text_processing import tokenize
|
||||
from ..inference.kokoro_v1 import KokoroV1
|
||||
from kokoro import KPipeline
|
||||
|
||||
|
||||
class TTSService:
|
||||
|
@ -154,23 +155,43 @@ class TTSService:
|
|||
try:
|
||||
# Check if it's a combined voice
|
||||
if "+" in voice:
|
||||
voices = [v.strip() for v in voice.split("+") if v.strip()]
|
||||
if len(voices) < 2:
|
||||
# Split on + but preserve any parentheses
|
||||
voice_parts = []
|
||||
weights = []
|
||||
for part in voice.split("+"):
|
||||
part = part.strip()
|
||||
if not part:
|
||||
continue
|
||||
# Extract voice name and weight if present
|
||||
if "(" in part and ")" in part:
|
||||
voice_name = part.split("(")[0].strip()
|
||||
weight = float(part.split("(")[1].split(")")[0])
|
||||
else:
|
||||
voice_name = part
|
||||
weight = 1.0
|
||||
voice_parts.append(voice_name)
|
||||
weights.append(weight)
|
||||
|
||||
if len(voice_parts) < 2:
|
||||
raise RuntimeError(f"Invalid combined voice name: {voice}")
|
||||
|
||||
# Normalize weights to sum to 1
|
||||
total_weight = sum(weights)
|
||||
weights = [w/total_weight for w in weights]
|
||||
|
||||
# Load and combine voices
|
||||
voice_tensors = []
|
||||
for v in voices:
|
||||
for v, w in zip(voice_parts, weights):
|
||||
path = await self._voice_manager.get_voice_path(v)
|
||||
if not path:
|
||||
raise RuntimeError(f"Voice not found: {v}")
|
||||
logger.debug(f"Loading voice tensor from: {path}")
|
||||
voice_tensor = torch.load(path, map_location="cpu")
|
||||
voice_tensors.append(voice_tensor)
|
||||
voice_tensors.append(voice_tensor * w)
|
||||
|
||||
# Average the voice tensors
|
||||
logger.debug(f"Combining {len(voice_tensors)} voice tensors")
|
||||
combined = torch.mean(torch.stack(voice_tensors), dim=0)
|
||||
# Sum the weighted voice tensors
|
||||
logger.debug(f"Combining {len(voice_tensors)} voice tensors with weights {weights}")
|
||||
combined = torch.sum(torch.stack(voice_tensors), dim=0)
|
||||
|
||||
# Save combined tensor
|
||||
temp_dir = tempfile.gettempdir()
|
||||
|
@ -259,43 +280,237 @@ class TTSService:
|
|||
raise
|
||||
|
||||
async def generate_audio(
|
||||
self, text: str, voice: str, speed: float = 1.0
|
||||
) -> Tuple[np.ndarray, float]:
|
||||
self, text: str, voice: str, speed: float = 1.0, return_timestamps: bool = False
|
||||
) -> Union[Tuple[np.ndarray, float], Tuple[np.ndarray, float, List[dict]]]:
|
||||
"""Generate complete audio for text using streaming internally."""
|
||||
start_time = time.time()
|
||||
chunks = []
|
||||
word_timestamps = []
|
||||
|
||||
try:
|
||||
# Use streaming generator but collect all valid chunks
|
||||
async for chunk in self.generate_audio_stream(
|
||||
text, voice, speed, # Default to WAV for raw audio
|
||||
):
|
||||
if chunk is not None:
|
||||
chunks.append(chunk)
|
||||
# Get backend and voice path
|
||||
backend = self.model_manager.get_backend()
|
||||
voice_name, voice_path = await self._get_voice_path(voice)
|
||||
|
||||
if not chunks:
|
||||
raise ValueError("No audio chunks were generated successfully")
|
||||
if isinstance(backend, KokoroV1):
|
||||
# Initialize quiet pipeline for text chunking
|
||||
quiet_pipeline = KPipeline(lang_code='a', model=False)
|
||||
|
||||
# Split text into chunks and get initial tokens
|
||||
text_chunks = []
|
||||
current_offset = 0.0 # Track time offset for timestamps
|
||||
|
||||
logger.debug("Splitting text into chunks...")
|
||||
for result in quiet_pipeline(text):
|
||||
if result.graphemes and result.phonemes:
|
||||
text_chunks.append((result.graphemes, result.phonemes))
|
||||
logger.debug(f"Split text into {len(text_chunks)} chunks")
|
||||
|
||||
# Process each chunk
|
||||
for chunk_idx, (chunk_text, chunk_phonemes) in enumerate(text_chunks):
|
||||
logger.debug(f"Processing chunk {chunk_idx + 1}/{len(text_chunks)}: '{chunk_text[:50]}...'")
|
||||
|
||||
# Generate audio and timestamps for this chunk
|
||||
for result in backend._pipeline(
|
||||
chunk_text,
|
||||
voice=voice_path,
|
||||
speed=speed,
|
||||
model=backend._model
|
||||
):
|
||||
# Collect audio chunks
|
||||
if result.audio is not None:
|
||||
chunks.append(result.audio.numpy())
|
||||
|
||||
# Process timestamps for this chunk
|
||||
if return_timestamps and hasattr(result, 'tokens') and result.tokens:
|
||||
logger.debug(f"Processing chunk timestamps with {len(result.tokens)} tokens")
|
||||
if result.pred_dur is not None:
|
||||
try:
|
||||
# Join timestamps for this chunk's tokens
|
||||
KPipeline.join_timestamps(result.tokens, result.pred_dur)
|
||||
|
||||
# Add timestamps with offset
|
||||
for token in result.tokens:
|
||||
if not all(hasattr(token, attr) for attr in ['text', 'start_ts', 'end_ts']):
|
||||
continue
|
||||
if not token.text or not token.text.strip():
|
||||
continue
|
||||
|
||||
# Apply offset to timestamps
|
||||
start_time = float(token.start_ts) + current_offset
|
||||
end_time = float(token.end_ts) + current_offset
|
||||
|
||||
word_timestamps.append({
|
||||
'word': str(token.text).strip(),
|
||||
'start_time': start_time,
|
||||
'end_time': end_time
|
||||
})
|
||||
logger.debug(f"Added timestamp for word '{token.text}': {start_time:.3f}s - {end_time:.3f}s")
|
||||
|
||||
# Update offset for next chunk based on pred_dur
|
||||
chunk_duration = float(result.pred_dur.sum()) / 80 # Convert frames to seconds
|
||||
current_offset = max(current_offset + chunk_duration, end_time)
|
||||
logger.debug(f"Updated time offset to {current_offset:.3f}s")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process timestamps for chunk: {e}")
|
||||
logger.debug(f"Processing timestamps with pred_dur shape: {result.pred_dur.shape}")
|
||||
try:
|
||||
# Join timestamps for this chunk's tokens
|
||||
KPipeline.join_timestamps(result.tokens, result.pred_dur)
|
||||
logger.debug("Successfully joined timestamps for chunk")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to join timestamps for chunk: {e}")
|
||||
continue
|
||||
|
||||
# Convert tokens to timestamps
|
||||
for token in result.tokens:
|
||||
try:
|
||||
# Skip tokens without required attributes
|
||||
if not all(hasattr(token, attr) for attr in ['text', 'start_ts', 'end_ts']):
|
||||
logger.debug(f"Skipping token missing attributes: {dir(token)}")
|
||||
continue
|
||||
|
||||
# Get and validate text
|
||||
text = str(token.text).strip() if token.text is not None else ''
|
||||
if not text:
|
||||
logger.debug("Skipping empty token")
|
||||
continue
|
||||
|
||||
# Get and validate timestamps
|
||||
start_ts = getattr(token, 'start_ts', None)
|
||||
end_ts = getattr(token, 'end_ts', None)
|
||||
if start_ts is None or end_ts is None:
|
||||
logger.debug(f"Skipping token with None timestamps: {text}")
|
||||
continue
|
||||
|
||||
# Convert timestamps to float
|
||||
try:
|
||||
start_time = float(start_ts)
|
||||
end_time = float(end_ts)
|
||||
except (TypeError, ValueError):
|
||||
logger.debug(f"Skipping token with invalid timestamps: {text}")
|
||||
continue
|
||||
|
||||
# Add timestamp
|
||||
word_timestamps.append({
|
||||
'word': text,
|
||||
'start_time': start_time,
|
||||
'end_time': end_time
|
||||
})
|
||||
logger.debug(f"Added timestamp for word '{text}': {start_time:.3f}s - {end_time:.3f}s")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing token: {e}")
|
||||
continue
|
||||
|
||||
if not chunks:
|
||||
raise ValueError("No audio chunks were generated successfully")
|
||||
|
||||
# Combine chunks
|
||||
audio = np.concatenate(chunks) if len(chunks) > 1 else chunks[0]
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
if return_timestamps:
|
||||
# Validate timestamps before returning
|
||||
if not word_timestamps:
|
||||
logger.warning("No valid timestamps were generated")
|
||||
else:
|
||||
# Sort timestamps by start time to ensure proper order
|
||||
word_timestamps.sort(key=lambda x: x['start_time'])
|
||||
# Validate timestamp sequence
|
||||
for i in range(1, len(word_timestamps)):
|
||||
prev = word_timestamps[i-1]
|
||||
curr = word_timestamps[i]
|
||||
if curr['start_time'] < prev['end_time']:
|
||||
logger.warning(f"Overlapping timestamps detected: '{prev['word']}' ({prev['start_time']:.3f}-{prev['end_time']:.3f}) and '{curr['word']}' ({curr['start_time']:.3f}-{curr['end_time']:.3f})")
|
||||
|
||||
logger.debug(f"Returning {len(word_timestamps)} word timestamps")
|
||||
logger.debug(f"First timestamp: {word_timestamps[0]['word']} at {word_timestamps[0]['start_time']:.3f}s")
|
||||
logger.debug(f"Last timestamp: {word_timestamps[-1]['word']} at {word_timestamps[-1]['end_time']:.3f}s")
|
||||
|
||||
return audio, processing_time, word_timestamps
|
||||
return audio, processing_time
|
||||
|
||||
# Combine chunks, ensuring we have valid arrays
|
||||
if len(chunks) == 1:
|
||||
audio = chunks[0]
|
||||
else:
|
||||
# Filter out any zero-dimensional arrays
|
||||
valid_chunks = [c for c in chunks if c.ndim > 0]
|
||||
if not valid_chunks:
|
||||
raise ValueError("No valid audio chunks to concatenate")
|
||||
audio = np.concatenate(valid_chunks)
|
||||
processing_time = time.time() - start_time
|
||||
return audio, processing_time
|
||||
# For legacy backends
|
||||
async for chunk in self.generate_audio_stream(
|
||||
text, voice, speed, # Default to WAV for raw audio
|
||||
):
|
||||
if chunk is not None:
|
||||
chunks.append(chunk)
|
||||
|
||||
if not chunks:
|
||||
raise ValueError("No audio chunks were generated successfully")
|
||||
|
||||
# Combine chunks
|
||||
audio = np.concatenate(chunks) if len(chunks) > 1 else chunks[0]
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
if return_timestamps:
|
||||
return audio, processing_time, [] # Empty timestamps for legacy backends
|
||||
return audio, processing_time
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in audio generation: {str(e)}")
|
||||
raise
|
||||
|
||||
async def combine_voices(self, voices: List[str]) -> str:
|
||||
"""Combine multiple voices."""
|
||||
async def combine_voices(self, voices: List[str]) -> torch.Tensor:
|
||||
"""Combine multiple voices.
|
||||
|
||||
Returns:
|
||||
Combined voice tensor
|
||||
"""
|
||||
return await self._voice_manager.combine_voices(voices)
|
||||
|
||||
async def list_voices(self) -> List[str]:
|
||||
"""List available voices."""
|
||||
return await self._voice_manager.list_voices()
|
||||
|
||||
async def generate_from_phonemes(
|
||||
self,
|
||||
phonemes: str,
|
||||
voice: str,
|
||||
speed: float = 1.0
|
||||
) -> Tuple[np.ndarray, float]:
|
||||
"""Generate audio directly from phonemes.
|
||||
|
||||
Args:
|
||||
phonemes: Phonemes in Kokoro format
|
||||
voice: Voice name
|
||||
speed: Speed multiplier
|
||||
|
||||
Returns:
|
||||
Tuple of (audio array, processing time)
|
||||
"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
# Get backend and voice path
|
||||
raise ValueError("Not yet implemented")
|
||||
# linked to https://github.com/hexgrad/kokoro/pull/53 or similiar
|
||||
backend = self.model_manager.get_backend()
|
||||
voice_name, voice_path = await self._get_voice_path(voice)
|
||||
|
||||
# if isinstance(backend, KokoroV1):
|
||||
# # For Kokoro V1, pass phonemes directly to pipeline
|
||||
# result = None
|
||||
# for r in backend._pipeline(
|
||||
# phonemes,
|
||||
# voice=voice_path,
|
||||
# speed=speed,
|
||||
# model=backend._model
|
||||
# ):
|
||||
# if r.audio is not None:
|
||||
# result = r
|
||||
# break
|
||||
|
||||
# if result is None or result.audio is None:
|
||||
# raise ValueError("No audio generated")
|
||||
|
||||
# processing_time = time.time() - start_time
|
||||
# return result.audio.numpy(), processing_time
|
||||
# else:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in phoneme audio generation: {str(e)}")
|
||||
raise
|
||||
|
|
|
@ -1,3 +1,17 @@
|
|||
from .schemas import OpenAISpeechRequest
|
||||
from .schemas import (
|
||||
OpenAISpeechRequest,
|
||||
CaptionedSpeechRequest,
|
||||
CaptionedSpeechResponse,
|
||||
WordTimestamp,
|
||||
TTSStatus,
|
||||
VoiceCombineRequest
|
||||
)
|
||||
|
||||
__all__ = ["OpenAISpeechRequest"]
|
||||
__all__ = [
|
||||
"OpenAISpeechRequest",
|
||||
"CaptionedSpeechRequest",
|
||||
"CaptionedSpeechResponse",
|
||||
"WordTimestamp",
|
||||
"TTSStatus",
|
||||
"VoiceCombineRequest"
|
||||
]
|
||||
|
|
|
@ -22,7 +22,19 @@ class TTSStatus(str, Enum):
|
|||
|
||||
|
||||
# OpenAI-compatible schemas
|
||||
class WordTimestamp(BaseModel):
|
||||
"""Word-level timestamp information"""
|
||||
word: str = Field(..., description="The word or token")
|
||||
start_time: float = Field(..., description="Start time in seconds")
|
||||
end_time: float = Field(..., description="End time in seconds")
|
||||
|
||||
class CaptionedSpeechResponse(BaseModel):
|
||||
"""Response schema for captioned speech endpoint"""
|
||||
audio: bytes = Field(..., description="The generated audio data")
|
||||
words: List[WordTimestamp] = Field(..., description="Word-level timestamps")
|
||||
|
||||
class OpenAISpeechRequest(BaseModel):
|
||||
"""Request schema for OpenAI-compatible speech endpoint"""
|
||||
model: str = Field(
|
||||
default="kokoro",
|
||||
description="The model to use for generation. Supported models: tts-1, tts-1-hd, kokoro"
|
||||
|
@ -50,3 +62,29 @@ class OpenAISpeechRequest(BaseModel):
|
|||
default=False,
|
||||
description="If true, returns a download link in X-Download-Path header after streaming completes",
|
||||
)
|
||||
|
||||
class CaptionedSpeechRequest(BaseModel):
|
||||
"""Request schema for captioned speech endpoint"""
|
||||
model: str = Field(
|
||||
default="kokoro",
|
||||
description="The model to use for generation. Supported models: tts-1, tts-1-hd, kokoro"
|
||||
)
|
||||
input: str = Field(..., description="The text to generate audio for")
|
||||
voice: str = Field(
|
||||
default="af",
|
||||
description="The voice to use for generation. Can be a base voice or a combined voice name.",
|
||||
)
|
||||
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field(
|
||||
default="mp3",
|
||||
description="The format to return audio in. Supported formats: mp3, opus, flac, wav, pcm. PCM format returns raw 16-bit samples without headers. AAC is not currently supported.",
|
||||
)
|
||||
speed: float = Field(
|
||||
default=1.0,
|
||||
ge=0.25,
|
||||
le=4.0,
|
||||
description="The speed of the generated audio. Select a value from 0.25 to 4.0.",
|
||||
)
|
||||
return_timestamps: bool = Field(
|
||||
default=True,
|
||||
description="If true (default), returns word-level timestamps in the response",
|
||||
)
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -92,15 +92,16 @@ def analyze_audio(filepath: str):
|
|||
'dominant_frequencies': dominant_freqs
|
||||
}
|
||||
|
||||
def plot_comparison(analyses, output_path):
|
||||
"""Create comparison plot of the audio analyses."""
|
||||
def plot_comparison(analyses, output_dir):
|
||||
"""Create detailed comparison plots of the audio analyses."""
|
||||
plt.style.use('dark_background')
|
||||
fig = plt.figure(figsize=(15, 10))
|
||||
fig.patch.set_facecolor('#1a1a2e')
|
||||
|
||||
# Plot waveforms
|
||||
fig_wave = plt.figure(figsize=(15, 10))
|
||||
fig_wave.patch.set_facecolor('#1a1a2e')
|
||||
|
||||
for i, (name, data) in enumerate(analyses.items()):
|
||||
ax = plt.subplot(3, 1, i+1)
|
||||
ax = plt.subplot(len(analyses), 1, i+1)
|
||||
samples = data['samples']
|
||||
time = np.arange(len(samples)) / data['sample_rate']
|
||||
plt.plot(time, samples / data['max_amplitude'], linewidth=0.5, color='#ff2a6d')
|
||||
|
@ -112,27 +113,96 @@ def plot_comparison(analyses, output_path):
|
|||
plt.ylim(-1.1, 1.1)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
||||
print(f"\nSaved comparison plot to {output_path}")
|
||||
plt.savefig(output_dir / 'waveforms.png', dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
# Plot spectral characteristics
|
||||
fig_spec = plt.figure(figsize=(15, 10))
|
||||
fig_spec.patch.set_facecolor('#1a1a2e')
|
||||
|
||||
for i, (name, data) in enumerate(analyses.items()):
|
||||
# Calculate spectrogram
|
||||
samples = data['samples']
|
||||
sample_rate = data['sample_rate']
|
||||
nperseg = 2048
|
||||
f, t, Sxx = plt.mlab.specgram(samples, NFFT=2048, Fs=sample_rate,
|
||||
noverlap=nperseg//2, scale='dB')
|
||||
|
||||
ax = plt.subplot(len(analyses), 1, i+1)
|
||||
plt.pcolormesh(t, f, Sxx, shading='gouraud', cmap='magma')
|
||||
plt.title(f"Spectrogram: {name}", color='white', pad=20)
|
||||
plt.ylabel('Frequency [Hz]', color='white')
|
||||
plt.xlabel('Time [sec]', color='white')
|
||||
plt.colorbar(label='Intensity [dB]')
|
||||
ax.set_facecolor('#1a1a2e')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_dir / 'spectrograms.png', dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
# Plot voice characteristics comparison
|
||||
fig_chars = plt.figure(figsize=(15, 8))
|
||||
fig_chars.patch.set_facecolor('#1a1a2e')
|
||||
|
||||
# Extract characteristics
|
||||
names = list(analyses.keys())
|
||||
rms_values = [data['rms'] for data in analyses.values()]
|
||||
centroids = [data['spectral_centroid'] for data in analyses.values()]
|
||||
max_amps = [data['max_amplitude'] for data in analyses.values()]
|
||||
|
||||
# Plot characteristics
|
||||
x = np.arange(len(names))
|
||||
width = 0.25
|
||||
|
||||
ax = plt.subplot(111)
|
||||
ax.bar(x - width, rms_values, width, label='RMS (Texture)', color='#ff2a6d')
|
||||
ax.bar(x, [c/1000 for c in centroids], width, label='Spectral Centroid/1000 (Brightness)', color='#05d9e8')
|
||||
ax.bar(x + width, max_amps, width, label='Max Amplitude', color='#ff65bd')
|
||||
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(names, rotation=45, ha='right')
|
||||
ax.legend()
|
||||
ax.set_title('Voice Characteristics Comparison', color='white', pad=20)
|
||||
ax.set_facecolor('#1a1a2e')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_dir / 'characteristics.png', dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
print(f"\nSaved comparison plots to {output_dir}")
|
||||
|
||||
def main():
|
||||
# Generate audio for each voice
|
||||
# Test different voice combinations with weights
|
||||
voices = {
|
||||
'af_bella': output_dir / 'af_bella.wav',
|
||||
'af_irulan': output_dir / 'af_irulan.wav',
|
||||
'af_bella+af_irulan': output_dir / 'af_bella+af_irulan.wav'
|
||||
'af_kore': output_dir / 'af_kore.wav',
|
||||
'af_bella(0.2)+af_kore(0.8)': output_dir / 'af_bella_20_af_kore_80.wav',
|
||||
'af_bella(0.8)+af_kore(0.2)': output_dir / 'af_bella_80_af_kore_20.wav',
|
||||
'af_bella(0.5)+af_kore(0.5)': output_dir / 'af_bella_50_af_kore_50.wav'
|
||||
}
|
||||
|
||||
# Generate audio for each voice/combination
|
||||
for voice, path in voices.items():
|
||||
generate_and_save_audio(voice, str(path))
|
||||
try:
|
||||
generate_and_save_audio(voice, str(path))
|
||||
except Exception as e:
|
||||
print(f"Error generating audio for {voice}: {e}")
|
||||
continue
|
||||
|
||||
# Analyze each audio file
|
||||
analyses = {}
|
||||
for name, path in voices.items():
|
||||
analyses[name] = analyze_audio(str(path))
|
||||
try:
|
||||
analyses[name] = analyze_audio(str(path))
|
||||
except Exception as e:
|
||||
print(f"Error analyzing {name}: {e}")
|
||||
continue
|
||||
|
||||
# Create comparison plot
|
||||
plot_comparison(analyses, output_dir / 'voice_comparison.png')
|
||||
# Create comparison plots
|
||||
if analyses:
|
||||
plot_comparison(analyses, output_dir)
|
||||
else:
|
||||
print("No analyses to plot")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
#!/usr/bin/env python3
|
||||
import os
|
||||
from pathlib import Path
|
||||
import requests
|
||||
|
||||
# Create output directory
|
||||
output_dir = Path(__file__).parent / "output"
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
def download_combined_voice(voice1: str, voice2: str, weights: tuple[float, float] = None) -> str:
|
||||
"""Download a combined voice file.
|
||||
|
||||
Args:
|
||||
voice1: First voice name
|
||||
voice2: Second voice name
|
||||
weights: Optional tuple of weights (w1, w2). If not provided, uses equal weights.
|
||||
|
||||
Returns:
|
||||
Path to downloaded .pt file
|
||||
"""
|
||||
print(f"\nDownloading combined voice: {voice1} + {voice2}")
|
||||
|
||||
# Construct voice string with optional weights
|
||||
if weights:
|
||||
voice_str = f"{voice1}({weights[0]})+{voice2}({weights[1]})"
|
||||
else:
|
||||
voice_str = f"{voice1}+{voice2}"
|
||||
|
||||
# Make the request to combine voices
|
||||
response = requests.post(
|
||||
"http://localhost:8880/v1/audio/voices/combine",
|
||||
json=voice_str
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to combine voices: {response.text}")
|
||||
|
||||
# Save the .pt file
|
||||
output_path = output_dir / f"{voice_str}.pt"
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
print(f"Saved combined voice to {output_path}")
|
||||
return str(output_path)
|
||||
|
||||
def main():
|
||||
# Test downloading various voice combinations
|
||||
combinations = [
|
||||
# Equal weights (default)
|
||||
("af_bella", "af_kore"),
|
||||
|
||||
# Different weight combinations
|
||||
("af_bella", "af_kore", (0.2, 0.8)),
|
||||
("af_bella", "af_kore", (0.8, 0.2)),
|
||||
("af_bella", "af_kore", (0.5, 0.5)),
|
||||
|
||||
# Test with different voices
|
||||
("af_bella", "af_jadzia"),
|
||||
("af_bella", "af_jadzia", (0.3, 0.7))
|
||||
]
|
||||
|
||||
for combo in combinations:
|
||||
try:
|
||||
if len(combo) == 3:
|
||||
voice1, voice2, weights = combo
|
||||
download_combined_voice(voice1, voice2, weights)
|
||||
else:
|
||||
voice1, voice2 = combo
|
||||
download_combined_voice(voice1, voice2)
|
||||
except Exception as e:
|
||||
print(f"Error downloading combination {combo}: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,54 @@
|
|||
import os
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
def analyze_voice_file(file_path):
|
||||
"""Analyze dimensions and statistics of a voice tensor."""
|
||||
try:
|
||||
tensor = torch.load(file_path, map_location="cpu")
|
||||
logger.info(f"\nAnalyzing {os.path.basename(file_path)}:")
|
||||
logger.info(f"Shape: {tensor.shape}")
|
||||
logger.info(f"Mean: {tensor.mean().item():.4f}")
|
||||
logger.info(f"Std: {tensor.std().item():.4f}")
|
||||
logger.info(f"Min: {tensor.min().item():.4f}")
|
||||
logger.info(f"Max: {tensor.max().item():.4f}")
|
||||
return tensor.shape
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing {file_path}: {e}")
|
||||
return None
|
||||
|
||||
def main():
|
||||
"""Analyze voice files in the voices directory."""
|
||||
# Get the project root directory
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(current_dir)))
|
||||
voices_dir = os.path.join(project_root, "api", "src", "voices", "v1_0")
|
||||
|
||||
logger.info(f"Scanning voices in: {voices_dir}")
|
||||
|
||||
# Track shapes for comparison
|
||||
shapes = {}
|
||||
|
||||
# Analyze each .pt file
|
||||
for file in os.listdir(voices_dir):
|
||||
if file.endswith('.pt'):
|
||||
file_path = os.path.join(voices_dir, file)
|
||||
shape = analyze_voice_file(file_path)
|
||||
if shape:
|
||||
shapes[file] = shape
|
||||
|
||||
# Report findings
|
||||
logger.info("\nShape Analysis:")
|
||||
shape_groups = {}
|
||||
for file, shape in shapes.items():
|
||||
if shape not in shape_groups:
|
||||
shape_groups[shape] = []
|
||||
shape_groups[shape].append(file)
|
||||
|
||||
for shape, files in shape_groups.items():
|
||||
logger.info(f"\nShape {shape}:")
|
||||
for file in files:
|
||||
logger.info(f" - {file}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,85 @@
|
|||
import os
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
def analyze_voice_content(tensor):
|
||||
"""Analyze the content distribution in the voice tensor."""
|
||||
# Look at the variance along the first dimension to see where the information is concentrated
|
||||
variance = torch.var(tensor, dim=(1,2)) # Variance across features
|
||||
logger.info(f"Variance distribution:")
|
||||
logger.info(f"First 5 rows variance: {variance[:5].mean().item():.6f}")
|
||||
logger.info(f"Last 5 rows variance: {variance[-5:].mean().item():.6f}")
|
||||
return variance
|
||||
|
||||
def trim_voice_tensor(tensor):
|
||||
"""Trim a 511x1x256 tensor to 510x1x256 by removing the row with least impact."""
|
||||
if tensor.shape[0] != 511:
|
||||
raise ValueError(f"Expected tensor with first dimension 511, got {tensor.shape[0]}")
|
||||
|
||||
# Analyze variance contribution of each row
|
||||
variance = analyze_voice_content(tensor)
|
||||
|
||||
# Determine which end has lower variance (less information)
|
||||
start_var = variance[:5].mean().item()
|
||||
end_var = variance[-5:].mean().item()
|
||||
|
||||
# Remove from the end with lower variance
|
||||
if end_var < start_var:
|
||||
logger.info("Trimming last row (lower variance at end)")
|
||||
return tensor[:-1]
|
||||
else:
|
||||
logger.info("Trimming first row (lower variance at start)")
|
||||
return tensor[1:]
|
||||
|
||||
def process_voice_file(file_path):
|
||||
"""Process a single voice file."""
|
||||
try:
|
||||
tensor = torch.load(file_path, map_location="cpu")
|
||||
if tensor.shape[0] != 511:
|
||||
logger.info(f"Skipping {os.path.basename(file_path)} - already correct shape {tensor.shape}")
|
||||
return False
|
||||
|
||||
logger.info(f"\nProcessing {os.path.basename(file_path)}:")
|
||||
logger.info(f"Original shape: {tensor.shape}")
|
||||
|
||||
# Create backup
|
||||
backup_path = file_path + ".backup"
|
||||
if not os.path.exists(backup_path):
|
||||
torch.save(tensor, backup_path)
|
||||
logger.info(f"Created backup at {backup_path}")
|
||||
|
||||
# Trim tensor
|
||||
trimmed = trim_voice_tensor(tensor)
|
||||
logger.info(f"New shape: {trimmed.shape}")
|
||||
|
||||
# Save trimmed tensor
|
||||
torch.save(trimmed, file_path)
|
||||
logger.info(f"Saved trimmed tensor to {file_path}")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Process voice files in the voices directory."""
|
||||
# Get the project root directory
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(current_dir)))
|
||||
voices_dir = os.path.join(project_root, "api", "src", "voices", "v1_0")
|
||||
|
||||
logger.info(f"Processing voices in: {voices_dir}")
|
||||
|
||||
processed = 0
|
||||
for file in os.listdir(voices_dir):
|
||||
if file.endswith('.pt') and not file.endswith('.backup'):
|
||||
file_path = os.path.join(voices_dir, file)
|
||||
if process_voice_file(file_path):
|
||||
processed += 1
|
||||
|
||||
logger.info(f"\nProcessed {processed} voice files")
|
||||
logger.info("Backups created with .backup extension")
|
||||
logger.info("To restore backups if needed, remove .backup extension to replace trimmed files")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
98
examples/captioned_speech_example.py
Normal file
98
examples/captioned_speech_example.py
Normal file
|
@ -0,0 +1,98 @@
|
|||
import json
|
||||
from typing import Tuple, Optional, Dict, List
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
|
||||
# Get the directory this script is in
|
||||
SCRIPT_DIR = Path(__file__).absolute().parent
|
||||
|
||||
def generate_captioned_speech(
|
||||
text: str,
|
||||
voice: str = "af_bella",
|
||||
speed: float = 1.0,
|
||||
response_format: str = "wav"
|
||||
) -> Tuple[Optional[bytes], Optional[List[Dict]]]:
|
||||
"""Generate audio with word-level timestamps."""
|
||||
response = requests.post(
|
||||
"http://localhost:8880/dev/captioned_speech",
|
||||
json={
|
||||
"model": "kokoro",
|
||||
"input": text,
|
||||
"voice": voice,
|
||||
"speed": speed,
|
||||
"response_format": response_format
|
||||
}
|
||||
)
|
||||
|
||||
print(f"Response status: {response.status_code}")
|
||||
print(f"Response headers: {dict(response.headers)}")
|
||||
|
||||
if response.status_code != 200:
|
||||
print(f"Error response: {response.text}")
|
||||
return None, None
|
||||
|
||||
try:
|
||||
# Get timestamps from header
|
||||
timestamps_json = response.headers.get('X-Word-Timestamps', '[]')
|
||||
word_timestamps = json.loads(timestamps_json)
|
||||
|
||||
# Get audio bytes from content
|
||||
audio_bytes = response.content
|
||||
|
||||
if not audio_bytes:
|
||||
print("Error: Empty audio content")
|
||||
return None, None
|
||||
|
||||
return audio_bytes, word_timestamps
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error parsing timestamps: {e}")
|
||||
return None, None
|
||||
|
||||
def main():
|
||||
# Example texts to convert
|
||||
examples = [
|
||||
"Hello world! Welcome to the captioned speech system.",
|
||||
"The quick brown fox jumps over the lazy dog.",
|
||||
"""If you have access to a room where gasoline is stored, remember that gas vapor accumulating in a closed room will explode after a time if you leave a candle burning in the room. A good deal of evaporation, however, must occur from the gasoline tins into the air of the room. If removal of the tops of the tins does not expose enough gasoline to the air to ensure copious evaporation, you can open lightly constructed tins further with a knife, ice pick or sharpened nail file. Or puncture a tiny hole in the tank which will permit gasoline to leak out on the floor. This will greatly increase the rate of evaporation. Before you light your candle, be sure that windows are closed and the room is as air-tight as you can make it. If you can see that windows in a neighboring room are opened wide, you have a chance of setting a large fire which will not only destroy the gasoline but anything else nearby; when the gasoline explodes, the doors of the storage room will be blown open, a draft to the neighboring windows will be created which will whip up a fine conflagration"""
|
||||
]
|
||||
|
||||
print("Generating captioned speech for example texts...\n")
|
||||
|
||||
# Create output directory in same directory as script
|
||||
output_dir = SCRIPT_DIR / "output"
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
for i, text in enumerate(examples):
|
||||
print(f"\nExample {i+1}:")
|
||||
print(f"Input text: {text}")
|
||||
try:
|
||||
# Generate audio and get timestamps
|
||||
audio_bytes, word_timestamps = generate_captioned_speech(text)
|
||||
|
||||
if not audio_bytes or not word_timestamps:
|
||||
print("Error: No audio data or timestamps generated")
|
||||
continue
|
||||
|
||||
# Save audio file
|
||||
audio_path = output_dir / f"captioned_example_{i+1}.wav"
|
||||
with audio_path.open("wb") as f:
|
||||
f.write(audio_bytes)
|
||||
print(f"Audio saved to: {audio_path}")
|
||||
|
||||
# Save timestamps to JSON
|
||||
timestamps_path = output_dir / f"captioned_example_{i+1}_timestamps.json"
|
||||
with timestamps_path.open("w") as f:
|
||||
json.dump(word_timestamps, f, indent=2)
|
||||
print(f"Timestamps saved to: {timestamps_path}")
|
||||
|
||||
# Print timestamps
|
||||
print("\nWord-level timestamps:")
|
||||
for ts in word_timestamps:
|
||||
print(f"{ts['word']}: {ts['start_time']:.3f}s - {ts['end_time']:.3f}s")
|
||||
|
||||
except requests.RequestException as e:
|
||||
print(f"Error: {e}\n")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -37,8 +37,8 @@ dependencies = [
|
|||
"semchunk>=3.0.1",
|
||||
"mutagen>=1.47.0",
|
||||
"psutil>=6.1.1",
|
||||
"kokoro==0.3.5",
|
||||
'misaki[en,ja,ko,zh,vi]==0.6.7',
|
||||
"kokoro==0.7.4",
|
||||
'misaki[en,ja,ko,zh,vi]==0.7.4',
|
||||
"spacy>=3.7.6",
|
||||
"en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl"
|
||||
]
|
||||
|
|
|
@ -30,6 +30,20 @@ export class VoiceSelector {
|
|||
}
|
||||
});
|
||||
|
||||
// Weight adjustment and voice removal
|
||||
this.elements.selectedVoices.addEventListener('input', (e) => {
|
||||
if (e.target.type === 'number') {
|
||||
const voice = e.target.dataset.voice;
|
||||
let weight = parseFloat(e.target.value);
|
||||
|
||||
// Ensure weight is between 0.1 and 10
|
||||
weight = Math.max(0.1, Math.min(10, weight));
|
||||
e.target.value = weight;
|
||||
|
||||
this.voiceService.updateWeight(voice, weight);
|
||||
}
|
||||
});
|
||||
|
||||
// Remove selected voice
|
||||
this.elements.selectedVoices.addEventListener('click', (e) => {
|
||||
if (e.target.classList.contains('remove-voice')) {
|
||||
|
@ -73,12 +87,22 @@ export class VoiceSelector {
|
|||
}
|
||||
|
||||
updateSelectedVoicesDisplay() {
|
||||
const selectedVoices = this.voiceService.getSelectedVoices();
|
||||
const selectedVoices = this.voiceService.getSelectedVoiceWeights();
|
||||
this.elements.selectedVoices.innerHTML = selectedVoices
|
||||
.map(voice => `
|
||||
.map(({voice, weight}) => `
|
||||
<span class="selected-voice-tag">
|
||||
${voice}
|
||||
<span class="remove-voice" data-voice="${voice}">×</span>
|
||||
<span class="voice-name">${voice}</span>
|
||||
<span class="voice-weight">
|
||||
<input type="number"
|
||||
value="${weight}"
|
||||
min="0.1"
|
||||
max="10"
|
||||
step="0.1"
|
||||
data-voice="${voice}"
|
||||
class="weight-input"
|
||||
title="Voice weight (0.1 to 10)">
|
||||
</span>
|
||||
<span class="remove-voice" data-voice="${voice}" title="Remove voice">×</span>
|
||||
</span>
|
||||
`)
|
||||
.join('');
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
export class VoiceService {
|
||||
constructor() {
|
||||
this.availableVoices = [];
|
||||
this.selectedVoices = new Set();
|
||||
this.selectedVoices = new Map(); // Changed to Map to store voice:weight pairs
|
||||
}
|
||||
|
||||
async loadVoices() {
|
||||
|
@ -39,16 +39,33 @@ export class VoiceService {
|
|||
}
|
||||
|
||||
getSelectedVoices() {
|
||||
return Array.from(this.selectedVoices);
|
||||
return Array.from(this.selectedVoices.keys());
|
||||
}
|
||||
|
||||
getSelectedVoiceWeights() {
|
||||
return Array.from(this.selectedVoices.entries()).map(([voice, weight]) => ({
|
||||
voice,
|
||||
weight
|
||||
}));
|
||||
}
|
||||
|
||||
getSelectedVoiceString() {
|
||||
return Array.from(this.selectedVoices).join('+');
|
||||
return Array.from(this.selectedVoices.entries())
|
||||
.map(([voice, weight]) => `${voice}(${weight})`)
|
||||
.join('+');
|
||||
}
|
||||
|
||||
addVoice(voice) {
|
||||
addVoice(voice, weight = 1) {
|
||||
if (this.availableVoices.includes(voice)) {
|
||||
this.selectedVoices.add(voice);
|
||||
this.selectedVoices.set(voice, parseFloat(weight) || 1);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
updateWeight(voice, weight) {
|
||||
if (this.selectedVoices.has(voice)) {
|
||||
this.selectedVoices.set(voice, parseFloat(weight) || 1);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
|
Loading…
Add table
Reference in a new issue