mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
174 lines
5 KiB
Python
174 lines
5 KiB
Python
"""
|
|
FastAPI OpenAI Compatible API
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
from contextlib import asynccontextmanager
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import uvicorn
|
|
from fastapi import Depends, FastAPI, HTTPException
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
|
from loguru import logger
|
|
|
|
from .core.config import settings
|
|
from .routers.debug import router as debug_router
|
|
from .routers.development import router as dev_router
|
|
from .routers.openai_compatible import router as openai_router
|
|
from .routers.web_player import router as web_router
|
|
|
|
|
|
def setup_logger():
|
|
"""Configure loguru logger with custom formatting"""
|
|
config = {
|
|
"handlers": [
|
|
{
|
|
"sink": sys.stdout,
|
|
"format": "<fg #2E8B57>{time:hh:mm:ss A}</fg #2E8B57> | "
|
|
"{level: <8} | "
|
|
"<fg #4169E1>{module}:{line}</fg #4169E1> | "
|
|
"{message}",
|
|
"colorize": True,
|
|
"level": "DEBUG",
|
|
},
|
|
],
|
|
}
|
|
logger.remove()
|
|
logger.configure(**config)
|
|
logger.level("ERROR", color="<red>")
|
|
|
|
|
|
# Configure logger
|
|
setup_logger()
|
|
|
|
security = HTTPBasic()
|
|
|
|
def get_http_credentials(credentials: HTTPBasicCredentials = Depends(security)):
|
|
"""Conditionally verify HTTP Basic Auth credentials"""
|
|
username = os.getenv("HTTP_USERNAME")
|
|
password = os.getenv("HTTP_PASSWORD")
|
|
|
|
# Skip authentication if credentials not configured
|
|
if not username or not password:
|
|
return
|
|
|
|
# Perform authentication check if credentials are configured
|
|
if (credentials.username != username or credentials.password != password):
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail="Incorrect username or password",
|
|
headers={"WWW-Authenticate": "Basic"},
|
|
)
|
|
return credentials.username
|
|
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Lifespan context manager for model initialization"""
|
|
from .inference.model_manager import get_manager
|
|
from .inference.voice_manager import get_manager as get_voice_manager
|
|
from .services.temp_manager import cleanup_temp_files
|
|
|
|
# Clean old temp files on startup
|
|
await cleanup_temp_files()
|
|
|
|
logger.info("Loading TTS model and voice packs...")
|
|
|
|
try:
|
|
# Initialize managers
|
|
model_manager = await get_manager()
|
|
voice_manager = await get_voice_manager()
|
|
|
|
# Initialize model with warmup and get status
|
|
device, model, voicepack_count = await model_manager.initialize_with_warmup(
|
|
voice_manager
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize model: {e}")
|
|
raise
|
|
|
|
boundary = "░" * 2 * 12
|
|
startup_msg = f"""
|
|
|
|
{boundary}
|
|
|
|
╔═╗┌─┐┌─┐┌┬┐
|
|
╠╣ ├─┤└─┐ │
|
|
╚ ┴ ┴└─┘ ┴
|
|
╦╔═┌─┐┬┌─┌─┐
|
|
╠╩╗│ │├┴┐│ │
|
|
╩ ╩└─┘┴ ┴└─┘
|
|
|
|
{boundary}
|
|
"""
|
|
startup_msg += f"\nModel warmed up on {device}: {model}"
|
|
if device == "mps":
|
|
startup_msg += "\nUsing Apple Metal Performance Shaders (MPS)"
|
|
elif device == "cuda":
|
|
startup_msg += f"\nCUDA: {torch.cuda.is_available()}"
|
|
else:
|
|
startup_msg += "\nRunning on CPU"
|
|
startup_msg += f"\n{voicepack_count} voice packs loaded"
|
|
|
|
# Add web player info if enabled
|
|
if settings.enable_web_player:
|
|
startup_msg += (
|
|
f"\n\nBeta Web Player: http://{settings.host}:{settings.port}/web/"
|
|
)
|
|
startup_msg += f"\nor http://localhost:{settings.port}/web/"
|
|
else:
|
|
startup_msg += "\n\nWeb Player: disabled"
|
|
|
|
startup_msg += f"\n{boundary}\n"
|
|
logger.info(startup_msg)
|
|
|
|
yield
|
|
|
|
|
|
# Initialize FastAPI app
|
|
app = FastAPI(
|
|
title=settings.api_title,
|
|
description=settings.api_description,
|
|
version=settings.api_version,
|
|
lifespan=lifespan,
|
|
openapi_url="/openapi.json", # Explicitly enable OpenAPI schema
|
|
)
|
|
|
|
# Add CORS middleware if enabled
|
|
if settings.cors_enabled:
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=settings.cors_origins,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Include routers
|
|
app.include_router(openai_router, prefix="/v1", dependencies=[Depends(get_http_credentials)])
|
|
app.include_router(dev_router, dependencies=[Depends(get_http_credentials)]) # Development endpoints
|
|
app.include_router(debug_router, dependencies=[Depends(get_http_credentials)]) # Debug endpoints
|
|
if settings.enable_web_player:
|
|
app.include_router(web_router, prefix="/web", dependencies=[Depends(get_http_credentials)]) # Web player static files
|
|
|
|
|
|
# Health check endpoint
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Health check endpoint"""
|
|
return {"status": "healthy"}
|
|
|
|
|
|
@app.get("/v1/test")
|
|
async def test_endpoint():
|
|
"""Test endpoint to verify routing"""
|
|
return {"status": "ok"}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
uvicorn.run("api.src.main:app", host=settings.host, port=settings.port, reload=True)
|