Fix merge conflicts

This commit is contained in:
Kishor Prins 2025-04-05 08:33:21 -07:00
commit b19fd1d179
56 changed files with 1729 additions and 893 deletions

View file

@ -1,87 +0,0 @@
name: Docker Build and push
on:
push:
tags: [ 'v*.*.*' ]
paths-ignore:
- '**.md'
- 'docs/**'
workflow_dispatch:
inputs:
version:
description: 'Version to build and publish (e.g. v0.2.0)'
required: true
type: string
jobs:
prepare-release:
runs-on: ubuntu-latest
outputs:
version: ${{ steps.get-version.outputs.version }}
is_prerelease: ${{ steps.check-prerelease.outputs.is_prerelease }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Get version
id: get-version
run: |
if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then
echo "version=${{ inputs.version }}" >> $GITHUB_OUTPUT
else
echo "version=$(cat VERSION)" >> $GITHUB_OUTPUT
fi
- name: Check if prerelease
id: check-prerelease
run: |
echo "is_prerelease=${{ contains(steps.get-version.outputs.version, '-pre') }}" >> $GITHUB_OUTPUT
build-images:
needs: prepare-release
runs-on: ubuntu-latest
permissions:
packages: write
env:
DOCKER_BUILDKIT: 1
BUILDKIT_STEP_LOG_MAX_SIZE: 10485760
# This environment variable will override the VERSION variable in your HCL file.
VERSION: ${{ needs.prepare-release.outputs.version }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Free disk space
run: |
echo "Listing current disk space"
df -h
echo "Cleaning up disk space..."
sudo rm -rf /usr/share/dotnet
sudo rm -rf /usr/local/lib/android
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache
docker system prune -af
echo "Disk space after cleanup"
df -h
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
with:
driver-opts: |
image=moby/buildkit:latest
network=host
- name: Log in to GitHub Container Registry
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build and push images
run: |
# No need to override VERSION via --set; the env var does the job.
docker buildx bake --push

View file

@ -1,102 +0,0 @@
name: Docker Build and Publish
on:
push:
tags: [ 'v*.*.*' ]
paths-ignore:
- '**.md'
- 'docs/**'
workflow_dispatch:
inputs:
version:
description: 'Version to release (e.g. v0.2.0)'
required: true
type: string
jobs:
prepare-release:
runs-on: ubuntu-latest
outputs:
version: ${{ steps.get-version.outputs.version }}
is_prerelease: ${{ steps.check-prerelease.outputs.is_prerelease }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Get version
id: get-version
run: |
if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then
echo "version=${{ inputs.version }}" >> $GITHUB_OUTPUT
else
echo "version=$(cat VERSION)" >> $GITHUB_OUTPUT
fi
- name: Check if prerelease
id: check-prerelease
run: echo "is_prerelease=${{ contains(steps.get-version.outputs.version, '-pre') }}" >> $GITHUB_OUTPUT
build-images:
needs: prepare-release
runs-on: ubuntu-latest
permissions:
packages: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Free disk space
run: |
echo "Listing current disk space"
df -h
echo "Cleaning up disk space..."
sudo rm -rf /usr/share/dotnet
sudo rm -rf /usr/local/lib/android
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache
docker system prune -af
echo "Disk space after cleanup"
df -h
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
with:
driver-opts: |
image=moby/buildkit:latest
network=host
- name: Log in to GitHub Container Registry
uses: docker/login-action@v2
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build and push images
env:
DOCKER_BUILDKIT: 1
BUILDKIT_STEP_LOG_MAX_SIZE: 10485760
VERSION: ${{ needs.prepare-release.outputs.version }}
run: docker buildx bake --push
create-release:
needs: [prepare-release, build-images]
runs-on: ubuntu-latest
permissions:
contents: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Create Release
uses: softprops/action-gh-release@v1
with:
tag_name: ${{ needs.prepare-release.outputs.version }}
generate_release_notes: true
draft: true
prerelease: ${{ needs.prepare-release.outputs.is_prerelease }}

110
.github/workflows/release.yml vendored Normal file
View file

@ -0,0 +1,110 @@
name: Create Release and Publish Docker Images
on:
push:
branches:
- release # Trigger when commits are pushed to the release branch (e.g., after merging master)
paths-ignore:
- '**.md'
- 'docs/**'
jobs:
prepare-release:
runs-on: ubuntu-latest
outputs:
version: ${{ steps.get-version.outputs.version }}
version_tag: ${{ steps.get-version.outputs.version_tag }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Get version from VERSION file
id: get-version
run: |
VERSION_PLAIN=$(cat VERSION)
echo "version=${VERSION_PLAIN}" >> $GITHUB_OUTPUT
echo "version_tag=v${VERSION_PLAIN}" >> $GITHUB_OUTPUT # Add 'v' prefix for tag
build-images:
needs: prepare-release
runs-on: ubuntu-latest
permissions:
packages: write # Needed to push images to GHCR
env:
DOCKER_BUILDKIT: 1
BUILDKIT_STEP_LOG_MAX_SIZE: 10485760
# This environment variable will override the VERSION variable in docker-bake.hcl
VERSION: ${{ needs.prepare-release.outputs.version_tag }} # Use tag version (vX.Y.Z) for bake
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0 # Needed to check for existing tags
- name: Check if tag already exists
run: |
TAG_NAME="${{ needs.prepare-release.outputs.version_tag }}"
echo "Checking for existing tag: $TAG_NAME"
# Fetch tags explicitly just in case checkout didn't get them all
git fetch --tags
if git rev-parse "$TAG_NAME" >/dev/null 2>&1; then
echo "::error::Tag $TAG_NAME already exists. Please increment the version in the VERSION file."
exit 1
else
echo "Tag $TAG_NAME does not exist. Proceeding with release."
fi
- name: Free disk space # Optional: Keep as needed for large builds
run: |
echo "Listing current disk space"
df -h
echo "Cleaning up disk space..."
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache
docker system prune -af
echo "Disk space after cleanup"
df -h
- name: Set up QEMU
uses: docker/setup-qemu-action@v3 # Use v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3 # Use v3
with:
driver-opts: |
image=moby/buildkit:latest
network=host
- name: Log in to GitHub Container Registry
uses: docker/login-action@v3 # Use v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build and push images using Docker Bake
run: |
echo "Building and pushing images for version ${{ needs.prepare-release.outputs.version_tag }}"
# The VERSION env var above sets the tag for the bake file targets
docker buildx bake --push
create-release:
needs: [prepare-release, build-images]
runs-on: ubuntu-latest
permissions:
contents: write # Needed to create releases
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0 # Fetch all history for release notes generation
- name: Create GitHub Release
uses: softprops/action-gh-release@v2 # Use v2
with:
tag_name: ${{ needs.prepare-release.outputs.version_tag }} # Use vX.Y.Z tag
name: Release ${{ needs.prepare-release.outputs.version_tag }}
generate_release_notes: true # Auto-generate release notes
draft: false # Publish immediately
prerelease: false # Mark as a stable release
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View file

@ -2,6 +2,37 @@
Notable changes to this project will be documented in this file.
## [v0.3.0] - 2025-04-04
### Added
- Apple Silicon (MPS) acceleration support for macOS users.
- Voice subtraction capability for creating unique voice effects.
- Windows PowerShell start scripts (`start-cpu.ps1`, `start-gpu.ps1`).
- Automatic model downloading integrated into all start scripts.
- Example Helm chart values for Azure AKS and Nvidia GPU Operator deployments.
- `CONTRIBUTING.md` guidelines for developers.
### Changed
- Version bump of underlying Kokoro and Misaki libraries
- Default API port reverted to 8880.
- Docker containers now run as a non-root user for enhanced security.
- Improved text normalization for numbers, currency, and time formats.
- Updated and improved Helm chart configurations and documentation.
- Enhanced temporary file management with better error tracking.
- Web UI dependencies (Siriwave) are now served locally.
- Standardized environment variable handling across shell/PowerShell scripts.
### Fixed
- Corrected an issue preventing download links from being returned when `streaming=false`.
- Resolved errors in Windows PowerShell scripts related to virtual environment activation order.
- Addressed potential segfaults during inference.
- Fixed various Helm chart issues related to health checks, ingress, and default values.
- Corrected audio quality degradation caused by incorrect bitrate settings in some cases.
- Ensured custom phonemes provided in input text are preserved.
- Fixed a 'MediaSource' error affecting playback stability in the web player.
### Removed
- Obsolete GitHub Actions build workflow, build and publish now occurs on merge to `Release` branch
## [v0.2.0post1] - 2025-02-07
- Fix: Building Kokoro from source with adjustments, to avoid CUDA lock
- Fixed ARM64 compatibility on Spacy dep to avoid emulation slowdown

86
CONTRIBUTING.md Normal file
View file

@ -0,0 +1,86 @@
# Contributing to Kokoro-FastAPI
Always appreciate community involvement in making this project better.
## Development Setup
We use `uv` for managing Python environments and dependencies, and `ruff` for linting and formatting.
1. **Clone the repository:**
```bash
git clone https://github.com/remsky/Kokoro-FastAPI.git
cd Kokoro-FastAPI
```
2. **Install `uv`:**
Follow the instructions on the [official `uv` documentation](https://docs.astral.sh/uv/install/).
3. **Create a virtual environment and install dependencies:**
It's recommended to use a virtual environment. `uv` can create one for you. Install the base dependencies along with the `test` and `cpu` extras (needed for running tests locally).
```bash
# Create and activate a virtual environment (e.g., named .venv)
uv venv
source .venv/bin/activate # On Linux/macOS
# .venv\Scripts\activate # On Windows
# Install dependencies including test requirements
uv pip install -e ".[test,cpu]"
```
*Note: If you have an NVIDIA GPU and want to test GPU-specific features locally, you can install `.[test,gpu]` instead, ensuring you have the correct CUDA toolkit installed.*
*Note: If running via uv locally, you will have to install espeak and handle any pathing issues that arise. The Docker images handle this automatically*
4. **Install `ruff` (if not already installed globally):**
While `ruff` might be included via dependencies, installing it explicitly ensures you have it available.
```bash
uv pip install ruff
```
## Running Tests
Before submitting changes, please ensure all tests pass as this is a automated requirement. The tests are run using `pytest`.
```bash
# Make sure your virtual environment is activated
uv run pytest
```
*Note: The CI workflow runs tests using `uv run pytest api/tests/ --asyncio-mode=auto --cov=api --cov-report=term-missing`. Running `uv run pytest` locally should cover the essential checks.*
## Testing with Docker Compose
In addition to local `pytest` runs, test your changes using Docker Compose to ensure they work correctly within the containerized environment. If you aren't able to test on CUDA hardware, make note so it can be tested by another maintainer
```bash
docker compose -f docker/cpu/docker-compose.yml up --build
+
docker compose -f docker/gpu/docker-compose.yml up --build
```
This command will build the Docker images (if they've changed) and start the services defined in the respective compose file. Verify the application starts correctly and test the relevant functionality.
## Code Formatting and Linting
We use `ruff` to maintain code quality and consistency. Please format and lint your code before committing.
1. **Format the code:**
```bash
# Make sure your virtual environment is activated
ruff format .
```
2. **Lint the code (and apply automatic fixes):**
```bash
# Make sure your virtual environment is activated
ruff check . --fix
```
Review any changes made by `--fix` and address any remaining linting errors manually.
## Submitting Changes
0. Clone the repo
1. Create a new branch for your feature or bug fix.
2. Make your changes, following setup, testing, and formatting guidelines above.
3. Please try to keep your changes inline with the current design, and modular. Large-scale changes will take longer to review and integrate, and have less chance of being approved outright.
4. Push your branch to your fork.
5. Open a Pull Request against the `master` branch of the main repository.
Thank you for contributing!

View file

@ -3,12 +3,12 @@
</p>
# <sub><sub>_`FastKoko`_ </sub></sub>
[![Tests](https://img.shields.io/badge/tests-69%20passed-darkgreen)]()
[![Coverage](https://img.shields.io/badge/coverage-52%25-tan)]()
[![Tests](https://img.shields.io/badge/tests-69-darkgreen)]()
[![Coverage](https://img.shields.io/badge/coverage-54%25-tan)]()
[![Try on Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Try%20on-Spaces-blue)](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero)
[![Kokoro](https://img.shields.io/badge/kokoro-v0.7.9::31a2b63-BB5420)](https://github.com/hexgrad/kokoro)
[![Misaki](https://img.shields.io/badge/misaki-v0.7.9::ebc76c2-B8860B)](https://github.com/hexgrad/misaki)
[![Kokoro](https://img.shields.io/badge/kokoro-0.9.2-BB5420)](https://github.com/hexgrad/kokoro)
[![Misaki](https://img.shields.io/badge/misaki-0.9.3-B8860B)](https://github.com/hexgrad/misaki)
[![Tested at Model Commit](https://img.shields.io/badge/last--tested--model--commit-1.0::9901c2b-blue)](https://huggingface.co/hexgrad/Kokoro-82M/commit/9901c2b79161b6e898b7ea857ae5298f47b8b0d6)
@ -24,10 +24,6 @@ Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokor
### Integration Guides
[![Helm Chart](https://img.shields.io/badge/Helm%20Chart-black?style=flat&logo=helm&logoColor=white)](https://github.com/remsky/Kokoro-FastAPI/wiki/Setup-Kubernetes) [![DigitalOcean](https://img.shields.io/badge/DigitalOcean-black?style=flat&logo=digitalocean&logoColor=white)](https://github.com/remsky/Kokoro-FastAPI/wiki/Integrations-DigitalOcean) [![SillyTavern](https://img.shields.io/badge/SillyTavern-black?style=flat&color=red)](https://github.com/remsky/Kokoro-FastAPI/wiki/Integrations-SillyTavern)
[![OpenWebUI](https://img.shields.io/badge/OpenWebUI-black?style=flat&color=white)](https://github.com/remsky/Kokoro-FastAPI/wiki/Integrations-OpenWebUi)
## Get Started
<details>
@ -38,11 +34,12 @@ Pre built images are available to run, with arm/multi-arch support, and baked in
Refer to the core/config.py file for a full list of variables which can be managed via the environment
```bash
# the `latest` tag can be used, but should not be considered stable as it may include `nightly` branch builds
# it may have some bonus features however, and feedback/testing is welcome
# the `latest` tag can be used, though it may have some unexpected bonus features which impact stability.
Named versions should be pinned for your regular usage.
Feedback/testing is always welcome
docker run -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-cpu:v0.2.2 # CPU, or:
docker run --gpus all -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-gpu:v0.2.2 #NVIDIA GPU
docker run -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-cpu:latest # CPU, or:
docker run --gpus all -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-gpu:latest #NVIDIA GPU
```
@ -66,7 +63,7 @@ docker run --gpus all -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-gpu:v0.2.2 #NV
# *Note for Apple Silicon (M1/M2) users:
# The current GPU build relies on CUDA, which is not supported on Apple Silicon.
# If you are on an M1/M2/M3 Mac, please use the `docker/cpu` setup.
# MPS (Apples GPU acceleration) support is planned but not yet available.
# MPS (Apple's GPU acceleration) support is planned but not yet available.
# Models will auto-download, but if needed you can manually download:
python docker/scripts/download_model.py --output api/src/models/v1_0
@ -139,8 +136,8 @@ with client.audio.speech.with_streaming_response.create(
</details>
## Features
<details>
<summary>OpenAI-Compatible Speech Endpoint</summary>
@ -550,13 +547,15 @@ for chunk in response.iter_content(chunk_size=1024):
<details>
<summary>Versioning & Development</summary>
I'm doing what I can to keep things stable, but we are on an early and rapid set of build cycles here.
If you run into trouble, you may have to roll back a version on the release tags if something comes up, or build up from source and/or troubleshoot + submit a PR. Will leave the branch up here for the last known stable points:
**Branching Strategy:**
* **`release` branch:** Contains the latest stable build, recommended for production use. Docker images tagged with specific versions (e.g., `v0.3.0`) are built from this branch.
* **`master` branch:** Used for active development. It may contain experimental features, ongoing changes, or fixes not yet in a stable release. Use this branch if you want the absolute latest code, but be aware it might be less stable. The `latest` Docker tag often points to builds from this branch.
`v0.1.4`
`v0.0.5post1`
Note: This is a *development* focused project at its core.
Free and open source is a community effort, and I love working on this project, though there's only really so many hours in a day. If you'd like to support the work, feel free to open a PR, buy me a coffee, or report any bugs/features/etc you find during use.
If you run into trouble, you may have to roll back a version on the release tags if something comes up, or build up from source and/or troubleshoot + submit a PR.
Free and open source is a community effort, and there's only really so many hours in a day. If you'd like to support the work, feel free to open a PR, buy me a coffee, or report any bugs/features/etc you find during use.
<a href="https://www.buymeacoffee.com/remsky" target="_blank">
<img

View file

@ -1 +1 @@
v0.2.1
0.3.0

View file

@ -1,5 +1,5 @@
from pydantic_settings import BaseSettings
import torch
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
@ -14,9 +14,13 @@ class Settings(BaseSettings):
output_dir: str = "output"
output_dir_size_limit_mb: float = 500.0 # Maximum size of output directory in MB
default_voice: str = "af_heart"
default_voice_code: str | None = None # If set, overrides the first letter of voice name, though api call param still takes precedence
default_voice_code: str | None = (
None # If set, overrides the first letter of voice name, though api call param still takes precedence
)
use_gpu: bool = True # Whether to use GPU acceleration if available
device_type: str | None = None # Will be auto-detected if None, can be "cuda", "mps", or "cpu"
device_type: str | None = (
None # Will be auto-detected if None, can be "cuda", "mps", or "cpu"
)
allow_local_voice_saving: bool = (
False # Whether to allow saving combined voices locally
)
@ -31,12 +35,21 @@ class Settings(BaseSettings):
target_min_tokens: int = 175 # Target minimum tokens per chunk
target_max_tokens: int = 250 # Target maximum tokens per chunk
absolute_max_tokens: int = 450 # Absolute maximum tokens per chunk
advanced_text_normalization: bool = True # Preproesses the text before misiki
voice_weight_normalization: bool = True # Normalize the voice weights so they add up to 1
advanced_text_normalization: bool = True # Preproesses the text before misiki
voice_weight_normalization: bool = (
True # Normalize the voice weights so they add up to 1
)
gap_trim_ms: int = 1 # Base amount to trim from streaming chunk ends in milliseconds
dynamic_gap_trim_padding_ms: int = 410 # Padding to add to dynamic gap trim
dynamic_gap_trim_padding_char_multiplier: dict[str,float] = {".":1,"!":0.9,"?":1,",":0.8}
gap_trim_ms: int = (
1 # Base amount to trim from streaming chunk ends in milliseconds
)
dynamic_gap_trim_padding_ms: int = 410 # Padding to add to dynamic gap trim
dynamic_gap_trim_padding_char_multiplier: dict[str, float] = {
".": 1,
"!": 0.9,
"?": 1,
",": 0.8,
}
# Web Player Settings
enable_web_player: bool = True # Whether to serve the web player UI
@ -69,5 +82,4 @@ class Settings(BaseSettings):
return "cpu"
settings = Settings()

View file

@ -1,34 +1,41 @@
"""Base interface for Kokoro inference."""
from abc import ABC, abstractmethod
from typing import AsyncGenerator, Optional, Tuple, Union, List
from typing import AsyncGenerator, List, Optional, Tuple, Union
import numpy as np
import torch
class AudioChunk:
"""Class for audio chunks returned by model backends"""
def __init__(self,
audio: np.ndarray,
word_timestamps: Optional[List]=[],
output: Optional[Union[bytes,np.ndarray]]=b""
):
self.audio=audio
self.word_timestamps=word_timestamps
self.output=output
def __init__(
self,
audio: np.ndarray,
word_timestamps: Optional[List] = [],
output: Optional[Union[bytes, np.ndarray]] = b"",
):
self.audio = audio
self.word_timestamps = word_timestamps
self.output = output
@staticmethod
def combine(audio_chunk_list: List):
output=AudioChunk(audio_chunk_list[0].audio,audio_chunk_list[0].word_timestamps)
output = AudioChunk(
audio_chunk_list[0].audio, audio_chunk_list[0].word_timestamps
)
for audio_chunk in audio_chunk_list[1:]:
output.audio=np.concatenate((output.audio,audio_chunk.audio),dtype=np.int16)
output.audio = np.concatenate(
(output.audio, audio_chunk.audio), dtype=np.int16
)
if output.word_timestamps is not None:
output.word_timestamps+=audio_chunk.word_timestamps
output.word_timestamps += audio_chunk.word_timestamps
return output
class ModelBackend(ABC):
"""Abstract base class for model inference backend."""

View file

@ -11,9 +11,10 @@ from loguru import logger
from ..core import paths
from ..core.config import settings
from ..core.model_config import model_config
from .base import BaseModelBackend
from .base import AudioChunk
from ..structures.schemas import WordTimestamp
from .base import AudioChunk, BaseModelBackend
class KokoroV1(BaseModelBackend):
"""Kokoro backend with controlled resource management."""
@ -50,7 +51,9 @@ class KokoroV1(BaseModelBackend):
self._model = KModel(config=config_path, model=model_path).eval()
# For MPS, manually move ISTFT layers to CPU while keeping rest on MPS
if self._device == "mps":
logger.info("Moving model to MPS device with CPU fallback for unsupported operations")
logger.info(
"Moving model to MPS device with CPU fallback for unsupported operations"
)
self._model = self._model.to(torch.device("mps"))
elif self._device == "cuda":
self._model = self._model.cuda()
@ -145,11 +148,11 @@ class KokoroV1(BaseModelBackend):
voice_path = temp_path
# Use provided lang_code, settings voice code override, or first letter of voice name
if lang_code: # api is given priority
if lang_code: # api is given priority
pipeline_lang_code = lang_code
elif settings.default_voice_code: # settings is next priority
elif settings.default_voice_code: # settings is next priority
pipeline_lang_code = settings.default_voice_code
else: # voice name is default/fallback
else: # voice name is default/fallback
pipeline_lang_code = voice_name[0].lower()
pipeline = self._get_pipeline(pipeline_lang_code)
@ -244,7 +247,15 @@ class KokoroV1(BaseModelBackend):
voice_path = temp_path
# Use provided lang_code, settings voice code override, or first letter of voice name
pipeline_lang_code = lang_code if lang_code else (settings.default_voice_code if settings.default_voice_code else voice_name[0].lower())
pipeline_lang_code = (
lang_code
if lang_code
else (
settings.default_voice_code
if settings.default_voice_code
else voice_name[0].lower()
)
)
pipeline = self._get_pipeline(pipeline_lang_code)
logger.debug(
@ -255,16 +266,19 @@ class KokoroV1(BaseModelBackend):
):
if result.audio is not None:
logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
word_timestamps=None
if return_timestamps and hasattr(result, "tokens") and result.tokens:
word_timestamps=[]
current_offset=0.0
word_timestamps = None
if (
return_timestamps
and hasattr(result, "tokens")
and result.tokens
):
word_timestamps = []
current_offset = 0.0
logger.debug(
f"Processing chunk timestamps with {len(result.tokens)} tokens"
)
f"Processing chunk timestamps with {len(result.tokens)} tokens"
)
if result.pred_dur is not None:
try:
# Add timestamps with offset
for token in result.tokens:
if not all(
@ -285,7 +299,7 @@ class KokoroV1(BaseModelBackend):
WordTimestamp(
word=str(token.text).strip(),
start_time=start_time,
end_time=end_time
end_time=end_time,
)
)
logger.debug(
@ -297,8 +311,9 @@ class KokoroV1(BaseModelBackend):
f"Failed to process timestamps for chunk: {e}"
)
yield AudioChunk(result.audio.numpy(),word_timestamps=word_timestamps)
yield AudioChunk(
result.audio.numpy(), word_timestamps=word_timestamps
)
else:
logger.warning("No audio in chunk")
@ -329,7 +344,7 @@ class KokoroV1(BaseModelBackend):
torch.cuda.synchronize()
elif self._device == "mps":
# Empty cache if available (future-proofing)
if hasattr(torch.mps, 'empty_cache'):
if hasattr(torch.mps, "empty_cache"):
torch.mps.empty_cache()
def unload(self) -> None:

View file

@ -3,8 +3,8 @@ import time
from datetime import datetime
import psutil
from fastapi import APIRouter
import torch
from fastapi import APIRouter
try:
import GPUtil
@ -115,12 +115,12 @@ async def get_system_info():
# GPU Info if available
gpu_info = None
if torch.backends.mps.is_available():
gpu_info = {
"type": "MPS",
"available": True,
"device": "Apple Silicon",
"backend": "Metal"
}
gpu_info = {
"type": "MPS",
"available": True,
"device": "Apple Silicon",
"backend": "Metal",
}
elif GPU_AVAILABLE:
try:
gpus = GPUtil.getGPUs()

View file

@ -156,6 +156,7 @@ async def generate_from_phonemes(
},
)
@router.post("/dev/captioned_speech")
async def create_captioned_speech(
request: CaptionedSpeechRequest,
@ -184,7 +185,9 @@ async def create_captioned_speech(
# Check if streaming is requested (default for OpenAI client)
if request.stream:
# Create generator but don't start it yet
generator = stream_audio_chunks(tts_service, request, client_request, writer)
generator = stream_audio_chunks(
tts_service, request, client_request, writer
)
# If download link requested, wrap generator with temp file writer
if request.return_download_link:
@ -211,21 +214,32 @@ async def create_captioned_speech(
# Write chunks to temp file and stream
async for chunk_data in generator:
# The timestamp acumulator is only used when word level time stamps are generated but no audio is returned.
timestamp_acumulator=[]
timestamp_acumulator = []
if chunk_data.output: # Skip empty chunks
await temp_writer.write(chunk_data.output)
base64_chunk= base64.b64encode(chunk_data.output).decode("utf-8")
base64_chunk = base64.b64encode(
chunk_data.output
).decode("utf-8")
# Add any chunks that may be in the acumulator into the return word_timestamps
chunk_data.word_timestamps=timestamp_acumulator + chunk_data.word_timestamps
timestamp_acumulator=[]
yield CaptionedSpeechResponse(audio=base64_chunk,audio_format=content_type,timestamps=chunk_data.word_timestamps)
chunk_data.word_timestamps = (
timestamp_acumulator + chunk_data.word_timestamps
)
timestamp_acumulator = []
yield CaptionedSpeechResponse(
audio=base64_chunk,
audio_format=content_type,
timestamps=chunk_data.word_timestamps,
)
else:
if chunk_data.word_timestamps is not None and len(chunk_data.word_timestamps) > 0:
timestamp_acumulator+=chunk_data.word_timestamps
if (
chunk_data.word_timestamps is not None
and len(chunk_data.word_timestamps) > 0
):
timestamp_acumulator += chunk_data.word_timestamps
# Finalize the temp file
await temp_writer.finalize()
except Exception as e:
@ -246,26 +260,37 @@ async def create_captioned_speech(
async def single_output():
try:
# The timestamp acumulator is only used when word level time stamps are generated but no audio is returned.
timestamp_acumulator=[]
timestamp_acumulator = []
# Stream chunks
async for chunk_data in generator:
if chunk_data.output: # Skip empty chunks
# Encode the chunk bytes into base 64
base64_chunk= base64.b64encode(chunk_data.output).decode("utf-8")
base64_chunk = base64.b64encode(chunk_data.output).decode(
"utf-8"
)
# Add any chunks that may be in the acumulator into the return word_timestamps
if chunk_data.word_timestamps != None:
chunk_data.word_timestamps = timestamp_acumulator + chunk_data.word_timestamps
chunk_data.word_timestamps = (
timestamp_acumulator + chunk_data.word_timestamps
)
else:
chunk_data.word_timestamps = []
timestamp_acumulator=[]
yield CaptionedSpeechResponse(audio=base64_chunk,audio_format=content_type,timestamps=chunk_data.word_timestamps)
timestamp_acumulator = []
yield CaptionedSpeechResponse(
audio=base64_chunk,
audio_format=content_type,
timestamps=chunk_data.word_timestamps,
)
else:
if chunk_data.word_timestamps is not None and len(chunk_data.word_timestamps) > 0:
timestamp_acumulator+=chunk_data.word_timestamps
if (
chunk_data.word_timestamps is not None
and len(chunk_data.word_timestamps) > 0
):
timestamp_acumulator += chunk_data.word_timestamps
except Exception as e:
logger.error(f"Error in single output streaming: {e}")
writer.close()
@ -293,7 +318,7 @@ async def create_captioned_speech(
normalization_options=request.normalization_options,
lang_code=request.lang_code,
)
audio_data = await AudioService.convert_audio(
audio_data,
request.response_format,
@ -301,7 +326,7 @@ async def create_captioned_speech(
is_last_chunk=False,
trim_audio=False,
)
# Convert to requested format with proper finalization
final = await AudioService.convert_audio(
AudioChunk(np.array([], dtype=np.int16)),
@ -309,11 +334,15 @@ async def create_captioned_speech(
writer,
is_last_chunk=True,
)
output=audio_data.output + final.output
base64_output= base64.b64encode(output).decode("utf-8")
content=CaptionedSpeechResponse(audio=base64_output,audio_format=content_type,timestamps=audio_data.word_timestamps).model_dump()
output = audio_data.output + final.output
base64_output = base64.b64encode(output).decode("utf-8")
content = CaptionedSpeechResponse(
audio=base64_output,
audio_format=content_type,
timestamps=audio_data.word_timestamps,
).model_dump()
writer.close()
@ -376,4 +405,4 @@ async def create_captioned_speech(
"message": str(e),
"type": "server_error",
},
)
)

View file

@ -10,18 +10,18 @@ from urllib import response
import aiofiles
import numpy as np
from ..services.streaming_audio_writer import StreamingAudioWriter
import torch
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
from fastapi.responses import FileResponse, StreamingResponse
from loguru import logger
from ..structures.schemas import CaptionedSpeechRequest
from ..core.config import settings
from ..inference.base import AudioChunk
from ..services.audio import AudioService
from ..services.streaming_audio_writer import StreamingAudioWriter
from ..services.tts_service import TTSService
from ..structures import OpenAISpeechRequest
from ..structures.schemas import CaptionedSpeechRequest
# Load OpenAI mappings
@ -80,7 +80,9 @@ def get_model_name(model: str) -> str:
return base_name + ".pth"
async def process_and_validate_voices(voice_input: Union[str, List[str]], tts_service: TTSService) -> str:
async def process_and_validate_voices(
voice_input: Union[str, List[str]], tts_service: TTSService
) -> str:
"""Process voice input, handling both string and list formats
Returns:
@ -107,22 +109,35 @@ async def process_and_validate_voices(voice_input: Union[str, List[str]], tts_se
mapped_voice = list(map(str.strip, mapped_voice))
if len(mapped_voice) > 2:
raise ValueError(f"Voice '{voices[voice_index]}' contains too many weight items")
raise ValueError(
f"Voice '{voices[voice_index]}' contains too many weight items"
)
if mapped_voice.count(")") > 1:
raise ValueError(f"Voice '{voices[voice_index]}' contains too many weight items")
raise ValueError(
f"Voice '{voices[voice_index]}' contains too many weight items"
)
mapped_voice[0] = _openai_mappings["voices"].get(mapped_voice[0], mapped_voice[0])
mapped_voice[0] = _openai_mappings["voices"].get(
mapped_voice[0], mapped_voice[0]
)
if mapped_voice[0] not in available_voices:
raise ValueError(f"Voice '{mapped_voice[0]}' not found. Available voices: {', '.join(sorted(available_voices))}")
raise ValueError(
f"Voice '{mapped_voice[0]}' not found. Available voices: {', '.join(sorted(available_voices))}"
)
voices[voice_index] = "(".join(mapped_voice)
return "".join(voices)
async def stream_audio_chunks(tts_service: TTSService, request: Union[OpenAISpeechRequest, CaptionedSpeechRequest], client_request: Request, writer: StreamingAudioWriter) -> AsyncGenerator[AudioChunk, None]:
async def stream_audio_chunks(
tts_service: TTSService,
request: Union[OpenAISpeechRequest, CaptionedSpeechRequest],
client_request: Request,
writer: StreamingAudioWriter,
) -> AsyncGenerator[AudioChunk, None]:
"""Stream audio chunks as they're generated with client disconnect handling"""
voice_name = await process_and_validate_voices(request.voice, tts_service)
unique_properties = {"return_timestamps": False}
@ -193,7 +208,9 @@ async def create_speech(
# Check if streaming is requested (default for OpenAI client)
if request.stream:
# Create generator but don't start it yet
generator = stream_audio_chunks(tts_service, request, client_request, writer)
generator = stream_audio_chunks(
tts_service, request, client_request, writer
)
# If download link requested, wrap generator with temp file writer
if request.return_download_link:
@ -215,7 +232,7 @@ async def create_speech(
"Transfer-Encoding": "chunked",
"X-Download-Path": download_path,
}
# Add header to indicate if temp file writing is available
if temp_writer._write_error:
headers["X-Download-Status"] = "unavailable"
@ -245,7 +262,9 @@ async def create_speech(
writer.close()
# Stream with temp file writing
return StreamingResponse(dual_output(), media_type=content_type, headers=headers)
return StreamingResponse(
dual_output(), media_type=content_type, headers=headers
)
async def single_output():
try:
@ -285,7 +304,13 @@ async def create_speech(
lang_code=request.lang_code,
)
audio_data = await AudioService.convert_audio(audio_data, request.response_format, writer, is_last_chunk=False, trim_audio=False)
audio_data = await AudioService.convert_audio(
audio_data,
request.response_format,
writer,
is_last_chunk=False,
trim_audio=False,
)
# Convert to requested format with proper finalization
final = await AudioService.convert_audio(
@ -334,7 +359,7 @@ async def create_speech(
except ValueError as e:
# Handle validation errors
logger.warning(f"Invalid request: {str(e)}")
try:
writer.close()
except:
@ -382,7 +407,6 @@ async def create_speech(
"type": "server_error",
},
)
@router.get("/download/{filename}")
@ -392,7 +416,9 @@ async def download_audio_file(filename: str):
from ..core.paths import _find_file, get_content_type
# Search for file in temp directory
file_path = await _find_file(filename=filename, search_paths=[settings.temp_file_dir])
file_path = await _find_file(
filename=filename, search_paths=[settings.temp_file_dir]
)
# Get content type from path helper
content_type = await get_content_type(file_path)
@ -425,9 +451,24 @@ async def list_models():
try:
# Create standard model list
models = [
{"id": "tts-1", "object": "model", "created": 1686935002, "owned_by": "kokoro"},
{"id": "tts-1-hd", "object": "model", "created": 1686935002, "owned_by": "kokoro"},
{"id": "kokoro", "object": "model", "created": 1686935002, "owned_by": "kokoro"},
{
"id": "tts-1",
"object": "model",
"created": 1686935002,
"owned_by": "kokoro",
},
{
"id": "tts-1-hd",
"object": "model",
"created": 1686935002,
"owned_by": "kokoro",
},
{
"id": "kokoro",
"object": "model",
"created": 1686935002,
"owned_by": "kokoro",
},
]
return {"object": "list", "data": models}
@ -449,14 +490,36 @@ async def retrieve_model(model: str):
try:
# Define available models
models = {
"tts-1": {"id": "tts-1", "object": "model", "created": 1686935002, "owned_by": "kokoro"},
"tts-1-hd": {"id": "tts-1-hd", "object": "model", "created": 1686935002, "owned_by": "kokoro"},
"kokoro": {"id": "kokoro", "object": "model", "created": 1686935002, "owned_by": "kokoro"},
"tts-1": {
"id": "tts-1",
"object": "model",
"created": 1686935002,
"owned_by": "kokoro",
},
"tts-1-hd": {
"id": "tts-1-hd",
"object": "model",
"created": 1686935002,
"owned_by": "kokoro",
},
"kokoro": {
"id": "kokoro",
"object": "model",
"created": 1686935002,
"owned_by": "kokoro",
},
}
# Check if requested model exists
if model not in models:
raise HTTPException(status_code=404, detail={"error": "model_not_found", "message": f"Model '{model}' not found", "type": "invalid_request_error"})
raise HTTPException(
status_code=404,
detail={
"error": "model_not_found",
"message": f"Model '{model}' not found",
"type": "invalid_request_error",
},
)
# Return the specific model
return models[model]
@ -541,7 +604,9 @@ async def combine_voices(request: Union[str, List[str]]):
available_voices = await tts_service.list_voices()
for voice in voices:
if voice not in available_voices:
raise ValueError(f"Base voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}")
raise ValueError(
f"Base voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}"
)
# Combine voices
combined_tensor = await tts_service.combine_voices(voices=voices)

View file

@ -1,12 +1,12 @@
"""Audio conversion service"""
import math
import struct
import time
from typing import Tuple
from io import BytesIO
from typing import Tuple
import numpy as np
import math
import scipy.io.wavfile as wavfile
import soundfile as sf
from loguru import logger
@ -14,8 +14,9 @@ from pydub import AudioSegment
from torch import norm
from ..core.config import settings
from .streaming_audio_writer import StreamingAudioWriter
from ..inference.base import AudioChunk
from .streaming_audio_writer import StreamingAudioWriter
class AudioNormalizer:
"""Handles audio normalization state for a single stream"""
@ -24,53 +25,78 @@ class AudioNormalizer:
self.chunk_trim_ms = settings.gap_trim_ms
self.sample_rate = 24000 # Sample rate of the audio
self.samples_to_trim = int(self.chunk_trim_ms * self.sample_rate / 1000)
self.samples_to_pad_start= int(50 * self.sample_rate / 1000)
self.samples_to_pad_start = int(50 * self.sample_rate / 1000)
def find_first_last_non_silent(self,audio_data: np.ndarray, chunk_text: str, speed: float, silence_threshold_db: int = -45, is_last_chunk: bool = False) -> tuple[int, int]:
def find_first_last_non_silent(
self,
audio_data: np.ndarray,
chunk_text: str,
speed: float,
silence_threshold_db: int = -45,
is_last_chunk: bool = False,
) -> tuple[int, int]:
"""Finds the indices of the first and last non-silent samples in audio data.
Args:
audio_data: Input audio data as numpy array
chunk_text: The text sent to the model to generate the resulting speech
speed: The speaking speed of the voice
silence_threshold_db: How quiet audio has to be to be conssidered silent
is_last_chunk: Whether this is the last chunk
Returns:
A tuple with the start of the non silent portion and with the end of the non silent portion
"""
pad_multiplier=1
split_character=chunk_text.strip()
pad_multiplier = 1
split_character = chunk_text.strip()
if len(split_character) > 0:
split_character=split_character[-1]
split_character = split_character[-1]
if split_character in settings.dynamic_gap_trim_padding_char_multiplier:
pad_multiplier=settings.dynamic_gap_trim_padding_char_multiplier[split_character]
pad_multiplier = settings.dynamic_gap_trim_padding_char_multiplier[
split_character
]
if not is_last_chunk:
samples_to_pad_end= max(int((settings.dynamic_gap_trim_padding_ms * self.sample_rate * pad_multiplier) / 1000) - self.samples_to_pad_start, 0)
samples_to_pad_end = max(
int(
(
settings.dynamic_gap_trim_padding_ms
* self.sample_rate
* pad_multiplier
)
/ 1000
)
- self.samples_to_pad_start,
0,
)
else:
samples_to_pad_end=self.samples_to_pad_start
samples_to_pad_end = self.samples_to_pad_start
# Convert dBFS threshold to amplitude
amplitude_threshold = np.iinfo(audio_data.dtype).max * (10 ** (silence_threshold_db / 20))
amplitude_threshold = np.iinfo(audio_data.dtype).max * (
10 ** (silence_threshold_db / 20)
)
# Find the first samples above the silence threshold at the start and end of the audio
non_silent_index_start, non_silent_index_end = None,None
non_silent_index_start, non_silent_index_end = None, None
for X in range(0,len(audio_data)):
for X in range(0, len(audio_data)):
if audio_data[X] > amplitude_threshold:
non_silent_index_start=X
non_silent_index_start = X
break
for X in range(len(audio_data) - 1, -1, -1):
if audio_data[X] > amplitude_threshold:
non_silent_index_end=X
non_silent_index_end = X
break
# Handle the case where the entire audio is silent
if non_silent_index_start == None or non_silent_index_end == None:
return 0, len(audio_data)
return max(non_silent_index_start - self.samples_to_pad_start,0), min(non_silent_index_end + math.ceil(samples_to_pad_end / speed),len(audio_data))
return max(non_silent_index_start - self.samples_to_pad_start, 0), min(
non_silent_index_end + math.ceil(samples_to_pad_end / speed),
len(audio_data),
)
def normalize(self, audio_data: np.ndarray) -> np.ndarray:
"""Convert audio data to int16 range
@ -85,6 +111,7 @@ class AudioNormalizer:
return np.clip(audio_data * 32767, -32768, 32767).astype(np.int16)
return audio_data
class AudioService:
"""Service for audio format conversions with streaming support"""
@ -143,28 +170,28 @@ class AudioService:
# Always normalize audio to ensure proper amplitude scaling
if normalizer is None:
normalizer = AudioNormalizer()
audio_chunk.audio = normalizer.normalize(audio_chunk.audio)
if trim_audio == True:
audio_chunk = AudioService.trim_audio(audio_chunk,chunk_text,speed,is_last_chunk,normalizer)
audio_chunk = AudioService.trim_audio(
audio_chunk, chunk_text, speed, is_last_chunk, normalizer
)
# Write audio data first
if len(audio_chunk.audio) > 0:
chunk_data = writer.write_chunk(audio_chunk.audio)
# Then finalize if this is the last chunk
if is_last_chunk:
final_data = writer.write_chunk(finalize=True)
if final_data:
audio_chunk.output=final_data
audio_chunk.output = final_data
return audio_chunk
if chunk_data:
audio_chunk.output=chunk_data
audio_chunk.output = chunk_data
return audio_chunk
except Exception as e:
@ -172,8 +199,15 @@ class AudioService:
raise ValueError(
f"Failed to convert audio stream to {output_format}: {str(e)}"
)
@staticmethod
def trim_audio(audio_chunk: AudioChunk, chunk_text: str = "", speed: float = 1, is_last_chunk: bool = False, normalizer: AudioNormalizer = None) -> AudioChunk:
def trim_audio(
audio_chunk: AudioChunk,
chunk_text: str = "",
speed: float = 1,
is_last_chunk: bool = False,
normalizer: AudioNormalizer = None,
) -> AudioChunk:
"""Trim silence from start and end
Args:
@ -182,30 +216,33 @@ class AudioService:
speed: The speaking speed of the voice
is_last_chunk: Whether this is the last chunk
normalizer: Optional AudioNormalizer instance for consistent normalization
Returns:
Trimmed audio data
"""
if normalizer is None:
normalizer = AudioNormalizer()
audio_chunk.audio=normalizer.normalize(audio_chunk.audio)
trimed_samples=0
audio_chunk.audio = normalizer.normalize(audio_chunk.audio)
trimed_samples = 0
# Trim start and end if enough samples
if len(audio_chunk.audio) > (2 * normalizer.samples_to_trim):
audio_chunk.audio = audio_chunk.audio[normalizer.samples_to_trim : -normalizer.samples_to_trim]
trimed_samples+=normalizer.samples_to_trim
# Find non silent portion and trim
start_index,end_index=normalizer.find_first_last_non_silent(audio_chunk.audio,chunk_text,speed,is_last_chunk=is_last_chunk)
audio_chunk.audio=audio_chunk.audio[start_index:end_index]
trimed_samples+=start_index
audio_chunk.audio = audio_chunk.audio[
normalizer.samples_to_trim : -normalizer.samples_to_trim
]
trimed_samples += normalizer.samples_to_trim
# Find non silent portion and trim
start_index, end_index = normalizer.find_first_last_non_silent(
audio_chunk.audio, chunk_text, speed, is_last_chunk=is_last_chunk
)
audio_chunk.audio = audio_chunk.audio[start_index:end_index]
trimed_samples += start_index
if audio_chunk.word_timestamps is not None:
for timestamp in audio_chunk.word_timestamps:
timestamp.start_time-=trimed_samples / 24000
timestamp.end_time-=trimed_samples / 24000
timestamp.start_time -= trimed_samples / 24000
timestamp.end_time -= trimed_samples / 24000
return audio_chunk

View file

@ -4,11 +4,12 @@ import struct
from io import BytesIO
from typing import Optional
import av
import numpy as np
import soundfile as sf
from loguru import logger
from pydub import AudioSegment
import av
class StreamingAudioWriter:
"""Handles streaming audio format conversions"""
@ -18,15 +19,29 @@ class StreamingAudioWriter:
self.sample_rate = sample_rate
self.channels = channels
self.bytes_written = 0
self.pts=0
self.pts = 0
codec_map = {"wav":"pcm_s16le","mp3":"mp3","opus":"libopus","flac":"flac", "aac":"aac"}
codec_map = {
"wav": "pcm_s16le",
"mp3": "mp3",
"opus": "libopus",
"flac": "flac",
"aac": "aac",
}
# Format-specific setup
if self.format in ["wav","flac","mp3","pcm","aac","opus"]:
if self.format in ["wav", "flac", "mp3", "pcm", "aac", "opus"]:
if self.format != "pcm":
self.output_buffer = BytesIO()
self.container = av.open(self.output_buffer, mode="w", format=self.format if self.format != "aac" else "adts")
self.stream = self.container.add_stream(codec_map[self.format],sample_rate=self.sample_rate,layout='mono' if self.channels == 1 else 'stereo')
self.container = av.open(
self.output_buffer,
mode="w",
format=self.format if self.format != "aac" else "adts",
)
self.stream = self.container.add_stream(
codec_map[self.format],
sample_rate=self.sample_rate,
layout="mono" if self.channels == 1 else "stereo",
)
self.stream.bit_rate = 128000
else:
raise ValueError(f"Unsupported format: {format}")
@ -53,8 +68,8 @@ class StreamingAudioWriter:
packets = self.stream.encode(None)
for packet in packets:
self.container.mux(packet)
data=self.output_buffer.getvalue()
data = self.output_buffer.getvalue()
self.close()
return data
@ -65,19 +80,21 @@ class StreamingAudioWriter:
# Write raw bytes
return audio_data.tobytes()
else:
frame = av.AudioFrame.from_ndarray(audio_data.reshape(1, -1), format='s16', layout='mono' if self.channels == 1 else 'stereo')
frame.sample_rate=self.sample_rate
frame = av.AudioFrame.from_ndarray(
audio_data.reshape(1, -1),
format="s16",
layout="mono" if self.channels == 1 else "stereo",
)
frame.sample_rate = self.sample_rate
frame.pts = self.pts
self.pts += frame.samples
packets = self.stream.encode(frame)
for packet in packets:
self.container.mux(packet)
data = self.output_buffer.getvalue()
self.output_buffer.seek(0)
self.output_buffer.truncate(0)
return data
return data

View file

@ -110,7 +110,7 @@ class TempFileWriter:
# Set a placeholder path so the API can still function
self.temp_path = f"unavailable_{self.format}"
self.download_path = f"/download/{self.temp_path}"
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
@ -131,11 +131,11 @@ class TempFileWriter:
"""
if self._finalized:
raise RuntimeError("Cannot write to finalized temp file")
# Skip writing if we've already encountered an error
if self._write_error or not self.temp_file:
return
try:
await self.temp_file.write(chunk)
await self.temp_file.flush()
@ -157,7 +157,7 @@ class TempFileWriter:
if self._write_error or not self.temp_file:
self._finalized = True
return self.download_path
try:
await self.temp_file.close()
self._finalized = True

View file

@ -6,12 +6,13 @@ Converts them into a format suitable for text-to-speech processing.
import re
from functools import lru_cache
import inflect
from numpy import number
from torch import mul
from ...structures.schemas import NormalizationOptions
from text_to_num import text2num
from torch import mul
from ...structures.schemas import NormalizationOptions
# Constants
VALID_TLDS = [
@ -54,26 +55,81 @@ VALID_TLDS = [
"uk",
"us",
"io",
"co",
]
VALID_UNITS = {
"m":"meter", "cm":"centimeter", "mm":"millimeter", "km":"kilometer", "in":"inch", "ft":"foot", "yd":"yard", "mi":"mile", # Length
"g":"gram", "kg":"kilogram", "mg":"miligram", # Mass
"s":"second", "ms":"milisecond", "min":"minutes", "h":"hour", # Time
"l":"liter", "ml":"mililiter", "cl":"centiliter", "dl":"deciliter", # Volume
"kph":"kilometer per hour", "mph":"mile per hour","mi/h":"mile per hour", "m/s":"meter per second", "km/h":"kilometer per hour", "mm/s":"milimeter per second","cm/s":"centimeter per second", "ft/s":"feet per second","cm/h":"centimeter per day", # Speed
"°c":"degree celsius","c":"degree celsius", "°f":"degree fahrenheit","f":"degree fahrenheit", "k":"kelvin", # Temperature
"pa":"pascal", "kpa":"kilopascal", "mpa":"megapascal", "atm":"atmosphere", # Pressure
"hz":"hertz", "khz":"kilohertz", "mhz":"megahertz", "ghz":"gigahertz", # Frequency
"v":"volt", "kv":"kilovolt", "mv":"mergavolt", # Voltage
"a":"amp", "ma":"megaamp", "ka":"kiloamp", # Current
"w":"watt", "kw":"kilowatt", "mw":"megawatt", # Power
"j":"joule", "kj":"kilojoule", "mj":"megajoule", # Energy
"Ω":"ohm", "":"kiloohm", "":"megaohm", # Resistance (Ohm)
"f":"farad", "µf":"microfarad", "nf":"nanofarad", "pf":"picofarad", # Capacitance
"b":"bit", "kb":"kilobit", "mb":"megabit", "gb":"gigabit", "tb":"terabit", "pb":"petabit", # Data size
"kbps":"kilobit per second","mbps":"megabit per second","gbps":"gigabit per second","tbps":"terabit per second",
"px":"pixel" # CSS units
"m": "meter",
"cm": "centimeter",
"mm": "millimeter",
"km": "kilometer",
"in": "inch",
"ft": "foot",
"yd": "yard",
"mi": "mile", # Length
"g": "gram",
"kg": "kilogram",
"mg": "milligram", # Mass
"s": "second",
"ms": "millisecond",
"min": "minutes",
"h": "hour", # Time
"l": "liter",
"ml": "mililiter",
"cl": "centiliter",
"dl": "deciliter", # Volume
"kph": "kilometer per hour",
"mph": "mile per hour",
"mi/h": "mile per hour",
"m/s": "meter per second",
"km/h": "kilometer per hour",
"mm/s": "milimeter per second",
"cm/s": "centimeter per second",
"ft/s": "feet per second",
"cm/h": "centimeter per day", # Speed
"°c": "degree celsius",
"c": "degree celsius",
"°f": "degree fahrenheit",
"f": "degree fahrenheit",
"k": "kelvin", # Temperature
"pa": "pascal",
"kpa": "kilopascal",
"mpa": "megapascal",
"atm": "atmosphere", # Pressure
"hz": "hertz",
"khz": "kilohertz",
"mhz": "megahertz",
"ghz": "gigahertz", # Frequency
"v": "volt",
"kv": "kilovolt",
"mv": "mergavolt", # Voltage
"a": "amp",
"ma": "megaamp",
"ka": "kiloamp", # Current
"w": "watt",
"kw": "kilowatt",
"mw": "megawatt", # Power
"j": "joule",
"kj": "kilojoule",
"mj": "megajoule", # Energy
"Ω": "ohm",
"": "kiloohm",
"": "megaohm", # Resistance (Ohm)
"f": "farad",
"µf": "microfarad",
"nf": "nanofarad",
"pf": "picofarad", # Capacitance
"b": "bit",
"kb": "kilobit",
"mb": "megabit",
"gb": "gigabit",
"tb": "terabit",
"pb": "petabit", # Data size
"kbps": "kilobit per second",
"mbps": "megabit per second",
"gbps": "gigabit per second",
"tbps": "terabit per second",
"px": "pixel", # CSS units
}
@ -88,11 +144,19 @@ URL_PATTERN = re.compile(
re.IGNORECASE,
)
UNIT_PATTERN = re.compile(r"((?<!\w)([+-]?)(\d{1,3}(,\d{3})*|\d+)(\.\d+)?)\s*(" + "|".join(sorted(list(VALID_UNITS.keys()),reverse=True)) + r"""){1}(?=[^\w\d]{1}|\b)""",re.IGNORECASE)
UNIT_PATTERN = re.compile(
r"((?<!\w)([+-]?)(\d{1,3}(,\d{3})*|\d+)(\.\d+)?)\s*("
+ "|".join(sorted(list(VALID_UNITS.keys()), reverse=True))
+ r"""){1}(?=[^\w\d]{1}|\b)""",
re.IGNORECASE,
)
TIME_PATTERN = re.compile(r"([0-9]{2} ?: ?[0-9]{2}( ?: ?[0-9]{2})?)( ?(pm|am)\b)?", re.IGNORECASE)
TIME_PATTERN = re.compile(
r"([0-9]{2} ?: ?[0-9]{2}( ?: ?[0-9]{2})?)( ?(pm|am)\b)?", re.IGNORECASE
)
INFLECT_ENGINE = inflect.engine()
INFLECT_ENGINE=inflect.engine()
def split_num(num: re.Match[str]) -> str:
"""Handle number splitting for various formats"""
@ -118,29 +182,32 @@ def split_num(num: re.Match[str]) -> str:
return f"{left} oh {right}{s}"
return f"{left} {right}{s}"
def handle_units(u: re.Match[str]) -> str:
"""Converts units to their full form"""
unit_string=u.group(6).strip()
unit=unit_string
unit_string = u.group(6).strip()
unit = unit_string
if unit_string.lower() in VALID_UNITS:
unit=VALID_UNITS[unit_string.lower()].split(" ")
unit = VALID_UNITS[unit_string.lower()].split(" ")
# Handles the B vs b case
if unit[0].endswith("bit"):
b_case=unit_string[min(1,len(unit_string) - 1)]
b_case = unit_string[min(1, len(unit_string) - 1)]
if b_case == "B":
unit[0]=unit[0][:-3] + "byte"
number=u.group(1).strip()
unit[0]=INFLECT_ENGINE.no(unit[0],number)
unit[0] = unit[0][:-3] + "byte"
number = u.group(1).strip()
unit[0] = INFLECT_ENGINE.no(unit[0], number)
return " ".join(unit)
def conditional_int(number: float, threshold: float = 0.00001):
if abs(round(number) - number) < threshold:
return int(round(number))
return number
def handle_money(m: re.Match[str]) -> str:
"""Convert money expressions to spoken form"""
@ -153,7 +220,7 @@ def handle_money(m: re.Match[str]) -> str:
number = float(number)
except:
return m.group()
if m.group(1) == "-":
number *= -1
@ -166,6 +233,7 @@ def handle_money(m: re.Match[str]) -> str:
return text_number
def handle_decimal(num: re.Match[str]) -> str:
"""Convert decimal numbers to spoken form"""
a, b = num.group().split(".")
@ -229,34 +297,41 @@ def handle_url(u: re.Match[str]) -> str:
# Clean up extra spaces
return re.sub(r"\s+", " ", url).strip()
def handle_phone_number(p: re.Match[str]) -> str:
p=list(p.groups())
country_code=""
p = list(p.groups())
country_code = ""
if p[0] is not None:
p[0]=p[0].replace("+","")
p[0] = p[0].replace("+", "")
country_code += INFLECT_ENGINE.number_to_words(p[0])
area_code=INFLECT_ENGINE.number_to_words(p[2].replace("(","").replace(")",""),group=1,comma="")
telephone_prefix=INFLECT_ENGINE.number_to_words(p[3],group=1,comma="")
line_number=INFLECT_ENGINE.number_to_words(p[4],group=1,comma="")
return ",".join([country_code,area_code,telephone_prefix,line_number])
area_code = INFLECT_ENGINE.number_to_words(
p[2].replace("(", "").replace(")", ""), group=1, comma=""
)
telephone_prefix = INFLECT_ENGINE.number_to_words(p[3], group=1, comma="")
line_number = INFLECT_ENGINE.number_to_words(p[4], group=1, comma="")
return ",".join([country_code, area_code, telephone_prefix, line_number])
def handle_time(t: re.Match[str]) -> str:
t=t.groups()
numbers = " ".join([INFLECT_ENGINE.number_to_words(X.strip()) for X in t[0].split(":")])
half=""
t = t.groups()
numbers = " ".join(
[INFLECT_ENGINE.number_to_words(X.strip()) for X in t[0].split(":")]
)
half = ""
if t[2] is not None:
half=t[2].strip()
half = t[2].strip()
return numbers + half
def normalize_text(text: str,normalization_options: NormalizationOptions) -> str:
def normalize_text(text: str, normalization_options: NormalizationOptions) -> str:
"""Normalize text for TTS processing"""
# Handle email addresses first if enabled
if normalization_options.email_normalization:
@ -268,16 +343,20 @@ def normalize_text(text: str,normalization_options: NormalizationOptions) -> str
# Pre-process numbers with units if enabled
if normalization_options.unit_normalization:
text=UNIT_PATTERN.sub(handle_units,text)
text = UNIT_PATTERN.sub(handle_units, text)
# Replace optional pluralization
if normalization_options.optional_pluralization_normalization:
text = re.sub(r"\(s\)","s",text)
text = re.sub(r"\(s\)", "s", text)
# Replace phone numbers:
if normalization_options.phone_normalization:
text = re.sub(r"(\+?\d{1,2})?([ .-]?)(\(?\d{3}\)?)[\s.-](\d{3})[\s.-](\d{4})",handle_phone_number,text)
text = re.sub(
r"(\+?\d{1,2})?([ .-]?)(\(?\d{3}\)?)[\s.-](\d{3})[\s.-](\d{4})",
handle_phone_number,
text,
)
# Replace quotes and brackets
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
text = text.replace("«", chr(8220)).replace("»", chr(8221))
@ -288,7 +367,10 @@ def normalize_text(text: str,normalization_options: NormalizationOptions) -> str
text = text.replace(a, b + " ")
# Handle simple time in the format of HH:MM:SS
text = TIME_PATTERN.sub(handle_time, text, )
text = TIME_PATTERN.sub(
handle_time,
text,
)
# Clean up whitespace
text = re.sub(r"[^\S \n]", " ", text)
@ -307,17 +389,17 @@ def normalize_text(text: str,normalization_options: NormalizationOptions) -> str
# Handle numbers and money
text = re.sub(r"(?<=\d),(?=\d)", "", text)
text = re.sub(
r"(?i)(-?)([$£])(\d+(?:\.\d+)?)((?: hundred| thousand| (?:[bm]|tr|quadr)illion)*)\b",
handle_money,
text,
)
text = re.sub(
r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)", split_num, text
)
text = re.sub(r"\d*\.\d+", handle_decimal, text)
# Handle various formatting
@ -328,6 +410,6 @@ def normalize_text(text: str,normalization_options: NormalizationOptions) -> str
text = re.sub(
r"(?:[A-Za-z]\.){2,} [a-z]", lambda m: m.group().replace(".", "-"), text
)
text = re.sub( r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
return text.strip()

View file

@ -7,14 +7,15 @@ from typing import AsyncGenerator, Dict, List, Tuple
from loguru import logger
from ...core.config import settings
from ...structures.schemas import NormalizationOptions
from .normalizer import normalize_text
from .phonemizer import phonemize
from .vocabulary import tokenize
from ...structures.schemas import NormalizationOptions
# Pre-compiled regex patterns for performance
CUSTOM_PHONEMES = re.compile(r"(\[([^\]]|\n)*?\])(\(\/([^\/)]|\n)*?\/\))")
def process_text_chunk(
text: str, language: str = "a", skip_phonemize: bool = False
) -> List[int]:
@ -29,7 +30,7 @@ def process_text_chunk(
List of token IDs
"""
start_time = time.time()
if skip_phonemize:
# Input is already phonemes, just tokenize
t0 = time.time()
@ -41,9 +42,7 @@ def process_text_chunk(
t1 = time.time()
t0 = time.time()
phonemes = phonemize(
text, language, normalize=False
) # Already normalized
phonemes = phonemize(text, language, normalize=False) # Already normalized
t1 = time.time()
t0 = time.time()
@ -88,21 +87,24 @@ def process_text(text: str, language: str = "a") -> List[int]:
return process_text_chunk(text, language)
def get_sentence_info(text: str, custom_phenomes_list: Dict[str, str]) -> List[Tuple[str, List[int], int]]:
def get_sentence_info(
text: str, custom_phenomes_list: Dict[str, str]
) -> List[Tuple[str, List[int], int]]:
"""Process all sentences and return info."""
sentences = re.split(r"([.!?;:])(?=\s|$)", text)
phoneme_length, min_value = len(custom_phenomes_list), 0
results = []
for i in range(0, len(sentences), 2):
sentence = sentences[i].strip()
for replaced in range(min_value, phoneme_length):
current_id = f"</|custom_phonemes_{replaced}|/>"
if current_id in sentence:
sentence = sentence.replace(current_id, custom_phenomes_list.pop(current_id))
sentence = sentence.replace(
current_id, custom_phenomes_list.pop(current_id)
)
min_value += 1
punct = sentences[i + 1] if i + 1 < len(sentences) else ""
if not sentence:
@ -114,16 +116,18 @@ def get_sentence_info(text: str, custom_phenomes_list: Dict[str, str]) -> List[T
return results
def handle_custom_phonemes(s: re.Match[str], phenomes_list: Dict[str,str]) -> str:
def handle_custom_phonemes(s: re.Match[str], phenomes_list: Dict[str, str]) -> str:
latest_id = f"</|custom_phonemes_{len(phenomes_list)}|/>"
phenomes_list[latest_id] = s.group(0).strip()
return latest_id
async def smart_split(
text: str,
text: str,
max_tokens: int = settings.absolute_max_tokens,
lang_code: str = "a",
normalization_options: NormalizationOptions = NormalizationOptions()
normalization_options: NormalizationOptions = NormalizationOptions(),
) -> AsyncGenerator[Tuple[str, List[int]], None]:
"""Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens."""
start_time = time.time()
@ -135,11 +139,15 @@ async def smart_split(
# Normalize text
if settings.advanced_text_normalization and normalization_options.normalize:
print(lang_code)
if lang_code in ["a","b","en-us","en-gb"]:
text = CUSTOM_PHONEMES.sub(lambda s: handle_custom_phonemes(s, custom_phoneme_list), text)
text=normalize_text(text,normalization_options)
if lang_code in ["a", "b", "en-us", "en-gb"]:
text = CUSTOM_PHONEMES.sub(
lambda s: handle_custom_phonemes(s, custom_phoneme_list), text
)
text = normalize_text(text, normalization_options)
else:
logger.info("Skipping text normalization as it is only supported for english")
logger.info(
"Skipping text normalization as it is only supported for english"
)
# Process all sentences
sentences = get_sentence_info(text, custom_phoneme_list)
@ -177,7 +185,7 @@ async def smart_split(
continue
full_clause = clause + comma
tokens = process_text_chunk(full_clause)
count = len(tokens)

View file

@ -8,7 +8,6 @@ import time
from typing import AsyncGenerator, List, Optional, Tuple, Union
import numpy as np
from .streaming_audio_writer import StreamingAudioWriter
import torch
from kokoro import KPipeline
from loguru import logger
@ -20,6 +19,7 @@ from ..inference.model_manager import get_manager as get_model_manager
from ..inference.voice_manager import get_manager as get_voice_manager
from ..structures.schemas import NormalizationOptions
from .audio import AudioNormalizer, AudioService
from .streaming_audio_writer import StreamingAudioWriter
from .text_processing import tokenize
from .text_processing.text_processor import process_text_chunk, smart_split
@ -69,7 +69,9 @@ class TTSService:
yield AudioChunk(np.array([], dtype=np.int16), output=b"")
return
chunk_data = await AudioService.convert_audio(
AudioChunk(np.array([], dtype=np.float32)), # Dummy data for type checking
AudioChunk(
np.array([], dtype=np.float32)
), # Dummy data for type checking
output_format,
writer,
speed,
@ -114,13 +116,22 @@ class TTSService:
except Exception as e:
logger.error(f"Failed to convert audio: {str(e)}")
else:
chunk_data = AudioService.trim_audio(chunk_data, chunk_text, speed, is_last, normalizer)
chunk_data = AudioService.trim_audio(
chunk_data, chunk_text, speed, is_last, normalizer
)
yield chunk_data
chunk_index += 1
else:
# For legacy backends, load voice tensor
voice_tensor = await self._voice_manager.load_voice(voice_name, device=backend.device)
chunk_data = await self.model_manager.generate(tokens, voice_tensor, speed=speed, return_timestamps=return_timestamps)
voice_tensor = await self._voice_manager.load_voice(
voice_name, device=backend.device
)
chunk_data = await self.model_manager.generate(
tokens,
voice_tensor,
speed=speed,
return_timestamps=return_timestamps,
)
if chunk_data.audio is None:
logger.error("Model generated None for audio chunk")
@ -146,7 +157,9 @@ class TTSService:
except Exception as e:
logger.error(f"Failed to convert audio: {str(e)}")
else:
trimmed = AudioService.trim_audio(chunk_data, chunk_text, speed, is_last, normalizer)
trimmed = AudioService.trim_audio(
chunk_data, chunk_text, speed, is_last, normalizer
)
yield trimmed
except Exception as e:
logger.error(f"Failed to process tokens: {str(e)}")
@ -178,7 +191,9 @@ class TTSService:
# If it is only once voice there is no point in loading it up, doing nothing with it, then saving it
if len(split_voice) == 1:
# Since its a single voice the only time that the weight would matter is if voice_weight_normalization is off
if ("(" not in voice and ")" not in voice) or settings.voice_weight_normalization == True:
if (
"(" not in voice and ")" not in voice
) or settings.voice_weight_normalization == True:
path = await self._voice_manager.get_voice_path(voice)
if not path:
raise RuntimeError(f"Voice not found: {voice}")
@ -206,13 +221,19 @@ class TTSService:
# Load the first voice as the starting point for voices to be combined onto
path = await self._voice_manager.get_voice_path(split_voice[0][0])
combined_tensor = await self._load_voice_from_path(path, split_voice[0][1] / total_weight)
combined_tensor = await self._load_voice_from_path(
path, split_voice[0][1] / total_weight
)
# Loop through each + or - in split_voice so they can be applied to combined voice
for operation_index in range(1, len(split_voice) - 1, 2):
# Get the voice path of the voice 1 index ahead of the operator
path = await self._voice_manager.get_voice_path(split_voice[operation_index + 1][0])
voice_tensor = await self._load_voice_from_path(path, split_voice[operation_index + 1][1] / total_weight)
path = await self._voice_manager.get_voice_path(
split_voice[operation_index + 1][0]
)
voice_tensor = await self._load_voice_from_path(
path, split_voice[operation_index + 1][1] / total_weight
)
# Either add or subtract the voice from the current combined voice
if split_voice[operation_index] == "+":
@ -255,10 +276,16 @@ class TTSService:
# Use provided lang_code or determine from voice name
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream")
logger.info(
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream"
)
# Process text in chunks with smart splitting
async for chunk_text, tokens in smart_split(text, lang_code=pipeline_lang_code, normalization_options=normalization_options):
async for chunk_text, tokens in smart_split(
text,
lang_code=pipeline_lang_code,
normalization_options=normalization_options,
):
try:
# Process audio for chunk
async for chunk_data in self._process_chunk(
@ -286,10 +313,14 @@ class TTSService:
yield chunk_data
else:
logger.warning(f"No audio generated for chunk: '{chunk_text[:100]}...'")
logger.warning(
f"No audio generated for chunk: '{chunk_text[:100]}...'"
)
chunk_index += 1
except Exception as e:
logger.error(f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}")
logger.error(
f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}"
)
continue
# Only finalize if we successfully processed at least one chunk
@ -332,7 +363,16 @@ class TTSService:
audio_data_chunks = []
try:
async for audio_stream_data in self.generate_audio_stream(text, voice, writer, speed=speed, normalization_options=normalization_options, return_timestamps=return_timestamps, lang_code=lang_code, output_format=None):
async for audio_stream_data in self.generate_audio_stream(
text,
voice,
writer,
speed=speed,
normalization_options=normalization_options,
return_timestamps=return_timestamps,
lang_code=lang_code,
output_format=None,
):
if len(audio_stream_data.audio) > 0:
audio_data_chunks.append(audio_stream_data)
@ -384,11 +424,15 @@ class TTSService:
result = None
# Use provided lang_code or determine from voice name
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme pipeline")
logger.info(
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme pipeline"
)
try:
# Use backend's pipeline management
for r in backend._get_pipeline(pipeline_lang_code).generate_from_tokens(
for r in backend._get_pipeline(
pipeline_lang_code
).generate_from_tokens(
tokens=phonemes, # Pass raw phonemes string
voice=voice_path,
speed=speed,
@ -406,7 +450,9 @@ class TTSService:
processing_time = time.time() - start_time
return result.audio.numpy(), processing_time
else:
raise ValueError("Phoneme generation only supported with Kokoro V1 backend")
raise ValueError(
"Phoneme generation only supported with Kokoro V1 backend"
)
except Exception as e:
logger.error(f"Error in phoneme audio generation: {str(e)}")

View file

@ -24,28 +24,27 @@ class JSONStreamingResponse(StreamingResponse, JSONResponse):
else:
self._content_iterable = iterate_in_threadpool(content)
async def body_iterator() -> AsyncIterable[bytes]:
async for content_ in self._content_iterable:
if isinstance(content_, BaseModel):
content_ = content_.model_dump()
yield self.render(content_)
self.body_iterator = body_iterator()
self.status_code = status_code
if media_type is not None:
self.media_type = media_type
self.background = background
self.init_headers(headers)
def render(self, content: typing.Any) -> bytes:
return (json.dumps(
return (
json.dumps(
content,
ensure_ascii=False,
allow_nan=False,
indent=None,
separators=(",", ":"),
) + "\n").encode("utf-8")
)
+ "\n"
).encode("utf-8")

View file

@ -35,17 +35,39 @@ class CaptionedSpeechResponse(BaseModel):
audio: str = Field(..., description="The generated audio data encoded in base 64")
audio_format: str = Field(..., description="The format of the output audio")
timestamps: Optional[List[WordTimestamp]] = Field(..., description="Word-level timestamps")
timestamps: Optional[List[WordTimestamp]] = Field(
..., description="Word-level timestamps"
)
class NormalizationOptions(BaseModel):
"""Options for the normalization system"""
normalize: bool = Field(default=True, description="Normalizes input text to make it easier for the model to say")
unit_normalization: bool = Field(default=False,description="Transforms units like 10KB to 10 kilobytes")
url_normalization: bool = Field(default=True, description="Changes urls so they can be properly pronouced by kokoro")
email_normalization: bool = Field(default=True, description="Changes emails so they can be properly pronouced by kokoro")
optional_pluralization_normalization: bool = Field(default=True, description="Replaces (s) with s so some words get pronounced correctly")
phone_normalization: bool = Field(default=True, description="Changes phone numbers so they can be properly pronouced by kokoro")
normalize: bool = Field(
default=True,
description="Normalizes input text to make it easier for the model to say",
)
unit_normalization: bool = Field(
default=False, description="Transforms units like 10KB to 10 kilobytes"
)
url_normalization: bool = Field(
default=True,
description="Changes urls so they can be properly pronounced by kokoro",
)
email_normalization: bool = Field(
default=True,
description="Changes emails so they can be properly pronouced by kokoro",
)
optional_pluralization_normalization: bool = Field(
default=True,
description="Replaces (s) with s so some words get pronounced correctly",
)
phone_normalization: bool = Field(
default=True,
description="Changes phone numbers so they can be properly pronouced by kokoro",
)
class OpenAISpeechRequest(BaseModel):
"""Request schema for OpenAI-compatible speech endpoint"""
@ -62,9 +84,11 @@ class OpenAISpeechRequest(BaseModel):
default="mp3",
description="The format to return audio in. Supported formats: mp3, opus, flac, wav, pcm. PCM format returns raw 16-bit samples without headers. AAC is not currently supported.",
)
download_format: Optional[Literal["mp3", "opus", "aac", "flac", "wav", "pcm"]] = Field(
default=None,
description="Optional different format for the final download. If not provided, uses response_format.",
download_format: Optional[Literal["mp3", "opus", "aac", "flac", "wav", "pcm"]] = (
Field(
default=None,
description="Optional different format for the final download. If not provided, uses response_format.",
)
)
speed: float = Field(
default=1.0,
@ -85,8 +109,8 @@ class OpenAISpeechRequest(BaseModel):
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
)
normalization_options: Optional[NormalizationOptions] = Field(
default= NormalizationOptions(),
description= "Options for the normalization system"
default=NormalizationOptions(),
description="Options for the normalization system",
)
@ -129,6 +153,6 @@ class CaptionedSpeechRequest(BaseModel):
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
)
normalization_options: Optional[NormalizationOptions] = Field(
default= NormalizationOptions(),
description= "Options for the normalization system"
default=NormalizationOptions(),
description="Options for the normalization system",
)

View file

@ -69,4 +69,3 @@ async def tts_service(mock_model_manager, mock_voice_manager):
def test_voice():
"""Return a test voice name."""
return "voice1"

View file

@ -5,9 +5,11 @@ from unittest.mock import patch
import numpy as np
import pytest
from api.src.services.audio import AudioNormalizer, AudioService
from api.src.inference.base import AudioChunk
from api.src.services.audio import AudioNormalizer, AudioService
from api.src.services.streaming_audio_writer import StreamingAudioWriter
@pytest.fixture(autouse=True)
def mock_settings():
"""Mock settings for all tests"""
@ -64,7 +66,9 @@ async def test_convert_to_mp3(sample_audio):
assert isinstance(audio_chunk, AudioChunk)
assert len(audio_chunk.output) > 0
# Check MP3 header (ID3 or MPEG frame sync)
assert audio_chunk.output.startswith(b"ID3") or audio_chunk.output.startswith(b"\xff\xfb")
assert audio_chunk.output.startswith(b"ID3") or audio_chunk.output.startswith(
b"\xff\xfb"
)
@pytest.mark.asyncio
@ -76,7 +80,7 @@ async def test_convert_to_opus(sample_audio):
writer = StreamingAudioWriter("opus", sample_rate=24000)
audio_chunk = await AudioService.convert_audio(
AudioChunk(audio_data), "opus",writer
AudioChunk(audio_data), "opus", writer
)
writer.close()
@ -120,12 +124,14 @@ async def test_convert_to_aac(sample_audio):
)
writer.close()
assert isinstance(audio_chunk.output, bytes)
assert isinstance(audio_chunk, AudioChunk)
assert len(audio_chunk.output) > 0
# Check ADTS header (AAC)
assert audio_chunk.output.startswith(b"\xff\xf0") or audio_chunk.output.startswith(b"\xff\xf1")
assert audio_chunk.output.startswith(b"\xff\xf0") or audio_chunk.output.startswith(
b"\xff\xf1"
)
@pytest.mark.asyncio
@ -150,7 +156,7 @@ async def test_convert_to_pcm(sample_audio):
@pytest.mark.asyncio
async def test_convert_to_invalid_format_raises_error(sample_audio):
"""Test that converting to an invalid format raises an error"""
#audio_data, sample_rate = sample_audio
# audio_data, sample_rate = sample_audio
with pytest.raises(ValueError, match="Unsupported format: invalid"):
writer = StreamingAudioWriter("invalid", sample_rate=24000)
@ -212,7 +218,6 @@ async def test_different_sample_rates(sample_audio):
sample_rates = [8000, 16000, 44100, 48000]
for rate in sample_rates:
writer = StreamingAudioWriter("wav", sample_rate=rate)
audio_chunk = await AudioService.convert_audio(

View file

@ -1,8 +1,10 @@
import pytest
from unittest.mock import patch, MagicMock
import requests
import base64
import json
from unittest.mock import MagicMock, patch
import pytest
import requests
def test_generate_captioned_speech():
"""Test the generate_captioned_speech function with mocked responses"""
@ -12,20 +14,21 @@ def test_generate_captioned_speech():
mock_timestamps_response = MagicMock()
mock_timestamps_response.status_code = 200
mock_timestamps_response.content = json.dumps({
"audio":base64.b64encode(b"mock audio data").decode("utf-8"),
"timestamps":[{"word": "test", "start_time": 0.0, "end_time": 1.0}]
})
mock_timestamps_response.content = json.dumps(
{
"audio": base64.b64encode(b"mock audio data").decode("utf-8"),
"timestamps": [{"word": "test", "start_time": 0.0, "end_time": 1.0}],
}
)
# Patch the HTTP requests
with patch('requests.post', return_value=mock_timestamps_response):
with patch("requests.post", return_value=mock_timestamps_response):
# Import here to avoid module-level import issues
from examples.captioned_speech_example import generate_captioned_speech
# Test the function
audio, timestamps = generate_captioned_speech("test text")
# Verify we got both audio and timestamps
assert audio == b"mock audio data"
assert timestamps == [{"word": "test", "start_time": 0.0, "end_time": 1.0}]
assert timestamps == [{"word": "test", "start_time": 0.0, "end_time": 1.0}]

View file

@ -5,27 +5,48 @@ import pytest
from api.src.services.text_processing.normalizer import normalize_text
from api.src.structures.schemas import NormalizationOptions
def test_url_protocols():
"""Test URL protocol handling"""
assert (
normalize_text("Check out https://example.com",normalization_options=NormalizationOptions())
normalize_text(
"Check out https://example.com",
normalization_options=NormalizationOptions(),
)
== "Check out https example dot com"
)
assert normalize_text("Visit http://site.com",normalization_options=NormalizationOptions()) == "Visit http site dot com"
assert (
normalize_text("Go to https://test.org/path",normalization_options=NormalizationOptions())
normalize_text(
"Visit http://site.com", normalization_options=NormalizationOptions()
)
== "Visit http site dot com"
)
assert (
normalize_text(
"Go to https://test.org/path", normalization_options=NormalizationOptions()
)
== "Go to https test dot org slash path"
)
def test_url_www():
"""Test www prefix handling"""
assert normalize_text("Go to www.example.com",normalization_options=NormalizationOptions()) == "Go to www example dot com"
assert (
normalize_text("Visit www.test.org/docs",normalization_options=NormalizationOptions()) == "Visit www test dot org slash docs"
normalize_text(
"Go to www.example.com", normalization_options=NormalizationOptions()
)
== "Go to www example dot com"
)
assert (
normalize_text("Check www.site.com?q=test",normalization_options=NormalizationOptions())
normalize_text(
"Visit www.test.org/docs", normalization_options=NormalizationOptions()
)
== "Visit www test dot org slash docs"
)
assert (
normalize_text(
"Check www.site.com?q=test", normalization_options=NormalizationOptions()
)
== "Check www site dot com question-mark q equals test"
)
@ -33,15 +54,21 @@ def test_url_www():
def test_url_localhost():
"""Test localhost URL handling"""
assert (
normalize_text("Running on localhost:7860",normalization_options=NormalizationOptions())
normalize_text(
"Running on localhost:7860", normalization_options=NormalizationOptions()
)
== "Running on localhost colon 78 60"
)
assert (
normalize_text("Server at localhost:8080/api",normalization_options=NormalizationOptions())
normalize_text(
"Server at localhost:8080/api", normalization_options=NormalizationOptions()
)
== "Server at localhost colon 80 80 slash api"
)
assert (
normalize_text("Test localhost:3000/test?v=1",normalization_options=NormalizationOptions())
normalize_text(
"Test localhost:3000/test?v=1", normalization_options=NormalizationOptions()
)
== "Test localhost colon 3000 slash test question-mark v equals 1"
)
@ -49,48 +76,104 @@ def test_url_localhost():
def test_url_ip_addresses():
"""Test IP address URL handling"""
assert (
normalize_text("Access 0.0.0.0:9090/test",normalization_options=NormalizationOptions())
normalize_text(
"Access 0.0.0.0:9090/test", normalization_options=NormalizationOptions()
)
== "Access 0 dot 0 dot 0 dot 0 colon 90 90 slash test"
)
assert (
normalize_text("API at 192.168.1.1:8000",normalization_options=NormalizationOptions())
normalize_text(
"API at 192.168.1.1:8000", normalization_options=NormalizationOptions()
)
== "API at 192 dot 168 dot 1 dot 1 colon 8000"
)
assert normalize_text("Server 127.0.0.1",normalization_options=NormalizationOptions()) == "Server 127 dot 0 dot 0 dot 1"
assert (
normalize_text("Server 127.0.0.1", normalization_options=NormalizationOptions())
== "Server 127 dot 0 dot 0 dot 1"
)
def test_url_raw_domains():
"""Test raw domain handling"""
assert (
normalize_text("Visit google.com/search",normalization_options=NormalizationOptions()) == "Visit google dot com slash search"
normalize_text(
"Visit google.com/search", normalization_options=NormalizationOptions()
)
== "Visit google dot com slash search"
)
assert (
normalize_text("Go to example.com/path?q=test",normalization_options=NormalizationOptions())
normalize_text(
"Go to example.com/path?q=test",
normalization_options=NormalizationOptions(),
)
== "Go to example dot com slash path question-mark q equals test"
)
assert normalize_text("Check docs.test.com",normalization_options=NormalizationOptions()) == "Check docs dot test dot com"
assert (
normalize_text(
"Check docs.test.com", normalization_options=NormalizationOptions()
)
== "Check docs dot test dot com"
)
def test_url_email_addresses():
"""Test email address handling"""
assert (
normalize_text("Email me at user@example.com",normalization_options=NormalizationOptions())
normalize_text(
"Email me at user@example.com", normalization_options=NormalizationOptions()
)
== "Email me at user at example dot com"
)
assert normalize_text("Contact admin@test.org",normalization_options=NormalizationOptions()) == "Contact admin at test dot org"
assert (
normalize_text("Send to test.user@site.com",normalization_options=NormalizationOptions())
normalize_text(
"Contact admin@test.org", normalization_options=NormalizationOptions()
)
== "Contact admin at test dot org"
)
assert (
normalize_text(
"Send to test.user@site.com", normalization_options=NormalizationOptions()
)
== "Send to test dot user at site dot com"
)
def test_money():
"""Test that money text is normalized correctly"""
assert normalize_text("He lost $5.3 thousand.",normalization_options=NormalizationOptions()) == "He lost five point three thousand dollars."
assert normalize_text("To put it weirdly -$6.9 million",normalization_options=NormalizationOptions()) == "To put it weirdly minus six point nine million dollars"
assert normalize_text("It costs $50.3.",normalization_options=NormalizationOptions()) == "It costs fifty dollars and thirty cents."
assert (
normalize_text(
"He lost $5.3 thousand.", normalization_options=NormalizationOptions()
)
== "He lost five point three thousand dollars."
)
assert (
normalize_text(
"To put it weirdly -$6.9 million",
normalization_options=NormalizationOptions(),
)
== "To put it weirdly minus six point nine million dollars"
)
assert (
normalize_text("It costs $50.3.", normalization_options=NormalizationOptions())
== "It costs fifty dollars and thirty cents."
)
def test_non_url_text():
"""Test that non-URL text is unaffected"""
assert normalize_text("This is not.a.url text",normalization_options=NormalizationOptions()) == "This is not-a-url text"
assert normalize_text("Hello, how are you today?",normalization_options=NormalizationOptions()) == "Hello, how are you today?"
assert normalize_text("It costs $50.",normalization_options=NormalizationOptions()) == "It costs fifty dollars."
assert (
normalize_text(
"This is not.a.url text", normalization_options=NormalizationOptions()
)
== "This is not-a-url text"
)
assert (
normalize_text(
"Hello, how are you today?", normalization_options=NormalizationOptions()
)
== "Hello, how are you today?"
)
assert (
normalize_text("It costs $50.", normalization_options=NormalizationOptions())
== "It costs fifty dollars."
)

View file

@ -4,20 +4,19 @@ import os
from typing import AsyncGenerator, Tuple
from unittest.mock import AsyncMock, MagicMock, patch
from api.src.services.streaming_audio_writer import StreamingAudioWriter
from api.src.inference.base import AudioChunk
import numpy as np
import pytest
from fastapi.testclient import TestClient
from api.src.core.config import settings
from api.src.inference.base import AudioChunk
from api.src.main import app
from api.src.routers.openai_compatible import (
get_tts_service,
load_openai_mappings,
stream_audio_chunks,
)
from api.src.services.streaming_audio_writer import StreamingAudioWriter
from api.src.services.tts_service import TTSService
from api.src.structures.schemas import OpenAISpeechRequest
@ -80,13 +79,13 @@ def test_list_models(mock_openai_mappings):
assert data["object"] == "list"
assert isinstance(data["data"], list)
assert len(data["data"]) == 3 # tts-1, tts-1-hd, and kokoro
# Verify all expected models are present
model_ids = [model["id"] for model in data["data"]]
assert "tts-1" in model_ids
assert "tts-1-hd" in model_ids
assert "kokoro" in model_ids
# Verify model format
for model in data["data"]:
assert model["object"] == "model"
@ -114,7 +113,6 @@ def test_retrieve_model(mock_openai_mappings):
assert error["detail"]["type"] == "invalid_request_error"
@pytest.mark.asyncio
async def test_get_tts_service_initialization():
"""Test TTSService initialization"""
@ -147,7 +145,7 @@ async def test_stream_audio_chunks_client_disconnect():
async def mock_stream(*args, **kwargs):
for i in range(5):
yield AudioChunk(np.ndarray([],np.int16),output=b"chunk")
yield AudioChunk(np.ndarray([], np.int16), output=b"chunk")
mock_service.generate_audio_stream = mock_stream
mock_service.list_voices.return_value = ["test_voice"]
@ -243,10 +241,10 @@ def mock_tts_service(mock_audio_bytes):
"""Mock TTS service for testing."""
with patch("api.src.routers.openai_compatible.get_tts_service") as mock_get:
service = AsyncMock(spec=TTSService)
service.generate_audio.return_value = AudioChunk(np.zeros(1000,np.int16))
service.generate_audio.return_value = AudioChunk(np.zeros(1000, np.int16))
async def mock_stream(*args, **kwargs) -> AsyncGenerator[AudioChunk, None]:
yield AudioChunk(np.ndarray([],np.int16),output=mock_audio_bytes)
yield AudioChunk(np.ndarray([], np.int16), output=mock_audio_bytes)
service.generate_audio_stream = mock_stream
service.list_voices.return_value = ["test_voice", "voice1", "voice2"]
@ -263,8 +261,10 @@ def test_openai_speech_endpoint(
):
"""Test the OpenAI-compatible speech endpoint with basic MP3 generation"""
# Configure mocks
mock_tts_service.generate_audio.return_value = AudioChunk(np.zeros(1000,np.int16))
mock_convert.return_value = AudioChunk(np.zeros(1000,np.int16),output=mock_audio_bytes)
mock_tts_service.generate_audio.return_value = AudioChunk(np.zeros(1000, np.int16))
mock_convert.return_value = AudioChunk(
np.zeros(1000, np.int16), output=mock_audio_bytes
)
response = client.post(
"/v1/audio/speech",

View file

@ -44,9 +44,12 @@ def test_get_sentence_info():
assert count == len(tokens)
assert count > 0
def test_get_sentence_info_phenomoes():
"""Test sentence splitting and info extraction."""
text = "This is sentence one. This is </|custom_phonemes_0|/> two! What about three?"
text = (
"This is sentence one. This is </|custom_phonemes_0|/> two! What about three?"
)
results = get_sentence_info(text, {"</|custom_phonemes_0|/>": r"sˈɛntᵊns"})
assert len(results) == 3
@ -58,6 +61,7 @@ def test_get_sentence_info_phenomoes():
assert count == len(tokens)
assert count > 0
@pytest.mark.asyncio
async def test_smart_split_short_text():
"""Test smart splitting with text under max tokens."""

View file

@ -2,8 +2,8 @@ apiVersion: v2
name: kokoro-fastapi
description: A Helm chart for deploying the Kokoro FastAPI TTS service to Kubernetes
type: application
version: 0.2.0
appVersion: "0.2.0"
version: 0.3.0
appVersion: "0.3.0"
keywords:
- tts

View file

@ -17,35 +17,44 @@ import base64
import concurrent.futures
import json
import os
import requests
import sys
import time
import wave
import sys
from pathlib import Path
import requests
def setup_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(description="Test Kokoro TTS for race conditions")
parser.add_argument("--url", default="http://localhost:8880",
help="Base URL of the Kokoro TTS service")
parser.add_argument("--threads", type=int, default=8,
help="Number of concurrent threads to use")
parser.add_argument("--iterations", type=int, default=5,
help="Number of iterations per thread")
parser.add_argument("--voice", default="af_heart",
help="Voice to use for TTS")
parser.add_argument("--output-dir", default="./tts_test_output",
help="Directory to save output files")
parser.add_argument("--debug", action="store_true",
help="Enable debug logging")
parser.add_argument(
"--url",
default="http://localhost:8880",
help="Base URL of the Kokoro TTS service",
)
parser.add_argument(
"--threads", type=int, default=8, help="Number of concurrent threads to use"
)
parser.add_argument(
"--iterations", type=int, default=5, help="Number of iterations per thread"
)
parser.add_argument("--voice", default="af_heart", help="Voice to use for TTS")
parser.add_argument(
"--output-dir",
default="./tts_test_output",
help="Directory to save output files",
)
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
return parser.parse_args()
def generate_test_sentence(thread_id, iteration):
"""Generate a simple test sentence with numbers to make mismatches easily identifiable"""
return f"This is test sentence number {thread_id}-{iteration}. " \
f"If you hear this sentence, you should hear the numbers {thread_id}-{iteration}."
return (
f"This is test sentence number {thread_id}-{iteration}. "
f"If you hear this sentence, you should hear the numbers {thread_id}-{iteration}."
)
def log_message(message, debug=False, is_error=False):
@ -62,81 +71,129 @@ def request_tts(url, test_id, text, voice, output_dir, debug=False):
start_time = time.time()
output_file = os.path.join(output_dir, f"test_{test_id}.wav")
text_file = os.path.join(output_dir, f"test_{test_id}.txt")
# Log output paths for debugging
log_message(f"Thread {test_id}: Text will be saved to: {text_file}", debug)
log_message(f"Thread {test_id}: Audio will be saved to: {output_file}", debug)
# Save the text for later comparison
try:
with open(text_file, "w") as f:
f.write(text)
log_message(f"Thread {test_id}: Successfully saved text file", debug)
except Exception as e:
log_message(f"Thread {test_id}: Error saving text file: {str(e)}", debug, is_error=True)
log_message(
f"Thread {test_id}: Error saving text file: {str(e)}", debug, is_error=True
)
# Make the TTS request
try:
log_message(f"Thread {test_id}: Requesting TTS for: '{text}'", debug)
response = requests.post(
f"{url}/v1/audio/speech",
json={
"model": "kokoro",
"input": text,
"voice": voice,
"response_format": "wav"
"response_format": "wav",
},
headers={"Accept": "audio/wav"},
timeout=60 # Increase timeout to 60 seconds
timeout=60, # Increase timeout to 60 seconds
)
log_message(f"Thread {test_id}: Response status code: {response.status_code}", debug)
log_message(f"Thread {test_id}: Response content type: {response.headers.get('Content-Type', 'None')}", debug)
log_message(f"Thread {test_id}: Response content length: {len(response.content)} bytes", debug)
log_message(
f"Thread {test_id}: Response status code: {response.status_code}", debug
)
log_message(
f"Thread {test_id}: Response content type: {response.headers.get('Content-Type', 'None')}",
debug,
)
log_message(
f"Thread {test_id}: Response content length: {len(response.content)} bytes",
debug,
)
if response.status_code != 200:
log_message(f"Thread {test_id}: API error: {response.status_code} - {response.text}", debug, is_error=True)
log_message(
f"Thread {test_id}: API error: {response.status_code} - {response.text}",
debug,
is_error=True,
)
return False
# Check if we got valid audio data
if len(response.content) < 100: # Sanity check - WAV files should be larger than this
log_message(f"Thread {test_id}: Received suspiciously small audio data: {len(response.content)} bytes", debug, is_error=True)
log_message(f"Thread {test_id}: Content (base64): {base64.b64encode(response.content).decode('utf-8')}", debug, is_error=True)
if (
len(response.content) < 100
): # Sanity check - WAV files should be larger than this
log_message(
f"Thread {test_id}: Received suspiciously small audio data: {len(response.content)} bytes",
debug,
is_error=True,
)
log_message(
f"Thread {test_id}: Content (base64): {base64.b64encode(response.content).decode('utf-8')}",
debug,
is_error=True,
)
return False
# Save the audio output with explicit error handling
try:
with open(output_file, "wb") as f:
bytes_written = f.write(response.content)
log_message(f"Thread {test_id}: Wrote {bytes_written} bytes to {output_file}", debug)
log_message(
f"Thread {test_id}: Wrote {bytes_written} bytes to {output_file}",
debug,
)
# Verify the WAV file exists and has content
if os.path.exists(output_file):
file_size = os.path.getsize(output_file)
log_message(f"Thread {test_id}: Verified file exists with size: {file_size} bytes", debug)
log_message(
f"Thread {test_id}: Verified file exists with size: {file_size} bytes",
debug,
)
# Validate WAV file by reading its headers
try:
with wave.open(output_file, 'rb') as wav_file:
with wave.open(output_file, "rb") as wav_file:
channels = wav_file.getnchannels()
sample_width = wav_file.getsampwidth()
framerate = wav_file.getframerate()
frames = wav_file.getnframes()
log_message(f"Thread {test_id}: Valid WAV file - channels: {channels}, "
f"sample width: {sample_width}, framerate: {framerate}, frames: {frames}", debug)
log_message(
f"Thread {test_id}: Valid WAV file - channels: {channels}, "
f"sample width: {sample_width}, framerate: {framerate}, frames: {frames}",
debug,
)
except Exception as wav_error:
log_message(f"Thread {test_id}: Invalid WAV file: {str(wav_error)}", debug, is_error=True)
log_message(
f"Thread {test_id}: Invalid WAV file: {str(wav_error)}",
debug,
is_error=True,
)
else:
log_message(f"Thread {test_id}: File was not created: {output_file}", debug, is_error=True)
log_message(
f"Thread {test_id}: File was not created: {output_file}",
debug,
is_error=True,
)
except Exception as save_error:
log_message(f"Thread {test_id}: Error saving audio file: {str(save_error)}", debug, is_error=True)
log_message(
f"Thread {test_id}: Error saving audio file: {str(save_error)}",
debug,
is_error=True,
)
return False
end_time = time.time()
log_message(f"Thread {test_id}: Saved output to {output_file} (time: {end_time - start_time:.2f}s)", debug)
log_message(
f"Thread {test_id}: Saved output to {output_file} (time: {end_time - start_time:.2f}s)",
debug,
)
return True
except requests.exceptions.Timeout:
log_message(f"Thread {test_id}: Request timed out", debug, is_error=True)
return False
@ -151,11 +208,17 @@ def worker_task(thread_id, args):
iteration = i + 1
test_id = f"{thread_id:02d}_{iteration:02d}"
text = generate_test_sentence(thread_id, iteration)
success = request_tts(args.url, test_id, text, args.voice, args.output_dir, args.debug)
success = request_tts(
args.url, test_id, text, args.voice, args.output_dir, args.debug
)
if not success:
log_message(f"Thread {thread_id}: Iteration {iteration} failed", args.debug, is_error=True)
log_message(
f"Thread {thread_id}: Iteration {iteration} failed",
args.debug,
is_error=True,
)
# Small delay between iterations to avoid overwhelming the API
time.sleep(0.1)
@ -164,46 +227,61 @@ def run_test(args):
"""Run the test with the specified parameters"""
# Ensure output directory exists and check permissions
os.makedirs(args.output_dir, exist_ok=True)
# Test write access to the output directory
test_file = os.path.join(args.output_dir, "write_test.txt")
try:
with open(test_file, "w") as f:
f.write("Testing write access\n")
os.remove(test_file)
log_message(f"Successfully verified write access to output directory: {args.output_dir}")
log_message(
f"Successfully verified write access to output directory: {args.output_dir}"
)
except Exception as e:
log_message(f"Warning: Cannot write to output directory {args.output_dir}: {str(e)}", is_error=True)
log_message(
f"Warning: Cannot write to output directory {args.output_dir}: {str(e)}",
is_error=True,
)
log_message(f"Current directory: {os.getcwd()}", is_error=True)
log_message(f"Directory contents: {os.listdir('.')}", is_error=True)
# Test connection to Kokoro TTS service
try:
response = requests.get(f"{args.url}/health", timeout=5)
if response.status_code == 200:
log_message(f"Successfully connected to Kokoro TTS service at {args.url}")
else:
log_message(f"Warning: Kokoro TTS service health check returned status {response.status_code}", is_error=True)
log_message(
f"Warning: Kokoro TTS service health check returned status {response.status_code}",
is_error=True,
)
except Exception as e:
log_message(f"Warning: Cannot connect to Kokoro TTS service at {args.url}: {str(e)}", is_error=True)
log_message(
f"Warning: Cannot connect to Kokoro TTS service at {args.url}: {str(e)}",
is_error=True,
)
# Record start time
start_time = time.time()
log_message(f"Starting test with {args.threads} threads, {args.iterations} iterations per thread")
log_message(
f"Starting test with {args.threads} threads, {args.iterations} iterations per thread"
)
# Create and start worker threads
with concurrent.futures.ThreadPoolExecutor(max_workers=args.threads) as executor:
futures = []
for thread_id in range(1, args.threads + 1):
futures.append(executor.submit(worker_task, thread_id, args))
# Wait for all tasks to complete
for future in concurrent.futures.as_completed(futures):
try:
future.result()
except Exception as e:
log_message(f"Thread execution failed: {str(e)}", args.debug, is_error=True)
log_message(
f"Thread execution failed: {str(e)}", args.debug, is_error=True
)
# Record end time and print summary
end_time = time.time()
total_time = end_time - start_time
@ -213,8 +291,12 @@ def run_test(args):
log_message(f"Average time per request: {total_time / total_requests:.2f} seconds")
log_message(f"Requests per second: {total_requests / total_time:.2f}")
log_message(f"Output files saved to: {os.path.abspath(args.output_dir)}")
log_message("To verify, listen to the audio files and check if they match the text files")
log_message("If you hear audio describing a different test number than the filename, you've found a race condition")
log_message(
"To verify, listen to the audio files and check if they match the text files"
)
log_message(
"If you hear audio describing a different test number than the filename, you've found a race condition"
)
def analyze_audio_files(output_dir):
@ -222,49 +304,58 @@ def analyze_audio_files(output_dir):
# Look for both WAV and TXT files
wav_files = list(Path(output_dir).glob("*.wav"))
txt_files = list(Path(output_dir).glob("*.txt"))
log_message(f"Found {len(wav_files)} WAV files and {len(txt_files)} TXT files")
if len(wav_files) == 0:
log_message("No WAV files found! This indicates the TTS service requests may be failing.", is_error=True)
log_message("Check the connection to the TTS service and the response status codes above.", is_error=True)
log_message(
"No WAV files found! This indicates the TTS service requests may be failing.",
is_error=True,
)
log_message(
"Check the connection to the TTS service and the response status codes above.",
is_error=True,
)
file_stats = []
for wav_path in wav_files:
try:
with wave.open(str(wav_path), 'rb') as wav_file:
with wave.open(str(wav_path), "rb") as wav_file:
frames = wav_file.getnframes()
rate = wav_file.getframerate()
duration = frames / rate
# Get corresponding text
text_path = wav_path.with_suffix('.txt')
text_path = wav_path.with_suffix(".txt")
if text_path.exists():
with open(text_path, 'r') as text_file:
with open(text_path, "r") as text_file:
text = text_file.read().strip()
else:
text = "N/A"
file_stats.append({
'filename': wav_path.name,
'duration': duration,
'text': text
})
file_stats.append(
{"filename": wav_path.name, "duration": duration, "text": text}
)
except Exception as e:
log_message(f"Error analyzing {wav_path}: {str(e)}", False, is_error=True)
# Print summary table
if file_stats:
log_message("\nAudio File Summary:")
log_message(f"{'Filename':<20}{'Duration':<12}{'Text':<60}")
log_message("-" * 92)
for stat in file_stats:
log_message(f"{stat['filename']:<20}{stat['duration']:<12.2f}{stat['text'][:57]+'...' if len(stat['text']) > 60 else stat['text']:<60}")
log_message(
f"{stat['filename']:<20}{stat['duration']:<12.2f}{stat['text'][:57] + '...' if len(stat['text']) > 60 else stat['text']:<60}"
)
# List missing WAV files where text files exist
missing_wavs = set(p.stem for p in txt_files) - set(p.stem for p in wav_files)
if missing_wavs:
log_message(f"\nFound {len(missing_wavs)} text files without corresponding WAV files:", is_error=True)
log_message(
f"\nFound {len(missing_wavs)} text files without corresponding WAV files:",
is_error=True,
)
for stem in sorted(list(missing_wavs))[:10]: # Limit to 10 for readability
log_message(f" - {stem}.txt (no WAV file)", is_error=True)
if len(missing_wavs) > 10:
@ -275,9 +366,13 @@ if __name__ == "__main__":
args = setup_args()
run_test(args)
analyze_audio_files(args.output_dir)
log_message("\nNext Steps:")
log_message("1. Listen to the generated audio files")
log_message("2. Verify if each audio correctly says its ID number")
log_message("3. Check for any mismatches between the audio content and the text files")
log_message("4. If mismatches are found, you've successfully reproduced the race condition")
log_message(
"3. Check for any mismatches between the audio content and the text files"
)
log_message(
"4. If mismatches are found, you've successfully reproduced the race condition"
)

View file

@ -1,8 +1,10 @@
import requests
import base64
import json
import pydub
text="""Delving into the Abyss: A Deeper Exploration of Meaning in 5 Seconds of Summer's "Jet Black Heart"
import requests
text = """Delving into the Abyss: A Deeper Exploration of Meaning in 5 Seconds of Summer's "Jet Black Heart"
5 Seconds of Summer, initially perceived as purveyors of upbeat, radio-friendly pop-punk, embarked on a significant artistic evolution with their album Sounds Good Feels Good. Among its tracks, "Jet Black Heart" stands out as a powerful testament to this shift, moving beyond catchy melodies and embracing a darker, more emotionally complex sound. Released in 2015, the song transcends the typical themes of youthful exuberance and romantic angst, instead plunging into the depths of personal turmoil and the corrosive effects of inner darkness on interpersonal relationships. "Jet Black Heart" is not merely a song about heartbreak; it is a raw and vulnerable exploration of internal struggle, self-destructive patterns, and the precarious flicker of hope that persists even in the face of profound emotional chaos. Through potent metaphors, starkly honest lyrics, and a sonic landscape that mirrors its thematic weight, the song offers a profound meditation on the human condition, grappling with the shadows that reside within us all and their far-reaching consequences.
@ -23,7 +25,7 @@ In conclusion, "Jet Black Heart" by 5 Seconds of Summer is far more than a typic
5 Seconds of Summer, initially perceived as purveyors of upbeat, radio-friendly pop-punk, embarked on a significant artistic evolution with their album Sounds Good Feels Good. Among its tracks, "Jet Black Heart" stands out as a powerful testament to this shift, moving beyond catchy melodies and embracing a darker, more emotionally complex sound. Released in 2015, the song transcends the typical themes of youthful exuberance and romantic angst, instead plunging into the depths of personal turmoil and the corrosive effects of inner darkness on interpersonal relationships. "Jet Black Heart" is not merely a song about heartbreak; it is a raw and vulnerable exploration of internal struggle, self-destructive patterns, and the precarious flicker of hope that persists even in the face of profound emotional chaos."""
Type="wav"
Type = "wav"
response = requests.post(
"http://localhost:8880/dev/captioned_speech",
json={
@ -34,30 +36,34 @@ response = requests.post(
"response_format": Type,
"stream": True,
},
stream=True
stream=True,
)
f=open(f"outputstream.{Type}","wb")
f = open(f"outputstream.{Type}", "wb")
for chunk in response.iter_lines(decode_unicode=True):
if chunk:
temp_json=json.loads(chunk)
temp_json = json.loads(chunk)
if temp_json["timestamps"] != []:
chunk_json=temp_json
chunk_json = temp_json
# Decode base 64 stream to bytes
chunk_audio=base64.b64decode(temp_json["audio"].encode("utf-8"))
chunk_audio = base64.b64decode(temp_json["audio"].encode("utf-8"))
# Process streaming chunks
f.write(chunk_audio)
# Print word level timestamps
last_chunks={"start_time":chunk_json["timestamps"][-10]["start_time"],"end_time":chunk_json["timestamps"][-3]["end_time"],"word":" ".join([X["word"] for X in chunk_json["timestamps"][-10:-3]])}
last_chunks = {
"start_time": chunk_json["timestamps"][-10]["start_time"],
"end_time": chunk_json["timestamps"][-3]["end_time"],
"word": " ".join([X["word"] for X in chunk_json["timestamps"][-10:-3]]),
}
print(f"CUTTING TO {last_chunks['word']}")
audioseg=pydub.AudioSegment.from_file(f"outputstream.{Type}",format=Type)
audioseg=audioseg[last_chunks["start_time"]*1000:last_chunks["end_time"] * 1000]
audioseg.export(f"outputstreamcut.{Type}",format=Type)
audioseg = pydub.AudioSegment.from_file(f"outputstream.{Type}", format=Type)
audioseg = audioseg[last_chunks["start_time"] * 1000 : last_chunks["end_time"] * 1000]
audioseg.export(f"outputstreamcut.{Type}", format=Type)
"""
@ -85,4 +91,4 @@ with open(f"outputnostream.{Type}", "wb") as f:
# Print word level timestamps
print(audio_json["timestamps"])
"""
"""

View file

@ -1,13 +1,14 @@
import requests
import base64
import json
text="""the administration has offered up a platter of repression for more than a year and is still slated to lose $400 million.
import requests
text = """the administration has offered up a platter of repression for more than a year and is still slated to lose $400 million.
Columbia is the largest private landowner in New York City and boasts an endowment of $14.8 billion;"""
Type="wav"
Type = "wav"
response = requests.post(
"http://localhost:8880/v1/audio/speech",
@ -19,7 +20,7 @@ response = requests.post(
"response_format": Type,
"stream": False,
},
stream=True
stream=True,
)
with open(f"outputnostreammoney.{Type}", "wb") as f:

View file

@ -1,6 +1,7 @@
from text_to_num import text2num
import re
import inflect
from text_to_num import text2num
from torch import mul
INFLECT_ENGINE = inflect.engine()
@ -11,6 +12,7 @@ def conditional_int(number: float, threshold: float = 0.00001):
return int(round(number))
return number
def handle_money(m: re.Match[str]) -> str:
"""Convert money expressions to spoken form"""
@ -23,7 +25,7 @@ def handle_money(m: re.Match[str]) -> str:
number = float(number)
except:
return m.group()
if m.group(1) == "-":
number *= -1

View file

@ -1,8 +1,9 @@
import requests
import base64
import json
text="""Delving into the Abyss: A Deeper Exploration of Meaning in 5 Seconds of Summer's "Jet Black Heart"
import requests
text = """Delving into the Abyss: A Deeper Exploration of Meaning in 5 Seconds of Summer's "Jet Black Heart"
5 Seconds of Summer, initially perceived as purveyors of upbeat, radio-friendly pop-punk, embarked on a significant artistic evolution with their album Sounds Good Feels Good. Among its tracks, "Jet Black Heart" stands out as a powerful testament to this shift, moving beyond catchy melodies and embracing a darker, more emotionally complex sound. Released in 2015, the song transcends the typical themes of youthful exuberance and romantic angst, instead plunging into the depths of personal turmoil and the corrosive effects of inner darkness on interpersonal relationships. "Jet Black Heart" is not merely a song about heartbreak; it is a raw and vulnerable exploration of internal struggle, self-destructive patterns, and the precarious flicker of hope that persists even in the face of profound emotional chaos. Through potent metaphors, starkly honest lyrics, and a sonic landscape that mirrors its thematic weight, the song offers a profound meditation on the human condition, grappling with the shadows that reside within us all and their far-reaching consequences.
@ -18,12 +19,12 @@ Beyond the lyrical content, the musical elements of "Jet Black Heart" contribute
In conclusion, "Jet Black Heart" by 5 Seconds of Summer is far more than a typical pop song; it is a poignant and deeply resonant exploration of inner darkness, self-destructive tendencies, and the fragile yet persistent hope for human connection and redemption. Through its powerful central metaphor of the "jet black heart," its unflinching portrayal of internal turmoil, and its subtle yet potent message of vulnerability and potential transformation, the song resonates with anyone who has grappled with their own inner demons and the complexities of human relationships. It is a reminder that even in the deepest darkness, a flicker of hope can endure, and that true healing and connection often emerge from the courageous act of confronting and sharing our most vulnerable selves. "Jet Black Heart" stands as a testament to 5 Seconds of Summer's artistic growth, showcasing their capacity to delve into profound emotional territories and create music that is not only catchy and engaging but also deeply meaningful and emotionally resonant, solidifying their position as a band capable of capturing the complexities of the human experience."""
text="""Delving into the Abyss: A Deeper Exploration of Meaning in 5 Seconds of Summer's "Jet Black Heart"
text = """Delving into the Abyss: A Deeper Exploration of Meaning in 5 Seconds of Summer's "Jet Black Heart"
5 Seconds of Summer, initially perceived as purveyors of upbeat, radio-friendly pop-punk, embarked on a significant artistic evolution with their album Sounds Good Feels Good. Among its tracks, "Jet Black Heart" stands out as a powerful testament to this shift, moving beyond catchy melodies and embracing a darker, more emotionally complex sound. Released in 2015, the song transcends the typical themes of youthful exuberance and romantic angst, instead plunging into the depths of personal turmoil and the corrosive effects of inner darkness on interpersonal relationships. "Jet Black Heart" is not merely a song about heartbreak; it is a raw and vulnerable exploration of internal struggle, self-destructive patterns, and the precarious flicker of hope that persists even in the face of profound emotional chaos."""
Type="wav"
Type = "wav"
response = requests.post(
@ -36,11 +37,11 @@ response = requests.post(
"response_format": Type,
"stream": True,
},
stream=True
stream=True,
)
f=open(f"outputstream.{Type}","wb")
f = open(f"outputstream.{Type}", "wb")
for chunk in response.iter_content():
if chunk:
# Process streaming chunks
@ -56,7 +57,7 @@ response = requests.post(
"response_format": Type,
"stream": False,
},
stream=True
stream=True,
)
with open(f"outputnostream.{Type}", "wb") as f:

View file

@ -20,7 +20,7 @@ services:
# # Gradio UI service [Comment out everything below if you don't need it]
# gradio-ui:
# image: ghcr.io/remsky/kokoro-fastapi-ui:v0.2.0
# image: ghcr.io/remsky/kokoro-fastapi-ui:v${VERSION}
# # Uncomment below (and comment out above) to build from source instead of using the released image
# build:
# context: ../../ui

View file

@ -1,7 +1,7 @@
name: kokoro-tts-gpu
services:
kokoro-tts:
# image: ghcr.io/remsky/kokoro-fastapi-gpu:v0.2.0
# image: ghcr.io/remsky/kokoro-fastapi-gpu:v${VERSION}
build:
context: ../..
dockerfile: docker/gpu/Dockerfile
@ -24,7 +24,7 @@ services:
# # Gradio UI service
# gradio-ui:
# image: ghcr.io/remsky/kokoro-fastapi-ui:v0.2.0
# image: ghcr.io/remsky/kokoro-fastapi-ui:v${VERSION}
# # Uncomment below to build from source instead of using the released image
# # build:
# # context: ../../ui

View file

@ -11,11 +11,11 @@ from loguru import logger
def verify_files(model_path: str, config_path: str) -> bool:
"""Verify that model files exist and are valid.
Args:
model_path: Path to model file
config_path: Path to config file
Returns:
True if files exist and are valid
"""
@ -25,15 +25,15 @@ def verify_files(model_path: str, config_path: str) -> bool:
return False
if not os.path.exists(config_path):
return False
# Verify config file is valid JSON
with open(config_path) as f:
config = json.load(f)
# Check model file size (should be non-zero)
if os.path.getsize(model_path) == 0:
return False
return True
except Exception:
return False
@ -41,45 +41,45 @@ def verify_files(model_path: str, config_path: str) -> bool:
def download_model(output_dir: str) -> None:
"""Download model files from GitHub release.
Args:
output_dir: Directory to save model files
"""
try:
# Create output directory
os.makedirs(output_dir, exist_ok=True)
# Define file paths
model_file = "kokoro-v1_0.pth"
config_file = "config.json"
model_path = os.path.join(output_dir, model_file)
config_path = os.path.join(output_dir, config_file)
# Check if files already exist and are valid
if verify_files(model_path, config_path):
logger.info("Model files already exist and are valid")
return
logger.info("Downloading Kokoro v1.0 model files")
# GitHub release URLs (to be updated with v0.2.0 release)
base_url = "https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.4"
model_url = f"{base_url}/{model_file}"
config_url = f"{base_url}/{config_file}"
# Download files
logger.info("Downloading model file...")
urlretrieve(model_url, model_path)
logger.info("Downloading config file...")
urlretrieve(config_url, config_path)
# Verify downloaded files
if not verify_files(model_path, config_path):
raise RuntimeError("Failed to verify downloaded files")
logger.info(f"✓ Model files prepared in {output_dir}")
except Exception as e:
logger.error(f"Failed to download model: {e}")
raise
@ -88,17 +88,15 @@ def download_model(output_dir: str) -> None:
def main():
"""Main entry point."""
import argparse
parser = argparse.ArgumentParser(description="Download Kokoro v1.0 model")
parser.add_argument(
"--output",
required=True,
help="Output directory for model files"
"--output", required=True, help="Output directory for model files"
)
args = parser.parse_args()
download_model(args.output)
if __name__ == "__main__":
main()
main()

View file

@ -123,7 +123,7 @@ def main():
with open(wells_path, "r", encoding="utf-8") as f:
full_text = f.read()
# Take first few paragraphs
text = " ".join(full_text.split("\n\n")[:2])
text = " ".join(full_text.split("\n\n")[1:3])
print("\nStarting TTS stream playback...")
print(f"Text length: {len(text)} characters")

View file

@ -1,6 +1,6 @@
[project]
name = "kokoro-fastapi"
version = "0.1.4"
version = "0.3.0"
description = "FastAPI TTS Service"
readme = "README.md"
requires-python = ">=3.10"
@ -31,10 +31,11 @@ dependencies = [
"matplotlib>=3.10.0",
"mutagen>=1.47.0",
"psutil>=6.1.1",
"kokoro @ git+https://github.com/hexgrad/kokoro.git@31a2b6337b8c1b1418ef68c48142328f640da938",
'misaki[en,ja,ko,zh]',
"spacy==3.7.2",
"en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl",
"espeakng-loader==0.2.4",
"kokoro==0.9.2",
"misaki[en,ja,ko,zh]==0.9.3",
"spacy==3.8.5",
"en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl",
"inflect>=7.5.0",
"phonemizer-fork>=3.3.2",
"av>=14.2.0",
@ -57,8 +58,8 @@ test = [
"pytest-cov==6.0.0",
"httpx==0.26.0",
"pytest-asyncio==0.25.3",
"openai>=1.59.6",
"tomli>=2.0.1",
"jinja2>=3.1.6"
]
[tool.uv]

46
scripts/fix_misaki.py Normal file
View file

@ -0,0 +1,46 @@
"""
Patch for misaki package to fix the EspeakWrapper.set_data_path issue.
"""
import importlib.util
import os
import sys
# Find the misaki package
try:
import misaki
misaki_path = os.path.dirname(misaki.__file__)
print(f"Found misaki package at: {misaki_path}")
except ImportError:
print("Misaki package not found. Make sure it's installed.")
sys.exit(1)
# Path to the espeak.py file
espeak_file = os.path.join(misaki_path, "espeak.py")
if not os.path.exists(espeak_file):
print(f"Could not find {espeak_file}")
sys.exit(1)
# Read the current content
with open(espeak_file, "r") as f:
content = f.read()
# Check if the problematic line exists
if "EspeakWrapper.set_data_path(espeakng_loader.get_data_path())" in content:
# Replace the problematic line
new_content = content.replace(
"EspeakWrapper.set_data_path(espeakng_loader.get_data_path())",
"# Fixed line to use data_path attribute instead of set_data_path method\n"
"EspeakWrapper.data_path = espeakng_loader.get_data_path()",
)
# Write the modified content back
with open(espeak_file, "w") as f:
f.write(new_content)
print(f"Successfully patched {espeak_file}")
else:
print(f"The problematic line was not found in {espeak_file}")
print("The file may have already been patched or the issue is different.")

View file

@ -1,138 +1,139 @@
import re
import subprocess
import tomli
from pathlib import Path
import tomli
def extract_dependency_info():
"""Extract version and commit hash for kokoro and misaki from pyproject.toml"""
"""Extract version for kokoro and misaki from pyproject.toml"""
with open("pyproject.toml", "rb") as f:
pyproject = tomli.load(f)
deps = pyproject["project"]["dependencies"]
info = {}
# Extract kokoro info
kokoro_found = False
misaki_found = False
for dep in deps:
if dep.startswith("kokoro @"):
# Extract version from the dependency string if available
version_match = re.search(r"kokoro @ git\+https://github\.com/hexgrad/kokoro\.git@", dep)
if version_match:
# If no explicit version, use v0.7.9 as shown in the README
version = "v0.7.9"
commit_match = re.search(r"@([a-f0-9]{7})", dep)
if commit_match:
info["kokoro"] = {
"version": version,
"commit": commit_match.group(1)
}
elif dep.startswith("misaki["):
# Extract version from the dependency string if available
version_match = re.search(r"misaki\[.*?\] @ git\+https://github\.com/hexgrad/misaki\.git@", dep)
if version_match:
# If no explicit version, use v0.7.9 as shown in the README
version = "v0.7.9"
commit_match = re.search(r"@([a-f0-9]{7})", dep)
if commit_match:
info["misaki"] = {
"version": version,
"commit": commit_match.group(1)
}
# Match kokoro==version
kokoro_match = re.match(r"^kokoro==(.+)$", dep)
if kokoro_match:
info["kokoro"] = {"version": kokoro_match.group(1)}
kokoro_found = True
# Match misaki[...] ==version or misaki==version
misaki_match = re.match(r"^misaki(?:\[.*?\])?==(.+)$", dep)
if misaki_match:
info["misaki"] = {"version": misaki_match.group(1)}
misaki_found = True
# Stop if both found
if kokoro_found and misaki_found:
break
if not kokoro_found:
raise ValueError("Kokoro version not found in pyproject.toml dependencies")
if not misaki_found:
raise ValueError("Misaki version not found in pyproject.toml dependencies")
return info
def run_pytest_with_coverage():
"""Run pytest with coverage and return the results"""
try:
# Run pytest with coverage
result = subprocess.run(
["pytest", "--cov=api", "-v"],
capture_output=True,
text=True,
check=True
["pytest", "--cov=api", "-v"], capture_output=True, text=True, check=True
)
# Extract test results
test_output = result.stdout
passed_tests = len(re.findall(r"PASSED", test_output))
# Extract coverage from .coverage file
coverage_output = subprocess.run(
["coverage", "report"],
capture_output=True,
text=True,
check=True
["coverage", "report"], capture_output=True, text=True, check=True
).stdout
# Extract total coverage percentage
coverage_match = re.search(r"TOTAL\s+\d+\s+\d+\s+(\d+)%", coverage_output)
coverage_percentage = coverage_match.group(1) if coverage_match else "0"
return passed_tests, coverage_percentage
except subprocess.CalledProcessError as e:
print(f"Error running tests: {e}")
print(f"Output: {e.output}")
return 0, "0"
def update_readme_badges(passed_tests, coverage_percentage, dep_info):
"""Update the badges in the README file"""
readme_path = Path("README.md")
if not readme_path.exists():
print("README.md not found")
return False
content = readme_path.read_text()
# Update tests badge
content = re.sub(
r'!\[Tests\]\(https://img\.shields\.io/badge/tests-\d+%20passed-[a-zA-Z]+\)',
f'![Tests](https://img.shields.io/badge/tests-{passed_tests}%20passed-darkgreen)',
content
r"!\[Tests\]\(https://img\.shields\.io/badge/tests-\d+%20passed-[a-zA-Z]+\)",
f"![Tests](https://img.shields.io/badge/tests-{passed_tests}%20passed-darkgreen)",
content,
)
# Update coverage badge
content = re.sub(
r'!\[Coverage\]\(https://img\.shields\.io/badge/coverage-\d+%25-[a-zA-Z]+\)',
f'![Coverage](https://img.shields.io/badge/coverage-{coverage_percentage}%25-tan)',
content
r"!\[Coverage\]\(https://img\.shields\.io/badge/coverage-\d+%25-[a-zA-Z]+\)",
f"![Coverage](https://img.shields.io/badge/coverage-{coverage_percentage}%25-tan)",
content,
)
# Update kokoro badge
if "kokoro" in dep_info:
# Find badge like kokoro-v0.9.2::abcdefg-BB5420 or kokoro-v0.9.2-BB5420
kokoro_version = dep_info["kokoro"]["version"]
content = re.sub(
r'!\[Kokoro\]\(https://img\.shields\.io/badge/kokoro-[^)]+\)',
f'![Kokoro](https://img.shields.io/badge/kokoro-{dep_info["kokoro"]["version"]}::{dep_info["kokoro"]["commit"]}-BB5420)',
content
r"(!\[Kokoro\]\(https://img\.shields\.io/badge/kokoro-)[^)-]+(-BB5420\))",
lambda m: f"{m.group(1)}{kokoro_version}{m.group(2)}",
content,
)
# Update misaki badge
if "misaki" in dep_info:
# Find badge like misaki-v0.9.3::abcdefg-B8860B or misaki-v0.9.3-B8860B
misaki_version = dep_info["misaki"]["version"]
content = re.sub(
r'!\[Misaki\]\(https://img\.shields\.io/badge/misaki-[^)]+\)',
f'![Misaki](https://img.shields.io/badge/misaki-{dep_info["misaki"]["version"]}::{dep_info["misaki"]["commit"]}-B8860B)',
content
r"(!\[Misaki\]\(https://img\.shields\.io/badge/misaki-)[^)-]+(-B8860B\))",
lambda m: f"{m.group(1)}{misaki_version}{m.group(2)}",
content,
)
readme_path.write_text(content)
return True
def main():
# Get dependency info
dep_info = extract_dependency_info()
# Run tests and get coverage
passed_tests, coverage_percentage = run_pytest_with_coverage()
# Update badges
if update_readme_badges(passed_tests, coverage_percentage, dep_info):
print(f"Updated badges:")
print(f"- Tests: {passed_tests} passed")
print(f"- Coverage: {coverage_percentage}%")
if "kokoro" in dep_info:
print(f"- Kokoro: {dep_info['kokoro']['version']}::{dep_info['kokoro']['commit']}")
print(f"- Kokoro: {dep_info['kokoro']['version']}")
if "misaki" in dep_info:
print(f"- Misaki: {dep_info['misaki']['version']}::{dep_info['misaki']['commit']}")
print(f"- Misaki: {dep_info['misaki']['version']}")
else:
print("Failed to update badges")
if __name__ == "__main__":
main()
main()

234
scripts/update_version.py Executable file
View file

@ -0,0 +1,234 @@
#!/usr/bin/env python3
"""
Version Update Script
This script reads the version from the VERSION file and updates references
in pyproject.toml, the Helm chart, and README.md.
"""
import re
from pathlib import Path
import yaml
# Get the project root directory
ROOT_DIR = Path(__file__).parent.parent
# --- Configuration ---
VERSION_FILE = ROOT_DIR / "VERSION"
PYPROJECT_FILE = ROOT_DIR / "pyproject.toml"
HELM_CHART_FILE = ROOT_DIR / "charts" / "kokoro-fastapi" / "Chart.yaml"
README_FILE = ROOT_DIR / "README.md"
# --- End Configuration ---
def update_pyproject(version: str):
"""Updates the version in pyproject.toml"""
if not PYPROJECT_FILE.exists():
print(f"Skipping: {PYPROJECT_FILE} not found.")
return
try:
content = PYPROJECT_FILE.read_text()
# Regex to find and capture current version = "X.Y.Z" under [project]
pattern = r'(^\[project\]\s*(?:.*\s)*?version\s*=\s*)"([^"]+)"'
match = re.search(pattern, content, flags=re.MULTILINE)
if not match:
print(f"Warning: Version pattern not found in {PYPROJECT_FILE}")
return
current_version = match.group(2)
if current_version == version:
print(f"Already up-to-date: {PYPROJECT_FILE} (version {version})")
else:
# Perform replacement
new_content = re.sub(
pattern, rf'\1"{version}"', content, count=1, flags=re.MULTILINE
)
PYPROJECT_FILE.write_text(new_content)
print(f"Updated {PYPROJECT_FILE} from {current_version} to {version}")
except Exception as e:
print(f"Error processing {PYPROJECT_FILE}: {e}")
def update_helm_chart(version: str):
"""Updates the version and appVersion in the Helm chart"""
if not HELM_CHART_FILE.exists():
print(f"Skipping: {HELM_CHART_FILE} not found.")
return
try:
content = HELM_CHART_FILE.read_text()
original_content = content
updated_count = 0
# Update 'version:' line (unquoted)
# Looks for 'version:' followed by optional whitespace and the version number
version_pattern = r"^(version:\s*)(\S+)"
current_version_match = re.search(version_pattern, content, flags=re.MULTILINE)
if current_version_match and current_version_match.group(2) != version:
content = re.sub(
version_pattern,
rf"\g<1>{version}",
content,
count=1,
flags=re.MULTILINE,
)
print(
f"Updating 'version' in {HELM_CHART_FILE} from {current_version_match.group(2)} to {version}"
)
updated_count += 1
elif current_version_match:
print(f"Already up-to-date: 'version' in {HELM_CHART_FILE} is {version}")
else:
print(f"Warning: 'version:' pattern not found in {HELM_CHART_FILE}")
# Update 'appVersion:' line (quoted or unquoted)
# Looks for 'appVersion:' followed by optional whitespace, optional quote, the version, optional quote
app_version_pattern = r"^(appVersion:\s*)(\"?)([^\"\s]+)(\"?)"
current_app_version_match = re.search(
app_version_pattern, content, flags=re.MULTILINE
)
if current_app_version_match:
leading_whitespace = current_app_version_match.group(
1
) # e.g., "appVersion: "
opening_quote = current_app_version_match.group(2) # e.g., '"' or ''
current_app_ver = current_app_version_match.group(3) # e.g., '0.2.0'
closing_quote = current_app_version_match.group(4) # e.g., '"' or ''
# Check if quotes were consistent (both present or both absent)
if opening_quote != closing_quote:
print(
f"Warning: Inconsistent quotes found for appVersion in {HELM_CHART_FILE}. Skipping update for this line."
)
elif (
current_app_ver == version and opening_quote == '"'
): # Check if already correct *and* quoted
print(
f"Already up-to-date: 'appVersion' in {HELM_CHART_FILE} is \"{version}\""
)
else:
# Always replace with the quoted version
replacement = f'{leading_whitespace}"{version}"' # Ensure quotes
original_display = f"{opening_quote}{current_app_ver}{closing_quote}" # How it looked before
target_display = f'"{version}"' # How it should look
# Only report update if the displayed value actually changes
if original_display != target_display:
content = re.sub(
app_version_pattern,
replacement,
content,
count=1,
flags=re.MULTILINE,
)
print(
f"Updating 'appVersion' in {HELM_CHART_FILE} from {original_display} to {target_display}"
)
updated_count += 1
else:
# It matches the target version but might need quoting fixed silently if we didn't update
# Or it was already correct. Check if content changed. If not, report up-to-date.
if not (
content != original_content and updated_count > 0
): # Avoid double message if version also changed
print(
f"Already up-to-date: 'appVersion' in {HELM_CHART_FILE} is {target_display}"
)
else:
print(f"Warning: 'appVersion:' pattern not found in {HELM_CHART_FILE}")
# Write back only if changes were made
if content != original_content:
HELM_CHART_FILE.write_text(content)
# Confirmation message printed above during the specific update
elif updated_count == 0 and current_version_match and current_app_version_match:
# If no updates were made but patterns were found, confirm it's up-to-date overall
print(f"Already up-to-date: {HELM_CHART_FILE} (version {version})")
except Exception as e:
print(f"Error processing {HELM_CHART_FILE}: {e}")
def update_readme(version_with_v: str):
"""Updates Docker image tags in README.md"""
if not README_FILE.exists():
print(f"Skipping: {README_FILE} not found.")
return
try:
content = README_FILE.read_text()
# Regex to find and capture current ghcr.io/.../kokoro-fastapi-(cpu|gpu):vX.Y.Z
pattern = r"(ghcr\.io/remsky/kokoro-fastapi-(?:cpu|gpu)):(v\d+\.\d+\.\d+)"
matches = list(re.finditer(pattern, content)) # Find all occurrences
if not matches:
print(f"Warning: Docker image tag pattern not found in {README_FILE}")
else:
updated_needed = False
for match in matches:
current_tag = match.group(2)
if current_tag != version_with_v:
updated_needed = True
break # Only need one mismatch to trigger update
if updated_needed:
# Perform replacement on all occurrences
new_content = re.sub(pattern, rf"\1:{version_with_v}", content)
README_FILE.write_text(new_content)
print(f"Updated Docker image tags in {README_FILE} to {version_with_v}")
else:
print(
f"Already up-to-date: Docker image tags in {README_FILE} (version {version_with_v})"
)
# Check for ':latest' tag usage remains the same
if ":latest" in content:
print(
f"Warning: Found ':latest' tag in {README_FILE}. Consider updating manually if needed."
)
except Exception as e:
print(f"Error processing {README_FILE}: {e}")
def main():
# Read the version from the VERSION file
if not VERSION_FILE.exists():
print(f"Error: {VERSION_FILE} not found.")
return
try:
version = VERSION_FILE.read_text().strip()
if not re.match(r"^\d+\.\d+\.\d+$", version):
print(
f"Error: Invalid version format '{version}' in {VERSION_FILE}. Expected X.Y.Z"
)
return
except Exception as e:
print(f"Error reading {VERSION_FILE}: {e}")
return
print(f"Read version: {version} from {VERSION_FILE}")
print("-" * 20)
# Prepare versions (with and without 'v')
version_plain = version
version_with_v = f"v{version}"
# Update files
update_pyproject(version_plain)
update_helm_chart(version_plain)
update_readme(version_with_v)
print("-" * 20)
print("Version update script finished.")
if __name__ == "__main__":
main()

View file

@ -1,49 +0,0 @@
{
"document": "doc.report.command",
"version": "ov/command/slim/1.1",
"engine": "linux/amd64|ALP|x.1.42.2|29e62e7836de7b1004607c51c502537ffe1969f0|2025-01-16_07:48:54AM|x",
"containerized": false,
"host_distro": {
"name": "Ubuntu",
"version": "22.04",
"display_name": "Ubuntu 22.04.5 LTS"
},
"type": "slim",
"state": "error",
"target_reference": "kokoro-fastapi:latest",
"system": {
"type": "",
"release": "",
"distro": {
"name": "",
"version": "",
"display_name": ""
}
},
"source_image": {
"identity": {
"id": ""
},
"size": 0,
"size_human": "",
"create_time": "",
"architecture": "",
"container_entry": {
"exe_path": ""
}
},
"minified_image_size": 0,
"minified_image_size_human": "",
"minified_image": "",
"minified_image_id": "",
"minified_image_digest": "",
"minified_image_has_data": false,
"minified_by": 0,
"artifact_location": "",
"container_report_name": "",
"seccomp_profile_name": "",
"apparmor_profile_name": "",
"image_stack": null,
"image_created": false,
"image_build_engine": ""
}

View file

@ -10,9 +10,17 @@ export PYTHONPATH=$PROJECT_ROOT:$PROJECT_ROOT/api
export MODEL_DIR=src/models
export VOICES_DIR=src/voices/v1_0
export WEB_PLAYER_PATH=$PROJECT_ROOT/web
# Set the espeak-ng data path to your location
export ESPEAK_DATA_PATH=/usr/lib/x86_64-linux-gnu/espeak-ng-data
# Run FastAPI with CPU extras using uv run
# Note: espeak may still require manual installation,
uv pip install -e ".[cpu]"
uv run --no-sync python docker/scripts/download_model.py --output api/src/models/v1_0
# Apply the misaki patch to fix possible EspeakWrapper issue in older versions
# echo "Applying misaki patch..."
# python scripts/fix_misaki.py
# Start the server
uv run --no-sync uvicorn api.src.main:app --host 0.0.0.0 --port 8880

View file

@ -3,13 +3,6 @@
# Get project root directory
PROJECT_ROOT=$(pwd)
# Create mps-specific venv directory
VENV_DIR="$PROJECT_ROOT/.venv-mps"
if [ ! -d "$VENV_DIR" ]; then
echo "Creating MPS-specific virtual environment..."
python3 -m venv "$VENV_DIR"
fi
# Set other environment variables
export USE_GPU=true
export USE_ONNX=false
@ -18,18 +11,11 @@ export MODEL_DIR=src/models
export VOICES_DIR=src/voices/v1_0
export WEB_PLAYER_PATH=$PROJECT_ROOT/web
# Set environment variables
export USE_GPU=true
export USE_ONNX=false
export PYTHONPATH=$PROJECT_ROOT:$PROJECT_ROOT/api
export MODEL_DIR=src/models
export VOICES_DIR=src/voices/v1_0
export WEB_PLAYER_PATH=$PROJECT_ROOT/web
export DEVICE_TYPE=mps
# Enable MPS fallback for unsupported operations
export PYTORCH_ENABLE_MPS_FALLBACK=1
# Run FastAPI with GPU extras using uv run
uv pip install -e .
uv run --no-sync python docker/scripts/download_model.py --output api/src/models/v1_0
uv run --no-sync uvicorn api.src.main:app --host 0.0.0.0 --port 8880

View file

@ -1,6 +1,7 @@
import pytest
from unittest.mock import AsyncMock, Mock
import pytest
from api.src.services.tts_service import TTSService
@ -30,17 +31,22 @@ async def mock_tts_service(mock_model_manager, mock_voice_manager):
@pytest.fixture(autouse=True)
async def setup_mocks(monkeypatch, mock_model_manager, mock_voice_manager, mock_tts_service):
async def setup_mocks(
monkeypatch, mock_model_manager, mock_voice_manager, mock_tts_service
):
"""Setup global mocks for UI tests"""
async def mock_get_model():
return mock_model_manager
async def mock_get_voice():
return mock_voice_manager
async def mock_create_service():
return mock_tts_service
monkeypatch.setattr("api.src.inference.model_manager.get_manager", mock_get_model)
monkeypatch.setattr("api.src.inference.voice_manager.get_manager", mock_get_voice)
monkeypatch.setattr("api.src.services.tts_service.TTSService.create", mock_create_service)
monkeypatch.setattr(
"api.src.services.tts_service.TTSService.create", mock_create_service
)

View file

@ -1,4 +1,4 @@
from unittest.mock import patch, mock_open
from unittest.mock import mock_open, patch
import pytest
import requests
@ -59,9 +59,11 @@ def test_check_api_status_connection_error():
def test_text_to_speech_success(mock_response, tmp_path):
"""Test successful speech generation"""
with patch("requests.post", return_value=mock_response({})), patch(
"ui.lib.api.OUTPUTS_DIR", str(tmp_path)
), patch("builtins.open", mock_open()) as mock_file:
with (
patch("requests.post", return_value=mock_response({})),
patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)),
patch("builtins.open", mock_open()) as mock_file,
):
result = api.text_to_speech("test text", "voice1", "mp3", 1.0)
assert result is not None
@ -116,9 +118,11 @@ def test_text_to_speech_api_params(mock_response, tmp_path):
]
for input_voice, expected_voice in test_cases:
with patch("requests.post") as mock_post, patch(
"ui.lib.api.OUTPUTS_DIR", str(tmp_path)
), patch("builtins.open", mock_open()):
with (
patch("requests.post") as mock_post,
patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)),
patch("builtins.open", mock_open()),
):
mock_post.return_value = mock_response({})
api.text_to_speech("test text", input_voice, "mp3", 1.5)
@ -149,11 +153,15 @@ def test_text_to_speech_output_filename(mock_response, tmp_path):
]
for input_voice, filename_check in test_cases:
with patch("requests.post", return_value=mock_response({})), patch(
"ui.lib.api.OUTPUTS_DIR", str(tmp_path)
), patch("builtins.open", mock_open()) as mock_file:
with (
patch("requests.post", return_value=mock_response({})),
patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)),
patch("builtins.open", mock_open()) as mock_file,
):
result = api.text_to_speech("test text", input_voice, "mp3", 1.0)
assert result is not None
assert filename_check(result), f"Expected voice pattern not found in filename: {result}"
assert filename_check(result), (
f"Expected voice pattern not found in filename: {result}"
)
mock_file.assert_called_once()

View file

@ -1,9 +1,9 @@
import gradio as gr
import pytest
from ui.lib.config import AUDIO_FORMATS
from ui.lib.components.model import create_model_column
from ui.lib.components.output import create_output_column
from ui.lib.config import AUDIO_FORMATS
def test_create_model_column_structure():

View file

@ -15,8 +15,9 @@ def mock_dirs(tmp_path):
inputs_dir.mkdir()
outputs_dir.mkdir()
with patch("ui.lib.files.INPUTS_DIR", str(inputs_dir)), patch(
"ui.lib.files.OUTPUTS_DIR", str(outputs_dir)
with (
patch("ui.lib.files.INPUTS_DIR", str(inputs_dir)),
patch("ui.lib.files.OUTPUTS_DIR", str(outputs_dir)),
):
yield inputs_dir, outputs_dir

View file

@ -62,8 +62,9 @@ def test_interface_html_links():
def test_update_status_available(mock_timer):
"""Test status update when service is available"""
voices = ["voice1", "voice2"]
with patch("ui.lib.api.check_api_status", return_value=(True, voices)), patch(
"gradio.Timer", return_value=mock_timer
with (
patch("ui.lib.api.check_api_status", return_value=(True, voices)),
patch("gradio.Timer", return_value=mock_timer),
):
demo = create_interface()
@ -81,8 +82,9 @@ def test_update_status_available(mock_timer):
def test_update_status_unavailable(mock_timer):
"""Test status update when service is unavailable"""
with patch("ui.lib.api.check_api_status", return_value=(False, [])), patch(
"gradio.Timer", return_value=mock_timer
with (
patch("ui.lib.api.check_api_status", return_value=(False, [])),
patch("gradio.Timer", return_value=mock_timer),
):
demo = create_interface()
update_fn = mock_timer.events[0].fn
@ -97,9 +99,10 @@ def test_update_status_unavailable(mock_timer):
def test_update_status_error(mock_timer):
"""Test status update when an error occurs"""
with patch(
"ui.lib.api.check_api_status", side_effect=Exception("Test error")
), patch("gradio.Timer", return_value=mock_timer):
with (
patch("ui.lib.api.check_api_status", side_effect=Exception("Test error")),
patch("gradio.Timer", return_value=mock_timer),
):
demo = create_interface()
update_fn = mock_timer.events[0].fn
@ -113,8 +116,9 @@ def test_update_status_error(mock_timer):
def test_timer_configuration(mock_timer):
"""Test timer configuration"""
with patch("ui.lib.api.check_api_status", return_value=(False, [])), patch(
"gradio.Timer", return_value=mock_timer
with (
patch("ui.lib.api.check_api_status", return_value=(False, [])),
patch("gradio.Timer", return_value=mock_timer),
):
demo = create_interface()

View file

@ -1,6 +1,6 @@
import os
import datetime
from typing import List, Tuple, Optional
import os
from typing import List, Optional, Tuple
import requests

View file

@ -11,12 +11,10 @@ def create_input_column(disable_local_saving: bool = False) -> Tuple[gr.Column,
text_input = gr.Textbox(
label="Text to speak", placeholder="Enter text here...", lines=4
)
# Always show file upload but handle differently based on disable_local_saving
file_upload = gr.File(
label="Upload Text File (.txt)", file_types=[".txt"]
)
file_upload = gr.File(label="Upload Text File (.txt)", file_types=[".txt"])
if not disable_local_saving:
# Show full interface with tabs when saving is enabled
with gr.Tabs() as tabs:
@ -24,7 +22,9 @@ def create_input_column(disable_local_saving: bool = False) -> Tuple[gr.Column,
tabs.selected = 0
# Direct Input Tab
with gr.TabItem("Direct Input"):
text_submit_direct = gr.Button("Generate Speech", variant="primary", size="lg")
text_submit_direct = gr.Button(
"Generate Speech", variant="primary", size="lg"
)
# File Input Tab
with gr.TabItem("From File"):
@ -48,7 +48,9 @@ def create_input_column(disable_local_saving: bool = False) -> Tuple[gr.Column,
)
else:
# Just show the generate button when saving is disabled
text_submit_direct = gr.Button("Generate Speech", variant="primary", size="lg")
text_submit_direct = gr.Button(
"Generate Speech", variant="primary", size="lg"
)
tabs = None
input_files_list = None
file_preview = None

View file

@ -1,4 +1,4 @@
from typing import Tuple, Optional
from typing import Optional, Tuple
import gradio as gr

View file

@ -12,7 +12,7 @@ def create_output_column(disable_local_saving: bool = False) -> Tuple[gr.Column,
audio_output = gr.Audio(
label="Generated Speech",
type="filepath",
waveform_options={"waveform_color": "#4C87AB"}
waveform_options={"waveform_color": "#4C87AB"},
)
# Create file-related components with visible=False when local saving is disabled
@ -26,14 +26,14 @@ def create_output_column(disable_local_saving: bool = False) -> Tuple[gr.Column,
)
play_btn = gr.Button(
"▶️ Play Selected",
"▶️ Play Selected",
size="sm",
visible=not disable_local_saving,
)
selected_audio = gr.Audio(
label="Selected Output",
type="filepath",
label="Selected Output",
type="filepath",
visible=False, # Always initially hidden
)

View file

@ -1,8 +1,8 @@
import os
import datetime
from typing import List, Tuple, Optional
import os
from typing import List, Optional, Tuple
from .config import INPUTS_DIR, OUTPUTS_DIR, AUDIO_FORMATS
from .config import AUDIO_FORMATS, INPUTS_DIR, OUTPUTS_DIR
def list_input_files() -> List[str]:

View file

@ -58,17 +58,21 @@ def setup_event_handlers(components: dict, disable_local_saving: bool = False):
def handle_file_upload(file):
if file is None:
return "" if disable_local_saving else [gr.update(choices=files.list_input_files())]
return (
""
if disable_local_saving
else [gr.update(choices=files.list_input_files())]
)
try:
# Read the file content
with open(file.name, 'r', encoding='utf-8') as f:
with open(file.name, "r", encoding="utf-8") as f:
text_content = f.read()
if disable_local_saving:
# When saving is disabled, put content directly in text input
# Normalize whitespace by replacing newlines with spaces
normalized_text = ' '.join(text_content.split())
normalized_text = " ".join(text_content.split())
return normalized_text
else:
# When saving is enabled, save file and update dropdown
@ -88,7 +92,11 @@ def setup_event_handlers(components: dict, disable_local_saving: bool = False):
except Exception as e:
print(f"Error handling file: {e}")
return "" if disable_local_saving else [gr.update(choices=files.list_input_files())]
return (
""
if disable_local_saving
else [gr.update(choices=files.list_input_files())]
)
def generate_from_text(text, voice, format, speed):
"""Generate speech from direct text input"""
@ -104,7 +112,7 @@ def setup_event_handlers(components: dict, disable_local_saving: bool = False):
# Only save text if local saving is enabled
if not disable_local_saving:
files.save_text(text)
result = api.text_to_speech(text, voice, format, speed)
if result is None:
gr.Warning("Failed to generate speech. Please try again.")
@ -203,7 +211,11 @@ def setup_event_handlers(components: dict, disable_local_saving: bool = False):
components["input"]["file_upload"].upload(
fn=handle_file_upload,
inputs=[components["input"]["file_upload"]],
outputs=[components["input"]["text_input"] if disable_local_saving else components["input"]["file_select"]],
outputs=[
components["input"]["text_input"]
if disable_local_saving
else components["input"]["file_select"]
],
)
if components["output"]["play_btn"] is not None:

View file

@ -1,9 +1,10 @@
import gradio as gr
import os
import gradio as gr
from . import api
from .handlers import setup_event_handlers
from .components import create_input_column, create_model_column, create_output_column
from .handlers import setup_event_handlers
def create_interface():