Merge pull request #102 from remsky/v0.1.4

V0.1.4: Improved web UI streaming headers
This commit is contained in:
remsky 2025-01-30 23:00:11 -07:00 committed by GitHub
commit 5ddeba26d8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
132 changed files with 10327 additions and 10288 deletions

View file

@ -15,11 +15,16 @@ jobs:
steps:
- uses: actions/checkout@v4
# Add FFmpeg installation step
- name: Install FFmpeg
# Match Dockerfile dependencies
- name: Install Dependencies
run: |
sudo apt-get update
sudo apt-get install -y ffmpeg
sudo apt-get install -y --no-install-recommends \
espeak-ng \
git \
libsndfile1 \
curl \
ffmpeg
- name: Install uv
uses: astral-sh/setup-uv@v5

44
.gitignore vendored
View file

@ -18,7 +18,8 @@ __pycache__/
*.egg
dist/
build/
*.onnx
*.pth
# Environment
# .env
.venv/
@ -38,6 +39,26 @@ ENV/
*.pth
*.tar*
# Other project files
.env
Kokoro-82M/
ui/data/
EXTERNAL_UV_DOCUMENTATION*
app
api/temp_files/
# Docker
Dockerfile*
docker-compose*
examples/ebook_test/chapter_to_audio.py
examples/ebook_test/chapters_to_audio.py
examples/ebook_test/parse_epub.py
api/src/voices/af_jadzia.pt
examples/assorted_checks/test_combinations/output/*
examples/assorted_checks/test_openai/output/*
# Audio files
examples/*.wav
examples/*.pcm
@ -46,22 +67,5 @@ examples/*.flac
examples/*.acc
examples/*.ogg
examples/speech.mp3
examples/phoneme_examples/output/example_1.wav
examples/phoneme_examples/output/example_2.wav
examples/phoneme_examples/output/example_3.wav
# Other project files
Kokoro-82M/
ui/data/
EXTERNAL_UV_DOCUMENTATION*
app
# Docker
Dockerfile*
docker-compose*
examples/assorted_checks/River_of_Teet_-_Sarah_Gailey.epub
examples/ebook_test/chapter_to_audio.py
examples/ebook_test/chapters_to_audio.py
examples/ebook_test/parse_epub.py
examples/ebook_test/River_of_Teet_-_Sarah_Gailey.epub
examples/ebook_test/River_of_Teet_-_Sarah_Gailey.txt
examples/phoneme_examples/output/*.wav
examples/assorted_checks/benchmarks/output_audio/*

View file

@ -2,6 +2,50 @@
Notable changes to this project will be documented in this file.
## [v0.1.4] - 2025-01-30
### Added
- Smart Chunking System:
- New text_processor with smart_split for improved sentence boundary detection
- Dynamically adjusts chunk sizes based on sentence structure, using phoneme/token information in an intial pass
- Should avoid ever going over the 510 limit per chunk, while preserving natural cadence
- Web UI Added (To Be Replacing Gradio):
- Integrated streaming with tempfile generation
- Download links available in X-Download-Path header
- Configurable cleanup triggers for temp files
- Debug Endpoints:
- /debug/threads for thread information and stack traces
- /debug/storage for temp file and output directory monitoring
- /debug/system for system resource information
- /debug/session_pools for ONNX/CUDA session status
- Automated Model Management:
- Auto-download from releases page
- Included download scripts for manual installation
- Pre-packaged voice models in repository
### Changed
- Significant architectural improvements:
- Multi-model architecture support
- Enhanced concurrency handling
- Improved streaming header management
- Better resource/session pool management
## [v0.1.2] - 2025-01-23
### Structural Improvements
- Models can be manually download and placed in api/src/models, or use included script
- TTSGPU/TPSCPU/STTSService classes replaced with a ModelManager service
- CPU/GPU of each of ONNX/PyTorch (Note: Only Pytorch GPU, and ONNX CPU/GPU have been tested)
- Should be able to improve new models as they become available, or new architectures, in a more modular way
- Converted a number of internal processes to async handling to improve concurrency
- Improving separation of concerns towards plug-in and modular structure, making PR's and new features easier
### Web UI (test release)
- An integrated simple web UI has been added on the FastAPI server directly
- This can be disabled via core/config.py or ENV variables if desired.
- Simplifies deployments, utility testing, aesthetics, etc
- Looking to deprecate/collaborate/hand off the Gradio UI
## [v0.1.0] - 2025-01-13
### Changed
- Major Docker improvements:

@ -1 +0,0 @@
Subproject commit c97b7bbc3e60f447383c79b2f94fee861ff156ac

View file

@ -1,70 +0,0 @@
# UV Setup
Deprecated notes for myself
## Structure
```
docker/
├── cpu/
│ ├── pyproject.toml # CPU deps (torch CPU)
│ └── requirements.lock # CPU lockfile
├── gpu/
│ ├── pyproject.toml # GPU deps (torch CUDA)
│ └── requirements.lock # GPU lockfile
└── shared/
└── pyproject.toml # Common deps
```
## Regenerate Lock Files
### CPU
```bash
cd docker/cpu
uv pip compile pyproject.toml ../shared/pyproject.toml --output-file requirements.lock
```
### GPU
```bash
cd docker/gpu
uv pip compile pyproject.toml ../shared/pyproject.toml --output-file requirements.lock
```
## Local Dev Setup
### CPU
```bash
cd docker/cpu
uv venv
.venv\Scripts\activate # Windows
uv pip sync requirements.lock
```
### GPU
```bash
cd docker/gpu
uv venv
.venv\Scripts\activate # Windows
uv pip sync requirements.lock --extra-index-url https://download.pytorch.org/whl/cu121 --index-strategy unsafe-best-match
```
### Run Server
```bash
# From project root with venv active:
uvicorn api.src.main:app --reload
```
## Docker
### CPU
```bash
cd docker/cpu
docker compose up
```
### GPU
```bash
cd docker/gpu
docker compose up
```
## Known Issues
- Module imports: Run server from project root
- PyTorch CUDA: Always use --extra-index-url and --index-strategy for GPU env

244
README.md
View file

@ -3,74 +3,83 @@
</p>
# <sub><sub>_`FastKoko`_ </sub></sub>
[![Tests](https://img.shields.io/badge/tests-117%20passed-darkgreen)]()
[![Coverage](https://img.shields.io/badge/coverage-60%25-grey)]()
[![Tests](https://img.shields.io/badge/tests-100%20passed-darkgreen)]()
[![Coverage](https://img.shields.io/badge/coverage-49%25-grey)]()
[![Tested at Model Commit](https://img.shields.io/badge/last--tested--model--commit-a67f113-blue)](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667) [![Try on Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Try%20on-Spaces-blue)](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero)
> Support for Kokoro-82M v1.0 coming very soon!
Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model
- OpenAI-compatible Speech endpoint, with inline voice combination functionality
- NVIDIA GPU accelerated or CPU Onnx inference
- OpenAI-compatible Speech endpoint, with inline voice combination, and mapped naming/models for strict systems
- NVIDIA GPU accelerated or CPU inference (ONNX, Pytorch)
- very fast generation time
- 35x-100x+ real time speed via 4060Ti+
- 5x+ real time speed via M3 Pro CPU
- streaming support w/ variable chunking to control latency & artifacts
- phoneme, simple audio generation web ui utility
- Runs on an 80mb-300mb model (CUDA container + 5gb on disk due to drivers)
- ~35x-100x+ real time speed via 4060Ti+
- ~5x+ real time speed via M3 Pro CPU
- streaming support & tempfile generation
- phoneme based dev endpoints
- (new) Integrated web UI on localhost:8880/web
- (new) Debug endpoints for monitoring threads, storage, and session pools
> [!Tip]
> You can try the new beta version from the `v0.1.2-pre` branch now:
<table>
<tr>
<td>
<img src="https://github.com/user-attachments/assets/440162eb-1918-4999-ab2b-e2730990efd0" width="100%" alt="Voice Analysis Comparison" style="border: 2px solid #333; padding: 5px;">
</td>
<td>
<ul>
<li>Integrated web UI (on localhost:8880/web)</li>
<li>Better concurrency handling, baked in models and voices</li>
<li>Voice name/model mappings to OAI standard</li>
<pre> # with:
docker run -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-cpu:latest # CPU
docker run --gpus all -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-gpu:latest # Nvidia GPU
</pre>
</ul>
</td>
</tr>
</table>
<details open>
<summary>Quick Start</summary>
## Get Started
The service can be accessed through either the API endpoints or the Gradio web interface.
<details >
<summary>Quickest Start (docker run)</summary>
Pre built images are available to run, with arm/multi-arch support, and baked in models
Refer to the core/config.py file for a full list of variables which can be managed via the environment
```bash
docker run -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-cpu:v0.1.4 # CPU, or:
docker run --gpus all -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-gpu:v0.1.4 #NVIDIA GPU
```
Once running, access:
- API Documentation: http://localhost:8880/docs
- Web Interface: http://localhost:8880/web
<div align="center" style="display: flex; justify-content: center; gap: 20px;">
<img src="assets/docs-screenshot.png" width="48%" alt="API Documentation" style="border: 2px solid #333; padding: 10px;">
<img src="assets/webui-screenshot.png" width="48%" alt="Web UI Screenshot" style="border: 2px solid #333; padding: 10px;">
</div>
</details>
<details>
<summary>Quick Start (docker compose) </summary>
1. Install prerequisites, and start the service using Docker Compose (Full setup including UI):
- Install [Docker Desktop](https://www.docker.com/products/docker-desktop/)
- Install [Docker](https://www.docker.com/products/docker-desktop/)
-
- Clone the repository:
```bash
git clone https://github.com/remsky/Kokoro-FastAPI.git
cd Kokoro-FastAPI
# * Switch to stable branch if any issues *
git checkout v0.0.5post1-stable
cd docker/gpu # OR
cd docker/gpu # OR
# cd docker/cpu # Run this or the above
docker compose up --build
docker compose up --build
# if you are missing any models, run:
# python ../scripts/download_model.py --type pth # for GPU
# python ../scripts/download_model.py --type onnx # for CPU
```
```bash
Or directly via UV
./start-cpu.sh
./start-gpu.sh
```
Once started:
- The API will be available at http://localhost:8880
- The UI can be accessed at http://localhost:7860
__Or__ running the API alone using Docker (model + voice packs baked in) (Most Recent):
```bash
docker run -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-cpu:v0.1.0post1 # CPU
docker run --gpus all -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-gpu:v0.1.0post1 # Nvidia GPU
```
4. Run locally as an OpenAI-Compatible Speech Endpoint
- The *Web UI* can be tested at http://localhost:8880/web
- The Gradio UI (deprecating) can be accessed at http://localhost:7860
2. Run locally as an OpenAI-Compatible Speech Endpoint
```python
from openai import OpenAI
client = OpenAI(
@ -87,12 +96,49 @@ The service can be accessed through either the API endpoints or the Gradio web i
response.stream_to_file("output.mp3")
```
</details>
<summary>Direct Run (via uv) </summary>
or visit http://localhost:7860
<p align="center">
<img src="ui\GradioScreenShot.png" width="80%" alt="Voice Analysis Comparison" style="border: 2px solid #333; padding: 10px;">
</p>
1. Install prerequisites ():
- Install [astral-uv](https://docs.astral.sh/uv/)
- Clone the repository:
```bash
git clone https://github.com/remsky/Kokoro-FastAPI.git
cd Kokoro-FastAPI
# if you are missing any models, run:
# python ../scripts/download_model.py --type pth # for GPU
# python ../scripts/download_model.py --type onnx # for CPU
```
Start directly via UV (with hot-reload)
```bash
./start-cpu.sh OR
./start-gpu.sh
```
Once started:
- The API will be available at http://localhost:8880
- The *Web UI* can be tested at http://localhost:8880/web
- The Gradio UI (deprecating) can be accessed at http://localhost:7860
2. Run locally as an OpenAI-Compatible Speech Endpoint
```python
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:8880/v1",
api_key="not-needed"
)
with client.audio.speech.with_streaming_response.create(
model="kokoro",
voice="af_sky+af_bella", #single or multiple voicepack combo
input="Hello world!",
response_format="mp3"
) as response:
response.stream_to_file("output.mp3")
```
</details>
## Features
@ -104,8 +150,8 @@ The service can be accessed through either the API endpoints or the Gradio web i
from openai import OpenAI
client = OpenAI(base_url="http://localhost:8880/v1", api_key="not-needed")
response = client.audio.speech.create(
model="kokoro", # Not used but required for compatibility, also accepts library defaults
voice="af_bella+af_sky",
model="kokoro",
voice="af_bella+af_sky", # see /api/src/core/openai_mappings.json to customize
input="Hello world!",
response_format="mp3"
)
@ -124,7 +170,7 @@ voices = response.json()["voices"]
response = requests.post(
"http://localhost:8880/v1/audio/speech",
json={
"model": "kokoro", # Not used but required for compatibility
"model": "kokoro",
"input": "Hello world!",
"voice": "af_bella",
"response_format": "mp3", # Supported: mp3, wav, opus, flac
@ -207,13 +253,12 @@ If you only want the API, just comment out everything in the docker-compose.yml
Currently, voices created via the API are accessible here, but voice combination/creation has not yet been added
Running the UI Docker Service
Running the UI Docker Service [deprecating]
- If you only want to run the Gradio web interface separately and connect it to an existing API service:
```bash
docker run -p 7860:7860 \
-e API_HOST=<api-hostname-or-ip> \
-e API_PORT=8880 \
ghcr.io/remsky/kokoro-fastapi-ui:v0.1.0
```
- Replace `<api-hostname-or-ip>` with:
@ -232,7 +277,7 @@ environment:
When running the Docker image directly:
```bash
docker run -p 7860:7860 -e DISABLE_LOCAL_SAVING=true ghcr.io/remsky/kokoro-fastapi-ui:latest
docker run -p 7860:7860 -e DISABLE_LOCAL_SAVING=true ghcr.io/remsky/kokoro-fastapi-ui:v0.1.4
```
</details>
@ -243,7 +288,7 @@ docker run -p 7860:7860 -e DISABLE_LOCAL_SAVING=true ghcr.io/remsky/kokoro-fasta
# OpenAI-compatible streaming
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:8880", api_key="not-needed")
base_url="http://localhost:8880/v1", api_key="not-needed")
# Stream to file
with client.audio.speech.with_streaming_response.create(
@ -325,17 +370,17 @@ Benchmarking was performed on generation via the local API using text lengths up
</p>
Key Performance Metrics:
- Realtime Speed: Ranges between 25-50x (generation time to output audio length)
- Realtime Speed: Ranges between 35x-100x (generation time to output audio length)
- Average Processing Rate: 137.67 tokens/second (cl100k_base)
</details>
<details>
<summary>GPU Vs. CPU</summary>
```bash
# GPU: Requires NVIDIA GPU with CUDA 12.1 support (~35x realtime speed)
# GPU: Requires NVIDIA GPU with CUDA 12.1 support (~35x-100x realtime speed)
docker compose up --build
# CPU: ONNX optimized inference (~2.4x realtime speed)
# CPU: ONNX optimized inference (~5x+ realtime speed on M3 Pro)
docker compose -f docker-compose.cpu.yml up --build
```
*Note: Overall speed may have reduced somewhat with the structural changes to accomodate streaming. Looking into it*
@ -355,36 +400,61 @@ Convert text to phonemes and/or generate audio directly from phonemes:
```python
import requests
# Convert text to phonemes
response = requests.post(
"http://localhost:8880/dev/phonemize",
json={
"text": "Hello world!",
"language": "a" # "a" for American English
}
)
result = response.json()
phonemes = result["phonemes"] # Phoneme string e.g ðɪs ɪz ˈoʊnli ɐ tˈɛst
tokens = result["tokens"] # Token IDs including start/end tokens
def get_phonemes(text: str, language: str = "a"):
"""Get phonemes and tokens for input text"""
response = requests.post(
"http://localhost:8880/dev/phonemize",
json={"text": text, "language": language} # "a" for American English
)
response.raise_for_status()
result = response.json()
return result["phonemes"], result["tokens"]
# Generate audio from phonemes
response = requests.post(
"http://localhost:8880/dev/generate_from_phonemes",
json={
"phonemes": phonemes,
"voice": "af_bella",
"speed": 1.0
}
)
def generate_audio_from_phonemes(phonemes: str, voice: str = "af_bella"):
"""Generate audio from phonemes"""
response = requests.post(
"http://localhost:8880/dev/generate_from_phonemes",
json={"phonemes": phonemes, "voice": voice},
headers={"Accept": "audio/wav"}
)
if response.status_code != 200:
print(f"Error: {response.text}")
return None
return response.content
# Save WAV audio
with open("speech.wav", "wb") as f:
f.write(response.content)
# Example usage
text = "Hello world!"
try:
# Convert text to phonemes
phonemes, tokens = get_phonemes(text)
print(f"Phonemes: {phonemes}") # e.g. ðɪs ɪz ˈoʊnli ɐ tˈɛst
print(f"Tokens: {tokens}") # Token IDs including start/end tokens
# Generate and save audio
if audio_bytes := generate_audio_from_phonemes(phonemes):
with open("speech.wav", "wb") as f:
f.write(audio_bytes)
print(f"Generated {len(audio_bytes)} bytes of audio")
except Exception as e:
print(f"Error: {e}")
```
See `examples/phoneme_examples/generate_phonemes.py` for a sample script.
</details>
<details>
<summary>Debug Endpoints</summary>
Monitor system state and resource usage with these endpoints:
- `/debug/threads` - Get thread information and stack traces
- `/debug/storage` - Monitor temp file and output directory usage
- `/debug/system` - Get system information (CPU, memory, GPU)
- `/debug/session_pools` - View ONNX session and CUDA stream status
Useful for debugging resource exhaustion or performance issues.
</details>
## Known Issues
<details>

View file

@ -337,11 +337,13 @@ def recursive_munch(d):
else:
return d
def build_model(path, device):
async def build_model(path, device):
from ..core.paths import load_json, load_model_weights
config = Path(__file__).parent / 'config.json'
assert config.exists(), f'Config path incorrect: config.json not found at {config}'
with open(config, 'r') as r:
args = recursive_munch(json.load(r))
args = recursive_munch(await load_json(config))
assert args.decoder.type == 'istftnet', f'Unknown decoder type: {args.decoder.type}'
decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
@ -365,7 +367,8 @@ def build_model(path, device):
decoder=decoder.to(device).eval(),
text_encoder=text_encoder.to(device).eval(),
)
for key, state_dict in torch.load(path, map_location='cpu', weights_only=True)['net'].items():
weights = await load_model_weights(path, device=device)
for key, state_dict in weights['net'].items():
assert key in model, key
try:
model[key].load_state_dict(state_dict)

View file

@ -9,25 +9,34 @@ class Settings(BaseSettings):
host: str = "0.0.0.0"
port: int = 8880
# TTS Settings
# Application Settings
output_dir: str = "output"
output_dir_size_limit_mb: float = 500.0 # Maximum size of output directory in MB
default_voice: str = "af"
model_dir: str = "/app/models" # Base directory for model files
pytorch_model_path: str = "kokoro-v0_19.pth"
onnx_model_path: str = "kokoro-v0_19.onnx"
voices_dir: str = "voices"
use_gpu: bool = True # Whether to use GPU acceleration if available
use_onnx: bool = False # Whether to use ONNX runtime
allow_local_voice_saving: bool = False # Whether to allow saving combined voices locally
# Container absolute paths
model_dir: str = "/app/api/src/models" # Absolute path in container
voices_dir: str = "/app/api/src/voices" # Absolute path in container
# Audio Settings
sample_rate: int = 24000
max_chunk_size: int = 300 # Maximum size of text chunks for processing
max_chunk_size: int = 400 # Maximum size of text chunks for processing
gap_trim_ms: int = 250 # Amount to trim from streaming chunk ends in milliseconds
# ONNX Optimization Settings
onnx_num_threads: int = 4 # Number of threads for intra-op parallelism
onnx_inter_op_threads: int = 4 # Number of threads for inter-op parallelism
onnx_execution_mode: str = "parallel" # parallel or sequential
onnx_optimization_level: str = "all" # all, basic, or disabled
onnx_memory_pattern: bool = True # Enable memory pattern optimization
onnx_arena_extend_strategy: str = "kNextPowerOfTwo" # Memory allocation strategy
# Web Player Settings
enable_web_player: bool = True # Whether to serve the web player UI
web_player_path: str = "web" # Path to web player static files
cors_origins: list[str] = ["*"] # CORS origins for web player
cors_enabled: bool = True # Whether to enable CORS
# Temp File Settings for WEB Ui
temp_file_dir: str = "api/temp_files" # Directory for temporary audio files (relative to project root)
max_temp_dir_size_mb: int = 2048 # Maximum size of temp directory (2GB)
max_temp_dir_age_hours: int = 1 # Remove temp files older than 1 hour
max_temp_dir_count: int = 3 # Maximum number of temp files to keep
class Config:
env_file = ".env"

View file

@ -0,0 +1,113 @@
"""Model configuration schemas."""
from pydantic import BaseModel, Field
class ONNXCPUConfig(BaseModel):
"""ONNX CPU runtime configuration."""
# Session pooling
max_instances: int = Field(4, description="Maximum concurrent model instances")
instance_timeout: int = Field(60, description="Session timeout in seconds")
# Runtime settings
num_threads: int = Field(8, description="Number of threads for parallel operations")
inter_op_threads: int = Field(4, description="Number of threads for operator parallelism")
execution_mode: str = Field("parallel", description="ONNX execution mode")
optimization_level: str = Field("all", description="ONNX optimization level")
memory_pattern: bool = Field(True, description="Enable memory pattern optimization")
arena_extend_strategy: str = Field("kNextPowerOfTwo", description="Memory arena strategy")
class Config:
frozen = True
class ONNXGPUConfig(ONNXCPUConfig):
"""ONNX GPU-specific configuration."""
# CUDA settings
device_id: int = Field(0, description="CUDA device ID")
gpu_mem_limit: float = Field(0.5, description="Fraction of GPU memory to use")
cudnn_conv_algo_search: str = Field("EXHAUSTIVE", description="CuDNN convolution algorithm search")
# Stream management
cuda_streams: int = Field(2, description="Number of CUDA streams for inference")
stream_timeout: int = Field(60, description="Stream timeout in seconds")
do_copy_in_default_stream: bool = Field(True, description="Copy in default CUDA stream")
class Config:
frozen = True
class PyTorchCPUConfig(BaseModel):
"""PyTorch CPU backend configuration."""
memory_threshold: float = Field(0.8, description="Memory threshold for cleanup")
retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors")
num_threads: int = Field(8, description="Number of threads for parallel operations")
pin_memory: bool = Field(True, description="Whether to pin memory for faster CPU-GPU transfer")
class Config:
frozen = True
class PyTorchGPUConfig(BaseModel):
"""PyTorch GPU backend configuration."""
device_id: int = Field(0, description="CUDA device ID")
use_triton: bool = Field(True, description="Whether to use Triton for CUDA kernels")
memory_threshold: float = Field(0.8, description="Memory threshold for cleanup")
retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors")
sync_cuda: bool = Field(True, description="Whether to synchronize CUDA operations")
cuda_streams: int = Field(2, description="Number of CUDA streams for inference")
stream_timeout: int = Field(60, description="Stream timeout in seconds")
class Config:
frozen = True
class ModelConfig(BaseModel):
"""Model configuration."""
# General settings
model_type: str = Field("pytorch", description="Model type ('pytorch' or 'onnx')")
device_type: str = Field("auto", description="Device type ('cpu', 'gpu', or 'auto')")
cache_models: bool = Field(True, description="Whether to cache loaded models")
cache_voices: bool = Field(True, description="Whether to cache voice tensors")
voice_cache_size: int = Field(2, description="Maximum number of cached voices")
# Model filenames
pytorch_model_file: str = Field("kokoro-v0_19-half.pth", description="PyTorch model filename")
onnx_model_file: str = Field("kokoro-v0_19.onnx", description="ONNX model filename")
# Backend-specific configs
onnx_cpu: ONNXCPUConfig = Field(default_factory=ONNXCPUConfig)
onnx_gpu: ONNXGPUConfig = Field(default_factory=ONNXGPUConfig)
pytorch_cpu: PyTorchCPUConfig = Field(default_factory=PyTorchCPUConfig)
pytorch_gpu: PyTorchGPUConfig = Field(default_factory=PyTorchGPUConfig)
class Config:
frozen = True
def get_backend_config(self, backend_type: str):
"""Get configuration for specific backend.
Args:
backend_type: Backend type ('pytorch_cpu', 'pytorch_gpu', 'onnx_cpu', 'onnx_gpu')
Returns:
Backend-specific configuration
Raises:
ValueError: If backend type is invalid
"""
if backend_type not in {
'pytorch_cpu', 'pytorch_gpu', 'onnx_cpu', 'onnx_gpu'
}:
raise ValueError(f"Invalid backend type: {backend_type}")
return getattr(self, backend_type)
# Global instance
model_config = ModelConfig()

View file

@ -0,0 +1,18 @@
{
"models": {
"tts-1": "kokoro-v0_19",
"tts-1-hd": "kokoro-v0_19",
"kokoro": "kokoro-v0_19"
},
"voices": {
"alloy": "am_adam",
"ash": "af_nicole",
"coral": "bf_emma",
"echo": "af_bella",
"fable": "af_sarah",
"onyx": "bm_george",
"nova": "bf_isabella",
"sage": "am_michael",
"shimmer": "af_sky"
}
}

414
api/src/core/paths.py Normal file
View file

@ -0,0 +1,414 @@
"""Async file and path operations."""
import io
import json
import os
from pathlib import Path
from typing import List, Optional, AsyncIterator, Callable, Set, Dict, Any
import aiofiles
import aiofiles.os
import torch
from loguru import logger
from .config import settings
async def _find_file(
filename: str,
search_paths: List[str],
filter_fn: Optional[Callable[[str], bool]] = None
) -> str:
"""Find file in search paths.
Args:
filename: Name of file to find
search_paths: List of paths to search in
filter_fn: Optional function to filter files
Returns:
Absolute path to file
Raises:
RuntimeError: If file not found
"""
if os.path.isabs(filename) and await aiofiles.os.path.exists(filename):
return filename
for path in search_paths:
full_path = os.path.join(path, filename)
if await aiofiles.os.path.exists(full_path):
if filter_fn is None or filter_fn(full_path):
return full_path
raise RuntimeError(f"File not found: {filename} in paths: {search_paths}")
async def _scan_directories(
search_paths: List[str],
filter_fn: Optional[Callable[[str], bool]] = None
) -> Set[str]:
"""Scan directories for files.
Args:
search_paths: List of paths to scan
filter_fn: Optional function to filter files
Returns:
Set of matching filenames
"""
results = set()
for path in search_paths:
if not await aiofiles.os.path.exists(path):
continue
try:
# Get directory entries first
entries = await aiofiles.os.scandir(path)
# Then process entries after await completes
for entry in entries:
if filter_fn is None or filter_fn(entry.name):
results.add(entry.name)
except Exception as e:
logger.warning(f"Error scanning {path}: {e}")
return results
async def get_model_path(model_name: str) -> str:
"""Get path to model file.
Args:
model_name: Name of model file
Returns:
Absolute path to model file
Raises:
RuntimeError: If model not found
"""
# Get api directory path (two levels up from core)
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
# Construct model directory path relative to api directory
model_dir = os.path.join(api_dir, settings.model_dir)
# Ensure model directory exists
os.makedirs(model_dir, exist_ok=True)
# Search in model directory
search_paths = [model_dir]
logger.debug(f"Searching for model in path: {model_dir}")
return await _find_file(model_name, search_paths)
async def get_voice_path(voice_name: str) -> str:
"""Get path to voice file.
Args:
voice_name: Name of voice file (without .pt extension)
Returns:
Absolute path to voice file
Raises:
RuntimeError: If voice not found
"""
# Get api directory path
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
# Construct voice directory path relative to api directory
voice_dir = os.path.join(api_dir, settings.voices_dir)
# Ensure voice directory exists
os.makedirs(voice_dir, exist_ok=True)
voice_file = f"{voice_name}.pt"
# Search in voice directory
search_paths = [voice_dir]
logger.debug(f"Searching for voice in path: {voice_dir}")
return await _find_file(voice_file, search_paths)
async def list_voices() -> List[str]:
"""List available voice files.
Returns:
List of voice names (without .pt extension)
"""
# Get api directory path
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
# Construct voice directory path relative to api directory
voice_dir = os.path.join(api_dir, settings.voices_dir)
# Ensure voice directory exists
os.makedirs(voice_dir, exist_ok=True)
# Search in voice directory
search_paths = [voice_dir]
logger.debug(f"Scanning for voices in path: {voice_dir}")
def filter_voice_files(name: str) -> bool:
return name.endswith('.pt')
voices = await _scan_directories(search_paths, filter_voice_files)
return sorted([name[:-3] for name in voices]) # Remove .pt extension
async def load_voice_tensor(voice_path: str, device: str = "cpu") -> torch.Tensor:
"""Load voice tensor from file.
Args:
voice_path: Path to voice file
device: Device to load tensor to
Returns:
Voice tensor
Raises:
RuntimeError: If file cannot be read
"""
try:
async with aiofiles.open(voice_path, 'rb') as f:
data = await f.read()
return torch.load(
io.BytesIO(data),
map_location=device,
weights_only=True
)
except Exception as e:
raise RuntimeError(f"Failed to load voice tensor from {voice_path}: {e}")
async def save_voice_tensor(tensor: torch.Tensor, voice_path: str) -> None:
"""Save voice tensor to file.
Args:
tensor: Voice tensor to save
voice_path: Path to save voice file
Raises:
RuntimeError: If file cannot be written
"""
try:
buffer = io.BytesIO()
torch.save(tensor, buffer)
async with aiofiles.open(voice_path, 'wb') as f:
await f.write(buffer.getvalue())
except Exception as e:
raise RuntimeError(f"Failed to save voice tensor to {voice_path}: {e}")
async def load_json(path: str) -> dict:
"""Load JSON file asynchronously.
Args:
path: Path to JSON file
Returns:
Parsed JSON data
Raises:
RuntimeError: If file cannot be read or parsed
"""
try:
async with aiofiles.open(path, 'r', encoding='utf-8') as f:
content = await f.read()
return json.loads(content)
except Exception as e:
raise RuntimeError(f"Failed to load JSON file {path}: {e}")
async def load_model_weights(path: str, device: str = "cpu") -> dict:
"""Load model weights asynchronously.
Args:
path: Path to model file (.pth or .onnx)
device: Device to load model to
Returns:
Model weights
Raises:
RuntimeError: If file cannot be read
"""
try:
async with aiofiles.open(path, 'rb') as f:
data = await f.read()
return torch.load(
io.BytesIO(data),
map_location=device,
weights_only=True
)
except Exception as e:
raise RuntimeError(f"Failed to load model weights from {path}: {e}")
async def read_file(path: str) -> str:
"""Read text file asynchronously.
Args:
path: Path to file
Returns:
File contents as string
Raises:
RuntimeError: If file cannot be read
"""
try:
async with aiofiles.open(path, 'r', encoding='utf-8') as f:
return await f.read()
except Exception as e:
raise RuntimeError(f"Failed to read file {path}: {e}")
async def read_bytes(path: str) -> bytes:
"""Read file as bytes asynchronously.
Args:
path: Path to file
Returns:
File contents as bytes
Raises:
RuntimeError: If file cannot be read
"""
try:
async with aiofiles.open(path, 'rb') as f:
return await f.read()
except Exception as e:
raise RuntimeError(f"Failed to read file {path}: {e}")
async def get_web_file_path(filename: str) -> str:
"""Get path to web static file.
Args:
filename: Name of file in web directory
Returns:
Absolute path to file
Raises:
RuntimeError: If file not found
"""
# Get project root directory (four levels up from core to get to project root)
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
# Construct web directory path relative to project root
web_dir = os.path.join("/app", settings.web_player_path)
# Search in web directory
search_paths = [web_dir]
logger.debug(f"Searching for web file in path: {web_dir}")
return await _find_file(filename, search_paths)
async def get_content_type(path: str) -> str:
"""Get content type for file.
Args:
path: Path to file
Returns:
Content type string
"""
ext = os.path.splitext(path)[1].lower()
return {
'.html': 'text/html',
'.js': 'application/javascript',
'.css': 'text/css',
'.png': 'image/png',
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.gif': 'image/gif',
'.svg': 'image/svg+xml',
'.ico': 'image/x-icon',
}.get(ext, 'application/octet-stream')
async def verify_model_path(model_path: str) -> bool:
"""Verify model file exists at path."""
return await aiofiles.os.path.exists(model_path)
async def cleanup_temp_files() -> None:
"""Clean up old temp files on startup"""
try:
if not await aiofiles.os.path.exists(settings.temp_file_dir):
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
return
entries = await aiofiles.os.scandir(settings.temp_file_dir)
for entry in entries:
if entry.is_file():
stat = await aiofiles.os.stat(entry.path)
max_age = stat.st_mtime + (settings.max_temp_dir_age_hours * 3600)
if max_age < stat.st_mtime:
try:
await aiofiles.os.remove(entry.path)
logger.info(f"Cleaned up old temp file: {entry.name}")
except Exception as e:
logger.warning(f"Failed to delete old temp file {entry.name}: {e}")
except Exception as e:
logger.warning(f"Error cleaning temp files: {e}")
async def get_temp_file_path(filename: str) -> str:
"""Get path to temporary audio file.
Args:
filename: Name of temp file
Returns:
Absolute path to temp file
Raises:
RuntimeError: If temp directory does not exist
"""
temp_path = os.path.join(settings.temp_file_dir, filename)
# Ensure temp directory exists
if not await aiofiles.os.path.exists(settings.temp_file_dir):
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
return temp_path
async def list_temp_files() -> List[str]:
"""List temporary audio files.
Returns:
List of temp file names
"""
if not await aiofiles.os.path.exists(settings.temp_file_dir):
return []
entries = await aiofiles.os.scandir(settings.temp_file_dir)
return [entry.name for entry in entries if entry.is_file()]
async def get_temp_dir_size() -> int:
"""Get total size of temp directory in bytes.
Returns:
Size in bytes
"""
if not await aiofiles.os.path.exists(settings.temp_file_dir):
return 0
total = 0
entries = await aiofiles.os.scandir(settings.temp_file_dir)
for entry in entries:
if entry.is_file():
stat = await aiofiles.os.stat(entry.path)
total += stat.st_size
return total

View file

@ -0,0 +1,16 @@
"""Model inference package."""
from .base import BaseModelBackend
from .model_manager import ModelManager, get_manager
from .onnx_cpu import ONNXCPUBackend
from .onnx_gpu import ONNXGPUBackend
from .pytorch_backend import PyTorchBackend
__all__ = [
'BaseModelBackend',
'ModelManager',
'get_manager',
'ONNXCPUBackend',
'ONNXGPUBackend',
'PyTorchBackend',
]

97
api/src/inference/base.py Normal file
View file

@ -0,0 +1,97 @@
"""Base interfaces for model inference."""
from abc import ABC, abstractmethod
from typing import List, Optional
import numpy as np
import torch
class ModelBackend(ABC):
"""Abstract base class for model inference backends."""
@abstractmethod
async def load_model(self, path: str) -> None:
"""Load model from path.
Args:
path: Path to model file
Raises:
RuntimeError: If model loading fails
"""
pass
@abstractmethod
def generate(
self,
tokens: List[int],
voice: torch.Tensor,
speed: float = 1.0
) -> np.ndarray:
"""Generate audio from tokens.
Args:
tokens: Input token IDs
voice: Voice embedding tensor
speed: Speed multiplier
Returns:
Generated audio samples
Raises:
RuntimeError: If generation fails
"""
pass
@abstractmethod
def unload(self) -> None:
"""Unload model and free resources."""
pass
@property
@abstractmethod
def is_loaded(self) -> bool:
"""Check if model is loaded.
Returns:
True if model is loaded, False otherwise
"""
pass
@property
@abstractmethod
def device(self) -> str:
"""Get device model is running on.
Returns:
Device string ('cpu' or 'cuda')
"""
pass
class BaseModelBackend(ModelBackend):
"""Base implementation of model backend."""
def __init__(self):
"""Initialize base backend."""
self._model: Optional[torch.nn.Module] = None
self._device: str = "cpu"
@property
def is_loaded(self) -> bool:
"""Check if model is loaded."""
return self._model is not None
@property
def device(self) -> str:
"""Get device model is running on."""
return self._device
def unload(self) -> None:
"""Unload model and free resources."""
if self._model is not None:
del self._model
self._model = None
if torch.cuda.is_available():
torch.cuda.empty_cache()

View file

@ -0,0 +1,339 @@
"""Model management and caching."""
import asyncio
from typing import Dict, Optional, Tuple
import torch
from loguru import logger
from ..core import paths
from ..core.config import settings
from ..core.model_config import ModelConfig, model_config
from .base import BaseModelBackend
from .onnx_cpu import ONNXCPUBackend
from .onnx_gpu import ONNXGPUBackend
from .pytorch_backend import PyTorchBackend
from .session_pool import CPUSessionPool, StreamingSessionPool
# Global singleton instance and lock for thread-safe initialization
_manager_instance = None
_manager_lock = asyncio.Lock()
class ModelManager:
"""Manages model loading and inference across backends."""
# Class-level state for shared resources
_loaded_models = {}
_backends = {}
def __init__(self, config: Optional[ModelConfig] = None):
"""Initialize model manager.
Note:
This should not be called directly. Use get_manager() instead.
"""
self._config = config or model_config
# Initialize session pools
self._session_pools = {
'onnx_cpu': CPUSessionPool(),
'onnx_gpu': StreamingSessionPool()
}
# Initialize locks
self._backend_locks: Dict[str, asyncio.Lock] = {}
def _determine_device(self) -> str:
"""Determine device based on settings."""
if settings.use_gpu and torch.cuda.is_available():
return "cuda"
return "cpu"
async def initialize(self) -> None:
"""Initialize backends."""
if self._backends:
logger.debug("Using existing backend instances")
return
device = self._determine_device()
try:
if device == "cuda":
if settings.use_onnx:
self._backends['onnx_gpu'] = ONNXGPUBackend()
self._current_backend = 'onnx_gpu'
logger.info("Initialized new ONNX GPU backend")
else:
self._backends['pytorch'] = PyTorchBackend()
self._current_backend = 'pytorch'
logger.info("Initialized new PyTorch backend on GPU")
else:
if settings.use_onnx:
self._backends['onnx_cpu'] = ONNXCPUBackend()
self._current_backend = 'onnx_cpu'
logger.info("Initialized new ONNX CPU backend")
else:
self._backends['pytorch'] = PyTorchBackend()
self._current_backend = 'pytorch'
logger.info("Initialized new PyTorch backend on CPU")
# Initialize locks for each backend
for backend in self._backends:
self._backend_locks[backend] = asyncio.Lock()
except Exception as e:
logger.error(f"Failed to initialize backend: {e}")
raise RuntimeError("Failed to initialize backend")
async def initialize_with_warmup(self, voice_manager) -> tuple[str, str, int]:
"""Initialize model with warmup and pre-cache voices.
Args:
voice_manager: Voice manager instance for loading voices
Returns:
Tuple of (device type, model type, number of loaded voices)
Raises:
RuntimeError: If initialization fails
"""
try:
# Determine backend type based on settings
if settings.use_onnx:
backend_type = 'onnx_gpu' if settings.use_gpu and torch.cuda.is_available() else 'onnx_cpu'
else:
backend_type = 'pytorch'
# Get backend
backend = self.get_backend(backend_type)
# Get and verify model path
model_file = model_config.pytorch_model_file if not settings.use_onnx else model_config.onnx_model_file
model_path = await paths.get_model_path(model_file)
if not await paths.verify_model_path(model_path):
raise RuntimeError(f"Model file not found: {model_path}")
# Pre-cache default voice and use for warmup
warmup_voice = await voice_manager.load_voice(
settings.default_voice, device=backend.device)
logger.info(f"Pre-cached voice {settings.default_voice} for warmup")
# Initialize model with warmup voice
await self.load_model(model_path, warmup_voice, backend_type)
# Only pre-cache default voice to avoid memory bloat
logger.info(f"Using {settings.default_voice} as warmup voice")
# Get available voices count
voices = await voice_manager.list_voices()
voicepack_count = len(voices)
# Get device info for return
device = "GPU" if settings.use_gpu else "CPU"
model = "ONNX" if settings.use_onnx else "PyTorch"
return device, model, voicepack_count
except Exception as e:
logger.error(f"Failed to initialize model with warmup: {e}")
raise RuntimeError(f"Failed to initialize model with warmup: {e}")
def get_backend(self, backend_type: Optional[str] = None) -> BaseModelBackend:
"""Get specified backend.
Args:
backend_type: Backend type ('pytorch_cpu', 'pytorch_gpu', 'onnx_cpu', 'onnx_gpu'),
uses default if None
Returns:
Model backend instance
Raises:
ValueError: If backend type is invalid
RuntimeError: If no backends are available
"""
if not self._backends:
raise RuntimeError("No backends available")
if backend_type is None:
backend_type = self._current_backend
if backend_type not in self._backends:
raise ValueError(
f"Invalid backend type: {backend_type}. "
f"Available backends: {', '.join(self._backends.keys())}"
)
return self._backends[backend_type]
def _determine_backend(self, model_path: str) -> str:
"""Determine appropriate backend based on model file and settings.
Args:
model_path: Path to model file
Returns:
Backend type to use
"""
# If ONNX is preferred or model is ONNX format
if settings.use_onnx or model_path.lower().endswith('.onnx'):
return 'onnx_gpu' if settings.use_gpu and torch.cuda.is_available() else 'onnx_cpu'
return 'pytorch'
async def load_model(
self,
model_path: str,
warmup_voice: Optional[torch.Tensor] = None,
backend_type: Optional[str] = None
) -> None:
"""Load model on specified backend.
Args:
model_path: Path to model file
warmup_voice: Optional voice tensor for warmup, skips warmup if None
backend_type: Backend to load on, uses default if None
Raises:
RuntimeError: If model loading fails
"""
try:
# Get absolute model path
abs_path = await paths.get_model_path(model_path)
# Auto-determine backend if not specified
if backend_type is None:
backend_type = self._determine_backend(abs_path)
# Get backend lock
lock = self._backend_locks[backend_type]
async with lock:
backend = self.get_backend(backend_type)
# For ONNX backends, use session pool
if backend_type.startswith('onnx'):
pool = self._session_pools[backend_type]
backend._session = await pool.get_session(abs_path)
self._loaded_models[backend_type] = abs_path
logger.info(f"Fetched model instance from {backend_type} pool")
# For PyTorch backends, load normally
else:
# Check if model is already loaded
if (backend_type in self._loaded_models and
self._loaded_models[backend_type] == abs_path and
backend.is_loaded):
logger.info(f"Fetching existing model instance from {backend_type}")
return
# Load model
await backend.load_model(abs_path)
self._loaded_models[backend_type] = abs_path
logger.info(f"Initialized new model instance on {backend_type}")
# Run warmup if voice provided
if warmup_voice is not None:
await self._warmup_inference(backend, warmup_voice)
except Exception as e:
# Clear cached path on failure
self._loaded_models.pop(backend_type, None)
raise RuntimeError(f"Failed to load model: {e}")
async def _warmup_inference(self, backend: BaseModelBackend, voice: torch.Tensor) -> None:
"""Run warmup inference to initialize model.
Args:
backend: Model backend to warm up
voice: Voice tensor already loaded on correct device
"""
try:
# Import here to avoid circular imports
from ..services.text_processing import process_text
# Use real text
text = "Testing text to speech synthesis."
# Process through pipeline
tokens = process_text(text)
if not tokens:
raise ValueError("Text processing failed")
# Run inference
backend.generate(tokens, voice, speed=1.0)
logger.debug("Completed warmup inference")
except Exception as e:
logger.warning(f"Warmup inference failed: {e}")
raise
async def generate(
self,
tokens: list[int],
voice: torch.Tensor,
speed: float = 1.0,
backend_type: Optional[str] = None
) -> torch.Tensor:
"""Generate audio using specified backend.
Args:
tokens: Input token IDs
voice: Voice tensor already loaded on correct device
speed: Speed multiplier
backend_type: Backend to use, uses default if None
Returns:
Generated audio tensor
Raises:
RuntimeError: If generation fails
"""
backend = self.get_backend(backend_type)
if not backend.is_loaded:
raise RuntimeError("Model not loaded")
try:
# Generate audio using provided voice tensor
# No lock needed here since inference is thread-safe
return backend.generate(tokens, voice, speed)
except Exception as e:
raise RuntimeError(f"Generation failed: {e}")
def unload_all(self) -> None:
"""Unload models from all backends and clear cache."""
# Clean up session pools
for pool in self._session_pools.values():
pool.cleanup()
# Unload PyTorch backends
for backend in self._backends.values():
backend.unload()
self._loaded_models.clear()
logger.info("Unloaded all models and cleared cache")
@property
def available_backends(self) -> list[str]:
"""Get list of available backends.
"""
return list(self._backends.keys())
@property
def current_backend(self) -> str:
"""Get current default backend.
"""
return self._current_backend
async def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
"""Get global model manager instance.
Args:
config: Optional model configuration
Returns:
ModelManager instance
Thread Safety:
This function should be thread-safe. Lemme know if it unravels on you
"""
global _manager_instance
# Fast path - return existing instance without lock
if _manager_instance is not None:
return _manager_instance
# Slow path - create new instance with lock
async with _manager_lock:
# Double-check pattern
if _manager_instance is None:
_manager_instance = ModelManager(config)
await _manager_instance.initialize()
return _manager_instance

View file

@ -0,0 +1,115 @@
"""CPU-based ONNX inference backend."""
from typing import Optional
import numpy as np
import torch
from loguru import logger
from onnxruntime import InferenceSession
from ..core import paths
from ..core.model_config import model_config
from .base import BaseModelBackend
from .session_pool import create_session_options, create_provider_options
class ONNXCPUBackend(BaseModelBackend):
"""ONNX-based CPU inference backend."""
def __init__(self):
"""Initialize CPU backend."""
super().__init__()
self._device = "cpu"
self._session: Optional[InferenceSession] = None
@property
def is_loaded(self) -> bool:
"""Check if model is loaded."""
return self._session is not None
async def load_model(self, path: str) -> None:
"""Load ONNX model.
Args:
path: Path to model file
Raises:
RuntimeError: If model loading fails
"""
try:
# Get verified model path
model_path = await paths.get_model_path(path)
logger.info(f"Loading ONNX model: {model_path}")
# Configure session
options = create_session_options(is_gpu=False)
provider_options = create_provider_options(is_gpu=False)
# Create session
self._session = InferenceSession(
model_path,
sess_options=options,
providers=["CPUExecutionProvider"],
provider_options=[provider_options]
)
except Exception as e:
raise RuntimeError(f"Failed to load ONNX model: {e}")
def generate(
self,
tokens: list[int],
voice: torch.Tensor,
speed: float = 1.0
) -> np.ndarray:
"""Generate audio using ONNX model.
Args:
tokens: Input token IDs
voice: Voice embedding tensor
speed: Speed multiplier
Returns:
Generated audio samples
Raises:
RuntimeError: If generation fails
"""
if not self.is_loaded:
raise RuntimeError("Model not loaded")
try:
# Prepare inputs with start/end tokens
tokens_input = np.array([[0, *tokens, 0]], dtype=np.int64) # Add start/end tokens
style_input = voice[len(tokens) + 2].numpy() # Adjust index for start/end tokens
speed_input = np.full(1, speed, dtype=np.float32)
# Build base inputs
inputs = {
"style": style_input,
"speed": speed_input
}
# Try both possible token input names #TODO:
for token_name in ["tokens", "input_ids"]:
try:
inputs[token_name] = tokens_input
result = self._session.run(None, inputs)
return result[0]
except Exception:
del inputs[token_name]
continue
raise RuntimeError("Model does not accept either 'tokens' or 'input_ids' as input name")
except Exception as e:
raise RuntimeError(f"Generation failed: {e}")
def unload(self) -> None:
"""Unload model and free resources."""
if self._session is not None:
del self._session
self._session = None
if torch.cuda.is_available():
torch.cuda.empty_cache()

View file

@ -0,0 +1,119 @@
"""GPU-based ONNX inference backend."""
from typing import Optional
import numpy as np
import torch
from loguru import logger
from onnxruntime import InferenceSession
from ..core import paths
from ..core.model_config import model_config
from .base import BaseModelBackend
from .session_pool import create_session_options, create_provider_options
class ONNXGPUBackend(BaseModelBackend):
"""ONNX-based GPU inference backend."""
def __init__(self):
"""Initialize GPU backend."""
super().__init__()
if not torch.cuda.is_available():
raise RuntimeError("CUDA not available")
self._device = "cuda"
self._session: Optional[InferenceSession] = None
# Configure GPU
torch.cuda.set_device(model_config.onnx_gpu.device_id)
@property
def is_loaded(self) -> bool:
"""Check if model is loaded."""
return self._session is not None
async def load_model(self, path: str) -> None:
"""Load ONNX model.
Args:
path: Path to model file
Raises:
RuntimeError: If model loading fails
"""
try:
# Get verified model path
model_path = await paths.get_model_path(path)
logger.info(f"Loading ONNX model on GPU: {model_path}")
# Configure session
options = create_session_options(is_gpu=True)
provider_options = create_provider_options(is_gpu=True)
# Create session with CUDA provider
self._session = InferenceSession(
model_path,
sess_options=options,
providers=["CUDAExecutionProvider"],
provider_options=[provider_options]
)
except Exception as e:
raise RuntimeError(f"Failed to load ONNX model: {e}")
def generate(
self,
tokens: list[int],
voice: torch.Tensor,
speed: float = 1.0
) -> np.ndarray:
"""Generate audio using ONNX model.
Args:
tokens: Input token IDs
voice: Voice embedding tensor
speed: Speed multiplier
Returns:
Generated audio samples
Raises:
RuntimeError: If generation fails
"""
if not self.is_loaded:
raise RuntimeError("Model not loaded")
try:
# Prepare inputs
tokens_input = np.array([[0, *tokens, 0]], dtype=np.int64) # Add start/end tokens
# Use modulo to ensure index stays within voice tensor bounds
style_idx = (len(tokens) + 2) % voice.size(0) # Add 2 for start/end tokens
style_input = voice[style_idx].cpu().numpy() # Move to CPU for ONNX
speed_input = np.full(1, speed, dtype=np.float32)
# Run inference
result = self._session.run(
None,
{
"tokens": tokens_input,
"style": style_input,
"speed": speed_input
}
)
return result[0]
except Exception as e:
if "out of memory" in str(e).lower():
# Clear CUDA cache and retry
torch.cuda.empty_cache()
return self.generate(tokens, voice, speed)
raise RuntimeError(f"Generation failed: {e}")
def unload(self) -> None:
"""Unload model and free resources."""
if self._session is not None:
del self._session
self._session = None
torch.cuda.empty_cache()

View file

@ -0,0 +1,244 @@
"""PyTorch inference backend with environment-based configuration."""
import gc
from typing import Optional
from contextlib import nullcontext
from typing import Optional
import numpy as np
import torch
from loguru import logger
from ..builds.models import build_model
from ..core import paths
from ..core.model_config import model_config
from ..core.config import settings
from .base import BaseModelBackend
class CUDAStreamManager:
"""CUDA stream manager for GPU operations."""
def __init__(self, num_streams: int):
"""Initialize stream manager.
Args:
num_streams: Number of CUDA streams
"""
self.streams = [torch.cuda.Stream() for _ in range(num_streams)]
self._current = 0
def get_next_stream(self) -> torch.cuda.Stream:
"""Get next available stream.
Returns:
CUDA stream
"""
stream = self.streams[self._current]
self._current = (self._current + 1) % len(self.streams)
return stream
@torch.no_grad()
def forward(
model: torch.nn.Module,
tokens: list[int],
ref_s: torch.Tensor,
speed: float,
stream: Optional[torch.cuda.Stream] = None,
) -> np.ndarray:
"""Forward pass through model.
Args:
model: PyTorch model
tokens: Input tokens
ref_s: Reference signal
speed: Speed multiplier
stream: Optional CUDA stream (GPU only)
Returns:
Generated audio
"""
device = ref_s.device
# Use provided stream or default for GPU
context = (
torch.cuda.stream(stream) if stream and device.type == "cuda" else nullcontext()
)
with context:
# Initial tensor setup
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
text_mask = length_to_mask(input_lengths).to(device)
# Split reference signals
style_dim = 128
s_ref = ref_s[:, :style_dim].clone().to(device)
s_content = ref_s[:, style_dim:].clone().to(device)
# BERT and encoder pass
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
# Predictor forward pass
d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask)
x, _ = model.predictor.lstm(d)
# Duration prediction
duration = model.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long()
del duration, x
# Alignment matrix construction
pred_aln_trg = torch.zeros(
input_lengths.item(), pred_dur.sum().item(), device=device
)
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
c_frame += pred_dur[0, i].item()
pred_aln_trg = pred_aln_trg.unsqueeze(0)
# Matrix multiplications
en = d.transpose(-1, -2) @ pred_aln_trg
del d
F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
del en
# Final text encoding and decoding
t_en = model.text_encoder(tokens, input_lengths, text_mask)
asr = t_en @ pred_aln_trg
del t_en
# Generate output
output = model.decoder(asr, F0_pred, N_pred, s_ref)
# Ensure operation completion if using custom stream
if stream and device.type == "cuda":
stream.synchronize()
return output.squeeze().cpu().numpy()
def length_to_mask(lengths: torch.Tensor) -> torch.Tensor:
"""Create attention mask from lengths."""
max_len = lengths.max()
mask = torch.arange(max_len, device=lengths.device)[None, :].expand(
lengths.shape[0], -1
)
if lengths.dtype != mask.dtype:
mask = mask.to(dtype=lengths.dtype)
return mask + 1 > lengths[:, None]
class PyTorchBackend(BaseModelBackend):
"""PyTorch inference backend with environment-based configuration."""
def __init__(self):
"""Initialize backend based on environment configuration."""
super().__init__()
# Configure device based on settings
self._device = (
"cuda" if settings.use_gpu and torch.cuda.is_available() else "cpu"
)
self._model: Optional[torch.nn.Module] = None
# Apply device-specific configurations
if self._device == "cuda":
config = model_config.pytorch_gpu
if config.sync_cuda:
torch.cuda.synchronize()
torch.cuda.set_device(config.device_id)
self._stream_manager = CUDAStreamManager(config.cuda_streams)
else:
config = model_config.pytorch_cpu
if config.num_threads > 0:
torch.set_num_threads(config.num_threads)
if config.pin_memory:
torch.set_default_tensor_type(torch.FloatTensor)
async def load_model(self, path: str) -> None:
"""Load PyTorch model.
Args:
path: Path to model file
Raises:
RuntimeError: If model loading fails
"""
try:
# Get verified model path
model_path = await paths.get_model_path(path)
logger.info(f"Loading PyTorch model on {self._device}: {model_path}")
self._model = await build_model(model_path, self._device)
except Exception as e:
raise RuntimeError(f"Failed to load PyTorch model: {e}")
def generate(
self, tokens: list[int], voice: torch.Tensor, speed: float = 1.0
) -> np.ndarray:
"""Generate audio using model.
Args:
tokens: Input token IDs
voice: Voice embedding tensor
speed: Speed multiplier
Returns:
Generated audio samples
Raises:
RuntimeError: If generation fails
"""
if not self.is_loaded:
raise RuntimeError("Model not loaded")
try:
# Memory management for GPU
if self._device == "cuda":
if self._check_memory():
self._clear_memory()
stream = self._stream_manager.get_next_stream()
else:
stream = None
# Get reference style from voice pack
ref_s = voice[len(tokens)].clone().to(self._device)
if ref_s.dim() == 1:
ref_s = ref_s.unsqueeze(0)
# Generate audio
return forward(self._model, tokens, ref_s, speed, stream)
except Exception as e:
logger.error(f"Generation failed: {e}")
if (
self._device == "cuda"
and model_config.pytorch_gpu.retry_on_oom
and "out of memory" in str(e).lower()
):
self._clear_memory()
return self.generate(tokens, voice, speed)
raise
finally:
if self._device == "cuda" and model_config.pytorch_gpu.sync_cuda:
torch.cuda.synchronize()
def _check_memory(self) -> bool:
"""Check if memory usage is above threshold."""
if self._device == "cuda":
memory_gb = torch.cuda.memory_allocated() / 1e9
return memory_gb > model_config.pytorch_gpu.memory_threshold
return False
def _clear_memory(self) -> None:
"""Clear device memory."""
if self._device == "cuda":
torch.cuda.empty_cache()
gc.collect()

View file

@ -0,0 +1,272 @@
"""Session pooling for model inference."""
import asyncio
import time
from dataclasses import dataclass
from typing import Dict, Optional, Set
import torch
from loguru import logger
from onnxruntime import (
ExecutionMode,
GraphOptimizationLevel,
InferenceSession,
SessionOptions
)
from ..core import paths
from ..core.model_config import model_config
@dataclass
class SessionInfo:
"""Session information."""
session: InferenceSession
last_used: float
stream_id: Optional[int] = None
def create_session_options(is_gpu: bool = False) -> SessionOptions:
"""Create ONNX session options.
Args:
is_gpu: Whether to use GPU configuration
Returns:
Configured session options
"""
options = SessionOptions()
config = model_config.onnx_gpu if is_gpu else model_config.onnx_cpu
# Set optimization level
if config.optimization_level == "all":
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
elif config.optimization_level == "basic":
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
else:
options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
# Configure threading
options.intra_op_num_threads = config.num_threads
options.inter_op_num_threads = config.inter_op_threads
# Set execution mode
options.execution_mode = (
ExecutionMode.ORT_PARALLEL
if config.execution_mode == "parallel"
else ExecutionMode.ORT_SEQUENTIAL
)
# Configure memory optimization
options.enable_mem_pattern = config.memory_pattern
return options
def create_provider_options(is_gpu: bool = False) -> Dict:
"""Create provider options.
Args:
is_gpu: Whether to use GPU configuration
Returns:
Provider configuration
"""
if is_gpu:
config = model_config.onnx_gpu
return {
"device_id": config.device_id,
"arena_extend_strategy": config.arena_extend_strategy,
"gpu_mem_limit": int(config.gpu_mem_limit * torch.cuda.get_device_properties(0).total_memory),
"cudnn_conv_algo_search": config.cudnn_conv_algo_search,
"do_copy_in_default_stream": config.do_copy_in_default_stream
}
else:
return {
"arena_extend_strategy": model_config.onnx_cpu.arena_extend_strategy,
"cpu_memory_arena_cfg": "cpu:0"
}
class BaseSessionPool:
"""Base session pool implementation."""
def __init__(self, max_size: int, timeout: int):
"""Initialize session pool.
Args:
max_size: Maximum number of concurrent sessions
timeout: Session timeout in seconds
"""
self._max_size = max_size
self._timeout = timeout
self._sessions: Dict[str, SessionInfo] = {}
self._lock = asyncio.Lock()
async def get_session(self, model_path: str) -> InferenceSession:
"""Get session from pool.
Args:
model_path: Path to model file
Returns:
ONNX inference session
Raises:
RuntimeError: If no sessions available
"""
async with self._lock:
# Clean expired sessions
self._cleanup_expired()
# TODO: Change session tracking to use unique IDs instead of model paths
# This would allow multiple instances of the same model
# Check if session exists and is valid
if model_path in self._sessions:
session_info = self._sessions[model_path]
session_info.last_used = time.time()
return session_info.session
# TODO: Modify session limit check to count instances per model path
# Rather than total sessions across all models
if len(self._sessions) >= self._max_size:
raise RuntimeError(
f"Maximum number of sessions reached ({self._max_size}). "
"Try again later or reduce concurrent requests."
)
# Create new session
session = await self._create_session(model_path)
self._sessions[model_path] = SessionInfo(
session=session,
last_used=time.time()
)
return session
def _cleanup_expired(self) -> None:
"""Remove expired sessions."""
current_time = time.time()
expired = [
path for path, info in self._sessions.items()
if current_time - info.last_used > self._timeout
]
for path in expired:
logger.info(f"Removing expired session: {path}")
del self._sessions[path]
async def _create_session(self, model_path: str) -> InferenceSession:
"""Create new session.
Args:
model_path: Path to model file
Returns:
ONNX inference session
"""
raise NotImplementedError
def cleanup(self) -> None:
"""Clean up all sessions."""
self._sessions.clear()
class StreamingSessionPool(BaseSessionPool):
"""GPU session pool with CUDA streams."""
def __init__(self):
"""Initialize GPU session pool."""
config = model_config.onnx_gpu
super().__init__(config.cuda_streams, config.stream_timeout)
self._available_streams: Set[int] = set(range(config.cuda_streams))
async def get_session(self, model_path: str) -> InferenceSession:
"""Get session with CUDA stream.
Args:
model_path: Path to model file
Returns:
ONNX inference session
Raises:
RuntimeError: If no streams available
"""
async with self._lock:
# Clean expired sessions
self._cleanup_expired()
# Try to find existing session
if model_path in self._sessions:
session_info = self._sessions[model_path]
session_info.last_used = time.time()
return session_info.session
# Get available stream
if not self._available_streams:
raise RuntimeError("No CUDA streams available")
stream_id = self._available_streams.pop()
try:
# Create new session
session = await self._create_session(model_path)
self._sessions[model_path] = SessionInfo(
session=session,
last_used=time.time(),
stream_id=stream_id
)
return session
except Exception:
# Return stream to pool on failure
self._available_streams.add(stream_id)
raise
def _cleanup_expired(self) -> None:
"""Remove expired sessions and return streams."""
current_time = time.time()
expired = [
path for path, info in self._sessions.items()
if current_time - info.last_used > self._timeout
]
for path in expired:
info = self._sessions[path]
if info.stream_id is not None:
self._available_streams.add(info.stream_id)
logger.info(f"Removing expired session: {path}")
del self._sessions[path]
async def _create_session(self, model_path: str) -> InferenceSession:
"""Create new session with CUDA provider."""
abs_path = await paths.get_model_path(model_path)
options = create_session_options(is_gpu=True)
provider_options = create_provider_options(is_gpu=True)
return InferenceSession(
abs_path,
sess_options=options,
providers=["CUDAExecutionProvider"],
provider_options=[provider_options]
)
class CPUSessionPool(BaseSessionPool):
"""CPU session pool."""
def __init__(self):
"""Initialize CPU session pool."""
config = model_config.onnx_cpu
super().__init__(config.max_instances, config.instance_timeout)
async def _create_session(self, model_path: str) -> InferenceSession:
"""Create new session with CPU provider."""
abs_path = await paths.get_model_path(model_path)
options = create_session_options(is_gpu=False)
provider_options = create_provider_options(is_gpu=False)
return InferenceSession(
abs_path,
sess_options=options,
providers=["CPUExecutionProvider"],
provider_options=[provider_options]
)

View file

@ -0,0 +1,215 @@
"""Voice pack management and caching."""
import os
from typing import Dict, List, Optional
import torch
from loguru import logger
from ..core import paths
from ..core.config import settings
from ..structures.model_schemas import VoiceConfig
class VoiceManager:
"""Manages voice loading and operations."""
def __init__(self, config: Optional[VoiceConfig] = None):
"""Initialize voice manager.
Args:
config: Optional voice configuration
"""
self._config = config or VoiceConfig()
self._voice_cache: Dict[str, torch.Tensor] = {}
def get_voice_path(self, voice_name: str) -> Optional[str]:
"""Get path to voice file.
Args:
voice_name: Name of voice
Returns:
Path to voice file if exists, None otherwise
"""
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
voice_path = os.path.join(api_dir, settings.voices_dir, f"{voice_name}.pt")
return voice_path if os.path.exists(voice_path) else None
async def load_voice(self, voice_name: str, device: str = "cpu") -> torch.Tensor:
"""Load voice tensor.
Args:
voice_name: Name of voice to load
device: Device to load voice on
Returns:
Voice tensor
Raises:
RuntimeError: If voice loading fails
"""
# Check if it's a combined voice request
if "+" in voice_name:
voices = [v.strip() for v in voice_name.split("+") if v.strip()]
if len(voices) < 2:
raise RuntimeError(f"Invalid combined voice name: {voice_name}")
# Load and combine voices
voice_tensors = []
for voice in voices:
try:
voice_tensor = await self.load_voice(voice, device)
voice_tensors.append(voice_tensor)
except Exception as e:
raise RuntimeError(f"Failed to load base voice {voice}: {e}")
return torch.mean(torch.stack(voice_tensors), dim=0)
# Handle single voice
voice_path = self.get_voice_path(voice_name)
if not voice_path:
raise RuntimeError(f"Voice not found: {voice_name}")
# Check cache
cache_key = f"{voice_path}_{device}"
if self._config.use_cache and cache_key in self._voice_cache:
return self._voice_cache[cache_key]
# Load voice tensor
try:
voice = await paths.load_voice_tensor(voice_path, device=device)
except Exception as e:
raise RuntimeError(f"Failed to load voice {voice_name}: {e}")
# Cache if enabled
if self._config.use_cache:
self._manage_cache()
self._voice_cache[cache_key] = voice
logger.debug(f"Cached voice: {voice_name} on {device}")
return voice
def _manage_cache(self) -> None:
"""Manage voice cache size using simple LRU."""
if len(self._voice_cache) >= self._config.cache_size:
# Remove least recently used voice
oldest = next(iter(self._voice_cache))
del self._voice_cache[oldest]
torch.cuda.empty_cache() # Clean up GPU memory if needed
logger.debug(f"Removed LRU voice from cache: {oldest}")
async def combine_voices(self, voices: List[str], device: str = "cpu") -> str:
"""Combine multiple voices into a new voice.
Args:
voices: List of voice names to combine
device: Device to load voices on
Returns:
Name of combined voice
Raises:
ValueError: If fewer than 2 voices provided
RuntimeError: If voice combination fails
"""
if len(voices) < 2:
raise ValueError("At least 2 voices are required for combination")
# Create combined name using + as separator
combined_name = "+".join(voices)
# If saving is enabled, try to save the combination
if settings.allow_local_voice_saving:
try:
# Load and combine voices
combined_tensor = await self.load_voice(combined_name, device)
# Save to disk
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
voices_dir = os.path.join(api_dir, settings.voices_dir)
os.makedirs(voices_dir, exist_ok=True)
combined_path = os.path.join(voices_dir, f"{combined_name}.pt")
try:
torch.save(combined_tensor, combined_path)
# Cache with path-based key
self._voice_cache[f"{combined_path}_{device}"] = combined_tensor
except Exception as e:
raise RuntimeError(f"Failed to save combined voice: {e}")
except Exception as e:
logger.warning(f"Failed to save combined voice: {e}")
# Continue without saving - will be combined on-the-fly when needed
return combined_name
async def list_voices(self) -> List[str]:
"""List available voices.
Returns:
List of voice names
"""
voices = set() # Use set to avoid duplicates
try:
# Get voices from disk
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
voices_dir = os.path.join(api_dir, settings.voices_dir)
os.makedirs(voices_dir, exist_ok=True)
for entry in os.listdir(voices_dir):
if entry.endswith(".pt"):
voices.add(entry[:-3])
except Exception as e:
logger.error(f"Error listing voices: {e}")
return sorted(list(voices))
def validate_voice(self, voice_path: str) -> bool:
"""Validate voice file.
Args:
voice_path: Path to voice file
Returns:
True if valid, False otherwise
"""
try:
if not os.path.exists(voice_path):
return False
voice = torch.load(voice_path, map_location="cpu")
return isinstance(voice, torch.Tensor)
except Exception:
return False
@property
def cache_info(self) -> Dict[str, int]:
"""Get cache statistics.
Returns:
Dictionary with cache info
"""
return {
'size': len(self._voice_cache),
'max_size': self._config.cache_size
}
# Global singleton instance and lock
_manager_instance = None
async def get_manager(config: Optional[VoiceConfig] = None) -> VoiceManager:
"""Get global voice manager instance.
Args:
config: Optional voice configuration
Returns:
VoiceManager instance
"""
global _manager_instance
if _manager_instance is None:
_manager_instance = VoiceManager(config)
return _manager_instance

View file

@ -2,19 +2,22 @@
FastAPI OpenAI Compatible API
"""
import os
import sys
from contextlib import asynccontextmanager
from pathlib import Path
import torch
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from .core.config import settings
from .routers.web_player import router as web_router
from .routers.development import router as dev_router
from .routers.openai_compatible import router as openai_router
from .services.tts_model import TTSModel
from .services.tts_service import TTSService
from .routers.debug import router as debug_router
def setup_logger():
@ -42,11 +45,39 @@ setup_logger()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for model initialization"""
from .inference.model_manager import get_manager
from .inference.voice_manager import get_manager as get_voice_manager
from .services.temp_manager import cleanup_temp_files
# Clean old temp files on startup
await cleanup_temp_files()
logger.info("Loading TTS model and voice packs...")
# Initialize the main model with warm-up
voicepack_count = await TTSModel.setup()
# boundary = "█████╗"*9
try:
# Initialize managers globally
model_manager = await get_manager()
voice_manager = await get_voice_manager()
# Initialize model with warmup and get status
device, model, voicepack_count = await model_manager.initialize_with_warmup(voice_manager)
except FileNotFoundError:
logger.error("""
Model files not found! You need to either:
1. Download models using the scripts:
GPU: python docker/scripts/download_model.py --type pth
CPU: python docker/scripts/download_model.py --type onnx
2. Set environment variables in docker-compose:
GPU: DOWNLOAD_PTH=true
CPU: DOWNLOAD_ONNX=true
""")
raise
except Exception as e:
logger.error(f"Failed to initialize model: {e}")
raise
boundary = "" * 2*12
startup_msg = f"""
@ -54,16 +85,22 @@ async def lifespan(app: FastAPI):
{boundary}
"""
# TODO: Improve CPU warmup, threads, memory, etc
startup_msg += f"\nModel warmed up on {TTSModel.get_device()}"
startup_msg += f"\n{voicepack_count} voice packs loaded\n"
startup_msg += f"\nModel warmed up on {device}: {model}"
startup_msg += f"\n{voicepack_count} voice packs loaded"
# Add web player info if enabled
if settings.enable_web_player:
startup_msg += f"\n\nBeta Web Player: http://{settings.host}:{settings.port}/web/"
else:
startup_msg += "\n\nWeb Player: disabled"
startup_msg += f"\n{boundary}\n"
logger.info(startup_msg)
@ -79,19 +116,22 @@ app = FastAPI(
openapi_url="/openapi.json", # Explicitly enable OpenAPI schema
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Add CORS middleware if enabled
if settings.cors_enabled:
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Include routers
app.include_router(openai_router, prefix="/v1")
app.include_router(dev_router) # New development endpoints
# app.include_router(text_router) # Deprecated but still live for backwards compatibility
app.include_router(dev_router) # Development endpoints
app.include_router(debug_router) # Debug endpoints
if settings.enable_web_player:
app.include_router(web_router, prefix="/web") # Web player static files
# Health check endpoint

188
api/src/routers/debug.py Normal file
View file

@ -0,0 +1,188 @@
from fastapi import APIRouter
import psutil
import threading
import time
from datetime import datetime
try:
import GPUtil
GPU_AVAILABLE = True
except ImportError:
GPU_AVAILABLE = False
router = APIRouter(tags=["debug"])
@router.get("/debug/threads")
async def get_thread_info():
process = psutil.Process()
current_threads = threading.enumerate()
# Get per-thread CPU times
thread_details = []
for thread in current_threads:
thread_info = {
"name": thread.name,
"id": thread.ident,
"alive": thread.is_alive(),
"daemon": thread.daemon
}
thread_details.append(thread_info)
return {
"total_threads": process.num_threads(),
"active_threads": len(current_threads),
"thread_names": [t.name for t in current_threads],
"thread_details": thread_details,
"memory_mb": process.memory_info().rss / 1024 / 1024
}
@router.get("/debug/storage")
async def get_storage_info():
# Get disk partitions
partitions = psutil.disk_partitions()
storage_info = []
for partition in partitions:
try:
usage = psutil.disk_usage(partition.mountpoint)
storage_info.append({
"device": partition.device,
"mountpoint": partition.mountpoint,
"fstype": partition.fstype,
"total_gb": usage.total / (1024**3),
"used_gb": usage.used / (1024**3),
"free_gb": usage.free / (1024**3),
"percent_used": usage.percent
})
except PermissionError:
continue
return {
"storage_info": storage_info
}
@router.get("/debug/system")
async def get_system_info():
process = psutil.Process()
# CPU Info
cpu_info = {
"cpu_count": psutil.cpu_count(),
"cpu_percent": psutil.cpu_percent(interval=1),
"per_cpu_percent": psutil.cpu_percent(interval=1, percpu=True),
"load_avg": psutil.getloadavg()
}
# Memory Info
virtual_memory = psutil.virtual_memory()
swap_memory = psutil.swap_memory()
memory_info = {
"virtual": {
"total_gb": virtual_memory.total / (1024**3),
"available_gb": virtual_memory.available / (1024**3),
"used_gb": virtual_memory.used / (1024**3),
"percent": virtual_memory.percent
},
"swap": {
"total_gb": swap_memory.total / (1024**3),
"used_gb": swap_memory.used / (1024**3),
"free_gb": swap_memory.free / (1024**3),
"percent": swap_memory.percent
}
}
# Process Info
process_info = {
"pid": process.pid,
"status": process.status(),
"create_time": datetime.fromtimestamp(process.create_time()).isoformat(),
"cpu_percent": process.cpu_percent(),
"memory_percent": process.memory_percent(),
}
# Network Info
network_info = {
"connections": len(process.net_connections()),
"network_io": psutil.net_io_counters()._asdict()
}
# GPU Info if available
gpu_info = None
if GPU_AVAILABLE:
try:
gpus = GPUtil.getGPUs()
gpu_info = [{
"id": gpu.id,
"name": gpu.name,
"load": gpu.load,
"memory": {
"total": gpu.memoryTotal,
"used": gpu.memoryUsed,
"free": gpu.memoryFree,
"percent": (gpu.memoryUsed / gpu.memoryTotal) * 100
},
"temperature": gpu.temperature
} for gpu in gpus]
except Exception:
gpu_info = "GPU information unavailable"
return {
"cpu": cpu_info,
"memory": memory_info,
"process": process_info,
"network": network_info,
"gpu": gpu_info
}
@router.get("/debug/session_pools")
async def get_session_pool_info():
"""Get information about ONNX session pools."""
from ..inference.model_manager import get_manager
manager = await get_manager()
pools = manager._session_pools
current_time = time.time()
pool_info = {}
# Get CPU pool info
if 'onnx_cpu' in pools:
cpu_pool = pools['onnx_cpu']
pool_info['cpu'] = {
"active_sessions": len(cpu_pool._sessions),
"max_sessions": cpu_pool._max_size,
"sessions": [{
"model": path,
"age_seconds": current_time - info.last_used
} for path, info in cpu_pool._sessions.items()]
}
# Get GPU pool info
if 'onnx_gpu' in pools:
gpu_pool = pools['onnx_gpu']
pool_info['gpu'] = {
"active_sessions": len(gpu_pool._sessions),
"max_streams": gpu_pool._max_size,
"available_streams": len(gpu_pool._available_streams),
"sessions": [{
"model": path,
"age_seconds": current_time - info.last_used,
"stream_id": info.stream_id
} for path, info in gpu_pool._sessions.items()]
}
# Add GPU memory info if available
if GPU_AVAILABLE:
try:
gpus = GPUtil.getGPUs()
if gpus:
gpu = gpus[0] # Assume first GPU
pool_info['gpu']['memory'] = {
"total_mb": gpu.memoryTotal,
"used_mb": gpu.memoryUsed,
"free_mb": gpu.memoryFree,
"percent_used": (gpu.memoryUsed / gpu.memoryTotal) * 100
}
except Exception:
pass
return pool_info

View file

@ -1,12 +1,15 @@
from typing import List
import numpy as np
from fastapi import APIRouter, Depends, HTTPException, Response
import torch
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from fastapi.responses import StreamingResponse
from loguru import logger
from ..services.audio import AudioService
from ..services.text_processing import phonemize, tokenize
from ..services.tts_model import TTSModel
from ..services.audio import AudioService, AudioNormalizer
from ..services.streaming_audio_writer import StreamingAudioWriter
from ..services.text_processing import phonemize, smart_split
from ..services.text_processing.vocabulary import tokenize
from ..services.tts_service import TTSService
from ..structures.text_schemas import (
GenerateFromPhonemesRequest,
@ -17,12 +20,10 @@ from ..structures.text_schemas import (
router = APIRouter(tags=["text processing"])
def get_tts_service() -> TTSService:
async def get_tts_service() -> TTSService:
"""Dependency to get TTSService instance"""
return TTSService()
return await TTSService.create() # Create service with properly initialized managers
@router.post("/text/phonemize", response_model=PhonemeResponse, tags=["deprecated"])
@router.post("/dev/phonemize", response_model=PhonemeResponse)
async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse:
"""Convert text to phonemes and tokens
@ -43,10 +44,8 @@ async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse:
if not phonemes:
raise ValueError("Failed to generate phonemes")
# Get tokens
# Get tokens (without adding start/end tokens to match process_text behavior)
tokens = tokenize(phonemes)
tokens = [0] + tokens + [0] # Add start/end tokens
return PhonemeResponse(phonemes=phonemes, tokens=tokens)
except ValueError as e:
logger.error(f"Error in phoneme generation: {str(e)}")
@ -58,73 +57,95 @@ async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse:
raise HTTPException(
status_code=500, detail={"error": "Server error", "message": str(e)}
)
@router.post("/text/generate_from_phonemes", tags=["deprecated"])
@router.post("/dev/generate_from_phonemes")
async def generate_from_phonemes(
request: GenerateFromPhonemesRequest,
client_request: Request,
tts_service: TTSService = Depends(get_tts_service),
) -> Response:
"""Generate audio directly from phonemes
Args:
request: Request containing phonemes and generation parameters
tts_service: Injected TTSService instance
Returns:
WAV audio bytes
"""
# Validate phonemes first
if not request.phonemes:
raise HTTPException(
status_code=400,
detail={"error": "Invalid request", "message": "Phonemes cannot be empty"},
)
# Validate voice exists
voice_path = tts_service._get_voice_path(request.voice)
if not voice_path:
raise HTTPException(
status_code=400,
detail={
"error": "Invalid request",
"message": f"Voice not found: {request.voice}",
},
)
) -> StreamingResponse:
"""Generate audio directly from phonemes with proper streaming"""
try:
# Load voice
voicepack = tts_service._load_voice(voice_path)
# Basic validation
if not isinstance(request.phonemes, str):
raise ValueError("Phonemes must be a string")
if not request.phonemes:
raise ValueError("Phonemes cannot be empty")
# Convert phonemes to tokens
tokens = tokenize(request.phonemes)
tokens = [0] + tokens + [0] # Add start/end tokens
# Create streaming audio writer and normalizer
writer = StreamingAudioWriter(format="wav", sample_rate=24000, channels=1)
normalizer = AudioNormalizer()
# Generate audio directly from tokens
audio = TTSModel.generate_from_tokens(tokens, voicepack, request.speed)
async def generate_chunks():
try:
has_data = False
# Process phonemes in chunks
async for chunk_text, _ in smart_split(request.phonemes):
# Check if client is still connected
is_disconnected = client_request.is_disconnected
if callable(is_disconnected):
is_disconnected = await is_disconnected()
if is_disconnected:
logger.info("Client disconnected, stopping audio generation")
break
# Convert to WAV bytes
wav_bytes = AudioService.convert_audio(
audio, 24000, "wav", is_first_chunk=True, is_last_chunk=True, stream=False
)
chunk_audio, _ = await tts_service.generate_from_phonemes(
phonemes=chunk_text,
voice=request.voice,
speed=1.0
)
if chunk_audio is not None:
has_data = True
# Normalize audio before writing
normalized_audio = await normalizer.normalize(chunk_audio)
# Write chunk and yield bytes
chunk_bytes = writer.write_chunk(normalized_audio)
if chunk_bytes:
yield chunk_bytes
return Response(
content=wav_bytes,
if not has_data:
raise ValueError("Failed to generate any audio data")
# Finalize and yield remaining bytes if we still have a connection
if not (callable(is_disconnected) and await is_disconnected()):
final_bytes = writer.write_chunk(finalize=True)
if final_bytes:
yield final_bytes
except Exception as e:
logger.error(f"Error in audio chunk generation: {str(e)}")
# Clean up writer on error
writer.write_chunk(finalize=True)
# Re-raise the original exception
raise
return StreamingResponse(
generate_chunks(),
media_type="audio/wav",
headers={
"Content-Disposition": "attachment; filename=speech.wav",
"X-Accel-Buffering": "no",
"Cache-Control": "no-cache",
},
"Transfer-Encoding": "chunked"
}
)
except ValueError as e:
logger.error(f"Invalid request: {str(e)}")
logger.error(f"Error generating audio: {str(e)}")
raise HTTPException(
status_code=400, detail={"error": "Invalid request", "message": str(e)}
status_code=400,
detail={
"error": "validation_error",
"message": str(e),
"type": "invalid_request_error"
}
)
except Exception as e:
logger.error(f"Error generating audio: {str(e)}")
raise HTTPException(
status_code=500, detail={"error": "Server error", "message": str(e)}
status_code=500,
detail={
"error": "processing_error",
"message": str(e),
"type": "server_error"
}
)

View file

@ -1,23 +1,72 @@
from typing import AsyncGenerator, List, Union
"""OpenAI-compatible router for text-to-speech"""
from fastapi import APIRouter, Depends, Header, HTTPException, Response, Request
from fastapi.responses import StreamingResponse
import json
import os
from typing import AsyncGenerator, Dict, List, Union
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
from fastapi.responses import StreamingResponse, FileResponse
from loguru import logger
from ..services.audio import AudioService
from ..services.tts_service import TTSService
from ..structures.schemas import OpenAISpeechRequest
from ..core.config import settings
# Load OpenAI mappings
def load_openai_mappings() -> Dict:
"""Load OpenAI voice and model mappings from JSON"""
api_dir = os.path.dirname(os.path.dirname(__file__))
mapping_path = os.path.join(api_dir, "core", "openai_mappings.json")
try:
with open(mapping_path, 'r') as f:
return json.load(f)
except Exception as e:
logger.error(f"Failed to load OpenAI mappings: {e}")
return {"models": {}, "voices": {}}
# Global mappings
_openai_mappings = load_openai_mappings()
router = APIRouter(
tags=["OpenAI Compatible TTS"],
responses={404: {"description": "Not found"}},
)
# Global TTSService instance with lock
_tts_service = None
_init_lock = None
def get_tts_service() -> TTSService:
"""Dependency to get TTSService instance with database session"""
return TTSService() # Initialize TTSService with default settings
async def get_tts_service() -> TTSService:
"""Get global TTSService instance"""
global _tts_service, _init_lock
# Create lock if needed
if _init_lock is None:
import asyncio
_init_lock = asyncio.Lock()
# Initialize service if needed
if _tts_service is None:
async with _init_lock:
# Double check pattern
if _tts_service is None:
_tts_service = await TTSService.create()
logger.info("Created global TTSService instance")
return _tts_service
def get_model_name(model: str) -> str:
"""Get internal model name from OpenAI model name"""
base_name = _openai_mappings["models"].get(model)
if not base_name:
raise ValueError(f"Unsupported model: {model}")
# Add extension based on runtime config
extension = ".onnx" if settings.use_onnx else ".pth"
return base_name + extension
async def process_voices(
voice_input: Union[str, List[str]], tts_service: TTSService
@ -25,26 +74,37 @@ async def process_voices(
"""Process voice input into a combined voice, handling both string and list formats"""
# Convert input to list of voices
if isinstance(voice_input, str):
# Check if it's an OpenAI voice name
mapped_voice = _openai_mappings["voices"].get(voice_input)
if mapped_voice:
voice_input = mapped_voice
voices = [v.strip() for v in voice_input.split("+") if v.strip()]
else:
voices = voice_input
# For list input, map each voice if it's an OpenAI voice name
voices = [_openai_mappings["voices"].get(v, v) for v in voice_input]
voices = [v.strip() for v in voices if v.strip()]
if not voices:
raise ValueError("No voices provided")
# Check if all voices exist
# If single voice, validate and return it
if len(voices) == 1:
available_voices = await tts_service.list_voices()
if voices[0] not in available_voices:
raise ValueError(
f"Voice '{voices[0]}' not found. Available voices: {', '.join(sorted(available_voices))}"
)
return voices[0]
# For multiple voices, validate base voices exist
available_voices = await tts_service.list_voices()
for voice in voices:
if voice not in available_voices:
raise ValueError(
f"Voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}"
f"Base voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}"
)
# If single voice, return it directly
if len(voices) == 1:
return voices[0]
# Otherwise combine voices
# Combine voices
return await tts_service.combine_voices(voices=voices)
@ -64,7 +124,10 @@ async def stream_audio_chunks(
output_format=request.response_format,
):
# Check if client is still connected
if await client_request.is_disconnected():
is_disconnected = client_request.is_disconnected
if callable(is_disconnected):
is_disconnected = await is_disconnected()
if is_disconnected:
logger.info("Client disconnected, stopping audio generation")
break
yield chunk
@ -78,12 +141,23 @@ async def stream_audio_chunks(
async def create_speech(
request: OpenAISpeechRequest,
client_request: Request,
tts_service: TTSService = Depends(get_tts_service),
x_raw_response: str = Header(None, alias="x-raw-response"),
):
"""OpenAI-compatible endpoint for text-to-speech"""
# Validate model before processing request
if request.model not in _openai_mappings["models"]:
raise HTTPException(
status_code=400,
detail={
"error": "invalid_model",
"message": f"Unsupported model: {request.model}",
"type": "invalid_request_error"
}
)
try:
# Process voice combination and validate
# model_name = get_model_name(request.model)
tts_service = await get_tts_service()
voice_to_use = await process_voices(request.voice, tts_service)
# Set content type based on format
@ -98,29 +172,79 @@ async def create_speech(
# Check if streaming is requested (default for OpenAI client)
if request.stream:
# Stream audio chunks as they're generated
# Create generator but don't start it yet
generator = stream_audio_chunks(tts_service, request, client_request)
# If download link requested, wrap generator with temp file writer
if request.return_download_link:
from ..services.temp_manager import TempFileWriter
temp_writer = TempFileWriter(request.response_format)
await temp_writer.__aenter__() # Initialize temp file
# Get download path immediately after temp file creation
download_path = temp_writer.download_path
# Create response headers with download path
headers = {
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
"X-Accel-Buffering": "no",
"Cache-Control": "no-cache",
"Transfer-Encoding": "chunked",
"X-Download-Path": download_path
}
# Create async generator for streaming
async def dual_output():
try:
# Write chunks to temp file and stream
async for chunk in generator:
if chunk: # Skip empty chunks
await temp_writer.write(chunk)
yield chunk
# Finalize the temp file
await temp_writer.finalize()
except Exception as e:
logger.error(f"Error in dual output streaming: {e}")
await temp_writer.__aexit__(type(e), e, e.__traceback__)
raise
finally:
# Ensure temp writer is closed
if not temp_writer._finalized:
await temp_writer.__aexit__(None, None, None)
# Stream with temp file writing
return StreamingResponse(
dual_output(),
media_type=content_type,
headers=headers
)
# Standard streaming without download link
return StreamingResponse(
stream_audio_chunks(tts_service, request, client_request),
generator,
media_type=content_type,
headers={
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
"X-Accel-Buffering": "no", # Disable proxy buffering
"Cache-Control": "no-cache", # Prevent caching
"Transfer-Encoding": "chunked", # Enable chunked transfer encoding
},
"X-Accel-Buffering": "no",
"Cache-Control": "no-cache",
"Transfer-Encoding": "chunked"
}
)
else:
# Generate complete audio
audio, _ = tts_service._generate_audio(
# Generate complete audio using public interface
audio, _ = await tts_service.generate_audio(
text=request.input,
voice=voice_to_use,
speed=request.speed,
stitch_long_output=True,
speed=request.speed
)
# Convert to requested format
content = AudioService.convert_audio(
audio, 24000, request.response_format, is_first_chunk=True, stream=False
# Convert to requested format with proper finalization
content = await AudioService.convert_audio(
audio, 24000, request.response_format,
is_first_chunk=True,
is_last_chunk=True
)
return Response(
@ -133,32 +257,98 @@ async def create_speech(
)
except ValueError as e:
logger.error(f"Invalid request: {str(e)}")
# Handle validation errors
logger.warning(f"Invalid request: {str(e)}")
raise HTTPException(
status_code=400, detail={"error": "Invalid request", "message": str(e)}
status_code=400,
detail={
"error": "validation_error",
"message": str(e),
"type": "invalid_request_error"
}
)
except RuntimeError as e:
# Handle runtime/processing errors
logger.error(f"Processing error: {str(e)}")
raise HTTPException(
status_code=500,
detail={
"error": "processing_error",
"message": str(e),
"type": "server_error"
}
)
except Exception as e:
logger.error(f"Error generating speech: {str(e)}")
# Handle unexpected errors
logger.error(f"Unexpected error in speech generation: {str(e)}")
raise HTTPException(
status_code=500, detail={"error": "Server error", "message": str(e)}
status_code=500,
detail={
"error": "processing_error",
"message": str(e),
"type": "server_error"
}
)
@router.get("/download/{filename}")
async def download_audio_file(filename: str):
"""Download a generated audio file from temp storage"""
try:
from ..core.paths import _find_file, get_content_type
# Search for file in temp directory
file_path = await _find_file(
filename=filename,
search_paths=[settings.temp_file_dir]
)
# Get content type from path helper
content_type = await get_content_type(file_path)
return FileResponse(
file_path,
media_type=content_type,
filename=filename,
headers={
"Cache-Control": "no-cache",
"Content-Disposition": f"attachment; filename={filename}"
}
)
except Exception as e:
logger.error(f"Error serving download file {filename}: {e}")
raise HTTPException(
status_code=500,
detail={
"error": "server_error",
"message": "Failed to serve audio file",
"type": "server_error"
}
)
@router.get("/audio/voices")
async def list_voices(tts_service: TTSService = Depends(get_tts_service)):
async def list_voices():
"""List all available voices for text-to-speech"""
try:
tts_service = await get_tts_service()
voices = await tts_service.list_voices()
return {"voices": voices}
except Exception as e:
logger.error(f"Error listing voices: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(
status_code=500,
detail={
"error": "server_error",
"message": "Failed to retrieve voice list",
"type": "server_error"
}
)
@router.post("/audio/voices/combine")
async def combine_voices(
request: Union[str, List[str]], tts_service: TTSService = Depends(get_tts_service)
):
async def combine_voices(request: Union[str, List[str]]):
"""Combine multiple voices into a new voice.
Args:
@ -174,18 +364,38 @@ async def combine_voices(
- 500: Server error (file system issues, combination failed)
"""
try:
tts_service = await get_tts_service()
combined_voice = await process_voices(request, tts_service)
voices = await tts_service.list_voices()
return {"voices": voices, "voice": combined_voice}
except ValueError as e:
logger.error(f"Invalid voice combination request: {str(e)}")
logger.warning(f"Invalid voice combination request: {str(e)}")
raise HTTPException(
status_code=400, detail={"error": "Invalid request", "message": str(e)}
status_code=400,
detail={
"error": "validation_error",
"message": str(e),
"type": "invalid_request_error"
}
)
except RuntimeError as e:
logger.error(f"Voice combination processing error: {str(e)}")
raise HTTPException(
status_code=500,
detail={
"error": "processing_error",
"message": "Failed to process voice combination request",
"type": "server_error"
}
)
except Exception as e:
logger.error(f"Server error during voice combination: {str(e)}")
logger.error(f"Unexpected error in voice combination: {str(e)}")
raise HTTPException(
status_code=500, detail={"error": "Server error", "message": "Server error"}
status_code=500,
detail={
"error": "server_error",
"message": "An unexpected error occurred",
"type": "server_error"
}
)

View file

@ -0,0 +1,48 @@
"""Web player router with async file serving."""
from fastapi import APIRouter, HTTPException
from fastapi.responses import Response
from loguru import logger
from ..core.config import settings
from ..core.paths import get_web_file_path, read_bytes, get_content_type
router = APIRouter(
tags=["Web Player"],
responses={404: {"description": "Not found"}},
)
@router.get("/{filename:path}")
async def serve_web_file(filename: str):
"""Serve web player static files asynchronously."""
if not settings.enable_web_player:
raise HTTPException(status_code=404, detail="Web player is disabled")
try:
# Default to index.html for root path
if filename == "" or filename == "/":
filename = "index.html"
# Get file path
file_path = await get_web_file_path(filename)
# Read file content
content = await read_bytes(file_path)
# Get content type
content_type = await get_content_type(file_path)
return Response(
content=content,
media_type=content_type,
headers={
"Cache-Control": "no-cache", # Prevent caching during development
}
)
except RuntimeError as e:
logger.warning(f"Web file not found: {filename}")
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
logger.error(f"Error serving web file {filename}: {e}")
raise HTTPException(status_code=500, detail="Internal server error")

View file

@ -10,37 +10,41 @@ from loguru import logger
from pydub import AudioSegment
from ..core.config import settings
from .streaming_audio_writer import StreamingAudioWriter
class AudioNormalizer:
"""Handles audio normalization state for a single stream"""
def __init__(self):
self.int16_max = np.iinfo(np.int16).max
self.chunk_trim_ms = settings.gap_trim_ms
self.sample_rate = 24000 # Sample rate of the audio
self.samples_to_trim = int(self.chunk_trim_ms * self.sample_rate / 1000)
def normalize(
self, audio_data: np.ndarray, is_last_chunk: bool = False
) -> np.ndarray:
"""Convert audio data to int16 range and trim chunk boundaries"""
if len(audio_data) == 0:
raise ValueError("Audio data cannot be empty")
async def normalize(self, audio_data: np.ndarray) -> np.ndarray:
"""Convert audio data to int16 range and trim silence from start and end
Args:
audio_data: Input audio data as numpy array
# Simple float32 to int16 conversion
audio_float = audio_data.astype(np.float32)
Returns:
Normalized and trimmed audio data
"""
if len(audio_data) == 0:
raise ValueError("Empty audio data")
# Trim start and end if enough samples
if len(audio_data) > (2 * self.samples_to_trim):
audio_data = audio_data[self.samples_to_trim:-self.samples_to_trim]
# Trim for non-final chunks
if not is_last_chunk and len(audio_float) > self.samples_to_trim:
audio_float = audio_float[:-self.samples_to_trim]
# Direct scaling like the non-streaming version
return (audio_float * 32767).astype(np.int16)
# Scale directly to int16 range with clipping
return np.clip(audio_data * 32767, -32768, 32767).astype(np.int16)
class AudioService:
"""Service for audio format conversions"""
"""Service for audio format conversions with streaming support"""
# Supported formats
SUPPORTED_FORMATS = {"wav", "mp3", "opus", "flac", "aac", "pcm", "ogg"}
# Default audio format settings balanced for speed and compression
DEFAULT_SETTINGS = {
@ -59,137 +63,60 @@ class AudioService:
},
}
_writers = {}
@staticmethod
def convert_audio(
async def convert_audio(
audio_data: np.ndarray,
sample_rate: int,
output_format: str,
is_first_chunk: bool = True,
is_last_chunk: bool = False,
normalizer: AudioNormalizer = None,
format_settings: dict = None,
stream: bool = True,
) -> bytes:
"""Convert audio data to specified format
"""Convert audio data to specified format with streaming support
Args:
audio_data: Numpy array of audio samples
sample_rate: Sample rate of the audio
output_format: Target format (wav, mp3, opus, flac, pcm)
is_first_chunk: Whether this is the first chunk of a stream
normalizer: Optional AudioNormalizer instance for consistent normalization across chunks
format_settings: Optional dict of format-specific settings to override defaults
Example: {
"mp3": {
"bitrate_mode": "VARIABLE",
"compression_level": 0.8
}
}
Default settings balance speed and compression:
optimized for localhost @ 0.0
- MP3: constant bitrate, no compression (0.0)
- OPUS: no compression (0.0)
- FLAC: no compression (0.0)
output_format: Target format (wav, mp3, ogg, pcm)
is_first_chunk: Whether this is the first chunk
is_last_chunk: Whether this is the last chunk
normalizer: Optional AudioNormalizer instance for consistent normalization
Returns:
Bytes of the converted audio
Bytes of the converted audio chunk
"""
buffer = BytesIO()
try:
# Validate format
if output_format not in AudioService.SUPPORTED_FORMATS:
raise ValueError(f"Format {output_format} not supported")
# Always normalize audio to ensure proper amplitude scaling
if normalizer is None:
normalizer = AudioNormalizer()
normalized_audio = normalizer.normalize(
audio_data, is_last_chunk=is_last_chunk
)
normalized_audio = await normalizer.normalize(audio_data)
if output_format == "pcm":
# Raw 16-bit PCM samples, no header
buffer.write(normalized_audio.tobytes())
elif output_format == "wav":
# Write the WAV header ourselves so that we can specify a "fake" data size.
# This is necessary for streaming responses to work properly: if we simply
# concatenated individual WAV files then the initial chunk's header length
# would be shorter than the full file length and subsequent chunks' RIFF
# headers would appear in the middle of the audio data.
if is_first_chunk:
# Modified from Python stdlib's wave.py module:
buffer.write(b'RIFF')
buffer.write(struct.pack('<L4s4sLHHLLHH4s',
0xFFFFFFFF, # total size (set to max)
b'WAVE',
b'fmt ',
16,
1, # PCM format
1, # channels
sample_rate,
sample_rate * 2, # byte rate
2, # block align
16, # bits per sample
b'data'
))
buffer.write(struct.pack('<L', 0xFFFFFFFF)) # data size (set to max)
# write raw PCM data
buffer.write(normalized_audio.tobytes())
elif output_format == "mp3":
# MP3 format with proper framing
settings = format_settings.get("mp3", {}) if format_settings else {}
settings = {**AudioService.DEFAULT_SETTINGS["mp3"], **settings}
sf.write(
buffer, normalized_audio, sample_rate, format="MP3", **settings
)
elif output_format == "opus":
# Opus format in OGG container
settings = format_settings.get("opus", {}) if format_settings else {}
settings = {**AudioService.DEFAULT_SETTINGS["opus"], **settings}
sf.write(
buffer,
normalized_audio,
sample_rate,
format="OGG",
subtype="OPUS",
**settings,
)
elif output_format == "flac":
# FLAC format with proper framing
if is_first_chunk:
logger.info("Starting FLAC stream...")
settings = format_settings.get("flac", {}) if format_settings else {}
settings = {**AudioService.DEFAULT_SETTINGS["flac"], **settings}
sf.write(
buffer,
normalized_audio,
sample_rate,
format="FLAC",
subtype="PCM_16",
**settings,
)
elif output_format == "aac":
# Convert numpy array directly to AAC using pydub
audio_segment = AudioSegment(
normalized_audio.tobytes(),
frame_rate=sample_rate,
sample_width=normalized_audio.dtype.itemsize,
channels=1 if len(normalized_audio.shape) == 1 else normalized_audio.shape[1]
)
settings = format_settings.get("aac", {}) if format_settings else {}
settings = {**AudioService.DEFAULT_SETTINGS["aac"], **settings}
audio_segment.export(
buffer,
format="adts", # ADTS is a common AAC container format
bitrate=settings["bitrate"]
)
else:
raise ValueError(
f"Format {output_format} not supported. Supported formats are: wav, mp3, opus, flac, pcm, aac."
# Get or create format-specific writer
writer_key = f"{output_format}_{sample_rate}"
if is_first_chunk or writer_key not in AudioService._writers:
AudioService._writers[writer_key] = StreamingAudioWriter(
output_format, sample_rate
)
writer = AudioService._writers[writer_key]
buffer.seek(0)
return buffer.getvalue()
# Write audio data first
if len(normalized_audio) > 0:
chunk_data = writer.write_chunk(normalized_audio)
# Then finalize if this is the last chunk
if is_last_chunk:
final_data = writer.write_chunk(finalize=True)
del AudioService._writers[writer_key]
return final_data if final_data else b''
return chunk_data if chunk_data else b''
except Exception as e:
logger.error(f"Error converting audio to {output_format}: {str(e)}")
raise ValueError(f"Failed to convert audio to {output_format}: {str(e)}")
logger.error(f"Error converting audio stream to {output_format}: {str(e)}")
raise ValueError(f"Failed to convert audio stream to {output_format}: {str(e)}")

View file

@ -0,0 +1,207 @@
"""Audio conversion service with proper streaming support"""
from io import BytesIO
import struct
from typing import Optional
import numpy as np
import soundfile as sf
from loguru import logger
from pydub import AudioSegment
class StreamingAudioWriter:
"""Handles streaming audio format conversions"""
def __init__(self, format: str, sample_rate: int, channels: int = 1):
self.format = format.lower()
self.sample_rate = sample_rate
self.channels = channels
self.bytes_written = 0
self.buffer = BytesIO()
# Format-specific setup
if self.format == "wav":
self._write_wav_header_initial()
elif self.format in ["ogg", "opus"]:
# For OGG/Opus, write to memory buffer
self.writer = sf.SoundFile(
file=self.buffer,
mode='w',
samplerate=sample_rate,
channels=channels,
format='OGG',
subtype='VORBIS' if self.format == "ogg" else "OPUS"
)
elif self.format == "flac":
# For FLAC, write to memory buffer
self.writer = sf.SoundFile(
file=self.buffer,
mode='w',
samplerate=sample_rate,
channels=channels,
format='FLAC'
)
elif self.format in ["mp3", "aac"]:
# For MP3/AAC, we'll use pydub's incremental writer
self.segments = [] # Store segments until we have enough data
self.total_duration = 0 # Track total duration in milliseconds
# Initialize an empty AudioSegment as our encoder
self.encoder = AudioSegment.silent(duration=0, frame_rate=self.sample_rate)
elif self.format == "pcm":
# PCM doesn't need initialization, we'll write raw bytes
pass
else:
raise ValueError(f"Unsupported format: {format}")
def _write_wav_header_initial(self) -> None:
"""Write initial WAV header with placeholders"""
self.buffer.write(b'RIFF')
self.buffer.write(struct.pack('<L', 0)) # Placeholder for file size
self.buffer.write(b'WAVE')
self.buffer.write(b'fmt ')
self.buffer.write(struct.pack('<L', 16)) # fmt chunk size
self.buffer.write(struct.pack('<H', 1)) # PCM format
self.buffer.write(struct.pack('<H', self.channels))
self.buffer.write(struct.pack('<L', self.sample_rate))
self.buffer.write(struct.pack('<L', self.sample_rate * self.channels * 2)) # Byte rate
self.buffer.write(struct.pack('<H', self.channels * 2)) # Block align
self.buffer.write(struct.pack('<H', 16)) # Bits per sample
self.buffer.write(b'data')
self.buffer.write(struct.pack('<L', 0)) # Placeholder for data size
def write_chunk(self, audio_data: Optional[np.ndarray] = None, finalize: bool = False) -> bytes:
"""Write a chunk of audio data and return bytes in the target format.
Args:
audio_data: Audio data to write, or None if finalizing
finalize: Whether this is the final write to close the stream
"""
output_buffer = BytesIO()
if finalize:
if self.format == "wav":
# Calculate actual file and data sizes
file_size = self.bytes_written + 36 # RIFF header bytes
data_size = self.bytes_written
# Seek to the beginning to overwrite the placeholders
self.buffer.seek(4)
self.buffer.write(struct.pack('<L', file_size))
self.buffer.seek(40)
self.buffer.write(struct.pack('<L', data_size))
self.buffer.seek(0)
return self.buffer.read()
elif self.format in ["ogg", "opus", "flac"]:
self.writer.close()
return self.buffer.getvalue()
elif self.format in ["mp3", "aac"]:
if hasattr(self, 'encoder') and len(self.encoder) > 0:
format_args = {
"mp3": {"format": "mp3", "codec": "libmp3lame"},
"aac": {"format": "adts", "codec": "aac"}
}[self.format]
parameters = []
if self.format == "mp3":
parameters.extend([
"-q:a", "2",
"-write_xing", "1", # XING header for MP3
"-id3v1", "1",
"-id3v2", "1",
"-write_vbr", "1",
"-vbr_quality", "2"
])
elif self.format == "aac":
parameters.extend([
"-q:a", "2",
"-write_xing", "0",
"-write_id3v1", "0",
"-write_id3v2", "0"
])
self.encoder.export(
output_buffer,
**format_args,
bitrate="192k",
parameters=parameters
)
self.encoder = None
return output_buffer.getvalue()
if audio_data is None or len(audio_data) == 0:
return b''
if self.format == "wav":
# Write raw PCM data
self.buffer.write(audio_data.tobytes())
self.bytes_written += len(audio_data.tobytes())
return b''
elif self.format in ["ogg", "opus", "flac"]:
# Write to soundfile buffer
self.writer.write(audio_data)
self.writer.flush()
return self.buffer.getvalue()
elif self.format in ["mp3", "aac"]:
# Convert chunk to AudioSegment and encode
segment = AudioSegment(
audio_data.tobytes(),
frame_rate=self.sample_rate,
sample_width=audio_data.dtype.itemsize,
channels=self.channels
)
# Track total duration
self.total_duration += len(segment)
# Add segment to encoder
self.encoder += segment
# Export current state to buffer without final metadata
format_args = {
"mp3": {"format": "mp3", "codec": "libmp3lame"},
"aac": {"format": "adts", "codec": "aac"}
}[self.format]
# For chunks, export without duration metadata or XING headers
self.encoder.export(output_buffer, **format_args, bitrate="192k", parameters=[
"-q:a", "2",
"-write_xing", "0" # No XING headers for chunks
])
# Get the encoded data
encoded_data = output_buffer.getvalue()
# Reset encoder to prevent memory growth
self.encoder = AudioSegment.silent(duration=0, frame_rate=self.sample_rate)
return encoded_data
elif self.format == "pcm":
# Write raw bytes
return audio_data.tobytes()
return b''
def close(self) -> Optional[bytes]:
"""Finish the audio file and return any remaining data"""
if self.format == "wav":
# Re-finalize WAV file by updating headers
self.buffer.seek(0)
file_content = self.write_chunk(finalize=True)
return file_content
elif self.format in ["ogg", "opus", "flac"]:
# Finalize other formats
self.writer.close()
return self.buffer.getvalue()
elif self.format in ["mp3", "aac"]:
# Finalize MP3/AAC
final_data = self.write_chunk(finalize=True)
return final_data
return None

View file

@ -0,0 +1,139 @@
"""Temporary file writer for audio downloads"""
import os
import tempfile
from typing import Optional, List
import aiofiles
from fastapi import HTTPException
from loguru import logger
from ..core.config import settings
async def cleanup_temp_files() -> None:
"""Clean up old temp files"""
try:
if not await aiofiles.os.path.exists(settings.temp_file_dir):
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
return
# Get all temp files with stats
files = []
total_size = 0
# Use os.scandir for sync iteration, but aiofiles.os.stat for async stats
for entry in os.scandir(settings.temp_file_dir):
if entry.is_file():
stat = await aiofiles.os.stat(entry.path)
files.append((entry.path, stat.st_mtime, stat.st_size))
total_size += stat.st_size
# Sort by modification time (oldest first)
files.sort(key=lambda x: x[1])
# Remove files if:
# 1. They're too old
# 2. We have too many files
# 3. Directory is too large
current_time = (await aiofiles.os.stat(settings.temp_file_dir)).st_mtime
max_age = settings.max_temp_dir_age_hours * 3600
for path, mtime, size in files:
should_delete = False
# Check age
if current_time - mtime > max_age:
should_delete = True
logger.info(f"Deleting old temp file: {path}")
# Check count limit
elif len(files) > settings.max_temp_dir_count:
should_delete = True
logger.info(f"Deleting excess temp file: {path}")
# Check size limit
elif total_size > settings.max_temp_dir_size_mb * 1024 * 1024:
should_delete = True
logger.info(f"Deleting to reduce directory size: {path}")
if should_delete:
try:
await aiofiles.os.remove(path)
total_size -= size
logger.info(f"Deleted temp file: {path}")
except Exception as e:
logger.warning(f"Failed to delete temp file {path}: {e}")
except Exception as e:
logger.warning(f"Error during temp file cleanup: {e}")
class TempFileWriter:
"""Handles writing audio chunks to a temp file"""
def __init__(self, format: str):
"""Initialize temp file writer
Args:
format: Audio format extension (mp3, wav, etc)
"""
self.format = format
self.temp_file = None
self._finalized = False
async def __aenter__(self):
"""Async context manager entry"""
# Clean up old files first
await cleanup_temp_files()
# Create temp file with proper extension
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
temp = tempfile.NamedTemporaryFile(
dir=settings.temp_file_dir,
delete=False,
suffix=f".{self.format}",
mode='wb'
)
self.temp_file = await aiofiles.open(temp.name, mode='wb')
self.temp_path = temp.name
temp.close() # Close sync file, we'll use async version
# Generate download path immediately
self.download_path = f"/download/{os.path.basename(self.temp_path)}"
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit"""
try:
if self.temp_file and not self._finalized:
await self.temp_file.close()
self._finalized = True
except Exception as e:
logger.error(f"Error closing temp file: {e}")
async def write(self, chunk: bytes) -> None:
"""Write a chunk of audio data
Args:
chunk: Audio data bytes to write
"""
if self._finalized:
raise RuntimeError("Cannot write to finalized temp file")
await self.temp_file.write(chunk)
await self.temp_file.flush()
async def finalize(self) -> str:
"""Close temp file and return download path
Returns:
Path to use for downloading the temp file
"""
if self._finalized:
raise RuntimeError("Temp file already finalized")
await self.temp_file.close()
self._finalized = True
return f"/download/{os.path.basename(self.temp_path)}"

View file

@ -1,13 +1,19 @@
"""Text processing pipeline."""
from .normalizer import normalize_text
from .phonemizer import EspeakBackend, PhonemizerBackend, phonemize
from .vocabulary import VOCAB, decode_tokens, tokenize
from .phonemizer import phonemize
from .vocabulary import tokenize
from .text_processor import process_text_chunk, smart_split
def process_text(text: str) -> list[int]:
"""Process text into token IDs (for backward compatibility)."""
return process_text_chunk(text)
__all__ = [
"normalize_text",
"phonemize",
"tokenize",
"decode_tokens",
"VOCAB",
"PhonemizerBackend",
"EspeakBackend",
'normalize_text',
'phonemize',
'tokenize',
'process_text',
'process_text_chunk',
'smart_split'
]

View file

@ -1,53 +0,0 @@
"""Text chunking service"""
import re
from ...core.config import settings
def split_text(text: str, max_chunk=None):
"""Split text into chunks on natural pause points
Args:
text: Text to split into chunks
max_chunk: Maximum chunk size (defaults to settings.max_chunk_size)
"""
if max_chunk is None:
max_chunk = settings.max_chunk_size
if not isinstance(text, str):
text = str(text) if text is not None else ""
text = text.strip()
if not text:
return
# First split into sentences
sentences = re.split(r"(?<=[.!?])\s+", text)
for sentence in sentences:
sentence = sentence.strip()
if not sentence:
continue
# For medium-length sentences, split on punctuation
if len(sentence) > max_chunk: # Lower threshold for more consistent sizes
# First try splitting on semicolons and colons
parts = re.split(r"(?<=[;:])\s+", sentence)
for part in parts:
part = part.strip()
if not part:
continue
# If part is still long, split on commas
if len(part) > max_chunk:
subparts = re.split(r"(?<=,)\s+", part)
for subpart in subparts:
subpart = subpart.strip()
if subpart:
yield subpart
else:
yield part
else:
yield sentence

View file

@ -0,0 +1,203 @@
"""Unified text processing for TTS with smart chunking."""
import re
import time
from typing import AsyncGenerator, List, Tuple
from loguru import logger
from .phonemizer import phonemize
from .normalizer import normalize_text
from .vocabulary import tokenize
# Target token ranges
TARGET_MIN = 200
TARGET_MAX = 350
ABSOLUTE_MAX = 500
def process_text_chunk(text: str, language: str = "a", skip_phonemize: bool = False) -> List[int]:
"""Process a chunk of text through normalization, phonemization, and tokenization.
Args:
text: Text chunk to process
language: Language code for phonemization
skip_phonemize: If True, treat input as phonemes and skip normalization/phonemization
Returns:
List of token IDs
"""
start_time = time.time()
if skip_phonemize:
# Input is already phonemes, just tokenize
t0 = time.time()
tokens = tokenize(text)
t1 = time.time()
else:
# Normal text processing pipeline
t0 = time.time()
normalized = normalize_text(text)
t1 = time.time()
t0 = time.time()
phonemes = phonemize(normalized, language, normalize=False) # Already normalized
t1 = time.time()
t0 = time.time()
tokens = tokenize(phonemes)
t1 = time.time()
total_time = time.time() - start_time
logger.debug(f"Total processing took {total_time*1000:.2f}ms for chunk: '{text[:50]}...'")
return tokens
async def yield_chunk(text: str, tokens: List[int], chunk_count: int) -> Tuple[str, List[int]]:
"""Yield a chunk with consistent logging."""
logger.debug(f"Yielding chunk {chunk_count}: '{text[:50]}...' ({len(tokens)} tokens)")
return text, tokens
def process_text(text: str, language: str = "a") -> List[int]:
"""Process text into token IDs.
Args:
text: Text to process
language: Language code for phonemization
Returns:
List of token IDs
"""
if not isinstance(text, str):
text = str(text) if text is not None else ""
text = text.strip()
if not text:
return []
return process_text_chunk(text, language)
def get_sentence_info(text: str) -> List[Tuple[str, List[int], int]]:
"""Process all sentences and return info."""
sentences = re.split(r'([.!?;:])', text)
results = []
for i in range(0, len(sentences), 2):
sentence = sentences[i].strip()
punct = sentences[i + 1] if i + 1 < len(sentences) else ""
if not sentence:
continue
full = sentence + punct
tokens = process_text_chunk(full)
results.append((full, tokens, len(tokens)))
return results
async def smart_split(text: str, max_tokens: int = ABSOLUTE_MAX) -> AsyncGenerator[Tuple[str, List[int]], None]:
"""Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens."""
start_time = time.time()
chunk_count = 0
logger.info(f"Starting smart split for {len(text)} chars")
# Process all sentences
sentences = get_sentence_info(text)
current_chunk = []
current_tokens = []
current_count = 0
for sentence, tokens, count in sentences:
# Handle sentences that exceed max tokens
if count > max_tokens:
# Yield current chunk if any
if current_chunk:
chunk_text = " ".join(current_chunk)
chunk_count += 1
logger.debug(f"Yielding chunk {chunk_count}: '{chunk_text[:50]}...' ({current_count} tokens)")
yield chunk_text, current_tokens
current_chunk = []
current_tokens = []
current_count = 0
# Split long sentence on commas
clauses = re.split(r'([,])', sentence)
clause_chunk = []
clause_tokens = []
clause_count = 0
for j in range(0, len(clauses), 2):
clause = clauses[j].strip()
comma = clauses[j + 1] if j + 1 < len(clauses) else ""
if not clause:
continue
full_clause = clause + comma
tokens = process_text_chunk(full_clause)
count = len(tokens)
# If adding clause keeps us under max and not optimal yet
if clause_count + count <= max_tokens and clause_count + count <= TARGET_MAX:
clause_chunk.append(full_clause)
clause_tokens.extend(tokens)
clause_count += count
else:
# Yield clause chunk if we have one
if clause_chunk:
chunk_text = " ".join(clause_chunk)
chunk_count += 1
logger.debug(f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}...' ({clause_count} tokens)")
yield chunk_text, clause_tokens
clause_chunk = [full_clause]
clause_tokens = tokens
clause_count = count
# Don't forget last clause chunk
if clause_chunk:
chunk_text = " ".join(clause_chunk)
chunk_count += 1
logger.debug(f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}...' ({clause_count} tokens)")
yield chunk_text, clause_tokens
# Regular sentence handling
elif current_count >= TARGET_MIN and current_count + count > TARGET_MAX:
# If we have a good sized chunk and adding next sentence exceeds target,
# yield current chunk and start new one
chunk_text = " ".join(current_chunk)
chunk_count += 1
logger.info(f"Yielding chunk {chunk_count}: '{chunk_text[:50]}...' ({current_count} tokens)")
yield chunk_text, current_tokens
current_chunk = [sentence]
current_tokens = tokens
current_count = count
elif current_count + count <= TARGET_MAX:
# Keep building chunk while under target max
current_chunk.append(sentence)
current_tokens.extend(tokens)
current_count += count
elif current_count + count <= max_tokens and current_count < TARGET_MIN:
# Only exceed target max if we haven't reached minimum size yet
current_chunk.append(sentence)
current_tokens.extend(tokens)
current_count += count
else:
# Yield current chunk and start new one
if current_chunk:
chunk_text = " ".join(current_chunk)
chunk_count += 1
logger.info(f"Yielding chunk {chunk_count}: '{chunk_text[:50]}...' ({current_count} tokens)")
yield chunk_text, current_tokens
current_chunk = [sentence]
current_tokens = tokens
current_count = count
# Don't forget the last chunk
if current_chunk:
chunk_text = " ".join(current_chunk)
chunk_count += 1
logger.info(f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}...' ({current_count} tokens)")
yield chunk_text, current_tokens
total_time = time.time() - start_time
logger.info(f"Split completed in {total_time*1000:.2f}ms, produced {chunk_count} chunks")

View file

@ -1,175 +0,0 @@
import os
import threading
from abc import ABC, abstractmethod
from typing import List, Tuple
import numpy as np
import torch
from loguru import logger
from ..core.config import settings
class TTSBaseModel(ABC):
_instance = None
_lock = threading.Lock()
_device = None
VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices")
@classmethod
async def setup(cls):
"""Initialize model and setup voices"""
with cls._lock:
# Set device
cuda_available = torch.cuda.is_available()
logger.info(f"CUDA available: {cuda_available}")
if cuda_available:
try:
# Test CUDA device
test_tensor = torch.zeros(1).cuda()
logger.info("CUDA test successful")
model_path = os.path.join(
settings.model_dir, settings.pytorch_model_path
)
cls._device = "cuda"
except Exception as e:
logger.error(f"CUDA test failed: {e}")
cls._device = "cpu"
else:
cls._device = "cpu"
model_path = os.path.join(settings.model_dir, settings.onnx_model_path)
logger.info(f"Initializing model on {cls._device}")
logger.info(f"Model dir: {settings.model_dir}")
logger.info(f"Model path: {model_path}")
logger.info(f"Files in model dir: {os.listdir(settings.model_dir)}")
# Initialize model first
model = cls.initialize(settings.model_dir, model_path=model_path)
if model is None:
raise RuntimeError(f"Failed to initialize {cls._device.upper()} model")
cls._instance = model
# Setup voices directory
os.makedirs(cls.VOICES_DIR, exist_ok=True)
# Copy base voices to local directory
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
if os.path.exists(base_voices_dir):
for file in os.listdir(base_voices_dir):
if file.endswith(".pt"):
voice_name = file[:-3]
voice_path = os.path.join(cls.VOICES_DIR, file)
if not os.path.exists(voice_path):
try:
logger.info(
f"Copying base voice {voice_name} to voices directory"
)
base_path = os.path.join(base_voices_dir, file)
voicepack = torch.load(
base_path,
map_location=cls._device,
weights_only=True,
)
torch.save(voicepack, voice_path)
except Exception as e:
logger.error(
f"Error copying voice {voice_name}: {str(e)}"
)
# Count voices in directory
voice_count = len(
[f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]
)
# Now that model and voices are ready, do warmup
try:
with open(
os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"core",
"don_quixote.txt",
)
) as f:
warmup_text = f.read()
except Exception as e:
logger.warning(f"Failed to load warmup text: {e}")
warmup_text = "This is a warmup text that will be split into chunks for processing."
# Use warmup service after model is fully initialized
from .warmup import WarmupService
warmup = WarmupService()
# Load and warm up voices
loaded_voices = warmup.load_voices()
await warmup.warmup_voices(warmup_text, loaded_voices)
logger.info("Model warm-up complete")
# Count voices in directory
voice_count = len(
[f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]
)
return voice_count
@classmethod
@abstractmethod
def initialize(cls, model_dir: str, model_path: str = None):
"""Initialize the model"""
pass
@classmethod
@abstractmethod
def process_text(cls, text: str, language: str) -> Tuple[str, List[int]]:
"""Process text into phonemes and tokens
Args:
text: Input text
language: Language code
Returns:
tuple[str, list[int]]: Phonemes and token IDs
"""
pass
@classmethod
@abstractmethod
def generate_from_text(
cls, text: str, voicepack: torch.Tensor, language: str, speed: float
) -> Tuple[np.ndarray, str]:
"""Generate audio from text
Args:
text: Input text
voicepack: Voice tensor
language: Language code
speed: Speed factor
Returns:
tuple[np.ndarray, str]: Generated audio samples and phonemes
"""
pass
@classmethod
@abstractmethod
def generate_from_tokens(
cls, tokens: List[int], voicepack: torch.Tensor, speed: float
) -> np.ndarray:
"""Generate audio from tokens
Args:
tokens: Token IDs
voicepack: Voice tensor
speed: Speed factor
Returns:
np.ndarray: Generated audio samples
"""
pass
@classmethod
def get_device(cls):
"""Get the current device"""
if cls._device is None:
raise RuntimeError("Model not initialized. Call setup() first.")
return cls._device

View file

@ -1,167 +0,0 @@
import os
import numpy as np
import torch
from loguru import logger
from onnxruntime import (
ExecutionMode,
GraphOptimizationLevel,
InferenceSession,
SessionOptions,
)
from ..core.config import settings
from .text_processing import phonemize, tokenize
from .tts_base import TTSBaseModel
class TTSCPUModel(TTSBaseModel):
_instance = None
_onnx_session = None
_device = "cpu"
@classmethod
def get_instance(cls):
"""Get the model instance"""
if cls._onnx_session is None:
raise RuntimeError("ONNX model not initialized. Call initialize() first.")
return cls._onnx_session
@classmethod
def initialize(cls, model_dir: str, model_path: str = None):
"""Initialize ONNX model for CPU inference"""
if cls._onnx_session is None:
try:
# Try loading ONNX model
onnx_path = os.path.join(model_dir, settings.onnx_model_path)
if not os.path.exists(onnx_path):
logger.error(f"ONNX model not found at {onnx_path}")
return None
logger.info(f"Loading ONNX model from {onnx_path}")
# Configure ONNX session for optimal performance
session_options = SessionOptions()
# Set optimization level
if settings.onnx_optimization_level == "all":
session_options.graph_optimization_level = (
GraphOptimizationLevel.ORT_ENABLE_ALL
)
elif settings.onnx_optimization_level == "basic":
session_options.graph_optimization_level = (
GraphOptimizationLevel.ORT_ENABLE_BASIC
)
else:
session_options.graph_optimization_level = (
GraphOptimizationLevel.ORT_DISABLE_ALL
)
# Configure threading
session_options.intra_op_num_threads = settings.onnx_num_threads
session_options.inter_op_num_threads = settings.onnx_inter_op_threads
# Set execution mode
session_options.execution_mode = (
ExecutionMode.ORT_PARALLEL
if settings.onnx_execution_mode == "parallel"
else ExecutionMode.ORT_SEQUENTIAL
)
# Enable/disable memory pattern optimization
session_options.enable_mem_pattern = settings.onnx_memory_pattern
# Configure CPU provider options
provider_options = {
"CPUExecutionProvider": {
"arena_extend_strategy": settings.onnx_arena_extend_strategy,
"cpu_memory_arena_cfg": "cpu:0",
}
}
session = InferenceSession(
onnx_path,
sess_options=session_options,
providers=["CPUExecutionProvider"],
provider_options=[provider_options],
)
cls._onnx_session = session
return session
except Exception as e:
logger.error(f"Failed to initialize ONNX model: {e}")
return None
return cls._onnx_session
@classmethod
def process_text(cls, text: str, language: str) -> tuple[str, list[int]]:
"""Process text into phonemes and tokens
Args:
text: Input text
language: Language code
Returns:
tuple[str, list[int]]: Phonemes and token IDs
"""
phonemes = phonemize(text, language)
tokens = tokenize(phonemes)
tokens = [0] + tokens + [0] # Add start/end tokens
return phonemes, tokens
@classmethod
def generate_from_text(
cls, text: str, voicepack: torch.Tensor, language: str, speed: float
) -> tuple[np.ndarray, str]:
"""Generate audio from text
Args:
text: Input text
voicepack: Voice tensor
language: Language code
speed: Speed factor
Returns:
tuple[np.ndarray, str]: Generated audio samples and phonemes
"""
if cls._onnx_session is None:
raise RuntimeError("ONNX model not initialized")
# Process text
phonemes, tokens = cls.process_text(text, language)
# Generate audio
audio = cls.generate_from_tokens(tokens, voicepack, speed)
return audio, phonemes
@classmethod
def generate_from_tokens(
cls, tokens: list[int], voicepack: torch.Tensor, speed: float
) -> np.ndarray:
"""Generate audio from tokens
Args:
tokens: Token IDs
voicepack: Voice tensor
speed: Speed factor
Returns:
np.ndarray: Generated audio samples
"""
if cls._onnx_session is None:
raise RuntimeError("ONNX model not initialized")
# Pre-allocate and prepare inputs
tokens_input = np.array([tokens], dtype=np.int64)
style_input = voicepack[
len(tokens) - 2
].numpy() # Already has correct dimensions
speed_input = np.full(
1, speed, dtype=np.float32
) # More efficient than ones * speed
# Run inference with optimized inputs
result = cls._onnx_session.run(
None, {"tokens": tokens_input, "style": style_input, "speed": speed_input}
)
return result[0]

View file

@ -1,262 +0,0 @@
import os
import time
import numpy as np
import torch
from ..builds.models import build_model
from loguru import logger
from ..core.config import settings
from .text_processing import phonemize, tokenize
from .tts_base import TTSBaseModel
# @torch.no_grad()
# def forward(model, tokens, ref_s, speed):
# """Forward pass through the model"""
# device = ref_s.device
# tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
# input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
# text_mask = length_to_mask(input_lengths).to(device)
# bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
# d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
# s = ref_s[:, 128:]
# d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
# x, _ = model.predictor.lstm(d)
# duration = model.predictor.duration_proj(x)
# duration = torch.sigmoid(duration).sum(axis=-1) / speed
# pred_dur = torch.round(duration).clamp(min=1).long()
# pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
# c_frame = 0
# for i in range(pred_aln_trg.size(0)):
# pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
# c_frame += pred_dur[0, i].item()
# en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
# F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
# t_en = model.text_encoder(tokens, input_lengths, text_mask)
# asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
# return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
@torch.no_grad()
def forward(model, tokens, ref_s, speed):
"""Forward pass through the model with moderate memory management"""
device = ref_s.device
try:
# Initial tensor setup with proper device placement
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
text_mask = length_to_mask(input_lengths).to(device)
# Split and clone reference signals with explicit device placement
s_content = ref_s[:, 128:].clone().to(device)
s_ref = ref_s[:, :128].clone().to(device)
# BERT and encoder pass
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
# Predictor forward pass
d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask)
x, _ = model.predictor.lstm(d)
# Duration prediction
duration = model.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long()
# Only cleanup large intermediates
del duration, x
# Alignment matrix construction
pred_aln_trg = torch.zeros(input_lengths.item(), pred_dur.sum().item(), device=device)
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
c_frame += pred_dur[0, i].item()
pred_aln_trg = pred_aln_trg.unsqueeze(0)
# Matrix multiplications with selective cleanup
en = d.transpose(-1, -2) @ pred_aln_trg
del d # Free large intermediate tensor
F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
del en # Free large intermediate tensor
# Final text encoding and decoding
t_en = model.text_encoder(tokens, input_lengths, text_mask)
asr = t_en @ pred_aln_trg
del t_en # Free large intermediate tensor
# Final decoding and transfer to CPU
output = model.decoder(asr, F0_pred, N_pred, s_ref)
result = output.squeeze().cpu().numpy()
return result
finally:
# Let PyTorch handle most cleanup automatically
# Only explicitly free the largest tensors
del pred_aln_trg, asr
# def length_to_mask(lengths):
# """Create attention mask from lengths"""
# mask = (
# torch.arange(lengths.max())
# .unsqueeze(0)
# .expand(lengths.shape[0], -1)
# .type_as(lengths)
# )
# mask = torch.gt(mask + 1, lengths.unsqueeze(1))
# return mask
def length_to_mask(lengths):
"""Create attention mask from lengths - possibly optimized version"""
max_len = lengths.max()
# Create mask directly on the same device as lengths
mask = torch.arange(max_len, device=lengths.device)[None, :].expand(
lengths.shape[0], -1
)
# Avoid type_as by using the correct dtype from the start
if lengths.dtype != mask.dtype:
mask = mask.to(dtype=lengths.dtype)
# Fuse operations using broadcasting
return mask + 1 > lengths[:, None]
class TTSGPUModel(TTSBaseModel):
_instance = None
_device = "cuda"
@classmethod
def get_instance(cls):
"""Get the model instance"""
if cls._instance is None:
raise RuntimeError("GPU model not initialized. Call initialize() first.")
return cls._instance
@classmethod
def initialize(cls, model_dir: str, model_path: str):
"""Initialize PyTorch model for GPU inference"""
if cls._instance is None and torch.cuda.is_available():
try:
logger.info("Initializing GPU model")
model_path = os.path.join(model_dir, settings.pytorch_model_path)
model = build_model(model_path, cls._device)
cls._instance = model
return model
except Exception as e:
logger.error(f"Failed to initialize GPU model: {e}")
return None
return cls._instance
@classmethod
def process_text(cls, text: str, language: str) -> tuple[str, list[int]]:
"""Process text into phonemes and tokens
Args:
text: Input text
language: Language code
Returns:
tuple[str, list[int]]: Phonemes and token IDs
"""
phonemes = phonemize(text, language)
tokens = tokenize(phonemes)
return phonemes, tokens
@classmethod
def generate_from_text(
cls, text: str, voicepack: torch.Tensor, language: str, speed: float
) -> tuple[np.ndarray, str]:
"""Generate audio from text
Args:
text: Input text
voicepack: Voice tensor
language: Language code
speed: Speed factor
Returns:
tuple[np.ndarray, str]: Generated audio samples and phonemes
"""
if cls._instance is None:
raise RuntimeError("GPU model not initialized")
# Process text
phonemes, tokens = cls.process_text(text, language)
# Generate audio
audio = cls.generate_from_tokens(tokens, voicepack, speed)
return audio, phonemes
@classmethod
def generate_from_tokens(
cls, tokens: list[int], voicepack: torch.Tensor, speed: float
) -> np.ndarray:
"""Generate audio from tokens with moderate memory management
Args:
tokens: Token IDs
voicepack: Voice tensor
speed: Speed factor
Returns:
np.ndarray: Generated audio samples
"""
if cls._instance is None:
raise RuntimeError("GPU model not initialized")
try:
device = cls._device
# Check memory pressure
if torch.cuda.is_available():
memory_allocated = torch.cuda.memory_allocated(device) / 1e9 # Convert to GB
if memory_allocated > 2.0: # 2GB limit
logger.info(
f"Memory usage above 2GB threshold:{memory_allocated:.2f}GB "
f"Clearing cache"
)
torch.cuda.empty_cache()
import gc
gc.collect()
# Get reference style with proper device placement
ref_s = voicepack[len(tokens)].clone().to(device)
# Generate audio
audio = forward(cls._instance, tokens, ref_s, speed)
return audio
except RuntimeError as e:
if "out of memory" in str(e):
# On OOM, do a full cleanup and retry
if torch.cuda.is_available():
logger.warning("Out of memory detected, performing full cleanup")
torch.cuda.synchronize()
torch.cuda.empty_cache()
import gc
gc.collect()
# Log memory stats after cleanup
memory_allocated = torch.cuda.memory_allocated(device)
memory_reserved = torch.cuda.memory_reserved(device)
logger.info(
f"Memory after OOM cleanup: "
f"Allocated: {memory_allocated / 1e9:.2f}GB, "
f"Reserved: {memory_reserved / 1e9:.2f}GB"
)
# Retry generation
ref_s = voicepack[len(tokens)].clone().to(device)
audio = forward(cls._instance, tokens, ref_s, speed)
return audio
raise
finally:
# Only synchronize at the top level, no empty_cache
if torch.cuda.is_available():
torch.cuda.synchronize()

View file

@ -1,8 +0,0 @@
import torch
if torch.cuda.is_available():
from .tts_gpu import TTSGPUModel as TTSModel
else:
from .tts_cpu import TTSCPUModel as TTSModel
__all__ = ["TTSModel"]

View file

@ -1,120 +1,255 @@
import io
import os
import re
import time
from functools import lru_cache
from typing import List, Optional, Tuple
"""TTS service using model and voice managers."""
import time
from typing import List, Tuple, Optional, AsyncGenerator, Union
import asyncio
import aiofiles.os
import numpy as np
import scipy.io.wavfile as wavfile
import torch
from loguru import logger
from ..core.config import settings
from ..inference.model_manager import get_manager as get_model_manager
from ..inference.voice_manager import get_manager as get_voice_manager
from .audio import AudioNormalizer, AudioService
from .text_processing import chunker, normalize_text
from .tts_model import TTSModel
from .text_processing.text_processor import process_text_chunk, smart_split
from .text_processing import tokenize
class TTSService:
"""Text-to-speech service."""
# Limit concurrent chunk processing
_chunk_semaphore = asyncio.Semaphore(4)
def __init__(self, output_dir: str = None):
"""Initialize service."""
self.output_dir = output_dir
self.model = TTSModel.get_instance()
self.model_manager = None
self._voice_manager = None
@staticmethod
@lru_cache(maxsize=3) # Cache up to 3 most recently used voices
def _load_voice(voice_path: str) -> torch.Tensor:
"""Load and cache a voice model"""
return torch.load(
voice_path, map_location=TTSModel.get_device(), weights_only=True
)
@classmethod
async def create(cls, output_dir: str = None) -> 'TTSService':
"""Create and initialize TTSService instance."""
service = cls(output_dir)
service.model_manager = await get_model_manager()
service._voice_manager = await get_voice_manager()
return service
def _get_voice_path(self, voice_name: str) -> Optional[str]:
"""Get the path to a voice file"""
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice_name}.pt")
return voice_path if os.path.exists(voice_path) else None
async def _process_chunk(
self,
tokens: List[int],
voice_tensor: torch.Tensor,
speed: float,
output_format: Optional[str] = None,
is_first: bool = False,
is_last: bool = False,
normalizer: Optional[AudioNormalizer] = None,
) -> Optional[Union[np.ndarray, bytes]]:
"""Process tokens into audio."""
async with self._chunk_semaphore:
try:
# Handle stream finalization
if is_last:
# Skip format conversion for raw audio mode
if not output_format:
return np.array([], dtype=np.float32)
return await AudioService.convert_audio(
np.array([0], dtype=np.float32), # Dummy data for type checking
24000,
output_format,
is_first_chunk=False,
normalizer=normalizer,
is_last_chunk=True
)
# Skip empty chunks
if not tokens:
return None
def _generate_audio(
self, text: str, voice: str, speed: float, stitch_long_output: bool = True
) -> Tuple[torch.Tensor, float]:
"""Generate complete audio and return with processing time"""
audio, processing_time = self._generate_audio_internal(
text, voice, speed, stitch_long_output
)
return audio, processing_time
# Generate audio using pre-warmed model
chunk_audio = await self.model_manager.generate(
tokens,
voice_tensor,
speed=speed
)
if chunk_audio is None:
logger.error("Model generated None for audio chunk")
return None
if len(chunk_audio) == 0:
logger.error("Model generated empty audio chunk")
return None
# For streaming, convert to bytes
if output_format:
try:
return await AudioService.convert_audio(
chunk_audio,
24000,
output_format,
is_first_chunk=is_first,
normalizer=normalizer,
is_last_chunk=is_last
)
except Exception as e:
logger.error(f"Failed to convert audio: {str(e)}")
return None
return chunk_audio
except Exception as e:
logger.error(f"Failed to process tokens: {str(e)}")
return None
def _generate_audio_internal(
self, text: str, voice: str, speed: float, stitch_long_output: bool = True
) -> Tuple[torch.Tensor, float]:
"""Generate audio and measure processing time"""
async def generate_audio_stream(
self,
text: str,
voice: str,
speed: float = 1.0,
output_format: str = "wav",
) -> AsyncGenerator[bytes, None]:
"""Generate and stream audio chunks."""
stream_normalizer = AudioNormalizer()
voice_tensor = None
chunk_index = 0
try:
# Get backend and load voice (should be fast if cached)
backend = self.model_manager.get_backend()
voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device)
# Process text in chunks with smart splitting
async for chunk_text, tokens in smart_split(text):
try:
# Process audio for chunk
result = await self._process_chunk(
tokens, # Now always a flat List[int]
voice_tensor,
speed,
output_format,
is_first=(chunk_index == 0),
is_last=False, # We'll update the last chunk later
normalizer=stream_normalizer
)
if result is not None:
yield result
chunk_index += 1
else:
logger.warning(f"No audio generated for chunk: '{chunk_text[:100]}...'")
except Exception as e:
logger.error(f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}")
continue
# Only finalize if we successfully processed at least one chunk
if chunk_index > 0:
try:
# Empty tokens list to finalize audio
final_result = await self._process_chunk(
[], # Empty tokens list
voice_tensor,
speed,
output_format,
is_first=False,
is_last=True,
normalizer=stream_normalizer
)
if final_result is not None:
logger.debug("Yielding final chunk to finalize audio")
yield final_result
else:
logger.warning("Final chunk processing returned None")
except Exception as e:
logger.error(f"Failed to process final chunk: {str(e)}")
else:
logger.warning("No audio chunks were successfully processed")
except Exception as e:
logger.error(f"Error in audio generation stream: {str(e)}")
raise
finally:
if voice_tensor is not None:
del voice_tensor
torch.cuda.empty_cache()
async def generate_from_phonemes(
self, phonemes: str, voice: str, speed: float = 1.0
) -> Tuple[np.ndarray, float]:
"""Generate audio from phonemes.
Args:
phonemes: Phoneme string to synthesize
voice: Voice ID to use
speed: Speed multiplier
Returns:
Tuple of (audio array, processing time)
"""
start_time = time.time()
voice_tensor = None
try:
# Normalize text once at the start
if not text:
raise ValueError("Text is empty after preprocessing")
normalized = normalize_text(text)
if not normalized:
raise ValueError("Text is empty after preprocessing")
text = str(normalized)
# Get backend and load voice
backend = self.model_manager.get_backend()
voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device)
# Check voice exists
voice_path = self._get_voice_path(voice)
if not voice_path:
raise ValueError(f"Voice not found: {voice}")
# Convert phonemes to tokens
tokens = tokenize(phonemes)
if len(tokens) > 500: # Model context limit
raise ValueError(f"Phoneme sequence too long ({len(tokens)} tokens, max 500)")
tokens = [0] + tokens + [0] # Add start/end tokens
# Generate audio
audio = await self.model_manager.generate(
tokens,
voice_tensor,
speed=speed
)
if audio is None:
raise ValueError("Failed to generate audio")
# Load voice using cached loader
voicepack = self._load_voice(voice_path)
processing_time = time.time() - start_time
return audio, processing_time
# For non-streaming, preprocess all chunks first
if stitch_long_output:
# Preprocess all chunks to phonemes/tokens
chunks_data = []
for chunk in chunker.split_text(text):
try:
phonemes, tokens = TTSModel.process_text(chunk, voice[0])
chunks_data.append((chunk, tokens))
except Exception as e:
logger.error(
f"Failed to process chunk: '{chunk}'. Error: {str(e)}"
)
continue
except Exception as e:
logger.error(f"Error in phoneme audio generation: {str(e)}")
raise
finally:
if voice_tensor is not None:
del voice_tensor
torch.cuda.empty_cache()
if not chunks_data:
raise ValueError("No chunks were processed successfully")
async def generate_audio(
self, text: str, voice: str, speed: float = 1.0
) -> Tuple[np.ndarray, float]:
"""Generate complete audio for text using streaming internally."""
start_time = time.time()
chunks = []
try:
# Use streaming generator but collect all valid chunks
async for chunk in self.generate_audio_stream(
text, voice, speed, # Default to WAV for raw audio
):
if chunk is not None:
chunks.append(chunk)
# Generate audio for all chunks
audio_chunks = []
for chunk, tokens in chunks_data:
try:
chunk_audio = TTSModel.generate_from_tokens(
tokens, voicepack, speed
)
if chunk_audio is not None:
audio_chunks.append(chunk_audio)
else:
logger.error(f"No audio generated for chunk: '{chunk}'")
except Exception as e:
logger.error(
f"Failed to generate audio for chunk: '{chunk}'. Error: {str(e)}"
)
continue
if not chunks:
raise ValueError("No audio chunks were generated successfully")
if not audio_chunks:
raise ValueError("No audio chunks were generated successfully")
# Concatenate all chunks
audio = (
np.concatenate(audio_chunks)
if len(audio_chunks) > 1
else audio_chunks[0]
)
# Combine chunks, ensuring we have valid arrays
if len(chunks) == 1:
audio = chunks[0]
else:
# Process single chunk
phonemes, tokens = TTSModel.process_text(text, voice[0])
audio = TTSModel.generate_from_tokens(tokens, voicepack, speed)
# Filter out any zero-dimensional arrays
valid_chunks = [c for c in chunks if c.ndim > 0]
if not valid_chunks:
raise ValueError("No valid audio chunks to concatenate")
audio = np.concatenate(valid_chunks)
processing_time = time.time() - start_time
return audio, processing_time
@ -122,148 +257,10 @@ class TTSService:
logger.error(f"Error in audio generation: {str(e)}")
raise
async def generate_audio_stream(
self,
text: str,
voice: str,
speed: float,
output_format: str = "wav",
silent=False,
):
"""Generate and yield audio chunks as they're generated for real-time streaming"""
try:
stream_start = time.time()
# Create normalizer for consistent audio levels
stream_normalizer = AudioNormalizer()
# Input validation and preprocessing
if not text:
raise ValueError("Text is empty")
preprocess_start = time.time()
normalized = normalize_text(text)
if not normalized:
raise ValueError("Text is empty after preprocessing")
text = str(normalized)
logger.debug(
f"Text preprocessing took: {(time.time() - preprocess_start)*1000:.1f}ms"
)
# Voice validation and loading
voice_start = time.time()
voice_path = self._get_voice_path(voice)
if not voice_path:
raise ValueError(f"Voice not found: {voice}")
voicepack = self._load_voice(voice_path)
logger.debug(
f"Voice loading took: {(time.time() - voice_start)*1000:.1f}ms"
)
# Process chunks as they're generated
is_first = True
chunks_processed = 0
# Process chunks as they come from generator
chunk_gen = chunker.split_text(text)
current_chunk = next(chunk_gen, None)
while current_chunk is not None:
next_chunk = next(chunk_gen, None) # Peek at next chunk
chunks_processed += 1
try:
# Process text and generate audio
phonemes, tokens = TTSModel.process_text(current_chunk, voice[0])
chunk_audio = TTSModel.generate_from_tokens(
tokens, voicepack, speed
)
if chunk_audio is not None:
# Convert chunk with proper streaming header handling
chunk_bytes = AudioService.convert_audio(
chunk_audio,
24000,
output_format,
is_first_chunk=is_first,
normalizer=stream_normalizer,
is_last_chunk=(next_chunk is None), # Last if no next chunk
stream=True # Ensure proper streaming format handling
)
yield chunk_bytes
is_first = False
else:
logger.error(f"No audio generated for chunk: '{current_chunk}'")
except Exception as e:
logger.error(
f"Failed to generate audio for chunk: '{current_chunk}'. Error: {str(e)}"
)
current_chunk = next_chunk # Move to next chunk
except Exception as e:
logger.error(f"Error in audio generation stream: {str(e)}")
raise
def _save_audio(self, audio: torch.Tensor, filepath: str):
"""Save audio to file"""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
wavfile.write(filepath, 24000, audio)
def _audio_to_bytes(self, audio: torch.Tensor) -> bytes:
"""Convert audio tensor to WAV bytes"""
buffer = io.BytesIO()
wavfile.write(buffer, 24000, audio)
return buffer.getvalue()
async def combine_voices(self, voices: List[str]) -> str:
"""Combine multiple voices into a new voice"""
if len(voices) < 2:
raise ValueError("At least 2 voices are required for combination")
# Load voices
t_voices: List[torch.Tensor] = []
v_name: List[str] = []
for voice in voices:
try:
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt")
voicepack = torch.load(
voice_path, map_location=TTSModel.get_device(), weights_only=True
)
t_voices.append(voicepack)
v_name.append(voice)
except Exception as e:
raise ValueError(f"Failed to load voice {voice}: {str(e)}")
# Combine voices
try:
f: str = "_".join(v_name)
v = torch.mean(torch.stack(t_voices), dim=0)
combined_path = os.path.join(TTSModel.VOICES_DIR, f"{f}.pt")
# Save combined voice
try:
torch.save(v, combined_path)
except Exception as e:
raise RuntimeError(
f"Failed to save combined voice to {combined_path}: {str(e)}"
)
return f
except Exception as e:
if not isinstance(e, (ValueError, RuntimeError)):
raise RuntimeError(f"Error combining voices: {str(e)}")
raise
"""Combine multiple voices."""
return await self._voice_manager.combine_voices(voices)
async def list_voices(self) -> List[str]:
"""List all available voices"""
voices = []
try:
it = await aiofiles.os.scandir(TTSModel.VOICES_DIR)
for entry in it:
if entry.name.endswith(".pt"):
voices.append(entry.name[:-3]) # Remove .pt extension
except Exception as e:
logger.error(f"Error listing voices: {str(e)}")
return sorted(voices)
"""List available voices."""
return await self._voice_manager.list_voices()

View file

@ -1,60 +0,0 @@
import os
from typing import List, Tuple
import torch
from loguru import logger
from ..core.config import settings
from .tts_model import TTSModel
from .tts_service import TTSService
class WarmupService:
"""Service for warming up TTS models and voice caches"""
def __init__(self):
"""Initialize warmup service and ensure model is ready"""
# Initialize model if not already initialized
if TTSModel._instance is None:
TTSModel.initialize(settings.model_dir)
self.tts_service = TTSService()
def load_voices(self) -> List[Tuple[str, torch.Tensor]]:
"""Load and cache voices up to LRU limit"""
# Get all voices sorted by filename length (shorter names first, usually base voices)
voice_files = sorted(
[f for f in os.listdir(TTSModel.VOICES_DIR) if f.endswith(".pt")], key=len
)
n_voices_cache = 1
loaded_voices = []
for voice_file in voice_files[:n_voices_cache]:
try:
voice_path = os.path.join(TTSModel.VOICES_DIR, voice_file)
# load using service, lru cache
voicepack = self.tts_service._load_voice(voice_path)
loaded_voices.append(
(voice_file[:-3], voicepack)
) # Store name and tensor
# voicepack = torch.load(voice_path, map_location=TTSModel.get_device(), weights_only=True)
# logger.info(f"Loaded voice {voice_file[:-3]} into cache")
except Exception as e:
logger.error(f"Failed to load voice {voice_file}: {e}")
logger.info(f"Pre-loaded {len(loaded_voices)} voices into cache")
return loaded_voices
async def warmup_voices(
self, warmup_text: str, loaded_voices: List[Tuple[str, torch.Tensor]]
):
"""Warm up voice inference and streaming"""
n_warmups = 1
for voice_name, _ in loaded_voices[:n_warmups]:
try:
logger.info(f"Running warmup inference on voice {voice_name}")
async for _ in self.tts_service.generate_audio_stream(
warmup_text, voice_name, 1.0, "pcm"
):
pass # Process all chunks to properly warm up
logger.info(f"Completed warmup for voice {voice_name}")
except Exception as e:
logger.warning(f"Warmup failed for voice {voice_name}: {e}")

View file

@ -0,0 +1,13 @@
"""Voice configuration schemas."""
from pydantic import BaseModel, Field
class VoiceConfig(BaseModel):
"""Voice configuration."""
use_cache: bool = Field(True, description="Whether to cache loaded voices")
cache_size: int = Field(3, description="Number of voices to cache")
validate_on_load: bool = Field(True, description="Whether to validate voices when loading")
class Config:
frozen = True # Make config immutable

View file

@ -23,7 +23,10 @@ class TTSStatus(str, Enum):
# OpenAI-compatible schemas
class OpenAISpeechRequest(BaseModel):
model: Literal["tts-1", "tts-1-hd", "kokoro"] = "kokoro"
model: str = Field(
default="kokoro",
description="The model to use for generation. Supported models: tts-1, tts-1-hd, kokoro"
)
input: str = Field(..., description="The text to generate audio for")
voice: str = Field(
default="af",
@ -43,3 +46,7 @@ class OpenAISpeechRequest(BaseModel):
default=True, # Default to streaming for OpenAI compatibility
description="If true (default), audio will be streamed as it's generated. Each chunk will be a complete sentence.",
)
return_download_link: bool = Field(
default=False,
description="If true, returns a download link in X-Download-Path header after streaming completes",
)

View file

@ -1,4 +1,5 @@
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from typing import List, Union, Optional
class PhonemeRequest(BaseModel):
@ -11,9 +12,27 @@ class PhonemeResponse(BaseModel):
tokens: list[int]
class GenerateFromPhonemesRequest(BaseModel):
phonemes: str
voice: str = Field(..., description="Voice ID to use for generation")
speed: float = Field(
default=1.0, ge=0.1, le=5.0, description="Speed factor for generation"
class StitchOptions(BaseModel):
"""Options for stitching audio chunks together"""
gap_method: str = Field(
default="static_trim",
description="Method to handle gaps between chunks. Currently only 'static_trim' supported."
)
trim_ms: int = Field(
default=0,
ge=0,
description="Milliseconds to trim from chunk boundaries when using static_trim"
)
@field_validator('gap_method')
@classmethod
def validate_gap_method(cls, v: str) -> str:
if v != 'static_trim':
raise ValueError("Currently only 'static_trim' gap method is supported")
return v
class GenerateFromPhonemesRequest(BaseModel):
"""Simple request for phoneme-to-speech generation"""
phonemes: str = Field(..., description="Phoneme string to synthesize")
voice: str = Field(..., description="Voice ID to use for generation")

BIN
api/src/voices/am_gurney.pt Normal file

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View file

@ -1,126 +1,69 @@
import os
import shutil
import sys
from unittest.mock import MagicMock, Mock, patch
import aiofiles.threadpool
import numpy as np
import pytest
def cleanup_mock_dirs():
"""Clean up any MagicMock directories created during tests"""
mock_dir = "MagicMock"
if os.path.exists(mock_dir):
shutil.rmtree(mock_dir)
@pytest.fixture(autouse=True)
def setup_aiofiles():
"""Setup aiofiles mock wrapper"""
aiofiles.threadpool.wrap.register(MagicMock)(
lambda *args, **kwargs: aiofiles.threadpool.AsyncBufferedIOBase(*args, **kwargs)
)
yield
@pytest.fixture(autouse=True)
def cleanup():
"""Automatically clean up before and after each test"""
cleanup_mock_dirs()
yield
cleanup_mock_dirs()
# Mock modules before they're imported
sys.modules["transformers"] = Mock()
sys.modules["phonemizer"] = Mock()
sys.modules["models"] = Mock()
sys.modules["models.build_model"] = Mock()
sys.modules["kokoro"] = Mock()
sys.modules["kokoro.generate"] = Mock()
sys.modules["kokoro.phonemize"] = Mock()
sys.modules["kokoro.tokenize"] = Mock()
# Mock ONNX runtime
mock_onnx = Mock()
mock_onnx.InferenceSession = Mock()
mock_onnx.SessionOptions = Mock()
mock_onnx.GraphOptimizationLevel = Mock()
mock_onnx.ExecutionMode = Mock()
sys.modules["onnxruntime"] = mock_onnx
# Create mock settings module
mock_settings_module = Mock()
mock_settings = Mock()
mock_settings.model_dir = "/mock/model/dir"
mock_settings.onnx_model_path = "mock.onnx"
mock_settings_module.settings = mock_settings
sys.modules["api.src.core.config"] = mock_settings_module
class MockTTSModel:
_instance = None
_onnx_session = None
VOICES_DIR = "/mock/voices/dir"
def __init__(self):
self._initialized = False
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance
@classmethod
def initialize(cls, model_dir):
cls._onnx_session = Mock()
cls._onnx_session.run = Mock(return_value=[np.zeros(48000)])
cls._instance._initialized = True
return cls._onnx_session
@classmethod
def setup(cls):
if not cls._instance._initialized:
cls.initialize("/mock/model/dir")
return cls._instance
@classmethod
def generate_from_tokens(cls, tokens, voicepack, speed):
if not cls._instance._initialized:
raise RuntimeError("Model not initialized. Call setup() first.")
return np.zeros(48000)
@classmethod
def process_text(cls, text, language):
return "mock phonemes", [1, 2, 3]
@staticmethod
def get_device():
return "cpu"
import pytest_asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import torch
from pathlib import Path
import os
from api.src.services.tts_service import TTSService
from api.src.inference.voice_manager import VoiceManager
from api.src.inference.model_manager import ModelManager
from api.src.structures.model_schemas import VoiceConfig
@pytest.fixture
def mock_tts_service(monkeypatch):
"""Mock TTSService for testing"""
mock_service = Mock()
mock_service._get_voice_path.return_value = "/mock/path/voice.pt"
mock_service._load_voice.return_value = np.zeros((1, 192))
# Mock TTSModel.generate_from_tokens since we call it directly
mock_generate = Mock(return_value=np.zeros(48000))
monkeypatch.setattr(
"api.src.routers.development.TTSModel.generate_from_tokens", mock_generate
)
return mock_service
def mock_voice_tensor():
"""Load a real voice tensor for testing."""
voice_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'src/voices/af_bella.pt')
return torch.load(voice_path, map_location='cpu', weights_only=False)
@pytest.fixture
def mock_audio_service(monkeypatch):
"""Mock AudioService"""
mock_service = Mock()
mock_service.convert_audio.return_value = b"mock audio data"
monkeypatch.setattr("api.src.routers.development.AudioService", mock_service)
return mock_service
def mock_audio_output():
"""Load pre-generated test audio for consistent testing."""
test_audio_path = os.path.join(os.path.dirname(__file__), 'test_data/test_audio.npy')
return np.load(test_audio_path) # Return as numpy array instead of bytes
@pytest_asyncio.fixture
async def mock_model_manager(mock_audio_output):
"""Mock model manager for testing."""
manager = AsyncMock(spec=ModelManager)
manager.get_backend = MagicMock()
async def mock_generate(*args, **kwargs):
# Simulate successful audio generation
return np.random.rand(24000).astype(np.float32) # 1 second of random audio data
manager.generate = AsyncMock(side_effect=mock_generate)
return manager
@pytest_asyncio.fixture
async def mock_voice_manager(mock_voice_tensor):
"""Mock voice manager for testing."""
manager = AsyncMock(spec=VoiceManager)
manager.get_voice_path = MagicMock(return_value="/mock/path/voice.pt")
manager.load_voice = AsyncMock(return_value=mock_voice_tensor)
manager.list_voices = AsyncMock(return_value=["voice1", "voice2"])
manager.combine_voices = AsyncMock(return_value="voice1_voice2")
return manager
@pytest_asyncio.fixture
async def tts_service(mock_model_manager, mock_voice_manager):
"""Get mocked TTS service instance."""
service = TTSService()
service.model_manager = mock_model_manager
service._voice_manager = mock_voice_manager
return service
@pytest.fixture
def test_voice():
"""Return a test voice name."""
return "voice1"
@pytest.fixture(scope="session")
def event_loop():
"""Create an instance of the default event loop for the test session."""
import asyncio
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
yield loop
loop.close()

View file

@ -26,106 +26,161 @@ def sample_audio():
return np.sin(2 * np.pi * frequency * t).astype(np.float32), sample_rate
def test_convert_to_wav(sample_audio):
@pytest.mark.asyncio
async def test_convert_to_wav(sample_audio):
"""Test converting to WAV format"""
audio_data, sample_rate = sample_audio
result = AudioService.convert_audio(audio_data, sample_rate, "wav")
# Write and finalize in one step for WAV
result = await AudioService.convert_audio(
audio_data,
sample_rate,
"wav",
is_first_chunk=True,
is_last_chunk=True
)
assert isinstance(result, bytes)
assert len(result) > 0
# Check WAV header
assert result.startswith(b'RIFF')
assert b'WAVE' in result[:12]
def test_convert_to_mp3(sample_audio):
@pytest.mark.asyncio
async def test_convert_to_mp3(sample_audio):
"""Test converting to MP3 format"""
audio_data, sample_rate = sample_audio
result = AudioService.convert_audio(audio_data, sample_rate, "mp3")
result = await AudioService.convert_audio(audio_data, sample_rate, "mp3")
assert isinstance(result, bytes)
assert len(result) > 0
# Check MP3 header (ID3 or MPEG frame sync)
assert result.startswith(b'ID3') or result.startswith(b'\xff\xfb')
def test_convert_to_opus(sample_audio):
@pytest.mark.asyncio
async def test_convert_to_opus(sample_audio):
"""Test converting to Opus format"""
audio_data, sample_rate = sample_audio
result = AudioService.convert_audio(audio_data, sample_rate, "opus")
result = await AudioService.convert_audio(audio_data, sample_rate, "opus")
assert isinstance(result, bytes)
assert len(result) > 0
# Check OGG header
assert result.startswith(b'OggS')
def test_convert_to_flac(sample_audio):
@pytest.mark.asyncio
async def test_convert_to_flac(sample_audio):
"""Test converting to FLAC format"""
audio_data, sample_rate = sample_audio
result = AudioService.convert_audio(audio_data, sample_rate, "flac")
result = await AudioService.convert_audio(audio_data, sample_rate, "flac")
assert isinstance(result, bytes)
assert len(result) > 0
# Check FLAC header
assert result.startswith(b'fLaC')
def test_convert_to_aac(sample_audio):
@pytest.mark.asyncio
async def test_convert_to_aac(sample_audio):
"""Test converting to AAC format"""
audio_data, sample_rate = sample_audio
result = AudioService.convert_audio(audio_data, sample_rate, "aac")
result = await AudioService.convert_audio(audio_data, sample_rate, "aac")
assert isinstance(result, bytes)
assert len(result) > 0
# AAC files typically start with an ADTS header
assert result.startswith(b'\xff\xf1') or result.startswith(b'\xff\xf9')
# Check ADTS header (AAC)
assert result.startswith(b'\xff\xf0') or result.startswith(b'\xff\xf1')
def test_convert_to_pcm(sample_audio):
@pytest.mark.asyncio
async def test_convert_to_pcm(sample_audio):
"""Test converting to PCM format"""
audio_data, sample_rate = sample_audio
result = AudioService.convert_audio(audio_data, sample_rate, "pcm")
result = await AudioService.convert_audio(audio_data, sample_rate, "pcm")
assert isinstance(result, bytes)
assert len(result) > 0
# PCM is raw bytes, so no header to check
def test_convert_to_invalid_format_raises_error(sample_audio):
@pytest.mark.asyncio
async def test_convert_to_invalid_format_raises_error(sample_audio):
"""Test that converting to an invalid format raises an error"""
audio_data, sample_rate = sample_audio
with pytest.raises(ValueError, match="Format invalid not supported"):
AudioService.convert_audio(audio_data, sample_rate, "invalid")
await AudioService.convert_audio(audio_data, sample_rate, "invalid")
def test_normalization_wav(sample_audio):
@pytest.mark.asyncio
async def test_normalization_wav(sample_audio):
"""Test that WAV output is properly normalized to int16 range"""
audio_data, sample_rate = sample_audio
# Create audio data outside int16 range
large_audio = audio_data * 1e5
result = AudioService.convert_audio(large_audio, sample_rate, "wav")
# Write and finalize in one step for WAV
result = await AudioService.convert_audio(
large_audio,
sample_rate,
"wav",
is_first_chunk=True,
is_last_chunk=True
)
assert isinstance(result, bytes)
assert len(result) > 0
def test_normalization_pcm(sample_audio):
@pytest.mark.asyncio
async def test_normalization_pcm(sample_audio):
"""Test that PCM output is properly normalized to int16 range"""
audio_data, sample_rate = sample_audio
# Create audio data outside int16 range
large_audio = audio_data * 1e5
result = AudioService.convert_audio(large_audio, sample_rate, "pcm")
result = await AudioService.convert_audio(large_audio, sample_rate, "pcm")
assert isinstance(result, bytes)
assert len(result) > 0
def test_invalid_audio_data():
@pytest.mark.asyncio
async def test_invalid_audio_data():
"""Test handling of invalid audio data"""
invalid_audio = np.array([]) # Empty array
sample_rate = 24000
with pytest.raises(ValueError):
AudioService.convert_audio(invalid_audio, sample_rate, "wav")
await AudioService.convert_audio(invalid_audio, sample_rate, "wav")
def test_different_sample_rates(sample_audio):
@pytest.mark.asyncio
async def test_different_sample_rates(sample_audio):
"""Test converting audio with different sample rates"""
audio_data, _ = sample_audio
sample_rates = [8000, 16000, 44100, 48000]
for rate in sample_rates:
result = AudioService.convert_audio(audio_data, rate, "wav")
result = await AudioService.convert_audio(
audio_data,
rate,
"wav",
is_first_chunk=True,
is_last_chunk=True
)
assert isinstance(result, bytes)
assert len(result) > 0
def test_buffer_position_after_conversion(sample_audio):
@pytest.mark.asyncio
async def test_buffer_position_after_conversion(sample_audio):
"""Test that buffer position is reset after writing"""
audio_data, sample_rate = sample_audio
result = AudioService.convert_audio(audio_data, sample_rate, "wav")
# Write and finalize in one step for first conversion
result = await AudioService.convert_audio(
audio_data,
sample_rate,
"wav",
is_first_chunk=True,
is_last_chunk=True
)
# Convert again to ensure buffer was properly reset
result2 = AudioService.convert_audio(audio_data, sample_rate, "wav")
result2 = await AudioService.convert_audio(
audio_data,
sample_rate,
"wav",
is_first_chunk=True,
is_last_chunk=True
)
assert len(result) == len(result2)

View file

@ -1,46 +0,0 @@
"""Tests for text chunking service"""
from unittest.mock import patch
import pytest
from api.src.services.text_processing import chunker
@pytest.fixture(autouse=True)
def mock_settings():
"""Mock settings for all tests"""
with patch("api.src.services.text_processing.chunker.settings") as mock_settings:
mock_settings.max_chunk_size = 300
yield mock_settings
def test_split_text():
"""Test text splitting into sentences"""
text = "First sentence. Second sentence! Third sentence?"
sentences = list(chunker.split_text(text))
assert len(sentences) == 3
assert sentences[0] == "First sentence."
assert sentences[1] == "Second sentence!"
assert sentences[2] == "Third sentence?"
def test_split_text_empty():
"""Test splitting empty text"""
assert list(chunker.split_text("")) == []
def test_split_text_single_sentence():
"""Test splitting single sentence"""
text = "Just one sentence."
assert list(chunker.split_text(text)) == ["Just one sentence."]
def test_split_text_with_custom_chunk_size():
"""Test splitting with custom max chunk size"""
text = "First part, second part, third part."
chunks = list(chunker.split_text(text, max_chunk=15))
assert len(chunks) == 3
assert chunks[0] == "First part,"
assert chunks[1] == "second part,"
assert chunks[2] == "third part."

View file

@ -0,0 +1,20 @@
import numpy as np
import os
def generate_test_audio():
"""Generate test audio data - 1 second of 440Hz tone"""
# Create 1 second of silence at 24kHz
audio = np.zeros(24000, dtype=np.float32)
# Add a simple sine wave to make it non-zero
t = np.linspace(0, 1, 24000)
audio += 0.5 * np.sin(2 * np.pi * 440 * t) # 440 Hz tone at half amplitude
# Create test_data directory if it doesn't exist
os.makedirs('api/tests/test_data', exist_ok=True)
# Save the test audio
np.save('api/tests/test_data/test_audio.npy', audio)
if __name__ == '__main__':
generate_test_audio()

Binary file not shown.

View file

@ -1,402 +0,0 @@
import asyncio
from unittest.mock import AsyncMock, Mock
import pytest
import pytest_asyncio
from fastapi.testclient import TestClient
from httpx import AsyncClient
from ..src.main import app
# Create test client
client = TestClient(app)
# Create async client fixture
@pytest_asyncio.fixture
async def async_client():
async with AsyncClient(app=app, base_url="http://test") as ac:
yield ac
# Mock services
@pytest.fixture
def mock_tts_service(monkeypatch):
mock_service = Mock()
mock_service._generate_audio.return_value = (bytes([0, 1, 2, 3]), 1.0)
# Create proper async generator mock
async def mock_stream(*args, **kwargs):
for chunk in [b"chunk1", b"chunk2"]:
yield chunk
mock_service.generate_audio_stream = mock_stream
# Create async mocks
mock_service.list_voices = AsyncMock(
return_value=[
"af",
"bm_lewis",
"bf_isabella",
"bf_emma",
"af_sarah",
"af_bella",
"am_adam",
"am_michael",
"bm_george",
]
)
mock_service.combine_voices = AsyncMock()
monkeypatch.setattr(
"api.src.routers.openai_compatible.TTSService",
lambda *args, **kwargs: mock_service,
)
return mock_service
@pytest.fixture
def mock_audio_service(monkeypatch):
mock_service = Mock()
mock_service.convert_audio.return_value = b"converted mock audio data"
monkeypatch.setattr("api.src.routers.openai_compatible.AudioService", mock_service)
return mock_service
def test_health_check():
"""Test the health check endpoint"""
response = client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "healthy"}
@pytest.mark.asyncio
async def test_openai_speech_endpoint(
mock_tts_service, mock_audio_service, async_client
):
"""Test the OpenAI-compatible speech endpoint"""
test_request = {
"model": "kokoro",
"input": "Hello world",
"voice": "bm_lewis",
"response_format": "wav",
"speed": 1.0,
"stream": False, # Explicitly disable streaming
}
response = await async_client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/wav"
assert response.headers["content-disposition"] == "attachment; filename=speech.wav"
mock_tts_service._generate_audio.assert_called_once_with(
text="Hello world", voice="bm_lewis", speed=1.0, stitch_long_output=True
)
assert response.content == b"converted mock audio data"
@pytest.mark.asyncio
async def test_openai_speech_invalid_voice(mock_tts_service, async_client):
"""Test the OpenAI-compatible speech endpoint with invalid voice"""
test_request = {
"model": "kokoro",
"input": "Hello world",
"voice": "invalid_voice",
"response_format": "wav",
"speed": 1.0,
"stream": False, # Explicitly disable streaming
}
response = await async_client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 400 # Bad request
assert "not found" in response.json()["detail"]["message"]
@pytest.mark.asyncio
async def test_openai_speech_invalid_speed(mock_tts_service, async_client):
"""Test the OpenAI-compatible speech endpoint with invalid speed"""
test_request = {
"model": "kokoro",
"input": "Hello world",
"voice": "af",
"response_format": "wav",
"speed": -1.0, # Invalid speed
"stream": False, # Explicitly disable streaming
}
response = await async_client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 422 # Validation error
@pytest.mark.asyncio
async def test_openai_speech_generation_error(mock_tts_service, async_client):
"""Test error handling in speech generation"""
mock_tts_service._generate_audio.side_effect = Exception("Generation failed")
test_request = {
"model": "kokoro",
"input": "Hello world",
"voice": "af",
"response_format": "wav",
"speed": 1.0,
"stream": False, # Explicitly disable streaming
}
response = await async_client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 500
assert "Generation failed" in response.json()["detail"]["message"]
@pytest.mark.asyncio
async def test_combine_voices_list_success(mock_tts_service, async_client):
"""Test successful voice combination using list format"""
test_voices = ["af_bella", "af_sarah"]
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
response = await async_client.post("/v1/audio/voices/combine", json=test_voices)
assert response.status_code == 200
assert response.json()["voice"] == "af_bella_af_sarah"
mock_tts_service.combine_voices.assert_called_once_with(voices=test_voices)
@pytest.mark.asyncio
async def test_combine_voices_string_success(mock_tts_service, async_client):
"""Test successful voice combination using string format with +"""
test_voices = "af_bella+af_sarah"
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
response = await async_client.post("/v1/audio/voices/combine", json=test_voices)
assert response.status_code == 200
assert response.json()["voice"] == "af_bella_af_sarah"
mock_tts_service.combine_voices.assert_called_once_with(
voices=["af_bella", "af_sarah"]
)
@pytest.mark.asyncio
async def test_combine_voices_single_voice(mock_tts_service, async_client):
"""Test combining single voice returns same voice"""
test_voices = ["af_bella"]
response = await async_client.post("/v1/audio/voices/combine", json=test_voices)
assert response.status_code == 200
assert response.json()["voice"] == "af_bella"
@pytest.mark.asyncio
async def test_combine_voices_empty_list(mock_tts_service, async_client):
"""Test combining empty voice list returns error"""
test_voices = []
response = await async_client.post("/v1/audio/voices/combine", json=test_voices)
assert response.status_code == 400
assert "No voices provided" in response.json()["detail"]["message"]
@pytest.mark.asyncio
async def test_combine_voices_error(mock_tts_service, async_client):
"""Test error handling in voice combination"""
test_voices = ["af_bella", "af_sarah"]
mock_tts_service.combine_voices = AsyncMock(
side_effect=Exception("Combination failed")
)
response = await async_client.post("/v1/audio/voices/combine", json=test_voices)
assert response.status_code == 500
assert "Server error" in response.json()["detail"]["message"]
@pytest.mark.asyncio
async def test_speech_with_combined_voice(
mock_tts_service, mock_audio_service, async_client
):
"""Test speech generation with combined voice using + syntax"""
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
test_request = {
"model": "kokoro",
"input": "Hello world",
"voice": "af_bella+af_sarah",
"response_format": "wav",
"speed": 1.0,
"stream": False,
}
response = await async_client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/wav"
mock_tts_service._generate_audio.assert_called_once_with(
text="Hello world",
voice="af_bella_af_sarah",
speed=1.0,
stitch_long_output=True,
)
@pytest.mark.asyncio
async def test_speech_with_whitespace_in_voice(
mock_tts_service, mock_audio_service, async_client
):
"""Test speech generation with whitespace in voice combination"""
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
test_request = {
"model": "kokoro",
"input": "Hello world",
"voice": " af_bella + af_sarah ",
"response_format": "wav",
"speed": 1.0,
"stream": False,
}
response = await async_client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/wav"
mock_tts_service.combine_voices.assert_called_once_with(
voices=["af_bella", "af_sarah"]
)
@pytest.mark.asyncio
async def test_speech_with_empty_voice_combination(mock_tts_service, async_client):
"""Test speech generation with empty voice combination"""
test_request = {
"model": "kokoro",
"input": "Hello world",
"voice": "+",
"response_format": "wav",
"speed": 1.0,
"stream": False,
}
response = await async_client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 400
assert "No voices provided" in response.json()["detail"]["message"]
@pytest.mark.asyncio
async def test_speech_with_invalid_combined_voice(mock_tts_service, async_client):
"""Test speech generation with invalid voice combination"""
test_request = {
"model": "kokoro",
"input": "Hello world",
"voice": "invalid+combination",
"response_format": "wav",
"speed": 1.0,
"stream": False,
}
response = await async_client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 400
assert "not found" in response.json()["detail"]["message"]
@pytest.mark.asyncio
async def test_speech_streaming_with_combined_voice(mock_tts_service, async_client):
"""Test streaming speech with combined voice using + syntax"""
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
test_request = {
"model": "kokoro",
"input": "Hello world",
"voice": "af_bella+af_sarah",
"response_format": "mp3",
"stream": True,
}
# Create streaming mock
async def mock_stream(*args, **kwargs):
for chunk in [b"mp3header", b"mp3data"]:
yield chunk
mock_tts_service.generate_audio_stream = mock_stream
# Add streaming header
headers = {"x-raw-response": "stream"}
response = await async_client.post(
"/v1/audio/speech", json=test_request, headers=headers
)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/mpeg"
assert response.headers["content-disposition"] == "attachment; filename=speech.mp3"
@pytest.mark.asyncio
async def test_openai_speech_pcm_streaming(mock_tts_service, async_client):
"""Test streaming PCM audio for real-time playback"""
test_request = {
"model": "kokoro",
"input": "Hello world",
"voice": "af",
"response_format": "pcm",
"stream": True,
}
# Create streaming mock for this test
async def mock_stream(*args, **kwargs):
for chunk in [b"chunk1", b"chunk2"]:
yield chunk
mock_tts_service.generate_audio_stream = mock_stream
# Add streaming header
headers = {"x-raw-response": "stream"}
response = await async_client.post(
"/v1/audio/speech", json=test_request, headers=headers
)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/pcm"
@pytest.mark.asyncio
async def test_openai_speech_streaming_mp3(mock_tts_service, async_client):
"""Test streaming MP3 audio to file"""
test_request = {
"model": "kokoro",
"input": "Hello world",
"voice": "af",
"response_format": "mp3",
"stream": True,
}
# Create streaming mock for this test
async def mock_stream(*args, **kwargs):
for chunk in [b"mp3header", b"mp3data"]:
yield chunk
mock_tts_service.generate_audio_stream = mock_stream
# Add streaming header
headers = {"x-raw-response": "stream"}
response = await async_client.post(
"/v1/audio/speech", json=test_request, headers=headers
)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/mpeg"
assert response.headers["content-disposition"] == "attachment; filename=speech.mp3"
@pytest.mark.asyncio
async def test_openai_speech_streaming_generator(mock_tts_service, async_client):
"""Test streaming with async generator"""
test_request = {
"model": "kokoro",
"input": "Hello world",
"voice": "af",
"response_format": "pcm",
"stream": True,
}
# Create streaming mock for this test
async def mock_stream(*args, **kwargs):
for chunk in [b"chunk1", b"chunk2"]:
yield chunk
mock_tts_service.generate_audio_stream = mock_stream
# Add streaming header
headers = {"x-raw-response": "stream"}
response = await async_client.post(
"/v1/audio/speech", json=test_request, headers=headers
)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/pcm"

View file

@ -1,108 +0,0 @@
"""Tests for FastAPI application"""
from unittest.mock import MagicMock, call, patch
import pytest
from fastapi.testclient import TestClient
from api.src.main import app, lifespan
@pytest.fixture
def test_client():
"""Create a test client"""
return TestClient(app)
def test_health_check(test_client):
"""Test health check endpoint"""
response = test_client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "healthy"}
@pytest.mark.asyncio
@patch("api.src.main.TTSModel")
@patch("api.src.main.logger")
async def test_lifespan_successful_warmup(mock_logger, mock_tts_model):
"""Test successful model warmup in lifespan"""
# Mock file system for voice counting
mock_tts_model.VOICES_DIR = "/mock/voices"
# Create async mock
async def async_setup():
return 3
mock_tts_model.setup = MagicMock()
mock_tts_model.setup.side_effect = async_setup
mock_tts_model.get_device.return_value = "cuda"
with patch("os.listdir", return_value=["voice1.pt", "voice2.pt", "voice3.pt"]):
# Create an async generator from the lifespan context manager
async_gen = lifespan(MagicMock())
# Start the context manager
await async_gen.__aenter__()
# Verify the expected logging sequence
mock_logger.info.assert_any_call("Loading TTS model and voice packs...")
# Check for the startup message containing the required info
startup_calls = [call[0][0] for call in mock_logger.info.call_args_list]
startup_msg = next(msg for msg in startup_calls if "Model warmed up on" in msg)
assert "Model warmed up on" in startup_msg
assert "3 voice packs loaded" in startup_msg
# Verify model setup was called
mock_tts_model.setup.assert_called_once()
# Clean up
await async_gen.__aexit__(None, None, None)
@pytest.mark.asyncio
@patch("api.src.main.TTSModel")
@patch("api.src.main.logger")
async def test_lifespan_failed_warmup(mock_logger, mock_tts_model):
"""Test failed model warmup in lifespan"""
# Mock the model setup to fail
mock_tts_model.setup.side_effect = RuntimeError("Failed to initialize model")
# Create an async generator from the lifespan context manager
async_gen = lifespan(MagicMock())
# Verify the exception is raised
with pytest.raises(RuntimeError, match="Failed to initialize model"):
await async_gen.__aenter__()
# Verify the expected logging sequence
mock_logger.info.assert_called_with("Loading TTS model and voice packs...")
# Clean up
await async_gen.__aexit__(None, None, None)
@pytest.mark.asyncio
@patch("api.src.main.TTSModel")
async def test_lifespan_cuda_warmup(mock_tts_model):
"""Test model warmup specifically on CUDA"""
# Mock file system for voice counting
mock_tts_model.VOICES_DIR = "/mock/voices"
# Create async mock
async def async_setup():
return 2
mock_tts_model.setup = MagicMock()
mock_tts_model.setup.side_effect = async_setup
mock_tts_model.get_device.return_value = "cuda"
with patch("os.listdir", return_value=["voice1.pt", "voice2.pt"]):
# Create an async generator from the lifespan context manager
async_gen = lifespan(MagicMock())
await async_gen.__aenter__()
# Verify model setup was called
mock_tts_model.setup.assert_called_once()
# Clean up
await async_gen.__aexit__(None, None, None)

View file

@ -0,0 +1,412 @@
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from fastapi.testclient import TestClient
import numpy as np
import asyncio
from typing import AsyncGenerator
import os
import json
from api.src.main import app
from api.src.services.tts_service import TTSService
from api.src.core.config import settings
from api.src.routers.openai_compatible import (
load_openai_mappings,
get_tts_service,
stream_audio_chunks
)
from api.src.structures.schemas import OpenAISpeechRequest
client = TestClient(app)
@pytest.fixture
def test_voice():
"""Fixture providing a test voice name."""
return "test_voice"
@pytest.fixture
def mock_openai_mappings():
"""Mock OpenAI mappings for testing."""
with patch("api.src.routers.openai_compatible._openai_mappings", {
"models": {
"tts-1": "kokoro-v0_19",
"tts-1-hd": "kokoro-v0_19"
},
"voices": {
"alloy": "am_adam",
"nova": "bf_isabella"
}
}):
yield
@pytest.fixture
def mock_json_file(tmp_path):
"""Create a temporary mock JSON file."""
content = {
"models": {"test-model": "test-kokoro"},
"voices": {"test-voice": "test-internal"}
}
json_file = tmp_path / "test_mappings.json"
json_file.write_text(json.dumps(content))
return json_file
def test_load_openai_mappings(mock_json_file):
"""Test loading OpenAI mappings from JSON file"""
with patch("os.path.join", return_value=str(mock_json_file)):
mappings = load_openai_mappings()
assert "models" in mappings
assert "voices" in mappings
assert mappings["models"]["test-model"] == "test-kokoro"
assert mappings["voices"]["test-voice"] == "test-internal"
def test_load_openai_mappings_file_not_found():
"""Test handling of missing mappings file"""
with patch("os.path.join", return_value="/nonexistent/path"):
mappings = load_openai_mappings()
assert mappings == {"models": {}, "voices": {}}
@pytest.mark.asyncio
async def test_get_tts_service_initialization():
"""Test TTSService initialization"""
with patch("api.src.routers.openai_compatible._tts_service", None):
with patch("api.src.routers.openai_compatible._init_lock", None):
with patch("api.src.services.tts_service.TTSService.create") as mock_create:
mock_service = AsyncMock()
mock_create.return_value = mock_service
# Test concurrent access
async def get_service():
return await get_tts_service()
# Create multiple concurrent requests
tasks = [get_service() for _ in range(5)]
results = await asyncio.gather(*tasks)
# Verify service was created only once
mock_create.assert_called_once()
assert all(r == mock_service for r in results)
@pytest.mark.asyncio
async def test_stream_audio_chunks_client_disconnect():
"""Test handling of client disconnect during streaming"""
mock_request = MagicMock()
mock_request.is_disconnected = AsyncMock(return_value=True)
mock_service = AsyncMock()
async def mock_stream(*args, **kwargs):
for i in range(5):
yield b"chunk"
mock_service.generate_audio_stream = mock_stream
mock_service.list_voices.return_value = ["test_voice"]
request = OpenAISpeechRequest(
model="kokoro",
input="Test text",
voice="test_voice",
response_format="mp3",
stream=True,
speed=1.0
)
chunks = []
async for chunk in stream_audio_chunks(mock_service, request, mock_request):
chunks.append(chunk)
assert len(chunks) == 0 # Should stop immediately due to disconnect
def test_openai_voice_mapping(mock_tts_service, mock_openai_mappings):
"""Test OpenAI voice name mapping"""
mock_tts_service.list_voices.return_value = ["am_adam", "bf_isabella"]
response = client.post(
"/v1/audio/speech",
json={
"model": "tts-1",
"input": "Hello world",
"voice": "alloy", # OpenAI voice name
"response_format": "mp3",
"stream": False
}
)
assert response.status_code == 200
mock_tts_service.generate_audio.assert_called_once()
assert mock_tts_service.generate_audio.call_args[1]["voice"] == "am_adam"
def test_openai_voice_mapping_streaming(mock_tts_service, mock_openai_mappings, mock_audio_bytes):
"""Test OpenAI voice mapping in streaming mode"""
mock_tts_service.list_voices.return_value = ["am_adam", "bf_isabella"]
response = client.post(
"/v1/audio/speech",
json={
"model": "tts-1-hd",
"input": "Hello world",
"voice": "nova", # OpenAI voice name
"response_format": "mp3",
"stream": True
}
)
assert response.status_code == 200
content = b""
for chunk in response.iter_bytes():
content += chunk
assert content == mock_audio_bytes
def test_invalid_openai_model(mock_tts_service, mock_openai_mappings):
"""Test error handling for invalid OpenAI model"""
response = client.post(
"/v1/audio/speech",
json={
"model": "invalid-model",
"input": "Hello world",
"voice": "alloy",
"response_format": "mp3",
"stream": False
}
)
assert response.status_code == 400
error_response = response.json()
assert error_response["detail"]["error"] == "invalid_model"
assert "Unsupported model" in error_response["detail"]["message"]
@pytest.fixture
def mock_audio_bytes():
"""Mock audio bytes for testing."""
return b"mock audio data"
@pytest.fixture
def mock_tts_service(mock_audio_bytes):
"""Mock TTS service for testing."""
with patch("api.src.routers.openai_compatible.get_tts_service") as mock_get:
service = AsyncMock(spec=TTSService)
service.generate_audio.return_value = (np.zeros(1000), 0.1)
async def mock_stream(*args, **kwargs) -> AsyncGenerator[bytes, None]:
yield mock_audio_bytes
service.generate_audio_stream = mock_stream
service.list_voices.return_value = ["test_voice", "voice1", "voice2"]
service.combine_voices.return_value = "voice1_voice2"
mock_get.return_value = service
mock_get.side_effect = None
yield service
@patch('api.src.services.audio.AudioService.convert_audio')
def test_openai_speech_endpoint(mock_convert, mock_tts_service, test_voice, mock_audio_bytes):
"""Test the OpenAI-compatible speech endpoint with basic MP3 generation"""
# Configure mocks
mock_tts_service.generate_audio.return_value = (np.zeros(1000), 0.1)
mock_convert.return_value = mock_audio_bytes
response = client.post(
"/v1/audio/speech",
json={
"model": "kokoro",
"input": "Hello world",
"voice": test_voice,
"response_format": "mp3",
"stream": False
}
)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/mpeg"
assert len(response.content) > 0
assert response.content == mock_audio_bytes
mock_tts_service.generate_audio.assert_called_once()
mock_convert.assert_called_once()
def test_openai_speech_streaming(mock_tts_service, test_voice, mock_audio_bytes):
"""Test the OpenAI-compatible speech endpoint with streaming"""
response = client.post(
"/v1/audio/speech",
json={
"model": "kokoro",
"input": "Hello world",
"voice": test_voice,
"response_format": "mp3",
"stream": True
}
)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/mpeg"
assert "Transfer-Encoding" in response.headers
assert response.headers["Transfer-Encoding"] == "chunked"
content = b""
for chunk in response.iter_bytes():
content += chunk
assert content == mock_audio_bytes
def test_openai_speech_pcm_streaming(mock_tts_service, test_voice, mock_audio_bytes):
"""Test PCM streaming format"""
response = client.post(
"/v1/audio/speech",
json={
"model": "kokoro",
"input": "Hello world",
"voice": test_voice,
"response_format": "pcm",
"stream": True
}
)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/pcm"
content = b""
for chunk in response.iter_bytes():
content += chunk
assert content == mock_audio_bytes
def test_openai_speech_invalid_voice(mock_tts_service):
"""Test error handling for invalid voice"""
mock_tts_service.generate_audio.side_effect = ValueError("Voice 'invalid_voice' not found")
response = client.post(
"/v1/audio/speech",
json={
"model": "kokoro",
"input": "Hello world",
"voice": "invalid_voice",
"response_format": "mp3",
"stream": False
}
)
assert response.status_code == 400
error_response = response.json()
assert error_response["detail"]["error"] == "validation_error"
assert "Voice 'invalid_voice' not found" in error_response["detail"]["message"]
assert error_response["detail"]["type"] == "invalid_request_error"
def test_openai_speech_empty_text(mock_tts_service, test_voice):
"""Test error handling for empty text"""
async def mock_error_stream(*args, **kwargs):
raise ValueError("Text is empty after preprocessing")
mock_tts_service.generate_audio = mock_error_stream
mock_tts_service.list_voices.return_value = ["test_voice"]
response = client.post(
"/v1/audio/speech",
json={
"model": "kokoro",
"input": "",
"voice": test_voice,
"response_format": "mp3",
"stream": False
}
)
assert response.status_code == 400
error_response = response.json()
assert error_response["detail"]["error"] == "validation_error"
assert "Text is empty after preprocessing" in error_response["detail"]["message"]
assert error_response["detail"]["type"] == "invalid_request_error"
def test_openai_speech_invalid_format(mock_tts_service, test_voice):
"""Test error handling for invalid format"""
response = client.post(
"/v1/audio/speech",
json={
"model": "kokoro",
"input": "Hello world",
"voice": test_voice,
"response_format": "invalid_format",
"stream": False
}
)
assert response.status_code == 422 # Validation error from Pydantic
def test_list_voices(mock_tts_service):
"""Test listing available voices"""
# Override the mock for this specific test
mock_tts_service.list_voices.return_value = ["voice1", "voice2"]
response = client.get("/v1/audio/voices")
assert response.status_code == 200
data = response.json()
assert "voices" in data
assert len(data["voices"]) == 2
assert "voice1" in data["voices"]
assert "voice2" in data["voices"]
def test_combine_voices(mock_tts_service):
"""Test combining voices endpoint"""
response = client.post(
"/v1/audio/voices/combine",
json="voice1+voice2"
)
assert response.status_code == 200
data = response.json()
assert "voice" in data
assert data["voice"] == "voice1_voice2"
def test_server_error(mock_tts_service, test_voice):
"""Test handling of server errors"""
async def mock_error_stream(*args, **kwargs):
raise RuntimeError("Internal server error")
mock_tts_service.generate_audio = mock_error_stream
mock_tts_service.list_voices.return_value = ["test_voice"]
response = client.post(
"/v1/audio/speech",
json={
"model": "kokoro",
"input": "Hello world",
"voice": test_voice,
"response_format": "mp3",
"stream": False
}
)
assert response.status_code == 500
error_response = response.json()
assert error_response["detail"]["error"] == "processing_error"
assert error_response["detail"]["type"] == "server_error"
def test_streaming_error(mock_tts_service, test_voice):
"""Test handling streaming errors"""
# Mock process_voices to raise the error
mock_tts_service.list_voices.side_effect = RuntimeError("Streaming failed")
response = client.post(
"/v1/audio/speech",
json={
"model": "kokoro",
"input": "Hello world",
"voice": test_voice,
"response_format": "mp3",
"stream": True
}
)
assert response.status_code == 500
error_data = response.json()
assert error_data["detail"]["error"] == "processing_error"
assert error_data["detail"]["type"] == "server_error"
assert "Streaming failed" in error_data["detail"]["message"]
@pytest.mark.asyncio
async def test_streaming_initialization_error():
"""Test handling of streaming initialization errors"""
mock_service = AsyncMock()
async def mock_error_stream(*args, **kwargs):
if False: # This makes it a proper generator
yield b""
raise RuntimeError("Failed to initialize stream")
mock_service.generate_audio_stream = mock_error_stream
mock_service.list_voices.return_value = ["test_voice"]
request = OpenAISpeechRequest(
model="kokoro",
input="Test text",
voice="test_voice",
response_format="mp3",
stream=True,
speed=1.0
)
with pytest.raises(RuntimeError) as exc:
async for _ in stream_audio_chunks(mock_service, request, MagicMock()):
pass
assert "Failed to initialize stream" in str(exc.value)

View file

@ -1,122 +0,0 @@
"""Tests for text processing endpoints"""
from unittest.mock import Mock, patch
import numpy as np
import pytest
import pytest_asyncio
from httpx import AsyncClient
from ..src.main import app
from .conftest import MockTTSModel
@pytest_asyncio.fixture
async def async_client():
async with AsyncClient(app=app, base_url="http://test") as ac:
yield ac
@pytest.mark.asyncio
async def test_phonemize_endpoint(async_client):
"""Test phoneme generation endpoint"""
with patch("api.src.routers.development.phonemize") as mock_phonemize, patch(
"api.src.routers.development.tokenize"
) as mock_tokenize:
# Setup mocks
mock_phonemize.return_value = "həlˈ"
mock_tokenize.return_value = [1, 2, 3]
# Test request
response = await async_client.post(
"/text/phonemize", json={"text": "hello", "language": "a"}
)
# Verify response
assert response.status_code == 200
result = response.json()
assert result["phonemes"] == "həlˈ"
assert result["tokens"] == [0, 1, 2, 3, 0] # Should add start/end tokens
@pytest.mark.asyncio
async def test_phonemize_empty_text(async_client):
"""Test phoneme generation with empty text"""
response = await async_client.post(
"/text/phonemize", json={"text": "", "language": "a"}
)
assert response.status_code == 500
assert "error" in response.json()["detail"]
@pytest.mark.asyncio
async def test_generate_from_phonemes(
async_client, mock_tts_service, mock_audio_service
):
"""Test audio generation from phonemes"""
with patch(
"api.src.routers.development.TTSService", return_value=mock_tts_service
):
response = await async_client.post(
"/text/generate_from_phonemes",
json={"phonemes": "həlˈ", "voice": "af_bella", "speed": 1.0},
)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/wav"
assert (
response.headers["content-disposition"] == "attachment; filename=speech.wav"
)
assert response.content == b"mock audio data"
@pytest.mark.asyncio
async def test_generate_from_phonemes_invalid_voice(async_client, mock_tts_service):
"""Test audio generation with invalid voice"""
mock_tts_service._get_voice_path.return_value = None
with patch(
"api.src.routers.development.TTSService", return_value=mock_tts_service
):
response = await async_client.post(
"/text/generate_from_phonemes",
json={"phonemes": "həlˈ", "voice": "invalid_voice", "speed": 1.0},
)
assert response.status_code == 400
assert "Voice not found" in response.json()["detail"]["message"]
@pytest.mark.asyncio
async def test_generate_from_phonemes_invalid_speed(async_client, monkeypatch):
"""Test audio generation with invalid speed"""
# Mock TTSModel initialization
mock_model = Mock()
mock_model.generate_from_tokens = Mock(return_value=np.zeros(48000))
monkeypatch.setattr("api.src.services.tts_model.TTSModel._instance", mock_model)
monkeypatch.setattr(
"api.src.services.tts_model.TTSModel.get_instance",
Mock(return_value=mock_model),
)
response = await async_client.post(
"/text/generate_from_phonemes",
json={"phonemes": "həlˈ", "voice": "af_bella", "speed": -1.0},
)
assert response.status_code == 422 # Validation error
@pytest.mark.asyncio
async def test_generate_from_phonemes_empty_phonemes(async_client, mock_tts_service):
"""Test audio generation with empty phonemes"""
with patch(
"api.src.routers.development.TTSService", return_value=mock_tts_service
):
response = await async_client.post(
"/text/generate_from_phonemes",
json={"phonemes": "", "voice": "af_bella", "speed": 1.0},
)
assert response.status_code == 400
assert "Invalid request" in response.json()["detail"]["error"]

View file

@ -1,201 +0,0 @@
"""Tests for TTS model implementations"""
import os
from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
import torch
from api.src.services.tts_base import TTSBaseModel
from api.src.services.tts_cpu import TTSCPUModel
from api.src.services.tts_gpu import TTSGPUModel, length_to_mask
# Base Model Tests
def test_get_device_error():
"""Test get_device() raises error when not initialized"""
TTSBaseModel._device = None
with pytest.raises(RuntimeError, match="Model not initialized"):
TTSBaseModel.get_device()
@pytest.mark.asyncio
@patch("torch.cuda.is_available")
@patch("os.path.exists")
@patch("os.path.join")
@patch("os.listdir")
@patch("torch.load")
@patch("torch.save")
@patch("api.src.services.tts_base.settings")
@patch("api.src.services.warmup.WarmupService")
async def test_setup_cuda_available(
mock_warmup_class, mock_settings, mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available
):
"""Test setup with CUDA available"""
TTSBaseModel._device = None
# Mock CUDA as unavailable since we're using CPU PyTorch
mock_cuda_available.return_value = False
mock_exists.return_value = True
mock_load.return_value = torch.zeros(1)
mock_listdir.return_value = ["voice1.pt", "voice2.pt"]
mock_join.return_value = "/mocked/path"
# Configure mock settings
mock_settings.model_dir = "/mock/model/dir"
mock_settings.onnx_model_path = "model.onnx"
mock_settings.voices_dir = "voices"
# Configure mock warmup service
mock_warmup = MagicMock()
mock_warmup.load_voices.return_value = [torch.zeros(1)]
mock_warmup.warmup_voices = AsyncMock()
mock_warmup_class.return_value = mock_warmup
# Create mock model
mock_model = MagicMock()
mock_model.bert = MagicMock()
mock_model.process_text = MagicMock(return_value=("dummy", [1, 2, 3]))
mock_model.generate_from_tokens = MagicMock(return_value=np.zeros(1000))
# Mock initialize to return our mock model
TTSBaseModel.initialize = MagicMock(return_value=mock_model)
TTSBaseModel._instance = mock_model
voice_count = await TTSBaseModel.setup()
assert TTSBaseModel._device == "cpu"
assert voice_count == 2
@pytest.mark.asyncio
@patch("torch.cuda.is_available")
@patch("os.path.exists")
@patch("os.path.join")
@patch("os.listdir")
@patch("torch.load")
@patch("torch.save")
@patch("api.src.services.tts_base.settings")
@patch("api.src.services.warmup.WarmupService")
async def test_setup_cuda_unavailable(
mock_warmup_class, mock_settings, mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available
):
"""Test setup with CUDA unavailable"""
TTSBaseModel._device = None
mock_cuda_available.return_value = False
mock_exists.return_value = True
mock_load.return_value = torch.zeros(1)
mock_listdir.return_value = ["voice1.pt", "voice2.pt"]
mock_join.return_value = "/mocked/path"
# Configure mock settings
mock_settings.model_dir = "/mock/model/dir"
mock_settings.onnx_model_path = "model.onnx"
mock_settings.voices_dir = "voices"
# Configure mock warmup service
mock_warmup = MagicMock()
mock_warmup.load_voices.return_value = [torch.zeros(1)]
mock_warmup.warmup_voices = AsyncMock()
mock_warmup_class.return_value = mock_warmup
# Create mock model
mock_model = MagicMock()
mock_model.bert = MagicMock()
mock_model.process_text = MagicMock(return_value=("dummy", [1, 2, 3]))
mock_model.generate_from_tokens = MagicMock(return_value=np.zeros(1000))
# Mock initialize to return our mock model
TTSBaseModel.initialize = MagicMock(return_value=mock_model)
TTSBaseModel._instance = mock_model
voice_count = await TTSBaseModel.setup()
assert TTSBaseModel._device == "cpu"
assert voice_count == 2
# CPU Model Tests
def test_cpu_initialize_missing_model():
"""Test CPU initialize with missing model"""
TTSCPUModel._onnx_session = None # Reset the session
with patch("os.path.exists", return_value=False), patch(
"onnxruntime.InferenceSession", return_value=None
):
result = TTSCPUModel.initialize("dummy_dir")
assert result is None
def test_cpu_generate_uninitialized():
"""Test CPU generate methods with uninitialized model"""
TTSCPUModel._onnx_session = None
with pytest.raises(RuntimeError, match="ONNX model not initialized"):
TTSCPUModel.generate_from_text("test", torch.zeros(1), "en", 1.0)
with pytest.raises(RuntimeError, match="ONNX model not initialized"):
TTSCPUModel.generate_from_tokens([1, 2, 3], torch.zeros(1), 1.0)
def test_cpu_process_text():
"""Test CPU process_text functionality"""
with patch("api.src.services.tts_cpu.phonemize") as mock_phonemize, patch(
"api.src.services.tts_cpu.tokenize"
) as mock_tokenize:
mock_phonemize.return_value = "test phonemes"
mock_tokenize.return_value = [1, 2, 3]
phonemes, tokens = TTSCPUModel.process_text("test", "en")
assert phonemes == "test phonemes"
assert tokens == [0, 1, 2, 3, 0] # Should add start/end tokens
# GPU Model Tests
@patch("torch.cuda.is_available")
def test_gpu_initialize_cuda_unavailable(mock_cuda_available):
"""Test GPU initialize with CUDA unavailable"""
mock_cuda_available.return_value = False
TTSGPUModel._instance = None
result = TTSGPUModel.initialize("dummy_dir", "dummy_path")
assert result is None
@patch("api.src.services.tts_gpu.length_to_mask")
def test_gpu_length_to_mask(mock_length_to_mask):
"""Test length_to_mask function"""
# Setup mock return value
expected_mask = torch.tensor(
[[False, False, False, True, True], [False, False, False, False, False]]
)
mock_length_to_mask.return_value = expected_mask
# Call function with test input
lengths = torch.tensor([3, 5])
mask = mock_length_to_mask(lengths)
# Verify mock was called with correct input
mock_length_to_mask.assert_called_once()
assert torch.equal(mask, expected_mask)
def test_gpu_generate_uninitialized():
"""Test GPU generate methods with uninitialized model"""
TTSGPUModel._instance = None
with pytest.raises(RuntimeError, match="GPU model not initialized"):
TTSGPUModel.generate_from_text("test", torch.zeros(1), "en", 1.0)
with pytest.raises(RuntimeError, match="GPU model not initialized"):
TTSGPUModel.generate_from_tokens([1, 2, 3], torch.zeros(1), 1.0)
def test_gpu_process_text():
"""Test GPU process_text functionality"""
with patch("api.src.services.tts_gpu.phonemize") as mock_phonemize, patch(
"api.src.services.tts_gpu.tokenize"
) as mock_tokenize:
mock_phonemize.return_value = "test phonemes"
mock_tokenize.return_value = [1, 2, 3]
phonemes, tokens = TTSGPUModel.process_text("test", "en")
assert phonemes == "test phonemes"
assert tokens == [1, 2, 3] # GPU implementation doesn't add start/end tokens

View file

@ -1,260 +0,0 @@
"""Tests for TTSService"""
import os
from unittest.mock import MagicMock, call, patch
import numpy as np
import pytest
import torch
from onnxruntime import InferenceSession
from api.src.core.config import settings
from api.src.services.tts_cpu import TTSCPUModel
from api.src.services.tts_gpu import TTSGPUModel
from api.src.services.tts_model import TTSModel
from api.src.services.tts_service import TTSService
@pytest.fixture
def tts_service(monkeypatch):
"""Create a TTSService instance for testing"""
# Mock TTSModel initialization
mock_model = MagicMock()
mock_model.generate_from_tokens = MagicMock(return_value=np.zeros(48000))
mock_model.process_text = MagicMock(return_value=("mock phonemes", [1, 2, 3]))
# Set up model instance
monkeypatch.setattr("api.src.services.tts_model.TTSModel._instance", mock_model)
monkeypatch.setattr(
"api.src.services.tts_model.TTSModel.get_instance",
MagicMock(return_value=mock_model),
)
monkeypatch.setattr(
"api.src.services.tts_model.TTSModel.get_device", MagicMock(return_value="cpu")
)
return TTSService()
@pytest.fixture
def sample_audio():
"""Generate a simple sine wave for testing"""
sample_rate = 24000
duration = 0.1 # 100ms
t = np.linspace(0, duration, int(sample_rate * duration))
frequency = 440 # A4 note
return np.sin(2 * np.pi * frequency * t).astype(np.float32)
def test_audio_to_bytes(tts_service, sample_audio):
"""Test converting audio tensor to bytes"""
audio_bytes = tts_service._audio_to_bytes(sample_audio)
assert isinstance(audio_bytes, bytes)
assert len(audio_bytes) > 0
@pytest.mark.asyncio
async def test_list_voices(tts_service):
"""Test listing available voices"""
# Override list_voices for testing
# # TODO:
# Whatever aiofiles does here pathing aiofiles vs aiofiles.os
# I am thoroughly confused by it.
# Cheating the test as it seems to work in the real world (for now)
async def mock_list_voices():
return ["voice1", "voice2"]
tts_service.list_voices = mock_list_voices
voices = await tts_service.list_voices()
assert len(voices) == 2
assert "voice1" in voices
assert "voice2" in voices
@pytest.mark.asyncio
async def test_list_voices_error(tts_service):
"""Test error handling in list_voices"""
# Override list_voices for testing
# TODO: See above.
async def mock_list_voices():
return []
tts_service.list_voices = mock_list_voices
voices = await tts_service.list_voices()
assert voices == []
def mock_model_setup(cuda_available=False):
"""Helper function to mock model setup"""
# Reset model state
TTSModel._instance = None
TTSModel._device = None
TTSModel._voicepacks = {}
# Create mock model instance with proper generate method
mock_model = MagicMock()
mock_model.generate.return_value = np.zeros(24000, dtype=np.float32)
TTSModel._instance = mock_model
# Set device based on CUDA availability
TTSModel._device = "cuda" if cuda_available else "cpu"
return 3 # Return voice count (including af.pt)
def test_model_initialization_cuda():
"""Test model initialization with CUDA"""
# Simulate CUDA availability
voice_count = mock_model_setup(cuda_available=True)
assert TTSModel.get_device() == "cuda"
assert voice_count == 3 # voice1.pt, voice2.pt, af.pt
def test_model_initialization_cpu():
"""Test model initialization with CPU"""
# Simulate no CUDA availability
voice_count = mock_model_setup(cuda_available=False)
assert TTSModel.get_device() == "cpu"
assert voice_count == 3 # voice1.pt, voice2.pt, af.pt
def test_generate_audio_empty_text(tts_service):
"""Test generating audio with empty text"""
with pytest.raises(ValueError, match="Text is empty after preprocessing"):
tts_service._generate_audio("", "af", 1.0)
@pytest.fixture(autouse=True)
def mock_settings():
"""Mock settings for all tests"""
with patch("api.src.services.text_processing.chunker.settings") as mock_settings:
mock_settings.max_chunk_size = 300
yield mock_settings
@patch("api.src.services.tts_model.TTSModel.get_instance")
@patch("api.src.services.tts_model.TTSModel.get_device")
@patch("os.path.exists")
@patch("kokoro.normalize_text")
@patch("kokoro.phonemize")
@patch("kokoro.tokenize")
@patch("kokoro.generate")
@patch("torch.load")
def test_generate_audio_phonemize_error(
mock_torch_load,
mock_generate,
mock_tokenize,
mock_phonemize,
mock_normalize,
mock_exists,
mock_get_device,
mock_instance,
tts_service,
):
"""Test handling phonemization error"""
mock_normalize.return_value = "Test text"
mock_phonemize.side_effect = Exception("Phonemization failed")
mock_instance.return_value = (
mock_generate,
"cpu",
) # Use the same mock for consistency
mock_get_device.return_value = "cpu"
mock_exists.return_value = True
mock_torch_load.return_value = torch.zeros((10, 24000))
mock_generate.return_value = (None, None)
with pytest.raises(ValueError, match="No chunks were processed successfully"):
tts_service._generate_audio("Test text", "af", 1.0)
@patch("api.src.services.tts_model.TTSModel.get_instance")
@patch("api.src.services.tts_model.TTSModel.get_device")
@patch("os.path.exists")
@patch("kokoro.normalize_text")
@patch("kokoro.phonemize")
@patch("kokoro.tokenize")
@patch("kokoro.generate")
@patch("torch.load")
def test_generate_audio_error(
mock_torch_load,
mock_generate,
mock_tokenize,
mock_phonemize,
mock_normalize,
mock_exists,
mock_get_device,
mock_instance,
tts_service,
):
"""Test handling generation error"""
mock_normalize.return_value = "Test text"
mock_phonemize.return_value = "Test text"
mock_tokenize.return_value = [1, 2] # Return integers instead of strings
mock_generate.side_effect = Exception("Generation failed")
mock_instance.return_value = (
mock_generate,
"cpu",
) # Use the same mock for consistency
mock_get_device.return_value = "cpu"
mock_exists.return_value = True
mock_torch_load.return_value = torch.zeros((10, 24000))
with pytest.raises(ValueError, match="No chunks were processed successfully"):
tts_service._generate_audio("Test text", "af", 1.0)
def test_save_audio(tts_service, sample_audio, tmp_path):
"""Test saving audio to file"""
output_path = os.path.join(tmp_path, "test_output.wav")
tts_service._save_audio(sample_audio, output_path)
assert os.path.exists(output_path)
assert os.path.getsize(output_path) > 0
@pytest.mark.asyncio
async def test_combine_voices(tts_service):
"""Test combining multiple voices"""
# Setup mocks for torch operations
with patch("torch.load", return_value=torch.tensor([1.0, 2.0])), patch(
"torch.stack", return_value=torch.tensor([[1.0, 2.0], [3.0, 4.0]])
), patch("torch.mean", return_value=torch.tensor([2.0, 3.0])), patch(
"torch.save"
), patch("os.path.exists", return_value=True):
# Test combining two voices
result = await tts_service.combine_voices(["voice1", "voice2"])
assert result == "voice1_voice2"
@pytest.mark.asyncio
async def test_combine_voices_invalid_input(tts_service):
"""Test combining voices with invalid input"""
# Test with empty list
with pytest.raises(ValueError, match="At least 2 voices are required"):
await tts_service.combine_voices([])
# Test with single voice
with pytest.raises(ValueError, match="At least 2 voices are required"):
await tts_service.combine_voices(["voice1"])
@patch("api.src.services.tts_service.TTSService._get_voice_path")
@patch("api.src.services.tts_model.TTSModel.get_instance")
def test_voicepack_loading_error(mock_get_instance, mock_get_voice_path):
"""Test voicepack loading error handling"""
mock_get_voice_path.return_value = None
mock_instance = MagicMock()
mock_instance.generate.return_value = np.zeros(24000, dtype=np.float32)
mock_get_instance.return_value = (mock_instance, "cpu")
TTSModel._voicepacks = {} # Reset voicepacks
service = TTSService()
with pytest.raises(ValueError, match="Voice not found: nonexistent_voice"):
service._generate_audio("test", "nonexistent_voice", 1.0)

View file

@ -0,0 +1,142 @@
# import pytest
# import numpy as np
# from unittest.mock import AsyncMock, patch
# @pytest.mark.asyncio
# async def test_generate_audio(tts_service, mock_audio_output, test_voice):
# """Test basic audio generation"""
# audio, processing_time = await tts_service.generate_audio(
# text="Hello world",
# voice=test_voice,
# speed=1.0
# )
# assert isinstance(audio, np.ndarray)
# assert audio == mock_audio_output.tobytes()
# assert processing_time > 0
# tts_service.model_manager.generate.assert_called_once()
# @pytest.mark.asyncio
# async def test_generate_audio_with_combined_voice(tts_service, mock_audio_output):
# """Test audio generation with a combined voice"""
# test_voices = ["voice1", "voice2"]
# combined_id = await tts_service._voice_manager.combine_voices(test_voices)
# audio, processing_time = await tts_service.generate_audio(
# text="Hello world",
# voice=combined_id,
# speed=1.0
# )
# assert isinstance(audio, np.ndarray)
# assert np.array_equal(audio, mock_audio_output)
# assert processing_time > 0
# @pytest.mark.asyncio
# async def test_generate_audio_stream(tts_service, mock_audio_output, test_voice):
# """Test streaming audio generation"""
# tts_service.model_manager.generate.return_value = mock_audio_output
# chunks = []
# async for chunk in tts_service.generate_audio_stream(
# text="Hello world",
# voice=test_voice,
# speed=1.0,
# output_format="pcm"
# ):
# assert isinstance(chunk, bytes)
# chunks.append(chunk)
# assert len(chunks) > 0
# tts_service.model_manager.generate.assert_called()
# @pytest.mark.asyncio
# async def test_empty_text(tts_service, test_voice):
# """Test handling empty text"""
# with pytest.raises(ValueError) as exc_info:
# await tts_service.generate_audio(
# text="",
# voice=test_voice,
# speed=1.0
# )
# assert "No audio chunks were generated successfully" in str(exc_info.value)
# @pytest.mark.asyncio
# async def test_invalid_voice(tts_service):
# """Test handling invalid voice"""
# tts_service._voice_manager.load_voice.side_effect = ValueError("Voice not found")
# with pytest.raises(ValueError) as exc_info:
# await tts_service.generate_audio(
# text="Hello world",
# voice="invalid_voice",
# speed=1.0
# )
# assert "Voice not found" in str(exc_info.value)
# @pytest.mark.asyncio
# async def test_model_generation_error(tts_service, test_voice):
# """Test handling model generation error"""
# # Make generate return None to simulate failed generation
# tts_service.model_manager.generate.return_value = None
# with pytest.raises(ValueError) as exc_info:
# await tts_service.generate_audio(
# text="Hello world",
# voice=test_voice,
# speed=1.0
# )
# assert "No audio chunks were generated successfully" in str(exc_info.value)
# @pytest.mark.asyncio
# async def test_streaming_generation_error(tts_service, test_voice):
# """Test handling streaming generation error"""
# # Make generate return None to simulate failed generation
# tts_service.model_manager.generate.return_value = None
# chunks = []
# async for chunk in tts_service.generate_audio_stream(
# text="Hello world",
# voice=test_voice,
# speed=1.0,
# output_format="pcm"
# ):
# chunks.append(chunk)
# # Should get no chunks if generation fails
# assert len(chunks) == 0
# @pytest.mark.asyncio
# async def test_list_voices(tts_service):
# """Test listing available voices"""
# voices = await tts_service.list_voices()
# assert len(voices) == 2
# assert "voice1" in voices
# assert "voice2" in voices
# tts_service._voice_manager.list_voices.assert_called_once()
# @pytest.mark.asyncio
# async def test_combine_voices(tts_service):
# """Test combining voices"""
# test_voices = ["voice1", "voice2"]
# combined_id = await tts_service.combine_voices(test_voices)
# assert combined_id == "voice1_voice2"
# tts_service._voice_manager.combine_voices.assert_called_once_with(test_voices)
# @pytest.mark.asyncio
# async def test_chunked_text_processing(tts_service, test_voice, mock_audio_output):
# """Test processing chunked text"""
# # Create text that will force chunking by exceeding max tokens
# long_text = "This is a test sentence." * 100 # Should be way over 500 tokens
# # Don't mock smart_split - let it actually split the text
# audio, processing_time = await tts_service.generate_audio(
# text=long_text,
# voice=test_voice,
# speed=1.0
# )
# # Should be called multiple times due to chunking
# assert tts_service.model_manager.generate.call_count > 1
# assert isinstance(audio, np.ndarray)
# assert processing_time > 0

View file

@ -0,0 +1,134 @@
import pytest
from unittest.mock import AsyncMock, patch, MagicMock
import torch
from pathlib import Path
from ..src.inference.voice_manager import VoiceManager
from ..src.structures.model_schemas import VoiceConfig
@pytest.fixture
def mock_voice_tensor():
return torch.randn(10, 10) # Dummy tensor
@pytest.fixture
def voice_manager():
return VoiceManager(VoiceConfig())
@pytest.mark.asyncio
async def test_load_voice(voice_manager, mock_voice_tensor):
"""Test loading a single voice"""
with patch("api.src.core.paths.load_voice_tensor", new_callable=AsyncMock) as mock_load:
mock_load.return_value = mock_voice_tensor
with patch("os.path.exists", return_value=True):
voice = await voice_manager.load_voice("af_bella", "cpu")
assert torch.equal(voice, mock_voice_tensor)
@pytest.mark.asyncio
async def test_load_voice_not_found(voice_manager):
"""Test loading non-existent voice"""
with patch("os.path.exists", return_value=False):
with pytest.raises(RuntimeError, match="Voice not found: invalid_voice"):
await voice_manager.load_voice("invalid_voice", "cpu")
@pytest.mark.skip(reason="Local saving is optional and not critical to functionality")
@pytest.mark.asyncio
async def test_combine_voices_with_saving(voice_manager, mock_voice_tensor):
"""Test combining voices with local saving enabled"""
pass
@pytest.mark.asyncio
async def test_combine_voices_without_saving(voice_manager, mock_voice_tensor):
"""Test combining voices without local saving"""
with patch("api.src.core.paths.load_voice_tensor", new_callable=AsyncMock) as mock_load, \
patch("torch.save") as mock_save, \
patch("os.makedirs"), \
patch("os.path.exists", return_value=True):
# Setup mocks
mock_load.return_value = mock_voice_tensor
# Mock settings
with patch("api.src.core.config.settings") as mock_settings:
mock_settings.allow_local_voice_saving = False
mock_settings.voices_dir = "/mock/voices"
# Combine voices
combined = await voice_manager.combine_voices(["af_bella", "af_sarah"], "cpu")
assert combined == "af_bella+af_sarah" # Note: using + separator
# Verify voice was not saved
mock_save.assert_not_called()
@pytest.mark.asyncio
async def test_combine_voices_single_voice(voice_manager):
"""Test combining with single voice"""
with pytest.raises(ValueError, match="At least 2 voices are required"):
await voice_manager.combine_voices(["af_bella"], "cpu")
@pytest.mark.asyncio
async def test_list_voices(voice_manager):
"""Test listing available voices"""
with patch("os.listdir", return_value=["af_bella.pt", "af_sarah.pt", "af_bella+af_sarah.pt"]), \
patch("os.makedirs"):
voices = await voice_manager.list_voices()
assert len(voices) == 3
assert "af_bella" in voices
assert "af_sarah" in voices
assert "af_bella+af_sarah" in voices
@pytest.mark.asyncio
async def test_load_combined_voice(voice_manager, mock_voice_tensor):
"""Test loading a combined voice"""
with patch("api.src.core.paths.load_voice_tensor", new_callable=AsyncMock) as mock_load:
mock_load.return_value = mock_voice_tensor
with patch("os.path.exists", return_value=True):
voice = await voice_manager.load_voice("af_bella+af_sarah", "cpu")
assert torch.equal(voice, mock_voice_tensor)
def test_cache_management(mock_voice_tensor):
"""Test voice cache management"""
# Create voice manager with small cache size
config = VoiceConfig(cache_size=2)
voice_manager = VoiceManager(config)
# Add items to cache
voice_manager._voice_cache = {
"voice1_cpu": torch.randn(5, 5),
"voice2_cpu": torch.randn(5, 5),
"voice3_cpu": torch.randn(5, 5), # Add one more than cache size
}
# Try managing cache
voice_manager._manage_cache()
# Check cache size maintained
assert len(voice_manager._voice_cache) <= 2
@pytest.mark.asyncio
async def test_voice_loading_with_cache(voice_manager, mock_voice_tensor):
"""Test voice loading with cache enabled"""
with patch("api.src.core.paths.load_voice_tensor", new_callable=AsyncMock) as mock_load, \
patch("os.path.exists", return_value=True):
mock_load.return_value = mock_voice_tensor
# First load should hit disk
voice1 = await voice_manager.load_voice("af_bella", "cpu")
assert mock_load.call_count == 1
# Second load should hit cache
voice2 = await voice_manager.load_voice("af_bella", "cpu")
assert mock_load.call_count == 1 # Still 1
assert torch.equal(voice1, voice2)

BIN
assets/beta_web_ui.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 385 KiB

BIN
assets/docs-screenshot.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 78 KiB

BIN
assets/webui-screenshot.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 283 KiB

17
debug.http Normal file
View file

@ -0,0 +1,17 @@
### Get Thread Information
GET http://localhost:8880/debug/threads
Accept: application/json
### Get Storage Information
GET http://localhost:8880/debug/storage
Accept: application/json
### Get System Information
GET http://localhost:8880/debug/system
Accept: application/json
### Get Session Pool Status
# Shows active ONNX sessions, CUDA stream usage, and session ages
# Useful for debugging resource exhaustion issues
GET http://localhost:8880/debug/session_pools
Accept: application/json

36
docker/build.sh Executable file
View file

@ -0,0 +1,36 @@
#!/bin/bash
set -e
# Get version from argument or use default
VERSION=${1:-"latest"}
# GitHub Container Registry settings
REGISTRY="ghcr.io"
OWNER="remsky"
REPO="kokoro-fastapi"
# Create and use a new builder that supports multi-platform builds
docker buildx create --name multiplatform-builder --use || true
# Build CPU image with multi-platform support
echo "Building CPU image..."
docker buildx build --platform linux/amd64,linux/arm64 \
-t ${REGISTRY}/${OWNER}/${REPO}-cpu:${VERSION} \
-t ${REGISTRY}/${OWNER}/${REPO}-cpu:latest \
-f docker/cpu/Dockerfile \
--push .
# Build GPU image with multi-platform support
echo "Building GPU image..."
docker buildx build --platform linux/amd64,linux/arm64 \
-t ${REGISTRY}/${OWNER}/${REPO}-gpu:${VERSION} \
-t ${REGISTRY}/${OWNER}/${REPO}-gpu:latest \
-f docker/gpu/Dockerfile \
--push .
echo "Build complete!"
echo "Created images:"
echo "- ${REGISTRY}/${OWNER}/${REPO}-cpu:${VERSION} (linux/amd64, linux/arm64)"
echo "- ${REGISTRY}/${OWNER}/${REPO}-cpu:latest (linux/amd64, linux/arm64)"
echo "- ${REGISTRY}/${OWNER}/${REPO}-gpu:${VERSION} (linux/amd64, linux/arm64)"
echo "- ${REGISTRY}/${OWNER}/${REPO}-gpu:latest (linux/amd64, linux/arm64)"

View file

@ -1,4 +1,4 @@
FROM python:3.10-slim
FROM --platform=$BUILDPLATFORM python:3.10-slim
# Install dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
@ -10,36 +10,17 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# Install uv
# Install uv for speed and glory
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
# Create non-root user
RUN useradd -m -u 1000 appuser
# Create directories and set ownership
RUN mkdir -p /app/models && \
mkdir -p /app/api/src/voices && \
RUN mkdir -p /app/api/src/voices && \
chown -R appuser:appuser /app
USER appuser
# Download and extract models
WORKDIR /app/models
RUN set -x && \
curl -L -o model.tar.gz https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.0.1/kokoro-82m-onnx.tar.gz && \
echo "Downloaded model.tar.gz:" && ls -lh model.tar.gz && \
tar xzf model.tar.gz && \
echo "Contents after extraction:" && ls -lhR && \
rm model.tar.gz && \
echo "Final contents:" && ls -lhR
# Download and extract voice models
WORKDIR /app/api/src/voices
RUN curl -L -o voices.tar.gz https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.0.1/voice-models.tar.gz && \
tar xzf voices.tar.gz && \
rm voices.tar.gz
# Switch back to app directory
WORKDIR /app
# Copy dependency files
@ -50,8 +31,10 @@ RUN --mount=type=cache,target=/root/.cache/uv \
uv venv && \
uv sync --extra cpu --no-install-project
# Copy project files
# Copy project files including models
COPY --chown=appuser:appuser api ./api
COPY --chown=appuser:appuser web ./web
COPY --chown=appuser:appuser docker/scripts/download_model.* ./
# Install project
RUN --mount=type=cache,target=/root/.cache/uv \
@ -59,9 +42,22 @@ RUN --mount=type=cache,target=/root/.cache/uv \
# Set environment variables
ENV PYTHONUNBUFFERED=1
ENV PYTHONPATH=/app:/app/models
ENV PYTHONPATH=/app
ENV PATH="/app/.venv/bin:$PATH"
ENV UV_LINK_MODE=copy
ENV USE_GPU=false
ENV USE_ONNX=true
ENV DOWNLOAD_ONNX=true
ENV DOWNLOAD_PTH=false
# Download models based on environment variables
RUN if [ "$DOWNLOAD_ONNX" = "true" ]; then \
python download_model.py --type onnx; \
fi && \
if [ "$DOWNLOAD_PTH" = "true" ]; then \
python download_model.py --type pth; \
fi
# Run FastAPI server
CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]

View file

@ -6,12 +6,11 @@ services:
context: ../..
dockerfile: docker/cpu/Dockerfile
volumes:
- ../../api/src:/app/api/src
- ../../api/src/voices:/app/api/src/voices
- ../../api:/app/api
ports:
- "8880:8880"
environment:
- PYTHONPATH=/app:/app/models
- PYTHONPATH=/app:/app/api
# ONNX Optimization Settings for vectorized operations
- ONNX_NUM_THREADS=8 # Maximize core usage for vectorized ops
- ONNX_INTER_OP_THREADS=4 # Higher inter-op for parallel matrix operations
@ -20,20 +19,20 @@ services:
- ONNX_MEMORY_PATTERN=true
- ONNX_ARENA_EXTEND_STRATEGY=kNextPowerOfTwo
# Gradio UI service [Comment out everything below if you don't need it]
gradio-ui:
image: ghcr.io/remsky/kokoro-fastapi-ui:v0.1.0
# Uncomment below (and comment out above) to build from source instead of using the released image
build:
context: ../../ui
ports:
- "7860:7860"
volumes:
- ../../ui/data:/app/ui/data
- ../../ui/app.py:/app/app.py # Mount app.py for hot reload
environment:
- GRADIO_WATCH=True # Enable hot reloading
- PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
- DISABLE_LOCAL_SAVING=false # Set to 'true' to disable local saving and hide file view
- API_HOST=kokoro-tts # Set TTS service URL
- API_PORT=8880 # Set TTS service PORT
# # Gradio UI service [Comment out everything below if you don't need it]
# gradio-ui:
# image: ghcr.io/remsky/kokoro-fastapi-ui:v0.1.0
# # Uncomment below (and comment out above) to build from source instead of using the released image
# build:
# context: ../../ui
# ports:
# - "7860:7860"
# volumes:
# - ../../ui/data:/app/ui/data
# - ../../ui/app.py:/app/app.py # Mount app.py for hot reload
# environment:
# - GRADIO_WATCH=True # Enable hot reloading
# - PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
# - DISABLE_LOCAL_SAVING=false # Set to 'true' to disable local saving and hide file view
# - API_HOST=kokoro-tts # Set TTS service URL
# - API_PORT=8880 # Set TTS service PORT

View file

@ -1,22 +0,0 @@
[project]
name = "kokoro-fastapi-cpu"
version = "0.1.0"
description = "FastAPI TTS Service - CPU Version"
readme = "../README.md"
requires-python = ">=3.10"
dependencies = [
# Core ML/DL for CPU
"torch>=2.5.1",
"transformers==4.47.1",
]
[tool.uv.workspace]
members = ["../shared"]
[tool.uv.sources]
torch = { index = "pytorch-cpu" }
[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true

View file

@ -1,229 +0,0 @@
# This file was autogenerated by uv via the following command:
# uv pip compile pyproject.toml ../shared/pyproject.toml --output-file requirements.lock
aiofiles==23.2.1
# via kokoro-fastapi (../shared/pyproject.toml)
annotated-types==0.7.0
# via pydantic
anyio==4.8.0
# via starlette
attrs==24.3.0
# via
# clldutils
# csvw
# jsonschema
# phonemizer
# referencing
babel==2.16.0
# via csvw
certifi==2024.12.14
# via requests
cffi==1.17.1
# via soundfile
charset-normalizer==3.4.1
# via requests
click==8.1.8
# via
# kokoro-fastapi (../shared/pyproject.toml)
# uvicorn
clldutils==3.21.0
# via segments
colorama==0.4.6
# via
# click
# colorlog
# csvw
# loguru
# tqdm
coloredlogs==15.0.1
# via onnxruntime
colorlog==6.9.0
# via clldutils
csvw==3.5.1
# via segments
dlinfo==1.2.1
# via phonemizer
exceptiongroup==1.2.2
# via anyio
fastapi==0.115.6
# via kokoro-fastapi (../shared/pyproject.toml)
filelock==3.16.1
# via
# huggingface-hub
# torch
# transformers
flatbuffers==24.12.23
# via onnxruntime
fsspec==2024.12.0
# via
# huggingface-hub
# torch
greenlet==3.1.1
# via sqlalchemy
h11==0.14.0
# via uvicorn
huggingface-hub==0.27.1
# via
# tokenizers
# transformers
humanfriendly==10.0
# via coloredlogs
idna==3.10
# via
# anyio
# requests
isodate==0.7.2
# via
# csvw
# rdflib
jinja2==3.1.5
# via torch
joblib==1.4.2
# via phonemizer
jsonschema==4.23.0
# via csvw
jsonschema-specifications==2024.10.1
# via jsonschema
language-tags==1.2.0
# via csvw
loguru==0.7.3
# via kokoro-fastapi (../shared/pyproject.toml)
lxml==5.3.0
# via clldutils
markdown==3.7
# via clldutils
markupsafe==3.0.2
# via
# clldutils
# jinja2
mpmath==1.3.0
# via sympy
munch==4.0.0
# via kokoro-fastapi (../shared/pyproject.toml)
networkx==3.4.2
# via torch
numpy==2.2.1
# via
# kokoro-fastapi (../shared/pyproject.toml)
# onnxruntime
# scipy
# soundfile
# transformers
onnxruntime==1.20.1
# via kokoro-fastapi (../shared/pyproject.toml)
packaging==24.2
# via
# huggingface-hub
# onnxruntime
# transformers
phonemizer==3.3.0
# via kokoro-fastapi (../shared/pyproject.toml)
protobuf==5.29.3
# via onnxruntime
pycparser==2.22
# via cffi
pydantic==2.10.4
# via
# kokoro-fastapi (../shared/pyproject.toml)
# fastapi
# pydantic-settings
pydantic-core==2.27.2
# via pydantic
pydantic-settings==2.7.0
# via kokoro-fastapi (../shared/pyproject.toml)
pylatexenc==2.10
# via clldutils
pyparsing==3.2.1
# via rdflib
pyreadline3==3.5.4
# via humanfriendly
python-dateutil==2.9.0.post0
# via
# clldutils
# csvw
python-dotenv==1.0.1
# via
# kokoro-fastapi (../shared/pyproject.toml)
# pydantic-settings
pyyaml==6.0.2
# via
# huggingface-hub
# transformers
rdflib==7.1.2
# via csvw
referencing==0.35.1
# via
# jsonschema
# jsonschema-specifications
regex==2024.11.6
# via
# kokoro-fastapi (../shared/pyproject.toml)
# segments
# tiktoken
# transformers
requests==2.32.3
# via
# kokoro-fastapi (../shared/pyproject.toml)
# csvw
# huggingface-hub
# tiktoken
# transformers
rfc3986==1.5.0
# via csvw
rpds-py==0.22.3
# via
# jsonschema
# referencing
safetensors==0.5.2
# via transformers
scipy==1.14.1
# via kokoro-fastapi (../shared/pyproject.toml)
segments==2.2.1
# via phonemizer
six==1.17.0
# via python-dateutil
sniffio==1.3.1
# via anyio
soundfile==0.13.0
# via kokoro-fastapi (../shared/pyproject.toml)
sqlalchemy==2.0.27
# via kokoro-fastapi (../shared/pyproject.toml)
starlette==0.41.3
# via fastapi
sympy==1.13.1
# via
# onnxruntime
# torch
tabulate==0.9.0
# via clldutils
tiktoken==0.8.0
# via kokoro-fastapi (../shared/pyproject.toml)
tokenizers==0.21.0
# via transformers
torch==2.5.1+cpu
# via kokoro-fastapi-cpu (pyproject.toml)
tqdm==4.67.1
# via
# kokoro-fastapi (../shared/pyproject.toml)
# huggingface-hub
# transformers
transformers==4.47.1
# via kokoro-fastapi-cpu (pyproject.toml)
typing-extensions==4.12.2
# via
# anyio
# fastapi
# huggingface-hub
# phonemizer
# pydantic
# pydantic-core
# sqlalchemy
# torch
# uvicorn
uritemplate==4.1.1
# via csvw
urllib3==2.3.0
# via requests
uvicorn==0.34.0
# via kokoro-fastapi (../shared/pyproject.toml)
win32-setctime==1.2.0
# via loguru

1841
docker/cpu/uv.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,4 +1,6 @@
FROM nvidia/cuda:12.1.0-base-ubuntu22.04
FROM --platform=$BUILDPLATFORM nvidia/cuda:12.3.2-cudnn9-runtime-ubuntu22.04
# Set non-interactive frontend
ENV DEBIAN_FRONTEND=noninteractive
# Install Python and other dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
@ -19,47 +21,43 @@ COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
RUN useradd -m -u 1000 appuser
# Create directories and set ownership
RUN mkdir -p /app/models && \
mkdir -p /app/api/src/voices && \
RUN mkdir -p /app/api/src/voices && \
chown -R appuser:appuser /app
USER appuser
# Download and extract models
WORKDIR /app/models
RUN curl -L -o model.tar.gz https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.0.1/kokoro-82m-pytorch.tar.gz && \
tar xzf model.tar.gz && \
rm model.tar.gz
# Download and extract voice models
WORKDIR /app/api/src/voices
RUN curl -L -o voices.tar.gz https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.0.1/voice-models.tar.gz && \
tar xzf voices.tar.gz && \
rm voices.tar.gz
# Switch back to app directory
WORKDIR /app
# Copy dependency files
COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml
# Install dependencies
# Install dependencies with GPU extras
RUN --mount=type=cache,target=/root/.cache/uv \
uv venv && \
uv sync --extra gpu --no-install-project
uv sync --extra gpu
# Copy project files
# Copy project files including models
COPY --chown=appuser:appuser api ./api
COPY --chown=appuser:appuser web ./web
COPY --chown=appuser:appuser docker/scripts/download_model.* ./
# Install project
# Install project with GPU extras
RUN --mount=type=cache,target=/root/.cache/uv \
uv sync --extra gpu
COPY --chown=appuser:appuser docker/scripts/ /app/docker/scripts/
RUN chmod +x docker/scripts/entrypoint.sh
RUN chmod +x docker/scripts/download_model.sh
# Set environment variables
ENV PYTHONUNBUFFERED=1
ENV PYTHONPATH=/app:/app/models
ENV PYTHONPATH=/app
ENV PATH="/app/.venv/bin:$PATH"
ENV UV_LINK_MODE=copy
ENV USE_GPU=true
ENV USE_ONNX=false
ENV DOWNLOAD_PTH=true
ENV DOWNLOAD_ONNX=false
# Run FastAPI server
CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]
CMD ["/app/docker/scripts/entrypoint.sh"]

View file

@ -6,12 +6,14 @@ services:
context: ../..
dockerfile: docker/gpu/Dockerfile
volumes:
- ../../api/src:/app/api/src # Mount src for development
- ../../api/src/voices:/app/api/src/voices # Mount voices for persistence
- ../../api:/app/api
ports:
- "8880:8880"
environment:
- PYTHONPATH=/app:/app/models
- PYTHONPATH=/app:/app/api
- USE_GPU=true
- USE_ONNX=false
- PYTHONUNBUFFERED=1
deploy:
resources:
reservations:
@ -20,20 +22,20 @@ services:
count: 1
capabilities: [gpu]
# Gradio UI service
gradio-ui:
image: ghcr.io/remsky/kokoro-fastapi-ui:v0.1.0
# Uncomment below to build from source instead of using the released image
# build:
# context: ../../ui
ports:
- "7860:7860"
volumes:
- ../../ui/data:/app/ui/data
- ../../ui/app.py:/app/app.py # Mount app.py for hot reload
environment:
- GRADIO_WATCH=1 # Enable hot reloading
- PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
- DISABLE_LOCAL_SAVING=false # Set to 'true' to disable local saving and hide file view
- API_HOST=kokoro-tts # Set TTS service URL
- API_PORT=8880 # Set TTS service PORT
# # Gradio UI service
# gradio-ui:
# image: ghcr.io/remsky/kokoro-fastapi-ui:v0.1.0
# # Uncomment below to build from source instead of using the released image
# # build:
# # context: ../../ui
# ports:
# - "7860:7860"
# volumes:
# - ../../ui/data:/app/ui/data
# - ../../ui/app.py:/app/app.py # Mount app.py for hot reload
# environment:
# - GRADIO_WATCH=1 # Enable hot reloading
# - PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
# - DISABLE_LOCAL_SAVING=false # Set to 'true' to disable local saving and hide file view
# - API_HOST=kokoro-tts # Set TTS service URL
# - API_PORT=8880 # Set TTS service PORT

View file

@ -1,22 +0,0 @@
[project]
name = "kokoro-fastapi-gpu"
version = "0.1.0"
description = "FastAPI TTS Service - GPU Version"
readme = "../README.md"
requires-python = ">=3.10"
dependencies = [
# Core ML/DL for GPU
"torch==2.5.1+cu121",
"transformers==4.47.1",
]
[tool.uv.workspace]
members = ["../shared"]
[tool.uv.sources]
torch = { index = "pytorch-cuda" }
[[tool.uv.index]]
name = "pytorch-cuda"
url = "https://download.pytorch.org/whl/cu121"
explicit = true

View file

@ -1,229 +0,0 @@
# This file was autogenerated by uv via the following command:
# uv pip compile pyproject.toml ../shared/pyproject.toml --output-file requirements.lock
aiofiles==23.2.1
# via kokoro-fastapi (../shared/pyproject.toml)
annotated-types==0.7.0
# via pydantic
anyio==4.8.0
# via starlette
attrs==24.3.0
# via
# clldutils
# csvw
# jsonschema
# phonemizer
# referencing
babel==2.16.0
# via csvw
certifi==2024.12.14
# via requests
cffi==1.17.1
# via soundfile
charset-normalizer==3.4.1
# via requests
click==8.1.8
# via
# kokoro-fastapi (../shared/pyproject.toml)
# uvicorn
clldutils==3.21.0
# via segments
colorama==0.4.6
# via
# click
# colorlog
# csvw
# loguru
# tqdm
coloredlogs==15.0.1
# via onnxruntime
colorlog==6.9.0
# via clldutils
csvw==3.5.1
# via segments
dlinfo==1.2.1
# via phonemizer
exceptiongroup==1.2.2
# via anyio
fastapi==0.115.6
# via kokoro-fastapi (../shared/pyproject.toml)
filelock==3.16.1
# via
# huggingface-hub
# torch
# transformers
flatbuffers==24.12.23
# via onnxruntime
fsspec==2024.12.0
# via
# huggingface-hub
# torch
greenlet==3.1.1
# via sqlalchemy
h11==0.14.0
# via uvicorn
huggingface-hub==0.27.1
# via
# tokenizers
# transformers
humanfriendly==10.0
# via coloredlogs
idna==3.10
# via
# anyio
# requests
isodate==0.7.2
# via
# csvw
# rdflib
jinja2==3.1.5
# via torch
joblib==1.4.2
# via phonemizer
jsonschema==4.23.0
# via csvw
jsonschema-specifications==2024.10.1
# via jsonschema
language-tags==1.2.0
# via csvw
loguru==0.7.3
# via kokoro-fastapi (../shared/pyproject.toml)
lxml==5.3.0
# via clldutils
markdown==3.7
# via clldutils
markupsafe==3.0.2
# via
# clldutils
# jinja2
mpmath==1.3.0
# via sympy
munch==4.0.0
# via kokoro-fastapi (../shared/pyproject.toml)
networkx==3.4.2
# via torch
numpy==2.2.1
# via
# kokoro-fastapi (../shared/pyproject.toml)
# onnxruntime
# scipy
# soundfile
# transformers
onnxruntime==1.20.1
# via kokoro-fastapi (../shared/pyproject.toml)
packaging==24.2
# via
# huggingface-hub
# onnxruntime
# transformers
phonemizer==3.3.0
# via kokoro-fastapi (../shared/pyproject.toml)
protobuf==5.29.3
# via onnxruntime
pycparser==2.22
# via cffi
pydantic==2.10.4
# via
# kokoro-fastapi (../shared/pyproject.toml)
# fastapi
# pydantic-settings
pydantic-core==2.27.2
# via pydantic
pydantic-settings==2.7.0
# via kokoro-fastapi (../shared/pyproject.toml)
pylatexenc==2.10
# via clldutils
pyparsing==3.2.1
# via rdflib
pyreadline3==3.5.4
# via humanfriendly
python-dateutil==2.9.0.post0
# via
# clldutils
# csvw
python-dotenv==1.0.1
# via
# kokoro-fastapi (../shared/pyproject.toml)
# pydantic-settings
pyyaml==6.0.2
# via
# huggingface-hub
# transformers
rdflib==7.1.2
# via csvw
referencing==0.35.1
# via
# jsonschema
# jsonschema-specifications
regex==2024.11.6
# via
# kokoro-fastapi (../shared/pyproject.toml)
# segments
# tiktoken
# transformers
requests==2.32.3
# via
# kokoro-fastapi (../shared/pyproject.toml)
# csvw
# huggingface-hub
# tiktoken
# transformers
rfc3986==1.5.0
# via csvw
rpds-py==0.22.3
# via
# jsonschema
# referencing
safetensors==0.5.2
# via transformers
scipy==1.14.1
# via kokoro-fastapi (../shared/pyproject.toml)
segments==2.2.1
# via phonemizer
six==1.17.0
# via python-dateutil
sniffio==1.3.1
# via anyio
soundfile==0.13.0
# via kokoro-fastapi (../shared/pyproject.toml)
sqlalchemy==2.0.27
# via kokoro-fastapi (../shared/pyproject.toml)
starlette==0.41.3
# via fastapi
sympy==1.13.1
# via
# onnxruntime
# torch
tabulate==0.9.0
# via clldutils
tiktoken==0.8.0
# via kokoro-fastapi (../shared/pyproject.toml)
tokenizers==0.21.0
# via transformers
torch==2.5.1+cu121
# via kokoro-fastapi-gpu (pyproject.toml)
tqdm==4.67.1
# via
# kokoro-fastapi (../shared/pyproject.toml)
# huggingface-hub
# transformers
transformers==4.47.1
# via kokoro-fastapi-gpu (pyproject.toml)
typing-extensions==4.12.2
# via
# anyio
# fastapi
# huggingface-hub
# phonemizer
# pydantic
# pydantic-core
# sqlalchemy
# torch
# uvicorn
uritemplate==4.1.1
# via csvw
urllib3==2.3.0
# via requests
uvicorn==0.34.0
# via kokoro-fastapi (../shared/pyproject.toml)
win32-setctime==1.2.0
# via loguru

1914
docker/gpu/uv.lock generated

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,103 @@
#!/usr/bin/env python3
import os
import sys
import argparse
import requests
from pathlib import Path
from typing import List
def download_file(url: str, output_dir: Path, model_type: str, overwrite:str) -> bool:
"""Download a file from URL to the specified directory.
Returns:
bool: True if download succeeded, False otherwise
"""
filename = os.path.basename(url)
if not filename.endswith(f'.{model_type}'):
print(f"Warning: {filename} is not a .{model_type} file", file=sys.stderr)
return False
output_path = output_dir / filename
if os.path.exists(output_path):
print(f"{filename} exists. Canceling download")
return True
print(f"Downloading {filename}...")
try:
response = requests.get(url, stream=True)
response.raise_for_status()
with open(output_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"Successfully downloaded {filename}")
return True
except Exception as e:
print(f"Error downloading {filename}: {e}", file=sys.stderr)
return False
def find_project_root() -> Path:
"""Find project root by looking for api directory."""
max_steps = 5
current = Path(__file__).resolve()
for _ in range(max_steps):
if (current / 'api').is_dir():
return current
current = current.parent
raise RuntimeError("Could not find project root (no api directory found)")
def main() -> int:
"""Download models to the project.
Returns:
int: Exit code (0 for success, 1 for failure)
"""
parser = argparse.ArgumentParser(description='Download model files')
parser.add_argument('--type', choices=['pth', 'onnx'], required=True,
help='Model type to download (pth or onnx)')
parser.add_argument('--overwrite', action='store_true', help='Overwite existing files')
parser.add_argument('urls', nargs='*', help='Optional model URLs to download')
args = parser.parse_args()
try:
# Find project root and ensure models directory exists
project_root = find_project_root()
models_dir = project_root / 'api' / 'src' / 'models'
print(f"Downloading models to {models_dir}")
models_dir.mkdir(exist_ok=True)
# Default models if no arguments provided
default_models = {
'pth': [
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth",
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19-half.pth"
],
'onnx': [
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.onnx",
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19_fp16.onnx"
]
}
# Use provided models or default
models_to_download = args.urls if args.urls else default_models[args.type]
# Download all models
success = True
for model_url in models_to_download:
if not download_file(model_url, models_dir, args.type,args.overwrite):
success = False
if success:
print(f"{args.type.upper()} model download complete!")
return 0
else:
print("Some downloads failed", file=sys.stderr)
return 1
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
return 1
if __name__ == "__main__":
sys.exit(main())

View file

@ -0,0 +1,110 @@
#!/bin/bash
# Find project root by looking for api directory
find_project_root() {
local current_dir="$PWD"
local max_steps=5
local steps=0
while [ $steps -lt $max_steps ]; do
if [ -d "$current_dir/api" ]; then
echo "$current_dir"
return 0
fi
current_dir="$(dirname "$current_dir")"
((steps++))
done
echo "Error: Could not find project root (no api directory found)" >&2
exit 1
}
# Function to download a file
download_file() {
local url="$1"
local output_dir="$2"
local model_type="$3"
local filename=$(basename "$url")
# Validate file extension
if [[ ! "$filename" =~ \.$model_type$ ]]; then
echo "Warning: $filename is not a .$model_type file" >&2
return 1
}
echo "Downloading $filename..."
if curl -L "$url" -o "$output_dir/$filename"; then
echo "Successfully downloaded $filename"
return 0
else
echo "Error downloading $filename" >&2
return 1
fi
}
# Parse arguments
MODEL_TYPE=""
while [[ $# -gt 0 ]]; do
case $1 in
--type)
MODEL_TYPE="$2"
shift 2
;;
*)
# If no flag specified, treat remaining args as model URLs
break
;;
esac
done
# Validate model type
if [ "$MODEL_TYPE" != "pth" ] && [ "$MODEL_TYPE" != "onnx" ]; then
echo "Error: Must specify model type with --type (pth or onnx)" >&2
exit 1
fi
# Find project root and ensure models directory exists
PROJECT_ROOT=$(find_project_root)
if [ $? -ne 0 ]; then
exit 1
fi
MODELS_DIR="$PROJECT_ROOT/api/src/models"
echo "Downloading models to $MODELS_DIR"
mkdir -p "$MODELS_DIR"
# Default models if no arguments provided
if [ "$MODEL_TYPE" = "pth" ]; then
DEFAULT_MODELS=(
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth"
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19-half.pth"
)
else
DEFAULT_MODELS=(
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.onnx"
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19_fp16.onnx"
)
fi
# Use provided models or default
if [ $# -gt 0 ]; then
MODELS=("$@")
else
MODELS=("${DEFAULT_MODELS[@]}")
fi
# Download all models
success=true
for model in "${MODELS[@]}"; do
if ! download_file "$model" "$MODELS_DIR" "$MODEL_TYPE"; then
success=false
fi
done
if [ "$success" = true ]; then
echo "${MODEL_TYPE^^} model download complete!"
exit 0
else
echo "Some downloads failed" >&2
exit 1
fi

View file

@ -0,0 +1,12 @@
#!/bin/bash
set -e
if [ "$DOWNLOAD_PTH" = "true" ]; then
python docker/scripts/download_model.py --type pth
fi
if [ "$DOWNLOAD_ONNX" = "true" ]; then
python docker/scripts/download_model.py --type onnx
fi
exec uv run python -m uvicorn api.src.main:app --host 0.0.0.0 --port 8880 --log-level debug

View file

@ -0,0 +1,167 @@
# Streaming Audio Writer Analysis
This auto-document provides an in-depth technical analysis of the `StreamingAudioWriter` class, detailing the streaming and non-streaming paths, supported formats, header management, and challenges faced in the implementation.
## Overview
The `StreamingAudioWriter` class is designed to handle streaming audio format conversions efficiently. It supports various audio formats and provides methods to write audio data in chunks, finalize the stream, and manage audio headers meticulously to ensure compatibility and integrity of the resulting audio files.
## Supported Formats
The class supports the following audio formats:
- **WAV**
- **OGG**
- **Opus**
- **FLAC**
- **MP3**
- **AAC**
- **PCM**
## Initialization
Upon initialization, the class sets up format-specific configurations to prepare for audio data processing:
- **WAV**:
- Writes an initial WAV header with placeholders for file size (`RIFF` chunk) and data size (`data` chunk).
- Utilizes the `_write_wav_header` method to generate the header.
- **OGG/Opus/FLAC**:
- Uses `soundfile.SoundFile` to write audio data to a memory buffer (`BytesIO`).
- Configures the writer with appropriate format and subtype based on the specified format.
- **MP3/AAC**:
- Utilizes `pydub.AudioSegment` for incremental writing.
- Initializes an empty `AudioSegment` as the encoder to accumulate audio data.
- **PCM**:
- Prepares to write raw PCM bytes without additional headers.
Initialization ensures that each format is correctly configured to handle the specific requirements of streaming and finalizing audio data.
## Streaming Path
### Writing Chunks
The `write_chunk` method handles the incoming audio data, processing it according to the specified format:
- **WAV**:
- **First Chunk**: Writes the initial WAV header to the buffer.
- **Subsequent Chunks**: Writes raw PCM data directly after the header.
- Updates `bytes_written` to track the total size of audio data written.
- **OGG/Opus/FLAC**:
- Writes audio data to the `soundfile` buffer.
- Flushes the writer to ensure data integrity.
- Retrieves the current buffer contents and truncates the buffer for the next chunk.
- **MP3/AAC**:
- Converts incoming audio data (`np.ndarray`) to a `pydub.AudioSegment`.
- Accumulates segments in the encoder.
- Exports the current state to the output buffer without writing duration metadata or XING headers for chunks.
- Resets the encoder to prevent memory growth after exporting.
- **PCM**:
- Directly writes raw bytes from the audio data to the output buffer.
### Finalizing
Finalizing the audio stream involves ensuring that all audio data is correctly written and that headers are updated to reflect the accurate file and data sizes:
- **WAV**:
- Rewrites the `RIFF` and `data` chunks in the header with the actual file size (`bytes_written + 36`) and data size (`bytes_written`).
- Creates a new buffer with the complete WAV file by copying audio data from the original buffer starting at byte 44 (end of the initial header).
- **OGG/Opus/FLAC**:
- Closes the `soundfile` writer to flush all remaining data to the buffer.
- Returns the final buffer content, ensuring that all necessary headers and data are correctly written.
- **MP3/AAC**:
- Exports any remaining audio data with proper headers and metadata, including duration and VBR quality for MP3.
- Writes ID3v1 and ID3v2 tags for MP3 formats.
- Performs final exports to ensure that all audio data is properly encoded and formatted.
- **PCM**:
- No finalization is needed as PCM involves raw data without headers.
## Non-Streaming Path
The `StreamingAudioWriter` class is inherently designed for streaming audio data. However, it's essential to understand how it behaves when handling complete files versus streaming data:
### Full File Writing
- **Process**:
- Accumulate all audio data in memory or buffer.
- Write the complete file with accurate headers and data sizes upon finalization.
- **Advantages**:
- Simplifies header management since the total data size is known before writing.
- Reduces complexity in data handling and processing.
- **Disadvantages**:
- High memory consumption for large audio files.
- Delay in availability of audio data until the entire file is processed.
### Stream-to-File Writing
- **Process**:
- Incrementally write audio data in chunks.
- Update headers and finalize the file dynamically as data flows.
- **Advantages**:
- Lower memory usage as data is processed in smaller chunks.
- Immediate availability of audio data, suitable for real-time streaming applications.
- **Disadvantages**:
- Complex header management to accommodate dynamic data sizes.
- Increased likelihood of header synchronization issues, leading to potential file corruption.
**Challenges**:
- Balancing memory usage with processing speed.
- Ensuring consistent and accurate header updates during streaming operations.
## Header Management
### WAV Headers
WAV files utilize `RIFF` headers to describe file structure:
- **Initial Header**:
- Contains placeholders for file size and data size (`struct.pack('<L', 0)`).
- **Final Header**:
- Calculates and writes the actual file size (`bytes_written + 36`) and data size (`bytes_written`).
- Ensures that audio players can correctly interpret the file by having accurate header information.
**Technical Details**:
- The `_write_wav_header` method initializes the WAV header with placeholders.
- Upon finalization, the `write_chunk` method creates a new buffer, writes the correct sizes, and appends the audio data from the original buffer starting at byte 44 (end of the initial header).
**Challenges**:
- Maintaining synchronization between audio data size and header placeholders.
- Ensuring that the header is correctly rewritten upon finalization to prevent file corruption.
### MP3/AAC Headers
MP3 and AAC formats require proper metadata and headers to ensure compatibility:
- **XING Headers (MP3)**:
- Essential for Variable Bit Rate (VBR) audio files.
- Control the quality and indexing of the MP3 file.
- **ID3 Tags (MP3)**:
- Provide metadata such as artist, title, and album information.
- **ADTS Headers (AAC)**:
- Describe the AAC frame headers necessary for decoding.
**Technical Details**:
- During finalization, the `write_chunk` method for MP3/AAC formats includes:
- Duration metadata (`-metadata duration`).
- VBR headers for MP3 (`-write_vbr`, `-vbr_quality`).
- ID3 tags for MP3 (`-write_id3v1`, `-write_id3v2`).
- Ensures that all remaining audio data is correctly encoded and formatted with the necessary headers.
**Challenges**:
- Ensuring that metadata is accurately written during the finalization process.
- Managing VBR headers to maintain audio quality and file integrity.

View file

@ -68,7 +68,7 @@ def main():
# Initialize system monitor
monitor = SystemMonitor(interval=1.0) # 1 second interval
# Set prefix for output files (e.g. "gpu", "cpu", "onnx", etc.)
prefix = "cpu"
prefix = "gpu"
# Generate token sizes
if "gpu" in prefix:
token_sizes = generate_token_sizes(

View file

@ -1,23 +1,23 @@
=== Benchmark Statistics (with correct RTF) ===
Total tokens processed: 1800
Total audio generated (s): 568.53
Total test duration (s): 306.02
Average processing rate (tokens/s): 5.75
Average RTF: 0.55
Average Real Time Speed: 1.81
Total tokens processed: 1500
Total audio generated (s): 427.90
Total test duration (s): 10.84
Average processing rate (tokens/s): 133.35
Average RTF: 0.02
Average Real Time Speed: 41.67
=== Per-chunk Stats ===
Average chunk size (tokens): 600.00
Min chunk size (tokens): 300
Max chunk size (tokens): 900
Average processing time (s): 101.89
Average output length (s): 189.51
Average chunk size (tokens): 300.00
Min chunk size (tokens): 100
Max chunk size (tokens): 500
Average processing time (s): 2.13
Average output length (s): 85.58
=== Performance Ranges ===
Processing rate range (tokens/s): 5.30 - 6.26
RTF range: 0.51x - 0.59x
Real Time Speed range: 1.69x - 1.96x
Processing rate range (tokens/s): 102.04 - 159.74
RTF range: 0.02x - 0.03x
Real Time Speed range: 33.33x - 50.00x

View file

@ -2,616 +2,240 @@
"results": [
{
"tokens": 150,
"processing_time": 2.36,
"output_length": 45.9,
"rtf": 0.05,
"elapsed_time": 2.44626
"processing_time": 1.18,
"output_length": 43.7,
"rtf": 0.03,
"elapsed_time": 1.20302
},
{
"tokens": 300,
"processing_time": 4.94,
"output_length": 96.425,
"rtf": 0.05,
"elapsed_time": 7.46073
"processing_time": 2.27,
"output_length": 86.75,
"rtf": 0.03,
"elapsed_time": 3.49958
},
{
"tokens": 450,
"processing_time": 8.94,
"output_length": 143.1,
"rtf": 0.06,
"elapsed_time": 16.55036
"processing_time": 3.49,
"output_length": 125.9,
"rtf": 0.03,
"elapsed_time": 7.03862
},
{
"tokens": 600,
"processing_time": 19.78,
"output_length": 188.675,
"rtf": 0.1,
"elapsed_time": 36.69352
"processing_time": 4.64,
"output_length": 169.325,
"rtf": 0.03,
"elapsed_time": 11.71062
},
{
"tokens": 750,
"processing_time": 19.89,
"output_length": 236.7,
"rtf": 0.08,
"elapsed_time": 56.77695
"processing_time": 5.07,
"output_length": 212.3,
"rtf": 0.02,
"elapsed_time": 16.83186
},
{
"tokens": 900,
"processing_time": 16.83,
"output_length": 283.425,
"rtf": 0.06,
"elapsed_time": 73.8079
"processing_time": 6.66,
"output_length": 258.0,
"rtf": 0.03,
"elapsed_time": 23.54135
}
],
"system_metrics": [
{
"timestamp": "2025-01-06T00:43:20.888295",
"cpu_percent": 36.92,
"ram_percent": 68.6,
"ram_used_gb": 43.6395263671875,
"gpu_memory_used": 7022.0,
"relative_time": 0.09646010398864746
"timestamp": "2025-01-30T05:06:38.733338",
"cpu_percent": 0.0,
"ram_percent": 18.6,
"ram_used_gb": 5.284908294677734,
"gpu_memory_used": 1925.0,
"relative_time": 0.039948463439941406
},
{
"timestamp": "2025-01-06T00:43:21.983741",
"cpu_percent": 22.29,
"ram_percent": 68.6,
"ram_used_gb": 43.642677307128906,
"gpu_memory_used": 7021.0,
"relative_time": 1.1906661987304688
},
{
"timestamp": "2025-01-06T00:43:23.078293",
"cpu_percent": 27.39,
"ram_percent": 68.6,
"ram_used_gb": 43.61421203613281,
"gpu_memory_used": 7190.0,
"relative_time": 2.264479160308838
},
{
"timestamp": "2025-01-06T00:43:24.151445",
"cpu_percent": 20.28,
"ram_percent": 68.6,
"ram_used_gb": 43.65406036376953,
"gpu_memory_used": 7193.0,
"relative_time": 3.349093198776245
},
{
"timestamp": "2025-01-06T00:43:25.237021",
"cpu_percent": 23.03,
"ram_percent": 68.6,
"ram_used_gb": 43.647274017333984,
"gpu_memory_used": 7191.0,
"relative_time": 4.413560628890991
},
{
"timestamp": "2025-01-06T00:43:26.300255",
"cpu_percent": 23.62,
"ram_percent": 68.6,
"ram_used_gb": 43.642295837402344,
"gpu_memory_used": 7185.0,
"relative_time": 5.484973430633545
},
{
"timestamp": "2025-01-06T00:43:27.377319",
"cpu_percent": 46.04,
"ram_percent": 68.7,
"ram_used_gb": 43.7291374206543,
"gpu_memory_used": 7178.0,
"relative_time": 6.658120632171631
},
{
"timestamp": "2025-01-06T00:43:28.546053",
"cpu_percent": 29.79,
"ram_percent": 68.7,
"ram_used_gb": 43.73202133178711,
"gpu_memory_used": 7177.0,
"relative_time": 7.725939035415649
},
{
"timestamp": "2025-01-06T00:43:29.613327",
"cpu_percent": 18.19,
"ram_percent": 68.8,
"ram_used_gb": 43.791343688964844,
"gpu_memory_used": 7177.0,
"relative_time": 8.800285577774048
},
{
"timestamp": "2025-01-06T00:43:30.689097",
"cpu_percent": 22.29,
"ram_percent": 68.9,
"ram_used_gb": 43.81514358520508,
"gpu_memory_used": 7176.0,
"relative_time": 9.899119853973389
},
{
"timestamp": "2025-01-06T00:43:31.786443",
"cpu_percent": 32.59,
"ram_percent": 68.9,
"ram_used_gb": 43.834510803222656,
"gpu_memory_used": 7189.0,
"relative_time": 11.042734384536743
},
{
"timestamp": "2025-01-06T00:43:32.929720",
"cpu_percent": 42.48,
"ram_percent": 68.8,
"ram_used_gb": 43.77507019042969,
"gpu_memory_used": 7192.0,
"relative_time": 12.117269277572632
},
{
"timestamp": "2025-01-06T00:43:34.004481",
"cpu_percent": 26.33,
"ram_percent": 68.8,
"ram_used_gb": 43.77891159057617,
"gpu_memory_used": 7192.0,
"relative_time": 13.19830870628357
},
{
"timestamp": "2025-01-06T00:43:35.086024",
"cpu_percent": 26.53,
"ram_percent": 68.8,
"ram_used_gb": 43.77515411376953,
"gpu_memory_used": 7192.0,
"relative_time": 14.29457426071167
},
{
"timestamp": "2025-01-06T00:43:36.183496",
"cpu_percent": 40.33,
"ram_percent": 68.9,
"ram_used_gb": 43.81095886230469,
"gpu_memory_used": 7192.0,
"relative_time": 15.402768850326538
},
{
"timestamp": "2025-01-06T00:43:37.290635",
"cpu_percent": 43.6,
"ram_percent": 69.0,
"ram_used_gb": 43.87236022949219,
"gpu_memory_used": 7190.0,
"relative_time": 16.574281930923462
},
{
"timestamp": "2025-01-06T00:43:38.462164",
"cpu_percent": 85.74,
"ram_percent": 69.0,
"ram_used_gb": 43.864280700683594,
"gpu_memory_used": 6953.0,
"relative_time": 17.66074824333191
},
{
"timestamp": "2025-01-06T00:43:39.548295",
"cpu_percent": 23.88,
"ram_percent": 68.8,
"ram_used_gb": 43.75236129760742,
"gpu_memory_used": 4722.0,
"relative_time": 18.739423036575317
},
{
"timestamp": "2025-01-06T00:43:40.626692",
"cpu_percent": 59.24,
"ram_percent": 68.7,
"ram_used_gb": 43.720741271972656,
"gpu_memory_used": 4723.0,
"relative_time": 19.846031665802002
},
{
"timestamp": "2025-01-06T00:43:41.733597",
"cpu_percent": 41.74,
"ram_percent": 68.4,
"ram_used_gb": 43.53546142578125,
"gpu_memory_used": 4722.0,
"relative_time": 20.920310020446777
},
{
"timestamp": "2025-01-06T00:43:42.808191",
"cpu_percent": 35.43,
"ram_percent": 68.3,
"ram_used_gb": 43.424468994140625,
"gpu_memory_used": 4726.0,
"relative_time": 22.00457763671875
},
{
"timestamp": "2025-01-06T00:43:43.891669",
"cpu_percent": 43.81,
"ram_percent": 68.2,
"ram_used_gb": 43.38311004638672,
"gpu_memory_used": 4727.0,
"relative_time": 23.08402943611145
},
{
"timestamp": "2025-01-06T00:43:44.971246",
"cpu_percent": 58.13,
"ram_percent": 68.0,
"ram_used_gb": 43.27970886230469,
"gpu_memory_used": 4731.0,
"relative_time": 24.249765396118164
},
{
"timestamp": "2025-01-06T00:43:46.137626",
"cpu_percent": 66.76,
"ram_percent": 68.0,
"ram_used_gb": 43.23844528198242,
"gpu_memory_used": 4731.0,
"relative_time": 25.32853865623474
},
{
"timestamp": "2025-01-06T00:43:47.219723",
"cpu_percent": 27.95,
"ram_percent": 67.8,
"ram_used_gb": 43.106136322021484,
"gpu_memory_used": 4734.0,
"relative_time": 26.499221563339233
},
{
"timestamp": "2025-01-06T00:43:48.386913",
"cpu_percent": 73.13,
"ram_percent": 67.7,
"ram_used_gb": 43.049781799316406,
"gpu_memory_used": 4736.0,
"relative_time": 27.592528104782104
},
{
"timestamp": "2025-01-06T00:43:49.480407",
"cpu_percent": 50.63,
"ram_percent": 67.6,
"ram_used_gb": 43.007415771484375,
"gpu_memory_used": 4736.0,
"relative_time": 28.711266040802002
},
{
"timestamp": "2025-01-06T00:43:50.599220",
"cpu_percent": 92.36,
"ram_percent": 67.5,
"ram_used_gb": 42.9685173034668,
"gpu_memory_used": 4728.0,
"relative_time": 29.916289567947388
},
{
"timestamp": "2025-01-06T00:43:51.803667",
"cpu_percent": 83.07,
"ram_percent": 67.5,
"ram_used_gb": 42.96232986450195,
"gpu_memory_used": 4724.0,
"relative_time": 31.039498805999756
},
{
"timestamp": "2025-01-06T00:43:52.927208",
"cpu_percent": 90.61,
"ram_percent": 67.5,
"ram_used_gb": 42.96202850341797,
"gpu_memory_used": 5037.0,
"relative_time": 32.2381911277771
},
{
"timestamp": "2025-01-06T00:43:54.128135",
"cpu_percent": 89.47,
"ram_percent": 67.5,
"ram_used_gb": 42.94692611694336,
"gpu_memory_used": 5085.0,
"relative_time": 33.35147500038147
{
"timestamp": "2025-01-30T05:06:39.774003",
"cpu_percent": 13.37,
"ram_percent": 18.6,
"ram_used_gb": 5.2852630615234375,
"gpu_memory_used": 3047.0,
"relative_time": 1.0883615016937256
},
{
"timestamp": "2025-01-06T00:43:55.238967",
"cpu_percent": 60.01,
"ram_percent": 67.4,
"ram_used_gb": 42.88222122192383,
"gpu_memory_used": 5085.0,
"relative_time": 34.455963373184204
},
{
"timestamp": "2025-01-06T00:43:56.344164",
"cpu_percent": 62.12,
"ram_percent": 67.3,
"ram_used_gb": 42.81411361694336,
"gpu_memory_used": 5083.0,
"relative_time": 35.549962282180786
},
{
"timestamp": "2025-01-06T00:43:57.437566",
"cpu_percent": 53.56,
"ram_percent": 67.3,
"ram_used_gb": 42.83011245727539,
"gpu_memory_used": 5078.0,
"relative_time": 36.66783380508423
},
{
"timestamp": "2025-01-06T00:43:58.554923",
"cpu_percent": 80.27,
"ram_percent": 67.3,
"ram_used_gb": 42.79304504394531,
"gpu_memory_used": 5069.0,
"relative_time": 37.77330660820007
{
"timestamp": "2025-01-30T05:06:40.822449",
"cpu_percent": 13.68,
"ram_percent": 18.7,
"ram_used_gb": 5.303462982177734,
"gpu_memory_used": 3040.0,
"relative_time": 2.12058687210083
},
{
"timestamp": "2025-01-06T00:43:59.660456",
"cpu_percent": 72.33,
"ram_percent": 67.2,
"ram_used_gb": 42.727474212646484,
"gpu_memory_used": 5079.0,
"relative_time": 38.885955810546875
{
"timestamp": "2025-01-30T05:06:41.854375",
"cpu_percent": 15.39,
"ram_percent": 18.7,
"ram_used_gb": 5.306262969970703,
"gpu_memory_used": 3326.0,
"relative_time": 3.166278600692749
},
{
"timestamp": "2025-01-06T00:44:00.773867",
"cpu_percent": 59.29,
"ram_percent": 66.9,
"ram_used_gb": 42.566131591796875,
"gpu_memory_used": 5079.0,
"relative_time": 39.99704432487488
},
{
"timestamp": "2025-01-06T00:44:01.884399",
"cpu_percent": 43.52,
"ram_percent": 66.5,
"ram_used_gb": 42.32980728149414,
"gpu_memory_used": 5079.0,
"relative_time": 41.13008522987366
},
{
"timestamp": "2025-01-06T00:44:03.018905",
"cpu_percent": 84.46,
"ram_percent": 66.5,
"ram_used_gb": 42.28911590576172,
"gpu_memory_used": 5087.0,
"relative_time": 42.296770095825195
},
{
"timestamp": "2025-01-06T00:44:04.184606",
"cpu_percent": 88.27,
"ram_percent": 66.3,
"ram_used_gb": 42.16263961791992,
"gpu_memory_used": 5091.0,
"relative_time": 43.42832589149475
{
"timestamp": "2025-01-30T05:06:42.900882",
"cpu_percent": 14.19,
"ram_percent": 18.8,
"ram_used_gb": 5.337162017822266,
"gpu_memory_used": 2530.0,
"relative_time": 4.256956577301025
},
{
"timestamp": "2025-01-06T00:44:05.315967",
"cpu_percent": 80.91,
"ram_percent": 65.9,
"ram_used_gb": 41.9491081237793,
"gpu_memory_used": 5089.0,
"relative_time": 44.52496290206909
{
"timestamp": "2025-01-30T05:06:43.990792",
"cpu_percent": 12.63,
"ram_percent": 18.8,
"ram_used_gb": 5.333805084228516,
"gpu_memory_used": 3331.0,
"relative_time": 5.2854602336883545
},
{
"timestamp": "2025-01-06T00:44:06.412298",
"cpu_percent": 41.68,
"ram_percent": 65.6,
"ram_used_gb": 41.72716522216797,
"gpu_memory_used": 5090.0,
"relative_time": 45.679444313049316
{
"timestamp": "2025-01-30T05:06:45.019134",
"cpu_percent": 14.14,
"ram_percent": 18.8,
"ram_used_gb": 5.334297180175781,
"gpu_memory_used": 3332.0,
"relative_time": 6.351738929748535
},
{
"timestamp": "2025-01-06T00:44:07.566964",
"cpu_percent": 73.02,
"ram_percent": 65.5,
"ram_used_gb": 41.64710998535156,
"gpu_memory_used": 5091.0,
"relative_time": 46.81710481643677
{
"timestamp": "2025-01-30T05:06:46.085997",
"cpu_percent": 12.78,
"ram_percent": 18.8,
"ram_used_gb": 5.351467132568359,
"gpu_memory_used": 2596.0,
"relative_time": 7.392607688903809
},
{
"timestamp": "2025-01-06T00:44:08.704786",
"cpu_percent": 75.38,
"ram_percent": 65.4,
"ram_used_gb": 41.59475326538086,
"gpu_memory_used": 5097.0,
"relative_time": 47.91444158554077
{
"timestamp": "2025-01-30T05:06:47.127113",
"cpu_percent": 14.7,
"ram_percent": 18.9,
"ram_used_gb": 5.367542266845703,
"gpu_memory_used": 3341.0,
"relative_time": 8.441826343536377
},
{
"timestamp": "2025-01-06T00:44:09.802745",
"cpu_percent": 42.21,
"ram_percent": 65.2,
"ram_used_gb": 41.45526885986328,
"gpu_memory_used": 5111.0,
"relative_time": 49.04095649719238
{
"timestamp": "2025-01-30T05:06:48.176033",
"cpu_percent": 13.47,
"ram_percent": 18.9,
"ram_used_gb": 5.361263275146484,
"gpu_memory_used": 3339.0,
"relative_time": 9.500520706176758
},
{
"timestamp": "2025-01-06T00:44:10.928231",
"cpu_percent": 65.65,
"ram_percent": 64.4,
"ram_used_gb": 40.93437957763672,
"gpu_memory_used": 5111.0,
"relative_time": 50.14311861991882
{
"timestamp": "2025-01-30T05:06:49.234332",
"cpu_percent": 15.84,
"ram_percent": 18.9,
"ram_used_gb": 5.3612213134765625,
"gpu_memory_used": 3339.0,
"relative_time": 10.53744649887085
},
{
"timestamp": "2025-01-30T05:06:50.271159",
"cpu_percent": 14.89,
"ram_percent": 18.9,
"ram_used_gb": 5.379688262939453,
"gpu_memory_used": 3646.0,
"relative_time": 11.570110321044922
},
{
"timestamp": "2025-01-30T05:06:51.303841",
"cpu_percent": 15.71,
"ram_percent": 19.0,
"ram_used_gb": 5.390773773193359,
"gpu_memory_used": 3037.0,
"relative_time": 12.60651707649231
},
{
"timestamp": "2025-01-06T00:44:12.036249",
"cpu_percent": 28.51,
"ram_percent": 64.1,
"ram_used_gb": 40.749881744384766,
"gpu_memory_used": 5107.0,
"relative_time": 51.250269651412964
},
"timestamp": "2025-01-30T05:06:52.340383",
"cpu_percent": 15.46,
"ram_percent": 19.0,
"ram_used_gb": 5.389518737792969,
"gpu_memory_used": 3319.0,
"relative_time": 13.636165380477905
},
{
"timestamp": "2025-01-06T00:44:13.137586",
"cpu_percent": 52.99,
"ram_percent": 64.2,
"ram_used_gb": 40.84278869628906,
"gpu_memory_used": 5104.0,
"relative_time": 52.34805965423584
},
"timestamp": "2025-01-30T05:06:53.370342",
"cpu_percent": 13.12,
"ram_percent": 19.0,
"ram_used_gb": 5.391136169433594,
"gpu_memory_used": 3320.0,
"relative_time": 14.67578935623169
},
{
"timestamp": "2025-01-06T00:44:14.235248",
"cpu_percent": 34.55,
"ram_percent": 64.1,
"ram_used_gb": 40.7873420715332,
"gpu_memory_used": 5097.0,
"relative_time": 53.424301862716675
},
"timestamp": "2025-01-30T05:06:54.376175",
"cpu_percent": 14.98,
"ram_percent": 19.0,
"ram_used_gb": 5.390045166015625,
"gpu_memory_used": 3627.0,
"relative_time": 15.70747685432434
},
{
"timestamp": "2025-01-06T00:44:15.311386",
"cpu_percent": 39.07,
"ram_percent": 64.2,
"ram_used_gb": 40.860008239746094,
"gpu_memory_used": 5091.0,
"relative_time": 54.50679922103882
},
"timestamp": "2025-01-30T05:06:55.441172",
"cpu_percent": 13.45,
"ram_percent": 19.0,
"ram_used_gb": 5.394947052001953,
"gpu_memory_used": 1937.0,
"relative_time": 16.758784770965576
},
{
"timestamp": "2025-01-06T00:44:16.393626",
"cpu_percent": 31.02,
"ram_percent": 64.3,
"ram_used_gb": 40.884307861328125,
"gpu_memory_used": 5093.0,
"relative_time": 55.57431173324585
},
"timestamp": "2025-01-30T05:06:56.492442",
"cpu_percent": 17.03,
"ram_percent": 18.9,
"ram_used_gb": 5.361682891845703,
"gpu_memory_used": 3041.0,
"relative_time": 17.789713144302368
},
{
"timestamp": "2025-01-06T00:44:17.461449",
"cpu_percent": 24.53,
"ram_percent": 64.3,
"ram_used_gb": 40.89955520629883,
"gpu_memory_used": 5070.0,
"relative_time": 56.660638093948364
},
"timestamp": "2025-01-30T05:06:57.523536",
"cpu_percent": 13.76,
"ram_percent": 18.9,
"ram_used_gb": 5.360996246337891,
"gpu_memory_used": 3321.0,
"relative_time": 18.838542222976685
},
{
"timestamp": "2025-01-06T00:44:18.547558",
"cpu_percent": 19.93,
"ram_percent": 64.3,
"ram_used_gb": 40.92641830444336,
"gpu_memory_used": 5074.0,
"relative_time": 57.736456871032715
},
"timestamp": "2025-01-30T05:06:58.572158",
"cpu_percent": 15.94,
"ram_percent": 18.9,
"ram_used_gb": 5.3652801513671875,
"gpu_memory_used": 3323.0,
"relative_time": 19.86689043045044
},
{
"timestamp": "2025-01-06T00:44:19.624478",
"cpu_percent": 15.63,
"ram_percent": 64.3,
"ram_used_gb": 40.92564392089844,
"gpu_memory_used": 5082.0,
"relative_time": 58.81701683998108
},
"timestamp": "2025-01-30T05:06:59.600551",
"cpu_percent": 15.67,
"ram_percent": 18.9,
"ram_used_gb": 5.363399505615234,
"gpu_memory_used": 3630.0,
"relative_time": 20.89712619781494
},
{
"timestamp": "2025-01-06T00:44:20.705184",
"cpu_percent": 29.86,
"ram_percent": 64.4,
"ram_used_gb": 40.935394287109375,
"gpu_memory_used": 5082.0,
"relative_time": 59.88701677322388
},
"timestamp": "2025-01-30T05:07:00.631315",
"cpu_percent": 15.37,
"ram_percent": 18.9,
"ram_used_gb": 5.3663482666015625,
"gpu_memory_used": 3629.0,
"relative_time": 22.01374316215515
},
{
"timestamp": "2025-01-06T00:44:21.775463",
"cpu_percent": 43.55,
"ram_percent": 64.4,
"ram_used_gb": 40.9350471496582,
"gpu_memory_used": 5080.0,
"relative_time": 60.96005439758301
},
{
"timestamp": "2025-01-06T00:44:22.847939",
"cpu_percent": 26.66,
"ram_percent": 64.4,
"ram_used_gb": 40.94179916381836,
"gpu_memory_used": 5076.0,
"relative_time": 62.02673673629761
},
{
"timestamp": "2025-01-06T00:44:23.914337",
"cpu_percent": 22.46,
"ram_percent": 64.4,
"ram_used_gb": 40.9537467956543,
"gpu_memory_used": 5076.0,
"relative_time": 63.10581707954407
},
{
"timestamp": "2025-01-06T00:44:24.993313",
"cpu_percent": 28.07,
"ram_percent": 64.4,
"ram_used_gb": 40.94577407836914,
"gpu_memory_used": 5076.0,
"relative_time": 64.18998432159424
},
{
"timestamp": "2025-01-06T00:44:26.077028",
"cpu_percent": 26.1,
"ram_percent": 64.4,
"ram_used_gb": 40.98012161254883,
"gpu_memory_used": 5197.0,
"relative_time": 65.28782486915588
},
{
"timestamp": "2025-01-06T00:44:27.175228",
"cpu_percent": 35.17,
"ram_percent": 64.6,
"ram_used_gb": 41.0831184387207,
"gpu_memory_used": 5422.0,
"relative_time": 66.37566781044006
},
{
"timestamp": "2025-01-06T00:44:28.265025",
"cpu_percent": 55.14,
"ram_percent": 64.9,
"ram_used_gb": 41.25740432739258,
"gpu_memory_used": 5512.0,
"relative_time": 67.48023676872253
},
{
"timestamp": "2025-01-06T00:44:29.367776",
"cpu_percent": 53.84,
"ram_percent": 65.0,
"ram_used_gb": 41.36682891845703,
"gpu_memory_used": 5616.0,
"relative_time": 68.57096815109253
},
{
"timestamp": "2025-01-06T00:44:30.458301",
"cpu_percent": 33.42,
"ram_percent": 65.3,
"ram_used_gb": 41.5602912902832,
"gpu_memory_used": 5724.0,
"relative_time": 69.66709041595459
},
{
"timestamp": "2025-01-06T00:44:31.554329",
"cpu_percent": 50.81,
"ram_percent": 65.5,
"ram_used_gb": 41.66044616699219,
"gpu_memory_used": 5827.0,
"relative_time": 70.75874853134155
},
{
"timestamp": "2025-01-06T00:44:32.646414",
"cpu_percent": 34.34,
"ram_percent": 65.6,
"ram_used_gb": 41.739715576171875,
"gpu_memory_used": 5843.0,
"relative_time": 71.86718988418579
},
{
"timestamp": "2025-01-06T00:44:33.754223",
"cpu_percent": 44.32,
"ram_percent": 66.0,
"ram_used_gb": 42.005794525146484,
"gpu_memory_used": 5901.0,
"relative_time": 72.95793795585632
},
{
"timestamp": "2025-01-06T00:44:34.848852",
"cpu_percent": 48.36,
"ram_percent": 66.5,
"ram_used_gb": 42.3160514831543,
"gpu_memory_used": 5924.0,
"relative_time": 74.35109186172485
},
{
"timestamp": "2025-01-06T00:44:36.240235",
"cpu_percent": 58.06,
"ram_percent": 67.5,
"ram_used_gb": 42.95722198486328,
"gpu_memory_used": 5930.0,
"relative_time": 75.47581958770752
},
{
"timestamp": "2025-01-06T00:44:37.363208",
"cpu_percent": 46.82,
"ram_percent": 67.6,
"ram_used_gb": 42.97764587402344,
"gpu_memory_used": 6364.0,
"relative_time": 76.58708119392395
},
"timestamp": "2025-01-30T05:07:01.747500",
"cpu_percent": 13.79,
"ram_percent": 18.9,
"ram_used_gb": 5.367362976074219,
"gpu_memory_used": 3620.0,
"relative_time": 23.05113124847412
},
{
"timestamp": "2025-01-06T00:44:38.474408",
"cpu_percent": 50.93,
"ram_percent": 67.9,
"ram_used_gb": 43.1597900390625,
"gpu_memory_used": 6426.0,
"relative_time": 77.6842532157898
"timestamp": "2025-01-30T05:07:02.784828",
"cpu_percent": 10.16,
"ram_percent": 19.1,
"ram_used_gb": 5.443946838378906,
"gpu_memory_used": 1916.0,
"relative_time": 24.08937978744507
}
],
"test_duration": 82.49591493606567
"test_duration": 26.596059799194336
}

View file

@ -1,23 +1,23 @@
=== Benchmark Statistics (with correct RTF) ===
Total tokens processed: 3150
Total audio generated (s): 994.22
Total test duration (s): 73.81
Average processing rate (tokens/s): 49.36
Average RTF: 0.07
Average Real Time Speed: 15.00
Total audio generated (s): 895.98
Total test duration (s): 23.54
Average processing rate (tokens/s): 133.43
Average RTF: 0.03
Average Real Time Speed: 35.29
=== Per-chunk Stats ===
Average chunk size (tokens): 525.00
Min chunk size (tokens): 150
Max chunk size (tokens): 900
Average processing time (s): 12.12
Average output length (s): 165.70
Average processing time (s): 3.88
Average output length (s): 149.33
=== Performance Ranges ===
Processing rate range (tokens/s): 30.33 - 63.56
RTF range: 0.05x - 0.10x
Real Time Speed range: 10.00x - 20.00x
Processing rate range (tokens/s): 127.12 - 147.93
RTF range: 0.02x - 0.03x
Real Time Speed range: 33.33x - 50.00x

Binary file not shown.

Before

Width:  |  Height:  |  Size: 230 KiB

After

Width:  |  Height:  |  Size: 230 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 206 KiB

After

Width:  |  Height:  |  Size: 260 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 491 KiB

After

Width:  |  Height:  |  Size: 392 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 224 KiB

After

Width:  |  Height:  |  Size: 235 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 221 KiB

After

Width:  |  Height:  |  Size: 265 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 463 KiB

After

Width:  |  Height:  |  Size: 429 KiB

View file

@ -1,394 +1,138 @@
#!/usr/bin/env python3
import os
import argparse
from typing import Dict, List, Tuple, Optional
import time
import wave
from pathlib import Path
import numpy as np
import requests
import matplotlib.pyplot as plt
from scipy.io import wavfile
from openai import OpenAI
# Create output directory
output_dir = Path(__file__).parent / "output"
output_dir.mkdir(exist_ok=True)
def submit_combine_voices(
voices: List[str], base_url: str = "http://localhost:8880"
) -> Optional[str]:
"""Combine multiple voices into a new voice.
# Initialize OpenAI client
client = OpenAI(base_url="http://localhost:8880/v1", api_key="not-needed")
Args:
voices: List of voice names to combine (e.g. ["af_bella", "af_sarah"])
base_url: API base URL
# Test text that showcases voice characteristics
text = """The quick brown fox jumps over the lazy dog.
How vexingly quick daft zebras jump!
The five boxing wizards jump quickly."""
Returns:
Name of the combined voice (e.g. "af_bella_af_sarah") or None if error
"""
def generate_and_save_audio(voice: str, output_path: str):
"""Generate audio using specified voice and save to WAV file."""
print(f"\nGenerating audio for voice: {voice}")
start_time = time.time()
# Generate audio using streaming response
with client.audio.speech.with_streaming_response.create(
model="kokoro",
voice=voice,
response_format="wav",
input=text,
) as response:
# Save the audio stream to file
with open(output_path, "wb") as f:
for chunk in response.iter_bytes():
f.write(chunk)
duration = time.time() - start_time
print(f"Generated in {duration:.2f}s")
print(f"Saved to {output_path}")
return output_path
def analyze_audio(filepath: str):
"""Analyze audio file and return key characteristics."""
print(f"\nAnalyzing {filepath}")
try:
response = requests.post(f"{base_url}/v1/audio/voices/combine", json=voices)
print(f"Response status: {response.status_code}")
print(f"Raw response: {response.text}")
# Accept both 200 and 201 as success
if response.status_code not in [200, 201]:
try:
error = response.json()["detail"]["message"]
print(f"Error combining voices: {error}")
except:
print(f"Error combining voices: {response.text}")
return None
try:
data = response.json()
if "voices" in data:
print(f"Available voices: {', '.join(sorted(data['voices']))}")
return data["voice"]
except Exception as e:
print(f"Error parsing response: {e}")
return None
print(f"\nTrying to read {filepath}")
with wave.open(filepath, 'rb') as wf:
sample_rate = wf.getframerate()
samples = np.frombuffer(wf.readframes(wf.getnframes()), dtype=np.int16)
print(f"Successfully read file:")
print(f"Sample rate: {sample_rate}")
print(f"Samples shape: {samples.shape}")
print(f"Samples dtype: {samples.dtype}")
print(f"First few samples: {samples[:10]}")
except Exception as e:
print(f"Error: {e}")
return None
def generate_speech(
text: str,
voice: str,
base_url: str = "http://localhost:8880",
output_file: str = "output.mp3",
) -> bool:
"""Generate speech using specified voice.
Args:
text: Text to convert to speech
voice: Voice name to use
base_url: API base URL
output_file: Path to save audio file
Returns:
True if successful, False otherwise
"""
try:
response = requests.post(
f"{base_url}/v1/audio/speech",
json={
"input": text,
"voice": voice,
"speed": 1.0,
"response_format": "wav", # Use WAV for analysis
"stream": False,
},
)
if response.status_code != 200:
error = response.json().get("detail", {}).get("message", response.text)
print(f"Error generating speech: {error}")
return False
# Save the audio
os.makedirs(
os.path.dirname(output_file) if os.path.dirname(output_file) else ".",
exist_ok=True,
)
with open(output_file, "wb") as f:
f.write(response.content)
print(f"Saved audio to {output_file}")
return True
except Exception as e:
print(f"Error: {e}")
return False
def analyze_audio(filepath: str) -> Tuple[np.ndarray, int, dict]:
"""Analyze audio file and return samples, sample rate, and audio characteristics.
Args:
filepath: Path to audio file
Returns:
Tuple of (samples, sample_rate, characteristics)
"""
sample_rate, samples = wavfile.read(filepath)
print(f"Error reading file: {str(e)}")
raise
# Convert to float64 for calculations
samples = samples.astype(np.float64) / 32768.0 # Normalize 16-bit audio
# Convert to mono if stereo
if len(samples.shape) > 1:
samples = np.mean(samples, axis=1)
# Calculate basic stats
duration = len(samples) / sample_rate
max_amp = np.max(np.abs(samples))
rms = np.sqrt(np.mean(samples**2))
duration = len(samples) / sample_rate
# Zero crossing rate (helps identify voice characteristics)
zero_crossings = np.sum(np.abs(np.diff(np.signbit(samples)))) / len(samples)
# Simple frequency analysis
if len(samples) > 0:
# Use FFT to get frequency components
fft_result = np.fft.fft(samples)
freqs = np.fft.fftfreq(len(samples), 1 / sample_rate)
# Get positive frequencies only
pos_mask = freqs > 0
freqs = freqs[pos_mask]
magnitudes = np.abs(fft_result)[pos_mask]
# Find dominant frequencies (top 3)
top_indices = np.argsort(magnitudes)[-3:]
dominant_freqs = freqs[top_indices]
# Calculate spectral centroid (brightness of sound)
spectral_centroid = np.sum(freqs * magnitudes) / np.sum(magnitudes)
else:
dominant_freqs = []
spectral_centroid = 0
characteristics = {
"max_amplitude": max_amp,
"rms": rms,
"duration": duration,
"zero_crossing_rate": zero_crossings,
"dominant_frequencies": dominant_freqs,
"spectral_centroid": spectral_centroid,
# Calculate frequency characteristics
# Compute FFT
N = len(samples)
yf = np.fft.fft(samples)
xf = np.fft.fftfreq(N, 1 / sample_rate)[:N//2]
magnitude = 2.0/N * np.abs(yf[0:N//2])
# Calculate spectral centroid
spectral_centroid = np.sum(xf * magnitude) / np.sum(magnitude)
# Determine dominant frequencies
dominant_freqs = xf[magnitude.argsort()[-5:]][::-1].tolist()
return {
'samples': samples,
'sample_rate': sample_rate,
'duration': duration,
'max_amplitude': max_amp,
'rms': rms,
'spectral_centroid': spectral_centroid,
'dominant_frequencies': dominant_freqs
}
return samples, sample_rate, characteristics
def setup_plot(fig, ax, title):
"""Configure plot styling"""
# Improve grid
ax.grid(True, linestyle="--", alpha=0.3, color="#ffffff")
# Set title and labels with better fonts
ax.set_title(title, pad=20, fontsize=16, fontweight="bold", color="#ffffff")
ax.set_xlabel(ax.get_xlabel(), fontsize=14, fontweight="medium", color="#ffffff")
ax.set_ylabel(ax.get_ylabel(), fontsize=14, fontweight="medium", color="#ffffff")
# Improve tick labels
ax.tick_params(labelsize=12, colors="#ffffff")
# Style spines
for spine in ax.spines.values():
spine.set_color("#ffffff")
spine.set_alpha(0.3)
spine.set_linewidth(0.5)
# Set background colors
ax.set_facecolor("#1a1a2e")
fig.patch.set_facecolor("#1a1a2e")
return fig, ax
def plot_analysis(audio_files: Dict[str, str], output_dir: str):
"""Plot comprehensive voice analysis including waveforms and metrics comparison.
Args:
audio_files: Dictionary of label -> filepath
output_dir: Directory to save plot files
"""
# Set dark style
plt.style.use("dark_background")
# Create figure with subplots
fig = plt.figure(figsize=(15, 15))
fig.patch.set_facecolor("#1a1a2e")
num_files = len(audio_files)
# Create subplot grid with proper spacing for waveforms and metrics
total_rows = num_files + 2 # Add one more row for metrics
gs = plt.GridSpec(
total_rows, 2, height_ratios=[1.5] * num_files + [1, 1], hspace=0.4, wspace=0.3
)
# Analyze all files first
all_chars = {}
for i, (label, filepath) in enumerate(audio_files.items()):
samples, sample_rate, chars = analyze_audio(filepath)
all_chars[label] = chars
# Plot waveform spanning both columns
ax = plt.subplot(gs[i, :])
time = np.arange(len(samples)) / sample_rate
plt.plot(time, samples / chars["max_amplitude"], linewidth=0.5, color="#ff2a6d")
ax.set_xlabel("Time (seconds)")
ax.set_ylabel("Normalized Amplitude")
ax.set_ylim(-1.1, 1.1)
setup_plot(fig, ax, f"Waveform: {label}")
# Colors for voices
colors = ["#ff2a6d", "#05d9e8", "#d1f7ff"]
# Create metrics for each subplot
metrics = [
(
plt.subplot(gs[num_files, 0]),
[
(
"Volume",
[chars["rms"] * 100 for chars in all_chars.values()],
"RMS×100",
)
],
),
(
plt.subplot(gs[num_files, 1]),
[
(
"Brightness",
[chars["spectral_centroid"] / 1000 for chars in all_chars.values()],
"kHz",
)
],
),
(
plt.subplot(gs[num_files + 1, 0]),
[
(
"Voice Pitch",
[
min(chars["dominant_frequencies"])
for chars in all_chars.values()
],
"Hz",
)
],
),
(
plt.subplot(gs[num_files + 1, 1]),
[
(
"Texture",
[
chars["zero_crossing_rate"] * 1000
for chars in all_chars.values()
],
"ZCR×1000",
)
],
),
]
# Plot each metric
for i, (ax, metric_data) in enumerate(metrics):
n_voices = len(audio_files)
bar_width = 0.25
indices = np.array([0])
values = metric_data[0][1]
max_val = max(values)
for j, (voice, color) in enumerate(zip(audio_files.keys(), colors)):
offset = (j - n_voices / 2 + 0.5) * bar_width
bars = ax.bar(
indices + offset,
[values[j]],
bar_width,
label=voice,
color=color,
alpha=0.8,
)
# Add value labels on top of bars
for bar in bars:
height = bar.get_height()
ax.text(
bar.get_x() + bar.get_width() / 2.0,
height,
f"{height:.1f}",
ha="center",
va="bottom",
color="white",
fontsize=10,
)
ax.set_xticks(indices)
ax.set_xticklabels([f"{metric_data[0][0]}\n({metric_data[0][2]})"])
ax.set_ylim(0, max_val * 1.2)
ax.set_ylabel("Value")
# Only show legend on first metric plot
if i == 0:
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
facecolor="#1a1a2e",
edgecolor="#ffffff",
)
# Style the subplot
setup_plot(fig, ax, metric_data[0][0])
# Adjust the figure size and padding
fig.set_size_inches(15, 20)
plt.subplots_adjust(right=0.85, top=0.95, bottom=0.05, left=0.1)
plt.savefig(os.path.join(output_dir, "analysis_comparison.png"), dpi=300)
print(f"Saved analysis comparison to {output_dir}/analysis_comparison.png")
# Print detailed comparative analysis
print("\nDetailed Voice Analysis:")
for label, chars in all_chars.items():
print(f"\n{label}:")
print(f" Max Amplitude: {chars['max_amplitude']:.2f}")
print(f" RMS (loudness): {chars['rms']:.2f}")
print(f" Duration: {chars['duration']:.2f}s")
print(f" Zero Crossing Rate: {chars['zero_crossing_rate']:.3f}")
print(f" Spectral Centroid: {chars['spectral_centroid']:.0f}Hz")
print(
f" Dominant Frequencies: {', '.join(f'{f:.0f}Hz' for f in chars['dominant_frequencies'])}"
)
def plot_comparison(analyses, output_path):
"""Create comparison plot of the audio analyses."""
plt.style.use('dark_background')
fig = plt.figure(figsize=(15, 10))
fig.patch.set_facecolor('#1a1a2e')
# Plot waveforms
for i, (name, data) in enumerate(analyses.items()):
ax = plt.subplot(3, 1, i+1)
samples = data['samples']
time = np.arange(len(samples)) / data['sample_rate']
plt.plot(time, samples / data['max_amplitude'], linewidth=0.5, color='#ff2a6d')
plt.title(f"Waveform: {name}", color='white', pad=20)
plt.xlabel("Time (seconds)", color='white')
plt.ylabel("Normalized Amplitude", color='white')
plt.grid(True, alpha=0.3)
ax.set_facecolor('#1a1a2e')
plt.ylim(-1.1, 1.1)
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"\nSaved comparison plot to {output_path}")
def main():
parser = argparse.ArgumentParser(description="Kokoro Voice Analysis Demo")
parser.add_argument("--voices", nargs="+", type=str, help="Voices to combine")
parser.add_argument(
"--text",
type=str,
default="Hello! This is a test of combined voices.",
help="Text to speak",
)
parser.add_argument("--url", default="http://localhost:8880", help="API base URL")
parser.add_argument(
"--output-dir",
default="examples/assorted_checks/test_combinations/output",
help="Output directory for audio files",
)
args = parser.parse_args()
if not args.voices:
print("No voices provided, using default test voices")
args.voices = ["af_bella", "af_nicole"]
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
# Dictionary to store audio files for analysis
audio_files = {}
# Generate speech with individual voices
print("Generating speech with individual voices...")
for voice in args.voices:
output_file = os.path.join(args.output_dir, f"analysis_{voice}.wav")
if generate_speech(args.text, voice, args.url, output_file):
audio_files[voice] = output_file
# Generate speech with combined voice
print(f"\nCombining voices: {', '.join(args.voices)}")
combined_voice = submit_combine_voices(args.voices, args.url)
if combined_voice:
print(f"Successfully created combined voice: {combined_voice}")
output_file = os.path.join(
args.output_dir, f"analysis_combined_{combined_voice}.wav"
)
if generate_speech(args.text, combined_voice, args.url, output_file):
audio_files["combined"] = output_file
# Generate comparison plots
plot_analysis(audio_files, args.output_dir)
else:
print("Failed to combine voices")
# Generate audio for each voice
voices = {
'af_bella': output_dir / 'af_bella.wav',
'af_irulan': output_dir / 'af_irulan.wav',
'af_bella+af_irulan': output_dir / 'af_bella+af_irulan.wav'
}
for voice, path in voices.items():
generate_and_save_audio(voice, str(path))
# Analyze each audio file
analyses = {}
for name, path in voices.items():
analyses[name] = analyze_audio(str(path))
# Create comparison plot
plot_comparison(analyses, output_dir / 'voice_comparison.png')
if __name__ == "__main__":
main()

Some files were not shown because too many files have changed in this diff Show more