Kokoro-FastAPI/api/src/services/streaming_audio_writer.py

234 lines
8.4 KiB
Python
Raw Normal View History

"""Audio conversion service with proper streaming support"""
import struct
2025-02-09 18:32:17 -07:00
from io import BytesIO
from typing import Optional
import numpy as np
import soundfile as sf
from loguru import logger
from pydub import AudioSegment
2025-02-09 18:32:17 -07:00
class StreamingAudioWriter:
"""Handles streaming audio format conversions"""
def __init__(self, format: str, sample_rate: int, channels: int = 1):
self.format = format.lower()
self.sample_rate = sample_rate
self.channels = channels
self.bytes_written = 0
2025-01-28 13:52:57 -07:00
self.buffer = BytesIO()
# Format-specific setup
if self.format == "wav":
self._write_wav_header_initial()
2025-01-28 13:52:57 -07:00
elif self.format in ["ogg", "opus"]:
# For OGG/Opus, write to memory buffer
self.writer = sf.SoundFile(
2025-01-28 13:52:57 -07:00
file=self.buffer,
2025-02-09 18:32:17 -07:00
mode="w",
samplerate=sample_rate,
channels=channels,
2025-02-09 18:32:17 -07:00
format="OGG",
subtype="VORBIS" if self.format == "ogg" else "OPUS",
)
2025-01-28 13:52:57 -07:00
elif self.format == "flac":
# For FLAC, write to memory buffer
self.writer = sf.SoundFile(
file=self.buffer,
2025-02-09 18:32:17 -07:00
mode="w",
2025-01-28 13:52:57 -07:00
samplerate=sample_rate,
channels=channels,
2025-02-09 18:32:17 -07:00
format="FLAC",
2025-01-28 13:52:57 -07:00
)
elif self.format in ["mp3", "aac"]:
# For MP3/AAC, we'll use pydub's incremental writer
self.segments = [] # Store segments until we have enough data
self.total_duration = 0 # Track total duration in milliseconds
# Initialize an empty AudioSegment as our encoder
self.encoder = AudioSegment.silent(duration=0, frame_rate=self.sample_rate)
2025-01-28 13:52:57 -07:00
elif self.format == "pcm":
# PCM doesn't need initialization, we'll write raw bytes
pass
else:
raise ValueError(f"Unsupported format: {format}")
def _write_wav_header_initial(self) -> None:
"""Write initial WAV header with placeholders"""
2025-02-09 18:32:17 -07:00
self.buffer.write(b"RIFF")
self.buffer.write(struct.pack("<L", 0)) # Placeholder for file size
self.buffer.write(b"WAVE")
self.buffer.write(b"fmt ")
self.buffer.write(struct.pack("<L", 16)) # fmt chunk size
self.buffer.write(struct.pack("<H", 1)) # PCM format
self.buffer.write(struct.pack("<H", self.channels))
self.buffer.write(struct.pack("<L", self.sample_rate))
self.buffer.write(
struct.pack("<L", self.sample_rate * self.channels * 2)
) # Byte rate
self.buffer.write(struct.pack("<H", self.channels * 2)) # Block align
self.buffer.write(struct.pack("<H", 16)) # Bits per sample
self.buffer.write(b"data")
self.buffer.write(struct.pack("<L", 0)) # Placeholder for data size
def write_chunk(
self, audio_data: Optional[np.ndarray] = None, finalize: bool = False
) -> bytes:
"""Write a chunk of audio data and return bytes in the target format.
2025-02-09 18:32:17 -07:00
Args:
audio_data: Audio data to write, or None if finalizing
finalize: Whether this is the final write to close the stream
"""
2025-01-28 13:52:57 -07:00
output_buffer = BytesIO()
if finalize:
if self.format == "wav":
# Calculate actual file and data sizes
file_size = self.bytes_written + 36 # RIFF header bytes
data_size = self.bytes_written
# Seek to the beginning to overwrite the placeholders
self.buffer.seek(4)
2025-02-09 18:32:17 -07:00
self.buffer.write(struct.pack("<L", file_size))
self.buffer.seek(40)
2025-02-09 18:32:17 -07:00
self.buffer.write(struct.pack("<L", data_size))
self.buffer.seek(0)
return self.buffer.read()
2025-01-28 13:52:57 -07:00
elif self.format in ["ogg", "opus", "flac"]:
self.writer.close()
2025-01-28 13:52:57 -07:00
return self.buffer.getvalue()
elif self.format in ["mp3", "aac"]:
2025-02-09 18:32:17 -07:00
if hasattr(self, "encoder") and len(self.encoder) > 0:
2025-01-28 13:52:57 -07:00
format_args = {
"mp3": {"format": "mp3", "codec": "libmp3lame"},
2025-02-09 18:32:17 -07:00
"aac": {"format": "adts", "codec": "aac"},
2025-01-28 13:52:57 -07:00
}[self.format]
2025-02-09 18:32:17 -07:00
parameters = []
if self.format == "mp3":
2025-02-09 18:32:17 -07:00
parameters.extend(
[
"-q:a",
"2",
"-write_xing",
"1", # XING header for MP3
"-id3v1",
"1",
"-id3v2",
"1",
"-write_vbr",
"1",
"-vbr_quality",
"2",
]
)
elif self.format == "aac":
2025-02-09 18:32:17 -07:00
parameters.extend(
[
"-q:a",
"2",
"-write_xing",
"0",
"-write_id3v1",
"0",
"-write_id3v2",
"0",
]
)
self.encoder.export(
2025-01-28 13:52:57 -07:00
output_buffer,
**format_args,
bitrate="192k",
2025-02-09 18:32:17 -07:00
parameters=parameters,
)
self.encoder = None
return output_buffer.getvalue()
2025-02-09 18:32:17 -07:00
if audio_data is None or len(audio_data) == 0:
2025-02-09 18:32:17 -07:00
return b""
if self.format == "wav":
# Write raw PCM data
self.buffer.write(audio_data.tobytes())
self.bytes_written += len(audio_data.tobytes())
2025-02-09 18:32:17 -07:00
return b""
2025-01-28 13:52:57 -07:00
elif self.format in ["ogg", "opus", "flac"]:
# Write to soundfile buffer
self.writer.write(audio_data)
self.writer.flush()
return self.buffer.getvalue()
2025-02-09 18:32:17 -07:00
2025-01-28 13:52:57 -07:00
elif self.format in ["mp3", "aac"]:
# Convert chunk to AudioSegment and encode
segment = AudioSegment(
audio_data.tobytes(),
frame_rate=self.sample_rate,
sample_width=audio_data.dtype.itemsize,
2025-02-09 18:32:17 -07:00
channels=self.channels,
)
2025-02-09 18:32:17 -07:00
# Track total duration
self.total_duration += len(segment)
2025-02-09 18:32:17 -07:00
# Add segment to encoder
self.encoder += segment
2025-02-09 18:32:17 -07:00
# Export current state to buffer without final metadata
2025-01-28 13:52:57 -07:00
format_args = {
"mp3": {"format": "mp3", "codec": "libmp3lame"},
2025-02-09 18:32:17 -07:00
"aac": {"format": "adts", "codec": "aac"},
2025-01-28 13:52:57 -07:00
}[self.format]
2025-02-09 18:32:17 -07:00
# For chunks, export without duration metadata or XING headers
2025-02-09 18:32:17 -07:00
self.encoder.export(
output_buffer,
**format_args,
bitrate="192k",
parameters=[
"-q:a",
"2",
"-write_xing",
"0", # No XING headers for chunks
],
)
# Get the encoded data
2025-01-28 13:52:57 -07:00
encoded_data = output_buffer.getvalue()
2025-02-09 18:32:17 -07:00
# Reset encoder to prevent memory growth
self.encoder = AudioSegment.silent(duration=0, frame_rate=self.sample_rate)
2025-02-09 18:32:17 -07:00
return encoded_data
2025-02-09 18:32:17 -07:00
2025-01-28 13:52:57 -07:00
elif self.format == "pcm":
# Write raw bytes
2025-01-28 13:52:57 -07:00
return audio_data.tobytes()
2025-02-09 18:32:17 -07:00
return b""
def close(self) -> Optional[bytes]:
"""Finish the audio file and return any remaining data"""
if self.format == "wav":
# Re-finalize WAV file by updating headers
self.buffer.seek(0)
file_content = self.write_chunk(finalize=True)
return file_content
2025-01-28 13:52:57 -07:00
elif self.format in ["ogg", "opus", "flac"]:
# Finalize other formats
self.writer.close()
2025-01-28 13:52:57 -07:00
return self.buffer.getvalue()
2025-01-28 13:52:57 -07:00
elif self.format in ["mp3", "aac"]:
# Finalize MP3/AAC
final_data = self.write_chunk(finalize=True)
return final_data
2025-01-28 13:52:57 -07:00
2025-02-09 18:32:17 -07:00
return None