Kokoro-FastAPI/api/src/services/tts.py
remsky 53cf71c151 -Removed commit lock on HF repo
-Warm start added to model initialization
-Layer caching tweaks to dockerfile
2025-01-01 17:38:22 -07:00

263 lines
10 KiB
Python

import io
import os
import re
import threading
import time
from typing import List, Tuple, Optional
import numpy as np
import scipy.io.wavfile as wavfile
import tiktoken
import torch
from loguru import logger
from kokoro import generate, normalize_text, phonemize, tokenize
from models import build_model
from ..core.config import settings
enc = tiktoken.get_encoding("cl100k_base")
class TTSModel:
_instance = None
_device = None
_lock = threading.Lock()
# Directory for all voices (copied base voices, and any created combined voices)
VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices")
@classmethod
def initialize(cls):
"""Initialize and warm up the model"""
with cls._lock:
if cls._instance is None:
# Initialize model
cls._device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Initializing model on {cls._device}")
model_path = os.path.join(settings.model_dir, settings.model_path)
model = build_model(model_path, cls._device)
cls._instance = model
# Ensure voices directory exists
os.makedirs(cls.VOICES_DIR, exist_ok=True)
# Copy base voices to local directory
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
if os.path.exists(base_voices_dir):
for file in os.listdir(base_voices_dir):
if file.endswith(".pt"):
voice_name = file[:-3]
voice_path = os.path.join(cls.VOICES_DIR, file)
if not os.path.exists(voice_path):
try:
logger.info(f"Copying base voice {voice_name} to voices directory")
base_path = os.path.join(base_voices_dir, file)
voicepack = torch.load(base_path, map_location=cls._device, weights_only=True)
torch.save(voicepack, voice_path)
except Exception as e:
logger.error(f"Error copying voice {voice_name}: {str(e)}")
# Warm up with default voice
try:
dummy_text = "Hello"
voice_path = os.path.join(cls.VOICES_DIR, "af.pt")
dummy_voicepack = torch.load(voice_path, map_location=cls._device, weights_only=True)
generate(model, dummy_text, dummy_voicepack, lang='a', speed=1.0)
logger.info("Model warm-up complete")
except Exception as e:
logger.warning(f"Model warm-up failed: {e}")
# Count voices in directory for validation
voice_count = len([f for f in os.listdir(cls.VOICES_DIR) if f.endswith('.pt')])
return cls._instance, voice_count
@classmethod
def get_instance(cls):
"""Get the initialized instance or raise an error"""
if cls._instance is None:
raise RuntimeError("Model not initialized. Call initialize() first.")
return cls._instance, cls._device
class TTSService:
def __init__(self, output_dir: str = None, start_worker: bool = False):
self.output_dir = output_dir
self._ensure_voices()
if start_worker:
self.start_worker()
def _ensure_voices(self):
"""Copy base voices to local voices directory during initialization"""
os.makedirs(TTSModel.VOICES_DIR, exist_ok=True)
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
if os.path.exists(base_voices_dir):
for file in os.listdir(base_voices_dir):
if file.endswith(".pt"):
voice_name = file[:-3]
voice_path = os.path.join(TTSModel.VOICES_DIR, file)
if not os.path.exists(voice_path):
try:
logger.info(f"Copying base voice {voice_name} to voices directory")
base_path = os.path.join(base_voices_dir, file)
voicepack = torch.load(base_path, map_location=TTSModel._device, weights_only=True)
torch.save(voicepack, voice_path)
except Exception as e:
logger.error(f"Error copying voice {voice_name}: {str(e)}")
def _split_text(self, text: str) -> List[str]:
"""Split text into sentences"""
return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()]
def _get_voice_path(self, voice_name: str) -> Optional[str]:
"""Get the path to a voice file.
Args:
voice_name: Name of the voice to find
Returns:
Path to the voice file if found, None otherwise
"""
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice_name}.pt")
return voice_path if os.path.exists(voice_path) else None
def _generate_audio(
self, text: str, voice: str, speed: float, stitch_long_output: bool = True
) -> Tuple[torch.Tensor, float]:
"""Generate audio and measure processing time"""
start_time = time.time()
try:
# Normalize text once at the start
text = normalize_text(text)
if not text:
raise ValueError("Text is empty after preprocessing")
# Check voice exists
voice_path = self._get_voice_path(voice)
if not voice_path:
raise ValueError(f"Voice not found: {voice}")
# Load model and voice
model = TTSModel._instance
voicepack = torch.load(voice_path, map_location=TTSModel._device, weights_only=True)
# Generate audio with or without stitching
if stitch_long_output:
chunks = self._split_text(text)
audio_chunks = []
# Process all chunks with same model/voicepack instance
for i, chunk in enumerate(chunks):
try:
# Validate phonemization first
ps = phonemize(chunk, voice[0])
tokens = tokenize(ps)
logger.debug(
f"Processing chunk {i + 1}/{len(chunks)}: {len(tokens)} tokens"
)
# Only proceed if phonemization succeeded
chunk_audio, _ = generate(
model, chunk, voicepack, lang=voice[0], speed=speed
)
if chunk_audio is not None:
audio_chunks.append(chunk_audio)
else:
logger.error(
f"No audio generated for chunk {i + 1}/{len(chunks)}"
)
except Exception as e:
logger.error(
f"Failed to generate audio for chunk {i + 1}/{len(chunks)}: '{chunk}'. Error: {str(e)}"
)
continue
if not audio_chunks:
raise ValueError("No audio chunks were generated successfully")
audio = (
np.concatenate(audio_chunks)
if len(audio_chunks) > 1
else audio_chunks[0]
)
else:
audio, _ = generate(model, text, voicepack, lang=voice[0], speed=speed)
processing_time = time.time() - start_time
return audio, processing_time
except Exception as e:
print(f"Error in audio generation: {str(e)}")
raise
def _save_audio(self, audio: torch.Tensor, filepath: str):
"""Save audio to file"""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
wavfile.write(filepath, 24000, audio)
def _audio_to_bytes(self, audio: torch.Tensor) -> bytes:
"""Convert audio tensor to WAV bytes"""
buffer = io.BytesIO()
wavfile.write(buffer, 24000, audio)
return buffer.getvalue()
def combine_voices(self, voices: List[str]) -> str:
"""Combine multiple voices into a new voice.
Args:
voices: List of voice names to combine
Returns:
Name of the combined voice
Raises:
ValueError: If less than 2 voices provided or voice loading fails
RuntimeError: If voice combination or saving fails
"""
if len(voices) < 2:
raise ValueError("At least 2 voices are required for combination")
# Load voices
t_voices: List[torch.Tensor] = []
v_name: List[str] = []
for voice in voices:
try:
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt")
voicepack = torch.load(voice_path, map_location=TTSModel._device, weights_only=True)
t_voices.append(voicepack)
v_name.append(voice)
except Exception as e:
raise ValueError(f"Failed to load voice {voice}: {str(e)}")
# Combine voices
try:
f: str = "_".join(v_name)
v = torch.mean(torch.stack(t_voices), dim=0)
combined_path = os.path.join(TTSModel.VOICES_DIR, f"{f}.pt")
# Save combined voice
try:
torch.save(v, combined_path)
except Exception as e:
raise RuntimeError(f"Failed to save combined voice to {combined_path}: {str(e)}")
return f
except Exception as e:
if not isinstance(e, (ValueError, RuntimeError)):
raise RuntimeError(f"Error combining voices: {str(e)}")
raise
def list_voices(self) -> List[str]:
"""List all available voices"""
voices = []
try:
for file in os.listdir(TTSModel.VOICES_DIR):
if file.endswith(".pt"):
voices.append(file[:-3]) # Remove .pt extension
except Exception as e:
logger.error(f"Error listing voices: {str(e)}")
return sorted(voices)