mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Add StreamingAudioWriter class for audio format conversions and remove deprecated migration notes
This commit is contained in:
parent
409a9e9af3
commit
8a60a2b90c
4 changed files with 141 additions and 180 deletions
|
@ -1,70 +0,0 @@
|
|||
# UV Setup
|
||||
Deprecated notes for myself
|
||||
## Structure
|
||||
```
|
||||
docker/
|
||||
├── cpu/
|
||||
│ ├── pyproject.toml # CPU deps (torch CPU)
|
||||
│ └── requirements.lock # CPU lockfile
|
||||
├── gpu/
|
||||
│ ├── pyproject.toml # GPU deps (torch CUDA)
|
||||
│ └── requirements.lock # GPU lockfile
|
||||
└── shared/
|
||||
└── pyproject.toml # Common deps
|
||||
```
|
||||
|
||||
## Regenerate Lock Files
|
||||
|
||||
### CPU
|
||||
```bash
|
||||
cd docker/cpu
|
||||
uv pip compile pyproject.toml ../shared/pyproject.toml --output-file requirements.lock
|
||||
```
|
||||
|
||||
### GPU
|
||||
```bash
|
||||
cd docker/gpu
|
||||
uv pip compile pyproject.toml ../shared/pyproject.toml --output-file requirements.lock
|
||||
```
|
||||
|
||||
## Local Dev Setup
|
||||
|
||||
### CPU
|
||||
```bash
|
||||
cd docker/cpu
|
||||
uv venv
|
||||
.venv\Scripts\activate # Windows
|
||||
uv pip sync requirements.lock
|
||||
```
|
||||
|
||||
### GPU
|
||||
```bash
|
||||
cd docker/gpu
|
||||
uv venv
|
||||
.venv\Scripts\activate # Windows
|
||||
uv pip sync requirements.lock --extra-index-url https://download.pytorch.org/whl/cu121 --index-strategy unsafe-best-match
|
||||
```
|
||||
|
||||
### Run Server
|
||||
```bash
|
||||
# From project root with venv active:
|
||||
uvicorn api.src.main:app --reload
|
||||
```
|
||||
|
||||
## Docker
|
||||
|
||||
### CPU
|
||||
```bash
|
||||
cd docker/cpu
|
||||
docker compose up
|
||||
```
|
||||
|
||||
### GPU
|
||||
```bash
|
||||
cd docker/gpu
|
||||
docker compose up
|
||||
```
|
||||
|
||||
## Known Issues
|
||||
- Module imports: Run server from project root
|
||||
- PyTorch CUDA: Always use --extra-index-url and --index-strategy for GPU env
|
|
@ -10,7 +10,7 @@ from loguru import logger
|
|||
from pydub import AudioSegment
|
||||
|
||||
from ..core.config import settings
|
||||
|
||||
from .streaming_audio_writer import StreamingAudioWriter
|
||||
|
||||
class AudioNormalizer:
|
||||
"""Handles audio normalization state for a single stream"""
|
||||
|
@ -45,7 +45,7 @@ class AudioNormalizer:
|
|||
|
||||
|
||||
class AudioService:
|
||||
"""Service for audio format conversions"""
|
||||
"""Service for audio format conversions with streaming support"""
|
||||
|
||||
# Default audio format settings balanced for speed and compression
|
||||
DEFAULT_SETTINGS = {
|
||||
|
@ -64,6 +64,8 @@ class AudioService:
|
|||
},
|
||||
}
|
||||
|
||||
_writers = {}
|
||||
|
||||
@staticmethod
|
||||
async def convert_audio(
|
||||
audio_data: np.ndarray,
|
||||
|
@ -72,127 +74,46 @@ class AudioService:
|
|||
is_first_chunk: bool = True,
|
||||
is_last_chunk: bool = False,
|
||||
normalizer: AudioNormalizer = None,
|
||||
format_settings: dict = None,
|
||||
stream: bool = True,
|
||||
) -> bytes:
|
||||
"""Convert audio data to specified format
|
||||
"""Convert audio data to specified format with streaming support
|
||||
|
||||
Args:
|
||||
audio_data: Numpy array of audio samples
|
||||
sample_rate: Sample rate of the audio
|
||||
output_format: Target format (wav, mp3, opus, flac, pcm)
|
||||
is_first_chunk: Whether this is the first chunk of a stream
|
||||
normalizer: Optional AudioNormalizer instance for consistent normalization across chunks
|
||||
format_settings: Optional dict of format-specific settings to override defaults
|
||||
Example: {
|
||||
"mp3": {
|
||||
"bitrate_mode": "VARIABLE",
|
||||
"compression_level": 0.8
|
||||
}
|
||||
}
|
||||
Default settings balance speed and compression:
|
||||
optimized for localhost @ 0.0
|
||||
- MP3: constant bitrate, no compression (0.0)
|
||||
- OPUS: no compression (0.0)
|
||||
- FLAC: no compression (0.0)
|
||||
output_format: Target format (wav, mp3, ogg, pcm)
|
||||
is_first_chunk: Whether this is the first chunk
|
||||
is_last_chunk: Whether this is the last chunk
|
||||
normalizer: Optional AudioNormalizer instance for consistent normalization
|
||||
|
||||
Returns:
|
||||
Bytes of the converted audio
|
||||
Bytes of the converted audio chunk
|
||||
"""
|
||||
buffer = BytesIO()
|
||||
|
||||
try:
|
||||
# Always normalize audio to ensure proper amplitude scaling
|
||||
if normalizer is None:
|
||||
normalizer = AudioNormalizer()
|
||||
normalized_audio = await normalizer.normalize(audio_data)
|
||||
|
||||
if output_format == "pcm":
|
||||
# Raw 16-bit PCM samples, no header
|
||||
buffer.write(normalized_audio.tobytes())
|
||||
elif output_format == "wav":
|
||||
# Write the WAV header ourselves so that we can specify a "fake" data size.
|
||||
# This is necessary for streaming responses to work properly: if we simply
|
||||
# concatenated individual WAV files then the initial chunk's header length
|
||||
# would be shorter than the full file length and subsequent chunks' RIFF
|
||||
# headers would appear in the middle of the audio data.
|
||||
if is_first_chunk:
|
||||
# Modified from Python stdlib's wave.py module:
|
||||
buffer.write(b'RIFF')
|
||||
buffer.write(struct.pack('<L4s4sLHHLLHH4s',
|
||||
0xFFFFFFFF, # total size (set to max)
|
||||
b'WAVE',
|
||||
b'fmt ',
|
||||
16,
|
||||
1, # PCM format
|
||||
1, # channels
|
||||
sample_rate,
|
||||
sample_rate * 2, # byte rate
|
||||
2, # block align
|
||||
16, # bits per sample
|
||||
b'data'
|
||||
))
|
||||
buffer.write(struct.pack('<L', 0xFFFFFFFF)) # data size (set to max)
|
||||
# write raw PCM data
|
||||
buffer.write(normalized_audio.tobytes())
|
||||
elif output_format == "mp3":
|
||||
# MP3 format with proper framing
|
||||
settings = format_settings.get("mp3", {}) if format_settings else {}
|
||||
settings = {**AudioService.DEFAULT_SETTINGS["mp3"], **settings}
|
||||
sf.write(
|
||||
buffer, normalized_audio, sample_rate, format="MP3", **settings
|
||||
)
|
||||
elif output_format == "opus":
|
||||
# Opus format in OGG container
|
||||
settings = format_settings.get("opus", {}) if format_settings else {}
|
||||
settings = {**AudioService.DEFAULT_SETTINGS["opus"], **settings}
|
||||
sf.write(
|
||||
buffer,
|
||||
normalized_audio,
|
||||
sample_rate,
|
||||
format="OGG",
|
||||
subtype="OPUS",
|
||||
**settings,
|
||||
)
|
||||
elif output_format == "flac":
|
||||
# FLAC format with proper framing
|
||||
if is_first_chunk:
|
||||
logger.info("Starting FLAC stream...")
|
||||
settings = format_settings.get("flac", {}) if format_settings else {}
|
||||
settings = {**AudioService.DEFAULT_SETTINGS["flac"], **settings}
|
||||
sf.write(
|
||||
buffer,
|
||||
normalized_audio,
|
||||
sample_rate,
|
||||
format="FLAC",
|
||||
subtype="PCM_16",
|
||||
**settings,
|
||||
)
|
||||
elif output_format == "aac":
|
||||
# Convert numpy array directly to AAC using pydub
|
||||
audio_segment = AudioSegment(
|
||||
normalized_audio.tobytes(),
|
||||
frame_rate=sample_rate,
|
||||
sample_width=normalized_audio.dtype.itemsize,
|
||||
channels=1 if len(normalized_audio.shape) == 1 else normalized_audio.shape[1]
|
||||
# Get or create format-specific writer
|
||||
writer_key = f"{output_format}_{sample_rate}"
|
||||
if is_first_chunk or writer_key not in AudioService._writers:
|
||||
AudioService._writers[writer_key] = StreamingAudioWriter(
|
||||
output_format, sample_rate
|
||||
)
|
||||
writer = AudioService._writers[writer_key]
|
||||
|
||||
settings = format_settings.get("aac", {}) if format_settings else {}
|
||||
settings = {**AudioService.DEFAULT_SETTINGS["aac"], **settings}
|
||||
# Write the current chunk
|
||||
chunk_data = writer.write_chunk(normalized_audio)
|
||||
|
||||
audio_segment.export(
|
||||
buffer,
|
||||
format="adts", # ADTS is a common AAC container format
|
||||
bitrate=settings["bitrate"]
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Format {output_format} not supported. Supported formats are: wav, mp3, opus, flac, pcm, aac."
|
||||
)
|
||||
# Handle last chunk and cleanup
|
||||
if is_last_chunk:
|
||||
final_data = writer.close()
|
||||
if final_data:
|
||||
chunk_data += final_data
|
||||
del AudioService._writers[writer_key]
|
||||
|
||||
buffer.seek(0)
|
||||
return buffer.getvalue()
|
||||
return chunk_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting audio to {output_format}: {str(e)}")
|
||||
raise ValueError(f"Failed to convert audio to {output_format}: {str(e)}")
|
||||
logger.error(f"Error converting audio stream to {output_format}: {str(e)}")
|
||||
raise ValueError(f"Failed to convert audio stream to {output_format}: {str(e)}")
|
||||
|
|
111
api/src/services/streaming_audio_writer.py
Normal file
111
api/src/services/streaming_audio_writer.py
Normal file
|
@ -0,0 +1,111 @@
|
|||
"""Audio conversion service with proper streaming support"""
|
||||
|
||||
from io import BytesIO
|
||||
import struct
|
||||
from typing import Generator, Optional
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
from loguru import logger
|
||||
from pydub import AudioSegment
|
||||
|
||||
class StreamingAudioWriter:
|
||||
"""Handles streaming audio format conversions"""
|
||||
|
||||
def __init__(self, format: str, sample_rate: int, channels: int = 1):
|
||||
self.format = format.lower()
|
||||
self.sample_rate = sample_rate
|
||||
self.channels = channels
|
||||
self.bytes_written = 0
|
||||
|
||||
# Format-specific setup
|
||||
if self.format == "wav":
|
||||
self._write_wav_header()
|
||||
elif self.format == "ogg":
|
||||
self.writer = sf.SoundFile(
|
||||
file=BytesIO(),
|
||||
mode='w',
|
||||
samplerate=sample_rate,
|
||||
channels=channels,
|
||||
format='OGG',
|
||||
subtype='VORBIS'
|
||||
)
|
||||
elif self.format == "mp3":
|
||||
# For MP3, we'll use pydub's incremental writer
|
||||
self.buffer = BytesIO()
|
||||
self.encoder = AudioSegment.from_mono_audiosegments()
|
||||
|
||||
def _write_wav_header(self) -> bytes:
|
||||
"""Write WAV header with correct streaming format"""
|
||||
header = BytesIO()
|
||||
header.write(b'RIFF')
|
||||
header.write(struct.pack('<L', 0)) # Placeholder for file size
|
||||
header.write(b'WAVE')
|
||||
header.write(b'fmt ')
|
||||
header.write(struct.pack('<L', 16)) # fmt chunk size
|
||||
header.write(struct.pack('<H', 1)) # PCM format
|
||||
header.write(struct.pack('<H', self.channels))
|
||||
header.write(struct.pack('<L', self.sample_rate))
|
||||
header.write(struct.pack('<L', self.sample_rate * self.channels * 2)) # Byte rate
|
||||
header.write(struct.pack('<H', self.channels * 2)) # Block align
|
||||
header.write(struct.pack('<H', 16)) # Bits per sample
|
||||
header.write(b'data')
|
||||
header.write(struct.pack('<L', 0)) # Placeholder for data size
|
||||
return header.getvalue()
|
||||
|
||||
def write_chunk(self, audio_data: np.ndarray) -> bytes:
|
||||
"""Write a chunk of audio data and return bytes in the target format"""
|
||||
buffer = BytesIO()
|
||||
|
||||
if self.format == "wav":
|
||||
# For WAV, we write raw PCM after the first chunk
|
||||
if self.bytes_written == 0:
|
||||
buffer.write(self._write_wav_header())
|
||||
buffer.write(audio_data.tobytes())
|
||||
self.bytes_written += len(audio_data.tobytes())
|
||||
|
||||
elif self.format == "ogg":
|
||||
# OGG/Vorbis handles streaming naturally
|
||||
self.writer.write(audio_data)
|
||||
self.writer.flush()
|
||||
buffer = self.writer.file
|
||||
buffer.seek(0, 2) # Seek to end
|
||||
chunk = buffer.getvalue()
|
||||
buffer.seek(0)
|
||||
buffer.truncate()
|
||||
return chunk
|
||||
|
||||
elif self.format == "mp3":
|
||||
# Convert chunk to AudioSegment and encode
|
||||
segment = AudioSegment(
|
||||
audio_data.tobytes(),
|
||||
frame_rate=self.sample_rate,
|
||||
sample_width=audio_data.dtype.itemsize,
|
||||
channels=self.channels
|
||||
)
|
||||
self.encoder += segment
|
||||
self.encoder.export(buffer, format="mp3")
|
||||
|
||||
return buffer.getvalue()
|
||||
|
||||
def close(self) -> Optional[bytes]:
|
||||
"""Finish the audio file and return any remaining data"""
|
||||
if self.format == "wav":
|
||||
# Update WAV header with final file size
|
||||
buffer = BytesIO()
|
||||
buffer.write(b'RIFF')
|
||||
buffer.write(struct.pack('<L', self.bytes_written + 36)) # File size
|
||||
buffer.write(b'WAVE')
|
||||
# ... rest of header ...
|
||||
buffer.write(struct.pack('<L', self.bytes_written)) # Data size
|
||||
return buffer.getvalue()
|
||||
|
||||
elif self.format == "ogg":
|
||||
self.writer.close()
|
||||
return None
|
||||
|
||||
elif self.format == "mp3":
|
||||
# Flush any remaining MP3 frames
|
||||
buffer = BytesIO()
|
||||
self.encoder.export(buffer, format="mp3")
|
||||
return buffer.getvalue()
|
|
@ -73,8 +73,7 @@ class TTSService:
|
|||
output_format,
|
||||
is_first_chunk=is_first,
|
||||
normalizer=normalizer,
|
||||
is_last_chunk=is_last,
|
||||
stream=True
|
||||
is_last_chunk=is_last
|
||||
)
|
||||
|
||||
return chunk_audio
|
||||
|
|
Loading…
Add table
Reference in a new issue