mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Enable ONNX GPU support in Docker configurations and refactor model file handling
This commit is contained in:
parent
4a24be1605
commit
d50214d3be
19 changed files with 123 additions and 4321 deletions
|
@ -7,6 +7,9 @@
|
|||
[]()
|
||||
[](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667) [](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero)
|
||||
|
||||
> [!INFO]
|
||||
> Pre-release. Not fully tested
|
||||
|
||||
Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model
|
||||
- OpenAI-compatible Speech endpoint, with inline voice combination functionality
|
||||
- NVIDIA GPU accelerated or CPU Onnx inference
|
||||
|
|
|
@ -9,37 +9,22 @@ class Settings(BaseSettings):
|
|||
host: str = "0.0.0.0"
|
||||
port: int = 8880
|
||||
|
||||
# TTS Settings
|
||||
# Application Settings
|
||||
output_dir: str = "output"
|
||||
output_dir_size_limit_mb: float = 500.0 # Maximum size of output directory in MB
|
||||
default_voice: str = "af"
|
||||
use_gpu: bool = False # Whether to use GPU acceleration if available
|
||||
use_onnx: bool = True # Whether to use ONNX runtime
|
||||
|
||||
# Container absolute paths
|
||||
model_dir: str = "/app/api/src/models" # Absolute path in container
|
||||
voices_dir: str = "/app/api/src/voices" # Absolute path in container
|
||||
|
||||
# Model filenames
|
||||
pytorch_model_file: str = "kokoro-v0_19.pth"
|
||||
onnx_model_file: str = "kokoro-v0_19.onnx"
|
||||
# Audio Settings
|
||||
sample_rate: int = 24000
|
||||
max_chunk_size: int = 300 # Maximum size of text chunks for processing
|
||||
gap_trim_ms: int = 250 # Amount to trim from streaming chunk ends in milliseconds
|
||||
|
||||
# ONNX Optimization Settings
|
||||
onnx_num_threads: int = 4 # Number of threads for intra-op parallelism
|
||||
onnx_inter_op_threads: int = 4 # Number of threads for inter-op parallelism
|
||||
onnx_execution_mode: str = "parallel" # parallel or sequential
|
||||
onnx_optimization_level: str = "all" # all, basic, or disabled
|
||||
onnx_memory_pattern: bool = True # Enable memory pattern optimization
|
||||
onnx_arena_extend_strategy: str = "kNextPowerOfTwo" # Memory allocation strategy
|
||||
|
||||
# ONNX GPU Settings
|
||||
onnx_device_id: int = 0 # GPU device ID to use
|
||||
onnx_gpu_mem_limit: float = 0.7 # Limit GPU memory usage to 70%
|
||||
onnx_cudnn_conv_algo_search: str = "EXHAUSTIVE" # CUDNN convolution algorithm search
|
||||
onnx_do_copy_in_default_stream: bool = True # Copy in default CUDA stream
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ class ONNXCPUConfig(BaseModel):
|
|||
instance_timeout: int = Field(300, description="Session timeout in seconds")
|
||||
|
||||
# Runtime settings
|
||||
num_threads: int = Field(8, description="Number of threads for parallel operations")
|
||||
num_threads: int = Field(4, description="Number of threads for parallel operations")
|
||||
inter_op_threads: int = Field(4, description="Number of threads for operator parallelism")
|
||||
execution_mode: str = Field("parallel", description="ONNX execution mode")
|
||||
optimization_level: str = Field("all", description="ONNX optimization level")
|
||||
|
@ -77,6 +77,10 @@ class ModelConfig(BaseModel):
|
|||
cache_voices: bool = Field(True, description="Whether to cache voice tensors")
|
||||
voice_cache_size: int = Field(2, description="Maximum number of cached voices")
|
||||
|
||||
# Model filenames
|
||||
pytorch_model_file: str = Field("kokoro-v0_19.pth", description="PyTorch model filename")
|
||||
onnx_model_file: str = Field("kokoro-v0_19.onnx", description="ONNX model filename")
|
||||
|
||||
# Backend-specific configs
|
||||
onnx_cpu: ONNXCPUConfig = Field(default_factory=ONNXCPUConfig)
|
||||
onnx_gpu: ONNXGPUConfig = Field(default_factory=ONNXGPUConfig)
|
||||
|
|
|
@ -17,26 +17,27 @@ from .pytorch_gpu import PyTorchGPUBackend
|
|||
from .session_pool import CPUSessionPool, StreamingSessionPool
|
||||
|
||||
|
||||
# Global singleton instance and state
|
||||
# Global singleton instance and lock for thread-safe initialization
|
||||
_manager_instance = None
|
||||
_manager_lock = asyncio.Lock()
|
||||
_loaded_models = {}
|
||||
_backends = {}
|
||||
|
||||
|
||||
class ModelManager:
|
||||
"""Manages model loading and inference across backends."""
|
||||
|
||||
# Class-level state for shared resources
|
||||
_loaded_models = {}
|
||||
_backends = {}
|
||||
|
||||
def __init__(self, config: Optional[ModelConfig] = None):
|
||||
"""Initialize model manager.
|
||||
|
||||
Args:
|
||||
config: Optional configuration
|
||||
|
||||
Note:
|
||||
This should not be called directly. Use get_manager() instead.
|
||||
"""
|
||||
self._config = config or model_config
|
||||
global _loaded_models, _backends
|
||||
self._loaded_models = _loaded_models
|
||||
self._backends = _backends
|
||||
|
||||
# Initialize session pools
|
||||
self._session_pools = {
|
||||
|
@ -293,10 +294,20 @@ async def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
|
|||
|
||||
Returns:
|
||||
ModelManager instance
|
||||
|
||||
Thread Safety:
|
||||
This function is thread-safe and ensures only one instance is created
|
||||
even under concurrent access.
|
||||
"""
|
||||
global _manager_instance
|
||||
|
||||
# Fast path - return existing instance without lock
|
||||
if _manager_instance is not None:
|
||||
return _manager_instance
|
||||
|
||||
# Slow path - create new instance with lock
|
||||
async with _manager_lock:
|
||||
# Double-check pattern
|
||||
if _manager_instance is None:
|
||||
_manager_instance = ModelManager(config)
|
||||
await _manager_instance.initialize()
|
||||
|
|
|
@ -75,20 +75,16 @@ def create_provider_options(is_gpu: bool = False) -> Dict:
|
|||
if is_gpu:
|
||||
config = model_config.onnx_gpu
|
||||
return {
|
||||
"CUDAExecutionProvider": {
|
||||
"device_id": config.device_id,
|
||||
"arena_extend_strategy": config.arena_extend_strategy,
|
||||
"gpu_mem_limit": int(config.gpu_mem_limit * torch.cuda.get_device_properties(0).total_memory),
|
||||
"cudnn_conv_algo_search": config.cudnn_conv_algo_search,
|
||||
"do_copy_in_default_stream": config.do_copy_in_default_stream
|
||||
}
|
||||
"device_id": config.device_id,
|
||||
"arena_extend_strategy": config.arena_extend_strategy,
|
||||
"gpu_mem_limit": int(config.gpu_mem_limit * torch.cuda.get_device_properties(0).total_memory),
|
||||
"cudnn_conv_algo_search": config.cudnn_conv_algo_search,
|
||||
"do_copy_in_default_stream": config.do_copy_in_default_stream
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"CPUExecutionProvider": {
|
||||
"arena_extend_strategy": model_config.onnx_cpu.arena_extend_strategy,
|
||||
"cpu_memory_arena_cfg": "cpu:0"
|
||||
}
|
||||
"arena_extend_strategy": model_config.onnx_cpu.arena_extend_strategy,
|
||||
"cpu_memory_arena_cfg": "cpu:0"
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -73,12 +73,13 @@ class VoiceManager:
|
|||
return voice
|
||||
|
||||
def _manage_cache(self) -> None:
|
||||
"""Manage voice cache size."""
|
||||
"""Manage voice cache size using simple LRU."""
|
||||
if len(self._voice_cache) >= self._config.cache_size:
|
||||
# Remove oldest voice
|
||||
# Remove least recently used voice
|
||||
oldest = next(iter(self._voice_cache))
|
||||
del self._voice_cache[oldest]
|
||||
logger.debug(f"Removed from voice cache: {oldest}")
|
||||
torch.cuda.empty_cache() # Clean up GPU memory if needed
|
||||
logger.debug(f"Removed LRU voice from cache: {oldest}")
|
||||
|
||||
async def combine_voices(self, voices: List[str], device: str = "cpu") -> str:
|
||||
"""Combine multiple voices into a new voice.
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
"""
|
||||
FastAPI OpenAI Compatible API
|
||||
"""
|
||||
|
@ -14,6 +13,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||
from loguru import logger
|
||||
|
||||
from .core.config import settings
|
||||
from .core.model_config import model_config
|
||||
from .routers.development import router as dev_router
|
||||
from .routers.openai_compatible import router as openai_router
|
||||
from .services.tts_service import TTSService
|
||||
|
@ -63,8 +63,8 @@ async def lifespan(app: FastAPI):
|
|||
# Get backend and initialize model
|
||||
backend = model_manager.get_backend(backend_type)
|
||||
|
||||
# Use model path directly from settings
|
||||
model_file = settings.pytorch_model_file if not settings.use_onnx else settings.onnx_model_file
|
||||
# Use model path from model_config
|
||||
model_file = model_config.pytorch_model_file if not settings.use_onnx else model_config.onnx_model_file
|
||||
model_path = os.path.join(settings.model_dir, model_file)
|
||||
|
||||
|
||||
|
|
|
@ -130,12 +130,12 @@ async def create_speech(
|
|||
},
|
||||
)
|
||||
else:
|
||||
# Generate complete audio
|
||||
audio, _ = tts_service._generate_audio(
|
||||
# Generate complete audio using public interface
|
||||
audio, _ = await tts_service.generate_audio(
|
||||
text=request.input,
|
||||
voice=voice_to_use,
|
||||
speed=request.speed,
|
||||
stitch_long_output=True,
|
||||
stitch_long_output=True
|
||||
)
|
||||
|
||||
# Convert to requested format
|
||||
|
@ -153,14 +153,37 @@ async def create_speech(
|
|||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid request: {str(e)}")
|
||||
# Handle validation errors
|
||||
logger.warning(f"Invalid request: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=400, detail={"error": "Invalid request", "message": str(e)}
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "validation_error",
|
||||
"message": str(e),
|
||||
"type": "invalid_request_error"
|
||||
}
|
||||
)
|
||||
except RuntimeError as e:
|
||||
# Handle runtime/processing errors
|
||||
logger.error(f"Processing error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "processing_error",
|
||||
"message": "Failed to process audio generation request",
|
||||
"type": "server_error"
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating speech: {str(e)}")
|
||||
# Handle unexpected errors
|
||||
logger.error(f"Unexpected error in speech generation: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail={"error": "Server error", "message": str(e)}
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "server_error",
|
||||
"message": "An unexpected error occurred",
|
||||
"type": "server_error"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
@ -173,7 +196,14 @@ async def list_voices():
|
|||
return {"voices": voices}
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing voices: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "server_error",
|
||||
"message": "Failed to retrieve voice list",
|
||||
"type": "server_error"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/audio/voices/combine")
|
||||
|
@ -199,13 +229,32 @@ async def combine_voices(request: Union[str, List[str]]):
|
|||
return {"voices": voices, "voice": combined_voice}
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid voice combination request: {str(e)}")
|
||||
logger.warning(f"Invalid voice combination request: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=400, detail={"error": "Invalid request", "message": str(e)}
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "validation_error",
|
||||
"message": str(e),
|
||||
"type": "invalid_request_error"
|
||||
}
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Voice combination processing error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "processing_error",
|
||||
"message": "Failed to process voice combination request",
|
||||
"type": "server_error"
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Server error during voice combination: {str(e)}")
|
||||
logger.error(f"Unexpected error in voice combination: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail={"error": "Server error", "message": "Server error"}
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "server_error",
|
||||
"message": "An unexpected error occurred",
|
||||
"type": "server_error"
|
||||
}
|
||||
)
|
||||
|
|
|
@ -51,7 +51,7 @@ class TTSService:
|
|||
return service
|
||||
|
||||
async def generate_audio(
|
||||
self, text: str, voice: str, speed: float = 1.0
|
||||
self, text: str, voice: str, speed: float = 1.0, stitch_long_output: bool = True
|
||||
) -> Tuple[np.ndarray, float]:
|
||||
"""Generate audio for text.
|
||||
|
||||
|
@ -59,9 +59,14 @@ class TTSService:
|
|||
text: Input text
|
||||
voice: Voice name
|
||||
speed: Speed multiplier
|
||||
stitch_long_output: Whether to stitch together long outputs
|
||||
|
||||
Returns:
|
||||
Audio samples and processing time
|
||||
|
||||
Raises:
|
||||
ValueError: If text is empty after preprocessing or no chunks generated
|
||||
RuntimeError: If audio generation fails
|
||||
"""
|
||||
start_time = time.time()
|
||||
voice_tensor = None
|
||||
|
|
|
@ -10,7 +10,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install uv
|
||||
# Install uv for speed and glory
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
||||
|
||||
# Create non-root user
|
||||
|
@ -44,6 +44,7 @@ ENV PYTHONPATH=/app
|
|||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
ENV UV_LINK_MODE=copy
|
||||
ENV USE_GPU=false
|
||||
ENV USE_ONNX=true
|
||||
|
||||
# Run FastAPI server
|
||||
CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]
|
||||
|
|
|
@ -1,22 +0,0 @@
|
|||
[project]
|
||||
name = "kokoro-fastapi-cpu"
|
||||
version = "0.1.0"
|
||||
description = "FastAPI TTS Service - CPU Version"
|
||||
readme = "../README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
# Core ML/DL for CPU
|
||||
"torch>=2.5.1",
|
||||
"transformers==4.47.1",
|
||||
]
|
||||
|
||||
[tool.uv.workspace]
|
||||
members = ["../shared"]
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = { index = "pytorch-cpu" }
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
explicit = true
|
|
@ -1,229 +0,0 @@
|
|||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile pyproject.toml ../shared/pyproject.toml --output-file requirements.lock
|
||||
aiofiles==23.2.1
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
anyio==4.8.0
|
||||
# via starlette
|
||||
attrs==24.3.0
|
||||
# via
|
||||
# clldutils
|
||||
# csvw
|
||||
# jsonschema
|
||||
# phonemizer
|
||||
# referencing
|
||||
babel==2.16.0
|
||||
# via csvw
|
||||
certifi==2024.12.14
|
||||
# via requests
|
||||
cffi==1.17.1
|
||||
# via soundfile
|
||||
charset-normalizer==3.4.1
|
||||
# via requests
|
||||
click==8.1.8
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# uvicorn
|
||||
clldutils==3.21.0
|
||||
# via segments
|
||||
colorama==0.4.6
|
||||
# via
|
||||
# click
|
||||
# colorlog
|
||||
# csvw
|
||||
# loguru
|
||||
# tqdm
|
||||
coloredlogs==15.0.1
|
||||
# via onnxruntime
|
||||
colorlog==6.9.0
|
||||
# via clldutils
|
||||
csvw==3.5.1
|
||||
# via segments
|
||||
dlinfo==1.2.1
|
||||
# via phonemizer
|
||||
exceptiongroup==1.2.2
|
||||
# via anyio
|
||||
fastapi==0.115.6
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
filelock==3.16.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# torch
|
||||
# transformers
|
||||
flatbuffers==24.12.23
|
||||
# via onnxruntime
|
||||
fsspec==2024.12.0
|
||||
# via
|
||||
# huggingface-hub
|
||||
# torch
|
||||
greenlet==3.1.1
|
||||
# via sqlalchemy
|
||||
h11==0.14.0
|
||||
# via uvicorn
|
||||
huggingface-hub==0.27.1
|
||||
# via
|
||||
# tokenizers
|
||||
# transformers
|
||||
humanfriendly==10.0
|
||||
# via coloredlogs
|
||||
idna==3.10
|
||||
# via
|
||||
# anyio
|
||||
# requests
|
||||
isodate==0.7.2
|
||||
# via
|
||||
# csvw
|
||||
# rdflib
|
||||
jinja2==3.1.5
|
||||
# via torch
|
||||
joblib==1.4.2
|
||||
# via phonemizer
|
||||
jsonschema==4.23.0
|
||||
# via csvw
|
||||
jsonschema-specifications==2024.10.1
|
||||
# via jsonschema
|
||||
language-tags==1.2.0
|
||||
# via csvw
|
||||
loguru==0.7.3
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
lxml==5.3.0
|
||||
# via clldutils
|
||||
markdown==3.7
|
||||
# via clldutils
|
||||
markupsafe==3.0.2
|
||||
# via
|
||||
# clldutils
|
||||
# jinja2
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
munch==4.0.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
networkx==3.4.2
|
||||
# via torch
|
||||
numpy==2.2.1
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# onnxruntime
|
||||
# scipy
|
||||
# soundfile
|
||||
# transformers
|
||||
onnxruntime==1.20.1
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
packaging==24.2
|
||||
# via
|
||||
# huggingface-hub
|
||||
# onnxruntime
|
||||
# transformers
|
||||
phonemizer==3.3.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
protobuf==5.29.3
|
||||
# via onnxruntime
|
||||
pycparser==2.22
|
||||
# via cffi
|
||||
pydantic==2.10.4
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# fastapi
|
||||
# pydantic-settings
|
||||
pydantic-core==2.27.2
|
||||
# via pydantic
|
||||
pydantic-settings==2.7.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
pylatexenc==2.10
|
||||
# via clldutils
|
||||
pyparsing==3.2.1
|
||||
# via rdflib
|
||||
pyreadline3==3.5.4
|
||||
# via humanfriendly
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# clldutils
|
||||
# csvw
|
||||
python-dotenv==1.0.1
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# pydantic-settings
|
||||
pyyaml==6.0.2
|
||||
# via
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
rdflib==7.1.2
|
||||
# via csvw
|
||||
referencing==0.35.1
|
||||
# via
|
||||
# jsonschema
|
||||
# jsonschema-specifications
|
||||
regex==2024.11.6
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# segments
|
||||
# tiktoken
|
||||
# transformers
|
||||
requests==2.32.3
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# csvw
|
||||
# huggingface-hub
|
||||
# tiktoken
|
||||
# transformers
|
||||
rfc3986==1.5.0
|
||||
# via csvw
|
||||
rpds-py==0.22.3
|
||||
# via
|
||||
# jsonschema
|
||||
# referencing
|
||||
safetensors==0.5.2
|
||||
# via transformers
|
||||
scipy==1.14.1
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
segments==2.2.1
|
||||
# via phonemizer
|
||||
six==1.17.0
|
||||
# via python-dateutil
|
||||
sniffio==1.3.1
|
||||
# via anyio
|
||||
soundfile==0.13.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
sqlalchemy==2.0.27
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
starlette==0.41.3
|
||||
# via fastapi
|
||||
sympy==1.13.1
|
||||
# via
|
||||
# onnxruntime
|
||||
# torch
|
||||
tabulate==0.9.0
|
||||
# via clldutils
|
||||
tiktoken==0.8.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
tokenizers==0.21.0
|
||||
# via transformers
|
||||
torch==2.5.1+cpu
|
||||
# via kokoro-fastapi-cpu (pyproject.toml)
|
||||
tqdm==4.67.1
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
transformers==4.47.1
|
||||
# via kokoro-fastapi-cpu (pyproject.toml)
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# anyio
|
||||
# fastapi
|
||||
# huggingface-hub
|
||||
# phonemizer
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# sqlalchemy
|
||||
# torch
|
||||
# uvicorn
|
||||
uritemplate==4.1.1
|
||||
# via csvw
|
||||
urllib3==2.3.0
|
||||
# via requests
|
||||
uvicorn==0.34.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
win32-setctime==1.2.0
|
||||
# via loguru
|
1841
docker/cpu/uv.lock
generated
1841
docker/cpu/uv.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -1,8 +1,11 @@
|
|||
FROM pytorch/pytorch:2.1.0-cuda11.8-cudnn8-runtime
|
||||
FROM nvidia/cuda:12.3.2-cudnn9-runtime-ubuntu22.04
|
||||
# Set non-interactive frontend
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
# Install dependencies
|
||||
|
||||
# Install Python and other dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3.10 \
|
||||
python3.10-venv \
|
||||
espeak-ng \
|
||||
git \
|
||||
libsndfile1 \
|
||||
|
@ -27,15 +30,15 @@ WORKDIR /app
|
|||
# Copy dependency files
|
||||
COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml
|
||||
|
||||
# Install dependencies
|
||||
# Install dependencies with GPU extras
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv venv && \
|
||||
uv sync --extra gpu --no-install-project
|
||||
uv sync --extra gpu
|
||||
|
||||
# Copy project files
|
||||
COPY --chown=appuser:appuser api ./api
|
||||
|
||||
# Install project
|
||||
# Install project with GPU extras
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv sync --extra gpu
|
||||
|
||||
|
|
|
@ -1,22 +0,0 @@
|
|||
[project]
|
||||
name = "kokoro-fastapi-gpu"
|
||||
version = "0.1.0"
|
||||
description = "FastAPI TTS Service - GPU Version"
|
||||
readme = "../README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
# Core ML/DL for GPU
|
||||
"torch==2.5.1+cu121",
|
||||
"transformers==4.47.1",
|
||||
]
|
||||
|
||||
[tool.uv.workspace]
|
||||
members = ["../shared"]
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = { index = "pytorch-cuda" }
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cuda"
|
||||
url = "https://download.pytorch.org/whl/cu121"
|
||||
explicit = true
|
|
@ -1,229 +0,0 @@
|
|||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile pyproject.toml ../shared/pyproject.toml --output-file requirements.lock
|
||||
aiofiles==23.2.1
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
anyio==4.8.0
|
||||
# via starlette
|
||||
attrs==24.3.0
|
||||
# via
|
||||
# clldutils
|
||||
# csvw
|
||||
# jsonschema
|
||||
# phonemizer
|
||||
# referencing
|
||||
babel==2.16.0
|
||||
# via csvw
|
||||
certifi==2024.12.14
|
||||
# via requests
|
||||
cffi==1.17.1
|
||||
# via soundfile
|
||||
charset-normalizer==3.4.1
|
||||
# via requests
|
||||
click==8.1.8
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# uvicorn
|
||||
clldutils==3.21.0
|
||||
# via segments
|
||||
colorama==0.4.6
|
||||
# via
|
||||
# click
|
||||
# colorlog
|
||||
# csvw
|
||||
# loguru
|
||||
# tqdm
|
||||
coloredlogs==15.0.1
|
||||
# via onnxruntime
|
||||
colorlog==6.9.0
|
||||
# via clldutils
|
||||
csvw==3.5.1
|
||||
# via segments
|
||||
dlinfo==1.2.1
|
||||
# via phonemizer
|
||||
exceptiongroup==1.2.2
|
||||
# via anyio
|
||||
fastapi==0.115.6
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
filelock==3.16.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# torch
|
||||
# transformers
|
||||
flatbuffers==24.12.23
|
||||
# via onnxruntime
|
||||
fsspec==2024.12.0
|
||||
# via
|
||||
# huggingface-hub
|
||||
# torch
|
||||
greenlet==3.1.1
|
||||
# via sqlalchemy
|
||||
h11==0.14.0
|
||||
# via uvicorn
|
||||
huggingface-hub==0.27.1
|
||||
# via
|
||||
# tokenizers
|
||||
# transformers
|
||||
humanfriendly==10.0
|
||||
# via coloredlogs
|
||||
idna==3.10
|
||||
# via
|
||||
# anyio
|
||||
# requests
|
||||
isodate==0.7.2
|
||||
# via
|
||||
# csvw
|
||||
# rdflib
|
||||
jinja2==3.1.5
|
||||
# via torch
|
||||
joblib==1.4.2
|
||||
# via phonemizer
|
||||
jsonschema==4.23.0
|
||||
# via csvw
|
||||
jsonschema-specifications==2024.10.1
|
||||
# via jsonschema
|
||||
language-tags==1.2.0
|
||||
# via csvw
|
||||
loguru==0.7.3
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
lxml==5.3.0
|
||||
# via clldutils
|
||||
markdown==3.7
|
||||
# via clldutils
|
||||
markupsafe==3.0.2
|
||||
# via
|
||||
# clldutils
|
||||
# jinja2
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
munch==4.0.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
networkx==3.4.2
|
||||
# via torch
|
||||
numpy==2.2.1
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# onnxruntime
|
||||
# scipy
|
||||
# soundfile
|
||||
# transformers
|
||||
onnxruntime==1.20.1
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
packaging==24.2
|
||||
# via
|
||||
# huggingface-hub
|
||||
# onnxruntime
|
||||
# transformers
|
||||
phonemizer==3.3.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
protobuf==5.29.3
|
||||
# via onnxruntime
|
||||
pycparser==2.22
|
||||
# via cffi
|
||||
pydantic==2.10.4
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# fastapi
|
||||
# pydantic-settings
|
||||
pydantic-core==2.27.2
|
||||
# via pydantic
|
||||
pydantic-settings==2.7.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
pylatexenc==2.10
|
||||
# via clldutils
|
||||
pyparsing==3.2.1
|
||||
# via rdflib
|
||||
pyreadline3==3.5.4
|
||||
# via humanfriendly
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# clldutils
|
||||
# csvw
|
||||
python-dotenv==1.0.1
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# pydantic-settings
|
||||
pyyaml==6.0.2
|
||||
# via
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
rdflib==7.1.2
|
||||
# via csvw
|
||||
referencing==0.35.1
|
||||
# via
|
||||
# jsonschema
|
||||
# jsonschema-specifications
|
||||
regex==2024.11.6
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# segments
|
||||
# tiktoken
|
||||
# transformers
|
||||
requests==2.32.3
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# csvw
|
||||
# huggingface-hub
|
||||
# tiktoken
|
||||
# transformers
|
||||
rfc3986==1.5.0
|
||||
# via csvw
|
||||
rpds-py==0.22.3
|
||||
# via
|
||||
# jsonschema
|
||||
# referencing
|
||||
safetensors==0.5.2
|
||||
# via transformers
|
||||
scipy==1.14.1
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
segments==2.2.1
|
||||
# via phonemizer
|
||||
six==1.17.0
|
||||
# via python-dateutil
|
||||
sniffio==1.3.1
|
||||
# via anyio
|
||||
soundfile==0.13.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
sqlalchemy==2.0.27
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
starlette==0.41.3
|
||||
# via fastapi
|
||||
sympy==1.13.1
|
||||
# via
|
||||
# onnxruntime
|
||||
# torch
|
||||
tabulate==0.9.0
|
||||
# via clldutils
|
||||
tiktoken==0.8.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
tokenizers==0.21.0
|
||||
# via transformers
|
||||
torch==2.5.1+cu121
|
||||
# via kokoro-fastapi-gpu (pyproject.toml)
|
||||
tqdm==4.67.1
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
transformers==4.47.1
|
||||
# via kokoro-fastapi-gpu (pyproject.toml)
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# anyio
|
||||
# fastapi
|
||||
# huggingface-hub
|
||||
# phonemizer
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# sqlalchemy
|
||||
# torch
|
||||
# uvicorn
|
||||
uritemplate==4.1.1
|
||||
# via csvw
|
||||
urllib3==2.3.0
|
||||
# via requests
|
||||
uvicorn==0.34.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
win32-setctime==1.2.0
|
||||
# via loguru
|
1914
docker/gpu/uv.lock
generated
1914
docker/gpu/uv.lock
generated
File diff suppressed because it is too large
Load diff
Binary file not shown.
|
@ -16,7 +16,6 @@ dependencies = [
|
|||
# ML/DL Base
|
||||
"numpy>=1.26.0",
|
||||
"scipy==1.14.1",
|
||||
"onnxruntime==1.20.1",
|
||||
# Audio processing
|
||||
"soundfile==0.13.0",
|
||||
# Text processing
|
||||
|
@ -39,9 +38,11 @@ dependencies = [
|
|||
[project.optional-dependencies]
|
||||
gpu = [
|
||||
"torch==2.5.1+cu121",
|
||||
"onnxruntime-gpu==1.20.1",
|
||||
]
|
||||
cpu = [
|
||||
"torch==2.5.1",
|
||||
"onnxruntime==1.20.1",
|
||||
]
|
||||
test = [
|
||||
"pytest==8.0.0",
|
||||
|
|
Loading…
Add table
Reference in a new issue