Kokoro-FastAPI/api/src/services/streaming_audio_writer.py
2025-04-04 16:50:46 -06:00

84 lines
2.8 KiB
Python

"""Audio conversion service with proper streaming support"""
import struct
from io import BytesIO
from typing import Optional
import av
import numpy as np
import soundfile as sf
from loguru import logger
from pydub import AudioSegment
class StreamingAudioWriter:
"""Handles streaming audio format conversions"""
def __init__(self, format: str, sample_rate: int, channels: int = 1):
self.format = format.lower()
self.sample_rate = sample_rate
self.channels = channels
self.bytes_written = 0
self.pts=0
codec_map = {"wav":"pcm_s16le","mp3":"mp3","opus":"libopus","flac":"flac", "aac":"aac"}
# Format-specific setup
if self.format in ["wav","flac","mp3","pcm","aac","opus"]:
if self.format != "pcm":
self.output_buffer = BytesIO()
self.container = av.open(self.output_buffer, mode="w", format=self.format if self.format != "aac" else "adts")
self.stream = self.container.add_stream(codec_map[self.format],sample_rate=self.sample_rate,layout='mono' if self.channels == 1 else 'stereo')
self.stream.bit_rate = 128000
else:
raise ValueError(f"Unsupported format: {format}")
def close(self):
if hasattr(self, "container"):
self.container.close()
if hasattr(self, "output_buffer"):
self.output_buffer.close()
def write_chunk(
self, audio_data: Optional[np.ndarray] = None, finalize: bool = False
) -> bytes:
"""Write a chunk of audio data and return bytes in the target format.
Args:
audio_data: Audio data to write, or None if finalizing
finalize: Whether this is the final write to close the stream
"""
if finalize:
if self.format != "pcm":
packets = self.stream.encode(None)
for packet in packets:
self.container.mux(packet)
data=self.output_buffer.getvalue()
self.close()
return data
if audio_data is None or len(audio_data) == 0:
return b""
if self.format == "pcm":
# Write raw bytes
return audio_data.tobytes()
else:
frame = av.AudioFrame.from_ndarray(audio_data.reshape(1, -1), format='s16', layout='mono' if self.channels == 1 else 'stereo')
frame.sample_rate=self.sample_rate
frame.pts = self.pts
self.pts += frame.samples
packets = self.stream.encode(frame)
for packet in packets:
self.container.mux(packet)
data = self.output_buffer.getvalue()
self.output_buffer.seek(0)
self.output_buffer.truncate(0)
return data