This commit is contained in:
Kishor Prins 2025-06-30 00:55:33 +00:00 committed by GitHub
commit 3244ef5f9c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 234 additions and 49 deletions

View file

@ -28,6 +28,7 @@ jobs:
build-images:
needs: prepare-release
runs-on: ubuntu-latest
timeout-minutes: 60
permissions:
packages: write # Needed to push images to GHCR
env:
@ -35,6 +36,9 @@ jobs:
BUILDKIT_STEP_LOG_MAX_SIZE: 10485760
# This environment variable will override the VERSION variable in docker-bake.hcl
VERSION: ${{ needs.prepare-release.outputs.version_tag }} # Use tag version (vX.Y.Z) for bake
strategy:
matrix:
build_target: ["cpu", "cpu-arm64", "gpu-arm64", "gpu", "rocm"]
steps:
- name: Checkout repository
uses: actions/checkout@v4
@ -60,7 +64,7 @@ jobs:
df -h
echo "Cleaning up disk space..."
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache
docker system prune -af
sudo docker system prune -af
echo "Disk space after cleanup"
df -h
@ -85,7 +89,7 @@ jobs:
run: |
echo "Building and pushing images for version ${{ needs.prepare-release.outputs.version_tag }}"
# The VERSION env var above sets the tag for the bake file targets
docker buildx bake --push
docker buildx bake ${{ matrix.build_target }} --push
create-release:
needs: [prepare-release, build-images]

View file

@ -497,4 +497,4 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
text = re.sub(r"\s{2,}", " ", text)
return text.strip()
return text

View file

@ -15,7 +15,7 @@ from .vocabulary import tokenize
# Pre-compiled regex patterns for performance
# Updated regex to be more strict and avoid matching isolated brackets
# Only matches complete patterns like [word](/ipa/) and prevents catastrophic backtracking
CUSTOM_PHONEMES = re.compile(r"(\[[^\[\]]*?\])(\(\/[^\/\(\)]*?\/\))")
CUSTOM_PHONEMES = re.compile(r"(\[[^\[\]]*?\]\(\/[^\/\(\)]*?\/\))")
# Pattern to find pause tags like [pause:0.5s]
PAUSE_TAG_PATTERN = re.compile(r"\[pause:(\d+(?:\.\d+)?)s\]", re.IGNORECASE)
@ -100,7 +100,7 @@ def process_text(text: str, language: str = "a") -> List[int]:
def get_sentence_info(
text: str, custom_phenomes_list: Dict[str, str], lang_code: str = "a"
text: str, lang_code: str = "a"
) -> List[Tuple[str, List[int], int]]:
"""Process all sentences and return info"""
# Detect Chinese text
@ -110,18 +110,10 @@ def get_sentence_info(
sentences = re.split(r"([,。!?;])+", text)
else:
sentences = re.split(r"([.!?;:])(?=\s|$)", text)
phoneme_length, min_value = len(custom_phenomes_list), 0
results = []
for i in range(0, len(sentences), 2):
sentence = sentences[i].strip()
for replaced in range(min_value, phoneme_length):
current_id = f"</|custom_phonemes_{replaced}|/>"
if current_id in sentence:
sentence = sentence.replace(
current_id, custom_phenomes_list.pop(current_id)
)
min_value += 1
punct = sentences[i + 1] if i + 1 < len(sentences) else ""
if not sentence:
continue
@ -173,24 +165,23 @@ async def smart_split(
# Strip leading and trailing spaces to prevent pause tag splitting artifacts
text_part_raw = text_part_raw.strip()
# Apply the original smart_split logic to this text part
custom_phoneme_list = {}
# Normalize text (original logic)
processed_text = text_part_raw
if settings.advanced_text_normalization and normalization_options.normalize:
if lang_code in ["a", "b", "en-us", "en-gb"]:
processed_text = CUSTOM_PHONEMES.sub(
lambda s: handle_custom_phonemes(s, custom_phoneme_list), processed_text
)
processed_text = normalize_text(processed_text, normalization_options)
processed_text = CUSTOM_PHONEMES.split(processed_text)
for index in range(0, len(processed_text), 2):
processed_text[index] = normalize_text(processed_text[index], normalization_options)
processed_text = "".join(processed_text).strip()
else:
logger.info(
"Skipping text normalization as it is only supported for english"
)
# Process all sentences (original logic)
sentences = get_sentence_info(processed_text, custom_phoneme_list, lang_code=lang_code)
sentences = get_sentence_info(processed_text, lang_code=lang_code)
current_chunk = []
current_tokens = []

View file

@ -34,7 +34,7 @@ def test_process_text_chunk_phonemes():
def test_get_sentence_info():
"""Test sentence splitting and info extraction."""
text = "This is sentence one. This is sentence two! What about three?"
results = get_sentence_info(text, {})
results = get_sentence_info(text)
assert len(results) == 3
for sentence, tokens, count in results:
@ -44,24 +44,6 @@ def test_get_sentence_info():
assert count == len(tokens)
assert count > 0
def test_get_sentence_info_phenomoes():
"""Test sentence splitting and info extraction."""
text = (
"This is sentence one. This is </|custom_phonemes_0|/> two! What about three?"
)
results = get_sentence_info(text, {"</|custom_phonemes_0|/>": r"sˈɛntᵊns"})
assert len(results) == 3
assert "sˈɛntᵊns" in results[1][0]
for sentence, tokens, count in results:
assert isinstance(sentence, str)
assert isinstance(tokens, list)
assert isinstance(count, int)
assert count == len(tokens)
assert count > 0
@pytest.mark.asyncio
async def test_smart_split_short_text():
"""Test smart splitting with text under max tokens."""
@ -74,6 +56,33 @@ async def test_smart_split_short_text():
assert isinstance(chunks[0][0], str)
assert isinstance(chunks[0][1], list)
@pytest.mark.asyncio
async def test_smart_custom_phenomes():
"""Test smart splitting with text under max tokens."""
text = "This is a short test sentence. [Kokoro](/kˈOkəɹO/) has a feature called custom phenomes. This is made possible by [Misaki](/misˈɑki/), the custom phenomizer that [Kokoro](/kˈOkəɹO/) version 1.0 uses"
chunks = []
async for chunk_text, chunk_tokens, pause_duration in smart_split(text):
chunks.append((chunk_text, chunk_tokens, pause_duration))
# Should have 1 chunks: text
assert len(chunks) == 1
# First chunk: text
assert chunks[0][2] is None # No pause
assert "This is a short test sentence. [Kokoro](/kˈOkəɹO/) has a feature called custom phenomes. This is made possible by [Misaki](/misˈɑki/), the custom phenomizer that [Kokoro](/kˈOkəɹO/) version one uses" in chunks[0][0]
assert len(chunks[0][1]) > 0
@pytest.mark.asyncio
async def test_smart_split_only_phenomes():
"""Test input that is entirely made of phenome annotations."""
text = "[Kokoro](/kˈOkəɹO/) [Misaki 1.2](/misˈɑki/) [Test](/tɛst/)"
chunks = []
async for chunk_text, chunk_tokens, pause_duration in smart_split(text, max_tokens=10):
chunks.append((chunk_text, chunk_tokens, pause_duration))
assert len(chunks) == 1
assert "[Kokoro](/kˈOkəɹO/) [Misaki 1.2](/misˈɑki/) [Test](/tɛst/)" in chunks[0][0]
@pytest.mark.asyncio
async def test_smart_split_long_text():
@ -116,7 +125,7 @@ def test_process_text_chunk_chinese_phonemes():
def test_get_sentence_info_chinese():
"""Test Chinese sentence splitting and info extraction."""
text = "这是一个句子。这是第二个句子!第三个问题?"
results = get_sentence_info(text, {}, lang_code="z")
results = get_sentence_info(text, lang_code="z")
assert len(results) == 3
for sentence, tokens, count in results:

View file

@ -40,10 +40,25 @@ target "_gpu_base" {
dockerfile = "docker/gpu/Dockerfile"
}
# Base settings for AMD ROCm builds
target "_rocm_base" {
inherits = ["_common"]
dockerfile = "docker/rocm/Dockerfile"
}
# CPU target with multi-platform support
target "cpu" {
inherits = ["_cpu_base"]
platforms = ["linux/amd64", "linux/arm64"]
platforms = ["linux/amd64"]
tags = [
"${REGISTRY}/${OWNER}/${REPO}-cpu:${VERSION}",
"${REGISTRY}/${OWNER}/${REPO}-cpu:latest"
]
}
target "cpu-arm64" {
inherits = ["_cpu_base"]
platforms = ["linux/arm64"]
tags = [
"${REGISTRY}/${OWNER}/${REPO}-cpu:${VERSION}",
"${REGISTRY}/${OWNER}/${REPO}-cpu:latest"
@ -53,16 +68,51 @@ target "cpu" {
# GPU target with multi-platform support
target "gpu" {
inherits = ["_gpu_base"]
platforms = ["linux/amd64", "linux/arm64"]
platforms = ["linux/amd64"]
tags = [
"${REGISTRY}/${OWNER}/${REPO}-gpu:${VERSION}",
"${REGISTRY}/${OWNER}/${REPO}-gpu:latest"
]
}
# Default group to build both CPU and GPU versions
group "default" {
targets = ["cpu", "gpu"]
target "gpu-arm64" {
inherits = ["_gpu_base"]
platforms = ["linux/arm64"]
tags = [
"${REGISTRY}/${OWNER}/${REPO}-gpu:${VERSION}",
"${REGISTRY}/${OWNER}/${REPO}-gpu:latest"
]
}
# AMD ROCm target with multi-platform support
target "rocm" {
inherits = ["_rocm_base"]
platforms = ["linux/amd64"]
tags = [
"${REGISTRY}/${OWNER}/${REPO}-rocm:${VERSION}",
"${REGISTRY}/${OWNER}/${REPO}-rocm:latest"
]
}
# Build groups for parallel builds
group "cpu" {
targets = ["cpu"]
}
group "cpu-arm64" {
targets = ["cpu-arm64"]
}
group "gpu-arm64" {
targets = ["gpu-arm64"]
}
group "gpu" {
targets = ["gpu"]
}
group "rocm" {
targets = ["rocm"]
}
# Development targets for faster local builds
@ -78,6 +128,12 @@ target "gpu-dev" {
tags = ["${REGISTRY}/${OWNER}/${REPO}-gpu:dev"]
}
target "rocm-dev" {
inherits = ["_rocm_base"]
# No multi-platform for dev builds
tags = ["${REGISTRY}/${OWNER}/${REPO}-rocm:dev"]
}
group "dev" {
targets = ["cpu-dev", "gpu-dev"]
}
targets = ["cpu-dev", "gpu-dev", "rocm-dev"]
}

72
docker/rocm/Dockerfile Normal file
View file

@ -0,0 +1,72 @@
FROM rocm/dev-ubuntu-24.04:6.4.1
ENV DEBIAN_FRONTEND=noninteractive \
PHONEMIZER_ESPEAK_PATH=/usr/bin \
PHONEMIZER_ESPEAK_DATA=/usr/share/espeak-ng-data \
ESPEAK_DATA_PATH=/usr/share/espeak-ng-data
# Install Python and other dependencies
RUN apt-get update && apt upgrade -y && apt-get install -y --no-install-recommends \
espeak-ng \
espeak-ng-data \
git \
libsndfile1 \
curl \
ffmpeg \
wget \
nano \
g++ \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/* \
&& mkdir -p /usr/share/espeak-ng-data \
&& ln -s /usr/lib/*/espeak-ng-data/* /usr/share/espeak-ng-data/
RUN mkdir -p /app/api/src/models/v1_0
WORKDIR /app
# Install UV using the installer script
RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
mv /root/.local/bin/uv /usr/local/bin/ && \
mv /root/.local/bin/uvx /usr/local/bin/
# Create non-root user and set up directories and permissions
RUN useradd -m -u 1001 appuser && \
mkdir -p /app/api/src/models/v1_0 && \
chown -R appuser:appuser /app
USER appuser
WORKDIR /app
# Copy dependency files
COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml
ENV PHONEMIZER_ESPEAK_PATH=/usr/bin \
PHONEMIZER_ESPEAK_DATA=/usr/share/espeak-ng-data \
ESPEAK_DATA_PATH=/usr/share/espeak-ng-data
# Install dependencies with GPU extras (using cache mounts)
RUN --mount=type=cache,target=/root/.cache/uv \
uv venv --python 3.10 && \
uv sync --extra rocm
# Copy project files including models
COPY --chown=appuser:appuser api ./api
COPY --chown=appuser:appuser web ./web
COPY --chown=appuser:appuser docker/scripts/ ./
RUN chmod +x ./entrypoint.sh
# Set all environment variables in one go
ENV PYTHONUNBUFFERED=1 \
PYTHONPATH=/app:/app/api \
PATH="/app/.venv/bin:$PATH" \
UV_LINK_MODE=copy \
USE_GPU=true
ENV DOWNLOAD_MODEL=true
# Download model if enabled
RUN if [ "$DOWNLOAD_MODEL" = "true" ]; then \
python download_model.py --output api/src/models/v1_0; \
fi
ENV DEVICE="gpu"
# Run FastAPI server through entrypoint.sh
CMD ["./entrypoint.sh"]

View file

@ -0,0 +1,38 @@
services:
kokoro-tts:
image: kprinssu/kokoro-fastapi:rocm
devices:
- /dev/dri
- /dev/kfd
security_opt:
- seccomp:unconfined
cap_add:
- SYS_PTRACE
group_add:
# NOTE: These groups are the group ids for: video, input, and render
# Numbers can be found via running: getent group $GROUP_NAME | cut -d: -f3
- 44
- 993
- 996
restart: 'always'
volumes:
- ./kokoro-tts/config:/root/.config/miopen
- ./kokoro-tts/cache:/root/.cache/miopen
ports:
- 8880:8880
environment:
- USE_GPU=true
- TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1
# IMPORTANT: ROCm's MIOpen libray will be slow if it has to figure out the optimal kernel shapes for each model
# See documentation on performancing tuning: https://github.com/ROCm/MIOpen/blob/develop/docs/conceptual/tuningdb.rst
# The volumes above cache the MIOpen shape files and user database for subsequent runs
#
# Steps:
# 1. Run Kokoro once with the following environment variables set:
# - MIOPEN_FIND_MODE=3
# - MIOPEN_FIND_ENFORCE=3
# 2. Generate various recordings using sample data (e.g. first couple paragraphs of Dracula); this will be slow
# 3. Comment out/remove the previously set environment variables
# 4. Add the following environment variables to enable caching of model shapes:
# - MIOPEN_FIND_MODE=2
# 5. Restart the container and run Kokoro again, it should be much faster

View file

@ -46,6 +46,10 @@ dependencies = [
gpu = [
"torch==2.6.0+cu124",
]
rocm = [
"torch==2.8.0.dev20250627+rocm6.4",
"pytorch-triton-rocm>=3.2.0",
]
cpu = [
"torch==2.6.0",
]
@ -63,13 +67,19 @@ conflicts = [
[
{ extra = "cpu" },
{ extra = "gpu" },
{ extra = "rocm" },
],
]
[tool.uv.sources]
torch = [
{ index = "pytorch-cpu", extra = "cpu" },
{ index = "pytorch-cuda", extra = "gpu" },
{ index = "pytorch-rocm", extra = "rocm" },
]
pytorch-triton-rocm = [
{ index = "pytorch-rocm", extra = "rocm" },
]
[[tool.uv.index]]
@ -82,6 +92,11 @@ name = "pytorch-cuda"
url = "https://download.pytorch.org/whl/cu124"
explicit = true
[[tool.uv.index]]
name = "pytorch-rocm"
url = "https://download.pytorch.org/whl/nightly/rocm6.4"
explicit = true
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
@ -93,5 +108,5 @@ packages.find = {where = ["api/src"], namespaces = true}
[tool.pytest.ini_options]
testpaths = ["api/tests", "ui/tests"]
python_files = ["test_*.py"]
addopts = "--cov=api --cov=ui --cov-report=term-missing --cov-config=.coveragerc --full-trace"
addopts = "--cov=api --cov=ui --cov-report=term-missing --cov-config=.coveragerc --full-trace"
asyncio_mode = "auto"