mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
-add email handling, minor additional URL processing, tests
This commit is contained in:
parent
1625082724
commit
a0a85f5ef0
2 changed files with 94 additions and 31 deletions
|
@ -1,12 +1,29 @@
|
||||||
|
"""
|
||||||
|
Text normalization module for TTS processing.
|
||||||
|
Handles various text formats including URLs, emails, numbers, money, and special characters.
|
||||||
|
Converts them into a format suitable for text-to-speech processing.
|
||||||
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
valid_tlds=["com", "org", "net", "edu", "gov", "mil", "int", "biz", "info", "name",
|
# Constants
|
||||||
|
VALID_TLDS = [
|
||||||
|
"com", "org", "net", "edu", "gov", "mil", "int", "biz", "info", "name",
|
||||||
"pro", "coop", "museum", "travel", "jobs", "mobi", "tel", "asia", "cat",
|
"pro", "coop", "museum", "travel", "jobs", "mobi", "tel", "asia", "cat",
|
||||||
"xxx", "aero", "arpa", "bg", "br", "ca", "cn", "de", "es", "eu", "fr",
|
"xxx", "aero", "arpa", "bg", "br", "ca", "cn", "de", "es", "eu", "fr",
|
||||||
"in", "it", "jp", "mx", "nl", "ru", "uk", "us", "io"]
|
"in", "it", "jp", "mx", "nl", "ru", "uk", "us", "io"
|
||||||
|
]
|
||||||
|
|
||||||
def split_num(num: re.Match) -> str:
|
# Pre-compiled regex patterns for performance
|
||||||
|
EMAIL_PATTERN = re.compile(r"\b[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-z]{2,}\b", re.IGNORECASE)
|
||||||
|
URL_PATTERN = re.compile(
|
||||||
|
r"(https?://|www\.|)+(localhost|[a-zA-Z0-9.-]+(\.(?:" +
|
||||||
|
"|".join(VALID_TLDS) + "))+|[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3})(:[0-9]+)?([/?][^\s]*)?",
|
||||||
|
re.IGNORECASE
|
||||||
|
)
|
||||||
|
|
||||||
|
def split_num(num: re.Match[str]) -> str:
|
||||||
"""Handle number splitting for various formats"""
|
"""Handle number splitting for various formats"""
|
||||||
num = num.group()
|
num = num.group()
|
||||||
if "." in num:
|
if "." in num:
|
||||||
|
@ -30,7 +47,7 @@ def split_num(num: re.Match) -> str:
|
||||||
return f"{left} oh {right}{s}"
|
return f"{left} oh {right}{s}"
|
||||||
return f"{left} {right}{s}"
|
return f"{left} {right}{s}"
|
||||||
|
|
||||||
def handle_money(m: re.Match) -> str:
|
def handle_money(m: re.Match[str]) -> str:
|
||||||
"""Convert money expressions to spoken form"""
|
"""Convert money expressions to spoken form"""
|
||||||
m = m.group()
|
m = m.group()
|
||||||
bill = "dollar" if m[0] == "$" else "pound"
|
bill = "dollar" if m[0] == "$" else "pound"
|
||||||
|
@ -49,32 +66,56 @@ def handle_money(m: re.Match) -> str:
|
||||||
)
|
)
|
||||||
return f"{b} {bill}{s} and {c} {coins}"
|
return f"{b} {bill}{s} and {c} {coins}"
|
||||||
|
|
||||||
def handle_decimal(num: re.Match) -> str:
|
def handle_decimal(num: re.Match[str]) -> str:
|
||||||
"""Convert decimal numbers to spoken form"""
|
"""Convert decimal numbers to spoken form"""
|
||||||
a, b = num.group().split(".")
|
a, b = num.group().split(".")
|
||||||
return " point ".join([a, " ".join(b)])
|
return " point ".join([a, " ".join(b)])
|
||||||
|
|
||||||
def handle_url(u: re.Match) -> str:
|
def handle_email(m: re.Match[str]) -> str:
|
||||||
|
"""Convert email addresses into speakable format"""
|
||||||
|
email = m.group(0)
|
||||||
|
parts = email.split('@')
|
||||||
|
if len(parts) == 2:
|
||||||
|
user, domain = parts
|
||||||
|
domain = domain.replace('.', ' dot ')
|
||||||
|
return f"{user} at {domain}"
|
||||||
|
return email
|
||||||
|
|
||||||
|
def handle_url(u: re.Match[str]) -> str:
|
||||||
"""Make URLs speakable by converting special characters to spoken words"""
|
"""Make URLs speakable by converting special characters to spoken words"""
|
||||||
if not u:
|
if not u:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
url = u.group(0).strip()
|
url = u.group(0).strip()
|
||||||
# Handle common URL prefixes
|
|
||||||
url = re.sub(r'^https?://', lambda a : 'https ' if 'https' in a.group() else 'http', url, flags=re.IGNORECASE)
|
# Handle protocol first
|
||||||
|
url = re.sub(r'^https?://', lambda a: 'https ' if 'https' in a.group() else 'http ', url, flags=re.IGNORECASE)
|
||||||
url = re.sub(r'^www\.', 'www ', url, flags=re.IGNORECASE)
|
url = re.sub(r'^www\.', 'www ', url, flags=re.IGNORECASE)
|
||||||
|
|
||||||
# Replace symbols with words
|
# Handle port numbers before other replacements
|
||||||
|
url = re.sub(r':(\d+)(?=/|$)', lambda m: f" colon {m.group(1)}", url)
|
||||||
|
|
||||||
url = url.replace(":", " colon ")
|
# Split into domain and path
|
||||||
|
parts = url.split('/', 1)
|
||||||
|
domain = parts[0]
|
||||||
|
path = parts[1] if len(parts) > 1 else ''
|
||||||
|
|
||||||
|
# Handle dots in domain
|
||||||
|
domain = domain.replace('.', ' dot ')
|
||||||
|
|
||||||
|
# Reconstruct URL
|
||||||
|
if path:
|
||||||
|
url = f"{domain} slash {path}"
|
||||||
|
else:
|
||||||
|
url = domain
|
||||||
|
|
||||||
|
# Replace remaining symbols with words
|
||||||
url = url.replace("-", " dash ")
|
url = url.replace("-", " dash ")
|
||||||
url = url.replace("_", " underscore ")
|
url = url.replace("_", " underscore ")
|
||||||
url = url.replace("/", " slash ")
|
|
||||||
url = url.replace(".", " dot ")
|
|
||||||
url = url.replace("@", " at ")
|
|
||||||
url = url.replace("?", " question-mark ")
|
url = url.replace("?", " question-mark ")
|
||||||
url = url.replace("=", " equals ")
|
url = url.replace("=", " equals ")
|
||||||
url = url.replace("&", " ampersand ")
|
url = url.replace("&", " ampersand ")
|
||||||
|
url = url.replace(":", " colon ") # Handle any remaining colons
|
||||||
|
|
||||||
# Clean up extra spaces
|
# Clean up extra spaces
|
||||||
return re.sub(r'\s+', ' ', url).strip()
|
return re.sub(r'\s+', ' ', url).strip()
|
||||||
|
@ -82,20 +123,17 @@ def handle_url(u: re.Match) -> str:
|
||||||
|
|
||||||
def normalize_urls(text: str) -> str:
|
def normalize_urls(text: str) -> str:
|
||||||
"""Pre-process URLs before other text normalization"""
|
"""Pre-process URLs before other text normalization"""
|
||||||
url_patterns = [
|
# Handle email addresses first
|
||||||
r"(https?://|www\.|)+(localhost|[a-zA-Z0-9.-]+(\.(?:" + "|".join(valid_tlds) + "))+|[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3})(:[0-9]+)?([/?][^\s]*)?", # URLs with http(s), raw ip, www, or domain.tld
|
text = EMAIL_PATTERN.sub(handle_email, text)
|
||||||
r"\b[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-z]{2,}\b" # Email addresses
|
|
||||||
]
|
|
||||||
|
|
||||||
for pattern in url_patterns:
|
# Handle URLs
|
||||||
text = re.sub(pattern, handle_url, text, flags=re.IGNORECASE)
|
text = URL_PATTERN.sub(handle_url, text)
|
||||||
|
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def normalize_text(text: str) -> str:
|
def normalize_text(text: str) -> str:
|
||||||
"""Normalize text for TTS processing"""
|
"""Normalize text for TTS processing"""
|
||||||
# Pre-process URLs first
|
# Pre-process URLs first
|
||||||
|
|
||||||
text = normalize_urls(text)
|
text = normalize_urls(text)
|
||||||
|
|
||||||
# Replace quotes and brackets
|
# Replace quotes and brackets
|
||||||
|
|
|
@ -3,19 +3,44 @@
|
||||||
import pytest
|
import pytest
|
||||||
from api.src.services.text_processing.normalizer import normalize_text
|
from api.src.services.text_processing.normalizer import normalize_text
|
||||||
|
|
||||||
def test_urls():
|
def test_url_protocols():
|
||||||
"""Test URL handling"""
|
"""Test URL protocol handling"""
|
||||||
# URLs with http/https
|
assert normalize_text("Check out https://example.com") == "Check out https example dot com"
|
||||||
assert normalize_text("Check out https://example.com") == "Check out http example dot com"
|
assert normalize_text("Visit http://site.com") == "Visit http site dot com"
|
||||||
assert normalize_text("Visit http://site.com/docs") == "Visit http site dot com slash docs"
|
assert normalize_text("Go to https://test.org/path") == "Go to https test dot org slash path"
|
||||||
|
|
||||||
# URLs with www
|
def test_url_www():
|
||||||
|
"""Test www prefix handling"""
|
||||||
assert normalize_text("Go to www.example.com") == "Go to www example dot com"
|
assert normalize_text("Go to www.example.com") == "Go to www example dot com"
|
||||||
|
assert normalize_text("Visit www.test.org/docs") == "Visit www test dot org slash docs"
|
||||||
|
assert normalize_text("Check www.site.com?q=test") == "Check www site dot com question-mark q equals test"
|
||||||
|
|
||||||
# Email addresses
|
def test_url_localhost():
|
||||||
|
"""Test localhost URL handling"""
|
||||||
|
assert normalize_text("Running on localhost:7860") == "Running on localhost colon 78 60"
|
||||||
|
assert normalize_text("Server at localhost:8080/api") == "Server at localhost colon 80 80 slash api"
|
||||||
|
assert normalize_text("Test localhost:3000/test?v=1") == "Test localhost colon 3000 slash test question-mark v equals 1"
|
||||||
|
|
||||||
|
def test_url_ip_addresses():
|
||||||
|
"""Test IP address URL handling"""
|
||||||
|
assert normalize_text("Access 0.0.0.0:9090/test") == "Access 0 dot 0 dot 0 dot 0 colon 90 90 slash test"
|
||||||
|
assert normalize_text("API at 192.168.1.1:8000") == "API at 192 dot 168 dot 1 dot 1 colon 8000"
|
||||||
|
assert normalize_text("Server 127.0.0.1") == "Server 127 dot 0 dot 0 dot 1"
|
||||||
|
|
||||||
|
def test_url_raw_domains():
|
||||||
|
"""Test raw domain handling"""
|
||||||
|
assert normalize_text("Visit google.com/search") == "Visit google dot com slash search"
|
||||||
|
assert normalize_text("Go to example.com/path?q=test") == "Go to example dot com slash path question-mark q equals test"
|
||||||
|
assert normalize_text("Check docs.test.com") == "Check docs dot test dot com"
|
||||||
|
|
||||||
|
def test_url_email_addresses():
|
||||||
|
"""Test email address handling"""
|
||||||
assert normalize_text("Email me at user@example.com") == "Email me at user at example dot com"
|
assert normalize_text("Email me at user@example.com") == "Email me at user at example dot com"
|
||||||
|
assert normalize_text("Contact admin@test.org") == "Contact admin at test dot org"
|
||||||
|
assert normalize_text("Send to test.user@site.com") == "Send to test dot user at site dot com"
|
||||||
|
|
||||||
# Normal text should be unaffected, other than downstream normalization
|
def test_non_url_text():
|
||||||
|
"""Test that non-URL text is unaffected"""
|
||||||
assert normalize_text("This is not.a.url text") == "This is not-a-url text"
|
assert normalize_text("This is not.a.url text") == "This is not-a-url text"
|
||||||
assert normalize_text("Hello, how are you today?") == "Hello, how are you today?"
|
assert normalize_text("Hello, how are you today?") == "Hello, how are you today?"
|
||||||
assert normalize_text("It costs $50.") == "It costs 50 dollars."
|
assert normalize_text("It costs $50.") == "It costs 50 dollars."
|
||||||
|
|
Loading…
Add table
Reference in a new issue