mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Merge pull request #350 from fireblade2534/master
Fix custom phenomes and make them much more robust
This commit is contained in:
commit
f8c89161f6
3 changed files with 39 additions and 39 deletions
|
@ -497,4 +497,4 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
|
|||
|
||||
text = re.sub(r"\s{2,}", " ", text)
|
||||
|
||||
return text.strip()
|
||||
return text
|
||||
|
|
|
@ -15,7 +15,7 @@ from .vocabulary import tokenize
|
|||
# Pre-compiled regex patterns for performance
|
||||
# Updated regex to be more strict and avoid matching isolated brackets
|
||||
# Only matches complete patterns like [word](/ipa/) and prevents catastrophic backtracking
|
||||
CUSTOM_PHONEMES = re.compile(r"(\[[^\[\]]*?\])(\(\/[^\/\(\)]*?\/\))")
|
||||
CUSTOM_PHONEMES = re.compile(r"(\[[^\[\]]*?\]\(\/[^\/\(\)]*?\/\))")
|
||||
# Pattern to find pause tags like [pause:0.5s]
|
||||
PAUSE_TAG_PATTERN = re.compile(r"\[pause:(\d+(?:\.\d+)?)s\]", re.IGNORECASE)
|
||||
|
||||
|
@ -100,7 +100,7 @@ def process_text(text: str, language: str = "a") -> List[int]:
|
|||
|
||||
|
||||
def get_sentence_info(
|
||||
text: str, custom_phenomes_list: Dict[str, str], lang_code: str = "a"
|
||||
text: str, lang_code: str = "a"
|
||||
) -> List[Tuple[str, List[int], int]]:
|
||||
"""Process all sentences and return info"""
|
||||
# Detect Chinese text
|
||||
|
@ -110,18 +110,10 @@ def get_sentence_info(
|
|||
sentences = re.split(r"([,。!?;])+", text)
|
||||
else:
|
||||
sentences = re.split(r"([.!?;:])(?=\s|$)", text)
|
||||
phoneme_length, min_value = len(custom_phenomes_list), 0
|
||||
|
||||
results = []
|
||||
for i in range(0, len(sentences), 2):
|
||||
sentence = sentences[i].strip()
|
||||
for replaced in range(min_value, phoneme_length):
|
||||
current_id = f"</|custom_phonemes_{replaced}|/>"
|
||||
if current_id in sentence:
|
||||
sentence = sentence.replace(
|
||||
current_id, custom_phenomes_list.pop(current_id)
|
||||
)
|
||||
min_value += 1
|
||||
punct = sentences[i + 1] if i + 1 < len(sentences) else ""
|
||||
if not sentence:
|
||||
continue
|
||||
|
@ -173,24 +165,23 @@ async def smart_split(
|
|||
# Strip leading and trailing spaces to prevent pause tag splitting artifacts
|
||||
text_part_raw = text_part_raw.strip()
|
||||
|
||||
# Apply the original smart_split logic to this text part
|
||||
custom_phoneme_list = {}
|
||||
|
||||
# Normalize text (original logic)
|
||||
processed_text = text_part_raw
|
||||
if settings.advanced_text_normalization and normalization_options.normalize:
|
||||
if lang_code in ["a", "b", "en-us", "en-gb"]:
|
||||
processed_text = CUSTOM_PHONEMES.sub(
|
||||
lambda s: handle_custom_phonemes(s, custom_phoneme_list), processed_text
|
||||
)
|
||||
processed_text = normalize_text(processed_text, normalization_options)
|
||||
processed_text = CUSTOM_PHONEMES.split(processed_text)
|
||||
for index in range(0, len(processed_text), 2):
|
||||
processed_text[index] = normalize_text(processed_text[index], normalization_options)
|
||||
|
||||
|
||||
processed_text = "".join(processed_text).strip()
|
||||
else:
|
||||
logger.info(
|
||||
"Skipping text normalization as it is only supported for english"
|
||||
)
|
||||
|
||||
# Process all sentences (original logic)
|
||||
sentences = get_sentence_info(processed_text, custom_phoneme_list, lang_code=lang_code)
|
||||
sentences = get_sentence_info(processed_text, lang_code=lang_code)
|
||||
|
||||
current_chunk = []
|
||||
current_tokens = []
|
||||
|
|
|
@ -34,7 +34,7 @@ def test_process_text_chunk_phonemes():
|
|||
def test_get_sentence_info():
|
||||
"""Test sentence splitting and info extraction."""
|
||||
text = "This is sentence one. This is sentence two! What about three?"
|
||||
results = get_sentence_info(text, {})
|
||||
results = get_sentence_info(text)
|
||||
|
||||
assert len(results) == 3
|
||||
for sentence, tokens, count in results:
|
||||
|
@ -44,24 +44,6 @@ def test_get_sentence_info():
|
|||
assert count == len(tokens)
|
||||
assert count > 0
|
||||
|
||||
|
||||
def test_get_sentence_info_phenomoes():
|
||||
"""Test sentence splitting and info extraction."""
|
||||
text = (
|
||||
"This is sentence one. This is </|custom_phonemes_0|/> two! What about three?"
|
||||
)
|
||||
results = get_sentence_info(text, {"</|custom_phonemes_0|/>": r"sˈɛntᵊns"})
|
||||
|
||||
assert len(results) == 3
|
||||
assert "sˈɛntᵊns" in results[1][0]
|
||||
for sentence, tokens, count in results:
|
||||
assert isinstance(sentence, str)
|
||||
assert isinstance(tokens, list)
|
||||
assert isinstance(count, int)
|
||||
assert count == len(tokens)
|
||||
assert count > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_split_short_text():
|
||||
"""Test smart splitting with text under max tokens."""
|
||||
|
@ -74,6 +56,33 @@ async def test_smart_split_short_text():
|
|||
assert isinstance(chunks[0][0], str)
|
||||
assert isinstance(chunks[0][1], list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_custom_phenomes():
|
||||
"""Test smart splitting with text under max tokens."""
|
||||
text = "This is a short test sentence. [Kokoro](/kˈOkəɹO/) has a feature called custom phenomes. This is made possible by [Misaki](/misˈɑki/), the custom phenomizer that [Kokoro](/kˈOkəɹO/) version 1.0 uses"
|
||||
chunks = []
|
||||
async for chunk_text, chunk_tokens, pause_duration in smart_split(text):
|
||||
chunks.append((chunk_text, chunk_tokens, pause_duration))
|
||||
|
||||
# Should have 1 chunks: text
|
||||
assert len(chunks) == 1
|
||||
|
||||
# First chunk: text
|
||||
assert chunks[0][2] is None # No pause
|
||||
assert "This is a short test sentence. [Kokoro](/kˈOkəɹO/) has a feature called custom phenomes. This is made possible by [Misaki](/misˈɑki/), the custom phenomizer that [Kokoro](/kˈOkəɹO/) version one uses" in chunks[0][0]
|
||||
assert len(chunks[0][1]) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_split_only_phenomes():
|
||||
"""Test input that is entirely made of phenome annotations."""
|
||||
text = "[Kokoro](/kˈOkəɹO/) [Misaki 1.2](/misˈɑki/) [Test](/tɛst/)"
|
||||
chunks = []
|
||||
async for chunk_text, chunk_tokens, pause_duration in smart_split(text, max_tokens=10):
|
||||
chunks.append((chunk_text, chunk_tokens, pause_duration))
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert "[Kokoro](/kˈOkəɹO/) [Misaki 1.2](/misˈɑki/) [Test](/tɛst/)" in chunks[0][0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_split_long_text():
|
||||
|
@ -116,7 +125,7 @@ def test_process_text_chunk_chinese_phonemes():
|
|||
def test_get_sentence_info_chinese():
|
||||
"""Test Chinese sentence splitting and info extraction."""
|
||||
text = "这是一个句子。这是第二个句子!第三个问题?"
|
||||
results = get_sentence_info(text, {}, lang_code="z")
|
||||
results = get_sentence_info(text, lang_code="z")
|
||||
|
||||
assert len(results) == 3
|
||||
for sentence, tokens, count in results:
|
||||
|
|
Loading…
Add table
Reference in a new issue