mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Merge pre-release: Update CI workflow for uv
This commit is contained in:
commit
5045cf968e
69 changed files with 9202 additions and 762 deletions
|
@ -7,6 +7,7 @@ omit =
|
|||
MagicMock/*
|
||||
test_*.py
|
||||
examples/*
|
||||
src/builds/*
|
||||
|
||||
[report]
|
||||
exclude_lines =
|
||||
|
|
71
.github/workflows/ci.yml
vendored
71
.github/workflows/ci.yml
vendored
|
@ -1,51 +1,32 @@
|
|||
# name: CI
|
||||
name: CI
|
||||
|
||||
# on:
|
||||
# push:
|
||||
# branches: [ "develop", "master" ]
|
||||
# pull_request:
|
||||
# branches: [ "develop", "master" ]
|
||||
on:
|
||||
push:
|
||||
branches: [ "master", "pre-release" ]
|
||||
pull_request:
|
||||
branches: [ "master", "pre-release" ]
|
||||
|
||||
# jobs:
|
||||
# test:
|
||||
# runs-on: ubuntu-latest
|
||||
# strategy:
|
||||
# matrix:
|
||||
# python-version: ["3.9", "3.10", "3.11"]
|
||||
# fail-fast: false
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
fail-fast: false
|
||||
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
# - name: Set up Python ${{ matrix.python-version }}
|
||||
# uses: actions/setup-python@v5
|
||||
# with:
|
||||
# python-version: ${{ matrix.python-version }}
|
||||
|
||||
# - name: Set up pip cache
|
||||
# uses: actions/cache@v3
|
||||
# with:
|
||||
# path: ~/.cache/pip
|
||||
# key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements*.txt') }}
|
||||
# restore-keys: |
|
||||
# ${{ runner.os }}-pip-
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
enable-cache: true
|
||||
|
||||
# - name: Install PyTorch CPU
|
||||
# run: |
|
||||
# python -m pip install --upgrade pip
|
||||
# pip install torch --index-url https://download.pytorch.org/whl/cpu
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e .[test,cpu]
|
||||
|
||||
# - name: Install dependencies
|
||||
# run: |
|
||||
# pip install ruff pytest-cov
|
||||
# pip install -r requirements.txt
|
||||
# pip install -r requirements-test.txt
|
||||
|
||||
# - name: Lint with ruff
|
||||
# run: |
|
||||
# ruff check .
|
||||
|
||||
|
||||
# - name: Test with pytest
|
||||
# run: |
|
||||
# pytest --asyncio-mode=auto --cov=api --cov-report=term-missing
|
||||
- name: Run Tests
|
||||
run: |
|
||||
uv run pytest api/tests/ --asyncio-mode=auto --cov=api --cov-report=term-missing
|
||||
|
|
124
.github/workflows/docker-publish.yml
vendored
124
.github/workflows/docker-publish.yml
vendored
|
@ -1,7 +1,9 @@
|
|||
name: Docker Build and Publish
|
||||
name: Docker Build, Slim, and Publish
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
tags: [ 'v*.*.*' ]
|
||||
# Allow manual trigger from GitHub UI
|
||||
workflow_dispatch:
|
||||
|
@ -16,6 +18,7 @@ jobs:
|
|||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
actions: write
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
|
@ -28,67 +31,76 @@ jobs:
|
|||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
# Extract metadata for GPU image
|
||||
- name: Extract metadata (tags, labels) for GPU Docker
|
||||
id: meta-gpu
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
tags: |
|
||||
type=semver,pattern=v{{version}}
|
||||
type=semver,pattern=v{{major}}.{{minor}}
|
||||
type=semver,pattern=v{{major}}
|
||||
type=raw,value=latest
|
||||
# Set up image names (converting to lowercase)
|
||||
- name: Set image names
|
||||
run: |
|
||||
echo "GPU_IMAGE_NAME=${{ env.REGISTRY }}/$(echo ${{ env.IMAGE_NAME }} | tr '[:upper:]' '[:lower:]')-gpu" >> $GITHUB_ENV
|
||||
echo "CPU_IMAGE_NAME=${{ env.REGISTRY }}/$(echo ${{ env.IMAGE_NAME }} | tr '[:upper:]' '[:lower:]')-cpu" >> $GITHUB_ENV
|
||||
echo "UI_IMAGE_NAME=${{ env.REGISTRY }}/$(echo ${{ env.IMAGE_NAME }} | tr '[:upper:]' '[:lower:]')-ui" >> $GITHUB_ENV
|
||||
|
||||
# Extract metadata for CPU image
|
||||
- name: Extract metadata (tags, labels) for CPU Docker
|
||||
id: meta-cpu
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
flavor: |
|
||||
suffix=-cpu
|
||||
tags: |
|
||||
type=semver,pattern=v{{version}}
|
||||
type=semver,pattern=v{{major}}.{{minor}}
|
||||
type=semver,pattern=v{{major}}
|
||||
type=raw,value=latest
|
||||
|
||||
# Build and push GPU version
|
||||
- name: Build and push GPU Docker image
|
||||
# Build GPU version
|
||||
- name: Build GPU Docker image
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: ./Dockerfile
|
||||
push: true
|
||||
tags: ${{ steps.meta-gpu.outputs.tags }}
|
||||
labels: ${{ steps.meta-gpu.outputs.labels }}
|
||||
file: ./docker/gpu/Dockerfile
|
||||
push: false
|
||||
load: true
|
||||
tags: ${{ env.GPU_IMAGE_NAME }}:v0.1.0
|
||||
build-args: |
|
||||
DOCKER_BUILDKIT=1
|
||||
platforms: linux/amd64
|
||||
|
||||
# Build and push CPU version
|
||||
- name: Build and push CPU Docker image
|
||||
# Slim GPU version
|
||||
- name: Slim GPU Docker image
|
||||
uses: kitabisa/docker-slim-action@v1
|
||||
env:
|
||||
DSLIM_HTTP_PROBE: false
|
||||
with:
|
||||
target: ${{ env.GPU_IMAGE_NAME }}:v0.1.0
|
||||
tag: v0.1.0-slim
|
||||
|
||||
# Push GPU versions
|
||||
- name: Push GPU Docker images
|
||||
run: |
|
||||
docker push ${{ env.GPU_IMAGE_NAME }}:v0.1.0
|
||||
docker push ${{ env.GPU_IMAGE_NAME }}:v0.1.0-slim
|
||||
docker tag ${{ env.GPU_IMAGE_NAME }}:v0.1.0 ${{ env.GPU_IMAGE_NAME }}:latest
|
||||
docker tag ${{ env.GPU_IMAGE_NAME }}:v0.1.0-slim ${{ env.GPU_IMAGE_NAME }}:latest-slim
|
||||
docker push ${{ env.GPU_IMAGE_NAME }}:latest
|
||||
docker push ${{ env.GPU_IMAGE_NAME }}:latest-slim
|
||||
|
||||
# Build CPU version
|
||||
- name: Build CPU Docker image
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: ./Dockerfile.cpu
|
||||
push: true
|
||||
tags: ${{ steps.meta-cpu.outputs.tags }}
|
||||
labels: ${{ steps.meta-cpu.outputs.labels }}
|
||||
file: ./docker/cpu/Dockerfile
|
||||
push: false
|
||||
load: true
|
||||
tags: ${{ env.CPU_IMAGE_NAME }}:v0.1.0
|
||||
build-args: |
|
||||
DOCKER_BUILDKIT=1
|
||||
platforms: linux/amd64
|
||||
|
||||
# Extract metadata for UI image
|
||||
- name: Extract metadata (tags, labels) for UI Docker
|
||||
id: meta-ui
|
||||
uses: docker/metadata-action@v5
|
||||
# Slim CPU version
|
||||
- name: Slim CPU Docker image
|
||||
uses: kitabisa/docker-slim-action@v1
|
||||
env:
|
||||
DSLIM_HTTP_PROBE: false
|
||||
with:
|
||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
flavor: |
|
||||
suffix=-ui
|
||||
tags: |
|
||||
type=semver,pattern=v{{version}}
|
||||
type=semver,pattern=v{{major}}.{{minor}}
|
||||
type=semver,pattern=v{{major}}
|
||||
type=raw,value=latest
|
||||
target: ${{ env.CPU_IMAGE_NAME }}:v0.1.0
|
||||
tag: v0.1.0-slim
|
||||
|
||||
# Push CPU versions
|
||||
- name: Push CPU Docker images
|
||||
run: |
|
||||
docker push ${{ env.CPU_IMAGE_NAME }}:v0.1.0
|
||||
docker push ${{ env.CPU_IMAGE_NAME }}:v0.1.0-slim
|
||||
docker tag ${{ env.CPU_IMAGE_NAME }}:v0.1.0 ${{ env.CPU_IMAGE_NAME }}:latest
|
||||
docker tag ${{ env.CPU_IMAGE_NAME }}:v0.1.0-slim ${{ env.CPU_IMAGE_NAME }}:latest-slim
|
||||
docker push ${{ env.CPU_IMAGE_NAME }}:latest
|
||||
docker push ${{ env.CPU_IMAGE_NAME }}:latest-slim
|
||||
|
||||
# Build and push UI version
|
||||
- name: Build and push UI Docker image
|
||||
|
@ -97,8 +109,11 @@ jobs:
|
|||
context: ./ui
|
||||
file: ./ui/Dockerfile
|
||||
push: true
|
||||
tags: ${{ steps.meta-ui.outputs.tags }}
|
||||
labels: ${{ steps.meta-ui.outputs.labels }}
|
||||
tags: |
|
||||
${{ env.UI_IMAGE_NAME }}:v0.1.0
|
||||
${{ env.UI_IMAGE_NAME }}:latest
|
||||
build-args: |
|
||||
DOCKER_BUILDKIT=1
|
||||
platforms: linux/amd64
|
||||
|
||||
create-release:
|
||||
|
@ -108,13 +123,16 @@ jobs:
|
|||
if: startsWith(github.ref, 'refs/tags/')
|
||||
permissions:
|
||||
contents: write
|
||||
packages: write
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Create Release
|
||||
uses: softprops/action-gh-release@v1
|
||||
env:
|
||||
IS_PRERELEASE: ${{ contains(github.ref, '-pre') }}
|
||||
with:
|
||||
generate_release_notes: true
|
||||
draft: false
|
||||
prerelease: false
|
||||
prerelease: ${{ contains(github.ref, '-pre') }}
|
||||
|
|
71
.gitignore
vendored
71
.gitignore
vendored
|
@ -2,51 +2,78 @@
|
|||
.git
|
||||
|
||||
# Python
|
||||
__pycache__
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pyd
|
||||
*.pt
|
||||
.Python
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
.Python
|
||||
.pytest_cache
|
||||
.coverage
|
||||
.coveragerc
|
||||
|
||||
# Python package build artifacts
|
||||
*.egg-info/
|
||||
*.egg
|
||||
dist/
|
||||
build/
|
||||
|
||||
# Environment
|
||||
# .env
|
||||
.venv
|
||||
.venv/
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
|
||||
# IDE
|
||||
.idea
|
||||
.vscode
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# Project specific
|
||||
*examples/*.wav
|
||||
*examples/*.pcm
|
||||
*examples/*.mp3
|
||||
*examples/*.flac
|
||||
*examples/*.acc
|
||||
*examples/*.ogg
|
||||
# Model files
|
||||
*.pt
|
||||
*.pth
|
||||
*.tar*
|
||||
|
||||
# Voice files
|
||||
api/src/voices/af_bella.pt
|
||||
api/src/voices/af_nicole.pt
|
||||
api/src/voices/af_sarah.pt
|
||||
api/src/voices/af_sky.pt
|
||||
api/src/voices/af.pt
|
||||
api/src/voices/am_adam.pt
|
||||
api/src/voices/am_michael.pt
|
||||
api/src/voices/bf_emma.pt
|
||||
api/src/voices/bf_isabella.pt
|
||||
api/src/voices/bm_george.pt
|
||||
api/src/voices/bm_lewis.pt
|
||||
|
||||
# Audio files
|
||||
examples/*.wav
|
||||
examples/*.pcm
|
||||
examples/*.mp3
|
||||
examples/*.flac
|
||||
examples/*.acc
|
||||
examples/*.ogg
|
||||
examples/speech.mp3
|
||||
examples/phoneme_examples/output/example_1.wav
|
||||
examples/phoneme_examples/output/example_2.wav
|
||||
examples/phoneme_examples/output/example_3.wav
|
||||
|
||||
# Other project files
|
||||
Kokoro-82M/
|
||||
ui/data
|
||||
tests/
|
||||
*.md
|
||||
*.txt
|
||||
requirements.txt
|
||||
ui/data/
|
||||
EXTERNAL_UV_DOCUMENTATION*
|
||||
|
||||
# Docker
|
||||
Dockerfile*
|
||||
docker-compose*
|
||||
|
||||
*.egg-info
|
||||
*.pt
|
||||
*.wav
|
||||
*.tar*
|
||||
examples/assorted_checks/River_of_Teet_-_Sarah_Gailey.epub
|
||||
examples/ebook_test/chapter_to_audio.py
|
||||
examples/ebook_test/chapters_to_audio.py
|
||||
examples/ebook_test/parse_epub.py
|
||||
examples/ebook_test/River_of_Teet_-_Sarah_Gailey.epub
|
||||
examples/ebook_test/River_of_Teet_-_Sarah_Gailey.txt
|
||||
|
|
1
.python-version
Normal file
1
.python-version
Normal file
|
@ -0,0 +1 @@
|
|||
3.10
|
|
@ -1,11 +1,12 @@
|
|||
line-length = 88
|
||||
|
||||
exclude = ["examples"]
|
||||
|
||||
[lint]
|
||||
select = ["I"]
|
||||
|
||||
[lint.isort]
|
||||
combine-as-imports = true
|
||||
force-wrap-aliases = true
|
||||
length-sort = true
|
||||
split-on-trailing-comma = true
|
||||
section-order = ["future", "standard-library", "third-party", "first-party", "local-folder"]
|
||||
|
|
11
CHANGELOG.md
11
CHANGELOG.md
|
@ -2,6 +2,17 @@
|
|||
|
||||
Notable changes to this project will be documented in this file.
|
||||
|
||||
## [v0.1.0] - 2025-01-13
|
||||
### Changed
|
||||
- Major Docker improvements:
|
||||
- Baked model directly into Dockerfile for improved deployment reliability
|
||||
- Switched to uv for dependency management
|
||||
- Streamlined container builds and reduced image sizes
|
||||
- Dependency Management:
|
||||
- Migrated from pip/poetry to uv for faster, more reliable package management
|
||||
- Added uv.lock for deterministic builds
|
||||
- Updated dependency resolution strategy
|
||||
|
||||
## [v0.0.5post1] - 2025-01-11
|
||||
### Fixed
|
||||
- Docker image tagging and versioning improvements (-gpu, -cpu, -ui)
|
||||
|
|
44
Dockerfile
44
Dockerfile
|
@ -1,44 +0,0 @@
|
|||
FROM nvidia/cuda:12.1.0-base-ubuntu22.04
|
||||
|
||||
# Install base system dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3-pip \
|
||||
python3-dev \
|
||||
espeak-ng \
|
||||
git \
|
||||
libsndfile1 \
|
||||
curl \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install PyTorch with CUDA support first
|
||||
RUN pip3 install --no-cache-dir torch==2.5.1 --extra-index-url https://download.pytorch.org/whl/cu121
|
||||
|
||||
# Install all other dependencies from requirements.txt
|
||||
COPY requirements.txt .
|
||||
RUN pip3 install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Create non-root user
|
||||
RUN useradd -m -u 1000 appuser
|
||||
|
||||
# Create model directory and set ownership
|
||||
RUN mkdir -p /app/Kokoro-82M && \
|
||||
chown -R appuser:appuser /app
|
||||
|
||||
# Switch to non-root user
|
||||
USER appuser
|
||||
|
||||
# Run with Python unbuffered output for live logging
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# Copy only necessary application code
|
||||
COPY --chown=appuser:appuser api /app/api
|
||||
|
||||
# Set Python path (app first for our imports, then model dir for model imports)
|
||||
ENV PYTHONPATH=/app:/app/Kokoro-82M
|
||||
|
||||
# Run FastAPI server with debug logging and reload
|
||||
CMD ["uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]
|
|
@ -1,44 +0,0 @@
|
|||
FROM ubuntu:22.04
|
||||
|
||||
# Install base system dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3-pip \
|
||||
python3-dev \
|
||||
espeak-ng \
|
||||
git \
|
||||
libsndfile1 \
|
||||
curl \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install PyTorch CPU version and ONNX runtime
|
||||
RUN pip3 install --no-cache-dir torch==2.5.1 --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
# Install all other dependencies from requirements.txt
|
||||
COPY requirements.txt .
|
||||
RUN pip3 install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy application code and model
|
||||
COPY . /app/
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Run with Python unbuffered output for live logging
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# Create non-root user
|
||||
RUN useradd -m -u 1000 appuser
|
||||
|
||||
# Create directories and set permissions
|
||||
RUN mkdir -p /app/Kokoro-82M && \
|
||||
chown -R appuser:appuser /app
|
||||
|
||||
# Switch to non-root user
|
||||
USER appuser
|
||||
|
||||
# Set Python path (app first for our imports, then model dir for model imports)
|
||||
ENV PYTHONPATH=/app:/app/Kokoro-82M
|
||||
|
||||
# Run FastAPI server with debug logging and reload
|
||||
CMD ["uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]
|
70
MigrationWorkingNotes.md
Normal file
70
MigrationWorkingNotes.md
Normal file
|
@ -0,0 +1,70 @@
|
|||
# UV Setup
|
||||
Deprecated notes for myself
|
||||
## Structure
|
||||
```
|
||||
docker/
|
||||
├── cpu/
|
||||
│ ├── pyproject.toml # CPU deps (torch CPU)
|
||||
│ └── requirements.lock # CPU lockfile
|
||||
├── gpu/
|
||||
│ ├── pyproject.toml # GPU deps (torch CUDA)
|
||||
│ └── requirements.lock # GPU lockfile
|
||||
└── shared/
|
||||
└── pyproject.toml # Common deps
|
||||
```
|
||||
|
||||
## Regenerate Lock Files
|
||||
|
||||
### CPU
|
||||
```bash
|
||||
cd docker/cpu
|
||||
uv pip compile pyproject.toml ../shared/pyproject.toml --output-file requirements.lock
|
||||
```
|
||||
|
||||
### GPU
|
||||
```bash
|
||||
cd docker/gpu
|
||||
uv pip compile pyproject.toml ../shared/pyproject.toml --output-file requirements.lock
|
||||
```
|
||||
|
||||
## Local Dev Setup
|
||||
|
||||
### CPU
|
||||
```bash
|
||||
cd docker/cpu
|
||||
uv venv
|
||||
.venv\Scripts\activate # Windows
|
||||
uv pip sync requirements.lock
|
||||
```
|
||||
|
||||
### GPU
|
||||
```bash
|
||||
cd docker/gpu
|
||||
uv venv
|
||||
.venv\Scripts\activate # Windows
|
||||
uv pip sync requirements.lock --extra-index-url https://download.pytorch.org/whl/cu121 --index-strategy unsafe-best-match
|
||||
```
|
||||
|
||||
### Run Server
|
||||
```bash
|
||||
# From project root with venv active:
|
||||
uvicorn api.src.main:app --reload
|
||||
```
|
||||
|
||||
## Docker
|
||||
|
||||
### CPU
|
||||
```bash
|
||||
cd docker/cpu
|
||||
docker compose up
|
||||
```
|
||||
|
||||
### GPU
|
||||
```bash
|
||||
cd docker/gpu
|
||||
docker compose up
|
||||
```
|
||||
|
||||
## Known Issues
|
||||
- Module imports: Run server from project root
|
||||
- PyTorch CUDA: Always use --extra-index-url and --index-strategy for GPU env
|
64
README.md
64
README.md
|
@ -2,9 +2,9 @@
|
|||
<img src="githubbanner.png" alt="Kokoro TTS Banner">
|
||||
</p>
|
||||
|
||||
# Kokoro TTS API
|
||||
# <sub><sub>_`FastKoko`_ </sub></sub>
|
||||
[]()
|
||||
[]()
|
||||
[]()
|
||||
[](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667) [](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero) [](https://www.buymeacoffee.com/remsky)
|
||||
|
||||
Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model
|
||||
|
@ -24,14 +24,30 @@ Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokor
|
|||
The service can be accessed through either the API endpoints or the Gradio web interface.
|
||||
|
||||
1. Install prerequisites:
|
||||
- Install [Docker Desktop](https://www.docker.com/products/docker-desktop/) + [Git](https://git-scm.com/downloads)
|
||||
- Clone and start the service:
|
||||
- Install [Docker Desktop](https://www.docker.com/products/docker-desktop/)
|
||||
- Clone the repository:
|
||||
```bash
|
||||
git clone https://github.com/remsky/Kokoro-FastAPI.git
|
||||
cd Kokoro-FastAPI
|
||||
docker compose up --build # for GPU
|
||||
#docker compose -f docker-compose.cpu.yml up --build # for CPU
|
||||
```
|
||||
|
||||
2. Start the service:
|
||||
|
||||
- Using Docker Compose (Full setup including UI):
|
||||
```bash
|
||||
cd docker/gpu # OR
|
||||
# cd docker/cpu # Run this or the above
|
||||
docker compose up --build
|
||||
```
|
||||
- OR running the API alone using Docker (model + voice packs baked in):
|
||||
```bash
|
||||
|
||||
docker run -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-cpu:latest # CPU
|
||||
docker run --gpus all -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-gpu:latest # Nvidia GPU
|
||||
# Minified versions are available with `:latest-slim` tag.
|
||||
```
|
||||
|
||||
|
||||
2. Run locally as an OpenAI-Compatible Speech Endpoint
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
@ -167,6 +183,21 @@ If you only want the API, just comment out everything in the docker-compose.yml
|
|||
Currently, voices created via the API are accessible here, but voice combination/creation has not yet been added
|
||||
|
||||
*Note: Recent updates for streaming could lead to temporary glitches. If so, pull from the most recent stable release v0.0.2 to restore*
|
||||
|
||||
### Disabling Local Saving
|
||||
|
||||
You can disable local saving of audio files and hide the file view in the UI by setting the `DISABLE_LOCAL_SAVING` environment variable to `true`. This is useful when running the service on a server where you don't want to store generated audio files locally.
|
||||
|
||||
When using Docker Compose:
|
||||
```yaml
|
||||
environment:
|
||||
- DISABLE_LOCAL_SAVING=true
|
||||
```
|
||||
|
||||
When running the Docker image directly:
|
||||
```bash
|
||||
docker run -p 7860:7860 -e DISABLE_LOCAL_SAVING=true ghcr.io/remsky/kokoro-fastapi-ui:latest
|
||||
```
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
@ -320,6 +351,27 @@ See `examples/phoneme_examples/generate_phonemes.py` for a sample script.
|
|||
|
||||
## Known Issues
|
||||
|
||||
<details>
|
||||
<summary>Versioning & Development</summary>
|
||||
|
||||
I'm doing what I can to keep things stable, but we are on an early and rapid set of build cycles here.
|
||||
If you run into trouble, you may have to roll back a version on the release tags if something comes up, or build up from source and/or troubleshoot + submit a PR. Will leave the branch up here for the last known stable points:
|
||||
|
||||
`v0.0.5post1`
|
||||
|
||||
Free and open source is a community effort, and I love working on this project, though there's only really so many hours in a day. If you'd like to support the work, feel free to open a PR, buy me a coffee, or report any bugs/features/etc you find during use.
|
||||
|
||||
<a href="https://www.buymeacoffee.com/remsky" target="_blank">
|
||||
<img
|
||||
src="https://cdn.buymeacoffee.com/buttons/v2/default-violet.png"
|
||||
alt="Buy Me A Coffee"
|
||||
style="height: 30px !important;width: 110px !important;"
|
||||
>
|
||||
</a>
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>Linux GPU Permissions</summary>
|
||||
|
||||
|
|
0
api/src/builds/__init__.py
Normal file
0
api/src/builds/__init__.py
Normal file
26
api/src/builds/config.json
Normal file
26
api/src/builds/config.json
Normal file
|
@ -0,0 +1,26 @@
|
|||
{
|
||||
"decoder": {
|
||||
"type": "istftnet",
|
||||
"upsample_kernel_sizes": [20, 12],
|
||||
"upsample_rates": [10, 6],
|
||||
"gen_istft_hop_size": 5,
|
||||
"gen_istft_n_fft": 20,
|
||||
"resblock_dilation_sizes": [
|
||||
[1, 3, 5],
|
||||
[1, 3, 5],
|
||||
[1, 3, 5]
|
||||
],
|
||||
"resblock_kernel_sizes": [3, 7, 11],
|
||||
"upsample_initial_channel": 512
|
||||
},
|
||||
"dim_in": 64,
|
||||
"dropout": 0.2,
|
||||
"hidden_dim": 512,
|
||||
"max_conv_dim": 512,
|
||||
"max_dur": 50,
|
||||
"multispeaker": true,
|
||||
"n_layer": 3,
|
||||
"n_mels": 80,
|
||||
"n_token": 178,
|
||||
"style_dim": 128
|
||||
}
|
524
api/src/builds/istftnet.py
Normal file
524
api/src/builds/istftnet.py
Normal file
|
@ -0,0 +1,524 @@
|
|||
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/istftnet.py
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from scipy.signal import get_window
|
||||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||
|
||||
|
||||
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size*dilation - dilation)/2)
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
class AdaIN1d(nn.Module):
|
||||
def __init__(self, style_dim, num_features):
|
||||
super().__init__()
|
||||
self.norm = nn.InstanceNorm1d(num_features, affine=False)
|
||||
self.fc = nn.Linear(style_dim, num_features*2)
|
||||
|
||||
def forward(self, x, s):
|
||||
h = self.fc(s)
|
||||
h = h.view(h.size(0), h.size(1), 1)
|
||||
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
||||
return (1 + gamma) * self.norm(x) + beta
|
||||
|
||||
class AdaINResBlock1(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
|
||||
super(AdaINResBlock1, self).__init__()
|
||||
self.convs1 = nn.ModuleList([
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2])))
|
||||
])
|
||||
self.convs1.apply(init_weights)
|
||||
|
||||
self.convs2 = nn.ModuleList([
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||
padding=get_padding(kernel_size, 1))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||
padding=get_padding(kernel_size, 1))),
|
||||
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||
padding=get_padding(kernel_size, 1)))
|
||||
])
|
||||
self.convs2.apply(init_weights)
|
||||
|
||||
self.adain1 = nn.ModuleList([
|
||||
AdaIN1d(style_dim, channels),
|
||||
AdaIN1d(style_dim, channels),
|
||||
AdaIN1d(style_dim, channels),
|
||||
])
|
||||
|
||||
self.adain2 = nn.ModuleList([
|
||||
AdaIN1d(style_dim, channels),
|
||||
AdaIN1d(style_dim, channels),
|
||||
AdaIN1d(style_dim, channels),
|
||||
])
|
||||
|
||||
self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
|
||||
self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
|
||||
|
||||
|
||||
def forward(self, x, s):
|
||||
for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
|
||||
xt = n1(x, s)
|
||||
xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
|
||||
xt = c1(xt)
|
||||
xt = n2(xt, s)
|
||||
xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for l in self.convs1:
|
||||
remove_weight_norm(l)
|
||||
for l in self.convs2:
|
||||
remove_weight_norm(l)
|
||||
|
||||
class TorchSTFT(torch.nn.Module):
|
||||
def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
|
||||
super().__init__()
|
||||
self.filter_length = filter_length
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.window = torch.from_numpy(get_window(window, win_length, fftbins=True).astype(np.float32))
|
||||
|
||||
def transform(self, input_data):
|
||||
forward_transform = torch.stft(
|
||||
input_data,
|
||||
self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device),
|
||||
return_complex=True)
|
||||
|
||||
return torch.abs(forward_transform), torch.angle(forward_transform)
|
||||
|
||||
def inverse(self, magnitude, phase):
|
||||
inverse_transform = torch.istft(
|
||||
magnitude * torch.exp(phase * 1j),
|
||||
self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device))
|
||||
|
||||
return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
|
||||
|
||||
def forward(self, input_data):
|
||||
self.magnitude, self.phase = self.transform(input_data)
|
||||
reconstruction = self.inverse(self.magnitude, self.phase)
|
||||
return reconstruction
|
||||
|
||||
class SineGen(torch.nn.Module):
|
||||
""" Definition of sine generator
|
||||
SineGen(samp_rate, harmonic_num = 0,
|
||||
sine_amp = 0.1, noise_std = 0.003,
|
||||
voiced_threshold = 0,
|
||||
flag_for_pulse=False)
|
||||
samp_rate: sampling rate in Hz
|
||||
harmonic_num: number of harmonic overtones (default 0)
|
||||
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
||||
noise_std: std of Gaussian noise (default 0.003)
|
||||
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
||||
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
||||
Note: when flag_for_pulse is True, the first time step of a voiced
|
||||
segment is always sin(np.pi) or cos(0)
|
||||
"""
|
||||
|
||||
def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
|
||||
sine_amp=0.1, noise_std=0.003,
|
||||
voiced_threshold=0,
|
||||
flag_for_pulse=False):
|
||||
super(SineGen, self).__init__()
|
||||
self.sine_amp = sine_amp
|
||||
self.noise_std = noise_std
|
||||
self.harmonic_num = harmonic_num
|
||||
self.dim = self.harmonic_num + 1
|
||||
self.sampling_rate = samp_rate
|
||||
self.voiced_threshold = voiced_threshold
|
||||
self.flag_for_pulse = flag_for_pulse
|
||||
self.upsample_scale = upsample_scale
|
||||
|
||||
def _f02uv(self, f0):
|
||||
# generate uv signal
|
||||
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
||||
return uv
|
||||
|
||||
def _f02sine(self, f0_values):
|
||||
""" f0_values: (batchsize, length, dim)
|
||||
where dim indicates fundamental tone and overtones
|
||||
"""
|
||||
# convert to F0 in rad. The interger part n can be ignored
|
||||
# because 2 * np.pi * n doesn't affect phase
|
||||
rad_values = (f0_values / self.sampling_rate) % 1
|
||||
|
||||
# initial phase noise (no noise for fundamental component)
|
||||
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
|
||||
device=f0_values.device)
|
||||
rand_ini[:, 0] = 0
|
||||
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
||||
|
||||
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
|
||||
if not self.flag_for_pulse:
|
||||
# # for normal case
|
||||
|
||||
# # To prevent torch.cumsum numerical overflow,
|
||||
# # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
|
||||
# # Buffer tmp_over_one_idx indicates the time step to add -1.
|
||||
# # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
|
||||
# tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
||||
# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
|
||||
# cumsum_shift = torch.zeros_like(rad_values)
|
||||
# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
||||
|
||||
# phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
||||
rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
|
||||
scale_factor=1/self.upsample_scale,
|
||||
mode="linear").transpose(1, 2)
|
||||
|
||||
# tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
||||
# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
|
||||
# cumsum_shift = torch.zeros_like(rad_values)
|
||||
# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
||||
|
||||
phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
||||
phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
|
||||
scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
|
||||
sines = torch.sin(phase)
|
||||
|
||||
else:
|
||||
# If necessary, make sure that the first time step of every
|
||||
# voiced segments is sin(pi) or cos(0)
|
||||
# This is used for pulse-train generation
|
||||
|
||||
# identify the last time step in unvoiced segments
|
||||
uv = self._f02uv(f0_values)
|
||||
uv_1 = torch.roll(uv, shifts=-1, dims=1)
|
||||
uv_1[:, -1, :] = 1
|
||||
u_loc = (uv < 1) * (uv_1 > 0)
|
||||
|
||||
# get the instantanouse phase
|
||||
tmp_cumsum = torch.cumsum(rad_values, dim=1)
|
||||
# different batch needs to be processed differently
|
||||
for idx in range(f0_values.shape[0]):
|
||||
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
|
||||
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
|
||||
# stores the accumulation of i.phase within
|
||||
# each voiced segments
|
||||
tmp_cumsum[idx, :, :] = 0
|
||||
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
|
||||
|
||||
# rad_values - tmp_cumsum: remove the accumulation of i.phase
|
||||
# within the previous voiced segment.
|
||||
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
|
||||
|
||||
# get the sines
|
||||
sines = torch.cos(i_phase * 2 * np.pi)
|
||||
return sines
|
||||
|
||||
def forward(self, f0):
|
||||
""" sine_tensor, uv = forward(f0)
|
||||
input F0: tensor(batchsize=1, length, dim=1)
|
||||
f0 for unvoiced steps should be 0
|
||||
output sine_tensor: tensor(batchsize=1, length, dim)
|
||||
output uv: tensor(batchsize=1, length, 1)
|
||||
"""
|
||||
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
|
||||
device=f0.device)
|
||||
# fundamental component
|
||||
fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
|
||||
|
||||
# generate sine waveforms
|
||||
sine_waves = self._f02sine(fn) * self.sine_amp
|
||||
|
||||
# generate uv signal
|
||||
# uv = torch.ones(f0.shape)
|
||||
# uv = uv * (f0 > self.voiced_threshold)
|
||||
uv = self._f02uv(f0)
|
||||
|
||||
# noise: for unvoiced should be similar to sine_amp
|
||||
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
||||
# . for voiced regions is self.noise_std
|
||||
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
||||
noise = noise_amp * torch.randn_like(sine_waves)
|
||||
|
||||
# first: set the unvoiced part to 0 by uv
|
||||
# then: additive noise
|
||||
sine_waves = sine_waves * uv + noise
|
||||
return sine_waves, uv, noise
|
||||
|
||||
|
||||
class SourceModuleHnNSF(torch.nn.Module):
|
||||
""" SourceModule for hn-nsf
|
||||
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
||||
add_noise_std=0.003, voiced_threshod=0)
|
||||
sampling_rate: sampling_rate in Hz
|
||||
harmonic_num: number of harmonic above F0 (default: 0)
|
||||
sine_amp: amplitude of sine source signal (default: 0.1)
|
||||
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
||||
note that amplitude of noise in unvoiced is decided
|
||||
by sine_amp
|
||||
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
||||
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
||||
F0_sampled (batchsize, length, 1)
|
||||
Sine_source (batchsize, length, 1)
|
||||
noise_source (batchsize, length 1)
|
||||
uv (batchsize, length, 1)
|
||||
"""
|
||||
|
||||
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
||||
add_noise_std=0.003, voiced_threshod=0):
|
||||
super(SourceModuleHnNSF, self).__init__()
|
||||
|
||||
self.sine_amp = sine_amp
|
||||
self.noise_std = add_noise_std
|
||||
|
||||
# to produce sine waveforms
|
||||
self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
|
||||
sine_amp, add_noise_std, voiced_threshod)
|
||||
|
||||
# to merge source harmonics into a single excitation
|
||||
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
||||
self.l_tanh = torch.nn.Tanh()
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
||||
F0_sampled (batchsize, length, 1)
|
||||
Sine_source (batchsize, length, 1)
|
||||
noise_source (batchsize, length 1)
|
||||
"""
|
||||
# source for harmonic branch
|
||||
with torch.no_grad():
|
||||
sine_wavs, uv, _ = self.l_sin_gen(x)
|
||||
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
||||
|
||||
# source for noise branch, in the same shape as uv
|
||||
noise = torch.randn_like(uv) * self.sine_amp / 3
|
||||
return sine_merge, noise, uv
|
||||
def padDiff(x):
|
||||
return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0)
|
||||
|
||||
|
||||
class Generator(torch.nn.Module):
|
||||
def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size):
|
||||
super(Generator, self).__init__()
|
||||
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
resblock = AdaINResBlock1
|
||||
|
||||
self.m_source = SourceModuleHnNSF(
|
||||
sampling_rate=24000,
|
||||
upsample_scale=np.prod(upsample_rates) * gen_istft_hop_size,
|
||||
harmonic_num=8, voiced_threshod=10)
|
||||
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * gen_istft_hop_size)
|
||||
self.noise_convs = nn.ModuleList()
|
||||
self.noise_res = nn.ModuleList()
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
self.ups.append(weight_norm(
|
||||
ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
|
||||
k, u, padding=(k-u)//2)))
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel//(2**(i+1))
|
||||
for j, (k, d) in enumerate(zip(resblock_kernel_sizes,resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(ch, k, d, style_dim))
|
||||
|
||||
c_cur = upsample_initial_channel // (2 ** (i + 1))
|
||||
|
||||
if i + 1 < len(upsample_rates): #
|
||||
stride_f0 = np.prod(upsample_rates[i + 1:])
|
||||
self.noise_convs.append(Conv1d(
|
||||
gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
|
||||
self.noise_res.append(resblock(c_cur, 7, [1,3,5], style_dim))
|
||||
else:
|
||||
self.noise_convs.append(Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1))
|
||||
self.noise_res.append(resblock(c_cur, 11, [1,3,5], style_dim))
|
||||
|
||||
|
||||
self.post_n_fft = gen_istft_n_fft
|
||||
self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
|
||||
self.ups.apply(init_weights)
|
||||
self.conv_post.apply(init_weights)
|
||||
self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
|
||||
self.stft = TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
|
||||
|
||||
|
||||
def forward(self, x, s, f0):
|
||||
with torch.no_grad():
|
||||
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
||||
|
||||
har_source, noi_source, uv = self.m_source(f0)
|
||||
har_source = har_source.transpose(1, 2).squeeze(1)
|
||||
har_spec, har_phase = self.stft.transform(har_source)
|
||||
har = torch.cat([har_spec, har_phase], dim=1)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
x_source = self.noise_convs[i](har)
|
||||
x_source = self.noise_res[i](x_source, s)
|
||||
|
||||
x = self.ups[i](x)
|
||||
if i == self.num_upsamples - 1:
|
||||
x = self.reflection_pad(x)
|
||||
|
||||
x = x + x_source
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i*self.num_kernels+j](x, s)
|
||||
else:
|
||||
xs += self.resblocks[i*self.num_kernels+j](x, s)
|
||||
x = xs / self.num_kernels
|
||||
x = F.leaky_relu(x)
|
||||
x = self.conv_post(x)
|
||||
spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
|
||||
phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
|
||||
return self.stft.inverse(spec, phase)
|
||||
|
||||
def fw_phase(self, x, s):
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i*self.num_kernels+j](x, s)
|
||||
else:
|
||||
xs += self.resblocks[i*self.num_kernels+j](x, s)
|
||||
x = xs / self.num_kernels
|
||||
x = F.leaky_relu(x)
|
||||
x = self.reflection_pad(x)
|
||||
x = self.conv_post(x)
|
||||
spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
|
||||
phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
|
||||
return spec, phase
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print('Removing weight norm...')
|
||||
for l in self.ups:
|
||||
remove_weight_norm(l)
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
remove_weight_norm(self.conv_pre)
|
||||
remove_weight_norm(self.conv_post)
|
||||
|
||||
|
||||
class AdainResBlk1d(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
|
||||
upsample='none', dropout_p=0.0):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.upsample_type = upsample
|
||||
self.upsample = UpSample1d(upsample)
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self._build_weights(dim_in, dim_out, style_dim)
|
||||
self.dropout = nn.Dropout(dropout_p)
|
||||
|
||||
if upsample == 'none':
|
||||
self.pool = nn.Identity()
|
||||
else:
|
||||
self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
|
||||
|
||||
|
||||
def _build_weights(self, dim_in, dim_out, style_dim):
|
||||
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
|
||||
self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
|
||||
self.norm1 = AdaIN1d(style_dim, dim_in)
|
||||
self.norm2 = AdaIN1d(style_dim, dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
|
||||
|
||||
def _shortcut(self, x):
|
||||
x = self.upsample(x)
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x, s):
|
||||
x = self.norm1(x, s)
|
||||
x = self.actv(x)
|
||||
x = self.pool(x)
|
||||
x = self.conv1(self.dropout(x))
|
||||
x = self.norm2(x, s)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(self.dropout(x))
|
||||
return x
|
||||
|
||||
def forward(self, x, s):
|
||||
out = self._residual(x, s)
|
||||
out = (out + self._shortcut(x)) / np.sqrt(2)
|
||||
return out
|
||||
|
||||
class UpSample1d(nn.Module):
|
||||
def __init__(self, layer_type):
|
||||
super().__init__()
|
||||
self.layer_type = layer_type
|
||||
|
||||
def forward(self, x):
|
||||
if self.layer_type == 'none':
|
||||
return x
|
||||
else:
|
||||
return F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, dim_in=512, F0_channel=512, style_dim=64, dim_out=80,
|
||||
resblock_kernel_sizes = [3,7,11],
|
||||
upsample_rates = [10, 6],
|
||||
upsample_initial_channel=512,
|
||||
resblock_dilation_sizes=[[1,3,5], [1,3,5], [1,3,5]],
|
||||
upsample_kernel_sizes=[20, 12],
|
||||
gen_istft_n_fft=20, gen_istft_hop_size=5):
|
||||
super().__init__()
|
||||
|
||||
self.decode = nn.ModuleList()
|
||||
|
||||
self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
|
||||
|
||||
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
||||
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
||||
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
||||
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
|
||||
|
||||
self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
|
||||
|
||||
self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
|
||||
|
||||
self.asr_res = nn.Sequential(
|
||||
weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
|
||||
)
|
||||
|
||||
|
||||
self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates,
|
||||
upsample_initial_channel, resblock_dilation_sizes,
|
||||
upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size)
|
||||
|
||||
def forward(self, asr, F0_curve, N, s):
|
||||
F0 = self.F0_conv(F0_curve.unsqueeze(1))
|
||||
N = self.N_conv(N.unsqueeze(1))
|
||||
|
||||
x = torch.cat([asr, F0, N], axis=1)
|
||||
x = self.encode(x, s)
|
||||
|
||||
asr_res = self.asr_res(asr)
|
||||
|
||||
res = True
|
||||
for block in self.decode:
|
||||
if res:
|
||||
x = torch.cat([x, asr_res, F0, N], axis=1)
|
||||
x = block(x, s)
|
||||
if block.upsample_type != "none":
|
||||
res = False
|
||||
|
||||
x = self.generator(x, s, F0_curve)
|
||||
return x
|
151
api/src/builds/kokoro.py
Normal file
151
api/src/builds/kokoro.py
Normal file
|
@ -0,0 +1,151 @@
|
|||
import re
|
||||
|
||||
import phonemizer
|
||||
import torch
|
||||
|
||||
|
||||
def split_num(num):
|
||||
num = num.group()
|
||||
if '.' in num:
|
||||
return num
|
||||
elif ':' in num:
|
||||
h, m = [int(n) for n in num.split(':')]
|
||||
if m == 0:
|
||||
return f"{h} o'clock"
|
||||
elif m < 10:
|
||||
return f'{h} oh {m}'
|
||||
return f'{h} {m}'
|
||||
year = int(num[:4])
|
||||
if year < 1100 or year % 1000 < 10:
|
||||
return num
|
||||
left, right = num[:2], int(num[2:4])
|
||||
s = 's' if num.endswith('s') else ''
|
||||
if 100 <= year % 1000 <= 999:
|
||||
if right == 0:
|
||||
return f'{left} hundred{s}'
|
||||
elif right < 10:
|
||||
return f'{left} oh {right}{s}'
|
||||
return f'{left} {right}{s}'
|
||||
|
||||
def flip_money(m):
|
||||
m = m.group()
|
||||
bill = 'dollar' if m[0] == '$' else 'pound'
|
||||
if m[-1].isalpha():
|
||||
return f'{m[1:]} {bill}s'
|
||||
elif '.' not in m:
|
||||
s = '' if m[1:] == '1' else 's'
|
||||
return f'{m[1:]} {bill}{s}'
|
||||
b, c = m[1:].split('.')
|
||||
s = '' if b == '1' else 's'
|
||||
c = int(c.ljust(2, '0'))
|
||||
coins = f"cent{'' if c == 1 else 's'}" if m[0] == '$' else ('penny' if c == 1 else 'pence')
|
||||
return f'{b} {bill}{s} and {c} {coins}'
|
||||
|
||||
def point_num(num):
|
||||
a, b = num.group().split('.')
|
||||
return ' point '.join([a, ' '.join(b)])
|
||||
|
||||
def normalize_text(text):
|
||||
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
|
||||
text = text.replace('«', chr(8220)).replace('»', chr(8221))
|
||||
text = text.replace(chr(8220), '"').replace(chr(8221), '"')
|
||||
text = text.replace('(', '«').replace(')', '»')
|
||||
for a, b in zip('、。!,:;?', ',.!,:;?'):
|
||||
text = text.replace(a, b+' ')
|
||||
text = re.sub(r'[^\S \n]', ' ', text)
|
||||
text = re.sub(r' +', ' ', text)
|
||||
text = re.sub(r'(?<=\n) +(?=\n)', '', text)
|
||||
text = re.sub(r'\bD[Rr]\.(?= [A-Z])', 'Doctor', text)
|
||||
text = re.sub(r'\b(?:Mr\.|MR\.(?= [A-Z]))', 'Mister', text)
|
||||
text = re.sub(r'\b(?:Ms\.|MS\.(?= [A-Z]))', 'Miss', text)
|
||||
text = re.sub(r'\b(?:Mrs\.|MRS\.(?= [A-Z]))', 'Mrs', text)
|
||||
text = re.sub(r'\betc\.(?! [A-Z])', 'etc', text)
|
||||
text = re.sub(r'(?i)\b(y)eah?\b', r"\1e'a", text)
|
||||
text = re.sub(r'\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)', split_num, text)
|
||||
text = re.sub(r'(?<=\d),(?=\d)', '', text)
|
||||
text = re.sub(r'(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b', flip_money, text)
|
||||
text = re.sub(r'\d*\.\d+', point_num, text)
|
||||
text = re.sub(r'(?<=\d)-(?=\d)', ' to ', text)
|
||||
text = re.sub(r'(?<=\d)S', ' S', text)
|
||||
text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
|
||||
text = re.sub(r"(?<=X')S\b", 's', text)
|
||||
text = re.sub(r'(?:[A-Za-z]\.){2,} [a-z]', lambda m: m.group().replace('.', '-'), text)
|
||||
text = re.sub(r'(?i)(?<=[A-Z])\.(?=[A-Z])', '-', text)
|
||||
return text.strip()
|
||||
|
||||
def get_vocab():
|
||||
_pad = "$"
|
||||
_punctuation = ';:,.!?¡¿—…"«»“” '
|
||||
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
||||
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
||||
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
|
||||
dicts = {}
|
||||
for i in range(len((symbols))):
|
||||
dicts[symbols[i]] = i
|
||||
return dicts
|
||||
|
||||
VOCAB = get_vocab()
|
||||
def tokenize(ps):
|
||||
return [i for i in map(VOCAB.get, ps) if i is not None]
|
||||
|
||||
phonemizers = dict(
|
||||
a=phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True),
|
||||
b=phonemizer.backend.EspeakBackend(language='en-gb', preserve_punctuation=True, with_stress=True),
|
||||
)
|
||||
def phonemize(text, lang, norm=True):
|
||||
if norm:
|
||||
text = normalize_text(text)
|
||||
ps = phonemizers[lang].phonemize([text])
|
||||
ps = ps[0] if ps else ''
|
||||
# https://en.wiktionary.org/wiki/kokoro#English
|
||||
ps = ps.replace('kəkˈoːɹoʊ', 'kˈoʊkəɹoʊ').replace('kəkˈɔːɹəʊ', 'kˈəʊkəɹəʊ')
|
||||
ps = ps.replace('ʲ', 'j').replace('r', 'ɹ').replace('x', 'k').replace('ɬ', 'l')
|
||||
ps = re.sub(r'(?<=[a-zɹː])(?=hˈʌndɹɪd)', ' ', ps)
|
||||
ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»“” ]|$)', 'z', ps)
|
||||
if lang == 'a':
|
||||
ps = re.sub(r'(?<=nˈaɪn)ti(?!ː)', 'di', ps)
|
||||
ps = ''.join(filter(lambda p: p in VOCAB, ps))
|
||||
return ps.strip()
|
||||
|
||||
def length_to_mask(lengths):
|
||||
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
||||
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
||||
return mask
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(model, tokens, ref_s, speed):
|
||||
device = ref_s.device
|
||||
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
||||
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
||||
text_mask = length_to_mask(input_lengths).to(device)
|
||||
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
||||
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
||||
s = ref_s[:, 128:]
|
||||
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
|
||||
x, _ = model.predictor.lstm(d)
|
||||
duration = model.predictor.duration_proj(x)
|
||||
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
||||
pred_dur = torch.round(duration).clamp(min=1).long()
|
||||
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
|
||||
c_frame = 0
|
||||
for i in range(pred_aln_trg.size(0)):
|
||||
pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1
|
||||
c_frame += pred_dur[0,i].item()
|
||||
en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
|
||||
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
||||
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
||||
asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
|
||||
return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
|
||||
|
||||
def generate(model, text, voicepack, lang='a', speed=1, ps=None):
|
||||
ps = ps or phonemize(text, lang)
|
||||
tokens = tokenize(ps)
|
||||
if not tokens:
|
||||
return None
|
||||
elif len(tokens) > 510:
|
||||
tokens = tokens[:510]
|
||||
print('Truncated to 510 tokens')
|
||||
ref_s = voicepack[len(tokens)]
|
||||
out = forward(model, tokens, ref_s, speed)
|
||||
ps = ''.join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
|
||||
return out, ps
|
375
api/src/builds/models.py
Normal file
375
api/src/builds/models.py
Normal file
|
@ -0,0 +1,375 @@
|
|||
# https://github.com/yl4579/StyleTTS2/blob/main/models.py
|
||||
import json
|
||||
import os
|
||||
import os.path as osp
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from munch import Munch
|
||||
from torch.nn.utils import spectral_norm, weight_norm
|
||||
|
||||
from .istftnet import AdaIN1d, Decoder
|
||||
from .plbert import load_plbert
|
||||
|
||||
|
||||
class LinearNorm(torch.nn.Module):
|
||||
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
|
||||
super(LinearNorm, self).__init__()
|
||||
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
||||
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.linear_layer.weight,
|
||||
gain=torch.nn.init.calculate_gain(w_init_gain))
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear_layer(x)
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, channels, eps=1e-5):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.eps = eps
|
||||
|
||||
self.gamma = nn.Parameter(torch.ones(channels))
|
||||
self.beta = nn.Parameter(torch.zeros(channels))
|
||||
|
||||
def forward(self, x):
|
||||
x = x.transpose(1, -1)
|
||||
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
||||
return x.transpose(1, -1)
|
||||
|
||||
class TextEncoder(nn.Module):
|
||||
def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
|
||||
super().__init__()
|
||||
self.embedding = nn.Embedding(n_symbols, channels)
|
||||
|
||||
padding = (kernel_size - 1) // 2
|
||||
self.cnn = nn.ModuleList()
|
||||
for _ in range(depth):
|
||||
self.cnn.append(nn.Sequential(
|
||||
weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
|
||||
LayerNorm(channels),
|
||||
actv,
|
||||
nn.Dropout(0.2),
|
||||
))
|
||||
# self.cnn = nn.Sequential(*self.cnn)
|
||||
|
||||
self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)
|
||||
|
||||
def forward(self, x, input_lengths, m):
|
||||
x = self.embedding(x) # [B, T, emb]
|
||||
x = x.transpose(1, 2) # [B, emb, T]
|
||||
m = m.to(input_lengths.device).unsqueeze(1)
|
||||
x.masked_fill_(m, 0.0)
|
||||
|
||||
for c in self.cnn:
|
||||
x = c(x)
|
||||
x.masked_fill_(m, 0.0)
|
||||
|
||||
x = x.transpose(1, 2) # [B, T, chn]
|
||||
|
||||
input_lengths = input_lengths.cpu().numpy()
|
||||
x = nn.utils.rnn.pack_padded_sequence(
|
||||
x, input_lengths, batch_first=True, enforce_sorted=False)
|
||||
|
||||
self.lstm.flatten_parameters()
|
||||
x, _ = self.lstm(x)
|
||||
x, _ = nn.utils.rnn.pad_packed_sequence(
|
||||
x, batch_first=True)
|
||||
|
||||
x = x.transpose(-1, -2)
|
||||
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
|
||||
|
||||
x_pad[:, :, :x.shape[-1]] = x
|
||||
x = x_pad.to(x.device)
|
||||
|
||||
x.masked_fill_(m, 0.0)
|
||||
|
||||
return x
|
||||
|
||||
def inference(self, x):
|
||||
x = self.embedding(x)
|
||||
x = x.transpose(1, 2)
|
||||
x = self.cnn(x)
|
||||
x = x.transpose(1, 2)
|
||||
self.lstm.flatten_parameters()
|
||||
x, _ = self.lstm(x)
|
||||
return x
|
||||
|
||||
def length_to_mask(self, lengths):
|
||||
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
||||
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
||||
return mask
|
||||
|
||||
|
||||
class UpSample1d(nn.Module):
|
||||
def __init__(self, layer_type):
|
||||
super().__init__()
|
||||
self.layer_type = layer_type
|
||||
|
||||
def forward(self, x):
|
||||
if self.layer_type == 'none':
|
||||
return x
|
||||
else:
|
||||
return F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
|
||||
class AdainResBlk1d(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
|
||||
upsample='none', dropout_p=0.0):
|
||||
super().__init__()
|
||||
self.actv = actv
|
||||
self.upsample_type = upsample
|
||||
self.upsample = UpSample1d(upsample)
|
||||
self.learned_sc = dim_in != dim_out
|
||||
self._build_weights(dim_in, dim_out, style_dim)
|
||||
self.dropout = nn.Dropout(dropout_p)
|
||||
|
||||
if upsample == 'none':
|
||||
self.pool = nn.Identity()
|
||||
else:
|
||||
self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
|
||||
|
||||
|
||||
def _build_weights(self, dim_in, dim_out, style_dim):
|
||||
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
|
||||
self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
|
||||
self.norm1 = AdaIN1d(style_dim, dim_in)
|
||||
self.norm2 = AdaIN1d(style_dim, dim_out)
|
||||
if self.learned_sc:
|
||||
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
|
||||
|
||||
def _shortcut(self, x):
|
||||
x = self.upsample(x)
|
||||
if self.learned_sc:
|
||||
x = self.conv1x1(x)
|
||||
return x
|
||||
|
||||
def _residual(self, x, s):
|
||||
x = self.norm1(x, s)
|
||||
x = self.actv(x)
|
||||
x = self.pool(x)
|
||||
x = self.conv1(self.dropout(x))
|
||||
x = self.norm2(x, s)
|
||||
x = self.actv(x)
|
||||
x = self.conv2(self.dropout(x))
|
||||
return x
|
||||
|
||||
def forward(self, x, s):
|
||||
out = self._residual(x, s)
|
||||
out = (out + self._shortcut(x)) / np.sqrt(2)
|
||||
return out
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
def __init__(self, style_dim, channels, eps=1e-5):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.eps = eps
|
||||
|
||||
self.fc = nn.Linear(style_dim, channels*2)
|
||||
|
||||
def forward(self, x, s):
|
||||
x = x.transpose(-1, -2)
|
||||
x = x.transpose(1, -1)
|
||||
|
||||
h = self.fc(s)
|
||||
h = h.view(h.size(0), h.size(1), 1)
|
||||
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
||||
gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
|
||||
|
||||
|
||||
x = F.layer_norm(x, (self.channels,), eps=self.eps)
|
||||
x = (1 + gamma) * x + beta
|
||||
return x.transpose(1, -1).transpose(-1, -2)
|
||||
|
||||
class ProsodyPredictor(nn.Module):
|
||||
|
||||
def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
|
||||
super().__init__()
|
||||
|
||||
self.text_encoder = DurationEncoder(sty_dim=style_dim,
|
||||
d_model=d_hid,
|
||||
nlayers=nlayers,
|
||||
dropout=dropout)
|
||||
|
||||
self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
|
||||
self.duration_proj = LinearNorm(d_hid, max_dur)
|
||||
|
||||
self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
|
||||
self.F0 = nn.ModuleList()
|
||||
self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
|
||||
self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
|
||||
self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
|
||||
|
||||
self.N = nn.ModuleList()
|
||||
self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
|
||||
self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
|
||||
self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
|
||||
|
||||
self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
|
||||
self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
|
||||
|
||||
|
||||
def forward(self, texts, style, text_lengths, alignment, m):
|
||||
d = self.text_encoder(texts, style, text_lengths, m)
|
||||
|
||||
batch_size = d.shape[0]
|
||||
text_size = d.shape[1]
|
||||
|
||||
# predict duration
|
||||
input_lengths = text_lengths.cpu().numpy()
|
||||
x = nn.utils.rnn.pack_padded_sequence(
|
||||
d, input_lengths, batch_first=True, enforce_sorted=False)
|
||||
|
||||
m = m.to(text_lengths.device).unsqueeze(1)
|
||||
|
||||
self.lstm.flatten_parameters()
|
||||
x, _ = self.lstm(x)
|
||||
x, _ = nn.utils.rnn.pad_packed_sequence(
|
||||
x, batch_first=True)
|
||||
|
||||
x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
|
||||
|
||||
x_pad[:, :x.shape[1], :] = x
|
||||
x = x_pad.to(x.device)
|
||||
|
||||
duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
|
||||
|
||||
en = (d.transpose(-1, -2) @ alignment)
|
||||
|
||||
return duration.squeeze(-1), en
|
||||
|
||||
def F0Ntrain(self, x, s):
|
||||
x, _ = self.shared(x.transpose(-1, -2))
|
||||
|
||||
F0 = x.transpose(-1, -2)
|
||||
for block in self.F0:
|
||||
F0 = block(F0, s)
|
||||
F0 = self.F0_proj(F0)
|
||||
|
||||
N = x.transpose(-1, -2)
|
||||
for block in self.N:
|
||||
N = block(N, s)
|
||||
N = self.N_proj(N)
|
||||
|
||||
return F0.squeeze(1), N.squeeze(1)
|
||||
|
||||
def length_to_mask(self, lengths):
|
||||
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
||||
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
||||
return mask
|
||||
|
||||
class DurationEncoder(nn.Module):
|
||||
|
||||
def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
|
||||
super().__init__()
|
||||
self.lstms = nn.ModuleList()
|
||||
for _ in range(nlayers):
|
||||
self.lstms.append(nn.LSTM(d_model + sty_dim,
|
||||
d_model // 2,
|
||||
num_layers=1,
|
||||
batch_first=True,
|
||||
bidirectional=True,
|
||||
dropout=dropout))
|
||||
self.lstms.append(AdaLayerNorm(sty_dim, d_model))
|
||||
|
||||
|
||||
self.dropout = dropout
|
||||
self.d_model = d_model
|
||||
self.sty_dim = sty_dim
|
||||
|
||||
def forward(self, x, style, text_lengths, m):
|
||||
masks = m.to(text_lengths.device)
|
||||
|
||||
x = x.permute(2, 0, 1)
|
||||
s = style.expand(x.shape[0], x.shape[1], -1)
|
||||
x = torch.cat([x, s], axis=-1)
|
||||
x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
|
||||
|
||||
x = x.transpose(0, 1)
|
||||
input_lengths = text_lengths.cpu().numpy()
|
||||
x = x.transpose(-1, -2)
|
||||
|
||||
for block in self.lstms:
|
||||
if isinstance(block, AdaLayerNorm):
|
||||
x = block(x.transpose(-1, -2), style).transpose(-1, -2)
|
||||
x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
|
||||
x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
|
||||
else:
|
||||
x = x.transpose(-1, -2)
|
||||
x = nn.utils.rnn.pack_padded_sequence(
|
||||
x, input_lengths, batch_first=True, enforce_sorted=False)
|
||||
block.flatten_parameters()
|
||||
x, _ = block(x)
|
||||
x, _ = nn.utils.rnn.pad_packed_sequence(
|
||||
x, batch_first=True)
|
||||
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||
x = x.transpose(-1, -2)
|
||||
|
||||
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
|
||||
|
||||
x_pad[:, :, :x.shape[-1]] = x
|
||||
x = x_pad.to(x.device)
|
||||
|
||||
return x.transpose(-1, -2)
|
||||
|
||||
def inference(self, x, style):
|
||||
x = self.embedding(x.transpose(-1, -2)) * np.sqrt(self.d_model)
|
||||
style = style.expand(x.shape[0], x.shape[1], -1)
|
||||
x = torch.cat([x, style], axis=-1)
|
||||
src = self.pos_encoder(x)
|
||||
output = self.transformer_encoder(src).transpose(0, 1)
|
||||
return output
|
||||
|
||||
def length_to_mask(self, lengths):
|
||||
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
||||
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
||||
return mask
|
||||
|
||||
# https://github.com/yl4579/StyleTTS2/blob/main/utils.py
|
||||
def recursive_munch(d):
|
||||
if isinstance(d, dict):
|
||||
return Munch((k, recursive_munch(v)) for k, v in d.items())
|
||||
elif isinstance(d, list):
|
||||
return [recursive_munch(v) for v in d]
|
||||
else:
|
||||
return d
|
||||
|
||||
def build_model(path, device):
|
||||
config = Path(__file__).parent / 'config.json'
|
||||
assert config.exists(), f'Config path incorrect: config.json not found at {config}'
|
||||
with open(config, 'r') as r:
|
||||
args = recursive_munch(json.load(r))
|
||||
assert args.decoder.type == 'istftnet', f'Unknown decoder type: {args.decoder.type}'
|
||||
decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
|
||||
resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
|
||||
upsample_rates = args.decoder.upsample_rates,
|
||||
upsample_initial_channel=args.decoder.upsample_initial_channel,
|
||||
resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
|
||||
upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
|
||||
gen_istft_n_fft=args.decoder.gen_istft_n_fft, gen_istft_hop_size=args.decoder.gen_istft_hop_size)
|
||||
text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
|
||||
predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
|
||||
bert = load_plbert()
|
||||
bert_encoder = nn.Linear(bert.config.hidden_size, args.hidden_dim)
|
||||
for parent in [bert, bert_encoder, predictor, decoder, text_encoder]:
|
||||
for child in parent.children():
|
||||
if isinstance(child, nn.RNNBase):
|
||||
child.flatten_parameters()
|
||||
model = Munch(
|
||||
bert=bert.to(device).eval(),
|
||||
bert_encoder=bert_encoder.to(device).eval(),
|
||||
predictor=predictor.to(device).eval(),
|
||||
decoder=decoder.to(device).eval(),
|
||||
text_encoder=text_encoder.to(device).eval(),
|
||||
)
|
||||
for key, state_dict in torch.load(path, map_location='cpu', weights_only=True)['net'].items():
|
||||
assert key in model, key
|
||||
try:
|
||||
model[key].load_state_dict(state_dict)
|
||||
except:
|
||||
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
||||
model[key].load_state_dict(state_dict, strict=False)
|
||||
return model
|
16
api/src/builds/plbert.py
Normal file
16
api/src/builds/plbert.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
# https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py
|
||||
from transformers import AlbertConfig, AlbertModel
|
||||
|
||||
|
||||
class CustomAlbert(AlbertModel):
|
||||
def forward(self, *args, **kwargs):
|
||||
# Call the original forward method
|
||||
outputs = super().forward(*args, **kwargs)
|
||||
# Only return the last_hidden_state
|
||||
return outputs.last_hidden_state
|
||||
|
||||
def load_plbert():
|
||||
plbert_config = {'vocab_size': 178, 'hidden_size': 768, 'num_attention_heads': 12, 'intermediate_size': 2048, 'max_position_embeddings': 512, 'num_hidden_layers': 12, 'dropout': 0.1}
|
||||
albert_base_configuration = AlbertConfig(**plbert_config)
|
||||
bert = CustomAlbert(albert_base_configuration)
|
||||
return bert
|
|
@ -13,7 +13,7 @@ class Settings(BaseSettings):
|
|||
output_dir: str = "output"
|
||||
output_dir_size_limit_mb: float = 500.0 # Maximum size of output directory in MB
|
||||
default_voice: str = "af"
|
||||
model_dir: str = "/app/Kokoro-82M" # Base directory for model files
|
||||
model_dir: str = "/app/api/model_files" # Base directory for model files
|
||||
pytorch_model_path: str = "kokoro-v0_19.pth"
|
||||
onnx_model_path: str = "kokoro-v0_19.onnx"
|
||||
voices_dir: str = "voices"
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import re
|
||||
|
||||
import torch
|
||||
import phonemizer
|
||||
import torch
|
||||
|
||||
|
||||
def split_num(num):
|
||||
|
|
|
@ -6,15 +6,15 @@ import sys
|
|||
from contextlib import asynccontextmanager
|
||||
|
||||
import uvicorn
|
||||
from loguru import logger
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from loguru import logger
|
||||
|
||||
from .core.config import settings
|
||||
from .services.tts_model import TTSModel
|
||||
from .routers.development import router as dev_router
|
||||
from .services.tts_service import TTSService
|
||||
from .routers.openai_compatible import router as openai_router
|
||||
from .services.tts_model import TTSModel
|
||||
from .services.tts_service import TTSService
|
||||
|
||||
|
||||
def setup_logger():
|
||||
|
@ -47,7 +47,7 @@ async def lifespan(app: FastAPI):
|
|||
# Initialize the main model with warm-up
|
||||
voicepack_count = await TTSModel.setup()
|
||||
# boundary = "█████╗"*9
|
||||
boundary = "░" * 24
|
||||
boundary = "░" * 2*12
|
||||
startup_msg = f"""
|
||||
|
||||
{boundary}
|
||||
|
|
|
@ -1,18 +1,18 @@
|
|||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response
|
||||
from loguru import logger
|
||||
from fastapi import Depends, Response, APIRouter, HTTPException
|
||||
|
||||
from ..services.audio import AudioService
|
||||
from ..services.text_processing import phonemize, tokenize
|
||||
from ..services.tts_model import TTSModel
|
||||
from ..services.tts_service import TTSService
|
||||
from ..structures.text_schemas import (
|
||||
GenerateFromPhonemesRequest,
|
||||
PhonemeRequest,
|
||||
PhonemeResponse,
|
||||
GenerateFromPhonemesRequest,
|
||||
)
|
||||
from ..services.text_processing import tokenize, phonemize
|
||||
|
||||
router = APIRouter(tags=["text processing"])
|
||||
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
from typing import List, Union, AsyncGenerator
|
||||
from typing import AsyncGenerator, List, Union
|
||||
|
||||
from loguru import logger
|
||||
from fastapi import Header, Depends, Response, APIRouter, HTTPException
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Response, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from loguru import logger
|
||||
|
||||
from ..services.audio import AudioService
|
||||
from ..structures.schemas import OpenAISpeechRequest
|
||||
from ..services.tts_service import TTSService
|
||||
from ..structures.schemas import OpenAISpeechRequest
|
||||
|
||||
router = APIRouter(
|
||||
tags=["OpenAI Compatible TTS"],
|
||||
|
@ -49,22 +49,35 @@ async def process_voices(
|
|||
|
||||
|
||||
async def stream_audio_chunks(
|
||||
tts_service: TTSService, request: OpenAISpeechRequest
|
||||
tts_service: TTSService,
|
||||
request: OpenAISpeechRequest,
|
||||
client_request: Request
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""Stream audio chunks as they're generated"""
|
||||
"""Stream audio chunks as they're generated with client disconnect handling"""
|
||||
voice_to_use = await process_voices(request.voice, tts_service)
|
||||
async for chunk in tts_service.generate_audio_stream(
|
||||
text=request.input,
|
||||
voice=voice_to_use,
|
||||
speed=request.speed,
|
||||
output_format=request.response_format,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
try:
|
||||
async for chunk in tts_service.generate_audio_stream(
|
||||
text=request.input,
|
||||
voice=voice_to_use,
|
||||
speed=request.speed,
|
||||
output_format=request.response_format,
|
||||
):
|
||||
# Check if client is still connected
|
||||
if await client_request.is_disconnected():
|
||||
logger.info("Client disconnected, stopping audio generation")
|
||||
break
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
logger.error(f"Error in audio streaming: {str(e)}")
|
||||
# Let the exception propagate to trigger cleanup
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/audio/speech")
|
||||
async def create_speech(
|
||||
request: OpenAISpeechRequest,
|
||||
client_request: Request,
|
||||
tts_service: TTSService = Depends(get_tts_service),
|
||||
x_raw_response: str = Header(None, alias="x-raw-response"),
|
||||
):
|
||||
|
@ -87,7 +100,7 @@ async def create_speech(
|
|||
if request.stream:
|
||||
# Stream audio chunks as they're generated
|
||||
return StreamingResponse(
|
||||
stream_audio_chunks(tts_service, request),
|
||||
stream_audio_chunks(tts_service, request, client_request),
|
||||
media_type=content_type,
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
||||
|
|
|
@ -3,8 +3,8 @@
|
|||
from io import BytesIO
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import scipy.io.wavfile as wavfile
|
||||
import soundfile as sf
|
||||
from loguru import logger
|
||||
|
||||
from ..core.config import settings
|
||||
|
@ -22,20 +22,19 @@ class AudioNormalizer:
|
|||
def normalize(
|
||||
self, audio_data: np.ndarray, is_last_chunk: bool = False
|
||||
) -> np.ndarray:
|
||||
"""Normalize audio data to int16 range and trim chunk boundaries"""
|
||||
# Convert to float32 if not already
|
||||
"""Convert audio data to int16 range and trim chunk boundaries"""
|
||||
if len(audio_data) == 0:
|
||||
raise ValueError("Audio data cannot be empty")
|
||||
|
||||
# Simple float32 to int16 conversion
|
||||
audio_float = audio_data.astype(np.float32)
|
||||
|
||||
# Normalize to [-1, 1] range first
|
||||
if np.max(np.abs(audio_float)) > 0:
|
||||
audio_float = audio_float / np.max(np.abs(audio_float))
|
||||
|
||||
# Trim end of non-final chunks to reduce gaps
|
||||
|
||||
# Trim for non-final chunks
|
||||
if not is_last_chunk and len(audio_float) > self.samples_to_trim:
|
||||
audio_float = audio_float[: -self.samples_to_trim]
|
||||
|
||||
# Scale to int16 range
|
||||
return (audio_float * self.int16_max).astype(np.int16)
|
||||
audio_float = audio_float[:-self.samples_to_trim]
|
||||
|
||||
# Direct scaling like the non-streaming version
|
||||
return (audio_float * 32767).astype(np.int16)
|
||||
|
||||
|
||||
class AudioService:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from .normalizer import normalize_text
|
||||
from .phonemizer import EspeakBackend, PhonemizerBackend, phonemize
|
||||
from .vocabulary import VOCAB, tokenize, decode_tokens
|
||||
from .vocabulary import VOCAB, decode_tokens, tokenize
|
||||
|
||||
__all__ = [
|
||||
"normalize_text",
|
||||
|
|
|
@ -5,19 +5,20 @@ import torch
|
|||
from loguru import logger
|
||||
from onnxruntime import (
|
||||
ExecutionMode,
|
||||
SessionOptions,
|
||||
InferenceSession,
|
||||
GraphOptimizationLevel,
|
||||
InferenceSession,
|
||||
SessionOptions,
|
||||
)
|
||||
|
||||
from .tts_base import TTSBaseModel
|
||||
from ..core.config import settings
|
||||
from .text_processing import tokenize, phonemize
|
||||
from .text_processing import phonemize, tokenize
|
||||
from .tts_base import TTSBaseModel
|
||||
|
||||
|
||||
class TTSCPUModel(TTSBaseModel):
|
||||
_instance = None
|
||||
_onnx_session = None
|
||||
_device = "cpu"
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
|
@ -30,64 +31,65 @@ class TTSCPUModel(TTSBaseModel):
|
|||
def initialize(cls, model_dir: str, model_path: str = None):
|
||||
"""Initialize ONNX model for CPU inference"""
|
||||
if cls._onnx_session is None:
|
||||
# Try loading ONNX model
|
||||
onnx_path = os.path.join(model_dir, settings.onnx_model_path)
|
||||
if os.path.exists(onnx_path):
|
||||
try:
|
||||
# Try loading ONNX model
|
||||
onnx_path = os.path.join(model_dir, settings.onnx_model_path)
|
||||
if not os.path.exists(onnx_path):
|
||||
logger.error(f"ONNX model not found at {onnx_path}")
|
||||
return None
|
||||
|
||||
logger.info(f"Loading ONNX model from {onnx_path}")
|
||||
else:
|
||||
logger.error(f"ONNX model not found at {onnx_path}")
|
||||
return None
|
||||
|
||||
if not onnx_path:
|
||||
return None
|
||||
# Configure ONNX session for optimal performance
|
||||
session_options = SessionOptions()
|
||||
|
||||
# Configure ONNX session for optimal performance
|
||||
session_options = SessionOptions()
|
||||
# Set optimization level
|
||||
if settings.onnx_optimization_level == "all":
|
||||
session_options.graph_optimization_level = (
|
||||
GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
)
|
||||
elif settings.onnx_optimization_level == "basic":
|
||||
session_options.graph_optimization_level = (
|
||||
GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||
)
|
||||
else:
|
||||
session_options.graph_optimization_level = (
|
||||
GraphOptimizationLevel.ORT_DISABLE_ALL
|
||||
)
|
||||
|
||||
# Set optimization level
|
||||
if settings.onnx_optimization_level == "all":
|
||||
session_options.graph_optimization_level = (
|
||||
GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
)
|
||||
elif settings.onnx_optimization_level == "basic":
|
||||
session_options.graph_optimization_level = (
|
||||
GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||
)
|
||||
else:
|
||||
session_options.graph_optimization_level = (
|
||||
GraphOptimizationLevel.ORT_DISABLE_ALL
|
||||
# Configure threading
|
||||
session_options.intra_op_num_threads = settings.onnx_num_threads
|
||||
session_options.inter_op_num_threads = settings.onnx_inter_op_threads
|
||||
|
||||
# Set execution mode
|
||||
session_options.execution_mode = (
|
||||
ExecutionMode.ORT_PARALLEL
|
||||
if settings.onnx_execution_mode == "parallel"
|
||||
else ExecutionMode.ORT_SEQUENTIAL
|
||||
)
|
||||
|
||||
# Configure threading
|
||||
session_options.intra_op_num_threads = settings.onnx_num_threads
|
||||
session_options.inter_op_num_threads = settings.onnx_inter_op_threads
|
||||
# Enable/disable memory pattern optimization
|
||||
session_options.enable_mem_pattern = settings.onnx_memory_pattern
|
||||
|
||||
# Set execution mode
|
||||
session_options.execution_mode = (
|
||||
ExecutionMode.ORT_PARALLEL
|
||||
if settings.onnx_execution_mode == "parallel"
|
||||
else ExecutionMode.ORT_SEQUENTIAL
|
||||
)
|
||||
|
||||
# Enable/disable memory pattern optimization
|
||||
session_options.enable_mem_pattern = settings.onnx_memory_pattern
|
||||
|
||||
# Configure CPU provider options
|
||||
provider_options = {
|
||||
"CPUExecutionProvider": {
|
||||
"arena_extend_strategy": settings.onnx_arena_extend_strategy,
|
||||
"cpu_memory_arena_cfg": "cpu:0",
|
||||
# Configure CPU provider options
|
||||
provider_options = {
|
||||
"CPUExecutionProvider": {
|
||||
"arena_extend_strategy": settings.onnx_arena_extend_strategy,
|
||||
"cpu_memory_arena_cfg": "cpu:0",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
session = InferenceSession(
|
||||
onnx_path,
|
||||
sess_options=session_options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
provider_options=[provider_options],
|
||||
)
|
||||
cls._onnx_session = session
|
||||
return session
|
||||
session = InferenceSession(
|
||||
onnx_path,
|
||||
sess_options=session_options,
|
||||
providers=["CPUExecutionProvider"],
|
||||
provider_options=[provider_options],
|
||||
)
|
||||
cls._onnx_session = session
|
||||
return session
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize ONNX model: {e}")
|
||||
return None
|
||||
return cls._onnx_session
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -3,12 +3,12 @@ import time
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from builds.models import build_model
|
||||
from loguru import logger
|
||||
from models import build_model
|
||||
|
||||
from .tts_base import TTSBaseModel
|
||||
from ..core.config import settings
|
||||
from .text_processing import tokenize, phonemize
|
||||
from .text_processing import phonemize, tokenize
|
||||
from .tts_base import TTSBaseModel
|
||||
|
||||
|
||||
# @torch.no_grad()
|
||||
|
|
|
@ -2,19 +2,19 @@ import io
|
|||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import List, Tuple, Optional
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import aiofiles.os
|
||||
import numpy as np
|
||||
import scipy.io.wavfile as wavfile
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from .audio import AudioService, AudioNormalizer
|
||||
from .tts_model import TTSModel
|
||||
from ..core.config import settings
|
||||
from .audio import AudioNormalizer, AudioService
|
||||
from .text_processing import chunker, normalize_text
|
||||
from .tts_model import TTSModel
|
||||
|
||||
|
||||
class TTSService:
|
||||
|
|
|
@ -4,9 +4,9 @@ from typing import List, Tuple
|
|||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from ..core.config import settings
|
||||
from .tts_model import TTSModel
|
||||
from .tts_service import TTSService
|
||||
from ..core.config import settings
|
||||
|
||||
|
||||
class WarmupService:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from enum import Enum
|
||||
from typing import List, Union, Literal
|
||||
from typing import List, Literal, Union
|
||||
|
||||
from pydantic import Field, BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VoiceCombineRequest(BaseModel):
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from pydantic import Field, BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PhonemeRequest(BaseModel):
|
||||
|
|
BIN
api/src/voices/af_irulan.pt
Normal file
BIN
api/src/voices/af_irulan.pt
Normal file
Binary file not shown.
|
@ -1,11 +1,11 @@
|
|||
import os
|
||||
import sys
|
||||
import shutil
|
||||
from unittest.mock import Mock, MagicMock, patch
|
||||
import sys
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import aiofiles.threadpool
|
||||
import numpy as np
|
||||
import pytest
|
||||
import aiofiles.threadpool
|
||||
|
||||
|
||||
def cleanup_mock_dirs():
|
||||
|
@ -32,77 +32,7 @@ def cleanup():
|
|||
cleanup_mock_dirs()
|
||||
|
||||
|
||||
# Create mock torch module
|
||||
mock_torch = Mock()
|
||||
mock_torch.cuda = Mock()
|
||||
mock_torch.cuda.is_available = Mock(return_value=False)
|
||||
|
||||
|
||||
# Create a mock tensor class that supports basic operations
|
||||
class MockTensor:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
if isinstance(data, (list, tuple)):
|
||||
self.shape = [len(data)]
|
||||
elif isinstance(data, MockTensor):
|
||||
self.shape = data.shape
|
||||
else:
|
||||
self.shape = getattr(data, "shape", [1])
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(self.data, (list, tuple)):
|
||||
if isinstance(idx, slice):
|
||||
return MockTensor(self.data[idx])
|
||||
return self.data[idx]
|
||||
return self
|
||||
|
||||
def max(self):
|
||||
if isinstance(self.data, (list, tuple)):
|
||||
max_val = max(self.data)
|
||||
return MockTensor(max_val)
|
||||
return 5 # Default for testing
|
||||
|
||||
def item(self):
|
||||
if isinstance(self.data, (list, tuple)):
|
||||
return max(self.data)
|
||||
if isinstance(self.data, (int, float)):
|
||||
return self.data
|
||||
return 5 # Default for testing
|
||||
|
||||
def cuda(self):
|
||||
"""Support cuda conversion"""
|
||||
return self
|
||||
|
||||
def any(self):
|
||||
if isinstance(self.data, (list, tuple)):
|
||||
return any(self.data)
|
||||
return False
|
||||
|
||||
def all(self):
|
||||
if isinstance(self.data, (list, tuple)):
|
||||
return all(self.data)
|
||||
return True
|
||||
|
||||
def unsqueeze(self, dim):
|
||||
return self
|
||||
|
||||
def expand(self, *args):
|
||||
return self
|
||||
|
||||
def type_as(self, other):
|
||||
return self
|
||||
|
||||
|
||||
# Add tensor operations to mock torch
|
||||
mock_torch.tensor = lambda x: MockTensor(x)
|
||||
mock_torch.zeros = lambda *args: MockTensor(
|
||||
[0] * (args[0] if isinstance(args[0], int) else args[0][0])
|
||||
)
|
||||
mock_torch.arange = lambda x: MockTensor(list(range(x)))
|
||||
mock_torch.gt = lambda x, y: MockTensor([False] * x.shape[0])
|
||||
|
||||
# Mock modules before they're imported
|
||||
sys.modules["torch"] = mock_torch
|
||||
sys.modules["transformers"] = Mock()
|
||||
sys.modules["phonemizer"] = Mock()
|
||||
sys.modules["models"] = Mock()
|
||||
|
|
|
@ -5,7 +5,7 @@ from unittest.mock import patch
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from api.src.services.audio import AudioService, AudioNormalizer
|
||||
from api.src.services.audio import AudioNormalizer, AudioService
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import asyncio
|
||||
from unittest.mock import Mock, AsyncMock
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import AsyncClient
|
||||
from fastapi.testclient import TestClient
|
||||
from httpx import AsyncClient
|
||||
|
||||
from ..src.main import app
|
||||
|
||||
|
|
|
@ -7,8 +7,8 @@ import pytest
|
|||
import pytest_asyncio
|
||||
from httpx import AsyncClient
|
||||
|
||||
from .conftest import MockTTSModel
|
||||
from ..src.main import app
|
||||
from .conftest import MockTTSModel
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
"""Tests for TTS model implementations"""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from api.src.services.tts_base import TTSBaseModel
|
||||
from api.src.services.tts_cpu import TTSCPUModel
|
||||
from api.src.services.tts_gpu import TTSGPUModel, length_to_mask
|
||||
from api.src.services.tts_base import TTSBaseModel
|
||||
|
||||
|
||||
# Base Model Tests
|
||||
|
@ -27,16 +27,30 @@ def test_get_device_error():
|
|||
@patch("os.listdir")
|
||||
@patch("torch.load")
|
||||
@patch("torch.save")
|
||||
@patch("api.src.services.tts_base.settings")
|
||||
@patch("api.src.services.warmup.WarmupService")
|
||||
async def test_setup_cuda_available(
|
||||
mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available
|
||||
mock_warmup_class, mock_settings, mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available
|
||||
):
|
||||
"""Test setup with CUDA available"""
|
||||
TTSBaseModel._device = None
|
||||
mock_cuda_available.return_value = True
|
||||
# Mock CUDA as unavailable since we're using CPU PyTorch
|
||||
mock_cuda_available.return_value = False
|
||||
mock_exists.return_value = True
|
||||
mock_load.return_value = torch.zeros(1)
|
||||
mock_listdir.return_value = ["voice1.pt", "voice2.pt"]
|
||||
mock_join.return_value = "/mocked/path"
|
||||
|
||||
# Configure mock settings
|
||||
mock_settings.model_dir = "/mock/model/dir"
|
||||
mock_settings.onnx_model_path = "model.onnx"
|
||||
mock_settings.voices_dir = "voices"
|
||||
|
||||
# Configure mock warmup service
|
||||
mock_warmup = MagicMock()
|
||||
mock_warmup.load_voices.return_value = [torch.zeros(1)]
|
||||
mock_warmup.warmup_voices = AsyncMock()
|
||||
mock_warmup_class.return_value = mock_warmup
|
||||
|
||||
# Create mock model
|
||||
mock_model = MagicMock()
|
||||
|
@ -49,7 +63,7 @@ async def test_setup_cuda_available(
|
|||
TTSBaseModel._instance = mock_model
|
||||
|
||||
voice_count = await TTSBaseModel.setup()
|
||||
assert TTSBaseModel._device == "cuda"
|
||||
assert TTSBaseModel._device == "cpu"
|
||||
assert voice_count == 2
|
||||
|
||||
|
||||
|
@ -60,8 +74,10 @@ async def test_setup_cuda_available(
|
|||
@patch("os.listdir")
|
||||
@patch("torch.load")
|
||||
@patch("torch.save")
|
||||
@patch("api.src.services.tts_base.settings")
|
||||
@patch("api.src.services.warmup.WarmupService")
|
||||
async def test_setup_cuda_unavailable(
|
||||
mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available
|
||||
mock_warmup_class, mock_settings, mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available
|
||||
):
|
||||
"""Test setup with CUDA unavailable"""
|
||||
TTSBaseModel._device = None
|
||||
|
@ -70,6 +86,17 @@ async def test_setup_cuda_unavailable(
|
|||
mock_load.return_value = torch.zeros(1)
|
||||
mock_listdir.return_value = ["voice1.pt", "voice2.pt"]
|
||||
mock_join.return_value = "/mocked/path"
|
||||
|
||||
# Configure mock settings
|
||||
mock_settings.model_dir = "/mock/model/dir"
|
||||
mock_settings.onnx_model_path = "model.onnx"
|
||||
mock_settings.voices_dir = "voices"
|
||||
|
||||
# Configure mock warmup service
|
||||
mock_warmup = MagicMock()
|
||||
mock_warmup.load_voices.return_value = [torch.zeros(1)]
|
||||
mock_warmup.warmup_voices = AsyncMock()
|
||||
mock_warmup_class.return_value = mock_warmup
|
||||
|
||||
# Create mock model
|
||||
mock_model = MagicMock()
|
||||
|
|
|
@ -4,8 +4,8 @@ import os
|
|||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import pytest
|
||||
import torch
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
from api.src.core.config import settings
|
||||
|
|
|
@ -1,79 +0,0 @@
|
|||
name: kokoro-fastapi
|
||||
services:
|
||||
model-fetcher:
|
||||
image: datamachines/git-lfs:latest
|
||||
volumes:
|
||||
- ./Kokoro-82M:/app/Kokoro-82M
|
||||
working_dir: /app/Kokoro-82M
|
||||
command: >
|
||||
sh -c "
|
||||
mkdir -p /app/Kokoro-82M;
|
||||
cd /app/Kokoro-82M;
|
||||
rm -f .git/index.lock;
|
||||
if [ -z \"$(ls -A .)\" ]; then
|
||||
git clone https://huggingface.co/hexgrad/Kokoro-82M .
|
||||
touch .cloned;
|
||||
else
|
||||
rm -f .git/index.lock && \
|
||||
git checkout main && \
|
||||
git pull origin main && \
|
||||
touch .cloned;
|
||||
fi;
|
||||
tail -f /dev/null
|
||||
"
|
||||
healthcheck:
|
||||
test: ["CMD", "test", "-f", ".cloned"]
|
||||
interval: 5s
|
||||
timeout: 2s
|
||||
retries: 300
|
||||
start_period: 1s
|
||||
|
||||
kokoro-tts:
|
||||
image: ghcr.io/remsky/kokoro-fastapi-cpu:v0.0.5post1
|
||||
# Uncomment below (and comment out above) to build from source instead of using the released image
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.cpu
|
||||
volumes:
|
||||
- ./api/src:/app/api/src
|
||||
- ./Kokoro-82M:/app/Kokoro-82M
|
||||
ports:
|
||||
- "8880:8880"
|
||||
environment:
|
||||
- PYTHONPATH=/app:/app/Kokoro-82M
|
||||
# ONNX Optimization Settings for vectorized operations
|
||||
- ONNX_NUM_THREADS=8 # Maximize core usage for vectorized ops
|
||||
- ONNX_INTER_OP_THREADS=4 # Higher inter-op for parallel matrix operations
|
||||
- ONNX_EXECUTION_MODE=parallel
|
||||
- ONNX_OPTIMIZATION_LEVEL=all
|
||||
- ONNX_MEMORY_PATTERN=true
|
||||
- ONNX_ARENA_EXTEND_STRATEGY=kNextPowerOfTwo
|
||||
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8880/health"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 30
|
||||
start_period: 30s
|
||||
depends_on:
|
||||
model-fetcher:
|
||||
condition: service_healthy
|
||||
|
||||
|
||||
# Gradio UI service [Comment out everything below if you don't need it]
|
||||
gradio-ui:
|
||||
image: ghcr.io/remsky/kokoro-fastapi-ui:v0.0.5post1
|
||||
# Uncomment below (and comment out above) to build from source instead of using the released image
|
||||
# build:
|
||||
# context: ./ui
|
||||
ports:
|
||||
- "7860:7860"
|
||||
volumes:
|
||||
- ./ui/data:/app/ui/data
|
||||
- ./ui/app.py:/app/app.py # Mount app.py for hot reload
|
||||
environment:
|
||||
- GRADIO_WATCH=True # Enable hot reloading
|
||||
- PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
|
||||
depends_on:
|
||||
kokoro-tts:
|
||||
condition: service_healthy
|
|
@ -1,80 +0,0 @@
|
|||
name: kokoro-fastapi
|
||||
services:
|
||||
model-fetcher:
|
||||
image: datamachines/git-lfs:latest
|
||||
environment:
|
||||
- SKIP_MODEL_FETCH=${SKIP_MODEL_FETCH:-false}
|
||||
volumes:
|
||||
- ./Kokoro-82M:/app/Kokoro-82M
|
||||
working_dir: /app/Kokoro-82M
|
||||
command: >
|
||||
sh -c "
|
||||
if [ \"$$SKIP_MODEL_FETCH\" = \"true\" ]; then
|
||||
echo 'Skipping model fetch...' && touch .cloned;
|
||||
else
|
||||
rm -f .git/index.lock;
|
||||
if [ -z \"$(ls -A .)\" ]; then
|
||||
git clone https://huggingface.co/hexgrad/Kokoro-82M .
|
||||
touch .cloned;
|
||||
else
|
||||
rm -f .git/index.lock && \
|
||||
git checkout main && \
|
||||
git pull origin main && \
|
||||
touch .cloned;
|
||||
fi;
|
||||
fi;
|
||||
tail -f /dev/null
|
||||
"
|
||||
healthcheck:
|
||||
test: ["CMD", "test", "-f", ".cloned"]
|
||||
interval: 5s
|
||||
timeout: 2s
|
||||
retries: 300
|
||||
start_period: 1s
|
||||
|
||||
kokoro-tts:
|
||||
image: ghcr.io/remsky/kokoro-fastapi-gpu:v0.0.5post1
|
||||
# Uncomment below (and comment out above) to build from source instead of using the released image
|
||||
# build:
|
||||
# context: .
|
||||
volumes:
|
||||
- ./api/src:/app/api/src
|
||||
- ./Kokoro-82M:/app/Kokoro-82M
|
||||
ports:
|
||||
- "8880:8880"
|
||||
environment:
|
||||
- PYTHONPATH=/app:/app/Kokoro-82M
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8880/health"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 30
|
||||
start_period: 30s
|
||||
depends_on:
|
||||
model-fetcher:
|
||||
condition: service_healthy
|
||||
|
||||
# Gradio UI service [Comment out everything below if you don't need it]
|
||||
gradio-ui:
|
||||
image: ghcr.io/remsky/kokoro-fastapi-ui:v0.0.5post1
|
||||
# Uncomment below (and comment out above) to build from source instead of using the released image
|
||||
# build:
|
||||
# context: ./ui
|
||||
ports:
|
||||
- "7860:7860"
|
||||
volumes:
|
||||
- ./ui/data:/app/ui/data
|
||||
- ./ui/app.py:/app/app.py # Mount app.py for hot reload
|
||||
environment:
|
||||
- GRADIO_WATCH=True # Enable hot reloading
|
||||
- PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
|
||||
depends_on:
|
||||
kokoro-tts:
|
||||
condition: service_healthy
|
62
docker/cpu/Dockerfile
Normal file
62
docker/cpu/Dockerfile
Normal file
|
@ -0,0 +1,62 @@
|
|||
FROM python:3.10-slim
|
||||
|
||||
# Install dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
espeak-ng \
|
||||
git \
|
||||
libsndfile1 \
|
||||
curl \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install uv
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
||||
|
||||
# Create non-root user
|
||||
RUN useradd -m -u 1000 appuser
|
||||
|
||||
# Create directories and set ownership
|
||||
RUN mkdir -p /app/api/model_files && \
|
||||
mkdir -p /app/api/src/voices && \
|
||||
chown -R appuser:appuser /app
|
||||
|
||||
USER appuser
|
||||
|
||||
# Download and extract models
|
||||
WORKDIR /app/api/model_files
|
||||
RUN curl -L -o model.tar.gz https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.0.1/kokoro-82m-onnx.tar.gz && \
|
||||
tar xzf model.tar.gz && \
|
||||
rm model.tar.gz
|
||||
|
||||
# Download and extract voice models
|
||||
WORKDIR /app/api/src/voices
|
||||
RUN curl -L -o voices.tar.gz https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.0.1/voice-models.tar.gz && \
|
||||
tar xzf voices.tar.gz && \
|
||||
rm voices.tar.gz
|
||||
|
||||
# Switch back to app directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy dependency files
|
||||
COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml
|
||||
|
||||
# Install dependencies
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv venv && \
|
||||
uv sync --extra cpu --no-install-project
|
||||
|
||||
# Copy project files
|
||||
COPY --chown=appuser:appuser api ./api
|
||||
|
||||
# Install project
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv sync --extra cpu
|
||||
|
||||
# Set environment variables
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV PYTHONPATH=/app:/app/Kokoro-82M
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Run FastAPI server
|
||||
CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]
|
37
docker/cpu/docker-compose.yml
Normal file
37
docker/cpu/docker-compose.yml
Normal file
|
@ -0,0 +1,37 @@
|
|||
name: kokoro-tts
|
||||
services:
|
||||
kokoro-tts:
|
||||
image: ghcr.io/remsky/kokoro-fastapi-cpu:latest
|
||||
# Uncomment below (and comment out above) to build from source instead of using the released image
|
||||
# build:
|
||||
# context: ../..
|
||||
# dockerfile: docker/cpu/Dockerfile
|
||||
volumes:
|
||||
- ../../api/src:/app/api/src
|
||||
ports:
|
||||
- "8880:8880"
|
||||
environment:
|
||||
- PYTHONPATH=/app:/app/Kokoro-82M
|
||||
# ONNX Optimization Settings for vectorized operations
|
||||
- ONNX_NUM_THREADS=8 # Maximize core usage for vectorized ops
|
||||
- ONNX_INTER_OP_THREADS=4 # Higher inter-op for parallel matrix operations
|
||||
- ONNX_EXECUTION_MODE=parallel
|
||||
- ONNX_OPTIMIZATION_LEVEL=all
|
||||
- ONNX_MEMORY_PATTERN=true
|
||||
- ONNX_ARENA_EXTEND_STRATEGY=kNextPowerOfTwo
|
||||
|
||||
# Gradio UI service [Comment out everything below if you don't need it]
|
||||
gradio-ui:
|
||||
image: ghcr.io/remsky/kokoro-fastapi:latest-ui
|
||||
# Uncomment below (and comment out above) to build from source instead of using the released image
|
||||
# build:
|
||||
# context: ../../ui
|
||||
ports:
|
||||
- "7860:7860"
|
||||
volumes:
|
||||
- ../../ui/data:/app/ui/data
|
||||
- ../../ui/app.py:/app/app.py # Mount app.py for hot reload
|
||||
environment:
|
||||
- GRADIO_WATCH=True # Enable hot reloading
|
||||
- PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
|
||||
- DISABLE_LOCAL_SAVING=false # Set to 'true' to disable local saving and hide file view
|
22
docker/cpu/pyproject.toml
Normal file
22
docker/cpu/pyproject.toml
Normal file
|
@ -0,0 +1,22 @@
|
|||
[project]
|
||||
name = "kokoro-fastapi-cpu"
|
||||
version = "0.1.0"
|
||||
description = "FastAPI TTS Service - CPU Version"
|
||||
readme = "../README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
# Core ML/DL for CPU
|
||||
"torch>=2.5.1",
|
||||
"transformers==4.47.1",
|
||||
]
|
||||
|
||||
[tool.uv.workspace]
|
||||
members = ["../shared"]
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = { index = "pytorch-cpu" }
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
explicit = true
|
229
docker/cpu/requirements.lock
Normal file
229
docker/cpu/requirements.lock
Normal file
|
@ -0,0 +1,229 @@
|
|||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile pyproject.toml ../shared/pyproject.toml --output-file requirements.lock
|
||||
aiofiles==23.2.1
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
anyio==4.8.0
|
||||
# via starlette
|
||||
attrs==24.3.0
|
||||
# via
|
||||
# clldutils
|
||||
# csvw
|
||||
# jsonschema
|
||||
# phonemizer
|
||||
# referencing
|
||||
babel==2.16.0
|
||||
# via csvw
|
||||
certifi==2024.12.14
|
||||
# via requests
|
||||
cffi==1.17.1
|
||||
# via soundfile
|
||||
charset-normalizer==3.4.1
|
||||
# via requests
|
||||
click==8.1.8
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# uvicorn
|
||||
clldutils==3.21.0
|
||||
# via segments
|
||||
colorama==0.4.6
|
||||
# via
|
||||
# click
|
||||
# colorlog
|
||||
# csvw
|
||||
# loguru
|
||||
# tqdm
|
||||
coloredlogs==15.0.1
|
||||
# via onnxruntime
|
||||
colorlog==6.9.0
|
||||
# via clldutils
|
||||
csvw==3.5.1
|
||||
# via segments
|
||||
dlinfo==1.2.1
|
||||
# via phonemizer
|
||||
exceptiongroup==1.2.2
|
||||
# via anyio
|
||||
fastapi==0.115.6
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
filelock==3.16.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# torch
|
||||
# transformers
|
||||
flatbuffers==24.12.23
|
||||
# via onnxruntime
|
||||
fsspec==2024.12.0
|
||||
# via
|
||||
# huggingface-hub
|
||||
# torch
|
||||
greenlet==3.1.1
|
||||
# via sqlalchemy
|
||||
h11==0.14.0
|
||||
# via uvicorn
|
||||
huggingface-hub==0.27.1
|
||||
# via
|
||||
# tokenizers
|
||||
# transformers
|
||||
humanfriendly==10.0
|
||||
# via coloredlogs
|
||||
idna==3.10
|
||||
# via
|
||||
# anyio
|
||||
# requests
|
||||
isodate==0.7.2
|
||||
# via
|
||||
# csvw
|
||||
# rdflib
|
||||
jinja2==3.1.5
|
||||
# via torch
|
||||
joblib==1.4.2
|
||||
# via phonemizer
|
||||
jsonschema==4.23.0
|
||||
# via csvw
|
||||
jsonschema-specifications==2024.10.1
|
||||
# via jsonschema
|
||||
language-tags==1.2.0
|
||||
# via csvw
|
||||
loguru==0.7.3
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
lxml==5.3.0
|
||||
# via clldutils
|
||||
markdown==3.7
|
||||
# via clldutils
|
||||
markupsafe==3.0.2
|
||||
# via
|
||||
# clldutils
|
||||
# jinja2
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
munch==4.0.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
networkx==3.4.2
|
||||
# via torch
|
||||
numpy==2.2.1
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# onnxruntime
|
||||
# scipy
|
||||
# soundfile
|
||||
# transformers
|
||||
onnxruntime==1.20.1
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
packaging==24.2
|
||||
# via
|
||||
# huggingface-hub
|
||||
# onnxruntime
|
||||
# transformers
|
||||
phonemizer==3.3.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
protobuf==5.29.3
|
||||
# via onnxruntime
|
||||
pycparser==2.22
|
||||
# via cffi
|
||||
pydantic==2.10.4
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# fastapi
|
||||
# pydantic-settings
|
||||
pydantic-core==2.27.2
|
||||
# via pydantic
|
||||
pydantic-settings==2.7.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
pylatexenc==2.10
|
||||
# via clldutils
|
||||
pyparsing==3.2.1
|
||||
# via rdflib
|
||||
pyreadline3==3.5.4
|
||||
# via humanfriendly
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# clldutils
|
||||
# csvw
|
||||
python-dotenv==1.0.1
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# pydantic-settings
|
||||
pyyaml==6.0.2
|
||||
# via
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
rdflib==7.1.2
|
||||
# via csvw
|
||||
referencing==0.35.1
|
||||
# via
|
||||
# jsonschema
|
||||
# jsonschema-specifications
|
||||
regex==2024.11.6
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# segments
|
||||
# tiktoken
|
||||
# transformers
|
||||
requests==2.32.3
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# csvw
|
||||
# huggingface-hub
|
||||
# tiktoken
|
||||
# transformers
|
||||
rfc3986==1.5.0
|
||||
# via csvw
|
||||
rpds-py==0.22.3
|
||||
# via
|
||||
# jsonschema
|
||||
# referencing
|
||||
safetensors==0.5.2
|
||||
# via transformers
|
||||
scipy==1.14.1
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
segments==2.2.1
|
||||
# via phonemizer
|
||||
six==1.17.0
|
||||
# via python-dateutil
|
||||
sniffio==1.3.1
|
||||
# via anyio
|
||||
soundfile==0.13.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
sqlalchemy==2.0.27
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
starlette==0.41.3
|
||||
# via fastapi
|
||||
sympy==1.13.1
|
||||
# via
|
||||
# onnxruntime
|
||||
# torch
|
||||
tabulate==0.9.0
|
||||
# via clldutils
|
||||
tiktoken==0.8.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
tokenizers==0.21.0
|
||||
# via transformers
|
||||
torch==2.5.1+cpu
|
||||
# via kokoro-fastapi-cpu (pyproject.toml)
|
||||
tqdm==4.67.1
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
transformers==4.47.1
|
||||
# via kokoro-fastapi-cpu (pyproject.toml)
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# anyio
|
||||
# fastapi
|
||||
# huggingface-hub
|
||||
# phonemizer
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# sqlalchemy
|
||||
# torch
|
||||
# uvicorn
|
||||
uritemplate==4.1.1
|
||||
# via csvw
|
||||
urllib3==2.3.0
|
||||
# via requests
|
||||
uvicorn==0.34.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
win32-setctime==1.2.0
|
||||
# via loguru
|
1841
docker/cpu/uv.lock
generated
Normal file
1841
docker/cpu/uv.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
64
docker/gpu/Dockerfile
Normal file
64
docker/gpu/Dockerfile
Normal file
|
@ -0,0 +1,64 @@
|
|||
FROM nvidia/cuda:12.1.0-base-ubuntu22.04
|
||||
|
||||
# Install Python and other dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3.10 \
|
||||
python3.10-venv \
|
||||
espeak-ng \
|
||||
git \
|
||||
libsndfile1 \
|
||||
curl \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install uv
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
||||
|
||||
# Create non-root user
|
||||
RUN useradd -m -u 1000 appuser
|
||||
|
||||
# Create directories and set ownership
|
||||
RUN mkdir -p /app/api/model_files && \
|
||||
mkdir -p /app/api/src/voices && \
|
||||
chown -R appuser:appuser /app
|
||||
|
||||
USER appuser
|
||||
|
||||
# Download and extract models
|
||||
WORKDIR /app/api/model_files
|
||||
RUN curl -L -o model.tar.gz https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.0.1/kokoro-82m-pytorch.tar.gz && \
|
||||
tar xzf model.tar.gz && \
|
||||
rm model.tar.gz
|
||||
|
||||
# Download and extract voice models
|
||||
WORKDIR /app/api/src/voices
|
||||
RUN curl -L -o voices.tar.gz https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.0.1/voice-models.tar.gz && \
|
||||
tar xzf voices.tar.gz && \
|
||||
rm voices.tar.gz
|
||||
|
||||
# Switch back to app directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy dependency files
|
||||
COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml
|
||||
|
||||
# Install dependencies
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv venv && \
|
||||
uv sync --extra gpu --no-install-project
|
||||
|
||||
# Copy project files
|
||||
COPY --chown=appuser:appuser api ./api
|
||||
|
||||
# Install project
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv sync --extra gpu
|
||||
|
||||
# Set environment variables
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV PYTHONPATH=/app:/app/Kokoro-82M
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Run FastAPI server
|
||||
CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]
|
37
docker/gpu/docker-compose.yml
Normal file
37
docker/gpu/docker-compose.yml
Normal file
|
@ -0,0 +1,37 @@
|
|||
name: kokoro-tts
|
||||
services:
|
||||
kokoro-tts:
|
||||
image: ghcr.io/remsky/kokoro-fastapi-gpu:latest
|
||||
# Uncomment below (and comment out above) to build from source instead of using the released image
|
||||
# build:
|
||||
# context: ../..
|
||||
# dockerfile: docker/gpu/Dockerfile
|
||||
volumes:
|
||||
- ../../api/src:/app/api/src # Mount src for development
|
||||
ports:
|
||||
- "8880:8880"
|
||||
environment:
|
||||
- PYTHONPATH=/app:/app/Kokoro-82M
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
|
||||
# Gradio UI service
|
||||
gradio-ui:
|
||||
image: ghcr.io/remsky/kokoro-fastapi-ui:latest
|
||||
# Uncomment below to build from source instead of using the released image
|
||||
# build:
|
||||
# context: ../../ui
|
||||
ports:
|
||||
- "7860:7860"
|
||||
volumes:
|
||||
- ../../ui/data:/app/ui/data
|
||||
- ../../ui/app.py:/app/app.py # Mount app.py for hot reload
|
||||
environment:
|
||||
- GRADIO_WATCH=1 # Enable hot reloading
|
||||
- PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
|
||||
- DISABLE_LOCAL_SAVING=false # Set to 'true' to disable local saving and hide file view
|
22
docker/gpu/pyproject.toml
Normal file
22
docker/gpu/pyproject.toml
Normal file
|
@ -0,0 +1,22 @@
|
|||
[project]
|
||||
name = "kokoro-fastapi-gpu"
|
||||
version = "0.1.0"
|
||||
description = "FastAPI TTS Service - GPU Version"
|
||||
readme = "../README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
# Core ML/DL for GPU
|
||||
"torch==2.5.1+cu121",
|
||||
"transformers==4.47.1",
|
||||
]
|
||||
|
||||
[tool.uv.workspace]
|
||||
members = ["../shared"]
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = { index = "pytorch-cuda" }
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cuda"
|
||||
url = "https://download.pytorch.org/whl/cu121"
|
||||
explicit = true
|
229
docker/gpu/requirements.lock
Normal file
229
docker/gpu/requirements.lock
Normal file
|
@ -0,0 +1,229 @@
|
|||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile pyproject.toml ../shared/pyproject.toml --output-file requirements.lock
|
||||
aiofiles==23.2.1
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
anyio==4.8.0
|
||||
# via starlette
|
||||
attrs==24.3.0
|
||||
# via
|
||||
# clldutils
|
||||
# csvw
|
||||
# jsonschema
|
||||
# phonemizer
|
||||
# referencing
|
||||
babel==2.16.0
|
||||
# via csvw
|
||||
certifi==2024.12.14
|
||||
# via requests
|
||||
cffi==1.17.1
|
||||
# via soundfile
|
||||
charset-normalizer==3.4.1
|
||||
# via requests
|
||||
click==8.1.8
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# uvicorn
|
||||
clldutils==3.21.0
|
||||
# via segments
|
||||
colorama==0.4.6
|
||||
# via
|
||||
# click
|
||||
# colorlog
|
||||
# csvw
|
||||
# loguru
|
||||
# tqdm
|
||||
coloredlogs==15.0.1
|
||||
# via onnxruntime
|
||||
colorlog==6.9.0
|
||||
# via clldutils
|
||||
csvw==3.5.1
|
||||
# via segments
|
||||
dlinfo==1.2.1
|
||||
# via phonemizer
|
||||
exceptiongroup==1.2.2
|
||||
# via anyio
|
||||
fastapi==0.115.6
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
filelock==3.16.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# torch
|
||||
# transformers
|
||||
flatbuffers==24.12.23
|
||||
# via onnxruntime
|
||||
fsspec==2024.12.0
|
||||
# via
|
||||
# huggingface-hub
|
||||
# torch
|
||||
greenlet==3.1.1
|
||||
# via sqlalchemy
|
||||
h11==0.14.0
|
||||
# via uvicorn
|
||||
huggingface-hub==0.27.1
|
||||
# via
|
||||
# tokenizers
|
||||
# transformers
|
||||
humanfriendly==10.0
|
||||
# via coloredlogs
|
||||
idna==3.10
|
||||
# via
|
||||
# anyio
|
||||
# requests
|
||||
isodate==0.7.2
|
||||
# via
|
||||
# csvw
|
||||
# rdflib
|
||||
jinja2==3.1.5
|
||||
# via torch
|
||||
joblib==1.4.2
|
||||
# via phonemizer
|
||||
jsonschema==4.23.0
|
||||
# via csvw
|
||||
jsonschema-specifications==2024.10.1
|
||||
# via jsonschema
|
||||
language-tags==1.2.0
|
||||
# via csvw
|
||||
loguru==0.7.3
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
lxml==5.3.0
|
||||
# via clldutils
|
||||
markdown==3.7
|
||||
# via clldutils
|
||||
markupsafe==3.0.2
|
||||
# via
|
||||
# clldutils
|
||||
# jinja2
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
munch==4.0.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
networkx==3.4.2
|
||||
# via torch
|
||||
numpy==2.2.1
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# onnxruntime
|
||||
# scipy
|
||||
# soundfile
|
||||
# transformers
|
||||
onnxruntime==1.20.1
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
packaging==24.2
|
||||
# via
|
||||
# huggingface-hub
|
||||
# onnxruntime
|
||||
# transformers
|
||||
phonemizer==3.3.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
protobuf==5.29.3
|
||||
# via onnxruntime
|
||||
pycparser==2.22
|
||||
# via cffi
|
||||
pydantic==2.10.4
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# fastapi
|
||||
# pydantic-settings
|
||||
pydantic-core==2.27.2
|
||||
# via pydantic
|
||||
pydantic-settings==2.7.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
pylatexenc==2.10
|
||||
# via clldutils
|
||||
pyparsing==3.2.1
|
||||
# via rdflib
|
||||
pyreadline3==3.5.4
|
||||
# via humanfriendly
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# clldutils
|
||||
# csvw
|
||||
python-dotenv==1.0.1
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# pydantic-settings
|
||||
pyyaml==6.0.2
|
||||
# via
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
rdflib==7.1.2
|
||||
# via csvw
|
||||
referencing==0.35.1
|
||||
# via
|
||||
# jsonschema
|
||||
# jsonschema-specifications
|
||||
regex==2024.11.6
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# segments
|
||||
# tiktoken
|
||||
# transformers
|
||||
requests==2.32.3
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# csvw
|
||||
# huggingface-hub
|
||||
# tiktoken
|
||||
# transformers
|
||||
rfc3986==1.5.0
|
||||
# via csvw
|
||||
rpds-py==0.22.3
|
||||
# via
|
||||
# jsonschema
|
||||
# referencing
|
||||
safetensors==0.5.2
|
||||
# via transformers
|
||||
scipy==1.14.1
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
segments==2.2.1
|
||||
# via phonemizer
|
||||
six==1.17.0
|
||||
# via python-dateutil
|
||||
sniffio==1.3.1
|
||||
# via anyio
|
||||
soundfile==0.13.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
sqlalchemy==2.0.27
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
starlette==0.41.3
|
||||
# via fastapi
|
||||
sympy==1.13.1
|
||||
# via
|
||||
# onnxruntime
|
||||
# torch
|
||||
tabulate==0.9.0
|
||||
# via clldutils
|
||||
tiktoken==0.8.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
tokenizers==0.21.0
|
||||
# via transformers
|
||||
torch==2.5.1+cu121
|
||||
# via kokoro-fastapi-gpu (pyproject.toml)
|
||||
tqdm==4.67.1
|
||||
# via
|
||||
# kokoro-fastapi (../shared/pyproject.toml)
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
transformers==4.47.1
|
||||
# via kokoro-fastapi-gpu (pyproject.toml)
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# anyio
|
||||
# fastapi
|
||||
# huggingface-hub
|
||||
# phonemizer
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# sqlalchemy
|
||||
# torch
|
||||
# uvicorn
|
||||
uritemplate==4.1.1
|
||||
# via csvw
|
||||
urllib3==2.3.0
|
||||
# via requests
|
||||
uvicorn==0.34.0
|
||||
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||
win32-setctime==1.2.0
|
||||
# via loguru
|
1914
docker/gpu/uv.lock
generated
Normal file
1914
docker/gpu/uv.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
44
docker/shared/pyproject.toml
Normal file
44
docker/shared/pyproject.toml
Normal file
|
@ -0,0 +1,44 @@
|
|||
[project]
|
||||
name = "kokoro-fastapi"
|
||||
version = "0.1.0"
|
||||
description = "FastAPI TTS Service"
|
||||
readme = "../README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
# Core dependencies
|
||||
"fastapi==0.115.6",
|
||||
"uvicorn==0.34.0",
|
||||
"click>=8.0.0",
|
||||
"pydantic==2.10.4",
|
||||
"pydantic-settings==2.7.0",
|
||||
"python-dotenv==1.0.1",
|
||||
"sqlalchemy==2.0.27",
|
||||
|
||||
# ML/DL Base
|
||||
"numpy>=1.26.0",
|
||||
"scipy==1.14.1",
|
||||
"onnxruntime==1.20.1",
|
||||
|
||||
# Audio processing
|
||||
"soundfile==0.13.0",
|
||||
|
||||
# Text processing
|
||||
"phonemizer==3.3.0",
|
||||
"regex==2024.11.6",
|
||||
|
||||
# Utilities
|
||||
"aiofiles==23.2.1",
|
||||
"tqdm==4.67.1",
|
||||
"requests==2.32.3",
|
||||
"munch==4.0.0",
|
||||
"tiktoken==0.8.0",
|
||||
"loguru==0.7.3",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
test = [
|
||||
"pytest==8.0.0",
|
||||
"httpx==0.26.0",
|
||||
"pytest-asyncio==0.23.5",
|
||||
"ruff==0.9.1",
|
||||
]
|
|
@ -9,7 +9,7 @@ sqlalchemy==2.0.27
|
|||
|
||||
# ML/DL
|
||||
transformers==4.47.1
|
||||
numpy==2.2.1
|
||||
numpy>=1.26.0 # Version managed by PyTorch dependencies
|
||||
scipy==1.14.1
|
||||
onnxruntime==1.20.1
|
||||
|
||||
|
@ -21,7 +21,7 @@ phonemizer==3.3.0
|
|||
regex==2024.11.6
|
||||
|
||||
# Utilities
|
||||
aiofiles==24.1.0
|
||||
aiofiles==23.2.1 # Last version before Windows path handling changes
|
||||
tqdm==4.67.1
|
||||
requests==2.32.3
|
||||
munch==4.0.0
|
243
docs/requirements.txt
Normal file
243
docs/requirements.txt
Normal file
|
@ -0,0 +1,243 @@
|
|||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile docs/requirements.in --universal --output-file docs/requirements.txt
|
||||
aiofiles==23.2.1
|
||||
# via -r docs/requirements.in
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
anyio==4.8.0
|
||||
# via
|
||||
# httpx
|
||||
# starlette
|
||||
attrs==24.3.0
|
||||
# via
|
||||
# clldutils
|
||||
# csvw
|
||||
# jsonschema
|
||||
# phonemizer
|
||||
# referencing
|
||||
babel==2.16.0
|
||||
# via csvw
|
||||
certifi==2024.12.14
|
||||
# via
|
||||
# httpcore
|
||||
# httpx
|
||||
# requests
|
||||
cffi==1.17.1
|
||||
# via soundfile
|
||||
charset-normalizer==3.4.1
|
||||
# via requests
|
||||
click==8.1.8
|
||||
# via uvicorn
|
||||
clldutils==3.21.0
|
||||
# via segments
|
||||
colorama==0.4.6
|
||||
# via
|
||||
# click
|
||||
# colorlog
|
||||
# csvw
|
||||
# loguru
|
||||
# pytest
|
||||
# tqdm
|
||||
coloredlogs==15.0.1
|
||||
# via onnxruntime
|
||||
colorlog==6.9.0
|
||||
# via clldutils
|
||||
csvw==3.5.1
|
||||
# via segments
|
||||
dlinfo==1.2.1
|
||||
# via phonemizer
|
||||
exceptiongroup==1.2.2 ; python_full_version < '3.11'
|
||||
# via
|
||||
# anyio
|
||||
# pytest
|
||||
fastapi==0.115.6
|
||||
# via -r docs/requirements.in
|
||||
filelock==3.16.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
flatbuffers==24.12.23
|
||||
# via onnxruntime
|
||||
fsspec==2024.12.0
|
||||
# via huggingface-hub
|
||||
greenlet==3.1.1 ; platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64'
|
||||
# via sqlalchemy
|
||||
h11==0.14.0
|
||||
# via
|
||||
# httpcore
|
||||
# uvicorn
|
||||
httpcore==1.0.7
|
||||
# via httpx
|
||||
httpx==0.26.0
|
||||
# via -r docs/requirements.in
|
||||
huggingface-hub==0.27.1
|
||||
# via
|
||||
# tokenizers
|
||||
# transformers
|
||||
humanfriendly==10.0
|
||||
# via coloredlogs
|
||||
idna==3.10
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
# requests
|
||||
iniconfig==2.0.0
|
||||
# via pytest
|
||||
isodate==0.7.2
|
||||
# via
|
||||
# csvw
|
||||
# rdflib
|
||||
joblib==1.4.2
|
||||
# via phonemizer
|
||||
jsonschema==4.23.0
|
||||
# via csvw
|
||||
jsonschema-specifications==2024.10.1
|
||||
# via jsonschema
|
||||
language-tags==1.2.0
|
||||
# via csvw
|
||||
loguru==0.7.3
|
||||
# via -r docs/requirements.in
|
||||
lxml==5.3.0
|
||||
# via clldutils
|
||||
markdown==3.7
|
||||
# via clldutils
|
||||
markupsafe==3.0.2
|
||||
# via clldutils
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
munch==4.0.0
|
||||
# via -r docs/requirements.in
|
||||
numpy==2.2.1
|
||||
# via
|
||||
# -r docs/requirements.in
|
||||
# onnxruntime
|
||||
# scipy
|
||||
# soundfile
|
||||
# transformers
|
||||
onnxruntime==1.20.1
|
||||
# via -r docs/requirements.in
|
||||
packaging==24.2
|
||||
# via
|
||||
# huggingface-hub
|
||||
# onnxruntime
|
||||
# pytest
|
||||
# transformers
|
||||
phonemizer==3.3.0
|
||||
# via -r docs/requirements.in
|
||||
pluggy==1.5.0
|
||||
# via pytest
|
||||
protobuf==5.29.3
|
||||
# via onnxruntime
|
||||
pycparser==2.22
|
||||
# via cffi
|
||||
pydantic==2.10.4
|
||||
# via
|
||||
# -r docs/requirements.in
|
||||
# fastapi
|
||||
# pydantic-settings
|
||||
pydantic-core==2.27.2
|
||||
# via pydantic
|
||||
pydantic-settings==2.7.0
|
||||
# via -r docs/requirements.in
|
||||
pylatexenc==2.10
|
||||
# via clldutils
|
||||
pyparsing==3.2.1
|
||||
# via rdflib
|
||||
pyreadline3==3.5.4 ; sys_platform == 'win32'
|
||||
# via humanfriendly
|
||||
pytest==8.0.0
|
||||
# via
|
||||
# -r docs/requirements.in
|
||||
# pytest-asyncio
|
||||
pytest-asyncio==0.23.5
|
||||
# via -r docs/requirements.in
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# clldutils
|
||||
# csvw
|
||||
python-dotenv==1.0.1
|
||||
# via
|
||||
# -r docs/requirements.in
|
||||
# pydantic-settings
|
||||
pyyaml==6.0.2
|
||||
# via
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
rdflib==7.1.2
|
||||
# via csvw
|
||||
referencing==0.35.1
|
||||
# via
|
||||
# jsonschema
|
||||
# jsonschema-specifications
|
||||
regex==2024.11.6
|
||||
# via
|
||||
# -r docs/requirements.in
|
||||
# segments
|
||||
# tiktoken
|
||||
# transformers
|
||||
requests==2.32.3
|
||||
# via
|
||||
# -r docs/requirements.in
|
||||
# csvw
|
||||
# huggingface-hub
|
||||
# tiktoken
|
||||
# transformers
|
||||
rfc3986==1.5.0
|
||||
# via csvw
|
||||
rpds-py==0.22.3
|
||||
# via
|
||||
# jsonschema
|
||||
# referencing
|
||||
safetensors==0.5.2
|
||||
# via transformers
|
||||
scipy==1.14.1
|
||||
# via -r docs/requirements.in
|
||||
segments==2.2.1
|
||||
# via phonemizer
|
||||
six==1.17.0
|
||||
# via python-dateutil
|
||||
sniffio==1.3.1
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
soundfile==0.13.0
|
||||
# via -r docs/requirements.in
|
||||
sqlalchemy==2.0.27
|
||||
# via -r docs/requirements.in
|
||||
starlette==0.41.3
|
||||
# via fastapi
|
||||
sympy==1.13.3
|
||||
# via onnxruntime
|
||||
tabulate==0.9.0
|
||||
# via clldutils
|
||||
tiktoken==0.8.0
|
||||
# via -r docs/requirements.in
|
||||
tokenizers==0.21.0
|
||||
# via transformers
|
||||
tomli==2.2.1 ; python_full_version < '3.11'
|
||||
# via pytest
|
||||
tqdm==4.67.1
|
||||
# via
|
||||
# -r docs/requirements.in
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
transformers==4.47.1
|
||||
# via -r docs/requirements.in
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# anyio
|
||||
# fastapi
|
||||
# huggingface-hub
|
||||
# phonemizer
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# sqlalchemy
|
||||
# uvicorn
|
||||
uritemplate==4.1.1
|
||||
# via csvw
|
||||
urllib3==2.3.0
|
||||
# via requests
|
||||
uvicorn==0.34.0
|
||||
# via -r docs/requirements.in
|
||||
win32-setctime==1.2.0 ; sys_platform == 'win32'
|
||||
# via loguru
|
2
examples/requirements.txt
Normal file
2
examples/requirements.txt
Normal file
|
@ -0,0 +1,2 @@
|
|||
openai>=1.0.0
|
||||
pyaudio>=0.2.13
|
Binary file not shown.
|
@ -8,7 +8,7 @@ import requests
|
|||
import sounddevice as sd
|
||||
|
||||
|
||||
def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"):
|
||||
def play_streaming_tts(text: str, output_file: str = None, voice: str = "af_sky"):
|
||||
"""Stream TTS audio and play it back in real-time"""
|
||||
|
||||
print("\nStarting TTS stream request...")
|
||||
|
|
90
pyproject.toml
Normal file
90
pyproject.toml
Normal file
|
@ -0,0 +1,90 @@
|
|||
[project]
|
||||
name = "kokoro-fastapi"
|
||||
version = "0.1.0"
|
||||
description = "FastAPI TTS Service"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
# Core dependencies
|
||||
"fastapi==0.115.6",
|
||||
"uvicorn==0.34.0",
|
||||
"click>=8.0.0",
|
||||
"pydantic==2.10.4",
|
||||
"pydantic-settings==2.7.0",
|
||||
"python-dotenv==1.0.1",
|
||||
"sqlalchemy==2.0.27",
|
||||
# ML/DL Base
|
||||
"numpy>=1.26.0",
|
||||
"scipy==1.14.1",
|
||||
"onnxruntime==1.20.1",
|
||||
# Audio processing
|
||||
"soundfile==0.13.0",
|
||||
# Text processing
|
||||
"phonemizer==3.3.0",
|
||||
"regex==2024.11.6",
|
||||
# Utilities
|
||||
"aiofiles==23.2.1",
|
||||
"tqdm==4.67.1",
|
||||
"requests==2.32.3",
|
||||
"munch==4.0.0",
|
||||
"tiktoken==0.8.0",
|
||||
"loguru==0.7.3",
|
||||
"transformers==4.47.1",
|
||||
"openai>=1.59.6",
|
||||
"ebooklib>=0.18",
|
||||
"html2text>=2024.2.26",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
gpu = [
|
||||
"torch==2.5.1+cu121",
|
||||
]
|
||||
cpu = [
|
||||
"torch==2.5.1+cpu",
|
||||
]
|
||||
test = [
|
||||
"pytest==8.0.0",
|
||||
"pytest-cov==4.1.0",
|
||||
"httpx==0.26.0",
|
||||
"pytest-asyncio==0.23.5",
|
||||
"gradio>=5",
|
||||
"openai>=1.59.6",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
conflicts = [
|
||||
[
|
||||
{ extra = "cpu" },
|
||||
{ extra = "gpu" },
|
||||
],
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = [
|
||||
{ index = "pytorch-cpu", extra = "cpu" },
|
||||
{ index = "pytorch-cuda", extra = "gpu" },
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cuda"
|
||||
url = "https://download.pytorch.org/whl/cu121"
|
||||
explicit = true
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.setuptools]
|
||||
package-dir = {"" = "api/src"}
|
||||
packages.find = {where = ["api/src"], namespaces = true}
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["api/tests", "ui/tests"]
|
||||
python_files = ["test_*.py"]
|
||||
addopts = "--cov=api --cov=ui --cov-report=term-missing --cov-config=.coveragerc"
|
||||
asyncio_mode = "strict"
|
|
@ -1,14 +0,0 @@
|
|||
# Core dependencies for testing
|
||||
fastapi==0.115.6
|
||||
uvicorn==0.34.0
|
||||
pydantic==2.10.4
|
||||
pydantic-settings==2.7.0
|
||||
python-dotenv==1.0.1
|
||||
sqlalchemy==2.0.27
|
||||
|
||||
# Testing
|
||||
pytest==8.0.0
|
||||
httpx==0.26.0
|
||||
pytest-asyncio==0.23.5
|
||||
pytest-cov==6.0.0
|
||||
gradio==4.19.2
|
|
@ -1,9 +1,3 @@
|
|||
import warnings
|
||||
|
||||
# Filter out Gradio Dropdown warnings about values not in choices
|
||||
#TODO: Warning continues to be displayed, though it isn't breaking anything
|
||||
warnings.filterwarnings('ignore', category=UserWarning, module='gradio.components.dropdown')
|
||||
|
||||
from lib.interface import create_interface
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -36,15 +36,18 @@ def check_api_status() -> Tuple[bool, List[str]]:
|
|||
|
||||
|
||||
def text_to_speech(
|
||||
text: str, voice_id: str, format: str, speed: float
|
||||
text: str, voice_id: str | list, format: str, speed: float
|
||||
) -> Optional[str]:
|
||||
"""Generate speech from text using TTS API."""
|
||||
if not text.strip():
|
||||
return None
|
||||
|
||||
# Handle multiple voices
|
||||
voice_str = voice_id if isinstance(voice_id, str) else "+".join(voice_id)
|
||||
|
||||
# Create output filename
|
||||
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
output_filename = f"output_{timestamp}_voice-{voice_id}_speed-{speed}.{format}"
|
||||
output_filename = f"output_{timestamp}_voice-{voice_str}_speed-{speed}.{format}"
|
||||
output_path = os.path.join(OUTPUTS_DIR, output_filename)
|
||||
|
||||
try:
|
||||
|
@ -53,7 +56,7 @@ def text_to_speech(
|
|||
json={
|
||||
"model": "kokoro",
|
||||
"input": text,
|
||||
"voice": voice_id,
|
||||
"voice": voice_str,
|
||||
"response_format": format,
|
||||
"speed": float(speed),
|
||||
},
|
||||
|
|
|
@ -5,54 +5,78 @@ import gradio as gr
|
|||
from .. import files
|
||||
|
||||
|
||||
def create_input_column() -> Tuple[gr.Column, dict]:
|
||||
def create_input_column(disable_local_saving: bool = False) -> Tuple[gr.Column, dict]:
|
||||
"""Create the input column with text input and file handling."""
|
||||
with gr.Column(scale=1) as col:
|
||||
with gr.Tabs() as tabs:
|
||||
# Set first tab as selected by default
|
||||
tabs.selected = 0
|
||||
# Direct Input Tab
|
||||
with gr.TabItem("Direct Input"):
|
||||
text_input = gr.Textbox(
|
||||
label="Text to speak", placeholder="Enter text here...", lines=4
|
||||
)
|
||||
text_submit = gr.Button("Generate Speech", variant="primary", size="lg")
|
||||
text_input = gr.Textbox(
|
||||
label="Text to speak", placeholder="Enter text here...", lines=4
|
||||
)
|
||||
|
||||
# Always show file upload but handle differently based on disable_local_saving
|
||||
file_upload = gr.File(
|
||||
label="Upload Text File (.txt)", file_types=[".txt"]
|
||||
)
|
||||
|
||||
if not disable_local_saving:
|
||||
# Show full interface with tabs when saving is enabled
|
||||
with gr.Tabs() as tabs:
|
||||
# Set first tab as selected by default
|
||||
tabs.selected = 0
|
||||
# Direct Input Tab
|
||||
with gr.TabItem("Direct Input"):
|
||||
text_submit_direct = gr.Button("Generate Speech", variant="primary", size="lg")
|
||||
|
||||
# File Input Tab
|
||||
with gr.TabItem("From File"):
|
||||
# Existing files dropdown
|
||||
input_files_list = gr.Dropdown(
|
||||
label="Select Existing File",
|
||||
choices=files.list_input_files(),
|
||||
value=None,
|
||||
)
|
||||
|
||||
# Simple file upload
|
||||
file_upload = gr.File(
|
||||
label="Upload Text File (.txt)", file_types=[".txt"]
|
||||
)
|
||||
|
||||
file_preview = gr.Textbox(
|
||||
label="File Content Preview", interactive=False, lines=4
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
file_submit = gr.Button(
|
||||
"Generate Speech", variant="primary", size="lg"
|
||||
)
|
||||
clear_files = gr.Button(
|
||||
"Clear Files", variant="secondary", size="lg"
|
||||
# File Input Tab
|
||||
with gr.TabItem("From File"):
|
||||
# Existing files dropdown
|
||||
input_files_list = gr.Dropdown(
|
||||
label="Select Existing File",
|
||||
choices=files.list_input_files(),
|
||||
value=None,
|
||||
)
|
||||
|
||||
components = {
|
||||
"tabs": tabs,
|
||||
"text_input": text_input,
|
||||
"file_select": input_files_list,
|
||||
"file_upload": file_upload,
|
||||
"file_preview": file_preview,
|
||||
"text_submit": text_submit,
|
||||
"file_submit": file_submit,
|
||||
"clear_files": clear_files,
|
||||
}
|
||||
file_preview = gr.Textbox(
|
||||
label="File Content Preview", interactive=False, lines=4
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
file_submit = gr.Button(
|
||||
"Generate Speech", variant="primary", size="lg"
|
||||
)
|
||||
clear_files = gr.Button(
|
||||
"Clear Files", variant="secondary", size="lg"
|
||||
)
|
||||
else:
|
||||
# Just show the generate button when saving is disabled
|
||||
text_submit_direct = gr.Button("Generate Speech", variant="primary", size="lg")
|
||||
tabs = None
|
||||
input_files_list = None
|
||||
file_preview = None
|
||||
file_submit = None
|
||||
clear_files = None
|
||||
|
||||
# Initialize components based on disable_local_saving
|
||||
if disable_local_saving:
|
||||
components = {
|
||||
"tabs": None,
|
||||
"text_input": text_input,
|
||||
"text_submit": text_submit_direct,
|
||||
"file_select": None,
|
||||
"file_upload": file_upload, # Keep file upload even when saving is disabled
|
||||
"file_preview": None,
|
||||
"file_submit": None,
|
||||
"clear_files": None,
|
||||
}
|
||||
else:
|
||||
components = {
|
||||
"tabs": tabs,
|
||||
"text_input": text_input,
|
||||
"text_submit": text_submit_direct,
|
||||
"file_select": input_files_list,
|
||||
"file_upload": file_upload,
|
||||
"file_preview": file_preview,
|
||||
"file_submit": file_submit,
|
||||
"clear_files": clear_files,
|
||||
}
|
||||
|
||||
return col, components
|
||||
|
|
|
@ -20,10 +20,10 @@ def create_model_column(voice_ids: Optional[list] = None) -> Tuple[gr.Column, di
|
|||
|
||||
voice_input = gr.Dropdown(
|
||||
choices=voice_ids,
|
||||
label="Voice",
|
||||
value=voice_ids[0] if voice_ids else None, # Set default value to first item if available
|
||||
label="Voice(s)",
|
||||
value=voice_ids[0] if voice_ids else None,
|
||||
interactive=True,
|
||||
allow_custom_value=True, # Allow temporary values during updates
|
||||
multiselect=True,
|
||||
)
|
||||
format_input = gr.Dropdown(
|
||||
choices=config.AUDIO_FORMATS, label="Audio Format", value="mp3"
|
||||
|
|
|
@ -5,34 +5,43 @@ import gradio as gr
|
|||
from .. import files
|
||||
|
||||
|
||||
def create_output_column() -> Tuple[gr.Column, dict]:
|
||||
def create_output_column(disable_local_saving: bool = False) -> Tuple[gr.Column, dict]:
|
||||
"""Create the output column with audio player and file list."""
|
||||
with gr.Column(scale=1) as col:
|
||||
gr.Markdown("### Latest Output")
|
||||
audio_output = gr.Audio(label="Generated Speech", type="filepath")
|
||||
audio_output = gr.Audio(
|
||||
label="Generated Speech",
|
||||
type="filepath",
|
||||
waveform_options={"waveform_color": "#4C87AB"}
|
||||
)
|
||||
|
||||
gr.Markdown("### Generated Files")
|
||||
# Initialize dropdown with empty choices first
|
||||
# Create file-related components with visible=False when local saving is disabled
|
||||
gr.Markdown("### Generated Files", visible=not disable_local_saving)
|
||||
output_files = gr.Dropdown(
|
||||
label="Previous Outputs",
|
||||
choices=[],
|
||||
choices=files.list_output_files() if not disable_local_saving else [],
|
||||
value=None,
|
||||
allow_custom_value=True,
|
||||
interactive=True,
|
||||
visible=not disable_local_saving,
|
||||
)
|
||||
# Then update choices after component creation
|
||||
output_files.choices = files.list_output_files()
|
||||
|
||||
play_btn = gr.Button("▶️ Play Selected", size="sm")
|
||||
play_btn = gr.Button(
|
||||
"▶️ Play Selected",
|
||||
size="sm",
|
||||
visible=not disable_local_saving,
|
||||
)
|
||||
|
||||
selected_audio = gr.Audio(
|
||||
label="Selected Output", type="filepath", visible=False
|
||||
label="Selected Output",
|
||||
type="filepath",
|
||||
visible=False, # Always initially hidden
|
||||
)
|
||||
|
||||
clear_outputs = gr.Button(
|
||||
"⚠️ Delete All Previously Generated Output Audio 🗑️",
|
||||
size="sm",
|
||||
variant="secondary",
|
||||
visible=not disable_local_saving,
|
||||
)
|
||||
|
||||
components = {
|
||||
|
|
|
@ -11,12 +11,14 @@ def list_input_files() -> List[str]:
|
|||
|
||||
|
||||
def list_output_files() -> List[str]:
|
||||
"""List all output audio files."""
|
||||
# Just return filenames since paths will be different inside/outside container
|
||||
return [
|
||||
f for f in os.listdir(OUTPUTS_DIR)
|
||||
"""List all output audio files, sorted by most recent first."""
|
||||
files = [
|
||||
os.path.join(OUTPUTS_DIR, f)
|
||||
for f in os.listdir(OUTPUTS_DIR)
|
||||
if any(f.endswith(ext) for ext in AUDIO_FORMATS)
|
||||
]
|
||||
# Sort files by modification time, most recent first
|
||||
return sorted(files, key=os.path.getmtime, reverse=True)
|
||||
|
||||
|
||||
def read_text_file(filename: str) -> str:
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
import os
|
||||
import shutil
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from . import api, files
|
||||
|
||||
|
||||
def setup_event_handlers(components: dict):
|
||||
def setup_event_handlers(components: dict, disable_local_saving: bool = False):
|
||||
"""Set up all event handlers for the UI components."""
|
||||
|
||||
def refresh_status():
|
||||
|
@ -57,27 +58,37 @@ def setup_event_handlers(components: dict):
|
|||
|
||||
def handle_file_upload(file):
|
||||
if file is None:
|
||||
return gr.update(choices=files.list_input_files())
|
||||
return "" if disable_local_saving else [gr.update(choices=files.list_input_files())]
|
||||
|
||||
try:
|
||||
# Copy file to inputs directory
|
||||
filename = os.path.basename(file.name)
|
||||
target_path = os.path.join(files.INPUTS_DIR, filename)
|
||||
# Read the file content
|
||||
with open(file.name, 'r', encoding='utf-8') as f:
|
||||
text_content = f.read()
|
||||
|
||||
# Handle duplicate filenames
|
||||
base, ext = os.path.splitext(filename)
|
||||
counter = 1
|
||||
while os.path.exists(target_path):
|
||||
new_name = f"{base}_{counter}{ext}"
|
||||
target_path = os.path.join(files.INPUTS_DIR, new_name)
|
||||
counter += 1
|
||||
if disable_local_saving:
|
||||
# When saving is disabled, put content directly in text input
|
||||
# Normalize whitespace by replacing newlines with spaces
|
||||
normalized_text = ' '.join(text_content.split())
|
||||
return normalized_text
|
||||
else:
|
||||
# When saving is enabled, save file and update dropdown
|
||||
filename = os.path.basename(file.name)
|
||||
target_path = os.path.join(files.INPUTS_DIR, filename)
|
||||
|
||||
shutil.copy2(file.name, target_path)
|
||||
# Handle duplicate filenames
|
||||
base, ext = os.path.splitext(filename)
|
||||
counter = 1
|
||||
while os.path.exists(target_path):
|
||||
new_name = f"{base}_{counter}{ext}"
|
||||
target_path = os.path.join(files.INPUTS_DIR, new_name)
|
||||
counter += 1
|
||||
|
||||
shutil.copy2(file.name, target_path)
|
||||
return [gr.update(choices=files.list_input_files())]
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error uploading file: {e}")
|
||||
|
||||
return gr.update(choices=files.list_input_files())
|
||||
print(f"Error handling file: {e}")
|
||||
return "" if disable_local_saving else [gr.update(choices=files.list_input_files())]
|
||||
|
||||
def generate_from_text(text, voice, format, speed):
|
||||
"""Generate speech from direct text input"""
|
||||
|
@ -90,18 +101,20 @@ def setup_event_handlers(components: dict):
|
|||
gr.Warning("Please enter text in the input box")
|
||||
return [None, gr.update(choices=files.list_output_files())]
|
||||
|
||||
files.save_text(text)
|
||||
# Only save text if local saving is enabled
|
||||
if not disable_local_saving:
|
||||
files.save_text(text)
|
||||
|
||||
result = api.text_to_speech(text, voice, format, speed)
|
||||
if result is None:
|
||||
gr.Warning("Failed to generate speech. Please try again.")
|
||||
return [None, gr.update(choices=files.list_output_files())]
|
||||
|
||||
# Update list and select the newly generated file
|
||||
output_files = files.list_output_files()
|
||||
last_file = output_files[-1] if output_files else None
|
||||
return [
|
||||
result,
|
||||
gr.update(choices=output_files, value=last_file),
|
||||
gr.update(
|
||||
choices=files.list_output_files(), value=os.path.basename(result)
|
||||
),
|
||||
]
|
||||
|
||||
def generate_from_file(selected_file, voice, format, speed):
|
||||
|
@ -121,19 +134,16 @@ def setup_event_handlers(components: dict):
|
|||
gr.Warning("Failed to generate speech. Please try again.")
|
||||
return [None, gr.update(choices=files.list_output_files())]
|
||||
|
||||
# Update list and select the newly generated file
|
||||
output_files = files.list_output_files()
|
||||
last_file = output_files[-1] if output_files else None
|
||||
return [
|
||||
result,
|
||||
gr.update(choices=output_files, value=last_file),
|
||||
gr.update(
|
||||
choices=files.list_output_files(), value=os.path.basename(result)
|
||||
),
|
||||
]
|
||||
|
||||
def play_selected(filename):
|
||||
if filename:
|
||||
file_path = os.path.join(files.OUTPUTS_DIR, filename)
|
||||
if os.path.exists(file_path):
|
||||
return gr.update(value=file_path, visible=True)
|
||||
def play_selected(file_path):
|
||||
if file_path and os.path.exists(file_path):
|
||||
return gr.update(value=file_path, visible=True)
|
||||
return gr.update(visible=False)
|
||||
|
||||
def clear_files(voice, format, speed):
|
||||
|
@ -165,45 +175,7 @@ def setup_event_handlers(components: dict):
|
|||
outputs=[components["model"]["status_btn"], components["model"]["voice"]],
|
||||
)
|
||||
|
||||
components["input"]["file_select"].change(
|
||||
fn=handle_file_select,
|
||||
inputs=[components["input"]["file_select"]],
|
||||
outputs=[components["input"]["file_preview"]],
|
||||
)
|
||||
|
||||
components["input"]["file_upload"].upload(
|
||||
fn=handle_file_upload,
|
||||
inputs=[components["input"]["file_upload"]],
|
||||
outputs=[components["input"]["file_select"]],
|
||||
)
|
||||
|
||||
components["output"]["play_btn"].click(
|
||||
fn=play_selected,
|
||||
inputs=[components["output"]["output_files"]],
|
||||
outputs=[components["output"]["selected_audio"]],
|
||||
)
|
||||
|
||||
# Connect clear files button
|
||||
components["input"]["clear_files"].click(
|
||||
fn=clear_files,
|
||||
inputs=[
|
||||
components["model"]["voice"],
|
||||
components["model"]["format"],
|
||||
components["model"]["speed"],
|
||||
],
|
||||
outputs=[
|
||||
components["input"]["file_select"],
|
||||
components["input"]["file_upload"],
|
||||
components["input"]["file_preview"],
|
||||
components["output"]["audio_output"],
|
||||
components["output"]["output_files"],
|
||||
components["model"]["voice"],
|
||||
components["model"]["format"],
|
||||
components["model"]["speed"],
|
||||
],
|
||||
)
|
||||
|
||||
# Connect submit buttons for each tab
|
||||
# Connect text submit button (always present)
|
||||
components["input"]["text_submit"].click(
|
||||
fn=generate_from_text,
|
||||
inputs=[
|
||||
|
@ -218,26 +190,70 @@ def setup_event_handlers(components: dict):
|
|||
],
|
||||
)
|
||||
|
||||
# Connect clear outputs button
|
||||
components["output"]["clear_outputs"].click(
|
||||
fn=clear_outputs,
|
||||
outputs=[
|
||||
components["output"]["audio_output"],
|
||||
components["output"]["output_files"],
|
||||
components["output"]["selected_audio"],
|
||||
],
|
||||
)
|
||||
# Only connect file-related handlers if components exist
|
||||
if components["input"]["file_select"] is not None:
|
||||
components["input"]["file_select"].change(
|
||||
fn=handle_file_select,
|
||||
inputs=[components["input"]["file_select"]],
|
||||
outputs=[components["input"]["file_preview"]],
|
||||
)
|
||||
|
||||
components["input"]["file_submit"].click(
|
||||
fn=generate_from_file,
|
||||
inputs=[
|
||||
components["input"]["file_select"],
|
||||
components["model"]["voice"],
|
||||
components["model"]["format"],
|
||||
components["model"]["speed"],
|
||||
],
|
||||
outputs=[
|
||||
components["output"]["audio_output"],
|
||||
components["output"]["output_files"],
|
||||
],
|
||||
)
|
||||
if components["input"]["file_upload"] is not None:
|
||||
# File upload handler - output depends on disable_local_saving
|
||||
components["input"]["file_upload"].upload(
|
||||
fn=handle_file_upload,
|
||||
inputs=[components["input"]["file_upload"]],
|
||||
outputs=[components["input"]["text_input"] if disable_local_saving else components["input"]["file_select"]],
|
||||
)
|
||||
|
||||
if components["output"]["play_btn"] is not None:
|
||||
components["output"]["play_btn"].click(
|
||||
fn=play_selected,
|
||||
inputs=[components["output"]["output_files"]],
|
||||
outputs=[components["output"]["selected_audio"]],
|
||||
)
|
||||
|
||||
if components["input"]["clear_files"] is not None:
|
||||
components["input"]["clear_files"].click(
|
||||
fn=clear_files,
|
||||
inputs=[
|
||||
components["model"]["voice"],
|
||||
components["model"]["format"],
|
||||
components["model"]["speed"],
|
||||
],
|
||||
outputs=[
|
||||
components["input"]["file_select"],
|
||||
components["input"]["file_upload"],
|
||||
components["input"]["file_preview"],
|
||||
components["output"]["audio_output"],
|
||||
components["output"]["output_files"],
|
||||
components["model"]["voice"],
|
||||
components["model"]["format"],
|
||||
components["model"]["speed"],
|
||||
],
|
||||
)
|
||||
|
||||
if components["output"]["clear_outputs"] is not None:
|
||||
components["output"]["clear_outputs"].click(
|
||||
fn=clear_outputs,
|
||||
outputs=[
|
||||
components["output"]["audio_output"],
|
||||
components["output"]["output_files"],
|
||||
components["output"]["selected_audio"],
|
||||
],
|
||||
)
|
||||
|
||||
if components["input"]["file_submit"] is not None:
|
||||
components["input"]["file_submit"].click(
|
||||
fn=generate_from_file,
|
||||
inputs=[
|
||||
components["input"]["file_select"],
|
||||
components["model"]["voice"],
|
||||
components["model"]["format"],
|
||||
components["model"]["speed"],
|
||||
],
|
||||
outputs=[
|
||||
components["output"]["audio_output"],
|
||||
components["output"]["output_files"],
|
||||
],
|
||||
)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import gradio as gr
|
||||
import os
|
||||
|
||||
from . import api
|
||||
from .handlers import setup_event_handlers
|
||||
|
@ -10,6 +11,9 @@ def create_interface():
|
|||
# Skip initial status check - let the timer handle it
|
||||
is_available, available_voices = False, []
|
||||
|
||||
# Check if local saving is disabled
|
||||
disable_local_saving = os.getenv("DISABLE_LOCAL_SAVING", "false").lower() == "true"
|
||||
|
||||
with gr.Blocks(title="Kokoro TTS Demo", theme=gr.themes.Monochrome()) as demo:
|
||||
gr.HTML(
|
||||
value='<div style="display: flex; gap: 0;">'
|
||||
|
@ -22,11 +26,11 @@ def create_interface():
|
|||
# Main interface
|
||||
with gr.Row():
|
||||
# Create columns
|
||||
input_col, input_components = create_input_column()
|
||||
input_col, input_components = create_input_column(disable_local_saving)
|
||||
model_col, model_components = create_model_column(
|
||||
available_voices
|
||||
) # Pass initial voices
|
||||
output_col, output_components = create_output_column()
|
||||
output_col, output_components = create_output_column(disable_local_saving)
|
||||
|
||||
# Collect all components
|
||||
components = {
|
||||
|
@ -36,7 +40,7 @@ def create_interface():
|
|||
}
|
||||
|
||||
# Set up event handlers
|
||||
setup_event_handlers(components)
|
||||
setup_event_handlers(components, disable_local_saving)
|
||||
|
||||
# Add periodic status check with Timer
|
||||
def update_status():
|
||||
|
|
|
@ -106,24 +106,54 @@ def test_get_status_html_unavailable():
|
|||
|
||||
def test_text_to_speech_api_params(mock_response, tmp_path):
|
||||
"""Test correct API parameters are sent"""
|
||||
with patch("requests.post") as mock_post, patch(
|
||||
"ui.lib.api.OUTPUTS_DIR", str(tmp_path)
|
||||
), patch("builtins.open", mock_open()):
|
||||
mock_post.return_value = mock_response({})
|
||||
api.text_to_speech("test text", "voice1", "mp3", 1.5)
|
||||
test_cases = [
|
||||
# Single voice as string
|
||||
("voice1", "voice1"),
|
||||
# Multiple voices as list
|
||||
(["voice1", "voice2"], "voice1+voice2"),
|
||||
# Single voice as list
|
||||
(["voice1"], "voice1"),
|
||||
]
|
||||
|
||||
mock_post.assert_called_once()
|
||||
args, kwargs = mock_post.call_args
|
||||
for input_voice, expected_voice in test_cases:
|
||||
with patch("requests.post") as mock_post, patch(
|
||||
"ui.lib.api.OUTPUTS_DIR", str(tmp_path)
|
||||
), patch("builtins.open", mock_open()):
|
||||
mock_post.return_value = mock_response({})
|
||||
api.text_to_speech("test text", input_voice, "mp3", 1.5)
|
||||
|
||||
# Check request body
|
||||
assert kwargs["json"] == {
|
||||
"model": "kokoro",
|
||||
"input": "test text",
|
||||
"voice": "voice1",
|
||||
"response_format": "mp3",
|
||||
"speed": 1.5,
|
||||
}
|
||||
mock_post.assert_called_once()
|
||||
args, kwargs = mock_post.call_args
|
||||
|
||||
# Check headers and timeout
|
||||
assert kwargs["headers"] == {"Content-Type": "application/json"}
|
||||
assert kwargs["timeout"] == 300
|
||||
# Check request body
|
||||
assert kwargs["json"] == {
|
||||
"model": "kokoro",
|
||||
"input": "test text",
|
||||
"voice": expected_voice,
|
||||
"response_format": "mp3",
|
||||
"speed": 1.5,
|
||||
}
|
||||
|
||||
# Check headers and timeout
|
||||
assert kwargs["headers"] == {"Content-Type": "application/json"}
|
||||
assert kwargs["timeout"] == 300
|
||||
|
||||
|
||||
def test_text_to_speech_output_filename(mock_response, tmp_path):
|
||||
"""Test output filename contains correct voice identifier"""
|
||||
test_cases = [
|
||||
# Single voice
|
||||
("voice1", lambda f: "voice-voice1" in f),
|
||||
# Multiple voices
|
||||
(["voice1", "voice2"], lambda f: "voice-voice1+voice2" in f),
|
||||
]
|
||||
|
||||
for input_voice, filename_check in test_cases:
|
||||
with patch("requests.post", return_value=mock_response({})), patch(
|
||||
"ui.lib.api.OUTPUTS_DIR", str(tmp_path)
|
||||
), patch("builtins.open", mock_open()) as mock_file:
|
||||
result = api.text_to_speech("test text", input_voice, "mp3", 1.0)
|
||||
|
||||
assert result is not None
|
||||
assert filename_check(result), f"Expected voice pattern not found in filename: {result}"
|
||||
mock_file.assert_called_once()
|
||||
|
|
|
@ -36,8 +36,10 @@ def test_model_column_default_values():
|
|||
expected_choices = [(voice_id, voice_id) for voice_id in voice_ids]
|
||||
assert components["voice"].choices == expected_choices
|
||||
# Value is not converted to tuple format for the value property
|
||||
assert components["voice"].value == voice_ids[0]
|
||||
assert components["voice"].value == [voice_ids[0]]
|
||||
assert components["voice"].interactive is True
|
||||
assert components["voice"].multiselect is True
|
||||
assert components["voice"].label == "Voice(s)"
|
||||
|
||||
# Test format dropdown
|
||||
# Gradio Dropdown converts choices to (value, label) tuples
|
||||
|
|
|
@ -136,7 +136,7 @@ def test_interface_components_presence():
|
|||
|
||||
required_components = {
|
||||
"Text to speak",
|
||||
"Voice",
|
||||
"Voice(s)",
|
||||
"Audio Format",
|
||||
"Speed",
|
||||
"Generated Speech",
|
||||
|
|
Loading…
Add table
Reference in a new issue