Added normilization options

This commit is contained in:
Fireblade 2025-02-11 19:09:35 -05:00
parent 8ea8e68b61
commit 09de389b29
5 changed files with 40 additions and 24 deletions

View file

@ -3,6 +3,7 @@
import io
import json
import os
import re
import tempfile
from typing import AsyncGenerator, Dict, List, Union
@ -138,6 +139,7 @@ async def stream_audio_chunks(
speed=request.speed,
output_format=request.response_format,
lang_code=request.lang_code or request.voice[0],
normalization_options=request.normalization_options
):
# Check if client is still connected
is_disconnected = client_request.is_disconnected

View file

@ -8,6 +8,8 @@ import re
from functools import lru_cache
import inflect
from ...structures.schemas import NormalizationOptions
# Constants
VALID_TLDS = [
"com",
@ -112,10 +114,10 @@ def split_num(num: re.Match[str]) -> str:
return f"{left} {right}{s}"
def handle_units(u: re.Match[str]) -> str:
"""Converts units to their full form"""
unit_string=u.group(6).strip()
unit=unit_string
print(unit)
if unit_string.lower() in VALID_UNITS:
unit=VALID_UNITS[unit_string.lower()].split(" ")
@ -213,23 +215,18 @@ def handle_url(u: re.Match[str]) -> str:
return re.sub(r"\s+", " ", url).strip()
def normalize_urls(text: str) -> str:
"""Pre-process URLs before other text normalization"""
# Handle email addresses first
def normalize_text(text: str,normalization_options: NormalizationOptions) -> str:
"""Normalize text for TTS processing"""
# Handle email addresses first if enabled
if normalization_options.email_normalization:
text = EMAIL_PATTERN.sub(handle_email, text)
# Handle URLs
# Handle URLs if enabled
if normalization_options.url_normalization:
text = URL_PATTERN.sub(handle_url, text)
return text
def normalize_text(text: str) -> str:
"""Normalize text for TTS processing"""
# Pre-process URLs first
text = normalize_urls(text)
# Pre-process numbers with units
# Pre-process numbers with units if enabled
if normalization_options.unit_normalization:
text=UNIT_PATTERN.sub(handle_units,text)
# Replace quotes and brackets
@ -261,12 +258,14 @@ def normalize_text(text: str) -> str:
text = re.sub(
r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)", split_num, text
)
text = re.sub(r"(?<=\d),(?=\d)", "", text)
text = re.sub(
r"(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b",
handle_money,
text,
)
text = re.sub(r"\d*\.\d+", handle_decimal, text)
# Handle various formatting

View file

@ -10,7 +10,7 @@ from ...core.config import settings
from .normalizer import normalize_text
from .phonemizer import phonemize
from .vocabulary import tokenize
from ...structures.schemas import NormalizationOptions
def process_text_chunk(
text: str, language: str = "a", skip_phonemize: bool = False
@ -87,8 +87,6 @@ def process_text(text: str, language: str = "a") -> List[int]:
def get_sentence_info(text: str) -> List[Tuple[str, List[int], int]]:
"""Process all sentences and return info."""
if settings.advanced_text_normalization:
text=normalize_text(text)
sentences = re.split(r"([.!?;:])(?=\s|$)", text)
results = []
for i in range(0, len(sentences), 2):
@ -106,13 +104,19 @@ def get_sentence_info(text: str) -> List[Tuple[str, List[int], int]]:
async def smart_split(
text: str, max_tokens: int = settings.absolute_max_tokens
text: str,
max_tokens: int = settings.absolute_max_tokens,
normalization_options: NormalizationOptions = NormalizationOptions()
) -> AsyncGenerator[Tuple[str, List[int]], None]:
"""Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens."""
start_time = time.time()
chunk_count = 0
logger.info(f"Starting smart split for {len(text)} chars")
# Normilize text
if settings.advanced_text_normalization and normalization_options.normalize:
text=normalize_text(text,normalization_options)
# Process all sentences
sentences = get_sentence_info(text)

View file

@ -18,7 +18,7 @@ from ..inference.voice_manager import get_manager as get_voice_manager
from .audio import AudioNormalizer, AudioService
from .text_processing import tokenize
from .text_processing.text_processor import process_text_chunk, smart_split
from ..structures.schemas import NormalizationOptions
class TTSService:
"""Text-to-speech service."""
@ -238,6 +238,7 @@ class TTSService:
speed: float = 1.0,
output_format: str = "wav",
lang_code: Optional[str] = None,
normalization_options: Optional[NormalizationOptions] = NormalizationOptions()
) -> AsyncGenerator[bytes, None]:
"""Generate and stream audio chunks."""
stream_normalizer = AudioNormalizer()
@ -258,7 +259,7 @@ class TTSService:
)
# Process text in chunks with smart splitting
async for chunk_text, tokens in smart_split(text):
async for chunk_text, tokens in smart_split(text,normalization_options=normalization_options):
try:
# Process audio for chunk
async for result in self._process_chunk(

View file

@ -36,6 +36,12 @@ class CaptionedSpeechResponse(BaseModel):
audio: bytes = Field(..., description="The generated audio data")
words: List[WordTimestamp] = Field(..., description="Word-level timestamps")
class NormalizationOptions(BaseModel):
"""Options for the normalization system"""
normalize: bool = Field(default=True, description="Normalizes input text to make it easier for the model to say")
unit_normalization: bool = Field(default=False,description="Transforms units like 10KB to 10 kilobytes")
url_normalization: bool = Field(default=True, description="Changes urls so they can be properly pronouced by kokoro")
email_normalization: bool = Field(default=True, description="Changes emails so they can be properly pronouced by kokoro")
class OpenAISpeechRequest(BaseModel):
"""Request schema for OpenAI-compatible speech endpoint"""
@ -71,6 +77,10 @@ class OpenAISpeechRequest(BaseModel):
default=None,
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
)
normalization_options: Optional[NormalizationOptions] = Field(
default= NormalizationOptions(),
description= "Options for the normalization system"
)
class CaptionedSpeechRequest(BaseModel):