Kokoro-FastAPI/api/src/main.py

175 lines
5 KiB
Python
Raw Normal View History

"""
FastAPI OpenAI Compatible API
"""
import os
import sys
2025-01-09 18:41:44 -07:00
from contextlib import asynccontextmanager
2025-01-28 13:52:57 -07:00
from pathlib import Path
import torch
import uvicorn
from fastapi import Depends, FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import HTTPBasic, HTTPBasicCredentials
2025-01-13 20:15:46 -07:00
from loguru import logger
from .core.config import settings
2025-02-09 18:32:17 -07:00
from .routers.debug import router as debug_router
2025-01-09 18:41:44 -07:00
from .routers.development import router as dev_router
from .routers.openai_compatible import router as openai_router
2025-02-09 18:32:17 -07:00
from .routers.web_player import router as web_router
def setup_logger():
"""Configure loguru logger with custom formatting"""
config = {
"handlers": [
{
"sink": sys.stdout,
"format": "<fg #2E8B57>{time:hh:mm:ss A}</fg #2E8B57> | "
2025-01-09 18:41:44 -07:00
"{level: <8} | "
"<fg #4169E1>{module}:{line}</fg #4169E1> | "
2025-01-09 18:41:44 -07:00
"{message}",
"colorize": True,
"level": "DEBUG",
},
],
}
logger.remove()
logger.configure(**config)
logger.level("ERROR", color="<red>")
# Configure logger
setup_logger()
security = HTTPBasic()
def get_http_credentials(credentials: HTTPBasicCredentials = Depends(security)):
"""Conditionally verify HTTP Basic Auth credentials"""
username = os.getenv("HTTP_USERNAME")
password = os.getenv("HTTP_PASSWORD")
# Skip authentication if credentials not configured
if not username or not password:
return
# Perform authentication check if credentials are configured
if (credentials.username != username or credentials.password != password):
raise HTTPException(
status_code=401,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Basic"},
)
return credentials.username
2025-02-09 18:32:17 -07:00
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for model initialization"""
from .inference.model_manager import get_manager
from .inference.voice_manager import get_manager as get_voice_manager
from .services.temp_manager import cleanup_temp_files
# Clean old temp files on startup
await cleanup_temp_files()
logger.info("Loading TTS model and voice packs...")
try:
# Initialize managers
model_manager = await get_manager()
voice_manager = await get_voice_manager()
# Initialize model with warmup and get status
2025-02-09 18:32:17 -07:00
device, model, voicepack_count = await model_manager.initialize_with_warmup(
voice_manager
)
2025-01-28 13:52:57 -07:00
except Exception as e:
logger.error(f"Failed to initialize model: {e}")
raise
2025-02-09 18:32:17 -07:00
boundary = "" * 2 * 12
2025-01-09 18:41:44 -07:00
startup_msg = f"""
2025-01-04 17:55:36 -07:00
{boundary}
2025-01-04 22:23:59 -07:00
2025-01-28 13:52:57 -07:00
2025-01-04 22:23:59 -07:00
2025-01-04 17:55:36 -07:00
{boundary}
"""
startup_msg += f"\nModel warmed up on {device}: {model}"
if device == "mps":
startup_msg += "\nUsing Apple Metal Performance Shaders (MPS)"
elif device == "cuda":
startup_msg += f"\nCUDA: {torch.cuda.is_available()}"
else:
startup_msg += "\nRunning on CPU"
startup_msg += f"\n{voicepack_count} voice packs loaded"
2025-02-09 18:32:17 -07:00
# Add web player info if enabled
if settings.enable_web_player:
2025-02-09 18:32:17 -07:00
startup_msg += (
f"\n\nBeta Web Player: http://{settings.host}:{settings.port}/web/"
)
startup_msg += f"\nor http://localhost:{settings.port}/web/"
else:
startup_msg += "\n\nWeb Player: disabled"
2025-02-09 18:32:17 -07:00
2025-01-04 17:55:36 -07:00
startup_msg += f"\n{boundary}\n"
logger.info(startup_msg)
yield
# Initialize FastAPI app
app = FastAPI(
title=settings.api_title,
description=settings.api_description,
version=settings.api_version,
lifespan=lifespan,
openapi_url="/openapi.json", # Explicitly enable OpenAPI schema
)
# Add CORS middleware if enabled
if settings.cors_enabled:
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Include routers
app.include_router(openai_router, prefix="/v1", dependencies=[Depends(get_http_credentials)])
app.include_router(dev_router, dependencies=[Depends(get_http_credentials)]) # Development endpoints
app.include_router(debug_router, dependencies=[Depends(get_http_credentials)]) # Debug endpoints
if settings.enable_web_player:
app.include_router(web_router, prefix="/web", dependencies=[Depends(get_http_credentials)]) # Web player static files
# Health check endpoint
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "healthy"}
@app.get("/v1/test")
async def test_endpoint():
"""Test endpoint to verify routing"""
return {"status": "ok"}
if __name__ == "__main__":
uvicorn.run("api.src.main:app", host=settings.host, port=settings.port, reload=True)