mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
commit
ee5be65596
8 changed files with 486 additions and 67 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -14,3 +14,5 @@ env/
|
||||||
.Python
|
.Python
|
||||||
|
|
||||||
|
|
||||||
|
.coverage
|
||||||
|
|
||||||
|
|
|
@ -2,8 +2,12 @@
|
||||||
|
|
||||||
Notable changes to this project will be documented in this file.
|
Notable changes to this project will be documented in this file.
|
||||||
|
|
||||||
## 2024-01-09
|
|
||||||
|
|
||||||
|
## 2025-01-02
|
||||||
|
- Audio Format Support:
|
||||||
|
- Added comprehensive audio format conversion support (mp3, wav, opus, flac)
|
||||||
|
|
||||||
|
## 2025-01-01
|
||||||
### Added
|
### Added
|
||||||
- Gradio Web Interface:
|
- Gradio Web Interface:
|
||||||
- Added simple web UI utility for audio generation from input or txt file
|
- Added simple web UI utility for audio generation from input or txt file
|
||||||
|
|
74
README.md
74
README.md
|
@ -3,8 +3,8 @@
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
# Kokoro TTS API
|
# Kokoro TTS API
|
||||||
[]()
|
[]()
|
||||||
[]()
|
[]()
|
||||||
[](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667)
|
[](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667)
|
||||||
|
|
||||||
Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model
|
Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model
|
||||||
|
@ -14,8 +14,7 @@ Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokor
|
||||||
- automatic chunking/stitching for long texts
|
- automatic chunking/stitching for long texts
|
||||||
- simple audio generation web ui utility
|
- simple audio generation web ui utility
|
||||||
|
|
||||||
<details open>
|
## Quick Start
|
||||||
<summary><b>Quick Start</b></summary>
|
|
||||||
|
|
||||||
The service can be accessed through either the API endpoints or the Gradio web interface.
|
The service can be accessed through either the API endpoints or the Gradio web interface.
|
||||||
|
|
||||||
|
@ -48,9 +47,10 @@ The service can be accessed through either the API endpoints or the Gradio web i
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="ui\GradioScreenShot.png" width="80%" alt="Voice Analysis Comparison" style="border: 2px solid #333; padding: 10px;">
|
<img src="ui\GradioScreenShot.png" width="80%" alt="Voice Analysis Comparison" style="border: 2px solid #333; padding: 10px;">
|
||||||
</p>
|
</p>
|
||||||
</details>
|
|
||||||
|
## Features
|
||||||
<details>
|
<details>
|
||||||
<summary><b>OpenAI-Compatible Speech Endpoint</b></summary>
|
<summary>OpenAI-Compatible Speech Endpoint</summary>
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Using OpenAI's Python library
|
# Using OpenAI's Python library
|
||||||
|
@ -98,7 +98,10 @@ python examples/test_all_voices.py # Test all available voices
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><b>Voice Combination</b></summary>
|
<summary>Voice Combination</summary>
|
||||||
|
|
||||||
|
- Averages model weights of any existing voicepacks
|
||||||
|
- Saves generated voicepacks for future use
|
||||||
|
|
||||||
Combine voices and generate audio:
|
Combine voices and generate audio:
|
||||||
```python
|
```python
|
||||||
|
@ -129,7 +132,23 @@ response = requests.post(
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><b>Gradio Web Utility</b></summary>
|
<summary>Multiple Output Audio Formats</summary>
|
||||||
|
|
||||||
|
- mp3
|
||||||
|
- wav
|
||||||
|
- opus
|
||||||
|
- flac
|
||||||
|
- aac
|
||||||
|
- pcm
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<img src="examples/benchmarks/format_comparison.png" width="80%" alt="Audio Format Comparison" style="border: 2px solid #333; padding: 10px;">
|
||||||
|
</p>
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Gradio Web Utility</summary>
|
||||||
|
|
||||||
Access the interactive web UI at http://localhost:7860 after starting the service. Features include:
|
Access the interactive web UI at http://localhost:7860 after starting the service. Features include:
|
||||||
- Voice/format/speed selection
|
- Voice/format/speed selection
|
||||||
|
@ -141,9 +160,9 @@ If you only want the API, just comment out everything in the docker-compose.yml
|
||||||
Currently, voices created via the API are accessible here, but voice combination/creation has not yet been added
|
Currently, voices created via the API are accessible here, but voice combination/creation has not yet been added
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
## Processing Details
|
||||||
<details>
|
<details>
|
||||||
<summary><b>Performance Benchmarks</b></summary>
|
<summary>Performance Benchmarks</summary>
|
||||||
|
|
||||||
Benchmarking was performed on generation via the local API using text lengths up to feature-length books (~1.5 hours output), measuring processing time and realtime factor. Tests were run on:
|
Benchmarking was performed on generation via the local API using text lengths up to feature-length books (~1.5 hours output), measuring processing time and realtime factor. Tests were run on:
|
||||||
- Windows 11 Home w/ WSL2
|
- Windows 11 Home w/ WSL2
|
||||||
|
@ -163,7 +182,7 @@ Key Performance Metrics:
|
||||||
- Average Processing Rate: 137.67 tokens/second (cl100k_base)
|
- Average Processing Rate: 137.67 tokens/second (cl100k_base)
|
||||||
</details>
|
</details>
|
||||||
<details>
|
<details>
|
||||||
<summary><b>GPU Vs. CPU<b></summary>
|
<summary>GPU Vs. CPU</summary>
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# GPU: Requires NVIDIA GPU with CUDA 12.1 support
|
# GPU: Requires NVIDIA GPU with CUDA 12.1 support
|
||||||
|
@ -172,35 +191,29 @@ docker compose up --build
|
||||||
# CPU: ~10x slower than GPU inference
|
# CPU: ~10x slower than GPU inference
|
||||||
docker compose -f docker-compose.cpu.yml up --build
|
docker compose -f docker-compose.cpu.yml up --build
|
||||||
```
|
```
|
||||||
</details>
|
|
||||||
<details>
|
|
||||||
<summary><b>Features</b></summary>
|
|
||||||
|
|
||||||
- OpenAI-compatible API endpoints (with optional Gradio Web UI)
|
|
||||||
- GPU-accelerated inference (if desired)
|
|
||||||
- Multiple audio formats: mp3, wav, opus, flac, (aac & pcm not implemented)
|
|
||||||
- Natural Boundary Detection:
|
|
||||||
- Automatically splits and stitches at sentence boundaries to reduce artifacts and maintain performacne
|
|
||||||
- Voice Combination:
|
|
||||||
- Averages model weights of any existing voicepacks
|
|
||||||
- Saves generated voicepacks for future use
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
*Note: CPU Inference is currently a very basic implementation, and not heavily tested*
|
*Note: CPU Inference is currently a very basic implementation, and not heavily tested*
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Natural Boundary Detection</summary>
|
||||||
|
|
||||||
|
- Automatically splits and stitches at sentence boundaries
|
||||||
|
- Helps to reduce artifacts and allow long form processing as the base model is only currently configured for approximately 30s output
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## Model and License
|
||||||
|
|
||||||
<details open>
|
<details open>
|
||||||
<summary><b>Model</b></summary>
|
<summary>Model</summary>
|
||||||
|
|
||||||
This API uses the [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) model from HuggingFace.
|
This API uses the [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) model from HuggingFace.
|
||||||
|
|
||||||
Visit the model page for more details about training, architecture, and capabilities. I have no affiliation with any of their work, and produced this wrapper for ease of use and personal projects.
|
Visit the model page for more details about training, architecture, and capabilities. I have no affiliation with any of their work, and produced this wrapper for ease of use and personal projects.
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><b>License</b></summary>
|
<summary>License</summary>
|
||||||
|
|
||||||
This project is licensed under the Apache License 2.0 - see below for details:
|
This project is licensed under the Apache License 2.0 - see below for details:
|
||||||
|
|
||||||
- The Kokoro model weights are licensed under Apache 2.0 (see [model page](https://huggingface.co/hexgrad/Kokoro-82M))
|
- The Kokoro model weights are licensed under Apache 2.0 (see [model page](https://huggingface.co/hexgrad/Kokoro-82M))
|
||||||
|
@ -209,3 +222,6 @@ This project is licensed under the Apache License 2.0 - see below for details:
|
||||||
|
|
||||||
The full Apache 2.0 license text can be found at: https://www.apache.org/licenses/LICENSE-2.0
|
The full Apache 2.0 license text can be found at: https://www.apache.org/licenses/LICENSE-2.0
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,6 @@ from io import BytesIO
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
import scipy.io.wavfile as wavfile
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,7 +19,7 @@ class AudioService:
|
||||||
Args:
|
Args:
|
||||||
audio_data: Numpy array of audio samples
|
audio_data: Numpy array of audio samples
|
||||||
sample_rate: Sample rate of the audio
|
sample_rate: Sample rate of the audio
|
||||||
output_format: Target format (wav, mp3, etc.)
|
output_format: Target format (wav, mp3, opus, flac, pcm)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Bytes of the converted audio
|
Bytes of the converted audio
|
||||||
|
@ -30,46 +29,36 @@ class AudioService:
|
||||||
try:
|
try:
|
||||||
if output_format == "wav":
|
if output_format == "wav":
|
||||||
logger.info("Writing to WAV format...")
|
logger.info("Writing to WAV format...")
|
||||||
wavfile.write(buffer, sample_rate, audio_data)
|
# Ensure audio_data is in int16 format for WAV
|
||||||
return buffer.getvalue()
|
audio_data_wav = (
|
||||||
|
audio_data / np.abs(audio_data).max() * np.iinfo(np.int16).max
|
||||||
|
).astype(np.int16) # Normalize
|
||||||
|
sf.write(buffer, audio_data_wav, sample_rate, format="WAV")
|
||||||
elif output_format == "mp3":
|
elif output_format == "mp3":
|
||||||
# For MP3, we need to convert to WAV first
|
|
||||||
logger.info("Converting to MP3 format...")
|
logger.info("Converting to MP3 format...")
|
||||||
wav_buffer = BytesIO()
|
# soundfile can write MP3 if ffmpeg or libsox is installed
|
||||||
wavfile.write(wav_buffer, sample_rate, audio_data)
|
sf.write(buffer, audio_data, sample_rate, format="MP3")
|
||||||
wav_buffer.seek(0)
|
|
||||||
|
|
||||||
# Convert WAV to MP3 using soundfile
|
|
||||||
buffer = BytesIO()
|
|
||||||
sf.write(buffer, audio_data, sample_rate, format="mp3")
|
|
||||||
return buffer.getvalue()
|
|
||||||
|
|
||||||
elif output_format == "opus":
|
elif output_format == "opus":
|
||||||
logger.info("Converting to Opus format...")
|
logger.info("Converting to Opus format...")
|
||||||
sf.write(buffer, audio_data, sample_rate, format="ogg", subtype="opus")
|
sf.write(buffer, audio_data, sample_rate, format="OGG", subtype="OPUS")
|
||||||
return buffer.getvalue()
|
|
||||||
|
|
||||||
elif output_format == "flac":
|
elif output_format == "flac":
|
||||||
logger.info("Converting to FLAC format...")
|
logger.info("Converting to FLAC format...")
|
||||||
sf.write(buffer, audio_data, sample_rate, format="flac")
|
sf.write(buffer, audio_data, sample_rate, format="FLAC")
|
||||||
return buffer.getvalue()
|
|
||||||
|
|
||||||
elif output_format == "aac":
|
|
||||||
raise ValueError(
|
|
||||||
"AAC format is not currently supported. Please use wav, mp3, opus, or flac."
|
|
||||||
)
|
|
||||||
|
|
||||||
elif output_format == "pcm":
|
elif output_format == "pcm":
|
||||||
raise ValueError(
|
logger.info("Extracting PCM data...")
|
||||||
"PCM format is not currently supported. Please use wav, mp3, opus, or flac."
|
# Ensure audio_data is in int16 format for PCM
|
||||||
)
|
audio_data_pcm = (
|
||||||
|
audio_data / np.abs(audio_data).max() * np.iinfo(np.int16).max
|
||||||
|
).astype(np.int16) # Normalize
|
||||||
|
buffer.write(audio_data_pcm.tobytes())
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Format {output_format} not supported. Supported formats are: wav, mp3, opus, flac."
|
f"Format {output_format} not supported. Supported formats are: wav, mp3, opus, flac, pcm."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
buffer.seek(0)
|
||||||
|
return buffer.getvalue()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error converting audio to {output_format}: {str(e)}")
|
logger.error(f"Error converting audio to {output_format}: {str(e)}")
|
||||||
raise ValueError(f"Failed to convert audio to {output_format}: {str(e)}")
|
raise ValueError(f"Failed to convert audio to {output_format}: {str(e)}")
|
||||||
|
|
|
@ -51,15 +51,19 @@ def test_convert_to_flac(sample_audio):
|
||||||
def test_convert_to_aac_raises_error(sample_audio):
|
def test_convert_to_aac_raises_error(sample_audio):
|
||||||
"""Test that converting to AAC raises an error"""
|
"""Test that converting to AAC raises an error"""
|
||||||
audio_data, sample_rate = sample_audio
|
audio_data, sample_rate = sample_audio
|
||||||
with pytest.raises(ValueError, match="AAC format is not currently supported"):
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match="Format aac not supported. Supported formats are: wav, mp3, opus, flac, pcm.",
|
||||||
|
):
|
||||||
AudioService.convert_audio(audio_data, sample_rate, "aac")
|
AudioService.convert_audio(audio_data, sample_rate, "aac")
|
||||||
|
|
||||||
|
|
||||||
def test_convert_to_pcm_raises_error(sample_audio):
|
def test_convert_to_pcm(sample_audio):
|
||||||
"""Test that converting to PCM raises an error"""
|
"""Test converting to PCM format"""
|
||||||
audio_data, sample_rate = sample_audio
|
audio_data, sample_rate = sample_audio
|
||||||
with pytest.raises(ValueError, match="PCM format is not currently supported"):
|
result = AudioService.convert_audio(audio_data, sample_rate, "pcm")
|
||||||
AudioService.convert_audio(audio_data, sample_rate, "pcm")
|
assert isinstance(result, bytes)
|
||||||
|
assert len(result) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_convert_to_invalid_format_raises_error(sample_audio):
|
def test_convert_to_invalid_format_raises_error(sample_audio):
|
||||||
|
@ -67,3 +71,51 @@ def test_convert_to_invalid_format_raises_error(sample_audio):
|
||||||
audio_data, sample_rate = sample_audio
|
audio_data, sample_rate = sample_audio
|
||||||
with pytest.raises(ValueError, match="Format invalid not supported"):
|
with pytest.raises(ValueError, match="Format invalid not supported"):
|
||||||
AudioService.convert_audio(audio_data, sample_rate, "invalid")
|
AudioService.convert_audio(audio_data, sample_rate, "invalid")
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalization_wav(sample_audio):
|
||||||
|
"""Test that WAV output is properly normalized to int16 range"""
|
||||||
|
audio_data, sample_rate = sample_audio
|
||||||
|
# Create audio data outside int16 range
|
||||||
|
large_audio = audio_data * 1e5
|
||||||
|
result = AudioService.convert_audio(large_audio, sample_rate, "wav")
|
||||||
|
assert isinstance(result, bytes)
|
||||||
|
assert len(result) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalization_pcm(sample_audio):
|
||||||
|
"""Test that PCM output is properly normalized to int16 range"""
|
||||||
|
audio_data, sample_rate = sample_audio
|
||||||
|
# Create audio data outside int16 range
|
||||||
|
large_audio = audio_data * 1e5
|
||||||
|
result = AudioService.convert_audio(large_audio, sample_rate, "pcm")
|
||||||
|
assert isinstance(result, bytes)
|
||||||
|
assert len(result) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_audio_data():
|
||||||
|
"""Test handling of invalid audio data"""
|
||||||
|
invalid_audio = np.array([]) # Empty array
|
||||||
|
sample_rate = 24000
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
AudioService.convert_audio(invalid_audio, sample_rate, "wav")
|
||||||
|
|
||||||
|
|
||||||
|
def test_different_sample_rates(sample_audio):
|
||||||
|
"""Test converting audio with different sample rates"""
|
||||||
|
audio_data, _ = sample_audio
|
||||||
|
sample_rates = [8000, 16000, 44100, 48000]
|
||||||
|
|
||||||
|
for rate in sample_rates:
|
||||||
|
result = AudioService.convert_audio(audio_data, rate, "wav")
|
||||||
|
assert isinstance(result, bytes)
|
||||||
|
assert len(result) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_buffer_position_after_conversion(sample_audio):
|
||||||
|
"""Test that buffer position is reset after writing"""
|
||||||
|
audio_data, sample_rate = sample_audio
|
||||||
|
result = AudioService.convert_audio(audio_data, sample_rate, "wav")
|
||||||
|
# Convert again to ensure buffer was properly reset
|
||||||
|
result2 = AudioService.convert_audio(audio_data, sample_rate, "wav")
|
||||||
|
assert len(result) == len(result2)
|
||||||
|
|
|
@ -4,6 +4,7 @@ import os
|
||||||
from unittest.mock import MagicMock, call, patch
|
from unittest.mock import MagicMock, call, patch
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from api.src.services.tts import TTSModel, TTSService
|
from api.src.services.tts import TTSModel, TTSService
|
||||||
|
@ -119,6 +120,78 @@ def test_generate_audio_no_chunks(
|
||||||
tts_service._generate_audio("Test text", "af", 1.0)
|
tts_service._generate_audio("Test text", "af", 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("torch.load")
|
||||||
|
@patch("torch.save")
|
||||||
|
@patch("torch.stack")
|
||||||
|
@patch("torch.mean")
|
||||||
|
@patch("os.path.exists")
|
||||||
|
def test_combine_voices(
|
||||||
|
mock_exists, mock_mean, mock_stack, mock_save, mock_load, tts_service
|
||||||
|
):
|
||||||
|
"""Test combining multiple voices"""
|
||||||
|
# Setup mocks
|
||||||
|
mock_exists.return_value = True
|
||||||
|
mock_load.return_value = torch.tensor([1.0, 2.0])
|
||||||
|
mock_stack.return_value = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
|
||||||
|
mock_mean.return_value = torch.tensor([2.0, 3.0])
|
||||||
|
|
||||||
|
# Test combining two voices
|
||||||
|
result = tts_service.combine_voices(["voice1", "voice2"])
|
||||||
|
|
||||||
|
assert result == "voice1_voice2"
|
||||||
|
mock_stack.assert_called_once()
|
||||||
|
mock_mean.assert_called_once()
|
||||||
|
mock_save.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_combine_voices_invalid_input(tts_service):
|
||||||
|
"""Test combining voices with invalid input"""
|
||||||
|
# Test with empty list
|
||||||
|
with pytest.raises(ValueError, match="At least 2 voices are required"):
|
||||||
|
tts_service.combine_voices([])
|
||||||
|
|
||||||
|
# Test with single voice
|
||||||
|
with pytest.raises(ValueError, match="At least 2 voices are required"):
|
||||||
|
tts_service.combine_voices(["voice1"])
|
||||||
|
|
||||||
|
|
||||||
|
@patch("os.makedirs")
|
||||||
|
@patch("os.path.exists")
|
||||||
|
@patch("os.listdir")
|
||||||
|
@patch("torch.load")
|
||||||
|
@patch("torch.save")
|
||||||
|
@patch("os.path.join")
|
||||||
|
def test_ensure_voices(
|
||||||
|
mock_join,
|
||||||
|
mock_save,
|
||||||
|
mock_load,
|
||||||
|
mock_listdir,
|
||||||
|
mock_exists,
|
||||||
|
mock_makedirs,
|
||||||
|
tts_service,
|
||||||
|
):
|
||||||
|
"""Test voice directory initialization"""
|
||||||
|
# Setup mocks
|
||||||
|
mock_exists.side_effect = [
|
||||||
|
True,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
] # base_dir exists, voice files don't exist
|
||||||
|
mock_listdir.return_value = ["voice1.pt", "voice2.pt"]
|
||||||
|
mock_load.return_value = MagicMock()
|
||||||
|
mock_join.return_value = "/fake/path"
|
||||||
|
|
||||||
|
# Test voice directory initialization
|
||||||
|
tts_service._ensure_voices()
|
||||||
|
|
||||||
|
# Verify directory was created
|
||||||
|
mock_makedirs.assert_called_once()
|
||||||
|
|
||||||
|
# Verify voices were loaded and saved
|
||||||
|
assert mock_load.call_count == len(mock_listdir.return_value)
|
||||||
|
assert mock_save.call_count == len(mock_listdir.return_value)
|
||||||
|
|
||||||
|
|
||||||
@patch("api.src.services.tts.TTSModel.get_instance")
|
@patch("api.src.services.tts.TTSModel.get_instance")
|
||||||
@patch("os.path.exists")
|
@patch("os.path.exists")
|
||||||
@patch("api.src.services.tts.normalize_text")
|
@patch("api.src.services.tts.normalize_text")
|
||||||
|
@ -236,7 +309,6 @@ def test_generate_audio_without_stitching(
|
||||||
"Test text", "af", 1.0, stitch_long_output=False
|
"Test text", "af", 1.0, stitch_long_output=False
|
||||||
)
|
)
|
||||||
assert isinstance(audio, np.ndarray)
|
assert isinstance(audio, np.ndarray)
|
||||||
assert isinstance(processing_time, float)
|
|
||||||
assert len(audio) > 0
|
assert len(audio) > 0
|
||||||
mock_generate.assert_called_once()
|
mock_generate.assert_called_once()
|
||||||
|
|
||||||
|
|
BIN
examples/benchmarks/format_comparison.png
Normal file
BIN
examples/benchmarks/format_comparison.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 764 KiB |
284
examples/test_audio_formats.py
Normal file
284
examples/test_audio_formats.py
Normal file
|
@ -0,0 +1,284 @@
|
||||||
|
"""Test script to generate and analyze different audio formats"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import openai
|
||||||
|
import requests
|
||||||
|
import soundfile as sf
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from scipy.io import wavfile
|
||||||
|
|
||||||
|
SAMPLE_TEXT = """
|
||||||
|
That is the germ of my great discovery. But you are wrong to say that we cannot move about in Time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Configure OpenAI client
|
||||||
|
client = openai.OpenAI(
|
||||||
|
timeout=60,
|
||||||
|
api_key="notneeded", # API key not required for our endpoint
|
||||||
|
base_url="http://localhost:8880/v1", # Point to our local server with v1 prefix
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_plot(fig, ax, title):
|
||||||
|
"""Configure plot styling"""
|
||||||
|
# Improve grid
|
||||||
|
ax.grid(True, linestyle="--", alpha=0.3, color="#ffffff")
|
||||||
|
|
||||||
|
# Set title and labels with better fonts and more padding
|
||||||
|
ax.set_title(title, pad=40, fontsize=16, fontweight="bold", color="#ffffff")
|
||||||
|
ax.set_xlabel(ax.get_xlabel(), fontsize=14, fontweight="medium", color="#ffffff")
|
||||||
|
ax.set_ylabel(ax.get_ylabel(), fontsize=14, fontweight="medium", color="#ffffff")
|
||||||
|
|
||||||
|
# Improve tick labels
|
||||||
|
ax.tick_params(labelsize=12, colors="#ffffff")
|
||||||
|
|
||||||
|
# Style spines
|
||||||
|
for spine in ax.spines.values():
|
||||||
|
spine.set_color("#ffffff")
|
||||||
|
spine.set_alpha(0.3)
|
||||||
|
spine.set_linewidth(0.5)
|
||||||
|
|
||||||
|
# Set background colors
|
||||||
|
ax.set_facecolor("#1a1a2e")
|
||||||
|
fig.patch.set_facecolor("#1a1a2e")
|
||||||
|
|
||||||
|
return fig, ax
|
||||||
|
|
||||||
|
|
||||||
|
def plot_format_comparison(stats: list, output_dir: str):
|
||||||
|
"""Plot audio format comparison"""
|
||||||
|
plt.style.use("dark_background")
|
||||||
|
|
||||||
|
# Create figure with subplots
|
||||||
|
fig = plt.figure(figsize=(18, 16)) # Taller figure to accommodate bottom legend
|
||||||
|
fig.patch.set_facecolor("#1a1a2e")
|
||||||
|
|
||||||
|
# Create subplot grid with balanced spacing for waveforms
|
||||||
|
gs_waves = plt.GridSpec(
|
||||||
|
len(stats), 1, left=0.15, right=0.85, top=0.9, bottom=0.35, hspace=0.4
|
||||||
|
)
|
||||||
|
|
||||||
|
# Plot waveforms for each format
|
||||||
|
for i, stat in enumerate(stats):
|
||||||
|
format_name = stat["format"].upper()
|
||||||
|
try:
|
||||||
|
# Handle PCM format differently
|
||||||
|
if stat["format"] == "pcm":
|
||||||
|
# Read raw PCM data (16-bit mono)
|
||||||
|
with open(
|
||||||
|
os.path.join(output_dir, f"test_audio.{stat['format']}"), "rb"
|
||||||
|
) as f:
|
||||||
|
raw_data = f.read()
|
||||||
|
data = np.frombuffer(raw_data, dtype=np.int16)
|
||||||
|
data = data.astype(np.float32) / 32768.0 # Convert to float [-1, 1]
|
||||||
|
sr = 24000
|
||||||
|
else:
|
||||||
|
# Read other formats with soundfile
|
||||||
|
data, sr = sf.read(
|
||||||
|
os.path.join(output_dir, f"test_audio.{stat['format']}")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Plot waveform
|
||||||
|
ax = plt.subplot(gs_waves[i])
|
||||||
|
time = np.arange(len(data)) / sr
|
||||||
|
plt.plot(time, data / np.max(np.abs(data)), linewidth=0.5, color="#ff2a6d")
|
||||||
|
ax.set_xlabel("Time (seconds)")
|
||||||
|
ax.set_ylabel("")
|
||||||
|
ax.set_ylim(-1.1, 1.1)
|
||||||
|
setup_plot(fig, ax, f"Waveform: {format_name}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error plotting waveform for {format_name}: {e}")
|
||||||
|
|
||||||
|
# Colors for formats
|
||||||
|
colors = ["#ff2a6d", "#05d9e8", "#d1f7ff", "#ff9e00", "#8c1eff"]
|
||||||
|
|
||||||
|
# Create three subplots for metrics with more space at bottom for legend
|
||||||
|
gs_bottom = plt.GridSpec(
|
||||||
|
1,
|
||||||
|
3,
|
||||||
|
left=0.15,
|
||||||
|
right=0.85,
|
||||||
|
bottom=0.15,
|
||||||
|
top=0.25, # More bottom space for legend
|
||||||
|
wspace=0.3,
|
||||||
|
)
|
||||||
|
|
||||||
|
# File Size subplot
|
||||||
|
ax1 = plt.subplot(gs_bottom[0])
|
||||||
|
metrics1 = [("File Size", [s["file_size_kb"] for s in stats], "KB")]
|
||||||
|
|
||||||
|
# Duration and Gen Time subplot
|
||||||
|
ax2 = plt.subplot(gs_bottom[1])
|
||||||
|
metrics2 = [
|
||||||
|
("Duration", [s["duration_seconds"] for s in stats], "s"),
|
||||||
|
("Gen Time", [s["generation_time"] for s in stats], "s"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Sample Rate subplot
|
||||||
|
ax3 = plt.subplot(gs_bottom[2])
|
||||||
|
metrics3 = [("Sample Rate", [s["sample_rate"] / 1000 for s in stats], "kHz")]
|
||||||
|
|
||||||
|
def plot_grouped_bars(ax, metrics, show_legend=True):
|
||||||
|
n_groups = len(metrics)
|
||||||
|
n_formats = len(stats)
|
||||||
|
# Use wider bars for time metrics
|
||||||
|
bar_width = 0.175 if metrics == metrics2 else 0.1
|
||||||
|
|
||||||
|
indices = np.arange(n_groups)
|
||||||
|
|
||||||
|
# Get max value for y-axis scaling
|
||||||
|
max_val = max(max(m[1]) for m in metrics)
|
||||||
|
|
||||||
|
for i, (stat, color) in enumerate(zip(stats, colors)):
|
||||||
|
values = [m[1][i] for m in metrics]
|
||||||
|
# Reduce spacing between bars for time metrics
|
||||||
|
spacing = 1.1 if metrics == metrics2 else 1.0
|
||||||
|
offset = (i - n_formats / 2 + 0.5) * bar_width * spacing
|
||||||
|
bars = ax.bar(
|
||||||
|
indices + offset,
|
||||||
|
values,
|
||||||
|
bar_width,
|
||||||
|
label=stat["format"].upper(),
|
||||||
|
color=color,
|
||||||
|
alpha=0.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add value labels on top of bars
|
||||||
|
for bar in bars:
|
||||||
|
height = bar.get_height()
|
||||||
|
ax.text(
|
||||||
|
bar.get_x() + bar.get_width() / 2.0,
|
||||||
|
height,
|
||||||
|
f"{height:.1f}",
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
color="white",
|
||||||
|
fontsize=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_xticks(indices)
|
||||||
|
ax.set_xticklabels([f"{m[0]}\n({m[2]})" for m in metrics])
|
||||||
|
|
||||||
|
# Set y-axis limits with some padding
|
||||||
|
ax.set_ylim(0, max_val * 1.2)
|
||||||
|
|
||||||
|
if show_legend:
|
||||||
|
# Place legend at the bottom
|
||||||
|
ax.legend(
|
||||||
|
bbox_to_anchor=(1.8, -0.8),
|
||||||
|
loc="center",
|
||||||
|
facecolor="#1a1a2e",
|
||||||
|
edgecolor="#ffffff",
|
||||||
|
ncol=len(stats),
|
||||||
|
) # Show all formats in one row
|
||||||
|
|
||||||
|
# Plot all three subplots with shared legend
|
||||||
|
plot_grouped_bars(ax1, metrics1, show_legend=True)
|
||||||
|
plot_grouped_bars(ax2, metrics2, show_legend=False)
|
||||||
|
plot_grouped_bars(ax3, metrics3, show_legend=False)
|
||||||
|
|
||||||
|
# Style all subplots
|
||||||
|
setup_plot(fig, ax1, "File Size")
|
||||||
|
setup_plot(fig, ax2, "Time Metrics")
|
||||||
|
setup_plot(fig, ax3, "Sample Rate")
|
||||||
|
|
||||||
|
# Add y-axis labels
|
||||||
|
ax1.set_ylabel("Value")
|
||||||
|
ax2.set_ylabel("Value")
|
||||||
|
ax3.set_ylabel("Value")
|
||||||
|
|
||||||
|
# Save the plot
|
||||||
|
plt.savefig(os.path.join(output_dir, "format_comparison.png"), dpi=300)
|
||||||
|
print(f"\nSaved format comparison plot to {output_dir}/format_comparison.png")
|
||||||
|
|
||||||
|
|
||||||
|
def get_audio_stats(file_path: str) -> dict:
|
||||||
|
"""Get audio file statistics"""
|
||||||
|
file_size = os.path.getsize(file_path)
|
||||||
|
file_size_kb = file_size / 1024 # Convert to KB
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try reading with soundfile first
|
||||||
|
data, sample_rate = sf.read(file_path)
|
||||||
|
duration = len(data) / sample_rate
|
||||||
|
channels = 1 if len(data.shape) == 1 else data.shape[1]
|
||||||
|
|
||||||
|
# Calculate audio statistics
|
||||||
|
stats = {
|
||||||
|
"format": Path(file_path).suffix[1:],
|
||||||
|
"file_size_kb": round(file_size_kb, 2),
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
"sample_rate": sample_rate,
|
||||||
|
"channels": channels,
|
||||||
|
"min_amplitude": float(np.min(data)),
|
||||||
|
"max_amplitude": float(np.max(data)),
|
||||||
|
"mean_amplitude": float(np.mean(np.abs(data))),
|
||||||
|
"rms_amplitude": float(np.sqrt(np.mean(np.square(data)))),
|
||||||
|
}
|
||||||
|
return stats
|
||||||
|
except:
|
||||||
|
# For PCM, read raw bytes and estimate duration
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
data = f.read()
|
||||||
|
# Assuming 16-bit PCM mono at 24kHz
|
||||||
|
samples = len(data) // 2 # 2 bytes per sample
|
||||||
|
duration = samples / 24000
|
||||||
|
return {
|
||||||
|
"format": "pcm",
|
||||||
|
"file_size_kb": round(file_size_kb, 2),
|
||||||
|
"duration_seconds": round(duration, 2),
|
||||||
|
"sample_rate": 24000,
|
||||||
|
"channels": 1,
|
||||||
|
"note": "PCM stats are estimated from raw bytes",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Generate and analyze audio in different formats"""
|
||||||
|
# Create output directory
|
||||||
|
output_dir = Path(__file__).parent / "output" / "test_formats"
|
||||||
|
output_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
# First generate audio in each format using the API
|
||||||
|
voice = "af" # Using default voice
|
||||||
|
formats = ["wav", "mp3", "opus", "flac", "pcm"]
|
||||||
|
stats = []
|
||||||
|
|
||||||
|
for fmt in formats:
|
||||||
|
output_path = output_dir / f"test_audio.{fmt}"
|
||||||
|
print(f"\nGenerating {fmt.upper()} audio...")
|
||||||
|
|
||||||
|
# Generate and save
|
||||||
|
start_time = time.time()
|
||||||
|
response = client.audio.speech.create(
|
||||||
|
model="kokoro", voice=voice, input=SAMPLE_TEXT, response_format=fmt
|
||||||
|
)
|
||||||
|
generation_time = time.time() - start_time
|
||||||
|
|
||||||
|
with open(output_path, "wb") as f:
|
||||||
|
f.write(response.content)
|
||||||
|
|
||||||
|
# Get stats
|
||||||
|
file_stats = get_audio_stats(str(output_path))
|
||||||
|
file_stats["generation_time"] = round(generation_time, 3)
|
||||||
|
stats.append(file_stats)
|
||||||
|
|
||||||
|
# Generate comparison plot
|
||||||
|
plot_format_comparison(stats, str(output_dir))
|
||||||
|
|
||||||
|
# Print detailed statistics
|
||||||
|
print("\nDetailed Audio Statistics:")
|
||||||
|
print("=" * 100)
|
||||||
|
for stat in stats:
|
||||||
|
print(f"\n{stat['format'].upper()} Format:")
|
||||||
|
for key, value in sorted(stat.items()):
|
||||||
|
if key not in ["format"]: # Skip format as it's in the header
|
||||||
|
print(f" {key}: {value}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Add table
Reference in a new issue