Refactor model loading and configuration: update, adjust model loading device,. add async streaming examples and remove unused warmup service.

This commit is contained in:
remsky 2025-01-22 02:33:29 -07:00
parent 21bf810f97
commit 4a24be1605
21 changed files with 929 additions and 484 deletions

View file

@ -367,7 +367,7 @@ async def build_model(path, device):
decoder=decoder.to(device).eval(), decoder=decoder.to(device).eval(),
text_encoder=text_encoder.to(device).eval(), text_encoder=text_encoder.to(device).eval(),
) )
weights = await load_model_weights(path, device='cpu') weights = await load_model_weights(path, device=device)
for key, state_dict in weights['net'].items(): for key, state_dict in weights['net'].items():
assert key in model, key assert key in model, key
try: try:

View file

@ -15,9 +15,9 @@ class Settings(BaseSettings):
default_voice: str = "af" default_voice: str = "af"
use_gpu: bool = False # Whether to use GPU acceleration if available use_gpu: bool = False # Whether to use GPU acceleration if available
use_onnx: bool = True # Whether to use ONNX runtime use_onnx: bool = True # Whether to use ONNX runtime
# Paths relative to api directory # Container absolute paths
model_dir: str = "src/models" # Model directory relative to api/ model_dir: str = "/app/api/src/models" # Absolute path in container
voices_dir: str = "src/voices" # Voices directory relative to api/ voices_dir: str = "/app/api/src/voices" # Absolute path in container
# Model filenames # Model filenames
pytorch_model_file: str = "kokoro-v0_19.pth" pytorch_model_file: str = "kokoro-v0_19.pth"

View file

@ -6,6 +6,11 @@ from pydantic import BaseModel, Field
class ONNXCPUConfig(BaseModel): class ONNXCPUConfig(BaseModel):
"""ONNX CPU runtime configuration.""" """ONNX CPU runtime configuration."""
# Session pooling
max_instances: int = Field(4, description="Maximum concurrent model instances")
instance_timeout: int = Field(300, description="Session timeout in seconds")
# Runtime settings
num_threads: int = Field(8, description="Number of threads for parallel operations") num_threads: int = Field(8, description="Number of threads for parallel operations")
inter_op_threads: int = Field(4, description="Number of threads for operator parallelism") inter_op_threads: int = Field(4, description="Number of threads for operator parallelism")
execution_mode: str = Field("parallel", description="ONNX execution mode") execution_mode: str = Field("parallel", description="ONNX execution mode")
@ -20,9 +25,14 @@ class ONNXCPUConfig(BaseModel):
class ONNXGPUConfig(ONNXCPUConfig): class ONNXGPUConfig(ONNXCPUConfig):
"""ONNX GPU-specific configuration.""" """ONNX GPU-specific configuration."""
# CUDA settings
device_id: int = Field(0, description="CUDA device ID") device_id: int = Field(0, description="CUDA device ID")
gpu_mem_limit: float = Field(0.7, description="Fraction of GPU memory to use") gpu_mem_limit: float = Field(0.7, description="Fraction of GPU memory to use")
cudnn_conv_algo_search: str = Field("EXHAUSTIVE", description="CuDNN convolution algorithm search") cudnn_conv_algo_search: str = Field("EXHAUSTIVE", description="CuDNN convolution algorithm search")
# Stream management
cuda_streams: int = Field(2, description="Number of CUDA streams for inference")
stream_timeout: int = Field(60, description="Stream timeout in seconds")
do_copy_in_default_stream: bool = Field(True, description="Copy in default CUDA stream") do_copy_in_default_stream: bool = Field(True, description="Copy in default CUDA stream")
class Config: class Config:
@ -32,8 +42,6 @@ class ONNXGPUConfig(ONNXCPUConfig):
class PyTorchCPUConfig(BaseModel): class PyTorchCPUConfig(BaseModel):
"""PyTorch CPU backend configuration.""" """PyTorch CPU backend configuration."""
max_batch_size: int = Field(32, description="Maximum batch size for batched inference")
stream_buffer_size: int = Field(8, description="Size of stream buffer")
memory_threshold: float = Field(0.8, description="Memory threshold for cleanup") memory_threshold: float = Field(0.8, description="Memory threshold for cleanup")
retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors") retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors")
num_threads: int = Field(8, description="Number of threads for parallel operations") num_threads: int = Field(8, description="Number of threads for parallel operations")
@ -49,18 +57,11 @@ class PyTorchGPUConfig(BaseModel):
device_id: int = Field(0, description="CUDA device ID") device_id: int = Field(0, description="CUDA device ID")
use_fp16: bool = Field(True, description="Whether to use FP16 precision") use_fp16: bool = Field(True, description="Whether to use FP16 precision")
use_triton: bool = Field(True, description="Whether to use Triton for CUDA kernels") use_triton: bool = Field(True, description="Whether to use Triton for CUDA kernels")
max_batch_size: int = Field(32, description="Maximum batch size for batched inference")
stream_buffer_size: int = Field(8, description="Size of CUDA stream buffer")
memory_threshold: float = Field(0.8, description="Memory threshold for cleanup") memory_threshold: float = Field(0.8, description="Memory threshold for cleanup")
retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors") retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors")
sync_cuda: bool = Field(True, description="Whether to synchronize CUDA operations") sync_cuda: bool = Field(True, description="Whether to synchronize CUDA operations")
cuda_streams: int = Field(2, description="Number of CUDA streams for inference")
class Config: stream_timeout: int = Field(60, description="Stream timeout in seconds")
frozen = True
"""PyTorch CPU-specific configuration."""
num_threads: int = Field(8, description="Number of threads for parallel operations")
pin_memory: bool = Field(True, description="Whether to pin memory for faster CPU-GPU transfer")
class Config: class Config:
frozen = True frozen = True
@ -74,7 +75,7 @@ class ModelConfig(BaseModel):
device_type: str = Field("auto", description="Device type ('cpu', 'gpu', or 'auto')") device_type: str = Field("auto", description="Device type ('cpu', 'gpu', or 'auto')")
cache_models: bool = Field(True, description="Whether to cache loaded models") cache_models: bool = Field(True, description="Whether to cache loaded models")
cache_voices: bool = Field(True, description="Whether to cache voice tensors") cache_voices: bool = Field(True, description="Whether to cache voice tensors")
voice_cache_size: int = Field(10, description="Maximum number of cached voices") voice_cache_size: int = Field(2, description="Maximum number of cached voices")
# Backend-specific configs # Backend-specific configs
onnx_cpu: ONNXCPUConfig = Field(default_factory=ONNXCPUConfig) onnx_cpu: ONNXCPUConfig = Field(default_factory=ONNXCPUConfig)

View file

@ -1,5 +1,6 @@
"""Model management and caching.""" """Model management and caching."""
import asyncio
from typing import Dict, Optional from typing import Dict, Optional
import torch import torch
@ -13,6 +14,14 @@ from .onnx_cpu import ONNXCPUBackend
from .onnx_gpu import ONNXGPUBackend from .onnx_gpu import ONNXGPUBackend
from .pytorch_cpu import PyTorchCPUBackend from .pytorch_cpu import PyTorchCPUBackend
from .pytorch_gpu import PyTorchGPUBackend from .pytorch_gpu import PyTorchGPUBackend
from .session_pool import CPUSessionPool, StreamingSessionPool
# Global singleton instance and state
_manager_instance = None
_manager_lock = asyncio.Lock()
_loaded_models = {}
_backends = {}
class ModelManager: class ModelManager:
@ -25,65 +34,60 @@ class ModelManager:
config: Optional configuration config: Optional configuration
""" """
self._config = config or model_config self._config = config or model_config
self._backends: Dict[str, BaseModelBackend] = {} global _loaded_models, _backends
self._current_backend: Optional[str] = None self._loaded_models = _loaded_models
self._initialize_backends() self._backends = _backends
def _initialize_backends(self) -> None: # Initialize session pools
"""Initialize available backends based on settings.""" self._session_pools = {
has_gpu = settings.use_gpu and torch.cuda.is_available() 'onnx_cpu': CPUSessionPool(),
'onnx_gpu': StreamingSessionPool()
}
# Initialize locks
self._backend_locks: Dict[str, asyncio.Lock] = {}
def _determine_device(self) -> str:
"""Determine device based on settings."""
if settings.use_gpu and torch.cuda.is_available():
return "cuda"
return "cpu"
async def initialize(self) -> None:
"""Initialize backends."""
if self._backends:
logger.debug("Using existing backend instances")
return
device = self._determine_device()
try: try:
if has_gpu: if device == "cuda":
if settings.use_onnx: if settings.use_onnx:
# ONNX GPU primary
self._backends['onnx_gpu'] = ONNXGPUBackend() self._backends['onnx_gpu'] = ONNXGPUBackend()
self._current_backend = 'onnx_gpu' self._current_backend = 'onnx_gpu'
logger.info("Initialized ONNX GPU backend") logger.info("Initialized new ONNX GPU backend")
# PyTorch GPU fallback
self._backends['pytorch_gpu'] = PyTorchGPUBackend()
logger.info("Initialized PyTorch GPU backend")
else: else:
# PyTorch GPU primary
self._backends['pytorch_gpu'] = PyTorchGPUBackend() self._backends['pytorch_gpu'] = PyTorchGPUBackend()
self._current_backend = 'pytorch_gpu' self._current_backend = 'pytorch_gpu'
logger.info("Initialized PyTorch GPU backend") logger.info("Initialized new PyTorch GPU backend")
# ONNX GPU fallback
self._backends['onnx_gpu'] = ONNXGPUBackend()
logger.info("Initialized ONNX GPU backend")
else: else:
self._initialize_cpu_backends()
except Exception as e:
logger.error(f"Failed to initialize GPU backends: {e}")
# Fallback to CPU if GPU fails
self._initialize_cpu_backends()
def _initialize_cpu_backends(self) -> None:
"""Initialize CPU backends based on settings."""
try:
if settings.use_onnx: if settings.use_onnx:
# ONNX CPU primary
self._backends['onnx_cpu'] = ONNXCPUBackend() self._backends['onnx_cpu'] = ONNXCPUBackend()
self._current_backend = 'onnx_cpu' self._current_backend = 'onnx_cpu'
logger.info("Initialized ONNX CPU backend") logger.info("Initialized new ONNX CPU backend")
# PyTorch CPU fallback
self._backends['pytorch_cpu'] = PyTorchCPUBackend()
logger.info("Initialized PyTorch CPU backend")
else: else:
# PyTorch CPU primary
self._backends['pytorch_cpu'] = PyTorchCPUBackend() self._backends['pytorch_cpu'] = PyTorchCPUBackend()
self._current_backend = 'pytorch_cpu' self._current_backend = 'pytorch_cpu'
logger.info("Initialized PyTorch CPU backend") logger.info("Initialized new PyTorch CPU backend")
# Initialize locks for each backend
for backend in self._backends:
self._backend_locks[backend] = asyncio.Lock()
# ONNX CPU fallback
self._backends['onnx_cpu'] = ONNXCPUBackend()
logger.info("Initialized ONNX CPU backend")
except Exception as e: except Exception as e:
logger.error(f"Failed to initialize CPU backends: {e}") logger.error(f"Failed to initialize backend: {e}")
raise RuntimeError("No backends available") raise RuntimeError("Failed to initialize backend")
def get_backend(self, backend_type: Optional[str] = None) -> BaseModelBackend: def get_backend(self, backend_type: Optional[str] = None) -> BaseModelBackend:
"""Get specified backend. """Get specified backend.
@ -154,17 +158,40 @@ class ModelManager:
if backend_type is None: if backend_type is None:
backend_type = self._determine_backend(abs_path) backend_type = self._determine_backend(abs_path)
# Get backend lock
lock = self._backend_locks[backend_type]
async with lock:
backend = self.get_backend(backend_type) backend = self.get_backend(backend_type)
# For ONNX backends, use session pool
if backend_type.startswith('onnx'):
pool = self._session_pools[backend_type]
backend._session = await pool.get_session(abs_path)
self._loaded_models[backend_type] = abs_path
logger.info(f"Fetched model instance from {backend_type} pool")
# For PyTorch backends, load normally
else:
# Check if model is already loaded
if (backend_type in self._loaded_models and
self._loaded_models[backend_type] == abs_path and
backend.is_loaded):
logger.info(f"Fetching existing model instance from {backend_type}")
return
# Load model # Load model
await backend.load_model(abs_path) await backend.load_model(abs_path)
logger.info(f"Loaded model on {backend_type} backend") self._loaded_models[backend_type] = abs_path
logger.info(f"Initialized new model instance on {backend_type}")
# Run warmup if voice provided # Run warmup if voice provided
if warmup_voice is not None: if warmup_voice is not None:
await self._warmup_inference(backend, warmup_voice) await self._warmup_inference(backend, warmup_voice)
except Exception as e: except Exception as e:
# Clear cached path on failure
self._loaded_models.pop(backend_type, None)
raise RuntimeError(f"Failed to load model: {e}") raise RuntimeError(f"Failed to load model: {e}")
async def _warmup_inference(self, backend: BaseModelBackend, voice: torch.Tensor) -> None: async def _warmup_inference(self, backend: BaseModelBackend, voice: torch.Tensor) -> None:
@ -188,7 +215,7 @@ class ModelManager:
# Run inference # Run inference
backend.generate(tokens, voice, speed=1.0) backend.generate(tokens, voice, speed=1.0)
logger.info("Completed warmup inference") logger.debug("Completed warmup inference")
except Exception as e: except Exception as e:
logger.warning(f"Warmup inference failed: {e}") logger.warning(f"Warmup inference failed: {e}")
@ -221,16 +248,23 @@ class ModelManager:
try: try:
# Generate audio using provided voice tensor # Generate audio using provided voice tensor
# No lock needed here since inference is thread-safe
return backend.generate(tokens, voice, speed) return backend.generate(tokens, voice, speed)
except Exception as e: except Exception as e:
raise RuntimeError(f"Generation failed: {e}") raise RuntimeError(f"Generation failed: {e}")
def unload_all(self) -> None: def unload_all(self) -> None:
"""Unload models from all backends.""" """Unload models from all backends and clear cache."""
# Clean up session pools
for pool in self._session_pools.values():
pool.cleanup()
# Unload PyTorch backends
for backend in self._backends.values(): for backend in self._backends.values():
backend.unload() backend.unload()
logger.info("Unloaded all models")
self._loaded_models.clear()
logger.info("Unloaded all models and cleared cache")
@property @property
def available_backends(self) -> list[str]: def available_backends(self) -> list[str]:
@ -251,12 +285,8 @@ class ModelManager:
return self._current_backend return self._current_backend
# Module-level instance async def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
_manager: Optional[ModelManager] = None """Get global model manager instance.
def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
"""Get or create global model manager instance.
Args: Args:
config: Optional model configuration config: Optional model configuration
@ -264,7 +294,10 @@ def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
Returns: Returns:
ModelManager instance ModelManager instance
""" """
global _manager global _manager_instance
if _manager is None:
_manager = ModelManager(config) async with _manager_lock:
return _manager if _manager_instance is None:
_manager_instance = ModelManager(config)
await _manager_instance.initialize()
return _manager_instance

View file

@ -1,20 +1,16 @@
"""CPU-based ONNX inference backend.""" """CPU-based ONNX inference backend."""
from typing import Dict, Optional from typing import Optional
import numpy as np import numpy as np
import torch import torch
from loguru import logger from loguru import logger
from onnxruntime import ( from onnxruntime import InferenceSession
ExecutionMode,
GraphOptimizationLevel,
InferenceSession,
SessionOptions
)
from ..core import paths from ..core import paths
from ..core.model_config import model_config from ..core.model_config import model_config
from .base import BaseModelBackend from .base import BaseModelBackend
from .session_pool import create_session_options, create_provider_options
class ONNXCPUBackend(BaseModelBackend): class ONNXCPUBackend(BaseModelBackend):
@ -47,8 +43,8 @@ class ONNXCPUBackend(BaseModelBackend):
logger.info(f"Loading ONNX model: {model_path}") logger.info(f"Loading ONNX model: {model_path}")
# Configure session # Configure session
options = self._create_session_options() options = create_session_options(is_gpu=False)
provider_options = self._create_provider_options() provider_options = create_provider_options(is_gpu=False)
# Create session # Create session
self._session = InferenceSession( self._session = InferenceSession(
@ -84,9 +80,9 @@ class ONNXCPUBackend(BaseModelBackend):
raise RuntimeError("Model not loaded") raise RuntimeError("Model not loaded")
try: try:
# Prepare inputs # Prepare inputs with start/end tokens
tokens_input = np.array([tokens], dtype=np.int64) tokens_input = np.array([[0, *tokens, 0]], dtype=np.int64) # Add start/end tokens
style_input = voice[len(tokens)].numpy() style_input = voice[len(tokens) + 2].numpy() # Adjust index for start/end tokens
speed_input = np.full(1, speed, dtype=np.float32) speed_input = np.full(1, speed, dtype=np.float32)
# Run inference # Run inference
@ -104,52 +100,6 @@ class ONNXCPUBackend(BaseModelBackend):
except Exception as e: except Exception as e:
raise RuntimeError(f"Generation failed: {e}") raise RuntimeError(f"Generation failed: {e}")
def _create_session_options(self) -> SessionOptions:
"""Create ONNX session options.
Returns:
Configured session options
"""
options = SessionOptions()
config = model_config.onnx_cpu
# Set optimization level
if config.optimization_level == "all":
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
elif config.optimization_level == "basic":
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
else:
options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
# Configure threading
options.intra_op_num_threads = config.num_threads
options.inter_op_num_threads = config.inter_op_threads
# Set execution mode
options.execution_mode = (
ExecutionMode.ORT_PARALLEL
if config.execution_mode == "parallel"
else ExecutionMode.ORT_SEQUENTIAL
)
# Configure memory optimization
options.enable_mem_pattern = config.memory_pattern
return options
def _create_provider_options(self) -> Dict:
"""Create CPU provider options.
Returns:
Provider configuration
"""
return {
"CPUExecutionProvider": {
"arena_extend_strategy": model_config.onnx_cpu.arena_extend_strategy,
"cpu_memory_arena_cfg": "cpu:0"
}
}
def unload(self) -> None: def unload(self) -> None:
"""Unload model and free resources.""" """Unload model and free resources."""
if self._session is not None: if self._session is not None:

View file

@ -1,20 +1,16 @@
"""GPU-based ONNX inference backend.""" """GPU-based ONNX inference backend."""
from typing import Dict, Optional from typing import Optional
import numpy as np import numpy as np
import torch import torch
from loguru import logger from loguru import logger
from onnxruntime import ( from onnxruntime import InferenceSession
ExecutionMode,
GraphOptimizationLevel,
InferenceSession,
SessionOptions
)
from ..core import paths from ..core import paths
from ..core.model_config import model_config from ..core.model_config import model_config
from .base import BaseModelBackend from .base import BaseModelBackend
from .session_pool import create_session_options, create_provider_options
class ONNXGPUBackend(BaseModelBackend): class ONNXGPUBackend(BaseModelBackend):
@ -28,6 +24,9 @@ class ONNXGPUBackend(BaseModelBackend):
self._device = "cuda" self._device = "cuda"
self._session: Optional[InferenceSession] = None self._session: Optional[InferenceSession] = None
# Configure GPU
torch.cuda.set_device(model_config.onnx_gpu.device_id)
@property @property
def is_loaded(self) -> bool: def is_loaded(self) -> bool:
"""Check if model is loaded.""" """Check if model is loaded."""
@ -49,8 +48,8 @@ class ONNXGPUBackend(BaseModelBackend):
logger.info(f"Loading ONNX model on GPU: {model_path}") logger.info(f"Loading ONNX model on GPU: {model_path}")
# Configure session # Configure session
options = self._create_session_options() options = create_session_options(is_gpu=True)
provider_options = self._create_provider_options() provider_options = create_provider_options(is_gpu=True)
# Create session with CUDA provider # Create session with CUDA provider
self._session = InferenceSession( self._session = InferenceSession(
@ -87,8 +86,8 @@ class ONNXGPUBackend(BaseModelBackend):
try: try:
# Prepare inputs # Prepare inputs
tokens_input = np.array([tokens], dtype=np.int64) tokens_input = np.array([[0, *tokens, 0]], dtype=np.int64) # Add start/end tokens
style_input = voice[len(tokens)].cpu().numpy() # Move to CPU for ONNX style_input = voice[len(tokens) + 2].cpu().numpy() # Move to CPU for ONNX
speed_input = np.full(1, speed, dtype=np.float32) speed_input = np.full(1, speed, dtype=np.float32)
# Run inference # Run inference
@ -104,62 +103,15 @@ class ONNXGPUBackend(BaseModelBackend):
return result[0] return result[0]
except Exception as e: except Exception as e:
if "out of memory" in str(e).lower():
# Clear CUDA cache and retry
torch.cuda.empty_cache()
return self.generate(tokens, voice, speed)
raise RuntimeError(f"Generation failed: {e}") raise RuntimeError(f"Generation failed: {e}")
def _create_session_options(self) -> SessionOptions:
"""Create ONNX session options.
Returns:
Configured session options
"""
options = SessionOptions()
config = model_config.onnx_gpu
# Set optimization level
if config.optimization_level == "all":
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
elif config.optimization_level == "basic":
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
else:
options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
# Configure threading
options.intra_op_num_threads = config.num_threads
options.inter_op_num_threads = config.inter_op_threads
# Set execution mode
options.execution_mode = (
ExecutionMode.ORT_PARALLEL
if config.execution_mode == "parallel"
else ExecutionMode.ORT_SEQUENTIAL
)
# Configure memory optimization
options.enable_mem_pattern = config.memory_pattern
return options
def _create_provider_options(self) -> Dict:
"""Create CUDA provider options.
Returns:
Provider configuration
"""
config = model_config.onnx_gpu
return {
"CUDAExecutionProvider": {
"device_id": config.device_id,
"arena_extend_strategy": config.arena_extend_strategy,
"gpu_mem_limit": int(config.gpu_mem_limit * torch.cuda.get_device_properties(0).total_memory),
"cudnn_conv_algo_search": config.cudnn_conv_algo_search,
"do_copy_in_default_stream": config.do_copy_in_default_stream
}
}
def unload(self) -> None: def unload(self) -> None:
"""Unload model and free resources.""" """Unload model and free resources."""
if self._session is not None: if self._session is not None:
del self._session del self._session
self._session = None self._session = None
if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()

View file

@ -13,8 +13,37 @@ from ..core.model_config import model_config
from .base import BaseModelBackend from .base import BaseModelBackend
class CUDAStreamManager:
"""CUDA stream manager."""
def __init__(self, num_streams: int):
"""Initialize stream manager.
Args:
num_streams: Number of CUDA streams
"""
self.streams = [torch.cuda.Stream() for _ in range(num_streams)]
self._current = 0
def get_next_stream(self) -> torch.cuda.Stream:
"""Get next available stream.
Returns:
CUDA stream
"""
stream = self.streams[self._current]
self._current = (self._current + 1) % len(self.streams)
return stream
@torch.no_grad() @torch.no_grad()
def forward(model: torch.nn.Module, tokens: list[int], ref_s: torch.Tensor, speed: float) -> np.ndarray: def forward(
model: torch.nn.Module,
tokens: list[int],
ref_s: torch.Tensor,
speed: float,
stream: Optional[torch.cuda.Stream] = None
) -> np.ndarray:
"""Forward pass through model. """Forward pass through model.
Args: Args:
@ -22,12 +51,15 @@ def forward(model: torch.nn.Module, tokens: list[int], ref_s: torch.Tensor, spee
tokens: Input tokens tokens: Input tokens
ref_s: Reference signal (shape: [1, n_features]) ref_s: Reference signal (shape: [1, n_features])
speed: Speed multiplier speed: Speed multiplier
stream: Optional CUDA stream
Returns: Returns:
Generated audio Generated audio
""" """
device = ref_s.device device = ref_s.device
# Use provided stream or default
with torch.cuda.stream(stream) if stream else torch.cuda.default_stream():
# Initial tensor setup # Initial tensor setup
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device) tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device) input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
@ -74,6 +106,11 @@ def forward(model: torch.nn.Module, tokens: list[int], ref_s: torch.Tensor, spee
# Generate output # Generate output
output = model.decoder(asr, F0_pred, N_pred, s_ref) output = model.decoder(asr, F0_pred, N_pred, s_ref)
# Ensure operation completion if using custom stream
if stream:
stream.synchronize()
return output.squeeze().cpu().numpy() return output.squeeze().cpu().numpy()
@ -92,9 +129,10 @@ class PyTorchGPUBackend(BaseModelBackend):
def __init__(self): def __init__(self):
"""Initialize GPU backend.""" """Initialize GPU backend."""
super().__init__() super().__init__()
if not torch.cuda.is_available(): from ..core.config import settings
raise RuntimeError("CUDA not available") if not (settings.use_gpu and torch.cuda.is_available()):
self._device = "cuda" raise RuntimeError("GPU backend requires GPU support and use_gpu=True")
self._device = "cuda" # Device is enforced by backend selection in model_manager
self._model: Optional[torch.nn.Module] = None self._model: Optional[torch.nn.Module] = None
# Configure GPU settings # Configure GPU settings
@ -103,6 +141,9 @@ class PyTorchGPUBackend(BaseModelBackend):
torch.cuda.synchronize() torch.cuda.synchronize()
torch.cuda.set_device(config.device_id) torch.cuda.set_device(config.device_id)
# Initialize stream manager
self._stream_manager = CUDAStreamManager(config.cuda_streams)
async def load_model(self, path: str) -> None: async def load_model(self, path: str) -> None:
"""Load PyTorch model. """Load PyTorch model.
@ -154,8 +195,11 @@ class PyTorchGPUBackend(BaseModelBackend):
if ref_s.dim() == 1: if ref_s.dim() == 1:
ref_s = ref_s.unsqueeze(0) # Add batch dimension if needed ref_s = ref_s.unsqueeze(0) # Add batch dimension if needed
# Generate audio # Get next available stream
return forward(self._model, tokens, ref_s, speed) stream = self._stream_manager.get_next_stream()
# Generate audio using stream
return forward(self._model, tokens, ref_s, speed, stream)
except Exception as e: except Exception as e:
logger.error(f"Generation failed: {e}") logger.error(f"Generation failed: {e}")

View file

@ -0,0 +1,271 @@
"""Session pooling for model inference."""
import asyncio
import time
from dataclasses import dataclass
from typing import Dict, Optional, Set
import torch
from loguru import logger
from onnxruntime import (
ExecutionMode,
GraphOptimizationLevel,
InferenceSession,
SessionOptions
)
from ..core import paths
from ..core.model_config import model_config
@dataclass
class SessionInfo:
"""Session information."""
session: InferenceSession
last_used: float
stream_id: Optional[int] = None
def create_session_options(is_gpu: bool = False) -> SessionOptions:
"""Create ONNX session options.
Args:
is_gpu: Whether to use GPU configuration
Returns:
Configured session options
"""
options = SessionOptions()
config = model_config.onnx_gpu if is_gpu else model_config.onnx_cpu
# Set optimization level
if config.optimization_level == "all":
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
elif config.optimization_level == "basic":
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
else:
options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
# Configure threading
options.intra_op_num_threads = config.num_threads
options.inter_op_num_threads = config.inter_op_threads
# Set execution mode
options.execution_mode = (
ExecutionMode.ORT_PARALLEL
if config.execution_mode == "parallel"
else ExecutionMode.ORT_SEQUENTIAL
)
# Configure memory optimization
options.enable_mem_pattern = config.memory_pattern
return options
def create_provider_options(is_gpu: bool = False) -> Dict:
"""Create provider options.
Args:
is_gpu: Whether to use GPU configuration
Returns:
Provider configuration
"""
if is_gpu:
config = model_config.onnx_gpu
return {
"CUDAExecutionProvider": {
"device_id": config.device_id,
"arena_extend_strategy": config.arena_extend_strategy,
"gpu_mem_limit": int(config.gpu_mem_limit * torch.cuda.get_device_properties(0).total_memory),
"cudnn_conv_algo_search": config.cudnn_conv_algo_search,
"do_copy_in_default_stream": config.do_copy_in_default_stream
}
}
else:
return {
"CPUExecutionProvider": {
"arena_extend_strategy": model_config.onnx_cpu.arena_extend_strategy,
"cpu_memory_arena_cfg": "cpu:0"
}
}
class BaseSessionPool:
"""Base session pool implementation."""
def __init__(self, max_size: int, timeout: int):
"""Initialize session pool.
Args:
max_size: Maximum number of concurrent sessions
timeout: Session timeout in seconds
"""
self._max_size = max_size
self._timeout = timeout
self._sessions: Dict[str, SessionInfo] = {}
self._lock = asyncio.Lock()
async def get_session(self, model_path: str) -> InferenceSession:
"""Get session from pool.
Args:
model_path: Path to model file
Returns:
ONNX inference session
Raises:
RuntimeError: If no sessions available
"""
async with self._lock:
# Clean expired sessions
self._cleanup_expired()
# Check if session exists and is valid
if model_path in self._sessions:
session_info = self._sessions[model_path]
session_info.last_used = time.time()
return session_info.session
# Check if we can create new session
if len(self._sessions) >= self._max_size:
raise RuntimeError(
f"Maximum number of sessions reached ({self._max_size})"
)
# Create new session
session = await self._create_session(model_path)
self._sessions[model_path] = SessionInfo(
session=session,
last_used=time.time()
)
return session
def _cleanup_expired(self) -> None:
"""Remove expired sessions."""
current_time = time.time()
expired = [
path for path, info in self._sessions.items()
if current_time - info.last_used > self._timeout
]
for path in expired:
logger.info(f"Removing expired session: {path}")
del self._sessions[path]
async def _create_session(self, model_path: str) -> InferenceSession:
"""Create new session.
Args:
model_path: Path to model file
Returns:
ONNX inference session
"""
raise NotImplementedError
def cleanup(self) -> None:
"""Clean up all sessions."""
self._sessions.clear()
class StreamingSessionPool(BaseSessionPool):
"""GPU session pool with CUDA streams."""
def __init__(self):
"""Initialize GPU session pool."""
config = model_config.onnx_gpu
super().__init__(config.cuda_streams, config.stream_timeout)
self._available_streams: Set[int] = set(range(config.cuda_streams))
async def get_session(self, model_path: str) -> InferenceSession:
"""Get session with CUDA stream.
Args:
model_path: Path to model file
Returns:
ONNX inference session
Raises:
RuntimeError: If no streams available
"""
async with self._lock:
# Clean expired sessions
self._cleanup_expired()
# Try to find existing session
if model_path in self._sessions:
session_info = self._sessions[model_path]
session_info.last_used = time.time()
return session_info.session
# Get available stream
if not self._available_streams:
raise RuntimeError("No CUDA streams available")
stream_id = self._available_streams.pop()
try:
# Create new session
session = await self._create_session(model_path)
self._sessions[model_path] = SessionInfo(
session=session,
last_used=time.time(),
stream_id=stream_id
)
return session
except Exception:
# Return stream to pool on failure
self._available_streams.add(stream_id)
raise
def _cleanup_expired(self) -> None:
"""Remove expired sessions and return streams."""
current_time = time.time()
expired = [
path for path, info in self._sessions.items()
if current_time - info.last_used > self._timeout
]
for path in expired:
info = self._sessions[path]
if info.stream_id is not None:
self._available_streams.add(info.stream_id)
logger.info(f"Removing expired session: {path}")
del self._sessions[path]
async def _create_session(self, model_path: str) -> InferenceSession:
"""Create new session with CUDA provider."""
abs_path = await paths.get_model_path(model_path)
options = create_session_options(is_gpu=True)
provider_options = create_provider_options(is_gpu=True)
return InferenceSession(
abs_path,
sess_options=options,
providers=["CUDAExecutionProvider"],
provider_options=[provider_options]
)
class CPUSessionPool(BaseSessionPool):
"""CPU session pool."""
def __init__(self):
"""Initialize CPU session pool."""
config = model_config.onnx_cpu
super().__init__(config.max_instances, config.instance_timeout)
async def _create_session(self, model_path: str) -> InferenceSession:
"""Create new session with CPU provider."""
abs_path = await paths.get_model_path(model_path)
options = create_session_options(is_gpu=False)
provider_options = create_provider_options(is_gpu=False)
return InferenceSession(
abs_path,
sess_options=options,
providers=["CPUExecutionProvider"],
provider_options=[provider_options]
)

View file

@ -1,11 +1,10 @@
"""Voice pack management and caching.""" """Voice pack management and caching."""
import os import os
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional
import torch import torch
from loguru import logger from loguru import logger
from pydantic import BaseModel
from ..core import paths from ..core import paths
from ..core.config import settings from ..core.config import settings
@ -13,7 +12,7 @@ from ..structures.model_schemas import VoiceConfig
class VoiceManager: class VoiceManager:
"""Manages voice loading, caching, and operations.""" """Manages voice loading and operations."""
def __init__(self, config: Optional[VoiceConfig] = None): def __init__(self, config: Optional[VoiceConfig] = None):
"""Initialize voice manager. """Initialize voice manager.
@ -33,15 +32,8 @@ class VoiceManager:
Returns: Returns:
Path to voice file if exists, None otherwise Path to voice file if exists, None otherwise
""" """
# Get api directory path (two levels up from inference)
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
# Construct voice path relative to api directory
voice_path = os.path.join(api_dir, settings.voices_dir, f"{voice_name}.pt") voice_path = os.path.join(api_dir, settings.voices_dir, f"{voice_name}.pt")
# Ensure voices directory exists
os.makedirs(os.path.dirname(voice_path), exist_ok=True)
return voice_path if os.path.exists(voice_path) else None return voice_path if os.path.exists(voice_path) else None
async def load_voice(self, voice_name: str, device: str = "cpu") -> torch.Tensor: async def load_voice(self, voice_name: str, device: str = "cpu") -> torch.Tensor:
@ -66,9 +58,11 @@ class VoiceManager:
if self._config.use_cache and cache_key in self._voice_cache: if self._config.use_cache and cache_key in self._voice_cache:
return self._voice_cache[cache_key] return self._voice_cache[cache_key]
try:
# Load voice tensor # Load voice tensor
try:
voice = await paths.load_voice_tensor(voice_path, device=device) voice = await paths.load_voice_tensor(voice_path, device=device)
except Exception as e:
raise RuntimeError(f"Failed to load voice {voice_name}: {e}")
# Cache if enabled # Cache if enabled
if self._config.use_cache: if self._config.use_cache:
@ -78,9 +72,6 @@ class VoiceManager:
return voice return voice
except Exception as e:
raise RuntimeError(f"Failed to load voice {voice_name}: {e}")
def _manage_cache(self) -> None: def _manage_cache(self) -> None:
"""Manage voice cache size.""" """Manage voice cache size."""
if len(self._voice_cache) >= self._config.cache_size: if len(self._voice_cache) >= self._config.cache_size:
@ -123,14 +114,14 @@ class VoiceManager:
# Get api directory path # Get api directory path
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
voices_dir = os.path.join(api_dir, settings.voices_dir) voices_dir = os.path.join(api_dir, settings.voices_dir)
# Ensure voices directory exists
os.makedirs(voices_dir, exist_ok=True) os.makedirs(voices_dir, exist_ok=True)
# Save combined voice # Save combined voice
combined_path = os.path.join(voices_dir, f"{combined_name}.pt") combined_path = os.path.join(voices_dir, f"{combined_name}.pt")
try: try:
torch.save(combined_tensor, combined_path) torch.save(combined_tensor, combined_path)
# Cache the new combined voice
self._voice_cache[f"{combined_path}_{device}"] = combined_tensor
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to save combined voice: {e}") raise RuntimeError(f"Failed to save combined voice: {e}")
@ -147,17 +138,13 @@ class VoiceManager:
""" """
voices = [] voices = []
try: try:
# Get api directory path
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
voices_dir = os.path.join(api_dir, settings.voices_dir) voices_dir = os.path.join(api_dir, settings.voices_dir)
# Ensure voices directory exists
os.makedirs(voices_dir, exist_ok=True) os.makedirs(voices_dir, exist_ok=True)
# List voice files
for entry in os.listdir(voices_dir): for entry in os.listdir(voices_dir):
if entry.endswith(".pt"): if entry.endswith(".pt"):
voices.append(entry[:-3]) # Remove .pt extension voices.append(entry[:-3])
except Exception as e: except Exception as e:
logger.error(f"Error listing voices: {e}") logger.error(f"Error listing voices: {e}")
return sorted(voices) return sorted(voices)
@ -174,11 +161,8 @@ class VoiceManager:
try: try:
if not os.path.exists(voice_path): if not os.path.exists(voice_path):
return False return False
# Try loading voice
voice = torch.load(voice_path, map_location="cpu") voice = torch.load(voice_path, map_location="cpu")
return isinstance(voice, torch.Tensor) return isinstance(voice, torch.Tensor)
except Exception: except Exception:
return False return False
@ -195,12 +179,12 @@ class VoiceManager:
} }
# Module-level instance # Global singleton instance
_manager: Optional[VoiceManager] = None _manager_instance = None
def get_manager(config: Optional[VoiceConfig] = None) -> VoiceManager: async def get_manager(config: Optional[VoiceConfig] = None) -> VoiceManager:
"""Get or create global voice manager instance. """Get global voice manager instance.
Args: Args:
config: Optional voice configuration config: Optional voice configuration
@ -208,7 +192,7 @@ def get_manager(config: Optional[VoiceConfig] = None) -> VoiceManager:
Returns: Returns:
VoiceManager instance VoiceManager instance
""" """
global _manager global _manager_instance
if _manager is None: if _manager_instance is None:
_manager = VoiceManager(config) _manager_instance = VoiceManager(config)
return _manager return _manager_instance

View file

@ -1,10 +1,13 @@
""" """
FastAPI OpenAI Compatible API FastAPI OpenAI Compatible API
""" """
import os
import sys import sys
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import torch
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@ -41,19 +44,59 @@ setup_logger()
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
"""Lifespan context manager for model initialization""" """Lifespan context manager for model initialization"""
from .inference.model_manager import get_manager
from .inference.voice_manager import get_manager as get_voice_manager
logger.info("Loading TTS model and voice packs...") logger.info("Loading TTS model and voice packs...")
# Initialize service try:
service = TTSService() # Initialize managers globally
await service.ensure_initialized() model_manager = await get_manager()
voice_manager = await get_voice_manager()
# Get available voices # Determine backend type based on settings
voices = await service.list_voices() if settings.use_gpu and torch.cuda.is_available():
backend_type = 'pytorch_gpu' if not settings.use_onnx else 'onnx_gpu'
else:
backend_type = 'pytorch_cpu' if not settings.use_onnx else 'onnx_cpu'
# Get backend and initialize model
backend = model_manager.get_backend(backend_type)
# Use model path directly from settings
model_file = settings.pytorch_model_file if not settings.use_onnx else settings.onnx_model_file
model_path = os.path.join(settings.model_dir, model_file)
if not os.path.exists(model_path):
raise RuntimeError(f"Model file not found: {model_path}")
# Pre-cache default voice and use for warmup
warmup_voice = await voice_manager.load_voice(settings.default_voice, device=backend.device)
logger.info(f"Pre-cached voice {settings.default_voice} for warmup")
# Initialize model with warmup voice
await model_manager.load_model(model_path, warmup_voice, backend_type)
# Pre-cache common voices in background
common_voices = ['af', 'af_bella', 'af_sarah', 'af_nicole']
for voice_name in common_voices:
try:
await voice_manager.load_voice(voice_name, device=backend.device)
logger.debug(f"Pre-cached voice {voice_name}")
except Exception as e:
logger.warning(f"Failed to pre-cache voice {voice_name}: {e}")
# Get available voices for startup message
voices = await voice_manager.list_voices()
voicepack_count = len(voices) voicepack_count = len(voices)
# Get device info from model manager # Get device info for startup message
device = "GPU" if settings.use_gpu else "CPU" device = "GPU" if settings.use_gpu else "CPU"
model = "ONNX" if settings.use_onnx else "PyTorch" model = "ONNX" if settings.use_onnx else "PyTorch"
except Exception as e:
logger.error(f"Failed to initialize model: {e}")
raise
boundary = "" * 2*12 boundary = "" * 2*12
startup_msg = f""" startup_msg = f"""

View file

@ -1,7 +1,7 @@
from typing import List from typing import List
import numpy as np import numpy as np
from fastapi import APIRouter, Depends, HTTPException, Response from fastapi import APIRouter, Depends, HTTPException, Request, Response
from loguru import logger from loguru import logger
from ..services.audio import AudioService from ..services.audio import AudioService
@ -16,9 +16,9 @@ from ..structures.text_schemas import (
router = APIRouter(tags=["text processing"]) router = APIRouter(tags=["text processing"])
def get_tts_service() -> TTSService: async def get_tts_service() -> TTSService:
"""Dependency to get TTSService instance""" """Dependency to get TTSService instance"""
return TTSService() return await TTSService.create() # Create service with properly initialized managers
@router.post("/text/phonemize", response_model=PhonemeResponse, tags=["deprecated"]) @router.post("/text/phonemize", response_model=PhonemeResponse, tags=["deprecated"])
@ -82,9 +82,6 @@ async def generate_from_phonemes(
) )
try: try:
# Ensure service is initialized
await tts_service.ensure_initialized()
# Validate voice exists # Validate voice exists
available_voices = await tts_service.list_voices() available_voices = await tts_service.list_voices()
if request.voice not in available_voices: if request.voice not in available_voices:

View file

@ -1,6 +1,6 @@
from typing import AsyncGenerator, List, Union from typing import AsyncGenerator, List, Union
from fastapi import APIRouter, Depends, Header, HTTPException, Response, Request from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from loguru import logger from loguru import logger
@ -13,10 +13,28 @@ router = APIRouter(
responses={404: {"description": "Not found"}}, responses={404: {"description": "Not found"}},
) )
# Global TTSService instance with lock
_tts_service = None
_init_lock = None
def get_tts_service() -> TTSService: async def get_tts_service() -> TTSService:
"""Dependency to get TTSService instance with database session""" """Get global TTSService instance"""
return TTSService() # Initialize TTSService with default settings global _tts_service, _init_lock
# Create lock if needed
if _init_lock is None:
import asyncio
_init_lock = asyncio.Lock()
# Initialize service if needed
if _tts_service is None:
async with _init_lock:
# Double check pattern
if _tts_service is None:
_tts_service = await TTSService.create()
logger.info("Created global TTSService instance")
return _tts_service
async def process_voices( async def process_voices(
@ -78,11 +96,13 @@ async def stream_audio_chunks(
async def create_speech( async def create_speech(
request: OpenAISpeechRequest, request: OpenAISpeechRequest,
client_request: Request, client_request: Request,
tts_service: TTSService = Depends(get_tts_service),
x_raw_response: str = Header(None, alias="x-raw-response"), x_raw_response: str = Header(None, alias="x-raw-response"),
): ):
"""OpenAI-compatible endpoint for text-to-speech""" """OpenAI-compatible endpoint for text-to-speech"""
try: try:
# Get global service instance
tts_service = await get_tts_service()
# Process voice combination and validate # Process voice combination and validate
voice_to_use = await process_voices(request.voice, tts_service) voice_to_use = await process_voices(request.voice, tts_service)
@ -145,9 +165,10 @@ async def create_speech(
@router.get("/audio/voices") @router.get("/audio/voices")
async def list_voices(tts_service: TTSService = Depends(get_tts_service)): async def list_voices():
"""List all available voices for text-to-speech""" """List all available voices for text-to-speech"""
try: try:
tts_service = await get_tts_service()
voices = await tts_service.list_voices() voices = await tts_service.list_voices()
return {"voices": voices} return {"voices": voices}
except Exception as e: except Exception as e:
@ -156,9 +177,7 @@ async def list_voices(tts_service: TTSService = Depends(get_tts_service)):
@router.post("/audio/voices/combine") @router.post("/audio/voices/combine")
async def combine_voices( async def combine_voices(request: Union[str, List[str]]):
request: Union[str, List[str]], tts_service: TTSService = Depends(get_tts_service)
):
"""Combine multiple voices into a new voice. """Combine multiple voices into a new voice.
Args: Args:
@ -174,6 +193,7 @@ async def combine_voices(
- 500: Server error (file system issues, combination failed) - 500: Server error (file system issues, combination failed)
""" """
try: try:
tts_service = await get_tts_service()
combined_voice = await process_voices(request, tts_service) combined_voice = await process_voices(request, tts_service)
voices = await tts_service.list_voices() voices = await tts_service.list_voices()
return {"voices": voices, "voice": combined_voice} return {"voices": voices, "voice": combined_voice}

View file

@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
import phonemizer import phonemizer
from .normalizer import normalize_text from .normalizer import normalize_text
phonemizers = {}
class PhonemizerBackend(ABC): class PhonemizerBackend(ABC):
"""Abstract base class for phonemization backends""" """Abstract base class for phonemization backends"""
@ -91,8 +91,9 @@ def phonemize(text: str, language: str = "a", normalize: bool = True) -> str:
Returns: Returns:
Phonemized text Phonemized text
""" """
global phonemizers
if normalize: if normalize:
text = normalize_text(text) text = normalize_text(text)
if language not in phonemizers:
phonemizer = create_phonemizer(language) phonemizers[language]=create_phonemizer(language)
return phonemizer.phonemize(text) return phonemizers[language].phonemize(text)

View file

@ -1,9 +1,8 @@
"""TTS service using model and voice managers.""" """TTS service using model and voice managers."""
import io import io
import os
import time import time
from typing import List, Tuple from typing import List, Tuple, Optional
import numpy as np import numpy as np
import scipy.io.wavfile as wavfile import scipy.io.wavfile as wavfile
@ -17,9 +16,14 @@ from .audio import AudioNormalizer, AudioService
from .text_processing import chunker, normalize_text, process_text from .text_processing import chunker, normalize_text, process_text
import asyncio
class TTSService: class TTSService:
"""Text-to-speech service.""" """Text-to-speech service."""
# Limit concurrent chunk processing
_chunk_semaphore = asyncio.Semaphore(4)
def __init__(self, output_dir: str = None): def __init__(self, output_dir: str = None):
"""Initialize service. """Initialize service.
@ -27,53 +31,24 @@ class TTSService:
output_dir: Optional output directory for saving audio output_dir: Optional output directory for saving audio
""" """
self.output_dir = output_dir self.output_dir = output_dir
self.model_manager = get_model_manager() self.model_manager = None
self.voice_manager = get_voice_manager() self._voice_manager = None
self._initialized = False
self._initialization_error = None
async def ensure_initialized(self): @classmethod
"""Ensure model is initialized.""" async def create(cls, output_dir: str = None) -> 'TTSService':
if self._initialized: """Create and initialize TTSService instance.
return
if self._initialization_error:
raise self._initialization_error
try: Args:
# Get api directory path (one level up from src) output_dir: Optional output directory for saving audio
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
# Determine model file and backend based on hardware Returns:
if settings.use_gpu and torch.cuda.is_available(): Initialized TTSService instance
model_file = settings.pytorch_model_file """
backend_type = 'pytorch_gpu' service = cls(output_dir)
else: # Initialize managers
model_file = settings.onnx_model_file service.model_manager = await get_model_manager()
backend_type = 'onnx_cpu' service._voice_manager = await get_voice_manager()
return service
# Construct model path relative to api directory
model_path = os.path.join(api_dir, settings.model_dir, model_file)
# Ensure model directory exists
os.makedirs(os.path.dirname(model_path), exist_ok=True)
if not os.path.exists(model_path):
raise RuntimeError(f"Model file not found: {model_path}")
# Load default voice for warmup
backend = self.model_manager.get_backend(backend_type)
warmup_voice = await self.voice_manager.load_voice(settings.default_voice, device=backend.device)
logger.info(f"Loaded voice {settings.default_voice} for warmup")
# Initialize model with warmup voice
await self.model_manager.load_model(model_path, warmup_voice, backend_type)
logger.info(f"Initialized model on {backend_type} backend")
self._initialized = True
except Exception as e:
logger.error(f"Failed to initialize model: {e}")
self._initialization_error = RuntimeError(f"Model initialization failed: {e}")
raise self._initialization_error
async def generate_audio( async def generate_audio(
self, text: str, voice: str, speed: float = 1.0 self, text: str, voice: str, speed: float = 1.0
@ -88,8 +63,8 @@ class TTSService:
Returns: Returns:
Audio samples and processing time Audio samples and processing time
""" """
await self.ensure_initialized()
start_time = time.time() start_time = time.time()
voice_tensor = None
try: try:
# Normalize text # Normalize text
@ -98,31 +73,40 @@ class TTSService:
raise ValueError("Text is empty after preprocessing") raise ValueError("Text is empty after preprocessing")
text = str(normalized) text = str(normalized)
# Process text into chunks
audio_chunks = []
for chunk in chunker.split_text(text):
try:
# Convert chunk to token IDs
tokens = process_text(chunk)
if not tokens:
continue
# Get backend and load voice # Get backend and load voice
backend = self.model_manager.get_backend() backend = self.model_manager.get_backend()
voice_tensor = await self.voice_manager.load_voice(voice, device=backend.device) voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device)
# Get all chunks upfront
chunks = list(chunker.split_text(text))
if not chunks:
raise ValueError("No text chunks to process")
# Process chunk with concurrency control
async def process_chunk(chunk: str) -> Optional[np.ndarray]:
async with self._chunk_semaphore:
try:
tokens = process_text(chunk)
if not tokens:
return None
# Generate audio # Generate audio
chunk_audio = await self.model_manager.generate( return await self.model_manager.generate(
tokens, tokens,
voice_tensor, voice_tensor,
speed=speed speed=speed
) )
if chunk_audio is not None:
audio_chunks.append(chunk_audio)
except Exception as e: except Exception as e:
logger.error(f"Failed to process chunk: '{chunk}'. Error: {str(e)}") logger.error(f"Failed to process chunk: '{chunk}'. Error: {str(e)}")
continue return None
# Process all chunks concurrently
chunk_results = await asyncio.gather(*[
process_chunk(chunk) for chunk in chunks
])
# Filter out None results and combine
audio_chunks = [chunk for chunk in chunk_results if chunk is not None]
if not audio_chunks: if not audio_chunks:
raise ValueError("No audio chunks were generated successfully") raise ValueError("No audio chunks were generated successfully")
@ -134,6 +118,11 @@ class TTSService:
except Exception as e: except Exception as e:
logger.error(f"Error in audio generation: {str(e)}") logger.error(f"Error in audio generation: {str(e)}")
raise raise
finally:
# Always clean up voice tensor
if voice_tensor is not None:
del voice_tensor
torch.cuda.empty_cache()
async def generate_audio_stream( async def generate_audio_stream(
self, self,
@ -153,32 +142,33 @@ class TTSService:
Yields: Yields:
Audio chunks as bytes Audio chunks as bytes
""" """
await self.ensure_initialized()
try:
# Setup audio processing # Setup audio processing
stream_normalizer = AudioNormalizer() stream_normalizer = AudioNormalizer()
voice_tensor = None
try:
# Normalize text # Normalize text
normalized = normalize_text(text) normalized = normalize_text(text)
if not normalized: if not normalized:
raise ValueError("Text is empty after preprocessing") raise ValueError("Text is empty after preprocessing")
text = str(normalized) text = str(normalized)
# Process chunks
is_first = True
chunk_gen = chunker.split_text(text)
current_chunk = next(chunk_gen, None)
while current_chunk is not None:
next_chunk = next(chunk_gen, None)
try:
# Convert chunk to token IDs
tokens = process_text(current_chunk)
if tokens:
# Get backend and load voice # Get backend and load voice
backend = self.model_manager.get_backend() backend = self.model_manager.get_backend()
voice_tensor = await self.voice_manager.load_voice(voice, device=backend.device) voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device)
# Get all chunks upfront
chunks = list(chunker.split_text(text))
if not chunks:
raise ValueError("No text chunks to process")
# Process chunk with concurrency control
async def process_chunk(chunk: str, is_first: bool, is_last: bool) -> Optional[bytes]:
async with self._chunk_semaphore:
try:
tokens = process_text(chunk)
if not tokens:
return None
# Generate audio # Generate audio
chunk_audio = await self.model_manager.generate( chunk_audio = await self.model_manager.generate(
@ -189,26 +179,38 @@ class TTSService:
if chunk_audio is not None: if chunk_audio is not None:
# Convert to bytes # Convert to bytes
chunk_bytes = AudioService.convert_audio( return AudioService.convert_audio(
chunk_audio, chunk_audio,
24000, 24000,
output_format, output_format,
is_first_chunk=is_first, is_first_chunk=is_first,
normalizer=stream_normalizer, normalizer=stream_normalizer,
is_last_chunk=(next_chunk is None), is_last_chunk=is_last,
stream=True stream=True
) )
yield chunk_bytes
is_first = False
except Exception as e: except Exception as e:
logger.error(f"Failed to generate audio for chunk: '{current_chunk}'. Error: {str(e)}") logger.error(f"Failed to generate audio for chunk: '{chunk}'. Error: {str(e)}")
return None
current_chunk = next_chunk # Create tasks for all chunks
tasks = [
process_chunk(chunk, i==0, i==len(chunks)-1)
for i, chunk in enumerate(chunks)
]
# Process chunks concurrently and yield results in order
for chunk_bytes in await asyncio.gather(*tasks):
if chunk_bytes is not None:
yield chunk_bytes
except Exception as e: except Exception as e:
logger.error(f"Error in audio generation stream: {str(e)}") logger.error(f"Error in audio generation stream: {str(e)}")
raise raise
finally:
# Always clean up voice tensor
if voice_tensor is not None:
del voice_tensor
torch.cuda.empty_cache()
async def combine_voices(self, voices: List[str]) -> str: async def combine_voices(self, voices: List[str]) -> str:
"""Combine multiple voices. """Combine multiple voices.
@ -219,8 +221,7 @@ class TTSService:
Returns: Returns:
Name of combined voice Name of combined voice
""" """
await self.ensure_initialized() return await self._voice_manager.combine_voices(voices)
return await self.voice_manager.combine_voices(voices)
async def list_voices(self) -> List[str]: async def list_voices(self) -> List[str]:
"""List available voices. """List available voices.
@ -228,7 +229,7 @@ class TTSService:
Returns: Returns:
List of voice names List of voice names
""" """
return await self.voice_manager.list_voices() return await self._voice_manager.list_voices()
def _audio_to_bytes(self, audio: np.ndarray) -> bytes: def _audio_to_bytes(self, audio: np.ndarray) -> bytes:
"""Convert audio to WAV bytes. """Convert audio to WAV bytes.

View file

@ -1,60 +0,0 @@
import os
from typing import List, Tuple
import torch
from loguru import logger
from ..core.config import settings
from .tts_model import TTSModel
from .tts_service import TTSService
class WarmupService:
"""Service for warming up TTS models and voice caches"""
def __init__(self):
"""Initialize warmup service and ensure model is ready"""
# Initialize model if not already initialized
if TTSModel._instance is None:
TTSModel.initialize(settings.model_dir)
self.tts_service = TTSService()
def load_voices(self) -> List[Tuple[str, torch.Tensor]]:
"""Load and cache voices up to LRU limit"""
# Get all voices sorted by filename length (shorter names first, usually base voices)
voice_files = sorted(
[f for f in os.listdir(TTSModel.VOICES_DIR) if f.endswith(".pt")], key=len
)
n_voices_cache = 1
loaded_voices = []
for voice_file in voice_files[:n_voices_cache]:
try:
voice_path = os.path.join(TTSModel.VOICES_DIR, voice_file)
# load using service, lru cache
voicepack = self.tts_service._load_voice(voice_path)
loaded_voices.append(
(voice_file[:-3], voicepack)
) # Store name and tensor
# voicepack = torch.load(voice_path, map_location=TTSModel.get_device(), weights_only=True)
# logger.info(f"Loaded voice {voice_file[:-3]} into cache")
except Exception as e:
logger.error(f"Failed to load voice {voice_file}: {e}")
logger.info(f"Pre-loaded {len(loaded_voices)} voices into cache")
return loaded_voices
async def warmup_voices(
self, warmup_text: str, loaded_voices: List[Tuple[str, torch.Tensor]]
):
"""Warm up voice inference and streaming"""
n_warmups = 1
for voice_name, _ in loaded_voices[:n_warmups]:
try:
logger.info(f"Running warmup inference on voice {voice_name}")
async for _ in self.tts_service.generate_audio_stream(
warmup_text, voice_name, 1.0, "pcm"
):
pass # Process all chunks to properly warm up
logger.info(f"Completed warmup for voice {voice_name}")
except Exception as e:
logger.warning(f"Warmup failed for voice {voice_name}: {e}")

View file

@ -10,7 +10,7 @@ services:
ports: ports:
- "8880:8880" - "8880:8880"
environment: environment:
- PYTHONPATH=/app - PYTHONPATH=/app:/app/api
- USE_GPU=true - USE_GPU=true
- USE_ONNX=false - USE_ONNX=false
- PYTHONUNBUFFERED=1 - PYTHONUNBUFFERED=1

View file

@ -25,9 +25,7 @@ def main() -> None:
def stream_to_speakers() -> None: def stream_to_speakers() -> None:
import pyaudio import pyaudio
player_stream = pyaudio.PyAudio().open( player_stream = pyaudio.PyAudio().open(format=pyaudio.paInt16, channels=1, rate=24000, output=True)
format=pyaudio.paInt16, channels=1, rate=24000, output=True
)
start_time = time.time() start_time = time.time()

View file

@ -0,0 +1,53 @@
#!/usr/bin/env rye run python
import asyncio
import time
from pathlib import Path
from openai import AsyncOpenAI
# Initialize async client
openai = AsyncOpenAI(base_url="http://localhost:8880/v1", api_key="not-needed-for-local")
async def save_to_file(text: str, file_id: int) -> None:
"""Save TTS output to file asynchronously"""
speech_file_path = Path(__file__).parent / f"speech_{file_id}.mp3"
start_time = time.time()
print(f"Starting file {file_id}")
try:
# Use streaming endpoint with mp3 format
async with openai.audio.speech.with_streaming_response.create(
model="kokoro",
voice="af_bella",
input=text,
response_format="mp3"
) as response:
print(f"File {file_id} - Time to first byte: {int((time.time() - start_time) * 1000)}ms")
# Open file in binary write mode
with open(speech_file_path, 'wb') as f:
async for chunk in response.iter_bytes():
f.write(chunk)
print(f"File {file_id} completed in {int((time.time() - start_time) * 1000)}ms")
except Exception as e:
print(f"Error processing file {file_id}: {e}")
async def main() -> None:
# Different text samples for variety
texts = [
"The quick brown fox jumped over the lazy dogs. I see skies of blue and clouds of white",
"I see skies of blue and clouds of white. I see skies of blue and clouds of white",
]
# Create tasks for saving to files
file_tasks = [
save_to_file(text, i)
for i, text in enumerate(texts)
]
# Run file tasks concurrently
await asyncio.gather(*file_tasks)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,91 @@
#!/usr/bin/env rye run python
import asyncio
import time
from pathlib import Path
import pyaudio
from openai import AsyncOpenAI
# Initialize async client
openai = AsyncOpenAI(base_url="http://localhost:8880/v1", api_key="not-needed-for-local")
# Create a shared PyAudio instance
p = pyaudio.PyAudio()
async def stream_to_speakers(text: str, stream_id: int) -> None:
"""Stream TTS audio to speakers asynchronously"""
player_stream = p.open(
format=pyaudio.paInt16,
channels=1,
rate=24000,
output=True
)
start_time = time.time()
print(f"Starting stream {stream_id}")
try:
async with openai.audio.speech.with_streaming_response.create(
model="kokoro",
voice="af_bella",
response_format="pcm",
input=text
) as response:
print(f"Stream {stream_id} - Time to first byte: {int((time.time() - start_time) * 1000)}ms")
async for chunk in response.iter_bytes(chunk_size=1024):
player_stream.write(chunk)
# Small sleep to allow other coroutines to run
await asyncio.sleep(0.001)
print(f"Stream {stream_id} completed in {int((time.time() - start_time) * 1000)}ms")
finally:
player_stream.stop_stream()
player_stream.close()
async def save_to_file(text: str, file_id: int) -> None:
"""Save TTS output to file asynchronously"""
speech_file_path = Path(__file__).parent / f"speech_{file_id}.mp3"
async with openai.audio.speech.with_streaming_response.create(
model="kokoro",
voice="af_bella",
input=text
) as response:
# Open file in binary write mode
with open(speech_file_path, 'wb') as f:
async for chunk in response.iter_bytes():
f.write(chunk)
print(f"File {file_id} saved to {speech_file_path}")
async def main() -> None:
# Different text samples for variety
texts = [
"The quick brown fox jumped over the lazy dogs. I see skies of blue and clouds of white",
"I see skies of blue and clouds of white. I see skies of blue and clouds of white",
]
# Create tasks for streaming to speakers
speaker_tasks = [
stream_to_speakers(text, i)
for i, text in enumerate(texts)
]
# Create tasks for saving to files
file_tasks = [
save_to_file(text, i)
for i, text in enumerate(texts)
]
# Combine all tasks
all_tasks = speaker_tasks + file_tasks
# Run all tasks concurrently
try:
await asyncio.gather(*all_tasks)
finally:
# Clean up PyAudio
p.terminate()
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,66 @@
#!/usr/bin/env rye run python
import asyncio
import time
import pyaudio
from openai import AsyncOpenAI
# Initialize async client
openai = AsyncOpenAI(base_url="http://localhost:8880/v1", api_key="not-needed-for-local")
# Create a shared PyAudio instance
p = pyaudio.PyAudio()
async def stream_to_speakers(text: str, stream_id: int) -> None:
"""Stream TTS audio to speakers asynchronously"""
player_stream = p.open(
format=pyaudio.paInt16,
channels=1,
rate=24000,
output=True
)
start_time = time.time()
print(f"Starting stream {stream_id}")
try:
async with openai.audio.speech.with_streaming_response.create(
model="kokoro",
voice="af_bella",
response_format="pcm",
input=text
) as response:
print(f"Stream {stream_id} - Time to first byte: {int((time.time() - start_time) * 1000)}ms")
async for chunk in response.iter_bytes(chunk_size=1024):
player_stream.write(chunk)
# Small sleep to allow other coroutines to run
await asyncio.sleep(0.001)
print(f"Stream {stream_id} completed in {int((time.time() - start_time) * 1000)}ms")
finally:
player_stream.stop_stream()
player_stream.close()
async def main() -> None:
# Different text samples for variety
texts = [
"The quick brown fox jumped over the lazy dogs. I see skies of blue and clouds of white",
"I see skies of blue and clouds of white. I see skies of blue and clouds of white",
]
# Create tasks for streaming to speakers
speaker_tasks = [
stream_to_speakers(text, i)
for i, text in enumerate(texts)
]
# Run speaker tasks concurrently
try:
await asyncio.gather(*speaker_tasks)
finally:
# Clean up PyAudio
p.terminate()
if __name__ == "__main__":
asyncio.run(main())

Binary file not shown.