Refactor model loading and configuration: update, adjust model loading device,. add async streaming examples and remove unused warmup service.

This commit is contained in:
remsky 2025-01-22 02:33:29 -07:00
parent 21bf810f97
commit 4a24be1605
21 changed files with 929 additions and 484 deletions

View file

@ -367,7 +367,7 @@ async def build_model(path, device):
decoder=decoder.to(device).eval(),
text_encoder=text_encoder.to(device).eval(),
)
weights = await load_model_weights(path, device='cpu')
weights = await load_model_weights(path, device=device)
for key, state_dict in weights['net'].items():
assert key in model, key
try:

View file

@ -15,9 +15,9 @@ class Settings(BaseSettings):
default_voice: str = "af"
use_gpu: bool = False # Whether to use GPU acceleration if available
use_onnx: bool = True # Whether to use ONNX runtime
# Paths relative to api directory
model_dir: str = "src/models" # Model directory relative to api/
voices_dir: str = "src/voices" # Voices directory relative to api/
# Container absolute paths
model_dir: str = "/app/api/src/models" # Absolute path in container
voices_dir: str = "/app/api/src/voices" # Absolute path in container
# Model filenames
pytorch_model_file: str = "kokoro-v0_19.pth"

View file

@ -6,6 +6,11 @@ from pydantic import BaseModel, Field
class ONNXCPUConfig(BaseModel):
"""ONNX CPU runtime configuration."""
# Session pooling
max_instances: int = Field(4, description="Maximum concurrent model instances")
instance_timeout: int = Field(300, description="Session timeout in seconds")
# Runtime settings
num_threads: int = Field(8, description="Number of threads for parallel operations")
inter_op_threads: int = Field(4, description="Number of threads for operator parallelism")
execution_mode: str = Field("parallel", description="ONNX execution mode")
@ -20,9 +25,14 @@ class ONNXCPUConfig(BaseModel):
class ONNXGPUConfig(ONNXCPUConfig):
"""ONNX GPU-specific configuration."""
# CUDA settings
device_id: int = Field(0, description="CUDA device ID")
gpu_mem_limit: float = Field(0.7, description="Fraction of GPU memory to use")
cudnn_conv_algo_search: str = Field("EXHAUSTIVE", description="CuDNN convolution algorithm search")
# Stream management
cuda_streams: int = Field(2, description="Number of CUDA streams for inference")
stream_timeout: int = Field(60, description="Stream timeout in seconds")
do_copy_in_default_stream: bool = Field(True, description="Copy in default CUDA stream")
class Config:
@ -32,8 +42,6 @@ class ONNXGPUConfig(ONNXCPUConfig):
class PyTorchCPUConfig(BaseModel):
"""PyTorch CPU backend configuration."""
max_batch_size: int = Field(32, description="Maximum batch size for batched inference")
stream_buffer_size: int = Field(8, description="Size of stream buffer")
memory_threshold: float = Field(0.8, description="Memory threshold for cleanup")
retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors")
num_threads: int = Field(8, description="Number of threads for parallel operations")
@ -49,18 +57,11 @@ class PyTorchGPUConfig(BaseModel):
device_id: int = Field(0, description="CUDA device ID")
use_fp16: bool = Field(True, description="Whether to use FP16 precision")
use_triton: bool = Field(True, description="Whether to use Triton for CUDA kernels")
max_batch_size: int = Field(32, description="Maximum batch size for batched inference")
stream_buffer_size: int = Field(8, description="Size of CUDA stream buffer")
memory_threshold: float = Field(0.8, description="Memory threshold for cleanup")
retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors")
sync_cuda: bool = Field(True, description="Whether to synchronize CUDA operations")
class Config:
frozen = True
"""PyTorch CPU-specific configuration."""
num_threads: int = Field(8, description="Number of threads for parallel operations")
pin_memory: bool = Field(True, description="Whether to pin memory for faster CPU-GPU transfer")
cuda_streams: int = Field(2, description="Number of CUDA streams for inference")
stream_timeout: int = Field(60, description="Stream timeout in seconds")
class Config:
frozen = True
@ -74,7 +75,7 @@ class ModelConfig(BaseModel):
device_type: str = Field("auto", description="Device type ('cpu', 'gpu', or 'auto')")
cache_models: bool = Field(True, description="Whether to cache loaded models")
cache_voices: bool = Field(True, description="Whether to cache voice tensors")
voice_cache_size: int = Field(10, description="Maximum number of cached voices")
voice_cache_size: int = Field(2, description="Maximum number of cached voices")
# Backend-specific configs
onnx_cpu: ONNXCPUConfig = Field(default_factory=ONNXCPUConfig)

View file

@ -1,5 +1,6 @@
"""Model management and caching."""
import asyncio
from typing import Dict, Optional
import torch
@ -13,11 +14,19 @@ from .onnx_cpu import ONNXCPUBackend
from .onnx_gpu import ONNXGPUBackend
from .pytorch_cpu import PyTorchCPUBackend
from .pytorch_gpu import PyTorchGPUBackend
from .session_pool import CPUSessionPool, StreamingSessionPool
# Global singleton instance and state
_manager_instance = None
_manager_lock = asyncio.Lock()
_loaded_models = {}
_backends = {}
class ModelManager:
"""Manages model loading and inference across backends."""
def __init__(self, config: Optional[ModelConfig] = None):
"""Initialize model manager.
@ -25,65 +34,60 @@ class ModelManager:
config: Optional configuration
"""
self._config = config or model_config
self._backends: Dict[str, BaseModelBackend] = {}
self._current_backend: Optional[str] = None
self._initialize_backends()
global _loaded_models, _backends
self._loaded_models = _loaded_models
self._backends = _backends
# Initialize session pools
self._session_pools = {
'onnx_cpu': CPUSessionPool(),
'onnx_gpu': StreamingSessionPool()
}
# Initialize locks
self._backend_locks: Dict[str, asyncio.Lock] = {}
def _initialize_backends(self) -> None:
"""Initialize available backends based on settings."""
has_gpu = settings.use_gpu and torch.cuda.is_available()
def _determine_device(self) -> str:
"""Determine device based on settings."""
if settings.use_gpu and torch.cuda.is_available():
return "cuda"
return "cpu"
async def initialize(self) -> None:
"""Initialize backends."""
if self._backends:
logger.debug("Using existing backend instances")
return
device = self._determine_device()
try:
if has_gpu:
if device == "cuda":
if settings.use_onnx:
# ONNX GPU primary
self._backends['onnx_gpu'] = ONNXGPUBackend()
self._current_backend = 'onnx_gpu'
logger.info("Initialized ONNX GPU backend")
# PyTorch GPU fallback
self._backends['pytorch_gpu'] = PyTorchGPUBackend()
logger.info("Initialized PyTorch GPU backend")
logger.info("Initialized new ONNX GPU backend")
else:
# PyTorch GPU primary
self._backends['pytorch_gpu'] = PyTorchGPUBackend()
self._current_backend = 'pytorch_gpu'
logger.info("Initialized PyTorch GPU backend")
logger.info("Initialized new PyTorch GPU backend")
else:
if settings.use_onnx:
self._backends['onnx_cpu'] = ONNXCPUBackend()
self._current_backend = 'onnx_cpu'
logger.info("Initialized new ONNX CPU backend")
else:
self._backends['pytorch_cpu'] = PyTorchCPUBackend()
self._current_backend = 'pytorch_cpu'
logger.info("Initialized new PyTorch CPU backend")
# ONNX GPU fallback
self._backends['onnx_gpu'] = ONNXGPUBackend()
logger.info("Initialized ONNX GPU backend")
else:
self._initialize_cpu_backends()
except Exception as e:
logger.error(f"Failed to initialize GPU backends: {e}")
# Fallback to CPU if GPU fails
self._initialize_cpu_backends()
def _initialize_cpu_backends(self) -> None:
"""Initialize CPU backends based on settings."""
try:
if settings.use_onnx:
# ONNX CPU primary
self._backends['onnx_cpu'] = ONNXCPUBackend()
self._current_backend = 'onnx_cpu'
logger.info("Initialized ONNX CPU backend")
# Initialize locks for each backend
for backend in self._backends:
self._backend_locks[backend] = asyncio.Lock()
# PyTorch CPU fallback
self._backends['pytorch_cpu'] = PyTorchCPUBackend()
logger.info("Initialized PyTorch CPU backend")
else:
# PyTorch CPU primary
self._backends['pytorch_cpu'] = PyTorchCPUBackend()
self._current_backend = 'pytorch_cpu'
logger.info("Initialized PyTorch CPU backend")
# ONNX CPU fallback
self._backends['onnx_cpu'] = ONNXCPUBackend()
logger.info("Initialized ONNX CPU backend")
except Exception as e:
logger.error(f"Failed to initialize CPU backends: {e}")
raise RuntimeError("No backends available")
logger.error(f"Failed to initialize backend: {e}")
raise RuntimeError("Failed to initialize backend")
def get_backend(self, backend_type: Optional[str] = None) -> BaseModelBackend:
"""Get specified backend.
@ -154,19 +158,42 @@ class ModelManager:
if backend_type is None:
backend_type = self._determine_backend(abs_path)
backend = self.get_backend(backend_type)
# Get backend lock
lock = self._backend_locks[backend_type]
# Load model
await backend.load_model(abs_path)
logger.info(f"Loaded model on {backend_type} backend")
# Run warmup if voice provided
if warmup_voice is not None:
await self._warmup_inference(backend, warmup_voice)
async with lock:
backend = self.get_backend(backend_type)
# For ONNX backends, use session pool
if backend_type.startswith('onnx'):
pool = self._session_pools[backend_type]
backend._session = await pool.get_session(abs_path)
self._loaded_models[backend_type] = abs_path
logger.info(f"Fetched model instance from {backend_type} pool")
# For PyTorch backends, load normally
else:
# Check if model is already loaded
if (backend_type in self._loaded_models and
self._loaded_models[backend_type] == abs_path and
backend.is_loaded):
logger.info(f"Fetching existing model instance from {backend_type}")
return
# Load model
await backend.load_model(abs_path)
self._loaded_models[backend_type] = abs_path
logger.info(f"Initialized new model instance on {backend_type}")
# Run warmup if voice provided
if warmup_voice is not None:
await self._warmup_inference(backend, warmup_voice)
except Exception as e:
# Clear cached path on failure
self._loaded_models.pop(backend_type, None)
raise RuntimeError(f"Failed to load model: {e}")
async def _warmup_inference(self, backend: BaseModelBackend, voice: torch.Tensor) -> None:
"""Run warmup inference to initialize model.
@ -188,7 +215,7 @@ class ModelManager:
# Run inference
backend.generate(tokens, voice, speed=1.0)
logger.info("Completed warmup inference")
logger.debug("Completed warmup inference")
except Exception as e:
logger.warning(f"Warmup inference failed: {e}")
@ -221,16 +248,23 @@ class ModelManager:
try:
# Generate audio using provided voice tensor
# No lock needed here since inference is thread-safe
return backend.generate(tokens, voice, speed)
except Exception as e:
raise RuntimeError(f"Generation failed: {e}")
def unload_all(self) -> None:
"""Unload models from all backends."""
"""Unload models from all backends and clear cache."""
# Clean up session pools
for pool in self._session_pools.values():
pool.cleanup()
# Unload PyTorch backends
for backend in self._backends.values():
backend.unload()
logger.info("Unloaded all models")
self._loaded_models.clear()
logger.info("Unloaded all models and cleared cache")
@property
def available_backends(self) -> list[str]:
@ -251,12 +285,8 @@ class ModelManager:
return self._current_backend
# Module-level instance
_manager: Optional[ModelManager] = None
def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
"""Get or create global model manager instance.
async def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
"""Get global model manager instance.
Args:
config: Optional model configuration
@ -264,7 +294,10 @@ def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
Returns:
ModelManager instance
"""
global _manager
if _manager is None:
_manager = ModelManager(config)
return _manager
global _manager_instance
async with _manager_lock:
if _manager_instance is None:
_manager_instance = ModelManager(config)
await _manager_instance.initialize()
return _manager_instance

View file

@ -1,20 +1,16 @@
"""CPU-based ONNX inference backend."""
from typing import Dict, Optional
from typing import Optional
import numpy as np
import torch
from loguru import logger
from onnxruntime import (
ExecutionMode,
GraphOptimizationLevel,
InferenceSession,
SessionOptions
)
from onnxruntime import InferenceSession
from ..core import paths
from ..core.model_config import model_config
from .base import BaseModelBackend
from .session_pool import create_session_options, create_provider_options
class ONNXCPUBackend(BaseModelBackend):
@ -47,8 +43,8 @@ class ONNXCPUBackend(BaseModelBackend):
logger.info(f"Loading ONNX model: {model_path}")
# Configure session
options = self._create_session_options()
provider_options = self._create_provider_options()
options = create_session_options(is_gpu=False)
provider_options = create_provider_options(is_gpu=False)
# Create session
self._session = InferenceSession(
@ -84,9 +80,9 @@ class ONNXCPUBackend(BaseModelBackend):
raise RuntimeError("Model not loaded")
try:
# Prepare inputs
tokens_input = np.array([tokens], dtype=np.int64)
style_input = voice[len(tokens)].numpy()
# Prepare inputs with start/end tokens
tokens_input = np.array([[0, *tokens, 0]], dtype=np.int64) # Add start/end tokens
style_input = voice[len(tokens) + 2].numpy() # Adjust index for start/end tokens
speed_input = np.full(1, speed, dtype=np.float32)
# Run inference
@ -104,52 +100,6 @@ class ONNXCPUBackend(BaseModelBackend):
except Exception as e:
raise RuntimeError(f"Generation failed: {e}")
def _create_session_options(self) -> SessionOptions:
"""Create ONNX session options.
Returns:
Configured session options
"""
options = SessionOptions()
config = model_config.onnx_cpu
# Set optimization level
if config.optimization_level == "all":
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
elif config.optimization_level == "basic":
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
else:
options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
# Configure threading
options.intra_op_num_threads = config.num_threads
options.inter_op_num_threads = config.inter_op_threads
# Set execution mode
options.execution_mode = (
ExecutionMode.ORT_PARALLEL
if config.execution_mode == "parallel"
else ExecutionMode.ORT_SEQUENTIAL
)
# Configure memory optimization
options.enable_mem_pattern = config.memory_pattern
return options
def _create_provider_options(self) -> Dict:
"""Create CPU provider options.
Returns:
Provider configuration
"""
return {
"CPUExecutionProvider": {
"arena_extend_strategy": model_config.onnx_cpu.arena_extend_strategy,
"cpu_memory_arena_cfg": "cpu:0"
}
}
def unload(self) -> None:
"""Unload model and free resources."""
if self._session is not None:

View file

@ -1,20 +1,16 @@
"""GPU-based ONNX inference backend."""
from typing import Dict, Optional
from typing import Optional
import numpy as np
import torch
from loguru import logger
from onnxruntime import (
ExecutionMode,
GraphOptimizationLevel,
InferenceSession,
SessionOptions
)
from onnxruntime import InferenceSession
from ..core import paths
from ..core.model_config import model_config
from .base import BaseModelBackend
from .session_pool import create_session_options, create_provider_options
class ONNXGPUBackend(BaseModelBackend):
@ -27,6 +23,9 @@ class ONNXGPUBackend(BaseModelBackend):
raise RuntimeError("CUDA not available")
self._device = "cuda"
self._session: Optional[InferenceSession] = None
# Configure GPU
torch.cuda.set_device(model_config.onnx_gpu.device_id)
@property
def is_loaded(self) -> bool:
@ -49,8 +48,8 @@ class ONNXGPUBackend(BaseModelBackend):
logger.info(f"Loading ONNX model on GPU: {model_path}")
# Configure session
options = self._create_session_options()
provider_options = self._create_provider_options()
options = create_session_options(is_gpu=True)
provider_options = create_provider_options(is_gpu=True)
# Create session with CUDA provider
self._session = InferenceSession(
@ -87,8 +86,8 @@ class ONNXGPUBackend(BaseModelBackend):
try:
# Prepare inputs
tokens_input = np.array([tokens], dtype=np.int64)
style_input = voice[len(tokens)].cpu().numpy() # Move to CPU for ONNX
tokens_input = np.array([[0, *tokens, 0]], dtype=np.int64) # Add start/end tokens
style_input = voice[len(tokens) + 2].cpu().numpy() # Move to CPU for ONNX
speed_input = np.full(1, speed, dtype=np.float32)
# Run inference
@ -104,62 +103,15 @@ class ONNXGPUBackend(BaseModelBackend):
return result[0]
except Exception as e:
if "out of memory" in str(e).lower():
# Clear CUDA cache and retry
torch.cuda.empty_cache()
return self.generate(tokens, voice, speed)
raise RuntimeError(f"Generation failed: {e}")
def _create_session_options(self) -> SessionOptions:
"""Create ONNX session options.
Returns:
Configured session options
"""
options = SessionOptions()
config = model_config.onnx_gpu
# Set optimization level
if config.optimization_level == "all":
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
elif config.optimization_level == "basic":
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
else:
options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
# Configure threading
options.intra_op_num_threads = config.num_threads
options.inter_op_num_threads = config.inter_op_threads
# Set execution mode
options.execution_mode = (
ExecutionMode.ORT_PARALLEL
if config.execution_mode == "parallel"
else ExecutionMode.ORT_SEQUENTIAL
)
# Configure memory optimization
options.enable_mem_pattern = config.memory_pattern
return options
def _create_provider_options(self) -> Dict:
"""Create CUDA provider options.
Returns:
Provider configuration
"""
config = model_config.onnx_gpu
return {
"CUDAExecutionProvider": {
"device_id": config.device_id,
"arena_extend_strategy": config.arena_extend_strategy,
"gpu_mem_limit": int(config.gpu_mem_limit * torch.cuda.get_device_properties(0).total_memory),
"cudnn_conv_algo_search": config.cudnn_conv_algo_search,
"do_copy_in_default_stream": config.do_copy_in_default_stream
}
}
def unload(self) -> None:
"""Unload model and free resources."""
if self._session is not None:
del self._session
self._session = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.empty_cache()

View file

@ -13,8 +13,37 @@ from ..core.model_config import model_config
from .base import BaseModelBackend
class CUDAStreamManager:
"""CUDA stream manager."""
def __init__(self, num_streams: int):
"""Initialize stream manager.
Args:
num_streams: Number of CUDA streams
"""
self.streams = [torch.cuda.Stream() for _ in range(num_streams)]
self._current = 0
def get_next_stream(self) -> torch.cuda.Stream:
"""Get next available stream.
Returns:
CUDA stream
"""
stream = self.streams[self._current]
self._current = (self._current + 1) % len(self.streams)
return stream
@torch.no_grad()
def forward(model: torch.nn.Module, tokens: list[int], ref_s: torch.Tensor, speed: float) -> np.ndarray:
def forward(
model: torch.nn.Module,
tokens: list[int],
ref_s: torch.Tensor,
speed: float,
stream: Optional[torch.cuda.Stream] = None
) -> np.ndarray:
"""Forward pass through model.
Args:
@ -22,59 +51,67 @@ def forward(model: torch.nn.Module, tokens: list[int], ref_s: torch.Tensor, spee
tokens: Input tokens
ref_s: Reference signal (shape: [1, n_features])
speed: Speed multiplier
stream: Optional CUDA stream
Returns:
Generated audio
"""
device = ref_s.device
# Initial tensor setup
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
text_mask = length_to_mask(input_lengths).to(device)
# Use provided stream or default
with torch.cuda.stream(stream) if stream else torch.cuda.default_stream():
# Initial tensor setup
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
text_mask = length_to_mask(input_lengths).to(device)
# Split reference signals (style_dim=128 from config)
style_dim = 128
s_ref = ref_s[:, :style_dim].clone().to(device)
s_content = ref_s[:, style_dim:].clone().to(device)
# Split reference signals (style_dim=128 from config)
style_dim = 128
s_ref = ref_s[:, :style_dim].clone().to(device)
s_content = ref_s[:, style_dim:].clone().to(device)
# BERT and encoder pass
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
# BERT and encoder pass
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
# Predictor forward pass
d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask)
x, _ = model.predictor.lstm(d)
# Predictor forward pass
d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask)
x, _ = model.predictor.lstm(d)
# Duration prediction
duration = model.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long()
del duration, x
# Duration prediction
duration = model.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long()
del duration, x
# Alignment matrix construction
pred_aln_trg = torch.zeros(input_lengths.item(), pred_dur.sum().item(), device=device)
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame:c_frame + pred_dur[0, i].item()] = 1
c_frame += pred_dur[0, i].item()
pred_aln_trg = pred_aln_trg.unsqueeze(0)
# Alignment matrix construction
pred_aln_trg = torch.zeros(input_lengths.item(), pred_dur.sum().item(), device=device)
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame:c_frame + pred_dur[0, i].item()] = 1
c_frame += pred_dur[0, i].item()
pred_aln_trg = pred_aln_trg.unsqueeze(0)
# Matrix multiplications
en = d.transpose(-1, -2) @ pred_aln_trg
del d
F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
del en
# Matrix multiplications
en = d.transpose(-1, -2) @ pred_aln_trg
del d
F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
del en
# Final text encoding and decoding
t_en = model.text_encoder(tokens, input_lengths, text_mask)
asr = t_en @ pred_aln_trg
del t_en
# Final text encoding and decoding
t_en = model.text_encoder(tokens, input_lengths, text_mask)
asr = t_en @ pred_aln_trg
del t_en
# Generate output
output = model.decoder(asr, F0_pred, N_pred, s_ref)
return output.squeeze().cpu().numpy()
# Generate output
output = model.decoder(asr, F0_pred, N_pred, s_ref)
# Ensure operation completion if using custom stream
if stream:
stream.synchronize()
return output.squeeze().cpu().numpy()
def length_to_mask(lengths: torch.Tensor) -> torch.Tensor:
@ -92,9 +129,10 @@ class PyTorchGPUBackend(BaseModelBackend):
def __init__(self):
"""Initialize GPU backend."""
super().__init__()
if not torch.cuda.is_available():
raise RuntimeError("CUDA not available")
self._device = "cuda"
from ..core.config import settings
if not (settings.use_gpu and torch.cuda.is_available()):
raise RuntimeError("GPU backend requires GPU support and use_gpu=True")
self._device = "cuda" # Device is enforced by backend selection in model_manager
self._model: Optional[torch.nn.Module] = None
# Configure GPU settings
@ -102,6 +140,9 @@ class PyTorchGPUBackend(BaseModelBackend):
if config.sync_cuda:
torch.cuda.synchronize()
torch.cuda.set_device(config.device_id)
# Initialize stream manager
self._stream_manager = CUDAStreamManager(config.cuda_streams)
async def load_model(self, path: str) -> None:
"""Load PyTorch model.
@ -154,8 +195,11 @@ class PyTorchGPUBackend(BaseModelBackend):
if ref_s.dim() == 1:
ref_s = ref_s.unsqueeze(0) # Add batch dimension if needed
# Generate audio
return forward(self._model, tokens, ref_s, speed)
# Get next available stream
stream = self._stream_manager.get_next_stream()
# Generate audio using stream
return forward(self._model, tokens, ref_s, speed, stream)
except Exception as e:
logger.error(f"Generation failed: {e}")

View file

@ -0,0 +1,271 @@
"""Session pooling for model inference."""
import asyncio
import time
from dataclasses import dataclass
from typing import Dict, Optional, Set
import torch
from loguru import logger
from onnxruntime import (
ExecutionMode,
GraphOptimizationLevel,
InferenceSession,
SessionOptions
)
from ..core import paths
from ..core.model_config import model_config
@dataclass
class SessionInfo:
"""Session information."""
session: InferenceSession
last_used: float
stream_id: Optional[int] = None
def create_session_options(is_gpu: bool = False) -> SessionOptions:
"""Create ONNX session options.
Args:
is_gpu: Whether to use GPU configuration
Returns:
Configured session options
"""
options = SessionOptions()
config = model_config.onnx_gpu if is_gpu else model_config.onnx_cpu
# Set optimization level
if config.optimization_level == "all":
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
elif config.optimization_level == "basic":
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
else:
options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
# Configure threading
options.intra_op_num_threads = config.num_threads
options.inter_op_num_threads = config.inter_op_threads
# Set execution mode
options.execution_mode = (
ExecutionMode.ORT_PARALLEL
if config.execution_mode == "parallel"
else ExecutionMode.ORT_SEQUENTIAL
)
# Configure memory optimization
options.enable_mem_pattern = config.memory_pattern
return options
def create_provider_options(is_gpu: bool = False) -> Dict:
"""Create provider options.
Args:
is_gpu: Whether to use GPU configuration
Returns:
Provider configuration
"""
if is_gpu:
config = model_config.onnx_gpu
return {
"CUDAExecutionProvider": {
"device_id": config.device_id,
"arena_extend_strategy": config.arena_extend_strategy,
"gpu_mem_limit": int(config.gpu_mem_limit * torch.cuda.get_device_properties(0).total_memory),
"cudnn_conv_algo_search": config.cudnn_conv_algo_search,
"do_copy_in_default_stream": config.do_copy_in_default_stream
}
}
else:
return {
"CPUExecutionProvider": {
"arena_extend_strategy": model_config.onnx_cpu.arena_extend_strategy,
"cpu_memory_arena_cfg": "cpu:0"
}
}
class BaseSessionPool:
"""Base session pool implementation."""
def __init__(self, max_size: int, timeout: int):
"""Initialize session pool.
Args:
max_size: Maximum number of concurrent sessions
timeout: Session timeout in seconds
"""
self._max_size = max_size
self._timeout = timeout
self._sessions: Dict[str, SessionInfo] = {}
self._lock = asyncio.Lock()
async def get_session(self, model_path: str) -> InferenceSession:
"""Get session from pool.
Args:
model_path: Path to model file
Returns:
ONNX inference session
Raises:
RuntimeError: If no sessions available
"""
async with self._lock:
# Clean expired sessions
self._cleanup_expired()
# Check if session exists and is valid
if model_path in self._sessions:
session_info = self._sessions[model_path]
session_info.last_used = time.time()
return session_info.session
# Check if we can create new session
if len(self._sessions) >= self._max_size:
raise RuntimeError(
f"Maximum number of sessions reached ({self._max_size})"
)
# Create new session
session = await self._create_session(model_path)
self._sessions[model_path] = SessionInfo(
session=session,
last_used=time.time()
)
return session
def _cleanup_expired(self) -> None:
"""Remove expired sessions."""
current_time = time.time()
expired = [
path for path, info in self._sessions.items()
if current_time - info.last_used > self._timeout
]
for path in expired:
logger.info(f"Removing expired session: {path}")
del self._sessions[path]
async def _create_session(self, model_path: str) -> InferenceSession:
"""Create new session.
Args:
model_path: Path to model file
Returns:
ONNX inference session
"""
raise NotImplementedError
def cleanup(self) -> None:
"""Clean up all sessions."""
self._sessions.clear()
class StreamingSessionPool(BaseSessionPool):
"""GPU session pool with CUDA streams."""
def __init__(self):
"""Initialize GPU session pool."""
config = model_config.onnx_gpu
super().__init__(config.cuda_streams, config.stream_timeout)
self._available_streams: Set[int] = set(range(config.cuda_streams))
async def get_session(self, model_path: str) -> InferenceSession:
"""Get session with CUDA stream.
Args:
model_path: Path to model file
Returns:
ONNX inference session
Raises:
RuntimeError: If no streams available
"""
async with self._lock:
# Clean expired sessions
self._cleanup_expired()
# Try to find existing session
if model_path in self._sessions:
session_info = self._sessions[model_path]
session_info.last_used = time.time()
return session_info.session
# Get available stream
if not self._available_streams:
raise RuntimeError("No CUDA streams available")
stream_id = self._available_streams.pop()
try:
# Create new session
session = await self._create_session(model_path)
self._sessions[model_path] = SessionInfo(
session=session,
last_used=time.time(),
stream_id=stream_id
)
return session
except Exception:
# Return stream to pool on failure
self._available_streams.add(stream_id)
raise
def _cleanup_expired(self) -> None:
"""Remove expired sessions and return streams."""
current_time = time.time()
expired = [
path for path, info in self._sessions.items()
if current_time - info.last_used > self._timeout
]
for path in expired:
info = self._sessions[path]
if info.stream_id is not None:
self._available_streams.add(info.stream_id)
logger.info(f"Removing expired session: {path}")
del self._sessions[path]
async def _create_session(self, model_path: str) -> InferenceSession:
"""Create new session with CUDA provider."""
abs_path = await paths.get_model_path(model_path)
options = create_session_options(is_gpu=True)
provider_options = create_provider_options(is_gpu=True)
return InferenceSession(
abs_path,
sess_options=options,
providers=["CUDAExecutionProvider"],
provider_options=[provider_options]
)
class CPUSessionPool(BaseSessionPool):
"""CPU session pool."""
def __init__(self):
"""Initialize CPU session pool."""
config = model_config.onnx_cpu
super().__init__(config.max_instances, config.instance_timeout)
async def _create_session(self, model_path: str) -> InferenceSession:
"""Create new session with CPU provider."""
abs_path = await paths.get_model_path(model_path)
options = create_session_options(is_gpu=False)
provider_options = create_provider_options(is_gpu=False)
return InferenceSession(
abs_path,
sess_options=options,
providers=["CPUExecutionProvider"],
provider_options=[provider_options]
)

View file

@ -1,11 +1,10 @@
"""Voice pack management and caching."""
import os
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional
import torch
from loguru import logger
from pydantic import BaseModel
from ..core import paths
from ..core.config import settings
@ -13,7 +12,7 @@ from ..structures.model_schemas import VoiceConfig
class VoiceManager:
"""Manages voice loading, caching, and operations."""
"""Manages voice loading and operations."""
def __init__(self, config: Optional[VoiceConfig] = None):
"""Initialize voice manager.
@ -33,15 +32,8 @@ class VoiceManager:
Returns:
Path to voice file if exists, None otherwise
"""
# Get api directory path (two levels up from inference)
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
# Construct voice path relative to api directory
voice_path = os.path.join(api_dir, settings.voices_dir, f"{voice_name}.pt")
# Ensure voices directory exists
os.makedirs(os.path.dirname(voice_path), exist_ok=True)
return voice_path if os.path.exists(voice_path) else None
async def load_voice(self, voice_name: str, device: str = "cpu") -> torch.Tensor:
@ -66,21 +58,20 @@ class VoiceManager:
if self._config.use_cache and cache_key in self._voice_cache:
return self._voice_cache[cache_key]
# Load voice tensor
try:
# Load voice tensor
voice = await paths.load_voice_tensor(voice_path, device=device)
# Cache if enabled
if self._config.use_cache:
self._manage_cache()
self._voice_cache[cache_key] = voice
logger.debug(f"Cached voice: {voice_name} on {device}")
return voice
except Exception as e:
raise RuntimeError(f"Failed to load voice {voice_name}: {e}")
# Cache if enabled
if self._config.use_cache:
self._manage_cache()
self._voice_cache[cache_key] = voice
logger.debug(f"Cached voice: {voice_name} on {device}")
return voice
def _manage_cache(self) -> None:
"""Manage voice cache size."""
if len(self._voice_cache) >= self._config.cache_size:
@ -123,14 +114,14 @@ class VoiceManager:
# Get api directory path
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
voices_dir = os.path.join(api_dir, settings.voices_dir)
# Ensure voices directory exists
os.makedirs(voices_dir, exist_ok=True)
# Save combined voice
combined_path = os.path.join(voices_dir, f"{combined_name}.pt")
try:
torch.save(combined_tensor, combined_path)
# Cache the new combined voice
self._voice_cache[f"{combined_path}_{device}"] = combined_tensor
except Exception as e:
raise RuntimeError(f"Failed to save combined voice: {e}")
@ -147,17 +138,13 @@ class VoiceManager:
"""
voices = []
try:
# Get api directory path
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
voices_dir = os.path.join(api_dir, settings.voices_dir)
# Ensure voices directory exists
os.makedirs(voices_dir, exist_ok=True)
# List voice files
for entry in os.listdir(voices_dir):
if entry.endswith(".pt"):
voices.append(entry[:-3]) # Remove .pt extension
voices.append(entry[:-3])
except Exception as e:
logger.error(f"Error listing voices: {e}")
return sorted(voices)
@ -174,11 +161,8 @@ class VoiceManager:
try:
if not os.path.exists(voice_path):
return False
# Try loading voice
voice = torch.load(voice_path, map_location="cpu")
return isinstance(voice, torch.Tensor)
except Exception:
return False
@ -195,12 +179,12 @@ class VoiceManager:
}
# Module-level instance
_manager: Optional[VoiceManager] = None
# Global singleton instance
_manager_instance = None
def get_manager(config: Optional[VoiceConfig] = None) -> VoiceManager:
"""Get or create global voice manager instance.
async def get_manager(config: Optional[VoiceConfig] = None) -> VoiceManager:
"""Get global voice manager instance.
Args:
config: Optional voice configuration
@ -208,7 +192,7 @@ def get_manager(config: Optional[VoiceConfig] = None) -> VoiceManager:
Returns:
VoiceManager instance
"""
global _manager
if _manager is None:
_manager = VoiceManager(config)
return _manager
global _manager_instance
if _manager_instance is None:
_manager_instance = VoiceManager(config)
return _manager_instance

View file

@ -1,10 +1,13 @@
"""
FastAPI OpenAI Compatible API
"""
import os
import sys
from contextlib import asynccontextmanager
import torch
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
@ -41,19 +44,59 @@ setup_logger()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for model initialization"""
from .inference.model_manager import get_manager
from .inference.voice_manager import get_manager as get_voice_manager
logger.info("Loading TTS model and voice packs...")
# Initialize service
service = TTSService()
await service.ensure_initialized()
# Get available voices
voices = await service.list_voices()
voicepack_count = len(voices)
try:
# Initialize managers globally
model_manager = await get_manager()
voice_manager = await get_voice_manager()
# Get device info from model manager
device = "GPU" if settings.use_gpu else "CPU"
model = "ONNX" if settings.use_onnx else "PyTorch"
# Determine backend type based on settings
if settings.use_gpu and torch.cuda.is_available():
backend_type = 'pytorch_gpu' if not settings.use_onnx else 'onnx_gpu'
else:
backend_type = 'pytorch_cpu' if not settings.use_onnx else 'onnx_cpu'
# Get backend and initialize model
backend = model_manager.get_backend(backend_type)
# Use model path directly from settings
model_file = settings.pytorch_model_file if not settings.use_onnx else settings.onnx_model_file
model_path = os.path.join(settings.model_dir, model_file)
if not os.path.exists(model_path):
raise RuntimeError(f"Model file not found: {model_path}")
# Pre-cache default voice and use for warmup
warmup_voice = await voice_manager.load_voice(settings.default_voice, device=backend.device)
logger.info(f"Pre-cached voice {settings.default_voice} for warmup")
# Initialize model with warmup voice
await model_manager.load_model(model_path, warmup_voice, backend_type)
# Pre-cache common voices in background
common_voices = ['af', 'af_bella', 'af_sarah', 'af_nicole']
for voice_name in common_voices:
try:
await voice_manager.load_voice(voice_name, device=backend.device)
logger.debug(f"Pre-cached voice {voice_name}")
except Exception as e:
logger.warning(f"Failed to pre-cache voice {voice_name}: {e}")
# Get available voices for startup message
voices = await voice_manager.list_voices()
voicepack_count = len(voices)
# Get device info for startup message
device = "GPU" if settings.use_gpu else "CPU"
model = "ONNX" if settings.use_onnx else "PyTorch"
except Exception as e:
logger.error(f"Failed to initialize model: {e}")
raise
boundary = "" * 2*12
startup_msg = f"""

View file

@ -1,7 +1,7 @@
from typing import List
import numpy as np
from fastapi import APIRouter, Depends, HTTPException, Response
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from loguru import logger
from ..services.audio import AudioService
@ -16,9 +16,9 @@ from ..structures.text_schemas import (
router = APIRouter(tags=["text processing"])
def get_tts_service() -> TTSService:
async def get_tts_service() -> TTSService:
"""Dependency to get TTSService instance"""
return TTSService()
return await TTSService.create() # Create service with properly initialized managers
@router.post("/text/phonemize", response_model=PhonemeResponse, tags=["deprecated"])
@ -82,9 +82,6 @@ async def generate_from_phonemes(
)
try:
# Ensure service is initialized
await tts_service.ensure_initialized()
# Validate voice exists
available_voices = await tts_service.list_voices()
if request.voice not in available_voices:

View file

@ -1,6 +1,6 @@
from typing import AsyncGenerator, List, Union
from fastapi import APIRouter, Depends, Header, HTTPException, Response, Request
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
from fastapi.responses import StreamingResponse
from loguru import logger
@ -13,10 +13,28 @@ router = APIRouter(
responses={404: {"description": "Not found"}},
)
# Global TTSService instance with lock
_tts_service = None
_init_lock = None
def get_tts_service() -> TTSService:
"""Dependency to get TTSService instance with database session"""
return TTSService() # Initialize TTSService with default settings
async def get_tts_service() -> TTSService:
"""Get global TTSService instance"""
global _tts_service, _init_lock
# Create lock if needed
if _init_lock is None:
import asyncio
_init_lock = asyncio.Lock()
# Initialize service if needed
if _tts_service is None:
async with _init_lock:
# Double check pattern
if _tts_service is None:
_tts_service = await TTSService.create()
logger.info("Created global TTSService instance")
return _tts_service
async def process_voices(
@ -78,11 +96,13 @@ async def stream_audio_chunks(
async def create_speech(
request: OpenAISpeechRequest,
client_request: Request,
tts_service: TTSService = Depends(get_tts_service),
x_raw_response: str = Header(None, alias="x-raw-response"),
):
"""OpenAI-compatible endpoint for text-to-speech"""
try:
# Get global service instance
tts_service = await get_tts_service()
# Process voice combination and validate
voice_to_use = await process_voices(request.voice, tts_service)
@ -145,9 +165,10 @@ async def create_speech(
@router.get("/audio/voices")
async def list_voices(tts_service: TTSService = Depends(get_tts_service)):
async def list_voices():
"""List all available voices for text-to-speech"""
try:
tts_service = await get_tts_service()
voices = await tts_service.list_voices()
return {"voices": voices}
except Exception as e:
@ -156,9 +177,7 @@ async def list_voices(tts_service: TTSService = Depends(get_tts_service)):
@router.post("/audio/voices/combine")
async def combine_voices(
request: Union[str, List[str]], tts_service: TTSService = Depends(get_tts_service)
):
async def combine_voices(request: Union[str, List[str]]):
"""Combine multiple voices into a new voice.
Args:
@ -174,6 +193,7 @@ async def combine_voices(
- 500: Server error (file system issues, combination failed)
"""
try:
tts_service = await get_tts_service()
combined_voice = await process_voices(request, tts_service)
voices = await tts_service.list_voices()
return {"voices": voices, "voice": combined_voice}

View file

@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
import phonemizer
from .normalizer import normalize_text
phonemizers = {}
class PhonemizerBackend(ABC):
"""Abstract base class for phonemization backends"""
@ -91,8 +91,9 @@ def phonemize(text: str, language: str = "a", normalize: bool = True) -> str:
Returns:
Phonemized text
"""
global phonemizers
if normalize:
text = normalize_text(text)
phonemizer = create_phonemizer(language)
return phonemizer.phonemize(text)
if language not in phonemizers:
phonemizers[language]=create_phonemizer(language)
return phonemizers[language].phonemize(text)

View file

@ -1,9 +1,8 @@
"""TTS service using model and voice managers."""
import io
import os
import time
from typing import List, Tuple
from typing import List, Tuple, Optional
import numpy as np
import scipy.io.wavfile as wavfile
@ -17,9 +16,14 @@ from .audio import AudioNormalizer, AudioService
from .text_processing import chunker, normalize_text, process_text
import asyncio
class TTSService:
"""Text-to-speech service."""
# Limit concurrent chunk processing
_chunk_semaphore = asyncio.Semaphore(4)
def __init__(self, output_dir: str = None):
"""Initialize service.
@ -27,53 +31,24 @@ class TTSService:
output_dir: Optional output directory for saving audio
"""
self.output_dir = output_dir
self.model_manager = get_model_manager()
self.voice_manager = get_voice_manager()
self._initialized = False
self._initialization_error = None
self.model_manager = None
self._voice_manager = None
async def ensure_initialized(self):
"""Ensure model is initialized."""
if self._initialized:
return
if self._initialization_error:
raise self._initialization_error
try:
# Get api directory path (one level up from src)
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
@classmethod
async def create(cls, output_dir: str = None) -> 'TTSService':
"""Create and initialize TTSService instance.
Args:
output_dir: Optional output directory for saving audio
# Determine model file and backend based on hardware
if settings.use_gpu and torch.cuda.is_available():
model_file = settings.pytorch_model_file
backend_type = 'pytorch_gpu'
else:
model_file = settings.onnx_model_file
backend_type = 'onnx_cpu'
# Construct model path relative to api directory
model_path = os.path.join(api_dir, settings.model_dir, model_file)
# Ensure model directory exists
os.makedirs(os.path.dirname(model_path), exist_ok=True)
if not os.path.exists(model_path):
raise RuntimeError(f"Model file not found: {model_path}")
# Load default voice for warmup
backend = self.model_manager.get_backend(backend_type)
warmup_voice = await self.voice_manager.load_voice(settings.default_voice, device=backend.device)
logger.info(f"Loaded voice {settings.default_voice} for warmup")
# Initialize model with warmup voice
await self.model_manager.load_model(model_path, warmup_voice, backend_type)
logger.info(f"Initialized model on {backend_type} backend")
self._initialized = True
except Exception as e:
logger.error(f"Failed to initialize model: {e}")
self._initialization_error = RuntimeError(f"Model initialization failed: {e}")
raise self._initialization_error
Returns:
Initialized TTSService instance
"""
service = cls(output_dir)
# Initialize managers
service.model_manager = await get_model_manager()
service._voice_manager = await get_voice_manager()
return service
async def generate_audio(
self, text: str, voice: str, speed: float = 1.0
@ -88,8 +63,8 @@ class TTSService:
Returns:
Audio samples and processing time
"""
await self.ensure_initialized()
start_time = time.time()
voice_tensor = None
try:
# Normalize text
@ -98,31 +73,40 @@ class TTSService:
raise ValueError("Text is empty after preprocessing")
text = str(normalized)
# Process text into chunks
audio_chunks = []
for chunk in chunker.split_text(text):
try:
# Convert chunk to token IDs
tokens = process_text(chunk)
if not tokens:
continue
# Get backend and load voice
backend = self.model_manager.get_backend()
voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device)
# Get backend and load voice
backend = self.model_manager.get_backend()
voice_tensor = await self.voice_manager.load_voice(voice, device=backend.device)
# Generate audio
chunk_audio = await self.model_manager.generate(
tokens,
voice_tensor,
speed=speed
)
if chunk_audio is not None:
audio_chunks.append(chunk_audio)
except Exception as e:
logger.error(f"Failed to process chunk: '{chunk}'. Error: {str(e)}")
continue
# Get all chunks upfront
chunks = list(chunker.split_text(text))
if not chunks:
raise ValueError("No text chunks to process")
# Process chunk with concurrency control
async def process_chunk(chunk: str) -> Optional[np.ndarray]:
async with self._chunk_semaphore:
try:
tokens = process_text(chunk)
if not tokens:
return None
# Generate audio
return await self.model_manager.generate(
tokens,
voice_tensor,
speed=speed
)
except Exception as e:
logger.error(f"Failed to process chunk: '{chunk}'. Error: {str(e)}")
return None
# Process all chunks concurrently
chunk_results = await asyncio.gather(*[
process_chunk(chunk) for chunk in chunks
])
# Filter out None results and combine
audio_chunks = [chunk for chunk in chunk_results if chunk is not None]
if not audio_chunks:
raise ValueError("No audio chunks were generated successfully")
@ -134,6 +118,11 @@ class TTSService:
except Exception as e:
logger.error(f"Error in audio generation: {str(e)}")
raise
finally:
# Always clean up voice tensor
if voice_tensor is not None:
del voice_tensor
torch.cuda.empty_cache()
async def generate_audio_stream(
self,
@ -153,33 +142,34 @@ class TTSService:
Yields:
Audio chunks as bytes
"""
await self.ensure_initialized()
# Setup audio processing
stream_normalizer = AudioNormalizer()
voice_tensor = None
try:
# Setup audio processing
stream_normalizer = AudioNormalizer()
# Normalize text
normalized = normalize_text(text)
if not normalized:
raise ValueError("Text is empty after preprocessing")
text = str(normalized)
# Process chunks
is_first = True
chunk_gen = chunker.split_text(text)
current_chunk = next(chunk_gen, None)
# Get backend and load voice
backend = self.model_manager.get_backend()
voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device)
# Get all chunks upfront
chunks = list(chunker.split_text(text))
if not chunks:
raise ValueError("No text chunks to process")
# Process chunk with concurrency control
async def process_chunk(chunk: str, is_first: bool, is_last: bool) -> Optional[bytes]:
async with self._chunk_semaphore:
try:
tokens = process_text(chunk)
if not tokens:
return None
while current_chunk is not None:
next_chunk = next(chunk_gen, None)
try:
# Convert chunk to token IDs
tokens = process_text(current_chunk)
if tokens:
# Get backend and load voice
backend = self.model_manager.get_backend()
voice_tensor = await self.voice_manager.load_voice(voice, device=backend.device)
# Generate audio
chunk_audio = await self.model_manager.generate(
tokens,
@ -189,26 +179,38 @@ class TTSService:
if chunk_audio is not None:
# Convert to bytes
chunk_bytes = AudioService.convert_audio(
return AudioService.convert_audio(
chunk_audio,
24000,
output_format,
is_first_chunk=is_first,
normalizer=stream_normalizer,
is_last_chunk=(next_chunk is None),
is_last_chunk=is_last,
stream=True
)
yield chunk_bytes
is_first = False
except Exception as e:
logger.error(f"Failed to generate audio for chunk: '{chunk}'. Error: {str(e)}")
return None
except Exception as e:
logger.error(f"Failed to generate audio for chunk: '{current_chunk}'. Error: {str(e)}")
# Create tasks for all chunks
tasks = [
process_chunk(chunk, i==0, i==len(chunks)-1)
for i, chunk in enumerate(chunks)
]
current_chunk = next_chunk
# Process chunks concurrently and yield results in order
for chunk_bytes in await asyncio.gather(*tasks):
if chunk_bytes is not None:
yield chunk_bytes
except Exception as e:
logger.error(f"Error in audio generation stream: {str(e)}")
raise
finally:
# Always clean up voice tensor
if voice_tensor is not None:
del voice_tensor
torch.cuda.empty_cache()
async def combine_voices(self, voices: List[str]) -> str:
"""Combine multiple voices.
@ -219,8 +221,7 @@ class TTSService:
Returns:
Name of combined voice
"""
await self.ensure_initialized()
return await self.voice_manager.combine_voices(voices)
return await self._voice_manager.combine_voices(voices)
async def list_voices(self) -> List[str]:
"""List available voices.
@ -228,7 +229,7 @@ class TTSService:
Returns:
List of voice names
"""
return await self.voice_manager.list_voices()
return await self._voice_manager.list_voices()
def _audio_to_bytes(self, audio: np.ndarray) -> bytes:
"""Convert audio to WAV bytes.

View file

@ -1,60 +0,0 @@
import os
from typing import List, Tuple
import torch
from loguru import logger
from ..core.config import settings
from .tts_model import TTSModel
from .tts_service import TTSService
class WarmupService:
"""Service for warming up TTS models and voice caches"""
def __init__(self):
"""Initialize warmup service and ensure model is ready"""
# Initialize model if not already initialized
if TTSModel._instance is None:
TTSModel.initialize(settings.model_dir)
self.tts_service = TTSService()
def load_voices(self) -> List[Tuple[str, torch.Tensor]]:
"""Load and cache voices up to LRU limit"""
# Get all voices sorted by filename length (shorter names first, usually base voices)
voice_files = sorted(
[f for f in os.listdir(TTSModel.VOICES_DIR) if f.endswith(".pt")], key=len
)
n_voices_cache = 1
loaded_voices = []
for voice_file in voice_files[:n_voices_cache]:
try:
voice_path = os.path.join(TTSModel.VOICES_DIR, voice_file)
# load using service, lru cache
voicepack = self.tts_service._load_voice(voice_path)
loaded_voices.append(
(voice_file[:-3], voicepack)
) # Store name and tensor
# voicepack = torch.load(voice_path, map_location=TTSModel.get_device(), weights_only=True)
# logger.info(f"Loaded voice {voice_file[:-3]} into cache")
except Exception as e:
logger.error(f"Failed to load voice {voice_file}: {e}")
logger.info(f"Pre-loaded {len(loaded_voices)} voices into cache")
return loaded_voices
async def warmup_voices(
self, warmup_text: str, loaded_voices: List[Tuple[str, torch.Tensor]]
):
"""Warm up voice inference and streaming"""
n_warmups = 1
for voice_name, _ in loaded_voices[:n_warmups]:
try:
logger.info(f"Running warmup inference on voice {voice_name}")
async for _ in self.tts_service.generate_audio_stream(
warmup_text, voice_name, 1.0, "pcm"
):
pass # Process all chunks to properly warm up
logger.info(f"Completed warmup for voice {voice_name}")
except Exception as e:
logger.warning(f"Warmup failed for voice {voice_name}: {e}")

View file

@ -10,7 +10,7 @@ services:
ports:
- "8880:8880"
environment:
- PYTHONPATH=/app
- PYTHONPATH=/app:/app/api
- USE_GPU=true
- USE_ONNX=false
- PYTHONUNBUFFERED=1

View file

@ -25,9 +25,7 @@ def main() -> None:
def stream_to_speakers() -> None:
import pyaudio
player_stream = pyaudio.PyAudio().open(
format=pyaudio.paInt16, channels=1, rate=24000, output=True
)
player_stream = pyaudio.PyAudio().open(format=pyaudio.paInt16, channels=1, rate=24000, output=True)
start_time = time.time()

View file

@ -0,0 +1,53 @@
#!/usr/bin/env rye run python
import asyncio
import time
from pathlib import Path
from openai import AsyncOpenAI
# Initialize async client
openai = AsyncOpenAI(base_url="http://localhost:8880/v1", api_key="not-needed-for-local")
async def save_to_file(text: str, file_id: int) -> None:
"""Save TTS output to file asynchronously"""
speech_file_path = Path(__file__).parent / f"speech_{file_id}.mp3"
start_time = time.time()
print(f"Starting file {file_id}")
try:
# Use streaming endpoint with mp3 format
async with openai.audio.speech.with_streaming_response.create(
model="kokoro",
voice="af_bella",
input=text,
response_format="mp3"
) as response:
print(f"File {file_id} - Time to first byte: {int((time.time() - start_time) * 1000)}ms")
# Open file in binary write mode
with open(speech_file_path, 'wb') as f:
async for chunk in response.iter_bytes():
f.write(chunk)
print(f"File {file_id} completed in {int((time.time() - start_time) * 1000)}ms")
except Exception as e:
print(f"Error processing file {file_id}: {e}")
async def main() -> None:
# Different text samples for variety
texts = [
"The quick brown fox jumped over the lazy dogs. I see skies of blue and clouds of white",
"I see skies of blue and clouds of white. I see skies of blue and clouds of white",
]
# Create tasks for saving to files
file_tasks = [
save_to_file(text, i)
for i, text in enumerate(texts)
]
# Run file tasks concurrently
await asyncio.gather(*file_tasks)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,91 @@
#!/usr/bin/env rye run python
import asyncio
import time
from pathlib import Path
import pyaudio
from openai import AsyncOpenAI
# Initialize async client
openai = AsyncOpenAI(base_url="http://localhost:8880/v1", api_key="not-needed-for-local")
# Create a shared PyAudio instance
p = pyaudio.PyAudio()
async def stream_to_speakers(text: str, stream_id: int) -> None:
"""Stream TTS audio to speakers asynchronously"""
player_stream = p.open(
format=pyaudio.paInt16,
channels=1,
rate=24000,
output=True
)
start_time = time.time()
print(f"Starting stream {stream_id}")
try:
async with openai.audio.speech.with_streaming_response.create(
model="kokoro",
voice="af_bella",
response_format="pcm",
input=text
) as response:
print(f"Stream {stream_id} - Time to first byte: {int((time.time() - start_time) * 1000)}ms")
async for chunk in response.iter_bytes(chunk_size=1024):
player_stream.write(chunk)
# Small sleep to allow other coroutines to run
await asyncio.sleep(0.001)
print(f"Stream {stream_id} completed in {int((time.time() - start_time) * 1000)}ms")
finally:
player_stream.stop_stream()
player_stream.close()
async def save_to_file(text: str, file_id: int) -> None:
"""Save TTS output to file asynchronously"""
speech_file_path = Path(__file__).parent / f"speech_{file_id}.mp3"
async with openai.audio.speech.with_streaming_response.create(
model="kokoro",
voice="af_bella",
input=text
) as response:
# Open file in binary write mode
with open(speech_file_path, 'wb') as f:
async for chunk in response.iter_bytes():
f.write(chunk)
print(f"File {file_id} saved to {speech_file_path}")
async def main() -> None:
# Different text samples for variety
texts = [
"The quick brown fox jumped over the lazy dogs. I see skies of blue and clouds of white",
"I see skies of blue and clouds of white. I see skies of blue and clouds of white",
]
# Create tasks for streaming to speakers
speaker_tasks = [
stream_to_speakers(text, i)
for i, text in enumerate(texts)
]
# Create tasks for saving to files
file_tasks = [
save_to_file(text, i)
for i, text in enumerate(texts)
]
# Combine all tasks
all_tasks = speaker_tasks + file_tasks
# Run all tasks concurrently
try:
await asyncio.gather(*all_tasks)
finally:
# Clean up PyAudio
p.terminate()
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,66 @@
#!/usr/bin/env rye run python
import asyncio
import time
import pyaudio
from openai import AsyncOpenAI
# Initialize async client
openai = AsyncOpenAI(base_url="http://localhost:8880/v1", api_key="not-needed-for-local")
# Create a shared PyAudio instance
p = pyaudio.PyAudio()
async def stream_to_speakers(text: str, stream_id: int) -> None:
"""Stream TTS audio to speakers asynchronously"""
player_stream = p.open(
format=pyaudio.paInt16,
channels=1,
rate=24000,
output=True
)
start_time = time.time()
print(f"Starting stream {stream_id}")
try:
async with openai.audio.speech.with_streaming_response.create(
model="kokoro",
voice="af_bella",
response_format="pcm",
input=text
) as response:
print(f"Stream {stream_id} - Time to first byte: {int((time.time() - start_time) * 1000)}ms")
async for chunk in response.iter_bytes(chunk_size=1024):
player_stream.write(chunk)
# Small sleep to allow other coroutines to run
await asyncio.sleep(0.001)
print(f"Stream {stream_id} completed in {int((time.time() - start_time) * 1000)}ms")
finally:
player_stream.stop_stream()
player_stream.close()
async def main() -> None:
# Different text samples for variety
texts = [
"The quick brown fox jumped over the lazy dogs. I see skies of blue and clouds of white",
"I see skies of blue and clouds of white. I see skies of blue and clouds of white",
]
# Create tasks for streaming to speakers
speaker_tasks = [
stream_to_speakers(text, i)
for i, text in enumerate(texts)
]
# Run speaker tasks concurrently
try:
await asyncio.gather(*speaker_tasks)
finally:
# Clean up PyAudio
p.terminate()
if __name__ == "__main__":
asyncio.run(main())

Binary file not shown.