mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Enhance TTS text processing: Implement pause tag handling in smart_split, allowing for better audio chunk generation with pauses. Update related tests to validate new functionality and ensure compatibility with existing features.
This commit is contained in:
parent
ab8ab7d749
commit
84d2a4d806
3 changed files with 282 additions and 167 deletions
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from typing import AsyncGenerator, Dict, List, Tuple
|
from typing import AsyncGenerator, Dict, List, Tuple, Optional
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
@ -13,7 +13,11 @@ from .phonemizer import phonemize
|
||||||
from .vocabulary import tokenize
|
from .vocabulary import tokenize
|
||||||
|
|
||||||
# Pre-compiled regex patterns for performance
|
# Pre-compiled regex patterns for performance
|
||||||
CUSTOM_PHONEMES = re.compile(r"(\[([^\]]|\n)*?\])(\(\/([^\/)]|\n)*?\/\))")
|
# Updated regex to be more strict and avoid matching isolated brackets
|
||||||
|
# Only matches complete patterns like [word](/ipa/) and prevents catastrophic backtracking
|
||||||
|
CUSTOM_PHONEMES = re.compile(r"(\[[^\[\]]*?\])(\(\/[^\/\(\)]*?\/\))")
|
||||||
|
# Pattern to find pause tags like [pause:0.5s]
|
||||||
|
PAUSE_TAG_PATTERN = re.compile(r"\[pause:(\d+(?:\.\d+)?)s\]", re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
def process_text_chunk(
|
def process_text_chunk(
|
||||||
|
@ -142,148 +146,189 @@ async def smart_split(
|
||||||
max_tokens: int = settings.absolute_max_tokens,
|
max_tokens: int = settings.absolute_max_tokens,
|
||||||
lang_code: str = "a",
|
lang_code: str = "a",
|
||||||
normalization_options: NormalizationOptions = NormalizationOptions(),
|
normalization_options: NormalizationOptions = NormalizationOptions(),
|
||||||
) -> AsyncGenerator[Tuple[str, List[int]], None]:
|
) -> AsyncGenerator[Tuple[str, List[int], Optional[float]], None]:
|
||||||
"""Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens."""
|
"""Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Tuple of (text_chunk, tokens, pause_duration_s).
|
||||||
|
If pause_duration_s is not None, it's a pause chunk with empty text/tokens.
|
||||||
|
Otherwise, it's a text chunk containing the original text.
|
||||||
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
logger.info(f"Starting smart split for {len(text)} chars")
|
logger.info(f"Starting smart split for {len(text)} chars")
|
||||||
|
|
||||||
custom_phoneme_list = {}
|
# --- Step 1: Split by Pause Tags FIRST ---
|
||||||
|
# This operates on the raw input text
|
||||||
|
parts = PAUSE_TAG_PATTERN.split(text)
|
||||||
|
logger.debug(f"Split raw text into {len(parts)} parts by pause tags.")
|
||||||
|
|
||||||
# Normalize text
|
part_idx = 0
|
||||||
if settings.advanced_text_normalization and normalization_options.normalize:
|
while part_idx < len(parts):
|
||||||
if lang_code in ["a", "b", "en-us", "en-gb"]:
|
text_part_raw = parts[part_idx] # This part is raw text
|
||||||
text = CUSTOM_PHONEMES.sub(
|
part_idx += 1
|
||||||
lambda s: handle_custom_phonemes(s, custom_phoneme_list), text
|
|
||||||
)
|
|
||||||
text = normalize_text(text, normalization_options)
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
"Skipping text normalization as it is only supported for english"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process all sentences
|
# --- Process Text Part ---
|
||||||
sentences = get_sentence_info(text, custom_phoneme_list, lang_code=lang_code)
|
if text_part_raw and text_part_raw.strip(): # Only process if the part is not empty string
|
||||||
|
# Strip leading and trailing spaces to prevent pause tag splitting artifacts
|
||||||
|
text_part_raw = text_part_raw.strip()
|
||||||
|
|
||||||
current_chunk = []
|
# Apply the original smart_split logic to this text part
|
||||||
current_tokens = []
|
custom_phoneme_list = {}
|
||||||
current_count = 0
|
|
||||||
|
|
||||||
for sentence, tokens, count in sentences:
|
# Normalize text (original logic)
|
||||||
# Handle sentences that exceed max tokens
|
processed_text = text_part_raw
|
||||||
if count > max_tokens:
|
if settings.advanced_text_normalization and normalization_options.normalize:
|
||||||
# Yield current chunk if any
|
if lang_code in ["a", "b", "en-us", "en-gb"]:
|
||||||
if current_chunk:
|
processed_text = CUSTOM_PHONEMES.sub(
|
||||||
chunk_text = " ".join(current_chunk).strip() # Strip after joining
|
lambda s: handle_custom_phonemes(s, custom_phoneme_list), processed_text
|
||||||
chunk_count += 1
|
)
|
||||||
logger.debug(
|
processed_text = normalize_text(processed_text, normalization_options)
|
||||||
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
|
|
||||||
)
|
|
||||||
yield chunk_text, current_tokens
|
|
||||||
current_chunk = []
|
|
||||||
current_tokens = []
|
|
||||||
current_count = 0
|
|
||||||
|
|
||||||
# Split long sentence on commas
|
|
||||||
clauses = re.split(r"([,])", sentence)
|
|
||||||
clause_chunk = []
|
|
||||||
clause_tokens = []
|
|
||||||
clause_count = 0
|
|
||||||
|
|
||||||
for j in range(0, len(clauses), 2):
|
|
||||||
clause = clauses[j].strip()
|
|
||||||
comma = clauses[j + 1] if j + 1 < len(clauses) else ""
|
|
||||||
|
|
||||||
if not clause:
|
|
||||||
continue
|
|
||||||
|
|
||||||
full_clause = clause + comma
|
|
||||||
|
|
||||||
tokens = process_text_chunk(full_clause)
|
|
||||||
count = len(tokens)
|
|
||||||
|
|
||||||
# If adding clause keeps us under max and not optimal yet
|
|
||||||
if (
|
|
||||||
clause_count + count <= max_tokens
|
|
||||||
and clause_count + count <= settings.target_max_tokens
|
|
||||||
):
|
|
||||||
clause_chunk.append(full_clause)
|
|
||||||
clause_tokens.extend(tokens)
|
|
||||||
clause_count += count
|
|
||||||
else:
|
else:
|
||||||
# Yield clause chunk if we have one
|
logger.info(
|
||||||
if clause_chunk:
|
"Skipping text normalization as it is only supported for english"
|
||||||
chunk_text = " ".join(clause_chunk).strip() # Strip after joining
|
)
|
||||||
|
|
||||||
|
# Process all sentences (original logic)
|
||||||
|
sentences = get_sentence_info(processed_text, custom_phoneme_list, lang_code=lang_code)
|
||||||
|
|
||||||
|
current_chunk = []
|
||||||
|
current_tokens = []
|
||||||
|
current_count = 0
|
||||||
|
|
||||||
|
for sentence, tokens, count in sentences:
|
||||||
|
# Handle sentences that exceed max tokens (original logic)
|
||||||
|
if count > max_tokens:
|
||||||
|
# Yield current chunk if any
|
||||||
|
if current_chunk:
|
||||||
|
chunk_text = " ".join(current_chunk).strip()
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)"
|
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)"
|
||||||
)
|
)
|
||||||
yield chunk_text, clause_tokens
|
yield chunk_text, current_tokens, None
|
||||||
clause_chunk = [full_clause]
|
current_chunk = []
|
||||||
clause_tokens = tokens
|
current_tokens = []
|
||||||
clause_count = count
|
current_count = 0
|
||||||
|
|
||||||
# Don't forget last clause chunk
|
# Split long sentence on commas (original logic)
|
||||||
if clause_chunk:
|
clauses = re.split(r"([,])", sentence)
|
||||||
chunk_text = " ".join(clause_chunk).strip() # Strip after joining
|
clause_chunk = []
|
||||||
chunk_count += 1
|
clause_tokens = []
|
||||||
logger.debug(
|
clause_count = 0
|
||||||
f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)"
|
|
||||||
)
|
|
||||||
yield chunk_text, clause_tokens
|
|
||||||
|
|
||||||
# Regular sentence handling
|
for j in range(0, len(clauses), 2):
|
||||||
elif (
|
clause = clauses[j].strip()
|
||||||
current_count >= settings.target_min_tokens
|
comma = clauses[j + 1] if j + 1 < len(clauses) else ""
|
||||||
and current_count + count > settings.target_max_tokens
|
|
||||||
):
|
if not clause:
|
||||||
# If we have a good sized chunk and adding next sentence exceeds target,
|
continue
|
||||||
# yield current chunk and start new one
|
|
||||||
chunk_text = " ".join(current_chunk).strip() # Strip after joining
|
full_clause = clause + comma
|
||||||
chunk_count += 1
|
|
||||||
logger.info(
|
tokens = process_text_chunk(full_clause)
|
||||||
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
|
count = len(tokens)
|
||||||
)
|
|
||||||
yield chunk_text, current_tokens
|
# If adding clause keeps us under max and not optimal yet
|
||||||
current_chunk = [sentence]
|
if (
|
||||||
current_tokens = tokens
|
clause_count + count <= max_tokens
|
||||||
current_count = count
|
and clause_count + count <= settings.target_max_tokens
|
||||||
elif current_count + count <= settings.target_max_tokens:
|
):
|
||||||
# Keep building chunk while under target max
|
clause_chunk.append(full_clause)
|
||||||
current_chunk.append(sentence)
|
clause_tokens.extend(tokens)
|
||||||
current_tokens.extend(tokens)
|
clause_count += count
|
||||||
current_count += count
|
else:
|
||||||
elif (
|
# Yield clause chunk if we have one
|
||||||
current_count + count <= max_tokens
|
if clause_chunk:
|
||||||
and current_count < settings.target_min_tokens
|
chunk_text = " ".join(clause_chunk).strip()
|
||||||
):
|
chunk_count += 1
|
||||||
# Only exceed target max if we haven't reached minimum size yet
|
logger.debug(
|
||||||
current_chunk.append(sentence)
|
f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({clause_count} tokens)"
|
||||||
current_tokens.extend(tokens)
|
)
|
||||||
current_count += count
|
yield chunk_text, clause_tokens, None
|
||||||
else:
|
clause_chunk = [full_clause]
|
||||||
# Yield current chunk and start new one
|
clause_tokens = tokens
|
||||||
|
clause_count = count
|
||||||
|
|
||||||
|
# Don't forget last clause chunk
|
||||||
|
if clause_chunk:
|
||||||
|
chunk_text = " ".join(clause_chunk).strip()
|
||||||
|
chunk_count += 1
|
||||||
|
logger.debug(
|
||||||
|
f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({clause_count} tokens)"
|
||||||
|
)
|
||||||
|
yield chunk_text, clause_tokens, None
|
||||||
|
|
||||||
|
# Regular sentence handling (original logic)
|
||||||
|
elif (
|
||||||
|
current_count >= settings.target_min_tokens
|
||||||
|
and current_count + count > settings.target_max_tokens
|
||||||
|
):
|
||||||
|
# If we have a good sized chunk and adding next sentence exceeds target,
|
||||||
|
# yield current chunk and start new one
|
||||||
|
chunk_text = " ".join(current_chunk).strip()
|
||||||
|
chunk_count += 1
|
||||||
|
logger.info(
|
||||||
|
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)"
|
||||||
|
)
|
||||||
|
yield chunk_text, current_tokens, None
|
||||||
|
current_chunk = [sentence]
|
||||||
|
current_tokens = tokens
|
||||||
|
current_count = count
|
||||||
|
elif current_count + count <= settings.target_max_tokens:
|
||||||
|
# Keep building chunk while under target max
|
||||||
|
current_chunk.append(sentence)
|
||||||
|
current_tokens.extend(tokens)
|
||||||
|
current_count += count
|
||||||
|
elif (
|
||||||
|
current_count + count <= max_tokens
|
||||||
|
and current_count < settings.target_min_tokens
|
||||||
|
):
|
||||||
|
# Only exceed target max if we haven't reached minimum size yet
|
||||||
|
current_chunk.append(sentence)
|
||||||
|
current_tokens.extend(tokens)
|
||||||
|
current_count += count
|
||||||
|
else:
|
||||||
|
# Yield current chunk and start new one
|
||||||
|
if current_chunk:
|
||||||
|
chunk_text = " ".join(current_chunk).strip()
|
||||||
|
chunk_count += 1
|
||||||
|
logger.info(
|
||||||
|
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)"
|
||||||
|
)
|
||||||
|
yield chunk_text, current_tokens, None
|
||||||
|
current_chunk = [sentence]
|
||||||
|
current_tokens = tokens
|
||||||
|
current_count = count
|
||||||
|
|
||||||
|
# Don't forget the last chunk for this text part
|
||||||
if current_chunk:
|
if current_chunk:
|
||||||
chunk_text = " ".join(current_chunk).strip() # Strip after joining
|
chunk_text = " ".join(current_chunk).strip()
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
|
f"Yielding final chunk {chunk_count} for part: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)"
|
||||||
)
|
)
|
||||||
yield chunk_text, current_tokens
|
yield chunk_text, current_tokens, None
|
||||||
current_chunk = [sentence]
|
|
||||||
current_tokens = tokens
|
|
||||||
current_count = count
|
|
||||||
|
|
||||||
# Don't forget the last chunk
|
# --- Handle Pause Part ---
|
||||||
if current_chunk:
|
# Check if the next part is a pause duration string
|
||||||
chunk_text = " ".join(current_chunk).strip() # Strip after joining
|
if part_idx < len(parts):
|
||||||
chunk_count += 1
|
duration_str = parts[part_idx]
|
||||||
logger.info(
|
# Check if it looks like a valid number string captured by the regex group
|
||||||
f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
|
if re.fullmatch(r"\d+(?:\.\d+)?", duration_str):
|
||||||
)
|
part_idx += 1 # Consume the duration string as it's been processed
|
||||||
yield chunk_text, current_tokens
|
try:
|
||||||
|
duration = float(duration_str)
|
||||||
|
if duration > 0:
|
||||||
|
chunk_count += 1
|
||||||
|
logger.info(f"Yielding pause chunk {chunk_count}: {duration}s")
|
||||||
|
yield "", [], duration # Yield pause chunk
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
# This case should be rare if re.fullmatch passed, but handle anyway
|
||||||
|
logger.warning(f"Could not parse valid-looking pause duration: {duration_str}")
|
||||||
|
|
||||||
|
# --- End of parts loop ---
|
||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks"
|
f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks (including pauses)"
|
||||||
)
|
)
|
||||||
|
|
|
@ -280,48 +280,88 @@ class TTSService:
|
||||||
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream"
|
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process text in chunks with smart splitting
|
# Process text in chunks with smart splitting, handling pause tags
|
||||||
async for chunk_text, tokens in smart_split(
|
async for chunk_text, tokens, pause_duration_s in smart_split(
|
||||||
text,
|
text,
|
||||||
lang_code=pipeline_lang_code,
|
lang_code=pipeline_lang_code,
|
||||||
normalization_options=normalization_options,
|
normalization_options=normalization_options,
|
||||||
):
|
):
|
||||||
try:
|
if pause_duration_s is not None and pause_duration_s > 0:
|
||||||
# Process audio for chunk
|
# --- Handle Pause Chunk ---
|
||||||
async for chunk_data in self._process_chunk(
|
try:
|
||||||
chunk_text, # Pass text for Kokoro V1
|
logger.debug(f"Generating {pause_duration_s}s silence chunk")
|
||||||
tokens, # Pass tokens for legacy backends
|
silence_samples = int(pause_duration_s * 24000) # 24kHz sample rate
|
||||||
voice_name, # Pass voice name
|
# Create proper silence as int16 zeros to avoid normalization artifacts
|
||||||
voice_path, # Pass voice path
|
silence_audio = np.zeros(silence_samples, dtype=np.int16)
|
||||||
speed,
|
pause_chunk = AudioChunk(audio=silence_audio, word_timestamps=[]) # Empty timestamps for silence
|
||||||
writer,
|
|
||||||
output_format,
|
|
||||||
is_first=(chunk_index == 0),
|
|
||||||
is_last=False, # We'll update the last chunk later
|
|
||||||
normalizer=stream_normalizer,
|
|
||||||
lang_code=pipeline_lang_code, # Pass lang_code
|
|
||||||
return_timestamps=return_timestamps,
|
|
||||||
):
|
|
||||||
if chunk_data.word_timestamps is not None:
|
|
||||||
for timestamp in chunk_data.word_timestamps:
|
|
||||||
timestamp.start_time += current_offset
|
|
||||||
timestamp.end_time += current_offset
|
|
||||||
|
|
||||||
current_offset += len(chunk_data.audio) / 24000
|
# Format and yield the silence chunk
|
||||||
|
if output_format:
|
||||||
if chunk_data.output is not None:
|
formatted_pause_chunk = await AudioService.convert_audio(
|
||||||
yield chunk_data
|
pause_chunk, output_format, writer, speed=speed, chunk_text="",
|
||||||
|
is_last_chunk=False, trim_audio=False, normalizer=stream_normalizer,
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"No audio generated for chunk: '{chunk_text[:100]}...'"
|
|
||||||
)
|
)
|
||||||
chunk_index += 1
|
if formatted_pause_chunk.output:
|
||||||
except Exception as e:
|
yield formatted_pause_chunk
|
||||||
logger.error(
|
else: # Raw audio mode
|
||||||
f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}"
|
# For raw audio mode, silence is already in the correct format (int16)
|
||||||
)
|
# Skip normalization to avoid any potential artifacts
|
||||||
continue
|
if len(pause_chunk.audio) > 0:
|
||||||
|
yield pause_chunk
|
||||||
|
|
||||||
|
# Update offset based on silence duration
|
||||||
|
current_offset += pause_duration_s
|
||||||
|
chunk_index += 1 # Count pause as a yielded chunk
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to process pause chunk: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif tokens or chunk_text.strip(): # Process if there are tokens OR non-whitespace text
|
||||||
|
# --- Handle Text Chunk ---
|
||||||
|
try:
|
||||||
|
# Process audio for chunk
|
||||||
|
async for chunk_data in self._process_chunk(
|
||||||
|
chunk_text, # Pass text for Kokoro V1
|
||||||
|
tokens, # Pass tokens for legacy backends
|
||||||
|
voice_name, # Pass voice name
|
||||||
|
voice_path, # Pass voice path
|
||||||
|
speed,
|
||||||
|
writer,
|
||||||
|
output_format,
|
||||||
|
is_first=(chunk_index == 0),
|
||||||
|
is_last=False, # We'll update the last chunk later
|
||||||
|
normalizer=stream_normalizer,
|
||||||
|
lang_code=pipeline_lang_code, # Pass lang_code
|
||||||
|
return_timestamps=return_timestamps,
|
||||||
|
):
|
||||||
|
if chunk_data.word_timestamps is not None:
|
||||||
|
for timestamp in chunk_data.word_timestamps:
|
||||||
|
timestamp.start_time += current_offset
|
||||||
|
timestamp.end_time += current_offset
|
||||||
|
|
||||||
|
# Update offset based on the actual duration of the generated audio chunk
|
||||||
|
chunk_duration = 0
|
||||||
|
if chunk_data.audio is not None and len(chunk_data.audio) > 0:
|
||||||
|
chunk_duration = len(chunk_data.audio) / 24000
|
||||||
|
current_offset += chunk_duration
|
||||||
|
|
||||||
|
# Yield the processed chunk (either formatted or raw)
|
||||||
|
if chunk_data.output is not None:
|
||||||
|
yield chunk_data
|
||||||
|
elif chunk_data.audio is not None and len(chunk_data.audio) > 0:
|
||||||
|
yield chunk_data
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"No audio generated for chunk: '{chunk_text[:100]}...'"
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk_index += 1 # Increment chunk index after processing text
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
# Only finalize if we successfully processed at least one chunk
|
# Only finalize if we successfully processed at least one chunk
|
||||||
if chunk_index > 0:
|
if chunk_index > 0:
|
||||||
|
|
|
@ -67,7 +67,7 @@ async def test_smart_split_short_text():
|
||||||
"""Test smart splitting with text under max tokens."""
|
"""Test smart splitting with text under max tokens."""
|
||||||
text = "This is a short test sentence."
|
text = "This is a short test sentence."
|
||||||
chunks = []
|
chunks = []
|
||||||
async for chunk_text, chunk_tokens in smart_split(text):
|
async for chunk_text, chunk_tokens, _ in smart_split(text):
|
||||||
chunks.append((chunk_text, chunk_tokens))
|
chunks.append((chunk_text, chunk_tokens))
|
||||||
|
|
||||||
assert len(chunks) == 1
|
assert len(chunks) == 1
|
||||||
|
@ -82,7 +82,7 @@ async def test_smart_split_long_text():
|
||||||
text = ". ".join(["This is test sentence number " + str(i) for i in range(20)])
|
text = ". ".join(["This is test sentence number " + str(i) for i in range(20)])
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
async for chunk_text, chunk_tokens in smart_split(text):
|
async for chunk_text, chunk_tokens, _ in smart_split(text):
|
||||||
chunks.append((chunk_text, chunk_tokens))
|
chunks.append((chunk_text, chunk_tokens))
|
||||||
|
|
||||||
assert len(chunks) > 1
|
assert len(chunks) > 1
|
||||||
|
@ -98,12 +98,13 @@ async def test_smart_split_with_punctuation():
|
||||||
text = "First sentence! Second sentence? Third sentence; Fourth sentence: Fifth sentence."
|
text = "First sentence! Second sentence? Third sentence; Fourth sentence: Fifth sentence."
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
async for chunk_text, chunk_tokens in smart_split(text):
|
async for chunk_text, chunk_tokens, _ in smart_split(text):
|
||||||
chunks.append(chunk_text)
|
chunks.append(chunk_text)
|
||||||
|
|
||||||
# Verify punctuation is preserved
|
# Verify punctuation is preserved
|
||||||
assert all(any(p in chunk for p in "!?;:.") for chunk in chunks)
|
assert all(any(p in chunk for p in "!?;:.") for chunk in chunks)
|
||||||
|
|
||||||
|
|
||||||
def test_process_text_chunk_chinese_phonemes():
|
def test_process_text_chunk_chinese_phonemes():
|
||||||
"""Test processing with Chinese pinyin phonemes."""
|
"""Test processing with Chinese pinyin phonemes."""
|
||||||
pinyin = "nǐ hǎo lì" # Example pinyin sequence with tones
|
pinyin = "nǐ hǎo lì" # Example pinyin sequence with tones
|
||||||
|
@ -125,12 +126,13 @@ def test_get_sentence_info_chinese():
|
||||||
assert count == len(tokens)
|
assert count == len(tokens)
|
||||||
assert count > 0
|
assert count > 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_smart_split_chinese_short():
|
async def test_smart_split_chinese_short():
|
||||||
"""Test Chinese smart splitting with short text."""
|
"""Test Chinese smart splitting with short text."""
|
||||||
text = "这是一句话。"
|
text = "这是一句话。"
|
||||||
chunks = []
|
chunks = []
|
||||||
async for chunk_text, chunk_tokens in smart_split(text, lang_code="z"):
|
async for chunk_text, chunk_tokens, _ in smart_split(text, lang_code="z"):
|
||||||
chunks.append((chunk_text, chunk_tokens))
|
chunks.append((chunk_text, chunk_tokens))
|
||||||
|
|
||||||
assert len(chunks) == 1
|
assert len(chunks) == 1
|
||||||
|
@ -144,7 +146,7 @@ async def test_smart_split_chinese_long():
|
||||||
text = "。".join([f"测试句子 {i}" for i in range(20)])
|
text = "。".join([f"测试句子 {i}" for i in range(20)])
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
async for chunk_text, chunk_tokens in smart_split(text, lang_code="z"):
|
async for chunk_text, chunk_tokens, _ in smart_split(text, lang_code="z"):
|
||||||
chunks.append((chunk_text, chunk_tokens))
|
chunks.append((chunk_text, chunk_tokens))
|
||||||
|
|
||||||
assert len(chunks) > 1
|
assert len(chunks) > 1
|
||||||
|
@ -160,8 +162,36 @@ async def test_smart_split_chinese_punctuation():
|
||||||
text = "第一句!第二问?第三句;第四句:第五句。"
|
text = "第一句!第二问?第三句;第四句:第五句。"
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
async for chunk_text, _ in smart_split(text, lang_code="z"):
|
async for chunk_text, _, _ in smart_split(text, lang_code="z"):
|
||||||
chunks.append(chunk_text)
|
chunks.append(chunk_text)
|
||||||
|
|
||||||
# Verify Chinese punctuation is preserved
|
# Verify Chinese punctuation is preserved
|
||||||
assert all(any(p in chunk for p in "!?;:。") for chunk in chunks)
|
assert all(any(p in chunk for p in "!?;:。") for chunk in chunks)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_smart_split_with_pause():
|
||||||
|
"""Test smart splitting with pause tags."""
|
||||||
|
text = "Hello world [pause:2.5s] How are you?"
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
async for chunk_text, chunk_tokens, pause_duration in smart_split(text):
|
||||||
|
chunks.append((chunk_text, chunk_tokens, pause_duration))
|
||||||
|
|
||||||
|
# Should have 3 chunks: text, pause, text
|
||||||
|
assert len(chunks) == 3
|
||||||
|
|
||||||
|
# First chunk: text
|
||||||
|
assert chunks[0][2] is None # No pause
|
||||||
|
assert "Hello world" in chunks[0][0]
|
||||||
|
assert len(chunks[0][1]) > 0
|
||||||
|
|
||||||
|
# Second chunk: pause
|
||||||
|
assert chunks[1][2] == 2.5 # 2.5 second pause
|
||||||
|
assert chunks[1][0] == "" # Empty text
|
||||||
|
assert len(chunks[1][1]) == 0 # No tokens
|
||||||
|
|
||||||
|
# Third chunk: text
|
||||||
|
assert chunks[2][2] is None # No pause
|
||||||
|
assert "How are you?" in chunks[2][0]
|
||||||
|
assert len(chunks[2][1]) > 0
|
Loading…
Add table
Reference in a new issue