Kokoro-FastAPI/api/src/services/tts.py

264 lines
10 KiB
Python
Raw Normal View History

import io
import os
import re
import threading
2024-12-31 10:30:12 -05:00
import time
from typing import List, Tuple, Optional
import numpy as np
import scipy.io.wavfile as wavfile
2024-12-31 10:30:12 -05:00
import tiktoken
import torch
from loguru import logger
2024-12-31 10:30:12 -05:00
from kokoro import generate, normalize_text, phonemize, tokenize
from models import build_model
from ..core.config import settings
enc = tiktoken.get_encoding("cl100k_base")
class TTSModel:
_instance = None
_device = None
_lock = threading.Lock()
# Directory for all voices (copied base voices, and any created combined voices)
VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices")
@classmethod
def initialize(cls):
"""Initialize and warm up the model"""
with cls._lock:
if cls._instance is None:
# Initialize model
cls._device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Initializing model on {cls._device}")
model_path = os.path.join(settings.model_dir, settings.model_path)
model = build_model(model_path, cls._device)
cls._instance = model
# Ensure voices directory exists
os.makedirs(cls.VOICES_DIR, exist_ok=True)
# Copy base voices to local directory
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
if os.path.exists(base_voices_dir):
for file in os.listdir(base_voices_dir):
if file.endswith(".pt"):
voice_name = file[:-3]
voice_path = os.path.join(cls.VOICES_DIR, file)
if not os.path.exists(voice_path):
try:
logger.info(f"Copying base voice {voice_name} to voices directory")
base_path = os.path.join(base_voices_dir, file)
voicepack = torch.load(base_path, map_location=cls._device, weights_only=True)
torch.save(voicepack, voice_path)
except Exception as e:
logger.error(f"Error copying voice {voice_name}: {str(e)}")
# Warm up with default voice
try:
dummy_text = "Hello"
voice_path = os.path.join(cls.VOICES_DIR, "af.pt")
dummy_voicepack = torch.load(voice_path, map_location=cls._device, weights_only=True)
generate(model, dummy_text, dummy_voicepack, lang='a', speed=1.0)
logger.info("Model warm-up complete")
except Exception as e:
logger.warning(f"Model warm-up failed: {e}")
# Count voices in directory for validation
voice_count = len([f for f in os.listdir(cls.VOICES_DIR) if f.endswith('.pt')])
return cls._instance, voice_count
@classmethod
def get_instance(cls):
"""Get the initialized instance or raise an error"""
if cls._instance is None:
raise RuntimeError("Model not initialized. Call initialize() first.")
return cls._instance, cls._device
class TTSService:
def __init__(self, output_dir: str = None, start_worker: bool = False):
self.output_dir = output_dir
self._ensure_voices()
if start_worker:
self.start_worker()
def _ensure_voices(self):
"""Copy base voices to local voices directory during initialization"""
os.makedirs(TTSModel.VOICES_DIR, exist_ok=True)
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
if os.path.exists(base_voices_dir):
for file in os.listdir(base_voices_dir):
if file.endswith(".pt"):
voice_name = file[:-3]
voice_path = os.path.join(TTSModel.VOICES_DIR, file)
if not os.path.exists(voice_path):
try:
logger.info(f"Copying base voice {voice_name} to voices directory")
base_path = os.path.join(base_voices_dir, file)
voicepack = torch.load(base_path, map_location=TTSModel._device, weights_only=True)
torch.save(voicepack, voice_path)
except Exception as e:
logger.error(f"Error copying voice {voice_name}: {str(e)}")
def _split_text(self, text: str) -> List[str]:
"""Split text into sentences"""
return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()]
def _get_voice_path(self, voice_name: str) -> Optional[str]:
"""Get the path to a voice file.
Args:
voice_name: Name of the voice to find
Returns:
Path to the voice file if found, None otherwise
"""
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice_name}.pt")
return voice_path if os.path.exists(voice_path) else None
def _generate_audio(
self, text: str, voice: str, speed: float, stitch_long_output: bool = True
) -> Tuple[torch.Tensor, float]:
"""Generate audio and measure processing time"""
start_time = time.time()
try:
# Normalize text once at the start
text = normalize_text(text)
if not text:
raise ValueError("Text is empty after preprocessing")
# Check voice exists
voice_path = self._get_voice_path(voice)
if not voice_path:
raise ValueError(f"Voice not found: {voice}")
# Load model and voice
model = TTSModel._instance
voicepack = torch.load(voice_path, map_location=TTSModel._device, weights_only=True)
# Generate audio with or without stitching
if stitch_long_output:
chunks = self._split_text(text)
audio_chunks = []
# Process all chunks with same model/voicepack instance
for i, chunk in enumerate(chunks):
try:
# Validate phonemization first
ps = phonemize(chunk, voice[0])
tokens = tokenize(ps)
logger.debug(
2024-12-31 10:30:12 -05:00
f"Processing chunk {i + 1}/{len(chunks)}: {len(tokens)} tokens"
)
# Only proceed if phonemization succeeded
chunk_audio, _ = generate(
model, chunk, voicepack, lang=voice[0], speed=speed
)
if chunk_audio is not None:
audio_chunks.append(chunk_audio)
else:
logger.error(
2024-12-31 10:30:12 -05:00
f"No audio generated for chunk {i + 1}/{len(chunks)}"
)
except Exception as e:
logger.error(
2024-12-31 10:30:12 -05:00
f"Failed to generate audio for chunk {i + 1}/{len(chunks)}: '{chunk}'. Error: {str(e)}"
)
continue
2024-12-30 13:39:35 -05:00
if not audio_chunks:
raise ValueError("No audio chunks were generated successfully")
2024-12-30 13:39:35 -05:00
audio = (
np.concatenate(audio_chunks)
if len(audio_chunks) > 1
else audio_chunks[0]
)
else:
audio, _ = generate(model, text, voicepack, lang=voice[0], speed=speed)
processing_time = time.time() - start_time
return audio, processing_time
except Exception as e:
print(f"Error in audio generation: {str(e)}")
raise
def _save_audio(self, audio: torch.Tensor, filepath: str):
"""Save audio to file"""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
wavfile.write(filepath, 24000, audio)
def _audio_to_bytes(self, audio: torch.Tensor) -> bytes:
"""Convert audio tensor to WAV bytes"""
buffer = io.BytesIO()
wavfile.write(buffer, 24000, audio)
return buffer.getvalue()
2024-12-31 10:30:12 -05:00
def combine_voices(self, voices: List[str]) -> str:
"""Combine multiple voices into a new voice.
Args:
voices: List of voice names to combine
Returns:
Name of the combined voice
Raises:
ValueError: If less than 2 voices provided or voice loading fails
RuntimeError: If voice combination or saving fails
"""
2024-12-31 10:30:12 -05:00
if len(voices) < 2:
raise ValueError("At least 2 voices are required for combination")
# Load voices
2024-12-31 10:30:12 -05:00
t_voices: List[torch.Tensor] = []
v_name: List[str] = []
for voice in voices:
try:
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt")
voicepack = torch.load(voice_path, map_location=TTSModel._device, weights_only=True)
t_voices.append(voicepack)
v_name.append(voice)
except Exception as e:
raise ValueError(f"Failed to load voice {voice}: {str(e)}")
# Combine voices
2024-12-31 10:30:12 -05:00
try:
f: str = "_".join(v_name)
v = torch.mean(torch.stack(t_voices), dim=0)
combined_path = os.path.join(TTSModel.VOICES_DIR, f"{f}.pt")
# Save combined voice
try:
torch.save(v, combined_path)
except Exception as e:
raise RuntimeError(f"Failed to save combined voice to {combined_path}: {str(e)}")
return f
2024-12-31 10:30:12 -05:00
except Exception as e:
if not isinstance(e, (ValueError, RuntimeError)):
raise RuntimeError(f"Error combining voices: {str(e)}")
raise
2024-12-31 10:30:12 -05:00
def list_voices(self) -> List[str]:
"""List all available voices"""
voices = []
try:
for file in os.listdir(TTSModel.VOICES_DIR):
if file.endswith(".pt"):
voices.append(file[:-3]) # Remove .pt extension
except Exception as e:
logger.error(f"Error listing voices: {str(e)}")
return sorted(voices)