mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Added normilization options
This commit is contained in:
parent
8ea8e68b61
commit
09de389b29
5 changed files with 40 additions and 24 deletions
|
@ -3,6 +3,7 @@
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import AsyncGenerator, Dict, List, Union
|
from typing import AsyncGenerator, Dict, List, Union
|
||||||
|
|
||||||
|
@ -138,6 +139,7 @@ async def stream_audio_chunks(
|
||||||
speed=request.speed,
|
speed=request.speed,
|
||||||
output_format=request.response_format,
|
output_format=request.response_format,
|
||||||
lang_code=request.lang_code or request.voice[0],
|
lang_code=request.lang_code or request.voice[0],
|
||||||
|
normalization_options=request.normalization_options
|
||||||
):
|
):
|
||||||
# Check if client is still connected
|
# Check if client is still connected
|
||||||
is_disconnected = client_request.is_disconnected
|
is_disconnected = client_request.is_disconnected
|
||||||
|
|
|
@ -8,6 +8,8 @@ import re
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
import inflect
|
import inflect
|
||||||
|
|
||||||
|
from ...structures.schemas import NormalizationOptions
|
||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
VALID_TLDS = [
|
VALID_TLDS = [
|
||||||
"com",
|
"com",
|
||||||
|
@ -112,10 +114,10 @@ def split_num(num: re.Match[str]) -> str:
|
||||||
return f"{left} {right}{s}"
|
return f"{left} {right}{s}"
|
||||||
|
|
||||||
def handle_units(u: re.Match[str]) -> str:
|
def handle_units(u: re.Match[str]) -> str:
|
||||||
|
"""Converts units to their full form"""
|
||||||
unit_string=u.group(6).strip()
|
unit_string=u.group(6).strip()
|
||||||
unit=unit_string
|
unit=unit_string
|
||||||
|
|
||||||
print(unit)
|
|
||||||
if unit_string.lower() in VALID_UNITS:
|
if unit_string.lower() in VALID_UNITS:
|
||||||
unit=VALID_UNITS[unit_string.lower()].split(" ")
|
unit=VALID_UNITS[unit_string.lower()].split(" ")
|
||||||
|
|
||||||
|
@ -213,24 +215,19 @@ def handle_url(u: re.Match[str]) -> str:
|
||||||
return re.sub(r"\s+", " ", url).strip()
|
return re.sub(r"\s+", " ", url).strip()
|
||||||
|
|
||||||
|
|
||||||
def normalize_urls(text: str) -> str:
|
def normalize_text(text: str,normalization_options: NormalizationOptions) -> str:
|
||||||
"""Pre-process URLs before other text normalization"""
|
|
||||||
# Handle email addresses first
|
|
||||||
text = EMAIL_PATTERN.sub(handle_email, text)
|
|
||||||
|
|
||||||
# Handle URLs
|
|
||||||
text = URL_PATTERN.sub(handle_url, text)
|
|
||||||
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_text(text: str) -> str:
|
|
||||||
"""Normalize text for TTS processing"""
|
"""Normalize text for TTS processing"""
|
||||||
# Pre-process URLs first
|
# Handle email addresses first if enabled
|
||||||
text = normalize_urls(text)
|
if normalization_options.email_normalization:
|
||||||
|
text = EMAIL_PATTERN.sub(handle_email, text)
|
||||||
|
|
||||||
# Pre-process numbers with units
|
# Handle URLs if enabled
|
||||||
text=UNIT_PATTERN.sub(handle_units,text)
|
if normalization_options.url_normalization:
|
||||||
|
text = URL_PATTERN.sub(handle_url, text)
|
||||||
|
|
||||||
|
# Pre-process numbers with units if enabled
|
||||||
|
if normalization_options.unit_normalization:
|
||||||
|
text=UNIT_PATTERN.sub(handle_units,text)
|
||||||
|
|
||||||
# Replace quotes and brackets
|
# Replace quotes and brackets
|
||||||
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
|
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
|
||||||
|
@ -261,12 +258,14 @@ def normalize_text(text: str) -> str:
|
||||||
text = re.sub(
|
text = re.sub(
|
||||||
r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)", split_num, text
|
r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)", split_num, text
|
||||||
)
|
)
|
||||||
|
|
||||||
text = re.sub(r"(?<=\d),(?=\d)", "", text)
|
text = re.sub(r"(?<=\d),(?=\d)", "", text)
|
||||||
text = re.sub(
|
text = re.sub(
|
||||||
r"(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b",
|
r"(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b",
|
||||||
handle_money,
|
handle_money,
|
||||||
text,
|
text,
|
||||||
)
|
)
|
||||||
|
|
||||||
text = re.sub(r"\d*\.\d+", handle_decimal, text)
|
text = re.sub(r"\d*\.\d+", handle_decimal, text)
|
||||||
|
|
||||||
# Handle various formatting
|
# Handle various formatting
|
||||||
|
|
|
@ -10,7 +10,7 @@ from ...core.config import settings
|
||||||
from .normalizer import normalize_text
|
from .normalizer import normalize_text
|
||||||
from .phonemizer import phonemize
|
from .phonemizer import phonemize
|
||||||
from .vocabulary import tokenize
|
from .vocabulary import tokenize
|
||||||
|
from ...structures.schemas import NormalizationOptions
|
||||||
|
|
||||||
def process_text_chunk(
|
def process_text_chunk(
|
||||||
text: str, language: str = "a", skip_phonemize: bool = False
|
text: str, language: str = "a", skip_phonemize: bool = False
|
||||||
|
@ -87,8 +87,6 @@ def process_text(text: str, language: str = "a") -> List[int]:
|
||||||
|
|
||||||
def get_sentence_info(text: str) -> List[Tuple[str, List[int], int]]:
|
def get_sentence_info(text: str) -> List[Tuple[str, List[int], int]]:
|
||||||
"""Process all sentences and return info."""
|
"""Process all sentences and return info."""
|
||||||
if settings.advanced_text_normalization:
|
|
||||||
text=normalize_text(text)
|
|
||||||
sentences = re.split(r"([.!?;:])(?=\s|$)", text)
|
sentences = re.split(r"([.!?;:])(?=\s|$)", text)
|
||||||
results = []
|
results = []
|
||||||
for i in range(0, len(sentences), 2):
|
for i in range(0, len(sentences), 2):
|
||||||
|
@ -106,13 +104,19 @@ def get_sentence_info(text: str) -> List[Tuple[str, List[int], int]]:
|
||||||
|
|
||||||
|
|
||||||
async def smart_split(
|
async def smart_split(
|
||||||
text: str, max_tokens: int = settings.absolute_max_tokens
|
text: str,
|
||||||
|
max_tokens: int = settings.absolute_max_tokens,
|
||||||
|
normalization_options: NormalizationOptions = NormalizationOptions()
|
||||||
) -> AsyncGenerator[Tuple[str, List[int]], None]:
|
) -> AsyncGenerator[Tuple[str, List[int]], None]:
|
||||||
"""Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens."""
|
"""Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens."""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
logger.info(f"Starting smart split for {len(text)} chars")
|
logger.info(f"Starting smart split for {len(text)} chars")
|
||||||
|
|
||||||
|
# Normilize text
|
||||||
|
if settings.advanced_text_normalization and normalization_options.normalize:
|
||||||
|
text=normalize_text(text,normalization_options)
|
||||||
|
|
||||||
# Process all sentences
|
# Process all sentences
|
||||||
sentences = get_sentence_info(text)
|
sentences = get_sentence_info(text)
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ from ..inference.voice_manager import get_manager as get_voice_manager
|
||||||
from .audio import AudioNormalizer, AudioService
|
from .audio import AudioNormalizer, AudioService
|
||||||
from .text_processing import tokenize
|
from .text_processing import tokenize
|
||||||
from .text_processing.text_processor import process_text_chunk, smart_split
|
from .text_processing.text_processor import process_text_chunk, smart_split
|
||||||
|
from ..structures.schemas import NormalizationOptions
|
||||||
|
|
||||||
class TTSService:
|
class TTSService:
|
||||||
"""Text-to-speech service."""
|
"""Text-to-speech service."""
|
||||||
|
@ -238,6 +238,7 @@ class TTSService:
|
||||||
speed: float = 1.0,
|
speed: float = 1.0,
|
||||||
output_format: str = "wav",
|
output_format: str = "wav",
|
||||||
lang_code: Optional[str] = None,
|
lang_code: Optional[str] = None,
|
||||||
|
normalization_options: Optional[NormalizationOptions] = NormalizationOptions()
|
||||||
) -> AsyncGenerator[bytes, None]:
|
) -> AsyncGenerator[bytes, None]:
|
||||||
"""Generate and stream audio chunks."""
|
"""Generate and stream audio chunks."""
|
||||||
stream_normalizer = AudioNormalizer()
|
stream_normalizer = AudioNormalizer()
|
||||||
|
@ -258,7 +259,7 @@ class TTSService:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process text in chunks with smart splitting
|
# Process text in chunks with smart splitting
|
||||||
async for chunk_text, tokens in smart_split(text):
|
async for chunk_text, tokens in smart_split(text,normalization_options=normalization_options):
|
||||||
try:
|
try:
|
||||||
# Process audio for chunk
|
# Process audio for chunk
|
||||||
async for result in self._process_chunk(
|
async for result in self._process_chunk(
|
||||||
|
|
|
@ -36,7 +36,13 @@ class CaptionedSpeechResponse(BaseModel):
|
||||||
audio: bytes = Field(..., description="The generated audio data")
|
audio: bytes = Field(..., description="The generated audio data")
|
||||||
words: List[WordTimestamp] = Field(..., description="Word-level timestamps")
|
words: List[WordTimestamp] = Field(..., description="Word-level timestamps")
|
||||||
|
|
||||||
|
class NormalizationOptions(BaseModel):
|
||||||
|
"""Options for the normalization system"""
|
||||||
|
normalize: bool = Field(default=True, description="Normalizes input text to make it easier for the model to say")
|
||||||
|
unit_normalization: bool = Field(default=False,description="Transforms units like 10KB to 10 kilobytes")
|
||||||
|
url_normalization: bool = Field(default=True, description="Changes urls so they can be properly pronouced by kokoro")
|
||||||
|
email_normalization: bool = Field(default=True, description="Changes emails so they can be properly pronouced by kokoro")
|
||||||
|
|
||||||
class OpenAISpeechRequest(BaseModel):
|
class OpenAISpeechRequest(BaseModel):
|
||||||
"""Request schema for OpenAI-compatible speech endpoint"""
|
"""Request schema for OpenAI-compatible speech endpoint"""
|
||||||
|
|
||||||
|
@ -71,6 +77,10 @@ class OpenAISpeechRequest(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
|
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
|
||||||
)
|
)
|
||||||
|
normalization_options: Optional[NormalizationOptions] = Field(
|
||||||
|
default= NormalizationOptions(),
|
||||||
|
description= "Options for the normalization system"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CaptionedSpeechRequest(BaseModel):
|
class CaptionedSpeechRequest(BaseModel):
|
||||||
|
|
Loading…
Add table
Reference in a new issue