mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Make the code cleaner and add tests
This commit is contained in:
parent
9c279f2b5e
commit
b89da1ff28
4 changed files with 83 additions and 38 deletions
|
@ -75,7 +75,7 @@ def create_phonemizer(language: str = "a") -> PhonemizerBackend:
|
||||||
Phonemizer backend instance
|
Phonemizer backend instance
|
||||||
"""
|
"""
|
||||||
# Map language codes to espeak language codes
|
# Map language codes to espeak language codes
|
||||||
lang_map = {"a": "en-us", "b": "en-gb"}
|
lang_map = {"a": "en-us", "b": "en-gb", "z": "z"}
|
||||||
|
|
||||||
if language not in lang_map:
|
if language not in lang_map:
|
||||||
raise ValueError(f"Unsupported language code: {language}")
|
raise ValueError(f"Unsupported language code: {language}")
|
||||||
|
|
|
@ -92,44 +92,30 @@ def get_sentence_info(
|
||||||
) -> List[Tuple[str, List[int], int]]:
|
) -> List[Tuple[str, List[int], int]]:
|
||||||
"""Process all sentences and return info, 支持中文分句"""
|
"""Process all sentences and return info, 支持中文分句"""
|
||||||
# 判断是否为中文
|
# 判断是否为中文
|
||||||
is_chinese = lang_code.startswith("zh") or re.search(r"[\u4e00-\u9fff]", text)
|
is_chinese = lang_code.startswith("z") or re.search(r"[\u4e00-\u9fff]", text)
|
||||||
if is_chinese:
|
if is_chinese:
|
||||||
# 按中文标点断句
|
# 按中文标点断句
|
||||||
sentences = re.split(r"([,。!?;])", text)
|
sentences = re.split(r"([,。!?;])+", text)
|
||||||
# 合并标点
|
|
||||||
merged = []
|
|
||||||
for i in range(0, len(sentences)-1, 2):
|
|
||||||
merged.append(sentences[i] + sentences[i+1])
|
|
||||||
if len(sentences) % 2 == 1:
|
|
||||||
merged.append(sentences[-1])
|
|
||||||
sentences = merged
|
|
||||||
else:
|
else:
|
||||||
sentences = re.split(r"([.!?;:])(?=\s|$)", text)
|
sentences = re.split(r"([.!?;:])(?=\s|$)", text)
|
||||||
phoneme_length, min_value = len(custom_phenomes_list), 0
|
phoneme_length, min_value = len(custom_phenomes_list), 0
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
if is_chinese:
|
for i in range(0, len(sentences), 2):
|
||||||
for sentence in sentences:
|
sentence = sentences[i].strip()
|
||||||
sentence = sentence.strip()
|
for replaced in range(min_value, phoneme_length):
|
||||||
if not sentence:
|
current_id = f"</|custom_phonemes_{replaced}|/>"
|
||||||
continue
|
if current_id in sentence:
|
||||||
tokens = process_text_chunk(sentence)
|
sentence = sentence.replace(
|
||||||
results.append((sentence, tokens, len(tokens)))
|
current_id, custom_phenomes_list.pop(current_id)
|
||||||
else:
|
)
|
||||||
for i in range(0, len(sentences), 2):
|
min_value += 1
|
||||||
sentence = sentences[i].strip()
|
punct = sentences[i + 1] if i + 1 < len(sentences) else ""
|
||||||
for replaced in range(min_value, phoneme_length):
|
if not sentence:
|
||||||
current_id = f"</|custom_phonemes_{replaced}|/>"
|
continue
|
||||||
if current_id in sentence:
|
full = sentence + punct
|
||||||
sentence = sentence.replace(
|
tokens = process_text_chunk(full)
|
||||||
current_id, custom_phenomes_list.pop(current_id)
|
results.append((full, tokens, len(tokens)))
|
||||||
)
|
|
||||||
min_value += 1
|
|
||||||
punct = sentences[i + 1] if i + 1 < len(sentences) else ""
|
|
||||||
if not sentence:
|
|
||||||
continue
|
|
||||||
full = sentence + punct
|
|
||||||
tokens = process_text_chunk(full)
|
|
||||||
results.append((full, tokens, len(tokens)))
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
@ -154,7 +140,6 @@ async def smart_split(
|
||||||
|
|
||||||
# Normalize text
|
# Normalize text
|
||||||
if settings.advanced_text_normalization and normalization_options.normalize:
|
if settings.advanced_text_normalization and normalization_options.normalize:
|
||||||
print(lang_code)
|
|
||||||
if lang_code in ["a", "b", "en-us", "en-gb"]:
|
if lang_code in ["a", "b", "en-us", "en-gb"]:
|
||||||
text = CUSTOM_PHONEMES.sub(
|
text = CUSTOM_PHONEMES.sub(
|
||||||
lambda s: handle_custom_phonemes(s, custom_phoneme_list), text
|
lambda s: handle_custom_phonemes(s, custom_phoneme_list), text
|
||||||
|
|
|
@ -103,3 +103,65 @@ async def test_smart_split_with_punctuation():
|
||||||
|
|
||||||
# Verify punctuation is preserved
|
# Verify punctuation is preserved
|
||||||
assert all(any(p in chunk for p in "!?;:.") for chunk in chunks)
|
assert all(any(p in chunk for p in "!?;:.") for chunk in chunks)
|
||||||
|
|
||||||
|
def test_process_text_chunk_chinese_phonemes():
|
||||||
|
"""Test processing with Chinese pinyin phonemes."""
|
||||||
|
pinyin = "nǐ hǎo lì" # Example pinyin sequence with tones
|
||||||
|
tokens = process_text_chunk(pinyin, skip_phonemize=True, language="z")
|
||||||
|
assert isinstance(tokens, list)
|
||||||
|
assert len(tokens) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_sentence_info_chinese():
|
||||||
|
"""Test Chinese sentence splitting and info extraction."""
|
||||||
|
text = "这是一个句子。这是第二个句子!第三个问题?"
|
||||||
|
results = get_sentence_info(text, {}, lang_code="z")
|
||||||
|
|
||||||
|
assert len(results) == 3
|
||||||
|
for sentence, tokens, count in results:
|
||||||
|
assert isinstance(sentence, str)
|
||||||
|
assert isinstance(tokens, list)
|
||||||
|
assert isinstance(count, int)
|
||||||
|
assert count == len(tokens)
|
||||||
|
assert count > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_smart_split_chinese_short():
|
||||||
|
"""Test Chinese smart splitting with short text."""
|
||||||
|
text = "这是一句话。"
|
||||||
|
chunks = []
|
||||||
|
async for chunk_text, chunk_tokens in smart_split(text, lang_code="z"):
|
||||||
|
chunks.append((chunk_text, chunk_tokens))
|
||||||
|
|
||||||
|
assert len(chunks) == 1
|
||||||
|
assert isinstance(chunks[0][0], str)
|
||||||
|
assert isinstance(chunks[0][1], list)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_smart_split_chinese_long():
|
||||||
|
"""Test Chinese smart splitting with longer text."""
|
||||||
|
text = "。".join([f"测试句子 {i}" for i in range(20)])
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
async for chunk_text, chunk_tokens in smart_split(text, lang_code="z"):
|
||||||
|
chunks.append((chunk_text, chunk_tokens))
|
||||||
|
|
||||||
|
assert len(chunks) > 1
|
||||||
|
for chunk_text, chunk_tokens in chunks:
|
||||||
|
assert isinstance(chunk_text, str)
|
||||||
|
assert isinstance(chunk_tokens, list)
|
||||||
|
assert len(chunk_tokens) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_smart_split_chinese_punctuation():
|
||||||
|
"""Test Chinese smart splitting with punctuation preservation."""
|
||||||
|
text = "第一句!第二问?第三句;第四句:第五句。"
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
async for chunk_text, _ in smart_split(text, lang_code="z"):
|
||||||
|
chunks.append(chunk_text)
|
||||||
|
|
||||||
|
# Verify Chinese punctuation is preserved
|
||||||
|
assert all(any(p in chunk for p in "!?;:。") for chunk in chunks)
|
|
@ -3,9 +3,7 @@ import json
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
text = """the administration has offered up a platter of repression for more than a year and is still slated to lose $400 million.
|
text = """奶酪芝士很浓郁!臭豆腐芝士有争议?陈年奶酪价格昂贵。"""
|
||||||
|
|
||||||
Columbia is the largest private landowner in New York City and boasts an endowment of $14.8 billion;"""
|
|
||||||
|
|
||||||
|
|
||||||
Type = "wav"
|
Type = "wav"
|
||||||
|
@ -15,7 +13,7 @@ response = requests.post(
|
||||||
json={
|
json={
|
||||||
"model": "kokoro",
|
"model": "kokoro",
|
||||||
"input": text,
|
"input": text,
|
||||||
"voice": "af_heart+af_sky",
|
"voice": "zf_xiaobei",
|
||||||
"speed": 1.0,
|
"speed": 1.0,
|
||||||
"response_format": Type,
|
"response_format": Type,
|
||||||
"stream": False,
|
"stream": False,
|
||||||
|
|
Loading…
Add table
Reference in a new issue