mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Added normilization options
This commit is contained in:
parent
8ea8e68b61
commit
09de389b29
5 changed files with 40 additions and 24 deletions
|
@ -3,6 +3,7 @@
|
|||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from typing import AsyncGenerator, Dict, List, Union
|
||||
|
||||
|
@ -138,6 +139,7 @@ async def stream_audio_chunks(
|
|||
speed=request.speed,
|
||||
output_format=request.response_format,
|
||||
lang_code=request.lang_code or request.voice[0],
|
||||
normalization_options=request.normalization_options
|
||||
):
|
||||
# Check if client is still connected
|
||||
is_disconnected = client_request.is_disconnected
|
||||
|
|
|
@ -8,6 +8,8 @@ import re
|
|||
from functools import lru_cache
|
||||
import inflect
|
||||
|
||||
from ...structures.schemas import NormalizationOptions
|
||||
|
||||
# Constants
|
||||
VALID_TLDS = [
|
||||
"com",
|
||||
|
@ -112,10 +114,10 @@ def split_num(num: re.Match[str]) -> str:
|
|||
return f"{left} {right}{s}"
|
||||
|
||||
def handle_units(u: re.Match[str]) -> str:
|
||||
"""Converts units to their full form"""
|
||||
unit_string=u.group(6).strip()
|
||||
unit=unit_string
|
||||
|
||||
print(unit)
|
||||
if unit_string.lower() in VALID_UNITS:
|
||||
unit=VALID_UNITS[unit_string.lower()].split(" ")
|
||||
|
||||
|
@ -213,24 +215,19 @@ def handle_url(u: re.Match[str]) -> str:
|
|||
return re.sub(r"\s+", " ", url).strip()
|
||||
|
||||
|
||||
def normalize_urls(text: str) -> str:
|
||||
"""Pre-process URLs before other text normalization"""
|
||||
# Handle email addresses first
|
||||
text = EMAIL_PATTERN.sub(handle_email, text)
|
||||
|
||||
# Handle URLs
|
||||
text = URL_PATTERN.sub(handle_url, text)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def normalize_text(text: str) -> str:
|
||||
def normalize_text(text: str,normalization_options: NormalizationOptions) -> str:
|
||||
"""Normalize text for TTS processing"""
|
||||
# Pre-process URLs first
|
||||
text = normalize_urls(text)
|
||||
# Handle email addresses first if enabled
|
||||
if normalization_options.email_normalization:
|
||||
text = EMAIL_PATTERN.sub(handle_email, text)
|
||||
|
||||
# Pre-process numbers with units
|
||||
text=UNIT_PATTERN.sub(handle_units,text)
|
||||
# Handle URLs if enabled
|
||||
if normalization_options.url_normalization:
|
||||
text = URL_PATTERN.sub(handle_url, text)
|
||||
|
||||
# Pre-process numbers with units if enabled
|
||||
if normalization_options.unit_normalization:
|
||||
text=UNIT_PATTERN.sub(handle_units,text)
|
||||
|
||||
# Replace quotes and brackets
|
||||
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
|
||||
|
@ -261,12 +258,14 @@ def normalize_text(text: str) -> str:
|
|||
text = re.sub(
|
||||
r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)", split_num, text
|
||||
)
|
||||
|
||||
text = re.sub(r"(?<=\d),(?=\d)", "", text)
|
||||
text = re.sub(
|
||||
r"(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b",
|
||||
handle_money,
|
||||
text,
|
||||
)
|
||||
|
||||
text = re.sub(r"\d*\.\d+", handle_decimal, text)
|
||||
|
||||
# Handle various formatting
|
||||
|
|
|
@ -10,7 +10,7 @@ from ...core.config import settings
|
|||
from .normalizer import normalize_text
|
||||
from .phonemizer import phonemize
|
||||
from .vocabulary import tokenize
|
||||
|
||||
from ...structures.schemas import NormalizationOptions
|
||||
|
||||
def process_text_chunk(
|
||||
text: str, language: str = "a", skip_phonemize: bool = False
|
||||
|
@ -87,8 +87,6 @@ def process_text(text: str, language: str = "a") -> List[int]:
|
|||
|
||||
def get_sentence_info(text: str) -> List[Tuple[str, List[int], int]]:
|
||||
"""Process all sentences and return info."""
|
||||
if settings.advanced_text_normalization:
|
||||
text=normalize_text(text)
|
||||
sentences = re.split(r"([.!?;:])(?=\s|$)", text)
|
||||
results = []
|
||||
for i in range(0, len(sentences), 2):
|
||||
|
@ -106,13 +104,19 @@ def get_sentence_info(text: str) -> List[Tuple[str, List[int], int]]:
|
|||
|
||||
|
||||
async def smart_split(
|
||||
text: str, max_tokens: int = settings.absolute_max_tokens
|
||||
text: str,
|
||||
max_tokens: int = settings.absolute_max_tokens,
|
||||
normalization_options: NormalizationOptions = NormalizationOptions()
|
||||
) -> AsyncGenerator[Tuple[str, List[int]], None]:
|
||||
"""Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens."""
|
||||
start_time = time.time()
|
||||
chunk_count = 0
|
||||
logger.info(f"Starting smart split for {len(text)} chars")
|
||||
|
||||
# Normilize text
|
||||
if settings.advanced_text_normalization and normalization_options.normalize:
|
||||
text=normalize_text(text,normalization_options)
|
||||
|
||||
# Process all sentences
|
||||
sentences = get_sentence_info(text)
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ from ..inference.voice_manager import get_manager as get_voice_manager
|
|||
from .audio import AudioNormalizer, AudioService
|
||||
from .text_processing import tokenize
|
||||
from .text_processing.text_processor import process_text_chunk, smart_split
|
||||
|
||||
from ..structures.schemas import NormalizationOptions
|
||||
|
||||
class TTSService:
|
||||
"""Text-to-speech service."""
|
||||
|
@ -238,6 +238,7 @@ class TTSService:
|
|||
speed: float = 1.0,
|
||||
output_format: str = "wav",
|
||||
lang_code: Optional[str] = None,
|
||||
normalization_options: Optional[NormalizationOptions] = NormalizationOptions()
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""Generate and stream audio chunks."""
|
||||
stream_normalizer = AudioNormalizer()
|
||||
|
@ -258,7 +259,7 @@ class TTSService:
|
|||
)
|
||||
|
||||
# Process text in chunks with smart splitting
|
||||
async for chunk_text, tokens in smart_split(text):
|
||||
async for chunk_text, tokens in smart_split(text,normalization_options=normalization_options):
|
||||
try:
|
||||
# Process audio for chunk
|
||||
async for result in self._process_chunk(
|
||||
|
|
|
@ -36,7 +36,13 @@ class CaptionedSpeechResponse(BaseModel):
|
|||
audio: bytes = Field(..., description="The generated audio data")
|
||||
words: List[WordTimestamp] = Field(..., description="Word-level timestamps")
|
||||
|
||||
|
||||
class NormalizationOptions(BaseModel):
|
||||
"""Options for the normalization system"""
|
||||
normalize: bool = Field(default=True, description="Normalizes input text to make it easier for the model to say")
|
||||
unit_normalization: bool = Field(default=False,description="Transforms units like 10KB to 10 kilobytes")
|
||||
url_normalization: bool = Field(default=True, description="Changes urls so they can be properly pronouced by kokoro")
|
||||
email_normalization: bool = Field(default=True, description="Changes emails so they can be properly pronouced by kokoro")
|
||||
|
||||
class OpenAISpeechRequest(BaseModel):
|
||||
"""Request schema for OpenAI-compatible speech endpoint"""
|
||||
|
||||
|
@ -71,6 +77,10 @@ class OpenAISpeechRequest(BaseModel):
|
|||
default=None,
|
||||
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
|
||||
)
|
||||
normalization_options: Optional[NormalizationOptions] = Field(
|
||||
default= NormalizationOptions(),
|
||||
description= "Options for the normalization system"
|
||||
)
|
||||
|
||||
|
||||
class CaptionedSpeechRequest(BaseModel):
|
||||
|
|
Loading…
Add table
Reference in a new issue