Refactor: Consolidate PyTorch CPU and GPU backends into a single PyTorchBackend class; remove obsolete files

This commit is contained in:
remsky 2025-01-25 13:33:42 -07:00
parent 3547d95ee6
commit 00497f8872
4 changed files with 258 additions and 426 deletions

View file

@ -4,8 +4,7 @@ from .base import BaseModelBackend
from .model_manager import ModelManager, get_manager
from .onnx_cpu import ONNXCPUBackend
from .onnx_gpu import ONNXGPUBackend
from .pytorch_cpu import PyTorchCPUBackend
from .pytorch_gpu import PyTorchGPUBackend
from .pytorch_backend import PyTorchBackend
__all__ = [
'BaseModelBackend',
@ -13,6 +12,5 @@ __all__ = [
'get_manager',
'ONNXCPUBackend',
'ONNXGPUBackend',
'PyTorchCPUBackend',
'PyTorchGPUBackend',
'PyTorchBackend',
]

View file

@ -12,8 +12,7 @@ from ..core.model_config import ModelConfig, model_config
from .base import BaseModelBackend
from .onnx_cpu import ONNXCPUBackend
from .onnx_gpu import ONNXGPUBackend
from .pytorch_cpu import PyTorchCPUBackend
from .pytorch_gpu import PyTorchGPUBackend
from .pytorch_backend import PyTorchBackend
from .session_pool import CPUSessionPool, StreamingSessionPool
@ -63,18 +62,18 @@ class ModelManager:
self._current_backend = 'onnx_gpu'
logger.info("Initialized new ONNX GPU backend")
else:
self._backends['pytorch_gpu'] = PyTorchGPUBackend()
self._current_backend = 'pytorch_gpu'
logger.info("Initialized new PyTorch GPU backend")
self._backends['pytorch'] = PyTorchBackend()
self._current_backend = 'pytorch'
logger.info("Initialized new PyTorch backend on GPU")
else:
if settings.use_onnx:
self._backends['onnx_cpu'] = ONNXCPUBackend()
self._current_backend = 'onnx_cpu'
logger.info("Initialized new ONNX CPU backend")
else:
self._backends['pytorch_cpu'] = PyTorchCPUBackend()
self._current_backend = 'pytorch_cpu'
logger.info("Initialized new PyTorch CPU backend")
self._backends['pytorch'] = PyTorchBackend()
self._current_backend = 'pytorch'
logger.info("Initialized new PyTorch backend on CPU")
# Initialize locks for each backend
for backend in self._backends:
@ -95,10 +94,10 @@ class ModelManager:
"""
try:
# Determine backend type based on settings
if settings.use_gpu and torch.cuda.is_available():
backend_type = 'pytorch_gpu' if not settings.use_onnx else 'onnx_gpu'
if settings.use_onnx:
backend_type = 'onnx_gpu' if settings.use_gpu and torch.cuda.is_available() else 'onnx_cpu'
else:
backend_type = 'pytorch_cpu' if not settings.use_onnx else 'onnx_cpu'
backend_type = 'pytorch'
# Get backend
backend = self.get_backend(backend_type)
@ -167,13 +166,10 @@ class ModelManager:
Returns:
Backend type to use
"""
has_gpu = settings.use_gpu and torch.cuda.is_available()
# If ONNX is preferred or model is ONNX format
if settings.use_onnx or model_path.lower().endswith('.onnx'):
return 'onnx_gpu' if has_gpu else 'onnx_cpu'
else:
return 'pytorch_gpu' if has_gpu else 'pytorch_cpu'
return 'onnx_gpu' if settings.use_gpu and torch.cuda.is_available() else 'onnx_cpu'
return 'pytorch'
async def load_model(
self,

View file

@ -1,225 +1,244 @@
"""GPU-based PyTorch inference backend."""
import gc
from typing import Optional
import numpy as np
import torch
from loguru import logger
from ..builds.models import build_model
from ..core import paths
from ..core.model_config import model_config
from .base import BaseModelBackend
class CUDAStreamManager:
"""CUDA stream manager."""
def __init__(self, num_streams: int):
"""Initialize stream manager.
Args:
num_streams: Number of CUDA streams
"""
self.streams = [torch.cuda.Stream() for _ in range(num_streams)]
self._current = 0
def get_next_stream(self) -> torch.cuda.Stream:
"""Get next available stream.
Returns:
CUDA stream
"""
stream = self.streams[self._current]
self._current = (self._current + 1) % len(self.streams)
return stream
@torch.no_grad()
def forward(
model: torch.nn.Module,
tokens: list[int],
ref_s: torch.Tensor,
speed: float,
stream: Optional[torch.cuda.Stream] = None
) -> np.ndarray:
"""Forward pass through model.
Args:
model: PyTorch model
tokens: Input tokens
ref_s: Reference signal (shape: [1, n_features])
speed: Speed multiplier
stream: Optional CUDA stream
Returns:
Generated audio
"""
device = ref_s.device
# Use provided stream or default
with torch.cuda.stream(stream) if stream else torch.cuda.default_stream():
# Initial tensor setup
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
text_mask = length_to_mask(input_lengths).to(device)
# Split reference signals (style_dim=128 from config)
style_dim = 128
s_ref = ref_s[:, :style_dim].clone().to(device)
s_content = ref_s[:, style_dim:].clone().to(device)
# BERT and encoder pass
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
# Predictor forward pass
d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask)
x, _ = model.predictor.lstm(d)
# Duration prediction
duration = model.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long()
del duration, x
# Alignment matrix construction
pred_aln_trg = torch.zeros(input_lengths.item(), pred_dur.sum().item(), device=device)
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame:c_frame + pred_dur[0, i].item()] = 1
c_frame += pred_dur[0, i].item()
pred_aln_trg = pred_aln_trg.unsqueeze(0)
# Matrix multiplications
en = d.transpose(-1, -2) @ pred_aln_trg
del d
F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
del en
# Final text encoding and decoding
t_en = model.text_encoder(tokens, input_lengths, text_mask)
asr = t_en @ pred_aln_trg
del t_en
# Generate output
output = model.decoder(asr, F0_pred, N_pred, s_ref)
# Ensure operation completion if using custom stream
if stream:
stream.synchronize()
return output.squeeze().cpu().numpy()
def length_to_mask(lengths: torch.Tensor) -> torch.Tensor:
"""Create attention mask from lengths."""
max_len = lengths.max()
mask = torch.arange(max_len, device=lengths.device)[None, :].expand(lengths.shape[0], -1)
if lengths.dtype != mask.dtype:
mask = mask.to(dtype=lengths.dtype)
return mask + 1 > lengths[:, None]
class PyTorchGPUBackend(BaseModelBackend):
"""PyTorch GPU inference backend."""
def __init__(self):
"""Initialize GPU backend."""
super().__init__()
from ..core.config import settings
if not (settings.use_gpu and torch.cuda.is_available()):
raise RuntimeError("GPU backend requires GPU support and use_gpu=True")
self._device = "cuda" # Device is enforced by backend selection in model_manager
self._model: Optional[torch.nn.Module] = None
# Configure GPU settings
config = model_config.pytorch_gpu
if config.sync_cuda:
torch.cuda.synchronize()
torch.cuda.set_device(config.device_id)
# Initialize stream manager
self._stream_manager = CUDAStreamManager(config.cuda_streams)
async def load_model(self, path: str) -> None:
"""Load PyTorch model.
Args:
path: Path to model file
Raises:
RuntimeError: If model loading fails
"""
try:
# Get verified model path
model_path = await paths.get_model_path(path)
logger.info(f"Loading PyTorch model: {model_path}")
self._model = await build_model(model_path, self._device)
except Exception as e:
raise RuntimeError(f"Failed to load PyTorch model: {e}")
def generate(
self,
tokens: list[int],
voice: torch.Tensor,
speed: float = 1.0
) -> np.ndarray:
"""Generate audio using GPU model.
Args:
tokens: Input token IDs
voice: Voice embedding tensor
speed: Speed multiplier
Returns:
Generated audio samples
Raises:
RuntimeError: If generation fails
"""
if not self.is_loaded:
raise RuntimeError("Model not loaded")
try:
# Check memory and cleanup if needed
if self._check_memory():
self._clear_memory()
# Get reference style from voice pack
ref_s = voice[len(tokens)].clone().to(self._device)
if ref_s.dim() == 1:
ref_s = ref_s.unsqueeze(0) # Add batch dimension if needed
# Get next available stream
stream = self._stream_manager.get_next_stream()
# Generate audio using stream
return forward(self._model, tokens, ref_s, speed, stream)
except Exception as e:
logger.error(f"Generation failed: {e}")
if model_config.pytorch_gpu.retry_on_oom and "out of memory" in str(e).lower():
self._clear_memory()
return self.generate(tokens, voice, speed) # Retry once
raise
finally:
if model_config.pytorch_gpu.sync_cuda:
torch.cuda.synchronize()
def _check_memory(self) -> bool:
"""Check if memory usage is above threshold."""
if torch.cuda.is_available():
memory_gb = torch.cuda.memory_allocated() / 1e9
return memory_gb > model_config.pytorch_gpu.memory_threshold
return False
def _clear_memory(self) -> None:
"""Clear GPU memory."""
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
"""PyTorch inference backend with environment-based configuration."""
import gc
from typing import Optional
from contextlib import nullcontext
from typing import Optional
import numpy as np
import torch
from loguru import logger
from ..builds.models import build_model
from ..core import paths
from ..core.model_config import model_config
from ..core.config import settings
from .base import BaseModelBackend
class CUDAStreamManager:
"""CUDA stream manager for GPU operations."""
def __init__(self, num_streams: int):
"""Initialize stream manager.
Args:
num_streams: Number of CUDA streams
"""
self.streams = [torch.cuda.Stream() for _ in range(num_streams)]
self._current = 0
def get_next_stream(self) -> torch.cuda.Stream:
"""Get next available stream.
Returns:
CUDA stream
"""
stream = self.streams[self._current]
self._current = (self._current + 1) % len(self.streams)
return stream
@torch.no_grad()
def forward(
model: torch.nn.Module,
tokens: list[int],
ref_s: torch.Tensor,
speed: float,
stream: Optional[torch.cuda.Stream] = None,
) -> np.ndarray:
"""Forward pass through model.
Args:
model: PyTorch model
tokens: Input tokens
ref_s: Reference signal
speed: Speed multiplier
stream: Optional CUDA stream (GPU only)
Returns:
Generated audio
"""
device = ref_s.device
# Use provided stream or default for GPU
context = (
torch.cuda.stream(stream) if stream and device.type == "cuda" else nullcontext()
)
with context:
# Initial tensor setup
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
text_mask = length_to_mask(input_lengths).to(device)
# Split reference signals
style_dim = 128
s_ref = ref_s[:, :style_dim].clone().to(device)
s_content = ref_s[:, style_dim:].clone().to(device)
# BERT and encoder pass
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
# Predictor forward pass
d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask)
x, _ = model.predictor.lstm(d)
# Duration prediction
duration = model.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long()
del duration, x
# Alignment matrix construction
pred_aln_trg = torch.zeros(
input_lengths.item(), pred_dur.sum().item(), device=device
)
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
c_frame += pred_dur[0, i].item()
pred_aln_trg = pred_aln_trg.unsqueeze(0)
# Matrix multiplications
en = d.transpose(-1, -2) @ pred_aln_trg
del d
F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
del en
# Final text encoding and decoding
t_en = model.text_encoder(tokens, input_lengths, text_mask)
asr = t_en @ pred_aln_trg
del t_en
# Generate output
output = model.decoder(asr, F0_pred, N_pred, s_ref)
# Ensure operation completion if using custom stream
if stream and device.type == "cuda":
stream.synchronize()
return output.squeeze().cpu().numpy()
def length_to_mask(lengths: torch.Tensor) -> torch.Tensor:
"""Create attention mask from lengths."""
max_len = lengths.max()
mask = torch.arange(max_len, device=lengths.device)[None, :].expand(
lengths.shape[0], -1
)
if lengths.dtype != mask.dtype:
mask = mask.to(dtype=lengths.dtype)
return mask + 1 > lengths[:, None]
class PyTorchBackend(BaseModelBackend):
"""PyTorch inference backend with environment-based configuration."""
def __init__(self):
"""Initialize backend based on environment configuration."""
super().__init__()
# Configure device based on settings
self._device = (
"cuda" if settings.use_gpu and torch.cuda.is_available() else "cpu"
)
self._model: Optional[torch.nn.Module] = None
# Apply device-specific configurations
if self._device == "cuda":
config = model_config.pytorch_gpu
if config.sync_cuda:
torch.cuda.synchronize()
torch.cuda.set_device(config.device_id)
self._stream_manager = CUDAStreamManager(config.cuda_streams)
else:
config = model_config.pytorch_cpu
if config.num_threads > 0:
torch.set_num_threads(config.num_threads)
if config.pin_memory:
torch.set_default_tensor_type(torch.FloatTensor)
async def load_model(self, path: str) -> None:
"""Load PyTorch model.
Args:
path: Path to model file
Raises:
RuntimeError: If model loading fails
"""
try:
# Get verified model path
model_path = await paths.get_model_path(path)
logger.info(f"Loading PyTorch model on {self._device}: {model_path}")
self._model = await build_model(model_path, self._device)
except Exception as e:
raise RuntimeError(f"Failed to load PyTorch model: {e}")
def generate(
self, tokens: list[int], voice: torch.Tensor, speed: float = 1.0
) -> np.ndarray:
"""Generate audio using model.
Args:
tokens: Input token IDs
voice: Voice embedding tensor
speed: Speed multiplier
Returns:
Generated audio samples
Raises:
RuntimeError: If generation fails
"""
if not self.is_loaded:
raise RuntimeError("Model not loaded")
try:
# Memory management for GPU
if self._device == "cuda":
if self._check_memory():
self._clear_memory()
stream = self._stream_manager.get_next_stream()
else:
stream = None
# Get reference style from voice pack
ref_s = voice[len(tokens)].clone().to(self._device)
if ref_s.dim() == 1:
ref_s = ref_s.unsqueeze(0)
# Generate audio
return forward(self._model, tokens, ref_s, speed, stream)
except Exception as e:
logger.error(f"Generation failed: {e}")
if (
self._device == "cuda"
and model_config.pytorch_gpu.retry_on_oom
and "out of memory" in str(e).lower()
):
self._clear_memory()
return self.generate(tokens, voice, speed)
raise
finally:
if self._device == "cuda" and model_config.pytorch_gpu.sync_cuda:
torch.cuda.synchronize()
def _check_memory(self) -> bool:
"""Check if memory usage is above threshold."""
if self._device == "cuda":
memory_gb = torch.cuda.memory_allocated() / 1e9
return memory_gb > model_config.pytorch_gpu.memory_threshold
return False
def _clear_memory(self) -> None:
"""Clear device memory."""
if self._device == "cuda":
torch.cuda.empty_cache()
gc.collect()

View file

@ -1,181 +0,0 @@
"""CPU-based PyTorch inference backend."""
import gc
from typing import Optional
import numpy as np
import torch
from loguru import logger
from ..builds.models import build_model
from ..core import paths
from ..core.model_config import model_config
from .base import BaseModelBackend
@torch.no_grad()
def forward(model: torch.nn.Module, tokens: list[int], ref_s: torch.Tensor, speed: float) -> np.ndarray:
"""Forward pass through model with memory management.
Args:
model: PyTorch model
tokens: Input tokens
ref_s: Reference signal
speed: Speed multiplier
Returns:
Generated audio
"""
device = ref_s.device
pred_aln_trg = None
asr = None
try:
# Initial tensor setup
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
text_mask = length_to_mask(input_lengths).to(device)
# Split reference signals
s_content = ref_s[:, 128:].clone().to(device)
s_ref = ref_s[:, :128].clone().to(device)
# BERT and encoder pass
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
# Predictor forward pass
d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask)
x, _ = model.predictor.lstm(d)
# Duration prediction
duration = model.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long()
del duration, x # Free large intermediates
# Alignment matrix construction
pred_aln_trg = torch.zeros(
input_lengths.item(),
pred_dur.sum().item(),
device=device
)
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame:c_frame + pred_dur[0, i].item()] = 1
c_frame += pred_dur[0, i].item()
pred_aln_trg = pred_aln_trg.unsqueeze(0)
# Matrix multiplications with cleanup
en = d.transpose(-1, -2) @ pred_aln_trg
del d
F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
del en
# Final text encoding and decoding
t_en = model.text_encoder(tokens, input_lengths, text_mask)
asr = t_en @ pred_aln_trg
del t_en
# Generate output
output = model.decoder(asr, F0_pred, N_pred, s_ref)
result = output.squeeze().cpu().numpy()
return result
finally:
# Clean up largest tensors if they were created
if pred_aln_trg is not None:
del pred_aln_trg
if asr is not None:
del asr
def length_to_mask(lengths: torch.Tensor) -> torch.Tensor:
"""Create attention mask from lengths.
Args:
lengths: Sequence lengths
Returns:
Boolean mask tensor
"""
max_len = lengths.max()
mask = torch.arange(max_len, device=lengths.device)[None, :].expand(
lengths.shape[0], -1
)
if lengths.dtype != mask.dtype:
mask = mask.to(dtype=lengths.dtype)
return mask + 1 > lengths[:, None]
class PyTorchCPUBackend(BaseModelBackend):
"""PyTorch CPU inference backend."""
def __init__(self):
"""Initialize CPU backend."""
super().__init__()
self._device = "cpu"
self._model: Optional[torch.nn.Module] = None
# Configure PyTorch CPU settings
config = model_config.pytorch_cpu
if config.num_threads > 0:
torch.set_num_threads(config.num_threads)
if config.pin_memory:
torch.set_default_tensor_type(torch.FloatTensor)
async def load_model(self, path: str) -> None:
"""Load PyTorch model.
Args:
path: Path to model file
Raises:
RuntimeError: If model loading fails
"""
try:
# Get verified model path
model_path = await paths.get_model_path(path)
logger.info(f"Loading PyTorch model on CPU: {model_path}")
self._model = await build_model(model_path, self._device)
except Exception as e:
raise RuntimeError(f"Failed to load PyTorch model: {e}")
def generate(
self,
tokens: list[int],
voice: torch.Tensor,
speed: float = 1.0
) -> np.ndarray:
"""Generate audio using CPU model.
Args:
tokens: Input token IDs
voice: Voice embedding tensor
speed: Speed multiplier
Returns:
Generated audio samples
Raises:
RuntimeError: If generation fails
"""
if not self.is_loaded:
raise RuntimeError("Model not loaded")
try:
# Prepare input
ref_s = voice[len(tokens)].clone()
# Generate audio
return forward(self._model, tokens, ref_s, speed)
except Exception as e:
raise RuntimeError(f"Generation failed: {e}")
finally:
# Clean up memory
gc.collect()