Kokoro-FastAPI/api/src/services/tts_model.py

95 lines
3.9 KiB
Python
Raw Normal View History

import os
import threading
import torch
from loguru import logger
from kokoro import tokenize, phonemize
from ..core.config import settings
from .tts_cpu import TTSCPUModel
from .tts_gpu import TTSGPUModel
class TTSModel:
_device = None
_lock = threading.Lock()
VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices")
@classmethod
def initialize(cls):
"""Initialize and warm up the model"""
with cls._lock:
# Set device and initialize model
cls._device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Initializing model on {cls._device}")
# Initialize appropriate model based on device
if cls._device == "cuda":
if not TTSGPUModel.initialize(settings.model_dir, settings.model_path):
raise RuntimeError("Failed to initialize GPU model")
else:
# Try CPU ONNX first, fallback to CPU PyTorch if needed
if not TTSCPUModel.initialize(settings.model_dir):
logger.warning("ONNX initialization failed, falling back to PyTorch CPU")
if not TTSGPUModel.initialize(settings.model_dir, settings.model_path):
raise RuntimeError("Failed to initialize CPU model")
# Setup voices directory
os.makedirs(cls.VOICES_DIR, exist_ok=True)
# Copy base voices to local directory
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
if os.path.exists(base_voices_dir):
for file in os.listdir(base_voices_dir):
if file.endswith(".pt"):
voice_name = file[:-3]
voice_path = os.path.join(cls.VOICES_DIR, file)
if not os.path.exists(voice_path):
try:
logger.info(
f"Copying base voice {voice_name} to voices directory"
)
base_path = os.path.join(base_voices_dir, file)
voicepack = torch.load(
base_path,
map_location=cls._device,
weights_only=True,
)
torch.save(voicepack, voice_path)
except Exception as e:
logger.error(
f"Error copying voice {voice_name}: {str(e)}"
)
# Warm up with default voice
try:
dummy_text = "Hello"
voice_path = os.path.join(cls.VOICES_DIR, "af.pt")
dummy_voicepack = torch.load(
voice_path, map_location=cls._device, weights_only=True
)
if cls._device == "cuda":
TTSGPUModel.generate(dummy_text, dummy_voicepack, "a", 1.0)
else:
ps = phonemize(dummy_text, "a")
tokens = tokenize(ps)
tokens = [0] + tokens + [0]
TTSCPUModel.generate(tokens, dummy_voicepack, 1.0)
logger.info("Model warm-up complete")
except Exception as e:
logger.warning(f"Model warm-up failed: {e}")
# Count voices in directory
voice_count = len(
[f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]
)
return voice_count
@classmethod
def get_device(cls):
"""Get the current device or raise an error"""
if cls._device is None:
raise RuntimeError("Model not initialized. Call initialize() first.")
return cls._device