mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-21 05:44:06 +00:00
Merge 67ac35f7bf
into dd8aa26813
This commit is contained in:
commit
3244ef5f9c
8 changed files with 234 additions and 49 deletions
8
.github/workflows/release.yml
vendored
8
.github/workflows/release.yml
vendored
|
@ -28,6 +28,7 @@ jobs:
|
||||||
build-images:
|
build-images:
|
||||||
needs: prepare-release
|
needs: prepare-release
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
timeout-minutes: 60
|
||||||
permissions:
|
permissions:
|
||||||
packages: write # Needed to push images to GHCR
|
packages: write # Needed to push images to GHCR
|
||||||
env:
|
env:
|
||||||
|
@ -35,6 +36,9 @@ jobs:
|
||||||
BUILDKIT_STEP_LOG_MAX_SIZE: 10485760
|
BUILDKIT_STEP_LOG_MAX_SIZE: 10485760
|
||||||
# This environment variable will override the VERSION variable in docker-bake.hcl
|
# This environment variable will override the VERSION variable in docker-bake.hcl
|
||||||
VERSION: ${{ needs.prepare-release.outputs.version_tag }} # Use tag version (vX.Y.Z) for bake
|
VERSION: ${{ needs.prepare-release.outputs.version_tag }} # Use tag version (vX.Y.Z) for bake
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
build_target: ["cpu", "cpu-arm64", "gpu-arm64", "gpu", "rocm"]
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
@ -60,7 +64,7 @@ jobs:
|
||||||
df -h
|
df -h
|
||||||
echo "Cleaning up disk space..."
|
echo "Cleaning up disk space..."
|
||||||
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache
|
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache
|
||||||
docker system prune -af
|
sudo docker system prune -af
|
||||||
echo "Disk space after cleanup"
|
echo "Disk space after cleanup"
|
||||||
df -h
|
df -h
|
||||||
|
|
||||||
|
@ -85,7 +89,7 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
echo "Building and pushing images for version ${{ needs.prepare-release.outputs.version_tag }}"
|
echo "Building and pushing images for version ${{ needs.prepare-release.outputs.version_tag }}"
|
||||||
# The VERSION env var above sets the tag for the bake file targets
|
# The VERSION env var above sets the tag for the bake file targets
|
||||||
docker buildx bake --push
|
docker buildx bake ${{ matrix.build_target }} --push
|
||||||
|
|
||||||
create-release:
|
create-release:
|
||||||
needs: [prepare-release, build-images]
|
needs: [prepare-release, build-images]
|
||||||
|
|
|
@ -497,4 +497,4 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
|
||||||
|
|
||||||
text = re.sub(r"\s{2,}", " ", text)
|
text = re.sub(r"\s{2,}", " ", text)
|
||||||
|
|
||||||
return text.strip()
|
return text
|
||||||
|
|
|
@ -15,7 +15,7 @@ from .vocabulary import tokenize
|
||||||
# Pre-compiled regex patterns for performance
|
# Pre-compiled regex patterns for performance
|
||||||
# Updated regex to be more strict and avoid matching isolated brackets
|
# Updated regex to be more strict and avoid matching isolated brackets
|
||||||
# Only matches complete patterns like [word](/ipa/) and prevents catastrophic backtracking
|
# Only matches complete patterns like [word](/ipa/) and prevents catastrophic backtracking
|
||||||
CUSTOM_PHONEMES = re.compile(r"(\[[^\[\]]*?\])(\(\/[^\/\(\)]*?\/\))")
|
CUSTOM_PHONEMES = re.compile(r"(\[[^\[\]]*?\]\(\/[^\/\(\)]*?\/\))")
|
||||||
# Pattern to find pause tags like [pause:0.5s]
|
# Pattern to find pause tags like [pause:0.5s]
|
||||||
PAUSE_TAG_PATTERN = re.compile(r"\[pause:(\d+(?:\.\d+)?)s\]", re.IGNORECASE)
|
PAUSE_TAG_PATTERN = re.compile(r"\[pause:(\d+(?:\.\d+)?)s\]", re.IGNORECASE)
|
||||||
|
|
||||||
|
@ -100,7 +100,7 @@ def process_text(text: str, language: str = "a") -> List[int]:
|
||||||
|
|
||||||
|
|
||||||
def get_sentence_info(
|
def get_sentence_info(
|
||||||
text: str, custom_phenomes_list: Dict[str, str], lang_code: str = "a"
|
text: str, lang_code: str = "a"
|
||||||
) -> List[Tuple[str, List[int], int]]:
|
) -> List[Tuple[str, List[int], int]]:
|
||||||
"""Process all sentences and return info"""
|
"""Process all sentences and return info"""
|
||||||
# Detect Chinese text
|
# Detect Chinese text
|
||||||
|
@ -110,18 +110,10 @@ def get_sentence_info(
|
||||||
sentences = re.split(r"([,。!?;])+", text)
|
sentences = re.split(r"([,。!?;])+", text)
|
||||||
else:
|
else:
|
||||||
sentences = re.split(r"([.!?;:])(?=\s|$)", text)
|
sentences = re.split(r"([.!?;:])(?=\s|$)", text)
|
||||||
phoneme_length, min_value = len(custom_phenomes_list), 0
|
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for i in range(0, len(sentences), 2):
|
for i in range(0, len(sentences), 2):
|
||||||
sentence = sentences[i].strip()
|
sentence = sentences[i].strip()
|
||||||
for replaced in range(min_value, phoneme_length):
|
|
||||||
current_id = f"</|custom_phonemes_{replaced}|/>"
|
|
||||||
if current_id in sentence:
|
|
||||||
sentence = sentence.replace(
|
|
||||||
current_id, custom_phenomes_list.pop(current_id)
|
|
||||||
)
|
|
||||||
min_value += 1
|
|
||||||
punct = sentences[i + 1] if i + 1 < len(sentences) else ""
|
punct = sentences[i + 1] if i + 1 < len(sentences) else ""
|
||||||
if not sentence:
|
if not sentence:
|
||||||
continue
|
continue
|
||||||
|
@ -173,24 +165,23 @@ async def smart_split(
|
||||||
# Strip leading and trailing spaces to prevent pause tag splitting artifacts
|
# Strip leading and trailing spaces to prevent pause tag splitting artifacts
|
||||||
text_part_raw = text_part_raw.strip()
|
text_part_raw = text_part_raw.strip()
|
||||||
|
|
||||||
# Apply the original smart_split logic to this text part
|
|
||||||
custom_phoneme_list = {}
|
|
||||||
|
|
||||||
# Normalize text (original logic)
|
# Normalize text (original logic)
|
||||||
processed_text = text_part_raw
|
processed_text = text_part_raw
|
||||||
if settings.advanced_text_normalization and normalization_options.normalize:
|
if settings.advanced_text_normalization and normalization_options.normalize:
|
||||||
if lang_code in ["a", "b", "en-us", "en-gb"]:
|
if lang_code in ["a", "b", "en-us", "en-gb"]:
|
||||||
processed_text = CUSTOM_PHONEMES.sub(
|
processed_text = CUSTOM_PHONEMES.split(processed_text)
|
||||||
lambda s: handle_custom_phonemes(s, custom_phoneme_list), processed_text
|
for index in range(0, len(processed_text), 2):
|
||||||
)
|
processed_text[index] = normalize_text(processed_text[index], normalization_options)
|
||||||
processed_text = normalize_text(processed_text, normalization_options)
|
|
||||||
|
|
||||||
|
processed_text = "".join(processed_text).strip()
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Skipping text normalization as it is only supported for english"
|
"Skipping text normalization as it is only supported for english"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process all sentences (original logic)
|
# Process all sentences (original logic)
|
||||||
sentences = get_sentence_info(processed_text, custom_phoneme_list, lang_code=lang_code)
|
sentences = get_sentence_info(processed_text, lang_code=lang_code)
|
||||||
|
|
||||||
current_chunk = []
|
current_chunk = []
|
||||||
current_tokens = []
|
current_tokens = []
|
||||||
|
|
|
@ -34,7 +34,7 @@ def test_process_text_chunk_phonemes():
|
||||||
def test_get_sentence_info():
|
def test_get_sentence_info():
|
||||||
"""Test sentence splitting and info extraction."""
|
"""Test sentence splitting and info extraction."""
|
||||||
text = "This is sentence one. This is sentence two! What about three?"
|
text = "This is sentence one. This is sentence two! What about three?"
|
||||||
results = get_sentence_info(text, {})
|
results = get_sentence_info(text)
|
||||||
|
|
||||||
assert len(results) == 3
|
assert len(results) == 3
|
||||||
for sentence, tokens, count in results:
|
for sentence, tokens, count in results:
|
||||||
|
@ -44,24 +44,6 @@ def test_get_sentence_info():
|
||||||
assert count == len(tokens)
|
assert count == len(tokens)
|
||||||
assert count > 0
|
assert count > 0
|
||||||
|
|
||||||
|
|
||||||
def test_get_sentence_info_phenomoes():
|
|
||||||
"""Test sentence splitting and info extraction."""
|
|
||||||
text = (
|
|
||||||
"This is sentence one. This is </|custom_phonemes_0|/> two! What about three?"
|
|
||||||
)
|
|
||||||
results = get_sentence_info(text, {"</|custom_phonemes_0|/>": r"sˈɛntᵊns"})
|
|
||||||
|
|
||||||
assert len(results) == 3
|
|
||||||
assert "sˈɛntᵊns" in results[1][0]
|
|
||||||
for sentence, tokens, count in results:
|
|
||||||
assert isinstance(sentence, str)
|
|
||||||
assert isinstance(tokens, list)
|
|
||||||
assert isinstance(count, int)
|
|
||||||
assert count == len(tokens)
|
|
||||||
assert count > 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_smart_split_short_text():
|
async def test_smart_split_short_text():
|
||||||
"""Test smart splitting with text under max tokens."""
|
"""Test smart splitting with text under max tokens."""
|
||||||
|
@ -74,6 +56,33 @@ async def test_smart_split_short_text():
|
||||||
assert isinstance(chunks[0][0], str)
|
assert isinstance(chunks[0][0], str)
|
||||||
assert isinstance(chunks[0][1], list)
|
assert isinstance(chunks[0][1], list)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_smart_custom_phenomes():
|
||||||
|
"""Test smart splitting with text under max tokens."""
|
||||||
|
text = "This is a short test sentence. [Kokoro](/kˈOkəɹO/) has a feature called custom phenomes. This is made possible by [Misaki](/misˈɑki/), the custom phenomizer that [Kokoro](/kˈOkəɹO/) version 1.0 uses"
|
||||||
|
chunks = []
|
||||||
|
async for chunk_text, chunk_tokens, pause_duration in smart_split(text):
|
||||||
|
chunks.append((chunk_text, chunk_tokens, pause_duration))
|
||||||
|
|
||||||
|
# Should have 1 chunks: text
|
||||||
|
assert len(chunks) == 1
|
||||||
|
|
||||||
|
# First chunk: text
|
||||||
|
assert chunks[0][2] is None # No pause
|
||||||
|
assert "This is a short test sentence. [Kokoro](/kˈOkəɹO/) has a feature called custom phenomes. This is made possible by [Misaki](/misˈɑki/), the custom phenomizer that [Kokoro](/kˈOkəɹO/) version one uses" in chunks[0][0]
|
||||||
|
assert len(chunks[0][1]) > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_smart_split_only_phenomes():
|
||||||
|
"""Test input that is entirely made of phenome annotations."""
|
||||||
|
text = "[Kokoro](/kˈOkəɹO/) [Misaki 1.2](/misˈɑki/) [Test](/tɛst/)"
|
||||||
|
chunks = []
|
||||||
|
async for chunk_text, chunk_tokens, pause_duration in smart_split(text, max_tokens=10):
|
||||||
|
chunks.append((chunk_text, chunk_tokens, pause_duration))
|
||||||
|
|
||||||
|
assert len(chunks) == 1
|
||||||
|
assert "[Kokoro](/kˈOkəɹO/) [Misaki 1.2](/misˈɑki/) [Test](/tɛst/)" in chunks[0][0]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_smart_split_long_text():
|
async def test_smart_split_long_text():
|
||||||
|
@ -116,7 +125,7 @@ def test_process_text_chunk_chinese_phonemes():
|
||||||
def test_get_sentence_info_chinese():
|
def test_get_sentence_info_chinese():
|
||||||
"""Test Chinese sentence splitting and info extraction."""
|
"""Test Chinese sentence splitting and info extraction."""
|
||||||
text = "这是一个句子。这是第二个句子!第三个问题?"
|
text = "这是一个句子。这是第二个句子!第三个问题?"
|
||||||
results = get_sentence_info(text, {}, lang_code="z")
|
results = get_sentence_info(text, lang_code="z")
|
||||||
|
|
||||||
assert len(results) == 3
|
assert len(results) == 3
|
||||||
for sentence, tokens, count in results:
|
for sentence, tokens, count in results:
|
||||||
|
|
|
@ -40,10 +40,25 @@ target "_gpu_base" {
|
||||||
dockerfile = "docker/gpu/Dockerfile"
|
dockerfile = "docker/gpu/Dockerfile"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Base settings for AMD ROCm builds
|
||||||
|
target "_rocm_base" {
|
||||||
|
inherits = ["_common"]
|
||||||
|
dockerfile = "docker/rocm/Dockerfile"
|
||||||
|
}
|
||||||
|
|
||||||
# CPU target with multi-platform support
|
# CPU target with multi-platform support
|
||||||
target "cpu" {
|
target "cpu" {
|
||||||
inherits = ["_cpu_base"]
|
inherits = ["_cpu_base"]
|
||||||
platforms = ["linux/amd64", "linux/arm64"]
|
platforms = ["linux/amd64"]
|
||||||
|
tags = [
|
||||||
|
"${REGISTRY}/${OWNER}/${REPO}-cpu:${VERSION}",
|
||||||
|
"${REGISTRY}/${OWNER}/${REPO}-cpu:latest"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
target "cpu-arm64" {
|
||||||
|
inherits = ["_cpu_base"]
|
||||||
|
platforms = ["linux/arm64"]
|
||||||
tags = [
|
tags = [
|
||||||
"${REGISTRY}/${OWNER}/${REPO}-cpu:${VERSION}",
|
"${REGISTRY}/${OWNER}/${REPO}-cpu:${VERSION}",
|
||||||
"${REGISTRY}/${OWNER}/${REPO}-cpu:latest"
|
"${REGISTRY}/${OWNER}/${REPO}-cpu:latest"
|
||||||
|
@ -53,16 +68,51 @@ target "cpu" {
|
||||||
# GPU target with multi-platform support
|
# GPU target with multi-platform support
|
||||||
target "gpu" {
|
target "gpu" {
|
||||||
inherits = ["_gpu_base"]
|
inherits = ["_gpu_base"]
|
||||||
platforms = ["linux/amd64", "linux/arm64"]
|
platforms = ["linux/amd64"]
|
||||||
tags = [
|
tags = [
|
||||||
"${REGISTRY}/${OWNER}/${REPO}-gpu:${VERSION}",
|
"${REGISTRY}/${OWNER}/${REPO}-gpu:${VERSION}",
|
||||||
"${REGISTRY}/${OWNER}/${REPO}-gpu:latest"
|
"${REGISTRY}/${OWNER}/${REPO}-gpu:latest"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
# Default group to build both CPU and GPU versions
|
target "gpu-arm64" {
|
||||||
group "default" {
|
inherits = ["_gpu_base"]
|
||||||
targets = ["cpu", "gpu"]
|
platforms = ["linux/arm64"]
|
||||||
|
tags = [
|
||||||
|
"${REGISTRY}/${OWNER}/${REPO}-gpu:${VERSION}",
|
||||||
|
"${REGISTRY}/${OWNER}/${REPO}-gpu:latest"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
# AMD ROCm target with multi-platform support
|
||||||
|
target "rocm" {
|
||||||
|
inherits = ["_rocm_base"]
|
||||||
|
platforms = ["linux/amd64"]
|
||||||
|
tags = [
|
||||||
|
"${REGISTRY}/${OWNER}/${REPO}-rocm:${VERSION}",
|
||||||
|
"${REGISTRY}/${OWNER}/${REPO}-rocm:latest"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Build groups for parallel builds
|
||||||
|
group "cpu" {
|
||||||
|
targets = ["cpu"]
|
||||||
|
}
|
||||||
|
|
||||||
|
group "cpu-arm64" {
|
||||||
|
targets = ["cpu-arm64"]
|
||||||
|
}
|
||||||
|
|
||||||
|
group "gpu-arm64" {
|
||||||
|
targets = ["gpu-arm64"]
|
||||||
|
}
|
||||||
|
|
||||||
|
group "gpu" {
|
||||||
|
targets = ["gpu"]
|
||||||
|
}
|
||||||
|
|
||||||
|
group "rocm" {
|
||||||
|
targets = ["rocm"]
|
||||||
}
|
}
|
||||||
|
|
||||||
# Development targets for faster local builds
|
# Development targets for faster local builds
|
||||||
|
@ -78,6 +128,12 @@ target "gpu-dev" {
|
||||||
tags = ["${REGISTRY}/${OWNER}/${REPO}-gpu:dev"]
|
tags = ["${REGISTRY}/${OWNER}/${REPO}-gpu:dev"]
|
||||||
}
|
}
|
||||||
|
|
||||||
group "dev" {
|
target "rocm-dev" {
|
||||||
targets = ["cpu-dev", "gpu-dev"]
|
inherits = ["_rocm_base"]
|
||||||
|
# No multi-platform for dev builds
|
||||||
|
tags = ["${REGISTRY}/${OWNER}/${REPO}-rocm:dev"]
|
||||||
|
}
|
||||||
|
|
||||||
|
group "dev" {
|
||||||
|
targets = ["cpu-dev", "gpu-dev", "rocm-dev"]
|
||||||
}
|
}
|
72
docker/rocm/Dockerfile
Normal file
72
docker/rocm/Dockerfile
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
FROM rocm/dev-ubuntu-24.04:6.4.1
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive \
|
||||||
|
PHONEMIZER_ESPEAK_PATH=/usr/bin \
|
||||||
|
PHONEMIZER_ESPEAK_DATA=/usr/share/espeak-ng-data \
|
||||||
|
ESPEAK_DATA_PATH=/usr/share/espeak-ng-data
|
||||||
|
|
||||||
|
# Install Python and other dependencies
|
||||||
|
RUN apt-get update && apt upgrade -y && apt-get install -y --no-install-recommends \
|
||||||
|
espeak-ng \
|
||||||
|
espeak-ng-data \
|
||||||
|
git \
|
||||||
|
libsndfile1 \
|
||||||
|
curl \
|
||||||
|
ffmpeg \
|
||||||
|
wget \
|
||||||
|
nano \
|
||||||
|
g++ \
|
||||||
|
&& apt-get clean \
|
||||||
|
&& rm -rf /var/lib/apt/lists/* \
|
||||||
|
&& mkdir -p /usr/share/espeak-ng-data \
|
||||||
|
&& ln -s /usr/lib/*/espeak-ng-data/* /usr/share/espeak-ng-data/
|
||||||
|
|
||||||
|
RUN mkdir -p /app/api/src/models/v1_0
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Install UV using the installer script
|
||||||
|
RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
||||||
|
mv /root/.local/bin/uv /usr/local/bin/ && \
|
||||||
|
mv /root/.local/bin/uvx /usr/local/bin/
|
||||||
|
|
||||||
|
# Create non-root user and set up directories and permissions
|
||||||
|
RUN useradd -m -u 1001 appuser && \
|
||||||
|
mkdir -p /app/api/src/models/v1_0 && \
|
||||||
|
chown -R appuser:appuser /app
|
||||||
|
|
||||||
|
USER appuser
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy dependency files
|
||||||
|
COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml
|
||||||
|
|
||||||
|
ENV PHONEMIZER_ESPEAK_PATH=/usr/bin \
|
||||||
|
PHONEMIZER_ESPEAK_DATA=/usr/share/espeak-ng-data \
|
||||||
|
ESPEAK_DATA_PATH=/usr/share/espeak-ng-data
|
||||||
|
|
||||||
|
# Install dependencies with GPU extras (using cache mounts)
|
||||||
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
|
uv venv --python 3.10 && \
|
||||||
|
uv sync --extra rocm
|
||||||
|
|
||||||
|
# Copy project files including models
|
||||||
|
COPY --chown=appuser:appuser api ./api
|
||||||
|
COPY --chown=appuser:appuser web ./web
|
||||||
|
COPY --chown=appuser:appuser docker/scripts/ ./
|
||||||
|
RUN chmod +x ./entrypoint.sh
|
||||||
|
|
||||||
|
# Set all environment variables in one go
|
||||||
|
ENV PYTHONUNBUFFERED=1 \
|
||||||
|
PYTHONPATH=/app:/app/api \
|
||||||
|
PATH="/app/.venv/bin:$PATH" \
|
||||||
|
UV_LINK_MODE=copy \
|
||||||
|
USE_GPU=true
|
||||||
|
|
||||||
|
ENV DOWNLOAD_MODEL=true
|
||||||
|
# Download model if enabled
|
||||||
|
RUN if [ "$DOWNLOAD_MODEL" = "true" ]; then \
|
||||||
|
python download_model.py --output api/src/models/v1_0; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
ENV DEVICE="gpu"
|
||||||
|
# Run FastAPI server through entrypoint.sh
|
||||||
|
CMD ["./entrypoint.sh"]
|
38
docker/rocm/docker-compose.yml
Normal file
38
docker/rocm/docker-compose.yml
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
services:
|
||||||
|
kokoro-tts:
|
||||||
|
image: kprinssu/kokoro-fastapi:rocm
|
||||||
|
devices:
|
||||||
|
- /dev/dri
|
||||||
|
- /dev/kfd
|
||||||
|
security_opt:
|
||||||
|
- seccomp:unconfined
|
||||||
|
cap_add:
|
||||||
|
- SYS_PTRACE
|
||||||
|
group_add:
|
||||||
|
# NOTE: These groups are the group ids for: video, input, and render
|
||||||
|
# Numbers can be found via running: getent group $GROUP_NAME | cut -d: -f3
|
||||||
|
- 44
|
||||||
|
- 993
|
||||||
|
- 996
|
||||||
|
restart: 'always'
|
||||||
|
volumes:
|
||||||
|
- ./kokoro-tts/config:/root/.config/miopen
|
||||||
|
- ./kokoro-tts/cache:/root/.cache/miopen
|
||||||
|
ports:
|
||||||
|
- 8880:8880
|
||||||
|
environment:
|
||||||
|
- USE_GPU=true
|
||||||
|
- TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1
|
||||||
|
# IMPORTANT: ROCm's MIOpen libray will be slow if it has to figure out the optimal kernel shapes for each model
|
||||||
|
# See documentation on performancing tuning: https://github.com/ROCm/MIOpen/blob/develop/docs/conceptual/tuningdb.rst
|
||||||
|
# The volumes above cache the MIOpen shape files and user database for subsequent runs
|
||||||
|
#
|
||||||
|
# Steps:
|
||||||
|
# 1. Run Kokoro once with the following environment variables set:
|
||||||
|
# - MIOPEN_FIND_MODE=3
|
||||||
|
# - MIOPEN_FIND_ENFORCE=3
|
||||||
|
# 2. Generate various recordings using sample data (e.g. first couple paragraphs of Dracula); this will be slow
|
||||||
|
# 3. Comment out/remove the previously set environment variables
|
||||||
|
# 4. Add the following environment variables to enable caching of model shapes:
|
||||||
|
# - MIOPEN_FIND_MODE=2
|
||||||
|
# 5. Restart the container and run Kokoro again, it should be much faster
|
|
@ -46,6 +46,10 @@ dependencies = [
|
||||||
gpu = [
|
gpu = [
|
||||||
"torch==2.6.0+cu124",
|
"torch==2.6.0+cu124",
|
||||||
]
|
]
|
||||||
|
rocm = [
|
||||||
|
"torch==2.8.0.dev20250627+rocm6.4",
|
||||||
|
"pytorch-triton-rocm>=3.2.0",
|
||||||
|
]
|
||||||
cpu = [
|
cpu = [
|
||||||
"torch==2.6.0",
|
"torch==2.6.0",
|
||||||
]
|
]
|
||||||
|
@ -63,13 +67,19 @@ conflicts = [
|
||||||
[
|
[
|
||||||
{ extra = "cpu" },
|
{ extra = "cpu" },
|
||||||
{ extra = "gpu" },
|
{ extra = "gpu" },
|
||||||
|
{ extra = "rocm" },
|
||||||
],
|
],
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
[tool.uv.sources]
|
[tool.uv.sources]
|
||||||
torch = [
|
torch = [
|
||||||
{ index = "pytorch-cpu", extra = "cpu" },
|
{ index = "pytorch-cpu", extra = "cpu" },
|
||||||
{ index = "pytorch-cuda", extra = "gpu" },
|
{ index = "pytorch-cuda", extra = "gpu" },
|
||||||
|
{ index = "pytorch-rocm", extra = "rocm" },
|
||||||
|
]
|
||||||
|
pytorch-triton-rocm = [
|
||||||
|
{ index = "pytorch-rocm", extra = "rocm" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[tool.uv.index]]
|
[[tool.uv.index]]
|
||||||
|
@ -82,6 +92,11 @@ name = "pytorch-cuda"
|
||||||
url = "https://download.pytorch.org/whl/cu124"
|
url = "https://download.pytorch.org/whl/cu124"
|
||||||
explicit = true
|
explicit = true
|
||||||
|
|
||||||
|
[[tool.uv.index]]
|
||||||
|
name = "pytorch-rocm"
|
||||||
|
url = "https://download.pytorch.org/whl/nightly/rocm6.4"
|
||||||
|
explicit = true
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["setuptools>=61.0"]
|
requires = ["setuptools>=61.0"]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
Loading…
Add table
Reference in a new issue