mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
WIP, Functional for CPU: Updated for ONNX runtime support, Dockerfile and TTS Service
This commit is contained in:
parent
f1131b4836
commit
e4d8e74738
15 changed files with 946 additions and 313 deletions
|
@ -10,8 +10,9 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install PyTorch CPU version
|
||||
RUN pip3 install --no-cache-dir torch==2.5.1 --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
# Install PyTorch CPU version and ONNX runtime
|
||||
RUN pip3 install --no-cache-dir torch==2.5.1 --extra-index-url https://download.pytorch.org/whl/cpu && \
|
||||
pip3 install --no-cache-dir onnxruntime==1.20.1
|
||||
|
||||
# Install all other dependencies from requirements.txt
|
||||
COPY requirements.txt .
|
||||
|
|
|
@ -10,7 +10,8 @@ from fastapi import FastAPI
|
|||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from .core.config import settings
|
||||
from .services.tts import TTSModel, TTSService
|
||||
from .services.tts_model import TTSModel
|
||||
from .services.tts_service import TTSService
|
||||
from .routers.openai_compatible import router as openai_router
|
||||
|
||||
|
||||
|
@ -20,7 +21,7 @@ async def lifespan(app: FastAPI):
|
|||
logger.info("Loading TTS model and voice packs...")
|
||||
|
||||
# Initialize the main model with warm-up
|
||||
model, voicepack_count = TTSModel.initialize()
|
||||
voicepack_count = TTSModel.initialize()
|
||||
logger.info(f"Model loaded and warmed up on {TTSModel._device}")
|
||||
logger.info(f"{voicepack_count} voice packs loaded successfully")
|
||||
yield
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import List
|
|||
from loguru import logger
|
||||
from fastapi import Depends, Response, APIRouter, HTTPException
|
||||
|
||||
from ..services.tts import TTSService
|
||||
from ..services.tts_service import TTSService
|
||||
from ..services.audio import AudioService
|
||||
from ..structures.schemas import OpenAISpeechRequest
|
||||
|
||||
|
@ -15,9 +15,7 @@ router = APIRouter(
|
|||
|
||||
def get_tts_service() -> TTSService:
|
||||
"""Dependency to get TTSService instance with database session"""
|
||||
return TTSService(
|
||||
start_worker=False
|
||||
) # Don't start worker thread for OpenAI endpoint
|
||||
return TTSService() # Initialize TTSService with default settings
|
||||
|
||||
|
||||
@router.post("/audio/speech")
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from .tts import TTSModel, TTSService
|
||||
from .tts_model import TTSModel
|
||||
from .tts_service import TTSService
|
||||
|
||||
__all__ = ["TTSService", "TTSModel"]
|
||||
|
|
|
@ -1,286 +0,0 @@
|
|||
import io
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import threading
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tiktoken
|
||||
import scipy.io.wavfile as wavfile
|
||||
from kokoro import generate, tokenize, phonemize, normalize_text
|
||||
from loguru import logger
|
||||
from models import build_model
|
||||
|
||||
from ..core.config import settings
|
||||
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
|
||||
class TTSModel:
|
||||
_instance = None
|
||||
_device = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
# Directory for all voices (copied base voices, and any created combined voices)
|
||||
VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices")
|
||||
|
||||
@classmethod
|
||||
def initialize(cls):
|
||||
"""Initialize and warm up the model"""
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
# Initialize model
|
||||
cls._device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"Initializing model on {cls._device}")
|
||||
model_path = os.path.join(settings.model_dir, settings.model_path)
|
||||
model = build_model(model_path, cls._device)
|
||||
cls._instance = model
|
||||
|
||||
# Ensure voices directory exists
|
||||
os.makedirs(cls.VOICES_DIR, exist_ok=True)
|
||||
|
||||
# Copy base voices to local directory
|
||||
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
|
||||
if os.path.exists(base_voices_dir):
|
||||
for file in os.listdir(base_voices_dir):
|
||||
if file.endswith(".pt"):
|
||||
voice_name = file[:-3]
|
||||
voice_path = os.path.join(cls.VOICES_DIR, file)
|
||||
if not os.path.exists(voice_path):
|
||||
try:
|
||||
logger.info(
|
||||
f"Copying base voice {voice_name} to voices directory"
|
||||
)
|
||||
base_path = os.path.join(base_voices_dir, file)
|
||||
voicepack = torch.load(
|
||||
base_path,
|
||||
map_location=cls._device,
|
||||
weights_only=True,
|
||||
)
|
||||
torch.save(voicepack, voice_path)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error copying voice {voice_name}: {str(e)}"
|
||||
)
|
||||
|
||||
# Warm up with default voice
|
||||
try:
|
||||
dummy_text = "Hello"
|
||||
voice_path = os.path.join(cls.VOICES_DIR, "af.pt")
|
||||
dummy_voicepack = torch.load(
|
||||
voice_path, map_location=cls._device, weights_only=True
|
||||
)
|
||||
generate(model, dummy_text, dummy_voicepack, lang="a", speed=1.0)
|
||||
logger.info("Model warm-up complete")
|
||||
except Exception as e:
|
||||
logger.warning(f"Model warm-up failed: {e}")
|
||||
|
||||
# Count voices in directory for validation
|
||||
voice_count = len(
|
||||
[f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]
|
||||
)
|
||||
return cls._instance, voice_count
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
"""Get the initialized instance or raise an error"""
|
||||
if cls._instance is None:
|
||||
raise RuntimeError("Model not initialized. Call initialize() first.")
|
||||
return cls._instance, cls._device
|
||||
|
||||
|
||||
class TTSService:
|
||||
def __init__(self, output_dir: str = None, start_worker: bool = False):
|
||||
self.output_dir = output_dir
|
||||
self._ensure_voices()
|
||||
if start_worker:
|
||||
self.start_worker()
|
||||
|
||||
def _ensure_voices(self):
|
||||
"""Copy base voices to local voices directory during initialization"""
|
||||
os.makedirs(TTSModel.VOICES_DIR, exist_ok=True)
|
||||
|
||||
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
|
||||
if os.path.exists(base_voices_dir):
|
||||
for file in os.listdir(base_voices_dir):
|
||||
if file.endswith(".pt"):
|
||||
voice_name = file[:-3]
|
||||
voice_path = os.path.join(TTSModel.VOICES_DIR, file)
|
||||
if not os.path.exists(voice_path):
|
||||
try:
|
||||
logger.info(
|
||||
f"Copying base voice {voice_name} to voices directory"
|
||||
)
|
||||
base_path = os.path.join(base_voices_dir, file)
|
||||
voicepack = torch.load(
|
||||
base_path,
|
||||
map_location=TTSModel._device,
|
||||
weights_only=True,
|
||||
)
|
||||
torch.save(voicepack, voice_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Error copying voice {voice_name}: {str(e)}")
|
||||
|
||||
def _split_text(self, text: str) -> List[str]:
|
||||
"""Split text into sentences"""
|
||||
return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()]
|
||||
|
||||
def _get_voice_path(self, voice_name: str) -> Optional[str]:
|
||||
"""Get the path to a voice file.
|
||||
|
||||
Args:
|
||||
voice_name: Name of the voice to find
|
||||
|
||||
Returns:
|
||||
Path to the voice file if found, None otherwise
|
||||
"""
|
||||
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice_name}.pt")
|
||||
return voice_path if os.path.exists(voice_path) else None
|
||||
|
||||
def _generate_audio(
|
||||
self, text: str, voice: str, speed: float, stitch_long_output: bool = True
|
||||
) -> Tuple[torch.Tensor, float]:
|
||||
"""Generate audio and measure processing time"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Normalize text once at the start
|
||||
text = normalize_text(text)
|
||||
if not text:
|
||||
raise ValueError("Text is empty after preprocessing")
|
||||
|
||||
# Check voice exists
|
||||
voice_path = self._get_voice_path(voice)
|
||||
if not voice_path:
|
||||
raise ValueError(f"Voice not found: {voice}")
|
||||
|
||||
# Load model and voice
|
||||
model = TTSModel._instance
|
||||
voicepack = torch.load(
|
||||
voice_path, map_location=TTSModel._device, weights_only=True
|
||||
)
|
||||
|
||||
# Generate audio with or without stitching
|
||||
if stitch_long_output:
|
||||
chunks = self._split_text(text)
|
||||
audio_chunks = []
|
||||
|
||||
# Process all chunks with same model/voicepack instance
|
||||
for i, chunk in enumerate(chunks):
|
||||
try:
|
||||
# Validate phonemization first
|
||||
# ps = phonemize(chunk, voice[0])
|
||||
# tokens = tokenize(ps)
|
||||
# logger.debug(
|
||||
# f"Processing chunk {i + 1}/{len(chunks)}: {len(tokens)} tokens"
|
||||
# )
|
||||
|
||||
# Only proceed if phonemization succeeded
|
||||
chunk_audio, _ = generate(
|
||||
model, chunk, voicepack, lang=voice[0], speed=speed
|
||||
)
|
||||
if chunk_audio is not None:
|
||||
audio_chunks.append(chunk_audio)
|
||||
else:
|
||||
logger.error(
|
||||
f"No audio generated for chunk {i + 1}/{len(chunks)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to generate audio for chunk {i + 1}/{len(chunks)}: '{chunk}'. Error: {str(e)}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not audio_chunks:
|
||||
raise ValueError("No audio chunks were generated successfully")
|
||||
|
||||
audio = (
|
||||
np.concatenate(audio_chunks)
|
||||
if len(audio_chunks) > 1
|
||||
else audio_chunks[0]
|
||||
)
|
||||
else:
|
||||
audio, _ = generate(model, text, voicepack, lang=voice[0], speed=speed)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
return audio, processing_time
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in audio generation: {str(e)}")
|
||||
raise
|
||||
|
||||
def _save_audio(self, audio: torch.Tensor, filepath: str):
|
||||
"""Save audio to file"""
|
||||
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
||||
wavfile.write(filepath, 24000, audio)
|
||||
|
||||
def _audio_to_bytes(self, audio: torch.Tensor) -> bytes:
|
||||
"""Convert audio tensor to WAV bytes"""
|
||||
buffer = io.BytesIO()
|
||||
wavfile.write(buffer, 24000, audio)
|
||||
return buffer.getvalue()
|
||||
|
||||
def combine_voices(self, voices: List[str]) -> str:
|
||||
"""Combine multiple voices into a new voice.
|
||||
|
||||
Args:
|
||||
voices: List of voice names to combine
|
||||
|
||||
Returns:
|
||||
Name of the combined voice
|
||||
|
||||
Raises:
|
||||
ValueError: If less than 2 voices provided or voice loading fails
|
||||
RuntimeError: If voice combination or saving fails
|
||||
"""
|
||||
if len(voices) < 2:
|
||||
raise ValueError("At least 2 voices are required for combination")
|
||||
|
||||
# Load voices
|
||||
t_voices: List[torch.Tensor] = []
|
||||
v_name: List[str] = []
|
||||
|
||||
for voice in voices:
|
||||
try:
|
||||
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt")
|
||||
voicepack = torch.load(
|
||||
voice_path, map_location=TTSModel._device, weights_only=True
|
||||
)
|
||||
t_voices.append(voicepack)
|
||||
v_name.append(voice)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load voice {voice}: {str(e)}")
|
||||
|
||||
# Combine voices
|
||||
try:
|
||||
f: str = "_".join(v_name)
|
||||
v = torch.mean(torch.stack(t_voices), dim=0)
|
||||
combined_path = os.path.join(TTSModel.VOICES_DIR, f"{f}.pt")
|
||||
|
||||
# Save combined voice
|
||||
try:
|
||||
torch.save(v, combined_path)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to save combined voice to {combined_path}: {str(e)}"
|
||||
)
|
||||
|
||||
return f
|
||||
|
||||
except Exception as e:
|
||||
if not isinstance(e, (ValueError, RuntimeError)):
|
||||
raise RuntimeError(f"Error combining voices: {str(e)}")
|
||||
raise
|
||||
|
||||
def list_voices(self) -> List[str]:
|
||||
"""List all available voices"""
|
||||
voices = []
|
||||
try:
|
||||
for file in os.listdir(TTSModel.VOICES_DIR):
|
||||
if file.endswith(".pt"):
|
||||
voices.append(file[:-3]) # Remove .pt extension
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing voices: {str(e)}")
|
||||
return sorted(voices)
|
65
api/src/services/tts_cpu.py
Normal file
65
api/src/services/tts_cpu.py
Normal file
|
@ -0,0 +1,65 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
from onnxruntime import InferenceSession, SessionOptions, GraphOptimizationLevel, ExecutionMode
|
||||
from loguru import logger
|
||||
|
||||
class TTSCPUModel:
|
||||
_instance = None
|
||||
_onnx_session = None
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, model_dir: str):
|
||||
"""Initialize ONNX model for CPU inference"""
|
||||
if cls._onnx_session is None:
|
||||
# Try loading ONNX model
|
||||
onnx_path = os.path.join(model_dir, "kokoro-v0_19.onnx")
|
||||
if not os.path.exists(onnx_path):
|
||||
return None
|
||||
|
||||
logger.info(f"Loading ONNX model from {onnx_path}")
|
||||
|
||||
# Configure ONNX session for optimal performance
|
||||
session_options = SessionOptions()
|
||||
session_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
session_options.intra_op_num_threads = 4 # Adjust based on CPU cores
|
||||
session_options.execution_mode = ExecutionMode.ORT_SEQUENTIAL
|
||||
|
||||
# Configure CPU provider options
|
||||
provider_options = {
|
||||
'CPUExecutionProvider': {
|
||||
'arena_extend_strategy': 'kNextPowerOfTwo',
|
||||
'cpu_memory_arena_cfg': 'cpu:0'
|
||||
}
|
||||
}
|
||||
|
||||
cls._onnx_session = InferenceSession(
|
||||
onnx_path,
|
||||
sess_options=session_options,
|
||||
providers=['CPUExecutionProvider'],
|
||||
provider_options=[provider_options]
|
||||
)
|
||||
|
||||
return cls._onnx_session
|
||||
return cls._onnx_session
|
||||
|
||||
@classmethod
|
||||
def generate(cls, tokens: list, voicepack: torch.Tensor, speed: float) -> np.ndarray:
|
||||
"""Generate audio using ONNX model"""
|
||||
if cls._onnx_session is None:
|
||||
raise RuntimeError("ONNX model not initialized")
|
||||
|
||||
# Pre-allocate and prepare inputs
|
||||
tokens_input = np.array([tokens], dtype=np.int64)
|
||||
style_input = voicepack[len(tokens)-2].numpy() # Already has correct dimensions
|
||||
speed_input = np.full(1, speed, dtype=np.float32) # More efficient than ones * speed
|
||||
|
||||
# Run inference with optimized inputs
|
||||
return cls._onnx_session.run(
|
||||
None,
|
||||
{
|
||||
'tokens': tokens_input,
|
||||
'style': style_input,
|
||||
'speed': speed_input
|
||||
}
|
||||
)[0]
|
32
api/src/services/tts_gpu.py
Normal file
32
api/src/services/tts_gpu.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
import os
|
||||
import torch
|
||||
from loguru import logger
|
||||
from models import build_model
|
||||
from kokoro import generate
|
||||
|
||||
class TTSGPUModel:
|
||||
_instance = None
|
||||
_device = "cuda"
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, model_dir: str, model_path: str):
|
||||
"""Initialize PyTorch model for GPU inference"""
|
||||
if cls._instance is None and torch.cuda.is_available():
|
||||
try:
|
||||
logger.info("Initializing GPU model")
|
||||
model_path = os.path.join(model_dir, model_path)
|
||||
model = build_model(model_path, cls._device)
|
||||
cls._instance = model
|
||||
return cls._instance
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize GPU model: {e}")
|
||||
return None
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def generate(cls, text: str, voicepack: torch.Tensor, lang: str, speed: float) -> tuple[torch.Tensor, dict]:
|
||||
"""Generate audio using PyTorch model on GPU"""
|
||||
if cls._instance is None:
|
||||
raise RuntimeError("GPU model not initialized")
|
||||
|
||||
return generate(cls._instance, text, voicepack, lang=lang, speed=speed)
|
94
api/src/services/tts_model.py
Normal file
94
api/src/services/tts_model.py
Normal file
|
@ -0,0 +1,94 @@
|
|||
import os
|
||||
import threading
|
||||
import torch
|
||||
from loguru import logger
|
||||
from kokoro import tokenize, phonemize
|
||||
|
||||
from ..core.config import settings
|
||||
from .tts_cpu import TTSCPUModel
|
||||
from .tts_gpu import TTSGPUModel
|
||||
|
||||
|
||||
class TTSModel:
|
||||
_device = None
|
||||
_lock = threading.Lock()
|
||||
VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices")
|
||||
|
||||
@classmethod
|
||||
def initialize(cls):
|
||||
"""Initialize and warm up the model"""
|
||||
with cls._lock:
|
||||
# Set device and initialize model
|
||||
cls._device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"Initializing model on {cls._device}")
|
||||
|
||||
# Initialize appropriate model based on device
|
||||
if cls._device == "cuda":
|
||||
if not TTSGPUModel.initialize(settings.model_dir, settings.model_path):
|
||||
raise RuntimeError("Failed to initialize GPU model")
|
||||
else:
|
||||
# Try CPU ONNX first, fallback to CPU PyTorch if needed
|
||||
if not TTSCPUModel.initialize(settings.model_dir):
|
||||
logger.warning("ONNX initialization failed, falling back to PyTorch CPU")
|
||||
if not TTSGPUModel.initialize(settings.model_dir, settings.model_path):
|
||||
raise RuntimeError("Failed to initialize CPU model")
|
||||
|
||||
# Setup voices directory
|
||||
os.makedirs(cls.VOICES_DIR, exist_ok=True)
|
||||
|
||||
# Copy base voices to local directory
|
||||
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
|
||||
if os.path.exists(base_voices_dir):
|
||||
for file in os.listdir(base_voices_dir):
|
||||
if file.endswith(".pt"):
|
||||
voice_name = file[:-3]
|
||||
voice_path = os.path.join(cls.VOICES_DIR, file)
|
||||
if not os.path.exists(voice_path):
|
||||
try:
|
||||
logger.info(
|
||||
f"Copying base voice {voice_name} to voices directory"
|
||||
)
|
||||
base_path = os.path.join(base_voices_dir, file)
|
||||
voicepack = torch.load(
|
||||
base_path,
|
||||
map_location=cls._device,
|
||||
weights_only=True,
|
||||
)
|
||||
torch.save(voicepack, voice_path)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error copying voice {voice_name}: {str(e)}"
|
||||
)
|
||||
|
||||
# Warm up with default voice
|
||||
try:
|
||||
dummy_text = "Hello"
|
||||
voice_path = os.path.join(cls.VOICES_DIR, "af.pt")
|
||||
dummy_voicepack = torch.load(
|
||||
voice_path, map_location=cls._device, weights_only=True
|
||||
)
|
||||
|
||||
if cls._device == "cuda":
|
||||
TTSGPUModel.generate(dummy_text, dummy_voicepack, "a", 1.0)
|
||||
else:
|
||||
ps = phonemize(dummy_text, "a")
|
||||
tokens = tokenize(ps)
|
||||
tokens = [0] + tokens + [0]
|
||||
TTSCPUModel.generate(tokens, dummy_voicepack, 1.0)
|
||||
|
||||
logger.info("Model warm-up complete")
|
||||
except Exception as e:
|
||||
logger.warning(f"Model warm-up failed: {e}")
|
||||
|
||||
# Count voices in directory
|
||||
voice_count = len(
|
||||
[f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]
|
||||
)
|
||||
return voice_count
|
||||
|
||||
@classmethod
|
||||
def get_device(cls):
|
||||
"""Get the current device or raise an error"""
|
||||
if cls._device is None:
|
||||
raise RuntimeError("Model not initialized. Call initialize() first.")
|
||||
return cls._device
|
168
api/src/services/tts_service.py
Normal file
168
api/src/services/tts_service.py
Normal file
|
@ -0,0 +1,168 @@
|
|||
import io
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import scipy.io.wavfile as wavfile
|
||||
from kokoro import tokenize, phonemize, normalize_text
|
||||
from loguru import logger
|
||||
|
||||
from ..core.config import settings
|
||||
from .tts_model import TTSModel
|
||||
from .tts_cpu import TTSCPUModel
|
||||
from .tts_gpu import TTSGPUModel
|
||||
|
||||
|
||||
class TTSService:
|
||||
def __init__(self, output_dir: str = None):
|
||||
self.output_dir = output_dir
|
||||
|
||||
def _split_text(self, text: str) -> List[str]:
|
||||
"""Split text into sentences"""
|
||||
return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()]
|
||||
|
||||
def _get_voice_path(self, voice_name: str) -> Optional[str]:
|
||||
"""Get the path to a voice file"""
|
||||
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice_name}.pt")
|
||||
return voice_path if os.path.exists(voice_path) else None
|
||||
|
||||
def _generate_audio(
|
||||
self, text: str, voice: str, speed: float, stitch_long_output: bool = True
|
||||
) -> Tuple[torch.Tensor, float]:
|
||||
"""Generate audio and measure processing time"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Normalize text once at the start
|
||||
text = normalize_text(text)
|
||||
if not text:
|
||||
raise ValueError("Text is empty after preprocessing")
|
||||
|
||||
# Check voice exists
|
||||
voice_path = self._get_voice_path(voice)
|
||||
if not voice_path:
|
||||
raise ValueError(f"Voice not found: {voice}")
|
||||
|
||||
# Load voice
|
||||
voicepack = torch.load(
|
||||
voice_path, map_location=TTSModel.get_device(), weights_only=True
|
||||
)
|
||||
|
||||
# Generate audio with or without stitching
|
||||
if stitch_long_output:
|
||||
chunks = self._split_text(text)
|
||||
audio_chunks = []
|
||||
|
||||
# Process all chunks
|
||||
for i, chunk in enumerate(chunks):
|
||||
try:
|
||||
# Process chunk
|
||||
if TTSModel.get_device() == "cuda":
|
||||
chunk_audio, _ = TTSGPUModel.generate(chunk, voicepack, voice[0], speed)
|
||||
else:
|
||||
ps = phonemize(chunk, voice[0])
|
||||
tokens = tokenize(ps)
|
||||
tokens = [0] + tokens + [0] # Add padding
|
||||
chunk_audio = TTSCPUModel.generate(tokens, voicepack, speed)
|
||||
|
||||
if chunk_audio is not None:
|
||||
audio_chunks.append(chunk_audio)
|
||||
else:
|
||||
logger.error(f"No audio generated for chunk {i + 1}/{len(chunks)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to generate audio for chunk {i + 1}/{len(chunks)}: '{chunk}'. Error: {str(e)}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not audio_chunks:
|
||||
raise ValueError("No audio chunks were generated successfully")
|
||||
|
||||
audio = (
|
||||
np.concatenate(audio_chunks)
|
||||
if len(audio_chunks) > 1
|
||||
else audio_chunks[0]
|
||||
)
|
||||
else:
|
||||
# Process single chunk
|
||||
if TTSModel.get_device() == "cuda":
|
||||
audio, _ = TTSGPUModel.generate(text, voicepack, voice[0], speed)
|
||||
else:
|
||||
ps = phonemize(text, voice[0])
|
||||
tokens = tokenize(ps)
|
||||
tokens = [0] + tokens + [0] # Add padding
|
||||
audio = TTSCPUModel.generate(tokens, voicepack, speed)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
return audio, processing_time
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in audio generation: {str(e)}")
|
||||
raise
|
||||
|
||||
def _save_audio(self, audio: torch.Tensor, filepath: str):
|
||||
"""Save audio to file"""
|
||||
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
||||
wavfile.write(filepath, 24000, audio)
|
||||
|
||||
def _audio_to_bytes(self, audio: torch.Tensor) -> bytes:
|
||||
"""Convert audio tensor to WAV bytes"""
|
||||
buffer = io.BytesIO()
|
||||
wavfile.write(buffer, 24000, audio)
|
||||
return buffer.getvalue()
|
||||
|
||||
def combine_voices(self, voices: List[str]) -> str:
|
||||
"""Combine multiple voices into a new voice"""
|
||||
if len(voices) < 2:
|
||||
raise ValueError("At least 2 voices are required for combination")
|
||||
|
||||
# Load voices
|
||||
t_voices: List[torch.Tensor] = []
|
||||
v_name: List[str] = []
|
||||
|
||||
for voice in voices:
|
||||
try:
|
||||
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt")
|
||||
voicepack = torch.load(
|
||||
voice_path, map_location=TTSModel.get_device(), weights_only=True
|
||||
)
|
||||
t_voices.append(voicepack)
|
||||
v_name.append(voice)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load voice {voice}: {str(e)}")
|
||||
|
||||
# Combine voices
|
||||
try:
|
||||
f: str = "_".join(v_name)
|
||||
v = torch.mean(torch.stack(t_voices), dim=0)
|
||||
combined_path = os.path.join(TTSModel.VOICES_DIR, f"{f}.pt")
|
||||
|
||||
# Save combined voice
|
||||
try:
|
||||
torch.save(v, combined_path)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to save combined voice to {combined_path}: {str(e)}"
|
||||
)
|
||||
|
||||
return f
|
||||
|
||||
except Exception as e:
|
||||
if not isinstance(e, (ValueError, RuntimeError)):
|
||||
raise RuntimeError(f"Error combining voices: {str(e)}")
|
||||
raise
|
||||
|
||||
def list_voices(self) -> List[str]:
|
||||
"""List all available voices"""
|
||||
voices = []
|
||||
try:
|
||||
for file in os.listdir(TTSModel.VOICES_DIR):
|
||||
if file.endswith(".pt"):
|
||||
voices.append(file[:-3]) # Remove .pt extension
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing voices: {str(e)}")
|
||||
return sorted(voices)
|
|
@ -7,7 +7,8 @@ import numpy as np
|
|||
import torch
|
||||
import pytest
|
||||
|
||||
from api.src.services.tts import TTSModel, TTSService
|
||||
from api.src.services.tts_model import TTSModel
|
||||
from api.src.services.tts_service import TTSService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -68,12 +69,12 @@ def test_list_voices(mock_join, mock_listdir, tts_service):
|
|||
assert "not_a_voice" not in voices
|
||||
|
||||
|
||||
@patch("api.src.services.tts.TTSModel.get_instance")
|
||||
@patch("api.src.services.tts.TTSModel.get_voicepack")
|
||||
@patch("api.src.services.tts.normalize_text")
|
||||
@patch("api.src.services.tts.phonemize")
|
||||
@patch("api.src.services.tts.tokenize")
|
||||
@patch("api.src.services.tts.generate")
|
||||
@patch("api.src.services.tts_model.TTSModel.get_instance")
|
||||
@patch("api.src.services.tts_model.TTSModel.get_voicepack")
|
||||
@patch("kokoro.normalize_text")
|
||||
@patch("kokoro.phonemize")
|
||||
@patch("kokoro.tokenize")
|
||||
@patch("kokoro.generate")
|
||||
def test_generate_audio_empty_text(
|
||||
mock_generate,
|
||||
mock_tokenize,
|
||||
|
@ -90,12 +91,12 @@ def test_generate_audio_empty_text(
|
|||
tts_service._generate_audio("", "af", 1.0)
|
||||
|
||||
|
||||
@patch("api.src.services.tts.TTSModel.get_instance")
|
||||
@patch("api.src.services.tts_model.TTSModel.get_instance")
|
||||
@patch("os.path.exists")
|
||||
@patch("api.src.services.tts.normalize_text")
|
||||
@patch("api.src.services.tts.phonemize")
|
||||
@patch("api.src.services.tts.tokenize")
|
||||
@patch("api.src.services.tts.generate")
|
||||
@patch("kokoro.normalize_text")
|
||||
@patch("kokoro.phonemize")
|
||||
@patch("kokoro.tokenize")
|
||||
@patch("kokoro.generate")
|
||||
@patch("torch.load")
|
||||
def test_generate_audio_no_chunks(
|
||||
mock_torch_load,
|
||||
|
@ -225,8 +226,8 @@ def test_generate_audio_success(
|
|||
assert len(audio) > 0
|
||||
|
||||
|
||||
@patch("api.src.services.tts.torch.cuda.is_available")
|
||||
@patch("api.src.services.tts.build_model")
|
||||
@patch("torch.cuda.is_available")
|
||||
@patch("models.build_model")
|
||||
def test_model_initialization_cuda(mock_build_model, mock_cuda_available):
|
||||
"""Test model initialization with CUDA"""
|
||||
mock_cuda_available.return_value = True
|
||||
|
@ -257,8 +258,8 @@ def test_model_initialization_cpu(mock_build_model, mock_cuda_available):
|
|||
mock_build_model.assert_called_once()
|
||||
|
||||
|
||||
@patch("api.src.services.tts.TTSService._get_voice_path")
|
||||
@patch("api.src.services.tts.TTSModel.get_instance")
|
||||
@patch("api.src.services.tts_service.TTSService._get_voice_path")
|
||||
@patch("api.src.services.tts_model.TTSModel.get_instance")
|
||||
def test_voicepack_loading_error(mock_get_instance, mock_get_voice_path):
|
||||
"""Test voicepack loading error handling"""
|
||||
mock_get_voice_path.return_value = None
|
||||
|
@ -271,7 +272,7 @@ def test_voicepack_loading_error(mock_get_instance, mock_get_voice_path):
|
|||
service._generate_audio("test", "nonexistent_voice", 1.0)
|
||||
|
||||
|
||||
@patch("api.src.services.tts.TTSModel")
|
||||
@patch("api.src.services.tts_model.TTSModel")
|
||||
def test_save_audio(mock_tts_model, tts_service, sample_audio, tmp_path):
|
||||
"""Test saving audio to file"""
|
||||
output_dir = os.path.join(tmp_path, "test_output")
|
||||
|
@ -284,7 +285,7 @@ def test_save_audio(mock_tts_model, tts_service, sample_audio, tmp_path):
|
|||
assert os.path.getsize(output_path) > 0
|
||||
|
||||
|
||||
@patch("api.src.services.tts.TTSModel.get_instance")
|
||||
@patch("api.src.services.tts_model.TTSModel.get_instance")
|
||||
@patch("os.path.exists")
|
||||
@patch("api.src.services.tts.normalize_text")
|
||||
@patch("api.src.services.tts.generate")
|
||||
|
|
216
examples/benchmarks/benchmark_results_cpu.json
Normal file
216
examples/benchmarks/benchmark_results_cpu.json
Normal file
|
@ -0,0 +1,216 @@
|
|||
{
|
||||
"results": [
|
||||
{
|
||||
"tokens": 100,
|
||||
"processing_time": 14.349808931350708,
|
||||
"output_length": 31.15,
|
||||
"rtf": 0.46,
|
||||
"elapsed_time": 14.716031074523926
|
||||
},
|
||||
{
|
||||
"tokens": 200,
|
||||
"processing_time": 28.341803312301636,
|
||||
"output_length": 62.6,
|
||||
"rtf": 0.45,
|
||||
"elapsed_time": 43.44207406044006
|
||||
},
|
||||
{
|
||||
"tokens": 300,
|
||||
"processing_time": 43.352553606033325,
|
||||
"output_length": 96.325,
|
||||
"rtf": 0.45,
|
||||
"elapsed_time": 87.26906609535217
|
||||
},
|
||||
{
|
||||
"tokens": 400,
|
||||
"processing_time": 71.02449822425842,
|
||||
"output_length": 128.575,
|
||||
"rtf": 0.55,
|
||||
"elapsed_time": 158.7198133468628
|
||||
},
|
||||
{
|
||||
"tokens": 500,
|
||||
"processing_time": 70.92521691322327,
|
||||
"output_length": 158.575,
|
||||
"rtf": 0.45,
|
||||
"elapsed_time": 230.01379895210266
|
||||
},
|
||||
{
|
||||
"tokens": 600,
|
||||
"processing_time": 83.6328592300415,
|
||||
"output_length": 189.25,
|
||||
"rtf": 0.44,
|
||||
"elapsed_time": 314.02610969543457
|
||||
},
|
||||
{
|
||||
"tokens": 700,
|
||||
"processing_time": 103.0810194015503,
|
||||
"output_length": 222.075,
|
||||
"rtf": 0.46,
|
||||
"elapsed_time": 417.5678551197052
|
||||
},
|
||||
{
|
||||
"tokens": 800,
|
||||
"processing_time": 127.02162909507751,
|
||||
"output_length": 253.85,
|
||||
"rtf": 0.5,
|
||||
"elapsed_time": 545.0128681659698
|
||||
},
|
||||
{
|
||||
"tokens": 900,
|
||||
"processing_time": 130.49781227111816,
|
||||
"output_length": 283.775,
|
||||
"rtf": 0.46,
|
||||
"elapsed_time": 675.8943417072296
|
||||
},
|
||||
{
|
||||
"tokens": 1000,
|
||||
"processing_time": 154.76425909996033,
|
||||
"output_length": 315.475,
|
||||
"rtf": 0.49,
|
||||
"elapsed_time": 831.0677945613861
|
||||
}
|
||||
],
|
||||
"system_metrics": [
|
||||
{
|
||||
"timestamp": "2025-01-03T00:23:52.896889",
|
||||
"cpu_percent": 4.5,
|
||||
"ram_percent": 39.1,
|
||||
"ram_used_gb": 24.86032485961914,
|
||||
"gpu_memory_used": 1281.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:24:07.429461",
|
||||
"cpu_percent": 4.5,
|
||||
"ram_percent": 39.1,
|
||||
"ram_used_gb": 24.847564697265625,
|
||||
"gpu_memory_used": 1285.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:24:07.620587",
|
||||
"cpu_percent": 2.7,
|
||||
"ram_percent": 39.1,
|
||||
"ram_used_gb": 24.846607208251953,
|
||||
"gpu_memory_used": 1275.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:24:36.140754",
|
||||
"cpu_percent": 5.4,
|
||||
"ram_percent": 39.1,
|
||||
"ram_used_gb": 24.857810974121094,
|
||||
"gpu_memory_used": 1267.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:24:36.340675",
|
||||
"cpu_percent": 6.2,
|
||||
"ram_percent": 39.1,
|
||||
"ram_used_gb": 24.85773468017578,
|
||||
"gpu_memory_used": 1267.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:25:19.905634",
|
||||
"cpu_percent": 29.1,
|
||||
"ram_percent": 39.2,
|
||||
"ram_used_gb": 24.920318603515625,
|
||||
"gpu_memory_used": 1256.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:25:20.182219",
|
||||
"cpu_percent": 20.0,
|
||||
"ram_percent": 39.2,
|
||||
"ram_used_gb": 24.930198669433594,
|
||||
"gpu_memory_used": 1256.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:26:31.414760",
|
||||
"cpu_percent": 5.3,
|
||||
"ram_percent": 39.5,
|
||||
"ram_used_gb": 25.127891540527344,
|
||||
"gpu_memory_used": 1259.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:26:31.617256",
|
||||
"cpu_percent": 3.6,
|
||||
"ram_percent": 39.5,
|
||||
"ram_used_gb": 25.126346588134766,
|
||||
"gpu_memory_used": 1252.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:27:42.736097",
|
||||
"cpu_percent": 10.5,
|
||||
"ram_percent": 39.5,
|
||||
"ram_used_gb": 25.100231170654297,
|
||||
"gpu_memory_used": 1249.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:27:42.912870",
|
||||
"cpu_percent": 5.3,
|
||||
"ram_percent": 39.5,
|
||||
"ram_used_gb": 25.098285675048828,
|
||||
"gpu_memory_used": 1249.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:29:06.725264",
|
||||
"cpu_percent": 8.9,
|
||||
"ram_percent": 39.5,
|
||||
"ram_used_gb": 25.123123168945312,
|
||||
"gpu_memory_used": 1239.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:29:06.928826",
|
||||
"cpu_percent": 5.5,
|
||||
"ram_percent": 39.5,
|
||||
"ram_used_gb": 25.128646850585938,
|
||||
"gpu_memory_used": 1239.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:30:50.206349",
|
||||
"cpu_percent": 49.6,
|
||||
"ram_percent": 39.6,
|
||||
"ram_used_gb": 25.162948608398438,
|
||||
"gpu_memory_used": 1245.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:30:50.491837",
|
||||
"cpu_percent": 14.8,
|
||||
"ram_percent": 39.5,
|
||||
"ram_used_gb": 25.13379669189453,
|
||||
"gpu_memory_used": 1245.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:32:57.721467",
|
||||
"cpu_percent": 6.2,
|
||||
"ram_percent": 39.6,
|
||||
"ram_used_gb": 25.187721252441406,
|
||||
"gpu_memory_used": 1384.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:32:57.913350",
|
||||
"cpu_percent": 3.6,
|
||||
"ram_percent": 39.6,
|
||||
"ram_used_gb": 25.199390411376953,
|
||||
"gpu_memory_used": 1384.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:35:08.608730",
|
||||
"cpu_percent": 6.3,
|
||||
"ram_percent": 39.8,
|
||||
"ram_used_gb": 25.311710357666016,
|
||||
"gpu_memory_used": 1330.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:35:08.791851",
|
||||
"cpu_percent": 5.3,
|
||||
"ram_percent": 39.8,
|
||||
"ram_used_gb": 25.326683044433594,
|
||||
"gpu_memory_used": 1333.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:37:43.782406",
|
||||
"cpu_percent": 6.8,
|
||||
"ram_percent": 40.6,
|
||||
"ram_used_gb": 25.803058624267578,
|
||||
"gpu_memory_used": 1409.0
|
||||
}
|
||||
]
|
||||
}
|
19
examples/benchmarks/benchmark_stats_cpu.txt
Normal file
19
examples/benchmarks/benchmark_stats_cpu.txt
Normal file
|
@ -0,0 +1,19 @@
|
|||
=== Benchmark Statistics (with correct RTF) ===
|
||||
|
||||
Overall Stats:
|
||||
Total tokens processed: 5500
|
||||
Total audio generated: 1741.65s
|
||||
Total test duration: 831.07s
|
||||
Average processing rate: 6.72 tokens/second
|
||||
Average RTF: 0.47x
|
||||
|
||||
Per-chunk Stats:
|
||||
Average chunk size: 550.00 tokens
|
||||
Min chunk size: 100.00 tokens
|
||||
Max chunk size: 1000.00 tokens
|
||||
Average processing time: 82.70s
|
||||
Average output length: 174.17s
|
||||
|
||||
Performance Ranges:
|
||||
Processing rate range: 5.63 - 7.17 tokens/second
|
||||
RTF range: 0.44x - 0.55x
|
323
examples/benchmarks/benchmark_tts_rtf.py
Normal file
323
examples/benchmarks/benchmark_tts_rtf.py
Normal file
|
@ -0,0 +1,323 @@
|
|||
import os
|
||||
import json
|
||||
import time
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
|
||||
import pandas as pd
|
||||
import psutil
|
||||
import seaborn as sns
|
||||
import requests
|
||||
import tiktoken
|
||||
import scipy.io.wavfile as wavfile
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
|
||||
def setup_plot(fig, ax, title):
|
||||
"""Configure plot styling"""
|
||||
ax.grid(True, linestyle="--", alpha=0.3, color="#ffffff")
|
||||
ax.set_title(title, pad=20, fontsize=16, fontweight="bold", color="#ffffff")
|
||||
ax.set_xlabel(ax.get_xlabel(), fontsize=14, fontweight="medium", color="#ffffff")
|
||||
ax.set_ylabel(ax.get_ylabel(), fontsize=14, fontweight="medium", color="#ffffff")
|
||||
ax.tick_params(labelsize=12, colors="#ffffff")
|
||||
|
||||
for spine in ax.spines.values():
|
||||
spine.set_color("#ffffff")
|
||||
spine.set_alpha(0.3)
|
||||
spine.set_linewidth(0.5)
|
||||
|
||||
ax.set_facecolor("#1a1a2e")
|
||||
fig.patch.set_facecolor("#1a1a2e")
|
||||
return fig, ax
|
||||
|
||||
|
||||
def get_text_for_tokens(text: str, num_tokens: int) -> str:
|
||||
"""Get a slice of text that contains exactly num_tokens tokens"""
|
||||
tokens = enc.encode(text)
|
||||
if num_tokens > len(tokens):
|
||||
return text
|
||||
return enc.decode(tokens[:num_tokens])
|
||||
|
||||
|
||||
def get_audio_length(audio_data: bytes) -> float:
|
||||
"""Get audio length in seconds from bytes data"""
|
||||
temp_path = "examples/benchmarks/output/temp.wav"
|
||||
os.makedirs(os.path.dirname(temp_path), exist_ok=True)
|
||||
with open(temp_path, "wb") as f:
|
||||
f.write(audio_data)
|
||||
|
||||
try:
|
||||
rate, data = wavfile.read(temp_path)
|
||||
return len(data) / rate
|
||||
finally:
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
|
||||
def get_gpu_memory():
|
||||
"""Get GPU memory usage using nvidia-smi"""
|
||||
try:
|
||||
result = subprocess.check_output(
|
||||
["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader"]
|
||||
)
|
||||
return float(result.decode("utf-8").strip())
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
return None
|
||||
|
||||
|
||||
def get_system_metrics():
|
||||
"""Get current system metrics"""
|
||||
# Take multiple CPU measurements over a short period
|
||||
samples = []
|
||||
for _ in range(3): # Take 3 samples
|
||||
# Get both overall and per-CPU percentages
|
||||
overall_cpu = psutil.cpu_percent(interval=0.1)
|
||||
per_cpu = psutil.cpu_percent(percpu=True)
|
||||
avg_per_cpu = sum(per_cpu) / len(per_cpu)
|
||||
# Use the maximum of overall and average per-CPU
|
||||
samples.append(max(overall_cpu, avg_per_cpu))
|
||||
|
||||
# Use the maximum CPU usage from all samples
|
||||
cpu_usage = round(max(samples), 2)
|
||||
|
||||
metrics = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"cpu_percent": cpu_usage,
|
||||
"ram_percent": psutil.virtual_memory().percent,
|
||||
"ram_used_gb": psutil.virtual_memory().used / (1024**3),
|
||||
}
|
||||
|
||||
gpu_mem = get_gpu_memory()
|
||||
if gpu_mem is not None:
|
||||
metrics["gpu_memory_used"] = gpu_mem
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def real_time_factor(processing_time: float, audio_length: float, decimals: int = 2) -> float:
|
||||
"""Calculate Real-Time Factor (RTF) as processing-time / length-of-audio"""
|
||||
rtf = processing_time / audio_length
|
||||
return round(rtf, decimals)
|
||||
|
||||
|
||||
def make_tts_request(text: str, timeout: int = 1800) -> tuple[float, float]:
|
||||
"""Make TTS request using OpenAI-compatible endpoint and return processing time and output length"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
response = requests.post(
|
||||
"http://localhost:8880/v1/audio/speech",
|
||||
json={
|
||||
"model": "kokoro",
|
||||
"input": text,
|
||||
"voice": "af",
|
||||
"response_format": "wav",
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
processing_time = round(time.time() - start_time, 2)
|
||||
audio_length = round(get_audio_length(response.content), 2)
|
||||
|
||||
# Save the audio file
|
||||
token_count = len(enc.encode(text))
|
||||
output_file = f"examples/benchmarks/output/chunk_{token_count}_tokens.wav"
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(response.content)
|
||||
print(f"Saved audio to {output_file}")
|
||||
|
||||
return processing_time, audio_length
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Error making request for text: {text[:50]}... Error: {str(e)}")
|
||||
return None, None
|
||||
except Exception as e:
|
||||
print(f"Error processing text: {text[:50]}... Error: {str(e)}")
|
||||
return None, None
|
||||
|
||||
|
||||
def plot_system_metrics(metrics_data):
|
||||
"""Create plots for system metrics over time"""
|
||||
df = pd.DataFrame(metrics_data)
|
||||
df["timestamp"] = pd.to_datetime(df["timestamp"])
|
||||
elapsed_time = (df["timestamp"] - df["timestamp"].iloc[0]).dt.total_seconds()
|
||||
|
||||
baseline_cpu = df["cpu_percent"].iloc[0]
|
||||
baseline_ram = df["ram_used_gb"].iloc[0]
|
||||
baseline_gpu = df["gpu_memory_used"].iloc[0] / 1024 if "gpu_memory_used" in df.columns else None
|
||||
|
||||
if "gpu_memory_used" in df.columns:
|
||||
df["gpu_memory_gb"] = df["gpu_memory_used"] / 1024
|
||||
|
||||
plt.style.use("dark_background")
|
||||
|
||||
has_gpu = "gpu_memory_used" in df.columns
|
||||
num_plots = 3 if has_gpu else 2
|
||||
fig, axes = plt.subplots(num_plots, 1, figsize=(15, 5 * num_plots))
|
||||
fig.patch.set_facecolor("#1a1a2e")
|
||||
|
||||
window = min(5, len(df) // 2)
|
||||
|
||||
# Plot CPU Usage
|
||||
smoothed_cpu = df["cpu_percent"].rolling(window=window, center=True).mean()
|
||||
sns.lineplot(x=elapsed_time, y=smoothed_cpu, ax=axes[0], color="#ff2a6d", linewidth=2)
|
||||
axes[0].axhline(y=baseline_cpu, color="#05d9e8", linestyle="--", alpha=0.5, label="Baseline")
|
||||
axes[0].set_xlabel("Time (seconds)")
|
||||
axes[0].set_ylabel("CPU Usage (%)")
|
||||
axes[0].set_title("CPU Usage Over Time")
|
||||
axes[0].set_ylim(0, max(df["cpu_percent"]) * 1.1)
|
||||
axes[0].legend()
|
||||
|
||||
# Plot RAM Usage
|
||||
smoothed_ram = df["ram_used_gb"].rolling(window=window, center=True).mean()
|
||||
sns.lineplot(x=elapsed_time, y=smoothed_ram, ax=axes[1], color="#05d9e8", linewidth=2)
|
||||
axes[1].axhline(y=baseline_ram, color="#ff2a6d", linestyle="--", alpha=0.5, label="Baseline")
|
||||
axes[1].set_xlabel("Time (seconds)")
|
||||
axes[1].set_ylabel("RAM Usage (GB)")
|
||||
axes[1].set_title("RAM Usage Over Time")
|
||||
axes[1].set_ylim(0, max(df["ram_used_gb"]) * 1.1)
|
||||
axes[1].legend()
|
||||
|
||||
# Plot GPU Memory if available
|
||||
if has_gpu:
|
||||
smoothed_gpu = df["gpu_memory_gb"].rolling(window=window, center=True).mean()
|
||||
sns.lineplot(x=elapsed_time, y=smoothed_gpu, ax=axes[2], color="#ff2a6d", linewidth=2)
|
||||
axes[2].axhline(y=baseline_gpu, color="#05d9e8", linestyle="--", alpha=0.5, label="Baseline")
|
||||
axes[2].set_xlabel("Time (seconds)")
|
||||
axes[2].set_ylabel("GPU Memory (GB)")
|
||||
axes[2].set_title("GPU Memory Usage Over Time")
|
||||
axes[2].set_ylim(0, max(df["gpu_memory_gb"]) * 1.1)
|
||||
axes[2].legend()
|
||||
|
||||
for ax in axes:
|
||||
ax.grid(True, linestyle="--", alpha=0.3)
|
||||
ax.set_facecolor("#1a1a2e")
|
||||
for spine in ax.spines.values():
|
||||
spine.set_color("#ffffff")
|
||||
spine.set_alpha(0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig("examples/benchmarks/system_usage_rtf.png", dpi=300, bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
|
||||
def main():
|
||||
os.makedirs("examples/benchmarks/output", exist_ok=True)
|
||||
|
||||
with open("examples/benchmarks/the_time_machine_hg_wells.txt", "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
|
||||
total_tokens = len(enc.encode(text))
|
||||
print(f"Total tokens in file: {total_tokens}")
|
||||
|
||||
# Generate token sizes with dense sampling at start
|
||||
dense_range = list(range(100, 1001, 100))
|
||||
token_sizes = sorted(list(set(dense_range)))
|
||||
print(f"Testing sizes: {token_sizes}")
|
||||
|
||||
results = []
|
||||
system_metrics = []
|
||||
test_start_time = time.time()
|
||||
|
||||
for num_tokens in token_sizes:
|
||||
chunk = get_text_for_tokens(text, num_tokens)
|
||||
actual_tokens = len(enc.encode(chunk))
|
||||
|
||||
print(f"\nProcessing chunk with {actual_tokens} tokens:")
|
||||
print(f"Text preview: {chunk[:100]}...")
|
||||
|
||||
system_metrics.append(get_system_metrics())
|
||||
|
||||
processing_time, audio_length = make_tts_request(chunk)
|
||||
if processing_time is None or audio_length is None:
|
||||
print("Breaking loop due to error")
|
||||
break
|
||||
|
||||
system_metrics.append(get_system_metrics())
|
||||
|
||||
# Calculate RTF using the correct formula
|
||||
rtf = real_time_factor(processing_time, audio_length)
|
||||
|
||||
results.append({
|
||||
"tokens": actual_tokens,
|
||||
"processing_time": processing_time,
|
||||
"output_length": audio_length,
|
||||
"rtf": rtf,
|
||||
"elapsed_time": round(time.time() - test_start_time, 2),
|
||||
})
|
||||
|
||||
with open("examples/benchmarks/benchmark_results_rtf.json", "w") as f:
|
||||
json.dump({"results": results, "system_metrics": system_metrics}, f, indent=2)
|
||||
|
||||
df = pd.DataFrame(results)
|
||||
if df.empty:
|
||||
print("No data to plot")
|
||||
return
|
||||
|
||||
df["tokens_per_second"] = df["tokens"] / df["processing_time"]
|
||||
|
||||
with open("examples/benchmarks/benchmark_stats_rtf.txt", "w") as f:
|
||||
f.write("=== Benchmark Statistics (with correct RTF) ===\n\n")
|
||||
|
||||
f.write("Overall Stats:\n")
|
||||
f.write(f"Total tokens processed: {df['tokens'].sum()}\n")
|
||||
f.write(f"Total audio generated: {df['output_length'].sum():.2f}s\n")
|
||||
f.write(f"Total test duration: {df['elapsed_time'].max():.2f}s\n")
|
||||
f.write(f"Average processing rate: {df['tokens_per_second'].mean():.2f} tokens/second\n")
|
||||
f.write(f"Average RTF: {df['rtf'].mean():.2f}x\n\n")
|
||||
|
||||
f.write("Per-chunk Stats:\n")
|
||||
f.write(f"Average chunk size: {df['tokens'].mean():.2f} tokens\n")
|
||||
f.write(f"Min chunk size: {df['tokens'].min():.2f} tokens\n")
|
||||
f.write(f"Max chunk size: {df['tokens'].max():.2f} tokens\n")
|
||||
f.write(f"Average processing time: {df['processing_time'].mean():.2f}s\n")
|
||||
f.write(f"Average output length: {df['output_length'].mean():.2f}s\n\n")
|
||||
|
||||
f.write("Performance Ranges:\n")
|
||||
f.write(f"Processing rate range: {df['tokens_per_second'].min():.2f} - {df['tokens_per_second'].max():.2f} tokens/second\n")
|
||||
f.write(f"RTF range: {df['rtf'].min():.2f}x - {df['rtf'].max():.2f}x\n")
|
||||
|
||||
plt.style.use("dark_background")
|
||||
|
||||
# Plot Processing Time vs Token Count
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
sns.scatterplot(data=df, x="tokens", y="processing_time", s=100, alpha=0.6, color="#ff2a6d")
|
||||
sns.regplot(data=df, x="tokens", y="processing_time", scatter=False, color="#05d9e8", line_kws={"linewidth": 2})
|
||||
corr = df["tokens"].corr(df["processing_time"])
|
||||
plt.text(0.05, 0.95, f"Correlation: {corr:.2f}", transform=ax.transAxes, fontsize=10, color="#ffffff",
|
||||
bbox=dict(facecolor="#1a1a2e", edgecolor="#ffffff", alpha=0.7))
|
||||
setup_plot(fig, ax, "Processing Time vs Input Size")
|
||||
ax.set_xlabel("Number of Input Tokens")
|
||||
ax.set_ylabel("Processing Time (seconds)")
|
||||
plt.savefig("examples/benchmarks/processing_time_rtf.png", dpi=300, bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
# Plot RTF vs Token Count
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
sns.scatterplot(data=df, x="tokens", y="rtf", s=100, alpha=0.6, color="#ff2a6d")
|
||||
sns.regplot(data=df, x="tokens", y="rtf", scatter=False, color="#05d9e8", line_kws={"linewidth": 2})
|
||||
corr = df["tokens"].corr(df["rtf"])
|
||||
plt.text(0.05, 0.95, f"Correlation: {corr:.2f}", transform=ax.transAxes, fontsize=10, color="#ffffff",
|
||||
bbox=dict(facecolor="#1a1a2e", edgecolor="#ffffff", alpha=0.7))
|
||||
setup_plot(fig, ax, "Real-Time Factor vs Input Size")
|
||||
ax.set_xlabel("Number of Input Tokens")
|
||||
ax.set_ylabel("Real-Time Factor (processing time / audio length)")
|
||||
plt.savefig("examples/benchmarks/realtime_factor_rtf.png", dpi=300, bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
plot_system_metrics(system_metrics)
|
||||
|
||||
print("\nResults saved to:")
|
||||
print("- examples/benchmarks/benchmark_results_rtf.json")
|
||||
print("- examples/benchmarks/benchmark_stats_rtf.txt")
|
||||
print("- examples/benchmarks/processing_time_rtf.png")
|
||||
print("- examples/benchmarks/realtime_factor_rtf.png")
|
||||
print("- examples/benchmarks/system_usage_rtf.png")
|
||||
print("\nAudio files saved in examples/benchmarks/output/")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
BIN
examples/benchmarks/processing_time_cpu.png
Normal file
BIN
examples/benchmarks/processing_time_cpu.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 254 KiB |
BIN
examples/benchmarks/realtime_factor_cpu.png
Normal file
BIN
examples/benchmarks/realtime_factor_cpu.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 208 KiB |
Loading…
Add table
Reference in a new issue