Refactor: Consolidate PyTorch CPU and GPU backends into a single PyTorchBackend class; remove obsolete files

This commit is contained in:
remsky 2025-01-25 13:33:42 -07:00
parent 3547d95ee6
commit 00497f8872
4 changed files with 258 additions and 426 deletions

View file

@ -4,8 +4,7 @@ from .base import BaseModelBackend
from .model_manager import ModelManager, get_manager from .model_manager import ModelManager, get_manager
from .onnx_cpu import ONNXCPUBackend from .onnx_cpu import ONNXCPUBackend
from .onnx_gpu import ONNXGPUBackend from .onnx_gpu import ONNXGPUBackend
from .pytorch_cpu import PyTorchCPUBackend from .pytorch_backend import PyTorchBackend
from .pytorch_gpu import PyTorchGPUBackend
__all__ = [ __all__ = [
'BaseModelBackend', 'BaseModelBackend',
@ -13,6 +12,5 @@ __all__ = [
'get_manager', 'get_manager',
'ONNXCPUBackend', 'ONNXCPUBackend',
'ONNXGPUBackend', 'ONNXGPUBackend',
'PyTorchCPUBackend', 'PyTorchBackend',
'PyTorchGPUBackend',
] ]

View file

@ -12,8 +12,7 @@ from ..core.model_config import ModelConfig, model_config
from .base import BaseModelBackend from .base import BaseModelBackend
from .onnx_cpu import ONNXCPUBackend from .onnx_cpu import ONNXCPUBackend
from .onnx_gpu import ONNXGPUBackend from .onnx_gpu import ONNXGPUBackend
from .pytorch_cpu import PyTorchCPUBackend from .pytorch_backend import PyTorchBackend
from .pytorch_gpu import PyTorchGPUBackend
from .session_pool import CPUSessionPool, StreamingSessionPool from .session_pool import CPUSessionPool, StreamingSessionPool
@ -63,18 +62,18 @@ class ModelManager:
self._current_backend = 'onnx_gpu' self._current_backend = 'onnx_gpu'
logger.info("Initialized new ONNX GPU backend") logger.info("Initialized new ONNX GPU backend")
else: else:
self._backends['pytorch_gpu'] = PyTorchGPUBackend() self._backends['pytorch'] = PyTorchBackend()
self._current_backend = 'pytorch_gpu' self._current_backend = 'pytorch'
logger.info("Initialized new PyTorch GPU backend") logger.info("Initialized new PyTorch backend on GPU")
else: else:
if settings.use_onnx: if settings.use_onnx:
self._backends['onnx_cpu'] = ONNXCPUBackend() self._backends['onnx_cpu'] = ONNXCPUBackend()
self._current_backend = 'onnx_cpu' self._current_backend = 'onnx_cpu'
logger.info("Initialized new ONNX CPU backend") logger.info("Initialized new ONNX CPU backend")
else: else:
self._backends['pytorch_cpu'] = PyTorchCPUBackend() self._backends['pytorch'] = PyTorchBackend()
self._current_backend = 'pytorch_cpu' self._current_backend = 'pytorch'
logger.info("Initialized new PyTorch CPU backend") logger.info("Initialized new PyTorch backend on CPU")
# Initialize locks for each backend # Initialize locks for each backend
for backend in self._backends: for backend in self._backends:
@ -95,10 +94,10 @@ class ModelManager:
""" """
try: try:
# Determine backend type based on settings # Determine backend type based on settings
if settings.use_gpu and torch.cuda.is_available(): if settings.use_onnx:
backend_type = 'pytorch_gpu' if not settings.use_onnx else 'onnx_gpu' backend_type = 'onnx_gpu' if settings.use_gpu and torch.cuda.is_available() else 'onnx_cpu'
else: else:
backend_type = 'pytorch_cpu' if not settings.use_onnx else 'onnx_cpu' backend_type = 'pytorch'
# Get backend # Get backend
backend = self.get_backend(backend_type) backend = self.get_backend(backend_type)
@ -167,13 +166,10 @@ class ModelManager:
Returns: Returns:
Backend type to use Backend type to use
""" """
has_gpu = settings.use_gpu and torch.cuda.is_available()
# If ONNX is preferred or model is ONNX format # If ONNX is preferred or model is ONNX format
if settings.use_onnx or model_path.lower().endswith('.onnx'): if settings.use_onnx or model_path.lower().endswith('.onnx'):
return 'onnx_gpu' if has_gpu else 'onnx_cpu' return 'onnx_gpu' if settings.use_gpu and torch.cuda.is_available() else 'onnx_cpu'
else: return 'pytorch'
return 'pytorch_gpu' if has_gpu else 'pytorch_cpu'
async def load_model( async def load_model(
self, self,

View file

@ -1,225 +1,244 @@
"""GPU-based PyTorch inference backend.""" """PyTorch inference backend with environment-based configuration."""
import gc import gc
from typing import Optional from typing import Optional
from contextlib import nullcontext
import numpy as np from typing import Optional
import torch
from loguru import logger import numpy as np
import torch
from ..builds.models import build_model from loguru import logger
from ..core import paths
from ..core.model_config import model_config from ..builds.models import build_model
from .base import BaseModelBackend from ..core import paths
from ..core.model_config import model_config
from ..core.config import settings
class CUDAStreamManager: from .base import BaseModelBackend
"""CUDA stream manager."""
def __init__(self, num_streams: int): class CUDAStreamManager:
"""Initialize stream manager. """CUDA stream manager for GPU operations."""
Args: def __init__(self, num_streams: int):
num_streams: Number of CUDA streams """Initialize stream manager.
"""
self.streams = [torch.cuda.Stream() for _ in range(num_streams)] Args:
self._current = 0 num_streams: Number of CUDA streams
"""
def get_next_stream(self) -> torch.cuda.Stream: self.streams = [torch.cuda.Stream() for _ in range(num_streams)]
"""Get next available stream. self._current = 0
Returns: def get_next_stream(self) -> torch.cuda.Stream:
CUDA stream """Get next available stream.
"""
stream = self.streams[self._current] Returns:
self._current = (self._current + 1) % len(self.streams) CUDA stream
return stream """
stream = self.streams[self._current]
self._current = (self._current + 1) % len(self.streams)
@torch.no_grad() return stream
def forward(
model: torch.nn.Module,
tokens: list[int], @torch.no_grad()
ref_s: torch.Tensor, def forward(
speed: float, model: torch.nn.Module,
stream: Optional[torch.cuda.Stream] = None tokens: list[int],
) -> np.ndarray: ref_s: torch.Tensor,
"""Forward pass through model. speed: float,
stream: Optional[torch.cuda.Stream] = None,
Args: ) -> np.ndarray:
model: PyTorch model """Forward pass through model.
tokens: Input tokens
ref_s: Reference signal (shape: [1, n_features]) Args:
speed: Speed multiplier model: PyTorch model
stream: Optional CUDA stream tokens: Input tokens
ref_s: Reference signal
Returns: speed: Speed multiplier
Generated audio stream: Optional CUDA stream (GPU only)
"""
device = ref_s.device Returns:
Generated audio
# Use provided stream or default """
with torch.cuda.stream(stream) if stream else torch.cuda.default_stream(): device = ref_s.device
# Initial tensor setup
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device) # Use provided stream or default for GPU
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device) context = (
text_mask = length_to_mask(input_lengths).to(device) torch.cuda.stream(stream) if stream and device.type == "cuda" else nullcontext()
)
# Split reference signals (style_dim=128 from config)
style_dim = 128 with context:
s_ref = ref_s[:, :style_dim].clone().to(device) # Initial tensor setup
s_content = ref_s[:, style_dim:].clone().to(device) tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
# BERT and encoder pass text_mask = length_to_mask(input_lengths).to(device)
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2) # Split reference signals
style_dim = 128
# Predictor forward pass s_ref = ref_s[:, :style_dim].clone().to(device)
d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask) s_content = ref_s[:, style_dim:].clone().to(device)
x, _ = model.predictor.lstm(d)
# BERT and encoder pass
# Duration prediction bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
duration = model.predictor.duration_proj(x) d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long() # Predictor forward pass
del duration, x d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask)
x, _ = model.predictor.lstm(d)
# Alignment matrix construction
pred_aln_trg = torch.zeros(input_lengths.item(), pred_dur.sum().item(), device=device) # Duration prediction
c_frame = 0 duration = model.predictor.duration_proj(x)
for i in range(pred_aln_trg.size(0)): duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_aln_trg[i, c_frame:c_frame + pred_dur[0, i].item()] = 1 pred_dur = torch.round(duration).clamp(min=1).long()
c_frame += pred_dur[0, i].item() del duration, x
pred_aln_trg = pred_aln_trg.unsqueeze(0)
# Alignment matrix construction
# Matrix multiplications pred_aln_trg = torch.zeros(
en = d.transpose(-1, -2) @ pred_aln_trg input_lengths.item(), pred_dur.sum().item(), device=device
del d )
c_frame = 0
F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content) for i in range(pred_aln_trg.size(0)):
del en pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
c_frame += pred_dur[0, i].item()
# Final text encoding and decoding pred_aln_trg = pred_aln_trg.unsqueeze(0)
t_en = model.text_encoder(tokens, input_lengths, text_mask)
asr = t_en @ pred_aln_trg # Matrix multiplications
del t_en en = d.transpose(-1, -2) @ pred_aln_trg
del d
# Generate output
output = model.decoder(asr, F0_pred, N_pred, s_ref) F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
del en
# Ensure operation completion if using custom stream
if stream: # Final text encoding and decoding
stream.synchronize() t_en = model.text_encoder(tokens, input_lengths, text_mask)
asr = t_en @ pred_aln_trg
return output.squeeze().cpu().numpy() del t_en
# Generate output
def length_to_mask(lengths: torch.Tensor) -> torch.Tensor: output = model.decoder(asr, F0_pred, N_pred, s_ref)
"""Create attention mask from lengths."""
max_len = lengths.max() # Ensure operation completion if using custom stream
mask = torch.arange(max_len, device=lengths.device)[None, :].expand(lengths.shape[0], -1) if stream and device.type == "cuda":
if lengths.dtype != mask.dtype: stream.synchronize()
mask = mask.to(dtype=lengths.dtype)
return mask + 1 > lengths[:, None] return output.squeeze().cpu().numpy()
class PyTorchGPUBackend(BaseModelBackend): def length_to_mask(lengths: torch.Tensor) -> torch.Tensor:
"""PyTorch GPU inference backend.""" """Create attention mask from lengths."""
max_len = lengths.max()
def __init__(self): mask = torch.arange(max_len, device=lengths.device)[None, :].expand(
"""Initialize GPU backend.""" lengths.shape[0], -1
super().__init__() )
from ..core.config import settings if lengths.dtype != mask.dtype:
if not (settings.use_gpu and torch.cuda.is_available()): mask = mask.to(dtype=lengths.dtype)
raise RuntimeError("GPU backend requires GPU support and use_gpu=True") return mask + 1 > lengths[:, None]
self._device = "cuda" # Device is enforced by backend selection in model_manager
self._model: Optional[torch.nn.Module] = None
class PyTorchBackend(BaseModelBackend):
# Configure GPU settings """PyTorch inference backend with environment-based configuration."""
config = model_config.pytorch_gpu
if config.sync_cuda: def __init__(self):
torch.cuda.synchronize() """Initialize backend based on environment configuration."""
torch.cuda.set_device(config.device_id) super().__init__()
# Initialize stream manager # Configure device based on settings
self._stream_manager = CUDAStreamManager(config.cuda_streams) self._device = (
"cuda" if settings.use_gpu and torch.cuda.is_available() else "cpu"
async def load_model(self, path: str) -> None: )
"""Load PyTorch model. self._model: Optional[torch.nn.Module] = None
Args: # Apply device-specific configurations
path: Path to model file if self._device == "cuda":
config = model_config.pytorch_gpu
Raises: if config.sync_cuda:
RuntimeError: If model loading fails torch.cuda.synchronize()
""" torch.cuda.set_device(config.device_id)
try: self._stream_manager = CUDAStreamManager(config.cuda_streams)
# Get verified model path else:
model_path = await paths.get_model_path(path) config = model_config.pytorch_cpu
if config.num_threads > 0:
logger.info(f"Loading PyTorch model: {model_path}") torch.set_num_threads(config.num_threads)
self._model = await build_model(model_path, self._device) if config.pin_memory:
torch.set_default_tensor_type(torch.FloatTensor)
except Exception as e:
raise RuntimeError(f"Failed to load PyTorch model: {e}") async def load_model(self, path: str) -> None:
"""Load PyTorch model.
def generate(
self, Args:
tokens: list[int], path: Path to model file
voice: torch.Tensor,
speed: float = 1.0 Raises:
) -> np.ndarray: RuntimeError: If model loading fails
"""Generate audio using GPU model. """
try:
Args: # Get verified model path
tokens: Input token IDs model_path = await paths.get_model_path(path)
voice: Voice embedding tensor
speed: Speed multiplier logger.info(f"Loading PyTorch model on {self._device}: {model_path}")
self._model = await build_model(model_path, self._device)
Returns:
Generated audio samples except Exception as e:
raise RuntimeError(f"Failed to load PyTorch model: {e}")
Raises:
RuntimeError: If generation fails def generate(
""" self, tokens: list[int], voice: torch.Tensor, speed: float = 1.0
if not self.is_loaded: ) -> np.ndarray:
raise RuntimeError("Model not loaded") """Generate audio using model.
try: Args:
# Check memory and cleanup if needed tokens: Input token IDs
if self._check_memory(): voice: Voice embedding tensor
self._clear_memory() speed: Speed multiplier
# Get reference style from voice pack Returns:
ref_s = voice[len(tokens)].clone().to(self._device) Generated audio samples
if ref_s.dim() == 1:
ref_s = ref_s.unsqueeze(0) # Add batch dimension if needed Raises:
RuntimeError: If generation fails
# Get next available stream """
stream = self._stream_manager.get_next_stream() if not self.is_loaded:
raise RuntimeError("Model not loaded")
# Generate audio using stream
return forward(self._model, tokens, ref_s, speed, stream) try:
# Memory management for GPU
except Exception as e: if self._device == "cuda":
logger.error(f"Generation failed: {e}") if self._check_memory():
if model_config.pytorch_gpu.retry_on_oom and "out of memory" in str(e).lower(): self._clear_memory()
self._clear_memory() stream = self._stream_manager.get_next_stream()
return self.generate(tokens, voice, speed) # Retry once else:
raise stream = None
finally:
if model_config.pytorch_gpu.sync_cuda: # Get reference style from voice pack
torch.cuda.synchronize() ref_s = voice[len(tokens)].clone().to(self._device)
if ref_s.dim() == 1:
def _check_memory(self) -> bool: ref_s = ref_s.unsqueeze(0)
"""Check if memory usage is above threshold."""
if torch.cuda.is_available(): # Generate audio
memory_gb = torch.cuda.memory_allocated() / 1e9 return forward(self._model, tokens, ref_s, speed, stream)
return memory_gb > model_config.pytorch_gpu.memory_threshold
return False except Exception as e:
logger.error(f"Generation failed: {e}")
def _clear_memory(self) -> None: if (
"""Clear GPU memory.""" self._device == "cuda"
if torch.cuda.is_available(): and model_config.pytorch_gpu.retry_on_oom
torch.cuda.empty_cache() and "out of memory" in str(e).lower()
gc.collect() ):
self._clear_memory()
return self.generate(tokens, voice, speed)
raise
finally:
if self._device == "cuda" and model_config.pytorch_gpu.sync_cuda:
torch.cuda.synchronize()
def _check_memory(self) -> bool:
"""Check if memory usage is above threshold."""
if self._device == "cuda":
memory_gb = torch.cuda.memory_allocated() / 1e9
return memory_gb > model_config.pytorch_gpu.memory_threshold
return False
def _clear_memory(self) -> None:
"""Clear device memory."""
if self._device == "cuda":
torch.cuda.empty_cache()
gc.collect()

View file

@ -1,181 +0,0 @@
"""CPU-based PyTorch inference backend."""
import gc
from typing import Optional
import numpy as np
import torch
from loguru import logger
from ..builds.models import build_model
from ..core import paths
from ..core.model_config import model_config
from .base import BaseModelBackend
@torch.no_grad()
def forward(model: torch.nn.Module, tokens: list[int], ref_s: torch.Tensor, speed: float) -> np.ndarray:
"""Forward pass through model with memory management.
Args:
model: PyTorch model
tokens: Input tokens
ref_s: Reference signal
speed: Speed multiplier
Returns:
Generated audio
"""
device = ref_s.device
pred_aln_trg = None
asr = None
try:
# Initial tensor setup
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
text_mask = length_to_mask(input_lengths).to(device)
# Split reference signals
s_content = ref_s[:, 128:].clone().to(device)
s_ref = ref_s[:, :128].clone().to(device)
# BERT and encoder pass
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
# Predictor forward pass
d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask)
x, _ = model.predictor.lstm(d)
# Duration prediction
duration = model.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long()
del duration, x # Free large intermediates
# Alignment matrix construction
pred_aln_trg = torch.zeros(
input_lengths.item(),
pred_dur.sum().item(),
device=device
)
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame:c_frame + pred_dur[0, i].item()] = 1
c_frame += pred_dur[0, i].item()
pred_aln_trg = pred_aln_trg.unsqueeze(0)
# Matrix multiplications with cleanup
en = d.transpose(-1, -2) @ pred_aln_trg
del d
F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
del en
# Final text encoding and decoding
t_en = model.text_encoder(tokens, input_lengths, text_mask)
asr = t_en @ pred_aln_trg
del t_en
# Generate output
output = model.decoder(asr, F0_pred, N_pred, s_ref)
result = output.squeeze().cpu().numpy()
return result
finally:
# Clean up largest tensors if they were created
if pred_aln_trg is not None:
del pred_aln_trg
if asr is not None:
del asr
def length_to_mask(lengths: torch.Tensor) -> torch.Tensor:
"""Create attention mask from lengths.
Args:
lengths: Sequence lengths
Returns:
Boolean mask tensor
"""
max_len = lengths.max()
mask = torch.arange(max_len, device=lengths.device)[None, :].expand(
lengths.shape[0], -1
)
if lengths.dtype != mask.dtype:
mask = mask.to(dtype=lengths.dtype)
return mask + 1 > lengths[:, None]
class PyTorchCPUBackend(BaseModelBackend):
"""PyTorch CPU inference backend."""
def __init__(self):
"""Initialize CPU backend."""
super().__init__()
self._device = "cpu"
self._model: Optional[torch.nn.Module] = None
# Configure PyTorch CPU settings
config = model_config.pytorch_cpu
if config.num_threads > 0:
torch.set_num_threads(config.num_threads)
if config.pin_memory:
torch.set_default_tensor_type(torch.FloatTensor)
async def load_model(self, path: str) -> None:
"""Load PyTorch model.
Args:
path: Path to model file
Raises:
RuntimeError: If model loading fails
"""
try:
# Get verified model path
model_path = await paths.get_model_path(path)
logger.info(f"Loading PyTorch model on CPU: {model_path}")
self._model = await build_model(model_path, self._device)
except Exception as e:
raise RuntimeError(f"Failed to load PyTorch model: {e}")
def generate(
self,
tokens: list[int],
voice: torch.Tensor,
speed: float = 1.0
) -> np.ndarray:
"""Generate audio using CPU model.
Args:
tokens: Input token IDs
voice: Voice embedding tensor
speed: Speed multiplier
Returns:
Generated audio samples
Raises:
RuntimeError: If generation fails
"""
if not self.is_loaded:
raise RuntimeError("Model not loaded")
try:
# Prepare input
ref_s = voice[len(tokens)].clone()
# Generate audio
return forward(self._model, tokens, ref_s, speed)
except Exception as e:
raise RuntimeError(f"Generation failed: {e}")
finally:
# Clean up memory
gc.collect()