diff --git a/docker/gpu/Dockerfile b/docker/gpu/Dockerfile index 3134f4e..f1cae1a 100644 --- a/docker/gpu/Dockerfile +++ b/docker/gpu/Dockerfile @@ -44,6 +44,10 @@ COPY --chown=appuser:appuser docker/scripts/download_model.* ./ RUN --mount=type=cache,target=/root/.cache/uv \ uv sync --extra gpu +COPY --chown=appuser:appuser docker/scripts/ /app/docker/scripts/ +RUN chmod +x docker/scripts/entrypoint.sh +RUN chmod +x docker/scripts/download_model.sh + # Set environment variables ENV PYTHONUNBUFFERED=1 ENV PYTHONPATH=/app @@ -55,13 +59,5 @@ ENV USE_ONNX=false ENV DOWNLOAD_PTH=true ENV DOWNLOAD_ONNX=false -# Download models based on environment variables -RUN if [ "$DOWNLOAD_PTH" = "true" ]; then \ - python download_model.py --type pth; \ - fi && \ - if [ "$DOWNLOAD_ONNX" = "true" ]; then \ - python download_model.py --type onnx; \ - fi - # Run FastAPI server -CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"] +CMD ["/app/docker/scripts/entrypoint.sh"] diff --git a/docker/scripts/download_model.py b/docker/scripts/download_model.py index bc808df..c27f802 100644 --- a/docker/scripts/download_model.py +++ b/docker/scripts/download_model.py @@ -6,7 +6,7 @@ import requests from pathlib import Path from typing import List -def download_file(url: str, output_dir: Path, model_type: str) -> bool: +def download_file(url: str, output_dir: Path, model_type: str, overwrite:str) -> bool: """Download a file from URL to the specified directory. Returns: @@ -19,6 +19,10 @@ def download_file(url: str, output_dir: Path, model_type: str) -> bool: output_path = output_dir / filename + if os.path.exists(output_path): + print(f"{filename} exists. Canceling download") + return True + print(f"Downloading {filename}...") try: response = requests.get(url, stream=True) @@ -52,6 +56,7 @@ def main() -> int: parser = argparse.ArgumentParser(description='Download model files') parser.add_argument('--type', choices=['pth', 'onnx'], required=True, help='Model type to download (pth or onnx)') + parser.add_argument('--overwrite', action='store_true', help='Overwite existing files') parser.add_argument('urls', nargs='*', help='Optional model URLs to download') args = parser.parse_args() @@ -65,7 +70,8 @@ def main() -> int: # Default models if no arguments provided default_models = { 'pth': [ - "https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth" + "https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth", + "https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19-half.pth" ], 'onnx': [ "https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.onnx", @@ -79,7 +85,7 @@ def main() -> int: # Download all models success = True for model_url in models_to_download: - if not download_file(model_url, models_dir, args.type): + if not download_file(model_url, models_dir, args.type,args.overwrite): success = False if success: diff --git a/docker/scripts/download_model.sh b/docker/scripts/download_model.sh index 926a831..24513e3 100644 --- a/docker/scripts/download_model.sh +++ b/docker/scripts/download_model.sh @@ -77,6 +77,7 @@ mkdir -p "$MODELS_DIR" if [ "$MODEL_TYPE" = "pth" ]; then DEFAULT_MODELS=( "https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth" + "https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19-half.pth" ) else DEFAULT_MODELS=( diff --git a/docker/scripts/entrypoint.sh b/docker/scripts/entrypoint.sh new file mode 100644 index 0000000..993e846 --- /dev/null +++ b/docker/scripts/entrypoint.sh @@ -0,0 +1,12 @@ +#!/bin/bash +set -e + +if [ "$DOWNLOAD_PTH" = "true" ]; then + python docker/scripts/download_model.py --type pth +fi + +if [ "$DOWNLOAD_ONNX" = "true" ]; then + python docker/scripts/download_model.py --type onnx +fi + +exec uv run python -m uvicorn api.src.main:app --host 0.0.0.0 --port 8880 --log-level debug \ No newline at end of file