mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Enhance TTS service to handle pauses and trailing newlines in text processing. Updated smart_split to preserve newlines and added logic for generating silence chunks during pauses. Improved error handling and logging for audio processing.
This commit is contained in:
parent
f1fa340494
commit
b31f79d8d7
2 changed files with 177 additions and 65 deletions
|
@ -110,8 +110,17 @@ def get_sentence_info(
|
|||
if not sentence:
|
||||
continue
|
||||
|
||||
full = sentence + punct
|
||||
tokens = process_text_chunk(full)
|
||||
# Check if the original text segment ended with newline(s) before punctuation
|
||||
original_segment = sentences[i]
|
||||
trailing_newlines = ""
|
||||
match = re.search(r"(\n+)$", original_segment)
|
||||
if match:
|
||||
trailing_newlines = match.group(1)
|
||||
|
||||
full = sentence + punct + trailing_newlines # Append trailing newlines
|
||||
# Tokenize without the trailing newlines for accurate TTS processing
|
||||
tokens = process_text_chunk(sentence + punct)
|
||||
# Store the full text including newlines for later check
|
||||
results.append((full, tokens, len(tokens)))
|
||||
|
||||
return results
|
||||
|
@ -161,28 +170,35 @@ async def smart_split(
|
|||
if count > max_tokens:
|
||||
# Yield current chunk if any
|
||||
if current_chunk:
|
||||
chunk_text = " ".join(current_chunk)
|
||||
# Join with space, but preserve original trailing newline of the last sentence if present
|
||||
last_sentence_original = current_chunk[-1]
|
||||
chunk_text_joined = " ".join(current_chunk)
|
||||
if last_sentence_original.endswith("\n"):
|
||||
chunk_text_joined += "\n" # Preserve the newline marker
|
||||
|
||||
chunk_count += 1
|
||||
logger.debug(
|
||||
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
|
||||
f"Yielding text chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({current_count} tokens)"
|
||||
)
|
||||
yield chunk_text, current_tokens
|
||||
yield chunk_text_joined, current_tokens, None # Pass the text with potential trailing newline
|
||||
current_chunk = []
|
||||
current_tokens = []
|
||||
current_count = 0
|
||||
|
||||
# Split long sentence on commas
|
||||
clauses = re.split(r"([,])", sentence)
|
||||
clause_chunk = []
|
||||
# Split long sentence on commas (simple approach)
|
||||
# Keep original sentence text ('sentence' now includes potential trailing newline)
|
||||
clauses = re.split(r"([,])", sentence.rstrip('\n')) # Split text part only
|
||||
trailing_newline_in_sentence = "\n" if sentence.endswith("\n") else ""
|
||||
clause_chunk = [] # Stores original clause text including potential trailing newline
|
||||
clause_tokens = []
|
||||
clause_count = 0
|
||||
|
||||
for j in range(0, len(clauses), 2):
|
||||
clause = clauses[j].strip()
|
||||
# clause = clauses[j].strip() # Don't strip here to preserve internal structure
|
||||
clause = clauses[j]
|
||||
comma = clauses[j + 1] if j + 1 < len(clauses) else ""
|
||||
|
||||
if not clause:
|
||||
continue
|
||||
if not clause.strip(): # Check if clause is just whitespace
|
||||
|
||||
full_clause = clause + comma
|
||||
|
||||
|
@ -200,77 +216,95 @@ async def smart_split(
|
|||
else:
|
||||
# Yield clause chunk if we have one
|
||||
if clause_chunk:
|
||||
chunk_text = " ".join(clause_chunk)
|
||||
# Join with space, preserve last clause's potential trailing newline
|
||||
last_clause_original = clause_chunk[-1]
|
||||
chunk_text_joined = " ".join(clause_chunk)
|
||||
if last_clause_original.endswith("\n"):
|
||||
chunk_text_joined += "\n"
|
||||
|
||||
chunk_count += 1
|
||||
logger.debug(
|
||||
f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)"
|
||||
f"Yielding clause chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({clause_count} tokens)"
|
||||
)
|
||||
yield chunk_text, clause_tokens
|
||||
clause_chunk = [full_clause]
|
||||
clause_tokens = tokens
|
||||
clause_count = count
|
||||
yield chunk_text_joined, clause_tokens, None
|
||||
# Start new clause chunk with original text
|
||||
clause_chunk = [full_clause + (trailing_newline_in_sentence if j == len(clauses) - 2 else "")]
|
||||
clause_tokens = clause_token_list
|
||||
clause_count = clause_token_count
|
||||
|
||||
# Don't forget last clause chunk
|
||||
if clause_chunk:
|
||||
chunk_text = " ".join(clause_chunk)
|
||||
chunk_count += 1
|
||||
logger.debug(
|
||||
f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)"
|
||||
)
|
||||
yield chunk_text, clause_tokens
|
||||
# Don't forget last clause chunk
|
||||
if clause_chunk:
|
||||
# Join with space, preserve last clause's potential trailing newline
|
||||
last_clause_original = clause_chunk[-1]
|
||||
chunk_text_joined = " ".join(clause_chunk)
|
||||
# The trailing newline logic was added when creating the chunk above
|
||||
#if last_clause_original.endswith("\n"):
|
||||
# chunk_text_joined += "\n"
|
||||
|
||||
# Regular sentence handling
|
||||
elif (
|
||||
chunk_count += 1
|
||||
logger.debug(
|
||||
f"Yielding final clause chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({clause_count} tokens)"
|
||||
)
|
||||
yield chunk_text_joined, clause_tokens, None
|
||||
current_count >= settings.target_min_tokens
|
||||
and current_count + count > settings.target_max_tokens
|
||||
):
|
||||
# If we have a good sized chunk and adding next sentence exceeds target,
|
||||
# yield current chunk and start new one
|
||||
chunk_text = " ".join(current_chunk)
|
||||
# Yield current chunk and start new one
|
||||
last_sentence_original = current_chunk[-1]
|
||||
chunk_text_joined = " ".join(current_chunk)
|
||||
if last_sentence_original.endswith("\n"):
|
||||
chunk_text_joined += "\n"
|
||||
chunk_count += 1
|
||||
logger.info(
|
||||
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
|
||||
f"Yielding text chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({current_count} tokens)"
|
||||
)
|
||||
yield chunk_text, current_tokens
|
||||
current_chunk = [sentence]
|
||||
yield chunk_text_joined, current_tokens, None
|
||||
current_chunk = [sentence] # sentence includes potential trailing newline
|
||||
current_tokens = tokens
|
||||
current_count = count
|
||||
elif current_count + count <= settings.target_max_tokens:
|
||||
# Keep building chunk while under target max
|
||||
current_chunk.append(sentence)
|
||||
# Keep building chunk
|
||||
current_chunk.append(sentence) # sentence includes potential trailing newline
|
||||
current_tokens.extend(tokens)
|
||||
current_count += count
|
||||
elif (
|
||||
current_count + count <= max_tokens
|
||||
and current_count < settings.target_min_tokens
|
||||
):
|
||||
# Only exceed target max if we haven't reached minimum size yet
|
||||
current_chunk.append(sentence)
|
||||
# Exceed target max only if below min size
|
||||
current_chunk.append(sentence) # sentence includes potential trailing newline
|
||||
current_tokens.extend(tokens)
|
||||
current_count += count
|
||||
else:
|
||||
# Yield current chunk and start new one
|
||||
if current_chunk:
|
||||
chunk_text = " ".join(current_chunk)
|
||||
last_sentence_original = current_chunk[-1]
|
||||
chunk_text_joined = " ".join(current_chunk)
|
||||
if last_sentence_original.endswith("\n"):
|
||||
chunk_text_joined += "\n"
|
||||
chunk_count += 1
|
||||
logger.info(
|
||||
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
|
||||
f"Yielding text chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({current_count} tokens)"
|
||||
)
|
||||
yield chunk_text, current_tokens
|
||||
current_chunk = [sentence]
|
||||
yield chunk_text_joined, current_tokens, None
|
||||
current_chunk = [sentence] # sentence includes potential trailing newline
|
||||
current_tokens = tokens
|
||||
current_count = count
|
||||
|
||||
# Don't forget the last chunk
|
||||
# Yield any remaining text chunk
|
||||
if current_chunk:
|
||||
chunk_text = " ".join(current_chunk)
|
||||
last_sentence_original = current_chunk[-1]
|
||||
chunk_text_joined = " ".join(current_chunk)
|
||||
if last_sentence_original.endswith("\n"):
|
||||
chunk_text_joined += "\n"
|
||||
chunk_count += 1
|
||||
logger.info(
|
||||
f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
|
||||
f"Yielding final text chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({current_count} tokens)"
|
||||
)
|
||||
yield chunk_text, current_tokens
|
||||
yield chunk_text_joined, current_tokens, None
|
||||
|
||||
total_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks"
|
||||
)
|
||||
)
|
|
@ -280,18 +280,56 @@ class TTSService:
|
|||
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream"
|
||||
)
|
||||
|
||||
# Process text in chunks with smart splitting
|
||||
async for chunk_text, tokens in smart_split(
|
||||
# Process text in chunks with smart splitting, handling pauses
|
||||
async for chunk_text, tokens, pause_duration_s in smart_split(
|
||||
text,
|
||||
lang_code=pipeline_lang_code,
|
||||
normalization_options=normalization_options,
|
||||
):
|
||||
try:
|
||||
# Process audio for chunk
|
||||
async for chunk_data in self._process_chunk(
|
||||
chunk_text, # Pass text for Kokoro V1
|
||||
tokens, # Pass tokens for legacy backends
|
||||
voice_name, # Pass voice name
|
||||
if pause_duration_s is not None and pause_duration_s > 0:
|
||||
# --- Handle Pause Chunk ---
|
||||
try:
|
||||
logger.debug(f"Generating {pause_duration_s}s silence chunk")
|
||||
silence_samples = int(pause_duration_s * settings.sample_rate)
|
||||
# Use float32 zeros as AudioService will normalize later
|
||||
silence_audio = np.zeros(silence_samples, dtype=np.float32)
|
||||
pause_chunk = AudioChunk(audio=silence_audio, word_timestamps=[]) # Empty timestamps for silence
|
||||
|
||||
# Convert silence chunk to the target format using AudioService
|
||||
if output_format:
|
||||
formatted_pause_chunk = await AudioService.convert_audio(
|
||||
pause_chunk,
|
||||
output_format,
|
||||
writer,
|
||||
speed=1.0, # Speed doesn't affect silence
|
||||
chunk_text="", # No text for silence
|
||||
is_last_chunk=False, # Not the final chunk
|
||||
trim_audio=False, # Don't trim silence
|
||||
normalizer=stream_normalizer,
|
||||
)
|
||||
if formatted_pause_chunk.output:
|
||||
yield formatted_pause_chunk
|
||||
else:
|
||||
# If no output format (raw audio), yield the raw chunk
|
||||
# Ensure normalization happens if needed (AudioService handles this)
|
||||
pause_chunk.audio = stream_normalizer.normalize(pause_chunk.audio)
|
||||
yield pause_chunk # Yield raw silence chunk
|
||||
|
||||
# Update offset based on silence duration
|
||||
current_offset += pause_duration_s
|
||||
chunk_index += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process pause chunk: {str(e)}")
|
||||
continue
|
||||
elif tokens or chunk_text:
|
||||
# --- Handle Text Chunk ---
|
||||
try:
|
||||
# Process audio for the text chunk
|
||||
async for chunk_data in self._process_chunk(
|
||||
chunk_text, # Pass text for Kokoro V1
|
||||
tokens, # Pass tokens for legacy backends
|
||||
voice_name, # Pass voice name
|
||||
voice_path, # Pass voice path
|
||||
speed,
|
||||
writer,
|
||||
|
@ -307,19 +345,48 @@ class TTSService:
|
|||
timestamp.start_time += current_offset
|
||||
timestamp.end_time += current_offset
|
||||
|
||||
current_offset += len(chunk_data.audio) / 24000
|
||||
|
||||
if chunk_data.output is not None:
|
||||
yield chunk_data
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
f"No audio generated for chunk: '{chunk_text[:100]}...'"
|
||||
)
|
||||
# If no output format (raw audio), yield the raw chunk
|
||||
# Ensure normalization happens if needed (AudioService handles this)
|
||||
pause_chunk.audio = stream_normalizer.normalize(pause_chunk.audio)
|
||||
if len(pause_chunk.audio) > 0: # Only yield if silence is not zero length
|
||||
yield pause_chunk # Yield raw silence chunk
|
||||
|
||||
# Update offset based on silence duration
|
||||
f"No audio generated for chunk: '{chunk_text.strip()[:100]}...'"
|
||||
chunk_index += 1
|
||||
except Exception as e:
|
||||
# --- Add pause after newline ---
|
||||
# Check the original chunk_text passed from smart_split for trailing newline
|
||||
if chunk_text.endswith('\n'):
|
||||
newline_pause_s = 0.5
|
||||
try:
|
||||
logger.debug(f"Adding {newline_pause_s}s pause after newline.")
|
||||
silence_samples = int(newline_pause_s * settings.sample_rate)
|
||||
silence_audio = np.zeros(silence_samples, dtype=np.float32)
|
||||
pause_chunk = AudioChunk(audio=silence_audio, word_timestamps=[])
|
||||
|
||||
if output_format:
|
||||
formatted_pause_chunk = await AudioService.convert_audio(
|
||||
pause_chunk, output_format, writer, speed=1.0, chunk_text="",
|
||||
is_last_chunk=False, trim_audio=False, normalizer=stream_normalizer,
|
||||
)
|
||||
if formatted_pause_chunk.output:
|
||||
yield formatted_pause_chunk
|
||||
else:
|
||||
pause_chunk.audio = stream_normalizer.normalize(pause_chunk.audio)
|
||||
if len(pause_chunk.audio) > 0:
|
||||
yield pause_chunk
|
||||
|
||||
current_offset += newline_pause_s # Add newline pause to offset
|
||||
|
||||
except Exception as pause_e:
|
||||
logger.error(f"Failed to process newline pause chunk: {str(pause_e)}")
|
||||
# -------------------------------
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}"
|
||||
f"Failed to process audio for chunk: '{chunk_text.strip()[:100]}...'. Error: {str(e)}"
|
||||
)
|
||||
continue
|
||||
|
||||
|
@ -374,7 +441,18 @@ class TTSService:
|
|||
output_format=None,
|
||||
):
|
||||
if len(audio_stream_data.audio) > 0:
|
||||
audio_data_chunks.append(audio_stream_data)
|
||||
# Ensure we only append chunks with actual audio data
|
||||
# Raw silence chunks generated for pauses will have audio data (zeros)
|
||||
# Formatted silence chunks might have empty audio but non-empty output
|
||||
if len(audio_stream_data.audio) > 0 or (output_format and audio_stream_data.output):
|
||||
audio_data_chunks.append(audio_stream_data)
|
||||
|
||||
if not audio_data_chunks:
|
||||
# Handle cases where only pauses were present or generation failed
|
||||
logger.warning("No valid audio chunks generated.")
|
||||
# Return an empty AudioChunk or raise an error? Returning empty for now.
|
||||
return AudioChunk(audio=np.array([], dtype=np.int16), word_timestamps=[])
|
||||
|
||||
|
||||
combined_audio_data = AudioChunk.combine(audio_data_chunks)
|
||||
return combined_audio_data
|
||||
|
@ -456,4 +534,4 @@ class TTSService:
|
|||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in phoneme audio generation: {str(e)}")
|
||||
raise
|
||||
raise
|
Loading…
Add table
Reference in a new issue