mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
simplified some normalization and added more tests
This commit is contained in:
parent
dbb66ff1e1
commit
d7d90cdc9d
7 changed files with 77 additions and 29 deletions
|
@ -47,12 +47,12 @@ class StreamingAudioWriter:
|
||||||
)
|
)
|
||||||
self.stream = self.container.add_stream(
|
self.stream = self.container.add_stream(
|
||||||
codec_map[self.format],
|
codec_map[self.format],
|
||||||
rate=self.sample_rate, # Correct parameter name is 'rate'
|
rate=self.sample_rate,
|
||||||
layout="mono" if self.channels == 1 else "stereo",
|
layout="mono" if self.channels == 1 else "stereo",
|
||||||
)
|
)
|
||||||
# Set bit_rate only for codecs where it's applicable and useful
|
# Set bit_rate only for codecs where it's applicable and useful
|
||||||
if self.format in ['mp3', 'aac', 'opus']:
|
if self.format in ['mp3', 'aac', 'opus']:
|
||||||
self.stream.bit_rate = 128000 # Example bitrate, can be configured
|
self.stream.bit_rate = 128000
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported format: {self.format}") # Use self.format here
|
raise ValueError(f"Unsupported format: {self.format}") # Use self.format here
|
||||||
|
|
||||||
|
|
|
@ -134,6 +134,23 @@ VALID_UNITS = {
|
||||||
"px": "pixel", # CSS units
|
"px": "pixel", # CSS units
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SYMBOL_REPLACEMENTS = {
|
||||||
|
'~': ' ',
|
||||||
|
'@': ' at ',
|
||||||
|
'#': ' number ',
|
||||||
|
'$': ' dollar ',
|
||||||
|
'%': ' percent ',
|
||||||
|
'^': ' ',
|
||||||
|
'&': ' and ',
|
||||||
|
'*': ' ',
|
||||||
|
'_': ' ',
|
||||||
|
'|': ' ',
|
||||||
|
'\\': ' ',
|
||||||
|
'/': ' slash ',
|
||||||
|
'=': ' equals ',
|
||||||
|
'+': ' plus ',
|
||||||
|
}
|
||||||
|
|
||||||
MONEY_UNITS = {"$": ("dollar", "cent"), "£": ("pound", "pence"), "€": ("euro", "cent")}
|
MONEY_UNITS = {"$": ("dollar", "cent"), "£": ("pound", "pence"), "€": ("euro", "cent")}
|
||||||
|
|
||||||
# Pre-compiled regex patterns for performance
|
# Pre-compiled regex patterns for performance
|
||||||
|
@ -464,20 +481,9 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
|
||||||
text = re.sub(r"\d*\.\d+", handle_decimal, text)
|
text = re.sub(r"\d*\.\d+", handle_decimal, text)
|
||||||
|
|
||||||
# Handle other problematic symbols AFTER money/number processing
|
# Handle other problematic symbols AFTER money/number processing
|
||||||
text = text.replace('~', '') # Remove tilde
|
if normalization_options.replace_remaining_symbols:
|
||||||
text = text.replace('@', ' at ') # At symbol
|
for symbol, replacement in SYMBOL_REPLACEMENTS.items():
|
||||||
text = text.replace('#', ' number ') # Hash/pound
|
text = text.replace(symbol, replacement)
|
||||||
text = text.replace('$', ' dollar ') # Dollar sign (if not handled by money pattern)
|
|
||||||
text = text.replace('%', ' percent ') # Percent sign
|
|
||||||
text = text.replace('^', '') # Caret
|
|
||||||
text = text.replace('&', ' and ') # Ampersand
|
|
||||||
text = text.replace('*', '') # Asterisk
|
|
||||||
text = text.replace('_', ' ') # Underscore to space
|
|
||||||
text = text.replace('|', ' ') # Pipe to space
|
|
||||||
text = text.replace('\\', ' ') # Backslash to space
|
|
||||||
text = text.replace('/', ' slash ') # Forward slash to space (unless in URLs)
|
|
||||||
text = text.replace('=', ' equals ') # Equals sign
|
|
||||||
text = text.replace('+', ' plus ') # Plus sign
|
|
||||||
|
|
||||||
# Handle various formatting
|
# Handle various formatting
|
||||||
text = re.sub(r"(?<=\d)-(?=\d)", " to ", text)
|
text = re.sub(r"(?<=\d)-(?=\d)", " to ", text)
|
||||||
|
@ -489,4 +495,6 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
|
||||||
)
|
)
|
||||||
text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
|
text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
|
||||||
|
|
||||||
|
text = re.sub(r"\s{2,}", " ", text)
|
||||||
|
|
||||||
return text.strip()
|
return text.strip()
|
||||||
|
|
|
@ -84,13 +84,12 @@ def create_phonemizer(language: str = "a") -> PhonemizerBackend:
|
||||||
return EspeakBackend(lang_map[language])
|
return EspeakBackend(lang_map[language])
|
||||||
|
|
||||||
|
|
||||||
def phonemize(text: str, language: str = "a", normalize: bool = True) -> str:
|
def phonemize(text: str, language: str = "a") -> str:
|
||||||
"""Convert text to phonemes
|
"""Convert text to phonemes
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: Text to convert to phonemes
|
text: Text to convert to phonemes
|
||||||
language: Language code ('a' for US English, 'b' for British English)
|
language: Language code ('a' for US English, 'b' for British English)
|
||||||
normalize: Whether to normalize text before phonemization
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Phonemized text
|
Phonemized text
|
||||||
|
@ -100,13 +99,6 @@ def phonemize(text: str, language: str = "a", normalize: bool = True) -> str:
|
||||||
# Strip input text first to remove problematic leading/trailing spaces
|
# Strip input text first to remove problematic leading/trailing spaces
|
||||||
text = text.strip()
|
text = text.strip()
|
||||||
|
|
||||||
if normalize:
|
|
||||||
# Create default normalization options and normalize text
|
|
||||||
normalization_options = NormalizationOptions()
|
|
||||||
text = normalize_text(text, normalization_options)
|
|
||||||
# Strip again after normalization
|
|
||||||
text = text.strip()
|
|
||||||
|
|
||||||
if language not in phonemizers:
|
if language not in phonemizers:
|
||||||
phonemizers[language] = create_phonemizer(language)
|
phonemizers[language] = create_phonemizer(language)
|
||||||
|
|
||||||
|
|
|
@ -52,7 +52,7 @@ def process_text_chunk(
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
phonemes = phonemize(text, language, normalize=False) # Already normalized
|
phonemes = phonemize(text, language)
|
||||||
# Strip phonemes result to ensure no extra spaces
|
# Strip phonemes result to ensure no extra spaces
|
||||||
phonemes = phonemes.strip()
|
phonemes = phonemes.strip()
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
|
@ -102,11 +102,11 @@ def process_text(text: str, language: str = "a") -> List[int]:
|
||||||
def get_sentence_info(
|
def get_sentence_info(
|
||||||
text: str, custom_phenomes_list: Dict[str, str], lang_code: str = "a"
|
text: str, custom_phenomes_list: Dict[str, str], lang_code: str = "a"
|
||||||
) -> List[Tuple[str, List[int], int]]:
|
) -> List[Tuple[str, List[int], int]]:
|
||||||
"""Process all sentences and return info, 支持中文分句"""
|
"""Process all sentences and return info"""
|
||||||
# 判断是否为中文
|
# Detect Chinese text
|
||||||
is_chinese = lang_code.startswith("z") or re.search(r"[\u4e00-\u9fff]", text)
|
is_chinese = lang_code.startswith("z") or re.search(r"[\u4e00-\u9fff]", text)
|
||||||
if is_chinese:
|
if is_chinese:
|
||||||
# 按中文标点断句
|
# Split using Chinese punctuation
|
||||||
sentences = re.split(r"([,。!?;])+", text)
|
sentences = re.split(r"([,。!?;])+", text)
|
||||||
else:
|
else:
|
||||||
sentences = re.split(r"([.!?;:])(?=\s|$)", text)
|
sentences = re.split(r"([.!?;:])(?=\s|$)", text)
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from email.policy import default
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Literal, Optional, Union
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
|
@ -66,6 +67,10 @@ class NormalizationOptions(BaseModel):
|
||||||
default=True,
|
default=True,
|
||||||
description="Changes phone numbers so they can be properly pronouced by kokoro",
|
description="Changes phone numbers so they can be properly pronouced by kokoro",
|
||||||
)
|
)
|
||||||
|
replace_remaining_symbols: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Replaces the remaining symbols after normalization with their words"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class OpenAISpeechRequest(BaseModel):
|
class OpenAISpeechRequest(BaseModel):
|
||||||
|
|
|
@ -175,6 +175,13 @@ def test_money():
|
||||||
== "The plant cost two hundred thousand dollars and eighty cents."
|
== "The plant cost two hundred thousand dollars and eighty cents."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"Your shopping spree cost $674.03!", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
|
== "Your shopping spree cost six hundred and seventy-four dollars and three cents!"
|
||||||
|
)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
normalize_text(
|
normalize_text(
|
||||||
"€30.2 is in euros", normalization_options=NormalizationOptions()
|
"€30.2 is in euros", normalization_options=NormalizationOptions()
|
||||||
|
@ -315,3 +322,12 @@ def test_non_url_text():
|
||||||
normalize_text("It costs $50.", normalization_options=NormalizationOptions())
|
normalize_text("It costs $50.", normalization_options=NormalizationOptions())
|
||||||
== "It costs fifty dollars."
|
== "It costs fifty dollars."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_remaining_symbol():
|
||||||
|
"""Test that remaining symbols are replaced"""
|
||||||
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"I love buying products @ good store here & @ other store", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
|
== "I love buying products at good store here and at other store"
|
||||||
|
)
|
||||||
|
|
|
@ -194,4 +194,31 @@ async def test_smart_split_with_pause():
|
||||||
# Third chunk: text
|
# Third chunk: text
|
||||||
assert chunks[2][2] is None # No pause
|
assert chunks[2][2] is None # No pause
|
||||||
assert "How are you?" in chunks[2][0]
|
assert "How are you?" in chunks[2][0]
|
||||||
|
assert len(chunks[2][1]) > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_smart_split_with_two_pause():
|
||||||
|
"""Test smart splitting with two pause tags."""
|
||||||
|
text = "[pause:0.5s][pause:1.67s]0.5"
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
async for chunk_text, chunk_tokens, pause_duration in smart_split(text):
|
||||||
|
chunks.append((chunk_text, chunk_tokens, pause_duration))
|
||||||
|
|
||||||
|
# Should have 3 chunks: pause, pause, text
|
||||||
|
assert len(chunks) == 3
|
||||||
|
|
||||||
|
# First chunk: pause
|
||||||
|
assert chunks[0][2] == 0.5 # 0.5 second pause
|
||||||
|
assert chunks[0][0] == "" # Empty text
|
||||||
|
assert len(chunks[0][1]) == 0
|
||||||
|
|
||||||
|
# Second chunk: pause
|
||||||
|
assert chunks[1][2] == 1.67 # 1.67 second pause
|
||||||
|
assert chunks[1][0] == "" # Empty text
|
||||||
|
assert len(chunks[1][1]) == 0 # No tokens
|
||||||
|
|
||||||
|
# Third chunk: text
|
||||||
|
assert chunks[2][2] is None # No pause
|
||||||
|
assert "zero point five" in chunks[2][0]
|
||||||
assert len(chunks[2][1]) > 0
|
assert len(chunks[2][1]) > 0
|
Loading…
Add table
Reference in a new issue