mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Refactor inference architecture: remove legacy TTS model, add ONNX and PyTorch backends, and introduce model configuration schemas
This commit is contained in:
parent
83c55ca735
commit
ab28a62e86
16 changed files with 1606 additions and 813 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -18,7 +18,8 @@ __pycache__/
|
||||||
*.egg
|
*.egg
|
||||||
dist/
|
dist/
|
||||||
build/
|
build/
|
||||||
|
*.onnx
|
||||||
|
*.pth
|
||||||
# Environment
|
# Environment
|
||||||
# .env
|
# .env
|
||||||
.venv/
|
.venv/
|
||||||
|
|
198
api/src/core/paths.py
Normal file
198
api/src/core/paths.py
Normal file
|
@ -0,0 +1,198 @@
|
||||||
|
"""Async file and path operations."""
|
||||||
|
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, AsyncIterator, Callable, Set
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
|
import aiofiles.os
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from .config import settings
|
||||||
|
|
||||||
|
|
||||||
|
async def _find_file(
|
||||||
|
filename: str,
|
||||||
|
search_paths: List[str],
|
||||||
|
filter_fn: Optional[Callable[[str], bool]] = None
|
||||||
|
) -> str:
|
||||||
|
"""Find file in search paths.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Name of file to find
|
||||||
|
search_paths: List of paths to search in
|
||||||
|
filter_fn: Optional function to filter files
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Absolute path to file
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If file not found
|
||||||
|
"""
|
||||||
|
if os.path.isabs(filename) and await aiofiles.os.path.exists(filename):
|
||||||
|
return filename
|
||||||
|
|
||||||
|
for path in search_paths:
|
||||||
|
full_path = os.path.join(path, filename)
|
||||||
|
if await aiofiles.os.path.exists(full_path):
|
||||||
|
if filter_fn is None or filter_fn(full_path):
|
||||||
|
return full_path
|
||||||
|
|
||||||
|
raise RuntimeError(f"File not found: {filename} in paths: {search_paths}")
|
||||||
|
|
||||||
|
|
||||||
|
async def _scan_directories(
|
||||||
|
search_paths: List[str],
|
||||||
|
filter_fn: Optional[Callable[[str], bool]] = None
|
||||||
|
) -> Set[str]:
|
||||||
|
"""Scan directories for files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
search_paths: List of paths to scan
|
||||||
|
filter_fn: Optional function to filter files
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set of matching filenames
|
||||||
|
"""
|
||||||
|
results = set()
|
||||||
|
|
||||||
|
for path in search_paths:
|
||||||
|
if not await aiofiles.os.path.exists(path):
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get directory entries first
|
||||||
|
entries = await aiofiles.os.scandir(path)
|
||||||
|
# Then process entries after await completes
|
||||||
|
for entry in entries:
|
||||||
|
if filter_fn is None or filter_fn(entry.name):
|
||||||
|
results.add(entry.name)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error scanning {path}: {e}")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
async def get_model_path(model_name: str) -> str:
|
||||||
|
"""Get path to model file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of model file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Absolute path to model file
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If model not found
|
||||||
|
"""
|
||||||
|
search_paths = [
|
||||||
|
settings.model_dir,
|
||||||
|
os.path.join(os.path.dirname(__file__), "..", "..", "..", "models")
|
||||||
|
]
|
||||||
|
|
||||||
|
return await _find_file(model_name, search_paths)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_voice_path(voice_name: str) -> str:
|
||||||
|
"""Get path to voice file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
voice_name: Name of voice file (without .pt extension)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Absolute path to voice file
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If voice not found
|
||||||
|
"""
|
||||||
|
voice_file = f"{voice_name}.pt"
|
||||||
|
|
||||||
|
search_paths = [
|
||||||
|
os.path.join(settings.model_dir, "..", settings.voices_dir),
|
||||||
|
os.path.join(os.path.dirname(__file__), "..", settings.voices_dir)
|
||||||
|
]
|
||||||
|
|
||||||
|
return await _find_file(voice_file, search_paths)
|
||||||
|
|
||||||
|
|
||||||
|
async def list_voices() -> List[str]:
|
||||||
|
"""List available voice files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of voice names (without .pt extension)
|
||||||
|
"""
|
||||||
|
search_paths = [
|
||||||
|
os.path.join(settings.model_dir, "..", settings.voices_dir),
|
||||||
|
os.path.join(os.path.dirname(__file__), "..", settings.voices_dir)
|
||||||
|
]
|
||||||
|
|
||||||
|
def filter_voice_files(name: str) -> bool:
|
||||||
|
return name.endswith('.pt')
|
||||||
|
|
||||||
|
voices = await _scan_directories(search_paths, filter_voice_files)
|
||||||
|
return sorted([name[:-3] for name in voices]) # Remove .pt extension
|
||||||
|
|
||||||
|
|
||||||
|
async def load_voice_tensor(voice_path: str, device: str = "cpu") -> torch.Tensor:
|
||||||
|
"""Load voice tensor from file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
voice_path: Path to voice file
|
||||||
|
device: Device to load tensor to
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Voice tensor
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If file cannot be read
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with aiofiles.open(voice_path, 'rb') as f:
|
||||||
|
data = await f.read()
|
||||||
|
return torch.load(
|
||||||
|
io.BytesIO(data),
|
||||||
|
map_location=device,
|
||||||
|
weights_only=True
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to load voice tensor from {voice_path}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def save_voice_tensor(tensor: torch.Tensor, voice_path: str) -> None:
|
||||||
|
"""Save voice tensor to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: Voice tensor to save
|
||||||
|
voice_path: Path to save voice file
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If file cannot be written
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
torch.save(tensor, buffer)
|
||||||
|
async with aiofiles.open(voice_path, 'wb') as f:
|
||||||
|
await f.write(buffer.getvalue())
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to save voice tensor to {voice_path}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def read_file(path: str) -> str:
|
||||||
|
"""Read text file asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
File contents as string
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If file cannot be read
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with aiofiles.open(path, 'r', encoding='utf-8') as f:
|
||||||
|
return await f.read()
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to read file {path}: {e}")
|
20
api/src/inference/__init__.py
Normal file
20
api/src/inference/__init__.py
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
"""Inference backends and model management."""
|
||||||
|
|
||||||
|
from .base import BaseModelBackend
|
||||||
|
from .model_manager import ModelManager, get_manager
|
||||||
|
from .onnx_cpu import ONNXCPUBackend
|
||||||
|
from .onnx_gpu import ONNXGPUBackend
|
||||||
|
from .pytorch_cpu import PyTorchCPUBackend
|
||||||
|
from .pytorch_gpu import PyTorchGPUBackend
|
||||||
|
from ..structures.model_schemas import ModelConfig
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'BaseModelBackend',
|
||||||
|
'ModelManager',
|
||||||
|
'get_manager',
|
||||||
|
'ModelConfig',
|
||||||
|
'ONNXCPUBackend',
|
||||||
|
'ONNXGPUBackend',
|
||||||
|
'PyTorchCPUBackend',
|
||||||
|
'PyTorchGPUBackend'
|
||||||
|
]
|
97
api/src/inference/base.py
Normal file
97
api/src/inference/base.py
Normal file
|
@ -0,0 +1,97 @@
|
||||||
|
"""Base interfaces for model inference."""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class ModelBackend(ABC):
|
||||||
|
"""Abstract base class for model inference backends."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def load_model(self, path: str) -> None:
|
||||||
|
"""Load model from path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to model file
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If model loading fails
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
tokens: List[int],
|
||||||
|
voice: torch.Tensor,
|
||||||
|
speed: float = 1.0
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Generate audio from tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: Input token IDs
|
||||||
|
voice: Voice embedding tensor
|
||||||
|
speed: Speed multiplier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generated audio samples
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If generation fails
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def unload(self) -> None:
|
||||||
|
"""Unload model and free resources."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def is_loaded(self) -> bool:
|
||||||
|
"""Check if model is loaded.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if model is loaded, False otherwise
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def device(self) -> str:
|
||||||
|
"""Get device model is running on.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Device string ('cpu' or 'cuda')
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BaseModelBackend(ModelBackend):
|
||||||
|
"""Base implementation of model backend."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize base backend."""
|
||||||
|
self._model: Optional[torch.nn.Module] = None
|
||||||
|
self._device: str = "cpu"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_loaded(self) -> bool:
|
||||||
|
"""Check if model is loaded."""
|
||||||
|
return self._model is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> str:
|
||||||
|
"""Get device model is running on."""
|
||||||
|
return self._device
|
||||||
|
|
||||||
|
def unload(self) -> None:
|
||||||
|
"""Unload model and free resources."""
|
||||||
|
if self._model is not None:
|
||||||
|
del self._model
|
||||||
|
self._model = None
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
251
api/src/inference/model_manager.py
Normal file
251
api/src/inference/model_manager.py
Normal file
|
@ -0,0 +1,251 @@
|
||||||
|
"""Model management and caching."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from .base import BaseModelBackend
|
||||||
|
from .voice_manager import get_manager as get_voice_manager
|
||||||
|
from .onnx_cpu import ONNXCPUBackend
|
||||||
|
from .onnx_gpu import ONNXGPUBackend
|
||||||
|
from .pytorch_cpu import PyTorchCPUBackend
|
||||||
|
from .pytorch_gpu import PyTorchGPUBackend
|
||||||
|
from ..core import paths
|
||||||
|
from ..core.config import settings
|
||||||
|
from ..structures.model_schemas import ModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
class ModelManager:
|
||||||
|
"""Manages model loading and inference across backends."""
|
||||||
|
|
||||||
|
def __init__(self, config: Optional[ModelConfig] = None):
|
||||||
|
"""Initialize model manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Optional configuration
|
||||||
|
"""
|
||||||
|
self._config = config or ModelConfig()
|
||||||
|
self._backends: Dict[str, BaseModelBackend] = {}
|
||||||
|
self._current_backend: Optional[str] = None
|
||||||
|
self._voice_manager = get_voice_manager()
|
||||||
|
self._initialize_backends()
|
||||||
|
|
||||||
|
def _initialize_backends(self) -> None:
|
||||||
|
"""Initialize available backends."""
|
||||||
|
"""Initialize available backends."""
|
||||||
|
# Initialize GPU backends if available
|
||||||
|
if settings.use_gpu and torch.cuda.is_available():
|
||||||
|
try:
|
||||||
|
# PyTorch GPU
|
||||||
|
self._backends['pytorch_gpu'] = PyTorchGPUBackend()
|
||||||
|
self._current_backend = 'pytorch_gpu'
|
||||||
|
logger.info("Initialized PyTorch GPU backend")
|
||||||
|
|
||||||
|
# ONNX GPU
|
||||||
|
self._backends['onnx_gpu'] = ONNXGPUBackend()
|
||||||
|
logger.info("Initialized ONNX GPU backend")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize GPU backends: {e}")
|
||||||
|
# Fallback to CPU if GPU fails
|
||||||
|
self._initialize_cpu_backends()
|
||||||
|
else:
|
||||||
|
self._initialize_cpu_backends()
|
||||||
|
|
||||||
|
def _initialize_cpu_backends(self) -> None:
|
||||||
|
"""Initialize CPU backends."""
|
||||||
|
try:
|
||||||
|
# PyTorch CPU
|
||||||
|
self._backends['pytorch_cpu'] = PyTorchCPUBackend()
|
||||||
|
self._current_backend = 'pytorch_cpu'
|
||||||
|
logger.info("Initialized PyTorch CPU backend")
|
||||||
|
|
||||||
|
# ONNX CPU
|
||||||
|
self._backends['onnx_cpu'] = ONNXCPUBackend()
|
||||||
|
logger.info("Initialized ONNX CPU backend")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize CPU backends: {e}")
|
||||||
|
raise RuntimeError("No backends available")
|
||||||
|
|
||||||
|
def get_backend(self, backend_type: Optional[str] = None) -> BaseModelBackend:
|
||||||
|
"""Get specified backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backend_type: Backend type ('pytorch_cpu', 'pytorch_gpu', 'onnx_cpu', 'onnx_gpu'),
|
||||||
|
uses default if None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model backend instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If backend type is invalid
|
||||||
|
RuntimeError: If no backends are available
|
||||||
|
"""
|
||||||
|
if not self._backends:
|
||||||
|
raise RuntimeError("No backends available")
|
||||||
|
|
||||||
|
if backend_type is None:
|
||||||
|
backend_type = self._current_backend
|
||||||
|
|
||||||
|
if backend_type not in self._backends:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid backend type: {backend_type}. "
|
||||||
|
f"Available backends: {', '.join(self._backends.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._backends[backend_type]
|
||||||
|
|
||||||
|
def _determine_backend(self, model_path: str) -> str:
|
||||||
|
"""Determine appropriate backend based on model file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to model file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Backend type to use
|
||||||
|
"""
|
||||||
|
is_onnx = model_path.lower().endswith('.onnx')
|
||||||
|
has_gpu = settings.use_gpu and torch.cuda.is_available()
|
||||||
|
|
||||||
|
if is_onnx:
|
||||||
|
return 'onnx_gpu' if has_gpu else 'onnx_cpu'
|
||||||
|
else:
|
||||||
|
return 'pytorch_gpu' if has_gpu else 'pytorch_cpu'
|
||||||
|
|
||||||
|
async def load_model(
|
||||||
|
self,
|
||||||
|
model_path: str,
|
||||||
|
backend_type: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
"""Load model on specified backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to model file
|
||||||
|
backend_type: Backend to load on, uses default if None
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If model loading fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get absolute model path
|
||||||
|
abs_path = await paths.get_model_path(model_path)
|
||||||
|
|
||||||
|
# Auto-determine backend if not specified
|
||||||
|
if backend_type is None:
|
||||||
|
backend_type = self._determine_backend(abs_path)
|
||||||
|
|
||||||
|
backend = self.get_backend(backend_type)
|
||||||
|
|
||||||
|
# Load model and run warmup
|
||||||
|
await backend.load_model(abs_path)
|
||||||
|
logger.info(f"Loaded model on {backend_type} backend")
|
||||||
|
await self._warmup_inference(backend)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to load model: {e}")
|
||||||
|
|
||||||
|
async def _warmup_inference(self, backend: BaseModelBackend) -> None:
|
||||||
|
"""Run warmup inference to initialize model."""
|
||||||
|
try:
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from ..text_processing import process_text
|
||||||
|
|
||||||
|
# Load default voice for warmup
|
||||||
|
voice = await self._voice_manager.load_voice(settings.default_voice, device=backend.device)
|
||||||
|
logger.info(f"Loaded voice {settings.default_voice} for warmup")
|
||||||
|
|
||||||
|
# Use real text
|
||||||
|
text = "Testing text to speech synthesis."
|
||||||
|
logger.info(f"Running warmup inference with voice: af")
|
||||||
|
|
||||||
|
# Process through pipeline
|
||||||
|
sequences = process_text(text)
|
||||||
|
if not sequences:
|
||||||
|
raise ValueError("Text processing failed")
|
||||||
|
|
||||||
|
# Run inference
|
||||||
|
backend.generate(sequences[0], voice, speed=1.0)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Warmup inference failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
self,
|
||||||
|
tokens: list[int],
|
||||||
|
voice_name: str,
|
||||||
|
speed: float = 1.0,
|
||||||
|
backend_type: Optional[str] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Generate audio using specified backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: Input token IDs
|
||||||
|
voice_name: Name of voice to use
|
||||||
|
speed: Speed multiplier
|
||||||
|
backend_type: Backend to use, uses default if None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generated audio tensor
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If generation fails
|
||||||
|
"""
|
||||||
|
backend = self.get_backend(backend_type)
|
||||||
|
if not backend.is_loaded:
|
||||||
|
raise RuntimeError("Model not loaded")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load voice using voice manager
|
||||||
|
voice = await self._voice_manager.load_voice(voice_name, device=backend.device)
|
||||||
|
|
||||||
|
# Generate audio
|
||||||
|
return backend.generate(tokens, voice, speed)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Generation failed: {e}")
|
||||||
|
|
||||||
|
def unload_all(self) -> None:
|
||||||
|
"""Unload models from all backends."""
|
||||||
|
for backend in self._backends.values():
|
||||||
|
backend.unload()
|
||||||
|
logger.info("Unloaded all models")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def available_backends(self) -> list[str]:
|
||||||
|
"""Get list of available backends.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of backend names
|
||||||
|
"""
|
||||||
|
return list(self._backends.keys())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_backend(self) -> str:
|
||||||
|
"""Get current default backend.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Backend name
|
||||||
|
"""
|
||||||
|
return self._current_backend
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level instance
|
||||||
|
_manager: Optional[ModelManager] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
|
||||||
|
"""Get or create global model manager instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Optional model configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelManager instance
|
||||||
|
"""
|
||||||
|
global _manager
|
||||||
|
if _manager is None:
|
||||||
|
_manager = ModelManager(config)
|
||||||
|
return _manager
|
154
api/src/inference/onnx_cpu.py
Normal file
154
api/src/inference/onnx_cpu.py
Normal file
|
@ -0,0 +1,154 @@
|
||||||
|
"""CPU-based ONNX inference backend."""
|
||||||
|
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
from onnxruntime import (
|
||||||
|
ExecutionMode,
|
||||||
|
GraphOptimizationLevel,
|
||||||
|
InferenceSession,
|
||||||
|
SessionOptions
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..core import paths
|
||||||
|
from ..core.config import settings
|
||||||
|
from ..structures.model_schemas import ONNXConfig
|
||||||
|
from .base import BaseModelBackend
|
||||||
|
|
||||||
|
|
||||||
|
class ONNXCPUBackend(BaseModelBackend):
|
||||||
|
"""ONNX-based CPU inference backend."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize CPU backend."""
|
||||||
|
super().__init__()
|
||||||
|
self._device = "cpu"
|
||||||
|
self._session: Optional[InferenceSession] = None
|
||||||
|
self._config = ONNXConfig(
|
||||||
|
optimization_level=settings.onnx_optimization_level,
|
||||||
|
num_threads=settings.onnx_num_threads,
|
||||||
|
inter_op_threads=settings.onnx_inter_op_threads,
|
||||||
|
execution_mode=settings.onnx_execution_mode,
|
||||||
|
memory_pattern=settings.onnx_memory_pattern,
|
||||||
|
arena_extend_strategy=settings.onnx_arena_extend_strategy
|
||||||
|
)
|
||||||
|
|
||||||
|
async def load_model(self, path: str) -> None:
|
||||||
|
"""Load ONNX model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to model file
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If model loading fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get verified model path
|
||||||
|
model_path = await paths.get_model_path(path)
|
||||||
|
|
||||||
|
logger.info(f"Loading ONNX model: {model_path}")
|
||||||
|
|
||||||
|
# Configure session
|
||||||
|
options = self._create_session_options()
|
||||||
|
provider_options = self._create_provider_options()
|
||||||
|
|
||||||
|
# Create session
|
||||||
|
self._session = InferenceSession(
|
||||||
|
model_path,
|
||||||
|
sess_options=options,
|
||||||
|
providers=["CPUExecutionProvider"],
|
||||||
|
provider_options=[provider_options]
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to load ONNX model: {e}")
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
tokens: list[int],
|
||||||
|
voice: torch.Tensor,
|
||||||
|
speed: float = 1.0
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Generate audio using ONNX model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: Input token IDs
|
||||||
|
voice: Voice embedding tensor
|
||||||
|
speed: Speed multiplier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generated audio samples
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If generation fails
|
||||||
|
"""
|
||||||
|
if not self.is_loaded:
|
||||||
|
raise RuntimeError("Model not loaded")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Prepare inputs
|
||||||
|
tokens_input = np.array([tokens], dtype=np.int64)
|
||||||
|
style_input = voice[len(tokens)].numpy()
|
||||||
|
speed_input = np.full(1, speed, dtype=np.float32)
|
||||||
|
|
||||||
|
# Run inference
|
||||||
|
result = self._session.run(
|
||||||
|
None,
|
||||||
|
{
|
||||||
|
"tokens": tokens_input,
|
||||||
|
"style": style_input,
|
||||||
|
"speed": speed_input
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return result[0]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Generation failed: {e}")
|
||||||
|
|
||||||
|
def _create_session_options(self) -> SessionOptions:
|
||||||
|
"""Create ONNX session options.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured session options
|
||||||
|
"""
|
||||||
|
options = SessionOptions()
|
||||||
|
|
||||||
|
# Set optimization level
|
||||||
|
if self._config.optimization_level == "all":
|
||||||
|
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||||
|
elif self._config.optimization_level == "basic":
|
||||||
|
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||||
|
else:
|
||||||
|
options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
|
||||||
|
|
||||||
|
# Configure threading
|
||||||
|
options.intra_op_num_threads = self._config.num_threads
|
||||||
|
options.inter_op_num_threads = self._config.inter_op_threads
|
||||||
|
|
||||||
|
# Set execution mode
|
||||||
|
options.execution_mode = (
|
||||||
|
ExecutionMode.ORT_PARALLEL
|
||||||
|
if self._config.execution_mode == "parallel"
|
||||||
|
else ExecutionMode.ORT_SEQUENTIAL
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure memory optimization
|
||||||
|
options.enable_mem_pattern = self._config.memory_pattern
|
||||||
|
|
||||||
|
return options
|
||||||
|
|
||||||
|
def _create_provider_options(self) -> Dict:
|
||||||
|
"""Create CPU provider options.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Provider configuration
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"CPUExecutionProvider": {
|
||||||
|
"arena_extend_strategy": self._config.arena_extend_strategy,
|
||||||
|
"cpu_memory_arena_cfg": "cpu:0"
|
||||||
|
}
|
||||||
|
}
|
163
api/src/inference/onnx_gpu.py
Normal file
163
api/src/inference/onnx_gpu.py
Normal file
|
@ -0,0 +1,163 @@
|
||||||
|
"""GPU-based ONNX inference backend."""
|
||||||
|
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
from onnxruntime import (
|
||||||
|
ExecutionMode,
|
||||||
|
GraphOptimizationLevel,
|
||||||
|
InferenceSession,
|
||||||
|
SessionOptions
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..core import paths
|
||||||
|
from ..core.config import settings
|
||||||
|
from ..structures.model_schemas import ONNXGPUConfig
|
||||||
|
from .base import BaseModelBackend
|
||||||
|
|
||||||
|
|
||||||
|
class ONNXGPUBackend(BaseModelBackend):
|
||||||
|
"""ONNX-based GPU inference backend."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize GPU backend."""
|
||||||
|
super().__init__()
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise RuntimeError("CUDA not available")
|
||||||
|
self._device = "cuda"
|
||||||
|
self._session: Optional[InferenceSession] = None
|
||||||
|
self._config = ONNXGPUConfig(
|
||||||
|
optimization_level=settings.onnx_optimization_level,
|
||||||
|
num_threads=settings.onnx_num_threads,
|
||||||
|
inter_op_threads=settings.onnx_inter_op_threads,
|
||||||
|
execution_mode=settings.onnx_execution_mode,
|
||||||
|
memory_pattern=settings.onnx_memory_pattern,
|
||||||
|
arena_extend_strategy=settings.onnx_arena_extend_strategy,
|
||||||
|
device_id=0,
|
||||||
|
gpu_mem_limit=0.7,
|
||||||
|
cudnn_conv_algo_search="EXHAUSTIVE",
|
||||||
|
do_copy_in_default_stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
async def load_model(self, path: str) -> None:
|
||||||
|
"""Load ONNX model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to model file
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If model loading fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get verified model path
|
||||||
|
model_path = await paths.get_model_path(path)
|
||||||
|
|
||||||
|
logger.info(f"Loading ONNX model on GPU: {model_path}")
|
||||||
|
|
||||||
|
# Configure session
|
||||||
|
options = self._create_session_options()
|
||||||
|
provider_options = self._create_provider_options()
|
||||||
|
|
||||||
|
# Create session with CUDA provider
|
||||||
|
self._session = InferenceSession(
|
||||||
|
model_path,
|
||||||
|
sess_options=options,
|
||||||
|
providers=["CUDAExecutionProvider"],
|
||||||
|
provider_options=[provider_options]
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to load ONNX model: {e}")
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
tokens: list[int],
|
||||||
|
voice: torch.Tensor,
|
||||||
|
speed: float = 1.0
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Generate audio using ONNX model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: Input token IDs
|
||||||
|
voice: Voice embedding tensor
|
||||||
|
speed: Speed multiplier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generated audio samples
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If generation fails
|
||||||
|
"""
|
||||||
|
if not self.is_loaded:
|
||||||
|
raise RuntimeError("Model not loaded")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Prepare inputs
|
||||||
|
tokens_input = np.array([tokens], dtype=np.int64)
|
||||||
|
style_input = voice[len(tokens)].cpu().numpy() # Move to CPU for ONNX
|
||||||
|
speed_input = np.full(1, speed, dtype=np.float32)
|
||||||
|
|
||||||
|
# Run inference
|
||||||
|
result = self._session.run(
|
||||||
|
None,
|
||||||
|
{
|
||||||
|
"tokens": tokens_input,
|
||||||
|
"style": style_input,
|
||||||
|
"speed": speed_input
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return result[0]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Generation failed: {e}")
|
||||||
|
|
||||||
|
def _create_session_options(self) -> SessionOptions:
|
||||||
|
"""Create ONNX session options.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured session options
|
||||||
|
"""
|
||||||
|
options = SessionOptions()
|
||||||
|
|
||||||
|
# Set optimization level
|
||||||
|
if self._config.optimization_level == "all":
|
||||||
|
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||||
|
elif self._config.optimization_level == "basic":
|
||||||
|
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||||
|
else:
|
||||||
|
options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
|
||||||
|
|
||||||
|
# Configure threading
|
||||||
|
options.intra_op_num_threads = self._config.num_threads
|
||||||
|
options.inter_op_num_threads = self._config.inter_op_threads
|
||||||
|
|
||||||
|
# Set execution mode
|
||||||
|
options.execution_mode = (
|
||||||
|
ExecutionMode.ORT_PARALLEL
|
||||||
|
if self._config.execution_mode == "parallel"
|
||||||
|
else ExecutionMode.ORT_SEQUENTIAL
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure memory optimization
|
||||||
|
options.enable_mem_pattern = self._config.memory_pattern
|
||||||
|
|
||||||
|
return options
|
||||||
|
|
||||||
|
def _create_provider_options(self) -> Dict:
|
||||||
|
"""Create CUDA provider options.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Provider configuration
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"CUDAExecutionProvider": {
|
||||||
|
"device_id": self._config.device_id,
|
||||||
|
"arena_extend_strategy": self._config.arena_extend_strategy,
|
||||||
|
"gpu_mem_limit": int(self._config.gpu_mem_limit * torch.cuda.get_device_properties(0).total_memory),
|
||||||
|
"cudnn_conv_algo_search": self._config.cudnn_conv_algo_search,
|
||||||
|
"do_copy_in_default_stream": self._config.do_copy_in_default_stream
|
||||||
|
}
|
||||||
|
}
|
181
api/src/inference/pytorch_cpu.py
Normal file
181
api/src/inference/pytorch_cpu.py
Normal file
|
@ -0,0 +1,181 @@
|
||||||
|
"""CPU-based PyTorch inference backend."""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from ..builds.models import build_model
|
||||||
|
from ..core import paths
|
||||||
|
from ..structures.model_schemas import PyTorchCPUConfig
|
||||||
|
from .base import BaseModelBackend
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(model: torch.nn.Module, tokens: list[int], ref_s: torch.Tensor, speed: float) -> np.ndarray:
|
||||||
|
"""Forward pass through model with memory management.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: PyTorch model
|
||||||
|
tokens: Input tokens
|
||||||
|
ref_s: Reference signal
|
||||||
|
speed: Speed multiplier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generated audio
|
||||||
|
"""
|
||||||
|
device = ref_s.device
|
||||||
|
pred_aln_trg = None
|
||||||
|
asr = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initial tensor setup
|
||||||
|
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
||||||
|
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
||||||
|
text_mask = length_to_mask(input_lengths).to(device)
|
||||||
|
|
||||||
|
# Split reference signals
|
||||||
|
s_content = ref_s[:, 128:].clone().to(device)
|
||||||
|
s_ref = ref_s[:, :128].clone().to(device)
|
||||||
|
|
||||||
|
# BERT and encoder pass
|
||||||
|
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
||||||
|
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
||||||
|
|
||||||
|
# Predictor forward pass
|
||||||
|
d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask)
|
||||||
|
x, _ = model.predictor.lstm(d)
|
||||||
|
|
||||||
|
# Duration prediction
|
||||||
|
duration = model.predictor.duration_proj(x)
|
||||||
|
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
||||||
|
pred_dur = torch.round(duration).clamp(min=1).long()
|
||||||
|
del duration, x # Free large intermediates
|
||||||
|
|
||||||
|
# Alignment matrix construction
|
||||||
|
pred_aln_trg = torch.zeros(
|
||||||
|
input_lengths.item(),
|
||||||
|
pred_dur.sum().item(),
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
c_frame = 0
|
||||||
|
for i in range(pred_aln_trg.size(0)):
|
||||||
|
pred_aln_trg[i, c_frame:c_frame + pred_dur[0, i].item()] = 1
|
||||||
|
c_frame += pred_dur[0, i].item()
|
||||||
|
pred_aln_trg = pred_aln_trg.unsqueeze(0)
|
||||||
|
|
||||||
|
# Matrix multiplications with cleanup
|
||||||
|
en = d.transpose(-1, -2) @ pred_aln_trg
|
||||||
|
del d
|
||||||
|
|
||||||
|
F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
|
||||||
|
del en
|
||||||
|
|
||||||
|
# Final text encoding and decoding
|
||||||
|
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
||||||
|
asr = t_en @ pred_aln_trg
|
||||||
|
del t_en
|
||||||
|
|
||||||
|
# Generate output
|
||||||
|
output = model.decoder(asr, F0_pred, N_pred, s_ref)
|
||||||
|
result = output.squeeze().cpu().numpy()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up largest tensors if they were created
|
||||||
|
if pred_aln_trg is not None:
|
||||||
|
del pred_aln_trg
|
||||||
|
if asr is not None:
|
||||||
|
del asr
|
||||||
|
|
||||||
|
|
||||||
|
def length_to_mask(lengths: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Create attention mask from lengths.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lengths: Sequence lengths
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Boolean mask tensor
|
||||||
|
"""
|
||||||
|
max_len = lengths.max()
|
||||||
|
mask = torch.arange(max_len, device=lengths.device)[None, :].expand(
|
||||||
|
lengths.shape[0], -1
|
||||||
|
)
|
||||||
|
if lengths.dtype != mask.dtype:
|
||||||
|
mask = mask.to(dtype=lengths.dtype)
|
||||||
|
return mask + 1 > lengths[:, None]
|
||||||
|
|
||||||
|
|
||||||
|
class PyTorchCPUBackend(BaseModelBackend):
|
||||||
|
"""PyTorch CPU inference backend."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize CPU backend."""
|
||||||
|
super().__init__()
|
||||||
|
self._device = "cpu"
|
||||||
|
self._model: Optional[torch.nn.Module] = None
|
||||||
|
self._config = PyTorchCPUConfig()
|
||||||
|
|
||||||
|
# Configure PyTorch CPU settings
|
||||||
|
if self._config.num_threads > 0:
|
||||||
|
torch.set_num_threads(self._config.num_threads)
|
||||||
|
if self._config.pin_memory:
|
||||||
|
torch.set_default_tensor_type(torch.FloatTensor)
|
||||||
|
|
||||||
|
async def load_model(self, path: str) -> None:
|
||||||
|
"""Load PyTorch model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to model file
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If model loading fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get verified model path
|
||||||
|
model_path = await paths.get_model_path(path)
|
||||||
|
|
||||||
|
logger.info(f"Loading PyTorch model on CPU: {model_path}")
|
||||||
|
self._model = await build_model(model_path, self._device)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to load PyTorch model: {e}")
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
tokens: list[int],
|
||||||
|
voice: torch.Tensor,
|
||||||
|
speed: float = 1.0
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Generate audio using CPU model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: Input token IDs
|
||||||
|
voice: Voice embedding tensor
|
||||||
|
speed: Speed multiplier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generated audio samples
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If generation fails
|
||||||
|
"""
|
||||||
|
if not self.is_loaded:
|
||||||
|
raise RuntimeError("Model not loaded")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Prepare input
|
||||||
|
ref_s = voice[len(tokens)].clone()
|
||||||
|
|
||||||
|
# Generate audio
|
||||||
|
return forward(self._model, tokens, ref_s, speed)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Generation failed: {e}")
|
||||||
|
finally:
|
||||||
|
# Clean up memory
|
||||||
|
gc.collect()
|
170
api/src/inference/pytorch_gpu.py
Normal file
170
api/src/inference/pytorch_gpu.py
Normal file
|
@ -0,0 +1,170 @@
|
||||||
|
"""GPU-based PyTorch inference backend."""
|
||||||
|
|
||||||
|
import gc
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from ..builds.models import build_model
|
||||||
|
from ..core import paths
|
||||||
|
from ..structures.model_schemas import PyTorchConfig
|
||||||
|
from .base import BaseModelBackend
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(model: torch.nn.Module, tokens: list[int], ref_s: torch.Tensor, speed: float) -> np.ndarray:
|
||||||
|
"""Forward pass through model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: PyTorch model
|
||||||
|
tokens: Input tokens
|
||||||
|
ref_s: Reference signal (shape: [1, n_features])
|
||||||
|
speed: Speed multiplier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generated audio
|
||||||
|
"""
|
||||||
|
device = ref_s.device
|
||||||
|
|
||||||
|
# Initial tensor setup
|
||||||
|
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
||||||
|
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
||||||
|
text_mask = length_to_mask(input_lengths).to(device)
|
||||||
|
|
||||||
|
# Split reference signals (style_dim=128 from config)
|
||||||
|
style_dim = 128
|
||||||
|
s_ref = ref_s[:, :style_dim].clone().to(device)
|
||||||
|
s_content = ref_s[:, style_dim:].clone().to(device)
|
||||||
|
|
||||||
|
# BERT and encoder pass
|
||||||
|
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
||||||
|
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
||||||
|
|
||||||
|
# Predictor forward pass
|
||||||
|
d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask)
|
||||||
|
x, _ = model.predictor.lstm(d)
|
||||||
|
|
||||||
|
# Duration prediction
|
||||||
|
duration = model.predictor.duration_proj(x)
|
||||||
|
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
||||||
|
pred_dur = torch.round(duration).clamp(min=1).long()
|
||||||
|
del duration, x
|
||||||
|
|
||||||
|
# Alignment matrix construction
|
||||||
|
pred_aln_trg = torch.zeros(input_lengths.item(), pred_dur.sum().item(), device=device)
|
||||||
|
c_frame = 0
|
||||||
|
for i in range(pred_aln_trg.size(0)):
|
||||||
|
pred_aln_trg[i, c_frame:c_frame + pred_dur[0, i].item()] = 1
|
||||||
|
c_frame += pred_dur[0, i].item()
|
||||||
|
pred_aln_trg = pred_aln_trg.unsqueeze(0)
|
||||||
|
|
||||||
|
# Matrix multiplications
|
||||||
|
en = d.transpose(-1, -2) @ pred_aln_trg
|
||||||
|
del d
|
||||||
|
|
||||||
|
F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
|
||||||
|
del en
|
||||||
|
|
||||||
|
# Final text encoding and decoding
|
||||||
|
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
||||||
|
asr = t_en @ pred_aln_trg
|
||||||
|
del t_en
|
||||||
|
|
||||||
|
# Generate output
|
||||||
|
output = model.decoder(asr, F0_pred, N_pred, s_ref)
|
||||||
|
return output.squeeze().cpu().numpy()
|
||||||
|
|
||||||
|
|
||||||
|
def length_to_mask(lengths: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Create attention mask from lengths."""
|
||||||
|
max_len = lengths.max()
|
||||||
|
mask = torch.arange(max_len, device=lengths.device)[None, :].expand(lengths.shape[0], -1)
|
||||||
|
if lengths.dtype != mask.dtype:
|
||||||
|
mask = mask.to(dtype=lengths.dtype)
|
||||||
|
return mask + 1 > lengths[:, None]
|
||||||
|
|
||||||
|
|
||||||
|
class PyTorchGPUBackend(BaseModelBackend):
|
||||||
|
"""PyTorch GPU inference backend."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize GPU backend."""
|
||||||
|
super().__init__()
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise RuntimeError("CUDA not available")
|
||||||
|
self._device = "cuda"
|
||||||
|
self._model: Optional[torch.nn.Module] = None
|
||||||
|
self._config = PyTorchConfig()
|
||||||
|
|
||||||
|
async def load_model(self, path: str) -> None:
|
||||||
|
"""Load PyTorch model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to model file
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If model loading fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get verified model path
|
||||||
|
model_path = await paths.get_model_path(path)
|
||||||
|
|
||||||
|
logger.info(f"Loading PyTorch model: {model_path}")
|
||||||
|
self._model = await build_model(model_path, self._device)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to load PyTorch model: {e}")
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
tokens: list[int],
|
||||||
|
voice: torch.Tensor,
|
||||||
|
speed: float = 1.0
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Generate audio using GPU model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: Input token IDs
|
||||||
|
voice: Voice embedding tensor
|
||||||
|
speed: Speed multiplier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generated audio samples
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If generation fails
|
||||||
|
"""
|
||||||
|
if not self.is_loaded:
|
||||||
|
raise RuntimeError("Model not loaded")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check memory and cleanup if needed
|
||||||
|
if self._check_memory():
|
||||||
|
self._clear_memory()
|
||||||
|
|
||||||
|
# Get reference style from voice pack
|
||||||
|
ref_s = voice[len(tokens)].clone().to(self._device)
|
||||||
|
if ref_s.dim() == 1:
|
||||||
|
ref_s = ref_s.unsqueeze(0) # Add batch dimension if needed
|
||||||
|
|
||||||
|
# Generate audio
|
||||||
|
return forward(self._model, tokens, ref_s, speed)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Generation failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _check_memory(self) -> bool:
|
||||||
|
"""Check if memory usage is above threshold."""
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
memory_gb = torch.cuda.memory_allocated() / 1e9
|
||||||
|
return memory_gb > self._config.memory_threshold
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _clear_memory(self) -> None:
|
||||||
|
"""Clear GPU memory."""
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
191
api/src/inference/voice_manager.py
Normal file
191
api/src/inference/voice_manager.py
Normal file
|
@ -0,0 +1,191 @@
|
||||||
|
"""Voice pack management and caching."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from ..core import paths
|
||||||
|
from ..core.config import settings
|
||||||
|
from ..structures.model_schemas import VoiceConfig
|
||||||
|
|
||||||
|
|
||||||
|
class VoiceManager:
|
||||||
|
"""Manages voice loading, caching, and operations."""
|
||||||
|
|
||||||
|
def __init__(self, config: Optional[VoiceConfig] = None):
|
||||||
|
"""Initialize voice manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Optional voice configuration
|
||||||
|
"""
|
||||||
|
self._config = config or VoiceConfig()
|
||||||
|
self._voice_cache: Dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
|
def get_voice_path(self, voice_name: str) -> Optional[str]:
|
||||||
|
"""Get path to voice file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
voice_name: Name of voice
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to voice file if exists, None otherwise
|
||||||
|
"""
|
||||||
|
voice_path = os.path.join(settings.voices_dir, f"{voice_name}.pt")
|
||||||
|
return voice_path if os.path.exists(voice_path) else None
|
||||||
|
|
||||||
|
async def load_voice(self, voice_name: str, device: str = "cpu") -> torch.Tensor:
|
||||||
|
"""Load voice tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
voice_name: Name of voice to load
|
||||||
|
device: Device to load voice on
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Voice tensor
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If voice loading fails
|
||||||
|
"""
|
||||||
|
voice_path = self.get_voice_path(voice_name)
|
||||||
|
if not voice_path:
|
||||||
|
raise RuntimeError(f"Voice not found: {voice_name}")
|
||||||
|
|
||||||
|
# Check cache first
|
||||||
|
cache_key = f"{voice_path}_{device}"
|
||||||
|
if self._config.use_cache and cache_key in self._voice_cache:
|
||||||
|
return self._voice_cache[cache_key]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load voice tensor
|
||||||
|
voice = await paths.load_voice_tensor(voice_path, device=device)
|
||||||
|
|
||||||
|
# Cache if enabled
|
||||||
|
if self._config.use_cache:
|
||||||
|
self._manage_cache()
|
||||||
|
self._voice_cache[cache_key] = voice
|
||||||
|
logger.debug(f"Cached voice: {voice_name} on {device}")
|
||||||
|
|
||||||
|
return voice
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to load voice {voice_name}: {e}")
|
||||||
|
|
||||||
|
def _manage_cache(self) -> None:
|
||||||
|
"""Manage voice cache size."""
|
||||||
|
if len(self._voice_cache) >= self._config.cache_size:
|
||||||
|
# Remove oldest voice
|
||||||
|
oldest = next(iter(self._voice_cache))
|
||||||
|
del self._voice_cache[oldest]
|
||||||
|
logger.debug(f"Removed from voice cache: {oldest}")
|
||||||
|
|
||||||
|
async def combine_voices(self, voices: List[str], device: str = "cpu") -> str:
|
||||||
|
"""Combine multiple voices into a new voice.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
voices: List of voice names to combine
|
||||||
|
device: Device to load voices on
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Name of combined voice
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If fewer than 2 voices provided
|
||||||
|
RuntimeError: If voice combination fails
|
||||||
|
"""
|
||||||
|
if len(voices) < 2:
|
||||||
|
raise ValueError("At least 2 voices are required for combination")
|
||||||
|
|
||||||
|
# Load voices
|
||||||
|
voice_tensors: List[torch.Tensor] = []
|
||||||
|
for voice in voices:
|
||||||
|
try:
|
||||||
|
voice_tensor = await self.load_voice(voice, device)
|
||||||
|
voice_tensors.append(voice_tensor)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to load voice {voice}: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Combine voices
|
||||||
|
combined_name = "_".join(voices)
|
||||||
|
combined_tensor = torch.mean(torch.stack(voice_tensors), dim=0)
|
||||||
|
|
||||||
|
# Save combined voice
|
||||||
|
combined_path = os.path.join(settings.voices_dir, f"{combined_name}.pt")
|
||||||
|
try:
|
||||||
|
torch.save(combined_tensor, combined_path)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to save combined voice: {e}")
|
||||||
|
|
||||||
|
return combined_name
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to combine voices: {e}")
|
||||||
|
|
||||||
|
async def list_voices(self) -> List[str]:
|
||||||
|
"""List available voices.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of voice names
|
||||||
|
"""
|
||||||
|
voices = []
|
||||||
|
try:
|
||||||
|
for entry in os.listdir(settings.voices_dir):
|
||||||
|
if entry.endswith(".pt"):
|
||||||
|
voices.append(entry[:-3]) # Remove .pt extension
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error listing voices: {e}")
|
||||||
|
return sorted(voices)
|
||||||
|
|
||||||
|
def validate_voice(self, voice_path: str) -> bool:
|
||||||
|
"""Validate voice file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
voice_path: Path to voice file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if valid, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not os.path.exists(voice_path):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Try loading voice
|
||||||
|
voice = torch.load(voice_path, map_location="cpu")
|
||||||
|
return isinstance(voice, torch.Tensor)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache_info(self) -> Dict[str, int]:
|
||||||
|
"""Get cache statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with cache info
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
'size': len(self._voice_cache),
|
||||||
|
'max_size': self._config.cache_size
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level instance
|
||||||
|
_manager: Optional[VoiceManager] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_manager(config: Optional[VoiceConfig] = None) -> VoiceManager:
|
||||||
|
"""Get or create global voice manager instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Optional voice configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
VoiceManager instance
|
||||||
|
"""
|
||||||
|
global _manager
|
||||||
|
if _manager is None:
|
||||||
|
_manager = VoiceManager(config)
|
||||||
|
return _manager
|
|
@ -1,175 +0,0 @@
|
||||||
import os
|
|
||||||
import threading
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from ..core.config import settings
|
|
||||||
|
|
||||||
|
|
||||||
class TTSBaseModel(ABC):
|
|
||||||
_instance = None
|
|
||||||
_lock = threading.Lock()
|
|
||||||
_device = None
|
|
||||||
VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def setup(cls):
|
|
||||||
"""Initialize model and setup voices"""
|
|
||||||
with cls._lock:
|
|
||||||
# Set device
|
|
||||||
cuda_available = torch.cuda.is_available()
|
|
||||||
logger.info(f"CUDA available: {cuda_available}")
|
|
||||||
if cuda_available:
|
|
||||||
try:
|
|
||||||
# Test CUDA device
|
|
||||||
test_tensor = torch.zeros(1).cuda()
|
|
||||||
logger.info("CUDA test successful")
|
|
||||||
model_path = os.path.join(
|
|
||||||
settings.model_dir, settings.pytorch_model_path
|
|
||||||
)
|
|
||||||
cls._device = "cuda"
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"CUDA test failed: {e}")
|
|
||||||
cls._device = "cpu"
|
|
||||||
else:
|
|
||||||
cls._device = "cpu"
|
|
||||||
model_path = os.path.join(settings.model_dir, settings.onnx_model_path)
|
|
||||||
logger.info(f"Initializing model on {cls._device}")
|
|
||||||
logger.info(f"Model dir: {settings.model_dir}")
|
|
||||||
logger.info(f"Model path: {model_path}")
|
|
||||||
logger.info(f"Files in model dir: {os.listdir(settings.model_dir)}")
|
|
||||||
|
|
||||||
# Initialize model first
|
|
||||||
model = cls.initialize(settings.model_dir, model_path=model_path)
|
|
||||||
if model is None:
|
|
||||||
raise RuntimeError(f"Failed to initialize {cls._device.upper()} model")
|
|
||||||
cls._instance = model
|
|
||||||
|
|
||||||
# Setup voices directory
|
|
||||||
os.makedirs(cls.VOICES_DIR, exist_ok=True)
|
|
||||||
|
|
||||||
# Copy base voices to local directory
|
|
||||||
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
|
|
||||||
if os.path.exists(base_voices_dir):
|
|
||||||
for file in os.listdir(base_voices_dir):
|
|
||||||
if file.endswith(".pt"):
|
|
||||||
voice_name = file[:-3]
|
|
||||||
voice_path = os.path.join(cls.VOICES_DIR, file)
|
|
||||||
if not os.path.exists(voice_path):
|
|
||||||
try:
|
|
||||||
logger.info(
|
|
||||||
f"Copying base voice {voice_name} to voices directory"
|
|
||||||
)
|
|
||||||
base_path = os.path.join(base_voices_dir, file)
|
|
||||||
voicepack = torch.load(
|
|
||||||
base_path,
|
|
||||||
map_location=cls._device,
|
|
||||||
weights_only=True,
|
|
||||||
)
|
|
||||||
torch.save(voicepack, voice_path)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error copying voice {voice_name}: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Count voices in directory
|
|
||||||
voice_count = len(
|
|
||||||
[f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Now that model and voices are ready, do warmup
|
|
||||||
try:
|
|
||||||
with open(
|
|
||||||
os.path.join(
|
|
||||||
os.path.dirname(os.path.dirname(__file__)),
|
|
||||||
"core",
|
|
||||||
"don_quixote.txt",
|
|
||||||
)
|
|
||||||
) as f:
|
|
||||||
warmup_text = f.read()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to load warmup text: {e}")
|
|
||||||
warmup_text = "This is a warmup text that will be split into chunks for processing."
|
|
||||||
|
|
||||||
# Use warmup service after model is fully initialized
|
|
||||||
from .warmup import WarmupService
|
|
||||||
|
|
||||||
warmup = WarmupService()
|
|
||||||
|
|
||||||
# Load and warm up voices
|
|
||||||
loaded_voices = warmup.load_voices()
|
|
||||||
await warmup.warmup_voices(warmup_text, loaded_voices)
|
|
||||||
|
|
||||||
logger.info("Model warm-up complete")
|
|
||||||
|
|
||||||
# Count voices in directory
|
|
||||||
voice_count = len(
|
|
||||||
[f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]
|
|
||||||
)
|
|
||||||
return voice_count
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@abstractmethod
|
|
||||||
def initialize(cls, model_dir: str, model_path: str = None):
|
|
||||||
"""Initialize the model"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@abstractmethod
|
|
||||||
def process_text(cls, text: str, language: str) -> Tuple[str, List[int]]:
|
|
||||||
"""Process text into phonemes and tokens
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Input text
|
|
||||||
language: Language code
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[str, list[int]]: Phonemes and token IDs
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@abstractmethod
|
|
||||||
def generate_from_text(
|
|
||||||
cls, text: str, voicepack: torch.Tensor, language: str, speed: float
|
|
||||||
) -> Tuple[np.ndarray, str]:
|
|
||||||
"""Generate audio from text
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Input text
|
|
||||||
voicepack: Voice tensor
|
|
||||||
language: Language code
|
|
||||||
speed: Speed factor
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[np.ndarray, str]: Generated audio samples and phonemes
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@abstractmethod
|
|
||||||
def generate_from_tokens(
|
|
||||||
cls, tokens: List[int], voicepack: torch.Tensor, speed: float
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Generate audio from tokens
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tokens: Token IDs
|
|
||||||
voicepack: Voice tensor
|
|
||||||
speed: Speed factor
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: Generated audio samples
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_device(cls):
|
|
||||||
"""Get the current device"""
|
|
||||||
if cls._device is None:
|
|
||||||
raise RuntimeError("Model not initialized. Call setup() first.")
|
|
||||||
return cls._device
|
|
|
@ -1,167 +0,0 @@
|
||||||
import os
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from loguru import logger
|
|
||||||
from onnxruntime import (
|
|
||||||
ExecutionMode,
|
|
||||||
GraphOptimizationLevel,
|
|
||||||
InferenceSession,
|
|
||||||
SessionOptions,
|
|
||||||
)
|
|
||||||
|
|
||||||
from ..core.config import settings
|
|
||||||
from .text_processing import phonemize, tokenize
|
|
||||||
from .tts_base import TTSBaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class TTSCPUModel(TTSBaseModel):
|
|
||||||
_instance = None
|
|
||||||
_onnx_session = None
|
|
||||||
_device = "cpu"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_instance(cls):
|
|
||||||
"""Get the model instance"""
|
|
||||||
if cls._onnx_session is None:
|
|
||||||
raise RuntimeError("ONNX model not initialized. Call initialize() first.")
|
|
||||||
return cls._onnx_session
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def initialize(cls, model_dir: str, model_path: str = None):
|
|
||||||
"""Initialize ONNX model for CPU inference"""
|
|
||||||
if cls._onnx_session is None:
|
|
||||||
try:
|
|
||||||
# Try loading ONNX model
|
|
||||||
onnx_path = os.path.join(model_dir, settings.onnx_model_path)
|
|
||||||
if not os.path.exists(onnx_path):
|
|
||||||
logger.error(f"ONNX model not found at {onnx_path}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
logger.info(f"Loading ONNX model from {onnx_path}")
|
|
||||||
|
|
||||||
# Configure ONNX session for optimal performance
|
|
||||||
session_options = SessionOptions()
|
|
||||||
|
|
||||||
# Set optimization level
|
|
||||||
if settings.onnx_optimization_level == "all":
|
|
||||||
session_options.graph_optimization_level = (
|
|
||||||
GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
||||||
)
|
|
||||||
elif settings.onnx_optimization_level == "basic":
|
|
||||||
session_options.graph_optimization_level = (
|
|
||||||
GraphOptimizationLevel.ORT_ENABLE_BASIC
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
session_options.graph_optimization_level = (
|
|
||||||
GraphOptimizationLevel.ORT_DISABLE_ALL
|
|
||||||
)
|
|
||||||
|
|
||||||
# Configure threading
|
|
||||||
session_options.intra_op_num_threads = settings.onnx_num_threads
|
|
||||||
session_options.inter_op_num_threads = settings.onnx_inter_op_threads
|
|
||||||
|
|
||||||
# Set execution mode
|
|
||||||
session_options.execution_mode = (
|
|
||||||
ExecutionMode.ORT_PARALLEL
|
|
||||||
if settings.onnx_execution_mode == "parallel"
|
|
||||||
else ExecutionMode.ORT_SEQUENTIAL
|
|
||||||
)
|
|
||||||
|
|
||||||
# Enable/disable memory pattern optimization
|
|
||||||
session_options.enable_mem_pattern = settings.onnx_memory_pattern
|
|
||||||
|
|
||||||
# Configure CPU provider options
|
|
||||||
provider_options = {
|
|
||||||
"CPUExecutionProvider": {
|
|
||||||
"arena_extend_strategy": settings.onnx_arena_extend_strategy,
|
|
||||||
"cpu_memory_arena_cfg": "cpu:0",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
session = InferenceSession(
|
|
||||||
onnx_path,
|
|
||||||
sess_options=session_options,
|
|
||||||
providers=["CPUExecutionProvider"],
|
|
||||||
provider_options=[provider_options],
|
|
||||||
)
|
|
||||||
cls._onnx_session = session
|
|
||||||
return session
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to initialize ONNX model: {e}")
|
|
||||||
return None
|
|
||||||
return cls._onnx_session
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def process_text(cls, text: str, language: str) -> tuple[str, list[int]]:
|
|
||||||
"""Process text into phonemes and tokens
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Input text
|
|
||||||
language: Language code
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[str, list[int]]: Phonemes and token IDs
|
|
||||||
"""
|
|
||||||
phonemes = phonemize(text, language)
|
|
||||||
tokens = tokenize(phonemes)
|
|
||||||
tokens = [0] + tokens + [0] # Add start/end tokens
|
|
||||||
return phonemes, tokens
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def generate_from_text(
|
|
||||||
cls, text: str, voicepack: torch.Tensor, language: str, speed: float
|
|
||||||
) -> tuple[np.ndarray, str]:
|
|
||||||
"""Generate audio from text
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Input text
|
|
||||||
voicepack: Voice tensor
|
|
||||||
language: Language code
|
|
||||||
speed: Speed factor
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[np.ndarray, str]: Generated audio samples and phonemes
|
|
||||||
"""
|
|
||||||
if cls._onnx_session is None:
|
|
||||||
raise RuntimeError("ONNX model not initialized")
|
|
||||||
|
|
||||||
# Process text
|
|
||||||
phonemes, tokens = cls.process_text(text, language)
|
|
||||||
|
|
||||||
# Generate audio
|
|
||||||
audio = cls.generate_from_tokens(tokens, voicepack, speed)
|
|
||||||
|
|
||||||
return audio, phonemes
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def generate_from_tokens(
|
|
||||||
cls, tokens: list[int], voicepack: torch.Tensor, speed: float
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Generate audio from tokens
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tokens: Token IDs
|
|
||||||
voicepack: Voice tensor
|
|
||||||
speed: Speed factor
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: Generated audio samples
|
|
||||||
"""
|
|
||||||
if cls._onnx_session is None:
|
|
||||||
raise RuntimeError("ONNX model not initialized")
|
|
||||||
|
|
||||||
# Pre-allocate and prepare inputs
|
|
||||||
tokens_input = np.array([tokens], dtype=np.int64)
|
|
||||||
style_input = voicepack[
|
|
||||||
len(tokens) - 2
|
|
||||||
].numpy() # Already has correct dimensions
|
|
||||||
speed_input = np.full(
|
|
||||||
1, speed, dtype=np.float32
|
|
||||||
) # More efficient than ones * speed
|
|
||||||
|
|
||||||
# Run inference with optimized inputs
|
|
||||||
result = cls._onnx_session.run(
|
|
||||||
None, {"tokens": tokens_input, "style": style_input, "speed": speed_input}
|
|
||||||
)
|
|
||||||
return result[0]
|
|
|
@ -1,262 +0,0 @@
|
||||||
import os
|
|
||||||
import time
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from ..builds.models import build_model
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from ..core.config import settings
|
|
||||||
from .text_processing import phonemize, tokenize
|
|
||||||
from .tts_base import TTSBaseModel
|
|
||||||
|
|
||||||
|
|
||||||
# @torch.no_grad()
|
|
||||||
# def forward(model, tokens, ref_s, speed):
|
|
||||||
# """Forward pass through the model"""
|
|
||||||
# device = ref_s.device
|
|
||||||
# tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
|
||||||
# input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
|
||||||
# text_mask = length_to_mask(input_lengths).to(device)
|
|
||||||
# bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
|
||||||
# d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
|
||||||
# s = ref_s[:, 128:]
|
|
||||||
# d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
|
|
||||||
# x, _ = model.predictor.lstm(d)
|
|
||||||
# duration = model.predictor.duration_proj(x)
|
|
||||||
# duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
|
||||||
# pred_dur = torch.round(duration).clamp(min=1).long()
|
|
||||||
# pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
|
|
||||||
# c_frame = 0
|
|
||||||
# for i in range(pred_aln_trg.size(0)):
|
|
||||||
# pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
|
|
||||||
# c_frame += pred_dur[0, i].item()
|
|
||||||
# en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
|
|
||||||
# F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
|
||||||
# t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
|
||||||
# asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
|
|
||||||
# return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward(model, tokens, ref_s, speed):
|
|
||||||
"""Forward pass through the model with moderate memory management"""
|
|
||||||
device = ref_s.device
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Initial tensor setup with proper device placement
|
|
||||||
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
|
||||||
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
|
||||||
text_mask = length_to_mask(input_lengths).to(device)
|
|
||||||
|
|
||||||
# Split and clone reference signals with explicit device placement
|
|
||||||
s_content = ref_s[:, 128:].clone().to(device)
|
|
||||||
s_ref = ref_s[:, :128].clone().to(device)
|
|
||||||
|
|
||||||
# BERT and encoder pass
|
|
||||||
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
|
||||||
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
|
||||||
|
|
||||||
# Predictor forward pass
|
|
||||||
d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask)
|
|
||||||
x, _ = model.predictor.lstm(d)
|
|
||||||
|
|
||||||
# Duration prediction
|
|
||||||
duration = model.predictor.duration_proj(x)
|
|
||||||
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
|
||||||
pred_dur = torch.round(duration).clamp(min=1).long()
|
|
||||||
# Only cleanup large intermediates
|
|
||||||
del duration, x
|
|
||||||
|
|
||||||
# Alignment matrix construction
|
|
||||||
pred_aln_trg = torch.zeros(input_lengths.item(), pred_dur.sum().item(), device=device)
|
|
||||||
c_frame = 0
|
|
||||||
for i in range(pred_aln_trg.size(0)):
|
|
||||||
pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
|
|
||||||
c_frame += pred_dur[0, i].item()
|
|
||||||
pred_aln_trg = pred_aln_trg.unsqueeze(0)
|
|
||||||
|
|
||||||
# Matrix multiplications with selective cleanup
|
|
||||||
en = d.transpose(-1, -2) @ pred_aln_trg
|
|
||||||
del d # Free large intermediate tensor
|
|
||||||
|
|
||||||
F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
|
|
||||||
del en # Free large intermediate tensor
|
|
||||||
|
|
||||||
# Final text encoding and decoding
|
|
||||||
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
|
||||||
asr = t_en @ pred_aln_trg
|
|
||||||
del t_en # Free large intermediate tensor
|
|
||||||
|
|
||||||
# Final decoding and transfer to CPU
|
|
||||||
output = model.decoder(asr, F0_pred, N_pred, s_ref)
|
|
||||||
result = output.squeeze().cpu().numpy()
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# Let PyTorch handle most cleanup automatically
|
|
||||||
# Only explicitly free the largest tensors
|
|
||||||
del pred_aln_trg, asr
|
|
||||||
|
|
||||||
|
|
||||||
# def length_to_mask(lengths):
|
|
||||||
# """Create attention mask from lengths"""
|
|
||||||
# mask = (
|
|
||||||
# torch.arange(lengths.max())
|
|
||||||
# .unsqueeze(0)
|
|
||||||
# .expand(lengths.shape[0], -1)
|
|
||||||
# .type_as(lengths)
|
|
||||||
# )
|
|
||||||
# mask = torch.gt(mask + 1, lengths.unsqueeze(1))
|
|
||||||
# return mask
|
|
||||||
|
|
||||||
|
|
||||||
def length_to_mask(lengths):
|
|
||||||
"""Create attention mask from lengths - possibly optimized version"""
|
|
||||||
max_len = lengths.max()
|
|
||||||
# Create mask directly on the same device as lengths
|
|
||||||
mask = torch.arange(max_len, device=lengths.device)[None, :].expand(
|
|
||||||
lengths.shape[0], -1
|
|
||||||
)
|
|
||||||
# Avoid type_as by using the correct dtype from the start
|
|
||||||
if lengths.dtype != mask.dtype:
|
|
||||||
mask = mask.to(dtype=lengths.dtype)
|
|
||||||
# Fuse operations using broadcasting
|
|
||||||
return mask + 1 > lengths[:, None]
|
|
||||||
|
|
||||||
|
|
||||||
class TTSGPUModel(TTSBaseModel):
|
|
||||||
_instance = None
|
|
||||||
_device = "cuda"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_instance(cls):
|
|
||||||
"""Get the model instance"""
|
|
||||||
if cls._instance is None:
|
|
||||||
raise RuntimeError("GPU model not initialized. Call initialize() first.")
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def initialize(cls, model_dir: str, model_path: str):
|
|
||||||
"""Initialize PyTorch model for GPU inference"""
|
|
||||||
if cls._instance is None and torch.cuda.is_available():
|
|
||||||
try:
|
|
||||||
logger.info("Initializing GPU model")
|
|
||||||
model_path = os.path.join(model_dir, settings.pytorch_model_path)
|
|
||||||
model = build_model(model_path, cls._device)
|
|
||||||
cls._instance = model
|
|
||||||
return model
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to initialize GPU model: {e}")
|
|
||||||
return None
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def process_text(cls, text: str, language: str) -> tuple[str, list[int]]:
|
|
||||||
"""Process text into phonemes and tokens
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Input text
|
|
||||||
language: Language code
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[str, list[int]]: Phonemes and token IDs
|
|
||||||
"""
|
|
||||||
phonemes = phonemize(text, language)
|
|
||||||
tokens = tokenize(phonemes)
|
|
||||||
return phonemes, tokens
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def generate_from_text(
|
|
||||||
cls, text: str, voicepack: torch.Tensor, language: str, speed: float
|
|
||||||
) -> tuple[np.ndarray, str]:
|
|
||||||
"""Generate audio from text
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Input text
|
|
||||||
voicepack: Voice tensor
|
|
||||||
language: Language code
|
|
||||||
speed: Speed factor
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[np.ndarray, str]: Generated audio samples and phonemes
|
|
||||||
"""
|
|
||||||
if cls._instance is None:
|
|
||||||
raise RuntimeError("GPU model not initialized")
|
|
||||||
|
|
||||||
# Process text
|
|
||||||
phonemes, tokens = cls.process_text(text, language)
|
|
||||||
|
|
||||||
# Generate audio
|
|
||||||
audio = cls.generate_from_tokens(tokens, voicepack, speed)
|
|
||||||
|
|
||||||
return audio, phonemes
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def generate_from_tokens(
|
|
||||||
cls, tokens: list[int], voicepack: torch.Tensor, speed: float
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Generate audio from tokens with moderate memory management
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tokens: Token IDs
|
|
||||||
voicepack: Voice tensor
|
|
||||||
speed: Speed factor
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
np.ndarray: Generated audio samples
|
|
||||||
"""
|
|
||||||
if cls._instance is None:
|
|
||||||
raise RuntimeError("GPU model not initialized")
|
|
||||||
|
|
||||||
try:
|
|
||||||
device = cls._device
|
|
||||||
|
|
||||||
# Check memory pressure
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
memory_allocated = torch.cuda.memory_allocated(device) / 1e9 # Convert to GB
|
|
||||||
if memory_allocated > 2.0: # 2GB limit
|
|
||||||
logger.info(
|
|
||||||
f"Memory usage above 2GB threshold:{memory_allocated:.2f}GB "
|
|
||||||
f"Clearing cache"
|
|
||||||
)
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
import gc
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
# Get reference style with proper device placement
|
|
||||||
ref_s = voicepack[len(tokens)].clone().to(device)
|
|
||||||
|
|
||||||
# Generate audio
|
|
||||||
audio = forward(cls._instance, tokens, ref_s, speed)
|
|
||||||
|
|
||||||
return audio
|
|
||||||
|
|
||||||
except RuntimeError as e:
|
|
||||||
if "out of memory" in str(e):
|
|
||||||
# On OOM, do a full cleanup and retry
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
logger.warning("Out of memory detected, performing full cleanup")
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
import gc
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
# Log memory stats after cleanup
|
|
||||||
memory_allocated = torch.cuda.memory_allocated(device)
|
|
||||||
memory_reserved = torch.cuda.memory_reserved(device)
|
|
||||||
logger.info(
|
|
||||||
f"Memory after OOM cleanup: "
|
|
||||||
f"Allocated: {memory_allocated / 1e9:.2f}GB, "
|
|
||||||
f"Reserved: {memory_reserved / 1e9:.2f}GB"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Retry generation
|
|
||||||
ref_s = voicepack[len(tokens)].clone().to(device)
|
|
||||||
audio = forward(cls._instance, tokens, ref_s, speed)
|
|
||||||
return audio
|
|
||||||
raise
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# Only synchronize at the top level, no empty_cache
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.synchronize()
|
|
|
@ -1,8 +0,0 @@
|
||||||
import torch
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
from .tts_gpu import TTSGPUModel as TTSModel
|
|
||||||
else:
|
|
||||||
from .tts_cpu import TTSCPUModel as TTSModel
|
|
||||||
|
|
||||||
__all__ = ["TTSModel"]
|
|
|
@ -1,120 +1,114 @@
|
||||||
|
"""TTS service using model and voice managers."""
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import time
|
import time
|
||||||
from functools import lru_cache
|
from typing import List, Tuple
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
import torch
|
||||||
|
|
||||||
import aiofiles.os
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.io.wavfile as wavfile
|
import scipy.io.wavfile as wavfile
|
||||||
import torch
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from ..core.config import settings
|
from ..core.config import settings
|
||||||
|
from ..inference.model_manager import get_manager as get_model_manager
|
||||||
|
from ..inference.voice_manager import get_manager as get_voice_manager
|
||||||
from .audio import AudioNormalizer, AudioService
|
from .audio import AudioNormalizer, AudioService
|
||||||
from .text_processing import chunker, normalize_text
|
from .text_processing import chunker, normalize_text
|
||||||
from .tts_model import TTSModel
|
|
||||||
|
|
||||||
|
|
||||||
class TTSService:
|
class TTSService:
|
||||||
|
"""Text-to-speech service."""
|
||||||
|
|
||||||
def __init__(self, output_dir: str = None):
|
def __init__(self, output_dir: str = None):
|
||||||
|
"""Initialize service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_dir: Optional output directory for saving audio
|
||||||
|
"""
|
||||||
self.output_dir = output_dir
|
self.output_dir = output_dir
|
||||||
self.model = TTSModel.get_instance()
|
self.model_manager = get_model_manager()
|
||||||
|
self.voice_manager = get_voice_manager()
|
||||||
|
self._initialized = False
|
||||||
|
self._initialization_error = None
|
||||||
|
|
||||||
@staticmethod
|
async def ensure_initialized(self):
|
||||||
@lru_cache(maxsize=3) # Cache up to 3 most recently used voices
|
"""Ensure model is initialized."""
|
||||||
def _load_voice(voice_path: str) -> torch.Tensor:
|
if self._initialized:
|
||||||
"""Load and cache a voice model"""
|
return
|
||||||
return torch.load(
|
if self._initialization_error:
|
||||||
voice_path, map_location=TTSModel.get_device(), weights_only=True
|
raise self._initialization_error
|
||||||
)
|
|
||||||
|
|
||||||
def _get_voice_path(self, voice_name: str) -> Optional[str]:
|
try:
|
||||||
"""Get the path to a voice file"""
|
# Determine model path based on hardware
|
||||||
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice_name}.pt")
|
if settings.use_gpu and torch.cuda.is_available():
|
||||||
return voice_path if os.path.exists(voice_path) else None
|
model_path = os.path.join(settings.model_dir, settings.pytorch_model_path)
|
||||||
|
backend_type = 'pytorch_gpu'
|
||||||
|
else:
|
||||||
|
model_path = os.path.join(settings.model_dir, settings.onnx_model_path)
|
||||||
|
backend_type = 'onnx_cpu'
|
||||||
|
|
||||||
|
# Initialize model
|
||||||
|
await self.model_manager.load_model(model_path, backend_type)
|
||||||
|
logger.info(f"Initialized model on {backend_type} backend")
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize model: {e}")
|
||||||
|
self._initialization_error = RuntimeError(f"Model initialization failed: {e}")
|
||||||
|
raise self._initialization_error
|
||||||
|
|
||||||
def _generate_audio(
|
async def generate_audio(
|
||||||
self, text: str, voice: str, speed: float, stitch_long_output: bool = True
|
self, text: str, voice: str, speed: float = 1.0
|
||||||
) -> Tuple[torch.Tensor, float]:
|
) -> Tuple[np.ndarray, float]:
|
||||||
"""Generate complete audio and return with processing time"""
|
"""Generate audio for text.
|
||||||
audio, processing_time = self._generate_audio_internal(
|
|
||||||
text, voice, speed, stitch_long_output
|
Args:
|
||||||
)
|
text: Input text
|
||||||
return audio, processing_time
|
voice: Voice name
|
||||||
|
speed: Speed multiplier
|
||||||
def _generate_audio_internal(
|
|
||||||
self, text: str, voice: str, speed: float, stitch_long_output: bool = True
|
Returns:
|
||||||
) -> Tuple[torch.Tensor, float]:
|
Audio samples and processing time
|
||||||
"""Generate audio and measure processing time"""
|
"""
|
||||||
|
await self.ensure_initialized()
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Normalize text once at the start
|
# Normalize text
|
||||||
if not text:
|
|
||||||
raise ValueError("Text is empty after preprocessing")
|
|
||||||
normalized = normalize_text(text)
|
normalized = normalize_text(text)
|
||||||
if not normalized:
|
if not normalized:
|
||||||
raise ValueError("Text is empty after preprocessing")
|
raise ValueError("Text is empty after preprocessing")
|
||||||
text = str(normalized)
|
text = str(normalized)
|
||||||
|
|
||||||
# Check voice exists
|
# Process text into chunks
|
||||||
voice_path = self._get_voice_path(voice)
|
audio_chunks = []
|
||||||
if not voice_path:
|
for chunk in chunker.split_text(text):
|
||||||
raise ValueError(f"Voice not found: {voice}")
|
try:
|
||||||
|
# Process text
|
||||||
# Load voice using cached loader
|
|
||||||
voicepack = self._load_voice(voice_path)
|
sequences = process_text(chunk)
|
||||||
|
if not sequences:
|
||||||
# For non-streaming, preprocess all chunks first
|
|
||||||
if stitch_long_output:
|
|
||||||
# Preprocess all chunks to phonemes/tokens
|
|
||||||
chunks_data = []
|
|
||||||
for chunk in chunker.split_text(text):
|
|
||||||
try:
|
|
||||||
phonemes, tokens = TTSModel.process_text(chunk, voice[0])
|
|
||||||
chunks_data.append((chunk, tokens))
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to process chunk: '{chunk}'. Error: {str(e)}"
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not chunks_data:
|
# Generate audio
|
||||||
raise ValueError("No chunks were processed successfully")
|
chunk_audio = await self.model_manager.generate(
|
||||||
|
sequences[0],
|
||||||
|
voice,
|
||||||
|
speed=speed
|
||||||
|
)
|
||||||
|
if chunk_audio is not None:
|
||||||
|
audio_chunks.append(chunk_audio)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to process chunk: '{chunk}'. Error: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
# Generate audio for all chunks
|
if not audio_chunks:
|
||||||
audio_chunks = []
|
raise ValueError("No audio chunks were generated successfully")
|
||||||
for chunk, tokens in chunks_data:
|
|
||||||
try:
|
|
||||||
chunk_audio = TTSModel.generate_from_tokens(
|
|
||||||
tokens, voicepack, speed
|
|
||||||
)
|
|
||||||
if chunk_audio is not None:
|
|
||||||
audio_chunks.append(chunk_audio)
|
|
||||||
else:
|
|
||||||
logger.error(f"No audio generated for chunk: '{chunk}'")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to generate audio for chunk: '{chunk}'. Error: {str(e)}"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not audio_chunks:
|
|
||||||
raise ValueError("No audio chunks were generated successfully")
|
|
||||||
|
|
||||||
# Concatenate all chunks
|
|
||||||
audio = (
|
|
||||||
np.concatenate(audio_chunks)
|
|
||||||
if len(audio_chunks) > 1
|
|
||||||
else audio_chunks[0]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Process single chunk
|
|
||||||
phonemes, tokens = TTSModel.process_text(text, voice[0])
|
|
||||||
audio = TTSModel.generate_from_tokens(tokens, voicepack, speed)
|
|
||||||
|
|
||||||
|
# Combine chunks
|
||||||
|
audio = np.concatenate(audio_chunks) if len(audio_chunks) > 1 else audio_chunks[0]
|
||||||
processing_time = time.time() - start_time
|
processing_time = time.time() - start_time
|
||||||
return audio, processing_time
|
return audio, processing_time
|
||||||
|
|
||||||
|
@ -126,144 +120,103 @@ class TTSService:
|
||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
voice: str,
|
voice: str,
|
||||||
speed: float,
|
speed: float = 1.0,
|
||||||
output_format: str = "wav",
|
output_format: str = "wav",
|
||||||
silent=False,
|
|
||||||
):
|
):
|
||||||
"""Generate and yield audio chunks as they're generated for real-time streaming"""
|
"""Generate and stream audio chunks.
|
||||||
try:
|
|
||||||
stream_start = time.time()
|
Args:
|
||||||
# Create normalizer for consistent audio levels
|
text: Input text
|
||||||
stream_normalizer = AudioNormalizer()
|
voice: Voice name
|
||||||
|
speed: Speed multiplier
|
||||||
|
output_format: Output audio format
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Audio chunks as bytes
|
||||||
|
"""
|
||||||
|
await self.ensure_initialized()
|
||||||
|
|
||||||
# Input validation and preprocessing
|
try:
|
||||||
if not text:
|
# Setup audio processing
|
||||||
raise ValueError("Text is empty")
|
stream_normalizer = AudioNormalizer()
|
||||||
preprocess_start = time.time()
|
|
||||||
|
# Normalize text
|
||||||
normalized = normalize_text(text)
|
normalized = normalize_text(text)
|
||||||
if not normalized:
|
if not normalized:
|
||||||
raise ValueError("Text is empty after preprocessing")
|
raise ValueError("Text is empty after preprocessing")
|
||||||
text = str(normalized)
|
text = str(normalized)
|
||||||
logger.debug(
|
|
||||||
f"Text preprocessing took: {(time.time() - preprocess_start)*1000:.1f}ms"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Voice validation and loading
|
# Process chunks
|
||||||
voice_start = time.time()
|
|
||||||
voice_path = self._get_voice_path(voice)
|
|
||||||
if not voice_path:
|
|
||||||
raise ValueError(f"Voice not found: {voice}")
|
|
||||||
voicepack = self._load_voice(voice_path)
|
|
||||||
logger.debug(
|
|
||||||
f"Voice loading took: {(time.time() - voice_start)*1000:.1f}ms"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process chunks as they're generated
|
|
||||||
is_first = True
|
is_first = True
|
||||||
chunks_processed = 0
|
|
||||||
|
|
||||||
# Process chunks as they come from generator
|
|
||||||
chunk_gen = chunker.split_text(text)
|
chunk_gen = chunker.split_text(text)
|
||||||
current_chunk = next(chunk_gen, None)
|
current_chunk = next(chunk_gen, None)
|
||||||
|
|
||||||
while current_chunk is not None:
|
while current_chunk is not None:
|
||||||
next_chunk = next(chunk_gen, None) # Peek at next chunk
|
next_chunk = next(chunk_gen, None)
|
||||||
chunks_processed += 1
|
|
||||||
try:
|
try:
|
||||||
# Process text and generate audio
|
# Process text
|
||||||
phonemes, tokens = TTSModel.process_text(current_chunk, voice[0])
|
from ..text_processing import process_text
|
||||||
chunk_audio = TTSModel.generate_from_tokens(
|
sequences = process_text(current_chunk)
|
||||||
tokens, voicepack, speed
|
if sequences:
|
||||||
)
|
# Generate audio
|
||||||
|
chunk_audio = await self.model_manager.generate(
|
||||||
if chunk_audio is not None:
|
sequences[0],
|
||||||
# Convert chunk with proper streaming header handling
|
voice,
|
||||||
chunk_bytes = AudioService.convert_audio(
|
speed=speed
|
||||||
chunk_audio,
|
|
||||||
24000,
|
|
||||||
output_format,
|
|
||||||
is_first_chunk=is_first,
|
|
||||||
normalizer=stream_normalizer,
|
|
||||||
is_last_chunk=(next_chunk is None), # Last if no next chunk
|
|
||||||
stream=True # Ensure proper streaming format handling
|
|
||||||
)
|
)
|
||||||
|
|
||||||
yield chunk_bytes
|
if chunk_audio is not None:
|
||||||
is_first = False
|
# Convert to bytes
|
||||||
else:
|
chunk_bytes = AudioService.convert_audio(
|
||||||
logger.error(f"No audio generated for chunk: '{current_chunk}'")
|
chunk_audio,
|
||||||
|
24000,
|
||||||
|
output_format,
|
||||||
|
is_first_chunk=is_first,
|
||||||
|
normalizer=stream_normalizer,
|
||||||
|
is_last_chunk=(next_chunk is None),
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
yield chunk_bytes
|
||||||
|
is_first = False
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(f"Failed to generate audio for chunk: '{current_chunk}'. Error: {str(e)}")
|
||||||
f"Failed to generate audio for chunk: '{current_chunk}'. Error: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
current_chunk = next_chunk # Move to next chunk
|
current_chunk = next_chunk
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in audio generation stream: {str(e)}")
|
logger.error(f"Error in audio generation stream: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _save_audio(self, audio: torch.Tensor, filepath: str):
|
|
||||||
"""Save audio to file"""
|
|
||||||
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
|
||||||
wavfile.write(filepath, 24000, audio)
|
|
||||||
|
|
||||||
def _audio_to_bytes(self, audio: torch.Tensor) -> bytes:
|
|
||||||
"""Convert audio tensor to WAV bytes"""
|
|
||||||
buffer = io.BytesIO()
|
|
||||||
wavfile.write(buffer, 24000, audio)
|
|
||||||
return buffer.getvalue()
|
|
||||||
|
|
||||||
async def combine_voices(self, voices: List[str]) -> str:
|
async def combine_voices(self, voices: List[str]) -> str:
|
||||||
"""Combine multiple voices into a new voice"""
|
"""Combine multiple voices.
|
||||||
if len(voices) < 2:
|
|
||||||
raise ValueError("At least 2 voices are required for combination")
|
Args:
|
||||||
|
voices: List of voice names
|
||||||
# Load voices
|
|
||||||
t_voices: List[torch.Tensor] = []
|
Returns:
|
||||||
v_name: List[str] = []
|
Name of combined voice
|
||||||
|
"""
|
||||||
for voice in voices:
|
await self.ensure_initialized()
|
||||||
try:
|
return await self.voice_manager.combine_voices(voices)
|
||||||
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt")
|
|
||||||
voicepack = torch.load(
|
|
||||||
voice_path, map_location=TTSModel.get_device(), weights_only=True
|
|
||||||
)
|
|
||||||
t_voices.append(voicepack)
|
|
||||||
v_name.append(voice)
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"Failed to load voice {voice}: {str(e)}")
|
|
||||||
|
|
||||||
# Combine voices
|
|
||||||
try:
|
|
||||||
f: str = "_".join(v_name)
|
|
||||||
v = torch.mean(torch.stack(t_voices), dim=0)
|
|
||||||
combined_path = os.path.join(TTSModel.VOICES_DIR, f"{f}.pt")
|
|
||||||
|
|
||||||
# Save combined voice
|
|
||||||
try:
|
|
||||||
torch.save(v, combined_path)
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Failed to save combined voice to {combined_path}: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return f
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
if not isinstance(e, (ValueError, RuntimeError)):
|
|
||||||
raise RuntimeError(f"Error combining voices: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def list_voices(self) -> List[str]:
|
async def list_voices(self) -> List[str]:
|
||||||
"""List all available voices"""
|
"""List available voices.
|
||||||
voices = []
|
|
||||||
try:
|
Returns:
|
||||||
it = await aiofiles.os.scandir(TTSModel.VOICES_DIR)
|
List of voice names
|
||||||
for entry in it:
|
"""
|
||||||
if entry.name.endswith(".pt"):
|
return await self.voice_manager.list_voices()
|
||||||
voices.append(entry.name[:-3]) # Remove .pt extension
|
|
||||||
except Exception as e:
|
def _audio_to_bytes(self, audio: np.ndarray) -> bytes:
|
||||||
logger.error(f"Error listing voices: {str(e)}")
|
"""Convert audio to WAV bytes.
|
||||||
return sorted(voices)
|
|
||||||
|
Args:
|
||||||
|
audio: Audio samples
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WAV bytes
|
||||||
|
"""
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
wavfile.write(buffer, 24000, audio)
|
||||||
|
return buffer.getvalue()
|
26
api/src/structures/model_schemas.py
Normal file
26
api/src/structures/model_schemas.py
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
"""Model and voice configuration schemas."""
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class ModelConfig(BaseModel):
|
||||||
|
"""Model configuration."""
|
||||||
|
optimization_level: str = "all" # all, basic, none
|
||||||
|
num_threads: int = 4
|
||||||
|
inter_op_threads: int = 4
|
||||||
|
execution_mode: str = "parallel" # parallel, sequential
|
||||||
|
memory_pattern: bool = True
|
||||||
|
arena_extend_strategy: str = "kNextPowerOfTwo"
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
frozen = True # Make config immutable
|
||||||
|
|
||||||
|
|
||||||
|
class VoiceConfig(BaseModel):
|
||||||
|
"""Voice configuration."""
|
||||||
|
use_cache: bool = True
|
||||||
|
cache_size: int = 3 # Number of voices to cache
|
||||||
|
validate_on_load: bool = True # Whether to validate voices when loading
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
frozen = True # Make config immutable
|
Loading…
Add table
Reference in a new issue