Refactor audio processing and cleanup: remove unused chunker, enhance StreamingAudioWriter for better MP3 handling, and improve text processing compatibility.

This commit is contained in:
remsky 2025-01-27 20:23:42 -07:00
parent 8a60a2b90c
commit 75889e157d
6 changed files with 324 additions and 192 deletions

View file

@ -30,9 +30,6 @@ class AudioNormalizer:
Returns: Returns:
Normalized and trimmed audio data Normalized and trimmed audio data
""" """
if len(audio_data) == 0:
raise ValueError("Audio data cannot be empty")
# Convert to float32 for processing # Convert to float32 for processing
audio_float = audio_data.astype(np.float32) audio_float = audio_data.astype(np.float32)
@ -102,17 +99,14 @@ class AudioService:
) )
writer = AudioService._writers[writer_key] writer = AudioService._writers[writer_key]
# Write the current chunk # Write chunk or finalize
if is_last_chunk:
chunk_data = writer.write_chunk(finalize=True)
del AudioService._writers[writer_key]
else:
chunk_data = writer.write_chunk(normalized_audio) chunk_data = writer.write_chunk(normalized_audio)
# Handle last chunk and cleanup return chunk_data if chunk_data else b''
if is_last_chunk:
final_data = writer.close()
if final_data:
chunk_data += final_data
del AudioService._writers[writer_key]
return chunk_data
except Exception as e: except Exception as e:
logger.error(f"Error converting audio stream to {output_format}: {str(e)}") logger.error(f"Error converting audio stream to {output_format}: {str(e)}")

View file

@ -33,7 +33,9 @@ class StreamingAudioWriter:
elif self.format == "mp3": elif self.format == "mp3":
# For MP3, we'll use pydub's incremental writer # For MP3, we'll use pydub's incremental writer
self.buffer = BytesIO() self.buffer = BytesIO()
self.encoder = AudioSegment.from_mono_audiosegments() self.segments = [] # Store segments until we have enough data
# Initialize an empty AudioSegment as our encoder
self.encoder = AudioSegment.silent(duration=0, frame_rate=self.sample_rate)
def _write_wav_header(self) -> bytes: def _write_wav_header(self) -> bytes:
"""Write WAV header with correct streaming format""" """Write WAV header with correct streaming format"""
@ -53,12 +55,45 @@ class StreamingAudioWriter:
header.write(struct.pack('<L', 0)) # Placeholder for data size header.write(struct.pack('<L', 0)) # Placeholder for data size
return header.getvalue() return header.getvalue()
def write_chunk(self, audio_data: np.ndarray) -> bytes: def write_chunk(self, audio_data: Optional[np.ndarray] = None, finalize: bool = False) -> bytes:
"""Write a chunk of audio data and return bytes in the target format""" """Write a chunk of audio data and return bytes in the target format.
Args:
audio_data: Audio data to write, or None if finalizing
finalize: Whether this is the final write to close the stream
"""
buffer = BytesIO() buffer = BytesIO()
if finalize:
if self.format == "wav": if self.format == "wav":
# For WAV, we write raw PCM after the first chunk # Write final WAV header with correct sizes
buffer.write(b'RIFF')
buffer.write(struct.pack('<L', self.bytes_written + 36))
buffer.write(b'WAVE')
buffer.write(b'fmt ')
buffer.write(struct.pack('<L', 16))
buffer.write(struct.pack('<H', 1))
buffer.write(struct.pack('<H', self.channels))
buffer.write(struct.pack('<L', self.sample_rate))
buffer.write(struct.pack('<L', self.sample_rate * self.channels * 2))
buffer.write(struct.pack('<H', self.channels * 2))
buffer.write(struct.pack('<H', 16))
buffer.write(b'data')
buffer.write(struct.pack('<L', self.bytes_written))
elif self.format == "ogg":
self.writer.close()
elif self.format == "mp3":
# Final export of any remaining audio
if hasattr(self, 'encoder') and len(self.encoder) > 0:
self.encoder.export(buffer, format="mp3", bitrate="192k", parameters=["-q:a", "2"])
self.encoder = None
return buffer.getvalue()
if audio_data is None or len(audio_data) == 0:
return b''
if self.format == "wav":
# For WAV, write raw PCM after the first chunk
if self.bytes_written == 0: if self.bytes_written == 0:
buffer.write(self._write_wav_header()) buffer.write(self._write_wav_header())
buffer.write(audio_data.tobytes()) buffer.write(audio_data.tobytes())
@ -83,8 +118,20 @@ class StreamingAudioWriter:
sample_width=audio_data.dtype.itemsize, sample_width=audio_data.dtype.itemsize,
channels=self.channels channels=self.channels
) )
self.encoder += segment
self.encoder.export(buffer, format="mp3") # Add segment to encoder
self.encoder = self.encoder + segment
# Export current state to buffer
self.encoder.export(buffer, format="mp3", bitrate="192k", parameters=["-q:a", "2"])
# Get the encoded data
encoded_data = buffer.getvalue()
# Reset encoder to prevent memory growth
self.encoder = AudioSegment.silent(duration=0, frame_rate=self.sample_rate)
return encoded_data
return buffer.getvalue() return buffer.getvalue()

View file

@ -1,28 +1,19 @@
"""Text processing pipeline.""" """Text processing pipeline."""
from .chunker import split_text
from .normalizer import normalize_text from .normalizer import normalize_text
from .phonemizer import phonemize from .phonemizer import phonemize
from .vocabulary import tokenize from .vocabulary import tokenize
from .text_processor import process_text_chunk, smart_split
def process_text(text: str) -> list[int]:
"""Process text into token IDs (for backward compatibility)."""
return process_text_chunk(text)
def process_text(text: str, language: str = "a") -> list[int]: __all__ = [
"""Process text through the full pipeline. 'normalize_text',
'phonemize',
Args: 'tokenize',
text: Input text 'process_text',
language: Language code ('a' for US English, 'b' for British English) 'process_text_chunk',
'smart_split'
Returns: ]
List of token IDs
Note:
The pipeline:
1. Converts text to phonemes using phonemizer
2. Converts phonemes to token IDs using vocabulary
"""
# Convert text to phonemes
phonemes = phonemize(text, language=language)
# Convert phonemes to token IDs
return tokenize(phonemes)

View file

@ -1,89 +0,0 @@
from __future__ import annotations
import re
from typing import Callable
# Prioritize sentence boundaries for TTS
_NON_WHITESPACE_SEMANTIC_SPLITTERS = (
'.', '!', '?', # Primary - sentence boundaries
';', ':', # Secondary - major clause boundaries
',', # Tertiary - minor clause boundaries
'(', ')', '[', ']', '"', '"', "'", "'", "'", '"', '`', # Other punctuation
'', '', # Dashes and ellipsis
'/', '\\', '', '&', '-', # Word joiners
)
"""Semantic splitters ordered by priority for TTS chunking"""
def _split_text(text: str) -> tuple[str, bool, list[str]]:
"""Split text using the most semantically meaningful splitter possible."""
splitter_is_whitespace = True
# Try splitting at, in order:
# - Newlines (natural paragraph breaks)
# - Spaces (if no other splits possible)
# - Semantic splitters (prioritizing sentence boundaries)
if '\n' in text or '\r' in text:
splitter = max(re.findall(r'[\r\n]+', text))
elif re.search(r'\s', text):
splitter = max(re.findall(r'\s+', text))
else:
# Find first semantic splitter present
for splitter in _NON_WHITESPACE_SEMANTIC_SPLITTERS:
if splitter in text:
splitter_is_whitespace = False
break
else:
return '', splitter_is_whitespace, list(text)
return splitter, splitter_is_whitespace, text.split(splitter)
class Chunker:
def __init__(self, chunk_size: int, token_counter: Callable[[str], int]) -> None:
self.chunk_size = chunk_size
self.token_counter = token_counter
def __call__(self, text: str) -> list[str]:
"""Split text into chunks based on semantic boundaries."""
if not isinstance(text, str):
text = str(text) if text is not None else ""
text = text.strip()
if not text:
return []
# Split the text
splitter, _, splits = _split_text(text)
chunks = []
current_chunk = []
current_len = 0
for split in splits:
split = split.strip()
if not split:
continue
# Check if adding this split would exceed chunk size
split_len = self.token_counter(split)
if current_len + split_len <= self.chunk_size:
current_chunk.append(split)
current_len += split_len
else:
# Save current chunk if it exists
if current_chunk:
chunks.append(splitter.join(current_chunk))
# Start new chunk with current split
current_chunk = [split]
current_len = split_len
# Add final chunk if it exists
if current_chunk:
chunks.append(splitter.join(current_chunk))
return chunks
def chunkerify(token_counter: Callable[[str], int], chunk_size: int) -> Chunker:
"""Create a chunker with the specified token counter and chunk size."""
return Chunker(chunk_size=chunk_size, token_counter=token_counter)

View file

@ -0,0 +1,177 @@
"""Unified text processing for TTS with smart chunking."""
import re
import time
from typing import AsyncGenerator, List, Tuple
from loguru import logger
from .phonemizer import phonemize
from .normalizer import normalize_text
from .vocabulary import tokenize
def process_text_chunk(text: str, language: str = "a") -> List[int]:
"""Process a chunk of text through normalization, phonemization, and tokenization.
Args:
text: Text chunk to process
language: Language code for phonemization
Returns:
List of token IDs
"""
start_time = time.time()
# Normalize
t0 = time.time()
normalized = normalize_text(text)
t1 = time.time()
logger.debug(f"Normalization took {(t1-t0)*1000:.2f}ms for {len(text)} chars")
# Phonemize
t0 = time.time()
phonemes = phonemize(normalized, language, normalize=False) # Already normalized
t1 = time.time()
logger.debug(f"Phonemization took {(t1-t0)*1000:.2f}ms for {len(normalized)} chars")
# Convert to token IDs
t0 = time.time()
tokens = tokenize(phonemes)
t1 = time.time()
logger.debug(f"Tokenization took {(t1-t0)*1000:.2f}ms for {len(phonemes)} chars")
total_time = time.time() - start_time
logger.debug(f"Total processing took {total_time*1000:.2f}ms for chunk: '{text[:50]}...'")
return tokens
def process_text(text: str, language: str = "a") -> List[int]:
"""Process text into token IDs.
Args:
text: Text to process
language: Language code for phonemization
Returns:
List of token IDs
"""
if not isinstance(text, str):
text = str(text) if text is not None else ""
text = text.strip()
if not text:
return []
return process_text_chunk(text, language)
async def smart_split(text: str, max_tokens: int = 500) -> AsyncGenerator[Tuple[str, List[int]], None]:
"""Split text into semantically meaningful chunks while respecting token limits.
Args:
text: Input text to split
max_tokens: Maximum tokens per chunk
Yields:
Tuples of (text chunk, token IDs) where token count is <= max_tokens
"""
start_time = time.time()
chunk_count = 0
total_chars = len(text)
logger.info(f"Starting text split for {total_chars} characters with {max_tokens} max tokens")
# Split on major punctuation first
sentences = re.split(r'([.!?;:])', text)
current_chunk = []
current_token_count = 0
for i in range(0, len(sentences), 2):
# Get sentence and its punctuation (if any)
sentence = sentences[i].strip()
punct = sentences[i + 1] if i + 1 < len(sentences) else ""
if not sentence:
continue
# Process sentence to get token count
sentence_with_punct = sentence + punct
tokens = process_text_chunk(sentence_with_punct)
token_count = len(tokens)
logger.debug(f"Sentence '{sentence_with_punct[:50]}...' has {token_count} tokens")
# If this single sentence is too long, split on commas
if token_count > max_tokens:
logger.debug(f"Sentence exceeds token limit, splitting on commas")
clause_splits = re.split(r'([,])', sentence_with_punct)
for j in range(0, len(clause_splits), 2):
clause = clause_splits[j].strip()
comma = clause_splits[j + 1] if j + 1 < len(clause_splits) else ""
if not clause:
continue
clause_with_punct = clause + comma
clause_tokens = process_text_chunk(clause_with_punct)
# If still too long, do a hard split on words
if len(clause_tokens) > max_tokens:
logger.debug(f"Clause exceeds token limit, splitting on words")
words = clause_with_punct.split()
temp_chunk = []
temp_tokens = []
for word in words:
word_tokens = process_text_chunk(word)
if len(temp_tokens) + len(word_tokens) > max_tokens:
if temp_chunk: # Don't yield empty chunks
chunk_text = " ".join(temp_chunk)
chunk_count += 1
logger.info(f"Yielding word-split chunk {chunk_count}: '{chunk_text[:50]}...' ({len(temp_tokens)} tokens)")
yield chunk_text, temp_tokens
temp_chunk = [word]
temp_tokens = word_tokens
else:
temp_chunk.append(word)
temp_tokens.extend(word_tokens)
if temp_chunk: # Don't forget the last chunk
chunk_text = " ".join(temp_chunk)
chunk_count += 1
logger.info(f"Yielding final word-split chunk {chunk_count}: '{chunk_text[:50]}...' ({len(temp_tokens)} tokens)")
yield chunk_text, temp_tokens
else:
# Check if adding this clause would exceed the limit
if current_token_count + len(clause_tokens) > max_tokens:
if current_chunk: # Don't yield empty chunks
chunk_text = " ".join(current_chunk)
chunk_count += 1
logger.info(f"Yielding clause-split chunk {chunk_count}: '{chunk_text[:50]}...' ({current_token_count} tokens)")
yield chunk_text, process_text_chunk(chunk_text)
current_chunk = [clause_with_punct]
current_token_count = len(clause_tokens)
else:
current_chunk.append(clause_with_punct)
current_token_count += len(clause_tokens)
else:
# Check if adding this sentence would exceed the limit
if current_token_count + token_count > max_tokens:
if current_chunk: # Don't yield empty chunks
chunk_text = " ".join(current_chunk)
chunk_count += 1
logger.info(f"Yielding sentence-split chunk {chunk_count}: '{chunk_text[:50]}...' ({current_token_count} tokens)")
yield chunk_text, process_text_chunk(chunk_text)
current_chunk = [sentence_with_punct]
current_token_count = token_count
else:
current_chunk.append(sentence_with_punct)
current_token_count += token_count
# Don't forget the last chunk
if current_chunk:
chunk_text = " ".join(current_chunk)
chunk_count += 1
logger.info(f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}...' ({current_token_count} tokens)")
yield chunk_text, process_text_chunk(chunk_text)
total_time = time.time() - start_time
logger.info(f"Text splitting completed in {total_time*1000:.2f}ms, produced {chunk_count} chunks")

View file

@ -1,11 +1,10 @@
"""TTS service using model and voice managers.""" """TTS service using model and voice managers."""
import io
import time import time
from typing import List, Tuple, Optional, AsyncGenerator, Union from typing import List, Tuple, Optional, AsyncGenerator, Union
import asyncio
import numpy as np import numpy as np
import scipy.io.wavfile as wavfile
import torch import torch
from loguru import logger from loguru import logger
@ -13,10 +12,7 @@ from ..core.config import settings
from ..inference.model_manager import get_manager as get_model_manager from ..inference.model_manager import get_manager as get_model_manager
from ..inference.voice_manager import get_manager as get_voice_manager from ..inference.voice_manager import get_manager as get_voice_manager
from .audio import AudioNormalizer, AudioService from .audio import AudioNormalizer, AudioService
from .text_processing import chunker, normalize_text, process_text from .text_processing.text_processor import process_text_chunk, smart_split
import asyncio
class TTSService: class TTSService:
"""Text-to-speech service.""" """Text-to-speech service."""
@ -40,7 +36,7 @@ class TTSService:
async def _process_chunk( async def _process_chunk(
self, self,
chunk: str, tokens: List[int],
voice_tensor: torch.Tensor, voice_tensor: torch.Tensor,
speed: float, speed: float,
output_format: Optional[str] = None, output_format: Optional[str] = None,
@ -48,10 +44,21 @@ class TTSService:
is_last: bool = False, is_last: bool = False,
normalizer: Optional[AudioNormalizer] = None, normalizer: Optional[AudioNormalizer] = None,
) -> Optional[Union[np.ndarray, bytes]]: ) -> Optional[Union[np.ndarray, bytes]]:
"""Process a single text chunk into audio.""" """Process tokens into audio."""
async with self._chunk_semaphore: async with self._chunk_semaphore:
try: try:
tokens = process_text(chunk) # Handle stream finalization
if is_last:
return await AudioService.convert_audio(
np.array([0], dtype=np.float32), # Dummy data for type checking
24000,
output_format,
is_first_chunk=False,
normalizer=normalizer,
is_last_chunk=True
)
# Skip empty chunks
if not tokens: if not tokens:
return None return None
@ -63,10 +70,16 @@ class TTSService:
) )
if chunk_audio is None: if chunk_audio is None:
logger.error("Model generated None for audio chunk")
return None
if len(chunk_audio) == 0:
logger.error("Model generated empty audio chunk")
return None return None
# For streaming, convert to bytes # For streaming, convert to bytes
if output_format: if output_format:
try:
return await AudioService.convert_audio( return await AudioService.convert_audio(
chunk_audio, chunk_audio,
24000, 24000,
@ -75,11 +88,13 @@ class TTSService:
normalizer=normalizer, normalizer=normalizer,
is_last_chunk=is_last is_last_chunk=is_last
) )
except Exception as e:
logger.error(f"Failed to convert audio: {str(e)}")
return None
return chunk_audio return chunk_audio
except Exception as e: except Exception as e:
logger.error(f"Failed to process chunk: '{chunk}'. Error: {str(e)}") logger.error(f"Failed to process tokens: {str(e)}")
return None return None
async def generate_audio_stream( async def generate_audio_stream(
@ -92,62 +107,59 @@ class TTSService:
"""Generate and stream audio chunks.""" """Generate and stream audio chunks."""
stream_normalizer = AudioNormalizer() stream_normalizer = AudioNormalizer()
voice_tensor = None voice_tensor = None
pending_results = {} chunk_index = 0
next_index = 0
try: try:
# Normalize text
normalized = normalize_text(text)
if not normalized:
raise ValueError("Text is empty after preprocessing")
text = str(normalized)
# Get backend and load voice (should be fast if cached) # Get backend and load voice (should be fast if cached)
backend = self.model_manager.get_backend() backend = self.model_manager.get_backend()
voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device) voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device)
# Process chunks with semaphore limiting concurrency # Process text in chunks with smart splitting
chunks = [] async for chunk_text, tokens in smart_split(text):
async for chunk in chunker.split_text(text): try:
chunks.append(chunk) # Process audio for chunk
result = await self._process_chunk(
if not chunks: tokens,
raise ValueError("No text chunks to process")
# Create tasks for all chunks
tasks = [
asyncio.create_task(
self._process_chunk(
chunk,
voice_tensor, voice_tensor,
speed, speed,
output_format, output_format,
is_first=(i == 0), is_first=(chunk_index == 0),
is_last=(i == len(chunks) - 1), is_last=False, # We'll update the last chunk later
normalizer=stream_normalizer normalizer=stream_normalizer
) )
)
for i, chunk in enumerate(chunks)
]
# Process chunks and maintain order
for i, task in enumerate(tasks):
result = await task
if i == next_index and result is not None:
# If this is the next chunk we need, yield it
yield result
next_index += 1
# Check if we have any subsequent chunks ready
while next_index in pending_results:
result = pending_results.pop(next_index)
if result is not None: if result is not None:
yield result yield result
next_index += 1 chunk_index += 1
else: else:
# Store out-of-order result logger.warning(f"No audio generated for chunk: '{chunk_text[:100]}...'")
pending_results[i] = result
except Exception as e:
logger.error(f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}")
continue
# Only finalize if we successfully processed at least one chunk
if chunk_index > 0:
try:
# Empty tokens list to finalize audio
final_result = await self._process_chunk(
[], # Empty tokens list
voice_tensor,
speed,
output_format,
is_first=False,
is_last=True,
normalizer=stream_normalizer
)
if final_result is not None:
logger.debug("Yielding final chunk to finalize audio")
yield final_result
else:
logger.warning("Final chunk processing returned None")
except Exception as e:
logger.error(f"Failed to process final chunk: {str(e)}")
else:
logger.warning("No audio chunks were successfully processed")
except Exception as e: except Exception as e:
logger.error(f"Error in audio generation stream: {str(e)}") logger.error(f"Error in audio generation stream: {str(e)}")