mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
251 lines
8.3 KiB
Python
251 lines
8.3 KiB
Python
![]() |
"""Model management and caching."""
|
||
|
|
||
|
import os
|
||
|
from typing import Dict, List, Optional, Union
|
||
|
|
||
|
import torch
|
||
|
from loguru import logger
|
||
|
from pydantic import BaseModel
|
||
|
|
||
|
from .base import BaseModelBackend
|
||
|
from .voice_manager import get_manager as get_voice_manager
|
||
|
from .onnx_cpu import ONNXCPUBackend
|
||
|
from .onnx_gpu import ONNXGPUBackend
|
||
|
from .pytorch_cpu import PyTorchCPUBackend
|
||
|
from .pytorch_gpu import PyTorchGPUBackend
|
||
|
from ..core import paths
|
||
|
from ..core.config import settings
|
||
|
from ..structures.model_schemas import ModelConfig
|
||
|
|
||
|
|
||
|
class ModelManager:
|
||
|
"""Manages model loading and inference across backends."""
|
||
|
|
||
|
def __init__(self, config: Optional[ModelConfig] = None):
|
||
|
"""Initialize model manager.
|
||
|
|
||
|
Args:
|
||
|
config: Optional configuration
|
||
|
"""
|
||
|
self._config = config or ModelConfig()
|
||
|
self._backends: Dict[str, BaseModelBackend] = {}
|
||
|
self._current_backend: Optional[str] = None
|
||
|
self._voice_manager = get_voice_manager()
|
||
|
self._initialize_backends()
|
||
|
|
||
|
def _initialize_backends(self) -> None:
|
||
|
"""Initialize available backends."""
|
||
|
"""Initialize available backends."""
|
||
|
# Initialize GPU backends if available
|
||
|
if settings.use_gpu and torch.cuda.is_available():
|
||
|
try:
|
||
|
# PyTorch GPU
|
||
|
self._backends['pytorch_gpu'] = PyTorchGPUBackend()
|
||
|
self._current_backend = 'pytorch_gpu'
|
||
|
logger.info("Initialized PyTorch GPU backend")
|
||
|
|
||
|
# ONNX GPU
|
||
|
self._backends['onnx_gpu'] = ONNXGPUBackend()
|
||
|
logger.info("Initialized ONNX GPU backend")
|
||
|
except Exception as e:
|
||
|
logger.error(f"Failed to initialize GPU backends: {e}")
|
||
|
# Fallback to CPU if GPU fails
|
||
|
self._initialize_cpu_backends()
|
||
|
else:
|
||
|
self._initialize_cpu_backends()
|
||
|
|
||
|
def _initialize_cpu_backends(self) -> None:
|
||
|
"""Initialize CPU backends."""
|
||
|
try:
|
||
|
# PyTorch CPU
|
||
|
self._backends['pytorch_cpu'] = PyTorchCPUBackend()
|
||
|
self._current_backend = 'pytorch_cpu'
|
||
|
logger.info("Initialized PyTorch CPU backend")
|
||
|
|
||
|
# ONNX CPU
|
||
|
self._backends['onnx_cpu'] = ONNXCPUBackend()
|
||
|
logger.info("Initialized ONNX CPU backend")
|
||
|
except Exception as e:
|
||
|
logger.error(f"Failed to initialize CPU backends: {e}")
|
||
|
raise RuntimeError("No backends available")
|
||
|
|
||
|
def get_backend(self, backend_type: Optional[str] = None) -> BaseModelBackend:
|
||
|
"""Get specified backend.
|
||
|
|
||
|
Args:
|
||
|
backend_type: Backend type ('pytorch_cpu', 'pytorch_gpu', 'onnx_cpu', 'onnx_gpu'),
|
||
|
uses default if None
|
||
|
|
||
|
Returns:
|
||
|
Model backend instance
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If backend type is invalid
|
||
|
RuntimeError: If no backends are available
|
||
|
"""
|
||
|
if not self._backends:
|
||
|
raise RuntimeError("No backends available")
|
||
|
|
||
|
if backend_type is None:
|
||
|
backend_type = self._current_backend
|
||
|
|
||
|
if backend_type not in self._backends:
|
||
|
raise ValueError(
|
||
|
f"Invalid backend type: {backend_type}. "
|
||
|
f"Available backends: {', '.join(self._backends.keys())}"
|
||
|
)
|
||
|
|
||
|
return self._backends[backend_type]
|
||
|
|
||
|
def _determine_backend(self, model_path: str) -> str:
|
||
|
"""Determine appropriate backend based on model file.
|
||
|
|
||
|
Args:
|
||
|
model_path: Path to model file
|
||
|
|
||
|
Returns:
|
||
|
Backend type to use
|
||
|
"""
|
||
|
is_onnx = model_path.lower().endswith('.onnx')
|
||
|
has_gpu = settings.use_gpu and torch.cuda.is_available()
|
||
|
|
||
|
if is_onnx:
|
||
|
return 'onnx_gpu' if has_gpu else 'onnx_cpu'
|
||
|
else:
|
||
|
return 'pytorch_gpu' if has_gpu else 'pytorch_cpu'
|
||
|
|
||
|
async def load_model(
|
||
|
self,
|
||
|
model_path: str,
|
||
|
backend_type: Optional[str] = None
|
||
|
) -> None:
|
||
|
"""Load model on specified backend.
|
||
|
|
||
|
Args:
|
||
|
model_path: Path to model file
|
||
|
backend_type: Backend to load on, uses default if None
|
||
|
|
||
|
Raises:
|
||
|
RuntimeError: If model loading fails
|
||
|
"""
|
||
|
try:
|
||
|
# Get absolute model path
|
||
|
abs_path = await paths.get_model_path(model_path)
|
||
|
|
||
|
# Auto-determine backend if not specified
|
||
|
if backend_type is None:
|
||
|
backend_type = self._determine_backend(abs_path)
|
||
|
|
||
|
backend = self.get_backend(backend_type)
|
||
|
|
||
|
# Load model and run warmup
|
||
|
await backend.load_model(abs_path)
|
||
|
logger.info(f"Loaded model on {backend_type} backend")
|
||
|
await self._warmup_inference(backend)
|
||
|
|
||
|
except Exception as e:
|
||
|
raise RuntimeError(f"Failed to load model: {e}")
|
||
|
|
||
|
async def _warmup_inference(self, backend: BaseModelBackend) -> None:
|
||
|
"""Run warmup inference to initialize model."""
|
||
|
try:
|
||
|
# Import here to avoid circular imports
|
||
|
from ..text_processing import process_text
|
||
|
|
||
|
# Load default voice for warmup
|
||
|
voice = await self._voice_manager.load_voice(settings.default_voice, device=backend.device)
|
||
|
logger.info(f"Loaded voice {settings.default_voice} for warmup")
|
||
|
|
||
|
# Use real text
|
||
|
text = "Testing text to speech synthesis."
|
||
|
logger.info(f"Running warmup inference with voice: af")
|
||
|
|
||
|
# Process through pipeline
|
||
|
sequences = process_text(text)
|
||
|
if not sequences:
|
||
|
raise ValueError("Text processing failed")
|
||
|
|
||
|
# Run inference
|
||
|
backend.generate(sequences[0], voice, speed=1.0)
|
||
|
|
||
|
except Exception as e:
|
||
|
logger.warning(f"Warmup inference failed: {e}")
|
||
|
raise
|
||
|
|
||
|
async def generate(
|
||
|
self,
|
||
|
tokens: list[int],
|
||
|
voice_name: str,
|
||
|
speed: float = 1.0,
|
||
|
backend_type: Optional[str] = None
|
||
|
) -> torch.Tensor:
|
||
|
"""Generate audio using specified backend.
|
||
|
|
||
|
Args:
|
||
|
tokens: Input token IDs
|
||
|
voice_name: Name of voice to use
|
||
|
speed: Speed multiplier
|
||
|
backend_type: Backend to use, uses default if None
|
||
|
|
||
|
Returns:
|
||
|
Generated audio tensor
|
||
|
|
||
|
Raises:
|
||
|
RuntimeError: If generation fails
|
||
|
"""
|
||
|
backend = self.get_backend(backend_type)
|
||
|
if not backend.is_loaded:
|
||
|
raise RuntimeError("Model not loaded")
|
||
|
|
||
|
try:
|
||
|
# Load voice using voice manager
|
||
|
voice = await self._voice_manager.load_voice(voice_name, device=backend.device)
|
||
|
|
||
|
# Generate audio
|
||
|
return backend.generate(tokens, voice, speed)
|
||
|
|
||
|
except Exception as e:
|
||
|
raise RuntimeError(f"Generation failed: {e}")
|
||
|
|
||
|
def unload_all(self) -> None:
|
||
|
"""Unload models from all backends."""
|
||
|
for backend in self._backends.values():
|
||
|
backend.unload()
|
||
|
logger.info("Unloaded all models")
|
||
|
|
||
|
@property
|
||
|
def available_backends(self) -> list[str]:
|
||
|
"""Get list of available backends.
|
||
|
|
||
|
Returns:
|
||
|
List of backend names
|
||
|
"""
|
||
|
return list(self._backends.keys())
|
||
|
|
||
|
@property
|
||
|
def current_backend(self) -> str:
|
||
|
"""Get current default backend.
|
||
|
|
||
|
Returns:
|
||
|
Backend name
|
||
|
"""
|
||
|
return self._current_backend
|
||
|
|
||
|
|
||
|
# Module-level instance
|
||
|
_manager: Optional[ModelManager] = None
|
||
|
|
||
|
|
||
|
def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
|
||
|
"""Get or create global model manager instance.
|
||
|
|
||
|
Args:
|
||
|
config: Optional model configuration
|
||
|
|
||
|
Returns:
|
||
|
ModelManager instance
|
||
|
"""
|
||
|
global _manager
|
||
|
if _manager is None:
|
||
|
_manager = ModelManager(config)
|
||
|
return _manager
|