Added normilization options

This commit is contained in:
Fireblade 2025-02-11 19:09:35 -05:00
parent 8ea8e68b61
commit 09de389b29
5 changed files with 40 additions and 24 deletions

View file

@ -3,6 +3,7 @@
import io import io
import json import json
import os import os
import re
import tempfile import tempfile
from typing import AsyncGenerator, Dict, List, Union from typing import AsyncGenerator, Dict, List, Union
@ -138,6 +139,7 @@ async def stream_audio_chunks(
speed=request.speed, speed=request.speed,
output_format=request.response_format, output_format=request.response_format,
lang_code=request.lang_code or request.voice[0], lang_code=request.lang_code or request.voice[0],
normalization_options=request.normalization_options
): ):
# Check if client is still connected # Check if client is still connected
is_disconnected = client_request.is_disconnected is_disconnected = client_request.is_disconnected

View file

@ -8,6 +8,8 @@ import re
from functools import lru_cache from functools import lru_cache
import inflect import inflect
from ...structures.schemas import NormalizationOptions
# Constants # Constants
VALID_TLDS = [ VALID_TLDS = [
"com", "com",
@ -112,10 +114,10 @@ def split_num(num: re.Match[str]) -> str:
return f"{left} {right}{s}" return f"{left} {right}{s}"
def handle_units(u: re.Match[str]) -> str: def handle_units(u: re.Match[str]) -> str:
"""Converts units to their full form"""
unit_string=u.group(6).strip() unit_string=u.group(6).strip()
unit=unit_string unit=unit_string
print(unit)
if unit_string.lower() in VALID_UNITS: if unit_string.lower() in VALID_UNITS:
unit=VALID_UNITS[unit_string.lower()].split(" ") unit=VALID_UNITS[unit_string.lower()].split(" ")
@ -213,24 +215,19 @@ def handle_url(u: re.Match[str]) -> str:
return re.sub(r"\s+", " ", url).strip() return re.sub(r"\s+", " ", url).strip()
def normalize_urls(text: str) -> str: def normalize_text(text: str,normalization_options: NormalizationOptions) -> str:
"""Pre-process URLs before other text normalization"""
# Handle email addresses first
text = EMAIL_PATTERN.sub(handle_email, text)
# Handle URLs
text = URL_PATTERN.sub(handle_url, text)
return text
def normalize_text(text: str) -> str:
"""Normalize text for TTS processing""" """Normalize text for TTS processing"""
# Pre-process URLs first # Handle email addresses first if enabled
text = normalize_urls(text) if normalization_options.email_normalization:
text = EMAIL_PATTERN.sub(handle_email, text)
# Pre-process numbers with units # Handle URLs if enabled
text=UNIT_PATTERN.sub(handle_units,text) if normalization_options.url_normalization:
text = URL_PATTERN.sub(handle_url, text)
# Pre-process numbers with units if enabled
if normalization_options.unit_normalization:
text=UNIT_PATTERN.sub(handle_units,text)
# Replace quotes and brackets # Replace quotes and brackets
text = text.replace(chr(8216), "'").replace(chr(8217), "'") text = text.replace(chr(8216), "'").replace(chr(8217), "'")
@ -261,12 +258,14 @@ def normalize_text(text: str) -> str:
text = re.sub( text = re.sub(
r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)", split_num, text r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)", split_num, text
) )
text = re.sub(r"(?<=\d),(?=\d)", "", text) text = re.sub(r"(?<=\d),(?=\d)", "", text)
text = re.sub( text = re.sub(
r"(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b", r"(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b",
handle_money, handle_money,
text, text,
) )
text = re.sub(r"\d*\.\d+", handle_decimal, text) text = re.sub(r"\d*\.\d+", handle_decimal, text)
# Handle various formatting # Handle various formatting

View file

@ -10,7 +10,7 @@ from ...core.config import settings
from .normalizer import normalize_text from .normalizer import normalize_text
from .phonemizer import phonemize from .phonemizer import phonemize
from .vocabulary import tokenize from .vocabulary import tokenize
from ...structures.schemas import NormalizationOptions
def process_text_chunk( def process_text_chunk(
text: str, language: str = "a", skip_phonemize: bool = False text: str, language: str = "a", skip_phonemize: bool = False
@ -87,8 +87,6 @@ def process_text(text: str, language: str = "a") -> List[int]:
def get_sentence_info(text: str) -> List[Tuple[str, List[int], int]]: def get_sentence_info(text: str) -> List[Tuple[str, List[int], int]]:
"""Process all sentences and return info.""" """Process all sentences and return info."""
if settings.advanced_text_normalization:
text=normalize_text(text)
sentences = re.split(r"([.!?;:])(?=\s|$)", text) sentences = re.split(r"([.!?;:])(?=\s|$)", text)
results = [] results = []
for i in range(0, len(sentences), 2): for i in range(0, len(sentences), 2):
@ -106,13 +104,19 @@ def get_sentence_info(text: str) -> List[Tuple[str, List[int], int]]:
async def smart_split( async def smart_split(
text: str, max_tokens: int = settings.absolute_max_tokens text: str,
max_tokens: int = settings.absolute_max_tokens,
normalization_options: NormalizationOptions = NormalizationOptions()
) -> AsyncGenerator[Tuple[str, List[int]], None]: ) -> AsyncGenerator[Tuple[str, List[int]], None]:
"""Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens.""" """Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens."""
start_time = time.time() start_time = time.time()
chunk_count = 0 chunk_count = 0
logger.info(f"Starting smart split for {len(text)} chars") logger.info(f"Starting smart split for {len(text)} chars")
# Normilize text
if settings.advanced_text_normalization and normalization_options.normalize:
text=normalize_text(text,normalization_options)
# Process all sentences # Process all sentences
sentences = get_sentence_info(text) sentences = get_sentence_info(text)

View file

@ -18,7 +18,7 @@ from ..inference.voice_manager import get_manager as get_voice_manager
from .audio import AudioNormalizer, AudioService from .audio import AudioNormalizer, AudioService
from .text_processing import tokenize from .text_processing import tokenize
from .text_processing.text_processor import process_text_chunk, smart_split from .text_processing.text_processor import process_text_chunk, smart_split
from ..structures.schemas import NormalizationOptions
class TTSService: class TTSService:
"""Text-to-speech service.""" """Text-to-speech service."""
@ -238,6 +238,7 @@ class TTSService:
speed: float = 1.0, speed: float = 1.0,
output_format: str = "wav", output_format: str = "wav",
lang_code: Optional[str] = None, lang_code: Optional[str] = None,
normalization_options: Optional[NormalizationOptions] = NormalizationOptions()
) -> AsyncGenerator[bytes, None]: ) -> AsyncGenerator[bytes, None]:
"""Generate and stream audio chunks.""" """Generate and stream audio chunks."""
stream_normalizer = AudioNormalizer() stream_normalizer = AudioNormalizer()
@ -258,7 +259,7 @@ class TTSService:
) )
# Process text in chunks with smart splitting # Process text in chunks with smart splitting
async for chunk_text, tokens in smart_split(text): async for chunk_text, tokens in smart_split(text,normalization_options=normalization_options):
try: try:
# Process audio for chunk # Process audio for chunk
async for result in self._process_chunk( async for result in self._process_chunk(

View file

@ -36,7 +36,13 @@ class CaptionedSpeechResponse(BaseModel):
audio: bytes = Field(..., description="The generated audio data") audio: bytes = Field(..., description="The generated audio data")
words: List[WordTimestamp] = Field(..., description="Word-level timestamps") words: List[WordTimestamp] = Field(..., description="Word-level timestamps")
class NormalizationOptions(BaseModel):
"""Options for the normalization system"""
normalize: bool = Field(default=True, description="Normalizes input text to make it easier for the model to say")
unit_normalization: bool = Field(default=False,description="Transforms units like 10KB to 10 kilobytes")
url_normalization: bool = Field(default=True, description="Changes urls so they can be properly pronouced by kokoro")
email_normalization: bool = Field(default=True, description="Changes emails so they can be properly pronouced by kokoro")
class OpenAISpeechRequest(BaseModel): class OpenAISpeechRequest(BaseModel):
"""Request schema for OpenAI-compatible speech endpoint""" """Request schema for OpenAI-compatible speech endpoint"""
@ -71,6 +77,10 @@ class OpenAISpeechRequest(BaseModel):
default=None, default=None,
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.", description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
) )
normalization_options: Optional[NormalizationOptions] = Field(
default= NormalizationOptions(),
description= "Options for the normalization system"
)
class CaptionedSpeechRequest(BaseModel): class CaptionedSpeechRequest(BaseModel):