mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Enable ONNX GPU support in Docker configurations and refactor model file handling
This commit is contained in:
parent
4a24be1605
commit
d50214d3be
19 changed files with 123 additions and 4321 deletions
|
@ -7,6 +7,9 @@
|
||||||
[]()
|
[]()
|
||||||
[](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667) [](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero)
|
[](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667) [](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero)
|
||||||
|
|
||||||
|
> [!INFO]
|
||||||
|
> Pre-release. Not fully tested
|
||||||
|
|
||||||
Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model
|
Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model
|
||||||
- OpenAI-compatible Speech endpoint, with inline voice combination functionality
|
- OpenAI-compatible Speech endpoint, with inline voice combination functionality
|
||||||
- NVIDIA GPU accelerated or CPU Onnx inference
|
- NVIDIA GPU accelerated or CPU Onnx inference
|
||||||
|
|
|
@ -9,37 +9,22 @@ class Settings(BaseSettings):
|
||||||
host: str = "0.0.0.0"
|
host: str = "0.0.0.0"
|
||||||
port: int = 8880
|
port: int = 8880
|
||||||
|
|
||||||
# TTS Settings
|
# Application Settings
|
||||||
output_dir: str = "output"
|
output_dir: str = "output"
|
||||||
output_dir_size_limit_mb: float = 500.0 # Maximum size of output directory in MB
|
output_dir_size_limit_mb: float = 500.0 # Maximum size of output directory in MB
|
||||||
default_voice: str = "af"
|
default_voice: str = "af"
|
||||||
use_gpu: bool = False # Whether to use GPU acceleration if available
|
use_gpu: bool = False # Whether to use GPU acceleration if available
|
||||||
use_onnx: bool = True # Whether to use ONNX runtime
|
use_onnx: bool = True # Whether to use ONNX runtime
|
||||||
|
|
||||||
# Container absolute paths
|
# Container absolute paths
|
||||||
model_dir: str = "/app/api/src/models" # Absolute path in container
|
model_dir: str = "/app/api/src/models" # Absolute path in container
|
||||||
voices_dir: str = "/app/api/src/voices" # Absolute path in container
|
voices_dir: str = "/app/api/src/voices" # Absolute path in container
|
||||||
|
|
||||||
# Model filenames
|
# Audio Settings
|
||||||
pytorch_model_file: str = "kokoro-v0_19.pth"
|
|
||||||
onnx_model_file: str = "kokoro-v0_19.onnx"
|
|
||||||
sample_rate: int = 24000
|
sample_rate: int = 24000
|
||||||
max_chunk_size: int = 300 # Maximum size of text chunks for processing
|
max_chunk_size: int = 300 # Maximum size of text chunks for processing
|
||||||
gap_trim_ms: int = 250 # Amount to trim from streaming chunk ends in milliseconds
|
gap_trim_ms: int = 250 # Amount to trim from streaming chunk ends in milliseconds
|
||||||
|
|
||||||
# ONNX Optimization Settings
|
|
||||||
onnx_num_threads: int = 4 # Number of threads for intra-op parallelism
|
|
||||||
onnx_inter_op_threads: int = 4 # Number of threads for inter-op parallelism
|
|
||||||
onnx_execution_mode: str = "parallel" # parallel or sequential
|
|
||||||
onnx_optimization_level: str = "all" # all, basic, or disabled
|
|
||||||
onnx_memory_pattern: bool = True # Enable memory pattern optimization
|
|
||||||
onnx_arena_extend_strategy: str = "kNextPowerOfTwo" # Memory allocation strategy
|
|
||||||
|
|
||||||
# ONNX GPU Settings
|
|
||||||
onnx_device_id: int = 0 # GPU device ID to use
|
|
||||||
onnx_gpu_mem_limit: float = 0.7 # Limit GPU memory usage to 70%
|
|
||||||
onnx_cudnn_conv_algo_search: str = "EXHAUSTIVE" # CUDNN convolution algorithm search
|
|
||||||
onnx_do_copy_in_default_stream: bool = True # Copy in default CUDA stream
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_file = ".env"
|
env_file = ".env"
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ class ONNXCPUConfig(BaseModel):
|
||||||
instance_timeout: int = Field(300, description="Session timeout in seconds")
|
instance_timeout: int = Field(300, description="Session timeout in seconds")
|
||||||
|
|
||||||
# Runtime settings
|
# Runtime settings
|
||||||
num_threads: int = Field(8, description="Number of threads for parallel operations")
|
num_threads: int = Field(4, description="Number of threads for parallel operations")
|
||||||
inter_op_threads: int = Field(4, description="Number of threads for operator parallelism")
|
inter_op_threads: int = Field(4, description="Number of threads for operator parallelism")
|
||||||
execution_mode: str = Field("parallel", description="ONNX execution mode")
|
execution_mode: str = Field("parallel", description="ONNX execution mode")
|
||||||
optimization_level: str = Field("all", description="ONNX optimization level")
|
optimization_level: str = Field("all", description="ONNX optimization level")
|
||||||
|
@ -77,6 +77,10 @@ class ModelConfig(BaseModel):
|
||||||
cache_voices: bool = Field(True, description="Whether to cache voice tensors")
|
cache_voices: bool = Field(True, description="Whether to cache voice tensors")
|
||||||
voice_cache_size: int = Field(2, description="Maximum number of cached voices")
|
voice_cache_size: int = Field(2, description="Maximum number of cached voices")
|
||||||
|
|
||||||
|
# Model filenames
|
||||||
|
pytorch_model_file: str = Field("kokoro-v0_19.pth", description="PyTorch model filename")
|
||||||
|
onnx_model_file: str = Field("kokoro-v0_19.onnx", description="ONNX model filename")
|
||||||
|
|
||||||
# Backend-specific configs
|
# Backend-specific configs
|
||||||
onnx_cpu: ONNXCPUConfig = Field(default_factory=ONNXCPUConfig)
|
onnx_cpu: ONNXCPUConfig = Field(default_factory=ONNXCPUConfig)
|
||||||
onnx_gpu: ONNXGPUConfig = Field(default_factory=ONNXGPUConfig)
|
onnx_gpu: ONNXGPUConfig = Field(default_factory=ONNXGPUConfig)
|
||||||
|
|
|
@ -17,26 +17,27 @@ from .pytorch_gpu import PyTorchGPUBackend
|
||||||
from .session_pool import CPUSessionPool, StreamingSessionPool
|
from .session_pool import CPUSessionPool, StreamingSessionPool
|
||||||
|
|
||||||
|
|
||||||
# Global singleton instance and state
|
# Global singleton instance and lock for thread-safe initialization
|
||||||
_manager_instance = None
|
_manager_instance = None
|
||||||
_manager_lock = asyncio.Lock()
|
_manager_lock = asyncio.Lock()
|
||||||
_loaded_models = {}
|
|
||||||
_backends = {}
|
|
||||||
|
|
||||||
|
|
||||||
class ModelManager:
|
class ModelManager:
|
||||||
"""Manages model loading and inference across backends."""
|
"""Manages model loading and inference across backends."""
|
||||||
|
|
||||||
|
# Class-level state for shared resources
|
||||||
|
_loaded_models = {}
|
||||||
|
_backends = {}
|
||||||
|
|
||||||
def __init__(self, config: Optional[ModelConfig] = None):
|
def __init__(self, config: Optional[ModelConfig] = None):
|
||||||
"""Initialize model manager.
|
"""Initialize model manager.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: Optional configuration
|
config: Optional configuration
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This should not be called directly. Use get_manager() instead.
|
||||||
"""
|
"""
|
||||||
self._config = config or model_config
|
self._config = config or model_config
|
||||||
global _loaded_models, _backends
|
|
||||||
self._loaded_models = _loaded_models
|
|
||||||
self._backends = _backends
|
|
||||||
|
|
||||||
# Initialize session pools
|
# Initialize session pools
|
||||||
self._session_pools = {
|
self._session_pools = {
|
||||||
|
@ -293,10 +294,20 @@ async def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ModelManager instance
|
ModelManager instance
|
||||||
|
|
||||||
|
Thread Safety:
|
||||||
|
This function is thread-safe and ensures only one instance is created
|
||||||
|
even under concurrent access.
|
||||||
"""
|
"""
|
||||||
global _manager_instance
|
global _manager_instance
|
||||||
|
|
||||||
|
# Fast path - return existing instance without lock
|
||||||
|
if _manager_instance is not None:
|
||||||
|
return _manager_instance
|
||||||
|
|
||||||
|
# Slow path - create new instance with lock
|
||||||
async with _manager_lock:
|
async with _manager_lock:
|
||||||
|
# Double-check pattern
|
||||||
if _manager_instance is None:
|
if _manager_instance is None:
|
||||||
_manager_instance = ModelManager(config)
|
_manager_instance = ModelManager(config)
|
||||||
await _manager_instance.initialize()
|
await _manager_instance.initialize()
|
||||||
|
|
|
@ -75,21 +75,17 @@ def create_provider_options(is_gpu: bool = False) -> Dict:
|
||||||
if is_gpu:
|
if is_gpu:
|
||||||
config = model_config.onnx_gpu
|
config = model_config.onnx_gpu
|
||||||
return {
|
return {
|
||||||
"CUDAExecutionProvider": {
|
|
||||||
"device_id": config.device_id,
|
"device_id": config.device_id,
|
||||||
"arena_extend_strategy": config.arena_extend_strategy,
|
"arena_extend_strategy": config.arena_extend_strategy,
|
||||||
"gpu_mem_limit": int(config.gpu_mem_limit * torch.cuda.get_device_properties(0).total_memory),
|
"gpu_mem_limit": int(config.gpu_mem_limit * torch.cuda.get_device_properties(0).total_memory),
|
||||||
"cudnn_conv_algo_search": config.cudnn_conv_algo_search,
|
"cudnn_conv_algo_search": config.cudnn_conv_algo_search,
|
||||||
"do_copy_in_default_stream": config.do_copy_in_default_stream
|
"do_copy_in_default_stream": config.do_copy_in_default_stream
|
||||||
}
|
}
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
return {
|
return {
|
||||||
"CPUExecutionProvider": {
|
|
||||||
"arena_extend_strategy": model_config.onnx_cpu.arena_extend_strategy,
|
"arena_extend_strategy": model_config.onnx_cpu.arena_extend_strategy,
|
||||||
"cpu_memory_arena_cfg": "cpu:0"
|
"cpu_memory_arena_cfg": "cpu:0"
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class BaseSessionPool:
|
class BaseSessionPool:
|
||||||
|
|
|
@ -73,12 +73,13 @@ class VoiceManager:
|
||||||
return voice
|
return voice
|
||||||
|
|
||||||
def _manage_cache(self) -> None:
|
def _manage_cache(self) -> None:
|
||||||
"""Manage voice cache size."""
|
"""Manage voice cache size using simple LRU."""
|
||||||
if len(self._voice_cache) >= self._config.cache_size:
|
if len(self._voice_cache) >= self._config.cache_size:
|
||||||
# Remove oldest voice
|
# Remove least recently used voice
|
||||||
oldest = next(iter(self._voice_cache))
|
oldest = next(iter(self._voice_cache))
|
||||||
del self._voice_cache[oldest]
|
del self._voice_cache[oldest]
|
||||||
logger.debug(f"Removed from voice cache: {oldest}")
|
torch.cuda.empty_cache() # Clean up GPU memory if needed
|
||||||
|
logger.debug(f"Removed LRU voice from cache: {oldest}")
|
||||||
|
|
||||||
async def combine_voices(self, voices: List[str], device: str = "cpu") -> str:
|
async def combine_voices(self, voices: List[str], device: str = "cpu") -> str:
|
||||||
"""Combine multiple voices into a new voice.
|
"""Combine multiple voices into a new voice.
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
|
|
||||||
"""
|
"""
|
||||||
FastAPI OpenAI Compatible API
|
FastAPI OpenAI Compatible API
|
||||||
"""
|
"""
|
||||||
|
@ -14,6 +13,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from .core.config import settings
|
from .core.config import settings
|
||||||
|
from .core.model_config import model_config
|
||||||
from .routers.development import router as dev_router
|
from .routers.development import router as dev_router
|
||||||
from .routers.openai_compatible import router as openai_router
|
from .routers.openai_compatible import router as openai_router
|
||||||
from .services.tts_service import TTSService
|
from .services.tts_service import TTSService
|
||||||
|
@ -63,8 +63,8 @@ async def lifespan(app: FastAPI):
|
||||||
# Get backend and initialize model
|
# Get backend and initialize model
|
||||||
backend = model_manager.get_backend(backend_type)
|
backend = model_manager.get_backend(backend_type)
|
||||||
|
|
||||||
# Use model path directly from settings
|
# Use model path from model_config
|
||||||
model_file = settings.pytorch_model_file if not settings.use_onnx else settings.onnx_model_file
|
model_file = model_config.pytorch_model_file if not settings.use_onnx else model_config.onnx_model_file
|
||||||
model_path = os.path.join(settings.model_dir, model_file)
|
model_path = os.path.join(settings.model_dir, model_file)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -130,12 +130,12 @@ async def create_speech(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Generate complete audio
|
# Generate complete audio using public interface
|
||||||
audio, _ = tts_service._generate_audio(
|
audio, _ = await tts_service.generate_audio(
|
||||||
text=request.input,
|
text=request.input,
|
||||||
voice=voice_to_use,
|
voice=voice_to_use,
|
||||||
speed=request.speed,
|
speed=request.speed,
|
||||||
stitch_long_output=True,
|
stitch_long_output=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert to requested format
|
# Convert to requested format
|
||||||
|
@ -153,14 +153,37 @@ async def create_speech(
|
||||||
)
|
)
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f"Invalid request: {str(e)}")
|
# Handle validation errors
|
||||||
|
logger.warning(f"Invalid request: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail={"error": "Invalid request", "message": str(e)}
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": "validation_error",
|
||||||
|
"message": str(e),
|
||||||
|
"type": "invalid_request_error"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
# Handle runtime/processing errors
|
||||||
|
logger.error(f"Processing error: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={
|
||||||
|
"error": "processing_error",
|
||||||
|
"message": "Failed to process audio generation request",
|
||||||
|
"type": "server_error"
|
||||||
|
}
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error generating speech: {str(e)}")
|
# Handle unexpected errors
|
||||||
|
logger.error(f"Unexpected error in speech generation: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail={"error": "Server error", "message": str(e)}
|
status_code=500,
|
||||||
|
detail={
|
||||||
|
"error": "server_error",
|
||||||
|
"message": "An unexpected error occurred",
|
||||||
|
"type": "server_error"
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -173,7 +196,14 @@ async def list_voices():
|
||||||
return {"voices": voices}
|
return {"voices": voices}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error listing voices: {str(e)}")
|
logger.error(f"Error listing voices: {str(e)}")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={
|
||||||
|
"error": "server_error",
|
||||||
|
"message": "Failed to retrieve voice list",
|
||||||
|
"type": "server_error"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/audio/voices/combine")
|
@router.post("/audio/voices/combine")
|
||||||
|
@ -199,13 +229,32 @@ async def combine_voices(request: Union[str, List[str]]):
|
||||||
return {"voices": voices, "voice": combined_voice}
|
return {"voices": voices, "voice": combined_voice}
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f"Invalid voice combination request: {str(e)}")
|
logger.warning(f"Invalid voice combination request: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400, detail={"error": "Invalid request", "message": str(e)}
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": "validation_error",
|
||||||
|
"message": str(e),
|
||||||
|
"type": "invalid_request_error"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error(f"Voice combination processing error: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={
|
||||||
|
"error": "processing_error",
|
||||||
|
"message": "Failed to process voice combination request",
|
||||||
|
"type": "server_error"
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Server error during voice combination: {str(e)}")
|
logger.error(f"Unexpected error in voice combination: {str(e)}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500, detail={"error": "Server error", "message": "Server error"}
|
status_code=500,
|
||||||
|
detail={
|
||||||
|
"error": "server_error",
|
||||||
|
"message": "An unexpected error occurred",
|
||||||
|
"type": "server_error"
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
|
@ -51,7 +51,7 @@ class TTSService:
|
||||||
return service
|
return service
|
||||||
|
|
||||||
async def generate_audio(
|
async def generate_audio(
|
||||||
self, text: str, voice: str, speed: float = 1.0
|
self, text: str, voice: str, speed: float = 1.0, stitch_long_output: bool = True
|
||||||
) -> Tuple[np.ndarray, float]:
|
) -> Tuple[np.ndarray, float]:
|
||||||
"""Generate audio for text.
|
"""Generate audio for text.
|
||||||
|
|
||||||
|
@ -59,9 +59,14 @@ class TTSService:
|
||||||
text: Input text
|
text: Input text
|
||||||
voice: Voice name
|
voice: Voice name
|
||||||
speed: Speed multiplier
|
speed: Speed multiplier
|
||||||
|
stitch_long_output: Whether to stitch together long outputs
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Audio samples and processing time
|
Audio samples and processing time
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If text is empty after preprocessing or no chunks generated
|
||||||
|
RuntimeError: If audio generation fails
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
voice_tensor = None
|
voice_tensor = None
|
||||||
|
|
|
@ -10,7 +10,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
&& apt-get clean \
|
&& apt-get clean \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Install uv
|
# Install uv for speed and glory
|
||||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
||||||
|
|
||||||
# Create non-root user
|
# Create non-root user
|
||||||
|
@ -44,6 +44,7 @@ ENV PYTHONPATH=/app
|
||||||
ENV PATH="/app/.venv/bin:$PATH"
|
ENV PATH="/app/.venv/bin:$PATH"
|
||||||
ENV UV_LINK_MODE=copy
|
ENV UV_LINK_MODE=copy
|
||||||
ENV USE_GPU=false
|
ENV USE_GPU=false
|
||||||
|
ENV USE_ONNX=true
|
||||||
|
|
||||||
# Run FastAPI server
|
# Run FastAPI server
|
||||||
CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]
|
CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]
|
||||||
|
|
|
@ -1,22 +0,0 @@
|
||||||
[project]
|
|
||||||
name = "kokoro-fastapi-cpu"
|
|
||||||
version = "0.1.0"
|
|
||||||
description = "FastAPI TTS Service - CPU Version"
|
|
||||||
readme = "../README.md"
|
|
||||||
requires-python = ">=3.10"
|
|
||||||
dependencies = [
|
|
||||||
# Core ML/DL for CPU
|
|
||||||
"torch>=2.5.1",
|
|
||||||
"transformers==4.47.1",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.uv.workspace]
|
|
||||||
members = ["../shared"]
|
|
||||||
|
|
||||||
[tool.uv.sources]
|
|
||||||
torch = { index = "pytorch-cpu" }
|
|
||||||
|
|
||||||
[[tool.uv.index]]
|
|
||||||
name = "pytorch-cpu"
|
|
||||||
url = "https://download.pytorch.org/whl/cpu"
|
|
||||||
explicit = true
|
|
|
@ -1,229 +0,0 @@
|
||||||
# This file was autogenerated by uv via the following command:
|
|
||||||
# uv pip compile pyproject.toml ../shared/pyproject.toml --output-file requirements.lock
|
|
||||||
aiofiles==23.2.1
|
|
||||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
annotated-types==0.7.0
|
|
||||||
# via pydantic
|
|
||||||
anyio==4.8.0
|
|
||||||
# via starlette
|
|
||||||
attrs==24.3.0
|
|
||||||
# via
|
|
||||||
# clldutils
|
|
||||||
# csvw
|
|
||||||
# jsonschema
|
|
||||||
# phonemizer
|
|
||||||
# referencing
|
|
||||||
babel==2.16.0
|
|
||||||
# via csvw
|
|
||||||
certifi==2024.12.14
|
|
||||||
# via requests
|
|
||||||
cffi==1.17.1
|
|
||||||
# via soundfile
|
|
||||||
charset-normalizer==3.4.1
|
|
||||||
# via requests
|
|
||||||
click==8.1.8
|
|
||||||
# via
|
|
||||||
# kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
# uvicorn
|
|
||||||
clldutils==3.21.0
|
|
||||||
# via segments
|
|
||||||
colorama==0.4.6
|
|
||||||
# via
|
|
||||||
# click
|
|
||||||
# colorlog
|
|
||||||
# csvw
|
|
||||||
# loguru
|
|
||||||
# tqdm
|
|
||||||
coloredlogs==15.0.1
|
|
||||||
# via onnxruntime
|
|
||||||
colorlog==6.9.0
|
|
||||||
# via clldutils
|
|
||||||
csvw==3.5.1
|
|
||||||
# via segments
|
|
||||||
dlinfo==1.2.1
|
|
||||||
# via phonemizer
|
|
||||||
exceptiongroup==1.2.2
|
|
||||||
# via anyio
|
|
||||||
fastapi==0.115.6
|
|
||||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
filelock==3.16.1
|
|
||||||
# via
|
|
||||||
# huggingface-hub
|
|
||||||
# torch
|
|
||||||
# transformers
|
|
||||||
flatbuffers==24.12.23
|
|
||||||
# via onnxruntime
|
|
||||||
fsspec==2024.12.0
|
|
||||||
# via
|
|
||||||
# huggingface-hub
|
|
||||||
# torch
|
|
||||||
greenlet==3.1.1
|
|
||||||
# via sqlalchemy
|
|
||||||
h11==0.14.0
|
|
||||||
# via uvicorn
|
|
||||||
huggingface-hub==0.27.1
|
|
||||||
# via
|
|
||||||
# tokenizers
|
|
||||||
# transformers
|
|
||||||
humanfriendly==10.0
|
|
||||||
# via coloredlogs
|
|
||||||
idna==3.10
|
|
||||||
# via
|
|
||||||
# anyio
|
|
||||||
# requests
|
|
||||||
isodate==0.7.2
|
|
||||||
# via
|
|
||||||
# csvw
|
|
||||||
# rdflib
|
|
||||||
jinja2==3.1.5
|
|
||||||
# via torch
|
|
||||||
joblib==1.4.2
|
|
||||||
# via phonemizer
|
|
||||||
jsonschema==4.23.0
|
|
||||||
# via csvw
|
|
||||||
jsonschema-specifications==2024.10.1
|
|
||||||
# via jsonschema
|
|
||||||
language-tags==1.2.0
|
|
||||||
# via csvw
|
|
||||||
loguru==0.7.3
|
|
||||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
lxml==5.3.0
|
|
||||||
# via clldutils
|
|
||||||
markdown==3.7
|
|
||||||
# via clldutils
|
|
||||||
markupsafe==3.0.2
|
|
||||||
# via
|
|
||||||
# clldutils
|
|
||||||
# jinja2
|
|
||||||
mpmath==1.3.0
|
|
||||||
# via sympy
|
|
||||||
munch==4.0.0
|
|
||||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
networkx==3.4.2
|
|
||||||
# via torch
|
|
||||||
numpy==2.2.1
|
|
||||||
# via
|
|
||||||
# kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
# onnxruntime
|
|
||||||
# scipy
|
|
||||||
# soundfile
|
|
||||||
# transformers
|
|
||||||
onnxruntime==1.20.1
|
|
||||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
packaging==24.2
|
|
||||||
# via
|
|
||||||
# huggingface-hub
|
|
||||||
# onnxruntime
|
|
||||||
# transformers
|
|
||||||
phonemizer==3.3.0
|
|
||||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
protobuf==5.29.3
|
|
||||||
# via onnxruntime
|
|
||||||
pycparser==2.22
|
|
||||||
# via cffi
|
|
||||||
pydantic==2.10.4
|
|
||||||
# via
|
|
||||||
# kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
# fastapi
|
|
||||||
# pydantic-settings
|
|
||||||
pydantic-core==2.27.2
|
|
||||||
# via pydantic
|
|
||||||
pydantic-settings==2.7.0
|
|
||||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
pylatexenc==2.10
|
|
||||||
# via clldutils
|
|
||||||
pyparsing==3.2.1
|
|
||||||
# via rdflib
|
|
||||||
pyreadline3==3.5.4
|
|
||||||
# via humanfriendly
|
|
||||||
python-dateutil==2.9.0.post0
|
|
||||||
# via
|
|
||||||
# clldutils
|
|
||||||
# csvw
|
|
||||||
python-dotenv==1.0.1
|
|
||||||
# via
|
|
||||||
# kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
# pydantic-settings
|
|
||||||
pyyaml==6.0.2
|
|
||||||
# via
|
|
||||||
# huggingface-hub
|
|
||||||
# transformers
|
|
||||||
rdflib==7.1.2
|
|
||||||
# via csvw
|
|
||||||
referencing==0.35.1
|
|
||||||
# via
|
|
||||||
# jsonschema
|
|
||||||
# jsonschema-specifications
|
|
||||||
regex==2024.11.6
|
|
||||||
# via
|
|
||||||
# kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
# segments
|
|
||||||
# tiktoken
|
|
||||||
# transformers
|
|
||||||
requests==2.32.3
|
|
||||||
# via
|
|
||||||
# kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
# csvw
|
|
||||||
# huggingface-hub
|
|
||||||
# tiktoken
|
|
||||||
# transformers
|
|
||||||
rfc3986==1.5.0
|
|
||||||
# via csvw
|
|
||||||
rpds-py==0.22.3
|
|
||||||
# via
|
|
||||||
# jsonschema
|
|
||||||
# referencing
|
|
||||||
safetensors==0.5.2
|
|
||||||
# via transformers
|
|
||||||
scipy==1.14.1
|
|
||||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
segments==2.2.1
|
|
||||||
# via phonemizer
|
|
||||||
six==1.17.0
|
|
||||||
# via python-dateutil
|
|
||||||
sniffio==1.3.1
|
|
||||||
# via anyio
|
|
||||||
soundfile==0.13.0
|
|
||||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
sqlalchemy==2.0.27
|
|
||||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
starlette==0.41.3
|
|
||||||
# via fastapi
|
|
||||||
sympy==1.13.1
|
|
||||||
# via
|
|
||||||
# onnxruntime
|
|
||||||
# torch
|
|
||||||
tabulate==0.9.0
|
|
||||||
# via clldutils
|
|
||||||
tiktoken==0.8.0
|
|
||||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
tokenizers==0.21.0
|
|
||||||
# via transformers
|
|
||||||
torch==2.5.1+cpu
|
|
||||||
# via kokoro-fastapi-cpu (pyproject.toml)
|
|
||||||
tqdm==4.67.1
|
|
||||||
# via
|
|
||||||
# kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
# huggingface-hub
|
|
||||||
# transformers
|
|
||||||
transformers==4.47.1
|
|
||||||
# via kokoro-fastapi-cpu (pyproject.toml)
|
|
||||||
typing-extensions==4.12.2
|
|
||||||
# via
|
|
||||||
# anyio
|
|
||||||
# fastapi
|
|
||||||
# huggingface-hub
|
|
||||||
# phonemizer
|
|
||||||
# pydantic
|
|
||||||
# pydantic-core
|
|
||||||
# sqlalchemy
|
|
||||||
# torch
|
|
||||||
# uvicorn
|
|
||||||
uritemplate==4.1.1
|
|
||||||
# via csvw
|
|
||||||
urllib3==2.3.0
|
|
||||||
# via requests
|
|
||||||
uvicorn==0.34.0
|
|
||||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
win32-setctime==1.2.0
|
|
||||||
# via loguru
|
|
1841
docker/cpu/uv.lock
generated
1841
docker/cpu/uv.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -1,8 +1,11 @@
|
||||||
FROM pytorch/pytorch:2.1.0-cuda11.8-cudnn8-runtime
|
FROM nvidia/cuda:12.3.2-cudnn9-runtime-ubuntu22.04
|
||||||
# Set non-interactive frontend
|
# Set non-interactive frontend
|
||||||
ENV DEBIAN_FRONTEND=noninteractive
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
# Install dependencies
|
|
||||||
|
# Install Python and other dependencies
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
python3.10 \
|
||||||
|
python3.10-venv \
|
||||||
espeak-ng \
|
espeak-ng \
|
||||||
git \
|
git \
|
||||||
libsndfile1 \
|
libsndfile1 \
|
||||||
|
@ -27,15 +30,15 @@ WORKDIR /app
|
||||||
# Copy dependency files
|
# Copy dependency files
|
||||||
COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml
|
COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml
|
||||||
|
|
||||||
# Install dependencies
|
# Install dependencies with GPU extras
|
||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
uv venv && \
|
uv venv && \
|
||||||
uv sync --extra gpu --no-install-project
|
uv sync --extra gpu
|
||||||
|
|
||||||
# Copy project files
|
# Copy project files
|
||||||
COPY --chown=appuser:appuser api ./api
|
COPY --chown=appuser:appuser api ./api
|
||||||
|
|
||||||
# Install project
|
# Install project with GPU extras
|
||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
uv sync --extra gpu
|
uv sync --extra gpu
|
||||||
|
|
||||||
|
|
|
@ -1,22 +0,0 @@
|
||||||
[project]
|
|
||||||
name = "kokoro-fastapi-gpu"
|
|
||||||
version = "0.1.0"
|
|
||||||
description = "FastAPI TTS Service - GPU Version"
|
|
||||||
readme = "../README.md"
|
|
||||||
requires-python = ">=3.10"
|
|
||||||
dependencies = [
|
|
||||||
# Core ML/DL for GPU
|
|
||||||
"torch==2.5.1+cu121",
|
|
||||||
"transformers==4.47.1",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.uv.workspace]
|
|
||||||
members = ["../shared"]
|
|
||||||
|
|
||||||
[tool.uv.sources]
|
|
||||||
torch = { index = "pytorch-cuda" }
|
|
||||||
|
|
||||||
[[tool.uv.index]]
|
|
||||||
name = "pytorch-cuda"
|
|
||||||
url = "https://download.pytorch.org/whl/cu121"
|
|
||||||
explicit = true
|
|
|
@ -1,229 +0,0 @@
|
||||||
# This file was autogenerated by uv via the following command:
|
|
||||||
# uv pip compile pyproject.toml ../shared/pyproject.toml --output-file requirements.lock
|
|
||||||
aiofiles==23.2.1
|
|
||||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
annotated-types==0.7.0
|
|
||||||
# via pydantic
|
|
||||||
anyio==4.8.0
|
|
||||||
# via starlette
|
|
||||||
attrs==24.3.0
|
|
||||||
# via
|
|
||||||
# clldutils
|
|
||||||
# csvw
|
|
||||||
# jsonschema
|
|
||||||
# phonemizer
|
|
||||||
# referencing
|
|
||||||
babel==2.16.0
|
|
||||||
# via csvw
|
|
||||||
certifi==2024.12.14
|
|
||||||
# via requests
|
|
||||||
cffi==1.17.1
|
|
||||||
# via soundfile
|
|
||||||
charset-normalizer==3.4.1
|
|
||||||
# via requests
|
|
||||||
click==8.1.8
|
|
||||||
# via
|
|
||||||
# kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
# uvicorn
|
|
||||||
clldutils==3.21.0
|
|
||||||
# via segments
|
|
||||||
colorama==0.4.6
|
|
||||||
# via
|
|
||||||
# click
|
|
||||||
# colorlog
|
|
||||||
# csvw
|
|
||||||
# loguru
|
|
||||||
# tqdm
|
|
||||||
coloredlogs==15.0.1
|
|
||||||
# via onnxruntime
|
|
||||||
colorlog==6.9.0
|
|
||||||
# via clldutils
|
|
||||||
csvw==3.5.1
|
|
||||||
# via segments
|
|
||||||
dlinfo==1.2.1
|
|
||||||
# via phonemizer
|
|
||||||
exceptiongroup==1.2.2
|
|
||||||
# via anyio
|
|
||||||
fastapi==0.115.6
|
|
||||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
filelock==3.16.1
|
|
||||||
# via
|
|
||||||
# huggingface-hub
|
|
||||||
# torch
|
|
||||||
# transformers
|
|
||||||
flatbuffers==24.12.23
|
|
||||||
# via onnxruntime
|
|
||||||
fsspec==2024.12.0
|
|
||||||
# via
|
|
||||||
# huggingface-hub
|
|
||||||
# torch
|
|
||||||
greenlet==3.1.1
|
|
||||||
# via sqlalchemy
|
|
||||||
h11==0.14.0
|
|
||||||
# via uvicorn
|
|
||||||
huggingface-hub==0.27.1
|
|
||||||
# via
|
|
||||||
# tokenizers
|
|
||||||
# transformers
|
|
||||||
humanfriendly==10.0
|
|
||||||
# via coloredlogs
|
|
||||||
idna==3.10
|
|
||||||
# via
|
|
||||||
# anyio
|
|
||||||
# requests
|
|
||||||
isodate==0.7.2
|
|
||||||
# via
|
|
||||||
# csvw
|
|
||||||
# rdflib
|
|
||||||
jinja2==3.1.5
|
|
||||||
# via torch
|
|
||||||
joblib==1.4.2
|
|
||||||
# via phonemizer
|
|
||||||
jsonschema==4.23.0
|
|
||||||
# via csvw
|
|
||||||
jsonschema-specifications==2024.10.1
|
|
||||||
# via jsonschema
|
|
||||||
language-tags==1.2.0
|
|
||||||
# via csvw
|
|
||||||
loguru==0.7.3
|
|
||||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
lxml==5.3.0
|
|
||||||
# via clldutils
|
|
||||||
markdown==3.7
|
|
||||||
# via clldutils
|
|
||||||
markupsafe==3.0.2
|
|
||||||
# via
|
|
||||||
# clldutils
|
|
||||||
# jinja2
|
|
||||||
mpmath==1.3.0
|
|
||||||
# via sympy
|
|
||||||
munch==4.0.0
|
|
||||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
networkx==3.4.2
|
|
||||||
# via torch
|
|
||||||
numpy==2.2.1
|
|
||||||
# via
|
|
||||||
# kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
# onnxruntime
|
|
||||||
# scipy
|
|
||||||
# soundfile
|
|
||||||
# transformers
|
|
||||||
onnxruntime==1.20.1
|
|
||||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
packaging==24.2
|
|
||||||
# via
|
|
||||||
# huggingface-hub
|
|
||||||
# onnxruntime
|
|
||||||
# transformers
|
|
||||||
phonemizer==3.3.0
|
|
||||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
protobuf==5.29.3
|
|
||||||
# via onnxruntime
|
|
||||||
pycparser==2.22
|
|
||||||
# via cffi
|
|
||||||
pydantic==2.10.4
|
|
||||||
# via
|
|
||||||
# kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
# fastapi
|
|
||||||
# pydantic-settings
|
|
||||||
pydantic-core==2.27.2
|
|
||||||
# via pydantic
|
|
||||||
pydantic-settings==2.7.0
|
|
||||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
pylatexenc==2.10
|
|
||||||
# via clldutils
|
|
||||||
pyparsing==3.2.1
|
|
||||||
# via rdflib
|
|
||||||
pyreadline3==3.5.4
|
|
||||||
# via humanfriendly
|
|
||||||
python-dateutil==2.9.0.post0
|
|
||||||
# via
|
|
||||||
# clldutils
|
|
||||||
# csvw
|
|
||||||
python-dotenv==1.0.1
|
|
||||||
# via
|
|
||||||
# kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
# pydantic-settings
|
|
||||||
pyyaml==6.0.2
|
|
||||||
# via
|
|
||||||
# huggingface-hub
|
|
||||||
# transformers
|
|
||||||
rdflib==7.1.2
|
|
||||||
# via csvw
|
|
||||||
referencing==0.35.1
|
|
||||||
# via
|
|
||||||
# jsonschema
|
|
||||||
# jsonschema-specifications
|
|
||||||
regex==2024.11.6
|
|
||||||
# via
|
|
||||||
# kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
# segments
|
|
||||||
# tiktoken
|
|
||||||
# transformers
|
|
||||||
requests==2.32.3
|
|
||||||
# via
|
|
||||||
# kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
# csvw
|
|
||||||
# huggingface-hub
|
|
||||||
# tiktoken
|
|
||||||
# transformers
|
|
||||||
rfc3986==1.5.0
|
|
||||||
# via csvw
|
|
||||||
rpds-py==0.22.3
|
|
||||||
# via
|
|
||||||
# jsonschema
|
|
||||||
# referencing
|
|
||||||
safetensors==0.5.2
|
|
||||||
# via transformers
|
|
||||||
scipy==1.14.1
|
|
||||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
segments==2.2.1
|
|
||||||
# via phonemizer
|
|
||||||
six==1.17.0
|
|
||||||
# via python-dateutil
|
|
||||||
sniffio==1.3.1
|
|
||||||
# via anyio
|
|
||||||
soundfile==0.13.0
|
|
||||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
sqlalchemy==2.0.27
|
|
||||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
starlette==0.41.3
|
|
||||||
# via fastapi
|
|
||||||
sympy==1.13.1
|
|
||||||
# via
|
|
||||||
# onnxruntime
|
|
||||||
# torch
|
|
||||||
tabulate==0.9.0
|
|
||||||
# via clldutils
|
|
||||||
tiktoken==0.8.0
|
|
||||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
tokenizers==0.21.0
|
|
||||||
# via transformers
|
|
||||||
torch==2.5.1+cu121
|
|
||||||
# via kokoro-fastapi-gpu (pyproject.toml)
|
|
||||||
tqdm==4.67.1
|
|
||||||
# via
|
|
||||||
# kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
# huggingface-hub
|
|
||||||
# transformers
|
|
||||||
transformers==4.47.1
|
|
||||||
# via kokoro-fastapi-gpu (pyproject.toml)
|
|
||||||
typing-extensions==4.12.2
|
|
||||||
# via
|
|
||||||
# anyio
|
|
||||||
# fastapi
|
|
||||||
# huggingface-hub
|
|
||||||
# phonemizer
|
|
||||||
# pydantic
|
|
||||||
# pydantic-core
|
|
||||||
# sqlalchemy
|
|
||||||
# torch
|
|
||||||
# uvicorn
|
|
||||||
uritemplate==4.1.1
|
|
||||||
# via csvw
|
|
||||||
urllib3==2.3.0
|
|
||||||
# via requests
|
|
||||||
uvicorn==0.34.0
|
|
||||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
|
||||||
win32-setctime==1.2.0
|
|
||||||
# via loguru
|
|
1914
docker/gpu/uv.lock
generated
1914
docker/gpu/uv.lock
generated
File diff suppressed because it is too large
Load diff
Binary file not shown.
|
@ -16,7 +16,6 @@ dependencies = [
|
||||||
# ML/DL Base
|
# ML/DL Base
|
||||||
"numpy>=1.26.0",
|
"numpy>=1.26.0",
|
||||||
"scipy==1.14.1",
|
"scipy==1.14.1",
|
||||||
"onnxruntime==1.20.1",
|
|
||||||
# Audio processing
|
# Audio processing
|
||||||
"soundfile==0.13.0",
|
"soundfile==0.13.0",
|
||||||
# Text processing
|
# Text processing
|
||||||
|
@ -39,9 +38,11 @@ dependencies = [
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
gpu = [
|
gpu = [
|
||||||
"torch==2.5.1+cu121",
|
"torch==2.5.1+cu121",
|
||||||
|
"onnxruntime-gpu==1.20.1",
|
||||||
]
|
]
|
||||||
cpu = [
|
cpu = [
|
||||||
"torch==2.5.1",
|
"torch==2.5.1",
|
||||||
|
"onnxruntime==1.20.1",
|
||||||
]
|
]
|
||||||
test = [
|
test = [
|
||||||
"pytest==8.0.0",
|
"pytest==8.0.0",
|
||||||
|
|
Loading…
Add table
Reference in a new issue