diff --git a/api/src/services/audio.py b/api/src/services/audio.py index 89bed40..bfe419b 100644 --- a/api/src/services/audio.py +++ b/api/src/services/audio.py @@ -30,9 +30,6 @@ class AudioNormalizer: Returns: Normalized and trimmed audio data """ - if len(audio_data) == 0: - raise ValueError("Audio data cannot be empty") - # Convert to float32 for processing audio_float = audio_data.astype(np.float32) @@ -102,17 +99,14 @@ class AudioService: ) writer = AudioService._writers[writer_key] - # Write the current chunk - chunk_data = writer.write_chunk(normalized_audio) - - # Handle last chunk and cleanup + # Write chunk or finalize if is_last_chunk: - final_data = writer.close() - if final_data: - chunk_data += final_data + chunk_data = writer.write_chunk(finalize=True) del AudioService._writers[writer_key] - - return chunk_data + else: + chunk_data = writer.write_chunk(normalized_audio) + + return chunk_data if chunk_data else b'' except Exception as e: logger.error(f"Error converting audio stream to {output_format}: {str(e)}") diff --git a/api/src/services/streaming_audio_writer.py b/api/src/services/streaming_audio_writer.py index 5a1501e..54e2ee5 100644 --- a/api/src/services/streaming_audio_writer.py +++ b/api/src/services/streaming_audio_writer.py @@ -33,7 +33,9 @@ class StreamingAudioWriter: elif self.format == "mp3": # For MP3, we'll use pydub's incremental writer self.buffer = BytesIO() - self.encoder = AudioSegment.from_mono_audiosegments() + self.segments = [] # Store segments until we have enough data + # Initialize an empty AudioSegment as our encoder + self.encoder = AudioSegment.silent(duration=0, frame_rate=self.sample_rate) def _write_wav_header(self) -> bytes: """Write WAV header with correct streaming format""" @@ -53,12 +55,45 @@ class StreamingAudioWriter: header.write(struct.pack(' bytes: - """Write a chunk of audio data and return bytes in the target format""" + def write_chunk(self, audio_data: Optional[np.ndarray] = None, finalize: bool = False) -> bytes: + """Write a chunk of audio data and return bytes in the target format. + + Args: + audio_data: Audio data to write, or None if finalizing + finalize: Whether this is the final write to close the stream + """ buffer = BytesIO() + if finalize: + if self.format == "wav": + # Write final WAV header with correct sizes + buffer.write(b'RIFF') + buffer.write(struct.pack(' 0: + self.encoder.export(buffer, format="mp3", bitrate="192k", parameters=["-q:a", "2"]) + self.encoder = None + return buffer.getvalue() + + if audio_data is None or len(audio_data) == 0: + return b'' + if self.format == "wav": - # For WAV, we write raw PCM after the first chunk + # For WAV, write raw PCM after the first chunk if self.bytes_written == 0: buffer.write(self._write_wav_header()) buffer.write(audio_data.tobytes()) @@ -83,8 +118,20 @@ class StreamingAudioWriter: sample_width=audio_data.dtype.itemsize, channels=self.channels ) - self.encoder += segment - self.encoder.export(buffer, format="mp3") + + # Add segment to encoder + self.encoder = self.encoder + segment + + # Export current state to buffer + self.encoder.export(buffer, format="mp3", bitrate="192k", parameters=["-q:a", "2"]) + + # Get the encoded data + encoded_data = buffer.getvalue() + + # Reset encoder to prevent memory growth + self.encoder = AudioSegment.silent(duration=0, frame_rate=self.sample_rate) + + return encoded_data return buffer.getvalue() diff --git a/api/src/services/text_processing/__init__.py b/api/src/services/text_processing/__init__.py index c9b3eb4..6ff9ad9 100644 --- a/api/src/services/text_processing/__init__.py +++ b/api/src/services/text_processing/__init__.py @@ -1,28 +1,19 @@ """Text processing pipeline.""" -from .chunker import split_text from .normalizer import normalize_text from .phonemizer import phonemize from .vocabulary import tokenize +from .text_processor import process_text_chunk, smart_split +def process_text(text: str) -> list[int]: + """Process text into token IDs (for backward compatibility).""" + return process_text_chunk(text) -def process_text(text: str, language: str = "a") -> list[int]: - """Process text through the full pipeline. - - Args: - text: Input text - language: Language code ('a' for US English, 'b' for British English) - - Returns: - List of token IDs - - Note: - The pipeline: - 1. Converts text to phonemes using phonemizer - 2. Converts phonemes to token IDs using vocabulary - """ - # Convert text to phonemes - phonemes = phonemize(text, language=language) - - # Convert phonemes to token IDs - return tokenize(phonemes) +__all__ = [ + 'normalize_text', + 'phonemize', + 'tokenize', + 'process_text', + 'process_text_chunk', + 'smart_split' +] diff --git a/api/src/services/text_processing/semchunk_slim.py b/api/src/services/text_processing/semchunk_slim.py deleted file mode 100644 index eb73a78..0000000 --- a/api/src/services/text_processing/semchunk_slim.py +++ /dev/null @@ -1,89 +0,0 @@ -from __future__ import annotations -import re -from typing import Callable - -# Prioritize sentence boundaries for TTS -_NON_WHITESPACE_SEMANTIC_SPLITTERS = ( - '.', '!', '?', # Primary - sentence boundaries - ';', ':', # Secondary - major clause boundaries - ',', # Tertiary - minor clause boundaries - '(', ')', '[', ']', '"', '"', "'", "'", "'", '"', '`', # Other punctuation - '—', '…', # Dashes and ellipsis - '/', '\\', '–', '&', '-', # Word joiners -) -"""Semantic splitters ordered by priority for TTS chunking""" - -def _split_text(text: str) -> tuple[str, bool, list[str]]: - """Split text using the most semantically meaningful splitter possible.""" - - splitter_is_whitespace = True - - # Try splitting at, in order: - # - Newlines (natural paragraph breaks) - # - Spaces (if no other splits possible) - # - Semantic splitters (prioritizing sentence boundaries) - if '\n' in text or '\r' in text: - splitter = max(re.findall(r'[\r\n]+', text)) - - elif re.search(r'\s', text): - splitter = max(re.findall(r'\s+', text)) - - else: - # Find first semantic splitter present - for splitter in _NON_WHITESPACE_SEMANTIC_SPLITTERS: - if splitter in text: - splitter_is_whitespace = False - break - else: - return '', splitter_is_whitespace, list(text) - - return splitter, splitter_is_whitespace, text.split(splitter) - -class Chunker: - def __init__(self, chunk_size: int, token_counter: Callable[[str], int]) -> None: - self.chunk_size = chunk_size - self.token_counter = token_counter - - def __call__(self, text: str) -> list[str]: - """Split text into chunks based on semantic boundaries.""" - if not isinstance(text, str): - text = str(text) if text is not None else "" - - text = text.strip() - if not text: - return [] - - # Split the text - splitter, _, splits = _split_text(text) - - chunks = [] - current_chunk = [] - current_len = 0 - - for split in splits: - split = split.strip() - if not split: - continue - - # Check if adding this split would exceed chunk size - split_len = self.token_counter(split) - if current_len + split_len <= self.chunk_size: - current_chunk.append(split) - current_len += split_len - else: - # Save current chunk if it exists - if current_chunk: - chunks.append(splitter.join(current_chunk)) - # Start new chunk with current split - current_chunk = [split] - current_len = split_len - - # Add final chunk if it exists - if current_chunk: - chunks.append(splitter.join(current_chunk)) - - return chunks - -def chunkerify(token_counter: Callable[[str], int], chunk_size: int) -> Chunker: - """Create a chunker with the specified token counter and chunk size.""" - return Chunker(chunk_size=chunk_size, token_counter=token_counter) \ No newline at end of file diff --git a/api/src/services/text_processing/text_processor.py b/api/src/services/text_processing/text_processor.py new file mode 100644 index 0000000..8aa0055 --- /dev/null +++ b/api/src/services/text_processing/text_processor.py @@ -0,0 +1,177 @@ +"""Unified text processing for TTS with smart chunking.""" + +import re +import time +from typing import AsyncGenerator, List, Tuple +from loguru import logger +from .phonemizer import phonemize +from .normalizer import normalize_text +from .vocabulary import tokenize + +def process_text_chunk(text: str, language: str = "a") -> List[int]: + """Process a chunk of text through normalization, phonemization, and tokenization. + + Args: + text: Text chunk to process + language: Language code for phonemization + + Returns: + List of token IDs + """ + start_time = time.time() + + # Normalize + t0 = time.time() + normalized = normalize_text(text) + t1 = time.time() + logger.debug(f"Normalization took {(t1-t0)*1000:.2f}ms for {len(text)} chars") + + # Phonemize + t0 = time.time() + phonemes = phonemize(normalized, language, normalize=False) # Already normalized + t1 = time.time() + logger.debug(f"Phonemization took {(t1-t0)*1000:.2f}ms for {len(normalized)} chars") + + # Convert to token IDs + t0 = time.time() + tokens = tokenize(phonemes) + t1 = time.time() + logger.debug(f"Tokenization took {(t1-t0)*1000:.2f}ms for {len(phonemes)} chars") + + total_time = time.time() - start_time + logger.debug(f"Total processing took {total_time*1000:.2f}ms for chunk: '{text[:50]}...'") + + return tokens + +def process_text(text: str, language: str = "a") -> List[int]: + """Process text into token IDs. + + Args: + text: Text to process + language: Language code for phonemization + + Returns: + List of token IDs + """ + if not isinstance(text, str): + text = str(text) if text is not None else "" + + text = text.strip() + if not text: + return [] + + return process_text_chunk(text, language) + +async def smart_split(text: str, max_tokens: int = 500) -> AsyncGenerator[Tuple[str, List[int]], None]: + """Split text into semantically meaningful chunks while respecting token limits. + + Args: + text: Input text to split + max_tokens: Maximum tokens per chunk + + Yields: + Tuples of (text chunk, token IDs) where token count is <= max_tokens + """ + start_time = time.time() + chunk_count = 0 + total_chars = len(text) + logger.info(f"Starting text split for {total_chars} characters with {max_tokens} max tokens") + + # Split on major punctuation first + sentences = re.split(r'([.!?;:])', text) + + current_chunk = [] + current_token_count = 0 + + for i in range(0, len(sentences), 2): + # Get sentence and its punctuation (if any) + sentence = sentences[i].strip() + punct = sentences[i + 1] if i + 1 < len(sentences) else "" + + if not sentence: + continue + + # Process sentence to get token count + sentence_with_punct = sentence + punct + tokens = process_text_chunk(sentence_with_punct) + token_count = len(tokens) + logger.debug(f"Sentence '{sentence_with_punct[:50]}...' has {token_count} tokens") + + # If this single sentence is too long, split on commas + if token_count > max_tokens: + logger.debug(f"Sentence exceeds token limit, splitting on commas") + clause_splits = re.split(r'([,])', sentence_with_punct) + for j in range(0, len(clause_splits), 2): + clause = clause_splits[j].strip() + comma = clause_splits[j + 1] if j + 1 < len(clause_splits) else "" + + if not clause: + continue + + clause_with_punct = clause + comma + clause_tokens = process_text_chunk(clause_with_punct) + + # If still too long, do a hard split on words + if len(clause_tokens) > max_tokens: + logger.debug(f"Clause exceeds token limit, splitting on words") + words = clause_with_punct.split() + temp_chunk = [] + temp_tokens = [] + + for word in words: + word_tokens = process_text_chunk(word) + if len(temp_tokens) + len(word_tokens) > max_tokens: + if temp_chunk: # Don't yield empty chunks + chunk_text = " ".join(temp_chunk) + chunk_count += 1 + logger.info(f"Yielding word-split chunk {chunk_count}: '{chunk_text[:50]}...' ({len(temp_tokens)} tokens)") + yield chunk_text, temp_tokens + temp_chunk = [word] + temp_tokens = word_tokens + else: + temp_chunk.append(word) + temp_tokens.extend(word_tokens) + + if temp_chunk: # Don't forget the last chunk + chunk_text = " ".join(temp_chunk) + chunk_count += 1 + logger.info(f"Yielding final word-split chunk {chunk_count}: '{chunk_text[:50]}...' ({len(temp_tokens)} tokens)") + yield chunk_text, temp_tokens + + else: + # Check if adding this clause would exceed the limit + if current_token_count + len(clause_tokens) > max_tokens: + if current_chunk: # Don't yield empty chunks + chunk_text = " ".join(current_chunk) + chunk_count += 1 + logger.info(f"Yielding clause-split chunk {chunk_count}: '{chunk_text[:50]}...' ({current_token_count} tokens)") + yield chunk_text, process_text_chunk(chunk_text) + current_chunk = [clause_with_punct] + current_token_count = len(clause_tokens) + else: + current_chunk.append(clause_with_punct) + current_token_count += len(clause_tokens) + + else: + # Check if adding this sentence would exceed the limit + if current_token_count + token_count > max_tokens: + if current_chunk: # Don't yield empty chunks + chunk_text = " ".join(current_chunk) + chunk_count += 1 + logger.info(f"Yielding sentence-split chunk {chunk_count}: '{chunk_text[:50]}...' ({current_token_count} tokens)") + yield chunk_text, process_text_chunk(chunk_text) + current_chunk = [sentence_with_punct] + current_token_count = token_count + else: + current_chunk.append(sentence_with_punct) + current_token_count += token_count + + # Don't forget the last chunk + if current_chunk: + chunk_text = " ".join(current_chunk) + chunk_count += 1 + logger.info(f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}...' ({current_token_count} tokens)") + yield chunk_text, process_text_chunk(chunk_text) + + total_time = time.time() - start_time + logger.info(f"Text splitting completed in {total_time*1000:.2f}ms, produced {chunk_count} chunks") diff --git a/api/src/services/tts_service.py b/api/src/services/tts_service.py index b8b023f..e6ef24d 100644 --- a/api/src/services/tts_service.py +++ b/api/src/services/tts_service.py @@ -1,11 +1,10 @@ """TTS service using model and voice managers.""" -import io import time from typing import List, Tuple, Optional, AsyncGenerator, Union +import asyncio import numpy as np -import scipy.io.wavfile as wavfile import torch from loguru import logger @@ -13,10 +12,7 @@ from ..core.config import settings from ..inference.model_manager import get_manager as get_model_manager from ..inference.voice_manager import get_manager as get_voice_manager from .audio import AudioNormalizer, AudioService -from .text_processing import chunker, normalize_text, process_text - - -import asyncio +from .text_processing.text_processor import process_text_chunk, smart_split class TTSService: """Text-to-speech service.""" @@ -40,7 +36,7 @@ class TTSService: async def _process_chunk( self, - chunk: str, + tokens: List[int], voice_tensor: torch.Tensor, speed: float, output_format: Optional[str] = None, @@ -48,13 +44,24 @@ class TTSService: is_last: bool = False, normalizer: Optional[AudioNormalizer] = None, ) -> Optional[Union[np.ndarray, bytes]]: - """Process a single text chunk into audio.""" + """Process tokens into audio.""" async with self._chunk_semaphore: try: - tokens = process_text(chunk) + # Handle stream finalization + if is_last: + return await AudioService.convert_audio( + np.array([0], dtype=np.float32), # Dummy data for type checking + 24000, + output_format, + is_first_chunk=False, + normalizer=normalizer, + is_last_chunk=True + ) + + # Skip empty chunks if not tokens: return None - + # Generate audio using pre-warmed model chunk_audio = await self.model_manager.generate( tokens, @@ -63,23 +70,31 @@ class TTSService: ) if chunk_audio is None: + logger.error("Model generated None for audio chunk") + return None + + if len(chunk_audio) == 0: + logger.error("Model generated empty audio chunk") return None # For streaming, convert to bytes if output_format: - return await AudioService.convert_audio( - chunk_audio, - 24000, - output_format, - is_first_chunk=is_first, - normalizer=normalizer, - is_last_chunk=is_last - ) + try: + return await AudioService.convert_audio( + chunk_audio, + 24000, + output_format, + is_first_chunk=is_first, + normalizer=normalizer, + is_last_chunk=is_last + ) + except Exception as e: + logger.error(f"Failed to convert audio: {str(e)}") + return None return chunk_audio - except Exception as e: - logger.error(f"Failed to process chunk: '{chunk}'. Error: {str(e)}") + logger.error(f"Failed to process tokens: {str(e)}") return None async def generate_audio_stream( @@ -92,62 +107,59 @@ class TTSService: """Generate and stream audio chunks.""" stream_normalizer = AudioNormalizer() voice_tensor = None - pending_results = {} - next_index = 0 + chunk_index = 0 try: - # Normalize text - normalized = normalize_text(text) - if not normalized: - raise ValueError("Text is empty after preprocessing") - text = str(normalized) - # Get backend and load voice (should be fast if cached) backend = self.model_manager.get_backend() voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device) - # Process chunks with semaphore limiting concurrency - chunks = [] - async for chunk in chunker.split_text(text): - chunks.append(chunk) - - if not chunks: - raise ValueError("No text chunks to process") - - # Create tasks for all chunks - tasks = [ - asyncio.create_task( - self._process_chunk( - chunk, + # Process text in chunks with smart splitting + async for chunk_text, tokens in smart_split(text): + try: + # Process audio for chunk + result = await self._process_chunk( + tokens, voice_tensor, speed, output_format, - is_first=(i == 0), - is_last=(i == len(chunks) - 1), + is_first=(chunk_index == 0), + is_last=False, # We'll update the last chunk later normalizer=stream_normalizer ) - ) - for i, chunk in enumerate(chunks) - ] - - # Process chunks and maintain order - for i, task in enumerate(tasks): - result = await task - - if i == next_index and result is not None: - # If this is the next chunk we need, yield it - yield result - next_index += 1 - # Check if we have any subsequent chunks ready - while next_index in pending_results: - result = pending_results.pop(next_index) - if result is not None: - yield result - next_index += 1 - else: - # Store out-of-order result - pending_results[i] = result + if result is not None: + yield result + chunk_index += 1 + else: + logger.warning(f"No audio generated for chunk: '{chunk_text[:100]}...'") + + except Exception as e: + logger.error(f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}") + continue + + # Only finalize if we successfully processed at least one chunk + if chunk_index > 0: + try: + # Empty tokens list to finalize audio + final_result = await self._process_chunk( + [], # Empty tokens list + voice_tensor, + speed, + output_format, + is_first=False, + is_last=True, + normalizer=stream_normalizer + ) + if final_result is not None: + logger.debug("Yielding final chunk to finalize audio") + yield final_result + else: + logger.warning("Final chunk processing returned None") + except Exception as e: + logger.error(f"Failed to process final chunk: {str(e)}") + else: + logger.warning("No audio chunks were successfully processed") except Exception as e: logger.error(f"Error in audio generation stream: {str(e)}")