feat(text): add Chinese punctuation-based sentence splitting for better TTS

This commit is contained in:
jiaohuix 2025-05-26 15:30:03 +08:00
parent ce22f60344
commit 9c279f2b5e
2 changed files with 41 additions and 25 deletions

View file

@ -11,7 +11,7 @@ from typing import List, Optional, Union
import inflect import inflect
from numpy import number from numpy import number
from text_to_num import text2num # from text_to_num import text2num
from torch import mul from torch import mul
from ...structures.schemas import NormalizationOptions from ...structures.schemas import NormalizationOptions

View file

@ -88,32 +88,48 @@ def process_text(text: str, language: str = "a") -> List[int]:
def get_sentence_info( def get_sentence_info(
text: str, custom_phenomes_list: Dict[str, str] text: str, custom_phenomes_list: Dict[str, str], lang_code: str = "a"
) -> List[Tuple[str, List[int], int]]: ) -> List[Tuple[str, List[int], int]]:
"""Process all sentences and return info.""" """Process all sentences and return info, 支持中文分句"""
sentences = re.split(r"([.!?;:])(?=\s|$)", text) # 判断是否为中文
is_chinese = lang_code.startswith("zh") or re.search(r"[\u4e00-\u9fff]", text)
if is_chinese:
# 按中文标点断句
sentences = re.split(r"([,。!?;])", text)
# 合并标点
merged = []
for i in range(0, len(sentences)-1, 2):
merged.append(sentences[i] + sentences[i+1])
if len(sentences) % 2 == 1:
merged.append(sentences[-1])
sentences = merged
else:
sentences = re.split(r"([.!?;:])(?=\s|$)", text)
phoneme_length, min_value = len(custom_phenomes_list), 0 phoneme_length, min_value = len(custom_phenomes_list), 0
results = [] results = []
for i in range(0, len(sentences), 2): if is_chinese:
sentence = sentences[i].strip() for sentence in sentences:
for replaced in range(min_value, phoneme_length): sentence = sentence.strip()
current_id = f"</|custom_phonemes_{replaced}|/>" if not sentence:
if current_id in sentence: continue
sentence = sentence.replace( tokens = process_text_chunk(sentence)
current_id, custom_phenomes_list.pop(current_id) results.append((sentence, tokens, len(tokens)))
) else:
min_value += 1 for i in range(0, len(sentences), 2):
sentence = sentences[i].strip()
punct = sentences[i + 1] if i + 1 < len(sentences) else "" for replaced in range(min_value, phoneme_length):
current_id = f"</|custom_phonemes_{replaced}|/>"
if not sentence: if current_id in sentence:
continue sentence = sentence.replace(
current_id, custom_phenomes_list.pop(current_id)
full = sentence + punct )
tokens = process_text_chunk(full) min_value += 1
results.append((full, tokens, len(tokens))) punct = sentences[i + 1] if i + 1 < len(sentences) else ""
if not sentence:
continue
full = sentence + punct
tokens = process_text_chunk(full)
results.append((full, tokens, len(tokens)))
return results return results
@ -150,7 +166,7 @@ async def smart_split(
) )
# Process all sentences # Process all sentences
sentences = get_sentence_info(text, custom_phoneme_list) sentences = get_sentence_info(text, custom_phoneme_list, lang_code=lang_code)
current_chunk = [] current_chunk = []
current_tokens = [] current_tokens = []