feat: enabled support for stitching long outputs in TTS requests

This commit is contained in:
remsky 2024-12-30 06:16:18 -07:00
parent 30581129c0
commit 79d5332c8a
5 changed files with 164 additions and 31 deletions

View file

@ -21,6 +21,7 @@ class QueueDB:
(id INTEGER PRIMARY KEY AUTOINCREMENT,
text TEXT NOT NULL,
voice TEXT DEFAULT 'af',
stitch_long_output BOOLEAN DEFAULT 1,
status TEXT DEFAULT 'pending',
output_file TEXT,
processing_time REAL,
@ -37,6 +38,7 @@ class QueueDB:
(id INTEGER PRIMARY KEY AUTOINCREMENT,
text TEXT NOT NULL,
voice TEXT DEFAULT 'af',
stitch_long_output BOOLEAN DEFAULT 1,
status TEXT DEFAULT 'pending',
output_file TEXT,
processing_time REAL,
@ -44,13 +46,14 @@ class QueueDB:
""")
conn.commit()
def add_request(self, text: str, voice: str) -> int:
def add_request(self, text: str, voice: str, stitch_long_output: bool = True) -> int:
"""Add a new TTS request to the queue"""
conn = sqlite3.connect(self.db_path)
try:
c = conn.cursor()
c.execute(
"INSERT INTO tts_queue (text, voice) VALUES (?, ?)", (text, voice)
"INSERT INTO tts_queue (text, voice, stitch_long_output) VALUES (?, ?, ?)",
(text, voice, stitch_long_output)
)
request_id = c.lastrowid
conn.commit()
@ -59,7 +62,8 @@ class QueueDB:
self._ensure_table_if_needed(conn)
c = conn.cursor()
c.execute(
"INSERT INTO tts_queue (text, voice) VALUES (?, ?)", (text, voice)
"INSERT INTO tts_queue (text, voice, stitch_long_output) VALUES (?, ?, ?)",
(text, voice, stitch_long_output)
)
request_id = c.lastrowid
conn.commit()
@ -73,7 +77,7 @@ class QueueDB:
try:
c = conn.cursor()
c.execute(
'SELECT id, text, voice FROM tts_queue WHERE status = "pending" ORDER BY created_at ASC LIMIT 1'
'SELECT id, text, voice, stitch_long_output FROM tts_queue WHERE status = "pending" ORDER BY created_at ASC LIMIT 1'
)
return c.fetchone()
except sqlite3.OperationalError: # Table doesn't exist

View file

@ -7,6 +7,7 @@ class TTSRequest(BaseModel):
text: str
voice: str = "af" # Default voice
local: bool = False # Whether to save file locally or return bytes
stitch_long_output: bool = True # Whether to stitch together long outputs
class TTSResponse(BaseModel):
@ -19,3 +20,4 @@ class TTSResponse(BaseModel):
class VoicesResponse(BaseModel):
voices: list[str]
default: str

View file

@ -32,7 +32,11 @@ async def create_tts(request: TTSRequest):
)
# Queue the request
request_id = tts_service.create_tts_request(request.text, request.voice)
request_id = tts_service.create_tts_request(
request.text,
request.voice,
request.stitch_long_output
)
return {
"request_id": request_id,
"status": "pending",

View file

@ -2,11 +2,12 @@ import os
import threading
import time
import io
from typing import Optional, Tuple
from typing import Optional, Tuple, List
import numpy as np
import torch
import scipy.io.wavfile as wavfile
from models import build_model
from kokoro import generate
from kokoro import generate, phonemize, tokenize
from ..database.queue import QueueDB
@ -57,7 +58,52 @@ class TTSService:
self.worker = threading.Thread(target=self._process_queue, daemon=True)
self.worker.start()
def _generate_audio(self, text: str, voice: str) -> Tuple[torch.Tensor, float]:
def _find_boundary(self, text: str, max_tokens: int, voice: str, margin: int = 50) -> int:
"""Find the closest sentence/clause boundary within token limit"""
# Try different boundary markers in order of preference
for marker in ['. ', '; ', ', ']:
# Look for the last occurrence of marker before max_tokens
test_text = text[:max_tokens + margin] # Look a bit beyond the limit
last_idx = test_text.rfind(marker)
if last_idx != -1:
# Verify this boundary is within our token limit
candidate = text[:last_idx + len(marker)].strip()
ps = phonemize(candidate, voice[0])
tokens = tokenize(ps)
if len(tokens) <= max_tokens:
return last_idx + len(marker)
# If no good boundary found, find last whitespace within limit
test_text = text[:max_tokens]
last_space = test_text.rfind(' ')
return last_space if last_space != -1 else max_tokens
def _split_text(self, text: str, voice: str) -> List[str]:
"""Split text into chunks that respect token limits and try to maintain sentence structure"""
MAX_TOKENS = 450 # Leave wider margin from 510 limit to account for tokenizer differences
chunks = []
remaining = text
while remaining:
# If remaining text is within limit, add it as final chunk
ps = phonemize(remaining, voice[0])
tokens = tokenize(ps)
if len(tokens) <= MAX_TOKENS:
chunks.append(remaining.strip())
break
# Find best boundary position
split_pos = self._find_boundary(remaining, MAX_TOKENS, voice)
# Add chunk and continue with remaining text
chunks.append(remaining[:split_pos].strip())
remaining = remaining[split_pos:].strip()
return chunks
def _generate_audio(self, text: str, voice: str, stitch_long_output: bool = True) -> Tuple[torch.Tensor, float]:
"""Generate audio and measure processing time"""
start_time = time.time()
@ -65,8 +111,24 @@ class TTSService:
model, device = TTSModel.get_instance()
voicepack = TTSModel.get_voicepack(voice)
# Generate audio
audio, _ = generate(model, text, voicepack, lang=voice[0])
# Generate audio with or without stitching
if stitch_long_output:
# Split text if needed and generate audio for each chunk
chunks = self._split_text(text, voice)
audio_chunks = []
for chunk in chunks:
chunk_audio, _ = generate(model, chunk, voicepack, lang=voice[0])
audio_chunks.append(chunk_audio)
# Concatenate audio chunks
if len(audio_chunks) > 1:
audio = np.concatenate(audio_chunks)
else:
audio = audio_chunks[0]
else:
# Generate single chunk without splitting
audio, _ = generate(model, text, voicepack, lang=voice[0])
processing_time = time.time() - start_time
return audio, processing_time
@ -87,15 +149,15 @@ class TTSService:
while True:
next_request = self.db.get_next_pending()
if next_request:
request_id, text, voice = next_request
request_id, text, voice, stitch_long_output = next_request
try:
# Generate audio and measure time
audio, processing_time = self._generate_audio(text, voice)
audio, processing_time = self._generate_audio(text, voice, stitch_long_output)
# Save to file
output_file = os.path.join(
output_file = os.path.abspath(os.path.join(
self.output_dir, f"speech_{request_id}.wav"
)
))
self._save_audio(audio, output_file)
# Update status with processing time
@ -124,9 +186,9 @@ class TTSService:
print(f"Error listing voices: {str(e)}")
return voices
def create_tts_request(self, text: str, voice: str = "af") -> int:
def create_tts_request(self, text: str, voice: str = "af", stitch_long_output: bool = True) -> int:
"""Create a new TTS request and return the request ID"""
return self.db.add_request(text, voice)
return self.db.add_request(text, voice, stitch_long_output)
def get_request_status(
self, request_id: int

View file

@ -79,13 +79,19 @@ def main():
text = f.read()
# Create range of sizes up to full text
sizes = [100, 250, 500, 750, 1000, 1500, 2000, 3000, 4000, 5000, 6000, 7000, len(text)]
sizes = [100, 250, 500, 750, 1000, 1500, 2000, 3000, 4000, 5000, 6000, 7000]
# Process chunks
results = []
import random
for size in sizes:
# Get chunk and count tokens
chunk = text[:size]
# Get random starting point ensuring we have enough text left
max_start = len(text) - size
if max_start > 0:
start = random.randint(0, max_start)
chunk = text[start:start + size]
else:
chunk = text[:size]
num_tokens = count_tokens(chunk)
print(f"\nProcessing chunk with {num_tokens} tokens ({size} chars):")
@ -106,23 +112,78 @@ def main():
# Create DataFrame for plotting
df = pd.DataFrame(results)
# Set the style
sns.set_theme(style="darkgrid", palette="husl", font_scale=1.1)
# Common plot settings
def setup_plot(fig, ax, title):
# Improve grid
ax.grid(True, linestyle='--', alpha=0.7)
# Set title and labels with better fonts
ax.set_title(title, pad=20, fontsize=16, fontweight='bold')
ax.set_xlabel(ax.get_xlabel(), fontsize=12, fontweight='medium')
ax.set_ylabel(ax.get_ylabel(), fontsize=12, fontweight='medium')
# Improve tick labels
ax.tick_params(labelsize=10)
# Add subtle spines
for spine in ax.spines.values():
spine.set_color('#666666')
spine.set_linewidth(0.5)
return fig, ax
# Plot 1: Processing Time vs Output Length
plt.figure(figsize=(12, 8))
sns.scatterplot(data=df, x='output_length', y='processing_time')
sns.regplot(data=df, x='output_length', y='processing_time', scatter=False)
plt.title('Processing Time vs Output Length')
plt.xlabel('Output Audio Length (seconds)')
plt.ylabel('Processing Time (seconds)')
fig, ax = plt.subplots(figsize=(12, 8))
# Create scatter plot with custom styling
scatter = sns.scatterplot(data=df, x='output_length', y='processing_time',
s=100, alpha=0.6, color='#2ecc71')
# Add regression line with confidence interval
sns.regplot(data=df, x='output_length', y='processing_time',
scatter=False, color='#e74c3c', line_kws={'linewidth': 2})
# Calculate correlation
corr = df['output_length'].corr(df['processing_time'])
# Add correlation annotation
plt.text(0.05, 0.95, f'Correlation: {corr:.2f}',
transform=ax.transAxes, fontsize=10,
bbox=dict(facecolor='white', edgecolor='none', alpha=0.7))
setup_plot(fig, ax, 'Processing Time vs Output Length')
ax.set_xlabel('Output Audio Length (seconds)')
ax.set_ylabel('Processing Time (seconds)')
plt.savefig('examples/time_vs_output.png', dpi=300, bbox_inches='tight')
plt.close()
# Plot 2: Processing Time vs Token Count
plt.figure(figsize=(12, 8))
sns.scatterplot(data=df, x='tokens', y='processing_time')
sns.regplot(data=df, x='tokens', y='processing_time', scatter=False)
plt.title('Processing Time vs Token Count')
plt.xlabel('Number of Input Tokens')
plt.ylabel('Processing Time (seconds)')
fig, ax = plt.subplots(figsize=(12, 8))
# Create scatter plot with custom styling
scatter = sns.scatterplot(data=df, x='tokens', y='processing_time',
s=100, alpha=0.6, color='#3498db')
# Add regression line with confidence interval
sns.regplot(data=df, x='tokens', y='processing_time',
scatter=False, color='#e74c3c', line_kws={'linewidth': 2})
# Calculate correlation
corr = df['tokens'].corr(df['processing_time'])
# Add correlation annotation
plt.text(0.05, 0.95, f'Correlation: {corr:.2f}',
transform=ax.transAxes, fontsize=10,
bbox=dict(facecolor='white', edgecolor='none', alpha=0.7))
setup_plot(fig, ax, 'Processing Time vs Token Count')
ax.set_xlabel('Number of Input Tokens')
ax.set_ylabel('Processing Time (seconds)')
plt.savefig('examples/time_vs_tokens.png', dpi=300, bbox_inches='tight')
plt.close()