From cf61cfa0052d0acdfad44a3e35feb871cd79de12 Mon Sep 17 00:00:00 2001 From: remsky Date: Wed, 1 Jan 2025 03:41:23 -0700 Subject: [PATCH 1/3] Update commit hash to include af_sky --- docker-compose.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker-compose.yml b/docker-compose.yml index aacb121..60af4f3 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -8,7 +8,7 @@ services: sh -c " if [ -z \"$(ls -A .)\" ]; then git clone https://huggingface.co/hexgrad/Kokoro-82M . && \ - git checkout 8228a351f87c8a6076502c1e3b7e72e821ebec9a; + git checkout 7e9ebc5be7f66a1843b585b63d19d55b5d58ce30; touch .cloned; else touch .cloned; From 7938de0f4ad632933d6131170fa2aecd10236cc0 Mon Sep 17 00:00:00 2001 From: remsky Date: Wed, 1 Jan 2025 03:41:52 -0700 Subject: [PATCH 2/3] Update docker-compose.cpu.yml --- docker-compose.cpu.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker-compose.cpu.yml b/docker-compose.cpu.yml index 02aa674..b581639 100644 --- a/docker-compose.cpu.yml +++ b/docker-compose.cpu.yml @@ -8,7 +8,7 @@ services: sh -c " if [ -z \"$(ls -A .)\" ]; then git clone https://huggingface.co/hexgrad/Kokoro-82M . && \ - git checkout 8228a351f87c8a6076502c1e3b7e72e821ebec9a; + git checkout 7e9ebc5be7f66a1843b585b63d19d55b5d58ce30; touch .cloned; else touch .cloned; From 53cf71c151a5b26eb54da5d29ed55a4254e3e54b Mon Sep 17 00:00:00 2001 From: remsky Date: Wed, 1 Jan 2025 17:38:22 -0700 Subject: [PATCH 3/3] -Removed commit lock on HF repo -Warm start added to model initialization -Layer caching tweaks to dockerfile --- .coveragerc | 5 +- .dockerignore | 41 ++++++++++++ .gitignore | 2 +- CHANGELOG.md | 27 ++++++++ Dockerfile | 14 ++-- README.md | 6 +- api/src/main.py | 16 ++--- api/src/services/tts.py | 103 ++++++++++++++++------------ api/tests/conftest.py | 15 +++++ api/tests/test_main.py | 123 +++++++++++++++++++++++++++------- api/tests/test_tts_service.py | 33 ++++----- docker-compose.cpu.yml | 9 ++- docker-compose.yml | 21 +++++- 13 files changed, 301 insertions(+), 114 deletions(-) create mode 100644 .dockerignore create mode 100644 CHANGELOG.md diff --git a/.coveragerc b/.coveragerc index f422eda..4072f19 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,6 +1,9 @@ [run] source = api -omit = Kokoro-82M/* +omit = + Kokoro-82M/* + MagicMock/* + test_*.py [report] exclude_lines = diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..b456f25 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,41 @@ +# Version control +.git +.gitignore + +# Python +__pycache__ +*.pyc +*.pyo +*.pyd +.Python +*.py[cod] +*$py.class +.pytest_cache +.coverage +.coveragerc + +# Environment +# .env +.venv +env/ +venv/ +ENV/ + +# IDE +.idea +.vscode +*.swp +*.swo + +# Project specific +examples/ +Kokoro-82M/ +ui/ +tests/ +*.md +*.txt +!requirements.txt + +# Docker +Dockerfile* +docker-compose* diff --git a/.gitignore b/.gitignore index 1d9db35..98b9187 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,6 @@ output/ - +ui/data/* *.db *.pyc diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..36715cd --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,27 @@ +# Changelog + +Notable changes to this project will be documented in this file. + +## 2024-01-09 + +### Modified +#### Configuration Changes +- Updated Docker configurations: + - Changes to `Dockerfile`: + - Improved layer caching by separating dependency and code layers + - Updates to `docker-compose.yml` and `docker-compose.cpu.yml`: + - Removed commit lock from model fetching to allow automatic model updates from HF + - Added git index lock cleanup + +#### API Changes +- Modified `api/src/main.py` +- Updated TTS service implementation in `api/src/services/tts.py`: + - Added device management for better resource control: + - Voices are now copied from model repository to api/src/voices directory for persistence + - Refactored voice pack handling: + - Removed static voice pack dictionary + - On-demand voice loading from disk + - Added model warm-up functionality: + - Model now initializes with a dummy text generation + - Uses default voice (af.pt) for warm-up + - Model is ready for inference on first request diff --git a/Dockerfile b/Dockerfile index e06d314..7d70af9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -17,25 +17,25 @@ RUN pip3 install --no-cache-dir torch==2.5.1 --extra-index-url https://download. COPY requirements.txt . RUN pip3 install --no-cache-dir -r requirements.txt -# Copy application code and model -COPY . /app/ - # Set working directory WORKDIR /app -# Run with Python unbuffered output for live logging -ENV PYTHONUNBUFFERED=1 - # Create non-root user RUN useradd -m -u 1000 appuser -# Create directories and set permissions +# Create model directory and set ownership RUN mkdir -p /app/Kokoro-82M && \ chown -R appuser:appuser /app # Switch to non-root user USER appuser +# Run with Python unbuffered output for live logging +ENV PYTHONUNBUFFERED=1 + +# Copy only necessary application code +COPY --chown=appuser:appuser api /app/api + # Set Python path (app first for our imports, then model dir for model imports) ENV PYTHONPATH=/app:/app/Kokoro-82M diff --git a/README.md b/README.md index fa70f3b..a626cc0 100644 --- a/README.md +++ b/README.md @@ -3,9 +3,9 @@

# Kokoro TTS API -[![Model Commit](https://img.shields.io/badge/model--commit-a67f113-blue)](https://huggingface.co/hexgrad/Kokoro-82M/tree/8228a351f87c8a6076502c1e3b7e72e821ebec9a) -[![Tests](https://img.shields.io/badge/tests-36%20passed-darkgreen)]() -[![Coverage](https://img.shields.io/badge/coverage-91%25-darkgreen)]() +[![Tests](https://img.shields.io/badge/tests-37%20passed-darkgreen)]() +[![Coverage](https://img.shields.io/badge/coverage-81%25-darkgreen)]() +[![Tested at Model Commit](https://img.shields.io/badge/last--tested--model--commit-a67f113-blue)](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667) FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model, providing an OpenAI-compatible endpoint with: - NVIDIA GPU accelerated inference (or CPU) option diff --git a/api/src/main.py b/api/src/main.py index 9115ade..ebe2f53 100644 --- a/api/src/main.py +++ b/api/src/main.py @@ -19,18 +19,10 @@ async def lifespan(app: FastAPI): """Lifespan context manager for model initialization""" logger.info("Loading TTS model and voice packs...") - # Initialize the main model - model, device = TTSModel.get_instance() - logger.info(f"Model loaded on {device}") - - # Initialize all voice packs - tts_service = TTSService() - voices = tts_service.list_voices() - for voice in voices: - logger.info(f"Loading voice pack: {voice}") - TTSModel.get_voicepack(voice) - - logger.info("All models and voice packs loaded successfully") + # Initialize the main model with warm-up + model, voicepack_count = TTSModel.initialize() + logger.info(f"Model loaded and warmed up on {TTSModel._device}") + logger.info(f"{voicepack_count} voice packs loaded successfully") yield diff --git a/api/src/services/tts.py b/api/src/services/tts.py index 686ef5d..de76836 100644 --- a/api/src/services/tts.py +++ b/api/src/services/tts.py @@ -21,43 +21,63 @@ enc = tiktoken.get_encoding("cl100k_base") class TTSModel: _instance = None + _device = None _lock = threading.Lock() - _voicepacks = {} # Directory for all voices (copied base voices, and any created combined voices) VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices") @classmethod - def get_instance(cls): - if cls._instance is None: - with cls._lock: - if cls._instance is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - print(f"Initializing model on {device}") - model_path = os.path.join(settings.model_dir, settings.model_path) - model = build_model(model_path, device) - # Note: RNN memory optimization is handled internally by the model - cls._instance = (model, device) - return cls._instance - - @classmethod - def get_voicepack(cls, voice_name: str) -> torch.Tensor: - """Get a voice pack from the voices directory.""" - model, device = cls.get_instance() - if voice_name not in cls._voicepacks: - try: - voice_path = os.path.join(cls.VOICES_DIR, f"{voice_name}.pt") - if not os.path.exists(voice_path): - raise FileNotFoundError(f"Voice file not found: {voice_name}") + def initialize(cls): + """Initialize and warm up the model""" + with cls._lock: + if cls._instance is None: + # Initialize model + cls._device = "cuda" if torch.cuda.is_available() else "cpu" + logger.info(f"Initializing model on {cls._device}") + model_path = os.path.join(settings.model_dir, settings.model_path) + model = build_model(model_path, cls._device) + cls._instance = model - voicepack = torch.load(voice_path, map_location=device, weights_only=True) - cls._voicepacks[voice_name] = voicepack - except Exception as e: - logger.error(f"Error loading voice {voice_name}: {str(e)}") - if voice_name != "af": - return cls.get_voicepack("af") - raise - return cls._voicepacks[voice_name] + # Ensure voices directory exists + os.makedirs(cls.VOICES_DIR, exist_ok=True) + + # Copy base voices to local directory + base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir) + if os.path.exists(base_voices_dir): + for file in os.listdir(base_voices_dir): + if file.endswith(".pt"): + voice_name = file[:-3] + voice_path = os.path.join(cls.VOICES_DIR, file) + if not os.path.exists(voice_path): + try: + logger.info(f"Copying base voice {voice_name} to voices directory") + base_path = os.path.join(base_voices_dir, file) + voicepack = torch.load(base_path, map_location=cls._device, weights_only=True) + torch.save(voicepack, voice_path) + except Exception as e: + logger.error(f"Error copying voice {voice_name}: {str(e)}") + + # Warm up with default voice + try: + dummy_text = "Hello" + voice_path = os.path.join(cls.VOICES_DIR, "af.pt") + dummy_voicepack = torch.load(voice_path, map_location=cls._device, weights_only=True) + generate(model, dummy_text, dummy_voicepack, lang='a', speed=1.0) + logger.info("Model warm-up complete") + except Exception as e: + logger.warning(f"Model warm-up failed: {e}") + + # Count voices in directory for validation + voice_count = len([f for f in os.listdir(cls.VOICES_DIR) if f.endswith('.pt')]) + return cls._instance, voice_count + + @classmethod + def get_instance(cls): + """Get the initialized instance or raise an error""" + if cls._instance is None: + raise RuntimeError("Model not initialized. Call initialize() first.") + return cls._instance, cls._device class TTSService: @@ -79,9 +99,9 @@ class TTSService: voice_path = os.path.join(TTSModel.VOICES_DIR, file) if not os.path.exists(voice_path): try: - base_path = os.path.join(base_voices_dir, file) logger.info(f"Copying base voice {voice_name} to voices directory") - voicepack = torch.load(base_path, map_location=TTSModel.get_instance()[1], weights_only=True) + base_path = os.path.join(base_voices_dir, file) + voicepack = torch.load(base_path, map_location=TTSModel._device, weights_only=True) torch.save(voicepack, voice_path) except Exception as e: logger.error(f"Error copying voice {voice_name}: {str(e)}") @@ -114,21 +134,21 @@ class TTSService: if not text: raise ValueError("Text is empty after preprocessing") - # Get model instance - model, device = TTSModel.get_instance() - - # Load voice + # Check voice exists voice_path = self._get_voice_path(voice) if not voice_path: raise ValueError(f"Voice not found: {voice}") - - voicepack = torch.load(voice_path, map_location=device, weights_only=True) + + # Load model and voice + model = TTSModel._instance + voicepack = torch.load(voice_path, map_location=TTSModel._device, weights_only=True) # Generate audio with or without stitching if stitch_long_output: chunks = self._split_text(text) audio_chunks = [] + # Process all chunks with same model/voicepack instance for i, chunk in enumerate(chunks): try: # Validate phonemization first @@ -204,12 +224,9 @@ class TTSService: v_name: List[str] = [] for voice in voices: - voice_path = self._get_voice_path(voice) - if not voice_path: - raise ValueError(f"Voice not found: {voice}") - try: - voicepack = torch.load(voice_path, map_location=TTSModel.get_instance()[1], weights_only=True) + voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt") + voicepack = torch.load(voice_path, map_location=TTSModel._device, weights_only=True) t_voices.append(voicepack) v_name.append(voice) except Exception as e: diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 6648c15..5972003 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -1,8 +1,23 @@ +import os +import shutil import sys from unittest.mock import Mock, patch import pytest +def cleanup_mock_dirs(): + """Clean up any MagicMock directories created during tests""" + mock_dir = "MagicMock" + if os.path.exists(mock_dir): + shutil.rmtree(mock_dir) + +@pytest.fixture(autouse=True) +def cleanup(): + """Automatically clean up before and after each test""" + cleanup_mock_dirs() + yield + cleanup_mock_dirs() + # Mock torch and other ML modules before they're imported sys.modules["torch"] = Mock() sys.modules["transformers"] = Mock() diff --git a/api/tests/test_main.py b/api/tests/test_main.py index 9493d27..4eedc64 100644 --- a/api/tests/test_main.py +++ b/api/tests/test_main.py @@ -1,45 +1,116 @@ -"""Tests for main FastAPI application""" +"""Tests for FastAPI application""" import pytest from unittest.mock import patch, MagicMock from fastapi.testclient import TestClient - -from api.src.main import app +from api.src.main import app, lifespan @pytest.fixture -def client(): +def test_client(): """Create a test client""" return TestClient(app) -def test_health_check(client): +def test_health_check(test_client): """Test health check endpoint""" - response = client.get("/health") + response = test_client.get("/health") assert response.status_code == 200 assert response.json() == {"status": "healthy"} -def test_test_endpoint(client): - """Test the test endpoint""" - response = client.get("/v1/test") - assert response.status_code == 200 - assert response.json() == {"status": "ok"} +@pytest.mark.asyncio +@patch('api.src.main.TTSModel') +@patch('api.src.main.logger') +async def test_lifespan_successful_warmup(mock_logger, mock_tts_model): + """Test successful model warmup in lifespan""" + # Mock the model initialization with model info and voicepack count + mock_model = MagicMock() + # Mock file system for voice counting + mock_tts_model.VOICES_DIR = "/mock/voices" + with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt', 'voice3.pt']): + mock_tts_model.initialize.return_value = (mock_model, 3) # 3 voice files + mock_tts_model._device = "cuda" # Set device class variable + + # Create an async generator from the lifespan context manager + async_gen = lifespan(MagicMock()) + # Start the context manager + await async_gen.__aenter__() + + # Verify the expected logging sequence + mock_logger.info.assert_any_call("Loading TTS model and voice packs...") + mock_logger.info.assert_any_call("Model loaded and warmed up on cuda") + mock_logger.info.assert_any_call("3 voice packs loaded successfully") + + # Verify model initialization was called + mock_tts_model.initialize.assert_called_once() + + # Clean up + await async_gen.__aexit__(None, None, None) -def test_cors_headers(client): - """Test CORS headers are present""" - response = client.get( - "/health", - headers={"Origin": "http://testserver"}, - ) - assert response.status_code == 200 - assert response.headers["access-control-allow-origin"] == "*" +@pytest.mark.asyncio +@patch('api.src.main.TTSModel') +@patch('api.src.main.logger') +async def test_lifespan_failed_warmup(mock_logger, mock_tts_model): + """Test failed model warmup in lifespan""" + # Mock the model initialization to fail + mock_tts_model.initialize.side_effect = Exception("Failed to initialize model") + + # Create an async generator from the lifespan context manager + async_gen = lifespan(MagicMock()) + + # Verify the exception is raised + with pytest.raises(Exception, match="Failed to initialize model"): + await async_gen.__aenter__() + + # Verify the expected logging sequence + mock_logger.info.assert_called_with("Loading TTS model and voice packs...") + + # Clean up + await async_gen.__aexit__(None, None, None) -def test_openapi_schema(client): - """Test OpenAPI schema is accessible""" - response = client.get("/openapi.json") - assert response.status_code == 200 - schema = response.json() - assert schema["info"]["title"] == app.title - assert schema["info"]["version"] == app.version +@pytest.mark.asyncio +@patch('api.src.main.TTSModel') +async def test_lifespan_cuda_warmup(mock_tts_model): + """Test model warmup specifically on CUDA""" + # Mock the model initialization with CUDA and voicepacks + mock_model = MagicMock() + # Mock file system for voice counting + mock_tts_model.VOICES_DIR = "/mock/voices" + with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt']): + mock_tts_model.initialize.return_value = (mock_model, 2) # 2 voice files + mock_tts_model._device = "cuda" # Set device class variable + + # Create an async generator from the lifespan context manager + async_gen = lifespan(MagicMock()) + await async_gen.__aenter__() + + # Verify model was initialized + mock_tts_model.initialize.assert_called_once() + + # Clean up + await async_gen.__aexit__(None, None, None) + + +@pytest.mark.asyncio +@patch('api.src.main.TTSModel') +async def test_lifespan_cpu_fallback(mock_tts_model): + """Test model warmup falling back to CPU""" + # Mock the model initialization with CPU and voicepacks + mock_model = MagicMock() + # Mock file system for voice counting + mock_tts_model.VOICES_DIR = "/mock/voices" + with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt', 'voice3.pt', 'voice4.pt']): + mock_tts_model.initialize.return_value = (mock_model, 4) # 4 voice files + mock_tts_model._device = "cpu" # Set device class variable + + # Create an async generator from the lifespan context manager + async_gen = lifespan(MagicMock()) + await async_gen.__aenter__() + + # Verify model was initialized + mock_tts_model.initialize.assert_called_once() + + # Clean up + await async_gen.__aexit__(None, None, None) diff --git a/api/tests/test_tts_service.py b/api/tests/test_tts_service.py index 533c514..a0273ad 100644 --- a/api/tests/test_tts_service.py +++ b/api/tests/test_tts_service.py @@ -131,9 +131,9 @@ def test_model_initialization_cuda(mock_build_model, mock_cuda_available): mock_build_model.return_value = mock_model TTSModel._instance = None # Reset singleton - model, device = TTSModel.get_instance() + model, voice_count = TTSModel.initialize() - assert device == "cuda" + assert TTSModel._device == "cuda" # Check the class variable instead assert model == mock_model mock_build_model.assert_called_once() @@ -147,31 +147,34 @@ def test_model_initialization_cpu(mock_build_model, mock_cuda_available): mock_build_model.return_value = mock_model TTSModel._instance = None # Reset singleton - model, device = TTSModel.get_instance() + model, voice_count = TTSModel.initialize() - assert device == "cpu" + assert TTSModel._device == "cpu" # Check the class variable instead assert model == mock_model mock_build_model.assert_called_once() -@patch('os.path.exists') -@patch('api.src.services.tts.torch.load') -@patch('os.path.join') -def test_voicepack_loading_error(mock_join, mock_torch_load, mock_exists): +@patch('api.src.services.tts.TTSService._get_voice_path') +@patch('api.src.services.tts.TTSModel.get_instance') +def test_voicepack_loading_error(mock_get_instance, mock_get_voice_path): """Test voicepack loading error handling""" - mock_join.side_effect = lambda *args: '/'.join(args) - mock_exists.side_effect = lambda x: False # All voice files don't exist + mock_get_voice_path.return_value = None + mock_get_instance.return_value = (MagicMock(), "cpu") - TTSModel._instance = (MagicMock(), "cpu") # Mock instance TTSModel._voicepacks = {} # Reset voicepacks - with pytest.raises(FileNotFoundError, match="Voice file not found: af"): - TTSModel.get_voicepack("nonexistent_voice") + service = TTSService(start_worker=False) + with pytest.raises(ValueError, match="Voice not found: nonexistent_voice"): + service._generate_audio("test", "nonexistent_voice", 1.0) -def test_save_audio(tts_service, sample_audio, tmp_path): +@patch('api.src.services.tts.TTSModel') +def test_save_audio(mock_tts_model, tts_service, sample_audio, tmp_path): """Test saving audio to file""" - output_path = os.path.join(tmp_path, "test_output", "audio.wav") + output_dir = os.path.join(tmp_path, "test_output") + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, "audio.wav") + tts_service._save_audio(sample_audio, output_path) assert os.path.exists(output_path) diff --git a/docker-compose.cpu.yml b/docker-compose.cpu.yml index b581639..2daeb46 100644 --- a/docker-compose.cpu.yml +++ b/docker-compose.cpu.yml @@ -6,18 +6,21 @@ services: working_dir: /app/Kokoro-82M command: > sh -c " + rm -f .git/index.lock; if [ -z \"$(ls -A .)\" ]; then - git clone https://huggingface.co/hexgrad/Kokoro-82M . && \ - git checkout 7e9ebc5be7f66a1843b585b63d19d55b5d58ce30; + git clone https://huggingface.co/hexgrad/Kokoro-82M touch .cloned; else + rm -f .git/index.lock && \ + git checkout main && \ + git pull origin main && \ touch .cloned; fi; tail -f /dev/null " healthcheck: test: ["CMD", "test", "-f", ".cloned"] - interval: 1s + interval: 3s timeout: 1s retries: 120 start_period: 1s diff --git a/docker-compose.yml b/docker-compose.yml index 60af4f3..565d158 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -6,18 +6,21 @@ services: working_dir: /app/Kokoro-82M command: > sh -c " + rm -f .git/index.lock; if [ -z \"$(ls -A .)\" ]; then - git clone https://huggingface.co/hexgrad/Kokoro-82M . && \ - git checkout 7e9ebc5be7f66a1843b585b63d19d55b5d58ce30; + git clone https://huggingface.co/hexgrad/Kokoro-82M touch .cloned; else + rm -f .git/index.lock && \ + git checkout main && \ + git pull origin main && \ touch .cloned; fi; tail -f /dev/null " healthcheck: test: ["CMD", "test", "-f", ".cloned"] - interval: 1s + interval: 3s timeout: 1s retries: 120 start_period: 1s @@ -42,3 +45,15 @@ services: depends_on: model-fetcher: condition: service_healthy + + # # Gradio UI service + # gradio-ui: + # build: + # context: ./ui + # ports: + # - "7860:7860" + # volumes: + # - ./ui/data:/app/ui/data + # - ./ui/app.py:/app/app.py # Mount app.py for hot reload + # environment: + # - GRADIO_WATCH=True # Enable hot reloading