mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Fix BUGs of streaming non-wav format audio; improve robustness of releasing audio container
Refactor StreamingAudioWriter to improve audio encoding reliability - Restructure audio encoding logic for better error handling - Create a new method `_create_container()` to manage container creation - Improve handling of different audio formats and encoding scenarios - Add error logging for audio chunk encoding failures - Simplify container and stream management in write_chunk method
This commit is contained in:
parent
fbdedfb131
commit
e67264f789
1 changed files with 67 additions and 38 deletions
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
import struct
|
import struct
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Optional
|
from typing import Optional, Dict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
@ -18,18 +18,35 @@ class StreamingAudioWriter:
|
||||||
self.sample_rate = sample_rate
|
self.sample_rate = sample_rate
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.bytes_written = 0
|
self.bytes_written = 0
|
||||||
self.pts=0
|
self.pts = 0
|
||||||
|
|
||||||
codec_map = {"wav":"pcm_s16le","mp3":"mp3","opus":"libopus","flac":"flac", "aac":"aac"}
|
|
||||||
# Format-specific setup
|
# Format-specific setup
|
||||||
if self.format in ["wav", "opus","flac","mp3","aac","pcm"]:
|
if self.format not in ["wav", "opus", "flac", "mp3", "aac", "pcm"]:
|
||||||
if self.format != "pcm":
|
|
||||||
self.output_buffer = BytesIO()
|
|
||||||
self.container = av.open(self.output_buffer, mode="w", format=self.format)
|
|
||||||
self.stream = self.container.add_stream(codec_map[self.format],sample_rate=self.sample_rate,layout='mono' if self.channels == 1 else 'stereo')
|
|
||||||
self.stream.bit_rate = 96000
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported format: {format}")
|
raise ValueError(f"Unsupported format: {format}")
|
||||||
|
|
||||||
|
# Codec mapping
|
||||||
|
self.codec_map = {
|
||||||
|
"wav": "pcm_s16le",
|
||||||
|
"mp3": "mp3",
|
||||||
|
"opus": "libopus",
|
||||||
|
"flac": "flac",
|
||||||
|
"aac": "aac"
|
||||||
|
}
|
||||||
|
|
||||||
|
def _create_container(self):
|
||||||
|
"""Create a new container for each write operation"""
|
||||||
|
if self.format == "pcm":
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
buffer = BytesIO()
|
||||||
|
container = av.open(buffer, mode="w", format=self.format)
|
||||||
|
stream = container.add_stream(
|
||||||
|
self.codec_map[self.format],
|
||||||
|
sample_rate=self.sample_rate,
|
||||||
|
layout='mono' if self.channels == 1 else 'stereo'
|
||||||
|
)
|
||||||
|
stream.bit_rate = 96000
|
||||||
|
return container, buffer
|
||||||
|
|
||||||
def write_chunk(
|
def write_chunk(
|
||||||
self, audio_data: Optional[np.ndarray] = None, finalize: bool = False
|
self, audio_data: Optional[np.ndarray] = None, finalize: bool = False
|
||||||
|
@ -40,38 +57,50 @@ class StreamingAudioWriter:
|
||||||
audio_data: Audio data to write, or None if finalizing
|
audio_data: Audio data to write, or None if finalizing
|
||||||
finalize: Whether this is the final write to close the stream
|
finalize: Whether this is the final write to close the stream
|
||||||
"""
|
"""
|
||||||
|
# Handle PCM format separately as it doesn't use PyAV
|
||||||
if finalize:
|
|
||||||
if self.format != "pcm":
|
|
||||||
# Flush encoder buffers
|
|
||||||
for packet in self.stream.encode(None):
|
|
||||||
self.container.mux(packet)
|
|
||||||
self.container.close()
|
|
||||||
data = self.output_buffer.getvalue()
|
|
||||||
self.output_buffer.seek(0)
|
|
||||||
self.output_buffer.truncate(0)
|
|
||||||
return data
|
|
||||||
return b""
|
|
||||||
|
|
||||||
if audio_data is None or len(audio_data) == 0:
|
|
||||||
return b""
|
|
||||||
|
|
||||||
if self.format == "pcm":
|
if self.format == "pcm":
|
||||||
|
if finalize or audio_data is None or len(audio_data) == 0:
|
||||||
|
return b""
|
||||||
return audio_data.tobytes()
|
return audio_data.tobytes()
|
||||||
else:
|
|
||||||
frame = av.AudioFrame.from_ndarray(audio_data.reshape(1, -1), format='s16', layout='mono' if self.channels == 1 else 'stereo')
|
# Handle empty input
|
||||||
|
if not finalize and (audio_data is None or len(audio_data) == 0):
|
||||||
|
return b""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create a new container for this operation
|
||||||
|
container, buffer = self._create_container()
|
||||||
|
stream = container.streams[0]
|
||||||
|
|
||||||
|
if finalize:
|
||||||
|
# Just return empty bytes for finalize in the new design
|
||||||
|
return b""
|
||||||
|
|
||||||
|
# Create audio frame
|
||||||
|
frame = av.AudioFrame.from_ndarray(
|
||||||
|
audio_data.reshape(1, -1),
|
||||||
|
format='s16',
|
||||||
|
layout='mono' if self.channels == 1 else 'stereo'
|
||||||
|
)
|
||||||
frame.sample_rate = self.sample_rate
|
frame.sample_rate = self.sample_rate
|
||||||
frame.pts = self.pts
|
frame.pts = self.pts
|
||||||
self.pts += frame.samples
|
self.pts += frame.samples
|
||||||
|
|
||||||
encoded_data = b""
|
# Encode the frame
|
||||||
for packet in self.stream.encode(frame):
|
for packet in stream.encode(frame):
|
||||||
self.container.mux(packet)
|
container.mux(packet)
|
||||||
# Get the encoded data from the buffer
|
|
||||||
encoded_data = self.output_buffer.getvalue()
|
# Flush any remaining packets
|
||||||
# Clear the buffer for next write
|
for packet in stream.encode(None):
|
||||||
self.output_buffer.seek(0)
|
container.mux(packet)
|
||||||
self.output_buffer.truncate(0)
|
|
||||||
|
# Close the container and get the data
|
||||||
|
container.close()
|
||||||
|
data = buffer.getvalue()
|
||||||
|
|
||||||
return encoded_data
|
return data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error encoding audio chunk: {e}")
|
||||||
|
return b""
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue