2024-12-31 01:52:16 -07:00
|
|
|
"""
|
|
|
|
FastAPI OpenAI Compatible API
|
|
|
|
"""
|
2024-12-31 01:57:00 -07:00
|
|
|
|
2025-01-09 07:20:14 -07:00
|
|
|
import sys
|
2025-01-09 18:41:44 -07:00
|
|
|
from contextlib import asynccontextmanager
|
2024-12-31 02:55:51 -07:00
|
|
|
|
|
|
|
import uvicorn
|
2024-12-30 04:17:50 -07:00
|
|
|
from fastapi import FastAPI
|
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
2025-01-13 20:15:46 -07:00
|
|
|
from loguru import logger
|
2024-12-30 04:17:50 -07:00
|
|
|
|
|
|
|
from .core.config import settings
|
2025-01-09 18:41:44 -07:00
|
|
|
from .routers.development import router as dev_router
|
2024-12-31 02:55:51 -07:00
|
|
|
from .routers.openai_compatible import router as openai_router
|
2025-01-13 20:15:46 -07:00
|
|
|
from .services.tts_model import TTSModel
|
|
|
|
from .services.tts_service import TTSService
|
2025-01-09 07:20:14 -07:00
|
|
|
|
|
|
|
|
|
|
|
def setup_logger():
|
|
|
|
"""Configure loguru logger with custom formatting"""
|
|
|
|
config = {
|
|
|
|
"handlers": [
|
|
|
|
{
|
|
|
|
"sink": sys.stdout,
|
|
|
|
"format": "<fg #2E8B57>{time:hh:mm:ss A}</fg #2E8B57> | "
|
2025-01-09 18:41:44 -07:00
|
|
|
"{level: <8} | "
|
|
|
|
"{message}",
|
2025-01-09 07:20:14 -07:00
|
|
|
"colorize": True,
|
2025-01-09 18:41:44 -07:00
|
|
|
"level": "INFO",
|
2025-01-09 07:20:14 -07:00
|
|
|
},
|
|
|
|
],
|
|
|
|
}
|
|
|
|
logger.remove()
|
|
|
|
logger.configure(**config)
|
|
|
|
logger.level("ERROR", color="<red>")
|
|
|
|
|
|
|
|
|
|
|
|
# Configure logger
|
|
|
|
setup_logger()
|
2024-12-31 01:52:16 -07:00
|
|
|
|
2024-12-30 13:21:17 -07:00
|
|
|
@asynccontextmanager
|
|
|
|
async def lifespan(app: FastAPI):
|
2024-12-31 01:57:00 -07:00
|
|
|
"""Lifespan context manager for model initialization"""
|
|
|
|
logger.info("Loading TTS model and voice packs...")
|
|
|
|
|
2025-01-01 17:38:22 -07:00
|
|
|
# Initialize the main model with warm-up
|
2025-01-06 03:32:41 -07:00
|
|
|
voicepack_count = await TTSModel.setup()
|
2025-01-04 17:55:36 -07:00
|
|
|
# boundary = "█████╗"*9
|
2025-01-13 20:18:02 -07:00
|
|
|
boundary = "░" * 2*12
|
2025-01-09 18:41:44 -07:00
|
|
|
startup_msg = f"""
|
2025-01-06 03:32:41 -07:00
|
|
|
|
2025-01-04 17:55:36 -07:00
|
|
|
{boundary}
|
2025-01-04 22:23:59 -07:00
|
|
|
|
|
|
|
╔═╗┌─┐┌─┐┌┬┐
|
|
|
|
╠╣ ├─┤└─┐ │
|
|
|
|
╚ ┴ ┴└─┘ ┴
|
|
|
|
╦╔═┌─┐┬┌─┌─┐
|
|
|
|
╠╩╗│ │├┴┐│ │
|
|
|
|
╩ ╩└─┘┴ ┴└─┘
|
2025-01-04 17:55:36 -07:00
|
|
|
|
|
|
|
{boundary}
|
|
|
|
"""
|
2025-01-06 03:32:41 -07:00
|
|
|
# TODO: Improve CPU warmup, threads, memory, etc
|
|
|
|
startup_msg += f"\nModel warmed up on {TTSModel.get_device()}"
|
|
|
|
startup_msg += f"\n{voicepack_count} voice packs loaded\n"
|
2025-01-04 17:55:36 -07:00
|
|
|
startup_msg += f"\n{boundary}\n"
|
|
|
|
logger.info(startup_msg)
|
|
|
|
|
2024-12-30 13:21:17 -07:00
|
|
|
yield
|
2024-12-30 04:17:50 -07:00
|
|
|
|
2024-12-31 01:57:00 -07:00
|
|
|
|
2024-12-30 04:17:50 -07:00
|
|
|
# Initialize FastAPI app
|
|
|
|
app = FastAPI(
|
|
|
|
title=settings.api_title,
|
|
|
|
description=settings.api_description,
|
|
|
|
version=settings.api_version,
|
2024-12-31 01:52:16 -07:00
|
|
|
lifespan=lifespan,
|
|
|
|
openapi_url="/openapi.json", # Explicitly enable OpenAPI schema
|
2024-12-30 04:17:50 -07:00
|
|
|
)
|
|
|
|
|
|
|
|
# Add CORS middleware
|
|
|
|
app.add_middleware(
|
|
|
|
CORSMiddleware,
|
|
|
|
allow_origins=["*"],
|
|
|
|
allow_credentials=True,
|
|
|
|
allow_methods=["*"],
|
|
|
|
allow_headers=["*"],
|
|
|
|
)
|
|
|
|
|
2025-01-03 17:54:17 -07:00
|
|
|
# Include routers
|
2024-12-31 01:52:16 -07:00
|
|
|
app.include_router(openai_router, prefix="/v1")
|
2025-01-09 07:20:14 -07:00
|
|
|
app.include_router(dev_router) # New development endpoints
|
|
|
|
# app.include_router(text_router) # Deprecated but still live for backwards compatibility
|
2024-12-30 04:17:50 -07:00
|
|
|
|
2024-12-31 01:57:00 -07:00
|
|
|
|
2024-12-30 04:17:50 -07:00
|
|
|
# Health check endpoint
|
|
|
|
@app.get("/health")
|
|
|
|
async def health_check():
|
|
|
|
"""Health check endpoint"""
|
|
|
|
return {"status": "healthy"}
|
|
|
|
|
2024-12-31 01:57:00 -07:00
|
|
|
|
2024-12-31 01:52:16 -07:00
|
|
|
@app.get("/v1/test")
|
|
|
|
async def test_endpoint():
|
|
|
|
"""Test endpoint to verify routing"""
|
|
|
|
return {"status": "ok"}
|
2024-12-30 04:17:50 -07:00
|
|
|
|
2024-12-31 01:57:00 -07:00
|
|
|
|
2024-12-30 04:17:50 -07:00
|
|
|
if __name__ == "__main__":
|
|
|
|
uvicorn.run("api.src.main:app", host=settings.host, port=settings.port, reload=True)
|