From 00497f8872aaaae46f14e8373e02de2c743d3adf Mon Sep 17 00:00:00 2001 From: remsky Date: Sat, 25 Jan 2025 13:33:42 -0700 Subject: [PATCH] Refactor: Consolidate PyTorch CPU and GPU backends into a single PyTorchBackend class; remove obsolete files --- api/src/inference/__init__.py | 6 +- api/src/inference/model_manager.py | 28 +- .../{pytorch_gpu.py => pytorch_backend.py} | 469 +++++++++--------- api/src/inference/pytorch_cpu.py | 181 ------- 4 files changed, 258 insertions(+), 426 deletions(-) rename api/src/inference/{pytorch_gpu.py => pytorch_backend.py} (63%) delete mode 100644 api/src/inference/pytorch_cpu.py diff --git a/api/src/inference/__init__.py b/api/src/inference/__init__.py index 98dc42c..d469b69 100644 --- a/api/src/inference/__init__.py +++ b/api/src/inference/__init__.py @@ -4,8 +4,7 @@ from .base import BaseModelBackend from .model_manager import ModelManager, get_manager from .onnx_cpu import ONNXCPUBackend from .onnx_gpu import ONNXGPUBackend -from .pytorch_cpu import PyTorchCPUBackend -from .pytorch_gpu import PyTorchGPUBackend +from .pytorch_backend import PyTorchBackend __all__ = [ 'BaseModelBackend', @@ -13,6 +12,5 @@ __all__ = [ 'get_manager', 'ONNXCPUBackend', 'ONNXGPUBackend', - 'PyTorchCPUBackend', - 'PyTorchGPUBackend', + 'PyTorchBackend', ] \ No newline at end of file diff --git a/api/src/inference/model_manager.py b/api/src/inference/model_manager.py index a89fc64..5920843 100644 --- a/api/src/inference/model_manager.py +++ b/api/src/inference/model_manager.py @@ -12,8 +12,7 @@ from ..core.model_config import ModelConfig, model_config from .base import BaseModelBackend from .onnx_cpu import ONNXCPUBackend from .onnx_gpu import ONNXGPUBackend -from .pytorch_cpu import PyTorchCPUBackend -from .pytorch_gpu import PyTorchGPUBackend +from .pytorch_backend import PyTorchBackend from .session_pool import CPUSessionPool, StreamingSessionPool @@ -63,18 +62,18 @@ class ModelManager: self._current_backend = 'onnx_gpu' logger.info("Initialized new ONNX GPU backend") else: - self._backends['pytorch_gpu'] = PyTorchGPUBackend() - self._current_backend = 'pytorch_gpu' - logger.info("Initialized new PyTorch GPU backend") + self._backends['pytorch'] = PyTorchBackend() + self._current_backend = 'pytorch' + logger.info("Initialized new PyTorch backend on GPU") else: if settings.use_onnx: self._backends['onnx_cpu'] = ONNXCPUBackend() self._current_backend = 'onnx_cpu' logger.info("Initialized new ONNX CPU backend") else: - self._backends['pytorch_cpu'] = PyTorchCPUBackend() - self._current_backend = 'pytorch_cpu' - logger.info("Initialized new PyTorch CPU backend") + self._backends['pytorch'] = PyTorchBackend() + self._current_backend = 'pytorch' + logger.info("Initialized new PyTorch backend on CPU") # Initialize locks for each backend for backend in self._backends: @@ -95,10 +94,10 @@ class ModelManager: """ try: # Determine backend type based on settings - if settings.use_gpu and torch.cuda.is_available(): - backend_type = 'pytorch_gpu' if not settings.use_onnx else 'onnx_gpu' + if settings.use_onnx: + backend_type = 'onnx_gpu' if settings.use_gpu and torch.cuda.is_available() else 'onnx_cpu' else: - backend_type = 'pytorch_cpu' if not settings.use_onnx else 'onnx_cpu' + backend_type = 'pytorch' # Get backend backend = self.get_backend(backend_type) @@ -167,13 +166,10 @@ class ModelManager: Returns: Backend type to use """ - has_gpu = settings.use_gpu and torch.cuda.is_available() - # If ONNX is preferred or model is ONNX format if settings.use_onnx or model_path.lower().endswith('.onnx'): - return 'onnx_gpu' if has_gpu else 'onnx_cpu' - else: - return 'pytorch_gpu' if has_gpu else 'pytorch_cpu' + return 'onnx_gpu' if settings.use_gpu and torch.cuda.is_available() else 'onnx_cpu' + return 'pytorch' async def load_model( self, diff --git a/api/src/inference/pytorch_gpu.py b/api/src/inference/pytorch_backend.py similarity index 63% rename from api/src/inference/pytorch_gpu.py rename to api/src/inference/pytorch_backend.py index fc07a0d..4ad5095 100644 --- a/api/src/inference/pytorch_gpu.py +++ b/api/src/inference/pytorch_backend.py @@ -1,225 +1,244 @@ -"""GPU-based PyTorch inference backend.""" - -import gc -from typing import Optional - -import numpy as np -import torch -from loguru import logger - -from ..builds.models import build_model -from ..core import paths -from ..core.model_config import model_config -from .base import BaseModelBackend - - -class CUDAStreamManager: - """CUDA stream manager.""" - - def __init__(self, num_streams: int): - """Initialize stream manager. - - Args: - num_streams: Number of CUDA streams - """ - self.streams = [torch.cuda.Stream() for _ in range(num_streams)] - self._current = 0 - - def get_next_stream(self) -> torch.cuda.Stream: - """Get next available stream. - - Returns: - CUDA stream - """ - stream = self.streams[self._current] - self._current = (self._current + 1) % len(self.streams) - return stream - - -@torch.no_grad() -def forward( - model: torch.nn.Module, - tokens: list[int], - ref_s: torch.Tensor, - speed: float, - stream: Optional[torch.cuda.Stream] = None -) -> np.ndarray: - """Forward pass through model. - - Args: - model: PyTorch model - tokens: Input tokens - ref_s: Reference signal (shape: [1, n_features]) - speed: Speed multiplier - stream: Optional CUDA stream - - Returns: - Generated audio - """ - device = ref_s.device - - # Use provided stream or default - with torch.cuda.stream(stream) if stream else torch.cuda.default_stream(): - # Initial tensor setup - tokens = torch.LongTensor([[0, *tokens, 0]]).to(device) - input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device) - text_mask = length_to_mask(input_lengths).to(device) - - # Split reference signals (style_dim=128 from config) - style_dim = 128 - s_ref = ref_s[:, :style_dim].clone().to(device) - s_content = ref_s[:, style_dim:].clone().to(device) - - # BERT and encoder pass - bert_dur = model.bert(tokens, attention_mask=(~text_mask).int()) - d_en = model.bert_encoder(bert_dur).transpose(-1, -2) - - # Predictor forward pass - d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask) - x, _ = model.predictor.lstm(d) - - # Duration prediction - duration = model.predictor.duration_proj(x) - duration = torch.sigmoid(duration).sum(axis=-1) / speed - pred_dur = torch.round(duration).clamp(min=1).long() - del duration, x - - # Alignment matrix construction - pred_aln_trg = torch.zeros(input_lengths.item(), pred_dur.sum().item(), device=device) - c_frame = 0 - for i in range(pred_aln_trg.size(0)): - pred_aln_trg[i, c_frame:c_frame + pred_dur[0, i].item()] = 1 - c_frame += pred_dur[0, i].item() - pred_aln_trg = pred_aln_trg.unsqueeze(0) - - # Matrix multiplications - en = d.transpose(-1, -2) @ pred_aln_trg - del d - - F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content) - del en - - # Final text encoding and decoding - t_en = model.text_encoder(tokens, input_lengths, text_mask) - asr = t_en @ pred_aln_trg - del t_en - - # Generate output - output = model.decoder(asr, F0_pred, N_pred, s_ref) - - # Ensure operation completion if using custom stream - if stream: - stream.synchronize() - - return output.squeeze().cpu().numpy() - - -def length_to_mask(lengths: torch.Tensor) -> torch.Tensor: - """Create attention mask from lengths.""" - max_len = lengths.max() - mask = torch.arange(max_len, device=lengths.device)[None, :].expand(lengths.shape[0], -1) - if lengths.dtype != mask.dtype: - mask = mask.to(dtype=lengths.dtype) - return mask + 1 > lengths[:, None] - - -class PyTorchGPUBackend(BaseModelBackend): - """PyTorch GPU inference backend.""" - - def __init__(self): - """Initialize GPU backend.""" - super().__init__() - from ..core.config import settings - if not (settings.use_gpu and torch.cuda.is_available()): - raise RuntimeError("GPU backend requires GPU support and use_gpu=True") - self._device = "cuda" # Device is enforced by backend selection in model_manager - self._model: Optional[torch.nn.Module] = None - - # Configure GPU settings - config = model_config.pytorch_gpu - if config.sync_cuda: - torch.cuda.synchronize() - torch.cuda.set_device(config.device_id) - - # Initialize stream manager - self._stream_manager = CUDAStreamManager(config.cuda_streams) - - async def load_model(self, path: str) -> None: - """Load PyTorch model. - - Args: - path: Path to model file - - Raises: - RuntimeError: If model loading fails - """ - try: - # Get verified model path - model_path = await paths.get_model_path(path) - - logger.info(f"Loading PyTorch model: {model_path}") - self._model = await build_model(model_path, self._device) - - except Exception as e: - raise RuntimeError(f"Failed to load PyTorch model: {e}") - - def generate( - self, - tokens: list[int], - voice: torch.Tensor, - speed: float = 1.0 - ) -> np.ndarray: - """Generate audio using GPU model. - - Args: - tokens: Input token IDs - voice: Voice embedding tensor - speed: Speed multiplier - - Returns: - Generated audio samples - - Raises: - RuntimeError: If generation fails - """ - if not self.is_loaded: - raise RuntimeError("Model not loaded") - - try: - # Check memory and cleanup if needed - if self._check_memory(): - self._clear_memory() - - # Get reference style from voice pack - ref_s = voice[len(tokens)].clone().to(self._device) - if ref_s.dim() == 1: - ref_s = ref_s.unsqueeze(0) # Add batch dimension if needed - - # Get next available stream - stream = self._stream_manager.get_next_stream() - - # Generate audio using stream - return forward(self._model, tokens, ref_s, speed, stream) - - except Exception as e: - logger.error(f"Generation failed: {e}") - if model_config.pytorch_gpu.retry_on_oom and "out of memory" in str(e).lower(): - self._clear_memory() - return self.generate(tokens, voice, speed) # Retry once - raise - finally: - if model_config.pytorch_gpu.sync_cuda: - torch.cuda.synchronize() - - def _check_memory(self) -> bool: - """Check if memory usage is above threshold.""" - if torch.cuda.is_available(): - memory_gb = torch.cuda.memory_allocated() / 1e9 - return memory_gb > model_config.pytorch_gpu.memory_threshold - return False - - def _clear_memory(self) -> None: - """Clear GPU memory.""" - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() \ No newline at end of file +"""PyTorch inference backend with environment-based configuration.""" + +import gc +from typing import Optional +from contextlib import nullcontext +from typing import Optional + +import numpy as np +import torch +from loguru import logger + +from ..builds.models import build_model +from ..core import paths +from ..core.model_config import model_config +from ..core.config import settings +from .base import BaseModelBackend + + +class CUDAStreamManager: + """CUDA stream manager for GPU operations.""" + + def __init__(self, num_streams: int): + """Initialize stream manager. + + Args: + num_streams: Number of CUDA streams + """ + self.streams = [torch.cuda.Stream() for _ in range(num_streams)] + self._current = 0 + + def get_next_stream(self) -> torch.cuda.Stream: + """Get next available stream. + + Returns: + CUDA stream + """ + stream = self.streams[self._current] + self._current = (self._current + 1) % len(self.streams) + return stream + + +@torch.no_grad() +def forward( + model: torch.nn.Module, + tokens: list[int], + ref_s: torch.Tensor, + speed: float, + stream: Optional[torch.cuda.Stream] = None, +) -> np.ndarray: + """Forward pass through model. + + Args: + model: PyTorch model + tokens: Input tokens + ref_s: Reference signal + speed: Speed multiplier + stream: Optional CUDA stream (GPU only) + + Returns: + Generated audio + """ + device = ref_s.device + + # Use provided stream or default for GPU + context = ( + torch.cuda.stream(stream) if stream and device.type == "cuda" else nullcontext() + ) + + with context: + # Initial tensor setup + tokens = torch.LongTensor([[0, *tokens, 0]]).to(device) + input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device) + text_mask = length_to_mask(input_lengths).to(device) + + # Split reference signals + style_dim = 128 + s_ref = ref_s[:, :style_dim].clone().to(device) + s_content = ref_s[:, style_dim:].clone().to(device) + + # BERT and encoder pass + bert_dur = model.bert(tokens, attention_mask=(~text_mask).int()) + d_en = model.bert_encoder(bert_dur).transpose(-1, -2) + + # Predictor forward pass + d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask) + x, _ = model.predictor.lstm(d) + + # Duration prediction + duration = model.predictor.duration_proj(x) + duration = torch.sigmoid(duration).sum(axis=-1) / speed + pred_dur = torch.round(duration).clamp(min=1).long() + del duration, x + + # Alignment matrix construction + pred_aln_trg = torch.zeros( + input_lengths.item(), pred_dur.sum().item(), device=device + ) + c_frame = 0 + for i in range(pred_aln_trg.size(0)): + pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1 + c_frame += pred_dur[0, i].item() + pred_aln_trg = pred_aln_trg.unsqueeze(0) + + # Matrix multiplications + en = d.transpose(-1, -2) @ pred_aln_trg + del d + + F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content) + del en + + # Final text encoding and decoding + t_en = model.text_encoder(tokens, input_lengths, text_mask) + asr = t_en @ pred_aln_trg + del t_en + + # Generate output + output = model.decoder(asr, F0_pred, N_pred, s_ref) + + # Ensure operation completion if using custom stream + if stream and device.type == "cuda": + stream.synchronize() + + return output.squeeze().cpu().numpy() + + +def length_to_mask(lengths: torch.Tensor) -> torch.Tensor: + """Create attention mask from lengths.""" + max_len = lengths.max() + mask = torch.arange(max_len, device=lengths.device)[None, :].expand( + lengths.shape[0], -1 + ) + if lengths.dtype != mask.dtype: + mask = mask.to(dtype=lengths.dtype) + return mask + 1 > lengths[:, None] + + +class PyTorchBackend(BaseModelBackend): + """PyTorch inference backend with environment-based configuration.""" + + def __init__(self): + """Initialize backend based on environment configuration.""" + super().__init__() + + # Configure device based on settings + self._device = ( + "cuda" if settings.use_gpu and torch.cuda.is_available() else "cpu" + ) + self._model: Optional[torch.nn.Module] = None + + # Apply device-specific configurations + if self._device == "cuda": + config = model_config.pytorch_gpu + if config.sync_cuda: + torch.cuda.synchronize() + torch.cuda.set_device(config.device_id) + self._stream_manager = CUDAStreamManager(config.cuda_streams) + else: + config = model_config.pytorch_cpu + if config.num_threads > 0: + torch.set_num_threads(config.num_threads) + if config.pin_memory: + torch.set_default_tensor_type(torch.FloatTensor) + + async def load_model(self, path: str) -> None: + """Load PyTorch model. + + Args: + path: Path to model file + + Raises: + RuntimeError: If model loading fails + """ + try: + # Get verified model path + model_path = await paths.get_model_path(path) + + logger.info(f"Loading PyTorch model on {self._device}: {model_path}") + self._model = await build_model(model_path, self._device) + + except Exception as e: + raise RuntimeError(f"Failed to load PyTorch model: {e}") + + def generate( + self, tokens: list[int], voice: torch.Tensor, speed: float = 1.0 + ) -> np.ndarray: + """Generate audio using model. + + Args: + tokens: Input token IDs + voice: Voice embedding tensor + speed: Speed multiplier + + Returns: + Generated audio samples + + Raises: + RuntimeError: If generation fails + """ + if not self.is_loaded: + raise RuntimeError("Model not loaded") + + try: + # Memory management for GPU + if self._device == "cuda": + if self._check_memory(): + self._clear_memory() + stream = self._stream_manager.get_next_stream() + else: + stream = None + + # Get reference style from voice pack + ref_s = voice[len(tokens)].clone().to(self._device) + if ref_s.dim() == 1: + ref_s = ref_s.unsqueeze(0) + + # Generate audio + return forward(self._model, tokens, ref_s, speed, stream) + + except Exception as e: + logger.error(f"Generation failed: {e}") + if ( + self._device == "cuda" + and model_config.pytorch_gpu.retry_on_oom + and "out of memory" in str(e).lower() + ): + self._clear_memory() + return self.generate(tokens, voice, speed) + raise + finally: + if self._device == "cuda" and model_config.pytorch_gpu.sync_cuda: + torch.cuda.synchronize() + + def _check_memory(self) -> bool: + """Check if memory usage is above threshold.""" + if self._device == "cuda": + memory_gb = torch.cuda.memory_allocated() / 1e9 + return memory_gb > model_config.pytorch_gpu.memory_threshold + return False + + def _clear_memory(self) -> None: + """Clear device memory.""" + if self._device == "cuda": + torch.cuda.empty_cache() + gc.collect() diff --git a/api/src/inference/pytorch_cpu.py b/api/src/inference/pytorch_cpu.py deleted file mode 100644 index b1d4328..0000000 --- a/api/src/inference/pytorch_cpu.py +++ /dev/null @@ -1,181 +0,0 @@ -"""CPU-based PyTorch inference backend.""" - -import gc -from typing import Optional - -import numpy as np -import torch -from loguru import logger - -from ..builds.models import build_model -from ..core import paths -from ..core.model_config import model_config -from .base import BaseModelBackend - - -@torch.no_grad() -def forward(model: torch.nn.Module, tokens: list[int], ref_s: torch.Tensor, speed: float) -> np.ndarray: - """Forward pass through model with memory management. - - Args: - model: PyTorch model - tokens: Input tokens - ref_s: Reference signal - speed: Speed multiplier - - Returns: - Generated audio - """ - device = ref_s.device - pred_aln_trg = None - asr = None - - try: - # Initial tensor setup - tokens = torch.LongTensor([[0, *tokens, 0]]).to(device) - input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device) - text_mask = length_to_mask(input_lengths).to(device) - - # Split reference signals - s_content = ref_s[:, 128:].clone().to(device) - s_ref = ref_s[:, :128].clone().to(device) - - # BERT and encoder pass - bert_dur = model.bert(tokens, attention_mask=(~text_mask).int()) - d_en = model.bert_encoder(bert_dur).transpose(-1, -2) - - # Predictor forward pass - d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask) - x, _ = model.predictor.lstm(d) - - # Duration prediction - duration = model.predictor.duration_proj(x) - duration = torch.sigmoid(duration).sum(axis=-1) / speed - pred_dur = torch.round(duration).clamp(min=1).long() - del duration, x # Free large intermediates - - # Alignment matrix construction - pred_aln_trg = torch.zeros( - input_lengths.item(), - pred_dur.sum().item(), - device=device - ) - c_frame = 0 - for i in range(pred_aln_trg.size(0)): - pred_aln_trg[i, c_frame:c_frame + pred_dur[0, i].item()] = 1 - c_frame += pred_dur[0, i].item() - pred_aln_trg = pred_aln_trg.unsqueeze(0) - - # Matrix multiplications with cleanup - en = d.transpose(-1, -2) @ pred_aln_trg - del d - - F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content) - del en - - # Final text encoding and decoding - t_en = model.text_encoder(tokens, input_lengths, text_mask) - asr = t_en @ pred_aln_trg - del t_en - - # Generate output - output = model.decoder(asr, F0_pred, N_pred, s_ref) - result = output.squeeze().cpu().numpy() - - return result - - finally: - # Clean up largest tensors if they were created - if pred_aln_trg is not None: - del pred_aln_trg - if asr is not None: - del asr - - -def length_to_mask(lengths: torch.Tensor) -> torch.Tensor: - """Create attention mask from lengths. - - Args: - lengths: Sequence lengths - - Returns: - Boolean mask tensor - """ - max_len = lengths.max() - mask = torch.arange(max_len, device=lengths.device)[None, :].expand( - lengths.shape[0], -1 - ) - if lengths.dtype != mask.dtype: - mask = mask.to(dtype=lengths.dtype) - return mask + 1 > lengths[:, None] - - -class PyTorchCPUBackend(BaseModelBackend): - """PyTorch CPU inference backend.""" - - def __init__(self): - """Initialize CPU backend.""" - super().__init__() - self._device = "cpu" - self._model: Optional[torch.nn.Module] = None - - # Configure PyTorch CPU settings - config = model_config.pytorch_cpu - if config.num_threads > 0: - torch.set_num_threads(config.num_threads) - if config.pin_memory: - torch.set_default_tensor_type(torch.FloatTensor) - - async def load_model(self, path: str) -> None: - """Load PyTorch model. - - Args: - path: Path to model file - - Raises: - RuntimeError: If model loading fails - """ - try: - # Get verified model path - model_path = await paths.get_model_path(path) - - logger.info(f"Loading PyTorch model on CPU: {model_path}") - self._model = await build_model(model_path, self._device) - - except Exception as e: - raise RuntimeError(f"Failed to load PyTorch model: {e}") - - def generate( - self, - tokens: list[int], - voice: torch.Tensor, - speed: float = 1.0 - ) -> np.ndarray: - """Generate audio using CPU model. - - Args: - tokens: Input token IDs - voice: Voice embedding tensor - speed: Speed multiplier - - Returns: - Generated audio samples - - Raises: - RuntimeError: If generation fails - """ - if not self.is_loaded: - raise RuntimeError("Model not loaded") - - try: - # Prepare input - ref_s = voice[len(tokens)].clone() - - # Generate audio - return forward(self._model, tokens, ref_s, speed) - - except Exception as e: - raise RuntimeError(f"Generation failed: {e}") - finally: - # Clean up memory - gc.collect() \ No newline at end of file