Merge pre-release: Update CI workflow for uv

This commit is contained in:
remsky 2025-01-14 04:15:50 -07:00
commit 5045cf968e
69 changed files with 9202 additions and 762 deletions

View file

@ -7,6 +7,7 @@ omit =
MagicMock/*
test_*.py
examples/*
src/builds/*
[report]
exclude_lines =

View file

@ -1,51 +1,32 @@
# name: CI
name: CI
# on:
# push:
# branches: [ "develop", "master" ]
# pull_request:
# branches: [ "develop", "master" ]
on:
push:
branches: [ "master", "pre-release" ]
pull_request:
branches: [ "master", "pre-release" ]
# jobs:
# test:
# runs-on: ubuntu-latest
# strategy:
# matrix:
# python-version: ["3.9", "3.10", "3.11"]
# fail-fast: false
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]
fail-fast: false
# steps:
# - uses: actions/checkout@v4
steps:
- uses: actions/checkout@v4
# - name: Set up Python ${{ matrix.python-version }}
# uses: actions/setup-python@v5
# with:
# python-version: ${{ matrix.python-version }}
# - name: Set up pip cache
# uses: actions/cache@v3
# with:
# path: ~/.cache/pip
# key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements*.txt') }}
# restore-keys: |
# ${{ runner.os }}-pip-
- name: Install uv
uses: astral-sh/setup-uv@v5
with:
python-version: ${{ matrix.python-version }}
enable-cache: true
# - name: Install PyTorch CPU
# run: |
# python -m pip install --upgrade pip
# pip install torch --index-url https://download.pytorch.org/whl/cpu
- name: Install dependencies
run: |
uv pip install -e .[test,cpu]
# - name: Install dependencies
# run: |
# pip install ruff pytest-cov
# pip install -r requirements.txt
# pip install -r requirements-test.txt
# - name: Lint with ruff
# run: |
# ruff check .
# - name: Test with pytest
# run: |
# pytest --asyncio-mode=auto --cov=api --cov-report=term-missing
- name: Run Tests
run: |
uv run pytest api/tests/ --asyncio-mode=auto --cov=api --cov-report=term-missing

View file

@ -1,7 +1,9 @@
name: Docker Build and Publish
name: Docker Build, Slim, and Publish
on:
push:
branches:
- master
tags: [ 'v*.*.*' ]
# Allow manual trigger from GitHub UI
workflow_dispatch:
@ -16,6 +18,7 @@ jobs:
permissions:
contents: read
packages: write
actions: write
steps:
- name: Checkout repository
@ -28,67 +31,76 @@ jobs:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
# Extract metadata for GPU image
- name: Extract metadata (tags, labels) for GPU Docker
id: meta-gpu
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
tags: |
type=semver,pattern=v{{version}}
type=semver,pattern=v{{major}}.{{minor}}
type=semver,pattern=v{{major}}
type=raw,value=latest
# Set up image names (converting to lowercase)
- name: Set image names
run: |
echo "GPU_IMAGE_NAME=${{ env.REGISTRY }}/$(echo ${{ env.IMAGE_NAME }} | tr '[:upper:]' '[:lower:]')-gpu" >> $GITHUB_ENV
echo "CPU_IMAGE_NAME=${{ env.REGISTRY }}/$(echo ${{ env.IMAGE_NAME }} | tr '[:upper:]' '[:lower:]')-cpu" >> $GITHUB_ENV
echo "UI_IMAGE_NAME=${{ env.REGISTRY }}/$(echo ${{ env.IMAGE_NAME }} | tr '[:upper:]' '[:lower:]')-ui" >> $GITHUB_ENV
# Extract metadata for CPU image
- name: Extract metadata (tags, labels) for CPU Docker
id: meta-cpu
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
flavor: |
suffix=-cpu
tags: |
type=semver,pattern=v{{version}}
type=semver,pattern=v{{major}}.{{minor}}
type=semver,pattern=v{{major}}
type=raw,value=latest
# Build and push GPU version
- name: Build and push GPU Docker image
# Build GPU version
- name: Build GPU Docker image
uses: docker/build-push-action@v5
with:
context: .
file: ./Dockerfile
push: true
tags: ${{ steps.meta-gpu.outputs.tags }}
labels: ${{ steps.meta-gpu.outputs.labels }}
file: ./docker/gpu/Dockerfile
push: false
load: true
tags: ${{ env.GPU_IMAGE_NAME }}:v0.1.0
build-args: |
DOCKER_BUILDKIT=1
platforms: linux/amd64
# Build and push CPU version
- name: Build and push CPU Docker image
# Slim GPU version
- name: Slim GPU Docker image
uses: kitabisa/docker-slim-action@v1
env:
DSLIM_HTTP_PROBE: false
with:
target: ${{ env.GPU_IMAGE_NAME }}:v0.1.0
tag: v0.1.0-slim
# Push GPU versions
- name: Push GPU Docker images
run: |
docker push ${{ env.GPU_IMAGE_NAME }}:v0.1.0
docker push ${{ env.GPU_IMAGE_NAME }}:v0.1.0-slim
docker tag ${{ env.GPU_IMAGE_NAME }}:v0.1.0 ${{ env.GPU_IMAGE_NAME }}:latest
docker tag ${{ env.GPU_IMAGE_NAME }}:v0.1.0-slim ${{ env.GPU_IMAGE_NAME }}:latest-slim
docker push ${{ env.GPU_IMAGE_NAME }}:latest
docker push ${{ env.GPU_IMAGE_NAME }}:latest-slim
# Build CPU version
- name: Build CPU Docker image
uses: docker/build-push-action@v5
with:
context: .
file: ./Dockerfile.cpu
push: true
tags: ${{ steps.meta-cpu.outputs.tags }}
labels: ${{ steps.meta-cpu.outputs.labels }}
file: ./docker/cpu/Dockerfile
push: false
load: true
tags: ${{ env.CPU_IMAGE_NAME }}:v0.1.0
build-args: |
DOCKER_BUILDKIT=1
platforms: linux/amd64
# Extract metadata for UI image
- name: Extract metadata (tags, labels) for UI Docker
id: meta-ui
uses: docker/metadata-action@v5
# Slim CPU version
- name: Slim CPU Docker image
uses: kitabisa/docker-slim-action@v1
env:
DSLIM_HTTP_PROBE: false
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
flavor: |
suffix=-ui
tags: |
type=semver,pattern=v{{version}}
type=semver,pattern=v{{major}}.{{minor}}
type=semver,pattern=v{{major}}
type=raw,value=latest
target: ${{ env.CPU_IMAGE_NAME }}:v0.1.0
tag: v0.1.0-slim
# Push CPU versions
- name: Push CPU Docker images
run: |
docker push ${{ env.CPU_IMAGE_NAME }}:v0.1.0
docker push ${{ env.CPU_IMAGE_NAME }}:v0.1.0-slim
docker tag ${{ env.CPU_IMAGE_NAME }}:v0.1.0 ${{ env.CPU_IMAGE_NAME }}:latest
docker tag ${{ env.CPU_IMAGE_NAME }}:v0.1.0-slim ${{ env.CPU_IMAGE_NAME }}:latest-slim
docker push ${{ env.CPU_IMAGE_NAME }}:latest
docker push ${{ env.CPU_IMAGE_NAME }}:latest-slim
# Build and push UI version
- name: Build and push UI Docker image
@ -97,8 +109,11 @@ jobs:
context: ./ui
file: ./ui/Dockerfile
push: true
tags: ${{ steps.meta-ui.outputs.tags }}
labels: ${{ steps.meta-ui.outputs.labels }}
tags: |
${{ env.UI_IMAGE_NAME }}:v0.1.0
${{ env.UI_IMAGE_NAME }}:latest
build-args: |
DOCKER_BUILDKIT=1
platforms: linux/amd64
create-release:
@ -108,13 +123,16 @@ jobs:
if: startsWith(github.ref, 'refs/tags/')
permissions:
contents: write
packages: write
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Create Release
uses: softprops/action-gh-release@v1
env:
IS_PRERELEASE: ${{ contains(github.ref, '-pre') }}
with:
generate_release_notes: true
draft: false
prerelease: false
prerelease: ${{ contains(github.ref, '-pre') }}

71
.gitignore vendored
View file

@ -2,51 +2,78 @@
.git
# Python
__pycache__
__pycache__/
*.pyc
*.pyo
*.pyd
*.pt
.Python
*.py[cod]
*$py.class
.Python
.pytest_cache
.coverage
.coveragerc
# Python package build artifacts
*.egg-info/
*.egg
dist/
build/
# Environment
# .env
.venv
.venv/
env/
venv/
ENV/
# IDE
.idea
.vscode
.idea/
.vscode/
*.swp
*.swo
# Project specific
*examples/*.wav
*examples/*.pcm
*examples/*.mp3
*examples/*.flac
*examples/*.acc
*examples/*.ogg
# Model files
*.pt
*.pth
*.tar*
# Voice files
api/src/voices/af_bella.pt
api/src/voices/af_nicole.pt
api/src/voices/af_sarah.pt
api/src/voices/af_sky.pt
api/src/voices/af.pt
api/src/voices/am_adam.pt
api/src/voices/am_michael.pt
api/src/voices/bf_emma.pt
api/src/voices/bf_isabella.pt
api/src/voices/bm_george.pt
api/src/voices/bm_lewis.pt
# Audio files
examples/*.wav
examples/*.pcm
examples/*.mp3
examples/*.flac
examples/*.acc
examples/*.ogg
examples/speech.mp3
examples/phoneme_examples/output/example_1.wav
examples/phoneme_examples/output/example_2.wav
examples/phoneme_examples/output/example_3.wav
# Other project files
Kokoro-82M/
ui/data
tests/
*.md
*.txt
requirements.txt
ui/data/
EXTERNAL_UV_DOCUMENTATION*
# Docker
Dockerfile*
docker-compose*
*.egg-info
*.pt
*.wav
*.tar*
examples/assorted_checks/River_of_Teet_-_Sarah_Gailey.epub
examples/ebook_test/chapter_to_audio.py
examples/ebook_test/chapters_to_audio.py
examples/ebook_test/parse_epub.py
examples/ebook_test/River_of_Teet_-_Sarah_Gailey.epub
examples/ebook_test/River_of_Teet_-_Sarah_Gailey.txt

1
.python-version Normal file
View file

@ -0,0 +1 @@
3.10

View file

@ -1,11 +1,12 @@
line-length = 88
exclude = ["examples"]
[lint]
select = ["I"]
[lint.isort]
combine-as-imports = true
force-wrap-aliases = true
length-sort = true
split-on-trailing-comma = true
section-order = ["future", "standard-library", "third-party", "first-party", "local-folder"]

View file

@ -2,6 +2,17 @@
Notable changes to this project will be documented in this file.
## [v0.1.0] - 2025-01-13
### Changed
- Major Docker improvements:
- Baked model directly into Dockerfile for improved deployment reliability
- Switched to uv for dependency management
- Streamlined container builds and reduced image sizes
- Dependency Management:
- Migrated from pip/poetry to uv for faster, more reliable package management
- Added uv.lock for deterministic builds
- Updated dependency resolution strategy
## [v0.0.5post1] - 2025-01-11
### Fixed
- Docker image tagging and versioning improvements (-gpu, -cpu, -ui)

View file

@ -1,44 +0,0 @@
FROM nvidia/cuda:12.1.0-base-ubuntu22.04
# Install base system dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
python3-pip \
python3-dev \
espeak-ng \
git \
libsndfile1 \
curl \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# Install PyTorch with CUDA support first
RUN pip3 install --no-cache-dir torch==2.5.1 --extra-index-url https://download.pytorch.org/whl/cu121
# Install all other dependencies from requirements.txt
COPY requirements.txt .
RUN pip3 install --no-cache-dir -r requirements.txt
# Set working directory
WORKDIR /app
# Create non-root user
RUN useradd -m -u 1000 appuser
# Create model directory and set ownership
RUN mkdir -p /app/Kokoro-82M && \
chown -R appuser:appuser /app
# Switch to non-root user
USER appuser
# Run with Python unbuffered output for live logging
ENV PYTHONUNBUFFERED=1
# Copy only necessary application code
COPY --chown=appuser:appuser api /app/api
# Set Python path (app first for our imports, then model dir for model imports)
ENV PYTHONPATH=/app:/app/Kokoro-82M
# Run FastAPI server with debug logging and reload
CMD ["uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]

View file

@ -1,44 +0,0 @@
FROM ubuntu:22.04
# Install base system dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
python3-pip \
python3-dev \
espeak-ng \
git \
libsndfile1 \
curl \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# Install PyTorch CPU version and ONNX runtime
RUN pip3 install --no-cache-dir torch==2.5.1 --extra-index-url https://download.pytorch.org/whl/cpu
# Install all other dependencies from requirements.txt
COPY requirements.txt .
RUN pip3 install --no-cache-dir -r requirements.txt
# Copy application code and model
COPY . /app/
# Set working directory
WORKDIR /app
# Run with Python unbuffered output for live logging
ENV PYTHONUNBUFFERED=1
# Create non-root user
RUN useradd -m -u 1000 appuser
# Create directories and set permissions
RUN mkdir -p /app/Kokoro-82M && \
chown -R appuser:appuser /app
# Switch to non-root user
USER appuser
# Set Python path (app first for our imports, then model dir for model imports)
ENV PYTHONPATH=/app:/app/Kokoro-82M
# Run FastAPI server with debug logging and reload
CMD ["uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]

70
MigrationWorkingNotes.md Normal file
View file

@ -0,0 +1,70 @@
# UV Setup
Deprecated notes for myself
## Structure
```
docker/
├── cpu/
│ ├── pyproject.toml # CPU deps (torch CPU)
│ └── requirements.lock # CPU lockfile
├── gpu/
│ ├── pyproject.toml # GPU deps (torch CUDA)
│ └── requirements.lock # GPU lockfile
└── shared/
└── pyproject.toml # Common deps
```
## Regenerate Lock Files
### CPU
```bash
cd docker/cpu
uv pip compile pyproject.toml ../shared/pyproject.toml --output-file requirements.lock
```
### GPU
```bash
cd docker/gpu
uv pip compile pyproject.toml ../shared/pyproject.toml --output-file requirements.lock
```
## Local Dev Setup
### CPU
```bash
cd docker/cpu
uv venv
.venv\Scripts\activate # Windows
uv pip sync requirements.lock
```
### GPU
```bash
cd docker/gpu
uv venv
.venv\Scripts\activate # Windows
uv pip sync requirements.lock --extra-index-url https://download.pytorch.org/whl/cu121 --index-strategy unsafe-best-match
```
### Run Server
```bash
# From project root with venv active:
uvicorn api.src.main:app --reload
```
## Docker
### CPU
```bash
cd docker/cpu
docker compose up
```
### GPU
```bash
cd docker/gpu
docker compose up
```
## Known Issues
- Module imports: Run server from project root
- PyTorch CUDA: Always use --extra-index-url and --index-strategy for GPU env

View file

@ -2,9 +2,9 @@
<img src="githubbanner.png" alt="Kokoro TTS Banner">
</p>
# Kokoro TTS API
# <sub><sub>_`FastKoko`_ </sub></sub>
[![Tests](https://img.shields.io/badge/tests-117%20passed-darkgreen)]()
[![Coverage](https://img.shields.io/badge/coverage-75%25-darkgreen)]()
[![Coverage](https://img.shields.io/badge/coverage-60%25-grey)]()
[![Tested at Model Commit](https://img.shields.io/badge/last--tested--model--commit-a67f113-blue)](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667) [![Try on Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Try%20on-Spaces-blue)](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero) [![Buy Me A Coffee](https://img.shields.io/badge/BMC-✨☕-gray?style=flat-square)](https://www.buymeacoffee.com/remsky)
Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model
@ -24,14 +24,30 @@ Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokor
The service can be accessed through either the API endpoints or the Gradio web interface.
1. Install prerequisites:
- Install [Docker Desktop](https://www.docker.com/products/docker-desktop/) + [Git](https://git-scm.com/downloads)
- Clone and start the service:
- Install [Docker Desktop](https://www.docker.com/products/docker-desktop/)
- Clone the repository:
```bash
git clone https://github.com/remsky/Kokoro-FastAPI.git
cd Kokoro-FastAPI
docker compose up --build # for GPU
#docker compose -f docker-compose.cpu.yml up --build # for CPU
```
2. Start the service:
- Using Docker Compose (Full setup including UI):
```bash
cd docker/gpu # OR
# cd docker/cpu # Run this or the above
docker compose up --build
```
- OR running the API alone using Docker (model + voice packs baked in):
```bash
docker run -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-cpu:latest # CPU
docker run --gpus all -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-gpu:latest # Nvidia GPU
# Minified versions are available with `:latest-slim` tag.
```
2. Run locally as an OpenAI-Compatible Speech Endpoint
```python
from openai import OpenAI
@ -167,6 +183,21 @@ If you only want the API, just comment out everything in the docker-compose.yml
Currently, voices created via the API are accessible here, but voice combination/creation has not yet been added
*Note: Recent updates for streaming could lead to temporary glitches. If so, pull from the most recent stable release v0.0.2 to restore*
### Disabling Local Saving
You can disable local saving of audio files and hide the file view in the UI by setting the `DISABLE_LOCAL_SAVING` environment variable to `true`. This is useful when running the service on a server where you don't want to store generated audio files locally.
When using Docker Compose:
```yaml
environment:
- DISABLE_LOCAL_SAVING=true
```
When running the Docker image directly:
```bash
docker run -p 7860:7860 -e DISABLE_LOCAL_SAVING=true ghcr.io/remsky/kokoro-fastapi-ui:latest
```
</details>
<details>
@ -320,6 +351,27 @@ See `examples/phoneme_examples/generate_phonemes.py` for a sample script.
## Known Issues
<details>
<summary>Versioning & Development</summary>
I'm doing what I can to keep things stable, but we are on an early and rapid set of build cycles here.
If you run into trouble, you may have to roll back a version on the release tags if something comes up, or build up from source and/or troubleshoot + submit a PR. Will leave the branch up here for the last known stable points:
`v0.0.5post1`
Free and open source is a community effort, and I love working on this project, though there's only really so many hours in a day. If you'd like to support the work, feel free to open a PR, buy me a coffee, or report any bugs/features/etc you find during use.
<a href="https://www.buymeacoffee.com/remsky" target="_blank">
<img
src="https://cdn.buymeacoffee.com/buttons/v2/default-violet.png"
alt="Buy Me A Coffee"
style="height: 30px !important;width: 110px !important;"
>
</a>
</details>
<details>
<summary>Linux GPU Permissions</summary>

View file

View file

@ -0,0 +1,26 @@
{
"decoder": {
"type": "istftnet",
"upsample_kernel_sizes": [20, 12],
"upsample_rates": [10, 6],
"gen_istft_hop_size": 5,
"gen_istft_n_fft": 20,
"resblock_dilation_sizes": [
[1, 3, 5],
[1, 3, 5],
[1, 3, 5]
],
"resblock_kernel_sizes": [3, 7, 11],
"upsample_initial_channel": 512
},
"dim_in": 64,
"dropout": 0.2,
"hidden_dim": 512,
"max_conv_dim": 512,
"max_dur": 50,
"multispeaker": true,
"n_layer": 3,
"n_mels": 80,
"n_token": 178,
"style_dim": 128
}

524
api/src/builds/istftnet.py Normal file
View file

@ -0,0 +1,524 @@
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/istftnet.py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.signal import get_window
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn.utils import remove_weight_norm, weight_norm
# https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size*dilation - dilation)/2)
LRELU_SLOPE = 0.1
class AdaIN1d(nn.Module):
def __init__(self, style_dim, num_features):
super().__init__()
self.norm = nn.InstanceNorm1d(num_features, affine=False)
self.fc = nn.Linear(style_dim, num_features*2)
def forward(self, x, s):
h = self.fc(s)
h = h.view(h.size(0), h.size(1), 1)
gamma, beta = torch.chunk(h, chunks=2, dim=1)
return (1 + gamma) * self.norm(x) + beta
class AdaINResBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
super(AdaINResBlock1, self).__init__()
self.convs1 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2])))
])
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1)))
])
self.convs2.apply(init_weights)
self.adain1 = nn.ModuleList([
AdaIN1d(style_dim, channels),
AdaIN1d(style_dim, channels),
AdaIN1d(style_dim, channels),
])
self.adain2 = nn.ModuleList([
AdaIN1d(style_dim, channels),
AdaIN1d(style_dim, channels),
AdaIN1d(style_dim, channels),
])
self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
def forward(self, x, s):
for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
xt = n1(x, s)
xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
xt = c1(xt)
xt = n2(xt, s)
xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class TorchSTFT(torch.nn.Module):
def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
super().__init__()
self.filter_length = filter_length
self.hop_length = hop_length
self.win_length = win_length
self.window = torch.from_numpy(get_window(window, win_length, fftbins=True).astype(np.float32))
def transform(self, input_data):
forward_transform = torch.stft(
input_data,
self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device),
return_complex=True)
return torch.abs(forward_transform), torch.angle(forward_transform)
def inverse(self, magnitude, phase):
inverse_transform = torch.istft(
magnitude * torch.exp(phase * 1j),
self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device))
return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
def forward(self, input_data):
self.magnitude, self.phase = self.transform(input_data)
reconstruction = self.inverse(self.magnitude, self.phase)
return reconstruction
class SineGen(torch.nn.Module):
""" Definition of sine generator
SineGen(samp_rate, harmonic_num = 0,
sine_amp = 0.1, noise_std = 0.003,
voiced_threshold = 0,
flag_for_pulse=False)
samp_rate: sampling rate in Hz
harmonic_num: number of harmonic overtones (default 0)
sine_amp: amplitude of sine-wavefrom (default 0.1)
noise_std: std of Gaussian noise (default 0.003)
voiced_thoreshold: F0 threshold for U/V classification (default 0)
flag_for_pulse: this SinGen is used inside PulseGen (default False)
Note: when flag_for_pulse is True, the first time step of a voiced
segment is always sin(np.pi) or cos(0)
"""
def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
sine_amp=0.1, noise_std=0.003,
voiced_threshold=0,
flag_for_pulse=False):
super(SineGen, self).__init__()
self.sine_amp = sine_amp
self.noise_std = noise_std
self.harmonic_num = harmonic_num
self.dim = self.harmonic_num + 1
self.sampling_rate = samp_rate
self.voiced_threshold = voiced_threshold
self.flag_for_pulse = flag_for_pulse
self.upsample_scale = upsample_scale
def _f02uv(self, f0):
# generate uv signal
uv = (f0 > self.voiced_threshold).type(torch.float32)
return uv
def _f02sine(self, f0_values):
""" f0_values: (batchsize, length, dim)
where dim indicates fundamental tone and overtones
"""
# convert to F0 in rad. The interger part n can be ignored
# because 2 * np.pi * n doesn't affect phase
rad_values = (f0_values / self.sampling_rate) % 1
# initial phase noise (no noise for fundamental component)
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
device=f0_values.device)
rand_ini[:, 0] = 0
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
if not self.flag_for_pulse:
# # for normal case
# # To prevent torch.cumsum numerical overflow,
# # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
# # Buffer tmp_over_one_idx indicates the time step to add -1.
# # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
# tmp_over_one = torch.cumsum(rad_values, 1) % 1
# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
# cumsum_shift = torch.zeros_like(rad_values)
# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
# phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
scale_factor=1/self.upsample_scale,
mode="linear").transpose(1, 2)
# tmp_over_one = torch.cumsum(rad_values, 1) % 1
# tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
# cumsum_shift = torch.zeros_like(rad_values)
# cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
sines = torch.sin(phase)
else:
# If necessary, make sure that the first time step of every
# voiced segments is sin(pi) or cos(0)
# This is used for pulse-train generation
# identify the last time step in unvoiced segments
uv = self._f02uv(f0_values)
uv_1 = torch.roll(uv, shifts=-1, dims=1)
uv_1[:, -1, :] = 1
u_loc = (uv < 1) * (uv_1 > 0)
# get the instantanouse phase
tmp_cumsum = torch.cumsum(rad_values, dim=1)
# different batch needs to be processed differently
for idx in range(f0_values.shape[0]):
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
# stores the accumulation of i.phase within
# each voiced segments
tmp_cumsum[idx, :, :] = 0
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
# rad_values - tmp_cumsum: remove the accumulation of i.phase
# within the previous voiced segment.
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
# get the sines
sines = torch.cos(i_phase * 2 * np.pi)
return sines
def forward(self, f0):
""" sine_tensor, uv = forward(f0)
input F0: tensor(batchsize=1, length, dim=1)
f0 for unvoiced steps should be 0
output sine_tensor: tensor(batchsize=1, length, dim)
output uv: tensor(batchsize=1, length, 1)
"""
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
device=f0.device)
# fundamental component
fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
# generate sine waveforms
sine_waves = self._f02sine(fn) * self.sine_amp
# generate uv signal
# uv = torch.ones(f0.shape)
# uv = uv * (f0 > self.voiced_threshold)
uv = self._f02uv(f0)
# noise: for unvoiced should be similar to sine_amp
# std = self.sine_amp/3 -> max value ~ self.sine_amp
# . for voiced regions is self.noise_std
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
noise = noise_amp * torch.randn_like(sine_waves)
# first: set the unvoiced part to 0 by uv
# then: additive noise
sine_waves = sine_waves * uv + noise
return sine_waves, uv, noise
class SourceModuleHnNSF(torch.nn.Module):
""" SourceModule for hn-nsf
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
add_noise_std=0.003, voiced_threshod=0)
sampling_rate: sampling_rate in Hz
harmonic_num: number of harmonic above F0 (default: 0)
sine_amp: amplitude of sine source signal (default: 0.1)
add_noise_std: std of additive Gaussian noise (default: 0.003)
note that amplitude of noise in unvoiced is decided
by sine_amp
voiced_threshold: threhold to set U/V given F0 (default: 0)
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
F0_sampled (batchsize, length, 1)
Sine_source (batchsize, length, 1)
noise_source (batchsize, length 1)
uv (batchsize, length, 1)
"""
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
add_noise_std=0.003, voiced_threshod=0):
super(SourceModuleHnNSF, self).__init__()
self.sine_amp = sine_amp
self.noise_std = add_noise_std
# to produce sine waveforms
self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
sine_amp, add_noise_std, voiced_threshod)
# to merge source harmonics into a single excitation
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
self.l_tanh = torch.nn.Tanh()
def forward(self, x):
"""
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
F0_sampled (batchsize, length, 1)
Sine_source (batchsize, length, 1)
noise_source (batchsize, length 1)
"""
# source for harmonic branch
with torch.no_grad():
sine_wavs, uv, _ = self.l_sin_gen(x)
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
# source for noise branch, in the same shape as uv
noise = torch.randn_like(uv) * self.sine_amp / 3
return sine_merge, noise, uv
def padDiff(x):
return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0)
class Generator(torch.nn.Module):
def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size):
super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
resblock = AdaINResBlock1
self.m_source = SourceModuleHnNSF(
sampling_rate=24000,
upsample_scale=np.prod(upsample_rates) * gen_istft_hop_size,
harmonic_num=8, voiced_threshod=10)
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * gen_istft_hop_size)
self.noise_convs = nn.ModuleList()
self.noise_res = nn.ModuleList()
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(weight_norm(
ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
k, u, padding=(k-u)//2)))
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel//(2**(i+1))
for j, (k, d) in enumerate(zip(resblock_kernel_sizes,resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d, style_dim))
c_cur = upsample_initial_channel // (2 ** (i + 1))
if i + 1 < len(upsample_rates): #
stride_f0 = np.prod(upsample_rates[i + 1:])
self.noise_convs.append(Conv1d(
gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
self.noise_res.append(resblock(c_cur, 7, [1,3,5], style_dim))
else:
self.noise_convs.append(Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1))
self.noise_res.append(resblock(c_cur, 11, [1,3,5], style_dim))
self.post_n_fft = gen_istft_n_fft
self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
self.stft = TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
def forward(self, x, s, f0):
with torch.no_grad():
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
har_source, noi_source, uv = self.m_source(f0)
har_source = har_source.transpose(1, 2).squeeze(1)
har_spec, har_phase = self.stft.transform(har_source)
har = torch.cat([har_spec, har_phase], dim=1)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, LRELU_SLOPE)
x_source = self.noise_convs[i](har)
x_source = self.noise_res[i](x_source, s)
x = self.ups[i](x)
if i == self.num_upsamples - 1:
x = self.reflection_pad(x)
x = x + x_source
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i*self.num_kernels+j](x, s)
else:
xs += self.resblocks[i*self.num_kernels+j](x, s)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
return self.stft.inverse(spec, phase)
def fw_phase(self, x, s):
for i in range(self.num_upsamples):
x = F.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i*self.num_kernels+j](x, s)
else:
xs += self.resblocks[i*self.num_kernels+j](x, s)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.reflection_pad(x)
x = self.conv_post(x)
spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
return spec, phase
def remove_weight_norm(self):
print('Removing weight norm...')
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
class AdainResBlk1d(nn.Module):
def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
upsample='none', dropout_p=0.0):
super().__init__()
self.actv = actv
self.upsample_type = upsample
self.upsample = UpSample1d(upsample)
self.learned_sc = dim_in != dim_out
self._build_weights(dim_in, dim_out, style_dim)
self.dropout = nn.Dropout(dropout_p)
if upsample == 'none':
self.pool = nn.Identity()
else:
self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
def _build_weights(self, dim_in, dim_out, style_dim):
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
self.norm1 = AdaIN1d(style_dim, dim_in)
self.norm2 = AdaIN1d(style_dim, dim_out)
if self.learned_sc:
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
def _shortcut(self, x):
x = self.upsample(x)
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x, s):
x = self.norm1(x, s)
x = self.actv(x)
x = self.pool(x)
x = self.conv1(self.dropout(x))
x = self.norm2(x, s)
x = self.actv(x)
x = self.conv2(self.dropout(x))
return x
def forward(self, x, s):
out = self._residual(x, s)
out = (out + self._shortcut(x)) / np.sqrt(2)
return out
class UpSample1d(nn.Module):
def __init__(self, layer_type):
super().__init__()
self.layer_type = layer_type
def forward(self, x):
if self.layer_type == 'none':
return x
else:
return F.interpolate(x, scale_factor=2, mode='nearest')
class Decoder(nn.Module):
def __init__(self, dim_in=512, F0_channel=512, style_dim=64, dim_out=80,
resblock_kernel_sizes = [3,7,11],
upsample_rates = [10, 6],
upsample_initial_channel=512,
resblock_dilation_sizes=[[1,3,5], [1,3,5], [1,3,5]],
upsample_kernel_sizes=[20, 12],
gen_istft_n_fft=20, gen_istft_hop_size=5):
super().__init__()
self.decode = nn.ModuleList()
self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
self.asr_res = nn.Sequential(
weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
)
self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates,
upsample_initial_channel, resblock_dilation_sizes,
upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size)
def forward(self, asr, F0_curve, N, s):
F0 = self.F0_conv(F0_curve.unsqueeze(1))
N = self.N_conv(N.unsqueeze(1))
x = torch.cat([asr, F0, N], axis=1)
x = self.encode(x, s)
asr_res = self.asr_res(asr)
res = True
for block in self.decode:
if res:
x = torch.cat([x, asr_res, F0, N], axis=1)
x = block(x, s)
if block.upsample_type != "none":
res = False
x = self.generator(x, s, F0_curve)
return x

151
api/src/builds/kokoro.py Normal file
View file

@ -0,0 +1,151 @@
import re
import phonemizer
import torch
def split_num(num):
num = num.group()
if '.' in num:
return num
elif ':' in num:
h, m = [int(n) for n in num.split(':')]
if m == 0:
return f"{h} o'clock"
elif m < 10:
return f'{h} oh {m}'
return f'{h} {m}'
year = int(num[:4])
if year < 1100 or year % 1000 < 10:
return num
left, right = num[:2], int(num[2:4])
s = 's' if num.endswith('s') else ''
if 100 <= year % 1000 <= 999:
if right == 0:
return f'{left} hundred{s}'
elif right < 10:
return f'{left} oh {right}{s}'
return f'{left} {right}{s}'
def flip_money(m):
m = m.group()
bill = 'dollar' if m[0] == '$' else 'pound'
if m[-1].isalpha():
return f'{m[1:]} {bill}s'
elif '.' not in m:
s = '' if m[1:] == '1' else 's'
return f'{m[1:]} {bill}{s}'
b, c = m[1:].split('.')
s = '' if b == '1' else 's'
c = int(c.ljust(2, '0'))
coins = f"cent{'' if c == 1 else 's'}" if m[0] == '$' else ('penny' if c == 1 else 'pence')
return f'{b} {bill}{s} and {c} {coins}'
def point_num(num):
a, b = num.group().split('.')
return ' point '.join([a, ' '.join(b)])
def normalize_text(text):
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
text = text.replace('«', chr(8220)).replace('»', chr(8221))
text = text.replace(chr(8220), '"').replace(chr(8221), '"')
text = text.replace('(', '«').replace(')', '»')
for a, b in zip('、。!,:;?', ',.!,:;?'):
text = text.replace(a, b+' ')
text = re.sub(r'[^\S \n]', ' ', text)
text = re.sub(r' +', ' ', text)
text = re.sub(r'(?<=\n) +(?=\n)', '', text)
text = re.sub(r'\bD[Rr]\.(?= [A-Z])', 'Doctor', text)
text = re.sub(r'\b(?:Mr\.|MR\.(?= [A-Z]))', 'Mister', text)
text = re.sub(r'\b(?:Ms\.|MS\.(?= [A-Z]))', 'Miss', text)
text = re.sub(r'\b(?:Mrs\.|MRS\.(?= [A-Z]))', 'Mrs', text)
text = re.sub(r'\betc\.(?! [A-Z])', 'etc', text)
text = re.sub(r'(?i)\b(y)eah?\b', r"\1e'a", text)
text = re.sub(r'\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)', split_num, text)
text = re.sub(r'(?<=\d),(?=\d)', '', text)
text = re.sub(r'(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b', flip_money, text)
text = re.sub(r'\d*\.\d+', point_num, text)
text = re.sub(r'(?<=\d)-(?=\d)', ' to ', text)
text = re.sub(r'(?<=\d)S', ' S', text)
text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
text = re.sub(r"(?<=X')S\b", 's', text)
text = re.sub(r'(?:[A-Za-z]\.){2,} [a-z]', lambda m: m.group().replace('.', '-'), text)
text = re.sub(r'(?i)(?<=[A-Z])\.(?=[A-Z])', '-', text)
return text.strip()
def get_vocab():
_pad = "$"
_punctuation = ';:,.!?¡¿—…"«»“” '
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'"
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
dicts = {}
for i in range(len((symbols))):
dicts[symbols[i]] = i
return dicts
VOCAB = get_vocab()
def tokenize(ps):
return [i for i in map(VOCAB.get, ps) if i is not None]
phonemizers = dict(
a=phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True),
b=phonemizer.backend.EspeakBackend(language='en-gb', preserve_punctuation=True, with_stress=True),
)
def phonemize(text, lang, norm=True):
if norm:
text = normalize_text(text)
ps = phonemizers[lang].phonemize([text])
ps = ps[0] if ps else ''
# https://en.wiktionary.org/wiki/kokoro#English
ps = ps.replace('kəkˈoːɹoʊ', 'kˈoʊkəɹoʊ').replace('kəkˈɔːɹəʊ', 'kˈəʊkəɹəʊ')
ps = ps.replace('ʲ', 'j').replace('r', 'ɹ').replace('x', 'k').replace('ɬ', 'l')
ps = re.sub(r'(?<=[a-zɹː])(?=hˈʌndɹɪd)', ' ', ps)
ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»“” ]|$)', 'z', ps)
if lang == 'a':
ps = re.sub(r'(?<=nˈaɪn)ti(?!ː)', 'di', ps)
ps = ''.join(filter(lambda p: p in VOCAB, ps))
return ps.strip()
def length_to_mask(lengths):
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
mask = torch.gt(mask+1, lengths.unsqueeze(1))
return mask
@torch.no_grad()
def forward(model, tokens, ref_s, speed):
device = ref_s.device
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
text_mask = length_to_mask(input_lengths).to(device)
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
s = ref_s[:, 128:]
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
x, _ = model.predictor.lstm(d)
duration = model.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long()
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame:c_frame + pred_dur[0,i].item()] = 1
c_frame += pred_dur[0,i].item()
en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
t_en = model.text_encoder(tokens, input_lengths, text_mask)
asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
def generate(model, text, voicepack, lang='a', speed=1, ps=None):
ps = ps or phonemize(text, lang)
tokens = tokenize(ps)
if not tokens:
return None
elif len(tokens) > 510:
tokens = tokens[:510]
print('Truncated to 510 tokens')
ref_s = voicepack[len(tokens)]
out = forward(model, tokens, ref_s, speed)
ps = ''.join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
return out, ps

375
api/src/builds/models.py Normal file
View file

@ -0,0 +1,375 @@
# https://github.com/yl4579/StyleTTS2/blob/main/models.py
import json
import os
import os.path as osp
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from munch import Munch
from torch.nn.utils import spectral_norm, weight_norm
from .istftnet import AdaIN1d, Decoder
from .plbert import load_plbert
class LinearNorm(torch.nn.Module):
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
super(LinearNorm, self).__init__()
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
torch.nn.init.xavier_uniform_(
self.linear_layer.weight,
gain=torch.nn.init.calculate_gain(w_init_gain))
def forward(self, x):
return self.linear_layer(x)
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.gamma = nn.Parameter(torch.ones(channels))
self.beta = nn.Parameter(torch.zeros(channels))
def forward(self, x):
x = x.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
return x.transpose(1, -1)
class TextEncoder(nn.Module):
def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
super().__init__()
self.embedding = nn.Embedding(n_symbols, channels)
padding = (kernel_size - 1) // 2
self.cnn = nn.ModuleList()
for _ in range(depth):
self.cnn.append(nn.Sequential(
weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
LayerNorm(channels),
actv,
nn.Dropout(0.2),
))
# self.cnn = nn.Sequential(*self.cnn)
self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)
def forward(self, x, input_lengths, m):
x = self.embedding(x) # [B, T, emb]
x = x.transpose(1, 2) # [B, emb, T]
m = m.to(input_lengths.device).unsqueeze(1)
x.masked_fill_(m, 0.0)
for c in self.cnn:
x = c(x)
x.masked_fill_(m, 0.0)
x = x.transpose(1, 2) # [B, T, chn]
input_lengths = input_lengths.cpu().numpy()
x = nn.utils.rnn.pack_padded_sequence(
x, input_lengths, batch_first=True, enforce_sorted=False)
self.lstm.flatten_parameters()
x, _ = self.lstm(x)
x, _ = nn.utils.rnn.pad_packed_sequence(
x, batch_first=True)
x = x.transpose(-1, -2)
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
x_pad[:, :, :x.shape[-1]] = x
x = x_pad.to(x.device)
x.masked_fill_(m, 0.0)
return x
def inference(self, x):
x = self.embedding(x)
x = x.transpose(1, 2)
x = self.cnn(x)
x = x.transpose(1, 2)
self.lstm.flatten_parameters()
x, _ = self.lstm(x)
return x
def length_to_mask(self, lengths):
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
mask = torch.gt(mask+1, lengths.unsqueeze(1))
return mask
class UpSample1d(nn.Module):
def __init__(self, layer_type):
super().__init__()
self.layer_type = layer_type
def forward(self, x):
if self.layer_type == 'none':
return x
else:
return F.interpolate(x, scale_factor=2, mode='nearest')
class AdainResBlk1d(nn.Module):
def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2),
upsample='none', dropout_p=0.0):
super().__init__()
self.actv = actv
self.upsample_type = upsample
self.upsample = UpSample1d(upsample)
self.learned_sc = dim_in != dim_out
self._build_weights(dim_in, dim_out, style_dim)
self.dropout = nn.Dropout(dropout_p)
if upsample == 'none':
self.pool = nn.Identity()
else:
self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
def _build_weights(self, dim_in, dim_out, style_dim):
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
self.norm1 = AdaIN1d(style_dim, dim_in)
self.norm2 = AdaIN1d(style_dim, dim_out)
if self.learned_sc:
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
def _shortcut(self, x):
x = self.upsample(x)
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x, s):
x = self.norm1(x, s)
x = self.actv(x)
x = self.pool(x)
x = self.conv1(self.dropout(x))
x = self.norm2(x, s)
x = self.actv(x)
x = self.conv2(self.dropout(x))
return x
def forward(self, x, s):
out = self._residual(x, s)
out = (out + self._shortcut(x)) / np.sqrt(2)
return out
class AdaLayerNorm(nn.Module):
def __init__(self, style_dim, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.fc = nn.Linear(style_dim, channels*2)
def forward(self, x, s):
x = x.transpose(-1, -2)
x = x.transpose(1, -1)
h = self.fc(s)
h = h.view(h.size(0), h.size(1), 1)
gamma, beta = torch.chunk(h, chunks=2, dim=1)
gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), eps=self.eps)
x = (1 + gamma) * x + beta
return x.transpose(1, -1).transpose(-1, -2)
class ProsodyPredictor(nn.Module):
def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
super().__init__()
self.text_encoder = DurationEncoder(sty_dim=style_dim,
d_model=d_hid,
nlayers=nlayers,
dropout=dropout)
self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
self.duration_proj = LinearNorm(d_hid, max_dur)
self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
self.F0 = nn.ModuleList()
self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
self.N = nn.ModuleList()
self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
def forward(self, texts, style, text_lengths, alignment, m):
d = self.text_encoder(texts, style, text_lengths, m)
batch_size = d.shape[0]
text_size = d.shape[1]
# predict duration
input_lengths = text_lengths.cpu().numpy()
x = nn.utils.rnn.pack_padded_sequence(
d, input_lengths, batch_first=True, enforce_sorted=False)
m = m.to(text_lengths.device).unsqueeze(1)
self.lstm.flatten_parameters()
x, _ = self.lstm(x)
x, _ = nn.utils.rnn.pad_packed_sequence(
x, batch_first=True)
x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
x_pad[:, :x.shape[1], :] = x
x = x_pad.to(x.device)
duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=self.training))
en = (d.transpose(-1, -2) @ alignment)
return duration.squeeze(-1), en
def F0Ntrain(self, x, s):
x, _ = self.shared(x.transpose(-1, -2))
F0 = x.transpose(-1, -2)
for block in self.F0:
F0 = block(F0, s)
F0 = self.F0_proj(F0)
N = x.transpose(-1, -2)
for block in self.N:
N = block(N, s)
N = self.N_proj(N)
return F0.squeeze(1), N.squeeze(1)
def length_to_mask(self, lengths):
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
mask = torch.gt(mask+1, lengths.unsqueeze(1))
return mask
class DurationEncoder(nn.Module):
def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
super().__init__()
self.lstms = nn.ModuleList()
for _ in range(nlayers):
self.lstms.append(nn.LSTM(d_model + sty_dim,
d_model // 2,
num_layers=1,
batch_first=True,
bidirectional=True,
dropout=dropout))
self.lstms.append(AdaLayerNorm(sty_dim, d_model))
self.dropout = dropout
self.d_model = d_model
self.sty_dim = sty_dim
def forward(self, x, style, text_lengths, m):
masks = m.to(text_lengths.device)
x = x.permute(2, 0, 1)
s = style.expand(x.shape[0], x.shape[1], -1)
x = torch.cat([x, s], axis=-1)
x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
x = x.transpose(0, 1)
input_lengths = text_lengths.cpu().numpy()
x = x.transpose(-1, -2)
for block in self.lstms:
if isinstance(block, AdaLayerNorm):
x = block(x.transpose(-1, -2), style).transpose(-1, -2)
x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
else:
x = x.transpose(-1, -2)
x = nn.utils.rnn.pack_padded_sequence(
x, input_lengths, batch_first=True, enforce_sorted=False)
block.flatten_parameters()
x, _ = block(x)
x, _ = nn.utils.rnn.pad_packed_sequence(
x, batch_first=True)
x = F.dropout(x, p=self.dropout, training=self.training)
x = x.transpose(-1, -2)
x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
x_pad[:, :, :x.shape[-1]] = x
x = x_pad.to(x.device)
return x.transpose(-1, -2)
def inference(self, x, style):
x = self.embedding(x.transpose(-1, -2)) * np.sqrt(self.d_model)
style = style.expand(x.shape[0], x.shape[1], -1)
x = torch.cat([x, style], axis=-1)
src = self.pos_encoder(x)
output = self.transformer_encoder(src).transpose(0, 1)
return output
def length_to_mask(self, lengths):
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
mask = torch.gt(mask+1, lengths.unsqueeze(1))
return mask
# https://github.com/yl4579/StyleTTS2/blob/main/utils.py
def recursive_munch(d):
if isinstance(d, dict):
return Munch((k, recursive_munch(v)) for k, v in d.items())
elif isinstance(d, list):
return [recursive_munch(v) for v in d]
else:
return d
def build_model(path, device):
config = Path(__file__).parent / 'config.json'
assert config.exists(), f'Config path incorrect: config.json not found at {config}'
with open(config, 'r') as r:
args = recursive_munch(json.load(r))
assert args.decoder.type == 'istftnet', f'Unknown decoder type: {args.decoder.type}'
decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
upsample_rates = args.decoder.upsample_rates,
upsample_initial_channel=args.decoder.upsample_initial_channel,
resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
gen_istft_n_fft=args.decoder.gen_istft_n_fft, gen_istft_hop_size=args.decoder.gen_istft_hop_size)
text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
bert = load_plbert()
bert_encoder = nn.Linear(bert.config.hidden_size, args.hidden_dim)
for parent in [bert, bert_encoder, predictor, decoder, text_encoder]:
for child in parent.children():
if isinstance(child, nn.RNNBase):
child.flatten_parameters()
model = Munch(
bert=bert.to(device).eval(),
bert_encoder=bert_encoder.to(device).eval(),
predictor=predictor.to(device).eval(),
decoder=decoder.to(device).eval(),
text_encoder=text_encoder.to(device).eval(),
)
for key, state_dict in torch.load(path, map_location='cpu', weights_only=True)['net'].items():
assert key in model, key
try:
model[key].load_state_dict(state_dict)
except:
state_dict = {k[7:]: v for k, v in state_dict.items()}
model[key].load_state_dict(state_dict, strict=False)
return model

16
api/src/builds/plbert.py Normal file
View file

@ -0,0 +1,16 @@
# https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py
from transformers import AlbertConfig, AlbertModel
class CustomAlbert(AlbertModel):
def forward(self, *args, **kwargs):
# Call the original forward method
outputs = super().forward(*args, **kwargs)
# Only return the last_hidden_state
return outputs.last_hidden_state
def load_plbert():
plbert_config = {'vocab_size': 178, 'hidden_size': 768, 'num_attention_heads': 12, 'intermediate_size': 2048, 'max_position_embeddings': 512, 'num_hidden_layers': 12, 'dropout': 0.1}
albert_base_configuration = AlbertConfig(**plbert_config)
bert = CustomAlbert(albert_base_configuration)
return bert

View file

@ -13,7 +13,7 @@ class Settings(BaseSettings):
output_dir: str = "output"
output_dir_size_limit_mb: float = 500.0 # Maximum size of output directory in MB
default_voice: str = "af"
model_dir: str = "/app/Kokoro-82M" # Base directory for model files
model_dir: str = "/app/api/model_files" # Base directory for model files
pytorch_model_path: str = "kokoro-v0_19.pth"
onnx_model_path: str = "kokoro-v0_19.onnx"
voices_dir: str = "voices"

View file

@ -1,7 +1,7 @@
import re
import torch
import phonemizer
import torch
def split_num(num):

View file

@ -6,15 +6,15 @@ import sys
from contextlib import asynccontextmanager
import uvicorn
from loguru import logger
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from .core.config import settings
from .services.tts_model import TTSModel
from .routers.development import router as dev_router
from .services.tts_service import TTSService
from .routers.openai_compatible import router as openai_router
from .services.tts_model import TTSModel
from .services.tts_service import TTSService
def setup_logger():
@ -47,7 +47,7 @@ async def lifespan(app: FastAPI):
# Initialize the main model with warm-up
voicepack_count = await TTSModel.setup()
# boundary = "█████╗"*9
boundary = "" * 24
boundary = "" * 2*12
startup_msg = f"""
{boundary}

View file

@ -1,18 +1,18 @@
from typing import List
import numpy as np
from fastapi import APIRouter, Depends, HTTPException, Response
from loguru import logger
from fastapi import Depends, Response, APIRouter, HTTPException
from ..services.audio import AudioService
from ..services.text_processing import phonemize, tokenize
from ..services.tts_model import TTSModel
from ..services.tts_service import TTSService
from ..structures.text_schemas import (
GenerateFromPhonemesRequest,
PhonemeRequest,
PhonemeResponse,
GenerateFromPhonemesRequest,
)
from ..services.text_processing import tokenize, phonemize
router = APIRouter(tags=["text processing"])

View file

@ -1,12 +1,12 @@
from typing import List, Union, AsyncGenerator
from typing import AsyncGenerator, List, Union
from loguru import logger
from fastapi import Header, Depends, Response, APIRouter, HTTPException
from fastapi import APIRouter, Depends, Header, HTTPException, Response, Request
from fastapi.responses import StreamingResponse
from loguru import logger
from ..services.audio import AudioService
from ..structures.schemas import OpenAISpeechRequest
from ..services.tts_service import TTSService
from ..structures.schemas import OpenAISpeechRequest
router = APIRouter(
tags=["OpenAI Compatible TTS"],
@ -49,22 +49,35 @@ async def process_voices(
async def stream_audio_chunks(
tts_service: TTSService, request: OpenAISpeechRequest
tts_service: TTSService,
request: OpenAISpeechRequest,
client_request: Request
) -> AsyncGenerator[bytes, None]:
"""Stream audio chunks as they're generated"""
"""Stream audio chunks as they're generated with client disconnect handling"""
voice_to_use = await process_voices(request.voice, tts_service)
async for chunk in tts_service.generate_audio_stream(
text=request.input,
voice=voice_to_use,
speed=request.speed,
output_format=request.response_format,
):
yield chunk
try:
async for chunk in tts_service.generate_audio_stream(
text=request.input,
voice=voice_to_use,
speed=request.speed,
output_format=request.response_format,
):
# Check if client is still connected
if await client_request.is_disconnected():
logger.info("Client disconnected, stopping audio generation")
break
yield chunk
except Exception as e:
logger.error(f"Error in audio streaming: {str(e)}")
# Let the exception propagate to trigger cleanup
raise
@router.post("/audio/speech")
async def create_speech(
request: OpenAISpeechRequest,
client_request: Request,
tts_service: TTSService = Depends(get_tts_service),
x_raw_response: str = Header(None, alias="x-raw-response"),
):
@ -87,7 +100,7 @@ async def create_speech(
if request.stream:
# Stream audio chunks as they're generated
return StreamingResponse(
stream_audio_chunks(tts_service, request),
stream_audio_chunks(tts_service, request, client_request),
media_type=content_type,
headers={
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",

View file

@ -3,8 +3,8 @@
from io import BytesIO
import numpy as np
import soundfile as sf
import scipy.io.wavfile as wavfile
import soundfile as sf
from loguru import logger
from ..core.config import settings
@ -22,20 +22,19 @@ class AudioNormalizer:
def normalize(
self, audio_data: np.ndarray, is_last_chunk: bool = False
) -> np.ndarray:
"""Normalize audio data to int16 range and trim chunk boundaries"""
# Convert to float32 if not already
"""Convert audio data to int16 range and trim chunk boundaries"""
if len(audio_data) == 0:
raise ValueError("Audio data cannot be empty")
# Simple float32 to int16 conversion
audio_float = audio_data.astype(np.float32)
# Normalize to [-1, 1] range first
if np.max(np.abs(audio_float)) > 0:
audio_float = audio_float / np.max(np.abs(audio_float))
# Trim end of non-final chunks to reduce gaps
# Trim for non-final chunks
if not is_last_chunk and len(audio_float) > self.samples_to_trim:
audio_float = audio_float[: -self.samples_to_trim]
# Scale to int16 range
return (audio_float * self.int16_max).astype(np.int16)
audio_float = audio_float[:-self.samples_to_trim]
# Direct scaling like the non-streaming version
return (audio_float * 32767).astype(np.int16)
class AudioService:

View file

@ -1,6 +1,6 @@
from .normalizer import normalize_text
from .phonemizer import EspeakBackend, PhonemizerBackend, phonemize
from .vocabulary import VOCAB, tokenize, decode_tokens
from .vocabulary import VOCAB, decode_tokens, tokenize
__all__ = [
"normalize_text",

View file

@ -5,19 +5,20 @@ import torch
from loguru import logger
from onnxruntime import (
ExecutionMode,
SessionOptions,
InferenceSession,
GraphOptimizationLevel,
InferenceSession,
SessionOptions,
)
from .tts_base import TTSBaseModel
from ..core.config import settings
from .text_processing import tokenize, phonemize
from .text_processing import phonemize, tokenize
from .tts_base import TTSBaseModel
class TTSCPUModel(TTSBaseModel):
_instance = None
_onnx_session = None
_device = "cpu"
@classmethod
def get_instance(cls):
@ -30,64 +31,65 @@ class TTSCPUModel(TTSBaseModel):
def initialize(cls, model_dir: str, model_path: str = None):
"""Initialize ONNX model for CPU inference"""
if cls._onnx_session is None:
# Try loading ONNX model
onnx_path = os.path.join(model_dir, settings.onnx_model_path)
if os.path.exists(onnx_path):
try:
# Try loading ONNX model
onnx_path = os.path.join(model_dir, settings.onnx_model_path)
if not os.path.exists(onnx_path):
logger.error(f"ONNX model not found at {onnx_path}")
return None
logger.info(f"Loading ONNX model from {onnx_path}")
else:
logger.error(f"ONNX model not found at {onnx_path}")
return None
if not onnx_path:
return None
# Configure ONNX session for optimal performance
session_options = SessionOptions()
# Configure ONNX session for optimal performance
session_options = SessionOptions()
# Set optimization level
if settings.onnx_optimization_level == "all":
session_options.graph_optimization_level = (
GraphOptimizationLevel.ORT_ENABLE_ALL
)
elif settings.onnx_optimization_level == "basic":
session_options.graph_optimization_level = (
GraphOptimizationLevel.ORT_ENABLE_BASIC
)
else:
session_options.graph_optimization_level = (
GraphOptimizationLevel.ORT_DISABLE_ALL
)
# Set optimization level
if settings.onnx_optimization_level == "all":
session_options.graph_optimization_level = (
GraphOptimizationLevel.ORT_ENABLE_ALL
)
elif settings.onnx_optimization_level == "basic":
session_options.graph_optimization_level = (
GraphOptimizationLevel.ORT_ENABLE_BASIC
)
else:
session_options.graph_optimization_level = (
GraphOptimizationLevel.ORT_DISABLE_ALL
# Configure threading
session_options.intra_op_num_threads = settings.onnx_num_threads
session_options.inter_op_num_threads = settings.onnx_inter_op_threads
# Set execution mode
session_options.execution_mode = (
ExecutionMode.ORT_PARALLEL
if settings.onnx_execution_mode == "parallel"
else ExecutionMode.ORT_SEQUENTIAL
)
# Configure threading
session_options.intra_op_num_threads = settings.onnx_num_threads
session_options.inter_op_num_threads = settings.onnx_inter_op_threads
# Enable/disable memory pattern optimization
session_options.enable_mem_pattern = settings.onnx_memory_pattern
# Set execution mode
session_options.execution_mode = (
ExecutionMode.ORT_PARALLEL
if settings.onnx_execution_mode == "parallel"
else ExecutionMode.ORT_SEQUENTIAL
)
# Enable/disable memory pattern optimization
session_options.enable_mem_pattern = settings.onnx_memory_pattern
# Configure CPU provider options
provider_options = {
"CPUExecutionProvider": {
"arena_extend_strategy": settings.onnx_arena_extend_strategy,
"cpu_memory_arena_cfg": "cpu:0",
# Configure CPU provider options
provider_options = {
"CPUExecutionProvider": {
"arena_extend_strategy": settings.onnx_arena_extend_strategy,
"cpu_memory_arena_cfg": "cpu:0",
}
}
}
session = InferenceSession(
onnx_path,
sess_options=session_options,
providers=["CPUExecutionProvider"],
provider_options=[provider_options],
)
cls._onnx_session = session
return session
session = InferenceSession(
onnx_path,
sess_options=session_options,
providers=["CPUExecutionProvider"],
provider_options=[provider_options],
)
cls._onnx_session = session
return session
except Exception as e:
logger.error(f"Failed to initialize ONNX model: {e}")
return None
return cls._onnx_session
@classmethod

View file

@ -3,12 +3,12 @@ import time
import numpy as np
import torch
from builds.models import build_model
from loguru import logger
from models import build_model
from .tts_base import TTSBaseModel
from ..core.config import settings
from .text_processing import tokenize, phonemize
from .text_processing import phonemize, tokenize
from .tts_base import TTSBaseModel
# @torch.no_grad()

View file

@ -2,19 +2,19 @@ import io
import os
import re
import time
from typing import List, Tuple, Optional
from functools import lru_cache
from typing import List, Optional, Tuple
import numpy as np
import torch
import aiofiles.os
import numpy as np
import scipy.io.wavfile as wavfile
import torch
from loguru import logger
from .audio import AudioService, AudioNormalizer
from .tts_model import TTSModel
from ..core.config import settings
from .audio import AudioNormalizer, AudioService
from .text_processing import chunker, normalize_text
from .tts_model import TTSModel
class TTSService:

View file

@ -4,9 +4,9 @@ from typing import List, Tuple
import torch
from loguru import logger
from ..core.config import settings
from .tts_model import TTSModel
from .tts_service import TTSService
from ..core.config import settings
class WarmupService:

View file

@ -1,7 +1,7 @@
from enum import Enum
from typing import List, Union, Literal
from typing import List, Literal, Union
from pydantic import Field, BaseModel
from pydantic import BaseModel, Field
class VoiceCombineRequest(BaseModel):

View file

@ -1,4 +1,4 @@
from pydantic import Field, BaseModel
from pydantic import BaseModel, Field
class PhonemeRequest(BaseModel):

BIN
api/src/voices/af_irulan.pt Normal file

Binary file not shown.

View file

@ -1,11 +1,11 @@
import os
import sys
import shutil
from unittest.mock import Mock, MagicMock, patch
import sys
from unittest.mock import MagicMock, Mock, patch
import aiofiles.threadpool
import numpy as np
import pytest
import aiofiles.threadpool
def cleanup_mock_dirs():
@ -32,77 +32,7 @@ def cleanup():
cleanup_mock_dirs()
# Create mock torch module
mock_torch = Mock()
mock_torch.cuda = Mock()
mock_torch.cuda.is_available = Mock(return_value=False)
# Create a mock tensor class that supports basic operations
class MockTensor:
def __init__(self, data):
self.data = data
if isinstance(data, (list, tuple)):
self.shape = [len(data)]
elif isinstance(data, MockTensor):
self.shape = data.shape
else:
self.shape = getattr(data, "shape", [1])
def __getitem__(self, idx):
if isinstance(self.data, (list, tuple)):
if isinstance(idx, slice):
return MockTensor(self.data[idx])
return self.data[idx]
return self
def max(self):
if isinstance(self.data, (list, tuple)):
max_val = max(self.data)
return MockTensor(max_val)
return 5 # Default for testing
def item(self):
if isinstance(self.data, (list, tuple)):
return max(self.data)
if isinstance(self.data, (int, float)):
return self.data
return 5 # Default for testing
def cuda(self):
"""Support cuda conversion"""
return self
def any(self):
if isinstance(self.data, (list, tuple)):
return any(self.data)
return False
def all(self):
if isinstance(self.data, (list, tuple)):
return all(self.data)
return True
def unsqueeze(self, dim):
return self
def expand(self, *args):
return self
def type_as(self, other):
return self
# Add tensor operations to mock torch
mock_torch.tensor = lambda x: MockTensor(x)
mock_torch.zeros = lambda *args: MockTensor(
[0] * (args[0] if isinstance(args[0], int) else args[0][0])
)
mock_torch.arange = lambda x: MockTensor(list(range(x)))
mock_torch.gt = lambda x, y: MockTensor([False] * x.shape[0])
# Mock modules before they're imported
sys.modules["torch"] = mock_torch
sys.modules["transformers"] = Mock()
sys.modules["phonemizer"] = Mock()
sys.modules["models"] = Mock()

View file

@ -5,7 +5,7 @@ from unittest.mock import patch
import numpy as np
import pytest
from api.src.services.audio import AudioService, AudioNormalizer
from api.src.services.audio import AudioNormalizer, AudioService
@pytest.fixture(autouse=True)

View file

@ -1,10 +1,10 @@
import asyncio
from unittest.mock import Mock, AsyncMock
from unittest.mock import AsyncMock, Mock
import pytest
import pytest_asyncio
from httpx import AsyncClient
from fastapi.testclient import TestClient
from httpx import AsyncClient
from ..src.main import app

View file

@ -7,8 +7,8 @@ import pytest
import pytest_asyncio
from httpx import AsyncClient
from .conftest import MockTTSModel
from ..src.main import app
from .conftest import MockTTSModel
@pytest_asyncio.fixture

View file

@ -1,15 +1,15 @@
"""Tests for TTS model implementations"""
import os
from unittest.mock import MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import torch
import pytest
import torch
from api.src.services.tts_base import TTSBaseModel
from api.src.services.tts_cpu import TTSCPUModel
from api.src.services.tts_gpu import TTSGPUModel, length_to_mask
from api.src.services.tts_base import TTSBaseModel
# Base Model Tests
@ -27,16 +27,30 @@ def test_get_device_error():
@patch("os.listdir")
@patch("torch.load")
@patch("torch.save")
@patch("api.src.services.tts_base.settings")
@patch("api.src.services.warmup.WarmupService")
async def test_setup_cuda_available(
mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available
mock_warmup_class, mock_settings, mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available
):
"""Test setup with CUDA available"""
TTSBaseModel._device = None
mock_cuda_available.return_value = True
# Mock CUDA as unavailable since we're using CPU PyTorch
mock_cuda_available.return_value = False
mock_exists.return_value = True
mock_load.return_value = torch.zeros(1)
mock_listdir.return_value = ["voice1.pt", "voice2.pt"]
mock_join.return_value = "/mocked/path"
# Configure mock settings
mock_settings.model_dir = "/mock/model/dir"
mock_settings.onnx_model_path = "model.onnx"
mock_settings.voices_dir = "voices"
# Configure mock warmup service
mock_warmup = MagicMock()
mock_warmup.load_voices.return_value = [torch.zeros(1)]
mock_warmup.warmup_voices = AsyncMock()
mock_warmup_class.return_value = mock_warmup
# Create mock model
mock_model = MagicMock()
@ -49,7 +63,7 @@ async def test_setup_cuda_available(
TTSBaseModel._instance = mock_model
voice_count = await TTSBaseModel.setup()
assert TTSBaseModel._device == "cuda"
assert TTSBaseModel._device == "cpu"
assert voice_count == 2
@ -60,8 +74,10 @@ async def test_setup_cuda_available(
@patch("os.listdir")
@patch("torch.load")
@patch("torch.save")
@patch("api.src.services.tts_base.settings")
@patch("api.src.services.warmup.WarmupService")
async def test_setup_cuda_unavailable(
mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available
mock_warmup_class, mock_settings, mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available
):
"""Test setup with CUDA unavailable"""
TTSBaseModel._device = None
@ -70,6 +86,17 @@ async def test_setup_cuda_unavailable(
mock_load.return_value = torch.zeros(1)
mock_listdir.return_value = ["voice1.pt", "voice2.pt"]
mock_join.return_value = "/mocked/path"
# Configure mock settings
mock_settings.model_dir = "/mock/model/dir"
mock_settings.onnx_model_path = "model.onnx"
mock_settings.voices_dir = "voices"
# Configure mock warmup service
mock_warmup = MagicMock()
mock_warmup.load_voices.return_value = [torch.zeros(1)]
mock_warmup.warmup_voices = AsyncMock()
mock_warmup_class.return_value = mock_warmup
# Create mock model
mock_model = MagicMock()

View file

@ -4,8 +4,8 @@ import os
from unittest.mock import MagicMock, call, patch
import numpy as np
import torch
import pytest
import torch
from onnxruntime import InferenceSession
from api.src.core.config import settings

View file

@ -1,79 +0,0 @@
name: kokoro-fastapi
services:
model-fetcher:
image: datamachines/git-lfs:latest
volumes:
- ./Kokoro-82M:/app/Kokoro-82M
working_dir: /app/Kokoro-82M
command: >
sh -c "
mkdir -p /app/Kokoro-82M;
cd /app/Kokoro-82M;
rm -f .git/index.lock;
if [ -z \"$(ls -A .)\" ]; then
git clone https://huggingface.co/hexgrad/Kokoro-82M .
touch .cloned;
else
rm -f .git/index.lock && \
git checkout main && \
git pull origin main && \
touch .cloned;
fi;
tail -f /dev/null
"
healthcheck:
test: ["CMD", "test", "-f", ".cloned"]
interval: 5s
timeout: 2s
retries: 300
start_period: 1s
kokoro-tts:
image: ghcr.io/remsky/kokoro-fastapi-cpu:v0.0.5post1
# Uncomment below (and comment out above) to build from source instead of using the released image
build:
context: .
dockerfile: Dockerfile.cpu
volumes:
- ./api/src:/app/api/src
- ./Kokoro-82M:/app/Kokoro-82M
ports:
- "8880:8880"
environment:
- PYTHONPATH=/app:/app/Kokoro-82M
# ONNX Optimization Settings for vectorized operations
- ONNX_NUM_THREADS=8 # Maximize core usage for vectorized ops
- ONNX_INTER_OP_THREADS=4 # Higher inter-op for parallel matrix operations
- ONNX_EXECUTION_MODE=parallel
- ONNX_OPTIMIZATION_LEVEL=all
- ONNX_MEMORY_PATTERN=true
- ONNX_ARENA_EXTEND_STRATEGY=kNextPowerOfTwo
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8880/health"]
interval: 10s
timeout: 5s
retries: 30
start_period: 30s
depends_on:
model-fetcher:
condition: service_healthy
# Gradio UI service [Comment out everything below if you don't need it]
gradio-ui:
image: ghcr.io/remsky/kokoro-fastapi-ui:v0.0.5post1
# Uncomment below (and comment out above) to build from source instead of using the released image
# build:
# context: ./ui
ports:
- "7860:7860"
volumes:
- ./ui/data:/app/ui/data
- ./ui/app.py:/app/app.py # Mount app.py for hot reload
environment:
- GRADIO_WATCH=True # Enable hot reloading
- PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
depends_on:
kokoro-tts:
condition: service_healthy

View file

@ -1,80 +0,0 @@
name: kokoro-fastapi
services:
model-fetcher:
image: datamachines/git-lfs:latest
environment:
- SKIP_MODEL_FETCH=${SKIP_MODEL_FETCH:-false}
volumes:
- ./Kokoro-82M:/app/Kokoro-82M
working_dir: /app/Kokoro-82M
command: >
sh -c "
if [ \"$$SKIP_MODEL_FETCH\" = \"true\" ]; then
echo 'Skipping model fetch...' && touch .cloned;
else
rm -f .git/index.lock;
if [ -z \"$(ls -A .)\" ]; then
git clone https://huggingface.co/hexgrad/Kokoro-82M .
touch .cloned;
else
rm -f .git/index.lock && \
git checkout main && \
git pull origin main && \
touch .cloned;
fi;
fi;
tail -f /dev/null
"
healthcheck:
test: ["CMD", "test", "-f", ".cloned"]
interval: 5s
timeout: 2s
retries: 300
start_period: 1s
kokoro-tts:
image: ghcr.io/remsky/kokoro-fastapi-gpu:v0.0.5post1
# Uncomment below (and comment out above) to build from source instead of using the released image
# build:
# context: .
volumes:
- ./api/src:/app/api/src
- ./Kokoro-82M:/app/Kokoro-82M
ports:
- "8880:8880"
environment:
- PYTHONPATH=/app:/app/Kokoro-82M
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8880/health"]
interval: 10s
timeout: 5s
retries: 30
start_period: 30s
depends_on:
model-fetcher:
condition: service_healthy
# Gradio UI service [Comment out everything below if you don't need it]
gradio-ui:
image: ghcr.io/remsky/kokoro-fastapi-ui:v0.0.5post1
# Uncomment below (and comment out above) to build from source instead of using the released image
# build:
# context: ./ui
ports:
- "7860:7860"
volumes:
- ./ui/data:/app/ui/data
- ./ui/app.py:/app/app.py # Mount app.py for hot reload
environment:
- GRADIO_WATCH=True # Enable hot reloading
- PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
depends_on:
kokoro-tts:
condition: service_healthy

62
docker/cpu/Dockerfile Normal file
View file

@ -0,0 +1,62 @@
FROM python:3.10-slim
# Install dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
espeak-ng \
git \
libsndfile1 \
curl \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# Install uv
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
# Create non-root user
RUN useradd -m -u 1000 appuser
# Create directories and set ownership
RUN mkdir -p /app/api/model_files && \
mkdir -p /app/api/src/voices && \
chown -R appuser:appuser /app
USER appuser
# Download and extract models
WORKDIR /app/api/model_files
RUN curl -L -o model.tar.gz https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.0.1/kokoro-82m-onnx.tar.gz && \
tar xzf model.tar.gz && \
rm model.tar.gz
# Download and extract voice models
WORKDIR /app/api/src/voices
RUN curl -L -o voices.tar.gz https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.0.1/voice-models.tar.gz && \
tar xzf voices.tar.gz && \
rm voices.tar.gz
# Switch back to app directory
WORKDIR /app
# Copy dependency files
COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml
# Install dependencies
RUN --mount=type=cache,target=/root/.cache/uv \
uv venv && \
uv sync --extra cpu --no-install-project
# Copy project files
COPY --chown=appuser:appuser api ./api
# Install project
RUN --mount=type=cache,target=/root/.cache/uv \
uv sync --extra cpu
# Set environment variables
ENV PYTHONUNBUFFERED=1
ENV PYTHONPATH=/app:/app/Kokoro-82M
ENV PATH="/app/.venv/bin:$PATH"
ENV UV_LINK_MODE=copy
# Run FastAPI server
CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]

View file

@ -0,0 +1,37 @@
name: kokoro-tts
services:
kokoro-tts:
image: ghcr.io/remsky/kokoro-fastapi-cpu:latest
# Uncomment below (and comment out above) to build from source instead of using the released image
# build:
# context: ../..
# dockerfile: docker/cpu/Dockerfile
volumes:
- ../../api/src:/app/api/src
ports:
- "8880:8880"
environment:
- PYTHONPATH=/app:/app/Kokoro-82M
# ONNX Optimization Settings for vectorized operations
- ONNX_NUM_THREADS=8 # Maximize core usage for vectorized ops
- ONNX_INTER_OP_THREADS=4 # Higher inter-op for parallel matrix operations
- ONNX_EXECUTION_MODE=parallel
- ONNX_OPTIMIZATION_LEVEL=all
- ONNX_MEMORY_PATTERN=true
- ONNX_ARENA_EXTEND_STRATEGY=kNextPowerOfTwo
# Gradio UI service [Comment out everything below if you don't need it]
gradio-ui:
image: ghcr.io/remsky/kokoro-fastapi:latest-ui
# Uncomment below (and comment out above) to build from source instead of using the released image
# build:
# context: ../../ui
ports:
- "7860:7860"
volumes:
- ../../ui/data:/app/ui/data
- ../../ui/app.py:/app/app.py # Mount app.py for hot reload
environment:
- GRADIO_WATCH=True # Enable hot reloading
- PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
- DISABLE_LOCAL_SAVING=false # Set to 'true' to disable local saving and hide file view

22
docker/cpu/pyproject.toml Normal file
View file

@ -0,0 +1,22 @@
[project]
name = "kokoro-fastapi-cpu"
version = "0.1.0"
description = "FastAPI TTS Service - CPU Version"
readme = "../README.md"
requires-python = ">=3.10"
dependencies = [
# Core ML/DL for CPU
"torch>=2.5.1",
"transformers==4.47.1",
]
[tool.uv.workspace]
members = ["../shared"]
[tool.uv.sources]
torch = { index = "pytorch-cpu" }
[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true

View file

@ -0,0 +1,229 @@
# This file was autogenerated by uv via the following command:
# uv pip compile pyproject.toml ../shared/pyproject.toml --output-file requirements.lock
aiofiles==23.2.1
# via kokoro-fastapi (../shared/pyproject.toml)
annotated-types==0.7.0
# via pydantic
anyio==4.8.0
# via starlette
attrs==24.3.0
# via
# clldutils
# csvw
# jsonschema
# phonemizer
# referencing
babel==2.16.0
# via csvw
certifi==2024.12.14
# via requests
cffi==1.17.1
# via soundfile
charset-normalizer==3.4.1
# via requests
click==8.1.8
# via
# kokoro-fastapi (../shared/pyproject.toml)
# uvicorn
clldutils==3.21.0
# via segments
colorama==0.4.6
# via
# click
# colorlog
# csvw
# loguru
# tqdm
coloredlogs==15.0.1
# via onnxruntime
colorlog==6.9.0
# via clldutils
csvw==3.5.1
# via segments
dlinfo==1.2.1
# via phonemizer
exceptiongroup==1.2.2
# via anyio
fastapi==0.115.6
# via kokoro-fastapi (../shared/pyproject.toml)
filelock==3.16.1
# via
# huggingface-hub
# torch
# transformers
flatbuffers==24.12.23
# via onnxruntime
fsspec==2024.12.0
# via
# huggingface-hub
# torch
greenlet==3.1.1
# via sqlalchemy
h11==0.14.0
# via uvicorn
huggingface-hub==0.27.1
# via
# tokenizers
# transformers
humanfriendly==10.0
# via coloredlogs
idna==3.10
# via
# anyio
# requests
isodate==0.7.2
# via
# csvw
# rdflib
jinja2==3.1.5
# via torch
joblib==1.4.2
# via phonemizer
jsonschema==4.23.0
# via csvw
jsonschema-specifications==2024.10.1
# via jsonschema
language-tags==1.2.0
# via csvw
loguru==0.7.3
# via kokoro-fastapi (../shared/pyproject.toml)
lxml==5.3.0
# via clldutils
markdown==3.7
# via clldutils
markupsafe==3.0.2
# via
# clldutils
# jinja2
mpmath==1.3.0
# via sympy
munch==4.0.0
# via kokoro-fastapi (../shared/pyproject.toml)
networkx==3.4.2
# via torch
numpy==2.2.1
# via
# kokoro-fastapi (../shared/pyproject.toml)
# onnxruntime
# scipy
# soundfile
# transformers
onnxruntime==1.20.1
# via kokoro-fastapi (../shared/pyproject.toml)
packaging==24.2
# via
# huggingface-hub
# onnxruntime
# transformers
phonemizer==3.3.0
# via kokoro-fastapi (../shared/pyproject.toml)
protobuf==5.29.3
# via onnxruntime
pycparser==2.22
# via cffi
pydantic==2.10.4
# via
# kokoro-fastapi (../shared/pyproject.toml)
# fastapi
# pydantic-settings
pydantic-core==2.27.2
# via pydantic
pydantic-settings==2.7.0
# via kokoro-fastapi (../shared/pyproject.toml)
pylatexenc==2.10
# via clldutils
pyparsing==3.2.1
# via rdflib
pyreadline3==3.5.4
# via humanfriendly
python-dateutil==2.9.0.post0
# via
# clldutils
# csvw
python-dotenv==1.0.1
# via
# kokoro-fastapi (../shared/pyproject.toml)
# pydantic-settings
pyyaml==6.0.2
# via
# huggingface-hub
# transformers
rdflib==7.1.2
# via csvw
referencing==0.35.1
# via
# jsonschema
# jsonschema-specifications
regex==2024.11.6
# via
# kokoro-fastapi (../shared/pyproject.toml)
# segments
# tiktoken
# transformers
requests==2.32.3
# via
# kokoro-fastapi (../shared/pyproject.toml)
# csvw
# huggingface-hub
# tiktoken
# transformers
rfc3986==1.5.0
# via csvw
rpds-py==0.22.3
# via
# jsonschema
# referencing
safetensors==0.5.2
# via transformers
scipy==1.14.1
# via kokoro-fastapi (../shared/pyproject.toml)
segments==2.2.1
# via phonemizer
six==1.17.0
# via python-dateutil
sniffio==1.3.1
# via anyio
soundfile==0.13.0
# via kokoro-fastapi (../shared/pyproject.toml)
sqlalchemy==2.0.27
# via kokoro-fastapi (../shared/pyproject.toml)
starlette==0.41.3
# via fastapi
sympy==1.13.1
# via
# onnxruntime
# torch
tabulate==0.9.0
# via clldutils
tiktoken==0.8.0
# via kokoro-fastapi (../shared/pyproject.toml)
tokenizers==0.21.0
# via transformers
torch==2.5.1+cpu
# via kokoro-fastapi-cpu (pyproject.toml)
tqdm==4.67.1
# via
# kokoro-fastapi (../shared/pyproject.toml)
# huggingface-hub
# transformers
transformers==4.47.1
# via kokoro-fastapi-cpu (pyproject.toml)
typing-extensions==4.12.2
# via
# anyio
# fastapi
# huggingface-hub
# phonemizer
# pydantic
# pydantic-core
# sqlalchemy
# torch
# uvicorn
uritemplate==4.1.1
# via csvw
urllib3==2.3.0
# via requests
uvicorn==0.34.0
# via kokoro-fastapi (../shared/pyproject.toml)
win32-setctime==1.2.0
# via loguru

1841
docker/cpu/uv.lock generated Normal file

File diff suppressed because it is too large Load diff

64
docker/gpu/Dockerfile Normal file
View file

@ -0,0 +1,64 @@
FROM nvidia/cuda:12.1.0-base-ubuntu22.04
# Install Python and other dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
python3.10 \
python3.10-venv \
espeak-ng \
git \
libsndfile1 \
curl \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# Install uv
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
# Create non-root user
RUN useradd -m -u 1000 appuser
# Create directories and set ownership
RUN mkdir -p /app/api/model_files && \
mkdir -p /app/api/src/voices && \
chown -R appuser:appuser /app
USER appuser
# Download and extract models
WORKDIR /app/api/model_files
RUN curl -L -o model.tar.gz https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.0.1/kokoro-82m-pytorch.tar.gz && \
tar xzf model.tar.gz && \
rm model.tar.gz
# Download and extract voice models
WORKDIR /app/api/src/voices
RUN curl -L -o voices.tar.gz https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.0.1/voice-models.tar.gz && \
tar xzf voices.tar.gz && \
rm voices.tar.gz
# Switch back to app directory
WORKDIR /app
# Copy dependency files
COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml
# Install dependencies
RUN --mount=type=cache,target=/root/.cache/uv \
uv venv && \
uv sync --extra gpu --no-install-project
# Copy project files
COPY --chown=appuser:appuser api ./api
# Install project
RUN --mount=type=cache,target=/root/.cache/uv \
uv sync --extra gpu
# Set environment variables
ENV PYTHONUNBUFFERED=1
ENV PYTHONPATH=/app:/app/Kokoro-82M
ENV PATH="/app/.venv/bin:$PATH"
ENV UV_LINK_MODE=copy
# Run FastAPI server
CMD ["uv", "run", "python", "-m", "uvicorn", "api.src.main:app", "--host", "0.0.0.0", "--port", "8880", "--log-level", "debug"]

View file

@ -0,0 +1,37 @@
name: kokoro-tts
services:
kokoro-tts:
image: ghcr.io/remsky/kokoro-fastapi-gpu:latest
# Uncomment below (and comment out above) to build from source instead of using the released image
# build:
# context: ../..
# dockerfile: docker/gpu/Dockerfile
volumes:
- ../../api/src:/app/api/src # Mount src for development
ports:
- "8880:8880"
environment:
- PYTHONPATH=/app:/app/Kokoro-82M
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
# Gradio UI service
gradio-ui:
image: ghcr.io/remsky/kokoro-fastapi-ui:latest
# Uncomment below to build from source instead of using the released image
# build:
# context: ../../ui
ports:
- "7860:7860"
volumes:
- ../../ui/data:/app/ui/data
- ../../ui/app.py:/app/app.py # Mount app.py for hot reload
environment:
- GRADIO_WATCH=1 # Enable hot reloading
- PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
- DISABLE_LOCAL_SAVING=false # Set to 'true' to disable local saving and hide file view

22
docker/gpu/pyproject.toml Normal file
View file

@ -0,0 +1,22 @@
[project]
name = "kokoro-fastapi-gpu"
version = "0.1.0"
description = "FastAPI TTS Service - GPU Version"
readme = "../README.md"
requires-python = ">=3.10"
dependencies = [
# Core ML/DL for GPU
"torch==2.5.1+cu121",
"transformers==4.47.1",
]
[tool.uv.workspace]
members = ["../shared"]
[tool.uv.sources]
torch = { index = "pytorch-cuda" }
[[tool.uv.index]]
name = "pytorch-cuda"
url = "https://download.pytorch.org/whl/cu121"
explicit = true

View file

@ -0,0 +1,229 @@
# This file was autogenerated by uv via the following command:
# uv pip compile pyproject.toml ../shared/pyproject.toml --output-file requirements.lock
aiofiles==23.2.1
# via kokoro-fastapi (../shared/pyproject.toml)
annotated-types==0.7.0
# via pydantic
anyio==4.8.0
# via starlette
attrs==24.3.0
# via
# clldutils
# csvw
# jsonschema
# phonemizer
# referencing
babel==2.16.0
# via csvw
certifi==2024.12.14
# via requests
cffi==1.17.1
# via soundfile
charset-normalizer==3.4.1
# via requests
click==8.1.8
# via
# kokoro-fastapi (../shared/pyproject.toml)
# uvicorn
clldutils==3.21.0
# via segments
colorama==0.4.6
# via
# click
# colorlog
# csvw
# loguru
# tqdm
coloredlogs==15.0.1
# via onnxruntime
colorlog==6.9.0
# via clldutils
csvw==3.5.1
# via segments
dlinfo==1.2.1
# via phonemizer
exceptiongroup==1.2.2
# via anyio
fastapi==0.115.6
# via kokoro-fastapi (../shared/pyproject.toml)
filelock==3.16.1
# via
# huggingface-hub
# torch
# transformers
flatbuffers==24.12.23
# via onnxruntime
fsspec==2024.12.0
# via
# huggingface-hub
# torch
greenlet==3.1.1
# via sqlalchemy
h11==0.14.0
# via uvicorn
huggingface-hub==0.27.1
# via
# tokenizers
# transformers
humanfriendly==10.0
# via coloredlogs
idna==3.10
# via
# anyio
# requests
isodate==0.7.2
# via
# csvw
# rdflib
jinja2==3.1.5
# via torch
joblib==1.4.2
# via phonemizer
jsonschema==4.23.0
# via csvw
jsonschema-specifications==2024.10.1
# via jsonschema
language-tags==1.2.0
# via csvw
loguru==0.7.3
# via kokoro-fastapi (../shared/pyproject.toml)
lxml==5.3.0
# via clldutils
markdown==3.7
# via clldutils
markupsafe==3.0.2
# via
# clldutils
# jinja2
mpmath==1.3.0
# via sympy
munch==4.0.0
# via kokoro-fastapi (../shared/pyproject.toml)
networkx==3.4.2
# via torch
numpy==2.2.1
# via
# kokoro-fastapi (../shared/pyproject.toml)
# onnxruntime
# scipy
# soundfile
# transformers
onnxruntime==1.20.1
# via kokoro-fastapi (../shared/pyproject.toml)
packaging==24.2
# via
# huggingface-hub
# onnxruntime
# transformers
phonemizer==3.3.0
# via kokoro-fastapi (../shared/pyproject.toml)
protobuf==5.29.3
# via onnxruntime
pycparser==2.22
# via cffi
pydantic==2.10.4
# via
# kokoro-fastapi (../shared/pyproject.toml)
# fastapi
# pydantic-settings
pydantic-core==2.27.2
# via pydantic
pydantic-settings==2.7.0
# via kokoro-fastapi (../shared/pyproject.toml)
pylatexenc==2.10
# via clldutils
pyparsing==3.2.1
# via rdflib
pyreadline3==3.5.4
# via humanfriendly
python-dateutil==2.9.0.post0
# via
# clldutils
# csvw
python-dotenv==1.0.1
# via
# kokoro-fastapi (../shared/pyproject.toml)
# pydantic-settings
pyyaml==6.0.2
# via
# huggingface-hub
# transformers
rdflib==7.1.2
# via csvw
referencing==0.35.1
# via
# jsonschema
# jsonschema-specifications
regex==2024.11.6
# via
# kokoro-fastapi (../shared/pyproject.toml)
# segments
# tiktoken
# transformers
requests==2.32.3
# via
# kokoro-fastapi (../shared/pyproject.toml)
# csvw
# huggingface-hub
# tiktoken
# transformers
rfc3986==1.5.0
# via csvw
rpds-py==0.22.3
# via
# jsonschema
# referencing
safetensors==0.5.2
# via transformers
scipy==1.14.1
# via kokoro-fastapi (../shared/pyproject.toml)
segments==2.2.1
# via phonemizer
six==1.17.0
# via python-dateutil
sniffio==1.3.1
# via anyio
soundfile==0.13.0
# via kokoro-fastapi (../shared/pyproject.toml)
sqlalchemy==2.0.27
# via kokoro-fastapi (../shared/pyproject.toml)
starlette==0.41.3
# via fastapi
sympy==1.13.1
# via
# onnxruntime
# torch
tabulate==0.9.0
# via clldutils
tiktoken==0.8.0
# via kokoro-fastapi (../shared/pyproject.toml)
tokenizers==0.21.0
# via transformers
torch==2.5.1+cu121
# via kokoro-fastapi-gpu (pyproject.toml)
tqdm==4.67.1
# via
# kokoro-fastapi (../shared/pyproject.toml)
# huggingface-hub
# transformers
transformers==4.47.1
# via kokoro-fastapi-gpu (pyproject.toml)
typing-extensions==4.12.2
# via
# anyio
# fastapi
# huggingface-hub
# phonemizer
# pydantic
# pydantic-core
# sqlalchemy
# torch
# uvicorn
uritemplate==4.1.1
# via csvw
urllib3==2.3.0
# via requests
uvicorn==0.34.0
# via kokoro-fastapi (../shared/pyproject.toml)
win32-setctime==1.2.0
# via loguru

1914
docker/gpu/uv.lock generated Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,44 @@
[project]
name = "kokoro-fastapi"
version = "0.1.0"
description = "FastAPI TTS Service"
readme = "../README.md"
requires-python = ">=3.10"
dependencies = [
# Core dependencies
"fastapi==0.115.6",
"uvicorn==0.34.0",
"click>=8.0.0",
"pydantic==2.10.4",
"pydantic-settings==2.7.0",
"python-dotenv==1.0.1",
"sqlalchemy==2.0.27",
# ML/DL Base
"numpy>=1.26.0",
"scipy==1.14.1",
"onnxruntime==1.20.1",
# Audio processing
"soundfile==0.13.0",
# Text processing
"phonemizer==3.3.0",
"regex==2024.11.6",
# Utilities
"aiofiles==23.2.1",
"tqdm==4.67.1",
"requests==2.32.3",
"munch==4.0.0",
"tiktoken==0.8.0",
"loguru==0.7.3",
]
[project.optional-dependencies]
test = [
"pytest==8.0.0",
"httpx==0.26.0",
"pytest-asyncio==0.23.5",
"ruff==0.9.1",
]

View file

@ -9,7 +9,7 @@ sqlalchemy==2.0.27
# ML/DL
transformers==4.47.1
numpy==2.2.1
numpy>=1.26.0 # Version managed by PyTorch dependencies
scipy==1.14.1
onnxruntime==1.20.1
@ -21,7 +21,7 @@ phonemizer==3.3.0
regex==2024.11.6
# Utilities
aiofiles==24.1.0
aiofiles==23.2.1 # Last version before Windows path handling changes
tqdm==4.67.1
requests==2.32.3
munch==4.0.0

243
docs/requirements.txt Normal file
View file

@ -0,0 +1,243 @@
# This file was autogenerated by uv via the following command:
# uv pip compile docs/requirements.in --universal --output-file docs/requirements.txt
aiofiles==23.2.1
# via -r docs/requirements.in
annotated-types==0.7.0
# via pydantic
anyio==4.8.0
# via
# httpx
# starlette
attrs==24.3.0
# via
# clldutils
# csvw
# jsonschema
# phonemizer
# referencing
babel==2.16.0
# via csvw
certifi==2024.12.14
# via
# httpcore
# httpx
# requests
cffi==1.17.1
# via soundfile
charset-normalizer==3.4.1
# via requests
click==8.1.8
# via uvicorn
clldutils==3.21.0
# via segments
colorama==0.4.6
# via
# click
# colorlog
# csvw
# loguru
# pytest
# tqdm
coloredlogs==15.0.1
# via onnxruntime
colorlog==6.9.0
# via clldutils
csvw==3.5.1
# via segments
dlinfo==1.2.1
# via phonemizer
exceptiongroup==1.2.2 ; python_full_version < '3.11'
# via
# anyio
# pytest
fastapi==0.115.6
# via -r docs/requirements.in
filelock==3.16.1
# via
# huggingface-hub
# transformers
flatbuffers==24.12.23
# via onnxruntime
fsspec==2024.12.0
# via huggingface-hub
greenlet==3.1.1 ; platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64'
# via sqlalchemy
h11==0.14.0
# via
# httpcore
# uvicorn
httpcore==1.0.7
# via httpx
httpx==0.26.0
# via -r docs/requirements.in
huggingface-hub==0.27.1
# via
# tokenizers
# transformers
humanfriendly==10.0
# via coloredlogs
idna==3.10
# via
# anyio
# httpx
# requests
iniconfig==2.0.0
# via pytest
isodate==0.7.2
# via
# csvw
# rdflib
joblib==1.4.2
# via phonemizer
jsonschema==4.23.0
# via csvw
jsonschema-specifications==2024.10.1
# via jsonschema
language-tags==1.2.0
# via csvw
loguru==0.7.3
# via -r docs/requirements.in
lxml==5.3.0
# via clldutils
markdown==3.7
# via clldutils
markupsafe==3.0.2
# via clldutils
mpmath==1.3.0
# via sympy
munch==4.0.0
# via -r docs/requirements.in
numpy==2.2.1
# via
# -r docs/requirements.in
# onnxruntime
# scipy
# soundfile
# transformers
onnxruntime==1.20.1
# via -r docs/requirements.in
packaging==24.2
# via
# huggingface-hub
# onnxruntime
# pytest
# transformers
phonemizer==3.3.0
# via -r docs/requirements.in
pluggy==1.5.0
# via pytest
protobuf==5.29.3
# via onnxruntime
pycparser==2.22
# via cffi
pydantic==2.10.4
# via
# -r docs/requirements.in
# fastapi
# pydantic-settings
pydantic-core==2.27.2
# via pydantic
pydantic-settings==2.7.0
# via -r docs/requirements.in
pylatexenc==2.10
# via clldutils
pyparsing==3.2.1
# via rdflib
pyreadline3==3.5.4 ; sys_platform == 'win32'
# via humanfriendly
pytest==8.0.0
# via
# -r docs/requirements.in
# pytest-asyncio
pytest-asyncio==0.23.5
# via -r docs/requirements.in
python-dateutil==2.9.0.post0
# via
# clldutils
# csvw
python-dotenv==1.0.1
# via
# -r docs/requirements.in
# pydantic-settings
pyyaml==6.0.2
# via
# huggingface-hub
# transformers
rdflib==7.1.2
# via csvw
referencing==0.35.1
# via
# jsonschema
# jsonschema-specifications
regex==2024.11.6
# via
# -r docs/requirements.in
# segments
# tiktoken
# transformers
requests==2.32.3
# via
# -r docs/requirements.in
# csvw
# huggingface-hub
# tiktoken
# transformers
rfc3986==1.5.0
# via csvw
rpds-py==0.22.3
# via
# jsonschema
# referencing
safetensors==0.5.2
# via transformers
scipy==1.14.1
# via -r docs/requirements.in
segments==2.2.1
# via phonemizer
six==1.17.0
# via python-dateutil
sniffio==1.3.1
# via
# anyio
# httpx
soundfile==0.13.0
# via -r docs/requirements.in
sqlalchemy==2.0.27
# via -r docs/requirements.in
starlette==0.41.3
# via fastapi
sympy==1.13.3
# via onnxruntime
tabulate==0.9.0
# via clldutils
tiktoken==0.8.0
# via -r docs/requirements.in
tokenizers==0.21.0
# via transformers
tomli==2.2.1 ; python_full_version < '3.11'
# via pytest
tqdm==4.67.1
# via
# -r docs/requirements.in
# huggingface-hub
# transformers
transformers==4.47.1
# via -r docs/requirements.in
typing-extensions==4.12.2
# via
# anyio
# fastapi
# huggingface-hub
# phonemizer
# pydantic
# pydantic-core
# sqlalchemy
# uvicorn
uritemplate==4.1.1
# via csvw
urllib3==2.3.0
# via requests
uvicorn==0.34.0
# via -r docs/requirements.in
win32-setctime==1.2.0 ; sys_platform == 'win32'
# via loguru

View file

@ -0,0 +1,2 @@
openai>=1.0.0
pyaudio>=0.2.13

Binary file not shown.

View file

@ -8,7 +8,7 @@ import requests
import sounddevice as sd
def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"):
def play_streaming_tts(text: str, output_file: str = None, voice: str = "af_sky"):
"""Stream TTS audio and play it back in real-time"""
print("\nStarting TTS stream request...")

90
pyproject.toml Normal file
View file

@ -0,0 +1,90 @@
[project]
name = "kokoro-fastapi"
version = "0.1.0"
description = "FastAPI TTS Service"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
# Core dependencies
"fastapi==0.115.6",
"uvicorn==0.34.0",
"click>=8.0.0",
"pydantic==2.10.4",
"pydantic-settings==2.7.0",
"python-dotenv==1.0.1",
"sqlalchemy==2.0.27",
# ML/DL Base
"numpy>=1.26.0",
"scipy==1.14.1",
"onnxruntime==1.20.1",
# Audio processing
"soundfile==0.13.0",
# Text processing
"phonemizer==3.3.0",
"regex==2024.11.6",
# Utilities
"aiofiles==23.2.1",
"tqdm==4.67.1",
"requests==2.32.3",
"munch==4.0.0",
"tiktoken==0.8.0",
"loguru==0.7.3",
"transformers==4.47.1",
"openai>=1.59.6",
"ebooklib>=0.18",
"html2text>=2024.2.26",
]
[project.optional-dependencies]
gpu = [
"torch==2.5.1+cu121",
]
cpu = [
"torch==2.5.1+cpu",
]
test = [
"pytest==8.0.0",
"pytest-cov==4.1.0",
"httpx==0.26.0",
"pytest-asyncio==0.23.5",
"gradio>=5",
"openai>=1.59.6",
]
[tool.uv]
conflicts = [
[
{ extra = "cpu" },
{ extra = "gpu" },
],
]
[tool.uv.sources]
torch = [
{ index = "pytorch-cpu", extra = "cpu" },
{ index = "pytorch-cuda", extra = "gpu" },
]
[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true
[[tool.uv.index]]
name = "pytorch-cuda"
url = "https://download.pytorch.org/whl/cu121"
explicit = true
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[tool.setuptools]
package-dir = {"" = "api/src"}
packages.find = {where = ["api/src"], namespaces = true}
[tool.pytest.ini_options]
testpaths = ["api/tests", "ui/tests"]
python_files = ["test_*.py"]
addopts = "--cov=api --cov=ui --cov-report=term-missing --cov-config=.coveragerc"
asyncio_mode = "strict"

View file

@ -1,14 +0,0 @@
# Core dependencies for testing
fastapi==0.115.6
uvicorn==0.34.0
pydantic==2.10.4
pydantic-settings==2.7.0
python-dotenv==1.0.1
sqlalchemy==2.0.27
# Testing
pytest==8.0.0
httpx==0.26.0
pytest-asyncio==0.23.5
pytest-cov==6.0.0
gradio==4.19.2

View file

@ -1,9 +1,3 @@
import warnings
# Filter out Gradio Dropdown warnings about values not in choices
#TODO: Warning continues to be displayed, though it isn't breaking anything
warnings.filterwarnings('ignore', category=UserWarning, module='gradio.components.dropdown')
from lib.interface import create_interface
if __name__ == "__main__":

View file

@ -36,15 +36,18 @@ def check_api_status() -> Tuple[bool, List[str]]:
def text_to_speech(
text: str, voice_id: str, format: str, speed: float
text: str, voice_id: str | list, format: str, speed: float
) -> Optional[str]:
"""Generate speech from text using TTS API."""
if not text.strip():
return None
# Handle multiple voices
voice_str = voice_id if isinstance(voice_id, str) else "+".join(voice_id)
# Create output filename
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
output_filename = f"output_{timestamp}_voice-{voice_id}_speed-{speed}.{format}"
output_filename = f"output_{timestamp}_voice-{voice_str}_speed-{speed}.{format}"
output_path = os.path.join(OUTPUTS_DIR, output_filename)
try:
@ -53,7 +56,7 @@ def text_to_speech(
json={
"model": "kokoro",
"input": text,
"voice": voice_id,
"voice": voice_str,
"response_format": format,
"speed": float(speed),
},

View file

@ -5,54 +5,78 @@ import gradio as gr
from .. import files
def create_input_column() -> Tuple[gr.Column, dict]:
def create_input_column(disable_local_saving: bool = False) -> Tuple[gr.Column, dict]:
"""Create the input column with text input and file handling."""
with gr.Column(scale=1) as col:
with gr.Tabs() as tabs:
# Set first tab as selected by default
tabs.selected = 0
# Direct Input Tab
with gr.TabItem("Direct Input"):
text_input = gr.Textbox(
label="Text to speak", placeholder="Enter text here...", lines=4
)
text_submit = gr.Button("Generate Speech", variant="primary", size="lg")
text_input = gr.Textbox(
label="Text to speak", placeholder="Enter text here...", lines=4
)
# Always show file upload but handle differently based on disable_local_saving
file_upload = gr.File(
label="Upload Text File (.txt)", file_types=[".txt"]
)
if not disable_local_saving:
# Show full interface with tabs when saving is enabled
with gr.Tabs() as tabs:
# Set first tab as selected by default
tabs.selected = 0
# Direct Input Tab
with gr.TabItem("Direct Input"):
text_submit_direct = gr.Button("Generate Speech", variant="primary", size="lg")
# File Input Tab
with gr.TabItem("From File"):
# Existing files dropdown
input_files_list = gr.Dropdown(
label="Select Existing File",
choices=files.list_input_files(),
value=None,
)
# Simple file upload
file_upload = gr.File(
label="Upload Text File (.txt)", file_types=[".txt"]
)
file_preview = gr.Textbox(
label="File Content Preview", interactive=False, lines=4
)
with gr.Row():
file_submit = gr.Button(
"Generate Speech", variant="primary", size="lg"
)
clear_files = gr.Button(
"Clear Files", variant="secondary", size="lg"
# File Input Tab
with gr.TabItem("From File"):
# Existing files dropdown
input_files_list = gr.Dropdown(
label="Select Existing File",
choices=files.list_input_files(),
value=None,
)
components = {
"tabs": tabs,
"text_input": text_input,
"file_select": input_files_list,
"file_upload": file_upload,
"file_preview": file_preview,
"text_submit": text_submit,
"file_submit": file_submit,
"clear_files": clear_files,
}
file_preview = gr.Textbox(
label="File Content Preview", interactive=False, lines=4
)
with gr.Row():
file_submit = gr.Button(
"Generate Speech", variant="primary", size="lg"
)
clear_files = gr.Button(
"Clear Files", variant="secondary", size="lg"
)
else:
# Just show the generate button when saving is disabled
text_submit_direct = gr.Button("Generate Speech", variant="primary", size="lg")
tabs = None
input_files_list = None
file_preview = None
file_submit = None
clear_files = None
# Initialize components based on disable_local_saving
if disable_local_saving:
components = {
"tabs": None,
"text_input": text_input,
"text_submit": text_submit_direct,
"file_select": None,
"file_upload": file_upload, # Keep file upload even when saving is disabled
"file_preview": None,
"file_submit": None,
"clear_files": None,
}
else:
components = {
"tabs": tabs,
"text_input": text_input,
"text_submit": text_submit_direct,
"file_select": input_files_list,
"file_upload": file_upload,
"file_preview": file_preview,
"file_submit": file_submit,
"clear_files": clear_files,
}
return col, components

View file

@ -20,10 +20,10 @@ def create_model_column(voice_ids: Optional[list] = None) -> Tuple[gr.Column, di
voice_input = gr.Dropdown(
choices=voice_ids,
label="Voice",
value=voice_ids[0] if voice_ids else None, # Set default value to first item if available
label="Voice(s)",
value=voice_ids[0] if voice_ids else None,
interactive=True,
allow_custom_value=True, # Allow temporary values during updates
multiselect=True,
)
format_input = gr.Dropdown(
choices=config.AUDIO_FORMATS, label="Audio Format", value="mp3"

View file

@ -5,34 +5,43 @@ import gradio as gr
from .. import files
def create_output_column() -> Tuple[gr.Column, dict]:
def create_output_column(disable_local_saving: bool = False) -> Tuple[gr.Column, dict]:
"""Create the output column with audio player and file list."""
with gr.Column(scale=1) as col:
gr.Markdown("### Latest Output")
audio_output = gr.Audio(label="Generated Speech", type="filepath")
audio_output = gr.Audio(
label="Generated Speech",
type="filepath",
waveform_options={"waveform_color": "#4C87AB"}
)
gr.Markdown("### Generated Files")
# Initialize dropdown with empty choices first
# Create file-related components with visible=False when local saving is disabled
gr.Markdown("### Generated Files", visible=not disable_local_saving)
output_files = gr.Dropdown(
label="Previous Outputs",
choices=[],
choices=files.list_output_files() if not disable_local_saving else [],
value=None,
allow_custom_value=True,
interactive=True,
visible=not disable_local_saving,
)
# Then update choices after component creation
output_files.choices = files.list_output_files()
play_btn = gr.Button("▶️ Play Selected", size="sm")
play_btn = gr.Button(
"▶️ Play Selected",
size="sm",
visible=not disable_local_saving,
)
selected_audio = gr.Audio(
label="Selected Output", type="filepath", visible=False
label="Selected Output",
type="filepath",
visible=False, # Always initially hidden
)
clear_outputs = gr.Button(
"⚠️ Delete All Previously Generated Output Audio 🗑️",
size="sm",
variant="secondary",
visible=not disable_local_saving,
)
components = {

View file

@ -11,12 +11,14 @@ def list_input_files() -> List[str]:
def list_output_files() -> List[str]:
"""List all output audio files."""
# Just return filenames since paths will be different inside/outside container
return [
f for f in os.listdir(OUTPUTS_DIR)
"""List all output audio files, sorted by most recent first."""
files = [
os.path.join(OUTPUTS_DIR, f)
for f in os.listdir(OUTPUTS_DIR)
if any(f.endswith(ext) for ext in AUDIO_FORMATS)
]
# Sort files by modification time, most recent first
return sorted(files, key=os.path.getmtime, reverse=True)
def read_text_file(filename: str) -> str:

View file

@ -1,11 +1,12 @@
import os
import shutil
import gradio as gr
from . import api, files
def setup_event_handlers(components: dict):
def setup_event_handlers(components: dict, disable_local_saving: bool = False):
"""Set up all event handlers for the UI components."""
def refresh_status():
@ -57,27 +58,37 @@ def setup_event_handlers(components: dict):
def handle_file_upload(file):
if file is None:
return gr.update(choices=files.list_input_files())
return "" if disable_local_saving else [gr.update(choices=files.list_input_files())]
try:
# Copy file to inputs directory
filename = os.path.basename(file.name)
target_path = os.path.join(files.INPUTS_DIR, filename)
# Read the file content
with open(file.name, 'r', encoding='utf-8') as f:
text_content = f.read()
# Handle duplicate filenames
base, ext = os.path.splitext(filename)
counter = 1
while os.path.exists(target_path):
new_name = f"{base}_{counter}{ext}"
target_path = os.path.join(files.INPUTS_DIR, new_name)
counter += 1
if disable_local_saving:
# When saving is disabled, put content directly in text input
# Normalize whitespace by replacing newlines with spaces
normalized_text = ' '.join(text_content.split())
return normalized_text
else:
# When saving is enabled, save file and update dropdown
filename = os.path.basename(file.name)
target_path = os.path.join(files.INPUTS_DIR, filename)
shutil.copy2(file.name, target_path)
# Handle duplicate filenames
base, ext = os.path.splitext(filename)
counter = 1
while os.path.exists(target_path):
new_name = f"{base}_{counter}{ext}"
target_path = os.path.join(files.INPUTS_DIR, new_name)
counter += 1
shutil.copy2(file.name, target_path)
return [gr.update(choices=files.list_input_files())]
except Exception as e:
print(f"Error uploading file: {e}")
return gr.update(choices=files.list_input_files())
print(f"Error handling file: {e}")
return "" if disable_local_saving else [gr.update(choices=files.list_input_files())]
def generate_from_text(text, voice, format, speed):
"""Generate speech from direct text input"""
@ -90,18 +101,20 @@ def setup_event_handlers(components: dict):
gr.Warning("Please enter text in the input box")
return [None, gr.update(choices=files.list_output_files())]
files.save_text(text)
# Only save text if local saving is enabled
if not disable_local_saving:
files.save_text(text)
result = api.text_to_speech(text, voice, format, speed)
if result is None:
gr.Warning("Failed to generate speech. Please try again.")
return [None, gr.update(choices=files.list_output_files())]
# Update list and select the newly generated file
output_files = files.list_output_files()
last_file = output_files[-1] if output_files else None
return [
result,
gr.update(choices=output_files, value=last_file),
gr.update(
choices=files.list_output_files(), value=os.path.basename(result)
),
]
def generate_from_file(selected_file, voice, format, speed):
@ -121,19 +134,16 @@ def setup_event_handlers(components: dict):
gr.Warning("Failed to generate speech. Please try again.")
return [None, gr.update(choices=files.list_output_files())]
# Update list and select the newly generated file
output_files = files.list_output_files()
last_file = output_files[-1] if output_files else None
return [
result,
gr.update(choices=output_files, value=last_file),
gr.update(
choices=files.list_output_files(), value=os.path.basename(result)
),
]
def play_selected(filename):
if filename:
file_path = os.path.join(files.OUTPUTS_DIR, filename)
if os.path.exists(file_path):
return gr.update(value=file_path, visible=True)
def play_selected(file_path):
if file_path and os.path.exists(file_path):
return gr.update(value=file_path, visible=True)
return gr.update(visible=False)
def clear_files(voice, format, speed):
@ -165,45 +175,7 @@ def setup_event_handlers(components: dict):
outputs=[components["model"]["status_btn"], components["model"]["voice"]],
)
components["input"]["file_select"].change(
fn=handle_file_select,
inputs=[components["input"]["file_select"]],
outputs=[components["input"]["file_preview"]],
)
components["input"]["file_upload"].upload(
fn=handle_file_upload,
inputs=[components["input"]["file_upload"]],
outputs=[components["input"]["file_select"]],
)
components["output"]["play_btn"].click(
fn=play_selected,
inputs=[components["output"]["output_files"]],
outputs=[components["output"]["selected_audio"]],
)
# Connect clear files button
components["input"]["clear_files"].click(
fn=clear_files,
inputs=[
components["model"]["voice"],
components["model"]["format"],
components["model"]["speed"],
],
outputs=[
components["input"]["file_select"],
components["input"]["file_upload"],
components["input"]["file_preview"],
components["output"]["audio_output"],
components["output"]["output_files"],
components["model"]["voice"],
components["model"]["format"],
components["model"]["speed"],
],
)
# Connect submit buttons for each tab
# Connect text submit button (always present)
components["input"]["text_submit"].click(
fn=generate_from_text,
inputs=[
@ -218,26 +190,70 @@ def setup_event_handlers(components: dict):
],
)
# Connect clear outputs button
components["output"]["clear_outputs"].click(
fn=clear_outputs,
outputs=[
components["output"]["audio_output"],
components["output"]["output_files"],
components["output"]["selected_audio"],
],
)
# Only connect file-related handlers if components exist
if components["input"]["file_select"] is not None:
components["input"]["file_select"].change(
fn=handle_file_select,
inputs=[components["input"]["file_select"]],
outputs=[components["input"]["file_preview"]],
)
components["input"]["file_submit"].click(
fn=generate_from_file,
inputs=[
components["input"]["file_select"],
components["model"]["voice"],
components["model"]["format"],
components["model"]["speed"],
],
outputs=[
components["output"]["audio_output"],
components["output"]["output_files"],
],
)
if components["input"]["file_upload"] is not None:
# File upload handler - output depends on disable_local_saving
components["input"]["file_upload"].upload(
fn=handle_file_upload,
inputs=[components["input"]["file_upload"]],
outputs=[components["input"]["text_input"] if disable_local_saving else components["input"]["file_select"]],
)
if components["output"]["play_btn"] is not None:
components["output"]["play_btn"].click(
fn=play_selected,
inputs=[components["output"]["output_files"]],
outputs=[components["output"]["selected_audio"]],
)
if components["input"]["clear_files"] is not None:
components["input"]["clear_files"].click(
fn=clear_files,
inputs=[
components["model"]["voice"],
components["model"]["format"],
components["model"]["speed"],
],
outputs=[
components["input"]["file_select"],
components["input"]["file_upload"],
components["input"]["file_preview"],
components["output"]["audio_output"],
components["output"]["output_files"],
components["model"]["voice"],
components["model"]["format"],
components["model"]["speed"],
],
)
if components["output"]["clear_outputs"] is not None:
components["output"]["clear_outputs"].click(
fn=clear_outputs,
outputs=[
components["output"]["audio_output"],
components["output"]["output_files"],
components["output"]["selected_audio"],
],
)
if components["input"]["file_submit"] is not None:
components["input"]["file_submit"].click(
fn=generate_from_file,
inputs=[
components["input"]["file_select"],
components["model"]["voice"],
components["model"]["format"],
components["model"]["speed"],
],
outputs=[
components["output"]["audio_output"],
components["output"]["output_files"],
],
)

View file

@ -1,4 +1,5 @@
import gradio as gr
import os
from . import api
from .handlers import setup_event_handlers
@ -10,6 +11,9 @@ def create_interface():
# Skip initial status check - let the timer handle it
is_available, available_voices = False, []
# Check if local saving is disabled
disable_local_saving = os.getenv("DISABLE_LOCAL_SAVING", "false").lower() == "true"
with gr.Blocks(title="Kokoro TTS Demo", theme=gr.themes.Monochrome()) as demo:
gr.HTML(
value='<div style="display: flex; gap: 0;">'
@ -22,11 +26,11 @@ def create_interface():
# Main interface
with gr.Row():
# Create columns
input_col, input_components = create_input_column()
input_col, input_components = create_input_column(disable_local_saving)
model_col, model_components = create_model_column(
available_voices
) # Pass initial voices
output_col, output_components = create_output_column()
output_col, output_components = create_output_column(disable_local_saving)
# Collect all components
components = {
@ -36,7 +40,7 @@ def create_interface():
}
# Set up event handlers
setup_event_handlers(components)
setup_event_handlers(components, disable_local_saving)
# Add periodic status check with Timer
def update_status():

View file

@ -106,24 +106,54 @@ def test_get_status_html_unavailable():
def test_text_to_speech_api_params(mock_response, tmp_path):
"""Test correct API parameters are sent"""
with patch("requests.post") as mock_post, patch(
"ui.lib.api.OUTPUTS_DIR", str(tmp_path)
), patch("builtins.open", mock_open()):
mock_post.return_value = mock_response({})
api.text_to_speech("test text", "voice1", "mp3", 1.5)
test_cases = [
# Single voice as string
("voice1", "voice1"),
# Multiple voices as list
(["voice1", "voice2"], "voice1+voice2"),
# Single voice as list
(["voice1"], "voice1"),
]
mock_post.assert_called_once()
args, kwargs = mock_post.call_args
for input_voice, expected_voice in test_cases:
with patch("requests.post") as mock_post, patch(
"ui.lib.api.OUTPUTS_DIR", str(tmp_path)
), patch("builtins.open", mock_open()):
mock_post.return_value = mock_response({})
api.text_to_speech("test text", input_voice, "mp3", 1.5)
# Check request body
assert kwargs["json"] == {
"model": "kokoro",
"input": "test text",
"voice": "voice1",
"response_format": "mp3",
"speed": 1.5,
}
mock_post.assert_called_once()
args, kwargs = mock_post.call_args
# Check headers and timeout
assert kwargs["headers"] == {"Content-Type": "application/json"}
assert kwargs["timeout"] == 300
# Check request body
assert kwargs["json"] == {
"model": "kokoro",
"input": "test text",
"voice": expected_voice,
"response_format": "mp3",
"speed": 1.5,
}
# Check headers and timeout
assert kwargs["headers"] == {"Content-Type": "application/json"}
assert kwargs["timeout"] == 300
def test_text_to_speech_output_filename(mock_response, tmp_path):
"""Test output filename contains correct voice identifier"""
test_cases = [
# Single voice
("voice1", lambda f: "voice-voice1" in f),
# Multiple voices
(["voice1", "voice2"], lambda f: "voice-voice1+voice2" in f),
]
for input_voice, filename_check in test_cases:
with patch("requests.post", return_value=mock_response({})), patch(
"ui.lib.api.OUTPUTS_DIR", str(tmp_path)
), patch("builtins.open", mock_open()) as mock_file:
result = api.text_to_speech("test text", input_voice, "mp3", 1.0)
assert result is not None
assert filename_check(result), f"Expected voice pattern not found in filename: {result}"
mock_file.assert_called_once()

View file

@ -36,8 +36,10 @@ def test_model_column_default_values():
expected_choices = [(voice_id, voice_id) for voice_id in voice_ids]
assert components["voice"].choices == expected_choices
# Value is not converted to tuple format for the value property
assert components["voice"].value == voice_ids[0]
assert components["voice"].value == [voice_ids[0]]
assert components["voice"].interactive is True
assert components["voice"].multiselect is True
assert components["voice"].label == "Voice(s)"
# Test format dropdown
# Gradio Dropdown converts choices to (value, label) tuples

View file

@ -136,7 +136,7 @@ def test_interface_components_presence():
required_components = {
"Text to speak",
"Voice",
"Voice(s)",
"Audio Format",
"Speed",
"Generated Speech",

2556
uv.lock generated Normal file

File diff suppressed because it is too large Load diff