Fixes auto downloading models and adds the option to the python script to overwrite existing files

This commit is contained in:
Fireblade 2025-01-28 16:44:20 -05:00
parent 9867fc398f
commit 66ebd0e33c
4 changed files with 27 additions and 12 deletions

View file

@ -44,6 +44,10 @@ COPY --chown=appuser:appuser docker/scripts/download_model.* ./
RUN --mount=type=cache,target=/root/.cache/uv \
uv sync --extra gpu
COPY --chown=appuser:appuser docker/scripts/ /app/docker/scripts/
RUN chmod +x docker/scripts/entrypoint.sh
RUN chmod +x docker/scripts/download_model.sh
# Set environment variables
ENV PYTHONUNBUFFERED=1
ENV PYTHONPATH=/app
@ -55,13 +59,5 @@ ENV USE_ONNX=false
ENV DOWNLOAD_PTH=true
ENV DOWNLOAD_ONNX=false
# Download models based on environment variables
RUN if [ "$DOWNLOAD_PTH" = "true" ]; then \
python download_model.py --type pth; \
fi && \
if [ "$DOWNLOAD_ONNX" = "true" ]; then \
python download_model.py --type onnx; \
fi
# Run FastAPI server
CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]
CMD ["/app/docker/scripts/entrypoint.sh"]

View file

@ -6,7 +6,7 @@ import requests
from pathlib import Path
from typing import List
def download_file(url: str, output_dir: Path, model_type: str) -> bool:
def download_file(url: str, output_dir: Path, model_type: str, overwrite:str) -> bool:
"""Download a file from URL to the specified directory.
Returns:
@ -19,6 +19,10 @@ def download_file(url: str, output_dir: Path, model_type: str) -> bool:
output_path = output_dir / filename
if os.path.exists(output_path):
print(f"{filename} exists. Canceling download")
return True
print(f"Downloading {filename}...")
try:
response = requests.get(url, stream=True)
@ -52,6 +56,7 @@ def main() -> int:
parser = argparse.ArgumentParser(description='Download model files')
parser.add_argument('--type', choices=['pth', 'onnx'], required=True,
help='Model type to download (pth or onnx)')
parser.add_argument('--overwrite', action='store_true', help='Overwite existing files')
parser.add_argument('urls', nargs='*', help='Optional model URLs to download')
args = parser.parse_args()
@ -65,7 +70,8 @@ def main() -> int:
# Default models if no arguments provided
default_models = {
'pth': [
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth"
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth",
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19-half.pth"
],
'onnx': [
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.onnx",
@ -79,7 +85,7 @@ def main() -> int:
# Download all models
success = True
for model_url in models_to_download:
if not download_file(model_url, models_dir, args.type):
if not download_file(model_url, models_dir, args.type,args.overwrite):
success = False
if success:

View file

@ -77,6 +77,7 @@ mkdir -p "$MODELS_DIR"
if [ "$MODEL_TYPE" = "pth" ]; then
DEFAULT_MODELS=(
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth"
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19-half.pth"
)
else
DEFAULT_MODELS=(

View file

@ -0,0 +1,12 @@
#!/bin/bash
set -e
if [ "$DOWNLOAD_PTH" = "true" ]; then
python docker/scripts/download_model.py --type pth
fi
if [ "$DOWNLOAD_ONNX" = "true" ]; then
python docker/scripts/download_model.py --type onnx
fi
exec uv run python -m uvicorn api.src.main:app --host 0.0.0.0 --port 8880 --log-level debug