From 84d2a4d806ecb3bf49021a36fe27b7eaa792c8c0 Mon Sep 17 00:00:00 2001 From: Lukin Date: Fri, 30 May 2025 23:06:41 +0800 Subject: [PATCH] Enhance TTS text processing: Implement pause tag handling in smart_split, allowing for better audio chunk generation with pauses. Update related tests to validate new functionality and ensure compatibility with existing features. --- .../text_processing/text_processor.py | 293 ++++++++++-------- api/src/services/tts_service.py | 112 ++++--- api/tests/test_text_processor.py | 44 ++- 3 files changed, 282 insertions(+), 167 deletions(-) diff --git a/api/src/services/text_processing/text_processor.py b/api/src/services/text_processing/text_processor.py index 3d90325..c5a442d 100644 --- a/api/src/services/text_processing/text_processor.py +++ b/api/src/services/text_processing/text_processor.py @@ -2,7 +2,7 @@ import re import time -from typing import AsyncGenerator, Dict, List, Tuple +from typing import AsyncGenerator, Dict, List, Tuple, Optional from loguru import logger @@ -13,7 +13,11 @@ from .phonemizer import phonemize from .vocabulary import tokenize # Pre-compiled regex patterns for performance -CUSTOM_PHONEMES = re.compile(r"(\[([^\]]|\n)*?\])(\(\/([^\/)]|\n)*?\/\))") +# Updated regex to be more strict and avoid matching isolated brackets +# Only matches complete patterns like [word](/ipa/) and prevents catastrophic backtracking +CUSTOM_PHONEMES = re.compile(r"(\[[^\[\]]*?\])(\(\/[^\/\(\)]*?\/\))") +# Pattern to find pause tags like [pause:0.5s] +PAUSE_TAG_PATTERN = re.compile(r"\[pause:(\d+(?:\.\d+)?)s\]", re.IGNORECASE) def process_text_chunk( @@ -142,148 +146,189 @@ async def smart_split( max_tokens: int = settings.absolute_max_tokens, lang_code: str = "a", normalization_options: NormalizationOptions = NormalizationOptions(), -) -> AsyncGenerator[Tuple[str, List[int]], None]: - """Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens.""" +) -> AsyncGenerator[Tuple[str, List[int], Optional[float]], None]: + """Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens. + + Yields: + Tuple of (text_chunk, tokens, pause_duration_s). + If pause_duration_s is not None, it's a pause chunk with empty text/tokens. + Otherwise, it's a text chunk containing the original text. + """ start_time = time.time() chunk_count = 0 logger.info(f"Starting smart split for {len(text)} chars") - custom_phoneme_list = {} + # --- Step 1: Split by Pause Tags FIRST --- + # This operates on the raw input text + parts = PAUSE_TAG_PATTERN.split(text) + logger.debug(f"Split raw text into {len(parts)} parts by pause tags.") - # Normalize text - if settings.advanced_text_normalization and normalization_options.normalize: - if lang_code in ["a", "b", "en-us", "en-gb"]: - text = CUSTOM_PHONEMES.sub( - lambda s: handle_custom_phonemes(s, custom_phoneme_list), text - ) - text = normalize_text(text, normalization_options) - else: - logger.info( - "Skipping text normalization as it is only supported for english" - ) + part_idx = 0 + while part_idx < len(parts): + text_part_raw = parts[part_idx] # This part is raw text + part_idx += 1 - # Process all sentences - sentences = get_sentence_info(text, custom_phoneme_list, lang_code=lang_code) + # --- Process Text Part --- + if text_part_raw and text_part_raw.strip(): # Only process if the part is not empty string + # Strip leading and trailing spaces to prevent pause tag splitting artifacts + text_part_raw = text_part_raw.strip() - current_chunk = [] - current_tokens = [] - current_count = 0 + # Apply the original smart_split logic to this text part + custom_phoneme_list = {} - for sentence, tokens, count in sentences: - # Handle sentences that exceed max tokens - if count > max_tokens: - # Yield current chunk if any - if current_chunk: - chunk_text = " ".join(current_chunk).strip() # Strip after joining - chunk_count += 1 - logger.debug( - f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" - ) - yield chunk_text, current_tokens - current_chunk = [] - current_tokens = [] - current_count = 0 - - # Split long sentence on commas - clauses = re.split(r"([,])", sentence) - clause_chunk = [] - clause_tokens = [] - clause_count = 0 - - for j in range(0, len(clauses), 2): - clause = clauses[j].strip() - comma = clauses[j + 1] if j + 1 < len(clauses) else "" - - if not clause: - continue - - full_clause = clause + comma - - tokens = process_text_chunk(full_clause) - count = len(tokens) - - # If adding clause keeps us under max and not optimal yet - if ( - clause_count + count <= max_tokens - and clause_count + count <= settings.target_max_tokens - ): - clause_chunk.append(full_clause) - clause_tokens.extend(tokens) - clause_count += count + # Normalize text (original logic) + processed_text = text_part_raw + if settings.advanced_text_normalization and normalization_options.normalize: + if lang_code in ["a", "b", "en-us", "en-gb"]: + processed_text = CUSTOM_PHONEMES.sub( + lambda s: handle_custom_phonemes(s, custom_phoneme_list), processed_text + ) + processed_text = normalize_text(processed_text, normalization_options) else: - # Yield clause chunk if we have one - if clause_chunk: - chunk_text = " ".join(clause_chunk).strip() # Strip after joining + logger.info( + "Skipping text normalization as it is only supported for english" + ) + + # Process all sentences (original logic) + sentences = get_sentence_info(processed_text, custom_phoneme_list, lang_code=lang_code) + + current_chunk = [] + current_tokens = [] + current_count = 0 + + for sentence, tokens, count in sentences: + # Handle sentences that exceed max tokens (original logic) + if count > max_tokens: + # Yield current chunk if any + if current_chunk: + chunk_text = " ".join(current_chunk).strip() chunk_count += 1 logger.debug( - f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)" + f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)" ) - yield chunk_text, clause_tokens - clause_chunk = [full_clause] - clause_tokens = tokens - clause_count = count + yield chunk_text, current_tokens, None + current_chunk = [] + current_tokens = [] + current_count = 0 - # Don't forget last clause chunk - if clause_chunk: - chunk_text = " ".join(clause_chunk).strip() # Strip after joining - chunk_count += 1 - logger.debug( - f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)" - ) - yield chunk_text, clause_tokens + # Split long sentence on commas (original logic) + clauses = re.split(r"([,])", sentence) + clause_chunk = [] + clause_tokens = [] + clause_count = 0 - # Regular sentence handling - elif ( - current_count >= settings.target_min_tokens - and current_count + count > settings.target_max_tokens - ): - # If we have a good sized chunk and adding next sentence exceeds target, - # yield current chunk and start new one - chunk_text = " ".join(current_chunk).strip() # Strip after joining - chunk_count += 1 - logger.info( - f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" - ) - yield chunk_text, current_tokens - current_chunk = [sentence] - current_tokens = tokens - current_count = count - elif current_count + count <= settings.target_max_tokens: - # Keep building chunk while under target max - current_chunk.append(sentence) - current_tokens.extend(tokens) - current_count += count - elif ( - current_count + count <= max_tokens - and current_count < settings.target_min_tokens - ): - # Only exceed target max if we haven't reached minimum size yet - current_chunk.append(sentence) - current_tokens.extend(tokens) - current_count += count - else: - # Yield current chunk and start new one + for j in range(0, len(clauses), 2): + clause = clauses[j].strip() + comma = clauses[j + 1] if j + 1 < len(clauses) else "" + + if not clause: + continue + + full_clause = clause + comma + + tokens = process_text_chunk(full_clause) + count = len(tokens) + + # If adding clause keeps us under max and not optimal yet + if ( + clause_count + count <= max_tokens + and clause_count + count <= settings.target_max_tokens + ): + clause_chunk.append(full_clause) + clause_tokens.extend(tokens) + clause_count += count + else: + # Yield clause chunk if we have one + if clause_chunk: + chunk_text = " ".join(clause_chunk).strip() + chunk_count += 1 + logger.debug( + f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({clause_count} tokens)" + ) + yield chunk_text, clause_tokens, None + clause_chunk = [full_clause] + clause_tokens = tokens + clause_count = count + + # Don't forget last clause chunk + if clause_chunk: + chunk_text = " ".join(clause_chunk).strip() + chunk_count += 1 + logger.debug( + f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({clause_count} tokens)" + ) + yield chunk_text, clause_tokens, None + + # Regular sentence handling (original logic) + elif ( + current_count >= settings.target_min_tokens + and current_count + count > settings.target_max_tokens + ): + # If we have a good sized chunk and adding next sentence exceeds target, + # yield current chunk and start new one + chunk_text = " ".join(current_chunk).strip() + chunk_count += 1 + logger.info( + f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)" + ) + yield chunk_text, current_tokens, None + current_chunk = [sentence] + current_tokens = tokens + current_count = count + elif current_count + count <= settings.target_max_tokens: + # Keep building chunk while under target max + current_chunk.append(sentence) + current_tokens.extend(tokens) + current_count += count + elif ( + current_count + count <= max_tokens + and current_count < settings.target_min_tokens + ): + # Only exceed target max if we haven't reached minimum size yet + current_chunk.append(sentence) + current_tokens.extend(tokens) + current_count += count + else: + # Yield current chunk and start new one + if current_chunk: + chunk_text = " ".join(current_chunk).strip() + chunk_count += 1 + logger.info( + f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)" + ) + yield chunk_text, current_tokens, None + current_chunk = [sentence] + current_tokens = tokens + current_count = count + + # Don't forget the last chunk for this text part if current_chunk: - chunk_text = " ".join(current_chunk).strip() # Strip after joining + chunk_text = " ".join(current_chunk).strip() chunk_count += 1 logger.info( - f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" + f"Yielding final chunk {chunk_count} for part: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)" ) - yield chunk_text, current_tokens - current_chunk = [sentence] - current_tokens = tokens - current_count = count + yield chunk_text, current_tokens, None - # Don't forget the last chunk - if current_chunk: - chunk_text = " ".join(current_chunk).strip() # Strip after joining - chunk_count += 1 - logger.info( - f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" - ) - yield chunk_text, current_tokens + # --- Handle Pause Part --- + # Check if the next part is a pause duration string + if part_idx < len(parts): + duration_str = parts[part_idx] + # Check if it looks like a valid number string captured by the regex group + if re.fullmatch(r"\d+(?:\.\d+)?", duration_str): + part_idx += 1 # Consume the duration string as it's been processed + try: + duration = float(duration_str) + if duration > 0: + chunk_count += 1 + logger.info(f"Yielding pause chunk {chunk_count}: {duration}s") + yield "", [], duration # Yield pause chunk + except (ValueError, TypeError): + # This case should be rare if re.fullmatch passed, but handle anyway + logger.warning(f"Could not parse valid-looking pause duration: {duration_str}") + # --- End of parts loop --- total_time = time.time() - start_time logger.info( - f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks" + f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks (including pauses)" ) diff --git a/api/src/services/tts_service.py b/api/src/services/tts_service.py index 0a69b85..399600e 100644 --- a/api/src/services/tts_service.py +++ b/api/src/services/tts_service.py @@ -280,48 +280,88 @@ class TTSService: f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream" ) - # Process text in chunks with smart splitting - async for chunk_text, tokens in smart_split( + # Process text in chunks with smart splitting, handling pause tags + async for chunk_text, tokens, pause_duration_s in smart_split( text, lang_code=pipeline_lang_code, normalization_options=normalization_options, ): - try: - # Process audio for chunk - async for chunk_data in self._process_chunk( - chunk_text, # Pass text for Kokoro V1 - tokens, # Pass tokens for legacy backends - voice_name, # Pass voice name - voice_path, # Pass voice path - speed, - writer, - output_format, - is_first=(chunk_index == 0), - is_last=False, # We'll update the last chunk later - normalizer=stream_normalizer, - lang_code=pipeline_lang_code, # Pass lang_code - return_timestamps=return_timestamps, - ): - if chunk_data.word_timestamps is not None: - for timestamp in chunk_data.word_timestamps: - timestamp.start_time += current_offset - timestamp.end_time += current_offset + if pause_duration_s is not None and pause_duration_s > 0: + # --- Handle Pause Chunk --- + try: + logger.debug(f"Generating {pause_duration_s}s silence chunk") + silence_samples = int(pause_duration_s * 24000) # 24kHz sample rate + # Create proper silence as int16 zeros to avoid normalization artifacts + silence_audio = np.zeros(silence_samples, dtype=np.int16) + pause_chunk = AudioChunk(audio=silence_audio, word_timestamps=[]) # Empty timestamps for silence - current_offset += len(chunk_data.audio) / 24000 - - if chunk_data.output is not None: - yield chunk_data - - else: - logger.warning( - f"No audio generated for chunk: '{chunk_text[:100]}...'" + # Format and yield the silence chunk + if output_format: + formatted_pause_chunk = await AudioService.convert_audio( + pause_chunk, output_format, writer, speed=speed, chunk_text="", + is_last_chunk=False, trim_audio=False, normalizer=stream_normalizer, ) - chunk_index += 1 - except Exception as e: - logger.error( - f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}" - ) - continue + if formatted_pause_chunk.output: + yield formatted_pause_chunk + else: # Raw audio mode + # For raw audio mode, silence is already in the correct format (int16) + # Skip normalization to avoid any potential artifacts + if len(pause_chunk.audio) > 0: + yield pause_chunk + + # Update offset based on silence duration + current_offset += pause_duration_s + chunk_index += 1 # Count pause as a yielded chunk + + except Exception as e: + logger.error(f"Failed to process pause chunk: {str(e)}") + continue + + elif tokens or chunk_text.strip(): # Process if there are tokens OR non-whitespace text + # --- Handle Text Chunk --- + try: + # Process audio for chunk + async for chunk_data in self._process_chunk( + chunk_text, # Pass text for Kokoro V1 + tokens, # Pass tokens for legacy backends + voice_name, # Pass voice name + voice_path, # Pass voice path + speed, + writer, + output_format, + is_first=(chunk_index == 0), + is_last=False, # We'll update the last chunk later + normalizer=stream_normalizer, + lang_code=pipeline_lang_code, # Pass lang_code + return_timestamps=return_timestamps, + ): + if chunk_data.word_timestamps is not None: + for timestamp in chunk_data.word_timestamps: + timestamp.start_time += current_offset + timestamp.end_time += current_offset + + # Update offset based on the actual duration of the generated audio chunk + chunk_duration = 0 + if chunk_data.audio is not None and len(chunk_data.audio) > 0: + chunk_duration = len(chunk_data.audio) / 24000 + current_offset += chunk_duration + + # Yield the processed chunk (either formatted or raw) + if chunk_data.output is not None: + yield chunk_data + elif chunk_data.audio is not None and len(chunk_data.audio) > 0: + yield chunk_data + else: + logger.warning( + f"No audio generated for chunk: '{chunk_text[:100]}...'" + ) + + chunk_index += 1 # Increment chunk index after processing text + except Exception as e: + logger.error( + f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}" + ) + continue # Only finalize if we successfully processed at least one chunk if chunk_index > 0: diff --git a/api/tests/test_text_processor.py b/api/tests/test_text_processor.py index 6ff8282..95c0259 100644 --- a/api/tests/test_text_processor.py +++ b/api/tests/test_text_processor.py @@ -67,7 +67,7 @@ async def test_smart_split_short_text(): """Test smart splitting with text under max tokens.""" text = "This is a short test sentence." chunks = [] - async for chunk_text, chunk_tokens in smart_split(text): + async for chunk_text, chunk_tokens, _ in smart_split(text): chunks.append((chunk_text, chunk_tokens)) assert len(chunks) == 1 @@ -82,7 +82,7 @@ async def test_smart_split_long_text(): text = ". ".join(["This is test sentence number " + str(i) for i in range(20)]) chunks = [] - async for chunk_text, chunk_tokens in smart_split(text): + async for chunk_text, chunk_tokens, _ in smart_split(text): chunks.append((chunk_text, chunk_tokens)) assert len(chunks) > 1 @@ -98,12 +98,13 @@ async def test_smart_split_with_punctuation(): text = "First sentence! Second sentence? Third sentence; Fourth sentence: Fifth sentence." chunks = [] - async for chunk_text, chunk_tokens in smart_split(text): + async for chunk_text, chunk_tokens, _ in smart_split(text): chunks.append(chunk_text) # Verify punctuation is preserved assert all(any(p in chunk for p in "!?;:.") for chunk in chunks) + def test_process_text_chunk_chinese_phonemes(): """Test processing with Chinese pinyin phonemes.""" pinyin = "nǐ hǎo lì" # Example pinyin sequence with tones @@ -125,12 +126,13 @@ def test_get_sentence_info_chinese(): assert count == len(tokens) assert count > 0 + @pytest.mark.asyncio async def test_smart_split_chinese_short(): """Test Chinese smart splitting with short text.""" text = "这是一句话。" chunks = [] - async for chunk_text, chunk_tokens in smart_split(text, lang_code="z"): + async for chunk_text, chunk_tokens, _ in smart_split(text, lang_code="z"): chunks.append((chunk_text, chunk_tokens)) assert len(chunks) == 1 @@ -144,7 +146,7 @@ async def test_smart_split_chinese_long(): text = "。".join([f"测试句子 {i}" for i in range(20)]) chunks = [] - async for chunk_text, chunk_tokens in smart_split(text, lang_code="z"): + async for chunk_text, chunk_tokens, _ in smart_split(text, lang_code="z"): chunks.append((chunk_text, chunk_tokens)) assert len(chunks) > 1 @@ -160,8 +162,36 @@ async def test_smart_split_chinese_punctuation(): text = "第一句!第二问?第三句;第四句:第五句。" chunks = [] - async for chunk_text, _ in smart_split(text, lang_code="z"): + async for chunk_text, _, _ in smart_split(text, lang_code="z"): chunks.append(chunk_text) # Verify Chinese punctuation is preserved - assert all(any(p in chunk for p in "!?;:。") for chunk in chunks) \ No newline at end of file + assert all(any(p in chunk for p in "!?;:。") for chunk in chunks) + + +@pytest.mark.asyncio +async def test_smart_split_with_pause(): + """Test smart splitting with pause tags.""" + text = "Hello world [pause:2.5s] How are you?" + + chunks = [] + async for chunk_text, chunk_tokens, pause_duration in smart_split(text): + chunks.append((chunk_text, chunk_tokens, pause_duration)) + + # Should have 3 chunks: text, pause, text + assert len(chunks) == 3 + + # First chunk: text + assert chunks[0][2] is None # No pause + assert "Hello world" in chunks[0][0] + assert len(chunks[0][1]) > 0 + + # Second chunk: pause + assert chunks[1][2] == 2.5 # 2.5 second pause + assert chunks[1][0] == "" # Empty text + assert len(chunks[1][1]) == 0 # No tokens + + # Third chunk: text + assert chunks[2][2] is None # No pause + assert "How are you?" in chunks[2][0] + assert len(chunks[2][1]) > 0 \ No newline at end of file