mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Add async audio processing and semantic chunking support; flattened static audio trimming
This commit is contained in:
parent
31b5e33408
commit
ee1f7cde18
8 changed files with 180 additions and 60 deletions
|
@ -77,7 +77,7 @@ class ModelConfig(BaseModel):
|
|||
voice_cache_size: int = Field(2, description="Maximum number of cached voices")
|
||||
|
||||
# Model filenames
|
||||
pytorch_model_file: str = Field("kokoro-v0_19.pth", description="PyTorch model filename")
|
||||
pytorch_model_file: str = Field("kokoro-v0_19-half.pth", description="PyTorch model filename")
|
||||
onnx_model_file: str = Field("kokoro-v0_19.onnx", description="ONNX model filename")
|
||||
|
||||
# Backend-specific configs
|
||||
|
|
|
@ -138,7 +138,7 @@ async def generate_from_phonemes(
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
# Convert to WAV bytes
|
||||
wav_bytes = AudioService.convert_audio(
|
||||
wav_bytes = await AudioService.convert_audio(
|
||||
audio, 24000, "wav", is_first_chunk=True, is_last_chunk=True, stream=False,
|
||||
)
|
||||
|
||||
|
|
|
@ -218,7 +218,7 @@ async def create_speech(
|
|||
)
|
||||
|
||||
# Convert to requested format
|
||||
content = AudioService.convert_audio(
|
||||
content = await AudioService.convert_audio(
|
||||
audio, 24000, request.response_format, is_first_chunk=True, stream=False
|
||||
)
|
||||
|
||||
|
|
|
@ -20,21 +20,26 @@ class AudioNormalizer:
|
|||
self.sample_rate = 24000 # Sample rate of the audio
|
||||
self.samples_to_trim = int(self.chunk_trim_ms * self.sample_rate / 1000)
|
||||
|
||||
def normalize(
|
||||
self, audio_data: np.ndarray, is_last_chunk: bool = False
|
||||
) -> np.ndarray:
|
||||
"""Convert audio data to int16 range and trim chunk boundaries"""
|
||||
async def normalize(self, audio_data: np.ndarray) -> np.ndarray:
|
||||
"""Convert audio data to int16 range and trim silence from start and end
|
||||
|
||||
Args:
|
||||
audio_data: Input audio data as numpy array
|
||||
|
||||
Returns:
|
||||
Normalized and trimmed audio data
|
||||
"""
|
||||
if len(audio_data) == 0:
|
||||
raise ValueError("Audio data cannot be empty")
|
||||
|
||||
# Simple float32 to int16 conversion
|
||||
# Convert to float32 for processing
|
||||
audio_float = audio_data.astype(np.float32)
|
||||
|
||||
# Trim for non-final chunks
|
||||
if not is_last_chunk and len(audio_float) > self.samples_to_trim:
|
||||
audio_float = audio_float[:-self.samples_to_trim]
|
||||
# Trim start and end if enough samples
|
||||
if len(audio_float) > (2 * self.samples_to_trim):
|
||||
audio_float = audio_float[self.samples_to_trim:-self.samples_to_trim]
|
||||
|
||||
# Direct scaling like the non-streaming version
|
||||
# Scale to int16 range
|
||||
return (audio_float * 32767).astype(np.int16)
|
||||
|
||||
|
||||
|
@ -59,7 +64,7 @@ class AudioService:
|
|||
}
|
||||
|
||||
@staticmethod
|
||||
def convert_audio(
|
||||
async def convert_audio(
|
||||
audio_data: np.ndarray,
|
||||
sample_rate: int,
|
||||
output_format: str,
|
||||
|
@ -99,9 +104,7 @@ class AudioService:
|
|||
# Always normalize audio to ensure proper amplitude scaling
|
||||
if normalizer is None:
|
||||
normalizer = AudioNormalizer()
|
||||
normalized_audio = normalizer.normalize(
|
||||
audio_data, is_last_chunk=is_last_chunk
|
||||
)
|
||||
normalized_audio = await normalizer.normalize(audio_data)
|
||||
|
||||
if output_format == "pcm":
|
||||
# Raw 16-bit PCM samples, no header
|
||||
|
|
|
@ -1,53 +1,74 @@
|
|||
"""Text chunking service"""
|
||||
"""Text chunking module for TTS processing"""
|
||||
|
||||
import re
|
||||
from typing import List, AsyncGenerator
|
||||
from . import semchunk_slim
|
||||
|
||||
from ...core.config import settings
|
||||
|
||||
|
||||
def split_text(text: str, max_chunk=None):
|
||||
"""Split text into chunks on natural pause points
|
||||
async def fallback_split(text: str, max_chars: int = 400) -> List[str]:
|
||||
"""Emergency length control - only used if chunks are too long"""
|
||||
words = text.split()
|
||||
chunks = []
|
||||
current = []
|
||||
current_len = 0
|
||||
|
||||
for word in words:
|
||||
# Always include at least one word per chunk
|
||||
if not current:
|
||||
current.append(word)
|
||||
current_len = len(word)
|
||||
continue
|
||||
|
||||
# Check if adding word would exceed limit
|
||||
if current_len + len(word) + 1 <= max_chars:
|
||||
current.append(word)
|
||||
current_len += len(word) + 1
|
||||
else:
|
||||
chunks.append(" ".join(current))
|
||||
current = [word]
|
||||
current_len = len(word)
|
||||
|
||||
if current:
|
||||
chunks.append(" ".join(current))
|
||||
|
||||
return chunks
|
||||
|
||||
async def split_text(text: str, max_chunk: int = None) -> AsyncGenerator[str, None]:
|
||||
"""Split text into TTS-friendly chunks
|
||||
|
||||
Args:
|
||||
text: Text to split into chunks
|
||||
max_chunk: Maximum chunk size (defaults to settings.max_chunk_size)
|
||||
max_chunk: Maximum chunk size (defaults to 400)
|
||||
|
||||
Yields:
|
||||
Text chunks suitable for TTS processing
|
||||
"""
|
||||
if max_chunk is None:
|
||||
max_chunk = settings.max_chunk_size
|
||||
|
||||
max_chunk = 400
|
||||
|
||||
if not isinstance(text, str):
|
||||
text = str(text) if text is not None else ""
|
||||
|
||||
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return
|
||||
|
||||
# First split into sentences
|
||||
sentences = re.split(r"(?<=[.!?])\s+", text)
|
||||
|
||||
for sentence in sentences:
|
||||
sentence = sentence.strip()
|
||||
if not sentence:
|
||||
|
||||
# Initialize chunker targeting ~300 chars to allow for expansion
|
||||
chunker = semchunk_slim.chunkerify(
|
||||
lambda t: len(t) // 5, # Simple length-based target
|
||||
chunk_size=60 # Target ~300 chars
|
||||
)
|
||||
|
||||
# Get initial chunks
|
||||
chunks = chunker(text)
|
||||
|
||||
# Process chunks
|
||||
for chunk in chunks:
|
||||
chunk = chunk.strip()
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
# For medium-length sentences, split on punctuation
|
||||
if len(sentence) > max_chunk: # Lower threshold for more consistent sizes
|
||||
# First try splitting on semicolons and colons
|
||||
parts = re.split(r"(?<=[;:])\s+", sentence)
|
||||
|
||||
for part in parts:
|
||||
part = part.strip()
|
||||
if not part:
|
||||
continue
|
||||
|
||||
# If part is still long, split on commas
|
||||
if len(part) > max_chunk:
|
||||
subparts = re.split(r"(?<=,)\s+", part)
|
||||
for subpart in subparts:
|
||||
subpart = subpart.strip()
|
||||
if subpart:
|
||||
yield subpart
|
||||
else:
|
||||
yield part
|
||||
|
||||
# Use fallback for any chunks that are too long
|
||||
if len(chunk) > max_chunk:
|
||||
for subchunk in await fallback_split(chunk, max_chunk):
|
||||
yield subchunk
|
||||
else:
|
||||
yield sentence
|
||||
yield chunk
|
||||
|
|
89
api/src/services/text_processing/semchunk_slim.py
Normal file
89
api/src/services/text_processing/semchunk_slim.py
Normal file
|
@ -0,0 +1,89 @@
|
|||
from __future__ import annotations
|
||||
import re
|
||||
from typing import Callable
|
||||
|
||||
# Prioritize sentence boundaries for TTS
|
||||
_NON_WHITESPACE_SEMANTIC_SPLITTERS = (
|
||||
'.', '!', '?', # Primary - sentence boundaries
|
||||
';', ':', # Secondary - major clause boundaries
|
||||
',', # Tertiary - minor clause boundaries
|
||||
'(', ')', '[', ']', '"', '"', "'", "'", "'", '"', '`', # Other punctuation
|
||||
'—', '…', # Dashes and ellipsis
|
||||
'/', '\\', '–', '&', '-', # Word joiners
|
||||
)
|
||||
"""Semantic splitters ordered by priority for TTS chunking"""
|
||||
|
||||
def _split_text(text: str) -> tuple[str, bool, list[str]]:
|
||||
"""Split text using the most semantically meaningful splitter possible."""
|
||||
|
||||
splitter_is_whitespace = True
|
||||
|
||||
# Try splitting at, in order:
|
||||
# - Newlines (natural paragraph breaks)
|
||||
# - Spaces (if no other splits possible)
|
||||
# - Semantic splitters (prioritizing sentence boundaries)
|
||||
if '\n' in text or '\r' in text:
|
||||
splitter = max(re.findall(r'[\r\n]+', text))
|
||||
|
||||
elif re.search(r'\s', text):
|
||||
splitter = max(re.findall(r'\s+', text))
|
||||
|
||||
else:
|
||||
# Find first semantic splitter present
|
||||
for splitter in _NON_WHITESPACE_SEMANTIC_SPLITTERS:
|
||||
if splitter in text:
|
||||
splitter_is_whitespace = False
|
||||
break
|
||||
else:
|
||||
return '', splitter_is_whitespace, list(text)
|
||||
|
||||
return splitter, splitter_is_whitespace, text.split(splitter)
|
||||
|
||||
class Chunker:
|
||||
def __init__(self, chunk_size: int, token_counter: Callable[[str], int]) -> None:
|
||||
self.chunk_size = chunk_size
|
||||
self.token_counter = token_counter
|
||||
|
||||
def __call__(self, text: str) -> list[str]:
|
||||
"""Split text into chunks based on semantic boundaries."""
|
||||
if not isinstance(text, str):
|
||||
text = str(text) if text is not None else ""
|
||||
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
|
||||
# Split the text
|
||||
splitter, _, splits = _split_text(text)
|
||||
|
||||
chunks = []
|
||||
current_chunk = []
|
||||
current_len = 0
|
||||
|
||||
for split in splits:
|
||||
split = split.strip()
|
||||
if not split:
|
||||
continue
|
||||
|
||||
# Check if adding this split would exceed chunk size
|
||||
split_len = self.token_counter(split)
|
||||
if current_len + split_len <= self.chunk_size:
|
||||
current_chunk.append(split)
|
||||
current_len += split_len
|
||||
else:
|
||||
# Save current chunk if it exists
|
||||
if current_chunk:
|
||||
chunks.append(splitter.join(current_chunk))
|
||||
# Start new chunk with current split
|
||||
current_chunk = [split]
|
||||
current_len = split_len
|
||||
|
||||
# Add final chunk if it exists
|
||||
if current_chunk:
|
||||
chunks.append(splitter.join(current_chunk))
|
||||
|
||||
return chunks
|
||||
|
||||
def chunkerify(token_counter: Callable[[str], int], chunk_size: int) -> Chunker:
|
||||
"""Create a chunker with the specified token counter and chunk size."""
|
||||
return Chunker(chunk_size=chunk_size, token_counter=token_counter)
|
|
@ -82,8 +82,11 @@ class TTSService:
|
|||
backend = self.model_manager.get_backend()
|
||||
voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device)
|
||||
|
||||
# Get all chunks upfront
|
||||
chunks = list(chunker.split_text(text))
|
||||
# Get chunks using async generator
|
||||
chunks = []
|
||||
async for chunk in chunker.split_text(text):
|
||||
chunks.append(chunk)
|
||||
|
||||
if not chunks:
|
||||
raise ValueError("No text chunks to process")
|
||||
|
||||
|
@ -162,8 +165,11 @@ class TTSService:
|
|||
backend = self.model_manager.get_backend()
|
||||
voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device)
|
||||
|
||||
# Get all chunks upfront
|
||||
chunks = list(chunker.split_text(text))
|
||||
# Get chunks using async generator
|
||||
chunks = []
|
||||
async for chunk in chunker.split_text(text):
|
||||
chunks.append(chunk)
|
||||
|
||||
if not chunks:
|
||||
raise ValueError("No text chunks to process")
|
||||
|
||||
|
@ -184,7 +190,7 @@ class TTSService:
|
|||
|
||||
if chunk_audio is not None:
|
||||
# Convert to bytes
|
||||
return AudioService.convert_audio(
|
||||
return await AudioService.convert_audio(
|
||||
chunk_audio,
|
||||
24000,
|
||||
output_format,
|
||||
|
|
|
@ -34,6 +34,7 @@ dependencies = [
|
|||
"html2text>=2024.2.26",
|
||||
"pydub>=0.25.1",
|
||||
"matplotlib>=3.10.0",
|
||||
"semchunk>=3.0.1"
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
Loading…
Add table
Reference in a new issue