- SQLAlchemy integration for TTS queue management

- Model pre-loading and database initialization in the FastAPI app lifespan.
This commit is contained in:
remsky 2024-12-30 13:21:17 -07:00
parent 5afb9e9be8
commit 60a19bde43
9 changed files with 186 additions and 187 deletions

View file

@ -27,7 +27,8 @@ RUN pip3 install --no-cache-dir \
uvicorn==0.34.0 \ uvicorn==0.34.0 \
pydantic==2.10.4 \ pydantic==2.10.4 \
pydantic-settings==2.7.0 \ pydantic-settings==2.7.0 \
python-dotenv==1.0.1 python-dotenv==1.0.1 \
sqlalchemy==2.0.27
# Set working directory # Set working directory
WORKDIR /app WORKDIR /app

View file

@ -0,0 +1,26 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from pathlib import Path
import os
from .models import Base
DB_PATH = Path(__file__).parent.parent / "output" / "queue.db"
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
SQLALCHEMY_DATABASE_URL = f"sqlite:///{DB_PATH}"
engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def init_db():
"""Create tables if they don't exist"""
Base.metadata.create_all(bind=engine)
def get_db():
"""Get database session"""
db = SessionLocal()
try:
yield db
finally:
db.close()

View file

@ -0,0 +1,19 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, Float, Boolean, DateTime, Enum as SQLEnum
from sqlalchemy.ext.declarative import declarative_base
from ..models.schemas import TTSStatus
Base = declarative_base()
class TTSQueue(Base):
__tablename__ = "tts_queue"
id = Column(Integer, primary_key=True, index=True)
text = Column(String, nullable=False)
voice = Column(String, default="af")
speed = Column(Float, default=1.0)
stitch_long_output = Column(Boolean, default=True)
status = Column(SQLEnum(TTSStatus), default=TTSStatus.PENDING)
output_file = Column(String, nullable=True)
processing_time = Column(Float, nullable=True)
created_at = Column(DateTime, default=datetime.utcnow)

View file

@ -1,153 +1,52 @@
import sqlite3 from typing import Optional
import os from sqlalchemy.orm import Session
from pathlib import Path from .models import TTSQueue
from typing import Optional, Tuple from .database import init_db
from ..models.schemas import TTSStatus
DB_PATH = Path(__file__).parent.parent / "output" / "queue.db"
class QueueDB: class QueueDB:
def __init__(self, db_path: str = str(DB_PATH)): def __init__(self, db: Session):
self.db_path = db_path self.db = db
os.makedirs(os.path.dirname(db_path), exist_ok=True) init_db() # Ensure tables exist
self._init_db()
def _init_db(self):
"""Initialize the database with required tables"""
conn = sqlite3.connect(self.db_path)
c = conn.cursor()
c.execute("""
CREATE TABLE IF NOT EXISTS tts_queue
(id INTEGER PRIMARY KEY AUTOINCREMENT,
text TEXT NOT NULL,
voice TEXT DEFAULT 'af',
stitch_long_output BOOLEAN DEFAULT 1,
status TEXT DEFAULT 'pending',
output_file TEXT,
processing_time REAL,
speed REAL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)
""")
conn.commit()
conn.close()
def _ensure_table_if_needed(self, conn: sqlite3.Connection):
"""Create table if it doesn't exist, only called for write operations"""
c = conn.cursor()
c.execute("""
CREATE TABLE IF NOT EXISTS tts_queue
(id INTEGER PRIMARY KEY AUTOINCREMENT,
text TEXT NOT NULL,
voice TEXT DEFAULT 'af',
stitch_long_output BOOLEAN DEFAULT 1,
status TEXT DEFAULT 'pending',
output_file TEXT,
processing_time REAL,
speed REAL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)
""")
conn.commit()
def add_request(self, text: str, voice: str, speed: float, stitch_long_output: bool = True) -> int: def add_request(self, text: str, voice: str, speed: float, stitch_long_output: bool = True) -> int:
"""Add a new TTS request to the queue""" """Add a new TTS request to the queue"""
conn = sqlite3.connect(self.db_path) db_item = TTSQueue(
try: text=text,
c = conn.cursor() voice=voice,
c.execute( speed=speed,
"INSERT INTO tts_queue (text, voice, speed, stitch_long_output) VALUES (?, ?, ?, ?)", stitch_long_output=stitch_long_output
(text, voice, speed, stitch_long_output)
) )
request_id = c.lastrowid self.db.add(db_item)
conn.commit() self.db.commit()
return request_id self.db.refresh(db_item)
except sqlite3.OperationalError: # Table doesn't exist return db_item.id
self._ensure_table_if_needed(conn)
c = conn.cursor()
c.execute(
"INSERT INTO tts_queue (text, voice, speed, stitch_long_output) VALUES (?, ?, ?, ?)",
(text, voice, speed, stitch_long_output)
)
request_id = c.lastrowid
conn.commit()
return request_id
finally:
conn.close()
def get_next_pending(self) -> Optional[Tuple[int, str, float, str]]: def get_next_pending(self) -> Optional[TTSQueue]:
"""Get the next pending request""" """Get the next pending request"""
conn = sqlite3.connect(self.db_path) return self.db.query(TTSQueue)\
try: .filter(TTSQueue.status == TTSStatus.PENDING)\
c = conn.cursor() .order_by(TTSQueue.created_at)\
c.execute( .first()
'SELECT id, text, voice, speed, stitch_long_output FROM tts_queue WHERE status = "pending" ORDER BY created_at ASC LIMIT 1'
)
return c.fetchone()
except sqlite3.OperationalError: # Table doesn't exist
return None
finally:
conn.close()
def update_status( def update_status(
self, self,
request_id: int, request_id: int,
status: str, status: TTSStatus,
output_file: Optional[str] = None, output_file: Optional[str] = None,
processing_time: Optional[float] = None, processing_time: Optional[float] = None,
): ):
"""Update request status, output file, and processing time""" """Update request status, output file, and processing time"""
conn = sqlite3.connect(self.db_path) request = self.db.query(TTSQueue).filter(TTSQueue.id == request_id).first()
try: if request:
c = conn.cursor() request.status = status
if output_file and processing_time is not None: if output_file:
c.execute( request.output_file = output_file
"UPDATE tts_queue SET status = ?, output_file = ?, processing_time = ? WHERE id = ?", if processing_time is not None:
(status, output_file, processing_time, request_id), request.processing_time = processing_time
) self.db.commit()
elif output_file:
c.execute(
"UPDATE tts_queue SET status = ?, output_file = ? WHERE id = ?",
(status, output_file, request_id),
)
else:
c.execute(
"UPDATE tts_queue SET status = ? WHERE id = ?", (status, request_id)
)
conn.commit()
except sqlite3.OperationalError: # Table doesn't exist
self._ensure_table_if_needed(conn)
# Retry the update
c = conn.cursor()
if output_file and processing_time is not None:
c.execute(
"UPDATE tts_queue SET status = ?, output_file = ?, processing_time = ? WHERE id = ?",
(status, output_file, processing_time, request_id),
)
elif output_file:
c.execute(
"UPDATE tts_queue SET status = ?, output_file = ? WHERE id = ?",
(status, output_file, request_id),
)
else:
c.execute(
"UPDATE tts_queue SET status = ? WHERE id = ?", (status, request_id)
)
conn.commit()
finally:
conn.close()
def get_status( def get_status(self, request_id: int) -> Optional[TTSQueue]:
self, request_id: int """Get full request details by ID"""
) -> Optional[Tuple[str, Optional[str], Optional[float]]]: return self.db.query(TTSQueue).filter(TTSQueue.id == request_id).first()
"""Get status, output file, and processing time for a request"""
conn = sqlite3.connect(self.db_path)
try:
c = conn.cursor()
c.execute(
"SELECT status, output_file, processing_time FROM tts_queue WHERE id = ?",
(request_id,),
)
return c.fetchone()
except sqlite3.OperationalError: # Table doesn't exist
return None
finally:
conn.close()

View file

@ -1,15 +1,32 @@
import uvicorn import uvicorn
from contextlib import asynccontextmanager
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from .core.config import settings from .core.config import settings
from .routers import tts_router from .routers import tts_router
from .database.database import init_db
from .services.tts import TTSModel
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for database and model initialization"""
print("Initializing database and preloading models...")
init_db() # Initialize database tables
# Preload TTS model and default voice
TTSModel.get_instance() # This loads the model
TTSModel.get_voicepack("af") # Preload default voice, optional
print("Initialization complete!")
yield
# Initialize FastAPI app # Initialize FastAPI app
app = FastAPI( app = FastAPI(
title=settings.api_title, title=settings.api_title,
description=settings.api_description, description=settings.api_description,
version=settings.api_version, version=settings.api_version,
lifespan=lifespan
) )
# Add CORS middleware # Add CORS middleware

View file

@ -1,6 +1,29 @@
from pydantic import BaseModel from datetime import datetime
from pydantic import BaseModel, Field
from typing import Optional from typing import Optional
from enum import Enum
class TTSStatus(str, Enum):
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
class TTSQueueModel(BaseModel):
id: Optional[int] = None
text: str
voice: str = "af"
speed: float = 1.0
stitch_long_output: bool = True
status: TTSStatus = TTSStatus.PENDING
output_file: Optional[str] = None
processing_time: Optional[float] = None
created_at: datetime = Field(default_factory=datetime.utcnow)
class Config:
from_attributes = True
class TTSRequest(BaseModel): class TTSRequest(BaseModel):

View file

@ -1,7 +1,9 @@
import os import os
from fastapi import APIRouter, HTTPException, Response from fastapi import APIRouter, HTTPException, Response, Depends
from sqlalchemy.orm import Session
from ..models.schemas import TTSRequest, TTSResponse, VoicesResponse from ..models.schemas import TTSRequest, TTSResponse, VoicesResponse
from ..services.tts import TTSService from ..services.tts import TTSService
from ..database.database import get_db
router = APIRouter( router = APIRouter(
prefix="/tts", prefix="/tts",
@ -9,19 +11,20 @@ router = APIRouter(
responses={404: {"description": "Not found"}}, responses={404: {"description": "Not found"}},
) )
# Initialize TTS service def get_tts_service(db: Session = Depends(get_db)) -> TTSService:
tts_service = TTSService() """Dependency to get TTSService instance with database session"""
return TTSService(db)
@router.get("/voices", response_model=VoicesResponse) @router.get("/voices", response_model=VoicesResponse)
async def get_voices(): async def get_voices(tts_service: TTSService = Depends(get_tts_service)):
"""List all available voices""" """List all available voices"""
voices = tts_service.list_voices() voices = tts_service.list_voices()
return {"voices": voices, "default": "af"} return {"voices": voices, "default": "af"}
@router.post("", response_model=TTSResponse) @router.post("", response_model=TTSResponse)
async def create_tts(request: TTSRequest): async def create_tts(request: TTSRequest, tts_service: TTSService = Depends(get_tts_service)):
"""Submit text for TTS generation""" """Submit text for TTS generation"""
# Validate voice exists # Validate voice exists
voices = tts_service.list_voices() voices = tts_service.list_voices()
@ -47,37 +50,35 @@ async def create_tts(request: TTSRequest):
@router.get("/{request_id}", response_model=TTSResponse) @router.get("/{request_id}", response_model=TTSResponse)
async def get_status(request_id: int): async def get_status(request_id: int, tts_service: TTSService = Depends(get_tts_service)):
"""Check the status of a TTS request""" """Check the status of a TTS request"""
status = tts_service.get_request_status(request_id) request = tts_service.get_request_status(request_id)
if not status: if not request:
raise HTTPException(status_code=404, detail="Request not found") raise HTTPException(status_code=404, detail="Request not found")
status_str, output_file, processing_time = status
return { return {
"request_id": request_id, "request_id": request.id,
"status": status_str, "status": request.status,
"output_file": output_file, "output_file": request.output_file,
"processing_time": processing_time, "processing_time": request.processing_time,
} }
@router.get("/file/{request_id}") @router.get("/file/{request_id}")
async def get_file(request_id: int): async def get_file(request_id: int, tts_service: TTSService = Depends(get_tts_service)):
"""Download the generated audio file""" """Download the generated audio file"""
status = tts_service.get_request_status(request_id) request = tts_service.get_request_status(request_id)
if not status: if not request:
raise HTTPException(status_code=404, detail="Request not found") raise HTTPException(status_code=404, detail="Request not found")
status_str, output_file, _ = status if request.status != "completed":
if status_str != "completed":
raise HTTPException(status_code=400, detail="Audio generation not complete") raise HTTPException(status_code=400, detail="Audio generation not complete")
if not output_file or not os.path.exists(output_file): if not request.output_file or not os.path.exists(request.output_file):
raise HTTPException(status_code=404, detail="Audio file not found") raise HTTPException(status_code=404, detail="Audio file not found")
# Read file and ensure it's closed after # Read file and ensure it's closed after
with open(output_file, "rb") as f: with open(request.output_file, "rb") as f:
content = f.read() content = f.read()
return Response( return Response(

View file

@ -2,7 +2,10 @@ import os
import threading import threading
import time import time
import io import io
from typing import Optional, Tuple, List from typing import Optional, List, Tuple
from sqlalchemy.orm import Session, sessionmaker
from ..models.schemas import TTSStatus
from ..database.models import TTSQueue
import numpy as np import numpy as np
import torch import torch
import scipy.io.wavfile as wavfile import scipy.io.wavfile as wavfile
@ -45,11 +48,12 @@ class TTSModel:
class TTSService: class TTSService:
def __init__(self, output_dir: str = None): def __init__(self, db: Session, output_dir: str = None):
if output_dir is None: if output_dir is None:
output_dir = os.path.join(os.path.dirname(__file__), "..", "output") output_dir = os.path.join(os.path.dirname(__file__), "..", "output")
self.output_dir = output_dir self.output_dir = output_dir
self.db = QueueDB() self.db = QueueDB(db)
self.engine = db.get_bind() # Get engine from session
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
self._start_worker() self._start_worker()
@ -146,31 +150,41 @@ class TTSService:
def _process_queue(self): def _process_queue(self):
"""Background worker that processes the queue""" """Background worker that processes the queue"""
# Create a new session for the background worker
Session = sessionmaker(bind=self.engine)
while True: while True:
next_request = self.db.get_next_pending() # Create a new session for each iteration
if next_request: with Session() as session:
request_id, text, voice, speed, stitch_long_output = next_request db = QueueDB(session)
request = db.get_next_pending()
if request:
try: try:
# Generate audio and measure time # Generate audio and measure time
audio, processing_time = self._generate_audio(text, voice, speed, stitch_long_output) audio, processing_time = self._generate_audio(
request.text,
request.voice,
request.speed,
request.stitch_long_output
)
# Save to file # Save to file
output_file = os.path.abspath(os.path.join( output_file = os.path.abspath(os.path.join(
self.output_dir, f"speech_{request_id}.wav" self.output_dir, f"speech_{request.id}.wav"
)) ))
self._save_audio(audio, output_file) self._save_audio(audio, output_file)
# Update status with processing time # Update status with processing time
self.db.update_status( db.update_status(
request_id, request.id,
"completed", TTSStatus.COMPLETED,
output_file=output_file, output_file=output_file,
processing_time=processing_time, processing_time=processing_time,
) )
except Exception as e: except Exception as e:
print(f"Error processing request {request_id}: {str(e)}") print(f"Error processing request {request.id}: {str(e)}")
self.db.update_status(request_id, "failed") db.update_status(request.id, TTSStatus.FAILED)
time.sleep(1) # Prevent busy waiting time.sleep(1) # Prevent busy waiting
@ -190,8 +204,6 @@ class TTSService:
"""Create a new TTS request and return the request ID""" """Create a new TTS request and return the request ID"""
return self.db.add_request(text, voice, speed, stitch_long_output) return self.db.add_request(text, voice, speed, stitch_long_output)
def get_request_status( def get_request_status(self, request_id: int) -> Optional[TTSQueue]:
self, request_id: int """Get the full request details"""
) -> Optional[Tuple[str, Optional[str], Optional[float]]]:
"""Get the status, output file path, and processing time for a request"""
return self.db.get_status(request_id) return self.db.get_status(request_id)

View file

@ -3,6 +3,7 @@ fastapi==0.115.6
uvicorn==0.34.0 uvicorn==0.34.0
pydantic==2.10.4 pydantic==2.10.4
python-dotenv==1.0.1 python-dotenv==1.0.1
sqlalchemy==2.0.27
# ML/DL # ML/DL
torch==2.5.1 torch==2.5.1