mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Merge branch 'master' into fondoger/silence-tags
This commit is contained in:
commit
5dbf2e2e4b
55 changed files with 1712 additions and 870 deletions
87
.github/workflows/build-push.yml
vendored
87
.github/workflows/build-push.yml
vendored
|
@ -1,87 +0,0 @@
|
||||||
name: Docker Build and push
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
tags: [ 'v*.*.*' ]
|
|
||||||
paths-ignore:
|
|
||||||
- '**.md'
|
|
||||||
- 'docs/**'
|
|
||||||
workflow_dispatch:
|
|
||||||
inputs:
|
|
||||||
version:
|
|
||||||
description: 'Version to build and publish (e.g. v0.2.0)'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
prepare-release:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
outputs:
|
|
||||||
version: ${{ steps.get-version.outputs.version }}
|
|
||||||
is_prerelease: ${{ steps.check-prerelease.outputs.is_prerelease }}
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Get version
|
|
||||||
id: get-version
|
|
||||||
run: |
|
|
||||||
if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then
|
|
||||||
echo "version=${{ inputs.version }}" >> $GITHUB_OUTPUT
|
|
||||||
else
|
|
||||||
echo "version=$(cat VERSION)" >> $GITHUB_OUTPUT
|
|
||||||
fi
|
|
||||||
|
|
||||||
- name: Check if prerelease
|
|
||||||
id: check-prerelease
|
|
||||||
run: |
|
|
||||||
echo "is_prerelease=${{ contains(steps.get-version.outputs.version, '-pre') }}" >> $GITHUB_OUTPUT
|
|
||||||
|
|
||||||
build-images:
|
|
||||||
needs: prepare-release
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
permissions:
|
|
||||||
packages: write
|
|
||||||
env:
|
|
||||||
DOCKER_BUILDKIT: 1
|
|
||||||
BUILDKIT_STEP_LOG_MAX_SIZE: 10485760
|
|
||||||
# This environment variable will override the VERSION variable in your HCL file.
|
|
||||||
VERSION: ${{ needs.prepare-release.outputs.version }}
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Free disk space
|
|
||||||
run: |
|
|
||||||
echo "Listing current disk space"
|
|
||||||
df -h
|
|
||||||
echo "Cleaning up disk space..."
|
|
||||||
sudo rm -rf /usr/share/dotnet
|
|
||||||
sudo rm -rf /usr/local/lib/android
|
|
||||||
sudo rm -rf /opt/ghc
|
|
||||||
sudo rm -rf /opt/hostedtoolcache
|
|
||||||
docker system prune -af
|
|
||||||
echo "Disk space after cleanup"
|
|
||||||
df -h
|
|
||||||
|
|
||||||
- name: Set up QEMU
|
|
||||||
uses: docker/setup-qemu-action@v2
|
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v2
|
|
||||||
with:
|
|
||||||
driver-opts: |
|
|
||||||
image=moby/buildkit:latest
|
|
||||||
network=host
|
|
||||||
|
|
||||||
- name: Log in to GitHub Container Registry
|
|
||||||
uses: docker/login-action@v2
|
|
||||||
with:
|
|
||||||
registry: ghcr.io
|
|
||||||
username: ${{ github.actor }}
|
|
||||||
password: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
|
|
||||||
- name: Build and push images
|
|
||||||
run: |
|
|
||||||
# No need to override VERSION via --set; the env var does the job.
|
|
||||||
docker buildx bake --push
|
|
102
.github/workflows/docker-publish.yml
vendored
102
.github/workflows/docker-publish.yml
vendored
|
@ -1,102 +0,0 @@
|
||||||
name: Docker Build and Publish
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
tags: [ 'v*.*.*' ]
|
|
||||||
paths-ignore:
|
|
||||||
- '**.md'
|
|
||||||
- 'docs/**'
|
|
||||||
workflow_dispatch:
|
|
||||||
inputs:
|
|
||||||
version:
|
|
||||||
description: 'Version to release (e.g. v0.2.0)'
|
|
||||||
required: true
|
|
||||||
type: string
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
prepare-release:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
outputs:
|
|
||||||
version: ${{ steps.get-version.outputs.version }}
|
|
||||||
is_prerelease: ${{ steps.check-prerelease.outputs.is_prerelease }}
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Get version
|
|
||||||
id: get-version
|
|
||||||
run: |
|
|
||||||
if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then
|
|
||||||
echo "version=${{ inputs.version }}" >> $GITHUB_OUTPUT
|
|
||||||
else
|
|
||||||
echo "version=$(cat VERSION)" >> $GITHUB_OUTPUT
|
|
||||||
fi
|
|
||||||
|
|
||||||
- name: Check if prerelease
|
|
||||||
id: check-prerelease
|
|
||||||
run: echo "is_prerelease=${{ contains(steps.get-version.outputs.version, '-pre') }}" >> $GITHUB_OUTPUT
|
|
||||||
|
|
||||||
build-images:
|
|
||||||
needs: prepare-release
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
permissions:
|
|
||||||
packages: write
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Free disk space
|
|
||||||
run: |
|
|
||||||
echo "Listing current disk space"
|
|
||||||
df -h
|
|
||||||
echo "Cleaning up disk space..."
|
|
||||||
sudo rm -rf /usr/share/dotnet
|
|
||||||
sudo rm -rf /usr/local/lib/android
|
|
||||||
sudo rm -rf /opt/ghc
|
|
||||||
sudo rm -rf /opt/hostedtoolcache
|
|
||||||
docker system prune -af
|
|
||||||
echo "Disk space after cleanup"
|
|
||||||
df -h
|
|
||||||
|
|
||||||
- name: Set up QEMU
|
|
||||||
uses: docker/setup-qemu-action@v2
|
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v2
|
|
||||||
with:
|
|
||||||
driver-opts: |
|
|
||||||
image=moby/buildkit:latest
|
|
||||||
network=host
|
|
||||||
|
|
||||||
- name: Log in to GitHub Container Registry
|
|
||||||
uses: docker/login-action@v2
|
|
||||||
with:
|
|
||||||
registry: ghcr.io
|
|
||||||
username: ${{ github.actor }}
|
|
||||||
password: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
|
|
||||||
- name: Build and push images
|
|
||||||
env:
|
|
||||||
DOCKER_BUILDKIT: 1
|
|
||||||
BUILDKIT_STEP_LOG_MAX_SIZE: 10485760
|
|
||||||
VERSION: ${{ needs.prepare-release.outputs.version }}
|
|
||||||
run: docker buildx bake --push
|
|
||||||
|
|
||||||
create-release:
|
|
||||||
needs: [prepare-release, build-images]
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
permissions:
|
|
||||||
contents: write
|
|
||||||
steps:
|
|
||||||
- name: Checkout repository
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
fetch-depth: 0
|
|
||||||
|
|
||||||
- name: Create Release
|
|
||||||
uses: softprops/action-gh-release@v1
|
|
||||||
with:
|
|
||||||
tag_name: ${{ needs.prepare-release.outputs.version }}
|
|
||||||
generate_release_notes: true
|
|
||||||
draft: true
|
|
||||||
prerelease: ${{ needs.prepare-release.outputs.is_prerelease }}
|
|
110
.github/workflows/release.yml
vendored
Normal file
110
.github/workflows/release.yml
vendored
Normal file
|
@ -0,0 +1,110 @@
|
||||||
|
name: Create Release and Publish Docker Images
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- release # Trigger when commits are pushed to the release branch (e.g., after merging master)
|
||||||
|
paths-ignore:
|
||||||
|
- '**.md'
|
||||||
|
- 'docs/**'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
prepare-release:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
outputs:
|
||||||
|
version: ${{ steps.get-version.outputs.version }}
|
||||||
|
version_tag: ${{ steps.get-version.outputs.version_tag }}
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Get version from VERSION file
|
||||||
|
id: get-version
|
||||||
|
run: |
|
||||||
|
VERSION_PLAIN=$(cat VERSION)
|
||||||
|
echo "version=${VERSION_PLAIN}" >> $GITHUB_OUTPUT
|
||||||
|
echo "version_tag=v${VERSION_PLAIN}" >> $GITHUB_OUTPUT # Add 'v' prefix for tag
|
||||||
|
|
||||||
|
build-images:
|
||||||
|
needs: prepare-release
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
packages: write # Needed to push images to GHCR
|
||||||
|
env:
|
||||||
|
DOCKER_BUILDKIT: 1
|
||||||
|
BUILDKIT_STEP_LOG_MAX_SIZE: 10485760
|
||||||
|
# This environment variable will override the VERSION variable in docker-bake.hcl
|
||||||
|
VERSION: ${{ needs.prepare-release.outputs.version_tag }} # Use tag version (vX.Y.Z) for bake
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0 # Needed to check for existing tags
|
||||||
|
|
||||||
|
- name: Check if tag already exists
|
||||||
|
run: |
|
||||||
|
TAG_NAME="${{ needs.prepare-release.outputs.version_tag }}"
|
||||||
|
echo "Checking for existing tag: $TAG_NAME"
|
||||||
|
# Fetch tags explicitly just in case checkout didn't get them all
|
||||||
|
git fetch --tags
|
||||||
|
if git rev-parse "$TAG_NAME" >/dev/null 2>&1; then
|
||||||
|
echo "::error::Tag $TAG_NAME already exists. Please increment the version in the VERSION file."
|
||||||
|
exit 1
|
||||||
|
else
|
||||||
|
echo "Tag $TAG_NAME does not exist. Proceeding with release."
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Free disk space # Optional: Keep as needed for large builds
|
||||||
|
run: |
|
||||||
|
echo "Listing current disk space"
|
||||||
|
df -h
|
||||||
|
echo "Cleaning up disk space..."
|
||||||
|
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache
|
||||||
|
docker system prune -af
|
||||||
|
echo "Disk space after cleanup"
|
||||||
|
df -h
|
||||||
|
|
||||||
|
- name: Set up QEMU
|
||||||
|
uses: docker/setup-qemu-action@v3 # Use v3
|
||||||
|
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3 # Use v3
|
||||||
|
with:
|
||||||
|
driver-opts: |
|
||||||
|
image=moby/buildkit:latest
|
||||||
|
network=host
|
||||||
|
|
||||||
|
- name: Log in to GitHub Container Registry
|
||||||
|
uses: docker/login-action@v3 # Use v3
|
||||||
|
with:
|
||||||
|
registry: ghcr.io
|
||||||
|
username: ${{ github.actor }}
|
||||||
|
password: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
|
- name: Build and push images using Docker Bake
|
||||||
|
run: |
|
||||||
|
echo "Building and pushing images for version ${{ needs.prepare-release.outputs.version_tag }}"
|
||||||
|
# The VERSION env var above sets the tag for the bake file targets
|
||||||
|
docker buildx bake --push
|
||||||
|
|
||||||
|
create-release:
|
||||||
|
needs: [prepare-release, build-images]
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: write # Needed to create releases
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0 # Fetch all history for release notes generation
|
||||||
|
|
||||||
|
- name: Create GitHub Release
|
||||||
|
uses: softprops/action-gh-release@v2 # Use v2
|
||||||
|
with:
|
||||||
|
tag_name: ${{ needs.prepare-release.outputs.version_tag }} # Use vX.Y.Z tag
|
||||||
|
name: Release ${{ needs.prepare-release.outputs.version_tag }}
|
||||||
|
generate_release_notes: true # Auto-generate release notes
|
||||||
|
draft: false # Publish immediately
|
||||||
|
prerelease: false # Mark as a stable release
|
||||||
|
env:
|
||||||
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
31
CHANGELOG.md
31
CHANGELOG.md
|
@ -2,6 +2,37 @@
|
||||||
|
|
||||||
Notable changes to this project will be documented in this file.
|
Notable changes to this project will be documented in this file.
|
||||||
|
|
||||||
|
## [v0.3.0] - 2025-04-04
|
||||||
|
### Added
|
||||||
|
- Apple Silicon (MPS) acceleration support for macOS users.
|
||||||
|
- Voice subtraction capability for creating unique voice effects.
|
||||||
|
- Windows PowerShell start scripts (`start-cpu.ps1`, `start-gpu.ps1`).
|
||||||
|
- Automatic model downloading integrated into all start scripts.
|
||||||
|
- Example Helm chart values for Azure AKS and Nvidia GPU Operator deployments.
|
||||||
|
- `CONTRIBUTING.md` guidelines for developers.
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
- Version bump of underlying Kokoro and Misaki libraries
|
||||||
|
- Default API port reverted to 8880.
|
||||||
|
- Docker containers now run as a non-root user for enhanced security.
|
||||||
|
- Improved text normalization for numbers, currency, and time formats.
|
||||||
|
- Updated and improved Helm chart configurations and documentation.
|
||||||
|
- Enhanced temporary file management with better error tracking.
|
||||||
|
- Web UI dependencies (Siriwave) are now served locally.
|
||||||
|
- Standardized environment variable handling across shell/PowerShell scripts.
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
- Corrected an issue preventing download links from being returned when `streaming=false`.
|
||||||
|
- Resolved errors in Windows PowerShell scripts related to virtual environment activation order.
|
||||||
|
- Addressed potential segfaults during inference.
|
||||||
|
- Fixed various Helm chart issues related to health checks, ingress, and default values.
|
||||||
|
- Corrected audio quality degradation caused by incorrect bitrate settings in some cases.
|
||||||
|
- Ensured custom phonemes provided in input text are preserved.
|
||||||
|
- Fixed a 'MediaSource' error affecting playback stability in the web player.
|
||||||
|
|
||||||
|
### Removed
|
||||||
|
- Obsolete GitHub Actions build workflow, build and publish now occurs on merge to `Release` branch
|
||||||
|
|
||||||
## [v0.2.0post1] - 2025-02-07
|
## [v0.2.0post1] - 2025-02-07
|
||||||
- Fix: Building Kokoro from source with adjustments, to avoid CUDA lock
|
- Fix: Building Kokoro from source with adjustments, to avoid CUDA lock
|
||||||
- Fixed ARM64 compatibility on Spacy dep to avoid emulation slowdown
|
- Fixed ARM64 compatibility on Spacy dep to avoid emulation slowdown
|
||||||
|
|
86
CONTRIBUTING.md
Normal file
86
CONTRIBUTING.md
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
# Contributing to Kokoro-FastAPI
|
||||||
|
|
||||||
|
Always appreciate community involvement in making this project better.
|
||||||
|
|
||||||
|
## Development Setup
|
||||||
|
|
||||||
|
We use `uv` for managing Python environments and dependencies, and `ruff` for linting and formatting.
|
||||||
|
|
||||||
|
1. **Clone the repository:**
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/remsky/Kokoro-FastAPI.git
|
||||||
|
cd Kokoro-FastAPI
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Install `uv`:**
|
||||||
|
Follow the instructions on the [official `uv` documentation](https://docs.astral.sh/uv/install/).
|
||||||
|
|
||||||
|
3. **Create a virtual environment and install dependencies:**
|
||||||
|
It's recommended to use a virtual environment. `uv` can create one for you. Install the base dependencies along with the `test` and `cpu` extras (needed for running tests locally).
|
||||||
|
```bash
|
||||||
|
# Create and activate a virtual environment (e.g., named .venv)
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate # On Linux/macOS
|
||||||
|
# .venv\Scripts\activate # On Windows
|
||||||
|
|
||||||
|
# Install dependencies including test requirements
|
||||||
|
uv pip install -e ".[test,cpu]"
|
||||||
|
```
|
||||||
|
*Note: If you have an NVIDIA GPU and want to test GPU-specific features locally, you can install `.[test,gpu]` instead, ensuring you have the correct CUDA toolkit installed.*
|
||||||
|
|
||||||
|
*Note: If running via uv locally, you will have to install espeak and handle any pathing issues that arise. The Docker images handle this automatically*
|
||||||
|
|
||||||
|
4. **Install `ruff` (if not already installed globally):**
|
||||||
|
While `ruff` might be included via dependencies, installing it explicitly ensures you have it available.
|
||||||
|
```bash
|
||||||
|
uv pip install ruff
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running Tests
|
||||||
|
|
||||||
|
Before submitting changes, please ensure all tests pass as this is a automated requirement. The tests are run using `pytest`.
|
||||||
|
```bash
|
||||||
|
# Make sure your virtual environment is activated
|
||||||
|
uv run pytest
|
||||||
|
```
|
||||||
|
*Note: The CI workflow runs tests using `uv run pytest api/tests/ --asyncio-mode=auto --cov=api --cov-report=term-missing`. Running `uv run pytest` locally should cover the essential checks.*
|
||||||
|
|
||||||
|
## Testing with Docker Compose
|
||||||
|
|
||||||
|
In addition to local `pytest` runs, test your changes using Docker Compose to ensure they work correctly within the containerized environment. If you aren't able to test on CUDA hardware, make note so it can be tested by another maintainer
|
||||||
|
|
||||||
|
```bash
|
||||||
|
|
||||||
|
docker compose -f docker/cpu/docker-compose.yml up --build
|
||||||
|
+
|
||||||
|
docker compose -f docker/gpu/docker-compose.yml up --build
|
||||||
|
```
|
||||||
|
This command will build the Docker images (if they've changed) and start the services defined in the respective compose file. Verify the application starts correctly and test the relevant functionality.
|
||||||
|
|
||||||
|
## Code Formatting and Linting
|
||||||
|
|
||||||
|
We use `ruff` to maintain code quality and consistency. Please format and lint your code before committing.
|
||||||
|
|
||||||
|
1. **Format the code:**
|
||||||
|
```bash
|
||||||
|
# Make sure your virtual environment is activated
|
||||||
|
ruff format .
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Lint the code (and apply automatic fixes):**
|
||||||
|
```bash
|
||||||
|
# Make sure your virtual environment is activated
|
||||||
|
ruff check . --fix
|
||||||
|
```
|
||||||
|
Review any changes made by `--fix` and address any remaining linting errors manually.
|
||||||
|
|
||||||
|
## Submitting Changes
|
||||||
|
|
||||||
|
0. Clone the repo
|
||||||
|
1. Create a new branch for your feature or bug fix.
|
||||||
|
2. Make your changes, following setup, testing, and formatting guidelines above.
|
||||||
|
3. Please try to keep your changes inline with the current design, and modular. Large-scale changes will take longer to review and integrate, and have less chance of being approved outright.
|
||||||
|
4. Push your branch to your fork.
|
||||||
|
5. Open a Pull Request against the `master` branch of the main repository.
|
||||||
|
|
||||||
|
Thank you for contributing!
|
37
README.md
37
README.md
|
@ -3,12 +3,12 @@
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
# <sub><sub>_`FastKoko`_ </sub></sub>
|
# <sub><sub>_`FastKoko`_ </sub></sub>
|
||||||
[]()
|
[]()
|
||||||
[]()
|
[]()
|
||||||
[](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero)
|
[](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero)
|
||||||
|
|
||||||
[](https://github.com/hexgrad/kokoro)
|
[](https://github.com/hexgrad/kokoro)
|
||||||
[](https://github.com/hexgrad/misaki)
|
[](https://github.com/hexgrad/misaki)
|
||||||
|
|
||||||
[](https://huggingface.co/hexgrad/Kokoro-82M/commit/9901c2b79161b6e898b7ea857ae5298f47b8b0d6)
|
[](https://huggingface.co/hexgrad/Kokoro-82M/commit/9901c2b79161b6e898b7ea857ae5298f47b8b0d6)
|
||||||
|
|
||||||
|
@ -24,10 +24,6 @@ Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokor
|
||||||
### Integration Guides
|
### Integration Guides
|
||||||
[](https://github.com/remsky/Kokoro-FastAPI/wiki/Setup-Kubernetes) [](https://github.com/remsky/Kokoro-FastAPI/wiki/Integrations-DigitalOcean) [](https://github.com/remsky/Kokoro-FastAPI/wiki/Integrations-SillyTavern)
|
[](https://github.com/remsky/Kokoro-FastAPI/wiki/Setup-Kubernetes) [](https://github.com/remsky/Kokoro-FastAPI/wiki/Integrations-DigitalOcean) [](https://github.com/remsky/Kokoro-FastAPI/wiki/Integrations-SillyTavern)
|
||||||
[](https://github.com/remsky/Kokoro-FastAPI/wiki/Integrations-OpenWebUi)
|
[](https://github.com/remsky/Kokoro-FastAPI/wiki/Integrations-OpenWebUi)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Get Started
|
## Get Started
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
|
@ -38,11 +34,12 @@ Pre built images are available to run, with arm/multi-arch support, and baked in
|
||||||
Refer to the core/config.py file for a full list of variables which can be managed via the environment
|
Refer to the core/config.py file for a full list of variables which can be managed via the environment
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# the `latest` tag can be used, but should not be considered stable as it may include `nightly` branch builds
|
# the `latest` tag can be used, though it may have some unexpected bonus features which impact stability.
|
||||||
# it may have some bonus features however, and feedback/testing is welcome
|
Named versions should be pinned for your regular usage.
|
||||||
|
Feedback/testing is always welcome
|
||||||
|
|
||||||
docker run -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-cpu:v0.2.2 # CPU, or:
|
docker run -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-cpu:latest # CPU, or:
|
||||||
docker run --gpus all -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-gpu:v0.2.2 #NVIDIA GPU
|
docker run --gpus all -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-gpu:latest #NVIDIA GPU
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
@ -66,7 +63,7 @@ docker run --gpus all -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-gpu:v0.2.2 #NV
|
||||||
# *Note for Apple Silicon (M1/M2) users:
|
# *Note for Apple Silicon (M1/M2) users:
|
||||||
# The current GPU build relies on CUDA, which is not supported on Apple Silicon.
|
# The current GPU build relies on CUDA, which is not supported on Apple Silicon.
|
||||||
# If you are on an M1/M2/M3 Mac, please use the `docker/cpu` setup.
|
# If you are on an M1/M2/M3 Mac, please use the `docker/cpu` setup.
|
||||||
# MPS (Apple’s GPU acceleration) support is planned but not yet available.
|
# MPS (Apple's GPU acceleration) support is planned but not yet available.
|
||||||
|
|
||||||
# Models will auto-download, but if needed you can manually download:
|
# Models will auto-download, but if needed you can manually download:
|
||||||
python docker/scripts/download_model.py --output api/src/models/v1_0
|
python docker/scripts/download_model.py --output api/src/models/v1_0
|
||||||
|
@ -139,8 +136,8 @@ with client.audio.speech.with_streaming_response.create(
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary>OpenAI-Compatible Speech Endpoint</summary>
|
<summary>OpenAI-Compatible Speech Endpoint</summary>
|
||||||
|
|
||||||
|
@ -564,13 +561,15 @@ for chunk in response.iter_content(chunk_size=1024):
|
||||||
<details>
|
<details>
|
||||||
<summary>Versioning & Development</summary>
|
<summary>Versioning & Development</summary>
|
||||||
|
|
||||||
I'm doing what I can to keep things stable, but we are on an early and rapid set of build cycles here.
|
**Branching Strategy:**
|
||||||
If you run into trouble, you may have to roll back a version on the release tags if something comes up, or build up from source and/or troubleshoot + submit a PR. Will leave the branch up here for the last known stable points:
|
* **`release` branch:** Contains the latest stable build, recommended for production use. Docker images tagged with specific versions (e.g., `v0.3.0`) are built from this branch.
|
||||||
|
* **`master` branch:** Used for active development. It may contain experimental features, ongoing changes, or fixes not yet in a stable release. Use this branch if you want the absolute latest code, but be aware it might be less stable. The `latest` Docker tag often points to builds from this branch.
|
||||||
|
|
||||||
`v0.1.4`
|
Note: This is a *development* focused project at its core.
|
||||||
`v0.0.5post1`
|
|
||||||
|
|
||||||
Free and open source is a community effort, and I love working on this project, though there's only really so many hours in a day. If you'd like to support the work, feel free to open a PR, buy me a coffee, or report any bugs/features/etc you find during use.
|
If you run into trouble, you may have to roll back a version on the release tags if something comes up, or build up from source and/or troubleshoot + submit a PR.
|
||||||
|
|
||||||
|
Free and open source is a community effort, and there's only really so many hours in a day. If you'd like to support the work, feel free to open a PR, buy me a coffee, or report any bugs/features/etc you find during use.
|
||||||
|
|
||||||
<a href="https://www.buymeacoffee.com/remsky" target="_blank">
|
<a href="https://www.buymeacoffee.com/remsky" target="_blank">
|
||||||
<img
|
<img
|
||||||
|
|
2
VERSION
2
VERSION
|
@ -1 +1 @@
|
||||||
v0.2.1
|
0.3.0
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from pydantic_settings import BaseSettings
|
|
||||||
import torch
|
import torch
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
|
@ -14,9 +14,13 @@ class Settings(BaseSettings):
|
||||||
output_dir: str = "output"
|
output_dir: str = "output"
|
||||||
output_dir_size_limit_mb: float = 500.0 # Maximum size of output directory in MB
|
output_dir_size_limit_mb: float = 500.0 # Maximum size of output directory in MB
|
||||||
default_voice: str = "af_heart"
|
default_voice: str = "af_heart"
|
||||||
default_voice_code: str | None = None # If set, overrides the first letter of voice name, though api call param still takes precedence
|
default_voice_code: str | None = (
|
||||||
|
None # If set, overrides the first letter of voice name, though api call param still takes precedence
|
||||||
|
)
|
||||||
use_gpu: bool = True # Whether to use GPU acceleration if available
|
use_gpu: bool = True # Whether to use GPU acceleration if available
|
||||||
device_type: str | None = None # Will be auto-detected if None, can be "cuda", "mps", or "cpu"
|
device_type: str | None = (
|
||||||
|
None # Will be auto-detected if None, can be "cuda", "mps", or "cpu"
|
||||||
|
)
|
||||||
allow_local_voice_saving: bool = (
|
allow_local_voice_saving: bool = (
|
||||||
False # Whether to allow saving combined voices locally
|
False # Whether to allow saving combined voices locally
|
||||||
)
|
)
|
||||||
|
@ -32,11 +36,20 @@ class Settings(BaseSettings):
|
||||||
target_max_tokens: int = 250 # Target maximum tokens per chunk
|
target_max_tokens: int = 250 # Target maximum tokens per chunk
|
||||||
absolute_max_tokens: int = 450 # Absolute maximum tokens per chunk
|
absolute_max_tokens: int = 450 # Absolute maximum tokens per chunk
|
||||||
advanced_text_normalization: bool = True # Preproesses the text before misiki
|
advanced_text_normalization: bool = True # Preproesses the text before misiki
|
||||||
voice_weight_normalization: bool = True # Normalize the voice weights so they add up to 1
|
voice_weight_normalization: bool = (
|
||||||
|
True # Normalize the voice weights so they add up to 1
|
||||||
|
)
|
||||||
|
|
||||||
gap_trim_ms: int = 1 # Base amount to trim from streaming chunk ends in milliseconds
|
gap_trim_ms: int = (
|
||||||
|
1 # Base amount to trim from streaming chunk ends in milliseconds
|
||||||
|
)
|
||||||
dynamic_gap_trim_padding_ms: int = 410 # Padding to add to dynamic gap trim
|
dynamic_gap_trim_padding_ms: int = 410 # Padding to add to dynamic gap trim
|
||||||
dynamic_gap_trim_padding_char_multiplier: dict[str,float] = {".":1,"!":0.9,"?":1,",":0.8}
|
dynamic_gap_trim_padding_char_multiplier: dict[str, float] = {
|
||||||
|
".": 1,
|
||||||
|
"!": 0.9,
|
||||||
|
"?": 1,
|
||||||
|
",": 0.8,
|
||||||
|
}
|
||||||
|
|
||||||
# Web Player Settings
|
# Web Player Settings
|
||||||
enable_web_player: bool = True # Whether to serve the web player UI
|
enable_web_player: bool = True # Whether to serve the web player UI
|
||||||
|
@ -69,5 +82,4 @@ class Settings(BaseSettings):
|
||||||
return "cpu"
|
return "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
|
@ -1,34 +1,41 @@
|
||||||
"""Base interface for Kokoro inference."""
|
"""Base interface for Kokoro inference."""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import AsyncGenerator, Optional, Tuple, Union, List
|
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class AudioChunk:
|
class AudioChunk:
|
||||||
"""Class for audio chunks returned by model backends"""
|
"""Class for audio chunks returned by model backends"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
word_timestamps: Optional[List]=[],
|
word_timestamps: Optional[List] = [],
|
||||||
output: Optional[Union[bytes,np.ndarray]]=b""
|
output: Optional[Union[bytes, np.ndarray]] = b"",
|
||||||
):
|
):
|
||||||
self.audio=audio
|
self.audio = audio
|
||||||
self.word_timestamps=word_timestamps
|
self.word_timestamps = word_timestamps
|
||||||
self.output=output
|
self.output = output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def combine(audio_chunk_list: List):
|
def combine(audio_chunk_list: List):
|
||||||
output=AudioChunk(audio_chunk_list[0].audio,audio_chunk_list[0].word_timestamps)
|
output = AudioChunk(
|
||||||
|
audio_chunk_list[0].audio, audio_chunk_list[0].word_timestamps
|
||||||
|
)
|
||||||
|
|
||||||
for audio_chunk in audio_chunk_list[1:]:
|
for audio_chunk in audio_chunk_list[1:]:
|
||||||
output.audio=np.concatenate((output.audio,audio_chunk.audio),dtype=np.int16)
|
output.audio = np.concatenate(
|
||||||
|
(output.audio, audio_chunk.audio), dtype=np.int16
|
||||||
|
)
|
||||||
if output.word_timestamps is not None:
|
if output.word_timestamps is not None:
|
||||||
output.word_timestamps+=audio_chunk.word_timestamps
|
output.word_timestamps += audio_chunk.word_timestamps
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
class ModelBackend(ABC):
|
class ModelBackend(ABC):
|
||||||
"""Abstract base class for model inference backend."""
|
"""Abstract base class for model inference backend."""
|
||||||
|
|
||||||
|
|
|
@ -11,9 +11,10 @@ from loguru import logger
|
||||||
from ..core import paths
|
from ..core import paths
|
||||||
from ..core.config import settings
|
from ..core.config import settings
|
||||||
from ..core.model_config import model_config
|
from ..core.model_config import model_config
|
||||||
from .base import BaseModelBackend
|
|
||||||
from .base import AudioChunk
|
|
||||||
from ..structures.schemas import WordTimestamp
|
from ..structures.schemas import WordTimestamp
|
||||||
|
from .base import AudioChunk, BaseModelBackend
|
||||||
|
|
||||||
|
|
||||||
class KokoroV1(BaseModelBackend):
|
class KokoroV1(BaseModelBackend):
|
||||||
"""Kokoro backend with controlled resource management."""
|
"""Kokoro backend with controlled resource management."""
|
||||||
|
|
||||||
|
@ -50,7 +51,9 @@ class KokoroV1(BaseModelBackend):
|
||||||
self._model = KModel(config=config_path, model=model_path).eval()
|
self._model = KModel(config=config_path, model=model_path).eval()
|
||||||
# For MPS, manually move ISTFT layers to CPU while keeping rest on MPS
|
# For MPS, manually move ISTFT layers to CPU while keeping rest on MPS
|
||||||
if self._device == "mps":
|
if self._device == "mps":
|
||||||
logger.info("Moving model to MPS device with CPU fallback for unsupported operations")
|
logger.info(
|
||||||
|
"Moving model to MPS device with CPU fallback for unsupported operations"
|
||||||
|
)
|
||||||
self._model = self._model.to(torch.device("mps"))
|
self._model = self._model.to(torch.device("mps"))
|
||||||
elif self._device == "cuda":
|
elif self._device == "cuda":
|
||||||
self._model = self._model.cuda()
|
self._model = self._model.cuda()
|
||||||
|
@ -244,7 +247,15 @@ class KokoroV1(BaseModelBackend):
|
||||||
voice_path = temp_path
|
voice_path = temp_path
|
||||||
|
|
||||||
# Use provided lang_code, settings voice code override, or first letter of voice name
|
# Use provided lang_code, settings voice code override, or first letter of voice name
|
||||||
pipeline_lang_code = lang_code if lang_code else (settings.default_voice_code if settings.default_voice_code else voice_name[0].lower())
|
pipeline_lang_code = (
|
||||||
|
lang_code
|
||||||
|
if lang_code
|
||||||
|
else (
|
||||||
|
settings.default_voice_code
|
||||||
|
if settings.default_voice_code
|
||||||
|
else voice_name[0].lower()
|
||||||
|
)
|
||||||
|
)
|
||||||
pipeline = self._get_pipeline(pipeline_lang_code)
|
pipeline = self._get_pipeline(pipeline_lang_code)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
@ -255,16 +266,19 @@ class KokoroV1(BaseModelBackend):
|
||||||
):
|
):
|
||||||
if result.audio is not None:
|
if result.audio is not None:
|
||||||
logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
|
logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
|
||||||
word_timestamps=None
|
word_timestamps = None
|
||||||
if return_timestamps and hasattr(result, "tokens") and result.tokens:
|
if (
|
||||||
word_timestamps=[]
|
return_timestamps
|
||||||
current_offset=0.0
|
and hasattr(result, "tokens")
|
||||||
|
and result.tokens
|
||||||
|
):
|
||||||
|
word_timestamps = []
|
||||||
|
current_offset = 0.0
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Processing chunk timestamps with {len(result.tokens)} tokens"
|
f"Processing chunk timestamps with {len(result.tokens)} tokens"
|
||||||
)
|
)
|
||||||
if result.pred_dur is not None:
|
if result.pred_dur is not None:
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# Add timestamps with offset
|
# Add timestamps with offset
|
||||||
for token in result.tokens:
|
for token in result.tokens:
|
||||||
if not all(
|
if not all(
|
||||||
|
@ -285,7 +299,7 @@ class KokoroV1(BaseModelBackend):
|
||||||
WordTimestamp(
|
WordTimestamp(
|
||||||
word=str(token.text).strip(),
|
word=str(token.text).strip(),
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time
|
end_time=end_time,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
@ -297,8 +311,9 @@ class KokoroV1(BaseModelBackend):
|
||||||
f"Failed to process timestamps for chunk: {e}"
|
f"Failed to process timestamps for chunk: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
yield AudioChunk(
|
||||||
yield AudioChunk(result.audio.numpy(),word_timestamps=word_timestamps)
|
result.audio.numpy(), word_timestamps=word_timestamps
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning("No audio in chunk")
|
logger.warning("No audio in chunk")
|
||||||
|
|
||||||
|
@ -329,7 +344,7 @@ class KokoroV1(BaseModelBackend):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
elif self._device == "mps":
|
elif self._device == "mps":
|
||||||
# Empty cache if available (future-proofing)
|
# Empty cache if available (future-proofing)
|
||||||
if hasattr(torch.mps, 'empty_cache'):
|
if hasattr(torch.mps, "empty_cache"):
|
||||||
torch.mps.empty_cache()
|
torch.mps.empty_cache()
|
||||||
|
|
||||||
def unload(self) -> None:
|
def unload(self) -> None:
|
||||||
|
|
|
@ -3,8 +3,8 @@ import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
from fastapi import APIRouter
|
|
||||||
import torch
|
import torch
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import GPUtil
|
import GPUtil
|
||||||
|
@ -119,7 +119,7 @@ async def get_system_info():
|
||||||
"type": "MPS",
|
"type": "MPS",
|
||||||
"available": True,
|
"available": True,
|
||||||
"device": "Apple Silicon",
|
"device": "Apple Silicon",
|
||||||
"backend": "Metal"
|
"backend": "Metal",
|
||||||
}
|
}
|
||||||
elif GPU_AVAILABLE:
|
elif GPU_AVAILABLE:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -156,6 +156,7 @@ async def generate_from_phonemes(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/dev/captioned_speech")
|
@router.post("/dev/captioned_speech")
|
||||||
async def create_captioned_speech(
|
async def create_captioned_speech(
|
||||||
request: CaptionedSpeechRequest,
|
request: CaptionedSpeechRequest,
|
||||||
|
@ -184,7 +185,9 @@ async def create_captioned_speech(
|
||||||
# Check if streaming is requested (default for OpenAI client)
|
# Check if streaming is requested (default for OpenAI client)
|
||||||
if request.stream:
|
if request.stream:
|
||||||
# Create generator but don't start it yet
|
# Create generator but don't start it yet
|
||||||
generator = stream_audio_chunks(tts_service, request, client_request, writer)
|
generator = stream_audio_chunks(
|
||||||
|
tts_service, request, client_request, writer
|
||||||
|
)
|
||||||
|
|
||||||
# If download link requested, wrap generator with temp file writer
|
# If download link requested, wrap generator with temp file writer
|
||||||
if request.return_download_link:
|
if request.return_download_link:
|
||||||
|
@ -211,20 +214,31 @@ async def create_captioned_speech(
|
||||||
# Write chunks to temp file and stream
|
# Write chunks to temp file and stream
|
||||||
async for chunk_data in generator:
|
async for chunk_data in generator:
|
||||||
# The timestamp acumulator is only used when word level time stamps are generated but no audio is returned.
|
# The timestamp acumulator is only used when word level time stamps are generated but no audio is returned.
|
||||||
timestamp_acumulator=[]
|
timestamp_acumulator = []
|
||||||
|
|
||||||
if chunk_data.output: # Skip empty chunks
|
if chunk_data.output: # Skip empty chunks
|
||||||
await temp_writer.write(chunk_data.output)
|
await temp_writer.write(chunk_data.output)
|
||||||
base64_chunk= base64.b64encode(chunk_data.output).decode("utf-8")
|
base64_chunk = base64.b64encode(
|
||||||
|
chunk_data.output
|
||||||
|
).decode("utf-8")
|
||||||
|
|
||||||
# Add any chunks that may be in the acumulator into the return word_timestamps
|
# Add any chunks that may be in the acumulator into the return word_timestamps
|
||||||
chunk_data.word_timestamps=timestamp_acumulator + chunk_data.word_timestamps
|
chunk_data.word_timestamps = (
|
||||||
timestamp_acumulator=[]
|
timestamp_acumulator + chunk_data.word_timestamps
|
||||||
|
)
|
||||||
|
timestamp_acumulator = []
|
||||||
|
|
||||||
yield CaptionedSpeechResponse(audio=base64_chunk,audio_format=content_type,timestamps=chunk_data.word_timestamps)
|
yield CaptionedSpeechResponse(
|
||||||
|
audio=base64_chunk,
|
||||||
|
audio_format=content_type,
|
||||||
|
timestamps=chunk_data.word_timestamps,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if chunk_data.word_timestamps is not None and len(chunk_data.word_timestamps) > 0:
|
if (
|
||||||
timestamp_acumulator+=chunk_data.word_timestamps
|
chunk_data.word_timestamps is not None
|
||||||
|
and len(chunk_data.word_timestamps) > 0
|
||||||
|
):
|
||||||
|
timestamp_acumulator += chunk_data.word_timestamps
|
||||||
|
|
||||||
# Finalize the temp file
|
# Finalize the temp file
|
||||||
await temp_writer.finalize()
|
await temp_writer.finalize()
|
||||||
|
@ -246,25 +260,36 @@ async def create_captioned_speech(
|
||||||
async def single_output():
|
async def single_output():
|
||||||
try:
|
try:
|
||||||
# The timestamp acumulator is only used when word level time stamps are generated but no audio is returned.
|
# The timestamp acumulator is only used when word level time stamps are generated but no audio is returned.
|
||||||
timestamp_acumulator=[]
|
timestamp_acumulator = []
|
||||||
|
|
||||||
# Stream chunks
|
# Stream chunks
|
||||||
async for chunk_data in generator:
|
async for chunk_data in generator:
|
||||||
if chunk_data.output: # Skip empty chunks
|
if chunk_data.output: # Skip empty chunks
|
||||||
# Encode the chunk bytes into base 64
|
# Encode the chunk bytes into base 64
|
||||||
base64_chunk= base64.b64encode(chunk_data.output).decode("utf-8")
|
base64_chunk = base64.b64encode(chunk_data.output).decode(
|
||||||
|
"utf-8"
|
||||||
|
)
|
||||||
|
|
||||||
# Add any chunks that may be in the acumulator into the return word_timestamps
|
# Add any chunks that may be in the acumulator into the return word_timestamps
|
||||||
if chunk_data.word_timestamps != None:
|
if chunk_data.word_timestamps != None:
|
||||||
chunk_data.word_timestamps = timestamp_acumulator + chunk_data.word_timestamps
|
chunk_data.word_timestamps = (
|
||||||
|
timestamp_acumulator + chunk_data.word_timestamps
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
chunk_data.word_timestamps = []
|
chunk_data.word_timestamps = []
|
||||||
timestamp_acumulator=[]
|
timestamp_acumulator = []
|
||||||
|
|
||||||
yield CaptionedSpeechResponse(audio=base64_chunk,audio_format=content_type,timestamps=chunk_data.word_timestamps)
|
yield CaptionedSpeechResponse(
|
||||||
|
audio=base64_chunk,
|
||||||
|
audio_format=content_type,
|
||||||
|
timestamps=chunk_data.word_timestamps,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if chunk_data.word_timestamps is not None and len(chunk_data.word_timestamps) > 0:
|
if (
|
||||||
timestamp_acumulator+=chunk_data.word_timestamps
|
chunk_data.word_timestamps is not None
|
||||||
|
and len(chunk_data.word_timestamps) > 0
|
||||||
|
):
|
||||||
|
timestamp_acumulator += chunk_data.word_timestamps
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in single output streaming: {e}")
|
logger.error(f"Error in single output streaming: {e}")
|
||||||
|
@ -309,11 +334,15 @@ async def create_captioned_speech(
|
||||||
writer,
|
writer,
|
||||||
is_last_chunk=True,
|
is_last_chunk=True,
|
||||||
)
|
)
|
||||||
output=audio_data.output + final.output
|
output = audio_data.output + final.output
|
||||||
|
|
||||||
base64_output= base64.b64encode(output).decode("utf-8")
|
base64_output = base64.b64encode(output).decode("utf-8")
|
||||||
|
|
||||||
content=CaptionedSpeechResponse(audio=base64_output,audio_format=content_type,timestamps=audio_data.word_timestamps).model_dump()
|
content = CaptionedSpeechResponse(
|
||||||
|
audio=base64_output,
|
||||||
|
audio_format=content_type,
|
||||||
|
timestamps=audio_data.word_timestamps,
|
||||||
|
).model_dump()
|
||||||
|
|
||||||
writer.close()
|
writer.close()
|
||||||
|
|
||||||
|
|
|
@ -10,18 +10,18 @@ from urllib import response
|
||||||
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ..services.streaming_audio_writer import StreamingAudioWriter
|
|
||||||
import torch
|
import torch
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||||
from fastapi.responses import FileResponse, StreamingResponse
|
from fastapi.responses import FileResponse, StreamingResponse
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from ..structures.schemas import CaptionedSpeechRequest
|
|
||||||
|
|
||||||
from ..core.config import settings
|
from ..core.config import settings
|
||||||
from ..inference.base import AudioChunk
|
from ..inference.base import AudioChunk
|
||||||
from ..services.audio import AudioService
|
from ..services.audio import AudioService
|
||||||
|
from ..services.streaming_audio_writer import StreamingAudioWriter
|
||||||
from ..services.tts_service import TTSService
|
from ..services.tts_service import TTSService
|
||||||
from ..structures import OpenAISpeechRequest
|
from ..structures import OpenAISpeechRequest
|
||||||
|
from ..structures.schemas import CaptionedSpeechRequest
|
||||||
|
|
||||||
|
|
||||||
# Load OpenAI mappings
|
# Load OpenAI mappings
|
||||||
|
@ -80,7 +80,9 @@ def get_model_name(model: str) -> str:
|
||||||
return base_name + ".pth"
|
return base_name + ".pth"
|
||||||
|
|
||||||
|
|
||||||
async def process_and_validate_voices(voice_input: Union[str, List[str]], tts_service: TTSService) -> str:
|
async def process_and_validate_voices(
|
||||||
|
voice_input: Union[str, List[str]], tts_service: TTSService
|
||||||
|
) -> str:
|
||||||
"""Process voice input, handling both string and list formats
|
"""Process voice input, handling both string and list formats
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -107,22 +109,35 @@ async def process_and_validate_voices(voice_input: Union[str, List[str]], tts_se
|
||||||
mapped_voice = list(map(str.strip, mapped_voice))
|
mapped_voice = list(map(str.strip, mapped_voice))
|
||||||
|
|
||||||
if len(mapped_voice) > 2:
|
if len(mapped_voice) > 2:
|
||||||
raise ValueError(f"Voice '{voices[voice_index]}' contains too many weight items")
|
raise ValueError(
|
||||||
|
f"Voice '{voices[voice_index]}' contains too many weight items"
|
||||||
|
)
|
||||||
|
|
||||||
if mapped_voice.count(")") > 1:
|
if mapped_voice.count(")") > 1:
|
||||||
raise ValueError(f"Voice '{voices[voice_index]}' contains too many weight items")
|
raise ValueError(
|
||||||
|
f"Voice '{voices[voice_index]}' contains too many weight items"
|
||||||
|
)
|
||||||
|
|
||||||
mapped_voice[0] = _openai_mappings["voices"].get(mapped_voice[0], mapped_voice[0])
|
mapped_voice[0] = _openai_mappings["voices"].get(
|
||||||
|
mapped_voice[0], mapped_voice[0]
|
||||||
|
)
|
||||||
|
|
||||||
if mapped_voice[0] not in available_voices:
|
if mapped_voice[0] not in available_voices:
|
||||||
raise ValueError(f"Voice '{mapped_voice[0]}' not found. Available voices: {', '.join(sorted(available_voices))}")
|
raise ValueError(
|
||||||
|
f"Voice '{mapped_voice[0]}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
||||||
|
)
|
||||||
|
|
||||||
voices[voice_index] = "(".join(mapped_voice)
|
voices[voice_index] = "(".join(mapped_voice)
|
||||||
|
|
||||||
return "".join(voices)
|
return "".join(voices)
|
||||||
|
|
||||||
|
|
||||||
async def stream_audio_chunks(tts_service: TTSService, request: Union[OpenAISpeechRequest, CaptionedSpeechRequest], client_request: Request, writer: StreamingAudioWriter) -> AsyncGenerator[AudioChunk, None]:
|
async def stream_audio_chunks(
|
||||||
|
tts_service: TTSService,
|
||||||
|
request: Union[OpenAISpeechRequest, CaptionedSpeechRequest],
|
||||||
|
client_request: Request,
|
||||||
|
writer: StreamingAudioWriter,
|
||||||
|
) -> AsyncGenerator[AudioChunk, None]:
|
||||||
"""Stream audio chunks as they're generated with client disconnect handling"""
|
"""Stream audio chunks as they're generated with client disconnect handling"""
|
||||||
voice_name = await process_and_validate_voices(request.voice, tts_service)
|
voice_name = await process_and_validate_voices(request.voice, tts_service)
|
||||||
unique_properties = {"return_timestamps": False}
|
unique_properties = {"return_timestamps": False}
|
||||||
|
@ -193,7 +208,9 @@ async def create_speech(
|
||||||
# Check if streaming is requested (default for OpenAI client)
|
# Check if streaming is requested (default for OpenAI client)
|
||||||
if request.stream:
|
if request.stream:
|
||||||
# Create generator but don't start it yet
|
# Create generator but don't start it yet
|
||||||
generator = stream_audio_chunks(tts_service, request, client_request, writer)
|
generator = stream_audio_chunks(
|
||||||
|
tts_service, request, client_request, writer
|
||||||
|
)
|
||||||
|
|
||||||
# If download link requested, wrap generator with temp file writer
|
# If download link requested, wrap generator with temp file writer
|
||||||
if request.return_download_link:
|
if request.return_download_link:
|
||||||
|
@ -245,7 +262,9 @@ async def create_speech(
|
||||||
writer.close()
|
writer.close()
|
||||||
|
|
||||||
# Stream with temp file writing
|
# Stream with temp file writing
|
||||||
return StreamingResponse(dual_output(), media_type=content_type, headers=headers)
|
return StreamingResponse(
|
||||||
|
dual_output(), media_type=content_type, headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
async def single_output():
|
async def single_output():
|
||||||
try:
|
try:
|
||||||
|
@ -285,7 +304,13 @@ async def create_speech(
|
||||||
lang_code=request.lang_code,
|
lang_code=request.lang_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
audio_data = await AudioService.convert_audio(audio_data, request.response_format, writer, is_last_chunk=False, trim_audio=False)
|
audio_data = await AudioService.convert_audio(
|
||||||
|
audio_data,
|
||||||
|
request.response_format,
|
||||||
|
writer,
|
||||||
|
is_last_chunk=False,
|
||||||
|
trim_audio=False,
|
||||||
|
)
|
||||||
|
|
||||||
# Convert to requested format with proper finalization
|
# Convert to requested format with proper finalization
|
||||||
final = await AudioService.convert_audio(
|
final = await AudioService.convert_audio(
|
||||||
|
@ -384,7 +409,6 @@ async def create_speech(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/download/{filename}")
|
@router.get("/download/{filename}")
|
||||||
async def download_audio_file(filename: str):
|
async def download_audio_file(filename: str):
|
||||||
"""Download a generated audio file from temp storage"""
|
"""Download a generated audio file from temp storage"""
|
||||||
|
@ -392,7 +416,9 @@ async def download_audio_file(filename: str):
|
||||||
from ..core.paths import _find_file, get_content_type
|
from ..core.paths import _find_file, get_content_type
|
||||||
|
|
||||||
# Search for file in temp directory
|
# Search for file in temp directory
|
||||||
file_path = await _find_file(filename=filename, search_paths=[settings.temp_file_dir])
|
file_path = await _find_file(
|
||||||
|
filename=filename, search_paths=[settings.temp_file_dir]
|
||||||
|
)
|
||||||
|
|
||||||
# Get content type from path helper
|
# Get content type from path helper
|
||||||
content_type = await get_content_type(file_path)
|
content_type = await get_content_type(file_path)
|
||||||
|
@ -425,9 +451,24 @@ async def list_models():
|
||||||
try:
|
try:
|
||||||
# Create standard model list
|
# Create standard model list
|
||||||
models = [
|
models = [
|
||||||
{"id": "tts-1", "object": "model", "created": 1686935002, "owned_by": "kokoro"},
|
{
|
||||||
{"id": "tts-1-hd", "object": "model", "created": 1686935002, "owned_by": "kokoro"},
|
"id": "tts-1",
|
||||||
{"id": "kokoro", "object": "model", "created": 1686935002, "owned_by": "kokoro"},
|
"object": "model",
|
||||||
|
"created": 1686935002,
|
||||||
|
"owned_by": "kokoro",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "tts-1-hd",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1686935002,
|
||||||
|
"owned_by": "kokoro",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "kokoro",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1686935002,
|
||||||
|
"owned_by": "kokoro",
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
return {"object": "list", "data": models}
|
return {"object": "list", "data": models}
|
||||||
|
@ -449,14 +490,36 @@ async def retrieve_model(model: str):
|
||||||
try:
|
try:
|
||||||
# Define available models
|
# Define available models
|
||||||
models = {
|
models = {
|
||||||
"tts-1": {"id": "tts-1", "object": "model", "created": 1686935002, "owned_by": "kokoro"},
|
"tts-1": {
|
||||||
"tts-1-hd": {"id": "tts-1-hd", "object": "model", "created": 1686935002, "owned_by": "kokoro"},
|
"id": "tts-1",
|
||||||
"kokoro": {"id": "kokoro", "object": "model", "created": 1686935002, "owned_by": "kokoro"},
|
"object": "model",
|
||||||
|
"created": 1686935002,
|
||||||
|
"owned_by": "kokoro",
|
||||||
|
},
|
||||||
|
"tts-1-hd": {
|
||||||
|
"id": "tts-1-hd",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1686935002,
|
||||||
|
"owned_by": "kokoro",
|
||||||
|
},
|
||||||
|
"kokoro": {
|
||||||
|
"id": "kokoro",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1686935002,
|
||||||
|
"owned_by": "kokoro",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Check if requested model exists
|
# Check if requested model exists
|
||||||
if model not in models:
|
if model not in models:
|
||||||
raise HTTPException(status_code=404, detail={"error": "model_not_found", "message": f"Model '{model}' not found", "type": "invalid_request_error"})
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail={
|
||||||
|
"error": "model_not_found",
|
||||||
|
"message": f"Model '{model}' not found",
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Return the specific model
|
# Return the specific model
|
||||||
return models[model]
|
return models[model]
|
||||||
|
@ -541,7 +604,9 @@ async def combine_voices(request: Union[str, List[str]]):
|
||||||
available_voices = await tts_service.list_voices()
|
available_voices = await tts_service.list_voices()
|
||||||
for voice in voices:
|
for voice in voices:
|
||||||
if voice not in available_voices:
|
if voice not in available_voices:
|
||||||
raise ValueError(f"Base voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}")
|
raise ValueError(
|
||||||
|
f"Base voice '{voice}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
||||||
|
)
|
||||||
|
|
||||||
# Combine voices
|
# Combine voices
|
||||||
combined_tensor = await tts_service.combine_voices(voices=voices)
|
combined_tensor = await tts_service.combine_voices(voices=voices)
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
"""Audio conversion service"""
|
"""Audio conversion service"""
|
||||||
|
|
||||||
|
import math
|
||||||
import struct
|
import struct
|
||||||
import time
|
import time
|
||||||
from typing import Tuple
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
|
||||||
import scipy.io.wavfile as wavfile
|
import scipy.io.wavfile as wavfile
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
@ -14,8 +14,9 @@ from pydub import AudioSegment
|
||||||
from torch import norm
|
from torch import norm
|
||||||
|
|
||||||
from ..core.config import settings
|
from ..core.config import settings
|
||||||
from .streaming_audio_writer import StreamingAudioWriter
|
|
||||||
from ..inference.base import AudioChunk
|
from ..inference.base import AudioChunk
|
||||||
|
from .streaming_audio_writer import StreamingAudioWriter
|
||||||
|
|
||||||
|
|
||||||
class AudioNormalizer:
|
class AudioNormalizer:
|
||||||
"""Handles audio normalization state for a single stream"""
|
"""Handles audio normalization state for a single stream"""
|
||||||
|
@ -24,9 +25,16 @@ class AudioNormalizer:
|
||||||
self.chunk_trim_ms = settings.gap_trim_ms
|
self.chunk_trim_ms = settings.gap_trim_ms
|
||||||
self.sample_rate = 24000 # Sample rate of the audio
|
self.sample_rate = 24000 # Sample rate of the audio
|
||||||
self.samples_to_trim = int(self.chunk_trim_ms * self.sample_rate / 1000)
|
self.samples_to_trim = int(self.chunk_trim_ms * self.sample_rate / 1000)
|
||||||
self.samples_to_pad_start= int(50 * self.sample_rate / 1000)
|
self.samples_to_pad_start = int(50 * self.sample_rate / 1000)
|
||||||
|
|
||||||
def find_first_last_non_silent(self,audio_data: np.ndarray, chunk_text: str, speed: float, silence_threshold_db: int = -45, is_last_chunk: bool = False) -> tuple[int, int]:
|
def find_first_last_non_silent(
|
||||||
|
self,
|
||||||
|
audio_data: np.ndarray,
|
||||||
|
chunk_text: str,
|
||||||
|
speed: float,
|
||||||
|
silence_threshold_db: int = -45,
|
||||||
|
is_last_chunk: bool = False,
|
||||||
|
) -> tuple[int, int]:
|
||||||
"""Finds the indices of the first and last non-silent samples in audio data.
|
"""Finds the indices of the first and last non-silent samples in audio data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -40,37 +48,55 @@ class AudioNormalizer:
|
||||||
A tuple with the start of the non silent portion and with the end of the non silent portion
|
A tuple with the start of the non silent portion and with the end of the non silent portion
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pad_multiplier=1
|
pad_multiplier = 1
|
||||||
split_character=chunk_text.strip()
|
split_character = chunk_text.strip()
|
||||||
if len(split_character) > 0:
|
if len(split_character) > 0:
|
||||||
split_character=split_character[-1]
|
split_character = split_character[-1]
|
||||||
if split_character in settings.dynamic_gap_trim_padding_char_multiplier:
|
if split_character in settings.dynamic_gap_trim_padding_char_multiplier:
|
||||||
pad_multiplier=settings.dynamic_gap_trim_padding_char_multiplier[split_character]
|
pad_multiplier = settings.dynamic_gap_trim_padding_char_multiplier[
|
||||||
|
split_character
|
||||||
|
]
|
||||||
|
|
||||||
if not is_last_chunk:
|
if not is_last_chunk:
|
||||||
samples_to_pad_end= max(int((settings.dynamic_gap_trim_padding_ms * self.sample_rate * pad_multiplier) / 1000) - self.samples_to_pad_start, 0)
|
samples_to_pad_end = max(
|
||||||
|
int(
|
||||||
|
(
|
||||||
|
settings.dynamic_gap_trim_padding_ms
|
||||||
|
* self.sample_rate
|
||||||
|
* pad_multiplier
|
||||||
|
)
|
||||||
|
/ 1000
|
||||||
|
)
|
||||||
|
- self.samples_to_pad_start,
|
||||||
|
0,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
samples_to_pad_end=self.samples_to_pad_start
|
samples_to_pad_end = self.samples_to_pad_start
|
||||||
# Convert dBFS threshold to amplitude
|
# Convert dBFS threshold to amplitude
|
||||||
amplitude_threshold = np.iinfo(audio_data.dtype).max * (10 ** (silence_threshold_db / 20))
|
amplitude_threshold = np.iinfo(audio_data.dtype).max * (
|
||||||
|
10 ** (silence_threshold_db / 20)
|
||||||
|
)
|
||||||
# Find the first samples above the silence threshold at the start and end of the audio
|
# Find the first samples above the silence threshold at the start and end of the audio
|
||||||
non_silent_index_start, non_silent_index_end = None,None
|
non_silent_index_start, non_silent_index_end = None, None
|
||||||
|
|
||||||
for X in range(0,len(audio_data)):
|
for X in range(0, len(audio_data)):
|
||||||
if audio_data[X] > amplitude_threshold:
|
if audio_data[X] > amplitude_threshold:
|
||||||
non_silent_index_start=X
|
non_silent_index_start = X
|
||||||
break
|
break
|
||||||
|
|
||||||
for X in range(len(audio_data) - 1, -1, -1):
|
for X in range(len(audio_data) - 1, -1, -1):
|
||||||
if audio_data[X] > amplitude_threshold:
|
if audio_data[X] > amplitude_threshold:
|
||||||
non_silent_index_end=X
|
non_silent_index_end = X
|
||||||
break
|
break
|
||||||
|
|
||||||
# Handle the case where the entire audio is silent
|
# Handle the case where the entire audio is silent
|
||||||
if non_silent_index_start == None or non_silent_index_end == None:
|
if non_silent_index_start == None or non_silent_index_end == None:
|
||||||
return 0, len(audio_data)
|
return 0, len(audio_data)
|
||||||
|
|
||||||
return max(non_silent_index_start - self.samples_to_pad_start,0), min(non_silent_index_end + math.ceil(samples_to_pad_end / speed),len(audio_data))
|
return max(non_silent_index_start - self.samples_to_pad_start, 0), min(
|
||||||
|
non_silent_index_end + math.ceil(samples_to_pad_end / speed),
|
||||||
|
len(audio_data),
|
||||||
|
)
|
||||||
|
|
||||||
def normalize(self, audio_data: np.ndarray) -> np.ndarray:
|
def normalize(self, audio_data: np.ndarray) -> np.ndarray:
|
||||||
"""Convert audio data to int16 range
|
"""Convert audio data to int16 range
|
||||||
|
@ -85,6 +111,7 @@ class AudioNormalizer:
|
||||||
return np.clip(audio_data * 32767, -32768, 32767).astype(np.int16)
|
return np.clip(audio_data * 32767, -32768, 32767).astype(np.int16)
|
||||||
return audio_data
|
return audio_data
|
||||||
|
|
||||||
|
|
||||||
class AudioService:
|
class AudioService:
|
||||||
"""Service for audio format conversions with streaming support"""
|
"""Service for audio format conversions with streaming support"""
|
||||||
|
|
||||||
|
@ -155,18 +182,16 @@ class AudioService:
|
||||||
if len(audio_chunk.audio) > 0:
|
if len(audio_chunk.audio) > 0:
|
||||||
chunk_data = writer.write_chunk(audio_chunk.audio)
|
chunk_data = writer.write_chunk(audio_chunk.audio)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Then finalize if this is the last chunk
|
# Then finalize if this is the last chunk
|
||||||
if is_last_chunk:
|
if is_last_chunk:
|
||||||
final_data = writer.write_chunk(finalize=True)
|
final_data = writer.write_chunk(finalize=True)
|
||||||
|
|
||||||
if final_data:
|
if final_data:
|
||||||
audio_chunk.output=final_data
|
audio_chunk.output = final_data
|
||||||
return audio_chunk
|
return audio_chunk
|
||||||
|
|
||||||
if chunk_data:
|
if chunk_data:
|
||||||
audio_chunk.output=chunk_data
|
audio_chunk.output = chunk_data
|
||||||
return audio_chunk
|
return audio_chunk
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -174,8 +199,15 @@ class AudioService:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Failed to convert audio stream to {output_format}: {str(e)}"
|
f"Failed to convert audio stream to {output_format}: {str(e)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def trim_audio(audio_chunk: AudioChunk, chunk_text: str = "", speed: float = 1, is_last_chunk: bool = False, normalizer: AudioNormalizer = None) -> AudioChunk:
|
def trim_audio(
|
||||||
|
audio_chunk: AudioChunk,
|
||||||
|
chunk_text: str = "",
|
||||||
|
speed: float = 1,
|
||||||
|
is_last_chunk: bool = False,
|
||||||
|
normalizer: AudioNormalizer = None,
|
||||||
|
) -> AudioChunk:
|
||||||
"""Trim silence from start and end
|
"""Trim silence from start and end
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -191,23 +223,26 @@ class AudioService:
|
||||||
if normalizer is None:
|
if normalizer is None:
|
||||||
normalizer = AudioNormalizer()
|
normalizer = AudioNormalizer()
|
||||||
|
|
||||||
audio_chunk.audio=normalizer.normalize(audio_chunk.audio)
|
audio_chunk.audio = normalizer.normalize(audio_chunk.audio)
|
||||||
|
|
||||||
trimed_samples=0
|
trimed_samples = 0
|
||||||
# Trim start and end if enough samples
|
# Trim start and end if enough samples
|
||||||
if len(audio_chunk.audio) > (2 * normalizer.samples_to_trim):
|
if len(audio_chunk.audio) > (2 * normalizer.samples_to_trim):
|
||||||
audio_chunk.audio = audio_chunk.audio[normalizer.samples_to_trim : -normalizer.samples_to_trim]
|
audio_chunk.audio = audio_chunk.audio[
|
||||||
trimed_samples+=normalizer.samples_to_trim
|
normalizer.samples_to_trim : -normalizer.samples_to_trim
|
||||||
|
]
|
||||||
|
trimed_samples += normalizer.samples_to_trim
|
||||||
|
|
||||||
# Find non silent portion and trim
|
# Find non silent portion and trim
|
||||||
start_index,end_index=normalizer.find_first_last_non_silent(audio_chunk.audio,chunk_text,speed,is_last_chunk=is_last_chunk)
|
start_index, end_index = normalizer.find_first_last_non_silent(
|
||||||
|
audio_chunk.audio, chunk_text, speed, is_last_chunk=is_last_chunk
|
||||||
|
)
|
||||||
|
|
||||||
audio_chunk.audio=audio_chunk.audio[start_index:end_index]
|
audio_chunk.audio = audio_chunk.audio[start_index:end_index]
|
||||||
trimed_samples+=start_index
|
trimed_samples += start_index
|
||||||
|
|
||||||
if audio_chunk.word_timestamps is not None:
|
if audio_chunk.word_timestamps is not None:
|
||||||
for timestamp in audio_chunk.word_timestamps:
|
for timestamp in audio_chunk.word_timestamps:
|
||||||
timestamp.start_time-=trimed_samples / 24000
|
timestamp.start_time -= trimed_samples / 24000
|
||||||
timestamp.end_time-=trimed_samples / 24000
|
timestamp.end_time -= trimed_samples / 24000
|
||||||
return audio_chunk
|
return audio_chunk
|
||||||
|
|
|
@ -4,11 +4,12 @@ import struct
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import av
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydub import AudioSegment
|
from pydub import AudioSegment
|
||||||
import av
|
|
||||||
|
|
||||||
class StreamingAudioWriter:
|
class StreamingAudioWriter:
|
||||||
"""Handles streaming audio format conversions"""
|
"""Handles streaming audio format conversions"""
|
||||||
|
@ -18,15 +19,29 @@ class StreamingAudioWriter:
|
||||||
self.sample_rate = sample_rate
|
self.sample_rate = sample_rate
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.bytes_written = 0
|
self.bytes_written = 0
|
||||||
self.pts=0
|
self.pts = 0
|
||||||
|
|
||||||
codec_map = {"wav":"pcm_s16le","mp3":"mp3","opus":"libopus","flac":"flac", "aac":"aac"}
|
codec_map = {
|
||||||
|
"wav": "pcm_s16le",
|
||||||
|
"mp3": "mp3",
|
||||||
|
"opus": "libopus",
|
||||||
|
"flac": "flac",
|
||||||
|
"aac": "aac",
|
||||||
|
}
|
||||||
# Format-specific setup
|
# Format-specific setup
|
||||||
if self.format in ["wav","flac","mp3","pcm","aac","opus"]:
|
if self.format in ["wav", "flac", "mp3", "pcm", "aac", "opus"]:
|
||||||
if self.format != "pcm":
|
if self.format != "pcm":
|
||||||
self.output_buffer = BytesIO()
|
self.output_buffer = BytesIO()
|
||||||
self.container = av.open(self.output_buffer, mode="w", format=self.format if self.format != "aac" else "adts")
|
self.container = av.open(
|
||||||
self.stream = self.container.add_stream(codec_map[self.format],sample_rate=self.sample_rate,layout='mono' if self.channels == 1 else 'stereo')
|
self.output_buffer,
|
||||||
|
mode="w",
|
||||||
|
format=self.format if self.format != "aac" else "adts",
|
||||||
|
)
|
||||||
|
self.stream = self.container.add_stream(
|
||||||
|
codec_map[self.format],
|
||||||
|
sample_rate=self.sample_rate,
|
||||||
|
layout="mono" if self.channels == 1 else "stereo",
|
||||||
|
)
|
||||||
self.stream.bit_rate = 128000
|
self.stream.bit_rate = 128000
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported format: {format}")
|
raise ValueError(f"Unsupported format: {format}")
|
||||||
|
@ -54,7 +69,7 @@ class StreamingAudioWriter:
|
||||||
for packet in packets:
|
for packet in packets:
|
||||||
self.container.mux(packet)
|
self.container.mux(packet)
|
||||||
|
|
||||||
data=self.output_buffer.getvalue()
|
data = self.output_buffer.getvalue()
|
||||||
self.close()
|
self.close()
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -65,9 +80,12 @@ class StreamingAudioWriter:
|
||||||
# Write raw bytes
|
# Write raw bytes
|
||||||
return audio_data.tobytes()
|
return audio_data.tobytes()
|
||||||
else:
|
else:
|
||||||
frame = av.AudioFrame.from_ndarray(audio_data.reshape(1, -1), format='s16', layout='mono' if self.channels == 1 else 'stereo')
|
frame = av.AudioFrame.from_ndarray(
|
||||||
frame.sample_rate=self.sample_rate
|
audio_data.reshape(1, -1),
|
||||||
|
format="s16",
|
||||||
|
layout="mono" if self.channels == 1 else "stereo",
|
||||||
|
)
|
||||||
|
frame.sample_rate = self.sample_rate
|
||||||
|
|
||||||
frame.pts = self.pts
|
frame.pts = self.pts
|
||||||
self.pts += frame.samples
|
self.pts += frame.samples
|
||||||
|
@ -80,4 +98,3 @@ class StreamingAudioWriter:
|
||||||
self.output_buffer.seek(0)
|
self.output_buffer.seek(0)
|
||||||
self.output_buffer.truncate(0)
|
self.output_buffer.truncate(0)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
|
@ -6,12 +6,13 @@ Converts them into a format suitable for text-to-speech processing.
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
import inflect
|
import inflect
|
||||||
from numpy import number
|
from numpy import number
|
||||||
from torch import mul
|
|
||||||
from ...structures.schemas import NormalizationOptions
|
|
||||||
|
|
||||||
from text_to_num import text2num
|
from text_to_num import text2num
|
||||||
|
from torch import mul
|
||||||
|
|
||||||
|
from ...structures.schemas import NormalizationOptions
|
||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
VALID_TLDS = [
|
VALID_TLDS = [
|
||||||
|
@ -54,26 +55,81 @@ VALID_TLDS = [
|
||||||
"uk",
|
"uk",
|
||||||
"us",
|
"us",
|
||||||
"io",
|
"io",
|
||||||
|
"co",
|
||||||
]
|
]
|
||||||
|
|
||||||
VALID_UNITS = {
|
VALID_UNITS = {
|
||||||
"m":"meter", "cm":"centimeter", "mm":"millimeter", "km":"kilometer", "in":"inch", "ft":"foot", "yd":"yard", "mi":"mile", # Length
|
"m": "meter",
|
||||||
"g":"gram", "kg":"kilogram", "mg":"miligram", # Mass
|
"cm": "centimeter",
|
||||||
"s":"second", "ms":"milisecond", "min":"minutes", "h":"hour", # Time
|
"mm": "millimeter",
|
||||||
"l":"liter", "ml":"mililiter", "cl":"centiliter", "dl":"deciliter", # Volume
|
"km": "kilometer",
|
||||||
"kph":"kilometer per hour", "mph":"mile per hour","mi/h":"mile per hour", "m/s":"meter per second", "km/h":"kilometer per hour", "mm/s":"milimeter per second","cm/s":"centimeter per second", "ft/s":"feet per second","cm/h":"centimeter per day", # Speed
|
"in": "inch",
|
||||||
"°c":"degree celsius","c":"degree celsius", "°f":"degree fahrenheit","f":"degree fahrenheit", "k":"kelvin", # Temperature
|
"ft": "foot",
|
||||||
"pa":"pascal", "kpa":"kilopascal", "mpa":"megapascal", "atm":"atmosphere", # Pressure
|
"yd": "yard",
|
||||||
"hz":"hertz", "khz":"kilohertz", "mhz":"megahertz", "ghz":"gigahertz", # Frequency
|
"mi": "mile", # Length
|
||||||
"v":"volt", "kv":"kilovolt", "mv":"mergavolt", # Voltage
|
"g": "gram",
|
||||||
"a":"amp", "ma":"megaamp", "ka":"kiloamp", # Current
|
"kg": "kilogram",
|
||||||
"w":"watt", "kw":"kilowatt", "mw":"megawatt", # Power
|
"mg": "milligram", # Mass
|
||||||
"j":"joule", "kj":"kilojoule", "mj":"megajoule", # Energy
|
"s": "second",
|
||||||
"Ω":"ohm", "kΩ":"kiloohm", "mΩ":"megaohm", # Resistance (Ohm)
|
"ms": "millisecond",
|
||||||
"f":"farad", "µf":"microfarad", "nf":"nanofarad", "pf":"picofarad", # Capacitance
|
"min": "minutes",
|
||||||
"b":"bit", "kb":"kilobit", "mb":"megabit", "gb":"gigabit", "tb":"terabit", "pb":"petabit", # Data size
|
"h": "hour", # Time
|
||||||
"kbps":"kilobit per second","mbps":"megabit per second","gbps":"gigabit per second","tbps":"terabit per second",
|
"l": "liter",
|
||||||
"px":"pixel" # CSS units
|
"ml": "mililiter",
|
||||||
|
"cl": "centiliter",
|
||||||
|
"dl": "deciliter", # Volume
|
||||||
|
"kph": "kilometer per hour",
|
||||||
|
"mph": "mile per hour",
|
||||||
|
"mi/h": "mile per hour",
|
||||||
|
"m/s": "meter per second",
|
||||||
|
"km/h": "kilometer per hour",
|
||||||
|
"mm/s": "milimeter per second",
|
||||||
|
"cm/s": "centimeter per second",
|
||||||
|
"ft/s": "feet per second",
|
||||||
|
"cm/h": "centimeter per day", # Speed
|
||||||
|
"°c": "degree celsius",
|
||||||
|
"c": "degree celsius",
|
||||||
|
"°f": "degree fahrenheit",
|
||||||
|
"f": "degree fahrenheit",
|
||||||
|
"k": "kelvin", # Temperature
|
||||||
|
"pa": "pascal",
|
||||||
|
"kpa": "kilopascal",
|
||||||
|
"mpa": "megapascal",
|
||||||
|
"atm": "atmosphere", # Pressure
|
||||||
|
"hz": "hertz",
|
||||||
|
"khz": "kilohertz",
|
||||||
|
"mhz": "megahertz",
|
||||||
|
"ghz": "gigahertz", # Frequency
|
||||||
|
"v": "volt",
|
||||||
|
"kv": "kilovolt",
|
||||||
|
"mv": "mergavolt", # Voltage
|
||||||
|
"a": "amp",
|
||||||
|
"ma": "megaamp",
|
||||||
|
"ka": "kiloamp", # Current
|
||||||
|
"w": "watt",
|
||||||
|
"kw": "kilowatt",
|
||||||
|
"mw": "megawatt", # Power
|
||||||
|
"j": "joule",
|
||||||
|
"kj": "kilojoule",
|
||||||
|
"mj": "megajoule", # Energy
|
||||||
|
"Ω": "ohm",
|
||||||
|
"kΩ": "kiloohm",
|
||||||
|
"mΩ": "megaohm", # Resistance (Ohm)
|
||||||
|
"f": "farad",
|
||||||
|
"µf": "microfarad",
|
||||||
|
"nf": "nanofarad",
|
||||||
|
"pf": "picofarad", # Capacitance
|
||||||
|
"b": "bit",
|
||||||
|
"kb": "kilobit",
|
||||||
|
"mb": "megabit",
|
||||||
|
"gb": "gigabit",
|
||||||
|
"tb": "terabit",
|
||||||
|
"pb": "petabit", # Data size
|
||||||
|
"kbps": "kilobit per second",
|
||||||
|
"mbps": "megabit per second",
|
||||||
|
"gbps": "gigabit per second",
|
||||||
|
"tbps": "terabit per second",
|
||||||
|
"px": "pixel", # CSS units
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -88,11 +144,19 @@ URL_PATTERN = re.compile(
|
||||||
re.IGNORECASE,
|
re.IGNORECASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
UNIT_PATTERN = re.compile(r"((?<!\w)([+-]?)(\d{1,3}(,\d{3})*|\d+)(\.\d+)?)\s*(" + "|".join(sorted(list(VALID_UNITS.keys()),reverse=True)) + r"""){1}(?=[^\w\d]{1}|\b)""",re.IGNORECASE)
|
UNIT_PATTERN = re.compile(
|
||||||
|
r"((?<!\w)([+-]?)(\d{1,3}(,\d{3})*|\d+)(\.\d+)?)\s*("
|
||||||
|
+ "|".join(sorted(list(VALID_UNITS.keys()), reverse=True))
|
||||||
|
+ r"""){1}(?=[^\w\d]{1}|\b)""",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
TIME_PATTERN = re.compile(r"([0-9]{2} ?: ?[0-9]{2}( ?: ?[0-9]{2})?)( ?(pm|am)\b)?", re.IGNORECASE)
|
TIME_PATTERN = re.compile(
|
||||||
|
r"([0-9]{2} ?: ?[0-9]{2}( ?: ?[0-9]{2})?)( ?(pm|am)\b)?", re.IGNORECASE
|
||||||
|
)
|
||||||
|
|
||||||
|
INFLECT_ENGINE = inflect.engine()
|
||||||
|
|
||||||
INFLECT_ENGINE=inflect.engine()
|
|
||||||
|
|
||||||
def split_num(num: re.Match[str]) -> str:
|
def split_num(num: re.Match[str]) -> str:
|
||||||
"""Handle number splitting for various formats"""
|
"""Handle number splitting for various formats"""
|
||||||
|
@ -118,29 +182,32 @@ def split_num(num: re.Match[str]) -> str:
|
||||||
return f"{left} oh {right}{s}"
|
return f"{left} oh {right}{s}"
|
||||||
return f"{left} {right}{s}"
|
return f"{left} {right}{s}"
|
||||||
|
|
||||||
|
|
||||||
def handle_units(u: re.Match[str]) -> str:
|
def handle_units(u: re.Match[str]) -> str:
|
||||||
"""Converts units to their full form"""
|
"""Converts units to their full form"""
|
||||||
unit_string=u.group(6).strip()
|
unit_string = u.group(6).strip()
|
||||||
unit=unit_string
|
unit = unit_string
|
||||||
|
|
||||||
if unit_string.lower() in VALID_UNITS:
|
if unit_string.lower() in VALID_UNITS:
|
||||||
unit=VALID_UNITS[unit_string.lower()].split(" ")
|
unit = VALID_UNITS[unit_string.lower()].split(" ")
|
||||||
|
|
||||||
# Handles the B vs b case
|
# Handles the B vs b case
|
||||||
if unit[0].endswith("bit"):
|
if unit[0].endswith("bit"):
|
||||||
b_case=unit_string[min(1,len(unit_string) - 1)]
|
b_case = unit_string[min(1, len(unit_string) - 1)]
|
||||||
if b_case == "B":
|
if b_case == "B":
|
||||||
unit[0]=unit[0][:-3] + "byte"
|
unit[0] = unit[0][:-3] + "byte"
|
||||||
|
|
||||||
number=u.group(1).strip()
|
number = u.group(1).strip()
|
||||||
unit[0]=INFLECT_ENGINE.no(unit[0],number)
|
unit[0] = INFLECT_ENGINE.no(unit[0], number)
|
||||||
return " ".join(unit)
|
return " ".join(unit)
|
||||||
|
|
||||||
|
|
||||||
def conditional_int(number: float, threshold: float = 0.00001):
|
def conditional_int(number: float, threshold: float = 0.00001):
|
||||||
if abs(round(number) - number) < threshold:
|
if abs(round(number) - number) < threshold:
|
||||||
return int(round(number))
|
return int(round(number))
|
||||||
return number
|
return number
|
||||||
|
|
||||||
|
|
||||||
def handle_money(m: re.Match[str]) -> str:
|
def handle_money(m: re.Match[str]) -> str:
|
||||||
"""Convert money expressions to spoken form"""
|
"""Convert money expressions to spoken form"""
|
||||||
|
|
||||||
|
@ -166,6 +233,7 @@ def handle_money(m: re.Match[str]) -> str:
|
||||||
|
|
||||||
return text_number
|
return text_number
|
||||||
|
|
||||||
|
|
||||||
def handle_decimal(num: re.Match[str]) -> str:
|
def handle_decimal(num: re.Match[str]) -> str:
|
||||||
"""Convert decimal numbers to spoken form"""
|
"""Convert decimal numbers to spoken form"""
|
||||||
a, b = num.group().split(".")
|
a, b = num.group().split(".")
|
||||||
|
@ -229,34 +297,41 @@ def handle_url(u: re.Match[str]) -> str:
|
||||||
# Clean up extra spaces
|
# Clean up extra spaces
|
||||||
return re.sub(r"\s+", " ", url).strip()
|
return re.sub(r"\s+", " ", url).strip()
|
||||||
|
|
||||||
def handle_phone_number(p: re.Match[str]) -> str:
|
|
||||||
p=list(p.groups())
|
|
||||||
|
|
||||||
country_code=""
|
def handle_phone_number(p: re.Match[str]) -> str:
|
||||||
|
p = list(p.groups())
|
||||||
|
|
||||||
|
country_code = ""
|
||||||
if p[0] is not None:
|
if p[0] is not None:
|
||||||
p[0]=p[0].replace("+","")
|
p[0] = p[0].replace("+", "")
|
||||||
country_code += INFLECT_ENGINE.number_to_words(p[0])
|
country_code += INFLECT_ENGINE.number_to_words(p[0])
|
||||||
|
|
||||||
area_code=INFLECT_ENGINE.number_to_words(p[2].replace("(","").replace(")",""),group=1,comma="")
|
area_code = INFLECT_ENGINE.number_to_words(
|
||||||
|
p[2].replace("(", "").replace(")", ""), group=1, comma=""
|
||||||
|
)
|
||||||
|
|
||||||
telephone_prefix=INFLECT_ENGINE.number_to_words(p[3],group=1,comma="")
|
telephone_prefix = INFLECT_ENGINE.number_to_words(p[3], group=1, comma="")
|
||||||
|
|
||||||
line_number=INFLECT_ENGINE.number_to_words(p[4],group=1,comma="")
|
line_number = INFLECT_ENGINE.number_to_words(p[4], group=1, comma="")
|
||||||
|
|
||||||
|
return ",".join([country_code, area_code, telephone_prefix, line_number])
|
||||||
|
|
||||||
return ",".join([country_code,area_code,telephone_prefix,line_number])
|
|
||||||
|
|
||||||
def handle_time(t: re.Match[str]) -> str:
|
def handle_time(t: re.Match[str]) -> str:
|
||||||
t=t.groups()
|
t = t.groups()
|
||||||
|
|
||||||
numbers = " ".join([INFLECT_ENGINE.number_to_words(X.strip()) for X in t[0].split(":")])
|
numbers = " ".join(
|
||||||
|
[INFLECT_ENGINE.number_to_words(X.strip()) for X in t[0].split(":")]
|
||||||
|
)
|
||||||
|
|
||||||
half=""
|
half = ""
|
||||||
if t[2] is not None:
|
if t[2] is not None:
|
||||||
half=t[2].strip()
|
half = t[2].strip()
|
||||||
|
|
||||||
return numbers + half
|
return numbers + half
|
||||||
|
|
||||||
def normalize_text(text: str,normalization_options: NormalizationOptions) -> str:
|
|
||||||
|
def normalize_text(text: str, normalization_options: NormalizationOptions) -> str:
|
||||||
"""Normalize text for TTS processing"""
|
"""Normalize text for TTS processing"""
|
||||||
# Handle email addresses first if enabled
|
# Handle email addresses first if enabled
|
||||||
if normalization_options.email_normalization:
|
if normalization_options.email_normalization:
|
||||||
|
@ -268,15 +343,19 @@ def normalize_text(text: str,normalization_options: NormalizationOptions) -> str
|
||||||
|
|
||||||
# Pre-process numbers with units if enabled
|
# Pre-process numbers with units if enabled
|
||||||
if normalization_options.unit_normalization:
|
if normalization_options.unit_normalization:
|
||||||
text=UNIT_PATTERN.sub(handle_units,text)
|
text = UNIT_PATTERN.sub(handle_units, text)
|
||||||
|
|
||||||
# Replace optional pluralization
|
# Replace optional pluralization
|
||||||
if normalization_options.optional_pluralization_normalization:
|
if normalization_options.optional_pluralization_normalization:
|
||||||
text = re.sub(r"\(s\)","s",text)
|
text = re.sub(r"\(s\)", "s", text)
|
||||||
|
|
||||||
# Replace phone numbers:
|
# Replace phone numbers:
|
||||||
if normalization_options.phone_normalization:
|
if normalization_options.phone_normalization:
|
||||||
text = re.sub(r"(\+?\d{1,2})?([ .-]?)(\(?\d{3}\)?)[\s.-](\d{3})[\s.-](\d{4})",handle_phone_number,text)
|
text = re.sub(
|
||||||
|
r"(\+?\d{1,2})?([ .-]?)(\(?\d{3}\)?)[\s.-](\d{3})[\s.-](\d{4})",
|
||||||
|
handle_phone_number,
|
||||||
|
text,
|
||||||
|
)
|
||||||
|
|
||||||
# Replace quotes and brackets
|
# Replace quotes and brackets
|
||||||
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
|
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
|
||||||
|
@ -288,7 +367,10 @@ def normalize_text(text: str,normalization_options: NormalizationOptions) -> str
|
||||||
text = text.replace(a, b + " ")
|
text = text.replace(a, b + " ")
|
||||||
|
|
||||||
# Handle simple time in the format of HH:MM:SS
|
# Handle simple time in the format of HH:MM:SS
|
||||||
text = TIME_PATTERN.sub(handle_time, text, )
|
text = TIME_PATTERN.sub(
|
||||||
|
handle_time,
|
||||||
|
text,
|
||||||
|
)
|
||||||
|
|
||||||
# Clean up whitespace
|
# Clean up whitespace
|
||||||
text = re.sub(r"[^\S \n]", " ", text)
|
text = re.sub(r"[^\S \n]", " ", text)
|
||||||
|
@ -328,6 +410,6 @@ def normalize_text(text: str,normalization_options: NormalizationOptions) -> str
|
||||||
text = re.sub(
|
text = re.sub(
|
||||||
r"(?:[A-Za-z]\.){2,} [a-z]", lambda m: m.group().replace(".", "-"), text
|
r"(?:[A-Za-z]\.){2,} [a-z]", lambda m: m.group().replace(".", "-"), text
|
||||||
)
|
)
|
||||||
text = re.sub( r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
|
text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
|
||||||
|
|
||||||
return text.strip()
|
return text.strip()
|
||||||
|
|
|
@ -7,16 +7,17 @@ from typing import AsyncGenerator, Dict, List, Tuple
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from ...core.config import settings
|
from ...core.config import settings
|
||||||
|
from ...structures.schemas import NormalizationOptions
|
||||||
from .normalizer import normalize_text
|
from .normalizer import normalize_text
|
||||||
from .phonemizer import phonemize
|
from .phonemizer import phonemize
|
||||||
from .vocabulary import tokenize
|
from .vocabulary import tokenize
|
||||||
from ...structures.schemas import NormalizationOptions
|
|
||||||
|
|
||||||
# Pre-compiled regex patterns for performance
|
# Pre-compiled regex patterns for performance
|
||||||
CUSTOM_PHONEMES = re.compile(r"(\[([^\]]|\n)*?\])(\(\/([^\/)]|\n)*?\/\))")
|
CUSTOM_PHONEMES = re.compile(r"(\[([^\]]|\n)*?\])(\(\/([^\/)]|\n)*?\/\))")
|
||||||
# Matching: [silent 1s], [silent 0.5s], [silent .5s]
|
# Matching: [silent 1s], [silent 0.5s], [silent .5s]
|
||||||
SILENCE_TAG = re.compile(r"\[silent (\d*\.?\d+)s\]")
|
SILENCE_TAG = re.compile(r"\[silent (\d*\.?\d+)s\]")
|
||||||
|
|
||||||
|
|
||||||
def process_text_chunk(
|
def process_text_chunk(
|
||||||
text: str, language: str = "a", skip_phonemize: bool = False
|
text: str, language: str = "a", skip_phonemize: bool = False
|
||||||
) -> List[int]:
|
) -> List[int]:
|
||||||
|
@ -43,9 +44,7 @@ def process_text_chunk(
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
phonemes = phonemize(
|
phonemes = phonemize(text, language, normalize=False) # Already normalized
|
||||||
text, language, normalize=False
|
|
||||||
) # Already normalized
|
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
@ -89,7 +88,6 @@ def process_text(text: str, language: str = "a") -> List[int]:
|
||||||
|
|
||||||
return process_text_chunk(text, language)
|
return process_text_chunk(text, language)
|
||||||
|
|
||||||
|
|
||||||
def get_sentence_info(text: str, custom_phenomes_list: Dict[str, str]) -> List[Tuple[str, List[int], int]]:
|
def get_sentence_info(text: str, custom_phenomes_list: Dict[str, str]) -> List[Tuple[str, List[int], int]]:
|
||||||
"""
|
"""
|
||||||
Process all sentences and return info.
|
Process all sentences and return info.
|
||||||
|
@ -134,6 +132,7 @@ def get_sentence_info(text: str, custom_phenomes_list: Dict[str, str]) -> List[T
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def handle_custom_phonemes(s: re.Match[str], phenomes_list: Dict[str,str]) -> str:
|
def handle_custom_phonemes(s: re.Match[str], phenomes_list: Dict[str,str]) -> str:
|
||||||
"""
|
"""
|
||||||
Replace [text](/phonemes/) with a <|custom_phonemes_X|/> tag to avoid being normalized.
|
Replace [text](/phonemes/) with a <|custom_phonemes_X|/> tag to avoid being normalized.
|
||||||
|
@ -143,11 +142,12 @@ def handle_custom_phonemes(s: re.Match[str], phenomes_list: Dict[str,str]) -> st
|
||||||
phenomes_list[latest_id] = s.group(0).strip()
|
phenomes_list[latest_id] = s.group(0).strip()
|
||||||
return latest_id
|
return latest_id
|
||||||
|
|
||||||
|
|
||||||
async def smart_split(
|
async def smart_split(
|
||||||
text: str,
|
text: str,
|
||||||
max_tokens: int = settings.absolute_max_tokens,
|
max_tokens: int = settings.absolute_max_tokens,
|
||||||
lang_code: str = "a",
|
lang_code: str = "a",
|
||||||
normalization_options: NormalizationOptions = NormalizationOptions()
|
normalization_options: NormalizationOptions = NormalizationOptions(),
|
||||||
) -> AsyncGenerator[Tuple[str, List[int]], None]:
|
) -> AsyncGenerator[Tuple[str, List[int]], None]:
|
||||||
"""Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens."""
|
"""Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens."""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
@ -162,9 +162,11 @@ async def smart_split(
|
||||||
if settings.advanced_text_normalization and normalization_options.normalize:
|
if settings.advanced_text_normalization and normalization_options.normalize:
|
||||||
if lang_code in ["a","b","en-us","en-gb"]:
|
if lang_code in ["a","b","en-us","en-gb"]:
|
||||||
text = CUSTOM_PHONEMES.sub(lambda s: handle_custom_phonemes(s, custom_phoneme_list), text)
|
text = CUSTOM_PHONEMES.sub(lambda s: handle_custom_phonemes(s, custom_phoneme_list), text)
|
||||||
text=normalize_text(text,normalization_options)
|
text = normalize_text(text,normalization_options)
|
||||||
else:
|
else:
|
||||||
logger.info("Skipping text normalization as it is only supported for english")
|
logger.info(
|
||||||
|
"Skipping text normalization as it is only supported for english"
|
||||||
|
)
|
||||||
|
|
||||||
# Process all sentences
|
# Process all sentences
|
||||||
sentences = get_sentence_info(text, custom_phoneme_list)
|
sentences = get_sentence_info(text, custom_phoneme_list)
|
||||||
|
|
|
@ -8,7 +8,6 @@ import time
|
||||||
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .streaming_audio_writer import StreamingAudioWriter
|
|
||||||
import torch
|
import torch
|
||||||
from kokoro import KPipeline
|
from kokoro import KPipeline
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
@ -20,6 +19,7 @@ from ..inference.model_manager import get_manager as get_model_manager
|
||||||
from ..inference.voice_manager import get_manager as get_voice_manager
|
from ..inference.voice_manager import get_manager as get_voice_manager
|
||||||
from ..structures.schemas import NormalizationOptions
|
from ..structures.schemas import NormalizationOptions
|
||||||
from .audio import AudioNormalizer, AudioService
|
from .audio import AudioNormalizer, AudioService
|
||||||
|
from .streaming_audio_writer import StreamingAudioWriter
|
||||||
from .text_processing import tokenize
|
from .text_processing import tokenize
|
||||||
from .text_processing.text_processor import SILENCE_TAG, smart_split
|
from .text_processing.text_processor import SILENCE_TAG, smart_split
|
||||||
|
|
||||||
|
@ -88,7 +88,9 @@ class TTSService:
|
||||||
yield AudioChunk(np.array([], dtype=np.int16), output=b"")
|
yield AudioChunk(np.array([], dtype=np.int16), output=b"")
|
||||||
return
|
return
|
||||||
chunk_data = await AudioService.convert_audio(
|
chunk_data = await AudioService.convert_audio(
|
||||||
AudioChunk(np.array([], dtype=np.float32)), # Dummy data for type checking
|
AudioChunk(
|
||||||
|
np.array([], dtype=np.float32)
|
||||||
|
), # Dummy data for type checking
|
||||||
output_format,
|
output_format,
|
||||||
writer,
|
writer,
|
||||||
speed,
|
speed,
|
||||||
|
@ -133,13 +135,22 @@ class TTSService:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to convert audio: {str(e)}")
|
logger.error(f"Failed to convert audio: {str(e)}")
|
||||||
else:
|
else:
|
||||||
chunk_data = AudioService.trim_audio(chunk_data, chunk_text, speed, is_last, normalizer)
|
chunk_data = AudioService.trim_audio(
|
||||||
|
chunk_data, chunk_text, speed, is_last, normalizer
|
||||||
|
)
|
||||||
yield chunk_data
|
yield chunk_data
|
||||||
chunk_index += 1
|
chunk_index += 1
|
||||||
else:
|
else:
|
||||||
# For legacy backends, load voice tensor
|
# For legacy backends, load voice tensor
|
||||||
voice_tensor = await self._voice_manager.load_voice(voice_name, device=backend.device)
|
voice_tensor = await self._voice_manager.load_voice(
|
||||||
chunk_data = await self.model_manager.generate(tokens, voice_tensor, speed=speed, return_timestamps=return_timestamps)
|
voice_name, device=backend.device
|
||||||
|
)
|
||||||
|
chunk_data = await self.model_manager.generate(
|
||||||
|
tokens,
|
||||||
|
voice_tensor,
|
||||||
|
speed=speed,
|
||||||
|
return_timestamps=return_timestamps,
|
||||||
|
)
|
||||||
|
|
||||||
if chunk_data.audio is None:
|
if chunk_data.audio is None:
|
||||||
logger.error("Model generated None for audio chunk")
|
logger.error("Model generated None for audio chunk")
|
||||||
|
@ -165,7 +176,9 @@ class TTSService:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to convert audio: {str(e)}")
|
logger.error(f"Failed to convert audio: {str(e)}")
|
||||||
else:
|
else:
|
||||||
trimmed = AudioService.trim_audio(chunk_data, chunk_text, speed, is_last, normalizer)
|
trimmed = AudioService.trim_audio(
|
||||||
|
chunk_data, chunk_text, speed, is_last, normalizer
|
||||||
|
)
|
||||||
yield trimmed
|
yield trimmed
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to process tokens: {str(e)}")
|
logger.error(f"Failed to process tokens: {str(e)}")
|
||||||
|
@ -197,7 +210,9 @@ class TTSService:
|
||||||
# If it is only once voice there is no point in loading it up, doing nothing with it, then saving it
|
# If it is only once voice there is no point in loading it up, doing nothing with it, then saving it
|
||||||
if len(split_voice) == 1:
|
if len(split_voice) == 1:
|
||||||
# Since its a single voice the only time that the weight would matter is if voice_weight_normalization is off
|
# Since its a single voice the only time that the weight would matter is if voice_weight_normalization is off
|
||||||
if ("(" not in voice and ")" not in voice) or settings.voice_weight_normalization == True:
|
if (
|
||||||
|
"(" not in voice and ")" not in voice
|
||||||
|
) or settings.voice_weight_normalization == True:
|
||||||
path = await self._voice_manager.get_voice_path(voice)
|
path = await self._voice_manager.get_voice_path(voice)
|
||||||
if not path:
|
if not path:
|
||||||
raise RuntimeError(f"Voice not found: {voice}")
|
raise RuntimeError(f"Voice not found: {voice}")
|
||||||
|
@ -225,13 +240,19 @@ class TTSService:
|
||||||
|
|
||||||
# Load the first voice as the starting point for voices to be combined onto
|
# Load the first voice as the starting point for voices to be combined onto
|
||||||
path = await self._voice_manager.get_voice_path(split_voice[0][0])
|
path = await self._voice_manager.get_voice_path(split_voice[0][0])
|
||||||
combined_tensor = await self._load_voice_from_path(path, split_voice[0][1] / total_weight)
|
combined_tensor = await self._load_voice_from_path(
|
||||||
|
path, split_voice[0][1] / total_weight
|
||||||
|
)
|
||||||
|
|
||||||
# Loop through each + or - in split_voice so they can be applied to combined voice
|
# Loop through each + or - in split_voice so they can be applied to combined voice
|
||||||
for operation_index in range(1, len(split_voice) - 1, 2):
|
for operation_index in range(1, len(split_voice) - 1, 2):
|
||||||
# Get the voice path of the voice 1 index ahead of the operator
|
# Get the voice path of the voice 1 index ahead of the operator
|
||||||
path = await self._voice_manager.get_voice_path(split_voice[operation_index + 1][0])
|
path = await self._voice_manager.get_voice_path(
|
||||||
voice_tensor = await self._load_voice_from_path(path, split_voice[operation_index + 1][1] / total_weight)
|
split_voice[operation_index + 1][0]
|
||||||
|
)
|
||||||
|
voice_tensor = await self._load_voice_from_path(
|
||||||
|
path, split_voice[operation_index + 1][1] / total_weight
|
||||||
|
)
|
||||||
|
|
||||||
# Either add or subtract the voice from the current combined voice
|
# Either add or subtract the voice from the current combined voice
|
||||||
if split_voice[operation_index] == "+":
|
if split_voice[operation_index] == "+":
|
||||||
|
@ -274,10 +295,16 @@ class TTSService:
|
||||||
|
|
||||||
# Use provided lang_code or determine from voice name
|
# Use provided lang_code or determine from voice name
|
||||||
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
|
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
|
||||||
logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream")
|
logger.info(
|
||||||
|
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream"
|
||||||
|
)
|
||||||
|
|
||||||
# Process text in chunks with smart splitting
|
# Process text in chunks with smart splitting
|
||||||
async for chunk_text, tokens in smart_split(text, lang_code=pipeline_lang_code, normalization_options=normalization_options):
|
async for chunk_text, tokens in smart_split(
|
||||||
|
text,
|
||||||
|
lang_code=pipeline_lang_code,
|
||||||
|
normalization_options=normalization_options,
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
# Process audio for chunk
|
# Process audio for chunk
|
||||||
async for chunk_data in self._process_chunk(
|
async for chunk_data in self._process_chunk(
|
||||||
|
@ -305,10 +332,14 @@ class TTSService:
|
||||||
yield chunk_data
|
yield chunk_data
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.warning(f"No audio generated for chunk: '{chunk_text[:100]}...'")
|
logger.warning(
|
||||||
|
f"No audio generated for chunk: '{chunk_text[:100]}...'"
|
||||||
|
)
|
||||||
chunk_index += 1
|
chunk_index += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}")
|
logger.error(
|
||||||
|
f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}"
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Only finalize if we successfully processed at least one chunk
|
# Only finalize if we successfully processed at least one chunk
|
||||||
|
@ -351,7 +382,16 @@ class TTSService:
|
||||||
audio_data_chunks = []
|
audio_data_chunks = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for audio_stream_data in self.generate_audio_stream(text, voice, writer, speed=speed, normalization_options=normalization_options, return_timestamps=return_timestamps, lang_code=lang_code, output_format=None):
|
async for audio_stream_data in self.generate_audio_stream(
|
||||||
|
text,
|
||||||
|
voice,
|
||||||
|
writer,
|
||||||
|
speed=speed,
|
||||||
|
normalization_options=normalization_options,
|
||||||
|
return_timestamps=return_timestamps,
|
||||||
|
lang_code=lang_code,
|
||||||
|
output_format=None,
|
||||||
|
):
|
||||||
if len(audio_stream_data.audio) > 0:
|
if len(audio_stream_data.audio) > 0:
|
||||||
audio_data_chunks.append(audio_stream_data)
|
audio_data_chunks.append(audio_stream_data)
|
||||||
|
|
||||||
|
@ -403,11 +443,15 @@ class TTSService:
|
||||||
result = None
|
result = None
|
||||||
# Use provided lang_code or determine from voice name
|
# Use provided lang_code or determine from voice name
|
||||||
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
|
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
|
||||||
logger.info(f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme pipeline")
|
logger.info(
|
||||||
|
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme pipeline"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Use backend's pipeline management
|
# Use backend's pipeline management
|
||||||
for r in backend._get_pipeline(pipeline_lang_code).generate_from_tokens(
|
for r in backend._get_pipeline(
|
||||||
|
pipeline_lang_code
|
||||||
|
).generate_from_tokens(
|
||||||
tokens=phonemes, # Pass raw phonemes string
|
tokens=phonemes, # Pass raw phonemes string
|
||||||
voice=voice_path,
|
voice=voice_path,
|
||||||
speed=speed,
|
speed=speed,
|
||||||
|
@ -425,7 +469,9 @@ class TTSService:
|
||||||
processing_time = time.time() - start_time
|
processing_time = time.time() - start_time
|
||||||
return result.audio.numpy(), processing_time
|
return result.audio.numpy(), processing_time
|
||||||
else:
|
else:
|
||||||
raise ValueError("Phoneme generation only supported with Kokoro V1 backend")
|
raise ValueError(
|
||||||
|
"Phoneme generation only supported with Kokoro V1 backend"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in phoneme audio generation: {str(e)}")
|
logger.error(f"Error in phoneme audio generation: {str(e)}")
|
||||||
|
|
|
@ -24,16 +24,12 @@ class JSONStreamingResponse(StreamingResponse, JSONResponse):
|
||||||
else:
|
else:
|
||||||
self._content_iterable = iterate_in_threadpool(content)
|
self._content_iterable = iterate_in_threadpool(content)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def body_iterator() -> AsyncIterable[bytes]:
|
async def body_iterator() -> AsyncIterable[bytes]:
|
||||||
async for content_ in self._content_iterable:
|
async for content_ in self._content_iterable:
|
||||||
if isinstance(content_, BaseModel):
|
if isinstance(content_, BaseModel):
|
||||||
content_ = content_.model_dump()
|
content_ = content_.model_dump()
|
||||||
yield self.render(content_)
|
yield self.render(content_)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
self.body_iterator = body_iterator()
|
self.body_iterator = body_iterator()
|
||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
if media_type is not None:
|
if media_type is not None:
|
||||||
|
@ -42,10 +38,13 @@ class JSONStreamingResponse(StreamingResponse, JSONResponse):
|
||||||
self.init_headers(headers)
|
self.init_headers(headers)
|
||||||
|
|
||||||
def render(self, content: typing.Any) -> bytes:
|
def render(self, content: typing.Any) -> bytes:
|
||||||
return (json.dumps(
|
return (
|
||||||
|
json.dumps(
|
||||||
content,
|
content,
|
||||||
ensure_ascii=False,
|
ensure_ascii=False,
|
||||||
allow_nan=False,
|
allow_nan=False,
|
||||||
indent=None,
|
indent=None,
|
||||||
separators=(",", ":"),
|
separators=(",", ":"),
|
||||||
) + "\n").encode("utf-8")
|
)
|
||||||
|
+ "\n"
|
||||||
|
).encode("utf-8")
|
||||||
|
|
|
@ -35,16 +35,38 @@ class CaptionedSpeechResponse(BaseModel):
|
||||||
|
|
||||||
audio: str = Field(..., description="The generated audio data encoded in base 64")
|
audio: str = Field(..., description="The generated audio data encoded in base 64")
|
||||||
audio_format: str = Field(..., description="The format of the output audio")
|
audio_format: str = Field(..., description="The format of the output audio")
|
||||||
timestamps: Optional[List[WordTimestamp]] = Field(..., description="Word-level timestamps")
|
timestamps: Optional[List[WordTimestamp]] = Field(
|
||||||
|
..., description="Word-level timestamps"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class NormalizationOptions(BaseModel):
|
class NormalizationOptions(BaseModel):
|
||||||
"""Options for the normalization system"""
|
"""Options for the normalization system"""
|
||||||
normalize: bool = Field(default=True, description="Normalizes input text to make it easier for the model to say")
|
|
||||||
unit_normalization: bool = Field(default=False,description="Transforms units like 10KB to 10 kilobytes")
|
normalize: bool = Field(
|
||||||
url_normalization: bool = Field(default=True, description="Changes urls so they can be properly pronouced by kokoro")
|
default=True,
|
||||||
email_normalization: bool = Field(default=True, description="Changes emails so they can be properly pronouced by kokoro")
|
description="Normalizes input text to make it easier for the model to say",
|
||||||
optional_pluralization_normalization: bool = Field(default=True, description="Replaces (s) with s so some words get pronounced correctly")
|
)
|
||||||
phone_normalization: bool = Field(default=True, description="Changes phone numbers so they can be properly pronouced by kokoro")
|
unit_normalization: bool = Field(
|
||||||
|
default=False, description="Transforms units like 10KB to 10 kilobytes"
|
||||||
|
)
|
||||||
|
url_normalization: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Changes urls so they can be properly pronounced by kokoro",
|
||||||
|
)
|
||||||
|
email_normalization: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Changes emails so they can be properly pronouced by kokoro",
|
||||||
|
)
|
||||||
|
optional_pluralization_normalization: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Replaces (s) with s so some words get pronounced correctly",
|
||||||
|
)
|
||||||
|
phone_normalization: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Changes phone numbers so they can be properly pronouced by kokoro",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class OpenAISpeechRequest(BaseModel):
|
class OpenAISpeechRequest(BaseModel):
|
||||||
"""Request schema for OpenAI-compatible speech endpoint"""
|
"""Request schema for OpenAI-compatible speech endpoint"""
|
||||||
|
@ -62,10 +84,12 @@ class OpenAISpeechRequest(BaseModel):
|
||||||
default="mp3",
|
default="mp3",
|
||||||
description="The format to return audio in. Supported formats: mp3, opus, flac, wav, pcm. PCM format returns raw 16-bit samples without headers. AAC is not currently supported.",
|
description="The format to return audio in. Supported formats: mp3, opus, flac, wav, pcm. PCM format returns raw 16-bit samples without headers. AAC is not currently supported.",
|
||||||
)
|
)
|
||||||
download_format: Optional[Literal["mp3", "opus", "aac", "flac", "wav", "pcm"]] = Field(
|
download_format: Optional[Literal["mp3", "opus", "aac", "flac", "wav", "pcm"]] = (
|
||||||
|
Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Optional different format for the final download. If not provided, uses response_format.",
|
description="Optional different format for the final download. If not provided, uses response_format.",
|
||||||
)
|
)
|
||||||
|
)
|
||||||
speed: float = Field(
|
speed: float = Field(
|
||||||
default=1.0,
|
default=1.0,
|
||||||
ge=0.25,
|
ge=0.25,
|
||||||
|
@ -85,8 +109,8 @@ class OpenAISpeechRequest(BaseModel):
|
||||||
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
|
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
|
||||||
)
|
)
|
||||||
normalization_options: Optional[NormalizationOptions] = Field(
|
normalization_options: Optional[NormalizationOptions] = Field(
|
||||||
default= NormalizationOptions(),
|
default=NormalizationOptions(),
|
||||||
description= "Options for the normalization system"
|
description="Options for the normalization system",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -129,6 +153,6 @@ class CaptionedSpeechRequest(BaseModel):
|
||||||
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
|
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
|
||||||
)
|
)
|
||||||
normalization_options: Optional[NormalizationOptions] = Field(
|
normalization_options: Optional[NormalizationOptions] = Field(
|
||||||
default= NormalizationOptions(),
|
default=NormalizationOptions(),
|
||||||
description= "Options for the normalization system"
|
description="Options for the normalization system",
|
||||||
)
|
)
|
||||||
|
|
|
@ -69,4 +69,3 @@ async def tts_service(mock_model_manager, mock_voice_manager):
|
||||||
def test_voice():
|
def test_voice():
|
||||||
"""Return a test voice name."""
|
"""Return a test voice name."""
|
||||||
return "voice1"
|
return "voice1"
|
||||||
|
|
||||||
|
|
|
@ -5,9 +5,11 @@ from unittest.mock import patch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from api.src.services.audio import AudioNormalizer, AudioService
|
|
||||||
from api.src.inference.base import AudioChunk
|
from api.src.inference.base import AudioChunk
|
||||||
|
from api.src.services.audio import AudioNormalizer, AudioService
|
||||||
from api.src.services.streaming_audio_writer import StreamingAudioWriter
|
from api.src.services.streaming_audio_writer import StreamingAudioWriter
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def mock_settings():
|
def mock_settings():
|
||||||
"""Mock settings for all tests"""
|
"""Mock settings for all tests"""
|
||||||
|
@ -64,7 +66,9 @@ async def test_convert_to_mp3(sample_audio):
|
||||||
assert isinstance(audio_chunk, AudioChunk)
|
assert isinstance(audio_chunk, AudioChunk)
|
||||||
assert len(audio_chunk.output) > 0
|
assert len(audio_chunk.output) > 0
|
||||||
# Check MP3 header (ID3 or MPEG frame sync)
|
# Check MP3 header (ID3 or MPEG frame sync)
|
||||||
assert audio_chunk.output.startswith(b"ID3") or audio_chunk.output.startswith(b"\xff\xfb")
|
assert audio_chunk.output.startswith(b"ID3") or audio_chunk.output.startswith(
|
||||||
|
b"\xff\xfb"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -76,7 +80,7 @@ async def test_convert_to_opus(sample_audio):
|
||||||
writer = StreamingAudioWriter("opus", sample_rate=24000)
|
writer = StreamingAudioWriter("opus", sample_rate=24000)
|
||||||
|
|
||||||
audio_chunk = await AudioService.convert_audio(
|
audio_chunk = await AudioService.convert_audio(
|
||||||
AudioChunk(audio_data), "opus",writer
|
AudioChunk(audio_data), "opus", writer
|
||||||
)
|
)
|
||||||
|
|
||||||
writer.close()
|
writer.close()
|
||||||
|
@ -125,7 +129,9 @@ async def test_convert_to_aac(sample_audio):
|
||||||
assert isinstance(audio_chunk, AudioChunk)
|
assert isinstance(audio_chunk, AudioChunk)
|
||||||
assert len(audio_chunk.output) > 0
|
assert len(audio_chunk.output) > 0
|
||||||
# Check ADTS header (AAC)
|
# Check ADTS header (AAC)
|
||||||
assert audio_chunk.output.startswith(b"\xff\xf0") or audio_chunk.output.startswith(b"\xff\xf1")
|
assert audio_chunk.output.startswith(b"\xff\xf0") or audio_chunk.output.startswith(
|
||||||
|
b"\xff\xf1"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -150,7 +156,7 @@ async def test_convert_to_pcm(sample_audio):
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_convert_to_invalid_format_raises_error(sample_audio):
|
async def test_convert_to_invalid_format_raises_error(sample_audio):
|
||||||
"""Test that converting to an invalid format raises an error"""
|
"""Test that converting to an invalid format raises an error"""
|
||||||
#audio_data, sample_rate = sample_audio
|
# audio_data, sample_rate = sample_audio
|
||||||
with pytest.raises(ValueError, match="Unsupported format: invalid"):
|
with pytest.raises(ValueError, match="Unsupported format: invalid"):
|
||||||
writer = StreamingAudioWriter("invalid", sample_rate=24000)
|
writer = StreamingAudioWriter("invalid", sample_rate=24000)
|
||||||
|
|
||||||
|
@ -212,7 +218,6 @@ async def test_different_sample_rates(sample_audio):
|
||||||
sample_rates = [8000, 16000, 44100, 48000]
|
sample_rates = [8000, 16000, 44100, 48000]
|
||||||
|
|
||||||
for rate in sample_rates:
|
for rate in sample_rates:
|
||||||
|
|
||||||
writer = StreamingAudioWriter("wav", sample_rate=rate)
|
writer = StreamingAudioWriter("wav", sample_rate=rate)
|
||||||
|
|
||||||
audio_chunk = await AudioService.convert_audio(
|
audio_chunk = await AudioService.convert_audio(
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
import pytest
|
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
import requests
|
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
def test_generate_captioned_speech():
|
def test_generate_captioned_speech():
|
||||||
"""Test the generate_captioned_speech function with mocked responses"""
|
"""Test the generate_captioned_speech function with mocked responses"""
|
||||||
|
@ -12,14 +14,15 @@ def test_generate_captioned_speech():
|
||||||
|
|
||||||
mock_timestamps_response = MagicMock()
|
mock_timestamps_response = MagicMock()
|
||||||
mock_timestamps_response.status_code = 200
|
mock_timestamps_response.status_code = 200
|
||||||
mock_timestamps_response.content = json.dumps({
|
mock_timestamps_response.content = json.dumps(
|
||||||
"audio":base64.b64encode(b"mock audio data").decode("utf-8"),
|
{
|
||||||
"timestamps":[{"word": "test", "start_time": 0.0, "end_time": 1.0}]
|
"audio": base64.b64encode(b"mock audio data").decode("utf-8"),
|
||||||
})
|
"timestamps": [{"word": "test", "start_time": 0.0, "end_time": 1.0}],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Patch the HTTP requests
|
# Patch the HTTP requests
|
||||||
with patch('requests.post', return_value=mock_timestamps_response):
|
with patch("requests.post", return_value=mock_timestamps_response):
|
||||||
|
|
||||||
# Import here to avoid module-level import issues
|
# Import here to avoid module-level import issues
|
||||||
from examples.captioned_speech_example import generate_captioned_speech
|
from examples.captioned_speech_example import generate_captioned_speech
|
||||||
|
|
||||||
|
|
|
@ -5,27 +5,48 @@ import pytest
|
||||||
from api.src.services.text_processing.normalizer import normalize_text
|
from api.src.services.text_processing.normalizer import normalize_text
|
||||||
from api.src.structures.schemas import NormalizationOptions
|
from api.src.structures.schemas import NormalizationOptions
|
||||||
|
|
||||||
|
|
||||||
def test_url_protocols():
|
def test_url_protocols():
|
||||||
"""Test URL protocol handling"""
|
"""Test URL protocol handling"""
|
||||||
assert (
|
assert (
|
||||||
normalize_text("Check out https://example.com",normalization_options=NormalizationOptions())
|
normalize_text(
|
||||||
|
"Check out https://example.com",
|
||||||
|
normalization_options=NormalizationOptions(),
|
||||||
|
)
|
||||||
== "Check out https example dot com"
|
== "Check out https example dot com"
|
||||||
)
|
)
|
||||||
assert normalize_text("Visit http://site.com",normalization_options=NormalizationOptions()) == "Visit http site dot com"
|
|
||||||
assert (
|
assert (
|
||||||
normalize_text("Go to https://test.org/path",normalization_options=NormalizationOptions())
|
normalize_text(
|
||||||
|
"Visit http://site.com", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
|
== "Visit http site dot com"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"Go to https://test.org/path", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
== "Go to https test dot org slash path"
|
== "Go to https test dot org slash path"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_url_www():
|
def test_url_www():
|
||||||
"""Test www prefix handling"""
|
"""Test www prefix handling"""
|
||||||
assert normalize_text("Go to www.example.com",normalization_options=NormalizationOptions()) == "Go to www example dot com"
|
|
||||||
assert (
|
assert (
|
||||||
normalize_text("Visit www.test.org/docs",normalization_options=NormalizationOptions()) == "Visit www test dot org slash docs"
|
normalize_text(
|
||||||
|
"Go to www.example.com", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
|
== "Go to www example dot com"
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
normalize_text("Check www.site.com?q=test",normalization_options=NormalizationOptions())
|
normalize_text(
|
||||||
|
"Visit www.test.org/docs", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
|
== "Visit www test dot org slash docs"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"Check www.site.com?q=test", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
== "Check www site dot com question-mark q equals test"
|
== "Check www site dot com question-mark q equals test"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -33,15 +54,21 @@ def test_url_www():
|
||||||
def test_url_localhost():
|
def test_url_localhost():
|
||||||
"""Test localhost URL handling"""
|
"""Test localhost URL handling"""
|
||||||
assert (
|
assert (
|
||||||
normalize_text("Running on localhost:7860",normalization_options=NormalizationOptions())
|
normalize_text(
|
||||||
|
"Running on localhost:7860", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
== "Running on localhost colon 78 60"
|
== "Running on localhost colon 78 60"
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
normalize_text("Server at localhost:8080/api",normalization_options=NormalizationOptions())
|
normalize_text(
|
||||||
|
"Server at localhost:8080/api", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
== "Server at localhost colon 80 80 slash api"
|
== "Server at localhost colon 80 80 slash api"
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
normalize_text("Test localhost:3000/test?v=1",normalization_options=NormalizationOptions())
|
normalize_text(
|
||||||
|
"Test localhost:3000/test?v=1", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
== "Test localhost colon 3000 slash test question-mark v equals 1"
|
== "Test localhost colon 3000 slash test question-mark v equals 1"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -49,48 +76,104 @@ def test_url_localhost():
|
||||||
def test_url_ip_addresses():
|
def test_url_ip_addresses():
|
||||||
"""Test IP address URL handling"""
|
"""Test IP address URL handling"""
|
||||||
assert (
|
assert (
|
||||||
normalize_text("Access 0.0.0.0:9090/test",normalization_options=NormalizationOptions())
|
normalize_text(
|
||||||
|
"Access 0.0.0.0:9090/test", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
== "Access 0 dot 0 dot 0 dot 0 colon 90 90 slash test"
|
== "Access 0 dot 0 dot 0 dot 0 colon 90 90 slash test"
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
normalize_text("API at 192.168.1.1:8000",normalization_options=NormalizationOptions())
|
normalize_text(
|
||||||
|
"API at 192.168.1.1:8000", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
== "API at 192 dot 168 dot 1 dot 1 colon 8000"
|
== "API at 192 dot 168 dot 1 dot 1 colon 8000"
|
||||||
)
|
)
|
||||||
assert normalize_text("Server 127.0.0.1",normalization_options=NormalizationOptions()) == "Server 127 dot 0 dot 0 dot 1"
|
assert (
|
||||||
|
normalize_text("Server 127.0.0.1", normalization_options=NormalizationOptions())
|
||||||
|
== "Server 127 dot 0 dot 0 dot 1"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_url_raw_domains():
|
def test_url_raw_domains():
|
||||||
"""Test raw domain handling"""
|
"""Test raw domain handling"""
|
||||||
assert (
|
assert (
|
||||||
normalize_text("Visit google.com/search",normalization_options=NormalizationOptions()) == "Visit google dot com slash search"
|
normalize_text(
|
||||||
|
"Visit google.com/search", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
|
== "Visit google dot com slash search"
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
normalize_text("Go to example.com/path?q=test",normalization_options=NormalizationOptions())
|
normalize_text(
|
||||||
|
"Go to example.com/path?q=test",
|
||||||
|
normalization_options=NormalizationOptions(),
|
||||||
|
)
|
||||||
== "Go to example dot com slash path question-mark q equals test"
|
== "Go to example dot com slash path question-mark q equals test"
|
||||||
)
|
)
|
||||||
assert normalize_text("Check docs.test.com",normalization_options=NormalizationOptions()) == "Check docs dot test dot com"
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"Check docs.test.com", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
|
== "Check docs dot test dot com"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_url_email_addresses():
|
def test_url_email_addresses():
|
||||||
"""Test email address handling"""
|
"""Test email address handling"""
|
||||||
assert (
|
assert (
|
||||||
normalize_text("Email me at user@example.com",normalization_options=NormalizationOptions())
|
normalize_text(
|
||||||
|
"Email me at user@example.com", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
== "Email me at user at example dot com"
|
== "Email me at user at example dot com"
|
||||||
)
|
)
|
||||||
assert normalize_text("Contact admin@test.org",normalization_options=NormalizationOptions()) == "Contact admin at test dot org"
|
|
||||||
assert (
|
assert (
|
||||||
normalize_text("Send to test.user@site.com",normalization_options=NormalizationOptions())
|
normalize_text(
|
||||||
|
"Contact admin@test.org", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
|
== "Contact admin at test dot org"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"Send to test.user@site.com", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
== "Send to test dot user at site dot com"
|
== "Send to test dot user at site dot com"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_money():
|
def test_money():
|
||||||
"""Test that money text is normalized correctly"""
|
"""Test that money text is normalized correctly"""
|
||||||
assert normalize_text("He lost $5.3 thousand.",normalization_options=NormalizationOptions()) == "He lost five point three thousand dollars."
|
assert (
|
||||||
assert normalize_text("To put it weirdly -$6.9 million",normalization_options=NormalizationOptions()) == "To put it weirdly minus six point nine million dollars"
|
normalize_text(
|
||||||
assert normalize_text("It costs $50.3.",normalization_options=NormalizationOptions()) == "It costs fifty dollars and thirty cents."
|
"He lost $5.3 thousand.", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
|
== "He lost five point three thousand dollars."
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"To put it weirdly -$6.9 million",
|
||||||
|
normalization_options=NormalizationOptions(),
|
||||||
|
)
|
||||||
|
== "To put it weirdly minus six point nine million dollars"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
normalize_text("It costs $50.3.", normalization_options=NormalizationOptions())
|
||||||
|
== "It costs fifty dollars and thirty cents."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_non_url_text():
|
def test_non_url_text():
|
||||||
"""Test that non-URL text is unaffected"""
|
"""Test that non-URL text is unaffected"""
|
||||||
assert normalize_text("This is not.a.url text",normalization_options=NormalizationOptions()) == "This is not-a-url text"
|
assert (
|
||||||
assert normalize_text("Hello, how are you today?",normalization_options=NormalizationOptions()) == "Hello, how are you today?"
|
normalize_text(
|
||||||
assert normalize_text("It costs $50.",normalization_options=NormalizationOptions()) == "It costs fifty dollars."
|
"This is not.a.url text", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
|
== "This is not-a-url text"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"Hello, how are you today?", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
|
== "Hello, how are you today?"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
normalize_text("It costs $50.", normalization_options=NormalizationOptions())
|
||||||
|
== "It costs fifty dollars."
|
||||||
|
)
|
||||||
|
|
|
@ -4,20 +4,19 @@ import os
|
||||||
from typing import AsyncGenerator, Tuple
|
from typing import AsyncGenerator, Tuple
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
from api.src.services.streaming_audio_writer import StreamingAudioWriter
|
|
||||||
|
|
||||||
from api.src.inference.base import AudioChunk
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from api.src.core.config import settings
|
from api.src.core.config import settings
|
||||||
|
from api.src.inference.base import AudioChunk
|
||||||
from api.src.main import app
|
from api.src.main import app
|
||||||
from api.src.routers.openai_compatible import (
|
from api.src.routers.openai_compatible import (
|
||||||
get_tts_service,
|
get_tts_service,
|
||||||
load_openai_mappings,
|
load_openai_mappings,
|
||||||
stream_audio_chunks,
|
stream_audio_chunks,
|
||||||
)
|
)
|
||||||
|
from api.src.services.streaming_audio_writer import StreamingAudioWriter
|
||||||
from api.src.services.tts_service import TTSService
|
from api.src.services.tts_service import TTSService
|
||||||
from api.src.structures.schemas import OpenAISpeechRequest
|
from api.src.structures.schemas import OpenAISpeechRequest
|
||||||
|
|
||||||
|
@ -114,7 +113,6 @@ def test_retrieve_model(mock_openai_mappings):
|
||||||
assert error["detail"]["type"] == "invalid_request_error"
|
assert error["detail"]["type"] == "invalid_request_error"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_tts_service_initialization():
|
async def test_get_tts_service_initialization():
|
||||||
"""Test TTSService initialization"""
|
"""Test TTSService initialization"""
|
||||||
|
@ -147,7 +145,7 @@ async def test_stream_audio_chunks_client_disconnect():
|
||||||
|
|
||||||
async def mock_stream(*args, **kwargs):
|
async def mock_stream(*args, **kwargs):
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
yield AudioChunk(np.ndarray([],np.int16),output=b"chunk")
|
yield AudioChunk(np.ndarray([], np.int16), output=b"chunk")
|
||||||
|
|
||||||
mock_service.generate_audio_stream = mock_stream
|
mock_service.generate_audio_stream = mock_stream
|
||||||
mock_service.list_voices.return_value = ["test_voice"]
|
mock_service.list_voices.return_value = ["test_voice"]
|
||||||
|
@ -243,10 +241,10 @@ def mock_tts_service(mock_audio_bytes):
|
||||||
"""Mock TTS service for testing."""
|
"""Mock TTS service for testing."""
|
||||||
with patch("api.src.routers.openai_compatible.get_tts_service") as mock_get:
|
with patch("api.src.routers.openai_compatible.get_tts_service") as mock_get:
|
||||||
service = AsyncMock(spec=TTSService)
|
service = AsyncMock(spec=TTSService)
|
||||||
service.generate_audio.return_value = AudioChunk(np.zeros(1000,np.int16))
|
service.generate_audio.return_value = AudioChunk(np.zeros(1000, np.int16))
|
||||||
|
|
||||||
async def mock_stream(*args, **kwargs) -> AsyncGenerator[AudioChunk, None]:
|
async def mock_stream(*args, **kwargs) -> AsyncGenerator[AudioChunk, None]:
|
||||||
yield AudioChunk(np.ndarray([],np.int16),output=mock_audio_bytes)
|
yield AudioChunk(np.ndarray([], np.int16), output=mock_audio_bytes)
|
||||||
|
|
||||||
service.generate_audio_stream = mock_stream
|
service.generate_audio_stream = mock_stream
|
||||||
service.list_voices.return_value = ["test_voice", "voice1", "voice2"]
|
service.list_voices.return_value = ["test_voice", "voice1", "voice2"]
|
||||||
|
@ -263,8 +261,10 @@ def test_openai_speech_endpoint(
|
||||||
):
|
):
|
||||||
"""Test the OpenAI-compatible speech endpoint with basic MP3 generation"""
|
"""Test the OpenAI-compatible speech endpoint with basic MP3 generation"""
|
||||||
# Configure mocks
|
# Configure mocks
|
||||||
mock_tts_service.generate_audio.return_value = AudioChunk(np.zeros(1000,np.int16))
|
mock_tts_service.generate_audio.return_value = AudioChunk(np.zeros(1000, np.int16))
|
||||||
mock_convert.return_value = AudioChunk(np.zeros(1000,np.int16),output=mock_audio_bytes)
|
mock_convert.return_value = AudioChunk(
|
||||||
|
np.zeros(1000, np.int16), output=mock_audio_bytes
|
||||||
|
)
|
||||||
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/audio/speech",
|
"/v1/audio/speech",
|
||||||
|
|
|
@ -44,9 +44,12 @@ def test_get_sentence_info():
|
||||||
assert count == len(tokens)
|
assert count == len(tokens)
|
||||||
assert count > 0
|
assert count > 0
|
||||||
|
|
||||||
|
|
||||||
def test_get_sentence_info_phenomoes():
|
def test_get_sentence_info_phenomoes():
|
||||||
"""Test sentence splitting and info extraction."""
|
"""Test sentence splitting and info extraction."""
|
||||||
text = "This is sentence one. This is </|custom_phonemes_0|/> two! What about three?"
|
text = (
|
||||||
|
"This is sentence one. This is </|custom_phonemes_0|/> two! What about three?"
|
||||||
|
)
|
||||||
results = get_sentence_info(text, {"</|custom_phonemes_0|/>": r"sˈɛntᵊns"})
|
results = get_sentence_info(text, {"</|custom_phonemes_0|/>": r"sˈɛntᵊns"})
|
||||||
|
|
||||||
assert len(results) == 3
|
assert len(results) == 3
|
||||||
|
@ -58,13 +61,14 @@ def test_get_sentence_info_phenomoes():
|
||||||
assert count == len(tokens)
|
assert count == len(tokens)
|
||||||
assert count > 0
|
assert count > 0
|
||||||
|
|
||||||
|
|
||||||
def test_get_sentence_info_silence_tags():
|
def test_get_sentence_info_silence_tags():
|
||||||
"""Test sentence splitting and info extraction with silence tags."""
|
"""Test sentence splitting and info extraction with silence tags."""
|
||||||
text = "This is a test sentence, [silent](/1s/) with silence for one second."
|
text = "This is a test sentence, [silent 1s] with silence for one second."
|
||||||
results = get_sentence_info(text, {})
|
results = get_sentence_info(text, {})
|
||||||
|
|
||||||
assert len(results) == 3
|
assert len(results) == 3
|
||||||
assert results[1][0] == "[silent](/1s/)"
|
assert results[1][0] == "[silent 1s]"
|
||||||
for sentence, tokens, count in results:
|
for sentence, tokens, count in results:
|
||||||
assert isinstance(sentence, str)
|
assert isinstance(sentence, str)
|
||||||
assert isinstance(tokens, list)
|
assert isinstance(tokens, list)
|
||||||
|
|
|
@ -2,8 +2,8 @@ apiVersion: v2
|
||||||
name: kokoro-fastapi
|
name: kokoro-fastapi
|
||||||
description: A Helm chart for deploying the Kokoro FastAPI TTS service to Kubernetes
|
description: A Helm chart for deploying the Kokoro FastAPI TTS service to Kubernetes
|
||||||
type: application
|
type: application
|
||||||
version: 0.2.0
|
version: 0.3.0
|
||||||
appVersion: "0.2.0"
|
appVersion: "0.3.0"
|
||||||
|
|
||||||
keywords:
|
keywords:
|
||||||
- tts
|
- tts
|
||||||
|
|
|
@ -17,35 +17,44 @@ import base64
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import requests
|
import sys
|
||||||
import time
|
import time
|
||||||
import wave
|
import wave
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
def setup_args():
|
def setup_args():
|
||||||
"""Parse command line arguments"""
|
"""Parse command line arguments"""
|
||||||
parser = argparse.ArgumentParser(description="Test Kokoro TTS for race conditions")
|
parser = argparse.ArgumentParser(description="Test Kokoro TTS for race conditions")
|
||||||
parser.add_argument("--url", default="http://localhost:8880",
|
parser.add_argument(
|
||||||
help="Base URL of the Kokoro TTS service")
|
"--url",
|
||||||
parser.add_argument("--threads", type=int, default=8,
|
default="http://localhost:8880",
|
||||||
help="Number of concurrent threads to use")
|
help="Base URL of the Kokoro TTS service",
|
||||||
parser.add_argument("--iterations", type=int, default=5,
|
)
|
||||||
help="Number of iterations per thread")
|
parser.add_argument(
|
||||||
parser.add_argument("--voice", default="af_heart",
|
"--threads", type=int, default=8, help="Number of concurrent threads to use"
|
||||||
help="Voice to use for TTS")
|
)
|
||||||
parser.add_argument("--output-dir", default="./tts_test_output",
|
parser.add_argument(
|
||||||
help="Directory to save output files")
|
"--iterations", type=int, default=5, help="Number of iterations per thread"
|
||||||
parser.add_argument("--debug", action="store_true",
|
)
|
||||||
help="Enable debug logging")
|
parser.add_argument("--voice", default="af_heart", help="Voice to use for TTS")
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
default="./tts_test_output",
|
||||||
|
help="Directory to save output files",
|
||||||
|
)
|
||||||
|
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def generate_test_sentence(thread_id, iteration):
|
def generate_test_sentence(thread_id, iteration):
|
||||||
"""Generate a simple test sentence with numbers to make mismatches easily identifiable"""
|
"""Generate a simple test sentence with numbers to make mismatches easily identifiable"""
|
||||||
return f"This is test sentence number {thread_id}-{iteration}. " \
|
return (
|
||||||
|
f"This is test sentence number {thread_id}-{iteration}. "
|
||||||
f"If you hear this sentence, you should hear the numbers {thread_id}-{iteration}."
|
f"If you hear this sentence, you should hear the numbers {thread_id}-{iteration}."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def log_message(message, debug=False, is_error=False):
|
def log_message(message, debug=False, is_error=False):
|
||||||
|
@ -73,7 +82,9 @@ def request_tts(url, test_id, text, voice, output_dir, debug=False):
|
||||||
f.write(text)
|
f.write(text)
|
||||||
log_message(f"Thread {test_id}: Successfully saved text file", debug)
|
log_message(f"Thread {test_id}: Successfully saved text file", debug)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log_message(f"Thread {test_id}: Error saving text file: {str(e)}", debug, is_error=True)
|
log_message(
|
||||||
|
f"Thread {test_id}: Error saving text file: {str(e)}", debug, is_error=True
|
||||||
|
)
|
||||||
|
|
||||||
# Make the TTS request
|
# Make the TTS request
|
||||||
try:
|
try:
|
||||||
|
@ -85,56 +96,102 @@ def request_tts(url, test_id, text, voice, output_dir, debug=False):
|
||||||
"model": "kokoro",
|
"model": "kokoro",
|
||||||
"input": text,
|
"input": text,
|
||||||
"voice": voice,
|
"voice": voice,
|
||||||
"response_format": "wav"
|
"response_format": "wav",
|
||||||
},
|
},
|
||||||
headers={"Accept": "audio/wav"},
|
headers={"Accept": "audio/wav"},
|
||||||
timeout=60 # Increase timeout to 60 seconds
|
timeout=60, # Increase timeout to 60 seconds
|
||||||
)
|
)
|
||||||
|
|
||||||
log_message(f"Thread {test_id}: Response status code: {response.status_code}", debug)
|
log_message(
|
||||||
log_message(f"Thread {test_id}: Response content type: {response.headers.get('Content-Type', 'None')}", debug)
|
f"Thread {test_id}: Response status code: {response.status_code}", debug
|
||||||
log_message(f"Thread {test_id}: Response content length: {len(response.content)} bytes", debug)
|
)
|
||||||
|
log_message(
|
||||||
|
f"Thread {test_id}: Response content type: {response.headers.get('Content-Type', 'None')}",
|
||||||
|
debug,
|
||||||
|
)
|
||||||
|
log_message(
|
||||||
|
f"Thread {test_id}: Response content length: {len(response.content)} bytes",
|
||||||
|
debug,
|
||||||
|
)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
log_message(f"Thread {test_id}: API error: {response.status_code} - {response.text}", debug, is_error=True)
|
log_message(
|
||||||
|
f"Thread {test_id}: API error: {response.status_code} - {response.text}",
|
||||||
|
debug,
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check if we got valid audio data
|
# Check if we got valid audio data
|
||||||
if len(response.content) < 100: # Sanity check - WAV files should be larger than this
|
if (
|
||||||
log_message(f"Thread {test_id}: Received suspiciously small audio data: {len(response.content)} bytes", debug, is_error=True)
|
len(response.content) < 100
|
||||||
log_message(f"Thread {test_id}: Content (base64): {base64.b64encode(response.content).decode('utf-8')}", debug, is_error=True)
|
): # Sanity check - WAV files should be larger than this
|
||||||
|
log_message(
|
||||||
|
f"Thread {test_id}: Received suspiciously small audio data: {len(response.content)} bytes",
|
||||||
|
debug,
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
|
log_message(
|
||||||
|
f"Thread {test_id}: Content (base64): {base64.b64encode(response.content).decode('utf-8')}",
|
||||||
|
debug,
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Save the audio output with explicit error handling
|
# Save the audio output with explicit error handling
|
||||||
try:
|
try:
|
||||||
with open(output_file, "wb") as f:
|
with open(output_file, "wb") as f:
|
||||||
bytes_written = f.write(response.content)
|
bytes_written = f.write(response.content)
|
||||||
log_message(f"Thread {test_id}: Wrote {bytes_written} bytes to {output_file}", debug)
|
log_message(
|
||||||
|
f"Thread {test_id}: Wrote {bytes_written} bytes to {output_file}",
|
||||||
|
debug,
|
||||||
|
)
|
||||||
|
|
||||||
# Verify the WAV file exists and has content
|
# Verify the WAV file exists and has content
|
||||||
if os.path.exists(output_file):
|
if os.path.exists(output_file):
|
||||||
file_size = os.path.getsize(output_file)
|
file_size = os.path.getsize(output_file)
|
||||||
log_message(f"Thread {test_id}: Verified file exists with size: {file_size} bytes", debug)
|
log_message(
|
||||||
|
f"Thread {test_id}: Verified file exists with size: {file_size} bytes",
|
||||||
|
debug,
|
||||||
|
)
|
||||||
|
|
||||||
# Validate WAV file by reading its headers
|
# Validate WAV file by reading its headers
|
||||||
try:
|
try:
|
||||||
with wave.open(output_file, 'rb') as wav_file:
|
with wave.open(output_file, "rb") as wav_file:
|
||||||
channels = wav_file.getnchannels()
|
channels = wav_file.getnchannels()
|
||||||
sample_width = wav_file.getsampwidth()
|
sample_width = wav_file.getsampwidth()
|
||||||
framerate = wav_file.getframerate()
|
framerate = wav_file.getframerate()
|
||||||
frames = wav_file.getnframes()
|
frames = wav_file.getnframes()
|
||||||
log_message(f"Thread {test_id}: Valid WAV file - channels: {channels}, "
|
log_message(
|
||||||
f"sample width: {sample_width}, framerate: {framerate}, frames: {frames}", debug)
|
f"Thread {test_id}: Valid WAV file - channels: {channels}, "
|
||||||
|
f"sample width: {sample_width}, framerate: {framerate}, frames: {frames}",
|
||||||
|
debug,
|
||||||
|
)
|
||||||
except Exception as wav_error:
|
except Exception as wav_error:
|
||||||
log_message(f"Thread {test_id}: Invalid WAV file: {str(wav_error)}", debug, is_error=True)
|
log_message(
|
||||||
|
f"Thread {test_id}: Invalid WAV file: {str(wav_error)}",
|
||||||
|
debug,
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
log_message(f"Thread {test_id}: File was not created: {output_file}", debug, is_error=True)
|
log_message(
|
||||||
|
f"Thread {test_id}: File was not created: {output_file}",
|
||||||
|
debug,
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
except Exception as save_error:
|
except Exception as save_error:
|
||||||
log_message(f"Thread {test_id}: Error saving audio file: {str(save_error)}", debug, is_error=True)
|
log_message(
|
||||||
|
f"Thread {test_id}: Error saving audio file: {str(save_error)}",
|
||||||
|
debug,
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
log_message(f"Thread {test_id}: Saved output to {output_file} (time: {end_time - start_time:.2f}s)", debug)
|
log_message(
|
||||||
|
f"Thread {test_id}: Saved output to {output_file} (time: {end_time - start_time:.2f}s)",
|
||||||
|
debug,
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except requests.exceptions.Timeout:
|
except requests.exceptions.Timeout:
|
||||||
|
@ -151,10 +208,16 @@ def worker_task(thread_id, args):
|
||||||
iteration = i + 1
|
iteration = i + 1
|
||||||
test_id = f"{thread_id:02d}_{iteration:02d}"
|
test_id = f"{thread_id:02d}_{iteration:02d}"
|
||||||
text = generate_test_sentence(thread_id, iteration)
|
text = generate_test_sentence(thread_id, iteration)
|
||||||
success = request_tts(args.url, test_id, text, args.voice, args.output_dir, args.debug)
|
success = request_tts(
|
||||||
|
args.url, test_id, text, args.voice, args.output_dir, args.debug
|
||||||
|
)
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
log_message(f"Thread {thread_id}: Iteration {iteration} failed", args.debug, is_error=True)
|
log_message(
|
||||||
|
f"Thread {thread_id}: Iteration {iteration} failed",
|
||||||
|
args.debug,
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Small delay between iterations to avoid overwhelming the API
|
# Small delay between iterations to avoid overwhelming the API
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
@ -171,9 +234,14 @@ def run_test(args):
|
||||||
with open(test_file, "w") as f:
|
with open(test_file, "w") as f:
|
||||||
f.write("Testing write access\n")
|
f.write("Testing write access\n")
|
||||||
os.remove(test_file)
|
os.remove(test_file)
|
||||||
log_message(f"Successfully verified write access to output directory: {args.output_dir}")
|
log_message(
|
||||||
|
f"Successfully verified write access to output directory: {args.output_dir}"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log_message(f"Warning: Cannot write to output directory {args.output_dir}: {str(e)}", is_error=True)
|
log_message(
|
||||||
|
f"Warning: Cannot write to output directory {args.output_dir}: {str(e)}",
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
log_message(f"Current directory: {os.getcwd()}", is_error=True)
|
log_message(f"Current directory: {os.getcwd()}", is_error=True)
|
||||||
log_message(f"Directory contents: {os.listdir('.')}", is_error=True)
|
log_message(f"Directory contents: {os.listdir('.')}", is_error=True)
|
||||||
|
|
||||||
|
@ -183,13 +251,21 @@ def run_test(args):
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
log_message(f"Successfully connected to Kokoro TTS service at {args.url}")
|
log_message(f"Successfully connected to Kokoro TTS service at {args.url}")
|
||||||
else:
|
else:
|
||||||
log_message(f"Warning: Kokoro TTS service health check returned status {response.status_code}", is_error=True)
|
log_message(
|
||||||
|
f"Warning: Kokoro TTS service health check returned status {response.status_code}",
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log_message(f"Warning: Cannot connect to Kokoro TTS service at {args.url}: {str(e)}", is_error=True)
|
log_message(
|
||||||
|
f"Warning: Cannot connect to Kokoro TTS service at {args.url}: {str(e)}",
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Record start time
|
# Record start time
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
log_message(f"Starting test with {args.threads} threads, {args.iterations} iterations per thread")
|
log_message(
|
||||||
|
f"Starting test with {args.threads} threads, {args.iterations} iterations per thread"
|
||||||
|
)
|
||||||
|
|
||||||
# Create and start worker threads
|
# Create and start worker threads
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=args.threads) as executor:
|
with concurrent.futures.ThreadPoolExecutor(max_workers=args.threads) as executor:
|
||||||
|
@ -202,7 +278,9 @@ def run_test(args):
|
||||||
try:
|
try:
|
||||||
future.result()
|
future.result()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log_message(f"Thread execution failed: {str(e)}", args.debug, is_error=True)
|
log_message(
|
||||||
|
f"Thread execution failed: {str(e)}", args.debug, is_error=True
|
||||||
|
)
|
||||||
|
|
||||||
# Record end time and print summary
|
# Record end time and print summary
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
@ -213,8 +291,12 @@ def run_test(args):
|
||||||
log_message(f"Average time per request: {total_time / total_requests:.2f} seconds")
|
log_message(f"Average time per request: {total_time / total_requests:.2f} seconds")
|
||||||
log_message(f"Requests per second: {total_requests / total_time:.2f}")
|
log_message(f"Requests per second: {total_requests / total_time:.2f}")
|
||||||
log_message(f"Output files saved to: {os.path.abspath(args.output_dir)}")
|
log_message(f"Output files saved to: {os.path.abspath(args.output_dir)}")
|
||||||
log_message("To verify, listen to the audio files and check if they match the text files")
|
log_message(
|
||||||
log_message("If you hear audio describing a different test number than the filename, you've found a race condition")
|
"To verify, listen to the audio files and check if they match the text files"
|
||||||
|
)
|
||||||
|
log_message(
|
||||||
|
"If you hear audio describing a different test number than the filename, you've found a race condition"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def analyze_audio_files(output_dir):
|
def analyze_audio_files(output_dir):
|
||||||
|
@ -226,30 +308,34 @@ def analyze_audio_files(output_dir):
|
||||||
log_message(f"Found {len(wav_files)} WAV files and {len(txt_files)} TXT files")
|
log_message(f"Found {len(wav_files)} WAV files and {len(txt_files)} TXT files")
|
||||||
|
|
||||||
if len(wav_files) == 0:
|
if len(wav_files) == 0:
|
||||||
log_message("No WAV files found! This indicates the TTS service requests may be failing.", is_error=True)
|
log_message(
|
||||||
log_message("Check the connection to the TTS service and the response status codes above.", is_error=True)
|
"No WAV files found! This indicates the TTS service requests may be failing.",
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
|
log_message(
|
||||||
|
"Check the connection to the TTS service and the response status codes above.",
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
|
|
||||||
file_stats = []
|
file_stats = []
|
||||||
for wav_path in wav_files:
|
for wav_path in wav_files:
|
||||||
try:
|
try:
|
||||||
with wave.open(str(wav_path), 'rb') as wav_file:
|
with wave.open(str(wav_path), "rb") as wav_file:
|
||||||
frames = wav_file.getnframes()
|
frames = wav_file.getnframes()
|
||||||
rate = wav_file.getframerate()
|
rate = wav_file.getframerate()
|
||||||
duration = frames / rate
|
duration = frames / rate
|
||||||
|
|
||||||
# Get corresponding text
|
# Get corresponding text
|
||||||
text_path = wav_path.with_suffix('.txt')
|
text_path = wav_path.with_suffix(".txt")
|
||||||
if text_path.exists():
|
if text_path.exists():
|
||||||
with open(text_path, 'r') as text_file:
|
with open(text_path, "r") as text_file:
|
||||||
text = text_file.read().strip()
|
text = text_file.read().strip()
|
||||||
else:
|
else:
|
||||||
text = "N/A"
|
text = "N/A"
|
||||||
|
|
||||||
file_stats.append({
|
file_stats.append(
|
||||||
'filename': wav_path.name,
|
{"filename": wav_path.name, "duration": duration, "text": text}
|
||||||
'duration': duration,
|
)
|
||||||
'text': text
|
|
||||||
})
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log_message(f"Error analyzing {wav_path}: {str(e)}", False, is_error=True)
|
log_message(f"Error analyzing {wav_path}: {str(e)}", False, is_error=True)
|
||||||
|
|
||||||
|
@ -259,12 +345,17 @@ def analyze_audio_files(output_dir):
|
||||||
log_message(f"{'Filename':<20}{'Duration':<12}{'Text':<60}")
|
log_message(f"{'Filename':<20}{'Duration':<12}{'Text':<60}")
|
||||||
log_message("-" * 92)
|
log_message("-" * 92)
|
||||||
for stat in file_stats:
|
for stat in file_stats:
|
||||||
log_message(f"{stat['filename']:<20}{stat['duration']:<12.2f}{stat['text'][:57]+'...' if len(stat['text']) > 60 else stat['text']:<60}")
|
log_message(
|
||||||
|
f"{stat['filename']:<20}{stat['duration']:<12.2f}{stat['text'][:57] + '...' if len(stat['text']) > 60 else stat['text']:<60}"
|
||||||
|
)
|
||||||
|
|
||||||
# List missing WAV files where text files exist
|
# List missing WAV files where text files exist
|
||||||
missing_wavs = set(p.stem for p in txt_files) - set(p.stem for p in wav_files)
|
missing_wavs = set(p.stem for p in txt_files) - set(p.stem for p in wav_files)
|
||||||
if missing_wavs:
|
if missing_wavs:
|
||||||
log_message(f"\nFound {len(missing_wavs)} text files without corresponding WAV files:", is_error=True)
|
log_message(
|
||||||
|
f"\nFound {len(missing_wavs)} text files without corresponding WAV files:",
|
||||||
|
is_error=True,
|
||||||
|
)
|
||||||
for stem in sorted(list(missing_wavs))[:10]: # Limit to 10 for readability
|
for stem in sorted(list(missing_wavs))[:10]: # Limit to 10 for readability
|
||||||
log_message(f" - {stem}.txt (no WAV file)", is_error=True)
|
log_message(f" - {stem}.txt (no WAV file)", is_error=True)
|
||||||
if len(missing_wavs) > 10:
|
if len(missing_wavs) > 10:
|
||||||
|
@ -279,5 +370,9 @@ if __name__ == "__main__":
|
||||||
log_message("\nNext Steps:")
|
log_message("\nNext Steps:")
|
||||||
log_message("1. Listen to the generated audio files")
|
log_message("1. Listen to the generated audio files")
|
||||||
log_message("2. Verify if each audio correctly says its ID number")
|
log_message("2. Verify if each audio correctly says its ID number")
|
||||||
log_message("3. Check for any mismatches between the audio content and the text files")
|
log_message(
|
||||||
log_message("4. If mismatches are found, you've successfully reproduced the race condition")
|
"3. Check for any mismatches between the audio content and the text files"
|
||||||
|
)
|
||||||
|
log_message(
|
||||||
|
"4. If mismatches are found, you've successfully reproduced the race condition"
|
||||||
|
)
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
import requests
|
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import pydub
|
import pydub
|
||||||
text="""Delving into the Abyss: A Deeper Exploration of Meaning in 5 Seconds of Summer's "Jet Black Heart"
|
import requests
|
||||||
|
|
||||||
|
text = """Delving into the Abyss: A Deeper Exploration of Meaning in 5 Seconds of Summer's "Jet Black Heart"
|
||||||
|
|
||||||
5 Seconds of Summer, initially perceived as purveyors of upbeat, radio-friendly pop-punk, embarked on a significant artistic evolution with their album Sounds Good Feels Good. Among its tracks, "Jet Black Heart" stands out as a powerful testament to this shift, moving beyond catchy melodies and embracing a darker, more emotionally complex sound. Released in 2015, the song transcends the typical themes of youthful exuberance and romantic angst, instead plunging into the depths of personal turmoil and the corrosive effects of inner darkness on interpersonal relationships. "Jet Black Heart" is not merely a song about heartbreak; it is a raw and vulnerable exploration of internal struggle, self-destructive patterns, and the precarious flicker of hope that persists even in the face of profound emotional chaos. Through potent metaphors, starkly honest lyrics, and a sonic landscape that mirrors its thematic weight, the song offers a profound meditation on the human condition, grappling with the shadows that reside within us all and their far-reaching consequences.
|
5 Seconds of Summer, initially perceived as purveyors of upbeat, radio-friendly pop-punk, embarked on a significant artistic evolution with their album Sounds Good Feels Good. Among its tracks, "Jet Black Heart" stands out as a powerful testament to this shift, moving beyond catchy melodies and embracing a darker, more emotionally complex sound. Released in 2015, the song transcends the typical themes of youthful exuberance and romantic angst, instead plunging into the depths of personal turmoil and the corrosive effects of inner darkness on interpersonal relationships. "Jet Black Heart" is not merely a song about heartbreak; it is a raw and vulnerable exploration of internal struggle, self-destructive patterns, and the precarious flicker of hope that persists even in the face of profound emotional chaos. Through potent metaphors, starkly honest lyrics, and a sonic landscape that mirrors its thematic weight, the song offers a profound meditation on the human condition, grappling with the shadows that reside within us all and their far-reaching consequences.
|
||||||
|
|
||||||
|
@ -23,7 +25,7 @@ In conclusion, "Jet Black Heart" by 5 Seconds of Summer is far more than a typic
|
||||||
5 Seconds of Summer, initially perceived as purveyors of upbeat, radio-friendly pop-punk, embarked on a significant artistic evolution with their album Sounds Good Feels Good. Among its tracks, "Jet Black Heart" stands out as a powerful testament to this shift, moving beyond catchy melodies and embracing a darker, more emotionally complex sound. Released in 2015, the song transcends the typical themes of youthful exuberance and romantic angst, instead plunging into the depths of personal turmoil and the corrosive effects of inner darkness on interpersonal relationships. "Jet Black Heart" is not merely a song about heartbreak; it is a raw and vulnerable exploration of internal struggle, self-destructive patterns, and the precarious flicker of hope that persists even in the face of profound emotional chaos."""
|
5 Seconds of Summer, initially perceived as purveyors of upbeat, radio-friendly pop-punk, embarked on a significant artistic evolution with their album Sounds Good Feels Good. Among its tracks, "Jet Black Heart" stands out as a powerful testament to this shift, moving beyond catchy melodies and embracing a darker, more emotionally complex sound. Released in 2015, the song transcends the typical themes of youthful exuberance and romantic angst, instead plunging into the depths of personal turmoil and the corrosive effects of inner darkness on interpersonal relationships. "Jet Black Heart" is not merely a song about heartbreak; it is a raw and vulnerable exploration of internal struggle, self-destructive patterns, and the precarious flicker of hope that persists even in the face of profound emotional chaos."""
|
||||||
|
|
||||||
|
|
||||||
Type="wav"
|
Type = "wav"
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
"http://localhost:8880/dev/captioned_speech",
|
"http://localhost:8880/dev/captioned_speech",
|
||||||
json={
|
json={
|
||||||
|
@ -34,30 +36,34 @@ response = requests.post(
|
||||||
"response_format": Type,
|
"response_format": Type,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
},
|
},
|
||||||
stream=True
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
f=open(f"outputstream.{Type}","wb")
|
f = open(f"outputstream.{Type}", "wb")
|
||||||
for chunk in response.iter_lines(decode_unicode=True):
|
for chunk in response.iter_lines(decode_unicode=True):
|
||||||
if chunk:
|
if chunk:
|
||||||
temp_json=json.loads(chunk)
|
temp_json = json.loads(chunk)
|
||||||
if temp_json["timestamps"] != []:
|
if temp_json["timestamps"] != []:
|
||||||
chunk_json=temp_json
|
chunk_json = temp_json
|
||||||
|
|
||||||
# Decode base 64 stream to bytes
|
# Decode base 64 stream to bytes
|
||||||
chunk_audio=base64.b64decode(temp_json["audio"].encode("utf-8"))
|
chunk_audio = base64.b64decode(temp_json["audio"].encode("utf-8"))
|
||||||
|
|
||||||
# Process streaming chunks
|
# Process streaming chunks
|
||||||
f.write(chunk_audio)
|
f.write(chunk_audio)
|
||||||
|
|
||||||
# Print word level timestamps
|
# Print word level timestamps
|
||||||
last_chunks={"start_time":chunk_json["timestamps"][-10]["start_time"],"end_time":chunk_json["timestamps"][-3]["end_time"],"word":" ".join([X["word"] for X in chunk_json["timestamps"][-10:-3]])}
|
last_chunks = {
|
||||||
|
"start_time": chunk_json["timestamps"][-10]["start_time"],
|
||||||
|
"end_time": chunk_json["timestamps"][-3]["end_time"],
|
||||||
|
"word": " ".join([X["word"] for X in chunk_json["timestamps"][-10:-3]]),
|
||||||
|
}
|
||||||
|
|
||||||
print(f"CUTTING TO {last_chunks['word']}")
|
print(f"CUTTING TO {last_chunks['word']}")
|
||||||
|
|
||||||
audioseg=pydub.AudioSegment.from_file(f"outputstream.{Type}",format=Type)
|
audioseg = pydub.AudioSegment.from_file(f"outputstream.{Type}", format=Type)
|
||||||
audioseg=audioseg[last_chunks["start_time"]*1000:last_chunks["end_time"] * 1000]
|
audioseg = audioseg[last_chunks["start_time"] * 1000 : last_chunks["end_time"] * 1000]
|
||||||
audioseg.export(f"outputstreamcut.{Type}",format=Type)
|
audioseg.export(f"outputstreamcut.{Type}", format=Type)
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,13 +1,14 @@
|
||||||
import requests
|
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
|
|
||||||
text="""the administration has offered up a platter of repression for more than a year and is still slated to lose $400 million.
|
import requests
|
||||||
|
|
||||||
|
text = """the administration has offered up a platter of repression for more than a year and is still slated to lose $400 million.
|
||||||
|
|
||||||
Columbia is the largest private landowner in New York City and boasts an endowment of $14.8 billion;"""
|
Columbia is the largest private landowner in New York City and boasts an endowment of $14.8 billion;"""
|
||||||
|
|
||||||
|
|
||||||
Type="wav"
|
Type = "wav"
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
"http://localhost:8880/v1/audio/speech",
|
"http://localhost:8880/v1/audio/speech",
|
||||||
|
@ -19,7 +20,7 @@ response = requests.post(
|
||||||
"response_format": Type,
|
"response_format": Type,
|
||||||
"stream": False,
|
"stream": False,
|
||||||
},
|
},
|
||||||
stream=True
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(f"outputnostreammoney.{Type}", "wb") as f:
|
with open(f"outputnostreammoney.{Type}", "wb") as f:
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from text_to_num import text2num
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import inflect
|
import inflect
|
||||||
|
from text_to_num import text2num
|
||||||
from torch import mul
|
from torch import mul
|
||||||
|
|
||||||
INFLECT_ENGINE = inflect.engine()
|
INFLECT_ENGINE = inflect.engine()
|
||||||
|
@ -11,6 +12,7 @@ def conditional_int(number: float, threshold: float = 0.00001):
|
||||||
return int(round(number))
|
return int(round(number))
|
||||||
return number
|
return number
|
||||||
|
|
||||||
|
|
||||||
def handle_money(m: re.Match[str]) -> str:
|
def handle_money(m: re.Match[str]) -> str:
|
||||||
"""Convert money expressions to spoken form"""
|
"""Convert money expressions to spoken form"""
|
||||||
|
|
||||||
|
|
15
dev/Test.py
15
dev/Test.py
|
@ -1,8 +1,9 @@
|
||||||
import requests
|
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
|
|
||||||
text="""Delving into the Abyss: A Deeper Exploration of Meaning in 5 Seconds of Summer's "Jet Black Heart"
|
import requests
|
||||||
|
|
||||||
|
text = """Delving into the Abyss: A Deeper Exploration of Meaning in 5 Seconds of Summer's "Jet Black Heart"
|
||||||
|
|
||||||
5 Seconds of Summer, initially perceived as purveyors of upbeat, radio-friendly pop-punk, embarked on a significant artistic evolution with their album Sounds Good Feels Good. Among its tracks, "Jet Black Heart" stands out as a powerful testament to this shift, moving beyond catchy melodies and embracing a darker, more emotionally complex sound. Released in 2015, the song transcends the typical themes of youthful exuberance and romantic angst, instead plunging into the depths of personal turmoil and the corrosive effects of inner darkness on interpersonal relationships. "Jet Black Heart" is not merely a song about heartbreak; it is a raw and vulnerable exploration of internal struggle, self-destructive patterns, and the precarious flicker of hope that persists even in the face of profound emotional chaos. Through potent metaphors, starkly honest lyrics, and a sonic landscape that mirrors its thematic weight, the song offers a profound meditation on the human condition, grappling with the shadows that reside within us all and their far-reaching consequences.
|
5 Seconds of Summer, initially perceived as purveyors of upbeat, radio-friendly pop-punk, embarked on a significant artistic evolution with their album Sounds Good Feels Good. Among its tracks, "Jet Black Heart" stands out as a powerful testament to this shift, moving beyond catchy melodies and embracing a darker, more emotionally complex sound. Released in 2015, the song transcends the typical themes of youthful exuberance and romantic angst, instead plunging into the depths of personal turmoil and the corrosive effects of inner darkness on interpersonal relationships. "Jet Black Heart" is not merely a song about heartbreak; it is a raw and vulnerable exploration of internal struggle, self-destructive patterns, and the precarious flicker of hope that persists even in the face of profound emotional chaos. Through potent metaphors, starkly honest lyrics, and a sonic landscape that mirrors its thematic weight, the song offers a profound meditation on the human condition, grappling with the shadows that reside within us all and their far-reaching consequences.
|
||||||
|
|
||||||
|
@ -18,12 +19,12 @@ Beyond the lyrical content, the musical elements of "Jet Black Heart" contribute
|
||||||
|
|
||||||
In conclusion, "Jet Black Heart" by 5 Seconds of Summer is far more than a typical pop song; it is a poignant and deeply resonant exploration of inner darkness, self-destructive tendencies, and the fragile yet persistent hope for human connection and redemption. Through its powerful central metaphor of the "jet black heart," its unflinching portrayal of internal turmoil, and its subtle yet potent message of vulnerability and potential transformation, the song resonates with anyone who has grappled with their own inner demons and the complexities of human relationships. It is a reminder that even in the deepest darkness, a flicker of hope can endure, and that true healing and connection often emerge from the courageous act of confronting and sharing our most vulnerable selves. "Jet Black Heart" stands as a testament to 5 Seconds of Summer's artistic growth, showcasing their capacity to delve into profound emotional territories and create music that is not only catchy and engaging but also deeply meaningful and emotionally resonant, solidifying their position as a band capable of capturing the complexities of the human experience."""
|
In conclusion, "Jet Black Heart" by 5 Seconds of Summer is far more than a typical pop song; it is a poignant and deeply resonant exploration of inner darkness, self-destructive tendencies, and the fragile yet persistent hope for human connection and redemption. Through its powerful central metaphor of the "jet black heart," its unflinching portrayal of internal turmoil, and its subtle yet potent message of vulnerability and potential transformation, the song resonates with anyone who has grappled with their own inner demons and the complexities of human relationships. It is a reminder that even in the deepest darkness, a flicker of hope can endure, and that true healing and connection often emerge from the courageous act of confronting and sharing our most vulnerable selves. "Jet Black Heart" stands as a testament to 5 Seconds of Summer's artistic growth, showcasing their capacity to delve into profound emotional territories and create music that is not only catchy and engaging but also deeply meaningful and emotionally resonant, solidifying their position as a band capable of capturing the complexities of the human experience."""
|
||||||
|
|
||||||
text="""Delving into the Abyss: A Deeper Exploration of Meaning in 5 Seconds of Summer's "Jet Black Heart"
|
text = """Delving into the Abyss: A Deeper Exploration of Meaning in 5 Seconds of Summer's "Jet Black Heart"
|
||||||
|
|
||||||
5 Seconds of Summer, initially perceived as purveyors of upbeat, radio-friendly pop-punk, embarked on a significant artistic evolution with their album Sounds Good Feels Good. Among its tracks, "Jet Black Heart" stands out as a powerful testament to this shift, moving beyond catchy melodies and embracing a darker, more emotionally complex sound. Released in 2015, the song transcends the typical themes of youthful exuberance and romantic angst, instead plunging into the depths of personal turmoil and the corrosive effects of inner darkness on interpersonal relationships. "Jet Black Heart" is not merely a song about heartbreak; it is a raw and vulnerable exploration of internal struggle, self-destructive patterns, and the precarious flicker of hope that persists even in the face of profound emotional chaos."""
|
5 Seconds of Summer, initially perceived as purveyors of upbeat, radio-friendly pop-punk, embarked on a significant artistic evolution with their album Sounds Good Feels Good. Among its tracks, "Jet Black Heart" stands out as a powerful testament to this shift, moving beyond catchy melodies and embracing a darker, more emotionally complex sound. Released in 2015, the song transcends the typical themes of youthful exuberance and romantic angst, instead plunging into the depths of personal turmoil and the corrosive effects of inner darkness on interpersonal relationships. "Jet Black Heart" is not merely a song about heartbreak; it is a raw and vulnerable exploration of internal struggle, self-destructive patterns, and the precarious flicker of hope that persists even in the face of profound emotional chaos."""
|
||||||
|
|
||||||
|
|
||||||
Type="wav"
|
Type = "wav"
|
||||||
|
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
|
@ -36,11 +37,11 @@ response = requests.post(
|
||||||
"response_format": Type,
|
"response_format": Type,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
},
|
},
|
||||||
stream=True
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
f=open(f"outputstream.{Type}","wb")
|
f = open(f"outputstream.{Type}", "wb")
|
||||||
for chunk in response.iter_content():
|
for chunk in response.iter_content():
|
||||||
if chunk:
|
if chunk:
|
||||||
# Process streaming chunks
|
# Process streaming chunks
|
||||||
|
@ -56,7 +57,7 @@ response = requests.post(
|
||||||
"response_format": Type,
|
"response_format": Type,
|
||||||
"stream": False,
|
"stream": False,
|
||||||
},
|
},
|
||||||
stream=True
|
stream=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(f"outputnostream.{Type}", "wb") as f:
|
with open(f"outputnostream.{Type}", "wb") as f:
|
||||||
|
|
|
@ -20,7 +20,7 @@ services:
|
||||||
|
|
||||||
# # Gradio UI service [Comment out everything below if you don't need it]
|
# # Gradio UI service [Comment out everything below if you don't need it]
|
||||||
# gradio-ui:
|
# gradio-ui:
|
||||||
# image: ghcr.io/remsky/kokoro-fastapi-ui:v0.2.0
|
# image: ghcr.io/remsky/kokoro-fastapi-ui:v${VERSION}
|
||||||
# # Uncomment below (and comment out above) to build from source instead of using the released image
|
# # Uncomment below (and comment out above) to build from source instead of using the released image
|
||||||
# build:
|
# build:
|
||||||
# context: ../../ui
|
# context: ../../ui
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
name: kokoro-tts-gpu
|
name: kokoro-tts-gpu
|
||||||
services:
|
services:
|
||||||
kokoro-tts:
|
kokoro-tts:
|
||||||
# image: ghcr.io/remsky/kokoro-fastapi-gpu:v0.2.0
|
# image: ghcr.io/remsky/kokoro-fastapi-gpu:v${VERSION}
|
||||||
build:
|
build:
|
||||||
context: ../..
|
context: ../..
|
||||||
dockerfile: docker/gpu/Dockerfile
|
dockerfile: docker/gpu/Dockerfile
|
||||||
|
@ -24,7 +24,7 @@ services:
|
||||||
|
|
||||||
# # Gradio UI service
|
# # Gradio UI service
|
||||||
# gradio-ui:
|
# gradio-ui:
|
||||||
# image: ghcr.io/remsky/kokoro-fastapi-ui:v0.2.0
|
# image: ghcr.io/remsky/kokoro-fastapi-ui:v${VERSION}
|
||||||
# # Uncomment below to build from source instead of using the released image
|
# # Uncomment below to build from source instead of using the released image
|
||||||
# # build:
|
# # build:
|
||||||
# # context: ../../ui
|
# # context: ../../ui
|
||||||
|
|
|
@ -91,9 +91,7 @@ def main():
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Download Kokoro v1.0 model")
|
parser = argparse.ArgumentParser(description="Download Kokoro v1.0 model")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output",
|
"--output", required=True, help="Output directory for model files"
|
||||||
required=True,
|
|
||||||
help="Output directory for model files"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
|
@ -123,7 +123,7 @@ def main():
|
||||||
with open(wells_path, "r", encoding="utf-8") as f:
|
with open(wells_path, "r", encoding="utf-8") as f:
|
||||||
full_text = f.read()
|
full_text = f.read()
|
||||||
# Take first few paragraphs
|
# Take first few paragraphs
|
||||||
text = " ".join(full_text.split("\n\n")[:2])
|
text = " ".join(full_text.split("\n\n")[1:3])
|
||||||
|
|
||||||
print("\nStarting TTS stream playback...")
|
print("\nStarting TTS stream playback...")
|
||||||
print(f"Text length: {len(text)} characters")
|
print(f"Text length: {len(text)} characters")
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[project]
|
[project]
|
||||||
name = "kokoro-fastapi"
|
name = "kokoro-fastapi"
|
||||||
version = "0.1.4"
|
version = "0.3.0"
|
||||||
description = "FastAPI TTS Service"
|
description = "FastAPI TTS Service"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
|
@ -31,10 +31,11 @@ dependencies = [
|
||||||
"matplotlib>=3.10.0",
|
"matplotlib>=3.10.0",
|
||||||
"mutagen>=1.47.0",
|
"mutagen>=1.47.0",
|
||||||
"psutil>=6.1.1",
|
"psutil>=6.1.1",
|
||||||
"kokoro @ git+https://github.com/hexgrad/kokoro.git@31a2b6337b8c1b1418ef68c48142328f640da938",
|
"espeakng-loader==0.2.4",
|
||||||
'misaki[en,ja,ko,zh] @ git+https://github.com/hexgrad/misaki.git@ebc76c21b66c5fc4866ed0ec234047177b396170',
|
"kokoro==0.9.2",
|
||||||
"spacy==3.7.2",
|
"misaki[en,ja,ko,zh]==0.9.3",
|
||||||
"en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl",
|
"spacy==3.8.5",
|
||||||
|
"en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl",
|
||||||
"inflect>=7.5.0",
|
"inflect>=7.5.0",
|
||||||
"phonemizer-fork>=3.3.2",
|
"phonemizer-fork>=3.3.2",
|
||||||
"av>=14.2.0",
|
"av>=14.2.0",
|
||||||
|
@ -53,8 +54,8 @@ test = [
|
||||||
"pytest-cov==6.0.0",
|
"pytest-cov==6.0.0",
|
||||||
"httpx==0.26.0",
|
"httpx==0.26.0",
|
||||||
"pytest-asyncio==0.25.3",
|
"pytest-asyncio==0.25.3",
|
||||||
"openai>=1.59.6",
|
|
||||||
"tomli>=2.0.1",
|
"tomli>=2.0.1",
|
||||||
|
"jinja2>=3.1.6"
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.uv]
|
[tool.uv]
|
||||||
|
|
46
scripts/fix_misaki.py
Normal file
46
scripts/fix_misaki.py
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
"""
|
||||||
|
Patch for misaki package to fix the EspeakWrapper.set_data_path issue.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Find the misaki package
|
||||||
|
try:
|
||||||
|
import misaki
|
||||||
|
|
||||||
|
misaki_path = os.path.dirname(misaki.__file__)
|
||||||
|
print(f"Found misaki package at: {misaki_path}")
|
||||||
|
except ImportError:
|
||||||
|
print("Misaki package not found. Make sure it's installed.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Path to the espeak.py file
|
||||||
|
espeak_file = os.path.join(misaki_path, "espeak.py")
|
||||||
|
|
||||||
|
if not os.path.exists(espeak_file):
|
||||||
|
print(f"Could not find {espeak_file}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Read the current content
|
||||||
|
with open(espeak_file, "r") as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# Check if the problematic line exists
|
||||||
|
if "EspeakWrapper.set_data_path(espeakng_loader.get_data_path())" in content:
|
||||||
|
# Replace the problematic line
|
||||||
|
new_content = content.replace(
|
||||||
|
"EspeakWrapper.set_data_path(espeakng_loader.get_data_path())",
|
||||||
|
"# Fixed line to use data_path attribute instead of set_data_path method\n"
|
||||||
|
"EspeakWrapper.data_path = espeakng_loader.get_data_path()",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Write the modified content back
|
||||||
|
with open(espeak_file, "w") as f:
|
||||||
|
f.write(new_content)
|
||||||
|
|
||||||
|
print(f"Successfully patched {espeak_file}")
|
||||||
|
else:
|
||||||
|
print(f"The problematic line was not found in {espeak_file}")
|
||||||
|
print("The file may have already been patched or the issue is different.")
|
|
@ -1,54 +1,51 @@
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import tomli
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import tomli
|
||||||
|
|
||||||
|
|
||||||
def extract_dependency_info():
|
def extract_dependency_info():
|
||||||
"""Extract version and commit hash for kokoro and misaki from pyproject.toml"""
|
"""Extract version for kokoro and misaki from pyproject.toml"""
|
||||||
with open("pyproject.toml", "rb") as f:
|
with open("pyproject.toml", "rb") as f:
|
||||||
pyproject = tomli.load(f)
|
pyproject = tomli.load(f)
|
||||||
|
|
||||||
deps = pyproject["project"]["dependencies"]
|
deps = pyproject["project"]["dependencies"]
|
||||||
info = {}
|
info = {}
|
||||||
|
kokoro_found = False
|
||||||
|
misaki_found = False
|
||||||
|
|
||||||
# Extract kokoro info
|
|
||||||
for dep in deps:
|
for dep in deps:
|
||||||
if dep.startswith("kokoro @"):
|
# Match kokoro==version
|
||||||
# Extract version from the dependency string if available
|
kokoro_match = re.match(r"^kokoro==(.+)$", dep)
|
||||||
version_match = re.search(r"kokoro @ git\+https://github\.com/hexgrad/kokoro\.git@", dep)
|
if kokoro_match:
|
||||||
if version_match:
|
info["kokoro"] = {"version": kokoro_match.group(1)}
|
||||||
# If no explicit version, use v0.7.9 as shown in the README
|
kokoro_found = True
|
||||||
version = "v0.7.9"
|
|
||||||
commit_match = re.search(r"@([a-f0-9]{7})", dep)
|
# Match misaki[...] ==version or misaki==version
|
||||||
if commit_match:
|
misaki_match = re.match(r"^misaki(?:\[.*?\])?==(.+)$", dep)
|
||||||
info["kokoro"] = {
|
if misaki_match:
|
||||||
"version": version,
|
info["misaki"] = {"version": misaki_match.group(1)}
|
||||||
"commit": commit_match.group(1)
|
misaki_found = True
|
||||||
}
|
|
||||||
elif dep.startswith("misaki["):
|
# Stop if both found
|
||||||
# Extract version from the dependency string if available
|
if kokoro_found and misaki_found:
|
||||||
version_match = re.search(r"misaki\[.*?\] @ git\+https://github\.com/hexgrad/misaki\.git@", dep)
|
break
|
||||||
if version_match:
|
|
||||||
# If no explicit version, use v0.7.9 as shown in the README
|
if not kokoro_found:
|
||||||
version = "v0.7.9"
|
raise ValueError("Kokoro version not found in pyproject.toml dependencies")
|
||||||
commit_match = re.search(r"@([a-f0-9]{7})", dep)
|
if not misaki_found:
|
||||||
if commit_match:
|
raise ValueError("Misaki version not found in pyproject.toml dependencies")
|
||||||
info["misaki"] = {
|
|
||||||
"version": version,
|
|
||||||
"commit": commit_match.group(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
|
||||||
def run_pytest_with_coverage():
|
def run_pytest_with_coverage():
|
||||||
"""Run pytest with coverage and return the results"""
|
"""Run pytest with coverage and return the results"""
|
||||||
try:
|
try:
|
||||||
# Run pytest with coverage
|
# Run pytest with coverage
|
||||||
result = subprocess.run(
|
result = subprocess.run(
|
||||||
["pytest", "--cov=api", "-v"],
|
["pytest", "--cov=api", "-v"], capture_output=True, text=True, check=True
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
check=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract test results
|
# Extract test results
|
||||||
|
@ -57,10 +54,7 @@ def run_pytest_with_coverage():
|
||||||
|
|
||||||
# Extract coverage from .coverage file
|
# Extract coverage from .coverage file
|
||||||
coverage_output = subprocess.run(
|
coverage_output = subprocess.run(
|
||||||
["coverage", "report"],
|
["coverage", "report"], capture_output=True, text=True, check=True
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
check=True
|
|
||||||
).stdout
|
).stdout
|
||||||
|
|
||||||
# Extract total coverage percentage
|
# Extract total coverage percentage
|
||||||
|
@ -73,6 +67,7 @@ def run_pytest_with_coverage():
|
||||||
print(f"Output: {e.output}")
|
print(f"Output: {e.output}")
|
||||||
return 0, "0"
|
return 0, "0"
|
||||||
|
|
||||||
|
|
||||||
def update_readme_badges(passed_tests, coverage_percentage, dep_info):
|
def update_readme_badges(passed_tests, coverage_percentage, dep_info):
|
||||||
"""Update the badges in the README file"""
|
"""Update the badges in the README file"""
|
||||||
readme_path = Path("README.md")
|
readme_path = Path("README.md")
|
||||||
|
@ -84,37 +79,42 @@ def update_readme_badges(passed_tests, coverage_percentage, dep_info):
|
||||||
|
|
||||||
# Update tests badge
|
# Update tests badge
|
||||||
content = re.sub(
|
content = re.sub(
|
||||||
r'!\[Tests\]\(https://img\.shields\.io/badge/tests-\d+%20passed-[a-zA-Z]+\)',
|
r"!\[Tests\]\(https://img\.shields\.io/badge/tests-\d+%20passed-[a-zA-Z]+\)",
|
||||||
f'',
|
f"",
|
||||||
content
|
content,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update coverage badge
|
# Update coverage badge
|
||||||
content = re.sub(
|
content = re.sub(
|
||||||
r'!\[Coverage\]\(https://img\.shields\.io/badge/coverage-\d+%25-[a-zA-Z]+\)',
|
r"!\[Coverage\]\(https://img\.shields\.io/badge/coverage-\d+%25-[a-zA-Z]+\)",
|
||||||
f'',
|
f"",
|
||||||
content
|
content,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update kokoro badge
|
# Update kokoro badge
|
||||||
if "kokoro" in dep_info:
|
if "kokoro" in dep_info:
|
||||||
|
# Find badge like kokoro-v0.9.2::abcdefg-BB5420 or kokoro-v0.9.2-BB5420
|
||||||
|
kokoro_version = dep_info["kokoro"]["version"]
|
||||||
content = re.sub(
|
content = re.sub(
|
||||||
r'!\[Kokoro\]\(https://img\.shields\.io/badge/kokoro-[^)]+\)',
|
r"(!\[Kokoro\]\(https://img\.shields\.io/badge/kokoro-)[^)-]+(-BB5420\))",
|
||||||
f'',
|
lambda m: f"{m.group(1)}{kokoro_version}{m.group(2)}",
|
||||||
content
|
content,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update misaki badge
|
# Update misaki badge
|
||||||
if "misaki" in dep_info:
|
if "misaki" in dep_info:
|
||||||
|
# Find badge like misaki-v0.9.3::abcdefg-B8860B or misaki-v0.9.3-B8860B
|
||||||
|
misaki_version = dep_info["misaki"]["version"]
|
||||||
content = re.sub(
|
content = re.sub(
|
||||||
r'!\[Misaki\]\(https://img\.shields\.io/badge/misaki-[^)]+\)',
|
r"(!\[Misaki\]\(https://img\.shields\.io/badge/misaki-)[^)-]+(-B8860B\))",
|
||||||
f'',
|
lambda m: f"{m.group(1)}{misaki_version}{m.group(2)}",
|
||||||
content
|
content,
|
||||||
)
|
)
|
||||||
|
|
||||||
readme_path.write_text(content)
|
readme_path.write_text(content)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Get dependency info
|
# Get dependency info
|
||||||
dep_info = extract_dependency_info()
|
dep_info = extract_dependency_info()
|
||||||
|
@ -128,11 +128,12 @@ def main():
|
||||||
print(f"- Tests: {passed_tests} passed")
|
print(f"- Tests: {passed_tests} passed")
|
||||||
print(f"- Coverage: {coverage_percentage}%")
|
print(f"- Coverage: {coverage_percentage}%")
|
||||||
if "kokoro" in dep_info:
|
if "kokoro" in dep_info:
|
||||||
print(f"- Kokoro: {dep_info['kokoro']['version']}::{dep_info['kokoro']['commit']}")
|
print(f"- Kokoro: {dep_info['kokoro']['version']}")
|
||||||
if "misaki" in dep_info:
|
if "misaki" in dep_info:
|
||||||
print(f"- Misaki: {dep_info['misaki']['version']}::{dep_info['misaki']['commit']}")
|
print(f"- Misaki: {dep_info['misaki']['version']}")
|
||||||
else:
|
else:
|
||||||
print("Failed to update badges")
|
print("Failed to update badges")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
234
scripts/update_version.py
Executable file
234
scripts/update_version.py
Executable file
|
@ -0,0 +1,234 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Version Update Script
|
||||||
|
|
||||||
|
This script reads the version from the VERSION file and updates references
|
||||||
|
in pyproject.toml, the Helm chart, and README.md.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
# Get the project root directory
|
||||||
|
ROOT_DIR = Path(__file__).parent.parent
|
||||||
|
|
||||||
|
# --- Configuration ---
|
||||||
|
VERSION_FILE = ROOT_DIR / "VERSION"
|
||||||
|
PYPROJECT_FILE = ROOT_DIR / "pyproject.toml"
|
||||||
|
HELM_CHART_FILE = ROOT_DIR / "charts" / "kokoro-fastapi" / "Chart.yaml"
|
||||||
|
README_FILE = ROOT_DIR / "README.md"
|
||||||
|
# --- End Configuration ---
|
||||||
|
|
||||||
|
|
||||||
|
def update_pyproject(version: str):
|
||||||
|
"""Updates the version in pyproject.toml"""
|
||||||
|
if not PYPROJECT_FILE.exists():
|
||||||
|
print(f"Skipping: {PYPROJECT_FILE} not found.")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
content = PYPROJECT_FILE.read_text()
|
||||||
|
# Regex to find and capture current version = "X.Y.Z" under [project]
|
||||||
|
pattern = r'(^\[project\]\s*(?:.*\s)*?version\s*=\s*)"([^"]+)"'
|
||||||
|
match = re.search(pattern, content, flags=re.MULTILINE)
|
||||||
|
|
||||||
|
if not match:
|
||||||
|
print(f"Warning: Version pattern not found in {PYPROJECT_FILE}")
|
||||||
|
return
|
||||||
|
|
||||||
|
current_version = match.group(2)
|
||||||
|
if current_version == version:
|
||||||
|
print(f"Already up-to-date: {PYPROJECT_FILE} (version {version})")
|
||||||
|
else:
|
||||||
|
# Perform replacement
|
||||||
|
new_content = re.sub(
|
||||||
|
pattern, rf'\1"{version}"', content, count=1, flags=re.MULTILINE
|
||||||
|
)
|
||||||
|
PYPROJECT_FILE.write_text(new_content)
|
||||||
|
print(f"Updated {PYPROJECT_FILE} from {current_version} to {version}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {PYPROJECT_FILE}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def update_helm_chart(version: str):
|
||||||
|
"""Updates the version and appVersion in the Helm chart"""
|
||||||
|
if not HELM_CHART_FILE.exists():
|
||||||
|
print(f"Skipping: {HELM_CHART_FILE} not found.")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
content = HELM_CHART_FILE.read_text()
|
||||||
|
original_content = content
|
||||||
|
updated_count = 0
|
||||||
|
|
||||||
|
# Update 'version:' line (unquoted)
|
||||||
|
# Looks for 'version:' followed by optional whitespace and the version number
|
||||||
|
version_pattern = r"^(version:\s*)(\S+)"
|
||||||
|
current_version_match = re.search(version_pattern, content, flags=re.MULTILINE)
|
||||||
|
if current_version_match and current_version_match.group(2) != version:
|
||||||
|
content = re.sub(
|
||||||
|
version_pattern,
|
||||||
|
rf"\g<1>{version}",
|
||||||
|
content,
|
||||||
|
count=1,
|
||||||
|
flags=re.MULTILINE,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Updating 'version' in {HELM_CHART_FILE} from {current_version_match.group(2)} to {version}"
|
||||||
|
)
|
||||||
|
updated_count += 1
|
||||||
|
elif current_version_match:
|
||||||
|
print(f"Already up-to-date: 'version' in {HELM_CHART_FILE} is {version}")
|
||||||
|
else:
|
||||||
|
print(f"Warning: 'version:' pattern not found in {HELM_CHART_FILE}")
|
||||||
|
|
||||||
|
# Update 'appVersion:' line (quoted or unquoted)
|
||||||
|
# Looks for 'appVersion:' followed by optional whitespace, optional quote, the version, optional quote
|
||||||
|
app_version_pattern = r"^(appVersion:\s*)(\"?)([^\"\s]+)(\"?)"
|
||||||
|
current_app_version_match = re.search(
|
||||||
|
app_version_pattern, content, flags=re.MULTILINE
|
||||||
|
)
|
||||||
|
|
||||||
|
if current_app_version_match:
|
||||||
|
leading_whitespace = current_app_version_match.group(
|
||||||
|
1
|
||||||
|
) # e.g., "appVersion: "
|
||||||
|
opening_quote = current_app_version_match.group(2) # e.g., '"' or ''
|
||||||
|
current_app_ver = current_app_version_match.group(3) # e.g., '0.2.0'
|
||||||
|
closing_quote = current_app_version_match.group(4) # e.g., '"' or ''
|
||||||
|
|
||||||
|
# Check if quotes were consistent (both present or both absent)
|
||||||
|
if opening_quote != closing_quote:
|
||||||
|
print(
|
||||||
|
f"Warning: Inconsistent quotes found for appVersion in {HELM_CHART_FILE}. Skipping update for this line."
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
current_app_ver == version and opening_quote == '"'
|
||||||
|
): # Check if already correct *and* quoted
|
||||||
|
print(
|
||||||
|
f"Already up-to-date: 'appVersion' in {HELM_CHART_FILE} is \"{version}\""
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Always replace with the quoted version
|
||||||
|
replacement = f'{leading_whitespace}"{version}"' # Ensure quotes
|
||||||
|
original_display = f"{opening_quote}{current_app_ver}{closing_quote}" # How it looked before
|
||||||
|
target_display = f'"{version}"' # How it should look
|
||||||
|
|
||||||
|
# Only report update if the displayed value actually changes
|
||||||
|
if original_display != target_display:
|
||||||
|
content = re.sub(
|
||||||
|
app_version_pattern,
|
||||||
|
replacement,
|
||||||
|
content,
|
||||||
|
count=1,
|
||||||
|
flags=re.MULTILINE,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Updating 'appVersion' in {HELM_CHART_FILE} from {original_display} to {target_display}"
|
||||||
|
)
|
||||||
|
updated_count += 1
|
||||||
|
else:
|
||||||
|
# It matches the target version but might need quoting fixed silently if we didn't update
|
||||||
|
# Or it was already correct. Check if content changed. If not, report up-to-date.
|
||||||
|
if not (
|
||||||
|
content != original_content and updated_count > 0
|
||||||
|
): # Avoid double message if version also changed
|
||||||
|
print(
|
||||||
|
f"Already up-to-date: 'appVersion' in {HELM_CHART_FILE} is {target_display}"
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(f"Warning: 'appVersion:' pattern not found in {HELM_CHART_FILE}")
|
||||||
|
|
||||||
|
# Write back only if changes were made
|
||||||
|
if content != original_content:
|
||||||
|
HELM_CHART_FILE.write_text(content)
|
||||||
|
# Confirmation message printed above during the specific update
|
||||||
|
elif updated_count == 0 and current_version_match and current_app_version_match:
|
||||||
|
# If no updates were made but patterns were found, confirm it's up-to-date overall
|
||||||
|
print(f"Already up-to-date: {HELM_CHART_FILE} (version {version})")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {HELM_CHART_FILE}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def update_readme(version_with_v: str):
|
||||||
|
"""Updates Docker image tags in README.md"""
|
||||||
|
if not README_FILE.exists():
|
||||||
|
print(f"Skipping: {README_FILE} not found.")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
content = README_FILE.read_text()
|
||||||
|
# Regex to find and capture current ghcr.io/.../kokoro-fastapi-(cpu|gpu):vX.Y.Z
|
||||||
|
pattern = r"(ghcr\.io/remsky/kokoro-fastapi-(?:cpu|gpu)):(v\d+\.\d+\.\d+)"
|
||||||
|
matches = list(re.finditer(pattern, content)) # Find all occurrences
|
||||||
|
|
||||||
|
if not matches:
|
||||||
|
print(f"Warning: Docker image tag pattern not found in {README_FILE}")
|
||||||
|
else:
|
||||||
|
updated_needed = False
|
||||||
|
for match in matches:
|
||||||
|
current_tag = match.group(2)
|
||||||
|
if current_tag != version_with_v:
|
||||||
|
updated_needed = True
|
||||||
|
break # Only need one mismatch to trigger update
|
||||||
|
|
||||||
|
if updated_needed:
|
||||||
|
# Perform replacement on all occurrences
|
||||||
|
new_content = re.sub(pattern, rf"\1:{version_with_v}", content)
|
||||||
|
README_FILE.write_text(new_content)
|
||||||
|
print(f"Updated Docker image tags in {README_FILE} to {version_with_v}")
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"Already up-to-date: Docker image tags in {README_FILE} (version {version_with_v})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for ':latest' tag usage remains the same
|
||||||
|
if ":latest" in content:
|
||||||
|
print(
|
||||||
|
f"Warning: Found ':latest' tag in {README_FILE}. Consider updating manually if needed."
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {README_FILE}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Read the version from the VERSION file
|
||||||
|
if not VERSION_FILE.exists():
|
||||||
|
print(f"Error: {VERSION_FILE} not found.")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
version = VERSION_FILE.read_text().strip()
|
||||||
|
if not re.match(r"^\d+\.\d+\.\d+$", version):
|
||||||
|
print(
|
||||||
|
f"Error: Invalid version format '{version}' in {VERSION_FILE}. Expected X.Y.Z"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading {VERSION_FILE}: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"Read version: {version} from {VERSION_FILE}")
|
||||||
|
print("-" * 20)
|
||||||
|
|
||||||
|
# Prepare versions (with and without 'v')
|
||||||
|
version_plain = version
|
||||||
|
version_with_v = f"v{version}"
|
||||||
|
|
||||||
|
# Update files
|
||||||
|
update_pyproject(version_plain)
|
||||||
|
update_helm_chart(version_plain)
|
||||||
|
update_readme(version_with_v)
|
||||||
|
|
||||||
|
print("-" * 20)
|
||||||
|
print("Version update script finished.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -1,49 +0,0 @@
|
||||||
{
|
|
||||||
"document": "doc.report.command",
|
|
||||||
"version": "ov/command/slim/1.1",
|
|
||||||
"engine": "linux/amd64|ALP|x.1.42.2|29e62e7836de7b1004607c51c502537ffe1969f0|2025-01-16_07:48:54AM|x",
|
|
||||||
"containerized": false,
|
|
||||||
"host_distro": {
|
|
||||||
"name": "Ubuntu",
|
|
||||||
"version": "22.04",
|
|
||||||
"display_name": "Ubuntu 22.04.5 LTS"
|
|
||||||
},
|
|
||||||
"type": "slim",
|
|
||||||
"state": "error",
|
|
||||||
"target_reference": "kokoro-fastapi:latest",
|
|
||||||
"system": {
|
|
||||||
"type": "",
|
|
||||||
"release": "",
|
|
||||||
"distro": {
|
|
||||||
"name": "",
|
|
||||||
"version": "",
|
|
||||||
"display_name": ""
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"source_image": {
|
|
||||||
"identity": {
|
|
||||||
"id": ""
|
|
||||||
},
|
|
||||||
"size": 0,
|
|
||||||
"size_human": "",
|
|
||||||
"create_time": "",
|
|
||||||
"architecture": "",
|
|
||||||
"container_entry": {
|
|
||||||
"exe_path": ""
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"minified_image_size": 0,
|
|
||||||
"minified_image_size_human": "",
|
|
||||||
"minified_image": "",
|
|
||||||
"minified_image_id": "",
|
|
||||||
"minified_image_digest": "",
|
|
||||||
"minified_image_has_data": false,
|
|
||||||
"minified_by": 0,
|
|
||||||
"artifact_location": "",
|
|
||||||
"container_report_name": "",
|
|
||||||
"seccomp_profile_name": "",
|
|
||||||
"apparmor_profile_name": "",
|
|
||||||
"image_stack": null,
|
|
||||||
"image_created": false,
|
|
||||||
"image_build_engine": ""
|
|
||||||
}
|
|
|
@ -10,9 +10,17 @@ export PYTHONPATH=$PROJECT_ROOT:$PROJECT_ROOT/api
|
||||||
export MODEL_DIR=src/models
|
export MODEL_DIR=src/models
|
||||||
export VOICES_DIR=src/voices/v1_0
|
export VOICES_DIR=src/voices/v1_0
|
||||||
export WEB_PLAYER_PATH=$PROJECT_ROOT/web
|
export WEB_PLAYER_PATH=$PROJECT_ROOT/web
|
||||||
|
# Set the espeak-ng data path to your location
|
||||||
|
export ESPEAK_DATA_PATH=/usr/lib/x86_64-linux-gnu/espeak-ng-data
|
||||||
|
|
||||||
# Run FastAPI with CPU extras using uv run
|
# Run FastAPI with CPU extras using uv run
|
||||||
# Note: espeak may still require manual installation,
|
# Note: espeak may still require manual installation,
|
||||||
uv pip install -e ".[cpu]"
|
uv pip install -e ".[cpu]"
|
||||||
uv run --no-sync python docker/scripts/download_model.py --output api/src/models/v1_0
|
uv run --no-sync python docker/scripts/download_model.py --output api/src/models/v1_0
|
||||||
|
|
||||||
|
# Apply the misaki patch to fix possible EspeakWrapper issue in older versions
|
||||||
|
# echo "Applying misaki patch..."
|
||||||
|
# python scripts/fix_misaki.py
|
||||||
|
|
||||||
|
# Start the server
|
||||||
uv run --no-sync uvicorn api.src.main:app --host 0.0.0.0 --port 8880
|
uv run --no-sync uvicorn api.src.main:app --host 0.0.0.0 --port 8880
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import pytest
|
|
||||||
from unittest.mock import AsyncMock, Mock
|
from unittest.mock import AsyncMock, Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from api.src.services.tts_service import TTSService
|
from api.src.services.tts_service import TTSService
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,8 +31,11 @@ async def mock_tts_service(mock_model_manager, mock_voice_manager):
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
async def setup_mocks(monkeypatch, mock_model_manager, mock_voice_manager, mock_tts_service):
|
async def setup_mocks(
|
||||||
|
monkeypatch, mock_model_manager, mock_voice_manager, mock_tts_service
|
||||||
|
):
|
||||||
"""Setup global mocks for UI tests"""
|
"""Setup global mocks for UI tests"""
|
||||||
|
|
||||||
async def mock_get_model():
|
async def mock_get_model():
|
||||||
return mock_model_manager
|
return mock_model_manager
|
||||||
|
|
||||||
|
@ -43,4 +47,6 @@ async def setup_mocks(monkeypatch, mock_model_manager, mock_voice_manager, mock_
|
||||||
|
|
||||||
monkeypatch.setattr("api.src.inference.model_manager.get_manager", mock_get_model)
|
monkeypatch.setattr("api.src.inference.model_manager.get_manager", mock_get_model)
|
||||||
monkeypatch.setattr("api.src.inference.voice_manager.get_manager", mock_get_voice)
|
monkeypatch.setattr("api.src.inference.voice_manager.get_manager", mock_get_voice)
|
||||||
monkeypatch.setattr("api.src.services.tts_service.TTSService.create", mock_create_service)
|
monkeypatch.setattr(
|
||||||
|
"api.src.services.tts_service.TTSService.create", mock_create_service
|
||||||
|
)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from unittest.mock import patch, mock_open
|
from unittest.mock import mock_open, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
@ -59,9 +59,11 @@ def test_check_api_status_connection_error():
|
||||||
|
|
||||||
def test_text_to_speech_success(mock_response, tmp_path):
|
def test_text_to_speech_success(mock_response, tmp_path):
|
||||||
"""Test successful speech generation"""
|
"""Test successful speech generation"""
|
||||||
with patch("requests.post", return_value=mock_response({})), patch(
|
with (
|
||||||
"ui.lib.api.OUTPUTS_DIR", str(tmp_path)
|
patch("requests.post", return_value=mock_response({})),
|
||||||
), patch("builtins.open", mock_open()) as mock_file:
|
patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)),
|
||||||
|
patch("builtins.open", mock_open()) as mock_file,
|
||||||
|
):
|
||||||
result = api.text_to_speech("test text", "voice1", "mp3", 1.0)
|
result = api.text_to_speech("test text", "voice1", "mp3", 1.0)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
|
@ -116,9 +118,11 @@ def test_text_to_speech_api_params(mock_response, tmp_path):
|
||||||
]
|
]
|
||||||
|
|
||||||
for input_voice, expected_voice in test_cases:
|
for input_voice, expected_voice in test_cases:
|
||||||
with patch("requests.post") as mock_post, patch(
|
with (
|
||||||
"ui.lib.api.OUTPUTS_DIR", str(tmp_path)
|
patch("requests.post") as mock_post,
|
||||||
), patch("builtins.open", mock_open()):
|
patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)),
|
||||||
|
patch("builtins.open", mock_open()),
|
||||||
|
):
|
||||||
mock_post.return_value = mock_response({})
|
mock_post.return_value = mock_response({})
|
||||||
api.text_to_speech("test text", input_voice, "mp3", 1.5)
|
api.text_to_speech("test text", input_voice, "mp3", 1.5)
|
||||||
|
|
||||||
|
@ -149,11 +153,15 @@ def test_text_to_speech_output_filename(mock_response, tmp_path):
|
||||||
]
|
]
|
||||||
|
|
||||||
for input_voice, filename_check in test_cases:
|
for input_voice, filename_check in test_cases:
|
||||||
with patch("requests.post", return_value=mock_response({})), patch(
|
with (
|
||||||
"ui.lib.api.OUTPUTS_DIR", str(tmp_path)
|
patch("requests.post", return_value=mock_response({})),
|
||||||
), patch("builtins.open", mock_open()) as mock_file:
|
patch("ui.lib.api.OUTPUTS_DIR", str(tmp_path)),
|
||||||
|
patch("builtins.open", mock_open()) as mock_file,
|
||||||
|
):
|
||||||
result = api.text_to_speech("test text", input_voice, "mp3", 1.0)
|
result = api.text_to_speech("test text", input_voice, "mp3", 1.0)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert filename_check(result), f"Expected voice pattern not found in filename: {result}"
|
assert filename_check(result), (
|
||||||
|
f"Expected voice pattern not found in filename: {result}"
|
||||||
|
)
|
||||||
mock_file.assert_called_once()
|
mock_file.assert_called_once()
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ui.lib.config import AUDIO_FORMATS
|
|
||||||
from ui.lib.components.model import create_model_column
|
from ui.lib.components.model import create_model_column
|
||||||
from ui.lib.components.output import create_output_column
|
from ui.lib.components.output import create_output_column
|
||||||
|
from ui.lib.config import AUDIO_FORMATS
|
||||||
|
|
||||||
|
|
||||||
def test_create_model_column_structure():
|
def test_create_model_column_structure():
|
||||||
|
|
|
@ -15,8 +15,9 @@ def mock_dirs(tmp_path):
|
||||||
inputs_dir.mkdir()
|
inputs_dir.mkdir()
|
||||||
outputs_dir.mkdir()
|
outputs_dir.mkdir()
|
||||||
|
|
||||||
with patch("ui.lib.files.INPUTS_DIR", str(inputs_dir)), patch(
|
with (
|
||||||
"ui.lib.files.OUTPUTS_DIR", str(outputs_dir)
|
patch("ui.lib.files.INPUTS_DIR", str(inputs_dir)),
|
||||||
|
patch("ui.lib.files.OUTPUTS_DIR", str(outputs_dir)),
|
||||||
):
|
):
|
||||||
yield inputs_dir, outputs_dir
|
yield inputs_dir, outputs_dir
|
||||||
|
|
||||||
|
|
|
@ -62,8 +62,9 @@ def test_interface_html_links():
|
||||||
def test_update_status_available(mock_timer):
|
def test_update_status_available(mock_timer):
|
||||||
"""Test status update when service is available"""
|
"""Test status update when service is available"""
|
||||||
voices = ["voice1", "voice2"]
|
voices = ["voice1", "voice2"]
|
||||||
with patch("ui.lib.api.check_api_status", return_value=(True, voices)), patch(
|
with (
|
||||||
"gradio.Timer", return_value=mock_timer
|
patch("ui.lib.api.check_api_status", return_value=(True, voices)),
|
||||||
|
patch("gradio.Timer", return_value=mock_timer),
|
||||||
):
|
):
|
||||||
demo = create_interface()
|
demo = create_interface()
|
||||||
|
|
||||||
|
@ -81,8 +82,9 @@ def test_update_status_available(mock_timer):
|
||||||
|
|
||||||
def test_update_status_unavailable(mock_timer):
|
def test_update_status_unavailable(mock_timer):
|
||||||
"""Test status update when service is unavailable"""
|
"""Test status update when service is unavailable"""
|
||||||
with patch("ui.lib.api.check_api_status", return_value=(False, [])), patch(
|
with (
|
||||||
"gradio.Timer", return_value=mock_timer
|
patch("ui.lib.api.check_api_status", return_value=(False, [])),
|
||||||
|
patch("gradio.Timer", return_value=mock_timer),
|
||||||
):
|
):
|
||||||
demo = create_interface()
|
demo = create_interface()
|
||||||
update_fn = mock_timer.events[0].fn
|
update_fn = mock_timer.events[0].fn
|
||||||
|
@ -97,9 +99,10 @@ def test_update_status_unavailable(mock_timer):
|
||||||
|
|
||||||
def test_update_status_error(mock_timer):
|
def test_update_status_error(mock_timer):
|
||||||
"""Test status update when an error occurs"""
|
"""Test status update when an error occurs"""
|
||||||
with patch(
|
with (
|
||||||
"ui.lib.api.check_api_status", side_effect=Exception("Test error")
|
patch("ui.lib.api.check_api_status", side_effect=Exception("Test error")),
|
||||||
), patch("gradio.Timer", return_value=mock_timer):
|
patch("gradio.Timer", return_value=mock_timer),
|
||||||
|
):
|
||||||
demo = create_interface()
|
demo = create_interface()
|
||||||
update_fn = mock_timer.events[0].fn
|
update_fn = mock_timer.events[0].fn
|
||||||
|
|
||||||
|
@ -113,8 +116,9 @@ def test_update_status_error(mock_timer):
|
||||||
|
|
||||||
def test_timer_configuration(mock_timer):
|
def test_timer_configuration(mock_timer):
|
||||||
"""Test timer configuration"""
|
"""Test timer configuration"""
|
||||||
with patch("ui.lib.api.check_api_status", return_value=(False, [])), patch(
|
with (
|
||||||
"gradio.Timer", return_value=mock_timer
|
patch("ui.lib.api.check_api_status", return_value=(False, [])),
|
||||||
|
patch("gradio.Timer", return_value=mock_timer),
|
||||||
):
|
):
|
||||||
demo = create_interface()
|
demo = create_interface()
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import os
|
|
||||||
import datetime
|
import datetime
|
||||||
from typing import List, Tuple, Optional
|
import os
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
|
|
@ -13,9 +13,7 @@ def create_input_column(disable_local_saving: bool = False) -> Tuple[gr.Column,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Always show file upload but handle differently based on disable_local_saving
|
# Always show file upload but handle differently based on disable_local_saving
|
||||||
file_upload = gr.File(
|
file_upload = gr.File(label="Upload Text File (.txt)", file_types=[".txt"])
|
||||||
label="Upload Text File (.txt)", file_types=[".txt"]
|
|
||||||
)
|
|
||||||
|
|
||||||
if not disable_local_saving:
|
if not disable_local_saving:
|
||||||
# Show full interface with tabs when saving is enabled
|
# Show full interface with tabs when saving is enabled
|
||||||
|
@ -24,7 +22,9 @@ def create_input_column(disable_local_saving: bool = False) -> Tuple[gr.Column,
|
||||||
tabs.selected = 0
|
tabs.selected = 0
|
||||||
# Direct Input Tab
|
# Direct Input Tab
|
||||||
with gr.TabItem("Direct Input"):
|
with gr.TabItem("Direct Input"):
|
||||||
text_submit_direct = gr.Button("Generate Speech", variant="primary", size="lg")
|
text_submit_direct = gr.Button(
|
||||||
|
"Generate Speech", variant="primary", size="lg"
|
||||||
|
)
|
||||||
|
|
||||||
# File Input Tab
|
# File Input Tab
|
||||||
with gr.TabItem("From File"):
|
with gr.TabItem("From File"):
|
||||||
|
@ -48,7 +48,9 @@ def create_input_column(disable_local_saving: bool = False) -> Tuple[gr.Column,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Just show the generate button when saving is disabled
|
# Just show the generate button when saving is disabled
|
||||||
text_submit_direct = gr.Button("Generate Speech", variant="primary", size="lg")
|
text_submit_direct = gr.Button(
|
||||||
|
"Generate Speech", variant="primary", size="lg"
|
||||||
|
)
|
||||||
tabs = None
|
tabs = None
|
||||||
input_files_list = None
|
input_files_list = None
|
||||||
file_preview = None
|
file_preview = None
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Tuple, Optional
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ def create_output_column(disable_local_saving: bool = False) -> Tuple[gr.Column,
|
||||||
audio_output = gr.Audio(
|
audio_output = gr.Audio(
|
||||||
label="Generated Speech",
|
label="Generated Speech",
|
||||||
type="filepath",
|
type="filepath",
|
||||||
waveform_options={"waveform_color": "#4C87AB"}
|
waveform_options={"waveform_color": "#4C87AB"},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create file-related components with visible=False when local saving is disabled
|
# Create file-related components with visible=False when local saving is disabled
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
import os
|
|
||||||
import datetime
|
import datetime
|
||||||
from typing import List, Tuple, Optional
|
import os
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from .config import INPUTS_DIR, OUTPUTS_DIR, AUDIO_FORMATS
|
from .config import AUDIO_FORMATS, INPUTS_DIR, OUTPUTS_DIR
|
||||||
|
|
||||||
|
|
||||||
def list_input_files() -> List[str]:
|
def list_input_files() -> List[str]:
|
||||||
|
|
|
@ -58,17 +58,21 @@ def setup_event_handlers(components: dict, disable_local_saving: bool = False):
|
||||||
|
|
||||||
def handle_file_upload(file):
|
def handle_file_upload(file):
|
||||||
if file is None:
|
if file is None:
|
||||||
return "" if disable_local_saving else [gr.update(choices=files.list_input_files())]
|
return (
|
||||||
|
""
|
||||||
|
if disable_local_saving
|
||||||
|
else [gr.update(choices=files.list_input_files())]
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Read the file content
|
# Read the file content
|
||||||
with open(file.name, 'r', encoding='utf-8') as f:
|
with open(file.name, "r", encoding="utf-8") as f:
|
||||||
text_content = f.read()
|
text_content = f.read()
|
||||||
|
|
||||||
if disable_local_saving:
|
if disable_local_saving:
|
||||||
# When saving is disabled, put content directly in text input
|
# When saving is disabled, put content directly in text input
|
||||||
# Normalize whitespace by replacing newlines with spaces
|
# Normalize whitespace by replacing newlines with spaces
|
||||||
normalized_text = ' '.join(text_content.split())
|
normalized_text = " ".join(text_content.split())
|
||||||
return normalized_text
|
return normalized_text
|
||||||
else:
|
else:
|
||||||
# When saving is enabled, save file and update dropdown
|
# When saving is enabled, save file and update dropdown
|
||||||
|
@ -88,7 +92,11 @@ def setup_event_handlers(components: dict, disable_local_saving: bool = False):
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error handling file: {e}")
|
print(f"Error handling file: {e}")
|
||||||
return "" if disable_local_saving else [gr.update(choices=files.list_input_files())]
|
return (
|
||||||
|
""
|
||||||
|
if disable_local_saving
|
||||||
|
else [gr.update(choices=files.list_input_files())]
|
||||||
|
)
|
||||||
|
|
||||||
def generate_from_text(text, voice, format, speed):
|
def generate_from_text(text, voice, format, speed):
|
||||||
"""Generate speech from direct text input"""
|
"""Generate speech from direct text input"""
|
||||||
|
@ -203,7 +211,11 @@ def setup_event_handlers(components: dict, disable_local_saving: bool = False):
|
||||||
components["input"]["file_upload"].upload(
|
components["input"]["file_upload"].upload(
|
||||||
fn=handle_file_upload,
|
fn=handle_file_upload,
|
||||||
inputs=[components["input"]["file_upload"]],
|
inputs=[components["input"]["file_upload"]],
|
||||||
outputs=[components["input"]["text_input"] if disable_local_saving else components["input"]["file_select"]],
|
outputs=[
|
||||||
|
components["input"]["text_input"]
|
||||||
|
if disable_local_saving
|
||||||
|
else components["input"]["file_select"]
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
if components["output"]["play_btn"] is not None:
|
if components["output"]["play_btn"] is not None:
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
import gradio as gr
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
from . import api
|
from . import api
|
||||||
from .handlers import setup_event_handlers
|
|
||||||
from .components import create_input_column, create_model_column, create_output_column
|
from .components import create_input_column, create_model_column, create_output_column
|
||||||
|
from .handlers import setup_event_handlers
|
||||||
|
|
||||||
|
|
||||||
def create_interface():
|
def create_interface():
|
||||||
|
|
Loading…
Add table
Reference in a new issue