diff --git a/api/src/routers/openai_compatible.py b/api/src/routers/openai_compatible.py index 5508d65..3be678a 100644 --- a/api/src/routers/openai_compatible.py +++ b/api/src/routers/openai_compatible.py @@ -3,6 +3,7 @@ import io import json import os +import re import tempfile from typing import AsyncGenerator, Dict, List, Union @@ -138,6 +139,7 @@ async def stream_audio_chunks( speed=request.speed, output_format=request.response_format, lang_code=request.lang_code or request.voice[0], + normalization_options=request.normalization_options ): # Check if client is still connected is_disconnected = client_request.is_disconnected diff --git a/api/src/services/text_processing/normalizer.py b/api/src/services/text_processing/normalizer.py index ca26ffb..3bc9021 100644 --- a/api/src/services/text_processing/normalizer.py +++ b/api/src/services/text_processing/normalizer.py @@ -8,6 +8,8 @@ import re from functools import lru_cache import inflect +from ...structures.schemas import NormalizationOptions + # Constants VALID_TLDS = [ "com", @@ -112,10 +114,10 @@ def split_num(num: re.Match[str]) -> str: return f"{left} {right}{s}" def handle_units(u: re.Match[str]) -> str: + """Converts units to their full form""" unit_string=u.group(6).strip() unit=unit_string - print(unit) if unit_string.lower() in VALID_UNITS: unit=VALID_UNITS[unit_string.lower()].split(" ") @@ -213,24 +215,19 @@ def handle_url(u: re.Match[str]) -> str: return re.sub(r"\s+", " ", url).strip() -def normalize_urls(text: str) -> str: - """Pre-process URLs before other text normalization""" - # Handle email addresses first - text = EMAIL_PATTERN.sub(handle_email, text) - - # Handle URLs - text = URL_PATTERN.sub(handle_url, text) - - return text - - -def normalize_text(text: str) -> str: +def normalize_text(text: str,normalization_options: NormalizationOptions) -> str: """Normalize text for TTS processing""" - # Pre-process URLs first - text = normalize_urls(text) + # Handle email addresses first if enabled + if normalization_options.email_normalization: + text = EMAIL_PATTERN.sub(handle_email, text) - # Pre-process numbers with units - text=UNIT_PATTERN.sub(handle_units,text) + # Handle URLs if enabled + if normalization_options.url_normalization: + text = URL_PATTERN.sub(handle_url, text) + + # Pre-process numbers with units if enabled + if normalization_options.unit_normalization: + text=UNIT_PATTERN.sub(handle_units,text) # Replace quotes and brackets text = text.replace(chr(8216), "'").replace(chr(8217), "'") @@ -261,12 +258,14 @@ def normalize_text(text: str) -> str: text = re.sub( r"\d*\.\d+|\b\d{4}s?\b|(? List[int]: def get_sentence_info(text: str) -> List[Tuple[str, List[int], int]]: """Process all sentences and return info.""" - if settings.advanced_text_normalization: - text=normalize_text(text) sentences = re.split(r"([.!?;:])(?=\s|$)", text) results = [] for i in range(0, len(sentences), 2): @@ -106,13 +104,19 @@ def get_sentence_info(text: str) -> List[Tuple[str, List[int], int]]: async def smart_split( - text: str, max_tokens: int = settings.absolute_max_tokens + text: str, + max_tokens: int = settings.absolute_max_tokens, + normalization_options: NormalizationOptions = NormalizationOptions() ) -> AsyncGenerator[Tuple[str, List[int]], None]: """Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens.""" start_time = time.time() chunk_count = 0 logger.info(f"Starting smart split for {len(text)} chars") + # Normilize text + if settings.advanced_text_normalization and normalization_options.normalize: + text=normalize_text(text,normalization_options) + # Process all sentences sentences = get_sentence_info(text) diff --git a/api/src/services/tts_service.py b/api/src/services/tts_service.py index 3d533d9..ba4dcc4 100644 --- a/api/src/services/tts_service.py +++ b/api/src/services/tts_service.py @@ -18,7 +18,7 @@ from ..inference.voice_manager import get_manager as get_voice_manager from .audio import AudioNormalizer, AudioService from .text_processing import tokenize from .text_processing.text_processor import process_text_chunk, smart_split - +from ..structures.schemas import NormalizationOptions class TTSService: """Text-to-speech service.""" @@ -238,6 +238,7 @@ class TTSService: speed: float = 1.0, output_format: str = "wav", lang_code: Optional[str] = None, + normalization_options: Optional[NormalizationOptions] = NormalizationOptions() ) -> AsyncGenerator[bytes, None]: """Generate and stream audio chunks.""" stream_normalizer = AudioNormalizer() @@ -258,7 +259,7 @@ class TTSService: ) # Process text in chunks with smart splitting - async for chunk_text, tokens in smart_split(text): + async for chunk_text, tokens in smart_split(text,normalization_options=normalization_options): try: # Process audio for chunk async for result in self._process_chunk( diff --git a/api/src/structures/schemas.py b/api/src/structures/schemas.py index 4e76a69..491ae60 100644 --- a/api/src/structures/schemas.py +++ b/api/src/structures/schemas.py @@ -36,7 +36,13 @@ class CaptionedSpeechResponse(BaseModel): audio: bytes = Field(..., description="The generated audio data") words: List[WordTimestamp] = Field(..., description="Word-level timestamps") - +class NormalizationOptions(BaseModel): + """Options for the normalization system""" + normalize: bool = Field(default=True, description="Normalizes input text to make it easier for the model to say") + unit_normalization: bool = Field(default=False,description="Transforms units like 10KB to 10 kilobytes") + url_normalization: bool = Field(default=True, description="Changes urls so they can be properly pronouced by kokoro") + email_normalization: bool = Field(default=True, description="Changes emails so they can be properly pronouced by kokoro") + class OpenAISpeechRequest(BaseModel): """Request schema for OpenAI-compatible speech endpoint""" @@ -71,6 +77,10 @@ class OpenAISpeechRequest(BaseModel): default=None, description="Optional language code to use for text processing. If not provided, will use first letter of voice name.", ) + normalization_options: Optional[NormalizationOptions] = Field( + default= NormalizationOptions(), + description= "Options for the normalization system" + ) class CaptionedSpeechRequest(BaseModel):