Enhance TTS service to handle pauses and trailing newlines in text processing. Updated smart_split to preserve newlines and added logic for generating silence chunks during pauses. Improved error handling and logging for audio processing.

This commit is contained in:
Lukin 2025-04-07 13:21:49 +08:00
parent f1fa340494
commit b31f79d8d7
2 changed files with 177 additions and 65 deletions

View file

@ -110,8 +110,17 @@ def get_sentence_info(
if not sentence: if not sentence:
continue continue
full = sentence + punct # Check if the original text segment ended with newline(s) before punctuation
tokens = process_text_chunk(full) original_segment = sentences[i]
trailing_newlines = ""
match = re.search(r"(\n+)$", original_segment)
if match:
trailing_newlines = match.group(1)
full = sentence + punct + trailing_newlines # Append trailing newlines
# Tokenize without the trailing newlines for accurate TTS processing
tokens = process_text_chunk(sentence + punct)
# Store the full text including newlines for later check
results.append((full, tokens, len(tokens))) results.append((full, tokens, len(tokens)))
return results return results
@ -161,28 +170,35 @@ async def smart_split(
if count > max_tokens: if count > max_tokens:
# Yield current chunk if any # Yield current chunk if any
if current_chunk: if current_chunk:
chunk_text = " ".join(current_chunk) # Join with space, but preserve original trailing newline of the last sentence if present
last_sentence_original = current_chunk[-1]
chunk_text_joined = " ".join(current_chunk)
if last_sentence_original.endswith("\n"):
chunk_text_joined += "\n" # Preserve the newline marker
chunk_count += 1 chunk_count += 1
logger.debug( logger.debug(
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" f"Yielding text chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({current_count} tokens)"
) )
yield chunk_text, current_tokens yield chunk_text_joined, current_tokens, None # Pass the text with potential trailing newline
current_chunk = [] current_chunk = []
current_tokens = [] current_tokens = []
current_count = 0 current_count = 0
# Split long sentence on commas # Split long sentence on commas (simple approach)
clauses = re.split(r"([,])", sentence) # Keep original sentence text ('sentence' now includes potential trailing newline)
clause_chunk = [] clauses = re.split(r"([,])", sentence.rstrip('\n')) # Split text part only
trailing_newline_in_sentence = "\n" if sentence.endswith("\n") else ""
clause_chunk = [] # Stores original clause text including potential trailing newline
clause_tokens = [] clause_tokens = []
clause_count = 0 clause_count = 0
for j in range(0, len(clauses), 2): for j in range(0, len(clauses), 2):
clause = clauses[j].strip() # clause = clauses[j].strip() # Don't strip here to preserve internal structure
clause = clauses[j]
comma = clauses[j + 1] if j + 1 < len(clauses) else "" comma = clauses[j + 1] if j + 1 < len(clauses) else ""
if not clause: if not clause.strip(): # Check if clause is just whitespace
continue
full_clause = clause + comma full_clause = clause + comma
@ -200,77 +216,95 @@ async def smart_split(
else: else:
# Yield clause chunk if we have one # Yield clause chunk if we have one
if clause_chunk: if clause_chunk:
chunk_text = " ".join(clause_chunk) # Join with space, preserve last clause's potential trailing newline
last_clause_original = clause_chunk[-1]
chunk_text_joined = " ".join(clause_chunk)
if last_clause_original.endswith("\n"):
chunk_text_joined += "\n"
chunk_count += 1 chunk_count += 1
logger.debug( logger.debug(
f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)" f"Yielding clause chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({clause_count} tokens)"
) )
yield chunk_text, clause_tokens yield chunk_text_joined, clause_tokens, None
clause_chunk = [full_clause] # Start new clause chunk with original text
clause_tokens = tokens clause_chunk = [full_clause + (trailing_newline_in_sentence if j == len(clauses) - 2 else "")]
clause_count = count clause_tokens = clause_token_list
clause_count = clause_token_count
# Don't forget last clause chunk # Don't forget last clause chunk
if clause_chunk: if clause_chunk:
chunk_text = " ".join(clause_chunk) # Join with space, preserve last clause's potential trailing newline
chunk_count += 1 last_clause_original = clause_chunk[-1]
logger.debug( chunk_text_joined = " ".join(clause_chunk)
f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)" # The trailing newline logic was added when creating the chunk above
) #if last_clause_original.endswith("\n"):
yield chunk_text, clause_tokens # chunk_text_joined += "\n"
# Regular sentence handling chunk_count += 1
elif ( logger.debug(
f"Yielding final clause chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({clause_count} tokens)"
)
yield chunk_text_joined, clause_tokens, None
current_count >= settings.target_min_tokens current_count >= settings.target_min_tokens
and current_count + count > settings.target_max_tokens and current_count + count > settings.target_max_tokens
): ):
# If we have a good sized chunk and adding next sentence exceeds target, # If we have a good sized chunk and adding next sentence exceeds target,
# yield current chunk and start new one # Yield current chunk and start new one
chunk_text = " ".join(current_chunk) last_sentence_original = current_chunk[-1]
chunk_text_joined = " ".join(current_chunk)
if last_sentence_original.endswith("\n"):
chunk_text_joined += "\n"
chunk_count += 1 chunk_count += 1
logger.info( logger.info(
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" f"Yielding text chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({current_count} tokens)"
) )
yield chunk_text, current_tokens yield chunk_text_joined, current_tokens, None
current_chunk = [sentence] current_chunk = [sentence] # sentence includes potential trailing newline
current_tokens = tokens current_tokens = tokens
current_count = count current_count = count
elif current_count + count <= settings.target_max_tokens: elif current_count + count <= settings.target_max_tokens:
# Keep building chunk while under target max # Keep building chunk
current_chunk.append(sentence) current_chunk.append(sentence) # sentence includes potential trailing newline
current_tokens.extend(tokens) current_tokens.extend(tokens)
current_count += count current_count += count
elif ( elif (
current_count + count <= max_tokens current_count + count <= max_tokens
and current_count < settings.target_min_tokens and current_count < settings.target_min_tokens
): ):
# Only exceed target max if we haven't reached minimum size yet # Exceed target max only if below min size
current_chunk.append(sentence) current_chunk.append(sentence) # sentence includes potential trailing newline
current_tokens.extend(tokens) current_tokens.extend(tokens)
current_count += count current_count += count
else: else:
# Yield current chunk and start new one # Yield current chunk and start new one
if current_chunk: if current_chunk:
chunk_text = " ".join(current_chunk) last_sentence_original = current_chunk[-1]
chunk_text_joined = " ".join(current_chunk)
if last_sentence_original.endswith("\n"):
chunk_text_joined += "\n"
chunk_count += 1 chunk_count += 1
logger.info( logger.info(
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" f"Yielding text chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({current_count} tokens)"
) )
yield chunk_text, current_tokens yield chunk_text_joined, current_tokens, None
current_chunk = [sentence] current_chunk = [sentence] # sentence includes potential trailing newline
current_tokens = tokens current_tokens = tokens
current_count = count current_count = count
# Don't forget the last chunk # Yield any remaining text chunk
if current_chunk: if current_chunk:
chunk_text = " ".join(current_chunk) last_sentence_original = current_chunk[-1]
chunk_text_joined = " ".join(current_chunk)
if last_sentence_original.endswith("\n"):
chunk_text_joined += "\n"
chunk_count += 1 chunk_count += 1
logger.info( logger.info(
f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" f"Yielding final text chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({current_count} tokens)"
) )
yield chunk_text, current_tokens yield chunk_text_joined, current_tokens, None
total_time = time.time() - start_time total_time = time.time() - start_time
logger.info( logger.info(
f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks" f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks"
) )

View file

@ -280,18 +280,56 @@ class TTSService:
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream" f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream"
) )
# Process text in chunks with smart splitting # Process text in chunks with smart splitting, handling pauses
async for chunk_text, tokens in smart_split( async for chunk_text, tokens, pause_duration_s in smart_split(
text, text,
lang_code=pipeline_lang_code, lang_code=pipeline_lang_code,
normalization_options=normalization_options, normalization_options=normalization_options,
): ):
try: if pause_duration_s is not None and pause_duration_s > 0:
# Process audio for chunk # --- Handle Pause Chunk ---
async for chunk_data in self._process_chunk( try:
chunk_text, # Pass text for Kokoro V1 logger.debug(f"Generating {pause_duration_s}s silence chunk")
tokens, # Pass tokens for legacy backends silence_samples = int(pause_duration_s * settings.sample_rate)
voice_name, # Pass voice name # Use float32 zeros as AudioService will normalize later
silence_audio = np.zeros(silence_samples, dtype=np.float32)
pause_chunk = AudioChunk(audio=silence_audio, word_timestamps=[]) # Empty timestamps for silence
# Convert silence chunk to the target format using AudioService
if output_format:
formatted_pause_chunk = await AudioService.convert_audio(
pause_chunk,
output_format,
writer,
speed=1.0, # Speed doesn't affect silence
chunk_text="", # No text for silence
is_last_chunk=False, # Not the final chunk
trim_audio=False, # Don't trim silence
normalizer=stream_normalizer,
)
if formatted_pause_chunk.output:
yield formatted_pause_chunk
else:
# If no output format (raw audio), yield the raw chunk
# Ensure normalization happens if needed (AudioService handles this)
pause_chunk.audio = stream_normalizer.normalize(pause_chunk.audio)
yield pause_chunk # Yield raw silence chunk
# Update offset based on silence duration
current_offset += pause_duration_s
chunk_index += 1
except Exception as e:
logger.error(f"Failed to process pause chunk: {str(e)}")
continue
elif tokens or chunk_text:
# --- Handle Text Chunk ---
try:
# Process audio for the text chunk
async for chunk_data in self._process_chunk(
chunk_text, # Pass text for Kokoro V1
tokens, # Pass tokens for legacy backends
voice_name, # Pass voice name
voice_path, # Pass voice path voice_path, # Pass voice path
speed, speed,
writer, writer,
@ -307,19 +345,48 @@ class TTSService:
timestamp.start_time += current_offset timestamp.start_time += current_offset
timestamp.end_time += current_offset timestamp.end_time += current_offset
current_offset += len(chunk_data.audio) / 24000
if chunk_data.output is not None:
yield chunk_data
else: else:
logger.warning( # If no output format (raw audio), yield the raw chunk
f"No audio generated for chunk: '{chunk_text[:100]}...'" # Ensure normalization happens if needed (AudioService handles this)
) pause_chunk.audio = stream_normalizer.normalize(pause_chunk.audio)
if len(pause_chunk.audio) > 0: # Only yield if silence is not zero length
yield pause_chunk # Yield raw silence chunk
# Update offset based on silence duration
f"No audio generated for chunk: '{chunk_text.strip()[:100]}...'"
chunk_index += 1 chunk_index += 1
except Exception as e: # --- Add pause after newline ---
# Check the original chunk_text passed from smart_split for trailing newline
if chunk_text.endswith('\n'):
newline_pause_s = 0.5
try:
logger.debug(f"Adding {newline_pause_s}s pause after newline.")
silence_samples = int(newline_pause_s * settings.sample_rate)
silence_audio = np.zeros(silence_samples, dtype=np.float32)
pause_chunk = AudioChunk(audio=silence_audio, word_timestamps=[])
if output_format:
formatted_pause_chunk = await AudioService.convert_audio(
pause_chunk, output_format, writer, speed=1.0, chunk_text="",
is_last_chunk=False, trim_audio=False, normalizer=stream_normalizer,
)
if formatted_pause_chunk.output:
yield formatted_pause_chunk
else:
pause_chunk.audio = stream_normalizer.normalize(pause_chunk.audio)
if len(pause_chunk.audio) > 0:
yield pause_chunk
current_offset += newline_pause_s # Add newline pause to offset
except Exception as pause_e:
logger.error(f"Failed to process newline pause chunk: {str(pause_e)}")
# -------------------------------
except Exception as e:
logger.error( logger.error(
f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}" f"Failed to process audio for chunk: '{chunk_text.strip()[:100]}...'. Error: {str(e)}"
) )
continue continue
@ -374,7 +441,18 @@ class TTSService:
output_format=None, output_format=None,
): ):
if len(audio_stream_data.audio) > 0: if len(audio_stream_data.audio) > 0:
audio_data_chunks.append(audio_stream_data) # Ensure we only append chunks with actual audio data
# Raw silence chunks generated for pauses will have audio data (zeros)
# Formatted silence chunks might have empty audio but non-empty output
if len(audio_stream_data.audio) > 0 or (output_format and audio_stream_data.output):
audio_data_chunks.append(audio_stream_data)
if not audio_data_chunks:
# Handle cases where only pauses were present or generation failed
logger.warning("No valid audio chunks generated.")
# Return an empty AudioChunk or raise an error? Returning empty for now.
return AudioChunk(audio=np.array([], dtype=np.int16), word_timestamps=[])
combined_audio_data = AudioChunk.combine(audio_data_chunks) combined_audio_data = AudioChunk.combine(audio_data_chunks)
return combined_audio_data return combined_audio_data
@ -456,4 +534,4 @@ class TTSService:
except Exception as e: except Exception as e:
logger.error(f"Error in phoneme audio generation: {str(e)}") logger.error(f"Error in phoneme audio generation: {str(e)}")
raise raise