mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Merge pull request #233 from rampadc/master
This commit is contained in:
commit
04b5dfa84c
7 changed files with 88 additions and 9 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -70,3 +70,6 @@ examples/speech.mp3
|
||||||
examples/phoneme_examples/output/*.wav
|
examples/phoneme_examples/output/*.wav
|
||||||
examples/assorted_checks/benchmarks/output_audio/*
|
examples/assorted_checks/benchmarks/output_audio/*
|
||||||
uv.lock
|
uv.lock
|
||||||
|
|
||||||
|
# Mac MPS virtualenv for dual testing
|
||||||
|
.venv-mps
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
|
@ -15,6 +16,7 @@ class Settings(BaseSettings):
|
||||||
default_voice: str = "af_heart"
|
default_voice: str = "af_heart"
|
||||||
default_voice_code: str | None = None # If set, overrides the first letter of voice name, though api call param still takes precedence
|
default_voice_code: str | None = None # If set, overrides the first letter of voice name, though api call param still takes precedence
|
||||||
use_gpu: bool = True # Whether to use GPU acceleration if available
|
use_gpu: bool = True # Whether to use GPU acceleration if available
|
||||||
|
device_type: str | None = None # Will be auto-detected if None, can be "cuda", "mps", or "cpu"
|
||||||
allow_local_voice_saving: bool = (
|
allow_local_voice_saving: bool = (
|
||||||
False # Whether to allow saving combined voices locally
|
False # Whether to allow saving combined voices locally
|
||||||
)
|
)
|
||||||
|
@ -51,5 +53,21 @@ class Settings(BaseSettings):
|
||||||
class Config:
|
class Config:
|
||||||
env_file = ".env"
|
env_file = ".env"
|
||||||
|
|
||||||
|
def get_device(self) -> str:
|
||||||
|
"""Get the appropriate device based on settings and availability"""
|
||||||
|
if not self.use_gpu:
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
|
if self.device_type:
|
||||||
|
return self.device_type
|
||||||
|
|
||||||
|
# Auto-detect device
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
return "mps"
|
||||||
|
elif torch.cuda.is_available():
|
||||||
|
return "cuda"
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
|
@ -21,7 +21,7 @@ class KokoroV1(BaseModelBackend):
|
||||||
"""Initialize backend with environment-based configuration."""
|
"""Initialize backend with environment-based configuration."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Strictly respect settings.use_gpu
|
# Strictly respect settings.use_gpu
|
||||||
self._device = "cuda" if settings.use_gpu else "cpu"
|
self._device = settings.get_device()
|
||||||
self._model: Optional[KModel] = None
|
self._model: Optional[KModel] = None
|
||||||
self._pipelines: Dict[str, KPipeline] = {} # Store pipelines by lang_code
|
self._pipelines: Dict[str, KPipeline] = {} # Store pipelines by lang_code
|
||||||
|
|
||||||
|
@ -48,9 +48,14 @@ class KokoroV1(BaseModelBackend):
|
||||||
|
|
||||||
# Load model and let KModel handle device mapping
|
# Load model and let KModel handle device mapping
|
||||||
self._model = KModel(config=config_path, model=model_path).eval()
|
self._model = KModel(config=config_path, model=model_path).eval()
|
||||||
# Move to CUDA if needed
|
# For MPS, manually move ISTFT layers to CPU while keeping rest on MPS
|
||||||
if self._device == "cuda":
|
if self._device == "mps":
|
||||||
|
logger.info("Moving model to MPS device with CPU fallback for unsupported operations")
|
||||||
|
self._model = self._model.to(torch.device("mps"))
|
||||||
|
elif self._device == "cuda":
|
||||||
self._model = self._model.cuda()
|
self._model = self._model.cuda()
|
||||||
|
else:
|
||||||
|
self._model = self._model.cpu()
|
||||||
|
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
raise e
|
raise e
|
||||||
|
@ -314,6 +319,7 @@ class KokoroV1(BaseModelBackend):
|
||||||
if self._device == "cuda":
|
if self._device == "cuda":
|
||||||
memory_gb = torch.cuda.memory_allocated() / 1e9
|
memory_gb = torch.cuda.memory_allocated() / 1e9
|
||||||
return memory_gb > model_config.pytorch_gpu.memory_threshold
|
return memory_gb > model_config.pytorch_gpu.memory_threshold
|
||||||
|
# MPS doesn't provide memory management APIs
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _clear_memory(self) -> None:
|
def _clear_memory(self) -> None:
|
||||||
|
@ -321,6 +327,10 @@ class KokoroV1(BaseModelBackend):
|
||||||
if self._device == "cuda":
|
if self._device == "cuda":
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
elif self._device == "mps":
|
||||||
|
# Empty cache if available (future-proofing)
|
||||||
|
if hasattr(torch.mps, 'empty_cache'):
|
||||||
|
torch.mps.empty_cache()
|
||||||
|
|
||||||
def unload(self) -> None:
|
def unload(self) -> None:
|
||||||
"""Unload model and free resources."""
|
"""Unload model and free resources."""
|
||||||
|
|
|
@ -19,7 +19,7 @@ class VoiceManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize voice manager."""
|
"""Initialize voice manager."""
|
||||||
# Strictly respect settings.use_gpu
|
# Strictly respect settings.use_gpu
|
||||||
self._device = "cuda" if settings.use_gpu else "cpu"
|
self._device = settings.get_device()
|
||||||
self._voices: Dict[str, torch.Tensor] = {}
|
self._voices: Dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
async def get_voice_path(self, voice_name: str) -> str:
|
async def get_voice_path(self, voice_name: str) -> str:
|
||||||
|
|
|
@ -85,7 +85,12 @@ async def lifespan(app: FastAPI):
|
||||||
{boundary}
|
{boundary}
|
||||||
"""
|
"""
|
||||||
startup_msg += f"\nModel warmed up on {device}: {model}"
|
startup_msg += f"\nModel warmed up on {device}: {model}"
|
||||||
startup_msg += f"CUDA: {torch.cuda.is_available()}"
|
if device == "mps":
|
||||||
|
startup_msg += "\nUsing Apple Metal Performance Shaders (MPS)"
|
||||||
|
elif device == "cuda":
|
||||||
|
startup_msg += f"\nCUDA: {torch.cuda.is_available()}"
|
||||||
|
else:
|
||||||
|
startup_msg += "\nRunning on CPU"
|
||||||
startup_msg += f"\n{voicepack_count} voice packs loaded"
|
startup_msg += f"\n{voicepack_count} voice packs loaded"
|
||||||
|
|
||||||
# Add web player info if enabled
|
# Add web player info if enabled
|
||||||
|
|
|
@ -4,6 +4,7 @@ from datetime import datetime
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
import torch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import GPUtil
|
import GPUtil
|
||||||
|
@ -113,7 +114,14 @@ async def get_system_info():
|
||||||
|
|
||||||
# GPU Info if available
|
# GPU Info if available
|
||||||
gpu_info = None
|
gpu_info = None
|
||||||
if GPU_AVAILABLE:
|
if torch.backends.mps.is_available():
|
||||||
|
gpu_info = {
|
||||||
|
"type": "MPS",
|
||||||
|
"available": True,
|
||||||
|
"device": "Apple Silicon",
|
||||||
|
"backend": "Metal"
|
||||||
|
}
|
||||||
|
elif GPU_AVAILABLE:
|
||||||
try:
|
try:
|
||||||
gpus = GPUtil.getGPUs()
|
gpus = GPUtil.getGPUs()
|
||||||
gpu_info = [
|
gpu_info = [
|
||||||
|
|
35
start-gpu_mac.sh
Executable file
35
start-gpu_mac.sh
Executable file
|
@ -0,0 +1,35 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Get project root directory
|
||||||
|
PROJECT_ROOT=$(pwd)
|
||||||
|
|
||||||
|
# Create mps-specific venv directory
|
||||||
|
VENV_DIR="$PROJECT_ROOT/.venv-mps"
|
||||||
|
if [ ! -d "$VENV_DIR" ]; then
|
||||||
|
echo "Creating MPS-specific virtual environment..."
|
||||||
|
python3 -m venv "$VENV_DIR"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Set other environment variables
|
||||||
|
export USE_GPU=true
|
||||||
|
export USE_ONNX=false
|
||||||
|
export PYTHONPATH=$PROJECT_ROOT:$PROJECT_ROOT/api
|
||||||
|
export MODEL_DIR=src/models
|
||||||
|
export VOICES_DIR=src/voices/v1_0
|
||||||
|
export WEB_PLAYER_PATH=$PROJECT_ROOT/web
|
||||||
|
|
||||||
|
# Set environment variables
|
||||||
|
export USE_GPU=true
|
||||||
|
export USE_ONNX=false
|
||||||
|
export PYTHONPATH=$PROJECT_ROOT:$PROJECT_ROOT/api
|
||||||
|
export MODEL_DIR=src/models
|
||||||
|
export VOICES_DIR=src/voices/v1_0
|
||||||
|
export WEB_PLAYER_PATH=$PROJECT_ROOT/web
|
||||||
|
|
||||||
|
export DEVICE_TYPE=mps
|
||||||
|
# Enable MPS fallback for unsupported operations
|
||||||
|
export PYTORCH_ENABLE_MPS_FALLBACK=1
|
||||||
|
|
||||||
|
# Run FastAPI with GPU extras using uv run
|
||||||
|
uv pip install -e .
|
||||||
|
uv run --no-sync uvicorn api.src.main:app --host 0.0.0.0 --port 8880
|
Loading…
Add table
Reference in a new issue