- SQLAlchemy integration for TTS queue management

- Model pre-loading and database initialization in the FastAPI app lifespan.
This commit is contained in:
remsky 2024-12-30 13:21:17 -07:00
parent 5afb9e9be8
commit 60a19bde43
9 changed files with 186 additions and 187 deletions

View file

@ -27,7 +27,8 @@ RUN pip3 install --no-cache-dir \
uvicorn==0.34.0 \
pydantic==2.10.4 \
pydantic-settings==2.7.0 \
python-dotenv==1.0.1
python-dotenv==1.0.1 \
sqlalchemy==2.0.27
# Set working directory
WORKDIR /app

View file

@ -0,0 +1,26 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from pathlib import Path
import os
from .models import Base
DB_PATH = Path(__file__).parent.parent / "output" / "queue.db"
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
SQLALCHEMY_DATABASE_URL = f"sqlite:///{DB_PATH}"
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def init_db():
"""Create tables if they don't exist"""
Base.metadata.create_all(bind=engine)
def get_db():
"""Get database session"""
db = SessionLocal()
try:
yield db
finally:
db.close()

View file

@ -0,0 +1,19 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, Float, Boolean, DateTime, Enum as SQLEnum
from sqlalchemy.ext.declarative import declarative_base
from ..models.schemas import TTSStatus
Base = declarative_base()
class TTSQueue(Base):
__tablename__ = "tts_queue"
id = Column(Integer, primary_key=True, index=True)
text = Column(String, nullable=False)
voice = Column(String, default="af")
speed = Column(Float, default=1.0)
stitch_long_output = Column(Boolean, default=True)
status = Column(SQLEnum(TTSStatus), default=TTSStatus.PENDING)
output_file = Column(String, nullable=True)
processing_time = Column(Float, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)

View file

@ -1,153 +1,52 @@
import sqlite3
import os
from pathlib import Path
from typing import Optional, Tuple
DB_PATH = Path(__file__).parent.parent / "output" / "queue.db"
from typing import Optional
from sqlalchemy.orm import Session
from .models import TTSQueue
from .database import init_db
from ..models.schemas import TTSStatus
class QueueDB:
def __init__(self, db_path: str = str(DB_PATH)):
self.db_path = db_path
os.makedirs(os.path.dirname(db_path), exist_ok=True)
self._init_db()
def _init_db(self):
"""Initialize the database with required tables"""
conn = sqlite3.connect(self.db_path)
c = conn.cursor()
c.execute("""
CREATE TABLE IF NOT EXISTS tts_queue
(id INTEGER PRIMARY KEY AUTOINCREMENT,
text TEXT NOT NULL,
voice TEXT DEFAULT 'af',
stitch_long_output BOOLEAN DEFAULT 1,
status TEXT DEFAULT 'pending',
output_file TEXT,
processing_time REAL,
speed REAL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)
""")
conn.commit()
conn.close()
def _ensure_table_if_needed(self, conn: sqlite3.Connection):
"""Create table if it doesn't exist, only called for write operations"""
c = conn.cursor()
c.execute("""
CREATE TABLE IF NOT EXISTS tts_queue
(id INTEGER PRIMARY KEY AUTOINCREMENT,
text TEXT NOT NULL,
voice TEXT DEFAULT 'af',
stitch_long_output BOOLEAN DEFAULT 1,
status TEXT DEFAULT 'pending',
output_file TEXT,
processing_time REAL,
speed REAL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)
""")
conn.commit()
def __init__(self, db: Session):
self.db = db
init_db() # Ensure tables exist
def add_request(self, text: str, voice: str, speed: float, stitch_long_output: bool = True) -> int:
"""Add a new TTS request to the queue"""
conn = sqlite3.connect(self.db_path)
try:
c = conn.cursor()
c.execute(
"INSERT INTO tts_queue (text, voice, speed, stitch_long_output) VALUES (?, ?, ?, ?)",
(text, voice, speed, stitch_long_output)
)
request_id = c.lastrowid
conn.commit()
return request_id
except sqlite3.OperationalError: # Table doesn't exist
self._ensure_table_if_needed(conn)
c = conn.cursor()
c.execute(
"INSERT INTO tts_queue (text, voice, speed, stitch_long_output) VALUES (?, ?, ?, ?)",
(text, voice, speed, stitch_long_output)
)
request_id = c.lastrowid
conn.commit()
return request_id
finally:
conn.close()
db_item = TTSQueue(
text=text,
voice=voice,
speed=speed,
stitch_long_output=stitch_long_output
)
self.db.add(db_item)
self.db.commit()
self.db.refresh(db_item)
return db_item.id
def get_next_pending(self) -> Optional[Tuple[int, str, float, str]]:
def get_next_pending(self) -> Optional[TTSQueue]:
"""Get the next pending request"""
conn = sqlite3.connect(self.db_path)
try:
c = conn.cursor()
c.execute(
'SELECT id, text, voice, speed, stitch_long_output FROM tts_queue WHERE status = "pending" ORDER BY created_at ASC LIMIT 1'
)
return c.fetchone()
except sqlite3.OperationalError: # Table doesn't exist
return None
finally:
conn.close()
return self.db.query(TTSQueue)\
.filter(TTSQueue.status == TTSStatus.PENDING)\
.order_by(TTSQueue.created_at)\
.first()
def update_status(
self,
request_id: int,
status: str,
status: TTSStatus,
output_file: Optional[str] = None,
processing_time: Optional[float] = None,
):
"""Update request status, output file, and processing time"""
conn = sqlite3.connect(self.db_path)
try:
c = conn.cursor()
if output_file and processing_time is not None:
c.execute(
"UPDATE tts_queue SET status = ?, output_file = ?, processing_time = ? WHERE id = ?",
(status, output_file, processing_time, request_id),
)
elif output_file:
c.execute(
"UPDATE tts_queue SET status = ?, output_file = ? WHERE id = ?",
(status, output_file, request_id),
)
else:
c.execute(
"UPDATE tts_queue SET status = ? WHERE id = ?", (status, request_id)
)
conn.commit()
except sqlite3.OperationalError: # Table doesn't exist
self._ensure_table_if_needed(conn)
# Retry the update
c = conn.cursor()
if output_file and processing_time is not None:
c.execute(
"UPDATE tts_queue SET status = ?, output_file = ?, processing_time = ? WHERE id = ?",
(status, output_file, processing_time, request_id),
)
elif output_file:
c.execute(
"UPDATE tts_queue SET status = ?, output_file = ? WHERE id = ?",
(status, output_file, request_id),
)
else:
c.execute(
"UPDATE tts_queue SET status = ? WHERE id = ?", (status, request_id)
)
conn.commit()
finally:
conn.close()
request = self.db.query(TTSQueue).filter(TTSQueue.id == request_id).first()
if request:
request.status = status
if output_file:
request.output_file = output_file
if processing_time is not None:
request.processing_time = processing_time
self.db.commit()
def get_status(
self, request_id: int
) -> Optional[Tuple[str, Optional[str], Optional[float]]]:
"""Get status, output file, and processing time for a request"""
conn = sqlite3.connect(self.db_path)
try:
c = conn.cursor()
c.execute(
"SELECT status, output_file, processing_time FROM tts_queue WHERE id = ?",
(request_id,),
)
return c.fetchone()
except sqlite3.OperationalError: # Table doesn't exist
return None
finally:
conn.close()
def get_status(self, request_id: int) -> Optional[TTSQueue]:
"""Get full request details by ID"""
return self.db.query(TTSQueue).filter(TTSQueue.id == request_id).first()

View file

@ -1,15 +1,32 @@
import uvicorn
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from .core.config import settings
from .routers import tts_router
from .database.database import init_db
from .services.tts import TTSModel
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for database and model initialization"""
print("Initializing database and preloading models...")
init_db() # Initialize database tables
# Preload TTS model and default voice
TTSModel.get_instance() # This loads the model
TTSModel.get_voicepack("af") # Preload default voice, optional
print("Initialization complete!")
yield
# Initialize FastAPI app
app = FastAPI(
title=settings.api_title,
description=settings.api_description,
version=settings.api_version,
lifespan=lifespan
)
# Add CORS middleware

View file

@ -1,6 +1,29 @@
from pydantic import BaseModel
from datetime import datetime
from pydantic import BaseModel, Field
from typing import Optional
from enum import Enum
class TTSStatus(str, Enum):
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
class TTSQueueModel(BaseModel):
id: Optional[int] = None
text: str
voice: str = "af"
speed: float = 1.0
stitch_long_output: bool = True
status: TTSStatus = TTSStatus.PENDING
output_file: Optional[str] = None
processing_time: Optional[float] = None
created_at: datetime = Field(default_factory=datetime.utcnow)
class Config:
from_attributes = True
class TTSRequest(BaseModel):

View file

@ -1,7 +1,9 @@
import os
from fastapi import APIRouter, HTTPException, Response
from fastapi import APIRouter, HTTPException, Response, Depends
from sqlalchemy.orm import Session
from ..models.schemas import TTSRequest, TTSResponse, VoicesResponse
from ..services.tts import TTSService
from ..database.database import get_db
router = APIRouter(
prefix="/tts",
@ -9,19 +11,20 @@ router = APIRouter(
responses={404: {"description": "Not found"}},
)
# Initialize TTS service
tts_service = TTSService()
def get_tts_service(db: Session = Depends(get_db)) -> TTSService:
"""Dependency to get TTSService instance with database session"""
return TTSService(db)
@router.get("/voices", response_model=VoicesResponse)
async def get_voices():
async def get_voices(tts_service: TTSService = Depends(get_tts_service)):
"""List all available voices"""
voices = tts_service.list_voices()
return {"voices": voices, "default": "af"}
@router.post("", response_model=TTSResponse)
async def create_tts(request: TTSRequest):
async def create_tts(request: TTSRequest, tts_service: TTSService = Depends(get_tts_service)):
"""Submit text for TTS generation"""
# Validate voice exists
voices = tts_service.list_voices()
@ -47,37 +50,35 @@ async def create_tts(request: TTSRequest):
@router.get("/{request_id}", response_model=TTSResponse)
async def get_status(request_id: int):
async def get_status(request_id: int, tts_service: TTSService = Depends(get_tts_service)):
"""Check the status of a TTS request"""
status = tts_service.get_request_status(request_id)
if not status:
request = tts_service.get_request_status(request_id)
if not request:
raise HTTPException(status_code=404, detail="Request not found")
status_str, output_file, processing_time = status
return {
"request_id": request_id,
"status": status_str,
"output_file": output_file,
"processing_time": processing_time,
"request_id": request.id,
"status": request.status,
"output_file": request.output_file,
"processing_time": request.processing_time,
}
@router.get("/file/{request_id}")
async def get_file(request_id: int):
async def get_file(request_id: int, tts_service: TTSService = Depends(get_tts_service)):
"""Download the generated audio file"""
status = tts_service.get_request_status(request_id)
if not status:
request = tts_service.get_request_status(request_id)
if not request:
raise HTTPException(status_code=404, detail="Request not found")
status_str, output_file, _ = status
if status_str != "completed":
if request.status != "completed":
raise HTTPException(status_code=400, detail="Audio generation not complete")
if not output_file or not os.path.exists(output_file):
if not request.output_file or not os.path.exists(request.output_file):
raise HTTPException(status_code=404, detail="Audio file not found")
# Read file and ensure it's closed after
with open(output_file, "rb") as f:
with open(request.output_file, "rb") as f:
content = f.read()
return Response(

View file

@ -2,7 +2,10 @@ import os
import threading
import time
import io
from typing import Optional, Tuple, List
from typing import Optional, List, Tuple
from sqlalchemy.orm import Session, sessionmaker
from ..models.schemas import TTSStatus
from ..database.models import TTSQueue
import numpy as np
import torch
import scipy.io.wavfile as wavfile
@ -45,11 +48,12 @@ class TTSModel:
class TTSService:
def __init__(self, output_dir: str = None):
def __init__(self, db: Session, output_dir: str = None):
if output_dir is None:
output_dir = os.path.join(os.path.dirname(__file__), "..", "output")
self.output_dir = output_dir
self.db = QueueDB()
self.db = QueueDB(db)
self.engine = db.get_bind() # Get engine from session
os.makedirs(output_dir, exist_ok=True)
self._start_worker()
@ -146,31 +150,41 @@ class TTSService:
def _process_queue(self):
"""Background worker that processes the queue"""
# Create a new session for the background worker
Session = sessionmaker(bind=self.engine)
while True:
next_request = self.db.get_next_pending()
if next_request:
request_id, text, voice, speed, stitch_long_output = next_request
try:
# Generate audio and measure time
audio, processing_time = self._generate_audio(text, voice, speed, stitch_long_output)
# Create a new session for each iteration
with Session() as session:
db = QueueDB(session)
request = db.get_next_pending()
if request:
try:
# Generate audio and measure time
audio, processing_time = self._generate_audio(
request.text,
request.voice,
request.speed,
request.stitch_long_output
)
# Save to file
output_file = os.path.abspath(os.path.join(
self.output_dir, f"speech_{request_id}.wav"
))
self._save_audio(audio, output_file)
# Save to file
output_file = os.path.abspath(os.path.join(
self.output_dir, f"speech_{request.id}.wav"
))
self._save_audio(audio, output_file)
# Update status with processing time
self.db.update_status(
request_id,
"completed",
output_file=output_file,
processing_time=processing_time,
)
# Update status with processing time
db.update_status(
request.id,
TTSStatus.COMPLETED,
output_file=output_file,
processing_time=processing_time,
)
except Exception as e:
print(f"Error processing request {request_id}: {str(e)}")
self.db.update_status(request_id, "failed")
except Exception as e:
print(f"Error processing request {request.id}: {str(e)}")
db.update_status(request.id, TTSStatus.FAILED)
time.sleep(1) # Prevent busy waiting
@ -190,8 +204,6 @@ class TTSService:
"""Create a new TTS request and return the request ID"""
return self.db.add_request(text, voice, speed, stitch_long_output)
def get_request_status(
self, request_id: int
) -> Optional[Tuple[str, Optional[str], Optional[float]]]:
"""Get the status, output file path, and processing time for a request"""
def get_request_status(self, request_id: int) -> Optional[TTSQueue]:
"""Get the full request details"""
return self.db.get_status(request_id)

View file

@ -3,6 +3,7 @@ fastapi==0.115.6
uvicorn==0.34.0
pydantic==2.10.4
python-dotenv==1.0.1
sqlalchemy==2.0.27
# ML/DL
torch==2.5.1