mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Refactor inference architecture: remove legacy TTS model, add ONNX and PyTorch backends, and introduce model configuration schemas
This commit is contained in:
parent
83c55ca735
commit
ab28a62e86
16 changed files with 1606 additions and 813 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -18,7 +18,8 @@ __pycache__/
|
|||
*.egg
|
||||
dist/
|
||||
build/
|
||||
|
||||
*.onnx
|
||||
*.pth
|
||||
# Environment
|
||||
# .env
|
||||
.venv/
|
||||
|
|
198
api/src/core/paths.py
Normal file
198
api/src/core/paths.py
Normal file
|
@ -0,0 +1,198 @@
|
|||
"""Async file and path operations."""
|
||||
|
||||
import io
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, AsyncIterator, Callable, Set
|
||||
|
||||
import aiofiles
|
||||
import aiofiles.os
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from .config import settings
|
||||
|
||||
|
||||
async def _find_file(
|
||||
filename: str,
|
||||
search_paths: List[str],
|
||||
filter_fn: Optional[Callable[[str], bool]] = None
|
||||
) -> str:
|
||||
"""Find file in search paths.
|
||||
|
||||
Args:
|
||||
filename: Name of file to find
|
||||
search_paths: List of paths to search in
|
||||
filter_fn: Optional function to filter files
|
||||
|
||||
Returns:
|
||||
Absolute path to file
|
||||
|
||||
Raises:
|
||||
RuntimeError: If file not found
|
||||
"""
|
||||
if os.path.isabs(filename) and await aiofiles.os.path.exists(filename):
|
||||
return filename
|
||||
|
||||
for path in search_paths:
|
||||
full_path = os.path.join(path, filename)
|
||||
if await aiofiles.os.path.exists(full_path):
|
||||
if filter_fn is None or filter_fn(full_path):
|
||||
return full_path
|
||||
|
||||
raise RuntimeError(f"File not found: {filename} in paths: {search_paths}")
|
||||
|
||||
|
||||
async def _scan_directories(
|
||||
search_paths: List[str],
|
||||
filter_fn: Optional[Callable[[str], bool]] = None
|
||||
) -> Set[str]:
|
||||
"""Scan directories for files.
|
||||
|
||||
Args:
|
||||
search_paths: List of paths to scan
|
||||
filter_fn: Optional function to filter files
|
||||
|
||||
Returns:
|
||||
Set of matching filenames
|
||||
"""
|
||||
results = set()
|
||||
|
||||
for path in search_paths:
|
||||
if not await aiofiles.os.path.exists(path):
|
||||
continue
|
||||
|
||||
try:
|
||||
# Get directory entries first
|
||||
entries = await aiofiles.os.scandir(path)
|
||||
# Then process entries after await completes
|
||||
for entry in entries:
|
||||
if filter_fn is None or filter_fn(entry.name):
|
||||
results.add(entry.name)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error scanning {path}: {e}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def get_model_path(model_name: str) -> str:
|
||||
"""Get path to model file.
|
||||
|
||||
Args:
|
||||
model_name: Name of model file
|
||||
|
||||
Returns:
|
||||
Absolute path to model file
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model not found
|
||||
"""
|
||||
search_paths = [
|
||||
settings.model_dir,
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", "..", "models")
|
||||
]
|
||||
|
||||
return await _find_file(model_name, search_paths)
|
||||
|
||||
|
||||
async def get_voice_path(voice_name: str) -> str:
|
||||
"""Get path to voice file.
|
||||
|
||||
Args:
|
||||
voice_name: Name of voice file (without .pt extension)
|
||||
|
||||
Returns:
|
||||
Absolute path to voice file
|
||||
|
||||
Raises:
|
||||
RuntimeError: If voice not found
|
||||
"""
|
||||
voice_file = f"{voice_name}.pt"
|
||||
|
||||
search_paths = [
|
||||
os.path.join(settings.model_dir, "..", settings.voices_dir),
|
||||
os.path.join(os.path.dirname(__file__), "..", settings.voices_dir)
|
||||
]
|
||||
|
||||
return await _find_file(voice_file, search_paths)
|
||||
|
||||
|
||||
async def list_voices() -> List[str]:
|
||||
"""List available voice files.
|
||||
|
||||
Returns:
|
||||
List of voice names (without .pt extension)
|
||||
"""
|
||||
search_paths = [
|
||||
os.path.join(settings.model_dir, "..", settings.voices_dir),
|
||||
os.path.join(os.path.dirname(__file__), "..", settings.voices_dir)
|
||||
]
|
||||
|
||||
def filter_voice_files(name: str) -> bool:
|
||||
return name.endswith('.pt')
|
||||
|
||||
voices = await _scan_directories(search_paths, filter_voice_files)
|
||||
return sorted([name[:-3] for name in voices]) # Remove .pt extension
|
||||
|
||||
|
||||
async def load_voice_tensor(voice_path: str, device: str = "cpu") -> torch.Tensor:
|
||||
"""Load voice tensor from file.
|
||||
|
||||
Args:
|
||||
voice_path: Path to voice file
|
||||
device: Device to load tensor to
|
||||
|
||||
Returns:
|
||||
Voice tensor
|
||||
|
||||
Raises:
|
||||
RuntimeError: If file cannot be read
|
||||
"""
|
||||
try:
|
||||
async with aiofiles.open(voice_path, 'rb') as f:
|
||||
data = await f.read()
|
||||
return torch.load(
|
||||
io.BytesIO(data),
|
||||
map_location=device,
|
||||
weights_only=True
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load voice tensor from {voice_path}: {e}")
|
||||
|
||||
|
||||
async def save_voice_tensor(tensor: torch.Tensor, voice_path: str) -> None:
|
||||
"""Save voice tensor to file.
|
||||
|
||||
Args:
|
||||
tensor: Voice tensor to save
|
||||
voice_path: Path to save voice file
|
||||
|
||||
Raises:
|
||||
RuntimeError: If file cannot be written
|
||||
"""
|
||||
try:
|
||||
buffer = io.BytesIO()
|
||||
torch.save(tensor, buffer)
|
||||
async with aiofiles.open(voice_path, 'wb') as f:
|
||||
await f.write(buffer.getvalue())
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to save voice tensor to {voice_path}: {e}")
|
||||
|
||||
|
||||
async def read_file(path: str) -> str:
|
||||
"""Read text file asynchronously.
|
||||
|
||||
Args:
|
||||
path: Path to file
|
||||
|
||||
Returns:
|
||||
File contents as string
|
||||
|
||||
Raises:
|
||||
RuntimeError: If file cannot be read
|
||||
"""
|
||||
try:
|
||||
async with aiofiles.open(path, 'r', encoding='utf-8') as f:
|
||||
return await f.read()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to read file {path}: {e}")
|
20
api/src/inference/__init__.py
Normal file
20
api/src/inference/__init__.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
"""Inference backends and model management."""
|
||||
|
||||
from .base import BaseModelBackend
|
||||
from .model_manager import ModelManager, get_manager
|
||||
from .onnx_cpu import ONNXCPUBackend
|
||||
from .onnx_gpu import ONNXGPUBackend
|
||||
from .pytorch_cpu import PyTorchCPUBackend
|
||||
from .pytorch_gpu import PyTorchGPUBackend
|
||||
from ..structures.model_schemas import ModelConfig
|
||||
|
||||
__all__ = [
|
||||
'BaseModelBackend',
|
||||
'ModelManager',
|
||||
'get_manager',
|
||||
'ModelConfig',
|
||||
'ONNXCPUBackend',
|
||||
'ONNXGPUBackend',
|
||||
'PyTorchCPUBackend',
|
||||
'PyTorchGPUBackend'
|
||||
]
|
97
api/src/inference/base.py
Normal file
97
api/src/inference/base.py
Normal file
|
@ -0,0 +1,97 @@
|
|||
"""Base interfaces for model inference."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class ModelBackend(ABC):
|
||||
"""Abstract base class for model inference backends."""
|
||||
|
||||
@abstractmethod
|
||||
async def load_model(self, path: str) -> None:
|
||||
"""Load model from path.
|
||||
|
||||
Args:
|
||||
path: Path to model file
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model loading fails
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def generate(
|
||||
self,
|
||||
tokens: List[int],
|
||||
voice: torch.Tensor,
|
||||
speed: float = 1.0
|
||||
) -> np.ndarray:
|
||||
"""Generate audio from tokens.
|
||||
|
||||
Args:
|
||||
tokens: Input token IDs
|
||||
voice: Voice embedding tensor
|
||||
speed: Speed multiplier
|
||||
|
||||
Returns:
|
||||
Generated audio samples
|
||||
|
||||
Raises:
|
||||
RuntimeError: If generation fails
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unload(self) -> None:
|
||||
"""Unload model and free resources."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if model is loaded.
|
||||
|
||||
Returns:
|
||||
True if model is loaded, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def device(self) -> str:
|
||||
"""Get device model is running on.
|
||||
|
||||
Returns:
|
||||
Device string ('cpu' or 'cuda')
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class BaseModelBackend(ModelBackend):
|
||||
"""Base implementation of model backend."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize base backend."""
|
||||
self._model: Optional[torch.nn.Module] = None
|
||||
self._device: str = "cpu"
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if model is loaded."""
|
||||
return self._model is not None
|
||||
|
||||
@property
|
||||
def device(self) -> str:
|
||||
"""Get device model is running on."""
|
||||
return self._device
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Unload model and free resources."""
|
||||
if self._model is not None:
|
||||
del self._model
|
||||
self._model = None
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
251
api/src/inference/model_manager.py
Normal file
251
api/src/inference/model_manager.py
Normal file
|
@ -0,0 +1,251 @@
|
|||
"""Model management and caching."""
|
||||
|
||||
import os
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .base import BaseModelBackend
|
||||
from .voice_manager import get_manager as get_voice_manager
|
||||
from .onnx_cpu import ONNXCPUBackend
|
||||
from .onnx_gpu import ONNXGPUBackend
|
||||
from .pytorch_cpu import PyTorchCPUBackend
|
||||
from .pytorch_gpu import PyTorchGPUBackend
|
||||
from ..core import paths
|
||||
from ..core.config import settings
|
||||
from ..structures.model_schemas import ModelConfig
|
||||
|
||||
|
||||
class ModelManager:
|
||||
"""Manages model loading and inference across backends."""
|
||||
|
||||
def __init__(self, config: Optional[ModelConfig] = None):
|
||||
"""Initialize model manager.
|
||||
|
||||
Args:
|
||||
config: Optional configuration
|
||||
"""
|
||||
self._config = config or ModelConfig()
|
||||
self._backends: Dict[str, BaseModelBackend] = {}
|
||||
self._current_backend: Optional[str] = None
|
||||
self._voice_manager = get_voice_manager()
|
||||
self._initialize_backends()
|
||||
|
||||
def _initialize_backends(self) -> None:
|
||||
"""Initialize available backends."""
|
||||
"""Initialize available backends."""
|
||||
# Initialize GPU backends if available
|
||||
if settings.use_gpu and torch.cuda.is_available():
|
||||
try:
|
||||
# PyTorch GPU
|
||||
self._backends['pytorch_gpu'] = PyTorchGPUBackend()
|
||||
self._current_backend = 'pytorch_gpu'
|
||||
logger.info("Initialized PyTorch GPU backend")
|
||||
|
||||
# ONNX GPU
|
||||
self._backends['onnx_gpu'] = ONNXGPUBackend()
|
||||
logger.info("Initialized ONNX GPU backend")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize GPU backends: {e}")
|
||||
# Fallback to CPU if GPU fails
|
||||
self._initialize_cpu_backends()
|
||||
else:
|
||||
self._initialize_cpu_backends()
|
||||
|
||||
def _initialize_cpu_backends(self) -> None:
|
||||
"""Initialize CPU backends."""
|
||||
try:
|
||||
# PyTorch CPU
|
||||
self._backends['pytorch_cpu'] = PyTorchCPUBackend()
|
||||
self._current_backend = 'pytorch_cpu'
|
||||
logger.info("Initialized PyTorch CPU backend")
|
||||
|
||||
# ONNX CPU
|
||||
self._backends['onnx_cpu'] = ONNXCPUBackend()
|
||||
logger.info("Initialized ONNX CPU backend")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize CPU backends: {e}")
|
||||
raise RuntimeError("No backends available")
|
||||
|
||||
def get_backend(self, backend_type: Optional[str] = None) -> BaseModelBackend:
|
||||
"""Get specified backend.
|
||||
|
||||
Args:
|
||||
backend_type: Backend type ('pytorch_cpu', 'pytorch_gpu', 'onnx_cpu', 'onnx_gpu'),
|
||||
uses default if None
|
||||
|
||||
Returns:
|
||||
Model backend instance
|
||||
|
||||
Raises:
|
||||
ValueError: If backend type is invalid
|
||||
RuntimeError: If no backends are available
|
||||
"""
|
||||
if not self._backends:
|
||||
raise RuntimeError("No backends available")
|
||||
|
||||
if backend_type is None:
|
||||
backend_type = self._current_backend
|
||||
|
||||
if backend_type not in self._backends:
|
||||
raise ValueError(
|
||||
f"Invalid backend type: {backend_type}. "
|
||||
f"Available backends: {', '.join(self._backends.keys())}"
|
||||
)
|
||||
|
||||
return self._backends[backend_type]
|
||||
|
||||
def _determine_backend(self, model_path: str) -> str:
|
||||
"""Determine appropriate backend based on model file.
|
||||
|
||||
Args:
|
||||
model_path: Path to model file
|
||||
|
||||
Returns:
|
||||
Backend type to use
|
||||
"""
|
||||
is_onnx = model_path.lower().endswith('.onnx')
|
||||
has_gpu = settings.use_gpu and torch.cuda.is_available()
|
||||
|
||||
if is_onnx:
|
||||
return 'onnx_gpu' if has_gpu else 'onnx_cpu'
|
||||
else:
|
||||
return 'pytorch_gpu' if has_gpu else 'pytorch_cpu'
|
||||
|
||||
async def load_model(
|
||||
self,
|
||||
model_path: str,
|
||||
backend_type: Optional[str] = None
|
||||
) -> None:
|
||||
"""Load model on specified backend.
|
||||
|
||||
Args:
|
||||
model_path: Path to model file
|
||||
backend_type: Backend to load on, uses default if None
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model loading fails
|
||||
"""
|
||||
try:
|
||||
# Get absolute model path
|
||||
abs_path = await paths.get_model_path(model_path)
|
||||
|
||||
# Auto-determine backend if not specified
|
||||
if backend_type is None:
|
||||
backend_type = self._determine_backend(abs_path)
|
||||
|
||||
backend = self.get_backend(backend_type)
|
||||
|
||||
# Load model and run warmup
|
||||
await backend.load_model(abs_path)
|
||||
logger.info(f"Loaded model on {backend_type} backend")
|
||||
await self._warmup_inference(backend)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load model: {e}")
|
||||
|
||||
async def _warmup_inference(self, backend: BaseModelBackend) -> None:
|
||||
"""Run warmup inference to initialize model."""
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from ..text_processing import process_text
|
||||
|
||||
# Load default voice for warmup
|
||||
voice = await self._voice_manager.load_voice(settings.default_voice, device=backend.device)
|
||||
logger.info(f"Loaded voice {settings.default_voice} for warmup")
|
||||
|
||||
# Use real text
|
||||
text = "Testing text to speech synthesis."
|
||||
logger.info(f"Running warmup inference with voice: af")
|
||||
|
||||
# Process through pipeline
|
||||
sequences = process_text(text)
|
||||
if not sequences:
|
||||
raise ValueError("Text processing failed")
|
||||
|
||||
# Run inference
|
||||
backend.generate(sequences[0], voice, speed=1.0)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Warmup inference failed: {e}")
|
||||
raise
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
tokens: list[int],
|
||||
voice_name: str,
|
||||
speed: float = 1.0,
|
||||
backend_type: Optional[str] = None
|
||||
) -> torch.Tensor:
|
||||
"""Generate audio using specified backend.
|
||||
|
||||
Args:
|
||||
tokens: Input token IDs
|
||||
voice_name: Name of voice to use
|
||||
speed: Speed multiplier
|
||||
backend_type: Backend to use, uses default if None
|
||||
|
||||
Returns:
|
||||
Generated audio tensor
|
||||
|
||||
Raises:
|
||||
RuntimeError: If generation fails
|
||||
"""
|
||||
backend = self.get_backend(backend_type)
|
||||
if not backend.is_loaded:
|
||||
raise RuntimeError("Model not loaded")
|
||||
|
||||
try:
|
||||
# Load voice using voice manager
|
||||
voice = await self._voice_manager.load_voice(voice_name, device=backend.device)
|
||||
|
||||
# Generate audio
|
||||
return backend.generate(tokens, voice, speed)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Generation failed: {e}")
|
||||
|
||||
def unload_all(self) -> None:
|
||||
"""Unload models from all backends."""
|
||||
for backend in self._backends.values():
|
||||
backend.unload()
|
||||
logger.info("Unloaded all models")
|
||||
|
||||
@property
|
||||
def available_backends(self) -> list[str]:
|
||||
"""Get list of available backends.
|
||||
|
||||
Returns:
|
||||
List of backend names
|
||||
"""
|
||||
return list(self._backends.keys())
|
||||
|
||||
@property
|
||||
def current_backend(self) -> str:
|
||||
"""Get current default backend.
|
||||
|
||||
Returns:
|
||||
Backend name
|
||||
"""
|
||||
return self._current_backend
|
||||
|
||||
|
||||
# Module-level instance
|
||||
_manager: Optional[ModelManager] = None
|
||||
|
||||
|
||||
def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
|
||||
"""Get or create global model manager instance.
|
||||
|
||||
Args:
|
||||
config: Optional model configuration
|
||||
|
||||
Returns:
|
||||
ModelManager instance
|
||||
"""
|
||||
global _manager
|
||||
if _manager is None:
|
||||
_manager = ModelManager(config)
|
||||
return _manager
|
154
api/src/inference/onnx_cpu.py
Normal file
154
api/src/inference/onnx_cpu.py
Normal file
|
@ -0,0 +1,154 @@
|
|||
"""CPU-based ONNX inference backend."""
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
from onnxruntime import (
|
||||
ExecutionMode,
|
||||
GraphOptimizationLevel,
|
||||
InferenceSession,
|
||||
SessionOptions
|
||||
)
|
||||
|
||||
from ..core import paths
|
||||
from ..core.config import settings
|
||||
from ..structures.model_schemas import ONNXConfig
|
||||
from .base import BaseModelBackend
|
||||
|
||||
|
||||
class ONNXCPUBackend(BaseModelBackend):
|
||||
"""ONNX-based CPU inference backend."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize CPU backend."""
|
||||
super().__init__()
|
||||
self._device = "cpu"
|
||||
self._session: Optional[InferenceSession] = None
|
||||
self._config = ONNXConfig(
|
||||
optimization_level=settings.onnx_optimization_level,
|
||||
num_threads=settings.onnx_num_threads,
|
||||
inter_op_threads=settings.onnx_inter_op_threads,
|
||||
execution_mode=settings.onnx_execution_mode,
|
||||
memory_pattern=settings.onnx_memory_pattern,
|
||||
arena_extend_strategy=settings.onnx_arena_extend_strategy
|
||||
)
|
||||
|
||||
async def load_model(self, path: str) -> None:
|
||||
"""Load ONNX model.
|
||||
|
||||
Args:
|
||||
path: Path to model file
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model loading fails
|
||||
"""
|
||||
try:
|
||||
# Get verified model path
|
||||
model_path = await paths.get_model_path(path)
|
||||
|
||||
logger.info(f"Loading ONNX model: {model_path}")
|
||||
|
||||
# Configure session
|
||||
options = self._create_session_options()
|
||||
provider_options = self._create_provider_options()
|
||||
|
||||
# Create session
|
||||
self._session = InferenceSession(
|
||||
model_path,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
provider_options=[provider_options]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load ONNX model: {e}")
|
||||
|
||||
def generate(
|
||||
self,
|
||||
tokens: list[int],
|
||||
voice: torch.Tensor,
|
||||
speed: float = 1.0
|
||||
) -> np.ndarray:
|
||||
"""Generate audio using ONNX model.
|
||||
|
||||
Args:
|
||||
tokens: Input token IDs
|
||||
voice: Voice embedding tensor
|
||||
speed: Speed multiplier
|
||||
|
||||
Returns:
|
||||
Generated audio samples
|
||||
|
||||
Raises:
|
||||
RuntimeError: If generation fails
|
||||
"""
|
||||
if not self.is_loaded:
|
||||
raise RuntimeError("Model not loaded")
|
||||
|
||||
try:
|
||||
# Prepare inputs
|
||||
tokens_input = np.array([tokens], dtype=np.int64)
|
||||
style_input = voice[len(tokens)].numpy()
|
||||
speed_input = np.full(1, speed, dtype=np.float32)
|
||||
|
||||
# Run inference
|
||||
result = self._session.run(
|
||||
None,
|
||||
{
|
||||
"tokens": tokens_input,
|
||||
"style": style_input,
|
||||
"speed": speed_input
|
||||
}
|
||||
)
|
||||
|
||||
return result[0]
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Generation failed: {e}")
|
||||
|
||||
def _create_session_options(self) -> SessionOptions:
|
||||
"""Create ONNX session options.
|
||||
|
||||
Returns:
|
||||
Configured session options
|
||||
"""
|
||||
options = SessionOptions()
|
||||
|
||||
# Set optimization level
|
||||
if self._config.optimization_level == "all":
|
||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
elif self._config.optimization_level == "basic":
|
||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||
else:
|
||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
|
||||
|
||||
# Configure threading
|
||||
options.intra_op_num_threads = self._config.num_threads
|
||||
options.inter_op_num_threads = self._config.inter_op_threads
|
||||
|
||||
# Set execution mode
|
||||
options.execution_mode = (
|
||||
ExecutionMode.ORT_PARALLEL
|
||||
if self._config.execution_mode == "parallel"
|
||||
else ExecutionMode.ORT_SEQUENTIAL
|
||||
)
|
||||
|
||||
# Configure memory optimization
|
||||
options.enable_mem_pattern = self._config.memory_pattern
|
||||
|
||||
return options
|
||||
|
||||
def _create_provider_options(self) -> Dict:
|
||||
"""Create CPU provider options.
|
||||
|
||||
Returns:
|
||||
Provider configuration
|
||||
"""
|
||||
return {
|
||||
"CPUExecutionProvider": {
|
||||
"arena_extend_strategy": self._config.arena_extend_strategy,
|
||||
"cpu_memory_arena_cfg": "cpu:0"
|
||||
}
|
||||
}
|
163
api/src/inference/onnx_gpu.py
Normal file
163
api/src/inference/onnx_gpu.py
Normal file
|
@ -0,0 +1,163 @@
|
|||
"""GPU-based ONNX inference backend."""
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
from onnxruntime import (
|
||||
ExecutionMode,
|
||||
GraphOptimizationLevel,
|
||||
InferenceSession,
|
||||
SessionOptions
|
||||
)
|
||||
|
||||
from ..core import paths
|
||||
from ..core.config import settings
|
||||
from ..structures.model_schemas import ONNXGPUConfig
|
||||
from .base import BaseModelBackend
|
||||
|
||||
|
||||
class ONNXGPUBackend(BaseModelBackend):
|
||||
"""ONNX-based GPU inference backend."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize GPU backend."""
|
||||
super().__init__()
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("CUDA not available")
|
||||
self._device = "cuda"
|
||||
self._session: Optional[InferenceSession] = None
|
||||
self._config = ONNXGPUConfig(
|
||||
optimization_level=settings.onnx_optimization_level,
|
||||
num_threads=settings.onnx_num_threads,
|
||||
inter_op_threads=settings.onnx_inter_op_threads,
|
||||
execution_mode=settings.onnx_execution_mode,
|
||||
memory_pattern=settings.onnx_memory_pattern,
|
||||
arena_extend_strategy=settings.onnx_arena_extend_strategy,
|
||||
device_id=0,
|
||||
gpu_mem_limit=0.7,
|
||||
cudnn_conv_algo_search="EXHAUSTIVE",
|
||||
do_copy_in_default_stream=True
|
||||
)
|
||||
|
||||
async def load_model(self, path: str) -> None:
|
||||
"""Load ONNX model.
|
||||
|
||||
Args:
|
||||
path: Path to model file
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model loading fails
|
||||
"""
|
||||
try:
|
||||
# Get verified model path
|
||||
model_path = await paths.get_model_path(path)
|
||||
|
||||
logger.info(f"Loading ONNX model on GPU: {model_path}")
|
||||
|
||||
# Configure session
|
||||
options = self._create_session_options()
|
||||
provider_options = self._create_provider_options()
|
||||
|
||||
# Create session with CUDA provider
|
||||
self._session = InferenceSession(
|
||||
model_path,
|
||||
sess_options=options,
|
||||
providers=["CUDAExecutionProvider"],
|
||||
provider_options=[provider_options]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load ONNX model: {e}")
|
||||
|
||||
def generate(
|
||||
self,
|
||||
tokens: list[int],
|
||||
voice: torch.Tensor,
|
||||
speed: float = 1.0
|
||||
) -> np.ndarray:
|
||||
"""Generate audio using ONNX model.
|
||||
|
||||
Args:
|
||||
tokens: Input token IDs
|
||||
voice: Voice embedding tensor
|
||||
speed: Speed multiplier
|
||||
|
||||
Returns:
|
||||
Generated audio samples
|
||||
|
||||
Raises:
|
||||
RuntimeError: If generation fails
|
||||
"""
|
||||
if not self.is_loaded:
|
||||
raise RuntimeError("Model not loaded")
|
||||
|
||||
try:
|
||||
# Prepare inputs
|
||||
tokens_input = np.array([tokens], dtype=np.int64)
|
||||
style_input = voice[len(tokens)].cpu().numpy() # Move to CPU for ONNX
|
||||
speed_input = np.full(1, speed, dtype=np.float32)
|
||||
|
||||
# Run inference
|
||||
result = self._session.run(
|
||||
None,
|
||||
{
|
||||
"tokens": tokens_input,
|
||||
"style": style_input,
|
||||
"speed": speed_input
|
||||
}
|
||||
)
|
||||
|
||||
return result[0]
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Generation failed: {e}")
|
||||
|
||||
def _create_session_options(self) -> SessionOptions:
|
||||
"""Create ONNX session options.
|
||||
|
||||
Returns:
|
||||
Configured session options
|
||||
"""
|
||||
options = SessionOptions()
|
||||
|
||||
# Set optimization level
|
||||
if self._config.optimization_level == "all":
|
||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
elif self._config.optimization_level == "basic":
|
||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||
else:
|
||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
|
||||
|
||||
# Configure threading
|
||||
options.intra_op_num_threads = self._config.num_threads
|
||||
options.inter_op_num_threads = self._config.inter_op_threads
|
||||
|
||||
# Set execution mode
|
||||
options.execution_mode = (
|
||||
ExecutionMode.ORT_PARALLEL
|
||||
if self._config.execution_mode == "parallel"
|
||||
else ExecutionMode.ORT_SEQUENTIAL
|
||||
)
|
||||
|
||||
# Configure memory optimization
|
||||
options.enable_mem_pattern = self._config.memory_pattern
|
||||
|
||||
return options
|
||||
|
||||
def _create_provider_options(self) -> Dict:
|
||||
"""Create CUDA provider options.
|
||||
|
||||
Returns:
|
||||
Provider configuration
|
||||
"""
|
||||
return {
|
||||
"CUDAExecutionProvider": {
|
||||
"device_id": self._config.device_id,
|
||||
"arena_extend_strategy": self._config.arena_extend_strategy,
|
||||
"gpu_mem_limit": int(self._config.gpu_mem_limit * torch.cuda.get_device_properties(0).total_memory),
|
||||
"cudnn_conv_algo_search": self._config.cudnn_conv_algo_search,
|
||||
"do_copy_in_default_stream": self._config.do_copy_in_default_stream
|
||||
}
|
||||
}
|
181
api/src/inference/pytorch_cpu.py
Normal file
181
api/src/inference/pytorch_cpu.py
Normal file
|
@ -0,0 +1,181 @@
|
|||
"""CPU-based PyTorch inference backend."""
|
||||
|
||||
import gc
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from ..builds.models import build_model
|
||||
from ..core import paths
|
||||
from ..structures.model_schemas import PyTorchCPUConfig
|
||||
from .base import BaseModelBackend
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(model: torch.nn.Module, tokens: list[int], ref_s: torch.Tensor, speed: float) -> np.ndarray:
|
||||
"""Forward pass through model with memory management.
|
||||
|
||||
Args:
|
||||
model: PyTorch model
|
||||
tokens: Input tokens
|
||||
ref_s: Reference signal
|
||||
speed: Speed multiplier
|
||||
|
||||
Returns:
|
||||
Generated audio
|
||||
"""
|
||||
device = ref_s.device
|
||||
pred_aln_trg = None
|
||||
asr = None
|
||||
|
||||
try:
|
||||
# Initial tensor setup
|
||||
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
||||
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
||||
text_mask = length_to_mask(input_lengths).to(device)
|
||||
|
||||
# Split reference signals
|
||||
s_content = ref_s[:, 128:].clone().to(device)
|
||||
s_ref = ref_s[:, :128].clone().to(device)
|
||||
|
||||
# BERT and encoder pass
|
||||
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
||||
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
||||
|
||||
# Predictor forward pass
|
||||
d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask)
|
||||
x, _ = model.predictor.lstm(d)
|
||||
|
||||
# Duration prediction
|
||||
duration = model.predictor.duration_proj(x)
|
||||
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
||||
pred_dur = torch.round(duration).clamp(min=1).long()
|
||||
del duration, x # Free large intermediates
|
||||
|
||||
# Alignment matrix construction
|
||||
pred_aln_trg = torch.zeros(
|
||||
input_lengths.item(),
|
||||
pred_dur.sum().item(),
|
||||
device=device
|
||||
)
|
||||
c_frame = 0
|
||||
for i in range(pred_aln_trg.size(0)):
|
||||
pred_aln_trg[i, c_frame:c_frame + pred_dur[0, i].item()] = 1
|
||||
c_frame += pred_dur[0, i].item()
|
||||
pred_aln_trg = pred_aln_trg.unsqueeze(0)
|
||||
|
||||
# Matrix multiplications with cleanup
|
||||
en = d.transpose(-1, -2) @ pred_aln_trg
|
||||
del d
|
||||
|
||||
F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
|
||||
del en
|
||||
|
||||
# Final text encoding and decoding
|
||||
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
||||
asr = t_en @ pred_aln_trg
|
||||
del t_en
|
||||
|
||||
# Generate output
|
||||
output = model.decoder(asr, F0_pred, N_pred, s_ref)
|
||||
result = output.squeeze().cpu().numpy()
|
||||
|
||||
return result
|
||||
|
||||
finally:
|
||||
# Clean up largest tensors if they were created
|
||||
if pred_aln_trg is not None:
|
||||
del pred_aln_trg
|
||||
if asr is not None:
|
||||
del asr
|
||||
|
||||
|
||||
def length_to_mask(lengths: torch.Tensor) -> torch.Tensor:
|
||||
"""Create attention mask from lengths.
|
||||
|
||||
Args:
|
||||
lengths: Sequence lengths
|
||||
|
||||
Returns:
|
||||
Boolean mask tensor
|
||||
"""
|
||||
max_len = lengths.max()
|
||||
mask = torch.arange(max_len, device=lengths.device)[None, :].expand(
|
||||
lengths.shape[0], -1
|
||||
)
|
||||
if lengths.dtype != mask.dtype:
|
||||
mask = mask.to(dtype=lengths.dtype)
|
||||
return mask + 1 > lengths[:, None]
|
||||
|
||||
|
||||
class PyTorchCPUBackend(BaseModelBackend):
|
||||
"""PyTorch CPU inference backend."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize CPU backend."""
|
||||
super().__init__()
|
||||
self._device = "cpu"
|
||||
self._model: Optional[torch.nn.Module] = None
|
||||
self._config = PyTorchCPUConfig()
|
||||
|
||||
# Configure PyTorch CPU settings
|
||||
if self._config.num_threads > 0:
|
||||
torch.set_num_threads(self._config.num_threads)
|
||||
if self._config.pin_memory:
|
||||
torch.set_default_tensor_type(torch.FloatTensor)
|
||||
|
||||
async def load_model(self, path: str) -> None:
|
||||
"""Load PyTorch model.
|
||||
|
||||
Args:
|
||||
path: Path to model file
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model loading fails
|
||||
"""
|
||||
try:
|
||||
# Get verified model path
|
||||
model_path = await paths.get_model_path(path)
|
||||
|
||||
logger.info(f"Loading PyTorch model on CPU: {model_path}")
|
||||
self._model = await build_model(model_path, self._device)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load PyTorch model: {e}")
|
||||
|
||||
def generate(
|
||||
self,
|
||||
tokens: list[int],
|
||||
voice: torch.Tensor,
|
||||
speed: float = 1.0
|
||||
) -> np.ndarray:
|
||||
"""Generate audio using CPU model.
|
||||
|
||||
Args:
|
||||
tokens: Input token IDs
|
||||
voice: Voice embedding tensor
|
||||
speed: Speed multiplier
|
||||
|
||||
Returns:
|
||||
Generated audio samples
|
||||
|
||||
Raises:
|
||||
RuntimeError: If generation fails
|
||||
"""
|
||||
if not self.is_loaded:
|
||||
raise RuntimeError("Model not loaded")
|
||||
|
||||
try:
|
||||
# Prepare input
|
||||
ref_s = voice[len(tokens)].clone()
|
||||
|
||||
# Generate audio
|
||||
return forward(self._model, tokens, ref_s, speed)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Generation failed: {e}")
|
||||
finally:
|
||||
# Clean up memory
|
||||
gc.collect()
|
170
api/src/inference/pytorch_gpu.py
Normal file
170
api/src/inference/pytorch_gpu.py
Normal file
|
@ -0,0 +1,170 @@
|
|||
"""GPU-based PyTorch inference backend."""
|
||||
|
||||
import gc
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from ..builds.models import build_model
|
||||
from ..core import paths
|
||||
from ..structures.model_schemas import PyTorchConfig
|
||||
from .base import BaseModelBackend
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(model: torch.nn.Module, tokens: list[int], ref_s: torch.Tensor, speed: float) -> np.ndarray:
|
||||
"""Forward pass through model.
|
||||
|
||||
Args:
|
||||
model: PyTorch model
|
||||
tokens: Input tokens
|
||||
ref_s: Reference signal (shape: [1, n_features])
|
||||
speed: Speed multiplier
|
||||
|
||||
Returns:
|
||||
Generated audio
|
||||
"""
|
||||
device = ref_s.device
|
||||
|
||||
# Initial tensor setup
|
||||
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
||||
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
||||
text_mask = length_to_mask(input_lengths).to(device)
|
||||
|
||||
# Split reference signals (style_dim=128 from config)
|
||||
style_dim = 128
|
||||
s_ref = ref_s[:, :style_dim].clone().to(device)
|
||||
s_content = ref_s[:, style_dim:].clone().to(device)
|
||||
|
||||
# BERT and encoder pass
|
||||
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
||||
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
||||
|
||||
# Predictor forward pass
|
||||
d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask)
|
||||
x, _ = model.predictor.lstm(d)
|
||||
|
||||
# Duration prediction
|
||||
duration = model.predictor.duration_proj(x)
|
||||
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
||||
pred_dur = torch.round(duration).clamp(min=1).long()
|
||||
del duration, x
|
||||
|
||||
# Alignment matrix construction
|
||||
pred_aln_trg = torch.zeros(input_lengths.item(), pred_dur.sum().item(), device=device)
|
||||
c_frame = 0
|
||||
for i in range(pred_aln_trg.size(0)):
|
||||
pred_aln_trg[i, c_frame:c_frame + pred_dur[0, i].item()] = 1
|
||||
c_frame += pred_dur[0, i].item()
|
||||
pred_aln_trg = pred_aln_trg.unsqueeze(0)
|
||||
|
||||
# Matrix multiplications
|
||||
en = d.transpose(-1, -2) @ pred_aln_trg
|
||||
del d
|
||||
|
||||
F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
|
||||
del en
|
||||
|
||||
# Final text encoding and decoding
|
||||
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
||||
asr = t_en @ pred_aln_trg
|
||||
del t_en
|
||||
|
||||
# Generate output
|
||||
output = model.decoder(asr, F0_pred, N_pred, s_ref)
|
||||
return output.squeeze().cpu().numpy()
|
||||
|
||||
|
||||
def length_to_mask(lengths: torch.Tensor) -> torch.Tensor:
|
||||
"""Create attention mask from lengths."""
|
||||
max_len = lengths.max()
|
||||
mask = torch.arange(max_len, device=lengths.device)[None, :].expand(lengths.shape[0], -1)
|
||||
if lengths.dtype != mask.dtype:
|
||||
mask = mask.to(dtype=lengths.dtype)
|
||||
return mask + 1 > lengths[:, None]
|
||||
|
||||
|
||||
class PyTorchGPUBackend(BaseModelBackend):
|
||||
"""PyTorch GPU inference backend."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize GPU backend."""
|
||||
super().__init__()
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("CUDA not available")
|
||||
self._device = "cuda"
|
||||
self._model: Optional[torch.nn.Module] = None
|
||||
self._config = PyTorchConfig()
|
||||
|
||||
async def load_model(self, path: str) -> None:
|
||||
"""Load PyTorch model.
|
||||
|
||||
Args:
|
||||
path: Path to model file
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model loading fails
|
||||
"""
|
||||
try:
|
||||
# Get verified model path
|
||||
model_path = await paths.get_model_path(path)
|
||||
|
||||
logger.info(f"Loading PyTorch model: {model_path}")
|
||||
self._model = await build_model(model_path, self._device)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load PyTorch model: {e}")
|
||||
|
||||
def generate(
|
||||
self,
|
||||
tokens: list[int],
|
||||
voice: torch.Tensor,
|
||||
speed: float = 1.0
|
||||
) -> np.ndarray:
|
||||
"""Generate audio using GPU model.
|
||||
|
||||
Args:
|
||||
tokens: Input token IDs
|
||||
voice: Voice embedding tensor
|
||||
speed: Speed multiplier
|
||||
|
||||
Returns:
|
||||
Generated audio samples
|
||||
|
||||
Raises:
|
||||
RuntimeError: If generation fails
|
||||
"""
|
||||
if not self.is_loaded:
|
||||
raise RuntimeError("Model not loaded")
|
||||
|
||||
try:
|
||||
# Check memory and cleanup if needed
|
||||
if self._check_memory():
|
||||
self._clear_memory()
|
||||
|
||||
# Get reference style from voice pack
|
||||
ref_s = voice[len(tokens)].clone().to(self._device)
|
||||
if ref_s.dim() == 1:
|
||||
ref_s = ref_s.unsqueeze(0) # Add batch dimension if needed
|
||||
|
||||
# Generate audio
|
||||
return forward(self._model, tokens, ref_s, speed)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Generation failed: {e}")
|
||||
raise
|
||||
|
||||
def _check_memory(self) -> bool:
|
||||
"""Check if memory usage is above threshold."""
|
||||
if torch.cuda.is_available():
|
||||
memory_gb = torch.cuda.memory_allocated() / 1e9
|
||||
return memory_gb > self._config.memory_threshold
|
||||
return False
|
||||
|
||||
def _clear_memory(self) -> None:
|
||||
"""Clear GPU memory."""
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
191
api/src/inference/voice_manager.py
Normal file
191
api/src/inference/voice_manager.py
Normal file
|
@ -0,0 +1,191 @@
|
|||
"""Voice pack management and caching."""
|
||||
|
||||
import os
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..core import paths
|
||||
from ..core.config import settings
|
||||
from ..structures.model_schemas import VoiceConfig
|
||||
|
||||
|
||||
class VoiceManager:
|
||||
"""Manages voice loading, caching, and operations."""
|
||||
|
||||
def __init__(self, config: Optional[VoiceConfig] = None):
|
||||
"""Initialize voice manager.
|
||||
|
||||
Args:
|
||||
config: Optional voice configuration
|
||||
"""
|
||||
self._config = config or VoiceConfig()
|
||||
self._voice_cache: Dict[str, torch.Tensor] = {}
|
||||
|
||||
def get_voice_path(self, voice_name: str) -> Optional[str]:
|
||||
"""Get path to voice file.
|
||||
|
||||
Args:
|
||||
voice_name: Name of voice
|
||||
|
||||
Returns:
|
||||
Path to voice file if exists, None otherwise
|
||||
"""
|
||||
voice_path = os.path.join(settings.voices_dir, f"{voice_name}.pt")
|
||||
return voice_path if os.path.exists(voice_path) else None
|
||||
|
||||
async def load_voice(self, voice_name: str, device: str = "cpu") -> torch.Tensor:
|
||||
"""Load voice tensor.
|
||||
|
||||
Args:
|
||||
voice_name: Name of voice to load
|
||||
device: Device to load voice on
|
||||
|
||||
Returns:
|
||||
Voice tensor
|
||||
|
||||
Raises:
|
||||
RuntimeError: If voice loading fails
|
||||
"""
|
||||
voice_path = self.get_voice_path(voice_name)
|
||||
if not voice_path:
|
||||
raise RuntimeError(f"Voice not found: {voice_name}")
|
||||
|
||||
# Check cache first
|
||||
cache_key = f"{voice_path}_{device}"
|
||||
if self._config.use_cache and cache_key in self._voice_cache:
|
||||
return self._voice_cache[cache_key]
|
||||
|
||||
try:
|
||||
# Load voice tensor
|
||||
voice = await paths.load_voice_tensor(voice_path, device=device)
|
||||
|
||||
# Cache if enabled
|
||||
if self._config.use_cache:
|
||||
self._manage_cache()
|
||||
self._voice_cache[cache_key] = voice
|
||||
logger.debug(f"Cached voice: {voice_name} on {device}")
|
||||
|
||||
return voice
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load voice {voice_name}: {e}")
|
||||
|
||||
def _manage_cache(self) -> None:
|
||||
"""Manage voice cache size."""
|
||||
if len(self._voice_cache) >= self._config.cache_size:
|
||||
# Remove oldest voice
|
||||
oldest = next(iter(self._voice_cache))
|
||||
del self._voice_cache[oldest]
|
||||
logger.debug(f"Removed from voice cache: {oldest}")
|
||||
|
||||
async def combine_voices(self, voices: List[str], device: str = "cpu") -> str:
|
||||
"""Combine multiple voices into a new voice.
|
||||
|
||||
Args:
|
||||
voices: List of voice names to combine
|
||||
device: Device to load voices on
|
||||
|
||||
Returns:
|
||||
Name of combined voice
|
||||
|
||||
Raises:
|
||||
ValueError: If fewer than 2 voices provided
|
||||
RuntimeError: If voice combination fails
|
||||
"""
|
||||
if len(voices) < 2:
|
||||
raise ValueError("At least 2 voices are required for combination")
|
||||
|
||||
# Load voices
|
||||
voice_tensors: List[torch.Tensor] = []
|
||||
for voice in voices:
|
||||
try:
|
||||
voice_tensor = await self.load_voice(voice, device)
|
||||
voice_tensors.append(voice_tensor)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load voice {voice}: {e}")
|
||||
|
||||
try:
|
||||
# Combine voices
|
||||
combined_name = "_".join(voices)
|
||||
combined_tensor = torch.mean(torch.stack(voice_tensors), dim=0)
|
||||
|
||||
# Save combined voice
|
||||
combined_path = os.path.join(settings.voices_dir, f"{combined_name}.pt")
|
||||
try:
|
||||
torch.save(combined_tensor, combined_path)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to save combined voice: {e}")
|
||||
|
||||
return combined_name
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to combine voices: {e}")
|
||||
|
||||
async def list_voices(self) -> List[str]:
|
||||
"""List available voices.
|
||||
|
||||
Returns:
|
||||
List of voice names
|
||||
"""
|
||||
voices = []
|
||||
try:
|
||||
for entry in os.listdir(settings.voices_dir):
|
||||
if entry.endswith(".pt"):
|
||||
voices.append(entry[:-3]) # Remove .pt extension
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing voices: {e}")
|
||||
return sorted(voices)
|
||||
|
||||
def validate_voice(self, voice_path: str) -> bool:
|
||||
"""Validate voice file.
|
||||
|
||||
Args:
|
||||
voice_path: Path to voice file
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
if not os.path.exists(voice_path):
|
||||
return False
|
||||
|
||||
# Try loading voice
|
||||
voice = torch.load(voice_path, map_location="cpu")
|
||||
return isinstance(voice, torch.Tensor)
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@property
|
||||
def cache_info(self) -> Dict[str, int]:
|
||||
"""Get cache statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with cache info
|
||||
"""
|
||||
return {
|
||||
'size': len(self._voice_cache),
|
||||
'max_size': self._config.cache_size
|
||||
}
|
||||
|
||||
|
||||
# Module-level instance
|
||||
_manager: Optional[VoiceManager] = None
|
||||
|
||||
|
||||
def get_manager(config: Optional[VoiceConfig] = None) -> VoiceManager:
|
||||
"""Get or create global voice manager instance.
|
||||
|
||||
Args:
|
||||
config: Optional voice configuration
|
||||
|
||||
Returns:
|
||||
VoiceManager instance
|
||||
"""
|
||||
global _manager
|
||||
if _manager is None:
|
||||
_manager = VoiceManager(config)
|
||||
return _manager
|
|
@ -1,175 +0,0 @@
|
|||
import os
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from ..core.config import settings
|
||||
|
||||
|
||||
class TTSBaseModel(ABC):
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
_device = None
|
||||
VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices")
|
||||
|
||||
@classmethod
|
||||
async def setup(cls):
|
||||
"""Initialize model and setup voices"""
|
||||
with cls._lock:
|
||||
# Set device
|
||||
cuda_available = torch.cuda.is_available()
|
||||
logger.info(f"CUDA available: {cuda_available}")
|
||||
if cuda_available:
|
||||
try:
|
||||
# Test CUDA device
|
||||
test_tensor = torch.zeros(1).cuda()
|
||||
logger.info("CUDA test successful")
|
||||
model_path = os.path.join(
|
||||
settings.model_dir, settings.pytorch_model_path
|
||||
)
|
||||
cls._device = "cuda"
|
||||
except Exception as e:
|
||||
logger.error(f"CUDA test failed: {e}")
|
||||
cls._device = "cpu"
|
||||
else:
|
||||
cls._device = "cpu"
|
||||
model_path = os.path.join(settings.model_dir, settings.onnx_model_path)
|
||||
logger.info(f"Initializing model on {cls._device}")
|
||||
logger.info(f"Model dir: {settings.model_dir}")
|
||||
logger.info(f"Model path: {model_path}")
|
||||
logger.info(f"Files in model dir: {os.listdir(settings.model_dir)}")
|
||||
|
||||
# Initialize model first
|
||||
model = cls.initialize(settings.model_dir, model_path=model_path)
|
||||
if model is None:
|
||||
raise RuntimeError(f"Failed to initialize {cls._device.upper()} model")
|
||||
cls._instance = model
|
||||
|
||||
# Setup voices directory
|
||||
os.makedirs(cls.VOICES_DIR, exist_ok=True)
|
||||
|
||||
# Copy base voices to local directory
|
||||
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
|
||||
if os.path.exists(base_voices_dir):
|
||||
for file in os.listdir(base_voices_dir):
|
||||
if file.endswith(".pt"):
|
||||
voice_name = file[:-3]
|
||||
voice_path = os.path.join(cls.VOICES_DIR, file)
|
||||
if not os.path.exists(voice_path):
|
||||
try:
|
||||
logger.info(
|
||||
f"Copying base voice {voice_name} to voices directory"
|
||||
)
|
||||
base_path = os.path.join(base_voices_dir, file)
|
||||
voicepack = torch.load(
|
||||
base_path,
|
||||
map_location=cls._device,
|
||||
weights_only=True,
|
||||
)
|
||||
torch.save(voicepack, voice_path)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error copying voice {voice_name}: {str(e)}"
|
||||
)
|
||||
|
||||
# Count voices in directory
|
||||
voice_count = len(
|
||||
[f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]
|
||||
)
|
||||
|
||||
# Now that model and voices are ready, do warmup
|
||||
try:
|
||||
with open(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
"core",
|
||||
"don_quixote.txt",
|
||||
)
|
||||
) as f:
|
||||
warmup_text = f.read()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load warmup text: {e}")
|
||||
warmup_text = "This is a warmup text that will be split into chunks for processing."
|
||||
|
||||
# Use warmup service after model is fully initialized
|
||||
from .warmup import WarmupService
|
||||
|
||||
warmup = WarmupService()
|
||||
|
||||
# Load and warm up voices
|
||||
loaded_voices = warmup.load_voices()
|
||||
await warmup.warmup_voices(warmup_text, loaded_voices)
|
||||
|
||||
logger.info("Model warm-up complete")
|
||||
|
||||
# Count voices in directory
|
||||
voice_count = len(
|
||||
[f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]
|
||||
)
|
||||
return voice_count
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def initialize(cls, model_dir: str, model_path: str = None):
|
||||
"""Initialize the model"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def process_text(cls, text: str, language: str) -> Tuple[str, List[int]]:
|
||||
"""Process text into phonemes and tokens
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
language: Language code
|
||||
|
||||
Returns:
|
||||
tuple[str, list[int]]: Phonemes and token IDs
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def generate_from_text(
|
||||
cls, text: str, voicepack: torch.Tensor, language: str, speed: float
|
||||
) -> Tuple[np.ndarray, str]:
|
||||
"""Generate audio from text
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
voicepack: Voice tensor
|
||||
language: Language code
|
||||
speed: Speed factor
|
||||
|
||||
Returns:
|
||||
tuple[np.ndarray, str]: Generated audio samples and phonemes
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def generate_from_tokens(
|
||||
cls, tokens: List[int], voicepack: torch.Tensor, speed: float
|
||||
) -> np.ndarray:
|
||||
"""Generate audio from tokens
|
||||
|
||||
Args:
|
||||
tokens: Token IDs
|
||||
voicepack: Voice tensor
|
||||
speed: Speed factor
|
||||
|
||||
Returns:
|
||||
np.ndarray: Generated audio samples
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_device(cls):
|
||||
"""Get the current device"""
|
||||
if cls._device is None:
|
||||
raise RuntimeError("Model not initialized. Call setup() first.")
|
||||
return cls._device
|
|
@ -1,167 +0,0 @@
|
|||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
from onnxruntime import (
|
||||
ExecutionMode,
|
||||
GraphOptimizationLevel,
|
||||
InferenceSession,
|
||||
SessionOptions,
|
||||
)
|
||||
|
||||
from ..core.config import settings
|
||||
from .text_processing import phonemize, tokenize
|
||||
from .tts_base import TTSBaseModel
|
||||
|
||||
|
||||
class TTSCPUModel(TTSBaseModel):
|
||||
_instance = None
|
||||
_onnx_session = None
|
||||
_device = "cpu"
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
"""Get the model instance"""
|
||||
if cls._onnx_session is None:
|
||||
raise RuntimeError("ONNX model not initialized. Call initialize() first.")
|
||||
return cls._onnx_session
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, model_dir: str, model_path: str = None):
|
||||
"""Initialize ONNX model for CPU inference"""
|
||||
if cls._onnx_session is None:
|
||||
try:
|
||||
# Try loading ONNX model
|
||||
onnx_path = os.path.join(model_dir, settings.onnx_model_path)
|
||||
if not os.path.exists(onnx_path):
|
||||
logger.error(f"ONNX model not found at {onnx_path}")
|
||||
return None
|
||||
|
||||
logger.info(f"Loading ONNX model from {onnx_path}")
|
||||
|
||||
# Configure ONNX session for optimal performance
|
||||
session_options = SessionOptions()
|
||||
|
||||
# Set optimization level
|
||||
if settings.onnx_optimization_level == "all":
|
||||
session_options.graph_optimization_level = (
|
||||
GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
)
|
||||
elif settings.onnx_optimization_level == "basic":
|
||||
session_options.graph_optimization_level = (
|
||||
GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||
)
|
||||
else:
|
||||
session_options.graph_optimization_level = (
|
||||
GraphOptimizationLevel.ORT_DISABLE_ALL
|
||||
)
|
||||
|
||||
# Configure threading
|
||||
session_options.intra_op_num_threads = settings.onnx_num_threads
|
||||
session_options.inter_op_num_threads = settings.onnx_inter_op_threads
|
||||
|
||||
# Set execution mode
|
||||
session_options.execution_mode = (
|
||||
ExecutionMode.ORT_PARALLEL
|
||||
if settings.onnx_execution_mode == "parallel"
|
||||
else ExecutionMode.ORT_SEQUENTIAL
|
||||
)
|
||||
|
||||
# Enable/disable memory pattern optimization
|
||||
session_options.enable_mem_pattern = settings.onnx_memory_pattern
|
||||
|
||||
# Configure CPU provider options
|
||||
provider_options = {
|
||||
"CPUExecutionProvider": {
|
||||
"arena_extend_strategy": settings.onnx_arena_extend_strategy,
|
||||
"cpu_memory_arena_cfg": "cpu:0",
|
||||
}
|
||||
}
|
||||
|
||||
session = InferenceSession(
|
||||
onnx_path,
|
||||
sess_options=session_options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
provider_options=[provider_options],
|
||||
)
|
||||
cls._onnx_session = session
|
||||
return session
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize ONNX model: {e}")
|
||||
return None
|
||||
return cls._onnx_session
|
||||
|
||||
@classmethod
|
||||
def process_text(cls, text: str, language: str) -> tuple[str, list[int]]:
|
||||
"""Process text into phonemes and tokens
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
language: Language code
|
||||
|
||||
Returns:
|
||||
tuple[str, list[int]]: Phonemes and token IDs
|
||||
"""
|
||||
phonemes = phonemize(text, language)
|
||||
tokens = tokenize(phonemes)
|
||||
tokens = [0] + tokens + [0] # Add start/end tokens
|
||||
return phonemes, tokens
|
||||
|
||||
@classmethod
|
||||
def generate_from_text(
|
||||
cls, text: str, voicepack: torch.Tensor, language: str, speed: float
|
||||
) -> tuple[np.ndarray, str]:
|
||||
"""Generate audio from text
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
voicepack: Voice tensor
|
||||
language: Language code
|
||||
speed: Speed factor
|
||||
|
||||
Returns:
|
||||
tuple[np.ndarray, str]: Generated audio samples and phonemes
|
||||
"""
|
||||
if cls._onnx_session is None:
|
||||
raise RuntimeError("ONNX model not initialized")
|
||||
|
||||
# Process text
|
||||
phonemes, tokens = cls.process_text(text, language)
|
||||
|
||||
# Generate audio
|
||||
audio = cls.generate_from_tokens(tokens, voicepack, speed)
|
||||
|
||||
return audio, phonemes
|
||||
|
||||
@classmethod
|
||||
def generate_from_tokens(
|
||||
cls, tokens: list[int], voicepack: torch.Tensor, speed: float
|
||||
) -> np.ndarray:
|
||||
"""Generate audio from tokens
|
||||
|
||||
Args:
|
||||
tokens: Token IDs
|
||||
voicepack: Voice tensor
|
||||
speed: Speed factor
|
||||
|
||||
Returns:
|
||||
np.ndarray: Generated audio samples
|
||||
"""
|
||||
if cls._onnx_session is None:
|
||||
raise RuntimeError("ONNX model not initialized")
|
||||
|
||||
# Pre-allocate and prepare inputs
|
||||
tokens_input = np.array([tokens], dtype=np.int64)
|
||||
style_input = voicepack[
|
||||
len(tokens) - 2
|
||||
].numpy() # Already has correct dimensions
|
||||
speed_input = np.full(
|
||||
1, speed, dtype=np.float32
|
||||
) # More efficient than ones * speed
|
||||
|
||||
# Run inference with optimized inputs
|
||||
result = cls._onnx_session.run(
|
||||
None, {"tokens": tokens_input, "style": style_input, "speed": speed_input}
|
||||
)
|
||||
return result[0]
|
|
@ -1,262 +0,0 @@
|
|||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from ..builds.models import build_model
|
||||
from loguru import logger
|
||||
|
||||
from ..core.config import settings
|
||||
from .text_processing import phonemize, tokenize
|
||||
from .tts_base import TTSBaseModel
|
||||
|
||||
|
||||
# @torch.no_grad()
|
||||
# def forward(model, tokens, ref_s, speed):
|
||||
# """Forward pass through the model"""
|
||||
# device = ref_s.device
|
||||
# tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
||||
# input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
||||
# text_mask = length_to_mask(input_lengths).to(device)
|
||||
# bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
||||
# d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
||||
# s = ref_s[:, 128:]
|
||||
# d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
|
||||
# x, _ = model.predictor.lstm(d)
|
||||
# duration = model.predictor.duration_proj(x)
|
||||
# duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
||||
# pred_dur = torch.round(duration).clamp(min=1).long()
|
||||
# pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
|
||||
# c_frame = 0
|
||||
# for i in range(pred_aln_trg.size(0)):
|
||||
# pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
|
||||
# c_frame += pred_dur[0, i].item()
|
||||
# en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
|
||||
# F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
||||
# t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
||||
# asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
|
||||
# return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
|
||||
@torch.no_grad()
|
||||
def forward(model, tokens, ref_s, speed):
|
||||
"""Forward pass through the model with moderate memory management"""
|
||||
device = ref_s.device
|
||||
|
||||
try:
|
||||
# Initial tensor setup with proper device placement
|
||||
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
||||
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
||||
text_mask = length_to_mask(input_lengths).to(device)
|
||||
|
||||
# Split and clone reference signals with explicit device placement
|
||||
s_content = ref_s[:, 128:].clone().to(device)
|
||||
s_ref = ref_s[:, :128].clone().to(device)
|
||||
|
||||
# BERT and encoder pass
|
||||
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
||||
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
||||
|
||||
# Predictor forward pass
|
||||
d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask)
|
||||
x, _ = model.predictor.lstm(d)
|
||||
|
||||
# Duration prediction
|
||||
duration = model.predictor.duration_proj(x)
|
||||
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
||||
pred_dur = torch.round(duration).clamp(min=1).long()
|
||||
# Only cleanup large intermediates
|
||||
del duration, x
|
||||
|
||||
# Alignment matrix construction
|
||||
pred_aln_trg = torch.zeros(input_lengths.item(), pred_dur.sum().item(), device=device)
|
||||
c_frame = 0
|
||||
for i in range(pred_aln_trg.size(0)):
|
||||
pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
|
||||
c_frame += pred_dur[0, i].item()
|
||||
pred_aln_trg = pred_aln_trg.unsqueeze(0)
|
||||
|
||||
# Matrix multiplications with selective cleanup
|
||||
en = d.transpose(-1, -2) @ pred_aln_trg
|
||||
del d # Free large intermediate tensor
|
||||
|
||||
F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
|
||||
del en # Free large intermediate tensor
|
||||
|
||||
# Final text encoding and decoding
|
||||
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
||||
asr = t_en @ pred_aln_trg
|
||||
del t_en # Free large intermediate tensor
|
||||
|
||||
# Final decoding and transfer to CPU
|
||||
output = model.decoder(asr, F0_pred, N_pred, s_ref)
|
||||
result = output.squeeze().cpu().numpy()
|
||||
|
||||
return result
|
||||
|
||||
finally:
|
||||
# Let PyTorch handle most cleanup automatically
|
||||
# Only explicitly free the largest tensors
|
||||
del pred_aln_trg, asr
|
||||
|
||||
|
||||
# def length_to_mask(lengths):
|
||||
# """Create attention mask from lengths"""
|
||||
# mask = (
|
||||
# torch.arange(lengths.max())
|
||||
# .unsqueeze(0)
|
||||
# .expand(lengths.shape[0], -1)
|
||||
# .type_as(lengths)
|
||||
# )
|
||||
# mask = torch.gt(mask + 1, lengths.unsqueeze(1))
|
||||
# return mask
|
||||
|
||||
|
||||
def length_to_mask(lengths):
|
||||
"""Create attention mask from lengths - possibly optimized version"""
|
||||
max_len = lengths.max()
|
||||
# Create mask directly on the same device as lengths
|
||||
mask = torch.arange(max_len, device=lengths.device)[None, :].expand(
|
||||
lengths.shape[0], -1
|
||||
)
|
||||
# Avoid type_as by using the correct dtype from the start
|
||||
if lengths.dtype != mask.dtype:
|
||||
mask = mask.to(dtype=lengths.dtype)
|
||||
# Fuse operations using broadcasting
|
||||
return mask + 1 > lengths[:, None]
|
||||
|
||||
|
||||
class TTSGPUModel(TTSBaseModel):
|
||||
_instance = None
|
||||
_device = "cuda"
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
"""Get the model instance"""
|
||||
if cls._instance is None:
|
||||
raise RuntimeError("GPU model not initialized. Call initialize() first.")
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, model_dir: str, model_path: str):
|
||||
"""Initialize PyTorch model for GPU inference"""
|
||||
if cls._instance is None and torch.cuda.is_available():
|
||||
try:
|
||||
logger.info("Initializing GPU model")
|
||||
model_path = os.path.join(model_dir, settings.pytorch_model_path)
|
||||
model = build_model(model_path, cls._device)
|
||||
cls._instance = model
|
||||
return model
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize GPU model: {e}")
|
||||
return None
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def process_text(cls, text: str, language: str) -> tuple[str, list[int]]:
|
||||
"""Process text into phonemes and tokens
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
language: Language code
|
||||
|
||||
Returns:
|
||||
tuple[str, list[int]]: Phonemes and token IDs
|
||||
"""
|
||||
phonemes = phonemize(text, language)
|
||||
tokens = tokenize(phonemes)
|
||||
return phonemes, tokens
|
||||
|
||||
@classmethod
|
||||
def generate_from_text(
|
||||
cls, text: str, voicepack: torch.Tensor, language: str, speed: float
|
||||
) -> tuple[np.ndarray, str]:
|
||||
"""Generate audio from text
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
voicepack: Voice tensor
|
||||
language: Language code
|
||||
speed: Speed factor
|
||||
|
||||
Returns:
|
||||
tuple[np.ndarray, str]: Generated audio samples and phonemes
|
||||
"""
|
||||
if cls._instance is None:
|
||||
raise RuntimeError("GPU model not initialized")
|
||||
|
||||
# Process text
|
||||
phonemes, tokens = cls.process_text(text, language)
|
||||
|
||||
# Generate audio
|
||||
audio = cls.generate_from_tokens(tokens, voicepack, speed)
|
||||
|
||||
return audio, phonemes
|
||||
|
||||
@classmethod
|
||||
def generate_from_tokens(
|
||||
cls, tokens: list[int], voicepack: torch.Tensor, speed: float
|
||||
) -> np.ndarray:
|
||||
"""Generate audio from tokens with moderate memory management
|
||||
|
||||
Args:
|
||||
tokens: Token IDs
|
||||
voicepack: Voice tensor
|
||||
speed: Speed factor
|
||||
|
||||
Returns:
|
||||
np.ndarray: Generated audio samples
|
||||
"""
|
||||
if cls._instance is None:
|
||||
raise RuntimeError("GPU model not initialized")
|
||||
|
||||
try:
|
||||
device = cls._device
|
||||
|
||||
# Check memory pressure
|
||||
if torch.cuda.is_available():
|
||||
memory_allocated = torch.cuda.memory_allocated(device) / 1e9 # Convert to GB
|
||||
if memory_allocated > 2.0: # 2GB limit
|
||||
logger.info(
|
||||
f"Memory usage above 2GB threshold:{memory_allocated:.2f}GB "
|
||||
f"Clearing cache"
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
import gc
|
||||
gc.collect()
|
||||
|
||||
# Get reference style with proper device placement
|
||||
ref_s = voicepack[len(tokens)].clone().to(device)
|
||||
|
||||
# Generate audio
|
||||
audio = forward(cls._instance, tokens, ref_s, speed)
|
||||
|
||||
return audio
|
||||
|
||||
except RuntimeError as e:
|
||||
if "out of memory" in str(e):
|
||||
# On OOM, do a full cleanup and retry
|
||||
if torch.cuda.is_available():
|
||||
logger.warning("Out of memory detected, performing full cleanup")
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
import gc
|
||||
gc.collect()
|
||||
|
||||
# Log memory stats after cleanup
|
||||
memory_allocated = torch.cuda.memory_allocated(device)
|
||||
memory_reserved = torch.cuda.memory_reserved(device)
|
||||
logger.info(
|
||||
f"Memory after OOM cleanup: "
|
||||
f"Allocated: {memory_allocated / 1e9:.2f}GB, "
|
||||
f"Reserved: {memory_reserved / 1e9:.2f}GB"
|
||||
)
|
||||
|
||||
# Retry generation
|
||||
ref_s = voicepack[len(tokens)].clone().to(device)
|
||||
audio = forward(cls._instance, tokens, ref_s, speed)
|
||||
return audio
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Only synchronize at the top level, no empty_cache
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
|
@ -1,8 +0,0 @@
|
|||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
from .tts_gpu import TTSGPUModel as TTSModel
|
||||
else:
|
||||
from .tts_cpu import TTSCPUModel as TTSModel
|
||||
|
||||
__all__ = ["TTSModel"]
|
|
@ -1,120 +1,114 @@
|
|||
"""TTS service using model and voice managers."""
|
||||
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import aiofiles.os
|
||||
import numpy as np
|
||||
import scipy.io.wavfile as wavfile
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from ..core.config import settings
|
||||
from ..inference.model_manager import get_manager as get_model_manager
|
||||
from ..inference.voice_manager import get_manager as get_voice_manager
|
||||
from .audio import AudioNormalizer, AudioService
|
||||
from .text_processing import chunker, normalize_text
|
||||
from .tts_model import TTSModel
|
||||
|
||||
|
||||
class TTSService:
|
||||
"""Text-to-speech service."""
|
||||
|
||||
def __init__(self, output_dir: str = None):
|
||||
"""Initialize service.
|
||||
|
||||
Args:
|
||||
output_dir: Optional output directory for saving audio
|
||||
"""
|
||||
self.output_dir = output_dir
|
||||
self.model = TTSModel.get_instance()
|
||||
self.model_manager = get_model_manager()
|
||||
self.voice_manager = get_voice_manager()
|
||||
self._initialized = False
|
||||
self._initialization_error = None
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=3) # Cache up to 3 most recently used voices
|
||||
def _load_voice(voice_path: str) -> torch.Tensor:
|
||||
"""Load and cache a voice model"""
|
||||
return torch.load(
|
||||
voice_path, map_location=TTSModel.get_device(), weights_only=True
|
||||
)
|
||||
async def ensure_initialized(self):
|
||||
"""Ensure model is initialized."""
|
||||
if self._initialized:
|
||||
return
|
||||
if self._initialization_error:
|
||||
raise self._initialization_error
|
||||
|
||||
def _get_voice_path(self, voice_name: str) -> Optional[str]:
|
||||
"""Get the path to a voice file"""
|
||||
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice_name}.pt")
|
||||
return voice_path if os.path.exists(voice_path) else None
|
||||
try:
|
||||
# Determine model path based on hardware
|
||||
if settings.use_gpu and torch.cuda.is_available():
|
||||
model_path = os.path.join(settings.model_dir, settings.pytorch_model_path)
|
||||
backend_type = 'pytorch_gpu'
|
||||
else:
|
||||
model_path = os.path.join(settings.model_dir, settings.onnx_model_path)
|
||||
backend_type = 'onnx_cpu'
|
||||
|
||||
# Initialize model
|
||||
await self.model_manager.load_model(model_path, backend_type)
|
||||
logger.info(f"Initialized model on {backend_type} backend")
|
||||
self._initialized = True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize model: {e}")
|
||||
self._initialization_error = RuntimeError(f"Model initialization failed: {e}")
|
||||
raise self._initialization_error
|
||||
|
||||
def _generate_audio(
|
||||
self, text: str, voice: str, speed: float, stitch_long_output: bool = True
|
||||
) -> Tuple[torch.Tensor, float]:
|
||||
"""Generate complete audio and return with processing time"""
|
||||
audio, processing_time = self._generate_audio_internal(
|
||||
text, voice, speed, stitch_long_output
|
||||
)
|
||||
return audio, processing_time
|
||||
|
||||
def _generate_audio_internal(
|
||||
self, text: str, voice: str, speed: float, stitch_long_output: bool = True
|
||||
) -> Tuple[torch.Tensor, float]:
|
||||
"""Generate audio and measure processing time"""
|
||||
async def generate_audio(
|
||||
self, text: str, voice: str, speed: float = 1.0
|
||||
) -> Tuple[np.ndarray, float]:
|
||||
"""Generate audio for text.
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
voice: Voice name
|
||||
speed: Speed multiplier
|
||||
|
||||
Returns:
|
||||
Audio samples and processing time
|
||||
"""
|
||||
await self.ensure_initialized()
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Normalize text once at the start
|
||||
if not text:
|
||||
raise ValueError("Text is empty after preprocessing")
|
||||
# Normalize text
|
||||
normalized = normalize_text(text)
|
||||
if not normalized:
|
||||
raise ValueError("Text is empty after preprocessing")
|
||||
text = str(normalized)
|
||||
|
||||
# Check voice exists
|
||||
voice_path = self._get_voice_path(voice)
|
||||
if not voice_path:
|
||||
raise ValueError(f"Voice not found: {voice}")
|
||||
|
||||
# Load voice using cached loader
|
||||
voicepack = self._load_voice(voice_path)
|
||||
|
||||
# For non-streaming, preprocess all chunks first
|
||||
if stitch_long_output:
|
||||
# Preprocess all chunks to phonemes/tokens
|
||||
chunks_data = []
|
||||
for chunk in chunker.split_text(text):
|
||||
try:
|
||||
phonemes, tokens = TTSModel.process_text(chunk, voice[0])
|
||||
chunks_data.append((chunk, tokens))
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to process chunk: '{chunk}'. Error: {str(e)}"
|
||||
)
|
||||
# Process text into chunks
|
||||
audio_chunks = []
|
||||
for chunk in chunker.split_text(text):
|
||||
try:
|
||||
# Process text
|
||||
|
||||
sequences = process_text(chunk)
|
||||
if not sequences:
|
||||
continue
|
||||
|
||||
if not chunks_data:
|
||||
raise ValueError("No chunks were processed successfully")
|
||||
# Generate audio
|
||||
chunk_audio = await self.model_manager.generate(
|
||||
sequences[0],
|
||||
voice,
|
||||
speed=speed
|
||||
)
|
||||
if chunk_audio is not None:
|
||||
audio_chunks.append(chunk_audio)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process chunk: '{chunk}'. Error: {str(e)}")
|
||||
continue
|
||||
|
||||
# Generate audio for all chunks
|
||||
audio_chunks = []
|
||||
for chunk, tokens in chunks_data:
|
||||
try:
|
||||
chunk_audio = TTSModel.generate_from_tokens(
|
||||
tokens, voicepack, speed
|
||||
)
|
||||
if chunk_audio is not None:
|
||||
audio_chunks.append(chunk_audio)
|
||||
else:
|
||||
logger.error(f"No audio generated for chunk: '{chunk}'")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to generate audio for chunk: '{chunk}'. Error: {str(e)}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not audio_chunks:
|
||||
raise ValueError("No audio chunks were generated successfully")
|
||||
|
||||
# Concatenate all chunks
|
||||
audio = (
|
||||
np.concatenate(audio_chunks)
|
||||
if len(audio_chunks) > 1
|
||||
else audio_chunks[0]
|
||||
)
|
||||
else:
|
||||
# Process single chunk
|
||||
phonemes, tokens = TTSModel.process_text(text, voice[0])
|
||||
audio = TTSModel.generate_from_tokens(tokens, voicepack, speed)
|
||||
if not audio_chunks:
|
||||
raise ValueError("No audio chunks were generated successfully")
|
||||
|
||||
# Combine chunks
|
||||
audio = np.concatenate(audio_chunks) if len(audio_chunks) > 1 else audio_chunks[0]
|
||||
processing_time = time.time() - start_time
|
||||
return audio, processing_time
|
||||
|
||||
|
@ -126,144 +120,103 @@ class TTSService:
|
|||
self,
|
||||
text: str,
|
||||
voice: str,
|
||||
speed: float,
|
||||
speed: float = 1.0,
|
||||
output_format: str = "wav",
|
||||
silent=False,
|
||||
):
|
||||
"""Generate and yield audio chunks as they're generated for real-time streaming"""
|
||||
try:
|
||||
stream_start = time.time()
|
||||
# Create normalizer for consistent audio levels
|
||||
stream_normalizer = AudioNormalizer()
|
||||
"""Generate and stream audio chunks.
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
voice: Voice name
|
||||
speed: Speed multiplier
|
||||
output_format: Output audio format
|
||||
|
||||
Yields:
|
||||
Audio chunks as bytes
|
||||
"""
|
||||
await self.ensure_initialized()
|
||||
|
||||
# Input validation and preprocessing
|
||||
if not text:
|
||||
raise ValueError("Text is empty")
|
||||
preprocess_start = time.time()
|
||||
try:
|
||||
# Setup audio processing
|
||||
stream_normalizer = AudioNormalizer()
|
||||
|
||||
# Normalize text
|
||||
normalized = normalize_text(text)
|
||||
if not normalized:
|
||||
raise ValueError("Text is empty after preprocessing")
|
||||
text = str(normalized)
|
||||
logger.debug(
|
||||
f"Text preprocessing took: {(time.time() - preprocess_start)*1000:.1f}ms"
|
||||
)
|
||||
|
||||
# Voice validation and loading
|
||||
voice_start = time.time()
|
||||
voice_path = self._get_voice_path(voice)
|
||||
if not voice_path:
|
||||
raise ValueError(f"Voice not found: {voice}")
|
||||
voicepack = self._load_voice(voice_path)
|
||||
logger.debug(
|
||||
f"Voice loading took: {(time.time() - voice_start)*1000:.1f}ms"
|
||||
)
|
||||
|
||||
# Process chunks as they're generated
|
||||
# Process chunks
|
||||
is_first = True
|
||||
chunks_processed = 0
|
||||
|
||||
# Process chunks as they come from generator
|
||||
chunk_gen = chunker.split_text(text)
|
||||
current_chunk = next(chunk_gen, None)
|
||||
|
||||
while current_chunk is not None:
|
||||
next_chunk = next(chunk_gen, None) # Peek at next chunk
|
||||
chunks_processed += 1
|
||||
next_chunk = next(chunk_gen, None)
|
||||
try:
|
||||
# Process text and generate audio
|
||||
phonemes, tokens = TTSModel.process_text(current_chunk, voice[0])
|
||||
chunk_audio = TTSModel.generate_from_tokens(
|
||||
tokens, voicepack, speed
|
||||
)
|
||||
|
||||
if chunk_audio is not None:
|
||||
# Convert chunk with proper streaming header handling
|
||||
chunk_bytes = AudioService.convert_audio(
|
||||
chunk_audio,
|
||||
24000,
|
||||
output_format,
|
||||
is_first_chunk=is_first,
|
||||
normalizer=stream_normalizer,
|
||||
is_last_chunk=(next_chunk is None), # Last if no next chunk
|
||||
stream=True # Ensure proper streaming format handling
|
||||
# Process text
|
||||
from ..text_processing import process_text
|
||||
sequences = process_text(current_chunk)
|
||||
if sequences:
|
||||
# Generate audio
|
||||
chunk_audio = await self.model_manager.generate(
|
||||
sequences[0],
|
||||
voice,
|
||||
speed=speed
|
||||
)
|
||||
|
||||
yield chunk_bytes
|
||||
is_first = False
|
||||
else:
|
||||
logger.error(f"No audio generated for chunk: '{current_chunk}'")
|
||||
if chunk_audio is not None:
|
||||
# Convert to bytes
|
||||
chunk_bytes = AudioService.convert_audio(
|
||||
chunk_audio,
|
||||
24000,
|
||||
output_format,
|
||||
is_first_chunk=is_first,
|
||||
normalizer=stream_normalizer,
|
||||
is_last_chunk=(next_chunk is None),
|
||||
stream=True
|
||||
)
|
||||
yield chunk_bytes
|
||||
is_first = False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to generate audio for chunk: '{current_chunk}'. Error: {str(e)}"
|
||||
)
|
||||
logger.error(f"Failed to generate audio for chunk: '{current_chunk}'. Error: {str(e)}")
|
||||
|
||||
current_chunk = next_chunk # Move to next chunk
|
||||
current_chunk = next_chunk
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in audio generation stream: {str(e)}")
|
||||
raise
|
||||
|
||||
def _save_audio(self, audio: torch.Tensor, filepath: str):
|
||||
"""Save audio to file"""
|
||||
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
||||
wavfile.write(filepath, 24000, audio)
|
||||
|
||||
def _audio_to_bytes(self, audio: torch.Tensor) -> bytes:
|
||||
"""Convert audio tensor to WAV bytes"""
|
||||
buffer = io.BytesIO()
|
||||
wavfile.write(buffer, 24000, audio)
|
||||
return buffer.getvalue()
|
||||
|
||||
async def combine_voices(self, voices: List[str]) -> str:
|
||||
"""Combine multiple voices into a new voice"""
|
||||
if len(voices) < 2:
|
||||
raise ValueError("At least 2 voices are required for combination")
|
||||
|
||||
# Load voices
|
||||
t_voices: List[torch.Tensor] = []
|
||||
v_name: List[str] = []
|
||||
|
||||
for voice in voices:
|
||||
try:
|
||||
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt")
|
||||
voicepack = torch.load(
|
||||
voice_path, map_location=TTSModel.get_device(), weights_only=True
|
||||
)
|
||||
t_voices.append(voicepack)
|
||||
v_name.append(voice)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load voice {voice}: {str(e)}")
|
||||
|
||||
# Combine voices
|
||||
try:
|
||||
f: str = "_".join(v_name)
|
||||
v = torch.mean(torch.stack(t_voices), dim=0)
|
||||
combined_path = os.path.join(TTSModel.VOICES_DIR, f"{f}.pt")
|
||||
|
||||
# Save combined voice
|
||||
try:
|
||||
torch.save(v, combined_path)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to save combined voice to {combined_path}: {str(e)}"
|
||||
)
|
||||
|
||||
return f
|
||||
|
||||
except Exception as e:
|
||||
if not isinstance(e, (ValueError, RuntimeError)):
|
||||
raise RuntimeError(f"Error combining voices: {str(e)}")
|
||||
raise
|
||||
"""Combine multiple voices.
|
||||
|
||||
Args:
|
||||
voices: List of voice names
|
||||
|
||||
Returns:
|
||||
Name of combined voice
|
||||
"""
|
||||
await self.ensure_initialized()
|
||||
return await self.voice_manager.combine_voices(voices)
|
||||
|
||||
async def list_voices(self) -> List[str]:
|
||||
"""List all available voices"""
|
||||
voices = []
|
||||
try:
|
||||
it = await aiofiles.os.scandir(TTSModel.VOICES_DIR)
|
||||
for entry in it:
|
||||
if entry.name.endswith(".pt"):
|
||||
voices.append(entry.name[:-3]) # Remove .pt extension
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing voices: {str(e)}")
|
||||
return sorted(voices)
|
||||
"""List available voices.
|
||||
|
||||
Returns:
|
||||
List of voice names
|
||||
"""
|
||||
return await self.voice_manager.list_voices()
|
||||
|
||||
def _audio_to_bytes(self, audio: np.ndarray) -> bytes:
|
||||
"""Convert audio to WAV bytes.
|
||||
|
||||
Args:
|
||||
audio: Audio samples
|
||||
|
||||
Returns:
|
||||
WAV bytes
|
||||
"""
|
||||
buffer = io.BytesIO()
|
||||
wavfile.write(buffer, 24000, audio)
|
||||
return buffer.getvalue()
|
26
api/src/structures/model_schemas.py
Normal file
26
api/src/structures/model_schemas.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
"""Model and voice configuration schemas."""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
"""Model configuration."""
|
||||
optimization_level: str = "all" # all, basic, none
|
||||
num_threads: int = 4
|
||||
inter_op_threads: int = 4
|
||||
execution_mode: str = "parallel" # parallel, sequential
|
||||
memory_pattern: bool = True
|
||||
arena_extend_strategy: str = "kNextPowerOfTwo"
|
||||
|
||||
class Config:
|
||||
frozen = True # Make config immutable
|
||||
|
||||
|
||||
class VoiceConfig(BaseModel):
|
||||
"""Voice configuration."""
|
||||
use_cache: bool = True
|
||||
cache_size: int = 3 # Number of voices to cache
|
||||
validate_on_load: bool = True # Whether to validate voices when loading
|
||||
|
||||
class Config:
|
||||
frozen = True # Make config immutable
|
Loading…
Add table
Reference in a new issue