Fix BUGs of streaming non-wav format audio; improve robustness of releasing audio container

Refactor StreamingAudioWriter to improve audio encoding reliability

- Restructure audio encoding logic for better error handling
- Create a new method `_create_container()` to manage container creation
- Improve handling of different audio formats and encoding scenarios
- Add error logging for audio chunk encoding failures
- Simplify container and stream management in write_chunk method
This commit is contained in:
CodePothunter 2025-03-10 13:26:55 +08:00
parent fbdedfb131
commit e67264f789

View file

@ -2,7 +2,7 @@
import struct import struct
from io import BytesIO from io import BytesIO
from typing import Optional from typing import Optional, Dict
import numpy as np import numpy as np
import soundfile as sf import soundfile as sf
@ -18,19 +18,36 @@ class StreamingAudioWriter:
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.channels = channels self.channels = channels
self.bytes_written = 0 self.bytes_written = 0
self.pts=0 self.pts = 0
codec_map = {"wav":"pcm_s16le","mp3":"mp3","opus":"libopus","flac":"flac", "aac":"aac"}
# Format-specific setup # Format-specific setup
if self.format in ["wav", "opus","flac","mp3","aac","pcm"]: if self.format not in ["wav", "opus", "flac", "mp3", "aac", "pcm"]:
if self.format != "pcm":
self.output_buffer = BytesIO()
self.container = av.open(self.output_buffer, mode="w", format=self.format)
self.stream = self.container.add_stream(codec_map[self.format],sample_rate=self.sample_rate,layout='mono' if self.channels == 1 else 'stereo')
self.stream.bit_rate = 96000
else:
raise ValueError(f"Unsupported format: {format}") raise ValueError(f"Unsupported format: {format}")
# Codec mapping
self.codec_map = {
"wav": "pcm_s16le",
"mp3": "mp3",
"opus": "libopus",
"flac": "flac",
"aac": "aac"
}
def _create_container(self):
"""Create a new container for each write operation"""
if self.format == "pcm":
return None, None
buffer = BytesIO()
container = av.open(buffer, mode="w", format=self.format)
stream = container.add_stream(
self.codec_map[self.format],
sample_rate=self.sample_rate,
layout='mono' if self.channels == 1 else 'stereo'
)
stream.bit_rate = 96000
return container, buffer
def write_chunk( def write_chunk(
self, audio_data: Optional[np.ndarray] = None, finalize: bool = False self, audio_data: Optional[np.ndarray] = None, finalize: bool = False
) -> bytes: ) -> bytes:
@ -40,38 +57,50 @@ class StreamingAudioWriter:
audio_data: Audio data to write, or None if finalizing audio_data: Audio data to write, or None if finalizing
finalize: Whether this is the final write to close the stream finalize: Whether this is the final write to close the stream
""" """
# Handle PCM format separately as it doesn't use PyAV
if self.format == "pcm":
if finalize or audio_data is None or len(audio_data) == 0:
return b""
return audio_data.tobytes()
# Handle empty input
if not finalize and (audio_data is None or len(audio_data) == 0):
return b""
try:
# Create a new container for this operation
container, buffer = self._create_container()
stream = container.streams[0]
if finalize: if finalize:
if self.format != "pcm": # Just return empty bytes for finalize in the new design
# Flush encoder buffers
for packet in self.stream.encode(None):
self.container.mux(packet)
self.container.close()
data = self.output_buffer.getvalue()
self.output_buffer.seek(0)
self.output_buffer.truncate(0)
return data
return b"" return b""
if audio_data is None or len(audio_data) == 0: # Create audio frame
return b"" frame = av.AudioFrame.from_ndarray(
audio_data.reshape(1, -1),
if self.format == "pcm": format='s16',
return audio_data.tobytes() layout='mono' if self.channels == 1 else 'stereo'
else: )
frame = av.AudioFrame.from_ndarray(audio_data.reshape(1, -1), format='s16', layout='mono' if self.channels == 1 else 'stereo')
frame.sample_rate = self.sample_rate frame.sample_rate = self.sample_rate
frame.pts = self.pts frame.pts = self.pts
self.pts += frame.samples self.pts += frame.samples
encoded_data = b"" # Encode the frame
for packet in self.stream.encode(frame): for packet in stream.encode(frame):
self.container.mux(packet) container.mux(packet)
# Get the encoded data from the buffer
encoded_data = self.output_buffer.getvalue()
# Clear the buffer for next write
self.output_buffer.seek(0)
self.output_buffer.truncate(0)
return encoded_data # Flush any remaining packets
for packet in stream.encode(None):
container.mux(packet)
# Close the container and get the data
container.close()
data = buffer.getvalue()
return data
except Exception as e:
logger.error(f"Error encoding audio chunk: {e}")
return b""