mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
simplified some normalization and added more tests
This commit is contained in:
parent
dbb66ff1e1
commit
d7d90cdc9d
7 changed files with 77 additions and 29 deletions
|
@ -47,12 +47,12 @@ class StreamingAudioWriter:
|
|||
)
|
||||
self.stream = self.container.add_stream(
|
||||
codec_map[self.format],
|
||||
rate=self.sample_rate, # Correct parameter name is 'rate'
|
||||
rate=self.sample_rate,
|
||||
layout="mono" if self.channels == 1 else "stereo",
|
||||
)
|
||||
# Set bit_rate only for codecs where it's applicable and useful
|
||||
if self.format in ['mp3', 'aac', 'opus']:
|
||||
self.stream.bit_rate = 128000 # Example bitrate, can be configured
|
||||
self.stream.bit_rate = 128000
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: {self.format}") # Use self.format here
|
||||
|
||||
|
|
|
@ -134,6 +134,23 @@ VALID_UNITS = {
|
|||
"px": "pixel", # CSS units
|
||||
}
|
||||
|
||||
SYMBOL_REPLACEMENTS = {
|
||||
'~': ' ',
|
||||
'@': ' at ',
|
||||
'#': ' number ',
|
||||
'$': ' dollar ',
|
||||
'%': ' percent ',
|
||||
'^': ' ',
|
||||
'&': ' and ',
|
||||
'*': ' ',
|
||||
'_': ' ',
|
||||
'|': ' ',
|
||||
'\\': ' ',
|
||||
'/': ' slash ',
|
||||
'=': ' equals ',
|
||||
'+': ' plus ',
|
||||
}
|
||||
|
||||
MONEY_UNITS = {"$": ("dollar", "cent"), "£": ("pound", "pence"), "€": ("euro", "cent")}
|
||||
|
||||
# Pre-compiled regex patterns for performance
|
||||
|
@ -464,20 +481,9 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
|
|||
text = re.sub(r"\d*\.\d+", handle_decimal, text)
|
||||
|
||||
# Handle other problematic symbols AFTER money/number processing
|
||||
text = text.replace('~', '') # Remove tilde
|
||||
text = text.replace('@', ' at ') # At symbol
|
||||
text = text.replace('#', ' number ') # Hash/pound
|
||||
text = text.replace('$', ' dollar ') # Dollar sign (if not handled by money pattern)
|
||||
text = text.replace('%', ' percent ') # Percent sign
|
||||
text = text.replace('^', '') # Caret
|
||||
text = text.replace('&', ' and ') # Ampersand
|
||||
text = text.replace('*', '') # Asterisk
|
||||
text = text.replace('_', ' ') # Underscore to space
|
||||
text = text.replace('|', ' ') # Pipe to space
|
||||
text = text.replace('\\', ' ') # Backslash to space
|
||||
text = text.replace('/', ' slash ') # Forward slash to space (unless in URLs)
|
||||
text = text.replace('=', ' equals ') # Equals sign
|
||||
text = text.replace('+', ' plus ') # Plus sign
|
||||
if normalization_options.replace_remaining_symbols:
|
||||
for symbol, replacement in SYMBOL_REPLACEMENTS.items():
|
||||
text = text.replace(symbol, replacement)
|
||||
|
||||
# Handle various formatting
|
||||
text = re.sub(r"(?<=\d)-(?=\d)", " to ", text)
|
||||
|
@ -489,4 +495,6 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
|
|||
)
|
||||
text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
|
||||
|
||||
text = re.sub(r"\s{2,}", " ", text)
|
||||
|
||||
return text.strip()
|
||||
|
|
|
@ -84,13 +84,12 @@ def create_phonemizer(language: str = "a") -> PhonemizerBackend:
|
|||
return EspeakBackend(lang_map[language])
|
||||
|
||||
|
||||
def phonemize(text: str, language: str = "a", normalize: bool = True) -> str:
|
||||
def phonemize(text: str, language: str = "a") -> str:
|
||||
"""Convert text to phonemes
|
||||
|
||||
Args:
|
||||
text: Text to convert to phonemes
|
||||
language: Language code ('a' for US English, 'b' for British English)
|
||||
normalize: Whether to normalize text before phonemization
|
||||
|
||||
Returns:
|
||||
Phonemized text
|
||||
|
@ -100,13 +99,6 @@ def phonemize(text: str, language: str = "a", normalize: bool = True) -> str:
|
|||
# Strip input text first to remove problematic leading/trailing spaces
|
||||
text = text.strip()
|
||||
|
||||
if normalize:
|
||||
# Create default normalization options and normalize text
|
||||
normalization_options = NormalizationOptions()
|
||||
text = normalize_text(text, normalization_options)
|
||||
# Strip again after normalization
|
||||
text = text.strip()
|
||||
|
||||
if language not in phonemizers:
|
||||
phonemizers[language] = create_phonemizer(language)
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ def process_text_chunk(
|
|||
t1 = time.time()
|
||||
|
||||
t0 = time.time()
|
||||
phonemes = phonemize(text, language, normalize=False) # Already normalized
|
||||
phonemes = phonemize(text, language)
|
||||
# Strip phonemes result to ensure no extra spaces
|
||||
phonemes = phonemes.strip()
|
||||
t1 = time.time()
|
||||
|
@ -102,11 +102,11 @@ def process_text(text: str, language: str = "a") -> List[int]:
|
|||
def get_sentence_info(
|
||||
text: str, custom_phenomes_list: Dict[str, str], lang_code: str = "a"
|
||||
) -> List[Tuple[str, List[int], int]]:
|
||||
"""Process all sentences and return info, 支持中文分句"""
|
||||
# 判断是否为中文
|
||||
"""Process all sentences and return info"""
|
||||
# Detect Chinese text
|
||||
is_chinese = lang_code.startswith("z") or re.search(r"[\u4e00-\u9fff]", text)
|
||||
if is_chinese:
|
||||
# 按中文标点断句
|
||||
# Split using Chinese punctuation
|
||||
sentences = re.split(r"([,。!?;])+", text)
|
||||
else:
|
||||
sentences = re.split(r"([.!?;:])(?=\s|$)", text)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from email.policy import default
|
||||
from enum import Enum
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
|
@ -66,6 +67,10 @@ class NormalizationOptions(BaseModel):
|
|||
default=True,
|
||||
description="Changes phone numbers so they can be properly pronouced by kokoro",
|
||||
)
|
||||
replace_remaining_symbols: bool = Field(
|
||||
default=True,
|
||||
description="Replaces the remaining symbols after normalization with their words"
|
||||
)
|
||||
|
||||
|
||||
class OpenAISpeechRequest(BaseModel):
|
||||
|
|
|
@ -175,6 +175,13 @@ def test_money():
|
|||
== "The plant cost two hundred thousand dollars and eighty cents."
|
||||
)
|
||||
|
||||
assert (
|
||||
normalize_text(
|
||||
"Your shopping spree cost $674.03!", normalization_options=NormalizationOptions()
|
||||
)
|
||||
== "Your shopping spree cost six hundred and seventy-four dollars and three cents!"
|
||||
)
|
||||
|
||||
assert (
|
||||
normalize_text(
|
||||
"€30.2 is in euros", normalization_options=NormalizationOptions()
|
||||
|
@ -315,3 +322,12 @@ def test_non_url_text():
|
|||
normalize_text("It costs $50.", normalization_options=NormalizationOptions())
|
||||
== "It costs fifty dollars."
|
||||
)
|
||||
|
||||
def test_remaining_symbol():
|
||||
"""Test that remaining symbols are replaced"""
|
||||
assert (
|
||||
normalize_text(
|
||||
"I love buying products @ good store here & @ other store", normalization_options=NormalizationOptions()
|
||||
)
|
||||
== "I love buying products at good store here and at other store"
|
||||
)
|
||||
|
|
|
@ -194,4 +194,31 @@ async def test_smart_split_with_pause():
|
|||
# Third chunk: text
|
||||
assert chunks[2][2] is None # No pause
|
||||
assert "How are you?" in chunks[2][0]
|
||||
assert len(chunks[2][1]) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_split_with_two_pause():
|
||||
"""Test smart splitting with two pause tags."""
|
||||
text = "[pause:0.5s][pause:1.67s]0.5"
|
||||
|
||||
chunks = []
|
||||
async for chunk_text, chunk_tokens, pause_duration in smart_split(text):
|
||||
chunks.append((chunk_text, chunk_tokens, pause_duration))
|
||||
|
||||
# Should have 3 chunks: pause, pause, text
|
||||
assert len(chunks) == 3
|
||||
|
||||
# First chunk: pause
|
||||
assert chunks[0][2] == 0.5 # 0.5 second pause
|
||||
assert chunks[0][0] == "" # Empty text
|
||||
assert len(chunks[0][1]) == 0
|
||||
|
||||
# Second chunk: pause
|
||||
assert chunks[1][2] == 1.67 # 1.67 second pause
|
||||
assert chunks[1][0] == "" # Empty text
|
||||
assert len(chunks[1][1]) == 0 # No tokens
|
||||
|
||||
# Third chunk: text
|
||||
assert chunks[2][2] is None # No pause
|
||||
assert "zero point five" in chunks[2][0]
|
||||
assert len(chunks[2][1]) > 0
|
Loading…
Add table
Reference in a new issue