2025-01-20 22:42:29 -07:00
|
|
|
"""Model management and caching."""
|
|
|
|
|
2025-01-22 02:33:29 -07:00
|
|
|
import asyncio
|
2025-01-22 21:11:47 -07:00
|
|
|
from typing import Dict, Optional, Tuple
|
2025-01-20 22:42:29 -07:00
|
|
|
|
|
|
|
import torch
|
|
|
|
from loguru import logger
|
|
|
|
|
2025-01-21 21:44:21 -07:00
|
|
|
from ..core import paths
|
|
|
|
from ..core.config import settings
|
|
|
|
from ..core.model_config import ModelConfig, model_config
|
2025-01-20 22:42:29 -07:00
|
|
|
from .base import BaseModelBackend
|
|
|
|
from .onnx_cpu import ONNXCPUBackend
|
|
|
|
from .onnx_gpu import ONNXGPUBackend
|
|
|
|
from .pytorch_cpu import PyTorchCPUBackend
|
|
|
|
from .pytorch_gpu import PyTorchGPUBackend
|
2025-01-22 02:33:29 -07:00
|
|
|
from .session_pool import CPUSessionPool, StreamingSessionPool
|
|
|
|
|
|
|
|
|
2025-01-22 05:00:38 -07:00
|
|
|
# Global singleton instance and lock for thread-safe initialization
|
2025-01-22 02:33:29 -07:00
|
|
|
_manager_instance = None
|
|
|
|
_manager_lock = asyncio.Lock()
|
2025-01-20 22:42:29 -07:00
|
|
|
|
|
|
|
class ModelManager:
|
|
|
|
"""Manages model loading and inference across backends."""
|
2025-01-22 05:00:38 -07:00
|
|
|
# Class-level state for shared resources
|
|
|
|
_loaded_models = {}
|
|
|
|
_backends = {}
|
2025-01-20 22:42:29 -07:00
|
|
|
def __init__(self, config: Optional[ModelConfig] = None):
|
|
|
|
"""Initialize model manager.
|
2025-01-22 05:00:38 -07:00
|
|
|
Note:
|
|
|
|
This should not be called directly. Use get_manager() instead.
|
2025-01-20 22:42:29 -07:00
|
|
|
"""
|
2025-01-21 21:44:21 -07:00
|
|
|
self._config = config or model_config
|
2025-01-22 02:33:29 -07:00
|
|
|
|
|
|
|
# Initialize session pools
|
|
|
|
self._session_pools = {
|
|
|
|
'onnx_cpu': CPUSessionPool(),
|
|
|
|
'onnx_gpu': StreamingSessionPool()
|
|
|
|
}
|
|
|
|
|
|
|
|
# Initialize locks
|
|
|
|
self._backend_locks: Dict[str, asyncio.Lock] = {}
|
2025-01-20 22:42:29 -07:00
|
|
|
|
2025-01-22 02:33:29 -07:00
|
|
|
def _determine_device(self) -> str:
|
|
|
|
"""Determine device based on settings."""
|
|
|
|
if settings.use_gpu and torch.cuda.is_available():
|
|
|
|
return "cuda"
|
|
|
|
return "cpu"
|
|
|
|
|
|
|
|
async def initialize(self) -> None:
|
|
|
|
"""Initialize backends."""
|
|
|
|
if self._backends:
|
|
|
|
logger.debug("Using existing backend instances")
|
|
|
|
return
|
|
|
|
|
|
|
|
device = self._determine_device()
|
2025-01-21 21:44:21 -07:00
|
|
|
|
|
|
|
try:
|
2025-01-22 02:33:29 -07:00
|
|
|
if device == "cuda":
|
2025-01-21 21:44:21 -07:00
|
|
|
if settings.use_onnx:
|
|
|
|
self._backends['onnx_gpu'] = ONNXGPUBackend()
|
|
|
|
self._current_backend = 'onnx_gpu'
|
2025-01-22 02:33:29 -07:00
|
|
|
logger.info("Initialized new ONNX GPU backend")
|
2025-01-21 21:44:21 -07:00
|
|
|
else:
|
|
|
|
self._backends['pytorch_gpu'] = PyTorchGPUBackend()
|
|
|
|
self._current_backend = 'pytorch_gpu'
|
2025-01-22 02:33:29 -07:00
|
|
|
logger.info("Initialized new PyTorch GPU backend")
|
2025-01-21 21:44:21 -07:00
|
|
|
else:
|
2025-01-22 02:33:29 -07:00
|
|
|
if settings.use_onnx:
|
|
|
|
self._backends['onnx_cpu'] = ONNXCPUBackend()
|
|
|
|
self._current_backend = 'onnx_cpu'
|
|
|
|
logger.info("Initialized new ONNX CPU backend")
|
|
|
|
else:
|
|
|
|
self._backends['pytorch_cpu'] = PyTorchCPUBackend()
|
|
|
|
self._current_backend = 'pytorch_cpu'
|
|
|
|
logger.info("Initialized new PyTorch CPU backend")
|
|
|
|
|
|
|
|
# Initialize locks for each backend
|
|
|
|
for backend in self._backends:
|
|
|
|
self._backend_locks[backend] = asyncio.Lock()
|
2025-01-21 21:44:21 -07:00
|
|
|
|
2025-01-20 22:42:29 -07:00
|
|
|
except Exception as e:
|
2025-01-22 02:33:29 -07:00
|
|
|
logger.error(f"Failed to initialize backend: {e}")
|
|
|
|
raise RuntimeError("Failed to initialize backend")
|
2025-01-20 22:42:29 -07:00
|
|
|
|
2025-01-22 21:11:47 -07:00
|
|
|
async def initialize_with_warmup(self, voice_manager) -> tuple[str, str, int]:
|
|
|
|
"""Initialize model with warmup and pre-cache voices.
|
|
|
|
Args:
|
|
|
|
voice_manager: Voice manager instance for loading voices
|
|
|
|
Returns:
|
|
|
|
Tuple of (device type, model type, number of loaded voices)
|
|
|
|
Raises:
|
|
|
|
RuntimeError: If initialization fails
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
# Determine backend type based on settings
|
|
|
|
if settings.use_gpu and torch.cuda.is_available():
|
|
|
|
backend_type = 'pytorch_gpu' if not settings.use_onnx else 'onnx_gpu'
|
|
|
|
else:
|
|
|
|
backend_type = 'pytorch_cpu' if not settings.use_onnx else 'onnx_cpu'
|
|
|
|
|
|
|
|
# Get backend
|
|
|
|
backend = self.get_backend(backend_type)
|
|
|
|
|
|
|
|
# Get and verify model path
|
|
|
|
model_file = model_config.pytorch_model_file if not settings.use_onnx else model_config.onnx_model_file
|
|
|
|
model_path = await paths.get_model_path(model_file)
|
|
|
|
|
|
|
|
if not await paths.verify_model_path(model_path):
|
|
|
|
raise RuntimeError(f"Model file not found: {model_path}")
|
|
|
|
|
|
|
|
# Pre-cache default voice and use for warmup
|
|
|
|
warmup_voice = await voice_manager.load_voice(
|
|
|
|
settings.default_voice, device=backend.device)
|
|
|
|
logger.info(f"Pre-cached voice {settings.default_voice} for warmup")
|
|
|
|
|
|
|
|
# Initialize model with warmup voice
|
|
|
|
await self.load_model(model_path, warmup_voice, backend_type)
|
|
|
|
|
2025-01-24 05:01:38 -07:00
|
|
|
# Only pre-cache default voice to avoid memory bloat
|
|
|
|
logger.info(f"Using {settings.default_voice} as warmup voice")
|
2025-01-22 21:11:47 -07:00
|
|
|
|
|
|
|
# Get available voices count
|
|
|
|
voices = await voice_manager.list_voices()
|
|
|
|
voicepack_count = len(voices)
|
|
|
|
|
|
|
|
# Get device info for return
|
|
|
|
device = "GPU" if settings.use_gpu else "CPU"
|
|
|
|
model = "ONNX" if settings.use_onnx else "PyTorch"
|
|
|
|
|
|
|
|
return device, model, voicepack_count
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Failed to initialize model with warmup: {e}")
|
|
|
|
raise RuntimeError(f"Failed to initialize model with warmup: {e}")
|
|
|
|
|
2025-01-20 22:42:29 -07:00
|
|
|
def get_backend(self, backend_type: Optional[str] = None) -> BaseModelBackend:
|
|
|
|
"""Get specified backend.
|
|
|
|
Args:
|
|
|
|
backend_type: Backend type ('pytorch_cpu', 'pytorch_gpu', 'onnx_cpu', 'onnx_gpu'),
|
|
|
|
uses default if None
|
|
|
|
Returns:
|
|
|
|
Model backend instance
|
|
|
|
Raises:
|
|
|
|
ValueError: If backend type is invalid
|
|
|
|
RuntimeError: If no backends are available
|
|
|
|
"""
|
|
|
|
if not self._backends:
|
|
|
|
raise RuntimeError("No backends available")
|
|
|
|
|
|
|
|
if backend_type is None:
|
|
|
|
backend_type = self._current_backend
|
|
|
|
|
|
|
|
if backend_type not in self._backends:
|
|
|
|
raise ValueError(
|
|
|
|
f"Invalid backend type: {backend_type}. "
|
|
|
|
f"Available backends: {', '.join(self._backends.keys())}"
|
|
|
|
)
|
|
|
|
|
|
|
|
return self._backends[backend_type]
|
|
|
|
|
|
|
|
def _determine_backend(self, model_path: str) -> str:
|
2025-01-21 21:44:21 -07:00
|
|
|
"""Determine appropriate backend based on model file and settings.
|
2025-01-20 22:42:29 -07:00
|
|
|
Args:
|
|
|
|
model_path: Path to model file
|
|
|
|
Returns:
|
|
|
|
Backend type to use
|
|
|
|
"""
|
|
|
|
has_gpu = settings.use_gpu and torch.cuda.is_available()
|
|
|
|
|
2025-01-21 21:44:21 -07:00
|
|
|
# If ONNX is preferred or model is ONNX format
|
|
|
|
if settings.use_onnx or model_path.lower().endswith('.onnx'):
|
2025-01-20 22:42:29 -07:00
|
|
|
return 'onnx_gpu' if has_gpu else 'onnx_cpu'
|
|
|
|
else:
|
|
|
|
return 'pytorch_gpu' if has_gpu else 'pytorch_cpu'
|
|
|
|
|
|
|
|
async def load_model(
|
|
|
|
self,
|
|
|
|
model_path: str,
|
2025-01-21 21:44:21 -07:00
|
|
|
warmup_voice: Optional[torch.Tensor] = None,
|
2025-01-20 22:42:29 -07:00
|
|
|
backend_type: Optional[str] = None
|
|
|
|
) -> None:
|
|
|
|
"""Load model on specified backend.
|
|
|
|
Args:
|
|
|
|
model_path: Path to model file
|
2025-01-21 21:44:21 -07:00
|
|
|
warmup_voice: Optional voice tensor for warmup, skips warmup if None
|
2025-01-20 22:42:29 -07:00
|
|
|
backend_type: Backend to load on, uses default if None
|
|
|
|
Raises:
|
|
|
|
RuntimeError: If model loading fails
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
# Get absolute model path
|
|
|
|
abs_path = await paths.get_model_path(model_path)
|
|
|
|
|
|
|
|
# Auto-determine backend if not specified
|
|
|
|
if backend_type is None:
|
|
|
|
backend_type = self._determine_backend(abs_path)
|
|
|
|
|
2025-01-22 02:33:29 -07:00
|
|
|
# Get backend lock
|
|
|
|
lock = self._backend_locks[backend_type]
|
2025-01-21 21:44:21 -07:00
|
|
|
|
2025-01-22 02:33:29 -07:00
|
|
|
async with lock:
|
|
|
|
backend = self.get_backend(backend_type)
|
|
|
|
|
|
|
|
# For ONNX backends, use session pool
|
|
|
|
if backend_type.startswith('onnx'):
|
|
|
|
pool = self._session_pools[backend_type]
|
|
|
|
backend._session = await pool.get_session(abs_path)
|
|
|
|
self._loaded_models[backend_type] = abs_path
|
|
|
|
logger.info(f"Fetched model instance from {backend_type} pool")
|
|
|
|
|
|
|
|
# For PyTorch backends, load normally
|
|
|
|
else:
|
|
|
|
# Check if model is already loaded
|
|
|
|
if (backend_type in self._loaded_models and
|
|
|
|
self._loaded_models[backend_type] == abs_path and
|
|
|
|
backend.is_loaded):
|
|
|
|
logger.info(f"Fetching existing model instance from {backend_type}")
|
|
|
|
return
|
|
|
|
|
|
|
|
# Load model
|
|
|
|
await backend.load_model(abs_path)
|
|
|
|
self._loaded_models[backend_type] = abs_path
|
|
|
|
logger.info(f"Initialized new model instance on {backend_type}")
|
|
|
|
|
|
|
|
# Run warmup if voice provided
|
|
|
|
if warmup_voice is not None:
|
|
|
|
await self._warmup_inference(backend, warmup_voice)
|
2025-01-20 22:42:29 -07:00
|
|
|
|
|
|
|
except Exception as e:
|
2025-01-22 02:33:29 -07:00
|
|
|
# Clear cached path on failure
|
|
|
|
self._loaded_models.pop(backend_type, None)
|
2025-01-20 22:42:29 -07:00
|
|
|
raise RuntimeError(f"Failed to load model: {e}")
|
2025-01-22 02:33:29 -07:00
|
|
|
|
2025-01-21 21:44:21 -07:00
|
|
|
async def _warmup_inference(self, backend: BaseModelBackend, voice: torch.Tensor) -> None:
|
|
|
|
"""Run warmup inference to initialize model.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
backend: Model backend to warm up
|
|
|
|
voice: Voice tensor already loaded on correct device
|
|
|
|
"""
|
2025-01-20 22:42:29 -07:00
|
|
|
try:
|
|
|
|
# Import here to avoid circular imports
|
2025-01-21 21:44:21 -07:00
|
|
|
from ..services.text_processing import process_text
|
2025-01-20 22:42:29 -07:00
|
|
|
|
|
|
|
# Use real text
|
|
|
|
text = "Testing text to speech synthesis."
|
|
|
|
|
|
|
|
# Process through pipeline
|
2025-01-21 21:44:21 -07:00
|
|
|
tokens = process_text(text)
|
|
|
|
if not tokens:
|
2025-01-20 22:42:29 -07:00
|
|
|
raise ValueError("Text processing failed")
|
|
|
|
|
|
|
|
# Run inference
|
2025-01-21 21:44:21 -07:00
|
|
|
backend.generate(tokens, voice, speed=1.0)
|
2025-01-22 02:33:29 -07:00
|
|
|
logger.debug("Completed warmup inference")
|
2025-01-20 22:42:29 -07:00
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
logger.warning(f"Warmup inference failed: {e}")
|
|
|
|
raise
|
|
|
|
|
|
|
|
async def generate(
|
|
|
|
self,
|
|
|
|
tokens: list[int],
|
2025-01-21 21:44:21 -07:00
|
|
|
voice: torch.Tensor,
|
2025-01-20 22:42:29 -07:00
|
|
|
speed: float = 1.0,
|
|
|
|
backend_type: Optional[str] = None
|
|
|
|
) -> torch.Tensor:
|
|
|
|
"""Generate audio using specified backend.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
tokens: Input token IDs
|
2025-01-21 21:44:21 -07:00
|
|
|
voice: Voice tensor already loaded on correct device
|
2025-01-20 22:42:29 -07:00
|
|
|
speed: Speed multiplier
|
|
|
|
backend_type: Backend to use, uses default if None
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Generated audio tensor
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
RuntimeError: If generation fails
|
|
|
|
"""
|
|
|
|
backend = self.get_backend(backend_type)
|
|
|
|
if not backend.is_loaded:
|
|
|
|
raise RuntimeError("Model not loaded")
|
|
|
|
|
|
|
|
try:
|
2025-01-21 21:44:21 -07:00
|
|
|
# Generate audio using provided voice tensor
|
2025-01-22 02:33:29 -07:00
|
|
|
# No lock needed here since inference is thread-safe
|
2025-01-20 22:42:29 -07:00
|
|
|
return backend.generate(tokens, voice, speed)
|
|
|
|
except Exception as e:
|
|
|
|
raise RuntimeError(f"Generation failed: {e}")
|
|
|
|
|
|
|
|
def unload_all(self) -> None:
|
2025-01-22 02:33:29 -07:00
|
|
|
"""Unload models from all backends and clear cache."""
|
|
|
|
# Clean up session pools
|
|
|
|
for pool in self._session_pools.values():
|
|
|
|
pool.cleanup()
|
|
|
|
|
|
|
|
# Unload PyTorch backends
|
2025-01-20 22:42:29 -07:00
|
|
|
for backend in self._backends.values():
|
|
|
|
backend.unload()
|
2025-01-22 02:33:29 -07:00
|
|
|
|
|
|
|
self._loaded_models.clear()
|
|
|
|
logger.info("Unloaded all models and cleared cache")
|
2025-01-20 22:42:29 -07:00
|
|
|
|
|
|
|
@property
|
|
|
|
def available_backends(self) -> list[str]:
|
|
|
|
"""Get list of available backends.
|
|
|
|
"""
|
|
|
|
return list(self._backends.keys())
|
|
|
|
|
|
|
|
@property
|
|
|
|
def current_backend(self) -> str:
|
|
|
|
"""Get current default backend.
|
|
|
|
"""
|
|
|
|
return self._current_backend
|
|
|
|
|
|
|
|
|
2025-01-22 02:33:29 -07:00
|
|
|
async def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
|
|
|
|
"""Get global model manager instance.
|
2025-01-20 22:42:29 -07:00
|
|
|
Args:
|
|
|
|
config: Optional model configuration
|
|
|
|
Returns:
|
|
|
|
ModelManager instance
|
2025-01-22 05:00:38 -07:00
|
|
|
Thread Safety:
|
2025-01-22 21:11:47 -07:00
|
|
|
This function should be thread-safe. Lemme know if it unravels on you
|
2025-01-20 22:42:29 -07:00
|
|
|
"""
|
2025-01-22 02:33:29 -07:00
|
|
|
global _manager_instance
|
|
|
|
|
2025-01-22 05:00:38 -07:00
|
|
|
# Fast path - return existing instance without lock
|
|
|
|
if _manager_instance is not None:
|
|
|
|
return _manager_instance
|
|
|
|
|
|
|
|
# Slow path - create new instance with lock
|
2025-01-22 02:33:29 -07:00
|
|
|
async with _manager_lock:
|
2025-01-22 05:00:38 -07:00
|
|
|
# Double-check pattern
|
2025-01-22 02:33:29 -07:00
|
|
|
if _manager_instance is None:
|
|
|
|
_manager_instance = ModelManager(config)
|
|
|
|
await _manager_instance.initialize()
|
|
|
|
return _manager_instance
|