Refactor TTS service and chunker: update comments and remove unused code

This commit is contained in:
remsky 2025-01-28 20:41:57 -07:00
parent b25ba5e7e6
commit 355ec54f78
3 changed files with 136 additions and 158 deletions

View file

@ -1,74 +1,31 @@
"""Text chunking module for TTS processing""" # """Text chunking module for TTS processing"""
from typing import List, AsyncGenerator # from typing import List, AsyncGenerator
from . import semchunk_slim
async def fallback_split(text: str, max_chars: int = 400) -> List[str]: # async def fallback_split(text: str, max_chars: int = 400) -> List[str]:
"""Emergency length control - only used if chunks are too long""" # """Emergency length control - only used if chunks are too long"""
words = text.split() # words = text.split()
chunks = [] # chunks = []
current = [] # current = []
current_len = 0 # current_len = 0
for word in words: # for word in words:
# Always include at least one word per chunk # # Always include at least one word per chunk
if not current: # if not current:
current.append(word) # current.append(word)
current_len = len(word) # current_len = len(word)
continue # continue
# Check if adding word would exceed limit # # Check if adding word would exceed limit
if current_len + len(word) + 1 <= max_chars: # if current_len + len(word) + 1 <= max_chars:
current.append(word) # current.append(word)
current_len += len(word) + 1 # current_len += len(word) + 1
else: # else:
chunks.append(" ".join(current)) # chunks.append(" ".join(current))
current = [word] # current = [word]
current_len = len(word) # current_len = len(word)
if current: # if current:
chunks.append(" ".join(current)) # chunks.append(" ".join(current))
return chunks # return chunks
async def split_text(text: str, max_chunk: int = None) -> AsyncGenerator[str, None]:
"""Split text into TTS-friendly chunks
Args:
text: Text to split into chunks
max_chunk: Maximum chunk size (defaults to 400)
Yields:
Text chunks suitable for TTS processing
"""
if max_chunk is None:
max_chunk = 400
if not isinstance(text, str):
text = str(text) if text is not None else ""
text = text.strip()
if not text:
return
# Initialize chunker targeting ~300 chars to allow for expansion
chunker = semchunk_slim.chunkerify(
lambda t: len(t) // 5, # Simple length-based target
chunk_size=60 # Target ~300 chars
)
# Get initial chunks
chunks = chunker(text)
# Process chunks
for chunk in chunks:
chunk = chunk.strip()
if not chunk:
continue
# Use fallback for any chunks that are too long
if len(chunk) > max_chunk:
for subchunk in await fallback_split(chunk, max_chunk):
yield subchunk
else:
yield chunk

View file

@ -8,6 +8,11 @@ from .phonemizer import phonemize
from .normalizer import normalize_text from .normalizer import normalize_text
from .vocabulary import tokenize from .vocabulary import tokenize
# Constants for chunk size optimization
TARGET_MIN_TOKENS = 300
TARGET_MAX_TOKENS = 400
ABSOLUTE_MAX_TOKENS = 500
def process_text_chunk(text: str, language: str = "a") -> List[int]: def process_text_chunk(text: str, language: str = "a") -> List[int]:
"""Process a chunk of text through normalization, phonemization, and tokenization. """Process a chunk of text through normalization, phonemization, and tokenization.
@ -43,6 +48,15 @@ def process_text_chunk(text: str, language: str = "a") -> List[int]:
return tokens return tokens
def is_chunk_size_optimal(token_count: int) -> bool:
"""Check if chunk size is within optimal range."""
return TARGET_MIN_TOKENS <= token_count <= TARGET_MAX_TOKENS
async def yield_chunk(text: str, tokens: List[int], chunk_count: int) -> Tuple[str, List[int]]:
"""Yield a chunk with consistent logging."""
logger.info(f"Yielding chunk {chunk_count}: '{text[:50]}...' ({len(tokens)} tokens)")
return text, tokens
def process_text(text: str, language: str = "a") -> List[int]: def process_text(text: str, language: str = "a") -> List[int]:
"""Process text into token IDs. """Process text into token IDs.
@ -62,116 +76,123 @@ def process_text(text: str, language: str = "a") -> List[int]:
return process_text_chunk(text, language) return process_text_chunk(text, language)
async def smart_split(text: str, max_tokens: int = 500) -> AsyncGenerator[Tuple[str, List[int]], None]: # Target token ranges
"""Split text into semantically meaningful chunks while respecting token limits. TARGET_MIN = 300
TARGET_MAX = 400
Args: ABSOLUTE_MAX = 500
text: Input text to split
max_tokens: Maximum tokens per chunk def get_sentence_info(text: str) -> List[Tuple[str, List[int], int]]:
"""Process all sentences and return info."""
Yields:
Tuples of (text chunk, token IDs) where token count is <= max_tokens
"""
start_time = time.time()
chunk_count = 0
total_chars = len(text)
logger.info(f"Starting text split for {total_chars} characters with {max_tokens} max tokens")
# Split on major punctuation first
sentences = re.split(r'([.!?;:])', text) sentences = re.split(r'([.!?;:])', text)
results = []
current_chunk = []
current_token_count = 0
for i in range(0, len(sentences), 2): for i in range(0, len(sentences), 2):
# Get sentence and its punctuation (if any)
sentence = sentences[i].strip() sentence = sentences[i].strip()
punct = sentences[i + 1] if i + 1 < len(sentences) else "" punct = sentences[i + 1] if i + 1 < len(sentences) else ""
if not sentence: if not sentence:
continue continue
# Process sentence to get token count full = sentence + punct
sentence_with_punct = sentence + punct tokens = process_text_chunk(full)
tokens = process_text_chunk(sentence_with_punct) results.append((full, tokens, len(tokens)))
token_count = len(tokens)
logger.debug(f"Sentence '{sentence_with_punct[:50]}...' has {token_count} tokens")
# If this single sentence is too long, split on commas return results
if token_count > max_tokens:
logger.debug(f"Sentence exceeds token limit, splitting on commas") async def smart_split(text: str, max_tokens: int = ABSOLUTE_MAX) -> AsyncGenerator[Tuple[str, List[int]], None]:
clause_splits = re.split(r'([,])', sentence_with_punct) """Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens."""
for j in range(0, len(clause_splits), 2): start_time = time.time()
clause = clause_splits[j].strip() chunk_count = 0
comma = clause_splits[j + 1] if j + 1 < len(clause_splits) else "" logger.info(f"Starting smart split for {len(text)} chars")
# Process all sentences
sentences = get_sentence_info(text)
current_chunk = []
current_tokens = []
current_count = 0
for sentence, tokens, count in sentences:
# Handle sentences that exceed max tokens
if count > max_tokens:
# Yield current chunk if any
if current_chunk:
chunk_text = " ".join(current_chunk)
chunk_count += 1
logger.info(f"Yielding chunk {chunk_count}: '{chunk_text[:50]}...' ({current_count} tokens)")
yield chunk_text, current_tokens
current_chunk = []
current_tokens = []
current_count = 0
# Split long sentence on commas
clauses = re.split(r'([,])', sentence)
clause_chunk = []
clause_tokens = []
clause_count = 0
for j in range(0, len(clauses), 2):
clause = clauses[j].strip()
comma = clauses[j + 1] if j + 1 < len(clauses) else ""
if not clause: if not clause:
continue continue
clause_with_punct = clause + comma full_clause = clause + comma
clause_tokens = process_text_chunk(clause_with_punct) tokens = process_text_chunk(full_clause)
count = len(tokens)
# If still too long, do a hard split on words # If adding clause keeps us under max and not optimal yet
if len(clause_tokens) > max_tokens: if clause_count + count <= max_tokens and clause_count + count <= TARGET_MAX:
logger.debug(f"Clause exceeds token limit, splitting on words") clause_chunk.append(full_clause)
words = clause_with_punct.split() clause_tokens.extend(tokens)
temp_chunk = [] clause_count += count
temp_tokens = []
for word in words:
word_tokens = process_text_chunk(word)
if len(temp_tokens) + len(word_tokens) > max_tokens:
if temp_chunk: # Don't yield empty chunks
chunk_text = " ".join(temp_chunk)
chunk_count += 1
logger.info(f"Yielding word-split chunk {chunk_count}: '{chunk_text[:50]}...' ({len(temp_tokens)} tokens)")
yield chunk_text, temp_tokens
temp_chunk = [word]
temp_tokens = word_tokens
else:
temp_chunk.append(word)
temp_tokens.extend(word_tokens)
if temp_chunk: # Don't forget the last chunk
chunk_text = " ".join(temp_chunk)
chunk_count += 1
logger.info(f"Yielding final word-split chunk {chunk_count}: '{chunk_text[:50]}...' ({len(temp_tokens)} tokens)")
yield chunk_text, temp_tokens
else: else:
# Check if adding this clause would exceed the limit # Yield clause chunk if we have one
if current_token_count + len(clause_tokens) > max_tokens: if clause_chunk:
if current_chunk: # Don't yield empty chunks chunk_text = " ".join(clause_chunk)
chunk_text = " ".join(current_chunk) chunk_count += 1
chunk_count += 1 logger.info(f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}...' ({clause_count} tokens)")
logger.info(f"Yielding clause-split chunk {chunk_count}: '{chunk_text[:50]}...' ({current_token_count} tokens)") yield chunk_text, clause_tokens
yield chunk_text, process_text_chunk(chunk_text) clause_chunk = [full_clause]
current_chunk = [clause_with_punct] clause_tokens = tokens
current_token_count = len(clause_tokens) clause_count = count
else:
current_chunk.append(clause_with_punct) # Don't forget last clause chunk
current_token_count += len(clause_tokens) if clause_chunk:
chunk_text = " ".join(clause_chunk)
chunk_count += 1
logger.info(f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}...' ({clause_count} tokens)")
yield chunk_text, clause_tokens
# Regular sentence handling
elif current_count + count <= TARGET_MAX:
# Keep building chunk while under target max
current_chunk.append(sentence)
current_tokens.extend(tokens)
current_count += count
elif current_count + count <= max_tokens:
# Accept slightly larger chunk if needed
current_chunk.append(sentence)
current_tokens.extend(tokens)
current_count += count
else: else:
# Check if adding this sentence would exceed the limit # Yield current chunk and start new one
if current_token_count + token_count > max_tokens: if current_chunk:
if current_chunk: # Don't yield empty chunks chunk_text = " ".join(current_chunk)
chunk_text = " ".join(current_chunk) chunk_count += 1
chunk_count += 1 logger.info(f"Yielding chunk {chunk_count}: '{chunk_text[:50]}...' ({current_count} tokens)")
logger.info(f"Yielding sentence-split chunk {chunk_count}: '{chunk_text[:50]}...' ({current_token_count} tokens)") yield chunk_text, current_tokens
yield chunk_text, process_text_chunk(chunk_text) current_chunk = [sentence]
current_chunk = [sentence_with_punct] current_tokens = tokens
current_token_count = token_count current_count = count
else:
current_chunk.append(sentence_with_punct)
current_token_count += token_count
# Don't forget the last chunk # Don't forget the last chunk
if current_chunk: if current_chunk:
chunk_text = " ".join(current_chunk) chunk_text = " ".join(current_chunk)
chunk_count += 1 chunk_count += 1
logger.info(f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}...' ({current_token_count} tokens)") logger.info(f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}...' ({current_count} tokens)")
yield chunk_text, process_text_chunk(chunk_text) yield chunk_text, current_tokens
total_time = time.time() - start_time total_time = time.time() - start_time
logger.info(f"Text splitting completed in {total_time*1000:.2f}ms, produced {chunk_count} chunks") logger.info(f"Split completed in {total_time*1000:.2f}ms, produced {chunk_count} chunks")

View file

@ -119,7 +119,7 @@ class TTSService:
try: try:
# Process audio for chunk # Process audio for chunk
result = await self._process_chunk( result = await self._process_chunk(
tokens, tokens, # Now always a flat List[int]
voice_tensor, voice_tensor,
speed, speed,
output_format, output_format,