Kokoro-FastAPI/api/src/services/tts.py
remsky 60a19bde43 - SQLAlchemy integration for TTS queue management
- Model pre-loading and database initialization in the FastAPI app lifespan.
2024-12-30 13:21:17 -07:00

209 lines
8 KiB
Python

import os
import threading
import time
import io
from typing import Optional, List, Tuple
from sqlalchemy.orm import Session, sessionmaker
from ..models.schemas import TTSStatus
from ..database.models import TTSQueue
import numpy as np
import torch
import scipy.io.wavfile as wavfile
from models import build_model
from kokoro import generate, phonemize, tokenize
from ..database.queue import QueueDB
class TTSModel:
_instance = None
_lock = threading.Lock()
_voicepacks = {}
@classmethod
def get_instance(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Initializing model on {device}")
model = build_model("kokoro-v0_19.pth", device)
cls._instance = (model, device)
return cls._instance
@classmethod
def get_voicepack(cls, voice_name: str) -> torch.Tensor:
model, device = cls.get_instance()
if voice_name not in cls._voicepacks:
try:
voicepack = torch.load(
f"voices/{voice_name}.pt", map_location=device, weights_only=True
)
cls._voicepacks[voice_name] = voicepack
except Exception as e:
print(f"Error loading voice {voice_name}: {str(e)}")
if voice_name != "af":
return cls.get_voicepack("af")
raise
return cls._voicepacks[voice_name]
class TTSService:
def __init__(self, db: Session, output_dir: str = None):
if output_dir is None:
output_dir = os.path.join(os.path.dirname(__file__), "..", "output")
self.output_dir = output_dir
self.db = QueueDB(db)
self.engine = db.get_bind() # Get engine from session
os.makedirs(output_dir, exist_ok=True)
self._start_worker()
def _start_worker(self):
"""Start the background worker thread"""
self.worker = threading.Thread(target=self._process_queue, daemon=True)
self.worker.start()
def _find_boundary(self, text: str, max_tokens: int, voice: str, margin: int = 50) -> int:
"""Find the closest sentence/clause boundary within token limit"""
# Try different boundary markers in order of preference
for marker in ['. ', '; ', ', ']:
# Look for the last occurrence of marker before max_tokens
test_text = text[:max_tokens + margin] # Look a bit beyond the limit
last_idx = test_text.rfind(marker)
if last_idx != -1:
# Verify this boundary is within our token limit
candidate = text[:last_idx + len(marker)].strip()
ps = phonemize(candidate, voice[0])
tokens = tokenize(ps)
if len(tokens) <= max_tokens:
return last_idx + len(marker)
# If no good boundary found, find last whitespace within limit
test_text = text[:max_tokens]
last_space = test_text.rfind(' ')
return last_space if last_space != -1 else max_tokens
def _split_text(self, text: str, voice: str) -> List[str]:
"""Split text into chunks that respect token limits and try to maintain sentence structure"""
MAX_TOKENS = 450 # Leave wider margin from 510 limit to account for tokenizer differences
chunks = []
remaining = text
while remaining:
# If remaining text is within limit, add it as final chunk
ps = phonemize(remaining, voice[0])
tokens = tokenize(ps)
if len(tokens) <= MAX_TOKENS:
chunks.append(remaining.strip())
break
# Find best boundary position
split_pos = self._find_boundary(remaining, MAX_TOKENS, voice)
# Add chunk and continue with remaining text
chunks.append(remaining[:split_pos].strip())
remaining = remaining[split_pos:].strip()
return chunks
def _generate_audio(self, text: str, voice: str, speed: float, stitch_long_output: bool = True) -> Tuple[torch.Tensor, float]:
"""Generate audio and measure processing time"""
start_time = time.time()
# Get model instance and voicepack
model, device = TTSModel.get_instance()
voicepack = TTSModel.get_voicepack(voice)
# Generate audio with or without stitching
if stitch_long_output:
# Split text if needed and generate audio for each chunk
chunks = self._split_text(text, voice)
audio_chunks = []
for chunk in chunks:
chunk_audio, _ = generate(model, chunk, voicepack, lang=voice[0], speed=speed)
audio_chunks.append(chunk_audio)
# Concatenate audio chunks
if len(audio_chunks) > 1:
audio = np.concatenate(audio_chunks)
else:
audio = audio_chunks[0]
else:
# Generate single chunk without splitting
audio, _ = generate(model, text, voicepack, lang=voice[0])
processing_time = time.time() - start_time
return audio, processing_time
def _save_audio(self, audio: torch.Tensor, filepath: str):
"""Save audio to file"""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
wavfile.write(filepath, 24000, audio)
def _audio_to_bytes(self, audio: torch.Tensor) -> bytes:
"""Convert audio tensor to WAV bytes"""
buffer = io.BytesIO()
wavfile.write(buffer, 24000, audio)
return buffer.getvalue()
def _process_queue(self):
"""Background worker that processes the queue"""
# Create a new session for the background worker
Session = sessionmaker(bind=self.engine)
while True:
# Create a new session for each iteration
with Session() as session:
db = QueueDB(session)
request = db.get_next_pending()
if request:
try:
# Generate audio and measure time
audio, processing_time = self._generate_audio(
request.text,
request.voice,
request.speed,
request.stitch_long_output
)
# Save to file
output_file = os.path.abspath(os.path.join(
self.output_dir, f"speech_{request.id}.wav"
))
self._save_audio(audio, output_file)
# Update status with processing time
db.update_status(
request.id,
TTSStatus.COMPLETED,
output_file=output_file,
processing_time=processing_time,
)
except Exception as e:
print(f"Error processing request {request.id}: {str(e)}")
db.update_status(request.id, TTSStatus.FAILED)
time.sleep(1) # Prevent busy waiting
def list_voices(self) -> list[str]:
"""List all available voices"""
voices = []
try:
for file in os.listdir("voices"):
if file.endswith(".pt"):
voice_name = file[:-3] # Remove .pt extension
voices.append(voice_name)
except Exception as e:
print(f"Error listing voices: {str(e)}")
return voices
def create_tts_request(self, text: str, voice: str = "af", speed: float = 1.0, stitch_long_output: bool = True) -> int:
"""Create a new TTS request and return the request ID"""
return self.db.add_request(text, voice, speed, stitch_long_output)
def get_request_status(self, request_id: int) -> Optional[TTSQueue]:
"""Get the full request details"""
return self.db.get_status(request_id)