mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Merge pre-release: Update CI workflow for uv
This commit is contained in:
commit
5045cf968e
69 changed files with 9202 additions and 762 deletions
|
@ -7,6 +7,7 @@ omit =
|
||||||
MagicMock/*
|
MagicMock/*
|
||||||
test_*.py
|
test_*.py
|
||||||
examples/*
|
examples/*
|
||||||
|
src/builds/*
|
||||||
|
|
||||||
[report]
|
[report]
|
||||||
exclude_lines =
|
exclude_lines =
|
||||||
|
|
71
.github/workflows/ci.yml
vendored
71
.github/workflows/ci.yml
vendored
|
@ -1,51 +1,32 @@
|
||||||
# name: CI
|
name: CI
|
||||||
|
|
||||||
# on:
|
on:
|
||||||
# push:
|
push:
|
||||||
# branches: [ "develop", "master" ]
|
branches: [ "master", "pre-release" ]
|
||||||
# pull_request:
|
pull_request:
|
||||||
# branches: [ "develop", "master" ]
|
branches: [ "master", "pre-release" ]
|
||||||
|
|
||||||
# jobs:
|
jobs:
|
||||||
# test:
|
test:
|
||||||
# runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
# strategy:
|
strategy:
|
||||||
# matrix:
|
matrix:
|
||||||
# python-version: ["3.9", "3.10", "3.11"]
|
python-version: ["3.10"]
|
||||||
# fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
# steps:
|
steps:
|
||||||
# - uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
# - name: Set up Python ${{ matrix.python-version }}
|
- name: Install uv
|
||||||
# uses: actions/setup-python@v5
|
uses: astral-sh/setup-uv@v5
|
||||||
# with:
|
with:
|
||||||
# python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
|
enable-cache: true
|
||||||
|
|
||||||
# - name: Set up pip cache
|
- name: Install dependencies
|
||||||
# uses: actions/cache@v3
|
run: |
|
||||||
# with:
|
uv pip install -e .[test,cpu]
|
||||||
# path: ~/.cache/pip
|
|
||||||
# key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements*.txt') }}
|
|
||||||
# restore-keys: |
|
|
||||||
# ${{ runner.os }}-pip-
|
|
||||||
|
|
||||||
# - name: Install PyTorch CPU
|
- name: Run Tests
|
||||||
# run: |
|
run: |
|
||||||
# python -m pip install --upgrade pip
|
uv run pytest api/tests/ --asyncio-mode=auto --cov=api --cov-report=term-missing
|
||||||
# pip install torch --index-url https://download.pytorch.org/whl/cpu
|
|
||||||
|
|
||||||
# - name: Install dependencies
|
|
||||||
# run: |
|
|
||||||
# pip install ruff pytest-cov
|
|
||||||
# pip install -r requirements.txt
|
|
||||||
# pip install -r requirements-test.txt
|
|
||||||
|
|
||||||
# - name: Lint with ruff
|
|
||||||
# run: |
|
|
||||||
# ruff check .
|
|
||||||
|
|
||||||
|
|
||||||
# - name: Test with pytest
|
|
||||||
# run: |
|
|
||||||
# pytest --asyncio-mode=auto --cov=api --cov-report=term-missing
|
|
||||||
|
|
124
.github/workflows/docker-publish.yml
vendored
124
.github/workflows/docker-publish.yml
vendored
|
@ -1,7 +1,9 @@
|
||||||
name: Docker Build and Publish
|
name: Docker Build, Slim, and Publish
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
tags: [ 'v*.*.*' ]
|
tags: [ 'v*.*.*' ]
|
||||||
# Allow manual trigger from GitHub UI
|
# Allow manual trigger from GitHub UI
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
@ -16,6 +18,7 @@ jobs:
|
||||||
permissions:
|
permissions:
|
||||||
contents: read
|
contents: read
|
||||||
packages: write
|
packages: write
|
||||||
|
actions: write
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
|
@ -28,67 +31,76 @@ jobs:
|
||||||
username: ${{ github.actor }}
|
username: ${{ github.actor }}
|
||||||
password: ${{ secrets.GITHUB_TOKEN }}
|
password: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
# Extract metadata for GPU image
|
# Set up image names (converting to lowercase)
|
||||||
- name: Extract metadata (tags, labels) for GPU Docker
|
- name: Set image names
|
||||||
id: meta-gpu
|
run: |
|
||||||
uses: docker/metadata-action@v5
|
echo "GPU_IMAGE_NAME=${{ env.REGISTRY }}/$(echo ${{ env.IMAGE_NAME }} | tr '[:upper:]' '[:lower:]')-gpu" >> $GITHUB_ENV
|
||||||
with:
|
echo "CPU_IMAGE_NAME=${{ env.REGISTRY }}/$(echo ${{ env.IMAGE_NAME }} | tr '[:upper:]' '[:lower:]')-cpu" >> $GITHUB_ENV
|
||||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
echo "UI_IMAGE_NAME=${{ env.REGISTRY }}/$(echo ${{ env.IMAGE_NAME }} | tr '[:upper:]' '[:lower:]')-ui" >> $GITHUB_ENV
|
||||||
tags: |
|
|
||||||
type=semver,pattern=v{{version}}
|
|
||||||
type=semver,pattern=v{{major}}.{{minor}}
|
|
||||||
type=semver,pattern=v{{major}}
|
|
||||||
type=raw,value=latest
|
|
||||||
|
|
||||||
# Extract metadata for CPU image
|
# Build GPU version
|
||||||
- name: Extract metadata (tags, labels) for CPU Docker
|
- name: Build GPU Docker image
|
||||||
id: meta-cpu
|
|
||||||
uses: docker/metadata-action@v5
|
|
||||||
with:
|
|
||||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
|
||||||
flavor: |
|
|
||||||
suffix=-cpu
|
|
||||||
tags: |
|
|
||||||
type=semver,pattern=v{{version}}
|
|
||||||
type=semver,pattern=v{{major}}.{{minor}}
|
|
||||||
type=semver,pattern=v{{major}}
|
|
||||||
type=raw,value=latest
|
|
||||||
|
|
||||||
# Build and push GPU version
|
|
||||||
- name: Build and push GPU Docker image
|
|
||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v5
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
file: ./Dockerfile
|
file: ./docker/gpu/Dockerfile
|
||||||
push: true
|
push: false
|
||||||
tags: ${{ steps.meta-gpu.outputs.tags }}
|
load: true
|
||||||
labels: ${{ steps.meta-gpu.outputs.labels }}
|
tags: ${{ env.GPU_IMAGE_NAME }}:v0.1.0
|
||||||
|
build-args: |
|
||||||
|
DOCKER_BUILDKIT=1
|
||||||
platforms: linux/amd64
|
platforms: linux/amd64
|
||||||
|
|
||||||
# Build and push CPU version
|
# Slim GPU version
|
||||||
- name: Build and push CPU Docker image
|
- name: Slim GPU Docker image
|
||||||
|
uses: kitabisa/docker-slim-action@v1
|
||||||
|
env:
|
||||||
|
DSLIM_HTTP_PROBE: false
|
||||||
|
with:
|
||||||
|
target: ${{ env.GPU_IMAGE_NAME }}:v0.1.0
|
||||||
|
tag: v0.1.0-slim
|
||||||
|
|
||||||
|
# Push GPU versions
|
||||||
|
- name: Push GPU Docker images
|
||||||
|
run: |
|
||||||
|
docker push ${{ env.GPU_IMAGE_NAME }}:v0.1.0
|
||||||
|
docker push ${{ env.GPU_IMAGE_NAME }}:v0.1.0-slim
|
||||||
|
docker tag ${{ env.GPU_IMAGE_NAME }}:v0.1.0 ${{ env.GPU_IMAGE_NAME }}:latest
|
||||||
|
docker tag ${{ env.GPU_IMAGE_NAME }}:v0.1.0-slim ${{ env.GPU_IMAGE_NAME }}:latest-slim
|
||||||
|
docker push ${{ env.GPU_IMAGE_NAME }}:latest
|
||||||
|
docker push ${{ env.GPU_IMAGE_NAME }}:latest-slim
|
||||||
|
|
||||||
|
# Build CPU version
|
||||||
|
- name: Build CPU Docker image
|
||||||
uses: docker/build-push-action@v5
|
uses: docker/build-push-action@v5
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
file: ./Dockerfile.cpu
|
file: ./docker/cpu/Dockerfile
|
||||||
push: true
|
push: false
|
||||||
tags: ${{ steps.meta-cpu.outputs.tags }}
|
load: true
|
||||||
labels: ${{ steps.meta-cpu.outputs.labels }}
|
tags: ${{ env.CPU_IMAGE_NAME }}:v0.1.0
|
||||||
|
build-args: |
|
||||||
|
DOCKER_BUILDKIT=1
|
||||||
platforms: linux/amd64
|
platforms: linux/amd64
|
||||||
|
|
||||||
# Extract metadata for UI image
|
# Slim CPU version
|
||||||
- name: Extract metadata (tags, labels) for UI Docker
|
- name: Slim CPU Docker image
|
||||||
id: meta-ui
|
uses: kitabisa/docker-slim-action@v1
|
||||||
uses: docker/metadata-action@v5
|
env:
|
||||||
|
DSLIM_HTTP_PROBE: false
|
||||||
with:
|
with:
|
||||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
target: ${{ env.CPU_IMAGE_NAME }}:v0.1.0
|
||||||
flavor: |
|
tag: v0.1.0-slim
|
||||||
suffix=-ui
|
|
||||||
tags: |
|
# Push CPU versions
|
||||||
type=semver,pattern=v{{version}}
|
- name: Push CPU Docker images
|
||||||
type=semver,pattern=v{{major}}.{{minor}}
|
run: |
|
||||||
type=semver,pattern=v{{major}}
|
docker push ${{ env.CPU_IMAGE_NAME }}:v0.1.0
|
||||||
type=raw,value=latest
|
docker push ${{ env.CPU_IMAGE_NAME }}:v0.1.0-slim
|
||||||
|
docker tag ${{ env.CPU_IMAGE_NAME }}:v0.1.0 ${{ env.CPU_IMAGE_NAME }}:latest
|
||||||
|
docker tag ${{ env.CPU_IMAGE_NAME }}:v0.1.0-slim ${{ env.CPU_IMAGE_NAME }}:latest-slim
|
||||||
|
docker push ${{ env.CPU_IMAGE_NAME }}:latest
|
||||||
|
docker push ${{ env.CPU_IMAGE_NAME }}:latest-slim
|
||||||
|
|
||||||
# Build and push UI version
|
# Build and push UI version
|
||||||
- name: Build and push UI Docker image
|
- name: Build and push UI Docker image
|
||||||
|
@ -97,8 +109,11 @@ jobs:
|
||||||
context: ./ui
|
context: ./ui
|
||||||
file: ./ui/Dockerfile
|
file: ./ui/Dockerfile
|
||||||
push: true
|
push: true
|
||||||
tags: ${{ steps.meta-ui.outputs.tags }}
|
tags: |
|
||||||
labels: ${{ steps.meta-ui.outputs.labels }}
|
${{ env.UI_IMAGE_NAME }}:v0.1.0
|
||||||
|
${{ env.UI_IMAGE_NAME }}:latest
|
||||||
|
build-args: |
|
||||||
|
DOCKER_BUILDKIT=1
|
||||||
platforms: linux/amd64
|
platforms: linux/amd64
|
||||||
|
|
||||||
create-release:
|
create-release:
|
||||||
|
@ -108,13 +123,16 @@ jobs:
|
||||||
if: startsWith(github.ref, 'refs/tags/')
|
if: startsWith(github.ref, 'refs/tags/')
|
||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
|
packages: write
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Create Release
|
- name: Create Release
|
||||||
uses: softprops/action-gh-release@v1
|
uses: softprops/action-gh-release@v1
|
||||||
|
env:
|
||||||
|
IS_PRERELEASE: ${{ contains(github.ref, '-pre') }}
|
||||||
with:
|
with:
|
||||||
generate_release_notes: true
|
generate_release_notes: true
|
||||||
draft: false
|
draft: false
|
||||||
prerelease: false
|
prerelease: ${{ contains(github.ref, '-pre') }}
|
||||||
|
|
71
.gitignore
vendored
71
.gitignore
vendored
|
@ -2,51 +2,78 @@
|
||||||
.git
|
.git
|
||||||
|
|
||||||
# Python
|
# Python
|
||||||
__pycache__
|
__pycache__/
|
||||||
*.pyc
|
*.pyc
|
||||||
*.pyo
|
*.pyo
|
||||||
*.pyd
|
*.pyd
|
||||||
*.pt
|
|
||||||
.Python
|
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
*$py.class
|
*$py.class
|
||||||
|
.Python
|
||||||
.pytest_cache
|
.pytest_cache
|
||||||
.coverage
|
.coverage
|
||||||
.coveragerc
|
.coveragerc
|
||||||
|
|
||||||
|
# Python package build artifacts
|
||||||
|
*.egg-info/
|
||||||
|
*.egg
|
||||||
|
dist/
|
||||||
|
build/
|
||||||
|
|
||||||
# Environment
|
# Environment
|
||||||
# .env
|
# .env
|
||||||
.venv
|
.venv/
|
||||||
env/
|
env/
|
||||||
venv/
|
venv/
|
||||||
ENV/
|
ENV/
|
||||||
|
|
||||||
# IDE
|
# IDE
|
||||||
.idea
|
.idea/
|
||||||
.vscode
|
.vscode/
|
||||||
*.swp
|
*.swp
|
||||||
*.swo
|
*.swo
|
||||||
|
|
||||||
# Project specific
|
# Project specific
|
||||||
*examples/*.wav
|
# Model files
|
||||||
*examples/*.pcm
|
*.pt
|
||||||
*examples/*.mp3
|
*.pth
|
||||||
*examples/*.flac
|
*.tar*
|
||||||
*examples/*.acc
|
|
||||||
*examples/*.ogg
|
|
||||||
|
|
||||||
|
# Voice files
|
||||||
|
api/src/voices/af_bella.pt
|
||||||
|
api/src/voices/af_nicole.pt
|
||||||
|
api/src/voices/af_sarah.pt
|
||||||
|
api/src/voices/af_sky.pt
|
||||||
|
api/src/voices/af.pt
|
||||||
|
api/src/voices/am_adam.pt
|
||||||
|
api/src/voices/am_michael.pt
|
||||||
|
api/src/voices/bf_emma.pt
|
||||||
|
api/src/voices/bf_isabella.pt
|
||||||
|
api/src/voices/bm_george.pt
|
||||||
|
api/src/voices/bm_lewis.pt
|
||||||
|
|
||||||
|
# Audio files
|
||||||
|
examples/*.wav
|
||||||
|
examples/*.pcm
|
||||||
|
examples/*.mp3
|
||||||
|
examples/*.flac
|
||||||
|
examples/*.acc
|
||||||
|
examples/*.ogg
|
||||||
|
examples/speech.mp3
|
||||||
|
examples/phoneme_examples/output/example_1.wav
|
||||||
|
examples/phoneme_examples/output/example_2.wav
|
||||||
|
examples/phoneme_examples/output/example_3.wav
|
||||||
|
|
||||||
|
# Other project files
|
||||||
Kokoro-82M/
|
Kokoro-82M/
|
||||||
ui/data
|
ui/data/
|
||||||
tests/
|
EXTERNAL_UV_DOCUMENTATION*
|
||||||
*.md
|
|
||||||
*.txt
|
|
||||||
requirements.txt
|
|
||||||
|
|
||||||
# Docker
|
# Docker
|
||||||
Dockerfile*
|
Dockerfile*
|
||||||
docker-compose*
|
docker-compose*
|
||||||
|
examples/assorted_checks/River_of_Teet_-_Sarah_Gailey.epub
|
||||||
*.egg-info
|
examples/ebook_test/chapter_to_audio.py
|
||||||
*.pt
|
examples/ebook_test/chapters_to_audio.py
|
||||||
*.wav
|
examples/ebook_test/parse_epub.py
|
||||||
*.tar*
|
examples/ebook_test/River_of_Teet_-_Sarah_Gailey.epub
|
||||||
|
examples/ebook_test/River_of_Teet_-_Sarah_Gailey.txt
|
||||||
|
|
1
.python-version
Normal file
1
.python-version
Normal file
|
@ -0,0 +1 @@
|
||||||
|
3.10
|
|
@ -1,11 +1,12 @@
|
||||||
line-length = 88
|
line-length = 88
|
||||||
|
|
||||||
|
exclude = ["examples"]
|
||||||
|
|
||||||
[lint]
|
[lint]
|
||||||
select = ["I"]
|
select = ["I"]
|
||||||
|
|
||||||
[lint.isort]
|
[lint.isort]
|
||||||
combine-as-imports = true
|
combine-as-imports = true
|
||||||
force-wrap-aliases = true
|
force-wrap-aliases = true
|
||||||
length-sort = true
|
|
||||||
split-on-trailing-comma = true
|
split-on-trailing-comma = true
|
||||||
section-order = ["future", "standard-library", "third-party", "first-party", "local-folder"]
|
section-order = ["future", "standard-library", "third-party", "first-party", "local-folder"]
|
||||||
|
|
11
CHANGELOG.md
11
CHANGELOG.md
|
@ -2,6 +2,17 @@
|
||||||
|
|
||||||
Notable changes to this project will be documented in this file.
|
Notable changes to this project will be documented in this file.
|
||||||
|
|
||||||
|
## [v0.1.0] - 2025-01-13
|
||||||
|
### Changed
|
||||||
|
- Major Docker improvements:
|
||||||
|
- Baked model directly into Dockerfile for improved deployment reliability
|
||||||
|
- Switched to uv for dependency management
|
||||||
|
- Streamlined container builds and reduced image sizes
|
||||||
|
- Dependency Management:
|
||||||
|
- Migrated from pip/poetry to uv for faster, more reliable package management
|
||||||
|
- Added uv.lock for deterministic builds
|
||||||
|
- Updated dependency resolution strategy
|
||||||
|
|
||||||
## [v0.0.5post1] - 2025-01-11
|
## [v0.0.5post1] - 2025-01-11
|
||||||
### Fixed
|
### Fixed
|
||||||
- Docker image tagging and versioning improvements (-gpu, -cpu, -ui)
|
- Docker image tagging and versioning improvements (-gpu, -cpu, -ui)
|
||||||
|
|
44
Dockerfile
44
Dockerfile
|
@ -1,44 +0,0 @@
|
||||||
FROM nvidia/cuda:12.1.0-base-ubuntu22.04
|
|
||||||
|
|
||||||
# Install base system dependencies
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
|
||||||
python3-pip \
|
|
||||||
python3-dev \
|
|
||||||
espeak-ng \
|
|
||||||
git \
|
|
||||||
libsndfile1 \
|
|
||||||
curl \
|
|
||||||
&& apt-get clean \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# Install PyTorch with CUDA support first
|
|
||||||
RUN pip3 install --no-cache-dir torch==2.5.1 --extra-index-url https://download.pytorch.org/whl/cu121
|
|
||||||
|
|
||||||
# Install all other dependencies from requirements.txt
|
|
||||||
COPY requirements.txt .
|
|
||||||
RUN pip3 install --no-cache-dir -r requirements.txt
|
|
||||||
|
|
||||||
# Set working directory
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
# Create non-root user
|
|
||||||
RUN useradd -m -u 1000 appuser
|
|
||||||
|
|
||||||
# Create model directory and set ownership
|
|
||||||
RUN mkdir -p /app/Kokoro-82M && \
|
|
||||||
chown -R appuser:appuser /app
|
|
||||||
|
|
||||||
# Switch to non-root user
|
|
||||||
USER appuser
|
|
||||||
|
|
||||||
# Run with Python unbuffered output for live logging
|
|
||||||
ENV PYTHONUNBUFFERED=1
|
|
||||||
|
|
||||||
# Copy only necessary application code
|
|
||||||
COPY --chown=appuser:appuser api /app/api
|
|
||||||
|
|
||||||
# Set Python path (app first for our imports, then model dir for model imports)
|
|
||||||
ENV PYTHONPATH=/app:/app/Kokoro-82M
|
|
||||||
|
|
||||||
# Run FastAPI server with debug logging and reload
|
|
||||||
CMD ["uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]
|
|
|
@ -1,44 +0,0 @@
|
||||||
FROM ubuntu:22.04
|
|
||||||
|
|
||||||
# Install base system dependencies
|
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
|
||||||
python3-pip \
|
|
||||||
python3-dev \
|
|
||||||
espeak-ng \
|
|
||||||
git \
|
|
||||||
libsndfile1 \
|
|
||||||
curl \
|
|
||||||
&& apt-get clean \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# Install PyTorch CPU version and ONNX runtime
|
|
||||||
RUN pip3 install --no-cache-dir torch==2.5.1 --extra-index-url https://download.pytorch.org/whl/cpu
|
|
||||||
|
|
||||||
# Install all other dependencies from requirements.txt
|
|
||||||
COPY requirements.txt .
|
|
||||||
RUN pip3 install --no-cache-dir -r requirements.txt
|
|
||||||
|
|
||||||
# Copy application code and model
|
|
||||||
COPY . /app/
|
|
||||||
|
|
||||||
# Set working directory
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
# Run with Python unbuffered output for live logging
|
|
||||||
ENV PYTHONUNBUFFERED=1
|
|
||||||
|
|
||||||
# Create non-root user
|
|
||||||
RUN useradd -m -u 1000 appuser
|
|
||||||
|
|
||||||
# Create directories and set permissions
|
|
||||||
RUN mkdir -p /app/Kokoro-82M && \
|
|
||||||
chown -R appuser:appuser /app
|
|
||||||
|
|
||||||
# Switch to non-root user
|
|
||||||
USER appuser
|
|
||||||
|
|
||||||
# Set Python path (app first for our imports, then model dir for model imports)
|
|
||||||
ENV PYTHONPATH=/app:/app/Kokoro-82M
|
|
||||||
|
|
||||||
# Run FastAPI server with debug logging and reload
|
|
||||||
CMD ["uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]
|
|
70
MigrationWorkingNotes.md
Normal file
70
MigrationWorkingNotes.md
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
# UV Setup
|
||||||
|
Deprecated notes for myself
|
||||||
|
## Structure
|
||||||
|
```
|
||||||
|
docker/
|
||||||
|
├── cpu/
|
||||||
|
│ ├── pyproject.toml # CPU deps (torch CPU)
|
||||||
|
│ └── requirements.lock # CPU lockfile
|
||||||
|
├── gpu/
|
||||||
|
│ ├── pyproject.toml # GPU deps (torch CUDA)
|
||||||
|
│ └── requirements.lock # GPU lockfile
|
||||||
|
└── shared/
|
||||||
|
└── pyproject.toml # Common deps
|
||||||
|
```
|
||||||
|
|
||||||
|
## Regenerate Lock Files
|
||||||
|
|
||||||
|
### CPU
|
||||||
|
```bash
|
||||||
|
cd docker/cpu
|
||||||
|
uv pip compile pyproject.toml ../shared/pyproject.toml --output-file requirements.lock
|
||||||
|
```
|
||||||
|
|
||||||
|
### GPU
|
||||||
|
```bash
|
||||||
|
cd docker/gpu
|
||||||
|
uv pip compile pyproject.toml ../shared/pyproject.toml --output-file requirements.lock
|
||||||
|
```
|
||||||
|
|
||||||
|
## Local Dev Setup
|
||||||
|
|
||||||
|
### CPU
|
||||||
|
```bash
|
||||||
|
cd docker/cpu
|
||||||
|
uv venv
|
||||||
|
.venv\Scripts\activate # Windows
|
||||||
|
uv pip sync requirements.lock
|
||||||
|
```
|
||||||
|
|
||||||
|
### GPU
|
||||||
|
```bash
|
||||||
|
cd docker/gpu
|
||||||
|
uv venv
|
||||||
|
.venv\Scripts\activate # Windows
|
||||||
|
uv pip sync requirements.lock --extra-index-url https://download.pytorch.org/whl/cu121 --index-strategy unsafe-best-match
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run Server
|
||||||
|
```bash
|
||||||
|
# From project root with venv active:
|
||||||
|
uvicorn api.src.main:app --reload
|
||||||
|
```
|
||||||
|
|
||||||
|
## Docker
|
||||||
|
|
||||||
|
### CPU
|
||||||
|
```bash
|
||||||
|
cd docker/cpu
|
||||||
|
docker compose up
|
||||||
|
```
|
||||||
|
|
||||||
|
### GPU
|
||||||
|
```bash
|
||||||
|
cd docker/gpu
|
||||||
|
docker compose up
|
||||||
|
```
|
||||||
|
|
||||||
|
## Known Issues
|
||||||
|
- Module imports: Run server from project root
|
||||||
|
- PyTorch CUDA: Always use --extra-index-url and --index-strategy for GPU env
|
64
README.md
64
README.md
|
@ -2,9 +2,9 @@
|
||||||
<img src="githubbanner.png" alt="Kokoro TTS Banner">
|
<img src="githubbanner.png" alt="Kokoro TTS Banner">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
# Kokoro TTS API
|
# <sub><sub>_`FastKoko`_ </sub></sub>
|
||||||
[]()
|
[]()
|
||||||
[]()
|
[]()
|
||||||
[](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667) [](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero) [](https://www.buymeacoffee.com/remsky)
|
[](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667) [](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero) [](https://www.buymeacoffee.com/remsky)
|
||||||
|
|
||||||
Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model
|
Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model
|
||||||
|
@ -24,14 +24,30 @@ Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokor
|
||||||
The service can be accessed through either the API endpoints or the Gradio web interface.
|
The service can be accessed through either the API endpoints or the Gradio web interface.
|
||||||
|
|
||||||
1. Install prerequisites:
|
1. Install prerequisites:
|
||||||
- Install [Docker Desktop](https://www.docker.com/products/docker-desktop/) + [Git](https://git-scm.com/downloads)
|
- Install [Docker Desktop](https://www.docker.com/products/docker-desktop/)
|
||||||
- Clone and start the service:
|
- Clone the repository:
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/remsky/Kokoro-FastAPI.git
|
git clone https://github.com/remsky/Kokoro-FastAPI.git
|
||||||
cd Kokoro-FastAPI
|
cd Kokoro-FastAPI
|
||||||
docker compose up --build # for GPU
|
|
||||||
#docker compose -f docker-compose.cpu.yml up --build # for CPU
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
2. Start the service:
|
||||||
|
|
||||||
|
- Using Docker Compose (Full setup including UI):
|
||||||
|
```bash
|
||||||
|
cd docker/gpu # OR
|
||||||
|
# cd docker/cpu # Run this or the above
|
||||||
|
docker compose up --build
|
||||||
|
```
|
||||||
|
- OR running the API alone using Docker (model + voice packs baked in):
|
||||||
|
```bash
|
||||||
|
|
||||||
|
docker run -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-cpu:latest # CPU
|
||||||
|
docker run --gpus all -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-gpu:latest # Nvidia GPU
|
||||||
|
# Minified versions are available with `:latest-slim` tag.
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
2. Run locally as an OpenAI-Compatible Speech Endpoint
|
2. Run locally as an OpenAI-Compatible Speech Endpoint
|
||||||
```python
|
```python
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
@ -167,6 +183,21 @@ If you only want the API, just comment out everything in the docker-compose.yml
|
||||||
Currently, voices created via the API are accessible here, but voice combination/creation has not yet been added
|
Currently, voices created via the API are accessible here, but voice combination/creation has not yet been added
|
||||||
|
|
||||||
*Note: Recent updates for streaming could lead to temporary glitches. If so, pull from the most recent stable release v0.0.2 to restore*
|
*Note: Recent updates for streaming could lead to temporary glitches. If so, pull from the most recent stable release v0.0.2 to restore*
|
||||||
|
|
||||||
|
### Disabling Local Saving
|
||||||
|
|
||||||
|
You can disable local saving of audio files and hide the file view in the UI by setting the `DISABLE_LOCAL_SAVING` environment variable to `true`. This is useful when running the service on a server where you don't want to store generated audio files locally.
|
||||||
|
|
||||||
|
When using Docker Compose:
|
||||||
|
```yaml
|
||||||
|
environment:
|
||||||
|
- DISABLE_LOCAL_SAVING=true
|
||||||
|
```
|
||||||
|
|
||||||
|
When running the Docker image directly:
|
||||||
|
```bash
|
||||||
|
docker run -p 7860:7860 -e DISABLE_LOCAL_SAVING=true ghcr.io/remsky/kokoro-fastapi-ui:latest
|
||||||
|
```
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
|
@ -320,6 +351,27 @@ See `examples/phoneme_examples/generate_phonemes.py` for a sample script.
|
||||||
|
|
||||||
## Known Issues
|
## Known Issues
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Versioning & Development</summary>
|
||||||
|
|
||||||
|
I'm doing what I can to keep things stable, but we are on an early and rapid set of build cycles here.
|
||||||
|
If you run into trouble, you may have to roll back a version on the release tags if something comes up, or build up from source and/or troubleshoot + submit a PR. Will leave the branch up here for the last known stable points:
|
||||||
|
|
||||||
|
`v0.0.5post1`
|
||||||
|
|
||||||
|
Free and open source is a community effort, and I love working on this project, though there's only really so many hours in a day. If you'd like to support the work, feel free to open a PR, buy me a coffee, or report any bugs/features/etc you find during use.
|
||||||
|
|
||||||
|
<a href="https://www.buymeacoffee.com/remsky" target="_blank">
|
||||||
|
<img
|
||||||
|
src="https://cdn.buymeacoffee.com/buttons/v2/default-violet.png"
|
||||||
|
alt="Buy Me A Coffee"
|
||||||
|
style="height: 30px !important;width: 110px !important;"
|
||||||
|
>
|
||||||
|
</a>
|
||||||
|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary>Linux GPU Permissions</summary>
|
<summary>Linux GPU Permissions</summary>
|
||||||
|
|
||||||
|
|
0
api/src/builds/__init__.py
Normal file
0
api/src/builds/__init__.py
Normal file
26
api/src/builds/config.json
Normal file
26
api/src/builds/config.json
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
{
|
||||||
|
"decoder": {
|
||||||
|
"type": "istftnet",
|
||||||
|
"upsample_kernel_sizes": [20, 12],
|
||||||
|
"upsample_rates": [10, 6],
|
||||||
|
"gen_istft_hop_size": 5,
|
||||||
|
"gen_istft_n_fft": 20,
|
||||||
|
"resblock_dilation_sizes": [
|
||||||
|
[1, 3, 5],
|
||||||
|
[1, 3, 5],
|
||||||
|
[1, 3, 5]
|
||||||
|
],
|
||||||
|
"resblock_kernel_sizes": [3, 7, 11],
|
||||||
|
"upsample_initial_channel": 512
|
||||||
|
},
|
||||||
|
"dim_in": 64,
|
||||||
|
"dropout": 0.2,
|
||||||
|
"hidden_dim": 512,
|
||||||
|
"max_conv_dim": 512,
|
||||||
|
"max_dur": 50,
|
||||||
|
"multispeaker": true,
|
||||||
|
"n_layer": 3,
|
||||||
|
"n_mels": 80,
|
||||||
|
"n_token": 178,
|
||||||
|
"style_dim": 128
|
||||||
|
}
|
524
api/src/builds/istftnet.py
Normal file
524
api/src/builds/istftnet.py
Normal file
|
@ -0,0 +1,524 @@
|
||||||
|
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/istftnet.py
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from scipy.signal import get_window
|
||||||
|
from torch.nn import Conv1d, ConvTranspose1d
|
||||||
|
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||||
|
|
||||||
|
|
||||||
|
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py
|
||||||
|
def init_weights(m, mean=0.0, std=0.01):
|
||||||
|
classname = m.__class__.__name__
|
||||||
|
if classname.find("Conv") != -1:
|
||||||
|
m.weight.data.normal_(mean, std)
|
||||||
|
|
||||||
|
def get_padding(kernel_size, dilation=1):
|
||||||
|
return int((kernel_size*dilation - dilation)/2)
|
||||||
|
|
||||||
|
LRELU_SLOPE = 0.1
|
||||||
|
|
||||||
|
class AdaIN1d(nn.Module):
|
||||||
|
def __init__(self, style_dim, num_features):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = nn.InstanceNorm1d(num_features, affine=False)
|
||||||
|
self.fc = nn.Linear(style_dim, num_features*2)
|
||||||
|
|
||||||
|
def forward(self, x, s):
|
||||||
|
h = self.fc(s)
|
||||||
|
h = h.view(h.size(0), h.size(1), 1)
|
||||||
|
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
||||||
|
return (1 + gamma) * self.norm(x) + beta
|
||||||
|
|
||||||
|
class AdaINResBlock1(torch.nn.Module):
|
||||||
|
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
|
||||||
|
super(AdaINResBlock1, self).__init__()
|
||||||
|
self.convs1 = nn.ModuleList([
|
||||||
|
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
||||||
|
padding=get_padding(kernel_size, dilation[0]))),
|
||||||
|
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
||||||
|
padding=get_padding(kernel_size, dilation[1]))),
|
||||||
|
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
||||||
|
padding=get_padding(kernel_size, dilation[2])))
|
||||||
|
])
|
||||||
|
self.convs1.apply(init_weights)
|
||||||
|
|
||||||
|
self.convs2 = nn.ModuleList([
|
||||||
|
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1))),
|
||||||
|
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1))),
|
||||||
|
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1)))
|
||||||
|
])
|
||||||
|
self.convs2.apply(init_weights)
|
||||||
|
|
||||||
|
self.adain1 = nn.ModuleList([
|
||||||
|
AdaIN1d(style_dim, channels),
|
||||||
|
AdaIN1d(style_dim, channels),
|
||||||
|
AdaIN1d(style_dim, channels),
|
||||||
|
])
|
||||||
|
|
||||||
|
self.adain2 = nn.ModuleList([
|
||||||
|
AdaIN1d(style_dim, channels),
|
||||||
|
AdaIN1d(style_dim, channels),
|
||||||
|
AdaIN1d(style_dim, channels),
|
||||||
|
])
|
||||||
|
|
||||||
|
self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
|
||||||
|
self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x, s):
|
||||||
|
for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
|
||||||
|
xt = n1(x, s)
|
||||||
|
xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
|
||||||
|
xt = c1(xt)
|
||||||
|
xt = n2(xt, s)
|
||||||
|
xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
|
||||||
|
xt = c2(xt)
|
||||||
|
x = xt + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
for l in self.convs1:
|
||||||
|
remove_weight_norm(l)
|
||||||
|
for l in self.convs2:
|
||||||
|
remove_weight_norm(l)
|
||||||
|
|
||||||
|
class TorchSTFT(torch.nn.Module):
|
||||||
|
def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
|
||||||
|
super().__init__()
|
||||||
|
self.filter_length = filter_length
|
||||||
|
self.hop_length = hop_length
|
||||||
|
self.win_length = win_length
|
||||||
|
self.window = torch.from_numpy(get_window(window, win_length, fftbins=True).astype(np.float32))
|
||||||
|
|
||||||
|
def transform(self, input_data):
|
||||||
|
forward_transform = torch.stft(
|
||||||
|
input_data,
|
||||||
|
self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device),
|
||||||
|
return_complex=True)
|
||||||
|
|
||||||
|
return torch.abs(forward_transform), torch.angle(forward_transform)
|
||||||
|
|
||||||
|
def inverse(self, magnitude, phase):
|
||||||
|
inverse_transform = torch.istft(
|
||||||
|
magnitude * torch.exp(phase * 1j),
|
||||||
|
self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device))
|
||||||
|
|
||||||
|
return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
|
||||||
|
|
||||||
|
def forward(self, input_data):
|
||||||
|
self.magnitude, self.phase = self.transform(input_data)
|
||||||
|
reconstruction = self.inverse(self.magnitude, self.phase)
|
||||||
|
return reconstruction
|
||||||
|
|
||||||
|
class SineGen(torch.nn.Module):
|
||||||
|
""" Definition of sine generator
|
||||||
|
SineGen(samp_rate, harmonic_num = 0,
|
||||||
|
sine_amp = 0.1, noise_std = 0.003,
|
||||||
|
voiced_threshold = 0,
|
||||||
|
flag_for_pulse=False)
|
||||||
|
samp_rate: sampling rate in Hz
|
||||||
|
harmonic_num: number of harmonic overtones (default 0)
|
||||||
|
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
||||||
|
noise_std: std of Gaussian noise (default 0.003)
|
||||||
|
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
||||||
|
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
||||||
|
Note: when flag_for_pulse is True, the first time step of a voiced
|
||||||
|
segment is always sin(np.pi) or cos(0)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
|
||||||
|
sine_amp=0.1, noise_std=0.003,
|
||||||
|
voiced_threshold=0,
|
||||||
|
flag_for_pulse=False):
|
||||||
|
super(SineGen, self).__init__()
|
||||||
|
self.sine_amp = sine_amp
|
||||||
|
self.noise_std = noise_std
|
||||||
|
self.harmonic_num = harmonic_num
|
||||||
|
self.dim = self.harmonic_num + 1
|
||||||
|
self.sampling_rate = samp_rate
|
||||||
|
self.voiced_threshold = voiced_threshold
|
||||||
|
self.flag_for_pulse = flag_for_pulse
|
||||||
|
self.upsample_scale = upsample_scale
|
||||||
|
|
||||||
|
def _f02uv(self, f0):
|
||||||
|
# generate uv signal
|
||||||
|
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
||||||
|
return uv
|
||||||
|
|
||||||
|
def _f02sine(self, f0_values):
|
||||||
|
""" f0_values: (batchsize, length, dim)
|
||||||
|
where dim indicates fundamental tone and overtones
|
||||||
|
"""
|
||||||
|
# convert to F0 in rad. The interger part n can be ignored
|
||||||
|
# because 2 * np.pi * n doesn't affect phase
|
||||||
|
rad_values = (f0_values / self.sampling_rate) % 1
|
||||||
|
|
||||||
|
# initial phase noise (no noise for fundamental component)
|
||||||
|
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
|
||||||
|
device=f0_values.device)
|
||||||
|
rand_ini[:, 0] = 0
|
||||||
|
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
||||||
|
|
||||||
|
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
|
||||||
|
if not self.flag_for_pulse:
|
||||||
|
# # for normal case
|
||||||
|
|
||||||
|
# # To prevent torch.cumsum numerical overflow,
|
||||||
|
# # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
|
||||||
|
# # Buffer tmp_over_one_idx indicates the time step to add -1.
|
||||||
|
# # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
|
||||||
|
# tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
||||||
|
# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
|
||||||
|
# cumsum_shift = torch.zeros_like(rad_values)
|
||||||
|
# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
||||||
|
|
||||||
|
# phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
||||||
|
rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
|
||||||
|
scale_factor=1/self.upsample_scale,
|
||||||
|
mode="linear").transpose(1, 2)
|
||||||
|
|
||||||
|
# tmp_over_one = torch.cumsum(rad_values, 1) % 1
|
||||||
|
# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
|
||||||
|
# cumsum_shift = torch.zeros_like(rad_values)
|
||||||
|
# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
||||||
|
|
||||||
|
phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
||||||
|
phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
|
||||||
|
scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
|
||||||
|
sines = torch.sin(phase)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# If necessary, make sure that the first time step of every
|
||||||
|
# voiced segments is sin(pi) or cos(0)
|
||||||
|
# This is used for pulse-train generation
|
||||||
|
|
||||||
|
# identify the last time step in unvoiced segments
|
||||||
|
uv = self._f02uv(f0_values)
|
||||||
|
uv_1 = torch.roll(uv, shifts=-1, dims=1)
|
||||||
|
uv_1[:, -1, :] = 1
|
||||||
|
u_loc = (uv < 1) * (uv_1 > 0)
|
||||||
|
|
||||||
|
# get the instantanouse phase
|
||||||
|
tmp_cumsum = torch.cumsum(rad_values, dim=1)
|
||||||
|
# different batch needs to be processed differently
|
||||||
|
for idx in range(f0_values.shape[0]):
|
||||||
|
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
|
||||||
|
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
|
||||||
|
# stores the accumulation of i.phase within
|
||||||
|
# each voiced segments
|
||||||
|
tmp_cumsum[idx, :, :] = 0
|
||||||
|
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
|
||||||
|
|
||||||
|
# rad_values - tmp_cumsum: remove the accumulation of i.phase
|
||||||
|
# within the previous voiced segment.
|
||||||
|
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
|
||||||
|
|
||||||
|
# get the sines
|
||||||
|
sines = torch.cos(i_phase * 2 * np.pi)
|
||||||
|
return sines
|
||||||
|
|
||||||
|
def forward(self, f0):
|
||||||
|
""" sine_tensor, uv = forward(f0)
|
||||||
|
input F0: tensor(batchsize=1, length, dim=1)
|
||||||
|
f0 for unvoiced steps should be 0
|
||||||
|
output sine_tensor: tensor(batchsize=1, length, dim)
|
||||||
|
output uv: tensor(batchsize=1, length, 1)
|
||||||
|
"""
|
||||||
|
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
|
||||||
|
device=f0.device)
|
||||||
|
# fundamental component
|
||||||
|
fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
|
||||||
|
|
||||||
|
# generate sine waveforms
|
||||||
|
sine_waves = self._f02sine(fn) * self.sine_amp
|
||||||
|
|
||||||
|
# generate uv signal
|
||||||
|
# uv = torch.ones(f0.shape)
|
||||||
|
# uv = uv * (f0 > self.voiced_threshold)
|
||||||
|
uv = self._f02uv(f0)
|
||||||
|
|
||||||
|
# noise: for unvoiced should be similar to sine_amp
|
||||||
|
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
||||||
|
# . for voiced regions is self.noise_std
|
||||||
|
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
||||||
|
noise = noise_amp * torch.randn_like(sine_waves)
|
||||||
|
|
||||||
|
# first: set the unvoiced part to 0 by uv
|
||||||
|
# then: additive noise
|
||||||
|
sine_waves = sine_waves * uv + noise
|
||||||
|
return sine_waves, uv, noise
|
||||||
|
|
||||||
|
|
||||||
|
class SourceModuleHnNSF(torch.nn.Module):
|
||||||
|
""" SourceModule for hn-nsf
|
||||||
|
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
||||||
|
add_noise_std=0.003, voiced_threshod=0)
|
||||||
|
sampling_rate: sampling_rate in Hz
|
||||||
|
harmonic_num: number of harmonic above F0 (default: 0)
|
||||||
|
sine_amp: amplitude of sine source signal (default: 0.1)
|
||||||
|
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
||||||
|
note that amplitude of noise in unvoiced is decided
|
||||||
|
by sine_amp
|
||||||
|
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
||||||
|
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
||||||
|
F0_sampled (batchsize, length, 1)
|
||||||
|
Sine_source (batchsize, length, 1)
|
||||||
|
noise_source (batchsize, length 1)
|
||||||
|
uv (batchsize, length, 1)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
||||||
|
add_noise_std=0.003, voiced_threshod=0):
|
||||||
|
super(SourceModuleHnNSF, self).__init__()
|
||||||
|
|
||||||
|
self.sine_amp = sine_amp
|
||||||
|
self.noise_std = add_noise_std
|
||||||
|
|
||||||
|
# to produce sine waveforms
|
||||||
|
self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
|
||||||
|
sine_amp, add_noise_std, voiced_threshod)
|
||||||
|
|
||||||
|
# to merge source harmonics into a single excitation
|
||||||
|
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
||||||
|
self.l_tanh = torch.nn.Tanh()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
||||||
|
F0_sampled (batchsize, length, 1)
|
||||||
|
Sine_source (batchsize, length, 1)
|
||||||
|
noise_source (batchsize, length 1)
|
||||||
|
"""
|
||||||
|
# source for harmonic branch
|
||||||
|
with torch.no_grad():
|
||||||
|
sine_wavs, uv, _ = self.l_sin_gen(x)
|
||||||
|
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
||||||
|
|
||||||
|
# source for noise branch, in the same shape as uv
|
||||||
|
noise = torch.randn_like(uv) * self.sine_amp / 3
|
||||||
|
return sine_merge, noise, uv
|
||||||
|
def padDiff(x):
|
||||||
|
return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0)
|
||||||
|
|
||||||
|
|
||||||
|
class Generator(torch.nn.Module):
|
||||||
|
def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size):
|
||||||
|
super(Generator, self).__init__()
|
||||||
|
|
||||||
|
self.num_kernels = len(resblock_kernel_sizes)
|
||||||
|
self.num_upsamples = len(upsample_rates)
|
||||||
|
resblock = AdaINResBlock1
|
||||||
|
|
||||||
|
self.m_source = SourceModuleHnNSF(
|
||||||
|
sampling_rate=24000,
|
||||||
|
upsample_scale=np.prod(upsample_rates) * gen_istft_hop_size,
|
||||||
|
harmonic_num=8, voiced_threshod=10)
|
||||||
|
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * gen_istft_hop_size)
|
||||||
|
self.noise_convs = nn.ModuleList()
|
||||||
|
self.noise_res = nn.ModuleList()
|
||||||
|
|
||||||
|
self.ups = nn.ModuleList()
|
||||||
|
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||||
|
self.ups.append(weight_norm(
|
||||||
|
ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
|
||||||
|
k, u, padding=(k-u)//2)))
|
||||||
|
|
||||||
|
self.resblocks = nn.ModuleList()
|
||||||
|
for i in range(len(self.ups)):
|
||||||
|
ch = upsample_initial_channel//(2**(i+1))
|
||||||
|
for j, (k, d) in enumerate(zip(resblock_kernel_sizes,resblock_dilation_sizes)):
|
||||||
|
self.resblocks.append(resblock(ch, k, d, style_dim))
|
||||||
|
|
||||||
|
c_cur = upsample_initial_channel // (2 ** (i + 1))
|
||||||
|
|
||||||
|
if i + 1 < len(upsample_rates): #
|
||||||
|
stride_f0 = np.prod(upsample_rates[i + 1:])
|
||||||
|
self.noise_convs.append(Conv1d(
|
||||||
|
gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
|
||||||
|
self.noise_res.append(resblock(c_cur, 7, [1,3,5], style_dim))
|
||||||
|
else:
|
||||||
|
self.noise_convs.append(Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1))
|
||||||
|
self.noise_res.append(resblock(c_cur, 11, [1,3,5], style_dim))
|
||||||
|
|
||||||
|
|
||||||
|
self.post_n_fft = gen_istft_n_fft
|
||||||
|
self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
|
||||||
|
self.ups.apply(init_weights)
|
||||||
|
self.conv_post.apply(init_weights)
|
||||||
|
self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
|
||||||
|
self.stft = TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x, s, f0):
|
||||||
|
with torch.no_grad():
|
||||||
|
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
||||||
|
|
||||||
|
har_source, noi_source, uv = self.m_source(f0)
|
||||||
|
har_source = har_source.transpose(1, 2).squeeze(1)
|
||||||
|
har_spec, har_phase = self.stft.transform(har_source)
|
||||||
|
har = torch.cat([har_spec, har_phase], dim=1)
|
||||||
|
|
||||||
|
for i in range(self.num_upsamples):
|
||||||
|
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
|
x_source = self.noise_convs[i](har)
|
||||||
|
x_source = self.noise_res[i](x_source, s)
|
||||||
|
|
||||||
|
x = self.ups[i](x)
|
||||||
|
if i == self.num_upsamples - 1:
|
||||||
|
x = self.reflection_pad(x)
|
||||||
|
|
||||||
|
x = x + x_source
|
||||||
|
xs = None
|
||||||
|
for j in range(self.num_kernels):
|
||||||
|
if xs is None:
|
||||||
|
xs = self.resblocks[i*self.num_kernels+j](x, s)
|
||||||
|
else:
|
||||||
|
xs += self.resblocks[i*self.num_kernels+j](x, s)
|
||||||
|
x = xs / self.num_kernels
|
||||||
|
x = F.leaky_relu(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
|
||||||
|
phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
|
||||||
|
return self.stft.inverse(spec, phase)
|
||||||
|
|
||||||
|
def fw_phase(self, x, s):
|
||||||
|
for i in range(self.num_upsamples):
|
||||||
|
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
|
x = self.ups[i](x)
|
||||||
|
xs = None
|
||||||
|
for j in range(self.num_kernels):
|
||||||
|
if xs is None:
|
||||||
|
xs = self.resblocks[i*self.num_kernels+j](x, s)
|
||||||
|
else:
|
||||||
|
xs += self.resblocks[i*self.num_kernels+j](x, s)
|
||||||
|
x = xs / self.num_kernels
|
||||||
|
x = F.leaky_relu(x)
|
||||||
|
x = self.reflection_pad(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
|
||||||
|
phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
|
||||||
|
return spec, phase
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
print('Removing weight norm...')
|
||||||
|
for l in self.ups:
|
||||||
|
remove_weight_norm(l)
|
||||||
|
for l in self.resblocks:
|
||||||
|
l.remove_weight_norm()
|
||||||
|
remove_weight_norm(self.conv_pre)
|
||||||
|
remove_weight_norm(self.conv_post)
|
||||||
|
|
||||||
|
|
||||||
|
class AdainResBlk1d(nn.Module):
|
||||||
|
def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
|
||||||
|
upsample='none', dropout_p=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.actv = actv
|
||||||
|
self.upsample_type = upsample
|
||||||
|
self.upsample = UpSample1d(upsample)
|
||||||
|
self.learned_sc = dim_in != dim_out
|
||||||
|
self._build_weights(dim_in, dim_out, style_dim)
|
||||||
|
self.dropout = nn.Dropout(dropout_p)
|
||||||
|
|
||||||
|
if upsample == 'none':
|
||||||
|
self.pool = nn.Identity()
|
||||||
|
else:
|
||||||
|
self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
|
||||||
|
|
||||||
|
|
||||||
|
def _build_weights(self, dim_in, dim_out, style_dim):
|
||||||
|
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
|
||||||
|
self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
|
||||||
|
self.norm1 = AdaIN1d(style_dim, dim_in)
|
||||||
|
self.norm2 = AdaIN1d(style_dim, dim_out)
|
||||||
|
if self.learned_sc:
|
||||||
|
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
|
||||||
|
|
||||||
|
def _shortcut(self, x):
|
||||||
|
x = self.upsample(x)
|
||||||
|
if self.learned_sc:
|
||||||
|
x = self.conv1x1(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _residual(self, x, s):
|
||||||
|
x = self.norm1(x, s)
|
||||||
|
x = self.actv(x)
|
||||||
|
x = self.pool(x)
|
||||||
|
x = self.conv1(self.dropout(x))
|
||||||
|
x = self.norm2(x, s)
|
||||||
|
x = self.actv(x)
|
||||||
|
x = self.conv2(self.dropout(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x, s):
|
||||||
|
out = self._residual(x, s)
|
||||||
|
out = (out + self._shortcut(x)) / np.sqrt(2)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class UpSample1d(nn.Module):
|
||||||
|
def __init__(self, layer_type):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_type = layer_type
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.layer_type == 'none':
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
return F.interpolate(x, scale_factor=2, mode='nearest')
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
def __init__(self, dim_in=512, F0_channel=512, style_dim=64, dim_out=80,
|
||||||
|
resblock_kernel_sizes = [3,7,11],
|
||||||
|
upsample_rates = [10, 6],
|
||||||
|
upsample_initial_channel=512,
|
||||||
|
resblock_dilation_sizes=[[1,3,5], [1,3,5], [1,3,5]],
|
||||||
|
upsample_kernel_sizes=[20, 12],
|
||||||
|
gen_istft_n_fft=20, gen_istft_hop_size=5):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.decode = nn.ModuleList()
|
||||||
|
|
||||||
|
self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
|
||||||
|
|
||||||
|
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
||||||
|
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
||||||
|
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
|
||||||
|
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
|
||||||
|
|
||||||
|
self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
|
||||||
|
|
||||||
|
self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
|
||||||
|
|
||||||
|
self.asr_res = nn.Sequential(
|
||||||
|
weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates,
|
||||||
|
upsample_initial_channel, resblock_dilation_sizes,
|
||||||
|
upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size)
|
||||||
|
|
||||||
|
def forward(self, asr, F0_curve, N, s):
|
||||||
|
F0 = self.F0_conv(F0_curve.unsqueeze(1))
|
||||||
|
N = self.N_conv(N.unsqueeze(1))
|
||||||
|
|
||||||
|
x = torch.cat([asr, F0, N], axis=1)
|
||||||
|
x = self.encode(x, s)
|
||||||
|
|
||||||
|
asr_res = self.asr_res(asr)
|
||||||
|
|
||||||
|
res = True
|
||||||
|
for block in self.decode:
|
||||||
|
if res:
|
||||||
|
x = torch.cat([x, asr_res, F0, N], axis=1)
|
||||||
|
x = block(x, s)
|
||||||
|
if block.upsample_type != "none":
|
||||||
|
res = False
|
||||||
|
|
||||||
|
x = self.generator(x, s, F0_curve)
|
||||||
|
return x
|
151
api/src/builds/kokoro.py
Normal file
151
api/src/builds/kokoro.py
Normal file
|
@ -0,0 +1,151 @@
|
||||||
|
import re
|
||||||
|
|
||||||
|
import phonemizer
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def split_num(num):
|
||||||
|
num = num.group()
|
||||||
|
if '.' in num:
|
||||||
|
return num
|
||||||
|
elif ':' in num:
|
||||||
|
h, m = [int(n) for n in num.split(':')]
|
||||||
|
if m == 0:
|
||||||
|
return f"{h} o'clock"
|
||||||
|
elif m < 10:
|
||||||
|
return f'{h} oh {m}'
|
||||||
|
return f'{h} {m}'
|
||||||
|
year = int(num[:4])
|
||||||
|
if year < 1100 or year % 1000 < 10:
|
||||||
|
return num
|
||||||
|
left, right = num[:2], int(num[2:4])
|
||||||
|
s = 's' if num.endswith('s') else ''
|
||||||
|
if 100 <= year % 1000 <= 999:
|
||||||
|
if right == 0:
|
||||||
|
return f'{left} hundred{s}'
|
||||||
|
elif right < 10:
|
||||||
|
return f'{left} oh {right}{s}'
|
||||||
|
return f'{left} {right}{s}'
|
||||||
|
|
||||||
|
def flip_money(m):
|
||||||
|
m = m.group()
|
||||||
|
bill = 'dollar' if m[0] == '$' else 'pound'
|
||||||
|
if m[-1].isalpha():
|
||||||
|
return f'{m[1:]} {bill}s'
|
||||||
|
elif '.' not in m:
|
||||||
|
s = '' if m[1:] == '1' else 's'
|
||||||
|
return f'{m[1:]} {bill}{s}'
|
||||||
|
b, c = m[1:].split('.')
|
||||||
|
s = '' if b == '1' else 's'
|
||||||
|
c = int(c.ljust(2, '0'))
|
||||||
|
coins = f"cent{'' if c == 1 else 's'}" if m[0] == '$' else ('penny' if c == 1 else 'pence')
|
||||||
|
return f'{b} {bill}{s} and {c} {coins}'
|
||||||
|
|
||||||
|
def point_num(num):
|
||||||
|
a, b = num.group().split('.')
|
||||||
|
return ' point '.join([a, ' '.join(b)])
|
||||||
|
|
||||||
|
def normalize_text(text):
|
||||||
|
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
|
||||||
|
text = text.replace('«', chr(8220)).replace('»', chr(8221))
|
||||||
|
text = text.replace(chr(8220), '"').replace(chr(8221), '"')
|
||||||
|
text = text.replace('(', '«').replace(')', '»')
|
||||||
|
for a, b in zip('、。!,:;?', ',.!,:;?'):
|
||||||
|
text = text.replace(a, b+' ')
|
||||||
|
text = re.sub(r'[^\S \n]', ' ', text)
|
||||||
|
text = re.sub(r' +', ' ', text)
|
||||||
|
text = re.sub(r'(?<=\n) +(?=\n)', '', text)
|
||||||
|
text = re.sub(r'\bD[Rr]\.(?= [A-Z])', 'Doctor', text)
|
||||||
|
text = re.sub(r'\b(?:Mr\.|MR\.(?= [A-Z]))', 'Mister', text)
|
||||||
|
text = re.sub(r'\b(?:Ms\.|MS\.(?= [A-Z]))', 'Miss', text)
|
||||||
|
text = re.sub(r'\b(?:Mrs\.|MRS\.(?= [A-Z]))', 'Mrs', text)
|
||||||
|
text = re.sub(r'\betc\.(?! [A-Z])', 'etc', text)
|
||||||
|
text = re.sub(r'(?i)\b(y)eah?\b', r"\1e'a", text)
|
||||||
|
text = re.sub(r'\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)', split_num, text)
|
||||||
|
text = re.sub(r'(?<=\d),(?=\d)', '', text)
|
||||||
|
text = re.sub(r'(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b', flip_money, text)
|
||||||
|
text = re.sub(r'\d*\.\d+', point_num, text)
|
||||||
|
text = re.sub(r'(?<=\d)-(?=\d)', ' to ', text)
|
||||||
|
text = re.sub(r'(?<=\d)S', ' S', text)
|
||||||
|
text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
|
||||||
|
text = re.sub(r"(?<=X')S\b", 's', text)
|
||||||
|
text = re.sub(r'(?:[A-Za-z]\.){2,} [a-z]', lambda m: m.group().replace('.', '-'), text)
|
||||||
|
text = re.sub(r'(?i)(?<=[A-Z])\.(?=[A-Z])', '-', text)
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
def get_vocab():
|
||||||
|
_pad = "$"
|
||||||
|
_punctuation = ';:,.!?¡¿—…"«»“” '
|
||||||
|
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
|
||||||
|
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
||||||
|
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
|
||||||
|
dicts = {}
|
||||||
|
for i in range(len((symbols))):
|
||||||
|
dicts[symbols[i]] = i
|
||||||
|
return dicts
|
||||||
|
|
||||||
|
VOCAB = get_vocab()
|
||||||
|
def tokenize(ps):
|
||||||
|
return [i for i in map(VOCAB.get, ps) if i is not None]
|
||||||
|
|
||||||
|
phonemizers = dict(
|
||||||
|
a=phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True),
|
||||||
|
b=phonemizer.backend.EspeakBackend(language='en-gb', preserve_punctuation=True, with_stress=True),
|
||||||
|
)
|
||||||
|
def phonemize(text, lang, norm=True):
|
||||||
|
if norm:
|
||||||
|
text = normalize_text(text)
|
||||||
|
ps = phonemizers[lang].phonemize([text])
|
||||||
|
ps = ps[0] if ps else ''
|
||||||
|
# https://en.wiktionary.org/wiki/kokoro#English
|
||||||
|
ps = ps.replace('kəkˈoːɹoʊ', 'kˈoʊkəɹoʊ').replace('kəkˈɔːɹəʊ', 'kˈəʊkəɹəʊ')
|
||||||
|
ps = ps.replace('ʲ', 'j').replace('r', 'ɹ').replace('x', 'k').replace('ɬ', 'l')
|
||||||
|
ps = re.sub(r'(?<=[a-zɹː])(?=hˈʌndɹɪd)', ' ', ps)
|
||||||
|
ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»“” ]|$)', 'z', ps)
|
||||||
|
if lang == 'a':
|
||||||
|
ps = re.sub(r'(?<=nˈaɪn)ti(?!ː)', 'di', ps)
|
||||||
|
ps = ''.join(filter(lambda p: p in VOCAB, ps))
|
||||||
|
return ps.strip()
|
||||||
|
|
||||||
|
def length_to_mask(lengths):
|
||||||
|
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
||||||
|
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
||||||
|
return mask
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(model, tokens, ref_s, speed):
|
||||||
|
device = ref_s.device
|
||||||
|
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
||||||
|
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
||||||
|
text_mask = length_to_mask(input_lengths).to(device)
|
||||||
|
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
||||||
|
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
||||||
|
s = ref_s[:, 128:]
|
||||||
|
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
|
||||||
|
x, _ = model.predictor.lstm(d)
|
||||||
|
duration = model.predictor.duration_proj(x)
|
||||||
|
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
||||||
|
pred_dur = torch.round(duration).clamp(min=1).long()
|
||||||
|
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
|
||||||
|
c_frame = 0
|
||||||
|
for i in range(pred_aln_trg.size(0)):
|
||||||
|
pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1
|
||||||
|
c_frame += pred_dur[0,i].item()
|
||||||
|
en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
|
||||||
|
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
||||||
|
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
||||||
|
asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
|
||||||
|
return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
|
||||||
|
|
||||||
|
def generate(model, text, voicepack, lang='a', speed=1, ps=None):
|
||||||
|
ps = ps or phonemize(text, lang)
|
||||||
|
tokens = tokenize(ps)
|
||||||
|
if not tokens:
|
||||||
|
return None
|
||||||
|
elif len(tokens) > 510:
|
||||||
|
tokens = tokens[:510]
|
||||||
|
print('Truncated to 510 tokens')
|
||||||
|
ref_s = voicepack[len(tokens)]
|
||||||
|
out = forward(model, tokens, ref_s, speed)
|
||||||
|
ps = ''.join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
|
||||||
|
return out, ps
|
375
api/src/builds/models.py
Normal file
375
api/src/builds/models.py
Normal file
|
@ -0,0 +1,375 @@
|
||||||
|
# https://github.com/yl4579/StyleTTS2/blob/main/models.py
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from munch import Munch
|
||||||
|
from torch.nn.utils import spectral_norm, weight_norm
|
||||||
|
|
||||||
|
from .istftnet import AdaIN1d, Decoder
|
||||||
|
from .plbert import load_plbert
|
||||||
|
|
||||||
|
|
||||||
|
class LinearNorm(torch.nn.Module):
|
||||||
|
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
|
||||||
|
super(LinearNorm, self).__init__()
|
||||||
|
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
||||||
|
|
||||||
|
torch.nn.init.xavier_uniform_(
|
||||||
|
self.linear_layer.weight,
|
||||||
|
gain=torch.nn.init.calculate_gain(w_init_gain))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.linear_layer(x)
|
||||||
|
|
||||||
|
class LayerNorm(nn.Module):
|
||||||
|
def __init__(self, channels, eps=1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
self.gamma = nn.Parameter(torch.ones(channels))
|
||||||
|
self.beta = nn.Parameter(torch.zeros(channels))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.transpose(1, -1)
|
||||||
|
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
||||||
|
return x.transpose(1, -1)
|
||||||
|
|
||||||
|
class TextEncoder(nn.Module):
|
||||||
|
def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
|
||||||
|
super().__init__()
|
||||||
|
self.embedding = nn.Embedding(n_symbols, channels)
|
||||||
|
|
||||||
|
padding = (kernel_size - 1) // 2
|
||||||
|
self.cnn = nn.ModuleList()
|
||||||
|
for _ in range(depth):
|
||||||
|
self.cnn.append(nn.Sequential(
|
||||||
|
weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
|
||||||
|
LayerNorm(channels),
|
||||||
|
actv,
|
||||||
|
nn.Dropout(0.2),
|
||||||
|
))
|
||||||
|
# self.cnn = nn.Sequential(*self.cnn)
|
||||||
|
|
||||||
|
self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)
|
||||||
|
|
||||||
|
def forward(self, x, input_lengths, m):
|
||||||
|
x = self.embedding(x) # [B, T, emb]
|
||||||
|
x = x.transpose(1, 2) # [B, emb, T]
|
||||||
|
m = m.to(input_lengths.device).unsqueeze(1)
|
||||||
|
x.masked_fill_(m, 0.0)
|
||||||
|
|
||||||
|
for c in self.cnn:
|
||||||
|
x = c(x)
|
||||||
|
x.masked_fill_(m, 0.0)
|
||||||
|
|
||||||
|
x = x.transpose(1, 2) # [B, T, chn]
|
||||||
|
|
||||||
|
input_lengths = input_lengths.cpu().numpy()
|
||||||
|
x = nn.utils.rnn.pack_padded_sequence(
|
||||||
|
x, input_lengths, batch_first=True, enforce_sorted=False)
|
||||||
|
|
||||||
|
self.lstm.flatten_parameters()
|
||||||
|
x, _ = self.lstm(x)
|
||||||
|
x, _ = nn.utils.rnn.pad_packed_sequence(
|
||||||
|
x, batch_first=True)
|
||||||
|
|
||||||
|
x = x.transpose(-1, -2)
|
||||||
|
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
|
||||||
|
|
||||||
|
x_pad[:, :, :x.shape[-1]] = x
|
||||||
|
x = x_pad.to(x.device)
|
||||||
|
|
||||||
|
x.masked_fill_(m, 0.0)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def inference(self, x):
|
||||||
|
x = self.embedding(x)
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
x = self.cnn(x)
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
self.lstm.flatten_parameters()
|
||||||
|
x, _ = self.lstm(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def length_to_mask(self, lengths):
|
||||||
|
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
||||||
|
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
class UpSample1d(nn.Module):
|
||||||
|
def __init__(self, layer_type):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_type = layer_type
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.layer_type == 'none':
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
return F.interpolate(x, scale_factor=2, mode='nearest')
|
||||||
|
|
||||||
|
class AdainResBlk1d(nn.Module):
|
||||||
|
def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
|
||||||
|
upsample='none', dropout_p=0.0):
|
||||||
|
super().__init__()
|
||||||
|
self.actv = actv
|
||||||
|
self.upsample_type = upsample
|
||||||
|
self.upsample = UpSample1d(upsample)
|
||||||
|
self.learned_sc = dim_in != dim_out
|
||||||
|
self._build_weights(dim_in, dim_out, style_dim)
|
||||||
|
self.dropout = nn.Dropout(dropout_p)
|
||||||
|
|
||||||
|
if upsample == 'none':
|
||||||
|
self.pool = nn.Identity()
|
||||||
|
else:
|
||||||
|
self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
|
||||||
|
|
||||||
|
|
||||||
|
def _build_weights(self, dim_in, dim_out, style_dim):
|
||||||
|
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
|
||||||
|
self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
|
||||||
|
self.norm1 = AdaIN1d(style_dim, dim_in)
|
||||||
|
self.norm2 = AdaIN1d(style_dim, dim_out)
|
||||||
|
if self.learned_sc:
|
||||||
|
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
|
||||||
|
|
||||||
|
def _shortcut(self, x):
|
||||||
|
x = self.upsample(x)
|
||||||
|
if self.learned_sc:
|
||||||
|
x = self.conv1x1(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _residual(self, x, s):
|
||||||
|
x = self.norm1(x, s)
|
||||||
|
x = self.actv(x)
|
||||||
|
x = self.pool(x)
|
||||||
|
x = self.conv1(self.dropout(x))
|
||||||
|
x = self.norm2(x, s)
|
||||||
|
x = self.actv(x)
|
||||||
|
x = self.conv2(self.dropout(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x, s):
|
||||||
|
out = self._residual(x, s)
|
||||||
|
out = (out + self._shortcut(x)) / np.sqrt(2)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class AdaLayerNorm(nn.Module):
|
||||||
|
def __init__(self, style_dim, channels, eps=1e-5):
|
||||||
|
super().__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
self.fc = nn.Linear(style_dim, channels*2)
|
||||||
|
|
||||||
|
def forward(self, x, s):
|
||||||
|
x = x.transpose(-1, -2)
|
||||||
|
x = x.transpose(1, -1)
|
||||||
|
|
||||||
|
h = self.fc(s)
|
||||||
|
h = h.view(h.size(0), h.size(1), 1)
|
||||||
|
gamma, beta = torch.chunk(h, chunks=2, dim=1)
|
||||||
|
gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
|
||||||
|
|
||||||
|
|
||||||
|
x = F.layer_norm(x, (self.channels,), eps=self.eps)
|
||||||
|
x = (1 + gamma) * x + beta
|
||||||
|
return x.transpose(1, -1).transpose(-1, -2)
|
||||||
|
|
||||||
|
class ProsodyPredictor(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.text_encoder = DurationEncoder(sty_dim=style_dim,
|
||||||
|
d_model=d_hid,
|
||||||
|
nlayers=nlayers,
|
||||||
|
dropout=dropout)
|
||||||
|
|
||||||
|
self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
|
||||||
|
self.duration_proj = LinearNorm(d_hid, max_dur)
|
||||||
|
|
||||||
|
self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
|
||||||
|
self.F0 = nn.ModuleList()
|
||||||
|
self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
|
||||||
|
self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
|
||||||
|
self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
|
||||||
|
|
||||||
|
self.N = nn.ModuleList()
|
||||||
|
self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
|
||||||
|
self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
|
||||||
|
self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
|
||||||
|
|
||||||
|
self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
|
||||||
|
self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, texts, style, text_lengths, alignment, m):
|
||||||
|
d = self.text_encoder(texts, style, text_lengths, m)
|
||||||
|
|
||||||
|
batch_size = d.shape[0]
|
||||||
|
text_size = d.shape[1]
|
||||||
|
|
||||||
|
# predict duration
|
||||||
|
input_lengths = text_lengths.cpu().numpy()
|
||||||
|
x = nn.utils.rnn.pack_padded_sequence(
|
||||||
|
d, input_lengths, batch_first=True, enforce_sorted=False)
|
||||||
|
|
||||||
|
m = m.to(text_lengths.device).unsqueeze(1)
|
||||||
|
|
||||||
|
self.lstm.flatten_parameters()
|
||||||
|
x, _ = self.lstm(x)
|
||||||
|
x, _ = nn.utils.rnn.pad_packed_sequence(
|
||||||
|
x, batch_first=True)
|
||||||
|
|
||||||
|
x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
|
||||||
|
|
||||||
|
x_pad[:, :x.shape[1], :] = x
|
||||||
|
x = x_pad.to(x.device)
|
||||||
|
|
||||||
|
duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
|
||||||
|
|
||||||
|
en = (d.transpose(-1, -2) @ alignment)
|
||||||
|
|
||||||
|
return duration.squeeze(-1), en
|
||||||
|
|
||||||
|
def F0Ntrain(self, x, s):
|
||||||
|
x, _ = self.shared(x.transpose(-1, -2))
|
||||||
|
|
||||||
|
F0 = x.transpose(-1, -2)
|
||||||
|
for block in self.F0:
|
||||||
|
F0 = block(F0, s)
|
||||||
|
F0 = self.F0_proj(F0)
|
||||||
|
|
||||||
|
N = x.transpose(-1, -2)
|
||||||
|
for block in self.N:
|
||||||
|
N = block(N, s)
|
||||||
|
N = self.N_proj(N)
|
||||||
|
|
||||||
|
return F0.squeeze(1), N.squeeze(1)
|
||||||
|
|
||||||
|
def length_to_mask(self, lengths):
|
||||||
|
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
||||||
|
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
||||||
|
return mask
|
||||||
|
|
||||||
|
class DurationEncoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.lstms = nn.ModuleList()
|
||||||
|
for _ in range(nlayers):
|
||||||
|
self.lstms.append(nn.LSTM(d_model + sty_dim,
|
||||||
|
d_model // 2,
|
||||||
|
num_layers=1,
|
||||||
|
batch_first=True,
|
||||||
|
bidirectional=True,
|
||||||
|
dropout=dropout))
|
||||||
|
self.lstms.append(AdaLayerNorm(sty_dim, d_model))
|
||||||
|
|
||||||
|
|
||||||
|
self.dropout = dropout
|
||||||
|
self.d_model = d_model
|
||||||
|
self.sty_dim = sty_dim
|
||||||
|
|
||||||
|
def forward(self, x, style, text_lengths, m):
|
||||||
|
masks = m.to(text_lengths.device)
|
||||||
|
|
||||||
|
x = x.permute(2, 0, 1)
|
||||||
|
s = style.expand(x.shape[0], x.shape[1], -1)
|
||||||
|
x = torch.cat([x, s], axis=-1)
|
||||||
|
x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
|
||||||
|
|
||||||
|
x = x.transpose(0, 1)
|
||||||
|
input_lengths = text_lengths.cpu().numpy()
|
||||||
|
x = x.transpose(-1, -2)
|
||||||
|
|
||||||
|
for block in self.lstms:
|
||||||
|
if isinstance(block, AdaLayerNorm):
|
||||||
|
x = block(x.transpose(-1, -2), style).transpose(-1, -2)
|
||||||
|
x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
|
||||||
|
x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
|
||||||
|
else:
|
||||||
|
x = x.transpose(-1, -2)
|
||||||
|
x = nn.utils.rnn.pack_padded_sequence(
|
||||||
|
x, input_lengths, batch_first=True, enforce_sorted=False)
|
||||||
|
block.flatten_parameters()
|
||||||
|
x, _ = block(x)
|
||||||
|
x, _ = nn.utils.rnn.pad_packed_sequence(
|
||||||
|
x, batch_first=True)
|
||||||
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
||||||
|
x = x.transpose(-1, -2)
|
||||||
|
|
||||||
|
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
|
||||||
|
|
||||||
|
x_pad[:, :, :x.shape[-1]] = x
|
||||||
|
x = x_pad.to(x.device)
|
||||||
|
|
||||||
|
return x.transpose(-1, -2)
|
||||||
|
|
||||||
|
def inference(self, x, style):
|
||||||
|
x = self.embedding(x.transpose(-1, -2)) * np.sqrt(self.d_model)
|
||||||
|
style = style.expand(x.shape[0], x.shape[1], -1)
|
||||||
|
x = torch.cat([x, style], axis=-1)
|
||||||
|
src = self.pos_encoder(x)
|
||||||
|
output = self.transformer_encoder(src).transpose(0, 1)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def length_to_mask(self, lengths):
|
||||||
|
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
||||||
|
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
||||||
|
return mask
|
||||||
|
|
||||||
|
# https://github.com/yl4579/StyleTTS2/blob/main/utils.py
|
||||||
|
def recursive_munch(d):
|
||||||
|
if isinstance(d, dict):
|
||||||
|
return Munch((k, recursive_munch(v)) for k, v in d.items())
|
||||||
|
elif isinstance(d, list):
|
||||||
|
return [recursive_munch(v) for v in d]
|
||||||
|
else:
|
||||||
|
return d
|
||||||
|
|
||||||
|
def build_model(path, device):
|
||||||
|
config = Path(__file__).parent / 'config.json'
|
||||||
|
assert config.exists(), f'Config path incorrect: config.json not found at {config}'
|
||||||
|
with open(config, 'r') as r:
|
||||||
|
args = recursive_munch(json.load(r))
|
||||||
|
assert args.decoder.type == 'istftnet', f'Unknown decoder type: {args.decoder.type}'
|
||||||
|
decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
|
||||||
|
resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
|
||||||
|
upsample_rates = args.decoder.upsample_rates,
|
||||||
|
upsample_initial_channel=args.decoder.upsample_initial_channel,
|
||||||
|
resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
|
||||||
|
upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
|
||||||
|
gen_istft_n_fft=args.decoder.gen_istft_n_fft, gen_istft_hop_size=args.decoder.gen_istft_hop_size)
|
||||||
|
text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
|
||||||
|
predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
|
||||||
|
bert = load_plbert()
|
||||||
|
bert_encoder = nn.Linear(bert.config.hidden_size, args.hidden_dim)
|
||||||
|
for parent in [bert, bert_encoder, predictor, decoder, text_encoder]:
|
||||||
|
for child in parent.children():
|
||||||
|
if isinstance(child, nn.RNNBase):
|
||||||
|
child.flatten_parameters()
|
||||||
|
model = Munch(
|
||||||
|
bert=bert.to(device).eval(),
|
||||||
|
bert_encoder=bert_encoder.to(device).eval(),
|
||||||
|
predictor=predictor.to(device).eval(),
|
||||||
|
decoder=decoder.to(device).eval(),
|
||||||
|
text_encoder=text_encoder.to(device).eval(),
|
||||||
|
)
|
||||||
|
for key, state_dict in torch.load(path, map_location='cpu', weights_only=True)['net'].items():
|
||||||
|
assert key in model, key
|
||||||
|
try:
|
||||||
|
model[key].load_state_dict(state_dict)
|
||||||
|
except:
|
||||||
|
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
||||||
|
model[key].load_state_dict(state_dict, strict=False)
|
||||||
|
return model
|
16
api/src/builds/plbert.py
Normal file
16
api/src/builds/plbert.py
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
# https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py
|
||||||
|
from transformers import AlbertConfig, AlbertModel
|
||||||
|
|
||||||
|
|
||||||
|
class CustomAlbert(AlbertModel):
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
# Call the original forward method
|
||||||
|
outputs = super().forward(*args, **kwargs)
|
||||||
|
# Only return the last_hidden_state
|
||||||
|
return outputs.last_hidden_state
|
||||||
|
|
||||||
|
def load_plbert():
|
||||||
|
plbert_config = {'vocab_size': 178, 'hidden_size': 768, 'num_attention_heads': 12, 'intermediate_size': 2048, 'max_position_embeddings': 512, 'num_hidden_layers': 12, 'dropout': 0.1}
|
||||||
|
albert_base_configuration = AlbertConfig(**plbert_config)
|
||||||
|
bert = CustomAlbert(albert_base_configuration)
|
||||||
|
return bert
|
|
@ -13,7 +13,7 @@ class Settings(BaseSettings):
|
||||||
output_dir: str = "output"
|
output_dir: str = "output"
|
||||||
output_dir_size_limit_mb: float = 500.0 # Maximum size of output directory in MB
|
output_dir_size_limit_mb: float = 500.0 # Maximum size of output directory in MB
|
||||||
default_voice: str = "af"
|
default_voice: str = "af"
|
||||||
model_dir: str = "/app/Kokoro-82M" # Base directory for model files
|
model_dir: str = "/app/api/model_files" # Base directory for model files
|
||||||
pytorch_model_path: str = "kokoro-v0_19.pth"
|
pytorch_model_path: str = "kokoro-v0_19.pth"
|
||||||
onnx_model_path: str = "kokoro-v0_19.onnx"
|
onnx_model_path: str = "kokoro-v0_19.onnx"
|
||||||
voices_dir: str = "voices"
|
voices_dir: str = "voices"
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import torch
|
|
||||||
import phonemizer
|
import phonemizer
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def split_num(num):
|
def split_num(num):
|
||||||
|
|
|
@ -6,15 +6,15 @@ import sys
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from loguru import logger
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from .core.config import settings
|
from .core.config import settings
|
||||||
from .services.tts_model import TTSModel
|
|
||||||
from .routers.development import router as dev_router
|
from .routers.development import router as dev_router
|
||||||
from .services.tts_service import TTSService
|
|
||||||
from .routers.openai_compatible import router as openai_router
|
from .routers.openai_compatible import router as openai_router
|
||||||
|
from .services.tts_model import TTSModel
|
||||||
|
from .services.tts_service import TTSService
|
||||||
|
|
||||||
|
|
||||||
def setup_logger():
|
def setup_logger():
|
||||||
|
@ -47,7 +47,7 @@ async def lifespan(app: FastAPI):
|
||||||
# Initialize the main model with warm-up
|
# Initialize the main model with warm-up
|
||||||
voicepack_count = await TTSModel.setup()
|
voicepack_count = await TTSModel.setup()
|
||||||
# boundary = "█████╗"*9
|
# boundary = "█████╗"*9
|
||||||
boundary = "░" * 24
|
boundary = "░" * 2*12
|
||||||
startup_msg = f"""
|
startup_msg = f"""
|
||||||
|
|
||||||
{boundary}
|
{boundary}
|
||||||
|
|
|
@ -1,18 +1,18 @@
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Response
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from fastapi import Depends, Response, APIRouter, HTTPException
|
|
||||||
|
|
||||||
from ..services.audio import AudioService
|
from ..services.audio import AudioService
|
||||||
|
from ..services.text_processing import phonemize, tokenize
|
||||||
from ..services.tts_model import TTSModel
|
from ..services.tts_model import TTSModel
|
||||||
from ..services.tts_service import TTSService
|
from ..services.tts_service import TTSService
|
||||||
from ..structures.text_schemas import (
|
from ..structures.text_schemas import (
|
||||||
|
GenerateFromPhonemesRequest,
|
||||||
PhonemeRequest,
|
PhonemeRequest,
|
||||||
PhonemeResponse,
|
PhonemeResponse,
|
||||||
GenerateFromPhonemesRequest,
|
|
||||||
)
|
)
|
||||||
from ..services.text_processing import tokenize, phonemize
|
|
||||||
|
|
||||||
router = APIRouter(tags=["text processing"])
|
router = APIRouter(tags=["text processing"])
|
||||||
|
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
from typing import List, Union, AsyncGenerator
|
from typing import AsyncGenerator, List, Union
|
||||||
|
|
||||||
from loguru import logger
|
from fastapi import APIRouter, Depends, Header, HTTPException, Response, Request
|
||||||
from fastapi import Header, Depends, Response, APIRouter, HTTPException
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from ..services.audio import AudioService
|
from ..services.audio import AudioService
|
||||||
from ..structures.schemas import OpenAISpeechRequest
|
|
||||||
from ..services.tts_service import TTSService
|
from ..services.tts_service import TTSService
|
||||||
|
from ..structures.schemas import OpenAISpeechRequest
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
tags=["OpenAI Compatible TTS"],
|
tags=["OpenAI Compatible TTS"],
|
||||||
|
@ -49,22 +49,35 @@ async def process_voices(
|
||||||
|
|
||||||
|
|
||||||
async def stream_audio_chunks(
|
async def stream_audio_chunks(
|
||||||
tts_service: TTSService, request: OpenAISpeechRequest
|
tts_service: TTSService,
|
||||||
|
request: OpenAISpeechRequest,
|
||||||
|
client_request: Request
|
||||||
) -> AsyncGenerator[bytes, None]:
|
) -> AsyncGenerator[bytes, None]:
|
||||||
"""Stream audio chunks as they're generated"""
|
"""Stream audio chunks as they're generated with client disconnect handling"""
|
||||||
voice_to_use = await process_voices(request.voice, tts_service)
|
voice_to_use = await process_voices(request.voice, tts_service)
|
||||||
|
|
||||||
|
try:
|
||||||
async for chunk in tts_service.generate_audio_stream(
|
async for chunk in tts_service.generate_audio_stream(
|
||||||
text=request.input,
|
text=request.input,
|
||||||
voice=voice_to_use,
|
voice=voice_to_use,
|
||||||
speed=request.speed,
|
speed=request.speed,
|
||||||
output_format=request.response_format,
|
output_format=request.response_format,
|
||||||
):
|
):
|
||||||
|
# Check if client is still connected
|
||||||
|
if await client_request.is_disconnected():
|
||||||
|
logger.info("Client disconnected, stopping audio generation")
|
||||||
|
break
|
||||||
yield chunk
|
yield chunk
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in audio streaming: {str(e)}")
|
||||||
|
# Let the exception propagate to trigger cleanup
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
@router.post("/audio/speech")
|
@router.post("/audio/speech")
|
||||||
async def create_speech(
|
async def create_speech(
|
||||||
request: OpenAISpeechRequest,
|
request: OpenAISpeechRequest,
|
||||||
|
client_request: Request,
|
||||||
tts_service: TTSService = Depends(get_tts_service),
|
tts_service: TTSService = Depends(get_tts_service),
|
||||||
x_raw_response: str = Header(None, alias="x-raw-response"),
|
x_raw_response: str = Header(None, alias="x-raw-response"),
|
||||||
):
|
):
|
||||||
|
@ -87,7 +100,7 @@ async def create_speech(
|
||||||
if request.stream:
|
if request.stream:
|
||||||
# Stream audio chunks as they're generated
|
# Stream audio chunks as they're generated
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
stream_audio_chunks(tts_service, request),
|
stream_audio_chunks(tts_service, request, client_request),
|
||||||
media_type=content_type,
|
media_type=content_type,
|
||||||
headers={
|
headers={
|
||||||
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
||||||
|
|
|
@ -3,8 +3,8 @@
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import soundfile as sf
|
|
||||||
import scipy.io.wavfile as wavfile
|
import scipy.io.wavfile as wavfile
|
||||||
|
import soundfile as sf
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from ..core.config import settings
|
from ..core.config import settings
|
||||||
|
@ -22,20 +22,19 @@ class AudioNormalizer:
|
||||||
def normalize(
|
def normalize(
|
||||||
self, audio_data: np.ndarray, is_last_chunk: bool = False
|
self, audio_data: np.ndarray, is_last_chunk: bool = False
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Normalize audio data to int16 range and trim chunk boundaries"""
|
"""Convert audio data to int16 range and trim chunk boundaries"""
|
||||||
# Convert to float32 if not already
|
if len(audio_data) == 0:
|
||||||
|
raise ValueError("Audio data cannot be empty")
|
||||||
|
|
||||||
|
# Simple float32 to int16 conversion
|
||||||
audio_float = audio_data.astype(np.float32)
|
audio_float = audio_data.astype(np.float32)
|
||||||
|
|
||||||
# Normalize to [-1, 1] range first
|
# Trim for non-final chunks
|
||||||
if np.max(np.abs(audio_float)) > 0:
|
|
||||||
audio_float = audio_float / np.max(np.abs(audio_float))
|
|
||||||
|
|
||||||
# Trim end of non-final chunks to reduce gaps
|
|
||||||
if not is_last_chunk and len(audio_float) > self.samples_to_trim:
|
if not is_last_chunk and len(audio_float) > self.samples_to_trim:
|
||||||
audio_float = audio_float[: -self.samples_to_trim]
|
audio_float = audio_float[:-self.samples_to_trim]
|
||||||
|
|
||||||
# Scale to int16 range
|
# Direct scaling like the non-streaming version
|
||||||
return (audio_float * self.int16_max).astype(np.int16)
|
return (audio_float * 32767).astype(np.int16)
|
||||||
|
|
||||||
|
|
||||||
class AudioService:
|
class AudioService:
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from .normalizer import normalize_text
|
from .normalizer import normalize_text
|
||||||
from .phonemizer import EspeakBackend, PhonemizerBackend, phonemize
|
from .phonemizer import EspeakBackend, PhonemizerBackend, phonemize
|
||||||
from .vocabulary import VOCAB, tokenize, decode_tokens
|
from .vocabulary import VOCAB, decode_tokens, tokenize
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"normalize_text",
|
"normalize_text",
|
||||||
|
|
|
@ -5,19 +5,20 @@ import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from onnxruntime import (
|
from onnxruntime import (
|
||||||
ExecutionMode,
|
ExecutionMode,
|
||||||
SessionOptions,
|
|
||||||
InferenceSession,
|
|
||||||
GraphOptimizationLevel,
|
GraphOptimizationLevel,
|
||||||
|
InferenceSession,
|
||||||
|
SessionOptions,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .tts_base import TTSBaseModel
|
|
||||||
from ..core.config import settings
|
from ..core.config import settings
|
||||||
from .text_processing import tokenize, phonemize
|
from .text_processing import phonemize, tokenize
|
||||||
|
from .tts_base import TTSBaseModel
|
||||||
|
|
||||||
|
|
||||||
class TTSCPUModel(TTSBaseModel):
|
class TTSCPUModel(TTSBaseModel):
|
||||||
_instance = None
|
_instance = None
|
||||||
_onnx_session = None
|
_onnx_session = None
|
||||||
|
_device = "cpu"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_instance(cls):
|
def get_instance(cls):
|
||||||
|
@ -30,16 +31,14 @@ class TTSCPUModel(TTSBaseModel):
|
||||||
def initialize(cls, model_dir: str, model_path: str = None):
|
def initialize(cls, model_dir: str, model_path: str = None):
|
||||||
"""Initialize ONNX model for CPU inference"""
|
"""Initialize ONNX model for CPU inference"""
|
||||||
if cls._onnx_session is None:
|
if cls._onnx_session is None:
|
||||||
|
try:
|
||||||
# Try loading ONNX model
|
# Try loading ONNX model
|
||||||
onnx_path = os.path.join(model_dir, settings.onnx_model_path)
|
onnx_path = os.path.join(model_dir, settings.onnx_model_path)
|
||||||
if os.path.exists(onnx_path):
|
if not os.path.exists(onnx_path):
|
||||||
logger.info(f"Loading ONNX model from {onnx_path}")
|
|
||||||
else:
|
|
||||||
logger.error(f"ONNX model not found at {onnx_path}")
|
logger.error(f"ONNX model not found at {onnx_path}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if not onnx_path:
|
logger.info(f"Loading ONNX model from {onnx_path}")
|
||||||
return None
|
|
||||||
|
|
||||||
# Configure ONNX session for optimal performance
|
# Configure ONNX session for optimal performance
|
||||||
session_options = SessionOptions()
|
session_options = SessionOptions()
|
||||||
|
@ -88,6 +87,9 @@ class TTSCPUModel(TTSBaseModel):
|
||||||
)
|
)
|
||||||
cls._onnx_session = session
|
cls._onnx_session = session
|
||||||
return session
|
return session
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize ONNX model: {e}")
|
||||||
|
return None
|
||||||
return cls._onnx_session
|
return cls._onnx_session
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -3,12 +3,12 @@ import time
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from builds.models import build_model
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from models import build_model
|
|
||||||
|
|
||||||
from .tts_base import TTSBaseModel
|
|
||||||
from ..core.config import settings
|
from ..core.config import settings
|
||||||
from .text_processing import tokenize, phonemize
|
from .text_processing import phonemize, tokenize
|
||||||
|
from .tts_base import TTSBaseModel
|
||||||
|
|
||||||
|
|
||||||
# @torch.no_grad()
|
# @torch.no_grad()
|
||||||
|
|
|
@ -2,19 +2,19 @@ import io
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from typing import List, Tuple, Optional
|
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import aiofiles.os
|
import aiofiles.os
|
||||||
|
import numpy as np
|
||||||
import scipy.io.wavfile as wavfile
|
import scipy.io.wavfile as wavfile
|
||||||
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from .audio import AudioService, AudioNormalizer
|
|
||||||
from .tts_model import TTSModel
|
|
||||||
from ..core.config import settings
|
from ..core.config import settings
|
||||||
|
from .audio import AudioNormalizer, AudioService
|
||||||
from .text_processing import chunker, normalize_text
|
from .text_processing import chunker, normalize_text
|
||||||
|
from .tts_model import TTSModel
|
||||||
|
|
||||||
|
|
||||||
class TTSService:
|
class TTSService:
|
||||||
|
|
|
@ -4,9 +4,9 @@ from typing import List, Tuple
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from ..core.config import settings
|
||||||
from .tts_model import TTSModel
|
from .tts_model import TTSModel
|
||||||
from .tts_service import TTSService
|
from .tts_service import TTSService
|
||||||
from ..core.config import settings
|
|
||||||
|
|
||||||
|
|
||||||
class WarmupService:
|
class WarmupService:
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Union, Literal
|
from typing import List, Literal, Union
|
||||||
|
|
||||||
from pydantic import Field, BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class VoiceCombineRequest(BaseModel):
|
class VoiceCombineRequest(BaseModel):
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from pydantic import Field, BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class PhonemeRequest(BaseModel):
|
class PhonemeRequest(BaseModel):
|
||||||
|
|
BIN
api/src/voices/af_irulan.pt
Normal file
BIN
api/src/voices/af_irulan.pt
Normal file
Binary file not shown.
|
@ -1,11 +1,11 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import shutil
|
import shutil
|
||||||
from unittest.mock import Mock, MagicMock, patch
|
import sys
|
||||||
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
|
import aiofiles.threadpool
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import aiofiles.threadpool
|
|
||||||
|
|
||||||
|
|
||||||
def cleanup_mock_dirs():
|
def cleanup_mock_dirs():
|
||||||
|
@ -32,77 +32,7 @@ def cleanup():
|
||||||
cleanup_mock_dirs()
|
cleanup_mock_dirs()
|
||||||
|
|
||||||
|
|
||||||
# Create mock torch module
|
|
||||||
mock_torch = Mock()
|
|
||||||
mock_torch.cuda = Mock()
|
|
||||||
mock_torch.cuda.is_available = Mock(return_value=False)
|
|
||||||
|
|
||||||
|
|
||||||
# Create a mock tensor class that supports basic operations
|
|
||||||
class MockTensor:
|
|
||||||
def __init__(self, data):
|
|
||||||
self.data = data
|
|
||||||
if isinstance(data, (list, tuple)):
|
|
||||||
self.shape = [len(data)]
|
|
||||||
elif isinstance(data, MockTensor):
|
|
||||||
self.shape = data.shape
|
|
||||||
else:
|
|
||||||
self.shape = getattr(data, "shape", [1])
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
if isinstance(self.data, (list, tuple)):
|
|
||||||
if isinstance(idx, slice):
|
|
||||||
return MockTensor(self.data[idx])
|
|
||||||
return self.data[idx]
|
|
||||||
return self
|
|
||||||
|
|
||||||
def max(self):
|
|
||||||
if isinstance(self.data, (list, tuple)):
|
|
||||||
max_val = max(self.data)
|
|
||||||
return MockTensor(max_val)
|
|
||||||
return 5 # Default for testing
|
|
||||||
|
|
||||||
def item(self):
|
|
||||||
if isinstance(self.data, (list, tuple)):
|
|
||||||
return max(self.data)
|
|
||||||
if isinstance(self.data, (int, float)):
|
|
||||||
return self.data
|
|
||||||
return 5 # Default for testing
|
|
||||||
|
|
||||||
def cuda(self):
|
|
||||||
"""Support cuda conversion"""
|
|
||||||
return self
|
|
||||||
|
|
||||||
def any(self):
|
|
||||||
if isinstance(self.data, (list, tuple)):
|
|
||||||
return any(self.data)
|
|
||||||
return False
|
|
||||||
|
|
||||||
def all(self):
|
|
||||||
if isinstance(self.data, (list, tuple)):
|
|
||||||
return all(self.data)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def unsqueeze(self, dim):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def expand(self, *args):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def type_as(self, other):
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
# Add tensor operations to mock torch
|
|
||||||
mock_torch.tensor = lambda x: MockTensor(x)
|
|
||||||
mock_torch.zeros = lambda *args: MockTensor(
|
|
||||||
[0] * (args[0] if isinstance(args[0], int) else args[0][0])
|
|
||||||
)
|
|
||||||
mock_torch.arange = lambda x: MockTensor(list(range(x)))
|
|
||||||
mock_torch.gt = lambda x, y: MockTensor([False] * x.shape[0])
|
|
||||||
|
|
||||||
# Mock modules before they're imported
|
# Mock modules before they're imported
|
||||||
sys.modules["torch"] = mock_torch
|
|
||||||
sys.modules["transformers"] = Mock()
|
sys.modules["transformers"] = Mock()
|
||||||
sys.modules["phonemizer"] = Mock()
|
sys.modules["phonemizer"] = Mock()
|
||||||
sys.modules["models"] = Mock()
|
sys.modules["models"] = Mock()
|
||||||
|
|
|
@ -5,7 +5,7 @@ from unittest.mock import patch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from api.src.services.audio import AudioService, AudioNormalizer
|
from api.src.services.audio import AudioNormalizer, AudioService
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from unittest.mock import Mock, AsyncMock
|
from unittest.mock import AsyncMock, Mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
from httpx import AsyncClient
|
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
from httpx import AsyncClient
|
||||||
|
|
||||||
from ..src.main import app
|
from ..src.main import app
|
||||||
|
|
||||||
|
|
|
@ -7,8 +7,8 @@ import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
|
|
||||||
from .conftest import MockTTSModel
|
|
||||||
from ..src.main import app
|
from ..src.main import app
|
||||||
|
from .conftest import MockTTSModel
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
@pytest_asyncio.fixture
|
||||||
|
|
|
@ -1,15 +1,15 @@
|
||||||
"""Tests for TTS model implementations"""
|
"""Tests for TTS model implementations"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from api.src.services.tts_base import TTSBaseModel
|
||||||
from api.src.services.tts_cpu import TTSCPUModel
|
from api.src.services.tts_cpu import TTSCPUModel
|
||||||
from api.src.services.tts_gpu import TTSGPUModel, length_to_mask
|
from api.src.services.tts_gpu import TTSGPUModel, length_to_mask
|
||||||
from api.src.services.tts_base import TTSBaseModel
|
|
||||||
|
|
||||||
|
|
||||||
# Base Model Tests
|
# Base Model Tests
|
||||||
|
@ -27,17 +27,31 @@ def test_get_device_error():
|
||||||
@patch("os.listdir")
|
@patch("os.listdir")
|
||||||
@patch("torch.load")
|
@patch("torch.load")
|
||||||
@patch("torch.save")
|
@patch("torch.save")
|
||||||
|
@patch("api.src.services.tts_base.settings")
|
||||||
|
@patch("api.src.services.warmup.WarmupService")
|
||||||
async def test_setup_cuda_available(
|
async def test_setup_cuda_available(
|
||||||
mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available
|
mock_warmup_class, mock_settings, mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available
|
||||||
):
|
):
|
||||||
"""Test setup with CUDA available"""
|
"""Test setup with CUDA available"""
|
||||||
TTSBaseModel._device = None
|
TTSBaseModel._device = None
|
||||||
mock_cuda_available.return_value = True
|
# Mock CUDA as unavailable since we're using CPU PyTorch
|
||||||
|
mock_cuda_available.return_value = False
|
||||||
mock_exists.return_value = True
|
mock_exists.return_value = True
|
||||||
mock_load.return_value = torch.zeros(1)
|
mock_load.return_value = torch.zeros(1)
|
||||||
mock_listdir.return_value = ["voice1.pt", "voice2.pt"]
|
mock_listdir.return_value = ["voice1.pt", "voice2.pt"]
|
||||||
mock_join.return_value = "/mocked/path"
|
mock_join.return_value = "/mocked/path"
|
||||||
|
|
||||||
|
# Configure mock settings
|
||||||
|
mock_settings.model_dir = "/mock/model/dir"
|
||||||
|
mock_settings.onnx_model_path = "model.onnx"
|
||||||
|
mock_settings.voices_dir = "voices"
|
||||||
|
|
||||||
|
# Configure mock warmup service
|
||||||
|
mock_warmup = MagicMock()
|
||||||
|
mock_warmup.load_voices.return_value = [torch.zeros(1)]
|
||||||
|
mock_warmup.warmup_voices = AsyncMock()
|
||||||
|
mock_warmup_class.return_value = mock_warmup
|
||||||
|
|
||||||
# Create mock model
|
# Create mock model
|
||||||
mock_model = MagicMock()
|
mock_model = MagicMock()
|
||||||
mock_model.bert = MagicMock()
|
mock_model.bert = MagicMock()
|
||||||
|
@ -49,7 +63,7 @@ async def test_setup_cuda_available(
|
||||||
TTSBaseModel._instance = mock_model
|
TTSBaseModel._instance = mock_model
|
||||||
|
|
||||||
voice_count = await TTSBaseModel.setup()
|
voice_count = await TTSBaseModel.setup()
|
||||||
assert TTSBaseModel._device == "cuda"
|
assert TTSBaseModel._device == "cpu"
|
||||||
assert voice_count == 2
|
assert voice_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,8 +74,10 @@ async def test_setup_cuda_available(
|
||||||
@patch("os.listdir")
|
@patch("os.listdir")
|
||||||
@patch("torch.load")
|
@patch("torch.load")
|
||||||
@patch("torch.save")
|
@patch("torch.save")
|
||||||
|
@patch("api.src.services.tts_base.settings")
|
||||||
|
@patch("api.src.services.warmup.WarmupService")
|
||||||
async def test_setup_cuda_unavailable(
|
async def test_setup_cuda_unavailable(
|
||||||
mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available
|
mock_warmup_class, mock_settings, mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available
|
||||||
):
|
):
|
||||||
"""Test setup with CUDA unavailable"""
|
"""Test setup with CUDA unavailable"""
|
||||||
TTSBaseModel._device = None
|
TTSBaseModel._device = None
|
||||||
|
@ -71,6 +87,17 @@ async def test_setup_cuda_unavailable(
|
||||||
mock_listdir.return_value = ["voice1.pt", "voice2.pt"]
|
mock_listdir.return_value = ["voice1.pt", "voice2.pt"]
|
||||||
mock_join.return_value = "/mocked/path"
|
mock_join.return_value = "/mocked/path"
|
||||||
|
|
||||||
|
# Configure mock settings
|
||||||
|
mock_settings.model_dir = "/mock/model/dir"
|
||||||
|
mock_settings.onnx_model_path = "model.onnx"
|
||||||
|
mock_settings.voices_dir = "voices"
|
||||||
|
|
||||||
|
# Configure mock warmup service
|
||||||
|
mock_warmup = MagicMock()
|
||||||
|
mock_warmup.load_voices.return_value = [torch.zeros(1)]
|
||||||
|
mock_warmup.warmup_voices = AsyncMock()
|
||||||
|
mock_warmup_class.return_value = mock_warmup
|
||||||
|
|
||||||
# Create mock model
|
# Create mock model
|
||||||
mock_model = MagicMock()
|
mock_model = MagicMock()
|
||||||
mock_model.bert = MagicMock()
|
mock_model.bert = MagicMock()
|
||||||
|
|
|
@ -4,8 +4,8 @@ import os
|
||||||
from unittest.mock import MagicMock, call, patch
|
from unittest.mock import MagicMock, call, patch
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
from onnxruntime import InferenceSession
|
from onnxruntime import InferenceSession
|
||||||
|
|
||||||
from api.src.core.config import settings
|
from api.src.core.config import settings
|
||||||
|
|
|
@ -1,79 +0,0 @@
|
||||||
name: kokoro-fastapi
|
|
||||||
services:
|
|
||||||
model-fetcher:
|
|
||||||
image: datamachines/git-lfs:latest
|
|
||||||
volumes:
|
|
||||||
- ./Kokoro-82M:/app/Kokoro-82M
|
|
||||||
working_dir: /app/Kokoro-82M
|
|
||||||
command: >
|
|
||||||
sh -c "
|
|
||||||
mkdir -p /app/Kokoro-82M;
|
|
||||||
cd /app/Kokoro-82M;
|
|
||||||
rm -f .git/index.lock;
|
|
||||||
if [ -z \"$(ls -A .)\" ]; then
|
|
||||||
git clone https://huggingface.co/hexgrad/Kokoro-82M .
|
|
||||||
touch .cloned;
|
|
||||||
else
|
|
||||||
rm -f .git/index.lock && \
|
|
||||||
git checkout main && \
|
|
||||||
git pull origin main && \
|
|
||||||
touch .cloned;
|
|
||||||
fi;
|
|
||||||
tail -f /dev/null
|
|
||||||
"
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD", "test", "-f", ".cloned"]
|
|
||||||
interval: 5s
|
|
||||||
timeout: 2s
|
|
||||||
retries: 300
|
|
||||||
start_period: 1s
|
|
||||||
|
|
||||||
kokoro-tts:
|
|
||||||
image: ghcr.io/remsky/kokoro-fastapi-cpu:v0.0.5post1
|
|
||||||
# Uncomment below (and comment out above) to build from source instead of using the released image
|
|
||||||
build:
|
|
||||||
context: .
|
|
||||||
dockerfile: Dockerfile.cpu
|
|
||||||
volumes:
|
|
||||||
- ./api/src:/app/api/src
|
|
||||||
- ./Kokoro-82M:/app/Kokoro-82M
|
|
||||||
ports:
|
|
||||||
- "8880:8880"
|
|
||||||
environment:
|
|
||||||
- PYTHONPATH=/app:/app/Kokoro-82M
|
|
||||||
# ONNX Optimization Settings for vectorized operations
|
|
||||||
- ONNX_NUM_THREADS=8 # Maximize core usage for vectorized ops
|
|
||||||
- ONNX_INTER_OP_THREADS=4 # Higher inter-op for parallel matrix operations
|
|
||||||
- ONNX_EXECUTION_MODE=parallel
|
|
||||||
- ONNX_OPTIMIZATION_LEVEL=all
|
|
||||||
- ONNX_MEMORY_PATTERN=true
|
|
||||||
- ONNX_ARENA_EXTEND_STRATEGY=kNextPowerOfTwo
|
|
||||||
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD", "curl", "-f", "http://localhost:8880/health"]
|
|
||||||
interval: 10s
|
|
||||||
timeout: 5s
|
|
||||||
retries: 30
|
|
||||||
start_period: 30s
|
|
||||||
depends_on:
|
|
||||||
model-fetcher:
|
|
||||||
condition: service_healthy
|
|
||||||
|
|
||||||
|
|
||||||
# Gradio UI service [Comment out everything below if you don't need it]
|
|
||||||
gradio-ui:
|
|
||||||
image: ghcr.io/remsky/kokoro-fastapi-ui:v0.0.5post1
|
|
||||||
# Uncomment below (and comment out above) to build from source instead of using the released image
|
|
||||||
# build:
|
|
||||||
# context: ./ui
|
|
||||||
ports:
|
|
||||||
- "7860:7860"
|
|
||||||
volumes:
|
|
||||||
- ./ui/data:/app/ui/data
|
|
||||||
- ./ui/app.py:/app/app.py # Mount app.py for hot reload
|
|
||||||
environment:
|
|
||||||
- GRADIO_WATCH=True # Enable hot reloading
|
|
||||||
- PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
|
|
||||||
depends_on:
|
|
||||||
kokoro-tts:
|
|
||||||
condition: service_healthy
|
|
|
@ -1,80 +0,0 @@
|
||||||
name: kokoro-fastapi
|
|
||||||
services:
|
|
||||||
model-fetcher:
|
|
||||||
image: datamachines/git-lfs:latest
|
|
||||||
environment:
|
|
||||||
- SKIP_MODEL_FETCH=${SKIP_MODEL_FETCH:-false}
|
|
||||||
volumes:
|
|
||||||
- ./Kokoro-82M:/app/Kokoro-82M
|
|
||||||
working_dir: /app/Kokoro-82M
|
|
||||||
command: >
|
|
||||||
sh -c "
|
|
||||||
if [ \"$$SKIP_MODEL_FETCH\" = \"true\" ]; then
|
|
||||||
echo 'Skipping model fetch...' && touch .cloned;
|
|
||||||
else
|
|
||||||
rm -f .git/index.lock;
|
|
||||||
if [ -z \"$(ls -A .)\" ]; then
|
|
||||||
git clone https://huggingface.co/hexgrad/Kokoro-82M .
|
|
||||||
touch .cloned;
|
|
||||||
else
|
|
||||||
rm -f .git/index.lock && \
|
|
||||||
git checkout main && \
|
|
||||||
git pull origin main && \
|
|
||||||
touch .cloned;
|
|
||||||
fi;
|
|
||||||
fi;
|
|
||||||
tail -f /dev/null
|
|
||||||
"
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD", "test", "-f", ".cloned"]
|
|
||||||
interval: 5s
|
|
||||||
timeout: 2s
|
|
||||||
retries: 300
|
|
||||||
start_period: 1s
|
|
||||||
|
|
||||||
kokoro-tts:
|
|
||||||
image: ghcr.io/remsky/kokoro-fastapi-gpu:v0.0.5post1
|
|
||||||
# Uncomment below (and comment out above) to build from source instead of using the released image
|
|
||||||
# build:
|
|
||||||
# context: .
|
|
||||||
volumes:
|
|
||||||
- ./api/src:/app/api/src
|
|
||||||
- ./Kokoro-82M:/app/Kokoro-82M
|
|
||||||
ports:
|
|
||||||
- "8880:8880"
|
|
||||||
environment:
|
|
||||||
- PYTHONPATH=/app:/app/Kokoro-82M
|
|
||||||
deploy:
|
|
||||||
resources:
|
|
||||||
reservations:
|
|
||||||
devices:
|
|
||||||
- driver: nvidia
|
|
||||||
count: 1
|
|
||||||
capabilities: [gpu]
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD", "curl", "-f", "http://localhost:8880/health"]
|
|
||||||
interval: 10s
|
|
||||||
timeout: 5s
|
|
||||||
retries: 30
|
|
||||||
start_period: 30s
|
|
||||||
depends_on:
|
|
||||||
model-fetcher:
|
|
||||||
condition: service_healthy
|
|
||||||
|
|
||||||
# Gradio UI service [Comment out everything below if you don't need it]
|
|
||||||
gradio-ui:
|
|
||||||
image: ghcr.io/remsky/kokoro-fastapi-ui:v0.0.5post1
|
|
||||||
# Uncomment below (and comment out above) to build from source instead of using the released image
|
|
||||||
# build:
|
|
||||||
# context: ./ui
|
|
||||||
ports:
|
|
||||||
- "7860:7860"
|
|
||||||
volumes:
|
|
||||||
- ./ui/data:/app/ui/data
|
|
||||||
- ./ui/app.py:/app/app.py # Mount app.py for hot reload
|
|
||||||
environment:
|
|
||||||
- GRADIO_WATCH=True # Enable hot reloading
|
|
||||||
- PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
|
|
||||||
depends_on:
|
|
||||||
kokoro-tts:
|
|
||||||
condition: service_healthy
|
|
62
docker/cpu/Dockerfile
Normal file
62
docker/cpu/Dockerfile
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
FROM python:3.10-slim
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
espeak-ng \
|
||||||
|
git \
|
||||||
|
libsndfile1 \
|
||||||
|
curl \
|
||||||
|
&& apt-get clean \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install uv
|
||||||
|
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
||||||
|
|
||||||
|
# Create non-root user
|
||||||
|
RUN useradd -m -u 1000 appuser
|
||||||
|
|
||||||
|
# Create directories and set ownership
|
||||||
|
RUN mkdir -p /app/api/model_files && \
|
||||||
|
mkdir -p /app/api/src/voices && \
|
||||||
|
chown -R appuser:appuser /app
|
||||||
|
|
||||||
|
USER appuser
|
||||||
|
|
||||||
|
# Download and extract models
|
||||||
|
WORKDIR /app/api/model_files
|
||||||
|
RUN curl -L -o model.tar.gz https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.0.1/kokoro-82m-onnx.tar.gz && \
|
||||||
|
tar xzf model.tar.gz && \
|
||||||
|
rm model.tar.gz
|
||||||
|
|
||||||
|
# Download and extract voice models
|
||||||
|
WORKDIR /app/api/src/voices
|
||||||
|
RUN curl -L -o voices.tar.gz https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.0.1/voice-models.tar.gz && \
|
||||||
|
tar xzf voices.tar.gz && \
|
||||||
|
rm voices.tar.gz
|
||||||
|
|
||||||
|
# Switch back to app directory
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy dependency files
|
||||||
|
COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
|
uv venv && \
|
||||||
|
uv sync --extra cpu --no-install-project
|
||||||
|
|
||||||
|
# Copy project files
|
||||||
|
COPY --chown=appuser:appuser api ./api
|
||||||
|
|
||||||
|
# Install project
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
|
uv sync --extra cpu
|
||||||
|
|
||||||
|
# Set environment variables
|
||||||
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
ENV PYTHONPATH=/app:/app/Kokoro-82M
|
||||||
|
ENV PATH="/app/.venv/bin:$PATH"
|
||||||
|
ENV UV_LINK_MODE=copy
|
||||||
|
|
||||||
|
# Run FastAPI server
|
||||||
|
CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]
|
37
docker/cpu/docker-compose.yml
Normal file
37
docker/cpu/docker-compose.yml
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
name: kokoro-tts
|
||||||
|
services:
|
||||||
|
kokoro-tts:
|
||||||
|
image: ghcr.io/remsky/kokoro-fastapi-cpu:latest
|
||||||
|
# Uncomment below (and comment out above) to build from source instead of using the released image
|
||||||
|
# build:
|
||||||
|
# context: ../..
|
||||||
|
# dockerfile: docker/cpu/Dockerfile
|
||||||
|
volumes:
|
||||||
|
- ../../api/src:/app/api/src
|
||||||
|
ports:
|
||||||
|
- "8880:8880"
|
||||||
|
environment:
|
||||||
|
- PYTHONPATH=/app:/app/Kokoro-82M
|
||||||
|
# ONNX Optimization Settings for vectorized operations
|
||||||
|
- ONNX_NUM_THREADS=8 # Maximize core usage for vectorized ops
|
||||||
|
- ONNX_INTER_OP_THREADS=4 # Higher inter-op for parallel matrix operations
|
||||||
|
- ONNX_EXECUTION_MODE=parallel
|
||||||
|
- ONNX_OPTIMIZATION_LEVEL=all
|
||||||
|
- ONNX_MEMORY_PATTERN=true
|
||||||
|
- ONNX_ARENA_EXTEND_STRATEGY=kNextPowerOfTwo
|
||||||
|
|
||||||
|
# Gradio UI service [Comment out everything below if you don't need it]
|
||||||
|
gradio-ui:
|
||||||
|
image: ghcr.io/remsky/kokoro-fastapi:latest-ui
|
||||||
|
# Uncomment below (and comment out above) to build from source instead of using the released image
|
||||||
|
# build:
|
||||||
|
# context: ../../ui
|
||||||
|
ports:
|
||||||
|
- "7860:7860"
|
||||||
|
volumes:
|
||||||
|
- ../../ui/data:/app/ui/data
|
||||||
|
- ../../ui/app.py:/app/app.py # Mount app.py for hot reload
|
||||||
|
environment:
|
||||||
|
- GRADIO_WATCH=True # Enable hot reloading
|
||||||
|
- PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
|
||||||
|
- DISABLE_LOCAL_SAVING=false # Set to 'true' to disable local saving and hide file view
|
22
docker/cpu/pyproject.toml
Normal file
22
docker/cpu/pyproject.toml
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
[project]
|
||||||
|
name = "kokoro-fastapi-cpu"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "FastAPI TTS Service - CPU Version"
|
||||||
|
readme = "../README.md"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
dependencies = [
|
||||||
|
# Core ML/DL for CPU
|
||||||
|
"torch>=2.5.1",
|
||||||
|
"transformers==4.47.1",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.uv.workspace]
|
||||||
|
members = ["../shared"]
|
||||||
|
|
||||||
|
[tool.uv.sources]
|
||||||
|
torch = { index = "pytorch-cpu" }
|
||||||
|
|
||||||
|
[[tool.uv.index]]
|
||||||
|
name = "pytorch-cpu"
|
||||||
|
url = "https://download.pytorch.org/whl/cpu"
|
||||||
|
explicit = true
|
229
docker/cpu/requirements.lock
Normal file
229
docker/cpu/requirements.lock
Normal file
|
@ -0,0 +1,229 @@
|
||||||
|
# This file was autogenerated by uv via the following command:
|
||||||
|
# uv pip compile pyproject.toml ../shared/pyproject.toml --output-file requirements.lock
|
||||||
|
aiofiles==23.2.1
|
||||||
|
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
annotated-types==0.7.0
|
||||||
|
# via pydantic
|
||||||
|
anyio==4.8.0
|
||||||
|
# via starlette
|
||||||
|
attrs==24.3.0
|
||||||
|
# via
|
||||||
|
# clldutils
|
||||||
|
# csvw
|
||||||
|
# jsonschema
|
||||||
|
# phonemizer
|
||||||
|
# referencing
|
||||||
|
babel==2.16.0
|
||||||
|
# via csvw
|
||||||
|
certifi==2024.12.14
|
||||||
|
# via requests
|
||||||
|
cffi==1.17.1
|
||||||
|
# via soundfile
|
||||||
|
charset-normalizer==3.4.1
|
||||||
|
# via requests
|
||||||
|
click==8.1.8
|
||||||
|
# via
|
||||||
|
# kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
# uvicorn
|
||||||
|
clldutils==3.21.0
|
||||||
|
# via segments
|
||||||
|
colorama==0.4.6
|
||||||
|
# via
|
||||||
|
# click
|
||||||
|
# colorlog
|
||||||
|
# csvw
|
||||||
|
# loguru
|
||||||
|
# tqdm
|
||||||
|
coloredlogs==15.0.1
|
||||||
|
# via onnxruntime
|
||||||
|
colorlog==6.9.0
|
||||||
|
# via clldutils
|
||||||
|
csvw==3.5.1
|
||||||
|
# via segments
|
||||||
|
dlinfo==1.2.1
|
||||||
|
# via phonemizer
|
||||||
|
exceptiongroup==1.2.2
|
||||||
|
# via anyio
|
||||||
|
fastapi==0.115.6
|
||||||
|
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
filelock==3.16.1
|
||||||
|
# via
|
||||||
|
# huggingface-hub
|
||||||
|
# torch
|
||||||
|
# transformers
|
||||||
|
flatbuffers==24.12.23
|
||||||
|
# via onnxruntime
|
||||||
|
fsspec==2024.12.0
|
||||||
|
# via
|
||||||
|
# huggingface-hub
|
||||||
|
# torch
|
||||||
|
greenlet==3.1.1
|
||||||
|
# via sqlalchemy
|
||||||
|
h11==0.14.0
|
||||||
|
# via uvicorn
|
||||||
|
huggingface-hub==0.27.1
|
||||||
|
# via
|
||||||
|
# tokenizers
|
||||||
|
# transformers
|
||||||
|
humanfriendly==10.0
|
||||||
|
# via coloredlogs
|
||||||
|
idna==3.10
|
||||||
|
# via
|
||||||
|
# anyio
|
||||||
|
# requests
|
||||||
|
isodate==0.7.2
|
||||||
|
# via
|
||||||
|
# csvw
|
||||||
|
# rdflib
|
||||||
|
jinja2==3.1.5
|
||||||
|
# via torch
|
||||||
|
joblib==1.4.2
|
||||||
|
# via phonemizer
|
||||||
|
jsonschema==4.23.0
|
||||||
|
# via csvw
|
||||||
|
jsonschema-specifications==2024.10.1
|
||||||
|
# via jsonschema
|
||||||
|
language-tags==1.2.0
|
||||||
|
# via csvw
|
||||||
|
loguru==0.7.3
|
||||||
|
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
lxml==5.3.0
|
||||||
|
# via clldutils
|
||||||
|
markdown==3.7
|
||||||
|
# via clldutils
|
||||||
|
markupsafe==3.0.2
|
||||||
|
# via
|
||||||
|
# clldutils
|
||||||
|
# jinja2
|
||||||
|
mpmath==1.3.0
|
||||||
|
# via sympy
|
||||||
|
munch==4.0.0
|
||||||
|
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
networkx==3.4.2
|
||||||
|
# via torch
|
||||||
|
numpy==2.2.1
|
||||||
|
# via
|
||||||
|
# kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
# onnxruntime
|
||||||
|
# scipy
|
||||||
|
# soundfile
|
||||||
|
# transformers
|
||||||
|
onnxruntime==1.20.1
|
||||||
|
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
packaging==24.2
|
||||||
|
# via
|
||||||
|
# huggingface-hub
|
||||||
|
# onnxruntime
|
||||||
|
# transformers
|
||||||
|
phonemizer==3.3.0
|
||||||
|
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
protobuf==5.29.3
|
||||||
|
# via onnxruntime
|
||||||
|
pycparser==2.22
|
||||||
|
# via cffi
|
||||||
|
pydantic==2.10.4
|
||||||
|
# via
|
||||||
|
# kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
# fastapi
|
||||||
|
# pydantic-settings
|
||||||
|
pydantic-core==2.27.2
|
||||||
|
# via pydantic
|
||||||
|
pydantic-settings==2.7.0
|
||||||
|
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
pylatexenc==2.10
|
||||||
|
# via clldutils
|
||||||
|
pyparsing==3.2.1
|
||||||
|
# via rdflib
|
||||||
|
pyreadline3==3.5.4
|
||||||
|
# via humanfriendly
|
||||||
|
python-dateutil==2.9.0.post0
|
||||||
|
# via
|
||||||
|
# clldutils
|
||||||
|
# csvw
|
||||||
|
python-dotenv==1.0.1
|
||||||
|
# via
|
||||||
|
# kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
# pydantic-settings
|
||||||
|
pyyaml==6.0.2
|
||||||
|
# via
|
||||||
|
# huggingface-hub
|
||||||
|
# transformers
|
||||||
|
rdflib==7.1.2
|
||||||
|
# via csvw
|
||||||
|
referencing==0.35.1
|
||||||
|
# via
|
||||||
|
# jsonschema
|
||||||
|
# jsonschema-specifications
|
||||||
|
regex==2024.11.6
|
||||||
|
# via
|
||||||
|
# kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
# segments
|
||||||
|
# tiktoken
|
||||||
|
# transformers
|
||||||
|
requests==2.32.3
|
||||||
|
# via
|
||||||
|
# kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
# csvw
|
||||||
|
# huggingface-hub
|
||||||
|
# tiktoken
|
||||||
|
# transformers
|
||||||
|
rfc3986==1.5.0
|
||||||
|
# via csvw
|
||||||
|
rpds-py==0.22.3
|
||||||
|
# via
|
||||||
|
# jsonschema
|
||||||
|
# referencing
|
||||||
|
safetensors==0.5.2
|
||||||
|
# via transformers
|
||||||
|
scipy==1.14.1
|
||||||
|
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
segments==2.2.1
|
||||||
|
# via phonemizer
|
||||||
|
six==1.17.0
|
||||||
|
# via python-dateutil
|
||||||
|
sniffio==1.3.1
|
||||||
|
# via anyio
|
||||||
|
soundfile==0.13.0
|
||||||
|
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
sqlalchemy==2.0.27
|
||||||
|
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
starlette==0.41.3
|
||||||
|
# via fastapi
|
||||||
|
sympy==1.13.1
|
||||||
|
# via
|
||||||
|
# onnxruntime
|
||||||
|
# torch
|
||||||
|
tabulate==0.9.0
|
||||||
|
# via clldutils
|
||||||
|
tiktoken==0.8.0
|
||||||
|
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
tokenizers==0.21.0
|
||||||
|
# via transformers
|
||||||
|
torch==2.5.1+cpu
|
||||||
|
# via kokoro-fastapi-cpu (pyproject.toml)
|
||||||
|
tqdm==4.67.1
|
||||||
|
# via
|
||||||
|
# kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
# huggingface-hub
|
||||||
|
# transformers
|
||||||
|
transformers==4.47.1
|
||||||
|
# via kokoro-fastapi-cpu (pyproject.toml)
|
||||||
|
typing-extensions==4.12.2
|
||||||
|
# via
|
||||||
|
# anyio
|
||||||
|
# fastapi
|
||||||
|
# huggingface-hub
|
||||||
|
# phonemizer
|
||||||
|
# pydantic
|
||||||
|
# pydantic-core
|
||||||
|
# sqlalchemy
|
||||||
|
# torch
|
||||||
|
# uvicorn
|
||||||
|
uritemplate==4.1.1
|
||||||
|
# via csvw
|
||||||
|
urllib3==2.3.0
|
||||||
|
# via requests
|
||||||
|
uvicorn==0.34.0
|
||||||
|
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
win32-setctime==1.2.0
|
||||||
|
# via loguru
|
1841
docker/cpu/uv.lock
generated
Normal file
1841
docker/cpu/uv.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
64
docker/gpu/Dockerfile
Normal file
64
docker/gpu/Dockerfile
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
FROM nvidia/cuda:12.1.0-base-ubuntu22.04
|
||||||
|
|
||||||
|
# Install Python and other dependencies
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
python3.10 \
|
||||||
|
python3.10-venv \
|
||||||
|
espeak-ng \
|
||||||
|
git \
|
||||||
|
libsndfile1 \
|
||||||
|
curl \
|
||||||
|
&& apt-get clean \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install uv
|
||||||
|
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
||||||
|
|
||||||
|
# Create non-root user
|
||||||
|
RUN useradd -m -u 1000 appuser
|
||||||
|
|
||||||
|
# Create directories and set ownership
|
||||||
|
RUN mkdir -p /app/api/model_files && \
|
||||||
|
mkdir -p /app/api/src/voices && \
|
||||||
|
chown -R appuser:appuser /app
|
||||||
|
|
||||||
|
USER appuser
|
||||||
|
|
||||||
|
# Download and extract models
|
||||||
|
WORKDIR /app/api/model_files
|
||||||
|
RUN curl -L -o model.tar.gz https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.0.1/kokoro-82m-pytorch.tar.gz && \
|
||||||
|
tar xzf model.tar.gz && \
|
||||||
|
rm model.tar.gz
|
||||||
|
|
||||||
|
# Download and extract voice models
|
||||||
|
WORKDIR /app/api/src/voices
|
||||||
|
RUN curl -L -o voices.tar.gz https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.0.1/voice-models.tar.gz && \
|
||||||
|
tar xzf voices.tar.gz && \
|
||||||
|
rm voices.tar.gz
|
||||||
|
|
||||||
|
# Switch back to app directory
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy dependency files
|
||||||
|
COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
|
uv venv && \
|
||||||
|
uv sync --extra gpu --no-install-project
|
||||||
|
|
||||||
|
# Copy project files
|
||||||
|
COPY --chown=appuser:appuser api ./api
|
||||||
|
|
||||||
|
# Install project
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
|
uv sync --extra gpu
|
||||||
|
|
||||||
|
# Set environment variables
|
||||||
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
ENV PYTHONPATH=/app:/app/Kokoro-82M
|
||||||
|
ENV PATH="/app/.venv/bin:$PATH"
|
||||||
|
ENV UV_LINK_MODE=copy
|
||||||
|
|
||||||
|
# Run FastAPI server
|
||||||
|
CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]
|
37
docker/gpu/docker-compose.yml
Normal file
37
docker/gpu/docker-compose.yml
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
name: kokoro-tts
|
||||||
|
services:
|
||||||
|
kokoro-tts:
|
||||||
|
image: ghcr.io/remsky/kokoro-fastapi-gpu:latest
|
||||||
|
# Uncomment below (and comment out above) to build from source instead of using the released image
|
||||||
|
# build:
|
||||||
|
# context: ../..
|
||||||
|
# dockerfile: docker/gpu/Dockerfile
|
||||||
|
volumes:
|
||||||
|
- ../../api/src:/app/api/src # Mount src for development
|
||||||
|
ports:
|
||||||
|
- "8880:8880"
|
||||||
|
environment:
|
||||||
|
- PYTHONPATH=/app:/app/Kokoro-82M
|
||||||
|
deploy:
|
||||||
|
resources:
|
||||||
|
reservations:
|
||||||
|
devices:
|
||||||
|
- driver: nvidia
|
||||||
|
count: 1
|
||||||
|
capabilities: [gpu]
|
||||||
|
|
||||||
|
# Gradio UI service
|
||||||
|
gradio-ui:
|
||||||
|
image: ghcr.io/remsky/kokoro-fastapi-ui:latest
|
||||||
|
# Uncomment below to build from source instead of using the released image
|
||||||
|
# build:
|
||||||
|
# context: ../../ui
|
||||||
|
ports:
|
||||||
|
- "7860:7860"
|
||||||
|
volumes:
|
||||||
|
- ../../ui/data:/app/ui/data
|
||||||
|
- ../../ui/app.py:/app/app.py # Mount app.py for hot reload
|
||||||
|
environment:
|
||||||
|
- GRADIO_WATCH=1 # Enable hot reloading
|
||||||
|
- PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
|
||||||
|
- DISABLE_LOCAL_SAVING=false # Set to 'true' to disable local saving and hide file view
|
22
docker/gpu/pyproject.toml
Normal file
22
docker/gpu/pyproject.toml
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
[project]
|
||||||
|
name = "kokoro-fastapi-gpu"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "FastAPI TTS Service - GPU Version"
|
||||||
|
readme = "../README.md"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
dependencies = [
|
||||||
|
# Core ML/DL for GPU
|
||||||
|
"torch==2.5.1+cu121",
|
||||||
|
"transformers==4.47.1",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.uv.workspace]
|
||||||
|
members = ["../shared"]
|
||||||
|
|
||||||
|
[tool.uv.sources]
|
||||||
|
torch = { index = "pytorch-cuda" }
|
||||||
|
|
||||||
|
[[tool.uv.index]]
|
||||||
|
name = "pytorch-cuda"
|
||||||
|
url = "https://download.pytorch.org/whl/cu121"
|
||||||
|
explicit = true
|
229
docker/gpu/requirements.lock
Normal file
229
docker/gpu/requirements.lock
Normal file
|
@ -0,0 +1,229 @@
|
||||||
|
# This file was autogenerated by uv via the following command:
|
||||||
|
# uv pip compile pyproject.toml ../shared/pyproject.toml --output-file requirements.lock
|
||||||
|
aiofiles==23.2.1
|
||||||
|
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
annotated-types==0.7.0
|
||||||
|
# via pydantic
|
||||||
|
anyio==4.8.0
|
||||||
|
# via starlette
|
||||||
|
attrs==24.3.0
|
||||||
|
# via
|
||||||
|
# clldutils
|
||||||
|
# csvw
|
||||||
|
# jsonschema
|
||||||
|
# phonemizer
|
||||||
|
# referencing
|
||||||
|
babel==2.16.0
|
||||||
|
# via csvw
|
||||||
|
certifi==2024.12.14
|
||||||
|
# via requests
|
||||||
|
cffi==1.17.1
|
||||||
|
# via soundfile
|
||||||
|
charset-normalizer==3.4.1
|
||||||
|
# via requests
|
||||||
|
click==8.1.8
|
||||||
|
# via
|
||||||
|
# kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
# uvicorn
|
||||||
|
clldutils==3.21.0
|
||||||
|
# via segments
|
||||||
|
colorama==0.4.6
|
||||||
|
# via
|
||||||
|
# click
|
||||||
|
# colorlog
|
||||||
|
# csvw
|
||||||
|
# loguru
|
||||||
|
# tqdm
|
||||||
|
coloredlogs==15.0.1
|
||||||
|
# via onnxruntime
|
||||||
|
colorlog==6.9.0
|
||||||
|
# via clldutils
|
||||||
|
csvw==3.5.1
|
||||||
|
# via segments
|
||||||
|
dlinfo==1.2.1
|
||||||
|
# via phonemizer
|
||||||
|
exceptiongroup==1.2.2
|
||||||
|
# via anyio
|
||||||
|
fastapi==0.115.6
|
||||||
|
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
filelock==3.16.1
|
||||||
|
# via
|
||||||
|
# huggingface-hub
|
||||||
|
# torch
|
||||||
|
# transformers
|
||||||
|
flatbuffers==24.12.23
|
||||||
|
# via onnxruntime
|
||||||
|
fsspec==2024.12.0
|
||||||
|
# via
|
||||||
|
# huggingface-hub
|
||||||
|
# torch
|
||||||
|
greenlet==3.1.1
|
||||||
|
# via sqlalchemy
|
||||||
|
h11==0.14.0
|
||||||
|
# via uvicorn
|
||||||
|
huggingface-hub==0.27.1
|
||||||
|
# via
|
||||||
|
# tokenizers
|
||||||
|
# transformers
|
||||||
|
humanfriendly==10.0
|
||||||
|
# via coloredlogs
|
||||||
|
idna==3.10
|
||||||
|
# via
|
||||||
|
# anyio
|
||||||
|
# requests
|
||||||
|
isodate==0.7.2
|
||||||
|
# via
|
||||||
|
# csvw
|
||||||
|
# rdflib
|
||||||
|
jinja2==3.1.5
|
||||||
|
# via torch
|
||||||
|
joblib==1.4.2
|
||||||
|
# via phonemizer
|
||||||
|
jsonschema==4.23.0
|
||||||
|
# via csvw
|
||||||
|
jsonschema-specifications==2024.10.1
|
||||||
|
# via jsonschema
|
||||||
|
language-tags==1.2.0
|
||||||
|
# via csvw
|
||||||
|
loguru==0.7.3
|
||||||
|
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
lxml==5.3.0
|
||||||
|
# via clldutils
|
||||||
|
markdown==3.7
|
||||||
|
# via clldutils
|
||||||
|
markupsafe==3.0.2
|
||||||
|
# via
|
||||||
|
# clldutils
|
||||||
|
# jinja2
|
||||||
|
mpmath==1.3.0
|
||||||
|
# via sympy
|
||||||
|
munch==4.0.0
|
||||||
|
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
networkx==3.4.2
|
||||||
|
# via torch
|
||||||
|
numpy==2.2.1
|
||||||
|
# via
|
||||||
|
# kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
# onnxruntime
|
||||||
|
# scipy
|
||||||
|
# soundfile
|
||||||
|
# transformers
|
||||||
|
onnxruntime==1.20.1
|
||||||
|
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
packaging==24.2
|
||||||
|
# via
|
||||||
|
# huggingface-hub
|
||||||
|
# onnxruntime
|
||||||
|
# transformers
|
||||||
|
phonemizer==3.3.0
|
||||||
|
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
protobuf==5.29.3
|
||||||
|
# via onnxruntime
|
||||||
|
pycparser==2.22
|
||||||
|
# via cffi
|
||||||
|
pydantic==2.10.4
|
||||||
|
# via
|
||||||
|
# kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
# fastapi
|
||||||
|
# pydantic-settings
|
||||||
|
pydantic-core==2.27.2
|
||||||
|
# via pydantic
|
||||||
|
pydantic-settings==2.7.0
|
||||||
|
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
pylatexenc==2.10
|
||||||
|
# via clldutils
|
||||||
|
pyparsing==3.2.1
|
||||||
|
# via rdflib
|
||||||
|
pyreadline3==3.5.4
|
||||||
|
# via humanfriendly
|
||||||
|
python-dateutil==2.9.0.post0
|
||||||
|
# via
|
||||||
|
# clldutils
|
||||||
|
# csvw
|
||||||
|
python-dotenv==1.0.1
|
||||||
|
# via
|
||||||
|
# kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
# pydantic-settings
|
||||||
|
pyyaml==6.0.2
|
||||||
|
# via
|
||||||
|
# huggingface-hub
|
||||||
|
# transformers
|
||||||
|
rdflib==7.1.2
|
||||||
|
# via csvw
|
||||||
|
referencing==0.35.1
|
||||||
|
# via
|
||||||
|
# jsonschema
|
||||||
|
# jsonschema-specifications
|
||||||
|
regex==2024.11.6
|
||||||
|
# via
|
||||||
|
# kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
# segments
|
||||||
|
# tiktoken
|
||||||
|
# transformers
|
||||||
|
requests==2.32.3
|
||||||
|
# via
|
||||||
|
# kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
# csvw
|
||||||
|
# huggingface-hub
|
||||||
|
# tiktoken
|
||||||
|
# transformers
|
||||||
|
rfc3986==1.5.0
|
||||||
|
# via csvw
|
||||||
|
rpds-py==0.22.3
|
||||||
|
# via
|
||||||
|
# jsonschema
|
||||||
|
# referencing
|
||||||
|
safetensors==0.5.2
|
||||||
|
# via transformers
|
||||||
|
scipy==1.14.1
|
||||||
|
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
segments==2.2.1
|
||||||
|
# via phonemizer
|
||||||
|
six==1.17.0
|
||||||
|
# via python-dateutil
|
||||||
|
sniffio==1.3.1
|
||||||
|
# via anyio
|
||||||
|
soundfile==0.13.0
|
||||||
|
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
sqlalchemy==2.0.27
|
||||||
|
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
starlette==0.41.3
|
||||||
|
# via fastapi
|
||||||
|
sympy==1.13.1
|
||||||
|
# via
|
||||||
|
# onnxruntime
|
||||||
|
# torch
|
||||||
|
tabulate==0.9.0
|
||||||
|
# via clldutils
|
||||||
|
tiktoken==0.8.0
|
||||||
|
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
tokenizers==0.21.0
|
||||||
|
# via transformers
|
||||||
|
torch==2.5.1+cu121
|
||||||
|
# via kokoro-fastapi-gpu (pyproject.toml)
|
||||||
|
tqdm==4.67.1
|
||||||
|
# via
|
||||||
|
# kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
# huggingface-hub
|
||||||
|
# transformers
|
||||||
|
transformers==4.47.1
|
||||||
|
# via kokoro-fastapi-gpu (pyproject.toml)
|
||||||
|
typing-extensions==4.12.2
|
||||||
|
# via
|
||||||
|
# anyio
|
||||||
|
# fastapi
|
||||||
|
# huggingface-hub
|
||||||
|
# phonemizer
|
||||||
|
# pydantic
|
||||||
|
# pydantic-core
|
||||||
|
# sqlalchemy
|
||||||
|
# torch
|
||||||
|
# uvicorn
|
||||||
|
uritemplate==4.1.1
|
||||||
|
# via csvw
|
||||||
|
urllib3==2.3.0
|
||||||
|
# via requests
|
||||||
|
uvicorn==0.34.0
|
||||||
|
# via kokoro-fastapi (../shared/pyproject.toml)
|
||||||
|
win32-setctime==1.2.0
|
||||||
|
# via loguru
|
1914
docker/gpu/uv.lock
generated
Normal file
1914
docker/gpu/uv.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
44
docker/shared/pyproject.toml
Normal file
44
docker/shared/pyproject.toml
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
[project]
|
||||||
|
name = "kokoro-fastapi"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "FastAPI TTS Service"
|
||||||
|
readme = "../README.md"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
dependencies = [
|
||||||
|
# Core dependencies
|
||||||
|
"fastapi==0.115.6",
|
||||||
|
"uvicorn==0.34.0",
|
||||||
|
"click>=8.0.0",
|
||||||
|
"pydantic==2.10.4",
|
||||||
|
"pydantic-settings==2.7.0",
|
||||||
|
"python-dotenv==1.0.1",
|
||||||
|
"sqlalchemy==2.0.27",
|
||||||
|
|
||||||
|
# ML/DL Base
|
||||||
|
"numpy>=1.26.0",
|
||||||
|
"scipy==1.14.1",
|
||||||
|
"onnxruntime==1.20.1",
|
||||||
|
|
||||||
|
# Audio processing
|
||||||
|
"soundfile==0.13.0",
|
||||||
|
|
||||||
|
# Text processing
|
||||||
|
"phonemizer==3.3.0",
|
||||||
|
"regex==2024.11.6",
|
||||||
|
|
||||||
|
# Utilities
|
||||||
|
"aiofiles==23.2.1",
|
||||||
|
"tqdm==4.67.1",
|
||||||
|
"requests==2.32.3",
|
||||||
|
"munch==4.0.0",
|
||||||
|
"tiktoken==0.8.0",
|
||||||
|
"loguru==0.7.3",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
test = [
|
||||||
|
"pytest==8.0.0",
|
||||||
|
"httpx==0.26.0",
|
||||||
|
"pytest-asyncio==0.23.5",
|
||||||
|
"ruff==0.9.1",
|
||||||
|
]
|
|
@ -9,7 +9,7 @@ sqlalchemy==2.0.27
|
||||||
|
|
||||||
# ML/DL
|
# ML/DL
|
||||||
transformers==4.47.1
|
transformers==4.47.1
|
||||||
numpy==2.2.1
|
numpy>=1.26.0 # Version managed by PyTorch dependencies
|
||||||
scipy==1.14.1
|
scipy==1.14.1
|
||||||
onnxruntime==1.20.1
|
onnxruntime==1.20.1
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ phonemizer==3.3.0
|
||||||
regex==2024.11.6
|
regex==2024.11.6
|
||||||
|
|
||||||
# Utilities
|
# Utilities
|
||||||
aiofiles==24.1.0
|
aiofiles==23.2.1 # Last version before Windows path handling changes
|
||||||
tqdm==4.67.1
|
tqdm==4.67.1
|
||||||
requests==2.32.3
|
requests==2.32.3
|
||||||
munch==4.0.0
|
munch==4.0.0
|
243
docs/requirements.txt
Normal file
243
docs/requirements.txt
Normal file
|
@ -0,0 +1,243 @@
|
||||||
|
# This file was autogenerated by uv via the following command:
|
||||||
|
# uv pip compile docs/requirements.in --universal --output-file docs/requirements.txt
|
||||||
|
aiofiles==23.2.1
|
||||||
|
# via -r docs/requirements.in
|
||||||
|
annotated-types==0.7.0
|
||||||
|
# via pydantic
|
||||||
|
anyio==4.8.0
|
||||||
|
# via
|
||||||
|
# httpx
|
||||||
|
# starlette
|
||||||
|
attrs==24.3.0
|
||||||
|
# via
|
||||||
|
# clldutils
|
||||||
|
# csvw
|
||||||
|
# jsonschema
|
||||||
|
# phonemizer
|
||||||
|
# referencing
|
||||||
|
babel==2.16.0
|
||||||
|
# via csvw
|
||||||
|
certifi==2024.12.14
|
||||||
|
# via
|
||||||
|
# httpcore
|
||||||
|
# httpx
|
||||||
|
# requests
|
||||||
|
cffi==1.17.1
|
||||||
|
# via soundfile
|
||||||
|
charset-normalizer==3.4.1
|
||||||
|
# via requests
|
||||||
|
click==8.1.8
|
||||||
|
# via uvicorn
|
||||||
|
clldutils==3.21.0
|
||||||
|
# via segments
|
||||||
|
colorama==0.4.6
|
||||||
|
# via
|
||||||
|
# click
|
||||||
|
# colorlog
|
||||||
|
# csvw
|
||||||
|
# loguru
|
||||||
|
# pytest
|
||||||
|
# tqdm
|
||||||
|
coloredlogs==15.0.1
|
||||||
|
# via onnxruntime
|
||||||
|
colorlog==6.9.0
|
||||||
|
# via clldutils
|
||||||
|
csvw==3.5.1
|
||||||
|
# via segments
|
||||||
|
dlinfo==1.2.1
|
||||||
|
# via phonemizer
|
||||||
|
exceptiongroup==1.2.2 ; python_full_version < '3.11'
|
||||||
|
# via
|
||||||
|
# anyio
|
||||||
|
# pytest
|
||||||
|
fastapi==0.115.6
|
||||||
|
# via -r docs/requirements.in
|
||||||
|
filelock==3.16.1
|
||||||
|
# via
|
||||||
|
# huggingface-hub
|
||||||
|
# transformers
|
||||||
|
flatbuffers==24.12.23
|
||||||
|
# via onnxruntime
|
||||||
|
fsspec==2024.12.0
|
||||||
|
# via huggingface-hub
|
||||||
|
greenlet==3.1.1 ; platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64'
|
||||||
|
# via sqlalchemy
|
||||||
|
h11==0.14.0
|
||||||
|
# via
|
||||||
|
# httpcore
|
||||||
|
# uvicorn
|
||||||
|
httpcore==1.0.7
|
||||||
|
# via httpx
|
||||||
|
httpx==0.26.0
|
||||||
|
# via -r docs/requirements.in
|
||||||
|
huggingface-hub==0.27.1
|
||||||
|
# via
|
||||||
|
# tokenizers
|
||||||
|
# transformers
|
||||||
|
humanfriendly==10.0
|
||||||
|
# via coloredlogs
|
||||||
|
idna==3.10
|
||||||
|
# via
|
||||||
|
# anyio
|
||||||
|
# httpx
|
||||||
|
# requests
|
||||||
|
iniconfig==2.0.0
|
||||||
|
# via pytest
|
||||||
|
isodate==0.7.2
|
||||||
|
# via
|
||||||
|
# csvw
|
||||||
|
# rdflib
|
||||||
|
joblib==1.4.2
|
||||||
|
# via phonemizer
|
||||||
|
jsonschema==4.23.0
|
||||||
|
# via csvw
|
||||||
|
jsonschema-specifications==2024.10.1
|
||||||
|
# via jsonschema
|
||||||
|
language-tags==1.2.0
|
||||||
|
# via csvw
|
||||||
|
loguru==0.7.3
|
||||||
|
# via -r docs/requirements.in
|
||||||
|
lxml==5.3.0
|
||||||
|
# via clldutils
|
||||||
|
markdown==3.7
|
||||||
|
# via clldutils
|
||||||
|
markupsafe==3.0.2
|
||||||
|
# via clldutils
|
||||||
|
mpmath==1.3.0
|
||||||
|
# via sympy
|
||||||
|
munch==4.0.0
|
||||||
|
# via -r docs/requirements.in
|
||||||
|
numpy==2.2.1
|
||||||
|
# via
|
||||||
|
# -r docs/requirements.in
|
||||||
|
# onnxruntime
|
||||||
|
# scipy
|
||||||
|
# soundfile
|
||||||
|
# transformers
|
||||||
|
onnxruntime==1.20.1
|
||||||
|
# via -r docs/requirements.in
|
||||||
|
packaging==24.2
|
||||||
|
# via
|
||||||
|
# huggingface-hub
|
||||||
|
# onnxruntime
|
||||||
|
# pytest
|
||||||
|
# transformers
|
||||||
|
phonemizer==3.3.0
|
||||||
|
# via -r docs/requirements.in
|
||||||
|
pluggy==1.5.0
|
||||||
|
# via pytest
|
||||||
|
protobuf==5.29.3
|
||||||
|
# via onnxruntime
|
||||||
|
pycparser==2.22
|
||||||
|
# via cffi
|
||||||
|
pydantic==2.10.4
|
||||||
|
# via
|
||||||
|
# -r docs/requirements.in
|
||||||
|
# fastapi
|
||||||
|
# pydantic-settings
|
||||||
|
pydantic-core==2.27.2
|
||||||
|
# via pydantic
|
||||||
|
pydantic-settings==2.7.0
|
||||||
|
# via -r docs/requirements.in
|
||||||
|
pylatexenc==2.10
|
||||||
|
# via clldutils
|
||||||
|
pyparsing==3.2.1
|
||||||
|
# via rdflib
|
||||||
|
pyreadline3==3.5.4 ; sys_platform == 'win32'
|
||||||
|
# via humanfriendly
|
||||||
|
pytest==8.0.0
|
||||||
|
# via
|
||||||
|
# -r docs/requirements.in
|
||||||
|
# pytest-asyncio
|
||||||
|
pytest-asyncio==0.23.5
|
||||||
|
# via -r docs/requirements.in
|
||||||
|
python-dateutil==2.9.0.post0
|
||||||
|
# via
|
||||||
|
# clldutils
|
||||||
|
# csvw
|
||||||
|
python-dotenv==1.0.1
|
||||||
|
# via
|
||||||
|
# -r docs/requirements.in
|
||||||
|
# pydantic-settings
|
||||||
|
pyyaml==6.0.2
|
||||||
|
# via
|
||||||
|
# huggingface-hub
|
||||||
|
# transformers
|
||||||
|
rdflib==7.1.2
|
||||||
|
# via csvw
|
||||||
|
referencing==0.35.1
|
||||||
|
# via
|
||||||
|
# jsonschema
|
||||||
|
# jsonschema-specifications
|
||||||
|
regex==2024.11.6
|
||||||
|
# via
|
||||||
|
# -r docs/requirements.in
|
||||||
|
# segments
|
||||||
|
# tiktoken
|
||||||
|
# transformers
|
||||||
|
requests==2.32.3
|
||||||
|
# via
|
||||||
|
# -r docs/requirements.in
|
||||||
|
# csvw
|
||||||
|
# huggingface-hub
|
||||||
|
# tiktoken
|
||||||
|
# transformers
|
||||||
|
rfc3986==1.5.0
|
||||||
|
# via csvw
|
||||||
|
rpds-py==0.22.3
|
||||||
|
# via
|
||||||
|
# jsonschema
|
||||||
|
# referencing
|
||||||
|
safetensors==0.5.2
|
||||||
|
# via transformers
|
||||||
|
scipy==1.14.1
|
||||||
|
# via -r docs/requirements.in
|
||||||
|
segments==2.2.1
|
||||||
|
# via phonemizer
|
||||||
|
six==1.17.0
|
||||||
|
# via python-dateutil
|
||||||
|
sniffio==1.3.1
|
||||||
|
# via
|
||||||
|
# anyio
|
||||||
|
# httpx
|
||||||
|
soundfile==0.13.0
|
||||||
|
# via -r docs/requirements.in
|
||||||
|
sqlalchemy==2.0.27
|
||||||
|
# via -r docs/requirements.in
|
||||||
|
starlette==0.41.3
|
||||||
|
# via fastapi
|
||||||
|
sympy==1.13.3
|
||||||
|
# via onnxruntime
|
||||||
|
tabulate==0.9.0
|
||||||
|
# via clldutils
|
||||||
|
tiktoken==0.8.0
|
||||||
|
# via -r docs/requirements.in
|
||||||
|
tokenizers==0.21.0
|
||||||
|
# via transformers
|
||||||
|
tomli==2.2.1 ; python_full_version < '3.11'
|
||||||
|
# via pytest
|
||||||
|
tqdm==4.67.1
|
||||||
|
# via
|
||||||
|
# -r docs/requirements.in
|
||||||
|
# huggingface-hub
|
||||||
|
# transformers
|
||||||
|
transformers==4.47.1
|
||||||
|
# via -r docs/requirements.in
|
||||||
|
typing-extensions==4.12.2
|
||||||
|
# via
|
||||||
|
# anyio
|
||||||
|
# fastapi
|
||||||
|
# huggingface-hub
|
||||||
|
# phonemizer
|
||||||
|
# pydantic
|
||||||
|
# pydantic-core
|
||||||
|
# sqlalchemy
|
||||||
|
# uvicorn
|
||||||
|
uritemplate==4.1.1
|
||||||
|
# via csvw
|
||||||
|
urllib3==2.3.0
|
||||||
|
# via requests
|
||||||
|
uvicorn==0.34.0
|
||||||
|
# via -r docs/requirements.in
|
||||||
|
win32-setctime==1.2.0 ; sys_platform == 'win32'
|
||||||
|
# via loguru
|
2
examples/requirements.txt
Normal file
2
examples/requirements.txt
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
openai>=1.0.0
|
||||||
|
pyaudio>=0.2.13
|
Binary file not shown.
|
@ -8,7 +8,7 @@ import requests
|
||||||
import sounddevice as sd
|
import sounddevice as sd
|
||||||
|
|
||||||
|
|
||||||
def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"):
|
def play_streaming_tts(text: str, output_file: str = None, voice: str = "af_sky"):
|
||||||
"""Stream TTS audio and play it back in real-time"""
|
"""Stream TTS audio and play it back in real-time"""
|
||||||
|
|
||||||
print("\nStarting TTS stream request...")
|
print("\nStarting TTS stream request...")
|
||||||
|
|
90
pyproject.toml
Normal file
90
pyproject.toml
Normal file
|
@ -0,0 +1,90 @@
|
||||||
|
[project]
|
||||||
|
name = "kokoro-fastapi"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "FastAPI TTS Service"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
dependencies = [
|
||||||
|
# Core dependencies
|
||||||
|
"fastapi==0.115.6",
|
||||||
|
"uvicorn==0.34.0",
|
||||||
|
"click>=8.0.0",
|
||||||
|
"pydantic==2.10.4",
|
||||||
|
"pydantic-settings==2.7.0",
|
||||||
|
"python-dotenv==1.0.1",
|
||||||
|
"sqlalchemy==2.0.27",
|
||||||
|
# ML/DL Base
|
||||||
|
"numpy>=1.26.0",
|
||||||
|
"scipy==1.14.1",
|
||||||
|
"onnxruntime==1.20.1",
|
||||||
|
# Audio processing
|
||||||
|
"soundfile==0.13.0",
|
||||||
|
# Text processing
|
||||||
|
"phonemizer==3.3.0",
|
||||||
|
"regex==2024.11.6",
|
||||||
|
# Utilities
|
||||||
|
"aiofiles==23.2.1",
|
||||||
|
"tqdm==4.67.1",
|
||||||
|
"requests==2.32.3",
|
||||||
|
"munch==4.0.0",
|
||||||
|
"tiktoken==0.8.0",
|
||||||
|
"loguru==0.7.3",
|
||||||
|
"transformers==4.47.1",
|
||||||
|
"openai>=1.59.6",
|
||||||
|
"ebooklib>=0.18",
|
||||||
|
"html2text>=2024.2.26",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
gpu = [
|
||||||
|
"torch==2.5.1+cu121",
|
||||||
|
]
|
||||||
|
cpu = [
|
||||||
|
"torch==2.5.1+cpu",
|
||||||
|
]
|
||||||
|
test = [
|
||||||
|
"pytest==8.0.0",
|
||||||
|
"pytest-cov==4.1.0",
|
||||||
|
"httpx==0.26.0",
|
||||||
|
"pytest-asyncio==0.23.5",
|
||||||
|
"gradio>=5",
|
||||||
|
"openai>=1.59.6",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.uv]
|
||||||
|
conflicts = [
|
||||||
|
[
|
||||||
|
{ extra = "cpu" },
|
||||||
|
{ extra = "gpu" },
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.uv.sources]
|
||||||
|
torch = [
|
||||||
|
{ index = "pytorch-cpu", extra = "cpu" },
|
||||||
|
{ index = "pytorch-cuda", extra = "gpu" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[tool.uv.index]]
|
||||||
|
name = "pytorch-cpu"
|
||||||
|
url = "https://download.pytorch.org/whl/cpu"
|
||||||
|
explicit = true
|
||||||
|
|
||||||
|
[[tool.uv.index]]
|
||||||
|
name = "pytorch-cuda"
|
||||||
|
url = "https://download.pytorch.org/whl/cu121"
|
||||||
|
explicit = true
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=61.0"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[tool.setuptools]
|
||||||
|
package-dir = {"" = "api/src"}
|
||||||
|
packages.find = {where = ["api/src"], namespaces = true}
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
testpaths = ["api/tests", "ui/tests"]
|
||||||
|
python_files = ["test_*.py"]
|
||||||
|
addopts = "--cov=api --cov=ui --cov-report=term-missing --cov-config=.coveragerc"
|
||||||
|
asyncio_mode = "strict"
|
|
@ -1,14 +0,0 @@
|
||||||
# Core dependencies for testing
|
|
||||||
fastapi==0.115.6
|
|
||||||
uvicorn==0.34.0
|
|
||||||
pydantic==2.10.4
|
|
||||||
pydantic-settings==2.7.0
|
|
||||||
python-dotenv==1.0.1
|
|
||||||
sqlalchemy==2.0.27
|
|
||||||
|
|
||||||
# Testing
|
|
||||||
pytest==8.0.0
|
|
||||||
httpx==0.26.0
|
|
||||||
pytest-asyncio==0.23.5
|
|
||||||
pytest-cov==6.0.0
|
|
||||||
gradio==4.19.2
|
|
|
@ -1,9 +1,3 @@
|
||||||
import warnings
|
|
||||||
|
|
||||||
# Filter out Gradio Dropdown warnings about values not in choices
|
|
||||||
#TODO: Warning continues to be displayed, though it isn't breaking anything
|
|
||||||
warnings.filterwarnings('ignore', category=UserWarning, module='gradio.components.dropdown')
|
|
||||||
|
|
||||||
from lib.interface import create_interface
|
from lib.interface import create_interface
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -36,15 +36,18 @@ def check_api_status() -> Tuple[bool, List[str]]:
|
||||||
|
|
||||||
|
|
||||||
def text_to_speech(
|
def text_to_speech(
|
||||||
text: str, voice_id: str, format: str, speed: float
|
text: str, voice_id: str | list, format: str, speed: float
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""Generate speech from text using TTS API."""
|
"""Generate speech from text using TTS API."""
|
||||||
if not text.strip():
|
if not text.strip():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Handle multiple voices
|
||||||
|
voice_str = voice_id if isinstance(voice_id, str) else "+".join(voice_id)
|
||||||
|
|
||||||
# Create output filename
|
# Create output filename
|
||||||
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||||
output_filename = f"output_{timestamp}_voice-{voice_id}_speed-{speed}.{format}"
|
output_filename = f"output_{timestamp}_voice-{voice_str}_speed-{speed}.{format}"
|
||||||
output_path = os.path.join(OUTPUTS_DIR, output_filename)
|
output_path = os.path.join(OUTPUTS_DIR, output_filename)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -53,7 +56,7 @@ def text_to_speech(
|
||||||
json={
|
json={
|
||||||
"model": "kokoro",
|
"model": "kokoro",
|
||||||
"input": text,
|
"input": text,
|
||||||
"voice": voice_id,
|
"voice": voice_str,
|
||||||
"response_format": format,
|
"response_format": format,
|
||||||
"speed": float(speed),
|
"speed": float(speed),
|
||||||
},
|
},
|
||||||
|
|
|
@ -5,18 +5,26 @@ import gradio as gr
|
||||||
from .. import files
|
from .. import files
|
||||||
|
|
||||||
|
|
||||||
def create_input_column() -> Tuple[gr.Column, dict]:
|
def create_input_column(disable_local_saving: bool = False) -> Tuple[gr.Column, dict]:
|
||||||
"""Create the input column with text input and file handling."""
|
"""Create the input column with text input and file handling."""
|
||||||
with gr.Column(scale=1) as col:
|
with gr.Column(scale=1) as col:
|
||||||
|
text_input = gr.Textbox(
|
||||||
|
label="Text to speak", placeholder="Enter text here...", lines=4
|
||||||
|
)
|
||||||
|
|
||||||
|
# Always show file upload but handle differently based on disable_local_saving
|
||||||
|
file_upload = gr.File(
|
||||||
|
label="Upload Text File (.txt)", file_types=[".txt"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if not disable_local_saving:
|
||||||
|
# Show full interface with tabs when saving is enabled
|
||||||
with gr.Tabs() as tabs:
|
with gr.Tabs() as tabs:
|
||||||
# Set first tab as selected by default
|
# Set first tab as selected by default
|
||||||
tabs.selected = 0
|
tabs.selected = 0
|
||||||
# Direct Input Tab
|
# Direct Input Tab
|
||||||
with gr.TabItem("Direct Input"):
|
with gr.TabItem("Direct Input"):
|
||||||
text_input = gr.Textbox(
|
text_submit_direct = gr.Button("Generate Speech", variant="primary", size="lg")
|
||||||
label="Text to speak", placeholder="Enter text here...", lines=4
|
|
||||||
)
|
|
||||||
text_submit = gr.Button("Generate Speech", variant="primary", size="lg")
|
|
||||||
|
|
||||||
# File Input Tab
|
# File Input Tab
|
||||||
with gr.TabItem("From File"):
|
with gr.TabItem("From File"):
|
||||||
|
@ -27,11 +35,6 @@ def create_input_column() -> Tuple[gr.Column, dict]:
|
||||||
value=None,
|
value=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Simple file upload
|
|
||||||
file_upload = gr.File(
|
|
||||||
label="Upload Text File (.txt)", file_types=[".txt"]
|
|
||||||
)
|
|
||||||
|
|
||||||
file_preview = gr.Textbox(
|
file_preview = gr.Textbox(
|
||||||
label="File Content Preview", interactive=False, lines=4
|
label="File Content Preview", interactive=False, lines=4
|
||||||
)
|
)
|
||||||
|
@ -43,14 +46,35 @@ def create_input_column() -> Tuple[gr.Column, dict]:
|
||||||
clear_files = gr.Button(
|
clear_files = gr.Button(
|
||||||
"Clear Files", variant="secondary", size="lg"
|
"Clear Files", variant="secondary", size="lg"
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# Just show the generate button when saving is disabled
|
||||||
|
text_submit_direct = gr.Button("Generate Speech", variant="primary", size="lg")
|
||||||
|
tabs = None
|
||||||
|
input_files_list = None
|
||||||
|
file_preview = None
|
||||||
|
file_submit = None
|
||||||
|
clear_files = None
|
||||||
|
|
||||||
|
# Initialize components based on disable_local_saving
|
||||||
|
if disable_local_saving:
|
||||||
|
components = {
|
||||||
|
"tabs": None,
|
||||||
|
"text_input": text_input,
|
||||||
|
"text_submit": text_submit_direct,
|
||||||
|
"file_select": None,
|
||||||
|
"file_upload": file_upload, # Keep file upload even when saving is disabled
|
||||||
|
"file_preview": None,
|
||||||
|
"file_submit": None,
|
||||||
|
"clear_files": None,
|
||||||
|
}
|
||||||
|
else:
|
||||||
components = {
|
components = {
|
||||||
"tabs": tabs,
|
"tabs": tabs,
|
||||||
"text_input": text_input,
|
"text_input": text_input,
|
||||||
|
"text_submit": text_submit_direct,
|
||||||
"file_select": input_files_list,
|
"file_select": input_files_list,
|
||||||
"file_upload": file_upload,
|
"file_upload": file_upload,
|
||||||
"file_preview": file_preview,
|
"file_preview": file_preview,
|
||||||
"text_submit": text_submit,
|
|
||||||
"file_submit": file_submit,
|
"file_submit": file_submit,
|
||||||
"clear_files": clear_files,
|
"clear_files": clear_files,
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,10 +20,10 @@ def create_model_column(voice_ids: Optional[list] = None) -> Tuple[gr.Column, di
|
||||||
|
|
||||||
voice_input = gr.Dropdown(
|
voice_input = gr.Dropdown(
|
||||||
choices=voice_ids,
|
choices=voice_ids,
|
||||||
label="Voice",
|
label="Voice(s)",
|
||||||
value=voice_ids[0] if voice_ids else None, # Set default value to first item if available
|
value=voice_ids[0] if voice_ids else None,
|
||||||
interactive=True,
|
interactive=True,
|
||||||
allow_custom_value=True, # Allow temporary values during updates
|
multiselect=True,
|
||||||
)
|
)
|
||||||
format_input = gr.Dropdown(
|
format_input = gr.Dropdown(
|
||||||
choices=config.AUDIO_FORMATS, label="Audio Format", value="mp3"
|
choices=config.AUDIO_FORMATS, label="Audio Format", value="mp3"
|
||||||
|
|
|
@ -5,34 +5,43 @@ import gradio as gr
|
||||||
from .. import files
|
from .. import files
|
||||||
|
|
||||||
|
|
||||||
def create_output_column() -> Tuple[gr.Column, dict]:
|
def create_output_column(disable_local_saving: bool = False) -> Tuple[gr.Column, dict]:
|
||||||
"""Create the output column with audio player and file list."""
|
"""Create the output column with audio player and file list."""
|
||||||
with gr.Column(scale=1) as col:
|
with gr.Column(scale=1) as col:
|
||||||
gr.Markdown("### Latest Output")
|
gr.Markdown("### Latest Output")
|
||||||
audio_output = gr.Audio(label="Generated Speech", type="filepath")
|
audio_output = gr.Audio(
|
||||||
|
label="Generated Speech",
|
||||||
|
type="filepath",
|
||||||
|
waveform_options={"waveform_color": "#4C87AB"}
|
||||||
|
)
|
||||||
|
|
||||||
gr.Markdown("### Generated Files")
|
# Create file-related components with visible=False when local saving is disabled
|
||||||
# Initialize dropdown with empty choices first
|
gr.Markdown("### Generated Files", visible=not disable_local_saving)
|
||||||
output_files = gr.Dropdown(
|
output_files = gr.Dropdown(
|
||||||
label="Previous Outputs",
|
label="Previous Outputs",
|
||||||
choices=[],
|
choices=files.list_output_files() if not disable_local_saving else [],
|
||||||
value=None,
|
value=None,
|
||||||
allow_custom_value=True,
|
allow_custom_value=True,
|
||||||
interactive=True,
|
visible=not disable_local_saving,
|
||||||
)
|
)
|
||||||
# Then update choices after component creation
|
|
||||||
output_files.choices = files.list_output_files()
|
|
||||||
|
|
||||||
play_btn = gr.Button("▶️ Play Selected", size="sm")
|
play_btn = gr.Button(
|
||||||
|
"▶️ Play Selected",
|
||||||
|
size="sm",
|
||||||
|
visible=not disable_local_saving,
|
||||||
|
)
|
||||||
|
|
||||||
selected_audio = gr.Audio(
|
selected_audio = gr.Audio(
|
||||||
label="Selected Output", type="filepath", visible=False
|
label="Selected Output",
|
||||||
|
type="filepath",
|
||||||
|
visible=False, # Always initially hidden
|
||||||
)
|
)
|
||||||
|
|
||||||
clear_outputs = gr.Button(
|
clear_outputs = gr.Button(
|
||||||
"⚠️ Delete All Previously Generated Output Audio 🗑️",
|
"⚠️ Delete All Previously Generated Output Audio 🗑️",
|
||||||
size="sm",
|
size="sm",
|
||||||
variant="secondary",
|
variant="secondary",
|
||||||
|
visible=not disable_local_saving,
|
||||||
)
|
)
|
||||||
|
|
||||||
components = {
|
components = {
|
||||||
|
|
|
@ -11,12 +11,14 @@ def list_input_files() -> List[str]:
|
||||||
|
|
||||||
|
|
||||||
def list_output_files() -> List[str]:
|
def list_output_files() -> List[str]:
|
||||||
"""List all output audio files."""
|
"""List all output audio files, sorted by most recent first."""
|
||||||
# Just return filenames since paths will be different inside/outside container
|
files = [
|
||||||
return [
|
os.path.join(OUTPUTS_DIR, f)
|
||||||
f for f in os.listdir(OUTPUTS_DIR)
|
for f in os.listdir(OUTPUTS_DIR)
|
||||||
if any(f.endswith(ext) for ext in AUDIO_FORMATS)
|
if any(f.endswith(ext) for ext in AUDIO_FORMATS)
|
||||||
]
|
]
|
||||||
|
# Sort files by modification time, most recent first
|
||||||
|
return sorted(files, key=os.path.getmtime, reverse=True)
|
||||||
|
|
||||||
|
|
||||||
def read_text_file(filename: str) -> str:
|
def read_text_file(filename: str) -> str:
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from . import api, files
|
from . import api, files
|
||||||
|
|
||||||
|
|
||||||
def setup_event_handlers(components: dict):
|
def setup_event_handlers(components: dict, disable_local_saving: bool = False):
|
||||||
"""Set up all event handlers for the UI components."""
|
"""Set up all event handlers for the UI components."""
|
||||||
|
|
||||||
def refresh_status():
|
def refresh_status():
|
||||||
|
@ -57,10 +58,20 @@ def setup_event_handlers(components: dict):
|
||||||
|
|
||||||
def handle_file_upload(file):
|
def handle_file_upload(file):
|
||||||
if file is None:
|
if file is None:
|
||||||
return gr.update(choices=files.list_input_files())
|
return "" if disable_local_saving else [gr.update(choices=files.list_input_files())]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Copy file to inputs directory
|
# Read the file content
|
||||||
|
with open(file.name, 'r', encoding='utf-8') as f:
|
||||||
|
text_content = f.read()
|
||||||
|
|
||||||
|
if disable_local_saving:
|
||||||
|
# When saving is disabled, put content directly in text input
|
||||||
|
# Normalize whitespace by replacing newlines with spaces
|
||||||
|
normalized_text = ' '.join(text_content.split())
|
||||||
|
return normalized_text
|
||||||
|
else:
|
||||||
|
# When saving is enabled, save file and update dropdown
|
||||||
filename = os.path.basename(file.name)
|
filename = os.path.basename(file.name)
|
||||||
target_path = os.path.join(files.INPUTS_DIR, filename)
|
target_path = os.path.join(files.INPUTS_DIR, filename)
|
||||||
|
|
||||||
|
@ -73,11 +84,11 @@ def setup_event_handlers(components: dict):
|
||||||
counter += 1
|
counter += 1
|
||||||
|
|
||||||
shutil.copy2(file.name, target_path)
|
shutil.copy2(file.name, target_path)
|
||||||
|
return [gr.update(choices=files.list_input_files())]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error uploading file: {e}")
|
print(f"Error handling file: {e}")
|
||||||
|
return "" if disable_local_saving else [gr.update(choices=files.list_input_files())]
|
||||||
return gr.update(choices=files.list_input_files())
|
|
||||||
|
|
||||||
def generate_from_text(text, voice, format, speed):
|
def generate_from_text(text, voice, format, speed):
|
||||||
"""Generate speech from direct text input"""
|
"""Generate speech from direct text input"""
|
||||||
|
@ -90,18 +101,20 @@ def setup_event_handlers(components: dict):
|
||||||
gr.Warning("Please enter text in the input box")
|
gr.Warning("Please enter text in the input box")
|
||||||
return [None, gr.update(choices=files.list_output_files())]
|
return [None, gr.update(choices=files.list_output_files())]
|
||||||
|
|
||||||
|
# Only save text if local saving is enabled
|
||||||
|
if not disable_local_saving:
|
||||||
files.save_text(text)
|
files.save_text(text)
|
||||||
|
|
||||||
result = api.text_to_speech(text, voice, format, speed)
|
result = api.text_to_speech(text, voice, format, speed)
|
||||||
if result is None:
|
if result is None:
|
||||||
gr.Warning("Failed to generate speech. Please try again.")
|
gr.Warning("Failed to generate speech. Please try again.")
|
||||||
return [None, gr.update(choices=files.list_output_files())]
|
return [None, gr.update(choices=files.list_output_files())]
|
||||||
|
|
||||||
# Update list and select the newly generated file
|
|
||||||
output_files = files.list_output_files()
|
|
||||||
last_file = output_files[-1] if output_files else None
|
|
||||||
return [
|
return [
|
||||||
result,
|
result,
|
||||||
gr.update(choices=output_files, value=last_file),
|
gr.update(
|
||||||
|
choices=files.list_output_files(), value=os.path.basename(result)
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
def generate_from_file(selected_file, voice, format, speed):
|
def generate_from_file(selected_file, voice, format, speed):
|
||||||
|
@ -121,18 +134,15 @@ def setup_event_handlers(components: dict):
|
||||||
gr.Warning("Failed to generate speech. Please try again.")
|
gr.Warning("Failed to generate speech. Please try again.")
|
||||||
return [None, gr.update(choices=files.list_output_files())]
|
return [None, gr.update(choices=files.list_output_files())]
|
||||||
|
|
||||||
# Update list and select the newly generated file
|
|
||||||
output_files = files.list_output_files()
|
|
||||||
last_file = output_files[-1] if output_files else None
|
|
||||||
return [
|
return [
|
||||||
result,
|
result,
|
||||||
gr.update(choices=output_files, value=last_file),
|
gr.update(
|
||||||
|
choices=files.list_output_files(), value=os.path.basename(result)
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
def play_selected(filename):
|
def play_selected(file_path):
|
||||||
if filename:
|
if file_path and os.path.exists(file_path):
|
||||||
file_path = os.path.join(files.OUTPUTS_DIR, filename)
|
|
||||||
if os.path.exists(file_path):
|
|
||||||
return gr.update(value=file_path, visible=True)
|
return gr.update(value=file_path, visible=True)
|
||||||
return gr.update(visible=False)
|
return gr.update(visible=False)
|
||||||
|
|
||||||
|
@ -165,25 +175,45 @@ def setup_event_handlers(components: dict):
|
||||||
outputs=[components["model"]["status_btn"], components["model"]["voice"]],
|
outputs=[components["model"]["status_btn"], components["model"]["voice"]],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Connect text submit button (always present)
|
||||||
|
components["input"]["text_submit"].click(
|
||||||
|
fn=generate_from_text,
|
||||||
|
inputs=[
|
||||||
|
components["input"]["text_input"],
|
||||||
|
components["model"]["voice"],
|
||||||
|
components["model"]["format"],
|
||||||
|
components["model"]["speed"],
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
components["output"]["audio_output"],
|
||||||
|
components["output"]["output_files"],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only connect file-related handlers if components exist
|
||||||
|
if components["input"]["file_select"] is not None:
|
||||||
components["input"]["file_select"].change(
|
components["input"]["file_select"].change(
|
||||||
fn=handle_file_select,
|
fn=handle_file_select,
|
||||||
inputs=[components["input"]["file_select"]],
|
inputs=[components["input"]["file_select"]],
|
||||||
outputs=[components["input"]["file_preview"]],
|
outputs=[components["input"]["file_preview"]],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if components["input"]["file_upload"] is not None:
|
||||||
|
# File upload handler - output depends on disable_local_saving
|
||||||
components["input"]["file_upload"].upload(
|
components["input"]["file_upload"].upload(
|
||||||
fn=handle_file_upload,
|
fn=handle_file_upload,
|
||||||
inputs=[components["input"]["file_upload"]],
|
inputs=[components["input"]["file_upload"]],
|
||||||
outputs=[components["input"]["file_select"]],
|
outputs=[components["input"]["text_input"] if disable_local_saving else components["input"]["file_select"]],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if components["output"]["play_btn"] is not None:
|
||||||
components["output"]["play_btn"].click(
|
components["output"]["play_btn"].click(
|
||||||
fn=play_selected,
|
fn=play_selected,
|
||||||
inputs=[components["output"]["output_files"]],
|
inputs=[components["output"]["output_files"]],
|
||||||
outputs=[components["output"]["selected_audio"]],
|
outputs=[components["output"]["selected_audio"]],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Connect clear files button
|
if components["input"]["clear_files"] is not None:
|
||||||
components["input"]["clear_files"].click(
|
components["input"]["clear_files"].click(
|
||||||
fn=clear_files,
|
fn=clear_files,
|
||||||
inputs=[
|
inputs=[
|
||||||
|
@ -203,22 +233,7 @@ def setup_event_handlers(components: dict):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Connect submit buttons for each tab
|
if components["output"]["clear_outputs"] is not None:
|
||||||
components["input"]["text_submit"].click(
|
|
||||||
fn=generate_from_text,
|
|
||||||
inputs=[
|
|
||||||
components["input"]["text_input"],
|
|
||||||
components["model"]["voice"],
|
|
||||||
components["model"]["format"],
|
|
||||||
components["model"]["speed"],
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
components["output"]["audio_output"],
|
|
||||||
components["output"]["output_files"],
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Connect clear outputs button
|
|
||||||
components["output"]["clear_outputs"].click(
|
components["output"]["clear_outputs"].click(
|
||||||
fn=clear_outputs,
|
fn=clear_outputs,
|
||||||
outputs=[
|
outputs=[
|
||||||
|
@ -228,6 +243,7 @@ def setup_event_handlers(components: dict):
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if components["input"]["file_submit"] is not None:
|
||||||
components["input"]["file_submit"].click(
|
components["input"]["file_submit"].click(
|
||||||
fn=generate_from_file,
|
fn=generate_from_file,
|
||||||
inputs=[
|
inputs=[
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
import os
|
||||||
|
|
||||||
from . import api
|
from . import api
|
||||||
from .handlers import setup_event_handlers
|
from .handlers import setup_event_handlers
|
||||||
|
@ -10,6 +11,9 @@ def create_interface():
|
||||||
# Skip initial status check - let the timer handle it
|
# Skip initial status check - let the timer handle it
|
||||||
is_available, available_voices = False, []
|
is_available, available_voices = False, []
|
||||||
|
|
||||||
|
# Check if local saving is disabled
|
||||||
|
disable_local_saving = os.getenv("DISABLE_LOCAL_SAVING", "false").lower() == "true"
|
||||||
|
|
||||||
with gr.Blocks(title="Kokoro TTS Demo", theme=gr.themes.Monochrome()) as demo:
|
with gr.Blocks(title="Kokoro TTS Demo", theme=gr.themes.Monochrome()) as demo:
|
||||||
gr.HTML(
|
gr.HTML(
|
||||||
value='<div style="display: flex; gap: 0;">'
|
value='<div style="display: flex; gap: 0;">'
|
||||||
|
@ -22,11 +26,11 @@ def create_interface():
|
||||||
# Main interface
|
# Main interface
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
# Create columns
|
# Create columns
|
||||||
input_col, input_components = create_input_column()
|
input_col, input_components = create_input_column(disable_local_saving)
|
||||||
model_col, model_components = create_model_column(
|
model_col, model_components = create_model_column(
|
||||||
available_voices
|
available_voices
|
||||||
) # Pass initial voices
|
) # Pass initial voices
|
||||||
output_col, output_components = create_output_column()
|
output_col, output_components = create_output_column(disable_local_saving)
|
||||||
|
|
||||||
# Collect all components
|
# Collect all components
|
||||||
components = {
|
components = {
|
||||||
|
@ -36,7 +40,7 @@ def create_interface():
|
||||||
}
|
}
|
||||||
|
|
||||||
# Set up event handlers
|
# Set up event handlers
|
||||||
setup_event_handlers(components)
|
setup_event_handlers(components, disable_local_saving)
|
||||||
|
|
||||||
# Add periodic status check with Timer
|
# Add periodic status check with Timer
|
||||||
def update_status():
|
def update_status():
|
||||||
|
|
|
@ -106,11 +106,21 @@ def test_get_status_html_unavailable():
|
||||||
|
|
||||||
def test_text_to_speech_api_params(mock_response, tmp_path):
|
def test_text_to_speech_api_params(mock_response, tmp_path):
|
||||||
"""Test correct API parameters are sent"""
|
"""Test correct API parameters are sent"""
|
||||||
|
test_cases = [
|
||||||
|
# Single voice as string
|
||||||
|
("voice1", "voice1"),
|
||||||
|
# Multiple voices as list
|
||||||
|
(["voice1", "voice2"], "voice1+voice2"),
|
||||||
|
# Single voice as list
|
||||||
|
(["voice1"], "voice1"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for input_voice, expected_voice in test_cases:
|
||||||
with patch("requests.post") as mock_post, patch(
|
with patch("requests.post") as mock_post, patch(
|
||||||
"ui.lib.api.OUTPUTS_DIR", str(tmp_path)
|
"ui.lib.api.OUTPUTS_DIR", str(tmp_path)
|
||||||
), patch("builtins.open", mock_open()):
|
), patch("builtins.open", mock_open()):
|
||||||
mock_post.return_value = mock_response({})
|
mock_post.return_value = mock_response({})
|
||||||
api.text_to_speech("test text", "voice1", "mp3", 1.5)
|
api.text_to_speech("test text", input_voice, "mp3", 1.5)
|
||||||
|
|
||||||
mock_post.assert_called_once()
|
mock_post.assert_called_once()
|
||||||
args, kwargs = mock_post.call_args
|
args, kwargs = mock_post.call_args
|
||||||
|
@ -119,7 +129,7 @@ def test_text_to_speech_api_params(mock_response, tmp_path):
|
||||||
assert kwargs["json"] == {
|
assert kwargs["json"] == {
|
||||||
"model": "kokoro",
|
"model": "kokoro",
|
||||||
"input": "test text",
|
"input": "test text",
|
||||||
"voice": "voice1",
|
"voice": expected_voice,
|
||||||
"response_format": "mp3",
|
"response_format": "mp3",
|
||||||
"speed": 1.5,
|
"speed": 1.5,
|
||||||
}
|
}
|
||||||
|
@ -127,3 +137,23 @@ def test_text_to_speech_api_params(mock_response, tmp_path):
|
||||||
# Check headers and timeout
|
# Check headers and timeout
|
||||||
assert kwargs["headers"] == {"Content-Type": "application/json"}
|
assert kwargs["headers"] == {"Content-Type": "application/json"}
|
||||||
assert kwargs["timeout"] == 300
|
assert kwargs["timeout"] == 300
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_to_speech_output_filename(mock_response, tmp_path):
|
||||||
|
"""Test output filename contains correct voice identifier"""
|
||||||
|
test_cases = [
|
||||||
|
# Single voice
|
||||||
|
("voice1", lambda f: "voice-voice1" in f),
|
||||||
|
# Multiple voices
|
||||||
|
(["voice1", "voice2"], lambda f: "voice-voice1+voice2" in f),
|
||||||
|
]
|
||||||
|
|
||||||
|
for input_voice, filename_check in test_cases:
|
||||||
|
with patch("requests.post", return_value=mock_response({})), patch(
|
||||||
|
"ui.lib.api.OUTPUTS_DIR", str(tmp_path)
|
||||||
|
), patch("builtins.open", mock_open()) as mock_file:
|
||||||
|
result = api.text_to_speech("test text", input_voice, "mp3", 1.0)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert filename_check(result), f"Expected voice pattern not found in filename: {result}"
|
||||||
|
mock_file.assert_called_once()
|
||||||
|
|
|
@ -36,8 +36,10 @@ def test_model_column_default_values():
|
||||||
expected_choices = [(voice_id, voice_id) for voice_id in voice_ids]
|
expected_choices = [(voice_id, voice_id) for voice_id in voice_ids]
|
||||||
assert components["voice"].choices == expected_choices
|
assert components["voice"].choices == expected_choices
|
||||||
# Value is not converted to tuple format for the value property
|
# Value is not converted to tuple format for the value property
|
||||||
assert components["voice"].value == voice_ids[0]
|
assert components["voice"].value == [voice_ids[0]]
|
||||||
assert components["voice"].interactive is True
|
assert components["voice"].interactive is True
|
||||||
|
assert components["voice"].multiselect is True
|
||||||
|
assert components["voice"].label == "Voice(s)"
|
||||||
|
|
||||||
# Test format dropdown
|
# Test format dropdown
|
||||||
# Gradio Dropdown converts choices to (value, label) tuples
|
# Gradio Dropdown converts choices to (value, label) tuples
|
||||||
|
|
|
@ -136,7 +136,7 @@ def test_interface_components_presence():
|
||||||
|
|
||||||
required_components = {
|
required_components = {
|
||||||
"Text to speak",
|
"Text to speak",
|
||||||
"Voice",
|
"Voice(s)",
|
||||||
"Audio Format",
|
"Audio Format",
|
||||||
"Speed",
|
"Speed",
|
||||||
"Generated Speech",
|
"Generated Speech",
|
||||||
|
|
Loading…
Add table
Reference in a new issue