From 60a19bde43c298567aebbe3d2066ce3973d9958e Mon Sep 17 00:00:00 2001 From: remsky Date: Mon, 30 Dec 2024 13:21:17 -0700 Subject: [PATCH] - SQLAlchemy integration for TTS queue management - Model pre-loading and database initialization in the FastAPI app lifespan. --- Dockerfile | 3 +- api/src/database/database.py | 26 ++++++ api/src/database/models.py | 19 ++++ api/src/database/queue.py | 171 +++++++---------------------------- api/src/main.py | 17 ++++ api/src/models/schemas.py | 27 +++++- api/src/routers/tts.py | 41 +++++---- api/src/services/tts.py | 68 ++++++++------ requirements.txt | 1 + 9 files changed, 186 insertions(+), 187 deletions(-) create mode 100644 api/src/database/database.py create mode 100644 api/src/database/models.py diff --git a/Dockerfile b/Dockerfile index 5a2cf93..c5fddde 100644 --- a/Dockerfile +++ b/Dockerfile @@ -27,7 +27,8 @@ RUN pip3 install --no-cache-dir \ uvicorn==0.34.0 \ pydantic==2.10.4 \ pydantic-settings==2.7.0 \ - python-dotenv==1.0.1 + python-dotenv==1.0.1 \ + sqlalchemy==2.0.27 # Set working directory WORKDIR /app diff --git a/api/src/database/database.py b/api/src/database/database.py new file mode 100644 index 0000000..093cbd0 --- /dev/null +++ b/api/src/database/database.py @@ -0,0 +1,26 @@ +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from pathlib import Path +import os + +from .models import Base + +DB_PATH = Path(__file__).parent.parent / "output" / "queue.db" +os.makedirs(os.path.dirname(DB_PATH), exist_ok=True) + +SQLALCHEMY_DATABASE_URL = f"sqlite:///{DB_PATH}" +engine = create_engine(SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}) + +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +def init_db(): + """Create tables if they don't exist""" + Base.metadata.create_all(bind=engine) + +def get_db(): + """Get database session""" + db = SessionLocal() + try: + yield db + finally: + db.close() diff --git a/api/src/database/models.py b/api/src/database/models.py new file mode 100644 index 0000000..3ed9167 --- /dev/null +++ b/api/src/database/models.py @@ -0,0 +1,19 @@ +from datetime import datetime +from sqlalchemy import Column, Integer, String, Float, Boolean, DateTime, Enum as SQLEnum +from sqlalchemy.ext.declarative import declarative_base +from ..models.schemas import TTSStatus + +Base = declarative_base() + +class TTSQueue(Base): + __tablename__ = "tts_queue" + + id = Column(Integer, primary_key=True, index=True) + text = Column(String, nullable=False) + voice = Column(String, default="af") + speed = Column(Float, default=1.0) + stitch_long_output = Column(Boolean, default=True) + status = Column(SQLEnum(TTSStatus), default=TTSStatus.PENDING) + output_file = Column(String, nullable=True) + processing_time = Column(Float, nullable=True) + created_at = Column(DateTime, default=datetime.utcnow) diff --git a/api/src/database/queue.py b/api/src/database/queue.py index 1e18e60..8e4aae5 100644 --- a/api/src/database/queue.py +++ b/api/src/database/queue.py @@ -1,153 +1,52 @@ -import sqlite3 -import os -from pathlib import Path -from typing import Optional, Tuple - -DB_PATH = Path(__file__).parent.parent / "output" / "queue.db" +from typing import Optional +from sqlalchemy.orm import Session +from .models import TTSQueue +from .database import init_db +from ..models.schemas import TTSStatus class QueueDB: - def __init__(self, db_path: str = str(DB_PATH)): - self.db_path = db_path - os.makedirs(os.path.dirname(db_path), exist_ok=True) - self._init_db() - - def _init_db(self): - """Initialize the database with required tables""" - conn = sqlite3.connect(self.db_path) - c = conn.cursor() - c.execute(""" - CREATE TABLE IF NOT EXISTS tts_queue - (id INTEGER PRIMARY KEY AUTOINCREMENT, - text TEXT NOT NULL, - voice TEXT DEFAULT 'af', - stitch_long_output BOOLEAN DEFAULT 1, - status TEXT DEFAULT 'pending', - output_file TEXT, - processing_time REAL, - speed REAL, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP) - """) - conn.commit() - conn.close() - - def _ensure_table_if_needed(self, conn: sqlite3.Connection): - """Create table if it doesn't exist, only called for write operations""" - c = conn.cursor() - c.execute(""" - CREATE TABLE IF NOT EXISTS tts_queue - (id INTEGER PRIMARY KEY AUTOINCREMENT, - text TEXT NOT NULL, - voice TEXT DEFAULT 'af', - stitch_long_output BOOLEAN DEFAULT 1, - status TEXT DEFAULT 'pending', - output_file TEXT, - processing_time REAL, - speed REAL, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP) - """) - conn.commit() + def __init__(self, db: Session): + self.db = db + init_db() # Ensure tables exist def add_request(self, text: str, voice: str, speed: float, stitch_long_output: bool = True) -> int: """Add a new TTS request to the queue""" - conn = sqlite3.connect(self.db_path) - try: - c = conn.cursor() - c.execute( - "INSERT INTO tts_queue (text, voice, speed, stitch_long_output) VALUES (?, ?, ?, ?)", - (text, voice, speed, stitch_long_output) - ) - request_id = c.lastrowid - conn.commit() - return request_id - except sqlite3.OperationalError: # Table doesn't exist - self._ensure_table_if_needed(conn) - c = conn.cursor() - c.execute( - "INSERT INTO tts_queue (text, voice, speed, stitch_long_output) VALUES (?, ?, ?, ?)", - (text, voice, speed, stitch_long_output) - ) - request_id = c.lastrowid - conn.commit() - return request_id - finally: - conn.close() + db_item = TTSQueue( + text=text, + voice=voice, + speed=speed, + stitch_long_output=stitch_long_output + ) + self.db.add(db_item) + self.db.commit() + self.db.refresh(db_item) + return db_item.id - def get_next_pending(self) -> Optional[Tuple[int, str, float, str]]: + def get_next_pending(self) -> Optional[TTSQueue]: """Get the next pending request""" - conn = sqlite3.connect(self.db_path) - try: - c = conn.cursor() - c.execute( - 'SELECT id, text, voice, speed, stitch_long_output FROM tts_queue WHERE status = "pending" ORDER BY created_at ASC LIMIT 1' - ) - return c.fetchone() - except sqlite3.OperationalError: # Table doesn't exist - return None - finally: - conn.close() + return self.db.query(TTSQueue)\ + .filter(TTSQueue.status == TTSStatus.PENDING)\ + .order_by(TTSQueue.created_at)\ + .first() def update_status( self, request_id: int, - status: str, + status: TTSStatus, output_file: Optional[str] = None, processing_time: Optional[float] = None, ): """Update request status, output file, and processing time""" - conn = sqlite3.connect(self.db_path) - try: - c = conn.cursor() - if output_file and processing_time is not None: - c.execute( - "UPDATE tts_queue SET status = ?, output_file = ?, processing_time = ? WHERE id = ?", - (status, output_file, processing_time, request_id), - ) - elif output_file: - c.execute( - "UPDATE tts_queue SET status = ?, output_file = ? WHERE id = ?", - (status, output_file, request_id), - ) - else: - c.execute( - "UPDATE tts_queue SET status = ? WHERE id = ?", (status, request_id) - ) - conn.commit() - except sqlite3.OperationalError: # Table doesn't exist - self._ensure_table_if_needed(conn) - # Retry the update - c = conn.cursor() - if output_file and processing_time is not None: - c.execute( - "UPDATE tts_queue SET status = ?, output_file = ?, processing_time = ? WHERE id = ?", - (status, output_file, processing_time, request_id), - ) - elif output_file: - c.execute( - "UPDATE tts_queue SET status = ?, output_file = ? WHERE id = ?", - (status, output_file, request_id), - ) - else: - c.execute( - "UPDATE tts_queue SET status = ? WHERE id = ?", (status, request_id) - ) - conn.commit() - finally: - conn.close() + request = self.db.query(TTSQueue).filter(TTSQueue.id == request_id).first() + if request: + request.status = status + if output_file: + request.output_file = output_file + if processing_time is not None: + request.processing_time = processing_time + self.db.commit() - def get_status( - self, request_id: int - ) -> Optional[Tuple[str, Optional[str], Optional[float]]]: - """Get status, output file, and processing time for a request""" - conn = sqlite3.connect(self.db_path) - try: - c = conn.cursor() - c.execute( - "SELECT status, output_file, processing_time FROM tts_queue WHERE id = ?", - (request_id,), - ) - return c.fetchone() - except sqlite3.OperationalError: # Table doesn't exist - return None - finally: - conn.close() + def get_status(self, request_id: int) -> Optional[TTSQueue]: + """Get full request details by ID""" + return self.db.query(TTSQueue).filter(TTSQueue.id == request_id).first() diff --git a/api/src/main.py b/api/src/main.py index 10b75ef..a9bfc01 100644 --- a/api/src/main.py +++ b/api/src/main.py @@ -1,15 +1,32 @@ import uvicorn +from contextlib import asynccontextmanager from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from .core.config import settings from .routers import tts_router +from .database.database import init_db +from .services.tts import TTSModel + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Lifespan context manager for database and model initialization""" + print("Initializing database and preloading models...") + init_db() # Initialize database tables + + # Preload TTS model and default voice + TTSModel.get_instance() # This loads the model + TTSModel.get_voicepack("af") # Preload default voice, optional + print("Initialization complete!") + + yield # Initialize FastAPI app app = FastAPI( title=settings.api_title, description=settings.api_description, version=settings.api_version, + lifespan=lifespan ) # Add CORS middleware diff --git a/api/src/models/schemas.py b/api/src/models/schemas.py index 4c02f45..3458d14 100644 --- a/api/src/models/schemas.py +++ b/api/src/models/schemas.py @@ -1,6 +1,29 @@ -from pydantic import BaseModel - +from datetime import datetime +from pydantic import BaseModel, Field from typing import Optional +from enum import Enum + + +class TTSStatus(str, Enum): + PENDING = "pending" + PROCESSING = "processing" + COMPLETED = "completed" + FAILED = "failed" + + +class TTSQueueModel(BaseModel): + id: Optional[int] = None + text: str + voice: str = "af" + speed: float = 1.0 + stitch_long_output: bool = True + status: TTSStatus = TTSStatus.PENDING + output_file: Optional[str] = None + processing_time: Optional[float] = None + created_at: datetime = Field(default_factory=datetime.utcnow) + + class Config: + from_attributes = True class TTSRequest(BaseModel): diff --git a/api/src/routers/tts.py b/api/src/routers/tts.py index 7bfa023..3113439 100644 --- a/api/src/routers/tts.py +++ b/api/src/routers/tts.py @@ -1,7 +1,9 @@ import os -from fastapi import APIRouter, HTTPException, Response +from fastapi import APIRouter, HTTPException, Response, Depends +from sqlalchemy.orm import Session from ..models.schemas import TTSRequest, TTSResponse, VoicesResponse from ..services.tts import TTSService +from ..database.database import get_db router = APIRouter( prefix="/tts", @@ -9,19 +11,20 @@ router = APIRouter( responses={404: {"description": "Not found"}}, ) -# Initialize TTS service -tts_service = TTSService() +def get_tts_service(db: Session = Depends(get_db)) -> TTSService: + """Dependency to get TTSService instance with database session""" + return TTSService(db) @router.get("/voices", response_model=VoicesResponse) -async def get_voices(): +async def get_voices(tts_service: TTSService = Depends(get_tts_service)): """List all available voices""" voices = tts_service.list_voices() return {"voices": voices, "default": "af"} @router.post("", response_model=TTSResponse) -async def create_tts(request: TTSRequest): +async def create_tts(request: TTSRequest, tts_service: TTSService = Depends(get_tts_service)): """Submit text for TTS generation""" # Validate voice exists voices = tts_service.list_voices() @@ -47,37 +50,35 @@ async def create_tts(request: TTSRequest): @router.get("/{request_id}", response_model=TTSResponse) -async def get_status(request_id: int): +async def get_status(request_id: int, tts_service: TTSService = Depends(get_tts_service)): """Check the status of a TTS request""" - status = tts_service.get_request_status(request_id) - if not status: + request = tts_service.get_request_status(request_id) + if not request: raise HTTPException(status_code=404, detail="Request not found") - status_str, output_file, processing_time = status return { - "request_id": request_id, - "status": status_str, - "output_file": output_file, - "processing_time": processing_time, + "request_id": request.id, + "status": request.status, + "output_file": request.output_file, + "processing_time": request.processing_time, } @router.get("/file/{request_id}") -async def get_file(request_id: int): +async def get_file(request_id: int, tts_service: TTSService = Depends(get_tts_service)): """Download the generated audio file""" - status = tts_service.get_request_status(request_id) - if not status: + request = tts_service.get_request_status(request_id) + if not request: raise HTTPException(status_code=404, detail="Request not found") - status_str, output_file, _ = status - if status_str != "completed": + if request.status != "completed": raise HTTPException(status_code=400, detail="Audio generation not complete") - if not output_file or not os.path.exists(output_file): + if not request.output_file or not os.path.exists(request.output_file): raise HTTPException(status_code=404, detail="Audio file not found") # Read file and ensure it's closed after - with open(output_file, "rb") as f: + with open(request.output_file, "rb") as f: content = f.read() return Response( diff --git a/api/src/services/tts.py b/api/src/services/tts.py index dcf1c6b..02aa4e6 100644 --- a/api/src/services/tts.py +++ b/api/src/services/tts.py @@ -2,7 +2,10 @@ import os import threading import time import io -from typing import Optional, Tuple, List +from typing import Optional, List, Tuple +from sqlalchemy.orm import Session, sessionmaker +from ..models.schemas import TTSStatus +from ..database.models import TTSQueue import numpy as np import torch import scipy.io.wavfile as wavfile @@ -45,11 +48,12 @@ class TTSModel: class TTSService: - def __init__(self, output_dir: str = None): + def __init__(self, db: Session, output_dir: str = None): if output_dir is None: output_dir = os.path.join(os.path.dirname(__file__), "..", "output") self.output_dir = output_dir - self.db = QueueDB() + self.db = QueueDB(db) + self.engine = db.get_bind() # Get engine from session os.makedirs(output_dir, exist_ok=True) self._start_worker() @@ -146,31 +150,41 @@ class TTSService: def _process_queue(self): """Background worker that processes the queue""" + # Create a new session for the background worker + Session = sessionmaker(bind=self.engine) + while True: - next_request = self.db.get_next_pending() - if next_request: - request_id, text, voice, speed, stitch_long_output = next_request - try: - # Generate audio and measure time - audio, processing_time = self._generate_audio(text, voice, speed, stitch_long_output) + # Create a new session for each iteration + with Session() as session: + db = QueueDB(session) + request = db.get_next_pending() + if request: + try: + # Generate audio and measure time + audio, processing_time = self._generate_audio( + request.text, + request.voice, + request.speed, + request.stitch_long_output + ) - # Save to file - output_file = os.path.abspath(os.path.join( - self.output_dir, f"speech_{request_id}.wav" - )) - self._save_audio(audio, output_file) + # Save to file + output_file = os.path.abspath(os.path.join( + self.output_dir, f"speech_{request.id}.wav" + )) + self._save_audio(audio, output_file) - # Update status with processing time - self.db.update_status( - request_id, - "completed", - output_file=output_file, - processing_time=processing_time, - ) + # Update status with processing time + db.update_status( + request.id, + TTSStatus.COMPLETED, + output_file=output_file, + processing_time=processing_time, + ) - except Exception as e: - print(f"Error processing request {request_id}: {str(e)}") - self.db.update_status(request_id, "failed") + except Exception as e: + print(f"Error processing request {request.id}: {str(e)}") + db.update_status(request.id, TTSStatus.FAILED) time.sleep(1) # Prevent busy waiting @@ -190,8 +204,6 @@ class TTSService: """Create a new TTS request and return the request ID""" return self.db.add_request(text, voice, speed, stitch_long_output) - def get_request_status( - self, request_id: int - ) -> Optional[Tuple[str, Optional[str], Optional[float]]]: - """Get the status, output file path, and processing time for a request""" + def get_request_status(self, request_id: int) -> Optional[TTSQueue]: + """Get the full request details""" return self.db.get_status(request_id) diff --git a/requirements.txt b/requirements.txt index 57f3b49..a5a314d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ fastapi==0.115.6 uvicorn==0.34.0 pydantic==2.10.4 python-dotenv==1.0.1 +sqlalchemy==2.0.27 # ML/DL torch==2.5.1