feat: enabled support for stitching long outputs in TTS requests

This commit is contained in:
remsky 2024-12-30 06:16:18 -07:00
parent 30581129c0
commit 79d5332c8a
5 changed files with 164 additions and 31 deletions

View file

@ -21,6 +21,7 @@ class QueueDB:
(id INTEGER PRIMARY KEY AUTOINCREMENT, (id INTEGER PRIMARY KEY AUTOINCREMENT,
text TEXT NOT NULL, text TEXT NOT NULL,
voice TEXT DEFAULT 'af', voice TEXT DEFAULT 'af',
stitch_long_output BOOLEAN DEFAULT 1,
status TEXT DEFAULT 'pending', status TEXT DEFAULT 'pending',
output_file TEXT, output_file TEXT,
processing_time REAL, processing_time REAL,
@ -37,6 +38,7 @@ class QueueDB:
(id INTEGER PRIMARY KEY AUTOINCREMENT, (id INTEGER PRIMARY KEY AUTOINCREMENT,
text TEXT NOT NULL, text TEXT NOT NULL,
voice TEXT DEFAULT 'af', voice TEXT DEFAULT 'af',
stitch_long_output BOOLEAN DEFAULT 1,
status TEXT DEFAULT 'pending', status TEXT DEFAULT 'pending',
output_file TEXT, output_file TEXT,
processing_time REAL, processing_time REAL,
@ -44,13 +46,14 @@ class QueueDB:
""") """)
conn.commit() conn.commit()
def add_request(self, text: str, voice: str) -> int: def add_request(self, text: str, voice: str, stitch_long_output: bool = True) -> int:
"""Add a new TTS request to the queue""" """Add a new TTS request to the queue"""
conn = sqlite3.connect(self.db_path) conn = sqlite3.connect(self.db_path)
try: try:
c = conn.cursor() c = conn.cursor()
c.execute( c.execute(
"INSERT INTO tts_queue (text, voice) VALUES (?, ?)", (text, voice) "INSERT INTO tts_queue (text, voice, stitch_long_output) VALUES (?, ?, ?)",
(text, voice, stitch_long_output)
) )
request_id = c.lastrowid request_id = c.lastrowid
conn.commit() conn.commit()
@ -59,7 +62,8 @@ class QueueDB:
self._ensure_table_if_needed(conn) self._ensure_table_if_needed(conn)
c = conn.cursor() c = conn.cursor()
c.execute( c.execute(
"INSERT INTO tts_queue (text, voice) VALUES (?, ?)", (text, voice) "INSERT INTO tts_queue (text, voice, stitch_long_output) VALUES (?, ?, ?)",
(text, voice, stitch_long_output)
) )
request_id = c.lastrowid request_id = c.lastrowid
conn.commit() conn.commit()
@ -73,7 +77,7 @@ class QueueDB:
try: try:
c = conn.cursor() c = conn.cursor()
c.execute( c.execute(
'SELECT id, text, voice FROM tts_queue WHERE status = "pending" ORDER BY created_at ASC LIMIT 1' 'SELECT id, text, voice, stitch_long_output FROM tts_queue WHERE status = "pending" ORDER BY created_at ASC LIMIT 1'
) )
return c.fetchone() return c.fetchone()
except sqlite3.OperationalError: # Table doesn't exist except sqlite3.OperationalError: # Table doesn't exist

View file

@ -7,6 +7,7 @@ class TTSRequest(BaseModel):
text: str text: str
voice: str = "af" # Default voice voice: str = "af" # Default voice
local: bool = False # Whether to save file locally or return bytes local: bool = False # Whether to save file locally or return bytes
stitch_long_output: bool = True # Whether to stitch together long outputs
class TTSResponse(BaseModel): class TTSResponse(BaseModel):
@ -19,3 +20,4 @@ class TTSResponse(BaseModel):
class VoicesResponse(BaseModel): class VoicesResponse(BaseModel):
voices: list[str] voices: list[str]
default: str default: str

View file

@ -32,7 +32,11 @@ async def create_tts(request: TTSRequest):
) )
# Queue the request # Queue the request
request_id = tts_service.create_tts_request(request.text, request.voice) request_id = tts_service.create_tts_request(
request.text,
request.voice,
request.stitch_long_output
)
return { return {
"request_id": request_id, "request_id": request_id,
"status": "pending", "status": "pending",

View file

@ -2,11 +2,12 @@ import os
import threading import threading
import time import time
import io import io
from typing import Optional, Tuple from typing import Optional, Tuple, List
import numpy as np
import torch import torch
import scipy.io.wavfile as wavfile import scipy.io.wavfile as wavfile
from models import build_model from models import build_model
from kokoro import generate from kokoro import generate, phonemize, tokenize
from ..database.queue import QueueDB from ..database.queue import QueueDB
@ -57,7 +58,52 @@ class TTSService:
self.worker = threading.Thread(target=self._process_queue, daemon=True) self.worker = threading.Thread(target=self._process_queue, daemon=True)
self.worker.start() self.worker.start()
def _generate_audio(self, text: str, voice: str) -> Tuple[torch.Tensor, float]: def _find_boundary(self, text: str, max_tokens: int, voice: str, margin: int = 50) -> int:
"""Find the closest sentence/clause boundary within token limit"""
# Try different boundary markers in order of preference
for marker in ['. ', '; ', ', ']:
# Look for the last occurrence of marker before max_tokens
test_text = text[:max_tokens + margin] # Look a bit beyond the limit
last_idx = test_text.rfind(marker)
if last_idx != -1:
# Verify this boundary is within our token limit
candidate = text[:last_idx + len(marker)].strip()
ps = phonemize(candidate, voice[0])
tokens = tokenize(ps)
if len(tokens) <= max_tokens:
return last_idx + len(marker)
# If no good boundary found, find last whitespace within limit
test_text = text[:max_tokens]
last_space = test_text.rfind(' ')
return last_space if last_space != -1 else max_tokens
def _split_text(self, text: str, voice: str) -> List[str]:
"""Split text into chunks that respect token limits and try to maintain sentence structure"""
MAX_TOKENS = 450 # Leave wider margin from 510 limit to account for tokenizer differences
chunks = []
remaining = text
while remaining:
# If remaining text is within limit, add it as final chunk
ps = phonemize(remaining, voice[0])
tokens = tokenize(ps)
if len(tokens) <= MAX_TOKENS:
chunks.append(remaining.strip())
break
# Find best boundary position
split_pos = self._find_boundary(remaining, MAX_TOKENS, voice)
# Add chunk and continue with remaining text
chunks.append(remaining[:split_pos].strip())
remaining = remaining[split_pos:].strip()
return chunks
def _generate_audio(self, text: str, voice: str, stitch_long_output: bool = True) -> Tuple[torch.Tensor, float]:
"""Generate audio and measure processing time""" """Generate audio and measure processing time"""
start_time = time.time() start_time = time.time()
@ -65,8 +111,24 @@ class TTSService:
model, device = TTSModel.get_instance() model, device = TTSModel.get_instance()
voicepack = TTSModel.get_voicepack(voice) voicepack = TTSModel.get_voicepack(voice)
# Generate audio # Generate audio with or without stitching
audio, _ = generate(model, text, voicepack, lang=voice[0]) if stitch_long_output:
# Split text if needed and generate audio for each chunk
chunks = self._split_text(text, voice)
audio_chunks = []
for chunk in chunks:
chunk_audio, _ = generate(model, chunk, voicepack, lang=voice[0])
audio_chunks.append(chunk_audio)
# Concatenate audio chunks
if len(audio_chunks) > 1:
audio = np.concatenate(audio_chunks)
else:
audio = audio_chunks[0]
else:
# Generate single chunk without splitting
audio, _ = generate(model, text, voicepack, lang=voice[0])
processing_time = time.time() - start_time processing_time = time.time() - start_time
return audio, processing_time return audio, processing_time
@ -87,15 +149,15 @@ class TTSService:
while True: while True:
next_request = self.db.get_next_pending() next_request = self.db.get_next_pending()
if next_request: if next_request:
request_id, text, voice = next_request request_id, text, voice, stitch_long_output = next_request
try: try:
# Generate audio and measure time # Generate audio and measure time
audio, processing_time = self._generate_audio(text, voice) audio, processing_time = self._generate_audio(text, voice, stitch_long_output)
# Save to file # Save to file
output_file = os.path.join( output_file = os.path.abspath(os.path.join(
self.output_dir, f"speech_{request_id}.wav" self.output_dir, f"speech_{request_id}.wav"
) ))
self._save_audio(audio, output_file) self._save_audio(audio, output_file)
# Update status with processing time # Update status with processing time
@ -124,9 +186,9 @@ class TTSService:
print(f"Error listing voices: {str(e)}") print(f"Error listing voices: {str(e)}")
return voices return voices
def create_tts_request(self, text: str, voice: str = "af") -> int: def create_tts_request(self, text: str, voice: str = "af", stitch_long_output: bool = True) -> int:
"""Create a new TTS request and return the request ID""" """Create a new TTS request and return the request ID"""
return self.db.add_request(text, voice) return self.db.add_request(text, voice, stitch_long_output)
def get_request_status( def get_request_status(
self, request_id: int self, request_id: int

View file

@ -79,13 +79,19 @@ def main():
text = f.read() text = f.read()
# Create range of sizes up to full text # Create range of sizes up to full text
sizes = [100, 250, 500, 750, 1000, 1500, 2000, 3000, 4000, 5000, 6000, 7000, len(text)] sizes = [100, 250, 500, 750, 1000, 1500, 2000, 3000, 4000, 5000, 6000, 7000]
# Process chunks # Process chunks
results = [] results = []
import random
for size in sizes: for size in sizes:
# Get chunk and count tokens # Get random starting point ensuring we have enough text left
chunk = text[:size] max_start = len(text) - size
if max_start > 0:
start = random.randint(0, max_start)
chunk = text[start:start + size]
else:
chunk = text[:size]
num_tokens = count_tokens(chunk) num_tokens = count_tokens(chunk)
print(f"\nProcessing chunk with {num_tokens} tokens ({size} chars):") print(f"\nProcessing chunk with {num_tokens} tokens ({size} chars):")
@ -106,23 +112,78 @@ def main():
# Create DataFrame for plotting # Create DataFrame for plotting
df = pd.DataFrame(results) df = pd.DataFrame(results)
# Set the style
sns.set_theme(style="darkgrid", palette="husl", font_scale=1.1)
# Common plot settings
def setup_plot(fig, ax, title):
# Improve grid
ax.grid(True, linestyle='--', alpha=0.7)
# Set title and labels with better fonts
ax.set_title(title, pad=20, fontsize=16, fontweight='bold')
ax.set_xlabel(ax.get_xlabel(), fontsize=12, fontweight='medium')
ax.set_ylabel(ax.get_ylabel(), fontsize=12, fontweight='medium')
# Improve tick labels
ax.tick_params(labelsize=10)
# Add subtle spines
for spine in ax.spines.values():
spine.set_color('#666666')
spine.set_linewidth(0.5)
return fig, ax
# Plot 1: Processing Time vs Output Length # Plot 1: Processing Time vs Output Length
plt.figure(figsize=(12, 8)) fig, ax = plt.subplots(figsize=(12, 8))
sns.scatterplot(data=df, x='output_length', y='processing_time')
sns.regplot(data=df, x='output_length', y='processing_time', scatter=False) # Create scatter plot with custom styling
plt.title('Processing Time vs Output Length') scatter = sns.scatterplot(data=df, x='output_length', y='processing_time',
plt.xlabel('Output Audio Length (seconds)') s=100, alpha=0.6, color='#2ecc71')
plt.ylabel('Processing Time (seconds)')
# Add regression line with confidence interval
sns.regplot(data=df, x='output_length', y='processing_time',
scatter=False, color='#e74c3c', line_kws={'linewidth': 2})
# Calculate correlation
corr = df['output_length'].corr(df['processing_time'])
# Add correlation annotation
plt.text(0.05, 0.95, f'Correlation: {corr:.2f}',
transform=ax.transAxes, fontsize=10,
bbox=dict(facecolor='white', edgecolor='none', alpha=0.7))
setup_plot(fig, ax, 'Processing Time vs Output Length')
ax.set_xlabel('Output Audio Length (seconds)')
ax.set_ylabel('Processing Time (seconds)')
plt.savefig('examples/time_vs_output.png', dpi=300, bbox_inches='tight') plt.savefig('examples/time_vs_output.png', dpi=300, bbox_inches='tight')
plt.close() plt.close()
# Plot 2: Processing Time vs Token Count # Plot 2: Processing Time vs Token Count
plt.figure(figsize=(12, 8)) fig, ax = plt.subplots(figsize=(12, 8))
sns.scatterplot(data=df, x='tokens', y='processing_time')
sns.regplot(data=df, x='tokens', y='processing_time', scatter=False) # Create scatter plot with custom styling
plt.title('Processing Time vs Token Count') scatter = sns.scatterplot(data=df, x='tokens', y='processing_time',
plt.xlabel('Number of Input Tokens') s=100, alpha=0.6, color='#3498db')
plt.ylabel('Processing Time (seconds)')
# Add regression line with confidence interval
sns.regplot(data=df, x='tokens', y='processing_time',
scatter=False, color='#e74c3c', line_kws={'linewidth': 2})
# Calculate correlation
corr = df['tokens'].corr(df['processing_time'])
# Add correlation annotation
plt.text(0.05, 0.95, f'Correlation: {corr:.2f}',
transform=ax.transAxes, fontsize=10,
bbox=dict(facecolor='white', edgecolor='none', alpha=0.7))
setup_plot(fig, ax, 'Processing Time vs Token Count')
ax.set_xlabel('Number of Input Tokens')
ax.set_ylabel('Processing Time (seconds)')
plt.savefig('examples/time_vs_tokens.png', dpi=300, bbox_inches='tight') plt.savefig('examples/time_vs_tokens.png', dpi=300, bbox_inches='tight')
plt.close() plt.close()