mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Refactor TTS service to improve audio chunk handling and filename safety. Removed unnecessary comments, adjusted text processing for legacy backends, and enhanced error handling during audio stream generation. Updated filename regex to restrict allowed characters for safer filenames.
This commit is contained in:
parent
88b9349198
commit
7a838ab3e8
2 changed files with 9 additions and 14 deletions
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from typing import AsyncGenerator, Dict, List, Tuple, Optional # Add Optional import
|
from typing import AsyncGenerator, Dict, List, Tuple, Optional
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
|
|
@ -78,7 +78,6 @@ class TTSService:
|
||||||
"",
|
"",
|
||||||
normalizer=normalizer,
|
normalizer=normalizer,
|
||||||
is_last_chunk=True,
|
is_last_chunk=True,
|
||||||
trim_audio=False, # Don't trim final silence
|
|
||||||
)
|
)
|
||||||
yield chunk_data
|
yield chunk_data
|
||||||
return
|
return
|
||||||
|
@ -92,9 +91,7 @@ class TTSService:
|
||||||
backend = self.model_manager.get_backend()
|
backend = self.model_manager.get_backend()
|
||||||
|
|
||||||
# Generate audio using pre-warmed model
|
# Generate audio using pre-warmed model
|
||||||
# Note: chunk_text is the *original* text including custom phoneme markers and newlines
|
|
||||||
# The model needs the text *with phonemes restored*
|
|
||||||
text_for_model = chunk_text # Start with original
|
|
||||||
# Restore custom phonemes if backend needs it (like KokoroV1)
|
# Restore custom phonemes if backend needs it (like KokoroV1)
|
||||||
if isinstance(backend, KokoroV1):
|
if isinstance(backend, KokoroV1):
|
||||||
# Find phoneme markers in this specific chunk_text and restore
|
# Find phoneme markers in this specific chunk_text and restore
|
||||||
|
@ -120,7 +117,7 @@ class TTSService:
|
||||||
output_format,
|
output_format,
|
||||||
writer,
|
writer,
|
||||||
speed,
|
speed,
|
||||||
chunk_text.strip(), # Pass original text for trimming logic
|
chunk_text,
|
||||||
is_last_chunk=is_last, # Should always be False here, handled above
|
is_last_chunk=is_last, # Should always be False here, handled above
|
||||||
normalizer=normalizer,
|
normalizer=normalizer,
|
||||||
trim_audio=True # Trim speech parts
|
trim_audio=True # Trim speech parts
|
||||||
|
@ -138,9 +135,7 @@ class TTSService:
|
||||||
logger.warning(f"Model generation yielded no audio chunks for: '{text_for_model[:50]}...'")
|
logger.warning(f"Model generation yielded no audio chunks for: '{text_for_model[:50]}...'")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# --- Legacy backend path (using tokens) ---
|
# For legacy backends, load voice tensor
|
||||||
# This path might not work correctly with custom phonemes restored in text_for_model
|
|
||||||
logger.warning("Using legacy backend path with tokens - custom phonemes might not be handled.")
|
|
||||||
voice_tensor = await self._voice_manager.load_voice(
|
voice_tensor = await self._voice_manager.load_voice(
|
||||||
voice_name, device=backend.device
|
voice_name, device=backend.device
|
||||||
)
|
)
|
||||||
|
@ -270,7 +265,7 @@ class TTSService:
|
||||||
|
|
||||||
# Save the new combined voice so it can be loaded later
|
# Save the new combined voice so it can be loaded later
|
||||||
# Use a safe filename based on the original input string
|
# Use a safe filename based on the original input string
|
||||||
safe_filename = re.sub(r'[^\w+-.\(\)]', '_', voice) + ".pt" # Allow weights in filename
|
safe_filename = re.sub(r'[^\w+-]', '_', voice) + ".pt"
|
||||||
temp_dir = tempfile.gettempdir()
|
temp_dir = tempfile.gettempdir()
|
||||||
combined_path = os.path.join(temp_dir, safe_filename)
|
combined_path = os.path.join(temp_dir, safe_filename)
|
||||||
logger.debug(f"Saving combined voice '{voice}' to temporary path: {combined_path}")
|
logger.debug(f"Saving combined voice '{voice}' to temporary path: {combined_path}")
|
||||||
|
@ -456,10 +451,10 @@ class TTSService:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Error during audio stream generation: {str(e)}") # Use exception for traceback
|
logger.exception(f"Error during audio stream generation: {str(e)}") # Use exception for traceback
|
||||||
# Ensure writer is closed on error - moved to caller (e.g., route handler)
|
# Ensure writer is closed on error - moved to caller (e.g., route handler)
|
||||||
# try:
|
try:
|
||||||
# writer.close()
|
writer.close()
|
||||||
# except Exception as close_e:
|
except Exception as close_e:
|
||||||
# logger.error(f"Error closing writer during exception handling: {close_e}")
|
logger.error(f"Error closing writer during exception handling: {close_e}")
|
||||||
raise e # Re-raise the original exception
|
raise e # Re-raise the original exception
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue