diff --git a/.gitignore b/.gitignore
index 98b9187..aebbfa7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -14,3 +14,5 @@ env/
.Python
+.coverage
+
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 44c98bf..4194878 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -2,8 +2,12 @@
Notable changes to this project will be documented in this file.
-## 2024-01-09
+## 2025-01-02
+- Audio Format Support:
+ - Added comprehensive audio format conversion support (mp3, wav, opus, flac)
+
+## 2025-01-01
### Added
- Gradio Web Interface:
- Added simple web UI utility for audio generation from input or txt file
diff --git a/README.md b/README.md
index 639c13d..ff794f6 100644
--- a/README.md
+++ b/README.md
@@ -3,8 +3,8 @@
# Kokoro TTS API
-[]()
-[]()
+[]()
+[]()
[](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667)
Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model
@@ -14,8 +14,7 @@ Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokor
- automatic chunking/stitching for long texts
- simple audio generation web ui utility
-
-Quick Start
+## Quick Start
The service can be accessed through either the API endpoints or the Gradio web interface.
@@ -48,9 +47,10 @@ The service can be accessed through either the API endpoints or the Gradio web i
-
+
+## Features
-OpenAI-Compatible Speech Endpoint
+OpenAI-Compatible Speech Endpoint
```python
# Using OpenAI's Python library
@@ -98,7 +98,10 @@ python examples/test_all_voices.py # Test all available voices
-Voice Combination
+Voice Combination
+
+- Averages model weights of any existing voicepacks
+- Saves generated voicepacks for future use
Combine voices and generate audio:
```python
@@ -129,7 +132,23 @@ response = requests.post(
-Gradio Web Utility
+Multiple Output Audio Formats
+
+- mp3
+- wav
+- opus
+- flac
+- aac
+- pcm
+
+
+
+
+
+
+
+
+Gradio Web Utility
Access the interactive web UI at http://localhost:7860 after starting the service. Features include:
- Voice/format/speed selection
@@ -141,9 +160,9 @@ If you only want the API, just comment out everything in the docker-compose.yml
Currently, voices created via the API are accessible here, but voice combination/creation has not yet been added
-
+## Processing Details
-Performance Benchmarks
+Performance Benchmarks
Benchmarking was performed on generation via the local API using text lengths up to feature-length books (~1.5 hours output), measuring processing time and realtime factor. Tests were run on:
- Windows 11 Home w/ WSL2
@@ -163,7 +182,7 @@ Key Performance Metrics:
- Average Processing Rate: 137.67 tokens/second (cl100k_base)
-GPU Vs. CPU
+GPU Vs. CPU
```bash
# GPU: Requires NVIDIA GPU with CUDA 12.1 support
@@ -172,35 +191,29 @@ docker compose up --build
# CPU: ~10x slower than GPU inference
docker compose -f docker-compose.cpu.yml up --build
```
-
-
-Features
-
-- OpenAI-compatible API endpoints (with optional Gradio Web UI)
-- GPU-accelerated inference (if desired)
-- Multiple audio formats: mp3, wav, opus, flac, (aac & pcm not implemented)
-- Natural Boundary Detection:
- - Automatically splits and stitches at sentence boundaries to reduce artifacts and maintain performacne
-- Voice Combination:
- - Averages model weights of any existing voicepacks
- - Saves generated voicepacks for future use
-
-
*Note: CPU Inference is currently a very basic implementation, and not heavily tested*
+
+
+Natural Boundary Detection
+
+- Automatically splits and stitches at sentence boundaries
+- Helps to reduce artifacts and allow long form processing as the base model is only currently configured for approximately 30s output
+
+
+## Model and License
+
-Model
+Model
This API uses the [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) model from HuggingFace.
Visit the model page for more details about training, architecture, and capabilities. I have no affiliation with any of their work, and produced this wrapper for ease of use and personal projects.
-
-License
-
+License
This project is licensed under the Apache License 2.0 - see below for details:
- The Kokoro model weights are licensed under Apache 2.0 (see [model page](https://huggingface.co/hexgrad/Kokoro-82M))
@@ -209,3 +222,6 @@ This project is licensed under the Apache License 2.0 - see below for details:
The full Apache 2.0 license text can be found at: https://www.apache.org/licenses/LICENSE-2.0
+
+
+
diff --git a/api/src/services/audio.py b/api/src/services/audio.py
index ce2bccd..b8cc708 100644
--- a/api/src/services/audio.py
+++ b/api/src/services/audio.py
@@ -30,7 +30,9 @@ class AudioService:
if output_format == "wav":
logger.info("Writing to WAV format...")
# Ensure audio_data is in int16 format for WAV
- audio_data_wav = (audio_data / np.abs(audio_data).max() * np.iinfo(np.int16).max).astype(np.int16) # Normalize
+ audio_data_wav = (
+ audio_data / np.abs(audio_data).max() * np.iinfo(np.int16).max
+ ).astype(np.int16) # Normalize
sf.write(buffer, audio_data_wav, sample_rate, format="WAV")
elif output_format == "mp3":
logger.info("Converting to MP3 format...")
@@ -45,7 +47,9 @@ class AudioService:
elif output_format == "pcm":
logger.info("Extracting PCM data...")
# Ensure audio_data is in int16 format for PCM
- audio_data_pcm = (audio_data / np.abs(audio_data).max() * np.iinfo(np.int16).max).astype(np.int16) # Normalize
+ audio_data_pcm = (
+ audio_data / np.abs(audio_data).max() * np.iinfo(np.int16).max
+ ).astype(np.int16) # Normalize
buffer.write(audio_data_pcm.tobytes())
else:
raise ValueError(
diff --git a/api/tests/test_audio_service.py b/api/tests/test_audio_service.py
index ac0780e..32f4300 100644
--- a/api/tests/test_audio_service.py
+++ b/api/tests/test_audio_service.py
@@ -51,15 +51,19 @@ def test_convert_to_flac(sample_audio):
def test_convert_to_aac_raises_error(sample_audio):
"""Test that converting to AAC raises an error"""
audio_data, sample_rate = sample_audio
- with pytest.raises(ValueError, match="AAC format is not currently supported"):
+ with pytest.raises(
+ ValueError,
+ match="Format aac not supported. Supported formats are: wav, mp3, opus, flac, pcm.",
+ ):
AudioService.convert_audio(audio_data, sample_rate, "aac")
-def test_convert_to_pcm_raises_error(sample_audio):
- """Test that converting to PCM raises an error"""
+def test_convert_to_pcm(sample_audio):
+ """Test converting to PCM format"""
audio_data, sample_rate = sample_audio
- with pytest.raises(ValueError, match="PCM format is not currently supported"):
- AudioService.convert_audio(audio_data, sample_rate, "pcm")
+ result = AudioService.convert_audio(audio_data, sample_rate, "pcm")
+ assert isinstance(result, bytes)
+ assert len(result) > 0
def test_convert_to_invalid_format_raises_error(sample_audio):
@@ -67,3 +71,51 @@ def test_convert_to_invalid_format_raises_error(sample_audio):
audio_data, sample_rate = sample_audio
with pytest.raises(ValueError, match="Format invalid not supported"):
AudioService.convert_audio(audio_data, sample_rate, "invalid")
+
+
+def test_normalization_wav(sample_audio):
+ """Test that WAV output is properly normalized to int16 range"""
+ audio_data, sample_rate = sample_audio
+ # Create audio data outside int16 range
+ large_audio = audio_data * 1e5
+ result = AudioService.convert_audio(large_audio, sample_rate, "wav")
+ assert isinstance(result, bytes)
+ assert len(result) > 0
+
+
+def test_normalization_pcm(sample_audio):
+ """Test that PCM output is properly normalized to int16 range"""
+ audio_data, sample_rate = sample_audio
+ # Create audio data outside int16 range
+ large_audio = audio_data * 1e5
+ result = AudioService.convert_audio(large_audio, sample_rate, "pcm")
+ assert isinstance(result, bytes)
+ assert len(result) > 0
+
+
+def test_invalid_audio_data():
+ """Test handling of invalid audio data"""
+ invalid_audio = np.array([]) # Empty array
+ sample_rate = 24000
+ with pytest.raises(ValueError):
+ AudioService.convert_audio(invalid_audio, sample_rate, "wav")
+
+
+def test_different_sample_rates(sample_audio):
+ """Test converting audio with different sample rates"""
+ audio_data, _ = sample_audio
+ sample_rates = [8000, 16000, 44100, 48000]
+
+ for rate in sample_rates:
+ result = AudioService.convert_audio(audio_data, rate, "wav")
+ assert isinstance(result, bytes)
+ assert len(result) > 0
+
+
+def test_buffer_position_after_conversion(sample_audio):
+ """Test that buffer position is reset after writing"""
+ audio_data, sample_rate = sample_audio
+ result = AudioService.convert_audio(audio_data, sample_rate, "wav")
+ # Convert again to ensure buffer was properly reset
+ result2 = AudioService.convert_audio(audio_data, sample_rate, "wav")
+ assert len(result) == len(result2)
diff --git a/api/tests/test_tts_service.py b/api/tests/test_tts_service.py
index 8616c5f..d2a138b 100644
--- a/api/tests/test_tts_service.py
+++ b/api/tests/test_tts_service.py
@@ -4,6 +4,7 @@ import os
from unittest.mock import MagicMock, call, patch
import numpy as np
+import torch
import pytest
from api.src.services.tts import TTSModel, TTSService
@@ -119,6 +120,78 @@ def test_generate_audio_no_chunks(
tts_service._generate_audio("Test text", "af", 1.0)
+@patch("torch.load")
+@patch("torch.save")
+@patch("torch.stack")
+@patch("torch.mean")
+@patch("os.path.exists")
+def test_combine_voices(
+ mock_exists, mock_mean, mock_stack, mock_save, mock_load, tts_service
+):
+ """Test combining multiple voices"""
+ # Setup mocks
+ mock_exists.return_value = True
+ mock_load.return_value = torch.tensor([1.0, 2.0])
+ mock_stack.return_value = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
+ mock_mean.return_value = torch.tensor([2.0, 3.0])
+
+ # Test combining two voices
+ result = tts_service.combine_voices(["voice1", "voice2"])
+
+ assert result == "voice1_voice2"
+ mock_stack.assert_called_once()
+ mock_mean.assert_called_once()
+ mock_save.assert_called_once()
+
+
+def test_combine_voices_invalid_input(tts_service):
+ """Test combining voices with invalid input"""
+ # Test with empty list
+ with pytest.raises(ValueError, match="At least 2 voices are required"):
+ tts_service.combine_voices([])
+
+ # Test with single voice
+ with pytest.raises(ValueError, match="At least 2 voices are required"):
+ tts_service.combine_voices(["voice1"])
+
+
+@patch("os.makedirs")
+@patch("os.path.exists")
+@patch("os.listdir")
+@patch("torch.load")
+@patch("torch.save")
+@patch("os.path.join")
+def test_ensure_voices(
+ mock_join,
+ mock_save,
+ mock_load,
+ mock_listdir,
+ mock_exists,
+ mock_makedirs,
+ tts_service,
+):
+ """Test voice directory initialization"""
+ # Setup mocks
+ mock_exists.side_effect = [
+ True,
+ False,
+ False,
+ ] # base_dir exists, voice files don't exist
+ mock_listdir.return_value = ["voice1.pt", "voice2.pt"]
+ mock_load.return_value = MagicMock()
+ mock_join.return_value = "/fake/path"
+
+ # Test voice directory initialization
+ tts_service._ensure_voices()
+
+ # Verify directory was created
+ mock_makedirs.assert_called_once()
+
+ # Verify voices were loaded and saved
+ assert mock_load.call_count == len(mock_listdir.return_value)
+ assert mock_save.call_count == len(mock_listdir.return_value)
+
+
@patch("api.src.services.tts.TTSModel.get_instance")
@patch("os.path.exists")
@patch("api.src.services.tts.normalize_text")
@@ -236,7 +309,6 @@ def test_generate_audio_without_stitching(
"Test text", "af", 1.0, stitch_long_output=False
)
assert isinstance(audio, np.ndarray)
- assert isinstance(processing_time, float)
assert len(audio) > 0
mock_generate.assert_called_once()
diff --git a/examples/benchmarks/format_comparison.png b/examples/benchmarks/format_comparison.png
new file mode 100644
index 0000000..95ac515
Binary files /dev/null and b/examples/benchmarks/format_comparison.png differ
diff --git a/examples/test_audio_formats.py b/examples/test_audio_formats.py
new file mode 100644
index 0000000..e126dec
--- /dev/null
+++ b/examples/test_audio_formats.py
@@ -0,0 +1,284 @@
+"""Test script to generate and analyze different audio formats"""
+
+import os
+import time
+from pathlib import Path
+
+import numpy as np
+import openai
+import requests
+import soundfile as sf
+import matplotlib.pyplot as plt
+from scipy.io import wavfile
+
+SAMPLE_TEXT = """
+That is the germ of my great discovery. But you are wrong to say that we cannot move about in Time.
+"""
+
+# Configure OpenAI client
+client = openai.OpenAI(
+ timeout=60,
+ api_key="notneeded", # API key not required for our endpoint
+ base_url="http://localhost:8880/v1", # Point to our local server with v1 prefix
+)
+
+
+def setup_plot(fig, ax, title):
+ """Configure plot styling"""
+ # Improve grid
+ ax.grid(True, linestyle="--", alpha=0.3, color="#ffffff")
+
+ # Set title and labels with better fonts and more padding
+ ax.set_title(title, pad=40, fontsize=16, fontweight="bold", color="#ffffff")
+ ax.set_xlabel(ax.get_xlabel(), fontsize=14, fontweight="medium", color="#ffffff")
+ ax.set_ylabel(ax.get_ylabel(), fontsize=14, fontweight="medium", color="#ffffff")
+
+ # Improve tick labels
+ ax.tick_params(labelsize=12, colors="#ffffff")
+
+ # Style spines
+ for spine in ax.spines.values():
+ spine.set_color("#ffffff")
+ spine.set_alpha(0.3)
+ spine.set_linewidth(0.5)
+
+ # Set background colors
+ ax.set_facecolor("#1a1a2e")
+ fig.patch.set_facecolor("#1a1a2e")
+
+ return fig, ax
+
+
+def plot_format_comparison(stats: list, output_dir: str):
+ """Plot audio format comparison"""
+ plt.style.use("dark_background")
+
+ # Create figure with subplots
+ fig = plt.figure(figsize=(18, 16)) # Taller figure to accommodate bottom legend
+ fig.patch.set_facecolor("#1a1a2e")
+
+ # Create subplot grid with balanced spacing for waveforms
+ gs_waves = plt.GridSpec(
+ len(stats), 1, left=0.15, right=0.85, top=0.9, bottom=0.35, hspace=0.4
+ )
+
+ # Plot waveforms for each format
+ for i, stat in enumerate(stats):
+ format_name = stat["format"].upper()
+ try:
+ # Handle PCM format differently
+ if stat["format"] == "pcm":
+ # Read raw PCM data (16-bit mono)
+ with open(
+ os.path.join(output_dir, f"test_audio.{stat['format']}"), "rb"
+ ) as f:
+ raw_data = f.read()
+ data = np.frombuffer(raw_data, dtype=np.int16)
+ data = data.astype(np.float32) / 32768.0 # Convert to float [-1, 1]
+ sr = 24000
+ else:
+ # Read other formats with soundfile
+ data, sr = sf.read(
+ os.path.join(output_dir, f"test_audio.{stat['format']}")
+ )
+
+ # Plot waveform
+ ax = plt.subplot(gs_waves[i])
+ time = np.arange(len(data)) / sr
+ plt.plot(time, data / np.max(np.abs(data)), linewidth=0.5, color="#ff2a6d")
+ ax.set_xlabel("Time (seconds)")
+ ax.set_ylabel("")
+ ax.set_ylim(-1.1, 1.1)
+ setup_plot(fig, ax, f"Waveform: {format_name}")
+ except Exception as e:
+ print(f"Error plotting waveform for {format_name}: {e}")
+
+ # Colors for formats
+ colors = ["#ff2a6d", "#05d9e8", "#d1f7ff", "#ff9e00", "#8c1eff"]
+
+ # Create three subplots for metrics with more space at bottom for legend
+ gs_bottom = plt.GridSpec(
+ 1,
+ 3,
+ left=0.15,
+ right=0.85,
+ bottom=0.15,
+ top=0.25, # More bottom space for legend
+ wspace=0.3,
+ )
+
+ # File Size subplot
+ ax1 = plt.subplot(gs_bottom[0])
+ metrics1 = [("File Size", [s["file_size_kb"] for s in stats], "KB")]
+
+ # Duration and Gen Time subplot
+ ax2 = plt.subplot(gs_bottom[1])
+ metrics2 = [
+ ("Duration", [s["duration_seconds"] for s in stats], "s"),
+ ("Gen Time", [s["generation_time"] for s in stats], "s"),
+ ]
+
+ # Sample Rate subplot
+ ax3 = plt.subplot(gs_bottom[2])
+ metrics3 = [("Sample Rate", [s["sample_rate"] / 1000 for s in stats], "kHz")]
+
+ def plot_grouped_bars(ax, metrics, show_legend=True):
+ n_groups = len(metrics)
+ n_formats = len(stats)
+ # Use wider bars for time metrics
+ bar_width = 0.175 if metrics == metrics2 else 0.1
+
+ indices = np.arange(n_groups)
+
+ # Get max value for y-axis scaling
+ max_val = max(max(m[1]) for m in metrics)
+
+ for i, (stat, color) in enumerate(zip(stats, colors)):
+ values = [m[1][i] for m in metrics]
+ # Reduce spacing between bars for time metrics
+ spacing = 1.1 if metrics == metrics2 else 1.0
+ offset = (i - n_formats / 2 + 0.5) * bar_width * spacing
+ bars = ax.bar(
+ indices + offset,
+ values,
+ bar_width,
+ label=stat["format"].upper(),
+ color=color,
+ alpha=0.8,
+ )
+
+ # Add value labels on top of bars
+ for bar in bars:
+ height = bar.get_height()
+ ax.text(
+ bar.get_x() + bar.get_width() / 2.0,
+ height,
+ f"{height:.1f}",
+ ha="center",
+ va="bottom",
+ color="white",
+ fontsize=10,
+ )
+
+ ax.set_xticks(indices)
+ ax.set_xticklabels([f"{m[0]}\n({m[2]})" for m in metrics])
+
+ # Set y-axis limits with some padding
+ ax.set_ylim(0, max_val * 1.2)
+
+ if show_legend:
+ # Place legend at the bottom
+ ax.legend(
+ bbox_to_anchor=(1.8, -0.8),
+ loc="center",
+ facecolor="#1a1a2e",
+ edgecolor="#ffffff",
+ ncol=len(stats),
+ ) # Show all formats in one row
+
+ # Plot all three subplots with shared legend
+ plot_grouped_bars(ax1, metrics1, show_legend=True)
+ plot_grouped_bars(ax2, metrics2, show_legend=False)
+ plot_grouped_bars(ax3, metrics3, show_legend=False)
+
+ # Style all subplots
+ setup_plot(fig, ax1, "File Size")
+ setup_plot(fig, ax2, "Time Metrics")
+ setup_plot(fig, ax3, "Sample Rate")
+
+ # Add y-axis labels
+ ax1.set_ylabel("Value")
+ ax2.set_ylabel("Value")
+ ax3.set_ylabel("Value")
+
+ # Save the plot
+ plt.savefig(os.path.join(output_dir, "format_comparison.png"), dpi=300)
+ print(f"\nSaved format comparison plot to {output_dir}/format_comparison.png")
+
+
+def get_audio_stats(file_path: str) -> dict:
+ """Get audio file statistics"""
+ file_size = os.path.getsize(file_path)
+ file_size_kb = file_size / 1024 # Convert to KB
+
+ try:
+ # Try reading with soundfile first
+ data, sample_rate = sf.read(file_path)
+ duration = len(data) / sample_rate
+ channels = 1 if len(data.shape) == 1 else data.shape[1]
+
+ # Calculate audio statistics
+ stats = {
+ "format": Path(file_path).suffix[1:],
+ "file_size_kb": round(file_size_kb, 2),
+ "duration_seconds": round(duration, 2),
+ "sample_rate": sample_rate,
+ "channels": channels,
+ "min_amplitude": float(np.min(data)),
+ "max_amplitude": float(np.max(data)),
+ "mean_amplitude": float(np.mean(np.abs(data))),
+ "rms_amplitude": float(np.sqrt(np.mean(np.square(data)))),
+ }
+ return stats
+ except:
+ # For PCM, read raw bytes and estimate duration
+ with open(file_path, "rb") as f:
+ data = f.read()
+ # Assuming 16-bit PCM mono at 24kHz
+ samples = len(data) // 2 # 2 bytes per sample
+ duration = samples / 24000
+ return {
+ "format": "pcm",
+ "file_size_kb": round(file_size_kb, 2),
+ "duration_seconds": round(duration, 2),
+ "sample_rate": 24000,
+ "channels": 1,
+ "note": "PCM stats are estimated from raw bytes",
+ }
+
+
+def main():
+ """Generate and analyze audio in different formats"""
+ # Create output directory
+ output_dir = Path(__file__).parent / "output" / "test_formats"
+ output_dir.mkdir(exist_ok=True, parents=True)
+
+ # First generate audio in each format using the API
+ voice = "af" # Using default voice
+ formats = ["wav", "mp3", "opus", "flac", "pcm"]
+ stats = []
+
+ for fmt in formats:
+ output_path = output_dir / f"test_audio.{fmt}"
+ print(f"\nGenerating {fmt.upper()} audio...")
+
+ # Generate and save
+ start_time = time.time()
+ response = client.audio.speech.create(
+ model="kokoro", voice=voice, input=SAMPLE_TEXT, response_format=fmt
+ )
+ generation_time = time.time() - start_time
+
+ with open(output_path, "wb") as f:
+ f.write(response.content)
+
+ # Get stats
+ file_stats = get_audio_stats(str(output_path))
+ file_stats["generation_time"] = round(generation_time, 3)
+ stats.append(file_stats)
+
+ # Generate comparison plot
+ plot_format_comparison(stats, str(output_dir))
+
+ # Print detailed statistics
+ print("\nDetailed Audio Statistics:")
+ print("=" * 100)
+ for stat in stats:
+ print(f"\n{stat['format'].upper()} Format:")
+ for key, value in sorted(stat.items()):
+ if key not in ["format"]: # Skip format as it's in the header
+ print(f" {key}: {value}")
+
+
+if __name__ == "__main__":
+ main()