mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
- SQLAlchemy integration for TTS queue management
- Model pre-loading and database initialization in the FastAPI app lifespan.
This commit is contained in:
parent
5afb9e9be8
commit
60a19bde43
9 changed files with 186 additions and 187 deletions
|
@ -27,7 +27,8 @@ RUN pip3 install --no-cache-dir \
|
|||
uvicorn==0.34.0 \
|
||||
pydantic==2.10.4 \
|
||||
pydantic-settings==2.7.0 \
|
||||
python-dotenv==1.0.1
|
||||
python-dotenv==1.0.1 \
|
||||
sqlalchemy==2.0.27
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
|
26
api/src/database/database.py
Normal file
26
api/src/database/database.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
from .models import Base
|
||||
|
||||
DB_PATH = Path(__file__).parent.parent / "output" / "queue.db"
|
||||
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
|
||||
|
||||
SQLALCHEMY_DATABASE_URL = f"sqlite:///{DB_PATH}"
|
||||
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
|
||||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
def init_db():
|
||||
"""Create tables if they don't exist"""
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
def get_db():
|
||||
"""Get database session"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
19
api/src/database/models.py
Normal file
19
api/src/database/models.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
from datetime import datetime
|
||||
from sqlalchemy import Column, Integer, String, Float, Boolean, DateTime, Enum as SQLEnum
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from ..models.schemas import TTSStatus
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class TTSQueue(Base):
|
||||
__tablename__ = "tts_queue"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
text = Column(String, nullable=False)
|
||||
voice = Column(String, default="af")
|
||||
speed = Column(Float, default=1.0)
|
||||
stitch_long_output = Column(Boolean, default=True)
|
||||
status = Column(SQLEnum(TTSStatus), default=TTSStatus.PENDING)
|
||||
output_file = Column(String, nullable=True)
|
||||
processing_time = Column(Float, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
|
@ -1,153 +1,52 @@
|
|||
import sqlite3
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
DB_PATH = Path(__file__).parent.parent / "output" / "queue.db"
|
||||
from typing import Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from .models import TTSQueue
|
||||
from .database import init_db
|
||||
from ..models.schemas import TTSStatus
|
||||
|
||||
|
||||
class QueueDB:
|
||||
def __init__(self, db_path: str = str(DB_PATH)):
|
||||
self.db_path = db_path
|
||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
self._init_db()
|
||||
|
||||
def _init_db(self):
|
||||
"""Initialize the database with required tables"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
c = conn.cursor()
|
||||
c.execute("""
|
||||
CREATE TABLE IF NOT EXISTS tts_queue
|
||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
text TEXT NOT NULL,
|
||||
voice TEXT DEFAULT 'af',
|
||||
stitch_long_output BOOLEAN DEFAULT 1,
|
||||
status TEXT DEFAULT 'pending',
|
||||
output_file TEXT,
|
||||
processing_time REAL,
|
||||
speed REAL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)
|
||||
""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def _ensure_table_if_needed(self, conn: sqlite3.Connection):
|
||||
"""Create table if it doesn't exist, only called for write operations"""
|
||||
c = conn.cursor()
|
||||
c.execute("""
|
||||
CREATE TABLE IF NOT EXISTS tts_queue
|
||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
text TEXT NOT NULL,
|
||||
voice TEXT DEFAULT 'af',
|
||||
stitch_long_output BOOLEAN DEFAULT 1,
|
||||
status TEXT DEFAULT 'pending',
|
||||
output_file TEXT,
|
||||
processing_time REAL,
|
||||
speed REAL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)
|
||||
""")
|
||||
conn.commit()
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
init_db() # Ensure tables exist
|
||||
|
||||
def add_request(self, text: str, voice: str, speed: float, stitch_long_output: bool = True) -> int:
|
||||
"""Add a new TTS request to the queue"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
try:
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
"INSERT INTO tts_queue (text, voice, speed, stitch_long_output) VALUES (?, ?, ?, ?)",
|
||||
(text, voice, speed, stitch_long_output)
|
||||
)
|
||||
request_id = c.lastrowid
|
||||
conn.commit()
|
||||
return request_id
|
||||
except sqlite3.OperationalError: # Table doesn't exist
|
||||
self._ensure_table_if_needed(conn)
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
"INSERT INTO tts_queue (text, voice, speed, stitch_long_output) VALUES (?, ?, ?, ?)",
|
||||
(text, voice, speed, stitch_long_output)
|
||||
)
|
||||
request_id = c.lastrowid
|
||||
conn.commit()
|
||||
return request_id
|
||||
finally:
|
||||
conn.close()
|
||||
db_item = TTSQueue(
|
||||
text=text,
|
||||
voice=voice,
|
||||
speed=speed,
|
||||
stitch_long_output=stitch_long_output
|
||||
)
|
||||
self.db.add(db_item)
|
||||
self.db.commit()
|
||||
self.db.refresh(db_item)
|
||||
return db_item.id
|
||||
|
||||
def get_next_pending(self) -> Optional[Tuple[int, str, float, str]]:
|
||||
def get_next_pending(self) -> Optional[TTSQueue]:
|
||||
"""Get the next pending request"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
try:
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
'SELECT id, text, voice, speed, stitch_long_output FROM tts_queue WHERE status = "pending" ORDER BY created_at ASC LIMIT 1'
|
||||
)
|
||||
return c.fetchone()
|
||||
except sqlite3.OperationalError: # Table doesn't exist
|
||||
return None
|
||||
finally:
|
||||
conn.close()
|
||||
return self.db.query(TTSQueue)\
|
||||
.filter(TTSQueue.status == TTSStatus.PENDING)\
|
||||
.order_by(TTSQueue.created_at)\
|
||||
.first()
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
request_id: int,
|
||||
status: str,
|
||||
status: TTSStatus,
|
||||
output_file: Optional[str] = None,
|
||||
processing_time: Optional[float] = None,
|
||||
):
|
||||
"""Update request status, output file, and processing time"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
try:
|
||||
c = conn.cursor()
|
||||
if output_file and processing_time is not None:
|
||||
c.execute(
|
||||
"UPDATE tts_queue SET status = ?, output_file = ?, processing_time = ? WHERE id = ?",
|
||||
(status, output_file, processing_time, request_id),
|
||||
)
|
||||
elif output_file:
|
||||
c.execute(
|
||||
"UPDATE tts_queue SET status = ?, output_file = ? WHERE id = ?",
|
||||
(status, output_file, request_id),
|
||||
)
|
||||
else:
|
||||
c.execute(
|
||||
"UPDATE tts_queue SET status = ? WHERE id = ?", (status, request_id)
|
||||
)
|
||||
conn.commit()
|
||||
except sqlite3.OperationalError: # Table doesn't exist
|
||||
self._ensure_table_if_needed(conn)
|
||||
# Retry the update
|
||||
c = conn.cursor()
|
||||
if output_file and processing_time is not None:
|
||||
c.execute(
|
||||
"UPDATE tts_queue SET status = ?, output_file = ?, processing_time = ? WHERE id = ?",
|
||||
(status, output_file, processing_time, request_id),
|
||||
)
|
||||
elif output_file:
|
||||
c.execute(
|
||||
"UPDATE tts_queue SET status = ?, output_file = ? WHERE id = ?",
|
||||
(status, output_file, request_id),
|
||||
)
|
||||
else:
|
||||
c.execute(
|
||||
"UPDATE tts_queue SET status = ? WHERE id = ?", (status, request_id)
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
request = self.db.query(TTSQueue).filter(TTSQueue.id == request_id).first()
|
||||
if request:
|
||||
request.status = status
|
||||
if output_file:
|
||||
request.output_file = output_file
|
||||
if processing_time is not None:
|
||||
request.processing_time = processing_time
|
||||
self.db.commit()
|
||||
|
||||
def get_status(
|
||||
self, request_id: int
|
||||
) -> Optional[Tuple[str, Optional[str], Optional[float]]]:
|
||||
"""Get status, output file, and processing time for a request"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
try:
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
"SELECT status, output_file, processing_time FROM tts_queue WHERE id = ?",
|
||||
(request_id,),
|
||||
)
|
||||
return c.fetchone()
|
||||
except sqlite3.OperationalError: # Table doesn't exist
|
||||
return None
|
||||
finally:
|
||||
conn.close()
|
||||
def get_status(self, request_id: int) -> Optional[TTSQueue]:
|
||||
"""Get full request details by ID"""
|
||||
return self.db.query(TTSQueue).filter(TTSQueue.id == request_id).first()
|
||||
|
|
|
@ -1,15 +1,32 @@
|
|||
import uvicorn
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from .core.config import settings
|
||||
from .routers import tts_router
|
||||
from .database.database import init_db
|
||||
from .services.tts import TTSModel
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Lifespan context manager for database and model initialization"""
|
||||
print("Initializing database and preloading models...")
|
||||
init_db() # Initialize database tables
|
||||
|
||||
# Preload TTS model and default voice
|
||||
TTSModel.get_instance() # This loads the model
|
||||
TTSModel.get_voicepack("af") # Preload default voice, optional
|
||||
print("Initialization complete!")
|
||||
|
||||
yield
|
||||
|
||||
# Initialize FastAPI app
|
||||
app = FastAPI(
|
||||
title=settings.api_title,
|
||||
description=settings.api_description,
|
||||
version=settings.api_version,
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
|
|
|
@ -1,6 +1,29 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class TTSStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class TTSQueueModel(BaseModel):
|
||||
id: Optional[int] = None
|
||||
text: str
|
||||
voice: str = "af"
|
||||
speed: float = 1.0
|
||||
stitch_long_output: bool = True
|
||||
status: TTSStatus = TTSStatus.PENDING
|
||||
output_file: Optional[str] = None
|
||||
processing_time: Optional[float] = None
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class TTSRequest(BaseModel):
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import os
|
||||
from fastapi import APIRouter, HTTPException, Response
|
||||
from fastapi import APIRouter, HTTPException, Response, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
from ..models.schemas import TTSRequest, TTSResponse, VoicesResponse
|
||||
from ..services.tts import TTSService
|
||||
from ..database.database import get_db
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/tts",
|
||||
|
@ -9,19 +11,20 @@ router = APIRouter(
|
|||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
# Initialize TTS service
|
||||
tts_service = TTSService()
|
||||
def get_tts_service(db: Session = Depends(get_db)) -> TTSService:
|
||||
"""Dependency to get TTSService instance with database session"""
|
||||
return TTSService(db)
|
||||
|
||||
|
||||
@router.get("/voices", response_model=VoicesResponse)
|
||||
async def get_voices():
|
||||
async def get_voices(tts_service: TTSService = Depends(get_tts_service)):
|
||||
"""List all available voices"""
|
||||
voices = tts_service.list_voices()
|
||||
return {"voices": voices, "default": "af"}
|
||||
|
||||
|
||||
@router.post("", response_model=TTSResponse)
|
||||
async def create_tts(request: TTSRequest):
|
||||
async def create_tts(request: TTSRequest, tts_service: TTSService = Depends(get_tts_service)):
|
||||
"""Submit text for TTS generation"""
|
||||
# Validate voice exists
|
||||
voices = tts_service.list_voices()
|
||||
|
@ -47,37 +50,35 @@ async def create_tts(request: TTSRequest):
|
|||
|
||||
|
||||
@router.get("/{request_id}", response_model=TTSResponse)
|
||||
async def get_status(request_id: int):
|
||||
async def get_status(request_id: int, tts_service: TTSService = Depends(get_tts_service)):
|
||||
"""Check the status of a TTS request"""
|
||||
status = tts_service.get_request_status(request_id)
|
||||
if not status:
|
||||
request = tts_service.get_request_status(request_id)
|
||||
if not request:
|
||||
raise HTTPException(status_code=404, detail="Request not found")
|
||||
|
||||
status_str, output_file, processing_time = status
|
||||
return {
|
||||
"request_id": request_id,
|
||||
"status": status_str,
|
||||
"output_file": output_file,
|
||||
"processing_time": processing_time,
|
||||
"request_id": request.id,
|
||||
"status": request.status,
|
||||
"output_file": request.output_file,
|
||||
"processing_time": request.processing_time,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/file/{request_id}")
|
||||
async def get_file(request_id: int):
|
||||
async def get_file(request_id: int, tts_service: TTSService = Depends(get_tts_service)):
|
||||
"""Download the generated audio file"""
|
||||
status = tts_service.get_request_status(request_id)
|
||||
if not status:
|
||||
request = tts_service.get_request_status(request_id)
|
||||
if not request:
|
||||
raise HTTPException(status_code=404, detail="Request not found")
|
||||
|
||||
status_str, output_file, _ = status
|
||||
if status_str != "completed":
|
||||
if request.status != "completed":
|
||||
raise HTTPException(status_code=400, detail="Audio generation not complete")
|
||||
|
||||
if not output_file or not os.path.exists(output_file):
|
||||
if not request.output_file or not os.path.exists(request.output_file):
|
||||
raise HTTPException(status_code=404, detail="Audio file not found")
|
||||
|
||||
# Read file and ensure it's closed after
|
||||
with open(output_file, "rb") as f:
|
||||
with open(request.output_file, "rb") as f:
|
||||
content = f.read()
|
||||
|
||||
return Response(
|
||||
|
|
|
@ -2,7 +2,10 @@ import os
|
|||
import threading
|
||||
import time
|
||||
import io
|
||||
from typing import Optional, Tuple, List
|
||||
from typing import Optional, List, Tuple
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from ..models.schemas import TTSStatus
|
||||
from ..database.models import TTSQueue
|
||||
import numpy as np
|
||||
import torch
|
||||
import scipy.io.wavfile as wavfile
|
||||
|
@ -45,11 +48,12 @@ class TTSModel:
|
|||
|
||||
|
||||
class TTSService:
|
||||
def __init__(self, output_dir: str = None):
|
||||
def __init__(self, db: Session, output_dir: str = None):
|
||||
if output_dir is None:
|
||||
output_dir = os.path.join(os.path.dirname(__file__), "..", "output")
|
||||
self.output_dir = output_dir
|
||||
self.db = QueueDB()
|
||||
self.db = QueueDB(db)
|
||||
self.engine = db.get_bind() # Get engine from session
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
self._start_worker()
|
||||
|
||||
|
@ -146,31 +150,41 @@ class TTSService:
|
|||
|
||||
def _process_queue(self):
|
||||
"""Background worker that processes the queue"""
|
||||
# Create a new session for the background worker
|
||||
Session = sessionmaker(bind=self.engine)
|
||||
|
||||
while True:
|
||||
next_request = self.db.get_next_pending()
|
||||
if next_request:
|
||||
request_id, text, voice, speed, stitch_long_output = next_request
|
||||
try:
|
||||
# Generate audio and measure time
|
||||
audio, processing_time = self._generate_audio(text, voice, speed, stitch_long_output)
|
||||
# Create a new session for each iteration
|
||||
with Session() as session:
|
||||
db = QueueDB(session)
|
||||
request = db.get_next_pending()
|
||||
if request:
|
||||
try:
|
||||
# Generate audio and measure time
|
||||
audio, processing_time = self._generate_audio(
|
||||
request.text,
|
||||
request.voice,
|
||||
request.speed,
|
||||
request.stitch_long_output
|
||||
)
|
||||
|
||||
# Save to file
|
||||
output_file = os.path.abspath(os.path.join(
|
||||
self.output_dir, f"speech_{request_id}.wav"
|
||||
))
|
||||
self._save_audio(audio, output_file)
|
||||
# Save to file
|
||||
output_file = os.path.abspath(os.path.join(
|
||||
self.output_dir, f"speech_{request.id}.wav"
|
||||
))
|
||||
self._save_audio(audio, output_file)
|
||||
|
||||
# Update status with processing time
|
||||
self.db.update_status(
|
||||
request_id,
|
||||
"completed",
|
||||
output_file=output_file,
|
||||
processing_time=processing_time,
|
||||
)
|
||||
# Update status with processing time
|
||||
db.update_status(
|
||||
request.id,
|
||||
TTSStatus.COMPLETED,
|
||||
output_file=output_file,
|
||||
processing_time=processing_time,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing request {request_id}: {str(e)}")
|
||||
self.db.update_status(request_id, "failed")
|
||||
except Exception as e:
|
||||
print(f"Error processing request {request.id}: {str(e)}")
|
||||
db.update_status(request.id, TTSStatus.FAILED)
|
||||
|
||||
time.sleep(1) # Prevent busy waiting
|
||||
|
||||
|
@ -190,8 +204,6 @@ class TTSService:
|
|||
"""Create a new TTS request and return the request ID"""
|
||||
return self.db.add_request(text, voice, speed, stitch_long_output)
|
||||
|
||||
def get_request_status(
|
||||
self, request_id: int
|
||||
) -> Optional[Tuple[str, Optional[str], Optional[float]]]:
|
||||
"""Get the status, output file path, and processing time for a request"""
|
||||
def get_request_status(self, request_id: int) -> Optional[TTSQueue]:
|
||||
"""Get the full request details"""
|
||||
return self.db.get_status(request_id)
|
||||
|
|
|
@ -3,6 +3,7 @@ fastapi==0.115.6
|
|||
uvicorn==0.34.0
|
||||
pydantic==2.10.4
|
||||
python-dotenv==1.0.1
|
||||
sqlalchemy==2.0.27
|
||||
|
||||
# ML/DL
|
||||
torch==2.5.1
|
||||
|
|
Loading…
Add table
Reference in a new issue