Enhance TTS text processing: Implement pause tag handling in smart_split, allowing for better audio chunk generation with pauses. Update related tests to validate new functionality and ensure compatibility with existing features.

This commit is contained in:
Lukin 2025-05-30 23:06:41 +08:00
parent ab8ab7d749
commit 84d2a4d806
3 changed files with 282 additions and 167 deletions

View file

@ -2,7 +2,7 @@
import re import re
import time import time
from typing import AsyncGenerator, Dict, List, Tuple from typing import AsyncGenerator, Dict, List, Tuple, Optional
from loguru import logger from loguru import logger
@ -13,7 +13,11 @@ from .phonemizer import phonemize
from .vocabulary import tokenize from .vocabulary import tokenize
# Pre-compiled regex patterns for performance # Pre-compiled regex patterns for performance
CUSTOM_PHONEMES = re.compile(r"(\[([^\]]|\n)*?\])(\(\/([^\/)]|\n)*?\/\))") # Updated regex to be more strict and avoid matching isolated brackets
# Only matches complete patterns like [word](/ipa/) and prevents catastrophic backtracking
CUSTOM_PHONEMES = re.compile(r"(\[[^\[\]]*?\])(\(\/[^\/\(\)]*?\/\))")
# Pattern to find pause tags like [pause:0.5s]
PAUSE_TAG_PATTERN = re.compile(r"\[pause:(\d+(?:\.\d+)?)s\]", re.IGNORECASE)
def process_text_chunk( def process_text_chunk(
@ -142,148 +146,189 @@ async def smart_split(
max_tokens: int = settings.absolute_max_tokens, max_tokens: int = settings.absolute_max_tokens,
lang_code: str = "a", lang_code: str = "a",
normalization_options: NormalizationOptions = NormalizationOptions(), normalization_options: NormalizationOptions = NormalizationOptions(),
) -> AsyncGenerator[Tuple[str, List[int]], None]: ) -> AsyncGenerator[Tuple[str, List[int], Optional[float]], None]:
"""Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens.""" """Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens.
Yields:
Tuple of (text_chunk, tokens, pause_duration_s).
If pause_duration_s is not None, it's a pause chunk with empty text/tokens.
Otherwise, it's a text chunk containing the original text.
"""
start_time = time.time() start_time = time.time()
chunk_count = 0 chunk_count = 0
logger.info(f"Starting smart split for {len(text)} chars") logger.info(f"Starting smart split for {len(text)} chars")
custom_phoneme_list = {} # --- Step 1: Split by Pause Tags FIRST ---
# This operates on the raw input text
parts = PAUSE_TAG_PATTERN.split(text)
logger.debug(f"Split raw text into {len(parts)} parts by pause tags.")
# Normalize text part_idx = 0
if settings.advanced_text_normalization and normalization_options.normalize: while part_idx < len(parts):
if lang_code in ["a", "b", "en-us", "en-gb"]: text_part_raw = parts[part_idx] # This part is raw text
text = CUSTOM_PHONEMES.sub( part_idx += 1
lambda s: handle_custom_phonemes(s, custom_phoneme_list), text
)
text = normalize_text(text, normalization_options)
else:
logger.info(
"Skipping text normalization as it is only supported for english"
)
# Process all sentences # --- Process Text Part ---
sentences = get_sentence_info(text, custom_phoneme_list, lang_code=lang_code) if text_part_raw and text_part_raw.strip(): # Only process if the part is not empty string
# Strip leading and trailing spaces to prevent pause tag splitting artifacts
text_part_raw = text_part_raw.strip()
current_chunk = [] # Apply the original smart_split logic to this text part
current_tokens = [] custom_phoneme_list = {}
current_count = 0
for sentence, tokens, count in sentences: # Normalize text (original logic)
# Handle sentences that exceed max tokens processed_text = text_part_raw
if count > max_tokens: if settings.advanced_text_normalization and normalization_options.normalize:
# Yield current chunk if any if lang_code in ["a", "b", "en-us", "en-gb"]:
if current_chunk: processed_text = CUSTOM_PHONEMES.sub(
chunk_text = " ".join(current_chunk).strip() # Strip after joining lambda s: handle_custom_phonemes(s, custom_phoneme_list), processed_text
chunk_count += 1 )
logger.debug( processed_text = normalize_text(processed_text, normalization_options)
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
)
yield chunk_text, current_tokens
current_chunk = []
current_tokens = []
current_count = 0
# Split long sentence on commas
clauses = re.split(r"([,])", sentence)
clause_chunk = []
clause_tokens = []
clause_count = 0
for j in range(0, len(clauses), 2):
clause = clauses[j].strip()
comma = clauses[j + 1] if j + 1 < len(clauses) else ""
if not clause:
continue
full_clause = clause + comma
tokens = process_text_chunk(full_clause)
count = len(tokens)
# If adding clause keeps us under max and not optimal yet
if (
clause_count + count <= max_tokens
and clause_count + count <= settings.target_max_tokens
):
clause_chunk.append(full_clause)
clause_tokens.extend(tokens)
clause_count += count
else: else:
# Yield clause chunk if we have one logger.info(
if clause_chunk: "Skipping text normalization as it is only supported for english"
chunk_text = " ".join(clause_chunk).strip() # Strip after joining )
# Process all sentences (original logic)
sentences = get_sentence_info(processed_text, custom_phoneme_list, lang_code=lang_code)
current_chunk = []
current_tokens = []
current_count = 0
for sentence, tokens, count in sentences:
# Handle sentences that exceed max tokens (original logic)
if count > max_tokens:
# Yield current chunk if any
if current_chunk:
chunk_text = " ".join(current_chunk).strip()
chunk_count += 1 chunk_count += 1
logger.debug( logger.debug(
f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)" f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)"
) )
yield chunk_text, clause_tokens yield chunk_text, current_tokens, None
clause_chunk = [full_clause] current_chunk = []
clause_tokens = tokens current_tokens = []
clause_count = count current_count = 0
# Don't forget last clause chunk # Split long sentence on commas (original logic)
if clause_chunk: clauses = re.split(r"([,])", sentence)
chunk_text = " ".join(clause_chunk).strip() # Strip after joining clause_chunk = []
chunk_count += 1 clause_tokens = []
logger.debug( clause_count = 0
f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)"
)
yield chunk_text, clause_tokens
# Regular sentence handling for j in range(0, len(clauses), 2):
elif ( clause = clauses[j].strip()
current_count >= settings.target_min_tokens comma = clauses[j + 1] if j + 1 < len(clauses) else ""
and current_count + count > settings.target_max_tokens
): if not clause:
# If we have a good sized chunk and adding next sentence exceeds target, continue
# yield current chunk and start new one
chunk_text = " ".join(current_chunk).strip() # Strip after joining full_clause = clause + comma
chunk_count += 1
logger.info( tokens = process_text_chunk(full_clause)
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" count = len(tokens)
)
yield chunk_text, current_tokens # If adding clause keeps us under max and not optimal yet
current_chunk = [sentence] if (
current_tokens = tokens clause_count + count <= max_tokens
current_count = count and clause_count + count <= settings.target_max_tokens
elif current_count + count <= settings.target_max_tokens: ):
# Keep building chunk while under target max clause_chunk.append(full_clause)
current_chunk.append(sentence) clause_tokens.extend(tokens)
current_tokens.extend(tokens) clause_count += count
current_count += count else:
elif ( # Yield clause chunk if we have one
current_count + count <= max_tokens if clause_chunk:
and current_count < settings.target_min_tokens chunk_text = " ".join(clause_chunk).strip()
): chunk_count += 1
# Only exceed target max if we haven't reached minimum size yet logger.debug(
current_chunk.append(sentence) f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({clause_count} tokens)"
current_tokens.extend(tokens) )
current_count += count yield chunk_text, clause_tokens, None
else: clause_chunk = [full_clause]
# Yield current chunk and start new one clause_tokens = tokens
clause_count = count
# Don't forget last clause chunk
if clause_chunk:
chunk_text = " ".join(clause_chunk).strip()
chunk_count += 1
logger.debug(
f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({clause_count} tokens)"
)
yield chunk_text, clause_tokens, None
# Regular sentence handling (original logic)
elif (
current_count >= settings.target_min_tokens
and current_count + count > settings.target_max_tokens
):
# If we have a good sized chunk and adding next sentence exceeds target,
# yield current chunk and start new one
chunk_text = " ".join(current_chunk).strip()
chunk_count += 1
logger.info(
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)"
)
yield chunk_text, current_tokens, None
current_chunk = [sentence]
current_tokens = tokens
current_count = count
elif current_count + count <= settings.target_max_tokens:
# Keep building chunk while under target max
current_chunk.append(sentence)
current_tokens.extend(tokens)
current_count += count
elif (
current_count + count <= max_tokens
and current_count < settings.target_min_tokens
):
# Only exceed target max if we haven't reached minimum size yet
current_chunk.append(sentence)
current_tokens.extend(tokens)
current_count += count
else:
# Yield current chunk and start new one
if current_chunk:
chunk_text = " ".join(current_chunk).strip()
chunk_count += 1
logger.info(
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)"
)
yield chunk_text, current_tokens, None
current_chunk = [sentence]
current_tokens = tokens
current_count = count
# Don't forget the last chunk for this text part
if current_chunk: if current_chunk:
chunk_text = " ".join(current_chunk).strip() # Strip after joining chunk_text = " ".join(current_chunk).strip()
chunk_count += 1 chunk_count += 1
logger.info( logger.info(
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" f"Yielding final chunk {chunk_count} for part: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)"
) )
yield chunk_text, current_tokens yield chunk_text, current_tokens, None
current_chunk = [sentence]
current_tokens = tokens
current_count = count
# Don't forget the last chunk # --- Handle Pause Part ---
if current_chunk: # Check if the next part is a pause duration string
chunk_text = " ".join(current_chunk).strip() # Strip after joining if part_idx < len(parts):
chunk_count += 1 duration_str = parts[part_idx]
logger.info( # Check if it looks like a valid number string captured by the regex group
f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" if re.fullmatch(r"\d+(?:\.\d+)?", duration_str):
) part_idx += 1 # Consume the duration string as it's been processed
yield chunk_text, current_tokens try:
duration = float(duration_str)
if duration > 0:
chunk_count += 1
logger.info(f"Yielding pause chunk {chunk_count}: {duration}s")
yield "", [], duration # Yield pause chunk
except (ValueError, TypeError):
# This case should be rare if re.fullmatch passed, but handle anyway
logger.warning(f"Could not parse valid-looking pause duration: {duration_str}")
# --- End of parts loop ---
total_time = time.time() - start_time total_time = time.time() - start_time
logger.info( logger.info(
f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks" f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks (including pauses)"
) )

View file

@ -280,48 +280,88 @@ class TTSService:
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream" f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream"
) )
# Process text in chunks with smart splitting # Process text in chunks with smart splitting, handling pause tags
async for chunk_text, tokens in smart_split( async for chunk_text, tokens, pause_duration_s in smart_split(
text, text,
lang_code=pipeline_lang_code, lang_code=pipeline_lang_code,
normalization_options=normalization_options, normalization_options=normalization_options,
): ):
try: if pause_duration_s is not None and pause_duration_s > 0:
# Process audio for chunk # --- Handle Pause Chunk ---
async for chunk_data in self._process_chunk( try:
chunk_text, # Pass text for Kokoro V1 logger.debug(f"Generating {pause_duration_s}s silence chunk")
tokens, # Pass tokens for legacy backends silence_samples = int(pause_duration_s * 24000) # 24kHz sample rate
voice_name, # Pass voice name # Create proper silence as int16 zeros to avoid normalization artifacts
voice_path, # Pass voice path silence_audio = np.zeros(silence_samples, dtype=np.int16)
speed, pause_chunk = AudioChunk(audio=silence_audio, word_timestamps=[]) # Empty timestamps for silence
writer,
output_format,
is_first=(chunk_index == 0),
is_last=False, # We'll update the last chunk later
normalizer=stream_normalizer,
lang_code=pipeline_lang_code, # Pass lang_code
return_timestamps=return_timestamps,
):
if chunk_data.word_timestamps is not None:
for timestamp in chunk_data.word_timestamps:
timestamp.start_time += current_offset
timestamp.end_time += current_offset
current_offset += len(chunk_data.audio) / 24000 # Format and yield the silence chunk
if output_format:
if chunk_data.output is not None: formatted_pause_chunk = await AudioService.convert_audio(
yield chunk_data pause_chunk, output_format, writer, speed=speed, chunk_text="",
is_last_chunk=False, trim_audio=False, normalizer=stream_normalizer,
else:
logger.warning(
f"No audio generated for chunk: '{chunk_text[:100]}...'"
) )
chunk_index += 1 if formatted_pause_chunk.output:
except Exception as e: yield formatted_pause_chunk
logger.error( else: # Raw audio mode
f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}" # For raw audio mode, silence is already in the correct format (int16)
) # Skip normalization to avoid any potential artifacts
continue if len(pause_chunk.audio) > 0:
yield pause_chunk
# Update offset based on silence duration
current_offset += pause_duration_s
chunk_index += 1 # Count pause as a yielded chunk
except Exception as e:
logger.error(f"Failed to process pause chunk: {str(e)}")
continue
elif tokens or chunk_text.strip(): # Process if there are tokens OR non-whitespace text
# --- Handle Text Chunk ---
try:
# Process audio for chunk
async for chunk_data in self._process_chunk(
chunk_text, # Pass text for Kokoro V1
tokens, # Pass tokens for legacy backends
voice_name, # Pass voice name
voice_path, # Pass voice path
speed,
writer,
output_format,
is_first=(chunk_index == 0),
is_last=False, # We'll update the last chunk later
normalizer=stream_normalizer,
lang_code=pipeline_lang_code, # Pass lang_code
return_timestamps=return_timestamps,
):
if chunk_data.word_timestamps is not None:
for timestamp in chunk_data.word_timestamps:
timestamp.start_time += current_offset
timestamp.end_time += current_offset
# Update offset based on the actual duration of the generated audio chunk
chunk_duration = 0
if chunk_data.audio is not None and len(chunk_data.audio) > 0:
chunk_duration = len(chunk_data.audio) / 24000
current_offset += chunk_duration
# Yield the processed chunk (either formatted or raw)
if chunk_data.output is not None:
yield chunk_data
elif chunk_data.audio is not None and len(chunk_data.audio) > 0:
yield chunk_data
else:
logger.warning(
f"No audio generated for chunk: '{chunk_text[:100]}...'"
)
chunk_index += 1 # Increment chunk index after processing text
except Exception as e:
logger.error(
f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}"
)
continue
# Only finalize if we successfully processed at least one chunk # Only finalize if we successfully processed at least one chunk
if chunk_index > 0: if chunk_index > 0:

View file

@ -67,7 +67,7 @@ async def test_smart_split_short_text():
"""Test smart splitting with text under max tokens.""" """Test smart splitting with text under max tokens."""
text = "This is a short test sentence." text = "This is a short test sentence."
chunks = [] chunks = []
async for chunk_text, chunk_tokens in smart_split(text): async for chunk_text, chunk_tokens, _ in smart_split(text):
chunks.append((chunk_text, chunk_tokens)) chunks.append((chunk_text, chunk_tokens))
assert len(chunks) == 1 assert len(chunks) == 1
@ -82,7 +82,7 @@ async def test_smart_split_long_text():
text = ". ".join(["This is test sentence number " + str(i) for i in range(20)]) text = ". ".join(["This is test sentence number " + str(i) for i in range(20)])
chunks = [] chunks = []
async for chunk_text, chunk_tokens in smart_split(text): async for chunk_text, chunk_tokens, _ in smart_split(text):
chunks.append((chunk_text, chunk_tokens)) chunks.append((chunk_text, chunk_tokens))
assert len(chunks) > 1 assert len(chunks) > 1
@ -98,12 +98,13 @@ async def test_smart_split_with_punctuation():
text = "First sentence! Second sentence? Third sentence; Fourth sentence: Fifth sentence." text = "First sentence! Second sentence? Third sentence; Fourth sentence: Fifth sentence."
chunks = [] chunks = []
async for chunk_text, chunk_tokens in smart_split(text): async for chunk_text, chunk_tokens, _ in smart_split(text):
chunks.append(chunk_text) chunks.append(chunk_text)
# Verify punctuation is preserved # Verify punctuation is preserved
assert all(any(p in chunk for p in "!?;:.") for chunk in chunks) assert all(any(p in chunk for p in "!?;:.") for chunk in chunks)
def test_process_text_chunk_chinese_phonemes(): def test_process_text_chunk_chinese_phonemes():
"""Test processing with Chinese pinyin phonemes.""" """Test processing with Chinese pinyin phonemes."""
pinyin = "nǐ hǎo lì" # Example pinyin sequence with tones pinyin = "nǐ hǎo lì" # Example pinyin sequence with tones
@ -125,12 +126,13 @@ def test_get_sentence_info_chinese():
assert count == len(tokens) assert count == len(tokens)
assert count > 0 assert count > 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_smart_split_chinese_short(): async def test_smart_split_chinese_short():
"""Test Chinese smart splitting with short text.""" """Test Chinese smart splitting with short text."""
text = "这是一句话。" text = "这是一句话。"
chunks = [] chunks = []
async for chunk_text, chunk_tokens in smart_split(text, lang_code="z"): async for chunk_text, chunk_tokens, _ in smart_split(text, lang_code="z"):
chunks.append((chunk_text, chunk_tokens)) chunks.append((chunk_text, chunk_tokens))
assert len(chunks) == 1 assert len(chunks) == 1
@ -144,7 +146,7 @@ async def test_smart_split_chinese_long():
text = "".join([f"测试句子 {i}" for i in range(20)]) text = "".join([f"测试句子 {i}" for i in range(20)])
chunks = [] chunks = []
async for chunk_text, chunk_tokens in smart_split(text, lang_code="z"): async for chunk_text, chunk_tokens, _ in smart_split(text, lang_code="z"):
chunks.append((chunk_text, chunk_tokens)) chunks.append((chunk_text, chunk_tokens))
assert len(chunks) > 1 assert len(chunks) > 1
@ -160,8 +162,36 @@ async def test_smart_split_chinese_punctuation():
text = "第一句!第二问?第三句;第四句:第五句。" text = "第一句!第二问?第三句;第四句:第五句。"
chunks = [] chunks = []
async for chunk_text, _ in smart_split(text, lang_code="z"): async for chunk_text, _, _ in smart_split(text, lang_code="z"):
chunks.append(chunk_text) chunks.append(chunk_text)
# Verify Chinese punctuation is preserved # Verify Chinese punctuation is preserved
assert all(any(p in chunk for p in "!?;:。") for chunk in chunks) assert all(any(p in chunk for p in "!?;:。") for chunk in chunks)
@pytest.mark.asyncio
async def test_smart_split_with_pause():
"""Test smart splitting with pause tags."""
text = "Hello world [pause:2.5s] How are you?"
chunks = []
async for chunk_text, chunk_tokens, pause_duration in smart_split(text):
chunks.append((chunk_text, chunk_tokens, pause_duration))
# Should have 3 chunks: text, pause, text
assert len(chunks) == 3
# First chunk: text
assert chunks[0][2] is None # No pause
assert "Hello world" in chunks[0][0]
assert len(chunks[0][1]) > 0
# Second chunk: pause
assert chunks[1][2] == 2.5 # 2.5 second pause
assert chunks[1][0] == "" # Empty text
assert len(chunks[1][1]) == 0 # No tokens
# Third chunk: text
assert chunks[2][2] is None # No pause
assert "How are you?" in chunks[2][0]
assert len(chunks[2][1]) > 0