mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Update dependencies, enhance voice management, and add captioned speech support
This commit is contained in:
parent
9198de2d95
commit
6c234a3b67
31 changed files with 979 additions and 169 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -69,3 +69,4 @@ examples/*.ogg
|
||||||
examples/speech.mp3
|
examples/speech.mp3
|
||||||
examples/phoneme_examples/output/*.wav
|
examples/phoneme_examples/output/*.wav
|
||||||
examples/assorted_checks/benchmarks/output_audio/*
|
examples/assorted_checks/benchmarks/output_audio/*
|
||||||
|
uv.lock
|
||||||
|
|
|
@ -1,4 +1,9 @@
|
||||||
"""Model configuration for Kokoro V1."""
|
"""Model configuration for Kokoro V1.
|
||||||
|
|
||||||
|
This module provides model-specific configuration settings that complement the application-level
|
||||||
|
settings in config.py. While config.py handles general application settings (API, paths, etc.),
|
||||||
|
this module focuses on memory management and model file paths.
|
||||||
|
"""
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
@ -9,51 +14,29 @@ class KokoroV1Config(BaseModel):
|
||||||
class Config:
|
class Config:
|
||||||
frozen = True
|
frozen = True
|
||||||
|
|
||||||
class PyTorchCPUConfig(BaseModel):
|
class PyTorchConfig(BaseModel):
|
||||||
"""PyTorch CPU backend configuration."""
|
"""PyTorch backend configuration."""
|
||||||
|
|
||||||
memory_threshold: float = Field(0.8, description="Memory threshold for cleanup")
|
memory_threshold: float = Field(0.8, description="Memory threshold for cleanup")
|
||||||
retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors")
|
retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors")
|
||||||
num_threads: int = Field(8, description="Number of threads for parallel operations")
|
|
||||||
pin_memory: bool = Field(True, description="Whether to pin memory for faster CPU-GPU transfer")
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
frozen = True
|
frozen = True
|
||||||
|
|
||||||
|
|
||||||
class PyTorchGPUConfig(BaseModel):
|
|
||||||
"""PyTorch GPU backend configuration."""
|
|
||||||
|
|
||||||
device_id: int = Field(0, description="CUDA device ID")
|
|
||||||
use_triton: bool = Field(True, description="Whether to use Triton for CUDA kernels")
|
|
||||||
memory_threshold: float = Field(0.8, description="Memory threshold for cleanup")
|
|
||||||
retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors")
|
|
||||||
sync_cuda: bool = Field(True, description="Whether to synchronize CUDA operations")
|
|
||||||
cuda_streams: int = Field(2, description="Number of CUDA streams for inference")
|
|
||||||
stream_timeout: int = Field(60, description="Stream timeout in seconds")
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
frozen = True
|
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseModel):
|
class ModelConfig(BaseModel):
|
||||||
"""Kokoro V1 model configuration."""
|
"""Kokoro V1 model configuration."""
|
||||||
|
|
||||||
# General settings
|
# General settings
|
||||||
device_type: str = Field("cpu", description="Device type ('cpu' or 'gpu')")
|
|
||||||
cache_voices: bool = Field(True, description="Whether to cache voice tensors")
|
cache_voices: bool = Field(True, description="Whether to cache voice tensors")
|
||||||
voice_cache_size: int = Field(2, description="Maximum number of cached voices")
|
voice_cache_size: int = Field(2, description="Maximum number of cached voices")
|
||||||
|
|
||||||
# Model filename
|
# Model filename
|
||||||
pytorch_kokoro_v1_file: str = Field("v1_0/kokoro-v1_0.pth", description="PyTorch Kokoro V1 model filename")
|
pytorch_kokoro_v1_file: str = Field("v1_0/kokoro-v1_0.pth", description="PyTorch Kokoro V1 model filename")
|
||||||
|
|
||||||
# Backend configs
|
# Backend config
|
||||||
pytorch_cpu: PyTorchCPUConfig = Field(default_factory=PyTorchCPUConfig)
|
pytorch_gpu: PyTorchConfig = Field(default_factory=PyTorchConfig)
|
||||||
pytorch_gpu: PyTorchGPUConfig = Field(default_factory=PyTorchGPUConfig)
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
frozen = True
|
frozen = True
|
||||||
|
|
||||||
|
|
||||||
# Global instance
|
# Global instance
|
||||||
model_config = ModelConfig()
|
model_config = ModelConfig()
|
|
@ -160,7 +160,7 @@ async def list_voices() -> List[str]:
|
||||||
return sorted([name[:-3] for name in voices]) # Remove .pt extension
|
return sorted([name[:-3] for name in voices]) # Remove .pt extension
|
||||||
|
|
||||||
|
|
||||||
async def load_voice_tensor(voice_path: str, device: str = "cpu") -> torch.Tensor:
|
async def load_voice_tensor(voice_path: str, device: str = "cpu", weights_only=False) -> torch.Tensor:
|
||||||
"""Load voice tensor from file.
|
"""Load voice tensor from file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -179,7 +179,7 @@ async def load_voice_tensor(voice_path: str, device: str = "cpu") -> torch.Tenso
|
||||||
return torch.load(
|
return torch.load(
|
||||||
io.BytesIO(data),
|
io.BytesIO(data),
|
||||||
map_location=device,
|
map_location=device,
|
||||||
weights_only=True
|
weights_only=weights_only
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to load voice tensor from {voice_path}: {e}")
|
raise RuntimeError(f"Failed to load voice tensor from {voice_path}: {e}")
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import aiofiles
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from ..core import paths
|
from ..core import paths
|
||||||
|
@ -57,7 +58,7 @@ class VoiceManager:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to load voice {voice_name}: {e}")
|
raise RuntimeError(f"Failed to load voice {voice_name}: {e}")
|
||||||
|
|
||||||
async def combine_voices(self, voices: List[str], device: Optional[str] = None) -> str:
|
async def combine_voices(self, voices: List[str], device: Optional[str] = None) -> torch.Tensor:
|
||||||
"""Combine multiple voices.
|
"""Combine multiple voices.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -65,7 +66,7 @@ class VoiceManager:
|
||||||
device: Optional override for target device
|
device: Optional override for target device
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Name of combined voice
|
Combined voice tensor
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If any voice not found
|
RuntimeError: If any voice not found
|
||||||
|
@ -80,10 +81,7 @@ class VoiceManager:
|
||||||
voice_tensors.append(voice)
|
voice_tensors.append(voice)
|
||||||
|
|
||||||
combined = torch.mean(torch.stack(voice_tensors), dim=0)
|
combined = torch.mean(torch.stack(voice_tensors), dim=0)
|
||||||
combined_name = "+".join(voices)
|
return combined
|
||||||
self._voices[combined_name] = combined
|
|
||||||
|
|
||||||
return combined_name
|
|
||||||
|
|
||||||
async def list_voices(self) -> List[str]:
|
async def list_voices(self) -> List[str]:
|
||||||
"""List available voice names.
|
"""List available voice names.
|
||||||
|
|
|
@ -8,14 +8,19 @@ from loguru import logger
|
||||||
|
|
||||||
from ..services.audio import AudioService, AudioNormalizer
|
from ..services.audio import AudioService, AudioNormalizer
|
||||||
from ..services.streaming_audio_writer import StreamingAudioWriter
|
from ..services.streaming_audio_writer import StreamingAudioWriter
|
||||||
from ..services.text_processing import phonemize, smart_split
|
from ..services.text_processing import smart_split
|
||||||
from ..services.text_processing.vocabulary import tokenize
|
from kokoro import KPipeline
|
||||||
from ..services.tts_service import TTSService
|
from ..services.tts_service import TTSService
|
||||||
from ..structures.text_schemas import (
|
from ..structures.text_schemas import (
|
||||||
GenerateFromPhonemesRequest,
|
GenerateFromPhonemesRequest,
|
||||||
PhonemeRequest,
|
PhonemeRequest,
|
||||||
PhonemeResponse,
|
PhonemeResponse,
|
||||||
)
|
)
|
||||||
|
from ..structures import (
|
||||||
|
CaptionedSpeechRequest,
|
||||||
|
CaptionedSpeechResponse,
|
||||||
|
WordTimestamp
|
||||||
|
)
|
||||||
|
|
||||||
router = APIRouter(tags=["text processing"])
|
router = APIRouter(tags=["text processing"])
|
||||||
|
|
||||||
|
@ -26,11 +31,10 @@ async def get_tts_service() -> TTSService:
|
||||||
|
|
||||||
@router.post("/dev/phonemize", response_model=PhonemeResponse)
|
@router.post("/dev/phonemize", response_model=PhonemeResponse)
|
||||||
async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse:
|
async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse:
|
||||||
"""Convert text to phonemes and tokens
|
"""Convert text to phonemes using Kokoro's quiet mode.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: Request containing text and language
|
request: Request containing text and language
|
||||||
tts_service: Injected TTSService instance
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Phonemes and token IDs
|
Phonemes and token IDs
|
||||||
|
@ -39,14 +43,17 @@ async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse:
|
||||||
if not request.text:
|
if not request.text:
|
||||||
raise ValueError("Text cannot be empty")
|
raise ValueError("Text cannot be empty")
|
||||||
|
|
||||||
# Get phonemes
|
# Initialize Kokoro pipeline in quiet mode (no model)
|
||||||
phonemes = phonemize(request.text, request.language)
|
pipeline = KPipeline(lang_code=request.language, model=False)
|
||||||
if not phonemes:
|
|
||||||
raise ValueError("Failed to generate phonemes")
|
# Get first result from pipeline (we only need one since we're not chunking)
|
||||||
|
for result in pipeline(request.text):
|
||||||
|
# result.graphemes = original text
|
||||||
|
# result.phonemes = phonemized text
|
||||||
|
# result.tokens = token objects (if available)
|
||||||
|
return PhonemeResponse(phonemes=result.phonemes, tokens=[])
|
||||||
|
|
||||||
# Get tokens (without adding start/end tokens to match process_text behavior)
|
raise ValueError("Failed to generate phonemes")
|
||||||
tokens = tokenize(phonemes)
|
|
||||||
return PhonemeResponse(phonemes=phonemes, tokens=tokens)
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f"Error in phoneme generation: {str(e)}")
|
logger.error(f"Error in phoneme generation: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
@ -63,7 +70,7 @@ async def generate_from_phonemes(
|
||||||
client_request: Request,
|
client_request: Request,
|
||||||
tts_service: TTSService = Depends(get_tts_service),
|
tts_service: TTSService = Depends(get_tts_service),
|
||||||
) -> StreamingResponse:
|
) -> StreamingResponse:
|
||||||
"""Generate audio directly from phonemes with proper streaming"""
|
"""Generate audio directly from phonemes using Kokoro's phoneme format"""
|
||||||
try:
|
try:
|
||||||
# Basic validation
|
# Basic validation
|
||||||
if not isinstance(request.phonemes, str):
|
if not isinstance(request.phonemes, str):
|
||||||
|
@ -77,41 +84,30 @@ async def generate_from_phonemes(
|
||||||
|
|
||||||
async def generate_chunks():
|
async def generate_chunks():
|
||||||
try:
|
try:
|
||||||
has_data = False
|
# Generate audio from phonemes
|
||||||
# Process phonemes in chunks
|
chunk_audio, _ = await tts_service.generate_from_phonemes(
|
||||||
async for chunk_text, _ in smart_split(request.phonemes):
|
phonemes=request.phonemes, # Pass complete phoneme string
|
||||||
# Check if client is still connected
|
voice=request.voice,
|
||||||
is_disconnected = client_request.is_disconnected
|
speed=1.0
|
||||||
if callable(is_disconnected):
|
)
|
||||||
is_disconnected = await is_disconnected()
|
|
||||||
if is_disconnected:
|
if chunk_audio is not None:
|
||||||
logger.info("Client disconnected, stopping audio generation")
|
# Normalize audio before writing
|
||||||
break
|
normalized_audio = await normalizer.normalize(chunk_audio)
|
||||||
|
# Write chunk and yield bytes
|
||||||
chunk_audio, _ = await tts_service.generate_from_phonemes(
|
chunk_bytes = writer.write_chunk(normalized_audio)
|
||||||
phonemes=chunk_text,
|
if chunk_bytes:
|
||||||
voice=request.voice,
|
yield chunk_bytes
|
||||||
speed=1.0
|
|
||||||
)
|
# Finalize and yield remaining bytes
|
||||||
if chunk_audio is not None:
|
|
||||||
has_data = True
|
|
||||||
# Normalize audio before writing
|
|
||||||
normalized_audio = await normalizer.normalize(chunk_audio)
|
|
||||||
# Write chunk and yield bytes
|
|
||||||
chunk_bytes = writer.write_chunk(normalized_audio)
|
|
||||||
if chunk_bytes:
|
|
||||||
yield chunk_bytes
|
|
||||||
|
|
||||||
if not has_data:
|
|
||||||
raise ValueError("Failed to generate any audio data")
|
|
||||||
|
|
||||||
# Finalize and yield remaining bytes if we still have a connection
|
|
||||||
if not (callable(is_disconnected) and await is_disconnected()):
|
|
||||||
final_bytes = writer.write_chunk(finalize=True)
|
final_bytes = writer.write_chunk(finalize=True)
|
||||||
if final_bytes:
|
if final_bytes:
|
||||||
yield final_bytes
|
yield final_bytes
|
||||||
|
else:
|
||||||
|
raise ValueError("Failed to generate audio data")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in audio chunk generation: {str(e)}")
|
logger.error(f"Error in audio generation: {str(e)}")
|
||||||
# Clean up writer on error
|
# Clean up writer on error
|
||||||
writer.write_chunk(finalize=True)
|
writer.write_chunk(finalize=True)
|
||||||
# Re-raise the original exception
|
# Re-raise the original exception
|
||||||
|
@ -128,7 +124,6 @@ async def generate_from_phonemes(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f"Error generating audio: {str(e)}")
|
logger.error(f"Error generating audio: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
@ -149,3 +144,92 @@ async def generate_from_phonemes(
|
||||||
"type": "server_error"
|
"type": "server_error"
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@router.post("/dev/captioned_speech")
|
||||||
|
async def create_captioned_speech(
|
||||||
|
request: CaptionedSpeechRequest,
|
||||||
|
tts_service: TTSService = Depends(get_tts_service),
|
||||||
|
) -> StreamingResponse:
|
||||||
|
"""Generate audio with word-level timestamps using Kokoro's output"""
|
||||||
|
try:
|
||||||
|
# Get voice path
|
||||||
|
voice_name, voice_path = await tts_service._get_voice_path(request.voice)
|
||||||
|
|
||||||
|
# Generate audio with timestamps
|
||||||
|
audio, _, word_timestamps = await tts_service.generate_audio(
|
||||||
|
text=request.input,
|
||||||
|
voice=voice_name,
|
||||||
|
speed=request.speed,
|
||||||
|
return_timestamps=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create streaming audio writer
|
||||||
|
writer = StreamingAudioWriter(format=request.response_format, sample_rate=24000, channels=1)
|
||||||
|
normalizer = AudioNormalizer()
|
||||||
|
|
||||||
|
async def generate_chunks():
|
||||||
|
try:
|
||||||
|
if audio is not None:
|
||||||
|
# Normalize audio before writing
|
||||||
|
normalized_audio = await normalizer.normalize(audio)
|
||||||
|
# Write chunk and yield bytes
|
||||||
|
chunk_bytes = writer.write_chunk(normalized_audio)
|
||||||
|
if chunk_bytes:
|
||||||
|
yield chunk_bytes
|
||||||
|
|
||||||
|
# Finalize and yield remaining bytes
|
||||||
|
final_bytes = writer.write_chunk(finalize=True)
|
||||||
|
if final_bytes:
|
||||||
|
yield final_bytes
|
||||||
|
else:
|
||||||
|
raise ValueError("Failed to generate audio data")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in audio generation: {str(e)}")
|
||||||
|
# Clean up writer on error
|
||||||
|
writer.write_chunk(finalize=True)
|
||||||
|
# Re-raise the original exception
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Convert timestamps to JSON and add as header
|
||||||
|
import json
|
||||||
|
logger.debug(f"Processing {len(word_timestamps)} word timestamps")
|
||||||
|
timestamps_json = json.dumps([{
|
||||||
|
'word': str(ts['word']), # Ensure string for text
|
||||||
|
'start_time': float(ts['start_time']), # Ensure float for timestamps
|
||||||
|
'end_time': float(ts['end_time'])
|
||||||
|
} for ts in word_timestamps])
|
||||||
|
logger.debug(f"Generated timestamps JSON: {timestamps_json}")
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
generate_chunks(),
|
||||||
|
media_type=f"audio/{request.response_format}",
|
||||||
|
headers={
|
||||||
|
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Transfer-Encoding": "chunked",
|
||||||
|
"X-Word-Timestamps": timestamps_json
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Error in captioned speech generation: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": "validation_error",
|
||||||
|
"message": str(e),
|
||||||
|
"type": "invalid_request_error"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in captioned speech generation: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={
|
||||||
|
"error": "processing_error",
|
||||||
|
"message": str(e),
|
||||||
|
"type": "server_error"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
|
@ -2,15 +2,19 @@
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import io
|
||||||
|
import tempfile
|
||||||
from typing import AsyncGenerator, Dict, List, Union
|
from typing import AsyncGenerator, Dict, List, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import aiofiles
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||||
from fastapi.responses import StreamingResponse, FileResponse
|
from fastapi.responses import StreamingResponse, FileResponse
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from ..services.audio import AudioService
|
from ..services.audio import AudioService
|
||||||
from ..services.tts_service import TTSService
|
from ..services.tts_service import TTSService
|
||||||
from ..structures.schemas import OpenAISpeechRequest
|
from ..structures import OpenAISpeechRequest
|
||||||
from ..core.config import settings
|
from ..core.config import settings
|
||||||
|
|
||||||
# Load OpenAI mappings
|
# Load OpenAI mappings
|
||||||
|
@ -72,55 +76,65 @@ def get_model_name(model: str) -> str:
|
||||||
async def process_voices(
|
async def process_voices(
|
||||||
voice_input: Union[str, List[str]], tts_service: TTSService
|
voice_input: Union[str, List[str]], tts_service: TTSService
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Process voice input into a combined voice, handling both string and list formats"""
|
"""Process voice input, handling both string and list formats
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Voice name to use (with weights if specified)
|
||||||
|
"""
|
||||||
# Convert input to list of voices
|
# Convert input to list of voices
|
||||||
if isinstance(voice_input, str):
|
if isinstance(voice_input, str):
|
||||||
# Check if it's an OpenAI voice name
|
# Check if it's an OpenAI voice name
|
||||||
mapped_voice = _openai_mappings["voices"].get(voice_input)
|
mapped_voice = _openai_mappings["voices"].get(voice_input)
|
||||||
if mapped_voice:
|
if mapped_voice:
|
||||||
voice_input = mapped_voice
|
voice_input = mapped_voice
|
||||||
voices = [v.strip() for v in voice_input.split("+") if v.strip()]
|
# Split on + but preserve any parentheses
|
||||||
|
voices = []
|
||||||
|
for part in voice_input.split("+"):
|
||||||
|
part = part.strip()
|
||||||
|
if not part:
|
||||||
|
continue
|
||||||
|
# Extract voice name without weight
|
||||||
|
voice_name = part.split("(")[0].strip()
|
||||||
|
# Check if it's a valid voice
|
||||||
|
available_voices = await tts_service.list_voices()
|
||||||
|
if voice_name not in available_voices:
|
||||||
|
raise ValueError(
|
||||||
|
f"Voice '{voice_name}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
||||||
|
)
|
||||||
|
voices.append(part)
|
||||||
else:
|
else:
|
||||||
# For list input, map each voice if it's an OpenAI voice name
|
# For list input, map each voice if it's an OpenAI voice name
|
||||||
voices = [_openai_mappings["voices"].get(v, v) for v in voice_input]
|
voices = []
|
||||||
voices = [v.strip() for v in voices if v.strip()]
|
for v in voice_input:
|
||||||
|
mapped = _openai_mappings["voices"].get(v, v)
|
||||||
|
voice_name = mapped.split("(")[0].strip()
|
||||||
|
# Check if it's a valid voice
|
||||||
|
available_voices = await tts_service.list_voices()
|
||||||
|
if voice_name not in available_voices:
|
||||||
|
raise ValueError(
|
||||||
|
f"Voice '{voice_name}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
||||||
|
)
|
||||||
|
voices.append(mapped)
|
||||||
|
|
||||||
if not voices:
|
if not voices:
|
||||||
raise ValueError("No voices provided")
|
raise ValueError("No voices provided")
|
||||||
|
|
||||||
# If single voice, validate and return it
|
# For multiple voices, combine them with +
|
||||||
if len(voices) == 1:
|
return "+".join(voices)
|
||||||
available_voices = await tts_service.list_voices()
|
|
||||||
if voices[0] not in available_voices:
|
|
||||||
raise ValueError(
|
|
||||||
f"Voice '{voices[0]}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
|
||||||
)
|
|
||||||
return voices[0]
|
|
||||||
|
|
||||||
# For multiple voices, validate base voices exist
|
|
||||||
available_voices = await tts_service.list_voices()
|
|
||||||
for voice in voices:
|
|
||||||
if voice not in available_voices:
|
|
||||||
raise ValueError(
|
|
||||||
f"Base voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Combine voices
|
|
||||||
return await tts_service.combine_voices(voices=voices)
|
|
||||||
|
|
||||||
|
|
||||||
async def stream_audio_chunks(
|
async def stream_audio_chunks(
|
||||||
tts_service: TTSService,
|
tts_service: TTSService,
|
||||||
request: OpenAISpeechRequest,
|
request: OpenAISpeechRequest,
|
||||||
client_request: Request
|
client_request: Request
|
||||||
) -> AsyncGenerator[bytes, None]:
|
) -> AsyncGenerator[bytes, None]:
|
||||||
"""Stream audio chunks as they're generated with client disconnect handling"""
|
"""Stream audio chunks as they're generated with client disconnect handling"""
|
||||||
voice_to_use = await process_voices(request.voice, tts_service)
|
voice_name = await process_voices(request.voice, tts_service)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for chunk in tts_service.generate_audio_stream(
|
async for chunk in tts_service.generate_audio_stream(
|
||||||
text=request.input,
|
text=request.input,
|
||||||
voice=voice_to_use,
|
voice=voice_name,
|
||||||
speed=request.speed,
|
speed=request.speed,
|
||||||
output_format=request.response_format,
|
output_format=request.response_format,
|
||||||
):
|
):
|
||||||
|
@ -159,7 +173,7 @@ async def create_speech(
|
||||||
try:
|
try:
|
||||||
# model_name = get_model_name(request.model)
|
# model_name = get_model_name(request.model)
|
||||||
tts_service = await get_tts_service()
|
tts_service = await get_tts_service()
|
||||||
voice_to_use = await process_voices(request.voice, tts_service)
|
voice_name = await process_voices(request.voice, tts_service)
|
||||||
|
|
||||||
# Set content type based on format
|
# Set content type based on format
|
||||||
content_type = {
|
content_type = {
|
||||||
|
@ -237,7 +251,7 @@ async def create_speech(
|
||||||
# Generate complete audio using public interface
|
# Generate complete audio using public interface
|
||||||
audio, _ = await tts_service.generate_audio(
|
audio, _ = await tts_service.generate_audio(
|
||||||
text=request.input,
|
text=request.input,
|
||||||
voice=voice_to_use,
|
voice=voice_name,
|
||||||
speed=request.speed
|
speed=request.speed
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -350,14 +364,14 @@ async def list_voices():
|
||||||
|
|
||||||
@router.post("/audio/voices/combine")
|
@router.post("/audio/voices/combine")
|
||||||
async def combine_voices(request: Union[str, List[str]]):
|
async def combine_voices(request: Union[str, List[str]]):
|
||||||
"""Combine multiple voices into a new voice.
|
"""Combine multiple voices into a new voice and return the .pt file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: Either a string with voices separated by + (e.g. "voice1+voice2")
|
request: Either a string with voices separated by + (e.g. "voice1+voice2")
|
||||||
or a list of voice names to combine
|
or a list of voice names to combine
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with combined voice name and list of all available voices
|
FileResponse with the combined voice .pt file
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
HTTPException:
|
HTTPException:
|
||||||
|
@ -365,10 +379,51 @@ async def combine_voices(request: Union[str, List[str]]):
|
||||||
- 500: Server error (file system issues, combination failed)
|
- 500: Server error (file system issues, combination failed)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Convert input to list of voices
|
||||||
|
if isinstance(request, str):
|
||||||
|
# Check if it's an OpenAI voice name
|
||||||
|
mapped_voice = _openai_mappings["voices"].get(request)
|
||||||
|
if mapped_voice:
|
||||||
|
request = mapped_voice
|
||||||
|
voices = [v.strip() for v in request.split("+") if v.strip()]
|
||||||
|
else:
|
||||||
|
# For list input, map each voice if it's an OpenAI voice name
|
||||||
|
voices = [_openai_mappings["voices"].get(v, v) for v in request]
|
||||||
|
voices = [v.strip() for v in voices if v.strip()]
|
||||||
|
|
||||||
|
if not voices:
|
||||||
|
raise ValueError("No voices provided")
|
||||||
|
|
||||||
|
# For multiple voices, validate base voices exist
|
||||||
tts_service = await get_tts_service()
|
tts_service = await get_tts_service()
|
||||||
combined_voice = await process_voices(request, tts_service)
|
available_voices = await tts_service.list_voices()
|
||||||
voices = await tts_service.list_voices()
|
for voice in voices:
|
||||||
return {"voices": voices, "voice": combined_voice}
|
if voice not in available_voices:
|
||||||
|
raise ValueError(
|
||||||
|
f"Base voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combine voices
|
||||||
|
combined_tensor = await tts_service.combine_voices(voices=voices)
|
||||||
|
combined_name = "+".join(voices)
|
||||||
|
|
||||||
|
# Save to temp file
|
||||||
|
temp_dir = tempfile.gettempdir()
|
||||||
|
voice_path = os.path.join(temp_dir, f"{combined_name}.pt")
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
torch.save(combined_tensor, buffer)
|
||||||
|
async with aiofiles.open(voice_path, 'wb') as f:
|
||||||
|
await f.write(buffer.getvalue())
|
||||||
|
|
||||||
|
return FileResponse(
|
||||||
|
voice_path,
|
||||||
|
media_type="application/octet-stream",
|
||||||
|
filename=f"{combined_name}.pt",
|
||||||
|
headers={
|
||||||
|
"Content-Disposition": f"attachment; filename={combined_name}.pt",
|
||||||
|
"Cache-Control": "no-cache"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.warning(f"Invalid voice combination request: {str(e)}")
|
logger.warning(f"Invalid voice combination request: {str(e)}")
|
||||||
|
|
|
@ -17,6 +17,7 @@ from .audio import AudioNormalizer, AudioService
|
||||||
from .text_processing.text_processor import process_text_chunk, smart_split
|
from .text_processing.text_processor import process_text_chunk, smart_split
|
||||||
from .text_processing import tokenize
|
from .text_processing import tokenize
|
||||||
from ..inference.kokoro_v1 import KokoroV1
|
from ..inference.kokoro_v1 import KokoroV1
|
||||||
|
from kokoro import KPipeline
|
||||||
|
|
||||||
|
|
||||||
class TTSService:
|
class TTSService:
|
||||||
|
@ -154,23 +155,43 @@ class TTSService:
|
||||||
try:
|
try:
|
||||||
# Check if it's a combined voice
|
# Check if it's a combined voice
|
||||||
if "+" in voice:
|
if "+" in voice:
|
||||||
voices = [v.strip() for v in voice.split("+") if v.strip()]
|
# Split on + but preserve any parentheses
|
||||||
if len(voices) < 2:
|
voice_parts = []
|
||||||
|
weights = []
|
||||||
|
for part in voice.split("+"):
|
||||||
|
part = part.strip()
|
||||||
|
if not part:
|
||||||
|
continue
|
||||||
|
# Extract voice name and weight if present
|
||||||
|
if "(" in part and ")" in part:
|
||||||
|
voice_name = part.split("(")[0].strip()
|
||||||
|
weight = float(part.split("(")[1].split(")")[0])
|
||||||
|
else:
|
||||||
|
voice_name = part
|
||||||
|
weight = 1.0
|
||||||
|
voice_parts.append(voice_name)
|
||||||
|
weights.append(weight)
|
||||||
|
|
||||||
|
if len(voice_parts) < 2:
|
||||||
raise RuntimeError(f"Invalid combined voice name: {voice}")
|
raise RuntimeError(f"Invalid combined voice name: {voice}")
|
||||||
|
|
||||||
|
# Normalize weights to sum to 1
|
||||||
|
total_weight = sum(weights)
|
||||||
|
weights = [w/total_weight for w in weights]
|
||||||
|
|
||||||
# Load and combine voices
|
# Load and combine voices
|
||||||
voice_tensors = []
|
voice_tensors = []
|
||||||
for v in voices:
|
for v, w in zip(voice_parts, weights):
|
||||||
path = await self._voice_manager.get_voice_path(v)
|
path = await self._voice_manager.get_voice_path(v)
|
||||||
if not path:
|
if not path:
|
||||||
raise RuntimeError(f"Voice not found: {v}")
|
raise RuntimeError(f"Voice not found: {v}")
|
||||||
logger.debug(f"Loading voice tensor from: {path}")
|
logger.debug(f"Loading voice tensor from: {path}")
|
||||||
voice_tensor = torch.load(path, map_location="cpu")
|
voice_tensor = torch.load(path, map_location="cpu")
|
||||||
voice_tensors.append(voice_tensor)
|
voice_tensors.append(voice_tensor * w)
|
||||||
|
|
||||||
# Average the voice tensors
|
# Sum the weighted voice tensors
|
||||||
logger.debug(f"Combining {len(voice_tensors)} voice tensors")
|
logger.debug(f"Combining {len(voice_tensors)} voice tensors with weights {weights}")
|
||||||
combined = torch.mean(torch.stack(voice_tensors), dim=0)
|
combined = torch.sum(torch.stack(voice_tensors), dim=0)
|
||||||
|
|
||||||
# Save combined tensor
|
# Save combined tensor
|
||||||
temp_dir = tempfile.gettempdir()
|
temp_dir = tempfile.gettempdir()
|
||||||
|
@ -259,43 +280,237 @@ class TTSService:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def generate_audio(
|
async def generate_audio(
|
||||||
self, text: str, voice: str, speed: float = 1.0
|
self, text: str, voice: str, speed: float = 1.0, return_timestamps: bool = False
|
||||||
) -> Tuple[np.ndarray, float]:
|
) -> Union[Tuple[np.ndarray, float], Tuple[np.ndarray, float, List[dict]]]:
|
||||||
"""Generate complete audio for text using streaming internally."""
|
"""Generate complete audio for text using streaming internally."""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
chunks = []
|
chunks = []
|
||||||
|
word_timestamps = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Use streaming generator but collect all valid chunks
|
# Get backend and voice path
|
||||||
async for chunk in self.generate_audio_stream(
|
backend = self.model_manager.get_backend()
|
||||||
text, voice, speed, # Default to WAV for raw audio
|
voice_name, voice_path = await self._get_voice_path(voice)
|
||||||
):
|
|
||||||
if chunk is not None:
|
|
||||||
chunks.append(chunk)
|
|
||||||
|
|
||||||
if not chunks:
|
if isinstance(backend, KokoroV1):
|
||||||
raise ValueError("No audio chunks were generated successfully")
|
# Initialize quiet pipeline for text chunking
|
||||||
|
quiet_pipeline = KPipeline(lang_code='a', model=False)
|
||||||
|
|
||||||
|
# Split text into chunks and get initial tokens
|
||||||
|
text_chunks = []
|
||||||
|
current_offset = 0.0 # Track time offset for timestamps
|
||||||
|
|
||||||
|
logger.debug("Splitting text into chunks...")
|
||||||
|
for result in quiet_pipeline(text):
|
||||||
|
if result.graphemes and result.phonemes:
|
||||||
|
text_chunks.append((result.graphemes, result.phonemes))
|
||||||
|
logger.debug(f"Split text into {len(text_chunks)} chunks")
|
||||||
|
|
||||||
|
# Process each chunk
|
||||||
|
for chunk_idx, (chunk_text, chunk_phonemes) in enumerate(text_chunks):
|
||||||
|
logger.debug(f"Processing chunk {chunk_idx + 1}/{len(text_chunks)}: '{chunk_text[:50]}...'")
|
||||||
|
|
||||||
|
# Generate audio and timestamps for this chunk
|
||||||
|
for result in backend._pipeline(
|
||||||
|
chunk_text,
|
||||||
|
voice=voice_path,
|
||||||
|
speed=speed,
|
||||||
|
model=backend._model
|
||||||
|
):
|
||||||
|
# Collect audio chunks
|
||||||
|
if result.audio is not None:
|
||||||
|
chunks.append(result.audio.numpy())
|
||||||
|
|
||||||
|
# Process timestamps for this chunk
|
||||||
|
if return_timestamps and hasattr(result, 'tokens') and result.tokens:
|
||||||
|
logger.debug(f"Processing chunk timestamps with {len(result.tokens)} tokens")
|
||||||
|
if result.pred_dur is not None:
|
||||||
|
try:
|
||||||
|
# Join timestamps for this chunk's tokens
|
||||||
|
KPipeline.join_timestamps(result.tokens, result.pred_dur)
|
||||||
|
|
||||||
|
# Add timestamps with offset
|
||||||
|
for token in result.tokens:
|
||||||
|
if not all(hasattr(token, attr) for attr in ['text', 'start_ts', 'end_ts']):
|
||||||
|
continue
|
||||||
|
if not token.text or not token.text.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Apply offset to timestamps
|
||||||
|
start_time = float(token.start_ts) + current_offset
|
||||||
|
end_time = float(token.end_ts) + current_offset
|
||||||
|
|
||||||
|
word_timestamps.append({
|
||||||
|
'word': str(token.text).strip(),
|
||||||
|
'start_time': start_time,
|
||||||
|
'end_time': end_time
|
||||||
|
})
|
||||||
|
logger.debug(f"Added timestamp for word '{token.text}': {start_time:.3f}s - {end_time:.3f}s")
|
||||||
|
|
||||||
|
# Update offset for next chunk based on pred_dur
|
||||||
|
chunk_duration = float(result.pred_dur.sum()) / 80 # Convert frames to seconds
|
||||||
|
current_offset = max(current_offset + chunk_duration, end_time)
|
||||||
|
logger.debug(f"Updated time offset to {current_offset:.3f}s")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to process timestamps for chunk: {e}")
|
||||||
|
logger.debug(f"Processing timestamps with pred_dur shape: {result.pred_dur.shape}")
|
||||||
|
try:
|
||||||
|
# Join timestamps for this chunk's tokens
|
||||||
|
KPipeline.join_timestamps(result.tokens, result.pred_dur)
|
||||||
|
logger.debug("Successfully joined timestamps for chunk")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to join timestamps for chunk: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Convert tokens to timestamps
|
||||||
|
for token in result.tokens:
|
||||||
|
try:
|
||||||
|
# Skip tokens without required attributes
|
||||||
|
if not all(hasattr(token, attr) for attr in ['text', 'start_ts', 'end_ts']):
|
||||||
|
logger.debug(f"Skipping token missing attributes: {dir(token)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get and validate text
|
||||||
|
text = str(token.text).strip() if token.text is not None else ''
|
||||||
|
if not text:
|
||||||
|
logger.debug("Skipping empty token")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get and validate timestamps
|
||||||
|
start_ts = getattr(token, 'start_ts', None)
|
||||||
|
end_ts = getattr(token, 'end_ts', None)
|
||||||
|
if start_ts is None or end_ts is None:
|
||||||
|
logger.debug(f"Skipping token with None timestamps: {text}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Convert timestamps to float
|
||||||
|
try:
|
||||||
|
start_time = float(start_ts)
|
||||||
|
end_time = float(end_ts)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
logger.debug(f"Skipping token with invalid timestamps: {text}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Add timestamp
|
||||||
|
word_timestamps.append({
|
||||||
|
'word': text,
|
||||||
|
'start_time': start_time,
|
||||||
|
'end_time': end_time
|
||||||
|
})
|
||||||
|
logger.debug(f"Added timestamp for word '{text}': {start_time:.3f}s - {end_time:.3f}s")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error processing token: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not chunks:
|
||||||
|
raise ValueError("No audio chunks were generated successfully")
|
||||||
|
|
||||||
|
# Combine chunks
|
||||||
|
audio = np.concatenate(chunks) if len(chunks) > 1 else chunks[0]
|
||||||
|
processing_time = time.time() - start_time
|
||||||
|
|
||||||
|
if return_timestamps:
|
||||||
|
# Validate timestamps before returning
|
||||||
|
if not word_timestamps:
|
||||||
|
logger.warning("No valid timestamps were generated")
|
||||||
|
else:
|
||||||
|
# Sort timestamps by start time to ensure proper order
|
||||||
|
word_timestamps.sort(key=lambda x: x['start_time'])
|
||||||
|
# Validate timestamp sequence
|
||||||
|
for i in range(1, len(word_timestamps)):
|
||||||
|
prev = word_timestamps[i-1]
|
||||||
|
curr = word_timestamps[i]
|
||||||
|
if curr['start_time'] < prev['end_time']:
|
||||||
|
logger.warning(f"Overlapping timestamps detected: '{prev['word']}' ({prev['start_time']:.3f}-{prev['end_time']:.3f}) and '{curr['word']}' ({curr['start_time']:.3f}-{curr['end_time']:.3f})")
|
||||||
|
|
||||||
|
logger.debug(f"Returning {len(word_timestamps)} word timestamps")
|
||||||
|
logger.debug(f"First timestamp: {word_timestamps[0]['word']} at {word_timestamps[0]['start_time']:.3f}s")
|
||||||
|
logger.debug(f"Last timestamp: {word_timestamps[-1]['word']} at {word_timestamps[-1]['end_time']:.3f}s")
|
||||||
|
|
||||||
|
return audio, processing_time, word_timestamps
|
||||||
|
return audio, processing_time
|
||||||
|
|
||||||
# Combine chunks, ensuring we have valid arrays
|
|
||||||
if len(chunks) == 1:
|
|
||||||
audio = chunks[0]
|
|
||||||
else:
|
else:
|
||||||
# Filter out any zero-dimensional arrays
|
# For legacy backends
|
||||||
valid_chunks = [c for c in chunks if c.ndim > 0]
|
async for chunk in self.generate_audio_stream(
|
||||||
if not valid_chunks:
|
text, voice, speed, # Default to WAV for raw audio
|
||||||
raise ValueError("No valid audio chunks to concatenate")
|
):
|
||||||
audio = np.concatenate(valid_chunks)
|
if chunk is not None:
|
||||||
processing_time = time.time() - start_time
|
chunks.append(chunk)
|
||||||
return audio, processing_time
|
|
||||||
|
if not chunks:
|
||||||
|
raise ValueError("No audio chunks were generated successfully")
|
||||||
|
|
||||||
|
# Combine chunks
|
||||||
|
audio = np.concatenate(chunks) if len(chunks) > 1 else chunks[0]
|
||||||
|
processing_time = time.time() - start_time
|
||||||
|
|
||||||
|
if return_timestamps:
|
||||||
|
return audio, processing_time, [] # Empty timestamps for legacy backends
|
||||||
|
return audio, processing_time
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in audio generation: {str(e)}")
|
logger.error(f"Error in audio generation: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def combine_voices(self, voices: List[str]) -> str:
|
async def combine_voices(self, voices: List[str]) -> torch.Tensor:
|
||||||
"""Combine multiple voices."""
|
"""Combine multiple voices.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined voice tensor
|
||||||
|
"""
|
||||||
return await self._voice_manager.combine_voices(voices)
|
return await self._voice_manager.combine_voices(voices)
|
||||||
|
|
||||||
async def list_voices(self) -> List[str]:
|
async def list_voices(self) -> List[str]:
|
||||||
"""List available voices."""
|
"""List available voices."""
|
||||||
return await self._voice_manager.list_voices()
|
return await self._voice_manager.list_voices()
|
||||||
|
|
||||||
|
async def generate_from_phonemes(
|
||||||
|
self,
|
||||||
|
phonemes: str,
|
||||||
|
voice: str,
|
||||||
|
speed: float = 1.0
|
||||||
|
) -> Tuple[np.ndarray, float]:
|
||||||
|
"""Generate audio directly from phonemes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
phonemes: Phonemes in Kokoro format
|
||||||
|
voice: Voice name
|
||||||
|
speed: Speed multiplier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (audio array, processing time)
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
try:
|
||||||
|
# Get backend and voice path
|
||||||
|
raise ValueError("Not yet implemented")
|
||||||
|
# linked to https://github.com/hexgrad/kokoro/pull/53 or similiar
|
||||||
|
backend = self.model_manager.get_backend()
|
||||||
|
voice_name, voice_path = await self._get_voice_path(voice)
|
||||||
|
|
||||||
|
# if isinstance(backend, KokoroV1):
|
||||||
|
# # For Kokoro V1, pass phonemes directly to pipeline
|
||||||
|
# result = None
|
||||||
|
# for r in backend._pipeline(
|
||||||
|
# phonemes,
|
||||||
|
# voice=voice_path,
|
||||||
|
# speed=speed,
|
||||||
|
# model=backend._model
|
||||||
|
# ):
|
||||||
|
# if r.audio is not None:
|
||||||
|
# result = r
|
||||||
|
# break
|
||||||
|
|
||||||
|
# if result is None or result.audio is None:
|
||||||
|
# raise ValueError("No audio generated")
|
||||||
|
|
||||||
|
# processing_time = time.time() - start_time
|
||||||
|
# return result.audio.numpy(), processing_time
|
||||||
|
# else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in phoneme audio generation: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
from .schemas import OpenAISpeechRequest
|
from .schemas import (
|
||||||
|
OpenAISpeechRequest,
|
||||||
|
CaptionedSpeechRequest,
|
||||||
|
CaptionedSpeechResponse,
|
||||||
|
WordTimestamp,
|
||||||
|
TTSStatus,
|
||||||
|
VoiceCombineRequest
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = ["OpenAISpeechRequest"]
|
__all__ = [
|
||||||
|
"OpenAISpeechRequest",
|
||||||
|
"CaptionedSpeechRequest",
|
||||||
|
"CaptionedSpeechResponse",
|
||||||
|
"WordTimestamp",
|
||||||
|
"TTSStatus",
|
||||||
|
"VoiceCombineRequest"
|
||||||
|
]
|
||||||
|
|
|
@ -22,7 +22,19 @@ class TTSStatus(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
# OpenAI-compatible schemas
|
# OpenAI-compatible schemas
|
||||||
|
class WordTimestamp(BaseModel):
|
||||||
|
"""Word-level timestamp information"""
|
||||||
|
word: str = Field(..., description="The word or token")
|
||||||
|
start_time: float = Field(..., description="Start time in seconds")
|
||||||
|
end_time: float = Field(..., description="End time in seconds")
|
||||||
|
|
||||||
|
class CaptionedSpeechResponse(BaseModel):
|
||||||
|
"""Response schema for captioned speech endpoint"""
|
||||||
|
audio: bytes = Field(..., description="The generated audio data")
|
||||||
|
words: List[WordTimestamp] = Field(..., description="Word-level timestamps")
|
||||||
|
|
||||||
class OpenAISpeechRequest(BaseModel):
|
class OpenAISpeechRequest(BaseModel):
|
||||||
|
"""Request schema for OpenAI-compatible speech endpoint"""
|
||||||
model: str = Field(
|
model: str = Field(
|
||||||
default="kokoro",
|
default="kokoro",
|
||||||
description="The model to use for generation. Supported models: tts-1, tts-1-hd, kokoro"
|
description="The model to use for generation. Supported models: tts-1, tts-1-hd, kokoro"
|
||||||
|
@ -50,3 +62,29 @@ class OpenAISpeechRequest(BaseModel):
|
||||||
default=False,
|
default=False,
|
||||||
description="If true, returns a download link in X-Download-Path header after streaming completes",
|
description="If true, returns a download link in X-Download-Path header after streaming completes",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class CaptionedSpeechRequest(BaseModel):
|
||||||
|
"""Request schema for captioned speech endpoint"""
|
||||||
|
model: str = Field(
|
||||||
|
default="kokoro",
|
||||||
|
description="The model to use for generation. Supported models: tts-1, tts-1-hd, kokoro"
|
||||||
|
)
|
||||||
|
input: str = Field(..., description="The text to generate audio for")
|
||||||
|
voice: str = Field(
|
||||||
|
default="af",
|
||||||
|
description="The voice to use for generation. Can be a base voice or a combined voice name.",
|
||||||
|
)
|
||||||
|
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field(
|
||||||
|
default="mp3",
|
||||||
|
description="The format to return audio in. Supported formats: mp3, opus, flac, wav, pcm. PCM format returns raw 16-bit samples without headers. AAC is not currently supported.",
|
||||||
|
)
|
||||||
|
speed: float = Field(
|
||||||
|
default=1.0,
|
||||||
|
ge=0.25,
|
||||||
|
le=4.0,
|
||||||
|
description="The speed of the generated audio. Select a value from 0.25 to 4.0.",
|
||||||
|
)
|
||||||
|
return_timestamps: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="If true (default), returns word-level timestamps in the response",
|
||||||
|
)
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -92,15 +92,16 @@ def analyze_audio(filepath: str):
|
||||||
'dominant_frequencies': dominant_freqs
|
'dominant_frequencies': dominant_freqs
|
||||||
}
|
}
|
||||||
|
|
||||||
def plot_comparison(analyses, output_path):
|
def plot_comparison(analyses, output_dir):
|
||||||
"""Create comparison plot of the audio analyses."""
|
"""Create detailed comparison plots of the audio analyses."""
|
||||||
plt.style.use('dark_background')
|
plt.style.use('dark_background')
|
||||||
fig = plt.figure(figsize=(15, 10))
|
|
||||||
fig.patch.set_facecolor('#1a1a2e')
|
|
||||||
|
|
||||||
# Plot waveforms
|
# Plot waveforms
|
||||||
|
fig_wave = plt.figure(figsize=(15, 10))
|
||||||
|
fig_wave.patch.set_facecolor('#1a1a2e')
|
||||||
|
|
||||||
for i, (name, data) in enumerate(analyses.items()):
|
for i, (name, data) in enumerate(analyses.items()):
|
||||||
ax = plt.subplot(3, 1, i+1)
|
ax = plt.subplot(len(analyses), 1, i+1)
|
||||||
samples = data['samples']
|
samples = data['samples']
|
||||||
time = np.arange(len(samples)) / data['sample_rate']
|
time = np.arange(len(samples)) / data['sample_rate']
|
||||||
plt.plot(time, samples / data['max_amplitude'], linewidth=0.5, color='#ff2a6d')
|
plt.plot(time, samples / data['max_amplitude'], linewidth=0.5, color='#ff2a6d')
|
||||||
|
@ -112,27 +113,96 @@ def plot_comparison(analyses, output_path):
|
||||||
plt.ylim(-1.1, 1.1)
|
plt.ylim(-1.1, 1.1)
|
||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
plt.savefig(output_dir / 'waveforms.png', dpi=300, bbox_inches='tight')
|
||||||
print(f"\nSaved comparison plot to {output_path}")
|
plt.close()
|
||||||
|
|
||||||
|
# Plot spectral characteristics
|
||||||
|
fig_spec = plt.figure(figsize=(15, 10))
|
||||||
|
fig_spec.patch.set_facecolor('#1a1a2e')
|
||||||
|
|
||||||
|
for i, (name, data) in enumerate(analyses.items()):
|
||||||
|
# Calculate spectrogram
|
||||||
|
samples = data['samples']
|
||||||
|
sample_rate = data['sample_rate']
|
||||||
|
nperseg = 2048
|
||||||
|
f, t, Sxx = plt.mlab.specgram(samples, NFFT=2048, Fs=sample_rate,
|
||||||
|
noverlap=nperseg//2, scale='dB')
|
||||||
|
|
||||||
|
ax = plt.subplot(len(analyses), 1, i+1)
|
||||||
|
plt.pcolormesh(t, f, Sxx, shading='gouraud', cmap='magma')
|
||||||
|
plt.title(f"Spectrogram: {name}", color='white', pad=20)
|
||||||
|
plt.ylabel('Frequency [Hz]', color='white')
|
||||||
|
plt.xlabel('Time [sec]', color='white')
|
||||||
|
plt.colorbar(label='Intensity [dB]')
|
||||||
|
ax.set_facecolor('#1a1a2e')
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(output_dir / 'spectrograms.png', dpi=300, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
# Plot voice characteristics comparison
|
||||||
|
fig_chars = plt.figure(figsize=(15, 8))
|
||||||
|
fig_chars.patch.set_facecolor('#1a1a2e')
|
||||||
|
|
||||||
|
# Extract characteristics
|
||||||
|
names = list(analyses.keys())
|
||||||
|
rms_values = [data['rms'] for data in analyses.values()]
|
||||||
|
centroids = [data['spectral_centroid'] for data in analyses.values()]
|
||||||
|
max_amps = [data['max_amplitude'] for data in analyses.values()]
|
||||||
|
|
||||||
|
# Plot characteristics
|
||||||
|
x = np.arange(len(names))
|
||||||
|
width = 0.25
|
||||||
|
|
||||||
|
ax = plt.subplot(111)
|
||||||
|
ax.bar(x - width, rms_values, width, label='RMS (Texture)', color='#ff2a6d')
|
||||||
|
ax.bar(x, [c/1000 for c in centroids], width, label='Spectral Centroid/1000 (Brightness)', color='#05d9e8')
|
||||||
|
ax.bar(x + width, max_amps, width, label='Max Amplitude', color='#ff65bd')
|
||||||
|
|
||||||
|
ax.set_xticks(x)
|
||||||
|
ax.set_xticklabels(names, rotation=45, ha='right')
|
||||||
|
ax.legend()
|
||||||
|
ax.set_title('Voice Characteristics Comparison', color='white', pad=20)
|
||||||
|
ax.set_facecolor('#1a1a2e')
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(output_dir / 'characteristics.png', dpi=300, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
print(f"\nSaved comparison plots to {output_dir}")
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Generate audio for each voice
|
# Test different voice combinations with weights
|
||||||
voices = {
|
voices = {
|
||||||
'af_bella': output_dir / 'af_bella.wav',
|
'af_bella': output_dir / 'af_bella.wav',
|
||||||
'af_irulan': output_dir / 'af_irulan.wav',
|
'af_kore': output_dir / 'af_kore.wav',
|
||||||
'af_bella+af_irulan': output_dir / 'af_bella+af_irulan.wav'
|
'af_bella(0.2)+af_kore(0.8)': output_dir / 'af_bella_20_af_kore_80.wav',
|
||||||
|
'af_bella(0.8)+af_kore(0.2)': output_dir / 'af_bella_80_af_kore_20.wav',
|
||||||
|
'af_bella(0.5)+af_kore(0.5)': output_dir / 'af_bella_50_af_kore_50.wav'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Generate audio for each voice/combination
|
||||||
for voice, path in voices.items():
|
for voice, path in voices.items():
|
||||||
generate_and_save_audio(voice, str(path))
|
try:
|
||||||
|
generate_and_save_audio(voice, str(path))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error generating audio for {voice}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
# Analyze each audio file
|
# Analyze each audio file
|
||||||
analyses = {}
|
analyses = {}
|
||||||
for name, path in voices.items():
|
for name, path in voices.items():
|
||||||
analyses[name] = analyze_audio(str(path))
|
try:
|
||||||
|
analyses[name] = analyze_audio(str(path))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error analyzing {name}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
# Create comparison plot
|
# Create comparison plots
|
||||||
plot_comparison(analyses, output_dir / 'voice_comparison.png')
|
if analyses:
|
||||||
|
plot_comparison(analyses, output_dir)
|
||||||
|
else:
|
||||||
|
print("No analyses to plot")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -0,0 +1,74 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
import requests
|
||||||
|
|
||||||
|
# Create output directory
|
||||||
|
output_dir = Path(__file__).parent / "output"
|
||||||
|
output_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
def download_combined_voice(voice1: str, voice2: str, weights: tuple[float, float] = None) -> str:
|
||||||
|
"""Download a combined voice file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
voice1: First voice name
|
||||||
|
voice2: Second voice name
|
||||||
|
weights: Optional tuple of weights (w1, w2). If not provided, uses equal weights.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to downloaded .pt file
|
||||||
|
"""
|
||||||
|
print(f"\nDownloading combined voice: {voice1} + {voice2}")
|
||||||
|
|
||||||
|
# Construct voice string with optional weights
|
||||||
|
if weights:
|
||||||
|
voice_str = f"{voice1}({weights[0]})+{voice2}({weights[1]})"
|
||||||
|
else:
|
||||||
|
voice_str = f"{voice1}+{voice2}"
|
||||||
|
|
||||||
|
# Make the request to combine voices
|
||||||
|
response = requests.post(
|
||||||
|
"http://localhost:8880/v1/audio/voices/combine",
|
||||||
|
json=voice_str
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f"Failed to combine voices: {response.text}")
|
||||||
|
|
||||||
|
# Save the .pt file
|
||||||
|
output_path = output_dir / f"{voice_str}.pt"
|
||||||
|
with open(output_path, "wb") as f:
|
||||||
|
f.write(response.content)
|
||||||
|
|
||||||
|
print(f"Saved combined voice to {output_path}")
|
||||||
|
return str(output_path)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Test downloading various voice combinations
|
||||||
|
combinations = [
|
||||||
|
# Equal weights (default)
|
||||||
|
("af_bella", "af_kore"),
|
||||||
|
|
||||||
|
# Different weight combinations
|
||||||
|
("af_bella", "af_kore", (0.2, 0.8)),
|
||||||
|
("af_bella", "af_kore", (0.8, 0.2)),
|
||||||
|
("af_bella", "af_kore", (0.5, 0.5)),
|
||||||
|
|
||||||
|
# Test with different voices
|
||||||
|
("af_bella", "af_jadzia"),
|
||||||
|
("af_bella", "af_jadzia", (0.3, 0.7))
|
||||||
|
]
|
||||||
|
|
||||||
|
for combo in combinations:
|
||||||
|
try:
|
||||||
|
if len(combo) == 3:
|
||||||
|
voice1, voice2, weights = combo
|
||||||
|
download_combined_voice(voice1, voice2, weights)
|
||||||
|
else:
|
||||||
|
voice1, voice2 = combo
|
||||||
|
download_combined_voice(voice1, voice2)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error downloading combination {combo}: {e}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,54 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
def analyze_voice_file(file_path):
|
||||||
|
"""Analyze dimensions and statistics of a voice tensor."""
|
||||||
|
try:
|
||||||
|
tensor = torch.load(file_path, map_location="cpu")
|
||||||
|
logger.info(f"\nAnalyzing {os.path.basename(file_path)}:")
|
||||||
|
logger.info(f"Shape: {tensor.shape}")
|
||||||
|
logger.info(f"Mean: {tensor.mean().item():.4f}")
|
||||||
|
logger.info(f"Std: {tensor.std().item():.4f}")
|
||||||
|
logger.info(f"Min: {tensor.min().item():.4f}")
|
||||||
|
logger.info(f"Max: {tensor.max().item():.4f}")
|
||||||
|
return tensor.shape
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error analyzing {file_path}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Analyze voice files in the voices directory."""
|
||||||
|
# Get the project root directory
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
project_root = os.path.dirname(os.path.dirname(os.path.dirname(current_dir)))
|
||||||
|
voices_dir = os.path.join(project_root, "api", "src", "voices", "v1_0")
|
||||||
|
|
||||||
|
logger.info(f"Scanning voices in: {voices_dir}")
|
||||||
|
|
||||||
|
# Track shapes for comparison
|
||||||
|
shapes = {}
|
||||||
|
|
||||||
|
# Analyze each .pt file
|
||||||
|
for file in os.listdir(voices_dir):
|
||||||
|
if file.endswith('.pt'):
|
||||||
|
file_path = os.path.join(voices_dir, file)
|
||||||
|
shape = analyze_voice_file(file_path)
|
||||||
|
if shape:
|
||||||
|
shapes[file] = shape
|
||||||
|
|
||||||
|
# Report findings
|
||||||
|
logger.info("\nShape Analysis:")
|
||||||
|
shape_groups = {}
|
||||||
|
for file, shape in shapes.items():
|
||||||
|
if shape not in shape_groups:
|
||||||
|
shape_groups[shape] = []
|
||||||
|
shape_groups[shape].append(file)
|
||||||
|
|
||||||
|
for shape, files in shape_groups.items():
|
||||||
|
logger.info(f"\nShape {shape}:")
|
||||||
|
for file in files:
|
||||||
|
logger.info(f" - {file}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -0,0 +1,85 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
def analyze_voice_content(tensor):
|
||||||
|
"""Analyze the content distribution in the voice tensor."""
|
||||||
|
# Look at the variance along the first dimension to see where the information is concentrated
|
||||||
|
variance = torch.var(tensor, dim=(1,2)) # Variance across features
|
||||||
|
logger.info(f"Variance distribution:")
|
||||||
|
logger.info(f"First 5 rows variance: {variance[:5].mean().item():.6f}")
|
||||||
|
logger.info(f"Last 5 rows variance: {variance[-5:].mean().item():.6f}")
|
||||||
|
return variance
|
||||||
|
|
||||||
|
def trim_voice_tensor(tensor):
|
||||||
|
"""Trim a 511x1x256 tensor to 510x1x256 by removing the row with least impact."""
|
||||||
|
if tensor.shape[0] != 511:
|
||||||
|
raise ValueError(f"Expected tensor with first dimension 511, got {tensor.shape[0]}")
|
||||||
|
|
||||||
|
# Analyze variance contribution of each row
|
||||||
|
variance = analyze_voice_content(tensor)
|
||||||
|
|
||||||
|
# Determine which end has lower variance (less information)
|
||||||
|
start_var = variance[:5].mean().item()
|
||||||
|
end_var = variance[-5:].mean().item()
|
||||||
|
|
||||||
|
# Remove from the end with lower variance
|
||||||
|
if end_var < start_var:
|
||||||
|
logger.info("Trimming last row (lower variance at end)")
|
||||||
|
return tensor[:-1]
|
||||||
|
else:
|
||||||
|
logger.info("Trimming first row (lower variance at start)")
|
||||||
|
return tensor[1:]
|
||||||
|
|
||||||
|
def process_voice_file(file_path):
|
||||||
|
"""Process a single voice file."""
|
||||||
|
try:
|
||||||
|
tensor = torch.load(file_path, map_location="cpu")
|
||||||
|
if tensor.shape[0] != 511:
|
||||||
|
logger.info(f"Skipping {os.path.basename(file_path)} - already correct shape {tensor.shape}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.info(f"\nProcessing {os.path.basename(file_path)}:")
|
||||||
|
logger.info(f"Original shape: {tensor.shape}")
|
||||||
|
|
||||||
|
# Create backup
|
||||||
|
backup_path = file_path + ".backup"
|
||||||
|
if not os.path.exists(backup_path):
|
||||||
|
torch.save(tensor, backup_path)
|
||||||
|
logger.info(f"Created backup at {backup_path}")
|
||||||
|
|
||||||
|
# Trim tensor
|
||||||
|
trimmed = trim_voice_tensor(tensor)
|
||||||
|
logger.info(f"New shape: {trimmed.shape}")
|
||||||
|
|
||||||
|
# Save trimmed tensor
|
||||||
|
torch.save(trimmed, file_path)
|
||||||
|
logger.info(f"Saved trimmed tensor to {file_path}")
|
||||||
|
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing {file_path}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Process voice files in the voices directory."""
|
||||||
|
# Get the project root directory
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
project_root = os.path.dirname(os.path.dirname(os.path.dirname(current_dir)))
|
||||||
|
voices_dir = os.path.join(project_root, "api", "src", "voices", "v1_0")
|
||||||
|
|
||||||
|
logger.info(f"Processing voices in: {voices_dir}")
|
||||||
|
|
||||||
|
processed = 0
|
||||||
|
for file in os.listdir(voices_dir):
|
||||||
|
if file.endswith('.pt') and not file.endswith('.backup'):
|
||||||
|
file_path = os.path.join(voices_dir, file)
|
||||||
|
if process_voice_file(file_path):
|
||||||
|
processed += 1
|
||||||
|
|
||||||
|
logger.info(f"\nProcessed {processed} voice files")
|
||||||
|
logger.info("Backups created with .backup extension")
|
||||||
|
logger.info("To restore backups if needed, remove .backup extension to replace trimmed files")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
98
examples/captioned_speech_example.py
Normal file
98
examples/captioned_speech_example.py
Normal file
|
@ -0,0 +1,98 @@
|
||||||
|
import json
|
||||||
|
from typing import Tuple, Optional, Dict, List
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
# Get the directory this script is in
|
||||||
|
SCRIPT_DIR = Path(__file__).absolute().parent
|
||||||
|
|
||||||
|
def generate_captioned_speech(
|
||||||
|
text: str,
|
||||||
|
voice: str = "af_bella",
|
||||||
|
speed: float = 1.0,
|
||||||
|
response_format: str = "wav"
|
||||||
|
) -> Tuple[Optional[bytes], Optional[List[Dict]]]:
|
||||||
|
"""Generate audio with word-level timestamps."""
|
||||||
|
response = requests.post(
|
||||||
|
"http://localhost:8880/dev/captioned_speech",
|
||||||
|
json={
|
||||||
|
"model": "kokoro",
|
||||||
|
"input": text,
|
||||||
|
"voice": voice,
|
||||||
|
"speed": speed,
|
||||||
|
"response_format": response_format
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Response status: {response.status_code}")
|
||||||
|
print(f"Response headers: {dict(response.headers)}")
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
print(f"Error response: {response.text}")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get timestamps from header
|
||||||
|
timestamps_json = response.headers.get('X-Word-Timestamps', '[]')
|
||||||
|
word_timestamps = json.loads(timestamps_json)
|
||||||
|
|
||||||
|
# Get audio bytes from content
|
||||||
|
audio_bytes = response.content
|
||||||
|
|
||||||
|
if not audio_bytes:
|
||||||
|
print("Error: Empty audio content")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
return audio_bytes, word_timestamps
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(f"Error parsing timestamps: {e}")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Example texts to convert
|
||||||
|
examples = [
|
||||||
|
"Hello world! Welcome to the captioned speech system.",
|
||||||
|
"The quick brown fox jumps over the lazy dog.",
|
||||||
|
"""If you have access to a room where gasoline is stored, remember that gas vapor accumulating in a closed room will explode after a time if you leave a candle burning in the room. A good deal of evaporation, however, must occur from the gasoline tins into the air of the room. If removal of the tops of the tins does not expose enough gasoline to the air to ensure copious evaporation, you can open lightly constructed tins further with a knife, ice pick or sharpened nail file. Or puncture a tiny hole in the tank which will permit gasoline to leak out on the floor. This will greatly increase the rate of evaporation. Before you light your candle, be sure that windows are closed and the room is as air-tight as you can make it. If you can see that windows in a neighboring room are opened wide, you have a chance of setting a large fire which will not only destroy the gasoline but anything else nearby; when the gasoline explodes, the doors of the storage room will be blown open, a draft to the neighboring windows will be created which will whip up a fine conflagration"""
|
||||||
|
]
|
||||||
|
|
||||||
|
print("Generating captioned speech for example texts...\n")
|
||||||
|
|
||||||
|
# Create output directory in same directory as script
|
||||||
|
output_dir = SCRIPT_DIR / "output"
|
||||||
|
output_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
for i, text in enumerate(examples):
|
||||||
|
print(f"\nExample {i+1}:")
|
||||||
|
print(f"Input text: {text}")
|
||||||
|
try:
|
||||||
|
# Generate audio and get timestamps
|
||||||
|
audio_bytes, word_timestamps = generate_captioned_speech(text)
|
||||||
|
|
||||||
|
if not audio_bytes or not word_timestamps:
|
||||||
|
print("Error: No audio data or timestamps generated")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Save audio file
|
||||||
|
audio_path = output_dir / f"captioned_example_{i+1}.wav"
|
||||||
|
with audio_path.open("wb") as f:
|
||||||
|
f.write(audio_bytes)
|
||||||
|
print(f"Audio saved to: {audio_path}")
|
||||||
|
|
||||||
|
# Save timestamps to JSON
|
||||||
|
timestamps_path = output_dir / f"captioned_example_{i+1}_timestamps.json"
|
||||||
|
with timestamps_path.open("w") as f:
|
||||||
|
json.dump(word_timestamps, f, indent=2)
|
||||||
|
print(f"Timestamps saved to: {timestamps_path}")
|
||||||
|
|
||||||
|
# Print timestamps
|
||||||
|
print("\nWord-level timestamps:")
|
||||||
|
for ts in word_timestamps:
|
||||||
|
print(f"{ts['word']}: {ts['start_time']:.3f}s - {ts['end_time']:.3f}s")
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
print(f"Error: {e}\n")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -37,8 +37,8 @@ dependencies = [
|
||||||
"semchunk>=3.0.1",
|
"semchunk>=3.0.1",
|
||||||
"mutagen>=1.47.0",
|
"mutagen>=1.47.0",
|
||||||
"psutil>=6.1.1",
|
"psutil>=6.1.1",
|
||||||
"kokoro==0.3.5",
|
"kokoro==0.7.4",
|
||||||
'misaki[en,ja,ko,zh,vi]==0.6.7',
|
'misaki[en,ja,ko,zh,vi]==0.7.4',
|
||||||
"spacy>=3.7.6",
|
"spacy>=3.7.6",
|
||||||
"en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl"
|
"en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl"
|
||||||
]
|
]
|
||||||
|
|
|
@ -30,6 +30,20 @@ export class VoiceSelector {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Weight adjustment and voice removal
|
||||||
|
this.elements.selectedVoices.addEventListener('input', (e) => {
|
||||||
|
if (e.target.type === 'number') {
|
||||||
|
const voice = e.target.dataset.voice;
|
||||||
|
let weight = parseFloat(e.target.value);
|
||||||
|
|
||||||
|
// Ensure weight is between 0.1 and 10
|
||||||
|
weight = Math.max(0.1, Math.min(10, weight));
|
||||||
|
e.target.value = weight;
|
||||||
|
|
||||||
|
this.voiceService.updateWeight(voice, weight);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
// Remove selected voice
|
// Remove selected voice
|
||||||
this.elements.selectedVoices.addEventListener('click', (e) => {
|
this.elements.selectedVoices.addEventListener('click', (e) => {
|
||||||
if (e.target.classList.contains('remove-voice')) {
|
if (e.target.classList.contains('remove-voice')) {
|
||||||
|
@ -73,12 +87,22 @@ export class VoiceSelector {
|
||||||
}
|
}
|
||||||
|
|
||||||
updateSelectedVoicesDisplay() {
|
updateSelectedVoicesDisplay() {
|
||||||
const selectedVoices = this.voiceService.getSelectedVoices();
|
const selectedVoices = this.voiceService.getSelectedVoiceWeights();
|
||||||
this.elements.selectedVoices.innerHTML = selectedVoices
|
this.elements.selectedVoices.innerHTML = selectedVoices
|
||||||
.map(voice => `
|
.map(({voice, weight}) => `
|
||||||
<span class="selected-voice-tag">
|
<span class="selected-voice-tag">
|
||||||
${voice}
|
<span class="voice-name">${voice}</span>
|
||||||
<span class="remove-voice" data-voice="${voice}">×</span>
|
<span class="voice-weight">
|
||||||
|
<input type="number"
|
||||||
|
value="${weight}"
|
||||||
|
min="0.1"
|
||||||
|
max="10"
|
||||||
|
step="0.1"
|
||||||
|
data-voice="${voice}"
|
||||||
|
class="weight-input"
|
||||||
|
title="Voice weight (0.1 to 10)">
|
||||||
|
</span>
|
||||||
|
<span class="remove-voice" data-voice="${voice}" title="Remove voice">×</span>
|
||||||
</span>
|
</span>
|
||||||
`)
|
`)
|
||||||
.join('');
|
.join('');
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
export class VoiceService {
|
export class VoiceService {
|
||||||
constructor() {
|
constructor() {
|
||||||
this.availableVoices = [];
|
this.availableVoices = [];
|
||||||
this.selectedVoices = new Set();
|
this.selectedVoices = new Map(); // Changed to Map to store voice:weight pairs
|
||||||
}
|
}
|
||||||
|
|
||||||
async loadVoices() {
|
async loadVoices() {
|
||||||
|
@ -39,16 +39,33 @@ export class VoiceService {
|
||||||
}
|
}
|
||||||
|
|
||||||
getSelectedVoices() {
|
getSelectedVoices() {
|
||||||
return Array.from(this.selectedVoices);
|
return Array.from(this.selectedVoices.keys());
|
||||||
|
}
|
||||||
|
|
||||||
|
getSelectedVoiceWeights() {
|
||||||
|
return Array.from(this.selectedVoices.entries()).map(([voice, weight]) => ({
|
||||||
|
voice,
|
||||||
|
weight
|
||||||
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
getSelectedVoiceString() {
|
getSelectedVoiceString() {
|
||||||
return Array.from(this.selectedVoices).join('+');
|
return Array.from(this.selectedVoices.entries())
|
||||||
|
.map(([voice, weight]) => `${voice}(${weight})`)
|
||||||
|
.join('+');
|
||||||
}
|
}
|
||||||
|
|
||||||
addVoice(voice) {
|
addVoice(voice, weight = 1) {
|
||||||
if (this.availableVoices.includes(voice)) {
|
if (this.availableVoices.includes(voice)) {
|
||||||
this.selectedVoices.add(voice);
|
this.selectedVoices.set(voice, parseFloat(weight) || 1);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
updateWeight(voice, weight) {
|
||||||
|
if (this.selectedVoices.has(voice)) {
|
||||||
|
this.selectedVoices.set(voice, parseFloat(weight) || 1);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
|
Loading…
Add table
Reference in a new issue