mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Refactor TTS service and text processing to enhance handling of pauses, newlines, and custom phonemes. Updated smart_split to manage pause tags and improved error logging. Adjusted audio generation logic for better performance and clarity.
This commit is contained in:
parent
b31f79d8d7
commit
c0da571857
2 changed files with 453 additions and 398 deletions
|
@ -2,7 +2,7 @@
|
|||
|
||||
import re
|
||||
import time
|
||||
from typing import AsyncGenerator, Dict, List, Tuple
|
||||
from typing import AsyncGenerator, Dict, List, Tuple, Optional # Add Optional import
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
@ -14,6 +14,8 @@ from .vocabulary import tokenize
|
|||
|
||||
# Pre-compiled regex patterns for performance
|
||||
CUSTOM_PHONEMES = re.compile(r"(\[([^\]]|\n)*?\])(\(\/([^\/)]|\n)*?\/\))")
|
||||
# Pattern to find pause tags like [pause:0.5s]
|
||||
PAUSE_TAG_PATTERN = re.compile(r"\[pause:(\d+(?:\.\d+)?)s\]", re.IGNORECASE)
|
||||
|
||||
|
||||
def process_text_chunk(
|
||||
|
@ -42,7 +44,8 @@ def process_text_chunk(
|
|||
t1 = time.time()
|
||||
|
||||
t0 = time.time()
|
||||
phonemes = phonemize(text, language, normalize=False) # Already normalized
|
||||
# Normalize step is usually done before smart_split, but phonemize itself might do basic norm
|
||||
phonemes = phonemize(text, language, normalize=False)
|
||||
t1 = time.time()
|
||||
|
||||
t0 = time.time()
|
||||
|
@ -51,7 +54,7 @@ def process_text_chunk(
|
|||
|
||||
total_time = time.time() - start_time
|
||||
logger.debug(
|
||||
f"Total processing took {total_time * 1000:.2f}ms for chunk: '{text[:50]}{'...' if len(text) > 50 else ''}'"
|
||||
f"Tokenization took {total_time * 1000:.2f}ms for chunk: '{text[:50]}{'...' if len(text) > 50 else ''}'"
|
||||
)
|
||||
|
||||
return tokens
|
||||
|
@ -90,45 +93,61 @@ def process_text(text: str, language: str = "a") -> List[int]:
|
|||
def get_sentence_info(
|
||||
text: str, custom_phenomes_list: Dict[str, str]
|
||||
) -> List[Tuple[str, List[int], int]]:
|
||||
"""Process all sentences and return info."""
|
||||
sentences = re.split(r"([.!?;:])(?=\s|$)", text)
|
||||
phoneme_length, min_value = len(custom_phenomes_list), 0
|
||||
"""Process all sentences and return info, preserving trailing newlines."""
|
||||
# Split by sentence-ending punctuation, keeping the punctuation
|
||||
sentences_parts = re.split(r'([.!?]+|\n+)', text)
|
||||
sentences = []
|
||||
current_sentence = ""
|
||||
for part in sentences_parts:
|
||||
if not part:
|
||||
continue
|
||||
current_sentence += part
|
||||
# If the part ends with sentence punctuation or newline, consider it a sentence end
|
||||
if re.search(r'[.!?\n]$', part):
|
||||
sentences.append(current_sentence)
|
||||
current_sentence = ""
|
||||
if current_sentence: # Add any remaining part
|
||||
sentences.append(current_sentence)
|
||||
|
||||
|
||||
phoneme_length = len(custom_phenomes_list)
|
||||
restored_phoneme_keys = list(custom_phenomes_list.keys()) # Keys to restore
|
||||
|
||||
results = []
|
||||
for i in range(0, len(sentences), 2):
|
||||
sentence = sentences[i].strip()
|
||||
for replaced in range(min_value, phoneme_length):
|
||||
current_id = f"</|custom_phonemes_{replaced}|/>"
|
||||
if current_id in sentence:
|
||||
sentence = sentence.replace(
|
||||
current_id, custom_phenomes_list.pop(current_id)
|
||||
)
|
||||
min_value += 1
|
||||
for original_sentence in sentences:
|
||||
sentence_text_part = original_sentence.rstrip('\n') # Text without trailing newline for processing
|
||||
trailing_newlines = original_sentence[len(sentence_text_part):] # Capture trailing newlines
|
||||
|
||||
punct = sentences[i + 1] if i + 1 < len(sentences) else ""
|
||||
|
||||
if not sentence:
|
||||
if not sentence_text_part.strip(): # Skip empty or whitespace-only sentences
|
||||
if trailing_newlines: # If only newlines, represent as empty text with newline marker
|
||||
results.append(("\n", [], 0)) # Store newline marker, no tokens
|
||||
continue
|
||||
|
||||
# Check if the original text segment ended with newline(s) before punctuation
|
||||
original_segment = sentences[i]
|
||||
trailing_newlines = ""
|
||||
match = re.search(r"(\n+)$", original_segment)
|
||||
if match:
|
||||
trailing_newlines = match.group(1)
|
||||
# Restore custom phonemes for this sentence *before* tokenization
|
||||
sentence_to_tokenize = sentence_text_part
|
||||
restored_count = 0
|
||||
# Iterate through *all* possible phoneme IDs that might be in this sentence
|
||||
for ph_id in restored_phoneme_keys:
|
||||
if ph_id in sentence_to_tokenize:
|
||||
sentence_to_tokenize = sentence_to_tokenize.replace(ph_id, custom_phenomes_list[ph_id])
|
||||
restored_count+=1
|
||||
if restored_count > 0:
|
||||
logger.debug(f"Restored {restored_count} custom phonemes for tokenization in: '{sentence_text_part[:30]}...'")
|
||||
|
||||
full = sentence + punct + trailing_newlines # Append trailing newlines
|
||||
# Tokenize without the trailing newlines for accurate TTS processing
|
||||
tokens = process_text_chunk(sentence + punct)
|
||||
# Store the full text including newlines for later check
|
||||
results.append((full, tokens, len(tokens)))
|
||||
|
||||
# Tokenize the text part (without trailing newlines)
|
||||
tokens = process_text_chunk(sentence_to_tokenize)
|
||||
|
||||
# Store the original sentence text (including trailing newlines) along with tokens
|
||||
results.append((original_sentence, tokens, len(tokens)))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def handle_custom_phonemes(s: re.Match[str], phenomes_list: Dict[str, str]) -> str:
|
||||
latest_id = f"</|custom_phonemes_{len(phenomes_list)}|/>"
|
||||
phenomes_list[latest_id] = s.group(0).strip()
|
||||
phenomes_list[latest_id] = s.group(0).strip() # Store the full original tag [phoneme](/ipa/)
|
||||
logger.debug(f"Replacing custom phoneme {phenomes_list[latest_id]} with ID {latest_id}")
|
||||
return latest_id
|
||||
|
||||
|
||||
|
@ -137,174 +156,152 @@ async def smart_split(
|
|||
max_tokens: int = settings.absolute_max_tokens,
|
||||
lang_code: str = "a",
|
||||
normalization_options: NormalizationOptions = NormalizationOptions(),
|
||||
) -> AsyncGenerator[Tuple[str, List[int]], None]:
|
||||
"""Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens."""
|
||||
) -> AsyncGenerator[Tuple[str, List[int], Optional[float]], None]:
|
||||
"""Build optimal chunks targeting token limits, handling pause tags and newlines.
|
||||
|
||||
Yields:
|
||||
Tuple of (text_chunk, tokens, pause_duration_s).
|
||||
If pause_duration_s is not None, it's a pause chunk with empty text/tokens.
|
||||
Otherwise, it's a text chunk. text_chunk may end with '\n'.
|
||||
"""
|
||||
start_time = time.time()
|
||||
chunk_count = 0
|
||||
logger.info(f"Starting smart split for {len(text)} chars")
|
||||
logger.info(f"Starting smart split for {len(text)} chars, max_tokens={max_tokens}")
|
||||
|
||||
custom_phoneme_list = {}
|
||||
|
||||
# Normalize text
|
||||
# 1. Temporarily replace custom phonemes like [word](/ipa/) with unique IDs
|
||||
text_with_ids = CUSTOM_PHONEMES.sub(
|
||||
lambda s: handle_custom_phonemes(s, custom_phoneme_list), text
|
||||
)
|
||||
if custom_phoneme_list:
|
||||
logger.debug(f"Found custom phonemes: {custom_phoneme_list}")
|
||||
|
||||
|
||||
# 2. Normalize the text *with IDs* if required
|
||||
normalized_text = text_with_ids
|
||||
if settings.advanced_text_normalization and normalization_options.normalize:
|
||||
print(lang_code)
|
||||
if lang_code in ["a", "b", "en-us", "en-gb"]:
|
||||
text = CUSTOM_PHONEMES.sub(
|
||||
lambda s: handle_custom_phonemes(s, custom_phoneme_list), text
|
||||
)
|
||||
text = normalize_text(text, normalization_options)
|
||||
normalized_text = normalize_text(normalized_text, normalization_options)
|
||||
logger.debug("Applied text normalization.")
|
||||
else:
|
||||
logger.info(
|
||||
"Skipping text normalization as it is only supported for english"
|
||||
)
|
||||
|
||||
# Process all sentences
|
||||
sentences = get_sentence_info(text, custom_phoneme_list)
|
||||
# 3. Split the normalized text by pause tags
|
||||
parts = PAUSE_TAG_PATTERN.split(normalized_text)
|
||||
logger.debug(f"Split into {len(parts)} parts by pause tags.")
|
||||
|
||||
current_chunk = []
|
||||
current_tokens = []
|
||||
current_count = 0
|
||||
|
||||
for sentence, tokens, count in sentences:
|
||||
# Handle sentences that exceed max tokens
|
||||
if count > max_tokens:
|
||||
# Yield current chunk if any
|
||||
if current_chunk:
|
||||
# Join with space, but preserve original trailing newline of the last sentence if present
|
||||
last_sentence_original = current_chunk[-1]
|
||||
chunk_text_joined = " ".join(current_chunk)
|
||||
if last_sentence_original.endswith("\n"):
|
||||
chunk_text_joined += "\n" # Preserve the newline marker
|
||||
part_idx = 0
|
||||
while part_idx < len(parts):
|
||||
text_part = parts[part_idx] # This part contains text and custom phoneme IDs
|
||||
part_idx += 1
|
||||
|
||||
chunk_count += 1
|
||||
logger.debug(
|
||||
f"Yielding text chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({current_count} tokens)"
|
||||
)
|
||||
yield chunk_text_joined, current_tokens, None # Pass the text with potential trailing newline
|
||||
current_chunk = []
|
||||
current_tokens = []
|
||||
current_count = 0
|
||||
if text_part:
|
||||
# Process this text part using sentence splitting
|
||||
# We pass the text_part *with IDs* to get_sentence_info
|
||||
# get_sentence_info will handle restoring phonemes just before tokenization
|
||||
sentences = get_sentence_info(text_part, custom_phoneme_list)
|
||||
|
||||
# Split long sentence on commas (simple approach)
|
||||
# Keep original sentence text ('sentence' now includes potential trailing newline)
|
||||
clauses = re.split(r"([,])", sentence.rstrip('\n')) # Split text part only
|
||||
trailing_newline_in_sentence = "\n" if sentence.endswith("\n") else ""
|
||||
clause_chunk = [] # Stores original clause text including potential trailing newline
|
||||
clause_tokens = []
|
||||
clause_count = 0
|
||||
current_chunk_texts = [] # Store original sentence texts for the current chunk
|
||||
current_chunk_tokens = []
|
||||
current_token_count = 0
|
||||
|
||||
for j in range(0, len(clauses), 2):
|
||||
# clause = clauses[j].strip() # Don't strip here to preserve internal structure
|
||||
clause = clauses[j]
|
||||
comma = clauses[j + 1] if j + 1 < len(clauses) else ""
|
||||
|
||||
if not clause.strip(): # Check if clause is just whitespace
|
||||
|
||||
full_clause = clause + comma
|
||||
|
||||
tokens = process_text_chunk(full_clause)
|
||||
count = len(tokens)
|
||||
|
||||
# If adding clause keeps us under max and not optimal yet
|
||||
if (
|
||||
clause_count + count <= max_tokens
|
||||
and clause_count + count <= settings.target_max_tokens
|
||||
):
|
||||
clause_chunk.append(full_clause)
|
||||
clause_tokens.extend(tokens)
|
||||
clause_count += count
|
||||
else:
|
||||
# Yield clause chunk if we have one
|
||||
if clause_chunk:
|
||||
# Join with space, preserve last clause's potential trailing newline
|
||||
last_clause_original = clause_chunk[-1]
|
||||
chunk_text_joined = " ".join(clause_chunk)
|
||||
if last_clause_original.endswith("\n"):
|
||||
chunk_text_joined += "\n"
|
||||
for sentence_text, sentence_tokens, sentence_token_count in sentences:
|
||||
# --- Chunking Logic ---
|
||||
|
||||
# Condition 1: Current sentence alone exceeds max tokens
|
||||
if sentence_token_count > max_tokens:
|
||||
logger.warning(f"Single sentence exceeds max_tokens ({sentence_token_count} > {max_tokens}): '{sentence_text[:50]}...'")
|
||||
# Yield any existing chunk first
|
||||
if current_chunk_texts:
|
||||
chunk_text_joined = " ".join(current_chunk_texts) # Join original texts
|
||||
chunk_count += 1
|
||||
logger.debug(
|
||||
f"Yielding clause chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({clause_count} tokens)"
|
||||
)
|
||||
yield chunk_text_joined, clause_tokens, None
|
||||
# Start new clause chunk with original text
|
||||
clause_chunk = [full_clause + (trailing_newline_in_sentence if j == len(clauses) - 2 else "")]
|
||||
clause_tokens = clause_token_list
|
||||
clause_count = clause_token_count
|
||||
logger.info(f"Yielding text chunk {chunk_count} (before oversized sentence): '{chunk_text_joined[:50]}...' ({current_token_count} tokens)")
|
||||
yield chunk_text_joined, current_chunk_tokens, None
|
||||
current_chunk_texts = []
|
||||
current_chunk_tokens = []
|
||||
current_token_count = 0
|
||||
|
||||
# Don't forget last clause chunk
|
||||
if clause_chunk:
|
||||
# Join with space, preserve last clause's potential trailing newline
|
||||
last_clause_original = clause_chunk[-1]
|
||||
chunk_text_joined = " ".join(clause_chunk)
|
||||
# The trailing newline logic was added when creating the chunk above
|
||||
#if last_clause_original.endswith("\n"):
|
||||
# chunk_text_joined += "\n"
|
||||
# Yield the oversized sentence as its own chunk
|
||||
# Restore phonemes before yielding the text
|
||||
text_to_yield = sentence_text
|
||||
for p_id, p_val in custom_phoneme_list.items():
|
||||
if p_id in text_to_yield:
|
||||
text_to_yield = text_to_yield.replace(p_id, p_val)
|
||||
|
||||
chunk_count += 1
|
||||
logger.debug(
|
||||
f"Yielding final clause chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({clause_count} tokens)"
|
||||
)
|
||||
yield chunk_text_joined, clause_tokens, None
|
||||
current_count >= settings.target_min_tokens
|
||||
and current_count + count > settings.target_max_tokens
|
||||
):
|
||||
# If we have a good sized chunk and adding next sentence exceeds target,
|
||||
# Yield current chunk and start new one
|
||||
last_sentence_original = current_chunk[-1]
|
||||
chunk_text_joined = " ".join(current_chunk)
|
||||
if last_sentence_original.endswith("\n"):
|
||||
chunk_text_joined += "\n"
|
||||
chunk_count += 1
|
||||
logger.info(
|
||||
f"Yielding text chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({current_count} tokens)"
|
||||
)
|
||||
yield chunk_text_joined, current_tokens, None
|
||||
current_chunk = [sentence] # sentence includes potential trailing newline
|
||||
current_tokens = tokens
|
||||
current_count = count
|
||||
elif current_count + count <= settings.target_max_tokens:
|
||||
# Keep building chunk
|
||||
current_chunk.append(sentence) # sentence includes potential trailing newline
|
||||
current_tokens.extend(tokens)
|
||||
current_count += count
|
||||
elif (
|
||||
current_count + count <= max_tokens
|
||||
and current_count < settings.target_min_tokens
|
||||
):
|
||||
# Exceed target max only if below min size
|
||||
current_chunk.append(sentence) # sentence includes potential trailing newline
|
||||
current_tokens.extend(tokens)
|
||||
current_count += count
|
||||
else:
|
||||
# Yield current chunk and start new one
|
||||
if current_chunk:
|
||||
last_sentence_original = current_chunk[-1]
|
||||
chunk_text_joined = " ".join(current_chunk)
|
||||
if last_sentence_original.endswith("\n"):
|
||||
chunk_text_joined += "\n"
|
||||
chunk_count += 1
|
||||
logger.info(
|
||||
f"Yielding text chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({current_count} tokens)"
|
||||
)
|
||||
yield chunk_text_joined, current_tokens, None
|
||||
current_chunk = [sentence] # sentence includes potential trailing newline
|
||||
current_tokens = tokens
|
||||
current_count = count
|
||||
logger.info(f"Yielding oversized text chunk {chunk_count}: '{text_to_yield[:50]}...' ({sentence_token_count} tokens)")
|
||||
yield text_to_yield, sentence_tokens, None
|
||||
continue # Move to the next sentence
|
||||
|
||||
# Condition 2: Adding the current sentence would exceed max_tokens
|
||||
elif current_token_count + sentence_token_count > max_tokens:
|
||||
# Yield the current chunk first
|
||||
if current_chunk_texts:
|
||||
chunk_text_joined = " ".join(current_chunk_texts) # Join original texts
|
||||
chunk_count += 1
|
||||
logger.info(f"Yielding text chunk {chunk_count} (max_tokens limit): '{chunk_text_joined[:50]}...' ({current_token_count} tokens)")
|
||||
yield chunk_text_joined, current_chunk_tokens, None
|
||||
# Start a new chunk with the current sentence
|
||||
current_chunk_texts = [sentence_text]
|
||||
current_chunk_tokens = sentence_tokens
|
||||
current_token_count = sentence_token_count
|
||||
|
||||
# Condition 3: Adding exceeds target_max_tokens when already above target_min_tokens
|
||||
elif (current_token_count >= settings.target_min_tokens and
|
||||
current_token_count + sentence_token_count > settings.target_max_tokens):
|
||||
# Yield the current chunk
|
||||
chunk_text_joined = " ".join(current_chunk_texts) # Join original texts
|
||||
chunk_count += 1
|
||||
logger.info(f"Yielding text chunk {chunk_count} (target_max limit): '{chunk_text_joined[:50]}...' ({current_token_count} tokens)")
|
||||
yield chunk_text_joined, current_chunk_tokens, None
|
||||
# Start a new chunk
|
||||
current_chunk_texts = [sentence_text]
|
||||
current_chunk_tokens = sentence_tokens
|
||||
current_token_count = sentence_token_count
|
||||
|
||||
# Condition 4: Add sentence to current chunk (fits within max_tokens and either below target_max or below target_min)
|
||||
else:
|
||||
current_chunk_texts.append(sentence_text)
|
||||
current_chunk_tokens.extend(sentence_tokens)
|
||||
current_token_count += sentence_token_count
|
||||
|
||||
# --- End of sentence loop for this text part ---
|
||||
|
||||
# Yield any remaining accumulated chunk for this text part
|
||||
if current_chunk_texts:
|
||||
chunk_text_joined = " ".join(current_chunk_texts) # Join original texts
|
||||
# Restore phonemes before yielding
|
||||
text_to_yield = chunk_text_joined
|
||||
for p_id, p_val in custom_phoneme_list.items():
|
||||
if p_id in text_to_yield:
|
||||
text_to_yield = text_to_yield.replace(p_id, p_val)
|
||||
|
||||
chunk_count += 1
|
||||
logger.info(f"Yielding final text chunk {chunk_count} for part: '{text_to_yield[:50]}...' ({current_token_count} tokens)")
|
||||
yield text_to_yield, current_chunk_tokens, None
|
||||
|
||||
|
||||
# Check if the next part is a pause duration
|
||||
if part_idx < len(parts):
|
||||
duration_str = parts[part_idx]
|
||||
part_idx += 1 # Move past the duration string
|
||||
try:
|
||||
duration = float(duration_str)
|
||||
if duration > 0:
|
||||
chunk_count += 1
|
||||
logger.info(f"Yielding pause chunk {chunk_count}: {duration}s")
|
||||
yield "", [], duration # Yield pause chunk
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"Could not parse pause duration: {duration_str}")
|
||||
# If parsing fails, potentially treat the duration_str as text?
|
||||
# For now, just log a warning and skip.
|
||||
|
||||
# Yield any remaining text chunk
|
||||
if current_chunk:
|
||||
last_sentence_original = current_chunk[-1]
|
||||
chunk_text_joined = " ".join(current_chunk)
|
||||
if last_sentence_original.endswith("\n"):
|
||||
chunk_text_joined += "\n"
|
||||
chunk_count += 1
|
||||
logger.info(
|
||||
f"Yielding final text chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({current_count} tokens)"
|
||||
)
|
||||
yield chunk_text_joined, current_tokens, None
|
||||
|
||||
total_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks"
|
||||
f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks (including pauses)"
|
||||
)
|
|
@ -78,28 +78,83 @@ class TTSService:
|
|||
"",
|
||||
normalizer=normalizer,
|
||||
is_last_chunk=True,
|
||||
trim_audio=False, # Don't trim final silence
|
||||
)
|
||||
yield chunk_data
|
||||
return
|
||||
|
||||
# Skip empty chunks
|
||||
# Skip empty chunks (shouldn't happen if called correctly, but safety)
|
||||
if not tokens and not chunk_text:
|
||||
return
|
||||
logger.warning("Empty chunk passed to _process_chunk")
|
||||
return
|
||||
|
||||
# Get backend
|
||||
backend = self.model_manager.get_backend()
|
||||
|
||||
# Generate audio using pre-warmed model
|
||||
# Note: chunk_text is the *original* text including custom phoneme markers and newlines
|
||||
# The model needs the text *with phonemes restored*
|
||||
text_for_model = chunk_text # Start with original
|
||||
# Restore custom phonemes if backend needs it (like KokoroV1)
|
||||
if isinstance(backend, KokoroV1):
|
||||
chunk_index = 0
|
||||
# For Kokoro V1, pass text and voice info with lang_code
|
||||
# Find phoneme markers in this specific chunk_text and restore
|
||||
# (This assumes smart_split yielded text with markers) - let's refine smart_split yield
|
||||
# For now, assume chunk_text is ready for the model (phonemes restored by smart_split)
|
||||
pass
|
||||
|
||||
|
||||
if isinstance(backend, KokoroV1):
|
||||
internal_chunk_index = 0
|
||||
async for chunk_data in self.model_manager.generate(
|
||||
chunk_text,
|
||||
text_for_model.strip(), # Pass cleaned text to model
|
||||
(voice_name, voice_path),
|
||||
speed=speed,
|
||||
lang_code=lang_code,
|
||||
return_timestamps=return_timestamps,
|
||||
):
|
||||
# For streaming, convert to bytes if format specified
|
||||
if output_format:
|
||||
try:
|
||||
chunk_data = await AudioService.convert_audio(
|
||||
chunk_data,
|
||||
output_format,
|
||||
writer,
|
||||
speed,
|
||||
chunk_text.strip(), # Pass original text for trimming logic
|
||||
is_last_chunk=is_last, # Should always be False here, handled above
|
||||
normalizer=normalizer,
|
||||
trim_audio=True # Trim speech parts
|
||||
)
|
||||
yield chunk_data
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert audio: {str(e)}")
|
||||
else: # Raw audio mode
|
||||
chunk_data = AudioService.trim_audio(
|
||||
chunk_data, chunk_text.strip(), speed, False, normalizer # Trim speech parts
|
||||
)
|
||||
yield chunk_data
|
||||
internal_chunk_index += 1
|
||||
if internal_chunk_index == 0:
|
||||
logger.warning(f"Model generation yielded no audio chunks for: '{text_for_model[:50]}...'")
|
||||
|
||||
else:
|
||||
# --- Legacy backend path (using tokens) ---
|
||||
# This path might not work correctly with custom phonemes restored in text_for_model
|
||||
logger.warning("Using legacy backend path with tokens - custom phonemes might not be handled.")
|
||||
voice_tensor = await self._voice_manager.load_voice(
|
||||
voice_name, device=backend.device
|
||||
)
|
||||
async for chunk_data in self.model_manager.generate( # Needs to be async generator
|
||||
tokens, # Legacy uses tokens
|
||||
(voice_name, voice_tensor), # Pass tuple as expected
|
||||
speed=speed,
|
||||
return_timestamps=return_timestamps,
|
||||
):
|
||||
|
||||
if chunk_data.audio is None or len(chunk_data.audio) == 0:
|
||||
logger.error("Legacy model generated empty or None audio chunk")
|
||||
continue # Skip this chunk
|
||||
|
||||
# For streaming, convert to bytes
|
||||
if output_format:
|
||||
try:
|
||||
|
@ -108,61 +163,22 @@ class TTSService:
|
|||
output_format,
|
||||
writer,
|
||||
speed,
|
||||
chunk_text,
|
||||
is_last_chunk=is_last,
|
||||
chunk_text.strip(), # Pass original text for trimming logic
|
||||
normalizer=normalizer,
|
||||
is_last_chunk=is_last, # Should be False here
|
||||
trim_audio=True # Trim speech parts
|
||||
)
|
||||
yield chunk_data
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert audio: {str(e)}")
|
||||
else:
|
||||
chunk_data = AudioService.trim_audio(
|
||||
chunk_data, chunk_text, speed, is_last, normalizer
|
||||
logger.error(f"Failed to convert legacy audio: {str(e)}")
|
||||
else: # Raw audio mode
|
||||
trimmed = AudioService.trim_audio(
|
||||
chunk_data, chunk_text.strip(), speed, False, normalizer # Trim speech parts
|
||||
)
|
||||
yield chunk_data
|
||||
chunk_index += 1
|
||||
else:
|
||||
# For legacy backends, load voice tensor
|
||||
voice_tensor = await self._voice_manager.load_voice(
|
||||
voice_name, device=backend.device
|
||||
)
|
||||
chunk_data = await self.model_manager.generate(
|
||||
tokens,
|
||||
voice_tensor,
|
||||
speed=speed,
|
||||
return_timestamps=return_timestamps,
|
||||
)
|
||||
|
||||
if chunk_data.audio is None:
|
||||
logger.error("Model generated None for audio chunk")
|
||||
return
|
||||
|
||||
if len(chunk_data.audio) == 0:
|
||||
logger.error("Model generated empty audio chunk")
|
||||
return
|
||||
|
||||
# For streaming, convert to bytes
|
||||
if output_format:
|
||||
try:
|
||||
chunk_data = await AudioService.convert_audio(
|
||||
chunk_data,
|
||||
output_format,
|
||||
writer,
|
||||
speed,
|
||||
chunk_text,
|
||||
normalizer=normalizer,
|
||||
is_last_chunk=is_last,
|
||||
)
|
||||
yield chunk_data
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert audio: {str(e)}")
|
||||
else:
|
||||
trimmed = AudioService.trim_audio(
|
||||
chunk_data, chunk_text, speed, is_last, normalizer
|
||||
)
|
||||
yield trimmed
|
||||
yield trimmed
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process tokens: {str(e)}")
|
||||
logger.exception(f"Failed to process chunk: '{chunk_text[:50]}...'. Error: {str(e)}")
|
||||
|
||||
|
||||
async def _load_voice_from_path(self, path: str, weight: float):
|
||||
# Check if the path is None and raise a ValueError if it is not
|
||||
|
@ -170,13 +186,15 @@ class TTSService:
|
|||
raise ValueError(f"Voice not found at path: {path}")
|
||||
|
||||
logger.debug(f"Loading voice tensor from path: {path}")
|
||||
return torch.load(path, map_location="cpu") * weight
|
||||
# Ensure loading happens on CPU initially to avoid device mismatches
|
||||
tensor = torch.load(path, map_location="cpu")
|
||||
return tensor * weight
|
||||
|
||||
async def _get_voices_path(self, voice: str) -> Tuple[str, str]:
|
||||
"""Get voice path, handling combined voices.
|
||||
|
||||
Args:
|
||||
voice: Voice name or combined voice names (e.g., 'af_jadzia+af_jessica')
|
||||
voice: Voice name or combined voice names (e.g., 'af_jadzia(0.7)+af_jessica(0.3)')
|
||||
|
||||
Returns:
|
||||
Tuple of (voice name to use, voice path to use)
|
||||
|
@ -185,72 +203,87 @@ class TTSService:
|
|||
RuntimeError: If voice not found
|
||||
"""
|
||||
try:
|
||||
# Split the voice on + and - and ensure that they get added to the list eg: hi+bob = ["hi","+","bob"]
|
||||
split_voice = re.split(r"([-+])", voice)
|
||||
# Regex to handle names, weights, and operators: af_name(weight)[+-]af_other(weight)...
|
||||
pattern = re.compile(r"([a-zA-Z0-9_]+)(?:\((\d+(?:\.\d+)?)\))?([+-]?)")
|
||||
matches = pattern.findall(voice.replace(" ", "")) # Remove spaces
|
||||
|
||||
# If it is only once voice there is no point in loading it up, doing nothing with it, then saving it
|
||||
if len(split_voice) == 1:
|
||||
# Since its a single voice the only time that the weight would matter is if voice_weight_normalization is off
|
||||
if (
|
||||
"(" not in voice and ")" not in voice
|
||||
) or settings.voice_weight_normalization == True:
|
||||
path = await self._voice_manager.get_voice_path(voice)
|
||||
if not path:
|
||||
raise RuntimeError(f"Voice not found: {voice}")
|
||||
logger.debug(f"Using single voice path: {path}")
|
||||
return voice, path
|
||||
if not matches:
|
||||
raise ValueError(f"Could not parse voice string: {voice}")
|
||||
|
||||
# If only one voice and no explicit weight or operators, handle directly
|
||||
if len(matches) == 1 and not matches[0][1] and not matches[0][2]:
|
||||
voice_name = matches[0][0]
|
||||
path = await self._voice_manager.get_voice_path(voice_name)
|
||||
if not path:
|
||||
raise RuntimeError(f"Voice not found: {voice_name}")
|
||||
logger.debug(f"Using single voice path: {path}")
|
||||
return voice_name, path
|
||||
|
||||
# Process combinations
|
||||
voice_parts = []
|
||||
total_weight = 0
|
||||
for name, weight_str, operator in matches:
|
||||
weight = float(weight_str) if weight_str else 1.0
|
||||
voice_parts.append({"name": name, "weight": weight, "op": operator})
|
||||
# Use weight directly for total, normalization happens later if enabled
|
||||
total_weight += weight # Summing base weights before potential normalization
|
||||
|
||||
for voice_index in range(0, len(split_voice), 2):
|
||||
voice_object = split_voice[voice_index]
|
||||
# Check base voices exist
|
||||
available_voices = await self._voice_manager.list_voices()
|
||||
for part in voice_parts:
|
||||
if part["name"] not in available_voices:
|
||||
raise ValueError(f"Base voice '{part['name']}' not found in combined string '{voice}'. Available: {available_voices}")
|
||||
|
||||
if "(" in voice_object and ")" in voice_object:
|
||||
voice_name = voice_object.split("(")[0].strip()
|
||||
voice_weight = float(voice_object.split("(")[1].split(")")[0])
|
||||
else:
|
||||
voice_name = voice_object
|
||||
voice_weight = 1
|
||||
|
||||
total_weight += voice_weight
|
||||
split_voice[voice_index] = (voice_name, voice_weight)
|
||||
# Determine normalization factor
|
||||
norm_factor = total_weight if settings.voice_weight_normalization and total_weight > 0 else 1.0
|
||||
if settings.voice_weight_normalization:
|
||||
logger.debug(f"Normalizing combined voice weights by factor: {norm_factor:.2f}")
|
||||
else:
|
||||
logger.debug("Voice weight normalization disabled, using raw weights.")
|
||||
|
||||
# If voice_weight_normalization is false prevent normalizing the weights by setting the total_weight to 1 so it divides each weight by 1
|
||||
if settings.voice_weight_normalization == False:
|
||||
total_weight = 1
|
||||
|
||||
# Load the first voice as the starting point for voices to be combined onto
|
||||
path = await self._voice_manager.get_voice_path(split_voice[0][0])
|
||||
combined_tensor = await self._load_voice_from_path(
|
||||
path, split_voice[0][1] / total_weight
|
||||
)
|
||||
# Load and combine tensors
|
||||
first_part = voice_parts[0]
|
||||
base_path = await self._voice_manager.get_voice_path(first_part["name"])
|
||||
combined_tensor = await self._load_voice_from_path(base_path, first_part["weight"] / norm_factor)
|
||||
|
||||
# Loop through each + or - in split_voice so they can be applied to combined voice
|
||||
for operation_index in range(1, len(split_voice) - 1, 2):
|
||||
# Get the voice path of the voice 1 index ahead of the operator
|
||||
path = await self._voice_manager.get_voice_path(
|
||||
split_voice[operation_index + 1][0]
|
||||
)
|
||||
voice_tensor = await self._load_voice_from_path(
|
||||
path, split_voice[operation_index + 1][1] / total_weight
|
||||
)
|
||||
current_op = "+" # Implicitly start with addition for the first voice
|
||||
|
||||
# Either add or subtract the voice from the current combined voice
|
||||
if split_voice[operation_index] == "+":
|
||||
for i in range(len(voice_parts) - 1):
|
||||
current_part = voice_parts[i]
|
||||
next_part = voice_parts[i+1]
|
||||
|
||||
# Determine the operation based on the *current* part's operator
|
||||
op_symbol = current_part["op"] if current_part["op"] else "+" # Default to '+' if no operator
|
||||
|
||||
path = await self._voice_manager.get_voice_path(next_part["name"])
|
||||
voice_tensor = await self._load_voice_from_path(path, next_part["weight"] / norm_factor)
|
||||
|
||||
if op_symbol == "+":
|
||||
combined_tensor += voice_tensor
|
||||
else:
|
||||
logger.debug(f"Adding voice {next_part['name']} (weight {next_part['weight']/norm_factor:.2f})")
|
||||
elif op_symbol == "-":
|
||||
combined_tensor -= voice_tensor
|
||||
logger.debug(f"Subtracting voice {next_part['name']} (weight {next_part['weight']/norm_factor:.2f})")
|
||||
|
||||
# Save the new combined voice so it can be loaded latter
|
||||
|
||||
# Save the new combined voice so it can be loaded later
|
||||
# Use a safe filename based on the original input string
|
||||
safe_filename = re.sub(r'[^\w+-]', '_', voice) + ".pt"
|
||||
temp_dir = tempfile.gettempdir()
|
||||
combined_path = os.path.join(temp_dir, f"{voice}.pt")
|
||||
logger.debug(f"Saving combined voice to: {combined_path}")
|
||||
torch.save(combined_tensor, combined_path)
|
||||
return voice, combined_path
|
||||
combined_path = os.path.join(temp_dir, safe_filename)
|
||||
logger.debug(f"Saving combined voice '{voice}' to temporary path: {combined_path}")
|
||||
# Save the tensor to the device specified by settings for model loading consistency
|
||||
target_device = settings.get_device()
|
||||
torch.save(combined_tensor.to(target_device), combined_path)
|
||||
return voice, combined_path # Return original name and temp path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get voice path: {e}")
|
||||
logger.error(f"Failed to get or combine voice path for '{voice}': {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def generate_audio_stream(
|
||||
self,
|
||||
text: str,
|
||||
|
@ -262,25 +295,29 @@ class TTSService:
|
|||
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
|
||||
return_timestamps: Optional[bool] = False,
|
||||
) -> AsyncGenerator[AudioChunk, None]:
|
||||
"""Generate and stream audio chunks."""
|
||||
"""Generate and stream audio chunks, handling text, pauses, and newlines."""
|
||||
stream_normalizer = AudioNormalizer()
|
||||
chunk_index = 0
|
||||
current_offset = 0.0
|
||||
current_offset = 0.0 # Track audio time offset for timestamps
|
||||
try:
|
||||
# Get backend
|
||||
backend = self.model_manager.get_backend()
|
||||
|
||||
# Get voice path, handling combined voices
|
||||
# voice_name will be the potentially complex combined name string
|
||||
voice_name, voice_path = await self._get_voices_path(voice)
|
||||
logger.debug(f"Using voice path: {voice_path}")
|
||||
logger.debug(f"Using voice path for '{voice_name}': {voice_path}")
|
||||
|
||||
# Use provided lang_code or determine from voice name
|
||||
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
|
||||
# Determine language code
|
||||
# Use provided lang_code, fallback to settings override, then first letter of first base voice
|
||||
first_base_voice_match = re.match(r"([a-zA-Z0-9_]+)", voice)
|
||||
first_base_voice = first_base_voice_match.group(1) if first_base_voice_match else "a" # Default 'a'
|
||||
pipeline_lang_code = lang_code if lang_code else (settings.default_voice_code if settings.default_voice_code else first_base_voice[:1].lower())
|
||||
logger.info(
|
||||
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream"
|
||||
)
|
||||
|
||||
# Process text in chunks with smart splitting, handling pauses
|
||||
# Process text in chunks (handling pauses and newlines within smart_split)
|
||||
async for chunk_text, tokens, pause_duration_s in smart_split(
|
||||
text,
|
||||
lang_code=pipeline_lang_code,
|
||||
|
@ -291,29 +328,21 @@ class TTSService:
|
|||
try:
|
||||
logger.debug(f"Generating {pause_duration_s}s silence chunk")
|
||||
silence_samples = int(pause_duration_s * settings.sample_rate)
|
||||
# Use float32 zeros as AudioService will normalize later
|
||||
silence_audio = np.zeros(silence_samples, dtype=np.float32)
|
||||
pause_chunk = AudioChunk(audio=silence_audio, word_timestamps=[]) # Empty timestamps for silence
|
||||
|
||||
# Convert silence chunk to the target format using AudioService
|
||||
# Format and yield the silence chunk
|
||||
if output_format:
|
||||
formatted_pause_chunk = await AudioService.convert_audio(
|
||||
pause_chunk,
|
||||
output_format,
|
||||
writer,
|
||||
speed=1.0, # Speed doesn't affect silence
|
||||
chunk_text="", # No text for silence
|
||||
is_last_chunk=False, # Not the final chunk
|
||||
trim_audio=False, # Don't trim silence
|
||||
normalizer=stream_normalizer,
|
||||
pause_chunk, output_format, writer, speed=1.0, chunk_text="",
|
||||
is_last_chunk=False, trim_audio=False, normalizer=stream_normalizer,
|
||||
)
|
||||
if formatted_pause_chunk.output:
|
||||
yield formatted_pause_chunk
|
||||
else:
|
||||
# If no output format (raw audio), yield the raw chunk
|
||||
# Ensure normalization happens if needed (AudioService handles this)
|
||||
pause_chunk.audio = stream_normalizer.normalize(pause_chunk.audio)
|
||||
yield pause_chunk # Yield raw silence chunk
|
||||
else: # Raw audio mode
|
||||
pause_chunk.audio = stream_normalizer.normalize(pause_chunk.audio)
|
||||
if len(pause_chunk.audio) > 0:
|
||||
yield pause_chunk
|
||||
|
||||
# Update offset based on silence duration
|
||||
current_offset += pause_duration_s
|
||||
|
@ -322,105 +351,122 @@ class TTSService:
|
|||
except Exception as e:
|
||||
logger.error(f"Failed to process pause chunk: {str(e)}")
|
||||
continue
|
||||
elif tokens or chunk_text:
|
||||
|
||||
elif tokens or chunk_text.strip(): # Process if there are tokens OR non-whitespace text
|
||||
# --- Handle Text Chunk ---
|
||||
original_text_with_markers = chunk_text # Keep original including markers/newlines
|
||||
text_chunk_for_model = chunk_text.strip() # Clean text for the model
|
||||
has_trailing_newline = chunk_text.endswith('\n')
|
||||
|
||||
try:
|
||||
# Process audio for the text chunk
|
||||
async for chunk_data in self._process_chunk(
|
||||
chunk_text, # Pass text for Kokoro V1
|
||||
tokens, # Pass tokens for legacy backends
|
||||
voice_name, # Pass voice name
|
||||
voice_path, # Pass voice path
|
||||
speed,
|
||||
writer,
|
||||
output_format,
|
||||
is_first=(chunk_index == 0),
|
||||
is_last=False, # We'll update the last chunk later
|
||||
normalizer=stream_normalizer,
|
||||
lang_code=pipeline_lang_code, # Pass lang_code
|
||||
return_timestamps=return_timestamps,
|
||||
):
|
||||
if chunk_data.word_timestamps is not None:
|
||||
for timestamp in chunk_data.word_timestamps:
|
||||
timestamp.start_time += current_offset
|
||||
timestamp.end_time += current_offset
|
||||
text_chunk_for_model, # Pass cleaned text for model processing
|
||||
tokens,
|
||||
voice_name,
|
||||
voice_path,
|
||||
speed,
|
||||
writer,
|
||||
output_format,
|
||||
is_first=(chunk_index == 0),
|
||||
is_last=False,
|
||||
normalizer=stream_normalizer,
|
||||
lang_code=pipeline_lang_code,
|
||||
return_timestamps=return_timestamps,
|
||||
):
|
||||
# Adjust timestamps relative to the stream start
|
||||
if chunk_data.word_timestamps:
|
||||
for timestamp in chunk_data.word_timestamps:
|
||||
timestamp.start_time += current_offset
|
||||
timestamp.end_time += current_offset
|
||||
|
||||
else:
|
||||
# If no output format (raw audio), yield the raw chunk
|
||||
# Ensure normalization happens if needed (AudioService handles this)
|
||||
pause_chunk.audio = stream_normalizer.normalize(pause_chunk.audio)
|
||||
if len(pause_chunk.audio) > 0: # Only yield if silence is not zero length
|
||||
yield pause_chunk # Yield raw silence chunk
|
||||
# Update offset based on the *actual duration* of the generated audio chunk
|
||||
# Check if audio data exists before calculating duration
|
||||
chunk_duration = 0
|
||||
if chunk_data.audio is not None and len(chunk_data.audio) > 0:
|
||||
chunk_duration = len(chunk_data.audio) / settings.sample_rate
|
||||
current_offset += chunk_duration
|
||||
|
||||
# Update offset based on silence duration
|
||||
f"No audio generated for chunk: '{chunk_text.strip()[:100]}...'"
|
||||
chunk_index += 1
|
||||
# --- Add pause after newline ---
|
||||
# Check the original chunk_text passed from smart_split for trailing newline
|
||||
if chunk_text.endswith('\n'):
|
||||
# Yield the processed chunk (either formatted or raw)
|
||||
if output_format and chunk_data.output:
|
||||
yield chunk_data
|
||||
elif not output_format and chunk_data.audio is not None and len(chunk_data.audio) > 0:
|
||||
yield chunk_data
|
||||
else:
|
||||
logger.warning(
|
||||
f"No audio generated or output for text chunk: '{text_chunk_for_model[:50]}...'"
|
||||
)
|
||||
|
||||
|
||||
# --- Add pause after newline (if applicable) ---
|
||||
if has_trailing_newline:
|
||||
newline_pause_s = 0.5
|
||||
try:
|
||||
logger.debug(f"Adding {newline_pause_s}s pause after newline.")
|
||||
silence_samples = int(newline_pause_s * settings.sample_rate)
|
||||
silence_audio = np.zeros(silence_samples, dtype=np.float32)
|
||||
pause_chunk = AudioChunk(audio=silence_audio, word_timestamps=[])
|
||||
# Create a *new* AudioChunk instance for the newline pause
|
||||
newline_pause_chunk = AudioChunk(audio=silence_audio, word_timestamps=[])
|
||||
|
||||
if output_format:
|
||||
formatted_pause_chunk = await AudioService.convert_audio(
|
||||
pause_chunk, output_format, writer, speed=1.0, chunk_text="",
|
||||
newline_pause_chunk, output_format, writer, speed=1.0, chunk_text="", # Use newline_pause_chunk
|
||||
is_last_chunk=False, trim_audio=False, normalizer=stream_normalizer,
|
||||
)
|
||||
if formatted_pause_chunk.output:
|
||||
yield formatted_pause_chunk
|
||||
else:
|
||||
pause_chunk.audio = stream_normalizer.normalize(pause_chunk.audio)
|
||||
if len(pause_chunk.audio) > 0:
|
||||
yield pause_chunk
|
||||
# Normalize the *new* chunk before yielding
|
||||
newline_pause_chunk.audio = stream_normalizer.normalize(newline_pause_chunk.audio)
|
||||
if len(newline_pause_chunk.audio) > 0:
|
||||
yield newline_pause_chunk # Yield the normalized newline pause chunk
|
||||
|
||||
current_offset += newline_pause_s # Add newline pause to offset
|
||||
|
||||
except Exception as pause_e:
|
||||
logger.error(f"Failed to process newline pause chunk: {str(pause_e)}")
|
||||
# -------------------------------
|
||||
# ------------------------------------------------
|
||||
|
||||
chunk_index += 1 # Increment chunk index after processing text and potential newline pause
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to process audio for chunk: '{chunk_text.strip()[:100]}...'. Error: {str(e)}"
|
||||
)
|
||||
continue
|
||||
logger.exception( # Use exception to include traceback
|
||||
f"Failed processing audio for chunk: '{text_chunk_for_model[:50]}...'. Error: {str(e)}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Only finalize if we successfully processed at least one chunk
|
||||
# --- End of main loop ---
|
||||
|
||||
# Finalize the stream (sends any remaining buffered data)
|
||||
# Only finalize if we successfully processed at least one chunk (text or pause)
|
||||
if chunk_index > 0:
|
||||
try:
|
||||
# Empty tokens list to finalize audio
|
||||
async for chunk_data in self._process_chunk(
|
||||
"", # Empty text
|
||||
[], # Empty tokens
|
||||
voice_name,
|
||||
voice_path,
|
||||
speed,
|
||||
writer,
|
||||
output_format,
|
||||
is_first=False,
|
||||
is_last=True, # Signal this is the last chunk
|
||||
normalizer=stream_normalizer,
|
||||
lang_code=pipeline_lang_code, # Pass lang_code
|
||||
async for final_chunk_data in self._process_chunk(
|
||||
"", [], voice_name, voice_path, speed, writer, output_format,
|
||||
is_first=False, is_last=True, normalizer=stream_normalizer, lang_code=pipeline_lang_code
|
||||
):
|
||||
if chunk_data.output is not None:
|
||||
yield chunk_data
|
||||
if output_format and final_chunk_data.output:
|
||||
yield final_chunk_data
|
||||
elif not output_format and final_chunk_data.audio is not None and len(final_chunk_data.audio) > 0:
|
||||
yield final_chunk_data # Should yield empty chunk in raw mode upon finalize
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to finalize audio stream: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in phoneme audio generation: {str(e)}")
|
||||
raise e
|
||||
logger.exception(f"Error during audio stream generation: {str(e)}") # Use exception for traceback
|
||||
# Ensure writer is closed on error
|
||||
try:
|
||||
writer.close()
|
||||
except Exception as close_e:
|
||||
logger.error(f"Error closing writer during exception handling: {close_e}")
|
||||
raise e # Re-raise the original exception
|
||||
|
||||
|
||||
async def generate_audio(
|
||||
self,
|
||||
text: str,
|
||||
voice: str,
|
||||
writer: StreamingAudioWriter,
|
||||
writer: StreamingAudioWriter, # Writer needed even for non-streaming internally
|
||||
speed: float = 1.0,
|
||||
return_timestamps: bool = False,
|
||||
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
|
||||
|
@ -428,37 +474,42 @@ class TTSService:
|
|||
) -> AudioChunk:
|
||||
"""Generate complete audio for text using streaming internally."""
|
||||
audio_data_chunks = []
|
||||
|
||||
output_format = None # Signal raw audio mode for internal streaming
|
||||
combined_chunk = None
|
||||
try:
|
||||
async for audio_stream_data in self.generate_audio_stream(
|
||||
text,
|
||||
voice,
|
||||
writer,
|
||||
writer, # Pass writer, although it won't be used for formatting here
|
||||
speed=speed,
|
||||
normalization_options=normalization_options,
|
||||
return_timestamps=return_timestamps,
|
||||
lang_code=lang_code,
|
||||
output_format=None,
|
||||
output_format=output_format, # Explicitly None for raw audio
|
||||
):
|
||||
if len(audio_stream_data.audio) > 0:
|
||||
# Ensure we only append chunks with actual audio data
|
||||
# Raw silence chunks generated for pauses will have audio data (zeros)
|
||||
# Formatted silence chunks might have empty audio but non-empty output
|
||||
if len(audio_stream_data.audio) > 0 or (output_format and audio_stream_data.output):
|
||||
audio_data_chunks.append(audio_stream_data)
|
||||
# Ensure we only append chunks with actual audio data
|
||||
# Raw silence chunks generated for pauses will have audio data (zeros)
|
||||
if audio_stream_data.audio is not None and len(audio_stream_data.audio) > 0:
|
||||
audio_data_chunks.append(audio_stream_data)
|
||||
|
||||
if not audio_data_chunks:
|
||||
# Handle cases where only pauses were present or generation failed
|
||||
logger.warning("No valid audio chunks generated.")
|
||||
# Return an empty AudioChunk or raise an error? Returning empty for now.
|
||||
return AudioChunk(audio=np.array([], dtype=np.int16), word_timestamps=[])
|
||||
combined_chunk = AudioChunk(audio=np.array([], dtype=np.int16), word_timestamps=[])
|
||||
else:
|
||||
combined_chunk = AudioChunk.combine(audio_data_chunks)
|
||||
|
||||
|
||||
combined_audio_data = AudioChunk.combine(audio_data_chunks)
|
||||
return combined_audio_data
|
||||
return combined_chunk
|
||||
except Exception as e:
|
||||
logger.error(f"Error in audio generation: {str(e)}")
|
||||
raise
|
||||
logger.error(f"Error in combined audio generation: {str(e)}")
|
||||
raise # Re-raise after logging
|
||||
finally:
|
||||
# Explicitly close the writer if it was passed, though it shouldn't hold resources in raw mode
|
||||
try:
|
||||
writer.close()
|
||||
except Exception:
|
||||
pass # Ignore errors during cleanup
|
||||
|
||||
|
||||
|
||||
async def combine_voices(self, voices: List[str]) -> torch.Tensor:
|
||||
"""Combine multiple voices.
|
||||
|
@ -495,38 +546,45 @@ class TTSService:
|
|||
try:
|
||||
# Get backend and voice path
|
||||
backend = self.model_manager.get_backend()
|
||||
# Use _get_voices_path to handle potential combined voice names passed here too
|
||||
voice_name, voice_path = await self._get_voices_path(voice)
|
||||
|
||||
if isinstance(backend, KokoroV1):
|
||||
# For Kokoro V1, use generate_from_tokens with raw phonemes
|
||||
result = None
|
||||
# Use provided lang_code or determine from voice name
|
||||
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
|
||||
result_audio = None
|
||||
# Determine language code
|
||||
first_base_voice_match = re.match(r"([a-zA-Z0-9_]+)", voice_name)
|
||||
first_base_voice = first_base_voice_match.group(1) if first_base_voice_match else "a"
|
||||
pipeline_lang_code = lang_code if lang_code else (settings.default_voice_code if settings.default_voice_code else first_base_voice[:1].lower())
|
||||
|
||||
logger.info(
|
||||
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme pipeline"
|
||||
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme generation"
|
||||
)
|
||||
|
||||
try:
|
||||
# Use backend's pipeline management
|
||||
for r in backend._get_pipeline(
|
||||
pipeline_lang_code
|
||||
).generate_from_tokens(
|
||||
tokens=phonemes, # Pass raw phonemes string
|
||||
voice=voice_path,
|
||||
speed=speed,
|
||||
):
|
||||
if r.audio is not None:
|
||||
result = r
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate from phonemes: {e}")
|
||||
raise RuntimeError(f"Phoneme generation failed: {e}")
|
||||
# Use backend's pipeline management and iterate through potential chunks
|
||||
full_audio_list = []
|
||||
async for r in backend.generate_from_tokens( # generate_from_tokens is now async
|
||||
tokens=phonemes, # Pass raw phonemes string
|
||||
voice=(voice_name, voice_path), # Pass tuple
|
||||
speed=speed,
|
||||
lang_code=pipeline_lang_code,
|
||||
):
|
||||
if r is not None and len(r) > 0:
|
||||
# r is directly the numpy array chunk
|
||||
full_audio_list.append(r)
|
||||
|
||||
if result is None or result.audio is None:
|
||||
raise ValueError("No audio generated")
|
||||
|
||||
if not full_audio_list:
|
||||
raise ValueError("No audio generated from phonemes")
|
||||
|
||||
# Combine chunks if necessary
|
||||
result_audio = np.concatenate(full_audio_list) if len(full_audio_list) > 1 else full_audio_list[0]
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
return result.audio.numpy(), processing_time
|
||||
# Normalize the final audio before returning
|
||||
normalizer = AudioNormalizer()
|
||||
normalized_audio = normalizer.normalize(result_audio)
|
||||
return normalized_audio, processing_time
|
||||
else:
|
||||
raise ValueError(
|
||||
"Phoneme generation only supported with Kokoro V1 backend"
|
||||
|
|
Loading…
Add table
Reference in a new issue