Enhance TTS text processing: Implement pause tag handling in smart_split, allowing for better audio chunk generation with pauses. Update related tests to validate new functionality and ensure compatibility with existing features.

This commit is contained in:
Lukin 2025-05-30 23:06:41 +08:00
parent ab8ab7d749
commit 84d2a4d806
3 changed files with 282 additions and 167 deletions

View file

@ -2,7 +2,7 @@
import re import re
import time import time
from typing import AsyncGenerator, Dict, List, Tuple from typing import AsyncGenerator, Dict, List, Tuple, Optional
from loguru import logger from loguru import logger
@ -13,7 +13,11 @@ from .phonemizer import phonemize
from .vocabulary import tokenize from .vocabulary import tokenize
# Pre-compiled regex patterns for performance # Pre-compiled regex patterns for performance
CUSTOM_PHONEMES = re.compile(r"(\[([^\]]|\n)*?\])(\(\/([^\/)]|\n)*?\/\))") # Updated regex to be more strict and avoid matching isolated brackets
# Only matches complete patterns like [word](/ipa/) and prevents catastrophic backtracking
CUSTOM_PHONEMES = re.compile(r"(\[[^\[\]]*?\])(\(\/[^\/\(\)]*?\/\))")
# Pattern to find pause tags like [pause:0.5s]
PAUSE_TAG_PATTERN = re.compile(r"\[pause:(\d+(?:\.\d+)?)s\]", re.IGNORECASE)
def process_text_chunk( def process_text_chunk(
@ -142,49 +146,72 @@ async def smart_split(
max_tokens: int = settings.absolute_max_tokens, max_tokens: int = settings.absolute_max_tokens,
lang_code: str = "a", lang_code: str = "a",
normalization_options: NormalizationOptions = NormalizationOptions(), normalization_options: NormalizationOptions = NormalizationOptions(),
) -> AsyncGenerator[Tuple[str, List[int]], None]: ) -> AsyncGenerator[Tuple[str, List[int], Optional[float]], None]:
"""Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens.""" """Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens.
Yields:
Tuple of (text_chunk, tokens, pause_duration_s).
If pause_duration_s is not None, it's a pause chunk with empty text/tokens.
Otherwise, it's a text chunk containing the original text.
"""
start_time = time.time() start_time = time.time()
chunk_count = 0 chunk_count = 0
logger.info(f"Starting smart split for {len(text)} chars") logger.info(f"Starting smart split for {len(text)} chars")
# --- Step 1: Split by Pause Tags FIRST ---
# This operates on the raw input text
parts = PAUSE_TAG_PATTERN.split(text)
logger.debug(f"Split raw text into {len(parts)} parts by pause tags.")
part_idx = 0
while part_idx < len(parts):
text_part_raw = parts[part_idx] # This part is raw text
part_idx += 1
# --- Process Text Part ---
if text_part_raw and text_part_raw.strip(): # Only process if the part is not empty string
# Strip leading and trailing spaces to prevent pause tag splitting artifacts
text_part_raw = text_part_raw.strip()
# Apply the original smart_split logic to this text part
custom_phoneme_list = {} custom_phoneme_list = {}
# Normalize text # Normalize text (original logic)
processed_text = text_part_raw
if settings.advanced_text_normalization and normalization_options.normalize: if settings.advanced_text_normalization and normalization_options.normalize:
if lang_code in ["a", "b", "en-us", "en-gb"]: if lang_code in ["a", "b", "en-us", "en-gb"]:
text = CUSTOM_PHONEMES.sub( processed_text = CUSTOM_PHONEMES.sub(
lambda s: handle_custom_phonemes(s, custom_phoneme_list), text lambda s: handle_custom_phonemes(s, custom_phoneme_list), processed_text
) )
text = normalize_text(text, normalization_options) processed_text = normalize_text(processed_text, normalization_options)
else: else:
logger.info( logger.info(
"Skipping text normalization as it is only supported for english" "Skipping text normalization as it is only supported for english"
) )
# Process all sentences # Process all sentences (original logic)
sentences = get_sentence_info(text, custom_phoneme_list, lang_code=lang_code) sentences = get_sentence_info(processed_text, custom_phoneme_list, lang_code=lang_code)
current_chunk = [] current_chunk = []
current_tokens = [] current_tokens = []
current_count = 0 current_count = 0
for sentence, tokens, count in sentences: for sentence, tokens, count in sentences:
# Handle sentences that exceed max tokens # Handle sentences that exceed max tokens (original logic)
if count > max_tokens: if count > max_tokens:
# Yield current chunk if any # Yield current chunk if any
if current_chunk: if current_chunk:
chunk_text = " ".join(current_chunk).strip() # Strip after joining chunk_text = " ".join(current_chunk).strip()
chunk_count += 1 chunk_count += 1
logger.debug( logger.debug(
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)"
) )
yield chunk_text, current_tokens yield chunk_text, current_tokens, None
current_chunk = [] current_chunk = []
current_tokens = [] current_tokens = []
current_count = 0 current_count = 0
# Split long sentence on commas # Split long sentence on commas (original logic)
clauses = re.split(r"([,])", sentence) clauses = re.split(r"([,])", sentence)
clause_chunk = [] clause_chunk = []
clause_tokens = [] clause_tokens = []
@ -213,38 +240,38 @@ async def smart_split(
else: else:
# Yield clause chunk if we have one # Yield clause chunk if we have one
if clause_chunk: if clause_chunk:
chunk_text = " ".join(clause_chunk).strip() # Strip after joining chunk_text = " ".join(clause_chunk).strip()
chunk_count += 1 chunk_count += 1
logger.debug( logger.debug(
f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)" f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({clause_count} tokens)"
) )
yield chunk_text, clause_tokens yield chunk_text, clause_tokens, None
clause_chunk = [full_clause] clause_chunk = [full_clause]
clause_tokens = tokens clause_tokens = tokens
clause_count = count clause_count = count
# Don't forget last clause chunk # Don't forget last clause chunk
if clause_chunk: if clause_chunk:
chunk_text = " ".join(clause_chunk).strip() # Strip after joining chunk_text = " ".join(clause_chunk).strip()
chunk_count += 1 chunk_count += 1
logger.debug( logger.debug(
f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)" f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({clause_count} tokens)"
) )
yield chunk_text, clause_tokens yield chunk_text, clause_tokens, None
# Regular sentence handling # Regular sentence handling (original logic)
elif ( elif (
current_count >= settings.target_min_tokens current_count >= settings.target_min_tokens
and current_count + count > settings.target_max_tokens and current_count + count > settings.target_max_tokens
): ):
# If we have a good sized chunk and adding next sentence exceeds target, # If we have a good sized chunk and adding next sentence exceeds target,
# yield current chunk and start new one # yield current chunk and start new one
chunk_text = " ".join(current_chunk).strip() # Strip after joining chunk_text = " ".join(current_chunk).strip()
chunk_count += 1 chunk_count += 1
logger.info( logger.info(
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)"
) )
yield chunk_text, current_tokens yield chunk_text, current_tokens, None
current_chunk = [sentence] current_chunk = [sentence]
current_tokens = tokens current_tokens = tokens
current_count = count current_count = count
@ -264,26 +291,44 @@ async def smart_split(
else: else:
# Yield current chunk and start new one # Yield current chunk and start new one
if current_chunk: if current_chunk:
chunk_text = " ".join(current_chunk).strip() # Strip after joining chunk_text = " ".join(current_chunk).strip()
chunk_count += 1 chunk_count += 1
logger.info( logger.info(
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)"
) )
yield chunk_text, current_tokens yield chunk_text, current_tokens, None
current_chunk = [sentence] current_chunk = [sentence]
current_tokens = tokens current_tokens = tokens
current_count = count current_count = count
# Don't forget the last chunk # Don't forget the last chunk for this text part
if current_chunk: if current_chunk:
chunk_text = " ".join(current_chunk).strip() # Strip after joining chunk_text = " ".join(current_chunk).strip()
chunk_count += 1 chunk_count += 1
logger.info( logger.info(
f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" f"Yielding final chunk {chunk_count} for part: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)"
) )
yield chunk_text, current_tokens yield chunk_text, current_tokens, None
# --- Handle Pause Part ---
# Check if the next part is a pause duration string
if part_idx < len(parts):
duration_str = parts[part_idx]
# Check if it looks like a valid number string captured by the regex group
if re.fullmatch(r"\d+(?:\.\d+)?", duration_str):
part_idx += 1 # Consume the duration string as it's been processed
try:
duration = float(duration_str)
if duration > 0:
chunk_count += 1
logger.info(f"Yielding pause chunk {chunk_count}: {duration}s")
yield "", [], duration # Yield pause chunk
except (ValueError, TypeError):
# This case should be rare if re.fullmatch passed, but handle anyway
logger.warning(f"Could not parse valid-looking pause duration: {duration_str}")
# --- End of parts loop ---
total_time = time.time() - start_time total_time = time.time() - start_time
logger.info( logger.info(
f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks" f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks (including pauses)"
) )

View file

@ -280,12 +280,45 @@ class TTSService:
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream" f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream"
) )
# Process text in chunks with smart splitting # Process text in chunks with smart splitting, handling pause tags
async for chunk_text, tokens in smart_split( async for chunk_text, tokens, pause_duration_s in smart_split(
text, text,
lang_code=pipeline_lang_code, lang_code=pipeline_lang_code,
normalization_options=normalization_options, normalization_options=normalization_options,
): ):
if pause_duration_s is not None and pause_duration_s > 0:
# --- Handle Pause Chunk ---
try:
logger.debug(f"Generating {pause_duration_s}s silence chunk")
silence_samples = int(pause_duration_s * 24000) # 24kHz sample rate
# Create proper silence as int16 zeros to avoid normalization artifacts
silence_audio = np.zeros(silence_samples, dtype=np.int16)
pause_chunk = AudioChunk(audio=silence_audio, word_timestamps=[]) # Empty timestamps for silence
# Format and yield the silence chunk
if output_format:
formatted_pause_chunk = await AudioService.convert_audio(
pause_chunk, output_format, writer, speed=speed, chunk_text="",
is_last_chunk=False, trim_audio=False, normalizer=stream_normalizer,
)
if formatted_pause_chunk.output:
yield formatted_pause_chunk
else: # Raw audio mode
# For raw audio mode, silence is already in the correct format (int16)
# Skip normalization to avoid any potential artifacts
if len(pause_chunk.audio) > 0:
yield pause_chunk
# Update offset based on silence duration
current_offset += pause_duration_s
chunk_index += 1 # Count pause as a yielded chunk
except Exception as e:
logger.error(f"Failed to process pause chunk: {str(e)}")
continue
elif tokens or chunk_text.strip(): # Process if there are tokens OR non-whitespace text
# --- Handle Text Chunk ---
try: try:
# Process audio for chunk # Process audio for chunk
async for chunk_data in self._process_chunk( async for chunk_data in self._process_chunk(
@ -307,16 +340,23 @@ class TTSService:
timestamp.start_time += current_offset timestamp.start_time += current_offset
timestamp.end_time += current_offset timestamp.end_time += current_offset
current_offset += len(chunk_data.audio) / 24000 # Update offset based on the actual duration of the generated audio chunk
chunk_duration = 0
if chunk_data.audio is not None and len(chunk_data.audio) > 0:
chunk_duration = len(chunk_data.audio) / 24000
current_offset += chunk_duration
# Yield the processed chunk (either formatted or raw)
if chunk_data.output is not None: if chunk_data.output is not None:
yield chunk_data yield chunk_data
elif chunk_data.audio is not None and len(chunk_data.audio) > 0:
yield chunk_data
else: else:
logger.warning( logger.warning(
f"No audio generated for chunk: '{chunk_text[:100]}...'" f"No audio generated for chunk: '{chunk_text[:100]}...'"
) )
chunk_index += 1
chunk_index += 1 # Increment chunk index after processing text
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}" f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}"

View file

@ -67,7 +67,7 @@ async def test_smart_split_short_text():
"""Test smart splitting with text under max tokens.""" """Test smart splitting with text under max tokens."""
text = "This is a short test sentence." text = "This is a short test sentence."
chunks = [] chunks = []
async for chunk_text, chunk_tokens in smart_split(text): async for chunk_text, chunk_tokens, _ in smart_split(text):
chunks.append((chunk_text, chunk_tokens)) chunks.append((chunk_text, chunk_tokens))
assert len(chunks) == 1 assert len(chunks) == 1
@ -82,7 +82,7 @@ async def test_smart_split_long_text():
text = ". ".join(["This is test sentence number " + str(i) for i in range(20)]) text = ". ".join(["This is test sentence number " + str(i) for i in range(20)])
chunks = [] chunks = []
async for chunk_text, chunk_tokens in smart_split(text): async for chunk_text, chunk_tokens, _ in smart_split(text):
chunks.append((chunk_text, chunk_tokens)) chunks.append((chunk_text, chunk_tokens))
assert len(chunks) > 1 assert len(chunks) > 1
@ -98,12 +98,13 @@ async def test_smart_split_with_punctuation():
text = "First sentence! Second sentence? Third sentence; Fourth sentence: Fifth sentence." text = "First sentence! Second sentence? Third sentence; Fourth sentence: Fifth sentence."
chunks = [] chunks = []
async for chunk_text, chunk_tokens in smart_split(text): async for chunk_text, chunk_tokens, _ in smart_split(text):
chunks.append(chunk_text) chunks.append(chunk_text)
# Verify punctuation is preserved # Verify punctuation is preserved
assert all(any(p in chunk for p in "!?;:.") for chunk in chunks) assert all(any(p in chunk for p in "!?;:.") for chunk in chunks)
def test_process_text_chunk_chinese_phonemes(): def test_process_text_chunk_chinese_phonemes():
"""Test processing with Chinese pinyin phonemes.""" """Test processing with Chinese pinyin phonemes."""
pinyin = "nǐ hǎo lì" # Example pinyin sequence with tones pinyin = "nǐ hǎo lì" # Example pinyin sequence with tones
@ -125,12 +126,13 @@ def test_get_sentence_info_chinese():
assert count == len(tokens) assert count == len(tokens)
assert count > 0 assert count > 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_smart_split_chinese_short(): async def test_smart_split_chinese_short():
"""Test Chinese smart splitting with short text.""" """Test Chinese smart splitting with short text."""
text = "这是一句话。" text = "这是一句话。"
chunks = [] chunks = []
async for chunk_text, chunk_tokens in smart_split(text, lang_code="z"): async for chunk_text, chunk_tokens, _ in smart_split(text, lang_code="z"):
chunks.append((chunk_text, chunk_tokens)) chunks.append((chunk_text, chunk_tokens))
assert len(chunks) == 1 assert len(chunks) == 1
@ -144,7 +146,7 @@ async def test_smart_split_chinese_long():
text = "".join([f"测试句子 {i}" for i in range(20)]) text = "".join([f"测试句子 {i}" for i in range(20)])
chunks = [] chunks = []
async for chunk_text, chunk_tokens in smart_split(text, lang_code="z"): async for chunk_text, chunk_tokens, _ in smart_split(text, lang_code="z"):
chunks.append((chunk_text, chunk_tokens)) chunks.append((chunk_text, chunk_tokens))
assert len(chunks) > 1 assert len(chunks) > 1
@ -160,8 +162,36 @@ async def test_smart_split_chinese_punctuation():
text = "第一句!第二问?第三句;第四句:第五句。" text = "第一句!第二问?第三句;第四句:第五句。"
chunks = [] chunks = []
async for chunk_text, _ in smart_split(text, lang_code="z"): async for chunk_text, _, _ in smart_split(text, lang_code="z"):
chunks.append(chunk_text) chunks.append(chunk_text)
# Verify Chinese punctuation is preserved # Verify Chinese punctuation is preserved
assert all(any(p in chunk for p in "!?;:。") for chunk in chunks) assert all(any(p in chunk for p in "!?;:。") for chunk in chunks)
@pytest.mark.asyncio
async def test_smart_split_with_pause():
"""Test smart splitting with pause tags."""
text = "Hello world [pause:2.5s] How are you?"
chunks = []
async for chunk_text, chunk_tokens, pause_duration in smart_split(text):
chunks.append((chunk_text, chunk_tokens, pause_duration))
# Should have 3 chunks: text, pause, text
assert len(chunks) == 3
# First chunk: text
assert chunks[0][2] is None # No pause
assert "Hello world" in chunks[0][0]
assert len(chunks[0][1]) > 0
# Second chunk: pause
assert chunks[1][2] == 2.5 # 2.5 second pause
assert chunks[1][0] == "" # Empty text
assert len(chunks[1][1]) == 0 # No tokens
# Third chunk: text
assert chunks[2][2] is None # No pause
assert "How are you?" in chunks[2][0]
assert len(chunks[2][1]) > 0