mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
feat: enabled support for stitching long outputs in TTS requests
This commit is contained in:
parent
30581129c0
commit
79d5332c8a
5 changed files with 164 additions and 31 deletions
|
@ -21,6 +21,7 @@ class QueueDB:
|
|||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
text TEXT NOT NULL,
|
||||
voice TEXT DEFAULT 'af',
|
||||
stitch_long_output BOOLEAN DEFAULT 1,
|
||||
status TEXT DEFAULT 'pending',
|
||||
output_file TEXT,
|
||||
processing_time REAL,
|
||||
|
@ -37,6 +38,7 @@ class QueueDB:
|
|||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
text TEXT NOT NULL,
|
||||
voice TEXT DEFAULT 'af',
|
||||
stitch_long_output BOOLEAN DEFAULT 1,
|
||||
status TEXT DEFAULT 'pending',
|
||||
output_file TEXT,
|
||||
processing_time REAL,
|
||||
|
@ -44,13 +46,14 @@ class QueueDB:
|
|||
""")
|
||||
conn.commit()
|
||||
|
||||
def add_request(self, text: str, voice: str) -> int:
|
||||
def add_request(self, text: str, voice: str, stitch_long_output: bool = True) -> int:
|
||||
"""Add a new TTS request to the queue"""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
try:
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
"INSERT INTO tts_queue (text, voice) VALUES (?, ?)", (text, voice)
|
||||
"INSERT INTO tts_queue (text, voice, stitch_long_output) VALUES (?, ?, ?)",
|
||||
(text, voice, stitch_long_output)
|
||||
)
|
||||
request_id = c.lastrowid
|
||||
conn.commit()
|
||||
|
@ -59,7 +62,8 @@ class QueueDB:
|
|||
self._ensure_table_if_needed(conn)
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
"INSERT INTO tts_queue (text, voice) VALUES (?, ?)", (text, voice)
|
||||
"INSERT INTO tts_queue (text, voice, stitch_long_output) VALUES (?, ?, ?)",
|
||||
(text, voice, stitch_long_output)
|
||||
)
|
||||
request_id = c.lastrowid
|
||||
conn.commit()
|
||||
|
@ -73,7 +77,7 @@ class QueueDB:
|
|||
try:
|
||||
c = conn.cursor()
|
||||
c.execute(
|
||||
'SELECT id, text, voice FROM tts_queue WHERE status = "pending" ORDER BY created_at ASC LIMIT 1'
|
||||
'SELECT id, text, voice, stitch_long_output FROM tts_queue WHERE status = "pending" ORDER BY created_at ASC LIMIT 1'
|
||||
)
|
||||
return c.fetchone()
|
||||
except sqlite3.OperationalError: # Table doesn't exist
|
||||
|
|
|
@ -7,6 +7,7 @@ class TTSRequest(BaseModel):
|
|||
text: str
|
||||
voice: str = "af" # Default voice
|
||||
local: bool = False # Whether to save file locally or return bytes
|
||||
stitch_long_output: bool = True # Whether to stitch together long outputs
|
||||
|
||||
|
||||
class TTSResponse(BaseModel):
|
||||
|
@ -19,3 +20,4 @@ class TTSResponse(BaseModel):
|
|||
class VoicesResponse(BaseModel):
|
||||
voices: list[str]
|
||||
default: str
|
||||
|
||||
|
|
|
@ -32,7 +32,11 @@ async def create_tts(request: TTSRequest):
|
|||
)
|
||||
|
||||
# Queue the request
|
||||
request_id = tts_service.create_tts_request(request.text, request.voice)
|
||||
request_id = tts_service.create_tts_request(
|
||||
request.text,
|
||||
request.voice,
|
||||
request.stitch_long_output
|
||||
)
|
||||
return {
|
||||
"request_id": request_id,
|
||||
"status": "pending",
|
||||
|
|
|
@ -2,11 +2,12 @@ import os
|
|||
import threading
|
||||
import time
|
||||
import io
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, List
|
||||
import numpy as np
|
||||
import torch
|
||||
import scipy.io.wavfile as wavfile
|
||||
from models import build_model
|
||||
from kokoro import generate
|
||||
from kokoro import generate, phonemize, tokenize
|
||||
from ..database.queue import QueueDB
|
||||
|
||||
|
||||
|
@ -57,7 +58,52 @@ class TTSService:
|
|||
self.worker = threading.Thread(target=self._process_queue, daemon=True)
|
||||
self.worker.start()
|
||||
|
||||
def _generate_audio(self, text: str, voice: str) -> Tuple[torch.Tensor, float]:
|
||||
def _find_boundary(self, text: str, max_tokens: int, voice: str, margin: int = 50) -> int:
|
||||
"""Find the closest sentence/clause boundary within token limit"""
|
||||
# Try different boundary markers in order of preference
|
||||
for marker in ['. ', '; ', ', ']:
|
||||
# Look for the last occurrence of marker before max_tokens
|
||||
test_text = text[:max_tokens + margin] # Look a bit beyond the limit
|
||||
last_idx = test_text.rfind(marker)
|
||||
|
||||
if last_idx != -1:
|
||||
# Verify this boundary is within our token limit
|
||||
candidate = text[:last_idx + len(marker)].strip()
|
||||
ps = phonemize(candidate, voice[0])
|
||||
tokens = tokenize(ps)
|
||||
|
||||
if len(tokens) <= max_tokens:
|
||||
return last_idx + len(marker)
|
||||
|
||||
# If no good boundary found, find last whitespace within limit
|
||||
test_text = text[:max_tokens]
|
||||
last_space = test_text.rfind(' ')
|
||||
return last_space if last_space != -1 else max_tokens
|
||||
|
||||
def _split_text(self, text: str, voice: str) -> List[str]:
|
||||
"""Split text into chunks that respect token limits and try to maintain sentence structure"""
|
||||
MAX_TOKENS = 450 # Leave wider margin from 510 limit to account for tokenizer differences
|
||||
chunks = []
|
||||
remaining = text
|
||||
|
||||
while remaining:
|
||||
# If remaining text is within limit, add it as final chunk
|
||||
ps = phonemize(remaining, voice[0])
|
||||
tokens = tokenize(ps)
|
||||
if len(tokens) <= MAX_TOKENS:
|
||||
chunks.append(remaining.strip())
|
||||
break
|
||||
|
||||
# Find best boundary position
|
||||
split_pos = self._find_boundary(remaining, MAX_TOKENS, voice)
|
||||
|
||||
# Add chunk and continue with remaining text
|
||||
chunks.append(remaining[:split_pos].strip())
|
||||
remaining = remaining[split_pos:].strip()
|
||||
|
||||
return chunks
|
||||
|
||||
def _generate_audio(self, text: str, voice: str, stitch_long_output: bool = True) -> Tuple[torch.Tensor, float]:
|
||||
"""Generate audio and measure processing time"""
|
||||
start_time = time.time()
|
||||
|
||||
|
@ -65,8 +111,24 @@ class TTSService:
|
|||
model, device = TTSModel.get_instance()
|
||||
voicepack = TTSModel.get_voicepack(voice)
|
||||
|
||||
# Generate audio
|
||||
audio, _ = generate(model, text, voicepack, lang=voice[0])
|
||||
# Generate audio with or without stitching
|
||||
if stitch_long_output:
|
||||
# Split text if needed and generate audio for each chunk
|
||||
chunks = self._split_text(text, voice)
|
||||
audio_chunks = []
|
||||
|
||||
for chunk in chunks:
|
||||
chunk_audio, _ = generate(model, chunk, voicepack, lang=voice[0])
|
||||
audio_chunks.append(chunk_audio)
|
||||
|
||||
# Concatenate audio chunks
|
||||
if len(audio_chunks) > 1:
|
||||
audio = np.concatenate(audio_chunks)
|
||||
else:
|
||||
audio = audio_chunks[0]
|
||||
else:
|
||||
# Generate single chunk without splitting
|
||||
audio, _ = generate(model, text, voicepack, lang=voice[0])
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
return audio, processing_time
|
||||
|
@ -87,15 +149,15 @@ class TTSService:
|
|||
while True:
|
||||
next_request = self.db.get_next_pending()
|
||||
if next_request:
|
||||
request_id, text, voice = next_request
|
||||
request_id, text, voice, stitch_long_output = next_request
|
||||
try:
|
||||
# Generate audio and measure time
|
||||
audio, processing_time = self._generate_audio(text, voice)
|
||||
audio, processing_time = self._generate_audio(text, voice, stitch_long_output)
|
||||
|
||||
# Save to file
|
||||
output_file = os.path.join(
|
||||
output_file = os.path.abspath(os.path.join(
|
||||
self.output_dir, f"speech_{request_id}.wav"
|
||||
)
|
||||
))
|
||||
self._save_audio(audio, output_file)
|
||||
|
||||
# Update status with processing time
|
||||
|
@ -124,9 +186,9 @@ class TTSService:
|
|||
print(f"Error listing voices: {str(e)}")
|
||||
return voices
|
||||
|
||||
def create_tts_request(self, text: str, voice: str = "af") -> int:
|
||||
def create_tts_request(self, text: str, voice: str = "af", stitch_long_output: bool = True) -> int:
|
||||
"""Create a new TTS request and return the request ID"""
|
||||
return self.db.add_request(text, voice)
|
||||
return self.db.add_request(text, voice, stitch_long_output)
|
||||
|
||||
def get_request_status(
|
||||
self, request_id: int
|
||||
|
|
|
@ -79,13 +79,19 @@ def main():
|
|||
text = f.read()
|
||||
|
||||
# Create range of sizes up to full text
|
||||
sizes = [100, 250, 500, 750, 1000, 1500, 2000, 3000, 4000, 5000, 6000, 7000, len(text)]
|
||||
sizes = [100, 250, 500, 750, 1000, 1500, 2000, 3000, 4000, 5000, 6000, 7000]
|
||||
|
||||
# Process chunks
|
||||
results = []
|
||||
import random
|
||||
for size in sizes:
|
||||
# Get chunk and count tokens
|
||||
chunk = text[:size]
|
||||
# Get random starting point ensuring we have enough text left
|
||||
max_start = len(text) - size
|
||||
if max_start > 0:
|
||||
start = random.randint(0, max_start)
|
||||
chunk = text[start:start + size]
|
||||
else:
|
||||
chunk = text[:size]
|
||||
num_tokens = count_tokens(chunk)
|
||||
|
||||
print(f"\nProcessing chunk with {num_tokens} tokens ({size} chars):")
|
||||
|
@ -106,23 +112,78 @@ def main():
|
|||
# Create DataFrame for plotting
|
||||
df = pd.DataFrame(results)
|
||||
|
||||
# Set the style
|
||||
sns.set_theme(style="darkgrid", palette="husl", font_scale=1.1)
|
||||
|
||||
# Common plot settings
|
||||
def setup_plot(fig, ax, title):
|
||||
# Improve grid
|
||||
ax.grid(True, linestyle='--', alpha=0.7)
|
||||
|
||||
# Set title and labels with better fonts
|
||||
ax.set_title(title, pad=20, fontsize=16, fontweight='bold')
|
||||
ax.set_xlabel(ax.get_xlabel(), fontsize=12, fontweight='medium')
|
||||
ax.set_ylabel(ax.get_ylabel(), fontsize=12, fontweight='medium')
|
||||
|
||||
# Improve tick labels
|
||||
ax.tick_params(labelsize=10)
|
||||
|
||||
# Add subtle spines
|
||||
for spine in ax.spines.values():
|
||||
spine.set_color('#666666')
|
||||
spine.set_linewidth(0.5)
|
||||
|
||||
return fig, ax
|
||||
|
||||
# Plot 1: Processing Time vs Output Length
|
||||
plt.figure(figsize=(12, 8))
|
||||
sns.scatterplot(data=df, x='output_length', y='processing_time')
|
||||
sns.regplot(data=df, x='output_length', y='processing_time', scatter=False)
|
||||
plt.title('Processing Time vs Output Length')
|
||||
plt.xlabel('Output Audio Length (seconds)')
|
||||
plt.ylabel('Processing Time (seconds)')
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
|
||||
# Create scatter plot with custom styling
|
||||
scatter = sns.scatterplot(data=df, x='output_length', y='processing_time',
|
||||
s=100, alpha=0.6, color='#2ecc71')
|
||||
|
||||
# Add regression line with confidence interval
|
||||
sns.regplot(data=df, x='output_length', y='processing_time',
|
||||
scatter=False, color='#e74c3c', line_kws={'linewidth': 2})
|
||||
|
||||
# Calculate correlation
|
||||
corr = df['output_length'].corr(df['processing_time'])
|
||||
|
||||
# Add correlation annotation
|
||||
plt.text(0.05, 0.95, f'Correlation: {corr:.2f}',
|
||||
transform=ax.transAxes, fontsize=10,
|
||||
bbox=dict(facecolor='white', edgecolor='none', alpha=0.7))
|
||||
|
||||
setup_plot(fig, ax, 'Processing Time vs Output Length')
|
||||
ax.set_xlabel('Output Audio Length (seconds)')
|
||||
ax.set_ylabel('Processing Time (seconds)')
|
||||
|
||||
plt.savefig('examples/time_vs_output.png', dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
# Plot 2: Processing Time vs Token Count
|
||||
plt.figure(figsize=(12, 8))
|
||||
sns.scatterplot(data=df, x='tokens', y='processing_time')
|
||||
sns.regplot(data=df, x='tokens', y='processing_time', scatter=False)
|
||||
plt.title('Processing Time vs Token Count')
|
||||
plt.xlabel('Number of Input Tokens')
|
||||
plt.ylabel('Processing Time (seconds)')
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
|
||||
# Create scatter plot with custom styling
|
||||
scatter = sns.scatterplot(data=df, x='tokens', y='processing_time',
|
||||
s=100, alpha=0.6, color='#3498db')
|
||||
|
||||
# Add regression line with confidence interval
|
||||
sns.regplot(data=df, x='tokens', y='processing_time',
|
||||
scatter=False, color='#e74c3c', line_kws={'linewidth': 2})
|
||||
|
||||
# Calculate correlation
|
||||
corr = df['tokens'].corr(df['processing_time'])
|
||||
|
||||
# Add correlation annotation
|
||||
plt.text(0.05, 0.95, f'Correlation: {corr:.2f}',
|
||||
transform=ax.transAxes, fontsize=10,
|
||||
bbox=dict(facecolor='white', edgecolor='none', alpha=0.7))
|
||||
|
||||
setup_plot(fig, ax, 'Processing Time vs Token Count')
|
||||
ax.set_xlabel('Number of Input Tokens')
|
||||
ax.set_ylabel('Processing Time (seconds)')
|
||||
|
||||
plt.savefig('examples/time_vs_tokens.png', dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue