mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Refactor TTS service and chunker: update comments and remove unused code
This commit is contained in:
parent
b25ba5e7e6
commit
355ec54f78
3 changed files with 136 additions and 158 deletions
|
@ -1,74 +1,31 @@
|
||||||
"""Text chunking module for TTS processing"""
|
# """Text chunking module for TTS processing"""
|
||||||
|
|
||||||
from typing import List, AsyncGenerator
|
# from typing import List, AsyncGenerator
|
||||||
from . import semchunk_slim
|
|
||||||
|
|
||||||
async def fallback_split(text: str, max_chars: int = 400) -> List[str]:
|
# async def fallback_split(text: str, max_chars: int = 400) -> List[str]:
|
||||||
"""Emergency length control - only used if chunks are too long"""
|
# """Emergency length control - only used if chunks are too long"""
|
||||||
words = text.split()
|
# words = text.split()
|
||||||
chunks = []
|
# chunks = []
|
||||||
current = []
|
# current = []
|
||||||
current_len = 0
|
# current_len = 0
|
||||||
|
|
||||||
for word in words:
|
# for word in words:
|
||||||
# Always include at least one word per chunk
|
# # Always include at least one word per chunk
|
||||||
if not current:
|
# if not current:
|
||||||
current.append(word)
|
# current.append(word)
|
||||||
current_len = len(word)
|
# current_len = len(word)
|
||||||
continue
|
# continue
|
||||||
|
|
||||||
# Check if adding word would exceed limit
|
# # Check if adding word would exceed limit
|
||||||
if current_len + len(word) + 1 <= max_chars:
|
# if current_len + len(word) + 1 <= max_chars:
|
||||||
current.append(word)
|
# current.append(word)
|
||||||
current_len += len(word) + 1
|
# current_len += len(word) + 1
|
||||||
else:
|
# else:
|
||||||
chunks.append(" ".join(current))
|
# chunks.append(" ".join(current))
|
||||||
current = [word]
|
# current = [word]
|
||||||
current_len = len(word)
|
# current_len = len(word)
|
||||||
|
|
||||||
if current:
|
# if current:
|
||||||
chunks.append(" ".join(current))
|
# chunks.append(" ".join(current))
|
||||||
|
|
||||||
return chunks
|
# return chunks
|
||||||
|
|
||||||
async def split_text(text: str, max_chunk: int = None) -> AsyncGenerator[str, None]:
|
|
||||||
"""Split text into TTS-friendly chunks
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Text to split into chunks
|
|
||||||
max_chunk: Maximum chunk size (defaults to 400)
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
Text chunks suitable for TTS processing
|
|
||||||
"""
|
|
||||||
if max_chunk is None:
|
|
||||||
max_chunk = 400
|
|
||||||
|
|
||||||
if not isinstance(text, str):
|
|
||||||
text = str(text) if text is not None else ""
|
|
||||||
|
|
||||||
text = text.strip()
|
|
||||||
if not text:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Initialize chunker targeting ~300 chars to allow for expansion
|
|
||||||
chunker = semchunk_slim.chunkerify(
|
|
||||||
lambda t: len(t) // 5, # Simple length-based target
|
|
||||||
chunk_size=60 # Target ~300 chars
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get initial chunks
|
|
||||||
chunks = chunker(text)
|
|
||||||
|
|
||||||
# Process chunks
|
|
||||||
for chunk in chunks:
|
|
||||||
chunk = chunk.strip()
|
|
||||||
if not chunk:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Use fallback for any chunks that are too long
|
|
||||||
if len(chunk) > max_chunk:
|
|
||||||
for subchunk in await fallback_split(chunk, max_chunk):
|
|
||||||
yield subchunk
|
|
||||||
else:
|
|
||||||
yield chunk
|
|
||||||
|
|
|
@ -8,6 +8,11 @@ from .phonemizer import phonemize
|
||||||
from .normalizer import normalize_text
|
from .normalizer import normalize_text
|
||||||
from .vocabulary import tokenize
|
from .vocabulary import tokenize
|
||||||
|
|
||||||
|
# Constants for chunk size optimization
|
||||||
|
TARGET_MIN_TOKENS = 300
|
||||||
|
TARGET_MAX_TOKENS = 400
|
||||||
|
ABSOLUTE_MAX_TOKENS = 500
|
||||||
|
|
||||||
def process_text_chunk(text: str, language: str = "a") -> List[int]:
|
def process_text_chunk(text: str, language: str = "a") -> List[int]:
|
||||||
"""Process a chunk of text through normalization, phonemization, and tokenization.
|
"""Process a chunk of text through normalization, phonemization, and tokenization.
|
||||||
|
|
||||||
|
@ -43,6 +48,15 @@ def process_text_chunk(text: str, language: str = "a") -> List[int]:
|
||||||
|
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
|
def is_chunk_size_optimal(token_count: int) -> bool:
|
||||||
|
"""Check if chunk size is within optimal range."""
|
||||||
|
return TARGET_MIN_TOKENS <= token_count <= TARGET_MAX_TOKENS
|
||||||
|
|
||||||
|
async def yield_chunk(text: str, tokens: List[int], chunk_count: int) -> Tuple[str, List[int]]:
|
||||||
|
"""Yield a chunk with consistent logging."""
|
||||||
|
logger.info(f"Yielding chunk {chunk_count}: '{text[:50]}...' ({len(tokens)} tokens)")
|
||||||
|
return text, tokens
|
||||||
|
|
||||||
def process_text(text: str, language: str = "a") -> List[int]:
|
def process_text(text: str, language: str = "a") -> List[int]:
|
||||||
"""Process text into token IDs.
|
"""Process text into token IDs.
|
||||||
|
|
||||||
|
@ -62,116 +76,123 @@ def process_text(text: str, language: str = "a") -> List[int]:
|
||||||
|
|
||||||
return process_text_chunk(text, language)
|
return process_text_chunk(text, language)
|
||||||
|
|
||||||
async def smart_split(text: str, max_tokens: int = 500) -> AsyncGenerator[Tuple[str, List[int]], None]:
|
# Target token ranges
|
||||||
"""Split text into semantically meaningful chunks while respecting token limits.
|
TARGET_MIN = 300
|
||||||
|
TARGET_MAX = 400
|
||||||
Args:
|
ABSOLUTE_MAX = 500
|
||||||
text: Input text to split
|
|
||||||
max_tokens: Maximum tokens per chunk
|
def get_sentence_info(text: str) -> List[Tuple[str, List[int], int]]:
|
||||||
|
"""Process all sentences and return info."""
|
||||||
Yields:
|
|
||||||
Tuples of (text chunk, token IDs) where token count is <= max_tokens
|
|
||||||
"""
|
|
||||||
start_time = time.time()
|
|
||||||
chunk_count = 0
|
|
||||||
total_chars = len(text)
|
|
||||||
logger.info(f"Starting text split for {total_chars} characters with {max_tokens} max tokens")
|
|
||||||
|
|
||||||
# Split on major punctuation first
|
|
||||||
sentences = re.split(r'([.!?;:])', text)
|
sentences = re.split(r'([.!?;:])', text)
|
||||||
|
results = []
|
||||||
current_chunk = []
|
|
||||||
current_token_count = 0
|
|
||||||
|
|
||||||
for i in range(0, len(sentences), 2):
|
for i in range(0, len(sentences), 2):
|
||||||
# Get sentence and its punctuation (if any)
|
|
||||||
sentence = sentences[i].strip()
|
sentence = sentences[i].strip()
|
||||||
punct = sentences[i + 1] if i + 1 < len(sentences) else ""
|
punct = sentences[i + 1] if i + 1 < len(sentences) else ""
|
||||||
|
|
||||||
if not sentence:
|
if not sentence:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Process sentence to get token count
|
full = sentence + punct
|
||||||
sentence_with_punct = sentence + punct
|
tokens = process_text_chunk(full)
|
||||||
tokens = process_text_chunk(sentence_with_punct)
|
results.append((full, tokens, len(tokens)))
|
||||||
token_count = len(tokens)
|
|
||||||
logger.debug(f"Sentence '{sentence_with_punct[:50]}...' has {token_count} tokens")
|
|
||||||
|
|
||||||
# If this single sentence is too long, split on commas
|
return results
|
||||||
if token_count > max_tokens:
|
|
||||||
logger.debug(f"Sentence exceeds token limit, splitting on commas")
|
async def smart_split(text: str, max_tokens: int = ABSOLUTE_MAX) -> AsyncGenerator[Tuple[str, List[int]], None]:
|
||||||
clause_splits = re.split(r'([,])', sentence_with_punct)
|
"""Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens."""
|
||||||
for j in range(0, len(clause_splits), 2):
|
start_time = time.time()
|
||||||
clause = clause_splits[j].strip()
|
chunk_count = 0
|
||||||
comma = clause_splits[j + 1] if j + 1 < len(clause_splits) else ""
|
logger.info(f"Starting smart split for {len(text)} chars")
|
||||||
|
|
||||||
|
# Process all sentences
|
||||||
|
sentences = get_sentence_info(text)
|
||||||
|
|
||||||
|
current_chunk = []
|
||||||
|
current_tokens = []
|
||||||
|
current_count = 0
|
||||||
|
|
||||||
|
for sentence, tokens, count in sentences:
|
||||||
|
# Handle sentences that exceed max tokens
|
||||||
|
if count > max_tokens:
|
||||||
|
# Yield current chunk if any
|
||||||
|
if current_chunk:
|
||||||
|
chunk_text = " ".join(current_chunk)
|
||||||
|
chunk_count += 1
|
||||||
|
logger.info(f"Yielding chunk {chunk_count}: '{chunk_text[:50]}...' ({current_count} tokens)")
|
||||||
|
yield chunk_text, current_tokens
|
||||||
|
current_chunk = []
|
||||||
|
current_tokens = []
|
||||||
|
current_count = 0
|
||||||
|
|
||||||
|
# Split long sentence on commas
|
||||||
|
clauses = re.split(r'([,])', sentence)
|
||||||
|
clause_chunk = []
|
||||||
|
clause_tokens = []
|
||||||
|
clause_count = 0
|
||||||
|
|
||||||
|
for j in range(0, len(clauses), 2):
|
||||||
|
clause = clauses[j].strip()
|
||||||
|
comma = clauses[j + 1] if j + 1 < len(clauses) else ""
|
||||||
|
|
||||||
if not clause:
|
if not clause:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
clause_with_punct = clause + comma
|
full_clause = clause + comma
|
||||||
clause_tokens = process_text_chunk(clause_with_punct)
|
tokens = process_text_chunk(full_clause)
|
||||||
|
count = len(tokens)
|
||||||
|
|
||||||
# If still too long, do a hard split on words
|
# If adding clause keeps us under max and not optimal yet
|
||||||
if len(clause_tokens) > max_tokens:
|
if clause_count + count <= max_tokens and clause_count + count <= TARGET_MAX:
|
||||||
logger.debug(f"Clause exceeds token limit, splitting on words")
|
clause_chunk.append(full_clause)
|
||||||
words = clause_with_punct.split()
|
clause_tokens.extend(tokens)
|
||||||
temp_chunk = []
|
clause_count += count
|
||||||
temp_tokens = []
|
|
||||||
|
|
||||||
for word in words:
|
|
||||||
word_tokens = process_text_chunk(word)
|
|
||||||
if len(temp_tokens) + len(word_tokens) > max_tokens:
|
|
||||||
if temp_chunk: # Don't yield empty chunks
|
|
||||||
chunk_text = " ".join(temp_chunk)
|
|
||||||
chunk_count += 1
|
|
||||||
logger.info(f"Yielding word-split chunk {chunk_count}: '{chunk_text[:50]}...' ({len(temp_tokens)} tokens)")
|
|
||||||
yield chunk_text, temp_tokens
|
|
||||||
temp_chunk = [word]
|
|
||||||
temp_tokens = word_tokens
|
|
||||||
else:
|
|
||||||
temp_chunk.append(word)
|
|
||||||
temp_tokens.extend(word_tokens)
|
|
||||||
|
|
||||||
if temp_chunk: # Don't forget the last chunk
|
|
||||||
chunk_text = " ".join(temp_chunk)
|
|
||||||
chunk_count += 1
|
|
||||||
logger.info(f"Yielding final word-split chunk {chunk_count}: '{chunk_text[:50]}...' ({len(temp_tokens)} tokens)")
|
|
||||||
yield chunk_text, temp_tokens
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Check if adding this clause would exceed the limit
|
# Yield clause chunk if we have one
|
||||||
if current_token_count + len(clause_tokens) > max_tokens:
|
if clause_chunk:
|
||||||
if current_chunk: # Don't yield empty chunks
|
chunk_text = " ".join(clause_chunk)
|
||||||
chunk_text = " ".join(current_chunk)
|
chunk_count += 1
|
||||||
chunk_count += 1
|
logger.info(f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}...' ({clause_count} tokens)")
|
||||||
logger.info(f"Yielding clause-split chunk {chunk_count}: '{chunk_text[:50]}...' ({current_token_count} tokens)")
|
yield chunk_text, clause_tokens
|
||||||
yield chunk_text, process_text_chunk(chunk_text)
|
clause_chunk = [full_clause]
|
||||||
current_chunk = [clause_with_punct]
|
clause_tokens = tokens
|
||||||
current_token_count = len(clause_tokens)
|
clause_count = count
|
||||||
else:
|
|
||||||
current_chunk.append(clause_with_punct)
|
# Don't forget last clause chunk
|
||||||
current_token_count += len(clause_tokens)
|
if clause_chunk:
|
||||||
|
chunk_text = " ".join(clause_chunk)
|
||||||
|
chunk_count += 1
|
||||||
|
logger.info(f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}...' ({clause_count} tokens)")
|
||||||
|
yield chunk_text, clause_tokens
|
||||||
|
|
||||||
|
# Regular sentence handling
|
||||||
|
elif current_count + count <= TARGET_MAX:
|
||||||
|
# Keep building chunk while under target max
|
||||||
|
current_chunk.append(sentence)
|
||||||
|
current_tokens.extend(tokens)
|
||||||
|
current_count += count
|
||||||
|
elif current_count + count <= max_tokens:
|
||||||
|
# Accept slightly larger chunk if needed
|
||||||
|
current_chunk.append(sentence)
|
||||||
|
current_tokens.extend(tokens)
|
||||||
|
current_count += count
|
||||||
else:
|
else:
|
||||||
# Check if adding this sentence would exceed the limit
|
# Yield current chunk and start new one
|
||||||
if current_token_count + token_count > max_tokens:
|
if current_chunk:
|
||||||
if current_chunk: # Don't yield empty chunks
|
chunk_text = " ".join(current_chunk)
|
||||||
chunk_text = " ".join(current_chunk)
|
chunk_count += 1
|
||||||
chunk_count += 1
|
logger.info(f"Yielding chunk {chunk_count}: '{chunk_text[:50]}...' ({current_count} tokens)")
|
||||||
logger.info(f"Yielding sentence-split chunk {chunk_count}: '{chunk_text[:50]}...' ({current_token_count} tokens)")
|
yield chunk_text, current_tokens
|
||||||
yield chunk_text, process_text_chunk(chunk_text)
|
current_chunk = [sentence]
|
||||||
current_chunk = [sentence_with_punct]
|
current_tokens = tokens
|
||||||
current_token_count = token_count
|
current_count = count
|
||||||
else:
|
|
||||||
current_chunk.append(sentence_with_punct)
|
|
||||||
current_token_count += token_count
|
|
||||||
|
|
||||||
# Don't forget the last chunk
|
# Don't forget the last chunk
|
||||||
if current_chunk:
|
if current_chunk:
|
||||||
chunk_text = " ".join(current_chunk)
|
chunk_text = " ".join(current_chunk)
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
logger.info(f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}...' ({current_token_count} tokens)")
|
logger.info(f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}...' ({current_count} tokens)")
|
||||||
yield chunk_text, process_text_chunk(chunk_text)
|
yield chunk_text, current_tokens
|
||||||
|
|
||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
logger.info(f"Text splitting completed in {total_time*1000:.2f}ms, produced {chunk_count} chunks")
|
logger.info(f"Split completed in {total_time*1000:.2f}ms, produced {chunk_count} chunks")
|
||||||
|
|
|
@ -119,7 +119,7 @@ class TTSService:
|
||||||
try:
|
try:
|
||||||
# Process audio for chunk
|
# Process audio for chunk
|
||||||
result = await self._process_chunk(
|
result = await self._process_chunk(
|
||||||
tokens,
|
tokens, # Now always a flat List[int]
|
||||||
voice_tensor,
|
voice_tensor,
|
||||||
speed,
|
speed,
|
||||||
output_format,
|
output_format,
|
||||||
|
|
Loading…
Add table
Reference in a new issue