Enhance model inference: update documentation, add model download scripts for PyTorch and ONNX, and refactor configuration handling

This commit is contained in:
remsky 2025-01-21 21:44:21 -07:00
parent ab28a62e86
commit 21bf810f97
25 changed files with 774 additions and 309 deletions

View file

@ -337,11 +337,13 @@ def recursive_munch(d):
else:
return d
def build_model(path, device):
async def build_model(path, device):
from ..core.paths import load_json, load_model_weights
config = Path(__file__).parent / 'config.json'
assert config.exists(), f'Config path incorrect: config.json not found at {config}'
with open(config, 'r') as r:
args = recursive_munch(json.load(r))
args = recursive_munch(await load_json(config))
assert args.decoder.type == 'istftnet', f'Unknown decoder type: {args.decoder.type}'
decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
@ -365,7 +367,8 @@ def build_model(path, device):
decoder=decoder.to(device).eval(),
text_encoder=text_encoder.to(device).eval(),
)
for key, state_dict in torch.load(path, map_location='cpu', weights_only=True)['net'].items():
weights = await load_model_weights(path, device='cpu')
for key, state_dict in weights['net'].items():
assert key in model, key
try:
model[key].load_state_dict(state_dict)

View file

@ -13,10 +13,15 @@ class Settings(BaseSettings):
output_dir: str = "output"
output_dir_size_limit_mb: float = 500.0 # Maximum size of output directory in MB
default_voice: str = "af"
model_dir: str = "/app/models" # Base directory for model files
pytorch_model_path: str = "kokoro-v0_19.pth"
onnx_model_path: str = "kokoro-v0_19.onnx"
voices_dir: str = "voices"
use_gpu: bool = False # Whether to use GPU acceleration if available
use_onnx: bool = True # Whether to use ONNX runtime
# Paths relative to api directory
model_dir: str = "src/models" # Model directory relative to api/
voices_dir: str = "src/voices" # Voices directory relative to api/
# Model filenames
pytorch_model_file: str = "kokoro-v0_19.pth"
onnx_model_file: str = "kokoro-v0_19.onnx"
sample_rate: int = 24000
max_chunk_size: int = 300 # Maximum size of text chunks for processing
gap_trim_ms: int = 250 # Amount to trim from streaming chunk ends in milliseconds
@ -28,6 +33,12 @@ class Settings(BaseSettings):
onnx_optimization_level: str = "all" # all, basic, or disabled
onnx_memory_pattern: bool = True # Enable memory pattern optimization
onnx_arena_extend_strategy: str = "kNextPowerOfTwo" # Memory allocation strategy
# ONNX GPU Settings
onnx_device_id: int = 0 # GPU device ID to use
onnx_gpu_mem_limit: float = 0.7 # Limit GPU memory usage to 70%
onnx_cudnn_conv_algo_search: str = "EXHAUSTIVE" # CUDNN convolution algorithm search
onnx_do_copy_in_default_stream: bool = True # Copy in default CUDA stream
class Config:
env_file = ".env"

View file

@ -0,0 +1,109 @@
"""Model configuration schemas."""
from pydantic import BaseModel, Field
class ONNXCPUConfig(BaseModel):
"""ONNX CPU runtime configuration."""
num_threads: int = Field(8, description="Number of threads for parallel operations")
inter_op_threads: int = Field(4, description="Number of threads for operator parallelism")
execution_mode: str = Field("parallel", description="ONNX execution mode")
optimization_level: str = Field("all", description="ONNX optimization level")
memory_pattern: bool = Field(True, description="Enable memory pattern optimization")
arena_extend_strategy: str = Field("kNextPowerOfTwo", description="Memory arena strategy")
class Config:
frozen = True
class ONNXGPUConfig(ONNXCPUConfig):
"""ONNX GPU-specific configuration."""
device_id: int = Field(0, description="CUDA device ID")
gpu_mem_limit: float = Field(0.7, description="Fraction of GPU memory to use")
cudnn_conv_algo_search: str = Field("EXHAUSTIVE", description="CuDNN convolution algorithm search")
do_copy_in_default_stream: bool = Field(True, description="Copy in default CUDA stream")
class Config:
frozen = True
class PyTorchCPUConfig(BaseModel):
"""PyTorch CPU backend configuration."""
max_batch_size: int = Field(32, description="Maximum batch size for batched inference")
stream_buffer_size: int = Field(8, description="Size of stream buffer")
memory_threshold: float = Field(0.8, description="Memory threshold for cleanup")
retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors")
num_threads: int = Field(8, description="Number of threads for parallel operations")
pin_memory: bool = Field(True, description="Whether to pin memory for faster CPU-GPU transfer")
class Config:
frozen = True
class PyTorchGPUConfig(BaseModel):
"""PyTorch GPU backend configuration."""
device_id: int = Field(0, description="CUDA device ID")
use_fp16: bool = Field(True, description="Whether to use FP16 precision")
use_triton: bool = Field(True, description="Whether to use Triton for CUDA kernels")
max_batch_size: int = Field(32, description="Maximum batch size for batched inference")
stream_buffer_size: int = Field(8, description="Size of CUDA stream buffer")
memory_threshold: float = Field(0.8, description="Memory threshold for cleanup")
retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors")
sync_cuda: bool = Field(True, description="Whether to synchronize CUDA operations")
class Config:
frozen = True
"""PyTorch CPU-specific configuration."""
num_threads: int = Field(8, description="Number of threads for parallel operations")
pin_memory: bool = Field(True, description="Whether to pin memory for faster CPU-GPU transfer")
class Config:
frozen = True
class ModelConfig(BaseModel):
"""Model configuration."""
# General settings
model_type: str = Field("pytorch", description="Model type ('pytorch' or 'onnx')")
device_type: str = Field("auto", description="Device type ('cpu', 'gpu', or 'auto')")
cache_models: bool = Field(True, description="Whether to cache loaded models")
cache_voices: bool = Field(True, description="Whether to cache voice tensors")
voice_cache_size: int = Field(10, description="Maximum number of cached voices")
# Backend-specific configs
onnx_cpu: ONNXCPUConfig = Field(default_factory=ONNXCPUConfig)
onnx_gpu: ONNXGPUConfig = Field(default_factory=ONNXGPUConfig)
pytorch_cpu: PyTorchCPUConfig = Field(default_factory=PyTorchCPUConfig)
pytorch_gpu: PyTorchGPUConfig = Field(default_factory=PyTorchGPUConfig)
class Config:
frozen = True
def get_backend_config(self, backend_type: str):
"""Get configuration for specific backend.
Args:
backend_type: Backend type ('pytorch_cpu', 'pytorch_gpu', 'onnx_cpu', 'onnx_gpu')
Returns:
Backend-specific configuration
Raises:
ValueError: If backend type is invalid
"""
if backend_type not in {
'pytorch_cpu', 'pytorch_gpu', 'onnx_cpu', 'onnx_gpu'
}:
raise ValueError(f"Invalid backend type: {backend_type}")
return getattr(self, backend_type)
# Global instance
model_config = ModelConfig()

View file

@ -1,9 +1,10 @@
"""Async file and path operations."""
import io
import json
import os
from pathlib import Path
from typing import List, Optional, AsyncIterator, Callable, Set
from typing import List, Optional, AsyncIterator, Callable, Set, Dict, Any
import aiofiles
import aiofiles.os
@ -87,10 +88,18 @@ async def get_model_path(model_name: str) -> str:
Raises:
RuntimeError: If model not found
"""
search_paths = [
settings.model_dir,
os.path.join(os.path.dirname(__file__), "..", "..", "..", "models")
]
# Get api directory path (two levels up from core)
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
# Construct model directory path relative to api directory
model_dir = os.path.join(api_dir, settings.model_dir)
# Ensure model directory exists
os.makedirs(model_dir, exist_ok=True)
# Search in model directory
search_paths = [model_dir]
logger.debug(f"Searching for model in path: {model_dir}")
return await _find_file(model_name, search_paths)
@ -107,12 +116,20 @@ async def get_voice_path(voice_name: str) -> str:
Raises:
RuntimeError: If voice not found
"""
# Get api directory path
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
# Construct voice directory path relative to api directory
voice_dir = os.path.join(api_dir, settings.voices_dir)
# Ensure voice directory exists
os.makedirs(voice_dir, exist_ok=True)
voice_file = f"{voice_name}.pt"
search_paths = [
os.path.join(settings.model_dir, "..", settings.voices_dir),
os.path.join(os.path.dirname(__file__), "..", settings.voices_dir)
]
# Search in voice directory
search_paths = [voice_dir]
logger.debug(f"Searching for voice in path: {voice_dir}")
return await _find_file(voice_file, search_paths)
@ -123,10 +140,18 @@ async def list_voices() -> List[str]:
Returns:
List of voice names (without .pt extension)
"""
search_paths = [
os.path.join(settings.model_dir, "..", settings.voices_dir),
os.path.join(os.path.dirname(__file__), "..", settings.voices_dir)
]
# Get api directory path
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
# Construct voice directory path relative to api directory
voice_dir = os.path.join(api_dir, settings.voices_dir)
# Ensure voice directory exists
os.makedirs(voice_dir, exist_ok=True)
# Search in voice directory
search_paths = [voice_dir]
logger.debug(f"Scanning for voices in path: {voice_dir}")
def filter_voice_files(name: str) -> bool:
return name.endswith('.pt')
@ -179,6 +204,51 @@ async def save_voice_tensor(tensor: torch.Tensor, voice_path: str) -> None:
raise RuntimeError(f"Failed to save voice tensor to {voice_path}: {e}")
async def load_json(path: str) -> dict:
"""Load JSON file asynchronously.
Args:
path: Path to JSON file
Returns:
Parsed JSON data
Raises:
RuntimeError: If file cannot be read or parsed
"""
try:
async with aiofiles.open(path, 'r', encoding='utf-8') as f:
content = await f.read()
return json.loads(content)
except Exception as e:
raise RuntimeError(f"Failed to load JSON file {path}: {e}")
async def load_model_weights(path: str, device: str = "cpu") -> dict:
"""Load model weights asynchronously.
Args:
path: Path to model file (.pth or .onnx)
device: Device to load model to
Returns:
Model weights
Raises:
RuntimeError: If file cannot be read
"""
try:
async with aiofiles.open(path, 'rb') as f:
data = await f.read()
return torch.load(
io.BytesIO(data),
map_location=device,
weights_only=True
)
except Exception as e:
raise RuntimeError(f"Failed to load model weights from {path}: {e}")
async def read_file(path: str) -> str:
"""Read text file asynchronously.

View file

@ -1,4 +1,4 @@
"""Inference backends and model management."""
"""Model inference package."""
from .base import BaseModelBackend
from .model_manager import ModelManager, get_manager
@ -6,15 +6,13 @@ from .onnx_cpu import ONNXCPUBackend
from .onnx_gpu import ONNXGPUBackend
from .pytorch_cpu import PyTorchCPUBackend
from .pytorch_gpu import PyTorchGPUBackend
from ..structures.model_schemas import ModelConfig
__all__ = [
'BaseModelBackend',
'ModelManager',
'get_manager',
'ModelConfig',
'ONNXCPUBackend',
'ONNXGPUBackend',
'ONNXGPUBackend',
'PyTorchCPUBackend',
'PyTorchGPUBackend'
'PyTorchGPUBackend',
]

View file

@ -1,21 +1,18 @@
"""Model management and caching."""
import os
from typing import Dict, List, Optional, Union
from typing import Dict, Optional
import torch
from loguru import logger
from pydantic import BaseModel
from ..core import paths
from ..core.config import settings
from ..core.model_config import ModelConfig, model_config
from .base import BaseModelBackend
from .voice_manager import get_manager as get_voice_manager
from .onnx_cpu import ONNXCPUBackend
from .onnx_gpu import ONNXGPUBackend
from .pytorch_cpu import PyTorchCPUBackend
from .pytorch_gpu import PyTorchGPUBackend
from ..core import paths
from ..core.config import settings
from ..structures.model_schemas import ModelConfig
class ModelManager:
@ -27,44 +24,63 @@ class ModelManager:
Args:
config: Optional configuration
"""
self._config = config or ModelConfig()
self._config = config or model_config
self._backends: Dict[str, BaseModelBackend] = {}
self._current_backend: Optional[str] = None
self._voice_manager = get_voice_manager()
self._initialize_backends()
def _initialize_backends(self) -> None:
"""Initialize available backends."""
"""Initialize available backends."""
# Initialize GPU backends if available
if settings.use_gpu and torch.cuda.is_available():
try:
# PyTorch GPU
self._backends['pytorch_gpu'] = PyTorchGPUBackend()
self._current_backend = 'pytorch_gpu'
logger.info("Initialized PyTorch GPU backend")
# ONNX GPU
self._backends['onnx_gpu'] = ONNXGPUBackend()
logger.info("Initialized ONNX GPU backend")
except Exception as e:
logger.error(f"Failed to initialize GPU backends: {e}")
# Fallback to CPU if GPU fails
"""Initialize available backends based on settings."""
has_gpu = settings.use_gpu and torch.cuda.is_available()
try:
if has_gpu:
if settings.use_onnx:
# ONNX GPU primary
self._backends['onnx_gpu'] = ONNXGPUBackend()
self._current_backend = 'onnx_gpu'
logger.info("Initialized ONNX GPU backend")
# PyTorch GPU fallback
self._backends['pytorch_gpu'] = PyTorchGPUBackend()
logger.info("Initialized PyTorch GPU backend")
else:
# PyTorch GPU primary
self._backends['pytorch_gpu'] = PyTorchGPUBackend()
self._current_backend = 'pytorch_gpu'
logger.info("Initialized PyTorch GPU backend")
# ONNX GPU fallback
self._backends['onnx_gpu'] = ONNXGPUBackend()
logger.info("Initialized ONNX GPU backend")
else:
self._initialize_cpu_backends()
else:
except Exception as e:
logger.error(f"Failed to initialize GPU backends: {e}")
# Fallback to CPU if GPU fails
self._initialize_cpu_backends()
def _initialize_cpu_backends(self) -> None:
"""Initialize CPU backends."""
"""Initialize CPU backends based on settings."""
try:
# PyTorch CPU
self._backends['pytorch_cpu'] = PyTorchCPUBackend()
self._current_backend = 'pytorch_cpu'
logger.info("Initialized PyTorch CPU backend")
# ONNX CPU
self._backends['onnx_cpu'] = ONNXCPUBackend()
logger.info("Initialized ONNX CPU backend")
if settings.use_onnx:
# ONNX CPU primary
self._backends['onnx_cpu'] = ONNXCPUBackend()
self._current_backend = 'onnx_cpu'
logger.info("Initialized ONNX CPU backend")
# PyTorch CPU fallback
self._backends['pytorch_cpu'] = PyTorchCPUBackend()
logger.info("Initialized PyTorch CPU backend")
else:
# PyTorch CPU primary
self._backends['pytorch_cpu'] = PyTorchCPUBackend()
self._current_backend = 'pytorch_cpu'
logger.info("Initialized PyTorch CPU backend")
# ONNX CPU fallback
self._backends['onnx_cpu'] = ONNXCPUBackend()
logger.info("Initialized ONNX CPU backend")
except Exception as e:
logger.error(f"Failed to initialize CPU backends: {e}")
raise RuntimeError("No backends available")
@ -98,7 +114,7 @@ class ModelManager:
return self._backends[backend_type]
def _determine_backend(self, model_path: str) -> str:
"""Determine appropriate backend based on model file.
"""Determine appropriate backend based on model file and settings.
Args:
model_path: Path to model file
@ -106,10 +122,10 @@ class ModelManager:
Returns:
Backend type to use
"""
is_onnx = model_path.lower().endswith('.onnx')
has_gpu = settings.use_gpu and torch.cuda.is_available()
if is_onnx:
# If ONNX is preferred or model is ONNX format
if settings.use_onnx or model_path.lower().endswith('.onnx'):
return 'onnx_gpu' if has_gpu else 'onnx_cpu'
else:
return 'pytorch_gpu' if has_gpu else 'pytorch_cpu'
@ -117,12 +133,14 @@ class ModelManager:
async def load_model(
self,
model_path: str,
warmup_voice: Optional[torch.Tensor] = None,
backend_type: Optional[str] = None
) -> None:
"""Load model on specified backend.
Args:
model_path: Path to model file
warmup_voice: Optional voice tensor for warmup, skips warmup if None
backend_type: Backend to load on, uses default if None
Raises:
@ -138,35 +156,39 @@ class ModelManager:
backend = self.get_backend(backend_type)
# Load model and run warmup
# Load model
await backend.load_model(abs_path)
logger.info(f"Loaded model on {backend_type} backend")
await self._warmup_inference(backend)
# Run warmup if voice provided
if warmup_voice is not None:
await self._warmup_inference(backend, warmup_voice)
except Exception as e:
raise RuntimeError(f"Failed to load model: {e}")
async def _warmup_inference(self, backend: BaseModelBackend) -> None:
"""Run warmup inference to initialize model."""
async def _warmup_inference(self, backend: BaseModelBackend, voice: torch.Tensor) -> None:
"""Run warmup inference to initialize model.
Args:
backend: Model backend to warm up
voice: Voice tensor already loaded on correct device
"""
try:
# Import here to avoid circular imports
from ..text_processing import process_text
# Load default voice for warmup
voice = await self._voice_manager.load_voice(settings.default_voice, device=backend.device)
logger.info(f"Loaded voice {settings.default_voice} for warmup")
from ..services.text_processing import process_text
# Use real text
text = "Testing text to speech synthesis."
logger.info(f"Running warmup inference with voice: af")
# Process through pipeline
sequences = process_text(text)
if not sequences:
tokens = process_text(text)
if not tokens:
raise ValueError("Text processing failed")
# Run inference
backend.generate(sequences[0], voice, speed=1.0)
backend.generate(tokens, voice, speed=1.0)
logger.info("Completed warmup inference")
except Exception as e:
logger.warning(f"Warmup inference failed: {e}")
@ -175,7 +197,7 @@ class ModelManager:
async def generate(
self,
tokens: list[int],
voice_name: str,
voice: torch.Tensor,
speed: float = 1.0,
backend_type: Optional[str] = None
) -> torch.Tensor:
@ -183,7 +205,7 @@ class ModelManager:
Args:
tokens: Input token IDs
voice_name: Name of voice to use
voice: Voice tensor already loaded on correct device
speed: Speed multiplier
backend_type: Backend to use, uses default if None
@ -198,10 +220,7 @@ class ModelManager:
raise RuntimeError("Model not loaded")
try:
# Load voice using voice manager
voice = await self._voice_manager.load_voice(voice_name, device=backend.device)
# Generate audio
# Generate audio using provided voice tensor
return backend.generate(tokens, voice, speed)
except Exception as e:

View file

@ -13,8 +13,7 @@ from onnxruntime import (
)
from ..core import paths
from ..core.config import settings
from ..structures.model_schemas import ONNXConfig
from ..core.model_config import model_config
from .base import BaseModelBackend
@ -26,14 +25,11 @@ class ONNXCPUBackend(BaseModelBackend):
super().__init__()
self._device = "cpu"
self._session: Optional[InferenceSession] = None
self._config = ONNXConfig(
optimization_level=settings.onnx_optimization_level,
num_threads=settings.onnx_num_threads,
inter_op_threads=settings.onnx_inter_op_threads,
execution_mode=settings.onnx_execution_mode,
memory_pattern=settings.onnx_memory_pattern,
arena_extend_strategy=settings.onnx_arena_extend_strategy
)
@property
def is_loaded(self) -> bool:
"""Check if model is loaded."""
return self._session is not None
async def load_model(self, path: str) -> None:
"""Load ONNX model.
@ -115,28 +111,29 @@ class ONNXCPUBackend(BaseModelBackend):
Configured session options
"""
options = SessionOptions()
config = model_config.onnx_cpu
# Set optimization level
if self._config.optimization_level == "all":
if config.optimization_level == "all":
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
elif self._config.optimization_level == "basic":
elif config.optimization_level == "basic":
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
else:
options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
# Configure threading
options.intra_op_num_threads = self._config.num_threads
options.inter_op_num_threads = self._config.inter_op_threads
options.intra_op_num_threads = config.num_threads
options.inter_op_num_threads = config.inter_op_threads
# Set execution mode
options.execution_mode = (
ExecutionMode.ORT_PARALLEL
if self._config.execution_mode == "parallel"
if config.execution_mode == "parallel"
else ExecutionMode.ORT_SEQUENTIAL
)
# Configure memory optimization
options.enable_mem_pattern = self._config.memory_pattern
options.enable_mem_pattern = config.memory_pattern
return options
@ -148,7 +145,15 @@ class ONNXCPUBackend(BaseModelBackend):
"""
return {
"CPUExecutionProvider": {
"arena_extend_strategy": self._config.arena_extend_strategy,
"arena_extend_strategy": model_config.onnx_cpu.arena_extend_strategy,
"cpu_memory_arena_cfg": "cpu:0"
}
}
}
def unload(self) -> None:
"""Unload model and free resources."""
if self._session is not None:
del self._session
self._session = None
if torch.cuda.is_available():
torch.cuda.empty_cache()

View file

@ -13,8 +13,7 @@ from onnxruntime import (
)
from ..core import paths
from ..core.config import settings
from ..structures.model_schemas import ONNXGPUConfig
from ..core.model_config import model_config
from .base import BaseModelBackend
@ -28,18 +27,11 @@ class ONNXGPUBackend(BaseModelBackend):
raise RuntimeError("CUDA not available")
self._device = "cuda"
self._session: Optional[InferenceSession] = None
self._config = ONNXGPUConfig(
optimization_level=settings.onnx_optimization_level,
num_threads=settings.onnx_num_threads,
inter_op_threads=settings.onnx_inter_op_threads,
execution_mode=settings.onnx_execution_mode,
memory_pattern=settings.onnx_memory_pattern,
arena_extend_strategy=settings.onnx_arena_extend_strategy,
device_id=0,
gpu_mem_limit=0.7,
cudnn_conv_algo_search="EXHAUSTIVE",
do_copy_in_default_stream=True
)
@property
def is_loaded(self) -> bool:
"""Check if model is loaded."""
return self._session is not None
async def load_model(self, path: str) -> None:
"""Load ONNX model.
@ -121,28 +113,29 @@ class ONNXGPUBackend(BaseModelBackend):
Configured session options
"""
options = SessionOptions()
config = model_config.onnx_gpu
# Set optimization level
if self._config.optimization_level == "all":
if config.optimization_level == "all":
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
elif self._config.optimization_level == "basic":
elif config.optimization_level == "basic":
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
else:
options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
# Configure threading
options.intra_op_num_threads = self._config.num_threads
options.inter_op_num_threads = self._config.inter_op_threads
options.intra_op_num_threads = config.num_threads
options.inter_op_num_threads = config.inter_op_threads
# Set execution mode
options.execution_mode = (
ExecutionMode.ORT_PARALLEL
if self._config.execution_mode == "parallel"
if config.execution_mode == "parallel"
else ExecutionMode.ORT_SEQUENTIAL
)
# Configure memory optimization
options.enable_mem_pattern = self._config.memory_pattern
options.enable_mem_pattern = config.memory_pattern
return options
@ -152,12 +145,21 @@ class ONNXGPUBackend(BaseModelBackend):
Returns:
Provider configuration
"""
config = model_config.onnx_gpu
return {
"CUDAExecutionProvider": {
"device_id": self._config.device_id,
"arena_extend_strategy": self._config.arena_extend_strategy,
"gpu_mem_limit": int(self._config.gpu_mem_limit * torch.cuda.get_device_properties(0).total_memory),
"cudnn_conv_algo_search": self._config.cudnn_conv_algo_search,
"do_copy_in_default_stream": self._config.do_copy_in_default_stream
"device_id": config.device_id,
"arena_extend_strategy": config.arena_extend_strategy,
"gpu_mem_limit": int(config.gpu_mem_limit * torch.cuda.get_device_properties(0).total_memory),
"cudnn_conv_algo_search": config.cudnn_conv_algo_search,
"do_copy_in_default_stream": config.do_copy_in_default_stream
}
}
}
def unload(self) -> None:
"""Unload model and free resources."""
if self._session is not None:
del self._session
self._session = None
if torch.cuda.is_available():
torch.cuda.empty_cache()

View file

@ -9,7 +9,7 @@ from loguru import logger
from ..builds.models import build_model
from ..core import paths
from ..structures.model_schemas import PyTorchCPUConfig
from ..core.model_config import model_config
from .base import BaseModelBackend
@ -118,12 +118,12 @@ class PyTorchCPUBackend(BaseModelBackend):
super().__init__()
self._device = "cpu"
self._model: Optional[torch.nn.Module] = None
self._config = PyTorchCPUConfig()
# Configure PyTorch CPU settings
if self._config.num_threads > 0:
torch.set_num_threads(self._config.num_threads)
if self._config.pin_memory:
config = model_config.pytorch_cpu
if config.num_threads > 0:
torch.set_num_threads(config.num_threads)
if config.pin_memory:
torch.set_default_tensor_type(torch.FloatTensor)
async def load_model(self, path: str) -> None:

View file

@ -9,7 +9,7 @@ from loguru import logger
from ..builds.models import build_model
from ..core import paths
from ..structures.model_schemas import PyTorchConfig
from ..core.model_config import model_config
from .base import BaseModelBackend
@ -96,7 +96,12 @@ class PyTorchGPUBackend(BaseModelBackend):
raise RuntimeError("CUDA not available")
self._device = "cuda"
self._model: Optional[torch.nn.Module] = None
self._config = PyTorchConfig()
# Configure GPU settings
config = model_config.pytorch_gpu
if config.sync_cuda:
torch.cuda.synchronize()
torch.cuda.set_device(config.device_id)
async def load_model(self, path: str) -> None:
"""Load PyTorch model.
@ -154,13 +159,19 @@ class PyTorchGPUBackend(BaseModelBackend):
except Exception as e:
logger.error(f"Generation failed: {e}")
if model_config.pytorch_gpu.retry_on_oom and "out of memory" in str(e).lower():
self._clear_memory()
return self.generate(tokens, voice, speed) # Retry once
raise
finally:
if model_config.pytorch_gpu.sync_cuda:
torch.cuda.synchronize()
def _check_memory(self) -> bool:
"""Check if memory usage is above threshold."""
if torch.cuda.is_available():
memory_gb = torch.cuda.memory_allocated() / 1e9
return memory_gb > self._config.memory_threshold
return memory_gb > model_config.pytorch_gpu.memory_threshold
return False
def _clear_memory(self) -> None:

View file

@ -33,7 +33,15 @@ class VoiceManager:
Returns:
Path to voice file if exists, None otherwise
"""
voice_path = os.path.join(settings.voices_dir, f"{voice_name}.pt")
# Get api directory path (two levels up from inference)
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
# Construct voice path relative to api directory
voice_path = os.path.join(api_dir, settings.voices_dir, f"{voice_name}.pt")
# Ensure voices directory exists
os.makedirs(os.path.dirname(voice_path), exist_ok=True)
return voice_path if os.path.exists(voice_path) else None
async def load_voice(self, voice_name: str, device: str = "cpu") -> torch.Tensor:
@ -112,8 +120,15 @@ class VoiceManager:
combined_name = "_".join(voices)
combined_tensor = torch.mean(torch.stack(voice_tensors), dim=0)
# Get api directory path
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
voices_dir = os.path.join(api_dir, settings.voices_dir)
# Ensure voices directory exists
os.makedirs(voices_dir, exist_ok=True)
# Save combined voice
combined_path = os.path.join(settings.voices_dir, f"{combined_name}.pt")
combined_path = os.path.join(voices_dir, f"{combined_name}.pt")
try:
torch.save(combined_tensor, combined_path)
except Exception as e:
@ -132,7 +147,15 @@ class VoiceManager:
"""
voices = []
try:
for entry in os.listdir(settings.voices_dir):
# Get api directory path
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
voices_dir = os.path.join(api_dir, settings.voices_dir)
# Ensure voices directory exists
os.makedirs(voices_dir, exist_ok=True)
# List voice files
for entry in os.listdir(voices_dir):
if entry.endswith(".pt"):
voices.append(entry[:-3]) # Remove .pt extension
except Exception as e:

View file

@ -13,7 +13,6 @@ from loguru import logger
from .core.config import settings
from .routers.development import router as dev_router
from .routers.openai_compatible import router as openai_router
from .services.tts_model import TTSModel
from .services.tts_service import TTSService
@ -44,25 +43,32 @@ async def lifespan(app: FastAPI):
"""Lifespan context manager for model initialization"""
logger.info("Loading TTS model and voice packs...")
# Initialize the main model with warm-up
voicepack_count = await TTSModel.setup()
# boundary = "█████╗"*9
# Initialize service
service = TTSService()
await service.ensure_initialized()
# Get available voices
voices = await service.list_voices()
voicepack_count = len(voices)
# Get device info from model manager
device = "GPU" if settings.use_gpu else "CPU"
model = "ONNX" if settings.use_onnx else "PyTorch"
boundary = "" * 2*12
startup_msg = f"""
{boundary}
{boundary}
"""
# TODO: Improve CPU warmup, threads, memory, etc
startup_msg += f"\nModel warmed up on {TTSModel.get_device()}"
startup_msg += f"\nModel warmed up on {device}: {model}"
startup_msg += f"\n{voicepack_count} voice packs loaded\n"
startup_msg += f"\n{boundary}\n"
logger.info(startup_msg)

View file

@ -6,7 +6,6 @@ from loguru import logger
from ..services.audio import AudioService
from ..services.text_processing import phonemize, tokenize
from ..services.tts_model import TTSModel
from ..services.tts_service import TTSService
from ..structures.text_schemas import (
GenerateFromPhonemesRequest,
@ -82,27 +81,34 @@ async def generate_from_phonemes(
detail={"error": "Invalid request", "message": "Phonemes cannot be empty"},
)
# Validate voice exists
voice_path = tts_service._get_voice_path(request.voice)
if not voice_path:
raise HTTPException(
status_code=400,
detail={
"error": "Invalid request",
"message": f"Voice not found: {request.voice}",
},
)
try:
# Load voice
voicepack = tts_service._load_voice(voice_path)
# Ensure service is initialized
await tts_service.ensure_initialized()
# Validate voice exists
available_voices = await tts_service.list_voices()
if request.voice not in available_voices:
raise HTTPException(
status_code=400,
detail={
"error": "Invalid request",
"message": f"Voice not found: {request.voice}",
},
)
# Convert phonemes to tokens
tokens = tokenize(request.phonemes)
tokens = [0] + tokens + [0] # Add start/end tokens
# Generate audio directly from tokens
audio = TTSModel.generate_from_tokens(tokens, voicepack, request.speed)
audio = await tts_service.model_manager.generate(
tokens,
request.voice,
speed=request.speed
)
if audio is None:
raise ValueError("Failed to generate audio")
# Convert to WAV bytes
wav_bytes = AudioService.convert_audio(

View file

@ -1,13 +1,28 @@
from .normalizer import normalize_text
from .phonemizer import EspeakBackend, PhonemizerBackend, phonemize
from .vocabulary import VOCAB, decode_tokens, tokenize
"""Text processing pipeline."""
__all__ = [
"normalize_text",
"phonemize",
"tokenize",
"decode_tokens",
"VOCAB",
"PhonemizerBackend",
"EspeakBackend",
]
from .chunker import split_text
from .normalizer import normalize_text
from .phonemizer import phonemize
from .vocabulary import tokenize
def process_text(text: str, language: str = "a") -> list[int]:
"""Process text through the full pipeline.
Args:
text: Input text
language: Language code ('a' for US English, 'b' for British English)
Returns:
List of token IDs
Note:
The pipeline:
1. Converts text to phonemes using phonemizer
2. Converts phonemes to token IDs using vocabulary
"""
# Convert text to phonemes
phonemes = phonemize(text, language=language)
# Convert phonemes to token IDs
return tokenize(phonemes)

View file

@ -5,17 +5,16 @@ import os
import time
from typing import List, Tuple
import torch
import numpy as np
import scipy.io.wavfile as wavfile
import torch
from loguru import logger
from ..core.config import settings
from ..inference.model_manager import get_manager as get_model_manager
from ..inference.voice_manager import get_manager as get_voice_manager
from .audio import AudioNormalizer, AudioService
from .text_processing import chunker, normalize_text
from .text_processing import chunker, normalize_text, process_text
class TTSService:
@ -41,16 +40,33 @@ class TTSService:
raise self._initialization_error
try:
# Determine model path based on hardware
# Get api directory path (one level up from src)
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
# Determine model file and backend based on hardware
if settings.use_gpu and torch.cuda.is_available():
model_path = os.path.join(settings.model_dir, settings.pytorch_model_path)
model_file = settings.pytorch_model_file
backend_type = 'pytorch_gpu'
else:
model_path = os.path.join(settings.model_dir, settings.onnx_model_path)
model_file = settings.onnx_model_file
backend_type = 'onnx_cpu'
# Initialize model
await self.model_manager.load_model(model_path, backend_type)
# Construct model path relative to api directory
model_path = os.path.join(api_dir, settings.model_dir, model_file)
# Ensure model directory exists
os.makedirs(os.path.dirname(model_path), exist_ok=True)
if not os.path.exists(model_path):
raise RuntimeError(f"Model file not found: {model_path}")
# Load default voice for warmup
backend = self.model_manager.get_backend(backend_type)
warmup_voice = await self.voice_manager.load_voice(settings.default_voice, device=backend.device)
logger.info(f"Loaded voice {settings.default_voice} for warmup")
# Initialize model with warmup voice
await self.model_manager.load_model(model_path, warmup_voice, backend_type)
logger.info(f"Initialized model on {backend_type} backend")
self._initialized = True
@ -86,16 +102,19 @@ class TTSService:
audio_chunks = []
for chunk in chunker.split_text(text):
try:
# Process text
sequences = process_text(chunk)
if not sequences:
# Convert chunk to token IDs
tokens = process_text(chunk)
if not tokens:
continue
# Get backend and load voice
backend = self.model_manager.get_backend()
voice_tensor = await self.voice_manager.load_voice(voice, device=backend.device)
# Generate audio
chunk_audio = await self.model_manager.generate(
sequences[0],
voice,
tokens,
voice_tensor,
speed=speed
)
if chunk_audio is not None:
@ -154,14 +173,17 @@ class TTSService:
while current_chunk is not None:
next_chunk = next(chunk_gen, None)
try:
# Process text
from ..text_processing import process_text
sequences = process_text(current_chunk)
if sequences:
# Convert chunk to token IDs
tokens = process_text(current_chunk)
if tokens:
# Get backend and load voice
backend = self.model_manager.get_backend()
voice_tensor = await self.voice_manager.load_voice(voice, device=backend.device)
# Generate audio
chunk_audio = await self.model_manager.generate(
sequences[0],
voice,
tokens,
voice_tensor,
speed=speed
)

View file

@ -1,26 +1,13 @@
"""Model and voice configuration schemas."""
"""Voice configuration schemas."""
from pydantic import BaseModel
class ModelConfig(BaseModel):
"""Model configuration."""
optimization_level: str = "all" # all, basic, none
num_threads: int = 4
inter_op_threads: int = 4
execution_mode: str = "parallel" # parallel, sequential
memory_pattern: bool = True
arena_extend_strategy: str = "kNextPowerOfTwo"
class Config:
frozen = True # Make config immutable
from pydantic import BaseModel, Field
class VoiceConfig(BaseModel):
"""Voice configuration."""
use_cache: bool = True
cache_size: int = 3 # Number of voices to cache
validate_on_load: bool = True # Whether to validate voices when loading
use_cache: bool = Field(True, description="Whether to cache loaded voices")
cache_size: int = Field(3, description="Number of voices to cache")
validate_on_load: bool = Field(True, description="Whether to validate voices when loading")
class Config:
frozen = True # Make config immutable

View file

@ -17,29 +17,10 @@ COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
RUN useradd -m -u 1000 appuser
# Create directories and set ownership
RUN mkdir -p /app/models && \
mkdir -p /app/api/src/voices && \
RUN mkdir -p /app/api/src/voices && \
chown -R appuser:appuser /app
USER appuser
# Download and extract models
WORKDIR /app/models
RUN set -x && \
curl -L -o model.tar.gz https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.0.1/kokoro-82m-onnx.tar.gz && \
echo "Downloaded model.tar.gz:" && ls -lh model.tar.gz && \
tar xzf model.tar.gz && \
echo "Contents after extraction:" && ls -lhR && \
rm model.tar.gz && \
echo "Final contents:" && ls -lhR
# Download and extract voice models
WORKDIR /app/api/src/voices
RUN curl -L -o voices.tar.gz https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.0.1/voice-models.tar.gz && \
tar xzf voices.tar.gz && \
rm voices.tar.gz
# Switch back to app directory
WORKDIR /app
# Copy dependency files
@ -59,9 +40,10 @@ RUN --mount=type=cache,target=/root/.cache/uv \
# Set environment variables
ENV PYTHONUNBUFFERED=1
ENV PYTHONPATH=/app:/app/models
ENV PYTHONPATH=/app
ENV PATH="/app/.venv/bin:$PATH"
ENV UV_LINK_MODE=copy
ENV USE_GPU=false
# Run FastAPI server
CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]

View file

@ -6,12 +6,11 @@ services:
context: ../..
dockerfile: docker/cpu/Dockerfile
volumes:
- ../../api/src:/app/api/src
- ../../api/src/voices:/app/api/src/voices
- ../../api:/app/api
ports:
- "8880:8880"
environment:
- PYTHONPATH=/app:/app/models
- PYTHONPATH=/app:/app/api
# ONNX Optimization Settings for vectorized operations
- ONNX_NUM_THREADS=8 # Maximize core usage for vectorized ops
- ONNX_INTER_OP_THREADS=4 # Higher inter-op for parallel matrix operations
@ -20,20 +19,20 @@ services:
- ONNX_MEMORY_PATTERN=true
- ONNX_ARENA_EXTEND_STRATEGY=kNextPowerOfTwo
# Gradio UI service [Comment out everything below if you don't need it]
gradio-ui:
image: ghcr.io/remsky/kokoro-fastapi-ui:v0.1.0
# Uncomment below (and comment out above) to build from source instead of using the released image
build:
context: ../../ui
ports:
- "7860:7860"
volumes:
- ../../ui/data:/app/ui/data
- ../../ui/app.py:/app/app.py # Mount app.py for hot reload
environment:
- GRADIO_WATCH=True # Enable hot reloading
- PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
- DISABLE_LOCAL_SAVING=false # Set to 'true' to disable local saving and hide file view
- API_HOST=kokoro-tts # Set TTS service URL
- API_PORT=8880 # Set TTS service PORT
# # Gradio UI service [Comment out everything below if you don't need it]
# gradio-ui:
# image: ghcr.io/remsky/kokoro-fastapi-ui:v0.1.0
# # Uncomment below (and comment out above) to build from source instead of using the released image
# build:
# context: ../../ui
# ports:
# - "7860:7860"
# volumes:
# - ../../ui/data:/app/ui/data
# - ../../ui/app.py:/app/app.py # Mount app.py for hot reload
# environment:
# - GRADIO_WATCH=True # Enable hot reloading
# - PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
# - DISABLE_LOCAL_SAVING=false # Set to 'true' to disable local saving and hide file view
# - API_HOST=kokoro-tts # Set TTS service URL
# - API_PORT=8880 # Set TTS service PORT

53
docker/cpu/download_onnx.py Executable file
View file

@ -0,0 +1,53 @@
#!/usr/bin/env python3
import os
import sys
import requests
from pathlib import Path
from typing import List
def download_file(url: str, output_dir: Path) -> None:
"""Download a file from URL to the specified directory."""
filename = os.path.basename(url)
output_path = output_dir / filename
print(f"Downloading {filename}...")
response = requests.get(url, stream=True)
response.raise_for_status()
with open(output_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
def find_project_root() -> Path:
"""Find project root by looking for api directory."""
max_steps = 5
current = Path(__file__).resolve()
for _ in range(max_steps):
if (current / 'api').is_dir():
return current
current = current.parent
raise RuntimeError("Could not find project root (no api directory found)")
def main(custom_models: List[str] = None):
# Always use top-level models directory relative to project root
project_root = find_project_root()
models_dir = project_root / 'api' / 'src' / 'models'
models_dir.mkdir(exist_ok=True, parents=True)
# Default ONNX model if no arguments provided
default_models = [
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.onnx",
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19_fp16.onnx"
]
# Use provided models or default
models_to_download = custom_models if custom_models else default_models
for model_url in models_to_download:
try:
download_file(model_url, models_dir)
except Exception as e:
print(f"Error downloading {model_url}: {e}")
if __name__ == "__main__":
main(sys.argv[1:] if len(sys.argv) > 1 else None)

32
docker/cpu/download_onnx.sh Executable file
View file

@ -0,0 +1,32 @@
#!/bin/bash
# Ensure models directory exists
mkdir -p api/src/models
# Function to download a file
download_file() {
local url="$1"
local filename=$(basename "$url")
echo "Downloading $filename..."
curl -L "$url" -o "api/src/models/$filename"
}
# Default ONNX model if no arguments provided
DEFAULT_MODELS=(
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.onnx"
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19_fp16.onnx"
)
# Use provided models or default
if [ $# -gt 0 ]; then
MODELS=("$@")
else
MODELS=("${DEFAULT_MODELS[@]}")
fi
# Download all models
for model in "${MODELS[@]}"; do
download_file "$model"
done
echo "ONNX model download complete!"

View file

@ -1,9 +1,8 @@
FROM nvidia/cuda:12.1.0-base-ubuntu22.04
# Install Python and other dependencies
FROM pytorch/pytorch:2.1.0-cuda11.8-cudnn8-runtime
# Set non-interactive frontend
ENV DEBIAN_FRONTEND=noninteractive
# Install dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
python3.10 \
python3.10-venv \
espeak-ng \
git \
libsndfile1 \
@ -19,25 +18,10 @@ COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
RUN useradd -m -u 1000 appuser
# Create directories and set ownership
RUN mkdir -p /app/models && \
mkdir -p /app/api/src/voices && \
RUN mkdir -p /app/api/src/voices && \
chown -R appuser:appuser /app
USER appuser
# Download and extract models
WORKDIR /app/models
RUN curl -L -o model.tar.gz https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.0.1/kokoro-82m-pytorch.tar.gz && \
tar xzf model.tar.gz && \
rm model.tar.gz
# Download and extract voice models
WORKDIR /app/api/src/voices
RUN curl -L -o voices.tar.gz https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.0.1/voice-models.tar.gz && \
tar xzf voices.tar.gz && \
rm voices.tar.gz
# Switch back to app directory
WORKDIR /app
# Copy dependency files
@ -57,9 +41,10 @@ RUN --mount=type=cache,target=/root/.cache/uv \
# Set environment variables
ENV PYTHONUNBUFFERED=1
ENV PYTHONPATH=/app:/app/models
ENV PYTHONPATH=/app
ENV PATH="/app/.venv/bin:$PATH"
ENV UV_LINK_MODE=copy
ENV USE_GPU=true
# Run FastAPI server
CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]

View file

@ -6,12 +6,14 @@ services:
context: ../..
dockerfile: docker/gpu/Dockerfile
volumes:
- ../../api/src:/app/api/src # Mount src for development
- ../../api/src/voices:/app/api/src/voices # Mount voices for persistence
- ../../api:/app/api
ports:
- "8880:8880"
environment:
- PYTHONPATH=/app:/app/models
- PYTHONPATH=/app
- USE_GPU=true
- USE_ONNX=false
- PYTHONUNBUFFERED=1
deploy:
resources:
reservations:
@ -20,20 +22,20 @@ services:
count: 1
capabilities: [gpu]
# Gradio UI service
gradio-ui:
image: ghcr.io/remsky/kokoro-fastapi-ui:v0.1.0
# Uncomment below to build from source instead of using the released image
# build:
# context: ../../ui
ports:
- "7860:7860"
volumes:
- ../../ui/data:/app/ui/data
- ../../ui/app.py:/app/app.py # Mount app.py for hot reload
environment:
- GRADIO_WATCH=1 # Enable hot reloading
- PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
- DISABLE_LOCAL_SAVING=false # Set to 'true' to disable local saving and hide file view
- API_HOST=kokoro-tts # Set TTS service URL
- API_PORT=8880 # Set TTS service PORT
# # Gradio UI service
# gradio-ui:
# image: ghcr.io/remsky/kokoro-fastapi-ui:v0.1.0
# # Uncomment below to build from source instead of using the released image
# # build:
# # context: ../../ui
# ports:
# - "7860:7860"
# volumes:
# - ../../ui/data:/app/ui/data
# - ../../ui/app.py:/app/app.py # Mount app.py for hot reload
# environment:
# - GRADIO_WATCH=1 # Enable hot reloading
# - PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
# - DISABLE_LOCAL_SAVING=false # Set to 'true' to disable local saving and hide file view
# - API_HOST=kokoro-tts # Set TTS service URL
# - API_PORT=8880 # Set TTS service PORT

57
docker/gpu/download_pth.py Executable file
View file

@ -0,0 +1,57 @@
#!/usr/bin/env python3
import os
import sys
import requests
from pathlib import Path
from typing import List
def download_file(url: str, output_dir: Path) -> None:
"""Download a file from URL to the specified directory."""
filename = os.path.basename(url)
if not filename.endswith('.pth'):
print(f"Warning: {filename} is not a .pth file")
return
output_path = output_dir / filename
print(f"Downloading {filename}...")
response = requests.get(url, stream=True)
response.raise_for_status()
with open(output_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
def find_project_root() -> Path:
"""Find project root by looking for api directory."""
max_steps = 5
current = Path(__file__).resolve()
for _ in range(max_steps):
if (current / 'api').is_dir():
return current
current = current.parent
raise RuntimeError("Could not find project root (no api directory found)")
def main(custom_models: List[str] = None):
# Find project root and ensure models directory exists
project_root = find_project_root()
models_dir = project_root / 'api' / 'src' / 'models'
print(f"Downloading models to {models_dir}")
models_dir.mkdir(exist_ok=True)
# Default PTH model if no arguments provided
default_models = [
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth"
]
# Use provided models or default
models_to_download = custom_models if custom_models else default_models
for model_url in models_to_download:
try:
download_file(model_url, models_dir)
except Exception as e:
print(f"Error downloading {model_url}: {e}")
if __name__ == "__main__":
main(sys.argv[1:] if len(sys.argv) > 1 else None)

31
docker/gpu/download_pth.sh Executable file
View file

@ -0,0 +1,31 @@
#!/bin/bash
# Ensure models directory exists
mkdir -p api/src/models
# Function to download a file
download_file() {
local url="$1"
local filename=$(basename "$url")
echo "Downloading $filename..."
curl -L "$url" -o "api/src/models/$filename"
}
# Default PTH model if no arguments provided
DEFAULT_MODELS=(
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth"
)
# Use provided models or default
if [ $# -gt 0 ]; then
MODELS=("$@")
else
MODELS=("${DEFAULT_MODELS[@]}")
fi
# Download all models
for model in "${MODELS[@]}"; do
download_file "$model"
done
echo "PyTorch model download complete!"

67
uv.lock generated
View file

@ -2,9 +2,17 @@ version = 1
requires-python = ">=3.10"
resolution-markers = [
"python_full_version == '3.11.*'",
"(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux')",
"(python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform == 'darwin')",
"python_full_version < '3.11'",
"(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')",
"(python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform == 'darwin')",
"python_full_version >= '3.13'",
"(python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'linux')",
"(python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.13' and sys_platform == 'darwin')",
"python_full_version == '3.12.*'",
"(python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'linux')",
"(python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.12.*' and sys_platform == 'darwin')",
]
conflicts = [[
{ package = "kokoro-fastapi", extra = "cpu" },
@ -798,6 +806,7 @@ dependencies = [
{ name = "phonemizer" },
{ name = "pydantic" },
{ name = "pydantic-settings" },
{ name = "pydub" },
{ name = "python-dotenv" },
{ name = "regex" },
{ name = "requests" },
@ -812,7 +821,8 @@ dependencies = [
[package.optional-dependencies]
cpu = [
{ name = "torch", version = "2.5.1+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" } },
{ name = "torch", version = "2.5.1", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" },
{ name = "torch", version = "2.5.1+cpu", source = { registry = "https://download.pytorch.org/whl/cpu" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
]
gpu = [
{ name = "torch", version = "2.5.1+cu121", source = { registry = "https://download.pytorch.org/whl/cu121" } },
@ -824,7 +834,6 @@ test = [
{ name = "pytest" },
{ name = "pytest-asyncio" },
{ name = "pytest-cov" },
{ name = "ruff" },
]
[package.metadata]
@ -845,18 +854,18 @@ requires-dist = [
{ name = "phonemizer", specifier = "==3.3.0" },
{ name = "pydantic", specifier = "==2.10.4" },
{ name = "pydantic-settings", specifier = "==2.7.0" },
{ name = "pydub", specifier = ">=0.25.1" },
{ name = "pytest", marker = "extra == 'test'", specifier = "==8.0.0" },
{ name = "pytest-asyncio", marker = "extra == 'test'", specifier = "==0.23.5" },
{ name = "pytest-cov", marker = "extra == 'test'", specifier = "==4.1.0" },
{ name = "python-dotenv", specifier = "==1.0.1" },
{ name = "regex", specifier = "==2024.11.6" },
{ name = "requests", specifier = "==2.32.3" },
{ name = "ruff", marker = "extra == 'test'", specifier = ">=0.2.2" },
{ name = "scipy", specifier = "==1.14.1" },
{ name = "soundfile", specifier = "==0.13.0" },
{ name = "sqlalchemy", specifier = "==2.0.27" },
{ name = "tiktoken", specifier = "==0.8.0" },
{ name = "torch", marker = "extra == 'cpu'", specifier = "==2.5.1+cpu", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "kokoro-fastapi", extra = "cpu" } },
{ name = "torch", marker = "extra == 'cpu'", specifier = "==2.5.1", index = "https://download.pytorch.org/whl/cpu", conflict = { package = "kokoro-fastapi", extra = "cpu" } },
{ name = "torch", marker = "extra == 'gpu'", specifier = "==2.5.1+cu121", index = "https://download.pytorch.org/whl/cu121", conflict = { package = "kokoro-fastapi", extra = "gpu" } },
{ name = "tqdm", specifier = "==4.67.1" },
{ name = "transformers", specifier = "==4.47.1" },
@ -2334,24 +2343,52 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/68/4f/12207897848a653d03ebbf6775a29d949408ded5f99b2d87198bc5c93508/tomlkit-0.12.0-py3-none-any.whl", hash = "sha256:926f1f37a1587c7a4f6c7484dae538f1345d96d793d9adab5d3675957b1d0766", size = 37334 },
]
[[package]]
name = "torch"
version = "2.5.1"
source = { registry = "https://download.pytorch.org/whl/cpu" }
resolution-markers = [
"(python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform == 'darwin')",
"(python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform == 'darwin')",
"(python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.13' and sys_platform == 'darwin')",
"(python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.12.*' and sys_platform == 'darwin')",
]
dependencies = [
{ name = "filelock", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" },
{ name = "fsspec", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" },
{ name = "jinja2", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" },
{ name = "networkx", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" },
{ name = "setuptools", marker = "(python_full_version >= '3.12' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and sys_platform == 'darwin')" },
{ name = "sympy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" },
{ name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" },
]
wheels = [
{ url = "https://download.pytorch.org/whl/cpu/torch-2.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:269b10c34430aa8e9643dbe035dc525c4a9b1d671cd3dbc8ecbcaed280ae322d" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.5.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:23d062bf70776a3d04dbe74db950db2a5245e1ba4f27208a87f0d743b0d06e86" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.5.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5b3203f191bc40783c99488d2e776dcf93ac431a59491d627a1ca5b3ae20b22" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.5.1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:31f8c39660962f9ae4eeec995e3049b5492eb7360dd4f07377658ef4d728fa4c" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.5.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:36d1be99281b6f602d9639bd0af3ee0006e7aab16f6718d86f709d395b6f262c" },
{ url = "https://download.pytorch.org/whl/cpu/torch-2.5.1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:8c712df61101964eb11910a846514011f0b6f5920c55dbf567bff8a34163d5b1" },
]
[[package]]
name = "torch"
version = "2.5.1+cpu"
source = { registry = "https://download.pytorch.org/whl/cpu" }
resolution-markers = [
"python_full_version == '3.11.*'",
"python_full_version < '3.11'",
"python_full_version >= '3.13'",
"python_full_version == '3.12.*'",
"(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux')",
"(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')",
"(python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'linux')",
"(python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'linux')",
]
dependencies = [
{ name = "filelock" },
{ name = "fsspec" },
{ name = "jinja2" },
{ name = "networkx" },
{ name = "setuptools", marker = "python_full_version >= '3.12'" },
{ name = "sympy" },
{ name = "typing-extensions" },
{ name = "filelock", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "fsspec", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "jinja2", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "networkx", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "setuptools", marker = "(python_full_version >= '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "sympy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "typing-extensions", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
]
wheels = [
{ url = "https://download.pytorch.org/whl/cpu/torch-2.5.1%2Bcpu-cp310-cp310-linux_x86_64.whl", hash = "sha256:7f91a2200e352745d70e22396bd501448e28350fbdbd8d8b1c83037e25451150" },