mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Add async audio processing and semantic chunking support; flattened static audio trimming
This commit is contained in:
parent
31b5e33408
commit
ee1f7cde18
8 changed files with 180 additions and 60 deletions
|
@ -77,7 +77,7 @@ class ModelConfig(BaseModel):
|
||||||
voice_cache_size: int = Field(2, description="Maximum number of cached voices")
|
voice_cache_size: int = Field(2, description="Maximum number of cached voices")
|
||||||
|
|
||||||
# Model filenames
|
# Model filenames
|
||||||
pytorch_model_file: str = Field("kokoro-v0_19.pth", description="PyTorch model filename")
|
pytorch_model_file: str = Field("kokoro-v0_19-half.pth", description="PyTorch model filename")
|
||||||
onnx_model_file: str = Field("kokoro-v0_19.onnx", description="ONNX model filename")
|
onnx_model_file: str = Field("kokoro-v0_19.onnx", description="ONNX model filename")
|
||||||
|
|
||||||
# Backend-specific configs
|
# Backend-specific configs
|
||||||
|
|
|
@ -138,7 +138,7 @@ async def generate_from_phonemes(
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# Convert to WAV bytes
|
# Convert to WAV bytes
|
||||||
wav_bytes = AudioService.convert_audio(
|
wav_bytes = await AudioService.convert_audio(
|
||||||
audio, 24000, "wav", is_first_chunk=True, is_last_chunk=True, stream=False,
|
audio, 24000, "wav", is_first_chunk=True, is_last_chunk=True, stream=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -218,7 +218,7 @@ async def create_speech(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert to requested format
|
# Convert to requested format
|
||||||
content = AudioService.convert_audio(
|
content = await AudioService.convert_audio(
|
||||||
audio, 24000, request.response_format, is_first_chunk=True, stream=False
|
audio, 24000, request.response_format, is_first_chunk=True, stream=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -20,21 +20,26 @@ class AudioNormalizer:
|
||||||
self.sample_rate = 24000 # Sample rate of the audio
|
self.sample_rate = 24000 # Sample rate of the audio
|
||||||
self.samples_to_trim = int(self.chunk_trim_ms * self.sample_rate / 1000)
|
self.samples_to_trim = int(self.chunk_trim_ms * self.sample_rate / 1000)
|
||||||
|
|
||||||
def normalize(
|
async def normalize(self, audio_data: np.ndarray) -> np.ndarray:
|
||||||
self, audio_data: np.ndarray, is_last_chunk: bool = False
|
"""Convert audio data to int16 range and trim silence from start and end
|
||||||
) -> np.ndarray:
|
|
||||||
"""Convert audio data to int16 range and trim chunk boundaries"""
|
Args:
|
||||||
|
audio_data: Input audio data as numpy array
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized and trimmed audio data
|
||||||
|
"""
|
||||||
if len(audio_data) == 0:
|
if len(audio_data) == 0:
|
||||||
raise ValueError("Audio data cannot be empty")
|
raise ValueError("Audio data cannot be empty")
|
||||||
|
|
||||||
# Simple float32 to int16 conversion
|
# Convert to float32 for processing
|
||||||
audio_float = audio_data.astype(np.float32)
|
audio_float = audio_data.astype(np.float32)
|
||||||
|
|
||||||
# Trim for non-final chunks
|
# Trim start and end if enough samples
|
||||||
if not is_last_chunk and len(audio_float) > self.samples_to_trim:
|
if len(audio_float) > (2 * self.samples_to_trim):
|
||||||
audio_float = audio_float[:-self.samples_to_trim]
|
audio_float = audio_float[self.samples_to_trim:-self.samples_to_trim]
|
||||||
|
|
||||||
# Direct scaling like the non-streaming version
|
# Scale to int16 range
|
||||||
return (audio_float * 32767).astype(np.int16)
|
return (audio_float * 32767).astype(np.int16)
|
||||||
|
|
||||||
|
|
||||||
|
@ -59,7 +64,7 @@ class AudioService:
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_audio(
|
async def convert_audio(
|
||||||
audio_data: np.ndarray,
|
audio_data: np.ndarray,
|
||||||
sample_rate: int,
|
sample_rate: int,
|
||||||
output_format: str,
|
output_format: str,
|
||||||
|
@ -99,9 +104,7 @@ class AudioService:
|
||||||
# Always normalize audio to ensure proper amplitude scaling
|
# Always normalize audio to ensure proper amplitude scaling
|
||||||
if normalizer is None:
|
if normalizer is None:
|
||||||
normalizer = AudioNormalizer()
|
normalizer = AudioNormalizer()
|
||||||
normalized_audio = normalizer.normalize(
|
normalized_audio = await normalizer.normalize(audio_data)
|
||||||
audio_data, is_last_chunk=is_last_chunk
|
|
||||||
)
|
|
||||||
|
|
||||||
if output_format == "pcm":
|
if output_format == "pcm":
|
||||||
# Raw 16-bit PCM samples, no header
|
# Raw 16-bit PCM samples, no header
|
||||||
|
|
|
@ -1,19 +1,48 @@
|
||||||
"""Text chunking service"""
|
"""Text chunking module for TTS processing"""
|
||||||
|
|
||||||
import re
|
from typing import List, AsyncGenerator
|
||||||
|
from . import semchunk_slim
|
||||||
|
|
||||||
from ...core.config import settings
|
async def fallback_split(text: str, max_chars: int = 400) -> List[str]:
|
||||||
|
"""Emergency length control - only used if chunks are too long"""
|
||||||
|
words = text.split()
|
||||||
|
chunks = []
|
||||||
|
current = []
|
||||||
|
current_len = 0
|
||||||
|
|
||||||
|
for word in words:
|
||||||
|
# Always include at least one word per chunk
|
||||||
|
if not current:
|
||||||
|
current.append(word)
|
||||||
|
current_len = len(word)
|
||||||
|
continue
|
||||||
|
|
||||||
def split_text(text: str, max_chunk=None):
|
# Check if adding word would exceed limit
|
||||||
"""Split text into chunks on natural pause points
|
if current_len + len(word) + 1 <= max_chars:
|
||||||
|
current.append(word)
|
||||||
|
current_len += len(word) + 1
|
||||||
|
else:
|
||||||
|
chunks.append(" ".join(current))
|
||||||
|
current = [word]
|
||||||
|
current_len = len(word)
|
||||||
|
|
||||||
|
if current:
|
||||||
|
chunks.append(" ".join(current))
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
async def split_text(text: str, max_chunk: int = None) -> AsyncGenerator[str, None]:
|
||||||
|
"""Split text into TTS-friendly chunks
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: Text to split into chunks
|
text: Text to split into chunks
|
||||||
max_chunk: Maximum chunk size (defaults to settings.max_chunk_size)
|
max_chunk: Maximum chunk size (defaults to 400)
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Text chunks suitable for TTS processing
|
||||||
"""
|
"""
|
||||||
if max_chunk is None:
|
if max_chunk is None:
|
||||||
max_chunk = settings.max_chunk_size
|
max_chunk = 400
|
||||||
|
|
||||||
if not isinstance(text, str):
|
if not isinstance(text, str):
|
||||||
text = str(text) if text is not None else ""
|
text = str(text) if text is not None else ""
|
||||||
|
@ -22,32 +51,24 @@ def split_text(text: str, max_chunk=None):
|
||||||
if not text:
|
if not text:
|
||||||
return
|
return
|
||||||
|
|
||||||
# First split into sentences
|
# Initialize chunker targeting ~300 chars to allow for expansion
|
||||||
sentences = re.split(r"(?<=[.!?])\s+", text)
|
chunker = semchunk_slim.chunkerify(
|
||||||
|
lambda t: len(t) // 5, # Simple length-based target
|
||||||
|
chunk_size=60 # Target ~300 chars
|
||||||
|
)
|
||||||
|
|
||||||
for sentence in sentences:
|
# Get initial chunks
|
||||||
sentence = sentence.strip()
|
chunks = chunker(text)
|
||||||
if not sentence:
|
|
||||||
|
# Process chunks
|
||||||
|
for chunk in chunks:
|
||||||
|
chunk = chunk.strip()
|
||||||
|
if not chunk:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# For medium-length sentences, split on punctuation
|
# Use fallback for any chunks that are too long
|
||||||
if len(sentence) > max_chunk: # Lower threshold for more consistent sizes
|
if len(chunk) > max_chunk:
|
||||||
# First try splitting on semicolons and colons
|
for subchunk in await fallback_split(chunk, max_chunk):
|
||||||
parts = re.split(r"(?<=[;:])\s+", sentence)
|
yield subchunk
|
||||||
|
|
||||||
for part in parts:
|
|
||||||
part = part.strip()
|
|
||||||
if not part:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# If part is still long, split on commas
|
|
||||||
if len(part) > max_chunk:
|
|
||||||
subparts = re.split(r"(?<=,)\s+", part)
|
|
||||||
for subpart in subparts:
|
|
||||||
subpart = subpart.strip()
|
|
||||||
if subpart:
|
|
||||||
yield subpart
|
|
||||||
else:
|
else:
|
||||||
yield part
|
yield chunk
|
||||||
else:
|
|
||||||
yield sentence
|
|
||||||
|
|
89
api/src/services/text_processing/semchunk_slim.py
Normal file
89
api/src/services/text_processing/semchunk_slim.py
Normal file
|
@ -0,0 +1,89 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
import re
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
# Prioritize sentence boundaries for TTS
|
||||||
|
_NON_WHITESPACE_SEMANTIC_SPLITTERS = (
|
||||||
|
'.', '!', '?', # Primary - sentence boundaries
|
||||||
|
';', ':', # Secondary - major clause boundaries
|
||||||
|
',', # Tertiary - minor clause boundaries
|
||||||
|
'(', ')', '[', ']', '"', '"', "'", "'", "'", '"', '`', # Other punctuation
|
||||||
|
'—', '…', # Dashes and ellipsis
|
||||||
|
'/', '\\', '–', '&', '-', # Word joiners
|
||||||
|
)
|
||||||
|
"""Semantic splitters ordered by priority for TTS chunking"""
|
||||||
|
|
||||||
|
def _split_text(text: str) -> tuple[str, bool, list[str]]:
|
||||||
|
"""Split text using the most semantically meaningful splitter possible."""
|
||||||
|
|
||||||
|
splitter_is_whitespace = True
|
||||||
|
|
||||||
|
# Try splitting at, in order:
|
||||||
|
# - Newlines (natural paragraph breaks)
|
||||||
|
# - Spaces (if no other splits possible)
|
||||||
|
# - Semantic splitters (prioritizing sentence boundaries)
|
||||||
|
if '\n' in text or '\r' in text:
|
||||||
|
splitter = max(re.findall(r'[\r\n]+', text))
|
||||||
|
|
||||||
|
elif re.search(r'\s', text):
|
||||||
|
splitter = max(re.findall(r'\s+', text))
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Find first semantic splitter present
|
||||||
|
for splitter in _NON_WHITESPACE_SEMANTIC_SPLITTERS:
|
||||||
|
if splitter in text:
|
||||||
|
splitter_is_whitespace = False
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
return '', splitter_is_whitespace, list(text)
|
||||||
|
|
||||||
|
return splitter, splitter_is_whitespace, text.split(splitter)
|
||||||
|
|
||||||
|
class Chunker:
|
||||||
|
def __init__(self, chunk_size: int, token_counter: Callable[[str], int]) -> None:
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
self.token_counter = token_counter
|
||||||
|
|
||||||
|
def __call__(self, text: str) -> list[str]:
|
||||||
|
"""Split text into chunks based on semantic boundaries."""
|
||||||
|
if not isinstance(text, str):
|
||||||
|
text = str(text) if text is not None else ""
|
||||||
|
|
||||||
|
text = text.strip()
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Split the text
|
||||||
|
splitter, _, splits = _split_text(text)
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
current_chunk = []
|
||||||
|
current_len = 0
|
||||||
|
|
||||||
|
for split in splits:
|
||||||
|
split = split.strip()
|
||||||
|
if not split:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if adding this split would exceed chunk size
|
||||||
|
split_len = self.token_counter(split)
|
||||||
|
if current_len + split_len <= self.chunk_size:
|
||||||
|
current_chunk.append(split)
|
||||||
|
current_len += split_len
|
||||||
|
else:
|
||||||
|
# Save current chunk if it exists
|
||||||
|
if current_chunk:
|
||||||
|
chunks.append(splitter.join(current_chunk))
|
||||||
|
# Start new chunk with current split
|
||||||
|
current_chunk = [split]
|
||||||
|
current_len = split_len
|
||||||
|
|
||||||
|
# Add final chunk if it exists
|
||||||
|
if current_chunk:
|
||||||
|
chunks.append(splitter.join(current_chunk))
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
def chunkerify(token_counter: Callable[[str], int], chunk_size: int) -> Chunker:
|
||||||
|
"""Create a chunker with the specified token counter and chunk size."""
|
||||||
|
return Chunker(chunk_size=chunk_size, token_counter=token_counter)
|
|
@ -82,8 +82,11 @@ class TTSService:
|
||||||
backend = self.model_manager.get_backend()
|
backend = self.model_manager.get_backend()
|
||||||
voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device)
|
voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device)
|
||||||
|
|
||||||
# Get all chunks upfront
|
# Get chunks using async generator
|
||||||
chunks = list(chunker.split_text(text))
|
chunks = []
|
||||||
|
async for chunk in chunker.split_text(text):
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
if not chunks:
|
if not chunks:
|
||||||
raise ValueError("No text chunks to process")
|
raise ValueError("No text chunks to process")
|
||||||
|
|
||||||
|
@ -162,8 +165,11 @@ class TTSService:
|
||||||
backend = self.model_manager.get_backend()
|
backend = self.model_manager.get_backend()
|
||||||
voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device)
|
voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device)
|
||||||
|
|
||||||
# Get all chunks upfront
|
# Get chunks using async generator
|
||||||
chunks = list(chunker.split_text(text))
|
chunks = []
|
||||||
|
async for chunk in chunker.split_text(text):
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
if not chunks:
|
if not chunks:
|
||||||
raise ValueError("No text chunks to process")
|
raise ValueError("No text chunks to process")
|
||||||
|
|
||||||
|
@ -184,7 +190,7 @@ class TTSService:
|
||||||
|
|
||||||
if chunk_audio is not None:
|
if chunk_audio is not None:
|
||||||
# Convert to bytes
|
# Convert to bytes
|
||||||
return AudioService.convert_audio(
|
return await AudioService.convert_audio(
|
||||||
chunk_audio,
|
chunk_audio,
|
||||||
24000,
|
24000,
|
||||||
output_format,
|
output_format,
|
||||||
|
|
|
@ -34,6 +34,7 @@ dependencies = [
|
||||||
"html2text>=2024.2.26",
|
"html2text>=2024.2.26",
|
||||||
"pydub>=0.25.1",
|
"pydub>=0.25.1",
|
||||||
"matplotlib>=3.10.0",
|
"matplotlib>=3.10.0",
|
||||||
|
"semchunk>=3.0.1"
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|
Loading…
Add table
Reference in a new issue