mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Merge branch 'master' into feat/gradio-gui
This commit is contained in:
commit
19321eabb2
13 changed files with 301 additions and 114 deletions
|
@ -1,6 +1,9 @@
|
||||||
[run]
|
[run]
|
||||||
source = api
|
source = api
|
||||||
omit = Kokoro-82M/*
|
omit =
|
||||||
|
Kokoro-82M/*
|
||||||
|
MagicMock/*
|
||||||
|
test_*.py
|
||||||
|
|
||||||
[report]
|
[report]
|
||||||
exclude_lines =
|
exclude_lines =
|
||||||
|
|
41
.dockerignore
Normal file
41
.dockerignore
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
# Version control
|
||||||
|
.git
|
||||||
|
.gitignore
|
||||||
|
|
||||||
|
# Python
|
||||||
|
__pycache__
|
||||||
|
*.pyc
|
||||||
|
*.pyo
|
||||||
|
*.pyd
|
||||||
|
.Python
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
.pytest_cache
|
||||||
|
.coverage
|
||||||
|
.coveragerc
|
||||||
|
|
||||||
|
# Environment
|
||||||
|
# .env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.idea
|
||||||
|
.vscode
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
|
||||||
|
# Project specific
|
||||||
|
examples/
|
||||||
|
Kokoro-82M/
|
||||||
|
ui/
|
||||||
|
tests/
|
||||||
|
*.md
|
||||||
|
*.txt
|
||||||
|
!requirements.txt
|
||||||
|
|
||||||
|
# Docker
|
||||||
|
Dockerfile*
|
||||||
|
docker-compose*
|
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -1,6 +1,6 @@
|
||||||
|
|
||||||
output/
|
output/
|
||||||
|
ui/data/*
|
||||||
|
|
||||||
*.db
|
*.db
|
||||||
*.pyc
|
*.pyc
|
||||||
|
|
27
CHANGELOG.md
Normal file
27
CHANGELOG.md
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
# Changelog
|
||||||
|
|
||||||
|
Notable changes to this project will be documented in this file.
|
||||||
|
|
||||||
|
## 2024-01-09
|
||||||
|
|
||||||
|
### Modified
|
||||||
|
#### Configuration Changes
|
||||||
|
- Updated Docker configurations:
|
||||||
|
- Changes to `Dockerfile`:
|
||||||
|
- Improved layer caching by separating dependency and code layers
|
||||||
|
- Updates to `docker-compose.yml` and `docker-compose.cpu.yml`:
|
||||||
|
- Removed commit lock from model fetching to allow automatic model updates from HF
|
||||||
|
- Added git index lock cleanup
|
||||||
|
|
||||||
|
#### API Changes
|
||||||
|
- Modified `api/src/main.py`
|
||||||
|
- Updated TTS service implementation in `api/src/services/tts.py`:
|
||||||
|
- Added device management for better resource control:
|
||||||
|
- Voices are now copied from model repository to api/src/voices directory for persistence
|
||||||
|
- Refactored voice pack handling:
|
||||||
|
- Removed static voice pack dictionary
|
||||||
|
- On-demand voice loading from disk
|
||||||
|
- Added model warm-up functionality:
|
||||||
|
- Model now initializes with a dummy text generation
|
||||||
|
- Uses default voice (af.pt) for warm-up
|
||||||
|
- Model is ready for inference on first request
|
14
Dockerfile
14
Dockerfile
|
@ -17,25 +17,25 @@ RUN pip3 install --no-cache-dir torch==2.5.1 --extra-index-url https://download.
|
||||||
COPY requirements.txt .
|
COPY requirements.txt .
|
||||||
RUN pip3 install --no-cache-dir -r requirements.txt
|
RUN pip3 install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
# Copy application code and model
|
|
||||||
COPY . /app/
|
|
||||||
|
|
||||||
# Set working directory
|
# Set working directory
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# Run with Python unbuffered output for live logging
|
|
||||||
ENV PYTHONUNBUFFERED=1
|
|
||||||
|
|
||||||
# Create non-root user
|
# Create non-root user
|
||||||
RUN useradd -m -u 1000 appuser
|
RUN useradd -m -u 1000 appuser
|
||||||
|
|
||||||
# Create directories and set permissions
|
# Create model directory and set ownership
|
||||||
RUN mkdir -p /app/Kokoro-82M && \
|
RUN mkdir -p /app/Kokoro-82M && \
|
||||||
chown -R appuser:appuser /app
|
chown -R appuser:appuser /app
|
||||||
|
|
||||||
# Switch to non-root user
|
# Switch to non-root user
|
||||||
USER appuser
|
USER appuser
|
||||||
|
|
||||||
|
# Run with Python unbuffered output for live logging
|
||||||
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
|
||||||
|
# Copy only necessary application code
|
||||||
|
COPY --chown=appuser:appuser api /app/api
|
||||||
|
|
||||||
# Set Python path (app first for our imports, then model dir for model imports)
|
# Set Python path (app first for our imports, then model dir for model imports)
|
||||||
ENV PYTHONPATH=/app:/app/Kokoro-82M
|
ENV PYTHONPATH=/app:/app/Kokoro-82M
|
||||||
|
|
||||||
|
|
|
@ -3,9 +3,9 @@
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
# Kokoro TTS API
|
# Kokoro TTS API
|
||||||
[](https://huggingface.co/hexgrad/Kokoro-82M/tree/8228a351f87c8a6076502c1e3b7e72e821ebec9a)
|
[]()
|
||||||
[]()
|
[]()
|
||||||
[]()
|
[](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667)
|
||||||
|
|
||||||
FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model, providing an OpenAI-compatible endpoint with:
|
FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model, providing an OpenAI-compatible endpoint with:
|
||||||
- NVIDIA GPU accelerated inference (or CPU) option
|
- NVIDIA GPU accelerated inference (or CPU) option
|
||||||
|
|
|
@ -19,18 +19,10 @@ async def lifespan(app: FastAPI):
|
||||||
"""Lifespan context manager for model initialization"""
|
"""Lifespan context manager for model initialization"""
|
||||||
logger.info("Loading TTS model and voice packs...")
|
logger.info("Loading TTS model and voice packs...")
|
||||||
|
|
||||||
# Initialize the main model
|
# Initialize the main model with warm-up
|
||||||
model, device = TTSModel.get_instance()
|
model, voicepack_count = TTSModel.initialize()
|
||||||
logger.info(f"Model loaded on {device}")
|
logger.info(f"Model loaded and warmed up on {TTSModel._device}")
|
||||||
|
logger.info(f"{voicepack_count} voice packs loaded successfully")
|
||||||
# Initialize all voice packs
|
|
||||||
tts_service = TTSService()
|
|
||||||
voices = tts_service.list_voices()
|
|
||||||
for voice in voices:
|
|
||||||
logger.info(f"Loading voice pack: {voice}")
|
|
||||||
TTSModel.get_voicepack(voice)
|
|
||||||
|
|
||||||
logger.info("All models and voice packs loaded successfully")
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -21,43 +21,63 @@ enc = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
class TTSModel:
|
class TTSModel:
|
||||||
_instance = None
|
_instance = None
|
||||||
|
_device = None
|
||||||
_lock = threading.Lock()
|
_lock = threading.Lock()
|
||||||
_voicepacks = {}
|
|
||||||
|
|
||||||
# Directory for all voices (copied base voices, and any created combined voices)
|
# Directory for all voices (copied base voices, and any created combined voices)
|
||||||
VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices")
|
VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_instance(cls):
|
def initialize(cls):
|
||||||
if cls._instance is None:
|
"""Initialize and warm up the model"""
|
||||||
with cls._lock:
|
with cls._lock:
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
# Initialize model
|
||||||
print(f"Initializing model on {device}")
|
cls._device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
model_path = os.path.join(settings.model_dir, settings.model_path)
|
logger.info(f"Initializing model on {cls._device}")
|
||||||
model = build_model(model_path, device)
|
model_path = os.path.join(settings.model_dir, settings.model_path)
|
||||||
# Note: RNN memory optimization is handled internally by the model
|
model = build_model(model_path, cls._device)
|
||||||
cls._instance = (model, device)
|
cls._instance = model
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_voicepack(cls, voice_name: str) -> torch.Tensor:
|
|
||||||
"""Get a voice pack from the voices directory."""
|
|
||||||
model, device = cls.get_instance()
|
|
||||||
if voice_name not in cls._voicepacks:
|
|
||||||
try:
|
|
||||||
voice_path = os.path.join(cls.VOICES_DIR, f"{voice_name}.pt")
|
|
||||||
if not os.path.exists(voice_path):
|
|
||||||
raise FileNotFoundError(f"Voice file not found: {voice_name}")
|
|
||||||
|
|
||||||
voicepack = torch.load(voice_path, map_location=device, weights_only=True)
|
# Ensure voices directory exists
|
||||||
cls._voicepacks[voice_name] = voicepack
|
os.makedirs(cls.VOICES_DIR, exist_ok=True)
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error loading voice {voice_name}: {str(e)}")
|
# Copy base voices to local directory
|
||||||
if voice_name != "af":
|
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
|
||||||
return cls.get_voicepack("af")
|
if os.path.exists(base_voices_dir):
|
||||||
raise
|
for file in os.listdir(base_voices_dir):
|
||||||
return cls._voicepacks[voice_name]
|
if file.endswith(".pt"):
|
||||||
|
voice_name = file[:-3]
|
||||||
|
voice_path = os.path.join(cls.VOICES_DIR, file)
|
||||||
|
if not os.path.exists(voice_path):
|
||||||
|
try:
|
||||||
|
logger.info(f"Copying base voice {voice_name} to voices directory")
|
||||||
|
base_path = os.path.join(base_voices_dir, file)
|
||||||
|
voicepack = torch.load(base_path, map_location=cls._device, weights_only=True)
|
||||||
|
torch.save(voicepack, voice_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error copying voice {voice_name}: {str(e)}")
|
||||||
|
|
||||||
|
# Warm up with default voice
|
||||||
|
try:
|
||||||
|
dummy_text = "Hello"
|
||||||
|
voice_path = os.path.join(cls.VOICES_DIR, "af.pt")
|
||||||
|
dummy_voicepack = torch.load(voice_path, map_location=cls._device, weights_only=True)
|
||||||
|
generate(model, dummy_text, dummy_voicepack, lang='a', speed=1.0)
|
||||||
|
logger.info("Model warm-up complete")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Model warm-up failed: {e}")
|
||||||
|
|
||||||
|
# Count voices in directory for validation
|
||||||
|
voice_count = len([f for f in os.listdir(cls.VOICES_DIR) if f.endswith('.pt')])
|
||||||
|
return cls._instance, voice_count
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls):
|
||||||
|
"""Get the initialized instance or raise an error"""
|
||||||
|
if cls._instance is None:
|
||||||
|
raise RuntimeError("Model not initialized. Call initialize() first.")
|
||||||
|
return cls._instance, cls._device
|
||||||
|
|
||||||
|
|
||||||
class TTSService:
|
class TTSService:
|
||||||
|
@ -79,9 +99,9 @@ class TTSService:
|
||||||
voice_path = os.path.join(TTSModel.VOICES_DIR, file)
|
voice_path = os.path.join(TTSModel.VOICES_DIR, file)
|
||||||
if not os.path.exists(voice_path):
|
if not os.path.exists(voice_path):
|
||||||
try:
|
try:
|
||||||
base_path = os.path.join(base_voices_dir, file)
|
|
||||||
logger.info(f"Copying base voice {voice_name} to voices directory")
|
logger.info(f"Copying base voice {voice_name} to voices directory")
|
||||||
voicepack = torch.load(base_path, map_location=TTSModel.get_instance()[1], weights_only=True)
|
base_path = os.path.join(base_voices_dir, file)
|
||||||
|
voicepack = torch.load(base_path, map_location=TTSModel._device, weights_only=True)
|
||||||
torch.save(voicepack, voice_path)
|
torch.save(voicepack, voice_path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error copying voice {voice_name}: {str(e)}")
|
logger.error(f"Error copying voice {voice_name}: {str(e)}")
|
||||||
|
@ -114,21 +134,21 @@ class TTSService:
|
||||||
if not text:
|
if not text:
|
||||||
raise ValueError("Text is empty after preprocessing")
|
raise ValueError("Text is empty after preprocessing")
|
||||||
|
|
||||||
# Get model instance
|
# Check voice exists
|
||||||
model, device = TTSModel.get_instance()
|
|
||||||
|
|
||||||
# Load voice
|
|
||||||
voice_path = self._get_voice_path(voice)
|
voice_path = self._get_voice_path(voice)
|
||||||
if not voice_path:
|
if not voice_path:
|
||||||
raise ValueError(f"Voice not found: {voice}")
|
raise ValueError(f"Voice not found: {voice}")
|
||||||
|
|
||||||
voicepack = torch.load(voice_path, map_location=device, weights_only=True)
|
# Load model and voice
|
||||||
|
model = TTSModel._instance
|
||||||
|
voicepack = torch.load(voice_path, map_location=TTSModel._device, weights_only=True)
|
||||||
|
|
||||||
# Generate audio with or without stitching
|
# Generate audio with or without stitching
|
||||||
if stitch_long_output:
|
if stitch_long_output:
|
||||||
chunks = self._split_text(text)
|
chunks = self._split_text(text)
|
||||||
audio_chunks = []
|
audio_chunks = []
|
||||||
|
|
||||||
|
# Process all chunks with same model/voicepack instance
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
try:
|
try:
|
||||||
# Validate phonemization first
|
# Validate phonemization first
|
||||||
|
@ -204,12 +224,9 @@ class TTSService:
|
||||||
v_name: List[str] = []
|
v_name: List[str] = []
|
||||||
|
|
||||||
for voice in voices:
|
for voice in voices:
|
||||||
voice_path = self._get_voice_path(voice)
|
|
||||||
if not voice_path:
|
|
||||||
raise ValueError(f"Voice not found: {voice}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
voicepack = torch.load(voice_path, map_location=TTSModel.get_instance()[1], weights_only=True)
|
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt")
|
||||||
|
voicepack = torch.load(voice_path, map_location=TTSModel._device, weights_only=True)
|
||||||
t_voices.append(voicepack)
|
t_voices.append(voicepack)
|
||||||
v_name.append(voice)
|
v_name.append(voice)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -1,8 +1,23 @@
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
def cleanup_mock_dirs():
|
||||||
|
"""Clean up any MagicMock directories created during tests"""
|
||||||
|
mock_dir = "MagicMock"
|
||||||
|
if os.path.exists(mock_dir):
|
||||||
|
shutil.rmtree(mock_dir)
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def cleanup():
|
||||||
|
"""Automatically clean up before and after each test"""
|
||||||
|
cleanup_mock_dirs()
|
||||||
|
yield
|
||||||
|
cleanup_mock_dirs()
|
||||||
|
|
||||||
# Mock torch and other ML modules before they're imported
|
# Mock torch and other ML modules before they're imported
|
||||||
sys.modules["torch"] = Mock()
|
sys.modules["torch"] = Mock()
|
||||||
sys.modules["transformers"] = Mock()
|
sys.modules["transformers"] = Mock()
|
||||||
|
|
|
@ -1,45 +1,116 @@
|
||||||
"""Tests for main FastAPI application"""
|
"""Tests for FastAPI application"""
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import patch, MagicMock
|
from unittest.mock import patch, MagicMock
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
from api.src.main import app, lifespan
|
||||||
from api.src.main import app
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def client():
|
def test_client():
|
||||||
"""Create a test client"""
|
"""Create a test client"""
|
||||||
return TestClient(app)
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
def test_health_check(client):
|
def test_health_check(test_client):
|
||||||
"""Test health check endpoint"""
|
"""Test health check endpoint"""
|
||||||
response = client.get("/health")
|
response = test_client.get("/health")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json() == {"status": "healthy"}
|
assert response.json() == {"status": "healthy"}
|
||||||
|
|
||||||
|
|
||||||
def test_test_endpoint(client):
|
@pytest.mark.asyncio
|
||||||
"""Test the test endpoint"""
|
@patch('api.src.main.TTSModel')
|
||||||
response = client.get("/v1/test")
|
@patch('api.src.main.logger')
|
||||||
assert response.status_code == 200
|
async def test_lifespan_successful_warmup(mock_logger, mock_tts_model):
|
||||||
assert response.json() == {"status": "ok"}
|
"""Test successful model warmup in lifespan"""
|
||||||
|
# Mock the model initialization with model info and voicepack count
|
||||||
|
mock_model = MagicMock()
|
||||||
|
# Mock file system for voice counting
|
||||||
|
mock_tts_model.VOICES_DIR = "/mock/voices"
|
||||||
|
with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt', 'voice3.pt']):
|
||||||
|
mock_tts_model.initialize.return_value = (mock_model, 3) # 3 voice files
|
||||||
|
mock_tts_model._device = "cuda" # Set device class variable
|
||||||
|
|
||||||
|
# Create an async generator from the lifespan context manager
|
||||||
|
async_gen = lifespan(MagicMock())
|
||||||
|
# Start the context manager
|
||||||
|
await async_gen.__aenter__()
|
||||||
|
|
||||||
|
# Verify the expected logging sequence
|
||||||
|
mock_logger.info.assert_any_call("Loading TTS model and voice packs...")
|
||||||
|
mock_logger.info.assert_any_call("Model loaded and warmed up on cuda")
|
||||||
|
mock_logger.info.assert_any_call("3 voice packs loaded successfully")
|
||||||
|
|
||||||
|
# Verify model initialization was called
|
||||||
|
mock_tts_model.initialize.assert_called_once()
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
await async_gen.__aexit__(None, None, None)
|
||||||
|
|
||||||
|
|
||||||
def test_cors_headers(client):
|
@pytest.mark.asyncio
|
||||||
"""Test CORS headers are present"""
|
@patch('api.src.main.TTSModel')
|
||||||
response = client.get(
|
@patch('api.src.main.logger')
|
||||||
"/health",
|
async def test_lifespan_failed_warmup(mock_logger, mock_tts_model):
|
||||||
headers={"Origin": "http://testserver"},
|
"""Test failed model warmup in lifespan"""
|
||||||
)
|
# Mock the model initialization to fail
|
||||||
assert response.status_code == 200
|
mock_tts_model.initialize.side_effect = Exception("Failed to initialize model")
|
||||||
assert response.headers["access-control-allow-origin"] == "*"
|
|
||||||
|
# Create an async generator from the lifespan context manager
|
||||||
|
async_gen = lifespan(MagicMock())
|
||||||
|
|
||||||
|
# Verify the exception is raised
|
||||||
|
with pytest.raises(Exception, match="Failed to initialize model"):
|
||||||
|
await async_gen.__aenter__()
|
||||||
|
|
||||||
|
# Verify the expected logging sequence
|
||||||
|
mock_logger.info.assert_called_with("Loading TTS model and voice packs...")
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
await async_gen.__aexit__(None, None, None)
|
||||||
|
|
||||||
|
|
||||||
def test_openapi_schema(client):
|
@pytest.mark.asyncio
|
||||||
"""Test OpenAPI schema is accessible"""
|
@patch('api.src.main.TTSModel')
|
||||||
response = client.get("/openapi.json")
|
async def test_lifespan_cuda_warmup(mock_tts_model):
|
||||||
assert response.status_code == 200
|
"""Test model warmup specifically on CUDA"""
|
||||||
schema = response.json()
|
# Mock the model initialization with CUDA and voicepacks
|
||||||
assert schema["info"]["title"] == app.title
|
mock_model = MagicMock()
|
||||||
assert schema["info"]["version"] == app.version
|
# Mock file system for voice counting
|
||||||
|
mock_tts_model.VOICES_DIR = "/mock/voices"
|
||||||
|
with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt']):
|
||||||
|
mock_tts_model.initialize.return_value = (mock_model, 2) # 2 voice files
|
||||||
|
mock_tts_model._device = "cuda" # Set device class variable
|
||||||
|
|
||||||
|
# Create an async generator from the lifespan context manager
|
||||||
|
async_gen = lifespan(MagicMock())
|
||||||
|
await async_gen.__aenter__()
|
||||||
|
|
||||||
|
# Verify model was initialized
|
||||||
|
mock_tts_model.initialize.assert_called_once()
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
await async_gen.__aexit__(None, None, None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch('api.src.main.TTSModel')
|
||||||
|
async def test_lifespan_cpu_fallback(mock_tts_model):
|
||||||
|
"""Test model warmup falling back to CPU"""
|
||||||
|
# Mock the model initialization with CPU and voicepacks
|
||||||
|
mock_model = MagicMock()
|
||||||
|
# Mock file system for voice counting
|
||||||
|
mock_tts_model.VOICES_DIR = "/mock/voices"
|
||||||
|
with patch('os.listdir', return_value=['voice1.pt', 'voice2.pt', 'voice3.pt', 'voice4.pt']):
|
||||||
|
mock_tts_model.initialize.return_value = (mock_model, 4) # 4 voice files
|
||||||
|
mock_tts_model._device = "cpu" # Set device class variable
|
||||||
|
|
||||||
|
# Create an async generator from the lifespan context manager
|
||||||
|
async_gen = lifespan(MagicMock())
|
||||||
|
await async_gen.__aenter__()
|
||||||
|
|
||||||
|
# Verify model was initialized
|
||||||
|
mock_tts_model.initialize.assert_called_once()
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
await async_gen.__aexit__(None, None, None)
|
||||||
|
|
|
@ -131,9 +131,9 @@ def test_model_initialization_cuda(mock_build_model, mock_cuda_available):
|
||||||
mock_build_model.return_value = mock_model
|
mock_build_model.return_value = mock_model
|
||||||
|
|
||||||
TTSModel._instance = None # Reset singleton
|
TTSModel._instance = None # Reset singleton
|
||||||
model, device = TTSModel.get_instance()
|
model, voice_count = TTSModel.initialize()
|
||||||
|
|
||||||
assert device == "cuda"
|
assert TTSModel._device == "cuda" # Check the class variable instead
|
||||||
assert model == mock_model
|
assert model == mock_model
|
||||||
mock_build_model.assert_called_once()
|
mock_build_model.assert_called_once()
|
||||||
|
|
||||||
|
@ -147,31 +147,34 @@ def test_model_initialization_cpu(mock_build_model, mock_cuda_available):
|
||||||
mock_build_model.return_value = mock_model
|
mock_build_model.return_value = mock_model
|
||||||
|
|
||||||
TTSModel._instance = None # Reset singleton
|
TTSModel._instance = None # Reset singleton
|
||||||
model, device = TTSModel.get_instance()
|
model, voice_count = TTSModel.initialize()
|
||||||
|
|
||||||
assert device == "cpu"
|
assert TTSModel._device == "cpu" # Check the class variable instead
|
||||||
assert model == mock_model
|
assert model == mock_model
|
||||||
mock_build_model.assert_called_once()
|
mock_build_model.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@patch('os.path.exists')
|
@patch('api.src.services.tts.TTSService._get_voice_path')
|
||||||
@patch('api.src.services.tts.torch.load')
|
@patch('api.src.services.tts.TTSModel.get_instance')
|
||||||
@patch('os.path.join')
|
def test_voicepack_loading_error(mock_get_instance, mock_get_voice_path):
|
||||||
def test_voicepack_loading_error(mock_join, mock_torch_load, mock_exists):
|
|
||||||
"""Test voicepack loading error handling"""
|
"""Test voicepack loading error handling"""
|
||||||
mock_join.side_effect = lambda *args: '/'.join(args)
|
mock_get_voice_path.return_value = None
|
||||||
mock_exists.side_effect = lambda x: False # All voice files don't exist
|
mock_get_instance.return_value = (MagicMock(), "cpu")
|
||||||
|
|
||||||
TTSModel._instance = (MagicMock(), "cpu") # Mock instance
|
|
||||||
TTSModel._voicepacks = {} # Reset voicepacks
|
TTSModel._voicepacks = {} # Reset voicepacks
|
||||||
|
|
||||||
with pytest.raises(FileNotFoundError, match="Voice file not found: af"):
|
service = TTSService(start_worker=False)
|
||||||
TTSModel.get_voicepack("nonexistent_voice")
|
with pytest.raises(ValueError, match="Voice not found: nonexistent_voice"):
|
||||||
|
service._generate_audio("test", "nonexistent_voice", 1.0)
|
||||||
|
|
||||||
|
|
||||||
def test_save_audio(tts_service, sample_audio, tmp_path):
|
@patch('api.src.services.tts.TTSModel')
|
||||||
|
def test_save_audio(mock_tts_model, tts_service, sample_audio, tmp_path):
|
||||||
"""Test saving audio to file"""
|
"""Test saving audio to file"""
|
||||||
output_path = os.path.join(tmp_path, "test_output", "audio.wav")
|
output_dir = os.path.join(tmp_path, "test_output")
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
output_path = os.path.join(output_dir, "audio.wav")
|
||||||
|
|
||||||
tts_service._save_audio(sample_audio, output_path)
|
tts_service._save_audio(sample_audio, output_path)
|
||||||
|
|
||||||
assert os.path.exists(output_path)
|
assert os.path.exists(output_path)
|
||||||
|
|
|
@ -6,18 +6,21 @@ services:
|
||||||
working_dir: /app/Kokoro-82M
|
working_dir: /app/Kokoro-82M
|
||||||
command: >
|
command: >
|
||||||
sh -c "
|
sh -c "
|
||||||
|
rm -f .git/index.lock;
|
||||||
if [ -z \"$(ls -A .)\" ]; then
|
if [ -z \"$(ls -A .)\" ]; then
|
||||||
git clone https://huggingface.co/hexgrad/Kokoro-82M . && \
|
git clone https://huggingface.co/hexgrad/Kokoro-82M
|
||||||
git checkout 8228a351f87c8a6076502c1e3b7e72e821ebec9a;
|
|
||||||
touch .cloned;
|
touch .cloned;
|
||||||
else
|
else
|
||||||
|
rm -f .git/index.lock && \
|
||||||
|
git checkout main && \
|
||||||
|
git pull origin main && \
|
||||||
touch .cloned;
|
touch .cloned;
|
||||||
fi;
|
fi;
|
||||||
tail -f /dev/null
|
tail -f /dev/null
|
||||||
"
|
"
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "test", "-f", ".cloned"]
|
test: ["CMD", "test", "-f", ".cloned"]
|
||||||
interval: 1s
|
interval: 3s
|
||||||
timeout: 1s
|
timeout: 1s
|
||||||
retries: 120
|
retries: 120
|
||||||
start_period: 1s
|
start_period: 1s
|
||||||
|
|
|
@ -6,18 +6,21 @@ services:
|
||||||
working_dir: /app/Kokoro-82M
|
working_dir: /app/Kokoro-82M
|
||||||
command: >
|
command: >
|
||||||
sh -c "
|
sh -c "
|
||||||
|
rm -f .git/index.lock;
|
||||||
if [ -z \"$(ls -A .)\" ]; then
|
if [ -z \"$(ls -A .)\" ]; then
|
||||||
git clone https://huggingface.co/hexgrad/Kokoro-82M . && \
|
git clone https://huggingface.co/hexgrad/Kokoro-82M
|
||||||
git checkout 8228a351f87c8a6076502c1e3b7e72e821ebec9a;
|
|
||||||
touch .cloned;
|
touch .cloned;
|
||||||
else
|
else
|
||||||
|
rm -f .git/index.lock && \
|
||||||
|
git checkout main && \
|
||||||
|
git pull origin main && \
|
||||||
touch .cloned;
|
touch .cloned;
|
||||||
fi;
|
fi;
|
||||||
tail -f /dev/null
|
tail -f /dev/null
|
||||||
"
|
"
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "test", "-f", ".cloned"]
|
test: ["CMD", "test", "-f", ".cloned"]
|
||||||
interval: 1s
|
interval: 3s
|
||||||
timeout: 1s
|
timeout: 1s
|
||||||
retries: 120
|
retries: 120
|
||||||
start_period: 1s
|
start_period: 1s
|
||||||
|
@ -42,3 +45,15 @@ services:
|
||||||
depends_on:
|
depends_on:
|
||||||
model-fetcher:
|
model-fetcher:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
|
|
||||||
|
# # Gradio UI service
|
||||||
|
# gradio-ui:
|
||||||
|
# build:
|
||||||
|
# context: ./ui
|
||||||
|
# ports:
|
||||||
|
# - "7860:7860"
|
||||||
|
# volumes:
|
||||||
|
# - ./ui/data:/app/ui/data
|
||||||
|
# - ./ui/app.py:/app/app.py # Mount app.py for hot reload
|
||||||
|
# environment:
|
||||||
|
# - GRADIO_WATCH=True # Enable hot reloading
|
||||||
|
|
Loading…
Add table
Reference in a new issue