Refactor audio processing and text normalization: Update audio normalization to use absolute amplitude threshold, enhance streaming audio writer with MP3 container options, and improve text normalization by stripping spaces and handling special characters to prevent audio artifacts.

This commit is contained in:
Lukin 2025-05-30 22:52:58 +08:00
parent 543cbecc1a
commit ab8ab7d749
6 changed files with 80 additions and 15 deletions

View file

@ -80,12 +80,12 @@ class AudioNormalizer:
non_silent_index_start, non_silent_index_end = None, None
for X in range(0, len(audio_data)):
if audio_data[X] > amplitude_threshold:
if abs(audio_data[X]) > amplitude_threshold:
non_silent_index_start = X
break
for X in range(len(audio_data) - 1, -1, -1):
if audio_data[X] > amplitude_threshold:
if abs(audio_data[X]) > amplitude_threshold:
non_silent_index_end = X
break

View file

@ -32,19 +32,29 @@ class StreamingAudioWriter:
if self.format in ["wav", "flac", "mp3", "pcm", "aac", "opus"]:
if self.format != "pcm":
self.output_buffer = BytesIO()
container_options = {}
# Try disabling Xing VBR header for MP3 to fix iOS timeline reading issues
if self.format == 'mp3':
# Disable Xing VBR header
container_options = {'write_xing': '0'}
logger.debug("Disabling Xing VBR header for MP3 encoding.")
self.container = av.open(
self.output_buffer,
mode="w",
format=self.format if self.format != "aac" else "adts",
options=container_options # Pass options here
)
self.stream = self.container.add_stream(
codec_map[self.format],
sample_rate=self.sample_rate,
rate=self.sample_rate, # Correct parameter name is 'rate'
layout="mono" if self.channels == 1 else "stereo",
)
self.stream.bit_rate = 128000
# Set bit_rate only for codecs where it's applicable and useful
if self.format in ['mp3', 'aac', 'opus']:
self.stream.bit_rate = 128000 # Example bitrate, can be configured
else:
raise ValueError(f"Unsupported format: {format}")
raise ValueError(f"Unsupported format: {self.format}") # Use self.format here
def close(self):
if hasattr(self, "container"):
@ -65,12 +75,18 @@ class StreamingAudioWriter:
if finalize:
if self.format != "pcm":
# Flush stream encoder
packets = self.stream.encode(None)
for packet in packets:
self.container.mux(packet)
# Closing the container handles writing the trailer and finalizing the file.
# No explicit flush method is available or needed here.
logger.debug("Muxed final packets.")
# Get the final bytes from the buffer *before* closing it
data = self.output_buffer.getvalue()
self.close()
self.close() # Close container and buffer
return data
if audio_data is None or len(audio_data) == 0:

View file

@ -391,6 +391,7 @@ def handle_time(t: re.Match[str]) -> str:
def normalize_text(text: str, normalization_options: NormalizationOptions) -> str:
"""Normalize text for TTS processing"""
# Handle email addresses first if enabled
if normalization_options.email_normalization:
text = EMAIL_PATTERN.sub(handle_email, text)
@ -415,7 +416,7 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
text,
)
# Replace quotes and brackets
# Replace quotes and brackets (additional cleanup)
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
text = text.replace("«", chr(8220)).replace("»", chr(8221))
text = text.replace(chr(8220), '"').replace(chr(8221), '"')
@ -435,6 +436,27 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
text = re.sub(r" +", " ", text)
text = re.sub(r"(?<=\n) +(?=\n)", "", text)
# Handle special characters that might cause audio artifacts first
# Replace newlines with spaces (or pauses if needed)
text = text.replace('\n', ' ')
text = text.replace('\r', ' ')
# Handle other problematic symbols
text = text.replace('~', '') # Remove tilde
text = text.replace('@', ' at ') # At symbol
text = text.replace('#', ' number ') # Hash/pound
text = text.replace('$', ' dollar ') # Dollar sign (if not handled by money pattern)
text = text.replace('%', ' percent ') # Percent sign
text = text.replace('^', '') # Caret
text = text.replace('&', ' and ') # Ampersand
text = text.replace('*', '') # Asterisk
text = text.replace('_', ' ') # Underscore to space
text = text.replace('|', ' ') # Pipe to space
text = text.replace('\\', ' ') # Backslash to space
text = text.replace('/', ' slash ') # Forward slash to space (unless in URLs)
text = text.replace('=', ' equals ') # Equals sign
text = text.replace('+', ' plus ') # Plus sign
# Handle titles and abbreviations
text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text)
text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text)

View file

@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
import phonemizer
from .normalizer import normalize_text
from ...structures.schemas import NormalizationOptions
phonemizers = {}
@ -95,8 +96,20 @@ def phonemize(text: str, language: str = "a", normalize: bool = True) -> str:
Phonemized text
"""
global phonemizers
# Strip input text first to remove problematic leading/trailing spaces
text = text.strip()
if normalize:
text = normalize_text(text)
# Create default normalization options and normalize text
normalization_options = NormalizationOptions()
text = normalize_text(text, normalization_options)
# Strip again after normalization
text = text.strip()
if language not in phonemizers:
phonemizers[language] = create_phonemizer(language)
return phonemizers[language].phonemize(text)
result = phonemizers[language].phonemize(text)
# Final strip to ensure no leading/trailing spaces in phonemes
return result.strip()

View file

@ -30,6 +30,12 @@ def process_text_chunk(
List of token IDs
"""
start_time = time.time()
# Strip input text to remove any leading/trailing spaces that could cause artifacts
text = text.strip()
if not text:
return []
if skip_phonemize:
# Input is already phonemes, just tokenize
@ -43,6 +49,8 @@ def process_text_chunk(
t0 = time.time()
phonemes = phonemize(text, language, normalize=False) # Already normalized
# Strip phonemes result to ensure no extra spaces
phonemes = phonemes.strip()
t1 = time.time()
t0 = time.time()
@ -114,6 +122,10 @@ def get_sentence_info(
if not sentence:
continue
full = sentence + punct
# Strip the full sentence to remove any leading/trailing spaces before processing
full = full.strip()
if not full: # Skip if empty after stripping
continue
tokens = process_text_chunk(full)
results.append((full, tokens, len(tokens)))
return results
@ -162,7 +174,7 @@ async def smart_split(
if count > max_tokens:
# Yield current chunk if any
if current_chunk:
chunk_text = " ".join(current_chunk)
chunk_text = " ".join(current_chunk).strip() # Strip after joining
chunk_count += 1
logger.debug(
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
@ -201,7 +213,7 @@ async def smart_split(
else:
# Yield clause chunk if we have one
if clause_chunk:
chunk_text = " ".join(clause_chunk)
chunk_text = " ".join(clause_chunk).strip() # Strip after joining
chunk_count += 1
logger.debug(
f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)"
@ -213,7 +225,7 @@ async def smart_split(
# Don't forget last clause chunk
if clause_chunk:
chunk_text = " ".join(clause_chunk)
chunk_text = " ".join(clause_chunk).strip() # Strip after joining
chunk_count += 1
logger.debug(
f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)"
@ -227,7 +239,7 @@ async def smart_split(
):
# If we have a good sized chunk and adding next sentence exceeds target,
# yield current chunk and start new one
chunk_text = " ".join(current_chunk)
chunk_text = " ".join(current_chunk).strip() # Strip after joining
chunk_count += 1
logger.info(
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
@ -252,7 +264,7 @@ async def smart_split(
else:
# Yield current chunk and start new one
if current_chunk:
chunk_text = " ".join(current_chunk)
chunk_text = " ".join(current_chunk).strip() # Strip after joining
chunk_count += 1
logger.info(
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
@ -264,7 +276,7 @@ async def smart_split(
# Don't forget the last chunk
if current_chunk:
chunk_text = " ".join(current_chunk)
chunk_text = " ".join(current_chunk).strip() # Strip after joining
chunk_count += 1
logger.info(
f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"

View file

@ -23,6 +23,8 @@ def tokenize(phonemes: str) -> list[int]:
Returns:
List of token IDs
"""
# Strip phonemes to remove leading/trailing spaces that could cause artifacts
phonemes = phonemes.strip()
return [i for i in map(VOCAB.get, phonemes) if i is not None]