mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Refactor model loading and configuration: update, adjust model loading device,. add async streaming examples and remove unused warmup service.
This commit is contained in:
parent
21bf810f97
commit
4a24be1605
21 changed files with 929 additions and 484 deletions
|
@ -367,7 +367,7 @@ async def build_model(path, device):
|
||||||
decoder=decoder.to(device).eval(),
|
decoder=decoder.to(device).eval(),
|
||||||
text_encoder=text_encoder.to(device).eval(),
|
text_encoder=text_encoder.to(device).eval(),
|
||||||
)
|
)
|
||||||
weights = await load_model_weights(path, device='cpu')
|
weights = await load_model_weights(path, device=device)
|
||||||
for key, state_dict in weights['net'].items():
|
for key, state_dict in weights['net'].items():
|
||||||
assert key in model, key
|
assert key in model, key
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -15,9 +15,9 @@ class Settings(BaseSettings):
|
||||||
default_voice: str = "af"
|
default_voice: str = "af"
|
||||||
use_gpu: bool = False # Whether to use GPU acceleration if available
|
use_gpu: bool = False # Whether to use GPU acceleration if available
|
||||||
use_onnx: bool = True # Whether to use ONNX runtime
|
use_onnx: bool = True # Whether to use ONNX runtime
|
||||||
# Paths relative to api directory
|
# Container absolute paths
|
||||||
model_dir: str = "src/models" # Model directory relative to api/
|
model_dir: str = "/app/api/src/models" # Absolute path in container
|
||||||
voices_dir: str = "src/voices" # Voices directory relative to api/
|
voices_dir: str = "/app/api/src/voices" # Absolute path in container
|
||||||
|
|
||||||
# Model filenames
|
# Model filenames
|
||||||
pytorch_model_file: str = "kokoro-v0_19.pth"
|
pytorch_model_file: str = "kokoro-v0_19.pth"
|
||||||
|
|
|
@ -6,6 +6,11 @@ from pydantic import BaseModel, Field
|
||||||
class ONNXCPUConfig(BaseModel):
|
class ONNXCPUConfig(BaseModel):
|
||||||
"""ONNX CPU runtime configuration."""
|
"""ONNX CPU runtime configuration."""
|
||||||
|
|
||||||
|
# Session pooling
|
||||||
|
max_instances: int = Field(4, description="Maximum concurrent model instances")
|
||||||
|
instance_timeout: int = Field(300, description="Session timeout in seconds")
|
||||||
|
|
||||||
|
# Runtime settings
|
||||||
num_threads: int = Field(8, description="Number of threads for parallel operations")
|
num_threads: int = Field(8, description="Number of threads for parallel operations")
|
||||||
inter_op_threads: int = Field(4, description="Number of threads for operator parallelism")
|
inter_op_threads: int = Field(4, description="Number of threads for operator parallelism")
|
||||||
execution_mode: str = Field("parallel", description="ONNX execution mode")
|
execution_mode: str = Field("parallel", description="ONNX execution mode")
|
||||||
|
@ -20,9 +25,14 @@ class ONNXCPUConfig(BaseModel):
|
||||||
class ONNXGPUConfig(ONNXCPUConfig):
|
class ONNXGPUConfig(ONNXCPUConfig):
|
||||||
"""ONNX GPU-specific configuration."""
|
"""ONNX GPU-specific configuration."""
|
||||||
|
|
||||||
|
# CUDA settings
|
||||||
device_id: int = Field(0, description="CUDA device ID")
|
device_id: int = Field(0, description="CUDA device ID")
|
||||||
gpu_mem_limit: float = Field(0.7, description="Fraction of GPU memory to use")
|
gpu_mem_limit: float = Field(0.7, description="Fraction of GPU memory to use")
|
||||||
cudnn_conv_algo_search: str = Field("EXHAUSTIVE", description="CuDNN convolution algorithm search")
|
cudnn_conv_algo_search: str = Field("EXHAUSTIVE", description="CuDNN convolution algorithm search")
|
||||||
|
|
||||||
|
# Stream management
|
||||||
|
cuda_streams: int = Field(2, description="Number of CUDA streams for inference")
|
||||||
|
stream_timeout: int = Field(60, description="Stream timeout in seconds")
|
||||||
do_copy_in_default_stream: bool = Field(True, description="Copy in default CUDA stream")
|
do_copy_in_default_stream: bool = Field(True, description="Copy in default CUDA stream")
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
@ -32,8 +42,6 @@ class ONNXGPUConfig(ONNXCPUConfig):
|
||||||
class PyTorchCPUConfig(BaseModel):
|
class PyTorchCPUConfig(BaseModel):
|
||||||
"""PyTorch CPU backend configuration."""
|
"""PyTorch CPU backend configuration."""
|
||||||
|
|
||||||
max_batch_size: int = Field(32, description="Maximum batch size for batched inference")
|
|
||||||
stream_buffer_size: int = Field(8, description="Size of stream buffer")
|
|
||||||
memory_threshold: float = Field(0.8, description="Memory threshold for cleanup")
|
memory_threshold: float = Field(0.8, description="Memory threshold for cleanup")
|
||||||
retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors")
|
retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors")
|
||||||
num_threads: int = Field(8, description="Number of threads for parallel operations")
|
num_threads: int = Field(8, description="Number of threads for parallel operations")
|
||||||
|
@ -49,18 +57,11 @@ class PyTorchGPUConfig(BaseModel):
|
||||||
device_id: int = Field(0, description="CUDA device ID")
|
device_id: int = Field(0, description="CUDA device ID")
|
||||||
use_fp16: bool = Field(True, description="Whether to use FP16 precision")
|
use_fp16: bool = Field(True, description="Whether to use FP16 precision")
|
||||||
use_triton: bool = Field(True, description="Whether to use Triton for CUDA kernels")
|
use_triton: bool = Field(True, description="Whether to use Triton for CUDA kernels")
|
||||||
max_batch_size: int = Field(32, description="Maximum batch size for batched inference")
|
|
||||||
stream_buffer_size: int = Field(8, description="Size of CUDA stream buffer")
|
|
||||||
memory_threshold: float = Field(0.8, description="Memory threshold for cleanup")
|
memory_threshold: float = Field(0.8, description="Memory threshold for cleanup")
|
||||||
retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors")
|
retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors")
|
||||||
sync_cuda: bool = Field(True, description="Whether to synchronize CUDA operations")
|
sync_cuda: bool = Field(True, description="Whether to synchronize CUDA operations")
|
||||||
|
cuda_streams: int = Field(2, description="Number of CUDA streams for inference")
|
||||||
class Config:
|
stream_timeout: int = Field(60, description="Stream timeout in seconds")
|
||||||
frozen = True
|
|
||||||
"""PyTorch CPU-specific configuration."""
|
|
||||||
|
|
||||||
num_threads: int = Field(8, description="Number of threads for parallel operations")
|
|
||||||
pin_memory: bool = Field(True, description="Whether to pin memory for faster CPU-GPU transfer")
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
frozen = True
|
frozen = True
|
||||||
|
@ -74,7 +75,7 @@ class ModelConfig(BaseModel):
|
||||||
device_type: str = Field("auto", description="Device type ('cpu', 'gpu', or 'auto')")
|
device_type: str = Field("auto", description="Device type ('cpu', 'gpu', or 'auto')")
|
||||||
cache_models: bool = Field(True, description="Whether to cache loaded models")
|
cache_models: bool = Field(True, description="Whether to cache loaded models")
|
||||||
cache_voices: bool = Field(True, description="Whether to cache voice tensors")
|
cache_voices: bool = Field(True, description="Whether to cache voice tensors")
|
||||||
voice_cache_size: int = Field(10, description="Maximum number of cached voices")
|
voice_cache_size: int = Field(2, description="Maximum number of cached voices")
|
||||||
|
|
||||||
# Backend-specific configs
|
# Backend-specific configs
|
||||||
onnx_cpu: ONNXCPUConfig = Field(default_factory=ONNXCPUConfig)
|
onnx_cpu: ONNXCPUConfig = Field(default_factory=ONNXCPUConfig)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
"""Model management and caching."""
|
"""Model management and caching."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -13,6 +14,14 @@ from .onnx_cpu import ONNXCPUBackend
|
||||||
from .onnx_gpu import ONNXGPUBackend
|
from .onnx_gpu import ONNXGPUBackend
|
||||||
from .pytorch_cpu import PyTorchCPUBackend
|
from .pytorch_cpu import PyTorchCPUBackend
|
||||||
from .pytorch_gpu import PyTorchGPUBackend
|
from .pytorch_gpu import PyTorchGPUBackend
|
||||||
|
from .session_pool import CPUSessionPool, StreamingSessionPool
|
||||||
|
|
||||||
|
|
||||||
|
# Global singleton instance and state
|
||||||
|
_manager_instance = None
|
||||||
|
_manager_lock = asyncio.Lock()
|
||||||
|
_loaded_models = {}
|
||||||
|
_backends = {}
|
||||||
|
|
||||||
|
|
||||||
class ModelManager:
|
class ModelManager:
|
||||||
|
@ -25,65 +34,60 @@ class ModelManager:
|
||||||
config: Optional configuration
|
config: Optional configuration
|
||||||
"""
|
"""
|
||||||
self._config = config or model_config
|
self._config = config or model_config
|
||||||
self._backends: Dict[str, BaseModelBackend] = {}
|
global _loaded_models, _backends
|
||||||
self._current_backend: Optional[str] = None
|
self._loaded_models = _loaded_models
|
||||||
self._initialize_backends()
|
self._backends = _backends
|
||||||
|
|
||||||
def _initialize_backends(self) -> None:
|
# Initialize session pools
|
||||||
"""Initialize available backends based on settings."""
|
self._session_pools = {
|
||||||
has_gpu = settings.use_gpu and torch.cuda.is_available()
|
'onnx_cpu': CPUSessionPool(),
|
||||||
|
'onnx_gpu': StreamingSessionPool()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Initialize locks
|
||||||
|
self._backend_locks: Dict[str, asyncio.Lock] = {}
|
||||||
|
|
||||||
|
def _determine_device(self) -> str:
|
||||||
|
"""Determine device based on settings."""
|
||||||
|
if settings.use_gpu and torch.cuda.is_available():
|
||||||
|
return "cuda"
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""Initialize backends."""
|
||||||
|
if self._backends:
|
||||||
|
logger.debug("Using existing backend instances")
|
||||||
|
return
|
||||||
|
|
||||||
|
device = self._determine_device()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if has_gpu:
|
if device == "cuda":
|
||||||
if settings.use_onnx:
|
if settings.use_onnx:
|
||||||
# ONNX GPU primary
|
|
||||||
self._backends['onnx_gpu'] = ONNXGPUBackend()
|
self._backends['onnx_gpu'] = ONNXGPUBackend()
|
||||||
self._current_backend = 'onnx_gpu'
|
self._current_backend = 'onnx_gpu'
|
||||||
logger.info("Initialized ONNX GPU backend")
|
logger.info("Initialized new ONNX GPU backend")
|
||||||
|
|
||||||
# PyTorch GPU fallback
|
|
||||||
self._backends['pytorch_gpu'] = PyTorchGPUBackend()
|
|
||||||
logger.info("Initialized PyTorch GPU backend")
|
|
||||||
else:
|
else:
|
||||||
# PyTorch GPU primary
|
|
||||||
self._backends['pytorch_gpu'] = PyTorchGPUBackend()
|
self._backends['pytorch_gpu'] = PyTorchGPUBackend()
|
||||||
self._current_backend = 'pytorch_gpu'
|
self._current_backend = 'pytorch_gpu'
|
||||||
logger.info("Initialized PyTorch GPU backend")
|
logger.info("Initialized new PyTorch GPU backend")
|
||||||
|
|
||||||
# ONNX GPU fallback
|
|
||||||
self._backends['onnx_gpu'] = ONNXGPUBackend()
|
|
||||||
logger.info("Initialized ONNX GPU backend")
|
|
||||||
else:
|
else:
|
||||||
self._initialize_cpu_backends()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to initialize GPU backends: {e}")
|
|
||||||
# Fallback to CPU if GPU fails
|
|
||||||
self._initialize_cpu_backends()
|
|
||||||
|
|
||||||
def _initialize_cpu_backends(self) -> None:
|
|
||||||
"""Initialize CPU backends based on settings."""
|
|
||||||
try:
|
|
||||||
if settings.use_onnx:
|
if settings.use_onnx:
|
||||||
# ONNX CPU primary
|
|
||||||
self._backends['onnx_cpu'] = ONNXCPUBackend()
|
self._backends['onnx_cpu'] = ONNXCPUBackend()
|
||||||
self._current_backend = 'onnx_cpu'
|
self._current_backend = 'onnx_cpu'
|
||||||
logger.info("Initialized ONNX CPU backend")
|
logger.info("Initialized new ONNX CPU backend")
|
||||||
|
|
||||||
# PyTorch CPU fallback
|
|
||||||
self._backends['pytorch_cpu'] = PyTorchCPUBackend()
|
|
||||||
logger.info("Initialized PyTorch CPU backend")
|
|
||||||
else:
|
else:
|
||||||
# PyTorch CPU primary
|
|
||||||
self._backends['pytorch_cpu'] = PyTorchCPUBackend()
|
self._backends['pytorch_cpu'] = PyTorchCPUBackend()
|
||||||
self._current_backend = 'pytorch_cpu'
|
self._current_backend = 'pytorch_cpu'
|
||||||
logger.info("Initialized PyTorch CPU backend")
|
logger.info("Initialized new PyTorch CPU backend")
|
||||||
|
|
||||||
|
# Initialize locks for each backend
|
||||||
|
for backend in self._backends:
|
||||||
|
self._backend_locks[backend] = asyncio.Lock()
|
||||||
|
|
||||||
# ONNX CPU fallback
|
|
||||||
self._backends['onnx_cpu'] = ONNXCPUBackend()
|
|
||||||
logger.info("Initialized ONNX CPU backend")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to initialize CPU backends: {e}")
|
logger.error(f"Failed to initialize backend: {e}")
|
||||||
raise RuntimeError("No backends available")
|
raise RuntimeError("Failed to initialize backend")
|
||||||
|
|
||||||
def get_backend(self, backend_type: Optional[str] = None) -> BaseModelBackend:
|
def get_backend(self, backend_type: Optional[str] = None) -> BaseModelBackend:
|
||||||
"""Get specified backend.
|
"""Get specified backend.
|
||||||
|
@ -154,17 +158,40 @@ class ModelManager:
|
||||||
if backend_type is None:
|
if backend_type is None:
|
||||||
backend_type = self._determine_backend(abs_path)
|
backend_type = self._determine_backend(abs_path)
|
||||||
|
|
||||||
|
# Get backend lock
|
||||||
|
lock = self._backend_locks[backend_type]
|
||||||
|
|
||||||
|
async with lock:
|
||||||
backend = self.get_backend(backend_type)
|
backend = self.get_backend(backend_type)
|
||||||
|
|
||||||
|
# For ONNX backends, use session pool
|
||||||
|
if backend_type.startswith('onnx'):
|
||||||
|
pool = self._session_pools[backend_type]
|
||||||
|
backend._session = await pool.get_session(abs_path)
|
||||||
|
self._loaded_models[backend_type] = abs_path
|
||||||
|
logger.info(f"Fetched model instance from {backend_type} pool")
|
||||||
|
|
||||||
|
# For PyTorch backends, load normally
|
||||||
|
else:
|
||||||
|
# Check if model is already loaded
|
||||||
|
if (backend_type in self._loaded_models and
|
||||||
|
self._loaded_models[backend_type] == abs_path and
|
||||||
|
backend.is_loaded):
|
||||||
|
logger.info(f"Fetching existing model instance from {backend_type}")
|
||||||
|
return
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
await backend.load_model(abs_path)
|
await backend.load_model(abs_path)
|
||||||
logger.info(f"Loaded model on {backend_type} backend")
|
self._loaded_models[backend_type] = abs_path
|
||||||
|
logger.info(f"Initialized new model instance on {backend_type}")
|
||||||
|
|
||||||
# Run warmup if voice provided
|
# Run warmup if voice provided
|
||||||
if warmup_voice is not None:
|
if warmup_voice is not None:
|
||||||
await self._warmup_inference(backend, warmup_voice)
|
await self._warmup_inference(backend, warmup_voice)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
# Clear cached path on failure
|
||||||
|
self._loaded_models.pop(backend_type, None)
|
||||||
raise RuntimeError(f"Failed to load model: {e}")
|
raise RuntimeError(f"Failed to load model: {e}")
|
||||||
|
|
||||||
async def _warmup_inference(self, backend: BaseModelBackend, voice: torch.Tensor) -> None:
|
async def _warmup_inference(self, backend: BaseModelBackend, voice: torch.Tensor) -> None:
|
||||||
|
@ -188,7 +215,7 @@ class ModelManager:
|
||||||
|
|
||||||
# Run inference
|
# Run inference
|
||||||
backend.generate(tokens, voice, speed=1.0)
|
backend.generate(tokens, voice, speed=1.0)
|
||||||
logger.info("Completed warmup inference")
|
logger.debug("Completed warmup inference")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Warmup inference failed: {e}")
|
logger.warning(f"Warmup inference failed: {e}")
|
||||||
|
@ -221,16 +248,23 @@ class ModelManager:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Generate audio using provided voice tensor
|
# Generate audio using provided voice tensor
|
||||||
|
# No lock needed here since inference is thread-safe
|
||||||
return backend.generate(tokens, voice, speed)
|
return backend.generate(tokens, voice, speed)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Generation failed: {e}")
|
raise RuntimeError(f"Generation failed: {e}")
|
||||||
|
|
||||||
def unload_all(self) -> None:
|
def unload_all(self) -> None:
|
||||||
"""Unload models from all backends."""
|
"""Unload models from all backends and clear cache."""
|
||||||
|
# Clean up session pools
|
||||||
|
for pool in self._session_pools.values():
|
||||||
|
pool.cleanup()
|
||||||
|
|
||||||
|
# Unload PyTorch backends
|
||||||
for backend in self._backends.values():
|
for backend in self._backends.values():
|
||||||
backend.unload()
|
backend.unload()
|
||||||
logger.info("Unloaded all models")
|
|
||||||
|
self._loaded_models.clear()
|
||||||
|
logger.info("Unloaded all models and cleared cache")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def available_backends(self) -> list[str]:
|
def available_backends(self) -> list[str]:
|
||||||
|
@ -251,12 +285,8 @@ class ModelManager:
|
||||||
return self._current_backend
|
return self._current_backend
|
||||||
|
|
||||||
|
|
||||||
# Module-level instance
|
async def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
|
||||||
_manager: Optional[ModelManager] = None
|
"""Get global model manager instance.
|
||||||
|
|
||||||
|
|
||||||
def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
|
|
||||||
"""Get or create global model manager instance.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: Optional model configuration
|
config: Optional model configuration
|
||||||
|
@ -264,7 +294,10 @@ def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
|
||||||
Returns:
|
Returns:
|
||||||
ModelManager instance
|
ModelManager instance
|
||||||
"""
|
"""
|
||||||
global _manager
|
global _manager_instance
|
||||||
if _manager is None:
|
|
||||||
_manager = ModelManager(config)
|
async with _manager_lock:
|
||||||
return _manager
|
if _manager_instance is None:
|
||||||
|
_manager_instance = ModelManager(config)
|
||||||
|
await _manager_instance.initialize()
|
||||||
|
return _manager_instance
|
|
@ -1,20 +1,16 @@
|
||||||
"""CPU-based ONNX inference backend."""
|
"""CPU-based ONNX inference backend."""
|
||||||
|
|
||||||
from typing import Dict, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from onnxruntime import (
|
from onnxruntime import InferenceSession
|
||||||
ExecutionMode,
|
|
||||||
GraphOptimizationLevel,
|
|
||||||
InferenceSession,
|
|
||||||
SessionOptions
|
|
||||||
)
|
|
||||||
|
|
||||||
from ..core import paths
|
from ..core import paths
|
||||||
from ..core.model_config import model_config
|
from ..core.model_config import model_config
|
||||||
from .base import BaseModelBackend
|
from .base import BaseModelBackend
|
||||||
|
from .session_pool import create_session_options, create_provider_options
|
||||||
|
|
||||||
|
|
||||||
class ONNXCPUBackend(BaseModelBackend):
|
class ONNXCPUBackend(BaseModelBackend):
|
||||||
|
@ -47,8 +43,8 @@ class ONNXCPUBackend(BaseModelBackend):
|
||||||
logger.info(f"Loading ONNX model: {model_path}")
|
logger.info(f"Loading ONNX model: {model_path}")
|
||||||
|
|
||||||
# Configure session
|
# Configure session
|
||||||
options = self._create_session_options()
|
options = create_session_options(is_gpu=False)
|
||||||
provider_options = self._create_provider_options()
|
provider_options = create_provider_options(is_gpu=False)
|
||||||
|
|
||||||
# Create session
|
# Create session
|
||||||
self._session = InferenceSession(
|
self._session = InferenceSession(
|
||||||
|
@ -84,9 +80,9 @@ class ONNXCPUBackend(BaseModelBackend):
|
||||||
raise RuntimeError("Model not loaded")
|
raise RuntimeError("Model not loaded")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Prepare inputs
|
# Prepare inputs with start/end tokens
|
||||||
tokens_input = np.array([tokens], dtype=np.int64)
|
tokens_input = np.array([[0, *tokens, 0]], dtype=np.int64) # Add start/end tokens
|
||||||
style_input = voice[len(tokens)].numpy()
|
style_input = voice[len(tokens) + 2].numpy() # Adjust index for start/end tokens
|
||||||
speed_input = np.full(1, speed, dtype=np.float32)
|
speed_input = np.full(1, speed, dtype=np.float32)
|
||||||
|
|
||||||
# Run inference
|
# Run inference
|
||||||
|
@ -104,52 +100,6 @@ class ONNXCPUBackend(BaseModelBackend):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Generation failed: {e}")
|
raise RuntimeError(f"Generation failed: {e}")
|
||||||
|
|
||||||
def _create_session_options(self) -> SessionOptions:
|
|
||||||
"""Create ONNX session options.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Configured session options
|
|
||||||
"""
|
|
||||||
options = SessionOptions()
|
|
||||||
config = model_config.onnx_cpu
|
|
||||||
|
|
||||||
# Set optimization level
|
|
||||||
if config.optimization_level == "all":
|
|
||||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
||||||
elif config.optimization_level == "basic":
|
|
||||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
|
|
||||||
else:
|
|
||||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
|
|
||||||
|
|
||||||
# Configure threading
|
|
||||||
options.intra_op_num_threads = config.num_threads
|
|
||||||
options.inter_op_num_threads = config.inter_op_threads
|
|
||||||
|
|
||||||
# Set execution mode
|
|
||||||
options.execution_mode = (
|
|
||||||
ExecutionMode.ORT_PARALLEL
|
|
||||||
if config.execution_mode == "parallel"
|
|
||||||
else ExecutionMode.ORT_SEQUENTIAL
|
|
||||||
)
|
|
||||||
|
|
||||||
# Configure memory optimization
|
|
||||||
options.enable_mem_pattern = config.memory_pattern
|
|
||||||
|
|
||||||
return options
|
|
||||||
|
|
||||||
def _create_provider_options(self) -> Dict:
|
|
||||||
"""Create CPU provider options.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Provider configuration
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
"CPUExecutionProvider": {
|
|
||||||
"arena_extend_strategy": model_config.onnx_cpu.arena_extend_strategy,
|
|
||||||
"cpu_memory_arena_cfg": "cpu:0"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def unload(self) -> None:
|
def unload(self) -> None:
|
||||||
"""Unload model and free resources."""
|
"""Unload model and free resources."""
|
||||||
if self._session is not None:
|
if self._session is not None:
|
||||||
|
|
|
@ -1,20 +1,16 @@
|
||||||
"""GPU-based ONNX inference backend."""
|
"""GPU-based ONNX inference backend."""
|
||||||
|
|
||||||
from typing import Dict, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from onnxruntime import (
|
from onnxruntime import InferenceSession
|
||||||
ExecutionMode,
|
|
||||||
GraphOptimizationLevel,
|
|
||||||
InferenceSession,
|
|
||||||
SessionOptions
|
|
||||||
)
|
|
||||||
|
|
||||||
from ..core import paths
|
from ..core import paths
|
||||||
from ..core.model_config import model_config
|
from ..core.model_config import model_config
|
||||||
from .base import BaseModelBackend
|
from .base import BaseModelBackend
|
||||||
|
from .session_pool import create_session_options, create_provider_options
|
||||||
|
|
||||||
|
|
||||||
class ONNXGPUBackend(BaseModelBackend):
|
class ONNXGPUBackend(BaseModelBackend):
|
||||||
|
@ -28,6 +24,9 @@ class ONNXGPUBackend(BaseModelBackend):
|
||||||
self._device = "cuda"
|
self._device = "cuda"
|
||||||
self._session: Optional[InferenceSession] = None
|
self._session: Optional[InferenceSession] = None
|
||||||
|
|
||||||
|
# Configure GPU
|
||||||
|
torch.cuda.set_device(model_config.onnx_gpu.device_id)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_loaded(self) -> bool:
|
def is_loaded(self) -> bool:
|
||||||
"""Check if model is loaded."""
|
"""Check if model is loaded."""
|
||||||
|
@ -49,8 +48,8 @@ class ONNXGPUBackend(BaseModelBackend):
|
||||||
logger.info(f"Loading ONNX model on GPU: {model_path}")
|
logger.info(f"Loading ONNX model on GPU: {model_path}")
|
||||||
|
|
||||||
# Configure session
|
# Configure session
|
||||||
options = self._create_session_options()
|
options = create_session_options(is_gpu=True)
|
||||||
provider_options = self._create_provider_options()
|
provider_options = create_provider_options(is_gpu=True)
|
||||||
|
|
||||||
# Create session with CUDA provider
|
# Create session with CUDA provider
|
||||||
self._session = InferenceSession(
|
self._session = InferenceSession(
|
||||||
|
@ -87,8 +86,8 @@ class ONNXGPUBackend(BaseModelBackend):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Prepare inputs
|
# Prepare inputs
|
||||||
tokens_input = np.array([tokens], dtype=np.int64)
|
tokens_input = np.array([[0, *tokens, 0]], dtype=np.int64) # Add start/end tokens
|
||||||
style_input = voice[len(tokens)].cpu().numpy() # Move to CPU for ONNX
|
style_input = voice[len(tokens) + 2].cpu().numpy() # Move to CPU for ONNX
|
||||||
speed_input = np.full(1, speed, dtype=np.float32)
|
speed_input = np.full(1, speed, dtype=np.float32)
|
||||||
|
|
||||||
# Run inference
|
# Run inference
|
||||||
|
@ -104,62 +103,15 @@ class ONNXGPUBackend(BaseModelBackend):
|
||||||
return result[0]
|
return result[0]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
if "out of memory" in str(e).lower():
|
||||||
|
# Clear CUDA cache and retry
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return self.generate(tokens, voice, speed)
|
||||||
raise RuntimeError(f"Generation failed: {e}")
|
raise RuntimeError(f"Generation failed: {e}")
|
||||||
|
|
||||||
def _create_session_options(self) -> SessionOptions:
|
|
||||||
"""Create ONNX session options.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Configured session options
|
|
||||||
"""
|
|
||||||
options = SessionOptions()
|
|
||||||
config = model_config.onnx_gpu
|
|
||||||
|
|
||||||
# Set optimization level
|
|
||||||
if config.optimization_level == "all":
|
|
||||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
||||||
elif config.optimization_level == "basic":
|
|
||||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
|
|
||||||
else:
|
|
||||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
|
|
||||||
|
|
||||||
# Configure threading
|
|
||||||
options.intra_op_num_threads = config.num_threads
|
|
||||||
options.inter_op_num_threads = config.inter_op_threads
|
|
||||||
|
|
||||||
# Set execution mode
|
|
||||||
options.execution_mode = (
|
|
||||||
ExecutionMode.ORT_PARALLEL
|
|
||||||
if config.execution_mode == "parallel"
|
|
||||||
else ExecutionMode.ORT_SEQUENTIAL
|
|
||||||
)
|
|
||||||
|
|
||||||
# Configure memory optimization
|
|
||||||
options.enable_mem_pattern = config.memory_pattern
|
|
||||||
|
|
||||||
return options
|
|
||||||
|
|
||||||
def _create_provider_options(self) -> Dict:
|
|
||||||
"""Create CUDA provider options.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Provider configuration
|
|
||||||
"""
|
|
||||||
config = model_config.onnx_gpu
|
|
||||||
return {
|
|
||||||
"CUDAExecutionProvider": {
|
|
||||||
"device_id": config.device_id,
|
|
||||||
"arena_extend_strategy": config.arena_extend_strategy,
|
|
||||||
"gpu_mem_limit": int(config.gpu_mem_limit * torch.cuda.get_device_properties(0).total_memory),
|
|
||||||
"cudnn_conv_algo_search": config.cudnn_conv_algo_search,
|
|
||||||
"do_copy_in_default_stream": config.do_copy_in_default_stream
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def unload(self) -> None:
|
def unload(self) -> None:
|
||||||
"""Unload model and free resources."""
|
"""Unload model and free resources."""
|
||||||
if self._session is not None:
|
if self._session is not None:
|
||||||
del self._session
|
del self._session
|
||||||
self._session = None
|
self._session = None
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
|
@ -13,8 +13,37 @@ from ..core.model_config import model_config
|
||||||
from .base import BaseModelBackend
|
from .base import BaseModelBackend
|
||||||
|
|
||||||
|
|
||||||
|
class CUDAStreamManager:
|
||||||
|
"""CUDA stream manager."""
|
||||||
|
|
||||||
|
def __init__(self, num_streams: int):
|
||||||
|
"""Initialize stream manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_streams: Number of CUDA streams
|
||||||
|
"""
|
||||||
|
self.streams = [torch.cuda.Stream() for _ in range(num_streams)]
|
||||||
|
self._current = 0
|
||||||
|
|
||||||
|
def get_next_stream(self) -> torch.cuda.Stream:
|
||||||
|
"""Get next available stream.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CUDA stream
|
||||||
|
"""
|
||||||
|
stream = self.streams[self._current]
|
||||||
|
self._current = (self._current + 1) % len(self.streams)
|
||||||
|
return stream
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(model: torch.nn.Module, tokens: list[int], ref_s: torch.Tensor, speed: float) -> np.ndarray:
|
def forward(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
tokens: list[int],
|
||||||
|
ref_s: torch.Tensor,
|
||||||
|
speed: float,
|
||||||
|
stream: Optional[torch.cuda.Stream] = None
|
||||||
|
) -> np.ndarray:
|
||||||
"""Forward pass through model.
|
"""Forward pass through model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -22,12 +51,15 @@ def forward(model: torch.nn.Module, tokens: list[int], ref_s: torch.Tensor, spee
|
||||||
tokens: Input tokens
|
tokens: Input tokens
|
||||||
ref_s: Reference signal (shape: [1, n_features])
|
ref_s: Reference signal (shape: [1, n_features])
|
||||||
speed: Speed multiplier
|
speed: Speed multiplier
|
||||||
|
stream: Optional CUDA stream
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Generated audio
|
Generated audio
|
||||||
"""
|
"""
|
||||||
device = ref_s.device
|
device = ref_s.device
|
||||||
|
|
||||||
|
# Use provided stream or default
|
||||||
|
with torch.cuda.stream(stream) if stream else torch.cuda.default_stream():
|
||||||
# Initial tensor setup
|
# Initial tensor setup
|
||||||
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
||||||
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
||||||
|
@ -74,6 +106,11 @@ def forward(model: torch.nn.Module, tokens: list[int], ref_s: torch.Tensor, spee
|
||||||
|
|
||||||
# Generate output
|
# Generate output
|
||||||
output = model.decoder(asr, F0_pred, N_pred, s_ref)
|
output = model.decoder(asr, F0_pred, N_pred, s_ref)
|
||||||
|
|
||||||
|
# Ensure operation completion if using custom stream
|
||||||
|
if stream:
|
||||||
|
stream.synchronize()
|
||||||
|
|
||||||
return output.squeeze().cpu().numpy()
|
return output.squeeze().cpu().numpy()
|
||||||
|
|
||||||
|
|
||||||
|
@ -92,9 +129,10 @@ class PyTorchGPUBackend(BaseModelBackend):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize GPU backend."""
|
"""Initialize GPU backend."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if not torch.cuda.is_available():
|
from ..core.config import settings
|
||||||
raise RuntimeError("CUDA not available")
|
if not (settings.use_gpu and torch.cuda.is_available()):
|
||||||
self._device = "cuda"
|
raise RuntimeError("GPU backend requires GPU support and use_gpu=True")
|
||||||
|
self._device = "cuda" # Device is enforced by backend selection in model_manager
|
||||||
self._model: Optional[torch.nn.Module] = None
|
self._model: Optional[torch.nn.Module] = None
|
||||||
|
|
||||||
# Configure GPU settings
|
# Configure GPU settings
|
||||||
|
@ -103,6 +141,9 @@ class PyTorchGPUBackend(BaseModelBackend):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
torch.cuda.set_device(config.device_id)
|
torch.cuda.set_device(config.device_id)
|
||||||
|
|
||||||
|
# Initialize stream manager
|
||||||
|
self._stream_manager = CUDAStreamManager(config.cuda_streams)
|
||||||
|
|
||||||
async def load_model(self, path: str) -> None:
|
async def load_model(self, path: str) -> None:
|
||||||
"""Load PyTorch model.
|
"""Load PyTorch model.
|
||||||
|
|
||||||
|
@ -154,8 +195,11 @@ class PyTorchGPUBackend(BaseModelBackend):
|
||||||
if ref_s.dim() == 1:
|
if ref_s.dim() == 1:
|
||||||
ref_s = ref_s.unsqueeze(0) # Add batch dimension if needed
|
ref_s = ref_s.unsqueeze(0) # Add batch dimension if needed
|
||||||
|
|
||||||
# Generate audio
|
# Get next available stream
|
||||||
return forward(self._model, tokens, ref_s, speed)
|
stream = self._stream_manager.get_next_stream()
|
||||||
|
|
||||||
|
# Generate audio using stream
|
||||||
|
return forward(self._model, tokens, ref_s, speed, stream)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Generation failed: {e}")
|
logger.error(f"Generation failed: {e}")
|
||||||
|
|
271
api/src/inference/session_pool.py
Normal file
271
api/src/inference/session_pool.py
Normal file
|
@ -0,0 +1,271 @@
|
||||||
|
"""Session pooling for model inference."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Optional, Set
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
from onnxruntime import (
|
||||||
|
ExecutionMode,
|
||||||
|
GraphOptimizationLevel,
|
||||||
|
InferenceSession,
|
||||||
|
SessionOptions
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..core import paths
|
||||||
|
from ..core.model_config import model_config
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SessionInfo:
|
||||||
|
"""Session information."""
|
||||||
|
session: InferenceSession
|
||||||
|
last_used: float
|
||||||
|
stream_id: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
def create_session_options(is_gpu: bool = False) -> SessionOptions:
|
||||||
|
"""Create ONNX session options.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
is_gpu: Whether to use GPU configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured session options
|
||||||
|
"""
|
||||||
|
options = SessionOptions()
|
||||||
|
config = model_config.onnx_gpu if is_gpu else model_config.onnx_cpu
|
||||||
|
|
||||||
|
# Set optimization level
|
||||||
|
if config.optimization_level == "all":
|
||||||
|
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||||
|
elif config.optimization_level == "basic":
|
||||||
|
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||||
|
else:
|
||||||
|
options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
|
||||||
|
|
||||||
|
# Configure threading
|
||||||
|
options.intra_op_num_threads = config.num_threads
|
||||||
|
options.inter_op_num_threads = config.inter_op_threads
|
||||||
|
|
||||||
|
# Set execution mode
|
||||||
|
options.execution_mode = (
|
||||||
|
ExecutionMode.ORT_PARALLEL
|
||||||
|
if config.execution_mode == "parallel"
|
||||||
|
else ExecutionMode.ORT_SEQUENTIAL
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure memory optimization
|
||||||
|
options.enable_mem_pattern = config.memory_pattern
|
||||||
|
|
||||||
|
return options
|
||||||
|
|
||||||
|
|
||||||
|
def create_provider_options(is_gpu: bool = False) -> Dict:
|
||||||
|
"""Create provider options.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
is_gpu: Whether to use GPU configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Provider configuration
|
||||||
|
"""
|
||||||
|
if is_gpu:
|
||||||
|
config = model_config.onnx_gpu
|
||||||
|
return {
|
||||||
|
"CUDAExecutionProvider": {
|
||||||
|
"device_id": config.device_id,
|
||||||
|
"arena_extend_strategy": config.arena_extend_strategy,
|
||||||
|
"gpu_mem_limit": int(config.gpu_mem_limit * torch.cuda.get_device_properties(0).total_memory),
|
||||||
|
"cudnn_conv_algo_search": config.cudnn_conv_algo_search,
|
||||||
|
"do_copy_in_default_stream": config.do_copy_in_default_stream
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"CPUExecutionProvider": {
|
||||||
|
"arena_extend_strategy": model_config.onnx_cpu.arena_extend_strategy,
|
||||||
|
"cpu_memory_arena_cfg": "cpu:0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSessionPool:
|
||||||
|
"""Base session pool implementation."""
|
||||||
|
|
||||||
|
def __init__(self, max_size: int, timeout: int):
|
||||||
|
"""Initialize session pool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_size: Maximum number of concurrent sessions
|
||||||
|
timeout: Session timeout in seconds
|
||||||
|
"""
|
||||||
|
self._max_size = max_size
|
||||||
|
self._timeout = timeout
|
||||||
|
self._sessions: Dict[str, SessionInfo] = {}
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def get_session(self, model_path: str) -> InferenceSession:
|
||||||
|
"""Get session from pool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to model file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ONNX inference session
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If no sessions available
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
# Clean expired sessions
|
||||||
|
self._cleanup_expired()
|
||||||
|
|
||||||
|
# Check if session exists and is valid
|
||||||
|
if model_path in self._sessions:
|
||||||
|
session_info = self._sessions[model_path]
|
||||||
|
session_info.last_used = time.time()
|
||||||
|
return session_info.session
|
||||||
|
|
||||||
|
# Check if we can create new session
|
||||||
|
if len(self._sessions) >= self._max_size:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Maximum number of sessions reached ({self._max_size})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create new session
|
||||||
|
session = await self._create_session(model_path)
|
||||||
|
self._sessions[model_path] = SessionInfo(
|
||||||
|
session=session,
|
||||||
|
last_used=time.time()
|
||||||
|
)
|
||||||
|
return session
|
||||||
|
|
||||||
|
def _cleanup_expired(self) -> None:
|
||||||
|
"""Remove expired sessions."""
|
||||||
|
current_time = time.time()
|
||||||
|
expired = [
|
||||||
|
path for path, info in self._sessions.items()
|
||||||
|
if current_time - info.last_used > self._timeout
|
||||||
|
]
|
||||||
|
for path in expired:
|
||||||
|
logger.info(f"Removing expired session: {path}")
|
||||||
|
del self._sessions[path]
|
||||||
|
|
||||||
|
async def _create_session(self, model_path: str) -> InferenceSession:
|
||||||
|
"""Create new session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to model file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ONNX inference session
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def cleanup(self) -> None:
|
||||||
|
"""Clean up all sessions."""
|
||||||
|
self._sessions.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingSessionPool(BaseSessionPool):
|
||||||
|
"""GPU session pool with CUDA streams."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize GPU session pool."""
|
||||||
|
config = model_config.onnx_gpu
|
||||||
|
super().__init__(config.cuda_streams, config.stream_timeout)
|
||||||
|
self._available_streams: Set[int] = set(range(config.cuda_streams))
|
||||||
|
|
||||||
|
async def get_session(self, model_path: str) -> InferenceSession:
|
||||||
|
"""Get session with CUDA stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to model file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ONNX inference session
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If no streams available
|
||||||
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
# Clean expired sessions
|
||||||
|
self._cleanup_expired()
|
||||||
|
|
||||||
|
# Try to find existing session
|
||||||
|
if model_path in self._sessions:
|
||||||
|
session_info = self._sessions[model_path]
|
||||||
|
session_info.last_used = time.time()
|
||||||
|
return session_info.session
|
||||||
|
|
||||||
|
# Get available stream
|
||||||
|
if not self._available_streams:
|
||||||
|
raise RuntimeError("No CUDA streams available")
|
||||||
|
stream_id = self._available_streams.pop()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create new session
|
||||||
|
session = await self._create_session(model_path)
|
||||||
|
self._sessions[model_path] = SessionInfo(
|
||||||
|
session=session,
|
||||||
|
last_used=time.time(),
|
||||||
|
stream_id=stream_id
|
||||||
|
)
|
||||||
|
return session
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# Return stream to pool on failure
|
||||||
|
self._available_streams.add(stream_id)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _cleanup_expired(self) -> None:
|
||||||
|
"""Remove expired sessions and return streams."""
|
||||||
|
current_time = time.time()
|
||||||
|
expired = [
|
||||||
|
path for path, info in self._sessions.items()
|
||||||
|
if current_time - info.last_used > self._timeout
|
||||||
|
]
|
||||||
|
for path in expired:
|
||||||
|
info = self._sessions[path]
|
||||||
|
if info.stream_id is not None:
|
||||||
|
self._available_streams.add(info.stream_id)
|
||||||
|
logger.info(f"Removing expired session: {path}")
|
||||||
|
del self._sessions[path]
|
||||||
|
|
||||||
|
async def _create_session(self, model_path: str) -> InferenceSession:
|
||||||
|
"""Create new session with CUDA provider."""
|
||||||
|
abs_path = await paths.get_model_path(model_path)
|
||||||
|
options = create_session_options(is_gpu=True)
|
||||||
|
provider_options = create_provider_options(is_gpu=True)
|
||||||
|
|
||||||
|
return InferenceSession(
|
||||||
|
abs_path,
|
||||||
|
sess_options=options,
|
||||||
|
providers=["CUDAExecutionProvider"],
|
||||||
|
provider_options=[provider_options]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CPUSessionPool(BaseSessionPool):
|
||||||
|
"""CPU session pool."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize CPU session pool."""
|
||||||
|
config = model_config.onnx_cpu
|
||||||
|
super().__init__(config.max_instances, config.instance_timeout)
|
||||||
|
|
||||||
|
async def _create_session(self, model_path: str) -> InferenceSession:
|
||||||
|
"""Create new session with CPU provider."""
|
||||||
|
abs_path = await paths.get_model_path(model_path)
|
||||||
|
options = create_session_options(is_gpu=False)
|
||||||
|
provider_options = create_provider_options(is_gpu=False)
|
||||||
|
|
||||||
|
return InferenceSession(
|
||||||
|
abs_path,
|
||||||
|
sess_options=options,
|
||||||
|
providers=["CPUExecutionProvider"],
|
||||||
|
provider_options=[provider_options]
|
||||||
|
)
|
|
@ -1,11 +1,10 @@
|
||||||
"""Voice pack management and caching."""
|
"""Voice pack management and caching."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from ..core import paths
|
from ..core import paths
|
||||||
from ..core.config import settings
|
from ..core.config import settings
|
||||||
|
@ -13,7 +12,7 @@ from ..structures.model_schemas import VoiceConfig
|
||||||
|
|
||||||
|
|
||||||
class VoiceManager:
|
class VoiceManager:
|
||||||
"""Manages voice loading, caching, and operations."""
|
"""Manages voice loading and operations."""
|
||||||
|
|
||||||
def __init__(self, config: Optional[VoiceConfig] = None):
|
def __init__(self, config: Optional[VoiceConfig] = None):
|
||||||
"""Initialize voice manager.
|
"""Initialize voice manager.
|
||||||
|
@ -33,15 +32,8 @@ class VoiceManager:
|
||||||
Returns:
|
Returns:
|
||||||
Path to voice file if exists, None otherwise
|
Path to voice file if exists, None otherwise
|
||||||
"""
|
"""
|
||||||
# Get api directory path (two levels up from inference)
|
|
||||||
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
|
||||||
# Construct voice path relative to api directory
|
|
||||||
voice_path = os.path.join(api_dir, settings.voices_dir, f"{voice_name}.pt")
|
voice_path = os.path.join(api_dir, settings.voices_dir, f"{voice_name}.pt")
|
||||||
|
|
||||||
# Ensure voices directory exists
|
|
||||||
os.makedirs(os.path.dirname(voice_path), exist_ok=True)
|
|
||||||
|
|
||||||
return voice_path if os.path.exists(voice_path) else None
|
return voice_path if os.path.exists(voice_path) else None
|
||||||
|
|
||||||
async def load_voice(self, voice_name: str, device: str = "cpu") -> torch.Tensor:
|
async def load_voice(self, voice_name: str, device: str = "cpu") -> torch.Tensor:
|
||||||
|
@ -66,9 +58,11 @@ class VoiceManager:
|
||||||
if self._config.use_cache and cache_key in self._voice_cache:
|
if self._config.use_cache and cache_key in self._voice_cache:
|
||||||
return self._voice_cache[cache_key]
|
return self._voice_cache[cache_key]
|
||||||
|
|
||||||
try:
|
|
||||||
# Load voice tensor
|
# Load voice tensor
|
||||||
|
try:
|
||||||
voice = await paths.load_voice_tensor(voice_path, device=device)
|
voice = await paths.load_voice_tensor(voice_path, device=device)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to load voice {voice_name}: {e}")
|
||||||
|
|
||||||
# Cache if enabled
|
# Cache if enabled
|
||||||
if self._config.use_cache:
|
if self._config.use_cache:
|
||||||
|
@ -78,9 +72,6 @@ class VoiceManager:
|
||||||
|
|
||||||
return voice
|
return voice
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Failed to load voice {voice_name}: {e}")
|
|
||||||
|
|
||||||
def _manage_cache(self) -> None:
|
def _manage_cache(self) -> None:
|
||||||
"""Manage voice cache size."""
|
"""Manage voice cache size."""
|
||||||
if len(self._voice_cache) >= self._config.cache_size:
|
if len(self._voice_cache) >= self._config.cache_size:
|
||||||
|
@ -123,14 +114,14 @@ class VoiceManager:
|
||||||
# Get api directory path
|
# Get api directory path
|
||||||
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||||
voices_dir = os.path.join(api_dir, settings.voices_dir)
|
voices_dir = os.path.join(api_dir, settings.voices_dir)
|
||||||
|
|
||||||
# Ensure voices directory exists
|
|
||||||
os.makedirs(voices_dir, exist_ok=True)
|
os.makedirs(voices_dir, exist_ok=True)
|
||||||
|
|
||||||
# Save combined voice
|
# Save combined voice
|
||||||
combined_path = os.path.join(voices_dir, f"{combined_name}.pt")
|
combined_path = os.path.join(voices_dir, f"{combined_name}.pt")
|
||||||
try:
|
try:
|
||||||
torch.save(combined_tensor, combined_path)
|
torch.save(combined_tensor, combined_path)
|
||||||
|
# Cache the new combined voice
|
||||||
|
self._voice_cache[f"{combined_path}_{device}"] = combined_tensor
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to save combined voice: {e}")
|
raise RuntimeError(f"Failed to save combined voice: {e}")
|
||||||
|
|
||||||
|
@ -147,17 +138,13 @@ class VoiceManager:
|
||||||
"""
|
"""
|
||||||
voices = []
|
voices = []
|
||||||
try:
|
try:
|
||||||
# Get api directory path
|
|
||||||
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||||
voices_dir = os.path.join(api_dir, settings.voices_dir)
|
voices_dir = os.path.join(api_dir, settings.voices_dir)
|
||||||
|
|
||||||
# Ensure voices directory exists
|
|
||||||
os.makedirs(voices_dir, exist_ok=True)
|
os.makedirs(voices_dir, exist_ok=True)
|
||||||
|
|
||||||
# List voice files
|
|
||||||
for entry in os.listdir(voices_dir):
|
for entry in os.listdir(voices_dir):
|
||||||
if entry.endswith(".pt"):
|
if entry.endswith(".pt"):
|
||||||
voices.append(entry[:-3]) # Remove .pt extension
|
voices.append(entry[:-3])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error listing voices: {e}")
|
logger.error(f"Error listing voices: {e}")
|
||||||
return sorted(voices)
|
return sorted(voices)
|
||||||
|
@ -174,11 +161,8 @@ class VoiceManager:
|
||||||
try:
|
try:
|
||||||
if not os.path.exists(voice_path):
|
if not os.path.exists(voice_path):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Try loading voice
|
|
||||||
voice = torch.load(voice_path, map_location="cpu")
|
voice = torch.load(voice_path, map_location="cpu")
|
||||||
return isinstance(voice, torch.Tensor)
|
return isinstance(voice, torch.Tensor)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -195,12 +179,12 @@ class VoiceManager:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# Module-level instance
|
# Global singleton instance
|
||||||
_manager: Optional[VoiceManager] = None
|
_manager_instance = None
|
||||||
|
|
||||||
|
|
||||||
def get_manager(config: Optional[VoiceConfig] = None) -> VoiceManager:
|
async def get_manager(config: Optional[VoiceConfig] = None) -> VoiceManager:
|
||||||
"""Get or create global voice manager instance.
|
"""Get global voice manager instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: Optional voice configuration
|
config: Optional voice configuration
|
||||||
|
@ -208,7 +192,7 @@ def get_manager(config: Optional[VoiceConfig] = None) -> VoiceManager:
|
||||||
Returns:
|
Returns:
|
||||||
VoiceManager instance
|
VoiceManager instance
|
||||||
"""
|
"""
|
||||||
global _manager
|
global _manager_instance
|
||||||
if _manager is None:
|
if _manager_instance is None:
|
||||||
_manager = VoiceManager(config)
|
_manager_instance = VoiceManager(config)
|
||||||
return _manager
|
return _manager_instance
|
|
@ -1,10 +1,13 @@
|
||||||
|
|
||||||
"""
|
"""
|
||||||
FastAPI OpenAI Compatible API
|
FastAPI OpenAI Compatible API
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
import torch
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
@ -41,19 +44,59 @@ setup_logger()
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
"""Lifespan context manager for model initialization"""
|
"""Lifespan context manager for model initialization"""
|
||||||
|
from .inference.model_manager import get_manager
|
||||||
|
from .inference.voice_manager import get_manager as get_voice_manager
|
||||||
|
|
||||||
logger.info("Loading TTS model and voice packs...")
|
logger.info("Loading TTS model and voice packs...")
|
||||||
|
|
||||||
# Initialize service
|
try:
|
||||||
service = TTSService()
|
# Initialize managers globally
|
||||||
await service.ensure_initialized()
|
model_manager = await get_manager()
|
||||||
|
voice_manager = await get_voice_manager()
|
||||||
|
|
||||||
# Get available voices
|
# Determine backend type based on settings
|
||||||
voices = await service.list_voices()
|
if settings.use_gpu and torch.cuda.is_available():
|
||||||
|
backend_type = 'pytorch_gpu' if not settings.use_onnx else 'onnx_gpu'
|
||||||
|
else:
|
||||||
|
backend_type = 'pytorch_cpu' if not settings.use_onnx else 'onnx_cpu'
|
||||||
|
|
||||||
|
# Get backend and initialize model
|
||||||
|
backend = model_manager.get_backend(backend_type)
|
||||||
|
|
||||||
|
# Use model path directly from settings
|
||||||
|
model_file = settings.pytorch_model_file if not settings.use_onnx else settings.onnx_model_file
|
||||||
|
model_path = os.path.join(settings.model_dir, model_file)
|
||||||
|
|
||||||
|
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
raise RuntimeError(f"Model file not found: {model_path}")
|
||||||
|
|
||||||
|
# Pre-cache default voice and use for warmup
|
||||||
|
warmup_voice = await voice_manager.load_voice(settings.default_voice, device=backend.device)
|
||||||
|
logger.info(f"Pre-cached voice {settings.default_voice} for warmup")
|
||||||
|
|
||||||
|
# Initialize model with warmup voice
|
||||||
|
await model_manager.load_model(model_path, warmup_voice, backend_type)
|
||||||
|
|
||||||
|
# Pre-cache common voices in background
|
||||||
|
common_voices = ['af', 'af_bella', 'af_sarah', 'af_nicole']
|
||||||
|
for voice_name in common_voices:
|
||||||
|
try:
|
||||||
|
await voice_manager.load_voice(voice_name, device=backend.device)
|
||||||
|
logger.debug(f"Pre-cached voice {voice_name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to pre-cache voice {voice_name}: {e}")
|
||||||
|
|
||||||
|
# Get available voices for startup message
|
||||||
|
voices = await voice_manager.list_voices()
|
||||||
voicepack_count = len(voices)
|
voicepack_count = len(voices)
|
||||||
|
|
||||||
# Get device info from model manager
|
# Get device info for startup message
|
||||||
device = "GPU" if settings.use_gpu else "CPU"
|
device = "GPU" if settings.use_gpu else "CPU"
|
||||||
model = "ONNX" if settings.use_onnx else "PyTorch"
|
model = "ONNX" if settings.use_onnx else "PyTorch"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize model: {e}")
|
||||||
|
raise
|
||||||
boundary = "░" * 2*12
|
boundary = "░" * 2*12
|
||||||
startup_msg = f"""
|
startup_msg = f"""
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Response
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from ..services.audio import AudioService
|
from ..services.audio import AudioService
|
||||||
|
@ -16,9 +16,9 @@ from ..structures.text_schemas import (
|
||||||
router = APIRouter(tags=["text processing"])
|
router = APIRouter(tags=["text processing"])
|
||||||
|
|
||||||
|
|
||||||
def get_tts_service() -> TTSService:
|
async def get_tts_service() -> TTSService:
|
||||||
"""Dependency to get TTSService instance"""
|
"""Dependency to get TTSService instance"""
|
||||||
return TTSService()
|
return await TTSService.create() # Create service with properly initialized managers
|
||||||
|
|
||||||
|
|
||||||
@router.post("/text/phonemize", response_model=PhonemeResponse, tags=["deprecated"])
|
@router.post("/text/phonemize", response_model=PhonemeResponse, tags=["deprecated"])
|
||||||
|
@ -82,9 +82,6 @@ async def generate_from_phonemes(
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Ensure service is initialized
|
|
||||||
await tts_service.ensure_initialized()
|
|
||||||
|
|
||||||
# Validate voice exists
|
# Validate voice exists
|
||||||
available_voices = await tts_service.list_voices()
|
available_voices = await tts_service.list_voices()
|
||||||
if request.voice not in available_voices:
|
if request.voice not in available_voices:
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from typing import AsyncGenerator, List, Union
|
from typing import AsyncGenerator, List, Union
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Response, Request
|
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
@ -13,10 +13,28 @@ router = APIRouter(
|
||||||
responses={404: {"description": "Not found"}},
|
responses={404: {"description": "Not found"}},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Global TTSService instance with lock
|
||||||
|
_tts_service = None
|
||||||
|
_init_lock = None
|
||||||
|
|
||||||
def get_tts_service() -> TTSService:
|
async def get_tts_service() -> TTSService:
|
||||||
"""Dependency to get TTSService instance with database session"""
|
"""Get global TTSService instance"""
|
||||||
return TTSService() # Initialize TTSService with default settings
|
global _tts_service, _init_lock
|
||||||
|
|
||||||
|
# Create lock if needed
|
||||||
|
if _init_lock is None:
|
||||||
|
import asyncio
|
||||||
|
_init_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
# Initialize service if needed
|
||||||
|
if _tts_service is None:
|
||||||
|
async with _init_lock:
|
||||||
|
# Double check pattern
|
||||||
|
if _tts_service is None:
|
||||||
|
_tts_service = await TTSService.create()
|
||||||
|
logger.info("Created global TTSService instance")
|
||||||
|
|
||||||
|
return _tts_service
|
||||||
|
|
||||||
|
|
||||||
async def process_voices(
|
async def process_voices(
|
||||||
|
@ -78,11 +96,13 @@ async def stream_audio_chunks(
|
||||||
async def create_speech(
|
async def create_speech(
|
||||||
request: OpenAISpeechRequest,
|
request: OpenAISpeechRequest,
|
||||||
client_request: Request,
|
client_request: Request,
|
||||||
tts_service: TTSService = Depends(get_tts_service),
|
|
||||||
x_raw_response: str = Header(None, alias="x-raw-response"),
|
x_raw_response: str = Header(None, alias="x-raw-response"),
|
||||||
):
|
):
|
||||||
"""OpenAI-compatible endpoint for text-to-speech"""
|
"""OpenAI-compatible endpoint for text-to-speech"""
|
||||||
try:
|
try:
|
||||||
|
# Get global service instance
|
||||||
|
tts_service = await get_tts_service()
|
||||||
|
|
||||||
# Process voice combination and validate
|
# Process voice combination and validate
|
||||||
voice_to_use = await process_voices(request.voice, tts_service)
|
voice_to_use = await process_voices(request.voice, tts_service)
|
||||||
|
|
||||||
|
@ -145,9 +165,10 @@ async def create_speech(
|
||||||
|
|
||||||
|
|
||||||
@router.get("/audio/voices")
|
@router.get("/audio/voices")
|
||||||
async def list_voices(tts_service: TTSService = Depends(get_tts_service)):
|
async def list_voices():
|
||||||
"""List all available voices for text-to-speech"""
|
"""List all available voices for text-to-speech"""
|
||||||
try:
|
try:
|
||||||
|
tts_service = await get_tts_service()
|
||||||
voices = await tts_service.list_voices()
|
voices = await tts_service.list_voices()
|
||||||
return {"voices": voices}
|
return {"voices": voices}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -156,9 +177,7 @@ async def list_voices(tts_service: TTSService = Depends(get_tts_service)):
|
||||||
|
|
||||||
|
|
||||||
@router.post("/audio/voices/combine")
|
@router.post("/audio/voices/combine")
|
||||||
async def combine_voices(
|
async def combine_voices(request: Union[str, List[str]]):
|
||||||
request: Union[str, List[str]], tts_service: TTSService = Depends(get_tts_service)
|
|
||||||
):
|
|
||||||
"""Combine multiple voices into a new voice.
|
"""Combine multiple voices into a new voice.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -174,6 +193,7 @@ async def combine_voices(
|
||||||
- 500: Server error (file system issues, combination failed)
|
- 500: Server error (file system issues, combination failed)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
tts_service = await get_tts_service()
|
||||||
combined_voice = await process_voices(request, tts_service)
|
combined_voice = await process_voices(request, tts_service)
|
||||||
voices = await tts_service.list_voices()
|
voices = await tts_service.list_voices()
|
||||||
return {"voices": voices, "voice": combined_voice}
|
return {"voices": voices, "voice": combined_voice}
|
||||||
|
|
|
@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
||||||
import phonemizer
|
import phonemizer
|
||||||
|
|
||||||
from .normalizer import normalize_text
|
from .normalizer import normalize_text
|
||||||
|
phonemizers = {}
|
||||||
|
|
||||||
class PhonemizerBackend(ABC):
|
class PhonemizerBackend(ABC):
|
||||||
"""Abstract base class for phonemization backends"""
|
"""Abstract base class for phonemization backends"""
|
||||||
|
@ -91,8 +91,9 @@ def phonemize(text: str, language: str = "a", normalize: bool = True) -> str:
|
||||||
Returns:
|
Returns:
|
||||||
Phonemized text
|
Phonemized text
|
||||||
"""
|
"""
|
||||||
|
global phonemizers
|
||||||
if normalize:
|
if normalize:
|
||||||
text = normalize_text(text)
|
text = normalize_text(text)
|
||||||
|
if language not in phonemizers:
|
||||||
phonemizer = create_phonemizer(language)
|
phonemizers[language]=create_phonemizer(language)
|
||||||
return phonemizer.phonemize(text)
|
return phonemizers[language].phonemize(text)
|
|
@ -1,9 +1,8 @@
|
||||||
"""TTS service using model and voice managers."""
|
"""TTS service using model and voice managers."""
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import os
|
|
||||||
import time
|
import time
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.io.wavfile as wavfile
|
import scipy.io.wavfile as wavfile
|
||||||
|
@ -17,9 +16,14 @@ from .audio import AudioNormalizer, AudioService
|
||||||
from .text_processing import chunker, normalize_text, process_text
|
from .text_processing import chunker, normalize_text, process_text
|
||||||
|
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
class TTSService:
|
class TTSService:
|
||||||
"""Text-to-speech service."""
|
"""Text-to-speech service."""
|
||||||
|
|
||||||
|
# Limit concurrent chunk processing
|
||||||
|
_chunk_semaphore = asyncio.Semaphore(4)
|
||||||
|
|
||||||
def __init__(self, output_dir: str = None):
|
def __init__(self, output_dir: str = None):
|
||||||
"""Initialize service.
|
"""Initialize service.
|
||||||
|
|
||||||
|
@ -27,53 +31,24 @@ class TTSService:
|
||||||
output_dir: Optional output directory for saving audio
|
output_dir: Optional output directory for saving audio
|
||||||
"""
|
"""
|
||||||
self.output_dir = output_dir
|
self.output_dir = output_dir
|
||||||
self.model_manager = get_model_manager()
|
self.model_manager = None
|
||||||
self.voice_manager = get_voice_manager()
|
self._voice_manager = None
|
||||||
self._initialized = False
|
|
||||||
self._initialization_error = None
|
|
||||||
|
|
||||||
async def ensure_initialized(self):
|
@classmethod
|
||||||
"""Ensure model is initialized."""
|
async def create(cls, output_dir: str = None) -> 'TTSService':
|
||||||
if self._initialized:
|
"""Create and initialize TTSService instance.
|
||||||
return
|
|
||||||
if self._initialization_error:
|
|
||||||
raise self._initialization_error
|
|
||||||
|
|
||||||
try:
|
Args:
|
||||||
# Get api directory path (one level up from src)
|
output_dir: Optional output directory for saving audio
|
||||||
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
|
||||||
|
|
||||||
# Determine model file and backend based on hardware
|
Returns:
|
||||||
if settings.use_gpu and torch.cuda.is_available():
|
Initialized TTSService instance
|
||||||
model_file = settings.pytorch_model_file
|
"""
|
||||||
backend_type = 'pytorch_gpu'
|
service = cls(output_dir)
|
||||||
else:
|
# Initialize managers
|
||||||
model_file = settings.onnx_model_file
|
service.model_manager = await get_model_manager()
|
||||||
backend_type = 'onnx_cpu'
|
service._voice_manager = await get_voice_manager()
|
||||||
|
return service
|
||||||
# Construct model path relative to api directory
|
|
||||||
model_path = os.path.join(api_dir, settings.model_dir, model_file)
|
|
||||||
|
|
||||||
# Ensure model directory exists
|
|
||||||
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
|
||||||
|
|
||||||
if not os.path.exists(model_path):
|
|
||||||
raise RuntimeError(f"Model file not found: {model_path}")
|
|
||||||
|
|
||||||
# Load default voice for warmup
|
|
||||||
backend = self.model_manager.get_backend(backend_type)
|
|
||||||
warmup_voice = await self.voice_manager.load_voice(settings.default_voice, device=backend.device)
|
|
||||||
logger.info(f"Loaded voice {settings.default_voice} for warmup")
|
|
||||||
|
|
||||||
# Initialize model with warmup voice
|
|
||||||
await self.model_manager.load_model(model_path, warmup_voice, backend_type)
|
|
||||||
logger.info(f"Initialized model on {backend_type} backend")
|
|
||||||
self._initialized = True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to initialize model: {e}")
|
|
||||||
self._initialization_error = RuntimeError(f"Model initialization failed: {e}")
|
|
||||||
raise self._initialization_error
|
|
||||||
|
|
||||||
async def generate_audio(
|
async def generate_audio(
|
||||||
self, text: str, voice: str, speed: float = 1.0
|
self, text: str, voice: str, speed: float = 1.0
|
||||||
|
@ -88,8 +63,8 @@ class TTSService:
|
||||||
Returns:
|
Returns:
|
||||||
Audio samples and processing time
|
Audio samples and processing time
|
||||||
"""
|
"""
|
||||||
await self.ensure_initialized()
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
voice_tensor = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Normalize text
|
# Normalize text
|
||||||
|
@ -98,31 +73,40 @@ class TTSService:
|
||||||
raise ValueError("Text is empty after preprocessing")
|
raise ValueError("Text is empty after preprocessing")
|
||||||
text = str(normalized)
|
text = str(normalized)
|
||||||
|
|
||||||
# Process text into chunks
|
|
||||||
audio_chunks = []
|
|
||||||
for chunk in chunker.split_text(text):
|
|
||||||
try:
|
|
||||||
# Convert chunk to token IDs
|
|
||||||
tokens = process_text(chunk)
|
|
||||||
if not tokens:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get backend and load voice
|
# Get backend and load voice
|
||||||
backend = self.model_manager.get_backend()
|
backend = self.model_manager.get_backend()
|
||||||
voice_tensor = await self.voice_manager.load_voice(voice, device=backend.device)
|
voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device)
|
||||||
|
|
||||||
|
# Get all chunks upfront
|
||||||
|
chunks = list(chunker.split_text(text))
|
||||||
|
if not chunks:
|
||||||
|
raise ValueError("No text chunks to process")
|
||||||
|
|
||||||
|
# Process chunk with concurrency control
|
||||||
|
async def process_chunk(chunk: str) -> Optional[np.ndarray]:
|
||||||
|
async with self._chunk_semaphore:
|
||||||
|
try:
|
||||||
|
tokens = process_text(chunk)
|
||||||
|
if not tokens:
|
||||||
|
return None
|
||||||
|
|
||||||
# Generate audio
|
# Generate audio
|
||||||
chunk_audio = await self.model_manager.generate(
|
return await self.model_manager.generate(
|
||||||
tokens,
|
tokens,
|
||||||
voice_tensor,
|
voice_tensor,
|
||||||
speed=speed
|
speed=speed
|
||||||
)
|
)
|
||||||
if chunk_audio is not None:
|
|
||||||
audio_chunks.append(chunk_audio)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to process chunk: '{chunk}'. Error: {str(e)}")
|
logger.error(f"Failed to process chunk: '{chunk}'. Error: {str(e)}")
|
||||||
continue
|
return None
|
||||||
|
|
||||||
|
# Process all chunks concurrently
|
||||||
|
chunk_results = await asyncio.gather(*[
|
||||||
|
process_chunk(chunk) for chunk in chunks
|
||||||
|
])
|
||||||
|
|
||||||
|
# Filter out None results and combine
|
||||||
|
audio_chunks = [chunk for chunk in chunk_results if chunk is not None]
|
||||||
if not audio_chunks:
|
if not audio_chunks:
|
||||||
raise ValueError("No audio chunks were generated successfully")
|
raise ValueError("No audio chunks were generated successfully")
|
||||||
|
|
||||||
|
@ -134,6 +118,11 @@ class TTSService:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in audio generation: {str(e)}")
|
logger.error(f"Error in audio generation: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
finally:
|
||||||
|
# Always clean up voice tensor
|
||||||
|
if voice_tensor is not None:
|
||||||
|
del voice_tensor
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
async def generate_audio_stream(
|
async def generate_audio_stream(
|
||||||
self,
|
self,
|
||||||
|
@ -153,32 +142,33 @@ class TTSService:
|
||||||
Yields:
|
Yields:
|
||||||
Audio chunks as bytes
|
Audio chunks as bytes
|
||||||
"""
|
"""
|
||||||
await self.ensure_initialized()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Setup audio processing
|
# Setup audio processing
|
||||||
stream_normalizer = AudioNormalizer()
|
stream_normalizer = AudioNormalizer()
|
||||||
|
voice_tensor = None
|
||||||
|
|
||||||
|
try:
|
||||||
# Normalize text
|
# Normalize text
|
||||||
normalized = normalize_text(text)
|
normalized = normalize_text(text)
|
||||||
if not normalized:
|
if not normalized:
|
||||||
raise ValueError("Text is empty after preprocessing")
|
raise ValueError("Text is empty after preprocessing")
|
||||||
text = str(normalized)
|
text = str(normalized)
|
||||||
|
|
||||||
# Process chunks
|
|
||||||
is_first = True
|
|
||||||
chunk_gen = chunker.split_text(text)
|
|
||||||
current_chunk = next(chunk_gen, None)
|
|
||||||
|
|
||||||
while current_chunk is not None:
|
|
||||||
next_chunk = next(chunk_gen, None)
|
|
||||||
try:
|
|
||||||
# Convert chunk to token IDs
|
|
||||||
tokens = process_text(current_chunk)
|
|
||||||
if tokens:
|
|
||||||
# Get backend and load voice
|
# Get backend and load voice
|
||||||
backend = self.model_manager.get_backend()
|
backend = self.model_manager.get_backend()
|
||||||
voice_tensor = await self.voice_manager.load_voice(voice, device=backend.device)
|
voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device)
|
||||||
|
|
||||||
|
# Get all chunks upfront
|
||||||
|
chunks = list(chunker.split_text(text))
|
||||||
|
if not chunks:
|
||||||
|
raise ValueError("No text chunks to process")
|
||||||
|
|
||||||
|
# Process chunk with concurrency control
|
||||||
|
async def process_chunk(chunk: str, is_first: bool, is_last: bool) -> Optional[bytes]:
|
||||||
|
async with self._chunk_semaphore:
|
||||||
|
try:
|
||||||
|
tokens = process_text(chunk)
|
||||||
|
if not tokens:
|
||||||
|
return None
|
||||||
|
|
||||||
# Generate audio
|
# Generate audio
|
||||||
chunk_audio = await self.model_manager.generate(
|
chunk_audio = await self.model_manager.generate(
|
||||||
|
@ -189,26 +179,38 @@ class TTSService:
|
||||||
|
|
||||||
if chunk_audio is not None:
|
if chunk_audio is not None:
|
||||||
# Convert to bytes
|
# Convert to bytes
|
||||||
chunk_bytes = AudioService.convert_audio(
|
return AudioService.convert_audio(
|
||||||
chunk_audio,
|
chunk_audio,
|
||||||
24000,
|
24000,
|
||||||
output_format,
|
output_format,
|
||||||
is_first_chunk=is_first,
|
is_first_chunk=is_first,
|
||||||
normalizer=stream_normalizer,
|
normalizer=stream_normalizer,
|
||||||
is_last_chunk=(next_chunk is None),
|
is_last_chunk=is_last,
|
||||||
stream=True
|
stream=True
|
||||||
)
|
)
|
||||||
yield chunk_bytes
|
|
||||||
is_first = False
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to generate audio for chunk: '{current_chunk}'. Error: {str(e)}")
|
logger.error(f"Failed to generate audio for chunk: '{chunk}'. Error: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
current_chunk = next_chunk
|
# Create tasks for all chunks
|
||||||
|
tasks = [
|
||||||
|
process_chunk(chunk, i==0, i==len(chunks)-1)
|
||||||
|
for i, chunk in enumerate(chunks)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Process chunks concurrently and yield results in order
|
||||||
|
for chunk_bytes in await asyncio.gather(*tasks):
|
||||||
|
if chunk_bytes is not None:
|
||||||
|
yield chunk_bytes
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in audio generation stream: {str(e)}")
|
logger.error(f"Error in audio generation stream: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
finally:
|
||||||
|
# Always clean up voice tensor
|
||||||
|
if voice_tensor is not None:
|
||||||
|
del voice_tensor
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
async def combine_voices(self, voices: List[str]) -> str:
|
async def combine_voices(self, voices: List[str]) -> str:
|
||||||
"""Combine multiple voices.
|
"""Combine multiple voices.
|
||||||
|
@ -219,8 +221,7 @@ class TTSService:
|
||||||
Returns:
|
Returns:
|
||||||
Name of combined voice
|
Name of combined voice
|
||||||
"""
|
"""
|
||||||
await self.ensure_initialized()
|
return await self._voice_manager.combine_voices(voices)
|
||||||
return await self.voice_manager.combine_voices(voices)
|
|
||||||
|
|
||||||
async def list_voices(self) -> List[str]:
|
async def list_voices(self) -> List[str]:
|
||||||
"""List available voices.
|
"""List available voices.
|
||||||
|
@ -228,7 +229,7 @@ class TTSService:
|
||||||
Returns:
|
Returns:
|
||||||
List of voice names
|
List of voice names
|
||||||
"""
|
"""
|
||||||
return await self.voice_manager.list_voices()
|
return await self._voice_manager.list_voices()
|
||||||
|
|
||||||
def _audio_to_bytes(self, audio: np.ndarray) -> bytes:
|
def _audio_to_bytes(self, audio: np.ndarray) -> bytes:
|
||||||
"""Convert audio to WAV bytes.
|
"""Convert audio to WAV bytes.
|
||||||
|
|
|
@ -1,60 +0,0 @@
|
||||||
import os
|
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from ..core.config import settings
|
|
||||||
from .tts_model import TTSModel
|
|
||||||
from .tts_service import TTSService
|
|
||||||
|
|
||||||
|
|
||||||
class WarmupService:
|
|
||||||
"""Service for warming up TTS models and voice caches"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""Initialize warmup service and ensure model is ready"""
|
|
||||||
# Initialize model if not already initialized
|
|
||||||
if TTSModel._instance is None:
|
|
||||||
TTSModel.initialize(settings.model_dir)
|
|
||||||
self.tts_service = TTSService()
|
|
||||||
|
|
||||||
def load_voices(self) -> List[Tuple[str, torch.Tensor]]:
|
|
||||||
"""Load and cache voices up to LRU limit"""
|
|
||||||
# Get all voices sorted by filename length (shorter names first, usually base voices)
|
|
||||||
voice_files = sorted(
|
|
||||||
[f for f in os.listdir(TTSModel.VOICES_DIR) if f.endswith(".pt")], key=len
|
|
||||||
)
|
|
||||||
|
|
||||||
n_voices_cache = 1
|
|
||||||
loaded_voices = []
|
|
||||||
for voice_file in voice_files[:n_voices_cache]:
|
|
||||||
try:
|
|
||||||
voice_path = os.path.join(TTSModel.VOICES_DIR, voice_file)
|
|
||||||
# load using service, lru cache
|
|
||||||
voicepack = self.tts_service._load_voice(voice_path)
|
|
||||||
loaded_voices.append(
|
|
||||||
(voice_file[:-3], voicepack)
|
|
||||||
) # Store name and tensor
|
|
||||||
# voicepack = torch.load(voice_path, map_location=TTSModel.get_device(), weights_only=True)
|
|
||||||
# logger.info(f"Loaded voice {voice_file[:-3]} into cache")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to load voice {voice_file}: {e}")
|
|
||||||
logger.info(f"Pre-loaded {len(loaded_voices)} voices into cache")
|
|
||||||
return loaded_voices
|
|
||||||
|
|
||||||
async def warmup_voices(
|
|
||||||
self, warmup_text: str, loaded_voices: List[Tuple[str, torch.Tensor]]
|
|
||||||
):
|
|
||||||
"""Warm up voice inference and streaming"""
|
|
||||||
n_warmups = 1
|
|
||||||
for voice_name, _ in loaded_voices[:n_warmups]:
|
|
||||||
try:
|
|
||||||
logger.info(f"Running warmup inference on voice {voice_name}")
|
|
||||||
async for _ in self.tts_service.generate_audio_stream(
|
|
||||||
warmup_text, voice_name, 1.0, "pcm"
|
|
||||||
):
|
|
||||||
pass # Process all chunks to properly warm up
|
|
||||||
logger.info(f"Completed warmup for voice {voice_name}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Warmup failed for voice {voice_name}: {e}")
|
|
|
@ -10,7 +10,7 @@ services:
|
||||||
ports:
|
ports:
|
||||||
- "8880:8880"
|
- "8880:8880"
|
||||||
environment:
|
environment:
|
||||||
- PYTHONPATH=/app
|
- PYTHONPATH=/app:/app/api
|
||||||
- USE_GPU=true
|
- USE_GPU=true
|
||||||
- USE_ONNX=false
|
- USE_ONNX=false
|
||||||
- PYTHONUNBUFFERED=1
|
- PYTHONUNBUFFERED=1
|
||||||
|
|
|
@ -25,9 +25,7 @@ def main() -> None:
|
||||||
def stream_to_speakers() -> None:
|
def stream_to_speakers() -> None:
|
||||||
import pyaudio
|
import pyaudio
|
||||||
|
|
||||||
player_stream = pyaudio.PyAudio().open(
|
player_stream = pyaudio.PyAudio().open(format=pyaudio.paInt16, channels=1, rate=24000, output=True)
|
||||||
format=pyaudio.paInt16, channels=1, rate=24000, output=True
|
|
||||||
)
|
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
|
53
examples/simul_file_test.py
Normal file
53
examples/simul_file_test.py
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
#!/usr/bin/env rye run python
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
# Initialize async client
|
||||||
|
openai = AsyncOpenAI(base_url="http://localhost:8880/v1", api_key="not-needed-for-local")
|
||||||
|
|
||||||
|
async def save_to_file(text: str, file_id: int) -> None:
|
||||||
|
"""Save TTS output to file asynchronously"""
|
||||||
|
speech_file_path = Path(__file__).parent / f"speech_{file_id}.mp3"
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
print(f"Starting file {file_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use streaming endpoint with mp3 format
|
||||||
|
async with openai.audio.speech.with_streaming_response.create(
|
||||||
|
model="kokoro",
|
||||||
|
voice="af_bella",
|
||||||
|
input=text,
|
||||||
|
response_format="mp3"
|
||||||
|
) as response:
|
||||||
|
print(f"File {file_id} - Time to first byte: {int((time.time() - start_time) * 1000)}ms")
|
||||||
|
|
||||||
|
# Open file in binary write mode
|
||||||
|
with open(speech_file_path, 'wb') as f:
|
||||||
|
async for chunk in response.iter_bytes():
|
||||||
|
f.write(chunk)
|
||||||
|
|
||||||
|
print(f"File {file_id} completed in {int((time.time() - start_time) * 1000)}ms")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing file {file_id}: {e}")
|
||||||
|
|
||||||
|
async def main() -> None:
|
||||||
|
# Different text samples for variety
|
||||||
|
texts = [
|
||||||
|
"The quick brown fox jumped over the lazy dogs. I see skies of blue and clouds of white",
|
||||||
|
"I see skies of blue and clouds of white. I see skies of blue and clouds of white",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create tasks for saving to files
|
||||||
|
file_tasks = [
|
||||||
|
save_to_file(text, i)
|
||||||
|
for i, text in enumerate(texts)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Run file tasks concurrently
|
||||||
|
await asyncio.gather(*file_tasks)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
91
examples/simul_openai_streaming_audio.py
Normal file
91
examples/simul_openai_streaming_audio.py
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
#!/usr/bin/env rye run python
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
import pyaudio
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
# Initialize async client
|
||||||
|
openai = AsyncOpenAI(base_url="http://localhost:8880/v1", api_key="not-needed-for-local")
|
||||||
|
|
||||||
|
# Create a shared PyAudio instance
|
||||||
|
p = pyaudio.PyAudio()
|
||||||
|
|
||||||
|
async def stream_to_speakers(text: str, stream_id: int) -> None:
|
||||||
|
"""Stream TTS audio to speakers asynchronously"""
|
||||||
|
player_stream = p.open(
|
||||||
|
format=pyaudio.paInt16,
|
||||||
|
channels=1,
|
||||||
|
rate=24000,
|
||||||
|
output=True
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
print(f"Starting stream {stream_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with openai.audio.speech.with_streaming_response.create(
|
||||||
|
model="kokoro",
|
||||||
|
voice="af_bella",
|
||||||
|
response_format="pcm",
|
||||||
|
input=text
|
||||||
|
) as response:
|
||||||
|
print(f"Stream {stream_id} - Time to first byte: {int((time.time() - start_time) * 1000)}ms")
|
||||||
|
|
||||||
|
async for chunk in response.iter_bytes(chunk_size=1024):
|
||||||
|
player_stream.write(chunk)
|
||||||
|
# Small sleep to allow other coroutines to run
|
||||||
|
await asyncio.sleep(0.001)
|
||||||
|
|
||||||
|
print(f"Stream {stream_id} completed in {int((time.time() - start_time) * 1000)}ms")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
player_stream.stop_stream()
|
||||||
|
player_stream.close()
|
||||||
|
|
||||||
|
async def save_to_file(text: str, file_id: int) -> None:
|
||||||
|
"""Save TTS output to file asynchronously"""
|
||||||
|
speech_file_path = Path(__file__).parent / f"speech_{file_id}.mp3"
|
||||||
|
|
||||||
|
async with openai.audio.speech.with_streaming_response.create(
|
||||||
|
model="kokoro",
|
||||||
|
voice="af_bella",
|
||||||
|
input=text
|
||||||
|
) as response:
|
||||||
|
# Open file in binary write mode
|
||||||
|
with open(speech_file_path, 'wb') as f:
|
||||||
|
async for chunk in response.iter_bytes():
|
||||||
|
f.write(chunk)
|
||||||
|
print(f"File {file_id} saved to {speech_file_path}")
|
||||||
|
|
||||||
|
async def main() -> None:
|
||||||
|
# Different text samples for variety
|
||||||
|
texts = [
|
||||||
|
"The quick brown fox jumped over the lazy dogs. I see skies of blue and clouds of white",
|
||||||
|
"I see skies of blue and clouds of white. I see skies of blue and clouds of white",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create tasks for streaming to speakers
|
||||||
|
speaker_tasks = [
|
||||||
|
stream_to_speakers(text, i)
|
||||||
|
for i, text in enumerate(texts)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create tasks for saving to files
|
||||||
|
file_tasks = [
|
||||||
|
save_to_file(text, i)
|
||||||
|
for i, text in enumerate(texts)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Combine all tasks
|
||||||
|
all_tasks = speaker_tasks + file_tasks
|
||||||
|
|
||||||
|
# Run all tasks concurrently
|
||||||
|
try:
|
||||||
|
await asyncio.gather(*all_tasks)
|
||||||
|
finally:
|
||||||
|
# Clean up PyAudio
|
||||||
|
p.terminate()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
66
examples/simul_speaker_test.py
Normal file
66
examples/simul_speaker_test.py
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
#!/usr/bin/env rye run python
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
import pyaudio
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
# Initialize async client
|
||||||
|
openai = AsyncOpenAI(base_url="http://localhost:8880/v1", api_key="not-needed-for-local")
|
||||||
|
|
||||||
|
# Create a shared PyAudio instance
|
||||||
|
p = pyaudio.PyAudio()
|
||||||
|
|
||||||
|
async def stream_to_speakers(text: str, stream_id: int) -> None:
|
||||||
|
"""Stream TTS audio to speakers asynchronously"""
|
||||||
|
player_stream = p.open(
|
||||||
|
format=pyaudio.paInt16,
|
||||||
|
channels=1,
|
||||||
|
rate=24000,
|
||||||
|
output=True
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
print(f"Starting stream {stream_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with openai.audio.speech.with_streaming_response.create(
|
||||||
|
model="kokoro",
|
||||||
|
voice="af_bella",
|
||||||
|
response_format="pcm",
|
||||||
|
input=text
|
||||||
|
) as response:
|
||||||
|
print(f"Stream {stream_id} - Time to first byte: {int((time.time() - start_time) * 1000)}ms")
|
||||||
|
|
||||||
|
async for chunk in response.iter_bytes(chunk_size=1024):
|
||||||
|
player_stream.write(chunk)
|
||||||
|
# Small sleep to allow other coroutines to run
|
||||||
|
await asyncio.sleep(0.001)
|
||||||
|
|
||||||
|
print(f"Stream {stream_id} completed in {int((time.time() - start_time) * 1000)}ms")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
player_stream.stop_stream()
|
||||||
|
player_stream.close()
|
||||||
|
|
||||||
|
async def main() -> None:
|
||||||
|
# Different text samples for variety
|
||||||
|
texts = [
|
||||||
|
"The quick brown fox jumped over the lazy dogs. I see skies of blue and clouds of white",
|
||||||
|
"I see skies of blue and clouds of white. I see skies of blue and clouds of white",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create tasks for streaming to speakers
|
||||||
|
speaker_tasks = [
|
||||||
|
stream_to_speakers(text, i)
|
||||||
|
for i, text in enumerate(texts)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Run speaker tasks concurrently
|
||||||
|
try:
|
||||||
|
await asyncio.gather(*speaker_tasks)
|
||||||
|
finally:
|
||||||
|
# Clean up PyAudio
|
||||||
|
p.terminate()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
Binary file not shown.
Loading…
Add table
Reference in a new issue