Refactor inference architecture: remove legacy TTS model, add ONNX and PyTorch backends, and introduce model configuration schemas

This commit is contained in:
remsky 2025-01-20 22:42:29 -07:00
parent 83c55ca735
commit ab28a62e86
16 changed files with 1606 additions and 813 deletions

3
.gitignore vendored
View file

@ -18,7 +18,8 @@ __pycache__/
*.egg *.egg
dist/ dist/
build/ build/
*.onnx
*.pth
# Environment # Environment
# .env # .env
.venv/ .venv/

198
api/src/core/paths.py Normal file
View file

@ -0,0 +1,198 @@
"""Async file and path operations."""
import io
import os
from pathlib import Path
from typing import List, Optional, AsyncIterator, Callable, Set
import aiofiles
import aiofiles.os
import torch
from loguru import logger
from .config import settings
async def _find_file(
filename: str,
search_paths: List[str],
filter_fn: Optional[Callable[[str], bool]] = None
) -> str:
"""Find file in search paths.
Args:
filename: Name of file to find
search_paths: List of paths to search in
filter_fn: Optional function to filter files
Returns:
Absolute path to file
Raises:
RuntimeError: If file not found
"""
if os.path.isabs(filename) and await aiofiles.os.path.exists(filename):
return filename
for path in search_paths:
full_path = os.path.join(path, filename)
if await aiofiles.os.path.exists(full_path):
if filter_fn is None or filter_fn(full_path):
return full_path
raise RuntimeError(f"File not found: {filename} in paths: {search_paths}")
async def _scan_directories(
search_paths: List[str],
filter_fn: Optional[Callable[[str], bool]] = None
) -> Set[str]:
"""Scan directories for files.
Args:
search_paths: List of paths to scan
filter_fn: Optional function to filter files
Returns:
Set of matching filenames
"""
results = set()
for path in search_paths:
if not await aiofiles.os.path.exists(path):
continue
try:
# Get directory entries first
entries = await aiofiles.os.scandir(path)
# Then process entries after await completes
for entry in entries:
if filter_fn is None or filter_fn(entry.name):
results.add(entry.name)
except Exception as e:
logger.warning(f"Error scanning {path}: {e}")
return results
async def get_model_path(model_name: str) -> str:
"""Get path to model file.
Args:
model_name: Name of model file
Returns:
Absolute path to model file
Raises:
RuntimeError: If model not found
"""
search_paths = [
settings.model_dir,
os.path.join(os.path.dirname(__file__), "..", "..", "..", "models")
]
return await _find_file(model_name, search_paths)
async def get_voice_path(voice_name: str) -> str:
"""Get path to voice file.
Args:
voice_name: Name of voice file (without .pt extension)
Returns:
Absolute path to voice file
Raises:
RuntimeError: If voice not found
"""
voice_file = f"{voice_name}.pt"
search_paths = [
os.path.join(settings.model_dir, "..", settings.voices_dir),
os.path.join(os.path.dirname(__file__), "..", settings.voices_dir)
]
return await _find_file(voice_file, search_paths)
async def list_voices() -> List[str]:
"""List available voice files.
Returns:
List of voice names (without .pt extension)
"""
search_paths = [
os.path.join(settings.model_dir, "..", settings.voices_dir),
os.path.join(os.path.dirname(__file__), "..", settings.voices_dir)
]
def filter_voice_files(name: str) -> bool:
return name.endswith('.pt')
voices = await _scan_directories(search_paths, filter_voice_files)
return sorted([name[:-3] for name in voices]) # Remove .pt extension
async def load_voice_tensor(voice_path: str, device: str = "cpu") -> torch.Tensor:
"""Load voice tensor from file.
Args:
voice_path: Path to voice file
device: Device to load tensor to
Returns:
Voice tensor
Raises:
RuntimeError: If file cannot be read
"""
try:
async with aiofiles.open(voice_path, 'rb') as f:
data = await f.read()
return torch.load(
io.BytesIO(data),
map_location=device,
weights_only=True
)
except Exception as e:
raise RuntimeError(f"Failed to load voice tensor from {voice_path}: {e}")
async def save_voice_tensor(tensor: torch.Tensor, voice_path: str) -> None:
"""Save voice tensor to file.
Args:
tensor: Voice tensor to save
voice_path: Path to save voice file
Raises:
RuntimeError: If file cannot be written
"""
try:
buffer = io.BytesIO()
torch.save(tensor, buffer)
async with aiofiles.open(voice_path, 'wb') as f:
await f.write(buffer.getvalue())
except Exception as e:
raise RuntimeError(f"Failed to save voice tensor to {voice_path}: {e}")
async def read_file(path: str) -> str:
"""Read text file asynchronously.
Args:
path: Path to file
Returns:
File contents as string
Raises:
RuntimeError: If file cannot be read
"""
try:
async with aiofiles.open(path, 'r', encoding='utf-8') as f:
return await f.read()
except Exception as e:
raise RuntimeError(f"Failed to read file {path}: {e}")

View file

@ -0,0 +1,20 @@
"""Inference backends and model management."""
from .base import BaseModelBackend
from .model_manager import ModelManager, get_manager
from .onnx_cpu import ONNXCPUBackend
from .onnx_gpu import ONNXGPUBackend
from .pytorch_cpu import PyTorchCPUBackend
from .pytorch_gpu import PyTorchGPUBackend
from ..structures.model_schemas import ModelConfig
__all__ = [
'BaseModelBackend',
'ModelManager',
'get_manager',
'ModelConfig',
'ONNXCPUBackend',
'ONNXGPUBackend',
'PyTorchCPUBackend',
'PyTorchGPUBackend'
]

97
api/src/inference/base.py Normal file
View file

@ -0,0 +1,97 @@
"""Base interfaces for model inference."""
from abc import ABC, abstractmethod
from typing import List, Optional
import numpy as np
import torch
class ModelBackend(ABC):
"""Abstract base class for model inference backends."""
@abstractmethod
async def load_model(self, path: str) -> None:
"""Load model from path.
Args:
path: Path to model file
Raises:
RuntimeError: If model loading fails
"""
pass
@abstractmethod
def generate(
self,
tokens: List[int],
voice: torch.Tensor,
speed: float = 1.0
) -> np.ndarray:
"""Generate audio from tokens.
Args:
tokens: Input token IDs
voice: Voice embedding tensor
speed: Speed multiplier
Returns:
Generated audio samples
Raises:
RuntimeError: If generation fails
"""
pass
@abstractmethod
def unload(self) -> None:
"""Unload model and free resources."""
pass
@property
@abstractmethod
def is_loaded(self) -> bool:
"""Check if model is loaded.
Returns:
True if model is loaded, False otherwise
"""
pass
@property
@abstractmethod
def device(self) -> str:
"""Get device model is running on.
Returns:
Device string ('cpu' or 'cuda')
"""
pass
class BaseModelBackend(ModelBackend):
"""Base implementation of model backend."""
def __init__(self):
"""Initialize base backend."""
self._model: Optional[torch.nn.Module] = None
self._device: str = "cpu"
@property
def is_loaded(self) -> bool:
"""Check if model is loaded."""
return self._model is not None
@property
def device(self) -> str:
"""Get device model is running on."""
return self._device
def unload(self) -> None:
"""Unload model and free resources."""
if self._model is not None:
del self._model
self._model = None
if torch.cuda.is_available():
torch.cuda.empty_cache()

View file

@ -0,0 +1,251 @@
"""Model management and caching."""
import os
from typing import Dict, List, Optional, Union
import torch
from loguru import logger
from pydantic import BaseModel
from .base import BaseModelBackend
from .voice_manager import get_manager as get_voice_manager
from .onnx_cpu import ONNXCPUBackend
from .onnx_gpu import ONNXGPUBackend
from .pytorch_cpu import PyTorchCPUBackend
from .pytorch_gpu import PyTorchGPUBackend
from ..core import paths
from ..core.config import settings
from ..structures.model_schemas import ModelConfig
class ModelManager:
"""Manages model loading and inference across backends."""
def __init__(self, config: Optional[ModelConfig] = None):
"""Initialize model manager.
Args:
config: Optional configuration
"""
self._config = config or ModelConfig()
self._backends: Dict[str, BaseModelBackend] = {}
self._current_backend: Optional[str] = None
self._voice_manager = get_voice_manager()
self._initialize_backends()
def _initialize_backends(self) -> None:
"""Initialize available backends."""
"""Initialize available backends."""
# Initialize GPU backends if available
if settings.use_gpu and torch.cuda.is_available():
try:
# PyTorch GPU
self._backends['pytorch_gpu'] = PyTorchGPUBackend()
self._current_backend = 'pytorch_gpu'
logger.info("Initialized PyTorch GPU backend")
# ONNX GPU
self._backends['onnx_gpu'] = ONNXGPUBackend()
logger.info("Initialized ONNX GPU backend")
except Exception as e:
logger.error(f"Failed to initialize GPU backends: {e}")
# Fallback to CPU if GPU fails
self._initialize_cpu_backends()
else:
self._initialize_cpu_backends()
def _initialize_cpu_backends(self) -> None:
"""Initialize CPU backends."""
try:
# PyTorch CPU
self._backends['pytorch_cpu'] = PyTorchCPUBackend()
self._current_backend = 'pytorch_cpu'
logger.info("Initialized PyTorch CPU backend")
# ONNX CPU
self._backends['onnx_cpu'] = ONNXCPUBackend()
logger.info("Initialized ONNX CPU backend")
except Exception as e:
logger.error(f"Failed to initialize CPU backends: {e}")
raise RuntimeError("No backends available")
def get_backend(self, backend_type: Optional[str] = None) -> BaseModelBackend:
"""Get specified backend.
Args:
backend_type: Backend type ('pytorch_cpu', 'pytorch_gpu', 'onnx_cpu', 'onnx_gpu'),
uses default if None
Returns:
Model backend instance
Raises:
ValueError: If backend type is invalid
RuntimeError: If no backends are available
"""
if not self._backends:
raise RuntimeError("No backends available")
if backend_type is None:
backend_type = self._current_backend
if backend_type not in self._backends:
raise ValueError(
f"Invalid backend type: {backend_type}. "
f"Available backends: {', '.join(self._backends.keys())}"
)
return self._backends[backend_type]
def _determine_backend(self, model_path: str) -> str:
"""Determine appropriate backend based on model file.
Args:
model_path: Path to model file
Returns:
Backend type to use
"""
is_onnx = model_path.lower().endswith('.onnx')
has_gpu = settings.use_gpu and torch.cuda.is_available()
if is_onnx:
return 'onnx_gpu' if has_gpu else 'onnx_cpu'
else:
return 'pytorch_gpu' if has_gpu else 'pytorch_cpu'
async def load_model(
self,
model_path: str,
backend_type: Optional[str] = None
) -> None:
"""Load model on specified backend.
Args:
model_path: Path to model file
backend_type: Backend to load on, uses default if None
Raises:
RuntimeError: If model loading fails
"""
try:
# Get absolute model path
abs_path = await paths.get_model_path(model_path)
# Auto-determine backend if not specified
if backend_type is None:
backend_type = self._determine_backend(abs_path)
backend = self.get_backend(backend_type)
# Load model and run warmup
await backend.load_model(abs_path)
logger.info(f"Loaded model on {backend_type} backend")
await self._warmup_inference(backend)
except Exception as e:
raise RuntimeError(f"Failed to load model: {e}")
async def _warmup_inference(self, backend: BaseModelBackend) -> None:
"""Run warmup inference to initialize model."""
try:
# Import here to avoid circular imports
from ..text_processing import process_text
# Load default voice for warmup
voice = await self._voice_manager.load_voice(settings.default_voice, device=backend.device)
logger.info(f"Loaded voice {settings.default_voice} for warmup")
# Use real text
text = "Testing text to speech synthesis."
logger.info(f"Running warmup inference with voice: af")
# Process through pipeline
sequences = process_text(text)
if not sequences:
raise ValueError("Text processing failed")
# Run inference
backend.generate(sequences[0], voice, speed=1.0)
except Exception as e:
logger.warning(f"Warmup inference failed: {e}")
raise
async def generate(
self,
tokens: list[int],
voice_name: str,
speed: float = 1.0,
backend_type: Optional[str] = None
) -> torch.Tensor:
"""Generate audio using specified backend.
Args:
tokens: Input token IDs
voice_name: Name of voice to use
speed: Speed multiplier
backend_type: Backend to use, uses default if None
Returns:
Generated audio tensor
Raises:
RuntimeError: If generation fails
"""
backend = self.get_backend(backend_type)
if not backend.is_loaded:
raise RuntimeError("Model not loaded")
try:
# Load voice using voice manager
voice = await self._voice_manager.load_voice(voice_name, device=backend.device)
# Generate audio
return backend.generate(tokens, voice, speed)
except Exception as e:
raise RuntimeError(f"Generation failed: {e}")
def unload_all(self) -> None:
"""Unload models from all backends."""
for backend in self._backends.values():
backend.unload()
logger.info("Unloaded all models")
@property
def available_backends(self) -> list[str]:
"""Get list of available backends.
Returns:
List of backend names
"""
return list(self._backends.keys())
@property
def current_backend(self) -> str:
"""Get current default backend.
Returns:
Backend name
"""
return self._current_backend
# Module-level instance
_manager: Optional[ModelManager] = None
def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
"""Get or create global model manager instance.
Args:
config: Optional model configuration
Returns:
ModelManager instance
"""
global _manager
if _manager is None:
_manager = ModelManager(config)
return _manager

View file

@ -0,0 +1,154 @@
"""CPU-based ONNX inference backend."""
from typing import Dict, Optional
import numpy as np
import torch
from loguru import logger
from onnxruntime import (
ExecutionMode,
GraphOptimizationLevel,
InferenceSession,
SessionOptions
)
from ..core import paths
from ..core.config import settings
from ..structures.model_schemas import ONNXConfig
from .base import BaseModelBackend
class ONNXCPUBackend(BaseModelBackend):
"""ONNX-based CPU inference backend."""
def __init__(self):
"""Initialize CPU backend."""
super().__init__()
self._device = "cpu"
self._session: Optional[InferenceSession] = None
self._config = ONNXConfig(
optimization_level=settings.onnx_optimization_level,
num_threads=settings.onnx_num_threads,
inter_op_threads=settings.onnx_inter_op_threads,
execution_mode=settings.onnx_execution_mode,
memory_pattern=settings.onnx_memory_pattern,
arena_extend_strategy=settings.onnx_arena_extend_strategy
)
async def load_model(self, path: str) -> None:
"""Load ONNX model.
Args:
path: Path to model file
Raises:
RuntimeError: If model loading fails
"""
try:
# Get verified model path
model_path = await paths.get_model_path(path)
logger.info(f"Loading ONNX model: {model_path}")
# Configure session
options = self._create_session_options()
provider_options = self._create_provider_options()
# Create session
self._session = InferenceSession(
model_path,
sess_options=options,
providers=["CPUExecutionProvider"],
provider_options=[provider_options]
)
except Exception as e:
raise RuntimeError(f"Failed to load ONNX model: {e}")
def generate(
self,
tokens: list[int],
voice: torch.Tensor,
speed: float = 1.0
) -> np.ndarray:
"""Generate audio using ONNX model.
Args:
tokens: Input token IDs
voice: Voice embedding tensor
speed: Speed multiplier
Returns:
Generated audio samples
Raises:
RuntimeError: If generation fails
"""
if not self.is_loaded:
raise RuntimeError("Model not loaded")
try:
# Prepare inputs
tokens_input = np.array([tokens], dtype=np.int64)
style_input = voice[len(tokens)].numpy()
speed_input = np.full(1, speed, dtype=np.float32)
# Run inference
result = self._session.run(
None,
{
"tokens": tokens_input,
"style": style_input,
"speed": speed_input
}
)
return result[0]
except Exception as e:
raise RuntimeError(f"Generation failed: {e}")
def _create_session_options(self) -> SessionOptions:
"""Create ONNX session options.
Returns:
Configured session options
"""
options = SessionOptions()
# Set optimization level
if self._config.optimization_level == "all":
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
elif self._config.optimization_level == "basic":
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
else:
options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
# Configure threading
options.intra_op_num_threads = self._config.num_threads
options.inter_op_num_threads = self._config.inter_op_threads
# Set execution mode
options.execution_mode = (
ExecutionMode.ORT_PARALLEL
if self._config.execution_mode == "parallel"
else ExecutionMode.ORT_SEQUENTIAL
)
# Configure memory optimization
options.enable_mem_pattern = self._config.memory_pattern
return options
def _create_provider_options(self) -> Dict:
"""Create CPU provider options.
Returns:
Provider configuration
"""
return {
"CPUExecutionProvider": {
"arena_extend_strategy": self._config.arena_extend_strategy,
"cpu_memory_arena_cfg": "cpu:0"
}
}

View file

@ -0,0 +1,163 @@
"""GPU-based ONNX inference backend."""
from typing import Dict, Optional
import numpy as np
import torch
from loguru import logger
from onnxruntime import (
ExecutionMode,
GraphOptimizationLevel,
InferenceSession,
SessionOptions
)
from ..core import paths
from ..core.config import settings
from ..structures.model_schemas import ONNXGPUConfig
from .base import BaseModelBackend
class ONNXGPUBackend(BaseModelBackend):
"""ONNX-based GPU inference backend."""
def __init__(self):
"""Initialize GPU backend."""
super().__init__()
if not torch.cuda.is_available():
raise RuntimeError("CUDA not available")
self._device = "cuda"
self._session: Optional[InferenceSession] = None
self._config = ONNXGPUConfig(
optimization_level=settings.onnx_optimization_level,
num_threads=settings.onnx_num_threads,
inter_op_threads=settings.onnx_inter_op_threads,
execution_mode=settings.onnx_execution_mode,
memory_pattern=settings.onnx_memory_pattern,
arena_extend_strategy=settings.onnx_arena_extend_strategy,
device_id=0,
gpu_mem_limit=0.7,
cudnn_conv_algo_search="EXHAUSTIVE",
do_copy_in_default_stream=True
)
async def load_model(self, path: str) -> None:
"""Load ONNX model.
Args:
path: Path to model file
Raises:
RuntimeError: If model loading fails
"""
try:
# Get verified model path
model_path = await paths.get_model_path(path)
logger.info(f"Loading ONNX model on GPU: {model_path}")
# Configure session
options = self._create_session_options()
provider_options = self._create_provider_options()
# Create session with CUDA provider
self._session = InferenceSession(
model_path,
sess_options=options,
providers=["CUDAExecutionProvider"],
provider_options=[provider_options]
)
except Exception as e:
raise RuntimeError(f"Failed to load ONNX model: {e}")
def generate(
self,
tokens: list[int],
voice: torch.Tensor,
speed: float = 1.0
) -> np.ndarray:
"""Generate audio using ONNX model.
Args:
tokens: Input token IDs
voice: Voice embedding tensor
speed: Speed multiplier
Returns:
Generated audio samples
Raises:
RuntimeError: If generation fails
"""
if not self.is_loaded:
raise RuntimeError("Model not loaded")
try:
# Prepare inputs
tokens_input = np.array([tokens], dtype=np.int64)
style_input = voice[len(tokens)].cpu().numpy() # Move to CPU for ONNX
speed_input = np.full(1, speed, dtype=np.float32)
# Run inference
result = self._session.run(
None,
{
"tokens": tokens_input,
"style": style_input,
"speed": speed_input
}
)
return result[0]
except Exception as e:
raise RuntimeError(f"Generation failed: {e}")
def _create_session_options(self) -> SessionOptions:
"""Create ONNX session options.
Returns:
Configured session options
"""
options = SessionOptions()
# Set optimization level
if self._config.optimization_level == "all":
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
elif self._config.optimization_level == "basic":
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
else:
options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
# Configure threading
options.intra_op_num_threads = self._config.num_threads
options.inter_op_num_threads = self._config.inter_op_threads
# Set execution mode
options.execution_mode = (
ExecutionMode.ORT_PARALLEL
if self._config.execution_mode == "parallel"
else ExecutionMode.ORT_SEQUENTIAL
)
# Configure memory optimization
options.enable_mem_pattern = self._config.memory_pattern
return options
def _create_provider_options(self) -> Dict:
"""Create CUDA provider options.
Returns:
Provider configuration
"""
return {
"CUDAExecutionProvider": {
"device_id": self._config.device_id,
"arena_extend_strategy": self._config.arena_extend_strategy,
"gpu_mem_limit": int(self._config.gpu_mem_limit * torch.cuda.get_device_properties(0).total_memory),
"cudnn_conv_algo_search": self._config.cudnn_conv_algo_search,
"do_copy_in_default_stream": self._config.do_copy_in_default_stream
}
}

View file

@ -0,0 +1,181 @@
"""CPU-based PyTorch inference backend."""
import gc
from typing import Optional
import numpy as np
import torch
from loguru import logger
from ..builds.models import build_model
from ..core import paths
from ..structures.model_schemas import PyTorchCPUConfig
from .base import BaseModelBackend
@torch.no_grad()
def forward(model: torch.nn.Module, tokens: list[int], ref_s: torch.Tensor, speed: float) -> np.ndarray:
"""Forward pass through model with memory management.
Args:
model: PyTorch model
tokens: Input tokens
ref_s: Reference signal
speed: Speed multiplier
Returns:
Generated audio
"""
device = ref_s.device
pred_aln_trg = None
asr = None
try:
# Initial tensor setup
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
text_mask = length_to_mask(input_lengths).to(device)
# Split reference signals
s_content = ref_s[:, 128:].clone().to(device)
s_ref = ref_s[:, :128].clone().to(device)
# BERT and encoder pass
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
# Predictor forward pass
d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask)
x, _ = model.predictor.lstm(d)
# Duration prediction
duration = model.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long()
del duration, x # Free large intermediates
# Alignment matrix construction
pred_aln_trg = torch.zeros(
input_lengths.item(),
pred_dur.sum().item(),
device=device
)
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame:c_frame + pred_dur[0, i].item()] = 1
c_frame += pred_dur[0, i].item()
pred_aln_trg = pred_aln_trg.unsqueeze(0)
# Matrix multiplications with cleanup
en = d.transpose(-1, -2) @ pred_aln_trg
del d
F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
del en
# Final text encoding and decoding
t_en = model.text_encoder(tokens, input_lengths, text_mask)
asr = t_en @ pred_aln_trg
del t_en
# Generate output
output = model.decoder(asr, F0_pred, N_pred, s_ref)
result = output.squeeze().cpu().numpy()
return result
finally:
# Clean up largest tensors if they were created
if pred_aln_trg is not None:
del pred_aln_trg
if asr is not None:
del asr
def length_to_mask(lengths: torch.Tensor) -> torch.Tensor:
"""Create attention mask from lengths.
Args:
lengths: Sequence lengths
Returns:
Boolean mask tensor
"""
max_len = lengths.max()
mask = torch.arange(max_len, device=lengths.device)[None, :].expand(
lengths.shape[0], -1
)
if lengths.dtype != mask.dtype:
mask = mask.to(dtype=lengths.dtype)
return mask + 1 > lengths[:, None]
class PyTorchCPUBackend(BaseModelBackend):
"""PyTorch CPU inference backend."""
def __init__(self):
"""Initialize CPU backend."""
super().__init__()
self._device = "cpu"
self._model: Optional[torch.nn.Module] = None
self._config = PyTorchCPUConfig()
# Configure PyTorch CPU settings
if self._config.num_threads > 0:
torch.set_num_threads(self._config.num_threads)
if self._config.pin_memory:
torch.set_default_tensor_type(torch.FloatTensor)
async def load_model(self, path: str) -> None:
"""Load PyTorch model.
Args:
path: Path to model file
Raises:
RuntimeError: If model loading fails
"""
try:
# Get verified model path
model_path = await paths.get_model_path(path)
logger.info(f"Loading PyTorch model on CPU: {model_path}")
self._model = await build_model(model_path, self._device)
except Exception as e:
raise RuntimeError(f"Failed to load PyTorch model: {e}")
def generate(
self,
tokens: list[int],
voice: torch.Tensor,
speed: float = 1.0
) -> np.ndarray:
"""Generate audio using CPU model.
Args:
tokens: Input token IDs
voice: Voice embedding tensor
speed: Speed multiplier
Returns:
Generated audio samples
Raises:
RuntimeError: If generation fails
"""
if not self.is_loaded:
raise RuntimeError("Model not loaded")
try:
# Prepare input
ref_s = voice[len(tokens)].clone()
# Generate audio
return forward(self._model, tokens, ref_s, speed)
except Exception as e:
raise RuntimeError(f"Generation failed: {e}")
finally:
# Clean up memory
gc.collect()

View file

@ -0,0 +1,170 @@
"""GPU-based PyTorch inference backend."""
import gc
from typing import Optional
import numpy as np
import torch
from loguru import logger
from ..builds.models import build_model
from ..core import paths
from ..structures.model_schemas import PyTorchConfig
from .base import BaseModelBackend
@torch.no_grad()
def forward(model: torch.nn.Module, tokens: list[int], ref_s: torch.Tensor, speed: float) -> np.ndarray:
"""Forward pass through model.
Args:
model: PyTorch model
tokens: Input tokens
ref_s: Reference signal (shape: [1, n_features])
speed: Speed multiplier
Returns:
Generated audio
"""
device = ref_s.device
# Initial tensor setup
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
text_mask = length_to_mask(input_lengths).to(device)
# Split reference signals (style_dim=128 from config)
style_dim = 128
s_ref = ref_s[:, :style_dim].clone().to(device)
s_content = ref_s[:, style_dim:].clone().to(device)
# BERT and encoder pass
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
# Predictor forward pass
d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask)
x, _ = model.predictor.lstm(d)
# Duration prediction
duration = model.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long()
del duration, x
# Alignment matrix construction
pred_aln_trg = torch.zeros(input_lengths.item(), pred_dur.sum().item(), device=device)
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame:c_frame + pred_dur[0, i].item()] = 1
c_frame += pred_dur[0, i].item()
pred_aln_trg = pred_aln_trg.unsqueeze(0)
# Matrix multiplications
en = d.transpose(-1, -2) @ pred_aln_trg
del d
F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
del en
# Final text encoding and decoding
t_en = model.text_encoder(tokens, input_lengths, text_mask)
asr = t_en @ pred_aln_trg
del t_en
# Generate output
output = model.decoder(asr, F0_pred, N_pred, s_ref)
return output.squeeze().cpu().numpy()
def length_to_mask(lengths: torch.Tensor) -> torch.Tensor:
"""Create attention mask from lengths."""
max_len = lengths.max()
mask = torch.arange(max_len, device=lengths.device)[None, :].expand(lengths.shape[0], -1)
if lengths.dtype != mask.dtype:
mask = mask.to(dtype=lengths.dtype)
return mask + 1 > lengths[:, None]
class PyTorchGPUBackend(BaseModelBackend):
"""PyTorch GPU inference backend."""
def __init__(self):
"""Initialize GPU backend."""
super().__init__()
if not torch.cuda.is_available():
raise RuntimeError("CUDA not available")
self._device = "cuda"
self._model: Optional[torch.nn.Module] = None
self._config = PyTorchConfig()
async def load_model(self, path: str) -> None:
"""Load PyTorch model.
Args:
path: Path to model file
Raises:
RuntimeError: If model loading fails
"""
try:
# Get verified model path
model_path = await paths.get_model_path(path)
logger.info(f"Loading PyTorch model: {model_path}")
self._model = await build_model(model_path, self._device)
except Exception as e:
raise RuntimeError(f"Failed to load PyTorch model: {e}")
def generate(
self,
tokens: list[int],
voice: torch.Tensor,
speed: float = 1.0
) -> np.ndarray:
"""Generate audio using GPU model.
Args:
tokens: Input token IDs
voice: Voice embedding tensor
speed: Speed multiplier
Returns:
Generated audio samples
Raises:
RuntimeError: If generation fails
"""
if not self.is_loaded:
raise RuntimeError("Model not loaded")
try:
# Check memory and cleanup if needed
if self._check_memory():
self._clear_memory()
# Get reference style from voice pack
ref_s = voice[len(tokens)].clone().to(self._device)
if ref_s.dim() == 1:
ref_s = ref_s.unsqueeze(0) # Add batch dimension if needed
# Generate audio
return forward(self._model, tokens, ref_s, speed)
except Exception as e:
logger.error(f"Generation failed: {e}")
raise
def _check_memory(self) -> bool:
"""Check if memory usage is above threshold."""
if torch.cuda.is_available():
memory_gb = torch.cuda.memory_allocated() / 1e9
return memory_gb > self._config.memory_threshold
return False
def _clear_memory(self) -> None:
"""Clear GPU memory."""
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()

View file

@ -0,0 +1,191 @@
"""Voice pack management and caching."""
import os
from typing import Dict, List, Optional, Union
import torch
from loguru import logger
from pydantic import BaseModel
from ..core import paths
from ..core.config import settings
from ..structures.model_schemas import VoiceConfig
class VoiceManager:
"""Manages voice loading, caching, and operations."""
def __init__(self, config: Optional[VoiceConfig] = None):
"""Initialize voice manager.
Args:
config: Optional voice configuration
"""
self._config = config or VoiceConfig()
self._voice_cache: Dict[str, torch.Tensor] = {}
def get_voice_path(self, voice_name: str) -> Optional[str]:
"""Get path to voice file.
Args:
voice_name: Name of voice
Returns:
Path to voice file if exists, None otherwise
"""
voice_path = os.path.join(settings.voices_dir, f"{voice_name}.pt")
return voice_path if os.path.exists(voice_path) else None
async def load_voice(self, voice_name: str, device: str = "cpu") -> torch.Tensor:
"""Load voice tensor.
Args:
voice_name: Name of voice to load
device: Device to load voice on
Returns:
Voice tensor
Raises:
RuntimeError: If voice loading fails
"""
voice_path = self.get_voice_path(voice_name)
if not voice_path:
raise RuntimeError(f"Voice not found: {voice_name}")
# Check cache first
cache_key = f"{voice_path}_{device}"
if self._config.use_cache and cache_key in self._voice_cache:
return self._voice_cache[cache_key]
try:
# Load voice tensor
voice = await paths.load_voice_tensor(voice_path, device=device)
# Cache if enabled
if self._config.use_cache:
self._manage_cache()
self._voice_cache[cache_key] = voice
logger.debug(f"Cached voice: {voice_name} on {device}")
return voice
except Exception as e:
raise RuntimeError(f"Failed to load voice {voice_name}: {e}")
def _manage_cache(self) -> None:
"""Manage voice cache size."""
if len(self._voice_cache) >= self._config.cache_size:
# Remove oldest voice
oldest = next(iter(self._voice_cache))
del self._voice_cache[oldest]
logger.debug(f"Removed from voice cache: {oldest}")
async def combine_voices(self, voices: List[str], device: str = "cpu") -> str:
"""Combine multiple voices into a new voice.
Args:
voices: List of voice names to combine
device: Device to load voices on
Returns:
Name of combined voice
Raises:
ValueError: If fewer than 2 voices provided
RuntimeError: If voice combination fails
"""
if len(voices) < 2:
raise ValueError("At least 2 voices are required for combination")
# Load voices
voice_tensors: List[torch.Tensor] = []
for voice in voices:
try:
voice_tensor = await self.load_voice(voice, device)
voice_tensors.append(voice_tensor)
except Exception as e:
raise RuntimeError(f"Failed to load voice {voice}: {e}")
try:
# Combine voices
combined_name = "_".join(voices)
combined_tensor = torch.mean(torch.stack(voice_tensors), dim=0)
# Save combined voice
combined_path = os.path.join(settings.voices_dir, f"{combined_name}.pt")
try:
torch.save(combined_tensor, combined_path)
except Exception as e:
raise RuntimeError(f"Failed to save combined voice: {e}")
return combined_name
except Exception as e:
raise RuntimeError(f"Failed to combine voices: {e}")
async def list_voices(self) -> List[str]:
"""List available voices.
Returns:
List of voice names
"""
voices = []
try:
for entry in os.listdir(settings.voices_dir):
if entry.endswith(".pt"):
voices.append(entry[:-3]) # Remove .pt extension
except Exception as e:
logger.error(f"Error listing voices: {e}")
return sorted(voices)
def validate_voice(self, voice_path: str) -> bool:
"""Validate voice file.
Args:
voice_path: Path to voice file
Returns:
True if valid, False otherwise
"""
try:
if not os.path.exists(voice_path):
return False
# Try loading voice
voice = torch.load(voice_path, map_location="cpu")
return isinstance(voice, torch.Tensor)
except Exception:
return False
@property
def cache_info(self) -> Dict[str, int]:
"""Get cache statistics.
Returns:
Dictionary with cache info
"""
return {
'size': len(self._voice_cache),
'max_size': self._config.cache_size
}
# Module-level instance
_manager: Optional[VoiceManager] = None
def get_manager(config: Optional[VoiceConfig] = None) -> VoiceManager:
"""Get or create global voice manager instance.
Args:
config: Optional voice configuration
Returns:
VoiceManager instance
"""
global _manager
if _manager is None:
_manager = VoiceManager(config)
return _manager

View file

@ -1,175 +0,0 @@
import os
import threading
from abc import ABC, abstractmethod
from typing import List, Tuple
import numpy as np
import torch
from loguru import logger
from ..core.config import settings
class TTSBaseModel(ABC):
_instance = None
_lock = threading.Lock()
_device = None
VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices")
@classmethod
async def setup(cls):
"""Initialize model and setup voices"""
with cls._lock:
# Set device
cuda_available = torch.cuda.is_available()
logger.info(f"CUDA available: {cuda_available}")
if cuda_available:
try:
# Test CUDA device
test_tensor = torch.zeros(1).cuda()
logger.info("CUDA test successful")
model_path = os.path.join(
settings.model_dir, settings.pytorch_model_path
)
cls._device = "cuda"
except Exception as e:
logger.error(f"CUDA test failed: {e}")
cls._device = "cpu"
else:
cls._device = "cpu"
model_path = os.path.join(settings.model_dir, settings.onnx_model_path)
logger.info(f"Initializing model on {cls._device}")
logger.info(f"Model dir: {settings.model_dir}")
logger.info(f"Model path: {model_path}")
logger.info(f"Files in model dir: {os.listdir(settings.model_dir)}")
# Initialize model first
model = cls.initialize(settings.model_dir, model_path=model_path)
if model is None:
raise RuntimeError(f"Failed to initialize {cls._device.upper()} model")
cls._instance = model
# Setup voices directory
os.makedirs(cls.VOICES_DIR, exist_ok=True)
# Copy base voices to local directory
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
if os.path.exists(base_voices_dir):
for file in os.listdir(base_voices_dir):
if file.endswith(".pt"):
voice_name = file[:-3]
voice_path = os.path.join(cls.VOICES_DIR, file)
if not os.path.exists(voice_path):
try:
logger.info(
f"Copying base voice {voice_name} to voices directory"
)
base_path = os.path.join(base_voices_dir, file)
voicepack = torch.load(
base_path,
map_location=cls._device,
weights_only=True,
)
torch.save(voicepack, voice_path)
except Exception as e:
logger.error(
f"Error copying voice {voice_name}: {str(e)}"
)
# Count voices in directory
voice_count = len(
[f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]
)
# Now that model and voices are ready, do warmup
try:
with open(
os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"core",
"don_quixote.txt",
)
) as f:
warmup_text = f.read()
except Exception as e:
logger.warning(f"Failed to load warmup text: {e}")
warmup_text = "This is a warmup text that will be split into chunks for processing."
# Use warmup service after model is fully initialized
from .warmup import WarmupService
warmup = WarmupService()
# Load and warm up voices
loaded_voices = warmup.load_voices()
await warmup.warmup_voices(warmup_text, loaded_voices)
logger.info("Model warm-up complete")
# Count voices in directory
voice_count = len(
[f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]
)
return voice_count
@classmethod
@abstractmethod
def initialize(cls, model_dir: str, model_path: str = None):
"""Initialize the model"""
pass
@classmethod
@abstractmethod
def process_text(cls, text: str, language: str) -> Tuple[str, List[int]]:
"""Process text into phonemes and tokens
Args:
text: Input text
language: Language code
Returns:
tuple[str, list[int]]: Phonemes and token IDs
"""
pass
@classmethod
@abstractmethod
def generate_from_text(
cls, text: str, voicepack: torch.Tensor, language: str, speed: float
) -> Tuple[np.ndarray, str]:
"""Generate audio from text
Args:
text: Input text
voicepack: Voice tensor
language: Language code
speed: Speed factor
Returns:
tuple[np.ndarray, str]: Generated audio samples and phonemes
"""
pass
@classmethod
@abstractmethod
def generate_from_tokens(
cls, tokens: List[int], voicepack: torch.Tensor, speed: float
) -> np.ndarray:
"""Generate audio from tokens
Args:
tokens: Token IDs
voicepack: Voice tensor
speed: Speed factor
Returns:
np.ndarray: Generated audio samples
"""
pass
@classmethod
def get_device(cls):
"""Get the current device"""
if cls._device is None:
raise RuntimeError("Model not initialized. Call setup() first.")
return cls._device

View file

@ -1,167 +0,0 @@
import os
import numpy as np
import torch
from loguru import logger
from onnxruntime import (
ExecutionMode,
GraphOptimizationLevel,
InferenceSession,
SessionOptions,
)
from ..core.config import settings
from .text_processing import phonemize, tokenize
from .tts_base import TTSBaseModel
class TTSCPUModel(TTSBaseModel):
_instance = None
_onnx_session = None
_device = "cpu"
@classmethod
def get_instance(cls):
"""Get the model instance"""
if cls._onnx_session is None:
raise RuntimeError("ONNX model not initialized. Call initialize() first.")
return cls._onnx_session
@classmethod
def initialize(cls, model_dir: str, model_path: str = None):
"""Initialize ONNX model for CPU inference"""
if cls._onnx_session is None:
try:
# Try loading ONNX model
onnx_path = os.path.join(model_dir, settings.onnx_model_path)
if not os.path.exists(onnx_path):
logger.error(f"ONNX model not found at {onnx_path}")
return None
logger.info(f"Loading ONNX model from {onnx_path}")
# Configure ONNX session for optimal performance
session_options = SessionOptions()
# Set optimization level
if settings.onnx_optimization_level == "all":
session_options.graph_optimization_level = (
GraphOptimizationLevel.ORT_ENABLE_ALL
)
elif settings.onnx_optimization_level == "basic":
session_options.graph_optimization_level = (
GraphOptimizationLevel.ORT_ENABLE_BASIC
)
else:
session_options.graph_optimization_level = (
GraphOptimizationLevel.ORT_DISABLE_ALL
)
# Configure threading
session_options.intra_op_num_threads = settings.onnx_num_threads
session_options.inter_op_num_threads = settings.onnx_inter_op_threads
# Set execution mode
session_options.execution_mode = (
ExecutionMode.ORT_PARALLEL
if settings.onnx_execution_mode == "parallel"
else ExecutionMode.ORT_SEQUENTIAL
)
# Enable/disable memory pattern optimization
session_options.enable_mem_pattern = settings.onnx_memory_pattern
# Configure CPU provider options
provider_options = {
"CPUExecutionProvider": {
"arena_extend_strategy": settings.onnx_arena_extend_strategy,
"cpu_memory_arena_cfg": "cpu:0",
}
}
session = InferenceSession(
onnx_path,
sess_options=session_options,
providers=["CPUExecutionProvider"],
provider_options=[provider_options],
)
cls._onnx_session = session
return session
except Exception as e:
logger.error(f"Failed to initialize ONNX model: {e}")
return None
return cls._onnx_session
@classmethod
def process_text(cls, text: str, language: str) -> tuple[str, list[int]]:
"""Process text into phonemes and tokens
Args:
text: Input text
language: Language code
Returns:
tuple[str, list[int]]: Phonemes and token IDs
"""
phonemes = phonemize(text, language)
tokens = tokenize(phonemes)
tokens = [0] + tokens + [0] # Add start/end tokens
return phonemes, tokens
@classmethod
def generate_from_text(
cls, text: str, voicepack: torch.Tensor, language: str, speed: float
) -> tuple[np.ndarray, str]:
"""Generate audio from text
Args:
text: Input text
voicepack: Voice tensor
language: Language code
speed: Speed factor
Returns:
tuple[np.ndarray, str]: Generated audio samples and phonemes
"""
if cls._onnx_session is None:
raise RuntimeError("ONNX model not initialized")
# Process text
phonemes, tokens = cls.process_text(text, language)
# Generate audio
audio = cls.generate_from_tokens(tokens, voicepack, speed)
return audio, phonemes
@classmethod
def generate_from_tokens(
cls, tokens: list[int], voicepack: torch.Tensor, speed: float
) -> np.ndarray:
"""Generate audio from tokens
Args:
tokens: Token IDs
voicepack: Voice tensor
speed: Speed factor
Returns:
np.ndarray: Generated audio samples
"""
if cls._onnx_session is None:
raise RuntimeError("ONNX model not initialized")
# Pre-allocate and prepare inputs
tokens_input = np.array([tokens], dtype=np.int64)
style_input = voicepack[
len(tokens) - 2
].numpy() # Already has correct dimensions
speed_input = np.full(
1, speed, dtype=np.float32
) # More efficient than ones * speed
# Run inference with optimized inputs
result = cls._onnx_session.run(
None, {"tokens": tokens_input, "style": style_input, "speed": speed_input}
)
return result[0]

View file

@ -1,262 +0,0 @@
import os
import time
import numpy as np
import torch
from ..builds.models import build_model
from loguru import logger
from ..core.config import settings
from .text_processing import phonemize, tokenize
from .tts_base import TTSBaseModel
# @torch.no_grad()
# def forward(model, tokens, ref_s, speed):
# """Forward pass through the model"""
# device = ref_s.device
# tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
# input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
# text_mask = length_to_mask(input_lengths).to(device)
# bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
# d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
# s = ref_s[:, 128:]
# d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
# x, _ = model.predictor.lstm(d)
# duration = model.predictor.duration_proj(x)
# duration = torch.sigmoid(duration).sum(axis=-1) / speed
# pred_dur = torch.round(duration).clamp(min=1).long()
# pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
# c_frame = 0
# for i in range(pred_aln_trg.size(0)):
# pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
# c_frame += pred_dur[0, i].item()
# en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
# F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
# t_en = model.text_encoder(tokens, input_lengths, text_mask)
# asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
# return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
@torch.no_grad()
def forward(model, tokens, ref_s, speed):
"""Forward pass through the model with moderate memory management"""
device = ref_s.device
try:
# Initial tensor setup with proper device placement
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
text_mask = length_to_mask(input_lengths).to(device)
# Split and clone reference signals with explicit device placement
s_content = ref_s[:, 128:].clone().to(device)
s_ref = ref_s[:, :128].clone().to(device)
# BERT and encoder pass
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
# Predictor forward pass
d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask)
x, _ = model.predictor.lstm(d)
# Duration prediction
duration = model.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long()
# Only cleanup large intermediates
del duration, x
# Alignment matrix construction
pred_aln_trg = torch.zeros(input_lengths.item(), pred_dur.sum().item(), device=device)
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
c_frame += pred_dur[0, i].item()
pred_aln_trg = pred_aln_trg.unsqueeze(0)
# Matrix multiplications with selective cleanup
en = d.transpose(-1, -2) @ pred_aln_trg
del d # Free large intermediate tensor
F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
del en # Free large intermediate tensor
# Final text encoding and decoding
t_en = model.text_encoder(tokens, input_lengths, text_mask)
asr = t_en @ pred_aln_trg
del t_en # Free large intermediate tensor
# Final decoding and transfer to CPU
output = model.decoder(asr, F0_pred, N_pred, s_ref)
result = output.squeeze().cpu().numpy()
return result
finally:
# Let PyTorch handle most cleanup automatically
# Only explicitly free the largest tensors
del pred_aln_trg, asr
# def length_to_mask(lengths):
# """Create attention mask from lengths"""
# mask = (
# torch.arange(lengths.max())
# .unsqueeze(0)
# .expand(lengths.shape[0], -1)
# .type_as(lengths)
# )
# mask = torch.gt(mask + 1, lengths.unsqueeze(1))
# return mask
def length_to_mask(lengths):
"""Create attention mask from lengths - possibly optimized version"""
max_len = lengths.max()
# Create mask directly on the same device as lengths
mask = torch.arange(max_len, device=lengths.device)[None, :].expand(
lengths.shape[0], -1
)
# Avoid type_as by using the correct dtype from the start
if lengths.dtype != mask.dtype:
mask = mask.to(dtype=lengths.dtype)
# Fuse operations using broadcasting
return mask + 1 > lengths[:, None]
class TTSGPUModel(TTSBaseModel):
_instance = None
_device = "cuda"
@classmethod
def get_instance(cls):
"""Get the model instance"""
if cls._instance is None:
raise RuntimeError("GPU model not initialized. Call initialize() first.")
return cls._instance
@classmethod
def initialize(cls, model_dir: str, model_path: str):
"""Initialize PyTorch model for GPU inference"""
if cls._instance is None and torch.cuda.is_available():
try:
logger.info("Initializing GPU model")
model_path = os.path.join(model_dir, settings.pytorch_model_path)
model = build_model(model_path, cls._device)
cls._instance = model
return model
except Exception as e:
logger.error(f"Failed to initialize GPU model: {e}")
return None
return cls._instance
@classmethod
def process_text(cls, text: str, language: str) -> tuple[str, list[int]]:
"""Process text into phonemes and tokens
Args:
text: Input text
language: Language code
Returns:
tuple[str, list[int]]: Phonemes and token IDs
"""
phonemes = phonemize(text, language)
tokens = tokenize(phonemes)
return phonemes, tokens
@classmethod
def generate_from_text(
cls, text: str, voicepack: torch.Tensor, language: str, speed: float
) -> tuple[np.ndarray, str]:
"""Generate audio from text
Args:
text: Input text
voicepack: Voice tensor
language: Language code
speed: Speed factor
Returns:
tuple[np.ndarray, str]: Generated audio samples and phonemes
"""
if cls._instance is None:
raise RuntimeError("GPU model not initialized")
# Process text
phonemes, tokens = cls.process_text(text, language)
# Generate audio
audio = cls.generate_from_tokens(tokens, voicepack, speed)
return audio, phonemes
@classmethod
def generate_from_tokens(
cls, tokens: list[int], voicepack: torch.Tensor, speed: float
) -> np.ndarray:
"""Generate audio from tokens with moderate memory management
Args:
tokens: Token IDs
voicepack: Voice tensor
speed: Speed factor
Returns:
np.ndarray: Generated audio samples
"""
if cls._instance is None:
raise RuntimeError("GPU model not initialized")
try:
device = cls._device
# Check memory pressure
if torch.cuda.is_available():
memory_allocated = torch.cuda.memory_allocated(device) / 1e9 # Convert to GB
if memory_allocated > 2.0: # 2GB limit
logger.info(
f"Memory usage above 2GB threshold:{memory_allocated:.2f}GB "
f"Clearing cache"
)
torch.cuda.empty_cache()
import gc
gc.collect()
# Get reference style with proper device placement
ref_s = voicepack[len(tokens)].clone().to(device)
# Generate audio
audio = forward(cls._instance, tokens, ref_s, speed)
return audio
except RuntimeError as e:
if "out of memory" in str(e):
# On OOM, do a full cleanup and retry
if torch.cuda.is_available():
logger.warning("Out of memory detected, performing full cleanup")
torch.cuda.synchronize()
torch.cuda.empty_cache()
import gc
gc.collect()
# Log memory stats after cleanup
memory_allocated = torch.cuda.memory_allocated(device)
memory_reserved = torch.cuda.memory_reserved(device)
logger.info(
f"Memory after OOM cleanup: "
f"Allocated: {memory_allocated / 1e9:.2f}GB, "
f"Reserved: {memory_reserved / 1e9:.2f}GB"
)
# Retry generation
ref_s = voicepack[len(tokens)].clone().to(device)
audio = forward(cls._instance, tokens, ref_s, speed)
return audio
raise
finally:
# Only synchronize at the top level, no empty_cache
if torch.cuda.is_available():
torch.cuda.synchronize()

View file

@ -1,8 +0,0 @@
import torch
if torch.cuda.is_available():
from .tts_gpu import TTSGPUModel as TTSModel
else:
from .tts_cpu import TTSCPUModel as TTSModel
__all__ = ["TTSModel"]

View file

@ -1,120 +1,114 @@
"""TTS service using model and voice managers."""
import io import io
import os import os
import re
import time import time
from functools import lru_cache from typing import List, Tuple
from typing import List, Optional, Tuple
import torch
import aiofiles.os
import numpy as np import numpy as np
import scipy.io.wavfile as wavfile import scipy.io.wavfile as wavfile
import torch
from loguru import logger from loguru import logger
from ..core.config import settings from ..core.config import settings
from ..inference.model_manager import get_manager as get_model_manager
from ..inference.voice_manager import get_manager as get_voice_manager
from .audio import AudioNormalizer, AudioService from .audio import AudioNormalizer, AudioService
from .text_processing import chunker, normalize_text from .text_processing import chunker, normalize_text
from .tts_model import TTSModel
class TTSService: class TTSService:
"""Text-to-speech service."""
def __init__(self, output_dir: str = None): def __init__(self, output_dir: str = None):
"""Initialize service.
Args:
output_dir: Optional output directory for saving audio
"""
self.output_dir = output_dir self.output_dir = output_dir
self.model = TTSModel.get_instance() self.model_manager = get_model_manager()
self.voice_manager = get_voice_manager()
self._initialized = False
self._initialization_error = None
@staticmethod async def ensure_initialized(self):
@lru_cache(maxsize=3) # Cache up to 3 most recently used voices """Ensure model is initialized."""
def _load_voice(voice_path: str) -> torch.Tensor: if self._initialized:
"""Load and cache a voice model""" return
return torch.load( if self._initialization_error:
voice_path, map_location=TTSModel.get_device(), weights_only=True raise self._initialization_error
)
def _get_voice_path(self, voice_name: str) -> Optional[str]: try:
"""Get the path to a voice file""" # Determine model path based on hardware
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice_name}.pt") if settings.use_gpu and torch.cuda.is_available():
return voice_path if os.path.exists(voice_path) else None model_path = os.path.join(settings.model_dir, settings.pytorch_model_path)
backend_type = 'pytorch_gpu'
else:
model_path = os.path.join(settings.model_dir, settings.onnx_model_path)
backend_type = 'onnx_cpu'
def _generate_audio( # Initialize model
self, text: str, voice: str, speed: float, stitch_long_output: bool = True await self.model_manager.load_model(model_path, backend_type)
) -> Tuple[torch.Tensor, float]: logger.info(f"Initialized model on {backend_type} backend")
"""Generate complete audio and return with processing time""" self._initialized = True
audio, processing_time = self._generate_audio_internal(
text, voice, speed, stitch_long_output
)
return audio, processing_time
def _generate_audio_internal( except Exception as e:
self, text: str, voice: str, speed: float, stitch_long_output: bool = True logger.error(f"Failed to initialize model: {e}")
) -> Tuple[torch.Tensor, float]: self._initialization_error = RuntimeError(f"Model initialization failed: {e}")
"""Generate audio and measure processing time""" raise self._initialization_error
async def generate_audio(
self, text: str, voice: str, speed: float = 1.0
) -> Tuple[np.ndarray, float]:
"""Generate audio for text.
Args:
text: Input text
voice: Voice name
speed: Speed multiplier
Returns:
Audio samples and processing time
"""
await self.ensure_initialized()
start_time = time.time() start_time = time.time()
try: try:
# Normalize text once at the start # Normalize text
if not text:
raise ValueError("Text is empty after preprocessing")
normalized = normalize_text(text) normalized = normalize_text(text)
if not normalized: if not normalized:
raise ValueError("Text is empty after preprocessing") raise ValueError("Text is empty after preprocessing")
text = str(normalized) text = str(normalized)
# Check voice exists # Process text into chunks
voice_path = self._get_voice_path(voice) audio_chunks = []
if not voice_path:
raise ValueError(f"Voice not found: {voice}")
# Load voice using cached loader
voicepack = self._load_voice(voice_path)
# For non-streaming, preprocess all chunks first
if stitch_long_output:
# Preprocess all chunks to phonemes/tokens
chunks_data = []
for chunk in chunker.split_text(text): for chunk in chunker.split_text(text):
try: try:
phonemes, tokens = TTSModel.process_text(chunk, voice[0]) # Process text
chunks_data.append((chunk, tokens))
except Exception as e: sequences = process_text(chunk)
logger.error( if not sequences:
f"Failed to process chunk: '{chunk}'. Error: {str(e)}"
)
continue continue
if not chunks_data: # Generate audio
raise ValueError("No chunks were processed successfully") chunk_audio = await self.model_manager.generate(
sequences[0],
# Generate audio for all chunks voice,
audio_chunks = [] speed=speed
for chunk, tokens in chunks_data:
try:
chunk_audio = TTSModel.generate_from_tokens(
tokens, voicepack, speed
) )
if chunk_audio is not None: if chunk_audio is not None:
audio_chunks.append(chunk_audio) audio_chunks.append(chunk_audio)
else:
logger.error(f"No audio generated for chunk: '{chunk}'")
except Exception as e: except Exception as e:
logger.error( logger.error(f"Failed to process chunk: '{chunk}'. Error: {str(e)}")
f"Failed to generate audio for chunk: '{chunk}'. Error: {str(e)}"
)
continue continue
if not audio_chunks: if not audio_chunks:
raise ValueError("No audio chunks were generated successfully") raise ValueError("No audio chunks were generated successfully")
# Concatenate all chunks # Combine chunks
audio = ( audio = np.concatenate(audio_chunks) if len(audio_chunks) > 1 else audio_chunks[0]
np.concatenate(audio_chunks)
if len(audio_chunks) > 1
else audio_chunks[0]
)
else:
# Process single chunk
phonemes, tokens = TTSModel.process_text(text, voice[0])
audio = TTSModel.generate_from_tokens(tokens, voicepack, speed)
processing_time = time.time() - start_time processing_time = time.time() - start_time
return audio, processing_time return audio, processing_time
@ -126,144 +120,103 @@ class TTSService:
self, self,
text: str, text: str,
voice: str, voice: str,
speed: float, speed: float = 1.0,
output_format: str = "wav", output_format: str = "wav",
silent=False,
): ):
"""Generate and yield audio chunks as they're generated for real-time streaming""" """Generate and stream audio chunks.
Args:
text: Input text
voice: Voice name
speed: Speed multiplier
output_format: Output audio format
Yields:
Audio chunks as bytes
"""
await self.ensure_initialized()
try: try:
stream_start = time.time() # Setup audio processing
# Create normalizer for consistent audio levels
stream_normalizer = AudioNormalizer() stream_normalizer = AudioNormalizer()
# Input validation and preprocessing # Normalize text
if not text:
raise ValueError("Text is empty")
preprocess_start = time.time()
normalized = normalize_text(text) normalized = normalize_text(text)
if not normalized: if not normalized:
raise ValueError("Text is empty after preprocessing") raise ValueError("Text is empty after preprocessing")
text = str(normalized) text = str(normalized)
logger.debug(
f"Text preprocessing took: {(time.time() - preprocess_start)*1000:.1f}ms"
)
# Voice validation and loading # Process chunks
voice_start = time.time()
voice_path = self._get_voice_path(voice)
if not voice_path:
raise ValueError(f"Voice not found: {voice}")
voicepack = self._load_voice(voice_path)
logger.debug(
f"Voice loading took: {(time.time() - voice_start)*1000:.1f}ms"
)
# Process chunks as they're generated
is_first = True is_first = True
chunks_processed = 0
# Process chunks as they come from generator
chunk_gen = chunker.split_text(text) chunk_gen = chunker.split_text(text)
current_chunk = next(chunk_gen, None) current_chunk = next(chunk_gen, None)
while current_chunk is not None: while current_chunk is not None:
next_chunk = next(chunk_gen, None) # Peek at next chunk next_chunk = next(chunk_gen, None)
chunks_processed += 1
try: try:
# Process text and generate audio # Process text
phonemes, tokens = TTSModel.process_text(current_chunk, voice[0]) from ..text_processing import process_text
chunk_audio = TTSModel.generate_from_tokens( sequences = process_text(current_chunk)
tokens, voicepack, speed if sequences:
# Generate audio
chunk_audio = await self.model_manager.generate(
sequences[0],
voice,
speed=speed
) )
if chunk_audio is not None: if chunk_audio is not None:
# Convert chunk with proper streaming header handling # Convert to bytes
chunk_bytes = AudioService.convert_audio( chunk_bytes = AudioService.convert_audio(
chunk_audio, chunk_audio,
24000, 24000,
output_format, output_format,
is_first_chunk=is_first, is_first_chunk=is_first,
normalizer=stream_normalizer, normalizer=stream_normalizer,
is_last_chunk=(next_chunk is None), # Last if no next chunk is_last_chunk=(next_chunk is None),
stream=True # Ensure proper streaming format handling stream=True
) )
yield chunk_bytes yield chunk_bytes
is_first = False is_first = False
else:
logger.error(f"No audio generated for chunk: '{current_chunk}'")
except Exception as e: except Exception as e:
logger.error( logger.error(f"Failed to generate audio for chunk: '{current_chunk}'. Error: {str(e)}")
f"Failed to generate audio for chunk: '{current_chunk}'. Error: {str(e)}"
)
current_chunk = next_chunk # Move to next chunk current_chunk = next_chunk
except Exception as e: except Exception as e:
logger.error(f"Error in audio generation stream: {str(e)}") logger.error(f"Error in audio generation stream: {str(e)}")
raise raise
def _save_audio(self, audio: torch.Tensor, filepath: str): async def combine_voices(self, voices: List[str]) -> str:
"""Save audio to file""" """Combine multiple voices.
os.makedirs(os.path.dirname(filepath), exist_ok=True)
wavfile.write(filepath, 24000, audio)
def _audio_to_bytes(self, audio: torch.Tensor) -> bytes: Args:
"""Convert audio tensor to WAV bytes""" voices: List of voice names
Returns:
Name of combined voice
"""
await self.ensure_initialized()
return await self.voice_manager.combine_voices(voices)
async def list_voices(self) -> List[str]:
"""List available voices.
Returns:
List of voice names
"""
return await self.voice_manager.list_voices()
def _audio_to_bytes(self, audio: np.ndarray) -> bytes:
"""Convert audio to WAV bytes.
Args:
audio: Audio samples
Returns:
WAV bytes
"""
buffer = io.BytesIO() buffer = io.BytesIO()
wavfile.write(buffer, 24000, audio) wavfile.write(buffer, 24000, audio)
return buffer.getvalue() return buffer.getvalue()
async def combine_voices(self, voices: List[str]) -> str:
"""Combine multiple voices into a new voice"""
if len(voices) < 2:
raise ValueError("At least 2 voices are required for combination")
# Load voices
t_voices: List[torch.Tensor] = []
v_name: List[str] = []
for voice in voices:
try:
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt")
voicepack = torch.load(
voice_path, map_location=TTSModel.get_device(), weights_only=True
)
t_voices.append(voicepack)
v_name.append(voice)
except Exception as e:
raise ValueError(f"Failed to load voice {voice}: {str(e)}")
# Combine voices
try:
f: str = "_".join(v_name)
v = torch.mean(torch.stack(t_voices), dim=0)
combined_path = os.path.join(TTSModel.VOICES_DIR, f"{f}.pt")
# Save combined voice
try:
torch.save(v, combined_path)
except Exception as e:
raise RuntimeError(
f"Failed to save combined voice to {combined_path}: {str(e)}"
)
return f
except Exception as e:
if not isinstance(e, (ValueError, RuntimeError)):
raise RuntimeError(f"Error combining voices: {str(e)}")
raise
async def list_voices(self) -> List[str]:
"""List all available voices"""
voices = []
try:
it = await aiofiles.os.scandir(TTSModel.VOICES_DIR)
for entry in it:
if entry.name.endswith(".pt"):
voices.append(entry.name[:-3]) # Remove .pt extension
except Exception as e:
logger.error(f"Error listing voices: {str(e)}")
return sorted(voices)

View file

@ -0,0 +1,26 @@
"""Model and voice configuration schemas."""
from pydantic import BaseModel
class ModelConfig(BaseModel):
"""Model configuration."""
optimization_level: str = "all" # all, basic, none
num_threads: int = 4
inter_op_threads: int = 4
execution_mode: str = "parallel" # parallel, sequential
memory_pattern: bool = True
arena_extend_strategy: str = "kNextPowerOfTwo"
class Config:
frozen = True # Make config immutable
class VoiceConfig(BaseModel):
"""Voice configuration."""
use_cache: bool = True
cache_size: int = 3 # Number of voices to cache
validate_on_load: bool = True # Whether to validate voices when loading
class Config:
frozen = True # Make config immutable