mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Refactor: Consolidate PyTorch CPU and GPU backends into a single PyTorchBackend class; remove obsolete files
This commit is contained in:
parent
3547d95ee6
commit
00497f8872
4 changed files with 258 additions and 426 deletions
|
@ -4,8 +4,7 @@ from .base import BaseModelBackend
|
||||||
from .model_manager import ModelManager, get_manager
|
from .model_manager import ModelManager, get_manager
|
||||||
from .onnx_cpu import ONNXCPUBackend
|
from .onnx_cpu import ONNXCPUBackend
|
||||||
from .onnx_gpu import ONNXGPUBackend
|
from .onnx_gpu import ONNXGPUBackend
|
||||||
from .pytorch_cpu import PyTorchCPUBackend
|
from .pytorch_backend import PyTorchBackend
|
||||||
from .pytorch_gpu import PyTorchGPUBackend
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'BaseModelBackend',
|
'BaseModelBackend',
|
||||||
|
@ -13,6 +12,5 @@ __all__ = [
|
||||||
'get_manager',
|
'get_manager',
|
||||||
'ONNXCPUBackend',
|
'ONNXCPUBackend',
|
||||||
'ONNXGPUBackend',
|
'ONNXGPUBackend',
|
||||||
'PyTorchCPUBackend',
|
'PyTorchBackend',
|
||||||
'PyTorchGPUBackend',
|
|
||||||
]
|
]
|
|
@ -12,8 +12,7 @@ from ..core.model_config import ModelConfig, model_config
|
||||||
from .base import BaseModelBackend
|
from .base import BaseModelBackend
|
||||||
from .onnx_cpu import ONNXCPUBackend
|
from .onnx_cpu import ONNXCPUBackend
|
||||||
from .onnx_gpu import ONNXGPUBackend
|
from .onnx_gpu import ONNXGPUBackend
|
||||||
from .pytorch_cpu import PyTorchCPUBackend
|
from .pytorch_backend import PyTorchBackend
|
||||||
from .pytorch_gpu import PyTorchGPUBackend
|
|
||||||
from .session_pool import CPUSessionPool, StreamingSessionPool
|
from .session_pool import CPUSessionPool, StreamingSessionPool
|
||||||
|
|
||||||
|
|
||||||
|
@ -63,18 +62,18 @@ class ModelManager:
|
||||||
self._current_backend = 'onnx_gpu'
|
self._current_backend = 'onnx_gpu'
|
||||||
logger.info("Initialized new ONNX GPU backend")
|
logger.info("Initialized new ONNX GPU backend")
|
||||||
else:
|
else:
|
||||||
self._backends['pytorch_gpu'] = PyTorchGPUBackend()
|
self._backends['pytorch'] = PyTorchBackend()
|
||||||
self._current_backend = 'pytorch_gpu'
|
self._current_backend = 'pytorch'
|
||||||
logger.info("Initialized new PyTorch GPU backend")
|
logger.info("Initialized new PyTorch backend on GPU")
|
||||||
else:
|
else:
|
||||||
if settings.use_onnx:
|
if settings.use_onnx:
|
||||||
self._backends['onnx_cpu'] = ONNXCPUBackend()
|
self._backends['onnx_cpu'] = ONNXCPUBackend()
|
||||||
self._current_backend = 'onnx_cpu'
|
self._current_backend = 'onnx_cpu'
|
||||||
logger.info("Initialized new ONNX CPU backend")
|
logger.info("Initialized new ONNX CPU backend")
|
||||||
else:
|
else:
|
||||||
self._backends['pytorch_cpu'] = PyTorchCPUBackend()
|
self._backends['pytorch'] = PyTorchBackend()
|
||||||
self._current_backend = 'pytorch_cpu'
|
self._current_backend = 'pytorch'
|
||||||
logger.info("Initialized new PyTorch CPU backend")
|
logger.info("Initialized new PyTorch backend on CPU")
|
||||||
|
|
||||||
# Initialize locks for each backend
|
# Initialize locks for each backend
|
||||||
for backend in self._backends:
|
for backend in self._backends:
|
||||||
|
@ -95,10 +94,10 @@ class ModelManager:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Determine backend type based on settings
|
# Determine backend type based on settings
|
||||||
if settings.use_gpu and torch.cuda.is_available():
|
if settings.use_onnx:
|
||||||
backend_type = 'pytorch_gpu' if not settings.use_onnx else 'onnx_gpu'
|
backend_type = 'onnx_gpu' if settings.use_gpu and torch.cuda.is_available() else 'onnx_cpu'
|
||||||
else:
|
else:
|
||||||
backend_type = 'pytorch_cpu' if not settings.use_onnx else 'onnx_cpu'
|
backend_type = 'pytorch'
|
||||||
|
|
||||||
# Get backend
|
# Get backend
|
||||||
backend = self.get_backend(backend_type)
|
backend = self.get_backend(backend_type)
|
||||||
|
@ -167,13 +166,10 @@ class ModelManager:
|
||||||
Returns:
|
Returns:
|
||||||
Backend type to use
|
Backend type to use
|
||||||
"""
|
"""
|
||||||
has_gpu = settings.use_gpu and torch.cuda.is_available()
|
|
||||||
|
|
||||||
# If ONNX is preferred or model is ONNX format
|
# If ONNX is preferred or model is ONNX format
|
||||||
if settings.use_onnx or model_path.lower().endswith('.onnx'):
|
if settings.use_onnx or model_path.lower().endswith('.onnx'):
|
||||||
return 'onnx_gpu' if has_gpu else 'onnx_cpu'
|
return 'onnx_gpu' if settings.use_gpu and torch.cuda.is_available() else 'onnx_cpu'
|
||||||
else:
|
return 'pytorch'
|
||||||
return 'pytorch_gpu' if has_gpu else 'pytorch_cpu'
|
|
||||||
|
|
||||||
async def load_model(
|
async def load_model(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
"""GPU-based PyTorch inference backend."""
|
"""PyTorch inference backend with environment-based configuration."""
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -10,11 +12,12 @@ from loguru import logger
|
||||||
from ..builds.models import build_model
|
from ..builds.models import build_model
|
||||||
from ..core import paths
|
from ..core import paths
|
||||||
from ..core.model_config import model_config
|
from ..core.model_config import model_config
|
||||||
|
from ..core.config import settings
|
||||||
from .base import BaseModelBackend
|
from .base import BaseModelBackend
|
||||||
|
|
||||||
|
|
||||||
class CUDAStreamManager:
|
class CUDAStreamManager:
|
||||||
"""CUDA stream manager."""
|
"""CUDA stream manager for GPU operations."""
|
||||||
|
|
||||||
def __init__(self, num_streams: int):
|
def __init__(self, num_streams: int):
|
||||||
"""Initialize stream manager.
|
"""Initialize stream manager.
|
||||||
|
@ -42,30 +45,34 @@ def forward(
|
||||||
tokens: list[int],
|
tokens: list[int],
|
||||||
ref_s: torch.Tensor,
|
ref_s: torch.Tensor,
|
||||||
speed: float,
|
speed: float,
|
||||||
stream: Optional[torch.cuda.Stream] = None
|
stream: Optional[torch.cuda.Stream] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Forward pass through model.
|
"""Forward pass through model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: PyTorch model
|
model: PyTorch model
|
||||||
tokens: Input tokens
|
tokens: Input tokens
|
||||||
ref_s: Reference signal (shape: [1, n_features])
|
ref_s: Reference signal
|
||||||
speed: Speed multiplier
|
speed: Speed multiplier
|
||||||
stream: Optional CUDA stream
|
stream: Optional CUDA stream (GPU only)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Generated audio
|
Generated audio
|
||||||
"""
|
"""
|
||||||
device = ref_s.device
|
device = ref_s.device
|
||||||
|
|
||||||
# Use provided stream or default
|
# Use provided stream or default for GPU
|
||||||
with torch.cuda.stream(stream) if stream else torch.cuda.default_stream():
|
context = (
|
||||||
|
torch.cuda.stream(stream) if stream and device.type == "cuda" else nullcontext()
|
||||||
|
)
|
||||||
|
|
||||||
|
with context:
|
||||||
# Initial tensor setup
|
# Initial tensor setup
|
||||||
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
||||||
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
||||||
text_mask = length_to_mask(input_lengths).to(device)
|
text_mask = length_to_mask(input_lengths).to(device)
|
||||||
|
|
||||||
# Split reference signals (style_dim=128 from config)
|
# Split reference signals
|
||||||
style_dim = 128
|
style_dim = 128
|
||||||
s_ref = ref_s[:, :style_dim].clone().to(device)
|
s_ref = ref_s[:, :style_dim].clone().to(device)
|
||||||
s_content = ref_s[:, style_dim:].clone().to(device)
|
s_content = ref_s[:, style_dim:].clone().to(device)
|
||||||
|
@ -85,10 +92,12 @@ def forward(
|
||||||
del duration, x
|
del duration, x
|
||||||
|
|
||||||
# Alignment matrix construction
|
# Alignment matrix construction
|
||||||
pred_aln_trg = torch.zeros(input_lengths.item(), pred_dur.sum().item(), device=device)
|
pred_aln_trg = torch.zeros(
|
||||||
|
input_lengths.item(), pred_dur.sum().item(), device=device
|
||||||
|
)
|
||||||
c_frame = 0
|
c_frame = 0
|
||||||
for i in range(pred_aln_trg.size(0)):
|
for i in range(pred_aln_trg.size(0)):
|
||||||
pred_aln_trg[i, c_frame:c_frame + pred_dur[0, i].item()] = 1
|
pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
|
||||||
c_frame += pred_dur[0, i].item()
|
c_frame += pred_dur[0, i].item()
|
||||||
pred_aln_trg = pred_aln_trg.unsqueeze(0)
|
pred_aln_trg = pred_aln_trg.unsqueeze(0)
|
||||||
|
|
||||||
|
@ -108,7 +117,7 @@ def forward(
|
||||||
output = model.decoder(asr, F0_pred, N_pred, s_ref)
|
output = model.decoder(asr, F0_pred, N_pred, s_ref)
|
||||||
|
|
||||||
# Ensure operation completion if using custom stream
|
# Ensure operation completion if using custom stream
|
||||||
if stream:
|
if stream and device.type == "cuda":
|
||||||
stream.synchronize()
|
stream.synchronize()
|
||||||
|
|
||||||
return output.squeeze().cpu().numpy()
|
return output.squeeze().cpu().numpy()
|
||||||
|
@ -117,32 +126,40 @@ def forward(
|
||||||
def length_to_mask(lengths: torch.Tensor) -> torch.Tensor:
|
def length_to_mask(lengths: torch.Tensor) -> torch.Tensor:
|
||||||
"""Create attention mask from lengths."""
|
"""Create attention mask from lengths."""
|
||||||
max_len = lengths.max()
|
max_len = lengths.max()
|
||||||
mask = torch.arange(max_len, device=lengths.device)[None, :].expand(lengths.shape[0], -1)
|
mask = torch.arange(max_len, device=lengths.device)[None, :].expand(
|
||||||
|
lengths.shape[0], -1
|
||||||
|
)
|
||||||
if lengths.dtype != mask.dtype:
|
if lengths.dtype != mask.dtype:
|
||||||
mask = mask.to(dtype=lengths.dtype)
|
mask = mask.to(dtype=lengths.dtype)
|
||||||
return mask + 1 > lengths[:, None]
|
return mask + 1 > lengths[:, None]
|
||||||
|
|
||||||
|
|
||||||
class PyTorchGPUBackend(BaseModelBackend):
|
class PyTorchBackend(BaseModelBackend):
|
||||||
"""PyTorch GPU inference backend."""
|
"""PyTorch inference backend with environment-based configuration."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize GPU backend."""
|
"""Initialize backend based on environment configuration."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
from ..core.config import settings
|
|
||||||
if not (settings.use_gpu and torch.cuda.is_available()):
|
# Configure device based on settings
|
||||||
raise RuntimeError("GPU backend requires GPU support and use_gpu=True")
|
self._device = (
|
||||||
self._device = "cuda" # Device is enforced by backend selection in model_manager
|
"cuda" if settings.use_gpu and torch.cuda.is_available() else "cpu"
|
||||||
|
)
|
||||||
self._model: Optional[torch.nn.Module] = None
|
self._model: Optional[torch.nn.Module] = None
|
||||||
|
|
||||||
# Configure GPU settings
|
# Apply device-specific configurations
|
||||||
config = model_config.pytorch_gpu
|
if self._device == "cuda":
|
||||||
if config.sync_cuda:
|
config = model_config.pytorch_gpu
|
||||||
torch.cuda.synchronize()
|
if config.sync_cuda:
|
||||||
torch.cuda.set_device(config.device_id)
|
torch.cuda.synchronize()
|
||||||
|
torch.cuda.set_device(config.device_id)
|
||||||
# Initialize stream manager
|
self._stream_manager = CUDAStreamManager(config.cuda_streams)
|
||||||
self._stream_manager = CUDAStreamManager(config.cuda_streams)
|
else:
|
||||||
|
config = model_config.pytorch_cpu
|
||||||
|
if config.num_threads > 0:
|
||||||
|
torch.set_num_threads(config.num_threads)
|
||||||
|
if config.pin_memory:
|
||||||
|
torch.set_default_tensor_type(torch.FloatTensor)
|
||||||
|
|
||||||
async def load_model(self, path: str) -> None:
|
async def load_model(self, path: str) -> None:
|
||||||
"""Load PyTorch model.
|
"""Load PyTorch model.
|
||||||
|
@ -157,19 +174,16 @@ class PyTorchGPUBackend(BaseModelBackend):
|
||||||
# Get verified model path
|
# Get verified model path
|
||||||
model_path = await paths.get_model_path(path)
|
model_path = await paths.get_model_path(path)
|
||||||
|
|
||||||
logger.info(f"Loading PyTorch model: {model_path}")
|
logger.info(f"Loading PyTorch model on {self._device}: {model_path}")
|
||||||
self._model = await build_model(model_path, self._device)
|
self._model = await build_model(model_path, self._device)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to load PyTorch model: {e}")
|
raise RuntimeError(f"Failed to load PyTorch model: {e}")
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self, tokens: list[int], voice: torch.Tensor, speed: float = 1.0
|
||||||
tokens: list[int],
|
|
||||||
voice: torch.Tensor,
|
|
||||||
speed: float = 1.0
|
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Generate audio using GPU model.
|
"""Generate audio using model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tokens: Input token IDs
|
tokens: Input token IDs
|
||||||
|
@ -186,40 +200,45 @@ class PyTorchGPUBackend(BaseModelBackend):
|
||||||
raise RuntimeError("Model not loaded")
|
raise RuntimeError("Model not loaded")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Check memory and cleanup if needed
|
# Memory management for GPU
|
||||||
if self._check_memory():
|
if self._device == "cuda":
|
||||||
self._clear_memory()
|
if self._check_memory():
|
||||||
|
self._clear_memory()
|
||||||
|
stream = self._stream_manager.get_next_stream()
|
||||||
|
else:
|
||||||
|
stream = None
|
||||||
|
|
||||||
# Get reference style from voice pack
|
# Get reference style from voice pack
|
||||||
ref_s = voice[len(tokens)].clone().to(self._device)
|
ref_s = voice[len(tokens)].clone().to(self._device)
|
||||||
if ref_s.dim() == 1:
|
if ref_s.dim() == 1:
|
||||||
ref_s = ref_s.unsqueeze(0) # Add batch dimension if needed
|
ref_s = ref_s.unsqueeze(0)
|
||||||
|
|
||||||
# Get next available stream
|
# Generate audio
|
||||||
stream = self._stream_manager.get_next_stream()
|
|
||||||
|
|
||||||
# Generate audio using stream
|
|
||||||
return forward(self._model, tokens, ref_s, speed, stream)
|
return forward(self._model, tokens, ref_s, speed, stream)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Generation failed: {e}")
|
logger.error(f"Generation failed: {e}")
|
||||||
if model_config.pytorch_gpu.retry_on_oom and "out of memory" in str(e).lower():
|
if (
|
||||||
|
self._device == "cuda"
|
||||||
|
and model_config.pytorch_gpu.retry_on_oom
|
||||||
|
and "out of memory" in str(e).lower()
|
||||||
|
):
|
||||||
self._clear_memory()
|
self._clear_memory()
|
||||||
return self.generate(tokens, voice, speed) # Retry once
|
return self.generate(tokens, voice, speed)
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
if model_config.pytorch_gpu.sync_cuda:
|
if self._device == "cuda" and model_config.pytorch_gpu.sync_cuda:
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
def _check_memory(self) -> bool:
|
def _check_memory(self) -> bool:
|
||||||
"""Check if memory usage is above threshold."""
|
"""Check if memory usage is above threshold."""
|
||||||
if torch.cuda.is_available():
|
if self._device == "cuda":
|
||||||
memory_gb = torch.cuda.memory_allocated() / 1e9
|
memory_gb = torch.cuda.memory_allocated() / 1e9
|
||||||
return memory_gb > model_config.pytorch_gpu.memory_threshold
|
return memory_gb > model_config.pytorch_gpu.memory_threshold
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _clear_memory(self) -> None:
|
def _clear_memory(self) -> None:
|
||||||
"""Clear GPU memory."""
|
"""Clear device memory."""
|
||||||
if torch.cuda.is_available():
|
if self._device == "cuda":
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
gc.collect()
|
gc.collect()
|
|
@ -1,181 +0,0 @@
|
||||||
"""CPU-based PyTorch inference backend."""
|
|
||||||
|
|
||||||
import gc
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from ..builds.models import build_model
|
|
||||||
from ..core import paths
|
|
||||||
from ..core.model_config import model_config
|
|
||||||
from .base import BaseModelBackend
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def forward(model: torch.nn.Module, tokens: list[int], ref_s: torch.Tensor, speed: float) -> np.ndarray:
|
|
||||||
"""Forward pass through model with memory management.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: PyTorch model
|
|
||||||
tokens: Input tokens
|
|
||||||
ref_s: Reference signal
|
|
||||||
speed: Speed multiplier
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Generated audio
|
|
||||||
"""
|
|
||||||
device = ref_s.device
|
|
||||||
pred_aln_trg = None
|
|
||||||
asr = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Initial tensor setup
|
|
||||||
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
|
||||||
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
|
||||||
text_mask = length_to_mask(input_lengths).to(device)
|
|
||||||
|
|
||||||
# Split reference signals
|
|
||||||
s_content = ref_s[:, 128:].clone().to(device)
|
|
||||||
s_ref = ref_s[:, :128].clone().to(device)
|
|
||||||
|
|
||||||
# BERT and encoder pass
|
|
||||||
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
|
||||||
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
|
||||||
|
|
||||||
# Predictor forward pass
|
|
||||||
d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask)
|
|
||||||
x, _ = model.predictor.lstm(d)
|
|
||||||
|
|
||||||
# Duration prediction
|
|
||||||
duration = model.predictor.duration_proj(x)
|
|
||||||
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
|
||||||
pred_dur = torch.round(duration).clamp(min=1).long()
|
|
||||||
del duration, x # Free large intermediates
|
|
||||||
|
|
||||||
# Alignment matrix construction
|
|
||||||
pred_aln_trg = torch.zeros(
|
|
||||||
input_lengths.item(),
|
|
||||||
pred_dur.sum().item(),
|
|
||||||
device=device
|
|
||||||
)
|
|
||||||
c_frame = 0
|
|
||||||
for i in range(pred_aln_trg.size(0)):
|
|
||||||
pred_aln_trg[i, c_frame:c_frame + pred_dur[0, i].item()] = 1
|
|
||||||
c_frame += pred_dur[0, i].item()
|
|
||||||
pred_aln_trg = pred_aln_trg.unsqueeze(0)
|
|
||||||
|
|
||||||
# Matrix multiplications with cleanup
|
|
||||||
en = d.transpose(-1, -2) @ pred_aln_trg
|
|
||||||
del d
|
|
||||||
|
|
||||||
F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
|
|
||||||
del en
|
|
||||||
|
|
||||||
# Final text encoding and decoding
|
|
||||||
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
|
||||||
asr = t_en @ pred_aln_trg
|
|
||||||
del t_en
|
|
||||||
|
|
||||||
# Generate output
|
|
||||||
output = model.decoder(asr, F0_pred, N_pred, s_ref)
|
|
||||||
result = output.squeeze().cpu().numpy()
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# Clean up largest tensors if they were created
|
|
||||||
if pred_aln_trg is not None:
|
|
||||||
del pred_aln_trg
|
|
||||||
if asr is not None:
|
|
||||||
del asr
|
|
||||||
|
|
||||||
|
|
||||||
def length_to_mask(lengths: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""Create attention mask from lengths.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
lengths: Sequence lengths
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Boolean mask tensor
|
|
||||||
"""
|
|
||||||
max_len = lengths.max()
|
|
||||||
mask = torch.arange(max_len, device=lengths.device)[None, :].expand(
|
|
||||||
lengths.shape[0], -1
|
|
||||||
)
|
|
||||||
if lengths.dtype != mask.dtype:
|
|
||||||
mask = mask.to(dtype=lengths.dtype)
|
|
||||||
return mask + 1 > lengths[:, None]
|
|
||||||
|
|
||||||
|
|
||||||
class PyTorchCPUBackend(BaseModelBackend):
|
|
||||||
"""PyTorch CPU inference backend."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""Initialize CPU backend."""
|
|
||||||
super().__init__()
|
|
||||||
self._device = "cpu"
|
|
||||||
self._model: Optional[torch.nn.Module] = None
|
|
||||||
|
|
||||||
# Configure PyTorch CPU settings
|
|
||||||
config = model_config.pytorch_cpu
|
|
||||||
if config.num_threads > 0:
|
|
||||||
torch.set_num_threads(config.num_threads)
|
|
||||||
if config.pin_memory:
|
|
||||||
torch.set_default_tensor_type(torch.FloatTensor)
|
|
||||||
|
|
||||||
async def load_model(self, path: str) -> None:
|
|
||||||
"""Load PyTorch model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: Path to model file
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If model loading fails
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Get verified model path
|
|
||||||
model_path = await paths.get_model_path(path)
|
|
||||||
|
|
||||||
logger.info(f"Loading PyTorch model on CPU: {model_path}")
|
|
||||||
self._model = await build_model(model_path, self._device)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Failed to load PyTorch model: {e}")
|
|
||||||
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
tokens: list[int],
|
|
||||||
voice: torch.Tensor,
|
|
||||||
speed: float = 1.0
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Generate audio using CPU model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tokens: Input token IDs
|
|
||||||
voice: Voice embedding tensor
|
|
||||||
speed: Speed multiplier
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Generated audio samples
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If generation fails
|
|
||||||
"""
|
|
||||||
if not self.is_loaded:
|
|
||||||
raise RuntimeError("Model not loaded")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Prepare input
|
|
||||||
ref_s = voice[len(tokens)].clone()
|
|
||||||
|
|
||||||
# Generate audio
|
|
||||||
return forward(self._model, tokens, ref_s, speed)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Generation failed: {e}")
|
|
||||||
finally:
|
|
||||||
# Clean up memory
|
|
||||||
gc.collect()
|
|
Loading…
Add table
Reference in a new issue