Compare commits

..

No commits in common. "master" and "v0.0.5post1" have entirely different histories.

260 changed files with 6786 additions and 15297 deletions

View file

@ -7,7 +7,6 @@ omit =
MagicMock/*
test_*.py
examples/*
src/builds/*
[report]
exclude_lines =

5
.gitattributes vendored
View file

@ -1,5 +0,0 @@
* text=auto
*.py text eol=lf
*.sh text eol=lf
*.yml text eol=lf

15
.github/FUNDING.yml vendored
View file

@ -1,15 +0,0 @@
# These are supported funding model platforms
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username
ko_fi: # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
polar: # Replace with a single Polar username
buy_me_a_coffee: remsky
thanks_dev: # Replace with a single thanks.dev username
custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']

View file

@ -1,23 +0,0 @@
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: ''
assignees: ''
---
**Describe the bug**
A clear and concise description of what the bug is.
**Screenshots or console output**
If applicable, add screenshots to help explain your problem. When doing so., please ensure you have the first command that triggered the trace and/or the command that started up your build included, otherwise it is difficult to diagnose.
**Branch / Deployment used**
Let us know if it's the master branch, or the stable branch indicated in the readme, as well as if you're running it locally, in the cloud, via the docker compose (cpu or gpu), or direct docker run commands. Please include the exact commands used to run in the latter cases.
**Operating System**
Include the platform, version numbers of your docker, etc. Whether its GPU (Nvidia or other) or CPU, Mac, Linux, Windows, etc.
**Additional context**
Add any other context about the problem here.

View file

@ -1,17 +0,0 @@
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: ''
assignees: ''
---
**Describe the feature you'd like**
A clear and concise description of what you want to happen. Is it a quality of life improvement, something new entirely?
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered. Consider whether it could be submitted as PR, or you'd need a hand to do so
**Additional context**
Add any other context or screenshots about the feature request here.

View file

@ -1,39 +1,51 @@
name: CI
on:
push:
branches: [ "master", "pre-release" ]
pull_request:
branches: [ "master", "pre-release" ]
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]
fail-fast: false
# name: CI
steps:
- uses: actions/checkout@v4
# on:
# push:
# branches: [ "develop", "master" ]
# pull_request:
# branches: [ "develop", "master" ]
# Match Dockerfile dependencies
- name: Install Dependencies
run: |
sudo apt-get update
sudo apt-get install -y --no-install-recommends \
espeak-ng \
git \
libsndfile1 \
curl \
ffmpeg
# jobs:
# test:
# runs-on: ubuntu-latest
# strategy:
# matrix:
# python-version: ["3.9", "3.10", "3.11"]
# fail-fast: false
- name: Install uv
uses: astral-sh/setup-uv@v5
with:
python-version: ${{ matrix.python-version }}
enable-cache: true
- name: Install dependencies
run: |
uv pip install -e .[test,cpu]
- name: Run Tests
run: |
uv run pytest api/tests/ --asyncio-mode=auto --cov=api --cov-report=term-missing
# steps:
# - uses: actions/checkout@v4
# - name: Set up Python ${{ matrix.python-version }}
# uses: actions/setup-python@v5
# with:
# python-version: ${{ matrix.python-version }}
# - name: Set up pip cache
# uses: actions/cache@v3
# with:
# path: ~/.cache/pip
# key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements*.txt') }}
# restore-keys: |
# ${{ runner.os }}-pip-
# - name: Install PyTorch CPU
# run: |
# python -m pip install --upgrade pip
# pip install torch --index-url https://download.pytorch.org/whl/cpu
# - name: Install dependencies
# run: |
# pip install ruff pytest-cov
# pip install -r requirements.txt
# pip install -r requirements-test.txt
# - name: Lint with ruff
# run: |
# ruff check .
# - name: Test with pytest
# run: |
# pytest --asyncio-mode=auto --cov=api --cov-report=term-missing

120
.github/workflows/docker-publish.yml vendored Normal file
View file

@ -0,0 +1,120 @@
name: Docker Build and Publish
on:
push:
tags: [ 'v*.*.*' ]
# Allow manual trigger from GitHub UI
workflow_dispatch:
env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}
jobs:
build:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Log in to the Container registry
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
# Extract metadata for GPU image
- name: Extract metadata (tags, labels) for GPU Docker
id: meta-gpu
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
tags: |
type=semver,pattern=v{{version}}
type=semver,pattern=v{{major}}.{{minor}}
type=semver,pattern=v{{major}}
type=raw,value=latest
# Extract metadata for CPU image
- name: Extract metadata (tags, labels) for CPU Docker
id: meta-cpu
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
flavor: |
suffix=-cpu
tags: |
type=semver,pattern=v{{version}}
type=semver,pattern=v{{major}}.{{minor}}
type=semver,pattern=v{{major}}
type=raw,value=latest
# Build and push GPU version
- name: Build and push GPU Docker image
uses: docker/build-push-action@v5
with:
context: .
file: ./Dockerfile
push: true
tags: ${{ steps.meta-gpu.outputs.tags }}
labels: ${{ steps.meta-gpu.outputs.labels }}
platforms: linux/amd64
# Build and push CPU version
- name: Build and push CPU Docker image
uses: docker/build-push-action@v5
with:
context: .
file: ./Dockerfile.cpu
push: true
tags: ${{ steps.meta-cpu.outputs.tags }}
labels: ${{ steps.meta-cpu.outputs.labels }}
platforms: linux/amd64
# Extract metadata for UI image
- name: Extract metadata (tags, labels) for UI Docker
id: meta-ui
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
flavor: |
suffix=-ui
tags: |
type=semver,pattern=v{{version}}
type=semver,pattern=v{{major}}.{{minor}}
type=semver,pattern=v{{major}}
type=raw,value=latest
# Build and push UI version
- name: Build and push UI Docker image
uses: docker/build-push-action@v5
with:
context: ./ui
file: ./ui/Dockerfile
push: true
tags: ${{ steps.meta-ui.outputs.tags }}
labels: ${{ steps.meta-ui.outputs.labels }}
platforms: linux/amd64
create-release:
needs: build
runs-on: ubuntu-latest
# Only run this job if we're pushing a tag
if: startsWith(github.ref, 'refs/tags/')
permissions:
contents: write
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Create Release
uses: softprops/action-gh-release@v1
with:
generate_release_notes: true
draft: false
prerelease: false

View file

@ -1,110 +0,0 @@
name: Create Release and Publish Docker Images
on:
push:
branches:
- release # Trigger when commits are pushed to the release branch (e.g., after merging master)
paths-ignore:
- '**.md'
- 'docs/**'
jobs:
prepare-release:
runs-on: ubuntu-latest
outputs:
version: ${{ steps.get-version.outputs.version }}
version_tag: ${{ steps.get-version.outputs.version_tag }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Get version from VERSION file
id: get-version
run: |
VERSION_PLAIN=$(cat VERSION)
echo "version=${VERSION_PLAIN}" >> $GITHUB_OUTPUT
echo "version_tag=v${VERSION_PLAIN}" >> $GITHUB_OUTPUT # Add 'v' prefix for tag
build-images:
needs: prepare-release
runs-on: ubuntu-latest
permissions:
packages: write # Needed to push images to GHCR
env:
DOCKER_BUILDKIT: 1
BUILDKIT_STEP_LOG_MAX_SIZE: 10485760
# This environment variable will override the VERSION variable in docker-bake.hcl
VERSION: ${{ needs.prepare-release.outputs.version_tag }} # Use tag version (vX.Y.Z) for bake
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0 # Needed to check for existing tags
- name: Check if tag already exists
run: |
TAG_NAME="${{ needs.prepare-release.outputs.version_tag }}"
echo "Checking for existing tag: $TAG_NAME"
# Fetch tags explicitly just in case checkout didn't get them all
git fetch --tags
if git rev-parse "$TAG_NAME" >/dev/null 2>&1; then
echo "::error::Tag $TAG_NAME already exists. Please increment the version in the VERSION file."
exit 1
else
echo "Tag $TAG_NAME does not exist. Proceeding with release."
fi
- name: Free disk space # Optional: Keep as needed for large builds
run: |
echo "Listing current disk space"
df -h
echo "Cleaning up disk space..."
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache
docker system prune -af
echo "Disk space after cleanup"
df -h
- name: Set up QEMU
uses: docker/setup-qemu-action@v3 # Use v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 # Use v3
with:
driver-opts: |
image=moby/buildkit:latest
network=host
- name: Log in to GitHub Container Registry
uses: docker/login-action@v3 # Use v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build and push images using Docker Bake
run: |
echo "Building and pushing images for version ${{ needs.prepare-release.outputs.version_tag }}"
# The VERSION env var above sets the tag for the bake file targets
docker buildx bake --push
create-release:
needs: [prepare-release, build-images]
runs-on: ubuntu-latest
permissions:
contents: write # Needed to create releases
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0 # Fetch all history for release notes generation
- name: Create GitHub Release
uses: softprops/action-gh-release@v2 # Use v2
with:
tag_name: ${{ needs.prepare-release.outputs.version_tag }} # Use vX.Y.Z tag
name: Release ${{ needs.prepare-release.outputs.version_tag }}
generate_release_notes: true # Auto-generate release notes
draft: false # Publish immediately
prerelease: false # Mark as a stable release
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

55
.github/workflows/sync-develop.yml vendored Normal file
View file

@ -0,0 +1,55 @@
# name: Sync develop with master
# on:
# push:
# branches:
# - master
# jobs:
# sync-develop:
# runs-on: ubuntu-latest
# permissions:
# contents: write
# issues: write
# steps:
# - name: Checkout repository
# uses: actions/checkout@v4
# with:
# fetch-depth: 0
# ref: develop
# - name: Configure Git
# run: |
# git config user.name "GitHub Actions"
# git config user.email "actions@github.com"
# - name: Merge master into develop
# run: |
# git fetch origin master:master
# git merge --no-ff origin/master -m "chore: Merge master into develop branch"
# - name: Push changes
# run: |
# if ! git push origin develop; then
# echo "Failed to push to develop branch"
# exit 1
# fi
# - name: Handle Failure
# if: failure()
# uses: actions/github-script@v7
# with:
# script: |
# const issueBody = `Automatic merge from master to develop failed.
# Please resolve this manually
# Workflow run: ${process.env.GITHUB_SERVER_URL}/${process.env.GITHUB_REPOSITORY}/actions/runs/${process.env.GITHUB_RUN_ID}`;
# await github.rest.issues.create({
# owner: context.repo.owner,
# repo: context.repo.repo,
# title: '🔄 Automatic master to develop merge failed',
# body: issueBody,
# labels: ['merge-failed', 'automation']
# });

65
.gitignore vendored
View file

@ -2,74 +2,51 @@
.git
# Python
__pycache__/
__pycache__
*.pyc
*.pyo
*.pyd
*.pt
.Python
*.py[cod]
*$py.class
.Python
.pytest_cache
.coverage
.coveragerc
# Python package build artifacts
*.egg-info/
*.egg
dist/
build/
*.onnx
*.pth
# Environment
# .env
.venv/
.venv
env/
venv/
ENV/
# IDE
.idea/
.vscode/
.idea
.vscode
*.swp
*.swo
# Project specific
# Model files
*examples/*.wav
*examples/*.pcm
*examples/*.mp3
*examples/*.flac
*examples/*.acc
*examples/*.ogg
*.pth
*.tar*
# Other project files
.env
Kokoro-82M/
ui/data/
EXTERNAL_UV_DOCUMENTATION*
app
api/temp_files/
ui/data
tests/
*.md
*.txt
requirements.txt
# Docker
Dockerfile*
docker-compose*
examples/ebook_test/chapter_to_audio.py
examples/ebook_test/chapters_to_audio.py
examples/ebook_test/parse_epub.py
api/src/voices/af_jadzia.pt
examples/assorted_checks/test_combinations/output/*
examples/assorted_checks/test_openai/output/*
# Audio files
examples/*.wav
examples/*.pcm
examples/*.mp3
examples/*.flac
examples/*.acc
examples/*.ogg
examples/speech.mp3
examples/phoneme_examples/output/*.wav
examples/assorted_checks/benchmarks/output_audio/*
uv.lock
# Mac MPS virtualenv for dual testing
.venv-mps
*.egg-info
*.pt
*.wav
*.tar*

View file

@ -1 +0,0 @@
3.10

View file

@ -1,12 +1,11 @@
line-length = 88
exclude = ["examples"]
[lint]
select = ["I"]
[lint.isort]
combine-as-imports = true
force-wrap-aliases = true
length-sort = true
split-on-trailing-comma = true
section-order = ["future", "standard-library", "third-party", "first-party", "local-folder"]

View file

@ -2,123 +2,6 @@
Notable changes to this project will be documented in this file.
## [v0.3.0] - 2025-04-04
### Added
- Apple Silicon (MPS) acceleration support for macOS users.
- Voice subtraction capability for creating unique voice effects.
- Windows PowerShell start scripts (`start-cpu.ps1`, `start-gpu.ps1`).
- Automatic model downloading integrated into all start scripts.
- Example Helm chart values for Azure AKS and Nvidia GPU Operator deployments.
- `CONTRIBUTING.md` guidelines for developers.
### Changed
- Version bump of underlying Kokoro and Misaki libraries
- Default API port reverted to 8880.
- Docker containers now run as a non-root user for enhanced security.
- Improved text normalization for numbers, currency, and time formats.
- Updated and improved Helm chart configurations and documentation.
- Enhanced temporary file management with better error tracking.
- Web UI dependencies (Siriwave) are now served locally.
- Standardized environment variable handling across shell/PowerShell scripts.
### Fixed
- Corrected an issue preventing download links from being returned when `streaming=false`.
- Resolved errors in Windows PowerShell scripts related to virtual environment activation order.
- Addressed potential segfaults during inference.
- Fixed various Helm chart issues related to health checks, ingress, and default values.
- Corrected audio quality degradation caused by incorrect bitrate settings in some cases.
- Ensured custom phonemes provided in input text are preserved.
- Fixed a 'MediaSource' error affecting playback stability in the web player.
### Removed
- Obsolete GitHub Actions build workflow, build and publish now occurs on merge to `Release` branch
## [v0.2.0post1] - 2025-02-07
- Fix: Building Kokoro from source with adjustments, to avoid CUDA lock
- Fixed ARM64 compatibility on Spacy dep to avoid emulation slowdown
- Added g++ for Japanese language support
- Temporarily disabled Vietnamese language support due to ARM64 compatibility issues
## [v0.2.0-pre] - 2025-02-06
### Added
- Complete Model Overhaul:
- Upgraded to Kokoro v1.0 model architecture
- Pre-installed multi-language support from Misaki:
- English (en), Japanese (ja), Korean (ko),Chinese (zh), Vietnamese (vi)
- All voice packs included for supported languages, along with the original versions.
- Enhanced Audio Generation Features:
- Per-word timestamped caption generation
- Phoneme-based audio generation capabilities
- Detailed phoneme generation
- Web UI Improvements:
- Improved voice mixing with weighted combinations
- Text file upload support
- Enhanced formatting and user interface
- Cleaner UI (in progress)
- Integration with https://github.com/hexgrad/kokoro and https://github.com/hexgrad/misaki packages
### Removed
- Deprecated support for Kokoro v0.19 model
### Changes
- Combine Voices endpoint now returns a .pt file, with generation combinations generated on the fly otherwise
## [v0.1.4] - 2025-01-30
### Added
- Smart Chunking System:
- New text_processor with smart_split for improved sentence boundary detection
- Dynamically adjusts chunk sizes based on sentence structure, using phoneme/token information in an intial pass
- Should avoid ever going over the 510 limit per chunk, while preserving natural cadence
- Web UI Added (To Be Replacing Gradio):
- Integrated streaming with tempfile generation
- Download links available in X-Download-Path header
- Configurable cleanup triggers for temp files
- Debug Endpoints:
- /debug/threads for thread information and stack traces
- /debug/storage for temp file and output directory monitoring
- /debug/system for system resource information
- /debug/session_pools for ONNX/CUDA session status
- Automated Model Management:
- Auto-download from releases page
- Included download scripts for manual installation
- Pre-packaged voice models in repository
### Changed
- Significant architectural improvements:
- Multi-model architecture support
- Enhanced concurrency handling
- Improved streaming header management
- Better resource/session pool management
## [v0.1.2] - 2025-01-23
### Structural Improvements
- Models can be manually download and placed in api/src/models, or use included script
- TTSGPU/TPSCPU/STTSService classes replaced with a ModelManager service
- CPU/GPU of each of ONNX/PyTorch (Note: Only Pytorch GPU, and ONNX CPU/GPU have been tested)
- Should be able to improve new models as they become available, or new architectures, in a more modular way
- Converted a number of internal processes to async handling to improve concurrency
- Improving separation of concerns towards plug-in and modular structure, making PR's and new features easier
### Web UI (test release)
- An integrated simple web UI has been added on the FastAPI server directly
- This can be disabled via core/config.py or ENV variables if desired.
- Simplifies deployments, utility testing, aesthetics, etc
- Looking to deprecate/collaborate/hand off the Gradio UI
## [v0.1.0] - 2025-01-13
### Changed
- Major Docker improvements:
- Baked model directly into Dockerfile for improved deployment reliability
- Switched to uv for dependency management
- Streamlined container builds and reduced image sizes
- Dependency Management:
- Migrated from pip/poetry to uv for faster, more reliable package management
- Added uv.lock for deterministic builds
- Updated dependency resolution strategy
## [v0.0.5post1] - 2025-01-11
### Fixed
- Docker image tagging and versioning improvements (-gpu, -cpu, -ui)

View file

@ -1,86 +0,0 @@
# Contributing to Kokoro-FastAPI
Always appreciate community involvement in making this project better.
## Development Setup
We use `uv` for managing Python environments and dependencies, and `ruff` for linting and formatting.
1. **Clone the repository:**
```bash
git clone https://github.com/remsky/Kokoro-FastAPI.git
cd Kokoro-FastAPI
```
2. **Install `uv`:**
Follow the instructions on the [official `uv` documentation](https://docs.astral.sh/uv/install/).
3. **Create a virtual environment and install dependencies:**
It's recommended to use a virtual environment. `uv` can create one for you. Install the base dependencies along with the `test` and `cpu` extras (needed for running tests locally).
```bash
# Create and activate a virtual environment (e.g., named .venv)
uv venv
source .venv/bin/activate # On Linux/macOS
# .venv\Scripts\activate # On Windows
# Install dependencies including test requirements
uv pip install -e ".[test,cpu]"
```
*Note: If you have an NVIDIA GPU and want to test GPU-specific features locally, you can install `.[test,gpu]` instead, ensuring you have the correct CUDA toolkit installed.*
*Note: If running via uv locally, you will have to install espeak and handle any pathing issues that arise. The Docker images handle this automatically*
4. **Install `ruff` (if not already installed globally):**
While `ruff` might be included via dependencies, installing it explicitly ensures you have it available.
```bash
uv pip install ruff
```
## Running Tests
Before submitting changes, please ensure all tests pass as this is a automated requirement. The tests are run using `pytest`.
```bash
# Make sure your virtual environment is activated
uv run pytest
```
*Note: The CI workflow runs tests using `uv run pytest api/tests/ --asyncio-mode=auto --cov=api --cov-report=term-missing`. Running `uv run pytest` locally should cover the essential checks.*
## Testing with Docker Compose
In addition to local `pytest` runs, test your changes using Docker Compose to ensure they work correctly within the containerized environment. If you aren't able to test on CUDA hardware, make note so it can be tested by another maintainer
```bash
docker compose -f docker/cpu/docker-compose.yml up --build
+
docker compose -f docker/gpu/docker-compose.yml up --build
```
This command will build the Docker images (if they've changed) and start the services defined in the respective compose file. Verify the application starts correctly and test the relevant functionality.
## Code Formatting and Linting
We use `ruff` to maintain code quality and consistency. Please format and lint your code before committing.
1. **Format the code:**
```bash
# Make sure your virtual environment is activated
ruff format .
```
2. **Lint the code (and apply automatic fixes):**
```bash
# Make sure your virtual environment is activated
ruff check . --fix
```
Review any changes made by `--fix` and address any remaining linting errors manually.
## Submitting Changes
0. Clone the repo
1. Create a new branch for your feature or bug fix.
2. Make your changes, following setup, testing, and formatting guidelines above.
3. Please try to keep your changes inline with the current design, and modular. Large-scale changes will take longer to review and integrate, and have less chance of being approved outright.
4. Push your branch to your fork.
5. Open a Pull Request against the `master` branch of the main repository.
Thank you for contributing!

44
Dockerfile Normal file
View file

@ -0,0 +1,44 @@
FROM nvidia/cuda:12.1.0-base-ubuntu22.04
# Install base system dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
python3-pip \
python3-dev \
espeak-ng \
git \
libsndfile1 \
curl \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# Install PyTorch with CUDA support first
RUN pip3 install --no-cache-dir torch==2.5.1 --extra-index-url https://download.pytorch.org/whl/cu121
# Install all other dependencies from requirements.txt
COPY requirements.txt .
RUN pip3 install --no-cache-dir -r requirements.txt
# Set working directory
WORKDIR /app
# Create non-root user
RUN useradd -m -u 1000 appuser
# Create model directory and set ownership
RUN mkdir -p /app/Kokoro-82M && \
chown -R appuser:appuser /app
# Switch to non-root user
USER appuser
# Run with Python unbuffered output for live logging
ENV PYTHONUNBUFFERED=1
# Copy only necessary application code
COPY --chown=appuser:appuser api /app/api
# Set Python path (app first for our imports, then model dir for model imports)
ENV PYTHONPATH=/app:/app/Kokoro-82M
# Run FastAPI server with debug logging and reload
CMD ["uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]

44
Dockerfile.cpu Normal file
View file

@ -0,0 +1,44 @@
FROM ubuntu:22.04
# Install base system dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
python3-pip \
python3-dev \
espeak-ng \
git \
libsndfile1 \
curl \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# Install PyTorch CPU version and ONNX runtime
RUN pip3 install --no-cache-dir torch==2.5.1 --extra-index-url https://download.pytorch.org/whl/cpu
# Install all other dependencies from requirements.txt
COPY requirements.txt .
RUN pip3 install --no-cache-dir -r requirements.txt
# Copy application code and model
COPY . /app/
# Set working directory
WORKDIR /app
# Run with Python unbuffered output for live logging
ENV PYTHONUNBUFFERED=1
# Create non-root user
RUN useradd -m -u 1000 appuser
# Create directories and set permissions
RUN mkdir -p /app/Kokoro-82M && \
chown -R appuser:appuser /app
# Switch to non-root user
USER appuser
# Set Python path (app first for our imports, then model dir for model imports)
ENV PYTHONPATH=/app:/app/Kokoro-82M
# Run FastAPI server with debug logging and reload
CMD ["uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]

1
Kokoro-82M Submodule

@ -0,0 +1 @@
Subproject commit c97b7bbc3e60f447383c79b2f94fee861ff156ac

201
LICENSE
View file

@ -1,201 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

448
README.md
View file

@ -2,142 +2,58 @@
<img src="githubbanner.png" alt="Kokoro TTS Banner">
</p>
# <sub><sub>_`FastKoko`_ </sub></sub>
[![Tests](https://img.shields.io/badge/tests-69-darkgreen)]()
[![Coverage](https://img.shields.io/badge/coverage-54%25-tan)]()
[![Try on Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Try%20on-Spaces-blue)](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero)
[![Kokoro](https://img.shields.io/badge/kokoro-0.9.2-BB5420)](https://github.com/hexgrad/kokoro)
[![Misaki](https://img.shields.io/badge/misaki-0.9.3-B8860B)](https://github.com/hexgrad/misaki)
[![Tested at Model Commit](https://img.shields.io/badge/last--tested--model--commit-1.0::9901c2b-blue)](https://huggingface.co/hexgrad/Kokoro-82M/commit/9901c2b79161b6e898b7ea857ae5298f47b8b0d6)
# Kokoro TTS API
[![Tests](https://img.shields.io/badge/tests-117%20passed-darkgreen)]()
[![Coverage](https://img.shields.io/badge/coverage-75%25-darkgreen)]()
[![Tested at Model Commit](https://img.shields.io/badge/last--tested--model--commit-a67f113-blue)](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667) [![Try on Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Try%20on-Spaces-blue)](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero)
Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model
- Multi-language support (English, Japanese, Korean, Chinese, _Vietnamese soon_)
- OpenAI-compatible Speech endpoint, NVIDIA GPU accelerated or CPU inference with PyTorch
- ONNX support coming soon, see v0.1.5 and earlier for legacy ONNX support in the interim
- Debug endpoints for monitoring system stats, integrated web UI on localhost:8880/web
- Phoneme-based audio generation, phoneme generation
- Per-word timestamped caption generation
- Voice mixing with weighted combinations
### Integration Guides
[![Helm Chart](https://img.shields.io/badge/Helm%20Chart-black?style=flat&logo=helm&logoColor=white)](https://github.com/remsky/Kokoro-FastAPI/wiki/Setup-Kubernetes) [![DigitalOcean](https://img.shields.io/badge/DigitalOcean-black?style=flat&logo=digitalocean&logoColor=white)](https://github.com/remsky/Kokoro-FastAPI/wiki/Integrations-DigitalOcean) [![SillyTavern](https://img.shields.io/badge/SillyTavern-black?style=flat&color=red)](https://github.com/remsky/Kokoro-FastAPI/wiki/Integrations-SillyTavern)
[![OpenWebUI](https://img.shields.io/badge/OpenWebUI-black?style=flat&color=white)](https://github.com/remsky/Kokoro-FastAPI/wiki/Integrations-OpenWebUi)
## Get Started
<details>
<summary>Quickest Start (docker run)</summary>
- OpenAI-compatible Speech endpoint, with inline voice combination functionality
- NVIDIA GPU accelerated or CPU Onnx inference
- very fast generation time
- 100x+ real time speed via HF A100
- 35-50x+ real time speed via 4060Ti
- 5x+ real time speed via M3 Pro CPU
- streaming support w/ variable chunking to control latency & artifacts
- simple audio generation web ui utility
- (new) phoneme endpoints for conversion and generation
Pre built images are available to run, with arm/multi-arch support, and baked in models
Refer to the core/config.py file for a full list of variables which can be managed via the environment
## Quick Start
```bash
# the `latest` tag can be used, though it may have some unexpected bonus features which impact stability.
Named versions should be pinned for your regular usage.
Feedback/testing is always welcome
The service can be accessed through either the API endpoints or the Gradio web interface.
docker run -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-cpu:latest # CPU, or:
docker run --gpus all -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-gpu:latest #NVIDIA GPU
```
</details>
<details>
<summary>Quick Start (docker compose) </summary>
1. Install prerequisites, and start the service using Docker Compose (Full setup including UI):
- Install [Docker](https://www.docker.com/products/docker-desktop/)
- Clone the repository:
1. Install prerequisites:
- Install [Docker Desktop](https://www.docker.com/products/docker-desktop/) + [Git](https://git-scm.com/downloads)
- Clone and start the service:
```bash
git clone https://github.com/remsky/Kokoro-FastAPI.git
cd Kokoro-FastAPI
cd docker/gpu # For GPU support
# or cd docker/cpu # For CPU support
docker compose up --build
# *Note for Apple Silicon (M1/M2) users:
# The current GPU build relies on CUDA, which is not supported on Apple Silicon.
# If you are on an M1/M2/M3 Mac, please use the `docker/cpu` setup.
# MPS (Apple's GPU acceleration) support is planned but not yet available.
# Models will auto-download, but if needed you can manually download:
python docker/scripts/download_model.py --output api/src/models/v1_0
# Or run directly via UV:
./start-gpu.sh # For GPU support
./start-cpu.sh # For CPU support
```
</details>
<details>
<summary>Direct Run (via uv) </summary>
1. Install prerequisites ():
- Install [astral-uv](https://docs.astral.sh/uv/)
- Install [espeak-ng](https://github.com/espeak-ng/espeak-ng) in your system if you want it available as a fallback for unknown words/sounds. The upstream libraries may attempt to handle this, but results have varied.
- Clone the repository:
```bash
git clone https://github.com/remsky/Kokoro-FastAPI.git
cd Kokoro-FastAPI
```
Run the [model download script](https://github.com/remsky/Kokoro-FastAPI/blob/master/docker/scripts/download_model.py) if you haven't already
Start directly via UV (with hot-reload)
Linux and macOS
```bash
./start-cpu.sh OR
./start-gpu.sh
```
Windows
```powershell
.\start-cpu.ps1 OR
.\start-gpu.ps1
```
</details>
<details open>
<summary> Up and Running? </summary>
Run locally as an OpenAI-Compatible Speech Endpoint
2. Run locally as an OpenAI-Compatible Speech Endpoint
```python
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:8880/v1", api_key="not-needed"
base_url="http://localhost:8880/v1",
api_key="not-needed"
)
with client.audio.speech.with_streaming_response.create(
response = client.audio.speech.create(
model="kokoro",
voice="af_sky+af_bella", #single or multiple voicepack combo
input="Hello world!"
) as response:
input="Hello world!",
response_format="mp3"
)
response.stream_to_file("output.mp3")
```
- The API will be available at http://localhost:8880
- API Documentation: http://localhost:8880/docs
- Web Interface: http://localhost:8880/web
<div align="center" style="display: flex; justify-content: center; gap: 10px;">
<img src="assets/docs-screenshot.png" width="42%" alt="API Documentation" style="border: 2px solid #333; padding: 10px;">
<img src="assets/webui-screenshot.png" width="42%" alt="Web UI Screenshot" style="border: 2px solid #333; padding: 10px;">
</div>
</details>
or visit http://localhost:7860
<p align="center">
<img src="ui\GradioScreenShot.png" width="80%" alt="Voice Analysis Comparison" style="border: 2px solid #333; padding: 10px;">
</p>
## Features
<details>
<summary>OpenAI-Compatible Speech Endpoint</summary>
@ -146,8 +62,8 @@ with client.audio.speech.with_streaming_response.create(
from openai import OpenAI
client = OpenAI(base_url="http://localhost:8880/v1", api_key="not-needed")
response = client.audio.speech.create(
model="kokoro",
voice="af_bella+af_sky", # see /api/src/core/openai_mappings.json to customize
model="kokoro", # Not used but required for compatibility, also accepts library defaults
voice="af_bella+af_sky",
input="Hello world!",
response_format="mp3"
)
@ -166,7 +82,7 @@ voices = response.json()["voices"]
response = requests.post(
"http://localhost:8880/v1/audio/speech",
json={
"model": "kokoro",
"model": "kokoro", # Not used but required for compatibility
"input": "Hello world!",
"voice": "af_bella",
"response_format": "mp3", # Supported: mp3, wav, opus, flac
@ -189,10 +105,9 @@ python examples/assorted_checks/test_voices/test_all_voices.py # Test all availa
<details>
<summary>Voice Combination</summary>
- Weighted voice combinations using ratios (e.g., "af_bella(2)+af_heart(1)" for 67%/33% mix)
- Ratios are automatically normalized to sum to 100%
- Available through any endpoint by adding weights in parentheses
- Averages model weights of any existing voicepacks
- Saves generated voicepacks for future use
- (new) Available through any endpoint, simply concatenate desired packs with "+"
Combine voices and generate audio:
```python
@ -200,46 +115,22 @@ import requests
response = requests.get("http://localhost:8880/v1/audio/voices")
voices = response.json()["voices"]
# Example 1: Simple voice combination (50%/50% mix)
response = requests.post(
"http://localhost:8880/v1/audio/speech",
json={
"input": "Hello world!",
"voice": "af_bella+af_sky", # Equal weights
"response_format": "mp3"
}
)
# Example 2: Weighted voice combination (67%/33% mix)
response = requests.post(
"http://localhost:8880/v1/audio/speech",
json={
"input": "Hello world!",
"voice": "af_bella(2)+af_sky(1)", # 2:1 ratio = 67%/33%
"response_format": "mp3"
}
)
# Example 3: Download combined voice as .pt file
# Create combined voice (saves locally on server)
response = requests.post(
"http://localhost:8880/v1/audio/voices/combine",
json="af_bella(2)+af_sky(1)" # 2:1 ratio = 67%/33%
json=[voices[0], voices[1]]
)
combined_voice = response.json()["voice"]
# Save the .pt file
with open("combined_voice.pt", "wb") as f:
f.write(response.content)
# Use the downloaded voice file
# Generate audio with combined voice (or, simply pass multiple directly with `+` )
response = requests.post(
"http://localhost:8880/v1/audio/speech",
json={
"input": "Hello world!",
"voice": "combined_voice", # Use the saved voice file
"voice": combined_voice, # or skip the above step with f"{voices[0]}+{voices[1]}"
"response_format": "mp3"
}
)
```
<p align="center">
<img src="assets/voice_analysis.png" width="80%" alt="Voice Analysis Comparison" style="border: 2px solid #333; padding: 10px;">
@ -253,7 +144,7 @@ response = requests.post(
- wav
- opus
- flac
- m4a
- aac
- pcm
<p align="center">
@ -262,6 +153,21 @@ response = requests.post(
</details>
<details>
<summary>Gradio Web Utility</summary>
Access the interactive web UI at http://localhost:7860 after starting the service. Features include:
- Voice/format/speed selection
- Audio playback and download
- Text file or direct input
If you only want the API, just comment out everything in the docker-compose.yml under and including `gradio-ui`
Currently, voices created via the API are accessible here, but voice combination/creation has not yet been added
*Note: Recent updates for streaming could lead to temporary glitches. If so, pull from the most recent stable release v0.0.2 to restore*
</details>
<details>
<summary>Streaming Support</summary>
@ -269,7 +175,7 @@ response = requests.post(
# OpenAI-compatible streaming
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:8880/v1", api_key="not-needed")
base_url="http://localhost:8880", api_key="not-needed")
# Stream to file
with client.audio.speech.with_streaming_response.create(
@ -351,23 +257,20 @@ Benchmarking was performed on generation via the local API using text lengths up
</p>
Key Performance Metrics:
- Realtime Speed: Ranges between 35x-100x (generation time to output audio length)
- Realtime Speed: Ranges between 25-50x (generation time to output audio length)
- Average Processing Rate: 137.67 tokens/second (cl100k_base)
</details>
<details>
<summary>GPU Vs. CPU</summary>
```bash
# GPU: Requires NVIDIA GPU with CUDA 12.8 support (~35x-100x realtime speed)
cd docker/gpu
docker compose up --build
# CPU: PyTorch CPU inference
cd docker/cpu
# GPU: Requires NVIDIA GPU with CUDA 12.1 support (~35x realtime speed)
docker compose up --build
# CPU: ONNX optimized inference (~2.4x realtime speed)
docker compose -f docker-compose.cpu.yml up --build
```
*Note: Overall speed may have reduced somewhat with the structural changes to accommodate streaming. Looking into it*
*Note: Overall speed may have reduced somewhat with the structural changes to accomodate streaming. Looking into it*
</details>
<details>
@ -375,80 +278,6 @@ docker compose up --build
- Automatically splits and stitches at sentence boundaries
- Helps to reduce artifacts and allow long form processing as the base model is only currently configured for approximately 30s output
The model is capable of processing up to a 510 phonemized token chunk at a time, however, this can often lead to 'rushed' speech or other artifacts. An additional layer of chunking is applied in the server, that creates flexible chunks with a `TARGET_MIN_TOKENS` , `TARGET_MAX_TOKENS`, and `ABSOLUTE_MAX_TOKENS` which are configurable via environment variables, and set to 175, 250, 450 by default
</details>
<details>
<summary>Timestamped Captions & Phonemes</summary>
Generate audio with word-level timestamps without streaming:
```python
import requests
import base64
import json
response = requests.post(
"http://localhost:8880/dev/captioned_speech",
json={
"model": "kokoro",
"input": "Hello world!",
"voice": "af_bella",
"speed": 1.0,
"response_format": "mp3",
"stream": False,
},
stream=False
)
with open("output.mp3","wb") as f:
audio_json=json.loads(response.content)
# Decode base 64 stream to bytes
chunk_audio=base64.b64decode(audio_json["audio"].encode("utf-8"))
# Process streaming chunks
f.write(chunk_audio)
# Print word level timestamps
print(audio_json["timestamps"])
```
Generate audio with word-level timestamps with streaming:
```python
import requests
import base64
import json
response = requests.post(
"http://localhost:8880/dev/captioned_speech",
json={
"model": "kokoro",
"input": "Hello world!",
"voice": "af_bella",
"speed": 1.0,
"response_format": "mp3",
"stream": True,
},
stream=True
)
f=open("output.mp3","wb")
for chunk in response.iter_lines(decode_unicode=True):
if chunk:
chunk_json=json.loads(chunk)
# Decode base 64 stream to bytes
chunk_audio=base64.b64decode(chunk_json["audio"].encode("utf-8"))
# Process streaming chunks
f.write(chunk_audio)
# Print word level timestamps
print(chunk_json["timestamps"])
```
</details>
<details>
@ -458,161 +287,36 @@ Convert text to phonemes and/or generate audio directly from phonemes:
```python
import requests
def get_phonemes(text: str, language: str = "a"):
"""Get phonemes and tokens for input text"""
# Convert text to phonemes
response = requests.post(
"http://localhost:8880/dev/phonemize",
json={"text": text, "language": language} # "a" for American English
json={
"text": "Hello world!",
"language": "a" # "a" for American English
}
)
response.raise_for_status()
result = response.json()
return result["phonemes"], result["tokens"]
phonemes = result["phonemes"] # Phoneme string e.g ðɪs ɪz ˈoʊnli ɐ tˈɛst
tokens = result["tokens"] # Token IDs including start/end tokens
def generate_audio_from_phonemes(phonemes: str, voice: str = "af_bella"):
"""Generate audio from phonemes"""
# Generate audio from phonemes
response = requests.post(
"http://localhost:8880/dev/generate_from_phonemes",
json={"phonemes": phonemes, "voice": voice},
headers={"Accept": "audio/wav"}
json={
"phonemes": phonemes,
"voice": "af_bella",
"speed": 1.0
}
)
if response.status_code != 200:
print(f"Error: {response.text}")
return None
return response.content
# Example usage
text = "Hello world!"
try:
# Convert text to phonemes
phonemes, tokens = get_phonemes(text)
print(f"Phonemes: {phonemes}") # e.g. ðɪs ɪz ˈoʊnli ɐ tˈɛst
print(f"Tokens: {tokens}") # Token IDs including start/end tokens
# Generate and save audio
if audio_bytes := generate_audio_from_phonemes(phonemes):
# Save WAV audio
with open("speech.wav", "wb") as f:
f.write(audio_bytes)
print(f"Generated {len(audio_bytes)} bytes of audio")
except Exception as e:
print(f"Error: {e}")
f.write(response.content)
```
See `examples/phoneme_examples/generate_phonemes.py` for a sample script.
</details>
<details>
<summary>Debug Endpoints</summary>
Monitor system state and resource usage with these endpoints:
- `/debug/threads` - Get thread information and stack traces
- `/debug/storage` - Monitor temp file and output directory usage
- `/debug/system` - Get system information (CPU, memory, GPU)
- `/debug/session_pools` - View ONNX session and CUDA stream status
Useful for debugging resource exhaustion or performance issues.
</details>
## Known Issues & Troubleshooting
<details>
<summary>Missing words & Missing some timestamps</summary>
The api will automaticly do text normalization on input text which may incorrectly remove or change some phrases. This can be disabled by adding `"normalization_options":{"normalize": false}` to your request json:
```python
import requests
response = requests.post(
"http://localhost:8880/v1/audio/speech",
json={
"input": "Hello world!",
"voice": "af_heart",
"response_format": "pcm",
"normalization_options":
{
"normalize": False
}
},
stream=True
)
for chunk in response.iter_content(chunk_size=1024):
if chunk:
# Process streaming chunks
pass
```
</details>
<details>
<summary>Versioning & Development</summary>
**Branching Strategy:**
* **`release` branch:** Contains the latest stable build, recommended for production use. Docker images tagged with specific versions (e.g., `v0.3.0`) are built from this branch.
* **`master` branch:** Used for active development. It may contain experimental features, ongoing changes, or fixes not yet in a stable release. Use this branch if you want the absolute latest code, but be aware it might be less stable. The `latest` Docker tag often points to builds from this branch.
Note: This is a *development* focused project at its core.
If you run into trouble, you may have to roll back a version on the release tags if something comes up, or build up from source and/or troubleshoot + submit a PR.
Free and open source is a community effort, and there's only really so many hours in a day. If you'd like to support the work, feel free to open a PR, buy me a coffee, or report any bugs/features/etc you find during use.
<a href="https://www.buymeacoffee.com/remsky" target="_blank">
<img
src="https://cdn.buymeacoffee.com/buttons/v2/default-violet.png"
alt="Buy Me A Coffee"
style="height: 30px !important;width: 110px !important;"
>
</a>
</details>
<details>
<summary>Linux GPU Permissions</summary>
Some Linux users may encounter GPU permission issues when running as non-root.
Can't guarantee anything, but here are some common solutions, consider your security requirements carefully
### Option 1: Container Groups (Likely the best option)
```yaml
services:
kokoro-tts:
# ... existing config ...
group_add:
- "video"
- "render"
```
### Option 2: Host System Groups
```yaml
services:
kokoro-tts:
# ... existing config ...
user: "${UID}:${GID}"
group_add:
- "video"
```
Note: May require adding host user to groups: `sudo usermod -aG docker,video $USER` and system restart.
### Option 3: Device Permissions (Use with caution)
```yaml
services:
kokoro-tts:
# ... existing config ...
devices:
- /dev/nvidia0:/dev/nvidia0
- /dev/nvidiactl:/dev/nvidiactl
- /dev/nvidia-uvm:/dev/nvidia-uvm
```
⚠️ Warning: Reduces system security. Use only in development environments.
Prerequisites: NVIDIA GPU, drivers, and container toolkit must be properly configured.
Visit [NVIDIA Container Toolkit installation](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) for more detailed information
</details>
## Model and License
<details open>

View file

@ -1 +0,0 @@
0.3.0

View file

@ -1,172 +0,0 @@
{
"istftnet": {
"upsample_kernel_sizes": [
20,
12
],
"upsample_rates": [
10,
6
],
"gen_istft_hop_size": 5,
"gen_istft_n_fft": 20,
"resblock_dilation_sizes": [
[
1,
3,
5
],
[
1,
3,
5
],
[
1,
3,
5
]
],
"resblock_kernel_sizes": [
3,
7,
11
],
"upsample_initial_channel": 512
},
"dim_in": 64,
"dropout": 0.2,
"hidden_dim": 512,
"max_conv_dim": 512,
"max_dur": 50,
"multispeaker": true,
"n_layer": 3,
"n_mels": 80,
"n_token": 178,
"style_dim": 128,
"text_encoder_kernel_size": 5,
"plbert": {
"hidden_size": 768,
"num_attention_heads": 12,
"intermediate_size": 2048,
"max_position_embeddings": 512,
"num_hidden_layers": 12,
"dropout": 0.1
},
"vocab": {
";": 1,
":": 2,
",": 3,
".": 4,
"!": 5,
"?": 6,
"—": 9,
"…": 10,
"\"": 11,
"(": 12,
")": 13,
"“": 14,
"”": 15,
" ": 16,
"̃": 17,
"ʣ": 18,
"ʥ": 19,
"ʦ": 20,
"ʨ": 21,
"ᵝ": 22,
"ꭧ": 23,
"A": 24,
"I": 25,
"O": 31,
"Q": 33,
"S": 35,
"T": 36,
"W": 39,
"Y": 41,
"ᵊ": 42,
"a": 43,
"b": 44,
"c": 45,
"d": 46,
"e": 47,
"f": 48,
"h": 50,
"i": 51,
"j": 52,
"k": 53,
"l": 54,
"m": 55,
"n": 56,
"o": 57,
"p": 58,
"q": 59,
"r": 60,
"s": 61,
"t": 62,
"u": 63,
"v": 64,
"w": 65,
"x": 66,
"y": 67,
"z": 68,
"ɑ": 69,
"ɐ": 70,
"ɒ": 71,
"æ": 72,
"β": 75,
"ɔ": 76,
"ɕ": 77,
"ç": 78,
"ɖ": 80,
"ð": 81,
"ʤ": 82,
"ə": 83,
"ɚ": 85,
"ɛ": 86,
"ɜ": 87,
"ɟ": 90,
"ɡ": 92,
"ɥ": 99,
"ɨ": 101,
"ɪ": 102,
"ʝ": 103,
"ɯ": 110,
"ɰ": 111,
"ŋ": 112,
"ɳ": 113,
"ɲ": 114,
"ɴ": 115,
"ø": 116,
"ɸ": 118,
"θ": 119,
"œ": 120,
"ɹ": 123,
"ɾ": 125,
"ɻ": 126,
"ʁ": 128,
"ɽ": 129,
"ʂ": 130,
"ʃ": 131,
"ʈ": 132,
"ʧ": 133,
"ʊ": 135,
"ʋ": 136,
"ʌ": 138,
"ɣ": 139,
"ɤ": 140,
"χ": 142,
"ʎ": 143,
"ʒ": 147,
"ʔ": 148,
"ˈ": 156,
"ˌ": 157,
"ː": 158,
"ʰ": 162,
"ʲ": 164,
"↓": 169,
"→": 171,
"↗": 172,
"↘": 173,
"ᵻ": 177
}
}

View file

@ -1,4 +1,3 @@
import torch
from pydantic_settings import BaseSettings
@ -10,76 +9,28 @@ class Settings(BaseSettings):
host: str = "0.0.0.0"
port: int = 8880
# Application Settings
# TTS Settings
output_dir: str = "output"
output_dir_size_limit_mb: float = 500.0 # Maximum size of output directory in MB
default_voice: str = "af_heart"
default_voice_code: str | None = (
None # If set, overrides the first letter of voice name, though api call param still takes precedence
)
use_gpu: bool = True # Whether to use GPU acceleration if available
device_type: str | None = (
None # Will be auto-detected if None, can be "cuda", "mps", or "cpu"
)
allow_local_voice_saving: bool = (
False # Whether to allow saving combined voices locally
)
# Container absolute paths
model_dir: str = "/app/api/src/models" # Absolute path in container
voices_dir: str = "/app/api/src/voices/v1_0" # Absolute path in container
# Audio Settings
default_voice: str = "af"
model_dir: str = "/app/Kokoro-82M" # Base directory for model files
pytorch_model_path: str = "kokoro-v0_19.pth"
onnx_model_path: str = "kokoro-v0_19.onnx"
voices_dir: str = "voices"
sample_rate: int = 24000
# Text Processing Settings
target_min_tokens: int = 175 # Target minimum tokens per chunk
target_max_tokens: int = 250 # Target maximum tokens per chunk
absolute_max_tokens: int = 450 # Absolute maximum tokens per chunk
advanced_text_normalization: bool = True # Preproesses the text before misiki
voice_weight_normalization: bool = (
True # Normalize the voice weights so they add up to 1
)
max_chunk_size: int = 300 # Maximum size of text chunks for processing
gap_trim_ms: int = 250 # Amount to trim from streaming chunk ends in milliseconds
gap_trim_ms: int = (
1 # Base amount to trim from streaming chunk ends in milliseconds
)
dynamic_gap_trim_padding_ms: int = 410 # Padding to add to dynamic gap trim
dynamic_gap_trim_padding_char_multiplier: dict[str, float] = {
".": 1,
"!": 0.9,
"?": 1,
",": 0.8,
}
# Web Player Settings
enable_web_player: bool = True # Whether to serve the web player UI
web_player_path: str = "web" # Path to web player static files
cors_origins: list[str] = ["*"] # CORS origins for web player
cors_enabled: bool = True # Whether to enable CORS
# Temp File Settings for WEB Ui
temp_file_dir: str = "api/temp_files" # Directory for temporary audio files (relative to project root)
max_temp_dir_size_mb: int = 2048 # Maximum size of temp directory (2GB)
max_temp_dir_age_hours: int = 1 # Remove temp files older than 1 hour
max_temp_dir_count: int = 3 # Maximum number of temp files to keep
# ONNX Optimization Settings
onnx_num_threads: int = 4 # Number of threads for intra-op parallelism
onnx_inter_op_threads: int = 4 # Number of threads for inter-op parallelism
onnx_execution_mode: str = "parallel" # parallel or sequential
onnx_optimization_level: str = "all" # all, basic, or disabled
onnx_memory_pattern: bool = True # Enable memory pattern optimization
onnx_arena_extend_strategy: str = "kNextPowerOfTwo" # Memory allocation strategy
class Config:
env_file = ".env"
def get_device(self) -> str:
"""Get the appropriate device based on settings and availability"""
if not self.use_gpu:
return "cpu"
if self.device_type:
return self.device_type
# Auto-detect device
if torch.backends.mps.is_available():
return "mps"
elif torch.cuda.is_available():
return "cuda"
return "cpu"
settings = Settings()

185
api/src/core/kokoro.py Normal file
View file

@ -0,0 +1,185 @@
import re
import torch
import phonemizer
def split_num(num):
num = num.group()
if "." in num:
return num
elif ":" in num:
h, m = [int(n) for n in num.split(":")]
if m == 0:
return f"{h} o'clock"
elif m < 10:
return f"{h} oh {m}"
return f"{h} {m}"
year = int(num[:4])
if year < 1100 or year % 1000 < 10:
return num
left, right = num[:2], int(num[2:4])
s = "s" if num.endswith("s") else ""
if 100 <= year % 1000 <= 999:
if right == 0:
return f"{left} hundred{s}"
elif right < 10:
return f"{left} oh {right}{s}"
return f"{left} {right}{s}"
def flip_money(m):
m = m.group()
bill = "dollar" if m[0] == "$" else "pound"
if m[-1].isalpha():
return f"{m[1:]} {bill}s"
elif "." not in m:
s = "" if m[1:] == "1" else "s"
return f"{m[1:]} {bill}{s}"
b, c = m[1:].split(".")
s = "" if b == "1" else "s"
c = int(c.ljust(2, "0"))
coins = (
f"cent{'' if c == 1 else 's'}"
if m[0] == "$"
else ("penny" if c == 1 else "pence")
)
return f"{b} {bill}{s} and {c} {coins}"
def point_num(num):
a, b = num.group().split(".")
return " point ".join([a, " ".join(b)])
def normalize_text(text):
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
text = text.replace("«", chr(8220)).replace("»", chr(8221))
text = text.replace(chr(8220), '"').replace(chr(8221), '"')
text = text.replace("(", "«").replace(")", "»")
for a, b in zip("、。!,:;?", ",.!,:;?"):
text = text.replace(a, b + " ")
text = re.sub(r"[^\S \n]", " ", text)
text = re.sub(r" +", " ", text)
text = re.sub(r"(?<=\n) +(?=\n)", "", text)
text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text)
text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text)
text = re.sub(r"\b(?:Ms\.|MS\.(?= [A-Z]))", "Miss", text)
text = re.sub(r"\b(?:Mrs\.|MRS\.(?= [A-Z]))", "Mrs", text)
text = re.sub(r"\betc\.(?! [A-Z])", "etc", text)
text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text)
text = re.sub(
r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)", split_num, text
)
text = re.sub(r"(?<=\d),(?=\d)", "", text)
text = re.sub(
r"(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b",
flip_money,
text,
)
text = re.sub(r"\d*\.\d+", point_num, text)
text = re.sub(r"(?<=\d)-(?=\d)", " to ", text)
text = re.sub(r"(?<=\d)S", " S", text)
text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
text = re.sub(r"(?<=X')S\b", "s", text)
text = re.sub(
r"(?:[A-Za-z]\.){2,} [a-z]", lambda m: m.group().replace(".", "-"), text
)
text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
return text.strip()
def get_vocab():
_pad = "$"
_punctuation = ';:,.!?¡¿—…"«»“” '
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'"
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
dicts = {}
for i in range(len((symbols))):
dicts[symbols[i]] = i
return dicts
VOCAB = get_vocab()
def tokenize(ps):
return [i for i in map(VOCAB.get, ps) if i is not None]
phonemizers = dict(
a=phonemizer.backend.EspeakBackend(
language="en-us", preserve_punctuation=True, with_stress=True
),
b=phonemizer.backend.EspeakBackend(
language="en-gb", preserve_punctuation=True, with_stress=True
),
)
def phonemize(text, lang, norm=True):
if norm:
text = normalize_text(text)
ps = phonemizers[lang].phonemize([text])
ps = ps[0] if ps else ""
# https://en.wiktionary.org/wiki/kokoro#English
ps = ps.replace("kəkˈoːɹoʊ", "kˈoʊkəɹoʊ").replace("kəkˈɔːɹəʊ", "kˈəʊkəɹəʊ")
ps = ps.replace("ʲ", "j").replace("r", "ɹ").replace("x", "k").replace("ɬ", "l")
ps = re.sub(r"(?<=[a-zɹː])(?=hˈʌndɹɪd)", " ", ps)
ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»“” ]|$)', "z", ps)
if lang == "a":
ps = re.sub(r"(?<=nˈaɪn)ti(?!ː)", "di", ps)
ps = "".join(filter(lambda p: p in VOCAB, ps))
return ps.strip()
def length_to_mask(lengths):
mask = (
torch.arange(lengths.max())
.unsqueeze(0)
.expand(lengths.shape[0], -1)
.type_as(lengths)
)
mask = torch.gt(mask + 1, lengths.unsqueeze(1))
return mask
@torch.no_grad()
def forward(model, tokens, ref_s, speed):
device = ref_s.device
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
text_mask = length_to_mask(input_lengths).to(device)
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
s = ref_s[:, 128:]
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
x, _ = model.predictor.lstm(d)
duration = model.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long()
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
c_frame += pred_dur[0, i].item()
en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
t_en = model.text_encoder(tokens, input_lengths, text_mask)
asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
def generate(model, text, voicepack, lang="a", speed=1):
ps = phonemize(text, lang)
tokens = tokenize(ps)
if not tokens:
return None
elif len(tokens) > 510:
tokens = tokens[:510]
print("Truncated to 510 tokens")
ref_s = voicepack[len(tokens)]
out = forward(model, tokens, ref_s, speed)
ps = "".join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
return out, ps

View file

@ -1,50 +0,0 @@
"""Model configuration for Kokoro V1.
This module provides model-specific configuration settings that complement the application-level
settings in config.py. While config.py handles general application settings (API, paths, etc.),
this module focuses on memory management and model file paths.
"""
from pydantic import BaseModel, Field
class KokoroV1Config(BaseModel):
"""Kokoro V1 configuration."""
languages: list[str] = ["en"]
class Config:
frozen = True
class PyTorchConfig(BaseModel):
"""PyTorch backend configuration."""
memory_threshold: float = Field(0.8, description="Memory threshold for cleanup")
retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors")
class Config:
frozen = True
class ModelConfig(BaseModel):
"""Kokoro V1 model configuration."""
# General settings
cache_voices: bool = Field(True, description="Whether to cache voice tensors")
voice_cache_size: int = Field(2, description="Maximum number of cached voices")
# Model filename
pytorch_kokoro_v1_file: str = Field(
"v1_0/kokoro-v1_0.pth", description="PyTorch Kokoro V1 model filename"
)
# Backend config
pytorch_gpu: PyTorchConfig = Field(default_factory=PyTorchConfig)
class Config:
frozen = True
# Global instance
model_config = ModelConfig()

View file

@ -1,18 +0,0 @@
{
"models": {
"tts-1": "kokoro-v1_0",
"tts-1-hd": "kokoro-v1_0",
"kokoro": "kokoro-v1_0"
},
"voices": {
"alloy": "am_v0adam",
"ash": "af_v0nicole",
"coral": "bf_v0emma",
"echo": "af_v0bella",
"fable": "af_sarah",
"onyx": "bm_george",
"nova": "bf_isabella",
"sage": "am_michael",
"shimmer": "af_sky"
}
}

View file

@ -1,413 +0,0 @@
"""Async file and path operations."""
import io
import json
import os
from pathlib import Path
from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Set
import aiofiles
import aiofiles.os
import torch
from loguru import logger
from .config import settings
async def _find_file(
filename: str,
search_paths: List[str],
filter_fn: Optional[Callable[[str], bool]] = None,
) -> str:
"""Find file in search paths.
Args:
filename: Name of file to find
search_paths: List of paths to search in
filter_fn: Optional function to filter files
Returns:
Absolute path to file
Raises:
RuntimeError: If file not found
"""
if os.path.isabs(filename) and await aiofiles.os.path.exists(filename):
return filename
for path in search_paths:
full_path = os.path.join(path, filename)
if await aiofiles.os.path.exists(full_path):
if filter_fn is None or filter_fn(full_path):
return full_path
raise FileNotFoundError(f"File not found: {filename} in paths: {search_paths}")
async def _scan_directories(
search_paths: List[str], filter_fn: Optional[Callable[[str], bool]] = None
) -> Set[str]:
"""Scan directories for files.
Args:
search_paths: List of paths to scan
filter_fn: Optional function to filter files
Returns:
Set of matching filenames
"""
results = set()
for path in search_paths:
if not await aiofiles.os.path.exists(path):
continue
try:
# Get directory entries first
entries = await aiofiles.os.scandir(path)
# Then process entries after await completes
for entry in entries:
if filter_fn is None or filter_fn(entry.name):
results.add(entry.name)
except Exception as e:
logger.warning(f"Error scanning {path}: {e}")
return results
async def get_model_path(model_name: str) -> str:
"""Get path to model file.
Args:
model_name: Name of model file
Returns:
Absolute path to model file
Raises:
RuntimeError: If model not found
"""
# Get api directory path (two levels up from core)
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
# Construct model directory path relative to api directory
model_dir = os.path.join(api_dir, settings.model_dir)
# Ensure model directory exists
os.makedirs(model_dir, exist_ok=True)
# Search in model directory
search_paths = [model_dir]
logger.debug(f"Searching for model in path: {model_dir}")
return await _find_file(model_name, search_paths)
async def get_voice_path(voice_name: str) -> str:
"""Get path to voice file.
Args:
voice_name: Name of voice file (without .pt extension)
Returns:
Absolute path to voice file
Raises:
RuntimeError: If voice not found
"""
# Get api directory path
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
# Construct voice directory path relative to api directory
voice_dir = os.path.join(api_dir, settings.voices_dir)
# Ensure voice directory exists
os.makedirs(voice_dir, exist_ok=True)
voice_file = f"{voice_name}.pt"
# Search in voice directory/o
search_paths = [voice_dir]
logger.debug(f"Searching for voice in path: {voice_dir}")
return await _find_file(voice_file, search_paths)
async def list_voices() -> List[str]:
"""List available voice files.
Returns:
List of voice names (without .pt extension)
"""
# Get api directory path
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
# Construct voice directory path relative to api directory
voice_dir = os.path.join(api_dir, settings.voices_dir)
# Ensure voice directory exists
os.makedirs(voice_dir, exist_ok=True)
# Search in voice directory
search_paths = [voice_dir]
logger.debug(f"Scanning for voices in path: {voice_dir}")
def filter_voice_files(name: str) -> bool:
return name.endswith(".pt")
voices = await _scan_directories(search_paths, filter_voice_files)
return sorted([name[:-3] for name in voices]) # Remove .pt extension
async def load_voice_tensor(
voice_path: str, device: str = "cpu", weights_only=False
) -> torch.Tensor:
"""Load voice tensor from file.
Args:
voice_path: Path to voice file
device: Device to load tensor to
Returns:
Voice tensor
Raises:
RuntimeError: If file cannot be read
"""
try:
async with aiofiles.open(voice_path, "rb") as f:
data = await f.read()
return torch.load(
io.BytesIO(data), map_location=device, weights_only=weights_only
)
except Exception as e:
raise RuntimeError(f"Failed to load voice tensor from {voice_path}: {e}")
async def save_voice_tensor(tensor: torch.Tensor, voice_path: str) -> None:
"""Save voice tensor to file.
Args:
tensor: Voice tensor to save
voice_path: Path to save voice file
Raises:
RuntimeError: If file cannot be written
"""
try:
buffer = io.BytesIO()
torch.save(tensor, buffer)
async with aiofiles.open(voice_path, "wb") as f:
await f.write(buffer.getvalue())
except Exception as e:
raise RuntimeError(f"Failed to save voice tensor to {voice_path}: {e}")
async def load_json(path: str) -> dict:
"""Load JSON file asynchronously.
Args:
path: Path to JSON file
Returns:
Parsed JSON data
Raises:
RuntimeError: If file cannot be read or parsed
"""
try:
async with aiofiles.open(path, "r", encoding="utf-8") as f:
content = await f.read()
return json.loads(content)
except Exception as e:
raise RuntimeError(f"Failed to load JSON file {path}: {e}")
async def load_model_weights(path: str, device: str = "cpu") -> dict:
"""Load model weights asynchronously.
Args:
path: Path to model file (.pth or .onnx)
device: Device to load model to
Returns:
Model weights
Raises:
RuntimeError: If file cannot be read
"""
try:
async with aiofiles.open(path, "rb") as f:
data = await f.read()
return torch.load(io.BytesIO(data), map_location=device, weights_only=True)
except Exception as e:
raise RuntimeError(f"Failed to load model weights from {path}: {e}")
async def read_file(path: str) -> str:
"""Read text file asynchronously.
Args:
path: Path to file
Returns:
File contents as string
Raises:
RuntimeError: If file cannot be read
"""
try:
async with aiofiles.open(path, "r", encoding="utf-8") as f:
return await f.read()
except Exception as e:
raise RuntimeError(f"Failed to read file {path}: {e}")
async def read_bytes(path: str) -> bytes:
"""Read file as bytes asynchronously.
Args:
path: Path to file
Returns:
File contents as bytes
Raises:
RuntimeError: If file cannot be read
"""
try:
async with aiofiles.open(path, "rb") as f:
return await f.read()
except Exception as e:
raise RuntimeError(f"Failed to read file {path}: {e}")
async def get_web_file_path(filename: str) -> str:
"""Get path to web static file.
Args:
filename: Name of file in web directory
Returns:
Absolute path to file
Raises:
RuntimeError: If file not found
"""
# Get project root directory (four levels up from core to get to project root)
root_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
)
# Construct web directory path relative to project root
web_dir = os.path.join("/app", settings.web_player_path)
# Search in web directory
search_paths = [web_dir]
logger.debug(f"Searching for web file in path: {web_dir}")
return await _find_file(filename, search_paths)
async def get_content_type(path: str) -> str:
"""Get content type for file.
Args:
path: Path to file
Returns:
Content type string
"""
ext = os.path.splitext(path)[1].lower()
return {
".html": "text/html",
".js": "application/javascript",
".css": "text/css",
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".gif": "image/gif",
".svg": "image/svg+xml",
".ico": "image/x-icon",
}.get(ext, "application/octet-stream")
async def verify_model_path(model_path: str) -> bool:
"""Verify model file exists at path."""
return await aiofiles.os.path.exists(model_path)
async def cleanup_temp_files() -> None:
"""Clean up old temp files on startup"""
try:
if not await aiofiles.os.path.exists(settings.temp_file_dir):
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
return
entries = await aiofiles.os.scandir(settings.temp_file_dir)
for entry in entries:
if entry.is_file():
stat = await aiofiles.os.stat(entry.path)
max_age = stat.st_mtime + (settings.max_temp_dir_age_hours * 3600)
if max_age < stat.st_mtime:
try:
await aiofiles.os.remove(entry.path)
logger.info(f"Cleaned up old temp file: {entry.name}")
except Exception as e:
logger.warning(
f"Failed to delete old temp file {entry.name}: {e}"
)
except Exception as e:
logger.warning(f"Error cleaning temp files: {e}")
async def get_temp_file_path(filename: str) -> str:
"""Get path to temporary audio file.
Args:
filename: Name of temp file
Returns:
Absolute path to temp file
Raises:
RuntimeError: If temp directory does not exist
"""
temp_path = os.path.join(settings.temp_file_dir, filename)
# Ensure temp directory exists
if not await aiofiles.os.path.exists(settings.temp_file_dir):
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
return temp_path
async def list_temp_files() -> List[str]:
"""List temporary audio files.
Returns:
List of temp file names
"""
if not await aiofiles.os.path.exists(settings.temp_file_dir):
return []
entries = await aiofiles.os.scandir(settings.temp_file_dir)
return [entry.name for entry in entries if entry.is_file()]
async def get_temp_dir_size() -> int:
"""Get total size of temp directory in bytes.
Returns:
Size in bytes
"""
if not await aiofiles.os.path.exists(settings.temp_file_dir):
return 0
total = 0
entries = await aiofiles.os.scandir(settings.temp_file_dir)
for entry in entries:
if entry.is_file():
stat = await aiofiles.os.stat(entry.path)
total += stat.st_size
return total

View file

@ -1,12 +0,0 @@
"""Model inference package."""
from .base import BaseModelBackend
from .kokoro_v1 import KokoroV1
from .model_manager import ModelManager, get_manager
__all__ = [
"BaseModelBackend",
"ModelManager",
"get_manager",
"KokoroV1",
]

View file

@ -1,127 +0,0 @@
"""Base interface for Kokoro inference."""
from abc import ABC, abstractmethod
from typing import AsyncGenerator, List, Optional, Tuple, Union
import numpy as np
import torch
class AudioChunk:
"""Class for audio chunks returned by model backends"""
def __init__(
self,
audio: np.ndarray,
word_timestamps: Optional[List] = [],
output: Optional[Union[bytes, np.ndarray]] = b"",
):
self.audio = audio
self.word_timestamps = word_timestamps
self.output = output
@staticmethod
def combine(audio_chunk_list: List):
output = AudioChunk(
audio_chunk_list[0].audio, audio_chunk_list[0].word_timestamps
)
for audio_chunk in audio_chunk_list[1:]:
output.audio = np.concatenate(
(output.audio, audio_chunk.audio), dtype=np.int16
)
if output.word_timestamps is not None:
output.word_timestamps += audio_chunk.word_timestamps
return output
class ModelBackend(ABC):
"""Abstract base class for model inference backend."""
@abstractmethod
async def load_model(self, path: str) -> None:
"""Load model from path.
Args:
path: Path to model file
Raises:
RuntimeError: If model loading fails
"""
pass
@abstractmethod
async def generate(
self,
text: str,
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
speed: float = 1.0,
) -> AsyncGenerator[AudioChunk, None]:
"""Generate audio from text.
Args:
text: Input text to synthesize
voice: Either a voice path or tuple of (name, tensor/path)
speed: Speed multiplier
Yields:
Generated audio chunks
Raises:
RuntimeError: If generation fails
"""
pass
@abstractmethod
def unload(self) -> None:
"""Unload model and free resources."""
pass
@property
@abstractmethod
def is_loaded(self) -> bool:
"""Check if model is loaded.
Returns:
True if model is loaded, False otherwise
"""
pass
@property
@abstractmethod
def device(self) -> str:
"""Get device model is running on.
Returns:
Device string ('cpu' or 'cuda')
"""
pass
class BaseModelBackend(ModelBackend):
"""Base implementation of model backend."""
def __init__(self):
"""Initialize base backend."""
self._model: Optional[torch.nn.Module] = None
self._device: str = "cpu"
@property
def is_loaded(self) -> bool:
"""Check if model is loaded."""
return self._model is not None
@property
def device(self) -> str:
"""Get device model is running on."""
return self._device
def unload(self) -> None:
"""Unload model and free resources."""
if self._model is not None:
del self._model
self._model = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()

View file

@ -1,370 +0,0 @@
"""Clean Kokoro implementation with controlled resource management."""
import os
from typing import AsyncGenerator, Dict, Optional, Tuple, Union
import numpy as np
import torch
from kokoro import KModel, KPipeline
from loguru import logger
from ..core import paths
from ..core.config import settings
from ..core.model_config import model_config
from ..structures.schemas import WordTimestamp
from .base import AudioChunk, BaseModelBackend
class KokoroV1(BaseModelBackend):
"""Kokoro backend with controlled resource management."""
def __init__(self):
"""Initialize backend with environment-based configuration."""
super().__init__()
# Strictly respect settings.use_gpu
self._device = settings.get_device()
self._model: Optional[KModel] = None
self._pipelines: Dict[str, KPipeline] = {} # Store pipelines by lang_code
async def load_model(self, path: str) -> None:
"""Load pre-baked model.
Args:
path: Path to model file
Raises:
RuntimeError: If model loading fails
"""
try:
# Get verified model path
model_path = await paths.get_model_path(path)
config_path = os.path.join(os.path.dirname(model_path), "config.json")
if not os.path.exists(config_path):
raise RuntimeError(f"Config file not found: {config_path}")
logger.info(f"Loading Kokoro model on {self._device}")
logger.info(f"Config path: {config_path}")
logger.info(f"Model path: {model_path}")
# Load model and let KModel handle device mapping
self._model = KModel(config=config_path, model=model_path).eval()
# For MPS, manually move ISTFT layers to CPU while keeping rest on MPS
if self._device == "mps":
logger.info(
"Moving model to MPS device with CPU fallback for unsupported operations"
)
self._model = self._model.to(torch.device("mps"))
elif self._device == "cuda":
self._model = self._model.cuda()
else:
self._model = self._model.cpu()
except FileNotFoundError as e:
raise e
except Exception as e:
raise RuntimeError(f"Failed to load Kokoro model: {e}")
def _get_pipeline(self, lang_code: str) -> KPipeline:
"""Get or create pipeline for language code.
Args:
lang_code: Language code to use
Returns:
KPipeline instance for the language
"""
if not self._model:
raise RuntimeError("Model not loaded")
if lang_code not in self._pipelines:
logger.info(f"Creating new pipeline for language code: {lang_code}")
self._pipelines[lang_code] = KPipeline(
lang_code=lang_code, model=self._model, device=self._device
)
return self._pipelines[lang_code]
async def generate_from_tokens(
self,
tokens: str,
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
speed: float = 1.0,
lang_code: Optional[str] = None,
) -> AsyncGenerator[np.ndarray, None]:
"""Generate audio from phoneme tokens.
Args:
tokens: Input phoneme tokens to synthesize
voice: Either a voice path string or a tuple of (voice_name, voice_tensor/path)
speed: Speed multiplier
lang_code: Optional language code override
Yields:
Generated audio chunks
Raises:
RuntimeError: If generation fails
"""
if not self.is_loaded:
raise RuntimeError("Model not loaded")
try:
# Memory management for GPU
if self._device == "cuda":
if self._check_memory():
self._clear_memory()
# Handle voice input
voice_path: str
voice_name: str
if isinstance(voice, tuple):
voice_name, voice_data = voice
if isinstance(voice_data, str):
voice_path = voice_data
else:
# Save tensor to temporary file
import tempfile
temp_dir = tempfile.gettempdir()
voice_path = os.path.join(temp_dir, f"{voice_name}.pt")
# Save tensor with CPU mapping for portability
torch.save(voice_data.cpu(), voice_path)
else:
voice_path = voice
voice_name = os.path.splitext(os.path.basename(voice_path))[0]
# Load voice tensor with proper device mapping
voice_tensor = await paths.load_voice_tensor(
voice_path, device=self._device
)
# Save back to a temporary file with proper device mapping
import tempfile
temp_dir = tempfile.gettempdir()
temp_path = os.path.join(
temp_dir, f"temp_voice_{os.path.basename(voice_path)}"
)
await paths.save_voice_tensor(voice_tensor, temp_path)
voice_path = temp_path
# Use provided lang_code, settings voice code override, or first letter of voice name
if lang_code: # api is given priority
pipeline_lang_code = lang_code
elif settings.default_voice_code: # settings is next priority
pipeline_lang_code = settings.default_voice_code
else: # voice name is default/fallback
pipeline_lang_code = voice_name[0].lower()
pipeline = self._get_pipeline(pipeline_lang_code)
logger.debug(
f"Generating audio from tokens with lang_code '{pipeline_lang_code}': '{tokens[:100]}{'...' if len(tokens) > 100 else ''}'"
)
for result in pipeline.generate_from_tokens(
tokens=tokens, voice=voice_path, speed=speed, model=self._model
):
if result.audio is not None:
logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
yield result.audio.numpy()
else:
logger.warning("No audio in chunk")
except Exception as e:
logger.error(f"Generation failed: {e}")
if (
self._device == "cuda"
and model_config.pytorch_gpu.retry_on_oom
and "out of memory" in str(e).lower()
):
self._clear_memory()
async for chunk in self.generate_from_tokens(
tokens, voice, speed, lang_code
):
yield chunk
raise
async def generate(
self,
text: str,
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
speed: float = 1.0,
lang_code: Optional[str] = None,
return_timestamps: Optional[bool] = False,
) -> AsyncGenerator[AudioChunk, None]:
"""Generate audio using model.
Args:
text: Input text to synthesize
voice: Either a voice path string or a tuple of (voice_name, voice_tensor/path)
speed: Speed multiplier
lang_code: Optional language code override
Yields:
Generated audio chunks
Raises:
RuntimeError: If generation fails
"""
if not self.is_loaded:
raise RuntimeError("Model not loaded")
try:
# Memory management for GPU
if self._device == "cuda":
if self._check_memory():
self._clear_memory()
# Handle voice input
voice_path: str
voice_name: str
if isinstance(voice, tuple):
voice_name, voice_data = voice
if isinstance(voice_data, str):
voice_path = voice_data
else:
# Save tensor to temporary file
import tempfile
temp_dir = tempfile.gettempdir()
voice_path = os.path.join(temp_dir, f"{voice_name}.pt")
# Save tensor with CPU mapping for portability
torch.save(voice_data.cpu(), voice_path)
else:
voice_path = voice
voice_name = os.path.splitext(os.path.basename(voice_path))[0]
# Load voice tensor with proper device mapping
voice_tensor = await paths.load_voice_tensor(
voice_path, device=self._device
)
# Save back to a temporary file with proper device mapping
import tempfile
temp_dir = tempfile.gettempdir()
temp_path = os.path.join(
temp_dir, f"temp_voice_{os.path.basename(voice_path)}"
)
await paths.save_voice_tensor(voice_tensor, temp_path)
voice_path = temp_path
# Use provided lang_code, settings voice code override, or first letter of voice name
pipeline_lang_code = (
lang_code
if lang_code
else (
settings.default_voice_code
if settings.default_voice_code
else voice_name[0].lower()
)
)
pipeline = self._get_pipeline(pipeline_lang_code)
logger.debug(
f"Generating audio for text with lang_code '{pipeline_lang_code}': '{text[:100]}{'...' if len(text) > 100 else ''}'"
)
for result in pipeline(
text, voice=voice_path, speed=speed, model=self._model
):
if result.audio is not None:
logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
word_timestamps = None
if (
return_timestamps
and hasattr(result, "tokens")
and result.tokens
):
word_timestamps = []
current_offset = 0.0
logger.debug(
f"Processing chunk timestamps with {len(result.tokens)} tokens"
)
if result.pred_dur is not None:
try:
# Add timestamps with offset
for token in result.tokens:
if not all(
hasattr(token, attr)
for attr in [
"text",
"start_ts",
"end_ts",
]
):
continue
if not token.text or not token.text.strip():
continue
start_time = float(token.start_ts) + current_offset
end_time = float(token.end_ts) + current_offset
word_timestamps.append(
WordTimestamp(
word=str(token.text).strip(),
start_time=start_time,
end_time=end_time,
)
)
logger.debug(
f"Added timestamp for word '{token.text}': {start_time:.3f}s - {end_time:.3f}s"
)
except Exception as e:
logger.error(
f"Failed to process timestamps for chunk: {e}"
)
yield AudioChunk(
result.audio.numpy(), word_timestamps=word_timestamps
)
else:
logger.warning("No audio in chunk")
except Exception as e:
logger.error(f"Generation failed: {e}")
if (
self._device == "cuda"
and model_config.pytorch_gpu.retry_on_oom
and "out of memory" in str(e).lower()
):
self._clear_memory()
async for chunk in self.generate(text, voice, speed, lang_code):
yield chunk
raise
def _check_memory(self) -> bool:
"""Check if memory usage is above threshold."""
if self._device == "cuda":
memory_gb = torch.cuda.memory_allocated() / 1e9
return memory_gb > model_config.pytorch_gpu.memory_threshold
# MPS doesn't provide memory management APIs
return False
def _clear_memory(self) -> None:
"""Clear device memory."""
if self._device == "cuda":
torch.cuda.empty_cache()
torch.cuda.synchronize()
elif self._device == "mps":
# Empty cache if available (future-proofing)
if hasattr(torch.mps, "empty_cache"):
torch.mps.empty_cache()
def unload(self) -> None:
"""Unload model and free resources."""
if self._model is not None:
del self._model
self._model = None
for pipeline in self._pipelines.values():
del pipeline
self._pipelines.clear()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
@property
def is_loaded(self) -> bool:
"""Check if model is loaded."""
return self._model is not None
@property
def device(self) -> str:
"""Get device model is running on."""
return self._device

View file

@ -1,171 +0,0 @@
"""Kokoro V1 model management."""
from typing import Optional
from loguru import logger
from ..core import paths
from ..core.config import settings
from ..core.model_config import ModelConfig, model_config
from .base import BaseModelBackend
from .kokoro_v1 import KokoroV1
class ModelManager:
"""Manages Kokoro V1 model loading and inference."""
# Singleton instance
_instance = None
def __init__(self, config: Optional[ModelConfig] = None):
"""Initialize manager.
Args:
config: Optional model configuration override
"""
self._config = config or model_config
self._backend: Optional[KokoroV1] = None # Explicitly type as KokoroV1
self._device: Optional[str] = None
def _determine_device(self) -> str:
"""Determine device based on settings."""
return "cuda" if settings.use_gpu else "cpu"
async def initialize(self) -> None:
"""Initialize Kokoro V1 backend."""
try:
self._device = self._determine_device()
logger.info(f"Initializing Kokoro V1 on {self._device}")
self._backend = KokoroV1()
except Exception as e:
raise RuntimeError(f"Failed to initialize Kokoro V1: {e}")
async def initialize_with_warmup(self, voice_manager) -> tuple[str, str, int]:
"""Initialize and warm up model.
Args:
voice_manager: Voice manager instance for warmup
Returns:
Tuple of (device, backend type, voice count)
Raises:
RuntimeError: If initialization fails
"""
import time
start = time.perf_counter()
try:
# Initialize backend
await self.initialize()
# Load model
model_path = self._config.pytorch_kokoro_v1_file
await self.load_model(model_path)
# Use paths module to get voice path
try:
voices = await paths.list_voices()
voice_path = await paths.get_voice_path(settings.default_voice)
# Warm up with short text
warmup_text = "Warmup text for initialization."
# Use default voice name for warmup
voice_name = settings.default_voice
logger.debug(f"Using default voice '{voice_name}' for warmup")
async for _ in self.generate(warmup_text, (voice_name, voice_path)):
pass
except Exception as e:
raise RuntimeError(f"Failed to get default voice: {e}")
ms = int((time.perf_counter() - start) * 1000)
logger.info(f"Warmup completed in {ms}ms")
return self._device, "kokoro_v1", len(voices)
except FileNotFoundError as e:
logger.error("""
Model files not found! You need to download the Kokoro V1 model:
1. Download model using the script:
python docker/scripts/download_model.py --output api/src/models/v1_0
2. Or set environment variable in docker-compose:
DOWNLOAD_MODEL=true
""")
exit(0)
except Exception as e:
raise RuntimeError(f"Warmup failed: {e}")
def get_backend(self) -> BaseModelBackend:
"""Get initialized backend.
Returns:
Initialized backend instance
Raises:
RuntimeError: If backend not initialized
"""
if not self._backend:
raise RuntimeError("Backend not initialized")
return self._backend
async def load_model(self, path: str) -> None:
"""Load model using initialized backend.
Args:
path: Path to model file
Raises:
RuntimeError: If loading fails
"""
if not self._backend:
raise RuntimeError("Backend not initialized")
try:
await self._backend.load_model(path)
except FileNotFoundError as e:
raise e
except Exception as e:
raise RuntimeError(f"Failed to load model: {e}")
async def generate(self, *args, **kwargs):
"""Generate audio using initialized backend.
Raises:
RuntimeError: If generation fails
"""
if not self._backend:
raise RuntimeError("Backend not initialized")
try:
async for chunk in self._backend.generate(*args, **kwargs):
yield chunk
except Exception as e:
raise RuntimeError(f"Generation failed: {e}")
def unload_all(self) -> None:
"""Unload model and free resources."""
if self._backend:
self._backend.unload()
self._backend = None
@property
def current_backend(self) -> str:
"""Get current backend type."""
return "kokoro_v1"
async def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
"""Get model manager instance.
Args:
config: Optional configuration override
Returns:
ModelManager instance
"""
if ModelManager._instance is None:
ModelManager._instance = ModelManager(config)
return ModelManager._instance

View file

@ -1,115 +0,0 @@
"""Voice management with controlled resource handling."""
from typing import Dict, List, Optional
import aiofiles
import torch
from loguru import logger
from ..core import paths
from ..core.config import settings
class VoiceManager:
"""Manages voice loading and caching with controlled resource usage."""
# Singleton instance
_instance = None
def __init__(self):
"""Initialize voice manager."""
# Strictly respect settings.use_gpu
self._device = settings.get_device()
self._voices: Dict[str, torch.Tensor] = {}
async def get_voice_path(self, voice_name: str) -> str:
"""Get path to voice file.
Args:
voice_name: Name of voice
Returns:
Path to voice file
Raises:
RuntimeError: If voice not found
"""
return await paths.get_voice_path(voice_name)
async def load_voice(
self, voice_name: str, device: Optional[str] = None
) -> torch.Tensor:
"""Load voice tensor.
Args:
voice_name: Name of voice to load
device: Optional override for target device
Returns:
Voice tensor
Raises:
RuntimeError: If voice not found
"""
try:
voice_path = await self.get_voice_path(voice_name)
target_device = device or self._device
voice = await paths.load_voice_tensor(voice_path, target_device)
self._voices[voice_name] = voice
return voice
except Exception as e:
raise RuntimeError(f"Failed to load voice {voice_name}: {e}")
async def combine_voices(
self, voices: List[str], device: Optional[str] = None
) -> torch.Tensor:
"""Combine multiple voices.
Args:
voices: List of voice names to combine
device: Optional override for target device
Returns:
Combined voice tensor
Raises:
RuntimeError: If any voice not found
"""
if len(voices) < 2:
raise ValueError("Need at least 2 voices to combine")
target_device = device or self._device
voice_tensors = []
for name in voices:
voice = await self.load_voice(name, target_device)
voice_tensors.append(voice)
combined = torch.mean(torch.stack(voice_tensors), dim=0)
return combined
async def list_voices(self) -> List[str]:
"""List available voice names.
Returns:
List of voice names
"""
return await paths.list_voices()
def cache_info(self) -> Dict[str, int]:
"""Get cache statistics.
Returns:
Dict with cache statistics
"""
return {"loaded_voices": len(self._voices), "device": self._device}
async def get_manager() -> VoiceManager:
"""Get voice manager instance.
Returns:
VoiceManager instance
"""
if VoiceManager._instance is None:
VoiceManager._instance = VoiceManager()
return VoiceManager._instance

View file

@ -2,22 +2,19 @@
FastAPI OpenAI Compatible API
"""
import os
import sys
from contextlib import asynccontextmanager
from pathlib import Path
import torch
import uvicorn
from loguru import logger
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from .core.config import settings
from .routers.debug import router as debug_router
from .services.tts_model import TTSModel
from .routers.development import router as dev_router
from .services.tts_service import TTSService
from .routers.openai_compatible import router as openai_router
from .routers.web_player import router as web_router
def setup_logger():
@ -28,10 +25,9 @@ def setup_logger():
"sink": sys.stdout,
"format": "<fg #2E8B57>{time:hh:mm:ss A}</fg #2E8B57> | "
"{level: <8} | "
"<fg #4169E1>{module}:{line}</fg #4169E1> | "
"{message}",
"colorize": True,
"level": "DEBUG",
"level": "INFO",
},
],
}
@ -43,34 +39,15 @@ def setup_logger():
# Configure logger
setup_logger()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for model initialization"""
from .inference.model_manager import get_manager
from .inference.voice_manager import get_manager as get_voice_manager
from .services.temp_manager import cleanup_temp_files
# Clean old temp files on startup
await cleanup_temp_files()
logger.info("Loading TTS model and voice packs...")
try:
# Initialize managers
model_manager = await get_manager()
voice_manager = await get_voice_manager()
# Initialize model with warmup and get status
device, model, voicepack_count = await model_manager.initialize_with_warmup(
voice_manager
)
except Exception as e:
logger.error(f"Failed to initialize model: {e}")
raise
boundary = "" * 2 * 12
# Initialize the main model with warm-up
voicepack_count = await TTSModel.setup()
# boundary = "█████╗"*9
boundary = "" * 24
startup_msg = f"""
{boundary}
@ -84,24 +61,9 @@ async def lifespan(app: FastAPI):
{boundary}
"""
startup_msg += f"\nModel warmed up on {device}: {model}"
if device == "mps":
startup_msg += "\nUsing Apple Metal Performance Shaders (MPS)"
elif device == "cuda":
startup_msg += f"\nCUDA: {torch.cuda.is_available()}"
else:
startup_msg += "\nRunning on CPU"
startup_msg += f"\n{voicepack_count} voice packs loaded"
# Add web player info if enabled
if settings.enable_web_player:
startup_msg += (
f"\n\nBeta Web Player: http://{settings.host}:{settings.port}/web/"
)
startup_msg += f"\nor http://localhost:{settings.port}/web/"
else:
startup_msg += "\n\nWeb Player: disabled"
# TODO: Improve CPU warmup, threads, memory, etc
startup_msg += f"\nModel warmed up on {TTSModel.get_device()}"
startup_msg += f"\n{voicepack_count} voice packs loaded\n"
startup_msg += f"\n{boundary}\n"
logger.info(startup_msg)
@ -117,11 +79,10 @@ app = FastAPI(
openapi_url="/openapi.json", # Explicitly enable OpenAPI schema
)
# Add CORS middleware if enabled
if settings.cors_enabled:
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
@ -129,10 +90,8 @@ if settings.cors_enabled:
# Include routers
app.include_router(openai_router, prefix="/v1")
app.include_router(dev_router) # Development endpoints
app.include_router(debug_router) # Debug endpoints
if settings.enable_web_player:
app.include_router(web_router, prefix="/web") # Web player static files
app.include_router(dev_router) # New development endpoints
# app.include_router(text_router) # Deprecated but still live for backwards compatibility
# Health check endpoint

View file

@ -1,150 +0,0 @@
{
"istftnet": {
"upsample_kernel_sizes": [20, 12],
"upsample_rates": [10, 6],
"gen_istft_hop_size": 5,
"gen_istft_n_fft": 20,
"resblock_dilation_sizes": [
[1, 3, 5],
[1, 3, 5],
[1, 3, 5]
],
"resblock_kernel_sizes": [3, 7, 11],
"upsample_initial_channel": 512
},
"dim_in": 64,
"dropout": 0.2,
"hidden_dim": 512,
"max_conv_dim": 512,
"max_dur": 50,
"multispeaker": true,
"n_layer": 3,
"n_mels": 80,
"n_token": 178,
"style_dim": 128,
"text_encoder_kernel_size": 5,
"plbert": {
"hidden_size": 768,
"num_attention_heads": 12,
"intermediate_size": 2048,
"max_position_embeddings": 512,
"num_hidden_layers": 12,
"dropout": 0.1
},
"vocab": {
";": 1,
":": 2,
",": 3,
".": 4,
"!": 5,
"?": 6,
"—": 9,
"…": 10,
"\"": 11,
"(": 12,
")": 13,
"“": 14,
"”": 15,
" ": 16,
"\u0303": 17,
"ʣ": 18,
"ʥ": 19,
"ʦ": 20,
"ʨ": 21,
"ᵝ": 22,
"\uAB67": 23,
"A": 24,
"I": 25,
"O": 31,
"Q": 33,
"S": 35,
"T": 36,
"W": 39,
"Y": 41,
"ᵊ": 42,
"a": 43,
"b": 44,
"c": 45,
"d": 46,
"e": 47,
"f": 48,
"h": 50,
"i": 51,
"j": 52,
"k": 53,
"l": 54,
"m": 55,
"n": 56,
"o": 57,
"p": 58,
"q": 59,
"r": 60,
"s": 61,
"t": 62,
"u": 63,
"v": 64,
"w": 65,
"x": 66,
"y": 67,
"z": 68,
"ɑ": 69,
"ɐ": 70,
"ɒ": 71,
"æ": 72,
"β": 75,
"ɔ": 76,
"ɕ": 77,
"ç": 78,
"ɖ": 80,
"ð": 81,
"ʤ": 82,
"ə": 83,
"ɚ": 85,
"ɛ": 86,
"ɜ": 87,
"ɟ": 90,
"ɡ": 92,
"ɥ": 99,
"ɨ": 101,
"ɪ": 102,
"ʝ": 103,
"ɯ": 110,
"ɰ": 111,
"ŋ": 112,
"ɳ": 113,
"ɲ": 114,
"ɴ": 115,
"ø": 116,
"ɸ": 118,
"θ": 119,
"œ": 120,
"ɹ": 123,
"ɾ": 125,
"ɻ": 126,
"ʁ": 128,
"ɽ": 129,
"ʂ": 130,
"ʃ": 131,
"ʈ": 132,
"ʧ": 133,
"ʊ": 135,
"ʋ": 136,
"ʌ": 138,
"ɣ": 139,
"ɤ": 140,
"χ": 142,
"ʎ": 143,
"ʒ": 147,
"ʔ": 148,
"ˈ": 156,
"ˌ": 157,
"ː": 158,
"ʰ": 162,
"ʲ": 164,
"↓": 169,
"→": 171,
"↗": 172,
"↘": 173,
"ᵻ": 177
}
}

View file

@ -1,209 +0,0 @@
import threading
import time
from datetime import datetime
import psutil
import torch
from fastapi import APIRouter
try:
import GPUtil
GPU_AVAILABLE = True
except ImportError:
GPU_AVAILABLE = False
router = APIRouter(tags=["debug"])
@router.get("/debug/threads")
async def get_thread_info():
process = psutil.Process()
current_threads = threading.enumerate()
# Get per-thread CPU times
thread_details = []
for thread in current_threads:
thread_info = {
"name": thread.name,
"id": thread.ident,
"alive": thread.is_alive(),
"daemon": thread.daemon,
}
thread_details.append(thread_info)
return {
"total_threads": process.num_threads(),
"active_threads": len(current_threads),
"thread_names": [t.name for t in current_threads],
"thread_details": thread_details,
"memory_mb": process.memory_info().rss / 1024 / 1024,
}
@router.get("/debug/storage")
async def get_storage_info():
# Get disk partitions
partitions = psutil.disk_partitions()
storage_info = []
for partition in partitions:
try:
usage = psutil.disk_usage(partition.mountpoint)
storage_info.append(
{
"device": partition.device,
"mountpoint": partition.mountpoint,
"fstype": partition.fstype,
"total_gb": usage.total / (1024**3),
"used_gb": usage.used / (1024**3),
"free_gb": usage.free / (1024**3),
"percent_used": usage.percent,
}
)
except PermissionError:
continue
return {"storage_info": storage_info}
@router.get("/debug/system")
async def get_system_info():
process = psutil.Process()
# CPU Info
cpu_info = {
"cpu_count": psutil.cpu_count(),
"cpu_percent": psutil.cpu_percent(interval=1),
"per_cpu_percent": psutil.cpu_percent(interval=1, percpu=True),
"load_avg": psutil.getloadavg(),
}
# Memory Info
virtual_memory = psutil.virtual_memory()
swap_memory = psutil.swap_memory()
memory_info = {
"virtual": {
"total_gb": virtual_memory.total / (1024**3),
"available_gb": virtual_memory.available / (1024**3),
"used_gb": virtual_memory.used / (1024**3),
"percent": virtual_memory.percent,
},
"swap": {
"total_gb": swap_memory.total / (1024**3),
"used_gb": swap_memory.used / (1024**3),
"free_gb": swap_memory.free / (1024**3),
"percent": swap_memory.percent,
},
}
# Process Info
process_info = {
"pid": process.pid,
"status": process.status(),
"create_time": datetime.fromtimestamp(process.create_time()).isoformat(),
"cpu_percent": process.cpu_percent(),
"memory_percent": process.memory_percent(),
}
# Network Info
network_info = {
"connections": len(process.net_connections()),
"network_io": psutil.net_io_counters()._asdict(),
}
# GPU Info if available
gpu_info = None
if torch.backends.mps.is_available():
gpu_info = {
"type": "MPS",
"available": True,
"device": "Apple Silicon",
"backend": "Metal",
}
elif GPU_AVAILABLE:
try:
gpus = GPUtil.getGPUs()
gpu_info = [
{
"id": gpu.id,
"name": gpu.name,
"load": gpu.load,
"memory": {
"total": gpu.memoryTotal,
"used": gpu.memoryUsed,
"free": gpu.memoryFree,
"percent": (gpu.memoryUsed / gpu.memoryTotal) * 100,
},
"temperature": gpu.temperature,
}
for gpu in gpus
]
except Exception:
gpu_info = "GPU information unavailable"
return {
"cpu": cpu_info,
"memory": memory_info,
"process": process_info,
"network": network_info,
"gpu": gpu_info,
}
@router.get("/debug/session_pools")
async def get_session_pool_info():
"""Get information about ONNX session pools."""
from ..inference.model_manager import get_manager
manager = await get_manager()
pools = manager._session_pools
current_time = time.time()
pool_info = {}
# Get CPU pool info
if "onnx_cpu" in pools:
cpu_pool = pools["onnx_cpu"]
pool_info["cpu"] = {
"active_sessions": len(cpu_pool._sessions),
"max_sessions": cpu_pool._max_size,
"sessions": [
{"model": path, "age_seconds": current_time - info.last_used}
for path, info in cpu_pool._sessions.items()
],
}
# Get GPU pool info
if "onnx_gpu" in pools:
gpu_pool = pools["onnx_gpu"]
pool_info["gpu"] = {
"active_sessions": len(gpu_pool._sessions),
"max_streams": gpu_pool._max_size,
"available_streams": len(gpu_pool._available_streams),
"sessions": [
{
"model": path,
"age_seconds": current_time - info.last_used,
"stream_id": info.stream_id,
}
for path, info in gpu_pool._sessions.items()
],
}
# Add GPU memory info if available
if GPU_AVAILABLE:
try:
gpus = GPUtil.getGPUs()
if gpus:
gpu = gpus[0] # Assume first GPU
pool_info["gpu"]["memory"] = {
"total_mb": gpu.memoryTotal,
"used_mb": gpu.memoryUsed,
"free_mb": gpu.memoryFree,
"percent_used": (gpu.memoryUsed / gpu.memoryTotal) * 100,
}
except Exception:
pass
return pool_info

View file

@ -1,49 +1,35 @@
import base64
import json
import os
import re
from pathlib import Path
from typing import AsyncGenerator, List, Tuple, Union
from typing import List
import numpy as np
import torch
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
from kokoro import KPipeline
from loguru import logger
from fastapi import Depends, Response, APIRouter, HTTPException
from ..core.config import settings
from ..inference.base import AudioChunk
from ..services.audio import AudioNormalizer, AudioService
from ..services.streaming_audio_writer import StreamingAudioWriter
from ..services.temp_manager import TempFileWriter
from ..services.text_processing import smart_split
from ..services.audio import AudioService
from ..services.tts_model import TTSModel
from ..services.tts_service import TTSService
from ..structures import CaptionedSpeechRequest, CaptionedSpeechResponse, WordTimestamp
from ..structures.custom_responses import JSONStreamingResponse
from ..structures.text_schemas import (
GenerateFromPhonemesRequest,
PhonemeRequest,
PhonemeResponse,
GenerateFromPhonemesRequest,
)
from .openai_compatible import process_and_validate_voices, stream_audio_chunks
from ..services.text_processing import tokenize, phonemize
router = APIRouter(tags=["text processing"])
async def get_tts_service() -> TTSService:
def get_tts_service() -> TTSService:
"""Dependency to get TTSService instance"""
return (
await TTSService.create()
) # Create service with properly initialized managers
return TTSService()
@router.post("/text/phonemize", response_model=PhonemeResponse, tags=["deprecated"])
@router.post("/dev/phonemize", response_model=PhonemeResponse)
async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse:
"""Convert text to phonemes using Kokoro's quiet mode.
"""Convert text to phonemes and tokens
Args:
request: Request containing text and language
tts_service: Injected TTSService instance
Returns:
Phonemes and token IDs
@ -52,17 +38,16 @@ async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse:
if not request.text:
raise ValueError("Text cannot be empty")
# Initialize Kokoro pipeline in quiet mode (no model)
pipeline = KPipeline(lang_code=request.language, model=False)
# Get first result from pipeline (we only need one since we're not chunking)
for result in pipeline(request.text):
# result.graphemes = original text
# result.phonemes = phonemized text
# result.tokens = token objects (if available)
return PhonemeResponse(phonemes=result.phonemes, tokens=[])
# Get phonemes
phonemes = phonemize(request.text, request.language)
if not phonemes:
raise ValueError("Failed to generate phonemes")
# Get tokens
tokens = tokenize(phonemes)
tokens = [0] + tokens + [0] # Add start/end tokens
return PhonemeResponse(phonemes=phonemes, tokens=tokens)
except ValueError as e:
logger.error(f"Error in phoneme generation: {str(e)}")
raise HTTPException(
@ -75,338 +60,71 @@ async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse:
)
@router.post("/text/generate_from_phonemes", tags=["deprecated"])
@router.post("/dev/generate_from_phonemes")
async def generate_from_phonemes(
request: GenerateFromPhonemesRequest,
client_request: Request,
tts_service: TTSService = Depends(get_tts_service),
) -> StreamingResponse:
"""Generate audio directly from phonemes using Kokoro's phoneme format"""
try:
# Basic validation
if not isinstance(request.phonemes, str):
raise ValueError("Phonemes must be a string")
) -> Response:
"""Generate audio directly from phonemes
Args:
request: Request containing phonemes and generation parameters
tts_service: Injected TTSService instance
Returns:
WAV audio bytes
"""
# Validate phonemes first
if not request.phonemes:
raise ValueError("Phonemes cannot be empty")
# Create streaming audio writer and normalizer
writer = StreamingAudioWriter(format="wav", sample_rate=24000, channels=1)
normalizer = AudioNormalizer()
async def generate_chunks():
try:
# Generate audio from phonemes
chunk_audio, _ = await tts_service.generate_from_phonemes(
phonemes=request.phonemes, # Pass complete phoneme string
voice=request.voice,
speed=1.0,
raise HTTPException(
status_code=400,
detail={"error": "Invalid request", "message": "Phonemes cannot be empty"},
)
if chunk_audio is not None:
# Normalize audio before writing
normalized_audio = normalizer.normalize(chunk_audio)
# Write chunk and yield bytes
chunk_bytes = writer.write_chunk(normalized_audio)
if chunk_bytes:
yield chunk_bytes
# Validate voice exists
voice_path = tts_service._get_voice_path(request.voice)
if not voice_path:
raise HTTPException(
status_code=400,
detail={
"error": "Invalid request",
"message": f"Voice not found: {request.voice}",
},
)
# Finalize and yield remaining bytes
final_bytes = writer.write_chunk(finalize=True)
if final_bytes:
yield final_bytes
writer.close()
else:
raise ValueError("Failed to generate audio data")
try:
# Load voice
voicepack = tts_service._load_voice(voice_path)
except Exception as e:
logger.error(f"Error in audio generation: {str(e)}")
# Clean up writer on error
writer.close()
# Re-raise the original exception
raise
# Convert phonemes to tokens
tokens = tokenize(request.phonemes)
tokens = [0] + tokens + [0] # Add start/end tokens
return StreamingResponse(
generate_chunks(),
# Generate audio directly from tokens
audio = TTSModel.generate_from_tokens(tokens, voicepack, request.speed)
# Convert to WAV bytes
wav_bytes = AudioService.convert_audio(
audio, 24000, "wav", is_first_chunk=True, is_last_chunk=True, stream=False
)
return Response(
content=wav_bytes,
media_type="audio/wav",
headers={
"Content-Disposition": "attachment; filename=speech.wav",
"X-Accel-Buffering": "no",
"Cache-Control": "no-cache",
"Transfer-Encoding": "chunked",
},
)
except ValueError as e:
logger.error(f"Error generating audio: {str(e)}")
logger.error(f"Invalid request: {str(e)}")
raise HTTPException(
status_code=400,
detail={
"error": "validation_error",
"message": str(e),
"type": "invalid_request_error",
},
status_code=400, detail={"error": "Invalid request", "message": str(e)}
)
except Exception as e:
logger.error(f"Error generating audio: {str(e)}")
raise HTTPException(
status_code=500,
detail={
"error": "processing_error",
"message": str(e),
"type": "server_error",
},
)
@router.post("/dev/captioned_speech")
async def create_captioned_speech(
request: CaptionedSpeechRequest,
client_request: Request,
x_raw_response: str = Header(None, alias="x-raw-response"),
tts_service: TTSService = Depends(get_tts_service),
):
"""Generate audio with word-level timestamps using streaming approach"""
try:
# model_name = get_model_name(request.model)
tts_service = await get_tts_service()
voice_name = await process_and_validate_voices(request.voice, tts_service)
# Set content type based on format
content_type = {
"mp3": "audio/mpeg",
"opus": "audio/opus",
"m4a": "audio/mp4",
"flac": "audio/flac",
"wav": "audio/wav",
"pcm": "audio/pcm",
}.get(request.response_format, f"audio/{request.response_format}")
writer = StreamingAudioWriter(request.response_format, sample_rate=24000)
# Check if streaming is requested (default for OpenAI client)
if request.stream:
# Create generator but don't start it yet
generator = stream_audio_chunks(
tts_service, request, client_request, writer
)
# If download link requested, wrap generator with temp file writer
if request.return_download_link:
from ..services.temp_manager import TempFileWriter
temp_writer = TempFileWriter(request.response_format)
await temp_writer.__aenter__() # Initialize temp file
# Get download path immediately after temp file creation
download_path = temp_writer.download_path
# Create response headers with download path
headers = {
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
"X-Accel-Buffering": "no",
"Cache-Control": "no-cache",
"Transfer-Encoding": "chunked",
"X-Download-Path": download_path,
}
# Create async generator for streaming
async def dual_output():
try:
# Write chunks to temp file and stream
async for chunk_data in generator:
# The timestamp acumulator is only used when word level time stamps are generated but no audio is returned.
timestamp_acumulator = []
if chunk_data.output: # Skip empty chunks
await temp_writer.write(chunk_data.output)
base64_chunk = base64.b64encode(
chunk_data.output
).decode("utf-8")
# Add any chunks that may be in the acumulator into the return word_timestamps
if chunk_data.word_timestamps is not None:
chunk_data.word_timestamps = (
timestamp_acumulator + chunk_data.word_timestamps
)
timestamp_acumulator = []
else:
chunk_data.word_timestamps = []
yield CaptionedSpeechResponse(
audio=base64_chunk,
audio_format=content_type,
timestamps=chunk_data.word_timestamps,
)
else:
if (
chunk_data.word_timestamps is not None
and len(chunk_data.word_timestamps) > 0
):
timestamp_acumulator += chunk_data.word_timestamps
# Finalize the temp file
await temp_writer.finalize()
except Exception as e:
logger.error(f"Error in dual output streaming: {e}")
await temp_writer.__aexit__(type(e), e, e.__traceback__)
raise
finally:
# Ensure temp writer is closed
if not temp_writer._finalized:
await temp_writer.__aexit__(None, None, None)
writer.close()
# Stream with temp file writing
return JSONStreamingResponse(
dual_output(), media_type="application/json", headers=headers
)
async def single_output():
try:
# The timestamp acumulator is only used when word level time stamps are generated but no audio is returned.
timestamp_acumulator = []
# Stream chunks
async for chunk_data in generator:
if chunk_data.output: # Skip empty chunks
# Encode the chunk bytes into base 64
base64_chunk = base64.b64encode(chunk_data.output).decode(
"utf-8"
)
# Add any chunks that may be in the acumulator into the return word_timestamps
if chunk_data.word_timestamps is not None:
chunk_data.word_timestamps = (
timestamp_acumulator + chunk_data.word_timestamps
)
else:
chunk_data.word_timestamps = []
timestamp_acumulator = []
yield CaptionedSpeechResponse(
audio=base64_chunk,
audio_format=content_type,
timestamps=chunk_data.word_timestamps,
)
else:
if (
chunk_data.word_timestamps is not None
and len(chunk_data.word_timestamps) > 0
):
timestamp_acumulator += chunk_data.word_timestamps
except Exception as e:
logger.error(f"Error in single output streaming: {e}")
writer.close()
raise
# Standard streaming without download link
return JSONStreamingResponse(
single_output(),
media_type="application/json",
headers={
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
"X-Accel-Buffering": "no",
"Cache-Control": "no-cache",
"Transfer-Encoding": "chunked",
},
)
else:
# Generate complete audio using public interface
audio_data = await tts_service.generate_audio(
text=request.input,
voice=voice_name,
writer=writer,
speed=request.speed,
return_timestamps=request.return_timestamps,
normalization_options=request.normalization_options,
lang_code=request.lang_code,
)
audio_data = await AudioService.convert_audio(
audio_data,
request.response_format,
writer,
is_last_chunk=False,
trim_audio=False,
)
# Convert to requested format with proper finalization
final = await AudioService.convert_audio(
AudioChunk(np.array([], dtype=np.int16)),
request.response_format,
writer,
is_last_chunk=True,
)
output = audio_data.output + final.output
base64_output = base64.b64encode(output).decode("utf-8")
content = CaptionedSpeechResponse(
audio=base64_output,
audio_format=content_type,
timestamps=audio_data.word_timestamps,
).model_dump()
writer.close()
return JSONResponse(
content=content,
media_type="application/json",
headers={
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
"Cache-Control": "no-cache", # Prevent caching
},
)
except ValueError as e:
# Handle validation errors
logger.warning(f"Invalid request: {str(e)}")
try:
writer.close()
except:
pass
raise HTTPException(
status_code=400,
detail={
"error": "validation_error",
"message": str(e),
"type": "invalid_request_error",
},
)
except RuntimeError as e:
# Handle runtime/processing errors
logger.error(f"Processing error: {str(e)}")
try:
writer.close()
except:
pass
raise HTTPException(
status_code=500,
detail={
"error": "processing_error",
"message": str(e),
"type": "server_error",
},
)
except Exception as e:
# Handle unexpected errors
logger.error(f"Unexpected error in captioned speech generation: {str(e)}")
try:
writer.close()
except:
pass
raise HTTPException(
status_code=500,
detail={
"error": "processing_error",
"message": str(e),
"type": "server_error",
},
status_code=500, detail={"error": "Server error", "message": str(e)}
)

View file

@ -1,197 +1,77 @@
"""OpenAI-compatible router for text-to-speech"""
from typing import List, Union, AsyncGenerator
import io
import json
import os
import re
import tempfile
from typing import AsyncGenerator, Dict, List, Tuple, Union
from urllib import response
import aiofiles
import numpy as np
import torch
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
from fastapi.responses import FileResponse, StreamingResponse
from loguru import logger
from fastapi import Header, Depends, Response, APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from ..core.config import settings
from ..inference.base import AudioChunk
from ..services.audio import AudioService
from ..services.streaming_audio_writer import StreamingAudioWriter
from ..structures.schemas import OpenAISpeechRequest
from ..services.tts_service import TTSService
from ..structures import OpenAISpeechRequest
from ..structures.schemas import CaptionedSpeechRequest
# Load OpenAI mappings
def load_openai_mappings() -> Dict:
"""Load OpenAI voice and model mappings from JSON"""
api_dir = os.path.dirname(os.path.dirname(__file__))
mapping_path = os.path.join(api_dir, "core", "openai_mappings.json")
try:
with open(mapping_path, "r") as f:
return json.load(f)
except Exception as e:
logger.error(f"Failed to load OpenAI mappings: {e}")
return {"models": {}, "voices": {}}
# Global mappings
_openai_mappings = load_openai_mappings()
router = APIRouter(
tags=["OpenAI Compatible TTS"],
responses={404: {"description": "Not found"}},
)
# Global TTSService instance with lock
_tts_service = None
_init_lock = None
def get_tts_service() -> TTSService:
"""Dependency to get TTSService instance with database session"""
return TTSService() # Initialize TTSService with default settings
async def get_tts_service() -> TTSService:
"""Get global TTSService instance"""
global _tts_service, _init_lock
# Create lock if needed
if _init_lock is None:
import asyncio
_init_lock = asyncio.Lock()
# Initialize service if needed
if _tts_service is None:
async with _init_lock:
# Double check pattern
if _tts_service is None:
_tts_service = await TTSService.create()
logger.info("Created global TTSService instance")
return _tts_service
def get_model_name(model: str) -> str:
"""Get internal model name from OpenAI model name"""
base_name = _openai_mappings["models"].get(model)
if not base_name:
raise ValueError(f"Unsupported model: {model}")
return base_name + ".pth"
async def process_and_validate_voices(
async def process_voices(
voice_input: Union[str, List[str]], tts_service: TTSService
) -> str:
"""Process voice input, handling both string and list formats
Returns:
Voice name to use (with weights if specified)
"""
voices = []
"""Process voice input into a combined voice, handling both string and list formats"""
# Convert input to list of voices
if isinstance(voice_input, str):
voice_input = voice_input.replace(" ", "").strip()
if voice_input[-1] in "+-" or voice_input[0] in "+-":
raise ValueError(f"Voice combination contains empty combine items")
if re.search(r"[+-]{2,}", voice_input) is not None:
raise ValueError(f"Voice combination contains empty combine items")
voices = re.split(r"([-+])", voice_input)
voices = [v.strip() for v in voice_input.split("+") if v.strip()]
else:
voices = [[item, "+"] for item in voice_input][:-1]
voices = voice_input
if not voices:
raise ValueError("No voices provided")
# Check if all voices exist
available_voices = await tts_service.list_voices()
for voice_index in range(0, len(voices), 2):
mapped_voice = voices[voice_index].split("(")
mapped_voice = list(map(str.strip, mapped_voice))
if len(mapped_voice) > 2:
for voice in voices:
if voice not in available_voices:
raise ValueError(
f"Voice '{voices[voice_index]}' contains too many weight items"
f"Voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}"
)
if mapped_voice.count(")") > 1:
raise ValueError(
f"Voice '{voices[voice_index]}' contains too many weight items"
)
# If single voice, return it directly
if len(voices) == 1:
return voices[0]
mapped_voice[0] = _openai_mappings["voices"].get(
mapped_voice[0], mapped_voice[0]
)
if mapped_voice[0] not in available_voices:
raise ValueError(
f"Voice '{mapped_voice[0]}' not found. Available voices: {', '.join(sorted(available_voices))}"
)
voices[voice_index] = "(".join(mapped_voice)
return "".join(voices)
# Otherwise combine voices
return await tts_service.combine_voices(voices=voices)
async def stream_audio_chunks(
tts_service: TTSService,
request: Union[OpenAISpeechRequest, CaptionedSpeechRequest],
client_request: Request,
writer: StreamingAudioWriter,
) -> AsyncGenerator[AudioChunk, None]:
"""Stream audio chunks as they're generated with client disconnect handling"""
voice_name = await process_and_validate_voices(request.voice, tts_service)
unique_properties = {"return_timestamps": False}
if hasattr(request, "return_timestamps"):
unique_properties["return_timestamps"] = request.return_timestamps
try:
async for chunk_data in tts_service.generate_audio_stream(
tts_service: TTSService, request: OpenAISpeechRequest
) -> AsyncGenerator[bytes, None]:
"""Stream audio chunks as they're generated"""
voice_to_use = await process_voices(request.voice, tts_service)
async for chunk in tts_service.generate_audio_stream(
text=request.input,
voice=voice_name,
writer=writer,
voice=voice_to_use,
speed=request.speed,
output_format=request.response_format,
lang_code=request.lang_code,
normalization_options=request.normalization_options,
return_timestamps=unique_properties["return_timestamps"],
):
# Check if client is still connected
is_disconnected = client_request.is_disconnected
if callable(is_disconnected):
is_disconnected = await is_disconnected()
if is_disconnected:
logger.info("Client disconnected, stopping audio generation")
break
yield chunk_data
except Exception as e:
logger.error(f"Error in audio streaming: {str(e)}")
# Let the exception propagate to trigger cleanup
raise
yield chunk
@router.post("/audio/speech")
async def create_speech(
request: OpenAISpeechRequest,
client_request: Request,
tts_service: TTSService = Depends(get_tts_service),
x_raw_response: str = Header(None, alias="x-raw-response"),
):
"""OpenAI-compatible endpoint for text-to-speech"""
# Validate model before processing request
if request.model not in _openai_mappings["models"]:
raise HTTPException(
status_code=400,
detail={
"error": "invalid_model",
"message": f"Unsupported model: {request.model}",
"type": "invalid_request_error",
},
)
try:
# model_name = get_model_name(request.model)
tts_service = await get_tts_service()
voice_name = await process_and_validate_voices(request.voice, tts_service)
# Process voice combination and validate
voice_to_use = await process_voices(request.voice, tts_service)
# Set content type based on format
content_type = {
@ -203,460 +83,96 @@ async def create_speech(
"pcm": "audio/pcm",
}.get(request.response_format, f"audio/{request.response_format}")
writer = StreamingAudioWriter(request.response_format, sample_rate=24000)
# Check if streaming is requested (default for OpenAI client)
if request.stream:
# Create generator but don't start it yet
generator = stream_audio_chunks(
tts_service, request, client_request, writer
)
# If download link requested, wrap generator with temp file writer
if request.return_download_link:
from ..services.temp_manager import TempFileWriter
# Use download_format if specified, otherwise use response_format
output_format = request.download_format or request.response_format
temp_writer = TempFileWriter(output_format)
await temp_writer.__aenter__() # Initialize temp file
# Get download path immediately after temp file creation
download_path = temp_writer.download_path
# Create response headers with download path
headers = {
"Content-Disposition": f"attachment; filename=speech.{output_format}",
"X-Accel-Buffering": "no",
"Cache-Control": "no-cache",
"Transfer-Encoding": "chunked",
"X-Download-Path": download_path,
}
# Add header to indicate if temp file writing is available
if temp_writer._write_error:
headers["X-Download-Status"] = "unavailable"
# Create async generator for streaming
async def dual_output():
try:
# Write chunks to temp file and stream
async for chunk_data in generator:
if chunk_data.output: # Skip empty chunks
await temp_writer.write(chunk_data.output)
# if return_json:
# yield chunk, chunk_data
# else:
yield chunk_data.output
# Finalize the temp file
await temp_writer.finalize()
except Exception as e:
logger.error(f"Error in dual output streaming: {e}")
await temp_writer.__aexit__(type(e), e, e.__traceback__)
raise
finally:
# Ensure temp writer is closed
if not temp_writer._finalized:
await temp_writer.__aexit__(None, None, None)
writer.close()
# Stream with temp file writing
# Stream audio chunks as they're generated
return StreamingResponse(
dual_output(), media_type=content_type, headers=headers
)
async def single_output():
try:
# Stream chunks
async for chunk_data in generator:
if chunk_data.output: # Skip empty chunks
yield chunk_data.output
except Exception as e:
logger.error(f"Error in single output streaming: {e}")
writer.close()
raise
# Standard streaming without download link
return StreamingResponse(
single_output(),
stream_audio_chunks(tts_service, request),
media_type=content_type,
headers={
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
"X-Accel-Buffering": "no",
"Cache-Control": "no-cache",
"Transfer-Encoding": "chunked",
"X-Accel-Buffering": "no", # Disable proxy buffering
"Cache-Control": "no-cache", # Prevent caching
"Transfer-Encoding": "chunked", # Enable chunked transfer encoding
},
)
else:
# Generate complete audio
audio, _ = tts_service._generate_audio(
text=request.input,
voice=voice_to_use,
speed=request.speed,
stitch_long_output=True,
)
# Convert to requested format
content = AudioService.convert_audio(
audio, 24000, request.response_format, is_first_chunk=True, stream=False
)
return Response(
content=content,
media_type=content_type,
headers={
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
"Cache-Control": "no-cache", # Prevent caching
}
# Generate complete audio using public interface
audio_data = await tts_service.generate_audio(
text=request.input,
voice=voice_name,
writer=writer,
speed=request.speed,
normalization_options=request.normalization_options,
lang_code=request.lang_code,
)
audio_data = await AudioService.convert_audio(
audio_data,
request.response_format,
writer,
is_last_chunk=False,
trim_audio=False,
)
# Convert to requested format with proper finalization
final = await AudioService.convert_audio(
AudioChunk(np.array([], dtype=np.int16)),
request.response_format,
writer,
is_last_chunk=True,
)
output = audio_data.output + final.output
if request.return_download_link:
from ..services.temp_manager import TempFileWriter
# Use download_format if specified, otherwise use response_format
output_format = request.download_format or request.response_format
temp_writer = TempFileWriter(output_format)
await temp_writer.__aenter__() # Initialize temp file
# Get download path immediately after temp file creation
download_path = temp_writer.download_path
headers["X-Download-Path"] = download_path
try:
# Write chunks to temp file
logger.info("Writing chunks to tempory file for download")
await temp_writer.write(output)
# Finalize the temp file
await temp_writer.finalize()
except Exception as e:
logger.error(f"Error in dual output: {e}")
await temp_writer.__aexit__(type(e), e, e.__traceback__)
raise
finally:
# Ensure temp writer is closed
if not temp_writer._finalized:
await temp_writer.__aexit__(None, None, None)
writer.close()
return Response(
content=output,
media_type=content_type,
headers=headers,
},
)
except ValueError as e:
# Handle validation errors
logger.warning(f"Invalid request: {str(e)}")
try:
writer.close()
except:
pass
logger.error(f"Invalid request: {str(e)}")
raise HTTPException(
status_code=400,
detail={
"error": "validation_error",
"message": str(e),
"type": "invalid_request_error",
},
)
except RuntimeError as e:
# Handle runtime/processing errors
logger.error(f"Processing error: {str(e)}")
try:
writer.close()
except:
pass
raise HTTPException(
status_code=500,
detail={
"error": "processing_error",
"message": str(e),
"type": "server_error",
},
status_code=400, detail={"error": "Invalid request", "message": str(e)}
)
except Exception as e:
# Handle unexpected errors
logger.error(f"Unexpected error in speech generation: {str(e)}")
try:
writer.close()
except:
pass
logger.error(f"Error generating speech: {str(e)}")
raise HTTPException(
status_code=500,
detail={
"error": "processing_error",
"message": str(e),
"type": "server_error",
},
)
@router.get("/download/{filename}")
async def download_audio_file(filename: str):
"""Download a generated audio file from temp storage"""
try:
from ..core.paths import _find_file, get_content_type
# Search for file in temp directory
file_path = await _find_file(
filename=filename, search_paths=[settings.temp_file_dir]
)
# Get content type from path helper
content_type = await get_content_type(file_path)
return FileResponse(
file_path,
media_type=content_type,
filename=filename,
headers={
"Cache-Control": "no-cache",
"Content-Disposition": f"attachment; filename={filename}",
},
)
except Exception as e:
logger.error(f"Error serving download file {filename}: {e}")
raise HTTPException(
status_code=500,
detail={
"error": "server_error",
"message": "Failed to serve audio file",
"type": "server_error",
},
)
@router.get("/models")
async def list_models():
"""List all available models"""
try:
# Create standard model list
models = [
{
"id": "tts-1",
"object": "model",
"created": 1686935002,
"owned_by": "kokoro",
},
{
"id": "tts-1-hd",
"object": "model",
"created": 1686935002,
"owned_by": "kokoro",
},
{
"id": "kokoro",
"object": "model",
"created": 1686935002,
"owned_by": "kokoro",
},
]
return {"object": "list", "data": models}
except Exception as e:
logger.error(f"Error listing models: {str(e)}")
raise HTTPException(
status_code=500,
detail={
"error": "server_error",
"message": "Failed to retrieve model list",
"type": "server_error",
},
)
@router.get("/models/{model}")
async def retrieve_model(model: str):
"""Retrieve a specific model"""
try:
# Define available models
models = {
"tts-1": {
"id": "tts-1",
"object": "model",
"created": 1686935002,
"owned_by": "kokoro",
},
"tts-1-hd": {
"id": "tts-1-hd",
"object": "model",
"created": 1686935002,
"owned_by": "kokoro",
},
"kokoro": {
"id": "kokoro",
"object": "model",
"created": 1686935002,
"owned_by": "kokoro",
},
}
# Check if requested model exists
if model not in models:
raise HTTPException(
status_code=404,
detail={
"error": "model_not_found",
"message": f"Model '{model}' not found",
"type": "invalid_request_error",
},
)
# Return the specific model
return models[model]
except HTTPException:
raise
except Exception as e:
logger.error(f"Error retrieving model {model}: {str(e)}")
raise HTTPException(
status_code=500,
detail={
"error": "server_error",
"message": "Failed to retrieve model information",
"type": "server_error",
},
status_code=500, detail={"error": "Server error", "message": str(e)}
)
@router.get("/audio/voices")
async def list_voices():
async def list_voices(tts_service: TTSService = Depends(get_tts_service)):
"""List all available voices for text-to-speech"""
try:
tts_service = await get_tts_service()
voices = await tts_service.list_voices()
return {"voices": voices}
except Exception as e:
logger.error(f"Error listing voices: {str(e)}")
raise HTTPException(
status_code=500,
detail={
"error": "server_error",
"message": "Failed to retrieve voice list",
"type": "server_error",
},
)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/audio/voices/combine")
async def combine_voices(request: Union[str, List[str]]):
"""Combine multiple voices into a new voice and return the .pt file.
async def combine_voices(
request: Union[str, List[str]], tts_service: TTSService = Depends(get_tts_service)
):
"""Combine multiple voices into a new voice.
Args:
request: Either a string with voices separated by + (e.g. "voice1+voice2")
or a list of voice names to combine
Returns:
FileResponse with the combined voice .pt file
Dict with combined voice name and list of all available voices
Raises:
HTTPException:
- 400: Invalid request (wrong number of voices, voice not found)
- 500: Server error (file system issues, combination failed)
"""
# Check if local voice saving is allowed
if not settings.allow_local_voice_saving:
raise HTTPException(
status_code=403,
detail={
"error": "permission_denied",
"message": "Local voice saving is disabled",
"type": "permission_error",
},
)
try:
# Convert input to list of voices
if isinstance(request, str):
# Check if it's an OpenAI voice name
mapped_voice = _openai_mappings["voices"].get(request)
if mapped_voice:
request = mapped_voice
voices = [v.strip() for v in request.split("+") if v.strip()]
else:
# For list input, map each voice if it's an OpenAI voice name
voices = [_openai_mappings["voices"].get(v, v) for v in request]
voices = [v.strip() for v in voices if v.strip()]
if not voices:
raise ValueError("No voices provided")
# For multiple voices, validate base voices exist
tts_service = await get_tts_service()
available_voices = await tts_service.list_voices()
for voice in voices:
if voice not in available_voices:
raise ValueError(
f"Base voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}"
)
# Combine voices
combined_tensor = await tts_service.combine_voices(voices=voices)
combined_name = "+".join(voices)
# Save to temp file
temp_dir = tempfile.gettempdir()
voice_path = os.path.join(temp_dir, f"{combined_name}.pt")
buffer = io.BytesIO()
torch.save(combined_tensor, buffer)
async with aiofiles.open(voice_path, "wb") as f:
await f.write(buffer.getvalue())
return FileResponse(
voice_path,
media_type="application/octet-stream",
filename=f"{combined_name}.pt",
headers={
"Content-Disposition": f"attachment; filename={combined_name}.pt",
"Cache-Control": "no-cache",
},
)
combined_voice = await process_voices(request, tts_service)
voices = await tts_service.list_voices()
return {"voices": voices, "voice": combined_voice}
except ValueError as e:
logger.warning(f"Invalid voice combination request: {str(e)}")
logger.error(f"Invalid voice combination request: {str(e)}")
raise HTTPException(
status_code=400,
detail={
"error": "validation_error",
"message": str(e),
"type": "invalid_request_error",
},
)
except RuntimeError as e:
logger.error(f"Voice combination processing error: {str(e)}")
raise HTTPException(
status_code=500,
detail={
"error": "processing_error",
"message": "Failed to process voice combination request",
"type": "server_error",
},
status_code=400, detail={"error": "Invalid request", "message": str(e)}
)
except Exception as e:
logger.error(f"Unexpected error in voice combination: {str(e)}")
logger.error(f"Server error during voice combination: {str(e)}")
raise HTTPException(
status_code=500,
detail={
"error": "server_error",
"message": "An unexpected error occurred",
"type": "server_error",
},
status_code=500, detail={"error": "Server error", "message": "Server error"}
)

View file

@ -1,49 +0,0 @@
"""Web player router with async file serving."""
from fastapi import APIRouter, HTTPException
from fastapi.responses import Response
from loguru import logger
from ..core.config import settings
from ..core.paths import get_content_type, get_web_file_path, read_bytes
router = APIRouter(
tags=["Web Player"],
responses={404: {"description": "Not found"}},
)
@router.get("/{filename:path}")
async def serve_web_file(filename: str):
"""Serve web player static files asynchronously."""
if not settings.enable_web_player:
raise HTTPException(status_code=404, detail="Web player is disabled")
try:
# Default to index.html for root path
if filename == "" or filename == "/":
filename = "index.html"
# Get file path
file_path = await get_web_file_path(filename)
# Read file content
content = await read_bytes(file_path)
# Get content type
content_type = await get_content_type(file_path)
return Response(
content=content,
media_type=content_type,
headers={
"Cache-Control": "no-cache", # Prevent caching during development
},
)
except RuntimeError as e:
logger.warning(f"Web file not found: {filename}")
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
logger.error(f"Error serving web file {filename}: {e}")
raise HTTPException(status_code=500, detail="Internal server error")

View file

@ -1,122 +1,45 @@
"""Audio conversion service"""
import math
import struct
import time
from io import BytesIO
from typing import Tuple
import numpy as np
import scipy.io.wavfile as wavfile
import soundfile as sf
import scipy.io.wavfile as wavfile
from loguru import logger
from pydub import AudioSegment
from torch import norm
from ..core.config import settings
from ..inference.base import AudioChunk
from .streaming_audio_writer import StreamingAudioWriter
class AudioNormalizer:
"""Handles audio normalization state for a single stream"""
def __init__(self):
self.int16_max = np.iinfo(np.int16).max
self.chunk_trim_ms = settings.gap_trim_ms
self.sample_rate = 24000 # Sample rate of the audio
self.samples_to_trim = int(self.chunk_trim_ms * self.sample_rate / 1000)
self.samples_to_pad_start = int(50 * self.sample_rate / 1000)
def find_first_last_non_silent(
self,
audio_data: np.ndarray,
chunk_text: str,
speed: float,
silence_threshold_db: int = -45,
is_last_chunk: bool = False,
) -> tuple[int, int]:
"""Finds the indices of the first and last non-silent samples in audio data.
def normalize(
self, audio_data: np.ndarray, is_last_chunk: bool = False
) -> np.ndarray:
"""Normalize audio data to int16 range and trim chunk boundaries"""
# Convert to float32 if not already
audio_float = audio_data.astype(np.float32)
Args:
audio_data: Input audio data as numpy array
chunk_text: The text sent to the model to generate the resulting speech
speed: The speaking speed of the voice
silence_threshold_db: How quiet audio has to be to be conssidered silent
is_last_chunk: Whether this is the last chunk
# Normalize to [-1, 1] range first
if np.max(np.abs(audio_float)) > 0:
audio_float = audio_float / np.max(np.abs(audio_float))
Returns:
A tuple with the start of the non silent portion and with the end of the non silent portion
"""
# Trim end of non-final chunks to reduce gaps
if not is_last_chunk and len(audio_float) > self.samples_to_trim:
audio_float = audio_float[: -self.samples_to_trim]
pad_multiplier = 1
split_character = chunk_text.strip()
if len(split_character) > 0:
split_character = split_character[-1]
if split_character in settings.dynamic_gap_trim_padding_char_multiplier:
pad_multiplier = settings.dynamic_gap_trim_padding_char_multiplier[
split_character
]
if not is_last_chunk:
samples_to_pad_end = max(
int(
(
settings.dynamic_gap_trim_padding_ms
* self.sample_rate
* pad_multiplier
)
/ 1000
)
- self.samples_to_pad_start,
0,
)
else:
samples_to_pad_end = self.samples_to_pad_start
# Convert dBFS threshold to amplitude
amplitude_threshold = np.iinfo(audio_data.dtype).max * (
10 ** (silence_threshold_db / 20)
)
# Find the first samples above the silence threshold at the start and end of the audio
non_silent_index_start, non_silent_index_end = None, None
for X in range(0, len(audio_data)):
if audio_data[X] > amplitude_threshold:
non_silent_index_start = X
break
for X in range(len(audio_data) - 1, -1, -1):
if audio_data[X] > amplitude_threshold:
non_silent_index_end = X
break
# Handle the case where the entire audio is silent
if non_silent_index_start == None or non_silent_index_end == None:
return 0, len(audio_data)
return max(non_silent_index_start - self.samples_to_pad_start, 0), min(
non_silent_index_end + math.ceil(samples_to_pad_end / speed),
len(audio_data),
)
def normalize(self, audio_data: np.ndarray) -> np.ndarray:
"""Convert audio data to int16 range
Args:
audio_data: Input audio data as numpy array
Returns:
Normalized audio data
"""
if audio_data.dtype != np.int16:
# Scale directly to int16 range with clipping
return np.clip(audio_data * 32767, -32768, 32767).astype(np.int16)
return audio_data
# Scale to int16 range
return (audio_float * self.int16_max).astype(np.int16)
class AudioService:
"""Service for audio format conversions with streaming support"""
# Supported formats
SUPPORTED_FORMATS = {"wav", "mp3", "opus", "flac", "aac", "pcm"}
"""Service for audio format conversions"""
# Default audio format settings balanced for speed and compression
DEFAULT_SETTINGS = {
@ -130,119 +53,110 @@ class AudioService:
"flac": {
"compression_level": 0.0, # Light compression, still fast
},
"aac": {
"bitrate": "192k", # Default AAC bitrate
},
}
@staticmethod
async def convert_audio(
audio_chunk: AudioChunk,
def convert_audio(
audio_data: np.ndarray,
sample_rate: int,
output_format: str,
writer: StreamingAudioWriter,
speed: float = 1,
chunk_text: str = "",
is_first_chunk: bool = True,
is_last_chunk: bool = False,
trim_audio: bool = True,
normalizer: AudioNormalizer = None,
) -> AudioChunk:
"""Convert audio data to specified format with streaming support
format_settings: dict = None,
stream: bool = True,
) -> bytes:
"""Convert audio data to specified format
Args:
audio_data: Numpy array of audio samples
output_format: Target format (wav, mp3, ogg, pcm)
writer: The StreamingAudioWriter to use
speed: The speaking speed of the voice
chunk_text: The text sent to the model to generate the resulting speech
is_last_chunk: Whether this is the last chunk
trim_audio: Whether audio should be trimmed
normalizer: Optional AudioNormalizer instance for consistent normalization
sample_rate: Sample rate of the audio
output_format: Target format (wav, mp3, opus, flac, pcm)
is_first_chunk: Whether this is the first chunk of a stream
normalizer: Optional AudioNormalizer instance for consistent normalization across chunks
format_settings: Optional dict of format-specific settings to override defaults
Example: {
"mp3": {
"bitrate_mode": "VARIABLE",
"compression_level": 0.8
}
}
Default settings balance speed and compression:
optimized for localhost @ 0.0
- MP3: constant bitrate, no compression (0.0)
- OPUS: no compression (0.0)
- FLAC: no compression (0.0)
Returns:
Bytes of the converted audio chunk
Bytes of the converted audio
"""
buffer = BytesIO()
try:
# Validate format
if output_format not in AudioService.SUPPORTED_FORMATS:
raise ValueError(f"Format {output_format} not supported")
# Always normalize audio to ensure proper amplitude scaling
if normalizer is None:
normalizer = AudioNormalizer()
audio_chunk.audio = normalizer.normalize(audio_chunk.audio)
if trim_audio == True:
audio_chunk = AudioService.trim_audio(
audio_chunk, chunk_text, speed, is_last_chunk, normalizer
normalized_audio = normalizer.normalize(
audio_data, is_last_chunk=is_last_chunk
)
# Write audio data first
if len(audio_chunk.audio) > 0:
chunk_data = writer.write_chunk(audio_chunk.audio)
if output_format == "pcm":
# Raw 16-bit PCM samples, no header
buffer.write(normalized_audio.tobytes())
elif output_format == "wav":
# WAV format with headers
sf.write(
buffer,
normalized_audio,
sample_rate,
format="WAV",
subtype="PCM_16",
)
elif output_format == "mp3":
# MP3 format with proper framing
settings = format_settings.get("mp3", {}) if format_settings else {}
settings = {**AudioService.DEFAULT_SETTINGS["mp3"], **settings}
sf.write(
buffer, normalized_audio, sample_rate, format="MP3", **settings
)
elif output_format == "opus":
# Opus format in OGG container
settings = format_settings.get("opus", {}) if format_settings else {}
settings = {**AudioService.DEFAULT_SETTINGS["opus"], **settings}
sf.write(
buffer,
normalized_audio,
sample_rate,
format="OGG",
subtype="OPUS",
**settings,
)
elif output_format == "flac":
# FLAC format with proper framing
if is_first_chunk:
logger.info("Starting FLAC stream...")
settings = format_settings.get("flac", {}) if format_settings else {}
settings = {**AudioService.DEFAULT_SETTINGS["flac"], **settings}
sf.write(
buffer,
normalized_audio,
sample_rate,
format="FLAC",
subtype="PCM_16",
**settings,
)
elif output_format == "aac":
raise ValueError(
"Format aac not currently supported. Supported formats are: wav, mp3, opus, flac, pcm."
)
else:
raise ValueError(
f"Format {output_format} not supported. Supported formats are: wav, mp3, opus, flac, pcm, aac."
)
# Then finalize if this is the last chunk
if is_last_chunk:
final_data = writer.write_chunk(finalize=True)
if final_data:
audio_chunk.output = final_data
return audio_chunk
if chunk_data:
audio_chunk.output = chunk_data
return audio_chunk
buffer.seek(0)
return buffer.getvalue()
except Exception as e:
logger.error(f"Error converting audio stream to {output_format}: {str(e)}")
raise ValueError(
f"Failed to convert audio stream to {output_format}: {str(e)}"
)
@staticmethod
def trim_audio(
audio_chunk: AudioChunk,
chunk_text: str = "",
speed: float = 1,
is_last_chunk: bool = False,
normalizer: AudioNormalizer = None,
) -> AudioChunk:
"""Trim silence from start and end
Args:
audio_data: Input audio data as numpy array
chunk_text: The text sent to the model to generate the resulting speech
speed: The speaking speed of the voice
is_last_chunk: Whether this is the last chunk
normalizer: Optional AudioNormalizer instance for consistent normalization
Returns:
Trimmed audio data
"""
if normalizer is None:
normalizer = AudioNormalizer()
audio_chunk.audio = normalizer.normalize(audio_chunk.audio)
trimed_samples = 0
# Trim start and end if enough samples
if len(audio_chunk.audio) > (2 * normalizer.samples_to_trim):
audio_chunk.audio = audio_chunk.audio[
normalizer.samples_to_trim : -normalizer.samples_to_trim
]
trimed_samples += normalizer.samples_to_trim
# Find non silent portion and trim
start_index, end_index = normalizer.find_first_last_non_silent(
audio_chunk.audio, chunk_text, speed, is_last_chunk=is_last_chunk
)
audio_chunk.audio = audio_chunk.audio[start_index:end_index]
trimed_samples += start_index
if audio_chunk.word_timestamps is not None:
for timestamp in audio_chunk.word_timestamps:
timestamp.start_time -= trimed_samples / 24000
timestamp.end_time -= trimed_samples / 24000
return audio_chunk
logger.error(f"Error converting audio to {output_format}: {str(e)}")
raise ValueError(f"Failed to convert audio to {output_format}: {str(e)}")

View file

@ -1,100 +0,0 @@
"""Audio conversion service with proper streaming support"""
import struct
from io import BytesIO
from typing import Optional
import av
import numpy as np
import soundfile as sf
from loguru import logger
from pydub import AudioSegment
class StreamingAudioWriter:
"""Handles streaming audio format conversions"""
def __init__(self, format: str, sample_rate: int, channels: int = 1):
self.format = format.lower()
self.sample_rate = sample_rate
self.channels = channels
self.bytes_written = 0
self.pts = 0
codec_map = {
"wav": "pcm_s16le",
"mp3": "mp3",
"opus": "libopus",
"flac": "flac",
"aac": "aac",
}
# Format-specific setup
if self.format in ["wav", "flac", "mp3", "pcm", "aac", "opus"]:
if self.format != "pcm":
self.output_buffer = BytesIO()
self.container = av.open(
self.output_buffer,
mode="w",
format=self.format if self.format != "aac" else "adts",
)
self.stream = self.container.add_stream(
codec_map[self.format],
sample_rate=self.sample_rate,
layout="mono" if self.channels == 1 else "stereo",
)
self.stream.bit_rate = 128000
else:
raise ValueError(f"Unsupported format: {format}")
def close(self):
if hasattr(self, "container"):
self.container.close()
if hasattr(self, "output_buffer"):
self.output_buffer.close()
def write_chunk(
self, audio_data: Optional[np.ndarray] = None, finalize: bool = False
) -> bytes:
"""Write a chunk of audio data and return bytes in the target format.
Args:
audio_data: Audio data to write, or None if finalizing
finalize: Whether this is the final write to close the stream
"""
if finalize:
if self.format != "pcm":
packets = self.stream.encode(None)
for packet in packets:
self.container.mux(packet)
data = self.output_buffer.getvalue()
self.close()
return data
if audio_data is None or len(audio_data) == 0:
return b""
if self.format == "pcm":
# Write raw bytes
return audio_data.tobytes()
else:
frame = av.AudioFrame.from_ndarray(
audio_data.reshape(1, -1),
format="s16",
layout="mono" if self.channels == 1 else "stereo",
)
frame.sample_rate = self.sample_rate
frame.pts = self.pts
self.pts += frame.samples
packets = self.stream.encode(frame)
for packet in packets:
self.container.mux(packet)
data = self.output_buffer.getvalue()
self.output_buffer.seek(0)
self.output_buffer.truncate(0)
return data

View file

@ -1,170 +0,0 @@
"""Temporary file writer for audio downloads"""
import os
import tempfile
from typing import List, Optional
import aiofiles
from fastapi import HTTPException
from loguru import logger
from ..core.config import settings
async def cleanup_temp_files() -> None:
"""Clean up old temp files"""
try:
if not await aiofiles.os.path.exists(settings.temp_file_dir):
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
return
# Get all temp files with stats
files = []
total_size = 0
# Use os.scandir for sync iteration, but aiofiles.os.stat for async stats
for entry in os.scandir(settings.temp_file_dir):
if entry.is_file():
stat = await aiofiles.os.stat(entry.path)
files.append((entry.path, stat.st_mtime, stat.st_size))
total_size += stat.st_size
# Sort by modification time (oldest first)
files.sort(key=lambda x: x[1])
# Remove files if:
# 1. They're too old
# 2. We have too many files
# 3. Directory is too large
current_time = (await aiofiles.os.stat(settings.temp_file_dir)).st_mtime
max_age = settings.max_temp_dir_age_hours * 3600
for path, mtime, size in files:
should_delete = False
# Check age
if current_time - mtime > max_age:
should_delete = True
logger.info(f"Deleting old temp file: {path}")
# Check count limit
elif len(files) > settings.max_temp_dir_count:
should_delete = True
logger.info(f"Deleting excess temp file: {path}")
# Check size limit
elif total_size > settings.max_temp_dir_size_mb * 1024 * 1024:
should_delete = True
logger.info(f"Deleting to reduce directory size: {path}")
if should_delete:
try:
await aiofiles.os.remove(path)
total_size -= size
logger.info(f"Deleted temp file: {path}")
except Exception as e:
logger.warning(f"Failed to delete temp file {path}: {e}")
except Exception as e:
logger.warning(f"Error during temp file cleanup: {e}")
class TempFileWriter:
"""Handles writing audio chunks to a temp file"""
def __init__(self, format: str):
"""Initialize temp file writer
Args:
format: Audio format extension (mp3, wav, etc)
"""
self.format = format
self.temp_file = None
self._finalized = False
self._write_error = False # Flag to track if we've had a write error
async def __aenter__(self):
"""Async context manager entry"""
try:
# Clean up old files first
await cleanup_temp_files()
# Create temp file with proper extension
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
temp = tempfile.NamedTemporaryFile(
dir=settings.temp_file_dir,
delete=False,
suffix=f".{self.format}",
mode="wb",
)
self.temp_file = await aiofiles.open(temp.name, mode="wb")
self.temp_path = temp.name
temp.close() # Close sync file, we'll use async version
# Generate download path immediately
self.download_path = f"/download/{os.path.basename(self.temp_path)}"
except Exception as e:
# Handle permission issues or other errors gracefully
logger.error(f"Failed to create temp file: {e}")
self._write_error = True
# Set a placeholder path so the API can still function
self.temp_path = f"unavailable_{self.format}"
self.download_path = f"/download/{self.temp_path}"
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit"""
try:
if self.temp_file and not self._finalized:
await self.temp_file.close()
self._finalized = True
except Exception as e:
logger.error(f"Error closing temp file: {e}")
self._write_error = True
async def write(self, chunk: bytes) -> None:
"""Write a chunk of audio data
Args:
chunk: Audio data bytes to write
"""
if self._finalized:
raise RuntimeError("Cannot write to finalized temp file")
# Skip writing if we've already encountered an error
if self._write_error or not self.temp_file:
return
try:
await self.temp_file.write(chunk)
await self.temp_file.flush()
except Exception as e:
# Handle permission issues or other errors gracefully
logger.error(f"Failed to write to temp file: {e}")
self._write_error = True
async def finalize(self) -> str:
"""Close temp file and return download path
Returns:
Path to use for downloading the temp file
"""
if self._finalized:
raise RuntimeError("Temp file already finalized")
# Skip finalizing if we've already encountered an error
if self._write_error or not self.temp_file:
self._finalized = True
return self.download_path
try:
await self.temp_file.close()
self._finalized = True
except Exception as e:
# Handle permission issues or other errors gracefully
logger.error(f"Failed to finalize temp file: {e}")
self._write_error = True
self._finalized = True
return self.download_path

View file

@ -1,21 +1,13 @@
"""Text processing pipeline."""
from .normalizer import normalize_text
from .phonemizer import phonemize
from .text_processor import process_text_chunk, smart_split
from .vocabulary import tokenize
def process_text(text: str) -> list[int]:
"""Process text into token IDs (for backward compatibility)."""
return process_text_chunk(text)
from .phonemizer import EspeakBackend, PhonemizerBackend, phonemize
from .vocabulary import VOCAB, tokenize, decode_tokens
__all__ = [
"normalize_text",
"phonemize",
"tokenize",
"process_text",
"process_text_chunk",
"smart_split",
"decode_tokens",
"VOCAB",
"PhonemizerBackend",
"EspeakBackend",
]

View file

@ -0,0 +1,53 @@
"""Text chunking service"""
import re
from ...core.config import settings
def split_text(text: str, max_chunk=None):
"""Split text into chunks on natural pause points
Args:
text: Text to split into chunks
max_chunk: Maximum chunk size (defaults to settings.max_chunk_size)
"""
if max_chunk is None:
max_chunk = settings.max_chunk_size
if not isinstance(text, str):
text = str(text) if text is not None else ""
text = text.strip()
if not text:
return
# First split into sentences
sentences = re.split(r"(?<=[.!?])\s+", text)
for sentence in sentences:
sentence = sentence.strip()
if not sentence:
continue
# For medium-length sentences, split on punctuation
if len(sentence) > max_chunk: # Lower threshold for more consistent sizes
# First try splitting on semicolons and colons
parts = re.split(r"(?<=[;:])\s+", sentence)
for part in parts:
part = part.strip()
if not part:
continue
# If part is still long, split on commas
if len(part) > max_chunk:
subparts = re.split(r"(?<=,)\s+", part)
for subpart in subparts:
subpart = subpart.strip()
if subpart:
yield subpart
else:
yield part
else:
yield sentence

View file

@ -4,17 +4,8 @@ Handles various text formats including URLs, emails, numbers, money, and special
Converts them into a format suitable for text-to-speech processing.
"""
import math
import re
from functools import lru_cache
from typing import List, Optional, Union
import inflect
from numpy import number
from text_to_num import text2num
from torch import mul
from ...structures.schemas import NormalizationOptions
# Constants
VALID_TLDS = [
@ -57,85 +48,8 @@ VALID_TLDS = [
"uk",
"us",
"io",
"co",
]
VALID_UNITS = {
"m": "meter",
"cm": "centimeter",
"mm": "millimeter",
"km": "kilometer",
"in": "inch",
"ft": "foot",
"yd": "yard",
"mi": "mile", # Length
"g": "gram",
"kg": "kilogram",
"mg": "milligram", # Mass
"s": "second",
"ms": "millisecond",
"min": "minutes",
"h": "hour", # Time
"l": "liter",
"ml": "mililiter",
"cl": "centiliter",
"dl": "deciliter", # Volume
"kph": "kilometer per hour",
"mph": "mile per hour",
"mi/h": "mile per hour",
"m/s": "meter per second",
"km/h": "kilometer per hour",
"mm/s": "milimeter per second",
"cm/s": "centimeter per second",
"ft/s": "feet per second",
"cm/h": "centimeter per day", # Speed
"°c": "degree celsius",
"c": "degree celsius",
"°f": "degree fahrenheit",
"f": "degree fahrenheit",
"k": "kelvin", # Temperature
"pa": "pascal",
"kpa": "kilopascal",
"mpa": "megapascal",
"atm": "atmosphere", # Pressure
"hz": "hertz",
"khz": "kilohertz",
"mhz": "megahertz",
"ghz": "gigahertz", # Frequency
"v": "volt",
"kv": "kilovolt",
"mv": "mergavolt", # Voltage
"a": "amp",
"ma": "megaamp",
"ka": "kiloamp", # Current
"w": "watt",
"kw": "kilowatt",
"mw": "megawatt", # Power
"j": "joule",
"kj": "kilojoule",
"mj": "megajoule", # Energy
"Ω": "ohm",
"": "kiloohm",
"": "megaohm", # Resistance (Ohm)
"f": "farad",
"µf": "microfarad",
"nf": "nanofarad",
"pf": "picofarad", # Capacitance
"b": "bit",
"kb": "kilobit",
"mb": "megabit",
"gb": "gigabit",
"tb": "terabit",
"pb": "petabit", # Data size
"kbps": "kilobit per second",
"mbps": "megabit per second",
"gbps": "gigabit per second",
"tbps": "terabit per second",
"px": "pixel", # CSS units
}
MONEY_UNITS = {"$": ("dollar", "cent"), "£": ("pound", "pence"), "": ("euro", "cent")}
# Pre-compiled regex patterns for performance
EMAIL_PATTERN = re.compile(
r"\b[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-z]{2,}\b", re.IGNORECASE
@ -147,133 +61,50 @@ URL_PATTERN = re.compile(
re.IGNORECASE,
)
UNIT_PATTERN = re.compile(
r"((?<!\w)([+-]?)(\d{1,3}(,\d{3})*|\d+)(\.\d+)?)\s*("
+ "|".join(sorted(list(VALID_UNITS.keys()), reverse=True))
+ r"""){1}(?=[^\w\d]{1}|\b)""",
re.IGNORECASE,
)
TIME_PATTERN = re.compile(
r"([0-9]{1,2} ?: ?[0-9]{2}( ?: ?[0-9]{2})?)( ?(pm|am)\b)?", re.IGNORECASE
)
MONEY_PATTERN = re.compile(
r"(-?)(["
+ "".join(MONEY_UNITS.keys())
+ r"])(\d+(?:\.\d+)?)((?: hundred| thousand| (?:[bm]|tr|quadr)illion|k|m|b|t)*)\b",
re.IGNORECASE,
)
NUMBER_PATTERN = re.compile(
r"(-?)(\d+(?:\.\d+)?)((?: hundred| thousand| (?:[bm]|tr|quadr)illion|k|m|b)*)\b",
re.IGNORECASE,
)
INFLECT_ENGINE = inflect.engine()
def handle_units(u: re.Match[str]) -> str:
"""Converts units to their full form"""
unit_string = u.group(6).strip()
unit = unit_string
if unit_string.lower() in VALID_UNITS:
unit = VALID_UNITS[unit_string.lower()].split(" ")
# Handles the B vs b case
if unit[0].endswith("bit"):
b_case = unit_string[min(1, len(unit_string) - 1)]
if b_case == "B":
unit[0] = unit[0][:-3] + "byte"
number = u.group(1).strip()
unit[0] = INFLECT_ENGINE.no(unit[0], number)
return " ".join(unit)
def conditional_int(number: float, threshold: float = 0.00001):
if abs(round(number) - number) < threshold:
return int(round(number))
return number
def translate_multiplier(multiplier: str) -> str:
"""Translate multiplier abrevations to words"""
multiplier_translation = {
"k": "thousand",
"m": "million",
"b": "billion",
"t": "trillion",
}
if multiplier.lower() in multiplier_translation:
return multiplier_translation[multiplier.lower()]
return multiplier.strip()
def split_four_digit(number: float):
part1 = str(conditional_int(number))[:2]
part2 = str(conditional_int(number))[2:]
return f"{INFLECT_ENGINE.number_to_words(part1)} {INFLECT_ENGINE.number_to_words(part2)}"
def handle_numbers(n: re.Match[str]) -> str:
number = n.group(2)
try:
number = float(number)
except:
return n.group()
if n.group(1) == "-":
number *= -1
multiplier = translate_multiplier(n.group(3))
number = conditional_int(number)
if multiplier != "":
multiplier = f" {multiplier}"
else:
if (
number % 1 == 0
and len(str(number)) == 4
and number > 1500
and number % 1000 > 9
):
return split_four_digit(number)
return f"{INFLECT_ENGINE.number_to_words(number)}{multiplier}"
def split_num(num: re.Match[str]) -> str:
"""Handle number splitting for various formats"""
num = num.group()
if "." in num:
return num
elif ":" in num:
h, m = [int(n) for n in num.split(":")]
if m == 0:
return f"{h} o'clock"
elif m < 10:
return f"{h} oh {m}"
return f"{h} {m}"
year = int(num[:4])
if year < 1100 or year % 1000 < 10:
return num
left, right = num[:2], int(num[2:4])
s = "s" if num.endswith("s") else ""
if 100 <= year % 1000 <= 999:
if right == 0:
return f"{left} hundred{s}"
elif right < 10:
return f"{left} oh {right}{s}"
return f"{left} {right}{s}"
def handle_money(m: re.Match[str]) -> str:
"""Convert money expressions to spoken form"""
bill, coin = MONEY_UNITS[m.group(2)]
number = m.group(3)
try:
number = float(number)
except:
return m.group()
if m.group(1) == "-":
number *= -1
multiplier = translate_multiplier(m.group(4))
if multiplier != "":
multiplier = f" {multiplier}"
if number % 1 == 0 or multiplier != "":
text_number = f"{INFLECT_ENGINE.number_to_words(conditional_int(number))}{multiplier} {INFLECT_ENGINE.plural(bill, count=number)}"
else:
sub_number = int(str(number).split(".")[-1].ljust(2, "0"))
text_number = f"{INFLECT_ENGINE.number_to_words(int(math.floor(number)))} {INFLECT_ENGINE.plural(bill, count=number)} and {INFLECT_ENGINE.number_to_words(sub_number)} {INFLECT_ENGINE.plural(coin, count=sub_number)}"
return text_number
m = m.group()
bill = "dollar" if m[0] == "$" else "pound"
if m[-1].isalpha():
return f"{m[1:]} {bill}s"
elif "." not in m:
s = "" if m[1:] == "1" else "s"
return f"{m[1:]} {bill}{s}"
b, c = m[1:].split(".")
s = "" if b == "1" else "s"
c = int(c.ljust(2, "0"))
coins = (
f"cent{'' if c == 1 else 's'}"
if m[0] == "$"
else ("penny" if c == 1 else "pence")
)
return f"{b} {bill}{s} and {c} {coins}"
def handle_decimal(num: re.Match[str]) -> str:
@ -340,96 +171,32 @@ def handle_url(u: re.Match[str]) -> str:
return re.sub(r"\s+", " ", url).strip()
def handle_phone_number(p: re.Match[str]) -> str:
p = list(p.groups())
country_code = ""
if p[0] is not None:
p[0] = p[0].replace("+", "")
country_code += INFLECT_ENGINE.number_to_words(p[0])
area_code = INFLECT_ENGINE.number_to_words(
p[2].replace("(", "").replace(")", ""), group=1, comma=""
)
telephone_prefix = INFLECT_ENGINE.number_to_words(p[3], group=1, comma="")
line_number = INFLECT_ENGINE.number_to_words(p[4], group=1, comma="")
return ",".join([country_code, area_code, telephone_prefix, line_number])
def handle_time(t: re.Match[str]) -> str:
t = t.groups()
time_parts = t[0].split(":")
numbers = []
numbers.append(INFLECT_ENGINE.number_to_words(time_parts[0].strip()))
minute_number = INFLECT_ENGINE.number_to_words(time_parts[1].strip())
if int(time_parts[1]) < 10:
if int(time_parts[1]) != 0:
numbers.append(f"oh {minute_number}")
else:
numbers.append(minute_number)
half = ""
if len(time_parts) > 2:
seconds_number = INFLECT_ENGINE.number_to_words(time_parts[2].strip())
second_word = INFLECT_ENGINE.plural("second", int(time_parts[2].strip()))
numbers.append(f"and {seconds_number} {second_word}")
else:
if t[2] is not None:
half = " " + t[2].strip()
else:
if int(time_parts[1]) == 0:
numbers.append("o'clock")
return " ".join(numbers) + half
def normalize_text(text: str, normalization_options: NormalizationOptions) -> str:
"""Normalize text for TTS processing"""
# Handle email addresses first if enabled
if normalization_options.email_normalization:
def normalize_urls(text: str) -> str:
"""Pre-process URLs before other text normalization"""
# Handle email addresses first
text = EMAIL_PATTERN.sub(handle_email, text)
# Handle URLs if enabled
if normalization_options.url_normalization:
# Handle URLs
text = URL_PATTERN.sub(handle_url, text)
# Pre-process numbers with units if enabled
if normalization_options.unit_normalization:
text = UNIT_PATTERN.sub(handle_units, text)
return text
# Replace optional pluralization
if normalization_options.optional_pluralization_normalization:
text = re.sub(r"\(s\)", "s", text)
# Replace phone numbers:
if normalization_options.phone_normalization:
text = re.sub(
r"(\+?\d{1,2})?([ .-]?)(\(?\d{3}\)?)[\s.-](\d{3})[\s.-](\d{4})",
handle_phone_number,
text,
)
def normalize_text(text: str) -> str:
"""Normalize text for TTS processing"""
# Pre-process URLs first
text = normalize_urls(text)
# Replace quotes and brackets
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
text = text.replace("«", chr(8220)).replace("»", chr(8221))
text = text.replace(chr(8220), '"').replace(chr(8221), '"')
text = text.replace("(", "«").replace(")", "»")
# Handle CJK punctuation and some non standard chars
for a, b in zip("、。!,:;?", ",.!,:;?-"):
# Handle CJK punctuation
for a, b in zip("、。!,:;?", ",.!,:;?"):
text = text.replace(a, b + " ")
# Handle simple time in the format of HH:MM:SS (am/pm)
text = TIME_PATTERN.sub(
handle_time,
text,
)
# Clean up whitespace
text = re.sub(r"[^\S \n]", " ", text)
text = re.sub(r" +", " ", text)
@ -446,15 +213,15 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text)
# Handle numbers and money
text = re.sub(
r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)", split_num, text
)
text = re.sub(r"(?<=\d),(?=\d)", "", text)
text = MONEY_PATTERN.sub(
text = re.sub(
r"(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b",
handle_money,
text,
)
text = NUMBER_PATTERN.sub(handle_numbers, text)
text = re.sub(r"\d*\.\d+", handle_decimal, text)
# Handle various formatting

View file

@ -5,8 +5,6 @@ import phonemizer
from .normalizer import normalize_text
phonemizers = {}
class PhonemizerBackend(ABC):
"""Abstract base class for phonemization backends"""
@ -36,7 +34,6 @@ class EspeakBackend(PhonemizerBackend):
self.backend = phonemizer.backend.EspeakBackend(
language=language, preserve_punctuation=True, with_stress=True
)
self.language = language
def phonemize(self, text: str) -> str:
@ -94,9 +91,8 @@ def phonemize(text: str, language: str = "a", normalize: bool = True) -> str:
Returns:
Phonemized text
"""
global phonemizers
if normalize:
text = normalize_text(text)
if language not in phonemizers:
phonemizers[language] = create_phonemizer(language)
return phonemizers[language].phonemize(text)
phonemizer = create_phonemizer(language)
return phonemizer.phonemize(text)

View file

@ -1,276 +0,0 @@
"""Unified text processing for TTS with smart chunking."""
import re
import time
from typing import AsyncGenerator, Dict, List, Tuple
from loguru import logger
from ...core.config import settings
from ...structures.schemas import NormalizationOptions
from .normalizer import normalize_text
from .phonemizer import phonemize
from .vocabulary import tokenize
# Pre-compiled regex patterns for performance
CUSTOM_PHONEMES = re.compile(r"(\[([^\]]|\n)*?\])(\(\/([^\/)]|\n)*?\/\))")
def process_text_chunk(
text: str, language: str = "a", skip_phonemize: bool = False
) -> List[int]:
"""Process a chunk of text through normalization, phonemization, and tokenization.
Args:
text: Text chunk to process
language: Language code for phonemization
skip_phonemize: If True, treat input as phonemes and skip normalization/phonemization
Returns:
List of token IDs
"""
start_time = time.time()
if skip_phonemize:
# Input is already phonemes, just tokenize
t0 = time.time()
tokens = tokenize(text)
t1 = time.time()
else:
# Normal text processing pipeline
t0 = time.time()
t1 = time.time()
t0 = time.time()
phonemes = phonemize(text, language, normalize=False) # Already normalized
t1 = time.time()
t0 = time.time()
tokens = tokenize(phonemes)
t1 = time.time()
total_time = time.time() - start_time
logger.debug(
f"Total processing took {total_time * 1000:.2f}ms for chunk: '{text[:50]}{'...' if len(text) > 50 else ''}'"
)
return tokens
async def yield_chunk(
text: str, tokens: List[int], chunk_count: int
) -> Tuple[str, List[int]]:
"""Yield a chunk with consistent logging."""
logger.debug(
f"Yielding chunk {chunk_count}: '{text[:50]}{'...' if len(text) > 50 else ''}' ({len(tokens)} tokens)"
)
return text, tokens
def process_text(text: str, language: str = "a") -> List[int]:
"""Process text into token IDs.
Args:
text: Text to process
language: Language code for phonemization
Returns:
List of token IDs
"""
if not isinstance(text, str):
text = str(text) if text is not None else ""
text = text.strip()
if not text:
return []
return process_text_chunk(text, language)
def get_sentence_info(
text: str, custom_phenomes_list: Dict[str, str]
) -> List[Tuple[str, List[int], int]]:
"""Process all sentences and return info."""
sentences = re.split(r"([.!?;:])(?=\s|$)", text)
phoneme_length, min_value = len(custom_phenomes_list), 0
results = []
for i in range(0, len(sentences), 2):
sentence = sentences[i].strip()
for replaced in range(min_value, phoneme_length):
current_id = f"</|custom_phonemes_{replaced}|/>"
if current_id in sentence:
sentence = sentence.replace(
current_id, custom_phenomes_list.pop(current_id)
)
min_value += 1
punct = sentences[i + 1] if i + 1 < len(sentences) else ""
if not sentence:
continue
full = sentence + punct
tokens = process_text_chunk(full)
results.append((full, tokens, len(tokens)))
return results
def handle_custom_phonemes(s: re.Match[str], phenomes_list: Dict[str, str]) -> str:
latest_id = f"</|custom_phonemes_{len(phenomes_list)}|/>"
phenomes_list[latest_id] = s.group(0).strip()
return latest_id
async def smart_split(
text: str,
max_tokens: int = settings.absolute_max_tokens,
lang_code: str = "a",
normalization_options: NormalizationOptions = NormalizationOptions(),
) -> AsyncGenerator[Tuple[str, List[int]], None]:
"""Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens."""
start_time = time.time()
chunk_count = 0
logger.info(f"Starting smart split for {len(text)} chars")
custom_phoneme_list = {}
# Normalize text
if settings.advanced_text_normalization and normalization_options.normalize:
print(lang_code)
if lang_code in ["a", "b", "en-us", "en-gb"]:
text = CUSTOM_PHONEMES.sub(
lambda s: handle_custom_phonemes(s, custom_phoneme_list), text
)
text = normalize_text(text, normalization_options)
else:
logger.info(
"Skipping text normalization as it is only supported for english"
)
# Process all sentences
sentences = get_sentence_info(text, custom_phoneme_list)
current_chunk = []
current_tokens = []
current_count = 0
for sentence, tokens, count in sentences:
# Handle sentences that exceed max tokens
if count > max_tokens:
# Yield current chunk if any
if current_chunk:
chunk_text = " ".join(current_chunk)
chunk_count += 1
logger.debug(
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
)
yield chunk_text, current_tokens
current_chunk = []
current_tokens = []
current_count = 0
# Split long sentence on commas
clauses = re.split(r"([,])", sentence)
clause_chunk = []
clause_tokens = []
clause_count = 0
for j in range(0, len(clauses), 2):
clause = clauses[j].strip()
comma = clauses[j + 1] if j + 1 < len(clauses) else ""
if not clause:
continue
full_clause = clause + comma
tokens = process_text_chunk(full_clause)
count = len(tokens)
# If adding clause keeps us under max and not optimal yet
if (
clause_count + count <= max_tokens
and clause_count + count <= settings.target_max_tokens
):
clause_chunk.append(full_clause)
clause_tokens.extend(tokens)
clause_count += count
else:
# Yield clause chunk if we have one
if clause_chunk:
chunk_text = " ".join(clause_chunk)
chunk_count += 1
logger.debug(
f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)"
)
yield chunk_text, clause_tokens
clause_chunk = [full_clause]
clause_tokens = tokens
clause_count = count
# Don't forget last clause chunk
if clause_chunk:
chunk_text = " ".join(clause_chunk)
chunk_count += 1
logger.debug(
f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)"
)
yield chunk_text, clause_tokens
# Regular sentence handling
elif (
current_count >= settings.target_min_tokens
and current_count + count > settings.target_max_tokens
):
# If we have a good sized chunk and adding next sentence exceeds target,
# yield current chunk and start new one
chunk_text = " ".join(current_chunk)
chunk_count += 1
logger.info(
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
)
yield chunk_text, current_tokens
current_chunk = [sentence]
current_tokens = tokens
current_count = count
elif current_count + count <= settings.target_max_tokens:
# Keep building chunk while under target max
current_chunk.append(sentence)
current_tokens.extend(tokens)
current_count += count
elif (
current_count + count <= max_tokens
and current_count < settings.target_min_tokens
):
# Only exceed target max if we haven't reached minimum size yet
current_chunk.append(sentence)
current_tokens.extend(tokens)
current_count += count
else:
# Yield current chunk and start new one
if current_chunk:
chunk_text = " ".join(current_chunk)
chunk_count += 1
logger.info(
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
)
yield chunk_text, current_tokens
current_chunk = [sentence]
current_tokens = tokens
current_count = count
# Don't forget the last chunk
if current_chunk:
chunk_text = " ".join(current_chunk)
chunk_count += 1
logger.info(
f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
)
yield chunk_text, current_tokens
total_time = time.time() - start_time
logger.info(
f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks"
)

View file

@ -0,0 +1,172 @@
import os
import threading
from abc import ABC, abstractmethod
from typing import List, Tuple
import numpy as np
import torch
from loguru import logger
from ..core.config import settings
class TTSBaseModel(ABC):
_instance = None
_lock = threading.Lock()
_device = None
VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices")
@classmethod
async def setup(cls):
"""Initialize model and setup voices"""
with cls._lock:
# Set device
cuda_available = torch.cuda.is_available()
logger.info(f"CUDA available: {cuda_available}")
if cuda_available:
try:
# Test CUDA device
test_tensor = torch.zeros(1).cuda()
logger.info("CUDA test successful")
model_path = os.path.join(
settings.model_dir, settings.pytorch_model_path
)
cls._device = "cuda"
except Exception as e:
logger.error(f"CUDA test failed: {e}")
cls._device = "cpu"
else:
cls._device = "cpu"
model_path = os.path.join(settings.model_dir, settings.onnx_model_path)
logger.info(f"Initializing model on {cls._device}")
# Initialize model first
model = cls.initialize(settings.model_dir, model_path=model_path)
if model is None:
raise RuntimeError(f"Failed to initialize {cls._device.upper()} model")
cls._instance = model
# Setup voices directory
os.makedirs(cls.VOICES_DIR, exist_ok=True)
# Copy base voices to local directory
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
if os.path.exists(base_voices_dir):
for file in os.listdir(base_voices_dir):
if file.endswith(".pt"):
voice_name = file[:-3]
voice_path = os.path.join(cls.VOICES_DIR, file)
if not os.path.exists(voice_path):
try:
logger.info(
f"Copying base voice {voice_name} to voices directory"
)
base_path = os.path.join(base_voices_dir, file)
voicepack = torch.load(
base_path,
map_location=cls._device,
weights_only=True,
)
torch.save(voicepack, voice_path)
except Exception as e:
logger.error(
f"Error copying voice {voice_name}: {str(e)}"
)
# Count voices in directory
voice_count = len(
[f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]
)
# Now that model and voices are ready, do warmup
try:
with open(
os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"core",
"don_quixote.txt",
)
) as f:
warmup_text = f.read()
except Exception as e:
logger.warning(f"Failed to load warmup text: {e}")
warmup_text = "This is a warmup text that will be split into chunks for processing."
# Use warmup service after model is fully initialized
from .warmup import WarmupService
warmup = WarmupService()
# Load and warm up voices
loaded_voices = warmup.load_voices()
await warmup.warmup_voices(warmup_text, loaded_voices)
logger.info("Model warm-up complete")
# Count voices in directory
voice_count = len(
[f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]
)
return voice_count
@classmethod
@abstractmethod
def initialize(cls, model_dir: str, model_path: str = None):
"""Initialize the model"""
pass
@classmethod
@abstractmethod
def process_text(cls, text: str, language: str) -> Tuple[str, List[int]]:
"""Process text into phonemes and tokens
Args:
text: Input text
language: Language code
Returns:
tuple[str, list[int]]: Phonemes and token IDs
"""
pass
@classmethod
@abstractmethod
def generate_from_text(
cls, text: str, voicepack: torch.Tensor, language: str, speed: float
) -> Tuple[np.ndarray, str]:
"""Generate audio from text
Args:
text: Input text
voicepack: Voice tensor
language: Language code
speed: Speed factor
Returns:
tuple[np.ndarray, str]: Generated audio samples and phonemes
"""
pass
@classmethod
@abstractmethod
def generate_from_tokens(
cls, tokens: List[int], voicepack: torch.Tensor, speed: float
) -> np.ndarray:
"""Generate audio from tokens
Args:
tokens: Token IDs
voicepack: Voice tensor
speed: Speed factor
Returns:
np.ndarray: Generated audio samples
"""
pass
@classmethod
def get_device(cls):
"""Get the current device"""
if cls._device is None:
raise RuntimeError("Model not initialized. Call setup() first.")
return cls._device

165
api/src/services/tts_cpu.py Normal file
View file

@ -0,0 +1,165 @@
import os
import numpy as np
import torch
from loguru import logger
from onnxruntime import (
ExecutionMode,
SessionOptions,
InferenceSession,
GraphOptimizationLevel,
)
from .tts_base import TTSBaseModel
from ..core.config import settings
from .text_processing import tokenize, phonemize
class TTSCPUModel(TTSBaseModel):
_instance = None
_onnx_session = None
@classmethod
def get_instance(cls):
"""Get the model instance"""
if cls._onnx_session is None:
raise RuntimeError("ONNX model not initialized. Call initialize() first.")
return cls._onnx_session
@classmethod
def initialize(cls, model_dir: str, model_path: str = None):
"""Initialize ONNX model for CPU inference"""
if cls._onnx_session is None:
# Try loading ONNX model
onnx_path = os.path.join(model_dir, settings.onnx_model_path)
if os.path.exists(onnx_path):
logger.info(f"Loading ONNX model from {onnx_path}")
else:
logger.error(f"ONNX model not found at {onnx_path}")
return None
if not onnx_path:
return None
# Configure ONNX session for optimal performance
session_options = SessionOptions()
# Set optimization level
if settings.onnx_optimization_level == "all":
session_options.graph_optimization_level = (
GraphOptimizationLevel.ORT_ENABLE_ALL
)
elif settings.onnx_optimization_level == "basic":
session_options.graph_optimization_level = (
GraphOptimizationLevel.ORT_ENABLE_BASIC
)
else:
session_options.graph_optimization_level = (
GraphOptimizationLevel.ORT_DISABLE_ALL
)
# Configure threading
session_options.intra_op_num_threads = settings.onnx_num_threads
session_options.inter_op_num_threads = settings.onnx_inter_op_threads
# Set execution mode
session_options.execution_mode = (
ExecutionMode.ORT_PARALLEL
if settings.onnx_execution_mode == "parallel"
else ExecutionMode.ORT_SEQUENTIAL
)
# Enable/disable memory pattern optimization
session_options.enable_mem_pattern = settings.onnx_memory_pattern
# Configure CPU provider options
provider_options = {
"CPUExecutionProvider": {
"arena_extend_strategy": settings.onnx_arena_extend_strategy,
"cpu_memory_arena_cfg": "cpu:0",
}
}
session = InferenceSession(
onnx_path,
sess_options=session_options,
providers=["CPUExecutionProvider"],
provider_options=[provider_options],
)
cls._onnx_session = session
return session
return cls._onnx_session
@classmethod
def process_text(cls, text: str, language: str) -> tuple[str, list[int]]:
"""Process text into phonemes and tokens
Args:
text: Input text
language: Language code
Returns:
tuple[str, list[int]]: Phonemes and token IDs
"""
phonemes = phonemize(text, language)
tokens = tokenize(phonemes)
tokens = [0] + tokens + [0] # Add start/end tokens
return phonemes, tokens
@classmethod
def generate_from_text(
cls, text: str, voicepack: torch.Tensor, language: str, speed: float
) -> tuple[np.ndarray, str]:
"""Generate audio from text
Args:
text: Input text
voicepack: Voice tensor
language: Language code
speed: Speed factor
Returns:
tuple[np.ndarray, str]: Generated audio samples and phonemes
"""
if cls._onnx_session is None:
raise RuntimeError("ONNX model not initialized")
# Process text
phonemes, tokens = cls.process_text(text, language)
# Generate audio
audio = cls.generate_from_tokens(tokens, voicepack, speed)
return audio, phonemes
@classmethod
def generate_from_tokens(
cls, tokens: list[int], voicepack: torch.Tensor, speed: float
) -> np.ndarray:
"""Generate audio from tokens
Args:
tokens: Token IDs
voicepack: Voice tensor
speed: Speed factor
Returns:
np.ndarray: Generated audio samples
"""
if cls._onnx_session is None:
raise RuntimeError("ONNX model not initialized")
# Pre-allocate and prepare inputs
tokens_input = np.array([tokens], dtype=np.int64)
style_input = voicepack[
len(tokens) - 2
].numpy() # Already has correct dimensions
speed_input = np.full(
1, speed, dtype=np.float32
) # More efficient than ones * speed
# Run inference with optimized inputs
result = cls._onnx_session.run(
None, {"tokens": tokens_input, "style": style_input, "speed": speed_input}
)
return result[0]

262
api/src/services/tts_gpu.py Normal file
View file

@ -0,0 +1,262 @@
import os
import time
import numpy as np
import torch
from loguru import logger
from models import build_model
from .tts_base import TTSBaseModel
from ..core.config import settings
from .text_processing import tokenize, phonemize
# @torch.no_grad()
# def forward(model, tokens, ref_s, speed):
# """Forward pass through the model"""
# device = ref_s.device
# tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
# input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
# text_mask = length_to_mask(input_lengths).to(device)
# bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
# d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
# s = ref_s[:, 128:]
# d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
# x, _ = model.predictor.lstm(d)
# duration = model.predictor.duration_proj(x)
# duration = torch.sigmoid(duration).sum(axis=-1) / speed
# pred_dur = torch.round(duration).clamp(min=1).long()
# pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
# c_frame = 0
# for i in range(pred_aln_trg.size(0)):
# pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
# c_frame += pred_dur[0, i].item()
# en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
# F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
# t_en = model.text_encoder(tokens, input_lengths, text_mask)
# asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
# return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
@torch.no_grad()
def forward(model, tokens, ref_s, speed):
"""Forward pass through the model with moderate memory management"""
device = ref_s.device
try:
# Initial tensor setup with proper device placement
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
text_mask = length_to_mask(input_lengths).to(device)
# Split and clone reference signals with explicit device placement
s_content = ref_s[:, 128:].clone().to(device)
s_ref = ref_s[:, :128].clone().to(device)
# BERT and encoder pass
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
# Predictor forward pass
d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask)
x, _ = model.predictor.lstm(d)
# Duration prediction
duration = model.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long()
# Only cleanup large intermediates
del duration, x
# Alignment matrix construction
pred_aln_trg = torch.zeros(input_lengths.item(), pred_dur.sum().item(), device=device)
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
c_frame += pred_dur[0, i].item()
pred_aln_trg = pred_aln_trg.unsqueeze(0)
# Matrix multiplications with selective cleanup
en = d.transpose(-1, -2) @ pred_aln_trg
del d # Free large intermediate tensor
F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
del en # Free large intermediate tensor
# Final text encoding and decoding
t_en = model.text_encoder(tokens, input_lengths, text_mask)
asr = t_en @ pred_aln_trg
del t_en # Free large intermediate tensor
# Final decoding and transfer to CPU
output = model.decoder(asr, F0_pred, N_pred, s_ref)
result = output.squeeze().cpu().numpy()
return result
finally:
# Let PyTorch handle most cleanup automatically
# Only explicitly free the largest tensors
del pred_aln_trg, asr
# def length_to_mask(lengths):
# """Create attention mask from lengths"""
# mask = (
# torch.arange(lengths.max())
# .unsqueeze(0)
# .expand(lengths.shape[0], -1)
# .type_as(lengths)
# )
# mask = torch.gt(mask + 1, lengths.unsqueeze(1))
# return mask
def length_to_mask(lengths):
"""Create attention mask from lengths - possibly optimized version"""
max_len = lengths.max()
# Create mask directly on the same device as lengths
mask = torch.arange(max_len, device=lengths.device)[None, :].expand(
lengths.shape[0], -1
)
# Avoid type_as by using the correct dtype from the start
if lengths.dtype != mask.dtype:
mask = mask.to(dtype=lengths.dtype)
# Fuse operations using broadcasting
return mask + 1 > lengths[:, None]
class TTSGPUModel(TTSBaseModel):
_instance = None
_device = "cuda"
@classmethod
def get_instance(cls):
"""Get the model instance"""
if cls._instance is None:
raise RuntimeError("GPU model not initialized. Call initialize() first.")
return cls._instance
@classmethod
def initialize(cls, model_dir: str, model_path: str):
"""Initialize PyTorch model for GPU inference"""
if cls._instance is None and torch.cuda.is_available():
try:
logger.info("Initializing GPU model")
model_path = os.path.join(model_dir, settings.pytorch_model_path)
model = build_model(model_path, cls._device)
cls._instance = model
return model
except Exception as e:
logger.error(f"Failed to initialize GPU model: {e}")
return None
return cls._instance
@classmethod
def process_text(cls, text: str, language: str) -> tuple[str, list[int]]:
"""Process text into phonemes and tokens
Args:
text: Input text
language: Language code
Returns:
tuple[str, list[int]]: Phonemes and token IDs
"""
phonemes = phonemize(text, language)
tokens = tokenize(phonemes)
return phonemes, tokens
@classmethod
def generate_from_text(
cls, text: str, voicepack: torch.Tensor, language: str, speed: float
) -> tuple[np.ndarray, str]:
"""Generate audio from text
Args:
text: Input text
voicepack: Voice tensor
language: Language code
speed: Speed factor
Returns:
tuple[np.ndarray, str]: Generated audio samples and phonemes
"""
if cls._instance is None:
raise RuntimeError("GPU model not initialized")
# Process text
phonemes, tokens = cls.process_text(text, language)
# Generate audio
audio = cls.generate_from_tokens(tokens, voicepack, speed)
return audio, phonemes
@classmethod
def generate_from_tokens(
cls, tokens: list[int], voicepack: torch.Tensor, speed: float
) -> np.ndarray:
"""Generate audio from tokens with moderate memory management
Args:
tokens: Token IDs
voicepack: Voice tensor
speed: Speed factor
Returns:
np.ndarray: Generated audio samples
"""
if cls._instance is None:
raise RuntimeError("GPU model not initialized")
try:
device = cls._device
# Check memory pressure
if torch.cuda.is_available():
memory_allocated = torch.cuda.memory_allocated(device) / 1e9 # Convert to GB
if memory_allocated > 2.0: # 2GB limit
logger.info(
f"Memory usage above 2GB threshold:{memory_allocated:.2f}GB "
f"Clearing cache"
)
torch.cuda.empty_cache()
import gc
gc.collect()
# Get reference style with proper device placement
ref_s = voicepack[len(tokens)].clone().to(device)
# Generate audio
audio = forward(cls._instance, tokens, ref_s, speed)
return audio
except RuntimeError as e:
if "out of memory" in str(e):
# On OOM, do a full cleanup and retry
if torch.cuda.is_available():
logger.warning("Out of memory detected, performing full cleanup")
torch.cuda.synchronize()
torch.cuda.empty_cache()
import gc
gc.collect()
# Log memory stats after cleanup
memory_allocated = torch.cuda.memory_allocated(device)
memory_reserved = torch.cuda.memory_reserved(device)
logger.info(
f"Memory after OOM cleanup: "
f"Allocated: {memory_allocated / 1e9:.2f}GB, "
f"Reserved: {memory_reserved / 1e9:.2f}GB"
)
# Retry generation
ref_s = voicepack[len(tokens)].clone().to(device)
audio = forward(cls._instance, tokens, ref_s, speed)
return audio
raise
finally:
# Only synchronize at the top level, no empty_cache
if torch.cuda.is_available():
torch.cuda.synchronize()

View file

@ -0,0 +1,8 @@
import torch
if torch.cuda.is_available():
from .tts_gpu import TTSGPUModel as TTSModel
else:
from .tts_cpu import TTSCPUModel as TTSModel
__all__ = ["TTSModel"]

View file

@ -1,459 +1,269 @@
"""TTS service using model and voice managers."""
import asyncio
import io
import os
import re
import tempfile
import time
from typing import AsyncGenerator, List, Optional, Tuple, Union
from typing import List, Tuple, Optional
from functools import lru_cache
import numpy as np
import torch
from kokoro import KPipeline
import aiofiles.os
import scipy.io.wavfile as wavfile
from loguru import logger
from .audio import AudioService, AudioNormalizer
from .tts_model import TTSModel
from ..core.config import settings
from ..inference.base import AudioChunk
from ..inference.kokoro_v1 import KokoroV1
from ..inference.model_manager import get_manager as get_model_manager
from ..inference.voice_manager import get_manager as get_voice_manager
from ..structures.schemas import NormalizationOptions
from .audio import AudioNormalizer, AudioService
from .streaming_audio_writer import StreamingAudioWriter
from .text_processing import tokenize
from .text_processing.text_processor import process_text_chunk, smart_split
from .text_processing import chunker, normalize_text
class TTSService:
"""Text-to-speech service."""
# Limit concurrent chunk processing
_chunk_semaphore = asyncio.Semaphore(4)
def __init__(self, output_dir: str = None):
"""Initialize service."""
self.output_dir = output_dir
self.model_manager = None
self._voice_manager = None
self.model = TTSModel.get_instance()
@classmethod
async def create(cls, output_dir: str = None) -> "TTSService":
"""Create and initialize TTSService instance."""
service = cls(output_dir)
service.model_manager = await get_model_manager()
service._voice_manager = await get_voice_manager()
return service
@staticmethod
@lru_cache(maxsize=3) # Cache up to 3 most recently used voices
def _load_voice(voice_path: str) -> torch.Tensor:
"""Load and cache a voice model"""
return torch.load(
voice_path, map_location=TTSModel.get_device(), weights_only=True
)
def _get_voice_path(self, voice_name: str) -> Optional[str]:
"""Get the path to a voice file"""
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice_name}.pt")
return voice_path if os.path.exists(voice_path) else None
def _generate_audio(
self, text: str, voice: str, speed: float, stitch_long_output: bool = True
) -> Tuple[torch.Tensor, float]:
"""Generate complete audio and return with processing time"""
audio, processing_time = self._generate_audio_internal(
text, voice, speed, stitch_long_output
)
return audio, processing_time
def _generate_audio_internal(
self, text: str, voice: str, speed: float, stitch_long_output: bool = True
) -> Tuple[torch.Tensor, float]:
"""Generate audio and measure processing time"""
start_time = time.time()
async def _process_chunk(
self,
chunk_text: str,
tokens: List[int],
voice_name: str,
voice_path: str,
speed: float,
writer: StreamingAudioWriter,
output_format: Optional[str] = None,
is_first: bool = False,
is_last: bool = False,
normalizer: Optional[AudioNormalizer] = None,
lang_code: Optional[str] = None,
return_timestamps: Optional[bool] = False,
) -> AsyncGenerator[AudioChunk, None]:
"""Process tokens into audio."""
async with self._chunk_semaphore:
try:
# Handle stream finalization
if is_last:
# Skip format conversion for raw audio mode
if not output_format:
yield AudioChunk(np.array([], dtype=np.int16), output=b"")
return
chunk_data = await AudioService.convert_audio(
AudioChunk(
np.array([], dtype=np.float32)
), # Dummy data for type checking
output_format,
writer,
speed,
"",
normalizer=normalizer,
is_last_chunk=True,
)
yield chunk_data
return
# Normalize text once at the start
if not text:
raise ValueError("Text is empty after preprocessing")
normalized = normalize_text(text)
if not normalized:
raise ValueError("Text is empty after preprocessing")
text = str(normalized)
# Skip empty chunks
if not tokens and not chunk_text:
return
# Check voice exists
voice_path = self._get_voice_path(voice)
if not voice_path:
raise ValueError(f"Voice not found: {voice}")
# Get backend
backend = self.model_manager.get_backend()
# Load voice using cached loader
voicepack = self._load_voice(voice_path)
# Generate audio using pre-warmed model
if isinstance(backend, KokoroV1):
chunk_index = 0
# For Kokoro V1, pass text and voice info with lang_code
async for chunk_data in self.model_manager.generate(
chunk_text,
(voice_name, voice_path),
speed=speed,
lang_code=lang_code,
return_timestamps=return_timestamps,
):
# For streaming, convert to bytes
if output_format:
# For non-streaming, preprocess all chunks first
if stitch_long_output:
# Preprocess all chunks to phonemes/tokens
chunks_data = []
for chunk in chunker.split_text(text):
try:
chunk_data = await AudioService.convert_audio(
chunk_data,
output_format,
writer,
speed,
chunk_text,
is_last_chunk=is_last,
normalizer=normalizer,
)
yield chunk_data
phonemes, tokens = TTSModel.process_text(chunk, voice[0])
chunks_data.append((chunk, tokens))
except Exception as e:
logger.error(f"Failed to convert audio: {str(e)}")
else:
chunk_data = AudioService.trim_audio(
chunk_data, chunk_text, speed, is_last, normalizer
)
yield chunk_data
chunk_index += 1
else:
# For legacy backends, load voice tensor
voice_tensor = await self._voice_manager.load_voice(
voice_name, device=backend.device
)
chunk_data = await self.model_manager.generate(
tokens,
voice_tensor,
speed=speed,
return_timestamps=return_timestamps,
logger.error(
f"Failed to process chunk: '{chunk}'. Error: {str(e)}"
)
continue
if chunk_data.audio is None:
logger.error("Model generated None for audio chunk")
return
if not chunks_data:
raise ValueError("No chunks were processed successfully")
if len(chunk_data.audio) == 0:
logger.error("Model generated empty audio chunk")
return
# For streaming, convert to bytes
if output_format:
# Generate audio for all chunks
audio_chunks = []
for chunk, tokens in chunks_data:
try:
chunk_data = await AudioService.convert_audio(
chunk_data,
output_format,
writer,
speed,
chunk_text,
normalizer=normalizer,
is_last_chunk=is_last,
chunk_audio = TTSModel.generate_from_tokens(
tokens, voicepack, speed
)
yield chunk_data
except Exception as e:
logger.error(f"Failed to convert audio: {str(e)}")
if chunk_audio is not None:
audio_chunks.append(chunk_audio)
else:
trimmed = AudioService.trim_audio(
chunk_data, chunk_text, speed, is_last, normalizer
)
yield trimmed
logger.error(f"No audio generated for chunk: '{chunk}'")
except Exception as e:
logger.error(f"Failed to process tokens: {str(e)}")
logger.error(
f"Failed to generate audio for chunk: '{chunk}'. Error: {str(e)}"
)
continue
async def _load_voice_from_path(self, path: str, weight: float):
# Check if the path is None and raise a ValueError if it is not
if not path:
raise ValueError(f"Voice not found at path: {path}")
if not audio_chunks:
raise ValueError("No audio chunks were generated successfully")
logger.debug(f"Loading voice tensor from path: {path}")
return torch.load(path, map_location="cpu") * weight
async def _get_voices_path(self, voice: str) -> Tuple[str, str]:
"""Get voice path, handling combined voices.
Args:
voice: Voice name or combined voice names (e.g., 'af_jadzia+af_jessica')
Returns:
Tuple of (voice name to use, voice path to use)
Raises:
RuntimeError: If voice not found
"""
try:
# Split the voice on + and - and ensure that they get added to the list eg: hi+bob = ["hi","+","bob"]
split_voice = re.split(r"([-+])", voice)
# If it is only once voice there is no point in loading it up, doing nothing with it, then saving it
if len(split_voice) == 1:
# Since its a single voice the only time that the weight would matter is if voice_weight_normalization is off
if (
"(" not in voice and ")" not in voice
) or settings.voice_weight_normalization == True:
path = await self._voice_manager.get_voice_path(voice)
if not path:
raise RuntimeError(f"Voice not found: {voice}")
logger.debug(f"Using single voice path: {path}")
return voice, path
total_weight = 0
for voice_index in range(0, len(split_voice), 2):
voice_object = split_voice[voice_index]
if "(" in voice_object and ")" in voice_object:
voice_name = voice_object.split("(")[0].strip()
voice_weight = float(voice_object.split("(")[1].split(")")[0])
# Concatenate all chunks
audio = (
np.concatenate(audio_chunks)
if len(audio_chunks) > 1
else audio_chunks[0]
)
else:
voice_name = voice_object
voice_weight = 1
# Process single chunk
phonemes, tokens = TTSModel.process_text(text, voice[0])
audio = TTSModel.generate_from_tokens(tokens, voicepack, speed)
total_weight += voice_weight
split_voice[voice_index] = (voice_name, voice_weight)
processing_time = time.time() - start_time
return audio, processing_time
# If voice_weight_normalization is false prevent normalizing the weights by setting the total_weight to 1 so it divides each weight by 1
if settings.voice_weight_normalization == False:
total_weight = 1
# Load the first voice as the starting point for voices to be combined onto
path = await self._voice_manager.get_voice_path(split_voice[0][0])
combined_tensor = await self._load_voice_from_path(
path, split_voice[0][1] / total_weight
)
# Loop through each + or - in split_voice so they can be applied to combined voice
for operation_index in range(1, len(split_voice) - 1, 2):
# Get the voice path of the voice 1 index ahead of the operator
path = await self._voice_manager.get_voice_path(
split_voice[operation_index + 1][0]
)
voice_tensor = await self._load_voice_from_path(
path, split_voice[operation_index + 1][1] / total_weight
)
# Either add or subtract the voice from the current combined voice
if split_voice[operation_index] == "+":
combined_tensor += voice_tensor
else:
combined_tensor -= voice_tensor
# Save the new combined voice so it can be loaded latter
temp_dir = tempfile.gettempdir()
combined_path = os.path.join(temp_dir, f"{voice}.pt")
logger.debug(f"Saving combined voice to: {combined_path}")
torch.save(combined_tensor, combined_path)
return voice, combined_path
except Exception as e:
logger.error(f"Failed to get voice path: {e}")
logger.error(f"Error in audio generation: {str(e)}")
raise
async def generate_audio_stream(
self,
text: str,
voice: str,
writer: StreamingAudioWriter,
speed: float = 1.0,
speed: float,
output_format: str = "wav",
lang_code: Optional[str] = None,
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
return_timestamps: Optional[bool] = False,
) -> AsyncGenerator[AudioChunk, None]:
"""Generate and stream audio chunks."""
silent=False,
):
"""Generate and yield audio chunks as they're generated for real-time streaming"""
try:
stream_start = time.time()
# Create normalizer for consistent audio levels
stream_normalizer = AudioNormalizer()
chunk_index = 0
current_offset = 0.0
try:
# Get backend
backend = self.model_manager.get_backend()
# Get voice path, handling combined voices
voice_name, voice_path = await self._get_voices_path(voice)
logger.debug(f"Using voice path: {voice_path}")
# Use provided lang_code or determine from voice name
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
logger.info(
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream"
# Input validation and preprocessing
if not text:
raise ValueError("Text is empty")
preprocess_start = time.time()
normalized = normalize_text(text)
if not normalized:
raise ValueError("Text is empty after preprocessing")
text = str(normalized)
logger.debug(
f"Text preprocessing took: {(time.time() - preprocess_start)*1000:.1f}ms"
)
# Process text in chunks with smart splitting
async for chunk_text, tokens in smart_split(
text,
lang_code=pipeline_lang_code,
normalization_options=normalization_options,
):
# Voice validation and loading
voice_start = time.time()
voice_path = self._get_voice_path(voice)
if not voice_path:
raise ValueError(f"Voice not found: {voice}")
voicepack = self._load_voice(voice_path)
logger.debug(
f"Voice loading took: {(time.time() - voice_start)*1000:.1f}ms"
)
# Process chunks as they're generated
is_first = True
chunks_processed = 0
# Process chunks as they come from generator
chunk_gen = chunker.split_text(text)
current_chunk = next(chunk_gen, None)
while current_chunk is not None:
next_chunk = next(chunk_gen, None) # Peek at next chunk
chunks_processed += 1
try:
# Process audio for chunk
async for chunk_data in self._process_chunk(
chunk_text, # Pass text for Kokoro V1
tokens, # Pass tokens for legacy backends
voice_name, # Pass voice name
voice_path, # Pass voice path
speed,
writer,
# Process text and generate audio
phonemes, tokens = TTSModel.process_text(current_chunk, voice[0])
chunk_audio = TTSModel.generate_from_tokens(
tokens, voicepack, speed
)
if chunk_audio is not None:
# Convert chunk with proper streaming header handling
chunk_bytes = AudioService.convert_audio(
chunk_audio,
24000,
output_format,
is_first=(chunk_index == 0),
is_last=False, # We'll update the last chunk later
is_first_chunk=is_first,
normalizer=stream_normalizer,
lang_code=pipeline_lang_code, # Pass lang_code
return_timestamps=return_timestamps,
):
if chunk_data.word_timestamps is not None:
for timestamp in chunk_data.word_timestamps:
timestamp.start_time += current_offset
timestamp.end_time += current_offset
current_offset += len(chunk_data.audio) / 24000
if chunk_data.output is not None:
yield chunk_data
else:
logger.warning(
f"No audio generated for chunk: '{chunk_text[:100]}...'"
is_last_chunk=(next_chunk is None), # Last if no next chunk
stream=True # Ensure proper streaming format handling
)
chunk_index += 1
yield chunk_bytes
is_first = False
else:
logger.error(f"No audio generated for chunk: '{current_chunk}'")
except Exception as e:
logger.error(
f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}"
f"Failed to generate audio for chunk: '{current_chunk}'. Error: {str(e)}"
)
continue
# Only finalize if we successfully processed at least one chunk
if chunk_index > 0:
try:
# Empty tokens list to finalize audio
async for chunk_data in self._process_chunk(
"", # Empty text
[], # Empty tokens
voice_name,
voice_path,
speed,
writer,
output_format,
is_first=False,
is_last=True, # Signal this is the last chunk
normalizer=stream_normalizer,
lang_code=pipeline_lang_code, # Pass lang_code
):
if chunk_data.output is not None:
yield chunk_data
except Exception as e:
logger.error(f"Failed to finalize audio stream: {str(e)}")
current_chunk = next_chunk # Move to next chunk
except Exception as e:
logger.error(f"Error in phoneme audio generation: {str(e)}")
raise e
async def generate_audio(
self,
text: str,
voice: str,
writer: StreamingAudioWriter,
speed: float = 1.0,
return_timestamps: bool = False,
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
lang_code: Optional[str] = None,
) -> AudioChunk:
"""Generate complete audio for text using streaming internally."""
audio_data_chunks = []
try:
async for audio_stream_data in self.generate_audio_stream(
text,
voice,
writer,
speed=speed,
normalization_options=normalization_options,
return_timestamps=return_timestamps,
lang_code=lang_code,
output_format=None,
):
if len(audio_stream_data.audio) > 0:
audio_data_chunks.append(audio_stream_data)
combined_audio_data = AudioChunk.combine(audio_data_chunks)
return combined_audio_data
except Exception as e:
logger.error(f"Error in audio generation: {str(e)}")
logger.error(f"Error in audio generation stream: {str(e)}")
raise
async def combine_voices(self, voices: List[str]) -> torch.Tensor:
"""Combine multiple voices.
def _save_audio(self, audio: torch.Tensor, filepath: str):
"""Save audio to file"""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
wavfile.write(filepath, 24000, audio)
Returns:
Combined voice tensor
"""
def _audio_to_bytes(self, audio: torch.Tensor) -> bytes:
"""Convert audio tensor to WAV bytes"""
buffer = io.BytesIO()
wavfile.write(buffer, 24000, audio)
return buffer.getvalue()
return await self._voice_manager.combine_voices(voices)
async def combine_voices(self, voices: List[str]) -> str:
"""Combine multiple voices into a new voice"""
if len(voices) < 2:
raise ValueError("At least 2 voices are required for combination")
# Load voices
t_voices: List[torch.Tensor] = []
v_name: List[str] = []
for voice in voices:
try:
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt")
voicepack = torch.load(
voice_path, map_location=TTSModel.get_device(), weights_only=True
)
t_voices.append(voicepack)
v_name.append(voice)
except Exception as e:
raise ValueError(f"Failed to load voice {voice}: {str(e)}")
# Combine voices
try:
f: str = "_".join(v_name)
v = torch.mean(torch.stack(t_voices), dim=0)
combined_path = os.path.join(TTSModel.VOICES_DIR, f"{f}.pt")
# Save combined voice
try:
torch.save(v, combined_path)
except Exception as e:
raise RuntimeError(
f"Failed to save combined voice to {combined_path}: {str(e)}"
)
return f
except Exception as e:
if not isinstance(e, (ValueError, RuntimeError)):
raise RuntimeError(f"Error combining voices: {str(e)}")
raise
async def list_voices(self) -> List[str]:
"""List available voices."""
return await self._voice_manager.list_voices()
async def generate_from_phonemes(
self,
phonemes: str,
voice: str,
speed: float = 1.0,
lang_code: Optional[str] = None,
) -> Tuple[np.ndarray, float]:
"""Generate audio directly from phonemes.
Args:
phonemes: Phonemes in Kokoro format
voice: Voice name
speed: Speed multiplier
lang_code: Optional language code override
Returns:
Tuple of (audio array, processing time)
"""
start_time = time.time()
"""List all available voices"""
voices = []
try:
# Get backend and voice path
backend = self.model_manager.get_backend()
voice_name, voice_path = await self._get_voices_path(voice)
if isinstance(backend, KokoroV1):
# For Kokoro V1, use generate_from_tokens with raw phonemes
result = None
# Use provided lang_code or determine from voice name
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
logger.info(
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme pipeline"
)
try:
# Use backend's pipeline management
for r in backend._get_pipeline(
pipeline_lang_code
).generate_from_tokens(
tokens=phonemes, # Pass raw phonemes string
voice=voice_path,
speed=speed,
):
if r.audio is not None:
result = r
break
it = await aiofiles.os.scandir(TTSModel.VOICES_DIR)
for entry in it:
if entry.name.endswith(".pt"):
voices.append(entry.name[:-3]) # Remove .pt extension
except Exception as e:
logger.error(f"Failed to generate from phonemes: {e}")
raise RuntimeError(f"Phoneme generation failed: {e}")
if result is None or result.audio is None:
raise ValueError("No audio generated")
processing_time = time.time() - start_time
return result.audio.numpy(), processing_time
else:
raise ValueError(
"Phoneme generation only supported with Kokoro V1 backend"
)
except Exception as e:
logger.error(f"Error in phoneme audio generation: {str(e)}")
raise
logger.error(f"Error listing voices: {str(e)}")
return sorted(voices)

View file

@ -0,0 +1,60 @@
import os
from typing import List, Tuple
import torch
from loguru import logger
from .tts_model import TTSModel
from .tts_service import TTSService
from ..core.config import settings
class WarmupService:
"""Service for warming up TTS models and voice caches"""
def __init__(self):
"""Initialize warmup service and ensure model is ready"""
# Initialize model if not already initialized
if TTSModel._instance is None:
TTSModel.initialize(settings.model_dir)
self.tts_service = TTSService()
def load_voices(self) -> List[Tuple[str, torch.Tensor]]:
"""Load and cache voices up to LRU limit"""
# Get all voices sorted by filename length (shorter names first, usually base voices)
voice_files = sorted(
[f for f in os.listdir(TTSModel.VOICES_DIR) if f.endswith(".pt")], key=len
)
n_voices_cache = 1
loaded_voices = []
for voice_file in voice_files[:n_voices_cache]:
try:
voice_path = os.path.join(TTSModel.VOICES_DIR, voice_file)
# load using service, lru cache
voicepack = self.tts_service._load_voice(voice_path)
loaded_voices.append(
(voice_file[:-3], voicepack)
) # Store name and tensor
# voicepack = torch.load(voice_path, map_location=TTSModel.get_device(), weights_only=True)
# logger.info(f"Loaded voice {voice_file[:-3]} into cache")
except Exception as e:
logger.error(f"Failed to load voice {voice_file}: {e}")
logger.info(f"Pre-loaded {len(loaded_voices)} voices into cache")
return loaded_voices
async def warmup_voices(
self, warmup_text: str, loaded_voices: List[Tuple[str, torch.Tensor]]
):
"""Warm up voice inference and streaming"""
n_warmups = 1
for voice_name, _ in loaded_voices[:n_warmups]:
try:
logger.info(f"Running warmup inference on voice {voice_name}")
async for _ in self.tts_service.generate_audio_stream(
warmup_text, voice_name, 1.0, "pcm"
):
pass # Process all chunks to properly warm up
logger.info(f"Completed warmup for voice {voice_name}")
except Exception as e:
logger.warning(f"Warmup failed for voice {voice_name}: {e}")

View file

@ -1,17 +1,3 @@
from .schemas import (
CaptionedSpeechRequest,
CaptionedSpeechResponse,
OpenAISpeechRequest,
TTSStatus,
VoiceCombineRequest,
WordTimestamp,
)
from .schemas import OpenAISpeechRequest
__all__ = [
"OpenAISpeechRequest",
"CaptionedSpeechRequest",
"CaptionedSpeechResponse",
"WordTimestamp",
"TTSStatus",
"VoiceCombineRequest",
]
__all__ = ["OpenAISpeechRequest"]

View file

@ -1,50 +0,0 @@
import json
import typing
from collections.abc import AsyncIterable, Iterable
from pydantic import BaseModel
from starlette.background import BackgroundTask
from starlette.concurrency import iterate_in_threadpool
from starlette.responses import JSONResponse, StreamingResponse
class JSONStreamingResponse(StreamingResponse, JSONResponse):
"""StreamingResponse that also render with JSON."""
def __init__(
self,
content: Iterable | AsyncIterable,
status_code: int = 200,
headers: dict[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
) -> None:
if isinstance(content, AsyncIterable):
self._content_iterable: AsyncIterable = content
else:
self._content_iterable = iterate_in_threadpool(content)
async def body_iterator() -> AsyncIterable[bytes]:
async for content_ in self._content_iterable:
if isinstance(content_, BaseModel):
content_ = content_.model_dump()
yield self.render(content_)
self.body_iterator = body_iterator()
self.status_code = status_code
if media_type is not None:
self.media_type = media_type
self.background = background
self.init_headers(headers)
def render(self, content: typing.Any) -> bytes:
return (
json.dumps(
content,
ensure_ascii=False,
allow_nan=False,
indent=None,
separators=(",", ":"),
)
+ "\n"
).encode("utf-8")

View file

@ -1,16 +0,0 @@
"""Voice configuration schemas."""
from pydantic import BaseModel, Field
class VoiceConfig(BaseModel):
"""Voice configuration."""
use_cache: bool = Field(True, description="Whether to cache loaded voices")
cache_size: int = Field(3, description="Number of voices to cache")
validate_on_load: bool = Field(
True, description="Whether to validate voices when loading"
)
class Config:
frozen = True # Make config immutable

View file

@ -1,7 +1,7 @@
from enum import Enum
from typing import List, Literal, Optional, Union
from typing import List, Union, Literal
from pydantic import BaseModel, Field
from pydantic import Field, BaseModel
class VoiceCombineRequest(BaseModel):
@ -22,108 +22,11 @@ class TTSStatus(str, Enum):
# OpenAI-compatible schemas
class WordTimestamp(BaseModel):
"""Word-level timestamp information"""
word: str = Field(..., description="The word or token")
start_time: float = Field(..., description="Start time in seconds")
end_time: float = Field(..., description="End time in seconds")
class CaptionedSpeechResponse(BaseModel):
"""Response schema for captioned speech endpoint"""
audio: str = Field(..., description="The generated audio data encoded in base 64")
audio_format: str = Field(..., description="The format of the output audio")
timestamps: Optional[List[WordTimestamp]] = Field(
..., description="Word-level timestamps"
)
class NormalizationOptions(BaseModel):
"""Options for the normalization system"""
normalize: bool = Field(
default=True,
description="Normalizes input text to make it easier for the model to say",
)
unit_normalization: bool = Field(
default=False, description="Transforms units like 10KB to 10 kilobytes"
)
url_normalization: bool = Field(
default=True,
description="Changes urls so they can be properly pronounced by kokoro",
)
email_normalization: bool = Field(
default=True,
description="Changes emails so they can be properly pronouced by kokoro",
)
optional_pluralization_normalization: bool = Field(
default=True,
description="Replaces (s) with s so some words get pronounced correctly",
)
phone_normalization: bool = Field(
default=True,
description="Changes phone numbers so they can be properly pronouced by kokoro",
)
class OpenAISpeechRequest(BaseModel):
"""Request schema for OpenAI-compatible speech endpoint"""
model: str = Field(
default="kokoro",
description="The model to use for generation. Supported models: tts-1, tts-1-hd, kokoro",
)
model: Literal["tts-1", "tts-1-hd", "kokoro"] = "kokoro"
input: str = Field(..., description="The text to generate audio for")
voice: str = Field(
default="af_heart",
description="The voice to use for generation. Can be a base voice or a combined voice name.",
)
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field(
default="mp3",
description="The format to return audio in. Supported formats: mp3, opus, flac, wav, pcm. PCM format returns raw 16-bit samples without headers. AAC is not currently supported.",
)
download_format: Optional[Literal["mp3", "opus", "aac", "flac", "wav", "pcm"]] = (
Field(
default=None,
description="Optional different format for the final download. If not provided, uses response_format.",
)
)
speed: float = Field(
default=1.0,
ge=0.25,
le=4.0,
description="The speed of the generated audio. Select a value from 0.25 to 4.0.",
)
stream: bool = Field(
default=True, # Default to streaming for OpenAI compatibility
description="If true (default), audio will be streamed as it's generated. Each chunk will be a complete sentence.",
)
return_download_link: bool = Field(
default=False,
description="If true, returns a download link in X-Download-Path header after streaming completes",
)
lang_code: Optional[str] = Field(
default=None,
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
)
normalization_options: Optional[NormalizationOptions] = Field(
default=NormalizationOptions(),
description="Options for the normalization system",
)
class CaptionedSpeechRequest(BaseModel):
"""Request schema for captioned speech endpoint"""
model: str = Field(
default="kokoro",
description="The model to use for generation. Supported models: tts-1, tts-1-hd, kokoro",
)
input: str = Field(..., description="The text to generate audio for")
voice: str = Field(
default="af_heart",
default="af",
description="The voice to use for generation. Can be a base voice or a combined voice name.",
)
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field(
@ -140,19 +43,3 @@ class CaptionedSpeechRequest(BaseModel):
default=True, # Default to streaming for OpenAI compatibility
description="If true (default), audio will be streamed as it's generated. Each chunk will be a complete sentence.",
)
return_timestamps: bool = Field(
default=True,
description="If true (default), returns word-level timestamps in the response",
)
return_download_link: bool = Field(
default=False,
description="If true, returns a download link in X-Download-Path header after streaming completes",
)
lang_code: Optional[str] = Field(
default=None,
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
)
normalization_options: Optional[NormalizationOptions] = Field(
default=NormalizationOptions(),
description="Options for the normalization system",
)

View file

@ -1,6 +1,4 @@
from typing import List, Optional, Union
from pydantic import BaseModel, Field, field_validator
from pydantic import Field, BaseModel
class PhonemeRequest(BaseModel):
@ -13,29 +11,9 @@ class PhonemeResponse(BaseModel):
tokens: list[int]
class StitchOptions(BaseModel):
"""Options for stitching audio chunks together"""
gap_method: str = Field(
default="static_trim",
description="Method to handle gaps between chunks. Currently only 'static_trim' supported.",
)
trim_ms: int = Field(
default=0,
ge=0,
description="Milliseconds to trim from chunk boundaries when using static_trim",
)
@field_validator("gap_method")
@classmethod
def validate_gap_method(cls, v: str) -> str:
if v != "static_trim":
raise ValueError("Currently only 'static_trim' gap method is supported")
return v
class GenerateFromPhonemesRequest(BaseModel):
"""Simple request for phoneme-to-speech generation"""
phonemes: str = Field(..., description="Phoneme string to synthesize")
phonemes: str
voice: str = Field(..., description="Voice ID to use for generation")
speed: float = Field(
default=1.0, ge=0.1, le=5.0, description="Speed factor for generation"
)

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Some files were not shown because too many files have changed in this diff Show more