mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Enhance TTS service to handle pauses and trailing newlines in text processing. Updated smart_split to preserve newlines and added logic for generating silence chunks during pauses. Improved error handling and logging for audio processing.
This commit is contained in:
parent
f1fa340494
commit
b31f79d8d7
2 changed files with 177 additions and 65 deletions
|
@ -110,8 +110,17 @@ def get_sentence_info(
|
||||||
if not sentence:
|
if not sentence:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
full = sentence + punct
|
# Check if the original text segment ended with newline(s) before punctuation
|
||||||
tokens = process_text_chunk(full)
|
original_segment = sentences[i]
|
||||||
|
trailing_newlines = ""
|
||||||
|
match = re.search(r"(\n+)$", original_segment)
|
||||||
|
if match:
|
||||||
|
trailing_newlines = match.group(1)
|
||||||
|
|
||||||
|
full = sentence + punct + trailing_newlines # Append trailing newlines
|
||||||
|
# Tokenize without the trailing newlines for accurate TTS processing
|
||||||
|
tokens = process_text_chunk(sentence + punct)
|
||||||
|
# Store the full text including newlines for later check
|
||||||
results.append((full, tokens, len(tokens)))
|
results.append((full, tokens, len(tokens)))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
@ -161,28 +170,35 @@ async def smart_split(
|
||||||
if count > max_tokens:
|
if count > max_tokens:
|
||||||
# Yield current chunk if any
|
# Yield current chunk if any
|
||||||
if current_chunk:
|
if current_chunk:
|
||||||
chunk_text = " ".join(current_chunk)
|
# Join with space, but preserve original trailing newline of the last sentence if present
|
||||||
|
last_sentence_original = current_chunk[-1]
|
||||||
|
chunk_text_joined = " ".join(current_chunk)
|
||||||
|
if last_sentence_original.endswith("\n"):
|
||||||
|
chunk_text_joined += "\n" # Preserve the newline marker
|
||||||
|
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
|
f"Yielding text chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({current_count} tokens)"
|
||||||
)
|
)
|
||||||
yield chunk_text, current_tokens
|
yield chunk_text_joined, current_tokens, None # Pass the text with potential trailing newline
|
||||||
current_chunk = []
|
current_chunk = []
|
||||||
current_tokens = []
|
current_tokens = []
|
||||||
current_count = 0
|
current_count = 0
|
||||||
|
|
||||||
# Split long sentence on commas
|
# Split long sentence on commas (simple approach)
|
||||||
clauses = re.split(r"([,])", sentence)
|
# Keep original sentence text ('sentence' now includes potential trailing newline)
|
||||||
clause_chunk = []
|
clauses = re.split(r"([,])", sentence.rstrip('\n')) # Split text part only
|
||||||
|
trailing_newline_in_sentence = "\n" if sentence.endswith("\n") else ""
|
||||||
|
clause_chunk = [] # Stores original clause text including potential trailing newline
|
||||||
clause_tokens = []
|
clause_tokens = []
|
||||||
clause_count = 0
|
clause_count = 0
|
||||||
|
|
||||||
for j in range(0, len(clauses), 2):
|
for j in range(0, len(clauses), 2):
|
||||||
clause = clauses[j].strip()
|
# clause = clauses[j].strip() # Don't strip here to preserve internal structure
|
||||||
|
clause = clauses[j]
|
||||||
comma = clauses[j + 1] if j + 1 < len(clauses) else ""
|
comma = clauses[j + 1] if j + 1 < len(clauses) else ""
|
||||||
|
|
||||||
if not clause:
|
if not clause.strip(): # Check if clause is just whitespace
|
||||||
continue
|
|
||||||
|
|
||||||
full_clause = clause + comma
|
full_clause = clause + comma
|
||||||
|
|
||||||
|
@ -200,75 +216,93 @@ async def smart_split(
|
||||||
else:
|
else:
|
||||||
# Yield clause chunk if we have one
|
# Yield clause chunk if we have one
|
||||||
if clause_chunk:
|
if clause_chunk:
|
||||||
chunk_text = " ".join(clause_chunk)
|
# Join with space, preserve last clause's potential trailing newline
|
||||||
|
last_clause_original = clause_chunk[-1]
|
||||||
|
chunk_text_joined = " ".join(clause_chunk)
|
||||||
|
if last_clause_original.endswith("\n"):
|
||||||
|
chunk_text_joined += "\n"
|
||||||
|
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)"
|
f"Yielding clause chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({clause_count} tokens)"
|
||||||
)
|
)
|
||||||
yield chunk_text, clause_tokens
|
yield chunk_text_joined, clause_tokens, None
|
||||||
clause_chunk = [full_clause]
|
# Start new clause chunk with original text
|
||||||
clause_tokens = tokens
|
clause_chunk = [full_clause + (trailing_newline_in_sentence if j == len(clauses) - 2 else "")]
|
||||||
clause_count = count
|
clause_tokens = clause_token_list
|
||||||
|
clause_count = clause_token_count
|
||||||
|
|
||||||
# Don't forget last clause chunk
|
# Don't forget last clause chunk
|
||||||
if clause_chunk:
|
if clause_chunk:
|
||||||
chunk_text = " ".join(clause_chunk)
|
# Join with space, preserve last clause's potential trailing newline
|
||||||
|
last_clause_original = clause_chunk[-1]
|
||||||
|
chunk_text_joined = " ".join(clause_chunk)
|
||||||
|
# The trailing newline logic was added when creating the chunk above
|
||||||
|
#if last_clause_original.endswith("\n"):
|
||||||
|
# chunk_text_joined += "\n"
|
||||||
|
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)"
|
f"Yielding final clause chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({clause_count} tokens)"
|
||||||
)
|
)
|
||||||
yield chunk_text, clause_tokens
|
yield chunk_text_joined, clause_tokens, None
|
||||||
|
|
||||||
# Regular sentence handling
|
|
||||||
elif (
|
|
||||||
current_count >= settings.target_min_tokens
|
current_count >= settings.target_min_tokens
|
||||||
and current_count + count > settings.target_max_tokens
|
and current_count + count > settings.target_max_tokens
|
||||||
):
|
):
|
||||||
# If we have a good sized chunk and adding next sentence exceeds target,
|
# If we have a good sized chunk and adding next sentence exceeds target,
|
||||||
# yield current chunk and start new one
|
# Yield current chunk and start new one
|
||||||
chunk_text = " ".join(current_chunk)
|
last_sentence_original = current_chunk[-1]
|
||||||
|
chunk_text_joined = " ".join(current_chunk)
|
||||||
|
if last_sentence_original.endswith("\n"):
|
||||||
|
chunk_text_joined += "\n"
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
|
f"Yielding text chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({current_count} tokens)"
|
||||||
)
|
)
|
||||||
yield chunk_text, current_tokens
|
yield chunk_text_joined, current_tokens, None
|
||||||
current_chunk = [sentence]
|
current_chunk = [sentence] # sentence includes potential trailing newline
|
||||||
current_tokens = tokens
|
current_tokens = tokens
|
||||||
current_count = count
|
current_count = count
|
||||||
elif current_count + count <= settings.target_max_tokens:
|
elif current_count + count <= settings.target_max_tokens:
|
||||||
# Keep building chunk while under target max
|
# Keep building chunk
|
||||||
current_chunk.append(sentence)
|
current_chunk.append(sentence) # sentence includes potential trailing newline
|
||||||
current_tokens.extend(tokens)
|
current_tokens.extend(tokens)
|
||||||
current_count += count
|
current_count += count
|
||||||
elif (
|
elif (
|
||||||
current_count + count <= max_tokens
|
current_count + count <= max_tokens
|
||||||
and current_count < settings.target_min_tokens
|
and current_count < settings.target_min_tokens
|
||||||
):
|
):
|
||||||
# Only exceed target max if we haven't reached minimum size yet
|
# Exceed target max only if below min size
|
||||||
current_chunk.append(sentence)
|
current_chunk.append(sentence) # sentence includes potential trailing newline
|
||||||
current_tokens.extend(tokens)
|
current_tokens.extend(tokens)
|
||||||
current_count += count
|
current_count += count
|
||||||
else:
|
else:
|
||||||
# Yield current chunk and start new one
|
# Yield current chunk and start new one
|
||||||
if current_chunk:
|
if current_chunk:
|
||||||
chunk_text = " ".join(current_chunk)
|
last_sentence_original = current_chunk[-1]
|
||||||
|
chunk_text_joined = " ".join(current_chunk)
|
||||||
|
if last_sentence_original.endswith("\n"):
|
||||||
|
chunk_text_joined += "\n"
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
|
f"Yielding text chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({current_count} tokens)"
|
||||||
)
|
)
|
||||||
yield chunk_text, current_tokens
|
yield chunk_text_joined, current_tokens, None
|
||||||
current_chunk = [sentence]
|
current_chunk = [sentence] # sentence includes potential trailing newline
|
||||||
current_tokens = tokens
|
current_tokens = tokens
|
||||||
current_count = count
|
current_count = count
|
||||||
|
|
||||||
# Don't forget the last chunk
|
# Yield any remaining text chunk
|
||||||
if current_chunk:
|
if current_chunk:
|
||||||
chunk_text = " ".join(current_chunk)
|
last_sentence_original = current_chunk[-1]
|
||||||
|
chunk_text_joined = " ".join(current_chunk)
|
||||||
|
if last_sentence_original.endswith("\n"):
|
||||||
|
chunk_text_joined += "\n"
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
|
f"Yielding final text chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({current_count} tokens)"
|
||||||
)
|
)
|
||||||
yield chunk_text, current_tokens
|
yield chunk_text_joined, current_tokens, None
|
||||||
|
|
||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
|
@ -280,14 +280,52 @@ class TTSService:
|
||||||
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream"
|
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process text in chunks with smart splitting
|
# Process text in chunks with smart splitting, handling pauses
|
||||||
async for chunk_text, tokens in smart_split(
|
async for chunk_text, tokens, pause_duration_s in smart_split(
|
||||||
text,
|
text,
|
||||||
lang_code=pipeline_lang_code,
|
lang_code=pipeline_lang_code,
|
||||||
normalization_options=normalization_options,
|
normalization_options=normalization_options,
|
||||||
):
|
):
|
||||||
|
if pause_duration_s is not None and pause_duration_s > 0:
|
||||||
|
# --- Handle Pause Chunk ---
|
||||||
try:
|
try:
|
||||||
# Process audio for chunk
|
logger.debug(f"Generating {pause_duration_s}s silence chunk")
|
||||||
|
silence_samples = int(pause_duration_s * settings.sample_rate)
|
||||||
|
# Use float32 zeros as AudioService will normalize later
|
||||||
|
silence_audio = np.zeros(silence_samples, dtype=np.float32)
|
||||||
|
pause_chunk = AudioChunk(audio=silence_audio, word_timestamps=[]) # Empty timestamps for silence
|
||||||
|
|
||||||
|
# Convert silence chunk to the target format using AudioService
|
||||||
|
if output_format:
|
||||||
|
formatted_pause_chunk = await AudioService.convert_audio(
|
||||||
|
pause_chunk,
|
||||||
|
output_format,
|
||||||
|
writer,
|
||||||
|
speed=1.0, # Speed doesn't affect silence
|
||||||
|
chunk_text="", # No text for silence
|
||||||
|
is_last_chunk=False, # Not the final chunk
|
||||||
|
trim_audio=False, # Don't trim silence
|
||||||
|
normalizer=stream_normalizer,
|
||||||
|
)
|
||||||
|
if formatted_pause_chunk.output:
|
||||||
|
yield formatted_pause_chunk
|
||||||
|
else:
|
||||||
|
# If no output format (raw audio), yield the raw chunk
|
||||||
|
# Ensure normalization happens if needed (AudioService handles this)
|
||||||
|
pause_chunk.audio = stream_normalizer.normalize(pause_chunk.audio)
|
||||||
|
yield pause_chunk # Yield raw silence chunk
|
||||||
|
|
||||||
|
# Update offset based on silence duration
|
||||||
|
current_offset += pause_duration_s
|
||||||
|
chunk_index += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to process pause chunk: {str(e)}")
|
||||||
|
continue
|
||||||
|
elif tokens or chunk_text:
|
||||||
|
# --- Handle Text Chunk ---
|
||||||
|
try:
|
||||||
|
# Process audio for the text chunk
|
||||||
async for chunk_data in self._process_chunk(
|
async for chunk_data in self._process_chunk(
|
||||||
chunk_text, # Pass text for Kokoro V1
|
chunk_text, # Pass text for Kokoro V1
|
||||||
tokens, # Pass tokens for legacy backends
|
tokens, # Pass tokens for legacy backends
|
||||||
|
@ -307,19 +345,48 @@ class TTSService:
|
||||||
timestamp.start_time += current_offset
|
timestamp.start_time += current_offset
|
||||||
timestamp.end_time += current_offset
|
timestamp.end_time += current_offset
|
||||||
|
|
||||||
current_offset += len(chunk_data.audio) / 24000
|
|
||||||
|
|
||||||
if chunk_data.output is not None:
|
|
||||||
yield chunk_data
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
# If no output format (raw audio), yield the raw chunk
|
||||||
f"No audio generated for chunk: '{chunk_text[:100]}...'"
|
# Ensure normalization happens if needed (AudioService handles this)
|
||||||
)
|
pause_chunk.audio = stream_normalizer.normalize(pause_chunk.audio)
|
||||||
|
if len(pause_chunk.audio) > 0: # Only yield if silence is not zero length
|
||||||
|
yield pause_chunk # Yield raw silence chunk
|
||||||
|
|
||||||
|
# Update offset based on silence duration
|
||||||
|
f"No audio generated for chunk: '{chunk_text.strip()[:100]}...'"
|
||||||
chunk_index += 1
|
chunk_index += 1
|
||||||
|
# --- Add pause after newline ---
|
||||||
|
# Check the original chunk_text passed from smart_split for trailing newline
|
||||||
|
if chunk_text.endswith('\n'):
|
||||||
|
newline_pause_s = 0.5
|
||||||
|
try:
|
||||||
|
logger.debug(f"Adding {newline_pause_s}s pause after newline.")
|
||||||
|
silence_samples = int(newline_pause_s * settings.sample_rate)
|
||||||
|
silence_audio = np.zeros(silence_samples, dtype=np.float32)
|
||||||
|
pause_chunk = AudioChunk(audio=silence_audio, word_timestamps=[])
|
||||||
|
|
||||||
|
if output_format:
|
||||||
|
formatted_pause_chunk = await AudioService.convert_audio(
|
||||||
|
pause_chunk, output_format, writer, speed=1.0, chunk_text="",
|
||||||
|
is_last_chunk=False, trim_audio=False, normalizer=stream_normalizer,
|
||||||
|
)
|
||||||
|
if formatted_pause_chunk.output:
|
||||||
|
yield formatted_pause_chunk
|
||||||
|
else:
|
||||||
|
pause_chunk.audio = stream_normalizer.normalize(pause_chunk.audio)
|
||||||
|
if len(pause_chunk.audio) > 0:
|
||||||
|
yield pause_chunk
|
||||||
|
|
||||||
|
current_offset += newline_pause_s # Add newline pause to offset
|
||||||
|
|
||||||
|
except Exception as pause_e:
|
||||||
|
logger.error(f"Failed to process newline pause chunk: {str(pause_e)}")
|
||||||
|
# -------------------------------
|
||||||
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}"
|
f"Failed to process audio for chunk: '{chunk_text.strip()[:100]}...'. Error: {str(e)}"
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
@ -374,8 +441,19 @@ class TTSService:
|
||||||
output_format=None,
|
output_format=None,
|
||||||
):
|
):
|
||||||
if len(audio_stream_data.audio) > 0:
|
if len(audio_stream_data.audio) > 0:
|
||||||
|
# Ensure we only append chunks with actual audio data
|
||||||
|
# Raw silence chunks generated for pauses will have audio data (zeros)
|
||||||
|
# Formatted silence chunks might have empty audio but non-empty output
|
||||||
|
if len(audio_stream_data.audio) > 0 or (output_format and audio_stream_data.output):
|
||||||
audio_data_chunks.append(audio_stream_data)
|
audio_data_chunks.append(audio_stream_data)
|
||||||
|
|
||||||
|
if not audio_data_chunks:
|
||||||
|
# Handle cases where only pauses were present or generation failed
|
||||||
|
logger.warning("No valid audio chunks generated.")
|
||||||
|
# Return an empty AudioChunk or raise an error? Returning empty for now.
|
||||||
|
return AudioChunk(audio=np.array([], dtype=np.int16), word_timestamps=[])
|
||||||
|
|
||||||
|
|
||||||
combined_audio_data = AudioChunk.combine(audio_data_chunks)
|
combined_audio_data = AudioChunk.combine(audio_data_chunks)
|
||||||
return combined_audio_data
|
return combined_audio_data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
Loading…
Add table
Reference in a new issue