added support for mps on mac with apple silicon

This commit is contained in:
Cong Nguyen 2025-03-10 11:58:45 +11:00
parent a578d22084
commit 9a9bc4aca9
7 changed files with 89 additions and 10 deletions

3
.gitignore vendored
View file

@ -70,3 +70,6 @@ examples/speech.mp3
examples/phoneme_examples/output/*.wav
examples/assorted_checks/benchmarks/output_audio/*
uv.lock
# Mac MPS virtualenv for dual testing
.venv-mps

View file

@ -1,4 +1,5 @@
from pydantic_settings import BaseSettings
import torch
class Settings(BaseSettings):
@ -15,6 +16,7 @@ class Settings(BaseSettings):
default_voice: str = "af_heart"
default_voice_code: str | None = None # If set, overrides the first letter of voice name, though api call param still takes precedence
use_gpu: bool = True # Whether to use GPU acceleration if available
device_type: str | None = None # Will be auto-detected if None, can be "cuda", "mps", or "cpu"
allow_local_voice_saving: bool = (
False # Whether to allow saving combined voices locally
)
@ -29,7 +31,7 @@ class Settings(BaseSettings):
target_min_tokens: int = 175 # Target minimum tokens per chunk
target_max_tokens: int = 250 # Target maximum tokens per chunk
absolute_max_tokens: int = 450 # Absolute maximum tokens per chunk
advanced_text_normalization: bool = True # Preproesses the text before misiki which leads
advanced_text_normalization: bool = True # Preproesses the text before misiki which leads
gap_trim_ms: int = 1 # Base amount to trim from streaming chunk ends in milliseconds
dynamic_gap_trim_padding_ms: int = 410 # Padding to add to dynamic gap trim
@ -50,5 +52,21 @@ class Settings(BaseSettings):
class Config:
env_file = ".env"
def get_device(self) -> str:
"""Get the appropriate device based on settings and availability"""
if not self.use_gpu:
return "cpu"
if self.device_type:
return self.device_type
# Auto-detect device
if torch.backends.mps.is_available():
return "mps"
elif torch.cuda.is_available():
return "cuda"
return "cpu"
settings = Settings()

View file

@ -21,7 +21,7 @@ class KokoroV1(BaseModelBackend):
"""Initialize backend with environment-based configuration."""
super().__init__()
# Strictly respect settings.use_gpu
self._device = "cuda" if settings.use_gpu else "cpu"
self._device = settings.get_device()
self._model: Optional[KModel] = None
self._pipelines: Dict[str, KPipeline] = {} # Store pipelines by lang_code
@ -48,9 +48,14 @@ class KokoroV1(BaseModelBackend):
# Load model and let KModel handle device mapping
self._model = KModel(config=config_path, model=model_path).eval()
# Move to CUDA if needed
if self._device == "cuda":
# For MPS, manually move ISTFT layers to CPU while keeping rest on MPS
if self._device == "mps":
logger.info("Moving model to MPS device with CPU fallback for unsupported operations")
self._model = self._model.to(torch.device("mps"))
elif self._device == "cuda":
self._model = self._model.cuda()
else:
self._model = self._model.cpu()
except FileNotFoundError as e:
raise e
@ -277,7 +282,7 @@ class KokoroV1(BaseModelBackend):
continue
if not token.text or not token.text.strip():
continue
start_time = float(token.start_ts) + current_offset
end_time = float(token.end_ts) + current_offset
word_timestamps.append(
@ -295,8 +300,8 @@ class KokoroV1(BaseModelBackend):
logger.error(
f"Failed to process timestamps for chunk: {e}"
)
yield AudioChunk(result.audio.numpy(),word_timestamps=word_timestamps)
else:
logger.warning("No audio in chunk")
@ -318,6 +323,7 @@ class KokoroV1(BaseModelBackend):
if self._device == "cuda":
memory_gb = torch.cuda.memory_allocated() / 1e9
return memory_gb > model_config.pytorch_gpu.memory_threshold
# MPS doesn't provide memory management APIs
return False
def _clear_memory(self) -> None:
@ -325,6 +331,10 @@ class KokoroV1(BaseModelBackend):
if self._device == "cuda":
torch.cuda.empty_cache()
torch.cuda.synchronize()
elif self._device == "mps":
# Empty cache if available (future-proofing)
if hasattr(torch.mps, 'empty_cache'):
torch.mps.empty_cache()
def unload(self) -> None:
"""Unload model and free resources."""

View file

@ -19,7 +19,7 @@ class VoiceManager:
def __init__(self):
"""Initialize voice manager."""
# Strictly respect settings.use_gpu
self._device = "cuda" if settings.use_gpu else "cpu"
self._device = settings.get_device()
self._voices: Dict[str, torch.Tensor] = {}
async def get_voice_path(self, voice_name: str) -> str:

View file

@ -85,7 +85,12 @@ async def lifespan(app: FastAPI):
{boundary}
"""
startup_msg += f"\nModel warmed up on {device}: {model}"
startup_msg += f"CUDA: {torch.cuda.is_available()}"
if device == "mps":
startup_msg += "\nUsing Apple Metal Performance Shaders (MPS)"
elif device == "cuda":
startup_msg += f"\nCUDA: {torch.cuda.is_available()}"
else:
startup_msg += "\nRunning on CPU"
startup_msg += f"\n{voicepack_count} voice packs loaded"
# Add web player info if enabled

View file

@ -4,6 +4,7 @@ from datetime import datetime
import psutil
from fastapi import APIRouter
import torch
try:
import GPUtil
@ -113,7 +114,14 @@ async def get_system_info():
# GPU Info if available
gpu_info = None
if GPU_AVAILABLE:
if torch.backends.mps.is_available():
gpu_info = {
"type": "MPS",
"available": True,
"device": "Apple Silicon",
"backend": "Metal"
}
elif GPU_AVAILABLE:
try:
gpus = GPUtil.getGPUs()
gpu_info = [

35
start-gpu_mac.sh Executable file
View file

@ -0,0 +1,35 @@
#!/bin/bash
# Get project root directory
PROJECT_ROOT=$(pwd)
# Create mps-specific venv directory
VENV_DIR="$PROJECT_ROOT/.venv-mps"
if [ ! -d "$VENV_DIR" ]; then
echo "Creating MPS-specific virtual environment..."
python3 -m venv "$VENV_DIR"
fi
# Set other environment variables
export USE_GPU=true
export USE_ONNX=false
export PYTHONPATH=$PROJECT_ROOT:$PROJECT_ROOT/api
export MODEL_DIR=src/models
export VOICES_DIR=src/voices/v1_0
export WEB_PLAYER_PATH=$PROJECT_ROOT/web
# Set environment variables
export USE_GPU=true
export USE_ONNX=false
export PYTHONPATH=$PROJECT_ROOT:$PROJECT_ROOT/api
export MODEL_DIR=src/models
export VOICES_DIR=src/voices/v1_0
export WEB_PLAYER_PATH=$PROJECT_ROOT/web
export DEVICE_TYPE=mps
# Enable MPS fallback for unsupported operations
export PYTORCH_ENABLE_MPS_FALLBACK=1
# Run FastAPI with GPU extras using uv run
uv pip install -e .
uv run --no-sync uvicorn api.src.main:app --host 0.0.0.0 --port 8881