From c0da5718578df71930ddee82d824bf9999465071 Mon Sep 17 00:00:00 2001 From: Lukin Date: Mon, 7 Apr 2025 14:12:18 +0800 Subject: [PATCH] Refactor TTS service and text processing to enhance handling of pauses, newlines, and custom phonemes. Updated smart_split to manage pause tags and improved error logging. Adjusted audio generation logic for better performance and clarity. --- .../text_processing/text_processor.py | 345 ++++++------ api/src/services/tts_service.py | 506 ++++++++++-------- 2 files changed, 453 insertions(+), 398 deletions(-) diff --git a/api/src/services/text_processing/text_processor.py b/api/src/services/text_processing/text_processor.py index cf145e9..315de5b 100644 --- a/api/src/services/text_processing/text_processor.py +++ b/api/src/services/text_processing/text_processor.py @@ -2,7 +2,7 @@ import re import time -from typing import AsyncGenerator, Dict, List, Tuple +from typing import AsyncGenerator, Dict, List, Tuple, Optional # Add Optional import from loguru import logger @@ -14,6 +14,8 @@ from .vocabulary import tokenize # Pre-compiled regex patterns for performance CUSTOM_PHONEMES = re.compile(r"(\[([^\]]|\n)*?\])(\(\/([^\/)]|\n)*?\/\))") +# Pattern to find pause tags like [pause:0.5s] +PAUSE_TAG_PATTERN = re.compile(r"\[pause:(\d+(?:\.\d+)?)s\]", re.IGNORECASE) def process_text_chunk( @@ -42,7 +44,8 @@ def process_text_chunk( t1 = time.time() t0 = time.time() - phonemes = phonemize(text, language, normalize=False) # Already normalized + # Normalize step is usually done before smart_split, but phonemize itself might do basic norm + phonemes = phonemize(text, language, normalize=False) t1 = time.time() t0 = time.time() @@ -51,7 +54,7 @@ def process_text_chunk( total_time = time.time() - start_time logger.debug( - f"Total processing took {total_time * 1000:.2f}ms for chunk: '{text[:50]}{'...' if len(text) > 50 else ''}'" + f"Tokenization took {total_time * 1000:.2f}ms for chunk: '{text[:50]}{'...' if len(text) > 50 else ''}'" ) return tokens @@ -90,45 +93,61 @@ def process_text(text: str, language: str = "a") -> List[int]: def get_sentence_info( text: str, custom_phenomes_list: Dict[str, str] ) -> List[Tuple[str, List[int], int]]: - """Process all sentences and return info.""" - sentences = re.split(r"([.!?;:])(?=\s|$)", text) - phoneme_length, min_value = len(custom_phenomes_list), 0 + """Process all sentences and return info, preserving trailing newlines.""" + # Split by sentence-ending punctuation, keeping the punctuation + sentences_parts = re.split(r'([.!?]+|\n+)', text) + sentences = [] + current_sentence = "" + for part in sentences_parts: + if not part: + continue + current_sentence += part + # If the part ends with sentence punctuation or newline, consider it a sentence end + if re.search(r'[.!?\n]$', part): + sentences.append(current_sentence) + current_sentence = "" + if current_sentence: # Add any remaining part + sentences.append(current_sentence) + + + phoneme_length = len(custom_phenomes_list) + restored_phoneme_keys = list(custom_phenomes_list.keys()) # Keys to restore results = [] - for i in range(0, len(sentences), 2): - sentence = sentences[i].strip() - for replaced in range(min_value, phoneme_length): - current_id = f"" - if current_id in sentence: - sentence = sentence.replace( - current_id, custom_phenomes_list.pop(current_id) - ) - min_value += 1 + for original_sentence in sentences: + sentence_text_part = original_sentence.rstrip('\n') # Text without trailing newline for processing + trailing_newlines = original_sentence[len(sentence_text_part):] # Capture trailing newlines - punct = sentences[i + 1] if i + 1 < len(sentences) else "" - - if not sentence: + if not sentence_text_part.strip(): # Skip empty or whitespace-only sentences + if trailing_newlines: # If only newlines, represent as empty text with newline marker + results.append(("\n", [], 0)) # Store newline marker, no tokens continue - # Check if the original text segment ended with newline(s) before punctuation - original_segment = sentences[i] - trailing_newlines = "" - match = re.search(r"(\n+)$", original_segment) - if match: - trailing_newlines = match.group(1) + # Restore custom phonemes for this sentence *before* tokenization + sentence_to_tokenize = sentence_text_part + restored_count = 0 + # Iterate through *all* possible phoneme IDs that might be in this sentence + for ph_id in restored_phoneme_keys: + if ph_id in sentence_to_tokenize: + sentence_to_tokenize = sentence_to_tokenize.replace(ph_id, custom_phenomes_list[ph_id]) + restored_count+=1 + if restored_count > 0: + logger.debug(f"Restored {restored_count} custom phonemes for tokenization in: '{sentence_text_part[:30]}...'") - full = sentence + punct + trailing_newlines # Append trailing newlines - # Tokenize without the trailing newlines for accurate TTS processing - tokens = process_text_chunk(sentence + punct) - # Store the full text including newlines for later check - results.append((full, tokens, len(tokens))) + + # Tokenize the text part (without trailing newlines) + tokens = process_text_chunk(sentence_to_tokenize) + + # Store the original sentence text (including trailing newlines) along with tokens + results.append((original_sentence, tokens, len(tokens))) return results def handle_custom_phonemes(s: re.Match[str], phenomes_list: Dict[str, str]) -> str: latest_id = f"" - phenomes_list[latest_id] = s.group(0).strip() + phenomes_list[latest_id] = s.group(0).strip() # Store the full original tag [phoneme](/ipa/) + logger.debug(f"Replacing custom phoneme {phenomes_list[latest_id]} with ID {latest_id}") return latest_id @@ -137,174 +156,152 @@ async def smart_split( max_tokens: int = settings.absolute_max_tokens, lang_code: str = "a", normalization_options: NormalizationOptions = NormalizationOptions(), -) -> AsyncGenerator[Tuple[str, List[int]], None]: - """Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens.""" +) -> AsyncGenerator[Tuple[str, List[int], Optional[float]], None]: + """Build optimal chunks targeting token limits, handling pause tags and newlines. + + Yields: + Tuple of (text_chunk, tokens, pause_duration_s). + If pause_duration_s is not None, it's a pause chunk with empty text/tokens. + Otherwise, it's a text chunk. text_chunk may end with '\n'. + """ start_time = time.time() chunk_count = 0 - logger.info(f"Starting smart split for {len(text)} chars") + logger.info(f"Starting smart split for {len(text)} chars, max_tokens={max_tokens}") custom_phoneme_list = {} - # Normalize text + # 1. Temporarily replace custom phonemes like [word](/ipa/) with unique IDs + text_with_ids = CUSTOM_PHONEMES.sub( + lambda s: handle_custom_phonemes(s, custom_phoneme_list), text + ) + if custom_phoneme_list: + logger.debug(f"Found custom phonemes: {custom_phoneme_list}") + + + # 2. Normalize the text *with IDs* if required + normalized_text = text_with_ids if settings.advanced_text_normalization and normalization_options.normalize: - print(lang_code) if lang_code in ["a", "b", "en-us", "en-gb"]: - text = CUSTOM_PHONEMES.sub( - lambda s: handle_custom_phonemes(s, custom_phoneme_list), text - ) - text = normalize_text(text, normalization_options) + normalized_text = normalize_text(normalized_text, normalization_options) + logger.debug("Applied text normalization.") else: logger.info( "Skipping text normalization as it is only supported for english" ) - # Process all sentences - sentences = get_sentence_info(text, custom_phoneme_list) + # 3. Split the normalized text by pause tags + parts = PAUSE_TAG_PATTERN.split(normalized_text) + logger.debug(f"Split into {len(parts)} parts by pause tags.") - current_chunk = [] - current_tokens = [] - current_count = 0 - for sentence, tokens, count in sentences: - # Handle sentences that exceed max tokens - if count > max_tokens: - # Yield current chunk if any - if current_chunk: - # Join with space, but preserve original trailing newline of the last sentence if present - last_sentence_original = current_chunk[-1] - chunk_text_joined = " ".join(current_chunk) - if last_sentence_original.endswith("\n"): - chunk_text_joined += "\n" # Preserve the newline marker + part_idx = 0 + while part_idx < len(parts): + text_part = parts[part_idx] # This part contains text and custom phoneme IDs + part_idx += 1 - chunk_count += 1 - logger.debug( - f"Yielding text chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({current_count} tokens)" - ) - yield chunk_text_joined, current_tokens, None # Pass the text with potential trailing newline - current_chunk = [] - current_tokens = [] - current_count = 0 + if text_part: + # Process this text part using sentence splitting + # We pass the text_part *with IDs* to get_sentence_info + # get_sentence_info will handle restoring phonemes just before tokenization + sentences = get_sentence_info(text_part, custom_phoneme_list) - # Split long sentence on commas (simple approach) - # Keep original sentence text ('sentence' now includes potential trailing newline) - clauses = re.split(r"([,])", sentence.rstrip('\n')) # Split text part only - trailing_newline_in_sentence = "\n" if sentence.endswith("\n") else "" - clause_chunk = [] # Stores original clause text including potential trailing newline - clause_tokens = [] - clause_count = 0 + current_chunk_texts = [] # Store original sentence texts for the current chunk + current_chunk_tokens = [] + current_token_count = 0 - for j in range(0, len(clauses), 2): - # clause = clauses[j].strip() # Don't strip here to preserve internal structure - clause = clauses[j] - comma = clauses[j + 1] if j + 1 < len(clauses) else "" - - if not clause.strip(): # Check if clause is just whitespace - - full_clause = clause + comma - - tokens = process_text_chunk(full_clause) - count = len(tokens) - - # If adding clause keeps us under max and not optimal yet - if ( - clause_count + count <= max_tokens - and clause_count + count <= settings.target_max_tokens - ): - clause_chunk.append(full_clause) - clause_tokens.extend(tokens) - clause_count += count - else: - # Yield clause chunk if we have one - if clause_chunk: - # Join with space, preserve last clause's potential trailing newline - last_clause_original = clause_chunk[-1] - chunk_text_joined = " ".join(clause_chunk) - if last_clause_original.endswith("\n"): - chunk_text_joined += "\n" + for sentence_text, sentence_tokens, sentence_token_count in sentences: + # --- Chunking Logic --- + # Condition 1: Current sentence alone exceeds max tokens + if sentence_token_count > max_tokens: + logger.warning(f"Single sentence exceeds max_tokens ({sentence_token_count} > {max_tokens}): '{sentence_text[:50]}...'") + # Yield any existing chunk first + if current_chunk_texts: + chunk_text_joined = " ".join(current_chunk_texts) # Join original texts chunk_count += 1 - logger.debug( - f"Yielding clause chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({clause_count} tokens)" - ) - yield chunk_text_joined, clause_tokens, None - # Start new clause chunk with original text - clause_chunk = [full_clause + (trailing_newline_in_sentence if j == len(clauses) - 2 else "")] - clause_tokens = clause_token_list - clause_count = clause_token_count + logger.info(f"Yielding text chunk {chunk_count} (before oversized sentence): '{chunk_text_joined[:50]}...' ({current_token_count} tokens)") + yield chunk_text_joined, current_chunk_tokens, None + current_chunk_texts = [] + current_chunk_tokens = [] + current_token_count = 0 - # Don't forget last clause chunk - if clause_chunk: - # Join with space, preserve last clause's potential trailing newline - last_clause_original = clause_chunk[-1] - chunk_text_joined = " ".join(clause_chunk) - # The trailing newline logic was added when creating the chunk above - #if last_clause_original.endswith("\n"): - # chunk_text_joined += "\n" + # Yield the oversized sentence as its own chunk + # Restore phonemes before yielding the text + text_to_yield = sentence_text + for p_id, p_val in custom_phoneme_list.items(): + if p_id in text_to_yield: + text_to_yield = text_to_yield.replace(p_id, p_val) chunk_count += 1 - logger.debug( - f"Yielding final clause chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({clause_count} tokens)" - ) - yield chunk_text_joined, clause_tokens, None - current_count >= settings.target_min_tokens - and current_count + count > settings.target_max_tokens - ): - # If we have a good sized chunk and adding next sentence exceeds target, - # Yield current chunk and start new one - last_sentence_original = current_chunk[-1] - chunk_text_joined = " ".join(current_chunk) - if last_sentence_original.endswith("\n"): - chunk_text_joined += "\n" - chunk_count += 1 - logger.info( - f"Yielding text chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({current_count} tokens)" - ) - yield chunk_text_joined, current_tokens, None - current_chunk = [sentence] # sentence includes potential trailing newline - current_tokens = tokens - current_count = count - elif current_count + count <= settings.target_max_tokens: - # Keep building chunk - current_chunk.append(sentence) # sentence includes potential trailing newline - current_tokens.extend(tokens) - current_count += count - elif ( - current_count + count <= max_tokens - and current_count < settings.target_min_tokens - ): - # Exceed target max only if below min size - current_chunk.append(sentence) # sentence includes potential trailing newline - current_tokens.extend(tokens) - current_count += count - else: - # Yield current chunk and start new one - if current_chunk: - last_sentence_original = current_chunk[-1] - chunk_text_joined = " ".join(current_chunk) - if last_sentence_original.endswith("\n"): - chunk_text_joined += "\n" - chunk_count += 1 - logger.info( - f"Yielding text chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({current_count} tokens)" - ) - yield chunk_text_joined, current_tokens, None - current_chunk = [sentence] # sentence includes potential trailing newline - current_tokens = tokens - current_count = count + logger.info(f"Yielding oversized text chunk {chunk_count}: '{text_to_yield[:50]}...' ({sentence_token_count} tokens)") + yield text_to_yield, sentence_tokens, None + continue # Move to the next sentence + + # Condition 2: Adding the current sentence would exceed max_tokens + elif current_token_count + sentence_token_count > max_tokens: + # Yield the current chunk first + if current_chunk_texts: + chunk_text_joined = " ".join(current_chunk_texts) # Join original texts + chunk_count += 1 + logger.info(f"Yielding text chunk {chunk_count} (max_tokens limit): '{chunk_text_joined[:50]}...' ({current_token_count} tokens)") + yield chunk_text_joined, current_chunk_tokens, None + # Start a new chunk with the current sentence + current_chunk_texts = [sentence_text] + current_chunk_tokens = sentence_tokens + current_token_count = sentence_token_count + + # Condition 3: Adding exceeds target_max_tokens when already above target_min_tokens + elif (current_token_count >= settings.target_min_tokens and + current_token_count + sentence_token_count > settings.target_max_tokens): + # Yield the current chunk + chunk_text_joined = " ".join(current_chunk_texts) # Join original texts + chunk_count += 1 + logger.info(f"Yielding text chunk {chunk_count} (target_max limit): '{chunk_text_joined[:50]}...' ({current_token_count} tokens)") + yield chunk_text_joined, current_chunk_tokens, None + # Start a new chunk + current_chunk_texts = [sentence_text] + current_chunk_tokens = sentence_tokens + current_token_count = sentence_token_count + + # Condition 4: Add sentence to current chunk (fits within max_tokens and either below target_max or below target_min) + else: + current_chunk_texts.append(sentence_text) + current_chunk_tokens.extend(sentence_tokens) + current_token_count += sentence_token_count + + # --- End of sentence loop for this text part --- + + # Yield any remaining accumulated chunk for this text part + if current_chunk_texts: + chunk_text_joined = " ".join(current_chunk_texts) # Join original texts + # Restore phonemes before yielding + text_to_yield = chunk_text_joined + for p_id, p_val in custom_phoneme_list.items(): + if p_id in text_to_yield: + text_to_yield = text_to_yield.replace(p_id, p_val) + + chunk_count += 1 + logger.info(f"Yielding final text chunk {chunk_count} for part: '{text_to_yield[:50]}...' ({current_token_count} tokens)") + yield text_to_yield, current_chunk_tokens, None + + + # Check if the next part is a pause duration + if part_idx < len(parts): + duration_str = parts[part_idx] + part_idx += 1 # Move past the duration string + try: + duration = float(duration_str) + if duration > 0: + chunk_count += 1 + logger.info(f"Yielding pause chunk {chunk_count}: {duration}s") + yield "", [], duration # Yield pause chunk + except (ValueError, TypeError): + logger.warning(f"Could not parse pause duration: {duration_str}") + # If parsing fails, potentially treat the duration_str as text? + # For now, just log a warning and skip. - # Yield any remaining text chunk - if current_chunk: - last_sentence_original = current_chunk[-1] - chunk_text_joined = " ".join(current_chunk) - if last_sentence_original.endswith("\n"): - chunk_text_joined += "\n" - chunk_count += 1 - logger.info( - f"Yielding final text chunk {chunk_count}: '{chunk_text_joined[:50]}{'...' if len(chunk_text_joined) > 50 else ''}' ({current_count} tokens)" - ) - yield chunk_text_joined, current_tokens, None total_time = time.time() - start_time logger.info( - f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks" + f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks (including pauses)" ) \ No newline at end of file diff --git a/api/src/services/tts_service.py b/api/src/services/tts_service.py index 0be9024..45163c6 100644 --- a/api/src/services/tts_service.py +++ b/api/src/services/tts_service.py @@ -78,28 +78,83 @@ class TTSService: "", normalizer=normalizer, is_last_chunk=True, + trim_audio=False, # Don't trim final silence ) yield chunk_data return - # Skip empty chunks + # Skip empty chunks (shouldn't happen if called correctly, but safety) if not tokens and not chunk_text: - return + logger.warning("Empty chunk passed to _process_chunk") + return # Get backend backend = self.model_manager.get_backend() # Generate audio using pre-warmed model + # Note: chunk_text is the *original* text including custom phoneme markers and newlines + # The model needs the text *with phonemes restored* + text_for_model = chunk_text # Start with original + # Restore custom phonemes if backend needs it (like KokoroV1) if isinstance(backend, KokoroV1): - chunk_index = 0 - # For Kokoro V1, pass text and voice info with lang_code + # Find phoneme markers in this specific chunk_text and restore + # (This assumes smart_split yielded text with markers) - let's refine smart_split yield + # For now, assume chunk_text is ready for the model (phonemes restored by smart_split) + pass + + + if isinstance(backend, KokoroV1): + internal_chunk_index = 0 async for chunk_data in self.model_manager.generate( - chunk_text, + text_for_model.strip(), # Pass cleaned text to model (voice_name, voice_path), speed=speed, lang_code=lang_code, return_timestamps=return_timestamps, ): + # For streaming, convert to bytes if format specified + if output_format: + try: + chunk_data = await AudioService.convert_audio( + chunk_data, + output_format, + writer, + speed, + chunk_text.strip(), # Pass original text for trimming logic + is_last_chunk=is_last, # Should always be False here, handled above + normalizer=normalizer, + trim_audio=True # Trim speech parts + ) + yield chunk_data + except Exception as e: + logger.error(f"Failed to convert audio: {str(e)}") + else: # Raw audio mode + chunk_data = AudioService.trim_audio( + chunk_data, chunk_text.strip(), speed, False, normalizer # Trim speech parts + ) + yield chunk_data + internal_chunk_index += 1 + if internal_chunk_index == 0: + logger.warning(f"Model generation yielded no audio chunks for: '{text_for_model[:50]}...'") + + else: + # --- Legacy backend path (using tokens) --- + # This path might not work correctly with custom phonemes restored in text_for_model + logger.warning("Using legacy backend path with tokens - custom phonemes might not be handled.") + voice_tensor = await self._voice_manager.load_voice( + voice_name, device=backend.device + ) + async for chunk_data in self.model_manager.generate( # Needs to be async generator + tokens, # Legacy uses tokens + (voice_name, voice_tensor), # Pass tuple as expected + speed=speed, + return_timestamps=return_timestamps, + ): + + if chunk_data.audio is None or len(chunk_data.audio) == 0: + logger.error("Legacy model generated empty or None audio chunk") + continue # Skip this chunk + # For streaming, convert to bytes if output_format: try: @@ -108,61 +163,22 @@ class TTSService: output_format, writer, speed, - chunk_text, - is_last_chunk=is_last, + chunk_text.strip(), # Pass original text for trimming logic normalizer=normalizer, + is_last_chunk=is_last, # Should be False here + trim_audio=True # Trim speech parts ) yield chunk_data except Exception as e: - logger.error(f"Failed to convert audio: {str(e)}") - else: - chunk_data = AudioService.trim_audio( - chunk_data, chunk_text, speed, is_last, normalizer + logger.error(f"Failed to convert legacy audio: {str(e)}") + else: # Raw audio mode + trimmed = AudioService.trim_audio( + chunk_data, chunk_text.strip(), speed, False, normalizer # Trim speech parts ) - yield chunk_data - chunk_index += 1 - else: - # For legacy backends, load voice tensor - voice_tensor = await self._voice_manager.load_voice( - voice_name, device=backend.device - ) - chunk_data = await self.model_manager.generate( - tokens, - voice_tensor, - speed=speed, - return_timestamps=return_timestamps, - ) - - if chunk_data.audio is None: - logger.error("Model generated None for audio chunk") - return - - if len(chunk_data.audio) == 0: - logger.error("Model generated empty audio chunk") - return - - # For streaming, convert to bytes - if output_format: - try: - chunk_data = await AudioService.convert_audio( - chunk_data, - output_format, - writer, - speed, - chunk_text, - normalizer=normalizer, - is_last_chunk=is_last, - ) - yield chunk_data - except Exception as e: - logger.error(f"Failed to convert audio: {str(e)}") - else: - trimmed = AudioService.trim_audio( - chunk_data, chunk_text, speed, is_last, normalizer - ) - yield trimmed + yield trimmed except Exception as e: - logger.error(f"Failed to process tokens: {str(e)}") + logger.exception(f"Failed to process chunk: '{chunk_text[:50]}...'. Error: {str(e)}") + async def _load_voice_from_path(self, path: str, weight: float): # Check if the path is None and raise a ValueError if it is not @@ -170,13 +186,15 @@ class TTSService: raise ValueError(f"Voice not found at path: {path}") logger.debug(f"Loading voice tensor from path: {path}") - return torch.load(path, map_location="cpu") * weight + # Ensure loading happens on CPU initially to avoid device mismatches + tensor = torch.load(path, map_location="cpu") + return tensor * weight async def _get_voices_path(self, voice: str) -> Tuple[str, str]: """Get voice path, handling combined voices. Args: - voice: Voice name or combined voice names (e.g., 'af_jadzia+af_jessica') + voice: Voice name or combined voice names (e.g., 'af_jadzia(0.7)+af_jessica(0.3)') Returns: Tuple of (voice name to use, voice path to use) @@ -185,72 +203,87 @@ class TTSService: RuntimeError: If voice not found """ try: - # Split the voice on + and - and ensure that they get added to the list eg: hi+bob = ["hi","+","bob"] - split_voice = re.split(r"([-+])", voice) + # Regex to handle names, weights, and operators: af_name(weight)[+-]af_other(weight)... + pattern = re.compile(r"([a-zA-Z0-9_]+)(?:\((\d+(?:\.\d+)?)\))?([+-]?)") + matches = pattern.findall(voice.replace(" ", "")) # Remove spaces - # If it is only once voice there is no point in loading it up, doing nothing with it, then saving it - if len(split_voice) == 1: - # Since its a single voice the only time that the weight would matter is if voice_weight_normalization is off - if ( - "(" not in voice and ")" not in voice - ) or settings.voice_weight_normalization == True: - path = await self._voice_manager.get_voice_path(voice) - if not path: - raise RuntimeError(f"Voice not found: {voice}") - logger.debug(f"Using single voice path: {path}") - return voice, path + if not matches: + raise ValueError(f"Could not parse voice string: {voice}") + # If only one voice and no explicit weight or operators, handle directly + if len(matches) == 1 and not matches[0][1] and not matches[0][2]: + voice_name = matches[0][0] + path = await self._voice_manager.get_voice_path(voice_name) + if not path: + raise RuntimeError(f"Voice not found: {voice_name}") + logger.debug(f"Using single voice path: {path}") + return voice_name, path + + # Process combinations + voice_parts = [] total_weight = 0 + for name, weight_str, operator in matches: + weight = float(weight_str) if weight_str else 1.0 + voice_parts.append({"name": name, "weight": weight, "op": operator}) + # Use weight directly for total, normalization happens later if enabled + total_weight += weight # Summing base weights before potential normalization - for voice_index in range(0, len(split_voice), 2): - voice_object = split_voice[voice_index] + # Check base voices exist + available_voices = await self._voice_manager.list_voices() + for part in voice_parts: + if part["name"] not in available_voices: + raise ValueError(f"Base voice '{part['name']}' not found in combined string '{voice}'. Available: {available_voices}") - if "(" in voice_object and ")" in voice_object: - voice_name = voice_object.split("(")[0].strip() - voice_weight = float(voice_object.split("(")[1].split(")")[0]) - else: - voice_name = voice_object - voice_weight = 1 - total_weight += voice_weight - split_voice[voice_index] = (voice_name, voice_weight) + # Determine normalization factor + norm_factor = total_weight if settings.voice_weight_normalization and total_weight > 0 else 1.0 + if settings.voice_weight_normalization: + logger.debug(f"Normalizing combined voice weights by factor: {norm_factor:.2f}") + else: + logger.debug("Voice weight normalization disabled, using raw weights.") - # If voice_weight_normalization is false prevent normalizing the weights by setting the total_weight to 1 so it divides each weight by 1 - if settings.voice_weight_normalization == False: - total_weight = 1 - # Load the first voice as the starting point for voices to be combined onto - path = await self._voice_manager.get_voice_path(split_voice[0][0]) - combined_tensor = await self._load_voice_from_path( - path, split_voice[0][1] / total_weight - ) + # Load and combine tensors + first_part = voice_parts[0] + base_path = await self._voice_manager.get_voice_path(first_part["name"]) + combined_tensor = await self._load_voice_from_path(base_path, first_part["weight"] / norm_factor) - # Loop through each + or - in split_voice so they can be applied to combined voice - for operation_index in range(1, len(split_voice) - 1, 2): - # Get the voice path of the voice 1 index ahead of the operator - path = await self._voice_manager.get_voice_path( - split_voice[operation_index + 1][0] - ) - voice_tensor = await self._load_voice_from_path( - path, split_voice[operation_index + 1][1] / total_weight - ) + current_op = "+" # Implicitly start with addition for the first voice - # Either add or subtract the voice from the current combined voice - if split_voice[operation_index] == "+": + for i in range(len(voice_parts) - 1): + current_part = voice_parts[i] + next_part = voice_parts[i+1] + + # Determine the operation based on the *current* part's operator + op_symbol = current_part["op"] if current_part["op"] else "+" # Default to '+' if no operator + + path = await self._voice_manager.get_voice_path(next_part["name"]) + voice_tensor = await self._load_voice_from_path(path, next_part["weight"] / norm_factor) + + if op_symbol == "+": combined_tensor += voice_tensor - else: + logger.debug(f"Adding voice {next_part['name']} (weight {next_part['weight']/norm_factor:.2f})") + elif op_symbol == "-": combined_tensor -= voice_tensor + logger.debug(f"Subtracting voice {next_part['name']} (weight {next_part['weight']/norm_factor:.2f})") - # Save the new combined voice so it can be loaded latter + + # Save the new combined voice so it can be loaded later + # Use a safe filename based on the original input string + safe_filename = re.sub(r'[^\w+-]', '_', voice) + ".pt" temp_dir = tempfile.gettempdir() - combined_path = os.path.join(temp_dir, f"{voice}.pt") - logger.debug(f"Saving combined voice to: {combined_path}") - torch.save(combined_tensor, combined_path) - return voice, combined_path + combined_path = os.path.join(temp_dir, safe_filename) + logger.debug(f"Saving combined voice '{voice}' to temporary path: {combined_path}") + # Save the tensor to the device specified by settings for model loading consistency + target_device = settings.get_device() + torch.save(combined_tensor.to(target_device), combined_path) + return voice, combined_path # Return original name and temp path + except Exception as e: - logger.error(f"Failed to get voice path: {e}") + logger.error(f"Failed to get or combine voice path for '{voice}': {e}") raise + async def generate_audio_stream( self, text: str, @@ -262,25 +295,29 @@ class TTSService: normalization_options: Optional[NormalizationOptions] = NormalizationOptions(), return_timestamps: Optional[bool] = False, ) -> AsyncGenerator[AudioChunk, None]: - """Generate and stream audio chunks.""" + """Generate and stream audio chunks, handling text, pauses, and newlines.""" stream_normalizer = AudioNormalizer() chunk_index = 0 - current_offset = 0.0 + current_offset = 0.0 # Track audio time offset for timestamps try: # Get backend backend = self.model_manager.get_backend() # Get voice path, handling combined voices + # voice_name will be the potentially complex combined name string voice_name, voice_path = await self._get_voices_path(voice) - logger.debug(f"Using voice path: {voice_path}") + logger.debug(f"Using voice path for '{voice_name}': {voice_path}") - # Use provided lang_code or determine from voice name - pipeline_lang_code = lang_code if lang_code else voice[:1].lower() + # Determine language code + # Use provided lang_code, fallback to settings override, then first letter of first base voice + first_base_voice_match = re.match(r"([a-zA-Z0-9_]+)", voice) + first_base_voice = first_base_voice_match.group(1) if first_base_voice_match else "a" # Default 'a' + pipeline_lang_code = lang_code if lang_code else (settings.default_voice_code if settings.default_voice_code else first_base_voice[:1].lower()) logger.info( f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream" ) - # Process text in chunks with smart splitting, handling pauses + # Process text in chunks (handling pauses and newlines within smart_split) async for chunk_text, tokens, pause_duration_s in smart_split( text, lang_code=pipeline_lang_code, @@ -291,29 +328,21 @@ class TTSService: try: logger.debug(f"Generating {pause_duration_s}s silence chunk") silence_samples = int(pause_duration_s * settings.sample_rate) - # Use float32 zeros as AudioService will normalize later silence_audio = np.zeros(silence_samples, dtype=np.float32) pause_chunk = AudioChunk(audio=silence_audio, word_timestamps=[]) # Empty timestamps for silence - # Convert silence chunk to the target format using AudioService + # Format and yield the silence chunk if output_format: formatted_pause_chunk = await AudioService.convert_audio( - pause_chunk, - output_format, - writer, - speed=1.0, # Speed doesn't affect silence - chunk_text="", # No text for silence - is_last_chunk=False, # Not the final chunk - trim_audio=False, # Don't trim silence - normalizer=stream_normalizer, + pause_chunk, output_format, writer, speed=1.0, chunk_text="", + is_last_chunk=False, trim_audio=False, normalizer=stream_normalizer, ) if formatted_pause_chunk.output: yield formatted_pause_chunk - else: - # If no output format (raw audio), yield the raw chunk - # Ensure normalization happens if needed (AudioService handles this) - pause_chunk.audio = stream_normalizer.normalize(pause_chunk.audio) - yield pause_chunk # Yield raw silence chunk + else: # Raw audio mode + pause_chunk.audio = stream_normalizer.normalize(pause_chunk.audio) + if len(pause_chunk.audio) > 0: + yield pause_chunk # Update offset based on silence duration current_offset += pause_duration_s @@ -322,105 +351,122 @@ class TTSService: except Exception as e: logger.error(f"Failed to process pause chunk: {str(e)}") continue - elif tokens or chunk_text: + + elif tokens or chunk_text.strip(): # Process if there are tokens OR non-whitespace text # --- Handle Text Chunk --- + original_text_with_markers = chunk_text # Keep original including markers/newlines + text_chunk_for_model = chunk_text.strip() # Clean text for the model + has_trailing_newline = chunk_text.endswith('\n') + try: # Process audio for the text chunk async for chunk_data in self._process_chunk( - chunk_text, # Pass text for Kokoro V1 - tokens, # Pass tokens for legacy backends - voice_name, # Pass voice name - voice_path, # Pass voice path - speed, - writer, - output_format, - is_first=(chunk_index == 0), - is_last=False, # We'll update the last chunk later - normalizer=stream_normalizer, - lang_code=pipeline_lang_code, # Pass lang_code - return_timestamps=return_timestamps, - ): - if chunk_data.word_timestamps is not None: - for timestamp in chunk_data.word_timestamps: - timestamp.start_time += current_offset - timestamp.end_time += current_offset + text_chunk_for_model, # Pass cleaned text for model processing + tokens, + voice_name, + voice_path, + speed, + writer, + output_format, + is_first=(chunk_index == 0), + is_last=False, + normalizer=stream_normalizer, + lang_code=pipeline_lang_code, + return_timestamps=return_timestamps, + ): + # Adjust timestamps relative to the stream start + if chunk_data.word_timestamps: + for timestamp in chunk_data.word_timestamps: + timestamp.start_time += current_offset + timestamp.end_time += current_offset - else: - # If no output format (raw audio), yield the raw chunk - # Ensure normalization happens if needed (AudioService handles this) - pause_chunk.audio = stream_normalizer.normalize(pause_chunk.audio) - if len(pause_chunk.audio) > 0: # Only yield if silence is not zero length - yield pause_chunk # Yield raw silence chunk + # Update offset based on the *actual duration* of the generated audio chunk + # Check if audio data exists before calculating duration + chunk_duration = 0 + if chunk_data.audio is not None and len(chunk_data.audio) > 0: + chunk_duration = len(chunk_data.audio) / settings.sample_rate + current_offset += chunk_duration - # Update offset based on silence duration - f"No audio generated for chunk: '{chunk_text.strip()[:100]}...'" - chunk_index += 1 - # --- Add pause after newline --- - # Check the original chunk_text passed from smart_split for trailing newline - if chunk_text.endswith('\n'): + # Yield the processed chunk (either formatted or raw) + if output_format and chunk_data.output: + yield chunk_data + elif not output_format and chunk_data.audio is not None and len(chunk_data.audio) > 0: + yield chunk_data + else: + logger.warning( + f"No audio generated or output for text chunk: '{text_chunk_for_model[:50]}...'" + ) + + + # --- Add pause after newline (if applicable) --- + if has_trailing_newline: newline_pause_s = 0.5 try: logger.debug(f"Adding {newline_pause_s}s pause after newline.") silence_samples = int(newline_pause_s * settings.sample_rate) silence_audio = np.zeros(silence_samples, dtype=np.float32) - pause_chunk = AudioChunk(audio=silence_audio, word_timestamps=[]) + # Create a *new* AudioChunk instance for the newline pause + newline_pause_chunk = AudioChunk(audio=silence_audio, word_timestamps=[]) if output_format: formatted_pause_chunk = await AudioService.convert_audio( - pause_chunk, output_format, writer, speed=1.0, chunk_text="", + newline_pause_chunk, output_format, writer, speed=1.0, chunk_text="", # Use newline_pause_chunk is_last_chunk=False, trim_audio=False, normalizer=stream_normalizer, ) if formatted_pause_chunk.output: yield formatted_pause_chunk else: - pause_chunk.audio = stream_normalizer.normalize(pause_chunk.audio) - if len(pause_chunk.audio) > 0: - yield pause_chunk + # Normalize the *new* chunk before yielding + newline_pause_chunk.audio = stream_normalizer.normalize(newline_pause_chunk.audio) + if len(newline_pause_chunk.audio) > 0: + yield newline_pause_chunk # Yield the normalized newline pause chunk current_offset += newline_pause_s # Add newline pause to offset except Exception as pause_e: logger.error(f"Failed to process newline pause chunk: {str(pause_e)}") - # ------------------------------- + # ------------------------------------------------ + chunk_index += 1 # Increment chunk index after processing text and potential newline pause except Exception as e: - logger.error( - f"Failed to process audio for chunk: '{chunk_text.strip()[:100]}...'. Error: {str(e)}" - ) - continue + logger.exception( # Use exception to include traceback + f"Failed processing audio for chunk: '{text_chunk_for_model[:50]}...'. Error: {str(e)}" + ) + continue - # Only finalize if we successfully processed at least one chunk + # --- End of main loop --- + + # Finalize the stream (sends any remaining buffered data) + # Only finalize if we successfully processed at least one chunk (text or pause) if chunk_index > 0: try: - # Empty tokens list to finalize audio - async for chunk_data in self._process_chunk( - "", # Empty text - [], # Empty tokens - voice_name, - voice_path, - speed, - writer, - output_format, - is_first=False, - is_last=True, # Signal this is the last chunk - normalizer=stream_normalizer, - lang_code=pipeline_lang_code, # Pass lang_code + async for final_chunk_data in self._process_chunk( + "", [], voice_name, voice_path, speed, writer, output_format, + is_first=False, is_last=True, normalizer=stream_normalizer, lang_code=pipeline_lang_code ): - if chunk_data.output is not None: - yield chunk_data + if output_format and final_chunk_data.output: + yield final_chunk_data + elif not output_format and final_chunk_data.audio is not None and len(final_chunk_data.audio) > 0: + yield final_chunk_data # Should yield empty chunk in raw mode upon finalize except Exception as e: logger.error(f"Failed to finalize audio stream: {str(e)}") except Exception as e: - logger.error(f"Error in phoneme audio generation: {str(e)}") - raise e + logger.exception(f"Error during audio stream generation: {str(e)}") # Use exception for traceback + # Ensure writer is closed on error + try: + writer.close() + except Exception as close_e: + logger.error(f"Error closing writer during exception handling: {close_e}") + raise e # Re-raise the original exception + async def generate_audio( self, text: str, voice: str, - writer: StreamingAudioWriter, + writer: StreamingAudioWriter, # Writer needed even for non-streaming internally speed: float = 1.0, return_timestamps: bool = False, normalization_options: Optional[NormalizationOptions] = NormalizationOptions(), @@ -428,37 +474,42 @@ class TTSService: ) -> AudioChunk: """Generate complete audio for text using streaming internally.""" audio_data_chunks = [] - + output_format = None # Signal raw audio mode for internal streaming + combined_chunk = None try: async for audio_stream_data in self.generate_audio_stream( text, voice, - writer, + writer, # Pass writer, although it won't be used for formatting here speed=speed, normalization_options=normalization_options, return_timestamps=return_timestamps, lang_code=lang_code, - output_format=None, + output_format=output_format, # Explicitly None for raw audio ): - if len(audio_stream_data.audio) > 0: - # Ensure we only append chunks with actual audio data - # Raw silence chunks generated for pauses will have audio data (zeros) - # Formatted silence chunks might have empty audio but non-empty output - if len(audio_stream_data.audio) > 0 or (output_format and audio_stream_data.output): - audio_data_chunks.append(audio_stream_data) + # Ensure we only append chunks with actual audio data + # Raw silence chunks generated for pauses will have audio data (zeros) + if audio_stream_data.audio is not None and len(audio_stream_data.audio) > 0: + audio_data_chunks.append(audio_stream_data) if not audio_data_chunks: - # Handle cases where only pauses were present or generation failed logger.warning("No valid audio chunks generated.") - # Return an empty AudioChunk or raise an error? Returning empty for now. - return AudioChunk(audio=np.array([], dtype=np.int16), word_timestamps=[]) + combined_chunk = AudioChunk(audio=np.array([], dtype=np.int16), word_timestamps=[]) + else: + combined_chunk = AudioChunk.combine(audio_data_chunks) - - combined_audio_data = AudioChunk.combine(audio_data_chunks) - return combined_audio_data + return combined_chunk except Exception as e: - logger.error(f"Error in audio generation: {str(e)}") - raise + logger.error(f"Error in combined audio generation: {str(e)}") + raise # Re-raise after logging + finally: + # Explicitly close the writer if it was passed, though it shouldn't hold resources in raw mode + try: + writer.close() + except Exception: + pass # Ignore errors during cleanup + + async def combine_voices(self, voices: List[str]) -> torch.Tensor: """Combine multiple voices. @@ -495,38 +546,45 @@ class TTSService: try: # Get backend and voice path backend = self.model_manager.get_backend() + # Use _get_voices_path to handle potential combined voice names passed here too voice_name, voice_path = await self._get_voices_path(voice) if isinstance(backend, KokoroV1): # For Kokoro V1, use generate_from_tokens with raw phonemes - result = None - # Use provided lang_code or determine from voice name - pipeline_lang_code = lang_code if lang_code else voice[:1].lower() + result_audio = None + # Determine language code + first_base_voice_match = re.match(r"([a-zA-Z0-9_]+)", voice_name) + first_base_voice = first_base_voice_match.group(1) if first_base_voice_match else "a" + pipeline_lang_code = lang_code if lang_code else (settings.default_voice_code if settings.default_voice_code else first_base_voice[:1].lower()) + logger.info( - f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme pipeline" + f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme generation" ) - try: - # Use backend's pipeline management - for r in backend._get_pipeline( - pipeline_lang_code - ).generate_from_tokens( - tokens=phonemes, # Pass raw phonemes string - voice=voice_path, - speed=speed, - ): - if r.audio is not None: - result = r - break - except Exception as e: - logger.error(f"Failed to generate from phonemes: {e}") - raise RuntimeError(f"Phoneme generation failed: {e}") + # Use backend's pipeline management and iterate through potential chunks + full_audio_list = [] + async for r in backend.generate_from_tokens( # generate_from_tokens is now async + tokens=phonemes, # Pass raw phonemes string + voice=(voice_name, voice_path), # Pass tuple + speed=speed, + lang_code=pipeline_lang_code, + ): + if r is not None and len(r) > 0: + # r is directly the numpy array chunk + full_audio_list.append(r) - if result is None or result.audio is None: - raise ValueError("No audio generated") + + if not full_audio_list: + raise ValueError("No audio generated from phonemes") + + # Combine chunks if necessary + result_audio = np.concatenate(full_audio_list) if len(full_audio_list) > 1 else full_audio_list[0] processing_time = time.time() - start_time - return result.audio.numpy(), processing_time + # Normalize the final audio before returning + normalizer = AudioNormalizer() + normalized_audio = normalizer.normalize(result_audio) + return normalized_audio, processing_time else: raise ValueError( "Phoneme generation only supported with Kokoro V1 backend"