Refactor: Consolidate PyTorch CPU and GPU backends into a single PyTorchBackend class; remove obsolete files

This commit is contained in:
remsky 2025-01-25 13:33:42 -07:00
parent 3547d95ee6
commit 00497f8872
4 changed files with 258 additions and 426 deletions

View file

@ -4,8 +4,7 @@ from .base import BaseModelBackend
from .model_manager import ModelManager, get_manager from .model_manager import ModelManager, get_manager
from .onnx_cpu import ONNXCPUBackend from .onnx_cpu import ONNXCPUBackend
from .onnx_gpu import ONNXGPUBackend from .onnx_gpu import ONNXGPUBackend
from .pytorch_cpu import PyTorchCPUBackend from .pytorch_backend import PyTorchBackend
from .pytorch_gpu import PyTorchGPUBackend
__all__ = [ __all__ = [
'BaseModelBackend', 'BaseModelBackend',
@ -13,6 +12,5 @@ __all__ = [
'get_manager', 'get_manager',
'ONNXCPUBackend', 'ONNXCPUBackend',
'ONNXGPUBackend', 'ONNXGPUBackend',
'PyTorchCPUBackend', 'PyTorchBackend',
'PyTorchGPUBackend',
] ]

View file

@ -12,8 +12,7 @@ from ..core.model_config import ModelConfig, model_config
from .base import BaseModelBackend from .base import BaseModelBackend
from .onnx_cpu import ONNXCPUBackend from .onnx_cpu import ONNXCPUBackend
from .onnx_gpu import ONNXGPUBackend from .onnx_gpu import ONNXGPUBackend
from .pytorch_cpu import PyTorchCPUBackend from .pytorch_backend import PyTorchBackend
from .pytorch_gpu import PyTorchGPUBackend
from .session_pool import CPUSessionPool, StreamingSessionPool from .session_pool import CPUSessionPool, StreamingSessionPool
@ -63,18 +62,18 @@ class ModelManager:
self._current_backend = 'onnx_gpu' self._current_backend = 'onnx_gpu'
logger.info("Initialized new ONNX GPU backend") logger.info("Initialized new ONNX GPU backend")
else: else:
self._backends['pytorch_gpu'] = PyTorchGPUBackend() self._backends['pytorch'] = PyTorchBackend()
self._current_backend = 'pytorch_gpu' self._current_backend = 'pytorch'
logger.info("Initialized new PyTorch GPU backend") logger.info("Initialized new PyTorch backend on GPU")
else: else:
if settings.use_onnx: if settings.use_onnx:
self._backends['onnx_cpu'] = ONNXCPUBackend() self._backends['onnx_cpu'] = ONNXCPUBackend()
self._current_backend = 'onnx_cpu' self._current_backend = 'onnx_cpu'
logger.info("Initialized new ONNX CPU backend") logger.info("Initialized new ONNX CPU backend")
else: else:
self._backends['pytorch_cpu'] = PyTorchCPUBackend() self._backends['pytorch'] = PyTorchBackend()
self._current_backend = 'pytorch_cpu' self._current_backend = 'pytorch'
logger.info("Initialized new PyTorch CPU backend") logger.info("Initialized new PyTorch backend on CPU")
# Initialize locks for each backend # Initialize locks for each backend
for backend in self._backends: for backend in self._backends:
@ -95,10 +94,10 @@ class ModelManager:
""" """
try: try:
# Determine backend type based on settings # Determine backend type based on settings
if settings.use_gpu and torch.cuda.is_available(): if settings.use_onnx:
backend_type = 'pytorch_gpu' if not settings.use_onnx else 'onnx_gpu' backend_type = 'onnx_gpu' if settings.use_gpu and torch.cuda.is_available() else 'onnx_cpu'
else: else:
backend_type = 'pytorch_cpu' if not settings.use_onnx else 'onnx_cpu' backend_type = 'pytorch'
# Get backend # Get backend
backend = self.get_backend(backend_type) backend = self.get_backend(backend_type)
@ -167,13 +166,10 @@ class ModelManager:
Returns: Returns:
Backend type to use Backend type to use
""" """
has_gpu = settings.use_gpu and torch.cuda.is_available()
# If ONNX is preferred or model is ONNX format # If ONNX is preferred or model is ONNX format
if settings.use_onnx or model_path.lower().endswith('.onnx'): if settings.use_onnx or model_path.lower().endswith('.onnx'):
return 'onnx_gpu' if has_gpu else 'onnx_cpu' return 'onnx_gpu' if settings.use_gpu and torch.cuda.is_available() else 'onnx_cpu'
else: return 'pytorch'
return 'pytorch_gpu' if has_gpu else 'pytorch_cpu'
async def load_model( async def load_model(
self, self,

View file

@ -1,7 +1,9 @@
"""GPU-based PyTorch inference backend.""" """PyTorch inference backend with environment-based configuration."""
import gc import gc
from typing import Optional from typing import Optional
from contextlib import nullcontext
from typing import Optional
import numpy as np import numpy as np
import torch import torch
@ -10,11 +12,12 @@ from loguru import logger
from ..builds.models import build_model from ..builds.models import build_model
from ..core import paths from ..core import paths
from ..core.model_config import model_config from ..core.model_config import model_config
from ..core.config import settings
from .base import BaseModelBackend from .base import BaseModelBackend
class CUDAStreamManager: class CUDAStreamManager:
"""CUDA stream manager.""" """CUDA stream manager for GPU operations."""
def __init__(self, num_streams: int): def __init__(self, num_streams: int):
"""Initialize stream manager. """Initialize stream manager.
@ -42,30 +45,34 @@ def forward(
tokens: list[int], tokens: list[int],
ref_s: torch.Tensor, ref_s: torch.Tensor,
speed: float, speed: float,
stream: Optional[torch.cuda.Stream] = None stream: Optional[torch.cuda.Stream] = None,
) -> np.ndarray: ) -> np.ndarray:
"""Forward pass through model. """Forward pass through model.
Args: Args:
model: PyTorch model model: PyTorch model
tokens: Input tokens tokens: Input tokens
ref_s: Reference signal (shape: [1, n_features]) ref_s: Reference signal
speed: Speed multiplier speed: Speed multiplier
stream: Optional CUDA stream stream: Optional CUDA stream (GPU only)
Returns: Returns:
Generated audio Generated audio
""" """
device = ref_s.device device = ref_s.device
# Use provided stream or default # Use provided stream or default for GPU
with torch.cuda.stream(stream) if stream else torch.cuda.default_stream(): context = (
torch.cuda.stream(stream) if stream and device.type == "cuda" else nullcontext()
)
with context:
# Initial tensor setup # Initial tensor setup
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device) tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device) input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
text_mask = length_to_mask(input_lengths).to(device) text_mask = length_to_mask(input_lengths).to(device)
# Split reference signals (style_dim=128 from config) # Split reference signals
style_dim = 128 style_dim = 128
s_ref = ref_s[:, :style_dim].clone().to(device) s_ref = ref_s[:, :style_dim].clone().to(device)
s_content = ref_s[:, style_dim:].clone().to(device) s_content = ref_s[:, style_dim:].clone().to(device)
@ -85,10 +92,12 @@ def forward(
del duration, x del duration, x
# Alignment matrix construction # Alignment matrix construction
pred_aln_trg = torch.zeros(input_lengths.item(), pred_dur.sum().item(), device=device) pred_aln_trg = torch.zeros(
input_lengths.item(), pred_dur.sum().item(), device=device
)
c_frame = 0 c_frame = 0
for i in range(pred_aln_trg.size(0)): for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame:c_frame + pred_dur[0, i].item()] = 1 pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
c_frame += pred_dur[0, i].item() c_frame += pred_dur[0, i].item()
pred_aln_trg = pred_aln_trg.unsqueeze(0) pred_aln_trg = pred_aln_trg.unsqueeze(0)
@ -108,7 +117,7 @@ def forward(
output = model.decoder(asr, F0_pred, N_pred, s_ref) output = model.decoder(asr, F0_pred, N_pred, s_ref)
# Ensure operation completion if using custom stream # Ensure operation completion if using custom stream
if stream: if stream and device.type == "cuda":
stream.synchronize() stream.synchronize()
return output.squeeze().cpu().numpy() return output.squeeze().cpu().numpy()
@ -117,32 +126,40 @@ def forward(
def length_to_mask(lengths: torch.Tensor) -> torch.Tensor: def length_to_mask(lengths: torch.Tensor) -> torch.Tensor:
"""Create attention mask from lengths.""" """Create attention mask from lengths."""
max_len = lengths.max() max_len = lengths.max()
mask = torch.arange(max_len, device=lengths.device)[None, :].expand(lengths.shape[0], -1) mask = torch.arange(max_len, device=lengths.device)[None, :].expand(
lengths.shape[0], -1
)
if lengths.dtype != mask.dtype: if lengths.dtype != mask.dtype:
mask = mask.to(dtype=lengths.dtype) mask = mask.to(dtype=lengths.dtype)
return mask + 1 > lengths[:, None] return mask + 1 > lengths[:, None]
class PyTorchGPUBackend(BaseModelBackend): class PyTorchBackend(BaseModelBackend):
"""PyTorch GPU inference backend.""" """PyTorch inference backend with environment-based configuration."""
def __init__(self): def __init__(self):
"""Initialize GPU backend.""" """Initialize backend based on environment configuration."""
super().__init__() super().__init__()
from ..core.config import settings
if not (settings.use_gpu and torch.cuda.is_available()): # Configure device based on settings
raise RuntimeError("GPU backend requires GPU support and use_gpu=True") self._device = (
self._device = "cuda" # Device is enforced by backend selection in model_manager "cuda" if settings.use_gpu and torch.cuda.is_available() else "cpu"
)
self._model: Optional[torch.nn.Module] = None self._model: Optional[torch.nn.Module] = None
# Configure GPU settings # Apply device-specific configurations
config = model_config.pytorch_gpu if self._device == "cuda":
if config.sync_cuda: config = model_config.pytorch_gpu
torch.cuda.synchronize() if config.sync_cuda:
torch.cuda.set_device(config.device_id) torch.cuda.synchronize()
torch.cuda.set_device(config.device_id)
# Initialize stream manager self._stream_manager = CUDAStreamManager(config.cuda_streams)
self._stream_manager = CUDAStreamManager(config.cuda_streams) else:
config = model_config.pytorch_cpu
if config.num_threads > 0:
torch.set_num_threads(config.num_threads)
if config.pin_memory:
torch.set_default_tensor_type(torch.FloatTensor)
async def load_model(self, path: str) -> None: async def load_model(self, path: str) -> None:
"""Load PyTorch model. """Load PyTorch model.
@ -157,19 +174,16 @@ class PyTorchGPUBackend(BaseModelBackend):
# Get verified model path # Get verified model path
model_path = await paths.get_model_path(path) model_path = await paths.get_model_path(path)
logger.info(f"Loading PyTorch model: {model_path}") logger.info(f"Loading PyTorch model on {self._device}: {model_path}")
self._model = await build_model(model_path, self._device) self._model = await build_model(model_path, self._device)
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to load PyTorch model: {e}") raise RuntimeError(f"Failed to load PyTorch model: {e}")
def generate( def generate(
self, self, tokens: list[int], voice: torch.Tensor, speed: float = 1.0
tokens: list[int],
voice: torch.Tensor,
speed: float = 1.0
) -> np.ndarray: ) -> np.ndarray:
"""Generate audio using GPU model. """Generate audio using model.
Args: Args:
tokens: Input token IDs tokens: Input token IDs
@ -186,40 +200,45 @@ class PyTorchGPUBackend(BaseModelBackend):
raise RuntimeError("Model not loaded") raise RuntimeError("Model not loaded")
try: try:
# Check memory and cleanup if needed # Memory management for GPU
if self._check_memory(): if self._device == "cuda":
self._clear_memory() if self._check_memory():
self._clear_memory()
stream = self._stream_manager.get_next_stream()
else:
stream = None
# Get reference style from voice pack # Get reference style from voice pack
ref_s = voice[len(tokens)].clone().to(self._device) ref_s = voice[len(tokens)].clone().to(self._device)
if ref_s.dim() == 1: if ref_s.dim() == 1:
ref_s = ref_s.unsqueeze(0) # Add batch dimension if needed ref_s = ref_s.unsqueeze(0)
# Get next available stream # Generate audio
stream = self._stream_manager.get_next_stream()
# Generate audio using stream
return forward(self._model, tokens, ref_s, speed, stream) return forward(self._model, tokens, ref_s, speed, stream)
except Exception as e: except Exception as e:
logger.error(f"Generation failed: {e}") logger.error(f"Generation failed: {e}")
if model_config.pytorch_gpu.retry_on_oom and "out of memory" in str(e).lower(): if (
self._device == "cuda"
and model_config.pytorch_gpu.retry_on_oom
and "out of memory" in str(e).lower()
):
self._clear_memory() self._clear_memory()
return self.generate(tokens, voice, speed) # Retry once return self.generate(tokens, voice, speed)
raise raise
finally: finally:
if model_config.pytorch_gpu.sync_cuda: if self._device == "cuda" and model_config.pytorch_gpu.sync_cuda:
torch.cuda.synchronize() torch.cuda.synchronize()
def _check_memory(self) -> bool: def _check_memory(self) -> bool:
"""Check if memory usage is above threshold.""" """Check if memory usage is above threshold."""
if torch.cuda.is_available(): if self._device == "cuda":
memory_gb = torch.cuda.memory_allocated() / 1e9 memory_gb = torch.cuda.memory_allocated() / 1e9
return memory_gb > model_config.pytorch_gpu.memory_threshold return memory_gb > model_config.pytorch_gpu.memory_threshold
return False return False
def _clear_memory(self) -> None: def _clear_memory(self) -> None:
"""Clear GPU memory.""" """Clear device memory."""
if torch.cuda.is_available(): if self._device == "cuda":
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()

View file

@ -1,181 +0,0 @@
"""CPU-based PyTorch inference backend."""
import gc
from typing import Optional
import numpy as np
import torch
from loguru import logger
from ..builds.models import build_model
from ..core import paths
from ..core.model_config import model_config
from .base import BaseModelBackend
@torch.no_grad()
def forward(model: torch.nn.Module, tokens: list[int], ref_s: torch.Tensor, speed: float) -> np.ndarray:
"""Forward pass through model with memory management.
Args:
model: PyTorch model
tokens: Input tokens
ref_s: Reference signal
speed: Speed multiplier
Returns:
Generated audio
"""
device = ref_s.device
pred_aln_trg = None
asr = None
try:
# Initial tensor setup
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
text_mask = length_to_mask(input_lengths).to(device)
# Split reference signals
s_content = ref_s[:, 128:].clone().to(device)
s_ref = ref_s[:, :128].clone().to(device)
# BERT and encoder pass
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
# Predictor forward pass
d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask)
x, _ = model.predictor.lstm(d)
# Duration prediction
duration = model.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long()
del duration, x # Free large intermediates
# Alignment matrix construction
pred_aln_trg = torch.zeros(
input_lengths.item(),
pred_dur.sum().item(),
device=device
)
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame:c_frame + pred_dur[0, i].item()] = 1
c_frame += pred_dur[0, i].item()
pred_aln_trg = pred_aln_trg.unsqueeze(0)
# Matrix multiplications with cleanup
en = d.transpose(-1, -2) @ pred_aln_trg
del d
F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
del en
# Final text encoding and decoding
t_en = model.text_encoder(tokens, input_lengths, text_mask)
asr = t_en @ pred_aln_trg
del t_en
# Generate output
output = model.decoder(asr, F0_pred, N_pred, s_ref)
result = output.squeeze().cpu().numpy()
return result
finally:
# Clean up largest tensors if they were created
if pred_aln_trg is not None:
del pred_aln_trg
if asr is not None:
del asr
def length_to_mask(lengths: torch.Tensor) -> torch.Tensor:
"""Create attention mask from lengths.
Args:
lengths: Sequence lengths
Returns:
Boolean mask tensor
"""
max_len = lengths.max()
mask = torch.arange(max_len, device=lengths.device)[None, :].expand(
lengths.shape[0], -1
)
if lengths.dtype != mask.dtype:
mask = mask.to(dtype=lengths.dtype)
return mask + 1 > lengths[:, None]
class PyTorchCPUBackend(BaseModelBackend):
"""PyTorch CPU inference backend."""
def __init__(self):
"""Initialize CPU backend."""
super().__init__()
self._device = "cpu"
self._model: Optional[torch.nn.Module] = None
# Configure PyTorch CPU settings
config = model_config.pytorch_cpu
if config.num_threads > 0:
torch.set_num_threads(config.num_threads)
if config.pin_memory:
torch.set_default_tensor_type(torch.FloatTensor)
async def load_model(self, path: str) -> None:
"""Load PyTorch model.
Args:
path: Path to model file
Raises:
RuntimeError: If model loading fails
"""
try:
# Get verified model path
model_path = await paths.get_model_path(path)
logger.info(f"Loading PyTorch model on CPU: {model_path}")
self._model = await build_model(model_path, self._device)
except Exception as e:
raise RuntimeError(f"Failed to load PyTorch model: {e}")
def generate(
self,
tokens: list[int],
voice: torch.Tensor,
speed: float = 1.0
) -> np.ndarray:
"""Generate audio using CPU model.
Args:
tokens: Input token IDs
voice: Voice embedding tensor
speed: Speed multiplier
Returns:
Generated audio samples
Raises:
RuntimeError: If generation fails
"""
if not self.is_loaded:
raise RuntimeError("Model not loaded")
try:
# Prepare input
ref_s = voice[len(tokens)].clone()
# Generate audio
return forward(self._model, tokens, ref_s, speed)
except Exception as e:
raise RuntimeError(f"Generation failed: {e}")
finally:
# Clean up memory
gc.collect()