Fix BUGs of streaming non-wav format audio; improve robustness of releasing audio container

Refactor StreamingAudioWriter to improve audio encoding reliability

- Restructure audio encoding logic for better error handling
- Create a new method `_create_container()` to manage container creation
- Improve handling of different audio formats and encoding scenarios
- Add error logging for audio chunk encoding failures
- Simplify container and stream management in write_chunk method
This commit is contained in:
CodePothunter 2025-03-10 13:26:55 +08:00
parent fbdedfb131
commit e67264f789

View file

@ -2,7 +2,7 @@
import struct import struct
from io import BytesIO from io import BytesIO
from typing import Optional from typing import Optional, Dict
import numpy as np import numpy as np
import soundfile as sf import soundfile as sf
@ -18,18 +18,35 @@ class StreamingAudioWriter:
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.channels = channels self.channels = channels
self.bytes_written = 0 self.bytes_written = 0
self.pts=0 self.pts = 0
codec_map = {"wav":"pcm_s16le","mp3":"mp3","opus":"libopus","flac":"flac", "aac":"aac"}
# Format-specific setup # Format-specific setup
if self.format in ["wav", "opus","flac","mp3","aac","pcm"]: if self.format not in ["wav", "opus", "flac", "mp3", "aac", "pcm"]:
if self.format != "pcm":
self.output_buffer = BytesIO()
self.container = av.open(self.output_buffer, mode="w", format=self.format)
self.stream = self.container.add_stream(codec_map[self.format],sample_rate=self.sample_rate,layout='mono' if self.channels == 1 else 'stereo')
self.stream.bit_rate = 96000
else:
raise ValueError(f"Unsupported format: {format}") raise ValueError(f"Unsupported format: {format}")
# Codec mapping
self.codec_map = {
"wav": "pcm_s16le",
"mp3": "mp3",
"opus": "libopus",
"flac": "flac",
"aac": "aac"
}
def _create_container(self):
"""Create a new container for each write operation"""
if self.format == "pcm":
return None, None
buffer = BytesIO()
container = av.open(buffer, mode="w", format=self.format)
stream = container.add_stream(
self.codec_map[self.format],
sample_rate=self.sample_rate,
layout='mono' if self.channels == 1 else 'stereo'
)
stream.bit_rate = 96000
return container, buffer
def write_chunk( def write_chunk(
self, audio_data: Optional[np.ndarray] = None, finalize: bool = False self, audio_data: Optional[np.ndarray] = None, finalize: bool = False
@ -40,38 +57,50 @@ class StreamingAudioWriter:
audio_data: Audio data to write, or None if finalizing audio_data: Audio data to write, or None if finalizing
finalize: Whether this is the final write to close the stream finalize: Whether this is the final write to close the stream
""" """
# Handle PCM format separately as it doesn't use PyAV
if finalize:
if self.format != "pcm":
# Flush encoder buffers
for packet in self.stream.encode(None):
self.container.mux(packet)
self.container.close()
data = self.output_buffer.getvalue()
self.output_buffer.seek(0)
self.output_buffer.truncate(0)
return data
return b""
if audio_data is None or len(audio_data) == 0:
return b""
if self.format == "pcm": if self.format == "pcm":
if finalize or audio_data is None or len(audio_data) == 0:
return b""
return audio_data.tobytes() return audio_data.tobytes()
else:
frame = av.AudioFrame.from_ndarray(audio_data.reshape(1, -1), format='s16', layout='mono' if self.channels == 1 else 'stereo') # Handle empty input
if not finalize and (audio_data is None or len(audio_data) == 0):
return b""
try:
# Create a new container for this operation
container, buffer = self._create_container()
stream = container.streams[0]
if finalize:
# Just return empty bytes for finalize in the new design
return b""
# Create audio frame
frame = av.AudioFrame.from_ndarray(
audio_data.reshape(1, -1),
format='s16',
layout='mono' if self.channels == 1 else 'stereo'
)
frame.sample_rate = self.sample_rate frame.sample_rate = self.sample_rate
frame.pts = self.pts frame.pts = self.pts
self.pts += frame.samples self.pts += frame.samples
encoded_data = b"" # Encode the frame
for packet in self.stream.encode(frame): for packet in stream.encode(frame):
self.container.mux(packet) container.mux(packet)
# Get the encoded data from the buffer
encoded_data = self.output_buffer.getvalue() # Flush any remaining packets
# Clear the buffer for next write for packet in stream.encode(None):
self.output_buffer.seek(0) container.mux(packet)
self.output_buffer.truncate(0)
# Close the container and get the data
container.close()
data = buffer.getvalue()
return encoded_data return data
except Exception as e:
logger.error(f"Error encoding audio chunk: {e}")
return b""