mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
WIP: CPU/GPU Functional, few straggling tests to fix and check.
This commit is contained in:
parent
e4d8e74738
commit
9496a3a63f
11 changed files with 366 additions and 233 deletions
|
@ -21,8 +21,8 @@ async def lifespan(app: FastAPI):
|
|||
logger.info("Loading TTS model and voice packs...")
|
||||
|
||||
# Initialize the main model with warm-up
|
||||
voicepack_count = TTSModel.initialize()
|
||||
logger.info(f"Model loaded and warmed up on {TTSModel._device}")
|
||||
voicepack_count = TTSModel.setup()
|
||||
logger.info(f"Model loaded and warmed up on {TTSModel.get_device()}")
|
||||
logger.info(f"{voicepack_count} voice packs loaded successfully")
|
||||
yield
|
||||
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
from .tts_model import TTSModel
|
||||
from .tts_service import TTSService
|
||||
|
||||
__all__ = ["TTSService", "TTSModel"]
|
||||
__all__ = ["TTSService"]
|
||||
|
|
110
api/src/services/tts_base.py
Normal file
110
api/src/services/tts_base.py
Normal file
|
@ -0,0 +1,110 @@
|
|||
import os
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
import torch
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from kokoro import tokenize, phonemize
|
||||
from typing import Union, List
|
||||
|
||||
from ..core.config import settings
|
||||
|
||||
|
||||
class TTSBaseModel(ABC):
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
_device = None
|
||||
VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices")
|
||||
|
||||
@classmethod
|
||||
def setup(cls):
|
||||
"""Initialize model and setup voices"""
|
||||
with cls._lock:
|
||||
# Set device
|
||||
cuda_available = torch.cuda.is_available()
|
||||
logger.info(f"CUDA available: {cuda_available}")
|
||||
if cuda_available:
|
||||
try:
|
||||
# Test CUDA device
|
||||
test_tensor = torch.zeros(1).cuda()
|
||||
logger.info("CUDA test successful")
|
||||
cls._device = "cuda"
|
||||
except Exception as e:
|
||||
logger.error(f"CUDA test failed: {e}")
|
||||
cls._device = "cpu"
|
||||
else:
|
||||
cls._device = "cpu"
|
||||
logger.info(f"Initializing model on {cls._device}")
|
||||
|
||||
# Initialize model
|
||||
if not cls.initialize(settings.model_dir, settings.model_path):
|
||||
raise RuntimeError(f"Failed to initialize {cls._device.upper()} model")
|
||||
|
||||
# Setup voices directory
|
||||
os.makedirs(cls.VOICES_DIR, exist_ok=True)
|
||||
|
||||
# Copy base voices to local directory
|
||||
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
|
||||
if os.path.exists(base_voices_dir):
|
||||
for file in os.listdir(base_voices_dir):
|
||||
if file.endswith(".pt"):
|
||||
voice_name = file[:-3]
|
||||
voice_path = os.path.join(cls.VOICES_DIR, file)
|
||||
if not os.path.exists(voice_path):
|
||||
try:
|
||||
logger.info(f"Copying base voice {voice_name} to voices directory")
|
||||
base_path = os.path.join(base_voices_dir, file)
|
||||
voicepack = torch.load(base_path, map_location=cls._device, weights_only=True)
|
||||
torch.save(voicepack, voice_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Error copying voice {voice_name}: {str(e)}")
|
||||
|
||||
# Warm up with default voice
|
||||
try:
|
||||
dummy_text = "Hello"
|
||||
voice_path = os.path.join(cls.VOICES_DIR, "af.pt")
|
||||
dummy_voicepack = torch.load(voice_path, map_location=cls._device, weights_only=True)
|
||||
|
||||
if cls._device == "cuda":
|
||||
cls.generate(dummy_text, dummy_voicepack, "a", 1.0)
|
||||
else:
|
||||
ps = phonemize(dummy_text, "a")
|
||||
tokens = tokenize(ps)
|
||||
tokens = [0] + tokens + [0]
|
||||
cls.generate(tokens, dummy_voicepack, 1.0)
|
||||
|
||||
logger.info("Model warm-up complete")
|
||||
except Exception as e:
|
||||
logger.warning(f"Model warm-up failed: {e}")
|
||||
|
||||
# Count voices in directory
|
||||
voice_count = len([f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")])
|
||||
return voice_count
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def initialize(cls, model_dir: str, model_path: str = None):
|
||||
"""Initialize the model"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def generate(cls, input_data: Union[str, List[int]], voicepack: torch.Tensor, *args) -> np.ndarray:
|
||||
"""Generate audio from input
|
||||
|
||||
Args:
|
||||
input_data: Either text string (GPU) or tokenized input (CPU)
|
||||
voicepack: Voice tensor
|
||||
*args: Additional args (lang+speed for GPU, speed for CPU)
|
||||
|
||||
Returns:
|
||||
np.ndarray: Generated audio samples
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_device(cls):
|
||||
"""Get the current device"""
|
||||
if cls._device is None:
|
||||
raise RuntimeError("Model not initialized. Call setup() first.")
|
||||
return cls._device
|
|
@ -4,17 +4,35 @@ import torch
|
|||
from onnxruntime import InferenceSession, SessionOptions, GraphOptimizationLevel, ExecutionMode
|
||||
from loguru import logger
|
||||
|
||||
class TTSCPUModel:
|
||||
from .tts_base import TTSBaseModel
|
||||
|
||||
class TTSCPUModel(TTSBaseModel):
|
||||
_instance = None
|
||||
_onnx_session = None
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, model_dir: str):
|
||||
def initialize(cls, model_dir: str, model_path: str = None):
|
||||
"""Initialize ONNX model for CPU inference"""
|
||||
if cls._onnx_session is None:
|
||||
# Try loading ONNX model
|
||||
onnx_path = os.path.join(model_dir, "kokoro-v0_19.onnx")
|
||||
if not os.path.exists(onnx_path):
|
||||
# First try the specified path if provided
|
||||
if model_path and model_path.endswith('.onnx'):
|
||||
onnx_path = os.path.join(model_dir, model_path)
|
||||
if os.path.exists(onnx_path):
|
||||
logger.info(f"Loading specified ONNX model from {onnx_path}")
|
||||
else:
|
||||
onnx_path = None
|
||||
else:
|
||||
# Look for any .onnx file in the directory as fallback
|
||||
onnx_files = [f for f in os.listdir(model_dir) if f.endswith('.onnx')]
|
||||
if onnx_files:
|
||||
onnx_path = os.path.join(model_dir, onnx_files[0])
|
||||
logger.info(f"Found ONNX model: {onnx_path}")
|
||||
else:
|
||||
logger.error(f"No ONNX model found in {model_dir}")
|
||||
return None
|
||||
|
||||
if not onnx_path:
|
||||
return None
|
||||
|
||||
logger.info(f"Loading ONNX model from {onnx_path}")
|
||||
|
@ -44,22 +62,33 @@ class TTSCPUModel:
|
|||
return cls._onnx_session
|
||||
|
||||
@classmethod
|
||||
def generate(cls, tokens: list, voicepack: torch.Tensor, speed: float) -> np.ndarray:
|
||||
"""Generate audio using ONNX model"""
|
||||
def generate(cls, input_data: list[int], voicepack: torch.Tensor, *args) -> np.ndarray:
|
||||
"""Generate audio using ONNX model
|
||||
|
||||
Args:
|
||||
input_data: list of token IDs
|
||||
voicepack: Voice tensor
|
||||
*args: (speed,) tuple
|
||||
|
||||
Returns:
|
||||
np.ndarray: Generated audio samples
|
||||
"""
|
||||
if cls._onnx_session is None:
|
||||
raise RuntimeError("ONNX model not initialized")
|
||||
|
||||
speed = args[0]
|
||||
# Pre-allocate and prepare inputs
|
||||
tokens_input = np.array([tokens], dtype=np.int64)
|
||||
style_input = voicepack[len(tokens)-2].numpy() # Already has correct dimensions
|
||||
tokens_input = np.array([input_data], dtype=np.int64)
|
||||
style_input = voicepack[len(input_data)-2].numpy() # Already has correct dimensions
|
||||
speed_input = np.full(1, speed, dtype=np.float32) # More efficient than ones * speed
|
||||
|
||||
# Run inference with optimized inputs
|
||||
return cls._onnx_session.run(
|
||||
result = cls._onnx_session.run(
|
||||
None,
|
||||
{
|
||||
'tokens': tokens_input,
|
||||
'style': style_input,
|
||||
'speed': speed_input
|
||||
}
|
||||
)[0]
|
||||
)
|
||||
return result[0]
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
from models import build_model
|
||||
from kokoro import generate
|
||||
|
||||
class TTSGPUModel:
|
||||
from .tts_base import TTSBaseModel
|
||||
|
||||
class TTSGPUModel(TTSBaseModel):
|
||||
_instance = None
|
||||
_device = "cuda"
|
||||
|
||||
|
@ -24,9 +27,26 @@ class TTSGPUModel:
|
|||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def generate(cls, text: str, voicepack: torch.Tensor, lang: str, speed: float) -> tuple[torch.Tensor, dict]:
|
||||
"""Generate audio using PyTorch model on GPU"""
|
||||
def generate(cls, input_data: str, voicepack: torch.Tensor, *args) -> np.ndarray:
|
||||
"""Generate audio using PyTorch model on GPU
|
||||
|
||||
Args:
|
||||
input_data: Text string to generate audio from
|
||||
voicepack: Voice tensor
|
||||
*args: (lang, speed) tuple
|
||||
|
||||
Returns:
|
||||
np.ndarray: Generated audio samples
|
||||
"""
|
||||
if cls._instance is None:
|
||||
raise RuntimeError("GPU model not initialized")
|
||||
|
||||
return generate(cls._instance, text, voicepack, lang=lang, speed=speed)
|
||||
lang, speed = args
|
||||
result = generate(cls._instance, input_data, voicepack, lang=lang, speed=speed)
|
||||
# kokoro.generate returns (audio, metadata, info), we only want audio
|
||||
audio = result[0]
|
||||
|
||||
# Convert to numpy array if needed
|
||||
if isinstance(audio, torch.Tensor):
|
||||
audio = audio.cpu().numpy()
|
||||
return audio
|
||||
|
|
|
@ -1,94 +1,8 @@
|
|||
import os
|
||||
import threading
|
||||
import torch
|
||||
from loguru import logger
|
||||
from kokoro import tokenize, phonemize
|
||||
|
||||
from ..core.config import settings
|
||||
from .tts_cpu import TTSCPUModel
|
||||
from .tts_gpu import TTSGPUModel
|
||||
if torch.cuda.is_available():
|
||||
from .tts_gpu import TTSGPUModel as TTSModel
|
||||
else:
|
||||
from .tts_cpu import TTSCPUModel as TTSModel
|
||||
|
||||
|
||||
class TTSModel:
|
||||
_device = None
|
||||
_lock = threading.Lock()
|
||||
VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices")
|
||||
|
||||
@classmethod
|
||||
def initialize(cls):
|
||||
"""Initialize and warm up the model"""
|
||||
with cls._lock:
|
||||
# Set device and initialize model
|
||||
cls._device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"Initializing model on {cls._device}")
|
||||
|
||||
# Initialize appropriate model based on device
|
||||
if cls._device == "cuda":
|
||||
if not TTSGPUModel.initialize(settings.model_dir, settings.model_path):
|
||||
raise RuntimeError("Failed to initialize GPU model")
|
||||
else:
|
||||
# Try CPU ONNX first, fallback to CPU PyTorch if needed
|
||||
if not TTSCPUModel.initialize(settings.model_dir):
|
||||
logger.warning("ONNX initialization failed, falling back to PyTorch CPU")
|
||||
if not TTSGPUModel.initialize(settings.model_dir, settings.model_path):
|
||||
raise RuntimeError("Failed to initialize CPU model")
|
||||
|
||||
# Setup voices directory
|
||||
os.makedirs(cls.VOICES_DIR, exist_ok=True)
|
||||
|
||||
# Copy base voices to local directory
|
||||
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
|
||||
if os.path.exists(base_voices_dir):
|
||||
for file in os.listdir(base_voices_dir):
|
||||
if file.endswith(".pt"):
|
||||
voice_name = file[:-3]
|
||||
voice_path = os.path.join(cls.VOICES_DIR, file)
|
||||
if not os.path.exists(voice_path):
|
||||
try:
|
||||
logger.info(
|
||||
f"Copying base voice {voice_name} to voices directory"
|
||||
)
|
||||
base_path = os.path.join(base_voices_dir, file)
|
||||
voicepack = torch.load(
|
||||
base_path,
|
||||
map_location=cls._device,
|
||||
weights_only=True,
|
||||
)
|
||||
torch.save(voicepack, voice_path)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error copying voice {voice_name}: {str(e)}"
|
||||
)
|
||||
|
||||
# Warm up with default voice
|
||||
try:
|
||||
dummy_text = "Hello"
|
||||
voice_path = os.path.join(cls.VOICES_DIR, "af.pt")
|
||||
dummy_voicepack = torch.load(
|
||||
voice_path, map_location=cls._device, weights_only=True
|
||||
)
|
||||
|
||||
if cls._device == "cuda":
|
||||
TTSGPUModel.generate(dummy_text, dummy_voicepack, "a", 1.0)
|
||||
else:
|
||||
ps = phonemize(dummy_text, "a")
|
||||
tokens = tokenize(ps)
|
||||
tokens = [0] + tokens + [0]
|
||||
TTSCPUModel.generate(tokens, dummy_voicepack, 1.0)
|
||||
|
||||
logger.info("Model warm-up complete")
|
||||
except Exception as e:
|
||||
logger.warning(f"Model warm-up failed: {e}")
|
||||
|
||||
# Count voices in directory
|
||||
voice_count = len(
|
||||
[f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]
|
||||
)
|
||||
return voice_count
|
||||
|
||||
@classmethod
|
||||
def get_device(cls):
|
||||
"""Get the current device or raise an error"""
|
||||
if cls._device is None:
|
||||
raise RuntimeError("Model not initialized. Call initialize() first.")
|
||||
return cls._device
|
||||
__all__ = ["TTSModel"]
|
||||
|
|
|
@ -12,8 +12,6 @@ from loguru import logger
|
|||
|
||||
from ..core.config import settings
|
||||
from .tts_model import TTSModel
|
||||
from .tts_cpu import TTSCPUModel
|
||||
from .tts_gpu import TTSGPUModel
|
||||
|
||||
|
||||
class TTSService:
|
||||
|
@ -22,6 +20,8 @@ class TTSService:
|
|||
|
||||
def _split_text(self, text: str) -> List[str]:
|
||||
"""Split text into sentences"""
|
||||
if not isinstance(text, str):
|
||||
text = str(text) if text is not None else ""
|
||||
return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()]
|
||||
|
||||
def _get_voice_path(self, voice_name: str) -> Optional[str]:
|
||||
|
@ -37,9 +37,12 @@ class TTSService:
|
|||
|
||||
try:
|
||||
# Normalize text once at the start
|
||||
text = normalize_text(text)
|
||||
if not text:
|
||||
raise ValueError("Text is empty after preprocessing")
|
||||
normalized = normalize_text(text)
|
||||
if not normalized:
|
||||
raise ValueError("Text is empty after preprocessing")
|
||||
text = str(normalized)
|
||||
|
||||
# Check voice exists
|
||||
voice_path = self._get_voice_path(voice)
|
||||
|
@ -61,12 +64,18 @@ class TTSService:
|
|||
try:
|
||||
# Process chunk
|
||||
if TTSModel.get_device() == "cuda":
|
||||
chunk_audio, _ = TTSGPUModel.generate(chunk, voicepack, voice[0], speed)
|
||||
# GPU takes (text, voicepack, lang, speed)
|
||||
try:
|
||||
chunk_audio = TTSModel.generate(chunk, voicepack, voice[0], speed)
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to generate audio: {str(e)}")
|
||||
chunk_audio = None
|
||||
else:
|
||||
# CPU takes (tokens, voicepack, speed)
|
||||
ps = phonemize(chunk, voice[0])
|
||||
tokens = tokenize(ps)
|
||||
tokens = [0] + tokens + [0] # Add padding
|
||||
chunk_audio = TTSCPUModel.generate(tokens, voicepack, speed)
|
||||
tokens = [0] + list(tokens) + [0] # Add padding
|
||||
chunk_audio = TTSModel.generate(tokens, voicepack, speed)
|
||||
|
||||
if chunk_audio is not None:
|
||||
audio_chunks.append(chunk_audio)
|
||||
|
@ -90,12 +99,18 @@ class TTSService:
|
|||
else:
|
||||
# Process single chunk
|
||||
if TTSModel.get_device() == "cuda":
|
||||
audio, _ = TTSGPUModel.generate(text, voicepack, voice[0], speed)
|
||||
# GPU takes (text, voicepack, lang, speed)
|
||||
try:
|
||||
audio = TTSModel.generate(text, voicepack, voice[0], speed)
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Failed to generate audio: {str(e)}")
|
||||
raise ValueError("No audio chunks were generated successfully")
|
||||
else:
|
||||
# CPU takes (tokens, voicepack, speed)
|
||||
ps = phonemize(text, voice[0])
|
||||
tokens = tokenize(ps)
|
||||
tokens = [0] + tokens + [0] # Add padding
|
||||
audio = TTSCPUModel.generate(tokens, voicepack, speed)
|
||||
tokens = [0] + list(tokens) + [0] # Add padding
|
||||
audio = TTSModel.generate(tokens, voicepack, speed)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
return audio, processing_time
|
||||
|
|
|
@ -36,7 +36,7 @@ sys.modules["kokoro.tokenize"] = Mock()
|
|||
@pytest.fixture(autouse=True)
|
||||
def mock_tts_model():
|
||||
"""Mock TTSModel to avoid loading real models during tests"""
|
||||
with patch("api.src.services.tts.TTSModel") as mock:
|
||||
with patch("api.src.services.tts_model.TTSModel") as mock:
|
||||
model_instance = Mock()
|
||||
model_instance.get_instance.return_value = model_instance
|
||||
model_instance.get_voicepack.return_value = None
|
||||
|
|
|
@ -26,13 +26,11 @@ def test_health_check(test_client):
|
|||
@patch("api.src.main.logger")
|
||||
async def test_lifespan_successful_warmup(mock_logger, mock_tts_model):
|
||||
"""Test successful model warmup in lifespan"""
|
||||
# Mock the model initialization with model info and voicepack count
|
||||
mock_model = MagicMock()
|
||||
# Mock file system for voice counting
|
||||
mock_tts_model.VOICES_DIR = "/mock/voices"
|
||||
with patch("os.listdir", return_value=["voice1.pt", "voice2.pt", "voice3.pt"]):
|
||||
mock_tts_model.initialize.return_value = (mock_model, 3) # 3 voice files
|
||||
mock_tts_model._device = "cuda" # Set device class variable
|
||||
mock_tts_model.setup.return_value = 3 # 3 voice files
|
||||
mock_tts_model.get_device.return_value = "cuda"
|
||||
|
||||
# Create an async generator from the lifespan context manager
|
||||
async_gen = lifespan(MagicMock())
|
||||
|
@ -44,8 +42,8 @@ async def test_lifespan_successful_warmup(mock_logger, mock_tts_model):
|
|||
mock_logger.info.assert_any_call("Model loaded and warmed up on cuda")
|
||||
mock_logger.info.assert_any_call("3 voice packs loaded successfully")
|
||||
|
||||
# Verify model initialization was called
|
||||
mock_tts_model.initialize.assert_called_once()
|
||||
# Verify model setup was called
|
||||
mock_tts_model.setup.assert_called_once()
|
||||
|
||||
# Clean up
|
||||
await async_gen.__aexit__(None, None, None)
|
||||
|
@ -56,14 +54,14 @@ async def test_lifespan_successful_warmup(mock_logger, mock_tts_model):
|
|||
@patch("api.src.main.logger")
|
||||
async def test_lifespan_failed_warmup(mock_logger, mock_tts_model):
|
||||
"""Test failed model warmup in lifespan"""
|
||||
# Mock the model initialization to fail
|
||||
mock_tts_model.initialize.side_effect = Exception("Failed to initialize model")
|
||||
# Mock the model setup to fail
|
||||
mock_tts_model.setup.side_effect = RuntimeError("Failed to initialize model")
|
||||
|
||||
# Create an async generator from the lifespan context manager
|
||||
async_gen = lifespan(MagicMock())
|
||||
|
||||
# Verify the exception is raised
|
||||
with pytest.raises(Exception, match="Failed to initialize model"):
|
||||
with pytest.raises(RuntimeError, match="Failed to initialize model"):
|
||||
await async_gen.__aenter__()
|
||||
|
||||
# Verify the expected logging sequence
|
||||
|
@ -77,20 +75,18 @@ async def test_lifespan_failed_warmup(mock_logger, mock_tts_model):
|
|||
@patch("api.src.main.TTSModel")
|
||||
async def test_lifespan_cuda_warmup(mock_tts_model):
|
||||
"""Test model warmup specifically on CUDA"""
|
||||
# Mock the model initialization with CUDA and voicepacks
|
||||
mock_model = MagicMock()
|
||||
# Mock file system for voice counting
|
||||
mock_tts_model.VOICES_DIR = "/mock/voices"
|
||||
with patch("os.listdir", return_value=["voice1.pt", "voice2.pt"]):
|
||||
mock_tts_model.initialize.return_value = (mock_model, 2) # 2 voice files
|
||||
mock_tts_model._device = "cuda" # Set device class variable
|
||||
mock_tts_model.setup.return_value = 2 # 2 voice files
|
||||
mock_tts_model.get_device.return_value = "cuda"
|
||||
|
||||
# Create an async generator from the lifespan context manager
|
||||
async_gen = lifespan(MagicMock())
|
||||
await async_gen.__aenter__()
|
||||
|
||||
# Verify model was initialized
|
||||
mock_tts_model.initialize.assert_called_once()
|
||||
# Verify model setup was called
|
||||
mock_tts_model.setup.assert_called_once()
|
||||
|
||||
# Clean up
|
||||
await async_gen.__aexit__(None, None, None)
|
||||
|
@ -100,22 +96,20 @@ async def test_lifespan_cuda_warmup(mock_tts_model):
|
|||
@patch("api.src.main.TTSModel")
|
||||
async def test_lifespan_cpu_fallback(mock_tts_model):
|
||||
"""Test model warmup falling back to CPU"""
|
||||
# Mock the model initialization with CPU and voicepacks
|
||||
mock_model = MagicMock()
|
||||
# Mock file system for voice counting
|
||||
mock_tts_model.VOICES_DIR = "/mock/voices"
|
||||
with patch(
|
||||
"os.listdir", return_value=["voice1.pt", "voice2.pt", "voice3.pt", "voice4.pt"]
|
||||
):
|
||||
mock_tts_model.initialize.return_value = (mock_model, 4) # 4 voice files
|
||||
mock_tts_model._device = "cpu" # Set device class variable
|
||||
mock_tts_model.setup.return_value = 4 # 4 voice files
|
||||
mock_tts_model.get_device.return_value = "cpu"
|
||||
|
||||
# Create an async generator from the lifespan context manager
|
||||
async_gen = lifespan(MagicMock())
|
||||
await async_gen.__aenter__()
|
||||
|
||||
# Verify model was initialized
|
||||
mock_tts_model.initialize.assert_called_once()
|
||||
# Verify model setup was called
|
||||
mock_tts_model.setup.assert_called_once()
|
||||
|
||||
# Clean up
|
||||
await async_gen.__aexit__(None, None, None)
|
||||
|
|
|
@ -7,6 +7,7 @@ import numpy as np
|
|||
import torch
|
||||
import pytest
|
||||
|
||||
from api.src.core.config import settings
|
||||
from api.src.services.tts_model import TTSModel
|
||||
from api.src.services.tts_service import TTSService
|
||||
|
||||
|
@ -14,7 +15,7 @@ from api.src.services.tts_service import TTSService
|
|||
@pytest.fixture
|
||||
def tts_service():
|
||||
"""Create a TTSService instance for testing"""
|
||||
return TTSService(start_worker=False)
|
||||
return TTSService()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -86,6 +87,7 @@ def test_generate_audio_empty_text(
|
|||
):
|
||||
"""Test generating audio with empty text"""
|
||||
mock_normalize.return_value = ""
|
||||
mock_instance.return_value = (MagicMock(), "cpu")
|
||||
|
||||
with pytest.raises(ValueError, match="Text is empty after preprocessing"):
|
||||
tts_service._generate_audio("", "af", 1.0)
|
||||
|
@ -111,7 +113,7 @@ def test_generate_audio_no_chunks(
|
|||
"""Test generating audio with no successful chunks"""
|
||||
mock_normalize.return_value = "Test text"
|
||||
mock_phonemize.return_value = "Test text"
|
||||
mock_tokenize.return_value = ["test", "text"]
|
||||
mock_tokenize.return_value = [1, 2] # Return integers instead of strings
|
||||
mock_generate.return_value = (None, None)
|
||||
mock_instance.return_value = (MagicMock(), "cpu")
|
||||
mock_exists.return_value = True
|
||||
|
@ -156,57 +158,23 @@ def test_combine_voices_invalid_input(tts_service):
|
|||
tts_service.combine_voices(["voice1"])
|
||||
|
||||
|
||||
@patch("os.makedirs")
|
||||
|
||||
@patch("api.src.services.tts_model.TTSModel.get_instance")
|
||||
@patch("api.src.services.tts_model.TTSModel.get_device")
|
||||
@patch("api.src.services.tts_model.TTSModel.generate")
|
||||
@patch("os.path.exists")
|
||||
@patch("os.listdir")
|
||||
@patch("torch.load")
|
||||
@patch("torch.save")
|
||||
@patch("os.path.join")
|
||||
def test_ensure_voices(
|
||||
mock_join,
|
||||
mock_save,
|
||||
mock_load,
|
||||
mock_listdir,
|
||||
mock_exists,
|
||||
mock_makedirs,
|
||||
tts_service,
|
||||
):
|
||||
"""Test voice directory initialization"""
|
||||
# Setup mocks
|
||||
mock_exists.side_effect = [
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
] # base_dir exists, voice files don't exist
|
||||
mock_listdir.return_value = ["voice1.pt", "voice2.pt"]
|
||||
mock_load.return_value = MagicMock()
|
||||
mock_join.return_value = "/fake/path"
|
||||
|
||||
# Test voice directory initialization
|
||||
tts_service._ensure_voices()
|
||||
|
||||
# Verify directory was created
|
||||
mock_makedirs.assert_called_once()
|
||||
|
||||
# Verify voices were loaded and saved
|
||||
assert mock_load.call_count == len(mock_listdir.return_value)
|
||||
assert mock_save.call_count == len(mock_listdir.return_value)
|
||||
|
||||
|
||||
@patch("api.src.services.tts.TTSModel.get_instance")
|
||||
@patch("os.path.exists")
|
||||
@patch("api.src.services.tts.normalize_text")
|
||||
@patch("api.src.services.tts.phonemize")
|
||||
@patch("api.src.services.tts.tokenize")
|
||||
@patch("api.src.services.tts.generate")
|
||||
@patch("kokoro.normalize_text")
|
||||
@patch("kokoro.phonemize")
|
||||
@patch("kokoro.tokenize")
|
||||
@patch("torch.load")
|
||||
def test_generate_audio_success(
|
||||
mock_torch_load,
|
||||
mock_generate,
|
||||
mock_tokenize,
|
||||
mock_phonemize,
|
||||
mock_normalize,
|
||||
mock_exists,
|
||||
mock_model_generate,
|
||||
mock_get_device,
|
||||
mock_instance,
|
||||
tts_service,
|
||||
sample_audio,
|
||||
|
@ -214,12 +182,17 @@ def test_generate_audio_success(
|
|||
"""Test successful audio generation"""
|
||||
mock_normalize.return_value = "Test text"
|
||||
mock_phonemize.return_value = "Test text"
|
||||
mock_tokenize.return_value = ["test", "text"]
|
||||
mock_generate.return_value = (sample_audio, None)
|
||||
mock_tokenize.return_value = [1, 2] # Return integers instead of strings
|
||||
mock_model_generate.return_value = sample_audio
|
||||
mock_instance.return_value = (MagicMock(), "cpu")
|
||||
mock_get_device.return_value = "cpu"
|
||||
mock_exists.return_value = True
|
||||
mock_torch_load.return_value = MagicMock()
|
||||
|
||||
# Initialize model
|
||||
TTSModel._instance = None
|
||||
TTSModel._device = "cpu"
|
||||
|
||||
audio, processing_time = tts_service._generate_audio("Test text", "af", 1.0)
|
||||
assert isinstance(audio, np.ndarray)
|
||||
assert isinstance(processing_time, float)
|
||||
|
@ -227,35 +200,94 @@ def test_generate_audio_success(
|
|||
|
||||
|
||||
@patch("torch.cuda.is_available")
|
||||
@patch("models.build_model")
|
||||
def test_model_initialization_cuda(mock_build_model, mock_cuda_available):
|
||||
@patch("api.src.services.tts_gpu.TTSGPUModel.initialize")
|
||||
@patch("os.makedirs")
|
||||
@patch("os.path.exists")
|
||||
@patch("os.listdir")
|
||||
@patch("torch.load")
|
||||
@patch("torch.save")
|
||||
@patch("api.src.core.config.settings")
|
||||
@patch("torch.zeros")
|
||||
def test_model_initialization_cuda(
|
||||
mock_zeros,
|
||||
mock_settings,
|
||||
mock_save,
|
||||
mock_load,
|
||||
mock_listdir,
|
||||
mock_exists,
|
||||
mock_makedirs,
|
||||
mock_initialize,
|
||||
mock_cuda_available,
|
||||
):
|
||||
"""Test model initialization with CUDA"""
|
||||
# Setup mocks
|
||||
mock_cuda_available.return_value = True
|
||||
mock_model = MagicMock()
|
||||
mock_build_model.return_value = mock_model
|
||||
mock_initialize.return_value = True
|
||||
mock_exists.return_value = True
|
||||
mock_listdir.return_value = ["voice1.pt", "voice2.pt"]
|
||||
mock_load.return_value = torch.zeros(1)
|
||||
mock_settings.model_dir = "test_dir"
|
||||
mock_settings.model_path = "test_path"
|
||||
mock_settings.voices_dir = "voices"
|
||||
mock_zeros.return_value = torch.zeros(1)
|
||||
|
||||
TTSModel._instance = None # Reset singleton
|
||||
model, voice_count = TTSModel.initialize()
|
||||
# Reset singleton and device
|
||||
TTSModel._instance = None
|
||||
TTSModel._device = None
|
||||
|
||||
# Mock settings to prevent actual file operations
|
||||
with patch.object(settings, 'model_dir', 'test_dir'), \
|
||||
patch.object(settings, 'model_path', 'test_path'):
|
||||
voice_count = TTSModel.setup()
|
||||
|
||||
assert TTSModel._device == "cuda" # Check the class variable instead
|
||||
assert model == mock_model
|
||||
mock_build_model.assert_called_once()
|
||||
assert TTSModel.get_device() == "cuda"
|
||||
assert voice_count == 2
|
||||
mock_initialize.assert_called_once_with("test_dir", "test_path")
|
||||
|
||||
|
||||
@patch("api.src.services.tts.torch.cuda.is_available")
|
||||
@patch("api.src.services.tts.build_model")
|
||||
def test_model_initialization_cpu(mock_build_model, mock_cuda_available):
|
||||
@patch("torch.cuda.is_available")
|
||||
@patch("api.src.services.tts_base.TTSBaseModel.initialize")
|
||||
@patch("os.makedirs")
|
||||
@patch("os.path.exists")
|
||||
@patch("os.listdir")
|
||||
@patch("torch.load")
|
||||
@patch("torch.save")
|
||||
@patch("api.src.core.config.settings")
|
||||
@patch("torch.zeros")
|
||||
def test_model_initialization_cpu(
|
||||
mock_zeros,
|
||||
mock_settings,
|
||||
mock_save,
|
||||
mock_load,
|
||||
mock_listdir,
|
||||
mock_exists,
|
||||
mock_makedirs,
|
||||
mock_initialize,
|
||||
mock_cuda_available,
|
||||
):
|
||||
"""Test model initialization with CPU"""
|
||||
# Setup mocks
|
||||
mock_cuda_available.return_value = False
|
||||
mock_model = MagicMock()
|
||||
mock_build_model.return_value = mock_model
|
||||
mock_initialize.return_value = False # This will trigger the RuntimeError
|
||||
mock_exists.return_value = True
|
||||
mock_listdir.return_value = ["voice1.pt", "voice2.pt", "voice3.pt"]
|
||||
mock_load.return_value = torch.zeros(1)
|
||||
mock_settings.model_dir = "test_dir"
|
||||
mock_settings.model_path = "test_path"
|
||||
mock_settings.voices_dir = "voices"
|
||||
mock_zeros.return_value = torch.zeros(1)
|
||||
|
||||
TTSModel._instance = None # Reset singleton
|
||||
model, voice_count = TTSModel.initialize()
|
||||
# Reset singleton and device
|
||||
TTSModel._instance = None
|
||||
TTSModel._device = None
|
||||
|
||||
assert TTSModel._device == "cpu" # Check the class variable instead
|
||||
assert model == mock_model
|
||||
mock_build_model.assert_called_once()
|
||||
# Mock settings to prevent actual file operations
|
||||
with patch.object(settings, 'model_dir', 'test_dir'), \
|
||||
patch.object(settings, 'model_path', 'test_path'), \
|
||||
pytest.raises(RuntimeError, match="Failed to initialize CPU model"):
|
||||
TTSModel.setup()
|
||||
|
||||
mock_initialize.assert_called_once_with("test_dir", "test_path")
|
||||
|
||||
|
||||
@patch("api.src.services.tts_service.TTSService._get_voice_path")
|
||||
|
@ -267,7 +299,7 @@ def test_voicepack_loading_error(mock_get_instance, mock_get_voice_path):
|
|||
|
||||
TTSModel._voicepacks = {} # Reset voicepacks
|
||||
|
||||
service = TTSService(start_worker=False)
|
||||
service = TTSService()
|
||||
with pytest.raises(ValueError, match="Voice not found: nonexistent_voice"):
|
||||
service._generate_audio("test", "nonexistent_voice", 1.0)
|
||||
|
||||
|
@ -286,23 +318,32 @@ def test_save_audio(mock_tts_model, tts_service, sample_audio, tmp_path):
|
|||
|
||||
|
||||
@patch("api.src.services.tts_model.TTSModel.get_instance")
|
||||
@patch("api.src.services.tts_model.TTSModel.get_device")
|
||||
@patch("api.src.services.tts_model.TTSModel.generate")
|
||||
@patch("os.path.exists")
|
||||
@patch("api.src.services.tts.normalize_text")
|
||||
@patch("api.src.services.tts.generate")
|
||||
@patch("kokoro.normalize_text")
|
||||
@patch("kokoro.phonemize")
|
||||
@patch("kokoro.tokenize")
|
||||
@patch("torch.load")
|
||||
def test_generate_audio_without_stitching(
|
||||
mock_torch_load,
|
||||
mock_generate,
|
||||
mock_tokenize,
|
||||
mock_phonemize,
|
||||
mock_normalize,
|
||||
mock_exists,
|
||||
mock_model_generate,
|
||||
mock_get_device,
|
||||
mock_instance,
|
||||
tts_service,
|
||||
sample_audio,
|
||||
):
|
||||
"""Test generating audio without text stitching"""
|
||||
mock_normalize.return_value = "Test text"
|
||||
mock_generate.return_value = (sample_audio, None)
|
||||
mock_phonemize.return_value = "Test text"
|
||||
mock_tokenize.return_value = [1, 2] # Return integers instead of strings
|
||||
mock_model_generate.return_value = sample_audio
|
||||
mock_instance.return_value = (MagicMock(), "cpu")
|
||||
mock_get_device.return_value = "cpu"
|
||||
mock_exists.return_value = True
|
||||
mock_torch_load.return_value = MagicMock()
|
||||
|
||||
|
@ -311,7 +352,7 @@ def test_generate_audio_without_stitching(
|
|||
)
|
||||
assert isinstance(audio, np.ndarray)
|
||||
assert len(audio) > 0
|
||||
mock_generate.assert_called_once()
|
||||
mock_model_generate.assert_called_once()
|
||||
|
||||
|
||||
@patch("os.listdir")
|
||||
|
@ -323,12 +364,13 @@ def test_list_voices_error(mock_listdir, tts_service):
|
|||
assert voices == []
|
||||
|
||||
|
||||
@patch("api.src.services.tts.TTSModel.get_instance")
|
||||
@patch("api.src.services.tts_model.TTSModel.get_instance")
|
||||
@patch("api.src.services.tts_model.TTSModel.get_device")
|
||||
@patch("os.path.exists")
|
||||
@patch("api.src.services.tts.normalize_text")
|
||||
@patch("api.src.services.tts.phonemize")
|
||||
@patch("api.src.services.tts.tokenize")
|
||||
@patch("api.src.services.tts.generate")
|
||||
@patch("kokoro.normalize_text")
|
||||
@patch("kokoro.phonemize")
|
||||
@patch("kokoro.tokenize")
|
||||
@patch("kokoro.generate")
|
||||
@patch("torch.load")
|
||||
def test_generate_audio_phonemize_error(
|
||||
mock_torch_load,
|
||||
|
@ -337,6 +379,7 @@ def test_generate_audio_phonemize_error(
|
|||
mock_phonemize,
|
||||
mock_normalize,
|
||||
mock_exists,
|
||||
mock_get_device,
|
||||
mock_instance,
|
||||
tts_service,
|
||||
):
|
||||
|
@ -344,33 +387,51 @@ def test_generate_audio_phonemize_error(
|
|||
mock_normalize.return_value = "Test text"
|
||||
mock_phonemize.side_effect = Exception("Phonemization failed")
|
||||
mock_instance.return_value = (MagicMock(), "cpu")
|
||||
mock_get_device.return_value = "cpu"
|
||||
mock_exists.return_value = True
|
||||
mock_torch_load.return_value = MagicMock()
|
||||
mock_generate.return_value = (None, None)
|
||||
|
||||
# Initialize model
|
||||
TTSModel._instance = None
|
||||
TTSModel._device = "cpu"
|
||||
|
||||
with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
|
||||
tts_service._generate_audio("Test text", "af", 1.0)
|
||||
|
||||
|
||||
@patch("api.src.services.tts.TTSModel.get_instance")
|
||||
@patch("api.src.services.tts_model.TTSModel.get_instance")
|
||||
@patch("api.src.services.tts_model.TTSModel.get_device")
|
||||
@patch("os.path.exists")
|
||||
@patch("api.src.services.tts.normalize_text")
|
||||
@patch("api.src.services.tts.generate")
|
||||
@patch("kokoro.normalize_text")
|
||||
@patch("kokoro.phonemize")
|
||||
@patch("kokoro.tokenize")
|
||||
@patch("kokoro.generate")
|
||||
@patch("torch.load")
|
||||
def test_generate_audio_error(
|
||||
mock_torch_load,
|
||||
mock_generate,
|
||||
mock_tokenize,
|
||||
mock_phonemize,
|
||||
mock_normalize,
|
||||
mock_exists,
|
||||
mock_get_device,
|
||||
mock_instance,
|
||||
tts_service,
|
||||
):
|
||||
"""Test handling generation error"""
|
||||
mock_normalize.return_value = "Test text"
|
||||
mock_phonemize.return_value = "Test text"
|
||||
mock_tokenize.return_value = [1, 2] # Return integers instead of strings
|
||||
mock_generate.side_effect = Exception("Generation failed")
|
||||
mock_instance.return_value = (MagicMock(), "cpu")
|
||||
mock_get_device.return_value = "cpu"
|
||||
mock_exists.return_value = True
|
||||
mock_torch_load.return_value = MagicMock()
|
||||
|
||||
# Initialize model
|
||||
TTSModel._instance = None
|
||||
TTSModel._device = "cpu"
|
||||
|
||||
with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
|
||||
tts_service._generate_audio("Test text", "af", 1.0)
|
||||
|
|
|
@ -69,22 +69,13 @@ def get_gpu_memory():
|
|||
|
||||
def get_system_metrics():
|
||||
"""Get current system metrics"""
|
||||
# Take multiple CPU measurements over a short period
|
||||
samples = []
|
||||
for _ in range(3): # Take 3 samples
|
||||
# Get both overall and per-CPU percentages
|
||||
overall_cpu = psutil.cpu_percent(interval=0.1)
|
||||
per_cpu = psutil.cpu_percent(percpu=True)
|
||||
avg_per_cpu = sum(per_cpu) / len(per_cpu)
|
||||
# Use the maximum of overall and average per-CPU
|
||||
samples.append(max(overall_cpu, avg_per_cpu))
|
||||
|
||||
# Use the maximum CPU usage from all samples
|
||||
cpu_usage = round(max(samples), 2)
|
||||
# Get per-CPU percentages and calculate average
|
||||
cpu_percentages = psutil.cpu_percent(percpu=True)
|
||||
avg_cpu = sum(cpu_percentages) / len(cpu_percentages)
|
||||
|
||||
metrics = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"cpu_percent": cpu_usage,
|
||||
"cpu_percent": round(avg_cpu, 2),
|
||||
"ram_percent": psutil.virtual_memory().percent,
|
||||
"ram_used_gb": psutil.virtual_memory().used / (1024**3),
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue