Merge branch 'master' into feat/gradio-gui

This commit is contained in:
remsky 2025-01-01 17:39:54 -07:00
commit 19321eabb2
13 changed files with 301 additions and 114 deletions

View file

@ -1,6 +1,9 @@
[run] [run]
source = api source = api
omit = Kokoro-82M/* omit =
Kokoro-82M/*
MagicMock/*
test_*.py
[report] [report]
exclude_lines = exclude_lines =

41
.dockerignore Normal file
View file

@ -0,0 +1,41 @@
# Version control
.git
.gitignore
# Python
__pycache__
*.pyc
*.pyo
*.pyd
.Python
*.py[cod]
*$py.class
.pytest_cache
.coverage
.coveragerc
# Environment
# .env
.venv
env/
venv/
ENV/
# IDE
.idea
.vscode
*.swp
*.swo
# Project specific
examples/
Kokoro-82M/
ui/
tests/
*.md
*.txt
!requirements.txt
# Docker
Dockerfile*
docker-compose*

2
.gitignore vendored
View file

@ -1,6 +1,6 @@
output/ output/
ui/data/*
*.db *.db
*.pyc *.pyc

27
CHANGELOG.md Normal file
View file

@ -0,0 +1,27 @@
# Changelog
Notable changes to this project will be documented in this file.
## 2024-01-09
### Modified
#### Configuration Changes
- Updated Docker configurations:
- Changes to `Dockerfile`:
- Improved layer caching by separating dependency and code layers
- Updates to `docker-compose.yml` and `docker-compose.cpu.yml`:
- Removed commit lock from model fetching to allow automatic model updates from HF
- Added git index lock cleanup
#### API Changes
- Modified `api/src/main.py`
- Updated TTS service implementation in `api/src/services/tts.py`:
- Added device management for better resource control:
- Voices are now copied from model repository to api/src/voices directory for persistence
- Refactored voice pack handling:
- Removed static voice pack dictionary
- On-demand voice loading from disk
- Added model warm-up functionality:
- Model now initializes with a dummy text generation
- Uses default voice (af.pt) for warm-up
- Model is ready for inference on first request

View file

@ -17,25 +17,25 @@ RUN pip3 install --no-cache-dir torch==2.5.1 --extra-index-url https://download.
COPY requirements.txt . COPY requirements.txt .
RUN pip3 install --no-cache-dir -r requirements.txt RUN pip3 install --no-cache-dir -r requirements.txt
# Copy application code and model
COPY . /app/
# Set working directory # Set working directory
WORKDIR /app WORKDIR /app
# Run with Python unbuffered output for live logging
ENV PYTHONUNBUFFERED=1
# Create non-root user # Create non-root user
RUN useradd -m -u 1000 appuser RUN useradd -m -u 1000 appuser
# Create directories and set permissions # Create model directory and set ownership
RUN mkdir -p /app/Kokoro-82M && \ RUN mkdir -p /app/Kokoro-82M && \
chown -R appuser:appuser /app chown -R appuser:appuser /app
# Switch to non-root user # Switch to non-root user
USER appuser USER appuser
# Run with Python unbuffered output for live logging
ENV PYTHONUNBUFFERED=1
# Copy only necessary application code
COPY --chown=appuser:appuser api /app/api
# Set Python path (app first for our imports, then model dir for model imports) # Set Python path (app first for our imports, then model dir for model imports)
ENV PYTHONPATH=/app:/app/Kokoro-82M ENV PYTHONPATH=/app:/app/Kokoro-82M

View file

@ -3,9 +3,9 @@
</p> </p>
# Kokoro TTS API # Kokoro TTS API
[![Model Commit](https://img.shields.io/badge/model--commit-a67f113-blue)](https://huggingface.co/hexgrad/Kokoro-82M/tree/8228a351f87c8a6076502c1e3b7e72e821ebec9a) [![Tests](https://img.shields.io/badge/tests-37%20passed-darkgreen)]()
[![Tests](https://img.shields.io/badge/tests-36%20passed-darkgreen)]() [![Coverage](https://img.shields.io/badge/coverage-81%25-darkgreen)]()
[![Coverage](https://img.shields.io/badge/coverage-91%25-darkgreen)]() [![Tested at Model Commit](https://img.shields.io/badge/last--tested--model--commit-a67f113-blue)](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667)
FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model, providing an OpenAI-compatible endpoint with: FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model, providing an OpenAI-compatible endpoint with:
- NVIDIA GPU accelerated inference (or CPU) option - NVIDIA GPU accelerated inference (or CPU) option

View file

@ -19,18 +19,10 @@ async def lifespan(app: FastAPI):
"""Lifespan context manager for model initialization""" """Lifespan context manager for model initialization"""
logger.info("Loading TTS model and voice packs...") logger.info("Loading TTS model and voice packs...")
# Initialize the main model # Initialize the main model with warm-up
model, device = TTSModel.get_instance() model, voicepack_count = TTSModel.initialize()
logger.info(f"Model loaded on {device}") logger.info(f"Model loaded and warmed up on {TTSModel._device}")
logger.info(f"{voicepack_count} voice packs loaded successfully")
# Initialize all voice packs
tts_service = TTSService()
voices = tts_service.list_voices()
for voice in voices:
logger.info(f"Loading voice pack: {voice}")
TTSModel.get_voicepack(voice)
logger.info("All models and voice packs loaded successfully")
yield yield

View file

@ -21,43 +21,63 @@ enc = tiktoken.get_encoding("cl100k_base")
class TTSModel: class TTSModel:
_instance = None _instance = None
_device = None
_lock = threading.Lock() _lock = threading.Lock()
_voicepacks = {}
# Directory for all voices (copied base voices, and any created combined voices) # Directory for all voices (copied base voices, and any created combined voices)
VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices") VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices")
@classmethod @classmethod
def get_instance(cls): def initialize(cls):
if cls._instance is None: """Initialize and warm up the model"""
with cls._lock: with cls._lock:
if cls._instance is None: if cls._instance is None:
device = "cuda" if torch.cuda.is_available() else "cpu" # Initialize model
print(f"Initializing model on {device}") cls._device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Initializing model on {cls._device}")
model_path = os.path.join(settings.model_dir, settings.model_path) model_path = os.path.join(settings.model_dir, settings.model_path)
model = build_model(model_path, device) model = build_model(model_path, cls._device)
# Note: RNN memory optimization is handled internally by the model cls._instance = model
cls._instance = (model, device)
return cls._instance # Ensure voices directory exists
os.makedirs(cls.VOICES_DIR, exist_ok=True)
# Copy base voices to local directory
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
if os.path.exists(base_voices_dir):
for file in os.listdir(base_voices_dir):
if file.endswith(".pt"):
voice_name = file[:-3]
voice_path = os.path.join(cls.VOICES_DIR, file)
if not os.path.exists(voice_path):
try:
logger.info(f"Copying base voice {voice_name} to voices directory")
base_path = os.path.join(base_voices_dir, file)
voicepack = torch.load(base_path, map_location=cls._device, weights_only=True)
torch.save(voicepack, voice_path)
except Exception as e:
logger.error(f"Error copying voice {voice_name}: {str(e)}")
# Warm up with default voice
try:
dummy_text = "Hello"
voice_path = os.path.join(cls.VOICES_DIR, "af.pt")
dummy_voicepack = torch.load(voice_path, map_location=cls._device, weights_only=True)
generate(model, dummy_text, dummy_voicepack, lang='a', speed=1.0)
logger.info("Model warm-up complete")
except Exception as e:
logger.warning(f"Model warm-up failed: {e}")
# Count voices in directory for validation
voice_count = len([f for f in os.listdir(cls.VOICES_DIR) if f.endswith('.pt')])
return cls._instance, voice_count
@classmethod @classmethod
def get_voicepack(cls, voice_name: str) -> torch.Tensor: def get_instance(cls):
"""Get a voice pack from the voices directory.""" """Get the initialized instance or raise an error"""
model, device = cls.get_instance() if cls._instance is None:
if voice_name not in cls._voicepacks: raise RuntimeError("Model not initialized. Call initialize() first.")
try: return cls._instance, cls._device
voice_path = os.path.join(cls.VOICES_DIR, f"{voice_name}.pt")
if not os.path.exists(voice_path):
raise FileNotFoundError(f"Voice file not found: {voice_name}")
voicepack = torch.load(voice_path, map_location=device, weights_only=True)
cls._voicepacks[voice_name] = voicepack
except Exception as e:
logger.error(f"Error loading voice {voice_name}: {str(e)}")
if voice_name != "af":
return cls.get_voicepack("af")
raise
return cls._voicepacks[voice_name]
class TTSService: class TTSService:
@ -79,9 +99,9 @@ class TTSService:
voice_path = os.path.join(TTSModel.VOICES_DIR, file) voice_path = os.path.join(TTSModel.VOICES_DIR, file)
if not os.path.exists(voice_path): if not os.path.exists(voice_path):
try: try:
base_path = os.path.join(base_voices_dir, file)
logger.info(f"Copying base voice {voice_name} to voices directory") logger.info(f"Copying base voice {voice_name} to voices directory")
voicepack = torch.load(base_path, map_location=TTSModel.get_instance()[1], weights_only=True) base_path = os.path.join(base_voices_dir, file)
voicepack = torch.load(base_path, map_location=TTSModel._device, weights_only=True)
torch.save(voicepack, voice_path) torch.save(voicepack, voice_path)
except Exception as e: except Exception as e:
logger.error(f"Error copying voice {voice_name}: {str(e)}") logger.error(f"Error copying voice {voice_name}: {str(e)}")
@ -114,21 +134,21 @@ class TTSService:
if not text: if not text:
raise ValueError("Text is empty after preprocessing") raise ValueError("Text is empty after preprocessing")
# Get model instance # Check voice exists
model, device = TTSModel.get_instance()
# Load voice
voice_path = self._get_voice_path(voice) voice_path = self._get_voice_path(voice)
if not voice_path: if not voice_path:
raise ValueError(f"Voice not found: {voice}") raise ValueError(f"Voice not found: {voice}")
voicepack = torch.load(voice_path, map_location=device, weights_only=True) # Load model and voice
model = TTSModel._instance
voicepack = torch.load(voice_path, map_location=TTSModel._device, weights_only=True)
# Generate audio with or without stitching # Generate audio with or without stitching
if stitch_long_output: if stitch_long_output:
chunks = self._split_text(text) chunks = self._split_text(text)
audio_chunks = [] audio_chunks = []
# Process all chunks with same model/voicepack instance
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):
try: try:
# Validate phonemization first # Validate phonemization first
@ -204,12 +224,9 @@ class TTSService:
v_name: List[str] = [] v_name: List[str] = []
for voice in voices: for voice in voices:
voice_path = self._get_voice_path(voice)
if not voice_path:
raise ValueError(f"Voice not found: {voice}")
try: try:
voicepack = torch.load(voice_path, map_location=TTSModel.get_instance()[1], weights_only=True) voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt")
voicepack = torch.load(voice_path, map_location=TTSModel._device, weights_only=True)
t_voices.append(voicepack) t_voices.append(voicepack)
v_name.append(voice) v_name.append(voice)
except Exception as e: except Exception as e:

View file

@ -1,8 +1,23 @@
import os
import shutil
import sys import sys
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import pytest import pytest
def cleanup_mock_dirs():
"""Clean up any MagicMock directories created during tests"""
mock_dir = "MagicMock"
if os.path.exists(mock_dir):
shutil.rmtree(mock_dir)
@pytest.fixture(autouse=True)
def cleanup():
"""Automatically clean up before and after each test"""
cleanup_mock_dirs()
yield
cleanup_mock_dirs()
# Mock torch and other ML modules before they're imported # Mock torch and other ML modules before they're imported
sys.modules["torch"] = Mock() sys.modules["torch"] = Mock()
sys.modules["transformers"] = Mock() sys.modules["transformers"] = Mock()

View file

@ -1,45 +1,116 @@
"""Tests for main FastAPI application""" """Tests for FastAPI application"""
import pytest import pytest
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from api.src.main import app, lifespan
from api.src.main import app
@pytest.fixture @pytest.fixture
def client(): def test_client():
"""Create a test client""" """Create a test client"""
return TestClient(app) return TestClient(app)
def test_health_check(client): def test_health_check(test_client):
"""Test health check endpoint""" """Test health check endpoint"""
response = client.get("/health") response = test_client.get("/health")
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == {"status": "healthy"} assert response.json() == {"status": "healthy"}
def test_test_endpoint(client): @pytest.mark.asyncio
"""Test the test endpoint""" @patch('api.src.main.TTSModel')
response = client.get("/v1/test") @patch('api.src.main.logger')
assert response.status_code == 200 async def test_lifespan_successful_warmup(mock_logger, mock_tts_model):
assert response.json() == {"status": "ok"} """Test successful model warmup in lifespan"""
# Mock the model initialization with model info and voicepack count
mock_model = MagicMock()
# Mock file system for voice counting
mock_tts_model.VOICES_DIR = "/mock/voices"
with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt', 'voice3.pt']):
mock_tts_model.initialize.return_value = (mock_model, 3) # 3 voice files
mock_tts_model._device = "cuda" # Set device class variable
# Create an async generator from the lifespan context manager
async_gen = lifespan(MagicMock())
# Start the context manager
await async_gen.__aenter__()
# Verify the expected logging sequence
mock_logger.info.assert_any_call("Loading TTS model and voice packs...")
mock_logger.info.assert_any_call("Model loaded and warmed up on cuda")
mock_logger.info.assert_any_call("3 voice packs loaded successfully")
# Verify model initialization was called
mock_tts_model.initialize.assert_called_once()
# Clean up
await async_gen.__aexit__(None, None, None)
def test_cors_headers(client): @pytest.mark.asyncio
"""Test CORS headers are present""" @patch('api.src.main.TTSModel')
response = client.get( @patch('api.src.main.logger')
"/health", async def test_lifespan_failed_warmup(mock_logger, mock_tts_model):
headers={"Origin": "http://testserver"}, """Test failed model warmup in lifespan"""
) # Mock the model initialization to fail
assert response.status_code == 200 mock_tts_model.initialize.side_effect = Exception("Failed to initialize model")
assert response.headers["access-control-allow-origin"] == "*"
# Create an async generator from the lifespan context manager
async_gen = lifespan(MagicMock())
# Verify the exception is raised
with pytest.raises(Exception, match="Failed to initialize model"):
await async_gen.__aenter__()
# Verify the expected logging sequence
mock_logger.info.assert_called_with("Loading TTS model and voice packs...")
# Clean up
await async_gen.__aexit__(None, None, None)
def test_openapi_schema(client): @pytest.mark.asyncio
"""Test OpenAPI schema is accessible""" @patch('api.src.main.TTSModel')
response = client.get("/openapi.json") async def test_lifespan_cuda_warmup(mock_tts_model):
assert response.status_code == 200 """Test model warmup specifically on CUDA"""
schema = response.json() # Mock the model initialization with CUDA and voicepacks
assert schema["info"]["title"] == app.title mock_model = MagicMock()
assert schema["info"]["version"] == app.version # Mock file system for voice counting
mock_tts_model.VOICES_DIR = "/mock/voices"
with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt']):
mock_tts_model.initialize.return_value = (mock_model, 2) # 2 voice files
mock_tts_model._device = "cuda" # Set device class variable
# Create an async generator from the lifespan context manager
async_gen = lifespan(MagicMock())
await async_gen.__aenter__()
# Verify model was initialized
mock_tts_model.initialize.assert_called_once()
# Clean up
await async_gen.__aexit__(None, None, None)
@pytest.mark.asyncio
@patch('api.src.main.TTSModel')
async def test_lifespan_cpu_fallback(mock_tts_model):
"""Test model warmup falling back to CPU"""
# Mock the model initialization with CPU and voicepacks
mock_model = MagicMock()
# Mock file system for voice counting
mock_tts_model.VOICES_DIR = "/mock/voices"
with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt', 'voice3.pt', 'voice4.pt']):
mock_tts_model.initialize.return_value = (mock_model, 4) # 4 voice files
mock_tts_model._device = "cpu" # Set device class variable
# Create an async generator from the lifespan context manager
async_gen = lifespan(MagicMock())
await async_gen.__aenter__()
# Verify model was initialized
mock_tts_model.initialize.assert_called_once()
# Clean up
await async_gen.__aexit__(None, None, None)

View file

@ -131,9 +131,9 @@ def test_model_initialization_cuda(mock_build_model, mock_cuda_available):
mock_build_model.return_value = mock_model mock_build_model.return_value = mock_model
TTSModel._instance = None # Reset singleton TTSModel._instance = None # Reset singleton
model, device = TTSModel.get_instance() model, voice_count = TTSModel.initialize()
assert device == "cuda" assert TTSModel._device == "cuda" # Check the class variable instead
assert model == mock_model assert model == mock_model
mock_build_model.assert_called_once() mock_build_model.assert_called_once()
@ -147,31 +147,34 @@ def test_model_initialization_cpu(mock_build_model, mock_cuda_available):
mock_build_model.return_value = mock_model mock_build_model.return_value = mock_model
TTSModel._instance = None # Reset singleton TTSModel._instance = None # Reset singleton
model, device = TTSModel.get_instance() model, voice_count = TTSModel.initialize()
assert device == "cpu" assert TTSModel._device == "cpu" # Check the class variable instead
assert model == mock_model assert model == mock_model
mock_build_model.assert_called_once() mock_build_model.assert_called_once()
@patch('os.path.exists') @patch('api.src.services.tts.TTSService._get_voice_path')
@patch('api.src.services.tts.torch.load') @patch('api.src.services.tts.TTSModel.get_instance')
@patch('os.path.join') def test_voicepack_loading_error(mock_get_instance, mock_get_voice_path):
def test_voicepack_loading_error(mock_join, mock_torch_load, mock_exists):
"""Test voicepack loading error handling""" """Test voicepack loading error handling"""
mock_join.side_effect = lambda *args: '/'.join(args) mock_get_voice_path.return_value = None
mock_exists.side_effect = lambda x: False # All voice files don't exist mock_get_instance.return_value = (MagicMock(), "cpu")
TTSModel._instance = (MagicMock(), "cpu") # Mock instance
TTSModel._voicepacks = {} # Reset voicepacks TTSModel._voicepacks = {} # Reset voicepacks
with pytest.raises(FileNotFoundError, match="Voice file not found: af"): service = TTSService(start_worker=False)
TTSModel.get_voicepack("nonexistent_voice") with pytest.raises(ValueError, match="Voice not found: nonexistent_voice"):
service._generate_audio("test", "nonexistent_voice", 1.0)
def test_save_audio(tts_service, sample_audio, tmp_path): @patch('api.src.services.tts.TTSModel')
def test_save_audio(mock_tts_model, tts_service, sample_audio, tmp_path):
"""Test saving audio to file""" """Test saving audio to file"""
output_path = os.path.join(tmp_path, "test_output", "audio.wav") output_dir = os.path.join(tmp_path, "test_output")
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, "audio.wav")
tts_service._save_audio(sample_audio, output_path) tts_service._save_audio(sample_audio, output_path)
assert os.path.exists(output_path) assert os.path.exists(output_path)

View file

@ -6,18 +6,21 @@ services:
working_dir: /app/Kokoro-82M working_dir: /app/Kokoro-82M
command: > command: >
sh -c " sh -c "
rm -f .git/index.lock;
if [ -z \"$(ls -A .)\" ]; then if [ -z \"$(ls -A .)\" ]; then
git clone https://huggingface.co/hexgrad/Kokoro-82M . && \ git clone https://huggingface.co/hexgrad/Kokoro-82M
git checkout 8228a351f87c8a6076502c1e3b7e72e821ebec9a;
touch .cloned; touch .cloned;
else else
rm -f .git/index.lock && \
git checkout main && \
git pull origin main && \
touch .cloned; touch .cloned;
fi; fi;
tail -f /dev/null tail -f /dev/null
" "
healthcheck: healthcheck:
test: ["CMD", "test", "-f", ".cloned"] test: ["CMD", "test", "-f", ".cloned"]
interval: 1s interval: 3s
timeout: 1s timeout: 1s
retries: 120 retries: 120
start_period: 1s start_period: 1s

View file

@ -6,18 +6,21 @@ services:
working_dir: /app/Kokoro-82M working_dir: /app/Kokoro-82M
command: > command: >
sh -c " sh -c "
rm -f .git/index.lock;
if [ -z \"$(ls -A .)\" ]; then if [ -z \"$(ls -A .)\" ]; then
git clone https://huggingface.co/hexgrad/Kokoro-82M . && \ git clone https://huggingface.co/hexgrad/Kokoro-82M
git checkout 8228a351f87c8a6076502c1e3b7e72e821ebec9a;
touch .cloned; touch .cloned;
else else
rm -f .git/index.lock && \
git checkout main && \
git pull origin main && \
touch .cloned; touch .cloned;
fi; fi;
tail -f /dev/null tail -f /dev/null
" "
healthcheck: healthcheck:
test: ["CMD", "test", "-f", ".cloned"] test: ["CMD", "test", "-f", ".cloned"]
interval: 1s interval: 3s
timeout: 1s timeout: 1s
retries: 120 retries: 120
start_period: 1s start_period: 1s
@ -42,3 +45,15 @@ services:
depends_on: depends_on:
model-fetcher: model-fetcher:
condition: service_healthy condition: service_healthy
# # Gradio UI service
# gradio-ui:
# build:
# context: ./ui
# ports:
# - "7860:7860"
# volumes:
# - ./ui/data:/app/ui/data
# - ./ui/app.py:/app/app.py # Mount app.py for hot reload
# environment:
# - GRADIO_WATCH=True # Enable hot reloading