mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
feat: enabled support for stitching long outputs in TTS requests
This commit is contained in:
parent
30581129c0
commit
79d5332c8a
5 changed files with 164 additions and 31 deletions
|
@ -21,6 +21,7 @@ class QueueDB:
|
||||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
text TEXT NOT NULL,
|
text TEXT NOT NULL,
|
||||||
voice TEXT DEFAULT 'af',
|
voice TEXT DEFAULT 'af',
|
||||||
|
stitch_long_output BOOLEAN DEFAULT 1,
|
||||||
status TEXT DEFAULT 'pending',
|
status TEXT DEFAULT 'pending',
|
||||||
output_file TEXT,
|
output_file TEXT,
|
||||||
processing_time REAL,
|
processing_time REAL,
|
||||||
|
@ -37,6 +38,7 @@ class QueueDB:
|
||||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
text TEXT NOT NULL,
|
text TEXT NOT NULL,
|
||||||
voice TEXT DEFAULT 'af',
|
voice TEXT DEFAULT 'af',
|
||||||
|
stitch_long_output BOOLEAN DEFAULT 1,
|
||||||
status TEXT DEFAULT 'pending',
|
status TEXT DEFAULT 'pending',
|
||||||
output_file TEXT,
|
output_file TEXT,
|
||||||
processing_time REAL,
|
processing_time REAL,
|
||||||
|
@ -44,13 +46,14 @@ class QueueDB:
|
||||||
""")
|
""")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
def add_request(self, text: str, voice: str) -> int:
|
def add_request(self, text: str, voice: str, stitch_long_output: bool = True) -> int:
|
||||||
"""Add a new TTS request to the queue"""
|
"""Add a new TTS request to the queue"""
|
||||||
conn = sqlite3.connect(self.db_path)
|
conn = sqlite3.connect(self.db_path)
|
||||||
try:
|
try:
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
c.execute(
|
c.execute(
|
||||||
"INSERT INTO tts_queue (text, voice) VALUES (?, ?)", (text, voice)
|
"INSERT INTO tts_queue (text, voice, stitch_long_output) VALUES (?, ?, ?)",
|
||||||
|
(text, voice, stitch_long_output)
|
||||||
)
|
)
|
||||||
request_id = c.lastrowid
|
request_id = c.lastrowid
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
@ -59,7 +62,8 @@ class QueueDB:
|
||||||
self._ensure_table_if_needed(conn)
|
self._ensure_table_if_needed(conn)
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
c.execute(
|
c.execute(
|
||||||
"INSERT INTO tts_queue (text, voice) VALUES (?, ?)", (text, voice)
|
"INSERT INTO tts_queue (text, voice, stitch_long_output) VALUES (?, ?, ?)",
|
||||||
|
(text, voice, stitch_long_output)
|
||||||
)
|
)
|
||||||
request_id = c.lastrowid
|
request_id = c.lastrowid
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
@ -73,7 +77,7 @@ class QueueDB:
|
||||||
try:
|
try:
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
c.execute(
|
c.execute(
|
||||||
'SELECT id, text, voice FROM tts_queue WHERE status = "pending" ORDER BY created_at ASC LIMIT 1'
|
'SELECT id, text, voice, stitch_long_output FROM tts_queue WHERE status = "pending" ORDER BY created_at ASC LIMIT 1'
|
||||||
)
|
)
|
||||||
return c.fetchone()
|
return c.fetchone()
|
||||||
except sqlite3.OperationalError: # Table doesn't exist
|
except sqlite3.OperationalError: # Table doesn't exist
|
||||||
|
|
|
@ -7,6 +7,7 @@ class TTSRequest(BaseModel):
|
||||||
text: str
|
text: str
|
||||||
voice: str = "af" # Default voice
|
voice: str = "af" # Default voice
|
||||||
local: bool = False # Whether to save file locally or return bytes
|
local: bool = False # Whether to save file locally or return bytes
|
||||||
|
stitch_long_output: bool = True # Whether to stitch together long outputs
|
||||||
|
|
||||||
|
|
||||||
class TTSResponse(BaseModel):
|
class TTSResponse(BaseModel):
|
||||||
|
@ -19,3 +20,4 @@ class TTSResponse(BaseModel):
|
||||||
class VoicesResponse(BaseModel):
|
class VoicesResponse(BaseModel):
|
||||||
voices: list[str]
|
voices: list[str]
|
||||||
default: str
|
default: str
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,11 @@ async def create_tts(request: TTSRequest):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Queue the request
|
# Queue the request
|
||||||
request_id = tts_service.create_tts_request(request.text, request.voice)
|
request_id = tts_service.create_tts_request(
|
||||||
|
request.text,
|
||||||
|
request.voice,
|
||||||
|
request.stitch_long_output
|
||||||
|
)
|
||||||
return {
|
return {
|
||||||
"request_id": request_id,
|
"request_id": request_id,
|
||||||
"status": "pending",
|
"status": "pending",
|
||||||
|
|
|
@ -2,11 +2,12 @@ import os
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import io
|
import io
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple, List
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import scipy.io.wavfile as wavfile
|
import scipy.io.wavfile as wavfile
|
||||||
from models import build_model
|
from models import build_model
|
||||||
from kokoro import generate
|
from kokoro import generate, phonemize, tokenize
|
||||||
from ..database.queue import QueueDB
|
from ..database.queue import QueueDB
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,7 +58,52 @@ class TTSService:
|
||||||
self.worker = threading.Thread(target=self._process_queue, daemon=True)
|
self.worker = threading.Thread(target=self._process_queue, daemon=True)
|
||||||
self.worker.start()
|
self.worker.start()
|
||||||
|
|
||||||
def _generate_audio(self, text: str, voice: str) -> Tuple[torch.Tensor, float]:
|
def _find_boundary(self, text: str, max_tokens: int, voice: str, margin: int = 50) -> int:
|
||||||
|
"""Find the closest sentence/clause boundary within token limit"""
|
||||||
|
# Try different boundary markers in order of preference
|
||||||
|
for marker in ['. ', '; ', ', ']:
|
||||||
|
# Look for the last occurrence of marker before max_tokens
|
||||||
|
test_text = text[:max_tokens + margin] # Look a bit beyond the limit
|
||||||
|
last_idx = test_text.rfind(marker)
|
||||||
|
|
||||||
|
if last_idx != -1:
|
||||||
|
# Verify this boundary is within our token limit
|
||||||
|
candidate = text[:last_idx + len(marker)].strip()
|
||||||
|
ps = phonemize(candidate, voice[0])
|
||||||
|
tokens = tokenize(ps)
|
||||||
|
|
||||||
|
if len(tokens) <= max_tokens:
|
||||||
|
return last_idx + len(marker)
|
||||||
|
|
||||||
|
# If no good boundary found, find last whitespace within limit
|
||||||
|
test_text = text[:max_tokens]
|
||||||
|
last_space = test_text.rfind(' ')
|
||||||
|
return last_space if last_space != -1 else max_tokens
|
||||||
|
|
||||||
|
def _split_text(self, text: str, voice: str) -> List[str]:
|
||||||
|
"""Split text into chunks that respect token limits and try to maintain sentence structure"""
|
||||||
|
MAX_TOKENS = 450 # Leave wider margin from 510 limit to account for tokenizer differences
|
||||||
|
chunks = []
|
||||||
|
remaining = text
|
||||||
|
|
||||||
|
while remaining:
|
||||||
|
# If remaining text is within limit, add it as final chunk
|
||||||
|
ps = phonemize(remaining, voice[0])
|
||||||
|
tokens = tokenize(ps)
|
||||||
|
if len(tokens) <= MAX_TOKENS:
|
||||||
|
chunks.append(remaining.strip())
|
||||||
|
break
|
||||||
|
|
||||||
|
# Find best boundary position
|
||||||
|
split_pos = self._find_boundary(remaining, MAX_TOKENS, voice)
|
||||||
|
|
||||||
|
# Add chunk and continue with remaining text
|
||||||
|
chunks.append(remaining[:split_pos].strip())
|
||||||
|
remaining = remaining[split_pos:].strip()
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def _generate_audio(self, text: str, voice: str, stitch_long_output: bool = True) -> Tuple[torch.Tensor, float]:
|
||||||
"""Generate audio and measure processing time"""
|
"""Generate audio and measure processing time"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
@ -65,8 +111,24 @@ class TTSService:
|
||||||
model, device = TTSModel.get_instance()
|
model, device = TTSModel.get_instance()
|
||||||
voicepack = TTSModel.get_voicepack(voice)
|
voicepack = TTSModel.get_voicepack(voice)
|
||||||
|
|
||||||
# Generate audio
|
# Generate audio with or without stitching
|
||||||
audio, _ = generate(model, text, voicepack, lang=voice[0])
|
if stitch_long_output:
|
||||||
|
# Split text if needed and generate audio for each chunk
|
||||||
|
chunks = self._split_text(text, voice)
|
||||||
|
audio_chunks = []
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
chunk_audio, _ = generate(model, chunk, voicepack, lang=voice[0])
|
||||||
|
audio_chunks.append(chunk_audio)
|
||||||
|
|
||||||
|
# Concatenate audio chunks
|
||||||
|
if len(audio_chunks) > 1:
|
||||||
|
audio = np.concatenate(audio_chunks)
|
||||||
|
else:
|
||||||
|
audio = audio_chunks[0]
|
||||||
|
else:
|
||||||
|
# Generate single chunk without splitting
|
||||||
|
audio, _ = generate(model, text, voicepack, lang=voice[0])
|
||||||
|
|
||||||
processing_time = time.time() - start_time
|
processing_time = time.time() - start_time
|
||||||
return audio, processing_time
|
return audio, processing_time
|
||||||
|
@ -87,15 +149,15 @@ class TTSService:
|
||||||
while True:
|
while True:
|
||||||
next_request = self.db.get_next_pending()
|
next_request = self.db.get_next_pending()
|
||||||
if next_request:
|
if next_request:
|
||||||
request_id, text, voice = next_request
|
request_id, text, voice, stitch_long_output = next_request
|
||||||
try:
|
try:
|
||||||
# Generate audio and measure time
|
# Generate audio and measure time
|
||||||
audio, processing_time = self._generate_audio(text, voice)
|
audio, processing_time = self._generate_audio(text, voice, stitch_long_output)
|
||||||
|
|
||||||
# Save to file
|
# Save to file
|
||||||
output_file = os.path.join(
|
output_file = os.path.abspath(os.path.join(
|
||||||
self.output_dir, f"speech_{request_id}.wav"
|
self.output_dir, f"speech_{request_id}.wav"
|
||||||
)
|
))
|
||||||
self._save_audio(audio, output_file)
|
self._save_audio(audio, output_file)
|
||||||
|
|
||||||
# Update status with processing time
|
# Update status with processing time
|
||||||
|
@ -124,9 +186,9 @@ class TTSService:
|
||||||
print(f"Error listing voices: {str(e)}")
|
print(f"Error listing voices: {str(e)}")
|
||||||
return voices
|
return voices
|
||||||
|
|
||||||
def create_tts_request(self, text: str, voice: str = "af") -> int:
|
def create_tts_request(self, text: str, voice: str = "af", stitch_long_output: bool = True) -> int:
|
||||||
"""Create a new TTS request and return the request ID"""
|
"""Create a new TTS request and return the request ID"""
|
||||||
return self.db.add_request(text, voice)
|
return self.db.add_request(text, voice, stitch_long_output)
|
||||||
|
|
||||||
def get_request_status(
|
def get_request_status(
|
||||||
self, request_id: int
|
self, request_id: int
|
||||||
|
|
|
@ -79,13 +79,19 @@ def main():
|
||||||
text = f.read()
|
text = f.read()
|
||||||
|
|
||||||
# Create range of sizes up to full text
|
# Create range of sizes up to full text
|
||||||
sizes = [100, 250, 500, 750, 1000, 1500, 2000, 3000, 4000, 5000, 6000, 7000, len(text)]
|
sizes = [100, 250, 500, 750, 1000, 1500, 2000, 3000, 4000, 5000, 6000, 7000]
|
||||||
|
|
||||||
# Process chunks
|
# Process chunks
|
||||||
results = []
|
results = []
|
||||||
|
import random
|
||||||
for size in sizes:
|
for size in sizes:
|
||||||
# Get chunk and count tokens
|
# Get random starting point ensuring we have enough text left
|
||||||
chunk = text[:size]
|
max_start = len(text) - size
|
||||||
|
if max_start > 0:
|
||||||
|
start = random.randint(0, max_start)
|
||||||
|
chunk = text[start:start + size]
|
||||||
|
else:
|
||||||
|
chunk = text[:size]
|
||||||
num_tokens = count_tokens(chunk)
|
num_tokens = count_tokens(chunk)
|
||||||
|
|
||||||
print(f"\nProcessing chunk with {num_tokens} tokens ({size} chars):")
|
print(f"\nProcessing chunk with {num_tokens} tokens ({size} chars):")
|
||||||
|
@ -106,23 +112,78 @@ def main():
|
||||||
# Create DataFrame for plotting
|
# Create DataFrame for plotting
|
||||||
df = pd.DataFrame(results)
|
df = pd.DataFrame(results)
|
||||||
|
|
||||||
|
# Set the style
|
||||||
|
sns.set_theme(style="darkgrid", palette="husl", font_scale=1.1)
|
||||||
|
|
||||||
|
# Common plot settings
|
||||||
|
def setup_plot(fig, ax, title):
|
||||||
|
# Improve grid
|
||||||
|
ax.grid(True, linestyle='--', alpha=0.7)
|
||||||
|
|
||||||
|
# Set title and labels with better fonts
|
||||||
|
ax.set_title(title, pad=20, fontsize=16, fontweight='bold')
|
||||||
|
ax.set_xlabel(ax.get_xlabel(), fontsize=12, fontweight='medium')
|
||||||
|
ax.set_ylabel(ax.get_ylabel(), fontsize=12, fontweight='medium')
|
||||||
|
|
||||||
|
# Improve tick labels
|
||||||
|
ax.tick_params(labelsize=10)
|
||||||
|
|
||||||
|
# Add subtle spines
|
||||||
|
for spine in ax.spines.values():
|
||||||
|
spine.set_color('#666666')
|
||||||
|
spine.set_linewidth(0.5)
|
||||||
|
|
||||||
|
return fig, ax
|
||||||
|
|
||||||
# Plot 1: Processing Time vs Output Length
|
# Plot 1: Processing Time vs Output Length
|
||||||
plt.figure(figsize=(12, 8))
|
fig, ax = plt.subplots(figsize=(12, 8))
|
||||||
sns.scatterplot(data=df, x='output_length', y='processing_time')
|
|
||||||
sns.regplot(data=df, x='output_length', y='processing_time', scatter=False)
|
# Create scatter plot with custom styling
|
||||||
plt.title('Processing Time vs Output Length')
|
scatter = sns.scatterplot(data=df, x='output_length', y='processing_time',
|
||||||
plt.xlabel('Output Audio Length (seconds)')
|
s=100, alpha=0.6, color='#2ecc71')
|
||||||
plt.ylabel('Processing Time (seconds)')
|
|
||||||
|
# Add regression line with confidence interval
|
||||||
|
sns.regplot(data=df, x='output_length', y='processing_time',
|
||||||
|
scatter=False, color='#e74c3c', line_kws={'linewidth': 2})
|
||||||
|
|
||||||
|
# Calculate correlation
|
||||||
|
corr = df['output_length'].corr(df['processing_time'])
|
||||||
|
|
||||||
|
# Add correlation annotation
|
||||||
|
plt.text(0.05, 0.95, f'Correlation: {corr:.2f}',
|
||||||
|
transform=ax.transAxes, fontsize=10,
|
||||||
|
bbox=dict(facecolor='white', edgecolor='none', alpha=0.7))
|
||||||
|
|
||||||
|
setup_plot(fig, ax, 'Processing Time vs Output Length')
|
||||||
|
ax.set_xlabel('Output Audio Length (seconds)')
|
||||||
|
ax.set_ylabel('Processing Time (seconds)')
|
||||||
|
|
||||||
plt.savefig('examples/time_vs_output.png', dpi=300, bbox_inches='tight')
|
plt.savefig('examples/time_vs_output.png', dpi=300, bbox_inches='tight')
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
# Plot 2: Processing Time vs Token Count
|
# Plot 2: Processing Time vs Token Count
|
||||||
plt.figure(figsize=(12, 8))
|
fig, ax = plt.subplots(figsize=(12, 8))
|
||||||
sns.scatterplot(data=df, x='tokens', y='processing_time')
|
|
||||||
sns.regplot(data=df, x='tokens', y='processing_time', scatter=False)
|
# Create scatter plot with custom styling
|
||||||
plt.title('Processing Time vs Token Count')
|
scatter = sns.scatterplot(data=df, x='tokens', y='processing_time',
|
||||||
plt.xlabel('Number of Input Tokens')
|
s=100, alpha=0.6, color='#3498db')
|
||||||
plt.ylabel('Processing Time (seconds)')
|
|
||||||
|
# Add regression line with confidence interval
|
||||||
|
sns.regplot(data=df, x='tokens', y='processing_time',
|
||||||
|
scatter=False, color='#e74c3c', line_kws={'linewidth': 2})
|
||||||
|
|
||||||
|
# Calculate correlation
|
||||||
|
corr = df['tokens'].corr(df['processing_time'])
|
||||||
|
|
||||||
|
# Add correlation annotation
|
||||||
|
plt.text(0.05, 0.95, f'Correlation: {corr:.2f}',
|
||||||
|
transform=ax.transAxes, fontsize=10,
|
||||||
|
bbox=dict(facecolor='white', edgecolor='none', alpha=0.7))
|
||||||
|
|
||||||
|
setup_plot(fig, ax, 'Processing Time vs Token Count')
|
||||||
|
ax.set_xlabel('Number of Input Tokens')
|
||||||
|
ax.set_ylabel('Processing Time (seconds)')
|
||||||
|
|
||||||
plt.savefig('examples/time_vs_tokens.png', dpi=300, bbox_inches='tight')
|
plt.savefig('examples/time_vs_tokens.png', dpi=300, bbox_inches='tight')
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue