Merge branch 'master' into feat/gradio-gui

This commit is contained in:
remsky 2025-01-01 17:39:54 -07:00
commit 19321eabb2
13 changed files with 301 additions and 114 deletions

View file

@ -1,6 +1,9 @@
[run]
source = api
omit = Kokoro-82M/*
omit =
Kokoro-82M/*
MagicMock/*
test_*.py
[report]
exclude_lines =

41
.dockerignore Normal file
View file

@ -0,0 +1,41 @@
# Version control
.git
.gitignore
# Python
__pycache__
*.pyc
*.pyo
*.pyd
.Python
*.py[cod]
*$py.class
.pytest_cache
.coverage
.coveragerc
# Environment
# .env
.venv
env/
venv/
ENV/
# IDE
.idea
.vscode
*.swp
*.swo
# Project specific
examples/
Kokoro-82M/
ui/
tests/
*.md
*.txt
!requirements.txt
# Docker
Dockerfile*
docker-compose*

2
.gitignore vendored
View file

@ -1,6 +1,6 @@
output/
ui/data/*
*.db
*.pyc

27
CHANGELOG.md Normal file
View file

@ -0,0 +1,27 @@
# Changelog
Notable changes to this project will be documented in this file.
## 2024-01-09
### Modified
#### Configuration Changes
- Updated Docker configurations:
- Changes to `Dockerfile`:
- Improved layer caching by separating dependency and code layers
- Updates to `docker-compose.yml` and `docker-compose.cpu.yml`:
- Removed commit lock from model fetching to allow automatic model updates from HF
- Added git index lock cleanup
#### API Changes
- Modified `api/src/main.py`
- Updated TTS service implementation in `api/src/services/tts.py`:
- Added device management for better resource control:
- Voices are now copied from model repository to api/src/voices directory for persistence
- Refactored voice pack handling:
- Removed static voice pack dictionary
- On-demand voice loading from disk
- Added model warm-up functionality:
- Model now initializes with a dummy text generation
- Uses default voice (af.pt) for warm-up
- Model is ready for inference on first request

View file

@ -17,25 +17,25 @@ RUN pip3 install --no-cache-dir torch==2.5.1 --extra-index-url https://download.
COPY requirements.txt .
RUN pip3 install --no-cache-dir -r requirements.txt
# Copy application code and model
COPY . /app/
# Set working directory
WORKDIR /app
# Run with Python unbuffered output for live logging
ENV PYTHONUNBUFFERED=1
# Create non-root user
RUN useradd -m -u 1000 appuser
# Create directories and set permissions
# Create model directory and set ownership
RUN mkdir -p /app/Kokoro-82M && \
chown -R appuser:appuser /app
# Switch to non-root user
USER appuser
# Run with Python unbuffered output for live logging
ENV PYTHONUNBUFFERED=1
# Copy only necessary application code
COPY --chown=appuser:appuser api /app/api
# Set Python path (app first for our imports, then model dir for model imports)
ENV PYTHONPATH=/app:/app/Kokoro-82M

View file

@ -3,9 +3,9 @@
</p>
# Kokoro TTS API
[![Model Commit](https://img.shields.io/badge/model--commit-a67f113-blue)](https://huggingface.co/hexgrad/Kokoro-82M/tree/8228a351f87c8a6076502c1e3b7e72e821ebec9a)
[![Tests](https://img.shields.io/badge/tests-36%20passed-darkgreen)]()
[![Coverage](https://img.shields.io/badge/coverage-91%25-darkgreen)]()
[![Tests](https://img.shields.io/badge/tests-37%20passed-darkgreen)]()
[![Coverage](https://img.shields.io/badge/coverage-81%25-darkgreen)]()
[![Tested at Model Commit](https://img.shields.io/badge/last--tested--model--commit-a67f113-blue)](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667)
FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model, providing an OpenAI-compatible endpoint with:
- NVIDIA GPU accelerated inference (or CPU) option

View file

@ -19,18 +19,10 @@ async def lifespan(app: FastAPI):
"""Lifespan context manager for model initialization"""
logger.info("Loading TTS model and voice packs...")
# Initialize the main model
model, device = TTSModel.get_instance()
logger.info(f"Model loaded on {device}")
# Initialize all voice packs
tts_service = TTSService()
voices = tts_service.list_voices()
for voice in voices:
logger.info(f"Loading voice pack: {voice}")
TTSModel.get_voicepack(voice)
logger.info("All models and voice packs loaded successfully")
# Initialize the main model with warm-up
model, voicepack_count = TTSModel.initialize()
logger.info(f"Model loaded and warmed up on {TTSModel._device}")
logger.info(f"{voicepack_count} voice packs loaded successfully")
yield

View file

@ -21,43 +21,63 @@ enc = tiktoken.get_encoding("cl100k_base")
class TTSModel:
_instance = None
_device = None
_lock = threading.Lock()
_voicepacks = {}
# Directory for all voices (copied base voices, and any created combined voices)
VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices")
@classmethod
def get_instance(cls):
if cls._instance is None:
def initialize(cls):
"""Initialize and warm up the model"""
with cls._lock:
if cls._instance is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Initializing model on {device}")
# Initialize model
cls._device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Initializing model on {cls._device}")
model_path = os.path.join(settings.model_dir, settings.model_path)
model = build_model(model_path, device)
# Note: RNN memory optimization is handled internally by the model
cls._instance = (model, device)
return cls._instance
model = build_model(model_path, cls._device)
cls._instance = model
# Ensure voices directory exists
os.makedirs(cls.VOICES_DIR, exist_ok=True)
# Copy base voices to local directory
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
if os.path.exists(base_voices_dir):
for file in os.listdir(base_voices_dir):
if file.endswith(".pt"):
voice_name = file[:-3]
voice_path = os.path.join(cls.VOICES_DIR, file)
if not os.path.exists(voice_path):
try:
logger.info(f"Copying base voice {voice_name} to voices directory")
base_path = os.path.join(base_voices_dir, file)
voicepack = torch.load(base_path, map_location=cls._device, weights_only=True)
torch.save(voicepack, voice_path)
except Exception as e:
logger.error(f"Error copying voice {voice_name}: {str(e)}")
# Warm up with default voice
try:
dummy_text = "Hello"
voice_path = os.path.join(cls.VOICES_DIR, "af.pt")
dummy_voicepack = torch.load(voice_path, map_location=cls._device, weights_only=True)
generate(model, dummy_text, dummy_voicepack, lang='a', speed=1.0)
logger.info("Model warm-up complete")
except Exception as e:
logger.warning(f"Model warm-up failed: {e}")
# Count voices in directory for validation
voice_count = len([f for f in os.listdir(cls.VOICES_DIR) if f.endswith('.pt')])
return cls._instance, voice_count
@classmethod
def get_voicepack(cls, voice_name: str) -> torch.Tensor:
"""Get a voice pack from the voices directory."""
model, device = cls.get_instance()
if voice_name not in cls._voicepacks:
try:
voice_path = os.path.join(cls.VOICES_DIR, f"{voice_name}.pt")
if not os.path.exists(voice_path):
raise FileNotFoundError(f"Voice file not found: {voice_name}")
voicepack = torch.load(voice_path, map_location=device, weights_only=True)
cls._voicepacks[voice_name] = voicepack
except Exception as e:
logger.error(f"Error loading voice {voice_name}: {str(e)}")
if voice_name != "af":
return cls.get_voicepack("af")
raise
return cls._voicepacks[voice_name]
def get_instance(cls):
"""Get the initialized instance or raise an error"""
if cls._instance is None:
raise RuntimeError("Model not initialized. Call initialize() first.")
return cls._instance, cls._device
class TTSService:
@ -79,9 +99,9 @@ class TTSService:
voice_path = os.path.join(TTSModel.VOICES_DIR, file)
if not os.path.exists(voice_path):
try:
base_path = os.path.join(base_voices_dir, file)
logger.info(f"Copying base voice {voice_name} to voices directory")
voicepack = torch.load(base_path, map_location=TTSModel.get_instance()[1], weights_only=True)
base_path = os.path.join(base_voices_dir, file)
voicepack = torch.load(base_path, map_location=TTSModel._device, weights_only=True)
torch.save(voicepack, voice_path)
except Exception as e:
logger.error(f"Error copying voice {voice_name}: {str(e)}")
@ -114,21 +134,21 @@ class TTSService:
if not text:
raise ValueError("Text is empty after preprocessing")
# Get model instance
model, device = TTSModel.get_instance()
# Load voice
# Check voice exists
voice_path = self._get_voice_path(voice)
if not voice_path:
raise ValueError(f"Voice not found: {voice}")
voicepack = torch.load(voice_path, map_location=device, weights_only=True)
# Load model and voice
model = TTSModel._instance
voicepack = torch.load(voice_path, map_location=TTSModel._device, weights_only=True)
# Generate audio with or without stitching
if stitch_long_output:
chunks = self._split_text(text)
audio_chunks = []
# Process all chunks with same model/voicepack instance
for i, chunk in enumerate(chunks):
try:
# Validate phonemization first
@ -204,12 +224,9 @@ class TTSService:
v_name: List[str] = []
for voice in voices:
voice_path = self._get_voice_path(voice)
if not voice_path:
raise ValueError(f"Voice not found: {voice}")
try:
voicepack = torch.load(voice_path, map_location=TTSModel.get_instance()[1], weights_only=True)
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt")
voicepack = torch.load(voice_path, map_location=TTSModel._device, weights_only=True)
t_voices.append(voicepack)
v_name.append(voice)
except Exception as e:

View file

@ -1,8 +1,23 @@
import os
import shutil
import sys
from unittest.mock import Mock, patch
import pytest
def cleanup_mock_dirs():
"""Clean up any MagicMock directories created during tests"""
mock_dir = "MagicMock"
if os.path.exists(mock_dir):
shutil.rmtree(mock_dir)
@pytest.fixture(autouse=True)
def cleanup():
"""Automatically clean up before and after each test"""
cleanup_mock_dirs()
yield
cleanup_mock_dirs()
# Mock torch and other ML modules before they're imported
sys.modules["torch"] = Mock()
sys.modules["transformers"] = Mock()

View file

@ -1,45 +1,116 @@
"""Tests for main FastAPI application"""
"""Tests for FastAPI application"""
import pytest
from unittest.mock import patch, MagicMock
from fastapi.testclient import TestClient
from api.src.main import app
from api.src.main import app, lifespan
@pytest.fixture
def client():
def test_client():
"""Create a test client"""
return TestClient(app)
def test_health_check(client):
def test_health_check(test_client):
"""Test health check endpoint"""
response = client.get("/health")
response = test_client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "healthy"}
def test_test_endpoint(client):
"""Test the test endpoint"""
response = client.get("/v1/test")
assert response.status_code == 200
assert response.json() == {"status": "ok"}
@pytest.mark.asyncio
@patch('api.src.main.TTSModel')
@patch('api.src.main.logger')
async def test_lifespan_successful_warmup(mock_logger, mock_tts_model):
"""Test successful model warmup in lifespan"""
# Mock the model initialization with model info and voicepack count
mock_model = MagicMock()
# Mock file system for voice counting
mock_tts_model.VOICES_DIR = "/mock/voices"
with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt', 'voice3.pt']):
mock_tts_model.initialize.return_value = (mock_model, 3) # 3 voice files
mock_tts_model._device = "cuda" # Set device class variable
# Create an async generator from the lifespan context manager
async_gen = lifespan(MagicMock())
# Start the context manager
await async_gen.__aenter__()
# Verify the expected logging sequence
mock_logger.info.assert_any_call("Loading TTS model and voice packs...")
mock_logger.info.assert_any_call("Model loaded and warmed up on cuda")
mock_logger.info.assert_any_call("3 voice packs loaded successfully")
# Verify model initialization was called
mock_tts_model.initialize.assert_called_once()
# Clean up
await async_gen.__aexit__(None, None, None)
def test_cors_headers(client):
"""Test CORS headers are present"""
response = client.get(
"/health",
headers={"Origin": "http://testserver"},
)
assert response.status_code == 200
assert response.headers["access-control-allow-origin"] == "*"
@pytest.mark.asyncio
@patch('api.src.main.TTSModel')
@patch('api.src.main.logger')
async def test_lifespan_failed_warmup(mock_logger, mock_tts_model):
"""Test failed model warmup in lifespan"""
# Mock the model initialization to fail
mock_tts_model.initialize.side_effect = Exception("Failed to initialize model")
# Create an async generator from the lifespan context manager
async_gen = lifespan(MagicMock())
# Verify the exception is raised
with pytest.raises(Exception, match="Failed to initialize model"):
await async_gen.__aenter__()
# Verify the expected logging sequence
mock_logger.info.assert_called_with("Loading TTS model and voice packs...")
# Clean up
await async_gen.__aexit__(None, None, None)
def test_openapi_schema(client):
"""Test OpenAPI schema is accessible"""
response = client.get("/openapi.json")
assert response.status_code == 200
schema = response.json()
assert schema["info"]["title"] == app.title
assert schema["info"]["version"] == app.version
@pytest.mark.asyncio
@patch('api.src.main.TTSModel')
async def test_lifespan_cuda_warmup(mock_tts_model):
"""Test model warmup specifically on CUDA"""
# Mock the model initialization with CUDA and voicepacks
mock_model = MagicMock()
# Mock file system for voice counting
mock_tts_model.VOICES_DIR = "/mock/voices"
with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt']):
mock_tts_model.initialize.return_value = (mock_model, 2) # 2 voice files
mock_tts_model._device = "cuda" # Set device class variable
# Create an async generator from the lifespan context manager
async_gen = lifespan(MagicMock())
await async_gen.__aenter__()
# Verify model was initialized
mock_tts_model.initialize.assert_called_once()
# Clean up
await async_gen.__aexit__(None, None, None)
@pytest.mark.asyncio
@patch('api.src.main.TTSModel')
async def test_lifespan_cpu_fallback(mock_tts_model):
"""Test model warmup falling back to CPU"""
# Mock the model initialization with CPU and voicepacks
mock_model = MagicMock()
# Mock file system for voice counting
mock_tts_model.VOICES_DIR = "/mock/voices"
with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt', 'voice3.pt', 'voice4.pt']):
mock_tts_model.initialize.return_value = (mock_model, 4) # 4 voice files
mock_tts_model._device = "cpu" # Set device class variable
# Create an async generator from the lifespan context manager
async_gen = lifespan(MagicMock())
await async_gen.__aenter__()
# Verify model was initialized
mock_tts_model.initialize.assert_called_once()
# Clean up
await async_gen.__aexit__(None, None, None)

View file

@ -131,9 +131,9 @@ def test_model_initialization_cuda(mock_build_model, mock_cuda_available):
mock_build_model.return_value = mock_model
TTSModel._instance = None # Reset singleton
model, device = TTSModel.get_instance()
model, voice_count = TTSModel.initialize()
assert device == "cuda"
assert TTSModel._device == "cuda" # Check the class variable instead
assert model == mock_model
mock_build_model.assert_called_once()
@ -147,31 +147,34 @@ def test_model_initialization_cpu(mock_build_model, mock_cuda_available):
mock_build_model.return_value = mock_model
TTSModel._instance = None # Reset singleton
model, device = TTSModel.get_instance()
model, voice_count = TTSModel.initialize()
assert device == "cpu"
assert TTSModel._device == "cpu" # Check the class variable instead
assert model == mock_model
mock_build_model.assert_called_once()
@patch('os.path.exists')
@patch('api.src.services.tts.torch.load')
@patch('os.path.join')
def test_voicepack_loading_error(mock_join, mock_torch_load, mock_exists):
@patch('api.src.services.tts.TTSService._get_voice_path')
@patch('api.src.services.tts.TTSModel.get_instance')
def test_voicepack_loading_error(mock_get_instance, mock_get_voice_path):
"""Test voicepack loading error handling"""
mock_join.side_effect = lambda *args: '/'.join(args)
mock_exists.side_effect = lambda x: False # All voice files don't exist
mock_get_voice_path.return_value = None
mock_get_instance.return_value = (MagicMock(), "cpu")
TTSModel._instance = (MagicMock(), "cpu") # Mock instance
TTSModel._voicepacks = {} # Reset voicepacks
with pytest.raises(FileNotFoundError, match="Voice file not found: af"):
TTSModel.get_voicepack("nonexistent_voice")
service = TTSService(start_worker=False)
with pytest.raises(ValueError, match="Voice not found: nonexistent_voice"):
service._generate_audio("test", "nonexistent_voice", 1.0)
def test_save_audio(tts_service, sample_audio, tmp_path):
@patch('api.src.services.tts.TTSModel')
def test_save_audio(mock_tts_model, tts_service, sample_audio, tmp_path):
"""Test saving audio to file"""
output_path = os.path.join(tmp_path, "test_output", "audio.wav")
output_dir = os.path.join(tmp_path, "test_output")
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, "audio.wav")
tts_service._save_audio(sample_audio, output_path)
assert os.path.exists(output_path)

View file

@ -6,18 +6,21 @@ services:
working_dir: /app/Kokoro-82M
command: >
sh -c "
rm -f .git/index.lock;
if [ -z \"$(ls -A .)\" ]; then
git clone https://huggingface.co/hexgrad/Kokoro-82M . && \
git checkout 8228a351f87c8a6076502c1e3b7e72e821ebec9a;
git clone https://huggingface.co/hexgrad/Kokoro-82M
touch .cloned;
else
rm -f .git/index.lock && \
git checkout main && \
git pull origin main && \
touch .cloned;
fi;
tail -f /dev/null
"
healthcheck:
test: ["CMD", "test", "-f", ".cloned"]
interval: 1s
interval: 3s
timeout: 1s
retries: 120
start_period: 1s

View file

@ -6,18 +6,21 @@ services:
working_dir: /app/Kokoro-82M
command: >
sh -c "
rm -f .git/index.lock;
if [ -z \"$(ls -A .)\" ]; then
git clone https://huggingface.co/hexgrad/Kokoro-82M . && \
git checkout 8228a351f87c8a6076502c1e3b7e72e821ebec9a;
git clone https://huggingface.co/hexgrad/Kokoro-82M
touch .cloned;
else
rm -f .git/index.lock && \
git checkout main && \
git pull origin main && \
touch .cloned;
fi;
tail -f /dev/null
"
healthcheck:
test: ["CMD", "test", "-f", ".cloned"]
interval: 1s
interval: 3s
timeout: 1s
retries: 120
start_period: 1s
@ -42,3 +45,15 @@ services:
depends_on:
model-fetcher:
condition: service_healthy
# # Gradio UI service
# gradio-ui:
# build:
# context: ./ui
# ports:
# - "7860:7860"
# volumes:
# - ./ui/data:/app/ui/data
# - ./ui/app.py:/app/app.py # Mount app.py for hot reload
# environment:
# - GRADIO_WATCH=True # Enable hot reloading