Merge pull request #102 from remsky/v0.1.4
V0.1.4: Improved web UI streaming headers
11
.github/workflows/ci.yml
vendored
|
@ -15,11 +15,16 @@ jobs:
|
|||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
# Add FFmpeg installation step
|
||||
- name: Install FFmpeg
|
||||
# Match Dockerfile dependencies
|
||||
- name: Install Dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y ffmpeg
|
||||
sudo apt-get install -y --no-install-recommends \
|
||||
espeak-ng \
|
||||
git \
|
||||
libsndfile1 \
|
||||
curl \
|
||||
ffmpeg
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
|
|
44
.gitignore
vendored
|
@ -18,7 +18,8 @@ __pycache__/
|
|||
*.egg
|
||||
dist/
|
||||
build/
|
||||
|
||||
*.onnx
|
||||
*.pth
|
||||
# Environment
|
||||
# .env
|
||||
.venv/
|
||||
|
@ -38,6 +39,26 @@ ENV/
|
|||
*.pth
|
||||
*.tar*
|
||||
|
||||
|
||||
# Other project files
|
||||
.env
|
||||
Kokoro-82M/
|
||||
ui/data/
|
||||
EXTERNAL_UV_DOCUMENTATION*
|
||||
app
|
||||
api/temp_files/
|
||||
|
||||
# Docker
|
||||
Dockerfile*
|
||||
docker-compose*
|
||||
examples/ebook_test/chapter_to_audio.py
|
||||
examples/ebook_test/chapters_to_audio.py
|
||||
examples/ebook_test/parse_epub.py
|
||||
api/src/voices/af_jadzia.pt
|
||||
examples/assorted_checks/test_combinations/output/*
|
||||
examples/assorted_checks/test_openai/output/*
|
||||
|
||||
|
||||
# Audio files
|
||||
examples/*.wav
|
||||
examples/*.pcm
|
||||
|
@ -46,22 +67,5 @@ examples/*.flac
|
|||
examples/*.acc
|
||||
examples/*.ogg
|
||||
examples/speech.mp3
|
||||
examples/phoneme_examples/output/example_1.wav
|
||||
examples/phoneme_examples/output/example_2.wav
|
||||
examples/phoneme_examples/output/example_3.wav
|
||||
|
||||
# Other project files
|
||||
Kokoro-82M/
|
||||
ui/data/
|
||||
EXTERNAL_UV_DOCUMENTATION*
|
||||
app
|
||||
|
||||
# Docker
|
||||
Dockerfile*
|
||||
docker-compose*
|
||||
examples/assorted_checks/River_of_Teet_-_Sarah_Gailey.epub
|
||||
examples/ebook_test/chapter_to_audio.py
|
||||
examples/ebook_test/chapters_to_audio.py
|
||||
examples/ebook_test/parse_epub.py
|
||||
examples/ebook_test/River_of_Teet_-_Sarah_Gailey.epub
|
||||
examples/ebook_test/River_of_Teet_-_Sarah_Gailey.txt
|
||||
examples/phoneme_examples/output/*.wav
|
||||
examples/assorted_checks/benchmarks/output_audio/*
|
||||
|
|
44
CHANGELOG.md
|
@ -2,6 +2,50 @@
|
|||
|
||||
Notable changes to this project will be documented in this file.
|
||||
|
||||
## [v0.1.4] - 2025-01-30
|
||||
### Added
|
||||
- Smart Chunking System:
|
||||
- New text_processor with smart_split for improved sentence boundary detection
|
||||
- Dynamically adjusts chunk sizes based on sentence structure, using phoneme/token information in an intial pass
|
||||
- Should avoid ever going over the 510 limit per chunk, while preserving natural cadence
|
||||
- Web UI Added (To Be Replacing Gradio):
|
||||
- Integrated streaming with tempfile generation
|
||||
- Download links available in X-Download-Path header
|
||||
- Configurable cleanup triggers for temp files
|
||||
- Debug Endpoints:
|
||||
- /debug/threads for thread information and stack traces
|
||||
- /debug/storage for temp file and output directory monitoring
|
||||
- /debug/system for system resource information
|
||||
- /debug/session_pools for ONNX/CUDA session status
|
||||
- Automated Model Management:
|
||||
- Auto-download from releases page
|
||||
- Included download scripts for manual installation
|
||||
- Pre-packaged voice models in repository
|
||||
|
||||
### Changed
|
||||
- Significant architectural improvements:
|
||||
- Multi-model architecture support
|
||||
- Enhanced concurrency handling
|
||||
- Improved streaming header management
|
||||
- Better resource/session pool management
|
||||
|
||||
|
||||
## [v0.1.2] - 2025-01-23
|
||||
### Structural Improvements
|
||||
- Models can be manually download and placed in api/src/models, or use included script
|
||||
- TTSGPU/TPSCPU/STTSService classes replaced with a ModelManager service
|
||||
- CPU/GPU of each of ONNX/PyTorch (Note: Only Pytorch GPU, and ONNX CPU/GPU have been tested)
|
||||
- Should be able to improve new models as they become available, or new architectures, in a more modular way
|
||||
- Converted a number of internal processes to async handling to improve concurrency
|
||||
- Improving separation of concerns towards plug-in and modular structure, making PR's and new features easier
|
||||
|
||||
### Web UI (test release)
|
||||
- An integrated simple web UI has been added on the FastAPI server directly
|
||||
- This can be disabled via core/config.py or ENV variables if desired.
|
||||
- Simplifies deployments, utility testing, aesthetics, etc
|
||||
- Looking to deprecate/collaborate/hand off the Gradio UI
|
||||
|
||||
|
||||
## [v0.1.0] - 2025-01-13
|
||||
### Changed
|
||||
- Major Docker improvements:
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
Subproject commit c97b7bbc3e60f447383c79b2f94fee861ff156ac
|
|
@ -1,70 +0,0 @@
|
|||
# UV Setup
|
||||
Deprecated notes for myself
|
||||
## Structure
|
||||
```
|
||||
docker/
|
||||
├── cpu/
|
||||
│ ├── pyproject.toml # CPU deps (torch CPU)
|
||||
│ └── requirements.lock # CPU lockfile
|
||||
├── gpu/
|
||||
│ ├── pyproject.toml # GPU deps (torch CUDA)
|
||||
│ └── requirements.lock # GPU lockfile
|
||||
└── shared/
|
||||
└── pyproject.toml # Common deps
|
||||
```
|
||||
|
||||
## Regenerate Lock Files
|
||||
|
||||
### CPU
|
||||
```bash
|
||||
cd docker/cpu
|
||||
uv pip compile pyproject.toml ../shared/pyproject.toml --output-file requirements.lock
|
||||
```
|
||||
|
||||
### GPU
|
||||
```bash
|
||||
cd docker/gpu
|
||||
uv pip compile pyproject.toml ../shared/pyproject.toml --output-file requirements.lock
|
||||
```
|
||||
|
||||
## Local Dev Setup
|
||||
|
||||
### CPU
|
||||
```bash
|
||||
cd docker/cpu
|
||||
uv venv
|
||||
.venv\Scripts\activate # Windows
|
||||
uv pip sync requirements.lock
|
||||
```
|
||||
|
||||
### GPU
|
||||
```bash
|
||||
cd docker/gpu
|
||||
uv venv
|
||||
.venv\Scripts\activate # Windows
|
||||
uv pip sync requirements.lock --extra-index-url https://download.pytorch.org/whl/cu121 --index-strategy unsafe-best-match
|
||||
```
|
||||
|
||||
### Run Server
|
||||
```bash
|
||||
# From project root with venv active:
|
||||
uvicorn api.src.main:app --reload
|
||||
```
|
||||
|
||||
## Docker
|
||||
|
||||
### CPU
|
||||
```bash
|
||||
cd docker/cpu
|
||||
docker compose up
|
||||
```
|
||||
|
||||
### GPU
|
||||
```bash
|
||||
cd docker/gpu
|
||||
docker compose up
|
||||
```
|
||||
|
||||
## Known Issues
|
||||
- Module imports: Run server from project root
|
||||
- PyTorch CUDA: Always use --extra-index-url and --index-strategy for GPU env
|
244
README.md
|
@ -3,74 +3,83 @@
|
|||
</p>
|
||||
|
||||
# <sub><sub>_`FastKoko`_ </sub></sub>
|
||||
[]()
|
||||
[]()
|
||||
[]()
|
||||
[]()
|
||||
[](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667) [](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero)
|
||||
|
||||
> Support for Kokoro-82M v1.0 coming very soon!
|
||||
|
||||
Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model
|
||||
- OpenAI-compatible Speech endpoint, with inline voice combination functionality
|
||||
- NVIDIA GPU accelerated or CPU Onnx inference
|
||||
- OpenAI-compatible Speech endpoint, with inline voice combination, and mapped naming/models for strict systems
|
||||
- NVIDIA GPU accelerated or CPU inference (ONNX, Pytorch)
|
||||
- very fast generation time
|
||||
- 35x-100x+ real time speed via 4060Ti+
|
||||
- 5x+ real time speed via M3 Pro CPU
|
||||
- streaming support w/ variable chunking to control latency & artifacts
|
||||
- phoneme, simple audio generation web ui utility
|
||||
- Runs on an 80mb-300mb model (CUDA container + 5gb on disk due to drivers)
|
||||
- ~35x-100x+ real time speed via 4060Ti+
|
||||
- ~5x+ real time speed via M3 Pro CPU
|
||||
- streaming support & tempfile generation
|
||||
- phoneme based dev endpoints
|
||||
- (new) Integrated web UI on localhost:8880/web
|
||||
- (new) Debug endpoints for monitoring threads, storage, and session pools
|
||||
|
||||
> [!Tip]
|
||||
> You can try the new beta version from the `v0.1.2-pre` branch now:
|
||||
<table>
|
||||
<tr>
|
||||
<td>
|
||||
<img src="https://github.com/user-attachments/assets/440162eb-1918-4999-ab2b-e2730990efd0" width="100%" alt="Voice Analysis Comparison" style="border: 2px solid #333; padding: 5px;">
|
||||
</td>
|
||||
<td>
|
||||
<ul>
|
||||
<li>Integrated web UI (on localhost:8880/web)</li>
|
||||
<li>Better concurrency handling, baked in models and voices</li>
|
||||
<li>Voice name/model mappings to OAI standard</li>
|
||||
<pre> # with:
|
||||
docker run -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-cpu:latest # CPU
|
||||
docker run --gpus all -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-gpu:latest # Nvidia GPU
|
||||
</pre>
|
||||
</ul>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<details open>
|
||||
<summary>Quick Start</summary>
|
||||
## Get Started
|
||||
|
||||
The service can be accessed through either the API endpoints or the Gradio web interface.
|
||||
<details >
|
||||
<summary>Quickest Start (docker run)</summary>
|
||||
|
||||
|
||||
Pre built images are available to run, with arm/multi-arch support, and baked in models
|
||||
Refer to the core/config.py file for a full list of variables which can be managed via the environment
|
||||
|
||||
```bash
|
||||
|
||||
docker run -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-cpu:v0.1.4 # CPU, or:
|
||||
docker run --gpus all -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-gpu:v0.1.4 #NVIDIA GPU
|
||||
```
|
||||
|
||||
Once running, access:
|
||||
- API Documentation: http://localhost:8880/docs
|
||||
- Web Interface: http://localhost:8880/web
|
||||
|
||||
<div align="center" style="display: flex; justify-content: center; gap: 20px;">
|
||||
<img src="assets/docs-screenshot.png" width="48%" alt="API Documentation" style="border: 2px solid #333; padding: 10px;">
|
||||
<img src="assets/webui-screenshot.png" width="48%" alt="Web UI Screenshot" style="border: 2px solid #333; padding: 10px;">
|
||||
</div>
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Quick Start (docker compose) </summary>
|
||||
|
||||
1. Install prerequisites, and start the service using Docker Compose (Full setup including UI):
|
||||
- Install [Docker Desktop](https://www.docker.com/products/docker-desktop/)
|
||||
- Install [Docker](https://www.docker.com/products/docker-desktop/)
|
||||
-
|
||||
- Clone the repository:
|
||||
```bash
|
||||
git clone https://github.com/remsky/Kokoro-FastAPI.git
|
||||
cd Kokoro-FastAPI
|
||||
|
||||
# * Switch to stable branch if any issues *
|
||||
git checkout v0.0.5post1-stable
|
||||
|
||||
cd docker/gpu # OR
|
||||
cd docker/gpu # OR
|
||||
# cd docker/cpu # Run this or the above
|
||||
docker compose up --build
|
||||
docker compose up --build
|
||||
# if you are missing any models, run:
|
||||
# python ../scripts/download_model.py --type pth # for GPU
|
||||
# python ../scripts/download_model.py --type onnx # for CPU
|
||||
```
|
||||
|
||||
|
||||
```bash
|
||||
Or directly via UV
|
||||
./start-cpu.sh
|
||||
./start-gpu.sh
|
||||
```
|
||||
|
||||
Once started:
|
||||
- The API will be available at http://localhost:8880
|
||||
- The UI can be accessed at http://localhost:7860
|
||||
|
||||
__Or__ running the API alone using Docker (model + voice packs baked in) (Most Recent):
|
||||
|
||||
```bash
|
||||
docker run -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-cpu:v0.1.0post1 # CPU
|
||||
docker run --gpus all -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-gpu:v0.1.0post1 # Nvidia GPU
|
||||
```
|
||||
|
||||
|
||||
4. Run locally as an OpenAI-Compatible Speech Endpoint
|
||||
- The *Web UI* can be tested at http://localhost:8880/web
|
||||
- The Gradio UI (deprecating) can be accessed at http://localhost:7860
|
||||
|
||||
2. Run locally as an OpenAI-Compatible Speech Endpoint
|
||||
```python
|
||||
from openai import OpenAI
|
||||
client = OpenAI(
|
||||
|
@ -87,12 +96,49 @@ The service can be accessed through either the API endpoints or the Gradio web i
|
|||
response.stream_to_file("output.mp3")
|
||||
|
||||
```
|
||||
</details>
|
||||
<summary>Direct Run (via uv) </summary>
|
||||
|
||||
or visit http://localhost:7860
|
||||
<p align="center">
|
||||
<img src="ui\GradioScreenShot.png" width="80%" alt="Voice Analysis Comparison" style="border: 2px solid #333; padding: 10px;">
|
||||
</p>
|
||||
1. Install prerequisites ():
|
||||
- Install [astral-uv](https://docs.astral.sh/uv/)
|
||||
- Clone the repository:
|
||||
```bash
|
||||
git clone https://github.com/remsky/Kokoro-FastAPI.git
|
||||
cd Kokoro-FastAPI
|
||||
|
||||
# if you are missing any models, run:
|
||||
# python ../scripts/download_model.py --type pth # for GPU
|
||||
# python ../scripts/download_model.py --type onnx # for CPU
|
||||
```
|
||||
|
||||
Start directly via UV (with hot-reload)
|
||||
```bash
|
||||
./start-cpu.sh OR
|
||||
./start-gpu.sh
|
||||
```
|
||||
|
||||
Once started:
|
||||
- The API will be available at http://localhost:8880
|
||||
- The *Web UI* can be tested at http://localhost:8880/web
|
||||
- The Gradio UI (deprecating) can be accessed at http://localhost:7860
|
||||
|
||||
2. Run locally as an OpenAI-Compatible Speech Endpoint
|
||||
```python
|
||||
from openai import OpenAI
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8880/v1",
|
||||
api_key="not-needed"
|
||||
)
|
||||
|
||||
with client.audio.speech.with_streaming_response.create(
|
||||
model="kokoro",
|
||||
voice="af_sky+af_bella", #single or multiple voicepack combo
|
||||
input="Hello world!",
|
||||
response_format="mp3"
|
||||
) as response:
|
||||
response.stream_to_file("output.mp3")
|
||||
|
||||
```
|
||||
</details>
|
||||
|
||||
## Features
|
||||
|
@ -104,8 +150,8 @@ The service can be accessed through either the API endpoints or the Gradio web i
|
|||
from openai import OpenAI
|
||||
client = OpenAI(base_url="http://localhost:8880/v1", api_key="not-needed")
|
||||
response = client.audio.speech.create(
|
||||
model="kokoro", # Not used but required for compatibility, also accepts library defaults
|
||||
voice="af_bella+af_sky",
|
||||
model="kokoro",
|
||||
voice="af_bella+af_sky", # see /api/src/core/openai_mappings.json to customize
|
||||
input="Hello world!",
|
||||
response_format="mp3"
|
||||
)
|
||||
|
@ -124,7 +170,7 @@ voices = response.json()["voices"]
|
|||
response = requests.post(
|
||||
"http://localhost:8880/v1/audio/speech",
|
||||
json={
|
||||
"model": "kokoro", # Not used but required for compatibility
|
||||
"model": "kokoro",
|
||||
"input": "Hello world!",
|
||||
"voice": "af_bella",
|
||||
"response_format": "mp3", # Supported: mp3, wav, opus, flac
|
||||
|
@ -207,13 +253,12 @@ If you only want the API, just comment out everything in the docker-compose.yml
|
|||
|
||||
Currently, voices created via the API are accessible here, but voice combination/creation has not yet been added
|
||||
|
||||
Running the UI Docker Service
|
||||
Running the UI Docker Service [deprecating]
|
||||
- If you only want to run the Gradio web interface separately and connect it to an existing API service:
|
||||
```bash
|
||||
docker run -p 7860:7860 \
|
||||
-e API_HOST=<api-hostname-or-ip> \
|
||||
-e API_PORT=8880 \
|
||||
ghcr.io/remsky/kokoro-fastapi-ui:v0.1.0
|
||||
```
|
||||
|
||||
- Replace `<api-hostname-or-ip>` with:
|
||||
|
@ -232,7 +277,7 @@ environment:
|
|||
|
||||
When running the Docker image directly:
|
||||
```bash
|
||||
docker run -p 7860:7860 -e DISABLE_LOCAL_SAVING=true ghcr.io/remsky/kokoro-fastapi-ui:latest
|
||||
docker run -p 7860:7860 -e DISABLE_LOCAL_SAVING=true ghcr.io/remsky/kokoro-fastapi-ui:v0.1.4
|
||||
```
|
||||
</details>
|
||||
|
||||
|
@ -243,7 +288,7 @@ docker run -p 7860:7860 -e DISABLE_LOCAL_SAVING=true ghcr.io/remsky/kokoro-fasta
|
|||
# OpenAI-compatible streaming
|
||||
from openai import OpenAI
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8880", api_key="not-needed")
|
||||
base_url="http://localhost:8880/v1", api_key="not-needed")
|
||||
|
||||
# Stream to file
|
||||
with client.audio.speech.with_streaming_response.create(
|
||||
|
@ -325,17 +370,17 @@ Benchmarking was performed on generation via the local API using text lengths up
|
|||
</p>
|
||||
|
||||
Key Performance Metrics:
|
||||
- Realtime Speed: Ranges between 25-50x (generation time to output audio length)
|
||||
- Realtime Speed: Ranges between 35x-100x (generation time to output audio length)
|
||||
- Average Processing Rate: 137.67 tokens/second (cl100k_base)
|
||||
</details>
|
||||
<details>
|
||||
<summary>GPU Vs. CPU</summary>
|
||||
|
||||
```bash
|
||||
# GPU: Requires NVIDIA GPU with CUDA 12.1 support (~35x realtime speed)
|
||||
# GPU: Requires NVIDIA GPU with CUDA 12.1 support (~35x-100x realtime speed)
|
||||
docker compose up --build
|
||||
|
||||
# CPU: ONNX optimized inference (~2.4x realtime speed)
|
||||
# CPU: ONNX optimized inference (~5x+ realtime speed on M3 Pro)
|
||||
docker compose -f docker-compose.cpu.yml up --build
|
||||
```
|
||||
*Note: Overall speed may have reduced somewhat with the structural changes to accomodate streaming. Looking into it*
|
||||
|
@ -355,36 +400,61 @@ Convert text to phonemes and/or generate audio directly from phonemes:
|
|||
```python
|
||||
import requests
|
||||
|
||||
# Convert text to phonemes
|
||||
response = requests.post(
|
||||
"http://localhost:8880/dev/phonemize",
|
||||
json={
|
||||
"text": "Hello world!",
|
||||
"language": "a" # "a" for American English
|
||||
}
|
||||
)
|
||||
result = response.json()
|
||||
phonemes = result["phonemes"] # Phoneme string e.g ðɪs ɪz ˈoʊnli ɐ tˈɛst
|
||||
tokens = result["tokens"] # Token IDs including start/end tokens
|
||||
def get_phonemes(text: str, language: str = "a"):
|
||||
"""Get phonemes and tokens for input text"""
|
||||
response = requests.post(
|
||||
"http://localhost:8880/dev/phonemize",
|
||||
json={"text": text, "language": language} # "a" for American English
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return result["phonemes"], result["tokens"]
|
||||
|
||||
# Generate audio from phonemes
|
||||
response = requests.post(
|
||||
"http://localhost:8880/dev/generate_from_phonemes",
|
||||
json={
|
||||
"phonemes": phonemes,
|
||||
"voice": "af_bella",
|
||||
"speed": 1.0
|
||||
}
|
||||
)
|
||||
def generate_audio_from_phonemes(phonemes: str, voice: str = "af_bella"):
|
||||
"""Generate audio from phonemes"""
|
||||
response = requests.post(
|
||||
"http://localhost:8880/dev/generate_from_phonemes",
|
||||
json={"phonemes": phonemes, "voice": voice},
|
||||
headers={"Accept": "audio/wav"}
|
||||
)
|
||||
if response.status_code != 200:
|
||||
print(f"Error: {response.text}")
|
||||
return None
|
||||
return response.content
|
||||
|
||||
# Save WAV audio
|
||||
with open("speech.wav", "wb") as f:
|
||||
f.write(response.content)
|
||||
# Example usage
|
||||
text = "Hello world!"
|
||||
try:
|
||||
# Convert text to phonemes
|
||||
phonemes, tokens = get_phonemes(text)
|
||||
print(f"Phonemes: {phonemes}") # e.g. ðɪs ɪz ˈoʊnli ɐ tˈɛst
|
||||
print(f"Tokens: {tokens}") # Token IDs including start/end tokens
|
||||
|
||||
# Generate and save audio
|
||||
if audio_bytes := generate_audio_from_phonemes(phonemes):
|
||||
with open("speech.wav", "wb") as f:
|
||||
f.write(audio_bytes)
|
||||
print(f"Generated {len(audio_bytes)} bytes of audio")
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
```
|
||||
|
||||
See `examples/phoneme_examples/generate_phonemes.py` for a sample script.
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Debug Endpoints</summary>
|
||||
|
||||
Monitor system state and resource usage with these endpoints:
|
||||
|
||||
- `/debug/threads` - Get thread information and stack traces
|
||||
- `/debug/storage` - Monitor temp file and output directory usage
|
||||
- `/debug/system` - Get system information (CPU, memory, GPU)
|
||||
- `/debug/session_pools` - View ONNX session and CUDA stream status
|
||||
|
||||
Useful for debugging resource exhaustion or performance issues.
|
||||
</details>
|
||||
|
||||
## Known Issues
|
||||
|
||||
<details>
|
||||
|
|
|
@ -337,11 +337,13 @@ def recursive_munch(d):
|
|||
else:
|
||||
return d
|
||||
|
||||
def build_model(path, device):
|
||||
async def build_model(path, device):
|
||||
from ..core.paths import load_json, load_model_weights
|
||||
|
||||
config = Path(__file__).parent / 'config.json'
|
||||
assert config.exists(), f'Config path incorrect: config.json not found at {config}'
|
||||
with open(config, 'r') as r:
|
||||
args = recursive_munch(json.load(r))
|
||||
|
||||
args = recursive_munch(await load_json(config))
|
||||
assert args.decoder.type == 'istftnet', f'Unknown decoder type: {args.decoder.type}'
|
||||
decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
|
||||
resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
|
||||
|
@ -365,7 +367,8 @@ def build_model(path, device):
|
|||
decoder=decoder.to(device).eval(),
|
||||
text_encoder=text_encoder.to(device).eval(),
|
||||
)
|
||||
for key, state_dict in torch.load(path, map_location='cpu', weights_only=True)['net'].items():
|
||||
weights = await load_model_weights(path, device=device)
|
||||
for key, state_dict in weights['net'].items():
|
||||
assert key in model, key
|
||||
try:
|
||||
model[key].load_state_dict(state_dict)
|
||||
|
|
|
@ -9,25 +9,34 @@ class Settings(BaseSettings):
|
|||
host: str = "0.0.0.0"
|
||||
port: int = 8880
|
||||
|
||||
# TTS Settings
|
||||
# Application Settings
|
||||
output_dir: str = "output"
|
||||
output_dir_size_limit_mb: float = 500.0 # Maximum size of output directory in MB
|
||||
default_voice: str = "af"
|
||||
model_dir: str = "/app/models" # Base directory for model files
|
||||
pytorch_model_path: str = "kokoro-v0_19.pth"
|
||||
onnx_model_path: str = "kokoro-v0_19.onnx"
|
||||
voices_dir: str = "voices"
|
||||
use_gpu: bool = True # Whether to use GPU acceleration if available
|
||||
use_onnx: bool = False # Whether to use ONNX runtime
|
||||
allow_local_voice_saving: bool = False # Whether to allow saving combined voices locally
|
||||
|
||||
# Container absolute paths
|
||||
model_dir: str = "/app/api/src/models" # Absolute path in container
|
||||
voices_dir: str = "/app/api/src/voices" # Absolute path in container
|
||||
|
||||
# Audio Settings
|
||||
sample_rate: int = 24000
|
||||
max_chunk_size: int = 300 # Maximum size of text chunks for processing
|
||||
max_chunk_size: int = 400 # Maximum size of text chunks for processing
|
||||
gap_trim_ms: int = 250 # Amount to trim from streaming chunk ends in milliseconds
|
||||
|
||||
# ONNX Optimization Settings
|
||||
onnx_num_threads: int = 4 # Number of threads for intra-op parallelism
|
||||
onnx_inter_op_threads: int = 4 # Number of threads for inter-op parallelism
|
||||
onnx_execution_mode: str = "parallel" # parallel or sequential
|
||||
onnx_optimization_level: str = "all" # all, basic, or disabled
|
||||
onnx_memory_pattern: bool = True # Enable memory pattern optimization
|
||||
onnx_arena_extend_strategy: str = "kNextPowerOfTwo" # Memory allocation strategy
|
||||
# Web Player Settings
|
||||
enable_web_player: bool = True # Whether to serve the web player UI
|
||||
web_player_path: str = "web" # Path to web player static files
|
||||
cors_origins: list[str] = ["*"] # CORS origins for web player
|
||||
cors_enabled: bool = True # Whether to enable CORS
|
||||
|
||||
# Temp File Settings for WEB Ui
|
||||
temp_file_dir: str = "api/temp_files" # Directory for temporary audio files (relative to project root)
|
||||
max_temp_dir_size_mb: int = 2048 # Maximum size of temp directory (2GB)
|
||||
max_temp_dir_age_hours: int = 1 # Remove temp files older than 1 hour
|
||||
max_temp_dir_count: int = 3 # Maximum number of temp files to keep
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
|
113
api/src/core/model_config.py
Normal file
|
@ -0,0 +1,113 @@
|
|||
"""Model configuration schemas."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ONNXCPUConfig(BaseModel):
|
||||
"""ONNX CPU runtime configuration."""
|
||||
|
||||
# Session pooling
|
||||
max_instances: int = Field(4, description="Maximum concurrent model instances")
|
||||
instance_timeout: int = Field(60, description="Session timeout in seconds")
|
||||
|
||||
# Runtime settings
|
||||
num_threads: int = Field(8, description="Number of threads for parallel operations")
|
||||
inter_op_threads: int = Field(4, description="Number of threads for operator parallelism")
|
||||
execution_mode: str = Field("parallel", description="ONNX execution mode")
|
||||
optimization_level: str = Field("all", description="ONNX optimization level")
|
||||
memory_pattern: bool = Field(True, description="Enable memory pattern optimization")
|
||||
arena_extend_strategy: str = Field("kNextPowerOfTwo", description="Memory arena strategy")
|
||||
|
||||
class Config:
|
||||
frozen = True
|
||||
|
||||
|
||||
class ONNXGPUConfig(ONNXCPUConfig):
|
||||
"""ONNX GPU-specific configuration."""
|
||||
|
||||
# CUDA settings
|
||||
device_id: int = Field(0, description="CUDA device ID")
|
||||
gpu_mem_limit: float = Field(0.5, description="Fraction of GPU memory to use")
|
||||
cudnn_conv_algo_search: str = Field("EXHAUSTIVE", description="CuDNN convolution algorithm search")
|
||||
|
||||
# Stream management
|
||||
cuda_streams: int = Field(2, description="Number of CUDA streams for inference")
|
||||
stream_timeout: int = Field(60, description="Stream timeout in seconds")
|
||||
do_copy_in_default_stream: bool = Field(True, description="Copy in default CUDA stream")
|
||||
|
||||
class Config:
|
||||
frozen = True
|
||||
|
||||
|
||||
class PyTorchCPUConfig(BaseModel):
|
||||
"""PyTorch CPU backend configuration."""
|
||||
|
||||
memory_threshold: float = Field(0.8, description="Memory threshold for cleanup")
|
||||
retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors")
|
||||
num_threads: int = Field(8, description="Number of threads for parallel operations")
|
||||
pin_memory: bool = Field(True, description="Whether to pin memory for faster CPU-GPU transfer")
|
||||
|
||||
class Config:
|
||||
frozen = True
|
||||
|
||||
|
||||
class PyTorchGPUConfig(BaseModel):
|
||||
"""PyTorch GPU backend configuration."""
|
||||
|
||||
device_id: int = Field(0, description="CUDA device ID")
|
||||
use_triton: bool = Field(True, description="Whether to use Triton for CUDA kernels")
|
||||
memory_threshold: float = Field(0.8, description="Memory threshold for cleanup")
|
||||
retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors")
|
||||
sync_cuda: bool = Field(True, description="Whether to synchronize CUDA operations")
|
||||
cuda_streams: int = Field(2, description="Number of CUDA streams for inference")
|
||||
stream_timeout: int = Field(60, description="Stream timeout in seconds")
|
||||
|
||||
class Config:
|
||||
frozen = True
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
"""Model configuration."""
|
||||
|
||||
# General settings
|
||||
model_type: str = Field("pytorch", description="Model type ('pytorch' or 'onnx')")
|
||||
device_type: str = Field("auto", description="Device type ('cpu', 'gpu', or 'auto')")
|
||||
cache_models: bool = Field(True, description="Whether to cache loaded models")
|
||||
cache_voices: bool = Field(True, description="Whether to cache voice tensors")
|
||||
voice_cache_size: int = Field(2, description="Maximum number of cached voices")
|
||||
|
||||
# Model filenames
|
||||
pytorch_model_file: str = Field("kokoro-v0_19-half.pth", description="PyTorch model filename")
|
||||
onnx_model_file: str = Field("kokoro-v0_19.onnx", description="ONNX model filename")
|
||||
|
||||
# Backend-specific configs
|
||||
onnx_cpu: ONNXCPUConfig = Field(default_factory=ONNXCPUConfig)
|
||||
onnx_gpu: ONNXGPUConfig = Field(default_factory=ONNXGPUConfig)
|
||||
pytorch_cpu: PyTorchCPUConfig = Field(default_factory=PyTorchCPUConfig)
|
||||
pytorch_gpu: PyTorchGPUConfig = Field(default_factory=PyTorchGPUConfig)
|
||||
|
||||
class Config:
|
||||
frozen = True
|
||||
|
||||
def get_backend_config(self, backend_type: str):
|
||||
"""Get configuration for specific backend.
|
||||
|
||||
Args:
|
||||
backend_type: Backend type ('pytorch_cpu', 'pytorch_gpu', 'onnx_cpu', 'onnx_gpu')
|
||||
|
||||
Returns:
|
||||
Backend-specific configuration
|
||||
|
||||
Raises:
|
||||
ValueError: If backend type is invalid
|
||||
"""
|
||||
if backend_type not in {
|
||||
'pytorch_cpu', 'pytorch_gpu', 'onnx_cpu', 'onnx_gpu'
|
||||
}:
|
||||
raise ValueError(f"Invalid backend type: {backend_type}")
|
||||
|
||||
return getattr(self, backend_type)
|
||||
|
||||
|
||||
# Global instance
|
||||
model_config = ModelConfig()
|
18
api/src/core/openai_mappings.json
Normal file
|
@ -0,0 +1,18 @@
|
|||
{
|
||||
"models": {
|
||||
"tts-1": "kokoro-v0_19",
|
||||
"tts-1-hd": "kokoro-v0_19",
|
||||
"kokoro": "kokoro-v0_19"
|
||||
},
|
||||
"voices": {
|
||||
"alloy": "am_adam",
|
||||
"ash": "af_nicole",
|
||||
"coral": "bf_emma",
|
||||
"echo": "af_bella",
|
||||
"fable": "af_sarah",
|
||||
"onyx": "bm_george",
|
||||
"nova": "bf_isabella",
|
||||
"sage": "am_michael",
|
||||
"shimmer": "af_sky"
|
||||
}
|
||||
}
|
414
api/src/core/paths.py
Normal file
|
@ -0,0 +1,414 @@
|
|||
"""Async file and path operations."""
|
||||
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, AsyncIterator, Callable, Set, Dict, Any
|
||||
|
||||
import aiofiles
|
||||
import aiofiles.os
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from .config import settings
|
||||
|
||||
|
||||
async def _find_file(
|
||||
filename: str,
|
||||
search_paths: List[str],
|
||||
filter_fn: Optional[Callable[[str], bool]] = None
|
||||
) -> str:
|
||||
"""Find file in search paths.
|
||||
|
||||
Args:
|
||||
filename: Name of file to find
|
||||
search_paths: List of paths to search in
|
||||
filter_fn: Optional function to filter files
|
||||
|
||||
Returns:
|
||||
Absolute path to file
|
||||
|
||||
Raises:
|
||||
RuntimeError: If file not found
|
||||
"""
|
||||
if os.path.isabs(filename) and await aiofiles.os.path.exists(filename):
|
||||
return filename
|
||||
|
||||
for path in search_paths:
|
||||
full_path = os.path.join(path, filename)
|
||||
if await aiofiles.os.path.exists(full_path):
|
||||
if filter_fn is None or filter_fn(full_path):
|
||||
return full_path
|
||||
|
||||
raise RuntimeError(f"File not found: {filename} in paths: {search_paths}")
|
||||
|
||||
|
||||
async def _scan_directories(
|
||||
search_paths: List[str],
|
||||
filter_fn: Optional[Callable[[str], bool]] = None
|
||||
) -> Set[str]:
|
||||
"""Scan directories for files.
|
||||
|
||||
Args:
|
||||
search_paths: List of paths to scan
|
||||
filter_fn: Optional function to filter files
|
||||
|
||||
Returns:
|
||||
Set of matching filenames
|
||||
"""
|
||||
results = set()
|
||||
|
||||
for path in search_paths:
|
||||
if not await aiofiles.os.path.exists(path):
|
||||
continue
|
||||
|
||||
try:
|
||||
# Get directory entries first
|
||||
entries = await aiofiles.os.scandir(path)
|
||||
# Then process entries after await completes
|
||||
for entry in entries:
|
||||
if filter_fn is None or filter_fn(entry.name):
|
||||
results.add(entry.name)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error scanning {path}: {e}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def get_model_path(model_name: str) -> str:
|
||||
"""Get path to model file.
|
||||
|
||||
Args:
|
||||
model_name: Name of model file
|
||||
|
||||
Returns:
|
||||
Absolute path to model file
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model not found
|
||||
"""
|
||||
# Get api directory path (two levels up from core)
|
||||
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
# Construct model directory path relative to api directory
|
||||
model_dir = os.path.join(api_dir, settings.model_dir)
|
||||
|
||||
# Ensure model directory exists
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
# Search in model directory
|
||||
search_paths = [model_dir]
|
||||
logger.debug(f"Searching for model in path: {model_dir}")
|
||||
|
||||
return await _find_file(model_name, search_paths)
|
||||
|
||||
|
||||
async def get_voice_path(voice_name: str) -> str:
|
||||
"""Get path to voice file.
|
||||
|
||||
Args:
|
||||
voice_name: Name of voice file (without .pt extension)
|
||||
|
||||
Returns:
|
||||
Absolute path to voice file
|
||||
|
||||
Raises:
|
||||
RuntimeError: If voice not found
|
||||
"""
|
||||
# Get api directory path
|
||||
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
# Construct voice directory path relative to api directory
|
||||
voice_dir = os.path.join(api_dir, settings.voices_dir)
|
||||
|
||||
# Ensure voice directory exists
|
||||
os.makedirs(voice_dir, exist_ok=True)
|
||||
|
||||
voice_file = f"{voice_name}.pt"
|
||||
|
||||
# Search in voice directory
|
||||
search_paths = [voice_dir]
|
||||
logger.debug(f"Searching for voice in path: {voice_dir}")
|
||||
|
||||
return await _find_file(voice_file, search_paths)
|
||||
|
||||
|
||||
async def list_voices() -> List[str]:
|
||||
"""List available voice files.
|
||||
|
||||
Returns:
|
||||
List of voice names (without .pt extension)
|
||||
"""
|
||||
# Get api directory path
|
||||
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
# Construct voice directory path relative to api directory
|
||||
voice_dir = os.path.join(api_dir, settings.voices_dir)
|
||||
|
||||
# Ensure voice directory exists
|
||||
os.makedirs(voice_dir, exist_ok=True)
|
||||
|
||||
# Search in voice directory
|
||||
search_paths = [voice_dir]
|
||||
logger.debug(f"Scanning for voices in path: {voice_dir}")
|
||||
|
||||
def filter_voice_files(name: str) -> bool:
|
||||
return name.endswith('.pt')
|
||||
|
||||
voices = await _scan_directories(search_paths, filter_voice_files)
|
||||
return sorted([name[:-3] for name in voices]) # Remove .pt extension
|
||||
|
||||
|
||||
async def load_voice_tensor(voice_path: str, device: str = "cpu") -> torch.Tensor:
|
||||
"""Load voice tensor from file.
|
||||
|
||||
Args:
|
||||
voice_path: Path to voice file
|
||||
device: Device to load tensor to
|
||||
|
||||
Returns:
|
||||
Voice tensor
|
||||
|
||||
Raises:
|
||||
RuntimeError: If file cannot be read
|
||||
"""
|
||||
try:
|
||||
async with aiofiles.open(voice_path, 'rb') as f:
|
||||
data = await f.read()
|
||||
return torch.load(
|
||||
io.BytesIO(data),
|
||||
map_location=device,
|
||||
weights_only=True
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load voice tensor from {voice_path}: {e}")
|
||||
|
||||
|
||||
async def save_voice_tensor(tensor: torch.Tensor, voice_path: str) -> None:
|
||||
"""Save voice tensor to file.
|
||||
|
||||
Args:
|
||||
tensor: Voice tensor to save
|
||||
voice_path: Path to save voice file
|
||||
|
||||
Raises:
|
||||
RuntimeError: If file cannot be written
|
||||
"""
|
||||
try:
|
||||
buffer = io.BytesIO()
|
||||
torch.save(tensor, buffer)
|
||||
async with aiofiles.open(voice_path, 'wb') as f:
|
||||
await f.write(buffer.getvalue())
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to save voice tensor to {voice_path}: {e}")
|
||||
|
||||
|
||||
async def load_json(path: str) -> dict:
|
||||
"""Load JSON file asynchronously.
|
||||
|
||||
Args:
|
||||
path: Path to JSON file
|
||||
|
||||
Returns:
|
||||
Parsed JSON data
|
||||
|
||||
Raises:
|
||||
RuntimeError: If file cannot be read or parsed
|
||||
"""
|
||||
try:
|
||||
async with aiofiles.open(path, 'r', encoding='utf-8') as f:
|
||||
content = await f.read()
|
||||
return json.loads(content)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load JSON file {path}: {e}")
|
||||
|
||||
|
||||
async def load_model_weights(path: str, device: str = "cpu") -> dict:
|
||||
"""Load model weights asynchronously.
|
||||
|
||||
Args:
|
||||
path: Path to model file (.pth or .onnx)
|
||||
device: Device to load model to
|
||||
|
||||
Returns:
|
||||
Model weights
|
||||
|
||||
Raises:
|
||||
RuntimeError: If file cannot be read
|
||||
"""
|
||||
try:
|
||||
async with aiofiles.open(path, 'rb') as f:
|
||||
data = await f.read()
|
||||
return torch.load(
|
||||
io.BytesIO(data),
|
||||
map_location=device,
|
||||
weights_only=True
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load model weights from {path}: {e}")
|
||||
|
||||
|
||||
async def read_file(path: str) -> str:
|
||||
"""Read text file asynchronously.
|
||||
|
||||
Args:
|
||||
path: Path to file
|
||||
|
||||
Returns:
|
||||
File contents as string
|
||||
|
||||
Raises:
|
||||
RuntimeError: If file cannot be read
|
||||
"""
|
||||
try:
|
||||
async with aiofiles.open(path, 'r', encoding='utf-8') as f:
|
||||
return await f.read()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to read file {path}: {e}")
|
||||
|
||||
|
||||
async def read_bytes(path: str) -> bytes:
|
||||
"""Read file as bytes asynchronously.
|
||||
|
||||
Args:
|
||||
path: Path to file
|
||||
|
||||
Returns:
|
||||
File contents as bytes
|
||||
|
||||
Raises:
|
||||
RuntimeError: If file cannot be read
|
||||
"""
|
||||
try:
|
||||
async with aiofiles.open(path, 'rb') as f:
|
||||
return await f.read()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to read file {path}: {e}")
|
||||
|
||||
|
||||
async def get_web_file_path(filename: str) -> str:
|
||||
"""Get path to web static file.
|
||||
|
||||
Args:
|
||||
filename: Name of file in web directory
|
||||
|
||||
Returns:
|
||||
Absolute path to file
|
||||
|
||||
Raises:
|
||||
RuntimeError: If file not found
|
||||
"""
|
||||
# Get project root directory (four levels up from core to get to project root)
|
||||
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
|
||||
# Construct web directory path relative to project root
|
||||
web_dir = os.path.join("/app", settings.web_player_path)
|
||||
|
||||
# Search in web directory
|
||||
search_paths = [web_dir]
|
||||
logger.debug(f"Searching for web file in path: {web_dir}")
|
||||
|
||||
return await _find_file(filename, search_paths)
|
||||
|
||||
|
||||
async def get_content_type(path: str) -> str:
|
||||
"""Get content type for file.
|
||||
|
||||
Args:
|
||||
path: Path to file
|
||||
|
||||
Returns:
|
||||
Content type string
|
||||
"""
|
||||
ext = os.path.splitext(path)[1].lower()
|
||||
return {
|
||||
'.html': 'text/html',
|
||||
'.js': 'application/javascript',
|
||||
'.css': 'text/css',
|
||||
'.png': 'image/png',
|
||||
'.jpg': 'image/jpeg',
|
||||
'.jpeg': 'image/jpeg',
|
||||
'.gif': 'image/gif',
|
||||
'.svg': 'image/svg+xml',
|
||||
'.ico': 'image/x-icon',
|
||||
}.get(ext, 'application/octet-stream')
|
||||
|
||||
|
||||
async def verify_model_path(model_path: str) -> bool:
|
||||
"""Verify model file exists at path."""
|
||||
return await aiofiles.os.path.exists(model_path)
|
||||
|
||||
|
||||
async def cleanup_temp_files() -> None:
|
||||
"""Clean up old temp files on startup"""
|
||||
try:
|
||||
if not await aiofiles.os.path.exists(settings.temp_file_dir):
|
||||
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
|
||||
return
|
||||
|
||||
entries = await aiofiles.os.scandir(settings.temp_file_dir)
|
||||
for entry in entries:
|
||||
if entry.is_file():
|
||||
stat = await aiofiles.os.stat(entry.path)
|
||||
max_age = stat.st_mtime + (settings.max_temp_dir_age_hours * 3600)
|
||||
if max_age < stat.st_mtime:
|
||||
try:
|
||||
await aiofiles.os.remove(entry.path)
|
||||
logger.info(f"Cleaned up old temp file: {entry.name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete old temp file {entry.name}: {e}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error cleaning temp files: {e}")
|
||||
|
||||
|
||||
async def get_temp_file_path(filename: str) -> str:
|
||||
"""Get path to temporary audio file.
|
||||
|
||||
Args:
|
||||
filename: Name of temp file
|
||||
|
||||
Returns:
|
||||
Absolute path to temp file
|
||||
|
||||
Raises:
|
||||
RuntimeError: If temp directory does not exist
|
||||
"""
|
||||
temp_path = os.path.join(settings.temp_file_dir, filename)
|
||||
|
||||
# Ensure temp directory exists
|
||||
if not await aiofiles.os.path.exists(settings.temp_file_dir):
|
||||
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
|
||||
|
||||
return temp_path
|
||||
|
||||
|
||||
async def list_temp_files() -> List[str]:
|
||||
"""List temporary audio files.
|
||||
|
||||
Returns:
|
||||
List of temp file names
|
||||
"""
|
||||
if not await aiofiles.os.path.exists(settings.temp_file_dir):
|
||||
return []
|
||||
|
||||
entries = await aiofiles.os.scandir(settings.temp_file_dir)
|
||||
return [entry.name for entry in entries if entry.is_file()]
|
||||
|
||||
|
||||
async def get_temp_dir_size() -> int:
|
||||
"""Get total size of temp directory in bytes.
|
||||
|
||||
Returns:
|
||||
Size in bytes
|
||||
"""
|
||||
if not await aiofiles.os.path.exists(settings.temp_file_dir):
|
||||
return 0
|
||||
|
||||
total = 0
|
||||
entries = await aiofiles.os.scandir(settings.temp_file_dir)
|
||||
for entry in entries:
|
||||
if entry.is_file():
|
||||
stat = await aiofiles.os.stat(entry.path)
|
||||
total += stat.st_size
|
||||
return total
|
16
api/src/inference/__init__.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
"""Model inference package."""
|
||||
|
||||
from .base import BaseModelBackend
|
||||
from .model_manager import ModelManager, get_manager
|
||||
from .onnx_cpu import ONNXCPUBackend
|
||||
from .onnx_gpu import ONNXGPUBackend
|
||||
from .pytorch_backend import PyTorchBackend
|
||||
|
||||
__all__ = [
|
||||
'BaseModelBackend',
|
||||
'ModelManager',
|
||||
'get_manager',
|
||||
'ONNXCPUBackend',
|
||||
'ONNXGPUBackend',
|
||||
'PyTorchBackend',
|
||||
]
|
97
api/src/inference/base.py
Normal file
|
@ -0,0 +1,97 @@
|
|||
"""Base interfaces for model inference."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class ModelBackend(ABC):
|
||||
"""Abstract base class for model inference backends."""
|
||||
|
||||
@abstractmethod
|
||||
async def load_model(self, path: str) -> None:
|
||||
"""Load model from path.
|
||||
|
||||
Args:
|
||||
path: Path to model file
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model loading fails
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def generate(
|
||||
self,
|
||||
tokens: List[int],
|
||||
voice: torch.Tensor,
|
||||
speed: float = 1.0
|
||||
) -> np.ndarray:
|
||||
"""Generate audio from tokens.
|
||||
|
||||
Args:
|
||||
tokens: Input token IDs
|
||||
voice: Voice embedding tensor
|
||||
speed: Speed multiplier
|
||||
|
||||
Returns:
|
||||
Generated audio samples
|
||||
|
||||
Raises:
|
||||
RuntimeError: If generation fails
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unload(self) -> None:
|
||||
"""Unload model and free resources."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if model is loaded.
|
||||
|
||||
Returns:
|
||||
True if model is loaded, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def device(self) -> str:
|
||||
"""Get device model is running on.
|
||||
|
||||
Returns:
|
||||
Device string ('cpu' or 'cuda')
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class BaseModelBackend(ModelBackend):
|
||||
"""Base implementation of model backend."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize base backend."""
|
||||
self._model: Optional[torch.nn.Module] = None
|
||||
self._device: str = "cpu"
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if model is loaded."""
|
||||
return self._model is not None
|
||||
|
||||
@property
|
||||
def device(self) -> str:
|
||||
"""Get device model is running on."""
|
||||
return self._device
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Unload model and free resources."""
|
||||
if self._model is not None:
|
||||
del self._model
|
||||
self._model = None
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
339
api/src/inference/model_manager.py
Normal file
|
@ -0,0 +1,339 @@
|
|||
"""Model management and caching."""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from ..core import paths
|
||||
from ..core.config import settings
|
||||
from ..core.model_config import ModelConfig, model_config
|
||||
from .base import BaseModelBackend
|
||||
from .onnx_cpu import ONNXCPUBackend
|
||||
from .onnx_gpu import ONNXGPUBackend
|
||||
from .pytorch_backend import PyTorchBackend
|
||||
from .session_pool import CPUSessionPool, StreamingSessionPool
|
||||
|
||||
|
||||
# Global singleton instance and lock for thread-safe initialization
|
||||
_manager_instance = None
|
||||
_manager_lock = asyncio.Lock()
|
||||
|
||||
class ModelManager:
|
||||
"""Manages model loading and inference across backends."""
|
||||
# Class-level state for shared resources
|
||||
_loaded_models = {}
|
||||
_backends = {}
|
||||
def __init__(self, config: Optional[ModelConfig] = None):
|
||||
"""Initialize model manager.
|
||||
Note:
|
||||
This should not be called directly. Use get_manager() instead.
|
||||
"""
|
||||
self._config = config or model_config
|
||||
|
||||
# Initialize session pools
|
||||
self._session_pools = {
|
||||
'onnx_cpu': CPUSessionPool(),
|
||||
'onnx_gpu': StreamingSessionPool()
|
||||
}
|
||||
|
||||
# Initialize locks
|
||||
self._backend_locks: Dict[str, asyncio.Lock] = {}
|
||||
|
||||
def _determine_device(self) -> str:
|
||||
"""Determine device based on settings."""
|
||||
if settings.use_gpu and torch.cuda.is_available():
|
||||
return "cuda"
|
||||
return "cpu"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize backends."""
|
||||
if self._backends:
|
||||
logger.debug("Using existing backend instances")
|
||||
return
|
||||
|
||||
device = self._determine_device()
|
||||
|
||||
try:
|
||||
if device == "cuda":
|
||||
if settings.use_onnx:
|
||||
self._backends['onnx_gpu'] = ONNXGPUBackend()
|
||||
self._current_backend = 'onnx_gpu'
|
||||
logger.info("Initialized new ONNX GPU backend")
|
||||
else:
|
||||
self._backends['pytorch'] = PyTorchBackend()
|
||||
self._current_backend = 'pytorch'
|
||||
logger.info("Initialized new PyTorch backend on GPU")
|
||||
else:
|
||||
if settings.use_onnx:
|
||||
self._backends['onnx_cpu'] = ONNXCPUBackend()
|
||||
self._current_backend = 'onnx_cpu'
|
||||
logger.info("Initialized new ONNX CPU backend")
|
||||
else:
|
||||
self._backends['pytorch'] = PyTorchBackend()
|
||||
self._current_backend = 'pytorch'
|
||||
logger.info("Initialized new PyTorch backend on CPU")
|
||||
|
||||
# Initialize locks for each backend
|
||||
for backend in self._backends:
|
||||
self._backend_locks[backend] = asyncio.Lock()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize backend: {e}")
|
||||
raise RuntimeError("Failed to initialize backend")
|
||||
|
||||
async def initialize_with_warmup(self, voice_manager) -> tuple[str, str, int]:
|
||||
"""Initialize model with warmup and pre-cache voices.
|
||||
Args:
|
||||
voice_manager: Voice manager instance for loading voices
|
||||
Returns:
|
||||
Tuple of (device type, model type, number of loaded voices)
|
||||
Raises:
|
||||
RuntimeError: If initialization fails
|
||||
"""
|
||||
try:
|
||||
# Determine backend type based on settings
|
||||
if settings.use_onnx:
|
||||
backend_type = 'onnx_gpu' if settings.use_gpu and torch.cuda.is_available() else 'onnx_cpu'
|
||||
else:
|
||||
backend_type = 'pytorch'
|
||||
|
||||
# Get backend
|
||||
backend = self.get_backend(backend_type)
|
||||
|
||||
# Get and verify model path
|
||||
model_file = model_config.pytorch_model_file if not settings.use_onnx else model_config.onnx_model_file
|
||||
model_path = await paths.get_model_path(model_file)
|
||||
|
||||
if not await paths.verify_model_path(model_path):
|
||||
raise RuntimeError(f"Model file not found: {model_path}")
|
||||
|
||||
# Pre-cache default voice and use for warmup
|
||||
warmup_voice = await voice_manager.load_voice(
|
||||
settings.default_voice, device=backend.device)
|
||||
logger.info(f"Pre-cached voice {settings.default_voice} for warmup")
|
||||
|
||||
# Initialize model with warmup voice
|
||||
await self.load_model(model_path, warmup_voice, backend_type)
|
||||
|
||||
# Only pre-cache default voice to avoid memory bloat
|
||||
logger.info(f"Using {settings.default_voice} as warmup voice")
|
||||
|
||||
# Get available voices count
|
||||
voices = await voice_manager.list_voices()
|
||||
voicepack_count = len(voices)
|
||||
|
||||
# Get device info for return
|
||||
device = "GPU" if settings.use_gpu else "CPU"
|
||||
model = "ONNX" if settings.use_onnx else "PyTorch"
|
||||
|
||||
return device, model, voicepack_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize model with warmup: {e}")
|
||||
raise RuntimeError(f"Failed to initialize model with warmup: {e}")
|
||||
|
||||
def get_backend(self, backend_type: Optional[str] = None) -> BaseModelBackend:
|
||||
"""Get specified backend.
|
||||
Args:
|
||||
backend_type: Backend type ('pytorch_cpu', 'pytorch_gpu', 'onnx_cpu', 'onnx_gpu'),
|
||||
uses default if None
|
||||
Returns:
|
||||
Model backend instance
|
||||
Raises:
|
||||
ValueError: If backend type is invalid
|
||||
RuntimeError: If no backends are available
|
||||
"""
|
||||
if not self._backends:
|
||||
raise RuntimeError("No backends available")
|
||||
|
||||
if backend_type is None:
|
||||
backend_type = self._current_backend
|
||||
|
||||
if backend_type not in self._backends:
|
||||
raise ValueError(
|
||||
f"Invalid backend type: {backend_type}. "
|
||||
f"Available backends: {', '.join(self._backends.keys())}"
|
||||
)
|
||||
|
||||
return self._backends[backend_type]
|
||||
|
||||
def _determine_backend(self, model_path: str) -> str:
|
||||
"""Determine appropriate backend based on model file and settings.
|
||||
Args:
|
||||
model_path: Path to model file
|
||||
Returns:
|
||||
Backend type to use
|
||||
"""
|
||||
# If ONNX is preferred or model is ONNX format
|
||||
if settings.use_onnx or model_path.lower().endswith('.onnx'):
|
||||
return 'onnx_gpu' if settings.use_gpu and torch.cuda.is_available() else 'onnx_cpu'
|
||||
return 'pytorch'
|
||||
|
||||
async def load_model(
|
||||
self,
|
||||
model_path: str,
|
||||
warmup_voice: Optional[torch.Tensor] = None,
|
||||
backend_type: Optional[str] = None
|
||||
) -> None:
|
||||
"""Load model on specified backend.
|
||||
Args:
|
||||
model_path: Path to model file
|
||||
warmup_voice: Optional voice tensor for warmup, skips warmup if None
|
||||
backend_type: Backend to load on, uses default if None
|
||||
Raises:
|
||||
RuntimeError: If model loading fails
|
||||
"""
|
||||
try:
|
||||
# Get absolute model path
|
||||
abs_path = await paths.get_model_path(model_path)
|
||||
|
||||
# Auto-determine backend if not specified
|
||||
if backend_type is None:
|
||||
backend_type = self._determine_backend(abs_path)
|
||||
|
||||
# Get backend lock
|
||||
lock = self._backend_locks[backend_type]
|
||||
|
||||
async with lock:
|
||||
backend = self.get_backend(backend_type)
|
||||
|
||||
# For ONNX backends, use session pool
|
||||
if backend_type.startswith('onnx'):
|
||||
pool = self._session_pools[backend_type]
|
||||
backend._session = await pool.get_session(abs_path)
|
||||
self._loaded_models[backend_type] = abs_path
|
||||
logger.info(f"Fetched model instance from {backend_type} pool")
|
||||
|
||||
# For PyTorch backends, load normally
|
||||
else:
|
||||
# Check if model is already loaded
|
||||
if (backend_type in self._loaded_models and
|
||||
self._loaded_models[backend_type] == abs_path and
|
||||
backend.is_loaded):
|
||||
logger.info(f"Fetching existing model instance from {backend_type}")
|
||||
return
|
||||
|
||||
# Load model
|
||||
await backend.load_model(abs_path)
|
||||
self._loaded_models[backend_type] = abs_path
|
||||
logger.info(f"Initialized new model instance on {backend_type}")
|
||||
|
||||
# Run warmup if voice provided
|
||||
if warmup_voice is not None:
|
||||
await self._warmup_inference(backend, warmup_voice)
|
||||
|
||||
except Exception as e:
|
||||
# Clear cached path on failure
|
||||
self._loaded_models.pop(backend_type, None)
|
||||
raise RuntimeError(f"Failed to load model: {e}")
|
||||
|
||||
async def _warmup_inference(self, backend: BaseModelBackend, voice: torch.Tensor) -> None:
|
||||
"""Run warmup inference to initialize model.
|
||||
|
||||
Args:
|
||||
backend: Model backend to warm up
|
||||
voice: Voice tensor already loaded on correct device
|
||||
"""
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from ..services.text_processing import process_text
|
||||
|
||||
# Use real text
|
||||
text = "Testing text to speech synthesis."
|
||||
|
||||
# Process through pipeline
|
||||
tokens = process_text(text)
|
||||
if not tokens:
|
||||
raise ValueError("Text processing failed")
|
||||
|
||||
# Run inference
|
||||
backend.generate(tokens, voice, speed=1.0)
|
||||
logger.debug("Completed warmup inference")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Warmup inference failed: {e}")
|
||||
raise
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
tokens: list[int],
|
||||
voice: torch.Tensor,
|
||||
speed: float = 1.0,
|
||||
backend_type: Optional[str] = None
|
||||
) -> torch.Tensor:
|
||||
"""Generate audio using specified backend.
|
||||
|
||||
Args:
|
||||
tokens: Input token IDs
|
||||
voice: Voice tensor already loaded on correct device
|
||||
speed: Speed multiplier
|
||||
backend_type: Backend to use, uses default if None
|
||||
|
||||
Returns:
|
||||
Generated audio tensor
|
||||
|
||||
Raises:
|
||||
RuntimeError: If generation fails
|
||||
"""
|
||||
backend = self.get_backend(backend_type)
|
||||
if not backend.is_loaded:
|
||||
raise RuntimeError("Model not loaded")
|
||||
|
||||
try:
|
||||
# Generate audio using provided voice tensor
|
||||
# No lock needed here since inference is thread-safe
|
||||
return backend.generate(tokens, voice, speed)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Generation failed: {e}")
|
||||
|
||||
def unload_all(self) -> None:
|
||||
"""Unload models from all backends and clear cache."""
|
||||
# Clean up session pools
|
||||
for pool in self._session_pools.values():
|
||||
pool.cleanup()
|
||||
|
||||
# Unload PyTorch backends
|
||||
for backend in self._backends.values():
|
||||
backend.unload()
|
||||
|
||||
self._loaded_models.clear()
|
||||
logger.info("Unloaded all models and cleared cache")
|
||||
|
||||
@property
|
||||
def available_backends(self) -> list[str]:
|
||||
"""Get list of available backends.
|
||||
"""
|
||||
return list(self._backends.keys())
|
||||
|
||||
@property
|
||||
def current_backend(self) -> str:
|
||||
"""Get current default backend.
|
||||
"""
|
||||
return self._current_backend
|
||||
|
||||
|
||||
async def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
|
||||
"""Get global model manager instance.
|
||||
Args:
|
||||
config: Optional model configuration
|
||||
Returns:
|
||||
ModelManager instance
|
||||
Thread Safety:
|
||||
This function should be thread-safe. Lemme know if it unravels on you
|
||||
"""
|
||||
global _manager_instance
|
||||
|
||||
# Fast path - return existing instance without lock
|
||||
if _manager_instance is not None:
|
||||
return _manager_instance
|
||||
|
||||
# Slow path - create new instance with lock
|
||||
async with _manager_lock:
|
||||
# Double-check pattern
|
||||
if _manager_instance is None:
|
||||
_manager_instance = ModelManager(config)
|
||||
await _manager_instance.initialize()
|
||||
return _manager_instance
|
||||
|
115
api/src/inference/onnx_cpu.py
Normal file
|
@ -0,0 +1,115 @@
|
|||
"""CPU-based ONNX inference backend."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
from ..core import paths
|
||||
from ..core.model_config import model_config
|
||||
from .base import BaseModelBackend
|
||||
from .session_pool import create_session_options, create_provider_options
|
||||
|
||||
|
||||
class ONNXCPUBackend(BaseModelBackend):
|
||||
"""ONNX-based CPU inference backend."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize CPU backend."""
|
||||
super().__init__()
|
||||
self._device = "cpu"
|
||||
self._session: Optional[InferenceSession] = None
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if model is loaded."""
|
||||
return self._session is not None
|
||||
|
||||
async def load_model(self, path: str) -> None:
|
||||
"""Load ONNX model.
|
||||
|
||||
Args:
|
||||
path: Path to model file
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model loading fails
|
||||
"""
|
||||
try:
|
||||
# Get verified model path
|
||||
model_path = await paths.get_model_path(path)
|
||||
|
||||
logger.info(f"Loading ONNX model: {model_path}")
|
||||
|
||||
# Configure session
|
||||
options = create_session_options(is_gpu=False)
|
||||
provider_options = create_provider_options(is_gpu=False)
|
||||
|
||||
# Create session
|
||||
self._session = InferenceSession(
|
||||
model_path,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
provider_options=[provider_options]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load ONNX model: {e}")
|
||||
|
||||
def generate(
|
||||
self,
|
||||
tokens: list[int],
|
||||
voice: torch.Tensor,
|
||||
speed: float = 1.0
|
||||
) -> np.ndarray:
|
||||
"""Generate audio using ONNX model.
|
||||
|
||||
Args:
|
||||
tokens: Input token IDs
|
||||
voice: Voice embedding tensor
|
||||
speed: Speed multiplier
|
||||
|
||||
Returns:
|
||||
Generated audio samples
|
||||
|
||||
Raises:
|
||||
RuntimeError: If generation fails
|
||||
"""
|
||||
if not self.is_loaded:
|
||||
raise RuntimeError("Model not loaded")
|
||||
|
||||
try:
|
||||
# Prepare inputs with start/end tokens
|
||||
tokens_input = np.array([[0, *tokens, 0]], dtype=np.int64) # Add start/end tokens
|
||||
style_input = voice[len(tokens) + 2].numpy() # Adjust index for start/end tokens
|
||||
speed_input = np.full(1, speed, dtype=np.float32)
|
||||
|
||||
# Build base inputs
|
||||
inputs = {
|
||||
"style": style_input,
|
||||
"speed": speed_input
|
||||
}
|
||||
|
||||
# Try both possible token input names #TODO:
|
||||
for token_name in ["tokens", "input_ids"]:
|
||||
try:
|
||||
inputs[token_name] = tokens_input
|
||||
result = self._session.run(None, inputs)
|
||||
return result[0]
|
||||
except Exception:
|
||||
del inputs[token_name]
|
||||
continue
|
||||
|
||||
raise RuntimeError("Model does not accept either 'tokens' or 'input_ids' as input name")
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Generation failed: {e}")
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Unload model and free resources."""
|
||||
if self._session is not None:
|
||||
del self._session
|
||||
self._session = None
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
119
api/src/inference/onnx_gpu.py
Normal file
|
@ -0,0 +1,119 @@
|
|||
"""GPU-based ONNX inference backend."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
from ..core import paths
|
||||
from ..core.model_config import model_config
|
||||
from .base import BaseModelBackend
|
||||
from .session_pool import create_session_options, create_provider_options
|
||||
|
||||
|
||||
class ONNXGPUBackend(BaseModelBackend):
|
||||
"""ONNX-based GPU inference backend."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize GPU backend."""
|
||||
super().__init__()
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("CUDA not available")
|
||||
self._device = "cuda"
|
||||
self._session: Optional[InferenceSession] = None
|
||||
|
||||
# Configure GPU
|
||||
torch.cuda.set_device(model_config.onnx_gpu.device_id)
|
||||
|
||||
@property
|
||||
def is_loaded(self) -> bool:
|
||||
"""Check if model is loaded."""
|
||||
return self._session is not None
|
||||
|
||||
async def load_model(self, path: str) -> None:
|
||||
"""Load ONNX model.
|
||||
|
||||
Args:
|
||||
path: Path to model file
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model loading fails
|
||||
"""
|
||||
try:
|
||||
# Get verified model path
|
||||
model_path = await paths.get_model_path(path)
|
||||
|
||||
logger.info(f"Loading ONNX model on GPU: {model_path}")
|
||||
|
||||
# Configure session
|
||||
options = create_session_options(is_gpu=True)
|
||||
provider_options = create_provider_options(is_gpu=True)
|
||||
|
||||
# Create session with CUDA provider
|
||||
self._session = InferenceSession(
|
||||
model_path,
|
||||
sess_options=options,
|
||||
providers=["CUDAExecutionProvider"],
|
||||
provider_options=[provider_options]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load ONNX model: {e}")
|
||||
|
||||
def generate(
|
||||
self,
|
||||
tokens: list[int],
|
||||
voice: torch.Tensor,
|
||||
speed: float = 1.0
|
||||
) -> np.ndarray:
|
||||
"""Generate audio using ONNX model.
|
||||
|
||||
Args:
|
||||
tokens: Input token IDs
|
||||
voice: Voice embedding tensor
|
||||
speed: Speed multiplier
|
||||
|
||||
Returns:
|
||||
Generated audio samples
|
||||
|
||||
Raises:
|
||||
RuntimeError: If generation fails
|
||||
"""
|
||||
if not self.is_loaded:
|
||||
raise RuntimeError("Model not loaded")
|
||||
|
||||
try:
|
||||
# Prepare inputs
|
||||
tokens_input = np.array([[0, *tokens, 0]], dtype=np.int64) # Add start/end tokens
|
||||
# Use modulo to ensure index stays within voice tensor bounds
|
||||
style_idx = (len(tokens) + 2) % voice.size(0) # Add 2 for start/end tokens
|
||||
style_input = voice[style_idx].cpu().numpy() # Move to CPU for ONNX
|
||||
speed_input = np.full(1, speed, dtype=np.float32)
|
||||
|
||||
# Run inference
|
||||
result = self._session.run(
|
||||
None,
|
||||
{
|
||||
"tokens": tokens_input,
|
||||
"style": style_input,
|
||||
"speed": speed_input
|
||||
}
|
||||
)
|
||||
|
||||
return result[0]
|
||||
|
||||
except Exception as e:
|
||||
if "out of memory" in str(e).lower():
|
||||
# Clear CUDA cache and retry
|
||||
torch.cuda.empty_cache()
|
||||
return self.generate(tokens, voice, speed)
|
||||
raise RuntimeError(f"Generation failed: {e}")
|
||||
|
||||
def unload(self) -> None:
|
||||
"""Unload model and free resources."""
|
||||
if self._session is not None:
|
||||
del self._session
|
||||
self._session = None
|
||||
torch.cuda.empty_cache()
|
244
api/src/inference/pytorch_backend.py
Normal file
|
@ -0,0 +1,244 @@
|
|||
"""PyTorch inference backend with environment-based configuration."""
|
||||
|
||||
import gc
|
||||
from typing import Optional
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from ..builds.models import build_model
|
||||
from ..core import paths
|
||||
from ..core.model_config import model_config
|
||||
from ..core.config import settings
|
||||
from .base import BaseModelBackend
|
||||
|
||||
|
||||
class CUDAStreamManager:
|
||||
"""CUDA stream manager for GPU operations."""
|
||||
|
||||
def __init__(self, num_streams: int):
|
||||
"""Initialize stream manager.
|
||||
|
||||
Args:
|
||||
num_streams: Number of CUDA streams
|
||||
"""
|
||||
self.streams = [torch.cuda.Stream() for _ in range(num_streams)]
|
||||
self._current = 0
|
||||
|
||||
def get_next_stream(self) -> torch.cuda.Stream:
|
||||
"""Get next available stream.
|
||||
|
||||
Returns:
|
||||
CUDA stream
|
||||
"""
|
||||
stream = self.streams[self._current]
|
||||
self._current = (self._current + 1) % len(self.streams)
|
||||
return stream
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
model: torch.nn.Module,
|
||||
tokens: list[int],
|
||||
ref_s: torch.Tensor,
|
||||
speed: float,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
) -> np.ndarray:
|
||||
"""Forward pass through model.
|
||||
|
||||
Args:
|
||||
model: PyTorch model
|
||||
tokens: Input tokens
|
||||
ref_s: Reference signal
|
||||
speed: Speed multiplier
|
||||
stream: Optional CUDA stream (GPU only)
|
||||
|
||||
Returns:
|
||||
Generated audio
|
||||
"""
|
||||
device = ref_s.device
|
||||
|
||||
# Use provided stream or default for GPU
|
||||
context = (
|
||||
torch.cuda.stream(stream) if stream and device.type == "cuda" else nullcontext()
|
||||
)
|
||||
|
||||
with context:
|
||||
# Initial tensor setup
|
||||
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
||||
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
||||
text_mask = length_to_mask(input_lengths).to(device)
|
||||
|
||||
# Split reference signals
|
||||
style_dim = 128
|
||||
s_ref = ref_s[:, :style_dim].clone().to(device)
|
||||
s_content = ref_s[:, style_dim:].clone().to(device)
|
||||
|
||||
# BERT and encoder pass
|
||||
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
||||
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
||||
|
||||
# Predictor forward pass
|
||||
d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask)
|
||||
x, _ = model.predictor.lstm(d)
|
||||
|
||||
# Duration prediction
|
||||
duration = model.predictor.duration_proj(x)
|
||||
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
||||
pred_dur = torch.round(duration).clamp(min=1).long()
|
||||
del duration, x
|
||||
|
||||
# Alignment matrix construction
|
||||
pred_aln_trg = torch.zeros(
|
||||
input_lengths.item(), pred_dur.sum().item(), device=device
|
||||
)
|
||||
c_frame = 0
|
||||
for i in range(pred_aln_trg.size(0)):
|
||||
pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
|
||||
c_frame += pred_dur[0, i].item()
|
||||
pred_aln_trg = pred_aln_trg.unsqueeze(0)
|
||||
|
||||
# Matrix multiplications
|
||||
en = d.transpose(-1, -2) @ pred_aln_trg
|
||||
del d
|
||||
|
||||
F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
|
||||
del en
|
||||
|
||||
# Final text encoding and decoding
|
||||
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
||||
asr = t_en @ pred_aln_trg
|
||||
del t_en
|
||||
|
||||
# Generate output
|
||||
output = model.decoder(asr, F0_pred, N_pred, s_ref)
|
||||
|
||||
# Ensure operation completion if using custom stream
|
||||
if stream and device.type == "cuda":
|
||||
stream.synchronize()
|
||||
|
||||
return output.squeeze().cpu().numpy()
|
||||
|
||||
|
||||
def length_to_mask(lengths: torch.Tensor) -> torch.Tensor:
|
||||
"""Create attention mask from lengths."""
|
||||
max_len = lengths.max()
|
||||
mask = torch.arange(max_len, device=lengths.device)[None, :].expand(
|
||||
lengths.shape[0], -1
|
||||
)
|
||||
if lengths.dtype != mask.dtype:
|
||||
mask = mask.to(dtype=lengths.dtype)
|
||||
return mask + 1 > lengths[:, None]
|
||||
|
||||
|
||||
class PyTorchBackend(BaseModelBackend):
|
||||
"""PyTorch inference backend with environment-based configuration."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize backend based on environment configuration."""
|
||||
super().__init__()
|
||||
|
||||
# Configure device based on settings
|
||||
self._device = (
|
||||
"cuda" if settings.use_gpu and torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
self._model: Optional[torch.nn.Module] = None
|
||||
|
||||
# Apply device-specific configurations
|
||||
if self._device == "cuda":
|
||||
config = model_config.pytorch_gpu
|
||||
if config.sync_cuda:
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.set_device(config.device_id)
|
||||
self._stream_manager = CUDAStreamManager(config.cuda_streams)
|
||||
else:
|
||||
config = model_config.pytorch_cpu
|
||||
if config.num_threads > 0:
|
||||
torch.set_num_threads(config.num_threads)
|
||||
if config.pin_memory:
|
||||
torch.set_default_tensor_type(torch.FloatTensor)
|
||||
|
||||
async def load_model(self, path: str) -> None:
|
||||
"""Load PyTorch model.
|
||||
|
||||
Args:
|
||||
path: Path to model file
|
||||
|
||||
Raises:
|
||||
RuntimeError: If model loading fails
|
||||
"""
|
||||
try:
|
||||
# Get verified model path
|
||||
model_path = await paths.get_model_path(path)
|
||||
|
||||
logger.info(f"Loading PyTorch model on {self._device}: {model_path}")
|
||||
self._model = await build_model(model_path, self._device)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load PyTorch model: {e}")
|
||||
|
||||
def generate(
|
||||
self, tokens: list[int], voice: torch.Tensor, speed: float = 1.0
|
||||
) -> np.ndarray:
|
||||
"""Generate audio using model.
|
||||
|
||||
Args:
|
||||
tokens: Input token IDs
|
||||
voice: Voice embedding tensor
|
||||
speed: Speed multiplier
|
||||
|
||||
Returns:
|
||||
Generated audio samples
|
||||
|
||||
Raises:
|
||||
RuntimeError: If generation fails
|
||||
"""
|
||||
if not self.is_loaded:
|
||||
raise RuntimeError("Model not loaded")
|
||||
|
||||
try:
|
||||
# Memory management for GPU
|
||||
if self._device == "cuda":
|
||||
if self._check_memory():
|
||||
self._clear_memory()
|
||||
stream = self._stream_manager.get_next_stream()
|
||||
else:
|
||||
stream = None
|
||||
|
||||
# Get reference style from voice pack
|
||||
ref_s = voice[len(tokens)].clone().to(self._device)
|
||||
if ref_s.dim() == 1:
|
||||
ref_s = ref_s.unsqueeze(0)
|
||||
|
||||
# Generate audio
|
||||
return forward(self._model, tokens, ref_s, speed, stream)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Generation failed: {e}")
|
||||
if (
|
||||
self._device == "cuda"
|
||||
and model_config.pytorch_gpu.retry_on_oom
|
||||
and "out of memory" in str(e).lower()
|
||||
):
|
||||
self._clear_memory()
|
||||
return self.generate(tokens, voice, speed)
|
||||
raise
|
||||
finally:
|
||||
if self._device == "cuda" and model_config.pytorch_gpu.sync_cuda:
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def _check_memory(self) -> bool:
|
||||
"""Check if memory usage is above threshold."""
|
||||
if self._device == "cuda":
|
||||
memory_gb = torch.cuda.memory_allocated() / 1e9
|
||||
return memory_gb > model_config.pytorch_gpu.memory_threshold
|
||||
return False
|
||||
|
||||
def _clear_memory(self) -> None:
|
||||
"""Clear device memory."""
|
||||
if self._device == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
272
api/src/inference/session_pool.py
Normal file
|
@ -0,0 +1,272 @@
|
|||
"""Session pooling for model inference."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Set
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from onnxruntime import (
|
||||
ExecutionMode,
|
||||
GraphOptimizationLevel,
|
||||
InferenceSession,
|
||||
SessionOptions
|
||||
)
|
||||
|
||||
from ..core import paths
|
||||
from ..core.model_config import model_config
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionInfo:
|
||||
"""Session information."""
|
||||
session: InferenceSession
|
||||
last_used: float
|
||||
stream_id: Optional[int] = None
|
||||
|
||||
|
||||
def create_session_options(is_gpu: bool = False) -> SessionOptions:
|
||||
"""Create ONNX session options.
|
||||
|
||||
Args:
|
||||
is_gpu: Whether to use GPU configuration
|
||||
|
||||
Returns:
|
||||
Configured session options
|
||||
"""
|
||||
options = SessionOptions()
|
||||
config = model_config.onnx_gpu if is_gpu else model_config.onnx_cpu
|
||||
|
||||
# Set optimization level
|
||||
if config.optimization_level == "all":
|
||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
elif config.optimization_level == "basic":
|
||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||
else:
|
||||
options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
|
||||
|
||||
# Configure threading
|
||||
options.intra_op_num_threads = config.num_threads
|
||||
options.inter_op_num_threads = config.inter_op_threads
|
||||
|
||||
# Set execution mode
|
||||
options.execution_mode = (
|
||||
ExecutionMode.ORT_PARALLEL
|
||||
if config.execution_mode == "parallel"
|
||||
else ExecutionMode.ORT_SEQUENTIAL
|
||||
)
|
||||
|
||||
# Configure memory optimization
|
||||
options.enable_mem_pattern = config.memory_pattern
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def create_provider_options(is_gpu: bool = False) -> Dict:
|
||||
"""Create provider options.
|
||||
|
||||
Args:
|
||||
is_gpu: Whether to use GPU configuration
|
||||
|
||||
Returns:
|
||||
Provider configuration
|
||||
"""
|
||||
if is_gpu:
|
||||
config = model_config.onnx_gpu
|
||||
return {
|
||||
"device_id": config.device_id,
|
||||
"arena_extend_strategy": config.arena_extend_strategy,
|
||||
"gpu_mem_limit": int(config.gpu_mem_limit * torch.cuda.get_device_properties(0).total_memory),
|
||||
"cudnn_conv_algo_search": config.cudnn_conv_algo_search,
|
||||
"do_copy_in_default_stream": config.do_copy_in_default_stream
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"arena_extend_strategy": model_config.onnx_cpu.arena_extend_strategy,
|
||||
"cpu_memory_arena_cfg": "cpu:0"
|
||||
}
|
||||
|
||||
|
||||
class BaseSessionPool:
|
||||
"""Base session pool implementation."""
|
||||
|
||||
def __init__(self, max_size: int, timeout: int):
|
||||
"""Initialize session pool.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of concurrent sessions
|
||||
timeout: Session timeout in seconds
|
||||
"""
|
||||
self._max_size = max_size
|
||||
self._timeout = timeout
|
||||
self._sessions: Dict[str, SessionInfo] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def get_session(self, model_path: str) -> InferenceSession:
|
||||
"""Get session from pool.
|
||||
|
||||
Args:
|
||||
model_path: Path to model file
|
||||
|
||||
Returns:
|
||||
ONNX inference session
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no sessions available
|
||||
"""
|
||||
async with self._lock:
|
||||
# Clean expired sessions
|
||||
self._cleanup_expired()
|
||||
|
||||
# TODO: Change session tracking to use unique IDs instead of model paths
|
||||
# This would allow multiple instances of the same model
|
||||
|
||||
# Check if session exists and is valid
|
||||
if model_path in self._sessions:
|
||||
session_info = self._sessions[model_path]
|
||||
session_info.last_used = time.time()
|
||||
return session_info.session
|
||||
|
||||
# TODO: Modify session limit check to count instances per model path
|
||||
# Rather than total sessions across all models
|
||||
if len(self._sessions) >= self._max_size:
|
||||
raise RuntimeError(
|
||||
f"Maximum number of sessions reached ({self._max_size}). "
|
||||
"Try again later or reduce concurrent requests."
|
||||
)
|
||||
|
||||
# Create new session
|
||||
session = await self._create_session(model_path)
|
||||
self._sessions[model_path] = SessionInfo(
|
||||
session=session,
|
||||
last_used=time.time()
|
||||
)
|
||||
return session
|
||||
|
||||
def _cleanup_expired(self) -> None:
|
||||
"""Remove expired sessions."""
|
||||
current_time = time.time()
|
||||
expired = [
|
||||
path for path, info in self._sessions.items()
|
||||
if current_time - info.last_used > self._timeout
|
||||
]
|
||||
for path in expired:
|
||||
logger.info(f"Removing expired session: {path}")
|
||||
del self._sessions[path]
|
||||
|
||||
async def _create_session(self, model_path: str) -> InferenceSession:
|
||||
"""Create new session.
|
||||
|
||||
Args:
|
||||
model_path: Path to model file
|
||||
|
||||
Returns:
|
||||
ONNX inference session
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Clean up all sessions."""
|
||||
self._sessions.clear()
|
||||
|
||||
|
||||
class StreamingSessionPool(BaseSessionPool):
|
||||
"""GPU session pool with CUDA streams."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize GPU session pool."""
|
||||
config = model_config.onnx_gpu
|
||||
super().__init__(config.cuda_streams, config.stream_timeout)
|
||||
self._available_streams: Set[int] = set(range(config.cuda_streams))
|
||||
|
||||
async def get_session(self, model_path: str) -> InferenceSession:
|
||||
"""Get session with CUDA stream.
|
||||
|
||||
Args:
|
||||
model_path: Path to model file
|
||||
|
||||
Returns:
|
||||
ONNX inference session
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no streams available
|
||||
"""
|
||||
async with self._lock:
|
||||
# Clean expired sessions
|
||||
self._cleanup_expired()
|
||||
|
||||
# Try to find existing session
|
||||
if model_path in self._sessions:
|
||||
session_info = self._sessions[model_path]
|
||||
session_info.last_used = time.time()
|
||||
return session_info.session
|
||||
|
||||
# Get available stream
|
||||
if not self._available_streams:
|
||||
raise RuntimeError("No CUDA streams available")
|
||||
stream_id = self._available_streams.pop()
|
||||
|
||||
try:
|
||||
# Create new session
|
||||
session = await self._create_session(model_path)
|
||||
self._sessions[model_path] = SessionInfo(
|
||||
session=session,
|
||||
last_used=time.time(),
|
||||
stream_id=stream_id
|
||||
)
|
||||
return session
|
||||
|
||||
except Exception:
|
||||
# Return stream to pool on failure
|
||||
self._available_streams.add(stream_id)
|
||||
raise
|
||||
|
||||
def _cleanup_expired(self) -> None:
|
||||
"""Remove expired sessions and return streams."""
|
||||
current_time = time.time()
|
||||
expired = [
|
||||
path for path, info in self._sessions.items()
|
||||
if current_time - info.last_used > self._timeout
|
||||
]
|
||||
for path in expired:
|
||||
info = self._sessions[path]
|
||||
if info.stream_id is not None:
|
||||
self._available_streams.add(info.stream_id)
|
||||
logger.info(f"Removing expired session: {path}")
|
||||
del self._sessions[path]
|
||||
|
||||
async def _create_session(self, model_path: str) -> InferenceSession:
|
||||
"""Create new session with CUDA provider."""
|
||||
abs_path = await paths.get_model_path(model_path)
|
||||
options = create_session_options(is_gpu=True)
|
||||
provider_options = create_provider_options(is_gpu=True)
|
||||
|
||||
return InferenceSession(
|
||||
abs_path,
|
||||
sess_options=options,
|
||||
providers=["CUDAExecutionProvider"],
|
||||
provider_options=[provider_options]
|
||||
)
|
||||
|
||||
|
||||
class CPUSessionPool(BaseSessionPool):
|
||||
"""CPU session pool."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize CPU session pool."""
|
||||
config = model_config.onnx_cpu
|
||||
super().__init__(config.max_instances, config.instance_timeout)
|
||||
|
||||
async def _create_session(self, model_path: str) -> InferenceSession:
|
||||
"""Create new session with CPU provider."""
|
||||
abs_path = await paths.get_model_path(model_path)
|
||||
options = create_session_options(is_gpu=False)
|
||||
provider_options = create_provider_options(is_gpu=False)
|
||||
|
||||
return InferenceSession(
|
||||
abs_path,
|
||||
sess_options=options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
provider_options=[provider_options]
|
||||
)
|
215
api/src/inference/voice_manager.py
Normal file
|
@ -0,0 +1,215 @@
|
|||
"""Voice pack management and caching."""
|
||||
|
||||
import os
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from ..core import paths
|
||||
from ..core.config import settings
|
||||
from ..structures.model_schemas import VoiceConfig
|
||||
|
||||
|
||||
class VoiceManager:
|
||||
"""Manages voice loading and operations."""
|
||||
|
||||
def __init__(self, config: Optional[VoiceConfig] = None):
|
||||
"""Initialize voice manager.
|
||||
|
||||
Args:
|
||||
config: Optional voice configuration
|
||||
"""
|
||||
self._config = config or VoiceConfig()
|
||||
self._voice_cache: Dict[str, torch.Tensor] = {}
|
||||
|
||||
def get_voice_path(self, voice_name: str) -> Optional[str]:
|
||||
"""Get path to voice file.
|
||||
|
||||
Args:
|
||||
voice_name: Name of voice
|
||||
|
||||
Returns:
|
||||
Path to voice file if exists, None otherwise
|
||||
"""
|
||||
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
voice_path = os.path.join(api_dir, settings.voices_dir, f"{voice_name}.pt")
|
||||
return voice_path if os.path.exists(voice_path) else None
|
||||
|
||||
async def load_voice(self, voice_name: str, device: str = "cpu") -> torch.Tensor:
|
||||
"""Load voice tensor.
|
||||
|
||||
Args:
|
||||
voice_name: Name of voice to load
|
||||
device: Device to load voice on
|
||||
|
||||
Returns:
|
||||
Voice tensor
|
||||
|
||||
Raises:
|
||||
RuntimeError: If voice loading fails
|
||||
"""
|
||||
# Check if it's a combined voice request
|
||||
if "+" in voice_name:
|
||||
voices = [v.strip() for v in voice_name.split("+") if v.strip()]
|
||||
if len(voices) < 2:
|
||||
raise RuntimeError(f"Invalid combined voice name: {voice_name}")
|
||||
|
||||
# Load and combine voices
|
||||
voice_tensors = []
|
||||
for voice in voices:
|
||||
try:
|
||||
voice_tensor = await self.load_voice(voice, device)
|
||||
voice_tensors.append(voice_tensor)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load base voice {voice}: {e}")
|
||||
|
||||
return torch.mean(torch.stack(voice_tensors), dim=0)
|
||||
|
||||
# Handle single voice
|
||||
voice_path = self.get_voice_path(voice_name)
|
||||
if not voice_path:
|
||||
raise RuntimeError(f"Voice not found: {voice_name}")
|
||||
|
||||
# Check cache
|
||||
cache_key = f"{voice_path}_{device}"
|
||||
if self._config.use_cache and cache_key in self._voice_cache:
|
||||
return self._voice_cache[cache_key]
|
||||
|
||||
# Load voice tensor
|
||||
try:
|
||||
voice = await paths.load_voice_tensor(voice_path, device=device)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load voice {voice_name}: {e}")
|
||||
|
||||
# Cache if enabled
|
||||
if self._config.use_cache:
|
||||
self._manage_cache()
|
||||
self._voice_cache[cache_key] = voice
|
||||
logger.debug(f"Cached voice: {voice_name} on {device}")
|
||||
|
||||
return voice
|
||||
|
||||
def _manage_cache(self) -> None:
|
||||
"""Manage voice cache size using simple LRU."""
|
||||
if len(self._voice_cache) >= self._config.cache_size:
|
||||
# Remove least recently used voice
|
||||
oldest = next(iter(self._voice_cache))
|
||||
del self._voice_cache[oldest]
|
||||
torch.cuda.empty_cache() # Clean up GPU memory if needed
|
||||
logger.debug(f"Removed LRU voice from cache: {oldest}")
|
||||
|
||||
async def combine_voices(self, voices: List[str], device: str = "cpu") -> str:
|
||||
"""Combine multiple voices into a new voice.
|
||||
|
||||
Args:
|
||||
voices: List of voice names to combine
|
||||
device: Device to load voices on
|
||||
|
||||
Returns:
|
||||
Name of combined voice
|
||||
|
||||
Raises:
|
||||
ValueError: If fewer than 2 voices provided
|
||||
RuntimeError: If voice combination fails
|
||||
"""
|
||||
if len(voices) < 2:
|
||||
raise ValueError("At least 2 voices are required for combination")
|
||||
|
||||
# Create combined name using + as separator
|
||||
combined_name = "+".join(voices)
|
||||
|
||||
# If saving is enabled, try to save the combination
|
||||
if settings.allow_local_voice_saving:
|
||||
try:
|
||||
# Load and combine voices
|
||||
combined_tensor = await self.load_voice(combined_name, device)
|
||||
|
||||
# Save to disk
|
||||
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
voices_dir = os.path.join(api_dir, settings.voices_dir)
|
||||
os.makedirs(voices_dir, exist_ok=True)
|
||||
|
||||
combined_path = os.path.join(voices_dir, f"{combined_name}.pt")
|
||||
try:
|
||||
torch.save(combined_tensor, combined_path)
|
||||
# Cache with path-based key
|
||||
self._voice_cache[f"{combined_path}_{device}"] = combined_tensor
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to save combined voice: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save combined voice: {e}")
|
||||
# Continue without saving - will be combined on-the-fly when needed
|
||||
|
||||
return combined_name
|
||||
|
||||
async def list_voices(self) -> List[str]:
|
||||
"""List available voices.
|
||||
|
||||
Returns:
|
||||
List of voice names
|
||||
"""
|
||||
voices = set() # Use set to avoid duplicates
|
||||
try:
|
||||
# Get voices from disk
|
||||
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
voices_dir = os.path.join(api_dir, settings.voices_dir)
|
||||
os.makedirs(voices_dir, exist_ok=True)
|
||||
|
||||
for entry in os.listdir(voices_dir):
|
||||
if entry.endswith(".pt"):
|
||||
voices.add(entry[:-3])
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing voices: {e}")
|
||||
return sorted(list(voices))
|
||||
|
||||
def validate_voice(self, voice_path: str) -> bool:
|
||||
"""Validate voice file.
|
||||
|
||||
Args:
|
||||
voice_path: Path to voice file
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
if not os.path.exists(voice_path):
|
||||
return False
|
||||
voice = torch.load(voice_path, map_location="cpu")
|
||||
return isinstance(voice, torch.Tensor)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@property
|
||||
def cache_info(self) -> Dict[str, int]:
|
||||
"""Get cache statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with cache info
|
||||
"""
|
||||
return {
|
||||
'size': len(self._voice_cache),
|
||||
'max_size': self._config.cache_size
|
||||
}
|
||||
|
||||
|
||||
# Global singleton instance and lock
|
||||
_manager_instance = None
|
||||
|
||||
|
||||
async def get_manager(config: Optional[VoiceConfig] = None) -> VoiceManager:
|
||||
"""Get global voice manager instance.
|
||||
|
||||
Args:
|
||||
config: Optional voice configuration
|
||||
|
||||
Returns:
|
||||
VoiceManager instance
|
||||
"""
|
||||
global _manager_instance
|
||||
|
||||
if _manager_instance is None:
|
||||
_manager_instance = VoiceManager(config)
|
||||
return _manager_instance
|
|
@ -2,19 +2,22 @@
|
|||
FastAPI OpenAI Compatible API
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from loguru import logger
|
||||
|
||||
from .core.config import settings
|
||||
from .routers.web_player import router as web_router
|
||||
from .routers.development import router as dev_router
|
||||
from .routers.openai_compatible import router as openai_router
|
||||
from .services.tts_model import TTSModel
|
||||
from .services.tts_service import TTSService
|
||||
from .routers.debug import router as debug_router
|
||||
|
||||
|
||||
def setup_logger():
|
||||
|
@ -42,11 +45,39 @@ setup_logger()
|
|||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Lifespan context manager for model initialization"""
|
||||
from .inference.model_manager import get_manager
|
||||
from .inference.voice_manager import get_manager as get_voice_manager
|
||||
from .services.temp_manager import cleanup_temp_files
|
||||
|
||||
# Clean old temp files on startup
|
||||
await cleanup_temp_files()
|
||||
|
||||
logger.info("Loading TTS model and voice packs...")
|
||||
|
||||
# Initialize the main model with warm-up
|
||||
voicepack_count = await TTSModel.setup()
|
||||
# boundary = "█████╗"*9
|
||||
try:
|
||||
# Initialize managers globally
|
||||
model_manager = await get_manager()
|
||||
voice_manager = await get_voice_manager()
|
||||
|
||||
# Initialize model with warmup and get status
|
||||
device, model, voicepack_count = await model_manager.initialize_with_warmup(voice_manager)
|
||||
except FileNotFoundError:
|
||||
logger.error("""
|
||||
Model files not found! You need to either:
|
||||
|
||||
1. Download models using the scripts:
|
||||
GPU: python docker/scripts/download_model.py --type pth
|
||||
CPU: python docker/scripts/download_model.py --type onnx
|
||||
|
||||
2. Set environment variables in docker-compose:
|
||||
GPU: DOWNLOAD_PTH=true
|
||||
CPU: DOWNLOAD_ONNX=true
|
||||
""")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize model: {e}")
|
||||
raise
|
||||
|
||||
boundary = "░" * 2*12
|
||||
startup_msg = f"""
|
||||
|
||||
|
@ -54,16 +85,22 @@ async def lifespan(app: FastAPI):
|
|||
|
||||
╔═╗┌─┐┌─┐┌┬┐
|
||||
╠╣ ├─┤└─┐ │
|
||||
╚ ┴ ┴└─┘ ┴
|
||||
╚ ┴ ┴└─┘ ┴
|
||||
╦╔═┌─┐┬┌─┌─┐
|
||||
╠╩╗│ │├┴┐│ │
|
||||
╩ ╩└─┘┴ ┴└─┘
|
||||
|
||||
{boundary}
|
||||
"""
|
||||
# TODO: Improve CPU warmup, threads, memory, etc
|
||||
startup_msg += f"\nModel warmed up on {TTSModel.get_device()}"
|
||||
startup_msg += f"\n{voicepack_count} voice packs loaded\n"
|
||||
startup_msg += f"\nModel warmed up on {device}: {model}"
|
||||
startup_msg += f"\n{voicepack_count} voice packs loaded"
|
||||
|
||||
# Add web player info if enabled
|
||||
if settings.enable_web_player:
|
||||
startup_msg += f"\n\nBeta Web Player: http://{settings.host}:{settings.port}/web/"
|
||||
else:
|
||||
startup_msg += "\n\nWeb Player: disabled"
|
||||
|
||||
startup_msg += f"\n{boundary}\n"
|
||||
logger.info(startup_msg)
|
||||
|
||||
|
@ -79,19 +116,22 @@ app = FastAPI(
|
|||
openapi_url="/openapi.json", # Explicitly enable OpenAPI schema
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
# Add CORS middleware if enabled
|
||||
if settings.cors_enabled:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include routers
|
||||
app.include_router(openai_router, prefix="/v1")
|
||||
app.include_router(dev_router) # New development endpoints
|
||||
# app.include_router(text_router) # Deprecated but still live for backwards compatibility
|
||||
app.include_router(dev_router) # Development endpoints
|
||||
app.include_router(debug_router) # Debug endpoints
|
||||
if settings.enable_web_player:
|
||||
app.include_router(web_router, prefix="/web") # Web player static files
|
||||
|
||||
|
||||
# Health check endpoint
|
||||
|
|
188
api/src/routers/debug.py
Normal file
|
@ -0,0 +1,188 @@
|
|||
from fastapi import APIRouter
|
||||
import psutil
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
try:
|
||||
import GPUtil
|
||||
GPU_AVAILABLE = True
|
||||
except ImportError:
|
||||
GPU_AVAILABLE = False
|
||||
|
||||
router = APIRouter(tags=["debug"])
|
||||
|
||||
@router.get("/debug/threads")
|
||||
async def get_thread_info():
|
||||
process = psutil.Process()
|
||||
current_threads = threading.enumerate()
|
||||
|
||||
# Get per-thread CPU times
|
||||
thread_details = []
|
||||
for thread in current_threads:
|
||||
thread_info = {
|
||||
"name": thread.name,
|
||||
"id": thread.ident,
|
||||
"alive": thread.is_alive(),
|
||||
"daemon": thread.daemon
|
||||
}
|
||||
thread_details.append(thread_info)
|
||||
|
||||
return {
|
||||
"total_threads": process.num_threads(),
|
||||
"active_threads": len(current_threads),
|
||||
"thread_names": [t.name for t in current_threads],
|
||||
"thread_details": thread_details,
|
||||
"memory_mb": process.memory_info().rss / 1024 / 1024
|
||||
}
|
||||
|
||||
@router.get("/debug/storage")
|
||||
async def get_storage_info():
|
||||
# Get disk partitions
|
||||
partitions = psutil.disk_partitions()
|
||||
storage_info = []
|
||||
|
||||
for partition in partitions:
|
||||
try:
|
||||
usage = psutil.disk_usage(partition.mountpoint)
|
||||
storage_info.append({
|
||||
"device": partition.device,
|
||||
"mountpoint": partition.mountpoint,
|
||||
"fstype": partition.fstype,
|
||||
"total_gb": usage.total / (1024**3),
|
||||
"used_gb": usage.used / (1024**3),
|
||||
"free_gb": usage.free / (1024**3),
|
||||
"percent_used": usage.percent
|
||||
})
|
||||
except PermissionError:
|
||||
continue
|
||||
|
||||
return {
|
||||
"storage_info": storage_info
|
||||
}
|
||||
|
||||
@router.get("/debug/system")
|
||||
async def get_system_info():
|
||||
process = psutil.Process()
|
||||
|
||||
# CPU Info
|
||||
cpu_info = {
|
||||
"cpu_count": psutil.cpu_count(),
|
||||
"cpu_percent": psutil.cpu_percent(interval=1),
|
||||
"per_cpu_percent": psutil.cpu_percent(interval=1, percpu=True),
|
||||
"load_avg": psutil.getloadavg()
|
||||
}
|
||||
|
||||
# Memory Info
|
||||
virtual_memory = psutil.virtual_memory()
|
||||
swap_memory = psutil.swap_memory()
|
||||
memory_info = {
|
||||
"virtual": {
|
||||
"total_gb": virtual_memory.total / (1024**3),
|
||||
"available_gb": virtual_memory.available / (1024**3),
|
||||
"used_gb": virtual_memory.used / (1024**3),
|
||||
"percent": virtual_memory.percent
|
||||
},
|
||||
"swap": {
|
||||
"total_gb": swap_memory.total / (1024**3),
|
||||
"used_gb": swap_memory.used / (1024**3),
|
||||
"free_gb": swap_memory.free / (1024**3),
|
||||
"percent": swap_memory.percent
|
||||
}
|
||||
}
|
||||
|
||||
# Process Info
|
||||
process_info = {
|
||||
"pid": process.pid,
|
||||
"status": process.status(),
|
||||
"create_time": datetime.fromtimestamp(process.create_time()).isoformat(),
|
||||
"cpu_percent": process.cpu_percent(),
|
||||
"memory_percent": process.memory_percent(),
|
||||
}
|
||||
|
||||
# Network Info
|
||||
network_info = {
|
||||
"connections": len(process.net_connections()),
|
||||
"network_io": psutil.net_io_counters()._asdict()
|
||||
}
|
||||
|
||||
# GPU Info if available
|
||||
gpu_info = None
|
||||
if GPU_AVAILABLE:
|
||||
try:
|
||||
gpus = GPUtil.getGPUs()
|
||||
gpu_info = [{
|
||||
"id": gpu.id,
|
||||
"name": gpu.name,
|
||||
"load": gpu.load,
|
||||
"memory": {
|
||||
"total": gpu.memoryTotal,
|
||||
"used": gpu.memoryUsed,
|
||||
"free": gpu.memoryFree,
|
||||
"percent": (gpu.memoryUsed / gpu.memoryTotal) * 100
|
||||
},
|
||||
"temperature": gpu.temperature
|
||||
} for gpu in gpus]
|
||||
except Exception:
|
||||
gpu_info = "GPU information unavailable"
|
||||
|
||||
return {
|
||||
"cpu": cpu_info,
|
||||
"memory": memory_info,
|
||||
"process": process_info,
|
||||
"network": network_info,
|
||||
"gpu": gpu_info
|
||||
}
|
||||
|
||||
@router.get("/debug/session_pools")
|
||||
async def get_session_pool_info():
|
||||
"""Get information about ONNX session pools."""
|
||||
from ..inference.model_manager import get_manager
|
||||
|
||||
manager = await get_manager()
|
||||
pools = manager._session_pools
|
||||
current_time = time.time()
|
||||
|
||||
pool_info = {}
|
||||
|
||||
# Get CPU pool info
|
||||
if 'onnx_cpu' in pools:
|
||||
cpu_pool = pools['onnx_cpu']
|
||||
pool_info['cpu'] = {
|
||||
"active_sessions": len(cpu_pool._sessions),
|
||||
"max_sessions": cpu_pool._max_size,
|
||||
"sessions": [{
|
||||
"model": path,
|
||||
"age_seconds": current_time - info.last_used
|
||||
} for path, info in cpu_pool._sessions.items()]
|
||||
}
|
||||
|
||||
# Get GPU pool info
|
||||
if 'onnx_gpu' in pools:
|
||||
gpu_pool = pools['onnx_gpu']
|
||||
pool_info['gpu'] = {
|
||||
"active_sessions": len(gpu_pool._sessions),
|
||||
"max_streams": gpu_pool._max_size,
|
||||
"available_streams": len(gpu_pool._available_streams),
|
||||
"sessions": [{
|
||||
"model": path,
|
||||
"age_seconds": current_time - info.last_used,
|
||||
"stream_id": info.stream_id
|
||||
} for path, info in gpu_pool._sessions.items()]
|
||||
}
|
||||
|
||||
# Add GPU memory info if available
|
||||
if GPU_AVAILABLE:
|
||||
try:
|
||||
gpus = GPUtil.getGPUs()
|
||||
if gpus:
|
||||
gpu = gpus[0] # Assume first GPU
|
||||
pool_info['gpu']['memory'] = {
|
||||
"total_mb": gpu.memoryTotal,
|
||||
"used_mb": gpu.memoryUsed,
|
||||
"free_mb": gpu.memoryFree,
|
||||
"percent_used": (gpu.memoryUsed / gpu.memoryTotal) * 100
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return pool_info
|
|
@ -1,12 +1,15 @@
|
|||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response
|
||||
import torch
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from loguru import logger
|
||||
|
||||
from ..services.audio import AudioService
|
||||
from ..services.text_processing import phonemize, tokenize
|
||||
from ..services.tts_model import TTSModel
|
||||
from ..services.audio import AudioService, AudioNormalizer
|
||||
from ..services.streaming_audio_writer import StreamingAudioWriter
|
||||
from ..services.text_processing import phonemize, smart_split
|
||||
from ..services.text_processing.vocabulary import tokenize
|
||||
from ..services.tts_service import TTSService
|
||||
from ..structures.text_schemas import (
|
||||
GenerateFromPhonemesRequest,
|
||||
|
@ -17,12 +20,10 @@ from ..structures.text_schemas import (
|
|||
router = APIRouter(tags=["text processing"])
|
||||
|
||||
|
||||
def get_tts_service() -> TTSService:
|
||||
async def get_tts_service() -> TTSService:
|
||||
"""Dependency to get TTSService instance"""
|
||||
return TTSService()
|
||||
return await TTSService.create() # Create service with properly initialized managers
|
||||
|
||||
|
||||
@router.post("/text/phonemize", response_model=PhonemeResponse, tags=["deprecated"])
|
||||
@router.post("/dev/phonemize", response_model=PhonemeResponse)
|
||||
async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse:
|
||||
"""Convert text to phonemes and tokens
|
||||
|
@ -43,10 +44,8 @@ async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse:
|
|||
if not phonemes:
|
||||
raise ValueError("Failed to generate phonemes")
|
||||
|
||||
# Get tokens
|
||||
# Get tokens (without adding start/end tokens to match process_text behavior)
|
||||
tokens = tokenize(phonemes)
|
||||
tokens = [0] + tokens + [0] # Add start/end tokens
|
||||
|
||||
return PhonemeResponse(phonemes=phonemes, tokens=tokens)
|
||||
except ValueError as e:
|
||||
logger.error(f"Error in phoneme generation: {str(e)}")
|
||||
|
@ -58,73 +57,95 @@ async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse:
|
|||
raise HTTPException(
|
||||
status_code=500, detail={"error": "Server error", "message": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/text/generate_from_phonemes", tags=["deprecated"])
|
||||
@router.post("/dev/generate_from_phonemes")
|
||||
async def generate_from_phonemes(
|
||||
request: GenerateFromPhonemesRequest,
|
||||
client_request: Request,
|
||||
tts_service: TTSService = Depends(get_tts_service),
|
||||
) -> Response:
|
||||
"""Generate audio directly from phonemes
|
||||
|
||||
Args:
|
||||
request: Request containing phonemes and generation parameters
|
||||
tts_service: Injected TTSService instance
|
||||
|
||||
Returns:
|
||||
WAV audio bytes
|
||||
"""
|
||||
# Validate phonemes first
|
||||
if not request.phonemes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "Invalid request", "message": "Phonemes cannot be empty"},
|
||||
)
|
||||
|
||||
# Validate voice exists
|
||||
voice_path = tts_service._get_voice_path(request.voice)
|
||||
if not voice_path:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Invalid request",
|
||||
"message": f"Voice not found: {request.voice}",
|
||||
},
|
||||
)
|
||||
|
||||
) -> StreamingResponse:
|
||||
"""Generate audio directly from phonemes with proper streaming"""
|
||||
try:
|
||||
# Load voice
|
||||
voicepack = tts_service._load_voice(voice_path)
|
||||
# Basic validation
|
||||
if not isinstance(request.phonemes, str):
|
||||
raise ValueError("Phonemes must be a string")
|
||||
if not request.phonemes:
|
||||
raise ValueError("Phonemes cannot be empty")
|
||||
|
||||
# Convert phonemes to tokens
|
||||
tokens = tokenize(request.phonemes)
|
||||
tokens = [0] + tokens + [0] # Add start/end tokens
|
||||
# Create streaming audio writer and normalizer
|
||||
writer = StreamingAudioWriter(format="wav", sample_rate=24000, channels=1)
|
||||
normalizer = AudioNormalizer()
|
||||
|
||||
# Generate audio directly from tokens
|
||||
audio = TTSModel.generate_from_tokens(tokens, voicepack, request.speed)
|
||||
async def generate_chunks():
|
||||
try:
|
||||
has_data = False
|
||||
# Process phonemes in chunks
|
||||
async for chunk_text, _ in smart_split(request.phonemes):
|
||||
# Check if client is still connected
|
||||
is_disconnected = client_request.is_disconnected
|
||||
if callable(is_disconnected):
|
||||
is_disconnected = await is_disconnected()
|
||||
if is_disconnected:
|
||||
logger.info("Client disconnected, stopping audio generation")
|
||||
break
|
||||
|
||||
# Convert to WAV bytes
|
||||
wav_bytes = AudioService.convert_audio(
|
||||
audio, 24000, "wav", is_first_chunk=True, is_last_chunk=True, stream=False
|
||||
)
|
||||
chunk_audio, _ = await tts_service.generate_from_phonemes(
|
||||
phonemes=chunk_text,
|
||||
voice=request.voice,
|
||||
speed=1.0
|
||||
)
|
||||
if chunk_audio is not None:
|
||||
has_data = True
|
||||
# Normalize audio before writing
|
||||
normalized_audio = await normalizer.normalize(chunk_audio)
|
||||
# Write chunk and yield bytes
|
||||
chunk_bytes = writer.write_chunk(normalized_audio)
|
||||
if chunk_bytes:
|
||||
yield chunk_bytes
|
||||
|
||||
return Response(
|
||||
content=wav_bytes,
|
||||
if not has_data:
|
||||
raise ValueError("Failed to generate any audio data")
|
||||
|
||||
# Finalize and yield remaining bytes if we still have a connection
|
||||
if not (callable(is_disconnected) and await is_disconnected()):
|
||||
final_bytes = writer.write_chunk(finalize=True)
|
||||
if final_bytes:
|
||||
yield final_bytes
|
||||
except Exception as e:
|
||||
logger.error(f"Error in audio chunk generation: {str(e)}")
|
||||
# Clean up writer on error
|
||||
writer.write_chunk(finalize=True)
|
||||
# Re-raise the original exception
|
||||
raise
|
||||
|
||||
return StreamingResponse(
|
||||
generate_chunks(),
|
||||
media_type="audio/wav",
|
||||
headers={
|
||||
"Content-Disposition": "attachment; filename=speech.wav",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Cache-Control": "no-cache",
|
||||
},
|
||||
"Transfer-Encoding": "chunked"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid request: {str(e)}")
|
||||
logger.error(f"Error generating audio: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=400, detail={"error": "Invalid request", "message": str(e)}
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "validation_error",
|
||||
"message": str(e),
|
||||
"type": "invalid_request_error"
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating audio: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail={"error": "Server error", "message": str(e)}
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "processing_error",
|
||||
"message": str(e),
|
||||
"type": "server_error"
|
||||
}
|
||||
)
|
||||
|
|
|
@ -1,23 +1,72 @@
|
|||
from typing import AsyncGenerator, List, Union
|
||||
"""OpenAI-compatible router for text-to-speech"""
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Response, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
import os
|
||||
from typing import AsyncGenerator, Dict, List, Union
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||
from fastapi.responses import StreamingResponse, FileResponse
|
||||
from loguru import logger
|
||||
|
||||
from ..services.audio import AudioService
|
||||
from ..services.tts_service import TTSService
|
||||
from ..structures.schemas import OpenAISpeechRequest
|
||||
from ..core.config import settings
|
||||
|
||||
# Load OpenAI mappings
|
||||
def load_openai_mappings() -> Dict:
|
||||
"""Load OpenAI voice and model mappings from JSON"""
|
||||
api_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
mapping_path = os.path.join(api_dir, "core", "openai_mappings.json")
|
||||
try:
|
||||
with open(mapping_path, 'r') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load OpenAI mappings: {e}")
|
||||
return {"models": {}, "voices": {}}
|
||||
|
||||
# Global mappings
|
||||
_openai_mappings = load_openai_mappings()
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
tags=["OpenAI Compatible TTS"],
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
# Global TTSService instance with lock
|
||||
_tts_service = None
|
||||
_init_lock = None
|
||||
|
||||
def get_tts_service() -> TTSService:
|
||||
"""Dependency to get TTSService instance with database session"""
|
||||
return TTSService() # Initialize TTSService with default settings
|
||||
|
||||
async def get_tts_service() -> TTSService:
|
||||
"""Get global TTSService instance"""
|
||||
global _tts_service, _init_lock
|
||||
|
||||
# Create lock if needed
|
||||
if _init_lock is None:
|
||||
import asyncio
|
||||
_init_lock = asyncio.Lock()
|
||||
|
||||
# Initialize service if needed
|
||||
if _tts_service is None:
|
||||
async with _init_lock:
|
||||
# Double check pattern
|
||||
if _tts_service is None:
|
||||
_tts_service = await TTSService.create()
|
||||
logger.info("Created global TTSService instance")
|
||||
|
||||
return _tts_service
|
||||
|
||||
|
||||
def get_model_name(model: str) -> str:
|
||||
"""Get internal model name from OpenAI model name"""
|
||||
base_name = _openai_mappings["models"].get(model)
|
||||
if not base_name:
|
||||
raise ValueError(f"Unsupported model: {model}")
|
||||
# Add extension based on runtime config
|
||||
extension = ".onnx" if settings.use_onnx else ".pth"
|
||||
return base_name + extension
|
||||
|
||||
async def process_voices(
|
||||
voice_input: Union[str, List[str]], tts_service: TTSService
|
||||
|
@ -25,26 +74,37 @@ async def process_voices(
|
|||
"""Process voice input into a combined voice, handling both string and list formats"""
|
||||
# Convert input to list of voices
|
||||
if isinstance(voice_input, str):
|
||||
# Check if it's an OpenAI voice name
|
||||
mapped_voice = _openai_mappings["voices"].get(voice_input)
|
||||
if mapped_voice:
|
||||
voice_input = mapped_voice
|
||||
voices = [v.strip() for v in voice_input.split("+") if v.strip()]
|
||||
else:
|
||||
voices = voice_input
|
||||
# For list input, map each voice if it's an OpenAI voice name
|
||||
voices = [_openai_mappings["voices"].get(v, v) for v in voice_input]
|
||||
voices = [v.strip() for v in voices if v.strip()]
|
||||
|
||||
if not voices:
|
||||
raise ValueError("No voices provided")
|
||||
|
||||
# Check if all voices exist
|
||||
# If single voice, validate and return it
|
||||
if len(voices) == 1:
|
||||
available_voices = await tts_service.list_voices()
|
||||
if voices[0] not in available_voices:
|
||||
raise ValueError(
|
||||
f"Voice '{voices[0]}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
||||
)
|
||||
return voices[0]
|
||||
|
||||
# For multiple voices, validate base voices exist
|
||||
available_voices = await tts_service.list_voices()
|
||||
for voice in voices:
|
||||
if voice not in available_voices:
|
||||
raise ValueError(
|
||||
f"Voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
||||
f"Base voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
||||
)
|
||||
|
||||
# If single voice, return it directly
|
||||
if len(voices) == 1:
|
||||
return voices[0]
|
||||
|
||||
# Otherwise combine voices
|
||||
# Combine voices
|
||||
return await tts_service.combine_voices(voices=voices)
|
||||
|
||||
|
||||
|
@ -64,7 +124,10 @@ async def stream_audio_chunks(
|
|||
output_format=request.response_format,
|
||||
):
|
||||
# Check if client is still connected
|
||||
if await client_request.is_disconnected():
|
||||
is_disconnected = client_request.is_disconnected
|
||||
if callable(is_disconnected):
|
||||
is_disconnected = await is_disconnected()
|
||||
if is_disconnected:
|
||||
logger.info("Client disconnected, stopping audio generation")
|
||||
break
|
||||
yield chunk
|
||||
|
@ -78,12 +141,23 @@ async def stream_audio_chunks(
|
|||
async def create_speech(
|
||||
request: OpenAISpeechRequest,
|
||||
client_request: Request,
|
||||
tts_service: TTSService = Depends(get_tts_service),
|
||||
x_raw_response: str = Header(None, alias="x-raw-response"),
|
||||
):
|
||||
"""OpenAI-compatible endpoint for text-to-speech"""
|
||||
# Validate model before processing request
|
||||
if request.model not in _openai_mappings["models"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "invalid_model",
|
||||
"message": f"Unsupported model: {request.model}",
|
||||
"type": "invalid_request_error"
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
# Process voice combination and validate
|
||||
# model_name = get_model_name(request.model)
|
||||
tts_service = await get_tts_service()
|
||||
voice_to_use = await process_voices(request.voice, tts_service)
|
||||
|
||||
# Set content type based on format
|
||||
|
@ -98,29 +172,79 @@ async def create_speech(
|
|||
|
||||
# Check if streaming is requested (default for OpenAI client)
|
||||
if request.stream:
|
||||
# Stream audio chunks as they're generated
|
||||
# Create generator but don't start it yet
|
||||
generator = stream_audio_chunks(tts_service, request, client_request)
|
||||
|
||||
# If download link requested, wrap generator with temp file writer
|
||||
if request.return_download_link:
|
||||
from ..services.temp_manager import TempFileWriter
|
||||
|
||||
temp_writer = TempFileWriter(request.response_format)
|
||||
await temp_writer.__aenter__() # Initialize temp file
|
||||
|
||||
# Get download path immediately after temp file creation
|
||||
download_path = temp_writer.download_path
|
||||
|
||||
# Create response headers with download path
|
||||
headers = {
|
||||
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Cache-Control": "no-cache",
|
||||
"Transfer-Encoding": "chunked",
|
||||
"X-Download-Path": download_path
|
||||
}
|
||||
|
||||
# Create async generator for streaming
|
||||
async def dual_output():
|
||||
try:
|
||||
# Write chunks to temp file and stream
|
||||
async for chunk in generator:
|
||||
if chunk: # Skip empty chunks
|
||||
await temp_writer.write(chunk)
|
||||
yield chunk
|
||||
|
||||
# Finalize the temp file
|
||||
await temp_writer.finalize()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in dual output streaming: {e}")
|
||||
await temp_writer.__aexit__(type(e), e, e.__traceback__)
|
||||
raise
|
||||
finally:
|
||||
# Ensure temp writer is closed
|
||||
if not temp_writer._finalized:
|
||||
await temp_writer.__aexit__(None, None, None)
|
||||
|
||||
# Stream with temp file writing
|
||||
return StreamingResponse(
|
||||
dual_output(),
|
||||
media_type=content_type,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Standard streaming without download link
|
||||
return StreamingResponse(
|
||||
stream_audio_chunks(tts_service, request, client_request),
|
||||
generator,
|
||||
media_type=content_type,
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
||||
"X-Accel-Buffering": "no", # Disable proxy buffering
|
||||
"Cache-Control": "no-cache", # Prevent caching
|
||||
"Transfer-Encoding": "chunked", # Enable chunked transfer encoding
|
||||
},
|
||||
"X-Accel-Buffering": "no",
|
||||
"Cache-Control": "no-cache",
|
||||
"Transfer-Encoding": "chunked"
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Generate complete audio
|
||||
audio, _ = tts_service._generate_audio(
|
||||
# Generate complete audio using public interface
|
||||
audio, _ = await tts_service.generate_audio(
|
||||
text=request.input,
|
||||
voice=voice_to_use,
|
||||
speed=request.speed,
|
||||
stitch_long_output=True,
|
||||
speed=request.speed
|
||||
)
|
||||
|
||||
# Convert to requested format
|
||||
content = AudioService.convert_audio(
|
||||
audio, 24000, request.response_format, is_first_chunk=True, stream=False
|
||||
# Convert to requested format with proper finalization
|
||||
content = await AudioService.convert_audio(
|
||||
audio, 24000, request.response_format,
|
||||
is_first_chunk=True,
|
||||
is_last_chunk=True
|
||||
)
|
||||
|
||||
return Response(
|
||||
|
@ -133,32 +257,98 @@ async def create_speech(
|
|||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid request: {str(e)}")
|
||||
# Handle validation errors
|
||||
logger.warning(f"Invalid request: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=400, detail={"error": "Invalid request", "message": str(e)}
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "validation_error",
|
||||
"message": str(e),
|
||||
"type": "invalid_request_error"
|
||||
}
|
||||
)
|
||||
except RuntimeError as e:
|
||||
# Handle runtime/processing errors
|
||||
logger.error(f"Processing error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "processing_error",
|
||||
"message": str(e),
|
||||
"type": "server_error"
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating speech: {str(e)}")
|
||||
# Handle unexpected errors
|
||||
logger.error(f"Unexpected error in speech generation: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail={"error": "Server error", "message": str(e)}
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "processing_error",
|
||||
"message": str(e),
|
||||
"type": "server_error"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/download/{filename}")
|
||||
async def download_audio_file(filename: str):
|
||||
"""Download a generated audio file from temp storage"""
|
||||
try:
|
||||
from ..core.paths import _find_file, get_content_type
|
||||
|
||||
# Search for file in temp directory
|
||||
file_path = await _find_file(
|
||||
filename=filename,
|
||||
search_paths=[settings.temp_file_dir]
|
||||
)
|
||||
|
||||
# Get content type from path helper
|
||||
content_type = await get_content_type(file_path)
|
||||
|
||||
return FileResponse(
|
||||
file_path,
|
||||
media_type=content_type,
|
||||
filename=filename,
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Content-Disposition": f"attachment; filename={filename}"
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error serving download file {filename}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "server_error",
|
||||
"message": "Failed to serve audio file",
|
||||
"type": "server_error"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/audio/voices")
|
||||
async def list_voices(tts_service: TTSService = Depends(get_tts_service)):
|
||||
async def list_voices():
|
||||
"""List all available voices for text-to-speech"""
|
||||
try:
|
||||
tts_service = await get_tts_service()
|
||||
voices = await tts_service.list_voices()
|
||||
return {"voices": voices}
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing voices: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "server_error",
|
||||
"message": "Failed to retrieve voice list",
|
||||
"type": "server_error"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/audio/voices/combine")
|
||||
async def combine_voices(
|
||||
request: Union[str, List[str]], tts_service: TTSService = Depends(get_tts_service)
|
||||
):
|
||||
async def combine_voices(request: Union[str, List[str]]):
|
||||
"""Combine multiple voices into a new voice.
|
||||
|
||||
Args:
|
||||
|
@ -174,18 +364,38 @@ async def combine_voices(
|
|||
- 500: Server error (file system issues, combination failed)
|
||||
"""
|
||||
try:
|
||||
tts_service = await get_tts_service()
|
||||
combined_voice = await process_voices(request, tts_service)
|
||||
voices = await tts_service.list_voices()
|
||||
return {"voices": voices, "voice": combined_voice}
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid voice combination request: {str(e)}")
|
||||
logger.warning(f"Invalid voice combination request: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=400, detail={"error": "Invalid request", "message": str(e)}
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "validation_error",
|
||||
"message": str(e),
|
||||
"type": "invalid_request_error"
|
||||
}
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Voice combination processing error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "processing_error",
|
||||
"message": "Failed to process voice combination request",
|
||||
"type": "server_error"
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Server error during voice combination: {str(e)}")
|
||||
logger.error(f"Unexpected error in voice combination: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail={"error": "Server error", "message": "Server error"}
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "server_error",
|
||||
"message": "An unexpected error occurred",
|
||||
"type": "server_error"
|
||||
}
|
||||
)
|
||||
|
|
48
api/src/routers/web_player.py
Normal file
|
@ -0,0 +1,48 @@
|
|||
"""Web player router with async file serving."""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import Response
|
||||
from loguru import logger
|
||||
|
||||
from ..core.config import settings
|
||||
from ..core.paths import get_web_file_path, read_bytes, get_content_type
|
||||
|
||||
router = APIRouter(
|
||||
tags=["Web Player"],
|
||||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
@router.get("/{filename:path}")
|
||||
async def serve_web_file(filename: str):
|
||||
"""Serve web player static files asynchronously."""
|
||||
if not settings.enable_web_player:
|
||||
raise HTTPException(status_code=404, detail="Web player is disabled")
|
||||
|
||||
try:
|
||||
# Default to index.html for root path
|
||||
if filename == "" or filename == "/":
|
||||
filename = "index.html"
|
||||
|
||||
# Get file path
|
||||
file_path = await get_web_file_path(filename)
|
||||
|
||||
# Read file content
|
||||
content = await read_bytes(file_path)
|
||||
|
||||
# Get content type
|
||||
content_type = await get_content_type(file_path)
|
||||
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=content_type,
|
||||
headers={
|
||||
"Cache-Control": "no-cache", # Prevent caching during development
|
||||
}
|
||||
)
|
||||
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"Web file not found: {filename}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error serving web file {filename}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
|
@ -10,37 +10,41 @@ from loguru import logger
|
|||
from pydub import AudioSegment
|
||||
|
||||
from ..core.config import settings
|
||||
|
||||
from .streaming_audio_writer import StreamingAudioWriter
|
||||
|
||||
class AudioNormalizer:
|
||||
"""Handles audio normalization state for a single stream"""
|
||||
|
||||
def __init__(self):
|
||||
self.int16_max = np.iinfo(np.int16).max
|
||||
self.chunk_trim_ms = settings.gap_trim_ms
|
||||
self.sample_rate = 24000 # Sample rate of the audio
|
||||
self.samples_to_trim = int(self.chunk_trim_ms * self.sample_rate / 1000)
|
||||
|
||||
def normalize(
|
||||
self, audio_data: np.ndarray, is_last_chunk: bool = False
|
||||
) -> np.ndarray:
|
||||
"""Convert audio data to int16 range and trim chunk boundaries"""
|
||||
if len(audio_data) == 0:
|
||||
raise ValueError("Audio data cannot be empty")
|
||||
async def normalize(self, audio_data: np.ndarray) -> np.ndarray:
|
||||
"""Convert audio data to int16 range and trim silence from start and end
|
||||
|
||||
Args:
|
||||
audio_data: Input audio data as numpy array
|
||||
|
||||
# Simple float32 to int16 conversion
|
||||
audio_float = audio_data.astype(np.float32)
|
||||
Returns:
|
||||
Normalized and trimmed audio data
|
||||
"""
|
||||
if len(audio_data) == 0:
|
||||
raise ValueError("Empty audio data")
|
||||
|
||||
# Trim start and end if enough samples
|
||||
if len(audio_data) > (2 * self.samples_to_trim):
|
||||
audio_data = audio_data[self.samples_to_trim:-self.samples_to_trim]
|
||||
|
||||
# Trim for non-final chunks
|
||||
if not is_last_chunk and len(audio_float) > self.samples_to_trim:
|
||||
audio_float = audio_float[:-self.samples_to_trim]
|
||||
|
||||
# Direct scaling like the non-streaming version
|
||||
return (audio_float * 32767).astype(np.int16)
|
||||
# Scale directly to int16 range with clipping
|
||||
return np.clip(audio_data * 32767, -32768, 32767).astype(np.int16)
|
||||
|
||||
|
||||
class AudioService:
|
||||
"""Service for audio format conversions"""
|
||||
"""Service for audio format conversions with streaming support"""
|
||||
|
||||
# Supported formats
|
||||
SUPPORTED_FORMATS = {"wav", "mp3", "opus", "flac", "aac", "pcm", "ogg"}
|
||||
|
||||
# Default audio format settings balanced for speed and compression
|
||||
DEFAULT_SETTINGS = {
|
||||
|
@ -59,137 +63,60 @@ class AudioService:
|
|||
},
|
||||
}
|
||||
|
||||
_writers = {}
|
||||
|
||||
@staticmethod
|
||||
def convert_audio(
|
||||
async def convert_audio(
|
||||
audio_data: np.ndarray,
|
||||
sample_rate: int,
|
||||
output_format: str,
|
||||
is_first_chunk: bool = True,
|
||||
is_last_chunk: bool = False,
|
||||
normalizer: AudioNormalizer = None,
|
||||
format_settings: dict = None,
|
||||
stream: bool = True,
|
||||
) -> bytes:
|
||||
"""Convert audio data to specified format
|
||||
"""Convert audio data to specified format with streaming support
|
||||
|
||||
Args:
|
||||
audio_data: Numpy array of audio samples
|
||||
sample_rate: Sample rate of the audio
|
||||
output_format: Target format (wav, mp3, opus, flac, pcm)
|
||||
is_first_chunk: Whether this is the first chunk of a stream
|
||||
normalizer: Optional AudioNormalizer instance for consistent normalization across chunks
|
||||
format_settings: Optional dict of format-specific settings to override defaults
|
||||
Example: {
|
||||
"mp3": {
|
||||
"bitrate_mode": "VARIABLE",
|
||||
"compression_level": 0.8
|
||||
}
|
||||
}
|
||||
Default settings balance speed and compression:
|
||||
optimized for localhost @ 0.0
|
||||
- MP3: constant bitrate, no compression (0.0)
|
||||
- OPUS: no compression (0.0)
|
||||
- FLAC: no compression (0.0)
|
||||
output_format: Target format (wav, mp3, ogg, pcm)
|
||||
is_first_chunk: Whether this is the first chunk
|
||||
is_last_chunk: Whether this is the last chunk
|
||||
normalizer: Optional AudioNormalizer instance for consistent normalization
|
||||
|
||||
Returns:
|
||||
Bytes of the converted audio
|
||||
Bytes of the converted audio chunk
|
||||
"""
|
||||
buffer = BytesIO()
|
||||
|
||||
try:
|
||||
# Validate format
|
||||
if output_format not in AudioService.SUPPORTED_FORMATS:
|
||||
raise ValueError(f"Format {output_format} not supported")
|
||||
|
||||
# Always normalize audio to ensure proper amplitude scaling
|
||||
if normalizer is None:
|
||||
normalizer = AudioNormalizer()
|
||||
normalized_audio = normalizer.normalize(
|
||||
audio_data, is_last_chunk=is_last_chunk
|
||||
)
|
||||
normalized_audio = await normalizer.normalize(audio_data)
|
||||
|
||||
if output_format == "pcm":
|
||||
# Raw 16-bit PCM samples, no header
|
||||
buffer.write(normalized_audio.tobytes())
|
||||
elif output_format == "wav":
|
||||
# Write the WAV header ourselves so that we can specify a "fake" data size.
|
||||
# This is necessary for streaming responses to work properly: if we simply
|
||||
# concatenated individual WAV files then the initial chunk's header length
|
||||
# would be shorter than the full file length and subsequent chunks' RIFF
|
||||
# headers would appear in the middle of the audio data.
|
||||
if is_first_chunk:
|
||||
# Modified from Python stdlib's wave.py module:
|
||||
buffer.write(b'RIFF')
|
||||
buffer.write(struct.pack('<L4s4sLHHLLHH4s',
|
||||
0xFFFFFFFF, # total size (set to max)
|
||||
b'WAVE',
|
||||
b'fmt ',
|
||||
16,
|
||||
1, # PCM format
|
||||
1, # channels
|
||||
sample_rate,
|
||||
sample_rate * 2, # byte rate
|
||||
2, # block align
|
||||
16, # bits per sample
|
||||
b'data'
|
||||
))
|
||||
buffer.write(struct.pack('<L', 0xFFFFFFFF)) # data size (set to max)
|
||||
# write raw PCM data
|
||||
buffer.write(normalized_audio.tobytes())
|
||||
elif output_format == "mp3":
|
||||
# MP3 format with proper framing
|
||||
settings = format_settings.get("mp3", {}) if format_settings else {}
|
||||
settings = {**AudioService.DEFAULT_SETTINGS["mp3"], **settings}
|
||||
sf.write(
|
||||
buffer, normalized_audio, sample_rate, format="MP3", **settings
|
||||
)
|
||||
elif output_format == "opus":
|
||||
# Opus format in OGG container
|
||||
settings = format_settings.get("opus", {}) if format_settings else {}
|
||||
settings = {**AudioService.DEFAULT_SETTINGS["opus"], **settings}
|
||||
sf.write(
|
||||
buffer,
|
||||
normalized_audio,
|
||||
sample_rate,
|
||||
format="OGG",
|
||||
subtype="OPUS",
|
||||
**settings,
|
||||
)
|
||||
elif output_format == "flac":
|
||||
# FLAC format with proper framing
|
||||
if is_first_chunk:
|
||||
logger.info("Starting FLAC stream...")
|
||||
settings = format_settings.get("flac", {}) if format_settings else {}
|
||||
settings = {**AudioService.DEFAULT_SETTINGS["flac"], **settings}
|
||||
sf.write(
|
||||
buffer,
|
||||
normalized_audio,
|
||||
sample_rate,
|
||||
format="FLAC",
|
||||
subtype="PCM_16",
|
||||
**settings,
|
||||
)
|
||||
elif output_format == "aac":
|
||||
# Convert numpy array directly to AAC using pydub
|
||||
audio_segment = AudioSegment(
|
||||
normalized_audio.tobytes(),
|
||||
frame_rate=sample_rate,
|
||||
sample_width=normalized_audio.dtype.itemsize,
|
||||
channels=1 if len(normalized_audio.shape) == 1 else normalized_audio.shape[1]
|
||||
)
|
||||
|
||||
settings = format_settings.get("aac", {}) if format_settings else {}
|
||||
settings = {**AudioService.DEFAULT_SETTINGS["aac"], **settings}
|
||||
|
||||
audio_segment.export(
|
||||
buffer,
|
||||
format="adts", # ADTS is a common AAC container format
|
||||
bitrate=settings["bitrate"]
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Format {output_format} not supported. Supported formats are: wav, mp3, opus, flac, pcm, aac."
|
||||
# Get or create format-specific writer
|
||||
writer_key = f"{output_format}_{sample_rate}"
|
||||
if is_first_chunk or writer_key not in AudioService._writers:
|
||||
AudioService._writers[writer_key] = StreamingAudioWriter(
|
||||
output_format, sample_rate
|
||||
)
|
||||
writer = AudioService._writers[writer_key]
|
||||
|
||||
buffer.seek(0)
|
||||
return buffer.getvalue()
|
||||
# Write audio data first
|
||||
if len(normalized_audio) > 0:
|
||||
chunk_data = writer.write_chunk(normalized_audio)
|
||||
|
||||
# Then finalize if this is the last chunk
|
||||
if is_last_chunk:
|
||||
final_data = writer.write_chunk(finalize=True)
|
||||
del AudioService._writers[writer_key]
|
||||
return final_data if final_data else b''
|
||||
|
||||
return chunk_data if chunk_data else b''
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting audio to {output_format}: {str(e)}")
|
||||
raise ValueError(f"Failed to convert audio to {output_format}: {str(e)}")
|
||||
logger.error(f"Error converting audio stream to {output_format}: {str(e)}")
|
||||
raise ValueError(f"Failed to convert audio stream to {output_format}: {str(e)}")
|
||||
|
|
207
api/src/services/streaming_audio_writer.py
Normal file
|
@ -0,0 +1,207 @@
|
|||
"""Audio conversion service with proper streaming support"""
|
||||
|
||||
from io import BytesIO
|
||||
import struct
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
from loguru import logger
|
||||
from pydub import AudioSegment
|
||||
|
||||
class StreamingAudioWriter:
|
||||
"""Handles streaming audio format conversions"""
|
||||
|
||||
def __init__(self, format: str, sample_rate: int, channels: int = 1):
|
||||
self.format = format.lower()
|
||||
self.sample_rate = sample_rate
|
||||
self.channels = channels
|
||||
self.bytes_written = 0
|
||||
self.buffer = BytesIO()
|
||||
|
||||
# Format-specific setup
|
||||
if self.format == "wav":
|
||||
self._write_wav_header_initial()
|
||||
elif self.format in ["ogg", "opus"]:
|
||||
# For OGG/Opus, write to memory buffer
|
||||
self.writer = sf.SoundFile(
|
||||
file=self.buffer,
|
||||
mode='w',
|
||||
samplerate=sample_rate,
|
||||
channels=channels,
|
||||
format='OGG',
|
||||
subtype='VORBIS' if self.format == "ogg" else "OPUS"
|
||||
)
|
||||
elif self.format == "flac":
|
||||
# For FLAC, write to memory buffer
|
||||
self.writer = sf.SoundFile(
|
||||
file=self.buffer,
|
||||
mode='w',
|
||||
samplerate=sample_rate,
|
||||
channels=channels,
|
||||
format='FLAC'
|
||||
)
|
||||
elif self.format in ["mp3", "aac"]:
|
||||
# For MP3/AAC, we'll use pydub's incremental writer
|
||||
self.segments = [] # Store segments until we have enough data
|
||||
self.total_duration = 0 # Track total duration in milliseconds
|
||||
# Initialize an empty AudioSegment as our encoder
|
||||
self.encoder = AudioSegment.silent(duration=0, frame_rate=self.sample_rate)
|
||||
elif self.format == "pcm":
|
||||
# PCM doesn't need initialization, we'll write raw bytes
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: {format}")
|
||||
|
||||
def _write_wav_header_initial(self) -> None:
|
||||
"""Write initial WAV header with placeholders"""
|
||||
self.buffer.write(b'RIFF')
|
||||
self.buffer.write(struct.pack('<L', 0)) # Placeholder for file size
|
||||
self.buffer.write(b'WAVE')
|
||||
self.buffer.write(b'fmt ')
|
||||
self.buffer.write(struct.pack('<L', 16)) # fmt chunk size
|
||||
self.buffer.write(struct.pack('<H', 1)) # PCM format
|
||||
self.buffer.write(struct.pack('<H', self.channels))
|
||||
self.buffer.write(struct.pack('<L', self.sample_rate))
|
||||
self.buffer.write(struct.pack('<L', self.sample_rate * self.channels * 2)) # Byte rate
|
||||
self.buffer.write(struct.pack('<H', self.channels * 2)) # Block align
|
||||
self.buffer.write(struct.pack('<H', 16)) # Bits per sample
|
||||
self.buffer.write(b'data')
|
||||
self.buffer.write(struct.pack('<L', 0)) # Placeholder for data size
|
||||
|
||||
def write_chunk(self, audio_data: Optional[np.ndarray] = None, finalize: bool = False) -> bytes:
|
||||
"""Write a chunk of audio data and return bytes in the target format.
|
||||
|
||||
Args:
|
||||
audio_data: Audio data to write, or None if finalizing
|
||||
finalize: Whether this is the final write to close the stream
|
||||
"""
|
||||
output_buffer = BytesIO()
|
||||
|
||||
if finalize:
|
||||
if self.format == "wav":
|
||||
# Calculate actual file and data sizes
|
||||
file_size = self.bytes_written + 36 # RIFF header bytes
|
||||
data_size = self.bytes_written
|
||||
|
||||
# Seek to the beginning to overwrite the placeholders
|
||||
self.buffer.seek(4)
|
||||
self.buffer.write(struct.pack('<L', file_size))
|
||||
self.buffer.seek(40)
|
||||
self.buffer.write(struct.pack('<L', data_size))
|
||||
|
||||
self.buffer.seek(0)
|
||||
return self.buffer.read()
|
||||
elif self.format in ["ogg", "opus", "flac"]:
|
||||
self.writer.close()
|
||||
return self.buffer.getvalue()
|
||||
elif self.format in ["mp3", "aac"]:
|
||||
if hasattr(self, 'encoder') and len(self.encoder) > 0:
|
||||
format_args = {
|
||||
"mp3": {"format": "mp3", "codec": "libmp3lame"},
|
||||
"aac": {"format": "adts", "codec": "aac"}
|
||||
}[self.format]
|
||||
|
||||
parameters = []
|
||||
if self.format == "mp3":
|
||||
parameters.extend([
|
||||
"-q:a", "2",
|
||||
"-write_xing", "1", # XING header for MP3
|
||||
"-id3v1", "1",
|
||||
"-id3v2", "1",
|
||||
"-write_vbr", "1",
|
||||
"-vbr_quality", "2"
|
||||
])
|
||||
elif self.format == "aac":
|
||||
parameters.extend([
|
||||
"-q:a", "2",
|
||||
"-write_xing", "0",
|
||||
"-write_id3v1", "0",
|
||||
"-write_id3v2", "0"
|
||||
])
|
||||
|
||||
self.encoder.export(
|
||||
output_buffer,
|
||||
**format_args,
|
||||
bitrate="192k",
|
||||
parameters=parameters
|
||||
)
|
||||
self.encoder = None
|
||||
|
||||
return output_buffer.getvalue()
|
||||
|
||||
if audio_data is None or len(audio_data) == 0:
|
||||
return b''
|
||||
|
||||
if self.format == "wav":
|
||||
# Write raw PCM data
|
||||
self.buffer.write(audio_data.tobytes())
|
||||
self.bytes_written += len(audio_data.tobytes())
|
||||
return b''
|
||||
|
||||
elif self.format in ["ogg", "opus", "flac"]:
|
||||
# Write to soundfile buffer
|
||||
self.writer.write(audio_data)
|
||||
self.writer.flush()
|
||||
return self.buffer.getvalue()
|
||||
|
||||
elif self.format in ["mp3", "aac"]:
|
||||
# Convert chunk to AudioSegment and encode
|
||||
segment = AudioSegment(
|
||||
audio_data.tobytes(),
|
||||
frame_rate=self.sample_rate,
|
||||
sample_width=audio_data.dtype.itemsize,
|
||||
channels=self.channels
|
||||
)
|
||||
|
||||
# Track total duration
|
||||
self.total_duration += len(segment)
|
||||
|
||||
# Add segment to encoder
|
||||
self.encoder += segment
|
||||
|
||||
# Export current state to buffer without final metadata
|
||||
format_args = {
|
||||
"mp3": {"format": "mp3", "codec": "libmp3lame"},
|
||||
"aac": {"format": "adts", "codec": "aac"}
|
||||
}[self.format]
|
||||
|
||||
# For chunks, export without duration metadata or XING headers
|
||||
self.encoder.export(output_buffer, **format_args, bitrate="192k", parameters=[
|
||||
"-q:a", "2",
|
||||
"-write_xing", "0" # No XING headers for chunks
|
||||
])
|
||||
|
||||
# Get the encoded data
|
||||
encoded_data = output_buffer.getvalue()
|
||||
|
||||
# Reset encoder to prevent memory growth
|
||||
self.encoder = AudioSegment.silent(duration=0, frame_rate=self.sample_rate)
|
||||
|
||||
return encoded_data
|
||||
|
||||
elif self.format == "pcm":
|
||||
# Write raw bytes
|
||||
return audio_data.tobytes()
|
||||
|
||||
return b''
|
||||
|
||||
def close(self) -> Optional[bytes]:
|
||||
"""Finish the audio file and return any remaining data"""
|
||||
if self.format == "wav":
|
||||
# Re-finalize WAV file by updating headers
|
||||
self.buffer.seek(0)
|
||||
file_content = self.write_chunk(finalize=True)
|
||||
return file_content
|
||||
|
||||
elif self.format in ["ogg", "opus", "flac"]:
|
||||
# Finalize other formats
|
||||
self.writer.close()
|
||||
return self.buffer.getvalue()
|
||||
|
||||
elif self.format in ["mp3", "aac"]:
|
||||
# Finalize MP3/AAC
|
||||
final_data = self.write_chunk(finalize=True)
|
||||
return final_data
|
||||
|
||||
return None
|
139
api/src/services/temp_manager.py
Normal file
|
@ -0,0 +1,139 @@
|
|||
"""Temporary file writer for audio downloads"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Optional, List
|
||||
|
||||
import aiofiles
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from ..core.config import settings
|
||||
|
||||
|
||||
async def cleanup_temp_files() -> None:
|
||||
"""Clean up old temp files"""
|
||||
try:
|
||||
if not await aiofiles.os.path.exists(settings.temp_file_dir):
|
||||
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
|
||||
return
|
||||
|
||||
# Get all temp files with stats
|
||||
files = []
|
||||
total_size = 0
|
||||
|
||||
# Use os.scandir for sync iteration, but aiofiles.os.stat for async stats
|
||||
for entry in os.scandir(settings.temp_file_dir):
|
||||
if entry.is_file():
|
||||
stat = await aiofiles.os.stat(entry.path)
|
||||
files.append((entry.path, stat.st_mtime, stat.st_size))
|
||||
total_size += stat.st_size
|
||||
|
||||
# Sort by modification time (oldest first)
|
||||
files.sort(key=lambda x: x[1])
|
||||
|
||||
# Remove files if:
|
||||
# 1. They're too old
|
||||
# 2. We have too many files
|
||||
# 3. Directory is too large
|
||||
current_time = (await aiofiles.os.stat(settings.temp_file_dir)).st_mtime
|
||||
max_age = settings.max_temp_dir_age_hours * 3600
|
||||
|
||||
for path, mtime, size in files:
|
||||
should_delete = False
|
||||
|
||||
# Check age
|
||||
if current_time - mtime > max_age:
|
||||
should_delete = True
|
||||
logger.info(f"Deleting old temp file: {path}")
|
||||
|
||||
# Check count limit
|
||||
elif len(files) > settings.max_temp_dir_count:
|
||||
should_delete = True
|
||||
logger.info(f"Deleting excess temp file: {path}")
|
||||
|
||||
# Check size limit
|
||||
elif total_size > settings.max_temp_dir_size_mb * 1024 * 1024:
|
||||
should_delete = True
|
||||
logger.info(f"Deleting to reduce directory size: {path}")
|
||||
|
||||
if should_delete:
|
||||
try:
|
||||
await aiofiles.os.remove(path)
|
||||
total_size -= size
|
||||
logger.info(f"Deleted temp file: {path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete temp file {path}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during temp file cleanup: {e}")
|
||||
|
||||
|
||||
class TempFileWriter:
|
||||
"""Handles writing audio chunks to a temp file"""
|
||||
|
||||
def __init__(self, format: str):
|
||||
"""Initialize temp file writer
|
||||
|
||||
Args:
|
||||
format: Audio format extension (mp3, wav, etc)
|
||||
"""
|
||||
self.format = format
|
||||
self.temp_file = None
|
||||
self._finalized = False
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry"""
|
||||
# Clean up old files first
|
||||
await cleanup_temp_files()
|
||||
|
||||
# Create temp file with proper extension
|
||||
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
|
||||
temp = tempfile.NamedTemporaryFile(
|
||||
dir=settings.temp_file_dir,
|
||||
delete=False,
|
||||
suffix=f".{self.format}",
|
||||
mode='wb'
|
||||
)
|
||||
self.temp_file = await aiofiles.open(temp.name, mode='wb')
|
||||
self.temp_path = temp.name
|
||||
temp.close() # Close sync file, we'll use async version
|
||||
|
||||
# Generate download path immediately
|
||||
self.download_path = f"/download/{os.path.basename(self.temp_path)}"
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit"""
|
||||
try:
|
||||
if self.temp_file and not self._finalized:
|
||||
await self.temp_file.close()
|
||||
self._finalized = True
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing temp file: {e}")
|
||||
|
||||
async def write(self, chunk: bytes) -> None:
|
||||
"""Write a chunk of audio data
|
||||
|
||||
Args:
|
||||
chunk: Audio data bytes to write
|
||||
"""
|
||||
if self._finalized:
|
||||
raise RuntimeError("Cannot write to finalized temp file")
|
||||
|
||||
await self.temp_file.write(chunk)
|
||||
await self.temp_file.flush()
|
||||
|
||||
async def finalize(self) -> str:
|
||||
"""Close temp file and return download path
|
||||
|
||||
Returns:
|
||||
Path to use for downloading the temp file
|
||||
"""
|
||||
if self._finalized:
|
||||
raise RuntimeError("Temp file already finalized")
|
||||
|
||||
await self.temp_file.close()
|
||||
self._finalized = True
|
||||
|
||||
return f"/download/{os.path.basename(self.temp_path)}"
|
|
@ -1,13 +1,19 @@
|
|||
"""Text processing pipeline."""
|
||||
|
||||
from .normalizer import normalize_text
|
||||
from .phonemizer import EspeakBackend, PhonemizerBackend, phonemize
|
||||
from .vocabulary import VOCAB, decode_tokens, tokenize
|
||||
from .phonemizer import phonemize
|
||||
from .vocabulary import tokenize
|
||||
from .text_processor import process_text_chunk, smart_split
|
||||
|
||||
def process_text(text: str) -> list[int]:
|
||||
"""Process text into token IDs (for backward compatibility)."""
|
||||
return process_text_chunk(text)
|
||||
|
||||
__all__ = [
|
||||
"normalize_text",
|
||||
"phonemize",
|
||||
"tokenize",
|
||||
"decode_tokens",
|
||||
"VOCAB",
|
||||
"PhonemizerBackend",
|
||||
"EspeakBackend",
|
||||
'normalize_text',
|
||||
'phonemize',
|
||||
'tokenize',
|
||||
'process_text',
|
||||
'process_text_chunk',
|
||||
'smart_split'
|
||||
]
|
||||
|
|
|
@ -1,53 +0,0 @@
|
|||
"""Text chunking service"""
|
||||
|
||||
import re
|
||||
|
||||
from ...core.config import settings
|
||||
|
||||
|
||||
def split_text(text: str, max_chunk=None):
|
||||
"""Split text into chunks on natural pause points
|
||||
|
||||
Args:
|
||||
text: Text to split into chunks
|
||||
max_chunk: Maximum chunk size (defaults to settings.max_chunk_size)
|
||||
"""
|
||||
if max_chunk is None:
|
||||
max_chunk = settings.max_chunk_size
|
||||
|
||||
if not isinstance(text, str):
|
||||
text = str(text) if text is not None else ""
|
||||
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return
|
||||
|
||||
# First split into sentences
|
||||
sentences = re.split(r"(?<=[.!?])\s+", text)
|
||||
|
||||
for sentence in sentences:
|
||||
sentence = sentence.strip()
|
||||
if not sentence:
|
||||
continue
|
||||
|
||||
# For medium-length sentences, split on punctuation
|
||||
if len(sentence) > max_chunk: # Lower threshold for more consistent sizes
|
||||
# First try splitting on semicolons and colons
|
||||
parts = re.split(r"(?<=[;:])\s+", sentence)
|
||||
|
||||
for part in parts:
|
||||
part = part.strip()
|
||||
if not part:
|
||||
continue
|
||||
|
||||
# If part is still long, split on commas
|
||||
if len(part) > max_chunk:
|
||||
subparts = re.split(r"(?<=,)\s+", part)
|
||||
for subpart in subparts:
|
||||
subpart = subpart.strip()
|
||||
if subpart:
|
||||
yield subpart
|
||||
else:
|
||||
yield part
|
||||
else:
|
||||
yield sentence
|
203
api/src/services/text_processing/text_processor.py
Normal file
|
@ -0,0 +1,203 @@
|
|||
"""Unified text processing for TTS with smart chunking."""
|
||||
|
||||
import re
|
||||
import time
|
||||
from typing import AsyncGenerator, List, Tuple
|
||||
from loguru import logger
|
||||
from .phonemizer import phonemize
|
||||
from .normalizer import normalize_text
|
||||
from .vocabulary import tokenize
|
||||
|
||||
# Target token ranges
|
||||
TARGET_MIN = 200
|
||||
TARGET_MAX = 350
|
||||
ABSOLUTE_MAX = 500
|
||||
|
||||
def process_text_chunk(text: str, language: str = "a", skip_phonemize: bool = False) -> List[int]:
|
||||
"""Process a chunk of text through normalization, phonemization, and tokenization.
|
||||
|
||||
Args:
|
||||
text: Text chunk to process
|
||||
language: Language code for phonemization
|
||||
skip_phonemize: If True, treat input as phonemes and skip normalization/phonemization
|
||||
|
||||
Returns:
|
||||
List of token IDs
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
if skip_phonemize:
|
||||
# Input is already phonemes, just tokenize
|
||||
t0 = time.time()
|
||||
tokens = tokenize(text)
|
||||
t1 = time.time()
|
||||
else:
|
||||
# Normal text processing pipeline
|
||||
t0 = time.time()
|
||||
normalized = normalize_text(text)
|
||||
t1 = time.time()
|
||||
|
||||
|
||||
t0 = time.time()
|
||||
phonemes = phonemize(normalized, language, normalize=False) # Already normalized
|
||||
t1 = time.time()
|
||||
|
||||
t0 = time.time()
|
||||
tokens = tokenize(phonemes)
|
||||
t1 = time.time()
|
||||
|
||||
total_time = time.time() - start_time
|
||||
logger.debug(f"Total processing took {total_time*1000:.2f}ms for chunk: '{text[:50]}...'")
|
||||
|
||||
return tokens
|
||||
|
||||
async def yield_chunk(text: str, tokens: List[int], chunk_count: int) -> Tuple[str, List[int]]:
|
||||
"""Yield a chunk with consistent logging."""
|
||||
logger.debug(f"Yielding chunk {chunk_count}: '{text[:50]}...' ({len(tokens)} tokens)")
|
||||
return text, tokens
|
||||
|
||||
def process_text(text: str, language: str = "a") -> List[int]:
|
||||
"""Process text into token IDs.
|
||||
|
||||
Args:
|
||||
text: Text to process
|
||||
language: Language code for phonemization
|
||||
|
||||
Returns:
|
||||
List of token IDs
|
||||
"""
|
||||
if not isinstance(text, str):
|
||||
text = str(text) if text is not None else ""
|
||||
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
|
||||
return process_text_chunk(text, language)
|
||||
|
||||
|
||||
def get_sentence_info(text: str) -> List[Tuple[str, List[int], int]]:
|
||||
"""Process all sentences and return info."""
|
||||
sentences = re.split(r'([.!?;:])', text)
|
||||
results = []
|
||||
|
||||
for i in range(0, len(sentences), 2):
|
||||
sentence = sentences[i].strip()
|
||||
punct = sentences[i + 1] if i + 1 < len(sentences) else ""
|
||||
|
||||
if not sentence:
|
||||
continue
|
||||
|
||||
full = sentence + punct
|
||||
tokens = process_text_chunk(full)
|
||||
results.append((full, tokens, len(tokens)))
|
||||
|
||||
return results
|
||||
|
||||
async def smart_split(text: str, max_tokens: int = ABSOLUTE_MAX) -> AsyncGenerator[Tuple[str, List[int]], None]:
|
||||
"""Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens."""
|
||||
start_time = time.time()
|
||||
chunk_count = 0
|
||||
logger.info(f"Starting smart split for {len(text)} chars")
|
||||
|
||||
# Process all sentences
|
||||
sentences = get_sentence_info(text)
|
||||
|
||||
current_chunk = []
|
||||
current_tokens = []
|
||||
current_count = 0
|
||||
|
||||
for sentence, tokens, count in sentences:
|
||||
# Handle sentences that exceed max tokens
|
||||
if count > max_tokens:
|
||||
# Yield current chunk if any
|
||||
if current_chunk:
|
||||
chunk_text = " ".join(current_chunk)
|
||||
chunk_count += 1
|
||||
logger.debug(f"Yielding chunk {chunk_count}: '{chunk_text[:50]}...' ({current_count} tokens)")
|
||||
yield chunk_text, current_tokens
|
||||
current_chunk = []
|
||||
current_tokens = []
|
||||
current_count = 0
|
||||
|
||||
# Split long sentence on commas
|
||||
clauses = re.split(r'([,])', sentence)
|
||||
clause_chunk = []
|
||||
clause_tokens = []
|
||||
clause_count = 0
|
||||
|
||||
for j in range(0, len(clauses), 2):
|
||||
clause = clauses[j].strip()
|
||||
comma = clauses[j + 1] if j + 1 < len(clauses) else ""
|
||||
|
||||
if not clause:
|
||||
continue
|
||||
|
||||
full_clause = clause + comma
|
||||
tokens = process_text_chunk(full_clause)
|
||||
count = len(tokens)
|
||||
|
||||
# If adding clause keeps us under max and not optimal yet
|
||||
if clause_count + count <= max_tokens and clause_count + count <= TARGET_MAX:
|
||||
clause_chunk.append(full_clause)
|
||||
clause_tokens.extend(tokens)
|
||||
clause_count += count
|
||||
else:
|
||||
# Yield clause chunk if we have one
|
||||
if clause_chunk:
|
||||
chunk_text = " ".join(clause_chunk)
|
||||
chunk_count += 1
|
||||
logger.debug(f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}...' ({clause_count} tokens)")
|
||||
yield chunk_text, clause_tokens
|
||||
clause_chunk = [full_clause]
|
||||
clause_tokens = tokens
|
||||
clause_count = count
|
||||
|
||||
# Don't forget last clause chunk
|
||||
if clause_chunk:
|
||||
chunk_text = " ".join(clause_chunk)
|
||||
chunk_count += 1
|
||||
logger.debug(f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}...' ({clause_count} tokens)")
|
||||
yield chunk_text, clause_tokens
|
||||
|
||||
# Regular sentence handling
|
||||
elif current_count >= TARGET_MIN and current_count + count > TARGET_MAX:
|
||||
# If we have a good sized chunk and adding next sentence exceeds target,
|
||||
# yield current chunk and start new one
|
||||
chunk_text = " ".join(current_chunk)
|
||||
chunk_count += 1
|
||||
logger.info(f"Yielding chunk {chunk_count}: '{chunk_text[:50]}...' ({current_count} tokens)")
|
||||
yield chunk_text, current_tokens
|
||||
current_chunk = [sentence]
|
||||
current_tokens = tokens
|
||||
current_count = count
|
||||
elif current_count + count <= TARGET_MAX:
|
||||
# Keep building chunk while under target max
|
||||
current_chunk.append(sentence)
|
||||
current_tokens.extend(tokens)
|
||||
current_count += count
|
||||
elif current_count + count <= max_tokens and current_count < TARGET_MIN:
|
||||
# Only exceed target max if we haven't reached minimum size yet
|
||||
current_chunk.append(sentence)
|
||||
current_tokens.extend(tokens)
|
||||
current_count += count
|
||||
else:
|
||||
# Yield current chunk and start new one
|
||||
if current_chunk:
|
||||
chunk_text = " ".join(current_chunk)
|
||||
chunk_count += 1
|
||||
logger.info(f"Yielding chunk {chunk_count}: '{chunk_text[:50]}...' ({current_count} tokens)")
|
||||
yield chunk_text, current_tokens
|
||||
current_chunk = [sentence]
|
||||
current_tokens = tokens
|
||||
current_count = count
|
||||
|
||||
# Don't forget the last chunk
|
||||
if current_chunk:
|
||||
chunk_text = " ".join(current_chunk)
|
||||
chunk_count += 1
|
||||
logger.info(f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}...' ({current_count} tokens)")
|
||||
yield chunk_text, current_tokens
|
||||
|
||||
total_time = time.time() - start_time
|
||||
logger.info(f"Split completed in {total_time*1000:.2f}ms, produced {chunk_count} chunks")
|
|
@ -1,175 +0,0 @@
|
|||
import os
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from ..core.config import settings
|
||||
|
||||
|
||||
class TTSBaseModel(ABC):
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
_device = None
|
||||
VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices")
|
||||
|
||||
@classmethod
|
||||
async def setup(cls):
|
||||
"""Initialize model and setup voices"""
|
||||
with cls._lock:
|
||||
# Set device
|
||||
cuda_available = torch.cuda.is_available()
|
||||
logger.info(f"CUDA available: {cuda_available}")
|
||||
if cuda_available:
|
||||
try:
|
||||
# Test CUDA device
|
||||
test_tensor = torch.zeros(1).cuda()
|
||||
logger.info("CUDA test successful")
|
||||
model_path = os.path.join(
|
||||
settings.model_dir, settings.pytorch_model_path
|
||||
)
|
||||
cls._device = "cuda"
|
||||
except Exception as e:
|
||||
logger.error(f"CUDA test failed: {e}")
|
||||
cls._device = "cpu"
|
||||
else:
|
||||
cls._device = "cpu"
|
||||
model_path = os.path.join(settings.model_dir, settings.onnx_model_path)
|
||||
logger.info(f"Initializing model on {cls._device}")
|
||||
logger.info(f"Model dir: {settings.model_dir}")
|
||||
logger.info(f"Model path: {model_path}")
|
||||
logger.info(f"Files in model dir: {os.listdir(settings.model_dir)}")
|
||||
|
||||
# Initialize model first
|
||||
model = cls.initialize(settings.model_dir, model_path=model_path)
|
||||
if model is None:
|
||||
raise RuntimeError(f"Failed to initialize {cls._device.upper()} model")
|
||||
cls._instance = model
|
||||
|
||||
# Setup voices directory
|
||||
os.makedirs(cls.VOICES_DIR, exist_ok=True)
|
||||
|
||||
# Copy base voices to local directory
|
||||
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
|
||||
if os.path.exists(base_voices_dir):
|
||||
for file in os.listdir(base_voices_dir):
|
||||
if file.endswith(".pt"):
|
||||
voice_name = file[:-3]
|
||||
voice_path = os.path.join(cls.VOICES_DIR, file)
|
||||
if not os.path.exists(voice_path):
|
||||
try:
|
||||
logger.info(
|
||||
f"Copying base voice {voice_name} to voices directory"
|
||||
)
|
||||
base_path = os.path.join(base_voices_dir, file)
|
||||
voicepack = torch.load(
|
||||
base_path,
|
||||
map_location=cls._device,
|
||||
weights_only=True,
|
||||
)
|
||||
torch.save(voicepack, voice_path)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error copying voice {voice_name}: {str(e)}"
|
||||
)
|
||||
|
||||
# Count voices in directory
|
||||
voice_count = len(
|
||||
[f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]
|
||||
)
|
||||
|
||||
# Now that model and voices are ready, do warmup
|
||||
try:
|
||||
with open(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
"core",
|
||||
"don_quixote.txt",
|
||||
)
|
||||
) as f:
|
||||
warmup_text = f.read()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load warmup text: {e}")
|
||||
warmup_text = "This is a warmup text that will be split into chunks for processing."
|
||||
|
||||
# Use warmup service after model is fully initialized
|
||||
from .warmup import WarmupService
|
||||
|
||||
warmup = WarmupService()
|
||||
|
||||
# Load and warm up voices
|
||||
loaded_voices = warmup.load_voices()
|
||||
await warmup.warmup_voices(warmup_text, loaded_voices)
|
||||
|
||||
logger.info("Model warm-up complete")
|
||||
|
||||
# Count voices in directory
|
||||
voice_count = len(
|
||||
[f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]
|
||||
)
|
||||
return voice_count
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def initialize(cls, model_dir: str, model_path: str = None):
|
||||
"""Initialize the model"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def process_text(cls, text: str, language: str) -> Tuple[str, List[int]]:
|
||||
"""Process text into phonemes and tokens
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
language: Language code
|
||||
|
||||
Returns:
|
||||
tuple[str, list[int]]: Phonemes and token IDs
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def generate_from_text(
|
||||
cls, text: str, voicepack: torch.Tensor, language: str, speed: float
|
||||
) -> Tuple[np.ndarray, str]:
|
||||
"""Generate audio from text
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
voicepack: Voice tensor
|
||||
language: Language code
|
||||
speed: Speed factor
|
||||
|
||||
Returns:
|
||||
tuple[np.ndarray, str]: Generated audio samples and phonemes
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def generate_from_tokens(
|
||||
cls, tokens: List[int], voicepack: torch.Tensor, speed: float
|
||||
) -> np.ndarray:
|
||||
"""Generate audio from tokens
|
||||
|
||||
Args:
|
||||
tokens: Token IDs
|
||||
voicepack: Voice tensor
|
||||
speed: Speed factor
|
||||
|
||||
Returns:
|
||||
np.ndarray: Generated audio samples
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_device(cls):
|
||||
"""Get the current device"""
|
||||
if cls._device is None:
|
||||
raise RuntimeError("Model not initialized. Call setup() first.")
|
||||
return cls._device
|
|
@ -1,167 +0,0 @@
|
|||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
from onnxruntime import (
|
||||
ExecutionMode,
|
||||
GraphOptimizationLevel,
|
||||
InferenceSession,
|
||||
SessionOptions,
|
||||
)
|
||||
|
||||
from ..core.config import settings
|
||||
from .text_processing import phonemize, tokenize
|
||||
from .tts_base import TTSBaseModel
|
||||
|
||||
|
||||
class TTSCPUModel(TTSBaseModel):
|
||||
_instance = None
|
||||
_onnx_session = None
|
||||
_device = "cpu"
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
"""Get the model instance"""
|
||||
if cls._onnx_session is None:
|
||||
raise RuntimeError("ONNX model not initialized. Call initialize() first.")
|
||||
return cls._onnx_session
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, model_dir: str, model_path: str = None):
|
||||
"""Initialize ONNX model for CPU inference"""
|
||||
if cls._onnx_session is None:
|
||||
try:
|
||||
# Try loading ONNX model
|
||||
onnx_path = os.path.join(model_dir, settings.onnx_model_path)
|
||||
if not os.path.exists(onnx_path):
|
||||
logger.error(f"ONNX model not found at {onnx_path}")
|
||||
return None
|
||||
|
||||
logger.info(f"Loading ONNX model from {onnx_path}")
|
||||
|
||||
# Configure ONNX session for optimal performance
|
||||
session_options = SessionOptions()
|
||||
|
||||
# Set optimization level
|
||||
if settings.onnx_optimization_level == "all":
|
||||
session_options.graph_optimization_level = (
|
||||
GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
)
|
||||
elif settings.onnx_optimization_level == "basic":
|
||||
session_options.graph_optimization_level = (
|
||||
GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||
)
|
||||
else:
|
||||
session_options.graph_optimization_level = (
|
||||
GraphOptimizationLevel.ORT_DISABLE_ALL
|
||||
)
|
||||
|
||||
# Configure threading
|
||||
session_options.intra_op_num_threads = settings.onnx_num_threads
|
||||
session_options.inter_op_num_threads = settings.onnx_inter_op_threads
|
||||
|
||||
# Set execution mode
|
||||
session_options.execution_mode = (
|
||||
ExecutionMode.ORT_PARALLEL
|
||||
if settings.onnx_execution_mode == "parallel"
|
||||
else ExecutionMode.ORT_SEQUENTIAL
|
||||
)
|
||||
|
||||
# Enable/disable memory pattern optimization
|
||||
session_options.enable_mem_pattern = settings.onnx_memory_pattern
|
||||
|
||||
# Configure CPU provider options
|
||||
provider_options = {
|
||||
"CPUExecutionProvider": {
|
||||
"arena_extend_strategy": settings.onnx_arena_extend_strategy,
|
||||
"cpu_memory_arena_cfg": "cpu:0",
|
||||
}
|
||||
}
|
||||
|
||||
session = InferenceSession(
|
||||
onnx_path,
|
||||
sess_options=session_options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
provider_options=[provider_options],
|
||||
)
|
||||
cls._onnx_session = session
|
||||
return session
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize ONNX model: {e}")
|
||||
return None
|
||||
return cls._onnx_session
|
||||
|
||||
@classmethod
|
||||
def process_text(cls, text: str, language: str) -> tuple[str, list[int]]:
|
||||
"""Process text into phonemes and tokens
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
language: Language code
|
||||
|
||||
Returns:
|
||||
tuple[str, list[int]]: Phonemes and token IDs
|
||||
"""
|
||||
phonemes = phonemize(text, language)
|
||||
tokens = tokenize(phonemes)
|
||||
tokens = [0] + tokens + [0] # Add start/end tokens
|
||||
return phonemes, tokens
|
||||
|
||||
@classmethod
|
||||
def generate_from_text(
|
||||
cls, text: str, voicepack: torch.Tensor, language: str, speed: float
|
||||
) -> tuple[np.ndarray, str]:
|
||||
"""Generate audio from text
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
voicepack: Voice tensor
|
||||
language: Language code
|
||||
speed: Speed factor
|
||||
|
||||
Returns:
|
||||
tuple[np.ndarray, str]: Generated audio samples and phonemes
|
||||
"""
|
||||
if cls._onnx_session is None:
|
||||
raise RuntimeError("ONNX model not initialized")
|
||||
|
||||
# Process text
|
||||
phonemes, tokens = cls.process_text(text, language)
|
||||
|
||||
# Generate audio
|
||||
audio = cls.generate_from_tokens(tokens, voicepack, speed)
|
||||
|
||||
return audio, phonemes
|
||||
|
||||
@classmethod
|
||||
def generate_from_tokens(
|
||||
cls, tokens: list[int], voicepack: torch.Tensor, speed: float
|
||||
) -> np.ndarray:
|
||||
"""Generate audio from tokens
|
||||
|
||||
Args:
|
||||
tokens: Token IDs
|
||||
voicepack: Voice tensor
|
||||
speed: Speed factor
|
||||
|
||||
Returns:
|
||||
np.ndarray: Generated audio samples
|
||||
"""
|
||||
if cls._onnx_session is None:
|
||||
raise RuntimeError("ONNX model not initialized")
|
||||
|
||||
# Pre-allocate and prepare inputs
|
||||
tokens_input = np.array([tokens], dtype=np.int64)
|
||||
style_input = voicepack[
|
||||
len(tokens) - 2
|
||||
].numpy() # Already has correct dimensions
|
||||
speed_input = np.full(
|
||||
1, speed, dtype=np.float32
|
||||
) # More efficient than ones * speed
|
||||
|
||||
# Run inference with optimized inputs
|
||||
result = cls._onnx_session.run(
|
||||
None, {"tokens": tokens_input, "style": style_input, "speed": speed_input}
|
||||
)
|
||||
return result[0]
|
|
@ -1,262 +0,0 @@
|
|||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from ..builds.models import build_model
|
||||
from loguru import logger
|
||||
|
||||
from ..core.config import settings
|
||||
from .text_processing import phonemize, tokenize
|
||||
from .tts_base import TTSBaseModel
|
||||
|
||||
|
||||
# @torch.no_grad()
|
||||
# def forward(model, tokens, ref_s, speed):
|
||||
# """Forward pass through the model"""
|
||||
# device = ref_s.device
|
||||
# tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
||||
# input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
||||
# text_mask = length_to_mask(input_lengths).to(device)
|
||||
# bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
||||
# d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
||||
# s = ref_s[:, 128:]
|
||||
# d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
|
||||
# x, _ = model.predictor.lstm(d)
|
||||
# duration = model.predictor.duration_proj(x)
|
||||
# duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
||||
# pred_dur = torch.round(duration).clamp(min=1).long()
|
||||
# pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
|
||||
# c_frame = 0
|
||||
# for i in range(pred_aln_trg.size(0)):
|
||||
# pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
|
||||
# c_frame += pred_dur[0, i].item()
|
||||
# en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
|
||||
# F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
||||
# t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
||||
# asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
|
||||
# return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
|
||||
@torch.no_grad()
|
||||
def forward(model, tokens, ref_s, speed):
|
||||
"""Forward pass through the model with moderate memory management"""
|
||||
device = ref_s.device
|
||||
|
||||
try:
|
||||
# Initial tensor setup with proper device placement
|
||||
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
||||
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
||||
text_mask = length_to_mask(input_lengths).to(device)
|
||||
|
||||
# Split and clone reference signals with explicit device placement
|
||||
s_content = ref_s[:, 128:].clone().to(device)
|
||||
s_ref = ref_s[:, :128].clone().to(device)
|
||||
|
||||
# BERT and encoder pass
|
||||
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
||||
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
||||
|
||||
# Predictor forward pass
|
||||
d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask)
|
||||
x, _ = model.predictor.lstm(d)
|
||||
|
||||
# Duration prediction
|
||||
duration = model.predictor.duration_proj(x)
|
||||
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
||||
pred_dur = torch.round(duration).clamp(min=1).long()
|
||||
# Only cleanup large intermediates
|
||||
del duration, x
|
||||
|
||||
# Alignment matrix construction
|
||||
pred_aln_trg = torch.zeros(input_lengths.item(), pred_dur.sum().item(), device=device)
|
||||
c_frame = 0
|
||||
for i in range(pred_aln_trg.size(0)):
|
||||
pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
|
||||
c_frame += pred_dur[0, i].item()
|
||||
pred_aln_trg = pred_aln_trg.unsqueeze(0)
|
||||
|
||||
# Matrix multiplications with selective cleanup
|
||||
en = d.transpose(-1, -2) @ pred_aln_trg
|
||||
del d # Free large intermediate tensor
|
||||
|
||||
F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
|
||||
del en # Free large intermediate tensor
|
||||
|
||||
# Final text encoding and decoding
|
||||
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
||||
asr = t_en @ pred_aln_trg
|
||||
del t_en # Free large intermediate tensor
|
||||
|
||||
# Final decoding and transfer to CPU
|
||||
output = model.decoder(asr, F0_pred, N_pred, s_ref)
|
||||
result = output.squeeze().cpu().numpy()
|
||||
|
||||
return result
|
||||
|
||||
finally:
|
||||
# Let PyTorch handle most cleanup automatically
|
||||
# Only explicitly free the largest tensors
|
||||
del pred_aln_trg, asr
|
||||
|
||||
|
||||
# def length_to_mask(lengths):
|
||||
# """Create attention mask from lengths"""
|
||||
# mask = (
|
||||
# torch.arange(lengths.max())
|
||||
# .unsqueeze(0)
|
||||
# .expand(lengths.shape[0], -1)
|
||||
# .type_as(lengths)
|
||||
# )
|
||||
# mask = torch.gt(mask + 1, lengths.unsqueeze(1))
|
||||
# return mask
|
||||
|
||||
|
||||
def length_to_mask(lengths):
|
||||
"""Create attention mask from lengths - possibly optimized version"""
|
||||
max_len = lengths.max()
|
||||
# Create mask directly on the same device as lengths
|
||||
mask = torch.arange(max_len, device=lengths.device)[None, :].expand(
|
||||
lengths.shape[0], -1
|
||||
)
|
||||
# Avoid type_as by using the correct dtype from the start
|
||||
if lengths.dtype != mask.dtype:
|
||||
mask = mask.to(dtype=lengths.dtype)
|
||||
# Fuse operations using broadcasting
|
||||
return mask + 1 > lengths[:, None]
|
||||
|
||||
|
||||
class TTSGPUModel(TTSBaseModel):
|
||||
_instance = None
|
||||
_device = "cuda"
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
"""Get the model instance"""
|
||||
if cls._instance is None:
|
||||
raise RuntimeError("GPU model not initialized. Call initialize() first.")
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, model_dir: str, model_path: str):
|
||||
"""Initialize PyTorch model for GPU inference"""
|
||||
if cls._instance is None and torch.cuda.is_available():
|
||||
try:
|
||||
logger.info("Initializing GPU model")
|
||||
model_path = os.path.join(model_dir, settings.pytorch_model_path)
|
||||
model = build_model(model_path, cls._device)
|
||||
cls._instance = model
|
||||
return model
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize GPU model: {e}")
|
||||
return None
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def process_text(cls, text: str, language: str) -> tuple[str, list[int]]:
|
||||
"""Process text into phonemes and tokens
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
language: Language code
|
||||
|
||||
Returns:
|
||||
tuple[str, list[int]]: Phonemes and token IDs
|
||||
"""
|
||||
phonemes = phonemize(text, language)
|
||||
tokens = tokenize(phonemes)
|
||||
return phonemes, tokens
|
||||
|
||||
@classmethod
|
||||
def generate_from_text(
|
||||
cls, text: str, voicepack: torch.Tensor, language: str, speed: float
|
||||
) -> tuple[np.ndarray, str]:
|
||||
"""Generate audio from text
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
voicepack: Voice tensor
|
||||
language: Language code
|
||||
speed: Speed factor
|
||||
|
||||
Returns:
|
||||
tuple[np.ndarray, str]: Generated audio samples and phonemes
|
||||
"""
|
||||
if cls._instance is None:
|
||||
raise RuntimeError("GPU model not initialized")
|
||||
|
||||
# Process text
|
||||
phonemes, tokens = cls.process_text(text, language)
|
||||
|
||||
# Generate audio
|
||||
audio = cls.generate_from_tokens(tokens, voicepack, speed)
|
||||
|
||||
return audio, phonemes
|
||||
|
||||
@classmethod
|
||||
def generate_from_tokens(
|
||||
cls, tokens: list[int], voicepack: torch.Tensor, speed: float
|
||||
) -> np.ndarray:
|
||||
"""Generate audio from tokens with moderate memory management
|
||||
|
||||
Args:
|
||||
tokens: Token IDs
|
||||
voicepack: Voice tensor
|
||||
speed: Speed factor
|
||||
|
||||
Returns:
|
||||
np.ndarray: Generated audio samples
|
||||
"""
|
||||
if cls._instance is None:
|
||||
raise RuntimeError("GPU model not initialized")
|
||||
|
||||
try:
|
||||
device = cls._device
|
||||
|
||||
# Check memory pressure
|
||||
if torch.cuda.is_available():
|
||||
memory_allocated = torch.cuda.memory_allocated(device) / 1e9 # Convert to GB
|
||||
if memory_allocated > 2.0: # 2GB limit
|
||||
logger.info(
|
||||
f"Memory usage above 2GB threshold:{memory_allocated:.2f}GB "
|
||||
f"Clearing cache"
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
import gc
|
||||
gc.collect()
|
||||
|
||||
# Get reference style with proper device placement
|
||||
ref_s = voicepack[len(tokens)].clone().to(device)
|
||||
|
||||
# Generate audio
|
||||
audio = forward(cls._instance, tokens, ref_s, speed)
|
||||
|
||||
return audio
|
||||
|
||||
except RuntimeError as e:
|
||||
if "out of memory" in str(e):
|
||||
# On OOM, do a full cleanup and retry
|
||||
if torch.cuda.is_available():
|
||||
logger.warning("Out of memory detected, performing full cleanup")
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
import gc
|
||||
gc.collect()
|
||||
|
||||
# Log memory stats after cleanup
|
||||
memory_allocated = torch.cuda.memory_allocated(device)
|
||||
memory_reserved = torch.cuda.memory_reserved(device)
|
||||
logger.info(
|
||||
f"Memory after OOM cleanup: "
|
||||
f"Allocated: {memory_allocated / 1e9:.2f}GB, "
|
||||
f"Reserved: {memory_reserved / 1e9:.2f}GB"
|
||||
)
|
||||
|
||||
# Retry generation
|
||||
ref_s = voicepack[len(tokens)].clone().to(device)
|
||||
audio = forward(cls._instance, tokens, ref_s, speed)
|
||||
return audio
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Only synchronize at the top level, no empty_cache
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
|
@ -1,8 +0,0 @@
|
|||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
from .tts_gpu import TTSGPUModel as TTSModel
|
||||
else:
|
||||
from .tts_cpu import TTSCPUModel as TTSModel
|
||||
|
||||
__all__ = ["TTSModel"]
|
|
@ -1,120 +1,255 @@
|
|||
import io
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional, Tuple
|
||||
"""TTS service using model and voice managers."""
|
||||
|
||||
import time
|
||||
from typing import List, Tuple, Optional, AsyncGenerator, Union
|
||||
import asyncio
|
||||
|
||||
import aiofiles.os
|
||||
import numpy as np
|
||||
import scipy.io.wavfile as wavfile
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from ..core.config import settings
|
||||
from ..inference.model_manager import get_manager as get_model_manager
|
||||
from ..inference.voice_manager import get_manager as get_voice_manager
|
||||
from .audio import AudioNormalizer, AudioService
|
||||
from .text_processing import chunker, normalize_text
|
||||
from .tts_model import TTSModel
|
||||
|
||||
from .text_processing.text_processor import process_text_chunk, smart_split
|
||||
from .text_processing import tokenize
|
||||
|
||||
class TTSService:
|
||||
"""Text-to-speech service."""
|
||||
|
||||
# Limit concurrent chunk processing
|
||||
_chunk_semaphore = asyncio.Semaphore(4)
|
||||
|
||||
def __init__(self, output_dir: str = None):
|
||||
"""Initialize service."""
|
||||
self.output_dir = output_dir
|
||||
self.model = TTSModel.get_instance()
|
||||
self.model_manager = None
|
||||
self._voice_manager = None
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=3) # Cache up to 3 most recently used voices
|
||||
def _load_voice(voice_path: str) -> torch.Tensor:
|
||||
"""Load and cache a voice model"""
|
||||
return torch.load(
|
||||
voice_path, map_location=TTSModel.get_device(), weights_only=True
|
||||
)
|
||||
@classmethod
|
||||
async def create(cls, output_dir: str = None) -> 'TTSService':
|
||||
"""Create and initialize TTSService instance."""
|
||||
service = cls(output_dir)
|
||||
service.model_manager = await get_model_manager()
|
||||
service._voice_manager = await get_voice_manager()
|
||||
return service
|
||||
|
||||
def _get_voice_path(self, voice_name: str) -> Optional[str]:
|
||||
"""Get the path to a voice file"""
|
||||
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice_name}.pt")
|
||||
return voice_path if os.path.exists(voice_path) else None
|
||||
async def _process_chunk(
|
||||
self,
|
||||
tokens: List[int],
|
||||
voice_tensor: torch.Tensor,
|
||||
speed: float,
|
||||
output_format: Optional[str] = None,
|
||||
is_first: bool = False,
|
||||
is_last: bool = False,
|
||||
normalizer: Optional[AudioNormalizer] = None,
|
||||
) -> Optional[Union[np.ndarray, bytes]]:
|
||||
"""Process tokens into audio."""
|
||||
async with self._chunk_semaphore:
|
||||
try:
|
||||
# Handle stream finalization
|
||||
if is_last:
|
||||
# Skip format conversion for raw audio mode
|
||||
if not output_format:
|
||||
return np.array([], dtype=np.float32)
|
||||
|
||||
return await AudioService.convert_audio(
|
||||
np.array([0], dtype=np.float32), # Dummy data for type checking
|
||||
24000,
|
||||
output_format,
|
||||
is_first_chunk=False,
|
||||
normalizer=normalizer,
|
||||
is_last_chunk=True
|
||||
)
|
||||
|
||||
# Skip empty chunks
|
||||
if not tokens:
|
||||
return None
|
||||
|
||||
def _generate_audio(
|
||||
self, text: str, voice: str, speed: float, stitch_long_output: bool = True
|
||||
) -> Tuple[torch.Tensor, float]:
|
||||
"""Generate complete audio and return with processing time"""
|
||||
audio, processing_time = self._generate_audio_internal(
|
||||
text, voice, speed, stitch_long_output
|
||||
)
|
||||
return audio, processing_time
|
||||
# Generate audio using pre-warmed model
|
||||
chunk_audio = await self.model_manager.generate(
|
||||
tokens,
|
||||
voice_tensor,
|
||||
speed=speed
|
||||
)
|
||||
|
||||
if chunk_audio is None:
|
||||
logger.error("Model generated None for audio chunk")
|
||||
return None
|
||||
|
||||
if len(chunk_audio) == 0:
|
||||
logger.error("Model generated empty audio chunk")
|
||||
return None
|
||||
|
||||
# For streaming, convert to bytes
|
||||
if output_format:
|
||||
try:
|
||||
return await AudioService.convert_audio(
|
||||
chunk_audio,
|
||||
24000,
|
||||
output_format,
|
||||
is_first_chunk=is_first,
|
||||
normalizer=normalizer,
|
||||
is_last_chunk=is_last
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert audio: {str(e)}")
|
||||
return None
|
||||
|
||||
return chunk_audio
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process tokens: {str(e)}")
|
||||
return None
|
||||
|
||||
def _generate_audio_internal(
|
||||
self, text: str, voice: str, speed: float, stitch_long_output: bool = True
|
||||
) -> Tuple[torch.Tensor, float]:
|
||||
"""Generate audio and measure processing time"""
|
||||
async def generate_audio_stream(
|
||||
self,
|
||||
text: str,
|
||||
voice: str,
|
||||
speed: float = 1.0,
|
||||
output_format: str = "wav",
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""Generate and stream audio chunks."""
|
||||
stream_normalizer = AudioNormalizer()
|
||||
voice_tensor = None
|
||||
chunk_index = 0
|
||||
|
||||
try:
|
||||
# Get backend and load voice (should be fast if cached)
|
||||
backend = self.model_manager.get_backend()
|
||||
voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device)
|
||||
|
||||
# Process text in chunks with smart splitting
|
||||
async for chunk_text, tokens in smart_split(text):
|
||||
try:
|
||||
# Process audio for chunk
|
||||
result = await self._process_chunk(
|
||||
tokens, # Now always a flat List[int]
|
||||
voice_tensor,
|
||||
speed,
|
||||
output_format,
|
||||
is_first=(chunk_index == 0),
|
||||
is_last=False, # We'll update the last chunk later
|
||||
normalizer=stream_normalizer
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
yield result
|
||||
chunk_index += 1
|
||||
else:
|
||||
logger.warning(f"No audio generated for chunk: '{chunk_text[:100]}...'")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}")
|
||||
continue
|
||||
|
||||
# Only finalize if we successfully processed at least one chunk
|
||||
if chunk_index > 0:
|
||||
try:
|
||||
# Empty tokens list to finalize audio
|
||||
final_result = await self._process_chunk(
|
||||
[], # Empty tokens list
|
||||
voice_tensor,
|
||||
speed,
|
||||
output_format,
|
||||
is_first=False,
|
||||
is_last=True,
|
||||
normalizer=stream_normalizer
|
||||
)
|
||||
if final_result is not None:
|
||||
logger.debug("Yielding final chunk to finalize audio")
|
||||
yield final_result
|
||||
else:
|
||||
logger.warning("Final chunk processing returned None")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process final chunk: {str(e)}")
|
||||
else:
|
||||
logger.warning("No audio chunks were successfully processed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in audio generation stream: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
if voice_tensor is not None:
|
||||
del voice_tensor
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
async def generate_from_phonemes(
|
||||
self, phonemes: str, voice: str, speed: float = 1.0
|
||||
) -> Tuple[np.ndarray, float]:
|
||||
"""Generate audio from phonemes.
|
||||
|
||||
Args:
|
||||
phonemes: Phoneme string to synthesize
|
||||
voice: Voice ID to use
|
||||
speed: Speed multiplier
|
||||
|
||||
Returns:
|
||||
Tuple of (audio array, processing time)
|
||||
"""
|
||||
start_time = time.time()
|
||||
voice_tensor = None
|
||||
|
||||
try:
|
||||
# Normalize text once at the start
|
||||
if not text:
|
||||
raise ValueError("Text is empty after preprocessing")
|
||||
normalized = normalize_text(text)
|
||||
if not normalized:
|
||||
raise ValueError("Text is empty after preprocessing")
|
||||
text = str(normalized)
|
||||
# Get backend and load voice
|
||||
backend = self.model_manager.get_backend()
|
||||
voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device)
|
||||
|
||||
# Check voice exists
|
||||
voice_path = self._get_voice_path(voice)
|
||||
if not voice_path:
|
||||
raise ValueError(f"Voice not found: {voice}")
|
||||
# Convert phonemes to tokens
|
||||
tokens = tokenize(phonemes)
|
||||
if len(tokens) > 500: # Model context limit
|
||||
raise ValueError(f"Phoneme sequence too long ({len(tokens)} tokens, max 500)")
|
||||
|
||||
tokens = [0] + tokens + [0] # Add start/end tokens
|
||||
|
||||
# Generate audio
|
||||
audio = await self.model_manager.generate(
|
||||
tokens,
|
||||
voice_tensor,
|
||||
speed=speed
|
||||
)
|
||||
|
||||
if audio is None:
|
||||
raise ValueError("Failed to generate audio")
|
||||
|
||||
# Load voice using cached loader
|
||||
voicepack = self._load_voice(voice_path)
|
||||
processing_time = time.time() - start_time
|
||||
return audio, processing_time
|
||||
|
||||
# For non-streaming, preprocess all chunks first
|
||||
if stitch_long_output:
|
||||
# Preprocess all chunks to phonemes/tokens
|
||||
chunks_data = []
|
||||
for chunk in chunker.split_text(text):
|
||||
try:
|
||||
phonemes, tokens = TTSModel.process_text(chunk, voice[0])
|
||||
chunks_data.append((chunk, tokens))
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to process chunk: '{chunk}'. Error: {str(e)}"
|
||||
)
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error in phoneme audio generation: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
if voice_tensor is not None:
|
||||
del voice_tensor
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if not chunks_data:
|
||||
raise ValueError("No chunks were processed successfully")
|
||||
async def generate_audio(
|
||||
self, text: str, voice: str, speed: float = 1.0
|
||||
) -> Tuple[np.ndarray, float]:
|
||||
"""Generate complete audio for text using streaming internally."""
|
||||
start_time = time.time()
|
||||
chunks = []
|
||||
|
||||
try:
|
||||
# Use streaming generator but collect all valid chunks
|
||||
async for chunk in self.generate_audio_stream(
|
||||
text, voice, speed, # Default to WAV for raw audio
|
||||
):
|
||||
if chunk is not None:
|
||||
chunks.append(chunk)
|
||||
|
||||
# Generate audio for all chunks
|
||||
audio_chunks = []
|
||||
for chunk, tokens in chunks_data:
|
||||
try:
|
||||
chunk_audio = TTSModel.generate_from_tokens(
|
||||
tokens, voicepack, speed
|
||||
)
|
||||
if chunk_audio is not None:
|
||||
audio_chunks.append(chunk_audio)
|
||||
else:
|
||||
logger.error(f"No audio generated for chunk: '{chunk}'")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to generate audio for chunk: '{chunk}'. Error: {str(e)}"
|
||||
)
|
||||
continue
|
||||
if not chunks:
|
||||
raise ValueError("No audio chunks were generated successfully")
|
||||
|
||||
if not audio_chunks:
|
||||
raise ValueError("No audio chunks were generated successfully")
|
||||
|
||||
# Concatenate all chunks
|
||||
audio = (
|
||||
np.concatenate(audio_chunks)
|
||||
if len(audio_chunks) > 1
|
||||
else audio_chunks[0]
|
||||
)
|
||||
# Combine chunks, ensuring we have valid arrays
|
||||
if len(chunks) == 1:
|
||||
audio = chunks[0]
|
||||
else:
|
||||
# Process single chunk
|
||||
phonemes, tokens = TTSModel.process_text(text, voice[0])
|
||||
audio = TTSModel.generate_from_tokens(tokens, voicepack, speed)
|
||||
|
||||
# Filter out any zero-dimensional arrays
|
||||
valid_chunks = [c for c in chunks if c.ndim > 0]
|
||||
if not valid_chunks:
|
||||
raise ValueError("No valid audio chunks to concatenate")
|
||||
audio = np.concatenate(valid_chunks)
|
||||
processing_time = time.time() - start_time
|
||||
return audio, processing_time
|
||||
|
||||
|
@ -122,148 +257,10 @@ class TTSService:
|
|||
logger.error(f"Error in audio generation: {str(e)}")
|
||||
raise
|
||||
|
||||
async def generate_audio_stream(
|
||||
self,
|
||||
text: str,
|
||||
voice: str,
|
||||
speed: float,
|
||||
output_format: str = "wav",
|
||||
silent=False,
|
||||
):
|
||||
"""Generate and yield audio chunks as they're generated for real-time streaming"""
|
||||
try:
|
||||
stream_start = time.time()
|
||||
# Create normalizer for consistent audio levels
|
||||
stream_normalizer = AudioNormalizer()
|
||||
|
||||
# Input validation and preprocessing
|
||||
if not text:
|
||||
raise ValueError("Text is empty")
|
||||
preprocess_start = time.time()
|
||||
normalized = normalize_text(text)
|
||||
if not normalized:
|
||||
raise ValueError("Text is empty after preprocessing")
|
||||
text = str(normalized)
|
||||
logger.debug(
|
||||
f"Text preprocessing took: {(time.time() - preprocess_start)*1000:.1f}ms"
|
||||
)
|
||||
|
||||
# Voice validation and loading
|
||||
voice_start = time.time()
|
||||
voice_path = self._get_voice_path(voice)
|
||||
if not voice_path:
|
||||
raise ValueError(f"Voice not found: {voice}")
|
||||
voicepack = self._load_voice(voice_path)
|
||||
logger.debug(
|
||||
f"Voice loading took: {(time.time() - voice_start)*1000:.1f}ms"
|
||||
)
|
||||
|
||||
# Process chunks as they're generated
|
||||
is_first = True
|
||||
chunks_processed = 0
|
||||
|
||||
# Process chunks as they come from generator
|
||||
chunk_gen = chunker.split_text(text)
|
||||
current_chunk = next(chunk_gen, None)
|
||||
|
||||
while current_chunk is not None:
|
||||
next_chunk = next(chunk_gen, None) # Peek at next chunk
|
||||
chunks_processed += 1
|
||||
try:
|
||||
# Process text and generate audio
|
||||
phonemes, tokens = TTSModel.process_text(current_chunk, voice[0])
|
||||
chunk_audio = TTSModel.generate_from_tokens(
|
||||
tokens, voicepack, speed
|
||||
)
|
||||
|
||||
if chunk_audio is not None:
|
||||
# Convert chunk with proper streaming header handling
|
||||
chunk_bytes = AudioService.convert_audio(
|
||||
chunk_audio,
|
||||
24000,
|
||||
output_format,
|
||||
is_first_chunk=is_first,
|
||||
normalizer=stream_normalizer,
|
||||
is_last_chunk=(next_chunk is None), # Last if no next chunk
|
||||
stream=True # Ensure proper streaming format handling
|
||||
)
|
||||
|
||||
yield chunk_bytes
|
||||
is_first = False
|
||||
else:
|
||||
logger.error(f"No audio generated for chunk: '{current_chunk}'")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to generate audio for chunk: '{current_chunk}'. Error: {str(e)}"
|
||||
)
|
||||
|
||||
current_chunk = next_chunk # Move to next chunk
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in audio generation stream: {str(e)}")
|
||||
raise
|
||||
|
||||
def _save_audio(self, audio: torch.Tensor, filepath: str):
|
||||
"""Save audio to file"""
|
||||
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
||||
wavfile.write(filepath, 24000, audio)
|
||||
|
||||
def _audio_to_bytes(self, audio: torch.Tensor) -> bytes:
|
||||
"""Convert audio tensor to WAV bytes"""
|
||||
buffer = io.BytesIO()
|
||||
wavfile.write(buffer, 24000, audio)
|
||||
return buffer.getvalue()
|
||||
|
||||
async def combine_voices(self, voices: List[str]) -> str:
|
||||
"""Combine multiple voices into a new voice"""
|
||||
if len(voices) < 2:
|
||||
raise ValueError("At least 2 voices are required for combination")
|
||||
|
||||
# Load voices
|
||||
t_voices: List[torch.Tensor] = []
|
||||
v_name: List[str] = []
|
||||
|
||||
for voice in voices:
|
||||
try:
|
||||
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt")
|
||||
voicepack = torch.load(
|
||||
voice_path, map_location=TTSModel.get_device(), weights_only=True
|
||||
)
|
||||
t_voices.append(voicepack)
|
||||
v_name.append(voice)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load voice {voice}: {str(e)}")
|
||||
|
||||
# Combine voices
|
||||
try:
|
||||
f: str = "_".join(v_name)
|
||||
v = torch.mean(torch.stack(t_voices), dim=0)
|
||||
combined_path = os.path.join(TTSModel.VOICES_DIR, f"{f}.pt")
|
||||
|
||||
# Save combined voice
|
||||
try:
|
||||
torch.save(v, combined_path)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to save combined voice to {combined_path}: {str(e)}"
|
||||
)
|
||||
|
||||
return f
|
||||
|
||||
except Exception as e:
|
||||
if not isinstance(e, (ValueError, RuntimeError)):
|
||||
raise RuntimeError(f"Error combining voices: {str(e)}")
|
||||
raise
|
||||
"""Combine multiple voices."""
|
||||
return await self._voice_manager.combine_voices(voices)
|
||||
|
||||
async def list_voices(self) -> List[str]:
|
||||
"""List all available voices"""
|
||||
voices = []
|
||||
try:
|
||||
it = await aiofiles.os.scandir(TTSModel.VOICES_DIR)
|
||||
for entry in it:
|
||||
if entry.name.endswith(".pt"):
|
||||
voices.append(entry.name[:-3]) # Remove .pt extension
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing voices: {str(e)}")
|
||||
return sorted(voices)
|
||||
"""List available voices."""
|
||||
return await self._voice_manager.list_voices()
|
|
@ -1,60 +0,0 @@
|
|||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from ..core.config import settings
|
||||
from .tts_model import TTSModel
|
||||
from .tts_service import TTSService
|
||||
|
||||
|
||||
class WarmupService:
|
||||
"""Service for warming up TTS models and voice caches"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize warmup service and ensure model is ready"""
|
||||
# Initialize model if not already initialized
|
||||
if TTSModel._instance is None:
|
||||
TTSModel.initialize(settings.model_dir)
|
||||
self.tts_service = TTSService()
|
||||
|
||||
def load_voices(self) -> List[Tuple[str, torch.Tensor]]:
|
||||
"""Load and cache voices up to LRU limit"""
|
||||
# Get all voices sorted by filename length (shorter names first, usually base voices)
|
||||
voice_files = sorted(
|
||||
[f for f in os.listdir(TTSModel.VOICES_DIR) if f.endswith(".pt")], key=len
|
||||
)
|
||||
|
||||
n_voices_cache = 1
|
||||
loaded_voices = []
|
||||
for voice_file in voice_files[:n_voices_cache]:
|
||||
try:
|
||||
voice_path = os.path.join(TTSModel.VOICES_DIR, voice_file)
|
||||
# load using service, lru cache
|
||||
voicepack = self.tts_service._load_voice(voice_path)
|
||||
loaded_voices.append(
|
||||
(voice_file[:-3], voicepack)
|
||||
) # Store name and tensor
|
||||
# voicepack = torch.load(voice_path, map_location=TTSModel.get_device(), weights_only=True)
|
||||
# logger.info(f"Loaded voice {voice_file[:-3]} into cache")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load voice {voice_file}: {e}")
|
||||
logger.info(f"Pre-loaded {len(loaded_voices)} voices into cache")
|
||||
return loaded_voices
|
||||
|
||||
async def warmup_voices(
|
||||
self, warmup_text: str, loaded_voices: List[Tuple[str, torch.Tensor]]
|
||||
):
|
||||
"""Warm up voice inference and streaming"""
|
||||
n_warmups = 1
|
||||
for voice_name, _ in loaded_voices[:n_warmups]:
|
||||
try:
|
||||
logger.info(f"Running warmup inference on voice {voice_name}")
|
||||
async for _ in self.tts_service.generate_audio_stream(
|
||||
warmup_text, voice_name, 1.0, "pcm"
|
||||
):
|
||||
pass # Process all chunks to properly warm up
|
||||
logger.info(f"Completed warmup for voice {voice_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Warmup failed for voice {voice_name}: {e}")
|
13
api/src/structures/model_schemas.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
"""Voice configuration schemas."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VoiceConfig(BaseModel):
|
||||
"""Voice configuration."""
|
||||
use_cache: bool = Field(True, description="Whether to cache loaded voices")
|
||||
cache_size: int = Field(3, description="Number of voices to cache")
|
||||
validate_on_load: bool = Field(True, description="Whether to validate voices when loading")
|
||||
|
||||
class Config:
|
||||
frozen = True # Make config immutable
|
|
@ -23,7 +23,10 @@ class TTSStatus(str, Enum):
|
|||
|
||||
# OpenAI-compatible schemas
|
||||
class OpenAISpeechRequest(BaseModel):
|
||||
model: Literal["tts-1", "tts-1-hd", "kokoro"] = "kokoro"
|
||||
model: str = Field(
|
||||
default="kokoro",
|
||||
description="The model to use for generation. Supported models: tts-1, tts-1-hd, kokoro"
|
||||
)
|
||||
input: str = Field(..., description="The text to generate audio for")
|
||||
voice: str = Field(
|
||||
default="af",
|
||||
|
@ -43,3 +46,7 @@ class OpenAISpeechRequest(BaseModel):
|
|||
default=True, # Default to streaming for OpenAI compatibility
|
||||
description="If true (default), audio will be streamed as it's generated. Each chunk will be a complete sentence.",
|
||||
)
|
||||
return_download_link: bool = Field(
|
||||
default=False,
|
||||
description="If true, returns a download link in X-Download-Path header after streaming completes",
|
||||
)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing import List, Union, Optional
|
||||
|
||||
|
||||
class PhonemeRequest(BaseModel):
|
||||
|
@ -11,9 +12,27 @@ class PhonemeResponse(BaseModel):
|
|||
tokens: list[int]
|
||||
|
||||
|
||||
class GenerateFromPhonemesRequest(BaseModel):
|
||||
phonemes: str
|
||||
voice: str = Field(..., description="Voice ID to use for generation")
|
||||
speed: float = Field(
|
||||
default=1.0, ge=0.1, le=5.0, description="Speed factor for generation"
|
||||
class StitchOptions(BaseModel):
|
||||
"""Options for stitching audio chunks together"""
|
||||
gap_method: str = Field(
|
||||
default="static_trim",
|
||||
description="Method to handle gaps between chunks. Currently only 'static_trim' supported."
|
||||
)
|
||||
trim_ms: int = Field(
|
||||
default=0,
|
||||
ge=0,
|
||||
description="Milliseconds to trim from chunk boundaries when using static_trim"
|
||||
)
|
||||
|
||||
@field_validator('gap_method')
|
||||
@classmethod
|
||||
def validate_gap_method(cls, v: str) -> str:
|
||||
if v != 'static_trim':
|
||||
raise ValueError("Currently only 'static_trim' gap method is supported")
|
||||
return v
|
||||
|
||||
|
||||
class GenerateFromPhonemesRequest(BaseModel):
|
||||
"""Simple request for phoneme-to-speech generation"""
|
||||
phonemes: str = Field(..., description="Phoneme string to synthesize")
|
||||
voice: str = Field(..., description="Voice ID to use for generation")
|
||||
|
|
BIN
api/src/voices/am_gurney.pt
Normal file
BIN
api/src/voices/v1_0/af_alloy.pt
Normal file
BIN
api/src/voices/v1_0/af_aoede.pt
Normal file
BIN
api/src/voices/v1_0/af_jessica.pt
Normal file
BIN
api/src/voices/v1_0/af_kore.pt
Normal file
BIN
api/src/voices/v1_0/af_nova.pt
Normal file
BIN
api/src/voices/v1_0/af_river.pt
Normal file
BIN
api/src/voices/v1_0/am_echo.pt
Normal file
BIN
api/src/voices/v1_0/am_eric.pt
Normal file
BIN
api/src/voices/v1_0/am_fenrir.pt
Normal file
BIN
api/src/voices/v1_0/am_liam.pt
Normal file
BIN
api/src/voices/v1_0/am_onyx.pt
Normal file
BIN
api/src/voices/v1_0/am_puck.pt
Normal file
BIN
api/src/voices/v1_0/bf_alice.pt
Normal file
BIN
api/src/voices/v1_0/bf_lily.pt
Normal file
BIN
api/src/voices/v1_0/bm_daniel.pt
Normal file
BIN
api/src/voices/v1_0/bm_fable.pt
Normal file
|
@ -1,126 +1,69 @@
|
|||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import aiofiles.threadpool
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
def cleanup_mock_dirs():
|
||||
"""Clean up any MagicMock directories created during tests"""
|
||||
mock_dir = "MagicMock"
|
||||
if os.path.exists(mock_dir):
|
||||
shutil.rmtree(mock_dir)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_aiofiles():
|
||||
"""Setup aiofiles mock wrapper"""
|
||||
aiofiles.threadpool.wrap.register(MagicMock)(
|
||||
lambda *args, **kwargs: aiofiles.threadpool.AsyncBufferedIOBase(*args, **kwargs)
|
||||
)
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup():
|
||||
"""Automatically clean up before and after each test"""
|
||||
cleanup_mock_dirs()
|
||||
yield
|
||||
cleanup_mock_dirs()
|
||||
|
||||
|
||||
# Mock modules before they're imported
|
||||
sys.modules["transformers"] = Mock()
|
||||
sys.modules["phonemizer"] = Mock()
|
||||
sys.modules["models"] = Mock()
|
||||
sys.modules["models.build_model"] = Mock()
|
||||
sys.modules["kokoro"] = Mock()
|
||||
sys.modules["kokoro.generate"] = Mock()
|
||||
sys.modules["kokoro.phonemize"] = Mock()
|
||||
sys.modules["kokoro.tokenize"] = Mock()
|
||||
|
||||
# Mock ONNX runtime
|
||||
mock_onnx = Mock()
|
||||
mock_onnx.InferenceSession = Mock()
|
||||
mock_onnx.SessionOptions = Mock()
|
||||
mock_onnx.GraphOptimizationLevel = Mock()
|
||||
mock_onnx.ExecutionMode = Mock()
|
||||
sys.modules["onnxruntime"] = mock_onnx
|
||||
|
||||
# Create mock settings module
|
||||
mock_settings_module = Mock()
|
||||
mock_settings = Mock()
|
||||
mock_settings.model_dir = "/mock/model/dir"
|
||||
mock_settings.onnx_model_path = "mock.onnx"
|
||||
mock_settings_module.settings = mock_settings
|
||||
sys.modules["api.src.core.config"] = mock_settings_module
|
||||
|
||||
|
||||
class MockTTSModel:
|
||||
_instance = None
|
||||
_onnx_session = None
|
||||
VOICES_DIR = "/mock/voices/dir"
|
||||
|
||||
def __init__(self):
|
||||
self._initialized = False
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, model_dir):
|
||||
cls._onnx_session = Mock()
|
||||
cls._onnx_session.run = Mock(return_value=[np.zeros(48000)])
|
||||
cls._instance._initialized = True
|
||||
return cls._onnx_session
|
||||
|
||||
@classmethod
|
||||
def setup(cls):
|
||||
if not cls._instance._initialized:
|
||||
cls.initialize("/mock/model/dir")
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def generate_from_tokens(cls, tokens, voicepack, speed):
|
||||
if not cls._instance._initialized:
|
||||
raise RuntimeError("Model not initialized. Call setup() first.")
|
||||
return np.zeros(48000)
|
||||
|
||||
@classmethod
|
||||
def process_text(cls, text, language):
|
||||
return "mock phonemes", [1, 2, 3]
|
||||
|
||||
@staticmethod
|
||||
def get_device():
|
||||
return "cpu"
|
||||
|
||||
import pytest_asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import numpy as np
|
||||
import torch
|
||||
from pathlib import Path
|
||||
import os
|
||||
from api.src.services.tts_service import TTSService
|
||||
from api.src.inference.voice_manager import VoiceManager
|
||||
from api.src.inference.model_manager import ModelManager
|
||||
from api.src.structures.model_schemas import VoiceConfig
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tts_service(monkeypatch):
|
||||
"""Mock TTSService for testing"""
|
||||
mock_service = Mock()
|
||||
mock_service._get_voice_path.return_value = "/mock/path/voice.pt"
|
||||
mock_service._load_voice.return_value = np.zeros((1, 192))
|
||||
|
||||
# Mock TTSModel.generate_from_tokens since we call it directly
|
||||
mock_generate = Mock(return_value=np.zeros(48000))
|
||||
monkeypatch.setattr(
|
||||
"api.src.routers.development.TTSModel.generate_from_tokens", mock_generate
|
||||
)
|
||||
|
||||
return mock_service
|
||||
|
||||
def mock_voice_tensor():
|
||||
"""Load a real voice tensor for testing."""
|
||||
voice_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'src/voices/af_bella.pt')
|
||||
return torch.load(voice_path, map_location='cpu', weights_only=False)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_audio_service(monkeypatch):
|
||||
"""Mock AudioService"""
|
||||
mock_service = Mock()
|
||||
mock_service.convert_audio.return_value = b"mock audio data"
|
||||
monkeypatch.setattr("api.src.routers.development.AudioService", mock_service)
|
||||
return mock_service
|
||||
def mock_audio_output():
|
||||
"""Load pre-generated test audio for consistent testing."""
|
||||
test_audio_path = os.path.join(os.path.dirname(__file__), 'test_data/test_audio.npy')
|
||||
return np.load(test_audio_path) # Return as numpy array instead of bytes
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def mock_model_manager(mock_audio_output):
|
||||
"""Mock model manager for testing."""
|
||||
manager = AsyncMock(spec=ModelManager)
|
||||
manager.get_backend = MagicMock()
|
||||
async def mock_generate(*args, **kwargs):
|
||||
# Simulate successful audio generation
|
||||
return np.random.rand(24000).astype(np.float32) # 1 second of random audio data
|
||||
manager.generate = AsyncMock(side_effect=mock_generate)
|
||||
return manager
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def mock_voice_manager(mock_voice_tensor):
|
||||
"""Mock voice manager for testing."""
|
||||
manager = AsyncMock(spec=VoiceManager)
|
||||
manager.get_voice_path = MagicMock(return_value="/mock/path/voice.pt")
|
||||
manager.load_voice = AsyncMock(return_value=mock_voice_tensor)
|
||||
manager.list_voices = AsyncMock(return_value=["voice1", "voice2"])
|
||||
manager.combine_voices = AsyncMock(return_value="voice1_voice2")
|
||||
return manager
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def tts_service(mock_model_manager, mock_voice_manager):
|
||||
"""Get mocked TTS service instance."""
|
||||
service = TTSService()
|
||||
service.model_manager = mock_model_manager
|
||||
service._voice_manager = mock_voice_manager
|
||||
return service
|
||||
|
||||
@pytest.fixture
|
||||
def test_voice():
|
||||
"""Return a test voice name."""
|
||||
return "voice1"
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""Create an instance of the default event loop for the test session."""
|
||||
import asyncio
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
yield loop
|
||||
loop.close()
|
||||
|
|
|
@ -26,106 +26,161 @@ def sample_audio():
|
|||
return np.sin(2 * np.pi * frequency * t).astype(np.float32), sample_rate
|
||||
|
||||
|
||||
def test_convert_to_wav(sample_audio):
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_to_wav(sample_audio):
|
||||
"""Test converting to WAV format"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
result = AudioService.convert_audio(audio_data, sample_rate, "wav")
|
||||
# Write and finalize in one step for WAV
|
||||
result = await AudioService.convert_audio(
|
||||
audio_data,
|
||||
sample_rate,
|
||||
"wav",
|
||||
is_first_chunk=True,
|
||||
is_last_chunk=True
|
||||
)
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) > 0
|
||||
# Check WAV header
|
||||
assert result.startswith(b'RIFF')
|
||||
assert b'WAVE' in result[:12]
|
||||
|
||||
|
||||
def test_convert_to_mp3(sample_audio):
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_to_mp3(sample_audio):
|
||||
"""Test converting to MP3 format"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
result = AudioService.convert_audio(audio_data, sample_rate, "mp3")
|
||||
result = await AudioService.convert_audio(audio_data, sample_rate, "mp3")
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) > 0
|
||||
# Check MP3 header (ID3 or MPEG frame sync)
|
||||
assert result.startswith(b'ID3') or result.startswith(b'\xff\xfb')
|
||||
|
||||
|
||||
def test_convert_to_opus(sample_audio):
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_to_opus(sample_audio):
|
||||
"""Test converting to Opus format"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
result = AudioService.convert_audio(audio_data, sample_rate, "opus")
|
||||
result = await AudioService.convert_audio(audio_data, sample_rate, "opus")
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) > 0
|
||||
# Check OGG header
|
||||
assert result.startswith(b'OggS')
|
||||
|
||||
|
||||
def test_convert_to_flac(sample_audio):
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_to_flac(sample_audio):
|
||||
"""Test converting to FLAC format"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
result = AudioService.convert_audio(audio_data, sample_rate, "flac")
|
||||
result = await AudioService.convert_audio(audio_data, sample_rate, "flac")
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) > 0
|
||||
# Check FLAC header
|
||||
assert result.startswith(b'fLaC')
|
||||
|
||||
|
||||
def test_convert_to_aac(sample_audio):
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_to_aac(sample_audio):
|
||||
"""Test converting to AAC format"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
result = AudioService.convert_audio(audio_data, sample_rate, "aac")
|
||||
result = await AudioService.convert_audio(audio_data, sample_rate, "aac")
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) > 0
|
||||
# AAC files typically start with an ADTS header
|
||||
assert result.startswith(b'\xff\xf1') or result.startswith(b'\xff\xf9')
|
||||
# Check ADTS header (AAC)
|
||||
assert result.startswith(b'\xff\xf0') or result.startswith(b'\xff\xf1')
|
||||
|
||||
|
||||
def test_convert_to_pcm(sample_audio):
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_to_pcm(sample_audio):
|
||||
"""Test converting to PCM format"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
result = AudioService.convert_audio(audio_data, sample_rate, "pcm")
|
||||
result = await AudioService.convert_audio(audio_data, sample_rate, "pcm")
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) > 0
|
||||
# PCM is raw bytes, so no header to check
|
||||
|
||||
|
||||
def test_convert_to_invalid_format_raises_error(sample_audio):
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_to_invalid_format_raises_error(sample_audio):
|
||||
"""Test that converting to an invalid format raises an error"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
with pytest.raises(ValueError, match="Format invalid not supported"):
|
||||
AudioService.convert_audio(audio_data, sample_rate, "invalid")
|
||||
await AudioService.convert_audio(audio_data, sample_rate, "invalid")
|
||||
|
||||
|
||||
def test_normalization_wav(sample_audio):
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalization_wav(sample_audio):
|
||||
"""Test that WAV output is properly normalized to int16 range"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
# Create audio data outside int16 range
|
||||
large_audio = audio_data * 1e5
|
||||
result = AudioService.convert_audio(large_audio, sample_rate, "wav")
|
||||
# Write and finalize in one step for WAV
|
||||
result = await AudioService.convert_audio(
|
||||
large_audio,
|
||||
sample_rate,
|
||||
"wav",
|
||||
is_first_chunk=True,
|
||||
is_last_chunk=True
|
||||
)
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
def test_normalization_pcm(sample_audio):
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalization_pcm(sample_audio):
|
||||
"""Test that PCM output is properly normalized to int16 range"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
# Create audio data outside int16 range
|
||||
large_audio = audio_data * 1e5
|
||||
result = AudioService.convert_audio(large_audio, sample_rate, "pcm")
|
||||
result = await AudioService.convert_audio(large_audio, sample_rate, "pcm")
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
def test_invalid_audio_data():
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_audio_data():
|
||||
"""Test handling of invalid audio data"""
|
||||
invalid_audio = np.array([]) # Empty array
|
||||
sample_rate = 24000
|
||||
with pytest.raises(ValueError):
|
||||
AudioService.convert_audio(invalid_audio, sample_rate, "wav")
|
||||
await AudioService.convert_audio(invalid_audio, sample_rate, "wav")
|
||||
|
||||
|
||||
def test_different_sample_rates(sample_audio):
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_sample_rates(sample_audio):
|
||||
"""Test converting audio with different sample rates"""
|
||||
audio_data, _ = sample_audio
|
||||
sample_rates = [8000, 16000, 44100, 48000]
|
||||
|
||||
for rate in sample_rates:
|
||||
result = AudioService.convert_audio(audio_data, rate, "wav")
|
||||
result = await AudioService.convert_audio(
|
||||
audio_data,
|
||||
rate,
|
||||
"wav",
|
||||
is_first_chunk=True,
|
||||
is_last_chunk=True
|
||||
)
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
def test_buffer_position_after_conversion(sample_audio):
|
||||
@pytest.mark.asyncio
|
||||
async def test_buffer_position_after_conversion(sample_audio):
|
||||
"""Test that buffer position is reset after writing"""
|
||||
audio_data, sample_rate = sample_audio
|
||||
result = AudioService.convert_audio(audio_data, sample_rate, "wav")
|
||||
# Write and finalize in one step for first conversion
|
||||
result = await AudioService.convert_audio(
|
||||
audio_data,
|
||||
sample_rate,
|
||||
"wav",
|
||||
is_first_chunk=True,
|
||||
is_last_chunk=True
|
||||
)
|
||||
# Convert again to ensure buffer was properly reset
|
||||
result2 = AudioService.convert_audio(audio_data, sample_rate, "wav")
|
||||
result2 = await AudioService.convert_audio(
|
||||
audio_data,
|
||||
sample_rate,
|
||||
"wav",
|
||||
is_first_chunk=True,
|
||||
is_last_chunk=True
|
||||
)
|
||||
assert len(result) == len(result2)
|
||||
|
|
|
@ -1,46 +0,0 @@
|
|||
"""Tests for text chunking service"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.src.services.text_processing import chunker
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_settings():
|
||||
"""Mock settings for all tests"""
|
||||
with patch("api.src.services.text_processing.chunker.settings") as mock_settings:
|
||||
mock_settings.max_chunk_size = 300
|
||||
yield mock_settings
|
||||
|
||||
|
||||
def test_split_text():
|
||||
"""Test text splitting into sentences"""
|
||||
text = "First sentence. Second sentence! Third sentence?"
|
||||
sentences = list(chunker.split_text(text))
|
||||
assert len(sentences) == 3
|
||||
assert sentences[0] == "First sentence."
|
||||
assert sentences[1] == "Second sentence!"
|
||||
assert sentences[2] == "Third sentence?"
|
||||
|
||||
|
||||
def test_split_text_empty():
|
||||
"""Test splitting empty text"""
|
||||
assert list(chunker.split_text("")) == []
|
||||
|
||||
|
||||
def test_split_text_single_sentence():
|
||||
"""Test splitting single sentence"""
|
||||
text = "Just one sentence."
|
||||
assert list(chunker.split_text(text)) == ["Just one sentence."]
|
||||
|
||||
|
||||
def test_split_text_with_custom_chunk_size():
|
||||
"""Test splitting with custom max chunk size"""
|
||||
text = "First part, second part, third part."
|
||||
chunks = list(chunker.split_text(text, max_chunk=15))
|
||||
assert len(chunks) == 3
|
||||
assert chunks[0] == "First part,"
|
||||
assert chunks[1] == "second part,"
|
||||
assert chunks[2] == "third part."
|
20
api/tests/test_data/generate_test_data.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
import numpy as np
|
||||
import os
|
||||
|
||||
def generate_test_audio():
|
||||
"""Generate test audio data - 1 second of 440Hz tone"""
|
||||
# Create 1 second of silence at 24kHz
|
||||
audio = np.zeros(24000, dtype=np.float32)
|
||||
|
||||
# Add a simple sine wave to make it non-zero
|
||||
t = np.linspace(0, 1, 24000)
|
||||
audio += 0.5 * np.sin(2 * np.pi * 440 * t) # 440 Hz tone at half amplitude
|
||||
|
||||
# Create test_data directory if it doesn't exist
|
||||
os.makedirs('api/tests/test_data', exist_ok=True)
|
||||
|
||||
# Save the test audio
|
||||
np.save('api/tests/test_data/test_audio.npy', audio)
|
||||
|
||||
if __name__ == '__main__':
|
||||
generate_test_audio()
|
BIN
api/tests/test_data/test_audio.npy
Normal file
|
@ -1,402 +0,0 @@
|
|||
import asyncio
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from fastapi.testclient import TestClient
|
||||
from httpx import AsyncClient
|
||||
|
||||
from ..src.main import app
|
||||
|
||||
# Create test client
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
# Create async client fixture
|
||||
@pytest_asyncio.fixture
|
||||
async def async_client():
|
||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||
yield ac
|
||||
|
||||
|
||||
# Mock services
|
||||
@pytest.fixture
|
||||
def mock_tts_service(monkeypatch):
|
||||
mock_service = Mock()
|
||||
mock_service._generate_audio.return_value = (bytes([0, 1, 2, 3]), 1.0)
|
||||
|
||||
# Create proper async generator mock
|
||||
async def mock_stream(*args, **kwargs):
|
||||
for chunk in [b"chunk1", b"chunk2"]:
|
||||
yield chunk
|
||||
|
||||
mock_service.generate_audio_stream = mock_stream
|
||||
|
||||
# Create async mocks
|
||||
mock_service.list_voices = AsyncMock(
|
||||
return_value=[
|
||||
"af",
|
||||
"bm_lewis",
|
||||
"bf_isabella",
|
||||
"bf_emma",
|
||||
"af_sarah",
|
||||
"af_bella",
|
||||
"am_adam",
|
||||
"am_michael",
|
||||
"bm_george",
|
||||
]
|
||||
)
|
||||
mock_service.combine_voices = AsyncMock()
|
||||
monkeypatch.setattr(
|
||||
"api.src.routers.openai_compatible.TTSService",
|
||||
lambda *args, **kwargs: mock_service,
|
||||
)
|
||||
return mock_service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_audio_service(monkeypatch):
|
||||
mock_service = Mock()
|
||||
mock_service.convert_audio.return_value = b"converted mock audio data"
|
||||
monkeypatch.setattr("api.src.routers.openai_compatible.AudioService", mock_service)
|
||||
return mock_service
|
||||
|
||||
|
||||
def test_health_check():
|
||||
"""Test the health check endpoint"""
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "healthy"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_speech_endpoint(
|
||||
mock_tts_service, mock_audio_service, async_client
|
||||
):
|
||||
"""Test the OpenAI-compatible speech endpoint"""
|
||||
test_request = {
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": "bm_lewis",
|
||||
"response_format": "wav",
|
||||
"speed": 1.0,
|
||||
"stream": False, # Explicitly disable streaming
|
||||
}
|
||||
response = await async_client.post("/v1/audio/speech", json=test_request)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/wav"
|
||||
assert response.headers["content-disposition"] == "attachment; filename=speech.wav"
|
||||
mock_tts_service._generate_audio.assert_called_once_with(
|
||||
text="Hello world", voice="bm_lewis", speed=1.0, stitch_long_output=True
|
||||
)
|
||||
assert response.content == b"converted mock audio data"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_speech_invalid_voice(mock_tts_service, async_client):
|
||||
"""Test the OpenAI-compatible speech endpoint with invalid voice"""
|
||||
test_request = {
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": "invalid_voice",
|
||||
"response_format": "wav",
|
||||
"speed": 1.0,
|
||||
"stream": False, # Explicitly disable streaming
|
||||
}
|
||||
response = await async_client.post("/v1/audio/speech", json=test_request)
|
||||
assert response.status_code == 400 # Bad request
|
||||
assert "not found" in response.json()["detail"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_speech_invalid_speed(mock_tts_service, async_client):
|
||||
"""Test the OpenAI-compatible speech endpoint with invalid speed"""
|
||||
test_request = {
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": "af",
|
||||
"response_format": "wav",
|
||||
"speed": -1.0, # Invalid speed
|
||||
"stream": False, # Explicitly disable streaming
|
||||
}
|
||||
response = await async_client.post("/v1/audio/speech", json=test_request)
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_speech_generation_error(mock_tts_service, async_client):
|
||||
"""Test error handling in speech generation"""
|
||||
mock_tts_service._generate_audio.side_effect = Exception("Generation failed")
|
||||
test_request = {
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": "af",
|
||||
"response_format": "wav",
|
||||
"speed": 1.0,
|
||||
"stream": False, # Explicitly disable streaming
|
||||
}
|
||||
response = await async_client.post("/v1/audio/speech", json=test_request)
|
||||
assert response.status_code == 500
|
||||
assert "Generation failed" in response.json()["detail"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_combine_voices_list_success(mock_tts_service, async_client):
|
||||
"""Test successful voice combination using list format"""
|
||||
test_voices = ["af_bella", "af_sarah"]
|
||||
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
|
||||
|
||||
response = await async_client.post("/v1/audio/voices/combine", json=test_voices)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["voice"] == "af_bella_af_sarah"
|
||||
mock_tts_service.combine_voices.assert_called_once_with(voices=test_voices)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_combine_voices_string_success(mock_tts_service, async_client):
|
||||
"""Test successful voice combination using string format with +"""
|
||||
test_voices = "af_bella+af_sarah"
|
||||
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
|
||||
|
||||
response = await async_client.post("/v1/audio/voices/combine", json=test_voices)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["voice"] == "af_bella_af_sarah"
|
||||
mock_tts_service.combine_voices.assert_called_once_with(
|
||||
voices=["af_bella", "af_sarah"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_combine_voices_single_voice(mock_tts_service, async_client):
|
||||
"""Test combining single voice returns same voice"""
|
||||
test_voices = ["af_bella"]
|
||||
response = await async_client.post("/v1/audio/voices/combine", json=test_voices)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["voice"] == "af_bella"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_combine_voices_empty_list(mock_tts_service, async_client):
|
||||
"""Test combining empty voice list returns error"""
|
||||
test_voices = []
|
||||
response = await async_client.post("/v1/audio/voices/combine", json=test_voices)
|
||||
assert response.status_code == 400
|
||||
assert "No voices provided" in response.json()["detail"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_combine_voices_error(mock_tts_service, async_client):
|
||||
"""Test error handling in voice combination"""
|
||||
test_voices = ["af_bella", "af_sarah"]
|
||||
mock_tts_service.combine_voices = AsyncMock(
|
||||
side_effect=Exception("Combination failed")
|
||||
)
|
||||
|
||||
response = await async_client.post("/v1/audio/voices/combine", json=test_voices)
|
||||
assert response.status_code == 500
|
||||
assert "Server error" in response.json()["detail"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_speech_with_combined_voice(
|
||||
mock_tts_service, mock_audio_service, async_client
|
||||
):
|
||||
"""Test speech generation with combined voice using + syntax"""
|
||||
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
|
||||
|
||||
test_request = {
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": "af_bella+af_sarah",
|
||||
"response_format": "wav",
|
||||
"speed": 1.0,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
response = await async_client.post("/v1/audio/speech", json=test_request)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/wav"
|
||||
mock_tts_service._generate_audio.assert_called_once_with(
|
||||
text="Hello world",
|
||||
voice="af_bella_af_sarah",
|
||||
speed=1.0,
|
||||
stitch_long_output=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_speech_with_whitespace_in_voice(
|
||||
mock_tts_service, mock_audio_service, async_client
|
||||
):
|
||||
"""Test speech generation with whitespace in voice combination"""
|
||||
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
|
||||
|
||||
test_request = {
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": " af_bella + af_sarah ",
|
||||
"response_format": "wav",
|
||||
"speed": 1.0,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
response = await async_client.post("/v1/audio/speech", json=test_request)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/wav"
|
||||
mock_tts_service.combine_voices.assert_called_once_with(
|
||||
voices=["af_bella", "af_sarah"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_speech_with_empty_voice_combination(mock_tts_service, async_client):
|
||||
"""Test speech generation with empty voice combination"""
|
||||
test_request = {
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": "+",
|
||||
"response_format": "wav",
|
||||
"speed": 1.0,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
response = await async_client.post("/v1/audio/speech", json=test_request)
|
||||
assert response.status_code == 400
|
||||
assert "No voices provided" in response.json()["detail"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_speech_with_invalid_combined_voice(mock_tts_service, async_client):
|
||||
"""Test speech generation with invalid voice combination"""
|
||||
test_request = {
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": "invalid+combination",
|
||||
"response_format": "wav",
|
||||
"speed": 1.0,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
response = await async_client.post("/v1/audio/speech", json=test_request)
|
||||
assert response.status_code == 400
|
||||
assert "not found" in response.json()["detail"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_speech_streaming_with_combined_voice(mock_tts_service, async_client):
|
||||
"""Test streaming speech with combined voice using + syntax"""
|
||||
mock_tts_service.combine_voices = AsyncMock(return_value="af_bella_af_sarah")
|
||||
|
||||
test_request = {
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": "af_bella+af_sarah",
|
||||
"response_format": "mp3",
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
# Create streaming mock
|
||||
async def mock_stream(*args, **kwargs):
|
||||
for chunk in [b"mp3header", b"mp3data"]:
|
||||
yield chunk
|
||||
|
||||
mock_tts_service.generate_audio_stream = mock_stream
|
||||
|
||||
# Add streaming header
|
||||
headers = {"x-raw-response": "stream"}
|
||||
response = await async_client.post(
|
||||
"/v1/audio/speech", json=test_request, headers=headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/mpeg"
|
||||
assert response.headers["content-disposition"] == "attachment; filename=speech.mp3"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_speech_pcm_streaming(mock_tts_service, async_client):
|
||||
"""Test streaming PCM audio for real-time playback"""
|
||||
test_request = {
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": "af",
|
||||
"response_format": "pcm",
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
# Create streaming mock for this test
|
||||
async def mock_stream(*args, **kwargs):
|
||||
for chunk in [b"chunk1", b"chunk2"]:
|
||||
yield chunk
|
||||
|
||||
mock_tts_service.generate_audio_stream = mock_stream
|
||||
|
||||
# Add streaming header
|
||||
headers = {"x-raw-response": "stream"}
|
||||
response = await async_client.post(
|
||||
"/v1/audio/speech", json=test_request, headers=headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/pcm"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_speech_streaming_mp3(mock_tts_service, async_client):
|
||||
"""Test streaming MP3 audio to file"""
|
||||
test_request = {
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": "af",
|
||||
"response_format": "mp3",
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
# Create streaming mock for this test
|
||||
async def mock_stream(*args, **kwargs):
|
||||
for chunk in [b"mp3header", b"mp3data"]:
|
||||
yield chunk
|
||||
|
||||
mock_tts_service.generate_audio_stream = mock_stream
|
||||
|
||||
# Add streaming header
|
||||
headers = {"x-raw-response": "stream"}
|
||||
response = await async_client.post(
|
||||
"/v1/audio/speech", json=test_request, headers=headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/mpeg"
|
||||
assert response.headers["content-disposition"] == "attachment; filename=speech.mp3"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_speech_streaming_generator(mock_tts_service, async_client):
|
||||
"""Test streaming with async generator"""
|
||||
test_request = {
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": "af",
|
||||
"response_format": "pcm",
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
# Create streaming mock for this test
|
||||
async def mock_stream(*args, **kwargs):
|
||||
for chunk in [b"chunk1", b"chunk2"]:
|
||||
yield chunk
|
||||
|
||||
mock_tts_service.generate_audio_stream = mock_stream
|
||||
|
||||
# Add streaming header
|
||||
headers = {"x-raw-response": "stream"}
|
||||
response = await async_client.post(
|
||||
"/v1/audio/speech", json=test_request, headers=headers
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/pcm"
|
|
@ -1,108 +0,0 @@
|
|||
"""Tests for FastAPI application"""
|
||||
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from api.src.main import app, lifespan
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_client():
|
||||
"""Create a test client"""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_health_check(test_client):
|
||||
"""Test health check endpoint"""
|
||||
response = test_client.get("/health")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "healthy"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("api.src.main.TTSModel")
|
||||
@patch("api.src.main.logger")
|
||||
async def test_lifespan_successful_warmup(mock_logger, mock_tts_model):
|
||||
"""Test successful model warmup in lifespan"""
|
||||
# Mock file system for voice counting
|
||||
mock_tts_model.VOICES_DIR = "/mock/voices"
|
||||
|
||||
# Create async mock
|
||||
async def async_setup():
|
||||
return 3
|
||||
|
||||
mock_tts_model.setup = MagicMock()
|
||||
mock_tts_model.setup.side_effect = async_setup
|
||||
mock_tts_model.get_device.return_value = "cuda"
|
||||
|
||||
with patch("os.listdir", return_value=["voice1.pt", "voice2.pt", "voice3.pt"]):
|
||||
# Create an async generator from the lifespan context manager
|
||||
async_gen = lifespan(MagicMock())
|
||||
# Start the context manager
|
||||
await async_gen.__aenter__()
|
||||
|
||||
# Verify the expected logging sequence
|
||||
mock_logger.info.assert_any_call("Loading TTS model and voice packs...")
|
||||
|
||||
# Check for the startup message containing the required info
|
||||
startup_calls = [call[0][0] for call in mock_logger.info.call_args_list]
|
||||
startup_msg = next(msg for msg in startup_calls if "Model warmed up on" in msg)
|
||||
assert "Model warmed up on" in startup_msg
|
||||
assert "3 voice packs loaded" in startup_msg
|
||||
|
||||
# Verify model setup was called
|
||||
mock_tts_model.setup.assert_called_once()
|
||||
|
||||
# Clean up
|
||||
await async_gen.__aexit__(None, None, None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("api.src.main.TTSModel")
|
||||
@patch("api.src.main.logger")
|
||||
async def test_lifespan_failed_warmup(mock_logger, mock_tts_model):
|
||||
"""Test failed model warmup in lifespan"""
|
||||
# Mock the model setup to fail
|
||||
mock_tts_model.setup.side_effect = RuntimeError("Failed to initialize model")
|
||||
|
||||
# Create an async generator from the lifespan context manager
|
||||
async_gen = lifespan(MagicMock())
|
||||
|
||||
# Verify the exception is raised
|
||||
with pytest.raises(RuntimeError, match="Failed to initialize model"):
|
||||
await async_gen.__aenter__()
|
||||
|
||||
# Verify the expected logging sequence
|
||||
mock_logger.info.assert_called_with("Loading TTS model and voice packs...")
|
||||
|
||||
# Clean up
|
||||
await async_gen.__aexit__(None, None, None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("api.src.main.TTSModel")
|
||||
async def test_lifespan_cuda_warmup(mock_tts_model):
|
||||
"""Test model warmup specifically on CUDA"""
|
||||
# Mock file system for voice counting
|
||||
mock_tts_model.VOICES_DIR = "/mock/voices"
|
||||
|
||||
# Create async mock
|
||||
async def async_setup():
|
||||
return 2
|
||||
|
||||
mock_tts_model.setup = MagicMock()
|
||||
mock_tts_model.setup.side_effect = async_setup
|
||||
mock_tts_model.get_device.return_value = "cuda"
|
||||
|
||||
with patch("os.listdir", return_value=["voice1.pt", "voice2.pt"]):
|
||||
# Create an async generator from the lifespan context manager
|
||||
async_gen = lifespan(MagicMock())
|
||||
await async_gen.__aenter__()
|
||||
|
||||
# Verify model setup was called
|
||||
mock_tts_model.setup.assert_called_once()
|
||||
|
||||
# Clean up
|
||||
await async_gen.__aexit__(None, None, None)
|
412
api/tests/test_openai_endpoints.py
Normal file
|
@ -0,0 +1,412 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from fastapi.testclient import TestClient
|
||||
import numpy as np
|
||||
import asyncio
|
||||
from typing import AsyncGenerator
|
||||
import os
|
||||
import json
|
||||
|
||||
from api.src.main import app
|
||||
from api.src.services.tts_service import TTSService
|
||||
from api.src.core.config import settings
|
||||
from api.src.routers.openai_compatible import (
|
||||
load_openai_mappings,
|
||||
get_tts_service,
|
||||
stream_audio_chunks
|
||||
)
|
||||
from api.src.structures.schemas import OpenAISpeechRequest
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
@pytest.fixture
|
||||
def test_voice():
|
||||
"""Fixture providing a test voice name."""
|
||||
return "test_voice"
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_mappings():
|
||||
"""Mock OpenAI mappings for testing."""
|
||||
with patch("api.src.routers.openai_compatible._openai_mappings", {
|
||||
"models": {
|
||||
"tts-1": "kokoro-v0_19",
|
||||
"tts-1-hd": "kokoro-v0_19"
|
||||
},
|
||||
"voices": {
|
||||
"alloy": "am_adam",
|
||||
"nova": "bf_isabella"
|
||||
}
|
||||
}):
|
||||
yield
|
||||
|
||||
@pytest.fixture
|
||||
def mock_json_file(tmp_path):
|
||||
"""Create a temporary mock JSON file."""
|
||||
content = {
|
||||
"models": {"test-model": "test-kokoro"},
|
||||
"voices": {"test-voice": "test-internal"}
|
||||
}
|
||||
json_file = tmp_path / "test_mappings.json"
|
||||
json_file.write_text(json.dumps(content))
|
||||
return json_file
|
||||
|
||||
def test_load_openai_mappings(mock_json_file):
|
||||
"""Test loading OpenAI mappings from JSON file"""
|
||||
with patch("os.path.join", return_value=str(mock_json_file)):
|
||||
mappings = load_openai_mappings()
|
||||
assert "models" in mappings
|
||||
assert "voices" in mappings
|
||||
assert mappings["models"]["test-model"] == "test-kokoro"
|
||||
assert mappings["voices"]["test-voice"] == "test-internal"
|
||||
|
||||
def test_load_openai_mappings_file_not_found():
|
||||
"""Test handling of missing mappings file"""
|
||||
with patch("os.path.join", return_value="/nonexistent/path"):
|
||||
mappings = load_openai_mappings()
|
||||
assert mappings == {"models": {}, "voices": {}}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tts_service_initialization():
|
||||
"""Test TTSService initialization"""
|
||||
with patch("api.src.routers.openai_compatible._tts_service", None):
|
||||
with patch("api.src.routers.openai_compatible._init_lock", None):
|
||||
with patch("api.src.services.tts_service.TTSService.create") as mock_create:
|
||||
mock_service = AsyncMock()
|
||||
mock_create.return_value = mock_service
|
||||
|
||||
# Test concurrent access
|
||||
async def get_service():
|
||||
return await get_tts_service()
|
||||
|
||||
# Create multiple concurrent requests
|
||||
tasks = [get_service() for _ in range(5)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Verify service was created only once
|
||||
mock_create.assert_called_once()
|
||||
assert all(r == mock_service for r in results)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_audio_chunks_client_disconnect():
|
||||
"""Test handling of client disconnect during streaming"""
|
||||
mock_request = MagicMock()
|
||||
mock_request.is_disconnected = AsyncMock(return_value=True)
|
||||
|
||||
mock_service = AsyncMock()
|
||||
async def mock_stream(*args, **kwargs):
|
||||
for i in range(5):
|
||||
yield b"chunk"
|
||||
mock_service.generate_audio_stream = mock_stream
|
||||
mock_service.list_voices.return_value = ["test_voice"]
|
||||
|
||||
request = OpenAISpeechRequest(
|
||||
model="kokoro",
|
||||
input="Test text",
|
||||
voice="test_voice",
|
||||
response_format="mp3",
|
||||
stream=True,
|
||||
speed=1.0
|
||||
)
|
||||
|
||||
chunks = []
|
||||
async for chunk in stream_audio_chunks(mock_service, request, mock_request):
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) == 0 # Should stop immediately due to disconnect
|
||||
|
||||
def test_openai_voice_mapping(mock_tts_service, mock_openai_mappings):
|
||||
"""Test OpenAI voice name mapping"""
|
||||
mock_tts_service.list_voices.return_value = ["am_adam", "bf_isabella"]
|
||||
|
||||
response = client.post(
|
||||
"/v1/audio/speech",
|
||||
json={
|
||||
"model": "tts-1",
|
||||
"input": "Hello world",
|
||||
"voice": "alloy", # OpenAI voice name
|
||||
"response_format": "mp3",
|
||||
"stream": False
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
mock_tts_service.generate_audio.assert_called_once()
|
||||
assert mock_tts_service.generate_audio.call_args[1]["voice"] == "am_adam"
|
||||
|
||||
def test_openai_voice_mapping_streaming(mock_tts_service, mock_openai_mappings, mock_audio_bytes):
|
||||
"""Test OpenAI voice mapping in streaming mode"""
|
||||
mock_tts_service.list_voices.return_value = ["am_adam", "bf_isabella"]
|
||||
|
||||
response = client.post(
|
||||
"/v1/audio/speech",
|
||||
json={
|
||||
"model": "tts-1-hd",
|
||||
"input": "Hello world",
|
||||
"voice": "nova", # OpenAI voice name
|
||||
"response_format": "mp3",
|
||||
"stream": True
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
content = b""
|
||||
for chunk in response.iter_bytes():
|
||||
content += chunk
|
||||
assert content == mock_audio_bytes
|
||||
|
||||
def test_invalid_openai_model(mock_tts_service, mock_openai_mappings):
|
||||
"""Test error handling for invalid OpenAI model"""
|
||||
response = client.post(
|
||||
"/v1/audio/speech",
|
||||
json={
|
||||
"model": "invalid-model",
|
||||
"input": "Hello world",
|
||||
"voice": "alloy",
|
||||
"response_format": "mp3",
|
||||
"stream": False
|
||||
}
|
||||
)
|
||||
assert response.status_code == 400
|
||||
error_response = response.json()
|
||||
assert error_response["detail"]["error"] == "invalid_model"
|
||||
assert "Unsupported model" in error_response["detail"]["message"]
|
||||
|
||||
@pytest.fixture
|
||||
def mock_audio_bytes():
|
||||
"""Mock audio bytes for testing."""
|
||||
return b"mock audio data"
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tts_service(mock_audio_bytes):
|
||||
"""Mock TTS service for testing."""
|
||||
with patch("api.src.routers.openai_compatible.get_tts_service") as mock_get:
|
||||
service = AsyncMock(spec=TTSService)
|
||||
service.generate_audio.return_value = (np.zeros(1000), 0.1)
|
||||
|
||||
async def mock_stream(*args, **kwargs) -> AsyncGenerator[bytes, None]:
|
||||
yield mock_audio_bytes
|
||||
|
||||
service.generate_audio_stream = mock_stream
|
||||
service.list_voices.return_value = ["test_voice", "voice1", "voice2"]
|
||||
service.combine_voices.return_value = "voice1_voice2"
|
||||
|
||||
mock_get.return_value = service
|
||||
mock_get.side_effect = None
|
||||
yield service
|
||||
|
||||
@patch('api.src.services.audio.AudioService.convert_audio')
|
||||
def test_openai_speech_endpoint(mock_convert, mock_tts_service, test_voice, mock_audio_bytes):
|
||||
"""Test the OpenAI-compatible speech endpoint with basic MP3 generation"""
|
||||
# Configure mocks
|
||||
mock_tts_service.generate_audio.return_value = (np.zeros(1000), 0.1)
|
||||
mock_convert.return_value = mock_audio_bytes
|
||||
|
||||
response = client.post(
|
||||
"/v1/audio/speech",
|
||||
json={
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": test_voice,
|
||||
"response_format": "mp3",
|
||||
"stream": False
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/mpeg"
|
||||
assert len(response.content) > 0
|
||||
assert response.content == mock_audio_bytes
|
||||
|
||||
mock_tts_service.generate_audio.assert_called_once()
|
||||
mock_convert.assert_called_once()
|
||||
|
||||
def test_openai_speech_streaming(mock_tts_service, test_voice, mock_audio_bytes):
|
||||
"""Test the OpenAI-compatible speech endpoint with streaming"""
|
||||
response = client.post(
|
||||
"/v1/audio/speech",
|
||||
json={
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": test_voice,
|
||||
"response_format": "mp3",
|
||||
"stream": True
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/mpeg"
|
||||
assert "Transfer-Encoding" in response.headers
|
||||
assert response.headers["Transfer-Encoding"] == "chunked"
|
||||
|
||||
content = b""
|
||||
for chunk in response.iter_bytes():
|
||||
content += chunk
|
||||
assert content == mock_audio_bytes
|
||||
|
||||
def test_openai_speech_pcm_streaming(mock_tts_service, test_voice, mock_audio_bytes):
|
||||
"""Test PCM streaming format"""
|
||||
response = client.post(
|
||||
"/v1/audio/speech",
|
||||
json={
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": test_voice,
|
||||
"response_format": "pcm",
|
||||
"stream": True
|
||||
}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/pcm"
|
||||
|
||||
content = b""
|
||||
for chunk in response.iter_bytes():
|
||||
content += chunk
|
||||
assert content == mock_audio_bytes
|
||||
|
||||
def test_openai_speech_invalid_voice(mock_tts_service):
|
||||
"""Test error handling for invalid voice"""
|
||||
mock_tts_service.generate_audio.side_effect = ValueError("Voice 'invalid_voice' not found")
|
||||
|
||||
response = client.post(
|
||||
"/v1/audio/speech",
|
||||
json={
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": "invalid_voice",
|
||||
"response_format": "mp3",
|
||||
"stream": False
|
||||
}
|
||||
)
|
||||
assert response.status_code == 400
|
||||
error_response = response.json()
|
||||
assert error_response["detail"]["error"] == "validation_error"
|
||||
assert "Voice 'invalid_voice' not found" in error_response["detail"]["message"]
|
||||
assert error_response["detail"]["type"] == "invalid_request_error"
|
||||
|
||||
def test_openai_speech_empty_text(mock_tts_service, test_voice):
|
||||
"""Test error handling for empty text"""
|
||||
async def mock_error_stream(*args, **kwargs):
|
||||
raise ValueError("Text is empty after preprocessing")
|
||||
|
||||
mock_tts_service.generate_audio = mock_error_stream
|
||||
mock_tts_service.list_voices.return_value = ["test_voice"]
|
||||
|
||||
response = client.post(
|
||||
"/v1/audio/speech",
|
||||
json={
|
||||
"model": "kokoro",
|
||||
"input": "",
|
||||
"voice": test_voice,
|
||||
"response_format": "mp3",
|
||||
"stream": False
|
||||
}
|
||||
)
|
||||
assert response.status_code == 400
|
||||
error_response = response.json()
|
||||
assert error_response["detail"]["error"] == "validation_error"
|
||||
assert "Text is empty after preprocessing" in error_response["detail"]["message"]
|
||||
assert error_response["detail"]["type"] == "invalid_request_error"
|
||||
|
||||
def test_openai_speech_invalid_format(mock_tts_service, test_voice):
|
||||
"""Test error handling for invalid format"""
|
||||
response = client.post(
|
||||
"/v1/audio/speech",
|
||||
json={
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": test_voice,
|
||||
"response_format": "invalid_format",
|
||||
"stream": False
|
||||
}
|
||||
)
|
||||
assert response.status_code == 422 # Validation error from Pydantic
|
||||
|
||||
def test_list_voices(mock_tts_service):
|
||||
"""Test listing available voices"""
|
||||
# Override the mock for this specific test
|
||||
mock_tts_service.list_voices.return_value = ["voice1", "voice2"]
|
||||
|
||||
response = client.get("/v1/audio/voices")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "voices" in data
|
||||
assert len(data["voices"]) == 2
|
||||
assert "voice1" in data["voices"]
|
||||
assert "voice2" in data["voices"]
|
||||
|
||||
def test_combine_voices(mock_tts_service):
|
||||
"""Test combining voices endpoint"""
|
||||
response = client.post(
|
||||
"/v1/audio/voices/combine",
|
||||
json="voice1+voice2"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "voice" in data
|
||||
assert data["voice"] == "voice1_voice2"
|
||||
|
||||
def test_server_error(mock_tts_service, test_voice):
|
||||
"""Test handling of server errors"""
|
||||
async def mock_error_stream(*args, **kwargs):
|
||||
raise RuntimeError("Internal server error")
|
||||
|
||||
mock_tts_service.generate_audio = mock_error_stream
|
||||
mock_tts_service.list_voices.return_value = ["test_voice"]
|
||||
|
||||
response = client.post(
|
||||
"/v1/audio/speech",
|
||||
json={
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": test_voice,
|
||||
"response_format": "mp3",
|
||||
"stream": False
|
||||
}
|
||||
)
|
||||
assert response.status_code == 500
|
||||
error_response = response.json()
|
||||
assert error_response["detail"]["error"] == "processing_error"
|
||||
assert error_response["detail"]["type"] == "server_error"
|
||||
|
||||
def test_streaming_error(mock_tts_service, test_voice):
|
||||
"""Test handling streaming errors"""
|
||||
# Mock process_voices to raise the error
|
||||
mock_tts_service.list_voices.side_effect = RuntimeError("Streaming failed")
|
||||
|
||||
response = client.post(
|
||||
"/v1/audio/speech",
|
||||
json={
|
||||
"model": "kokoro",
|
||||
"input": "Hello world",
|
||||
"voice": test_voice,
|
||||
"response_format": "mp3",
|
||||
"stream": True
|
||||
}
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
error_data = response.json()
|
||||
assert error_data["detail"]["error"] == "processing_error"
|
||||
assert error_data["detail"]["type"] == "server_error"
|
||||
assert "Streaming failed" in error_data["detail"]["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_initialization_error():
|
||||
"""Test handling of streaming initialization errors"""
|
||||
mock_service = AsyncMock()
|
||||
async def mock_error_stream(*args, **kwargs):
|
||||
if False: # This makes it a proper generator
|
||||
yield b""
|
||||
raise RuntimeError("Failed to initialize stream")
|
||||
mock_service.generate_audio_stream = mock_error_stream
|
||||
mock_service.list_voices.return_value = ["test_voice"]
|
||||
|
||||
request = OpenAISpeechRequest(
|
||||
model="kokoro",
|
||||
input="Test text",
|
||||
voice="test_voice",
|
||||
response_format="mp3",
|
||||
stream=True,
|
||||
speed=1.0
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError) as exc:
|
||||
async for _ in stream_audio_chunks(mock_service, request, MagicMock()):
|
||||
pass
|
||||
assert "Failed to initialize stream" in str(exc.value)
|
|
@ -1,122 +0,0 @@
|
|||
"""Tests for text processing endpoints"""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import AsyncClient
|
||||
|
||||
from ..src.main import app
|
||||
from .conftest import MockTTSModel
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def async_client():
|
||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||
yield ac
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_phonemize_endpoint(async_client):
|
||||
"""Test phoneme generation endpoint"""
|
||||
with patch("api.src.routers.development.phonemize") as mock_phonemize, patch(
|
||||
"api.src.routers.development.tokenize"
|
||||
) as mock_tokenize:
|
||||
# Setup mocks
|
||||
mock_phonemize.return_value = "həlˈoʊ"
|
||||
mock_tokenize.return_value = [1, 2, 3]
|
||||
|
||||
# Test request
|
||||
response = await async_client.post(
|
||||
"/text/phonemize", json={"text": "hello", "language": "a"}
|
||||
)
|
||||
|
||||
# Verify response
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
assert result["phonemes"] == "həlˈoʊ"
|
||||
assert result["tokens"] == [0, 1, 2, 3, 0] # Should add start/end tokens
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_phonemize_empty_text(async_client):
|
||||
"""Test phoneme generation with empty text"""
|
||||
response = await async_client.post(
|
||||
"/text/phonemize", json={"text": "", "language": "a"}
|
||||
)
|
||||
|
||||
assert response.status_code == 500
|
||||
assert "error" in response.json()["detail"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_from_phonemes(
|
||||
async_client, mock_tts_service, mock_audio_service
|
||||
):
|
||||
"""Test audio generation from phonemes"""
|
||||
with patch(
|
||||
"api.src.routers.development.TTSService", return_value=mock_tts_service
|
||||
):
|
||||
response = await async_client.post(
|
||||
"/text/generate_from_phonemes",
|
||||
json={"phonemes": "həlˈoʊ", "voice": "af_bella", "speed": 1.0},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/wav"
|
||||
assert (
|
||||
response.headers["content-disposition"] == "attachment; filename=speech.wav"
|
||||
)
|
||||
assert response.content == b"mock audio data"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_from_phonemes_invalid_voice(async_client, mock_tts_service):
|
||||
"""Test audio generation with invalid voice"""
|
||||
mock_tts_service._get_voice_path.return_value = None
|
||||
with patch(
|
||||
"api.src.routers.development.TTSService", return_value=mock_tts_service
|
||||
):
|
||||
response = await async_client.post(
|
||||
"/text/generate_from_phonemes",
|
||||
json={"phonemes": "həlˈoʊ", "voice": "invalid_voice", "speed": 1.0},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Voice not found" in response.json()["detail"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_from_phonemes_invalid_speed(async_client, monkeypatch):
|
||||
"""Test audio generation with invalid speed"""
|
||||
# Mock TTSModel initialization
|
||||
mock_model = Mock()
|
||||
mock_model.generate_from_tokens = Mock(return_value=np.zeros(48000))
|
||||
monkeypatch.setattr("api.src.services.tts_model.TTSModel._instance", mock_model)
|
||||
monkeypatch.setattr(
|
||||
"api.src.services.tts_model.TTSModel.get_instance",
|
||||
Mock(return_value=mock_model),
|
||||
)
|
||||
|
||||
response = await async_client.post(
|
||||
"/text/generate_from_phonemes",
|
||||
json={"phonemes": "həlˈoʊ", "voice": "af_bella", "speed": -1.0},
|
||||
)
|
||||
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_from_phonemes_empty_phonemes(async_client, mock_tts_service):
|
||||
"""Test audio generation with empty phonemes"""
|
||||
with patch(
|
||||
"api.src.routers.development.TTSService", return_value=mock_tts_service
|
||||
):
|
||||
response = await async_client.post(
|
||||
"/text/generate_from_phonemes",
|
||||
json={"phonemes": "", "voice": "af_bella", "speed": 1.0},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Invalid request" in response.json()["detail"]["error"]
|
|
@ -1,201 +0,0 @@
|
|||
"""Tests for TTS model implementations"""
|
||||
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from api.src.services.tts_base import TTSBaseModel
|
||||
from api.src.services.tts_cpu import TTSCPUModel
|
||||
from api.src.services.tts_gpu import TTSGPUModel, length_to_mask
|
||||
|
||||
|
||||
# Base Model Tests
|
||||
def test_get_device_error():
|
||||
"""Test get_device() raises error when not initialized"""
|
||||
TTSBaseModel._device = None
|
||||
with pytest.raises(RuntimeError, match="Model not initialized"):
|
||||
TTSBaseModel.get_device()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("torch.cuda.is_available")
|
||||
@patch("os.path.exists")
|
||||
@patch("os.path.join")
|
||||
@patch("os.listdir")
|
||||
@patch("torch.load")
|
||||
@patch("torch.save")
|
||||
@patch("api.src.services.tts_base.settings")
|
||||
@patch("api.src.services.warmup.WarmupService")
|
||||
async def test_setup_cuda_available(
|
||||
mock_warmup_class, mock_settings, mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available
|
||||
):
|
||||
"""Test setup with CUDA available"""
|
||||
TTSBaseModel._device = None
|
||||
# Mock CUDA as unavailable since we're using CPU PyTorch
|
||||
mock_cuda_available.return_value = False
|
||||
mock_exists.return_value = True
|
||||
mock_load.return_value = torch.zeros(1)
|
||||
mock_listdir.return_value = ["voice1.pt", "voice2.pt"]
|
||||
mock_join.return_value = "/mocked/path"
|
||||
|
||||
# Configure mock settings
|
||||
mock_settings.model_dir = "/mock/model/dir"
|
||||
mock_settings.onnx_model_path = "model.onnx"
|
||||
mock_settings.voices_dir = "voices"
|
||||
|
||||
# Configure mock warmup service
|
||||
mock_warmup = MagicMock()
|
||||
mock_warmup.load_voices.return_value = [torch.zeros(1)]
|
||||
mock_warmup.warmup_voices = AsyncMock()
|
||||
mock_warmup_class.return_value = mock_warmup
|
||||
|
||||
# Create mock model
|
||||
mock_model = MagicMock()
|
||||
mock_model.bert = MagicMock()
|
||||
mock_model.process_text = MagicMock(return_value=("dummy", [1, 2, 3]))
|
||||
mock_model.generate_from_tokens = MagicMock(return_value=np.zeros(1000))
|
||||
|
||||
# Mock initialize to return our mock model
|
||||
TTSBaseModel.initialize = MagicMock(return_value=mock_model)
|
||||
TTSBaseModel._instance = mock_model
|
||||
|
||||
voice_count = await TTSBaseModel.setup()
|
||||
assert TTSBaseModel._device == "cpu"
|
||||
assert voice_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("torch.cuda.is_available")
|
||||
@patch("os.path.exists")
|
||||
@patch("os.path.join")
|
||||
@patch("os.listdir")
|
||||
@patch("torch.load")
|
||||
@patch("torch.save")
|
||||
@patch("api.src.services.tts_base.settings")
|
||||
@patch("api.src.services.warmup.WarmupService")
|
||||
async def test_setup_cuda_unavailable(
|
||||
mock_warmup_class, mock_settings, mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available
|
||||
):
|
||||
"""Test setup with CUDA unavailable"""
|
||||
TTSBaseModel._device = None
|
||||
mock_cuda_available.return_value = False
|
||||
mock_exists.return_value = True
|
||||
mock_load.return_value = torch.zeros(1)
|
||||
mock_listdir.return_value = ["voice1.pt", "voice2.pt"]
|
||||
mock_join.return_value = "/mocked/path"
|
||||
|
||||
# Configure mock settings
|
||||
mock_settings.model_dir = "/mock/model/dir"
|
||||
mock_settings.onnx_model_path = "model.onnx"
|
||||
mock_settings.voices_dir = "voices"
|
||||
|
||||
# Configure mock warmup service
|
||||
mock_warmup = MagicMock()
|
||||
mock_warmup.load_voices.return_value = [torch.zeros(1)]
|
||||
mock_warmup.warmup_voices = AsyncMock()
|
||||
mock_warmup_class.return_value = mock_warmup
|
||||
|
||||
# Create mock model
|
||||
mock_model = MagicMock()
|
||||
mock_model.bert = MagicMock()
|
||||
mock_model.process_text = MagicMock(return_value=("dummy", [1, 2, 3]))
|
||||
mock_model.generate_from_tokens = MagicMock(return_value=np.zeros(1000))
|
||||
|
||||
# Mock initialize to return our mock model
|
||||
TTSBaseModel.initialize = MagicMock(return_value=mock_model)
|
||||
TTSBaseModel._instance = mock_model
|
||||
|
||||
voice_count = await TTSBaseModel.setup()
|
||||
assert TTSBaseModel._device == "cpu"
|
||||
assert voice_count == 2
|
||||
|
||||
|
||||
# CPU Model Tests
|
||||
def test_cpu_initialize_missing_model():
|
||||
"""Test CPU initialize with missing model"""
|
||||
TTSCPUModel._onnx_session = None # Reset the session
|
||||
with patch("os.path.exists", return_value=False), patch(
|
||||
"onnxruntime.InferenceSession", return_value=None
|
||||
):
|
||||
result = TTSCPUModel.initialize("dummy_dir")
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_cpu_generate_uninitialized():
|
||||
"""Test CPU generate methods with uninitialized model"""
|
||||
TTSCPUModel._onnx_session = None
|
||||
|
||||
with pytest.raises(RuntimeError, match="ONNX model not initialized"):
|
||||
TTSCPUModel.generate_from_text("test", torch.zeros(1), "en", 1.0)
|
||||
|
||||
with pytest.raises(RuntimeError, match="ONNX model not initialized"):
|
||||
TTSCPUModel.generate_from_tokens([1, 2, 3], torch.zeros(1), 1.0)
|
||||
|
||||
|
||||
def test_cpu_process_text():
|
||||
"""Test CPU process_text functionality"""
|
||||
with patch("api.src.services.tts_cpu.phonemize") as mock_phonemize, patch(
|
||||
"api.src.services.tts_cpu.tokenize"
|
||||
) as mock_tokenize:
|
||||
mock_phonemize.return_value = "test phonemes"
|
||||
mock_tokenize.return_value = [1, 2, 3]
|
||||
|
||||
phonemes, tokens = TTSCPUModel.process_text("test", "en")
|
||||
assert phonemes == "test phonemes"
|
||||
assert tokens == [0, 1, 2, 3, 0] # Should add start/end tokens
|
||||
|
||||
|
||||
# GPU Model Tests
|
||||
@patch("torch.cuda.is_available")
|
||||
def test_gpu_initialize_cuda_unavailable(mock_cuda_available):
|
||||
"""Test GPU initialize with CUDA unavailable"""
|
||||
mock_cuda_available.return_value = False
|
||||
TTSGPUModel._instance = None
|
||||
|
||||
result = TTSGPUModel.initialize("dummy_dir", "dummy_path")
|
||||
assert result is None
|
||||
|
||||
|
||||
@patch("api.src.services.tts_gpu.length_to_mask")
|
||||
def test_gpu_length_to_mask(mock_length_to_mask):
|
||||
"""Test length_to_mask function"""
|
||||
# Setup mock return value
|
||||
expected_mask = torch.tensor(
|
||||
[[False, False, False, True, True], [False, False, False, False, False]]
|
||||
)
|
||||
mock_length_to_mask.return_value = expected_mask
|
||||
|
||||
# Call function with test input
|
||||
lengths = torch.tensor([3, 5])
|
||||
mask = mock_length_to_mask(lengths)
|
||||
|
||||
# Verify mock was called with correct input
|
||||
mock_length_to_mask.assert_called_once()
|
||||
assert torch.equal(mask, expected_mask)
|
||||
|
||||
|
||||
def test_gpu_generate_uninitialized():
|
||||
"""Test GPU generate methods with uninitialized model"""
|
||||
TTSGPUModel._instance = None
|
||||
|
||||
with pytest.raises(RuntimeError, match="GPU model not initialized"):
|
||||
TTSGPUModel.generate_from_text("test", torch.zeros(1), "en", 1.0)
|
||||
|
||||
with pytest.raises(RuntimeError, match="GPU model not initialized"):
|
||||
TTSGPUModel.generate_from_tokens([1, 2, 3], torch.zeros(1), 1.0)
|
||||
|
||||
|
||||
def test_gpu_process_text():
|
||||
"""Test GPU process_text functionality"""
|
||||
with patch("api.src.services.tts_gpu.phonemize") as mock_phonemize, patch(
|
||||
"api.src.services.tts_gpu.tokenize"
|
||||
) as mock_tokenize:
|
||||
mock_phonemize.return_value = "test phonemes"
|
||||
mock_tokenize.return_value = [1, 2, 3]
|
||||
|
||||
phonemes, tokens = TTSGPUModel.process_text("test", "en")
|
||||
assert phonemes == "test phonemes"
|
||||
assert tokens == [1, 2, 3] # GPU implementation doesn't add start/end tokens
|
|
@ -1,260 +0,0 @@
|
|||
"""Tests for TTSService"""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
from api.src.core.config import settings
|
||||
from api.src.services.tts_cpu import TTSCPUModel
|
||||
from api.src.services.tts_gpu import TTSGPUModel
|
||||
from api.src.services.tts_model import TTSModel
|
||||
from api.src.services.tts_service import TTSService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tts_service(monkeypatch):
|
||||
"""Create a TTSService instance for testing"""
|
||||
# Mock TTSModel initialization
|
||||
mock_model = MagicMock()
|
||||
mock_model.generate_from_tokens = MagicMock(return_value=np.zeros(48000))
|
||||
mock_model.process_text = MagicMock(return_value=("mock phonemes", [1, 2, 3]))
|
||||
|
||||
# Set up model instance
|
||||
monkeypatch.setattr("api.src.services.tts_model.TTSModel._instance", mock_model)
|
||||
monkeypatch.setattr(
|
||||
"api.src.services.tts_model.TTSModel.get_instance",
|
||||
MagicMock(return_value=mock_model),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"api.src.services.tts_model.TTSModel.get_device", MagicMock(return_value="cpu")
|
||||
)
|
||||
|
||||
return TTSService()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_audio():
|
||||
"""Generate a simple sine wave for testing"""
|
||||
sample_rate = 24000
|
||||
duration = 0.1 # 100ms
|
||||
t = np.linspace(0, duration, int(sample_rate * duration))
|
||||
frequency = 440 # A4 note
|
||||
return np.sin(2 * np.pi * frequency * t).astype(np.float32)
|
||||
|
||||
|
||||
def test_audio_to_bytes(tts_service, sample_audio):
|
||||
"""Test converting audio tensor to bytes"""
|
||||
audio_bytes = tts_service._audio_to_bytes(sample_audio)
|
||||
assert isinstance(audio_bytes, bytes)
|
||||
assert len(audio_bytes) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_voices(tts_service):
|
||||
"""Test listing available voices"""
|
||||
|
||||
# Override list_voices for testing
|
||||
# # TODO:
|
||||
# Whatever aiofiles does here pathing aiofiles vs aiofiles.os
|
||||
# I am thoroughly confused by it.
|
||||
# Cheating the test as it seems to work in the real world (for now)
|
||||
async def mock_list_voices():
|
||||
return ["voice1", "voice2"]
|
||||
|
||||
tts_service.list_voices = mock_list_voices
|
||||
|
||||
voices = await tts_service.list_voices()
|
||||
assert len(voices) == 2
|
||||
assert "voice1" in voices
|
||||
assert "voice2" in voices
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_voices_error(tts_service):
|
||||
"""Test error handling in list_voices"""
|
||||
|
||||
# Override list_voices for testing
|
||||
# TODO: See above.
|
||||
async def mock_list_voices():
|
||||
return []
|
||||
|
||||
tts_service.list_voices = mock_list_voices
|
||||
|
||||
voices = await tts_service.list_voices()
|
||||
assert voices == []
|
||||
|
||||
|
||||
def mock_model_setup(cuda_available=False):
|
||||
"""Helper function to mock model setup"""
|
||||
# Reset model state
|
||||
TTSModel._instance = None
|
||||
TTSModel._device = None
|
||||
TTSModel._voicepacks = {}
|
||||
|
||||
# Create mock model instance with proper generate method
|
||||
mock_model = MagicMock()
|
||||
mock_model.generate.return_value = np.zeros(24000, dtype=np.float32)
|
||||
TTSModel._instance = mock_model
|
||||
|
||||
# Set device based on CUDA availability
|
||||
TTSModel._device = "cuda" if cuda_available else "cpu"
|
||||
|
||||
return 3 # Return voice count (including af.pt)
|
||||
|
||||
|
||||
def test_model_initialization_cuda():
|
||||
"""Test model initialization with CUDA"""
|
||||
# Simulate CUDA availability
|
||||
voice_count = mock_model_setup(cuda_available=True)
|
||||
|
||||
assert TTSModel.get_device() == "cuda"
|
||||
assert voice_count == 3 # voice1.pt, voice2.pt, af.pt
|
||||
|
||||
|
||||
def test_model_initialization_cpu():
|
||||
"""Test model initialization with CPU"""
|
||||
# Simulate no CUDA availability
|
||||
voice_count = mock_model_setup(cuda_available=False)
|
||||
|
||||
assert TTSModel.get_device() == "cpu"
|
||||
assert voice_count == 3 # voice1.pt, voice2.pt, af.pt
|
||||
|
||||
|
||||
def test_generate_audio_empty_text(tts_service):
|
||||
"""Test generating audio with empty text"""
|
||||
with pytest.raises(ValueError, match="Text is empty after preprocessing"):
|
||||
tts_service._generate_audio("", "af", 1.0)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_settings():
|
||||
"""Mock settings for all tests"""
|
||||
with patch("api.src.services.text_processing.chunker.settings") as mock_settings:
|
||||
mock_settings.max_chunk_size = 300
|
||||
yield mock_settings
|
||||
|
||||
|
||||
@patch("api.src.services.tts_model.TTSModel.get_instance")
|
||||
@patch("api.src.services.tts_model.TTSModel.get_device")
|
||||
@patch("os.path.exists")
|
||||
@patch("kokoro.normalize_text")
|
||||
@patch("kokoro.phonemize")
|
||||
@patch("kokoro.tokenize")
|
||||
@patch("kokoro.generate")
|
||||
@patch("torch.load")
|
||||
def test_generate_audio_phonemize_error(
|
||||
mock_torch_load,
|
||||
mock_generate,
|
||||
mock_tokenize,
|
||||
mock_phonemize,
|
||||
mock_normalize,
|
||||
mock_exists,
|
||||
mock_get_device,
|
||||
mock_instance,
|
||||
tts_service,
|
||||
):
|
||||
"""Test handling phonemization error"""
|
||||
mock_normalize.return_value = "Test text"
|
||||
mock_phonemize.side_effect = Exception("Phonemization failed")
|
||||
mock_instance.return_value = (
|
||||
mock_generate,
|
||||
"cpu",
|
||||
) # Use the same mock for consistency
|
||||
mock_get_device.return_value = "cpu"
|
||||
mock_exists.return_value = True
|
||||
mock_torch_load.return_value = torch.zeros((10, 24000))
|
||||
mock_generate.return_value = (None, None)
|
||||
|
||||
with pytest.raises(ValueError, match="No chunks were processed successfully"):
|
||||
tts_service._generate_audio("Test text", "af", 1.0)
|
||||
|
||||
|
||||
@patch("api.src.services.tts_model.TTSModel.get_instance")
|
||||
@patch("api.src.services.tts_model.TTSModel.get_device")
|
||||
@patch("os.path.exists")
|
||||
@patch("kokoro.normalize_text")
|
||||
@patch("kokoro.phonemize")
|
||||
@patch("kokoro.tokenize")
|
||||
@patch("kokoro.generate")
|
||||
@patch("torch.load")
|
||||
def test_generate_audio_error(
|
||||
mock_torch_load,
|
||||
mock_generate,
|
||||
mock_tokenize,
|
||||
mock_phonemize,
|
||||
mock_normalize,
|
||||
mock_exists,
|
||||
mock_get_device,
|
||||
mock_instance,
|
||||
tts_service,
|
||||
):
|
||||
"""Test handling generation error"""
|
||||
mock_normalize.return_value = "Test text"
|
||||
mock_phonemize.return_value = "Test text"
|
||||
mock_tokenize.return_value = [1, 2] # Return integers instead of strings
|
||||
mock_generate.side_effect = Exception("Generation failed")
|
||||
mock_instance.return_value = (
|
||||
mock_generate,
|
||||
"cpu",
|
||||
) # Use the same mock for consistency
|
||||
mock_get_device.return_value = "cpu"
|
||||
mock_exists.return_value = True
|
||||
mock_torch_load.return_value = torch.zeros((10, 24000))
|
||||
|
||||
with pytest.raises(ValueError, match="No chunks were processed successfully"):
|
||||
tts_service._generate_audio("Test text", "af", 1.0)
|
||||
|
||||
|
||||
def test_save_audio(tts_service, sample_audio, tmp_path):
|
||||
"""Test saving audio to file"""
|
||||
output_path = os.path.join(tmp_path, "test_output.wav")
|
||||
tts_service._save_audio(sample_audio, output_path)
|
||||
assert os.path.exists(output_path)
|
||||
assert os.path.getsize(output_path) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_combine_voices(tts_service):
|
||||
"""Test combining multiple voices"""
|
||||
# Setup mocks for torch operations
|
||||
with patch("torch.load", return_value=torch.tensor([1.0, 2.0])), patch(
|
||||
"torch.stack", return_value=torch.tensor([[1.0, 2.0], [3.0, 4.0]])
|
||||
), patch("torch.mean", return_value=torch.tensor([2.0, 3.0])), patch(
|
||||
"torch.save"
|
||||
), patch("os.path.exists", return_value=True):
|
||||
# Test combining two voices
|
||||
result = await tts_service.combine_voices(["voice1", "voice2"])
|
||||
|
||||
assert result == "voice1_voice2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_combine_voices_invalid_input(tts_service):
|
||||
"""Test combining voices with invalid input"""
|
||||
# Test with empty list
|
||||
with pytest.raises(ValueError, match="At least 2 voices are required"):
|
||||
await tts_service.combine_voices([])
|
||||
|
||||
# Test with single voice
|
||||
with pytest.raises(ValueError, match="At least 2 voices are required"):
|
||||
await tts_service.combine_voices(["voice1"])
|
||||
|
||||
|
||||
@patch("api.src.services.tts_service.TTSService._get_voice_path")
|
||||
@patch("api.src.services.tts_model.TTSModel.get_instance")
|
||||
def test_voicepack_loading_error(mock_get_instance, mock_get_voice_path):
|
||||
"""Test voicepack loading error handling"""
|
||||
mock_get_voice_path.return_value = None
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.generate.return_value = np.zeros(24000, dtype=np.float32)
|
||||
mock_get_instance.return_value = (mock_instance, "cpu")
|
||||
|
||||
TTSModel._voicepacks = {} # Reset voicepacks
|
||||
|
||||
service = TTSService()
|
||||
with pytest.raises(ValueError, match="Voice not found: nonexistent_voice"):
|
||||
service._generate_audio("test", "nonexistent_voice", 1.0)
|
142
api/tests/test_tts_service_new.py
Normal file
|
@ -0,0 +1,142 @@
|
|||
# import pytest
|
||||
# import numpy as np
|
||||
# from unittest.mock import AsyncMock, patch
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_generate_audio(tts_service, mock_audio_output, test_voice):
|
||||
# """Test basic audio generation"""
|
||||
# audio, processing_time = await tts_service.generate_audio(
|
||||
# text="Hello world",
|
||||
# voice=test_voice,
|
||||
# speed=1.0
|
||||
# )
|
||||
|
||||
# assert isinstance(audio, np.ndarray)
|
||||
# assert audio == mock_audio_output.tobytes()
|
||||
# assert processing_time > 0
|
||||
# tts_service.model_manager.generate.assert_called_once()
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_generate_audio_with_combined_voice(tts_service, mock_audio_output):
|
||||
# """Test audio generation with a combined voice"""
|
||||
# test_voices = ["voice1", "voice2"]
|
||||
# combined_id = await tts_service._voice_manager.combine_voices(test_voices)
|
||||
|
||||
# audio, processing_time = await tts_service.generate_audio(
|
||||
# text="Hello world",
|
||||
# voice=combined_id,
|
||||
# speed=1.0
|
||||
# )
|
||||
|
||||
# assert isinstance(audio, np.ndarray)
|
||||
# assert np.array_equal(audio, mock_audio_output)
|
||||
# assert processing_time > 0
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_generate_audio_stream(tts_service, mock_audio_output, test_voice):
|
||||
# """Test streaming audio generation"""
|
||||
# tts_service.model_manager.generate.return_value = mock_audio_output
|
||||
|
||||
# chunks = []
|
||||
# async for chunk in tts_service.generate_audio_stream(
|
||||
# text="Hello world",
|
||||
# voice=test_voice,
|
||||
# speed=1.0,
|
||||
# output_format="pcm"
|
||||
# ):
|
||||
# assert isinstance(chunk, bytes)
|
||||
# chunks.append(chunk)
|
||||
|
||||
# assert len(chunks) > 0
|
||||
# tts_service.model_manager.generate.assert_called()
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_empty_text(tts_service, test_voice):
|
||||
# """Test handling empty text"""
|
||||
# with pytest.raises(ValueError) as exc_info:
|
||||
# await tts_service.generate_audio(
|
||||
# text="",
|
||||
# voice=test_voice,
|
||||
# speed=1.0
|
||||
# )
|
||||
# assert "No audio chunks were generated successfully" in str(exc_info.value)
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_invalid_voice(tts_service):
|
||||
# """Test handling invalid voice"""
|
||||
# tts_service._voice_manager.load_voice.side_effect = ValueError("Voice not found")
|
||||
|
||||
# with pytest.raises(ValueError) as exc_info:
|
||||
# await tts_service.generate_audio(
|
||||
# text="Hello world",
|
||||
# voice="invalid_voice",
|
||||
# speed=1.0
|
||||
# )
|
||||
# assert "Voice not found" in str(exc_info.value)
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_model_generation_error(tts_service, test_voice):
|
||||
# """Test handling model generation error"""
|
||||
# # Make generate return None to simulate failed generation
|
||||
# tts_service.model_manager.generate.return_value = None
|
||||
|
||||
# with pytest.raises(ValueError) as exc_info:
|
||||
# await tts_service.generate_audio(
|
||||
# text="Hello world",
|
||||
# voice=test_voice,
|
||||
# speed=1.0
|
||||
# )
|
||||
# assert "No audio chunks were generated successfully" in str(exc_info.value)
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_streaming_generation_error(tts_service, test_voice):
|
||||
# """Test handling streaming generation error"""
|
||||
# # Make generate return None to simulate failed generation
|
||||
# tts_service.model_manager.generate.return_value = None
|
||||
|
||||
# chunks = []
|
||||
# async for chunk in tts_service.generate_audio_stream(
|
||||
# text="Hello world",
|
||||
# voice=test_voice,
|
||||
# speed=1.0,
|
||||
# output_format="pcm"
|
||||
# ):
|
||||
# chunks.append(chunk)
|
||||
|
||||
# # Should get no chunks if generation fails
|
||||
# assert len(chunks) == 0
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_list_voices(tts_service):
|
||||
# """Test listing available voices"""
|
||||
# voices = await tts_service.list_voices()
|
||||
# assert len(voices) == 2
|
||||
# assert "voice1" in voices
|
||||
# assert "voice2" in voices
|
||||
# tts_service._voice_manager.list_voices.assert_called_once()
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_combine_voices(tts_service):
|
||||
# """Test combining voices"""
|
||||
# test_voices = ["voice1", "voice2"]
|
||||
# combined_id = await tts_service.combine_voices(test_voices)
|
||||
# assert combined_id == "voice1_voice2"
|
||||
# tts_service._voice_manager.combine_voices.assert_called_once_with(test_voices)
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_chunked_text_processing(tts_service, test_voice, mock_audio_output):
|
||||
# """Test processing chunked text"""
|
||||
# # Create text that will force chunking by exceeding max tokens
|
||||
# long_text = "This is a test sentence." * 100 # Should be way over 500 tokens
|
||||
|
||||
# # Don't mock smart_split - let it actually split the text
|
||||
# audio, processing_time = await tts_service.generate_audio(
|
||||
# text=long_text,
|
||||
# voice=test_voice,
|
||||
# speed=1.0
|
||||
# )
|
||||
|
||||
# # Should be called multiple times due to chunking
|
||||
# assert tts_service.model_manager.generate.call_count > 1
|
||||
# assert isinstance(audio, np.ndarray)
|
||||
# assert processing_time > 0
|
134
api/tests/test_voice_manager.py
Normal file
|
@ -0,0 +1,134 @@
|
|||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
import torch
|
||||
from pathlib import Path
|
||||
|
||||
from ..src.inference.voice_manager import VoiceManager
|
||||
from ..src.structures.model_schemas import VoiceConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_voice_tensor():
|
||||
return torch.randn(10, 10) # Dummy tensor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def voice_manager():
|
||||
return VoiceManager(VoiceConfig())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_voice(voice_manager, mock_voice_tensor):
|
||||
"""Test loading a single voice"""
|
||||
with patch("api.src.core.paths.load_voice_tensor", new_callable=AsyncMock) as mock_load:
|
||||
mock_load.return_value = mock_voice_tensor
|
||||
with patch("os.path.exists", return_value=True):
|
||||
voice = await voice_manager.load_voice("af_bella", "cpu")
|
||||
assert torch.equal(voice, mock_voice_tensor)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_voice_not_found(voice_manager):
|
||||
"""Test loading non-existent voice"""
|
||||
with patch("os.path.exists", return_value=False):
|
||||
with pytest.raises(RuntimeError, match="Voice not found: invalid_voice"):
|
||||
await voice_manager.load_voice("invalid_voice", "cpu")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Local saving is optional and not critical to functionality")
|
||||
@pytest.mark.asyncio
|
||||
async def test_combine_voices_with_saving(voice_manager, mock_voice_tensor):
|
||||
"""Test combining voices with local saving enabled"""
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_combine_voices_without_saving(voice_manager, mock_voice_tensor):
|
||||
"""Test combining voices without local saving"""
|
||||
with patch("api.src.core.paths.load_voice_tensor", new_callable=AsyncMock) as mock_load, \
|
||||
patch("torch.save") as mock_save, \
|
||||
patch("os.makedirs"), \
|
||||
patch("os.path.exists", return_value=True):
|
||||
|
||||
# Setup mocks
|
||||
mock_load.return_value = mock_voice_tensor
|
||||
|
||||
# Mock settings
|
||||
with patch("api.src.core.config.settings") as mock_settings:
|
||||
mock_settings.allow_local_voice_saving = False
|
||||
mock_settings.voices_dir = "/mock/voices"
|
||||
|
||||
# Combine voices
|
||||
combined = await voice_manager.combine_voices(["af_bella", "af_sarah"], "cpu")
|
||||
assert combined == "af_bella+af_sarah" # Note: using + separator
|
||||
|
||||
# Verify voice was not saved
|
||||
mock_save.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_combine_voices_single_voice(voice_manager):
|
||||
"""Test combining with single voice"""
|
||||
with pytest.raises(ValueError, match="At least 2 voices are required"):
|
||||
await voice_manager.combine_voices(["af_bella"], "cpu")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_voices(voice_manager):
|
||||
"""Test listing available voices"""
|
||||
with patch("os.listdir", return_value=["af_bella.pt", "af_sarah.pt", "af_bella+af_sarah.pt"]), \
|
||||
patch("os.makedirs"):
|
||||
voices = await voice_manager.list_voices()
|
||||
assert len(voices) == 3
|
||||
assert "af_bella" in voices
|
||||
assert "af_sarah" in voices
|
||||
assert "af_bella+af_sarah" in voices
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_combined_voice(voice_manager, mock_voice_tensor):
|
||||
"""Test loading a combined voice"""
|
||||
with patch("api.src.core.paths.load_voice_tensor", new_callable=AsyncMock) as mock_load:
|
||||
mock_load.return_value = mock_voice_tensor
|
||||
with patch("os.path.exists", return_value=True):
|
||||
voice = await voice_manager.load_voice("af_bella+af_sarah", "cpu")
|
||||
assert torch.equal(voice, mock_voice_tensor)
|
||||
|
||||
|
||||
def test_cache_management(mock_voice_tensor):
|
||||
"""Test voice cache management"""
|
||||
# Create voice manager with small cache size
|
||||
config = VoiceConfig(cache_size=2)
|
||||
voice_manager = VoiceManager(config)
|
||||
|
||||
# Add items to cache
|
||||
voice_manager._voice_cache = {
|
||||
"voice1_cpu": torch.randn(5, 5),
|
||||
"voice2_cpu": torch.randn(5, 5),
|
||||
"voice3_cpu": torch.randn(5, 5), # Add one more than cache size
|
||||
}
|
||||
|
||||
# Try managing cache
|
||||
voice_manager._manage_cache()
|
||||
|
||||
# Check cache size maintained
|
||||
assert len(voice_manager._voice_cache) <= 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_voice_loading_with_cache(voice_manager, mock_voice_tensor):
|
||||
"""Test voice loading with cache enabled"""
|
||||
with patch("api.src.core.paths.load_voice_tensor", new_callable=AsyncMock) as mock_load, \
|
||||
patch("os.path.exists", return_value=True):
|
||||
|
||||
mock_load.return_value = mock_voice_tensor
|
||||
|
||||
# First load should hit disk
|
||||
voice1 = await voice_manager.load_voice("af_bella", "cpu")
|
||||
assert mock_load.call_count == 1
|
||||
|
||||
# Second load should hit cache
|
||||
voice2 = await voice_manager.load_voice("af_bella", "cpu")
|
||||
assert mock_load.call_count == 1 # Still 1
|
||||
|
||||
assert torch.equal(voice1, voice2)
|
BIN
assets/beta_web_ui.png
Normal file
After Width: | Height: | Size: 385 KiB |
BIN
assets/docs-screenshot.png
Normal file
After Width: | Height: | Size: 78 KiB |
BIN
assets/webui-screenshot.png
Normal file
After Width: | Height: | Size: 283 KiB |
17
debug.http
Normal file
|
@ -0,0 +1,17 @@
|
|||
### Get Thread Information
|
||||
GET http://localhost:8880/debug/threads
|
||||
Accept: application/json
|
||||
|
||||
### Get Storage Information
|
||||
GET http://localhost:8880/debug/storage
|
||||
Accept: application/json
|
||||
|
||||
### Get System Information
|
||||
GET http://localhost:8880/debug/system
|
||||
Accept: application/json
|
||||
|
||||
### Get Session Pool Status
|
||||
# Shows active ONNX sessions, CUDA stream usage, and session ages
|
||||
# Useful for debugging resource exhaustion issues
|
||||
GET http://localhost:8880/debug/session_pools
|
||||
Accept: application/json
|
36
docker/build.sh
Executable file
|
@ -0,0 +1,36 @@
|
|||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
# Get version from argument or use default
|
||||
VERSION=${1:-"latest"}
|
||||
|
||||
# GitHub Container Registry settings
|
||||
REGISTRY="ghcr.io"
|
||||
OWNER="remsky"
|
||||
REPO="kokoro-fastapi"
|
||||
|
||||
# Create and use a new builder that supports multi-platform builds
|
||||
docker buildx create --name multiplatform-builder --use || true
|
||||
|
||||
# Build CPU image with multi-platform support
|
||||
echo "Building CPU image..."
|
||||
docker buildx build --platform linux/amd64,linux/arm64 \
|
||||
-t ${REGISTRY}/${OWNER}/${REPO}-cpu:${VERSION} \
|
||||
-t ${REGISTRY}/${OWNER}/${REPO}-cpu:latest \
|
||||
-f docker/cpu/Dockerfile \
|
||||
--push .
|
||||
|
||||
# Build GPU image with multi-platform support
|
||||
echo "Building GPU image..."
|
||||
docker buildx build --platform linux/amd64,linux/arm64 \
|
||||
-t ${REGISTRY}/${OWNER}/${REPO}-gpu:${VERSION} \
|
||||
-t ${REGISTRY}/${OWNER}/${REPO}-gpu:latest \
|
||||
-f docker/gpu/Dockerfile \
|
||||
--push .
|
||||
|
||||
echo "Build complete!"
|
||||
echo "Created images:"
|
||||
echo "- ${REGISTRY}/${OWNER}/${REPO}-cpu:${VERSION} (linux/amd64, linux/arm64)"
|
||||
echo "- ${REGISTRY}/${OWNER}/${REPO}-cpu:latest (linux/amd64, linux/arm64)"
|
||||
echo "- ${REGISTRY}/${OWNER}/${REPO}-gpu:${VERSION} (linux/amd64, linux/arm64)"
|
||||
echo "- ${REGISTRY}/${OWNER}/${REPO}-gpu:latest (linux/amd64, linux/arm64)"
|
|
@ -1,4 +1,4 @@
|
|||
FROM python:3.10-slim
|
||||
FROM --platform=$BUILDPLATFORM python:3.10-slim
|
||||
|
||||
# Install dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
|
@ -10,36 +10,17 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install uv
|
||||
# Install uv for speed and glory
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
||||
|
||||
# Create non-root user
|
||||
RUN useradd -m -u 1000 appuser
|
||||
|
||||
# Create directories and set ownership
|
||||
RUN mkdir -p /app/models && \
|
||||
mkdir -p /app/api/src/voices && \
|
||||
RUN mkdir -p /app/api/src/voices && \
|
||||
chown -R appuser:appuser /app
|
||||
|
||||
USER appuser
|
||||
|
||||
# Download and extract models
|
||||
WORKDIR /app/models
|
||||
RUN set -x && \
|
||||
curl -L -o model.tar.gz https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.0.1/kokoro-82m-onnx.tar.gz && \
|
||||
echo "Downloaded model.tar.gz:" && ls -lh model.tar.gz && \
|
||||
tar xzf model.tar.gz && \
|
||||
echo "Contents after extraction:" && ls -lhR && \
|
||||
rm model.tar.gz && \
|
||||
echo "Final contents:" && ls -lhR
|
||||
|
||||
# Download and extract voice models
|
||||
WORKDIR /app/api/src/voices
|
||||
RUN curl -L -o voices.tar.gz https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.0.1/voice-models.tar.gz && \
|
||||
tar xzf voices.tar.gz && \
|
||||
rm voices.tar.gz
|
||||
|
||||
# Switch back to app directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy dependency files
|
||||
|
@ -50,8 +31,10 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
|||
uv venv && \
|
||||
uv sync --extra cpu --no-install-project
|
||||
|
||||
# Copy project files
|
||||
# Copy project files including models
|
||||
COPY --chown=appuser:appuser api ./api
|
||||
COPY --chown=appuser:appuser web ./web
|
||||
COPY --chown=appuser:appuser docker/scripts/download_model.* ./
|
||||
|
||||
# Install project
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
|
@ -59,9 +42,22 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
|||
|
||||
# Set environment variables
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV PYTHONPATH=/app:/app/models
|
||||
ENV PYTHONPATH=/app
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
ENV USE_GPU=false
|
||||
ENV USE_ONNX=true
|
||||
ENV DOWNLOAD_ONNX=true
|
||||
ENV DOWNLOAD_PTH=false
|
||||
|
||||
# Download models based on environment variables
|
||||
RUN if [ "$DOWNLOAD_ONNX" = "true" ]; then \
|
||||
python download_model.py --type onnx; \
|
||||
fi && \
|
||||
if [ "$DOWNLOAD_PTH" = "true" ]; then \
|
||||
python download_model.py --type pth; \
|
||||
fi
|
||||
|
||||
# Run FastAPI server
|
||||
CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]
|
||||
|
|
|
@ -6,12 +6,11 @@ services:
|
|||
context: ../..
|
||||
dockerfile: docker/cpu/Dockerfile
|
||||
volumes:
|
||||
- ../../api/src:/app/api/src
|
||||
- ../../api/src/voices:/app/api/src/voices
|
||||
- ../../api:/app/api
|
||||
ports:
|
||||
- "8880:8880"
|
||||
environment:
|
||||
- PYTHONPATH=/app:/app/models
|
||||
- PYTHONPATH=/app:/app/api
|
||||
# ONNX Optimization Settings for vectorized operations
|
||||
- ONNX_NUM_THREADS=8 # Maximize core usage for vectorized ops
|
||||
- ONNX_INTER_OP_THREADS=4 # Higher inter-op for parallel matrix operations
|
||||
|
@ -20,20 +19,20 @@ services:
|
|||
- ONNX_MEMORY_PATTERN=true
|
||||
- ONNX_ARENA_EXTEND_STRATEGY=kNextPowerOfTwo
|
||||
|
||||
# Gradio UI service [Comment out everything below if you don't need it]
|
||||
gradio-ui:
|
||||
image: ghcr.io/remsky/kokoro-fastapi-ui:v0.1.0
|
||||
# Uncomment below (and comment out above) to build from source instead of using the released image
|
||||
build:
|
||||
context: ../../ui
|
||||
ports:
|
||||
- "7860:7860"
|
||||
volumes:
|
||||
- ../../ui/data:/app/ui/data
|
||||
- ../../ui/app.py:/app/app.py # Mount app.py for hot reload
|
||||
environment:
|
||||
- GRADIO_WATCH=True # Enable hot reloading
|
||||
- PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
|
||||
- DISABLE_LOCAL_SAVING=false # Set to 'true' to disable local saving and hide file view
|
||||
- API_HOST=kokoro-tts # Set TTS service URL
|
||||
- API_PORT=8880 # Set TTS service PORT
|
||||
# # Gradio UI service [Comment out everything below if you don't need it]
|
||||
# gradio-ui:
|
||||
# image: ghcr.io/remsky/kokoro-fastapi-ui:v0.1.0
|
||||
# # Uncomment below (and comment out above) to build from source instead of using the released image
|
||||
# build:
|
||||
# context: ../../ui
|
||||
# ports:
|
||||
# - "7860:7860"
|
||||
# volumes:
|
||||
# - ../../ui/data:/app/ui/data
|
||||
# - ../../ui/app.py:/app/app.py # Mount app.py for hot reload
|
||||
# environment:
|
||||
# - GRADIO_WATCH=True # Enable hot reloading
|
||||
# - PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
|
||||
# - DISABLE_LOCAL_SAVING=false # Set to 'true' to disable local saving and hide file view
|
||||
# - API_HOST=kokoro-tts # Set TTS service URL
|
||||
# - API_PORT=8880 # Set TTS service PORT
|
||||
|
|
|
@ -1,22 +0,0 @@
|
|||
[project]
|
||||
name = "kokoro-fastapi-cpu"
|
||||
version = "0.1.0"
|
||||
description = "FastAPI TTS Service - CPU Version"
|
||||
readme = "../README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
# Core ML/DL for CPU
|
||||
"torch>=2.5.1",
|
||||
"transformers==4.47.1",
|
||||
]
|
||||
|
||||
[tool.uv.workspace]
|
||||
members = ["../shared"]
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = { index = "pytorch-cpu" }
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
explicit = true
|
|
@ -1,229 +0,0 @@
|
|||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile pyproject.toml ../shared/pyproject.toml --output-file requirements.lock
|
||||
aiofiles==23.2.1
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
anyio==4.8.0
|
||||
# via starlette
|
||||
attrs==24.3.0
|
||||
# via
|
||||
# clldutils
|
||||
# csvw
|
||||
# jsonschema
|
||||
# phonemizer
|
||||
# referencing
|
||||
babel==2.16.0
|
||||
# via csvw
|
||||
certifi==2024.12.14
|
||||
# via requests
|
||||
cffi==1.17.1
|
||||
# via soundfile
|
||||
charset-normalizer==3.4.1
|
||||
# via requests
|
||||
click==8.1.8
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# uvicorn
|
||||
clldutils==3.21.0
|
||||
# via segments
|
||||
colorama==0.4.6
|
||||
# via
|
||||
# click
|
||||
# colorlog
|
||||
# csvw
|
||||
# loguru
|
||||
# tqdm
|
||||
coloredlogs==15.0.1
|
||||
# via onnxruntime
|
||||
colorlog==6.9.0
|
||||
# via clldutils
|
||||
csvw==3.5.1
|
||||
# via segments
|
||||
dlinfo==1.2.1
|
||||
# via phonemizer
|
||||
exceptiongroup==1.2.2
|
||||
# via anyio
|
||||
fastapi==0.115.6
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
filelock==3.16.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# torch
|
||||
# transformers
|
||||
flatbuffers==24.12.23
|
||||
# via onnxruntime
|
||||
fsspec==2024.12.0
|
||||
# via
|
||||
# huggingface-hub
|
||||
# torch
|
||||
greenlet==3.1.1
|
||||
# via sqlalchemy
|
||||
h11==0.14.0
|
||||
# via uvicorn
|
||||
huggingface-hub==0.27.1
|
||||
# via
|
||||
# tokenizers
|
||||
# transformers
|
||||
humanfriendly==10.0
|
||||
# via coloredlogs
|
||||
idna==3.10
|
||||
# via
|
||||
# anyio
|
||||
# requests
|
||||
isodate==0.7.2
|
||||
# via
|
||||
# csvw
|
||||
# rdflib
|
||||
jinja2==3.1.5
|
||||
# via torch
|
||||
joblib==1.4.2
|
||||
# via phonemizer
|
||||
jsonschema==4.23.0
|
||||
# via csvw
|
||||
jsonschema-specifications==2024.10.1
|
||||
# via jsonschema
|
||||
language-tags==1.2.0
|
||||
# via csvw
|
||||
loguru==0.7.3
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
lxml==5.3.0
|
||||
# via clldutils
|
||||
markdown==3.7
|
||||
# via clldutils
|
||||
markupsafe==3.0.2
|
||||
# via
|
||||
# clldutils
|
||||
# jinja2
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
munch==4.0.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
networkx==3.4.2
|
||||
# via torch
|
||||
numpy==2.2.1
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# onnxruntime
|
||||
# scipy
|
||||
# soundfile
|
||||
# transformers
|
||||
onnxruntime==1.20.1
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
packaging==24.2
|
||||
# via
|
||||
# huggingface-hub
|
||||
# onnxruntime
|
||||
# transformers
|
||||
phonemizer==3.3.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
protobuf==5.29.3
|
||||
# via onnxruntime
|
||||
pycparser==2.22
|
||||
# via cffi
|
||||
pydantic==2.10.4
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# fastapi
|
||||
# pydantic-settings
|
||||
pydantic-core==2.27.2
|
||||
# via pydantic
|
||||
pydantic-settings==2.7.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
pylatexenc==2.10
|
||||
# via clldutils
|
||||
pyparsing==3.2.1
|
||||
# via rdflib
|
||||
pyreadline3==3.5.4
|
||||
# via humanfriendly
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# clldutils
|
||||
# csvw
|
||||
python-dotenv==1.0.1
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# pydantic-settings
|
||||
pyyaml==6.0.2
|
||||
# via
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
rdflib==7.1.2
|
||||
# via csvw
|
||||
referencing==0.35.1
|
||||
# via
|
||||
# jsonschema
|
||||
# jsonschema-specifications
|
||||
regex==2024.11.6
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# segments
|
||||
# tiktoken
|
||||
# transformers
|
||||
requests==2.32.3
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# csvw
|
||||
# huggingface-hub
|
||||
# tiktoken
|
||||
# transformers
|
||||
rfc3986==1.5.0
|
||||
# via csvw
|
||||
rpds-py==0.22.3
|
||||
# via
|
||||
# jsonschema
|
||||
# referencing
|
||||
safetensors==0.5.2
|
||||
# via transformers
|
||||
scipy==1.14.1
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
segments==2.2.1
|
||||
# via phonemizer
|
||||
six==1.17.0
|
||||
# via python-dateutil
|
||||
sniffio==1.3.1
|
||||
# via anyio
|
||||
soundfile==0.13.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
sqlalchemy==2.0.27
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
starlette==0.41.3
|
||||
# via fastapi
|
||||
sympy==1.13.1
|
||||
# via
|
||||
# onnxruntime
|
||||
# torch
|
||||
tabulate==0.9.0
|
||||
# via clldutils
|
||||
tiktoken==0.8.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
tokenizers==0.21.0
|
||||
# via transformers
|
||||
torch==2.5.1+cpu
|
||||
# via kokoro-fastapi-cpu (pyproject.toml)
|
||||
tqdm==4.67.1
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
transformers==4.47.1
|
||||
# via kokoro-fastapi-cpu (pyproject.toml)
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# anyio
|
||||
# fastapi
|
||||
# huggingface-hub
|
||||
# phonemizer
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# sqlalchemy
|
||||
# torch
|
||||
# uvicorn
|
||||
uritemplate==4.1.1
|
||||
# via csvw
|
||||
urllib3==2.3.0
|
||||
# via requests
|
||||
uvicorn==0.34.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
win32-setctime==1.2.0
|
||||
# via loguru
|
1841
docker/cpu/uv.lock
generated
|
@ -1,4 +1,6 @@
|
|||
FROM nvidia/cuda:12.1.0-base-ubuntu22.04
|
||||
FROM --platform=$BUILDPLATFORM nvidia/cuda:12.3.2-cudnn9-runtime-ubuntu22.04
|
||||
# Set non-interactive frontend
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install Python and other dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
|
@ -19,47 +21,43 @@ COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
|||
RUN useradd -m -u 1000 appuser
|
||||
|
||||
# Create directories and set ownership
|
||||
RUN mkdir -p /app/models && \
|
||||
mkdir -p /app/api/src/voices && \
|
||||
RUN mkdir -p /app/api/src/voices && \
|
||||
chown -R appuser:appuser /app
|
||||
|
||||
USER appuser
|
||||
|
||||
# Download and extract models
|
||||
WORKDIR /app/models
|
||||
RUN curl -L -o model.tar.gz https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.0.1/kokoro-82m-pytorch.tar.gz && \
|
||||
tar xzf model.tar.gz && \
|
||||
rm model.tar.gz
|
||||
|
||||
# Download and extract voice models
|
||||
WORKDIR /app/api/src/voices
|
||||
RUN curl -L -o voices.tar.gz https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.0.1/voice-models.tar.gz && \
|
||||
tar xzf voices.tar.gz && \
|
||||
rm voices.tar.gz
|
||||
|
||||
# Switch back to app directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy dependency files
|
||||
COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml
|
||||
|
||||
# Install dependencies
|
||||
# Install dependencies with GPU extras
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv venv && \
|
||||
uv sync --extra gpu --no-install-project
|
||||
uv sync --extra gpu
|
||||
|
||||
# Copy project files
|
||||
# Copy project files including models
|
||||
COPY --chown=appuser:appuser api ./api
|
||||
COPY --chown=appuser:appuser web ./web
|
||||
COPY --chown=appuser:appuser docker/scripts/download_model.* ./
|
||||
|
||||
# Install project
|
||||
# Install project with GPU extras
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv sync --extra gpu
|
||||
|
||||
COPY --chown=appuser:appuser docker/scripts/ /app/docker/scripts/
|
||||
RUN chmod +x docker/scripts/entrypoint.sh
|
||||
RUN chmod +x docker/scripts/download_model.sh
|
||||
|
||||
# Set environment variables
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV PYTHONPATH=/app:/app/models
|
||||
ENV PYTHONPATH=/app
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
ENV USE_GPU=true
|
||||
ENV USE_ONNX=false
|
||||
ENV DOWNLOAD_PTH=true
|
||||
ENV DOWNLOAD_ONNX=false
|
||||
|
||||
# Run FastAPI server
|
||||
CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]
|
||||
CMD ["/app/docker/scripts/entrypoint.sh"]
|
||||
|
|
|
@ -6,12 +6,14 @@ services:
|
|||
context: ../..
|
||||
dockerfile: docker/gpu/Dockerfile
|
||||
volumes:
|
||||
- ../../api/src:/app/api/src # Mount src for development
|
||||
- ../../api/src/voices:/app/api/src/voices # Mount voices for persistence
|
||||
- ../../api:/app/api
|
||||
ports:
|
||||
- "8880:8880"
|
||||
environment:
|
||||
- PYTHONPATH=/app:/app/models
|
||||
- PYTHONPATH=/app:/app/api
|
||||
- USE_GPU=true
|
||||
- USE_ONNX=false
|
||||
- PYTHONUNBUFFERED=1
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
|
@ -20,20 +22,20 @@ services:
|
|||
count: 1
|
||||
capabilities: [gpu]
|
||||
|
||||
# Gradio UI service
|
||||
gradio-ui:
|
||||
image: ghcr.io/remsky/kokoro-fastapi-ui:v0.1.0
|
||||
# Uncomment below to build from source instead of using the released image
|
||||
# build:
|
||||
# context: ../../ui
|
||||
ports:
|
||||
- "7860:7860"
|
||||
volumes:
|
||||
- ../../ui/data:/app/ui/data
|
||||
- ../../ui/app.py:/app/app.py # Mount app.py for hot reload
|
||||
environment:
|
||||
- GRADIO_WATCH=1 # Enable hot reloading
|
||||
- PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
|
||||
- DISABLE_LOCAL_SAVING=false # Set to 'true' to disable local saving and hide file view
|
||||
- API_HOST=kokoro-tts # Set TTS service URL
|
||||
- API_PORT=8880 # Set TTS service PORT
|
||||
# # Gradio UI service
|
||||
# gradio-ui:
|
||||
# image: ghcr.io/remsky/kokoro-fastapi-ui:v0.1.0
|
||||
# # Uncomment below to build from source instead of using the released image
|
||||
# # build:
|
||||
# # context: ../../ui
|
||||
# ports:
|
||||
# - "7860:7860"
|
||||
# volumes:
|
||||
# - ../../ui/data:/app/ui/data
|
||||
# - ../../ui/app.py:/app/app.py # Mount app.py for hot reload
|
||||
# environment:
|
||||
# - GRADIO_WATCH=1 # Enable hot reloading
|
||||
# - PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
|
||||
# - DISABLE_LOCAL_SAVING=false # Set to 'true' to disable local saving and hide file view
|
||||
# - API_HOST=kokoro-tts # Set TTS service URL
|
||||
# - API_PORT=8880 # Set TTS service PORT
|
||||
|
|
|
@ -1,22 +0,0 @@
|
|||
[project]
|
||||
name = "kokoro-fastapi-gpu"
|
||||
version = "0.1.0"
|
||||
description = "FastAPI TTS Service - GPU Version"
|
||||
readme = "../README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
# Core ML/DL for GPU
|
||||
"torch==2.5.1+cu121",
|
||||
"transformers==4.47.1",
|
||||
]
|
||||
|
||||
[tool.uv.workspace]
|
||||
members = ["../shared"]
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = { index = "pytorch-cuda" }
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cuda"
|
||||
url = "https://download.pytorch.org/whl/cu121"
|
||||
explicit = true
|
|
@ -1,229 +0,0 @@
|
|||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile pyproject.toml ../shared/pyproject.toml --output-file requirements.lock
|
||||
aiofiles==23.2.1
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
anyio==4.8.0
|
||||
# via starlette
|
||||
attrs==24.3.0
|
||||
# via
|
||||
# clldutils
|
||||
# csvw
|
||||
# jsonschema
|
||||
# phonemizer
|
||||
# referencing
|
||||
babel==2.16.0
|
||||
# via csvw
|
||||
certifi==2024.12.14
|
||||
# via requests
|
||||
cffi==1.17.1
|
||||
# via soundfile
|
||||
charset-normalizer==3.4.1
|
||||
# via requests
|
||||
click==8.1.8
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# uvicorn
|
||||
clldutils==3.21.0
|
||||
# via segments
|
||||
colorama==0.4.6
|
||||
# via
|
||||
# click
|
||||
# colorlog
|
||||
# csvw
|
||||
# loguru
|
||||
# tqdm
|
||||
coloredlogs==15.0.1
|
||||
# via onnxruntime
|
||||
colorlog==6.9.0
|
||||
# via clldutils
|
||||
csvw==3.5.1
|
||||
# via segments
|
||||
dlinfo==1.2.1
|
||||
# via phonemizer
|
||||
exceptiongroup==1.2.2
|
||||
# via anyio
|
||||
fastapi==0.115.6
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
filelock==3.16.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# torch
|
||||
# transformers
|
||||
flatbuffers==24.12.23
|
||||
# via onnxruntime
|
||||
fsspec==2024.12.0
|
||||
# via
|
||||
# huggingface-hub
|
||||
# torch
|
||||
greenlet==3.1.1
|
||||
# via sqlalchemy
|
||||
h11==0.14.0
|
||||
# via uvicorn
|
||||
huggingface-hub==0.27.1
|
||||
# via
|
||||
# tokenizers
|
||||
# transformers
|
||||
humanfriendly==10.0
|
||||
# via coloredlogs
|
||||
idna==3.10
|
||||
# via
|
||||
# anyio
|
||||
# requests
|
||||
isodate==0.7.2
|
||||
# via
|
||||
# csvw
|
||||
# rdflib
|
||||
jinja2==3.1.5
|
||||
# via torch
|
||||
joblib==1.4.2
|
||||
# via phonemizer
|
||||
jsonschema==4.23.0
|
||||
# via csvw
|
||||
jsonschema-specifications==2024.10.1
|
||||
# via jsonschema
|
||||
language-tags==1.2.0
|
||||
# via csvw
|
||||
loguru==0.7.3
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
lxml==5.3.0
|
||||
# via clldutils
|
||||
markdown==3.7
|
||||
# via clldutils
|
||||
markupsafe==3.0.2
|
||||
# via
|
||||
# clldutils
|
||||
# jinja2
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
munch==4.0.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
networkx==3.4.2
|
||||
# via torch
|
||||
numpy==2.2.1
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# onnxruntime
|
||||
# scipy
|
||||
# soundfile
|
||||
# transformers
|
||||
onnxruntime==1.20.1
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
packaging==24.2
|
||||
# via
|
||||
# huggingface-hub
|
||||
# onnxruntime
|
||||
# transformers
|
||||
phonemizer==3.3.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
protobuf==5.29.3
|
||||
# via onnxruntime
|
||||
pycparser==2.22
|
||||
# via cffi
|
||||
pydantic==2.10.4
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# fastapi
|
||||
# pydantic-settings
|
||||
pydantic-core==2.27.2
|
||||
# via pydantic
|
||||
pydantic-settings==2.7.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
pylatexenc==2.10
|
||||
# via clldutils
|
||||
pyparsing==3.2.1
|
||||
# via rdflib
|
||||
pyreadline3==3.5.4
|
||||
# via humanfriendly
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# clldutils
|
||||
# csvw
|
||||
python-dotenv==1.0.1
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# pydantic-settings
|
||||
pyyaml==6.0.2
|
||||
# via
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
rdflib==7.1.2
|
||||
# via csvw
|
||||
referencing==0.35.1
|
||||
# via
|
||||
# jsonschema
|
||||
# jsonschema-specifications
|
||||
regex==2024.11.6
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# segments
|
||||
# tiktoken
|
||||
# transformers
|
||||
requests==2.32.3
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# csvw
|
||||
# huggingface-hub
|
||||
# tiktoken
|
||||
# transformers
|
||||
rfc3986==1.5.0
|
||||
# via csvw
|
||||
rpds-py==0.22.3
|
||||
# via
|
||||
# jsonschema
|
||||
# referencing
|
||||
safetensors==0.5.2
|
||||
# via transformers
|
||||
scipy==1.14.1
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
segments==2.2.1
|
||||
# via phonemizer
|
||||
six==1.17.0
|
||||
# via python-dateutil
|
||||
sniffio==1.3.1
|
||||
# via anyio
|
||||
soundfile==0.13.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
sqlalchemy==2.0.27
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
starlette==0.41.3
|
||||
# via fastapi
|
||||
sympy==1.13.1
|
||||
# via
|
||||
# onnxruntime
|
||||
# torch
|
||||
tabulate==0.9.0
|
||||
# via clldutils
|
||||
tiktoken==0.8.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
tokenizers==0.21.0
|
||||
# via transformers
|
||||
torch==2.5.1+cu121
|
||||
# via kokoro-fastapi-gpu (pyproject.toml)
|
||||
tqdm==4.67.1
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
transformers==4.47.1
|
||||
# via kokoro-fastapi-gpu (pyproject.toml)
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# anyio
|
||||
# fastapi
|
||||
# huggingface-hub
|
||||
# phonemizer
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# sqlalchemy
|
||||
# torch
|
||||
# uvicorn
|
||||
uritemplate==4.1.1
|
||||
# via csvw
|
||||
urllib3==2.3.0
|
||||
# via requests
|
||||
uvicorn==0.34.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
win32-setctime==1.2.0
|
||||
# via loguru
|
1914
docker/gpu/uv.lock
generated
103
docker/scripts/download_model.py
Normal file
|
@ -0,0 +1,103 @@
|
|||
#!/usr/bin/env python3
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import requests
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
def download_file(url: str, output_dir: Path, model_type: str, overwrite:str) -> bool:
|
||||
"""Download a file from URL to the specified directory.
|
||||
|
||||
Returns:
|
||||
bool: True if download succeeded, False otherwise
|
||||
"""
|
||||
filename = os.path.basename(url)
|
||||
if not filename.endswith(f'.{model_type}'):
|
||||
print(f"Warning: {filename} is not a .{model_type} file", file=sys.stderr)
|
||||
return False
|
||||
|
||||
output_path = output_dir / filename
|
||||
|
||||
if os.path.exists(output_path):
|
||||
print(f"{filename} exists. Canceling download")
|
||||
return True
|
||||
|
||||
print(f"Downloading {filename}...")
|
||||
try:
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(output_path, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
print(f"Successfully downloaded {filename}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error downloading {filename}: {e}", file=sys.stderr)
|
||||
return False
|
||||
|
||||
def find_project_root() -> Path:
|
||||
"""Find project root by looking for api directory."""
|
||||
max_steps = 5
|
||||
current = Path(__file__).resolve()
|
||||
for _ in range(max_steps):
|
||||
if (current / 'api').is_dir():
|
||||
return current
|
||||
current = current.parent
|
||||
raise RuntimeError("Could not find project root (no api directory found)")
|
||||
|
||||
def main() -> int:
|
||||
"""Download models to the project.
|
||||
|
||||
Returns:
|
||||
int: Exit code (0 for success, 1 for failure)
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description='Download model files')
|
||||
parser.add_argument('--type', choices=['pth', 'onnx'], required=True,
|
||||
help='Model type to download (pth or onnx)')
|
||||
parser.add_argument('--overwrite', action='store_true', help='Overwite existing files')
|
||||
parser.add_argument('urls', nargs='*', help='Optional model URLs to download')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# Find project root and ensure models directory exists
|
||||
project_root = find_project_root()
|
||||
models_dir = project_root / 'api' / 'src' / 'models'
|
||||
print(f"Downloading models to {models_dir}")
|
||||
models_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Default models if no arguments provided
|
||||
default_models = {
|
||||
'pth': [
|
||||
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth",
|
||||
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19-half.pth"
|
||||
],
|
||||
'onnx': [
|
||||
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.onnx",
|
||||
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19_fp16.onnx"
|
||||
]
|
||||
}
|
||||
|
||||
# Use provided models or default
|
||||
models_to_download = args.urls if args.urls else default_models[args.type]
|
||||
|
||||
# Download all models
|
||||
success = True
|
||||
for model_url in models_to_download:
|
||||
if not download_file(model_url, models_dir, args.type,args.overwrite):
|
||||
success = False
|
||||
|
||||
if success:
|
||||
print(f"{args.type.upper()} model download complete!")
|
||||
return 0
|
||||
else:
|
||||
print("Some downloads failed", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
110
docker/scripts/download_model.sh
Normal file
|
@ -0,0 +1,110 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Find project root by looking for api directory
|
||||
find_project_root() {
|
||||
local current_dir="$PWD"
|
||||
local max_steps=5
|
||||
local steps=0
|
||||
|
||||
while [ $steps -lt $max_steps ]; do
|
||||
if [ -d "$current_dir/api" ]; then
|
||||
echo "$current_dir"
|
||||
return 0
|
||||
fi
|
||||
current_dir="$(dirname "$current_dir")"
|
||||
((steps++))
|
||||
done
|
||||
|
||||
echo "Error: Could not find project root (no api directory found)" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Function to download a file
|
||||
download_file() {
|
||||
local url="$1"
|
||||
local output_dir="$2"
|
||||
local model_type="$3"
|
||||
local filename=$(basename "$url")
|
||||
|
||||
# Validate file extension
|
||||
if [[ ! "$filename" =~ \.$model_type$ ]]; then
|
||||
echo "Warning: $filename is not a .$model_type file" >&2
|
||||
return 1
|
||||
}
|
||||
|
||||
echo "Downloading $filename..."
|
||||
if curl -L "$url" -o "$output_dir/$filename"; then
|
||||
echo "Successfully downloaded $filename"
|
||||
return 0
|
||||
else
|
||||
echo "Error downloading $filename" >&2
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Parse arguments
|
||||
MODEL_TYPE=""
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--type)
|
||||
MODEL_TYPE="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
# If no flag specified, treat remaining args as model URLs
|
||||
break
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Validate model type
|
||||
if [ "$MODEL_TYPE" != "pth" ] && [ "$MODEL_TYPE" != "onnx" ]; then
|
||||
echo "Error: Must specify model type with --type (pth or onnx)" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Find project root and ensure models directory exists
|
||||
PROJECT_ROOT=$(find_project_root)
|
||||
if [ $? -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
MODELS_DIR="$PROJECT_ROOT/api/src/models"
|
||||
echo "Downloading models to $MODELS_DIR"
|
||||
mkdir -p "$MODELS_DIR"
|
||||
|
||||
# Default models if no arguments provided
|
||||
if [ "$MODEL_TYPE" = "pth" ]; then
|
||||
DEFAULT_MODELS=(
|
||||
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth"
|
||||
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19-half.pth"
|
||||
)
|
||||
else
|
||||
DEFAULT_MODELS=(
|
||||
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.onnx"
|
||||
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19_fp16.onnx"
|
||||
)
|
||||
fi
|
||||
|
||||
# Use provided models or default
|
||||
if [ $# -gt 0 ]; then
|
||||
MODELS=("$@")
|
||||
else
|
||||
MODELS=("${DEFAULT_MODELS[@]}")
|
||||
fi
|
||||
|
||||
# Download all models
|
||||
success=true
|
||||
for model in "${MODELS[@]}"; do
|
||||
if ! download_file "$model" "$MODELS_DIR" "$MODEL_TYPE"; then
|
||||
success=false
|
||||
fi
|
||||
done
|
||||
|
||||
if [ "$success" = true ]; then
|
||||
echo "${MODEL_TYPE^^} model download complete!"
|
||||
exit 0
|
||||
else
|
||||
echo "Some downloads failed" >&2
|
||||
exit 1
|
||||
fi
|
12
docker/scripts/entrypoint.sh
Normal file
|
@ -0,0 +1,12 @@
|
|||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
if [ "$DOWNLOAD_PTH" = "true" ]; then
|
||||
python docker/scripts/download_model.py --type pth
|
||||
fi
|
||||
|
||||
if [ "$DOWNLOAD_ONNX" = "true" ]; then
|
||||
python docker/scripts/download_model.py --type onnx
|
||||
fi
|
||||
|
||||
exec uv run python -m uvicorn api.src.main:app --host 0.0.0.0 --port 8880 --log-level debug
|
167
docs/architecture/streaming_audio_writer_analysis.md
Normal file
|
@ -0,0 +1,167 @@
|
|||
# Streaming Audio Writer Analysis
|
||||
|
||||
This auto-document provides an in-depth technical analysis of the `StreamingAudioWriter` class, detailing the streaming and non-streaming paths, supported formats, header management, and challenges faced in the implementation.
|
||||
|
||||
## Overview
|
||||
|
||||
The `StreamingAudioWriter` class is designed to handle streaming audio format conversions efficiently. It supports various audio formats and provides methods to write audio data in chunks, finalize the stream, and manage audio headers meticulously to ensure compatibility and integrity of the resulting audio files.
|
||||
|
||||
## Supported Formats
|
||||
|
||||
The class supports the following audio formats:
|
||||
|
||||
- **WAV**
|
||||
- **OGG**
|
||||
- **Opus**
|
||||
- **FLAC**
|
||||
- **MP3**
|
||||
- **AAC**
|
||||
- **PCM**
|
||||
|
||||
## Initialization
|
||||
|
||||
Upon initialization, the class sets up format-specific configurations to prepare for audio data processing:
|
||||
|
||||
- **WAV**:
|
||||
- Writes an initial WAV header with placeholders for file size (`RIFF` chunk) and data size (`data` chunk).
|
||||
- Utilizes the `_write_wav_header` method to generate the header.
|
||||
|
||||
- **OGG/Opus/FLAC**:
|
||||
- Uses `soundfile.SoundFile` to write audio data to a memory buffer (`BytesIO`).
|
||||
- Configures the writer with appropriate format and subtype based on the specified format.
|
||||
|
||||
- **MP3/AAC**:
|
||||
- Utilizes `pydub.AudioSegment` for incremental writing.
|
||||
- Initializes an empty `AudioSegment` as the encoder to accumulate audio data.
|
||||
|
||||
- **PCM**:
|
||||
- Prepares to write raw PCM bytes without additional headers.
|
||||
|
||||
Initialization ensures that each format is correctly configured to handle the specific requirements of streaming and finalizing audio data.
|
||||
|
||||
## Streaming Path
|
||||
|
||||
### Writing Chunks
|
||||
|
||||
The `write_chunk` method handles the incoming audio data, processing it according to the specified format:
|
||||
|
||||
- **WAV**:
|
||||
- **First Chunk**: Writes the initial WAV header to the buffer.
|
||||
- **Subsequent Chunks**: Writes raw PCM data directly after the header.
|
||||
- Updates `bytes_written` to track the total size of audio data written.
|
||||
|
||||
- **OGG/Opus/FLAC**:
|
||||
- Writes audio data to the `soundfile` buffer.
|
||||
- Flushes the writer to ensure data integrity.
|
||||
- Retrieves the current buffer contents and truncates the buffer for the next chunk.
|
||||
|
||||
- **MP3/AAC**:
|
||||
- Converts incoming audio data (`np.ndarray`) to a `pydub.AudioSegment`.
|
||||
- Accumulates segments in the encoder.
|
||||
- Exports the current state to the output buffer without writing duration metadata or XING headers for chunks.
|
||||
- Resets the encoder to prevent memory growth after exporting.
|
||||
|
||||
- **PCM**:
|
||||
- Directly writes raw bytes from the audio data to the output buffer.
|
||||
|
||||
### Finalizing
|
||||
|
||||
Finalizing the audio stream involves ensuring that all audio data is correctly written and that headers are updated to reflect the accurate file and data sizes:
|
||||
|
||||
- **WAV**:
|
||||
- Rewrites the `RIFF` and `data` chunks in the header with the actual file size (`bytes_written + 36`) and data size (`bytes_written`).
|
||||
- Creates a new buffer with the complete WAV file by copying audio data from the original buffer starting at byte 44 (end of the initial header).
|
||||
|
||||
- **OGG/Opus/FLAC**:
|
||||
- Closes the `soundfile` writer to flush all remaining data to the buffer.
|
||||
- Returns the final buffer content, ensuring that all necessary headers and data are correctly written.
|
||||
|
||||
- **MP3/AAC**:
|
||||
- Exports any remaining audio data with proper headers and metadata, including duration and VBR quality for MP3.
|
||||
- Writes ID3v1 and ID3v2 tags for MP3 formats.
|
||||
- Performs final exports to ensure that all audio data is properly encoded and formatted.
|
||||
|
||||
- **PCM**:
|
||||
- No finalization is needed as PCM involves raw data without headers.
|
||||
|
||||
## Non-Streaming Path
|
||||
|
||||
The `StreamingAudioWriter` class is inherently designed for streaming audio data. However, it's essential to understand how it behaves when handling complete files versus streaming data:
|
||||
|
||||
### Full File Writing
|
||||
|
||||
- **Process**:
|
||||
- Accumulate all audio data in memory or buffer.
|
||||
- Write the complete file with accurate headers and data sizes upon finalization.
|
||||
|
||||
- **Advantages**:
|
||||
- Simplifies header management since the total data size is known before writing.
|
||||
- Reduces complexity in data handling and processing.
|
||||
|
||||
- **Disadvantages**:
|
||||
- High memory consumption for large audio files.
|
||||
- Delay in availability of audio data until the entire file is processed.
|
||||
|
||||
### Stream-to-File Writing
|
||||
|
||||
- **Process**:
|
||||
- Incrementally write audio data in chunks.
|
||||
- Update headers and finalize the file dynamically as data flows.
|
||||
|
||||
- **Advantages**:
|
||||
- Lower memory usage as data is processed in smaller chunks.
|
||||
- Immediate availability of audio data, suitable for real-time streaming applications.
|
||||
|
||||
- **Disadvantages**:
|
||||
- Complex header management to accommodate dynamic data sizes.
|
||||
- Increased likelihood of header synchronization issues, leading to potential file corruption.
|
||||
|
||||
**Challenges**:
|
||||
- Balancing memory usage with processing speed.
|
||||
- Ensuring consistent and accurate header updates during streaming operations.
|
||||
|
||||
## Header Management
|
||||
|
||||
### WAV Headers
|
||||
|
||||
WAV files utilize `RIFF` headers to describe file structure:
|
||||
|
||||
- **Initial Header**:
|
||||
- Contains placeholders for file size and data size (`struct.pack('<L', 0)`).
|
||||
|
||||
- **Final Header**:
|
||||
- Calculates and writes the actual file size (`bytes_written + 36`) and data size (`bytes_written`).
|
||||
- Ensures that audio players can correctly interpret the file by having accurate header information.
|
||||
|
||||
**Technical Details**:
|
||||
- The `_write_wav_header` method initializes the WAV header with placeholders.
|
||||
- Upon finalization, the `write_chunk` method creates a new buffer, writes the correct sizes, and appends the audio data from the original buffer starting at byte 44 (end of the initial header).
|
||||
|
||||
**Challenges**:
|
||||
- Maintaining synchronization between audio data size and header placeholders.
|
||||
- Ensuring that the header is correctly rewritten upon finalization to prevent file corruption.
|
||||
|
||||
### MP3/AAC Headers
|
||||
|
||||
MP3 and AAC formats require proper metadata and headers to ensure compatibility:
|
||||
|
||||
- **XING Headers (MP3)**:
|
||||
- Essential for Variable Bit Rate (VBR) audio files.
|
||||
- Control the quality and indexing of the MP3 file.
|
||||
|
||||
- **ID3 Tags (MP3)**:
|
||||
- Provide metadata such as artist, title, and album information.
|
||||
|
||||
- **ADTS Headers (AAC)**:
|
||||
- Describe the AAC frame headers necessary for decoding.
|
||||
|
||||
**Technical Details**:
|
||||
- During finalization, the `write_chunk` method for MP3/AAC formats includes:
|
||||
- Duration metadata (`-metadata duration`).
|
||||
- VBR headers for MP3 (`-write_vbr`, `-vbr_quality`).
|
||||
- ID3 tags for MP3 (`-write_id3v1`, `-write_id3v2`).
|
||||
- Ensures that all remaining audio data is correctly encoded and formatted with the necessary headers.
|
||||
|
||||
**Challenges**:
|
||||
- Ensuring that metadata is accurately written during the finalization process.
|
||||
- Managing VBR headers to maintain audio quality and file integrity.
|
|
@ -68,7 +68,7 @@ def main():
|
|||
# Initialize system monitor
|
||||
monitor = SystemMonitor(interval=1.0) # 1 second interval
|
||||
# Set prefix for output files (e.g. "gpu", "cpu", "onnx", etc.)
|
||||
prefix = "cpu"
|
||||
prefix = "gpu"
|
||||
# Generate token sizes
|
||||
if "gpu" in prefix:
|
||||
token_sizes = generate_token_sizes(
|
||||
|
|
|
@ -1,23 +1,23 @@
|
|||
=== Benchmark Statistics (with correct RTF) ===
|
||||
|
||||
Total tokens processed: 1800
|
||||
Total audio generated (s): 568.53
|
||||
Total test duration (s): 306.02
|
||||
Average processing rate (tokens/s): 5.75
|
||||
Average RTF: 0.55
|
||||
Average Real Time Speed: 1.81
|
||||
Total tokens processed: 1500
|
||||
Total audio generated (s): 427.90
|
||||
Total test duration (s): 10.84
|
||||
Average processing rate (tokens/s): 133.35
|
||||
Average RTF: 0.02
|
||||
Average Real Time Speed: 41.67
|
||||
|
||||
=== Per-chunk Stats ===
|
||||
|
||||
Average chunk size (tokens): 600.00
|
||||
Min chunk size (tokens): 300
|
||||
Max chunk size (tokens): 900
|
||||
Average processing time (s): 101.89
|
||||
Average output length (s): 189.51
|
||||
Average chunk size (tokens): 300.00
|
||||
Min chunk size (tokens): 100
|
||||
Max chunk size (tokens): 500
|
||||
Average processing time (s): 2.13
|
||||
Average output length (s): 85.58
|
||||
|
||||
=== Performance Ranges ===
|
||||
|
||||
Processing rate range (tokens/s): 5.30 - 6.26
|
||||
RTF range: 0.51x - 0.59x
|
||||
Real Time Speed range: 1.69x - 1.96x
|
||||
Processing rate range (tokens/s): 102.04 - 159.74
|
||||
RTF range: 0.02x - 0.03x
|
||||
Real Time Speed range: 33.33x - 50.00x
|
||||
|
||||
|
|
|
@ -2,616 +2,240 @@
|
|||
"results": [
|
||||
{
|
||||
"tokens": 150,
|
||||
"processing_time": 2.36,
|
||||
"output_length": 45.9,
|
||||
"rtf": 0.05,
|
||||
"elapsed_time": 2.44626
|
||||
"processing_time": 1.18,
|
||||
"output_length": 43.7,
|
||||
"rtf": 0.03,
|
||||
"elapsed_time": 1.20302
|
||||
},
|
||||
{
|
||||
"tokens": 300,
|
||||
"processing_time": 4.94,
|
||||
"output_length": 96.425,
|
||||
"rtf": 0.05,
|
||||
"elapsed_time": 7.46073
|
||||
"processing_time": 2.27,
|
||||
"output_length": 86.75,
|
||||
"rtf": 0.03,
|
||||
"elapsed_time": 3.49958
|
||||
},
|
||||
{
|
||||
"tokens": 450,
|
||||
"processing_time": 8.94,
|
||||
"output_length": 143.1,
|
||||
"rtf": 0.06,
|
||||
"elapsed_time": 16.55036
|
||||
"processing_time": 3.49,
|
||||
"output_length": 125.9,
|
||||
"rtf": 0.03,
|
||||
"elapsed_time": 7.03862
|
||||
},
|
||||
{
|
||||
"tokens": 600,
|
||||
"processing_time": 19.78,
|
||||
"output_length": 188.675,
|
||||
"rtf": 0.1,
|
||||
"elapsed_time": 36.69352
|
||||
"processing_time": 4.64,
|
||||
"output_length": 169.325,
|
||||
"rtf": 0.03,
|
||||
"elapsed_time": 11.71062
|
||||
},
|
||||
{
|
||||
"tokens": 750,
|
||||
"processing_time": 19.89,
|
||||
"output_length": 236.7,
|
||||
"rtf": 0.08,
|
||||
"elapsed_time": 56.77695
|
||||
"processing_time": 5.07,
|
||||
"output_length": 212.3,
|
||||
"rtf": 0.02,
|
||||
"elapsed_time": 16.83186
|
||||
},
|
||||
{
|
||||
"tokens": 900,
|
||||
"processing_time": 16.83,
|
||||
"output_length": 283.425,
|
||||
"rtf": 0.06,
|
||||
"elapsed_time": 73.8079
|
||||
"processing_time": 6.66,
|
||||
"output_length": 258.0,
|
||||
"rtf": 0.03,
|
||||
"elapsed_time": 23.54135
|
||||
}
|
||||
],
|
||||
"system_metrics": [
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:20.888295",
|
||||
"cpu_percent": 36.92,
|
||||
"ram_percent": 68.6,
|
||||
"ram_used_gb": 43.6395263671875,
|
||||
"gpu_memory_used": 7022.0,
|
||||
"relative_time": 0.09646010398864746
|
||||
"timestamp": "2025-01-30T05:06:38.733338",
|
||||
"cpu_percent": 0.0,
|
||||
"ram_percent": 18.6,
|
||||
"ram_used_gb": 5.284908294677734,
|
||||
"gpu_memory_used": 1925.0,
|
||||
"relative_time": 0.039948463439941406
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:21.983741",
|
||||
"cpu_percent": 22.29,
|
||||
"ram_percent": 68.6,
|
||||
"ram_used_gb": 43.642677307128906,
|
||||
"gpu_memory_used": 7021.0,
|
||||
"relative_time": 1.1906661987304688
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:23.078293",
|
||||
"cpu_percent": 27.39,
|
||||
"ram_percent": 68.6,
|
||||
"ram_used_gb": 43.61421203613281,
|
||||
"gpu_memory_used": 7190.0,
|
||||
"relative_time": 2.264479160308838
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:24.151445",
|
||||
"cpu_percent": 20.28,
|
||||
"ram_percent": 68.6,
|
||||
"ram_used_gb": 43.65406036376953,
|
||||
"gpu_memory_used": 7193.0,
|
||||
"relative_time": 3.349093198776245
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:25.237021",
|
||||
"cpu_percent": 23.03,
|
||||
"ram_percent": 68.6,
|
||||
"ram_used_gb": 43.647274017333984,
|
||||
"gpu_memory_used": 7191.0,
|
||||
"relative_time": 4.413560628890991
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:26.300255",
|
||||
"cpu_percent": 23.62,
|
||||
"ram_percent": 68.6,
|
||||
"ram_used_gb": 43.642295837402344,
|
||||
"gpu_memory_used": 7185.0,
|
||||
"relative_time": 5.484973430633545
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:27.377319",
|
||||
"cpu_percent": 46.04,
|
||||
"ram_percent": 68.7,
|
||||
"ram_used_gb": 43.7291374206543,
|
||||
"gpu_memory_used": 7178.0,
|
||||
"relative_time": 6.658120632171631
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:28.546053",
|
||||
"cpu_percent": 29.79,
|
||||
"ram_percent": 68.7,
|
||||
"ram_used_gb": 43.73202133178711,
|
||||
"gpu_memory_used": 7177.0,
|
||||
"relative_time": 7.725939035415649
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:29.613327",
|
||||
"cpu_percent": 18.19,
|
||||
"ram_percent": 68.8,
|
||||
"ram_used_gb": 43.791343688964844,
|
||||
"gpu_memory_used": 7177.0,
|
||||
"relative_time": 8.800285577774048
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:30.689097",
|
||||
"cpu_percent": 22.29,
|
||||
"ram_percent": 68.9,
|
||||
"ram_used_gb": 43.81514358520508,
|
||||
"gpu_memory_used": 7176.0,
|
||||
"relative_time": 9.899119853973389
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:31.786443",
|
||||
"cpu_percent": 32.59,
|
||||
"ram_percent": 68.9,
|
||||
"ram_used_gb": 43.834510803222656,
|
||||
"gpu_memory_used": 7189.0,
|
||||
"relative_time": 11.042734384536743
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:32.929720",
|
||||
"cpu_percent": 42.48,
|
||||
"ram_percent": 68.8,
|
||||
"ram_used_gb": 43.77507019042969,
|
||||
"gpu_memory_used": 7192.0,
|
||||
"relative_time": 12.117269277572632
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:34.004481",
|
||||
"cpu_percent": 26.33,
|
||||
"ram_percent": 68.8,
|
||||
"ram_used_gb": 43.77891159057617,
|
||||
"gpu_memory_used": 7192.0,
|
||||
"relative_time": 13.19830870628357
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:35.086024",
|
||||
"cpu_percent": 26.53,
|
||||
"ram_percent": 68.8,
|
||||
"ram_used_gb": 43.77515411376953,
|
||||
"gpu_memory_used": 7192.0,
|
||||
"relative_time": 14.29457426071167
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:36.183496",
|
||||
"cpu_percent": 40.33,
|
||||
"ram_percent": 68.9,
|
||||
"ram_used_gb": 43.81095886230469,
|
||||
"gpu_memory_used": 7192.0,
|
||||
"relative_time": 15.402768850326538
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:37.290635",
|
||||
"cpu_percent": 43.6,
|
||||
"ram_percent": 69.0,
|
||||
"ram_used_gb": 43.87236022949219,
|
||||
"gpu_memory_used": 7190.0,
|
||||
"relative_time": 16.574281930923462
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:38.462164",
|
||||
"cpu_percent": 85.74,
|
||||
"ram_percent": 69.0,
|
||||
"ram_used_gb": 43.864280700683594,
|
||||
"gpu_memory_used": 6953.0,
|
||||
"relative_time": 17.66074824333191
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:39.548295",
|
||||
"cpu_percent": 23.88,
|
||||
"ram_percent": 68.8,
|
||||
"ram_used_gb": 43.75236129760742,
|
||||
"gpu_memory_used": 4722.0,
|
||||
"relative_time": 18.739423036575317
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:40.626692",
|
||||
"cpu_percent": 59.24,
|
||||
"ram_percent": 68.7,
|
||||
"ram_used_gb": 43.720741271972656,
|
||||
"gpu_memory_used": 4723.0,
|
||||
"relative_time": 19.846031665802002
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:41.733597",
|
||||
"cpu_percent": 41.74,
|
||||
"ram_percent": 68.4,
|
||||
"ram_used_gb": 43.53546142578125,
|
||||
"gpu_memory_used": 4722.0,
|
||||
"relative_time": 20.920310020446777
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:42.808191",
|
||||
"cpu_percent": 35.43,
|
||||
"ram_percent": 68.3,
|
||||
"ram_used_gb": 43.424468994140625,
|
||||
"gpu_memory_used": 4726.0,
|
||||
"relative_time": 22.00457763671875
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:43.891669",
|
||||
"cpu_percent": 43.81,
|
||||
"ram_percent": 68.2,
|
||||
"ram_used_gb": 43.38311004638672,
|
||||
"gpu_memory_used": 4727.0,
|
||||
"relative_time": 23.08402943611145
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:44.971246",
|
||||
"cpu_percent": 58.13,
|
||||
"ram_percent": 68.0,
|
||||
"ram_used_gb": 43.27970886230469,
|
||||
"gpu_memory_used": 4731.0,
|
||||
"relative_time": 24.249765396118164
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:46.137626",
|
||||
"cpu_percent": 66.76,
|
||||
"ram_percent": 68.0,
|
||||
"ram_used_gb": 43.23844528198242,
|
||||
"gpu_memory_used": 4731.0,
|
||||
"relative_time": 25.32853865623474
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:47.219723",
|
||||
"cpu_percent": 27.95,
|
||||
"ram_percent": 67.8,
|
||||
"ram_used_gb": 43.106136322021484,
|
||||
"gpu_memory_used": 4734.0,
|
||||
"relative_time": 26.499221563339233
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:48.386913",
|
||||
"cpu_percent": 73.13,
|
||||
"ram_percent": 67.7,
|
||||
"ram_used_gb": 43.049781799316406,
|
||||
"gpu_memory_used": 4736.0,
|
||||
"relative_time": 27.592528104782104
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:49.480407",
|
||||
"cpu_percent": 50.63,
|
||||
"ram_percent": 67.6,
|
||||
"ram_used_gb": 43.007415771484375,
|
||||
"gpu_memory_used": 4736.0,
|
||||
"relative_time": 28.711266040802002
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:50.599220",
|
||||
"cpu_percent": 92.36,
|
||||
"ram_percent": 67.5,
|
||||
"ram_used_gb": 42.9685173034668,
|
||||
"gpu_memory_used": 4728.0,
|
||||
"relative_time": 29.916289567947388
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:51.803667",
|
||||
"cpu_percent": 83.07,
|
||||
"ram_percent": 67.5,
|
||||
"ram_used_gb": 42.96232986450195,
|
||||
"gpu_memory_used": 4724.0,
|
||||
"relative_time": 31.039498805999756
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:52.927208",
|
||||
"cpu_percent": 90.61,
|
||||
"ram_percent": 67.5,
|
||||
"ram_used_gb": 42.96202850341797,
|
||||
"gpu_memory_used": 5037.0,
|
||||
"relative_time": 32.2381911277771
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:54.128135",
|
||||
"cpu_percent": 89.47,
|
||||
"ram_percent": 67.5,
|
||||
"ram_used_gb": 42.94692611694336,
|
||||
"gpu_memory_used": 5085.0,
|
||||
"relative_time": 33.35147500038147
|
||||
{
|
||||
"timestamp": "2025-01-30T05:06:39.774003",
|
||||
"cpu_percent": 13.37,
|
||||
"ram_percent": 18.6,
|
||||
"ram_used_gb": 5.2852630615234375,
|
||||
"gpu_memory_used": 3047.0,
|
||||
"relative_time": 1.0883615016937256
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:55.238967",
|
||||
"cpu_percent": 60.01,
|
||||
"ram_percent": 67.4,
|
||||
"ram_used_gb": 42.88222122192383,
|
||||
"gpu_memory_used": 5085.0,
|
||||
"relative_time": 34.455963373184204
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:56.344164",
|
||||
"cpu_percent": 62.12,
|
||||
"ram_percent": 67.3,
|
||||
"ram_used_gb": 42.81411361694336,
|
||||
"gpu_memory_used": 5083.0,
|
||||
"relative_time": 35.549962282180786
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:57.437566",
|
||||
"cpu_percent": 53.56,
|
||||
"ram_percent": 67.3,
|
||||
"ram_used_gb": 42.83011245727539,
|
||||
"gpu_memory_used": 5078.0,
|
||||
"relative_time": 36.66783380508423
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:58.554923",
|
||||
"cpu_percent": 80.27,
|
||||
"ram_percent": 67.3,
|
||||
"ram_used_gb": 42.79304504394531,
|
||||
"gpu_memory_used": 5069.0,
|
||||
"relative_time": 37.77330660820007
|
||||
{
|
||||
"timestamp": "2025-01-30T05:06:40.822449",
|
||||
"cpu_percent": 13.68,
|
||||
"ram_percent": 18.7,
|
||||
"ram_used_gb": 5.303462982177734,
|
||||
"gpu_memory_used": 3040.0,
|
||||
"relative_time": 2.12058687210083
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:43:59.660456",
|
||||
"cpu_percent": 72.33,
|
||||
"ram_percent": 67.2,
|
||||
"ram_used_gb": 42.727474212646484,
|
||||
"gpu_memory_used": 5079.0,
|
||||
"relative_time": 38.885955810546875
|
||||
{
|
||||
"timestamp": "2025-01-30T05:06:41.854375",
|
||||
"cpu_percent": 15.39,
|
||||
"ram_percent": 18.7,
|
||||
"ram_used_gb": 5.306262969970703,
|
||||
"gpu_memory_used": 3326.0,
|
||||
"relative_time": 3.166278600692749
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:00.773867",
|
||||
"cpu_percent": 59.29,
|
||||
"ram_percent": 66.9,
|
||||
"ram_used_gb": 42.566131591796875,
|
||||
"gpu_memory_used": 5079.0,
|
||||
"relative_time": 39.99704432487488
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:01.884399",
|
||||
"cpu_percent": 43.52,
|
||||
"ram_percent": 66.5,
|
||||
"ram_used_gb": 42.32980728149414,
|
||||
"gpu_memory_used": 5079.0,
|
||||
"relative_time": 41.13008522987366
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:03.018905",
|
||||
"cpu_percent": 84.46,
|
||||
"ram_percent": 66.5,
|
||||
"ram_used_gb": 42.28911590576172,
|
||||
"gpu_memory_used": 5087.0,
|
||||
"relative_time": 42.296770095825195
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:04.184606",
|
||||
"cpu_percent": 88.27,
|
||||
"ram_percent": 66.3,
|
||||
"ram_used_gb": 42.16263961791992,
|
||||
"gpu_memory_used": 5091.0,
|
||||
"relative_time": 43.42832589149475
|
||||
{
|
||||
"timestamp": "2025-01-30T05:06:42.900882",
|
||||
"cpu_percent": 14.19,
|
||||
"ram_percent": 18.8,
|
||||
"ram_used_gb": 5.337162017822266,
|
||||
"gpu_memory_used": 2530.0,
|
||||
"relative_time": 4.256956577301025
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:05.315967",
|
||||
"cpu_percent": 80.91,
|
||||
"ram_percent": 65.9,
|
||||
"ram_used_gb": 41.9491081237793,
|
||||
"gpu_memory_used": 5089.0,
|
||||
"relative_time": 44.52496290206909
|
||||
{
|
||||
"timestamp": "2025-01-30T05:06:43.990792",
|
||||
"cpu_percent": 12.63,
|
||||
"ram_percent": 18.8,
|
||||
"ram_used_gb": 5.333805084228516,
|
||||
"gpu_memory_used": 3331.0,
|
||||
"relative_time": 5.2854602336883545
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:06.412298",
|
||||
"cpu_percent": 41.68,
|
||||
"ram_percent": 65.6,
|
||||
"ram_used_gb": 41.72716522216797,
|
||||
"gpu_memory_used": 5090.0,
|
||||
"relative_time": 45.679444313049316
|
||||
{
|
||||
"timestamp": "2025-01-30T05:06:45.019134",
|
||||
"cpu_percent": 14.14,
|
||||
"ram_percent": 18.8,
|
||||
"ram_used_gb": 5.334297180175781,
|
||||
"gpu_memory_used": 3332.0,
|
||||
"relative_time": 6.351738929748535
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:07.566964",
|
||||
"cpu_percent": 73.02,
|
||||
"ram_percent": 65.5,
|
||||
"ram_used_gb": 41.64710998535156,
|
||||
"gpu_memory_used": 5091.0,
|
||||
"relative_time": 46.81710481643677
|
||||
{
|
||||
"timestamp": "2025-01-30T05:06:46.085997",
|
||||
"cpu_percent": 12.78,
|
||||
"ram_percent": 18.8,
|
||||
"ram_used_gb": 5.351467132568359,
|
||||
"gpu_memory_used": 2596.0,
|
||||
"relative_time": 7.392607688903809
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:08.704786",
|
||||
"cpu_percent": 75.38,
|
||||
"ram_percent": 65.4,
|
||||
"ram_used_gb": 41.59475326538086,
|
||||
"gpu_memory_used": 5097.0,
|
||||
"relative_time": 47.91444158554077
|
||||
{
|
||||
"timestamp": "2025-01-30T05:06:47.127113",
|
||||
"cpu_percent": 14.7,
|
||||
"ram_percent": 18.9,
|
||||
"ram_used_gb": 5.367542266845703,
|
||||
"gpu_memory_used": 3341.0,
|
||||
"relative_time": 8.441826343536377
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:09.802745",
|
||||
"cpu_percent": 42.21,
|
||||
"ram_percent": 65.2,
|
||||
"ram_used_gb": 41.45526885986328,
|
||||
"gpu_memory_used": 5111.0,
|
||||
"relative_time": 49.04095649719238
|
||||
{
|
||||
"timestamp": "2025-01-30T05:06:48.176033",
|
||||
"cpu_percent": 13.47,
|
||||
"ram_percent": 18.9,
|
||||
"ram_used_gb": 5.361263275146484,
|
||||
"gpu_memory_used": 3339.0,
|
||||
"relative_time": 9.500520706176758
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:10.928231",
|
||||
"cpu_percent": 65.65,
|
||||
"ram_percent": 64.4,
|
||||
"ram_used_gb": 40.93437957763672,
|
||||
"gpu_memory_used": 5111.0,
|
||||
"relative_time": 50.14311861991882
|
||||
{
|
||||
"timestamp": "2025-01-30T05:06:49.234332",
|
||||
"cpu_percent": 15.84,
|
||||
"ram_percent": 18.9,
|
||||
"ram_used_gb": 5.3612213134765625,
|
||||
"gpu_memory_used": 3339.0,
|
||||
"relative_time": 10.53744649887085
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-30T05:06:50.271159",
|
||||
"cpu_percent": 14.89,
|
||||
"ram_percent": 18.9,
|
||||
"ram_used_gb": 5.379688262939453,
|
||||
"gpu_memory_used": 3646.0,
|
||||
"relative_time": 11.570110321044922
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-30T05:06:51.303841",
|
||||
"cpu_percent": 15.71,
|
||||
"ram_percent": 19.0,
|
||||
"ram_used_gb": 5.390773773193359,
|
||||
"gpu_memory_used": 3037.0,
|
||||
"relative_time": 12.60651707649231
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:12.036249",
|
||||
"cpu_percent": 28.51,
|
||||
"ram_percent": 64.1,
|
||||
"ram_used_gb": 40.749881744384766,
|
||||
"gpu_memory_used": 5107.0,
|
||||
"relative_time": 51.250269651412964
|
||||
},
|
||||
"timestamp": "2025-01-30T05:06:52.340383",
|
||||
"cpu_percent": 15.46,
|
||||
"ram_percent": 19.0,
|
||||
"ram_used_gb": 5.389518737792969,
|
||||
"gpu_memory_used": 3319.0,
|
||||
"relative_time": 13.636165380477905
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:13.137586",
|
||||
"cpu_percent": 52.99,
|
||||
"ram_percent": 64.2,
|
||||
"ram_used_gb": 40.84278869628906,
|
||||
"gpu_memory_used": 5104.0,
|
||||
"relative_time": 52.34805965423584
|
||||
},
|
||||
"timestamp": "2025-01-30T05:06:53.370342",
|
||||
"cpu_percent": 13.12,
|
||||
"ram_percent": 19.0,
|
||||
"ram_used_gb": 5.391136169433594,
|
||||
"gpu_memory_used": 3320.0,
|
||||
"relative_time": 14.67578935623169
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:14.235248",
|
||||
"cpu_percent": 34.55,
|
||||
"ram_percent": 64.1,
|
||||
"ram_used_gb": 40.7873420715332,
|
||||
"gpu_memory_used": 5097.0,
|
||||
"relative_time": 53.424301862716675
|
||||
},
|
||||
"timestamp": "2025-01-30T05:06:54.376175",
|
||||
"cpu_percent": 14.98,
|
||||
"ram_percent": 19.0,
|
||||
"ram_used_gb": 5.390045166015625,
|
||||
"gpu_memory_used": 3627.0,
|
||||
"relative_time": 15.70747685432434
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:15.311386",
|
||||
"cpu_percent": 39.07,
|
||||
"ram_percent": 64.2,
|
||||
"ram_used_gb": 40.860008239746094,
|
||||
"gpu_memory_used": 5091.0,
|
||||
"relative_time": 54.50679922103882
|
||||
},
|
||||
"timestamp": "2025-01-30T05:06:55.441172",
|
||||
"cpu_percent": 13.45,
|
||||
"ram_percent": 19.0,
|
||||
"ram_used_gb": 5.394947052001953,
|
||||
"gpu_memory_used": 1937.0,
|
||||
"relative_time": 16.758784770965576
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:16.393626",
|
||||
"cpu_percent": 31.02,
|
||||
"ram_percent": 64.3,
|
||||
"ram_used_gb": 40.884307861328125,
|
||||
"gpu_memory_used": 5093.0,
|
||||
"relative_time": 55.57431173324585
|
||||
},
|
||||
"timestamp": "2025-01-30T05:06:56.492442",
|
||||
"cpu_percent": 17.03,
|
||||
"ram_percent": 18.9,
|
||||
"ram_used_gb": 5.361682891845703,
|
||||
"gpu_memory_used": 3041.0,
|
||||
"relative_time": 17.789713144302368
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:17.461449",
|
||||
"cpu_percent": 24.53,
|
||||
"ram_percent": 64.3,
|
||||
"ram_used_gb": 40.89955520629883,
|
||||
"gpu_memory_used": 5070.0,
|
||||
"relative_time": 56.660638093948364
|
||||
},
|
||||
"timestamp": "2025-01-30T05:06:57.523536",
|
||||
"cpu_percent": 13.76,
|
||||
"ram_percent": 18.9,
|
||||
"ram_used_gb": 5.360996246337891,
|
||||
"gpu_memory_used": 3321.0,
|
||||
"relative_time": 18.838542222976685
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:18.547558",
|
||||
"cpu_percent": 19.93,
|
||||
"ram_percent": 64.3,
|
||||
"ram_used_gb": 40.92641830444336,
|
||||
"gpu_memory_used": 5074.0,
|
||||
"relative_time": 57.736456871032715
|
||||
},
|
||||
"timestamp": "2025-01-30T05:06:58.572158",
|
||||
"cpu_percent": 15.94,
|
||||
"ram_percent": 18.9,
|
||||
"ram_used_gb": 5.3652801513671875,
|
||||
"gpu_memory_used": 3323.0,
|
||||
"relative_time": 19.86689043045044
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:19.624478",
|
||||
"cpu_percent": 15.63,
|
||||
"ram_percent": 64.3,
|
||||
"ram_used_gb": 40.92564392089844,
|
||||
"gpu_memory_used": 5082.0,
|
||||
"relative_time": 58.81701683998108
|
||||
},
|
||||
"timestamp": "2025-01-30T05:06:59.600551",
|
||||
"cpu_percent": 15.67,
|
||||
"ram_percent": 18.9,
|
||||
"ram_used_gb": 5.363399505615234,
|
||||
"gpu_memory_used": 3630.0,
|
||||
"relative_time": 20.89712619781494
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:20.705184",
|
||||
"cpu_percent": 29.86,
|
||||
"ram_percent": 64.4,
|
||||
"ram_used_gb": 40.935394287109375,
|
||||
"gpu_memory_used": 5082.0,
|
||||
"relative_time": 59.88701677322388
|
||||
},
|
||||
"timestamp": "2025-01-30T05:07:00.631315",
|
||||
"cpu_percent": 15.37,
|
||||
"ram_percent": 18.9,
|
||||
"ram_used_gb": 5.3663482666015625,
|
||||
"gpu_memory_used": 3629.0,
|
||||
"relative_time": 22.01374316215515
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:21.775463",
|
||||
"cpu_percent": 43.55,
|
||||
"ram_percent": 64.4,
|
||||
"ram_used_gb": 40.9350471496582,
|
||||
"gpu_memory_used": 5080.0,
|
||||
"relative_time": 60.96005439758301
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:22.847939",
|
||||
"cpu_percent": 26.66,
|
||||
"ram_percent": 64.4,
|
||||
"ram_used_gb": 40.94179916381836,
|
||||
"gpu_memory_used": 5076.0,
|
||||
"relative_time": 62.02673673629761
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:23.914337",
|
||||
"cpu_percent": 22.46,
|
||||
"ram_percent": 64.4,
|
||||
"ram_used_gb": 40.9537467956543,
|
||||
"gpu_memory_used": 5076.0,
|
||||
"relative_time": 63.10581707954407
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:24.993313",
|
||||
"cpu_percent": 28.07,
|
||||
"ram_percent": 64.4,
|
||||
"ram_used_gb": 40.94577407836914,
|
||||
"gpu_memory_used": 5076.0,
|
||||
"relative_time": 64.18998432159424
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:26.077028",
|
||||
"cpu_percent": 26.1,
|
||||
"ram_percent": 64.4,
|
||||
"ram_used_gb": 40.98012161254883,
|
||||
"gpu_memory_used": 5197.0,
|
||||
"relative_time": 65.28782486915588
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:27.175228",
|
||||
"cpu_percent": 35.17,
|
||||
"ram_percent": 64.6,
|
||||
"ram_used_gb": 41.0831184387207,
|
||||
"gpu_memory_used": 5422.0,
|
||||
"relative_time": 66.37566781044006
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:28.265025",
|
||||
"cpu_percent": 55.14,
|
||||
"ram_percent": 64.9,
|
||||
"ram_used_gb": 41.25740432739258,
|
||||
"gpu_memory_used": 5512.0,
|
||||
"relative_time": 67.48023676872253
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:29.367776",
|
||||
"cpu_percent": 53.84,
|
||||
"ram_percent": 65.0,
|
||||
"ram_used_gb": 41.36682891845703,
|
||||
"gpu_memory_used": 5616.0,
|
||||
"relative_time": 68.57096815109253
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:30.458301",
|
||||
"cpu_percent": 33.42,
|
||||
"ram_percent": 65.3,
|
||||
"ram_used_gb": 41.5602912902832,
|
||||
"gpu_memory_used": 5724.0,
|
||||
"relative_time": 69.66709041595459
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:31.554329",
|
||||
"cpu_percent": 50.81,
|
||||
"ram_percent": 65.5,
|
||||
"ram_used_gb": 41.66044616699219,
|
||||
"gpu_memory_used": 5827.0,
|
||||
"relative_time": 70.75874853134155
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:32.646414",
|
||||
"cpu_percent": 34.34,
|
||||
"ram_percent": 65.6,
|
||||
"ram_used_gb": 41.739715576171875,
|
||||
"gpu_memory_used": 5843.0,
|
||||
"relative_time": 71.86718988418579
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:33.754223",
|
||||
"cpu_percent": 44.32,
|
||||
"ram_percent": 66.0,
|
||||
"ram_used_gb": 42.005794525146484,
|
||||
"gpu_memory_used": 5901.0,
|
||||
"relative_time": 72.95793795585632
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:34.848852",
|
||||
"cpu_percent": 48.36,
|
||||
"ram_percent": 66.5,
|
||||
"ram_used_gb": 42.3160514831543,
|
||||
"gpu_memory_used": 5924.0,
|
||||
"relative_time": 74.35109186172485
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:36.240235",
|
||||
"cpu_percent": 58.06,
|
||||
"ram_percent": 67.5,
|
||||
"ram_used_gb": 42.95722198486328,
|
||||
"gpu_memory_used": 5930.0,
|
||||
"relative_time": 75.47581958770752
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:37.363208",
|
||||
"cpu_percent": 46.82,
|
||||
"ram_percent": 67.6,
|
||||
"ram_used_gb": 42.97764587402344,
|
||||
"gpu_memory_used": 6364.0,
|
||||
"relative_time": 76.58708119392395
|
||||
},
|
||||
"timestamp": "2025-01-30T05:07:01.747500",
|
||||
"cpu_percent": 13.79,
|
||||
"ram_percent": 18.9,
|
||||
"ram_used_gb": 5.367362976074219,
|
||||
"gpu_memory_used": 3620.0,
|
||||
"relative_time": 23.05113124847412
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-06T00:44:38.474408",
|
||||
"cpu_percent": 50.93,
|
||||
"ram_percent": 67.9,
|
||||
"ram_used_gb": 43.1597900390625,
|
||||
"gpu_memory_used": 6426.0,
|
||||
"relative_time": 77.6842532157898
|
||||
"timestamp": "2025-01-30T05:07:02.784828",
|
||||
"cpu_percent": 10.16,
|
||||
"ram_percent": 19.1,
|
||||
"ram_used_gb": 5.443946838378906,
|
||||
"gpu_memory_used": 1916.0,
|
||||
"relative_time": 24.08937978744507
|
||||
}
|
||||
],
|
||||
"test_duration": 82.49591493606567
|
||||
"test_duration": 26.596059799194336
|
||||
}
|
|
@ -1,23 +1,23 @@
|
|||
=== Benchmark Statistics (with correct RTF) ===
|
||||
|
||||
Total tokens processed: 3150
|
||||
Total audio generated (s): 994.22
|
||||
Total test duration (s): 73.81
|
||||
Average processing rate (tokens/s): 49.36
|
||||
Average RTF: 0.07
|
||||
Average Real Time Speed: 15.00
|
||||
Total audio generated (s): 895.98
|
||||
Total test duration (s): 23.54
|
||||
Average processing rate (tokens/s): 133.43
|
||||
Average RTF: 0.03
|
||||
Average Real Time Speed: 35.29
|
||||
|
||||
=== Per-chunk Stats ===
|
||||
|
||||
Average chunk size (tokens): 525.00
|
||||
Min chunk size (tokens): 150
|
||||
Max chunk size (tokens): 900
|
||||
Average processing time (s): 12.12
|
||||
Average output length (s): 165.70
|
||||
Average processing time (s): 3.88
|
||||
Average output length (s): 149.33
|
||||
|
||||
=== Performance Ranges ===
|
||||
|
||||
Processing rate range (tokens/s): 30.33 - 63.56
|
||||
RTF range: 0.05x - 0.10x
|
||||
Real Time Speed range: 10.00x - 20.00x
|
||||
Processing rate range (tokens/s): 127.12 - 147.93
|
||||
RTF range: 0.02x - 0.03x
|
||||
Real Time Speed range: 33.33x - 50.00x
|
||||
|
||||
|
|
Before Width: | Height: | Size: 230 KiB After Width: | Height: | Size: 230 KiB |
Before Width: | Height: | Size: 206 KiB After Width: | Height: | Size: 260 KiB |
Before Width: | Height: | Size: 491 KiB After Width: | Height: | Size: 392 KiB |
Before Width: | Height: | Size: 224 KiB After Width: | Height: | Size: 235 KiB |
Before Width: | Height: | Size: 221 KiB After Width: | Height: | Size: 265 KiB |
Before Width: | Height: | Size: 463 KiB After Width: | Height: | Size: 429 KiB |
|
@ -1,394 +1,138 @@
|
|||
#!/usr/bin/env python3
|
||||
import os
|
||||
import argparse
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
import time
|
||||
import wave
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import matplotlib.pyplot as plt
|
||||
from scipy.io import wavfile
|
||||
from openai import OpenAI
|
||||
|
||||
# Create output directory
|
||||
output_dir = Path(__file__).parent / "output"
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
def submit_combine_voices(
|
||||
voices: List[str], base_url: str = "http://localhost:8880"
|
||||
) -> Optional[str]:
|
||||
"""Combine multiple voices into a new voice.
|
||||
# Initialize OpenAI client
|
||||
client = OpenAI(base_url="http://localhost:8880/v1", api_key="not-needed")
|
||||
|
||||
Args:
|
||||
voices: List of voice names to combine (e.g. ["af_bella", "af_sarah"])
|
||||
base_url: API base URL
|
||||
# Test text that showcases voice characteristics
|
||||
text = """The quick brown fox jumps over the lazy dog.
|
||||
How vexingly quick daft zebras jump!
|
||||
The five boxing wizards jump quickly."""
|
||||
|
||||
Returns:
|
||||
Name of the combined voice (e.g. "af_bella_af_sarah") or None if error
|
||||
"""
|
||||
def generate_and_save_audio(voice: str, output_path: str):
|
||||
"""Generate audio using specified voice and save to WAV file."""
|
||||
print(f"\nGenerating audio for voice: {voice}")
|
||||
start_time = time.time()
|
||||
|
||||
# Generate audio using streaming response
|
||||
with client.audio.speech.with_streaming_response.create(
|
||||
model="kokoro",
|
||||
voice=voice,
|
||||
response_format="wav",
|
||||
input=text,
|
||||
) as response:
|
||||
# Save the audio stream to file
|
||||
with open(output_path, "wb") as f:
|
||||
for chunk in response.iter_bytes():
|
||||
f.write(chunk)
|
||||
|
||||
duration = time.time() - start_time
|
||||
print(f"Generated in {duration:.2f}s")
|
||||
print(f"Saved to {output_path}")
|
||||
return output_path
|
||||
|
||||
def analyze_audio(filepath: str):
|
||||
"""Analyze audio file and return key characteristics."""
|
||||
print(f"\nAnalyzing {filepath}")
|
||||
try:
|
||||
response = requests.post(f"{base_url}/v1/audio/voices/combine", json=voices)
|
||||
print(f"Response status: {response.status_code}")
|
||||
print(f"Raw response: {response.text}")
|
||||
|
||||
# Accept both 200 and 201 as success
|
||||
if response.status_code not in [200, 201]:
|
||||
try:
|
||||
error = response.json()["detail"]["message"]
|
||||
print(f"Error combining voices: {error}")
|
||||
except:
|
||||
print(f"Error combining voices: {response.text}")
|
||||
return None
|
||||
|
||||
try:
|
||||
data = response.json()
|
||||
if "voices" in data:
|
||||
print(f"Available voices: {', '.join(sorted(data['voices']))}")
|
||||
return data["voice"]
|
||||
except Exception as e:
|
||||
print(f"Error parsing response: {e}")
|
||||
return None
|
||||
print(f"\nTrying to read {filepath}")
|
||||
with wave.open(filepath, 'rb') as wf:
|
||||
sample_rate = wf.getframerate()
|
||||
samples = np.frombuffer(wf.readframes(wf.getnframes()), dtype=np.int16)
|
||||
print(f"Successfully read file:")
|
||||
print(f"Sample rate: {sample_rate}")
|
||||
print(f"Samples shape: {samples.shape}")
|
||||
print(f"Samples dtype: {samples.dtype}")
|
||||
print(f"First few samples: {samples[:10]}")
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def generate_speech(
|
||||
text: str,
|
||||
voice: str,
|
||||
base_url: str = "http://localhost:8880",
|
||||
output_file: str = "output.mp3",
|
||||
) -> bool:
|
||||
"""Generate speech using specified voice.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech
|
||||
voice: Voice name to use
|
||||
base_url: API base URL
|
||||
output_file: Path to save audio file
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{base_url}/v1/audio/speech",
|
||||
json={
|
||||
"input": text,
|
||||
"voice": voice,
|
||||
"speed": 1.0,
|
||||
"response_format": "wav", # Use WAV for analysis
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error = response.json().get("detail", {}).get("message", response.text)
|
||||
print(f"Error generating speech: {error}")
|
||||
return False
|
||||
|
||||
# Save the audio
|
||||
os.makedirs(
|
||||
os.path.dirname(output_file) if os.path.dirname(output_file) else ".",
|
||||
exist_ok=True,
|
||||
)
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(response.content)
|
||||
print(f"Saved audio to {output_file}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def analyze_audio(filepath: str) -> Tuple[np.ndarray, int, dict]:
|
||||
"""Analyze audio file and return samples, sample rate, and audio characteristics.
|
||||
|
||||
Args:
|
||||
filepath: Path to audio file
|
||||
|
||||
Returns:
|
||||
Tuple of (samples, sample_rate, characteristics)
|
||||
"""
|
||||
sample_rate, samples = wavfile.read(filepath)
|
||||
|
||||
print(f"Error reading file: {str(e)}")
|
||||
raise
|
||||
|
||||
# Convert to float64 for calculations
|
||||
samples = samples.astype(np.float64) / 32768.0 # Normalize 16-bit audio
|
||||
|
||||
# Convert to mono if stereo
|
||||
if len(samples.shape) > 1:
|
||||
samples = np.mean(samples, axis=1)
|
||||
|
||||
|
||||
# Calculate basic stats
|
||||
duration = len(samples) / sample_rate
|
||||
max_amp = np.max(np.abs(samples))
|
||||
rms = np.sqrt(np.mean(samples**2))
|
||||
duration = len(samples) / sample_rate
|
||||
|
||||
# Zero crossing rate (helps identify voice characteristics)
|
||||
zero_crossings = np.sum(np.abs(np.diff(np.signbit(samples)))) / len(samples)
|
||||
|
||||
# Simple frequency analysis
|
||||
if len(samples) > 0:
|
||||
# Use FFT to get frequency components
|
||||
fft_result = np.fft.fft(samples)
|
||||
freqs = np.fft.fftfreq(len(samples), 1 / sample_rate)
|
||||
|
||||
# Get positive frequencies only
|
||||
pos_mask = freqs > 0
|
||||
freqs = freqs[pos_mask]
|
||||
magnitudes = np.abs(fft_result)[pos_mask]
|
||||
|
||||
# Find dominant frequencies (top 3)
|
||||
top_indices = np.argsort(magnitudes)[-3:]
|
||||
dominant_freqs = freqs[top_indices]
|
||||
|
||||
# Calculate spectral centroid (brightness of sound)
|
||||
spectral_centroid = np.sum(freqs * magnitudes) / np.sum(magnitudes)
|
||||
else:
|
||||
dominant_freqs = []
|
||||
spectral_centroid = 0
|
||||
|
||||
characteristics = {
|
||||
"max_amplitude": max_amp,
|
||||
"rms": rms,
|
||||
"duration": duration,
|
||||
"zero_crossing_rate": zero_crossings,
|
||||
"dominant_frequencies": dominant_freqs,
|
||||
"spectral_centroid": spectral_centroid,
|
||||
|
||||
# Calculate frequency characteristics
|
||||
# Compute FFT
|
||||
N = len(samples)
|
||||
yf = np.fft.fft(samples)
|
||||
xf = np.fft.fftfreq(N, 1 / sample_rate)[:N//2]
|
||||
magnitude = 2.0/N * np.abs(yf[0:N//2])
|
||||
# Calculate spectral centroid
|
||||
spectral_centroid = np.sum(xf * magnitude) / np.sum(magnitude)
|
||||
# Determine dominant frequencies
|
||||
dominant_freqs = xf[magnitude.argsort()[-5:]][::-1].tolist()
|
||||
|
||||
return {
|
||||
'samples': samples,
|
||||
'sample_rate': sample_rate,
|
||||
'duration': duration,
|
||||
'max_amplitude': max_amp,
|
||||
'rms': rms,
|
||||
'spectral_centroid': spectral_centroid,
|
||||
'dominant_frequencies': dominant_freqs
|
||||
}
|
||||
|
||||
return samples, sample_rate, characteristics
|
||||
|
||||
|
||||
def setup_plot(fig, ax, title):
|
||||
"""Configure plot styling"""
|
||||
# Improve grid
|
||||
ax.grid(True, linestyle="--", alpha=0.3, color="#ffffff")
|
||||
|
||||
# Set title and labels with better fonts
|
||||
ax.set_title(title, pad=20, fontsize=16, fontweight="bold", color="#ffffff")
|
||||
ax.set_xlabel(ax.get_xlabel(), fontsize=14, fontweight="medium", color="#ffffff")
|
||||
ax.set_ylabel(ax.get_ylabel(), fontsize=14, fontweight="medium", color="#ffffff")
|
||||
|
||||
# Improve tick labels
|
||||
ax.tick_params(labelsize=12, colors="#ffffff")
|
||||
|
||||
# Style spines
|
||||
for spine in ax.spines.values():
|
||||
spine.set_color("#ffffff")
|
||||
spine.set_alpha(0.3)
|
||||
spine.set_linewidth(0.5)
|
||||
|
||||
# Set background colors
|
||||
ax.set_facecolor("#1a1a2e")
|
||||
fig.patch.set_facecolor("#1a1a2e")
|
||||
|
||||
return fig, ax
|
||||
|
||||
|
||||
def plot_analysis(audio_files: Dict[str, str], output_dir: str):
|
||||
"""Plot comprehensive voice analysis including waveforms and metrics comparison.
|
||||
|
||||
Args:
|
||||
audio_files: Dictionary of label -> filepath
|
||||
output_dir: Directory to save plot files
|
||||
"""
|
||||
# Set dark style
|
||||
plt.style.use("dark_background")
|
||||
|
||||
# Create figure with subplots
|
||||
fig = plt.figure(figsize=(15, 15))
|
||||
fig.patch.set_facecolor("#1a1a2e")
|
||||
num_files = len(audio_files)
|
||||
|
||||
# Create subplot grid with proper spacing for waveforms and metrics
|
||||
total_rows = num_files + 2 # Add one more row for metrics
|
||||
gs = plt.GridSpec(
|
||||
total_rows, 2, height_ratios=[1.5] * num_files + [1, 1], hspace=0.4, wspace=0.3
|
||||
)
|
||||
|
||||
# Analyze all files first
|
||||
all_chars = {}
|
||||
for i, (label, filepath) in enumerate(audio_files.items()):
|
||||
samples, sample_rate, chars = analyze_audio(filepath)
|
||||
all_chars[label] = chars
|
||||
|
||||
# Plot waveform spanning both columns
|
||||
ax = plt.subplot(gs[i, :])
|
||||
time = np.arange(len(samples)) / sample_rate
|
||||
plt.plot(time, samples / chars["max_amplitude"], linewidth=0.5, color="#ff2a6d")
|
||||
ax.set_xlabel("Time (seconds)")
|
||||
ax.set_ylabel("Normalized Amplitude")
|
||||
ax.set_ylim(-1.1, 1.1)
|
||||
setup_plot(fig, ax, f"Waveform: {label}")
|
||||
|
||||
# Colors for voices
|
||||
colors = ["#ff2a6d", "#05d9e8", "#d1f7ff"]
|
||||
|
||||
# Create metrics for each subplot
|
||||
metrics = [
|
||||
(
|
||||
plt.subplot(gs[num_files, 0]),
|
||||
[
|
||||
(
|
||||
"Volume",
|
||||
[chars["rms"] * 100 for chars in all_chars.values()],
|
||||
"RMS×100",
|
||||
)
|
||||
],
|
||||
),
|
||||
(
|
||||
plt.subplot(gs[num_files, 1]),
|
||||
[
|
||||
(
|
||||
"Brightness",
|
||||
[chars["spectral_centroid"] / 1000 for chars in all_chars.values()],
|
||||
"kHz",
|
||||
)
|
||||
],
|
||||
),
|
||||
(
|
||||
plt.subplot(gs[num_files + 1, 0]),
|
||||
[
|
||||
(
|
||||
"Voice Pitch",
|
||||
[
|
||||
min(chars["dominant_frequencies"])
|
||||
for chars in all_chars.values()
|
||||
],
|
||||
"Hz",
|
||||
)
|
||||
],
|
||||
),
|
||||
(
|
||||
plt.subplot(gs[num_files + 1, 1]),
|
||||
[
|
||||
(
|
||||
"Texture",
|
||||
[
|
||||
chars["zero_crossing_rate"] * 1000
|
||||
for chars in all_chars.values()
|
||||
],
|
||||
"ZCR×1000",
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
# Plot each metric
|
||||
for i, (ax, metric_data) in enumerate(metrics):
|
||||
n_voices = len(audio_files)
|
||||
bar_width = 0.25
|
||||
indices = np.array([0])
|
||||
|
||||
values = metric_data[0][1]
|
||||
max_val = max(values)
|
||||
|
||||
for j, (voice, color) in enumerate(zip(audio_files.keys(), colors)):
|
||||
offset = (j - n_voices / 2 + 0.5) * bar_width
|
||||
bars = ax.bar(
|
||||
indices + offset,
|
||||
[values[j]],
|
||||
bar_width,
|
||||
label=voice,
|
||||
color=color,
|
||||
alpha=0.8,
|
||||
)
|
||||
|
||||
# Add value labels on top of bars
|
||||
for bar in bars:
|
||||
height = bar.get_height()
|
||||
ax.text(
|
||||
bar.get_x() + bar.get_width() / 2.0,
|
||||
height,
|
||||
f"{height:.1f}",
|
||||
ha="center",
|
||||
va="bottom",
|
||||
color="white",
|
||||
fontsize=10,
|
||||
)
|
||||
|
||||
ax.set_xticks(indices)
|
||||
ax.set_xticklabels([f"{metric_data[0][0]}\n({metric_data[0][2]})"])
|
||||
ax.set_ylim(0, max_val * 1.2)
|
||||
ax.set_ylabel("Value")
|
||||
|
||||
# Only show legend on first metric plot
|
||||
if i == 0:
|
||||
ax.legend(
|
||||
bbox_to_anchor=(1.05, 1),
|
||||
loc="upper left",
|
||||
facecolor="#1a1a2e",
|
||||
edgecolor="#ffffff",
|
||||
)
|
||||
|
||||
# Style the subplot
|
||||
setup_plot(fig, ax, metric_data[0][0])
|
||||
|
||||
# Adjust the figure size and padding
|
||||
fig.set_size_inches(15, 20)
|
||||
plt.subplots_adjust(right=0.85, top=0.95, bottom=0.05, left=0.1)
|
||||
plt.savefig(os.path.join(output_dir, "analysis_comparison.png"), dpi=300)
|
||||
print(f"Saved analysis comparison to {output_dir}/analysis_comparison.png")
|
||||
|
||||
# Print detailed comparative analysis
|
||||
print("\nDetailed Voice Analysis:")
|
||||
for label, chars in all_chars.items():
|
||||
print(f"\n{label}:")
|
||||
print(f" Max Amplitude: {chars['max_amplitude']:.2f}")
|
||||
print(f" RMS (loudness): {chars['rms']:.2f}")
|
||||
print(f" Duration: {chars['duration']:.2f}s")
|
||||
print(f" Zero Crossing Rate: {chars['zero_crossing_rate']:.3f}")
|
||||
print(f" Spectral Centroid: {chars['spectral_centroid']:.0f}Hz")
|
||||
print(
|
||||
f" Dominant Frequencies: {', '.join(f'{f:.0f}Hz' for f in chars['dominant_frequencies'])}"
|
||||
)
|
||||
|
||||
def plot_comparison(analyses, output_path):
|
||||
"""Create comparison plot of the audio analyses."""
|
||||
plt.style.use('dark_background')
|
||||
fig = plt.figure(figsize=(15, 10))
|
||||
fig.patch.set_facecolor('#1a1a2e')
|
||||
|
||||
# Plot waveforms
|
||||
for i, (name, data) in enumerate(analyses.items()):
|
||||
ax = plt.subplot(3, 1, i+1)
|
||||
samples = data['samples']
|
||||
time = np.arange(len(samples)) / data['sample_rate']
|
||||
plt.plot(time, samples / data['max_amplitude'], linewidth=0.5, color='#ff2a6d')
|
||||
plt.title(f"Waveform: {name}", color='white', pad=20)
|
||||
plt.xlabel("Time (seconds)", color='white')
|
||||
plt.ylabel("Normalized Amplitude", color='white')
|
||||
plt.grid(True, alpha=0.3)
|
||||
ax.set_facecolor('#1a1a2e')
|
||||
plt.ylim(-1.1, 1.1)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, dpi=300, bbox_inches='tight')
|
||||
print(f"\nSaved comparison plot to {output_path}")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Kokoro Voice Analysis Demo")
|
||||
parser.add_argument("--voices", nargs="+", type=str, help="Voices to combine")
|
||||
parser.add_argument(
|
||||
"--text",
|
||||
type=str,
|
||||
default="Hello! This is a test of combined voices.",
|
||||
help="Text to speak",
|
||||
)
|
||||
parser.add_argument("--url", default="http://localhost:8880", help="API base URL")
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
default="examples/assorted_checks/test_combinations/output",
|
||||
help="Output directory for audio files",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.voices:
|
||||
print("No voices provided, using default test voices")
|
||||
args.voices = ["af_bella", "af_nicole"]
|
||||
|
||||
# Create output directory
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Dictionary to store audio files for analysis
|
||||
audio_files = {}
|
||||
|
||||
# Generate speech with individual voices
|
||||
print("Generating speech with individual voices...")
|
||||
for voice in args.voices:
|
||||
output_file = os.path.join(args.output_dir, f"analysis_{voice}.wav")
|
||||
if generate_speech(args.text, voice, args.url, output_file):
|
||||
audio_files[voice] = output_file
|
||||
|
||||
# Generate speech with combined voice
|
||||
print(f"\nCombining voices: {', '.join(args.voices)}")
|
||||
combined_voice = submit_combine_voices(args.voices, args.url)
|
||||
|
||||
if combined_voice:
|
||||
print(f"Successfully created combined voice: {combined_voice}")
|
||||
output_file = os.path.join(
|
||||
args.output_dir, f"analysis_combined_{combined_voice}.wav"
|
||||
)
|
||||
if generate_speech(args.text, combined_voice, args.url, output_file):
|
||||
audio_files["combined"] = output_file
|
||||
|
||||
# Generate comparison plots
|
||||
plot_analysis(audio_files, args.output_dir)
|
||||
else:
|
||||
print("Failed to combine voices")
|
||||
|
||||
# Generate audio for each voice
|
||||
voices = {
|
||||
'af_bella': output_dir / 'af_bella.wav',
|
||||
'af_irulan': output_dir / 'af_irulan.wav',
|
||||
'af_bella+af_irulan': output_dir / 'af_bella+af_irulan.wav'
|
||||
}
|
||||
|
||||
for voice, path in voices.items():
|
||||
generate_and_save_audio(voice, str(path))
|
||||
|
||||
# Analyze each audio file
|
||||
analyses = {}
|
||||
for name, path in voices.items():
|
||||
analyses[name] = analyze_audio(str(path))
|
||||
|
||||
# Create comparison plot
|
||||
plot_comparison(analyses, output_dir / 'voice_comparison.png')
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|