mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Refactor TTS service to improve filename safety and audio chunk handling. Updated filename regex to allow additional characters, enhanced silence chunk creation for AudioService, and ensured final audio output is consistently in int16 format. Removed premature writer closure in the finalization process, delegating responsibility to the caller.
This commit is contained in:
parent
4b334beff4
commit
207d709de1
2 changed files with 61 additions and 22 deletions
|
@ -270,7 +270,7 @@ class TTSService:
|
||||||
|
|
||||||
# Save the new combined voice so it can be loaded later
|
# Save the new combined voice so it can be loaded later
|
||||||
# Use a safe filename based on the original input string
|
# Use a safe filename based on the original input string
|
||||||
safe_filename = re.sub(r'[^\w+-]', '_', voice) + ".pt"
|
safe_filename = re.sub(r'[^\w+-.\(\)]', '_', voice) + ".pt" # Allow weights in filename
|
||||||
temp_dir = tempfile.gettempdir()
|
temp_dir = tempfile.gettempdir()
|
||||||
combined_path = os.path.join(temp_dir, safe_filename)
|
combined_path = os.path.join(temp_dir, safe_filename)
|
||||||
logger.debug(f"Saving combined voice '{voice}' to temporary path: {combined_path}")
|
logger.debug(f"Saving combined voice '{voice}' to temporary path: {combined_path}")
|
||||||
|
@ -328,6 +328,7 @@ class TTSService:
|
||||||
try:
|
try:
|
||||||
logger.debug(f"Generating {pause_duration_s}s silence chunk")
|
logger.debug(f"Generating {pause_duration_s}s silence chunk")
|
||||||
silence_samples = int(pause_duration_s * settings.sample_rate)
|
silence_samples = int(pause_duration_s * settings.sample_rate)
|
||||||
|
# Create silence appropriate for AudioService (float32)
|
||||||
silence_audio = np.zeros(silence_samples, dtype=np.float32)
|
silence_audio = np.zeros(silence_samples, dtype=np.float32)
|
||||||
pause_chunk = AudioChunk(audio=silence_audio, word_timestamps=[]) # Empty timestamps for silence
|
pause_chunk = AudioChunk(audio=silence_audio, word_timestamps=[]) # Empty timestamps for silence
|
||||||
|
|
||||||
|
@ -340,13 +341,14 @@ class TTSService:
|
||||||
if formatted_pause_chunk.output:
|
if formatted_pause_chunk.output:
|
||||||
yield formatted_pause_chunk
|
yield formatted_pause_chunk
|
||||||
else: # Raw audio mode
|
else: # Raw audio mode
|
||||||
|
# Normalize to int16 for raw output consistency
|
||||||
pause_chunk.audio = stream_normalizer.normalize(pause_chunk.audio)
|
pause_chunk.audio = stream_normalizer.normalize(pause_chunk.audio)
|
||||||
if len(pause_chunk.audio) > 0:
|
if len(pause_chunk.audio) > 0:
|
||||||
yield pause_chunk
|
yield pause_chunk
|
||||||
|
|
||||||
# Update offset based on silence duration
|
# Update offset based on silence duration
|
||||||
current_offset += pause_duration_s
|
current_offset += pause_duration_s
|
||||||
chunk_index += 1
|
chunk_index += 1 # Count pause as a yielded chunk
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to process pause chunk: {str(e)}")
|
logger.error(f"Failed to process pause chunk: {str(e)}")
|
||||||
|
@ -368,8 +370,8 @@ class TTSService:
|
||||||
speed,
|
speed,
|
||||||
writer,
|
writer,
|
||||||
output_format,
|
output_format,
|
||||||
is_first=(chunk_index == 0),
|
is_first=(chunk_index == 0), # Check if this is the very first *audio* chunk
|
||||||
is_last=False,
|
is_last=False, # is_last is handled separately after the loop
|
||||||
normalizer=stream_normalizer,
|
normalizer=stream_normalizer,
|
||||||
lang_code=pipeline_lang_code,
|
lang_code=pipeline_lang_code,
|
||||||
return_timestamps=return_timestamps,
|
return_timestamps=return_timestamps,
|
||||||
|
@ -381,7 +383,6 @@ class TTSService:
|
||||||
timestamp.end_time += current_offset
|
timestamp.end_time += current_offset
|
||||||
|
|
||||||
# Update offset based on the *actual duration* of the generated audio chunk
|
# Update offset based on the *actual duration* of the generated audio chunk
|
||||||
# Check if audio data exists before calculating duration
|
|
||||||
chunk_duration = 0
|
chunk_duration = 0
|
||||||
if chunk_data.audio is not None and len(chunk_data.audio) > 0:
|
if chunk_data.audio is not None and len(chunk_data.audio) > 0:
|
||||||
chunk_duration = len(chunk_data.audio) / settings.sample_rate
|
chunk_duration = len(chunk_data.audio) / settings.sample_rate
|
||||||
|
@ -397,7 +398,6 @@ class TTSService:
|
||||||
f"No audio generated or output for text chunk: '{text_chunk_for_model[:50]}...'"
|
f"No audio generated or output for text chunk: '{text_chunk_for_model[:50]}...'"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# --- Add pause after newline (if applicable) ---
|
# --- Add pause after newline (if applicable) ---
|
||||||
if has_trailing_newline:
|
if has_trailing_newline:
|
||||||
newline_pause_s = 0.5
|
newline_pause_s = 0.5
|
||||||
|
@ -445,20 +445,21 @@ class TTSService:
|
||||||
"", [], voice_name, voice_path, speed, writer, output_format,
|
"", [], voice_name, voice_path, speed, writer, output_format,
|
||||||
is_first=False, is_last=True, normalizer=stream_normalizer, lang_code=pipeline_lang_code
|
is_first=False, is_last=True, normalizer=stream_normalizer, lang_code=pipeline_lang_code
|
||||||
):
|
):
|
||||||
|
# Yield final formatted chunk or raw empty chunk
|
||||||
if output_format and final_chunk_data.output:
|
if output_format and final_chunk_data.output:
|
||||||
yield final_chunk_data
|
yield final_chunk_data
|
||||||
elif not output_format and final_chunk_data.audio is not None and len(final_chunk_data.audio) > 0:
|
elif not output_format: # Raw mode: Finalize yields empty chunk signal
|
||||||
yield final_chunk_data # Should yield empty chunk in raw mode upon finalize
|
yield final_chunk_data # Yields empty AudioChunk
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to finalize audio stream: {str(e)}")
|
logger.error(f"Failed to finalize audio stream: {str(e)}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Error during audio stream generation: {str(e)}") # Use exception for traceback
|
logger.exception(f"Error during audio stream generation: {str(e)}") # Use exception for traceback
|
||||||
# Ensure writer is closed on error
|
# Ensure writer is closed on error - moved to caller (e.g., route handler)
|
||||||
try:
|
# try:
|
||||||
writer.close()
|
# writer.close()
|
||||||
except Exception as close_e:
|
# except Exception as close_e:
|
||||||
logger.error(f"Error closing writer during exception handling: {close_e}")
|
# logger.error(f"Error closing writer during exception handling: {close_e}")
|
||||||
raise e # Re-raise the original exception
|
raise e # Re-raise the original exception
|
||||||
|
|
||||||
|
|
||||||
|
@ -477,38 +478,51 @@ class TTSService:
|
||||||
output_format = None # Signal raw audio mode for internal streaming
|
output_format = None # Signal raw audio mode for internal streaming
|
||||||
combined_chunk = None
|
combined_chunk = None
|
||||||
try:
|
try:
|
||||||
|
# Pass a dummy writer if none provided, as generate_audio_stream requires one
|
||||||
|
# Although in raw mode (output_format=None), it shouldn't be heavily used for formatting
|
||||||
|
internal_writer = writer if writer else StreamingAudioWriter(format='wav', sample_rate=settings.sample_rate)
|
||||||
|
|
||||||
async for audio_stream_data in self.generate_audio_stream(
|
async for audio_stream_data in self.generate_audio_stream(
|
||||||
text,
|
text,
|
||||||
voice,
|
voice,
|
||||||
writer, # Pass writer, although it won't be used for formatting here
|
internal_writer, # Pass the writer instance
|
||||||
speed=speed,
|
speed=speed,
|
||||||
normalization_options=normalization_options,
|
normalization_options=normalization_options,
|
||||||
return_timestamps=return_timestamps,
|
return_timestamps=return_timestamps, # Pass this down
|
||||||
lang_code=lang_code,
|
lang_code=lang_code,
|
||||||
output_format=output_format, # Explicitly None for raw audio
|
output_format=output_format, # Explicitly None for raw audio
|
||||||
):
|
):
|
||||||
# Ensure we only append chunks with actual audio data
|
# Ensure we only append chunks with actual audio data
|
||||||
# Raw silence chunks generated for pauses will have audio data (zeros)
|
# Raw silence chunks generated for pauses will have audio data (zeros)
|
||||||
if audio_stream_data.audio is not None and len(audio_stream_data.audio) > 0:
|
if audio_stream_data.audio is not None and len(audio_stream_data.audio) > 0:
|
||||||
|
# Ensure timestamps are preserved if requested
|
||||||
|
if return_timestamps and not audio_stream_data.word_timestamps:
|
||||||
|
audio_stream_data.word_timestamps = [] # Initialize if needed
|
||||||
audio_data_chunks.append(audio_stream_data)
|
audio_data_chunks.append(audio_stream_data)
|
||||||
|
|
||||||
if not audio_data_chunks:
|
if not audio_data_chunks:
|
||||||
logger.warning("No valid audio chunks generated.")
|
logger.warning("No valid audio chunks generated.")
|
||||||
combined_chunk = AudioChunk(audio=np.array([], dtype=np.int16), word_timestamps=[])
|
combined_chunk = AudioChunk(audio=np.array([], dtype=np.int16), word_timestamps=[])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
combined_chunk = AudioChunk.combine(audio_data_chunks)
|
combined_chunk = AudioChunk.combine(audio_data_chunks)
|
||||||
|
# Ensure the combined audio is int16 before returning, as downstream expects this raw format.
|
||||||
|
if combined_chunk.audio.dtype != np.int16:
|
||||||
|
logger.warning(f"Combined audio dtype is {combined_chunk.audio.dtype}, converting to int16.")
|
||||||
|
# Assuming normalization happened, scale from float [-1, 1] to int16
|
||||||
|
if np.issubdtype(combined_chunk.audio.dtype, np.floating):
|
||||||
|
combined_chunk.audio = np.clip(combined_chunk.audio * 32767, -32768, 32767).astype(np.int16)
|
||||||
|
else:
|
||||||
|
# If it's another type, attempt direct conversion (might be lossy)
|
||||||
|
combined_chunk.audio = combined_chunk.audio.astype(np.int16)
|
||||||
|
|
||||||
|
|
||||||
return combined_chunk
|
return combined_chunk
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in combined audio generation: {str(e)}")
|
logger.error(f"Error in combined audio generation: {str(e)}")
|
||||||
raise # Re-raise after logging
|
raise # Re-raise after logging
|
||||||
finally:
|
# Removed finally block that closed the writer prematurely
|
||||||
# Explicitly close the writer if it was passed, though it shouldn't hold resources in raw mode
|
# The caller is now responsible for closing the writer after final conversion.
|
||||||
try:
|
|
||||||
writer.close()
|
|
||||||
except Exception:
|
|
||||||
pass # Ignore errors during cleanup
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def combine_voices(self, voices: List[str]) -> torch.Tensor:
|
async def combine_voices(self, voices: List[str]) -> torch.Tensor:
|
||||||
|
@ -584,6 +598,10 @@ class TTSService:
|
||||||
# Normalize the final audio before returning
|
# Normalize the final audio before returning
|
||||||
normalizer = AudioNormalizer()
|
normalizer = AudioNormalizer()
|
||||||
normalized_audio = normalizer.normalize(result_audio)
|
normalized_audio = normalizer.normalize(result_audio)
|
||||||
|
# Return as int16 for consistency
|
||||||
|
if normalized_audio.dtype != np.int16:
|
||||||
|
normalized_audio = np.clip(normalized_audio * 32767, -32768, 32767).astype(np.int16)
|
||||||
|
|
||||||
return normalized_audio, processing_time
|
return normalized_audio, processing_time
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
21
run-tests.sh
Executable file
21
run-tests.sh
Executable file
|
@ -0,0 +1,21 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Get project root directory
|
||||||
|
PROJECT_ROOT=$(pwd)
|
||||||
|
|
||||||
|
# Set environment variables
|
||||||
|
export USE_GPU=false
|
||||||
|
export USE_ONNX=false
|
||||||
|
export PYTHONPATH=$PROJECT_ROOT:$PROJECT_ROOT/api
|
||||||
|
export MODEL_DIR=src/models
|
||||||
|
export VOICES_DIR=src/voices/v1_0
|
||||||
|
export WEB_PLAYER_PATH=$PROJECT_ROOT/web
|
||||||
|
# Set the espeak-ng data path to your location
|
||||||
|
export ESPEAK_DATA_PATH=/usr/lib/x86_64-linux-gnu/espeak-ng-data
|
||||||
|
|
||||||
|
# Run FastAPI with CPU extras using uv run
|
||||||
|
# Note: espeak may still require manual installation,
|
||||||
|
uv pip install -e ".[test,cpu]"
|
||||||
|
uv run --no-sync python docker/scripts/download_model.py --output api/src/models/v1_0
|
||||||
|
|
||||||
|
uv run pytest api/tests/ --asyncio-mode=auto --cov=api --cov-report=term-missing
|
Loading…
Add table
Reference in a new issue