Add async audio processing and semantic chunking support; flattened static audio trimming

This commit is contained in:
remsky 2025-01-24 04:06:47 -07:00
parent 31b5e33408
commit ee1f7cde18
8 changed files with 180 additions and 60 deletions

View file

@ -77,7 +77,7 @@ class ModelConfig(BaseModel):
voice_cache_size: int = Field(2, description="Maximum number of cached voices")
# Model filenames
pytorch_model_file: str = Field("kokoro-v0_19.pth", description="PyTorch model filename")
pytorch_model_file: str = Field("kokoro-v0_19-half.pth", description="PyTorch model filename")
onnx_model_file: str = Field("kokoro-v0_19.onnx", description="ONNX model filename")
# Backend-specific configs

View file

@ -138,7 +138,7 @@ async def generate_from_phonemes(
torch.cuda.empty_cache()
# Convert to WAV bytes
wav_bytes = AudioService.convert_audio(
wav_bytes = await AudioService.convert_audio(
audio, 24000, "wav", is_first_chunk=True, is_last_chunk=True, stream=False,
)

View file

@ -218,7 +218,7 @@ async def create_speech(
)
# Convert to requested format
content = AudioService.convert_audio(
content = await AudioService.convert_audio(
audio, 24000, request.response_format, is_first_chunk=True, stream=False
)

View file

@ -20,21 +20,26 @@ class AudioNormalizer:
self.sample_rate = 24000 # Sample rate of the audio
self.samples_to_trim = int(self.chunk_trim_ms * self.sample_rate / 1000)
def normalize(
self, audio_data: np.ndarray, is_last_chunk: bool = False
) -> np.ndarray:
"""Convert audio data to int16 range and trim chunk boundaries"""
async def normalize(self, audio_data: np.ndarray) -> np.ndarray:
"""Convert audio data to int16 range and trim silence from start and end
Args:
audio_data: Input audio data as numpy array
Returns:
Normalized and trimmed audio data
"""
if len(audio_data) == 0:
raise ValueError("Audio data cannot be empty")
# Simple float32 to int16 conversion
# Convert to float32 for processing
audio_float = audio_data.astype(np.float32)
# Trim for non-final chunks
if not is_last_chunk and len(audio_float) > self.samples_to_trim:
audio_float = audio_float[:-self.samples_to_trim]
# Trim start and end if enough samples
if len(audio_float) > (2 * self.samples_to_trim):
audio_float = audio_float[self.samples_to_trim:-self.samples_to_trim]
# Direct scaling like the non-streaming version
# Scale to int16 range
return (audio_float * 32767).astype(np.int16)
@ -59,7 +64,7 @@ class AudioService:
}
@staticmethod
def convert_audio(
async def convert_audio(
audio_data: np.ndarray,
sample_rate: int,
output_format: str,
@ -99,9 +104,7 @@ class AudioService:
# Always normalize audio to ensure proper amplitude scaling
if normalizer is None:
normalizer = AudioNormalizer()
normalized_audio = normalizer.normalize(
audio_data, is_last_chunk=is_last_chunk
)
normalized_audio = await normalizer.normalize(audio_data)
if output_format == "pcm":
# Raw 16-bit PCM samples, no header

View file

@ -1,53 +1,74 @@
"""Text chunking service"""
"""Text chunking module for TTS processing"""
import re
from typing import List, AsyncGenerator
from . import semchunk_slim
from ...core.config import settings
def split_text(text: str, max_chunk=None):
"""Split text into chunks on natural pause points
async def fallback_split(text: str, max_chars: int = 400) -> List[str]:
"""Emergency length control - only used if chunks are too long"""
words = text.split()
chunks = []
current = []
current_len = 0
for word in words:
# Always include at least one word per chunk
if not current:
current.append(word)
current_len = len(word)
continue
# Check if adding word would exceed limit
if current_len + len(word) + 1 <= max_chars:
current.append(word)
current_len += len(word) + 1
else:
chunks.append(" ".join(current))
current = [word]
current_len = len(word)
if current:
chunks.append(" ".join(current))
return chunks
async def split_text(text: str, max_chunk: int = None) -> AsyncGenerator[str, None]:
"""Split text into TTS-friendly chunks
Args:
text: Text to split into chunks
max_chunk: Maximum chunk size (defaults to settings.max_chunk_size)
max_chunk: Maximum chunk size (defaults to 400)
Yields:
Text chunks suitable for TTS processing
"""
if max_chunk is None:
max_chunk = settings.max_chunk_size
max_chunk = 400
if not isinstance(text, str):
text = str(text) if text is not None else ""
text = text.strip()
if not text:
return
# First split into sentences
sentences = re.split(r"(?<=[.!?])\s+", text)
for sentence in sentences:
sentence = sentence.strip()
if not sentence:
# Initialize chunker targeting ~300 chars to allow for expansion
chunker = semchunk_slim.chunkerify(
lambda t: len(t) // 5, # Simple length-based target
chunk_size=60 # Target ~300 chars
)
# Get initial chunks
chunks = chunker(text)
# Process chunks
for chunk in chunks:
chunk = chunk.strip()
if not chunk:
continue
# For medium-length sentences, split on punctuation
if len(sentence) > max_chunk: # Lower threshold for more consistent sizes
# First try splitting on semicolons and colons
parts = re.split(r"(?<=[;:])\s+", sentence)
for part in parts:
part = part.strip()
if not part:
continue
# If part is still long, split on commas
if len(part) > max_chunk:
subparts = re.split(r"(?<=,)\s+", part)
for subpart in subparts:
subpart = subpart.strip()
if subpart:
yield subpart
else:
yield part
# Use fallback for any chunks that are too long
if len(chunk) > max_chunk:
for subchunk in await fallback_split(chunk, max_chunk):
yield subchunk
else:
yield sentence
yield chunk

View file

@ -0,0 +1,89 @@
from __future__ import annotations
import re
from typing import Callable
# Prioritize sentence boundaries for TTS
_NON_WHITESPACE_SEMANTIC_SPLITTERS = (
'.', '!', '?', # Primary - sentence boundaries
';', ':', # Secondary - major clause boundaries
',', # Tertiary - minor clause boundaries
'(', ')', '[', ']', '"', '"', "'", "'", "'", '"', '`', # Other punctuation
'', '', # Dashes and ellipsis
'/', '\\', '', '&', '-', # Word joiners
)
"""Semantic splitters ordered by priority for TTS chunking"""
def _split_text(text: str) -> tuple[str, bool, list[str]]:
"""Split text using the most semantically meaningful splitter possible."""
splitter_is_whitespace = True
# Try splitting at, in order:
# - Newlines (natural paragraph breaks)
# - Spaces (if no other splits possible)
# - Semantic splitters (prioritizing sentence boundaries)
if '\n' in text or '\r' in text:
splitter = max(re.findall(r'[\r\n]+', text))
elif re.search(r'\s', text):
splitter = max(re.findall(r'\s+', text))
else:
# Find first semantic splitter present
for splitter in _NON_WHITESPACE_SEMANTIC_SPLITTERS:
if splitter in text:
splitter_is_whitespace = False
break
else:
return '', splitter_is_whitespace, list(text)
return splitter, splitter_is_whitespace, text.split(splitter)
class Chunker:
def __init__(self, chunk_size: int, token_counter: Callable[[str], int]) -> None:
self.chunk_size = chunk_size
self.token_counter = token_counter
def __call__(self, text: str) -> list[str]:
"""Split text into chunks based on semantic boundaries."""
if not isinstance(text, str):
text = str(text) if text is not None else ""
text = text.strip()
if not text:
return []
# Split the text
splitter, _, splits = _split_text(text)
chunks = []
current_chunk = []
current_len = 0
for split in splits:
split = split.strip()
if not split:
continue
# Check if adding this split would exceed chunk size
split_len = self.token_counter(split)
if current_len + split_len <= self.chunk_size:
current_chunk.append(split)
current_len += split_len
else:
# Save current chunk if it exists
if current_chunk:
chunks.append(splitter.join(current_chunk))
# Start new chunk with current split
current_chunk = [split]
current_len = split_len
# Add final chunk if it exists
if current_chunk:
chunks.append(splitter.join(current_chunk))
return chunks
def chunkerify(token_counter: Callable[[str], int], chunk_size: int) -> Chunker:
"""Create a chunker with the specified token counter and chunk size."""
return Chunker(chunk_size=chunk_size, token_counter=token_counter)

View file

@ -82,8 +82,11 @@ class TTSService:
backend = self.model_manager.get_backend()
voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device)
# Get all chunks upfront
chunks = list(chunker.split_text(text))
# Get chunks using async generator
chunks = []
async for chunk in chunker.split_text(text):
chunks.append(chunk)
if not chunks:
raise ValueError("No text chunks to process")
@ -162,8 +165,11 @@ class TTSService:
backend = self.model_manager.get_backend()
voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device)
# Get all chunks upfront
chunks = list(chunker.split_text(text))
# Get chunks using async generator
chunks = []
async for chunk in chunker.split_text(text):
chunks.append(chunk)
if not chunks:
raise ValueError("No text chunks to process")
@ -184,7 +190,7 @@ class TTSService:
if chunk_audio is not None:
# Convert to bytes
return AudioService.convert_audio(
return await AudioService.convert_audio(
chunk_audio,
24000,
output_format,

View file

@ -34,6 +34,7 @@ dependencies = [
"html2text>=2024.2.26",
"pydub>=0.25.1",
"matplotlib>=3.10.0",
"semchunk>=3.0.1"
]
[project.optional-dependencies]