mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Refactor TTS service and chunker: update comments and remove unused code
This commit is contained in:
parent
b25ba5e7e6
commit
355ec54f78
3 changed files with 136 additions and 158 deletions
|
@ -1,74 +1,31 @@
|
|||
"""Text chunking module for TTS processing"""
|
||||
# """Text chunking module for TTS processing"""
|
||||
|
||||
from typing import List, AsyncGenerator
|
||||
from . import semchunk_slim
|
||||
# from typing import List, AsyncGenerator
|
||||
|
||||
async def fallback_split(text: str, max_chars: int = 400) -> List[str]:
|
||||
"""Emergency length control - only used if chunks are too long"""
|
||||
words = text.split()
|
||||
chunks = []
|
||||
current = []
|
||||
current_len = 0
|
||||
# async def fallback_split(text: str, max_chars: int = 400) -> List[str]:
|
||||
# """Emergency length control - only used if chunks are too long"""
|
||||
# words = text.split()
|
||||
# chunks = []
|
||||
# current = []
|
||||
# current_len = 0
|
||||
|
||||
for word in words:
|
||||
# Always include at least one word per chunk
|
||||
if not current:
|
||||
current.append(word)
|
||||
current_len = len(word)
|
||||
continue
|
||||
# for word in words:
|
||||
# # Always include at least one word per chunk
|
||||
# if not current:
|
||||
# current.append(word)
|
||||
# current_len = len(word)
|
||||
# continue
|
||||
|
||||
# Check if adding word would exceed limit
|
||||
if current_len + len(word) + 1 <= max_chars:
|
||||
current.append(word)
|
||||
current_len += len(word) + 1
|
||||
else:
|
||||
chunks.append(" ".join(current))
|
||||
current = [word]
|
||||
current_len = len(word)
|
||||
# # Check if adding word would exceed limit
|
||||
# if current_len + len(word) + 1 <= max_chars:
|
||||
# current.append(word)
|
||||
# current_len += len(word) + 1
|
||||
# else:
|
||||
# chunks.append(" ".join(current))
|
||||
# current = [word]
|
||||
# current_len = len(word)
|
||||
|
||||
if current:
|
||||
chunks.append(" ".join(current))
|
||||
# if current:
|
||||
# chunks.append(" ".join(current))
|
||||
|
||||
return chunks
|
||||
|
||||
async def split_text(text: str, max_chunk: int = None) -> AsyncGenerator[str, None]:
|
||||
"""Split text into TTS-friendly chunks
|
||||
|
||||
Args:
|
||||
text: Text to split into chunks
|
||||
max_chunk: Maximum chunk size (defaults to 400)
|
||||
|
||||
Yields:
|
||||
Text chunks suitable for TTS processing
|
||||
"""
|
||||
if max_chunk is None:
|
||||
max_chunk = 400
|
||||
|
||||
if not isinstance(text, str):
|
||||
text = str(text) if text is not None else ""
|
||||
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return
|
||||
|
||||
# Initialize chunker targeting ~300 chars to allow for expansion
|
||||
chunker = semchunk_slim.chunkerify(
|
||||
lambda t: len(t) // 5, # Simple length-based target
|
||||
chunk_size=60 # Target ~300 chars
|
||||
)
|
||||
|
||||
# Get initial chunks
|
||||
chunks = chunker(text)
|
||||
|
||||
# Process chunks
|
||||
for chunk in chunks:
|
||||
chunk = chunk.strip()
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
# Use fallback for any chunks that are too long
|
||||
if len(chunk) > max_chunk:
|
||||
for subchunk in await fallback_split(chunk, max_chunk):
|
||||
yield subchunk
|
||||
else:
|
||||
yield chunk
|
||||
# return chunks
|
||||
|
|
|
@ -8,6 +8,11 @@ from .phonemizer import phonemize
|
|||
from .normalizer import normalize_text
|
||||
from .vocabulary import tokenize
|
||||
|
||||
# Constants for chunk size optimization
|
||||
TARGET_MIN_TOKENS = 300
|
||||
TARGET_MAX_TOKENS = 400
|
||||
ABSOLUTE_MAX_TOKENS = 500
|
||||
|
||||
def process_text_chunk(text: str, language: str = "a") -> List[int]:
|
||||
"""Process a chunk of text through normalization, phonemization, and tokenization.
|
||||
|
||||
|
@ -43,6 +48,15 @@ def process_text_chunk(text: str, language: str = "a") -> List[int]:
|
|||
|
||||
return tokens
|
||||
|
||||
def is_chunk_size_optimal(token_count: int) -> bool:
|
||||
"""Check if chunk size is within optimal range."""
|
||||
return TARGET_MIN_TOKENS <= token_count <= TARGET_MAX_TOKENS
|
||||
|
||||
async def yield_chunk(text: str, tokens: List[int], chunk_count: int) -> Tuple[str, List[int]]:
|
||||
"""Yield a chunk with consistent logging."""
|
||||
logger.info(f"Yielding chunk {chunk_count}: '{text[:50]}...' ({len(tokens)} tokens)")
|
||||
return text, tokens
|
||||
|
||||
def process_text(text: str, language: str = "a") -> List[int]:
|
||||
"""Process text into token IDs.
|
||||
|
||||
|
@ -62,116 +76,123 @@ def process_text(text: str, language: str = "a") -> List[int]:
|
|||
|
||||
return process_text_chunk(text, language)
|
||||
|
||||
async def smart_split(text: str, max_tokens: int = 500) -> AsyncGenerator[Tuple[str, List[int]], None]:
|
||||
"""Split text into semantically meaningful chunks while respecting token limits.
|
||||
|
||||
Args:
|
||||
text: Input text to split
|
||||
max_tokens: Maximum tokens per chunk
|
||||
|
||||
Yields:
|
||||
Tuples of (text chunk, token IDs) where token count is <= max_tokens
|
||||
"""
|
||||
start_time = time.time()
|
||||
chunk_count = 0
|
||||
total_chars = len(text)
|
||||
logger.info(f"Starting text split for {total_chars} characters with {max_tokens} max tokens")
|
||||
|
||||
# Split on major punctuation first
|
||||
# Target token ranges
|
||||
TARGET_MIN = 300
|
||||
TARGET_MAX = 400
|
||||
ABSOLUTE_MAX = 500
|
||||
|
||||
def get_sentence_info(text: str) -> List[Tuple[str, List[int], int]]:
|
||||
"""Process all sentences and return info."""
|
||||
sentences = re.split(r'([.!?;:])', text)
|
||||
|
||||
current_chunk = []
|
||||
current_token_count = 0
|
||||
results = []
|
||||
|
||||
for i in range(0, len(sentences), 2):
|
||||
# Get sentence and its punctuation (if any)
|
||||
sentence = sentences[i].strip()
|
||||
punct = sentences[i + 1] if i + 1 < len(sentences) else ""
|
||||
|
||||
if not sentence:
|
||||
continue
|
||||
|
||||
# Process sentence to get token count
|
||||
sentence_with_punct = sentence + punct
|
||||
tokens = process_text_chunk(sentence_with_punct)
|
||||
token_count = len(tokens)
|
||||
logger.debug(f"Sentence '{sentence_with_punct[:50]}...' has {token_count} tokens")
|
||||
full = sentence + punct
|
||||
tokens = process_text_chunk(full)
|
||||
results.append((full, tokens, len(tokens)))
|
||||
|
||||
# If this single sentence is too long, split on commas
|
||||
if token_count > max_tokens:
|
||||
logger.debug(f"Sentence exceeds token limit, splitting on commas")
|
||||
clause_splits = re.split(r'([,])', sentence_with_punct)
|
||||
for j in range(0, len(clause_splits), 2):
|
||||
clause = clause_splits[j].strip()
|
||||
comma = clause_splits[j + 1] if j + 1 < len(clause_splits) else ""
|
||||
return results
|
||||
|
||||
async def smart_split(text: str, max_tokens: int = ABSOLUTE_MAX) -> AsyncGenerator[Tuple[str, List[int]], None]:
|
||||
"""Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens."""
|
||||
start_time = time.time()
|
||||
chunk_count = 0
|
||||
logger.info(f"Starting smart split for {len(text)} chars")
|
||||
|
||||
# Process all sentences
|
||||
sentences = get_sentence_info(text)
|
||||
|
||||
current_chunk = []
|
||||
current_tokens = []
|
||||
current_count = 0
|
||||
|
||||
for sentence, tokens, count in sentences:
|
||||
# Handle sentences that exceed max tokens
|
||||
if count > max_tokens:
|
||||
# Yield current chunk if any
|
||||
if current_chunk:
|
||||
chunk_text = " ".join(current_chunk)
|
||||
chunk_count += 1
|
||||
logger.info(f"Yielding chunk {chunk_count}: '{chunk_text[:50]}...' ({current_count} tokens)")
|
||||
yield chunk_text, current_tokens
|
||||
current_chunk = []
|
||||
current_tokens = []
|
||||
current_count = 0
|
||||
|
||||
# Split long sentence on commas
|
||||
clauses = re.split(r'([,])', sentence)
|
||||
clause_chunk = []
|
||||
clause_tokens = []
|
||||
clause_count = 0
|
||||
|
||||
for j in range(0, len(clauses), 2):
|
||||
clause = clauses[j].strip()
|
||||
comma = clauses[j + 1] if j + 1 < len(clauses) else ""
|
||||
|
||||
if not clause:
|
||||
continue
|
||||
|
||||
clause_with_punct = clause + comma
|
||||
clause_tokens = process_text_chunk(clause_with_punct)
|
||||
full_clause = clause + comma
|
||||
tokens = process_text_chunk(full_clause)
|
||||
count = len(tokens)
|
||||
|
||||
# If still too long, do a hard split on words
|
||||
if len(clause_tokens) > max_tokens:
|
||||
logger.debug(f"Clause exceeds token limit, splitting on words")
|
||||
words = clause_with_punct.split()
|
||||
temp_chunk = []
|
||||
temp_tokens = []
|
||||
|
||||
for word in words:
|
||||
word_tokens = process_text_chunk(word)
|
||||
if len(temp_tokens) + len(word_tokens) > max_tokens:
|
||||
if temp_chunk: # Don't yield empty chunks
|
||||
chunk_text = " ".join(temp_chunk)
|
||||
chunk_count += 1
|
||||
logger.info(f"Yielding word-split chunk {chunk_count}: '{chunk_text[:50]}...' ({len(temp_tokens)} tokens)")
|
||||
yield chunk_text, temp_tokens
|
||||
temp_chunk = [word]
|
||||
temp_tokens = word_tokens
|
||||
else:
|
||||
temp_chunk.append(word)
|
||||
temp_tokens.extend(word_tokens)
|
||||
|
||||
if temp_chunk: # Don't forget the last chunk
|
||||
chunk_text = " ".join(temp_chunk)
|
||||
chunk_count += 1
|
||||
logger.info(f"Yielding final word-split chunk {chunk_count}: '{chunk_text[:50]}...' ({len(temp_tokens)} tokens)")
|
||||
yield chunk_text, temp_tokens
|
||||
|
||||
# If adding clause keeps us under max and not optimal yet
|
||||
if clause_count + count <= max_tokens and clause_count + count <= TARGET_MAX:
|
||||
clause_chunk.append(full_clause)
|
||||
clause_tokens.extend(tokens)
|
||||
clause_count += count
|
||||
else:
|
||||
# Check if adding this clause would exceed the limit
|
||||
if current_token_count + len(clause_tokens) > max_tokens:
|
||||
if current_chunk: # Don't yield empty chunks
|
||||
chunk_text = " ".join(current_chunk)
|
||||
chunk_count += 1
|
||||
logger.info(f"Yielding clause-split chunk {chunk_count}: '{chunk_text[:50]}...' ({current_token_count} tokens)")
|
||||
yield chunk_text, process_text_chunk(chunk_text)
|
||||
current_chunk = [clause_with_punct]
|
||||
current_token_count = len(clause_tokens)
|
||||
else:
|
||||
current_chunk.append(clause_with_punct)
|
||||
current_token_count += len(clause_tokens)
|
||||
|
||||
# Yield clause chunk if we have one
|
||||
if clause_chunk:
|
||||
chunk_text = " ".join(clause_chunk)
|
||||
chunk_count += 1
|
||||
logger.info(f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}...' ({clause_count} tokens)")
|
||||
yield chunk_text, clause_tokens
|
||||
clause_chunk = [full_clause]
|
||||
clause_tokens = tokens
|
||||
clause_count = count
|
||||
|
||||
# Don't forget last clause chunk
|
||||
if clause_chunk:
|
||||
chunk_text = " ".join(clause_chunk)
|
||||
chunk_count += 1
|
||||
logger.info(f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}...' ({clause_count} tokens)")
|
||||
yield chunk_text, clause_tokens
|
||||
|
||||
# Regular sentence handling
|
||||
elif current_count + count <= TARGET_MAX:
|
||||
# Keep building chunk while under target max
|
||||
current_chunk.append(sentence)
|
||||
current_tokens.extend(tokens)
|
||||
current_count += count
|
||||
elif current_count + count <= max_tokens:
|
||||
# Accept slightly larger chunk if needed
|
||||
current_chunk.append(sentence)
|
||||
current_tokens.extend(tokens)
|
||||
current_count += count
|
||||
else:
|
||||
# Check if adding this sentence would exceed the limit
|
||||
if current_token_count + token_count > max_tokens:
|
||||
if current_chunk: # Don't yield empty chunks
|
||||
chunk_text = " ".join(current_chunk)
|
||||
chunk_count += 1
|
||||
logger.info(f"Yielding sentence-split chunk {chunk_count}: '{chunk_text[:50]}...' ({current_token_count} tokens)")
|
||||
yield chunk_text, process_text_chunk(chunk_text)
|
||||
current_chunk = [sentence_with_punct]
|
||||
current_token_count = token_count
|
||||
else:
|
||||
current_chunk.append(sentence_with_punct)
|
||||
current_token_count += token_count
|
||||
# Yield current chunk and start new one
|
||||
if current_chunk:
|
||||
chunk_text = " ".join(current_chunk)
|
||||
chunk_count += 1
|
||||
logger.info(f"Yielding chunk {chunk_count}: '{chunk_text[:50]}...' ({current_count} tokens)")
|
||||
yield chunk_text, current_tokens
|
||||
current_chunk = [sentence]
|
||||
current_tokens = tokens
|
||||
current_count = count
|
||||
|
||||
# Don't forget the last chunk
|
||||
if current_chunk:
|
||||
chunk_text = " ".join(current_chunk)
|
||||
chunk_count += 1
|
||||
logger.info(f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}...' ({current_token_count} tokens)")
|
||||
yield chunk_text, process_text_chunk(chunk_text)
|
||||
logger.info(f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}...' ({current_count} tokens)")
|
||||
yield chunk_text, current_tokens
|
||||
|
||||
total_time = time.time() - start_time
|
||||
logger.info(f"Text splitting completed in {total_time*1000:.2f}ms, produced {chunk_count} chunks")
|
||||
logger.info(f"Split completed in {total_time*1000:.2f}ms, produced {chunk_count} chunks")
|
||||
|
|
|
@ -119,7 +119,7 @@ class TTSService:
|
|||
try:
|
||||
# Process audio for chunk
|
||||
result = await self._process_chunk(
|
||||
tokens,
|
||||
tokens, # Now always a flat List[int]
|
||||
voice_tensor,
|
||||
speed,
|
||||
output_format,
|
||||
|
|
Loading…
Add table
Reference in a new issue