From 2a6d2ae483eeec0db7fe29b3796111cc4df6cdb4 Mon Sep 17 00:00:00 2001 From: Fireblade2534 Date: Thu, 26 Jun 2025 00:16:06 +0000 Subject: [PATCH] Fix custom phenomes and make them more robust --- .../services/text_processing/normalizer.py | 2 +- .../text_processing/text_processor.py | 27 ++++------ api/tests/test_text_processor.py | 49 +++++++++++-------- 3 files changed, 39 insertions(+), 39 deletions(-) diff --git a/api/src/services/text_processing/normalizer.py b/api/src/services/text_processing/normalizer.py index 7908318..1b1c9f7 100644 --- a/api/src/services/text_processing/normalizer.py +++ b/api/src/services/text_processing/normalizer.py @@ -497,4 +497,4 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st text = re.sub(r"\s{2,}", " ", text) - return text.strip() + return text diff --git a/api/src/services/text_processing/text_processor.py b/api/src/services/text_processing/text_processor.py index 39b0d6c..483618f 100644 --- a/api/src/services/text_processing/text_processor.py +++ b/api/src/services/text_processing/text_processor.py @@ -15,7 +15,7 @@ from .vocabulary import tokenize # Pre-compiled regex patterns for performance # Updated regex to be more strict and avoid matching isolated brackets # Only matches complete patterns like [word](/ipa/) and prevents catastrophic backtracking -CUSTOM_PHONEMES = re.compile(r"(\[[^\[\]]*?\])(\(\/[^\/\(\)]*?\/\))") +CUSTOM_PHONEMES = re.compile(r"(\[[^\[\]]*?\]\(\/[^\/\(\)]*?\/\))") # Pattern to find pause tags like [pause:0.5s] PAUSE_TAG_PATTERN = re.compile(r"\[pause:(\d+(?:\.\d+)?)s\]", re.IGNORECASE) @@ -100,7 +100,7 @@ def process_text(text: str, language: str = "a") -> List[int]: def get_sentence_info( - text: str, custom_phenomes_list: Dict[str, str], lang_code: str = "a" + text: str, lang_code: str = "a" ) -> List[Tuple[str, List[int], int]]: """Process all sentences and return info""" # Detect Chinese text @@ -110,18 +110,10 @@ def get_sentence_info( sentences = re.split(r"([,。!?;])+", text) else: sentences = re.split(r"([.!?;:])(?=\s|$)", text) - phoneme_length, min_value = len(custom_phenomes_list), 0 results = [] for i in range(0, len(sentences), 2): sentence = sentences[i].strip() - for replaced in range(min_value, phoneme_length): - current_id = f"" - if current_id in sentence: - sentence = sentence.replace( - current_id, custom_phenomes_list.pop(current_id) - ) - min_value += 1 punct = sentences[i + 1] if i + 1 < len(sentences) else "" if not sentence: continue @@ -173,24 +165,23 @@ async def smart_split( # Strip leading and trailing spaces to prevent pause tag splitting artifacts text_part_raw = text_part_raw.strip() - # Apply the original smart_split logic to this text part - custom_phoneme_list = {} - # Normalize text (original logic) processed_text = text_part_raw if settings.advanced_text_normalization and normalization_options.normalize: if lang_code in ["a", "b", "en-us", "en-gb"]: - processed_text = CUSTOM_PHONEMES.sub( - lambda s: handle_custom_phonemes(s, custom_phoneme_list), processed_text - ) - processed_text = normalize_text(processed_text, normalization_options) + processed_text = CUSTOM_PHONEMES.split(processed_text) + for index in range(0, len(processed_text), 2): + processed_text[index] = normalize_text(processed_text[index], normalization_options) + + + processed_text = "".join(processed_text).strip() else: logger.info( "Skipping text normalization as it is only supported for english" ) # Process all sentences (original logic) - sentences = get_sentence_info(processed_text, custom_phoneme_list, lang_code=lang_code) + sentences = get_sentence_info(processed_text, lang_code=lang_code) current_chunk = [] current_tokens = [] diff --git a/api/tests/test_text_processor.py b/api/tests/test_text_processor.py index 7495d1d..3fc8a87 100644 --- a/api/tests/test_text_processor.py +++ b/api/tests/test_text_processor.py @@ -34,7 +34,7 @@ def test_process_text_chunk_phonemes(): def test_get_sentence_info(): """Test sentence splitting and info extraction.""" text = "This is sentence one. This is sentence two! What about three?" - results = get_sentence_info(text, {}) + results = get_sentence_info(text) assert len(results) == 3 for sentence, tokens, count in results: @@ -44,24 +44,6 @@ def test_get_sentence_info(): assert count == len(tokens) assert count > 0 - -def test_get_sentence_info_phenomoes(): - """Test sentence splitting and info extraction.""" - text = ( - "This is sentence one. This is two! What about three?" - ) - results = get_sentence_info(text, {"": r"sˈɛntᵊns"}) - - assert len(results) == 3 - assert "sˈɛntᵊns" in results[1][0] - for sentence, tokens, count in results: - assert isinstance(sentence, str) - assert isinstance(tokens, list) - assert isinstance(count, int) - assert count == len(tokens) - assert count > 0 - - @pytest.mark.asyncio async def test_smart_split_short_text(): """Test smart splitting with text under max tokens.""" @@ -74,6 +56,33 @@ async def test_smart_split_short_text(): assert isinstance(chunks[0][0], str) assert isinstance(chunks[0][1], list) +@pytest.mark.asyncio +async def test_smart_custom_phenomes(): + """Test smart splitting with text under max tokens.""" + text = "This is a short test sentence. [Kokoro](/kˈOkəɹO/) has a feature called custom phenomes. This is made possible by [Misaki](/misˈɑki/), the custom phenomizer that [Kokoro](/kˈOkəɹO/) version 1.0 uses" + chunks = [] + async for chunk_text, chunk_tokens, pause_duration in smart_split(text): + chunks.append((chunk_text, chunk_tokens, pause_duration)) + + # Should have 1 chunks: text + assert len(chunks) == 1 + + # First chunk: text + assert chunks[0][2] is None # No pause + assert "This is a short test sentence. [Kokoro](/kˈOkəɹO/) has a feature called custom phenomes. This is made possible by [Misaki](/misˈɑki/), the custom phenomizer that [Kokoro](/kˈOkəɹO/) version one uses" in chunks[0][0] + assert len(chunks[0][1]) > 0 + +@pytest.mark.asyncio +async def test_smart_split_only_phenomes(): + """Test input that is entirely made of phenome annotations.""" + text = "[Kokoro](/kˈOkəɹO/) [Misaki 1.2](/misˈɑki/) [Test](/tɛst/)" + chunks = [] + async for chunk_text, chunk_tokens, pause_duration in smart_split(text, max_tokens=10): + chunks.append((chunk_text, chunk_tokens, pause_duration)) + + assert len(chunks) == 1 + assert "[Kokoro](/kˈOkəɹO/) [Misaki 1.2](/misˈɑki/) [Test](/tɛst/)" in chunks[0][0] + @pytest.mark.asyncio async def test_smart_split_long_text(): @@ -116,7 +125,7 @@ def test_process_text_chunk_chinese_phonemes(): def test_get_sentence_info_chinese(): """Test Chinese sentence splitting and info extraction.""" text = "这是一个句子。这是第二个句子!第三个问题?" - results = get_sentence_info(text, {}, lang_code="z") + results = get_sentence_info(text, lang_code="z") assert len(results) == 3 for sentence, tokens, count in results: