mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Merge pull request #350 from fireblade2534/master
Fix custom phenomes and make them much more robust
This commit is contained in:
commit
f8c89161f6
3 changed files with 39 additions and 39 deletions
|
@ -497,4 +497,4 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
|
||||||
|
|
||||||
text = re.sub(r"\s{2,}", " ", text)
|
text = re.sub(r"\s{2,}", " ", text)
|
||||||
|
|
||||||
return text.strip()
|
return text
|
||||||
|
|
|
@ -15,7 +15,7 @@ from .vocabulary import tokenize
|
||||||
# Pre-compiled regex patterns for performance
|
# Pre-compiled regex patterns for performance
|
||||||
# Updated regex to be more strict and avoid matching isolated brackets
|
# Updated regex to be more strict and avoid matching isolated brackets
|
||||||
# Only matches complete patterns like [word](/ipa/) and prevents catastrophic backtracking
|
# Only matches complete patterns like [word](/ipa/) and prevents catastrophic backtracking
|
||||||
CUSTOM_PHONEMES = re.compile(r"(\[[^\[\]]*?\])(\(\/[^\/\(\)]*?\/\))")
|
CUSTOM_PHONEMES = re.compile(r"(\[[^\[\]]*?\]\(\/[^\/\(\)]*?\/\))")
|
||||||
# Pattern to find pause tags like [pause:0.5s]
|
# Pattern to find pause tags like [pause:0.5s]
|
||||||
PAUSE_TAG_PATTERN = re.compile(r"\[pause:(\d+(?:\.\d+)?)s\]", re.IGNORECASE)
|
PAUSE_TAG_PATTERN = re.compile(r"\[pause:(\d+(?:\.\d+)?)s\]", re.IGNORECASE)
|
||||||
|
|
||||||
|
@ -100,7 +100,7 @@ def process_text(text: str, language: str = "a") -> List[int]:
|
||||||
|
|
||||||
|
|
||||||
def get_sentence_info(
|
def get_sentence_info(
|
||||||
text: str, custom_phenomes_list: Dict[str, str], lang_code: str = "a"
|
text: str, lang_code: str = "a"
|
||||||
) -> List[Tuple[str, List[int], int]]:
|
) -> List[Tuple[str, List[int], int]]:
|
||||||
"""Process all sentences and return info"""
|
"""Process all sentences and return info"""
|
||||||
# Detect Chinese text
|
# Detect Chinese text
|
||||||
|
@ -110,18 +110,10 @@ def get_sentence_info(
|
||||||
sentences = re.split(r"([,。!?;])+", text)
|
sentences = re.split(r"([,。!?;])+", text)
|
||||||
else:
|
else:
|
||||||
sentences = re.split(r"([.!?;:])(?=\s|$)", text)
|
sentences = re.split(r"([.!?;:])(?=\s|$)", text)
|
||||||
phoneme_length, min_value = len(custom_phenomes_list), 0
|
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for i in range(0, len(sentences), 2):
|
for i in range(0, len(sentences), 2):
|
||||||
sentence = sentences[i].strip()
|
sentence = sentences[i].strip()
|
||||||
for replaced in range(min_value, phoneme_length):
|
|
||||||
current_id = f"</|custom_phonemes_{replaced}|/>"
|
|
||||||
if current_id in sentence:
|
|
||||||
sentence = sentence.replace(
|
|
||||||
current_id, custom_phenomes_list.pop(current_id)
|
|
||||||
)
|
|
||||||
min_value += 1
|
|
||||||
punct = sentences[i + 1] if i + 1 < len(sentences) else ""
|
punct = sentences[i + 1] if i + 1 < len(sentences) else ""
|
||||||
if not sentence:
|
if not sentence:
|
||||||
continue
|
continue
|
||||||
|
@ -173,24 +165,23 @@ async def smart_split(
|
||||||
# Strip leading and trailing spaces to prevent pause tag splitting artifacts
|
# Strip leading and trailing spaces to prevent pause tag splitting artifacts
|
||||||
text_part_raw = text_part_raw.strip()
|
text_part_raw = text_part_raw.strip()
|
||||||
|
|
||||||
# Apply the original smart_split logic to this text part
|
|
||||||
custom_phoneme_list = {}
|
|
||||||
|
|
||||||
# Normalize text (original logic)
|
# Normalize text (original logic)
|
||||||
processed_text = text_part_raw
|
processed_text = text_part_raw
|
||||||
if settings.advanced_text_normalization and normalization_options.normalize:
|
if settings.advanced_text_normalization and normalization_options.normalize:
|
||||||
if lang_code in ["a", "b", "en-us", "en-gb"]:
|
if lang_code in ["a", "b", "en-us", "en-gb"]:
|
||||||
processed_text = CUSTOM_PHONEMES.sub(
|
processed_text = CUSTOM_PHONEMES.split(processed_text)
|
||||||
lambda s: handle_custom_phonemes(s, custom_phoneme_list), processed_text
|
for index in range(0, len(processed_text), 2):
|
||||||
)
|
processed_text[index] = normalize_text(processed_text[index], normalization_options)
|
||||||
processed_text = normalize_text(processed_text, normalization_options)
|
|
||||||
|
|
||||||
|
processed_text = "".join(processed_text).strip()
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Skipping text normalization as it is only supported for english"
|
"Skipping text normalization as it is only supported for english"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process all sentences (original logic)
|
# Process all sentences (original logic)
|
||||||
sentences = get_sentence_info(processed_text, custom_phoneme_list, lang_code=lang_code)
|
sentences = get_sentence_info(processed_text, lang_code=lang_code)
|
||||||
|
|
||||||
current_chunk = []
|
current_chunk = []
|
||||||
current_tokens = []
|
current_tokens = []
|
||||||
|
|
|
@ -34,7 +34,7 @@ def test_process_text_chunk_phonemes():
|
||||||
def test_get_sentence_info():
|
def test_get_sentence_info():
|
||||||
"""Test sentence splitting and info extraction."""
|
"""Test sentence splitting and info extraction."""
|
||||||
text = "This is sentence one. This is sentence two! What about three?"
|
text = "This is sentence one. This is sentence two! What about three?"
|
||||||
results = get_sentence_info(text, {})
|
results = get_sentence_info(text)
|
||||||
|
|
||||||
assert len(results) == 3
|
assert len(results) == 3
|
||||||
for sentence, tokens, count in results:
|
for sentence, tokens, count in results:
|
||||||
|
@ -44,24 +44,6 @@ def test_get_sentence_info():
|
||||||
assert count == len(tokens)
|
assert count == len(tokens)
|
||||||
assert count > 0
|
assert count > 0
|
||||||
|
|
||||||
|
|
||||||
def test_get_sentence_info_phenomoes():
|
|
||||||
"""Test sentence splitting and info extraction."""
|
|
||||||
text = (
|
|
||||||
"This is sentence one. This is </|custom_phonemes_0|/> two! What about three?"
|
|
||||||
)
|
|
||||||
results = get_sentence_info(text, {"</|custom_phonemes_0|/>": r"sˈɛntᵊns"})
|
|
||||||
|
|
||||||
assert len(results) == 3
|
|
||||||
assert "sˈɛntᵊns" in results[1][0]
|
|
||||||
for sentence, tokens, count in results:
|
|
||||||
assert isinstance(sentence, str)
|
|
||||||
assert isinstance(tokens, list)
|
|
||||||
assert isinstance(count, int)
|
|
||||||
assert count == len(tokens)
|
|
||||||
assert count > 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_smart_split_short_text():
|
async def test_smart_split_short_text():
|
||||||
"""Test smart splitting with text under max tokens."""
|
"""Test smart splitting with text under max tokens."""
|
||||||
|
@ -74,6 +56,33 @@ async def test_smart_split_short_text():
|
||||||
assert isinstance(chunks[0][0], str)
|
assert isinstance(chunks[0][0], str)
|
||||||
assert isinstance(chunks[0][1], list)
|
assert isinstance(chunks[0][1], list)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_smart_custom_phenomes():
|
||||||
|
"""Test smart splitting with text under max tokens."""
|
||||||
|
text = "This is a short test sentence. [Kokoro](/kˈOkəɹO/) has a feature called custom phenomes. This is made possible by [Misaki](/misˈɑki/), the custom phenomizer that [Kokoro](/kˈOkəɹO/) version 1.0 uses"
|
||||||
|
chunks = []
|
||||||
|
async for chunk_text, chunk_tokens, pause_duration in smart_split(text):
|
||||||
|
chunks.append((chunk_text, chunk_tokens, pause_duration))
|
||||||
|
|
||||||
|
# Should have 1 chunks: text
|
||||||
|
assert len(chunks) == 1
|
||||||
|
|
||||||
|
# First chunk: text
|
||||||
|
assert chunks[0][2] is None # No pause
|
||||||
|
assert "This is a short test sentence. [Kokoro](/kˈOkəɹO/) has a feature called custom phenomes. This is made possible by [Misaki](/misˈɑki/), the custom phenomizer that [Kokoro](/kˈOkəɹO/) version one uses" in chunks[0][0]
|
||||||
|
assert len(chunks[0][1]) > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_smart_split_only_phenomes():
|
||||||
|
"""Test input that is entirely made of phenome annotations."""
|
||||||
|
text = "[Kokoro](/kˈOkəɹO/) [Misaki 1.2](/misˈɑki/) [Test](/tɛst/)"
|
||||||
|
chunks = []
|
||||||
|
async for chunk_text, chunk_tokens, pause_duration in smart_split(text, max_tokens=10):
|
||||||
|
chunks.append((chunk_text, chunk_tokens, pause_duration))
|
||||||
|
|
||||||
|
assert len(chunks) == 1
|
||||||
|
assert "[Kokoro](/kˈOkəɹO/) [Misaki 1.2](/misˈɑki/) [Test](/tɛst/)" in chunks[0][0]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_smart_split_long_text():
|
async def test_smart_split_long_text():
|
||||||
|
@ -116,7 +125,7 @@ def test_process_text_chunk_chinese_phonemes():
|
||||||
def test_get_sentence_info_chinese():
|
def test_get_sentence_info_chinese():
|
||||||
"""Test Chinese sentence splitting and info extraction."""
|
"""Test Chinese sentence splitting and info extraction."""
|
||||||
text = "这是一个句子。这是第二个句子!第三个问题?"
|
text = "这是一个句子。这是第二个句子!第三个问题?"
|
||||||
results = get_sentence_info(text, {}, lang_code="z")
|
results = get_sentence_info(text, lang_code="z")
|
||||||
|
|
||||||
assert len(results) == 3
|
assert len(results) == 3
|
||||||
for sentence, tokens, count in results:
|
for sentence, tokens, count in results:
|
||||||
|
|
Loading…
Add table
Reference in a new issue