mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Enhance model inference: update documentation, add model download scripts for PyTorch and ONNX, and refactor configuration handling
This commit is contained in:
parent
ab28a62e86
commit
21bf810f97
25 changed files with 774 additions and 309 deletions
|
@ -337,11 +337,13 @@ def recursive_munch(d):
|
||||||
else:
|
else:
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def build_model(path, device):
|
async def build_model(path, device):
|
||||||
|
from ..core.paths import load_json, load_model_weights
|
||||||
|
|
||||||
config = Path(__file__).parent / 'config.json'
|
config = Path(__file__).parent / 'config.json'
|
||||||
assert config.exists(), f'Config path incorrect: config.json not found at {config}'
|
assert config.exists(), f'Config path incorrect: config.json not found at {config}'
|
||||||
with open(config, 'r') as r:
|
|
||||||
args = recursive_munch(json.load(r))
|
args = recursive_munch(await load_json(config))
|
||||||
assert args.decoder.type == 'istftnet', f'Unknown decoder type: {args.decoder.type}'
|
assert args.decoder.type == 'istftnet', f'Unknown decoder type: {args.decoder.type}'
|
||||||
decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
|
decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
|
||||||
resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
|
resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
|
||||||
|
@ -365,7 +367,8 @@ def build_model(path, device):
|
||||||
decoder=decoder.to(device).eval(),
|
decoder=decoder.to(device).eval(),
|
||||||
text_encoder=text_encoder.to(device).eval(),
|
text_encoder=text_encoder.to(device).eval(),
|
||||||
)
|
)
|
||||||
for key, state_dict in torch.load(path, map_location='cpu', weights_only=True)['net'].items():
|
weights = await load_model_weights(path, device='cpu')
|
||||||
|
for key, state_dict in weights['net'].items():
|
||||||
assert key in model, key
|
assert key in model, key
|
||||||
try:
|
try:
|
||||||
model[key].load_state_dict(state_dict)
|
model[key].load_state_dict(state_dict)
|
||||||
|
|
|
@ -13,10 +13,15 @@ class Settings(BaseSettings):
|
||||||
output_dir: str = "output"
|
output_dir: str = "output"
|
||||||
output_dir_size_limit_mb: float = 500.0 # Maximum size of output directory in MB
|
output_dir_size_limit_mb: float = 500.0 # Maximum size of output directory in MB
|
||||||
default_voice: str = "af"
|
default_voice: str = "af"
|
||||||
model_dir: str = "/app/models" # Base directory for model files
|
use_gpu: bool = False # Whether to use GPU acceleration if available
|
||||||
pytorch_model_path: str = "kokoro-v0_19.pth"
|
use_onnx: bool = True # Whether to use ONNX runtime
|
||||||
onnx_model_path: str = "kokoro-v0_19.onnx"
|
# Paths relative to api directory
|
||||||
voices_dir: str = "voices"
|
model_dir: str = "src/models" # Model directory relative to api/
|
||||||
|
voices_dir: str = "src/voices" # Voices directory relative to api/
|
||||||
|
|
||||||
|
# Model filenames
|
||||||
|
pytorch_model_file: str = "kokoro-v0_19.pth"
|
||||||
|
onnx_model_file: str = "kokoro-v0_19.onnx"
|
||||||
sample_rate: int = 24000
|
sample_rate: int = 24000
|
||||||
max_chunk_size: int = 300 # Maximum size of text chunks for processing
|
max_chunk_size: int = 300 # Maximum size of text chunks for processing
|
||||||
gap_trim_ms: int = 250 # Amount to trim from streaming chunk ends in milliseconds
|
gap_trim_ms: int = 250 # Amount to trim from streaming chunk ends in milliseconds
|
||||||
|
@ -28,6 +33,12 @@ class Settings(BaseSettings):
|
||||||
onnx_optimization_level: str = "all" # all, basic, or disabled
|
onnx_optimization_level: str = "all" # all, basic, or disabled
|
||||||
onnx_memory_pattern: bool = True # Enable memory pattern optimization
|
onnx_memory_pattern: bool = True # Enable memory pattern optimization
|
||||||
onnx_arena_extend_strategy: str = "kNextPowerOfTwo" # Memory allocation strategy
|
onnx_arena_extend_strategy: str = "kNextPowerOfTwo" # Memory allocation strategy
|
||||||
|
|
||||||
|
# ONNX GPU Settings
|
||||||
|
onnx_device_id: int = 0 # GPU device ID to use
|
||||||
|
onnx_gpu_mem_limit: float = 0.7 # Limit GPU memory usage to 70%
|
||||||
|
onnx_cudnn_conv_algo_search: str = "EXHAUSTIVE" # CUDNN convolution algorithm search
|
||||||
|
onnx_do_copy_in_default_stream: bool = True # Copy in default CUDA stream
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_file = ".env"
|
env_file = ".env"
|
||||||
|
|
109
api/src/core/model_config.py
Normal file
109
api/src/core/model_config.py
Normal file
|
@ -0,0 +1,109 @@
|
||||||
|
"""Model configuration schemas."""
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class ONNXCPUConfig(BaseModel):
|
||||||
|
"""ONNX CPU runtime configuration."""
|
||||||
|
|
||||||
|
num_threads: int = Field(8, description="Number of threads for parallel operations")
|
||||||
|
inter_op_threads: int = Field(4, description="Number of threads for operator parallelism")
|
||||||
|
execution_mode: str = Field("parallel", description="ONNX execution mode")
|
||||||
|
optimization_level: str = Field("all", description="ONNX optimization level")
|
||||||
|
memory_pattern: bool = Field(True, description="Enable memory pattern optimization")
|
||||||
|
arena_extend_strategy: str = Field("kNextPowerOfTwo", description="Memory arena strategy")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
frozen = True
|
||||||
|
|
||||||
|
|
||||||
|
class ONNXGPUConfig(ONNXCPUConfig):
|
||||||
|
"""ONNX GPU-specific configuration."""
|
||||||
|
|
||||||
|
device_id: int = Field(0, description="CUDA device ID")
|
||||||
|
gpu_mem_limit: float = Field(0.7, description="Fraction of GPU memory to use")
|
||||||
|
cudnn_conv_algo_search: str = Field("EXHAUSTIVE", description="CuDNN convolution algorithm search")
|
||||||
|
do_copy_in_default_stream: bool = Field(True, description="Copy in default CUDA stream")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
frozen = True
|
||||||
|
|
||||||
|
|
||||||
|
class PyTorchCPUConfig(BaseModel):
|
||||||
|
"""PyTorch CPU backend configuration."""
|
||||||
|
|
||||||
|
max_batch_size: int = Field(32, description="Maximum batch size for batched inference")
|
||||||
|
stream_buffer_size: int = Field(8, description="Size of stream buffer")
|
||||||
|
memory_threshold: float = Field(0.8, description="Memory threshold for cleanup")
|
||||||
|
retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors")
|
||||||
|
num_threads: int = Field(8, description="Number of threads for parallel operations")
|
||||||
|
pin_memory: bool = Field(True, description="Whether to pin memory for faster CPU-GPU transfer")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
frozen = True
|
||||||
|
|
||||||
|
|
||||||
|
class PyTorchGPUConfig(BaseModel):
|
||||||
|
"""PyTorch GPU backend configuration."""
|
||||||
|
|
||||||
|
device_id: int = Field(0, description="CUDA device ID")
|
||||||
|
use_fp16: bool = Field(True, description="Whether to use FP16 precision")
|
||||||
|
use_triton: bool = Field(True, description="Whether to use Triton for CUDA kernels")
|
||||||
|
max_batch_size: int = Field(32, description="Maximum batch size for batched inference")
|
||||||
|
stream_buffer_size: int = Field(8, description="Size of CUDA stream buffer")
|
||||||
|
memory_threshold: float = Field(0.8, description="Memory threshold for cleanup")
|
||||||
|
retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors")
|
||||||
|
sync_cuda: bool = Field(True, description="Whether to synchronize CUDA operations")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
frozen = True
|
||||||
|
"""PyTorch CPU-specific configuration."""
|
||||||
|
|
||||||
|
num_threads: int = Field(8, description="Number of threads for parallel operations")
|
||||||
|
pin_memory: bool = Field(True, description="Whether to pin memory for faster CPU-GPU transfer")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
frozen = True
|
||||||
|
|
||||||
|
|
||||||
|
class ModelConfig(BaseModel):
|
||||||
|
"""Model configuration."""
|
||||||
|
|
||||||
|
# General settings
|
||||||
|
model_type: str = Field("pytorch", description="Model type ('pytorch' or 'onnx')")
|
||||||
|
device_type: str = Field("auto", description="Device type ('cpu', 'gpu', or 'auto')")
|
||||||
|
cache_models: bool = Field(True, description="Whether to cache loaded models")
|
||||||
|
cache_voices: bool = Field(True, description="Whether to cache voice tensors")
|
||||||
|
voice_cache_size: int = Field(10, description="Maximum number of cached voices")
|
||||||
|
|
||||||
|
# Backend-specific configs
|
||||||
|
onnx_cpu: ONNXCPUConfig = Field(default_factory=ONNXCPUConfig)
|
||||||
|
onnx_gpu: ONNXGPUConfig = Field(default_factory=ONNXGPUConfig)
|
||||||
|
pytorch_cpu: PyTorchCPUConfig = Field(default_factory=PyTorchCPUConfig)
|
||||||
|
pytorch_gpu: PyTorchGPUConfig = Field(default_factory=PyTorchGPUConfig)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
frozen = True
|
||||||
|
|
||||||
|
def get_backend_config(self, backend_type: str):
|
||||||
|
"""Get configuration for specific backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backend_type: Backend type ('pytorch_cpu', 'pytorch_gpu', 'onnx_cpu', 'onnx_gpu')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Backend-specific configuration
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If backend type is invalid
|
||||||
|
"""
|
||||||
|
if backend_type not in {
|
||||||
|
'pytorch_cpu', 'pytorch_gpu', 'onnx_cpu', 'onnx_gpu'
|
||||||
|
}:
|
||||||
|
raise ValueError(f"Invalid backend type: {backend_type}")
|
||||||
|
|
||||||
|
return getattr(self, backend_type)
|
||||||
|
|
||||||
|
|
||||||
|
# Global instance
|
||||||
|
model_config = ModelConfig()
|
|
@ -1,9 +1,10 @@
|
||||||
"""Async file and path operations."""
|
"""Async file and path operations."""
|
||||||
|
|
||||||
import io
|
import io
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, AsyncIterator, Callable, Set
|
from typing import List, Optional, AsyncIterator, Callable, Set, Dict, Any
|
||||||
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import aiofiles.os
|
import aiofiles.os
|
||||||
|
@ -87,10 +88,18 @@ async def get_model_path(model_name: str) -> str:
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If model not found
|
RuntimeError: If model not found
|
||||||
"""
|
"""
|
||||||
search_paths = [
|
# Get api directory path (two levels up from core)
|
||||||
settings.model_dir,
|
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||||
os.path.join(os.path.dirname(__file__), "..", "..", "..", "models")
|
|
||||||
]
|
# Construct model directory path relative to api directory
|
||||||
|
model_dir = os.path.join(api_dir, settings.model_dir)
|
||||||
|
|
||||||
|
# Ensure model directory exists
|
||||||
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Search in model directory
|
||||||
|
search_paths = [model_dir]
|
||||||
|
logger.debug(f"Searching for model in path: {model_dir}")
|
||||||
|
|
||||||
return await _find_file(model_name, search_paths)
|
return await _find_file(model_name, search_paths)
|
||||||
|
|
||||||
|
@ -107,12 +116,20 @@ async def get_voice_path(voice_name: str) -> str:
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If voice not found
|
RuntimeError: If voice not found
|
||||||
"""
|
"""
|
||||||
|
# Get api directory path
|
||||||
|
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
|
||||||
|
# Construct voice directory path relative to api directory
|
||||||
|
voice_dir = os.path.join(api_dir, settings.voices_dir)
|
||||||
|
|
||||||
|
# Ensure voice directory exists
|
||||||
|
os.makedirs(voice_dir, exist_ok=True)
|
||||||
|
|
||||||
voice_file = f"{voice_name}.pt"
|
voice_file = f"{voice_name}.pt"
|
||||||
|
|
||||||
search_paths = [
|
# Search in voice directory
|
||||||
os.path.join(settings.model_dir, "..", settings.voices_dir),
|
search_paths = [voice_dir]
|
||||||
os.path.join(os.path.dirname(__file__), "..", settings.voices_dir)
|
logger.debug(f"Searching for voice in path: {voice_dir}")
|
||||||
]
|
|
||||||
|
|
||||||
return await _find_file(voice_file, search_paths)
|
return await _find_file(voice_file, search_paths)
|
||||||
|
|
||||||
|
@ -123,10 +140,18 @@ async def list_voices() -> List[str]:
|
||||||
Returns:
|
Returns:
|
||||||
List of voice names (without .pt extension)
|
List of voice names (without .pt extension)
|
||||||
"""
|
"""
|
||||||
search_paths = [
|
# Get api directory path
|
||||||
os.path.join(settings.model_dir, "..", settings.voices_dir),
|
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||||
os.path.join(os.path.dirname(__file__), "..", settings.voices_dir)
|
|
||||||
]
|
# Construct voice directory path relative to api directory
|
||||||
|
voice_dir = os.path.join(api_dir, settings.voices_dir)
|
||||||
|
|
||||||
|
# Ensure voice directory exists
|
||||||
|
os.makedirs(voice_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Search in voice directory
|
||||||
|
search_paths = [voice_dir]
|
||||||
|
logger.debug(f"Scanning for voices in path: {voice_dir}")
|
||||||
|
|
||||||
def filter_voice_files(name: str) -> bool:
|
def filter_voice_files(name: str) -> bool:
|
||||||
return name.endswith('.pt')
|
return name.endswith('.pt')
|
||||||
|
@ -179,6 +204,51 @@ async def save_voice_tensor(tensor: torch.Tensor, voice_path: str) -> None:
|
||||||
raise RuntimeError(f"Failed to save voice tensor to {voice_path}: {e}")
|
raise RuntimeError(f"Failed to save voice tensor to {voice_path}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def load_json(path: str) -> dict:
|
||||||
|
"""Load JSON file asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to JSON file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed JSON data
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If file cannot be read or parsed
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with aiofiles.open(path, 'r', encoding='utf-8') as f:
|
||||||
|
content = await f.read()
|
||||||
|
return json.loads(content)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to load JSON file {path}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def load_model_weights(path: str, device: str = "cpu") -> dict:
|
||||||
|
"""Load model weights asynchronously.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to model file (.pth or .onnx)
|
||||||
|
device: Device to load model to
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model weights
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If file cannot be read
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with aiofiles.open(path, 'rb') as f:
|
||||||
|
data = await f.read()
|
||||||
|
return torch.load(
|
||||||
|
io.BytesIO(data),
|
||||||
|
map_location=device,
|
||||||
|
weights_only=True
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to load model weights from {path}: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def read_file(path: str) -> str:
|
async def read_file(path: str) -> str:
|
||||||
"""Read text file asynchronously.
|
"""Read text file asynchronously.
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
"""Inference backends and model management."""
|
"""Model inference package."""
|
||||||
|
|
||||||
from .base import BaseModelBackend
|
from .base import BaseModelBackend
|
||||||
from .model_manager import ModelManager, get_manager
|
from .model_manager import ModelManager, get_manager
|
||||||
|
@ -6,15 +6,13 @@ from .onnx_cpu import ONNXCPUBackend
|
||||||
from .onnx_gpu import ONNXGPUBackend
|
from .onnx_gpu import ONNXGPUBackend
|
||||||
from .pytorch_cpu import PyTorchCPUBackend
|
from .pytorch_cpu import PyTorchCPUBackend
|
||||||
from .pytorch_gpu import PyTorchGPUBackend
|
from .pytorch_gpu import PyTorchGPUBackend
|
||||||
from ..structures.model_schemas import ModelConfig
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'BaseModelBackend',
|
'BaseModelBackend',
|
||||||
'ModelManager',
|
'ModelManager',
|
||||||
'get_manager',
|
'get_manager',
|
||||||
'ModelConfig',
|
|
||||||
'ONNXCPUBackend',
|
'ONNXCPUBackend',
|
||||||
'ONNXGPUBackend',
|
'ONNXGPUBackend',
|
||||||
'PyTorchCPUBackend',
|
'PyTorchCPUBackend',
|
||||||
'PyTorchGPUBackend'
|
'PyTorchGPUBackend',
|
||||||
]
|
]
|
|
@ -1,21 +1,18 @@
|
||||||
"""Model management and caching."""
|
"""Model management and caching."""
|
||||||
|
|
||||||
import os
|
from typing import Dict, Optional
|
||||||
from typing import Dict, List, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
from ..core import paths
|
||||||
|
from ..core.config import settings
|
||||||
|
from ..core.model_config import ModelConfig, model_config
|
||||||
from .base import BaseModelBackend
|
from .base import BaseModelBackend
|
||||||
from .voice_manager import get_manager as get_voice_manager
|
|
||||||
from .onnx_cpu import ONNXCPUBackend
|
from .onnx_cpu import ONNXCPUBackend
|
||||||
from .onnx_gpu import ONNXGPUBackend
|
from .onnx_gpu import ONNXGPUBackend
|
||||||
from .pytorch_cpu import PyTorchCPUBackend
|
from .pytorch_cpu import PyTorchCPUBackend
|
||||||
from .pytorch_gpu import PyTorchGPUBackend
|
from .pytorch_gpu import PyTorchGPUBackend
|
||||||
from ..core import paths
|
|
||||||
from ..core.config import settings
|
|
||||||
from ..structures.model_schemas import ModelConfig
|
|
||||||
|
|
||||||
|
|
||||||
class ModelManager:
|
class ModelManager:
|
||||||
|
@ -27,44 +24,63 @@ class ModelManager:
|
||||||
Args:
|
Args:
|
||||||
config: Optional configuration
|
config: Optional configuration
|
||||||
"""
|
"""
|
||||||
self._config = config or ModelConfig()
|
self._config = config or model_config
|
||||||
self._backends: Dict[str, BaseModelBackend] = {}
|
self._backends: Dict[str, BaseModelBackend] = {}
|
||||||
self._current_backend: Optional[str] = None
|
self._current_backend: Optional[str] = None
|
||||||
self._voice_manager = get_voice_manager()
|
|
||||||
self._initialize_backends()
|
self._initialize_backends()
|
||||||
|
|
||||||
def _initialize_backends(self) -> None:
|
def _initialize_backends(self) -> None:
|
||||||
"""Initialize available backends."""
|
"""Initialize available backends based on settings."""
|
||||||
"""Initialize available backends."""
|
has_gpu = settings.use_gpu and torch.cuda.is_available()
|
||||||
# Initialize GPU backends if available
|
|
||||||
if settings.use_gpu and torch.cuda.is_available():
|
try:
|
||||||
try:
|
if has_gpu:
|
||||||
# PyTorch GPU
|
if settings.use_onnx:
|
||||||
self._backends['pytorch_gpu'] = PyTorchGPUBackend()
|
# ONNX GPU primary
|
||||||
self._current_backend = 'pytorch_gpu'
|
self._backends['onnx_gpu'] = ONNXGPUBackend()
|
||||||
logger.info("Initialized PyTorch GPU backend")
|
self._current_backend = 'onnx_gpu'
|
||||||
|
logger.info("Initialized ONNX GPU backend")
|
||||||
# ONNX GPU
|
|
||||||
self._backends['onnx_gpu'] = ONNXGPUBackend()
|
# PyTorch GPU fallback
|
||||||
logger.info("Initialized ONNX GPU backend")
|
self._backends['pytorch_gpu'] = PyTorchGPUBackend()
|
||||||
except Exception as e:
|
logger.info("Initialized PyTorch GPU backend")
|
||||||
logger.error(f"Failed to initialize GPU backends: {e}")
|
else:
|
||||||
# Fallback to CPU if GPU fails
|
# PyTorch GPU primary
|
||||||
|
self._backends['pytorch_gpu'] = PyTorchGPUBackend()
|
||||||
|
self._current_backend = 'pytorch_gpu'
|
||||||
|
logger.info("Initialized PyTorch GPU backend")
|
||||||
|
|
||||||
|
# ONNX GPU fallback
|
||||||
|
self._backends['onnx_gpu'] = ONNXGPUBackend()
|
||||||
|
logger.info("Initialized ONNX GPU backend")
|
||||||
|
else:
|
||||||
self._initialize_cpu_backends()
|
self._initialize_cpu_backends()
|
||||||
else:
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize GPU backends: {e}")
|
||||||
|
# Fallback to CPU if GPU fails
|
||||||
self._initialize_cpu_backends()
|
self._initialize_cpu_backends()
|
||||||
|
|
||||||
def _initialize_cpu_backends(self) -> None:
|
def _initialize_cpu_backends(self) -> None:
|
||||||
"""Initialize CPU backends."""
|
"""Initialize CPU backends based on settings."""
|
||||||
try:
|
try:
|
||||||
# PyTorch CPU
|
if settings.use_onnx:
|
||||||
self._backends['pytorch_cpu'] = PyTorchCPUBackend()
|
# ONNX CPU primary
|
||||||
self._current_backend = 'pytorch_cpu'
|
self._backends['onnx_cpu'] = ONNXCPUBackend()
|
||||||
logger.info("Initialized PyTorch CPU backend")
|
self._current_backend = 'onnx_cpu'
|
||||||
|
logger.info("Initialized ONNX CPU backend")
|
||||||
# ONNX CPU
|
|
||||||
self._backends['onnx_cpu'] = ONNXCPUBackend()
|
# PyTorch CPU fallback
|
||||||
logger.info("Initialized ONNX CPU backend")
|
self._backends['pytorch_cpu'] = PyTorchCPUBackend()
|
||||||
|
logger.info("Initialized PyTorch CPU backend")
|
||||||
|
else:
|
||||||
|
# PyTorch CPU primary
|
||||||
|
self._backends['pytorch_cpu'] = PyTorchCPUBackend()
|
||||||
|
self._current_backend = 'pytorch_cpu'
|
||||||
|
logger.info("Initialized PyTorch CPU backend")
|
||||||
|
|
||||||
|
# ONNX CPU fallback
|
||||||
|
self._backends['onnx_cpu'] = ONNXCPUBackend()
|
||||||
|
logger.info("Initialized ONNX CPU backend")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to initialize CPU backends: {e}")
|
logger.error(f"Failed to initialize CPU backends: {e}")
|
||||||
raise RuntimeError("No backends available")
|
raise RuntimeError("No backends available")
|
||||||
|
@ -98,7 +114,7 @@ class ModelManager:
|
||||||
return self._backends[backend_type]
|
return self._backends[backend_type]
|
||||||
|
|
||||||
def _determine_backend(self, model_path: str) -> str:
|
def _determine_backend(self, model_path: str) -> str:
|
||||||
"""Determine appropriate backend based on model file.
|
"""Determine appropriate backend based on model file and settings.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_path: Path to model file
|
model_path: Path to model file
|
||||||
|
@ -106,10 +122,10 @@ class ModelManager:
|
||||||
Returns:
|
Returns:
|
||||||
Backend type to use
|
Backend type to use
|
||||||
"""
|
"""
|
||||||
is_onnx = model_path.lower().endswith('.onnx')
|
|
||||||
has_gpu = settings.use_gpu and torch.cuda.is_available()
|
has_gpu = settings.use_gpu and torch.cuda.is_available()
|
||||||
|
|
||||||
if is_onnx:
|
# If ONNX is preferred or model is ONNX format
|
||||||
|
if settings.use_onnx or model_path.lower().endswith('.onnx'):
|
||||||
return 'onnx_gpu' if has_gpu else 'onnx_cpu'
|
return 'onnx_gpu' if has_gpu else 'onnx_cpu'
|
||||||
else:
|
else:
|
||||||
return 'pytorch_gpu' if has_gpu else 'pytorch_cpu'
|
return 'pytorch_gpu' if has_gpu else 'pytorch_cpu'
|
||||||
|
@ -117,12 +133,14 @@ class ModelManager:
|
||||||
async def load_model(
|
async def load_model(
|
||||||
self,
|
self,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
|
warmup_voice: Optional[torch.Tensor] = None,
|
||||||
backend_type: Optional[str] = None
|
backend_type: Optional[str] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Load model on specified backend.
|
"""Load model on specified backend.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_path: Path to model file
|
model_path: Path to model file
|
||||||
|
warmup_voice: Optional voice tensor for warmup, skips warmup if None
|
||||||
backend_type: Backend to load on, uses default if None
|
backend_type: Backend to load on, uses default if None
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
|
@ -138,35 +156,39 @@ class ModelManager:
|
||||||
|
|
||||||
backend = self.get_backend(backend_type)
|
backend = self.get_backend(backend_type)
|
||||||
|
|
||||||
# Load model and run warmup
|
# Load model
|
||||||
await backend.load_model(abs_path)
|
await backend.load_model(abs_path)
|
||||||
logger.info(f"Loaded model on {backend_type} backend")
|
logger.info(f"Loaded model on {backend_type} backend")
|
||||||
await self._warmup_inference(backend)
|
|
||||||
|
# Run warmup if voice provided
|
||||||
|
if warmup_voice is not None:
|
||||||
|
await self._warmup_inference(backend, warmup_voice)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to load model: {e}")
|
raise RuntimeError(f"Failed to load model: {e}")
|
||||||
|
|
||||||
async def _warmup_inference(self, backend: BaseModelBackend) -> None:
|
async def _warmup_inference(self, backend: BaseModelBackend, voice: torch.Tensor) -> None:
|
||||||
"""Run warmup inference to initialize model."""
|
"""Run warmup inference to initialize model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backend: Model backend to warm up
|
||||||
|
voice: Voice tensor already loaded on correct device
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
# Import here to avoid circular imports
|
# Import here to avoid circular imports
|
||||||
from ..text_processing import process_text
|
from ..services.text_processing import process_text
|
||||||
|
|
||||||
# Load default voice for warmup
|
|
||||||
voice = await self._voice_manager.load_voice(settings.default_voice, device=backend.device)
|
|
||||||
logger.info(f"Loaded voice {settings.default_voice} for warmup")
|
|
||||||
|
|
||||||
# Use real text
|
# Use real text
|
||||||
text = "Testing text to speech synthesis."
|
text = "Testing text to speech synthesis."
|
||||||
logger.info(f"Running warmup inference with voice: af")
|
|
||||||
|
|
||||||
# Process through pipeline
|
# Process through pipeline
|
||||||
sequences = process_text(text)
|
tokens = process_text(text)
|
||||||
if not sequences:
|
if not tokens:
|
||||||
raise ValueError("Text processing failed")
|
raise ValueError("Text processing failed")
|
||||||
|
|
||||||
# Run inference
|
# Run inference
|
||||||
backend.generate(sequences[0], voice, speed=1.0)
|
backend.generate(tokens, voice, speed=1.0)
|
||||||
|
logger.info("Completed warmup inference")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Warmup inference failed: {e}")
|
logger.warning(f"Warmup inference failed: {e}")
|
||||||
|
@ -175,7 +197,7 @@ class ModelManager:
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
tokens: list[int],
|
tokens: list[int],
|
||||||
voice_name: str,
|
voice: torch.Tensor,
|
||||||
speed: float = 1.0,
|
speed: float = 1.0,
|
||||||
backend_type: Optional[str] = None
|
backend_type: Optional[str] = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
@ -183,7 +205,7 @@ class ModelManager:
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tokens: Input token IDs
|
tokens: Input token IDs
|
||||||
voice_name: Name of voice to use
|
voice: Voice tensor already loaded on correct device
|
||||||
speed: Speed multiplier
|
speed: Speed multiplier
|
||||||
backend_type: Backend to use, uses default if None
|
backend_type: Backend to use, uses default if None
|
||||||
|
|
||||||
|
@ -198,10 +220,7 @@ class ModelManager:
|
||||||
raise RuntimeError("Model not loaded")
|
raise RuntimeError("Model not loaded")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Load voice using voice manager
|
# Generate audio using provided voice tensor
|
||||||
voice = await self._voice_manager.load_voice(voice_name, device=backend.device)
|
|
||||||
|
|
||||||
# Generate audio
|
|
||||||
return backend.generate(tokens, voice, speed)
|
return backend.generate(tokens, voice, speed)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -13,8 +13,7 @@ from onnxruntime import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..core import paths
|
from ..core import paths
|
||||||
from ..core.config import settings
|
from ..core.model_config import model_config
|
||||||
from ..structures.model_schemas import ONNXConfig
|
|
||||||
from .base import BaseModelBackend
|
from .base import BaseModelBackend
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,14 +25,11 @@ class ONNXCPUBackend(BaseModelBackend):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._device = "cpu"
|
self._device = "cpu"
|
||||||
self._session: Optional[InferenceSession] = None
|
self._session: Optional[InferenceSession] = None
|
||||||
self._config = ONNXConfig(
|
|
||||||
optimization_level=settings.onnx_optimization_level,
|
@property
|
||||||
num_threads=settings.onnx_num_threads,
|
def is_loaded(self) -> bool:
|
||||||
inter_op_threads=settings.onnx_inter_op_threads,
|
"""Check if model is loaded."""
|
||||||
execution_mode=settings.onnx_execution_mode,
|
return self._session is not None
|
||||||
memory_pattern=settings.onnx_memory_pattern,
|
|
||||||
arena_extend_strategy=settings.onnx_arena_extend_strategy
|
|
||||||
)
|
|
||||||
|
|
||||||
async def load_model(self, path: str) -> None:
|
async def load_model(self, path: str) -> None:
|
||||||
"""Load ONNX model.
|
"""Load ONNX model.
|
||||||
|
@ -115,28 +111,29 @@ class ONNXCPUBackend(BaseModelBackend):
|
||||||
Configured session options
|
Configured session options
|
||||||
"""
|
"""
|
||||||
options = SessionOptions()
|
options = SessionOptions()
|
||||||
|
config = model_config.onnx_cpu
|
||||||
|
|
||||||
# Set optimization level
|
# Set optimization level
|
||||||
if self._config.optimization_level == "all":
|
if config.optimization_level == "all":
|
||||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||||
elif self._config.optimization_level == "basic":
|
elif config.optimization_level == "basic":
|
||||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
|
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||||
else:
|
else:
|
||||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
|
options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
|
||||||
|
|
||||||
# Configure threading
|
# Configure threading
|
||||||
options.intra_op_num_threads = self._config.num_threads
|
options.intra_op_num_threads = config.num_threads
|
||||||
options.inter_op_num_threads = self._config.inter_op_threads
|
options.inter_op_num_threads = config.inter_op_threads
|
||||||
|
|
||||||
# Set execution mode
|
# Set execution mode
|
||||||
options.execution_mode = (
|
options.execution_mode = (
|
||||||
ExecutionMode.ORT_PARALLEL
|
ExecutionMode.ORT_PARALLEL
|
||||||
if self._config.execution_mode == "parallel"
|
if config.execution_mode == "parallel"
|
||||||
else ExecutionMode.ORT_SEQUENTIAL
|
else ExecutionMode.ORT_SEQUENTIAL
|
||||||
)
|
)
|
||||||
|
|
||||||
# Configure memory optimization
|
# Configure memory optimization
|
||||||
options.enable_mem_pattern = self._config.memory_pattern
|
options.enable_mem_pattern = config.memory_pattern
|
||||||
|
|
||||||
return options
|
return options
|
||||||
|
|
||||||
|
@ -148,7 +145,15 @@ class ONNXCPUBackend(BaseModelBackend):
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
"CPUExecutionProvider": {
|
"CPUExecutionProvider": {
|
||||||
"arena_extend_strategy": self._config.arena_extend_strategy,
|
"arena_extend_strategy": model_config.onnx_cpu.arena_extend_strategy,
|
||||||
"cpu_memory_arena_cfg": "cpu:0"
|
"cpu_memory_arena_cfg": "cpu:0"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def unload(self) -> None:
|
||||||
|
"""Unload model and free resources."""
|
||||||
|
if self._session is not None:
|
||||||
|
del self._session
|
||||||
|
self._session = None
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
|
@ -13,8 +13,7 @@ from onnxruntime import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..core import paths
|
from ..core import paths
|
||||||
from ..core.config import settings
|
from ..core.model_config import model_config
|
||||||
from ..structures.model_schemas import ONNXGPUConfig
|
|
||||||
from .base import BaseModelBackend
|
from .base import BaseModelBackend
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,18 +27,11 @@ class ONNXGPUBackend(BaseModelBackend):
|
||||||
raise RuntimeError("CUDA not available")
|
raise RuntimeError("CUDA not available")
|
||||||
self._device = "cuda"
|
self._device = "cuda"
|
||||||
self._session: Optional[InferenceSession] = None
|
self._session: Optional[InferenceSession] = None
|
||||||
self._config = ONNXGPUConfig(
|
|
||||||
optimization_level=settings.onnx_optimization_level,
|
@property
|
||||||
num_threads=settings.onnx_num_threads,
|
def is_loaded(self) -> bool:
|
||||||
inter_op_threads=settings.onnx_inter_op_threads,
|
"""Check if model is loaded."""
|
||||||
execution_mode=settings.onnx_execution_mode,
|
return self._session is not None
|
||||||
memory_pattern=settings.onnx_memory_pattern,
|
|
||||||
arena_extend_strategy=settings.onnx_arena_extend_strategy,
|
|
||||||
device_id=0,
|
|
||||||
gpu_mem_limit=0.7,
|
|
||||||
cudnn_conv_algo_search="EXHAUSTIVE",
|
|
||||||
do_copy_in_default_stream=True
|
|
||||||
)
|
|
||||||
|
|
||||||
async def load_model(self, path: str) -> None:
|
async def load_model(self, path: str) -> None:
|
||||||
"""Load ONNX model.
|
"""Load ONNX model.
|
||||||
|
@ -121,28 +113,29 @@ class ONNXGPUBackend(BaseModelBackend):
|
||||||
Configured session options
|
Configured session options
|
||||||
"""
|
"""
|
||||||
options = SessionOptions()
|
options = SessionOptions()
|
||||||
|
config = model_config.onnx_gpu
|
||||||
|
|
||||||
# Set optimization level
|
# Set optimization level
|
||||||
if self._config.optimization_level == "all":
|
if config.optimization_level == "all":
|
||||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||||
elif self._config.optimization_level == "basic":
|
elif config.optimization_level == "basic":
|
||||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
|
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||||
else:
|
else:
|
||||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
|
options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
|
||||||
|
|
||||||
# Configure threading
|
# Configure threading
|
||||||
options.intra_op_num_threads = self._config.num_threads
|
options.intra_op_num_threads = config.num_threads
|
||||||
options.inter_op_num_threads = self._config.inter_op_threads
|
options.inter_op_num_threads = config.inter_op_threads
|
||||||
|
|
||||||
# Set execution mode
|
# Set execution mode
|
||||||
options.execution_mode = (
|
options.execution_mode = (
|
||||||
ExecutionMode.ORT_PARALLEL
|
ExecutionMode.ORT_PARALLEL
|
||||||
if self._config.execution_mode == "parallel"
|
if config.execution_mode == "parallel"
|
||||||
else ExecutionMode.ORT_SEQUENTIAL
|
else ExecutionMode.ORT_SEQUENTIAL
|
||||||
)
|
)
|
||||||
|
|
||||||
# Configure memory optimization
|
# Configure memory optimization
|
||||||
options.enable_mem_pattern = self._config.memory_pattern
|
options.enable_mem_pattern = config.memory_pattern
|
||||||
|
|
||||||
return options
|
return options
|
||||||
|
|
||||||
|
@ -152,12 +145,21 @@ class ONNXGPUBackend(BaseModelBackend):
|
||||||
Returns:
|
Returns:
|
||||||
Provider configuration
|
Provider configuration
|
||||||
"""
|
"""
|
||||||
|
config = model_config.onnx_gpu
|
||||||
return {
|
return {
|
||||||
"CUDAExecutionProvider": {
|
"CUDAExecutionProvider": {
|
||||||
"device_id": self._config.device_id,
|
"device_id": config.device_id,
|
||||||
"arena_extend_strategy": self._config.arena_extend_strategy,
|
"arena_extend_strategy": config.arena_extend_strategy,
|
||||||
"gpu_mem_limit": int(self._config.gpu_mem_limit * torch.cuda.get_device_properties(0).total_memory),
|
"gpu_mem_limit": int(config.gpu_mem_limit * torch.cuda.get_device_properties(0).total_memory),
|
||||||
"cudnn_conv_algo_search": self._config.cudnn_conv_algo_search,
|
"cudnn_conv_algo_search": config.cudnn_conv_algo_search,
|
||||||
"do_copy_in_default_stream": self._config.do_copy_in_default_stream
|
"do_copy_in_default_stream": config.do_copy_in_default_stream
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def unload(self) -> None:
|
||||||
|
"""Unload model and free resources."""
|
||||||
|
if self._session is not None:
|
||||||
|
del self._session
|
||||||
|
self._session = None
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
|
@ -9,7 +9,7 @@ from loguru import logger
|
||||||
|
|
||||||
from ..builds.models import build_model
|
from ..builds.models import build_model
|
||||||
from ..core import paths
|
from ..core import paths
|
||||||
from ..structures.model_schemas import PyTorchCPUConfig
|
from ..core.model_config import model_config
|
||||||
from .base import BaseModelBackend
|
from .base import BaseModelBackend
|
||||||
|
|
||||||
|
|
||||||
|
@ -118,12 +118,12 @@ class PyTorchCPUBackend(BaseModelBackend):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._device = "cpu"
|
self._device = "cpu"
|
||||||
self._model: Optional[torch.nn.Module] = None
|
self._model: Optional[torch.nn.Module] = None
|
||||||
self._config = PyTorchCPUConfig()
|
|
||||||
|
|
||||||
# Configure PyTorch CPU settings
|
# Configure PyTorch CPU settings
|
||||||
if self._config.num_threads > 0:
|
config = model_config.pytorch_cpu
|
||||||
torch.set_num_threads(self._config.num_threads)
|
if config.num_threads > 0:
|
||||||
if self._config.pin_memory:
|
torch.set_num_threads(config.num_threads)
|
||||||
|
if config.pin_memory:
|
||||||
torch.set_default_tensor_type(torch.FloatTensor)
|
torch.set_default_tensor_type(torch.FloatTensor)
|
||||||
|
|
||||||
async def load_model(self, path: str) -> None:
|
async def load_model(self, path: str) -> None:
|
||||||
|
|
|
@ -9,7 +9,7 @@ from loguru import logger
|
||||||
|
|
||||||
from ..builds.models import build_model
|
from ..builds.models import build_model
|
||||||
from ..core import paths
|
from ..core import paths
|
||||||
from ..structures.model_schemas import PyTorchConfig
|
from ..core.model_config import model_config
|
||||||
from .base import BaseModelBackend
|
from .base import BaseModelBackend
|
||||||
|
|
||||||
|
|
||||||
|
@ -96,7 +96,12 @@ class PyTorchGPUBackend(BaseModelBackend):
|
||||||
raise RuntimeError("CUDA not available")
|
raise RuntimeError("CUDA not available")
|
||||||
self._device = "cuda"
|
self._device = "cuda"
|
||||||
self._model: Optional[torch.nn.Module] = None
|
self._model: Optional[torch.nn.Module] = None
|
||||||
self._config = PyTorchConfig()
|
|
||||||
|
# Configure GPU settings
|
||||||
|
config = model_config.pytorch_gpu
|
||||||
|
if config.sync_cuda:
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
torch.cuda.set_device(config.device_id)
|
||||||
|
|
||||||
async def load_model(self, path: str) -> None:
|
async def load_model(self, path: str) -> None:
|
||||||
"""Load PyTorch model.
|
"""Load PyTorch model.
|
||||||
|
@ -154,13 +159,19 @@ class PyTorchGPUBackend(BaseModelBackend):
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Generation failed: {e}")
|
logger.error(f"Generation failed: {e}")
|
||||||
|
if model_config.pytorch_gpu.retry_on_oom and "out of memory" in str(e).lower():
|
||||||
|
self._clear_memory()
|
||||||
|
return self.generate(tokens, voice, speed) # Retry once
|
||||||
raise
|
raise
|
||||||
|
finally:
|
||||||
|
if model_config.pytorch_gpu.sync_cuda:
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
def _check_memory(self) -> bool:
|
def _check_memory(self) -> bool:
|
||||||
"""Check if memory usage is above threshold."""
|
"""Check if memory usage is above threshold."""
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
memory_gb = torch.cuda.memory_allocated() / 1e9
|
memory_gb = torch.cuda.memory_allocated() / 1e9
|
||||||
return memory_gb > self._config.memory_threshold
|
return memory_gb > model_config.pytorch_gpu.memory_threshold
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _clear_memory(self) -> None:
|
def _clear_memory(self) -> None:
|
||||||
|
|
|
@ -33,7 +33,15 @@ class VoiceManager:
|
||||||
Returns:
|
Returns:
|
||||||
Path to voice file if exists, None otherwise
|
Path to voice file if exists, None otherwise
|
||||||
"""
|
"""
|
||||||
voice_path = os.path.join(settings.voices_dir, f"{voice_name}.pt")
|
# Get api directory path (two levels up from inference)
|
||||||
|
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
|
||||||
|
# Construct voice path relative to api directory
|
||||||
|
voice_path = os.path.join(api_dir, settings.voices_dir, f"{voice_name}.pt")
|
||||||
|
|
||||||
|
# Ensure voices directory exists
|
||||||
|
os.makedirs(os.path.dirname(voice_path), exist_ok=True)
|
||||||
|
|
||||||
return voice_path if os.path.exists(voice_path) else None
|
return voice_path if os.path.exists(voice_path) else None
|
||||||
|
|
||||||
async def load_voice(self, voice_name: str, device: str = "cpu") -> torch.Tensor:
|
async def load_voice(self, voice_name: str, device: str = "cpu") -> torch.Tensor:
|
||||||
|
@ -112,8 +120,15 @@ class VoiceManager:
|
||||||
combined_name = "_".join(voices)
|
combined_name = "_".join(voices)
|
||||||
combined_tensor = torch.mean(torch.stack(voice_tensors), dim=0)
|
combined_tensor = torch.mean(torch.stack(voice_tensors), dim=0)
|
||||||
|
|
||||||
|
# Get api directory path
|
||||||
|
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
voices_dir = os.path.join(api_dir, settings.voices_dir)
|
||||||
|
|
||||||
|
# Ensure voices directory exists
|
||||||
|
os.makedirs(voices_dir, exist_ok=True)
|
||||||
|
|
||||||
# Save combined voice
|
# Save combined voice
|
||||||
combined_path = os.path.join(settings.voices_dir, f"{combined_name}.pt")
|
combined_path = os.path.join(voices_dir, f"{combined_name}.pt")
|
||||||
try:
|
try:
|
||||||
torch.save(combined_tensor, combined_path)
|
torch.save(combined_tensor, combined_path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -132,7 +147,15 @@ class VoiceManager:
|
||||||
"""
|
"""
|
||||||
voices = []
|
voices = []
|
||||||
try:
|
try:
|
||||||
for entry in os.listdir(settings.voices_dir):
|
# Get api directory path
|
||||||
|
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
voices_dir = os.path.join(api_dir, settings.voices_dir)
|
||||||
|
|
||||||
|
# Ensure voices directory exists
|
||||||
|
os.makedirs(voices_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# List voice files
|
||||||
|
for entry in os.listdir(voices_dir):
|
||||||
if entry.endswith(".pt"):
|
if entry.endswith(".pt"):
|
||||||
voices.append(entry[:-3]) # Remove .pt extension
|
voices.append(entry[:-3]) # Remove .pt extension
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -13,7 +13,6 @@ from loguru import logger
|
||||||
from .core.config import settings
|
from .core.config import settings
|
||||||
from .routers.development import router as dev_router
|
from .routers.development import router as dev_router
|
||||||
from .routers.openai_compatible import router as openai_router
|
from .routers.openai_compatible import router as openai_router
|
||||||
from .services.tts_model import TTSModel
|
|
||||||
from .services.tts_service import TTSService
|
from .services.tts_service import TTSService
|
||||||
|
|
||||||
|
|
||||||
|
@ -44,25 +43,32 @@ async def lifespan(app: FastAPI):
|
||||||
"""Lifespan context manager for model initialization"""
|
"""Lifespan context manager for model initialization"""
|
||||||
logger.info("Loading TTS model and voice packs...")
|
logger.info("Loading TTS model and voice packs...")
|
||||||
|
|
||||||
# Initialize the main model with warm-up
|
# Initialize service
|
||||||
voicepack_count = await TTSModel.setup()
|
service = TTSService()
|
||||||
# boundary = "█████╗"*9
|
await service.ensure_initialized()
|
||||||
|
|
||||||
|
# Get available voices
|
||||||
|
voices = await service.list_voices()
|
||||||
|
voicepack_count = len(voices)
|
||||||
|
|
||||||
|
# Get device info from model manager
|
||||||
|
device = "GPU" if settings.use_gpu else "CPU"
|
||||||
|
model = "ONNX" if settings.use_onnx else "PyTorch"
|
||||||
boundary = "░" * 2*12
|
boundary = "░" * 2*12
|
||||||
startup_msg = f"""
|
startup_msg = f"""
|
||||||
|
|
||||||
{boundary}
|
{boundary}
|
||||||
|
|
||||||
╔═╗┌─┐┌─┐┌┬┐
|
╔═╗┌─┐┌─┐┌┬┐
|
||||||
╠╣ ├─┤└─┐ │
|
╠╣ ├─┤└─┐ │
|
||||||
╚ ┴ ┴└─┘ ┴
|
╚ ┴ ┴└─┘ ┴
|
||||||
╦╔═┌─┐┬┌─┌─┐
|
╦╔═┌─┐┬┌─┌─┐
|
||||||
╠╩╗│ │├┴┐│ │
|
╠╩╗│ │├┴┐│ │
|
||||||
╩ ╩└─┘┴ ┴└─┘
|
╩ ╩└─┘┴ ┴└─┘
|
||||||
|
|
||||||
{boundary}
|
{boundary}
|
||||||
"""
|
"""
|
||||||
# TODO: Improve CPU warmup, threads, memory, etc
|
startup_msg += f"\nModel warmed up on {device}: {model}"
|
||||||
startup_msg += f"\nModel warmed up on {TTSModel.get_device()}"
|
|
||||||
startup_msg += f"\n{voicepack_count} voice packs loaded\n"
|
startup_msg += f"\n{voicepack_count} voice packs loaded\n"
|
||||||
startup_msg += f"\n{boundary}\n"
|
startup_msg += f"\n{boundary}\n"
|
||||||
logger.info(startup_msg)
|
logger.info(startup_msg)
|
||||||
|
|
|
@ -6,7 +6,6 @@ from loguru import logger
|
||||||
|
|
||||||
from ..services.audio import AudioService
|
from ..services.audio import AudioService
|
||||||
from ..services.text_processing import phonemize, tokenize
|
from ..services.text_processing import phonemize, tokenize
|
||||||
from ..services.tts_model import TTSModel
|
|
||||||
from ..services.tts_service import TTSService
|
from ..services.tts_service import TTSService
|
||||||
from ..structures.text_schemas import (
|
from ..structures.text_schemas import (
|
||||||
GenerateFromPhonemesRequest,
|
GenerateFromPhonemesRequest,
|
||||||
|
@ -82,27 +81,34 @@ async def generate_from_phonemes(
|
||||||
detail={"error": "Invalid request", "message": "Phonemes cannot be empty"},
|
detail={"error": "Invalid request", "message": "Phonemes cannot be empty"},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate voice exists
|
|
||||||
voice_path = tts_service._get_voice_path(request.voice)
|
|
||||||
if not voice_path:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail={
|
|
||||||
"error": "Invalid request",
|
|
||||||
"message": f"Voice not found: {request.voice}",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Load voice
|
# Ensure service is initialized
|
||||||
voicepack = tts_service._load_voice(voice_path)
|
await tts_service.ensure_initialized()
|
||||||
|
|
||||||
|
# Validate voice exists
|
||||||
|
available_voices = await tts_service.list_voices()
|
||||||
|
if request.voice not in available_voices:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": "Invalid request",
|
||||||
|
"message": f"Voice not found: {request.voice}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Convert phonemes to tokens
|
# Convert phonemes to tokens
|
||||||
tokens = tokenize(request.phonemes)
|
tokens = tokenize(request.phonemes)
|
||||||
tokens = [0] + tokens + [0] # Add start/end tokens
|
tokens = [0] + tokens + [0] # Add start/end tokens
|
||||||
|
|
||||||
# Generate audio directly from tokens
|
# Generate audio directly from tokens
|
||||||
audio = TTSModel.generate_from_tokens(tokens, voicepack, request.speed)
|
audio = await tts_service.model_manager.generate(
|
||||||
|
tokens,
|
||||||
|
request.voice,
|
||||||
|
speed=request.speed
|
||||||
|
)
|
||||||
|
|
||||||
|
if audio is None:
|
||||||
|
raise ValueError("Failed to generate audio")
|
||||||
|
|
||||||
# Convert to WAV bytes
|
# Convert to WAV bytes
|
||||||
wav_bytes = AudioService.convert_audio(
|
wav_bytes = AudioService.convert_audio(
|
||||||
|
|
|
@ -1,13 +1,28 @@
|
||||||
from .normalizer import normalize_text
|
"""Text processing pipeline."""
|
||||||
from .phonemizer import EspeakBackend, PhonemizerBackend, phonemize
|
|
||||||
from .vocabulary import VOCAB, decode_tokens, tokenize
|
|
||||||
|
|
||||||
__all__ = [
|
from .chunker import split_text
|
||||||
"normalize_text",
|
from .normalizer import normalize_text
|
||||||
"phonemize",
|
from .phonemizer import phonemize
|
||||||
"tokenize",
|
from .vocabulary import tokenize
|
||||||
"decode_tokens",
|
|
||||||
"VOCAB",
|
|
||||||
"PhonemizerBackend",
|
def process_text(text: str, language: str = "a") -> list[int]:
|
||||||
"EspeakBackend",
|
"""Process text through the full pipeline.
|
||||||
]
|
|
||||||
|
Args:
|
||||||
|
text: Input text
|
||||||
|
language: Language code ('a' for US English, 'b' for British English)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of token IDs
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The pipeline:
|
||||||
|
1. Converts text to phonemes using phonemizer
|
||||||
|
2. Converts phonemes to token IDs using vocabulary
|
||||||
|
"""
|
||||||
|
# Convert text to phonemes
|
||||||
|
phonemes = phonemize(text, language=language)
|
||||||
|
|
||||||
|
# Convert phonemes to token IDs
|
||||||
|
return tokenize(phonemes)
|
||||||
|
|
|
@ -5,17 +5,16 @@ import os
|
||||||
import time
|
import time
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.io.wavfile as wavfile
|
import scipy.io.wavfile as wavfile
|
||||||
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from ..core.config import settings
|
from ..core.config import settings
|
||||||
from ..inference.model_manager import get_manager as get_model_manager
|
from ..inference.model_manager import get_manager as get_model_manager
|
||||||
from ..inference.voice_manager import get_manager as get_voice_manager
|
from ..inference.voice_manager import get_manager as get_voice_manager
|
||||||
from .audio import AudioNormalizer, AudioService
|
from .audio import AudioNormalizer, AudioService
|
||||||
from .text_processing import chunker, normalize_text
|
from .text_processing import chunker, normalize_text, process_text
|
||||||
|
|
||||||
|
|
||||||
class TTSService:
|
class TTSService:
|
||||||
|
@ -41,16 +40,33 @@ class TTSService:
|
||||||
raise self._initialization_error
|
raise self._initialization_error
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Determine model path based on hardware
|
# Get api directory path (one level up from src)
|
||||||
|
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
|
||||||
|
# Determine model file and backend based on hardware
|
||||||
if settings.use_gpu and torch.cuda.is_available():
|
if settings.use_gpu and torch.cuda.is_available():
|
||||||
model_path = os.path.join(settings.model_dir, settings.pytorch_model_path)
|
model_file = settings.pytorch_model_file
|
||||||
backend_type = 'pytorch_gpu'
|
backend_type = 'pytorch_gpu'
|
||||||
else:
|
else:
|
||||||
model_path = os.path.join(settings.model_dir, settings.onnx_model_path)
|
model_file = settings.onnx_model_file
|
||||||
backend_type = 'onnx_cpu'
|
backend_type = 'onnx_cpu'
|
||||||
|
|
||||||
# Initialize model
|
# Construct model path relative to api directory
|
||||||
await self.model_manager.load_model(model_path, backend_type)
|
model_path = os.path.join(api_dir, settings.model_dir, model_file)
|
||||||
|
|
||||||
|
# Ensure model directory exists
|
||||||
|
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
||||||
|
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
raise RuntimeError(f"Model file not found: {model_path}")
|
||||||
|
|
||||||
|
# Load default voice for warmup
|
||||||
|
backend = self.model_manager.get_backend(backend_type)
|
||||||
|
warmup_voice = await self.voice_manager.load_voice(settings.default_voice, device=backend.device)
|
||||||
|
logger.info(f"Loaded voice {settings.default_voice} for warmup")
|
||||||
|
|
||||||
|
# Initialize model with warmup voice
|
||||||
|
await self.model_manager.load_model(model_path, warmup_voice, backend_type)
|
||||||
logger.info(f"Initialized model on {backend_type} backend")
|
logger.info(f"Initialized model on {backend_type} backend")
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
|
@ -86,16 +102,19 @@ class TTSService:
|
||||||
audio_chunks = []
|
audio_chunks = []
|
||||||
for chunk in chunker.split_text(text):
|
for chunk in chunker.split_text(text):
|
||||||
try:
|
try:
|
||||||
# Process text
|
# Convert chunk to token IDs
|
||||||
|
tokens = process_text(chunk)
|
||||||
sequences = process_text(chunk)
|
if not tokens:
|
||||||
if not sequences:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Get backend and load voice
|
||||||
|
backend = self.model_manager.get_backend()
|
||||||
|
voice_tensor = await self.voice_manager.load_voice(voice, device=backend.device)
|
||||||
|
|
||||||
# Generate audio
|
# Generate audio
|
||||||
chunk_audio = await self.model_manager.generate(
|
chunk_audio = await self.model_manager.generate(
|
||||||
sequences[0],
|
tokens,
|
||||||
voice,
|
voice_tensor,
|
||||||
speed=speed
|
speed=speed
|
||||||
)
|
)
|
||||||
if chunk_audio is not None:
|
if chunk_audio is not None:
|
||||||
|
@ -154,14 +173,17 @@ class TTSService:
|
||||||
while current_chunk is not None:
|
while current_chunk is not None:
|
||||||
next_chunk = next(chunk_gen, None)
|
next_chunk = next(chunk_gen, None)
|
||||||
try:
|
try:
|
||||||
# Process text
|
# Convert chunk to token IDs
|
||||||
from ..text_processing import process_text
|
tokens = process_text(current_chunk)
|
||||||
sequences = process_text(current_chunk)
|
if tokens:
|
||||||
if sequences:
|
# Get backend and load voice
|
||||||
|
backend = self.model_manager.get_backend()
|
||||||
|
voice_tensor = await self.voice_manager.load_voice(voice, device=backend.device)
|
||||||
|
|
||||||
# Generate audio
|
# Generate audio
|
||||||
chunk_audio = await self.model_manager.generate(
|
chunk_audio = await self.model_manager.generate(
|
||||||
sequences[0],
|
tokens,
|
||||||
voice,
|
voice_tensor,
|
||||||
speed=speed
|
speed=speed
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,26 +1,13 @@
|
||||||
"""Model and voice configuration schemas."""
|
"""Voice configuration schemas."""
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseModel):
|
|
||||||
"""Model configuration."""
|
|
||||||
optimization_level: str = "all" # all, basic, none
|
|
||||||
num_threads: int = 4
|
|
||||||
inter_op_threads: int = 4
|
|
||||||
execution_mode: str = "parallel" # parallel, sequential
|
|
||||||
memory_pattern: bool = True
|
|
||||||
arena_extend_strategy: str = "kNextPowerOfTwo"
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
frozen = True # Make config immutable
|
|
||||||
|
|
||||||
|
|
||||||
class VoiceConfig(BaseModel):
|
class VoiceConfig(BaseModel):
|
||||||
"""Voice configuration."""
|
"""Voice configuration."""
|
||||||
use_cache: bool = True
|
use_cache: bool = Field(True, description="Whether to cache loaded voices")
|
||||||
cache_size: int = 3 # Number of voices to cache
|
cache_size: int = Field(3, description="Number of voices to cache")
|
||||||
validate_on_load: bool = True # Whether to validate voices when loading
|
validate_on_load: bool = Field(True, description="Whether to validate voices when loading")
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
frozen = True # Make config immutable
|
frozen = True # Make config immutable
|
|
@ -17,29 +17,10 @@ COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
||||||
RUN useradd -m -u 1000 appuser
|
RUN useradd -m -u 1000 appuser
|
||||||
|
|
||||||
# Create directories and set ownership
|
# Create directories and set ownership
|
||||||
RUN mkdir -p /app/models && \
|
RUN mkdir -p /app/api/src/voices && \
|
||||||
mkdir -p /app/api/src/voices && \
|
|
||||||
chown -R appuser:appuser /app
|
chown -R appuser:appuser /app
|
||||||
|
|
||||||
USER appuser
|
USER appuser
|
||||||
|
|
||||||
# Download and extract models
|
|
||||||
WORKDIR /app/models
|
|
||||||
RUN set -x && \
|
|
||||||
curl -L -o model.tar.gz https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.0.1/kokoro-82m-onnx.tar.gz && \
|
|
||||||
echo "Downloaded model.tar.gz:" && ls -lh model.tar.gz && \
|
|
||||||
tar xzf model.tar.gz && \
|
|
||||||
echo "Contents after extraction:" && ls -lhR && \
|
|
||||||
rm model.tar.gz && \
|
|
||||||
echo "Final contents:" && ls -lhR
|
|
||||||
|
|
||||||
# Download and extract voice models
|
|
||||||
WORKDIR /app/api/src/voices
|
|
||||||
RUN curl -L -o voices.tar.gz https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.0.1/voice-models.tar.gz && \
|
|
||||||
tar xzf voices.tar.gz && \
|
|
||||||
rm voices.tar.gz
|
|
||||||
|
|
||||||
# Switch back to app directory
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# Copy dependency files
|
# Copy dependency files
|
||||||
|
@ -59,9 +40,10 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
|
|
||||||
# Set environment variables
|
# Set environment variables
|
||||||
ENV PYTHONUNBUFFERED=1
|
ENV PYTHONUNBUFFERED=1
|
||||||
ENV PYTHONPATH=/app:/app/models
|
ENV PYTHONPATH=/app
|
||||||
ENV PATH="/app/.venv/bin:$PATH"
|
ENV PATH="/app/.venv/bin:$PATH"
|
||||||
ENV UV_LINK_MODE=copy
|
ENV UV_LINK_MODE=copy
|
||||||
|
ENV USE_GPU=false
|
||||||
|
|
||||||
# Run FastAPI server
|
# Run FastAPI server
|
||||||
CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]
|
CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]
|
||||||
|
|
|
@ -6,12 +6,11 @@ services:
|
||||||
context: ../..
|
context: ../..
|
||||||
dockerfile: docker/cpu/Dockerfile
|
dockerfile: docker/cpu/Dockerfile
|
||||||
volumes:
|
volumes:
|
||||||
- ../../api/src:/app/api/src
|
- ../../api:/app/api
|
||||||
- ../../api/src/voices:/app/api/src/voices
|
|
||||||
ports:
|
ports:
|
||||||
- "8880:8880"
|
- "8880:8880"
|
||||||
environment:
|
environment:
|
||||||
- PYTHONPATH=/app:/app/models
|
- PYTHONPATH=/app:/app/api
|
||||||
# ONNX Optimization Settings for vectorized operations
|
# ONNX Optimization Settings for vectorized operations
|
||||||
- ONNX_NUM_THREADS=8 # Maximize core usage for vectorized ops
|
- ONNX_NUM_THREADS=8 # Maximize core usage for vectorized ops
|
||||||
- ONNX_INTER_OP_THREADS=4 # Higher inter-op for parallel matrix operations
|
- ONNX_INTER_OP_THREADS=4 # Higher inter-op for parallel matrix operations
|
||||||
|
@ -20,20 +19,20 @@ services:
|
||||||
- ONNX_MEMORY_PATTERN=true
|
- ONNX_MEMORY_PATTERN=true
|
||||||
- ONNX_ARENA_EXTEND_STRATEGY=kNextPowerOfTwo
|
- ONNX_ARENA_EXTEND_STRATEGY=kNextPowerOfTwo
|
||||||
|
|
||||||
# Gradio UI service [Comment out everything below if you don't need it]
|
# # Gradio UI service [Comment out everything below if you don't need it]
|
||||||
gradio-ui:
|
# gradio-ui:
|
||||||
image: ghcr.io/remsky/kokoro-fastapi-ui:v0.1.0
|
# image: ghcr.io/remsky/kokoro-fastapi-ui:v0.1.0
|
||||||
# Uncomment below (and comment out above) to build from source instead of using the released image
|
# # Uncomment below (and comment out above) to build from source instead of using the released image
|
||||||
build:
|
# build:
|
||||||
context: ../../ui
|
# context: ../../ui
|
||||||
ports:
|
# ports:
|
||||||
- "7860:7860"
|
# - "7860:7860"
|
||||||
volumes:
|
# volumes:
|
||||||
- ../../ui/data:/app/ui/data
|
# - ../../ui/data:/app/ui/data
|
||||||
- ../../ui/app.py:/app/app.py # Mount app.py for hot reload
|
# - ../../ui/app.py:/app/app.py # Mount app.py for hot reload
|
||||||
environment:
|
# environment:
|
||||||
- GRADIO_WATCH=True # Enable hot reloading
|
# - GRADIO_WATCH=True # Enable hot reloading
|
||||||
- PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
|
# - PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
|
||||||
- DISABLE_LOCAL_SAVING=false # Set to 'true' to disable local saving and hide file view
|
# - DISABLE_LOCAL_SAVING=false # Set to 'true' to disable local saving and hide file view
|
||||||
- API_HOST=kokoro-tts # Set TTS service URL
|
# - API_HOST=kokoro-tts # Set TTS service URL
|
||||||
- API_PORT=8880 # Set TTS service PORT
|
# - API_PORT=8880 # Set TTS service PORT
|
||||||
|
|
53
docker/cpu/download_onnx.py
Executable file
53
docker/cpu/download_onnx.py
Executable file
|
@ -0,0 +1,53 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import requests
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
def download_file(url: str, output_dir: Path) -> None:
|
||||||
|
"""Download a file from URL to the specified directory."""
|
||||||
|
filename = os.path.basename(url)
|
||||||
|
output_path = output_dir / filename
|
||||||
|
|
||||||
|
print(f"Downloading {filename}...")
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
with open(output_path, 'wb') as f:
|
||||||
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
|
f.write(chunk)
|
||||||
|
|
||||||
|
def find_project_root() -> Path:
|
||||||
|
"""Find project root by looking for api directory."""
|
||||||
|
max_steps = 5
|
||||||
|
current = Path(__file__).resolve()
|
||||||
|
for _ in range(max_steps):
|
||||||
|
if (current / 'api').is_dir():
|
||||||
|
return current
|
||||||
|
current = current.parent
|
||||||
|
raise RuntimeError("Could not find project root (no api directory found)")
|
||||||
|
|
||||||
|
def main(custom_models: List[str] = None):
|
||||||
|
# Always use top-level models directory relative to project root
|
||||||
|
project_root = find_project_root()
|
||||||
|
models_dir = project_root / 'api' / 'src' / 'models'
|
||||||
|
models_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
# Default ONNX model if no arguments provided
|
||||||
|
default_models = [
|
||||||
|
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.onnx",
|
||||||
|
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19_fp16.onnx"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Use provided models or default
|
||||||
|
models_to_download = custom_models if custom_models else default_models
|
||||||
|
|
||||||
|
for model_url in models_to_download:
|
||||||
|
try:
|
||||||
|
download_file(model_url, models_dir)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error downloading {model_url}: {e}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main(sys.argv[1:] if len(sys.argv) > 1 else None)
|
32
docker/cpu/download_onnx.sh
Executable file
32
docker/cpu/download_onnx.sh
Executable file
|
@ -0,0 +1,32 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Ensure models directory exists
|
||||||
|
mkdir -p api/src/models
|
||||||
|
|
||||||
|
# Function to download a file
|
||||||
|
download_file() {
|
||||||
|
local url="$1"
|
||||||
|
local filename=$(basename "$url")
|
||||||
|
echo "Downloading $filename..."
|
||||||
|
curl -L "$url" -o "api/src/models/$filename"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Default ONNX model if no arguments provided
|
||||||
|
DEFAULT_MODELS=(
|
||||||
|
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.onnx"
|
||||||
|
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19_fp16.onnx"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use provided models or default
|
||||||
|
if [ $# -gt 0 ]; then
|
||||||
|
MODELS=("$@")
|
||||||
|
else
|
||||||
|
MODELS=("${DEFAULT_MODELS[@]}")
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Download all models
|
||||||
|
for model in "${MODELS[@]}"; do
|
||||||
|
download_file "$model"
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "ONNX model download complete!"
|
|
@ -1,9 +1,8 @@
|
||||||
FROM nvidia/cuda:12.1.0-base-ubuntu22.04
|
FROM pytorch/pytorch:2.1.0-cuda11.8-cudnn8-runtime
|
||||||
|
# Set non-interactive frontend
|
||||||
# Install Python and other dependencies
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
# Install dependencies
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
python3.10 \
|
|
||||||
python3.10-venv \
|
|
||||||
espeak-ng \
|
espeak-ng \
|
||||||
git \
|
git \
|
||||||
libsndfile1 \
|
libsndfile1 \
|
||||||
|
@ -19,25 +18,10 @@ COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
||||||
RUN useradd -m -u 1000 appuser
|
RUN useradd -m -u 1000 appuser
|
||||||
|
|
||||||
# Create directories and set ownership
|
# Create directories and set ownership
|
||||||
RUN mkdir -p /app/models && \
|
RUN mkdir -p /app/api/src/voices && \
|
||||||
mkdir -p /app/api/src/voices && \
|
|
||||||
chown -R appuser:appuser /app
|
chown -R appuser:appuser /app
|
||||||
|
|
||||||
USER appuser
|
USER appuser
|
||||||
|
|
||||||
# Download and extract models
|
|
||||||
WORKDIR /app/models
|
|
||||||
RUN curl -L -o model.tar.gz https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.0.1/kokoro-82m-pytorch.tar.gz && \
|
|
||||||
tar xzf model.tar.gz && \
|
|
||||||
rm model.tar.gz
|
|
||||||
|
|
||||||
# Download and extract voice models
|
|
||||||
WORKDIR /app/api/src/voices
|
|
||||||
RUN curl -L -o voices.tar.gz https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.0.1/voice-models.tar.gz && \
|
|
||||||
tar xzf voices.tar.gz && \
|
|
||||||
rm voices.tar.gz
|
|
||||||
|
|
||||||
# Switch back to app directory
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# Copy dependency files
|
# Copy dependency files
|
||||||
|
@ -57,9 +41,10 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
|
|
||||||
# Set environment variables
|
# Set environment variables
|
||||||
ENV PYTHONUNBUFFERED=1
|
ENV PYTHONUNBUFFERED=1
|
||||||
ENV PYTHONPATH=/app:/app/models
|
ENV PYTHONPATH=/app
|
||||||
ENV PATH="/app/.venv/bin:$PATH"
|
ENV PATH="/app/.venv/bin:$PATH"
|
||||||
ENV UV_LINK_MODE=copy
|
ENV UV_LINK_MODE=copy
|
||||||
|
ENV USE_GPU=true
|
||||||
|
|
||||||
# Run FastAPI server
|
# Run FastAPI server
|
||||||
CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]
|
CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]
|
||||||
|
|
|
@ -6,12 +6,14 @@ services:
|
||||||
context: ../..
|
context: ../..
|
||||||
dockerfile: docker/gpu/Dockerfile
|
dockerfile: docker/gpu/Dockerfile
|
||||||
volumes:
|
volumes:
|
||||||
- ../../api/src:/app/api/src # Mount src for development
|
- ../../api:/app/api
|
||||||
- ../../api/src/voices:/app/api/src/voices # Mount voices for persistence
|
|
||||||
ports:
|
ports:
|
||||||
- "8880:8880"
|
- "8880:8880"
|
||||||
environment:
|
environment:
|
||||||
- PYTHONPATH=/app:/app/models
|
- PYTHONPATH=/app
|
||||||
|
- USE_GPU=true
|
||||||
|
- USE_ONNX=false
|
||||||
|
- PYTHONUNBUFFERED=1
|
||||||
deploy:
|
deploy:
|
||||||
resources:
|
resources:
|
||||||
reservations:
|
reservations:
|
||||||
|
@ -20,20 +22,20 @@ services:
|
||||||
count: 1
|
count: 1
|
||||||
capabilities: [gpu]
|
capabilities: [gpu]
|
||||||
|
|
||||||
# Gradio UI service
|
# # Gradio UI service
|
||||||
gradio-ui:
|
# gradio-ui:
|
||||||
image: ghcr.io/remsky/kokoro-fastapi-ui:v0.1.0
|
# image: ghcr.io/remsky/kokoro-fastapi-ui:v0.1.0
|
||||||
# Uncomment below to build from source instead of using the released image
|
# # Uncomment below to build from source instead of using the released image
|
||||||
# build:
|
# # build:
|
||||||
# context: ../../ui
|
# # context: ../../ui
|
||||||
ports:
|
# ports:
|
||||||
- "7860:7860"
|
# - "7860:7860"
|
||||||
volumes:
|
# volumes:
|
||||||
- ../../ui/data:/app/ui/data
|
# - ../../ui/data:/app/ui/data
|
||||||
- ../../ui/app.py:/app/app.py # Mount app.py for hot reload
|
# - ../../ui/app.py:/app/app.py # Mount app.py for hot reload
|
||||||
environment:
|
# environment:
|
||||||
- GRADIO_WATCH=1 # Enable hot reloading
|
# - GRADIO_WATCH=1 # Enable hot reloading
|
||||||
- PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
|
# - PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
|
||||||
- DISABLE_LOCAL_SAVING=false # Set to 'true' to disable local saving and hide file view
|
# - DISABLE_LOCAL_SAVING=false # Set to 'true' to disable local saving and hide file view
|
||||||
- API_HOST=kokoro-tts # Set TTS service URL
|
# - API_HOST=kokoro-tts # Set TTS service URL
|
||||||
- API_PORT=8880 # Set TTS service PORT
|
# - API_PORT=8880 # Set TTS service PORT
|
||||||
|
|
57
docker/gpu/download_pth.py
Executable file
57
docker/gpu/download_pth.py
Executable file
|
@ -0,0 +1,57 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import requests
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
def download_file(url: str, output_dir: Path) -> None:
|
||||||
|
"""Download a file from URL to the specified directory."""
|
||||||
|
filename = os.path.basename(url)
|
||||||
|
if not filename.endswith('.pth'):
|
||||||
|
print(f"Warning: {filename} is not a .pth file")
|
||||||
|
return
|
||||||
|
|
||||||
|
output_path = output_dir / filename
|
||||||
|
|
||||||
|
print(f"Downloading {filename}...")
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
with open(output_path, 'wb') as f:
|
||||||
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
|
f.write(chunk)
|
||||||
|
|
||||||
|
def find_project_root() -> Path:
|
||||||
|
"""Find project root by looking for api directory."""
|
||||||
|
max_steps = 5
|
||||||
|
current = Path(__file__).resolve()
|
||||||
|
for _ in range(max_steps):
|
||||||
|
if (current / 'api').is_dir():
|
||||||
|
return current
|
||||||
|
current = current.parent
|
||||||
|
raise RuntimeError("Could not find project root (no api directory found)")
|
||||||
|
|
||||||
|
def main(custom_models: List[str] = None):
|
||||||
|
# Find project root and ensure models directory exists
|
||||||
|
project_root = find_project_root()
|
||||||
|
models_dir = project_root / 'api' / 'src' / 'models'
|
||||||
|
print(f"Downloading models to {models_dir}")
|
||||||
|
models_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# Default PTH model if no arguments provided
|
||||||
|
default_models = [
|
||||||
|
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Use provided models or default
|
||||||
|
models_to_download = custom_models if custom_models else default_models
|
||||||
|
|
||||||
|
for model_url in models_to_download:
|
||||||
|
try:
|
||||||
|
download_file(model_url, models_dir)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error downloading {model_url}: {e}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main(sys.argv[1:] if len(sys.argv) > 1 else None)
|
31
docker/gpu/download_pth.sh
Executable file
31
docker/gpu/download_pth.sh
Executable file
|
@ -0,0 +1,31 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Ensure models directory exists
|
||||||
|
mkdir -p api/src/models
|
||||||
|
|
||||||
|
# Function to download a file
|
||||||
|
download_file() {
|
||||||
|
local url="$1"
|
||||||
|
local filename=$(basename "$url")
|
||||||
|
echo "Downloading $filename..."
|
||||||
|
curl -L "$url" -o "api/src/models/$filename"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Default PTH model if no arguments provided
|
||||||
|
DEFAULT_MODELS=(
|
||||||
|
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use provided models or default
|
||||||
|
if [ $# -gt 0 ]; then
|
||||||
|
MODELS=("$@")
|
||||||
|
else
|
||||||
|
MODELS=("${DEFAULT_MODELS[@]}")
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Download all models
|
||||||
|
for model in "${MODELS[@]}"; do
|
||||||
|
download_file "$model"
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "PyTorch model download complete!"
|
67
uv.lock
generated
67
uv.lock
generated
|
@ -2,9 +2,17 @@ version = 1
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
resolution-markers = [
|
resolution-markers = [
|
||||||
"python_full_version == '3.11.*'",
|
"python_full_version == '3.11.*'",
|
||||||
|
"(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||||
|
"(python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform == 'darwin')",
|
||||||
"python_full_version < '3.11'",
|
"python_full_version < '3.11'",
|
||||||
|
"(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||||
|
"(python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform == 'darwin')",
|
||||||
"python_full_version >= '3.13'",
|
"python_full_version >= '3.13'",
|
||||||
|
"(python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||||
|
"(python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.13' and sys_platform == 'darwin')",
|
||||||
"python_full_version == '3.12.*'",
|
"python_full_version == '3.12.*'",
|
||||||
|
"(python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||||
|
"(python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.12.*' and sys_platform == 'darwin')",
|
||||||
]
|
]
|
||||||
conflicts = [[
|
conflicts = [[
|
||||||
{ package = "kokoro-fastapi", extra = "cpu" },
|
{ package = "kokoro-fastapi", extra = "cpu" },
|
||||||
|
@ -798,6 +806,7 @@ dependencies = [
|
||||||
{ name = "phonemizer" },
|
{ name = "phonemizer" },
|
||||||
{ name = "pydantic" },
|
{ name = "pydantic" },
|
||||||
{ name = "pydantic-settings" },
|
{ name = "pydantic-settings" },
|
||||||
|
{ name = "pydub" },
|
||||||
{ name = "python-dotenv" },
|
{ name = "python-dotenv" },
|
||||||
{ name = "regex" },
|
{ name = "regex" },
|
||||||
{ name = "requests" },
|
{ name = "requests" },
|
||||||
|
@ -812,7 +821,8 @@ dependencies = [
|
||||||
|
|
||||||
[package.optional-dependencies]
|
[package.optional-dependencies]
|
||||||
cpu = [
|
cpu = [
|
||||||
{ name = "torch", version = "2.5.1+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" } },
|
{ name = "torch", version = "2.5.1", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" },
|
||||||
|
{ name = "torch", version = "2.5.1+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||||
]
|
]
|
||||||
gpu = [
|
gpu = [
|
||||||
{ name = "torch", version = "2.5.1+cu121", source = { registry = "https://download.pytorch.org/whl/cu121" } },
|
{ name = "torch", version = "2.5.1+cu121", source = { registry = "https://download.pytorch.org/whl/cu121" } },
|
||||||
|
@ -824,7 +834,6 @@ test = [
|
||||||
{ name = "pytest" },
|
{ name = "pytest" },
|
||||||
{ name = "pytest-asyncio" },
|
{ name = "pytest-asyncio" },
|
||||||
{ name = "pytest-cov" },
|
{ name = "pytest-cov" },
|
||||||
{ name = "ruff" },
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.metadata]
|
[package.metadata]
|
||||||
|
@ -845,18 +854,18 @@ requires-dist = [
|
||||||
{ name = "phonemizer", specifier = "==3.3.0" },
|
{ name = "phonemizer", specifier = "==3.3.0" },
|
||||||
{ name = "pydantic", specifier = "==2.10.4" },
|
{ name = "pydantic", specifier = "==2.10.4" },
|
||||||
{ name = "pydantic-settings", specifier = "==2.7.0" },
|
{ name = "pydantic-settings", specifier = "==2.7.0" },
|
||||||
|
{ name = "pydub", specifier = ">=0.25.1" },
|
||||||
{ name = "pytest", marker = "extra == 'test'", specifier = "==8.0.0" },
|
{ name = "pytest", marker = "extra == 'test'", specifier = "==8.0.0" },
|
||||||
{ name = "pytest-asyncio", marker = "extra == 'test'", specifier = "==0.23.5" },
|
{ name = "pytest-asyncio", marker = "extra == 'test'", specifier = "==0.23.5" },
|
||||||
{ name = "pytest-cov", marker = "extra == 'test'", specifier = "==4.1.0" },
|
{ name = "pytest-cov", marker = "extra == 'test'", specifier = "==4.1.0" },
|
||||||
{ name = "python-dotenv", specifier = "==1.0.1" },
|
{ name = "python-dotenv", specifier = "==1.0.1" },
|
||||||
{ name = "regex", specifier = "==2024.11.6" },
|
{ name = "regex", specifier = "==2024.11.6" },
|
||||||
{ name = "requests", specifier = "==2.32.3" },
|
{ name = "requests", specifier = "==2.32.3" },
|
||||||
{ name = "ruff", marker = "extra == 'test'", specifier = ">=0.2.2" },
|
|
||||||
{ name = "scipy", specifier = "==1.14.1" },
|
{ name = "scipy", specifier = "==1.14.1" },
|
||||||
{ name = "soundfile", specifier = "==0.13.0" },
|
{ name = "soundfile", specifier = "==0.13.0" },
|
||||||
{ name = "sqlalchemy", specifier = "==2.0.27" },
|
{ name = "sqlalchemy", specifier = "==2.0.27" },
|
||||||
{ name = "tiktoken", specifier = "==0.8.0" },
|
{ name = "tiktoken", specifier = "==0.8.0" },
|
||||||
{ name = "torch", marker = "extra == 'cpu'", specifier = "==2.5.1+cpu", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "kokoro-fastapi", extra = "cpu" } },
|
{ name = "torch", marker = "extra == 'cpu'", specifier = "==2.5.1", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "kokoro-fastapi", extra = "cpu" } },
|
||||||
{ name = "torch", marker = "extra == 'gpu'", specifier = "==2.5.1+cu121", index = "https://download.pytorch.org/whl/cu121", conflict = { package = "kokoro-fastapi", extra = "gpu" } },
|
{ name = "torch", marker = "extra == 'gpu'", specifier = "==2.5.1+cu121", index = "https://download.pytorch.org/whl/cu121", conflict = { package = "kokoro-fastapi", extra = "gpu" } },
|
||||||
{ name = "tqdm", specifier = "==4.67.1" },
|
{ name = "tqdm", specifier = "==4.67.1" },
|
||||||
{ name = "transformers", specifier = "==4.47.1" },
|
{ name = "transformers", specifier = "==4.47.1" },
|
||||||
|
@ -2334,24 +2343,52 @@ wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/68/4f/12207897848a653d03ebbf6775a29d949408ded5f99b2d87198bc5c93508/tomlkit-0.12.0-py3-none-any.whl", hash = "sha256:926f1f37a1587c7a4f6c7484dae538f1345d96d793d9adab5d3675957b1d0766", size = 37334 },
|
{ url = "https://files.pythonhosted.org/packages/68/4f/12207897848a653d03ebbf6775a29d949408ded5f99b2d87198bc5c93508/tomlkit-0.12.0-py3-none-any.whl", hash = "sha256:926f1f37a1587c7a4f6c7484dae538f1345d96d793d9adab5d3675957b1d0766", size = 37334 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "torch"
|
||||||
|
version = "2.5.1"
|
||||||
|
source = { registry = "https://download.pytorch.org/whl/cpu" }
|
||||||
|
resolution-markers = [
|
||||||
|
"(python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform == 'darwin')",
|
||||||
|
"(python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform == 'darwin')",
|
||||||
|
"(python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.13' and sys_platform == 'darwin')",
|
||||||
|
"(python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.12.*' and sys_platform == 'darwin')",
|
||||||
|
]
|
||||||
|
dependencies = [
|
||||||
|
{ name = "filelock", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" },
|
||||||
|
{ name = "fsspec", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" },
|
||||||
|
{ name = "jinja2", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" },
|
||||||
|
{ name = "networkx", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" },
|
||||||
|
{ name = "setuptools", marker = "(python_full_version >= '3.12' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and sys_platform == 'darwin')" },
|
||||||
|
{ name = "sympy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" },
|
||||||
|
{ name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" },
|
||||||
|
]
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://download.pytorch.org/whl/cpu/torch-2.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:269b10c34430aa8e9643dbe035dc525c4a9b1d671cd3dbc8ecbcaed280ae322d" },
|
||||||
|
{ url = "https://download.pytorch.org/whl/cpu/torch-2.5.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:23d062bf70776a3d04dbe74db950db2a5245e1ba4f27208a87f0d743b0d06e86" },
|
||||||
|
{ url = "https://download.pytorch.org/whl/cpu/torch-2.5.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5b3203f191bc40783c99488d2e776dcf93ac431a59491d627a1ca5b3ae20b22" },
|
||||||
|
{ url = "https://download.pytorch.org/whl/cpu/torch-2.5.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:31f8c39660962f9ae4eeec995e3049b5492eb7360dd4f07377658ef4d728fa4c" },
|
||||||
|
{ url = "https://download.pytorch.org/whl/cpu/torch-2.5.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:36d1be99281b6f602d9639bd0af3ee0006e7aab16f6718d86f709d395b6f262c" },
|
||||||
|
{ url = "https://download.pytorch.org/whl/cpu/torch-2.5.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:8c712df61101964eb11910a846514011f0b6f5920c55dbf567bff8a34163d5b1" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "torch"
|
name = "torch"
|
||||||
version = "2.5.1+cpu"
|
version = "2.5.1+cpu"
|
||||||
source = { registry = "https://download.pytorch.org/whl/cpu" }
|
source = { registry = "https://download.pytorch.org/whl/cpu" }
|
||||||
resolution-markers = [
|
resolution-markers = [
|
||||||
"python_full_version == '3.11.*'",
|
"(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||||
"python_full_version < '3.11'",
|
"(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||||
"python_full_version >= '3.13'",
|
"(python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||||
"python_full_version == '3.12.*'",
|
"(python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "filelock" },
|
{ name = "filelock", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||||
{ name = "fsspec" },
|
{ name = "fsspec", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||||
{ name = "jinja2" },
|
{ name = "jinja2", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||||
{ name = "networkx" },
|
{ name = "networkx", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||||
{ name = "setuptools", marker = "python_full_version >= '3.12'" },
|
{ name = "setuptools", marker = "(python_full_version >= '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||||
{ name = "sympy" },
|
{ name = "sympy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||||
{ name = "typing-extensions" },
|
{ name = "typing-extensions", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||||
]
|
]
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://download.pytorch.org/whl/cpu/torch-2.5.1%2Bcpu-cp310-cp310-linux_x86_64.whl", hash = "sha256:7f91a2200e352745d70e22396bd501448e28350fbdbd8d8b1c83037e25451150" },
|
{ url = "https://download.pytorch.org/whl/cpu/torch-2.5.1%2Bcpu-cp310-cp310-linux_x86_64.whl", hash = "sha256:7f91a2200e352745d70e22396bd501448e28350fbdbd8d8b1c83037e25451150" },
|
||||||
|
|
Loading…
Add table
Reference in a new issue