mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Initial commit of Kokoro V1.0-only integration
This commit is contained in:
parent
903bf91c81
commit
4c90a89545
35 changed files with 4560 additions and 1501 deletions
2
VERSION
2
VERSION
|
@ -1 +1 @@
|
|||
0.1.4
|
||||
v0.1.5-pre
|
|
@ -12,14 +12,14 @@ class Settings(BaseSettings):
|
|||
# Application Settings
|
||||
output_dir: str = "output"
|
||||
output_dir_size_limit_mb: float = 500.0 # Maximum size of output directory in MB
|
||||
default_voice: str = "af"
|
||||
default_voice: str = "af_heart"
|
||||
use_gpu: bool = True # Whether to use GPU acceleration if available
|
||||
use_onnx: bool = False # Whether to use ONNX runtime
|
||||
allow_local_voice_saving: bool = False # Whether to allow saving combined voices locally
|
||||
|
||||
# Container absolute paths
|
||||
model_dir: str = "/app/api/src/models" # Absolute path in container
|
||||
voices_dir: str = "/app/api/src/voices" # Absolute path in container
|
||||
voices_dir: str = "/app/api/src/voices/v1_0" # Absolute path in container
|
||||
|
||||
# Audio Settings
|
||||
sample_rate: int = 24000
|
||||
|
|
|
@ -1,47 +1,14 @@
|
|||
"""Model configuration schemas."""
|
||||
"""Model configuration for Kokoro V1."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class KokoroV1Config(BaseModel):
|
||||
"""Kokoro V1 configuration."""
|
||||
languages: list[str] = ["en"]
|
||||
|
||||
|
||||
class ONNXCPUConfig(BaseModel):
|
||||
"""ONNX CPU runtime configuration."""
|
||||
|
||||
# Session pooling
|
||||
max_instances: int = Field(4, description="Maximum concurrent model instances")
|
||||
instance_timeout: int = Field(60, description="Session timeout in seconds")
|
||||
|
||||
# Runtime settings
|
||||
num_threads: int = Field(8, description="Number of threads for parallel operations")
|
||||
inter_op_threads: int = Field(4, description="Number of threads for operator parallelism")
|
||||
execution_mode: str = Field("parallel", description="ONNX execution mode")
|
||||
optimization_level: str = Field("all", description="ONNX optimization level")
|
||||
memory_pattern: bool = Field(True, description="Enable memory pattern optimization")
|
||||
arena_extend_strategy: str = Field("kNextPowerOfTwo", description="Memory arena strategy")
|
||||
|
||||
class Config:
|
||||
frozen = True
|
||||
|
||||
|
||||
class ONNXGPUConfig(ONNXCPUConfig):
|
||||
"""ONNX GPU-specific configuration."""
|
||||
|
||||
# CUDA settings
|
||||
device_id: int = Field(0, description="CUDA device ID")
|
||||
gpu_mem_limit: float = Field(0.5, description="Fraction of GPU memory to use")
|
||||
cudnn_conv_algo_search: str = Field("EXHAUSTIVE", description="CuDNN convolution algorithm search")
|
||||
|
||||
# Stream management
|
||||
cuda_streams: int = Field(2, description="Number of CUDA streams for inference")
|
||||
stream_timeout: int = Field(60, description="Stream timeout in seconds")
|
||||
do_copy_in_default_stream: bool = Field(True, description="Copy in default CUDA stream")
|
||||
|
||||
class Config:
|
||||
frozen = True
|
||||
|
||||
|
||||
class PyTorchCPUConfig(BaseModel):
|
||||
"""PyTorch CPU backend configuration."""
|
||||
|
||||
|
@ -70,48 +37,23 @@ class PyTorchGPUConfig(BaseModel):
|
|||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
"""Model configuration."""
|
||||
"""Kokoro V1 model configuration."""
|
||||
|
||||
# General settings
|
||||
model_type: str = Field("pytorch", description="Model type ('pytorch' or 'onnx')")
|
||||
device_type: str = Field("auto", description="Device type ('cpu', 'gpu', or 'auto')")
|
||||
cache_models: bool = Field(True, description="Whether to cache loaded models")
|
||||
device_type: str = Field("cpu", description="Device type ('cpu' or 'gpu')")
|
||||
cache_voices: bool = Field(True, description="Whether to cache voice tensors")
|
||||
voice_cache_size: int = Field(2, description="Maximum number of cached voices")
|
||||
|
||||
# Model filenames
|
||||
# Model filename
|
||||
pytorch_kokoro_v1_file: str = Field("v1_0/kokoro-v1_0.pth", description="PyTorch Kokoro V1 model filename")
|
||||
pytorch_model_file: str = Field("kokoro-v0_19-half.pth", description="PyTorch model filename")
|
||||
onnx_model_file: str = Field("kokoro-v0_19.onnx", description="ONNX model filename")
|
||||
|
||||
# Backend-specific configs
|
||||
onnx_cpu: ONNXCPUConfig = Field(default_factory=ONNXCPUConfig)
|
||||
onnx_gpu: ONNXGPUConfig = Field(default_factory=ONNXGPUConfig)
|
||||
# Backend configs
|
||||
pytorch_cpu: PyTorchCPUConfig = Field(default_factory=PyTorchCPUConfig)
|
||||
pytorch_gpu: PyTorchGPUConfig = Field(default_factory=PyTorchGPUConfig)
|
||||
|
||||
class Config:
|
||||
frozen = True
|
||||
|
||||
def get_backend_config(self, backend_type: str):
|
||||
"""Get configuration for specific backend.
|
||||
|
||||
Args:
|
||||
backend_type: Backend type ('pytorch_cpu', 'pytorch_gpu', 'onnx_cpu', 'onnx_gpu', 'kokoro_v1')
|
||||
|
||||
Returns:
|
||||
Backend-specific configuration
|
||||
|
||||
Raises:
|
||||
ValueError: If backend type is invalid
|
||||
"""
|
||||
if backend_type not in {
|
||||
'pytorch_cpu', 'pytorch_gpu', 'onnx_cpu', 'onnx_gpu', 'kokoro_v1'
|
||||
}:
|
||||
raise ValueError(f"Invalid backend type: {backend_type}")
|
||||
|
||||
return getattr(self, backend_type)
|
||||
|
||||
|
||||
# Global instance
|
||||
model_config = ModelConfig()
|
|
@ -127,7 +127,7 @@ async def get_voice_path(voice_name: str) -> str:
|
|||
|
||||
voice_file = f"{voice_name}.pt"
|
||||
|
||||
# Search in voice directory
|
||||
# Search in voice directory/o
|
||||
search_paths = [voice_dir]
|
||||
logger.debug(f"Searching for voice in path: {voice_dir}")
|
||||
|
||||
|
|
|
@ -2,15 +2,11 @@
|
|||
|
||||
from .base import BaseModelBackend
|
||||
from .model_manager import ModelManager, get_manager
|
||||
from .onnx_cpu import ONNXCPUBackend
|
||||
from .onnx_gpu import ONNXGPUBackend
|
||||
from .pytorch_backend import PyTorchBackend
|
||||
from .kokoro_v1 import KokoroV1
|
||||
|
||||
__all__ = [
|
||||
'BaseModelBackend',
|
||||
'ModelManager',
|
||||
'get_manager',
|
||||
'ONNXCPUBackend',
|
||||
'ONNXGPUBackend',
|
||||
'PyTorchBackend',
|
||||
'KokoroV1',
|
||||
]
|
|
@ -1,14 +1,14 @@
|
|||
"""Base interfaces for model inference."""
|
||||
"""Base interface for Kokoro inference."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
from typing import AsyncGenerator, Optional, Union, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class ModelBackend(ABC):
|
||||
"""Abstract base class for model inference backends."""
|
||||
"""Abstract base class for model inference backend."""
|
||||
|
||||
@abstractmethod
|
||||
async def load_model(self, path: str) -> None:
|
||||
|
@ -23,21 +23,21 @@ class ModelBackend(ABC):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
def generate(
|
||||
async def generate(
|
||||
self,
|
||||
tokens: List[int],
|
||||
voice: torch.Tensor,
|
||||
text: str,
|
||||
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
|
||||
speed: float = 1.0
|
||||
) -> np.ndarray:
|
||||
"""Generate audio from tokens.
|
||||
) -> AsyncGenerator[np.ndarray, None]:
|
||||
"""Generate audio from text.
|
||||
|
||||
Args:
|
||||
tokens: Input token IDs
|
||||
voice: Voice embedding tensor
|
||||
text: Input text to synthesize
|
||||
voice: Either a voice path or tuple of (name, tensor/path)
|
||||
speed: Speed multiplier
|
||||
|
||||
Returns:
|
||||
Generated audio samples
|
||||
Yields:
|
||||
Generated audio chunks
|
||||
|
||||
Raises:
|
||||
RuntimeError: If generation fails
|
||||
|
@ -95,3 +95,4 @@ class BaseModelBackend(ModelBackend):
|
|||
self._model = None
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
|
@ -1,9 +1,7 @@
|
|||
"""PyTorch inference backend with environment-based configuration."""
|
||||
"""Clean Kokoro implementation with controlled resource management."""
|
||||
|
||||
import gc
|
||||
import os
|
||||
from typing import AsyncGenerator, Optional, List, Union, Tuple
|
||||
from contextlib import nullcontext
|
||||
from typing import AsyncGenerator, Optional, Union, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -15,23 +13,19 @@ from ..core.config import settings
|
|||
from .base import BaseModelBackend
|
||||
from kokoro import KModel, KPipeline
|
||||
|
||||
|
||||
class KokoroV1(BaseModelBackend):
|
||||
"""Kokoro package based inference backend with environment-based configuration."""
|
||||
"""Kokoro backend with controlled resource management."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize backend based on environment configuration."""
|
||||
"""Initialize backend with environment-based configuration."""
|
||||
super().__init__()
|
||||
|
||||
# Configure device based on settings
|
||||
self._device = (
|
||||
"cuda" if settings.use_gpu and torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
# Strictly respect settings.use_gpu
|
||||
self._device = "cuda" if settings.use_gpu else "cpu"
|
||||
self._model: Optional[KModel] = None
|
||||
self._pipeline: Optional[KPipeline] = None
|
||||
|
||||
async def load_model(self, path: str) -> None:
|
||||
"""Load Kokoro model.
|
||||
"""Load pre-baked model.
|
||||
|
||||
Args:
|
||||
path: Path to model file
|
||||
|
@ -42,8 +36,6 @@ class KokoroV1(BaseModelBackend):
|
|||
try:
|
||||
# Get verified model path
|
||||
model_path = await paths.get_model_path(path)
|
||||
|
||||
# Get config.json path from the same directory
|
||||
config_path = os.path.join(os.path.dirname(model_path), 'config.json')
|
||||
|
||||
if not os.path.exists(config_path):
|
||||
|
@ -53,22 +45,36 @@ class KokoroV1(BaseModelBackend):
|
|||
logger.info(f"Config path: {config_path}")
|
||||
logger.info(f"Model path: {model_path}")
|
||||
|
||||
# Initialize model with config and weights
|
||||
self._model = KModel(config=config_path, model=model_path).to(self._device).eval()
|
||||
# Initialize pipeline with American English by default
|
||||
self._pipeline = KPipeline(lang_code='a', model=self._model, device=self._device)
|
||||
# Load model and let KModel handle device mapping
|
||||
self._model = KModel(
|
||||
config=config_path,
|
||||
model=model_path
|
||||
).eval()
|
||||
# Move to CUDA if needed
|
||||
if self._device == "cuda":
|
||||
self._model = self._model.cuda()
|
||||
|
||||
# Initialize pipeline with our model and device
|
||||
self._pipeline = KPipeline(
|
||||
lang_code='a',
|
||||
model=self._model, # Pass our model directly
|
||||
device=self._device # Match our device setting
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load Kokoro model: {e}")
|
||||
|
||||
async def generate(
|
||||
self, text: str, voice: Union[str, Tuple[str, Union[torch.Tensor, str]]], speed: float = 1.0
|
||||
self,
|
||||
text: str,
|
||||
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
|
||||
speed: float = 1.0
|
||||
) -> AsyncGenerator[np.ndarray, None]:
|
||||
"""Generate audio using model.
|
||||
|
||||
Args:
|
||||
text: Input text to synthesize
|
||||
voice: Either a voice path string or a tuple of (voice_name, voice_tensor_or_path)
|
||||
voice: Either a voice path string or a tuple of (voice_name, voice_tensor/path)
|
||||
speed: Speed multiplier
|
||||
|
||||
Yields:
|
||||
|
@ -87,53 +93,38 @@ class KokoroV1(BaseModelBackend):
|
|||
self._clear_memory()
|
||||
|
||||
# Handle voice input
|
||||
if isinstance(voice, str):
|
||||
voice_path = voice # Voice path provided directly
|
||||
logger.debug(f"Using voice path directly: {voice_path}")
|
||||
# Get language code from first letter of voice name
|
||||
try:
|
||||
name = os.path.basename(voice_path)
|
||||
logger.debug(f"Voice basename: {name}")
|
||||
if name.endswith('.pt'):
|
||||
name = name[:-3]
|
||||
lang_code = name[0]
|
||||
logger.debug(f"Extracted language code: {lang_code}")
|
||||
except Exception as e:
|
||||
# Default to American English if we can't get language code
|
||||
logger.warning(f"Failed to extract language code: {e}, defaulting to 'a'")
|
||||
lang_code = 'a'
|
||||
else:
|
||||
# Unpack voice name and tensor/path
|
||||
voice_path: str
|
||||
if isinstance(voice, tuple):
|
||||
voice_name, voice_data = voice
|
||||
# If voice_data is a path, use it directly
|
||||
if isinstance(voice_data, str):
|
||||
voice_path = voice_data
|
||||
logger.debug(f"Using provided voice path: {voice_path}")
|
||||
else:
|
||||
# Save tensor to temporary file
|
||||
import tempfile
|
||||
temp_dir = tempfile.gettempdir()
|
||||
voice_path = os.path.join(temp_dir, f"{voice_name}.pt")
|
||||
logger.debug(f"Saving voice tensor to: {voice_path}")
|
||||
torch.save(voice_data, voice_path)
|
||||
# Get language code from voice name
|
||||
lang_code = voice_name[0]
|
||||
logger.debug(f"Using language code '{lang_code}' from voice name {voice_name}")
|
||||
# Save tensor with CPU mapping for portability
|
||||
torch.save(voice_data.cpu(), voice_path)
|
||||
else:
|
||||
voice_path = voice
|
||||
|
||||
# Update pipeline's language code if needed
|
||||
if self._pipeline.lang_code != lang_code:
|
||||
logger.debug(f"Creating pipeline with lang_code='{lang_code}'")
|
||||
self._pipeline = KPipeline(lang_code=lang_code, model=self._model, device=self._device)
|
||||
# Load voice tensor with proper device mapping
|
||||
voice_tensor = await paths.load_voice_tensor(voice_path, device=self._device)
|
||||
# Save back to a temporary file with proper device mapping
|
||||
import tempfile
|
||||
temp_dir = tempfile.gettempdir()
|
||||
temp_path = os.path.join(temp_dir, f"temp_voice_{os.path.basename(voice_path)}")
|
||||
await paths.save_voice_tensor(voice_tensor, temp_path)
|
||||
voice_path = temp_path
|
||||
|
||||
# Generate audio using pipeline
|
||||
# Generate using pipeline, force model to prevent downloads
|
||||
logger.debug(f"Generating audio for text: '{text[:100]}...'")
|
||||
for i, result in enumerate(self._pipeline(text, voice=voice_path, speed=speed)):
|
||||
logger.debug(f"Processing chunk {i+1}")
|
||||
for result in self._pipeline(text, voice=voice_path, speed=speed, model=self._model):
|
||||
if result.audio is not None:
|
||||
logger.debug(f"Got audio chunk {i+1} with shape: {result.audio.shape}")
|
||||
logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
|
||||
yield result.audio.numpy()
|
||||
else:
|
||||
logger.warning(f"No audio in chunk {i+1}")
|
||||
logger.warning("No audio in chunk")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Generation failed: {e}")
|
||||
|
@ -146,9 +137,6 @@ class KokoroV1(BaseModelBackend):
|
|||
async for chunk in self.generate(text, voice, speed):
|
||||
yield chunk
|
||||
raise
|
||||
finally:
|
||||
if self._device == "cuda" and model_config.pytorch_gpu.sync_cuda:
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def _check_memory(self) -> bool:
|
||||
"""Check if memory usage is above threshold."""
|
||||
|
@ -161,7 +149,7 @@ class KokoroV1(BaseModelBackend):
|
|||
"""Clear device memory."""
|
||||
if self._device == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Unload model and free resources."""
|
||||
|
@ -173,7 +161,7 @@ class KokoroV1(BaseModelBackend):
|
|||
self._pipeline = None
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
|
|
|
@ -1,377 +1,156 @@
|
|||
"""Model management and caching."""
|
||||
"""Kokoro V1 model management."""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, Optional, Tuple, Union, AsyncGenerator
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from ..core import paths
|
||||
from ..core.config import settings
|
||||
from ..core.model_config import ModelConfig, model_config
|
||||
from ..core.config import settings
|
||||
from .base import BaseModelBackend
|
||||
from .onnx_cpu import ONNXCPUBackend
|
||||
from .onnx_gpu import ONNXGPUBackend
|
||||
from .pytorch_backend import PyTorchBackend
|
||||
from .kokoro_v1 import KokoroV1
|
||||
from .session_pool import CPUSessionPool, StreamingSessionPool
|
||||
|
||||
|
||||
# Global singleton instance and lock for thread-safe initialization
|
||||
_manager_instance = None
|
||||
_manager_lock = asyncio.Lock()
|
||||
|
||||
class ModelManager:
|
||||
"""Manages model loading and inference across backends."""
|
||||
# Class-level state for shared resources
|
||||
_loaded_models = {}
|
||||
_backends = {}
|
||||
"""Manages Kokoro V1 model loading and inference."""
|
||||
|
||||
# Singleton instance
|
||||
_instance = None
|
||||
|
||||
def __init__(self, config: Optional[ModelConfig] = None):
|
||||
"""Initialize model manager.
|
||||
Note:
|
||||
This should not be called directly. Use get_manager() instead.
|
||||
"""Initialize manager.
|
||||
|
||||
Args:
|
||||
config: Optional model configuration override
|
||||
"""
|
||||
self._config = config or model_config
|
||||
|
||||
# Initialize session pools
|
||||
self._session_pools = {
|
||||
'onnx_cpu': CPUSessionPool(),
|
||||
'onnx_gpu': StreamingSessionPool()
|
||||
}
|
||||
|
||||
# Initialize locks
|
||||
self._backend_locks: Dict[str, asyncio.Lock] = {}
|
||||
self._backend: Optional[KokoroV1] = None # Explicitly type as KokoroV1
|
||||
self._device: Optional[str] = None
|
||||
|
||||
def _determine_device(self) -> str:
|
||||
"""Determine device based on settings."""
|
||||
if settings.use_gpu and torch.cuda.is_available():
|
||||
return "cuda"
|
||||
return "cpu"
|
||||
return "cuda" if settings.use_gpu else "cpu"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize backends."""
|
||||
if self._backends:
|
||||
logger.debug("Using existing backend instances")
|
||||
return
|
||||
|
||||
device = self._determine_device()
|
||||
|
||||
"""Initialize Kokoro V1 backend."""
|
||||
try:
|
||||
# First check if we should use Kokoro V1
|
||||
if model_config.pytorch_kokoro_v1_file:
|
||||
self._backends['kokoro_v1'] = KokoroV1()
|
||||
self._current_backend = 'kokoro_v1'
|
||||
logger.info(f"Initialized new Kokoro V1 backend on {device}")
|
||||
# Otherwise use legacy backends
|
||||
elif device == "cuda":
|
||||
if settings.use_onnx:
|
||||
self._backends['onnx_gpu'] = ONNXGPUBackend()
|
||||
self._current_backend = 'onnx_gpu'
|
||||
logger.info("Initialized new ONNX GPU backend")
|
||||
else:
|
||||
self._backends['pytorch'] = PyTorchBackend()
|
||||
self._current_backend = 'pytorch'
|
||||
logger.info("Initialized new PyTorch backend on GPU")
|
||||
else:
|
||||
if settings.use_onnx:
|
||||
self._backends['onnx_cpu'] = ONNXCPUBackend()
|
||||
self._current_backend = 'onnx_cpu'
|
||||
logger.info("Initialized new ONNX CPU backend")
|
||||
else:
|
||||
self._backends['pytorch'] = PyTorchBackend()
|
||||
self._current_backend = 'pytorch'
|
||||
logger.info("Initialized new PyTorch backend on CPU")
|
||||
|
||||
# Initialize locks for each backend
|
||||
for backend in self._backends:
|
||||
self._backend_locks[backend] = asyncio.Lock()
|
||||
self._device = self._determine_device()
|
||||
logger.info(f"Initializing Kokoro V1 on {self._device}")
|
||||
self._backend = KokoroV1()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize backend: {e}")
|
||||
raise RuntimeError("Failed to initialize backend")
|
||||
raise RuntimeError(f"Failed to initialize Kokoro V1: {e}")
|
||||
|
||||
async def initialize_with_warmup(self, voice_manager) -> tuple[str, str, int]:
|
||||
"""Initialize model with warmup and pre-cache voices.
|
||||
"""Initialize and warm up model.
|
||||
|
||||
Args:
|
||||
voice_manager: Voice manager instance for loading voices
|
||||
voice_manager: Voice manager instance for warmup
|
||||
|
||||
Returns:
|
||||
Tuple of (device type, model type, number of loaded voices)
|
||||
Tuple of (device, backend type, voice count)
|
||||
|
||||
Raises:
|
||||
RuntimeError: If initialization fails
|
||||
"""
|
||||
import time
|
||||
start = time.perf_counter()
|
||||
|
||||
try:
|
||||
# First check if we should use Kokoro V1
|
||||
if model_config.pytorch_kokoro_v1_file:
|
||||
backend_type = 'kokoro_v1'
|
||||
# Otherwise determine legacy backend type
|
||||
elif settings.use_onnx:
|
||||
backend_type = 'onnx_gpu' if settings.use_gpu and torch.cuda.is_available() else 'onnx_cpu'
|
||||
else:
|
||||
backend_type = 'pytorch'
|
||||
|
||||
# Get backend
|
||||
backend = self.get_backend(backend_type)
|
||||
|
||||
# Get and verify model path
|
||||
if backend_type == 'kokoro_v1':
|
||||
model_file = model_config.pytorch_kokoro_v1_file
|
||||
else:
|
||||
model_file = model_config.pytorch_model_file if not settings.use_onnx else model_config.onnx_model_file
|
||||
model_path = await paths.get_model_path(model_file)
|
||||
|
||||
if not await paths.verify_model_path(model_path):
|
||||
raise RuntimeError(f"Model file not found: {model_path}")
|
||||
|
||||
# Pre-cache default voice and use for warmup
|
||||
warmup_voice_tensor = await voice_manager.load_voice(
|
||||
settings.default_voice, device=backend.device)
|
||||
logger.info(f"Pre-cached voice {settings.default_voice} for warmup")
|
||||
|
||||
# For Kokoro V1, wrap voice in tuple with name
|
||||
if isinstance(backend, KokoroV1):
|
||||
warmup_voice = (settings.default_voice, warmup_voice_tensor)
|
||||
else:
|
||||
warmup_voice = warmup_voice_tensor
|
||||
|
||||
# Initialize model with warmup voice
|
||||
await self.load_model(model_path, warmup_voice, backend_type)
|
||||
|
||||
# Only pre-cache default voice to avoid memory bloat
|
||||
logger.info(f"Using {settings.default_voice} as warmup voice")
|
||||
|
||||
# Get available voices count
|
||||
voices = await voice_manager.list_voices()
|
||||
voicepack_count = len(voices)
|
||||
|
||||
# Get device info for return
|
||||
device = "GPU" if settings.use_gpu else "CPU"
|
||||
model = "Kokoro V1" if backend_type == 'kokoro_v1' else ("ONNX" if settings.use_onnx else "PyTorch")
|
||||
|
||||
return device, model, voicepack_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize model with warmup: {e}")
|
||||
raise RuntimeError(f"Failed to initialize model with warmup: {e}")
|
||||
|
||||
def get_backend(self, backend_type: Optional[str] = None) -> BaseModelBackend:
|
||||
"""Get specified backend.
|
||||
Args:
|
||||
backend_type: Backend type ('pytorch_cpu', 'pytorch_gpu', 'onnx_cpu', 'onnx_gpu', 'kokoro_v1'),
|
||||
uses default if None
|
||||
Returns:
|
||||
Model backend instance
|
||||
Raises:
|
||||
ValueError: If backend type is invalid
|
||||
RuntimeError: If no backends are available
|
||||
"""
|
||||
if not self._backends:
|
||||
raise RuntimeError("No backends available")
|
||||
|
||||
if backend_type is None:
|
||||
backend_type = self._current_backend
|
||||
|
||||
if backend_type not in self._backends:
|
||||
raise ValueError(
|
||||
f"Invalid backend type: {backend_type}. "
|
||||
f"Available backends: {', '.join(self._backends.keys())}"
|
||||
)
|
||||
|
||||
return self._backends[backend_type]
|
||||
|
||||
def _determine_backend(self, model_path: str) -> str:
|
||||
"""Determine appropriate backend based on model file and settings.
|
||||
Args:
|
||||
model_path: Path to model file
|
||||
Returns:
|
||||
Backend type to use
|
||||
"""
|
||||
# Check if it's a Kokoro V1 model
|
||||
if model_path.endswith(model_config.pytorch_kokoro_v1_file):
|
||||
return 'kokoro_v1'
|
||||
# Otherwise use legacy backend determination
|
||||
elif settings.use_onnx or model_path.lower().endswith('.onnx'):
|
||||
return 'onnx_gpu' if settings.use_gpu and torch.cuda.is_available() else 'onnx_cpu'
|
||||
return 'pytorch'
|
||||
|
||||
async def load_model(
|
||||
self,
|
||||
model_path: str,
|
||||
warmup_voice: Optional[Union[str, Tuple[str, torch.Tensor]]] = None,
|
||||
backend_type: Optional[str] = None
|
||||
) -> None:
|
||||
"""Load model on specified backend.
|
||||
Args:
|
||||
model_path: Path to model file
|
||||
warmup_voice: Optional voice tensor for warmup, skips warmup if None
|
||||
backend_type: Backend to load on, uses default if None
|
||||
Raises:
|
||||
RuntimeError: If model loading fails
|
||||
"""
|
||||
try:
|
||||
# Get absolute model path
|
||||
abs_path = await paths.get_model_path(model_path)
|
||||
|
||||
# Auto-determine backend if not specified
|
||||
if backend_type is None:
|
||||
backend_type = self._determine_backend(abs_path)
|
||||
|
||||
# Get backend lock
|
||||
lock = self._backend_locks[backend_type]
|
||||
|
||||
async with lock:
|
||||
backend = self.get_backend(backend_type)
|
||||
|
||||
# For ONNX backends, use session pool
|
||||
if backend_type.startswith('onnx'):
|
||||
pool = self._session_pools[backend_type]
|
||||
backend._session = await pool.get_session(abs_path)
|
||||
self._loaded_models[backend_type] = abs_path
|
||||
logger.info(f"Fetched model instance from {backend_type} pool")
|
||||
|
||||
# For PyTorch and Kokoro backends, load normally
|
||||
else:
|
||||
# Check if model is already loaded
|
||||
if (backend_type in self._loaded_models and
|
||||
self._loaded_models[backend_type] == abs_path and
|
||||
backend.is_loaded):
|
||||
logger.info(f"Fetching existing model instance from {backend_type}")
|
||||
return
|
||||
# Initialize backend
|
||||
await self.initialize()
|
||||
|
||||
# Load model
|
||||
await backend.load_model(abs_path)
|
||||
self._loaded_models[backend_type] = abs_path
|
||||
logger.info(f"Initialized new model instance on {backend_type}")
|
||||
model_path = self._config.pytorch_kokoro_v1_file
|
||||
await self.load_model(model_path)
|
||||
|
||||
# Run warmup if voice provided
|
||||
if warmup_voice is not None:
|
||||
await self._warmup_inference(backend, warmup_voice)
|
||||
# Use paths module to get voice path
|
||||
try:
|
||||
voices = await paths.list_voices()
|
||||
voice_path = await paths.get_voice_path(
|
||||
settings.default_voice)
|
||||
|
||||
# Warm up with short text
|
||||
warmup_text = "Warmup text for initialization."
|
||||
async for _ in self.generate(warmup_text, voice_path):
|
||||
pass
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to get default voice: {e}")
|
||||
|
||||
ms = int((time.perf_counter() - start) * 1000)
|
||||
logger.info(f"Warmup completed in {ms}ms")
|
||||
|
||||
return self._device, "kokoro_v1", len(voices)
|
||||
|
||||
except Exception as e:
|
||||
# Clear cached path on failure
|
||||
self._loaded_models.pop(backend_type, None)
|
||||
raise RuntimeError(f"Warmup failed: {e}")
|
||||
|
||||
def get_backend(self) -> BaseModelBackend:
|
||||
"""Get initialized backend.
|
||||
|
||||
Returns:
|
||||
Initialized backend instance
|
||||
|
||||
Raises:
|
||||
RuntimeError: If backend not initialized
|
||||
"""
|
||||
if not self._backend:
|
||||
raise RuntimeError("Backend not initialized")
|
||||
return self._backend
|
||||
|
||||
async def load_model(self, path: str) -> None:
|
||||
"""Load model using initialized backend.
|
||||
|
||||
Args:
|
||||
path: Path to model file
|
||||
|
||||
Raises:
|
||||
RuntimeError: If loading fails
|
||||
"""
|
||||
if not self._backend:
|
||||
raise RuntimeError("Backend not initialized")
|
||||
|
||||
try:
|
||||
await self._backend.load_model(path)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load model: {e}")
|
||||
|
||||
async def _warmup_inference(
|
||||
self,
|
||||
backend: BaseModelBackend,
|
||||
voice: Union[str, Tuple[str, torch.Tensor]]
|
||||
) -> None:
|
||||
"""Run warmup inference to initialize model.
|
||||
|
||||
Args:
|
||||
backend: Model backend to warm up
|
||||
voice: Voice path or (name, tensor) tuple
|
||||
"""
|
||||
try:
|
||||
# Use real text for warmup
|
||||
text = "Testing text to speech synthesis."
|
||||
|
||||
# Run inference
|
||||
if isinstance(backend, KokoroV1):
|
||||
async for _ in backend.generate(text, voice, speed=1.0):
|
||||
pass # Just run through the chunks
|
||||
else:
|
||||
# Import here to avoid circular imports
|
||||
from ..services.text_processing import process_text
|
||||
tokens = process_text(text)
|
||||
if not tokens:
|
||||
raise ValueError("Text processing failed")
|
||||
# For legacy backends, extract tensor if needed
|
||||
voice_tensor = voice[1] if isinstance(voice, tuple) else voice
|
||||
backend.generate(tokens, voice_tensor, speed=1.0)
|
||||
logger.debug("Completed warmup inference")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Warmup inference failed: {e}")
|
||||
raise
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
input_text: str,
|
||||
voice: Union[str, Tuple[str, torch.Tensor]],
|
||||
speed: float = 1.0,
|
||||
backend_type: Optional[str] = None
|
||||
) -> AsyncGenerator[np.ndarray, None]:
|
||||
"""Generate audio using specified backend.
|
||||
|
||||
Args:
|
||||
input_text: Input text to synthesize
|
||||
voice: Voice path or (name, tensor) tuple
|
||||
speed: Speed multiplier
|
||||
backend_type: Backend to use, uses default if None
|
||||
|
||||
Yields:
|
||||
Generated audio chunks
|
||||
async def generate(self, *args, **kwargs):
|
||||
"""Generate audio using initialized backend.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If generation fails
|
||||
"""
|
||||
backend = self.get_backend(backend_type)
|
||||
if not backend.is_loaded:
|
||||
raise RuntimeError("Model not loaded")
|
||||
if not self._backend:
|
||||
raise RuntimeError("Backend not initialized")
|
||||
|
||||
try:
|
||||
# Generate audio using provided voice
|
||||
# No lock needed here since inference is thread-safe
|
||||
if isinstance(backend, KokoroV1):
|
||||
async for chunk in backend.generate(input_text, voice, speed):
|
||||
async for chunk in self._backend.generate(*args, **kwargs):
|
||||
yield chunk
|
||||
else:
|
||||
# Import here to avoid circular imports
|
||||
from ..services.text_processing import process_text
|
||||
tokens = process_text(input_text)
|
||||
if not tokens:
|
||||
raise ValueError("Text processing failed")
|
||||
# For legacy backends, extract tensor if needed
|
||||
voice_tensor = voice[1] if isinstance(voice, tuple) else voice
|
||||
yield backend.generate(tokens, voice_tensor, speed)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Generation failed: {e}")
|
||||
|
||||
def unload_all(self) -> None:
|
||||
"""Unload models from all backends and clear cache."""
|
||||
# Clean up session pools
|
||||
for pool in self._session_pools.values():
|
||||
pool.cleanup()
|
||||
|
||||
# Unload all backends
|
||||
for backend in self._backends.values():
|
||||
backend.unload()
|
||||
|
||||
self._loaded_models.clear()
|
||||
logger.info("Unloaded all models and cleared cache")
|
||||
|
||||
@property
|
||||
def available_backends(self) -> list[str]:
|
||||
"""Get list of available backends."""
|
||||
return list(self._backends.keys())
|
||||
"""Unload model and free resources."""
|
||||
if self._backend:
|
||||
self._backend.unload()
|
||||
self._backend = None
|
||||
|
||||
@property
|
||||
def current_backend(self) -> str:
|
||||
"""Get current default backend."""
|
||||
return self._current_backend
|
||||
"""Get current backend type."""
|
||||
return "kokoro_v1"
|
||||
|
||||
|
||||
async def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
|
||||
"""Get global model manager instance.
|
||||
"""Get model manager instance.
|
||||
|
||||
Args:
|
||||
config: Optional model configuration
|
||||
config: Optional configuration override
|
||||
|
||||
Returns:
|
||||
ModelManager instance
|
||||
Thread Safety:
|
||||
This function should be thread-safe. Lemme know if it unravels on you
|
||||
"""
|
||||
global _manager_instance
|
||||
|
||||
# Fast path - return existing instance without lock
|
||||
if _manager_instance is not None:
|
||||
return _manager_instance
|
||||
|
||||
# Slow path - create new instance with lock
|
||||
async with _manager_lock:
|
||||
# Double-check pattern
|
||||
if _manager_instance is None:
|
||||
_manager_instance = ModelManager(config)
|
||||
await _manager_instance.initialize()
|
||||
return _manager_instance
|
||||
if ModelManager._instance is None:
|
||||
ModelManager._instance = ModelManager(config)
|
||||
return ModelManager._instance
|
||||
|
|
|
@ -1,115 +0,0 @@
|
|||
"""CPU-based ONNX inference backend."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
from ..core import paths
|
||||
from ..core.model_config import model_config
|
||||
from .base import BaseModelBackend
|
||||
from .session_pool import create_session_options, create_provider_options
|
||||
|
||||
|
||||
class ONNXCPUBackend(BaseModelBackend):
|
||||
"""ONNX-based CPU inference backend."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize CPU backend."""
|
||||
super().__init__()
|
||||
self._device = "cpu"
|
||||
self._session: Optional[InferenceSession] = None
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if model is loaded."""
|
||||
return self._session is not None
|
||||
|
||||
async def load_model(self, path: str) -> None:
|
||||
"""Load ONNX model.
|
||||
|
||||
Args:
|
||||
path: Path to model file
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model loading fails
|
||||
"""
|
||||
try:
|
||||
# Get verified model path
|
||||
model_path = await paths.get_model_path(path)
|
||||
|
||||
logger.info(f"Loading ONNX model: {model_path}")
|
||||
|
||||
# Configure session
|
||||
options = create_session_options(is_gpu=False)
|
||||
provider_options = create_provider_options(is_gpu=False)
|
||||
|
||||
# Create session
|
||||
self._session = InferenceSession(
|
||||
model_path,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
provider_options=[provider_options]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load ONNX model: {e}")
|
||||
|
||||
def generate(
|
||||
self,
|
||||
tokens: list[int],
|
||||
voice: torch.Tensor,
|
||||
speed: float = 1.0
|
||||
) -> np.ndarray:
|
||||
"""Generate audio using ONNX model.
|
||||
|
||||
Args:
|
||||
tokens: Input token IDs
|
||||
voice: Voice embedding tensor
|
||||
speed: Speed multiplier
|
||||
|
||||
Returns:
|
||||
Generated audio samples
|
||||
|
||||
Raises:
|
||||
RuntimeError: If generation fails
|
||||
"""
|
||||
if not self.is_loaded:
|
||||
raise RuntimeError("Model not loaded")
|
||||
|
||||
try:
|
||||
# Prepare inputs with start/end tokens
|
||||
tokens_input = np.array([[0, *tokens, 0]], dtype=np.int64) # Add start/end tokens
|
||||
style_input = voice[len(tokens) + 2].numpy() # Adjust index for start/end tokens
|
||||
speed_input = np.full(1, speed, dtype=np.float32)
|
||||
|
||||
# Build base inputs
|
||||
inputs = {
|
||||
"style": style_input,
|
||||
"speed": speed_input
|
||||
}
|
||||
|
||||
# Try both possible token input names #TODO:
|
||||
for token_name in ["tokens", "input_ids"]:
|
||||
try:
|
||||
inputs[token_name] = tokens_input
|
||||
result = self._session.run(None, inputs)
|
||||
return result[0]
|
||||
except Exception:
|
||||
del inputs[token_name]
|
||||
continue
|
||||
|
||||
raise RuntimeError("Model does not accept either 'tokens' or 'input_ids' as input name")
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Generation failed: {e}")
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Unload model and free resources."""
|
||||
if self._session is not None:
|
||||
del self._session
|
||||
self._session = None
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
|
@ -1,119 +0,0 @@
|
|||
"""GPU-based ONNX inference backend."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
from ..core import paths
|
||||
from ..core.model_config import model_config
|
||||
from .base import BaseModelBackend
|
||||
from .session_pool import create_session_options, create_provider_options
|
||||
|
||||
|
||||
class ONNXGPUBackend(BaseModelBackend):
|
||||
"""ONNX-based GPU inference backend."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize GPU backend."""
|
||||
super().__init__()
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("CUDA not available")
|
||||
self._device = "cuda"
|
||||
self._session: Optional[InferenceSession] = None
|
||||
|
||||
# Configure GPU
|
||||
torch.cuda.set_device(model_config.onnx_gpu.device_id)
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if model is loaded."""
|
||||
return self._session is not None
|
||||
|
||||
async def load_model(self, path: str) -> None:
|
||||
"""Load ONNX model.
|
||||
|
||||
Args:
|
||||
path: Path to model file
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model loading fails
|
||||
"""
|
||||
try:
|
||||
# Get verified model path
|
||||
model_path = await paths.get_model_path(path)
|
||||
|
||||
logger.info(f"Loading ONNX model on GPU: {model_path}")
|
||||
|
||||
# Configure session
|
||||
options = create_session_options(is_gpu=True)
|
||||
provider_options = create_provider_options(is_gpu=True)
|
||||
|
||||
# Create session with CUDA provider
|
||||
self._session = InferenceSession(
|
||||
model_path,
|
||||
sess_options=options,
|
||||
providers=["CUDAExecutionProvider"],
|
||||
provider_options=[provider_options]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load ONNX model: {e}")
|
||||
|
||||
def generate(
|
||||
self,
|
||||
tokens: list[int],
|
||||
voice: torch.Tensor,
|
||||
speed: float = 1.0
|
||||
) -> np.ndarray:
|
||||
"""Generate audio using ONNX model.
|
||||
|
||||
Args:
|
||||
tokens: Input token IDs
|
||||
voice: Voice embedding tensor
|
||||
speed: Speed multiplier
|
||||
|
||||
Returns:
|
||||
Generated audio samples
|
||||
|
||||
Raises:
|
||||
RuntimeError: If generation fails
|
||||
"""
|
||||
if not self.is_loaded:
|
||||
raise RuntimeError("Model not loaded")
|
||||
|
||||
try:
|
||||
# Prepare inputs
|
||||
tokens_input = np.array([[0, *tokens, 0]], dtype=np.int64) # Add start/end tokens
|
||||
# Use modulo to ensure index stays within voice tensor bounds
|
||||
style_idx = (len(tokens) + 2) % voice.size(0) # Add 2 for start/end tokens
|
||||
style_input = voice[style_idx].cpu().numpy() # Move to CPU for ONNX
|
||||
speed_input = np.full(1, speed, dtype=np.float32)
|
||||
|
||||
# Run inference
|
||||
result = self._session.run(
|
||||
None,
|
||||
{
|
||||
"tokens": tokens_input,
|
||||
"style": style_input,
|
||||
"speed": speed_input
|
||||
}
|
||||
)
|
||||
|
||||
return result[0]
|
||||
|
||||
except Exception as e:
|
||||
if "out of memory" in str(e).lower():
|
||||
# Clear CUDA cache and retry
|
||||
torch.cuda.empty_cache()
|
||||
return self.generate(tokens, voice, speed)
|
||||
raise RuntimeError(f"Generation failed: {e}")
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Unload model and free resources."""
|
||||
if self._session is not None:
|
||||
del self._session
|
||||
self._session = None
|
||||
torch.cuda.empty_cache()
|
|
@ -1,244 +0,0 @@
|
|||
"""PyTorch inference backend with environment-based configuration."""
|
||||
|
||||
import gc
|
||||
from typing import Optional
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from ..builds.models import build_model
|
||||
from ..core import paths
|
||||
from ..core.model_config import model_config
|
||||
from ..core.config import settings
|
||||
from .base import BaseModelBackend
|
||||
|
||||
|
||||
class CUDAStreamManager:
|
||||
"""CUDA stream manager for GPU operations."""
|
||||
|
||||
def __init__(self, num_streams: int):
|
||||
"""Initialize stream manager.
|
||||
|
||||
Args:
|
||||
num_streams: Number of CUDA streams
|
||||
"""
|
||||
self.streams = [torch.cuda.Stream() for _ in range(num_streams)]
|
||||
self._current = 0
|
||||
|
||||
def get_next_stream(self) -> torch.cuda.Stream:
|
||||
"""Get next available stream.
|
||||
|
||||
Returns:
|
||||
CUDA stream
|
||||
"""
|
||||
stream = self.streams[self._current]
|
||||
self._current = (self._current + 1) % len(self.streams)
|
||||
return stream
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
model: torch.nn.Module,
|
||||
tokens: list[int],
|
||||
ref_s: torch.Tensor,
|
||||
speed: float,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
) -> np.ndarray:
|
||||
"""Forward pass through model.
|
||||
|
||||
Args:
|
||||
model: PyTorch model
|
||||
tokens: Input tokens
|
||||
ref_s: Reference signal
|
||||
speed: Speed multiplier
|
||||
stream: Optional CUDA stream (GPU only)
|
||||
|
||||
Returns:
|
||||
Generated audio
|
||||
"""
|
||||
device = ref_s.device
|
||||
|
||||
# Use provided stream or default for GPU
|
||||
context = (
|
||||
torch.cuda.stream(stream) if stream and device.type == "cuda" else nullcontext()
|
||||
)
|
||||
|
||||
with context:
|
||||
# Initial tensor setup
|
||||
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
||||
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
||||
text_mask = length_to_mask(input_lengths).to(device)
|
||||
|
||||
# Split reference signals
|
||||
style_dim = 128
|
||||
s_ref = ref_s[:, :style_dim].clone().to(device)
|
||||
s_content = ref_s[:, style_dim:].clone().to(device)
|
||||
|
||||
# BERT and encoder pass
|
||||
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
||||
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
||||
|
||||
# Predictor forward pass
|
||||
d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask)
|
||||
x, _ = model.predictor.lstm(d)
|
||||
|
||||
# Duration prediction
|
||||
duration = model.predictor.duration_proj(x)
|
||||
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
||||
pred_dur = torch.round(duration).clamp(min=1).long()
|
||||
del duration, x
|
||||
|
||||
# Alignment matrix construction
|
||||
pred_aln_trg = torch.zeros(
|
||||
input_lengths.item(), pred_dur.sum().item(), device=device
|
||||
)
|
||||
c_frame = 0
|
||||
for i in range(pred_aln_trg.size(0)):
|
||||
pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
|
||||
c_frame += pred_dur[0, i].item()
|
||||
pred_aln_trg = pred_aln_trg.unsqueeze(0)
|
||||
|
||||
# Matrix multiplications
|
||||
en = d.transpose(-1, -2) @ pred_aln_trg
|
||||
del d
|
||||
|
||||
F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
|
||||
del en
|
||||
|
||||
# Final text encoding and decoding
|
||||
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
||||
asr = t_en @ pred_aln_trg
|
||||
del t_en
|
||||
|
||||
# Generate output
|
||||
output = model.decoder(asr, F0_pred, N_pred, s_ref)
|
||||
|
||||
# Ensure operation completion if using custom stream
|
||||
if stream and device.type == "cuda":
|
||||
stream.synchronize()
|
||||
|
||||
return output.squeeze().cpu().numpy()
|
||||
|
||||
|
||||
def length_to_mask(lengths: torch.Tensor) -> torch.Tensor:
|
||||
"""Create attention mask from lengths."""
|
||||
max_len = lengths.max()
|
||||
mask = torch.arange(max_len, device=lengths.device)[None, :].expand(
|
||||
lengths.shape[0], -1
|
||||
)
|
||||
if lengths.dtype != mask.dtype:
|
||||
mask = mask.to(dtype=lengths.dtype)
|
||||
return mask + 1 > lengths[:, None]
|
||||
|
||||
|
||||
class PyTorchBackend(BaseModelBackend):
|
||||
"""PyTorch inference backend with environment-based configuration."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize backend based on environment configuration."""
|
||||
super().__init__()
|
||||
|
||||
# Configure device based on settings
|
||||
self._device = (
|
||||
"cuda" if settings.use_gpu and torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
self._model: Optional[torch.nn.Module] = None
|
||||
|
||||
# Apply device-specific configurations
|
||||
if self._device == "cuda":
|
||||
config = model_config.pytorch_gpu
|
||||
if config.sync_cuda:
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.set_device(config.device_id)
|
||||
self._stream_manager = CUDAStreamManager(config.cuda_streams)
|
||||
else:
|
||||
config = model_config.pytorch_cpu
|
||||
if config.num_threads > 0:
|
||||
torch.set_num_threads(config.num_threads)
|
||||
if config.pin_memory:
|
||||
torch.set_default_tensor_type(torch.FloatTensor)
|
||||
|
||||
async def load_model(self, path: str) -> None:
|
||||
"""Load PyTorch model.
|
||||
|
||||
Args:
|
||||
path: Path to model file
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model loading fails
|
||||
"""
|
||||
try:
|
||||
# Get verified model path
|
||||
model_path = await paths.get_model_path(path)
|
||||
|
||||
logger.info(f"Loading PyTorch model on {self._device}: {model_path}")
|
||||
self._model = await build_model(model_path, self._device)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load PyTorch model: {e}")
|
||||
|
||||
def generate(
|
||||
self, tokens: list[int], voice: torch.Tensor, speed: float = 1.0
|
||||
) -> np.ndarray:
|
||||
"""Generate audio using model.
|
||||
|
||||
Args:
|
||||
tokens: Input token IDs
|
||||
voice: Voice embedding tensor
|
||||
speed: Speed multiplier
|
||||
|
||||
Returns:
|
||||
Generated audio samples
|
||||
|
||||
Raises:
|
||||
RuntimeError: If generation fails
|
||||
"""
|
||||
if not self.is_loaded:
|
||||
raise RuntimeError("Model not loaded")
|
||||
|
||||
try:
|
||||
# Memory management for GPU
|
||||
if self._device == "cuda":
|
||||
if self._check_memory():
|
||||
self._clear_memory()
|
||||
stream = self._stream_manager.get_next_stream()
|
||||
else:
|
||||
stream = None
|
||||
|
||||
# Get reference style from voice pack
|
||||
ref_s = voice[len(tokens)].clone().to(self._device)
|
||||
if ref_s.dim() == 1:
|
||||
ref_s = ref_s.unsqueeze(0)
|
||||
|
||||
# Generate audio
|
||||
return forward(self._model, tokens, ref_s, speed, stream)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Generation failed: {e}")
|
||||
if (
|
||||
self._device == "cuda"
|
||||
and model_config.pytorch_gpu.retry_on_oom
|
||||
and "out of memory" in str(e).lower()
|
||||
):
|
||||
self._clear_memory()
|
||||
return self.generate(tokens, voice, speed)
|
||||
raise
|
||||
finally:
|
||||
if self._device == "cuda" and model_config.pytorch_gpu.sync_cuda:
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def _check_memory(self) -> bool:
|
||||
"""Check if memory usage is above threshold."""
|
||||
if self._device == "cuda":
|
||||
memory_gb = torch.cuda.memory_allocated() / 1e9
|
||||
return memory_gb > model_config.pytorch_gpu.memory_threshold
|
||||
return False
|
||||
|
||||
def _clear_memory(self) -> None:
|
||||
"""Clear device memory."""
|
||||
if self._device == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
|
@ -1,272 +0,0 @@
|
|||
"""Session pooling for model inference."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Set
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from onnxruntime import (
|
||||
ExecutionMode,
|
||||
GraphOptimizationLevel,
|
||||
InferenceSession,
|
||||
SessionOptions
|
||||
)
|
||||
|
||||
from ..core import paths
|
||||
from ..core.model_config import model_config
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionInfo:
|
||||
"""Session information."""
|
||||
session: InferenceSession
|
||||
last_used: float
|
||||
stream_id: Optional[int] = None
|
||||
|
||||
|
||||
def create_session_options(is_gpu: bool = False) -> SessionOptions:
|
||||
"""Create ONNX session options.
|
||||
|
||||
Args:
|
||||
is_gpu: Whether to use GPU configuration
|
||||
|
||||
Returns:
|
||||
Configured session options
|
||||
"""
|
||||
options = SessionOptions()
|
||||
config = model_config.onnx_gpu if is_gpu else model_config.onnx_cpu
|
||||
|
||||
# Set optimization level
|
||||
if config.optimization_level == "all":
|
||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
elif config.optimization_level == "basic":
|
||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||
else:
|
||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
|
||||
|
||||
# Configure threading
|
||||
options.intra_op_num_threads = config.num_threads
|
||||
options.inter_op_num_threads = config.inter_op_threads
|
||||
|
||||
# Set execution mode
|
||||
options.execution_mode = (
|
||||
ExecutionMode.ORT_PARALLEL
|
||||
if config.execution_mode == "parallel"
|
||||
else ExecutionMode.ORT_SEQUENTIAL
|
||||
)
|
||||
|
||||
# Configure memory optimization
|
||||
options.enable_mem_pattern = config.memory_pattern
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def create_provider_options(is_gpu: bool = False) -> Dict:
|
||||
"""Create provider options.
|
||||
|
||||
Args:
|
||||
is_gpu: Whether to use GPU configuration
|
||||
|
||||
Returns:
|
||||
Provider configuration
|
||||
"""
|
||||
if is_gpu:
|
||||
config = model_config.onnx_gpu
|
||||
return {
|
||||
"device_id": config.device_id,
|
||||
"arena_extend_strategy": config.arena_extend_strategy,
|
||||
"gpu_mem_limit": int(config.gpu_mem_limit * torch.cuda.get_device_properties(0).total_memory),
|
||||
"cudnn_conv_algo_search": config.cudnn_conv_algo_search,
|
||||
"do_copy_in_default_stream": config.do_copy_in_default_stream
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"arena_extend_strategy": model_config.onnx_cpu.arena_extend_strategy,
|
||||
"cpu_memory_arena_cfg": "cpu:0"
|
||||
}
|
||||
|
||||
|
||||
class BaseSessionPool:
|
||||
"""Base session pool implementation."""
|
||||
|
||||
def __init__(self, max_size: int, timeout: int):
|
||||
"""Initialize session pool.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of concurrent sessions
|
||||
timeout: Session timeout in seconds
|
||||
"""
|
||||
self._max_size = max_size
|
||||
self._timeout = timeout
|
||||
self._sessions: Dict[str, SessionInfo] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def get_session(self, model_path: str) -> InferenceSession:
|
||||
"""Get session from pool.
|
||||
|
||||
Args:
|
||||
model_path: Path to model file
|
||||
|
||||
Returns:
|
||||
ONNX inference session
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no sessions available
|
||||
"""
|
||||
async with self._lock:
|
||||
# Clean expired sessions
|
||||
self._cleanup_expired()
|
||||
|
||||
# TODO: Change session tracking to use unique IDs instead of model paths
|
||||
# This would allow multiple instances of the same model
|
||||
|
||||
# Check if session exists and is valid
|
||||
if model_path in self._sessions:
|
||||
session_info = self._sessions[model_path]
|
||||
session_info.last_used = time.time()
|
||||
return session_info.session
|
||||
|
||||
# TODO: Modify session limit check to count instances per model path
|
||||
# Rather than total sessions across all models
|
||||
if len(self._sessions) >= self._max_size:
|
||||
raise RuntimeError(
|
||||
f"Maximum number of sessions reached ({self._max_size}). "
|
||||
"Try again later or reduce concurrent requests."
|
||||
)
|
||||
|
||||
# Create new session
|
||||
session = await self._create_session(model_path)
|
||||
self._sessions[model_path] = SessionInfo(
|
||||
session=session,
|
||||
last_used=time.time()
|
||||
)
|
||||
return session
|
||||
|
||||
def _cleanup_expired(self) -> None:
|
||||
"""Remove expired sessions."""
|
||||
current_time = time.time()
|
||||
expired = [
|
||||
path for path, info in self._sessions.items()
|
||||
if current_time - info.last_used > self._timeout
|
||||
]
|
||||
for path in expired:
|
||||
logger.info(f"Removing expired session: {path}")
|
||||
del self._sessions[path]
|
||||
|
||||
async def _create_session(self, model_path: str) -> InferenceSession:
|
||||
"""Create new session.
|
||||
|
||||
Args:
|
||||
model_path: Path to model file
|
||||
|
||||
Returns:
|
||||
ONNX inference session
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Clean up all sessions."""
|
||||
self._sessions.clear()
|
||||
|
||||
|
||||
class StreamingSessionPool(BaseSessionPool):
|
||||
"""GPU session pool with CUDA streams."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize GPU session pool."""
|
||||
config = model_config.onnx_gpu
|
||||
super().__init__(config.cuda_streams, config.stream_timeout)
|
||||
self._available_streams: Set[int] = set(range(config.cuda_streams))
|
||||
|
||||
async def get_session(self, model_path: str) -> InferenceSession:
|
||||
"""Get session with CUDA stream.
|
||||
|
||||
Args:
|
||||
model_path: Path to model file
|
||||
|
||||
Returns:
|
||||
ONNX inference session
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no streams available
|
||||
"""
|
||||
async with self._lock:
|
||||
# Clean expired sessions
|
||||
self._cleanup_expired()
|
||||
|
||||
# Try to find existing session
|
||||
if model_path in self._sessions:
|
||||
session_info = self._sessions[model_path]
|
||||
session_info.last_used = time.time()
|
||||
return session_info.session
|
||||
|
||||
# Get available stream
|
||||
if not self._available_streams:
|
||||
raise RuntimeError("No CUDA streams available")
|
||||
stream_id = self._available_streams.pop()
|
||||
|
||||
try:
|
||||
# Create new session
|
||||
session = await self._create_session(model_path)
|
||||
self._sessions[model_path] = SessionInfo(
|
||||
session=session,
|
||||
last_used=time.time(),
|
||||
stream_id=stream_id
|
||||
)
|
||||
return session
|
||||
|
||||
except Exception:
|
||||
# Return stream to pool on failure
|
||||
self._available_streams.add(stream_id)
|
||||
raise
|
||||
|
||||
def _cleanup_expired(self) -> None:
|
||||
"""Remove expired sessions and return streams."""
|
||||
current_time = time.time()
|
||||
expired = [
|
||||
path for path, info in self._sessions.items()
|
||||
if current_time - info.last_used > self._timeout
|
||||
]
|
||||
for path in expired:
|
||||
info = self._sessions[path]
|
||||
if info.stream_id is not None:
|
||||
self._available_streams.add(info.stream_id)
|
||||
logger.info(f"Removing expired session: {path}")
|
||||
del self._sessions[path]
|
||||
|
||||
async def _create_session(self, model_path: str) -> InferenceSession:
|
||||
"""Create new session with CUDA provider."""
|
||||
abs_path = await paths.get_model_path(model_path)
|
||||
options = create_session_options(is_gpu=True)
|
||||
provider_options = create_provider_options(is_gpu=True)
|
||||
|
||||
return InferenceSession(
|
||||
abs_path,
|
||||
sess_options=options,
|
||||
providers=["CUDAExecutionProvider"],
|
||||
provider_options=[provider_options]
|
||||
)
|
||||
|
||||
|
||||
class CPUSessionPool(BaseSessionPool):
|
||||
"""CPU session pool."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize CPU session pool."""
|
||||
config = model_config.onnx_cpu
|
||||
super().__init__(config.max_instances, config.instance_timeout)
|
||||
|
||||
async def _create_session(self, model_path: str) -> InferenceSession:
|
||||
"""Create new session with CPU provider."""
|
||||
abs_path = await paths.get_model_path(model_path)
|
||||
options = create_session_options(is_gpu=False)
|
||||
provider_options = create_provider_options(is_gpu=False)
|
||||
|
||||
return InferenceSession(
|
||||
abs_path,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
provider_options=[provider_options]
|
||||
)
|
|
@ -1,6 +1,5 @@
|
|||
"""Voice pack management and caching."""
|
||||
"""Voice management with controlled resource handling."""
|
||||
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
@ -8,248 +7,110 @@ from loguru import logger
|
|||
|
||||
from ..core import paths
|
||||
from ..core.config import settings
|
||||
from ..core.model_config import model_config
|
||||
from ..structures.model_schemas import VoiceConfig
|
||||
|
||||
|
||||
class VoiceManager:
|
||||
"""Manages voice loading and operations."""
|
||||
"""Manages voice loading and caching with controlled resource usage."""
|
||||
|
||||
def __init__(self, config: Optional[VoiceConfig] = None):
|
||||
"""Initialize voice manager.
|
||||
# Singleton instance
|
||||
_instance = None
|
||||
|
||||
Args:
|
||||
config: Optional voice configuration
|
||||
"""
|
||||
self._config = config or VoiceConfig()
|
||||
self._voice_cache: Dict[str, torch.Tensor] = {}
|
||||
def __init__(self):
|
||||
"""Initialize voice manager."""
|
||||
# Strictly respect settings.use_gpu
|
||||
self._device = "cuda" if settings.use_gpu else "cpu"
|
||||
self._voices: Dict[str, torch.Tensor] = {}
|
||||
|
||||
def get_voice_path(self, voice_name: str) -> Optional[str]:
|
||||
async def get_voice_path(self, voice_name: str) -> str:
|
||||
"""Get path to voice file.
|
||||
|
||||
Args:
|
||||
voice_name: Name of voice
|
||||
|
||||
Returns:
|
||||
Path to voice file if exists, None otherwise
|
||||
Path to voice file
|
||||
|
||||
Raises:
|
||||
RuntimeError: If voice not found
|
||||
"""
|
||||
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
voices_dir = os.path.join(api_dir, settings.voices_dir)
|
||||
return await paths.get_voice_path(voice_name)
|
||||
|
||||
logger.debug(f"Looking for voice: {voice_name}")
|
||||
logger.debug(f"Base voices directory: {voices_dir}")
|
||||
|
||||
# Check v1_0 subdirectory first if using Kokoro V1
|
||||
if model_config.pytorch_kokoro_v1_file:
|
||||
v1_path = os.path.join(voices_dir, 'v1_0', f"{voice_name}.pt")
|
||||
logger.debug(f"Checking v1_0 path: {v1_path}")
|
||||
if os.path.exists(v1_path):
|
||||
logger.debug(f"Found voice in v1_0: {v1_path}")
|
||||
return v1_path
|
||||
|
||||
# Fall back to main voices directory
|
||||
voice_path = os.path.join(voices_dir, f"{voice_name}.pt")
|
||||
logger.debug(f"Checking main path: {voice_path}")
|
||||
if os.path.exists(voice_path):
|
||||
logger.debug(f"Found voice in main dir: {voice_path}")
|
||||
return voice_path
|
||||
|
||||
logger.debug(f"Voice not found: {voice_name}")
|
||||
return None
|
||||
|
||||
async def load_voice(self, voice_name: str, device: str = "cpu") -> torch.Tensor:
|
||||
async def load_voice(self, voice_name: str, device: Optional[str] = None) -> torch.Tensor:
|
||||
"""Load voice tensor.
|
||||
|
||||
Args:
|
||||
voice_name: Name of voice to load
|
||||
device: Device to load voice on
|
||||
device: Optional override for target device
|
||||
|
||||
Returns:
|
||||
Voice tensor
|
||||
|
||||
Raises:
|
||||
RuntimeError: If voice loading fails
|
||||
RuntimeError: If voice not found
|
||||
"""
|
||||
# Check if it's a combined voice request
|
||||
if "+" in voice_name:
|
||||
voices = [v.strip() for v in voice_name.split("+") if v.strip()]
|
||||
if len(voices) < 2:
|
||||
raise RuntimeError(f"Invalid combined voice name: {voice_name}")
|
||||
|
||||
# Load and combine voices
|
||||
voice_tensors = []
|
||||
for voice in voices:
|
||||
try:
|
||||
voice_tensor = await self.load_voice(voice, device)
|
||||
voice_tensors.append(voice_tensor)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load base voice {voice}: {e}")
|
||||
|
||||
return torch.mean(torch.stack(voice_tensors), dim=0)
|
||||
|
||||
# Handle single voice
|
||||
voice_path = self.get_voice_path(voice_name)
|
||||
if not voice_path:
|
||||
raise RuntimeError(f"Voice not found: {voice_name}")
|
||||
|
||||
# Check cache
|
||||
cache_key = f"{voice_path}_{device}"
|
||||
if self._config.use_cache and cache_key in self._voice_cache:
|
||||
logger.debug(f"Using cached voice: {voice_name} from {voice_path}")
|
||||
return self._voice_cache[cache_key]
|
||||
|
||||
# Load voice tensor
|
||||
try:
|
||||
logger.debug(f"Loading voice tensor from: {voice_path}")
|
||||
voice = await paths.load_voice_tensor(voice_path, device=device)
|
||||
voice_path = await self.get_voice_path(voice_name)
|
||||
target_device = device or self._device
|
||||
voice = await paths.load_voice_tensor(voice_path, target_device)
|
||||
self._voices[voice_name] = voice
|
||||
return voice
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load voice {voice_name}: {e}")
|
||||
|
||||
# Cache if enabled
|
||||
if self._config.use_cache:
|
||||
self._manage_cache()
|
||||
self._voice_cache[cache_key] = voice
|
||||
logger.debug(f"Cached voice: {voice_name} on {device} from {voice_path}")
|
||||
|
||||
return voice
|
||||
|
||||
def _manage_cache(self) -> None:
|
||||
"""Manage voice cache size using simple LRU."""
|
||||
if len(self._voice_cache) >= self._config.cache_size:
|
||||
# Remove least recently used voice
|
||||
oldest = next(iter(self._voice_cache))
|
||||
del self._voice_cache[oldest]
|
||||
torch.cuda.empty_cache() # Clean up GPU memory if needed
|
||||
logger.debug(f"Removed LRU voice from cache: {oldest}")
|
||||
|
||||
async def combine_voices(self, voices: List[str], device: str = "cpu") -> str:
|
||||
"""Combine multiple voices into a new voice.
|
||||
async def combine_voices(self, voices: List[str], device: Optional[str] = None) -> str:
|
||||
"""Combine multiple voices.
|
||||
|
||||
Args:
|
||||
voices: List of voice names to combine
|
||||
device: Device to load voices on
|
||||
device: Optional override for target device
|
||||
|
||||
Returns:
|
||||
Name of combined voice
|
||||
|
||||
Raises:
|
||||
ValueError: If fewer than 2 voices provided
|
||||
RuntimeError: If voice combination fails
|
||||
RuntimeError: If any voice not found
|
||||
"""
|
||||
if len(voices) < 2:
|
||||
raise ValueError("At least 2 voices are required for combination")
|
||||
raise ValueError("Need at least 2 voices to combine")
|
||||
|
||||
# Create combined name using + as separator
|
||||
target_device = device or self._device
|
||||
voice_tensors = []
|
||||
for name in voices:
|
||||
voice = await self.load_voice(name, target_device)
|
||||
voice_tensors.append(voice)
|
||||
|
||||
combined = torch.mean(torch.stack(voice_tensors), dim=0)
|
||||
combined_name = "+".join(voices)
|
||||
|
||||
# If saving is enabled, try to save the combination
|
||||
if settings.allow_local_voice_saving:
|
||||
try:
|
||||
# Load and combine voices
|
||||
combined_tensor = await self.load_voice(combined_name, device)
|
||||
|
||||
# Save to disk
|
||||
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
voices_dir = os.path.join(api_dir, settings.voices_dir)
|
||||
|
||||
# Save in v1_0 directory if using Kokoro V1
|
||||
if model_config.pytorch_kokoro_v1_file:
|
||||
voices_dir = os.path.join(voices_dir, 'v1_0')
|
||||
|
||||
os.makedirs(voices_dir, exist_ok=True)
|
||||
|
||||
combined_path = os.path.join(voices_dir, f"{combined_name}.pt")
|
||||
try:
|
||||
torch.save(combined_tensor, combined_path)
|
||||
# Cache with path-based key
|
||||
self._voice_cache[f"{combined_path}_{device}"] = combined_tensor
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to save combined voice: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save combined voice: {e}")
|
||||
# Continue without saving - will be combined on-the-fly when needed
|
||||
self._voices[combined_name] = combined
|
||||
|
||||
return combined_name
|
||||
|
||||
async def list_voices(self) -> List[str]:
|
||||
"""List available voices.
|
||||
"""List available voice names.
|
||||
|
||||
Returns:
|
||||
List of voice names
|
||||
"""
|
||||
voices = set() # Use set to avoid duplicates
|
||||
try:
|
||||
# Get voices from disk
|
||||
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
voices_dir = os.path.join(api_dir, settings.voices_dir)
|
||||
os.makedirs(voices_dir, exist_ok=True)
|
||||
return await paths.list_voices()
|
||||
|
||||
# Check v1_0 subdirectory if using Kokoro V1
|
||||
if model_config.pytorch_kokoro_v1_file:
|
||||
v1_dir = os.path.join(voices_dir, 'v1_0')
|
||||
logger.debug(f"Checking v1_0 directory: {v1_dir}")
|
||||
if os.path.exists(v1_dir):
|
||||
for entry in os.listdir(v1_dir):
|
||||
if entry.endswith(".pt"):
|
||||
voices.add(entry[:-3])
|
||||
logger.debug(f"Found v1_0 voice: {entry[:-3]}")
|
||||
else:
|
||||
# Check main voices directory
|
||||
for entry in os.listdir(voices_dir):
|
||||
if entry.endswith(".pt"):
|
||||
voices.add(entry[:-3])
|
||||
logger.debug(f"Found main voice: {entry[:-3]}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing voices: {e}")
|
||||
return sorted(list(voices))
|
||||
|
||||
def validate_voice(self, voice_path: str) -> bool:
|
||||
"""Validate voice file.
|
||||
|
||||
Args:
|
||||
voice_path: Path to voice file
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
if not os.path.exists(voice_path):
|
||||
return False
|
||||
voice = torch.load(voice_path, map_location="cpu", weights_only=False)
|
||||
return isinstance(voice, torch.Tensor)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@property
|
||||
def cache_info(self) -> Dict[str, int]:
|
||||
"""Get cache statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with cache info
|
||||
Dict with cache statistics
|
||||
"""
|
||||
return {
|
||||
'size': len(self._voice_cache),
|
||||
'max_size': self._config.cache_size
|
||||
"loaded_voices": len(self._voices),
|
||||
"device": self._device
|
||||
}
|
||||
|
||||
|
||||
# Global singleton instance and lock
|
||||
_manager_instance = None
|
||||
|
||||
|
||||
async def get_manager(config: Optional[VoiceConfig] = None) -> VoiceManager:
|
||||
"""Get global voice manager instance.
|
||||
|
||||
Args:
|
||||
config: Optional voice configuration
|
||||
async def get_manager() -> VoiceManager:
|
||||
"""Get voice manager instance.
|
||||
|
||||
Returns:
|
||||
VoiceManager instance
|
||||
"""
|
||||
global _manager_instance
|
||||
|
||||
if _manager_instance is None:
|
||||
_manager_instance = VoiceManager(config)
|
||||
return _manager_instance
|
||||
if VoiceManager._instance is None:
|
||||
VoiceManager._instance = VoiceManager()
|
||||
return VoiceManager._instance
|
|
@ -55,23 +55,23 @@ async def lifespan(app: FastAPI):
|
|||
logger.info("Loading TTS model and voice packs...")
|
||||
|
||||
try:
|
||||
# Initialize managers globally
|
||||
# Initialize managers
|
||||
model_manager = await get_manager()
|
||||
voice_manager = await get_voice_manager()
|
||||
|
||||
# Initialize model with warmup and get status
|
||||
device, model, voicepack_count = await model_manager.initialize_with_warmup(voice_manager)
|
||||
device, model, voicepack_count = await model_manager\
|
||||
.initialize_with_warmup(voice_manager)
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error("""
|
||||
Model files not found! You need to either:
|
||||
Model files not found! You need to download the Kokoro V1 model:
|
||||
|
||||
1. Download models using the scripts:
|
||||
GPU: python docker/scripts/download_model.py --type pth
|
||||
CPU: python docker/scripts/download_model.py --type onnx
|
||||
1. Download model using the script:
|
||||
python docker/scripts/download_model.py --version v1_0 --output api/src/models/v1_0
|
||||
|
||||
2. Set environment variables in docker-compose:
|
||||
GPU: DOWNLOAD_PTH=true
|
||||
CPU: DOWNLOAD_ONNX=true
|
||||
2. Or set environment variable in docker-compose:
|
||||
DOWNLOAD_MODEL=true
|
||||
""")
|
||||
raise
|
||||
except Exception as e:
|
||||
|
|
|
@ -68,6 +68,7 @@ def get_model_name(model: str) -> str:
|
|||
extension = ".onnx" if settings.use_onnx else ".pth"
|
||||
return base_name + extension
|
||||
|
||||
|
||||
async def process_voices(
|
||||
voice_input: Union[str, List[str]], tts_service: TTSService
|
||||
) -> str:
|
||||
|
|
|
@ -161,7 +161,7 @@ class TTSService:
|
|||
# Load and combine voices
|
||||
voice_tensors = []
|
||||
for v in voices:
|
||||
path = self._voice_manager.get_voice_path(v)
|
||||
path = await self._voice_manager.get_voice_path(v)
|
||||
if not path:
|
||||
raise RuntimeError(f"Voice not found: {v}")
|
||||
logger.debug(f"Loading voice tensor from: {path}")
|
||||
|
@ -181,7 +181,7 @@ class TTSService:
|
|||
return voice, combined_path
|
||||
else:
|
||||
# Single voice
|
||||
path = self._voice_manager.get_voice_path(voice)
|
||||
path = await self._voice_manager.get_voice_path(voice)
|
||||
if not path:
|
||||
raise RuntimeError(f"Voice not found: {voice}")
|
||||
logger.debug(f"Using single voice path: {path}")
|
||||
|
|
Binary file not shown.
BIN
api/src/voices/v1_0/v0_af_irulan.pt
Normal file
BIN
api/src/voices/v1_0/v0_af_irulan.pt
Normal file
Binary file not shown.
|
@ -1,103 +1,86 @@
|
|||
#!/usr/bin/env python3
|
||||
import os
|
||||
import sys
|
||||
"""Download and prepare Kokoro model for Docker build."""
|
||||
|
||||
import argparse
|
||||
import requests
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
def download_file(url: str, output_dir: Path, model_type: str, overwrite:str) -> bool:
|
||||
"""Download a file from URL to the specified directory.
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from loguru import logger
|
||||
|
||||
Returns:
|
||||
bool: True if download succeeded, False otherwise
|
||||
|
||||
def download_model(version: str, output_dir: str) -> None:
|
||||
"""Download model files from HuggingFace.
|
||||
|
||||
Args:
|
||||
version: Model version to download
|
||||
output_dir: Directory to save model files
|
||||
"""
|
||||
filename = os.path.basename(url)
|
||||
if not filename.endswith(f'.{model_type}'):
|
||||
print(f"Warning: {filename} is not a .{model_type} file", file=sys.stderr)
|
||||
return False
|
||||
|
||||
output_path = output_dir / filename
|
||||
|
||||
if os.path.exists(output_path):
|
||||
print(f"{filename} exists. Canceling download")
|
||||
return True
|
||||
|
||||
print(f"Downloading {filename}...")
|
||||
try:
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
logger.info(f"Downloading Kokoro model version {version}")
|
||||
|
||||
# Create output directory
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Download model files
|
||||
model_file = hf_hub_download(
|
||||
repo_id="hexgrad/Kokoro-82M",
|
||||
filename=f"kokoro-{version}.pth"
|
||||
)
|
||||
config_file = hf_hub_download(
|
||||
repo_id="hexgrad/Kokoro-82M",
|
||||
filename="config.json"
|
||||
)
|
||||
|
||||
# Copy to output directory
|
||||
shutil.copy2(model_file, os.path.join(output_dir, "model.pt"))
|
||||
shutil.copy2(config_file, os.path.join(output_dir, "config.json"))
|
||||
|
||||
# Verify files
|
||||
model_path = os.path.join(output_dir, "model.pt")
|
||||
config_path = os.path.join(output_dir, "config.json")
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
raise RuntimeError(f"Model file not found: {model_path}")
|
||||
if not os.path.exists(config_path):
|
||||
raise RuntimeError(f"Config file not found: {config_path}")
|
||||
|
||||
# Load and verify model
|
||||
logger.info("Verifying model files...")
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
logger.info(f"Loaded config: {config}")
|
||||
|
||||
model = torch.load(model_path, map_location="cpu")
|
||||
logger.info(f"Loaded model with keys: {model.keys()}")
|
||||
|
||||
logger.info(f"✓ Model files prepared in {output_dir}")
|
||||
|
||||
with open(output_path, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
print(f"Successfully downloaded {filename}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error downloading {filename}: {e}", file=sys.stderr)
|
||||
return False
|
||||
logger.error(f"Failed to download model: {e}")
|
||||
raise
|
||||
|
||||
def find_project_root() -> Path:
|
||||
"""Find project root by looking for api directory."""
|
||||
max_steps = 5
|
||||
current = Path(__file__).resolve()
|
||||
for _ in range(max_steps):
|
||||
if (current / 'api').is_dir():
|
||||
return current
|
||||
current = current.parent
|
||||
raise RuntimeError("Could not find project root (no api directory found)")
|
||||
|
||||
def main() -> int:
|
||||
"""Download models to the project.
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
parser = argparse.ArgumentParser(description="Download Kokoro model for Docker build")
|
||||
parser.add_argument(
|
||||
"--version",
|
||||
default="v1_0",
|
||||
help="Model version to download"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
required=True,
|
||||
help="Output directory for model files"
|
||||
)
|
||||
|
||||
Returns:
|
||||
int: Exit code (0 for success, 1 for failure)
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='Download model files')
|
||||
parser.add_argument('--type', choices=['pth', 'onnx'], required=True,
|
||||
help='Model type to download (pth or onnx)')
|
||||
parser.add_argument('--overwrite', action='store_true', help='Overwite existing files')
|
||||
parser.add_argument('urls', nargs='*', help='Optional model URLs to download')
|
||||
args = parser.parse_args()
|
||||
download_model(args.version, args.output)
|
||||
|
||||
try:
|
||||
# Find project root and ensure models directory exists
|
||||
project_root = find_project_root()
|
||||
models_dir = project_root / 'api' / 'src' / 'models'
|
||||
print(f"Downloading models to {models_dir}")
|
||||
models_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Default models if no arguments provided
|
||||
default_models = {
|
||||
'pth': [
|
||||
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth",
|
||||
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19-half.pth"
|
||||
],
|
||||
'onnx': [
|
||||
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.onnx",
|
||||
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19_fp16.onnx"
|
||||
]
|
||||
}
|
||||
|
||||
# Use provided models or default
|
||||
models_to_download = args.urls if args.urls else default_models[args.type]
|
||||
|
||||
# Download all models
|
||||
success = True
|
||||
for model_url in models_to_download:
|
||||
if not download_file(model_url, models_dir, args.type,args.overwrite):
|
||||
success = False
|
||||
|
||||
if success:
|
||||
print(f"{args.type.upper()} model download complete!")
|
||||
return 0
|
||||
else:
|
||||
print("Some downloads failed", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
main()
|
|
@ -5,7 +5,7 @@ PROJECT_ROOT=$(pwd)
|
|||
|
||||
# Set environment variables
|
||||
export USE_GPU=false
|
||||
export USE_ONNX=true
|
||||
export USE_ONNX=false
|
||||
export PYTHONPATH=$PROJECT_ROOT:$PROJECT_ROOT/api
|
||||
export MODEL_DIR=$PROJECT_ROOT/api/src/models
|
||||
export VOICES_DIR=$PROJECT_ROOT/api/src/voices
|
||||
|
|
|
@ -375,8 +375,6 @@ export class AudioService {
|
|||
this.sourceBuffer = null;
|
||||
this.serverDownloadPath = null;
|
||||
this.pendingOperations = [];
|
||||
|
||||
window.location.reload();
|
||||
}
|
||||
|
||||
cleanup() {
|
||||
|
|
Loading…
Add table
Reference in a new issue