diff --git a/.gitignore b/.gitignore index 98b9187..aebbfa7 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,5 @@ env/ .Python +.coverage + diff --git a/CHANGELOG.md b/CHANGELOG.md index 44c98bf..4194878 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,8 +2,12 @@ Notable changes to this project will be documented in this file. -## 2024-01-09 +## 2025-01-02 +- Audio Format Support: + - Added comprehensive audio format conversion support (mp3, wav, opus, flac) + +## 2025-01-01 ### Added - Gradio Web Interface: - Added simple web UI utility for audio generation from input or txt file diff --git a/README.md b/README.md index 639c13d..ff794f6 100644 --- a/README.md +++ b/README.md @@ -3,8 +3,8 @@

# Kokoro TTS API -[![Tests](https://img.shields.io/badge/tests-81%20passed-darkgreen)]() -[![Coverage](https://img.shields.io/badge/coverage-76%25-darkgreen)]() +[![Tests](https://img.shields.io/badge/tests-89%20passed-darkgreen)]() +[![Coverage](https://img.shields.io/badge/coverage-80%25-darkgreen)]() [![Tested at Model Commit](https://img.shields.io/badge/last--tested--model--commit-a67f113-blue)](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667) Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model @@ -14,8 +14,7 @@ Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokor - automatic chunking/stitching for long texts - simple audio generation web ui utility -
-Quick Start +## Quick Start The service can be accessed through either the API endpoints or the Gradio web interface. @@ -48,9 +47,10 @@ The service can be accessed through either the API endpoints or the Gradio web i

Voice Analysis Comparison

-
+ +## Features
-OpenAI-Compatible Speech Endpoint +OpenAI-Compatible Speech Endpoint ```python # Using OpenAI's Python library @@ -98,7 +98,10 @@ python examples/test_all_voices.py # Test all available voices
-Voice Combination +Voice Combination + +- Averages model weights of any existing voicepacks +- Saves generated voicepacks for future use Combine voices and generate audio: ```python @@ -129,7 +132,23 @@ response = requests.post(
-Gradio Web Utility +Multiple Output Audio Formats + +- mp3 +- wav +- opus +- flac +- aac +- pcm + +

+Audio Format Comparison +

+ +
+ +
+Gradio Web Utility Access the interactive web UI at http://localhost:7860 after starting the service. Features include: - Voice/format/speed selection @@ -141,9 +160,9 @@ If you only want the API, just comment out everything in the docker-compose.yml Currently, voices created via the API are accessible here, but voice combination/creation has not yet been added
- +## Processing Details
-Performance Benchmarks +Performance Benchmarks Benchmarking was performed on generation via the local API using text lengths up to feature-length books (~1.5 hours output), measuring processing time and realtime factor. Tests were run on: - Windows 11 Home w/ WSL2 @@ -163,7 +182,7 @@ Key Performance Metrics: - Average Processing Rate: 137.67 tokens/second (cl100k_base)
-GPU Vs. CPU +GPU Vs. CPU ```bash # GPU: Requires NVIDIA GPU with CUDA 12.1 support @@ -172,35 +191,29 @@ docker compose up --build # CPU: ~10x slower than GPU inference docker compose -f docker-compose.cpu.yml up --build ``` -
-
-Features - -- OpenAI-compatible API endpoints (with optional Gradio Web UI) -- GPU-accelerated inference (if desired) -- Multiple audio formats: mp3, wav, opus, flac, (aac & pcm not implemented) -- Natural Boundary Detection: - - Automatically splits and stitches at sentence boundaries to reduce artifacts and maintain performacne -- Voice Combination: - - Averages model weights of any existing voicepacks - - Saves generated voicepacks for future use - - *Note: CPU Inference is currently a very basic implementation, and not heavily tested* +
+
+Natural Boundary Detection + +- Automatically splits and stitches at sentence boundaries +- Helps to reduce artifacts and allow long form processing as the base model is only currently configured for approximately 30s output +
+ +## Model and License +
-Model +Model This API uses the [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) model from HuggingFace. Visit the model page for more details about training, architecture, and capabilities. I have no affiliation with any of their work, and produced this wrapper for ease of use and personal projects.
-
-License - +License This project is licensed under the Apache License 2.0 - see below for details: - The Kokoro model weights are licensed under Apache 2.0 (see [model page](https://huggingface.co/hexgrad/Kokoro-82M)) @@ -209,3 +222,6 @@ This project is licensed under the Apache License 2.0 - see below for details: The full Apache 2.0 license text can be found at: https://www.apache.org/licenses/LICENSE-2.0
+ + + diff --git a/api/src/services/audio.py b/api/src/services/audio.py index ce2bccd..b8cc708 100644 --- a/api/src/services/audio.py +++ b/api/src/services/audio.py @@ -30,7 +30,9 @@ class AudioService: if output_format == "wav": logger.info("Writing to WAV format...") # Ensure audio_data is in int16 format for WAV - audio_data_wav = (audio_data / np.abs(audio_data).max() * np.iinfo(np.int16).max).astype(np.int16) # Normalize + audio_data_wav = ( + audio_data / np.abs(audio_data).max() * np.iinfo(np.int16).max + ).astype(np.int16) # Normalize sf.write(buffer, audio_data_wav, sample_rate, format="WAV") elif output_format == "mp3": logger.info("Converting to MP3 format...") @@ -45,7 +47,9 @@ class AudioService: elif output_format == "pcm": logger.info("Extracting PCM data...") # Ensure audio_data is in int16 format for PCM - audio_data_pcm = (audio_data / np.abs(audio_data).max() * np.iinfo(np.int16).max).astype(np.int16) # Normalize + audio_data_pcm = ( + audio_data / np.abs(audio_data).max() * np.iinfo(np.int16).max + ).astype(np.int16) # Normalize buffer.write(audio_data_pcm.tobytes()) else: raise ValueError( diff --git a/api/tests/test_audio_service.py b/api/tests/test_audio_service.py index ac0780e..32f4300 100644 --- a/api/tests/test_audio_service.py +++ b/api/tests/test_audio_service.py @@ -51,15 +51,19 @@ def test_convert_to_flac(sample_audio): def test_convert_to_aac_raises_error(sample_audio): """Test that converting to AAC raises an error""" audio_data, sample_rate = sample_audio - with pytest.raises(ValueError, match="AAC format is not currently supported"): + with pytest.raises( + ValueError, + match="Format aac not supported. Supported formats are: wav, mp3, opus, flac, pcm.", + ): AudioService.convert_audio(audio_data, sample_rate, "aac") -def test_convert_to_pcm_raises_error(sample_audio): - """Test that converting to PCM raises an error""" +def test_convert_to_pcm(sample_audio): + """Test converting to PCM format""" audio_data, sample_rate = sample_audio - with pytest.raises(ValueError, match="PCM format is not currently supported"): - AudioService.convert_audio(audio_data, sample_rate, "pcm") + result = AudioService.convert_audio(audio_data, sample_rate, "pcm") + assert isinstance(result, bytes) + assert len(result) > 0 def test_convert_to_invalid_format_raises_error(sample_audio): @@ -67,3 +71,51 @@ def test_convert_to_invalid_format_raises_error(sample_audio): audio_data, sample_rate = sample_audio with pytest.raises(ValueError, match="Format invalid not supported"): AudioService.convert_audio(audio_data, sample_rate, "invalid") + + +def test_normalization_wav(sample_audio): + """Test that WAV output is properly normalized to int16 range""" + audio_data, sample_rate = sample_audio + # Create audio data outside int16 range + large_audio = audio_data * 1e5 + result = AudioService.convert_audio(large_audio, sample_rate, "wav") + assert isinstance(result, bytes) + assert len(result) > 0 + + +def test_normalization_pcm(sample_audio): + """Test that PCM output is properly normalized to int16 range""" + audio_data, sample_rate = sample_audio + # Create audio data outside int16 range + large_audio = audio_data * 1e5 + result = AudioService.convert_audio(large_audio, sample_rate, "pcm") + assert isinstance(result, bytes) + assert len(result) > 0 + + +def test_invalid_audio_data(): + """Test handling of invalid audio data""" + invalid_audio = np.array([]) # Empty array + sample_rate = 24000 + with pytest.raises(ValueError): + AudioService.convert_audio(invalid_audio, sample_rate, "wav") + + +def test_different_sample_rates(sample_audio): + """Test converting audio with different sample rates""" + audio_data, _ = sample_audio + sample_rates = [8000, 16000, 44100, 48000] + + for rate in sample_rates: + result = AudioService.convert_audio(audio_data, rate, "wav") + assert isinstance(result, bytes) + assert len(result) > 0 + + +def test_buffer_position_after_conversion(sample_audio): + """Test that buffer position is reset after writing""" + audio_data, sample_rate = sample_audio + result = AudioService.convert_audio(audio_data, sample_rate, "wav") + # Convert again to ensure buffer was properly reset + result2 = AudioService.convert_audio(audio_data, sample_rate, "wav") + assert len(result) == len(result2) diff --git a/api/tests/test_tts_service.py b/api/tests/test_tts_service.py index 8616c5f..d2a138b 100644 --- a/api/tests/test_tts_service.py +++ b/api/tests/test_tts_service.py @@ -4,6 +4,7 @@ import os from unittest.mock import MagicMock, call, patch import numpy as np +import torch import pytest from api.src.services.tts import TTSModel, TTSService @@ -119,6 +120,78 @@ def test_generate_audio_no_chunks( tts_service._generate_audio("Test text", "af", 1.0) +@patch("torch.load") +@patch("torch.save") +@patch("torch.stack") +@patch("torch.mean") +@patch("os.path.exists") +def test_combine_voices( + mock_exists, mock_mean, mock_stack, mock_save, mock_load, tts_service +): + """Test combining multiple voices""" + # Setup mocks + mock_exists.return_value = True + mock_load.return_value = torch.tensor([1.0, 2.0]) + mock_stack.return_value = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + mock_mean.return_value = torch.tensor([2.0, 3.0]) + + # Test combining two voices + result = tts_service.combine_voices(["voice1", "voice2"]) + + assert result == "voice1_voice2" + mock_stack.assert_called_once() + mock_mean.assert_called_once() + mock_save.assert_called_once() + + +def test_combine_voices_invalid_input(tts_service): + """Test combining voices with invalid input""" + # Test with empty list + with pytest.raises(ValueError, match="At least 2 voices are required"): + tts_service.combine_voices([]) + + # Test with single voice + with pytest.raises(ValueError, match="At least 2 voices are required"): + tts_service.combine_voices(["voice1"]) + + +@patch("os.makedirs") +@patch("os.path.exists") +@patch("os.listdir") +@patch("torch.load") +@patch("torch.save") +@patch("os.path.join") +def test_ensure_voices( + mock_join, + mock_save, + mock_load, + mock_listdir, + mock_exists, + mock_makedirs, + tts_service, +): + """Test voice directory initialization""" + # Setup mocks + mock_exists.side_effect = [ + True, + False, + False, + ] # base_dir exists, voice files don't exist + mock_listdir.return_value = ["voice1.pt", "voice2.pt"] + mock_load.return_value = MagicMock() + mock_join.return_value = "/fake/path" + + # Test voice directory initialization + tts_service._ensure_voices() + + # Verify directory was created + mock_makedirs.assert_called_once() + + # Verify voices were loaded and saved + assert mock_load.call_count == len(mock_listdir.return_value) + assert mock_save.call_count == len(mock_listdir.return_value) + + @patch("api.src.services.tts.TTSModel.get_instance") @patch("os.path.exists") @patch("api.src.services.tts.normalize_text") @@ -236,7 +309,6 @@ def test_generate_audio_without_stitching( "Test text", "af", 1.0, stitch_long_output=False ) assert isinstance(audio, np.ndarray) - assert isinstance(processing_time, float) assert len(audio) > 0 mock_generate.assert_called_once() diff --git a/examples/benchmarks/format_comparison.png b/examples/benchmarks/format_comparison.png new file mode 100644 index 0000000..95ac515 Binary files /dev/null and b/examples/benchmarks/format_comparison.png differ diff --git a/examples/test_audio_formats.py b/examples/test_audio_formats.py new file mode 100644 index 0000000..e126dec --- /dev/null +++ b/examples/test_audio_formats.py @@ -0,0 +1,284 @@ +"""Test script to generate and analyze different audio formats""" + +import os +import time +from pathlib import Path + +import numpy as np +import openai +import requests +import soundfile as sf +import matplotlib.pyplot as plt +from scipy.io import wavfile + +SAMPLE_TEXT = """ +That is the germ of my great discovery. But you are wrong to say that we cannot move about in Time. +""" + +# Configure OpenAI client +client = openai.OpenAI( + timeout=60, + api_key="notneeded", # API key not required for our endpoint + base_url="http://localhost:8880/v1", # Point to our local server with v1 prefix +) + + +def setup_plot(fig, ax, title): + """Configure plot styling""" + # Improve grid + ax.grid(True, linestyle="--", alpha=0.3, color="#ffffff") + + # Set title and labels with better fonts and more padding + ax.set_title(title, pad=40, fontsize=16, fontweight="bold", color="#ffffff") + ax.set_xlabel(ax.get_xlabel(), fontsize=14, fontweight="medium", color="#ffffff") + ax.set_ylabel(ax.get_ylabel(), fontsize=14, fontweight="medium", color="#ffffff") + + # Improve tick labels + ax.tick_params(labelsize=12, colors="#ffffff") + + # Style spines + for spine in ax.spines.values(): + spine.set_color("#ffffff") + spine.set_alpha(0.3) + spine.set_linewidth(0.5) + + # Set background colors + ax.set_facecolor("#1a1a2e") + fig.patch.set_facecolor("#1a1a2e") + + return fig, ax + + +def plot_format_comparison(stats: list, output_dir: str): + """Plot audio format comparison""" + plt.style.use("dark_background") + + # Create figure with subplots + fig = plt.figure(figsize=(18, 16)) # Taller figure to accommodate bottom legend + fig.patch.set_facecolor("#1a1a2e") + + # Create subplot grid with balanced spacing for waveforms + gs_waves = plt.GridSpec( + len(stats), 1, left=0.15, right=0.85, top=0.9, bottom=0.35, hspace=0.4 + ) + + # Plot waveforms for each format + for i, stat in enumerate(stats): + format_name = stat["format"].upper() + try: + # Handle PCM format differently + if stat["format"] == "pcm": + # Read raw PCM data (16-bit mono) + with open( + os.path.join(output_dir, f"test_audio.{stat['format']}"), "rb" + ) as f: + raw_data = f.read() + data = np.frombuffer(raw_data, dtype=np.int16) + data = data.astype(np.float32) / 32768.0 # Convert to float [-1, 1] + sr = 24000 + else: + # Read other formats with soundfile + data, sr = sf.read( + os.path.join(output_dir, f"test_audio.{stat['format']}") + ) + + # Plot waveform + ax = plt.subplot(gs_waves[i]) + time = np.arange(len(data)) / sr + plt.plot(time, data / np.max(np.abs(data)), linewidth=0.5, color="#ff2a6d") + ax.set_xlabel("Time (seconds)") + ax.set_ylabel("") + ax.set_ylim(-1.1, 1.1) + setup_plot(fig, ax, f"Waveform: {format_name}") + except Exception as e: + print(f"Error plotting waveform for {format_name}: {e}") + + # Colors for formats + colors = ["#ff2a6d", "#05d9e8", "#d1f7ff", "#ff9e00", "#8c1eff"] + + # Create three subplots for metrics with more space at bottom for legend + gs_bottom = plt.GridSpec( + 1, + 3, + left=0.15, + right=0.85, + bottom=0.15, + top=0.25, # More bottom space for legend + wspace=0.3, + ) + + # File Size subplot + ax1 = plt.subplot(gs_bottom[0]) + metrics1 = [("File Size", [s["file_size_kb"] for s in stats], "KB")] + + # Duration and Gen Time subplot + ax2 = plt.subplot(gs_bottom[1]) + metrics2 = [ + ("Duration", [s["duration_seconds"] for s in stats], "s"), + ("Gen Time", [s["generation_time"] for s in stats], "s"), + ] + + # Sample Rate subplot + ax3 = plt.subplot(gs_bottom[2]) + metrics3 = [("Sample Rate", [s["sample_rate"] / 1000 for s in stats], "kHz")] + + def plot_grouped_bars(ax, metrics, show_legend=True): + n_groups = len(metrics) + n_formats = len(stats) + # Use wider bars for time metrics + bar_width = 0.175 if metrics == metrics2 else 0.1 + + indices = np.arange(n_groups) + + # Get max value for y-axis scaling + max_val = max(max(m[1]) for m in metrics) + + for i, (stat, color) in enumerate(zip(stats, colors)): + values = [m[1][i] for m in metrics] + # Reduce spacing between bars for time metrics + spacing = 1.1 if metrics == metrics2 else 1.0 + offset = (i - n_formats / 2 + 0.5) * bar_width * spacing + bars = ax.bar( + indices + offset, + values, + bar_width, + label=stat["format"].upper(), + color=color, + alpha=0.8, + ) + + # Add value labels on top of bars + for bar in bars: + height = bar.get_height() + ax.text( + bar.get_x() + bar.get_width() / 2.0, + height, + f"{height:.1f}", + ha="center", + va="bottom", + color="white", + fontsize=10, + ) + + ax.set_xticks(indices) + ax.set_xticklabels([f"{m[0]}\n({m[2]})" for m in metrics]) + + # Set y-axis limits with some padding + ax.set_ylim(0, max_val * 1.2) + + if show_legend: + # Place legend at the bottom + ax.legend( + bbox_to_anchor=(1.8, -0.8), + loc="center", + facecolor="#1a1a2e", + edgecolor="#ffffff", + ncol=len(stats), + ) # Show all formats in one row + + # Plot all three subplots with shared legend + plot_grouped_bars(ax1, metrics1, show_legend=True) + plot_grouped_bars(ax2, metrics2, show_legend=False) + plot_grouped_bars(ax3, metrics3, show_legend=False) + + # Style all subplots + setup_plot(fig, ax1, "File Size") + setup_plot(fig, ax2, "Time Metrics") + setup_plot(fig, ax3, "Sample Rate") + + # Add y-axis labels + ax1.set_ylabel("Value") + ax2.set_ylabel("Value") + ax3.set_ylabel("Value") + + # Save the plot + plt.savefig(os.path.join(output_dir, "format_comparison.png"), dpi=300) + print(f"\nSaved format comparison plot to {output_dir}/format_comparison.png") + + +def get_audio_stats(file_path: str) -> dict: + """Get audio file statistics""" + file_size = os.path.getsize(file_path) + file_size_kb = file_size / 1024 # Convert to KB + + try: + # Try reading with soundfile first + data, sample_rate = sf.read(file_path) + duration = len(data) / sample_rate + channels = 1 if len(data.shape) == 1 else data.shape[1] + + # Calculate audio statistics + stats = { + "format": Path(file_path).suffix[1:], + "file_size_kb": round(file_size_kb, 2), + "duration_seconds": round(duration, 2), + "sample_rate": sample_rate, + "channels": channels, + "min_amplitude": float(np.min(data)), + "max_amplitude": float(np.max(data)), + "mean_amplitude": float(np.mean(np.abs(data))), + "rms_amplitude": float(np.sqrt(np.mean(np.square(data)))), + } + return stats + except: + # For PCM, read raw bytes and estimate duration + with open(file_path, "rb") as f: + data = f.read() + # Assuming 16-bit PCM mono at 24kHz + samples = len(data) // 2 # 2 bytes per sample + duration = samples / 24000 + return { + "format": "pcm", + "file_size_kb": round(file_size_kb, 2), + "duration_seconds": round(duration, 2), + "sample_rate": 24000, + "channels": 1, + "note": "PCM stats are estimated from raw bytes", + } + + +def main(): + """Generate and analyze audio in different formats""" + # Create output directory + output_dir = Path(__file__).parent / "output" / "test_formats" + output_dir.mkdir(exist_ok=True, parents=True) + + # First generate audio in each format using the API + voice = "af" # Using default voice + formats = ["wav", "mp3", "opus", "flac", "pcm"] + stats = [] + + for fmt in formats: + output_path = output_dir / f"test_audio.{fmt}" + print(f"\nGenerating {fmt.upper()} audio...") + + # Generate and save + start_time = time.time() + response = client.audio.speech.create( + model="kokoro", voice=voice, input=SAMPLE_TEXT, response_format=fmt + ) + generation_time = time.time() - start_time + + with open(output_path, "wb") as f: + f.write(response.content) + + # Get stats + file_stats = get_audio_stats(str(output_path)) + file_stats["generation_time"] = round(generation_time, 3) + stats.append(file_stats) + + # Generate comparison plot + plot_format_comparison(stats, str(output_dir)) + + # Print detailed statistics + print("\nDetailed Audio Statistics:") + print("=" * 100) + for stat in stats: + print(f"\n{stat['format'].upper()} Format:") + for key, value in sorted(stat.items()): + if key not in ["format"]: # Skip format as it's in the header + print(f" {key}: {value}") + + +if __name__ == "__main__": + main()