Refactor smart_split function to enhance handling of custom phonemes and normalization. Improved logging for clarity and error handling, ensuring compatibility with both ID and original tag formats. Streamlined text processing logic for better performance and maintainability.

This commit is contained in:
Lukin 2025-04-08 12:00:27 +08:00
parent e42f7fcb67
commit 3e23fb0cf0

View file

@ -156,12 +156,13 @@ def get_sentence_info(
def handle_custom_phonemes(s: re.Match[str], phenomes_list: Dict[str, str]) -> str:
# Stores the *original tag* like "[word](/ipa/)" mapped to the ID
original_tag = s.group(0).strip()
latest_id = f"</|custom_phonemes_{len(phenomes_list)}|/>"
phenomes_list[latest_id] = s.group(0).strip() # Store the full original tag [phoneme](/ipa/)
logger.debug(f"Replacing custom phoneme {phenomes_list[latest_id]} with ID {latest_id}")
phenomes_list[latest_id] = original_tag
logger.debug(f"Replacing custom phoneme tag '{original_tag}' with ID {latest_id}")
return latest_id
async def smart_split(
text: str,
max_tokens: int = settings.absolute_max_tokens,
@ -173,75 +174,89 @@ async def smart_split(
Yields:
Tuple of (text_chunk, tokens, pause_duration_s).
If pause_duration_s is not None, it's a pause chunk with empty text/tokens.
Otherwise, it's a text chunk. text_chunk may end with '\n'.
Otherwise, it's a text chunk containing original formatting (incl. custom phoneme tags).
"""
start_time = time.time()
chunk_count = 0
logger.info(f"Starting smart split for {len(text)} chars, max_tokens={max_tokens}")
logger.info(f"Starting smart split for {len(text)} chars, max_tokens={max_tokens}, lang_code={lang_code}")
custom_phoneme_list = {}
# 1. Temporarily replace custom phonemes like [word](/ipa/) with unique IDs
text_with_ids = CUSTOM_PHONEMES.sub(
lambda s: handle_custom_phonemes(s, custom_phoneme_list), text
# --- Determine if normalization and ID replacement are needed ---
apply_normalization = (
settings.advanced_text_normalization
and normalization_options.normalize
and lang_code in ["a", "b", "en-us", "en-gb"] # Normalization only for English
)
if custom_phoneme_list:
logger.debug(f"Found custom phonemes: {custom_phoneme_list}")
use_ids = apply_normalization # Only use IDs if we are normalizing
logger.debug(f"Normalization active: {apply_normalization}. Using ID replacement: {use_ids}")
custom_phoneme_map = {} # Map ID -> Original Tag OR empty if use_ids is False
processed_text = text # Start with original text
# 2. Normalize the text *with IDs* if required
normalized_text = text_with_ids
if settings.advanced_text_normalization and normalization_options.normalize:
if lang_code in ["a", "b", "en-us", "en-gb"]:
normalized_text = normalize_text(normalized_text, normalization_options)
logger.debug("Applied text normalization.")
else:
logger.info(
"Skipping text normalization as it is only supported for english"
)
# --- Step 1: Optional ID Replacement ---
if use_ids:
processed_text = CUSTOM_PHONEMES.sub(
lambda s: handle_custom_phonemes(s, custom_phoneme_map), text
)
if custom_phoneme_map:
logger.debug(f"Found and replaced custom phonemes with IDs: {custom_phoneme_map}")
# 3. Split the normalized text by pause tags
parts = PAUSE_TAG_PATTERN.split(normalized_text)
# --- Step 2: Optional Normalization ---
if apply_normalization:
processed_text = normalize_text(processed_text, normalization_options)
logger.debug("Applied text normalization.")
# --- Step 3: Split by Pause Tags ---
# This operates on `processed_text` which either has IDs or original tags
parts = PAUSE_TAG_PATTERN.split(processed_text)
logger.debug(f"Split into {len(parts)} parts by pause tags.")
part_idx = 0
while part_idx < len(parts):
text_part = parts[part_idx] # This part contains text and custom phoneme IDs
text_part = parts[part_idx] # This part contains text (with IDs or original tags)
part_idx += 1
if text_part:
# Process this text part using sentence splitting
# We pass the text_part *with IDs* to get_sentence_info
# get_sentence_info will handle restoring phonemes just before tokenization
sentences = get_sentence_info(text_part, custom_phoneme_list)
# --- Process Text Part ---
# get_sentence_info MUST be able to handle BOTH inputs with IDs (using custom_phoneme_map)
# AND inputs with original [word](/ipa/) tags (when custom_phoneme_map is empty)
# It needs to extract IPA phonemes correctly in both cases for tokenization.
# Crucially, it should return the *original format* sentence text (with IDs or tags)
try:
sentences = get_sentence_info(text_part, custom_phoneme_map)
except Exception as e:
logger.error(f"get_sentence_info failed for part '{text_part[:50]}...': {e}", exc_info=True)
continue # Skip this part if sentence processing fails
current_chunk_texts = [] # Store original sentence texts for the current chunk
current_chunk_texts = [] # Store original format sentence texts for the current chunk
current_chunk_tokens = []
current_token_count = 0
for sentence_text, sentence_tokens, sentence_token_count in sentences:
# --- Chunking Logic ---
for sentence_text_original_format, sentence_tokens, sentence_token_count in sentences:
# --- Chunking Logic (remains the same) ---
# Condition 1: Current sentence alone exceeds max tokens
if sentence_token_count > max_tokens:
logger.warning(f"Single sentence exceeds max_tokens ({sentence_token_count} > {max_tokens}): '{sentence_text[:50]}...'")
logger.warning(f"Single sentence exceeds max_tokens ({sentence_token_count} > {max_tokens}): '{sentence_text_original_format[:50]}...'")
# Yield any existing chunk first
if current_chunk_texts:
chunk_text_joined = " ".join(current_chunk_texts) # Join original texts
chunk_text_to_yield = " ".join(current_chunk_texts)
# Restore original tags IF we used IDs
if use_ids:
for p_id, original_tag_val in custom_phoneme_map.items():
chunk_text_to_yield = chunk_text_to_yield.replace(p_id, original_tag_val)
chunk_count += 1
logger.info(f"Yielding text chunk {chunk_count} (before oversized sentence): '{chunk_text_joined[:50]}...' ({current_token_count} tokens)")
yield chunk_text_joined, current_chunk_tokens, None
logger.info(f"Yielding text chunk {chunk_count} (before oversized): '{chunk_text_to_yield[:50]}...' ({current_token_count} tokens)")
yield chunk_text_to_yield, current_chunk_tokens, None
current_chunk_texts = []
current_chunk_tokens = []
current_token_count = 0
# Yield the oversized sentence as its own chunk
# Restore phonemes before yielding the text
text_to_yield = sentence_text
for p_id, p_val in custom_phoneme_list.items():
if p_id in text_to_yield:
text_to_yield = text_to_yield.replace(p_id, p_val)
text_to_yield = sentence_text_original_format
# Restore original tags IF we used IDs
if use_ids:
for p_id, original_tag_val in custom_phoneme_map.items():
text_to_yield = text_to_yield.replace(p_id, original_tag_val)
chunk_count += 1
logger.info(f"Yielding oversized text chunk {chunk_count}: '{text_to_yield[:50]}...' ({sentence_token_count} tokens)")
@ -252,12 +267,15 @@ async def smart_split(
elif current_token_count + sentence_token_count > max_tokens:
# Yield the current chunk first
if current_chunk_texts:
chunk_text_joined = " ".join(current_chunk_texts) # Join original texts
chunk_text_to_yield = " ".join(current_chunk_texts)
if use_ids:
for p_id, original_tag_val in custom_phoneme_map.items():
chunk_text_to_yield = chunk_text_to_yield.replace(p_id, original_tag_val)
chunk_count += 1
logger.info(f"Yielding text chunk {chunk_count} (max_tokens limit): '{chunk_text_joined[:50]}...' ({current_token_count} tokens)")
yield chunk_text_joined, current_chunk_tokens, None
logger.info(f"Yielding text chunk {chunk_count} (max_tokens limit): '{chunk_text_to_yield[:50]}...' ({current_token_count} tokens)")
yield chunk_text_to_yield, current_chunk_tokens, None
# Start a new chunk with the current sentence
current_chunk_texts = [sentence_text]
current_chunk_texts = [sentence_text_original_format]
current_chunk_tokens = sentence_tokens
current_token_count = sentence_token_count
@ -265,18 +283,21 @@ async def smart_split(
elif (current_token_count >= settings.target_min_tokens and
current_token_count + sentence_token_count > settings.target_max_tokens):
# Yield the current chunk
chunk_text_joined = " ".join(current_chunk_texts) # Join original texts
chunk_text_to_yield = " ".join(current_chunk_texts)
if use_ids:
for p_id, original_tag_val in custom_phoneme_map.items():
chunk_text_to_yield = chunk_text_to_yield.replace(p_id, original_tag_val)
chunk_count += 1
logger.info(f"Yielding text chunk {chunk_count} (target_max limit): '{chunk_text_joined[:50]}...' ({current_token_count} tokens)")
yield chunk_text_joined, current_chunk_tokens, None
logger.info(f"Yielding text chunk {chunk_count} (target_max limit): '{chunk_text_to_yield[:50]}...' ({current_token_count} tokens)")
yield chunk_text_to_yield, current_chunk_tokens, None
# Start a new chunk
current_chunk_texts = [sentence_text]
current_chunk_texts = [sentence_text_original_format]
current_chunk_tokens = sentence_tokens
current_token_count = sentence_token_count
# Condition 4: Add sentence to current chunk (fits within max_tokens and either below target_max or below target_min)
# Condition 4: Add sentence to current chunk
else:
current_chunk_texts.append(sentence_text)
current_chunk_texts.append(sentence_text_original_format)
current_chunk_tokens.extend(sentence_tokens)
current_token_count += sentence_token_count
@ -284,18 +305,17 @@ async def smart_split(
# Yield any remaining accumulated chunk for this text part
if current_chunk_texts:
chunk_text_joined = " ".join(current_chunk_texts) # Join original texts
# Restore phonemes before yielding
text_to_yield = chunk_text_joined
for p_id, p_val in custom_phoneme_list.items():
if p_id in text_to_yield:
text_to_yield = text_to_yield.replace(p_id, p_val)
chunk_text_to_yield = " ".join(current_chunk_texts)
# Restore original tags IF we used IDs
if use_ids:
for p_id, original_tag_val in custom_phoneme_map.items():
chunk_text_to_yield = chunk_text_to_yield.replace(p_id, original_tag_val)
chunk_count += 1
logger.info(f"Yielding final text chunk {chunk_count} for part: '{text_to_yield[:50]}...' ({current_token_count} tokens)")
yield text_to_yield, current_chunk_tokens, None
logger.info(f"Yielding final text chunk {chunk_count} for part: '{chunk_text_to_yield[:50]}...' ({current_token_count} tokens)")
yield chunk_text_to_yield, current_chunk_tokens, None
# --- Handle Pause Part ---
# Check if the next part is a pause duration
if part_idx < len(parts):
duration_str = parts[part_idx]
@ -308,10 +328,9 @@ async def smart_split(
yield "", [], duration # Yield pause chunk
except (ValueError, TypeError):
logger.warning(f"Could not parse pause duration: {duration_str}")
# If parsing fails, potentially treat the duration_str as text?
# For now, just log a warning and skip.
# Treat as text if parsing fails? For now, just log and skip.
# --- End of parts loop ---
total_time = time.time() - start_time
logger.info(
f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks (including pauses)"