mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Refactor audio processing and cleanup: remove unused chunker, enhance StreamingAudioWriter for better MP3 handling, and improve text processing compatibility.
This commit is contained in:
parent
8a60a2b90c
commit
75889e157d
6 changed files with 324 additions and 192 deletions
|
@ -30,9 +30,6 @@ class AudioNormalizer:
|
||||||
Returns:
|
Returns:
|
||||||
Normalized and trimmed audio data
|
Normalized and trimmed audio data
|
||||||
"""
|
"""
|
||||||
if len(audio_data) == 0:
|
|
||||||
raise ValueError("Audio data cannot be empty")
|
|
||||||
|
|
||||||
# Convert to float32 for processing
|
# Convert to float32 for processing
|
||||||
audio_float = audio_data.astype(np.float32)
|
audio_float = audio_data.astype(np.float32)
|
||||||
|
|
||||||
|
@ -102,17 +99,14 @@ class AudioService:
|
||||||
)
|
)
|
||||||
writer = AudioService._writers[writer_key]
|
writer = AudioService._writers[writer_key]
|
||||||
|
|
||||||
# Write the current chunk
|
# Write chunk or finalize
|
||||||
|
if is_last_chunk:
|
||||||
|
chunk_data = writer.write_chunk(finalize=True)
|
||||||
|
del AudioService._writers[writer_key]
|
||||||
|
else:
|
||||||
chunk_data = writer.write_chunk(normalized_audio)
|
chunk_data = writer.write_chunk(normalized_audio)
|
||||||
|
|
||||||
# Handle last chunk and cleanup
|
return chunk_data if chunk_data else b''
|
||||||
if is_last_chunk:
|
|
||||||
final_data = writer.close()
|
|
||||||
if final_data:
|
|
||||||
chunk_data += final_data
|
|
||||||
del AudioService._writers[writer_key]
|
|
||||||
|
|
||||||
return chunk_data
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error converting audio stream to {output_format}: {str(e)}")
|
logger.error(f"Error converting audio stream to {output_format}: {str(e)}")
|
||||||
|
|
|
@ -33,7 +33,9 @@ class StreamingAudioWriter:
|
||||||
elif self.format == "mp3":
|
elif self.format == "mp3":
|
||||||
# For MP3, we'll use pydub's incremental writer
|
# For MP3, we'll use pydub's incremental writer
|
||||||
self.buffer = BytesIO()
|
self.buffer = BytesIO()
|
||||||
self.encoder = AudioSegment.from_mono_audiosegments()
|
self.segments = [] # Store segments until we have enough data
|
||||||
|
# Initialize an empty AudioSegment as our encoder
|
||||||
|
self.encoder = AudioSegment.silent(duration=0, frame_rate=self.sample_rate)
|
||||||
|
|
||||||
def _write_wav_header(self) -> bytes:
|
def _write_wav_header(self) -> bytes:
|
||||||
"""Write WAV header with correct streaming format"""
|
"""Write WAV header with correct streaming format"""
|
||||||
|
@ -53,12 +55,45 @@ class StreamingAudioWriter:
|
||||||
header.write(struct.pack('<L', 0)) # Placeholder for data size
|
header.write(struct.pack('<L', 0)) # Placeholder for data size
|
||||||
return header.getvalue()
|
return header.getvalue()
|
||||||
|
|
||||||
def write_chunk(self, audio_data: np.ndarray) -> bytes:
|
def write_chunk(self, audio_data: Optional[np.ndarray] = None, finalize: bool = False) -> bytes:
|
||||||
"""Write a chunk of audio data and return bytes in the target format"""
|
"""Write a chunk of audio data and return bytes in the target format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio_data: Audio data to write, or None if finalizing
|
||||||
|
finalize: Whether this is the final write to close the stream
|
||||||
|
"""
|
||||||
buffer = BytesIO()
|
buffer = BytesIO()
|
||||||
|
|
||||||
|
if finalize:
|
||||||
if self.format == "wav":
|
if self.format == "wav":
|
||||||
# For WAV, we write raw PCM after the first chunk
|
# Write final WAV header with correct sizes
|
||||||
|
buffer.write(b'RIFF')
|
||||||
|
buffer.write(struct.pack('<L', self.bytes_written + 36))
|
||||||
|
buffer.write(b'WAVE')
|
||||||
|
buffer.write(b'fmt ')
|
||||||
|
buffer.write(struct.pack('<L', 16))
|
||||||
|
buffer.write(struct.pack('<H', 1))
|
||||||
|
buffer.write(struct.pack('<H', self.channels))
|
||||||
|
buffer.write(struct.pack('<L', self.sample_rate))
|
||||||
|
buffer.write(struct.pack('<L', self.sample_rate * self.channels * 2))
|
||||||
|
buffer.write(struct.pack('<H', self.channels * 2))
|
||||||
|
buffer.write(struct.pack('<H', 16))
|
||||||
|
buffer.write(b'data')
|
||||||
|
buffer.write(struct.pack('<L', self.bytes_written))
|
||||||
|
elif self.format == "ogg":
|
||||||
|
self.writer.close()
|
||||||
|
elif self.format == "mp3":
|
||||||
|
# Final export of any remaining audio
|
||||||
|
if hasattr(self, 'encoder') and len(self.encoder) > 0:
|
||||||
|
self.encoder.export(buffer, format="mp3", bitrate="192k", parameters=["-q:a", "2"])
|
||||||
|
self.encoder = None
|
||||||
|
return buffer.getvalue()
|
||||||
|
|
||||||
|
if audio_data is None or len(audio_data) == 0:
|
||||||
|
return b''
|
||||||
|
|
||||||
|
if self.format == "wav":
|
||||||
|
# For WAV, write raw PCM after the first chunk
|
||||||
if self.bytes_written == 0:
|
if self.bytes_written == 0:
|
||||||
buffer.write(self._write_wav_header())
|
buffer.write(self._write_wav_header())
|
||||||
buffer.write(audio_data.tobytes())
|
buffer.write(audio_data.tobytes())
|
||||||
|
@ -83,8 +118,20 @@ class StreamingAudioWriter:
|
||||||
sample_width=audio_data.dtype.itemsize,
|
sample_width=audio_data.dtype.itemsize,
|
||||||
channels=self.channels
|
channels=self.channels
|
||||||
)
|
)
|
||||||
self.encoder += segment
|
|
||||||
self.encoder.export(buffer, format="mp3")
|
# Add segment to encoder
|
||||||
|
self.encoder = self.encoder + segment
|
||||||
|
|
||||||
|
# Export current state to buffer
|
||||||
|
self.encoder.export(buffer, format="mp3", bitrate="192k", parameters=["-q:a", "2"])
|
||||||
|
|
||||||
|
# Get the encoded data
|
||||||
|
encoded_data = buffer.getvalue()
|
||||||
|
|
||||||
|
# Reset encoder to prevent memory growth
|
||||||
|
self.encoder = AudioSegment.silent(duration=0, frame_rate=self.sample_rate)
|
||||||
|
|
||||||
|
return encoded_data
|
||||||
|
|
||||||
return buffer.getvalue()
|
return buffer.getvalue()
|
||||||
|
|
||||||
|
|
|
@ -1,28 +1,19 @@
|
||||||
"""Text processing pipeline."""
|
"""Text processing pipeline."""
|
||||||
|
|
||||||
from .chunker import split_text
|
|
||||||
from .normalizer import normalize_text
|
from .normalizer import normalize_text
|
||||||
from .phonemizer import phonemize
|
from .phonemizer import phonemize
|
||||||
from .vocabulary import tokenize
|
from .vocabulary import tokenize
|
||||||
|
from .text_processor import process_text_chunk, smart_split
|
||||||
|
|
||||||
|
def process_text(text: str) -> list[int]:
|
||||||
|
"""Process text into token IDs (for backward compatibility)."""
|
||||||
|
return process_text_chunk(text)
|
||||||
|
|
||||||
def process_text(text: str, language: str = "a") -> list[int]:
|
__all__ = [
|
||||||
"""Process text through the full pipeline.
|
'normalize_text',
|
||||||
|
'phonemize',
|
||||||
Args:
|
'tokenize',
|
||||||
text: Input text
|
'process_text',
|
||||||
language: Language code ('a' for US English, 'b' for British English)
|
'process_text_chunk',
|
||||||
|
'smart_split'
|
||||||
Returns:
|
]
|
||||||
List of token IDs
|
|
||||||
|
|
||||||
Note:
|
|
||||||
The pipeline:
|
|
||||||
1. Converts text to phonemes using phonemizer
|
|
||||||
2. Converts phonemes to token IDs using vocabulary
|
|
||||||
"""
|
|
||||||
# Convert text to phonemes
|
|
||||||
phonemes = phonemize(text, language=language)
|
|
||||||
|
|
||||||
# Convert phonemes to token IDs
|
|
||||||
return tokenize(phonemes)
|
|
||||||
|
|
|
@ -1,89 +0,0 @@
|
||||||
from __future__ import annotations
|
|
||||||
import re
|
|
||||||
from typing import Callable
|
|
||||||
|
|
||||||
# Prioritize sentence boundaries for TTS
|
|
||||||
_NON_WHITESPACE_SEMANTIC_SPLITTERS = (
|
|
||||||
'.', '!', '?', # Primary - sentence boundaries
|
|
||||||
';', ':', # Secondary - major clause boundaries
|
|
||||||
',', # Tertiary - minor clause boundaries
|
|
||||||
'(', ')', '[', ']', '"', '"', "'", "'", "'", '"', '`', # Other punctuation
|
|
||||||
'—', '…', # Dashes and ellipsis
|
|
||||||
'/', '\\', '–', '&', '-', # Word joiners
|
|
||||||
)
|
|
||||||
"""Semantic splitters ordered by priority for TTS chunking"""
|
|
||||||
|
|
||||||
def _split_text(text: str) -> tuple[str, bool, list[str]]:
|
|
||||||
"""Split text using the most semantically meaningful splitter possible."""
|
|
||||||
|
|
||||||
splitter_is_whitespace = True
|
|
||||||
|
|
||||||
# Try splitting at, in order:
|
|
||||||
# - Newlines (natural paragraph breaks)
|
|
||||||
# - Spaces (if no other splits possible)
|
|
||||||
# - Semantic splitters (prioritizing sentence boundaries)
|
|
||||||
if '\n' in text or '\r' in text:
|
|
||||||
splitter = max(re.findall(r'[\r\n]+', text))
|
|
||||||
|
|
||||||
elif re.search(r'\s', text):
|
|
||||||
splitter = max(re.findall(r'\s+', text))
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Find first semantic splitter present
|
|
||||||
for splitter in _NON_WHITESPACE_SEMANTIC_SPLITTERS:
|
|
||||||
if splitter in text:
|
|
||||||
splitter_is_whitespace = False
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
return '', splitter_is_whitespace, list(text)
|
|
||||||
|
|
||||||
return splitter, splitter_is_whitespace, text.split(splitter)
|
|
||||||
|
|
||||||
class Chunker:
|
|
||||||
def __init__(self, chunk_size: int, token_counter: Callable[[str], int]) -> None:
|
|
||||||
self.chunk_size = chunk_size
|
|
||||||
self.token_counter = token_counter
|
|
||||||
|
|
||||||
def __call__(self, text: str) -> list[str]:
|
|
||||||
"""Split text into chunks based on semantic boundaries."""
|
|
||||||
if not isinstance(text, str):
|
|
||||||
text = str(text) if text is not None else ""
|
|
||||||
|
|
||||||
text = text.strip()
|
|
||||||
if not text:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Split the text
|
|
||||||
splitter, _, splits = _split_text(text)
|
|
||||||
|
|
||||||
chunks = []
|
|
||||||
current_chunk = []
|
|
||||||
current_len = 0
|
|
||||||
|
|
||||||
for split in splits:
|
|
||||||
split = split.strip()
|
|
||||||
if not split:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check if adding this split would exceed chunk size
|
|
||||||
split_len = self.token_counter(split)
|
|
||||||
if current_len + split_len <= self.chunk_size:
|
|
||||||
current_chunk.append(split)
|
|
||||||
current_len += split_len
|
|
||||||
else:
|
|
||||||
# Save current chunk if it exists
|
|
||||||
if current_chunk:
|
|
||||||
chunks.append(splitter.join(current_chunk))
|
|
||||||
# Start new chunk with current split
|
|
||||||
current_chunk = [split]
|
|
||||||
current_len = split_len
|
|
||||||
|
|
||||||
# Add final chunk if it exists
|
|
||||||
if current_chunk:
|
|
||||||
chunks.append(splitter.join(current_chunk))
|
|
||||||
|
|
||||||
return chunks
|
|
||||||
|
|
||||||
def chunkerify(token_counter: Callable[[str], int], chunk_size: int) -> Chunker:
|
|
||||||
"""Create a chunker with the specified token counter and chunk size."""
|
|
||||||
return Chunker(chunk_size=chunk_size, token_counter=token_counter)
|
|
177
api/src/services/text_processing/text_processor.py
Normal file
177
api/src/services/text_processing/text_processor.py
Normal file
|
@ -0,0 +1,177 @@
|
||||||
|
"""Unified text processing for TTS with smart chunking."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from typing import AsyncGenerator, List, Tuple
|
||||||
|
from loguru import logger
|
||||||
|
from .phonemizer import phonemize
|
||||||
|
from .normalizer import normalize_text
|
||||||
|
from .vocabulary import tokenize
|
||||||
|
|
||||||
|
def process_text_chunk(text: str, language: str = "a") -> List[int]:
|
||||||
|
"""Process a chunk of text through normalization, phonemization, and tokenization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text chunk to process
|
||||||
|
language: Language code for phonemization
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of token IDs
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Normalize
|
||||||
|
t0 = time.time()
|
||||||
|
normalized = normalize_text(text)
|
||||||
|
t1 = time.time()
|
||||||
|
logger.debug(f"Normalization took {(t1-t0)*1000:.2f}ms for {len(text)} chars")
|
||||||
|
|
||||||
|
# Phonemize
|
||||||
|
t0 = time.time()
|
||||||
|
phonemes = phonemize(normalized, language, normalize=False) # Already normalized
|
||||||
|
t1 = time.time()
|
||||||
|
logger.debug(f"Phonemization took {(t1-t0)*1000:.2f}ms for {len(normalized)} chars")
|
||||||
|
|
||||||
|
# Convert to token IDs
|
||||||
|
t0 = time.time()
|
||||||
|
tokens = tokenize(phonemes)
|
||||||
|
t1 = time.time()
|
||||||
|
logger.debug(f"Tokenization took {(t1-t0)*1000:.2f}ms for {len(phonemes)} chars")
|
||||||
|
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
logger.debug(f"Total processing took {total_time*1000:.2f}ms for chunk: '{text[:50]}...'")
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def process_text(text: str, language: str = "a") -> List[int]:
|
||||||
|
"""Process text into token IDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to process
|
||||||
|
language: Language code for phonemization
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of token IDs
|
||||||
|
"""
|
||||||
|
if not isinstance(text, str):
|
||||||
|
text = str(text) if text is not None else ""
|
||||||
|
|
||||||
|
text = text.strip()
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return process_text_chunk(text, language)
|
||||||
|
|
||||||
|
async def smart_split(text: str, max_tokens: int = 500) -> AsyncGenerator[Tuple[str, List[int]], None]:
|
||||||
|
"""Split text into semantically meaningful chunks while respecting token limits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Input text to split
|
||||||
|
max_tokens: Maximum tokens per chunk
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Tuples of (text chunk, token IDs) where token count is <= max_tokens
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
chunk_count = 0
|
||||||
|
total_chars = len(text)
|
||||||
|
logger.info(f"Starting text split for {total_chars} characters with {max_tokens} max tokens")
|
||||||
|
|
||||||
|
# Split on major punctuation first
|
||||||
|
sentences = re.split(r'([.!?;:])', text)
|
||||||
|
|
||||||
|
current_chunk = []
|
||||||
|
current_token_count = 0
|
||||||
|
|
||||||
|
for i in range(0, len(sentences), 2):
|
||||||
|
# Get sentence and its punctuation (if any)
|
||||||
|
sentence = sentences[i].strip()
|
||||||
|
punct = sentences[i + 1] if i + 1 < len(sentences) else ""
|
||||||
|
|
||||||
|
if not sentence:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Process sentence to get token count
|
||||||
|
sentence_with_punct = sentence + punct
|
||||||
|
tokens = process_text_chunk(sentence_with_punct)
|
||||||
|
token_count = len(tokens)
|
||||||
|
logger.debug(f"Sentence '{sentence_with_punct[:50]}...' has {token_count} tokens")
|
||||||
|
|
||||||
|
# If this single sentence is too long, split on commas
|
||||||
|
if token_count > max_tokens:
|
||||||
|
logger.debug(f"Sentence exceeds token limit, splitting on commas")
|
||||||
|
clause_splits = re.split(r'([,])', sentence_with_punct)
|
||||||
|
for j in range(0, len(clause_splits), 2):
|
||||||
|
clause = clause_splits[j].strip()
|
||||||
|
comma = clause_splits[j + 1] if j + 1 < len(clause_splits) else ""
|
||||||
|
|
||||||
|
if not clause:
|
||||||
|
continue
|
||||||
|
|
||||||
|
clause_with_punct = clause + comma
|
||||||
|
clause_tokens = process_text_chunk(clause_with_punct)
|
||||||
|
|
||||||
|
# If still too long, do a hard split on words
|
||||||
|
if len(clause_tokens) > max_tokens:
|
||||||
|
logger.debug(f"Clause exceeds token limit, splitting on words")
|
||||||
|
words = clause_with_punct.split()
|
||||||
|
temp_chunk = []
|
||||||
|
temp_tokens = []
|
||||||
|
|
||||||
|
for word in words:
|
||||||
|
word_tokens = process_text_chunk(word)
|
||||||
|
if len(temp_tokens) + len(word_tokens) > max_tokens:
|
||||||
|
if temp_chunk: # Don't yield empty chunks
|
||||||
|
chunk_text = " ".join(temp_chunk)
|
||||||
|
chunk_count += 1
|
||||||
|
logger.info(f"Yielding word-split chunk {chunk_count}: '{chunk_text[:50]}...' ({len(temp_tokens)} tokens)")
|
||||||
|
yield chunk_text, temp_tokens
|
||||||
|
temp_chunk = [word]
|
||||||
|
temp_tokens = word_tokens
|
||||||
|
else:
|
||||||
|
temp_chunk.append(word)
|
||||||
|
temp_tokens.extend(word_tokens)
|
||||||
|
|
||||||
|
if temp_chunk: # Don't forget the last chunk
|
||||||
|
chunk_text = " ".join(temp_chunk)
|
||||||
|
chunk_count += 1
|
||||||
|
logger.info(f"Yielding final word-split chunk {chunk_count}: '{chunk_text[:50]}...' ({len(temp_tokens)} tokens)")
|
||||||
|
yield chunk_text, temp_tokens
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Check if adding this clause would exceed the limit
|
||||||
|
if current_token_count + len(clause_tokens) > max_tokens:
|
||||||
|
if current_chunk: # Don't yield empty chunks
|
||||||
|
chunk_text = " ".join(current_chunk)
|
||||||
|
chunk_count += 1
|
||||||
|
logger.info(f"Yielding clause-split chunk {chunk_count}: '{chunk_text[:50]}...' ({current_token_count} tokens)")
|
||||||
|
yield chunk_text, process_text_chunk(chunk_text)
|
||||||
|
current_chunk = [clause_with_punct]
|
||||||
|
current_token_count = len(clause_tokens)
|
||||||
|
else:
|
||||||
|
current_chunk.append(clause_with_punct)
|
||||||
|
current_token_count += len(clause_tokens)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Check if adding this sentence would exceed the limit
|
||||||
|
if current_token_count + token_count > max_tokens:
|
||||||
|
if current_chunk: # Don't yield empty chunks
|
||||||
|
chunk_text = " ".join(current_chunk)
|
||||||
|
chunk_count += 1
|
||||||
|
logger.info(f"Yielding sentence-split chunk {chunk_count}: '{chunk_text[:50]}...' ({current_token_count} tokens)")
|
||||||
|
yield chunk_text, process_text_chunk(chunk_text)
|
||||||
|
current_chunk = [sentence_with_punct]
|
||||||
|
current_token_count = token_count
|
||||||
|
else:
|
||||||
|
current_chunk.append(sentence_with_punct)
|
||||||
|
current_token_count += token_count
|
||||||
|
|
||||||
|
# Don't forget the last chunk
|
||||||
|
if current_chunk:
|
||||||
|
chunk_text = " ".join(current_chunk)
|
||||||
|
chunk_count += 1
|
||||||
|
logger.info(f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}...' ({current_token_count} tokens)")
|
||||||
|
yield chunk_text, process_text_chunk(chunk_text)
|
||||||
|
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
logger.info(f"Text splitting completed in {total_time*1000:.2f}ms, produced {chunk_count} chunks")
|
|
@ -1,11 +1,10 @@
|
||||||
"""TTS service using model and voice managers."""
|
"""TTS service using model and voice managers."""
|
||||||
|
|
||||||
import io
|
|
||||||
import time
|
import time
|
||||||
from typing import List, Tuple, Optional, AsyncGenerator, Union
|
from typing import List, Tuple, Optional, AsyncGenerator, Union
|
||||||
|
import asyncio
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.io.wavfile as wavfile
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
@ -13,10 +12,7 @@ from ..core.config import settings
|
||||||
from ..inference.model_manager import get_manager as get_model_manager
|
from ..inference.model_manager import get_manager as get_model_manager
|
||||||
from ..inference.voice_manager import get_manager as get_voice_manager
|
from ..inference.voice_manager import get_manager as get_voice_manager
|
||||||
from .audio import AudioNormalizer, AudioService
|
from .audio import AudioNormalizer, AudioService
|
||||||
from .text_processing import chunker, normalize_text, process_text
|
from .text_processing.text_processor import process_text_chunk, smart_split
|
||||||
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
class TTSService:
|
class TTSService:
|
||||||
"""Text-to-speech service."""
|
"""Text-to-speech service."""
|
||||||
|
@ -40,7 +36,7 @@ class TTSService:
|
||||||
|
|
||||||
async def _process_chunk(
|
async def _process_chunk(
|
||||||
self,
|
self,
|
||||||
chunk: str,
|
tokens: List[int],
|
||||||
voice_tensor: torch.Tensor,
|
voice_tensor: torch.Tensor,
|
||||||
speed: float,
|
speed: float,
|
||||||
output_format: Optional[str] = None,
|
output_format: Optional[str] = None,
|
||||||
|
@ -48,10 +44,21 @@ class TTSService:
|
||||||
is_last: bool = False,
|
is_last: bool = False,
|
||||||
normalizer: Optional[AudioNormalizer] = None,
|
normalizer: Optional[AudioNormalizer] = None,
|
||||||
) -> Optional[Union[np.ndarray, bytes]]:
|
) -> Optional[Union[np.ndarray, bytes]]:
|
||||||
"""Process a single text chunk into audio."""
|
"""Process tokens into audio."""
|
||||||
async with self._chunk_semaphore:
|
async with self._chunk_semaphore:
|
||||||
try:
|
try:
|
||||||
tokens = process_text(chunk)
|
# Handle stream finalization
|
||||||
|
if is_last:
|
||||||
|
return await AudioService.convert_audio(
|
||||||
|
np.array([0], dtype=np.float32), # Dummy data for type checking
|
||||||
|
24000,
|
||||||
|
output_format,
|
||||||
|
is_first_chunk=False,
|
||||||
|
normalizer=normalizer,
|
||||||
|
is_last_chunk=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Skip empty chunks
|
||||||
if not tokens:
|
if not tokens:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -63,10 +70,16 @@ class TTSService:
|
||||||
)
|
)
|
||||||
|
|
||||||
if chunk_audio is None:
|
if chunk_audio is None:
|
||||||
|
logger.error("Model generated None for audio chunk")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if len(chunk_audio) == 0:
|
||||||
|
logger.error("Model generated empty audio chunk")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# For streaming, convert to bytes
|
# For streaming, convert to bytes
|
||||||
if output_format:
|
if output_format:
|
||||||
|
try:
|
||||||
return await AudioService.convert_audio(
|
return await AudioService.convert_audio(
|
||||||
chunk_audio,
|
chunk_audio,
|
||||||
24000,
|
24000,
|
||||||
|
@ -75,11 +88,13 @@ class TTSService:
|
||||||
normalizer=normalizer,
|
normalizer=normalizer,
|
||||||
is_last_chunk=is_last
|
is_last_chunk=is_last
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to convert audio: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
return chunk_audio
|
return chunk_audio
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to process chunk: '{chunk}'. Error: {str(e)}")
|
logger.error(f"Failed to process tokens: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def generate_audio_stream(
|
async def generate_audio_stream(
|
||||||
|
@ -92,62 +107,59 @@ class TTSService:
|
||||||
"""Generate and stream audio chunks."""
|
"""Generate and stream audio chunks."""
|
||||||
stream_normalizer = AudioNormalizer()
|
stream_normalizer = AudioNormalizer()
|
||||||
voice_tensor = None
|
voice_tensor = None
|
||||||
pending_results = {}
|
chunk_index = 0
|
||||||
next_index = 0
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Normalize text
|
|
||||||
normalized = normalize_text(text)
|
|
||||||
if not normalized:
|
|
||||||
raise ValueError("Text is empty after preprocessing")
|
|
||||||
text = str(normalized)
|
|
||||||
|
|
||||||
# Get backend and load voice (should be fast if cached)
|
# Get backend and load voice (should be fast if cached)
|
||||||
backend = self.model_manager.get_backend()
|
backend = self.model_manager.get_backend()
|
||||||
voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device)
|
voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device)
|
||||||
|
|
||||||
# Process chunks with semaphore limiting concurrency
|
# Process text in chunks with smart splitting
|
||||||
chunks = []
|
async for chunk_text, tokens in smart_split(text):
|
||||||
async for chunk in chunker.split_text(text):
|
try:
|
||||||
chunks.append(chunk)
|
# Process audio for chunk
|
||||||
|
result = await self._process_chunk(
|
||||||
if not chunks:
|
tokens,
|
||||||
raise ValueError("No text chunks to process")
|
|
||||||
|
|
||||||
# Create tasks for all chunks
|
|
||||||
tasks = [
|
|
||||||
asyncio.create_task(
|
|
||||||
self._process_chunk(
|
|
||||||
chunk,
|
|
||||||
voice_tensor,
|
voice_tensor,
|
||||||
speed,
|
speed,
|
||||||
output_format,
|
output_format,
|
||||||
is_first=(i == 0),
|
is_first=(chunk_index == 0),
|
||||||
is_last=(i == len(chunks) - 1),
|
is_last=False, # We'll update the last chunk later
|
||||||
normalizer=stream_normalizer
|
normalizer=stream_normalizer
|
||||||
)
|
)
|
||||||
)
|
|
||||||
for i, chunk in enumerate(chunks)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Process chunks and maintain order
|
|
||||||
for i, task in enumerate(tasks):
|
|
||||||
result = await task
|
|
||||||
|
|
||||||
if i == next_index and result is not None:
|
|
||||||
# If this is the next chunk we need, yield it
|
|
||||||
yield result
|
|
||||||
next_index += 1
|
|
||||||
|
|
||||||
# Check if we have any subsequent chunks ready
|
|
||||||
while next_index in pending_results:
|
|
||||||
result = pending_results.pop(next_index)
|
|
||||||
if result is not None:
|
if result is not None:
|
||||||
yield result
|
yield result
|
||||||
next_index += 1
|
chunk_index += 1
|
||||||
else:
|
else:
|
||||||
# Store out-of-order result
|
logger.warning(f"No audio generated for chunk: '{chunk_text[:100]}...'")
|
||||||
pending_results[i] = result
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Only finalize if we successfully processed at least one chunk
|
||||||
|
if chunk_index > 0:
|
||||||
|
try:
|
||||||
|
# Empty tokens list to finalize audio
|
||||||
|
final_result = await self._process_chunk(
|
||||||
|
[], # Empty tokens list
|
||||||
|
voice_tensor,
|
||||||
|
speed,
|
||||||
|
output_format,
|
||||||
|
is_first=False,
|
||||||
|
is_last=True,
|
||||||
|
normalizer=stream_normalizer
|
||||||
|
)
|
||||||
|
if final_result is not None:
|
||||||
|
logger.debug("Yielding final chunk to finalize audio")
|
||||||
|
yield final_result
|
||||||
|
else:
|
||||||
|
logger.warning("Final chunk processing returned None")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to process final chunk: {str(e)}")
|
||||||
|
else:
|
||||||
|
logger.warning("No audio chunks were successfully processed")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in audio generation stream: {str(e)}")
|
logger.error(f"Error in audio generation stream: {str(e)}")
|
||||||
|
|
Loading…
Add table
Reference in a new issue