From e67264f7896111f8b423381e1a883ecee8fb2295 Mon Sep 17 00:00:00 2001 From: CodePothunter Date: Mon, 10 Mar 2025 13:26:55 +0800 Subject: [PATCH] Fix BUGs of streaming non-wav format audio; improve robustness of releasing audio container Refactor StreamingAudioWriter to improve audio encoding reliability - Restructure audio encoding logic for better error handling - Create a new method `_create_container()` to manage container creation - Improve handling of different audio formats and encoding scenarios - Add error logging for audio chunk encoding failures - Simplify container and stream management in write_chunk method --- api/src/services/streaming_audio_writer.py | 105 +++++++++++++-------- 1 file changed, 67 insertions(+), 38 deletions(-) diff --git a/api/src/services/streaming_audio_writer.py b/api/src/services/streaming_audio_writer.py index 2bcb0f3..7bf94f4 100644 --- a/api/src/services/streaming_audio_writer.py +++ b/api/src/services/streaming_audio_writer.py @@ -2,7 +2,7 @@ import struct from io import BytesIO -from typing import Optional +from typing import Optional, Dict import numpy as np import soundfile as sf @@ -18,18 +18,35 @@ class StreamingAudioWriter: self.sample_rate = sample_rate self.channels = channels self.bytes_written = 0 - self.pts=0 - - codec_map = {"wav":"pcm_s16le","mp3":"mp3","opus":"libopus","flac":"flac", "aac":"aac"} + self.pts = 0 + # Format-specific setup - if self.format in ["wav", "opus","flac","mp3","aac","pcm"]: - if self.format != "pcm": - self.output_buffer = BytesIO() - self.container = av.open(self.output_buffer, mode="w", format=self.format) - self.stream = self.container.add_stream(codec_map[self.format],sample_rate=self.sample_rate,layout='mono' if self.channels == 1 else 'stereo') - self.stream.bit_rate = 96000 - else: + if self.format not in ["wav", "opus", "flac", "mp3", "aac", "pcm"]: raise ValueError(f"Unsupported format: {format}") + + # Codec mapping + self.codec_map = { + "wav": "pcm_s16le", + "mp3": "mp3", + "opus": "libopus", + "flac": "flac", + "aac": "aac" + } + + def _create_container(self): + """Create a new container for each write operation""" + if self.format == "pcm": + return None, None + + buffer = BytesIO() + container = av.open(buffer, mode="w", format=self.format) + stream = container.add_stream( + self.codec_map[self.format], + sample_rate=self.sample_rate, + layout='mono' if self.channels == 1 else 'stereo' + ) + stream.bit_rate = 96000 + return container, buffer def write_chunk( self, audio_data: Optional[np.ndarray] = None, finalize: bool = False @@ -40,38 +57,50 @@ class StreamingAudioWriter: audio_data: Audio data to write, or None if finalizing finalize: Whether this is the final write to close the stream """ - - if finalize: - if self.format != "pcm": - # Flush encoder buffers - for packet in self.stream.encode(None): - self.container.mux(packet) - self.container.close() - data = self.output_buffer.getvalue() - self.output_buffer.seek(0) - self.output_buffer.truncate(0) - return data - return b"" - - if audio_data is None or len(audio_data) == 0: - return b"" - + # Handle PCM format separately as it doesn't use PyAV if self.format == "pcm": + if finalize or audio_data is None or len(audio_data) == 0: + return b"" return audio_data.tobytes() - else: - frame = av.AudioFrame.from_ndarray(audio_data.reshape(1, -1), format='s16', layout='mono' if self.channels == 1 else 'stereo') + + # Handle empty input + if not finalize and (audio_data is None or len(audio_data) == 0): + return b"" + + try: + # Create a new container for this operation + container, buffer = self._create_container() + stream = container.streams[0] + + if finalize: + # Just return empty bytes for finalize in the new design + return b"" + + # Create audio frame + frame = av.AudioFrame.from_ndarray( + audio_data.reshape(1, -1), + format='s16', + layout='mono' if self.channels == 1 else 'stereo' + ) frame.sample_rate = self.sample_rate frame.pts = self.pts self.pts += frame.samples - encoded_data = b"" - for packet in self.stream.encode(frame): - self.container.mux(packet) - # Get the encoded data from the buffer - encoded_data = self.output_buffer.getvalue() - # Clear the buffer for next write - self.output_buffer.seek(0) - self.output_buffer.truncate(0) + # Encode the frame + for packet in stream.encode(frame): + container.mux(packet) + + # Flush any remaining packets + for packet in stream.encode(None): + container.mux(packet) + + # Close the container and get the data + container.close() + data = buffer.getvalue() - return encoded_data + return data + + except Exception as e: + logger.error(f"Error encoding audio chunk: {e}") + return b""