Refactor TTS service and chunker: update comments and remove unused code

This commit is contained in:
remsky 2025-01-28 20:41:57 -07:00
parent b25ba5e7e6
commit 355ec54f78
3 changed files with 136 additions and 158 deletions

View file

@ -1,74 +1,31 @@
"""Text chunking module for TTS processing"""
# """Text chunking module for TTS processing"""
from typing import List, AsyncGenerator
from . import semchunk_slim
# from typing import List, AsyncGenerator
async def fallback_split(text: str, max_chars: int = 400) -> List[str]:
"""Emergency length control - only used if chunks are too long"""
words = text.split()
chunks = []
current = []
current_len = 0
# async def fallback_split(text: str, max_chars: int = 400) -> List[str]:
# """Emergency length control - only used if chunks are too long"""
# words = text.split()
# chunks = []
# current = []
# current_len = 0
for word in words:
# Always include at least one word per chunk
if not current:
current.append(word)
current_len = len(word)
continue
# for word in words:
# # Always include at least one word per chunk
# if not current:
# current.append(word)
# current_len = len(word)
# continue
# Check if adding word would exceed limit
if current_len + len(word) + 1 <= max_chars:
current.append(word)
current_len += len(word) + 1
else:
chunks.append(" ".join(current))
current = [word]
current_len = len(word)
# # Check if adding word would exceed limit
# if current_len + len(word) + 1 <= max_chars:
# current.append(word)
# current_len += len(word) + 1
# else:
# chunks.append(" ".join(current))
# current = [word]
# current_len = len(word)
if current:
chunks.append(" ".join(current))
# if current:
# chunks.append(" ".join(current))
return chunks
async def split_text(text: str, max_chunk: int = None) -> AsyncGenerator[str, None]:
"""Split text into TTS-friendly chunks
Args:
text: Text to split into chunks
max_chunk: Maximum chunk size (defaults to 400)
Yields:
Text chunks suitable for TTS processing
"""
if max_chunk is None:
max_chunk = 400
if not isinstance(text, str):
text = str(text) if text is not None else ""
text = text.strip()
if not text:
return
# Initialize chunker targeting ~300 chars to allow for expansion
chunker = semchunk_slim.chunkerify(
lambda t: len(t) // 5, # Simple length-based target
chunk_size=60 # Target ~300 chars
)
# Get initial chunks
chunks = chunker(text)
# Process chunks
for chunk in chunks:
chunk = chunk.strip()
if not chunk:
continue
# Use fallback for any chunks that are too long
if len(chunk) > max_chunk:
for subchunk in await fallback_split(chunk, max_chunk):
yield subchunk
else:
yield chunk
# return chunks

View file

@ -8,6 +8,11 @@ from .phonemizer import phonemize
from .normalizer import normalize_text
from .vocabulary import tokenize
# Constants for chunk size optimization
TARGET_MIN_TOKENS = 300
TARGET_MAX_TOKENS = 400
ABSOLUTE_MAX_TOKENS = 500
def process_text_chunk(text: str, language: str = "a") -> List[int]:
"""Process a chunk of text through normalization, phonemization, and tokenization.
@ -43,6 +48,15 @@ def process_text_chunk(text: str, language: str = "a") -> List[int]:
return tokens
def is_chunk_size_optimal(token_count: int) -> bool:
"""Check if chunk size is within optimal range."""
return TARGET_MIN_TOKENS <= token_count <= TARGET_MAX_TOKENS
async def yield_chunk(text: str, tokens: List[int], chunk_count: int) -> Tuple[str, List[int]]:
"""Yield a chunk with consistent logging."""
logger.info(f"Yielding chunk {chunk_count}: '{text[:50]}...' ({len(tokens)} tokens)")
return text, tokens
def process_text(text: str, language: str = "a") -> List[int]:
"""Process text into token IDs.
@ -62,116 +76,123 @@ def process_text(text: str, language: str = "a") -> List[int]:
return process_text_chunk(text, language)
async def smart_split(text: str, max_tokens: int = 500) -> AsyncGenerator[Tuple[str, List[int]], None]:
"""Split text into semantically meaningful chunks while respecting token limits.
Args:
text: Input text to split
max_tokens: Maximum tokens per chunk
Yields:
Tuples of (text chunk, token IDs) where token count is <= max_tokens
"""
start_time = time.time()
chunk_count = 0
total_chars = len(text)
logger.info(f"Starting text split for {total_chars} characters with {max_tokens} max tokens")
# Split on major punctuation first
# Target token ranges
TARGET_MIN = 300
TARGET_MAX = 400
ABSOLUTE_MAX = 500
def get_sentence_info(text: str) -> List[Tuple[str, List[int], int]]:
"""Process all sentences and return info."""
sentences = re.split(r'([.!?;:])', text)
current_chunk = []
current_token_count = 0
results = []
for i in range(0, len(sentences), 2):
# Get sentence and its punctuation (if any)
sentence = sentences[i].strip()
punct = sentences[i + 1] if i + 1 < len(sentences) else ""
if not sentence:
continue
# Process sentence to get token count
sentence_with_punct = sentence + punct
tokens = process_text_chunk(sentence_with_punct)
token_count = len(tokens)
logger.debug(f"Sentence '{sentence_with_punct[:50]}...' has {token_count} tokens")
full = sentence + punct
tokens = process_text_chunk(full)
results.append((full, tokens, len(tokens)))
# If this single sentence is too long, split on commas
if token_count > max_tokens:
logger.debug(f"Sentence exceeds token limit, splitting on commas")
clause_splits = re.split(r'([,])', sentence_with_punct)
for j in range(0, len(clause_splits), 2):
clause = clause_splits[j].strip()
comma = clause_splits[j + 1] if j + 1 < len(clause_splits) else ""
return results
async def smart_split(text: str, max_tokens: int = ABSOLUTE_MAX) -> AsyncGenerator[Tuple[str, List[int]], None]:
"""Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens."""
start_time = time.time()
chunk_count = 0
logger.info(f"Starting smart split for {len(text)} chars")
# Process all sentences
sentences = get_sentence_info(text)
current_chunk = []
current_tokens = []
current_count = 0
for sentence, tokens, count in sentences:
# Handle sentences that exceed max tokens
if count > max_tokens:
# Yield current chunk if any
if current_chunk:
chunk_text = " ".join(current_chunk)
chunk_count += 1
logger.info(f"Yielding chunk {chunk_count}: '{chunk_text[:50]}...' ({current_count} tokens)")
yield chunk_text, current_tokens
current_chunk = []
current_tokens = []
current_count = 0
# Split long sentence on commas
clauses = re.split(r'([,])', sentence)
clause_chunk = []
clause_tokens = []
clause_count = 0
for j in range(0, len(clauses), 2):
clause = clauses[j].strip()
comma = clauses[j + 1] if j + 1 < len(clauses) else ""
if not clause:
continue
clause_with_punct = clause + comma
clause_tokens = process_text_chunk(clause_with_punct)
full_clause = clause + comma
tokens = process_text_chunk(full_clause)
count = len(tokens)
# If still too long, do a hard split on words
if len(clause_tokens) > max_tokens:
logger.debug(f"Clause exceeds token limit, splitting on words")
words = clause_with_punct.split()
temp_chunk = []
temp_tokens = []
for word in words:
word_tokens = process_text_chunk(word)
if len(temp_tokens) + len(word_tokens) > max_tokens:
if temp_chunk: # Don't yield empty chunks
chunk_text = " ".join(temp_chunk)
chunk_count += 1
logger.info(f"Yielding word-split chunk {chunk_count}: '{chunk_text[:50]}...' ({len(temp_tokens)} tokens)")
yield chunk_text, temp_tokens
temp_chunk = [word]
temp_tokens = word_tokens
else:
temp_chunk.append(word)
temp_tokens.extend(word_tokens)
if temp_chunk: # Don't forget the last chunk
chunk_text = " ".join(temp_chunk)
chunk_count += 1
logger.info(f"Yielding final word-split chunk {chunk_count}: '{chunk_text[:50]}...' ({len(temp_tokens)} tokens)")
yield chunk_text, temp_tokens
# If adding clause keeps us under max and not optimal yet
if clause_count + count <= max_tokens and clause_count + count <= TARGET_MAX:
clause_chunk.append(full_clause)
clause_tokens.extend(tokens)
clause_count += count
else:
# Check if adding this clause would exceed the limit
if current_token_count + len(clause_tokens) > max_tokens:
if current_chunk: # Don't yield empty chunks
chunk_text = " ".join(current_chunk)
chunk_count += 1
logger.info(f"Yielding clause-split chunk {chunk_count}: '{chunk_text[:50]}...' ({current_token_count} tokens)")
yield chunk_text, process_text_chunk(chunk_text)
current_chunk = [clause_with_punct]
current_token_count = len(clause_tokens)
else:
current_chunk.append(clause_with_punct)
current_token_count += len(clause_tokens)
# Yield clause chunk if we have one
if clause_chunk:
chunk_text = " ".join(clause_chunk)
chunk_count += 1
logger.info(f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}...' ({clause_count} tokens)")
yield chunk_text, clause_tokens
clause_chunk = [full_clause]
clause_tokens = tokens
clause_count = count
# Don't forget last clause chunk
if clause_chunk:
chunk_text = " ".join(clause_chunk)
chunk_count += 1
logger.info(f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}...' ({clause_count} tokens)")
yield chunk_text, clause_tokens
# Regular sentence handling
elif current_count + count <= TARGET_MAX:
# Keep building chunk while under target max
current_chunk.append(sentence)
current_tokens.extend(tokens)
current_count += count
elif current_count + count <= max_tokens:
# Accept slightly larger chunk if needed
current_chunk.append(sentence)
current_tokens.extend(tokens)
current_count += count
else:
# Check if adding this sentence would exceed the limit
if current_token_count + token_count > max_tokens:
if current_chunk: # Don't yield empty chunks
chunk_text = " ".join(current_chunk)
chunk_count += 1
logger.info(f"Yielding sentence-split chunk {chunk_count}: '{chunk_text[:50]}...' ({current_token_count} tokens)")
yield chunk_text, process_text_chunk(chunk_text)
current_chunk = [sentence_with_punct]
current_token_count = token_count
else:
current_chunk.append(sentence_with_punct)
current_token_count += token_count
# Yield current chunk and start new one
if current_chunk:
chunk_text = " ".join(current_chunk)
chunk_count += 1
logger.info(f"Yielding chunk {chunk_count}: '{chunk_text[:50]}...' ({current_count} tokens)")
yield chunk_text, current_tokens
current_chunk = [sentence]
current_tokens = tokens
current_count = count
# Don't forget the last chunk
if current_chunk:
chunk_text = " ".join(current_chunk)
chunk_count += 1
logger.info(f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}...' ({current_token_count} tokens)")
yield chunk_text, process_text_chunk(chunk_text)
logger.info(f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}...' ({current_count} tokens)")
yield chunk_text, current_tokens
total_time = time.time() - start_time
logger.info(f"Text splitting completed in {total_time*1000:.2f}ms, produced {chunk_count} chunks")
logger.info(f"Split completed in {total_time*1000:.2f}ms, produced {chunk_count} chunks")

View file

@ -119,7 +119,7 @@ class TTSService:
try:
# Process audio for chunk
result = await self._process_chunk(
tokens,
tokens, # Now always a flat List[int]
voice_tensor,
speed,
output_format,