From 355ec54f78d66f55ba79b948dc2297f286df0fc8 Mon Sep 17 00:00:00 2001 From: remsky Date: Tue, 28 Jan 2025 20:41:57 -0700 Subject: [PATCH] Refactor TTS service and chunker: update comments and remove unused code --- api/src/services/text_processing/chunker.py | 93 +++----- .../text_processing/text_processor.py | 199 ++++++++++-------- api/src/services/tts_service.py | 2 +- 3 files changed, 136 insertions(+), 158 deletions(-) diff --git a/api/src/services/text_processing/chunker.py b/api/src/services/text_processing/chunker.py index 5ffbc00..80e67f0 100644 --- a/api/src/services/text_processing/chunker.py +++ b/api/src/services/text_processing/chunker.py @@ -1,74 +1,31 @@ -"""Text chunking module for TTS processing""" +# """Text chunking module for TTS processing""" -from typing import List, AsyncGenerator -from . import semchunk_slim +# from typing import List, AsyncGenerator -async def fallback_split(text: str, max_chars: int = 400) -> List[str]: - """Emergency length control - only used if chunks are too long""" - words = text.split() - chunks = [] - current = [] - current_len = 0 +# async def fallback_split(text: str, max_chars: int = 400) -> List[str]: +# """Emergency length control - only used if chunks are too long""" +# words = text.split() +# chunks = [] +# current = [] +# current_len = 0 - for word in words: - # Always include at least one word per chunk - if not current: - current.append(word) - current_len = len(word) - continue +# for word in words: +# # Always include at least one word per chunk +# if not current: +# current.append(word) +# current_len = len(word) +# continue - # Check if adding word would exceed limit - if current_len + len(word) + 1 <= max_chars: - current.append(word) - current_len += len(word) + 1 - else: - chunks.append(" ".join(current)) - current = [word] - current_len = len(word) +# # Check if adding word would exceed limit +# if current_len + len(word) + 1 <= max_chars: +# current.append(word) +# current_len += len(word) + 1 +# else: +# chunks.append(" ".join(current)) +# current = [word] +# current_len = len(word) - if current: - chunks.append(" ".join(current)) +# if current: +# chunks.append(" ".join(current)) - return chunks - -async def split_text(text: str, max_chunk: int = None) -> AsyncGenerator[str, None]: - """Split text into TTS-friendly chunks - - Args: - text: Text to split into chunks - max_chunk: Maximum chunk size (defaults to 400) - - Yields: - Text chunks suitable for TTS processing - """ - if max_chunk is None: - max_chunk = 400 - - if not isinstance(text, str): - text = str(text) if text is not None else "" - - text = text.strip() - if not text: - return - - # Initialize chunker targeting ~300 chars to allow for expansion - chunker = semchunk_slim.chunkerify( - lambda t: len(t) // 5, # Simple length-based target - chunk_size=60 # Target ~300 chars - ) - - # Get initial chunks - chunks = chunker(text) - - # Process chunks - for chunk in chunks: - chunk = chunk.strip() - if not chunk: - continue - - # Use fallback for any chunks that are too long - if len(chunk) > max_chunk: - for subchunk in await fallback_split(chunk, max_chunk): - yield subchunk - else: - yield chunk +# return chunks diff --git a/api/src/services/text_processing/text_processor.py b/api/src/services/text_processing/text_processor.py index 8aa0055..74ae2c4 100644 --- a/api/src/services/text_processing/text_processor.py +++ b/api/src/services/text_processing/text_processor.py @@ -8,6 +8,11 @@ from .phonemizer import phonemize from .normalizer import normalize_text from .vocabulary import tokenize +# Constants for chunk size optimization +TARGET_MIN_TOKENS = 300 +TARGET_MAX_TOKENS = 400 +ABSOLUTE_MAX_TOKENS = 500 + def process_text_chunk(text: str, language: str = "a") -> List[int]: """Process a chunk of text through normalization, phonemization, and tokenization. @@ -43,6 +48,15 @@ def process_text_chunk(text: str, language: str = "a") -> List[int]: return tokens +def is_chunk_size_optimal(token_count: int) -> bool: + """Check if chunk size is within optimal range.""" + return TARGET_MIN_TOKENS <= token_count <= TARGET_MAX_TOKENS + +async def yield_chunk(text: str, tokens: List[int], chunk_count: int) -> Tuple[str, List[int]]: + """Yield a chunk with consistent logging.""" + logger.info(f"Yielding chunk {chunk_count}: '{text[:50]}...' ({len(tokens)} tokens)") + return text, tokens + def process_text(text: str, language: str = "a") -> List[int]: """Process text into token IDs. @@ -62,116 +76,123 @@ def process_text(text: str, language: str = "a") -> List[int]: return process_text_chunk(text, language) -async def smart_split(text: str, max_tokens: int = 500) -> AsyncGenerator[Tuple[str, List[int]], None]: - """Split text into semantically meaningful chunks while respecting token limits. - - Args: - text: Input text to split - max_tokens: Maximum tokens per chunk - - Yields: - Tuples of (text chunk, token IDs) where token count is <= max_tokens - """ - start_time = time.time() - chunk_count = 0 - total_chars = len(text) - logger.info(f"Starting text split for {total_chars} characters with {max_tokens} max tokens") - - # Split on major punctuation first +# Target token ranges +TARGET_MIN = 300 +TARGET_MAX = 400 +ABSOLUTE_MAX = 500 + +def get_sentence_info(text: str) -> List[Tuple[str, List[int], int]]: + """Process all sentences and return info.""" sentences = re.split(r'([.!?;:])', text) - - current_chunk = [] - current_token_count = 0 + results = [] for i in range(0, len(sentences), 2): - # Get sentence and its punctuation (if any) sentence = sentences[i].strip() punct = sentences[i + 1] if i + 1 < len(sentences) else "" if not sentence: continue - # Process sentence to get token count - sentence_with_punct = sentence + punct - tokens = process_text_chunk(sentence_with_punct) - token_count = len(tokens) - logger.debug(f"Sentence '{sentence_with_punct[:50]}...' has {token_count} tokens") + full = sentence + punct + tokens = process_text_chunk(full) + results.append((full, tokens, len(tokens))) - # If this single sentence is too long, split on commas - if token_count > max_tokens: - logger.debug(f"Sentence exceeds token limit, splitting on commas") - clause_splits = re.split(r'([,])', sentence_with_punct) - for j in range(0, len(clause_splits), 2): - clause = clause_splits[j].strip() - comma = clause_splits[j + 1] if j + 1 < len(clause_splits) else "" + return results + +async def smart_split(text: str, max_tokens: int = ABSOLUTE_MAX) -> AsyncGenerator[Tuple[str, List[int]], None]: + """Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens.""" + start_time = time.time() + chunk_count = 0 + logger.info(f"Starting smart split for {len(text)} chars") + + # Process all sentences + sentences = get_sentence_info(text) + + current_chunk = [] + current_tokens = [] + current_count = 0 + + for sentence, tokens, count in sentences: + # Handle sentences that exceed max tokens + if count > max_tokens: + # Yield current chunk if any + if current_chunk: + chunk_text = " ".join(current_chunk) + chunk_count += 1 + logger.info(f"Yielding chunk {chunk_count}: '{chunk_text[:50]}...' ({current_count} tokens)") + yield chunk_text, current_tokens + current_chunk = [] + current_tokens = [] + current_count = 0 + + # Split long sentence on commas + clauses = re.split(r'([,])', sentence) + clause_chunk = [] + clause_tokens = [] + clause_count = 0 + + for j in range(0, len(clauses), 2): + clause = clauses[j].strip() + comma = clauses[j + 1] if j + 1 < len(clauses) else "" if not clause: continue - clause_with_punct = clause + comma - clause_tokens = process_text_chunk(clause_with_punct) + full_clause = clause + comma + tokens = process_text_chunk(full_clause) + count = len(tokens) - # If still too long, do a hard split on words - if len(clause_tokens) > max_tokens: - logger.debug(f"Clause exceeds token limit, splitting on words") - words = clause_with_punct.split() - temp_chunk = [] - temp_tokens = [] - - for word in words: - word_tokens = process_text_chunk(word) - if len(temp_tokens) + len(word_tokens) > max_tokens: - if temp_chunk: # Don't yield empty chunks - chunk_text = " ".join(temp_chunk) - chunk_count += 1 - logger.info(f"Yielding word-split chunk {chunk_count}: '{chunk_text[:50]}...' ({len(temp_tokens)} tokens)") - yield chunk_text, temp_tokens - temp_chunk = [word] - temp_tokens = word_tokens - else: - temp_chunk.append(word) - temp_tokens.extend(word_tokens) - - if temp_chunk: # Don't forget the last chunk - chunk_text = " ".join(temp_chunk) - chunk_count += 1 - logger.info(f"Yielding final word-split chunk {chunk_count}: '{chunk_text[:50]}...' ({len(temp_tokens)} tokens)") - yield chunk_text, temp_tokens - + # If adding clause keeps us under max and not optimal yet + if clause_count + count <= max_tokens and clause_count + count <= TARGET_MAX: + clause_chunk.append(full_clause) + clause_tokens.extend(tokens) + clause_count += count else: - # Check if adding this clause would exceed the limit - if current_token_count + len(clause_tokens) > max_tokens: - if current_chunk: # Don't yield empty chunks - chunk_text = " ".join(current_chunk) - chunk_count += 1 - logger.info(f"Yielding clause-split chunk {chunk_count}: '{chunk_text[:50]}...' ({current_token_count} tokens)") - yield chunk_text, process_text_chunk(chunk_text) - current_chunk = [clause_with_punct] - current_token_count = len(clause_tokens) - else: - current_chunk.append(clause_with_punct) - current_token_count += len(clause_tokens) - + # Yield clause chunk if we have one + if clause_chunk: + chunk_text = " ".join(clause_chunk) + chunk_count += 1 + logger.info(f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}...' ({clause_count} tokens)") + yield chunk_text, clause_tokens + clause_chunk = [full_clause] + clause_tokens = tokens + clause_count = count + + # Don't forget last clause chunk + if clause_chunk: + chunk_text = " ".join(clause_chunk) + chunk_count += 1 + logger.info(f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}...' ({clause_count} tokens)") + yield chunk_text, clause_tokens + + # Regular sentence handling + elif current_count + count <= TARGET_MAX: + # Keep building chunk while under target max + current_chunk.append(sentence) + current_tokens.extend(tokens) + current_count += count + elif current_count + count <= max_tokens: + # Accept slightly larger chunk if needed + current_chunk.append(sentence) + current_tokens.extend(tokens) + current_count += count else: - # Check if adding this sentence would exceed the limit - if current_token_count + token_count > max_tokens: - if current_chunk: # Don't yield empty chunks - chunk_text = " ".join(current_chunk) - chunk_count += 1 - logger.info(f"Yielding sentence-split chunk {chunk_count}: '{chunk_text[:50]}...' ({current_token_count} tokens)") - yield chunk_text, process_text_chunk(chunk_text) - current_chunk = [sentence_with_punct] - current_token_count = token_count - else: - current_chunk.append(sentence_with_punct) - current_token_count += token_count + # Yield current chunk and start new one + if current_chunk: + chunk_text = " ".join(current_chunk) + chunk_count += 1 + logger.info(f"Yielding chunk {chunk_count}: '{chunk_text[:50]}...' ({current_count} tokens)") + yield chunk_text, current_tokens + current_chunk = [sentence] + current_tokens = tokens + current_count = count # Don't forget the last chunk if current_chunk: chunk_text = " ".join(current_chunk) chunk_count += 1 - logger.info(f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}...' ({current_token_count} tokens)") - yield chunk_text, process_text_chunk(chunk_text) + logger.info(f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}...' ({current_count} tokens)") + yield chunk_text, current_tokens total_time = time.time() - start_time - logger.info(f"Text splitting completed in {total_time*1000:.2f}ms, produced {chunk_count} chunks") + logger.info(f"Split completed in {total_time*1000:.2f}ms, produced {chunk_count} chunks") diff --git a/api/src/services/tts_service.py b/api/src/services/tts_service.py index e6ef24d..5fbffd5 100644 --- a/api/src/services/tts_service.py +++ b/api/src/services/tts_service.py @@ -119,7 +119,7 @@ class TTSService: try: # Process audio for chunk result = await self._process_chunk( - tokens, + tokens, # Now always a flat List[int] voice_tensor, speed, output_format,