mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Refactor model loading and configuration: update, adjust model loading device,. add async streaming examples and remove unused warmup service.
This commit is contained in:
parent
21bf810f97
commit
4a24be1605
21 changed files with 929 additions and 484 deletions
|
@ -367,7 +367,7 @@ async def build_model(path, device):
|
|||
decoder=decoder.to(device).eval(),
|
||||
text_encoder=text_encoder.to(device).eval(),
|
||||
)
|
||||
weights = await load_model_weights(path, device='cpu')
|
||||
weights = await load_model_weights(path, device=device)
|
||||
for key, state_dict in weights['net'].items():
|
||||
assert key in model, key
|
||||
try:
|
||||
|
|
|
@ -15,9 +15,9 @@ class Settings(BaseSettings):
|
|||
default_voice: str = "af"
|
||||
use_gpu: bool = False # Whether to use GPU acceleration if available
|
||||
use_onnx: bool = True # Whether to use ONNX runtime
|
||||
# Paths relative to api directory
|
||||
model_dir: str = "src/models" # Model directory relative to api/
|
||||
voices_dir: str = "src/voices" # Voices directory relative to api/
|
||||
# Container absolute paths
|
||||
model_dir: str = "/app/api/src/models" # Absolute path in container
|
||||
voices_dir: str = "/app/api/src/voices" # Absolute path in container
|
||||
|
||||
# Model filenames
|
||||
pytorch_model_file: str = "kokoro-v0_19.pth"
|
||||
|
|
|
@ -6,6 +6,11 @@ from pydantic import BaseModel, Field
|
|||
class ONNXCPUConfig(BaseModel):
|
||||
"""ONNX CPU runtime configuration."""
|
||||
|
||||
# Session pooling
|
||||
max_instances: int = Field(4, description="Maximum concurrent model instances")
|
||||
instance_timeout: int = Field(300, description="Session timeout in seconds")
|
||||
|
||||
# Runtime settings
|
||||
num_threads: int = Field(8, description="Number of threads for parallel operations")
|
||||
inter_op_threads: int = Field(4, description="Number of threads for operator parallelism")
|
||||
execution_mode: str = Field("parallel", description="ONNX execution mode")
|
||||
|
@ -20,9 +25,14 @@ class ONNXCPUConfig(BaseModel):
|
|||
class ONNXGPUConfig(ONNXCPUConfig):
|
||||
"""ONNX GPU-specific configuration."""
|
||||
|
||||
# CUDA settings
|
||||
device_id: int = Field(0, description="CUDA device ID")
|
||||
gpu_mem_limit: float = Field(0.7, description="Fraction of GPU memory to use")
|
||||
cudnn_conv_algo_search: str = Field("EXHAUSTIVE", description="CuDNN convolution algorithm search")
|
||||
|
||||
# Stream management
|
||||
cuda_streams: int = Field(2, description="Number of CUDA streams for inference")
|
||||
stream_timeout: int = Field(60, description="Stream timeout in seconds")
|
||||
do_copy_in_default_stream: bool = Field(True, description="Copy in default CUDA stream")
|
||||
|
||||
class Config:
|
||||
|
@ -32,8 +42,6 @@ class ONNXGPUConfig(ONNXCPUConfig):
|
|||
class PyTorchCPUConfig(BaseModel):
|
||||
"""PyTorch CPU backend configuration."""
|
||||
|
||||
max_batch_size: int = Field(32, description="Maximum batch size for batched inference")
|
||||
stream_buffer_size: int = Field(8, description="Size of stream buffer")
|
||||
memory_threshold: float = Field(0.8, description="Memory threshold for cleanup")
|
||||
retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors")
|
||||
num_threads: int = Field(8, description="Number of threads for parallel operations")
|
||||
|
@ -49,18 +57,11 @@ class PyTorchGPUConfig(BaseModel):
|
|||
device_id: int = Field(0, description="CUDA device ID")
|
||||
use_fp16: bool = Field(True, description="Whether to use FP16 precision")
|
||||
use_triton: bool = Field(True, description="Whether to use Triton for CUDA kernels")
|
||||
max_batch_size: int = Field(32, description="Maximum batch size for batched inference")
|
||||
stream_buffer_size: int = Field(8, description="Size of CUDA stream buffer")
|
||||
memory_threshold: float = Field(0.8, description="Memory threshold for cleanup")
|
||||
retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors")
|
||||
sync_cuda: bool = Field(True, description="Whether to synchronize CUDA operations")
|
||||
|
||||
class Config:
|
||||
frozen = True
|
||||
"""PyTorch CPU-specific configuration."""
|
||||
|
||||
num_threads: int = Field(8, description="Number of threads for parallel operations")
|
||||
pin_memory: bool = Field(True, description="Whether to pin memory for faster CPU-GPU transfer")
|
||||
cuda_streams: int = Field(2, description="Number of CUDA streams for inference")
|
||||
stream_timeout: int = Field(60, description="Stream timeout in seconds")
|
||||
|
||||
class Config:
|
||||
frozen = True
|
||||
|
@ -74,7 +75,7 @@ class ModelConfig(BaseModel):
|
|||
device_type: str = Field("auto", description="Device type ('cpu', 'gpu', or 'auto')")
|
||||
cache_models: bool = Field(True, description="Whether to cache loaded models")
|
||||
cache_voices: bool = Field(True, description="Whether to cache voice tensors")
|
||||
voice_cache_size: int = Field(10, description="Maximum number of cached voices")
|
||||
voice_cache_size: int = Field(2, description="Maximum number of cached voices")
|
||||
|
||||
# Backend-specific configs
|
||||
onnx_cpu: ONNXCPUConfig = Field(default_factory=ONNXCPUConfig)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
"""Model management and caching."""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
|
@ -13,11 +14,19 @@ from .onnx_cpu import ONNXCPUBackend
|
|||
from .onnx_gpu import ONNXGPUBackend
|
||||
from .pytorch_cpu import PyTorchCPUBackend
|
||||
from .pytorch_gpu import PyTorchGPUBackend
|
||||
from .session_pool import CPUSessionPool, StreamingSessionPool
|
||||
|
||||
|
||||
# Global singleton instance and state
|
||||
_manager_instance = None
|
||||
_manager_lock = asyncio.Lock()
|
||||
_loaded_models = {}
|
||||
_backends = {}
|
||||
|
||||
|
||||
class ModelManager:
|
||||
"""Manages model loading and inference across backends."""
|
||||
|
||||
|
||||
def __init__(self, config: Optional[ModelConfig] = None):
|
||||
"""Initialize model manager.
|
||||
|
||||
|
@ -25,65 +34,60 @@ class ModelManager:
|
|||
config: Optional configuration
|
||||
"""
|
||||
self._config = config or model_config
|
||||
self._backends: Dict[str, BaseModelBackend] = {}
|
||||
self._current_backend: Optional[str] = None
|
||||
self._initialize_backends()
|
||||
global _loaded_models, _backends
|
||||
self._loaded_models = _loaded_models
|
||||
self._backends = _backends
|
||||
|
||||
# Initialize session pools
|
||||
self._session_pools = {
|
||||
'onnx_cpu': CPUSessionPool(),
|
||||
'onnx_gpu': StreamingSessionPool()
|
||||
}
|
||||
|
||||
# Initialize locks
|
||||
self._backend_locks: Dict[str, asyncio.Lock] = {}
|
||||
|
||||
def _initialize_backends(self) -> None:
|
||||
"""Initialize available backends based on settings."""
|
||||
has_gpu = settings.use_gpu and torch.cuda.is_available()
|
||||
def _determine_device(self) -> str:
|
||||
"""Determine device based on settings."""
|
||||
if settings.use_gpu and torch.cuda.is_available():
|
||||
return "cuda"
|
||||
return "cpu"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize backends."""
|
||||
if self._backends:
|
||||
logger.debug("Using existing backend instances")
|
||||
return
|
||||
|
||||
device = self._determine_device()
|
||||
|
||||
try:
|
||||
if has_gpu:
|
||||
if device == "cuda":
|
||||
if settings.use_onnx:
|
||||
# ONNX GPU primary
|
||||
self._backends['onnx_gpu'] = ONNXGPUBackend()
|
||||
self._current_backend = 'onnx_gpu'
|
||||
logger.info("Initialized ONNX GPU backend")
|
||||
|
||||
# PyTorch GPU fallback
|
||||
self._backends['pytorch_gpu'] = PyTorchGPUBackend()
|
||||
logger.info("Initialized PyTorch GPU backend")
|
||||
logger.info("Initialized new ONNX GPU backend")
|
||||
else:
|
||||
# PyTorch GPU primary
|
||||
self._backends['pytorch_gpu'] = PyTorchGPUBackend()
|
||||
self._current_backend = 'pytorch_gpu'
|
||||
logger.info("Initialized PyTorch GPU backend")
|
||||
logger.info("Initialized new PyTorch GPU backend")
|
||||
else:
|
||||
if settings.use_onnx:
|
||||
self._backends['onnx_cpu'] = ONNXCPUBackend()
|
||||
self._current_backend = 'onnx_cpu'
|
||||
logger.info("Initialized new ONNX CPU backend")
|
||||
else:
|
||||
self._backends['pytorch_cpu'] = PyTorchCPUBackend()
|
||||
self._current_backend = 'pytorch_cpu'
|
||||
logger.info("Initialized new PyTorch CPU backend")
|
||||
|
||||
# ONNX GPU fallback
|
||||
self._backends['onnx_gpu'] = ONNXGPUBackend()
|
||||
logger.info("Initialized ONNX GPU backend")
|
||||
else:
|
||||
self._initialize_cpu_backends()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize GPU backends: {e}")
|
||||
# Fallback to CPU if GPU fails
|
||||
self._initialize_cpu_backends()
|
||||
|
||||
def _initialize_cpu_backends(self) -> None:
|
||||
"""Initialize CPU backends based on settings."""
|
||||
try:
|
||||
if settings.use_onnx:
|
||||
# ONNX CPU primary
|
||||
self._backends['onnx_cpu'] = ONNXCPUBackend()
|
||||
self._current_backend = 'onnx_cpu'
|
||||
logger.info("Initialized ONNX CPU backend")
|
||||
# Initialize locks for each backend
|
||||
for backend in self._backends:
|
||||
self._backend_locks[backend] = asyncio.Lock()
|
||||
|
||||
# PyTorch CPU fallback
|
||||
self._backends['pytorch_cpu'] = PyTorchCPUBackend()
|
||||
logger.info("Initialized PyTorch CPU backend")
|
||||
else:
|
||||
# PyTorch CPU primary
|
||||
self._backends['pytorch_cpu'] = PyTorchCPUBackend()
|
||||
self._current_backend = 'pytorch_cpu'
|
||||
logger.info("Initialized PyTorch CPU backend")
|
||||
|
||||
# ONNX CPU fallback
|
||||
self._backends['onnx_cpu'] = ONNXCPUBackend()
|
||||
logger.info("Initialized ONNX CPU backend")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize CPU backends: {e}")
|
||||
raise RuntimeError("No backends available")
|
||||
logger.error(f"Failed to initialize backend: {e}")
|
||||
raise RuntimeError("Failed to initialize backend")
|
||||
|
||||
def get_backend(self, backend_type: Optional[str] = None) -> BaseModelBackend:
|
||||
"""Get specified backend.
|
||||
|
@ -154,19 +158,42 @@ class ModelManager:
|
|||
if backend_type is None:
|
||||
backend_type = self._determine_backend(abs_path)
|
||||
|
||||
backend = self.get_backend(backend_type)
|
||||
# Get backend lock
|
||||
lock = self._backend_locks[backend_type]
|
||||
|
||||
# Load model
|
||||
await backend.load_model(abs_path)
|
||||
logger.info(f"Loaded model on {backend_type} backend")
|
||||
|
||||
# Run warmup if voice provided
|
||||
if warmup_voice is not None:
|
||||
await self._warmup_inference(backend, warmup_voice)
|
||||
async with lock:
|
||||
backend = self.get_backend(backend_type)
|
||||
|
||||
# For ONNX backends, use session pool
|
||||
if backend_type.startswith('onnx'):
|
||||
pool = self._session_pools[backend_type]
|
||||
backend._session = await pool.get_session(abs_path)
|
||||
self._loaded_models[backend_type] = abs_path
|
||||
logger.info(f"Fetched model instance from {backend_type} pool")
|
||||
|
||||
# For PyTorch backends, load normally
|
||||
else:
|
||||
# Check if model is already loaded
|
||||
if (backend_type in self._loaded_models and
|
||||
self._loaded_models[backend_type] == abs_path and
|
||||
backend.is_loaded):
|
||||
logger.info(f"Fetching existing model instance from {backend_type}")
|
||||
return
|
||||
|
||||
# Load model
|
||||
await backend.load_model(abs_path)
|
||||
self._loaded_models[backend_type] = abs_path
|
||||
logger.info(f"Initialized new model instance on {backend_type}")
|
||||
|
||||
# Run warmup if voice provided
|
||||
if warmup_voice is not None:
|
||||
await self._warmup_inference(backend, warmup_voice)
|
||||
|
||||
except Exception as e:
|
||||
# Clear cached path on failure
|
||||
self._loaded_models.pop(backend_type, None)
|
||||
raise RuntimeError(f"Failed to load model: {e}")
|
||||
|
||||
|
||||
async def _warmup_inference(self, backend: BaseModelBackend, voice: torch.Tensor) -> None:
|
||||
"""Run warmup inference to initialize model.
|
||||
|
||||
|
@ -188,7 +215,7 @@ class ModelManager:
|
|||
|
||||
# Run inference
|
||||
backend.generate(tokens, voice, speed=1.0)
|
||||
logger.info("Completed warmup inference")
|
||||
logger.debug("Completed warmup inference")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Warmup inference failed: {e}")
|
||||
|
@ -221,16 +248,23 @@ class ModelManager:
|
|||
|
||||
try:
|
||||
# Generate audio using provided voice tensor
|
||||
# No lock needed here since inference is thread-safe
|
||||
return backend.generate(tokens, voice, speed)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Generation failed: {e}")
|
||||
|
||||
def unload_all(self) -> None:
|
||||
"""Unload models from all backends."""
|
||||
"""Unload models from all backends and clear cache."""
|
||||
# Clean up session pools
|
||||
for pool in self._session_pools.values():
|
||||
pool.cleanup()
|
||||
|
||||
# Unload PyTorch backends
|
||||
for backend in self._backends.values():
|
||||
backend.unload()
|
||||
logger.info("Unloaded all models")
|
||||
|
||||
self._loaded_models.clear()
|
||||
logger.info("Unloaded all models and cleared cache")
|
||||
|
||||
@property
|
||||
def available_backends(self) -> list[str]:
|
||||
|
@ -251,12 +285,8 @@ class ModelManager:
|
|||
return self._current_backend
|
||||
|
||||
|
||||
# Module-level instance
|
||||
_manager: Optional[ModelManager] = None
|
||||
|
||||
|
||||
def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
|
||||
"""Get or create global model manager instance.
|
||||
async def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
|
||||
"""Get global model manager instance.
|
||||
|
||||
Args:
|
||||
config: Optional model configuration
|
||||
|
@ -264,7 +294,10 @@ def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
|
|||
Returns:
|
||||
ModelManager instance
|
||||
"""
|
||||
global _manager
|
||||
if _manager is None:
|
||||
_manager = ModelManager(config)
|
||||
return _manager
|
||||
global _manager_instance
|
||||
|
||||
async with _manager_lock:
|
||||
if _manager_instance is None:
|
||||
_manager_instance = ModelManager(config)
|
||||
await _manager_instance.initialize()
|
||||
return _manager_instance
|
|
@ -1,20 +1,16 @@
|
|||
"""CPU-based ONNX inference backend."""
|
||||
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
from onnxruntime import (
|
||||
ExecutionMode,
|
||||
GraphOptimizationLevel,
|
||||
InferenceSession,
|
||||
SessionOptions
|
||||
)
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
from ..core import paths
|
||||
from ..core.model_config import model_config
|
||||
from .base import BaseModelBackend
|
||||
from .session_pool import create_session_options, create_provider_options
|
||||
|
||||
|
||||
class ONNXCPUBackend(BaseModelBackend):
|
||||
|
@ -47,8 +43,8 @@ class ONNXCPUBackend(BaseModelBackend):
|
|||
logger.info(f"Loading ONNX model: {model_path}")
|
||||
|
||||
# Configure session
|
||||
options = self._create_session_options()
|
||||
provider_options = self._create_provider_options()
|
||||
options = create_session_options(is_gpu=False)
|
||||
provider_options = create_provider_options(is_gpu=False)
|
||||
|
||||
# Create session
|
||||
self._session = InferenceSession(
|
||||
|
@ -84,9 +80,9 @@ class ONNXCPUBackend(BaseModelBackend):
|
|||
raise RuntimeError("Model not loaded")
|
||||
|
||||
try:
|
||||
# Prepare inputs
|
||||
tokens_input = np.array([tokens], dtype=np.int64)
|
||||
style_input = voice[len(tokens)].numpy()
|
||||
# Prepare inputs with start/end tokens
|
||||
tokens_input = np.array([[0, *tokens, 0]], dtype=np.int64) # Add start/end tokens
|
||||
style_input = voice[len(tokens) + 2].numpy() # Adjust index for start/end tokens
|
||||
speed_input = np.full(1, speed, dtype=np.float32)
|
||||
|
||||
# Run inference
|
||||
|
@ -104,52 +100,6 @@ class ONNXCPUBackend(BaseModelBackend):
|
|||
except Exception as e:
|
||||
raise RuntimeError(f"Generation failed: {e}")
|
||||
|
||||
def _create_session_options(self) -> SessionOptions:
|
||||
"""Create ONNX session options.
|
||||
|
||||
Returns:
|
||||
Configured session options
|
||||
"""
|
||||
options = SessionOptions()
|
||||
config = model_config.onnx_cpu
|
||||
|
||||
# Set optimization level
|
||||
if config.optimization_level == "all":
|
||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
elif config.optimization_level == "basic":
|
||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||
else:
|
||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
|
||||
|
||||
# Configure threading
|
||||
options.intra_op_num_threads = config.num_threads
|
||||
options.inter_op_num_threads = config.inter_op_threads
|
||||
|
||||
# Set execution mode
|
||||
options.execution_mode = (
|
||||
ExecutionMode.ORT_PARALLEL
|
||||
if config.execution_mode == "parallel"
|
||||
else ExecutionMode.ORT_SEQUENTIAL
|
||||
)
|
||||
|
||||
# Configure memory optimization
|
||||
options.enable_mem_pattern = config.memory_pattern
|
||||
|
||||
return options
|
||||
|
||||
def _create_provider_options(self) -> Dict:
|
||||
"""Create CPU provider options.
|
||||
|
||||
Returns:
|
||||
Provider configuration
|
||||
"""
|
||||
return {
|
||||
"CPUExecutionProvider": {
|
||||
"arena_extend_strategy": model_config.onnx_cpu.arena_extend_strategy,
|
||||
"cpu_memory_arena_cfg": "cpu:0"
|
||||
}
|
||||
}
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Unload model and free resources."""
|
||||
if self._session is not None:
|
||||
|
|
|
@ -1,20 +1,16 @@
|
|||
"""GPU-based ONNX inference backend."""
|
||||
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
from onnxruntime import (
|
||||
ExecutionMode,
|
||||
GraphOptimizationLevel,
|
||||
InferenceSession,
|
||||
SessionOptions
|
||||
)
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
from ..core import paths
|
||||
from ..core.model_config import model_config
|
||||
from .base import BaseModelBackend
|
||||
from .session_pool import create_session_options, create_provider_options
|
||||
|
||||
|
||||
class ONNXGPUBackend(BaseModelBackend):
|
||||
|
@ -27,6 +23,9 @@ class ONNXGPUBackend(BaseModelBackend):
|
|||
raise RuntimeError("CUDA not available")
|
||||
self._device = "cuda"
|
||||
self._session: Optional[InferenceSession] = None
|
||||
|
||||
# Configure GPU
|
||||
torch.cuda.set_device(model_config.onnx_gpu.device_id)
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
|
@ -49,8 +48,8 @@ class ONNXGPUBackend(BaseModelBackend):
|
|||
logger.info(f"Loading ONNX model on GPU: {model_path}")
|
||||
|
||||
# Configure session
|
||||
options = self._create_session_options()
|
||||
provider_options = self._create_provider_options()
|
||||
options = create_session_options(is_gpu=True)
|
||||
provider_options = create_provider_options(is_gpu=True)
|
||||
|
||||
# Create session with CUDA provider
|
||||
self._session = InferenceSession(
|
||||
|
@ -87,8 +86,8 @@ class ONNXGPUBackend(BaseModelBackend):
|
|||
|
||||
try:
|
||||
# Prepare inputs
|
||||
tokens_input = np.array([tokens], dtype=np.int64)
|
||||
style_input = voice[len(tokens)].cpu().numpy() # Move to CPU for ONNX
|
||||
tokens_input = np.array([[0, *tokens, 0]], dtype=np.int64) # Add start/end tokens
|
||||
style_input = voice[len(tokens) + 2].cpu().numpy() # Move to CPU for ONNX
|
||||
speed_input = np.full(1, speed, dtype=np.float32)
|
||||
|
||||
# Run inference
|
||||
|
@ -104,62 +103,15 @@ class ONNXGPUBackend(BaseModelBackend):
|
|||
return result[0]
|
||||
|
||||
except Exception as e:
|
||||
if "out of memory" in str(e).lower():
|
||||
# Clear CUDA cache and retry
|
||||
torch.cuda.empty_cache()
|
||||
return self.generate(tokens, voice, speed)
|
||||
raise RuntimeError(f"Generation failed: {e}")
|
||||
|
||||
def _create_session_options(self) -> SessionOptions:
|
||||
"""Create ONNX session options.
|
||||
|
||||
Returns:
|
||||
Configured session options
|
||||
"""
|
||||
options = SessionOptions()
|
||||
config = model_config.onnx_gpu
|
||||
|
||||
# Set optimization level
|
||||
if config.optimization_level == "all":
|
||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
elif config.optimization_level == "basic":
|
||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||
else:
|
||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
|
||||
|
||||
# Configure threading
|
||||
options.intra_op_num_threads = config.num_threads
|
||||
options.inter_op_num_threads = config.inter_op_threads
|
||||
|
||||
# Set execution mode
|
||||
options.execution_mode = (
|
||||
ExecutionMode.ORT_PARALLEL
|
||||
if config.execution_mode == "parallel"
|
||||
else ExecutionMode.ORT_SEQUENTIAL
|
||||
)
|
||||
|
||||
# Configure memory optimization
|
||||
options.enable_mem_pattern = config.memory_pattern
|
||||
|
||||
return options
|
||||
|
||||
def _create_provider_options(self) -> Dict:
|
||||
"""Create CUDA provider options.
|
||||
|
||||
Returns:
|
||||
Provider configuration
|
||||
"""
|
||||
config = model_config.onnx_gpu
|
||||
return {
|
||||
"CUDAExecutionProvider": {
|
||||
"device_id": config.device_id,
|
||||
"arena_extend_strategy": config.arena_extend_strategy,
|
||||
"gpu_mem_limit": int(config.gpu_mem_limit * torch.cuda.get_device_properties(0).total_memory),
|
||||
"cudnn_conv_algo_search": config.cudnn_conv_algo_search,
|
||||
"do_copy_in_default_stream": config.do_copy_in_default_stream
|
||||
}
|
||||
}
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Unload model and free resources."""
|
||||
if self._session is not None:
|
||||
del self._session
|
||||
self._session = None
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.empty_cache()
|
|
@ -13,8 +13,37 @@ from ..core.model_config import model_config
|
|||
from .base import BaseModelBackend
|
||||
|
||||
|
||||
class CUDAStreamManager:
|
||||
"""CUDA stream manager."""
|
||||
|
||||
def __init__(self, num_streams: int):
|
||||
"""Initialize stream manager.
|
||||
|
||||
Args:
|
||||
num_streams: Number of CUDA streams
|
||||
"""
|
||||
self.streams = [torch.cuda.Stream() for _ in range(num_streams)]
|
||||
self._current = 0
|
||||
|
||||
def get_next_stream(self) -> torch.cuda.Stream:
|
||||
"""Get next available stream.
|
||||
|
||||
Returns:
|
||||
CUDA stream
|
||||
"""
|
||||
stream = self.streams[self._current]
|
||||
self._current = (self._current + 1) % len(self.streams)
|
||||
return stream
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(model: torch.nn.Module, tokens: list[int], ref_s: torch.Tensor, speed: float) -> np.ndarray:
|
||||
def forward(
|
||||
model: torch.nn.Module,
|
||||
tokens: list[int],
|
||||
ref_s: torch.Tensor,
|
||||
speed: float,
|
||||
stream: Optional[torch.cuda.Stream] = None
|
||||
) -> np.ndarray:
|
||||
"""Forward pass through model.
|
||||
|
||||
Args:
|
||||
|
@ -22,59 +51,67 @@ def forward(model: torch.nn.Module, tokens: list[int], ref_s: torch.Tensor, spee
|
|||
tokens: Input tokens
|
||||
ref_s: Reference signal (shape: [1, n_features])
|
||||
speed: Speed multiplier
|
||||
stream: Optional CUDA stream
|
||||
|
||||
Returns:
|
||||
Generated audio
|
||||
"""
|
||||
device = ref_s.device
|
||||
|
||||
# Initial tensor setup
|
||||
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
||||
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
||||
text_mask = length_to_mask(input_lengths).to(device)
|
||||
# Use provided stream or default
|
||||
with torch.cuda.stream(stream) if stream else torch.cuda.default_stream():
|
||||
# Initial tensor setup
|
||||
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
||||
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
||||
text_mask = length_to_mask(input_lengths).to(device)
|
||||
|
||||
# Split reference signals (style_dim=128 from config)
|
||||
style_dim = 128
|
||||
s_ref = ref_s[:, :style_dim].clone().to(device)
|
||||
s_content = ref_s[:, style_dim:].clone().to(device)
|
||||
# Split reference signals (style_dim=128 from config)
|
||||
style_dim = 128
|
||||
s_ref = ref_s[:, :style_dim].clone().to(device)
|
||||
s_content = ref_s[:, style_dim:].clone().to(device)
|
||||
|
||||
# BERT and encoder pass
|
||||
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
||||
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
||||
# BERT and encoder pass
|
||||
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
||||
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
||||
|
||||
# Predictor forward pass
|
||||
d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask)
|
||||
x, _ = model.predictor.lstm(d)
|
||||
# Predictor forward pass
|
||||
d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask)
|
||||
x, _ = model.predictor.lstm(d)
|
||||
|
||||
# Duration prediction
|
||||
duration = model.predictor.duration_proj(x)
|
||||
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
||||
pred_dur = torch.round(duration).clamp(min=1).long()
|
||||
del duration, x
|
||||
# Duration prediction
|
||||
duration = model.predictor.duration_proj(x)
|
||||
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
||||
pred_dur = torch.round(duration).clamp(min=1).long()
|
||||
del duration, x
|
||||
|
||||
# Alignment matrix construction
|
||||
pred_aln_trg = torch.zeros(input_lengths.item(), pred_dur.sum().item(), device=device)
|
||||
c_frame = 0
|
||||
for i in range(pred_aln_trg.size(0)):
|
||||
pred_aln_trg[i, c_frame:c_frame + pred_dur[0, i].item()] = 1
|
||||
c_frame += pred_dur[0, i].item()
|
||||
pred_aln_trg = pred_aln_trg.unsqueeze(0)
|
||||
# Alignment matrix construction
|
||||
pred_aln_trg = torch.zeros(input_lengths.item(), pred_dur.sum().item(), device=device)
|
||||
c_frame = 0
|
||||
for i in range(pred_aln_trg.size(0)):
|
||||
pred_aln_trg[i, c_frame:c_frame + pred_dur[0, i].item()] = 1
|
||||
c_frame += pred_dur[0, i].item()
|
||||
pred_aln_trg = pred_aln_trg.unsqueeze(0)
|
||||
|
||||
# Matrix multiplications
|
||||
en = d.transpose(-1, -2) @ pred_aln_trg
|
||||
del d
|
||||
|
||||
F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
|
||||
del en
|
||||
# Matrix multiplications
|
||||
en = d.transpose(-1, -2) @ pred_aln_trg
|
||||
del d
|
||||
|
||||
F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
|
||||
del en
|
||||
|
||||
# Final text encoding and decoding
|
||||
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
||||
asr = t_en @ pred_aln_trg
|
||||
del t_en
|
||||
# Final text encoding and decoding
|
||||
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
||||
asr = t_en @ pred_aln_trg
|
||||
del t_en
|
||||
|
||||
# Generate output
|
||||
output = model.decoder(asr, F0_pred, N_pred, s_ref)
|
||||
return output.squeeze().cpu().numpy()
|
||||
# Generate output
|
||||
output = model.decoder(asr, F0_pred, N_pred, s_ref)
|
||||
|
||||
# Ensure operation completion if using custom stream
|
||||
if stream:
|
||||
stream.synchronize()
|
||||
|
||||
return output.squeeze().cpu().numpy()
|
||||
|
||||
|
||||
def length_to_mask(lengths: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -92,9 +129,10 @@ class PyTorchGPUBackend(BaseModelBackend):
|
|||
def __init__(self):
|
||||
"""Initialize GPU backend."""
|
||||
super().__init__()
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("CUDA not available")
|
||||
self._device = "cuda"
|
||||
from ..core.config import settings
|
||||
if not (settings.use_gpu and torch.cuda.is_available()):
|
||||
raise RuntimeError("GPU backend requires GPU support and use_gpu=True")
|
||||
self._device = "cuda" # Device is enforced by backend selection in model_manager
|
||||
self._model: Optional[torch.nn.Module] = None
|
||||
|
||||
# Configure GPU settings
|
||||
|
@ -102,6 +140,9 @@ class PyTorchGPUBackend(BaseModelBackend):
|
|||
if config.sync_cuda:
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.set_device(config.device_id)
|
||||
|
||||
# Initialize stream manager
|
||||
self._stream_manager = CUDAStreamManager(config.cuda_streams)
|
||||
|
||||
async def load_model(self, path: str) -> None:
|
||||
"""Load PyTorch model.
|
||||
|
@ -154,8 +195,11 @@ class PyTorchGPUBackend(BaseModelBackend):
|
|||
if ref_s.dim() == 1:
|
||||
ref_s = ref_s.unsqueeze(0) # Add batch dimension if needed
|
||||
|
||||
# Generate audio
|
||||
return forward(self._model, tokens, ref_s, speed)
|
||||
# Get next available stream
|
||||
stream = self._stream_manager.get_next_stream()
|
||||
|
||||
# Generate audio using stream
|
||||
return forward(self._model, tokens, ref_s, speed, stream)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Generation failed: {e}")
|
||||
|
|
271
api/src/inference/session_pool.py
Normal file
271
api/src/inference/session_pool.py
Normal file
|
@ -0,0 +1,271 @@
|
|||
"""Session pooling for model inference."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Set
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from onnxruntime import (
|
||||
ExecutionMode,
|
||||
GraphOptimizationLevel,
|
||||
InferenceSession,
|
||||
SessionOptions
|
||||
)
|
||||
|
||||
from ..core import paths
|
||||
from ..core.model_config import model_config
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionInfo:
|
||||
"""Session information."""
|
||||
session: InferenceSession
|
||||
last_used: float
|
||||
stream_id: Optional[int] = None
|
||||
|
||||
|
||||
def create_session_options(is_gpu: bool = False) -> SessionOptions:
|
||||
"""Create ONNX session options.
|
||||
|
||||
Args:
|
||||
is_gpu: Whether to use GPU configuration
|
||||
|
||||
Returns:
|
||||
Configured session options
|
||||
"""
|
||||
options = SessionOptions()
|
||||
config = model_config.onnx_gpu if is_gpu else model_config.onnx_cpu
|
||||
|
||||
# Set optimization level
|
||||
if config.optimization_level == "all":
|
||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
elif config.optimization_level == "basic":
|
||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||
else:
|
||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
|
||||
|
||||
# Configure threading
|
||||
options.intra_op_num_threads = config.num_threads
|
||||
options.inter_op_num_threads = config.inter_op_threads
|
||||
|
||||
# Set execution mode
|
||||
options.execution_mode = (
|
||||
ExecutionMode.ORT_PARALLEL
|
||||
if config.execution_mode == "parallel"
|
||||
else ExecutionMode.ORT_SEQUENTIAL
|
||||
)
|
||||
|
||||
# Configure memory optimization
|
||||
options.enable_mem_pattern = config.memory_pattern
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def create_provider_options(is_gpu: bool = False) -> Dict:
|
||||
"""Create provider options.
|
||||
|
||||
Args:
|
||||
is_gpu: Whether to use GPU configuration
|
||||
|
||||
Returns:
|
||||
Provider configuration
|
||||
"""
|
||||
if is_gpu:
|
||||
config = model_config.onnx_gpu
|
||||
return {
|
||||
"CUDAExecutionProvider": {
|
||||
"device_id": config.device_id,
|
||||
"arena_extend_strategy": config.arena_extend_strategy,
|
||||
"gpu_mem_limit": int(config.gpu_mem_limit * torch.cuda.get_device_properties(0).total_memory),
|
||||
"cudnn_conv_algo_search": config.cudnn_conv_algo_search,
|
||||
"do_copy_in_default_stream": config.do_copy_in_default_stream
|
||||
}
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"CPUExecutionProvider": {
|
||||
"arena_extend_strategy": model_config.onnx_cpu.arena_extend_strategy,
|
||||
"cpu_memory_arena_cfg": "cpu:0"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class BaseSessionPool:
|
||||
"""Base session pool implementation."""
|
||||
|
||||
def __init__(self, max_size: int, timeout: int):
|
||||
"""Initialize session pool.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of concurrent sessions
|
||||
timeout: Session timeout in seconds
|
||||
"""
|
||||
self._max_size = max_size
|
||||
self._timeout = timeout
|
||||
self._sessions: Dict[str, SessionInfo] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def get_session(self, model_path: str) -> InferenceSession:
|
||||
"""Get session from pool.
|
||||
|
||||
Args:
|
||||
model_path: Path to model file
|
||||
|
||||
Returns:
|
||||
ONNX inference session
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no sessions available
|
||||
"""
|
||||
async with self._lock:
|
||||
# Clean expired sessions
|
||||
self._cleanup_expired()
|
||||
|
||||
# Check if session exists and is valid
|
||||
if model_path in self._sessions:
|
||||
session_info = self._sessions[model_path]
|
||||
session_info.last_used = time.time()
|
||||
return session_info.session
|
||||
|
||||
# Check if we can create new session
|
||||
if len(self._sessions) >= self._max_size:
|
||||
raise RuntimeError(
|
||||
f"Maximum number of sessions reached ({self._max_size})"
|
||||
)
|
||||
|
||||
# Create new session
|
||||
session = await self._create_session(model_path)
|
||||
self._sessions[model_path] = SessionInfo(
|
||||
session=session,
|
||||
last_used=time.time()
|
||||
)
|
||||
return session
|
||||
|
||||
def _cleanup_expired(self) -> None:
|
||||
"""Remove expired sessions."""
|
||||
current_time = time.time()
|
||||
expired = [
|
||||
path for path, info in self._sessions.items()
|
||||
if current_time - info.last_used > self._timeout
|
||||
]
|
||||
for path in expired:
|
||||
logger.info(f"Removing expired session: {path}")
|
||||
del self._sessions[path]
|
||||
|
||||
async def _create_session(self, model_path: str) -> InferenceSession:
|
||||
"""Create new session.
|
||||
|
||||
Args:
|
||||
model_path: Path to model file
|
||||
|
||||
Returns:
|
||||
ONNX inference session
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Clean up all sessions."""
|
||||
self._sessions.clear()
|
||||
|
||||
|
||||
class StreamingSessionPool(BaseSessionPool):
|
||||
"""GPU session pool with CUDA streams."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize GPU session pool."""
|
||||
config = model_config.onnx_gpu
|
||||
super().__init__(config.cuda_streams, config.stream_timeout)
|
||||
self._available_streams: Set[int] = set(range(config.cuda_streams))
|
||||
|
||||
async def get_session(self, model_path: str) -> InferenceSession:
|
||||
"""Get session with CUDA stream.
|
||||
|
||||
Args:
|
||||
model_path: Path to model file
|
||||
|
||||
Returns:
|
||||
ONNX inference session
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no streams available
|
||||
"""
|
||||
async with self._lock:
|
||||
# Clean expired sessions
|
||||
self._cleanup_expired()
|
||||
|
||||
# Try to find existing session
|
||||
if model_path in self._sessions:
|
||||
session_info = self._sessions[model_path]
|
||||
session_info.last_used = time.time()
|
||||
return session_info.session
|
||||
|
||||
# Get available stream
|
||||
if not self._available_streams:
|
||||
raise RuntimeError("No CUDA streams available")
|
||||
stream_id = self._available_streams.pop()
|
||||
|
||||
try:
|
||||
# Create new session
|
||||
session = await self._create_session(model_path)
|
||||
self._sessions[model_path] = SessionInfo(
|
||||
session=session,
|
||||
last_used=time.time(),
|
||||
stream_id=stream_id
|
||||
)
|
||||
return session
|
||||
|
||||
except Exception:
|
||||
# Return stream to pool on failure
|
||||
self._available_streams.add(stream_id)
|
||||
raise
|
||||
|
||||
def _cleanup_expired(self) -> None:
|
||||
"""Remove expired sessions and return streams."""
|
||||
current_time = time.time()
|
||||
expired = [
|
||||
path for path, info in self._sessions.items()
|
||||
if current_time - info.last_used > self._timeout
|
||||
]
|
||||
for path in expired:
|
||||
info = self._sessions[path]
|
||||
if info.stream_id is not None:
|
||||
self._available_streams.add(info.stream_id)
|
||||
logger.info(f"Removing expired session: {path}")
|
||||
del self._sessions[path]
|
||||
|
||||
async def _create_session(self, model_path: str) -> InferenceSession:
|
||||
"""Create new session with CUDA provider."""
|
||||
abs_path = await paths.get_model_path(model_path)
|
||||
options = create_session_options(is_gpu=True)
|
||||
provider_options = create_provider_options(is_gpu=True)
|
||||
|
||||
return InferenceSession(
|
||||
abs_path,
|
||||
sess_options=options,
|
||||
providers=["CUDAExecutionProvider"],
|
||||
provider_options=[provider_options]
|
||||
)
|
||||
|
||||
|
||||
class CPUSessionPool(BaseSessionPool):
|
||||
"""CPU session pool."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize CPU session pool."""
|
||||
config = model_config.onnx_cpu
|
||||
super().__init__(config.max_instances, config.instance_timeout)
|
||||
|
||||
async def _create_session(self, model_path: str) -> InferenceSession:
|
||||
"""Create new session with CPU provider."""
|
||||
abs_path = await paths.get_model_path(model_path)
|
||||
options = create_session_options(is_gpu=False)
|
||||
provider_options = create_provider_options(is_gpu=False)
|
||||
|
||||
return InferenceSession(
|
||||
abs_path,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
provider_options=[provider_options]
|
||||
)
|
|
@ -1,11 +1,10 @@
|
|||
"""Voice pack management and caching."""
|
||||
|
||||
import os
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..core import paths
|
||||
from ..core.config import settings
|
||||
|
@ -13,7 +12,7 @@ from ..structures.model_schemas import VoiceConfig
|
|||
|
||||
|
||||
class VoiceManager:
|
||||
"""Manages voice loading, caching, and operations."""
|
||||
"""Manages voice loading and operations."""
|
||||
|
||||
def __init__(self, config: Optional[VoiceConfig] = None):
|
||||
"""Initialize voice manager.
|
||||
|
@ -33,15 +32,8 @@ class VoiceManager:
|
|||
Returns:
|
||||
Path to voice file if exists, None otherwise
|
||||
"""
|
||||
# Get api directory path (two levels up from inference)
|
||||
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
# Construct voice path relative to api directory
|
||||
voice_path = os.path.join(api_dir, settings.voices_dir, f"{voice_name}.pt")
|
||||
|
||||
# Ensure voices directory exists
|
||||
os.makedirs(os.path.dirname(voice_path), exist_ok=True)
|
||||
|
||||
return voice_path if os.path.exists(voice_path) else None
|
||||
|
||||
async def load_voice(self, voice_name: str, device: str = "cpu") -> torch.Tensor:
|
||||
|
@ -66,21 +58,20 @@ class VoiceManager:
|
|||
if self._config.use_cache and cache_key in self._voice_cache:
|
||||
return self._voice_cache[cache_key]
|
||||
|
||||
# Load voice tensor
|
||||
try:
|
||||
# Load voice tensor
|
||||
voice = await paths.load_voice_tensor(voice_path, device=device)
|
||||
|
||||
# Cache if enabled
|
||||
if self._config.use_cache:
|
||||
self._manage_cache()
|
||||
self._voice_cache[cache_key] = voice
|
||||
logger.debug(f"Cached voice: {voice_name} on {device}")
|
||||
|
||||
return voice
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load voice {voice_name}: {e}")
|
||||
|
||||
# Cache if enabled
|
||||
if self._config.use_cache:
|
||||
self._manage_cache()
|
||||
self._voice_cache[cache_key] = voice
|
||||
logger.debug(f"Cached voice: {voice_name} on {device}")
|
||||
|
||||
return voice
|
||||
|
||||
def _manage_cache(self) -> None:
|
||||
"""Manage voice cache size."""
|
||||
if len(self._voice_cache) >= self._config.cache_size:
|
||||
|
@ -123,14 +114,14 @@ class VoiceManager:
|
|||
# Get api directory path
|
||||
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
voices_dir = os.path.join(api_dir, settings.voices_dir)
|
||||
|
||||
# Ensure voices directory exists
|
||||
os.makedirs(voices_dir, exist_ok=True)
|
||||
|
||||
# Save combined voice
|
||||
combined_path = os.path.join(voices_dir, f"{combined_name}.pt")
|
||||
try:
|
||||
torch.save(combined_tensor, combined_path)
|
||||
# Cache the new combined voice
|
||||
self._voice_cache[f"{combined_path}_{device}"] = combined_tensor
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to save combined voice: {e}")
|
||||
|
||||
|
@ -147,17 +138,13 @@ class VoiceManager:
|
|||
"""
|
||||
voices = []
|
||||
try:
|
||||
# Get api directory path
|
||||
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
voices_dir = os.path.join(api_dir, settings.voices_dir)
|
||||
|
||||
# Ensure voices directory exists
|
||||
os.makedirs(voices_dir, exist_ok=True)
|
||||
|
||||
# List voice files
|
||||
for entry in os.listdir(voices_dir):
|
||||
if entry.endswith(".pt"):
|
||||
voices.append(entry[:-3]) # Remove .pt extension
|
||||
voices.append(entry[:-3])
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing voices: {e}")
|
||||
return sorted(voices)
|
||||
|
@ -174,11 +161,8 @@ class VoiceManager:
|
|||
try:
|
||||
if not os.path.exists(voice_path):
|
||||
return False
|
||||
|
||||
# Try loading voice
|
||||
voice = torch.load(voice_path, map_location="cpu")
|
||||
return isinstance(voice, torch.Tensor)
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
@ -195,12 +179,12 @@ class VoiceManager:
|
|||
}
|
||||
|
||||
|
||||
# Module-level instance
|
||||
_manager: Optional[VoiceManager] = None
|
||||
# Global singleton instance
|
||||
_manager_instance = None
|
||||
|
||||
|
||||
def get_manager(config: Optional[VoiceConfig] = None) -> VoiceManager:
|
||||
"""Get or create global voice manager instance.
|
||||
async def get_manager(config: Optional[VoiceConfig] = None) -> VoiceManager:
|
||||
"""Get global voice manager instance.
|
||||
|
||||
Args:
|
||||
config: Optional voice configuration
|
||||
|
@ -208,7 +192,7 @@ def get_manager(config: Optional[VoiceConfig] = None) -> VoiceManager:
|
|||
Returns:
|
||||
VoiceManager instance
|
||||
"""
|
||||
global _manager
|
||||
if _manager is None:
|
||||
_manager = VoiceManager(config)
|
||||
return _manager
|
||||
global _manager_instance
|
||||
if _manager_instance is None:
|
||||
_manager_instance = VoiceManager(config)
|
||||
return _manager_instance
|
|
@ -1,10 +1,13 @@
|
|||
|
||||
"""
|
||||
FastAPI OpenAI Compatible API
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import torch
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
@ -41,19 +44,59 @@ setup_logger()
|
|||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Lifespan context manager for model initialization"""
|
||||
from .inference.model_manager import get_manager
|
||||
from .inference.voice_manager import get_manager as get_voice_manager
|
||||
|
||||
logger.info("Loading TTS model and voice packs...")
|
||||
|
||||
# Initialize service
|
||||
service = TTSService()
|
||||
await service.ensure_initialized()
|
||||
|
||||
# Get available voices
|
||||
voices = await service.list_voices()
|
||||
voicepack_count = len(voices)
|
||||
try:
|
||||
# Initialize managers globally
|
||||
model_manager = await get_manager()
|
||||
voice_manager = await get_voice_manager()
|
||||
|
||||
# Get device info from model manager
|
||||
device = "GPU" if settings.use_gpu else "CPU"
|
||||
model = "ONNX" if settings.use_onnx else "PyTorch"
|
||||
# Determine backend type based on settings
|
||||
if settings.use_gpu and torch.cuda.is_available():
|
||||
backend_type = 'pytorch_gpu' if not settings.use_onnx else 'onnx_gpu'
|
||||
else:
|
||||
backend_type = 'pytorch_cpu' if not settings.use_onnx else 'onnx_cpu'
|
||||
|
||||
# Get backend and initialize model
|
||||
backend = model_manager.get_backend(backend_type)
|
||||
|
||||
# Use model path directly from settings
|
||||
model_file = settings.pytorch_model_file if not settings.use_onnx else settings.onnx_model_file
|
||||
model_path = os.path.join(settings.model_dir, model_file)
|
||||
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
raise RuntimeError(f"Model file not found: {model_path}")
|
||||
|
||||
# Pre-cache default voice and use for warmup
|
||||
warmup_voice = await voice_manager.load_voice(settings.default_voice, device=backend.device)
|
||||
logger.info(f"Pre-cached voice {settings.default_voice} for warmup")
|
||||
|
||||
# Initialize model with warmup voice
|
||||
await model_manager.load_model(model_path, warmup_voice, backend_type)
|
||||
|
||||
# Pre-cache common voices in background
|
||||
common_voices = ['af', 'af_bella', 'af_sarah', 'af_nicole']
|
||||
for voice_name in common_voices:
|
||||
try:
|
||||
await voice_manager.load_voice(voice_name, device=backend.device)
|
||||
logger.debug(f"Pre-cached voice {voice_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to pre-cache voice {voice_name}: {e}")
|
||||
|
||||
# Get available voices for startup message
|
||||
voices = await voice_manager.list_voices()
|
||||
voicepack_count = len(voices)
|
||||
|
||||
# Get device info for startup message
|
||||
device = "GPU" if settings.use_gpu else "CPU"
|
||||
model = "ONNX" if settings.use_onnx else "PyTorch"
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize model: {e}")
|
||||
raise
|
||||
boundary = "░" * 2*12
|
||||
startup_msg = f"""
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
from loguru import logger
|
||||
|
||||
from ..services.audio import AudioService
|
||||
|
@ -16,9 +16,9 @@ from ..structures.text_schemas import (
|
|||
router = APIRouter(tags=["text processing"])
|
||||
|
||||
|
||||
def get_tts_service() -> TTSService:
|
||||
async def get_tts_service() -> TTSService:
|
||||
"""Dependency to get TTSService instance"""
|
||||
return TTSService()
|
||||
return await TTSService.create() # Create service with properly initialized managers
|
||||
|
||||
|
||||
@router.post("/text/phonemize", response_model=PhonemeResponse, tags=["deprecated"])
|
||||
|
@ -82,9 +82,6 @@ async def generate_from_phonemes(
|
|||
)
|
||||
|
||||
try:
|
||||
# Ensure service is initialized
|
||||
await tts_service.ensure_initialized()
|
||||
|
||||
# Validate voice exists
|
||||
available_voices = await tts_service.list_voices()
|
||||
if request.voice not in available_voices:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from typing import AsyncGenerator, List, Union
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Response, Request
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from loguru import logger
|
||||
|
||||
|
@ -13,10 +13,28 @@ router = APIRouter(
|
|||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
# Global TTSService instance with lock
|
||||
_tts_service = None
|
||||
_init_lock = None
|
||||
|
||||
def get_tts_service() -> TTSService:
|
||||
"""Dependency to get TTSService instance with database session"""
|
||||
return TTSService() # Initialize TTSService with default settings
|
||||
async def get_tts_service() -> TTSService:
|
||||
"""Get global TTSService instance"""
|
||||
global _tts_service, _init_lock
|
||||
|
||||
# Create lock if needed
|
||||
if _init_lock is None:
|
||||
import asyncio
|
||||
_init_lock = asyncio.Lock()
|
||||
|
||||
# Initialize service if needed
|
||||
if _tts_service is None:
|
||||
async with _init_lock:
|
||||
# Double check pattern
|
||||
if _tts_service is None:
|
||||
_tts_service = await TTSService.create()
|
||||
logger.info("Created global TTSService instance")
|
||||
|
||||
return _tts_service
|
||||
|
||||
|
||||
async def process_voices(
|
||||
|
@ -78,11 +96,13 @@ async def stream_audio_chunks(
|
|||
async def create_speech(
|
||||
request: OpenAISpeechRequest,
|
||||
client_request: Request,
|
||||
tts_service: TTSService = Depends(get_tts_service),
|
||||
x_raw_response: str = Header(None, alias="x-raw-response"),
|
||||
):
|
||||
"""OpenAI-compatible endpoint for text-to-speech"""
|
||||
try:
|
||||
# Get global service instance
|
||||
tts_service = await get_tts_service()
|
||||
|
||||
# Process voice combination and validate
|
||||
voice_to_use = await process_voices(request.voice, tts_service)
|
||||
|
||||
|
@ -145,9 +165,10 @@ async def create_speech(
|
|||
|
||||
|
||||
@router.get("/audio/voices")
|
||||
async def list_voices(tts_service: TTSService = Depends(get_tts_service)):
|
||||
async def list_voices():
|
||||
"""List all available voices for text-to-speech"""
|
||||
try:
|
||||
tts_service = await get_tts_service()
|
||||
voices = await tts_service.list_voices()
|
||||
return {"voices": voices}
|
||||
except Exception as e:
|
||||
|
@ -156,9 +177,7 @@ async def list_voices(tts_service: TTSService = Depends(get_tts_service)):
|
|||
|
||||
|
||||
@router.post("/audio/voices/combine")
|
||||
async def combine_voices(
|
||||
request: Union[str, List[str]], tts_service: TTSService = Depends(get_tts_service)
|
||||
):
|
||||
async def combine_voices(request: Union[str, List[str]]):
|
||||
"""Combine multiple voices into a new voice.
|
||||
|
||||
Args:
|
||||
|
@ -174,6 +193,7 @@ async def combine_voices(
|
|||
- 500: Server error (file system issues, combination failed)
|
||||
"""
|
||||
try:
|
||||
tts_service = await get_tts_service()
|
||||
combined_voice = await process_voices(request, tts_service)
|
||||
voices = await tts_service.list_voices()
|
||||
return {"voices": voices, "voice": combined_voice}
|
||||
|
|
|
@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
|||
import phonemizer
|
||||
|
||||
from .normalizer import normalize_text
|
||||
|
||||
phonemizers = {}
|
||||
|
||||
class PhonemizerBackend(ABC):
|
||||
"""Abstract base class for phonemization backends"""
|
||||
|
@ -91,8 +91,9 @@ def phonemize(text: str, language: str = "a", normalize: bool = True) -> str:
|
|||
Returns:
|
||||
Phonemized text
|
||||
"""
|
||||
global phonemizers
|
||||
if normalize:
|
||||
text = normalize_text(text)
|
||||
|
||||
phonemizer = create_phonemizer(language)
|
||||
return phonemizer.phonemize(text)
|
||||
if language not in phonemizers:
|
||||
phonemizers[language]=create_phonemizer(language)
|
||||
return phonemizers[language].phonemize(text)
|
|
@ -1,9 +1,8 @@
|
|||
"""TTS service using model and voice managers."""
|
||||
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
from typing import List, Tuple
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import scipy.io.wavfile as wavfile
|
||||
|
@ -17,9 +16,14 @@ from .audio import AudioNormalizer, AudioService
|
|||
from .text_processing import chunker, normalize_text, process_text
|
||||
|
||||
|
||||
import asyncio
|
||||
|
||||
class TTSService:
|
||||
"""Text-to-speech service."""
|
||||
|
||||
# Limit concurrent chunk processing
|
||||
_chunk_semaphore = asyncio.Semaphore(4)
|
||||
|
||||
def __init__(self, output_dir: str = None):
|
||||
"""Initialize service.
|
||||
|
||||
|
@ -27,53 +31,24 @@ class TTSService:
|
|||
output_dir: Optional output directory for saving audio
|
||||
"""
|
||||
self.output_dir = output_dir
|
||||
self.model_manager = get_model_manager()
|
||||
self.voice_manager = get_voice_manager()
|
||||
self._initialized = False
|
||||
self._initialization_error = None
|
||||
self.model_manager = None
|
||||
self._voice_manager = None
|
||||
|
||||
async def ensure_initialized(self):
|
||||
"""Ensure model is initialized."""
|
||||
if self._initialized:
|
||||
return
|
||||
if self._initialization_error:
|
||||
raise self._initialization_error
|
||||
|
||||
try:
|
||||
# Get api directory path (one level up from src)
|
||||
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
@classmethod
|
||||
async def create(cls, output_dir: str = None) -> 'TTSService':
|
||||
"""Create and initialize TTSService instance.
|
||||
|
||||
Args:
|
||||
output_dir: Optional output directory for saving audio
|
||||
|
||||
# Determine model file and backend based on hardware
|
||||
if settings.use_gpu and torch.cuda.is_available():
|
||||
model_file = settings.pytorch_model_file
|
||||
backend_type = 'pytorch_gpu'
|
||||
else:
|
||||
model_file = settings.onnx_model_file
|
||||
backend_type = 'onnx_cpu'
|
||||
|
||||
# Construct model path relative to api directory
|
||||
model_path = os.path.join(api_dir, settings.model_dir, model_file)
|
||||
|
||||
# Ensure model directory exists
|
||||
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
raise RuntimeError(f"Model file not found: {model_path}")
|
||||
|
||||
# Load default voice for warmup
|
||||
backend = self.model_manager.get_backend(backend_type)
|
||||
warmup_voice = await self.voice_manager.load_voice(settings.default_voice, device=backend.device)
|
||||
logger.info(f"Loaded voice {settings.default_voice} for warmup")
|
||||
|
||||
# Initialize model with warmup voice
|
||||
await self.model_manager.load_model(model_path, warmup_voice, backend_type)
|
||||
logger.info(f"Initialized model on {backend_type} backend")
|
||||
self._initialized = True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize model: {e}")
|
||||
self._initialization_error = RuntimeError(f"Model initialization failed: {e}")
|
||||
raise self._initialization_error
|
||||
Returns:
|
||||
Initialized TTSService instance
|
||||
"""
|
||||
service = cls(output_dir)
|
||||
# Initialize managers
|
||||
service.model_manager = await get_model_manager()
|
||||
service._voice_manager = await get_voice_manager()
|
||||
return service
|
||||
|
||||
async def generate_audio(
|
||||
self, text: str, voice: str, speed: float = 1.0
|
||||
|
@ -88,8 +63,8 @@ class TTSService:
|
|||
Returns:
|
||||
Audio samples and processing time
|
||||
"""
|
||||
await self.ensure_initialized()
|
||||
start_time = time.time()
|
||||
voice_tensor = None
|
||||
|
||||
try:
|
||||
# Normalize text
|
||||
|
@ -98,31 +73,40 @@ class TTSService:
|
|||
raise ValueError("Text is empty after preprocessing")
|
||||
text = str(normalized)
|
||||
|
||||
# Process text into chunks
|
||||
audio_chunks = []
|
||||
for chunk in chunker.split_text(text):
|
||||
try:
|
||||
# Convert chunk to token IDs
|
||||
tokens = process_text(chunk)
|
||||
if not tokens:
|
||||
continue
|
||||
# Get backend and load voice
|
||||
backend = self.model_manager.get_backend()
|
||||
voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device)
|
||||
|
||||
# Get backend and load voice
|
||||
backend = self.model_manager.get_backend()
|
||||
voice_tensor = await self.voice_manager.load_voice(voice, device=backend.device)
|
||||
|
||||
# Generate audio
|
||||
chunk_audio = await self.model_manager.generate(
|
||||
tokens,
|
||||
voice_tensor,
|
||||
speed=speed
|
||||
)
|
||||
if chunk_audio is not None:
|
||||
audio_chunks.append(chunk_audio)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process chunk: '{chunk}'. Error: {str(e)}")
|
||||
continue
|
||||
# Get all chunks upfront
|
||||
chunks = list(chunker.split_text(text))
|
||||
if not chunks:
|
||||
raise ValueError("No text chunks to process")
|
||||
|
||||
# Process chunk with concurrency control
|
||||
async def process_chunk(chunk: str) -> Optional[np.ndarray]:
|
||||
async with self._chunk_semaphore:
|
||||
try:
|
||||
tokens = process_text(chunk)
|
||||
if not tokens:
|
||||
return None
|
||||
|
||||
# Generate audio
|
||||
return await self.model_manager.generate(
|
||||
tokens,
|
||||
voice_tensor,
|
||||
speed=speed
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process chunk: '{chunk}'. Error: {str(e)}")
|
||||
return None
|
||||
|
||||
# Process all chunks concurrently
|
||||
chunk_results = await asyncio.gather(*[
|
||||
process_chunk(chunk) for chunk in chunks
|
||||
])
|
||||
|
||||
# Filter out None results and combine
|
||||
audio_chunks = [chunk for chunk in chunk_results if chunk is not None]
|
||||
if not audio_chunks:
|
||||
raise ValueError("No audio chunks were generated successfully")
|
||||
|
||||
|
@ -134,6 +118,11 @@ class TTSService:
|
|||
except Exception as e:
|
||||
logger.error(f"Error in audio generation: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
# Always clean up voice tensor
|
||||
if voice_tensor is not None:
|
||||
del voice_tensor
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
async def generate_audio_stream(
|
||||
self,
|
||||
|
@ -153,33 +142,34 @@ class TTSService:
|
|||
Yields:
|
||||
Audio chunks as bytes
|
||||
"""
|
||||
await self.ensure_initialized()
|
||||
|
||||
# Setup audio processing
|
||||
stream_normalizer = AudioNormalizer()
|
||||
voice_tensor = None
|
||||
|
||||
try:
|
||||
# Setup audio processing
|
||||
stream_normalizer = AudioNormalizer()
|
||||
|
||||
# Normalize text
|
||||
normalized = normalize_text(text)
|
||||
if not normalized:
|
||||
raise ValueError("Text is empty after preprocessing")
|
||||
text = str(normalized)
|
||||
|
||||
# Process chunks
|
||||
is_first = True
|
||||
chunk_gen = chunker.split_text(text)
|
||||
current_chunk = next(chunk_gen, None)
|
||||
# Get backend and load voice
|
||||
backend = self.model_manager.get_backend()
|
||||
voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device)
|
||||
|
||||
# Get all chunks upfront
|
||||
chunks = list(chunker.split_text(text))
|
||||
if not chunks:
|
||||
raise ValueError("No text chunks to process")
|
||||
|
||||
# Process chunk with concurrency control
|
||||
async def process_chunk(chunk: str, is_first: bool, is_last: bool) -> Optional[bytes]:
|
||||
async with self._chunk_semaphore:
|
||||
try:
|
||||
tokens = process_text(chunk)
|
||||
if not tokens:
|
||||
return None
|
||||
|
||||
while current_chunk is not None:
|
||||
next_chunk = next(chunk_gen, None)
|
||||
try:
|
||||
# Convert chunk to token IDs
|
||||
tokens = process_text(current_chunk)
|
||||
if tokens:
|
||||
# Get backend and load voice
|
||||
backend = self.model_manager.get_backend()
|
||||
voice_tensor = await self.voice_manager.load_voice(voice, device=backend.device)
|
||||
|
||||
# Generate audio
|
||||
chunk_audio = await self.model_manager.generate(
|
||||
tokens,
|
||||
|
@ -189,26 +179,38 @@ class TTSService:
|
|||
|
||||
if chunk_audio is not None:
|
||||
# Convert to bytes
|
||||
chunk_bytes = AudioService.convert_audio(
|
||||
return AudioService.convert_audio(
|
||||
chunk_audio,
|
||||
24000,
|
||||
output_format,
|
||||
is_first_chunk=is_first,
|
||||
normalizer=stream_normalizer,
|
||||
is_last_chunk=(next_chunk is None),
|
||||
is_last_chunk=is_last,
|
||||
stream=True
|
||||
)
|
||||
yield chunk_bytes
|
||||
is_first = False
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate audio for chunk: '{chunk}'. Error: {str(e)}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate audio for chunk: '{current_chunk}'. Error: {str(e)}")
|
||||
# Create tasks for all chunks
|
||||
tasks = [
|
||||
process_chunk(chunk, i==0, i==len(chunks)-1)
|
||||
for i, chunk in enumerate(chunks)
|
||||
]
|
||||
|
||||
current_chunk = next_chunk
|
||||
# Process chunks concurrently and yield results in order
|
||||
for chunk_bytes in await asyncio.gather(*tasks):
|
||||
if chunk_bytes is not None:
|
||||
yield chunk_bytes
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in audio generation stream: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
# Always clean up voice tensor
|
||||
if voice_tensor is not None:
|
||||
del voice_tensor
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
async def combine_voices(self, voices: List[str]) -> str:
|
||||
"""Combine multiple voices.
|
||||
|
@ -219,8 +221,7 @@ class TTSService:
|
|||
Returns:
|
||||
Name of combined voice
|
||||
"""
|
||||
await self.ensure_initialized()
|
||||
return await self.voice_manager.combine_voices(voices)
|
||||
return await self._voice_manager.combine_voices(voices)
|
||||
|
||||
async def list_voices(self) -> List[str]:
|
||||
"""List available voices.
|
||||
|
@ -228,7 +229,7 @@ class TTSService:
|
|||
Returns:
|
||||
List of voice names
|
||||
"""
|
||||
return await self.voice_manager.list_voices()
|
||||
return await self._voice_manager.list_voices()
|
||||
|
||||
def _audio_to_bytes(self, audio: np.ndarray) -> bytes:
|
||||
"""Convert audio to WAV bytes.
|
||||
|
|
|
@ -1,60 +0,0 @@
|
|||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from ..core.config import settings
|
||||
from .tts_model import TTSModel
|
||||
from .tts_service import TTSService
|
||||
|
||||
|
||||
class WarmupService:
|
||||
"""Service for warming up TTS models and voice caches"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize warmup service and ensure model is ready"""
|
||||
# Initialize model if not already initialized
|
||||
if TTSModel._instance is None:
|
||||
TTSModel.initialize(settings.model_dir)
|
||||
self.tts_service = TTSService()
|
||||
|
||||
def load_voices(self) -> List[Tuple[str, torch.Tensor]]:
|
||||
"""Load and cache voices up to LRU limit"""
|
||||
# Get all voices sorted by filename length (shorter names first, usually base voices)
|
||||
voice_files = sorted(
|
||||
[f for f in os.listdir(TTSModel.VOICES_DIR) if f.endswith(".pt")], key=len
|
||||
)
|
||||
|
||||
n_voices_cache = 1
|
||||
loaded_voices = []
|
||||
for voice_file in voice_files[:n_voices_cache]:
|
||||
try:
|
||||
voice_path = os.path.join(TTSModel.VOICES_DIR, voice_file)
|
||||
# load using service, lru cache
|
||||
voicepack = self.tts_service._load_voice(voice_path)
|
||||
loaded_voices.append(
|
||||
(voice_file[:-3], voicepack)
|
||||
) # Store name and tensor
|
||||
# voicepack = torch.load(voice_path, map_location=TTSModel.get_device(), weights_only=True)
|
||||
# logger.info(f"Loaded voice {voice_file[:-3]} into cache")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load voice {voice_file}: {e}")
|
||||
logger.info(f"Pre-loaded {len(loaded_voices)} voices into cache")
|
||||
return loaded_voices
|
||||
|
||||
async def warmup_voices(
|
||||
self, warmup_text: str, loaded_voices: List[Tuple[str, torch.Tensor]]
|
||||
):
|
||||
"""Warm up voice inference and streaming"""
|
||||
n_warmups = 1
|
||||
for voice_name, _ in loaded_voices[:n_warmups]:
|
||||
try:
|
||||
logger.info(f"Running warmup inference on voice {voice_name}")
|
||||
async for _ in self.tts_service.generate_audio_stream(
|
||||
warmup_text, voice_name, 1.0, "pcm"
|
||||
):
|
||||
pass # Process all chunks to properly warm up
|
||||
logger.info(f"Completed warmup for voice {voice_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Warmup failed for voice {voice_name}: {e}")
|
|
@ -10,7 +10,7 @@ services:
|
|||
ports:
|
||||
- "8880:8880"
|
||||
environment:
|
||||
- PYTHONPATH=/app
|
||||
- PYTHONPATH=/app:/app/api
|
||||
- USE_GPU=true
|
||||
- USE_ONNX=false
|
||||
- PYTHONUNBUFFERED=1
|
||||
|
|
|
@ -25,9 +25,7 @@ def main() -> None:
|
|||
def stream_to_speakers() -> None:
|
||||
import pyaudio
|
||||
|
||||
player_stream = pyaudio.PyAudio().open(
|
||||
format=pyaudio.paInt16, channels=1, rate=24000, output=True
|
||||
)
|
||||
player_stream = pyaudio.PyAudio().open(format=pyaudio.paInt16, channels=1, rate=24000, output=True)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
|
|
53
examples/simul_file_test.py
Normal file
53
examples/simul_file_test.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
#!/usr/bin/env rye run python
|
||||
import asyncio
|
||||
import time
|
||||
from pathlib import Path
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
# Initialize async client
|
||||
openai = AsyncOpenAI(base_url="http://localhost:8880/v1", api_key="not-needed-for-local")
|
||||
|
||||
async def save_to_file(text: str, file_id: int) -> None:
|
||||
"""Save TTS output to file asynchronously"""
|
||||
speech_file_path = Path(__file__).parent / f"speech_{file_id}.mp3"
|
||||
|
||||
start_time = time.time()
|
||||
print(f"Starting file {file_id}")
|
||||
|
||||
try:
|
||||
# Use streaming endpoint with mp3 format
|
||||
async with openai.audio.speech.with_streaming_response.create(
|
||||
model="kokoro",
|
||||
voice="af_bella",
|
||||
input=text,
|
||||
response_format="mp3"
|
||||
) as response:
|
||||
print(f"File {file_id} - Time to first byte: {int((time.time() - start_time) * 1000)}ms")
|
||||
|
||||
# Open file in binary write mode
|
||||
with open(speech_file_path, 'wb') as f:
|
||||
async for chunk in response.iter_bytes():
|
||||
f.write(chunk)
|
||||
|
||||
print(f"File {file_id} completed in {int((time.time() - start_time) * 1000)}ms")
|
||||
except Exception as e:
|
||||
print(f"Error processing file {file_id}: {e}")
|
||||
|
||||
async def main() -> None:
|
||||
# Different text samples for variety
|
||||
texts = [
|
||||
"The quick brown fox jumped over the lazy dogs. I see skies of blue and clouds of white",
|
||||
"I see skies of blue and clouds of white. I see skies of blue and clouds of white",
|
||||
]
|
||||
|
||||
# Create tasks for saving to files
|
||||
file_tasks = [
|
||||
save_to_file(text, i)
|
||||
for i, text in enumerate(texts)
|
||||
]
|
||||
|
||||
# Run file tasks concurrently
|
||||
await asyncio.gather(*file_tasks)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
91
examples/simul_openai_streaming_audio.py
Normal file
91
examples/simul_openai_streaming_audio.py
Normal file
|
@ -0,0 +1,91 @@
|
|||
#!/usr/bin/env rye run python
|
||||
import asyncio
|
||||
import time
|
||||
from pathlib import Path
|
||||
import pyaudio
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
# Initialize async client
|
||||
openai = AsyncOpenAI(base_url="http://localhost:8880/v1", api_key="not-needed-for-local")
|
||||
|
||||
# Create a shared PyAudio instance
|
||||
p = pyaudio.PyAudio()
|
||||
|
||||
async def stream_to_speakers(text: str, stream_id: int) -> None:
|
||||
"""Stream TTS audio to speakers asynchronously"""
|
||||
player_stream = p.open(
|
||||
format=pyaudio.paInt16,
|
||||
channels=1,
|
||||
rate=24000,
|
||||
output=True
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
print(f"Starting stream {stream_id}")
|
||||
|
||||
try:
|
||||
async with openai.audio.speech.with_streaming_response.create(
|
||||
model="kokoro",
|
||||
voice="af_bella",
|
||||
response_format="pcm",
|
||||
input=text
|
||||
) as response:
|
||||
print(f"Stream {stream_id} - Time to first byte: {int((time.time() - start_time) * 1000)}ms")
|
||||
|
||||
async for chunk in response.iter_bytes(chunk_size=1024):
|
||||
player_stream.write(chunk)
|
||||
# Small sleep to allow other coroutines to run
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
print(f"Stream {stream_id} completed in {int((time.time() - start_time) * 1000)}ms")
|
||||
|
||||
finally:
|
||||
player_stream.stop_stream()
|
||||
player_stream.close()
|
||||
|
||||
async def save_to_file(text: str, file_id: int) -> None:
|
||||
"""Save TTS output to file asynchronously"""
|
||||
speech_file_path = Path(__file__).parent / f"speech_{file_id}.mp3"
|
||||
|
||||
async with openai.audio.speech.with_streaming_response.create(
|
||||
model="kokoro",
|
||||
voice="af_bella",
|
||||
input=text
|
||||
) as response:
|
||||
# Open file in binary write mode
|
||||
with open(speech_file_path, 'wb') as f:
|
||||
async for chunk in response.iter_bytes():
|
||||
f.write(chunk)
|
||||
print(f"File {file_id} saved to {speech_file_path}")
|
||||
|
||||
async def main() -> None:
|
||||
# Different text samples for variety
|
||||
texts = [
|
||||
"The quick brown fox jumped over the lazy dogs. I see skies of blue and clouds of white",
|
||||
"I see skies of blue and clouds of white. I see skies of blue and clouds of white",
|
||||
]
|
||||
|
||||
# Create tasks for streaming to speakers
|
||||
speaker_tasks = [
|
||||
stream_to_speakers(text, i)
|
||||
for i, text in enumerate(texts)
|
||||
]
|
||||
|
||||
# Create tasks for saving to files
|
||||
file_tasks = [
|
||||
save_to_file(text, i)
|
||||
for i, text in enumerate(texts)
|
||||
]
|
||||
|
||||
# Combine all tasks
|
||||
all_tasks = speaker_tasks + file_tasks
|
||||
|
||||
# Run all tasks concurrently
|
||||
try:
|
||||
await asyncio.gather(*all_tasks)
|
||||
finally:
|
||||
# Clean up PyAudio
|
||||
p.terminate()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
66
examples/simul_speaker_test.py
Normal file
66
examples/simul_speaker_test.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
#!/usr/bin/env rye run python
|
||||
import asyncio
|
||||
import time
|
||||
import pyaudio
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
# Initialize async client
|
||||
openai = AsyncOpenAI(base_url="http://localhost:8880/v1", api_key="not-needed-for-local")
|
||||
|
||||
# Create a shared PyAudio instance
|
||||
p = pyaudio.PyAudio()
|
||||
|
||||
async def stream_to_speakers(text: str, stream_id: int) -> None:
|
||||
"""Stream TTS audio to speakers asynchronously"""
|
||||
player_stream = p.open(
|
||||
format=pyaudio.paInt16,
|
||||
channels=1,
|
||||
rate=24000,
|
||||
output=True
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
print(f"Starting stream {stream_id}")
|
||||
|
||||
try:
|
||||
async with openai.audio.speech.with_streaming_response.create(
|
||||
model="kokoro",
|
||||
voice="af_bella",
|
||||
response_format="pcm",
|
||||
input=text
|
||||
) as response:
|
||||
print(f"Stream {stream_id} - Time to first byte: {int((time.time() - start_time) * 1000)}ms")
|
||||
|
||||
async for chunk in response.iter_bytes(chunk_size=1024):
|
||||
player_stream.write(chunk)
|
||||
# Small sleep to allow other coroutines to run
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
print(f"Stream {stream_id} completed in {int((time.time() - start_time) * 1000)}ms")
|
||||
|
||||
finally:
|
||||
player_stream.stop_stream()
|
||||
player_stream.close()
|
||||
|
||||
async def main() -> None:
|
||||
# Different text samples for variety
|
||||
texts = [
|
||||
"The quick brown fox jumped over the lazy dogs. I see skies of blue and clouds of white",
|
||||
"I see skies of blue and clouds of white. I see skies of blue and clouds of white",
|
||||
]
|
||||
|
||||
# Create tasks for streaming to speakers
|
||||
speaker_tasks = [
|
||||
stream_to_speakers(text, i)
|
||||
for i, text in enumerate(texts)
|
||||
]
|
||||
|
||||
# Run speaker tasks concurrently
|
||||
try:
|
||||
await asyncio.gather(*speaker_tasks)
|
||||
finally:
|
||||
# Clean up PyAudio
|
||||
p.terminate()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
Binary file not shown.
Loading…
Add table
Reference in a new issue