Update dependencies, enhance voice management, and add captioned speech support

This commit is contained in:
remsky 2025-02-04 19:41:41 -07:00
parent 9198de2d95
commit 6c234a3b67
31 changed files with 979 additions and 169 deletions

1
.gitignore vendored
View file

@ -69,3 +69,4 @@ examples/*.ogg
examples/speech.mp3
examples/phoneme_examples/output/*.wav
examples/assorted_checks/benchmarks/output_audio/*
uv.lock

View file

@ -1,4 +1,9 @@
"""Model configuration for Kokoro V1."""
"""Model configuration for Kokoro V1.
This module provides model-specific configuration settings that complement the application-level
settings in config.py. While config.py handles general application settings (API, paths, etc.),
this module focuses on memory management and model file paths.
"""
from pydantic import BaseModel, Field
@ -9,51 +14,29 @@ class KokoroV1Config(BaseModel):
class Config:
frozen = True
class PyTorchCPUConfig(BaseModel):
"""PyTorch CPU backend configuration."""
class PyTorchConfig(BaseModel):
"""PyTorch backend configuration."""
memory_threshold: float = Field(0.8, description="Memory threshold for cleanup")
retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors")
num_threads: int = Field(8, description="Number of threads for parallel operations")
pin_memory: bool = Field(True, description="Whether to pin memory for faster CPU-GPU transfer")
class Config:
frozen = True
class PyTorchGPUConfig(BaseModel):
"""PyTorch GPU backend configuration."""
device_id: int = Field(0, description="CUDA device ID")
use_triton: bool = Field(True, description="Whether to use Triton for CUDA kernels")
memory_threshold: float = Field(0.8, description="Memory threshold for cleanup")
retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors")
sync_cuda: bool = Field(True, description="Whether to synchronize CUDA operations")
cuda_streams: int = Field(2, description="Number of CUDA streams for inference")
stream_timeout: int = Field(60, description="Stream timeout in seconds")
class Config:
frozen = True
class ModelConfig(BaseModel):
"""Kokoro V1 model configuration."""
# General settings
device_type: str = Field("cpu", description="Device type ('cpu' or 'gpu')")
cache_voices: bool = Field(True, description="Whether to cache voice tensors")
voice_cache_size: int = Field(2, description="Maximum number of cached voices")
# Model filename
pytorch_kokoro_v1_file: str = Field("v1_0/kokoro-v1_0.pth", description="PyTorch Kokoro V1 model filename")
# Backend configs
pytorch_cpu: PyTorchCPUConfig = Field(default_factory=PyTorchCPUConfig)
pytorch_gpu: PyTorchGPUConfig = Field(default_factory=PyTorchGPUConfig)
# Backend config
pytorch_gpu: PyTorchConfig = Field(default_factory=PyTorchConfig)
class Config:
frozen = True
# Global instance
model_config = ModelConfig()

View file

@ -160,7 +160,7 @@ async def list_voices() -> List[str]:
return sorted([name[:-3] for name in voices]) # Remove .pt extension
async def load_voice_tensor(voice_path: str, device: str = "cpu") -> torch.Tensor:
async def load_voice_tensor(voice_path: str, device: str = "cpu", weights_only=False) -> torch.Tensor:
"""Load voice tensor from file.
Args:
@ -179,7 +179,7 @@ async def load_voice_tensor(voice_path: str, device: str = "cpu") -> torch.Tenso
return torch.load(
io.BytesIO(data),
map_location=device,
weights_only=True
weights_only=weights_only
)
except Exception as e:
raise RuntimeError(f"Failed to load voice tensor from {voice_path}: {e}")

View file

@ -3,6 +3,7 @@
from typing import Dict, List, Optional
import torch
import aiofiles
from loguru import logger
from ..core import paths
@ -57,7 +58,7 @@ class VoiceManager:
except Exception as e:
raise RuntimeError(f"Failed to load voice {voice_name}: {e}")
async def combine_voices(self, voices: List[str], device: Optional[str] = None) -> str:
async def combine_voices(self, voices: List[str], device: Optional[str] = None) -> torch.Tensor:
"""Combine multiple voices.
Args:
@ -65,7 +66,7 @@ class VoiceManager:
device: Optional override for target device
Returns:
Name of combined voice
Combined voice tensor
Raises:
RuntimeError: If any voice not found
@ -80,10 +81,7 @@ class VoiceManager:
voice_tensors.append(voice)
combined = torch.mean(torch.stack(voice_tensors), dim=0)
combined_name = "+".join(voices)
self._voices[combined_name] = combined
return combined_name
return combined
async def list_voices(self) -> List[str]:
"""List available voice names.

View file

@ -8,14 +8,19 @@ from loguru import logger
from ..services.audio import AudioService, AudioNormalizer
from ..services.streaming_audio_writer import StreamingAudioWriter
from ..services.text_processing import phonemize, smart_split
from ..services.text_processing.vocabulary import tokenize
from ..services.text_processing import smart_split
from kokoro import KPipeline
from ..services.tts_service import TTSService
from ..structures.text_schemas import (
GenerateFromPhonemesRequest,
PhonemeRequest,
PhonemeResponse,
)
from ..structures import (
CaptionedSpeechRequest,
CaptionedSpeechResponse,
WordTimestamp
)
router = APIRouter(tags=["text processing"])
@ -26,11 +31,10 @@ async def get_tts_service() -> TTSService:
@router.post("/dev/phonemize", response_model=PhonemeResponse)
async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse:
"""Convert text to phonemes and tokens
"""Convert text to phonemes using Kokoro's quiet mode.
Args:
request: Request containing text and language
tts_service: Injected TTSService instance
Returns:
Phonemes and token IDs
@ -39,14 +43,17 @@ async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse:
if not request.text:
raise ValueError("Text cannot be empty")
# Get phonemes
phonemes = phonemize(request.text, request.language)
if not phonemes:
raise ValueError("Failed to generate phonemes")
# Initialize Kokoro pipeline in quiet mode (no model)
pipeline = KPipeline(lang_code=request.language, model=False)
# Get first result from pipeline (we only need one since we're not chunking)
for result in pipeline(request.text):
# result.graphemes = original text
# result.phonemes = phonemized text
# result.tokens = token objects (if available)
return PhonemeResponse(phonemes=result.phonemes, tokens=[])
# Get tokens (without adding start/end tokens to match process_text behavior)
tokens = tokenize(phonemes)
return PhonemeResponse(phonemes=phonemes, tokens=tokens)
raise ValueError("Failed to generate phonemes")
except ValueError as e:
logger.error(f"Error in phoneme generation: {str(e)}")
raise HTTPException(
@ -63,7 +70,7 @@ async def generate_from_phonemes(
client_request: Request,
tts_service: TTSService = Depends(get_tts_service),
) -> StreamingResponse:
"""Generate audio directly from phonemes with proper streaming"""
"""Generate audio directly from phonemes using Kokoro's phoneme format"""
try:
# Basic validation
if not isinstance(request.phonemes, str):
@ -77,41 +84,30 @@ async def generate_from_phonemes(
async def generate_chunks():
try:
has_data = False
# Process phonemes in chunks
async for chunk_text, _ in smart_split(request.phonemes):
# Check if client is still connected
is_disconnected = client_request.is_disconnected
if callable(is_disconnected):
is_disconnected = await is_disconnected()
if is_disconnected:
logger.info("Client disconnected, stopping audio generation")
break
chunk_audio, _ = await tts_service.generate_from_phonemes(
phonemes=chunk_text,
voice=request.voice,
speed=1.0
)
if chunk_audio is not None:
has_data = True
# Normalize audio before writing
normalized_audio = await normalizer.normalize(chunk_audio)
# Write chunk and yield bytes
chunk_bytes = writer.write_chunk(normalized_audio)
if chunk_bytes:
yield chunk_bytes
if not has_data:
raise ValueError("Failed to generate any audio data")
# Finalize and yield remaining bytes if we still have a connection
if not (callable(is_disconnected) and await is_disconnected()):
# Generate audio from phonemes
chunk_audio, _ = await tts_service.generate_from_phonemes(
phonemes=request.phonemes, # Pass complete phoneme string
voice=request.voice,
speed=1.0
)
if chunk_audio is not None:
# Normalize audio before writing
normalized_audio = await normalizer.normalize(chunk_audio)
# Write chunk and yield bytes
chunk_bytes = writer.write_chunk(normalized_audio)
if chunk_bytes:
yield chunk_bytes
# Finalize and yield remaining bytes
final_bytes = writer.write_chunk(finalize=True)
if final_bytes:
yield final_bytes
else:
raise ValueError("Failed to generate audio data")
except Exception as e:
logger.error(f"Error in audio chunk generation: {str(e)}")
logger.error(f"Error in audio generation: {str(e)}")
# Clean up writer on error
writer.write_chunk(finalize=True)
# Re-raise the original exception
@ -128,7 +124,6 @@ async def generate_from_phonemes(
}
)
except ValueError as e:
logger.error(f"Error generating audio: {str(e)}")
raise HTTPException(
@ -149,3 +144,92 @@ async def generate_from_phonemes(
"type": "server_error"
}
)
@router.post("/dev/captioned_speech")
async def create_captioned_speech(
request: CaptionedSpeechRequest,
tts_service: TTSService = Depends(get_tts_service),
) -> StreamingResponse:
"""Generate audio with word-level timestamps using Kokoro's output"""
try:
# Get voice path
voice_name, voice_path = await tts_service._get_voice_path(request.voice)
# Generate audio with timestamps
audio, _, word_timestamps = await tts_service.generate_audio(
text=request.input,
voice=voice_name,
speed=request.speed,
return_timestamps=True
)
# Create streaming audio writer
writer = StreamingAudioWriter(format=request.response_format, sample_rate=24000, channels=1)
normalizer = AudioNormalizer()
async def generate_chunks():
try:
if audio is not None:
# Normalize audio before writing
normalized_audio = await normalizer.normalize(audio)
# Write chunk and yield bytes
chunk_bytes = writer.write_chunk(normalized_audio)
if chunk_bytes:
yield chunk_bytes
# Finalize and yield remaining bytes
final_bytes = writer.write_chunk(finalize=True)
if final_bytes:
yield final_bytes
else:
raise ValueError("Failed to generate audio data")
except Exception as e:
logger.error(f"Error in audio generation: {str(e)}")
# Clean up writer on error
writer.write_chunk(finalize=True)
# Re-raise the original exception
raise
# Convert timestamps to JSON and add as header
import json
logger.debug(f"Processing {len(word_timestamps)} word timestamps")
timestamps_json = json.dumps([{
'word': str(ts['word']), # Ensure string for text
'start_time': float(ts['start_time']), # Ensure float for timestamps
'end_time': float(ts['end_time'])
} for ts in word_timestamps])
logger.debug(f"Generated timestamps JSON: {timestamps_json}")
return StreamingResponse(
generate_chunks(),
media_type=f"audio/{request.response_format}",
headers={
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
"X-Accel-Buffering": "no",
"Cache-Control": "no-cache",
"Transfer-Encoding": "chunked",
"X-Word-Timestamps": timestamps_json
}
)
except ValueError as e:
logger.error(f"Error in captioned speech generation: {str(e)}")
raise HTTPException(
status_code=400,
detail={
"error": "validation_error",
"message": str(e),
"type": "invalid_request_error"
}
)
except Exception as e:
logger.error(f"Error in captioned speech generation: {str(e)}")
raise HTTPException(
status_code=500,
detail={
"error": "processing_error",
"message": str(e),
"type": "server_error"
}
)

View file

@ -2,15 +2,19 @@
import json
import os
import io
import tempfile
from typing import AsyncGenerator, Dict, List, Union
import torch
import aiofiles
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
from fastapi.responses import StreamingResponse, FileResponse
from loguru import logger
from ..services.audio import AudioService
from ..services.tts_service import TTSService
from ..structures.schemas import OpenAISpeechRequest
from ..structures import OpenAISpeechRequest
from ..core.config import settings
# Load OpenAI mappings
@ -72,55 +76,65 @@ def get_model_name(model: str) -> str:
async def process_voices(
voice_input: Union[str, List[str]], tts_service: TTSService
) -> str:
"""Process voice input into a combined voice, handling both string and list formats"""
"""Process voice input, handling both string and list formats
Returns:
Voice name to use (with weights if specified)
"""
# Convert input to list of voices
if isinstance(voice_input, str):
# Check if it's an OpenAI voice name
mapped_voice = _openai_mappings["voices"].get(voice_input)
if mapped_voice:
voice_input = mapped_voice
voices = [v.strip() for v in voice_input.split("+") if v.strip()]
# Split on + but preserve any parentheses
voices = []
for part in voice_input.split("+"):
part = part.strip()
if not part:
continue
# Extract voice name without weight
voice_name = part.split("(")[0].strip()
# Check if it's a valid voice
available_voices = await tts_service.list_voices()
if voice_name not in available_voices:
raise ValueError(
f"Voice '{voice_name}' not found. Available voices: {', '.join(sorted(available_voices))}"
)
voices.append(part)
else:
# For list input, map each voice if it's an OpenAI voice name
voices = [_openai_mappings["voices"].get(v, v) for v in voice_input]
voices = [v.strip() for v in voices if v.strip()]
voices = []
for v in voice_input:
mapped = _openai_mappings["voices"].get(v, v)
voice_name = mapped.split("(")[0].strip()
# Check if it's a valid voice
available_voices = await tts_service.list_voices()
if voice_name not in available_voices:
raise ValueError(
f"Voice '{voice_name}' not found. Available voices: {', '.join(sorted(available_voices))}"
)
voices.append(mapped)
if not voices:
raise ValueError("No voices provided")
# If single voice, validate and return it
if len(voices) == 1:
available_voices = await tts_service.list_voices()
if voices[0] not in available_voices:
raise ValueError(
f"Voice '{voices[0]}' not found. Available voices: {', '.join(sorted(available_voices))}"
)
return voices[0]
# For multiple voices, validate base voices exist
available_voices = await tts_service.list_voices()
for voice in voices:
if voice not in available_voices:
raise ValueError(
f"Base voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}"
)
# Combine voices
return await tts_service.combine_voices(voices=voices)
# For multiple voices, combine them with +
return "+".join(voices)
async def stream_audio_chunks(
tts_service: TTSService,
tts_service: TTSService,
request: OpenAISpeechRequest,
client_request: Request
) -> AsyncGenerator[bytes, None]:
"""Stream audio chunks as they're generated with client disconnect handling"""
voice_to_use = await process_voices(request.voice, tts_service)
voice_name = await process_voices(request.voice, tts_service)
try:
async for chunk in tts_service.generate_audio_stream(
text=request.input,
voice=voice_to_use,
voice=voice_name,
speed=request.speed,
output_format=request.response_format,
):
@ -159,7 +173,7 @@ async def create_speech(
try:
# model_name = get_model_name(request.model)
tts_service = await get_tts_service()
voice_to_use = await process_voices(request.voice, tts_service)
voice_name = await process_voices(request.voice, tts_service)
# Set content type based on format
content_type = {
@ -237,7 +251,7 @@ async def create_speech(
# Generate complete audio using public interface
audio, _ = await tts_service.generate_audio(
text=request.input,
voice=voice_to_use,
voice=voice_name,
speed=request.speed
)
@ -350,14 +364,14 @@ async def list_voices():
@router.post("/audio/voices/combine")
async def combine_voices(request: Union[str, List[str]]):
"""Combine multiple voices into a new voice.
"""Combine multiple voices into a new voice and return the .pt file.
Args:
request: Either a string with voices separated by + (e.g. "voice1+voice2")
or a list of voice names to combine
Returns:
Dict with combined voice name and list of all available voices
FileResponse with the combined voice .pt file
Raises:
HTTPException:
@ -365,10 +379,51 @@ async def combine_voices(request: Union[str, List[str]]):
- 500: Server error (file system issues, combination failed)
"""
try:
# Convert input to list of voices
if isinstance(request, str):
# Check if it's an OpenAI voice name
mapped_voice = _openai_mappings["voices"].get(request)
if mapped_voice:
request = mapped_voice
voices = [v.strip() for v in request.split("+") if v.strip()]
else:
# For list input, map each voice if it's an OpenAI voice name
voices = [_openai_mappings["voices"].get(v, v) for v in request]
voices = [v.strip() for v in voices if v.strip()]
if not voices:
raise ValueError("No voices provided")
# For multiple voices, validate base voices exist
tts_service = await get_tts_service()
combined_voice = await process_voices(request, tts_service)
voices = await tts_service.list_voices()
return {"voices": voices, "voice": combined_voice}
available_voices = await tts_service.list_voices()
for voice in voices:
if voice not in available_voices:
raise ValueError(
f"Base voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}"
)
# Combine voices
combined_tensor = await tts_service.combine_voices(voices=voices)
combined_name = "+".join(voices)
# Save to temp file
temp_dir = tempfile.gettempdir()
voice_path = os.path.join(temp_dir, f"{combined_name}.pt")
buffer = io.BytesIO()
torch.save(combined_tensor, buffer)
async with aiofiles.open(voice_path, 'wb') as f:
await f.write(buffer.getvalue())
return FileResponse(
voice_path,
media_type="application/octet-stream",
filename=f"{combined_name}.pt",
headers={
"Content-Disposition": f"attachment; filename={combined_name}.pt",
"Cache-Control": "no-cache"
}
)
except ValueError as e:
logger.warning(f"Invalid voice combination request: {str(e)}")

View file

@ -17,6 +17,7 @@ from .audio import AudioNormalizer, AudioService
from .text_processing.text_processor import process_text_chunk, smart_split
from .text_processing import tokenize
from ..inference.kokoro_v1 import KokoroV1
from kokoro import KPipeline
class TTSService:
@ -154,23 +155,43 @@ class TTSService:
try:
# Check if it's a combined voice
if "+" in voice:
voices = [v.strip() for v in voice.split("+") if v.strip()]
if len(voices) < 2:
# Split on + but preserve any parentheses
voice_parts = []
weights = []
for part in voice.split("+"):
part = part.strip()
if not part:
continue
# Extract voice name and weight if present
if "(" in part and ")" in part:
voice_name = part.split("(")[0].strip()
weight = float(part.split("(")[1].split(")")[0])
else:
voice_name = part
weight = 1.0
voice_parts.append(voice_name)
weights.append(weight)
if len(voice_parts) < 2:
raise RuntimeError(f"Invalid combined voice name: {voice}")
# Normalize weights to sum to 1
total_weight = sum(weights)
weights = [w/total_weight for w in weights]
# Load and combine voices
voice_tensors = []
for v in voices:
for v, w in zip(voice_parts, weights):
path = await self._voice_manager.get_voice_path(v)
if not path:
raise RuntimeError(f"Voice not found: {v}")
logger.debug(f"Loading voice tensor from: {path}")
voice_tensor = torch.load(path, map_location="cpu")
voice_tensors.append(voice_tensor)
voice_tensors.append(voice_tensor * w)
# Average the voice tensors
logger.debug(f"Combining {len(voice_tensors)} voice tensors")
combined = torch.mean(torch.stack(voice_tensors), dim=0)
# Sum the weighted voice tensors
logger.debug(f"Combining {len(voice_tensors)} voice tensors with weights {weights}")
combined = torch.sum(torch.stack(voice_tensors), dim=0)
# Save combined tensor
temp_dir = tempfile.gettempdir()
@ -259,43 +280,237 @@ class TTSService:
raise
async def generate_audio(
self, text: str, voice: str, speed: float = 1.0
) -> Tuple[np.ndarray, float]:
self, text: str, voice: str, speed: float = 1.0, return_timestamps: bool = False
) -> Union[Tuple[np.ndarray, float], Tuple[np.ndarray, float, List[dict]]]:
"""Generate complete audio for text using streaming internally."""
start_time = time.time()
chunks = []
word_timestamps = []
try:
# Use streaming generator but collect all valid chunks
async for chunk in self.generate_audio_stream(
text, voice, speed, # Default to WAV for raw audio
):
if chunk is not None:
chunks.append(chunk)
# Get backend and voice path
backend = self.model_manager.get_backend()
voice_name, voice_path = await self._get_voice_path(voice)
if not chunks:
raise ValueError("No audio chunks were generated successfully")
if isinstance(backend, KokoroV1):
# Initialize quiet pipeline for text chunking
quiet_pipeline = KPipeline(lang_code='a', model=False)
# Split text into chunks and get initial tokens
text_chunks = []
current_offset = 0.0 # Track time offset for timestamps
logger.debug("Splitting text into chunks...")
for result in quiet_pipeline(text):
if result.graphemes and result.phonemes:
text_chunks.append((result.graphemes, result.phonemes))
logger.debug(f"Split text into {len(text_chunks)} chunks")
# Process each chunk
for chunk_idx, (chunk_text, chunk_phonemes) in enumerate(text_chunks):
logger.debug(f"Processing chunk {chunk_idx + 1}/{len(text_chunks)}: '{chunk_text[:50]}...'")
# Generate audio and timestamps for this chunk
for result in backend._pipeline(
chunk_text,
voice=voice_path,
speed=speed,
model=backend._model
):
# Collect audio chunks
if result.audio is not None:
chunks.append(result.audio.numpy())
# Process timestamps for this chunk
if return_timestamps and hasattr(result, 'tokens') and result.tokens:
logger.debug(f"Processing chunk timestamps with {len(result.tokens)} tokens")
if result.pred_dur is not None:
try:
# Join timestamps for this chunk's tokens
KPipeline.join_timestamps(result.tokens, result.pred_dur)
# Add timestamps with offset
for token in result.tokens:
if not all(hasattr(token, attr) for attr in ['text', 'start_ts', 'end_ts']):
continue
if not token.text or not token.text.strip():
continue
# Apply offset to timestamps
start_time = float(token.start_ts) + current_offset
end_time = float(token.end_ts) + current_offset
word_timestamps.append({
'word': str(token.text).strip(),
'start_time': start_time,
'end_time': end_time
})
logger.debug(f"Added timestamp for word '{token.text}': {start_time:.3f}s - {end_time:.3f}s")
# Update offset for next chunk based on pred_dur
chunk_duration = float(result.pred_dur.sum()) / 80 # Convert frames to seconds
current_offset = max(current_offset + chunk_duration, end_time)
logger.debug(f"Updated time offset to {current_offset:.3f}s")
except Exception as e:
logger.error(f"Failed to process timestamps for chunk: {e}")
logger.debug(f"Processing timestamps with pred_dur shape: {result.pred_dur.shape}")
try:
# Join timestamps for this chunk's tokens
KPipeline.join_timestamps(result.tokens, result.pred_dur)
logger.debug("Successfully joined timestamps for chunk")
except Exception as e:
logger.error(f"Failed to join timestamps for chunk: {e}")
continue
# Convert tokens to timestamps
for token in result.tokens:
try:
# Skip tokens without required attributes
if not all(hasattr(token, attr) for attr in ['text', 'start_ts', 'end_ts']):
logger.debug(f"Skipping token missing attributes: {dir(token)}")
continue
# Get and validate text
text = str(token.text).strip() if token.text is not None else ''
if not text:
logger.debug("Skipping empty token")
continue
# Get and validate timestamps
start_ts = getattr(token, 'start_ts', None)
end_ts = getattr(token, 'end_ts', None)
if start_ts is None or end_ts is None:
logger.debug(f"Skipping token with None timestamps: {text}")
continue
# Convert timestamps to float
try:
start_time = float(start_ts)
end_time = float(end_ts)
except (TypeError, ValueError):
logger.debug(f"Skipping token with invalid timestamps: {text}")
continue
# Add timestamp
word_timestamps.append({
'word': text,
'start_time': start_time,
'end_time': end_time
})
logger.debug(f"Added timestamp for word '{text}': {start_time:.3f}s - {end_time:.3f}s")
except Exception as e:
logger.warning(f"Error processing token: {e}")
continue
if not chunks:
raise ValueError("No audio chunks were generated successfully")
# Combine chunks
audio = np.concatenate(chunks) if len(chunks) > 1 else chunks[0]
processing_time = time.time() - start_time
if return_timestamps:
# Validate timestamps before returning
if not word_timestamps:
logger.warning("No valid timestamps were generated")
else:
# Sort timestamps by start time to ensure proper order
word_timestamps.sort(key=lambda x: x['start_time'])
# Validate timestamp sequence
for i in range(1, len(word_timestamps)):
prev = word_timestamps[i-1]
curr = word_timestamps[i]
if curr['start_time'] < prev['end_time']:
logger.warning(f"Overlapping timestamps detected: '{prev['word']}' ({prev['start_time']:.3f}-{prev['end_time']:.3f}) and '{curr['word']}' ({curr['start_time']:.3f}-{curr['end_time']:.3f})")
logger.debug(f"Returning {len(word_timestamps)} word timestamps")
logger.debug(f"First timestamp: {word_timestamps[0]['word']} at {word_timestamps[0]['start_time']:.3f}s")
logger.debug(f"Last timestamp: {word_timestamps[-1]['word']} at {word_timestamps[-1]['end_time']:.3f}s")
return audio, processing_time, word_timestamps
return audio, processing_time
# Combine chunks, ensuring we have valid arrays
if len(chunks) == 1:
audio = chunks[0]
else:
# Filter out any zero-dimensional arrays
valid_chunks = [c for c in chunks if c.ndim > 0]
if not valid_chunks:
raise ValueError("No valid audio chunks to concatenate")
audio = np.concatenate(valid_chunks)
processing_time = time.time() - start_time
return audio, processing_time
# For legacy backends
async for chunk in self.generate_audio_stream(
text, voice, speed, # Default to WAV for raw audio
):
if chunk is not None:
chunks.append(chunk)
if not chunks:
raise ValueError("No audio chunks were generated successfully")
# Combine chunks
audio = np.concatenate(chunks) if len(chunks) > 1 else chunks[0]
processing_time = time.time() - start_time
if return_timestamps:
return audio, processing_time, [] # Empty timestamps for legacy backends
return audio, processing_time
except Exception as e:
logger.error(f"Error in audio generation: {str(e)}")
raise
async def combine_voices(self, voices: List[str]) -> str:
"""Combine multiple voices."""
async def combine_voices(self, voices: List[str]) -> torch.Tensor:
"""Combine multiple voices.
Returns:
Combined voice tensor
"""
return await self._voice_manager.combine_voices(voices)
async def list_voices(self) -> List[str]:
"""List available voices."""
return await self._voice_manager.list_voices()
async def generate_from_phonemes(
self,
phonemes: str,
voice: str,
speed: float = 1.0
) -> Tuple[np.ndarray, float]:
"""Generate audio directly from phonemes.
Args:
phonemes: Phonemes in Kokoro format
voice: Voice name
speed: Speed multiplier
Returns:
Tuple of (audio array, processing time)
"""
start_time = time.time()
try:
# Get backend and voice path
raise ValueError("Not yet implemented")
# linked to https://github.com/hexgrad/kokoro/pull/53 or similiar
backend = self.model_manager.get_backend()
voice_name, voice_path = await self._get_voice_path(voice)
# if isinstance(backend, KokoroV1):
# # For Kokoro V1, pass phonemes directly to pipeline
# result = None
# for r in backend._pipeline(
# phonemes,
# voice=voice_path,
# speed=speed,
# model=backend._model
# ):
# if r.audio is not None:
# result = r
# break
# if result is None or result.audio is None:
# raise ValueError("No audio generated")
# processing_time = time.time() - start_time
# return result.audio.numpy(), processing_time
# else:
pass
except Exception as e:
logger.error(f"Error in phoneme audio generation: {str(e)}")
raise

View file

@ -1,3 +1,17 @@
from .schemas import OpenAISpeechRequest
from .schemas import (
OpenAISpeechRequest,
CaptionedSpeechRequest,
CaptionedSpeechResponse,
WordTimestamp,
TTSStatus,
VoiceCombineRequest
)
__all__ = ["OpenAISpeechRequest"]
__all__ = [
"OpenAISpeechRequest",
"CaptionedSpeechRequest",
"CaptionedSpeechResponse",
"WordTimestamp",
"TTSStatus",
"VoiceCombineRequest"
]

View file

@ -22,7 +22,19 @@ class TTSStatus(str, Enum):
# OpenAI-compatible schemas
class WordTimestamp(BaseModel):
"""Word-level timestamp information"""
word: str = Field(..., description="The word or token")
start_time: float = Field(..., description="Start time in seconds")
end_time: float = Field(..., description="End time in seconds")
class CaptionedSpeechResponse(BaseModel):
"""Response schema for captioned speech endpoint"""
audio: bytes = Field(..., description="The generated audio data")
words: List[WordTimestamp] = Field(..., description="Word-level timestamps")
class OpenAISpeechRequest(BaseModel):
"""Request schema for OpenAI-compatible speech endpoint"""
model: str = Field(
default="kokoro",
description="The model to use for generation. Supported models: tts-1, tts-1-hd, kokoro"
@ -50,3 +62,29 @@ class OpenAISpeechRequest(BaseModel):
default=False,
description="If true, returns a download link in X-Download-Path header after streaming completes",
)
class CaptionedSpeechRequest(BaseModel):
"""Request schema for captioned speech endpoint"""
model: str = Field(
default="kokoro",
description="The model to use for generation. Supported models: tts-1, tts-1-hd, kokoro"
)
input: str = Field(..., description="The text to generate audio for")
voice: str = Field(
default="af",
description="The voice to use for generation. Can be a base voice or a combined voice name.",
)
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field(
default="mp3",
description="The format to return audio in. Supported formats: mp3, opus, flac, wav, pcm. PCM format returns raw 16-bit samples without headers. AAC is not currently supported.",
)
speed: float = Field(
default=1.0,
ge=0.25,
le=4.0,
description="The speed of the generated audio. Select a value from 0.25 to 4.0.",
)
return_timestamps: bool = Field(
default=True,
description="If true (default), returns word-level timestamps in the response",
)

Binary file not shown.

View file

@ -92,15 +92,16 @@ def analyze_audio(filepath: str):
'dominant_frequencies': dominant_freqs
}
def plot_comparison(analyses, output_path):
"""Create comparison plot of the audio analyses."""
def plot_comparison(analyses, output_dir):
"""Create detailed comparison plots of the audio analyses."""
plt.style.use('dark_background')
fig = plt.figure(figsize=(15, 10))
fig.patch.set_facecolor('#1a1a2e')
# Plot waveforms
fig_wave = plt.figure(figsize=(15, 10))
fig_wave.patch.set_facecolor('#1a1a2e')
for i, (name, data) in enumerate(analyses.items()):
ax = plt.subplot(3, 1, i+1)
ax = plt.subplot(len(analyses), 1, i+1)
samples = data['samples']
time = np.arange(len(samples)) / data['sample_rate']
plt.plot(time, samples / data['max_amplitude'], linewidth=0.5, color='#ff2a6d')
@ -112,27 +113,96 @@ def plot_comparison(analyses, output_path):
plt.ylim(-1.1, 1.1)
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"\nSaved comparison plot to {output_path}")
plt.savefig(output_dir / 'waveforms.png', dpi=300, bbox_inches='tight')
plt.close()
# Plot spectral characteristics
fig_spec = plt.figure(figsize=(15, 10))
fig_spec.patch.set_facecolor('#1a1a2e')
for i, (name, data) in enumerate(analyses.items()):
# Calculate spectrogram
samples = data['samples']
sample_rate = data['sample_rate']
nperseg = 2048
f, t, Sxx = plt.mlab.specgram(samples, NFFT=2048, Fs=sample_rate,
noverlap=nperseg//2, scale='dB')
ax = plt.subplot(len(analyses), 1, i+1)
plt.pcolormesh(t, f, Sxx, shading='gouraud', cmap='magma')
plt.title(f"Spectrogram: {name}", color='white', pad=20)
plt.ylabel('Frequency [Hz]', color='white')
plt.xlabel('Time [sec]', color='white')
plt.colorbar(label='Intensity [dB]')
ax.set_facecolor('#1a1a2e')
plt.tight_layout()
plt.savefig(output_dir / 'spectrograms.png', dpi=300, bbox_inches='tight')
plt.close()
# Plot voice characteristics comparison
fig_chars = plt.figure(figsize=(15, 8))
fig_chars.patch.set_facecolor('#1a1a2e')
# Extract characteristics
names = list(analyses.keys())
rms_values = [data['rms'] for data in analyses.values()]
centroids = [data['spectral_centroid'] for data in analyses.values()]
max_amps = [data['max_amplitude'] for data in analyses.values()]
# Plot characteristics
x = np.arange(len(names))
width = 0.25
ax = plt.subplot(111)
ax.bar(x - width, rms_values, width, label='RMS (Texture)', color='#ff2a6d')
ax.bar(x, [c/1000 for c in centroids], width, label='Spectral Centroid/1000 (Brightness)', color='#05d9e8')
ax.bar(x + width, max_amps, width, label='Max Amplitude', color='#ff65bd')
ax.set_xticks(x)
ax.set_xticklabels(names, rotation=45, ha='right')
ax.legend()
ax.set_title('Voice Characteristics Comparison', color='white', pad=20)
ax.set_facecolor('#1a1a2e')
plt.tight_layout()
plt.savefig(output_dir / 'characteristics.png', dpi=300, bbox_inches='tight')
plt.close()
print(f"\nSaved comparison plots to {output_dir}")
def main():
# Generate audio for each voice
# Test different voice combinations with weights
voices = {
'af_bella': output_dir / 'af_bella.wav',
'af_irulan': output_dir / 'af_irulan.wav',
'af_bella+af_irulan': output_dir / 'af_bella+af_irulan.wav'
'af_kore': output_dir / 'af_kore.wav',
'af_bella(0.2)+af_kore(0.8)': output_dir / 'af_bella_20_af_kore_80.wav',
'af_bella(0.8)+af_kore(0.2)': output_dir / 'af_bella_80_af_kore_20.wav',
'af_bella(0.5)+af_kore(0.5)': output_dir / 'af_bella_50_af_kore_50.wav'
}
# Generate audio for each voice/combination
for voice, path in voices.items():
generate_and_save_audio(voice, str(path))
try:
generate_and_save_audio(voice, str(path))
except Exception as e:
print(f"Error generating audio for {voice}: {e}")
continue
# Analyze each audio file
analyses = {}
for name, path in voices.items():
analyses[name] = analyze_audio(str(path))
try:
analyses[name] = analyze_audio(str(path))
except Exception as e:
print(f"Error analyzing {name}: {e}")
continue
# Create comparison plot
plot_comparison(analyses, output_dir / 'voice_comparison.png')
# Create comparison plots
if analyses:
plot_comparison(analyses, output_dir)
else:
print("No analyses to plot")
if __name__ == "__main__":
main()

View file

@ -0,0 +1,74 @@
#!/usr/bin/env python3
import os
from pathlib import Path
import requests
# Create output directory
output_dir = Path(__file__).parent / "output"
output_dir.mkdir(exist_ok=True)
def download_combined_voice(voice1: str, voice2: str, weights: tuple[float, float] = None) -> str:
"""Download a combined voice file.
Args:
voice1: First voice name
voice2: Second voice name
weights: Optional tuple of weights (w1, w2). If not provided, uses equal weights.
Returns:
Path to downloaded .pt file
"""
print(f"\nDownloading combined voice: {voice1} + {voice2}")
# Construct voice string with optional weights
if weights:
voice_str = f"{voice1}({weights[0]})+{voice2}({weights[1]})"
else:
voice_str = f"{voice1}+{voice2}"
# Make the request to combine voices
response = requests.post(
"http://localhost:8880/v1/audio/voices/combine",
json=voice_str
)
if response.status_code != 200:
raise Exception(f"Failed to combine voices: {response.text}")
# Save the .pt file
output_path = output_dir / f"{voice_str}.pt"
with open(output_path, "wb") as f:
f.write(response.content)
print(f"Saved combined voice to {output_path}")
return str(output_path)
def main():
# Test downloading various voice combinations
combinations = [
# Equal weights (default)
("af_bella", "af_kore"),
# Different weight combinations
("af_bella", "af_kore", (0.2, 0.8)),
("af_bella", "af_kore", (0.8, 0.2)),
("af_bella", "af_kore", (0.5, 0.5)),
# Test with different voices
("af_bella", "af_jadzia"),
("af_bella", "af_jadzia", (0.3, 0.7))
]
for combo in combinations:
try:
if len(combo) == 3:
voice1, voice2, weights = combo
download_combined_voice(voice1, voice2, weights)
else:
voice1, voice2 = combo
download_combined_voice(voice1, voice2)
except Exception as e:
print(f"Error downloading combination {combo}: {e}")
if __name__ == "__main__":
main()

View file

@ -0,0 +1,54 @@
import os
import torch
from loguru import logger
def analyze_voice_file(file_path):
"""Analyze dimensions and statistics of a voice tensor."""
try:
tensor = torch.load(file_path, map_location="cpu")
logger.info(f"\nAnalyzing {os.path.basename(file_path)}:")
logger.info(f"Shape: {tensor.shape}")
logger.info(f"Mean: {tensor.mean().item():.4f}")
logger.info(f"Std: {tensor.std().item():.4f}")
logger.info(f"Min: {tensor.min().item():.4f}")
logger.info(f"Max: {tensor.max().item():.4f}")
return tensor.shape
except Exception as e:
logger.error(f"Error analyzing {file_path}: {e}")
return None
def main():
"""Analyze voice files in the voices directory."""
# Get the project root directory
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(os.path.dirname(os.path.dirname(current_dir)))
voices_dir = os.path.join(project_root, "api", "src", "voices", "v1_0")
logger.info(f"Scanning voices in: {voices_dir}")
# Track shapes for comparison
shapes = {}
# Analyze each .pt file
for file in os.listdir(voices_dir):
if file.endswith('.pt'):
file_path = os.path.join(voices_dir, file)
shape = analyze_voice_file(file_path)
if shape:
shapes[file] = shape
# Report findings
logger.info("\nShape Analysis:")
shape_groups = {}
for file, shape in shapes.items():
if shape not in shape_groups:
shape_groups[shape] = []
shape_groups[shape].append(file)
for shape, files in shape_groups.items():
logger.info(f"\nShape {shape}:")
for file in files:
logger.info(f" - {file}")
if __name__ == "__main__":
main()

View file

@ -0,0 +1,85 @@
import os
import torch
from loguru import logger
def analyze_voice_content(tensor):
"""Analyze the content distribution in the voice tensor."""
# Look at the variance along the first dimension to see where the information is concentrated
variance = torch.var(tensor, dim=(1,2)) # Variance across features
logger.info(f"Variance distribution:")
logger.info(f"First 5 rows variance: {variance[:5].mean().item():.6f}")
logger.info(f"Last 5 rows variance: {variance[-5:].mean().item():.6f}")
return variance
def trim_voice_tensor(tensor):
"""Trim a 511x1x256 tensor to 510x1x256 by removing the row with least impact."""
if tensor.shape[0] != 511:
raise ValueError(f"Expected tensor with first dimension 511, got {tensor.shape[0]}")
# Analyze variance contribution of each row
variance = analyze_voice_content(tensor)
# Determine which end has lower variance (less information)
start_var = variance[:5].mean().item()
end_var = variance[-5:].mean().item()
# Remove from the end with lower variance
if end_var < start_var:
logger.info("Trimming last row (lower variance at end)")
return tensor[:-1]
else:
logger.info("Trimming first row (lower variance at start)")
return tensor[1:]
def process_voice_file(file_path):
"""Process a single voice file."""
try:
tensor = torch.load(file_path, map_location="cpu")
if tensor.shape[0] != 511:
logger.info(f"Skipping {os.path.basename(file_path)} - already correct shape {tensor.shape}")
return False
logger.info(f"\nProcessing {os.path.basename(file_path)}:")
logger.info(f"Original shape: {tensor.shape}")
# Create backup
backup_path = file_path + ".backup"
if not os.path.exists(backup_path):
torch.save(tensor, backup_path)
logger.info(f"Created backup at {backup_path}")
# Trim tensor
trimmed = trim_voice_tensor(tensor)
logger.info(f"New shape: {trimmed.shape}")
# Save trimmed tensor
torch.save(trimmed, file_path)
logger.info(f"Saved trimmed tensor to {file_path}")
return True
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
return False
def main():
"""Process voice files in the voices directory."""
# Get the project root directory
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(os.path.dirname(os.path.dirname(current_dir)))
voices_dir = os.path.join(project_root, "api", "src", "voices", "v1_0")
logger.info(f"Processing voices in: {voices_dir}")
processed = 0
for file in os.listdir(voices_dir):
if file.endswith('.pt') and not file.endswith('.backup'):
file_path = os.path.join(voices_dir, file)
if process_voice_file(file_path):
processed += 1
logger.info(f"\nProcessed {processed} voice files")
logger.info("Backups created with .backup extension")
logger.info("To restore backups if needed, remove .backup extension to replace trimmed files")
if __name__ == "__main__":
main()

View file

@ -0,0 +1,98 @@
import json
from typing import Tuple, Optional, Dict, List
from pathlib import Path
import requests
# Get the directory this script is in
SCRIPT_DIR = Path(__file__).absolute().parent
def generate_captioned_speech(
text: str,
voice: str = "af_bella",
speed: float = 1.0,
response_format: str = "wav"
) -> Tuple[Optional[bytes], Optional[List[Dict]]]:
"""Generate audio with word-level timestamps."""
response = requests.post(
"http://localhost:8880/dev/captioned_speech",
json={
"model": "kokoro",
"input": text,
"voice": voice,
"speed": speed,
"response_format": response_format
}
)
print(f"Response status: {response.status_code}")
print(f"Response headers: {dict(response.headers)}")
if response.status_code != 200:
print(f"Error response: {response.text}")
return None, None
try:
# Get timestamps from header
timestamps_json = response.headers.get('X-Word-Timestamps', '[]')
word_timestamps = json.loads(timestamps_json)
# Get audio bytes from content
audio_bytes = response.content
if not audio_bytes:
print("Error: Empty audio content")
return None, None
return audio_bytes, word_timestamps
except json.JSONDecodeError as e:
print(f"Error parsing timestamps: {e}")
return None, None
def main():
# Example texts to convert
examples = [
"Hello world! Welcome to the captioned speech system.",
"The quick brown fox jumps over the lazy dog.",
"""If you have access to a room where gasoline is stored, remember that gas vapor accumulating in a closed room will explode after a time if you leave a candle burning in the room. A good deal of evaporation, however, must occur from the gasoline tins into the air of the room. If removal of the tops of the tins does not expose enough gasoline to the air to ensure copious evaporation, you can open lightly constructed tins further with a knife, ice pick or sharpened nail file. Or puncture a tiny hole in the tank which will permit gasoline to leak out on the floor. This will greatly increase the rate of evaporation. Before you light your candle, be sure that windows are closed and the room is as air-tight as you can make it. If you can see that windows in a neighboring room are opened wide, you have a chance of setting a large fire which will not only destroy the gasoline but anything else nearby; when the gasoline explodes, the doors of the storage room will be blown open, a draft to the neighboring windows will be created which will whip up a fine conflagration"""
]
print("Generating captioned speech for example texts...\n")
# Create output directory in same directory as script
output_dir = SCRIPT_DIR / "output"
output_dir.mkdir(exist_ok=True)
for i, text in enumerate(examples):
print(f"\nExample {i+1}:")
print(f"Input text: {text}")
try:
# Generate audio and get timestamps
audio_bytes, word_timestamps = generate_captioned_speech(text)
if not audio_bytes or not word_timestamps:
print("Error: No audio data or timestamps generated")
continue
# Save audio file
audio_path = output_dir / f"captioned_example_{i+1}.wav"
with audio_path.open("wb") as f:
f.write(audio_bytes)
print(f"Audio saved to: {audio_path}")
# Save timestamps to JSON
timestamps_path = output_dir / f"captioned_example_{i+1}_timestamps.json"
with timestamps_path.open("w") as f:
json.dump(word_timestamps, f, indent=2)
print(f"Timestamps saved to: {timestamps_path}")
# Print timestamps
print("\nWord-level timestamps:")
for ts in word_timestamps:
print(f"{ts['word']}: {ts['start_time']:.3f}s - {ts['end_time']:.3f}s")
except requests.RequestException as e:
print(f"Error: {e}\n")
if __name__ == "__main__":
main()

View file

@ -37,8 +37,8 @@ dependencies = [
"semchunk>=3.0.1",
"mutagen>=1.47.0",
"psutil>=6.1.1",
"kokoro==0.3.5",
'misaki[en,ja,ko,zh,vi]==0.6.7',
"kokoro==0.7.4",
'misaki[en,ja,ko,zh,vi]==0.7.4',
"spacy>=3.7.6",
"en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl"
]

View file

@ -30,6 +30,20 @@ export class VoiceSelector {
}
});
// Weight adjustment and voice removal
this.elements.selectedVoices.addEventListener('input', (e) => {
if (e.target.type === 'number') {
const voice = e.target.dataset.voice;
let weight = parseFloat(e.target.value);
// Ensure weight is between 0.1 and 10
weight = Math.max(0.1, Math.min(10, weight));
e.target.value = weight;
this.voiceService.updateWeight(voice, weight);
}
});
// Remove selected voice
this.elements.selectedVoices.addEventListener('click', (e) => {
if (e.target.classList.contains('remove-voice')) {
@ -73,12 +87,22 @@ export class VoiceSelector {
}
updateSelectedVoicesDisplay() {
const selectedVoices = this.voiceService.getSelectedVoices();
const selectedVoices = this.voiceService.getSelectedVoiceWeights();
this.elements.selectedVoices.innerHTML = selectedVoices
.map(voice => `
.map(({voice, weight}) => `
<span class="selected-voice-tag">
${voice}
<span class="remove-voice" data-voice="${voice}">×</span>
<span class="voice-name">${voice}</span>
<span class="voice-weight">
<input type="number"
value="${weight}"
min="0.1"
max="10"
step="0.1"
data-voice="${voice}"
class="weight-input"
title="Voice weight (0.1 to 10)">
</span>
<span class="remove-voice" data-voice="${voice}" title="Remove voice">×</span>
</span>
`)
.join('');

View file

@ -1,7 +1,7 @@
export class VoiceService {
constructor() {
this.availableVoices = [];
this.selectedVoices = new Set();
this.selectedVoices = new Map(); // Changed to Map to store voice:weight pairs
}
async loadVoices() {
@ -39,16 +39,33 @@ export class VoiceService {
}
getSelectedVoices() {
return Array.from(this.selectedVoices);
return Array.from(this.selectedVoices.keys());
}
getSelectedVoiceWeights() {
return Array.from(this.selectedVoices.entries()).map(([voice, weight]) => ({
voice,
weight
}));
}
getSelectedVoiceString() {
return Array.from(this.selectedVoices).join('+');
return Array.from(this.selectedVoices.entries())
.map(([voice, weight]) => `${voice}(${weight})`)
.join('+');
}
addVoice(voice) {
addVoice(voice, weight = 1) {
if (this.availableVoices.includes(voice)) {
this.selectedVoices.add(voice);
this.selectedVoices.set(voice, parseFloat(weight) || 1);
return true;
}
return false;
}
updateWeight(voice, weight) {
if (this.selectedVoices.has(voice)) {
this.selectedVoices.set(voice, parseFloat(weight) || 1);
return true;
}
return false;