Refactor TTS service to improve voice combination logic and error handling. Updated voice parsing to support combined voices with weights, enhanced normalization handling, and streamlined audio generation process. Improved logging for better debugging and removed unnecessary comments for clarity.

This commit is contained in:
Lukin 2025-04-08 11:38:07 +08:00
parent 7a838ab3e8
commit 66201494d0

View file

@ -103,7 +103,7 @@ class TTSService:
if isinstance(backend, KokoroV1):
internal_chunk_index = 0
async for chunk_data in self.model_manager.generate(
text_for_model.strip(), # Pass cleaned text to model
chunk_text,
(voice_name, voice_path),
speed=speed,
lang_code=lang_code,
@ -189,7 +189,7 @@ class TTSService:
"""Get voice path, handling combined voices.
Args:
voice: Voice name or combined voice names (e.g., 'af_jadzia(0.7)+af_jessica(0.3)')
voice: Voice name or combined voice names (e.g., 'af_jadzia+af_jessica')
Returns:
Tuple of (voice name to use, voice path to use)
@ -198,84 +198,70 @@ class TTSService:
RuntimeError: If voice not found
"""
try:
# Regex to handle names, weights, and operators: af_name(weight)[+-]af_other(weight)...
pattern = re.compile(r"([a-zA-Z0-9_]+)(?:\((\d+(?:\.\d+)?)\))?([+-]?)")
matches = pattern.findall(voice.replace(" ", "")) # Remove spaces
# Split the voice on + and - and ensure that they get added to the list eg: hi+bob = ["hi","+","bob"]
split_voice = re.split(r"([-+])", voice)
if not matches:
raise ValueError(f"Could not parse voice string: {voice}")
# If it is only once voice there is no point in loading it up, doing nothing with it, then saving it
if len(split_voice) == 1:
# Since its a single voice the only time that the weight would matter is if voice_weight_normalization is off
if (
"(" not in voice and ")" not in voice
) or settings.voice_weight_normalization == True:
path = await self._voice_manager.get_voice_path(voice)
if not path:
raise RuntimeError(f"Voice not found: {voice}")
logger.debug(f"Using single voice path: {path}")
return voice, path
# If only one voice and no explicit weight or operators, handle directly
if len(matches) == 1 and not matches[0][1] and not matches[0][2]:
voice_name = matches[0][0]
path = await self._voice_manager.get_voice_path(voice_name)
if not path:
raise RuntimeError(f"Voice not found: {voice_name}")
logger.debug(f"Using single voice path: {path}")
return voice_name, path
# Process combinations
voice_parts = []
total_weight = 0
for name, weight_str, operator in matches:
weight = float(weight_str) if weight_str else 1.0
voice_parts.append({"name": name, "weight": weight, "op": operator})
# Use weight directly for total, normalization happens later if enabled
total_weight += weight # Summing base weights before potential normalization
# Check base voices exist
available_voices = await self._voice_manager.list_voices()
for part in voice_parts:
if part["name"] not in available_voices:
raise ValueError(f"Base voice '{part['name']}' not found in combined string '{voice}'. Available: {available_voices}")
for voice_index in range(0, len(split_voice), 2):
voice_object = split_voice[voice_index]
if "(" in voice_object and ")" in voice_object:
voice_name = voice_object.split("(")[0].strip()
voice_weight = float(voice_object.split("(")[1].split(")")[0])
else:
voice_name = voice_object
voice_weight = 1
# Determine normalization factor
norm_factor = total_weight if settings.voice_weight_normalization and total_weight > 0 else 1.0
if settings.voice_weight_normalization:
logger.debug(f"Normalizing combined voice weights by factor: {norm_factor:.2f}")
else:
logger.debug("Voice weight normalization disabled, using raw weights.")
total_weight += voice_weight
split_voice[voice_index] = (voice_name, voice_weight)
# If voice_weight_normalization is false prevent normalizing the weights by setting the total_weight to 1 so it divides each weight by 1
if settings.voice_weight_normalization == False:
total_weight = 1
# Load and combine tensors
first_part = voice_parts[0]
base_path = await self._voice_manager.get_voice_path(first_part["name"])
combined_tensor = await self._load_voice_from_path(base_path, first_part["weight"] / norm_factor)
# Load the first voice as the starting point for voices to be combined onto
path = await self._voice_manager.get_voice_path(split_voice[0][0])
combined_tensor = await self._load_voice_from_path(
path, split_voice[0][1] / total_weight
)
current_op = "+" # Implicitly start with addition for the first voice
# Loop through each + or - in split_voice so they can be applied to combined voice
for operation_index in range(1, len(split_voice) - 1, 2):
# Get the voice path of the voice 1 index ahead of the operator
path = await self._voice_manager.get_voice_path(
split_voice[operation_index + 1][0]
)
voice_tensor = await self._load_voice_from_path(
path, split_voice[operation_index + 1][1] / total_weight
)
for i in range(len(voice_parts) - 1):
current_part = voice_parts[i]
next_part = voice_parts[i+1]
# Determine the operation based on the *current* part's operator
op_symbol = current_part["op"] if current_part["op"] else "+" # Default to '+' if no operator
path = await self._voice_manager.get_voice_path(next_part["name"])
voice_tensor = await self._load_voice_from_path(path, next_part["weight"] / norm_factor)
if op_symbol == "+":
# Either add or subtract the voice from the current combined voice
if split_voice[operation_index] == "+":
combined_tensor += voice_tensor
logger.debug(f"Adding voice {next_part['name']} (weight {next_part['weight']/norm_factor:.2f})")
elif op_symbol == "-":
else:
combined_tensor -= voice_tensor
logger.debug(f"Subtracting voice {next_part['name']} (weight {next_part['weight']/norm_factor:.2f})")
# Save the new combined voice so it can be loaded later
# Use a safe filename based on the original input string
safe_filename = re.sub(r'[^\w+-]', '_', voice) + ".pt"
# Save the new combined voice so it can be loaded latter
temp_dir = tempfile.gettempdir()
combined_path = os.path.join(temp_dir, safe_filename)
logger.debug(f"Saving combined voice '{voice}' to temporary path: {combined_path}")
# Save the tensor to the device specified by settings for model loading consistency
target_device = settings.get_device()
torch.save(combined_tensor.to(target_device), combined_path)
return voice, combined_path # Return original name and temp path
combined_path = os.path.join(temp_dir, f"{voice}.pt")
logger.debug(f"Saving combined voice to: {combined_path}")
torch.save(combined_tensor, combined_path)
return voice, combined_path
except Exception as e:
logger.error(f"Failed to get or combine voice path for '{voice}': {e}")
logger.error(f"Failed to get voice path: {e}")
raise
@ -462,7 +448,7 @@ class TTSService:
self,
text: str,
voice: str,
writer: StreamingAudioWriter, # Writer needed even for non-streaming internally
writer: StreamingAudioWriter,
speed: float = 1.0,
return_timestamps: bool = False,
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
@ -470,54 +456,26 @@ class TTSService:
) -> AudioChunk:
"""Generate complete audio for text using streaming internally."""
audio_data_chunks = []
output_format = None # Signal raw audio mode for internal streaming
combined_chunk = None
try:
# Pass a dummy writer if none provided, as generate_audio_stream requires one
# Although in raw mode (output_format=None), it shouldn't be heavily used for formatting
internal_writer = writer if writer else StreamingAudioWriter(format='wav', sample_rate=settings.sample_rate)
try:
async for audio_stream_data in self.generate_audio_stream(
text,
voice,
internal_writer, # Pass the writer instance
writer,
speed=speed,
normalization_options=normalization_options,
return_timestamps=return_timestamps, # Pass this down
return_timestamps=return_timestamps,
lang_code=lang_code,
output_format=output_format, # Explicitly None for raw audio
output_format=None,
):
# Ensure we only append chunks with actual audio data
# Raw silence chunks generated for pauses will have audio data (zeros)
if audio_stream_data.audio is not None and len(audio_stream_data.audio) > 0:
# Ensure timestamps are preserved if requested
if return_timestamps and not audio_stream_data.word_timestamps:
audio_stream_data.word_timestamps = [] # Initialize if needed
if len(audio_stream_data.audio) > 0:
audio_data_chunks.append(audio_stream_data)
if not audio_data_chunks:
logger.warning("No valid audio chunks generated.")
combined_chunk = AudioChunk(audio=np.array([], dtype=np.int16), word_timestamps=[])
else:
combined_chunk = AudioChunk.combine(audio_data_chunks)
# Ensure the combined audio is int16 before returning, as downstream expects this raw format.
if combined_chunk.audio.dtype != np.int16:
logger.warning(f"Combined audio dtype is {combined_chunk.audio.dtype}, converting to int16.")
# Assuming normalization happened, scale from float [-1, 1] to int16
if np.issubdtype(combined_chunk.audio.dtype, np.floating):
combined_chunk.audio = np.clip(combined_chunk.audio * 32767, -32768, 32767).astype(np.int16)
else:
# If it's another type, attempt direct conversion (might be lossy)
combined_chunk.audio = combined_chunk.audio.astype(np.int16)
return combined_chunk
combined_audio_data = AudioChunk.combine(audio_data_chunks)
return combined_audio_data
except Exception as e:
logger.error(f"Error in combined audio generation: {str(e)}")
raise # Re-raise after logging
# Removed finally block that closed the writer prematurely
# The caller is now responsible for closing the writer after final conversion.
logger.error(f"Error in audio generation: {str(e)}")
raise
async def combine_voices(self, voices: List[str]) -> torch.Tensor:
@ -555,49 +513,38 @@ class TTSService:
try:
# Get backend and voice path
backend = self.model_manager.get_backend()
# Use _get_voices_path to handle potential combined voice names passed here too
voice_name, voice_path = await self._get_voices_path(voice)
if isinstance(backend, KokoroV1):
# For Kokoro V1, use generate_from_tokens with raw phonemes
result_audio = None
# Determine language code
first_base_voice_match = re.match(r"([a-zA-Z0-9_]+)", voice_name)
first_base_voice = first_base_voice_match.group(1) if first_base_voice_match else "a"
pipeline_lang_code = lang_code if lang_code else (settings.default_voice_code if settings.default_voice_code else first_base_voice[:1].lower())
result = None
# Use provided lang_code or determine from voice name
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
logger.info(
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme generation"
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme pipeline"
)
# Use backend's pipeline management and iterate through potential chunks
full_audio_list = []
async for r in backend.generate_from_tokens( # generate_from_tokens is now async
tokens=phonemes, # Pass raw phonemes string
voice=(voice_name, voice_path), # Pass tuple
speed=speed,
lang_code=pipeline_lang_code,
):
if r is not None and len(r) > 0:
# r is directly the numpy array chunk
full_audio_list.append(r)
try:
# Use backend's pipeline management
for r in backend._get_pipeline(
pipeline_lang_code
).generate_from_tokens(
tokens=phonemes, # Pass raw phonemes string
voice=voice_path,
speed=speed,
):
if r.audio is not None:
result = r
break
except Exception as e:
logger.error(f"Failed to generate from phonemes: {e}")
raise RuntimeError(f"Phoneme generation failed: {e}")
if not full_audio_list:
raise ValueError("No audio generated from phonemes")
# Combine chunks if necessary
result_audio = np.concatenate(full_audio_list) if len(full_audio_list) > 1 else full_audio_list[0]
if result is None or result.audio is None:
raise ValueError("No audio generated")
processing_time = time.time() - start_time
# Normalize the final audio before returning
normalizer = AudioNormalizer()
normalized_audio = normalizer.normalize(result_audio)
# Return as int16 for consistency
if normalized_audio.dtype != np.int16:
normalized_audio = np.clip(normalized_audio * 32767, -32768, 32767).astype(np.int16)
return normalized_audio, processing_time
return result.audio.numpy(), processing_time
else:
raise ValueError(
"Phoneme generation only supported with Kokoro V1 backend"