simplified some normalization and added more tests

This commit is contained in:
Fireblade2534 2025-06-12 16:00:06 +00:00
parent dbb66ff1e1
commit d7d90cdc9d
7 changed files with 77 additions and 29 deletions

View file

@ -47,12 +47,12 @@ class StreamingAudioWriter:
)
self.stream = self.container.add_stream(
codec_map[self.format],
rate=self.sample_rate, # Correct parameter name is 'rate'
rate=self.sample_rate,
layout="mono" if self.channels == 1 else "stereo",
)
# Set bit_rate only for codecs where it's applicable and useful
if self.format in ['mp3', 'aac', 'opus']:
self.stream.bit_rate = 128000 # Example bitrate, can be configured
self.stream.bit_rate = 128000
else:
raise ValueError(f"Unsupported format: {self.format}") # Use self.format here

View file

@ -134,6 +134,23 @@ VALID_UNITS = {
"px": "pixel", # CSS units
}
SYMBOL_REPLACEMENTS = {
'~': ' ',
'@': ' at ',
'#': ' number ',
'$': ' dollar ',
'%': ' percent ',
'^': ' ',
'&': ' and ',
'*': ' ',
'_': ' ',
'|': ' ',
'\\': ' ',
'/': ' slash ',
'=': ' equals ',
'+': ' plus ',
}
MONEY_UNITS = {"$": ("dollar", "cent"), "£": ("pound", "pence"), "": ("euro", "cent")}
# Pre-compiled regex patterns for performance
@ -464,20 +481,9 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
text = re.sub(r"\d*\.\d+", handle_decimal, text)
# Handle other problematic symbols AFTER money/number processing
text = text.replace('~', '') # Remove tilde
text = text.replace('@', ' at ') # At symbol
text = text.replace('#', ' number ') # Hash/pound
text = text.replace('$', ' dollar ') # Dollar sign (if not handled by money pattern)
text = text.replace('%', ' percent ') # Percent sign
text = text.replace('^', '') # Caret
text = text.replace('&', ' and ') # Ampersand
text = text.replace('*', '') # Asterisk
text = text.replace('_', ' ') # Underscore to space
text = text.replace('|', ' ') # Pipe to space
text = text.replace('\\', ' ') # Backslash to space
text = text.replace('/', ' slash ') # Forward slash to space (unless in URLs)
text = text.replace('=', ' equals ') # Equals sign
text = text.replace('+', ' plus ') # Plus sign
if normalization_options.replace_remaining_symbols:
for symbol, replacement in SYMBOL_REPLACEMENTS.items():
text = text.replace(symbol, replacement)
# Handle various formatting
text = re.sub(r"(?<=\d)-(?=\d)", " to ", text)
@ -489,4 +495,6 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
)
text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
text = re.sub(r"\s{2,}", " ", text)
return text.strip()

View file

@ -84,13 +84,12 @@ def create_phonemizer(language: str = "a") -> PhonemizerBackend:
return EspeakBackend(lang_map[language])
def phonemize(text: str, language: str = "a", normalize: bool = True) -> str:
def phonemize(text: str, language: str = "a") -> str:
"""Convert text to phonemes
Args:
text: Text to convert to phonemes
language: Language code ('a' for US English, 'b' for British English)
normalize: Whether to normalize text before phonemization
Returns:
Phonemized text
@ -100,13 +99,6 @@ def phonemize(text: str, language: str = "a", normalize: bool = True) -> str:
# Strip input text first to remove problematic leading/trailing spaces
text = text.strip()
if normalize:
# Create default normalization options and normalize text
normalization_options = NormalizationOptions()
text = normalize_text(text, normalization_options)
# Strip again after normalization
text = text.strip()
if language not in phonemizers:
phonemizers[language] = create_phonemizer(language)

View file

@ -52,7 +52,7 @@ def process_text_chunk(
t1 = time.time()
t0 = time.time()
phonemes = phonemize(text, language, normalize=False) # Already normalized
phonemes = phonemize(text, language)
# Strip phonemes result to ensure no extra spaces
phonemes = phonemes.strip()
t1 = time.time()
@ -102,11 +102,11 @@ def process_text(text: str, language: str = "a") -> List[int]:
def get_sentence_info(
text: str, custom_phenomes_list: Dict[str, str], lang_code: str = "a"
) -> List[Tuple[str, List[int], int]]:
"""Process all sentences and return info, 支持中文分句"""
# 判断是否为中文
"""Process all sentences and return info"""
# Detect Chinese text
is_chinese = lang_code.startswith("z") or re.search(r"[\u4e00-\u9fff]", text)
if is_chinese:
# 按中文标点断句
# Split using Chinese punctuation
sentences = re.split(r"([,。!?;])+", text)
else:
sentences = re.split(r"([.!?;:])(?=\s|$)", text)

View file

@ -1,3 +1,4 @@
from email.policy import default
from enum import Enum
from typing import List, Literal, Optional, Union
@ -66,6 +67,10 @@ class NormalizationOptions(BaseModel):
default=True,
description="Changes phone numbers so they can be properly pronouced by kokoro",
)
replace_remaining_symbols: bool = Field(
default=True,
description="Replaces the remaining symbols after normalization with their words"
)
class OpenAISpeechRequest(BaseModel):

View file

@ -175,6 +175,13 @@ def test_money():
== "The plant cost two hundred thousand dollars and eighty cents."
)
assert (
normalize_text(
"Your shopping spree cost $674.03!", normalization_options=NormalizationOptions()
)
== "Your shopping spree cost six hundred and seventy-four dollars and three cents!"
)
assert (
normalize_text(
"€30.2 is in euros", normalization_options=NormalizationOptions()
@ -315,3 +322,12 @@ def test_non_url_text():
normalize_text("It costs $50.", normalization_options=NormalizationOptions())
== "It costs fifty dollars."
)
def test_remaining_symbol():
"""Test that remaining symbols are replaced"""
assert (
normalize_text(
"I love buying products @ good store here & @ other store", normalization_options=NormalizationOptions()
)
== "I love buying products at good store here and at other store"
)

View file

@ -194,4 +194,31 @@ async def test_smart_split_with_pause():
# Third chunk: text
assert chunks[2][2] is None # No pause
assert "How are you?" in chunks[2][0]
assert len(chunks[2][1]) > 0
@pytest.mark.asyncio
async def test_smart_split_with_two_pause():
"""Test smart splitting with two pause tags."""
text = "[pause:0.5s][pause:1.67s]0.5"
chunks = []
async for chunk_text, chunk_tokens, pause_duration in smart_split(text):
chunks.append((chunk_text, chunk_tokens, pause_duration))
# Should have 3 chunks: pause, pause, text
assert len(chunks) == 3
# First chunk: pause
assert chunks[0][2] == 0.5 # 0.5 second pause
assert chunks[0][0] == "" # Empty text
assert len(chunks[0][1]) == 0
# Second chunk: pause
assert chunks[1][2] == 1.67 # 1.67 second pause
assert chunks[1][0] == "" # Empty text
assert len(chunks[1][1]) == 0 # No tokens
# Third chunk: text
assert chunks[2][2] is None # No pause
assert "zero point five" in chunks[2][0]
assert len(chunks[2][1]) > 0