mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Refactor TTS service to improve voice combination logic and error handling. Updated voice parsing to support combined voices with weights, enhanced normalization handling, and streamlined audio generation process. Improved logging for better debugging and removed unnecessary comments for clarity.
This commit is contained in:
parent
7a838ab3e8
commit
66201494d0
1 changed files with 83 additions and 136 deletions
|
@ -103,7 +103,7 @@ class TTSService:
|
|||
if isinstance(backend, KokoroV1):
|
||||
internal_chunk_index = 0
|
||||
async for chunk_data in self.model_manager.generate(
|
||||
text_for_model.strip(), # Pass cleaned text to model
|
||||
chunk_text,
|
||||
(voice_name, voice_path),
|
||||
speed=speed,
|
||||
lang_code=lang_code,
|
||||
|
@ -189,7 +189,7 @@ class TTSService:
|
|||
"""Get voice path, handling combined voices.
|
||||
|
||||
Args:
|
||||
voice: Voice name or combined voice names (e.g., 'af_jadzia(0.7)+af_jessica(0.3)')
|
||||
voice: Voice name or combined voice names (e.g., 'af_jadzia+af_jessica')
|
||||
|
||||
Returns:
|
||||
Tuple of (voice name to use, voice path to use)
|
||||
|
@ -198,84 +198,70 @@ class TTSService:
|
|||
RuntimeError: If voice not found
|
||||
"""
|
||||
try:
|
||||
# Regex to handle names, weights, and operators: af_name(weight)[+-]af_other(weight)...
|
||||
pattern = re.compile(r"([a-zA-Z0-9_]+)(?:\((\d+(?:\.\d+)?)\))?([+-]?)")
|
||||
matches = pattern.findall(voice.replace(" ", "")) # Remove spaces
|
||||
# Split the voice on + and - and ensure that they get added to the list eg: hi+bob = ["hi","+","bob"]
|
||||
split_voice = re.split(r"([-+])", voice)
|
||||
|
||||
if not matches:
|
||||
raise ValueError(f"Could not parse voice string: {voice}")
|
||||
# If it is only once voice there is no point in loading it up, doing nothing with it, then saving it
|
||||
if len(split_voice) == 1:
|
||||
# Since its a single voice the only time that the weight would matter is if voice_weight_normalization is off
|
||||
if (
|
||||
"(" not in voice and ")" not in voice
|
||||
) or settings.voice_weight_normalization == True:
|
||||
path = await self._voice_manager.get_voice_path(voice)
|
||||
if not path:
|
||||
raise RuntimeError(f"Voice not found: {voice}")
|
||||
logger.debug(f"Using single voice path: {path}")
|
||||
return voice, path
|
||||
|
||||
# If only one voice and no explicit weight or operators, handle directly
|
||||
if len(matches) == 1 and not matches[0][1] and not matches[0][2]:
|
||||
voice_name = matches[0][0]
|
||||
path = await self._voice_manager.get_voice_path(voice_name)
|
||||
if not path:
|
||||
raise RuntimeError(f"Voice not found: {voice_name}")
|
||||
logger.debug(f"Using single voice path: {path}")
|
||||
return voice_name, path
|
||||
|
||||
# Process combinations
|
||||
voice_parts = []
|
||||
total_weight = 0
|
||||
for name, weight_str, operator in matches:
|
||||
weight = float(weight_str) if weight_str else 1.0
|
||||
voice_parts.append({"name": name, "weight": weight, "op": operator})
|
||||
# Use weight directly for total, normalization happens later if enabled
|
||||
total_weight += weight # Summing base weights before potential normalization
|
||||
|
||||
# Check base voices exist
|
||||
available_voices = await self._voice_manager.list_voices()
|
||||
for part in voice_parts:
|
||||
if part["name"] not in available_voices:
|
||||
raise ValueError(f"Base voice '{part['name']}' not found in combined string '{voice}'. Available: {available_voices}")
|
||||
for voice_index in range(0, len(split_voice), 2):
|
||||
voice_object = split_voice[voice_index]
|
||||
|
||||
if "(" in voice_object and ")" in voice_object:
|
||||
voice_name = voice_object.split("(")[0].strip()
|
||||
voice_weight = float(voice_object.split("(")[1].split(")")[0])
|
||||
else:
|
||||
voice_name = voice_object
|
||||
voice_weight = 1
|
||||
|
||||
# Determine normalization factor
|
||||
norm_factor = total_weight if settings.voice_weight_normalization and total_weight > 0 else 1.0
|
||||
if settings.voice_weight_normalization:
|
||||
logger.debug(f"Normalizing combined voice weights by factor: {norm_factor:.2f}")
|
||||
else:
|
||||
logger.debug("Voice weight normalization disabled, using raw weights.")
|
||||
total_weight += voice_weight
|
||||
split_voice[voice_index] = (voice_name, voice_weight)
|
||||
|
||||
# If voice_weight_normalization is false prevent normalizing the weights by setting the total_weight to 1 so it divides each weight by 1
|
||||
if settings.voice_weight_normalization == False:
|
||||
total_weight = 1
|
||||
|
||||
# Load and combine tensors
|
||||
first_part = voice_parts[0]
|
||||
base_path = await self._voice_manager.get_voice_path(first_part["name"])
|
||||
combined_tensor = await self._load_voice_from_path(base_path, first_part["weight"] / norm_factor)
|
||||
# Load the first voice as the starting point for voices to be combined onto
|
||||
path = await self._voice_manager.get_voice_path(split_voice[0][0])
|
||||
combined_tensor = await self._load_voice_from_path(
|
||||
path, split_voice[0][1] / total_weight
|
||||
)
|
||||
|
||||
current_op = "+" # Implicitly start with addition for the first voice
|
||||
# Loop through each + or - in split_voice so they can be applied to combined voice
|
||||
for operation_index in range(1, len(split_voice) - 1, 2):
|
||||
# Get the voice path of the voice 1 index ahead of the operator
|
||||
path = await self._voice_manager.get_voice_path(
|
||||
split_voice[operation_index + 1][0]
|
||||
)
|
||||
voice_tensor = await self._load_voice_from_path(
|
||||
path, split_voice[operation_index + 1][1] / total_weight
|
||||
)
|
||||
|
||||
for i in range(len(voice_parts) - 1):
|
||||
current_part = voice_parts[i]
|
||||
next_part = voice_parts[i+1]
|
||||
|
||||
# Determine the operation based on the *current* part's operator
|
||||
op_symbol = current_part["op"] if current_part["op"] else "+" # Default to '+' if no operator
|
||||
|
||||
path = await self._voice_manager.get_voice_path(next_part["name"])
|
||||
voice_tensor = await self._load_voice_from_path(path, next_part["weight"] / norm_factor)
|
||||
|
||||
if op_symbol == "+":
|
||||
# Either add or subtract the voice from the current combined voice
|
||||
if split_voice[operation_index] == "+":
|
||||
combined_tensor += voice_tensor
|
||||
logger.debug(f"Adding voice {next_part['name']} (weight {next_part['weight']/norm_factor:.2f})")
|
||||
elif op_symbol == "-":
|
||||
else:
|
||||
combined_tensor -= voice_tensor
|
||||
logger.debug(f"Subtracting voice {next_part['name']} (weight {next_part['weight']/norm_factor:.2f})")
|
||||
|
||||
|
||||
# Save the new combined voice so it can be loaded later
|
||||
# Use a safe filename based on the original input string
|
||||
safe_filename = re.sub(r'[^\w+-]', '_', voice) + ".pt"
|
||||
# Save the new combined voice so it can be loaded latter
|
||||
temp_dir = tempfile.gettempdir()
|
||||
combined_path = os.path.join(temp_dir, safe_filename)
|
||||
logger.debug(f"Saving combined voice '{voice}' to temporary path: {combined_path}")
|
||||
# Save the tensor to the device specified by settings for model loading consistency
|
||||
target_device = settings.get_device()
|
||||
torch.save(combined_tensor.to(target_device), combined_path)
|
||||
return voice, combined_path # Return original name and temp path
|
||||
|
||||
combined_path = os.path.join(temp_dir, f"{voice}.pt")
|
||||
logger.debug(f"Saving combined voice to: {combined_path}")
|
||||
torch.save(combined_tensor, combined_path)
|
||||
return voice, combined_path
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get or combine voice path for '{voice}': {e}")
|
||||
logger.error(f"Failed to get voice path: {e}")
|
||||
raise
|
||||
|
||||
|
||||
|
@ -462,7 +448,7 @@ class TTSService:
|
|||
self,
|
||||
text: str,
|
||||
voice: str,
|
||||
writer: StreamingAudioWriter, # Writer needed even for non-streaming internally
|
||||
writer: StreamingAudioWriter,
|
||||
speed: float = 1.0,
|
||||
return_timestamps: bool = False,
|
||||
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
|
||||
|
@ -470,54 +456,26 @@ class TTSService:
|
|||
) -> AudioChunk:
|
||||
"""Generate complete audio for text using streaming internally."""
|
||||
audio_data_chunks = []
|
||||
output_format = None # Signal raw audio mode for internal streaming
|
||||
combined_chunk = None
|
||||
try:
|
||||
# Pass a dummy writer if none provided, as generate_audio_stream requires one
|
||||
# Although in raw mode (output_format=None), it shouldn't be heavily used for formatting
|
||||
internal_writer = writer if writer else StreamingAudioWriter(format='wav', sample_rate=settings.sample_rate)
|
||||
|
||||
try:
|
||||
async for audio_stream_data in self.generate_audio_stream(
|
||||
text,
|
||||
voice,
|
||||
internal_writer, # Pass the writer instance
|
||||
writer,
|
||||
speed=speed,
|
||||
normalization_options=normalization_options,
|
||||
return_timestamps=return_timestamps, # Pass this down
|
||||
return_timestamps=return_timestamps,
|
||||
lang_code=lang_code,
|
||||
output_format=output_format, # Explicitly None for raw audio
|
||||
output_format=None,
|
||||
):
|
||||
# Ensure we only append chunks with actual audio data
|
||||
# Raw silence chunks generated for pauses will have audio data (zeros)
|
||||
if audio_stream_data.audio is not None and len(audio_stream_data.audio) > 0:
|
||||
# Ensure timestamps are preserved if requested
|
||||
if return_timestamps and not audio_stream_data.word_timestamps:
|
||||
audio_stream_data.word_timestamps = [] # Initialize if needed
|
||||
if len(audio_stream_data.audio) > 0:
|
||||
audio_data_chunks.append(audio_stream_data)
|
||||
|
||||
if not audio_data_chunks:
|
||||
logger.warning("No valid audio chunks generated.")
|
||||
combined_chunk = AudioChunk(audio=np.array([], dtype=np.int16), word_timestamps=[])
|
||||
|
||||
else:
|
||||
combined_chunk = AudioChunk.combine(audio_data_chunks)
|
||||
# Ensure the combined audio is int16 before returning, as downstream expects this raw format.
|
||||
if combined_chunk.audio.dtype != np.int16:
|
||||
logger.warning(f"Combined audio dtype is {combined_chunk.audio.dtype}, converting to int16.")
|
||||
# Assuming normalization happened, scale from float [-1, 1] to int16
|
||||
if np.issubdtype(combined_chunk.audio.dtype, np.floating):
|
||||
combined_chunk.audio = np.clip(combined_chunk.audio * 32767, -32768, 32767).astype(np.int16)
|
||||
else:
|
||||
# If it's another type, attempt direct conversion (might be lossy)
|
||||
combined_chunk.audio = combined_chunk.audio.astype(np.int16)
|
||||
|
||||
|
||||
return combined_chunk
|
||||
combined_audio_data = AudioChunk.combine(audio_data_chunks)
|
||||
return combined_audio_data
|
||||
except Exception as e:
|
||||
logger.error(f"Error in combined audio generation: {str(e)}")
|
||||
raise # Re-raise after logging
|
||||
# Removed finally block that closed the writer prematurely
|
||||
# The caller is now responsible for closing the writer after final conversion.
|
||||
logger.error(f"Error in audio generation: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
async def combine_voices(self, voices: List[str]) -> torch.Tensor:
|
||||
|
@ -555,49 +513,38 @@ class TTSService:
|
|||
try:
|
||||
# Get backend and voice path
|
||||
backend = self.model_manager.get_backend()
|
||||
# Use _get_voices_path to handle potential combined voice names passed here too
|
||||
voice_name, voice_path = await self._get_voices_path(voice)
|
||||
|
||||
if isinstance(backend, KokoroV1):
|
||||
# For Kokoro V1, use generate_from_tokens with raw phonemes
|
||||
result_audio = None
|
||||
# Determine language code
|
||||
first_base_voice_match = re.match(r"([a-zA-Z0-9_]+)", voice_name)
|
||||
first_base_voice = first_base_voice_match.group(1) if first_base_voice_match else "a"
|
||||
pipeline_lang_code = lang_code if lang_code else (settings.default_voice_code if settings.default_voice_code else first_base_voice[:1].lower())
|
||||
|
||||
result = None
|
||||
# Use provided lang_code or determine from voice name
|
||||
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
|
||||
logger.info(
|
||||
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme generation"
|
||||
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme pipeline"
|
||||
)
|
||||
|
||||
# Use backend's pipeline management and iterate through potential chunks
|
||||
full_audio_list = []
|
||||
async for r in backend.generate_from_tokens( # generate_from_tokens is now async
|
||||
tokens=phonemes, # Pass raw phonemes string
|
||||
voice=(voice_name, voice_path), # Pass tuple
|
||||
speed=speed,
|
||||
lang_code=pipeline_lang_code,
|
||||
):
|
||||
if r is not None and len(r) > 0:
|
||||
# r is directly the numpy array chunk
|
||||
full_audio_list.append(r)
|
||||
try:
|
||||
# Use backend's pipeline management
|
||||
for r in backend._get_pipeline(
|
||||
pipeline_lang_code
|
||||
).generate_from_tokens(
|
||||
tokens=phonemes, # Pass raw phonemes string
|
||||
voice=voice_path,
|
||||
speed=speed,
|
||||
):
|
||||
if r.audio is not None:
|
||||
result = r
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate from phonemes: {e}")
|
||||
raise RuntimeError(f"Phoneme generation failed: {e}")
|
||||
|
||||
|
||||
if not full_audio_list:
|
||||
raise ValueError("No audio generated from phonemes")
|
||||
|
||||
# Combine chunks if necessary
|
||||
result_audio = np.concatenate(full_audio_list) if len(full_audio_list) > 1 else full_audio_list[0]
|
||||
if result is None or result.audio is None:
|
||||
raise ValueError("No audio generated")
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
# Normalize the final audio before returning
|
||||
normalizer = AudioNormalizer()
|
||||
normalized_audio = normalizer.normalize(result_audio)
|
||||
# Return as int16 for consistency
|
||||
if normalized_audio.dtype != np.int16:
|
||||
normalized_audio = np.clip(normalized_audio * 32767, -32768, 32767).astype(np.int16)
|
||||
|
||||
return normalized_audio, processing_time
|
||||
return result.audio.numpy(), processing_time
|
||||
else:
|
||||
raise ValueError(
|
||||
"Phoneme generation only supported with Kokoro V1 backend"
|
||||
|
|
Loading…
Add table
Reference in a new issue