wip: fix mp3

This commit is contained in:
Cong Nguyen 2025-04-01 18:05:37 +11:00
parent 64ced408b7
commit 5067a47ff9
2 changed files with 97 additions and 36 deletions

View file

@ -50,11 +50,15 @@ class StreamingAudioWriter:
if finalize: if finalize:
if self.format != "pcm": if self.format != "pcm":
packets = self.stream.encode(None) # Properly finalize the container
for packet in packets: if hasattr(self, 'stream') and self.stream:
self.container.mux(packet) packets = self.stream.encode(None)
for packet in packets:
self.container.mux(packet)
# Important: Call close() on the container
self.container.close()
data=self.output_buffer.getvalue() data = self.output_buffer.getvalue() if hasattr(self, 'output_buffer') else b''
self.close() self.close()
return data return data
@ -64,6 +68,20 @@ class StreamingAudioWriter:
if self.format == "pcm": if self.format == "pcm":
# Write raw bytes # Write raw bytes
return audio_data.tobytes() return audio_data.tobytes()
elif self.format == "mp3":
# For MP3, we need to handle streaming differently
frame = av.AudioFrame.from_ndarray(audio_data.reshape(1, -1), format='s16', layout='mono' if self.channels == 1 else 'stereo')
frame.sample_rate = self.sample_rate
frame.pts = self.pts
self.pts += frame.samples
packets = self.stream.encode(frame)
for packet in packets:
self.container.mux(packet)
# For MP3, just return an empty byte string to indicate processing occurred
# The complete MP3 file will be returned during finalization
return b''
else: else:
frame = av.AudioFrame.from_ndarray(audio_data.reshape(1, -1), format='s16', layout='mono' if self.channels == 1 else 'stereo') frame = av.AudioFrame.from_ndarray(audio_data.reshape(1, -1), format='s16', layout='mono' if self.channels == 1 else 'stereo')
frame.sample_rate=self.sample_rate frame.sample_rate=self.sample_rate
@ -80,4 +98,3 @@ class StreamingAudioWriter:
self.output_buffer.seek(0) self.output_buffer.seek(0)
self.output_buffer.truncate(0) self.output_buffer.truncate(0)
return data return data

View file

@ -245,6 +245,12 @@ class TTSService:
stream_normalizer = AudioNormalizer() stream_normalizer = AudioNormalizer()
chunk_index = 0 chunk_index = 0
current_offset = 0.0 current_offset = 0.0
# For MP3 format, we'll collect all audio chunks and encode at the end
is_mp3_format = output_format.lower() == "mp3"
all_audio_chunks = []
all_word_timestamps = []
try: try:
# Get backend # Get backend
backend = self.model_manager.get_backend() backend = self.model_manager.get_backend()
@ -268,13 +274,15 @@ class TTSService:
voice_path, # Pass voice path voice_path, # Pass voice path
speed, speed,
writer, writer,
output_format, # For MP3, we don't pass the output format during processing
None if is_mp3_format else output_format,
is_first=(chunk_index == 0), is_first=(chunk_index == 0),
is_last=False, # We'll update the last chunk later is_last=False, # We'll update the last chunk later
normalizer=stream_normalizer, normalizer=stream_normalizer,
lang_code=pipeline_lang_code, # Pass lang_code lang_code=pipeline_lang_code, # Pass lang_code
return_timestamps=return_timestamps, return_timestamps=return_timestamps,
): ):
# Update timestamps with current offset
if chunk_data.word_timestamps is not None: if chunk_data.word_timestamps is not None:
for timestamp in chunk_data.word_timestamps: for timestamp in chunk_data.word_timestamps:
timestamp.start_time += current_offset timestamp.start_time += current_offset
@ -282,42 +290,78 @@ class TTSService:
current_offset += len(chunk_data.audio) / 24000 current_offset += len(chunk_data.audio) / 24000
if chunk_data.output is not None: if is_mp3_format:
yield chunk_data # Collect audio for MP3 final encoding
if len(chunk_data.audio) > 0:
all_audio_chunks.append(chunk_data.audio)
if chunk_data.word_timestamps:
all_word_timestamps.extend(chunk_data.word_timestamps)
else: else:
logger.warning(f"No audio generated for chunk: '{chunk_text[:100]}...'") # For other formats, stream normally
if chunk_data.output is not None:
yield chunk_data
else:
logger.warning(f"No audio generated for chunk: '{chunk_text[:100]}...'")
chunk_index += 1 chunk_index += 1
except Exception as e: except Exception as e:
logger.error(f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}") logger.error(f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}")
continue continue
# Only finalize if we successfully processed at least one chunk # Only finalize if we successfully processed at least one chunk
if chunk_index > 0: if chunk_index > 0:
try: if is_mp3_format:
# Empty tokens list to finalize audio # For MP3, create a combined audio file
async for chunk_data in self._process_chunk( if all_audio_chunks:
"", # Empty text try:
[], # Empty tokens # Combine all audio chunks
voice_name, combined_audio = np.concatenate(all_audio_chunks)
voice_path,
speed, # Create a fresh MP3 writer for final encoding
writer, mp3_writer = StreamingAudioWriter("mp3", sample_rate=settings.sample_rate)
output_format,
is_first=False, # Write all audio at once and finalize
is_last=True, # Signal this is the last chunk mp3_writer.write_chunk(combined_audio)
normalizer=stream_normalizer, mp3_data = mp3_writer.write_chunk(finalize=True)
lang_code=pipeline_lang_code, # Pass lang_code mp3_writer.close()
):
if chunk_data.output is not None: # Create the final chunk with all audio and timestamps
yield chunk_data result_chunk = AudioChunk(
except Exception as e: audio=combined_audio,
logger.error(f"Failed to finalize audio stream: {str(e)}") word_timestamps=all_word_timestamps,
output=mp3_data
)
yield result_chunk
except Exception as e:
logger.error(f"Failed to create final MP3 file: {str(e)}")
raise RuntimeError(f"MP3 encoding failed: {str(e)}")
else:
# For other formats, just finalize the stream
try:
# Empty tokens list to finalize audio
async for chunk_data in self._process_chunk(
"", # Empty text
[], # Empty tokens
voice_name,
voice_path,
speed,
writer,
output_format,
is_first=False,
is_last=True, # Signal this is the last chunk
normalizer=stream_normalizer,
lang_code=pipeline_lang_code, # Pass lang_code
):
if chunk_data.output is not None:
yield chunk_data
except Exception as e:
logger.error(f"Failed to finalize audio stream: {str(e)}")
except Exception as e: except Exception as e:
logger.error(f"Error in phoneme audio generation: {str(e)}") logger.error(f"Error in audio generation: {str(e)}")
raise e raise e
async def generate_audio( async def generate_audio(
self, self,
text: str, text: str,