mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Refactor TTS service to improve voice combination logic and error handling. Updated voice parsing to support combined voices with weights, enhanced normalization handling, and streamlined audio generation process. Improved logging for better debugging and removed unnecessary comments for clarity.
This commit is contained in:
parent
7a838ab3e8
commit
66201494d0
1 changed files with 83 additions and 136 deletions
|
@ -103,7 +103,7 @@ class TTSService:
|
||||||
if isinstance(backend, KokoroV1):
|
if isinstance(backend, KokoroV1):
|
||||||
internal_chunk_index = 0
|
internal_chunk_index = 0
|
||||||
async for chunk_data in self.model_manager.generate(
|
async for chunk_data in self.model_manager.generate(
|
||||||
text_for_model.strip(), # Pass cleaned text to model
|
chunk_text,
|
||||||
(voice_name, voice_path),
|
(voice_name, voice_path),
|
||||||
speed=speed,
|
speed=speed,
|
||||||
lang_code=lang_code,
|
lang_code=lang_code,
|
||||||
|
@ -189,7 +189,7 @@ class TTSService:
|
||||||
"""Get voice path, handling combined voices.
|
"""Get voice path, handling combined voices.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
voice: Voice name or combined voice names (e.g., 'af_jadzia(0.7)+af_jessica(0.3)')
|
voice: Voice name or combined voice names (e.g., 'af_jadzia+af_jessica')
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (voice name to use, voice path to use)
|
Tuple of (voice name to use, voice path to use)
|
||||||
|
@ -198,84 +198,70 @@ class TTSService:
|
||||||
RuntimeError: If voice not found
|
RuntimeError: If voice not found
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Regex to handle names, weights, and operators: af_name(weight)[+-]af_other(weight)...
|
# Split the voice on + and - and ensure that they get added to the list eg: hi+bob = ["hi","+","bob"]
|
||||||
pattern = re.compile(r"([a-zA-Z0-9_]+)(?:\((\d+(?:\.\d+)?)\))?([+-]?)")
|
split_voice = re.split(r"([-+])", voice)
|
||||||
matches = pattern.findall(voice.replace(" ", "")) # Remove spaces
|
|
||||||
|
|
||||||
if not matches:
|
# If it is only once voice there is no point in loading it up, doing nothing with it, then saving it
|
||||||
raise ValueError(f"Could not parse voice string: {voice}")
|
if len(split_voice) == 1:
|
||||||
|
# Since its a single voice the only time that the weight would matter is if voice_weight_normalization is off
|
||||||
|
if (
|
||||||
|
"(" not in voice and ")" not in voice
|
||||||
|
) or settings.voice_weight_normalization == True:
|
||||||
|
path = await self._voice_manager.get_voice_path(voice)
|
||||||
|
if not path:
|
||||||
|
raise RuntimeError(f"Voice not found: {voice}")
|
||||||
|
logger.debug(f"Using single voice path: {path}")
|
||||||
|
return voice, path
|
||||||
|
|
||||||
# If only one voice and no explicit weight or operators, handle directly
|
|
||||||
if len(matches) == 1 and not matches[0][1] and not matches[0][2]:
|
|
||||||
voice_name = matches[0][0]
|
|
||||||
path = await self._voice_manager.get_voice_path(voice_name)
|
|
||||||
if not path:
|
|
||||||
raise RuntimeError(f"Voice not found: {voice_name}")
|
|
||||||
logger.debug(f"Using single voice path: {path}")
|
|
||||||
return voice_name, path
|
|
||||||
|
|
||||||
# Process combinations
|
|
||||||
voice_parts = []
|
|
||||||
total_weight = 0
|
total_weight = 0
|
||||||
for name, weight_str, operator in matches:
|
|
||||||
weight = float(weight_str) if weight_str else 1.0
|
|
||||||
voice_parts.append({"name": name, "weight": weight, "op": operator})
|
|
||||||
# Use weight directly for total, normalization happens later if enabled
|
|
||||||
total_weight += weight # Summing base weights before potential normalization
|
|
||||||
|
|
||||||
# Check base voices exist
|
for voice_index in range(0, len(split_voice), 2):
|
||||||
available_voices = await self._voice_manager.list_voices()
|
voice_object = split_voice[voice_index]
|
||||||
for part in voice_parts:
|
|
||||||
if part["name"] not in available_voices:
|
|
||||||
raise ValueError(f"Base voice '{part['name']}' not found in combined string '{voice}'. Available: {available_voices}")
|
|
||||||
|
|
||||||
|
if "(" in voice_object and ")" in voice_object:
|
||||||
|
voice_name = voice_object.split("(")[0].strip()
|
||||||
|
voice_weight = float(voice_object.split("(")[1].split(")")[0])
|
||||||
|
else:
|
||||||
|
voice_name = voice_object
|
||||||
|
voice_weight = 1
|
||||||
|
|
||||||
# Determine normalization factor
|
total_weight += voice_weight
|
||||||
norm_factor = total_weight if settings.voice_weight_normalization and total_weight > 0 else 1.0
|
split_voice[voice_index] = (voice_name, voice_weight)
|
||||||
if settings.voice_weight_normalization:
|
|
||||||
logger.debug(f"Normalizing combined voice weights by factor: {norm_factor:.2f}")
|
|
||||||
else:
|
|
||||||
logger.debug("Voice weight normalization disabled, using raw weights.")
|
|
||||||
|
|
||||||
|
# If voice_weight_normalization is false prevent normalizing the weights by setting the total_weight to 1 so it divides each weight by 1
|
||||||
|
if settings.voice_weight_normalization == False:
|
||||||
|
total_weight = 1
|
||||||
|
|
||||||
# Load and combine tensors
|
# Load the first voice as the starting point for voices to be combined onto
|
||||||
first_part = voice_parts[0]
|
path = await self._voice_manager.get_voice_path(split_voice[0][0])
|
||||||
base_path = await self._voice_manager.get_voice_path(first_part["name"])
|
combined_tensor = await self._load_voice_from_path(
|
||||||
combined_tensor = await self._load_voice_from_path(base_path, first_part["weight"] / norm_factor)
|
path, split_voice[0][1] / total_weight
|
||||||
|
)
|
||||||
|
|
||||||
current_op = "+" # Implicitly start with addition for the first voice
|
# Loop through each + or - in split_voice so they can be applied to combined voice
|
||||||
|
for operation_index in range(1, len(split_voice) - 1, 2):
|
||||||
|
# Get the voice path of the voice 1 index ahead of the operator
|
||||||
|
path = await self._voice_manager.get_voice_path(
|
||||||
|
split_voice[operation_index + 1][0]
|
||||||
|
)
|
||||||
|
voice_tensor = await self._load_voice_from_path(
|
||||||
|
path, split_voice[operation_index + 1][1] / total_weight
|
||||||
|
)
|
||||||
|
|
||||||
for i in range(len(voice_parts) - 1):
|
# Either add or subtract the voice from the current combined voice
|
||||||
current_part = voice_parts[i]
|
if split_voice[operation_index] == "+":
|
||||||
next_part = voice_parts[i+1]
|
|
||||||
|
|
||||||
# Determine the operation based on the *current* part's operator
|
|
||||||
op_symbol = current_part["op"] if current_part["op"] else "+" # Default to '+' if no operator
|
|
||||||
|
|
||||||
path = await self._voice_manager.get_voice_path(next_part["name"])
|
|
||||||
voice_tensor = await self._load_voice_from_path(path, next_part["weight"] / norm_factor)
|
|
||||||
|
|
||||||
if op_symbol == "+":
|
|
||||||
combined_tensor += voice_tensor
|
combined_tensor += voice_tensor
|
||||||
logger.debug(f"Adding voice {next_part['name']} (weight {next_part['weight']/norm_factor:.2f})")
|
else:
|
||||||
elif op_symbol == "-":
|
|
||||||
combined_tensor -= voice_tensor
|
combined_tensor -= voice_tensor
|
||||||
logger.debug(f"Subtracting voice {next_part['name']} (weight {next_part['weight']/norm_factor:.2f})")
|
|
||||||
|
|
||||||
|
# Save the new combined voice so it can be loaded latter
|
||||||
# Save the new combined voice so it can be loaded later
|
|
||||||
# Use a safe filename based on the original input string
|
|
||||||
safe_filename = re.sub(r'[^\w+-]', '_', voice) + ".pt"
|
|
||||||
temp_dir = tempfile.gettempdir()
|
temp_dir = tempfile.gettempdir()
|
||||||
combined_path = os.path.join(temp_dir, safe_filename)
|
combined_path = os.path.join(temp_dir, f"{voice}.pt")
|
||||||
logger.debug(f"Saving combined voice '{voice}' to temporary path: {combined_path}")
|
logger.debug(f"Saving combined voice to: {combined_path}")
|
||||||
# Save the tensor to the device specified by settings for model loading consistency
|
torch.save(combined_tensor, combined_path)
|
||||||
target_device = settings.get_device()
|
return voice, combined_path
|
||||||
torch.save(combined_tensor.to(target_device), combined_path)
|
|
||||||
return voice, combined_path # Return original name and temp path
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get or combine voice path for '{voice}': {e}")
|
logger.error(f"Failed to get voice path: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@ -462,7 +448,7 @@ class TTSService:
|
||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
voice: str,
|
voice: str,
|
||||||
writer: StreamingAudioWriter, # Writer needed even for non-streaming internally
|
writer: StreamingAudioWriter,
|
||||||
speed: float = 1.0,
|
speed: float = 1.0,
|
||||||
return_timestamps: bool = False,
|
return_timestamps: bool = False,
|
||||||
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
|
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
|
||||||
|
@ -470,54 +456,26 @@ class TTSService:
|
||||||
) -> AudioChunk:
|
) -> AudioChunk:
|
||||||
"""Generate complete audio for text using streaming internally."""
|
"""Generate complete audio for text using streaming internally."""
|
||||||
audio_data_chunks = []
|
audio_data_chunks = []
|
||||||
output_format = None # Signal raw audio mode for internal streaming
|
|
||||||
combined_chunk = None
|
|
||||||
try:
|
|
||||||
# Pass a dummy writer if none provided, as generate_audio_stream requires one
|
|
||||||
# Although in raw mode (output_format=None), it shouldn't be heavily used for formatting
|
|
||||||
internal_writer = writer if writer else StreamingAudioWriter(format='wav', sample_rate=settings.sample_rate)
|
|
||||||
|
|
||||||
|
try:
|
||||||
async for audio_stream_data in self.generate_audio_stream(
|
async for audio_stream_data in self.generate_audio_stream(
|
||||||
text,
|
text,
|
||||||
voice,
|
voice,
|
||||||
internal_writer, # Pass the writer instance
|
writer,
|
||||||
speed=speed,
|
speed=speed,
|
||||||
normalization_options=normalization_options,
|
normalization_options=normalization_options,
|
||||||
return_timestamps=return_timestamps, # Pass this down
|
return_timestamps=return_timestamps,
|
||||||
lang_code=lang_code,
|
lang_code=lang_code,
|
||||||
output_format=output_format, # Explicitly None for raw audio
|
output_format=None,
|
||||||
):
|
):
|
||||||
# Ensure we only append chunks with actual audio data
|
if len(audio_stream_data.audio) > 0:
|
||||||
# Raw silence chunks generated for pauses will have audio data (zeros)
|
|
||||||
if audio_stream_data.audio is not None and len(audio_stream_data.audio) > 0:
|
|
||||||
# Ensure timestamps are preserved if requested
|
|
||||||
if return_timestamps and not audio_stream_data.word_timestamps:
|
|
||||||
audio_stream_data.word_timestamps = [] # Initialize if needed
|
|
||||||
audio_data_chunks.append(audio_stream_data)
|
audio_data_chunks.append(audio_stream_data)
|
||||||
|
|
||||||
if not audio_data_chunks:
|
combined_audio_data = AudioChunk.combine(audio_data_chunks)
|
||||||
logger.warning("No valid audio chunks generated.")
|
return combined_audio_data
|
||||||
combined_chunk = AudioChunk(audio=np.array([], dtype=np.int16), word_timestamps=[])
|
|
||||||
|
|
||||||
else:
|
|
||||||
combined_chunk = AudioChunk.combine(audio_data_chunks)
|
|
||||||
# Ensure the combined audio is int16 before returning, as downstream expects this raw format.
|
|
||||||
if combined_chunk.audio.dtype != np.int16:
|
|
||||||
logger.warning(f"Combined audio dtype is {combined_chunk.audio.dtype}, converting to int16.")
|
|
||||||
# Assuming normalization happened, scale from float [-1, 1] to int16
|
|
||||||
if np.issubdtype(combined_chunk.audio.dtype, np.floating):
|
|
||||||
combined_chunk.audio = np.clip(combined_chunk.audio * 32767, -32768, 32767).astype(np.int16)
|
|
||||||
else:
|
|
||||||
# If it's another type, attempt direct conversion (might be lossy)
|
|
||||||
combined_chunk.audio = combined_chunk.audio.astype(np.int16)
|
|
||||||
|
|
||||||
|
|
||||||
return combined_chunk
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in combined audio generation: {str(e)}")
|
logger.error(f"Error in audio generation: {str(e)}")
|
||||||
raise # Re-raise after logging
|
raise
|
||||||
# Removed finally block that closed the writer prematurely
|
|
||||||
# The caller is now responsible for closing the writer after final conversion.
|
|
||||||
|
|
||||||
|
|
||||||
async def combine_voices(self, voices: List[str]) -> torch.Tensor:
|
async def combine_voices(self, voices: List[str]) -> torch.Tensor:
|
||||||
|
@ -555,49 +513,38 @@ class TTSService:
|
||||||
try:
|
try:
|
||||||
# Get backend and voice path
|
# Get backend and voice path
|
||||||
backend = self.model_manager.get_backend()
|
backend = self.model_manager.get_backend()
|
||||||
# Use _get_voices_path to handle potential combined voice names passed here too
|
|
||||||
voice_name, voice_path = await self._get_voices_path(voice)
|
voice_name, voice_path = await self._get_voices_path(voice)
|
||||||
|
|
||||||
if isinstance(backend, KokoroV1):
|
if isinstance(backend, KokoroV1):
|
||||||
# For Kokoro V1, use generate_from_tokens with raw phonemes
|
# For Kokoro V1, use generate_from_tokens with raw phonemes
|
||||||
result_audio = None
|
result = None
|
||||||
# Determine language code
|
# Use provided lang_code or determine from voice name
|
||||||
first_base_voice_match = re.match(r"([a-zA-Z0-9_]+)", voice_name)
|
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
|
||||||
first_base_voice = first_base_voice_match.group(1) if first_base_voice_match else "a"
|
|
||||||
pipeline_lang_code = lang_code if lang_code else (settings.default_voice_code if settings.default_voice_code else first_base_voice[:1].lower())
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme generation"
|
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme pipeline"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use backend's pipeline management and iterate through potential chunks
|
try:
|
||||||
full_audio_list = []
|
# Use backend's pipeline management
|
||||||
async for r in backend.generate_from_tokens( # generate_from_tokens is now async
|
for r in backend._get_pipeline(
|
||||||
tokens=phonemes, # Pass raw phonemes string
|
pipeline_lang_code
|
||||||
voice=(voice_name, voice_path), # Pass tuple
|
).generate_from_tokens(
|
||||||
speed=speed,
|
tokens=phonemes, # Pass raw phonemes string
|
||||||
lang_code=pipeline_lang_code,
|
voice=voice_path,
|
||||||
):
|
speed=speed,
|
||||||
if r is not None and len(r) > 0:
|
):
|
||||||
# r is directly the numpy array chunk
|
if r.audio is not None:
|
||||||
full_audio_list.append(r)
|
result = r
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to generate from phonemes: {e}")
|
||||||
|
raise RuntimeError(f"Phoneme generation failed: {e}")
|
||||||
|
|
||||||
|
if result is None or result.audio is None:
|
||||||
if not full_audio_list:
|
raise ValueError("No audio generated")
|
||||||
raise ValueError("No audio generated from phonemes")
|
|
||||||
|
|
||||||
# Combine chunks if necessary
|
|
||||||
result_audio = np.concatenate(full_audio_list) if len(full_audio_list) > 1 else full_audio_list[0]
|
|
||||||
|
|
||||||
processing_time = time.time() - start_time
|
processing_time = time.time() - start_time
|
||||||
# Normalize the final audio before returning
|
return result.audio.numpy(), processing_time
|
||||||
normalizer = AudioNormalizer()
|
|
||||||
normalized_audio = normalizer.normalize(result_audio)
|
|
||||||
# Return as int16 for consistency
|
|
||||||
if normalized_audio.dtype != np.int16:
|
|
||||||
normalized_audio = np.clip(normalized_audio * 32767, -32768, 32767).astype(np.int16)
|
|
||||||
|
|
||||||
return normalized_audio, processing_time
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Phoneme generation only supported with Kokoro V1 backend"
|
"Phoneme generation only supported with Kokoro V1 backend"
|
||||||
|
|
Loading…
Add table
Reference in a new issue