mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
added support for mps on mac with apple silicon
This commit is contained in:
parent
a578d22084
commit
9a9bc4aca9
7 changed files with 89 additions and 10 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -70,3 +70,6 @@ examples/speech.mp3
|
|||
examples/phoneme_examples/output/*.wav
|
||||
examples/assorted_checks/benchmarks/output_audio/*
|
||||
uv.lock
|
||||
|
||||
# Mac MPS virtualenv for dual testing
|
||||
.venv-mps
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from pydantic_settings import BaseSettings
|
||||
import torch
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
|
@ -15,6 +16,7 @@ class Settings(BaseSettings):
|
|||
default_voice: str = "af_heart"
|
||||
default_voice_code: str | None = None # If set, overrides the first letter of voice name, though api call param still takes precedence
|
||||
use_gpu: bool = True # Whether to use GPU acceleration if available
|
||||
device_type: str | None = None # Will be auto-detected if None, can be "cuda", "mps", or "cpu"
|
||||
allow_local_voice_saving: bool = (
|
||||
False # Whether to allow saving combined voices locally
|
||||
)
|
||||
|
@ -29,7 +31,7 @@ class Settings(BaseSettings):
|
|||
target_min_tokens: int = 175 # Target minimum tokens per chunk
|
||||
target_max_tokens: int = 250 # Target maximum tokens per chunk
|
||||
absolute_max_tokens: int = 450 # Absolute maximum tokens per chunk
|
||||
advanced_text_normalization: bool = True # Preproesses the text before misiki which leads
|
||||
advanced_text_normalization: bool = True # Preproesses the text before misiki which leads
|
||||
|
||||
gap_trim_ms: int = 1 # Base amount to trim from streaming chunk ends in milliseconds
|
||||
dynamic_gap_trim_padding_ms: int = 410 # Padding to add to dynamic gap trim
|
||||
|
@ -50,5 +52,21 @@ class Settings(BaseSettings):
|
|||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
def get_device(self) -> str:
|
||||
"""Get the appropriate device based on settings and availability"""
|
||||
if not self.use_gpu:
|
||||
return "cpu"
|
||||
|
||||
if self.device_type:
|
||||
return self.device_type
|
||||
|
||||
# Auto-detect device
|
||||
if torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
elif torch.cuda.is_available():
|
||||
return "cuda"
|
||||
return "cpu"
|
||||
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
|
|
@ -21,7 +21,7 @@ class KokoroV1(BaseModelBackend):
|
|||
"""Initialize backend with environment-based configuration."""
|
||||
super().__init__()
|
||||
# Strictly respect settings.use_gpu
|
||||
self._device = "cuda" if settings.use_gpu else "cpu"
|
||||
self._device = settings.get_device()
|
||||
self._model: Optional[KModel] = None
|
||||
self._pipelines: Dict[str, KPipeline] = {} # Store pipelines by lang_code
|
||||
|
||||
|
@ -48,9 +48,14 @@ class KokoroV1(BaseModelBackend):
|
|||
|
||||
# Load model and let KModel handle device mapping
|
||||
self._model = KModel(config=config_path, model=model_path).eval()
|
||||
# Move to CUDA if needed
|
||||
if self._device == "cuda":
|
||||
# For MPS, manually move ISTFT layers to CPU while keeping rest on MPS
|
||||
if self._device == "mps":
|
||||
logger.info("Moving model to MPS device with CPU fallback for unsupported operations")
|
||||
self._model = self._model.to(torch.device("mps"))
|
||||
elif self._device == "cuda":
|
||||
self._model = self._model.cuda()
|
||||
else:
|
||||
self._model = self._model.cpu()
|
||||
|
||||
except FileNotFoundError as e:
|
||||
raise e
|
||||
|
@ -277,7 +282,7 @@ class KokoroV1(BaseModelBackend):
|
|||
continue
|
||||
if not token.text or not token.text.strip():
|
||||
continue
|
||||
|
||||
|
||||
start_time = float(token.start_ts) + current_offset
|
||||
end_time = float(token.end_ts) + current_offset
|
||||
word_timestamps.append(
|
||||
|
@ -295,8 +300,8 @@ class KokoroV1(BaseModelBackend):
|
|||
logger.error(
|
||||
f"Failed to process timestamps for chunk: {e}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
yield AudioChunk(result.audio.numpy(),word_timestamps=word_timestamps)
|
||||
else:
|
||||
logger.warning("No audio in chunk")
|
||||
|
@ -318,6 +323,7 @@ class KokoroV1(BaseModelBackend):
|
|||
if self._device == "cuda":
|
||||
memory_gb = torch.cuda.memory_allocated() / 1e9
|
||||
return memory_gb > model_config.pytorch_gpu.memory_threshold
|
||||
# MPS doesn't provide memory management APIs
|
||||
return False
|
||||
|
||||
def _clear_memory(self) -> None:
|
||||
|
@ -325,6 +331,10 @@ class KokoroV1(BaseModelBackend):
|
|||
if self._device == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
elif self._device == "mps":
|
||||
# Empty cache if available (future-proofing)
|
||||
if hasattr(torch.mps, 'empty_cache'):
|
||||
torch.mps.empty_cache()
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Unload model and free resources."""
|
||||
|
|
|
@ -19,7 +19,7 @@ class VoiceManager:
|
|||
def __init__(self):
|
||||
"""Initialize voice manager."""
|
||||
# Strictly respect settings.use_gpu
|
||||
self._device = "cuda" if settings.use_gpu else "cpu"
|
||||
self._device = settings.get_device()
|
||||
self._voices: Dict[str, torch.Tensor] = {}
|
||||
|
||||
async def get_voice_path(self, voice_name: str) -> str:
|
||||
|
|
|
@ -85,7 +85,12 @@ async def lifespan(app: FastAPI):
|
|||
{boundary}
|
||||
"""
|
||||
startup_msg += f"\nModel warmed up on {device}: {model}"
|
||||
startup_msg += f"CUDA: {torch.cuda.is_available()}"
|
||||
if device == "mps":
|
||||
startup_msg += "\nUsing Apple Metal Performance Shaders (MPS)"
|
||||
elif device == "cuda":
|
||||
startup_msg += f"\nCUDA: {torch.cuda.is_available()}"
|
||||
else:
|
||||
startup_msg += "\nRunning on CPU"
|
||||
startup_msg += f"\n{voicepack_count} voice packs loaded"
|
||||
|
||||
# Add web player info if enabled
|
||||
|
|
|
@ -4,6 +4,7 @@ from datetime import datetime
|
|||
|
||||
import psutil
|
||||
from fastapi import APIRouter
|
||||
import torch
|
||||
|
||||
try:
|
||||
import GPUtil
|
||||
|
@ -113,7 +114,14 @@ async def get_system_info():
|
|||
|
||||
# GPU Info if available
|
||||
gpu_info = None
|
||||
if GPU_AVAILABLE:
|
||||
if torch.backends.mps.is_available():
|
||||
gpu_info = {
|
||||
"type": "MPS",
|
||||
"available": True,
|
||||
"device": "Apple Silicon",
|
||||
"backend": "Metal"
|
||||
}
|
||||
elif GPU_AVAILABLE:
|
||||
try:
|
||||
gpus = GPUtil.getGPUs()
|
||||
gpu_info = [
|
||||
|
|
35
start-gpu_mac.sh
Executable file
35
start-gpu_mac.sh
Executable file
|
@ -0,0 +1,35 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Get project root directory
|
||||
PROJECT_ROOT=$(pwd)
|
||||
|
||||
# Create mps-specific venv directory
|
||||
VENV_DIR="$PROJECT_ROOT/.venv-mps"
|
||||
if [ ! -d "$VENV_DIR" ]; then
|
||||
echo "Creating MPS-specific virtual environment..."
|
||||
python3 -m venv "$VENV_DIR"
|
||||
fi
|
||||
|
||||
# Set other environment variables
|
||||
export USE_GPU=true
|
||||
export USE_ONNX=false
|
||||
export PYTHONPATH=$PROJECT_ROOT:$PROJECT_ROOT/api
|
||||
export MODEL_DIR=src/models
|
||||
export VOICES_DIR=src/voices/v1_0
|
||||
export WEB_PLAYER_PATH=$PROJECT_ROOT/web
|
||||
|
||||
# Set environment variables
|
||||
export USE_GPU=true
|
||||
export USE_ONNX=false
|
||||
export PYTHONPATH=$PROJECT_ROOT:$PROJECT_ROOT/api
|
||||
export MODEL_DIR=src/models
|
||||
export VOICES_DIR=src/voices/v1_0
|
||||
export WEB_PLAYER_PATH=$PROJECT_ROOT/web
|
||||
|
||||
export DEVICE_TYPE=mps
|
||||
# Enable MPS fallback for unsupported operations
|
||||
export PYTORCH_ENABLE_MPS_FALLBACK=1
|
||||
|
||||
# Run FastAPI with GPU extras using uv run
|
||||
uv pip install -e .
|
||||
uv run --no-sync uvicorn api.src.main:app --host 0.0.0.0 --port 8881
|
Loading…
Add table
Reference in a new issue