From ac7947b51aa3c4a01cc7a7b8132caddb2957c2e4 Mon Sep 17 00:00:00 2001
From: remsky
Date: Thu, 6 Feb 2025 23:43:26 -0700
Subject: [PATCH] Refactor Docker configurations for GPU and CPU, update test
paths, and remove deprecated tests
---
README.md | 4 +-
api/src/main.py | 1 +
api/tests/test_kokoro_v1.py | 75 +++++++++++
api/tests/test_openai_endpoints.py | 11 +-
api/tests/test_paths.py | 116 ++++++++++++++++
api/tests/test_text_processor.py | 80 +++++++++++
api/tests/test_tts_service.py | 104 ++++++++++++++
api/tests/test_tts_service_new.py | 142 --------------------
docker-bake.hcl | 80 +++++++++++
docker/build.sh | 32 +----
docker/cpu/Dockerfile | 17 +--
docker/cpu/docker-compose.yml | 3 +-
docker/gpu/Dockerfile | 27 ++--
docker/gpu/docker-compose.yml | 2 +-
docker/shared/pyproject.toml | 45 -------
pyproject.toml | 15 +--
pytest.ini | 4 +-
ui/{tests => depr_tests}/conftest.py | 0
ui/{tests => depr_tests}/test_api.py | 0
ui/{tests => depr_tests}/test_components.py | 0
ui/{tests => depr_tests}/test_files.py | 0
ui/{tests => depr_tests}/test_handlers.py | 0
ui/{tests => depr_tests}/test_input.py | 0
ui/{tests => depr_tests}/test_interface.py | 0
24 files changed, 495 insertions(+), 263 deletions(-)
create mode 100644 api/tests/test_kokoro_v1.py
create mode 100644 api/tests/test_paths.py
create mode 100644 api/tests/test_text_processor.py
create mode 100644 api/tests/test_tts_service.py
delete mode 100644 api/tests/test_tts_service_new.py
create mode 100644 docker-bake.hcl
delete mode 100644 docker/shared/pyproject.toml
rename ui/{tests => depr_tests}/conftest.py (100%)
rename ui/{tests => depr_tests}/test_api.py (100%)
rename ui/{tests => depr_tests}/test_components.py (100%)
rename ui/{tests => depr_tests}/test_files.py (100%)
rename ui/{tests => depr_tests}/test_handlers.py (100%)
rename ui/{tests => depr_tests}/test_input.py (100%)
rename ui/{tests => depr_tests}/test_interface.py (100%)
diff --git a/README.md b/README.md
index a049c8d..993ecc9 100644
--- a/README.md
+++ b/README.md
@@ -3,8 +3,8 @@
# _`FastKoko`_
-[]()
-[]()
+[]()
+[]()
[](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero)
[](https://huggingface.co/hexgrad/Kokoro-82M/commit/9901c2b79161b6e898b7ea857ae5298f47b8b0d6)
diff --git a/api/src/main.py b/api/src/main.py
index 1a720e8..5e6a32c 100644
--- a/api/src/main.py
+++ b/api/src/main.py
@@ -93,6 +93,7 @@ Model files not found! You need to download the Kokoro V1 model:
{boundary}
"""
startup_msg += f"\nModel warmed up on {device}: {model}"
+ startup_msg += f"CUDA: {torch.cuda.is_available()}"
startup_msg += f"\n{voicepack_count} voice packs loaded"
# Add web player info if enabled
diff --git a/api/tests/test_kokoro_v1.py b/api/tests/test_kokoro_v1.py
new file mode 100644
index 0000000..ad6a7d0
--- /dev/null
+++ b/api/tests/test_kokoro_v1.py
@@ -0,0 +1,75 @@
+import pytest
+from unittest.mock import patch, MagicMock
+import torch
+import numpy as np
+from api.src.inference.kokoro_v1 import KokoroV1
+
+@pytest.fixture
+def kokoro_backend():
+ """Create a KokoroV1 instance for testing."""
+ return KokoroV1()
+
+def test_initial_state(kokoro_backend):
+ """Test initial state of KokoroV1."""
+ assert not kokoro_backend.is_loaded
+ assert kokoro_backend._model is None
+ assert kokoro_backend._pipeline is None
+ # Device should be set based on settings
+ assert kokoro_backend.device in ["cuda", "cpu"]
+
+@patch('torch.cuda.is_available', return_value=True)
+@patch('torch.cuda.memory_allocated')
+def test_memory_management(mock_memory, mock_cuda, kokoro_backend):
+ """Test GPU memory management functions."""
+ # Mock GPU memory usage
+ mock_memory.return_value = 5e9 # 5GB
+
+ # Test memory check
+ with patch('api.src.inference.kokoro_v1.model_config') as mock_config:
+ mock_config.pytorch_gpu.memory_threshold = 4
+ assert kokoro_backend._check_memory() == True
+
+ mock_config.pytorch_gpu.memory_threshold = 6
+ assert kokoro_backend._check_memory() == False
+
+@patch('torch.cuda.empty_cache')
+@patch('torch.cuda.synchronize')
+def test_clear_memory(mock_sync, mock_clear, kokoro_backend):
+ """Test memory clearing."""
+ with patch.object(kokoro_backend, '_device', 'cuda'):
+ kokoro_backend._clear_memory()
+ mock_clear.assert_called_once()
+ mock_sync.assert_called_once()
+
+@pytest.mark.asyncio
+async def test_load_model_validation(kokoro_backend):
+ """Test model loading validation."""
+ with pytest.raises(RuntimeError, match="Failed to load Kokoro model"):
+ await kokoro_backend.load_model("nonexistent_model.pth")
+
+def test_unload(kokoro_backend):
+ """Test model unloading."""
+ # Mock loaded state
+ kokoro_backend._model = MagicMock()
+ kokoro_backend._pipeline = MagicMock()
+ assert kokoro_backend.is_loaded
+
+ # Test unload
+ kokoro_backend.unload()
+ assert not kokoro_backend.is_loaded
+ assert kokoro_backend._model is None
+ assert kokoro_backend._pipeline is None
+
+@pytest.mark.asyncio
+async def test_generate_validation(kokoro_backend):
+ """Test generation validation."""
+ with pytest.raises(RuntimeError, match="Model not loaded"):
+ async for _ in kokoro_backend.generate("test", "voice"):
+ pass
+
+@pytest.mark.asyncio
+async def test_generate_from_tokens_validation(kokoro_backend):
+ """Test token generation validation."""
+ with pytest.raises(RuntimeError, match="Model not loaded"):
+ async for _ in kokoro_backend.generate_from_tokens("test tokens", "voice"):
+ pass
\ No newline at end of file
diff --git a/api/tests/test_openai_endpoints.py b/api/tests/test_openai_endpoints.py
index 8e28fe7..7e9386f 100644
--- a/api/tests/test_openai_endpoints.py
+++ b/api/tests/test_openai_endpoints.py
@@ -330,16 +330,19 @@ def test_list_voices(mock_tts_service):
assert "voice1" in data["voices"]
assert "voice2" in data["voices"]
-def test_combine_voices(mock_tts_service):
+@patch('api.src.routers.openai_compatible.settings')
+def test_combine_voices(mock_settings, mock_tts_service):
"""Test combining voices endpoint"""
+ # Enable local voice saving for this test
+ mock_settings.allow_local_voice_saving = True
+
response = client.post(
"/v1/audio/voices/combine",
json="voice1+voice2"
)
assert response.status_code == 200
- data = response.json()
- assert "voice" in data
- assert data["voice"] == "voice1_voice2"
+ assert response.headers["content-type"] == "application/octet-stream"
+ assert "voice1+voice2.pt" in response.headers["content-disposition"]
def test_server_error(mock_tts_service, test_voice):
"""Test handling of server errors"""
diff --git a/api/tests/test_paths.py b/api/tests/test_paths.py
new file mode 100644
index 0000000..209cffb
--- /dev/null
+++ b/api/tests/test_paths.py
@@ -0,0 +1,116 @@
+import os
+import pytest
+from unittest.mock import patch
+from api.src.core.paths import (
+ _find_file,
+ _scan_directories,
+ get_content_type,
+ get_temp_file_path,
+ list_temp_files,
+ get_temp_dir_size
+)
+
+@pytest.mark.asyncio
+async def test_find_file_exists():
+ """Test finding existing file."""
+ with patch('aiofiles.os.path.exists') as mock_exists:
+ mock_exists.return_value = True
+ path = await _find_file("test.txt", ["/test/path"])
+ assert path == "/test/path/test.txt"
+
+@pytest.mark.asyncio
+async def test_find_file_not_exists():
+ """Test finding non-existent file."""
+ with patch('aiofiles.os.path.exists') as mock_exists:
+ mock_exists.return_value = False
+ with pytest.raises(RuntimeError, match="File not found"):
+ await _find_file("test.txt", ["/test/path"])
+
+@pytest.mark.asyncio
+async def test_find_file_with_filter():
+ """Test finding file with filter function."""
+ with patch('aiofiles.os.path.exists') as mock_exists:
+ mock_exists.return_value = True
+ filter_fn = lambda p: p.endswith('.txt')
+ path = await _find_file("test.txt", ["/test/path"], filter_fn)
+ assert path == "/test/path/test.txt"
+
+@pytest.mark.asyncio
+async def test_scan_directories():
+ """Test scanning directories."""
+ mock_entry = type('MockEntry', (), {'name': 'test.txt'})()
+
+ with patch('aiofiles.os.path.exists') as mock_exists, \
+ patch('aiofiles.os.scandir') as mock_scandir:
+ mock_exists.return_value = True
+ mock_scandir.return_value = [mock_entry]
+
+ files = await _scan_directories(["/test/path"])
+ assert "test.txt" in files
+
+@pytest.mark.asyncio
+async def test_get_content_type():
+ """Test content type detection."""
+ test_cases = [
+ ("test.html", "text/html"),
+ ("test.js", "application/javascript"),
+ ("test.css", "text/css"),
+ ("test.png", "image/png"),
+ ("test.unknown", "application/octet-stream"),
+ ]
+
+ for filename, expected in test_cases:
+ content_type = await get_content_type(filename)
+ assert content_type == expected
+
+@pytest.mark.asyncio
+async def test_get_temp_file_path():
+ """Test temp file path generation."""
+ with patch('aiofiles.os.path.exists') as mock_exists, \
+ patch('aiofiles.os.makedirs') as mock_makedirs:
+ mock_exists.return_value = False
+
+ path = await get_temp_file_path("test.wav")
+ assert "test.wav" in path
+ mock_makedirs.assert_called_once()
+
+@pytest.mark.asyncio
+async def test_list_temp_files():
+ """Test listing temp files."""
+ class MockEntry:
+ def __init__(self, name):
+ self.name = name
+ def is_file(self):
+ return True
+
+ mock_entry = MockEntry('test.wav')
+
+ with patch('aiofiles.os.path.exists') as mock_exists, \
+ patch('aiofiles.os.scandir') as mock_scandir:
+ mock_exists.return_value = True
+ mock_scandir.return_value = [mock_entry]
+
+ files = await list_temp_files()
+ assert "test.wav" in files
+
+@pytest.mark.asyncio
+async def test_get_temp_dir_size():
+ """Test getting temp directory size."""
+ class MockEntry:
+ def __init__(self, path):
+ self.path = path
+ def is_file(self):
+ return True
+
+ mock_entry = MockEntry('/tmp/test.wav')
+ mock_stat = type('MockStat', (), {'st_size': 1024})()
+
+ with patch('aiofiles.os.path.exists') as mock_exists, \
+ patch('aiofiles.os.scandir') as mock_scandir, \
+ patch('aiofiles.os.stat') as mock_stat_fn:
+ mock_exists.return_value = True
+ mock_scandir.return_value = [mock_entry]
+ mock_stat_fn.return_value = mock_stat
+
+ size = await get_temp_dir_size()
+ assert size == 1024
\ No newline at end of file
diff --git a/api/tests/test_text_processor.py b/api/tests/test_text_processor.py
new file mode 100644
index 0000000..4b1d37d
--- /dev/null
+++ b/api/tests/test_text_processor.py
@@ -0,0 +1,80 @@
+import pytest
+from api.src.services.text_processing.text_processor import (
+ process_text_chunk,
+ get_sentence_info,
+ smart_split
+)
+
+def test_process_text_chunk_basic():
+ """Test basic text chunk processing."""
+ text = "Hello world"
+ tokens = process_text_chunk(text)
+ assert isinstance(tokens, list)
+ assert len(tokens) > 0
+
+def test_process_text_chunk_empty():
+ """Test processing empty text."""
+ text = ""
+ tokens = process_text_chunk(text)
+ assert isinstance(tokens, list)
+ assert len(tokens) == 0
+
+def test_process_text_chunk_phonemes():
+ """Test processing with skip_phonemize."""
+ phonemes = "h @ l @U" # Example phoneme sequence
+ tokens = process_text_chunk(phonemes, skip_phonemize=True)
+ assert isinstance(tokens, list)
+ assert len(tokens) > 0
+
+def test_get_sentence_info():
+ """Test sentence splitting and info extraction."""
+ text = "This is sentence one. This is sentence two! What about three?"
+ results = get_sentence_info(text)
+
+ assert len(results) == 3
+ for sentence, tokens, count in results:
+ assert isinstance(sentence, str)
+ assert isinstance(tokens, list)
+ assert isinstance(count, int)
+ assert count == len(tokens)
+ assert count > 0
+
+@pytest.mark.asyncio
+async def test_smart_split_short_text():
+ """Test smart splitting with text under max tokens."""
+ text = "This is a short test sentence."
+ chunks = []
+ async for chunk_text, chunk_tokens in smart_split(text):
+ chunks.append((chunk_text, chunk_tokens))
+
+ assert len(chunks) == 1
+ assert isinstance(chunks[0][0], str)
+ assert isinstance(chunks[0][1], list)
+
+@pytest.mark.asyncio
+async def test_smart_split_long_text():
+ """Test smart splitting with longer text."""
+ # Create text that should split into multiple chunks
+ text = ". ".join(["This is test sentence number " + str(i) for i in range(20)])
+
+ chunks = []
+ async for chunk_text, chunk_tokens in smart_split(text):
+ chunks.append((chunk_text, chunk_tokens))
+
+ assert len(chunks) > 1
+ for chunk_text, chunk_tokens in chunks:
+ assert isinstance(chunk_text, str)
+ assert isinstance(chunk_tokens, list)
+ assert len(chunk_tokens) > 0
+
+@pytest.mark.asyncio
+async def test_smart_split_with_punctuation():
+ """Test smart splitting handles punctuation correctly."""
+ text = "First sentence! Second sentence? Third sentence; Fourth sentence: Fifth sentence."
+
+ chunks = []
+ async for chunk_text, chunk_tokens in smart_split(text):
+ chunks.append(chunk_text)
+
+ # Verify punctuation is preserved
+ assert all(any(p in chunk for p in "!?;:." ) for chunk in chunks)
\ No newline at end of file
diff --git a/api/tests/test_tts_service.py b/api/tests/test_tts_service.py
new file mode 100644
index 0000000..5f4129a
--- /dev/null
+++ b/api/tests/test_tts_service.py
@@ -0,0 +1,104 @@
+import pytest
+import numpy as np
+import torch
+from unittest.mock import AsyncMock, patch, MagicMock
+from api.src.services.tts_service import TTSService
+
+@pytest.fixture
+def mock_managers():
+ """Mock model and voice managers."""
+ async def _mock_managers():
+ model_manager = AsyncMock()
+ model_manager.get_backend.return_value = MagicMock()
+
+ voice_manager = AsyncMock()
+ voice_manager.get_voice_path.return_value = "/path/to/voice.pt"
+ voice_manager.list_voices.return_value = ["voice1", "voice2"]
+
+ with patch("api.src.services.tts_service.get_model_manager") as mock_get_model, \
+ patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice:
+ mock_get_model.return_value = model_manager
+ mock_get_voice.return_value = voice_manager
+ return model_manager, voice_manager
+ return _mock_managers()
+
+@pytest.fixture
+def tts_service(mock_managers):
+ """Create TTSService instance with mocked dependencies."""
+ async def _create_service():
+ return await TTSService.create("test_output")
+ return _create_service()
+
+@pytest.mark.asyncio
+async def test_service_creation():
+ """Test service creation and initialization."""
+ model_manager = AsyncMock()
+ voice_manager = AsyncMock()
+
+ with patch("api.src.services.tts_service.get_model_manager") as mock_get_model, \
+ patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice:
+ mock_get_model.return_value = model_manager
+ mock_get_voice.return_value = voice_manager
+
+ service = await TTSService.create("test_output")
+ assert service.output_dir == "test_output"
+ assert service.model_manager is model_manager
+ assert service._voice_manager is voice_manager
+
+@pytest.mark.asyncio
+async def test_get_voice_path_single():
+ """Test getting path for single voice."""
+ model_manager = AsyncMock()
+ voice_manager = AsyncMock()
+ voice_manager.get_voice_path.return_value = "/path/to/voice1.pt"
+
+ with patch("api.src.services.tts_service.get_model_manager") as mock_get_model, \
+ patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice:
+ mock_get_model.return_value = model_manager
+ mock_get_voice.return_value = voice_manager
+
+ service = await TTSService.create("test_output")
+ name, path = await service._get_voice_path("voice1")
+ assert name == "voice1"
+ assert path == "/path/to/voice1.pt"
+ voice_manager.get_voice_path.assert_called_once_with("voice1")
+
+@pytest.mark.asyncio
+async def test_get_voice_path_combined():
+ """Test getting path for combined voices."""
+ model_manager = AsyncMock()
+ voice_manager = AsyncMock()
+ voice_manager.get_voice_path.return_value = "/path/to/voice.pt"
+
+ with patch("api.src.services.tts_service.get_model_manager") as mock_get_model, \
+ patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice, \
+ patch("torch.load") as mock_load, \
+ patch("torch.save") as mock_save, \
+ patch("tempfile.gettempdir") as mock_temp:
+ mock_get_model.return_value = model_manager
+ mock_get_voice.return_value = voice_manager
+ mock_temp.return_value = "/tmp"
+ mock_load.return_value = torch.ones(10)
+
+ service = await TTSService.create("test_output")
+ name, path = await service._get_voice_path("voice1+voice2")
+ assert name == "voice1+voice2"
+ assert path.endswith("voice1+voice2.pt")
+ mock_save.assert_called_once()
+
+@pytest.mark.asyncio
+async def test_list_voices():
+ """Test listing available voices."""
+ model_manager = AsyncMock()
+ voice_manager = AsyncMock()
+ voice_manager.list_voices.return_value = ["voice1", "voice2"]
+
+ with patch("api.src.services.tts_service.get_model_manager") as mock_get_model, \
+ patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice:
+ mock_get_model.return_value = model_manager
+ mock_get_voice.return_value = voice_manager
+
+ service = await TTSService.create("test_output")
+ voices = await service.list_voices()
+ assert voices == ["voice1", "voice2"]
+ voice_manager.list_voices.assert_called_once()
\ No newline at end of file
diff --git a/api/tests/test_tts_service_new.py b/api/tests/test_tts_service_new.py
deleted file mode 100644
index 42d02ae..0000000
--- a/api/tests/test_tts_service_new.py
+++ /dev/null
@@ -1,142 +0,0 @@
-# import pytest
-# import numpy as np
-# from unittest.mock import AsyncMock, patch
-
-# @pytest.mark.asyncio
-# async def test_generate_audio(tts_service, mock_audio_output, test_voice):
-# """Test basic audio generation"""
-# audio, processing_time = await tts_service.generate_audio(
-# text="Hello world",
-# voice=test_voice,
-# speed=1.0
-# )
-
-# assert isinstance(audio, np.ndarray)
-# assert audio == mock_audio_output.tobytes()
-# assert processing_time > 0
-# tts_service.model_manager.generate.assert_called_once()
-
-# @pytest.mark.asyncio
-# async def test_generate_audio_with_combined_voice(tts_service, mock_audio_output):
-# """Test audio generation with a combined voice"""
-# test_voices = ["voice1", "voice2"]
-# combined_id = await tts_service._voice_manager.combine_voices(test_voices)
-
-# audio, processing_time = await tts_service.generate_audio(
-# text="Hello world",
-# voice=combined_id,
-# speed=1.0
-# )
-
-# assert isinstance(audio, np.ndarray)
-# assert np.array_equal(audio, mock_audio_output)
-# assert processing_time > 0
-
-# @pytest.mark.asyncio
-# async def test_generate_audio_stream(tts_service, mock_audio_output, test_voice):
-# """Test streaming audio generation"""
-# tts_service.model_manager.generate.return_value = mock_audio_output
-
-# chunks = []
-# async for chunk in tts_service.generate_audio_stream(
-# text="Hello world",
-# voice=test_voice,
-# speed=1.0,
-# output_format="pcm"
-# ):
-# assert isinstance(chunk, bytes)
-# chunks.append(chunk)
-
-# assert len(chunks) > 0
-# tts_service.model_manager.generate.assert_called()
-
-# @pytest.mark.asyncio
-# async def test_empty_text(tts_service, test_voice):
-# """Test handling empty text"""
-# with pytest.raises(ValueError) as exc_info:
-# await tts_service.generate_audio(
-# text="",
-# voice=test_voice,
-# speed=1.0
-# )
-# assert "No audio chunks were generated successfully" in str(exc_info.value)
-
-# @pytest.mark.asyncio
-# async def test_invalid_voice(tts_service):
-# """Test handling invalid voice"""
-# tts_service._voice_manager.load_voice.side_effect = ValueError("Voice not found")
-
-# with pytest.raises(ValueError) as exc_info:
-# await tts_service.generate_audio(
-# text="Hello world",
-# voice="invalid_voice",
-# speed=1.0
-# )
-# assert "Voice not found" in str(exc_info.value)
-
-# @pytest.mark.asyncio
-# async def test_model_generation_error(tts_service, test_voice):
-# """Test handling model generation error"""
-# # Make generate return None to simulate failed generation
-# tts_service.model_manager.generate.return_value = None
-
-# with pytest.raises(ValueError) as exc_info:
-# await tts_service.generate_audio(
-# text="Hello world",
-# voice=test_voice,
-# speed=1.0
-# )
-# assert "No audio chunks were generated successfully" in str(exc_info.value)
-
-# @pytest.mark.asyncio
-# async def test_streaming_generation_error(tts_service, test_voice):
-# """Test handling streaming generation error"""
-# # Make generate return None to simulate failed generation
-# tts_service.model_manager.generate.return_value = None
-
-# chunks = []
-# async for chunk in tts_service.generate_audio_stream(
-# text="Hello world",
-# voice=test_voice,
-# speed=1.0,
-# output_format="pcm"
-# ):
-# chunks.append(chunk)
-
-# # Should get no chunks if generation fails
-# assert len(chunks) == 0
-
-# @pytest.mark.asyncio
-# async def test_list_voices(tts_service):
-# """Test listing available voices"""
-# voices = await tts_service.list_voices()
-# assert len(voices) == 2
-# assert "voice1" in voices
-# assert "voice2" in voices
-# tts_service._voice_manager.list_voices.assert_called_once()
-
-# @pytest.mark.asyncio
-# async def test_combine_voices(tts_service):
-# """Test combining voices"""
-# test_voices = ["voice1", "voice2"]
-# combined_id = await tts_service.combine_voices(test_voices)
-# assert combined_id == "voice1_voice2"
-# tts_service._voice_manager.combine_voices.assert_called_once_with(test_voices)
-
-# @pytest.mark.asyncio
-# async def test_chunked_text_processing(tts_service, test_voice, mock_audio_output):
-# """Test processing chunked text"""
-# # Create text that will force chunking by exceeding max tokens
-# long_text = "This is a test sentence." * 100 # Should be way over 500 tokens
-
-# # Don't mock smart_split - let it actually split the text
-# audio, processing_time = await tts_service.generate_audio(
-# text=long_text,
-# voice=test_voice,
-# speed=1.0
-# )
-
-# # Should be called multiple times due to chunking
-# assert tts_service.model_manager.generate.call_count > 1
-# assert isinstance(audio, np.ndarray)
-# assert processing_time > 0
\ No newline at end of file
diff --git a/docker-bake.hcl b/docker-bake.hcl
new file mode 100644
index 0000000..bed8065
--- /dev/null
+++ b/docker-bake.hcl
@@ -0,0 +1,80 @@
+# Variables for reuse
+variable "VERSION" {
+ default = "latest"
+}
+
+variable "REGISTRY" {
+ default = "ghcr.io"
+}
+
+variable "OWNER" {
+ default = "remsky"
+}
+
+variable "REPO" {
+ default = "kokoro-fastapi"
+}
+
+# Common settings shared between targets
+target "_common" {
+ context = "."
+ args = {
+ DEBIAN_FRONTEND = "noninteractive"
+ }
+ cache-from = ["type=registry,ref=${REGISTRY}/${OWNER}/${REPO}-cache"]
+ cache-to = ["type=registry,ref=${REGISTRY}/${OWNER}/${REPO}-cache,mode=max"]
+}
+
+# Base settings for CPU builds
+target "_cpu_base" {
+ inherits = ["_common"]
+ dockerfile = "docker/cpu/Dockerfile"
+}
+
+# Base settings for GPU builds
+target "_gpu_base" {
+ inherits = ["_common"]
+ dockerfile = "docker/gpu/Dockerfile"
+}
+
+# CPU target with multi-platform support
+target "cpu" {
+ inherits = ["_cpu_base"]
+ platforms = ["linux/amd64", "linux/arm64"]
+ tags = [
+ "${REGISTRY}/${OWNER}/${REPO}-cpu:${VERSION}",
+ "${REGISTRY}/${OWNER}/${REPO}-cpu:latest"
+ ]
+}
+
+# GPU target with multi-platform support
+target "gpu" {
+ inherits = ["_gpu_base"]
+ platforms = ["linux/amd64", "linux/arm64"]
+ tags = [
+ "${REGISTRY}/${OWNER}/${REPO}-gpu:${VERSION}",
+ "${REGISTRY}/${OWNER}/${REPO}-gpu:latest"
+ ]
+}
+
+# Default group to build both CPU and GPU versions
+group "default" {
+ targets = ["cpu", "gpu"]
+}
+
+# Development targets for faster local builds
+target "cpu-dev" {
+ inherits = ["_cpu_base"]
+ # No multi-platform for dev builds
+ tags = ["${REGISTRY}/${OWNER}/${REPO}-cpu:dev"]
+}
+
+target "gpu-dev" {
+ inherits = ["_gpu_base"]
+ # No multi-platform for dev builds
+ tags = ["${REGISTRY}/${OWNER}/${REPO}-gpu:dev"]
+}
+
+group "dev" {
+ targets = ["cpu-dev", "gpu-dev"]
+}
\ No newline at end of file
diff --git a/docker/build.sh b/docker/build.sh
index 57b5d16..c002127 100755
--- a/docker/build.sh
+++ b/docker/build.sh
@@ -4,33 +4,9 @@ set -e
# Get version from argument or use default
VERSION=${1:-"latest"}
-# GitHub Container Registry settings
-REGISTRY="ghcr.io"
-OWNER="remsky"
-REPO="kokoro-fastapi"
-
-# Create and use a new builder that supports multi-platform builds
-docker buildx create --name multiplatform-builder --use || true
-
-# Build CPU image with multi-platform support
-echo "Building CPU image..."
-docker buildx build --platform linux/amd64,linux/arm64 \
- -t ${REGISTRY}/${OWNER}/${REPO}-cpu:${VERSION} \
- -t ${REGISTRY}/${OWNER}/${REPO}-cpu:latest \
- -f docker/cpu/Dockerfile \
- --push .
-
-# Build GPU image with multi-platform support
-echo "Building GPU image..."
-docker buildx build --platform linux/amd64,linux/arm64 \
- -t ${REGISTRY}/${OWNER}/${REPO}-gpu:${VERSION} \
- -t ${REGISTRY}/${OWNER}/${REPO}-gpu:latest \
- -f docker/gpu/Dockerfile \
- --push .
+# Build both CPU and GPU images using docker buildx bake
+echo "Building CPU and GPU images..."
+VERSION=$VERSION docker buildx bake --push
echo "Build complete!"
-echo "Created images:"
-echo "- ${REGISTRY}/${OWNER}/${REPO}-cpu:${VERSION} (linux/amd64, linux/arm64)"
-echo "- ${REGISTRY}/${OWNER}/${REPO}-cpu:latest (linux/amd64, linux/arm64)"
-echo "- ${REGISTRY}/${OWNER}/${REPO}-gpu:${VERSION} (linux/amd64, linux/arm64)"
-echo "- ${REGISTRY}/${OWNER}/${REPO}-gpu:latest (linux/amd64, linux/arm64)"
+echo "Created images with version: $VERSION"
diff --git a/docker/cpu/Dockerfile b/docker/cpu/Dockerfile
index b97cd9c..a3fc898 100644
--- a/docker/cpu/Dockerfile
+++ b/docker/cpu/Dockerfile
@@ -3,14 +3,15 @@ FROM --platform=$BUILDPLATFORM python:3.10-slim
# Install dependencies and check espeak location
RUN apt-get update && apt-get install -y \
espeak-ng \
+ espeak-ng-data \
git \
libsndfile1 \
curl \
ffmpeg \
- && dpkg -L espeak-ng \
- && find / -name "espeak-ng-data" \
&& apt-get clean \
- && rm -rf /var/lib/apt/lists/*
+ && rm -rf /var/lib/apt/lists/* \
+ && mkdir -p /usr/share/espeak-ng-data \
+ && ln -s /usr/lib/x86_64-linux-gnu/espeak-ng-data/* /usr/share/espeak-ng-data/
# Install UV using the installer script
RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
@@ -20,9 +21,7 @@ RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
# Create non-root user and set up directories and permissions
RUN useradd -m -u 1000 appuser && \
mkdir -p /app/api/src/models/v1_0 && \
- chown -R appuser:appuser /app && \
- chown -R appuser:appuser /lib/x86_64-linux-gnu/espeak-ng-data
-
+ chown -R appuser:appuser /app
USER appuser
WORKDIR /app
@@ -33,17 +32,13 @@ COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml
# Install dependencies
RUN --mount=type=cache,target=/root/.cache/uv \
uv venv && \
- uv sync --extra cpu --no-install-project
+ uv sync --extra cpu
# Copy project files including models
COPY --chown=appuser:appuser api ./api
COPY --chown=appuser:appuser web ./web
COPY --chown=appuser:appuser docker/scripts/download_model.* ./
-# Install project
-RUN --mount=type=cache,target=/root/.cache/uv \
- uv sync --extra cpu
-
# Set environment variables
ENV PYTHONUNBUFFERED=1 \
PYTHONPATH=/app:/app/api \
diff --git a/docker/cpu/docker-compose.yml b/docker/cpu/docker-compose.yml
index a365475..ed15540 100644
--- a/docker/cpu/docker-compose.yml
+++ b/docker/cpu/docker-compose.yml
@@ -1,7 +1,6 @@
-name: kokoro-tts
+name: kokoro-fastapi-cpu
services:
kokoro-tts:
- # image: ghcr.io/remsky/kokoro-fastapi-cpu:v0.2.0
build:
context: ../..
dockerfile: docker/cpu/Dockerfile
diff --git a/docker/gpu/Dockerfile b/docker/gpu/Dockerfile
index 7fcbb11..df9d383 100644
--- a/docker/gpu/Dockerfile
+++ b/docker/gpu/Dockerfile
@@ -1,5 +1,4 @@
-FROM --platform=$BUILDPLATFORM nvidia/cuda:12.1.0-cudnn8-runtime-ubuntu22.04
-
+FROM --platform=$BUILDPLATFORM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04
# Set non-interactive frontend
ENV DEBIAN_FRONTEND=noninteractive
@@ -8,30 +7,26 @@ RUN apt-get update && apt-get install -y \
python3.10 \
python3.10-venv \
espeak-ng \
+ espeak-ng-data \
git \
libsndfile1 \
curl \
ffmpeg \
- && ls -la /usr/lib/x86_64-linux-gnu/espeak-ng-data \
- && apt-get clean \
- && rm -rf /var/lib/apt/lists/*
+ && apt-get clean && rm -rf /var/lib/apt/lists/* \
+ && mkdir -p /usr/share/espeak-ng-data \
+ && ln -s /usr/lib/x86_64-linux-gnu/espeak-ng-data/* /usr/share/espeak-ng-data/
-# Create user and set up permissions
-RUN useradd -m -u 1000 appuser && \
- mkdir -p /app/api/src/models/v1_0 && \
- chown -R appuser:appuser /app && \
- chown -R appuser:appuser /usr/lib/x86_64-linux-gnu/espeak-ng-data
-
-
-# Rest of your Dockerfile...
-
-# Install UV in a separate step
+# Install UV using the installer script
RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
mv /root/.local/bin/uv /usr/local/bin/ && \
- mv /root/.local/bin/uvx /usr/local/bin/
+ mv /root/.local/bin/uvx /usr/local/bin/ && \
+ useradd -m -u 1000 appuser && \
+ mkdir -p /app/api/src/models/v1_0 && \
+ chown -R appuser:appuser /app
USER appuser
WORKDIR /app
+
# Copy dependency files
COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml
diff --git a/docker/gpu/docker-compose.yml b/docker/gpu/docker-compose.yml
index 508d011..762aca6 100644
--- a/docker/gpu/docker-compose.yml
+++ b/docker/gpu/docker-compose.yml
@@ -1,4 +1,4 @@
-name: kokoro-tts
+name: kokoro-tts-gpu
services:
kokoro-tts:
# image: ghcr.io/remsky/kokoro-fastapi-gpu:v0.2.0
diff --git a/docker/shared/pyproject.toml b/docker/shared/pyproject.toml
deleted file mode 100644
index f45ff5e..0000000
--- a/docker/shared/pyproject.toml
+++ /dev/null
@@ -1,45 +0,0 @@
-[project]
-name = "kokoro-fastapi"
-version = "0.1.0"
-description = "FastAPI TTS Service"
-readme = "../README.md"
-requires-python = ">=3.10"
-dependencies = [
- # Core dependencies
- "fastapi==0.115.6",
- "uvicorn==0.34.0",
- "click>=8.0.0",
- "pydantic==2.10.4",
- "pydantic-settings==2.7.0",
- "python-dotenv==1.0.1",
- "sqlalchemy==2.0.27",
-
- # ML/DL Base
- "numpy>=1.26.0",
- "scipy==1.14.1",
- "onnxruntime==1.20.1",
-
- # Audio processing
- "soundfile==0.13.0",
-
- # Text processing
- "phonemizer==3.3.0",
- "regex==2024.11.6",
-
- # Utilities
- "aiofiles==23.2.1",
- "tqdm==4.67.1",
- "requests==2.32.3",
- "munch==4.0.0",
- "tiktoken==0.8.0",
- "loguru==0.7.3",
- "pydub>=0.25.1",
-]
-
-[project.optional-dependencies]
-test = [
- "pytest==8.0.0",
- "httpx==0.26.0",
- "pytest-asyncio==0.23.5",
- "ruff==0.9.1",
-]
diff --git a/pyproject.toml b/pyproject.toml
index 6d09582..551dc59 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -28,28 +28,23 @@ dependencies = [
"munch==4.0.0",
"tiktoken==0.8.0",
"loguru==0.7.3",
- # "transformers==4.47.1",
"openai>=1.59.6",
- # "ebooklib>=0.18",
- # "html2text>=2024.2.26",
"pydub>=0.25.1",
"matplotlib>=3.10.0",
"mutagen>=1.47.0",
"psutil>=6.1.1",
- "kokoro==0.7.6",
- 'misaki[en,ja,ko,zh,vi]==0.7.6',
+ "kokoro==0.7.9",
+ 'misaki[en,ja,ko,zh,vi]==0.7.9',
"spacy>=3.7.6",
"en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl"
]
[project.optional-dependencies]
gpu = [
- "torch==2.5.1+cu121",
- #"onnxruntime-gpu==1.20.1",
+ "torch==2.6.0+cu124",
]
cpu = [
- "torch==2.5.1",
- #"onnxruntime==1.20.1",
+ "torch==2.6.0",
]
test = [
"pytest==8.0.0",
@@ -81,7 +76,7 @@ explicit = true
[[tool.uv.index]]
name = "pytorch-cuda"
-url = "https://download.pytorch.org/whl/cu121"
+url = "https://download.pytorch.org/whl/cu124"
explicit = true
[build-system]
diff --git a/pytest.ini b/pytest.ini
index 47be4b5..3bcd461 100644
--- a/pytest.ini
+++ b/pytest.ini
@@ -1,5 +1,5 @@
[pytest]
-testpaths = api/tests ui/tests
+testpaths = api/tests
python_files = test_*.py
-addopts = -v --tb=short --cov=api --cov=ui --cov-report=term-missing --cov-config=.coveragerc
+addopts = -v --tb=short --cov=api --cov-report=term-missing --cov-config=.coveragerc
pythonpath = .
diff --git a/ui/tests/conftest.py b/ui/depr_tests/conftest.py
similarity index 100%
rename from ui/tests/conftest.py
rename to ui/depr_tests/conftest.py
diff --git a/ui/tests/test_api.py b/ui/depr_tests/test_api.py
similarity index 100%
rename from ui/tests/test_api.py
rename to ui/depr_tests/test_api.py
diff --git a/ui/tests/test_components.py b/ui/depr_tests/test_components.py
similarity index 100%
rename from ui/tests/test_components.py
rename to ui/depr_tests/test_components.py
diff --git a/ui/tests/test_files.py b/ui/depr_tests/test_files.py
similarity index 100%
rename from ui/tests/test_files.py
rename to ui/depr_tests/test_files.py
diff --git a/ui/tests/test_handlers.py b/ui/depr_tests/test_handlers.py
similarity index 100%
rename from ui/tests/test_handlers.py
rename to ui/depr_tests/test_handlers.py
diff --git a/ui/tests/test_input.py b/ui/depr_tests/test_input.py
similarity index 100%
rename from ui/tests/test_input.py
rename to ui/depr_tests/test_input.py
diff --git a/ui/tests/test_interface.py b/ui/depr_tests/test_interface.py
similarity index 100%
rename from ui/tests/test_interface.py
rename to ui/depr_tests/test_interface.py