mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
263 lines
10 KiB
Python
263 lines
10 KiB
Python
import io
|
|
import os
|
|
import re
|
|
import threading
|
|
import time
|
|
from typing import List, Tuple, Optional
|
|
|
|
import numpy as np
|
|
import scipy.io.wavfile as wavfile
|
|
import tiktoken
|
|
import torch
|
|
from loguru import logger
|
|
|
|
from kokoro import generate, normalize_text, phonemize, tokenize
|
|
from models import build_model
|
|
|
|
from ..core.config import settings
|
|
|
|
enc = tiktoken.get_encoding("cl100k_base")
|
|
|
|
|
|
class TTSModel:
|
|
_instance = None
|
|
_device = None
|
|
_lock = threading.Lock()
|
|
|
|
# Directory for all voices (copied base voices, and any created combined voices)
|
|
VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices")
|
|
|
|
@classmethod
|
|
def initialize(cls):
|
|
"""Initialize and warm up the model"""
|
|
with cls._lock:
|
|
if cls._instance is None:
|
|
# Initialize model
|
|
cls._device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
logger.info(f"Initializing model on {cls._device}")
|
|
model_path = os.path.join(settings.model_dir, settings.model_path)
|
|
model = build_model(model_path, cls._device)
|
|
cls._instance = model
|
|
|
|
# Ensure voices directory exists
|
|
os.makedirs(cls.VOICES_DIR, exist_ok=True)
|
|
|
|
# Copy base voices to local directory
|
|
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
|
|
if os.path.exists(base_voices_dir):
|
|
for file in os.listdir(base_voices_dir):
|
|
if file.endswith(".pt"):
|
|
voice_name = file[:-3]
|
|
voice_path = os.path.join(cls.VOICES_DIR, file)
|
|
if not os.path.exists(voice_path):
|
|
try:
|
|
logger.info(f"Copying base voice {voice_name} to voices directory")
|
|
base_path = os.path.join(base_voices_dir, file)
|
|
voicepack = torch.load(base_path, map_location=cls._device, weights_only=True)
|
|
torch.save(voicepack, voice_path)
|
|
except Exception as e:
|
|
logger.error(f"Error copying voice {voice_name}: {str(e)}")
|
|
|
|
# Warm up with default voice
|
|
try:
|
|
dummy_text = "Hello"
|
|
voice_path = os.path.join(cls.VOICES_DIR, "af.pt")
|
|
dummy_voicepack = torch.load(voice_path, map_location=cls._device, weights_only=True)
|
|
generate(model, dummy_text, dummy_voicepack, lang='a', speed=1.0)
|
|
logger.info("Model warm-up complete")
|
|
except Exception as e:
|
|
logger.warning(f"Model warm-up failed: {e}")
|
|
|
|
# Count voices in directory for validation
|
|
voice_count = len([f for f in os.listdir(cls.VOICES_DIR) if f.endswith('.pt')])
|
|
return cls._instance, voice_count
|
|
|
|
@classmethod
|
|
def get_instance(cls):
|
|
"""Get the initialized instance or raise an error"""
|
|
if cls._instance is None:
|
|
raise RuntimeError("Model not initialized. Call initialize() first.")
|
|
return cls._instance, cls._device
|
|
|
|
|
|
class TTSService:
|
|
def __init__(self, output_dir: str = None, start_worker: bool = False):
|
|
self.output_dir = output_dir
|
|
self._ensure_voices()
|
|
if start_worker:
|
|
self.start_worker()
|
|
|
|
def _ensure_voices(self):
|
|
"""Copy base voices to local voices directory during initialization"""
|
|
os.makedirs(TTSModel.VOICES_DIR, exist_ok=True)
|
|
|
|
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
|
|
if os.path.exists(base_voices_dir):
|
|
for file in os.listdir(base_voices_dir):
|
|
if file.endswith(".pt"):
|
|
voice_name = file[:-3]
|
|
voice_path = os.path.join(TTSModel.VOICES_DIR, file)
|
|
if not os.path.exists(voice_path):
|
|
try:
|
|
logger.info(f"Copying base voice {voice_name} to voices directory")
|
|
base_path = os.path.join(base_voices_dir, file)
|
|
voicepack = torch.load(base_path, map_location=TTSModel._device, weights_only=True)
|
|
torch.save(voicepack, voice_path)
|
|
except Exception as e:
|
|
logger.error(f"Error copying voice {voice_name}: {str(e)}")
|
|
|
|
def _split_text(self, text: str) -> List[str]:
|
|
"""Split text into sentences"""
|
|
return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()]
|
|
|
|
def _get_voice_path(self, voice_name: str) -> Optional[str]:
|
|
"""Get the path to a voice file.
|
|
|
|
Args:
|
|
voice_name: Name of the voice to find
|
|
|
|
Returns:
|
|
Path to the voice file if found, None otherwise
|
|
"""
|
|
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice_name}.pt")
|
|
return voice_path if os.path.exists(voice_path) else None
|
|
|
|
def _generate_audio(
|
|
self, text: str, voice: str, speed: float, stitch_long_output: bool = True
|
|
) -> Tuple[torch.Tensor, float]:
|
|
"""Generate audio and measure processing time"""
|
|
start_time = time.time()
|
|
|
|
try:
|
|
# Normalize text once at the start
|
|
text = normalize_text(text)
|
|
if not text:
|
|
raise ValueError("Text is empty after preprocessing")
|
|
|
|
# Check voice exists
|
|
voice_path = self._get_voice_path(voice)
|
|
if not voice_path:
|
|
raise ValueError(f"Voice not found: {voice}")
|
|
|
|
# Load model and voice
|
|
model = TTSModel._instance
|
|
voicepack = torch.load(voice_path, map_location=TTSModel._device, weights_only=True)
|
|
|
|
# Generate audio with or without stitching
|
|
if stitch_long_output:
|
|
chunks = self._split_text(text)
|
|
audio_chunks = []
|
|
|
|
# Process all chunks with same model/voicepack instance
|
|
for i, chunk in enumerate(chunks):
|
|
try:
|
|
# Validate phonemization first
|
|
ps = phonemize(chunk, voice[0])
|
|
tokens = tokenize(ps)
|
|
logger.debug(
|
|
f"Processing chunk {i + 1}/{len(chunks)}: {len(tokens)} tokens"
|
|
)
|
|
|
|
# Only proceed if phonemization succeeded
|
|
chunk_audio, _ = generate(
|
|
model, chunk, voicepack, lang=voice[0], speed=speed
|
|
)
|
|
if chunk_audio is not None:
|
|
audio_chunks.append(chunk_audio)
|
|
else:
|
|
logger.error(
|
|
f"No audio generated for chunk {i + 1}/{len(chunks)}"
|
|
)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Failed to generate audio for chunk {i + 1}/{len(chunks)}: '{chunk}'. Error: {str(e)}"
|
|
)
|
|
continue
|
|
|
|
if not audio_chunks:
|
|
raise ValueError("No audio chunks were generated successfully")
|
|
|
|
audio = (
|
|
np.concatenate(audio_chunks)
|
|
if len(audio_chunks) > 1
|
|
else audio_chunks[0]
|
|
)
|
|
else:
|
|
audio, _ = generate(model, text, voicepack, lang=voice[0], speed=speed)
|
|
|
|
processing_time = time.time() - start_time
|
|
return audio, processing_time
|
|
|
|
except Exception as e:
|
|
print(f"Error in audio generation: {str(e)}")
|
|
raise
|
|
|
|
def _save_audio(self, audio: torch.Tensor, filepath: str):
|
|
"""Save audio to file"""
|
|
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
|
wavfile.write(filepath, 24000, audio)
|
|
|
|
def _audio_to_bytes(self, audio: torch.Tensor) -> bytes:
|
|
"""Convert audio tensor to WAV bytes"""
|
|
buffer = io.BytesIO()
|
|
wavfile.write(buffer, 24000, audio)
|
|
return buffer.getvalue()
|
|
|
|
def combine_voices(self, voices: List[str]) -> str:
|
|
"""Combine multiple voices into a new voice.
|
|
|
|
Args:
|
|
voices: List of voice names to combine
|
|
|
|
Returns:
|
|
Name of the combined voice
|
|
|
|
Raises:
|
|
ValueError: If less than 2 voices provided or voice loading fails
|
|
RuntimeError: If voice combination or saving fails
|
|
"""
|
|
if len(voices) < 2:
|
|
raise ValueError("At least 2 voices are required for combination")
|
|
|
|
# Load voices
|
|
t_voices: List[torch.Tensor] = []
|
|
v_name: List[str] = []
|
|
|
|
for voice in voices:
|
|
try:
|
|
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt")
|
|
voicepack = torch.load(voice_path, map_location=TTSModel._device, weights_only=True)
|
|
t_voices.append(voicepack)
|
|
v_name.append(voice)
|
|
except Exception as e:
|
|
raise ValueError(f"Failed to load voice {voice}: {str(e)}")
|
|
|
|
# Combine voices
|
|
try:
|
|
f: str = "_".join(v_name)
|
|
v = torch.mean(torch.stack(t_voices), dim=0)
|
|
combined_path = os.path.join(TTSModel.VOICES_DIR, f"{f}.pt")
|
|
|
|
# Save combined voice
|
|
try:
|
|
torch.save(v, combined_path)
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to save combined voice to {combined_path}: {str(e)}")
|
|
|
|
return f
|
|
|
|
except Exception as e:
|
|
if not isinstance(e, (ValueError, RuntimeError)):
|
|
raise RuntimeError(f"Error combining voices: {str(e)}")
|
|
raise
|
|
|
|
def list_voices(self) -> List[str]:
|
|
"""List all available voices"""
|
|
voices = []
|
|
try:
|
|
for file in os.listdir(TTSModel.VOICES_DIR):
|
|
if file.endswith(".pt"):
|
|
voices.append(file[:-3]) # Remove .pt extension
|
|
except Exception as e:
|
|
logger.error(f"Error listing voices: {str(e)}")
|
|
return sorted(voices)
|