2025-01-28 13:52:57 -07:00
|
|
|
#!/usr/bin/env python3
|
2025-02-04 03:37:56 -07:00
|
|
|
"""Download and prepare Kokoro model for Docker build."""
|
|
|
|
|
2025-01-28 13:52:57 -07:00
|
|
|
import argparse
|
2025-02-04 03:37:56 -07:00
|
|
|
import json
|
|
|
|
import os
|
|
|
|
import shutil
|
2025-01-28 13:52:57 -07:00
|
|
|
from pathlib import Path
|
|
|
|
|
2025-02-04 03:37:56 -07:00
|
|
|
import torch
|
|
|
|
from huggingface_hub import hf_hub_download
|
|
|
|
from loguru import logger
|
2025-01-28 13:52:57 -07:00
|
|
|
|
|
|
|
|
2025-02-04 03:37:56 -07:00
|
|
|
def download_model(version: str, output_dir: str) -> None:
|
|
|
|
"""Download model files from HuggingFace.
|
2025-01-28 13:52:57 -07:00
|
|
|
|
2025-02-04 03:37:56 -07:00
|
|
|
Args:
|
|
|
|
version: Model version to download
|
|
|
|
output_dir: Directory to save model files
|
2025-01-28 13:52:57 -07:00
|
|
|
"""
|
|
|
|
try:
|
2025-02-04 03:37:56 -07:00
|
|
|
logger.info(f"Downloading Kokoro model version {version}")
|
|
|
|
|
|
|
|
# Create output directory
|
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
2025-01-28 13:52:57 -07:00
|
|
|
|
2025-02-04 03:37:56 -07:00
|
|
|
# Download model files
|
|
|
|
model_file = hf_hub_download(
|
|
|
|
repo_id="hexgrad/Kokoro-82M",
|
|
|
|
filename=f"kokoro-{version}.pth"
|
|
|
|
)
|
|
|
|
config_file = hf_hub_download(
|
|
|
|
repo_id="hexgrad/Kokoro-82M",
|
|
|
|
filename="config.json"
|
|
|
|
)
|
2025-01-28 13:52:57 -07:00
|
|
|
|
2025-02-04 03:37:56 -07:00
|
|
|
# Copy to output directory
|
|
|
|
shutil.copy2(model_file, os.path.join(output_dir, "model.pt"))
|
|
|
|
shutil.copy2(config_file, os.path.join(output_dir, "config.json"))
|
2025-01-28 13:52:57 -07:00
|
|
|
|
2025-02-04 03:37:56 -07:00
|
|
|
# Verify files
|
|
|
|
model_path = os.path.join(output_dir, "model.pt")
|
|
|
|
config_path = os.path.join(output_dir, "config.json")
|
2025-01-28 13:52:57 -07:00
|
|
|
|
2025-02-04 03:37:56 -07:00
|
|
|
if not os.path.exists(model_path):
|
|
|
|
raise RuntimeError(f"Model file not found: {model_path}")
|
|
|
|
if not os.path.exists(config_path):
|
|
|
|
raise RuntimeError(f"Config file not found: {config_path}")
|
2025-01-28 13:52:57 -07:00
|
|
|
|
2025-02-04 03:37:56 -07:00
|
|
|
# Load and verify model
|
|
|
|
logger.info("Verifying model files...")
|
|
|
|
with open(config_path) as f:
|
|
|
|
config = json.load(f)
|
|
|
|
logger.info(f"Loaded config: {config}")
|
|
|
|
|
|
|
|
model = torch.load(model_path, map_location="cpu")
|
|
|
|
logger.info(f"Loaded model with keys: {model.keys()}")
|
|
|
|
|
|
|
|
logger.info(f"✓ Model files prepared in {output_dir}")
|
|
|
|
|
2025-01-28 13:52:57 -07:00
|
|
|
except Exception as e:
|
2025-02-04 03:37:56 -07:00
|
|
|
logger.error(f"Failed to download model: {e}")
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
"""Main entry point."""
|
|
|
|
parser = argparse.ArgumentParser(description="Download Kokoro model for Docker build")
|
|
|
|
parser.add_argument(
|
|
|
|
"--version",
|
|
|
|
default="v1_0",
|
|
|
|
help="Model version to download"
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--output",
|
|
|
|
required=True,
|
|
|
|
help="Output directory for model files"
|
|
|
|
)
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
download_model(args.version, args.output)
|
|
|
|
|
2025-01-28 13:52:57 -07:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2025-02-04 03:37:56 -07:00
|
|
|
main()
|