mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Refactor Docker configurations for GPU and CPU, update test paths, and remove deprecated tests
This commit is contained in:
parent
165ffccd01
commit
ac7947b51a
24 changed files with 495 additions and 263 deletions
|
@ -3,8 +3,8 @@
|
|||
</p>
|
||||
|
||||
# <sub><sub>_`FastKoko`_ </sub></sub>
|
||||
[]()
|
||||
[]()
|
||||
[]()
|
||||
[]()
|
||||
[](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero)
|
||||
|
||||
[](https://huggingface.co/hexgrad/Kokoro-82M/commit/9901c2b79161b6e898b7ea857ae5298f47b8b0d6)
|
||||
|
|
|
@ -93,6 +93,7 @@ Model files not found! You need to download the Kokoro V1 model:
|
|||
{boundary}
|
||||
"""
|
||||
startup_msg += f"\nModel warmed up on {device}: {model}"
|
||||
startup_msg += f"CUDA: {torch.cuda.is_available()}"
|
||||
startup_msg += f"\n{voicepack_count} voice packs loaded"
|
||||
|
||||
# Add web player info if enabled
|
||||
|
|
75
api/tests/test_kokoro_v1.py
Normal file
75
api/tests/test_kokoro_v1.py
Normal file
|
@ -0,0 +1,75 @@
|
|||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import torch
|
||||
import numpy as np
|
||||
from api.src.inference.kokoro_v1 import KokoroV1
|
||||
|
||||
@pytest.fixture
|
||||
def kokoro_backend():
|
||||
"""Create a KokoroV1 instance for testing."""
|
||||
return KokoroV1()
|
||||
|
||||
def test_initial_state(kokoro_backend):
|
||||
"""Test initial state of KokoroV1."""
|
||||
assert not kokoro_backend.is_loaded
|
||||
assert kokoro_backend._model is None
|
||||
assert kokoro_backend._pipeline is None
|
||||
# Device should be set based on settings
|
||||
assert kokoro_backend.device in ["cuda", "cpu"]
|
||||
|
||||
@patch('torch.cuda.is_available', return_value=True)
|
||||
@patch('torch.cuda.memory_allocated')
|
||||
def test_memory_management(mock_memory, mock_cuda, kokoro_backend):
|
||||
"""Test GPU memory management functions."""
|
||||
# Mock GPU memory usage
|
||||
mock_memory.return_value = 5e9 # 5GB
|
||||
|
||||
# Test memory check
|
||||
with patch('api.src.inference.kokoro_v1.model_config') as mock_config:
|
||||
mock_config.pytorch_gpu.memory_threshold = 4
|
||||
assert kokoro_backend._check_memory() == True
|
||||
|
||||
mock_config.pytorch_gpu.memory_threshold = 6
|
||||
assert kokoro_backend._check_memory() == False
|
||||
|
||||
@patch('torch.cuda.empty_cache')
|
||||
@patch('torch.cuda.synchronize')
|
||||
def test_clear_memory(mock_sync, mock_clear, kokoro_backend):
|
||||
"""Test memory clearing."""
|
||||
with patch.object(kokoro_backend, '_device', 'cuda'):
|
||||
kokoro_backend._clear_memory()
|
||||
mock_clear.assert_called_once()
|
||||
mock_sync.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_model_validation(kokoro_backend):
|
||||
"""Test model loading validation."""
|
||||
with pytest.raises(RuntimeError, match="Failed to load Kokoro model"):
|
||||
await kokoro_backend.load_model("nonexistent_model.pth")
|
||||
|
||||
def test_unload(kokoro_backend):
|
||||
"""Test model unloading."""
|
||||
# Mock loaded state
|
||||
kokoro_backend._model = MagicMock()
|
||||
kokoro_backend._pipeline = MagicMock()
|
||||
assert kokoro_backend.is_loaded
|
||||
|
||||
# Test unload
|
||||
kokoro_backend.unload()
|
||||
assert not kokoro_backend.is_loaded
|
||||
assert kokoro_backend._model is None
|
||||
assert kokoro_backend._pipeline is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_validation(kokoro_backend):
|
||||
"""Test generation validation."""
|
||||
with pytest.raises(RuntimeError, match="Model not loaded"):
|
||||
async for _ in kokoro_backend.generate("test", "voice"):
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_from_tokens_validation(kokoro_backend):
|
||||
"""Test token generation validation."""
|
||||
with pytest.raises(RuntimeError, match="Model not loaded"):
|
||||
async for _ in kokoro_backend.generate_from_tokens("test tokens", "voice"):
|
||||
pass
|
|
@ -330,16 +330,19 @@ def test_list_voices(mock_tts_service):
|
|||
assert "voice1" in data["voices"]
|
||||
assert "voice2" in data["voices"]
|
||||
|
||||
def test_combine_voices(mock_tts_service):
|
||||
@patch('api.src.routers.openai_compatible.settings')
|
||||
def test_combine_voices(mock_settings, mock_tts_service):
|
||||
"""Test combining voices endpoint"""
|
||||
# Enable local voice saving for this test
|
||||
mock_settings.allow_local_voice_saving = True
|
||||
|
||||
response = client.post(
|
||||
"/v1/audio/voices/combine",
|
||||
json="voice1+voice2"
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "voice" in data
|
||||
assert data["voice"] == "voice1_voice2"
|
||||
assert response.headers["content-type"] == "application/octet-stream"
|
||||
assert "voice1+voice2.pt" in response.headers["content-disposition"]
|
||||
|
||||
def test_server_error(mock_tts_service, test_voice):
|
||||
"""Test handling of server errors"""
|
||||
|
|
116
api/tests/test_paths.py
Normal file
116
api/tests/test_paths.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
import os
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
from api.src.core.paths import (
|
||||
_find_file,
|
||||
_scan_directories,
|
||||
get_content_type,
|
||||
get_temp_file_path,
|
||||
list_temp_files,
|
||||
get_temp_dir_size
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_file_exists():
|
||||
"""Test finding existing file."""
|
||||
with patch('aiofiles.os.path.exists') as mock_exists:
|
||||
mock_exists.return_value = True
|
||||
path = await _find_file("test.txt", ["/test/path"])
|
||||
assert path == "/test/path/test.txt"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_file_not_exists():
|
||||
"""Test finding non-existent file."""
|
||||
with patch('aiofiles.os.path.exists') as mock_exists:
|
||||
mock_exists.return_value = False
|
||||
with pytest.raises(RuntimeError, match="File not found"):
|
||||
await _find_file("test.txt", ["/test/path"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_file_with_filter():
|
||||
"""Test finding file with filter function."""
|
||||
with patch('aiofiles.os.path.exists') as mock_exists:
|
||||
mock_exists.return_value = True
|
||||
filter_fn = lambda p: p.endswith('.txt')
|
||||
path = await _find_file("test.txt", ["/test/path"], filter_fn)
|
||||
assert path == "/test/path/test.txt"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scan_directories():
|
||||
"""Test scanning directories."""
|
||||
mock_entry = type('MockEntry', (), {'name': 'test.txt'})()
|
||||
|
||||
with patch('aiofiles.os.path.exists') as mock_exists, \
|
||||
patch('aiofiles.os.scandir') as mock_scandir:
|
||||
mock_exists.return_value = True
|
||||
mock_scandir.return_value = [mock_entry]
|
||||
|
||||
files = await _scan_directories(["/test/path"])
|
||||
assert "test.txt" in files
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_content_type():
|
||||
"""Test content type detection."""
|
||||
test_cases = [
|
||||
("test.html", "text/html"),
|
||||
("test.js", "application/javascript"),
|
||||
("test.css", "text/css"),
|
||||
("test.png", "image/png"),
|
||||
("test.unknown", "application/octet-stream"),
|
||||
]
|
||||
|
||||
for filename, expected in test_cases:
|
||||
content_type = await get_content_type(filename)
|
||||
assert content_type == expected
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_temp_file_path():
|
||||
"""Test temp file path generation."""
|
||||
with patch('aiofiles.os.path.exists') as mock_exists, \
|
||||
patch('aiofiles.os.makedirs') as mock_makedirs:
|
||||
mock_exists.return_value = False
|
||||
|
||||
path = await get_temp_file_path("test.wav")
|
||||
assert "test.wav" in path
|
||||
mock_makedirs.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_temp_files():
|
||||
"""Test listing temp files."""
|
||||
class MockEntry:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
def is_file(self):
|
||||
return True
|
||||
|
||||
mock_entry = MockEntry('test.wav')
|
||||
|
||||
with patch('aiofiles.os.path.exists') as mock_exists, \
|
||||
patch('aiofiles.os.scandir') as mock_scandir:
|
||||
mock_exists.return_value = True
|
||||
mock_scandir.return_value = [mock_entry]
|
||||
|
||||
files = await list_temp_files()
|
||||
assert "test.wav" in files
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_temp_dir_size():
|
||||
"""Test getting temp directory size."""
|
||||
class MockEntry:
|
||||
def __init__(self, path):
|
||||
self.path = path
|
||||
def is_file(self):
|
||||
return True
|
||||
|
||||
mock_entry = MockEntry('/tmp/test.wav')
|
||||
mock_stat = type('MockStat', (), {'st_size': 1024})()
|
||||
|
||||
with patch('aiofiles.os.path.exists') as mock_exists, \
|
||||
patch('aiofiles.os.scandir') as mock_scandir, \
|
||||
patch('aiofiles.os.stat') as mock_stat_fn:
|
||||
mock_exists.return_value = True
|
||||
mock_scandir.return_value = [mock_entry]
|
||||
mock_stat_fn.return_value = mock_stat
|
||||
|
||||
size = await get_temp_dir_size()
|
||||
assert size == 1024
|
80
api/tests/test_text_processor.py
Normal file
80
api/tests/test_text_processor.py
Normal file
|
@ -0,0 +1,80 @@
|
|||
import pytest
|
||||
from api.src.services.text_processing.text_processor import (
|
||||
process_text_chunk,
|
||||
get_sentence_info,
|
||||
smart_split
|
||||
)
|
||||
|
||||
def test_process_text_chunk_basic():
|
||||
"""Test basic text chunk processing."""
|
||||
text = "Hello world"
|
||||
tokens = process_text_chunk(text)
|
||||
assert isinstance(tokens, list)
|
||||
assert len(tokens) > 0
|
||||
|
||||
def test_process_text_chunk_empty():
|
||||
"""Test processing empty text."""
|
||||
text = ""
|
||||
tokens = process_text_chunk(text)
|
||||
assert isinstance(tokens, list)
|
||||
assert len(tokens) == 0
|
||||
|
||||
def test_process_text_chunk_phonemes():
|
||||
"""Test processing with skip_phonemize."""
|
||||
phonemes = "h @ l @U" # Example phoneme sequence
|
||||
tokens = process_text_chunk(phonemes, skip_phonemize=True)
|
||||
assert isinstance(tokens, list)
|
||||
assert len(tokens) > 0
|
||||
|
||||
def test_get_sentence_info():
|
||||
"""Test sentence splitting and info extraction."""
|
||||
text = "This is sentence one. This is sentence two! What about three?"
|
||||
results = get_sentence_info(text)
|
||||
|
||||
assert len(results) == 3
|
||||
for sentence, tokens, count in results:
|
||||
assert isinstance(sentence, str)
|
||||
assert isinstance(tokens, list)
|
||||
assert isinstance(count, int)
|
||||
assert count == len(tokens)
|
||||
assert count > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_split_short_text():
|
||||
"""Test smart splitting with text under max tokens."""
|
||||
text = "This is a short test sentence."
|
||||
chunks = []
|
||||
async for chunk_text, chunk_tokens in smart_split(text):
|
||||
chunks.append((chunk_text, chunk_tokens))
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert isinstance(chunks[0][0], str)
|
||||
assert isinstance(chunks[0][1], list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_split_long_text():
|
||||
"""Test smart splitting with longer text."""
|
||||
# Create text that should split into multiple chunks
|
||||
text = ". ".join(["This is test sentence number " + str(i) for i in range(20)])
|
||||
|
||||
chunks = []
|
||||
async for chunk_text, chunk_tokens in smart_split(text):
|
||||
chunks.append((chunk_text, chunk_tokens))
|
||||
|
||||
assert len(chunks) > 1
|
||||
for chunk_text, chunk_tokens in chunks:
|
||||
assert isinstance(chunk_text, str)
|
||||
assert isinstance(chunk_tokens, list)
|
||||
assert len(chunk_tokens) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_smart_split_with_punctuation():
|
||||
"""Test smart splitting handles punctuation correctly."""
|
||||
text = "First sentence! Second sentence? Third sentence; Fourth sentence: Fifth sentence."
|
||||
|
||||
chunks = []
|
||||
async for chunk_text, chunk_tokens in smart_split(text):
|
||||
chunks.append(chunk_text)
|
||||
|
||||
# Verify punctuation is preserved
|
||||
assert all(any(p in chunk for p in "!?;:." ) for chunk in chunks)
|
104
api/tests/test_tts_service.py
Normal file
104
api/tests/test_tts_service.py
Normal file
|
@ -0,0 +1,104 @@
|
|||
import pytest
|
||||
import numpy as np
|
||||
import torch
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from api.src.services.tts_service import TTSService
|
||||
|
||||
@pytest.fixture
|
||||
def mock_managers():
|
||||
"""Mock model and voice managers."""
|
||||
async def _mock_managers():
|
||||
model_manager = AsyncMock()
|
||||
model_manager.get_backend.return_value = MagicMock()
|
||||
|
||||
voice_manager = AsyncMock()
|
||||
voice_manager.get_voice_path.return_value = "/path/to/voice.pt"
|
||||
voice_manager.list_voices.return_value = ["voice1", "voice2"]
|
||||
|
||||
with patch("api.src.services.tts_service.get_model_manager") as mock_get_model, \
|
||||
patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice:
|
||||
mock_get_model.return_value = model_manager
|
||||
mock_get_voice.return_value = voice_manager
|
||||
return model_manager, voice_manager
|
||||
return _mock_managers()
|
||||
|
||||
@pytest.fixture
|
||||
def tts_service(mock_managers):
|
||||
"""Create TTSService instance with mocked dependencies."""
|
||||
async def _create_service():
|
||||
return await TTSService.create("test_output")
|
||||
return _create_service()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_creation():
|
||||
"""Test service creation and initialization."""
|
||||
model_manager = AsyncMock()
|
||||
voice_manager = AsyncMock()
|
||||
|
||||
with patch("api.src.services.tts_service.get_model_manager") as mock_get_model, \
|
||||
patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice:
|
||||
mock_get_model.return_value = model_manager
|
||||
mock_get_voice.return_value = voice_manager
|
||||
|
||||
service = await TTSService.create("test_output")
|
||||
assert service.output_dir == "test_output"
|
||||
assert service.model_manager is model_manager
|
||||
assert service._voice_manager is voice_manager
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_voice_path_single():
|
||||
"""Test getting path for single voice."""
|
||||
model_manager = AsyncMock()
|
||||
voice_manager = AsyncMock()
|
||||
voice_manager.get_voice_path.return_value = "/path/to/voice1.pt"
|
||||
|
||||
with patch("api.src.services.tts_service.get_model_manager") as mock_get_model, \
|
||||
patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice:
|
||||
mock_get_model.return_value = model_manager
|
||||
mock_get_voice.return_value = voice_manager
|
||||
|
||||
service = await TTSService.create("test_output")
|
||||
name, path = await service._get_voice_path("voice1")
|
||||
assert name == "voice1"
|
||||
assert path == "/path/to/voice1.pt"
|
||||
voice_manager.get_voice_path.assert_called_once_with("voice1")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_voice_path_combined():
|
||||
"""Test getting path for combined voices."""
|
||||
model_manager = AsyncMock()
|
||||
voice_manager = AsyncMock()
|
||||
voice_manager.get_voice_path.return_value = "/path/to/voice.pt"
|
||||
|
||||
with patch("api.src.services.tts_service.get_model_manager") as mock_get_model, \
|
||||
patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice, \
|
||||
patch("torch.load") as mock_load, \
|
||||
patch("torch.save") as mock_save, \
|
||||
patch("tempfile.gettempdir") as mock_temp:
|
||||
mock_get_model.return_value = model_manager
|
||||
mock_get_voice.return_value = voice_manager
|
||||
mock_temp.return_value = "/tmp"
|
||||
mock_load.return_value = torch.ones(10)
|
||||
|
||||
service = await TTSService.create("test_output")
|
||||
name, path = await service._get_voice_path("voice1+voice2")
|
||||
assert name == "voice1+voice2"
|
||||
assert path.endswith("voice1+voice2.pt")
|
||||
mock_save.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_voices():
|
||||
"""Test listing available voices."""
|
||||
model_manager = AsyncMock()
|
||||
voice_manager = AsyncMock()
|
||||
voice_manager.list_voices.return_value = ["voice1", "voice2"]
|
||||
|
||||
with patch("api.src.services.tts_service.get_model_manager") as mock_get_model, \
|
||||
patch("api.src.services.tts_service.get_voice_manager") as mock_get_voice:
|
||||
mock_get_model.return_value = model_manager
|
||||
mock_get_voice.return_value = voice_manager
|
||||
|
||||
service = await TTSService.create("test_output")
|
||||
voices = await service.list_voices()
|
||||
assert voices == ["voice1", "voice2"]
|
||||
voice_manager.list_voices.assert_called_once()
|
|
@ -1,142 +0,0 @@
|
|||
# import pytest
|
||||
# import numpy as np
|
||||
# from unittest.mock import AsyncMock, patch
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_generate_audio(tts_service, mock_audio_output, test_voice):
|
||||
# """Test basic audio generation"""
|
||||
# audio, processing_time = await tts_service.generate_audio(
|
||||
# text="Hello world",
|
||||
# voice=test_voice,
|
||||
# speed=1.0
|
||||
# )
|
||||
|
||||
# assert isinstance(audio, np.ndarray)
|
||||
# assert audio == mock_audio_output.tobytes()
|
||||
# assert processing_time > 0
|
||||
# tts_service.model_manager.generate.assert_called_once()
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_generate_audio_with_combined_voice(tts_service, mock_audio_output):
|
||||
# """Test audio generation with a combined voice"""
|
||||
# test_voices = ["voice1", "voice2"]
|
||||
# combined_id = await tts_service._voice_manager.combine_voices(test_voices)
|
||||
|
||||
# audio, processing_time = await tts_service.generate_audio(
|
||||
# text="Hello world",
|
||||
# voice=combined_id,
|
||||
# speed=1.0
|
||||
# )
|
||||
|
||||
# assert isinstance(audio, np.ndarray)
|
||||
# assert np.array_equal(audio, mock_audio_output)
|
||||
# assert processing_time > 0
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_generate_audio_stream(tts_service, mock_audio_output, test_voice):
|
||||
# """Test streaming audio generation"""
|
||||
# tts_service.model_manager.generate.return_value = mock_audio_output
|
||||
|
||||
# chunks = []
|
||||
# async for chunk in tts_service.generate_audio_stream(
|
||||
# text="Hello world",
|
||||
# voice=test_voice,
|
||||
# speed=1.0,
|
||||
# output_format="pcm"
|
||||
# ):
|
||||
# assert isinstance(chunk, bytes)
|
||||
# chunks.append(chunk)
|
||||
|
||||
# assert len(chunks) > 0
|
||||
# tts_service.model_manager.generate.assert_called()
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_empty_text(tts_service, test_voice):
|
||||
# """Test handling empty text"""
|
||||
# with pytest.raises(ValueError) as exc_info:
|
||||
# await tts_service.generate_audio(
|
||||
# text="",
|
||||
# voice=test_voice,
|
||||
# speed=1.0
|
||||
# )
|
||||
# assert "No audio chunks were generated successfully" in str(exc_info.value)
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_invalid_voice(tts_service):
|
||||
# """Test handling invalid voice"""
|
||||
# tts_service._voice_manager.load_voice.side_effect = ValueError("Voice not found")
|
||||
|
||||
# with pytest.raises(ValueError) as exc_info:
|
||||
# await tts_service.generate_audio(
|
||||
# text="Hello world",
|
||||
# voice="invalid_voice",
|
||||
# speed=1.0
|
||||
# )
|
||||
# assert "Voice not found" in str(exc_info.value)
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_model_generation_error(tts_service, test_voice):
|
||||
# """Test handling model generation error"""
|
||||
# # Make generate return None to simulate failed generation
|
||||
# tts_service.model_manager.generate.return_value = None
|
||||
|
||||
# with pytest.raises(ValueError) as exc_info:
|
||||
# await tts_service.generate_audio(
|
||||
# text="Hello world",
|
||||
# voice=test_voice,
|
||||
# speed=1.0
|
||||
# )
|
||||
# assert "No audio chunks were generated successfully" in str(exc_info.value)
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_streaming_generation_error(tts_service, test_voice):
|
||||
# """Test handling streaming generation error"""
|
||||
# # Make generate return None to simulate failed generation
|
||||
# tts_service.model_manager.generate.return_value = None
|
||||
|
||||
# chunks = []
|
||||
# async for chunk in tts_service.generate_audio_stream(
|
||||
# text="Hello world",
|
||||
# voice=test_voice,
|
||||
# speed=1.0,
|
||||
# output_format="pcm"
|
||||
# ):
|
||||
# chunks.append(chunk)
|
||||
|
||||
# # Should get no chunks if generation fails
|
||||
# assert len(chunks) == 0
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_list_voices(tts_service):
|
||||
# """Test listing available voices"""
|
||||
# voices = await tts_service.list_voices()
|
||||
# assert len(voices) == 2
|
||||
# assert "voice1" in voices
|
||||
# assert "voice2" in voices
|
||||
# tts_service._voice_manager.list_voices.assert_called_once()
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_combine_voices(tts_service):
|
||||
# """Test combining voices"""
|
||||
# test_voices = ["voice1", "voice2"]
|
||||
# combined_id = await tts_service.combine_voices(test_voices)
|
||||
# assert combined_id == "voice1_voice2"
|
||||
# tts_service._voice_manager.combine_voices.assert_called_once_with(test_voices)
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_chunked_text_processing(tts_service, test_voice, mock_audio_output):
|
||||
# """Test processing chunked text"""
|
||||
# # Create text that will force chunking by exceeding max tokens
|
||||
# long_text = "This is a test sentence." * 100 # Should be way over 500 tokens
|
||||
|
||||
# # Don't mock smart_split - let it actually split the text
|
||||
# audio, processing_time = await tts_service.generate_audio(
|
||||
# text=long_text,
|
||||
# voice=test_voice,
|
||||
# speed=1.0
|
||||
# )
|
||||
|
||||
# # Should be called multiple times due to chunking
|
||||
# assert tts_service.model_manager.generate.call_count > 1
|
||||
# assert isinstance(audio, np.ndarray)
|
||||
# assert processing_time > 0
|
80
docker-bake.hcl
Normal file
80
docker-bake.hcl
Normal file
|
@ -0,0 +1,80 @@
|
|||
# Variables for reuse
|
||||
variable "VERSION" {
|
||||
default = "latest"
|
||||
}
|
||||
|
||||
variable "REGISTRY" {
|
||||
default = "ghcr.io"
|
||||
}
|
||||
|
||||
variable "OWNER" {
|
||||
default = "remsky"
|
||||
}
|
||||
|
||||
variable "REPO" {
|
||||
default = "kokoro-fastapi"
|
||||
}
|
||||
|
||||
# Common settings shared between targets
|
||||
target "_common" {
|
||||
context = "."
|
||||
args = {
|
||||
DEBIAN_FRONTEND = "noninteractive"
|
||||
}
|
||||
cache-from = ["type=registry,ref=${REGISTRY}/${OWNER}/${REPO}-cache"]
|
||||
cache-to = ["type=registry,ref=${REGISTRY}/${OWNER}/${REPO}-cache,mode=max"]
|
||||
}
|
||||
|
||||
# Base settings for CPU builds
|
||||
target "_cpu_base" {
|
||||
inherits = ["_common"]
|
||||
dockerfile = "docker/cpu/Dockerfile"
|
||||
}
|
||||
|
||||
# Base settings for GPU builds
|
||||
target "_gpu_base" {
|
||||
inherits = ["_common"]
|
||||
dockerfile = "docker/gpu/Dockerfile"
|
||||
}
|
||||
|
||||
# CPU target with multi-platform support
|
||||
target "cpu" {
|
||||
inherits = ["_cpu_base"]
|
||||
platforms = ["linux/amd64", "linux/arm64"]
|
||||
tags = [
|
||||
"${REGISTRY}/${OWNER}/${REPO}-cpu:${VERSION}",
|
||||
"${REGISTRY}/${OWNER}/${REPO}-cpu:latest"
|
||||
]
|
||||
}
|
||||
|
||||
# GPU target with multi-platform support
|
||||
target "gpu" {
|
||||
inherits = ["_gpu_base"]
|
||||
platforms = ["linux/amd64", "linux/arm64"]
|
||||
tags = [
|
||||
"${REGISTRY}/${OWNER}/${REPO}-gpu:${VERSION}",
|
||||
"${REGISTRY}/${OWNER}/${REPO}-gpu:latest"
|
||||
]
|
||||
}
|
||||
|
||||
# Default group to build both CPU and GPU versions
|
||||
group "default" {
|
||||
targets = ["cpu", "gpu"]
|
||||
}
|
||||
|
||||
# Development targets for faster local builds
|
||||
target "cpu-dev" {
|
||||
inherits = ["_cpu_base"]
|
||||
# No multi-platform for dev builds
|
||||
tags = ["${REGISTRY}/${OWNER}/${REPO}-cpu:dev"]
|
||||
}
|
||||
|
||||
target "gpu-dev" {
|
||||
inherits = ["_gpu_base"]
|
||||
# No multi-platform for dev builds
|
||||
tags = ["${REGISTRY}/${OWNER}/${REPO}-gpu:dev"]
|
||||
}
|
||||
|
||||
group "dev" {
|
||||
targets = ["cpu-dev", "gpu-dev"]
|
||||
}
|
|
@ -4,33 +4,9 @@ set -e
|
|||
# Get version from argument or use default
|
||||
VERSION=${1:-"latest"}
|
||||
|
||||
# GitHub Container Registry settings
|
||||
REGISTRY="ghcr.io"
|
||||
OWNER="remsky"
|
||||
REPO="kokoro-fastapi"
|
||||
|
||||
# Create and use a new builder that supports multi-platform builds
|
||||
docker buildx create --name multiplatform-builder --use || true
|
||||
|
||||
# Build CPU image with multi-platform support
|
||||
echo "Building CPU image..."
|
||||
docker buildx build --platform linux/amd64,linux/arm64 \
|
||||
-t ${REGISTRY}/${OWNER}/${REPO}-cpu:${VERSION} \
|
||||
-t ${REGISTRY}/${OWNER}/${REPO}-cpu:latest \
|
||||
-f docker/cpu/Dockerfile \
|
||||
--push .
|
||||
|
||||
# Build GPU image with multi-platform support
|
||||
echo "Building GPU image..."
|
||||
docker buildx build --platform linux/amd64,linux/arm64 \
|
||||
-t ${REGISTRY}/${OWNER}/${REPO}-gpu:${VERSION} \
|
||||
-t ${REGISTRY}/${OWNER}/${REPO}-gpu:latest \
|
||||
-f docker/gpu/Dockerfile \
|
||||
--push .
|
||||
# Build both CPU and GPU images using docker buildx bake
|
||||
echo "Building CPU and GPU images..."
|
||||
VERSION=$VERSION docker buildx bake --push
|
||||
|
||||
echo "Build complete!"
|
||||
echo "Created images:"
|
||||
echo "- ${REGISTRY}/${OWNER}/${REPO}-cpu:${VERSION} (linux/amd64, linux/arm64)"
|
||||
echo "- ${REGISTRY}/${OWNER}/${REPO}-cpu:latest (linux/amd64, linux/arm64)"
|
||||
echo "- ${REGISTRY}/${OWNER}/${REPO}-gpu:${VERSION} (linux/amd64, linux/arm64)"
|
||||
echo "- ${REGISTRY}/${OWNER}/${REPO}-gpu:latest (linux/amd64, linux/arm64)"
|
||||
echo "Created images with version: $VERSION"
|
||||
|
|
|
@ -3,14 +3,15 @@ FROM --platform=$BUILDPLATFORM python:3.10-slim
|
|||
# Install dependencies and check espeak location
|
||||
RUN apt-get update && apt-get install -y \
|
||||
espeak-ng \
|
||||
espeak-ng-data \
|
||||
git \
|
||||
libsndfile1 \
|
||||
curl \
|
||||
ffmpeg \
|
||||
&& dpkg -L espeak-ng \
|
||||
&& find / -name "espeak-ng-data" \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
&& rm -rf /var/lib/apt/lists/* \
|
||||
&& mkdir -p /usr/share/espeak-ng-data \
|
||||
&& ln -s /usr/lib/x86_64-linux-gnu/espeak-ng-data/* /usr/share/espeak-ng-data/
|
||||
|
||||
# Install UV using the installer script
|
||||
RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
||||
|
@ -20,9 +21,7 @@ RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
|||
# Create non-root user and set up directories and permissions
|
||||
RUN useradd -m -u 1000 appuser && \
|
||||
mkdir -p /app/api/src/models/v1_0 && \
|
||||
chown -R appuser:appuser /app && \
|
||||
chown -R appuser:appuser /lib/x86_64-linux-gnu/espeak-ng-data
|
||||
|
||||
chown -R appuser:appuser /app
|
||||
|
||||
USER appuser
|
||||
WORKDIR /app
|
||||
|
@ -33,17 +32,13 @@ COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml
|
|||
# Install dependencies
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv venv && \
|
||||
uv sync --extra cpu --no-install-project
|
||||
uv sync --extra cpu
|
||||
|
||||
# Copy project files including models
|
||||
COPY --chown=appuser:appuser api ./api
|
||||
COPY --chown=appuser:appuser web ./web
|
||||
COPY --chown=appuser:appuser docker/scripts/download_model.* ./
|
||||
|
||||
# Install project
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv sync --extra cpu
|
||||
|
||||
# Set environment variables
|
||||
ENV PYTHONUNBUFFERED=1 \
|
||||
PYTHONPATH=/app:/app/api \
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
name: kokoro-tts
|
||||
name: kokoro-fastapi-cpu
|
||||
services:
|
||||
kokoro-tts:
|
||||
# image: ghcr.io/remsky/kokoro-fastapi-cpu:v0.2.0
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: docker/cpu/Dockerfile
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
FROM --platform=$BUILDPLATFORM nvidia/cuda:12.1.0-cudnn8-runtime-ubuntu22.04
|
||||
|
||||
FROM --platform=$BUILDPLATFORM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04
|
||||
# Set non-interactive frontend
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
|
@ -8,30 +7,26 @@ RUN apt-get update && apt-get install -y \
|
|||
python3.10 \
|
||||
python3.10-venv \
|
||||
espeak-ng \
|
||||
espeak-ng-data \
|
||||
git \
|
||||
libsndfile1 \
|
||||
curl \
|
||||
ffmpeg \
|
||||
&& ls -la /usr/lib/x86_64-linux-gnu/espeak-ng-data \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/* \
|
||||
&& mkdir -p /usr/share/espeak-ng-data \
|
||||
&& ln -s /usr/lib/x86_64-linux-gnu/espeak-ng-data/* /usr/share/espeak-ng-data/
|
||||
|
||||
# Create user and set up permissions
|
||||
RUN useradd -m -u 1000 appuser && \
|
||||
mkdir -p /app/api/src/models/v1_0 && \
|
||||
chown -R appuser:appuser /app && \
|
||||
chown -R appuser:appuser /usr/lib/x86_64-linux-gnu/espeak-ng-data
|
||||
|
||||
|
||||
# Rest of your Dockerfile...
|
||||
|
||||
# Install UV in a separate step
|
||||
# Install UV using the installer script
|
||||
RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
||||
mv /root/.local/bin/uv /usr/local/bin/ && \
|
||||
mv /root/.local/bin/uvx /usr/local/bin/
|
||||
mv /root/.local/bin/uvx /usr/local/bin/ && \
|
||||
useradd -m -u 1000 appuser && \
|
||||
mkdir -p /app/api/src/models/v1_0 && \
|
||||
chown -R appuser:appuser /app
|
||||
|
||||
USER appuser
|
||||
WORKDIR /app
|
||||
|
||||
# Copy dependency files
|
||||
COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
name: kokoro-tts
|
||||
name: kokoro-tts-gpu
|
||||
services:
|
||||
kokoro-tts:
|
||||
# image: ghcr.io/remsky/kokoro-fastapi-gpu:v0.2.0
|
||||
|
|
|
@ -1,45 +0,0 @@
|
|||
[project]
|
||||
name = "kokoro-fastapi"
|
||||
version = "0.1.0"
|
||||
description = "FastAPI TTS Service"
|
||||
readme = "../README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
# Core dependencies
|
||||
"fastapi==0.115.6",
|
||||
"uvicorn==0.34.0",
|
||||
"click>=8.0.0",
|
||||
"pydantic==2.10.4",
|
||||
"pydantic-settings==2.7.0",
|
||||
"python-dotenv==1.0.1",
|
||||
"sqlalchemy==2.0.27",
|
||||
|
||||
# ML/DL Base
|
||||
"numpy>=1.26.0",
|
||||
"scipy==1.14.1",
|
||||
"onnxruntime==1.20.1",
|
||||
|
||||
# Audio processing
|
||||
"soundfile==0.13.0",
|
||||
|
||||
# Text processing
|
||||
"phonemizer==3.3.0",
|
||||
"regex==2024.11.6",
|
||||
|
||||
# Utilities
|
||||
"aiofiles==23.2.1",
|
||||
"tqdm==4.67.1",
|
||||
"requests==2.32.3",
|
||||
"munch==4.0.0",
|
||||
"tiktoken==0.8.0",
|
||||
"loguru==0.7.3",
|
||||
"pydub>=0.25.1",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
test = [
|
||||
"pytest==8.0.0",
|
||||
"httpx==0.26.0",
|
||||
"pytest-asyncio==0.23.5",
|
||||
"ruff==0.9.1",
|
||||
]
|
|
@ -28,28 +28,23 @@ dependencies = [
|
|||
"munch==4.0.0",
|
||||
"tiktoken==0.8.0",
|
||||
"loguru==0.7.3",
|
||||
# "transformers==4.47.1",
|
||||
"openai>=1.59.6",
|
||||
# "ebooklib>=0.18",
|
||||
# "html2text>=2024.2.26",
|
||||
"pydub>=0.25.1",
|
||||
"matplotlib>=3.10.0",
|
||||
"mutagen>=1.47.0",
|
||||
"psutil>=6.1.1",
|
||||
"kokoro==0.7.6",
|
||||
'misaki[en,ja,ko,zh,vi]==0.7.6',
|
||||
"kokoro==0.7.9",
|
||||
'misaki[en,ja,ko,zh,vi]==0.7.9',
|
||||
"spacy>=3.7.6",
|
||||
"en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl"
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
gpu = [
|
||||
"torch==2.5.1+cu121",
|
||||
#"onnxruntime-gpu==1.20.1",
|
||||
"torch==2.6.0+cu124",
|
||||
]
|
||||
cpu = [
|
||||
"torch==2.5.1",
|
||||
#"onnxruntime==1.20.1",
|
||||
"torch==2.6.0",
|
||||
]
|
||||
test = [
|
||||
"pytest==8.0.0",
|
||||
|
@ -81,7 +76,7 @@ explicit = true
|
|||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cuda"
|
||||
url = "https://download.pytorch.org/whl/cu121"
|
||||
url = "https://download.pytorch.org/whl/cu124"
|
||||
explicit = true
|
||||
|
||||
[build-system]
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
[pytest]
|
||||
testpaths = api/tests ui/tests
|
||||
testpaths = api/tests
|
||||
python_files = test_*.py
|
||||
addopts = -v --tb=short --cov=api --cov=ui --cov-report=term-missing --cov-config=.coveragerc
|
||||
addopts = -v --tb=short --cov=api --cov-report=term-missing --cov-config=.coveragerc
|
||||
pythonpath = .
|
||||
|
|
Loading…
Add table
Reference in a new issue