Improve text normalize to keep original timestamps

This commit is contained in:
fondoger 2025-03-30 21:31:17 +08:00
parent d0c13f6401
commit 88f19d7751
2 changed files with 37 additions and 21 deletions

View file

@ -276,7 +276,9 @@ class KokoroV1(BaseModelBackend):
]
):
continue
if not token.text or not token.text.strip():
# token.start_ts may be None
if not token.text or not token.text.strip() or token.start_ts is None or token.end_ts is None:
continue
start_time = float(token.start_ts) + current_offset

View file

@ -10,6 +10,7 @@ import inflect
from numpy import number
from torch import mul
from ...structures.schemas import NormalizationOptions
from misaki import en
from text_to_num import text2num
@ -90,10 +91,23 @@ URL_PATTERN = re.compile(
UNIT_PATTERN = re.compile(r"((?<!\w)([+-]?)(\d{1,3}(,\d{3})*|\d+)(\.\d+)?)\s*(" + "|".join(sorted(list(VALID_UNITS.keys()),reverse=True)) + r"""){1}(?=[^\w\d]{1}|\b)""",re.IGNORECASE)
TIME_PATTERN = re.compile(r"([0-9]{2} ?: ?[0-9]{2}( ?: ?[0-9]{2})?)( ?(pm|am)\b)?", re.IGNORECASE)
TIME_PATTERN = re.compile(r"([0-9]{1,2} ?: ?[0-9]{2}( ?: ?[0-9]{2})?)( ?(pm|am)\b)?", re.IGNORECASE)
INFLECT_ENGINE=inflect.engine()
g2p = en.G2P(trf=False, british=False, fallback=None)
def sound_like(text: str, sound_like: str) -> str:
"""
Convert a string into a sound-alike format
Kokoro supports embedding phonemes in the text, and the token timestamps is based on the original text.
- Original Input Text: '[Misaki](/misˈɑki/) is a G2P engine designed for [Kokoro](/kˈOkəɹO/) models.'
- Text For Timestamps: 'Misaki is a G2P engine designed for Kokoro models.'
"""
phonemes, _ = g2p(sound_like)
return f"[{text}](/{phonemes}/)"
def split_num(num: re.Match[str]) -> str:
"""Handle number splitting for various formats"""
num = num.group()
@ -116,7 +130,7 @@ def split_num(num: re.Match[str]) -> str:
return f"{left} hundred{s}"
elif right < 10:
return f"{left} oh {right}{s}"
return f"{left} {right}{s}"
return sound_like(num, f"{left} {right}{s}")
def handle_units(u: re.Match[str]) -> str:
"""Converts units to their full form"""
@ -134,7 +148,7 @@ def handle_units(u: re.Match[str]) -> str:
number=u.group(1).strip()
unit[0]=INFLECT_ENGINE.no(unit[0],number)
return " ".join(unit)
return sound_like(u.group(), " ".join(unit))
def conditional_int(number: float, threshold: float = 0.00001):
if abs(round(number) - number) < threshold:
@ -164,12 +178,12 @@ def handle_money(m: re.Match[str]) -> str:
text_number = f"{INFLECT_ENGINE.number_to_words(int(round(number)))} {INFLECT_ENGINE.plural(bill, count=number)} and {INFLECT_ENGINE.number_to_words(sub_number)} {INFLECT_ENGINE.plural(coin, count=sub_number)}"
return text_number
return sound_like(m.group(), text_number)
def handle_decimal(num: re.Match[str]) -> str:
"""Convert decimal numbers to spoken form"""
a, b = num.group().split(".")
return " point ".join([a, " ".join(b)])
return sound_like(num.group(), " point ".join([a, " ".join(b)]))
def handle_email(m: re.Match[str]) -> str:
@ -179,7 +193,7 @@ def handle_email(m: re.Match[str]) -> str:
if len(parts) == 2:
user, domain = parts
domain = domain.replace(".", " dot ")
return f"{user} at {domain}"
return sound_like(email, f"{user} at {domain}")
return email
@ -227,34 +241,34 @@ def handle_url(u: re.Match[str]) -> str:
url = url.replace("/", " slash ") # Handle any remaining slashes
# Clean up extra spaces
return re.sub(r"\s+", " ", url).strip()
return sound_like(u.group(), re.sub(r"\s+", " ", url).strip())
def handle_phone_number(p: re.Match[str]) -> str:
p=list(p.groups())
g=list(p.groups())
country_code=""
if p[0] is not None:
p[0]=p[0].replace("+","")
country_code += INFLECT_ENGINE.number_to_words(p[0])
if g[0] is not None:
g[0]=g[0].replace("+","")
country_code += INFLECT_ENGINE.number_to_words(g[0])
area_code=INFLECT_ENGINE.number_to_words(p[2].replace("(","").replace(")",""),group=1,comma="")
area_code=INFLECT_ENGINE.number_to_words(g[2].replace("(","").replace(")",""),group=1,comma="")
telephone_prefix=INFLECT_ENGINE.number_to_words(p[3],group=1,comma="")
telephone_prefix=INFLECT_ENGINE.number_to_words(g[3],group=1,comma="")
line_number=INFLECT_ENGINE.number_to_words(p[4],group=1,comma="")
line_number=INFLECT_ENGINE.number_to_words(g[4],group=1,comma="")
return ",".join([country_code,area_code,telephone_prefix,line_number])
return sound_like(p.group(), ",".join([country_code,area_code,telephone_prefix,line_number]))
def handle_time(t: re.Match[str]) -> str:
t=t.groups()
g = t.groups()
numbers = " ".join([INFLECT_ENGINE.number_to_words(X.strip()) for X in t[0].split(":")])
numbers = " ".join([INFLECT_ENGINE.number_to_words(X.strip()) for X in g[0].split(":")])
half=""
if t[2] is not None:
half=t[2].strip()
if g[2] is not None:
half=g[2].strip()
return numbers + half
return sound_like(t.group(), numbers + half)
def normalize_text(text: str,normalization_options: NormalizationOptions) -> str:
"""Normalize text for TTS processing"""