mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Refactor audio processing and text normalization: Update audio normalization to use absolute amplitude threshold, enhance streaming audio writer with MP3 container options, and improve text normalization by stripping spaces and handling special characters to prevent audio artifacts.
This commit is contained in:
parent
543cbecc1a
commit
ab8ab7d749
6 changed files with 80 additions and 15 deletions
|
@ -80,12 +80,12 @@ class AudioNormalizer:
|
||||||
non_silent_index_start, non_silent_index_end = None, None
|
non_silent_index_start, non_silent_index_end = None, None
|
||||||
|
|
||||||
for X in range(0, len(audio_data)):
|
for X in range(0, len(audio_data)):
|
||||||
if audio_data[X] > amplitude_threshold:
|
if abs(audio_data[X]) > amplitude_threshold:
|
||||||
non_silent_index_start = X
|
non_silent_index_start = X
|
||||||
break
|
break
|
||||||
|
|
||||||
for X in range(len(audio_data) - 1, -1, -1):
|
for X in range(len(audio_data) - 1, -1, -1):
|
||||||
if audio_data[X] > amplitude_threshold:
|
if abs(audio_data[X]) > amplitude_threshold:
|
||||||
non_silent_index_end = X
|
non_silent_index_end = X
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
|
@ -32,19 +32,29 @@ class StreamingAudioWriter:
|
||||||
if self.format in ["wav", "flac", "mp3", "pcm", "aac", "opus"]:
|
if self.format in ["wav", "flac", "mp3", "pcm", "aac", "opus"]:
|
||||||
if self.format != "pcm":
|
if self.format != "pcm":
|
||||||
self.output_buffer = BytesIO()
|
self.output_buffer = BytesIO()
|
||||||
|
container_options = {}
|
||||||
|
# Try disabling Xing VBR header for MP3 to fix iOS timeline reading issues
|
||||||
|
if self.format == 'mp3':
|
||||||
|
# Disable Xing VBR header
|
||||||
|
container_options = {'write_xing': '0'}
|
||||||
|
logger.debug("Disabling Xing VBR header for MP3 encoding.")
|
||||||
|
|
||||||
self.container = av.open(
|
self.container = av.open(
|
||||||
self.output_buffer,
|
self.output_buffer,
|
||||||
mode="w",
|
mode="w",
|
||||||
format=self.format if self.format != "aac" else "adts",
|
format=self.format if self.format != "aac" else "adts",
|
||||||
|
options=container_options # Pass options here
|
||||||
)
|
)
|
||||||
self.stream = self.container.add_stream(
|
self.stream = self.container.add_stream(
|
||||||
codec_map[self.format],
|
codec_map[self.format],
|
||||||
sample_rate=self.sample_rate,
|
rate=self.sample_rate, # Correct parameter name is 'rate'
|
||||||
layout="mono" if self.channels == 1 else "stereo",
|
layout="mono" if self.channels == 1 else "stereo",
|
||||||
)
|
)
|
||||||
self.stream.bit_rate = 128000
|
# Set bit_rate only for codecs where it's applicable and useful
|
||||||
|
if self.format in ['mp3', 'aac', 'opus']:
|
||||||
|
self.stream.bit_rate = 128000 # Example bitrate, can be configured
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported format: {format}")
|
raise ValueError(f"Unsupported format: {self.format}") # Use self.format here
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if hasattr(self, "container"):
|
if hasattr(self, "container"):
|
||||||
|
@ -65,12 +75,18 @@ class StreamingAudioWriter:
|
||||||
|
|
||||||
if finalize:
|
if finalize:
|
||||||
if self.format != "pcm":
|
if self.format != "pcm":
|
||||||
|
# Flush stream encoder
|
||||||
packets = self.stream.encode(None)
|
packets = self.stream.encode(None)
|
||||||
for packet in packets:
|
for packet in packets:
|
||||||
self.container.mux(packet)
|
self.container.mux(packet)
|
||||||
|
|
||||||
|
# Closing the container handles writing the trailer and finalizing the file.
|
||||||
|
# No explicit flush method is available or needed here.
|
||||||
|
logger.debug("Muxed final packets.")
|
||||||
|
|
||||||
|
# Get the final bytes from the buffer *before* closing it
|
||||||
data = self.output_buffer.getvalue()
|
data = self.output_buffer.getvalue()
|
||||||
self.close()
|
self.close() # Close container and buffer
|
||||||
return data
|
return data
|
||||||
|
|
||||||
if audio_data is None or len(audio_data) == 0:
|
if audio_data is None or len(audio_data) == 0:
|
||||||
|
|
|
@ -391,6 +391,7 @@ def handle_time(t: re.Match[str]) -> str:
|
||||||
|
|
||||||
def normalize_text(text: str, normalization_options: NormalizationOptions) -> str:
|
def normalize_text(text: str, normalization_options: NormalizationOptions) -> str:
|
||||||
"""Normalize text for TTS processing"""
|
"""Normalize text for TTS processing"""
|
||||||
|
|
||||||
# Handle email addresses first if enabled
|
# Handle email addresses first if enabled
|
||||||
if normalization_options.email_normalization:
|
if normalization_options.email_normalization:
|
||||||
text = EMAIL_PATTERN.sub(handle_email, text)
|
text = EMAIL_PATTERN.sub(handle_email, text)
|
||||||
|
@ -415,7 +416,7 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
|
||||||
text,
|
text,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Replace quotes and brackets
|
# Replace quotes and brackets (additional cleanup)
|
||||||
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
|
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
|
||||||
text = text.replace("«", chr(8220)).replace("»", chr(8221))
|
text = text.replace("«", chr(8220)).replace("»", chr(8221))
|
||||||
text = text.replace(chr(8220), '"').replace(chr(8221), '"')
|
text = text.replace(chr(8220), '"').replace(chr(8221), '"')
|
||||||
|
@ -435,6 +436,27 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
|
||||||
text = re.sub(r" +", " ", text)
|
text = re.sub(r" +", " ", text)
|
||||||
text = re.sub(r"(?<=\n) +(?=\n)", "", text)
|
text = re.sub(r"(?<=\n) +(?=\n)", "", text)
|
||||||
|
|
||||||
|
# Handle special characters that might cause audio artifacts first
|
||||||
|
# Replace newlines with spaces (or pauses if needed)
|
||||||
|
text = text.replace('\n', ' ')
|
||||||
|
text = text.replace('\r', ' ')
|
||||||
|
|
||||||
|
# Handle other problematic symbols
|
||||||
|
text = text.replace('~', '') # Remove tilde
|
||||||
|
text = text.replace('@', ' at ') # At symbol
|
||||||
|
text = text.replace('#', ' number ') # Hash/pound
|
||||||
|
text = text.replace('$', ' dollar ') # Dollar sign (if not handled by money pattern)
|
||||||
|
text = text.replace('%', ' percent ') # Percent sign
|
||||||
|
text = text.replace('^', '') # Caret
|
||||||
|
text = text.replace('&', ' and ') # Ampersand
|
||||||
|
text = text.replace('*', '') # Asterisk
|
||||||
|
text = text.replace('_', ' ') # Underscore to space
|
||||||
|
text = text.replace('|', ' ') # Pipe to space
|
||||||
|
text = text.replace('\\', ' ') # Backslash to space
|
||||||
|
text = text.replace('/', ' slash ') # Forward slash to space (unless in URLs)
|
||||||
|
text = text.replace('=', ' equals ') # Equals sign
|
||||||
|
text = text.replace('+', ' plus ') # Plus sign
|
||||||
|
|
||||||
# Handle titles and abbreviations
|
# Handle titles and abbreviations
|
||||||
text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text)
|
text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text)
|
||||||
text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text)
|
text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text)
|
||||||
|
|
|
@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
|
||||||
import phonemizer
|
import phonemizer
|
||||||
|
|
||||||
from .normalizer import normalize_text
|
from .normalizer import normalize_text
|
||||||
|
from ...structures.schemas import NormalizationOptions
|
||||||
|
|
||||||
phonemizers = {}
|
phonemizers = {}
|
||||||
|
|
||||||
|
@ -95,8 +96,20 @@ def phonemize(text: str, language: str = "a", normalize: bool = True) -> str:
|
||||||
Phonemized text
|
Phonemized text
|
||||||
"""
|
"""
|
||||||
global phonemizers
|
global phonemizers
|
||||||
|
|
||||||
|
# Strip input text first to remove problematic leading/trailing spaces
|
||||||
|
text = text.strip()
|
||||||
|
|
||||||
if normalize:
|
if normalize:
|
||||||
text = normalize_text(text)
|
# Create default normalization options and normalize text
|
||||||
|
normalization_options = NormalizationOptions()
|
||||||
|
text = normalize_text(text, normalization_options)
|
||||||
|
# Strip again after normalization
|
||||||
|
text = text.strip()
|
||||||
|
|
||||||
if language not in phonemizers:
|
if language not in phonemizers:
|
||||||
phonemizers[language] = create_phonemizer(language)
|
phonemizers[language] = create_phonemizer(language)
|
||||||
return phonemizers[language].phonemize(text)
|
|
||||||
|
result = phonemizers[language].phonemize(text)
|
||||||
|
# Final strip to ensure no leading/trailing spaces in phonemes
|
||||||
|
return result.strip()
|
||||||
|
|
|
@ -30,6 +30,12 @@ def process_text_chunk(
|
||||||
List of token IDs
|
List of token IDs
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Strip input text to remove any leading/trailing spaces that could cause artifacts
|
||||||
|
text = text.strip()
|
||||||
|
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
|
||||||
if skip_phonemize:
|
if skip_phonemize:
|
||||||
# Input is already phonemes, just tokenize
|
# Input is already phonemes, just tokenize
|
||||||
|
@ -43,6 +49,8 @@ def process_text_chunk(
|
||||||
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
phonemes = phonemize(text, language, normalize=False) # Already normalized
|
phonemes = phonemize(text, language, normalize=False) # Already normalized
|
||||||
|
# Strip phonemes result to ensure no extra spaces
|
||||||
|
phonemes = phonemes.strip()
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
@ -114,6 +122,10 @@ def get_sentence_info(
|
||||||
if not sentence:
|
if not sentence:
|
||||||
continue
|
continue
|
||||||
full = sentence + punct
|
full = sentence + punct
|
||||||
|
# Strip the full sentence to remove any leading/trailing spaces before processing
|
||||||
|
full = full.strip()
|
||||||
|
if not full: # Skip if empty after stripping
|
||||||
|
continue
|
||||||
tokens = process_text_chunk(full)
|
tokens = process_text_chunk(full)
|
||||||
results.append((full, tokens, len(tokens)))
|
results.append((full, tokens, len(tokens)))
|
||||||
return results
|
return results
|
||||||
|
@ -162,7 +174,7 @@ async def smart_split(
|
||||||
if count > max_tokens:
|
if count > max_tokens:
|
||||||
# Yield current chunk if any
|
# Yield current chunk if any
|
||||||
if current_chunk:
|
if current_chunk:
|
||||||
chunk_text = " ".join(current_chunk)
|
chunk_text = " ".join(current_chunk).strip() # Strip after joining
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
|
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
|
||||||
|
@ -201,7 +213,7 @@ async def smart_split(
|
||||||
else:
|
else:
|
||||||
# Yield clause chunk if we have one
|
# Yield clause chunk if we have one
|
||||||
if clause_chunk:
|
if clause_chunk:
|
||||||
chunk_text = " ".join(clause_chunk)
|
chunk_text = " ".join(clause_chunk).strip() # Strip after joining
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)"
|
f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)"
|
||||||
|
@ -213,7 +225,7 @@ async def smart_split(
|
||||||
|
|
||||||
# Don't forget last clause chunk
|
# Don't forget last clause chunk
|
||||||
if clause_chunk:
|
if clause_chunk:
|
||||||
chunk_text = " ".join(clause_chunk)
|
chunk_text = " ".join(clause_chunk).strip() # Strip after joining
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)"
|
f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)"
|
||||||
|
@ -227,7 +239,7 @@ async def smart_split(
|
||||||
):
|
):
|
||||||
# If we have a good sized chunk and adding next sentence exceeds target,
|
# If we have a good sized chunk and adding next sentence exceeds target,
|
||||||
# yield current chunk and start new one
|
# yield current chunk and start new one
|
||||||
chunk_text = " ".join(current_chunk)
|
chunk_text = " ".join(current_chunk).strip() # Strip after joining
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
|
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
|
||||||
|
@ -252,7 +264,7 @@ async def smart_split(
|
||||||
else:
|
else:
|
||||||
# Yield current chunk and start new one
|
# Yield current chunk and start new one
|
||||||
if current_chunk:
|
if current_chunk:
|
||||||
chunk_text = " ".join(current_chunk)
|
chunk_text = " ".join(current_chunk).strip() # Strip after joining
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
|
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
|
||||||
|
@ -264,7 +276,7 @@ async def smart_split(
|
||||||
|
|
||||||
# Don't forget the last chunk
|
# Don't forget the last chunk
|
||||||
if current_chunk:
|
if current_chunk:
|
||||||
chunk_text = " ".join(current_chunk)
|
chunk_text = " ".join(current_chunk).strip() # Strip after joining
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
|
f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
|
||||||
|
|
|
@ -23,6 +23,8 @@ def tokenize(phonemes: str) -> list[int]:
|
||||||
Returns:
|
Returns:
|
||||||
List of token IDs
|
List of token IDs
|
||||||
"""
|
"""
|
||||||
|
# Strip phonemes to remove leading/trailing spaces that could cause artifacts
|
||||||
|
phonemes = phonemes.strip()
|
||||||
return [i for i in map(VOCAB.get, phonemes) if i is not None]
|
return [i for i in map(VOCAB.get, phonemes) if i is not None]
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue