mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
commit
ee5be65596
8 changed files with 486 additions and 67 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -14,3 +14,5 @@ env/
|
|||
.Python
|
||||
|
||||
|
||||
.coverage
|
||||
|
||||
|
|
|
@ -2,8 +2,12 @@
|
|||
|
||||
Notable changes to this project will be documented in this file.
|
||||
|
||||
## 2024-01-09
|
||||
|
||||
## 2025-01-02
|
||||
- Audio Format Support:
|
||||
- Added comprehensive audio format conversion support (mp3, wav, opus, flac)
|
||||
|
||||
## 2025-01-01
|
||||
### Added
|
||||
- Gradio Web Interface:
|
||||
- Added simple web UI utility for audio generation from input or txt file
|
||||
|
|
74
README.md
74
README.md
|
@ -3,8 +3,8 @@
|
|||
</p>
|
||||
|
||||
# Kokoro TTS API
|
||||
[]()
|
||||
[]()
|
||||
[]()
|
||||
[]()
|
||||
[](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667)
|
||||
|
||||
Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model
|
||||
|
@ -14,8 +14,7 @@ Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokor
|
|||
- automatic chunking/stitching for long texts
|
||||
- simple audio generation web ui utility
|
||||
|
||||
<details open>
|
||||
<summary><b>Quick Start</b></summary>
|
||||
## Quick Start
|
||||
|
||||
The service can be accessed through either the API endpoints or the Gradio web interface.
|
||||
|
||||
|
@ -48,9 +47,10 @@ The service can be accessed through either the API endpoints or the Gradio web i
|
|||
<p align="center">
|
||||
<img src="ui\GradioScreenShot.png" width="80%" alt="Voice Analysis Comparison" style="border: 2px solid #333; padding: 10px;">
|
||||
</p>
|
||||
</details>
|
||||
|
||||
## Features
|
||||
<details>
|
||||
<summary><b>OpenAI-Compatible Speech Endpoint</b></summary>
|
||||
<summary>OpenAI-Compatible Speech Endpoint</summary>
|
||||
|
||||
```python
|
||||
# Using OpenAI's Python library
|
||||
|
@ -98,7 +98,10 @@ python examples/test_all_voices.py # Test all available voices
|
|||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>Voice Combination</b></summary>
|
||||
<summary>Voice Combination</summary>
|
||||
|
||||
- Averages model weights of any existing voicepacks
|
||||
- Saves generated voicepacks for future use
|
||||
|
||||
Combine voices and generate audio:
|
||||
```python
|
||||
|
@ -129,7 +132,23 @@ response = requests.post(
|
|||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>Gradio Web Utility</b></summary>
|
||||
<summary>Multiple Output Audio Formats</summary>
|
||||
|
||||
- mp3
|
||||
- wav
|
||||
- opus
|
||||
- flac
|
||||
- aac
|
||||
- pcm
|
||||
|
||||
<p align="center">
|
||||
<img src="examples/benchmarks/format_comparison.png" width="80%" alt="Audio Format Comparison" style="border: 2px solid #333; padding: 10px;">
|
||||
</p>
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Gradio Web Utility</summary>
|
||||
|
||||
Access the interactive web UI at http://localhost:7860 after starting the service. Features include:
|
||||
- Voice/format/speed selection
|
||||
|
@ -141,9 +160,9 @@ If you only want the API, just comment out everything in the docker-compose.yml
|
|||
Currently, voices created via the API are accessible here, but voice combination/creation has not yet been added
|
||||
</details>
|
||||
|
||||
|
||||
## Processing Details
|
||||
<details>
|
||||
<summary><b>Performance Benchmarks</b></summary>
|
||||
<summary>Performance Benchmarks</summary>
|
||||
|
||||
Benchmarking was performed on generation via the local API using text lengths up to feature-length books (~1.5 hours output), measuring processing time and realtime factor. Tests were run on:
|
||||
- Windows 11 Home w/ WSL2
|
||||
|
@ -163,7 +182,7 @@ Key Performance Metrics:
|
|||
- Average Processing Rate: 137.67 tokens/second (cl100k_base)
|
||||
</details>
|
||||
<details>
|
||||
<summary><b>GPU Vs. CPU<b></summary>
|
||||
<summary>GPU Vs. CPU</summary>
|
||||
|
||||
```bash
|
||||
# GPU: Requires NVIDIA GPU with CUDA 12.1 support
|
||||
|
@ -172,35 +191,29 @@ docker compose up --build
|
|||
# CPU: ~10x slower than GPU inference
|
||||
docker compose -f docker-compose.cpu.yml up --build
|
||||
```
|
||||
</details>
|
||||
<details>
|
||||
<summary><b>Features</b></summary>
|
||||
|
||||
- OpenAI-compatible API endpoints (with optional Gradio Web UI)
|
||||
- GPU-accelerated inference (if desired)
|
||||
- Multiple audio formats: mp3, wav, opus, flac, (aac & pcm not implemented)
|
||||
- Natural Boundary Detection:
|
||||
- Automatically splits and stitches at sentence boundaries to reduce artifacts and maintain performacne
|
||||
- Voice Combination:
|
||||
- Averages model weights of any existing voicepacks
|
||||
- Saves generated voicepacks for future use
|
||||
|
||||
|
||||
|
||||
*Note: CPU Inference is currently a very basic implementation, and not heavily tested*
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Natural Boundary Detection</summary>
|
||||
|
||||
- Automatically splits and stitches at sentence boundaries
|
||||
- Helps to reduce artifacts and allow long form processing as the base model is only currently configured for approximately 30s output
|
||||
</details>
|
||||
|
||||
## Model and License
|
||||
|
||||
<details open>
|
||||
<summary><b>Model</b></summary>
|
||||
<summary>Model</summary>
|
||||
|
||||
This API uses the [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) model from HuggingFace.
|
||||
|
||||
Visit the model page for more details about training, architecture, and capabilities. I have no affiliation with any of their work, and produced this wrapper for ease of use and personal projects.
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>License</b></summary>
|
||||
|
||||
<summary>License</summary>
|
||||
This project is licensed under the Apache License 2.0 - see below for details:
|
||||
|
||||
- The Kokoro model weights are licensed under Apache 2.0 (see [model page](https://huggingface.co/hexgrad/Kokoro-82M))
|
||||
|
@ -209,3 +222,6 @@ This project is licensed under the Apache License 2.0 - see below for details:
|
|||
|
||||
The full Apache 2.0 license text can be found at: https://www.apache.org/licenses/LICENSE-2.0
|
||||
</details>
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@ from io import BytesIO
|
|||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import scipy.io.wavfile as wavfile
|
||||
from loguru import logger
|
||||
|
||||
|
||||
|
@ -20,7 +19,7 @@ class AudioService:
|
|||
Args:
|
||||
audio_data: Numpy array of audio samples
|
||||
sample_rate: Sample rate of the audio
|
||||
output_format: Target format (wav, mp3, etc.)
|
||||
output_format: Target format (wav, mp3, opus, flac, pcm)
|
||||
|
||||
Returns:
|
||||
Bytes of the converted audio
|
||||
|
@ -30,46 +29,36 @@ class AudioService:
|
|||
try:
|
||||
if output_format == "wav":
|
||||
logger.info("Writing to WAV format...")
|
||||
wavfile.write(buffer, sample_rate, audio_data)
|
||||
return buffer.getvalue()
|
||||
|
||||
# Ensure audio_data is in int16 format for WAV
|
||||
audio_data_wav = (
|
||||
audio_data / np.abs(audio_data).max() * np.iinfo(np.int16).max
|
||||
).astype(np.int16) # Normalize
|
||||
sf.write(buffer, audio_data_wav, sample_rate, format="WAV")
|
||||
elif output_format == "mp3":
|
||||
# For MP3, we need to convert to WAV first
|
||||
logger.info("Converting to MP3 format...")
|
||||
wav_buffer = BytesIO()
|
||||
wavfile.write(wav_buffer, sample_rate, audio_data)
|
||||
wav_buffer.seek(0)
|
||||
|
||||
# Convert WAV to MP3 using soundfile
|
||||
buffer = BytesIO()
|
||||
sf.write(buffer, audio_data, sample_rate, format="mp3")
|
||||
return buffer.getvalue()
|
||||
|
||||
# soundfile can write MP3 if ffmpeg or libsox is installed
|
||||
sf.write(buffer, audio_data, sample_rate, format="MP3")
|
||||
elif output_format == "opus":
|
||||
logger.info("Converting to Opus format...")
|
||||
sf.write(buffer, audio_data, sample_rate, format="ogg", subtype="opus")
|
||||
return buffer.getvalue()
|
||||
|
||||
sf.write(buffer, audio_data, sample_rate, format="OGG", subtype="OPUS")
|
||||
elif output_format == "flac":
|
||||
logger.info("Converting to FLAC format...")
|
||||
sf.write(buffer, audio_data, sample_rate, format="flac")
|
||||
return buffer.getvalue()
|
||||
|
||||
elif output_format == "aac":
|
||||
raise ValueError(
|
||||
"AAC format is not currently supported. Please use wav, mp3, opus, or flac."
|
||||
)
|
||||
|
||||
sf.write(buffer, audio_data, sample_rate, format="FLAC")
|
||||
elif output_format == "pcm":
|
||||
raise ValueError(
|
||||
"PCM format is not currently supported. Please use wav, mp3, opus, or flac."
|
||||
)
|
||||
|
||||
logger.info("Extracting PCM data...")
|
||||
# Ensure audio_data is in int16 format for PCM
|
||||
audio_data_pcm = (
|
||||
audio_data / np.abs(audio_data).max() * np.iinfo(np.int16).max
|
||||
).astype(np.int16) # Normalize
|
||||
buffer.write(audio_data_pcm.tobytes())
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Format {output_format} not supported. Supported formats are: wav, mp3, opus, flac."
|
||||
f"Format {output_format} not supported. Supported formats are: wav, mp3, opus, flac, pcm."
|
||||
)
|
||||
|
||||
buffer.seek(0)
|
||||
return buffer.getvalue()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting audio to {output_format}: {str(e)}")
|
||||
raise ValueError(f"Failed to convert audio to {output_format}: {str(e)}")
|
||||
|
|
|
@ -51,15 +51,19 @@ def test_convert_to_flac(sample_audio):
|
|||
def test_convert_to_aac_raises_error(sample_audio):
|
||||
"""Test that converting to AAC raises an error"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
with pytest.raises(ValueError, match="AAC format is not currently supported"):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Format aac not supported. Supported formats are: wav, mp3, opus, flac, pcm.",
|
||||
):
|
||||
AudioService.convert_audio(audio_data, sample_rate, "aac")
|
||||
|
||||
|
||||
def test_convert_to_pcm_raises_error(sample_audio):
|
||||
"""Test that converting to PCM raises an error"""
|
||||
def test_convert_to_pcm(sample_audio):
|
||||
"""Test converting to PCM format"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
with pytest.raises(ValueError, match="PCM format is not currently supported"):
|
||||
AudioService.convert_audio(audio_data, sample_rate, "pcm")
|
||||
result = AudioService.convert_audio(audio_data, sample_rate, "pcm")
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
def test_convert_to_invalid_format_raises_error(sample_audio):
|
||||
|
@ -67,3 +71,51 @@ def test_convert_to_invalid_format_raises_error(sample_audio):
|
|||
audio_data, sample_rate = sample_audio
|
||||
with pytest.raises(ValueError, match="Format invalid not supported"):
|
||||
AudioService.convert_audio(audio_data, sample_rate, "invalid")
|
||||
|
||||
|
||||
def test_normalization_wav(sample_audio):
|
||||
"""Test that WAV output is properly normalized to int16 range"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
# Create audio data outside int16 range
|
||||
large_audio = audio_data * 1e5
|
||||
result = AudioService.convert_audio(large_audio, sample_rate, "wav")
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
def test_normalization_pcm(sample_audio):
|
||||
"""Test that PCM output is properly normalized to int16 range"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
# Create audio data outside int16 range
|
||||
large_audio = audio_data * 1e5
|
||||
result = AudioService.convert_audio(large_audio, sample_rate, "pcm")
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
def test_invalid_audio_data():
|
||||
"""Test handling of invalid audio data"""
|
||||
invalid_audio = np.array([]) # Empty array
|
||||
sample_rate = 24000
|
||||
with pytest.raises(ValueError):
|
||||
AudioService.convert_audio(invalid_audio, sample_rate, "wav")
|
||||
|
||||
|
||||
def test_different_sample_rates(sample_audio):
|
||||
"""Test converting audio with different sample rates"""
|
||||
audio_data, _ = sample_audio
|
||||
sample_rates = [8000, 16000, 44100, 48000]
|
||||
|
||||
for rate in sample_rates:
|
||||
result = AudioService.convert_audio(audio_data, rate, "wav")
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
def test_buffer_position_after_conversion(sample_audio):
|
||||
"""Test that buffer position is reset after writing"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
result = AudioService.convert_audio(audio_data, sample_rate, "wav")
|
||||
# Convert again to ensure buffer was properly reset
|
||||
result2 = AudioService.convert_audio(audio_data, sample_rate, "wav")
|
||||
assert len(result) == len(result2)
|
||||
|
|
|
@ -4,6 +4,7 @@ import os
|
|||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import pytest
|
||||
|
||||
from api.src.services.tts import TTSModel, TTSService
|
||||
|
@ -119,6 +120,78 @@ def test_generate_audio_no_chunks(
|
|||
tts_service._generate_audio("Test text", "af", 1.0)
|
||||
|
||||
|
||||
@patch("torch.load")
|
||||
@patch("torch.save")
|
||||
@patch("torch.stack")
|
||||
@patch("torch.mean")
|
||||
@patch("os.path.exists")
|
||||
def test_combine_voices(
|
||||
mock_exists, mock_mean, mock_stack, mock_save, mock_load, tts_service
|
||||
):
|
||||
"""Test combining multiple voices"""
|
||||
# Setup mocks
|
||||
mock_exists.return_value = True
|
||||
mock_load.return_value = torch.tensor([1.0, 2.0])
|
||||
mock_stack.return_value = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
|
||||
mock_mean.return_value = torch.tensor([2.0, 3.0])
|
||||
|
||||
# Test combining two voices
|
||||
result = tts_service.combine_voices(["voice1", "voice2"])
|
||||
|
||||
assert result == "voice1_voice2"
|
||||
mock_stack.assert_called_once()
|
||||
mock_mean.assert_called_once()
|
||||
mock_save.assert_called_once()
|
||||
|
||||
|
||||
def test_combine_voices_invalid_input(tts_service):
|
||||
"""Test combining voices with invalid input"""
|
||||
# Test with empty list
|
||||
with pytest.raises(ValueError, match="At least 2 voices are required"):
|
||||
tts_service.combine_voices([])
|
||||
|
||||
# Test with single voice
|
||||
with pytest.raises(ValueError, match="At least 2 voices are required"):
|
||||
tts_service.combine_voices(["voice1"])
|
||||
|
||||
|
||||
@patch("os.makedirs")
|
||||
@patch("os.path.exists")
|
||||
@patch("os.listdir")
|
||||
@patch("torch.load")
|
||||
@patch("torch.save")
|
||||
@patch("os.path.join")
|
||||
def test_ensure_voices(
|
||||
mock_join,
|
||||
mock_save,
|
||||
mock_load,
|
||||
mock_listdir,
|
||||
mock_exists,
|
||||
mock_makedirs,
|
||||
tts_service,
|
||||
):
|
||||
"""Test voice directory initialization"""
|
||||
# Setup mocks
|
||||
mock_exists.side_effect = [
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
] # base_dir exists, voice files don't exist
|
||||
mock_listdir.return_value = ["voice1.pt", "voice2.pt"]
|
||||
mock_load.return_value = MagicMock()
|
||||
mock_join.return_value = "/fake/path"
|
||||
|
||||
# Test voice directory initialization
|
||||
tts_service._ensure_voices()
|
||||
|
||||
# Verify directory was created
|
||||
mock_makedirs.assert_called_once()
|
||||
|
||||
# Verify voices were loaded and saved
|
||||
assert mock_load.call_count == len(mock_listdir.return_value)
|
||||
assert mock_save.call_count == len(mock_listdir.return_value)
|
||||
|
||||
|
||||
@patch("api.src.services.tts.TTSModel.get_instance")
|
||||
@patch("os.path.exists")
|
||||
@patch("api.src.services.tts.normalize_text")
|
||||
|
@ -236,7 +309,6 @@ def test_generate_audio_without_stitching(
|
|||
"Test text", "af", 1.0, stitch_long_output=False
|
||||
)
|
||||
assert isinstance(audio, np.ndarray)
|
||||
assert isinstance(processing_time, float)
|
||||
assert len(audio) > 0
|
||||
mock_generate.assert_called_once()
|
||||
|
||||
|
|
BIN
examples/benchmarks/format_comparison.png
Normal file
BIN
examples/benchmarks/format_comparison.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 764 KiB |
284
examples/test_audio_formats.py
Normal file
284
examples/test_audio_formats.py
Normal file
|
@ -0,0 +1,284 @@
|
|||
"""Test script to generate and analyze different audio formats"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import openai
|
||||
import requests
|
||||
import soundfile as sf
|
||||
import matplotlib.pyplot as plt
|
||||
from scipy.io import wavfile
|
||||
|
||||
SAMPLE_TEXT = """
|
||||
That is the germ of my great discovery. But you are wrong to say that we cannot move about in Time.
|
||||
"""
|
||||
|
||||
# Configure OpenAI client
|
||||
client = openai.OpenAI(
|
||||
timeout=60,
|
||||
api_key="notneeded", # API key not required for our endpoint
|
||||
base_url="http://localhost:8880/v1", # Point to our local server with v1 prefix
|
||||
)
|
||||
|
||||
|
||||
def setup_plot(fig, ax, title):
|
||||
"""Configure plot styling"""
|
||||
# Improve grid
|
||||
ax.grid(True, linestyle="--", alpha=0.3, color="#ffffff")
|
||||
|
||||
# Set title and labels with better fonts and more padding
|
||||
ax.set_title(title, pad=40, fontsize=16, fontweight="bold", color="#ffffff")
|
||||
ax.set_xlabel(ax.get_xlabel(), fontsize=14, fontweight="medium", color="#ffffff")
|
||||
ax.set_ylabel(ax.get_ylabel(), fontsize=14, fontweight="medium", color="#ffffff")
|
||||
|
||||
# Improve tick labels
|
||||
ax.tick_params(labelsize=12, colors="#ffffff")
|
||||
|
||||
# Style spines
|
||||
for spine in ax.spines.values():
|
||||
spine.set_color("#ffffff")
|
||||
spine.set_alpha(0.3)
|
||||
spine.set_linewidth(0.5)
|
||||
|
||||
# Set background colors
|
||||
ax.set_facecolor("#1a1a2e")
|
||||
fig.patch.set_facecolor("#1a1a2e")
|
||||
|
||||
return fig, ax
|
||||
|
||||
|
||||
def plot_format_comparison(stats: list, output_dir: str):
|
||||
"""Plot audio format comparison"""
|
||||
plt.style.use("dark_background")
|
||||
|
||||
# Create figure with subplots
|
||||
fig = plt.figure(figsize=(18, 16)) # Taller figure to accommodate bottom legend
|
||||
fig.patch.set_facecolor("#1a1a2e")
|
||||
|
||||
# Create subplot grid with balanced spacing for waveforms
|
||||
gs_waves = plt.GridSpec(
|
||||
len(stats), 1, left=0.15, right=0.85, top=0.9, bottom=0.35, hspace=0.4
|
||||
)
|
||||
|
||||
# Plot waveforms for each format
|
||||
for i, stat in enumerate(stats):
|
||||
format_name = stat["format"].upper()
|
||||
try:
|
||||
# Handle PCM format differently
|
||||
if stat["format"] == "pcm":
|
||||
# Read raw PCM data (16-bit mono)
|
||||
with open(
|
||||
os.path.join(output_dir, f"test_audio.{stat['format']}"), "rb"
|
||||
) as f:
|
||||
raw_data = f.read()
|
||||
data = np.frombuffer(raw_data, dtype=np.int16)
|
||||
data = data.astype(np.float32) / 32768.0 # Convert to float [-1, 1]
|
||||
sr = 24000
|
||||
else:
|
||||
# Read other formats with soundfile
|
||||
data, sr = sf.read(
|
||||
os.path.join(output_dir, f"test_audio.{stat['format']}")
|
||||
)
|
||||
|
||||
# Plot waveform
|
||||
ax = plt.subplot(gs_waves[i])
|
||||
time = np.arange(len(data)) / sr
|
||||
plt.plot(time, data / np.max(np.abs(data)), linewidth=0.5, color="#ff2a6d")
|
||||
ax.set_xlabel("Time (seconds)")
|
||||
ax.set_ylabel("")
|
||||
ax.set_ylim(-1.1, 1.1)
|
||||
setup_plot(fig, ax, f"Waveform: {format_name}")
|
||||
except Exception as e:
|
||||
print(f"Error plotting waveform for {format_name}: {e}")
|
||||
|
||||
# Colors for formats
|
||||
colors = ["#ff2a6d", "#05d9e8", "#d1f7ff", "#ff9e00", "#8c1eff"]
|
||||
|
||||
# Create three subplots for metrics with more space at bottom for legend
|
||||
gs_bottom = plt.GridSpec(
|
||||
1,
|
||||
3,
|
||||
left=0.15,
|
||||
right=0.85,
|
||||
bottom=0.15,
|
||||
top=0.25, # More bottom space for legend
|
||||
wspace=0.3,
|
||||
)
|
||||
|
||||
# File Size subplot
|
||||
ax1 = plt.subplot(gs_bottom[0])
|
||||
metrics1 = [("File Size", [s["file_size_kb"] for s in stats], "KB")]
|
||||
|
||||
# Duration and Gen Time subplot
|
||||
ax2 = plt.subplot(gs_bottom[1])
|
||||
metrics2 = [
|
||||
("Duration", [s["duration_seconds"] for s in stats], "s"),
|
||||
("Gen Time", [s["generation_time"] for s in stats], "s"),
|
||||
]
|
||||
|
||||
# Sample Rate subplot
|
||||
ax3 = plt.subplot(gs_bottom[2])
|
||||
metrics3 = [("Sample Rate", [s["sample_rate"] / 1000 for s in stats], "kHz")]
|
||||
|
||||
def plot_grouped_bars(ax, metrics, show_legend=True):
|
||||
n_groups = len(metrics)
|
||||
n_formats = len(stats)
|
||||
# Use wider bars for time metrics
|
||||
bar_width = 0.175 if metrics == metrics2 else 0.1
|
||||
|
||||
indices = np.arange(n_groups)
|
||||
|
||||
# Get max value for y-axis scaling
|
||||
max_val = max(max(m[1]) for m in metrics)
|
||||
|
||||
for i, (stat, color) in enumerate(zip(stats, colors)):
|
||||
values = [m[1][i] for m in metrics]
|
||||
# Reduce spacing between bars for time metrics
|
||||
spacing = 1.1 if metrics == metrics2 else 1.0
|
||||
offset = (i - n_formats / 2 + 0.5) * bar_width * spacing
|
||||
bars = ax.bar(
|
||||
indices + offset,
|
||||
values,
|
||||
bar_width,
|
||||
label=stat["format"].upper(),
|
||||
color=color,
|
||||
alpha=0.8,
|
||||
)
|
||||
|
||||
# Add value labels on top of bars
|
||||
for bar in bars:
|
||||
height = bar.get_height()
|
||||
ax.text(
|
||||
bar.get_x() + bar.get_width() / 2.0,
|
||||
height,
|
||||
f"{height:.1f}",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
color="white",
|
||||
fontsize=10,
|
||||
)
|
||||
|
||||
ax.set_xticks(indices)
|
||||
ax.set_xticklabels([f"{m[0]}\n({m[2]})" for m in metrics])
|
||||
|
||||
# Set y-axis limits with some padding
|
||||
ax.set_ylim(0, max_val * 1.2)
|
||||
|
||||
if show_legend:
|
||||
# Place legend at the bottom
|
||||
ax.legend(
|
||||
bbox_to_anchor=(1.8, -0.8),
|
||||
loc="center",
|
||||
facecolor="#1a1a2e",
|
||||
edgecolor="#ffffff",
|
||||
ncol=len(stats),
|
||||
) # Show all formats in one row
|
||||
|
||||
# Plot all three subplots with shared legend
|
||||
plot_grouped_bars(ax1, metrics1, show_legend=True)
|
||||
plot_grouped_bars(ax2, metrics2, show_legend=False)
|
||||
plot_grouped_bars(ax3, metrics3, show_legend=False)
|
||||
|
||||
# Style all subplots
|
||||
setup_plot(fig, ax1, "File Size")
|
||||
setup_plot(fig, ax2, "Time Metrics")
|
||||
setup_plot(fig, ax3, "Sample Rate")
|
||||
|
||||
# Add y-axis labels
|
||||
ax1.set_ylabel("Value")
|
||||
ax2.set_ylabel("Value")
|
||||
ax3.set_ylabel("Value")
|
||||
|
||||
# Save the plot
|
||||
plt.savefig(os.path.join(output_dir, "format_comparison.png"), dpi=300)
|
||||
print(f"\nSaved format comparison plot to {output_dir}/format_comparison.png")
|
||||
|
||||
|
||||
def get_audio_stats(file_path: str) -> dict:
|
||||
"""Get audio file statistics"""
|
||||
file_size = os.path.getsize(file_path)
|
||||
file_size_kb = file_size / 1024 # Convert to KB
|
||||
|
||||
try:
|
||||
# Try reading with soundfile first
|
||||
data, sample_rate = sf.read(file_path)
|
||||
duration = len(data) / sample_rate
|
||||
channels = 1 if len(data.shape) == 1 else data.shape[1]
|
||||
|
||||
# Calculate audio statistics
|
||||
stats = {
|
||||
"format": Path(file_path).suffix[1:],
|
||||
"file_size_kb": round(file_size_kb, 2),
|
||||
"duration_seconds": round(duration, 2),
|
||||
"sample_rate": sample_rate,
|
||||
"channels": channels,
|
||||
"min_amplitude": float(np.min(data)),
|
||||
"max_amplitude": float(np.max(data)),
|
||||
"mean_amplitude": float(np.mean(np.abs(data))),
|
||||
"rms_amplitude": float(np.sqrt(np.mean(np.square(data)))),
|
||||
}
|
||||
return stats
|
||||
except:
|
||||
# For PCM, read raw bytes and estimate duration
|
||||
with open(file_path, "rb") as f:
|
||||
data = f.read()
|
||||
# Assuming 16-bit PCM mono at 24kHz
|
||||
samples = len(data) // 2 # 2 bytes per sample
|
||||
duration = samples / 24000
|
||||
return {
|
||||
"format": "pcm",
|
||||
"file_size_kb": round(file_size_kb, 2),
|
||||
"duration_seconds": round(duration, 2),
|
||||
"sample_rate": 24000,
|
||||
"channels": 1,
|
||||
"note": "PCM stats are estimated from raw bytes",
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
"""Generate and analyze audio in different formats"""
|
||||
# Create output directory
|
||||
output_dir = Path(__file__).parent / "output" / "test_formats"
|
||||
output_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# First generate audio in each format using the API
|
||||
voice = "af" # Using default voice
|
||||
formats = ["wav", "mp3", "opus", "flac", "pcm"]
|
||||
stats = []
|
||||
|
||||
for fmt in formats:
|
||||
output_path = output_dir / f"test_audio.{fmt}"
|
||||
print(f"\nGenerating {fmt.upper()} audio...")
|
||||
|
||||
# Generate and save
|
||||
start_time = time.time()
|
||||
response = client.audio.speech.create(
|
||||
model="kokoro", voice=voice, input=SAMPLE_TEXT, response_format=fmt
|
||||
)
|
||||
generation_time = time.time() - start_time
|
||||
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
# Get stats
|
||||
file_stats = get_audio_stats(str(output_path))
|
||||
file_stats["generation_time"] = round(generation_time, 3)
|
||||
stats.append(file_stats)
|
||||
|
||||
# Generate comparison plot
|
||||
plot_format_comparison(stats, str(output_dir))
|
||||
|
||||
# Print detailed statistics
|
||||
print("\nDetailed Audio Statistics:")
|
||||
print("=" * 100)
|
||||
for stat in stats:
|
||||
print(f"\n{stat['format'].upper()} Format:")
|
||||
for key, value in sorted(stat.items()):
|
||||
if key not in ["format"]: # Skip format as it's in the header
|
||||
print(f" {key}: {value}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Add table
Reference in a new issue