mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
-Add model instance pooling, better concurrency
-Add load testing setup with Locust
This commit is contained in:
parent
09b7c2cf1e
commit
e5e85b32d2
11 changed files with 737 additions and 530 deletions
|
@ -22,6 +22,9 @@ class PyTorchConfig(BaseModel):
|
||||||
|
|
||||||
memory_threshold: float = Field(0.8, description="Memory threshold for cleanup")
|
memory_threshold: float = Field(0.8, description="Memory threshold for cleanup")
|
||||||
retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors")
|
retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors")
|
||||||
|
max_concurrent_models: int = Field(2, description="Maximum number of concurrent model instances")
|
||||||
|
max_queue_size: int = Field(32, description="Maximum size of request queue")
|
||||||
|
chunk_semaphore_limit: int = Field(4, description="Maximum concurrent chunk processing per model")
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
frozen = True
|
frozen = True
|
||||||
|
|
|
@ -24,6 +24,11 @@ class KokoroV1(BaseModelBackend):
|
||||||
self._device = "cuda" if settings.use_gpu else "cpu"
|
self._device = "cuda" if settings.use_gpu else "cpu"
|
||||||
self._model: Optional[KModel] = None
|
self._model: Optional[KModel] = None
|
||||||
self._pipelines: Dict[str, KPipeline] = {} # Store pipelines by lang_code
|
self._pipelines: Dict[str, KPipeline] = {} # Store pipelines by lang_code
|
||||||
|
self._stream: Optional[torch.cuda.Stream] = None
|
||||||
|
|
||||||
|
def set_stream(self, stream: torch.cuda.Stream) -> None:
|
||||||
|
"""Set CUDA stream for this instance."""
|
||||||
|
self._stream = stream
|
||||||
|
|
||||||
async def load_model(self, path: str) -> None:
|
async def load_model(self, path: str) -> None:
|
||||||
"""Load pre-baked model.
|
"""Load pre-baked model.
|
||||||
|
@ -146,6 +151,19 @@ class KokoroV1(BaseModelBackend):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Generating audio from tokens with lang_code '{pipeline_lang_code}': '{tokens[:100]}...'"
|
f"Generating audio from tokens with lang_code '{pipeline_lang_code}': '{tokens[:100]}...'"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Use CUDA stream if available
|
||||||
|
if self._stream and self._device == "cuda":
|
||||||
|
with torch.cuda.stream(self._stream):
|
||||||
|
for result in pipeline.generate_from_tokens(
|
||||||
|
tokens=tokens, voice=voice_path, speed=speed, model=self._model
|
||||||
|
):
|
||||||
|
if result.audio is not None:
|
||||||
|
logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
|
||||||
|
yield result.audio.numpy()
|
||||||
|
else:
|
||||||
|
logger.warning("No audio in chunk")
|
||||||
|
else:
|
||||||
for result in pipeline.generate_from_tokens(
|
for result in pipeline.generate_from_tokens(
|
||||||
tokens=tokens, voice=voice_path, speed=speed, model=self._model
|
tokens=tokens, voice=voice_path, speed=speed, model=self._model
|
||||||
):
|
):
|
||||||
|
@ -239,6 +257,19 @@ class KokoroV1(BaseModelBackend):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Generating audio for text with lang_code '{pipeline_lang_code}': '{text[:100]}...'"
|
f"Generating audio for text with lang_code '{pipeline_lang_code}': '{text[:100]}...'"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Use CUDA stream if available
|
||||||
|
if self._stream and self._device == "cuda":
|
||||||
|
with torch.cuda.stream(self._stream):
|
||||||
|
for result in pipeline(
|
||||||
|
text, voice=voice_path, speed=speed, model=self._model
|
||||||
|
):
|
||||||
|
if result.audio is not None:
|
||||||
|
logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
|
||||||
|
yield result.audio.numpy()
|
||||||
|
else:
|
||||||
|
logger.warning("No audio in chunk")
|
||||||
|
else:
|
||||||
for result in pipeline(
|
for result in pipeline(
|
||||||
text, voice=voice_path, speed=speed, model=self._model
|
text, voice=voice_path, speed=speed, model=self._model
|
||||||
):
|
):
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
"""Kokoro V1 model management."""
|
"""Kokoro V1 model management."""
|
||||||
|
|
||||||
from typing import Optional
|
import asyncio
|
||||||
|
import time
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from ..core import paths
|
from ..core import paths
|
||||||
|
@ -11,131 +14,47 @@ from .base import BaseModelBackend
|
||||||
from .kokoro_v1 import KokoroV1
|
from .kokoro_v1 import KokoroV1
|
||||||
|
|
||||||
|
|
||||||
class ModelManager:
|
class ModelInstance:
|
||||||
"""Manages Kokoro V1 model loading and inference."""
|
"""Individual model instance with its own CUDA stream."""
|
||||||
|
|
||||||
# Singleton instance
|
def __init__(self, instance_id: int):
|
||||||
_instance = None
|
"""Initialize model instance."""
|
||||||
|
self.instance_id = instance_id
|
||||||
def __init__(self, config: Optional[ModelConfig] = None):
|
self._backend: Optional[KokoroV1] = None
|
||||||
"""Initialize manager.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config: Optional model configuration override
|
|
||||||
"""
|
|
||||||
self._config = config or model_config
|
|
||||||
self._backend: Optional[KokoroV1] = None # Explicitly type as KokoroV1
|
|
||||||
self._device: Optional[str] = None
|
self._device: Optional[str] = None
|
||||||
|
self._stream: Optional[torch.cuda.Stream] = None if not settings.use_gpu else torch.cuda.Stream()
|
||||||
|
self._in_use = False
|
||||||
|
self._last_used = 0.0
|
||||||
|
|
||||||
def _determine_device(self) -> str:
|
@property
|
||||||
"""Determine device based on settings."""
|
def is_available(self) -> bool:
|
||||||
return "cuda" if settings.use_gpu else "cpu"
|
"""Check if instance is available."""
|
||||||
|
return not self._in_use
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
"""Initialize Kokoro V1 backend."""
|
"""Initialize model instance."""
|
||||||
try:
|
try:
|
||||||
self._device = self._determine_device()
|
self._device = "cuda" if settings.use_gpu else "cpu"
|
||||||
logger.info(f"Initializing Kokoro V1 on {self._device}")
|
logger.info(f"Initializing Kokoro V1 instance {self.instance_id} on {self._device}")
|
||||||
self._backend = KokoroV1()
|
self._backend = KokoroV1()
|
||||||
|
if self._stream:
|
||||||
|
self._backend.set_stream(self._stream)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to initialize Kokoro V1: {e}")
|
raise RuntimeError(f"Failed to initialize Kokoro V1 instance {self.instance_id}: {e}")
|
||||||
|
|
||||||
async def initialize_with_warmup(self, voice_manager) -> tuple[str, str, int]:
|
|
||||||
"""Initialize and warm up model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
voice_manager: Voice manager instance for warmup
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (device, backend type, voice count)
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If initialization fails
|
|
||||||
"""
|
|
||||||
import time
|
|
||||||
|
|
||||||
start = time.perf_counter()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Initialize backend
|
|
||||||
await self.initialize()
|
|
||||||
|
|
||||||
# Load model
|
|
||||||
model_path = self._config.pytorch_kokoro_v1_file
|
|
||||||
await self.load_model(model_path)
|
|
||||||
|
|
||||||
# Use paths module to get voice path
|
|
||||||
try:
|
|
||||||
voices = await paths.list_voices()
|
|
||||||
voice_path = await paths.get_voice_path(settings.default_voice)
|
|
||||||
|
|
||||||
# Warm up with short text
|
|
||||||
warmup_text = "Warmup text for initialization."
|
|
||||||
# Use default voice name for warmup
|
|
||||||
voice_name = settings.default_voice
|
|
||||||
logger.debug(f"Using default voice '{voice_name}' for warmup")
|
|
||||||
async for _ in self.generate(warmup_text, (voice_name, voice_path)):
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Failed to get default voice: {e}")
|
|
||||||
|
|
||||||
ms = int((time.perf_counter() - start) * 1000)
|
|
||||||
logger.info(f"Warmup completed in {ms}ms")
|
|
||||||
|
|
||||||
return self._device, "kokoro_v1", len(voices)
|
|
||||||
except FileNotFoundError as e:
|
|
||||||
logger.error("""
|
|
||||||
Model files not found! You need to download the Kokoro V1 model:
|
|
||||||
|
|
||||||
1. Download model using the script:
|
|
||||||
python docker/scripts/download_model.py --output api/src/models/v1_0
|
|
||||||
|
|
||||||
2. Or set environment variable in docker-compose:
|
|
||||||
DOWNLOAD_MODEL=true
|
|
||||||
""")
|
|
||||||
exit(0)
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Warmup failed: {e}")
|
|
||||||
|
|
||||||
def get_backend(self) -> BaseModelBackend:
|
|
||||||
"""Get initialized backend.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Initialized backend instance
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If backend not initialized
|
|
||||||
"""
|
|
||||||
if not self._backend:
|
|
||||||
raise RuntimeError("Backend not initialized")
|
|
||||||
return self._backend
|
|
||||||
|
|
||||||
async def load_model(self, path: str) -> None:
|
async def load_model(self, path: str) -> None:
|
||||||
"""Load model using initialized backend.
|
"""Load model using initialized backend."""
|
||||||
|
|
||||||
Args:
|
|
||||||
path: Path to model file
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If loading fails
|
|
||||||
"""
|
|
||||||
if not self._backend:
|
if not self._backend:
|
||||||
raise RuntimeError("Backend not initialized")
|
raise RuntimeError("Backend not initialized")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self._backend.load_model(path)
|
await self._backend.load_model(path)
|
||||||
except FileNotFoundError as e:
|
|
||||||
raise e
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to load model: {e}")
|
raise RuntimeError(f"Failed to load model for instance {self.instance_id}: {e}")
|
||||||
|
|
||||||
async def generate(self, *args, **kwargs):
|
async def generate(self, *args, **kwargs):
|
||||||
"""Generate audio using initialized backend.
|
"""Generate audio using initialized backend."""
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If generation fails
|
|
||||||
"""
|
|
||||||
if not self._backend:
|
if not self._backend:
|
||||||
raise RuntimeError("Backend not initialized")
|
raise RuntimeError("Backend not initialized")
|
||||||
|
|
||||||
|
@ -143,18 +62,138 @@ Model files not found! You need to download the Kokoro V1 model:
|
||||||
async for chunk in self._backend.generate(*args, **kwargs):
|
async for chunk in self._backend.generate(*args, **kwargs):
|
||||||
yield chunk
|
yield chunk
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Generation failed: {e}")
|
raise RuntimeError(f"Generation failed for instance {self.instance_id}: {e}")
|
||||||
|
|
||||||
def unload_all(self) -> None:
|
def unload(self) -> None:
|
||||||
"""Unload model and free resources."""
|
"""Unload model and free resources."""
|
||||||
if self._backend:
|
if self._backend:
|
||||||
self._backend.unload()
|
self._backend.unload()
|
||||||
self._backend = None
|
self._backend = None
|
||||||
|
|
||||||
@property
|
|
||||||
def current_backend(self) -> str:
|
class ModelPool:
|
||||||
"""Get current backend type."""
|
"""Pool of model instances."""
|
||||||
return "kokoro_v1"
|
|
||||||
|
def __init__(self, max_instances: int):
|
||||||
|
"""Initialize model pool."""
|
||||||
|
self.max_instances = max_instances
|
||||||
|
self._instances: List[ModelInstance] = []
|
||||||
|
self._request_queue: asyncio.Queue = asyncio.Queue(maxsize=model_config.pytorch_gpu.max_queue_size)
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""Initialize model pool."""
|
||||||
|
async with self._lock:
|
||||||
|
for i in range(self.max_instances):
|
||||||
|
instance = ModelInstance(i)
|
||||||
|
await instance.initialize()
|
||||||
|
self._instances.append(instance)
|
||||||
|
|
||||||
|
async def get_instance(self) -> ModelInstance:
|
||||||
|
"""Get available model instance or wait for one."""
|
||||||
|
while True:
|
||||||
|
# Try to find an available instance
|
||||||
|
for instance in self._instances:
|
||||||
|
if instance.is_available:
|
||||||
|
instance._in_use = True
|
||||||
|
instance._last_used = time.time()
|
||||||
|
return instance
|
||||||
|
|
||||||
|
# If no instance is available, wait in queue
|
||||||
|
try:
|
||||||
|
await self._request_queue.put(asyncio.current_task())
|
||||||
|
await asyncio.sleep(0.1) # Small delay to prevent busy waiting
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
raise RuntimeError("Request queue is full")
|
||||||
|
|
||||||
|
async def release_instance(self, instance: ModelInstance) -> None:
|
||||||
|
"""Release model instance back to pool."""
|
||||||
|
instance._in_use = False
|
||||||
|
# Process next request in queue if any
|
||||||
|
if not self._request_queue.empty():
|
||||||
|
waiting_task = await self._request_queue.get()
|
||||||
|
if not waiting_task.done():
|
||||||
|
waiting_task.set_result(None)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelManager:
|
||||||
|
"""Manages Kokoro V1 model loading and inference."""
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
_instance = None
|
||||||
|
|
||||||
|
def __init__(self, config: Optional[ModelConfig] = None):
|
||||||
|
"""Initialize manager."""
|
||||||
|
self._config = config or model_config
|
||||||
|
self._pool: Optional[ModelPool] = None
|
||||||
|
self._chunk_semaphore = asyncio.Semaphore(self._config.pytorch_gpu.chunk_semaphore_limit)
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
"""Initialize model pool."""
|
||||||
|
if not self._pool:
|
||||||
|
self._pool = ModelPool(self._config.pytorch_gpu.max_concurrent_models)
|
||||||
|
await self._pool.initialize()
|
||||||
|
|
||||||
|
async def initialize_with_warmup(self, voice_manager) -> tuple[str, str, int]:
|
||||||
|
"""Initialize and warm up model pool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
voice_manager: Voice manager instance for warmup
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (device, backend type, voice count)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Initialize pool
|
||||||
|
await self.initialize()
|
||||||
|
|
||||||
|
# Load model on all instances
|
||||||
|
model_path = self._config.pytorch_kokoro_v1_file
|
||||||
|
for instance in self._pool._instances:
|
||||||
|
await instance.load_model(model_path)
|
||||||
|
|
||||||
|
# Warm up first instance
|
||||||
|
instance = self._pool._instances[0]
|
||||||
|
try:
|
||||||
|
voices = await paths.list_voices()
|
||||||
|
voice_path = await paths.get_voice_path(settings.default_voice)
|
||||||
|
|
||||||
|
# Warm up with short text
|
||||||
|
warmup_text = "Warmup text for initialization."
|
||||||
|
voice_name = settings.default_voice
|
||||||
|
logger.debug(f"Using default voice '{voice_name}' for warmup")
|
||||||
|
async for _ in instance.generate(warmup_text, (voice_name, voice_path)):
|
||||||
|
pass
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to get default voice: {e}")
|
||||||
|
|
||||||
|
return "cuda" if settings.use_gpu else "cpu", "kokoro_v1", len(voices)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Warmup failed: {e}")
|
||||||
|
|
||||||
|
async def generate(self, *args, **kwargs):
|
||||||
|
"""Generate audio using model pool."""
|
||||||
|
if not self._pool:
|
||||||
|
raise RuntimeError("Model pool not initialized")
|
||||||
|
|
||||||
|
# Get available instance
|
||||||
|
instance = await self._pool.get_instance()
|
||||||
|
try:
|
||||||
|
async with self._chunk_semaphore:
|
||||||
|
async for chunk in instance.generate(*args, **kwargs):
|
||||||
|
yield chunk
|
||||||
|
finally:
|
||||||
|
# Release instance back to pool
|
||||||
|
await self._pool.release_instance(instance)
|
||||||
|
|
||||||
|
def unload_all(self) -> None:
|
||||||
|
"""Unload all models and free resources."""
|
||||||
|
if self._pool:
|
||||||
|
for instance in self._pool._instances:
|
||||||
|
instance.unload()
|
||||||
|
self._pool = None
|
||||||
|
|
||||||
|
|
||||||
async def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
|
async def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
|
||||||
|
|
|
@ -143,53 +143,43 @@ async def get_system_info():
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/debug/session_pools")
|
@router.get("/debug/model_pool")
|
||||||
async def get_session_pool_info():
|
async def get_model_pool_info():
|
||||||
"""Get information about ONNX session pools."""
|
"""Get information about model pool status."""
|
||||||
from ..inference.model_manager import get_manager
|
from ..inference.model_manager import get_manager
|
||||||
|
|
||||||
manager = await get_manager()
|
manager = await get_manager()
|
||||||
pools = manager._session_pools
|
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
|
|
||||||
pool_info = {}
|
if not manager._pool:
|
||||||
|
return {"status": "Model pool not initialized"}
|
||||||
|
|
||||||
# Get CPU pool info
|
pool_info = {
|
||||||
if "onnx_cpu" in pools:
|
"max_instances": manager._pool.max_instances,
|
||||||
cpu_pool = pools["onnx_cpu"]
|
"active_instances": len(manager._pool._instances),
|
||||||
pool_info["cpu"] = {
|
"queue_size": manager._pool._request_queue.qsize(),
|
||||||
"active_sessions": len(cpu_pool._sessions),
|
"max_queue_size": manager._pool._request_queue.maxsize,
|
||||||
"max_sessions": cpu_pool._max_size,
|
"instances": []
|
||||||
"sessions": [
|
|
||||||
{"model": path, "age_seconds": current_time - info.last_used}
|
|
||||||
for path, info in cpu_pool._sessions.items()
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Get GPU pool info
|
# Get instance info
|
||||||
if "onnx_gpu" in pools:
|
for instance in manager._pool._instances:
|
||||||
gpu_pool = pools["onnx_gpu"]
|
instance_info = {
|
||||||
pool_info["gpu"] = {
|
"id": instance.instance_id,
|
||||||
"active_sessions": len(gpu_pool._sessions),
|
"in_use": instance._in_use,
|
||||||
"max_streams": gpu_pool._max_size,
|
"device": instance._device,
|
||||||
"available_streams": len(gpu_pool._available_streams),
|
"has_stream": instance._stream is not None,
|
||||||
"sessions": [
|
"last_used": current_time - instance._last_used if instance._last_used > 0 else None
|
||||||
{
|
|
||||||
"model": path,
|
|
||||||
"age_seconds": current_time - info.last_used,
|
|
||||||
"stream_id": info.stream_id,
|
|
||||||
}
|
|
||||||
for path, info in gpu_pool._sessions.items()
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
|
pool_info["instances"].append(instance_info)
|
||||||
|
|
||||||
# Add GPU memory info if available
|
# Add GPU info if available
|
||||||
if GPU_AVAILABLE:
|
if GPU_AVAILABLE:
|
||||||
try:
|
try:
|
||||||
gpus = GPUtil.getGPUs()
|
gpus = GPUtil.getGPUs()
|
||||||
if gpus:
|
if gpus:
|
||||||
gpu = gpus[0] # Assume first GPU
|
gpu = gpus[0] # Assume first GPU
|
||||||
pool_info["gpu"]["memory"] = {
|
pool_info["gpu_memory"] = {
|
||||||
"total_mb": gpu.memoryTotal,
|
"total_mb": gpu.memoryTotal,
|
||||||
"used_mb": gpu.memoryUsed,
|
"used_mb": gpu.memoryUsed,
|
||||||
"free_mb": gpu.memoryFree,
|
"free_mb": gpu.memoryFree,
|
||||||
|
|
|
@ -23,14 +23,13 @@ from .text_processing.text_processor import process_text_chunk, smart_split
|
||||||
class TTSService:
|
class TTSService:
|
||||||
"""Text-to-speech service."""
|
"""Text-to-speech service."""
|
||||||
|
|
||||||
# Limit concurrent chunk processing
|
|
||||||
_chunk_semaphore = asyncio.Semaphore(4)
|
|
||||||
|
|
||||||
def __init__(self, output_dir: str = None):
|
def __init__(self, output_dir: str = None):
|
||||||
"""Initialize service."""
|
"""Initialize service."""
|
||||||
self.output_dir = output_dir
|
self.output_dir = output_dir
|
||||||
self.model_manager = None
|
self.model_manager = None
|
||||||
self._voice_manager = None
|
self._voice_manager = None
|
||||||
|
# Create request queue for global request management
|
||||||
|
self._request_queue = asyncio.Queue(maxsize=32)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create(cls, output_dir: str = None) -> "TTSService":
|
async def create(cls, output_dir: str = None) -> "TTSService":
|
||||||
|
@ -54,7 +53,6 @@ class TTSService:
|
||||||
lang_code: Optional[str] = None,
|
lang_code: Optional[str] = None,
|
||||||
) -> AsyncGenerator[Union[np.ndarray, bytes], None]:
|
) -> AsyncGenerator[Union[np.ndarray, bytes], None]:
|
||||||
"""Process tokens into audio."""
|
"""Process tokens into audio."""
|
||||||
async with self._chunk_semaphore:
|
|
||||||
try:
|
try:
|
||||||
# Handle stream finalization
|
# Handle stream finalization
|
||||||
if is_last:
|
if is_last:
|
||||||
|
@ -78,12 +76,7 @@ class TTSService:
|
||||||
if not tokens and not chunk_text:
|
if not tokens and not chunk_text:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Get backend
|
# Generate audio using model pool
|
||||||
backend = self.model_manager.get_backend()
|
|
||||||
|
|
||||||
# Generate audio using pre-warmed model
|
|
||||||
if isinstance(backend, KokoroV1):
|
|
||||||
# For Kokoro V1, pass text and voice info with lang_code
|
|
||||||
async for chunk_audio in self.model_manager.generate(
|
async for chunk_audio in self.model_manager.generate(
|
||||||
chunk_text,
|
chunk_text,
|
||||||
(voice_name, voice_path),
|
(voice_name, voice_path),
|
||||||
|
@ -106,39 +99,7 @@ class TTSService:
|
||||||
logger.error(f"Failed to convert audio: {str(e)}")
|
logger.error(f"Failed to convert audio: {str(e)}")
|
||||||
else:
|
else:
|
||||||
yield chunk_audio
|
yield chunk_audio
|
||||||
else:
|
|
||||||
# For legacy backends, load voice tensor
|
|
||||||
voice_tensor = await self._voice_manager.load_voice(
|
|
||||||
voice_name, device=backend.device
|
|
||||||
)
|
|
||||||
chunk_audio = await self.model_manager.generate(
|
|
||||||
tokens, voice_tensor, speed=speed
|
|
||||||
)
|
|
||||||
|
|
||||||
if chunk_audio is None:
|
|
||||||
logger.error("Model generated None for audio chunk")
|
|
||||||
return
|
|
||||||
|
|
||||||
if len(chunk_audio) == 0:
|
|
||||||
logger.error("Model generated empty audio chunk")
|
|
||||||
return
|
|
||||||
|
|
||||||
# For streaming, convert to bytes
|
|
||||||
if output_format:
|
|
||||||
try:
|
|
||||||
converted = await AudioService.convert_audio(
|
|
||||||
chunk_audio,
|
|
||||||
24000,
|
|
||||||
output_format,
|
|
||||||
is_first_chunk=is_first,
|
|
||||||
normalizer=normalizer,
|
|
||||||
is_last_chunk=is_last,
|
|
||||||
)
|
|
||||||
yield converted
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to convert audio: {str(e)}")
|
|
||||||
else:
|
|
||||||
yield chunk_audio
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to process tokens: {str(e)}")
|
logger.error(f"Failed to process tokens: {str(e)}")
|
||||||
|
|
||||||
|
@ -228,9 +189,6 @@ class TTSService:
|
||||||
chunk_index = 0
|
chunk_index = 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get backend
|
|
||||||
backend = self.model_manager.get_backend()
|
|
||||||
|
|
||||||
# Get voice path, handling combined voices
|
# Get voice path, handling combined voices
|
||||||
voice_name, voice_path = await self._get_voice_path(voice)
|
voice_name, voice_path = await self._get_voice_path(voice)
|
||||||
logger.debug(f"Using voice path: {voice_path}")
|
logger.debug(f"Using voice path: {voice_path}")
|
||||||
|
@ -310,187 +268,25 @@ class TTSService:
|
||||||
word_timestamps = []
|
word_timestamps = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get backend and voice path
|
# Get voice path
|
||||||
backend = self.model_manager.get_backend()
|
|
||||||
voice_name, voice_path = await self._get_voice_path(voice)
|
voice_name, voice_path = await self._get_voice_path(voice)
|
||||||
|
|
||||||
if isinstance(backend, KokoroV1):
|
|
||||||
# Use provided lang_code or determine from voice name
|
# Use provided lang_code or determine from voice name
|
||||||
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
|
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in text chunking"
|
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in text chunking"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get pipelines from backend for proper device management
|
# Process text in chunks
|
||||||
try:
|
async for chunk_text, tokens in smart_split(text):
|
||||||
# Initialize quiet pipeline for text chunking
|
# Generate audio for chunk using model pool
|
||||||
text_chunks = []
|
async for chunk_audio in self.model_manager.generate(
|
||||||
current_offset = 0.0 # Track time offset for timestamps
|
chunk_text,
|
||||||
|
(voice_name, voice_path),
|
||||||
logger.debug("Splitting text into chunks...")
|
speed=speed,
|
||||||
# Use backend's pipeline management
|
lang_code=pipeline_lang_code,
|
||||||
for result in backend._get_pipeline(pipeline_lang_code)(text):
|
|
||||||
if result.graphemes and result.phonemes:
|
|
||||||
text_chunks.append((result.graphemes, result.phonemes))
|
|
||||||
logger.debug(f"Split text into {len(text_chunks)} chunks")
|
|
||||||
|
|
||||||
# Process each chunk
|
|
||||||
for chunk_idx, (chunk_text, chunk_phonemes) in enumerate(
|
|
||||||
text_chunks
|
|
||||||
):
|
):
|
||||||
logger.debug(
|
chunks.append(chunk_audio)
|
||||||
f"Processing chunk {chunk_idx + 1}/{len(text_chunks)}: '{chunk_text[:50]}...'"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use backend's pipeline for generation
|
|
||||||
for result in backend._get_pipeline(pipeline_lang_code)(
|
|
||||||
chunk_text, voice=voice_path, speed=speed
|
|
||||||
):
|
|
||||||
# Collect audio chunks
|
|
||||||
if result.audio is not None:
|
|
||||||
chunks.append(result.audio.numpy())
|
|
||||||
|
|
||||||
# Process timestamps for this chunk
|
|
||||||
if (
|
|
||||||
return_timestamps
|
|
||||||
and hasattr(result, "tokens")
|
|
||||||
and result.tokens
|
|
||||||
):
|
|
||||||
logger.debug(
|
|
||||||
f"Processing chunk timestamps with {len(result.tokens)} tokens"
|
|
||||||
)
|
|
||||||
if result.pred_dur is not None:
|
|
||||||
try:
|
|
||||||
# Join timestamps for this chunk's tokens
|
|
||||||
KPipeline.join_timestamps(
|
|
||||||
result.tokens, result.pred_dur
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add timestamps with offset
|
|
||||||
for token in result.tokens:
|
|
||||||
if not all(
|
|
||||||
hasattr(token, attr)
|
|
||||||
for attr in [
|
|
||||||
"text",
|
|
||||||
"start_ts",
|
|
||||||
"end_ts",
|
|
||||||
]
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
if not token.text or not token.text.strip():
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Apply offset to timestamps
|
|
||||||
start_time = (
|
|
||||||
float(token.start_ts) + current_offset
|
|
||||||
)
|
|
||||||
end_time = (
|
|
||||||
float(token.end_ts) + current_offset
|
|
||||||
)
|
|
||||||
|
|
||||||
word_timestamps.append(
|
|
||||||
{
|
|
||||||
"word": str(token.text).strip(),
|
|
||||||
"start_time": start_time,
|
|
||||||
"end_time": end_time,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"Added timestamp for word '{token.text}': {start_time:.3f}s - {end_time:.3f}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update offset for next chunk based on pred_dur
|
|
||||||
chunk_duration = (
|
|
||||||
float(result.pred_dur.sum()) / 80
|
|
||||||
) # Convert frames to seconds
|
|
||||||
current_offset = max(
|
|
||||||
current_offset + chunk_duration, end_time
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"Updated time offset to {current_offset:.3f}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to process timestamps for chunk: {e}"
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"Processing timestamps with pred_dur shape: {result.pred_dur.shape}"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
# Join timestamps for this chunk's tokens
|
|
||||||
KPipeline.join_timestamps(
|
|
||||||
result.tokens, result.pred_dur
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
"Successfully joined timestamps for chunk"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to join timestamps for chunk: {e}"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Convert tokens to timestamps
|
|
||||||
for token in result.tokens:
|
|
||||||
try:
|
|
||||||
# Skip tokens without required attributes
|
|
||||||
if not all(
|
|
||||||
hasattr(token, attr)
|
|
||||||
for attr in ["text", "start_ts", "end_ts"]
|
|
||||||
):
|
|
||||||
logger.debug(
|
|
||||||
f"Skipping token missing attributes: {dir(token)}"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get and validate text
|
|
||||||
text = (
|
|
||||||
str(token.text).strip()
|
|
||||||
if token.text is not None
|
|
||||||
else ""
|
|
||||||
)
|
|
||||||
if not text:
|
|
||||||
logger.debug("Skipping empty token")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get and validate timestamps
|
|
||||||
start_ts = getattr(token, "start_ts", None)
|
|
||||||
end_ts = getattr(token, "end_ts", None)
|
|
||||||
if start_ts is None or end_ts is None:
|
|
||||||
logger.debug(
|
|
||||||
f"Skipping token with None timestamps: {text}"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Convert timestamps to float
|
|
||||||
try:
|
|
||||||
start_time = float(start_ts)
|
|
||||||
end_time = float(end_ts)
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
logger.debug(
|
|
||||||
f"Skipping token with invalid timestamps: {text}"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Add timestamp
|
|
||||||
word_timestamps.append(
|
|
||||||
{
|
|
||||||
"word": text,
|
|
||||||
"start_time": start_time,
|
|
||||||
"end_time": end_time,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"Added timestamp for word '{text}': {start_time:.3f}s - {end_time:.3f}s"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error processing token: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to process text with pipeline: {e}")
|
|
||||||
raise RuntimeError(f"Pipeline processing failed: {e}")
|
|
||||||
|
|
||||||
if not chunks:
|
if not chunks:
|
||||||
raise ValueError("No audio chunks were generated successfully")
|
raise ValueError("No audio chunks were generated successfully")
|
||||||
|
@ -499,60 +295,11 @@ class TTSService:
|
||||||
audio = np.concatenate(chunks) if len(chunks) > 1 else chunks[0]
|
audio = np.concatenate(chunks) if len(chunks) > 1 else chunks[0]
|
||||||
processing_time = time.time() - start_time
|
processing_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Return with timestamps if requested
|
||||||
if return_timestamps:
|
if return_timestamps:
|
||||||
# Validate timestamps before returning
|
|
||||||
if not word_timestamps:
|
|
||||||
logger.warning("No valid timestamps were generated")
|
|
||||||
else:
|
|
||||||
# Sort timestamps by start time to ensure proper order
|
|
||||||
word_timestamps.sort(key=lambda x: x["start_time"])
|
|
||||||
# Validate timestamp sequence
|
|
||||||
for i in range(1, len(word_timestamps)):
|
|
||||||
prev = word_timestamps[i - 1]
|
|
||||||
curr = word_timestamps[i]
|
|
||||||
if curr["start_time"] < prev["end_time"]:
|
|
||||||
logger.warning(
|
|
||||||
f"Overlapping timestamps detected: '{prev['word']}' ({prev['start_time']:.3f}-{prev['end_time']:.3f}) and '{curr['word']}' ({curr['start_time']:.3f}-{curr['end_time']:.3f})"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Returning {len(word_timestamps)} word timestamps"
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"First timestamp: {word_timestamps[0]['word']} at {word_timestamps[0]['start_time']:.3f}s"
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"Last timestamp: {word_timestamps[-1]['word']} at {word_timestamps[-1]['end_time']:.3f}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
return audio, processing_time, word_timestamps
|
return audio, processing_time, word_timestamps
|
||||||
return audio, processing_time
|
return audio, processing_time
|
||||||
|
|
||||||
else:
|
|
||||||
# For legacy backends
|
|
||||||
async for chunk in self.generate_audio_stream(
|
|
||||||
text,
|
|
||||||
voice,
|
|
||||||
speed, # Default to WAV for raw audio
|
|
||||||
):
|
|
||||||
if chunk is not None:
|
|
||||||
chunks.append(chunk)
|
|
||||||
|
|
||||||
if not chunks:
|
|
||||||
raise ValueError("No audio chunks were generated successfully")
|
|
||||||
|
|
||||||
# Combine chunks
|
|
||||||
audio = np.concatenate(chunks) if len(chunks) > 1 else chunks[0]
|
|
||||||
processing_time = time.time() - start_time
|
|
||||||
|
|
||||||
if return_timestamps:
|
|
||||||
return (
|
|
||||||
audio,
|
|
||||||
processing_time,
|
|
||||||
[],
|
|
||||||
) # Empty timestamps for legacy backends
|
|
||||||
return audio, processing_time
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in audio generation: {str(e)}")
|
logger.error(f"Error in audio generation: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
@ -589,44 +336,32 @@ class TTSService:
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
try:
|
try:
|
||||||
# Get backend and voice path
|
# Get voice path
|
||||||
backend = self.model_manager.get_backend()
|
|
||||||
voice_name, voice_path = await self._get_voice_path(voice)
|
voice_name, voice_path = await self._get_voice_path(voice)
|
||||||
|
|
||||||
if isinstance(backend, KokoroV1):
|
|
||||||
# For Kokoro V1, use generate_from_tokens with raw phonemes
|
|
||||||
result = None
|
|
||||||
# Use provided lang_code or determine from voice name
|
# Use provided lang_code or determine from voice name
|
||||||
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
|
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme pipeline"
|
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme pipeline"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
# Generate audio using model pool
|
||||||
# Use backend's pipeline management
|
chunks = []
|
||||||
for r in backend._get_pipeline(
|
async for chunk_audio in self.model_manager.generate(
|
||||||
pipeline_lang_code
|
phonemes,
|
||||||
).generate_from_tokens(
|
(voice_name, voice_path),
|
||||||
tokens=phonemes, # Pass raw phonemes string
|
|
||||||
voice=voice_path,
|
|
||||||
speed=speed,
|
speed=speed,
|
||||||
|
lang_code=pipeline_lang_code,
|
||||||
):
|
):
|
||||||
if r.audio is not None:
|
chunks.append(chunk_audio)
|
||||||
result = r
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to generate from phonemes: {e}")
|
|
||||||
raise RuntimeError(f"Phoneme generation failed: {e}")
|
|
||||||
|
|
||||||
if result is None or result.audio is None:
|
if not chunks:
|
||||||
raise ValueError("No audio generated")
|
raise ValueError("No audio generated")
|
||||||
|
|
||||||
|
# Combine chunks
|
||||||
|
audio = np.concatenate(chunks) if len(chunks) > 1 else chunks[0]
|
||||||
processing_time = time.time() - start_time
|
processing_time = time.time() - start_time
|
||||||
return result.audio.numpy(), processing_time
|
return audio, processing_time
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Phoneme generation only supported with Kokoro V1 backend"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in phoneme audio generation: {str(e)}")
|
logger.error(f"Error in phoneme audio generation: {str(e)}")
|
||||||
|
|
|
@ -13,7 +13,7 @@ Accept: application/json
|
||||||
### Get Session Pool Status
|
### Get Session Pool Status
|
||||||
# Shows active ONNX sessions, CUDA stream usage, and session ages
|
# Shows active ONNX sessions, CUDA stream usage, and session ages
|
||||||
# Useful for debugging resource exhaustion issues
|
# Useful for debugging resource exhaustion issues
|
||||||
GET http://localhost:8880/debug/session_pools
|
GET http://localhost:8880/debug/model_pool
|
||||||
Accept: application/json
|
Accept: application/json
|
||||||
|
|
||||||
### List Available Models
|
### List Available Models
|
||||||
|
|
142
test_client/README.md
Normal file
142
test_client/README.md
Normal file
|
@ -0,0 +1,142 @@
|
||||||
|
# Kokoro FastAPI Load Testing
|
||||||
|
|
||||||
|
This directory contains load testing scripts using Locust to test the Kokoro FastAPI server's performance under concurrent load.
|
||||||
|
|
||||||
|
## Docker Setup
|
||||||
|
|
||||||
|
The easiest way to run the tests is using Docker:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build the Docker image
|
||||||
|
docker build -t kokoro-locust .
|
||||||
|
|
||||||
|
# Run with web interface (default)
|
||||||
|
docker run -p 8089:8089 -e LOCUST_HOST=http://host.docker.internal:8880 kokoro-locust
|
||||||
|
|
||||||
|
# Run headless mode with specific parameters
|
||||||
|
docker run -e LOCUST_HOST=http://host.docker.internal:8880 \
|
||||||
|
-e LOCUST_HEADLESS=true \
|
||||||
|
-e LOCUST_USERS=10 \
|
||||||
|
-e LOCUST_SPAWN_RATE=1 \
|
||||||
|
-e LOCUST_RUN_TIME=5m \
|
||||||
|
kokoro-locust
|
||||||
|
```
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
|
||||||
|
- `LOCUST_HOST`: Target server URL (default: http://localhost:8880)
|
||||||
|
- `LOCUST_USERS`: Number of users to simulate (default: 10)
|
||||||
|
- `LOCUST_SPAWN_RATE`: Users to spawn per second (default: 1)
|
||||||
|
- `LOCUST_RUN_TIME`: Test duration (default: 5m)
|
||||||
|
- `LOCUST_HEADLESS`: Run without web UI if true (default: false)
|
||||||
|
|
||||||
|
### Accessing Results
|
||||||
|
|
||||||
|
- Web UI: http://localhost:8089 when running in web mode
|
||||||
|
- HTML Report: Generated in headless mode, copy from container:
|
||||||
|
```bash
|
||||||
|
docker cp <container_id>:/locust/report.html ./report.html
|
||||||
|
```
|
||||||
|
|
||||||
|
## Local Setup (Alternative)
|
||||||
|
|
||||||
|
If you prefer running without Docker:
|
||||||
|
|
||||||
|
1. Create a virtual environment and install requirements:
|
||||||
|
```bash
|
||||||
|
python -m venv venv
|
||||||
|
source venv/bin/activate # On Windows: venv\Scripts\activate
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Make sure your Kokoro FastAPI server is running (default: http://localhost:8880)
|
||||||
|
|
||||||
|
3. Run Locust:
|
||||||
|
```bash
|
||||||
|
# Web UI mode
|
||||||
|
locust -f locustfile.py --host http://localhost:8880
|
||||||
|
|
||||||
|
# Headless mode
|
||||||
|
locust -f locustfile.py --host http://localhost:8880 --users 10 --spawn-rate 1 --run-time 5m --headless
|
||||||
|
```
|
||||||
|
|
||||||
|
## Test Scenarios
|
||||||
|
|
||||||
|
The load test includes:
|
||||||
|
1. TTS endpoint testing with short phrases
|
||||||
|
2. Model pool monitoring
|
||||||
|
|
||||||
|
## Testing Different Configurations
|
||||||
|
|
||||||
|
To test with different numbers of model instances:
|
||||||
|
|
||||||
|
1. Set the model instance count in your server environment:
|
||||||
|
```bash
|
||||||
|
export PYTORCH_MAX_CONCURRENT_MODELS=2 # Adjust as needed
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Restart your Kokoro FastAPI server
|
||||||
|
|
||||||
|
3. Run the load test with different user counts:
|
||||||
|
```bash
|
||||||
|
# Example: Test with 20 users
|
||||||
|
docker run -e LOCUST_HOST=http://host.docker.internal:8880 \
|
||||||
|
-e LOCUST_HEADLESS=true \
|
||||||
|
-e LOCUST_USERS=20 \
|
||||||
|
-e LOCUST_SPAWN_RATE=2 \
|
||||||
|
-e LOCUST_RUN_TIME=5m \
|
||||||
|
kokoro-locust
|
||||||
|
```
|
||||||
|
|
||||||
|
## Example Test Matrix
|
||||||
|
|
||||||
|
Test your server with different configurations:
|
||||||
|
|
||||||
|
| Model Instances | Concurrent Users | Expected Load |
|
||||||
|
|----------------|------------------|---------------|
|
||||||
|
| 1 | 5 | Light |
|
||||||
|
| 2 | 10 | Medium |
|
||||||
|
| 4 | 20 | Heavy |
|
||||||
|
|
||||||
|
## Quick Test Script
|
||||||
|
|
||||||
|
Here's a quick script to test multiple configurations:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Array of test configurations
|
||||||
|
configs=(
|
||||||
|
"1,5" # 1 instance, 5 users
|
||||||
|
"2,10" # 2 instances, 10 users
|
||||||
|
"4,20" # 4 instances, 20 users
|
||||||
|
)
|
||||||
|
|
||||||
|
for config in "${configs[@]}"; do
|
||||||
|
IFS=',' read -r instances users <<< "$config"
|
||||||
|
|
||||||
|
echo "Testing with $instances instances and $users users..."
|
||||||
|
|
||||||
|
# Set instance count on server (you'll need to implement this)
|
||||||
|
# ssh server "export PYTORCH_MAX_CONCURRENT_MODELS=$instances && restart_server"
|
||||||
|
|
||||||
|
# Run load test
|
||||||
|
docker run -e LOCUST_HOST=http://host.docker.internal:8880 \
|
||||||
|
-e LOCUST_HEADLESS=true \
|
||||||
|
-e LOCUST_USERS=$users \
|
||||||
|
-e LOCUST_SPAWN_RATE=1 \
|
||||||
|
-e LOCUST_RUN_TIME=5m \
|
||||||
|
kokoro-locust
|
||||||
|
|
||||||
|
echo "Waiting 30s before next test..."
|
||||||
|
sleep 30
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
## Tips
|
||||||
|
|
||||||
|
1. Start with low user counts and gradually increase
|
||||||
|
2. Monitor server resources during tests
|
||||||
|
3. Use the debug endpoint (/debug/model_pool) to monitor instance usage
|
||||||
|
4. Check server logs for any errors or bottlenecks
|
||||||
|
5. When using Docker, use `host.docker.internal` to access localhost
|
190
test_client/locustfile.py
Normal file
190
test_client/locustfile.py
Normal file
|
@ -0,0 +1,190 @@
|
||||||
|
from locust import HttpUser, task, between, events
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
class SystemStats:
|
||||||
|
def __init__(self):
|
||||||
|
self.queue_size = 0
|
||||||
|
self.active_instances = 0
|
||||||
|
self.gpu_memory_used = 0
|
||||||
|
self.cpu_percent = 0
|
||||||
|
self.memory_percent = 0
|
||||||
|
self.error_count = 0
|
||||||
|
self.last_error = None
|
||||||
|
|
||||||
|
system_stats = SystemStats()
|
||||||
|
|
||||||
|
@events.init.add_listener
|
||||||
|
def on_locust_init(environment, **_kwargs):
|
||||||
|
@environment.web_ui.app.route("/system-stats")
|
||||||
|
def system_stats_page():
|
||||||
|
return {
|
||||||
|
"queue_size": system_stats.queue_size,
|
||||||
|
"active_instances": system_stats.active_instances,
|
||||||
|
"gpu_memory_used": system_stats.gpu_memory_used,
|
||||||
|
"cpu_percent": system_stats.cpu_percent,
|
||||||
|
"memory_percent": system_stats.memory_percent,
|
||||||
|
"error_count": system_stats.error_count,
|
||||||
|
"last_error": system_stats.last_error
|
||||||
|
}
|
||||||
|
|
||||||
|
class KokoroUser(HttpUser):
|
||||||
|
wait_time = between(2, 3) # Increased wait time to reduce load
|
||||||
|
|
||||||
|
def on_start(self):
|
||||||
|
"""Initialize test data."""
|
||||||
|
self.test_phrases = [
|
||||||
|
"Hello, how are you today?",
|
||||||
|
"The quick brown fox jumps over the lazy dog.",
|
||||||
|
"Testing voice synthesis with a short phrase.",
|
||||||
|
"I hope this works well!",
|
||||||
|
"Just a quick test of the system."
|
||||||
|
]
|
||||||
|
|
||||||
|
self.test_config = {
|
||||||
|
"model": "kokoro",
|
||||||
|
"voice": "af_nova",
|
||||||
|
"response_format": "mp3",
|
||||||
|
"speed": 1.0,
|
||||||
|
"stream": False
|
||||||
|
}
|
||||||
|
|
||||||
|
@task(1)
|
||||||
|
def test_tts_endpoint(self):
|
||||||
|
"""Test the TTS endpoint with short phrases."""
|
||||||
|
import random
|
||||||
|
test_text = random.choice(self.test_phrases)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
**self.test_config,
|
||||||
|
"input": test_text
|
||||||
|
}
|
||||||
|
|
||||||
|
with self.client.post(
|
||||||
|
"/v1/audio/speech",
|
||||||
|
json=payload,
|
||||||
|
catch_response=True,
|
||||||
|
name="/v1/audio/speech (short text)"
|
||||||
|
) as response:
|
||||||
|
try:
|
||||||
|
if response.status_code == 200:
|
||||||
|
response.success()
|
||||||
|
elif response.status_code == 429: # Too Many Requests
|
||||||
|
response.failure("Rate limit exceeded")
|
||||||
|
system_stats.error_count += 1
|
||||||
|
system_stats.last_error = "Rate limit exceeded"
|
||||||
|
elif response.status_code >= 500:
|
||||||
|
error_msg = f"Server error: {response.status_code}"
|
||||||
|
try:
|
||||||
|
error_data = response.json()
|
||||||
|
if 'detail' in error_data:
|
||||||
|
error_msg = f"Server error: {error_data['detail']}"
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
response.failure(error_msg)
|
||||||
|
system_stats.error_count += 1
|
||||||
|
system_stats.last_error = error_msg
|
||||||
|
else:
|
||||||
|
response.failure(f"Unexpected status: {response.status_code}")
|
||||||
|
system_stats.error_count += 1
|
||||||
|
system_stats.last_error = f"Unexpected status: {response.status_code}"
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Request failed: {str(e)}"
|
||||||
|
response.failure(error_msg)
|
||||||
|
system_stats.error_count += 1
|
||||||
|
system_stats.last_error = error_msg
|
||||||
|
|
||||||
|
@task(1) # Reduced monitoring frequency
|
||||||
|
def monitor_system(self):
|
||||||
|
"""Monitor system metrics via debug endpoints."""
|
||||||
|
# Get model pool stats
|
||||||
|
with self.client.get(
|
||||||
|
"/debug/model_pool",
|
||||||
|
catch_response=True,
|
||||||
|
name="Debug - Model Pool"
|
||||||
|
) as response:
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
system_stats.queue_size = data.get("queue_size", 0)
|
||||||
|
system_stats.active_instances = data.get("active_instances", 0)
|
||||||
|
if "gpu_memory" in data:
|
||||||
|
system_stats.gpu_memory_used = data["gpu_memory"]["used_mb"]
|
||||||
|
|
||||||
|
# Report metrics
|
||||||
|
self.environment.events.request.fire(
|
||||||
|
request_type="METRIC",
|
||||||
|
name="Queue Size",
|
||||||
|
response_time=system_stats.queue_size,
|
||||||
|
response_length=0,
|
||||||
|
exception=None
|
||||||
|
)
|
||||||
|
self.environment.events.request.fire(
|
||||||
|
request_type="METRIC",
|
||||||
|
name="Active Instances",
|
||||||
|
response_time=system_stats.active_instances,
|
||||||
|
response_length=0,
|
||||||
|
exception=None
|
||||||
|
)
|
||||||
|
if "gpu_memory" in data:
|
||||||
|
self.environment.events.request.fire(
|
||||||
|
request_type="METRIC",
|
||||||
|
name="GPU Memory (MB)",
|
||||||
|
response_time=system_stats.gpu_memory_used,
|
||||||
|
response_length=0,
|
||||||
|
exception=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get system stats
|
||||||
|
with self.client.get(
|
||||||
|
"/debug/system",
|
||||||
|
catch_response=True,
|
||||||
|
name="Debug - System"
|
||||||
|
) as response:
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
system_stats.cpu_percent = data.get("cpu", {}).get("cpu_percent", 0)
|
||||||
|
system_stats.memory_percent = data.get("process", {}).get("memory_percent", 0)
|
||||||
|
|
||||||
|
# Report metrics
|
||||||
|
self.environment.events.request.fire(
|
||||||
|
request_type="METRIC",
|
||||||
|
name="CPU %",
|
||||||
|
response_time=system_stats.cpu_percent,
|
||||||
|
response_length=0,
|
||||||
|
exception=None
|
||||||
|
)
|
||||||
|
self.environment.events.request.fire(
|
||||||
|
request_type="METRIC",
|
||||||
|
name="Memory %",
|
||||||
|
response_time=system_stats.memory_percent,
|
||||||
|
response_length=0,
|
||||||
|
exception=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add custom charts
|
||||||
|
@events.init_command_line_parser.add_listener
|
||||||
|
def init_parser(parser):
|
||||||
|
parser.add_argument(
|
||||||
|
'--custom-stats',
|
||||||
|
dest='custom_stats',
|
||||||
|
action='store_true',
|
||||||
|
help='Enable custom statistics in web UI'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stats processor
|
||||||
|
def process_stats():
|
||||||
|
stats = {
|
||||||
|
"Queue Size": system_stats.queue_size,
|
||||||
|
"Active Instances": system_stats.active_instances,
|
||||||
|
"GPU Memory (MB)": system_stats.gpu_memory_used,
|
||||||
|
"CPU %": system_stats.cpu_percent,
|
||||||
|
"Memory %": system_stats.memory_percent,
|
||||||
|
"Error Count": system_stats.error_count
|
||||||
|
}
|
||||||
|
return stats
|
||||||
|
|
||||||
|
@events.test_stop.add_listener
|
||||||
|
def on_test_stop(environment, **_kwargs):
|
||||||
|
print("\nFinal System Stats:")
|
||||||
|
for metric, value in process_stats().items():
|
||||||
|
print(f"{metric}: {value}")
|
2
test_client/requirements.txt
Normal file
2
test_client/requirements.txt
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
locust==2.24.0
|
||||||
|
aiohttp==3.9.3
|
59
test_client/run_tests.sh
Executable file
59
test_client/run_tests.sh
Executable file
|
@ -0,0 +1,59 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Build the Docker image if needed
|
||||||
|
if [[ "$(docker images -q kokoro-locust 2> /dev/null)" == "" ]]; then
|
||||||
|
echo "Building Kokoro Locust image..."
|
||||||
|
docker build -t kokoro-locust .
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Array of test configurations: instances,users,spawn_rate,run_time
|
||||||
|
configs=(
|
||||||
|
"1,5,1,3m" # Light load: 1 instance, 5 users
|
||||||
|
"2,10,2,3m" # Medium load: 2 instances, 10 users
|
||||||
|
"4,20,2,3m" # Heavy load: 4 instances, 20 users
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create results directory
|
||||||
|
mkdir -p test_results
|
||||||
|
timestamp=$(date +%Y%m%d_%H%M%S)
|
||||||
|
results_dir="test_results/run_${timestamp}"
|
||||||
|
mkdir -p "$results_dir"
|
||||||
|
|
||||||
|
# Run tests for each configuration
|
||||||
|
for config in "${configs[@]}"; do
|
||||||
|
IFS=',' read -r instances users spawn_rate run_time <<< "$config"
|
||||||
|
|
||||||
|
echo "----------------------------------------"
|
||||||
|
echo "Testing with configuration:"
|
||||||
|
echo "- Model instances: $instances"
|
||||||
|
echo "- Concurrent users: $users"
|
||||||
|
echo "- Spawn rate: $spawn_rate"
|
||||||
|
echo "- Run time: $run_time"
|
||||||
|
echo "----------------------------------------"
|
||||||
|
|
||||||
|
# Export instance count for the server (if running locally)
|
||||||
|
export PYTORCH_MAX_CONCURRENT_MODELS=$instances
|
||||||
|
|
||||||
|
# Run load test
|
||||||
|
docker run --rm \
|
||||||
|
-e LOCUST_HOST=http://host.docker.internal:8880 \
|
||||||
|
-e LOCUST_HEADLESS=true \
|
||||||
|
-e LOCUST_USERS=$users \
|
||||||
|
-e LOCUST_SPAWN_RATE=$spawn_rate \
|
||||||
|
-e LOCUST_RUN_TIME=$run_time \
|
||||||
|
--name kokoro-locust-test \
|
||||||
|
kokoro-locust
|
||||||
|
|
||||||
|
# Copy the report
|
||||||
|
test_name="instances${instances}_users${users}"
|
||||||
|
docker cp kokoro-locust-test:/locust/report.html "$results_dir/${test_name}_report.html"
|
||||||
|
|
||||||
|
echo "Test complete. Report saved to $results_dir/${test_name}_report.html"
|
||||||
|
echo "Waiting 30s before next test..."
|
||||||
|
sleep 30
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "----------------------------------------"
|
||||||
|
echo "All tests complete!"
|
||||||
|
echo "Results saved in: $results_dir"
|
||||||
|
echo "----------------------------------------"
|
16
test_client/start.sh
Normal file
16
test_client/start.sh
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# If LOCUST_HEADLESS is true, run in headless mode with specified parameters
|
||||||
|
if [ "$LOCUST_HEADLESS" = "true" ]; then
|
||||||
|
locust -f locustfile.py \
|
||||||
|
--host ${LOCUST_HOST} \
|
||||||
|
--users ${LOCUST_USERS} \
|
||||||
|
--spawn-rate ${LOCUST_SPAWN_RATE} \
|
||||||
|
--run-time ${LOCUST_RUN_TIME} \
|
||||||
|
--headless \
|
||||||
|
--print-stats \
|
||||||
|
--html report.html
|
||||||
|
else
|
||||||
|
# Run with web interface
|
||||||
|
locust -f locustfile.py --host ${LOCUST_HOST}
|
||||||
|
fi
|
Loading…
Add table
Reference in a new issue