Refactor TTS service to improve voice combination logic and error handling. Updated voice parsing to support combined voices with weights, enhanced normalization handling, and streamlined audio generation process. Improved logging for better debugging and removed unnecessary comments for clarity.

This commit is contained in:
Lukin 2025-04-08 11:38:07 +08:00
parent 7a838ab3e8
commit 66201494d0

View file

@ -103,7 +103,7 @@ class TTSService:
if isinstance(backend, KokoroV1): if isinstance(backend, KokoroV1):
internal_chunk_index = 0 internal_chunk_index = 0
async for chunk_data in self.model_manager.generate( async for chunk_data in self.model_manager.generate(
text_for_model.strip(), # Pass cleaned text to model chunk_text,
(voice_name, voice_path), (voice_name, voice_path),
speed=speed, speed=speed,
lang_code=lang_code, lang_code=lang_code,
@ -189,7 +189,7 @@ class TTSService:
"""Get voice path, handling combined voices. """Get voice path, handling combined voices.
Args: Args:
voice: Voice name or combined voice names (e.g., 'af_jadzia(0.7)+af_jessica(0.3)') voice: Voice name or combined voice names (e.g., 'af_jadzia+af_jessica')
Returns: Returns:
Tuple of (voice name to use, voice path to use) Tuple of (voice name to use, voice path to use)
@ -198,84 +198,70 @@ class TTSService:
RuntimeError: If voice not found RuntimeError: If voice not found
""" """
try: try:
# Regex to handle names, weights, and operators: af_name(weight)[+-]af_other(weight)... # Split the voice on + and - and ensure that they get added to the list eg: hi+bob = ["hi","+","bob"]
pattern = re.compile(r"([a-zA-Z0-9_]+)(?:\((\d+(?:\.\d+)?)\))?([+-]?)") split_voice = re.split(r"([-+])", voice)
matches = pattern.findall(voice.replace(" ", "")) # Remove spaces
if not matches: # If it is only once voice there is no point in loading it up, doing nothing with it, then saving it
raise ValueError(f"Could not parse voice string: {voice}") if len(split_voice) == 1:
# Since its a single voice the only time that the weight would matter is if voice_weight_normalization is off
if (
"(" not in voice and ")" not in voice
) or settings.voice_weight_normalization == True:
path = await self._voice_manager.get_voice_path(voice)
if not path:
raise RuntimeError(f"Voice not found: {voice}")
logger.debug(f"Using single voice path: {path}")
return voice, path
# If only one voice and no explicit weight or operators, handle directly
if len(matches) == 1 and not matches[0][1] and not matches[0][2]:
voice_name = matches[0][0]
path = await self._voice_manager.get_voice_path(voice_name)
if not path:
raise RuntimeError(f"Voice not found: {voice_name}")
logger.debug(f"Using single voice path: {path}")
return voice_name, path
# Process combinations
voice_parts = []
total_weight = 0 total_weight = 0
for name, weight_str, operator in matches:
weight = float(weight_str) if weight_str else 1.0
voice_parts.append({"name": name, "weight": weight, "op": operator})
# Use weight directly for total, normalization happens later if enabled
total_weight += weight # Summing base weights before potential normalization
# Check base voices exist for voice_index in range(0, len(split_voice), 2):
available_voices = await self._voice_manager.list_voices() voice_object = split_voice[voice_index]
for part in voice_parts:
if part["name"] not in available_voices:
raise ValueError(f"Base voice '{part['name']}' not found in combined string '{voice}'. Available: {available_voices}")
if "(" in voice_object and ")" in voice_object:
voice_name = voice_object.split("(")[0].strip()
voice_weight = float(voice_object.split("(")[1].split(")")[0])
else:
voice_name = voice_object
voice_weight = 1
# Determine normalization factor total_weight += voice_weight
norm_factor = total_weight if settings.voice_weight_normalization and total_weight > 0 else 1.0 split_voice[voice_index] = (voice_name, voice_weight)
if settings.voice_weight_normalization:
logger.debug(f"Normalizing combined voice weights by factor: {norm_factor:.2f}")
else:
logger.debug("Voice weight normalization disabled, using raw weights.")
# If voice_weight_normalization is false prevent normalizing the weights by setting the total_weight to 1 so it divides each weight by 1
if settings.voice_weight_normalization == False:
total_weight = 1
# Load and combine tensors # Load the first voice as the starting point for voices to be combined onto
first_part = voice_parts[0] path = await self._voice_manager.get_voice_path(split_voice[0][0])
base_path = await self._voice_manager.get_voice_path(first_part["name"]) combined_tensor = await self._load_voice_from_path(
combined_tensor = await self._load_voice_from_path(base_path, first_part["weight"] / norm_factor) path, split_voice[0][1] / total_weight
)
current_op = "+" # Implicitly start with addition for the first voice # Loop through each + or - in split_voice so they can be applied to combined voice
for operation_index in range(1, len(split_voice) - 1, 2):
# Get the voice path of the voice 1 index ahead of the operator
path = await self._voice_manager.get_voice_path(
split_voice[operation_index + 1][0]
)
voice_tensor = await self._load_voice_from_path(
path, split_voice[operation_index + 1][1] / total_weight
)
for i in range(len(voice_parts) - 1): # Either add or subtract the voice from the current combined voice
current_part = voice_parts[i] if split_voice[operation_index] == "+":
next_part = voice_parts[i+1]
# Determine the operation based on the *current* part's operator
op_symbol = current_part["op"] if current_part["op"] else "+" # Default to '+' if no operator
path = await self._voice_manager.get_voice_path(next_part["name"])
voice_tensor = await self._load_voice_from_path(path, next_part["weight"] / norm_factor)
if op_symbol == "+":
combined_tensor += voice_tensor combined_tensor += voice_tensor
logger.debug(f"Adding voice {next_part['name']} (weight {next_part['weight']/norm_factor:.2f})") else:
elif op_symbol == "-":
combined_tensor -= voice_tensor combined_tensor -= voice_tensor
logger.debug(f"Subtracting voice {next_part['name']} (weight {next_part['weight']/norm_factor:.2f})")
# Save the new combined voice so it can be loaded latter
# Save the new combined voice so it can be loaded later
# Use a safe filename based on the original input string
safe_filename = re.sub(r'[^\w+-]', '_', voice) + ".pt"
temp_dir = tempfile.gettempdir() temp_dir = tempfile.gettempdir()
combined_path = os.path.join(temp_dir, safe_filename) combined_path = os.path.join(temp_dir, f"{voice}.pt")
logger.debug(f"Saving combined voice '{voice}' to temporary path: {combined_path}") logger.debug(f"Saving combined voice to: {combined_path}")
# Save the tensor to the device specified by settings for model loading consistency torch.save(combined_tensor, combined_path)
target_device = settings.get_device() return voice, combined_path
torch.save(combined_tensor.to(target_device), combined_path)
return voice, combined_path # Return original name and temp path
except Exception as e: except Exception as e:
logger.error(f"Failed to get or combine voice path for '{voice}': {e}") logger.error(f"Failed to get voice path: {e}")
raise raise
@ -462,7 +448,7 @@ class TTSService:
self, self,
text: str, text: str,
voice: str, voice: str,
writer: StreamingAudioWriter, # Writer needed even for non-streaming internally writer: StreamingAudioWriter,
speed: float = 1.0, speed: float = 1.0,
return_timestamps: bool = False, return_timestamps: bool = False,
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(), normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
@ -470,54 +456,26 @@ class TTSService:
) -> AudioChunk: ) -> AudioChunk:
"""Generate complete audio for text using streaming internally.""" """Generate complete audio for text using streaming internally."""
audio_data_chunks = [] audio_data_chunks = []
output_format = None # Signal raw audio mode for internal streaming
combined_chunk = None
try:
# Pass a dummy writer if none provided, as generate_audio_stream requires one
# Although in raw mode (output_format=None), it shouldn't be heavily used for formatting
internal_writer = writer if writer else StreamingAudioWriter(format='wav', sample_rate=settings.sample_rate)
try:
async for audio_stream_data in self.generate_audio_stream( async for audio_stream_data in self.generate_audio_stream(
text, text,
voice, voice,
internal_writer, # Pass the writer instance writer,
speed=speed, speed=speed,
normalization_options=normalization_options, normalization_options=normalization_options,
return_timestamps=return_timestamps, # Pass this down return_timestamps=return_timestamps,
lang_code=lang_code, lang_code=lang_code,
output_format=output_format, # Explicitly None for raw audio output_format=None,
): ):
# Ensure we only append chunks with actual audio data if len(audio_stream_data.audio) > 0:
# Raw silence chunks generated for pauses will have audio data (zeros)
if audio_stream_data.audio is not None and len(audio_stream_data.audio) > 0:
# Ensure timestamps are preserved if requested
if return_timestamps and not audio_stream_data.word_timestamps:
audio_stream_data.word_timestamps = [] # Initialize if needed
audio_data_chunks.append(audio_stream_data) audio_data_chunks.append(audio_stream_data)
if not audio_data_chunks: combined_audio_data = AudioChunk.combine(audio_data_chunks)
logger.warning("No valid audio chunks generated.") return combined_audio_data
combined_chunk = AudioChunk(audio=np.array([], dtype=np.int16), word_timestamps=[])
else:
combined_chunk = AudioChunk.combine(audio_data_chunks)
# Ensure the combined audio is int16 before returning, as downstream expects this raw format.
if combined_chunk.audio.dtype != np.int16:
logger.warning(f"Combined audio dtype is {combined_chunk.audio.dtype}, converting to int16.")
# Assuming normalization happened, scale from float [-1, 1] to int16
if np.issubdtype(combined_chunk.audio.dtype, np.floating):
combined_chunk.audio = np.clip(combined_chunk.audio * 32767, -32768, 32767).astype(np.int16)
else:
# If it's another type, attempt direct conversion (might be lossy)
combined_chunk.audio = combined_chunk.audio.astype(np.int16)
return combined_chunk
except Exception as e: except Exception as e:
logger.error(f"Error in combined audio generation: {str(e)}") logger.error(f"Error in audio generation: {str(e)}")
raise # Re-raise after logging raise
# Removed finally block that closed the writer prematurely
# The caller is now responsible for closing the writer after final conversion.
async def combine_voices(self, voices: List[str]) -> torch.Tensor: async def combine_voices(self, voices: List[str]) -> torch.Tensor:
@ -555,49 +513,38 @@ class TTSService:
try: try:
# Get backend and voice path # Get backend and voice path
backend = self.model_manager.get_backend() backend = self.model_manager.get_backend()
# Use _get_voices_path to handle potential combined voice names passed here too
voice_name, voice_path = await self._get_voices_path(voice) voice_name, voice_path = await self._get_voices_path(voice)
if isinstance(backend, KokoroV1): if isinstance(backend, KokoroV1):
# For Kokoro V1, use generate_from_tokens with raw phonemes # For Kokoro V1, use generate_from_tokens with raw phonemes
result_audio = None result = None
# Determine language code # Use provided lang_code or determine from voice name
first_base_voice_match = re.match(r"([a-zA-Z0-9_]+)", voice_name) pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
first_base_voice = first_base_voice_match.group(1) if first_base_voice_match else "a"
pipeline_lang_code = lang_code if lang_code else (settings.default_voice_code if settings.default_voice_code else first_base_voice[:1].lower())
logger.info( logger.info(
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme generation" f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme pipeline"
) )
# Use backend's pipeline management and iterate through potential chunks try:
full_audio_list = [] # Use backend's pipeline management
async for r in backend.generate_from_tokens( # generate_from_tokens is now async for r in backend._get_pipeline(
tokens=phonemes, # Pass raw phonemes string pipeline_lang_code
voice=(voice_name, voice_path), # Pass tuple ).generate_from_tokens(
speed=speed, tokens=phonemes, # Pass raw phonemes string
lang_code=pipeline_lang_code, voice=voice_path,
): speed=speed,
if r is not None and len(r) > 0: ):
# r is directly the numpy array chunk if r.audio is not None:
full_audio_list.append(r) result = r
break
except Exception as e:
logger.error(f"Failed to generate from phonemes: {e}")
raise RuntimeError(f"Phoneme generation failed: {e}")
if result is None or result.audio is None:
if not full_audio_list: raise ValueError("No audio generated")
raise ValueError("No audio generated from phonemes")
# Combine chunks if necessary
result_audio = np.concatenate(full_audio_list) if len(full_audio_list) > 1 else full_audio_list[0]
processing_time = time.time() - start_time processing_time = time.time() - start_time
# Normalize the final audio before returning return result.audio.numpy(), processing_time
normalizer = AudioNormalizer()
normalized_audio = normalizer.normalize(result_audio)
# Return as int16 for consistency
if normalized_audio.dtype != np.int16:
normalized_audio = np.clip(normalized_audio * 32767, -32768, 32767).astype(np.int16)
return normalized_audio, processing_time
else: else:
raise ValueError( raise ValueError(
"Phoneme generation only supported with Kokoro V1 backend" "Phoneme generation only supported with Kokoro V1 backend"