Enhance TTS service to handle pauses and trailing newlines in text processing. Updated smart_split to preserve newlines and added logic for generating silence chunks during pauses. Improved error handling and logging for audio processing.

This commit is contained in:
Lukin 2025-04-07 13:21:49 +08:00
parent f1fa340494
commit b31f79d8d7
2 changed files with 177 additions and 65 deletions

View file

@ -110,8 +110,17 @@ def get_sentence_info(
if not sentence:
continue
full = sentence + punct
tokens = process_text_chunk(full)
# Check if the original text segment ended with newline(s) before punctuation
original_segment = sentences[i]
trailing_newlines = ""
match = re.search(r"(\n+)$", original_segment)
if match:
trailing_newlines = match.group(1)
full = sentence + punct + trailing_newlines # Append trailing newlines
# Tokenize without the trailing newlines for accurate TTS processing
tokens = process_text_chunk(sentence + punct)
# Store the full text including newlines for later check
results.append((full, tokens, len(tokens)))
return results
@ -161,28 +170,35 @@ async def smart_split(
if count > max_tokens:
# Yield current chunk if any
if current_chunk:
chunk_text = " ".join(current_chunk)
# Join with space, but preserve original trailing newline of the last sentence if present
last_sentence_original = current_chunk[-1]
chunk_text_joined = " ".join(current_chunk)
if last_sentence_original.endswith("\n"):
chunk_text_joined += "\n" # Preserve the newline marker
chunk_count += 1
logger.debug(
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
f"Yielding text chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({current_count} tokens)"
)
yield chunk_text, current_tokens
yield chunk_text_joined, current_tokens, None # Pass the text with potential trailing newline
current_chunk = []
current_tokens = []
current_count = 0
# Split long sentence on commas
clauses = re.split(r"([,])", sentence)
clause_chunk = []
# Split long sentence on commas (simple approach)
# Keep original sentence text ('sentence' now includes potential trailing newline)
clauses = re.split(r"([,])", sentence.rstrip('\n')) # Split text part only
trailing_newline_in_sentence = "\n" if sentence.endswith("\n") else ""
clause_chunk = [] # Stores original clause text including potential trailing newline
clause_tokens = []
clause_count = 0
for j in range(0, len(clauses), 2):
clause = clauses[j].strip()
# clause = clauses[j].strip() # Don't strip here to preserve internal structure
clause = clauses[j]
comma = clauses[j + 1] if j + 1 < len(clauses) else ""
if not clause:
continue
if not clause.strip(): # Check if clause is just whitespace
full_clause = clause + comma
@ -200,77 +216,95 @@ async def smart_split(
else:
# Yield clause chunk if we have one
if clause_chunk:
chunk_text = " ".join(clause_chunk)
# Join with space, preserve last clause's potential trailing newline
last_clause_original = clause_chunk[-1]
chunk_text_joined = " ".join(clause_chunk)
if last_clause_original.endswith("\n"):
chunk_text_joined += "\n"
chunk_count += 1
logger.debug(
f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)"
f"Yielding clause chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({clause_count} tokens)"
)
yield chunk_text, clause_tokens
clause_chunk = [full_clause]
clause_tokens = tokens
clause_count = count
yield chunk_text_joined, clause_tokens, None
# Start new clause chunk with original text
clause_chunk = [full_clause + (trailing_newline_in_sentence if j == len(clauses) - 2 else "")]
clause_tokens = clause_token_list
clause_count = clause_token_count
# Don't forget last clause chunk
if clause_chunk:
chunk_text = " ".join(clause_chunk)
chunk_count += 1
logger.debug(
f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)"
)
yield chunk_text, clause_tokens
# Don't forget last clause chunk
if clause_chunk:
# Join with space, preserve last clause's potential trailing newline
last_clause_original = clause_chunk[-1]
chunk_text_joined = " ".join(clause_chunk)
# The trailing newline logic was added when creating the chunk above
#if last_clause_original.endswith("\n"):
# chunk_text_joined += "\n"
# Regular sentence handling
elif (
chunk_count += 1
logger.debug(
f"Yielding final clause chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({clause_count} tokens)"
)
yield chunk_text_joined, clause_tokens, None
current_count >= settings.target_min_tokens
and current_count + count > settings.target_max_tokens
):
# If we have a good sized chunk and adding next sentence exceeds target,
# yield current chunk and start new one
chunk_text = " ".join(current_chunk)
# Yield current chunk and start new one
last_sentence_original = current_chunk[-1]
chunk_text_joined = " ".join(current_chunk)
if last_sentence_original.endswith("\n"):
chunk_text_joined += "\n"
chunk_count += 1
logger.info(
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
f"Yielding text chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({current_count} tokens)"
)
yield chunk_text, current_tokens
current_chunk = [sentence]
yield chunk_text_joined, current_tokens, None
current_chunk = [sentence] # sentence includes potential trailing newline
current_tokens = tokens
current_count = count
elif current_count + count <= settings.target_max_tokens:
# Keep building chunk while under target max
current_chunk.append(sentence)
# Keep building chunk
current_chunk.append(sentence) # sentence includes potential trailing newline
current_tokens.extend(tokens)
current_count += count
elif (
current_count + count <= max_tokens
and current_count < settings.target_min_tokens
):
# Only exceed target max if we haven't reached minimum size yet
current_chunk.append(sentence)
# Exceed target max only if below min size
current_chunk.append(sentence) # sentence includes potential trailing newline
current_tokens.extend(tokens)
current_count += count
else:
# Yield current chunk and start new one
if current_chunk:
chunk_text = " ".join(current_chunk)
last_sentence_original = current_chunk[-1]
chunk_text_joined = " ".join(current_chunk)
if last_sentence_original.endswith("\n"):
chunk_text_joined += "\n"
chunk_count += 1
logger.info(
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
f"Yielding text chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({current_count} tokens)"
)
yield chunk_text, current_tokens
current_chunk = [sentence]
yield chunk_text_joined, current_tokens, None
current_chunk = [sentence] # sentence includes potential trailing newline
current_tokens = tokens
current_count = count
# Don't forget the last chunk
# Yield any remaining text chunk
if current_chunk:
chunk_text = " ".join(current_chunk)
last_sentence_original = current_chunk[-1]
chunk_text_joined = " ".join(current_chunk)
if last_sentence_original.endswith("\n"):
chunk_text_joined += "\n"
chunk_count += 1
logger.info(
f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
f"Yielding final text chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({current_count} tokens)"
)
yield chunk_text, current_tokens
yield chunk_text_joined, current_tokens, None
total_time = time.time() - start_time
logger.info(
f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks"
)
)

View file

@ -280,18 +280,56 @@ class TTSService:
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream"
)
# Process text in chunks with smart splitting
async for chunk_text, tokens in smart_split(
# Process text in chunks with smart splitting, handling pauses
async for chunk_text, tokens, pause_duration_s in smart_split(
text,
lang_code=pipeline_lang_code,
normalization_options=normalization_options,
):
try:
# Process audio for chunk
async for chunk_data in self._process_chunk(
chunk_text, # Pass text for Kokoro V1
tokens, # Pass tokens for legacy backends
voice_name, # Pass voice name
if pause_duration_s is not None and pause_duration_s > 0:
# --- Handle Pause Chunk ---
try:
logger.debug(f"Generating {pause_duration_s}s silence chunk")
silence_samples = int(pause_duration_s * settings.sample_rate)
# Use float32 zeros as AudioService will normalize later
silence_audio = np.zeros(silence_samples, dtype=np.float32)
pause_chunk = AudioChunk(audio=silence_audio, word_timestamps=[]) # Empty timestamps for silence
# Convert silence chunk to the target format using AudioService
if output_format:
formatted_pause_chunk = await AudioService.convert_audio(
pause_chunk,
output_format,
writer,
speed=1.0, # Speed doesn't affect silence
chunk_text="", # No text for silence
is_last_chunk=False, # Not the final chunk
trim_audio=False, # Don't trim silence
normalizer=stream_normalizer,
)
if formatted_pause_chunk.output:
yield formatted_pause_chunk
else:
# If no output format (raw audio), yield the raw chunk
# Ensure normalization happens if needed (AudioService handles this)
pause_chunk.audio = stream_normalizer.normalize(pause_chunk.audio)
yield pause_chunk # Yield raw silence chunk
# Update offset based on silence duration
current_offset += pause_duration_s
chunk_index += 1
except Exception as e:
logger.error(f"Failed to process pause chunk: {str(e)}")
continue
elif tokens or chunk_text:
# --- Handle Text Chunk ---
try:
# Process audio for the text chunk
async for chunk_data in self._process_chunk(
chunk_text, # Pass text for Kokoro V1
tokens, # Pass tokens for legacy backends
voice_name, # Pass voice name
voice_path, # Pass voice path
speed,
writer,
@ -307,19 +345,48 @@ class TTSService:
timestamp.start_time += current_offset
timestamp.end_time += current_offset
current_offset += len(chunk_data.audio) / 24000
if chunk_data.output is not None:
yield chunk_data
else:
logger.warning(
f"No audio generated for chunk: '{chunk_text[:100]}...'"
)
# If no output format (raw audio), yield the raw chunk
# Ensure normalization happens if needed (AudioService handles this)
pause_chunk.audio = stream_normalizer.normalize(pause_chunk.audio)
if len(pause_chunk.audio) > 0: # Only yield if silence is not zero length
yield pause_chunk # Yield raw silence chunk
# Update offset based on silence duration
f"No audio generated for chunk: '{chunk_text.strip()[:100]}...'"
chunk_index += 1
except Exception as e:
# --- Add pause after newline ---
# Check the original chunk_text passed from smart_split for trailing newline
if chunk_text.endswith('\n'):
newline_pause_s = 0.5
try:
logger.debug(f"Adding {newline_pause_s}s pause after newline.")
silence_samples = int(newline_pause_s * settings.sample_rate)
silence_audio = np.zeros(silence_samples, dtype=np.float32)
pause_chunk = AudioChunk(audio=silence_audio, word_timestamps=[])
if output_format:
formatted_pause_chunk = await AudioService.convert_audio(
pause_chunk, output_format, writer, speed=1.0, chunk_text="",
is_last_chunk=False, trim_audio=False, normalizer=stream_normalizer,
)
if formatted_pause_chunk.output:
yield formatted_pause_chunk
else:
pause_chunk.audio = stream_normalizer.normalize(pause_chunk.audio)
if len(pause_chunk.audio) > 0:
yield pause_chunk
current_offset += newline_pause_s # Add newline pause to offset
except Exception as pause_e:
logger.error(f"Failed to process newline pause chunk: {str(pause_e)}")
# -------------------------------
except Exception as e:
logger.error(
f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}"
f"Failed to process audio for chunk: '{chunk_text.strip()[:100]}...'. Error: {str(e)}"
)
continue
@ -374,7 +441,18 @@ class TTSService:
output_format=None,
):
if len(audio_stream_data.audio) > 0:
audio_data_chunks.append(audio_stream_data)
# Ensure we only append chunks with actual audio data
# Raw silence chunks generated for pauses will have audio data (zeros)
# Formatted silence chunks might have empty audio but non-empty output
if len(audio_stream_data.audio) > 0 or (output_format and audio_stream_data.output):
audio_data_chunks.append(audio_stream_data)
if not audio_data_chunks:
# Handle cases where only pauses were present or generation failed
logger.warning("No valid audio chunks generated.")
# Return an empty AudioChunk or raise an error? Returning empty for now.
return AudioChunk(audio=np.array([], dtype=np.int16), word_timestamps=[])
combined_audio_data = AudioChunk.combine(audio_data_chunks)
return combined_audio_data
@ -456,4 +534,4 @@ class TTSService:
except Exception as e:
logger.error(f"Error in phoneme audio generation: {str(e)}")
raise
raise