From ee1f7cde1819f5275ea61d56777aec13c55e3eae Mon Sep 17 00:00:00 2001 From: remsky Date: Fri, 24 Jan 2025 04:06:47 -0700 Subject: [PATCH] Add async audio processing and semantic chunking support; flattened static audio trimming --- api/src/core/model_config.py | 2 +- api/src/routers/development.py | 2 +- api/src/routers/openai_compatible.py | 2 +- api/src/services/audio.py | 29 +++--- api/src/services/text_processing/chunker.py | 99 +++++++++++-------- .../services/text_processing/semchunk_slim.py | 89 +++++++++++++++++ api/src/services/tts_service.py | 16 ++- pyproject.toml | 1 + 8 files changed, 180 insertions(+), 60 deletions(-) create mode 100644 api/src/services/text_processing/semchunk_slim.py diff --git a/api/src/core/model_config.py b/api/src/core/model_config.py index 3f1e00b..ac21935 100644 --- a/api/src/core/model_config.py +++ b/api/src/core/model_config.py @@ -77,7 +77,7 @@ class ModelConfig(BaseModel): voice_cache_size: int = Field(2, description="Maximum number of cached voices") # Model filenames - pytorch_model_file: str = Field("kokoro-v0_19.pth", description="PyTorch model filename") + pytorch_model_file: str = Field("kokoro-v0_19-half.pth", description="PyTorch model filename") onnx_model_file: str = Field("kokoro-v0_19.onnx", description="ONNX model filename") # Backend-specific configs diff --git a/api/src/routers/development.py b/api/src/routers/development.py index df1b638..58e7bb8 100644 --- a/api/src/routers/development.py +++ b/api/src/routers/development.py @@ -138,7 +138,7 @@ async def generate_from_phonemes( torch.cuda.empty_cache() # Convert to WAV bytes - wav_bytes = AudioService.convert_audio( + wav_bytes = await AudioService.convert_audio( audio, 24000, "wav", is_first_chunk=True, is_last_chunk=True, stream=False, ) diff --git a/api/src/routers/openai_compatible.py b/api/src/routers/openai_compatible.py index e65dd1c..5908a56 100644 --- a/api/src/routers/openai_compatible.py +++ b/api/src/routers/openai_compatible.py @@ -218,7 +218,7 @@ async def create_speech( ) # Convert to requested format - content = AudioService.convert_audio( + content = await AudioService.convert_audio( audio, 24000, request.response_format, is_first_chunk=True, stream=False ) diff --git a/api/src/services/audio.py b/api/src/services/audio.py index 4a45608..566511a 100644 --- a/api/src/services/audio.py +++ b/api/src/services/audio.py @@ -20,21 +20,26 @@ class AudioNormalizer: self.sample_rate = 24000 # Sample rate of the audio self.samples_to_trim = int(self.chunk_trim_ms * self.sample_rate / 1000) - def normalize( - self, audio_data: np.ndarray, is_last_chunk: bool = False - ) -> np.ndarray: - """Convert audio data to int16 range and trim chunk boundaries""" + async def normalize(self, audio_data: np.ndarray) -> np.ndarray: + """Convert audio data to int16 range and trim silence from start and end + + Args: + audio_data: Input audio data as numpy array + + Returns: + Normalized and trimmed audio data + """ if len(audio_data) == 0: raise ValueError("Audio data cannot be empty") - # Simple float32 to int16 conversion + # Convert to float32 for processing audio_float = audio_data.astype(np.float32) - # Trim for non-final chunks - if not is_last_chunk and len(audio_float) > self.samples_to_trim: - audio_float = audio_float[:-self.samples_to_trim] + # Trim start and end if enough samples + if len(audio_float) > (2 * self.samples_to_trim): + audio_float = audio_float[self.samples_to_trim:-self.samples_to_trim] - # Direct scaling like the non-streaming version + # Scale to int16 range return (audio_float * 32767).astype(np.int16) @@ -59,7 +64,7 @@ class AudioService: } @staticmethod - def convert_audio( + async def convert_audio( audio_data: np.ndarray, sample_rate: int, output_format: str, @@ -99,9 +104,7 @@ class AudioService: # Always normalize audio to ensure proper amplitude scaling if normalizer is None: normalizer = AudioNormalizer() - normalized_audio = normalizer.normalize( - audio_data, is_last_chunk=is_last_chunk - ) + normalized_audio = await normalizer.normalize(audio_data) if output_format == "pcm": # Raw 16-bit PCM samples, no header diff --git a/api/src/services/text_processing/chunker.py b/api/src/services/text_processing/chunker.py index 2bbda79..5ffbc00 100644 --- a/api/src/services/text_processing/chunker.py +++ b/api/src/services/text_processing/chunker.py @@ -1,53 +1,74 @@ -"""Text chunking service""" +"""Text chunking module for TTS processing""" -import re +from typing import List, AsyncGenerator +from . import semchunk_slim -from ...core.config import settings - - -def split_text(text: str, max_chunk=None): - """Split text into chunks on natural pause points +async def fallback_split(text: str, max_chars: int = 400) -> List[str]: + """Emergency length control - only used if chunks are too long""" + words = text.split() + chunks = [] + current = [] + current_len = 0 + + for word in words: + # Always include at least one word per chunk + if not current: + current.append(word) + current_len = len(word) + continue + + # Check if adding word would exceed limit + if current_len + len(word) + 1 <= max_chars: + current.append(word) + current_len += len(word) + 1 + else: + chunks.append(" ".join(current)) + current = [word] + current_len = len(word) + + if current: + chunks.append(" ".join(current)) + + return chunks +async def split_text(text: str, max_chunk: int = None) -> AsyncGenerator[str, None]: + """Split text into TTS-friendly chunks + Args: text: Text to split into chunks - max_chunk: Maximum chunk size (defaults to settings.max_chunk_size) + max_chunk: Maximum chunk size (defaults to 400) + + Yields: + Text chunks suitable for TTS processing """ if max_chunk is None: - max_chunk = settings.max_chunk_size - + max_chunk = 400 + if not isinstance(text, str): text = str(text) if text is not None else "" - + text = text.strip() if not text: return - - # First split into sentences - sentences = re.split(r"(?<=[.!?])\s+", text) - - for sentence in sentences: - sentence = sentence.strip() - if not sentence: + + # Initialize chunker targeting ~300 chars to allow for expansion + chunker = semchunk_slim.chunkerify( + lambda t: len(t) // 5, # Simple length-based target + chunk_size=60 # Target ~300 chars + ) + + # Get initial chunks + chunks = chunker(text) + + # Process chunks + for chunk in chunks: + chunk = chunk.strip() + if not chunk: continue - - # For medium-length sentences, split on punctuation - if len(sentence) > max_chunk: # Lower threshold for more consistent sizes - # First try splitting on semicolons and colons - parts = re.split(r"(?<=[;:])\s+", sentence) - - for part in parts: - part = part.strip() - if not part: - continue - - # If part is still long, split on commas - if len(part) > max_chunk: - subparts = re.split(r"(?<=,)\s+", part) - for subpart in subparts: - subpart = subpart.strip() - if subpart: - yield subpart - else: - yield part + + # Use fallback for any chunks that are too long + if len(chunk) > max_chunk: + for subchunk in await fallback_split(chunk, max_chunk): + yield subchunk else: - yield sentence + yield chunk diff --git a/api/src/services/text_processing/semchunk_slim.py b/api/src/services/text_processing/semchunk_slim.py new file mode 100644 index 0000000..eb73a78 --- /dev/null +++ b/api/src/services/text_processing/semchunk_slim.py @@ -0,0 +1,89 @@ +from __future__ import annotations +import re +from typing import Callable + +# Prioritize sentence boundaries for TTS +_NON_WHITESPACE_SEMANTIC_SPLITTERS = ( + '.', '!', '?', # Primary - sentence boundaries + ';', ':', # Secondary - major clause boundaries + ',', # Tertiary - minor clause boundaries + '(', ')', '[', ']', '"', '"', "'", "'", "'", '"', '`', # Other punctuation + '—', '…', # Dashes and ellipsis + '/', '\\', '–', '&', '-', # Word joiners +) +"""Semantic splitters ordered by priority for TTS chunking""" + +def _split_text(text: str) -> tuple[str, bool, list[str]]: + """Split text using the most semantically meaningful splitter possible.""" + + splitter_is_whitespace = True + + # Try splitting at, in order: + # - Newlines (natural paragraph breaks) + # - Spaces (if no other splits possible) + # - Semantic splitters (prioritizing sentence boundaries) + if '\n' in text or '\r' in text: + splitter = max(re.findall(r'[\r\n]+', text)) + + elif re.search(r'\s', text): + splitter = max(re.findall(r'\s+', text)) + + else: + # Find first semantic splitter present + for splitter in _NON_WHITESPACE_SEMANTIC_SPLITTERS: + if splitter in text: + splitter_is_whitespace = False + break + else: + return '', splitter_is_whitespace, list(text) + + return splitter, splitter_is_whitespace, text.split(splitter) + +class Chunker: + def __init__(self, chunk_size: int, token_counter: Callable[[str], int]) -> None: + self.chunk_size = chunk_size + self.token_counter = token_counter + + def __call__(self, text: str) -> list[str]: + """Split text into chunks based on semantic boundaries.""" + if not isinstance(text, str): + text = str(text) if text is not None else "" + + text = text.strip() + if not text: + return [] + + # Split the text + splitter, _, splits = _split_text(text) + + chunks = [] + current_chunk = [] + current_len = 0 + + for split in splits: + split = split.strip() + if not split: + continue + + # Check if adding this split would exceed chunk size + split_len = self.token_counter(split) + if current_len + split_len <= self.chunk_size: + current_chunk.append(split) + current_len += split_len + else: + # Save current chunk if it exists + if current_chunk: + chunks.append(splitter.join(current_chunk)) + # Start new chunk with current split + current_chunk = [split] + current_len = split_len + + # Add final chunk if it exists + if current_chunk: + chunks.append(splitter.join(current_chunk)) + + return chunks + +def chunkerify(token_counter: Callable[[str], int], chunk_size: int) -> Chunker: + """Create a chunker with the specified token counter and chunk size.""" + return Chunker(chunk_size=chunk_size, token_counter=token_counter) \ No newline at end of file diff --git a/api/src/services/tts_service.py b/api/src/services/tts_service.py index da4fc24..a261fc6 100644 --- a/api/src/services/tts_service.py +++ b/api/src/services/tts_service.py @@ -82,8 +82,11 @@ class TTSService: backend = self.model_manager.get_backend() voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device) - # Get all chunks upfront - chunks = list(chunker.split_text(text)) + # Get chunks using async generator + chunks = [] + async for chunk in chunker.split_text(text): + chunks.append(chunk) + if not chunks: raise ValueError("No text chunks to process") @@ -162,8 +165,11 @@ class TTSService: backend = self.model_manager.get_backend() voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device) - # Get all chunks upfront - chunks = list(chunker.split_text(text)) + # Get chunks using async generator + chunks = [] + async for chunk in chunker.split_text(text): + chunks.append(chunk) + if not chunks: raise ValueError("No text chunks to process") @@ -184,7 +190,7 @@ class TTSService: if chunk_audio is not None: # Convert to bytes - return AudioService.convert_audio( + return await AudioService.convert_audio( chunk_audio, 24000, output_format, diff --git a/pyproject.toml b/pyproject.toml index 87bb9dd..c815c23 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "html2text>=2024.2.26", "pydub>=0.25.1", "matplotlib>=3.10.0", + "semchunk>=3.0.1" ] [project.optional-dependencies]