mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
WIP, Functional for CPU: Updated for ONNX runtime support, Dockerfile and TTS Service
This commit is contained in:
parent
f1131b4836
commit
e4d8e74738
15 changed files with 946 additions and 313 deletions
|
@ -10,8 +10,9 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
&& apt-get clean \
|
&& apt-get clean \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Install PyTorch CPU version
|
# Install PyTorch CPU version and ONNX runtime
|
||||||
RUN pip3 install --no-cache-dir torch==2.5.1 --extra-index-url https://download.pytorch.org/whl/cpu
|
RUN pip3 install --no-cache-dir torch==2.5.1 --extra-index-url https://download.pytorch.org/whl/cpu && \
|
||||||
|
pip3 install --no-cache-dir onnxruntime==1.20.1
|
||||||
|
|
||||||
# Install all other dependencies from requirements.txt
|
# Install all other dependencies from requirements.txt
|
||||||
COPY requirements.txt .
|
COPY requirements.txt .
|
||||||
|
|
|
@ -10,7 +10,8 @@ from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
from .core.config import settings
|
from .core.config import settings
|
||||||
from .services.tts import TTSModel, TTSService
|
from .services.tts_model import TTSModel
|
||||||
|
from .services.tts_service import TTSService
|
||||||
from .routers.openai_compatible import router as openai_router
|
from .routers.openai_compatible import router as openai_router
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,7 +21,7 @@ async def lifespan(app: FastAPI):
|
||||||
logger.info("Loading TTS model and voice packs...")
|
logger.info("Loading TTS model and voice packs...")
|
||||||
|
|
||||||
# Initialize the main model with warm-up
|
# Initialize the main model with warm-up
|
||||||
model, voicepack_count = TTSModel.initialize()
|
voicepack_count = TTSModel.initialize()
|
||||||
logger.info(f"Model loaded and warmed up on {TTSModel._device}")
|
logger.info(f"Model loaded and warmed up on {TTSModel._device}")
|
||||||
logger.info(f"{voicepack_count} voice packs loaded successfully")
|
logger.info(f"{voicepack_count} voice packs loaded successfully")
|
||||||
yield
|
yield
|
||||||
|
|
|
@ -3,7 +3,7 @@ from typing import List
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from fastapi import Depends, Response, APIRouter, HTTPException
|
from fastapi import Depends, Response, APIRouter, HTTPException
|
||||||
|
|
||||||
from ..services.tts import TTSService
|
from ..services.tts_service import TTSService
|
||||||
from ..services.audio import AudioService
|
from ..services.audio import AudioService
|
||||||
from ..structures.schemas import OpenAISpeechRequest
|
from ..structures.schemas import OpenAISpeechRequest
|
||||||
|
|
||||||
|
@ -15,9 +15,7 @@ router = APIRouter(
|
||||||
|
|
||||||
def get_tts_service() -> TTSService:
|
def get_tts_service() -> TTSService:
|
||||||
"""Dependency to get TTSService instance with database session"""
|
"""Dependency to get TTSService instance with database session"""
|
||||||
return TTSService(
|
return TTSService() # Initialize TTSService with default settings
|
||||||
start_worker=False
|
|
||||||
) # Don't start worker thread for OpenAI endpoint
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/audio/speech")
|
@router.post("/audio/speech")
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
from .tts import TTSModel, TTSService
|
from .tts_model import TTSModel
|
||||||
|
from .tts_service import TTSService
|
||||||
|
|
||||||
__all__ = ["TTSService", "TTSModel"]
|
__all__ = ["TTSService", "TTSModel"]
|
||||||
|
|
|
@ -1,286 +0,0 @@
|
||||||
import io
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import time
|
|
||||||
import threading
|
|
||||||
from typing import List, Tuple, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import tiktoken
|
|
||||||
import scipy.io.wavfile as wavfile
|
|
||||||
from kokoro import generate, tokenize, phonemize, normalize_text
|
|
||||||
from loguru import logger
|
|
||||||
from models import build_model
|
|
||||||
|
|
||||||
from ..core.config import settings
|
|
||||||
|
|
||||||
enc = tiktoken.get_encoding("cl100k_base")
|
|
||||||
|
|
||||||
|
|
||||||
class TTSModel:
|
|
||||||
_instance = None
|
|
||||||
_device = None
|
|
||||||
_lock = threading.Lock()
|
|
||||||
|
|
||||||
# Directory for all voices (copied base voices, and any created combined voices)
|
|
||||||
VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def initialize(cls):
|
|
||||||
"""Initialize and warm up the model"""
|
|
||||||
with cls._lock:
|
|
||||||
if cls._instance is None:
|
|
||||||
# Initialize model
|
|
||||||
cls._device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
||||||
logger.info(f"Initializing model on {cls._device}")
|
|
||||||
model_path = os.path.join(settings.model_dir, settings.model_path)
|
|
||||||
model = build_model(model_path, cls._device)
|
|
||||||
cls._instance = model
|
|
||||||
|
|
||||||
# Ensure voices directory exists
|
|
||||||
os.makedirs(cls.VOICES_DIR, exist_ok=True)
|
|
||||||
|
|
||||||
# Copy base voices to local directory
|
|
||||||
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
|
|
||||||
if os.path.exists(base_voices_dir):
|
|
||||||
for file in os.listdir(base_voices_dir):
|
|
||||||
if file.endswith(".pt"):
|
|
||||||
voice_name = file[:-3]
|
|
||||||
voice_path = os.path.join(cls.VOICES_DIR, file)
|
|
||||||
if not os.path.exists(voice_path):
|
|
||||||
try:
|
|
||||||
logger.info(
|
|
||||||
f"Copying base voice {voice_name} to voices directory"
|
|
||||||
)
|
|
||||||
base_path = os.path.join(base_voices_dir, file)
|
|
||||||
voicepack = torch.load(
|
|
||||||
base_path,
|
|
||||||
map_location=cls._device,
|
|
||||||
weights_only=True,
|
|
||||||
)
|
|
||||||
torch.save(voicepack, voice_path)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error copying voice {voice_name}: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Warm up with default voice
|
|
||||||
try:
|
|
||||||
dummy_text = "Hello"
|
|
||||||
voice_path = os.path.join(cls.VOICES_DIR, "af.pt")
|
|
||||||
dummy_voicepack = torch.load(
|
|
||||||
voice_path, map_location=cls._device, weights_only=True
|
|
||||||
)
|
|
||||||
generate(model, dummy_text, dummy_voicepack, lang="a", speed=1.0)
|
|
||||||
logger.info("Model warm-up complete")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Model warm-up failed: {e}")
|
|
||||||
|
|
||||||
# Count voices in directory for validation
|
|
||||||
voice_count = len(
|
|
||||||
[f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]
|
|
||||||
)
|
|
||||||
return cls._instance, voice_count
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_instance(cls):
|
|
||||||
"""Get the initialized instance or raise an error"""
|
|
||||||
if cls._instance is None:
|
|
||||||
raise RuntimeError("Model not initialized. Call initialize() first.")
|
|
||||||
return cls._instance, cls._device
|
|
||||||
|
|
||||||
|
|
||||||
class TTSService:
|
|
||||||
def __init__(self, output_dir: str = None, start_worker: bool = False):
|
|
||||||
self.output_dir = output_dir
|
|
||||||
self._ensure_voices()
|
|
||||||
if start_worker:
|
|
||||||
self.start_worker()
|
|
||||||
|
|
||||||
def _ensure_voices(self):
|
|
||||||
"""Copy base voices to local voices directory during initialization"""
|
|
||||||
os.makedirs(TTSModel.VOICES_DIR, exist_ok=True)
|
|
||||||
|
|
||||||
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
|
|
||||||
if os.path.exists(base_voices_dir):
|
|
||||||
for file in os.listdir(base_voices_dir):
|
|
||||||
if file.endswith(".pt"):
|
|
||||||
voice_name = file[:-3]
|
|
||||||
voice_path = os.path.join(TTSModel.VOICES_DIR, file)
|
|
||||||
if not os.path.exists(voice_path):
|
|
||||||
try:
|
|
||||||
logger.info(
|
|
||||||
f"Copying base voice {voice_name} to voices directory"
|
|
||||||
)
|
|
||||||
base_path = os.path.join(base_voices_dir, file)
|
|
||||||
voicepack = torch.load(
|
|
||||||
base_path,
|
|
||||||
map_location=TTSModel._device,
|
|
||||||
weights_only=True,
|
|
||||||
)
|
|
||||||
torch.save(voicepack, voice_path)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error copying voice {voice_name}: {str(e)}")
|
|
||||||
|
|
||||||
def _split_text(self, text: str) -> List[str]:
|
|
||||||
"""Split text into sentences"""
|
|
||||||
return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()]
|
|
||||||
|
|
||||||
def _get_voice_path(self, voice_name: str) -> Optional[str]:
|
|
||||||
"""Get the path to a voice file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
voice_name: Name of the voice to find
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path to the voice file if found, None otherwise
|
|
||||||
"""
|
|
||||||
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice_name}.pt")
|
|
||||||
return voice_path if os.path.exists(voice_path) else None
|
|
||||||
|
|
||||||
def _generate_audio(
|
|
||||||
self, text: str, voice: str, speed: float, stitch_long_output: bool = True
|
|
||||||
) -> Tuple[torch.Tensor, float]:
|
|
||||||
"""Generate audio and measure processing time"""
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Normalize text once at the start
|
|
||||||
text = normalize_text(text)
|
|
||||||
if not text:
|
|
||||||
raise ValueError("Text is empty after preprocessing")
|
|
||||||
|
|
||||||
# Check voice exists
|
|
||||||
voice_path = self._get_voice_path(voice)
|
|
||||||
if not voice_path:
|
|
||||||
raise ValueError(f"Voice not found: {voice}")
|
|
||||||
|
|
||||||
# Load model and voice
|
|
||||||
model = TTSModel._instance
|
|
||||||
voicepack = torch.load(
|
|
||||||
voice_path, map_location=TTSModel._device, weights_only=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate audio with or without stitching
|
|
||||||
if stitch_long_output:
|
|
||||||
chunks = self._split_text(text)
|
|
||||||
audio_chunks = []
|
|
||||||
|
|
||||||
# Process all chunks with same model/voicepack instance
|
|
||||||
for i, chunk in enumerate(chunks):
|
|
||||||
try:
|
|
||||||
# Validate phonemization first
|
|
||||||
# ps = phonemize(chunk, voice[0])
|
|
||||||
# tokens = tokenize(ps)
|
|
||||||
# logger.debug(
|
|
||||||
# f"Processing chunk {i + 1}/{len(chunks)}: {len(tokens)} tokens"
|
|
||||||
# )
|
|
||||||
|
|
||||||
# Only proceed if phonemization succeeded
|
|
||||||
chunk_audio, _ = generate(
|
|
||||||
model, chunk, voicepack, lang=voice[0], speed=speed
|
|
||||||
)
|
|
||||||
if chunk_audio is not None:
|
|
||||||
audio_chunks.append(chunk_audio)
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
f"No audio generated for chunk {i + 1}/{len(chunks)}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to generate audio for chunk {i + 1}/{len(chunks)}: '{chunk}'. Error: {str(e)}"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not audio_chunks:
|
|
||||||
raise ValueError("No audio chunks were generated successfully")
|
|
||||||
|
|
||||||
audio = (
|
|
||||||
np.concatenate(audio_chunks)
|
|
||||||
if len(audio_chunks) > 1
|
|
||||||
else audio_chunks[0]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
audio, _ = generate(model, text, voicepack, lang=voice[0], speed=speed)
|
|
||||||
|
|
||||||
processing_time = time.time() - start_time
|
|
||||||
return audio, processing_time
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error in audio generation: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _save_audio(self, audio: torch.Tensor, filepath: str):
|
|
||||||
"""Save audio to file"""
|
|
||||||
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
|
||||||
wavfile.write(filepath, 24000, audio)
|
|
||||||
|
|
||||||
def _audio_to_bytes(self, audio: torch.Tensor) -> bytes:
|
|
||||||
"""Convert audio tensor to WAV bytes"""
|
|
||||||
buffer = io.BytesIO()
|
|
||||||
wavfile.write(buffer, 24000, audio)
|
|
||||||
return buffer.getvalue()
|
|
||||||
|
|
||||||
def combine_voices(self, voices: List[str]) -> str:
|
|
||||||
"""Combine multiple voices into a new voice.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
voices: List of voice names to combine
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Name of the combined voice
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If less than 2 voices provided or voice loading fails
|
|
||||||
RuntimeError: If voice combination or saving fails
|
|
||||||
"""
|
|
||||||
if len(voices) < 2:
|
|
||||||
raise ValueError("At least 2 voices are required for combination")
|
|
||||||
|
|
||||||
# Load voices
|
|
||||||
t_voices: List[torch.Tensor] = []
|
|
||||||
v_name: List[str] = []
|
|
||||||
|
|
||||||
for voice in voices:
|
|
||||||
try:
|
|
||||||
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt")
|
|
||||||
voicepack = torch.load(
|
|
||||||
voice_path, map_location=TTSModel._device, weights_only=True
|
|
||||||
)
|
|
||||||
t_voices.append(voicepack)
|
|
||||||
v_name.append(voice)
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"Failed to load voice {voice}: {str(e)}")
|
|
||||||
|
|
||||||
# Combine voices
|
|
||||||
try:
|
|
||||||
f: str = "_".join(v_name)
|
|
||||||
v = torch.mean(torch.stack(t_voices), dim=0)
|
|
||||||
combined_path = os.path.join(TTSModel.VOICES_DIR, f"{f}.pt")
|
|
||||||
|
|
||||||
# Save combined voice
|
|
||||||
try:
|
|
||||||
torch.save(v, combined_path)
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Failed to save combined voice to {combined_path}: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return f
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
if not isinstance(e, (ValueError, RuntimeError)):
|
|
||||||
raise RuntimeError(f"Error combining voices: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def list_voices(self) -> List[str]:
|
|
||||||
"""List all available voices"""
|
|
||||||
voices = []
|
|
||||||
try:
|
|
||||||
for file in os.listdir(TTSModel.VOICES_DIR):
|
|
||||||
if file.endswith(".pt"):
|
|
||||||
voices.append(file[:-3]) # Remove .pt extension
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error listing voices: {str(e)}")
|
|
||||||
return sorted(voices)
|
|
65
api/src/services/tts_cpu.py
Normal file
65
api/src/services/tts_cpu.py
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from onnxruntime import InferenceSession, SessionOptions, GraphOptimizationLevel, ExecutionMode
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
class TTSCPUModel:
|
||||||
|
_instance = None
|
||||||
|
_onnx_session = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def initialize(cls, model_dir: str):
|
||||||
|
"""Initialize ONNX model for CPU inference"""
|
||||||
|
if cls._onnx_session is None:
|
||||||
|
# Try loading ONNX model
|
||||||
|
onnx_path = os.path.join(model_dir, "kokoro-v0_19.onnx")
|
||||||
|
if not os.path.exists(onnx_path):
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.info(f"Loading ONNX model from {onnx_path}")
|
||||||
|
|
||||||
|
# Configure ONNX session for optimal performance
|
||||||
|
session_options = SessionOptions()
|
||||||
|
session_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||||
|
session_options.intra_op_num_threads = 4 # Adjust based on CPU cores
|
||||||
|
session_options.execution_mode = ExecutionMode.ORT_SEQUENTIAL
|
||||||
|
|
||||||
|
# Configure CPU provider options
|
||||||
|
provider_options = {
|
||||||
|
'CPUExecutionProvider': {
|
||||||
|
'arena_extend_strategy': 'kNextPowerOfTwo',
|
||||||
|
'cpu_memory_arena_cfg': 'cpu:0'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cls._onnx_session = InferenceSession(
|
||||||
|
onnx_path,
|
||||||
|
sess_options=session_options,
|
||||||
|
providers=['CPUExecutionProvider'],
|
||||||
|
provider_options=[provider_options]
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls._onnx_session
|
||||||
|
return cls._onnx_session
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate(cls, tokens: list, voicepack: torch.Tensor, speed: float) -> np.ndarray:
|
||||||
|
"""Generate audio using ONNX model"""
|
||||||
|
if cls._onnx_session is None:
|
||||||
|
raise RuntimeError("ONNX model not initialized")
|
||||||
|
|
||||||
|
# Pre-allocate and prepare inputs
|
||||||
|
tokens_input = np.array([tokens], dtype=np.int64)
|
||||||
|
style_input = voicepack[len(tokens)-2].numpy() # Already has correct dimensions
|
||||||
|
speed_input = np.full(1, speed, dtype=np.float32) # More efficient than ones * speed
|
||||||
|
|
||||||
|
# Run inference with optimized inputs
|
||||||
|
return cls._onnx_session.run(
|
||||||
|
None,
|
||||||
|
{
|
||||||
|
'tokens': tokens_input,
|
||||||
|
'style': style_input,
|
||||||
|
'speed': speed_input
|
||||||
|
}
|
||||||
|
)[0]
|
32
api/src/services/tts_gpu.py
Normal file
32
api/src/services/tts_gpu.py
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
from models import build_model
|
||||||
|
from kokoro import generate
|
||||||
|
|
||||||
|
class TTSGPUModel:
|
||||||
|
_instance = None
|
||||||
|
_device = "cuda"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def initialize(cls, model_dir: str, model_path: str):
|
||||||
|
"""Initialize PyTorch model for GPU inference"""
|
||||||
|
if cls._instance is None and torch.cuda.is_available():
|
||||||
|
try:
|
||||||
|
logger.info("Initializing GPU model")
|
||||||
|
model_path = os.path.join(model_dir, model_path)
|
||||||
|
model = build_model(model_path, cls._device)
|
||||||
|
cls._instance = model
|
||||||
|
return cls._instance
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize GPU model: {e}")
|
||||||
|
return None
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate(cls, text: str, voicepack: torch.Tensor, lang: str, speed: float) -> tuple[torch.Tensor, dict]:
|
||||||
|
"""Generate audio using PyTorch model on GPU"""
|
||||||
|
if cls._instance is None:
|
||||||
|
raise RuntimeError("GPU model not initialized")
|
||||||
|
|
||||||
|
return generate(cls._instance, text, voicepack, lang=lang, speed=speed)
|
94
api/src/services/tts_model.py
Normal file
94
api/src/services/tts_model.py
Normal file
|
@ -0,0 +1,94 @@
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
from kokoro import tokenize, phonemize
|
||||||
|
|
||||||
|
from ..core.config import settings
|
||||||
|
from .tts_cpu import TTSCPUModel
|
||||||
|
from .tts_gpu import TTSGPUModel
|
||||||
|
|
||||||
|
|
||||||
|
class TTSModel:
|
||||||
|
_device = None
|
||||||
|
_lock = threading.Lock()
|
||||||
|
VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def initialize(cls):
|
||||||
|
"""Initialize and warm up the model"""
|
||||||
|
with cls._lock:
|
||||||
|
# Set device and initialize model
|
||||||
|
cls._device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
logger.info(f"Initializing model on {cls._device}")
|
||||||
|
|
||||||
|
# Initialize appropriate model based on device
|
||||||
|
if cls._device == "cuda":
|
||||||
|
if not TTSGPUModel.initialize(settings.model_dir, settings.model_path):
|
||||||
|
raise RuntimeError("Failed to initialize GPU model")
|
||||||
|
else:
|
||||||
|
# Try CPU ONNX first, fallback to CPU PyTorch if needed
|
||||||
|
if not TTSCPUModel.initialize(settings.model_dir):
|
||||||
|
logger.warning("ONNX initialization failed, falling back to PyTorch CPU")
|
||||||
|
if not TTSGPUModel.initialize(settings.model_dir, settings.model_path):
|
||||||
|
raise RuntimeError("Failed to initialize CPU model")
|
||||||
|
|
||||||
|
# Setup voices directory
|
||||||
|
os.makedirs(cls.VOICES_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
# Copy base voices to local directory
|
||||||
|
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
|
||||||
|
if os.path.exists(base_voices_dir):
|
||||||
|
for file in os.listdir(base_voices_dir):
|
||||||
|
if file.endswith(".pt"):
|
||||||
|
voice_name = file[:-3]
|
||||||
|
voice_path = os.path.join(cls.VOICES_DIR, file)
|
||||||
|
if not os.path.exists(voice_path):
|
||||||
|
try:
|
||||||
|
logger.info(
|
||||||
|
f"Copying base voice {voice_name} to voices directory"
|
||||||
|
)
|
||||||
|
base_path = os.path.join(base_voices_dir, file)
|
||||||
|
voicepack = torch.load(
|
||||||
|
base_path,
|
||||||
|
map_location=cls._device,
|
||||||
|
weights_only=True,
|
||||||
|
)
|
||||||
|
torch.save(voicepack, voice_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Error copying voice {voice_name}: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Warm up with default voice
|
||||||
|
try:
|
||||||
|
dummy_text = "Hello"
|
||||||
|
voice_path = os.path.join(cls.VOICES_DIR, "af.pt")
|
||||||
|
dummy_voicepack = torch.load(
|
||||||
|
voice_path, map_location=cls._device, weights_only=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if cls._device == "cuda":
|
||||||
|
TTSGPUModel.generate(dummy_text, dummy_voicepack, "a", 1.0)
|
||||||
|
else:
|
||||||
|
ps = phonemize(dummy_text, "a")
|
||||||
|
tokens = tokenize(ps)
|
||||||
|
tokens = [0] + tokens + [0]
|
||||||
|
TTSCPUModel.generate(tokens, dummy_voicepack, 1.0)
|
||||||
|
|
||||||
|
logger.info("Model warm-up complete")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Model warm-up failed: {e}")
|
||||||
|
|
||||||
|
# Count voices in directory
|
||||||
|
voice_count = len(
|
||||||
|
[f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]
|
||||||
|
)
|
||||||
|
return voice_count
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_device(cls):
|
||||||
|
"""Get the current device or raise an error"""
|
||||||
|
if cls._device is None:
|
||||||
|
raise RuntimeError("Model not initialized. Call initialize() first.")
|
||||||
|
return cls._device
|
168
api/src/services/tts_service.py
Normal file
168
api/src/services/tts_service.py
Normal file
|
@ -0,0 +1,168 @@
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import scipy.io.wavfile as wavfile
|
||||||
|
from kokoro import tokenize, phonemize, normalize_text
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from ..core.config import settings
|
||||||
|
from .tts_model import TTSModel
|
||||||
|
from .tts_cpu import TTSCPUModel
|
||||||
|
from .tts_gpu import TTSGPUModel
|
||||||
|
|
||||||
|
|
||||||
|
class TTSService:
|
||||||
|
def __init__(self, output_dir: str = None):
|
||||||
|
self.output_dir = output_dir
|
||||||
|
|
||||||
|
def _split_text(self, text: str) -> List[str]:
|
||||||
|
"""Split text into sentences"""
|
||||||
|
return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()]
|
||||||
|
|
||||||
|
def _get_voice_path(self, voice_name: str) -> Optional[str]:
|
||||||
|
"""Get the path to a voice file"""
|
||||||
|
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice_name}.pt")
|
||||||
|
return voice_path if os.path.exists(voice_path) else None
|
||||||
|
|
||||||
|
def _generate_audio(
|
||||||
|
self, text: str, voice: str, speed: float, stitch_long_output: bool = True
|
||||||
|
) -> Tuple[torch.Tensor, float]:
|
||||||
|
"""Generate audio and measure processing time"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Normalize text once at the start
|
||||||
|
text = normalize_text(text)
|
||||||
|
if not text:
|
||||||
|
raise ValueError("Text is empty after preprocessing")
|
||||||
|
|
||||||
|
# Check voice exists
|
||||||
|
voice_path = self._get_voice_path(voice)
|
||||||
|
if not voice_path:
|
||||||
|
raise ValueError(f"Voice not found: {voice}")
|
||||||
|
|
||||||
|
# Load voice
|
||||||
|
voicepack = torch.load(
|
||||||
|
voice_path, map_location=TTSModel.get_device(), weights_only=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate audio with or without stitching
|
||||||
|
if stitch_long_output:
|
||||||
|
chunks = self._split_text(text)
|
||||||
|
audio_chunks = []
|
||||||
|
|
||||||
|
# Process all chunks
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
try:
|
||||||
|
# Process chunk
|
||||||
|
if TTSModel.get_device() == "cuda":
|
||||||
|
chunk_audio, _ = TTSGPUModel.generate(chunk, voicepack, voice[0], speed)
|
||||||
|
else:
|
||||||
|
ps = phonemize(chunk, voice[0])
|
||||||
|
tokens = tokenize(ps)
|
||||||
|
tokens = [0] + tokens + [0] # Add padding
|
||||||
|
chunk_audio = TTSCPUModel.generate(tokens, voicepack, speed)
|
||||||
|
|
||||||
|
if chunk_audio is not None:
|
||||||
|
audio_chunks.append(chunk_audio)
|
||||||
|
else:
|
||||||
|
logger.error(f"No audio generated for chunk {i + 1}/{len(chunks)}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to generate audio for chunk {i + 1}/{len(chunks)}: '{chunk}'. Error: {str(e)}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not audio_chunks:
|
||||||
|
raise ValueError("No audio chunks were generated successfully")
|
||||||
|
|
||||||
|
audio = (
|
||||||
|
np.concatenate(audio_chunks)
|
||||||
|
if len(audio_chunks) > 1
|
||||||
|
else audio_chunks[0]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Process single chunk
|
||||||
|
if TTSModel.get_device() == "cuda":
|
||||||
|
audio, _ = TTSGPUModel.generate(text, voicepack, voice[0], speed)
|
||||||
|
else:
|
||||||
|
ps = phonemize(text, voice[0])
|
||||||
|
tokens = tokenize(ps)
|
||||||
|
tokens = [0] + tokens + [0] # Add padding
|
||||||
|
audio = TTSCPUModel.generate(tokens, voicepack, speed)
|
||||||
|
|
||||||
|
processing_time = time.time() - start_time
|
||||||
|
return audio, processing_time
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in audio generation: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _save_audio(self, audio: torch.Tensor, filepath: str):
|
||||||
|
"""Save audio to file"""
|
||||||
|
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
||||||
|
wavfile.write(filepath, 24000, audio)
|
||||||
|
|
||||||
|
def _audio_to_bytes(self, audio: torch.Tensor) -> bytes:
|
||||||
|
"""Convert audio tensor to WAV bytes"""
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
wavfile.write(buffer, 24000, audio)
|
||||||
|
return buffer.getvalue()
|
||||||
|
|
||||||
|
def combine_voices(self, voices: List[str]) -> str:
|
||||||
|
"""Combine multiple voices into a new voice"""
|
||||||
|
if len(voices) < 2:
|
||||||
|
raise ValueError("At least 2 voices are required for combination")
|
||||||
|
|
||||||
|
# Load voices
|
||||||
|
t_voices: List[torch.Tensor] = []
|
||||||
|
v_name: List[str] = []
|
||||||
|
|
||||||
|
for voice in voices:
|
||||||
|
try:
|
||||||
|
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt")
|
||||||
|
voicepack = torch.load(
|
||||||
|
voice_path, map_location=TTSModel.get_device(), weights_only=True
|
||||||
|
)
|
||||||
|
t_voices.append(voicepack)
|
||||||
|
v_name.append(voice)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Failed to load voice {voice}: {str(e)}")
|
||||||
|
|
||||||
|
# Combine voices
|
||||||
|
try:
|
||||||
|
f: str = "_".join(v_name)
|
||||||
|
v = torch.mean(torch.stack(t_voices), dim=0)
|
||||||
|
combined_path = os.path.join(TTSModel.VOICES_DIR, f"{f}.pt")
|
||||||
|
|
||||||
|
# Save combined voice
|
||||||
|
try:
|
||||||
|
torch.save(v, combined_path)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to save combined voice to {combined_path}: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if not isinstance(e, (ValueError, RuntimeError)):
|
||||||
|
raise RuntimeError(f"Error combining voices: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def list_voices(self) -> List[str]:
|
||||||
|
"""List all available voices"""
|
||||||
|
voices = []
|
||||||
|
try:
|
||||||
|
for file in os.listdir(TTSModel.VOICES_DIR):
|
||||||
|
if file.endswith(".pt"):
|
||||||
|
voices.append(file[:-3]) # Remove .pt extension
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error listing voices: {str(e)}")
|
||||||
|
return sorted(voices)
|
|
@ -7,7 +7,8 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from api.src.services.tts import TTSModel, TTSService
|
from api.src.services.tts_model import TTSModel
|
||||||
|
from api.src.services.tts_service import TTSService
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -68,12 +69,12 @@ def test_list_voices(mock_join, mock_listdir, tts_service):
|
||||||
assert "not_a_voice" not in voices
|
assert "not_a_voice" not in voices
|
||||||
|
|
||||||
|
|
||||||
@patch("api.src.services.tts.TTSModel.get_instance")
|
@patch("api.src.services.tts_model.TTSModel.get_instance")
|
||||||
@patch("api.src.services.tts.TTSModel.get_voicepack")
|
@patch("api.src.services.tts_model.TTSModel.get_voicepack")
|
||||||
@patch("api.src.services.tts.normalize_text")
|
@patch("kokoro.normalize_text")
|
||||||
@patch("api.src.services.tts.phonemize")
|
@patch("kokoro.phonemize")
|
||||||
@patch("api.src.services.tts.tokenize")
|
@patch("kokoro.tokenize")
|
||||||
@patch("api.src.services.tts.generate")
|
@patch("kokoro.generate")
|
||||||
def test_generate_audio_empty_text(
|
def test_generate_audio_empty_text(
|
||||||
mock_generate,
|
mock_generate,
|
||||||
mock_tokenize,
|
mock_tokenize,
|
||||||
|
@ -90,12 +91,12 @@ def test_generate_audio_empty_text(
|
||||||
tts_service._generate_audio("", "af", 1.0)
|
tts_service._generate_audio("", "af", 1.0)
|
||||||
|
|
||||||
|
|
||||||
@patch("api.src.services.tts.TTSModel.get_instance")
|
@patch("api.src.services.tts_model.TTSModel.get_instance")
|
||||||
@patch("os.path.exists")
|
@patch("os.path.exists")
|
||||||
@patch("api.src.services.tts.normalize_text")
|
@patch("kokoro.normalize_text")
|
||||||
@patch("api.src.services.tts.phonemize")
|
@patch("kokoro.phonemize")
|
||||||
@patch("api.src.services.tts.tokenize")
|
@patch("kokoro.tokenize")
|
||||||
@patch("api.src.services.tts.generate")
|
@patch("kokoro.generate")
|
||||||
@patch("torch.load")
|
@patch("torch.load")
|
||||||
def test_generate_audio_no_chunks(
|
def test_generate_audio_no_chunks(
|
||||||
mock_torch_load,
|
mock_torch_load,
|
||||||
|
@ -225,8 +226,8 @@ def test_generate_audio_success(
|
||||||
assert len(audio) > 0
|
assert len(audio) > 0
|
||||||
|
|
||||||
|
|
||||||
@patch("api.src.services.tts.torch.cuda.is_available")
|
@patch("torch.cuda.is_available")
|
||||||
@patch("api.src.services.tts.build_model")
|
@patch("models.build_model")
|
||||||
def test_model_initialization_cuda(mock_build_model, mock_cuda_available):
|
def test_model_initialization_cuda(mock_build_model, mock_cuda_available):
|
||||||
"""Test model initialization with CUDA"""
|
"""Test model initialization with CUDA"""
|
||||||
mock_cuda_available.return_value = True
|
mock_cuda_available.return_value = True
|
||||||
|
@ -257,8 +258,8 @@ def test_model_initialization_cpu(mock_build_model, mock_cuda_available):
|
||||||
mock_build_model.assert_called_once()
|
mock_build_model.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@patch("api.src.services.tts.TTSService._get_voice_path")
|
@patch("api.src.services.tts_service.TTSService._get_voice_path")
|
||||||
@patch("api.src.services.tts.TTSModel.get_instance")
|
@patch("api.src.services.tts_model.TTSModel.get_instance")
|
||||||
def test_voicepack_loading_error(mock_get_instance, mock_get_voice_path):
|
def test_voicepack_loading_error(mock_get_instance, mock_get_voice_path):
|
||||||
"""Test voicepack loading error handling"""
|
"""Test voicepack loading error handling"""
|
||||||
mock_get_voice_path.return_value = None
|
mock_get_voice_path.return_value = None
|
||||||
|
@ -271,7 +272,7 @@ def test_voicepack_loading_error(mock_get_instance, mock_get_voice_path):
|
||||||
service._generate_audio("test", "nonexistent_voice", 1.0)
|
service._generate_audio("test", "nonexistent_voice", 1.0)
|
||||||
|
|
||||||
|
|
||||||
@patch("api.src.services.tts.TTSModel")
|
@patch("api.src.services.tts_model.TTSModel")
|
||||||
def test_save_audio(mock_tts_model, tts_service, sample_audio, tmp_path):
|
def test_save_audio(mock_tts_model, tts_service, sample_audio, tmp_path):
|
||||||
"""Test saving audio to file"""
|
"""Test saving audio to file"""
|
||||||
output_dir = os.path.join(tmp_path, "test_output")
|
output_dir = os.path.join(tmp_path, "test_output")
|
||||||
|
@ -284,7 +285,7 @@ def test_save_audio(mock_tts_model, tts_service, sample_audio, tmp_path):
|
||||||
assert os.path.getsize(output_path) > 0
|
assert os.path.getsize(output_path) > 0
|
||||||
|
|
||||||
|
|
||||||
@patch("api.src.services.tts.TTSModel.get_instance")
|
@patch("api.src.services.tts_model.TTSModel.get_instance")
|
||||||
@patch("os.path.exists")
|
@patch("os.path.exists")
|
||||||
@patch("api.src.services.tts.normalize_text")
|
@patch("api.src.services.tts.normalize_text")
|
||||||
@patch("api.src.services.tts.generate")
|
@patch("api.src.services.tts.generate")
|
||||||
|
|
216
examples/benchmarks/benchmark_results_cpu.json
Normal file
216
examples/benchmarks/benchmark_results_cpu.json
Normal file
|
@ -0,0 +1,216 @@
|
||||||
|
{
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"tokens": 100,
|
||||||
|
"processing_time": 14.349808931350708,
|
||||||
|
"output_length": 31.15,
|
||||||
|
"rtf": 0.46,
|
||||||
|
"elapsed_time": 14.716031074523926
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tokens": 200,
|
||||||
|
"processing_time": 28.341803312301636,
|
||||||
|
"output_length": 62.6,
|
||||||
|
"rtf": 0.45,
|
||||||
|
"elapsed_time": 43.44207406044006
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tokens": 300,
|
||||||
|
"processing_time": 43.352553606033325,
|
||||||
|
"output_length": 96.325,
|
||||||
|
"rtf": 0.45,
|
||||||
|
"elapsed_time": 87.26906609535217
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tokens": 400,
|
||||||
|
"processing_time": 71.02449822425842,
|
||||||
|
"output_length": 128.575,
|
||||||
|
"rtf": 0.55,
|
||||||
|
"elapsed_time": 158.7198133468628
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tokens": 500,
|
||||||
|
"processing_time": 70.92521691322327,
|
||||||
|
"output_length": 158.575,
|
||||||
|
"rtf": 0.45,
|
||||||
|
"elapsed_time": 230.01379895210266
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tokens": 600,
|
||||||
|
"processing_time": 83.6328592300415,
|
||||||
|
"output_length": 189.25,
|
||||||
|
"rtf": 0.44,
|
||||||
|
"elapsed_time": 314.02610969543457
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tokens": 700,
|
||||||
|
"processing_time": 103.0810194015503,
|
||||||
|
"output_length": 222.075,
|
||||||
|
"rtf": 0.46,
|
||||||
|
"elapsed_time": 417.5678551197052
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tokens": 800,
|
||||||
|
"processing_time": 127.02162909507751,
|
||||||
|
"output_length": 253.85,
|
||||||
|
"rtf": 0.5,
|
||||||
|
"elapsed_time": 545.0128681659698
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tokens": 900,
|
||||||
|
"processing_time": 130.49781227111816,
|
||||||
|
"output_length": 283.775,
|
||||||
|
"rtf": 0.46,
|
||||||
|
"elapsed_time": 675.8943417072296
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tokens": 1000,
|
||||||
|
"processing_time": 154.76425909996033,
|
||||||
|
"output_length": 315.475,
|
||||||
|
"rtf": 0.49,
|
||||||
|
"elapsed_time": 831.0677945613861
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"system_metrics": [
|
||||||
|
{
|
||||||
|
"timestamp": "2025-01-03T00:23:52.896889",
|
||||||
|
"cpu_percent": 4.5,
|
||||||
|
"ram_percent": 39.1,
|
||||||
|
"ram_used_gb": 24.86032485961914,
|
||||||
|
"gpu_memory_used": 1281.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"timestamp": "2025-01-03T00:24:07.429461",
|
||||||
|
"cpu_percent": 4.5,
|
||||||
|
"ram_percent": 39.1,
|
||||||
|
"ram_used_gb": 24.847564697265625,
|
||||||
|
"gpu_memory_used": 1285.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"timestamp": "2025-01-03T00:24:07.620587",
|
||||||
|
"cpu_percent": 2.7,
|
||||||
|
"ram_percent": 39.1,
|
||||||
|
"ram_used_gb": 24.846607208251953,
|
||||||
|
"gpu_memory_used": 1275.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"timestamp": "2025-01-03T00:24:36.140754",
|
||||||
|
"cpu_percent": 5.4,
|
||||||
|
"ram_percent": 39.1,
|
||||||
|
"ram_used_gb": 24.857810974121094,
|
||||||
|
"gpu_memory_used": 1267.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"timestamp": "2025-01-03T00:24:36.340675",
|
||||||
|
"cpu_percent": 6.2,
|
||||||
|
"ram_percent": 39.1,
|
||||||
|
"ram_used_gb": 24.85773468017578,
|
||||||
|
"gpu_memory_used": 1267.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"timestamp": "2025-01-03T00:25:19.905634",
|
||||||
|
"cpu_percent": 29.1,
|
||||||
|
"ram_percent": 39.2,
|
||||||
|
"ram_used_gb": 24.920318603515625,
|
||||||
|
"gpu_memory_used": 1256.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"timestamp": "2025-01-03T00:25:20.182219",
|
||||||
|
"cpu_percent": 20.0,
|
||||||
|
"ram_percent": 39.2,
|
||||||
|
"ram_used_gb": 24.930198669433594,
|
||||||
|
"gpu_memory_used": 1256.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"timestamp": "2025-01-03T00:26:31.414760",
|
||||||
|
"cpu_percent": 5.3,
|
||||||
|
"ram_percent": 39.5,
|
||||||
|
"ram_used_gb": 25.127891540527344,
|
||||||
|
"gpu_memory_used": 1259.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"timestamp": "2025-01-03T00:26:31.617256",
|
||||||
|
"cpu_percent": 3.6,
|
||||||
|
"ram_percent": 39.5,
|
||||||
|
"ram_used_gb": 25.126346588134766,
|
||||||
|
"gpu_memory_used": 1252.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"timestamp": "2025-01-03T00:27:42.736097",
|
||||||
|
"cpu_percent": 10.5,
|
||||||
|
"ram_percent": 39.5,
|
||||||
|
"ram_used_gb": 25.100231170654297,
|
||||||
|
"gpu_memory_used": 1249.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"timestamp": "2025-01-03T00:27:42.912870",
|
||||||
|
"cpu_percent": 5.3,
|
||||||
|
"ram_percent": 39.5,
|
||||||
|
"ram_used_gb": 25.098285675048828,
|
||||||
|
"gpu_memory_used": 1249.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"timestamp": "2025-01-03T00:29:06.725264",
|
||||||
|
"cpu_percent": 8.9,
|
||||||
|
"ram_percent": 39.5,
|
||||||
|
"ram_used_gb": 25.123123168945312,
|
||||||
|
"gpu_memory_used": 1239.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"timestamp": "2025-01-03T00:29:06.928826",
|
||||||
|
"cpu_percent": 5.5,
|
||||||
|
"ram_percent": 39.5,
|
||||||
|
"ram_used_gb": 25.128646850585938,
|
||||||
|
"gpu_memory_used": 1239.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"timestamp": "2025-01-03T00:30:50.206349",
|
||||||
|
"cpu_percent": 49.6,
|
||||||
|
"ram_percent": 39.6,
|
||||||
|
"ram_used_gb": 25.162948608398438,
|
||||||
|
"gpu_memory_used": 1245.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"timestamp": "2025-01-03T00:30:50.491837",
|
||||||
|
"cpu_percent": 14.8,
|
||||||
|
"ram_percent": 39.5,
|
||||||
|
"ram_used_gb": 25.13379669189453,
|
||||||
|
"gpu_memory_used": 1245.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"timestamp": "2025-01-03T00:32:57.721467",
|
||||||
|
"cpu_percent": 6.2,
|
||||||
|
"ram_percent": 39.6,
|
||||||
|
"ram_used_gb": 25.187721252441406,
|
||||||
|
"gpu_memory_used": 1384.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"timestamp": "2025-01-03T00:32:57.913350",
|
||||||
|
"cpu_percent": 3.6,
|
||||||
|
"ram_percent": 39.6,
|
||||||
|
"ram_used_gb": 25.199390411376953,
|
||||||
|
"gpu_memory_used": 1384.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"timestamp": "2025-01-03T00:35:08.608730",
|
||||||
|
"cpu_percent": 6.3,
|
||||||
|
"ram_percent": 39.8,
|
||||||
|
"ram_used_gb": 25.311710357666016,
|
||||||
|
"gpu_memory_used": 1330.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"timestamp": "2025-01-03T00:35:08.791851",
|
||||||
|
"cpu_percent": 5.3,
|
||||||
|
"ram_percent": 39.8,
|
||||||
|
"ram_used_gb": 25.326683044433594,
|
||||||
|
"gpu_memory_used": 1333.0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"timestamp": "2025-01-03T00:37:43.782406",
|
||||||
|
"cpu_percent": 6.8,
|
||||||
|
"ram_percent": 40.6,
|
||||||
|
"ram_used_gb": 25.803058624267578,
|
||||||
|
"gpu_memory_used": 1409.0
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
19
examples/benchmarks/benchmark_stats_cpu.txt
Normal file
19
examples/benchmarks/benchmark_stats_cpu.txt
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
=== Benchmark Statistics (with correct RTF) ===
|
||||||
|
|
||||||
|
Overall Stats:
|
||||||
|
Total tokens processed: 5500
|
||||||
|
Total audio generated: 1741.65s
|
||||||
|
Total test duration: 831.07s
|
||||||
|
Average processing rate: 6.72 tokens/second
|
||||||
|
Average RTF: 0.47x
|
||||||
|
|
||||||
|
Per-chunk Stats:
|
||||||
|
Average chunk size: 550.00 tokens
|
||||||
|
Min chunk size: 100.00 tokens
|
||||||
|
Max chunk size: 1000.00 tokens
|
||||||
|
Average processing time: 82.70s
|
||||||
|
Average output length: 174.17s
|
||||||
|
|
||||||
|
Performance Ranges:
|
||||||
|
Processing rate range: 5.63 - 7.17 tokens/second
|
||||||
|
RTF range: 0.44x - 0.55x
|
323
examples/benchmarks/benchmark_tts_rtf.py
Normal file
323
examples/benchmarks/benchmark_tts_rtf.py
Normal file
|
@ -0,0 +1,323 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import subprocess
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import psutil
|
||||||
|
import seaborn as sns
|
||||||
|
import requests
|
||||||
|
import tiktoken
|
||||||
|
import scipy.io.wavfile as wavfile
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
enc = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
|
|
||||||
|
def setup_plot(fig, ax, title):
|
||||||
|
"""Configure plot styling"""
|
||||||
|
ax.grid(True, linestyle="--", alpha=0.3, color="#ffffff")
|
||||||
|
ax.set_title(title, pad=20, fontsize=16, fontweight="bold", color="#ffffff")
|
||||||
|
ax.set_xlabel(ax.get_xlabel(), fontsize=14, fontweight="medium", color="#ffffff")
|
||||||
|
ax.set_ylabel(ax.get_ylabel(), fontsize=14, fontweight="medium", color="#ffffff")
|
||||||
|
ax.tick_params(labelsize=12, colors="#ffffff")
|
||||||
|
|
||||||
|
for spine in ax.spines.values():
|
||||||
|
spine.set_color("#ffffff")
|
||||||
|
spine.set_alpha(0.3)
|
||||||
|
spine.set_linewidth(0.5)
|
||||||
|
|
||||||
|
ax.set_facecolor("#1a1a2e")
|
||||||
|
fig.patch.set_facecolor("#1a1a2e")
|
||||||
|
return fig, ax
|
||||||
|
|
||||||
|
|
||||||
|
def get_text_for_tokens(text: str, num_tokens: int) -> str:
|
||||||
|
"""Get a slice of text that contains exactly num_tokens tokens"""
|
||||||
|
tokens = enc.encode(text)
|
||||||
|
if num_tokens > len(tokens):
|
||||||
|
return text
|
||||||
|
return enc.decode(tokens[:num_tokens])
|
||||||
|
|
||||||
|
|
||||||
|
def get_audio_length(audio_data: bytes) -> float:
|
||||||
|
"""Get audio length in seconds from bytes data"""
|
||||||
|
temp_path = "examples/benchmarks/output/temp.wav"
|
||||||
|
os.makedirs(os.path.dirname(temp_path), exist_ok=True)
|
||||||
|
with open(temp_path, "wb") as f:
|
||||||
|
f.write(audio_data)
|
||||||
|
|
||||||
|
try:
|
||||||
|
rate, data = wavfile.read(temp_path)
|
||||||
|
return len(data) / rate
|
||||||
|
finally:
|
||||||
|
if os.path.exists(temp_path):
|
||||||
|
os.remove(temp_path)
|
||||||
|
|
||||||
|
|
||||||
|
def get_gpu_memory():
|
||||||
|
"""Get GPU memory usage using nvidia-smi"""
|
||||||
|
try:
|
||||||
|
result = subprocess.check_output(
|
||||||
|
["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader"]
|
||||||
|
)
|
||||||
|
return float(result.decode("utf-8").strip())
|
||||||
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_system_metrics():
|
||||||
|
"""Get current system metrics"""
|
||||||
|
# Take multiple CPU measurements over a short period
|
||||||
|
samples = []
|
||||||
|
for _ in range(3): # Take 3 samples
|
||||||
|
# Get both overall and per-CPU percentages
|
||||||
|
overall_cpu = psutil.cpu_percent(interval=0.1)
|
||||||
|
per_cpu = psutil.cpu_percent(percpu=True)
|
||||||
|
avg_per_cpu = sum(per_cpu) / len(per_cpu)
|
||||||
|
# Use the maximum of overall and average per-CPU
|
||||||
|
samples.append(max(overall_cpu, avg_per_cpu))
|
||||||
|
|
||||||
|
# Use the maximum CPU usage from all samples
|
||||||
|
cpu_usage = round(max(samples), 2)
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"cpu_percent": cpu_usage,
|
||||||
|
"ram_percent": psutil.virtual_memory().percent,
|
||||||
|
"ram_used_gb": psutil.virtual_memory().used / (1024**3),
|
||||||
|
}
|
||||||
|
|
||||||
|
gpu_mem = get_gpu_memory()
|
||||||
|
if gpu_mem is not None:
|
||||||
|
metrics["gpu_memory_used"] = gpu_mem
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
def real_time_factor(processing_time: float, audio_length: float, decimals: int = 2) -> float:
|
||||||
|
"""Calculate Real-Time Factor (RTF) as processing-time / length-of-audio"""
|
||||||
|
rtf = processing_time / audio_length
|
||||||
|
return round(rtf, decimals)
|
||||||
|
|
||||||
|
|
||||||
|
def make_tts_request(text: str, timeout: int = 1800) -> tuple[float, float]:
|
||||||
|
"""Make TTS request using OpenAI-compatible endpoint and return processing time and output length"""
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
response = requests.post(
|
||||||
|
"http://localhost:8880/v1/audio/speech",
|
||||||
|
json={
|
||||||
|
"model": "kokoro",
|
||||||
|
"input": text,
|
||||||
|
"voice": "af",
|
||||||
|
"response_format": "wav",
|
||||||
|
},
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
processing_time = round(time.time() - start_time, 2)
|
||||||
|
audio_length = round(get_audio_length(response.content), 2)
|
||||||
|
|
||||||
|
# Save the audio file
|
||||||
|
token_count = len(enc.encode(text))
|
||||||
|
output_file = f"examples/benchmarks/output/chunk_{token_count}_tokens.wav"
|
||||||
|
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||||
|
with open(output_file, "wb") as f:
|
||||||
|
f.write(response.content)
|
||||||
|
print(f"Saved audio to {output_file}")
|
||||||
|
|
||||||
|
return processing_time, audio_length
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
print(f"Error making request for text: {text[:50]}... Error: {str(e)}")
|
||||||
|
return None, None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing text: {text[:50]}... Error: {str(e)}")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
def plot_system_metrics(metrics_data):
|
||||||
|
"""Create plots for system metrics over time"""
|
||||||
|
df = pd.DataFrame(metrics_data)
|
||||||
|
df["timestamp"] = pd.to_datetime(df["timestamp"])
|
||||||
|
elapsed_time = (df["timestamp"] - df["timestamp"].iloc[0]).dt.total_seconds()
|
||||||
|
|
||||||
|
baseline_cpu = df["cpu_percent"].iloc[0]
|
||||||
|
baseline_ram = df["ram_used_gb"].iloc[0]
|
||||||
|
baseline_gpu = df["gpu_memory_used"].iloc[0] / 1024 if "gpu_memory_used" in df.columns else None
|
||||||
|
|
||||||
|
if "gpu_memory_used" in df.columns:
|
||||||
|
df["gpu_memory_gb"] = df["gpu_memory_used"] / 1024
|
||||||
|
|
||||||
|
plt.style.use("dark_background")
|
||||||
|
|
||||||
|
has_gpu = "gpu_memory_used" in df.columns
|
||||||
|
num_plots = 3 if has_gpu else 2
|
||||||
|
fig, axes = plt.subplots(num_plots, 1, figsize=(15, 5 * num_plots))
|
||||||
|
fig.patch.set_facecolor("#1a1a2e")
|
||||||
|
|
||||||
|
window = min(5, len(df) // 2)
|
||||||
|
|
||||||
|
# Plot CPU Usage
|
||||||
|
smoothed_cpu = df["cpu_percent"].rolling(window=window, center=True).mean()
|
||||||
|
sns.lineplot(x=elapsed_time, y=smoothed_cpu, ax=axes[0], color="#ff2a6d", linewidth=2)
|
||||||
|
axes[0].axhline(y=baseline_cpu, color="#05d9e8", linestyle="--", alpha=0.5, label="Baseline")
|
||||||
|
axes[0].set_xlabel("Time (seconds)")
|
||||||
|
axes[0].set_ylabel("CPU Usage (%)")
|
||||||
|
axes[0].set_title("CPU Usage Over Time")
|
||||||
|
axes[0].set_ylim(0, max(df["cpu_percent"]) * 1.1)
|
||||||
|
axes[0].legend()
|
||||||
|
|
||||||
|
# Plot RAM Usage
|
||||||
|
smoothed_ram = df["ram_used_gb"].rolling(window=window, center=True).mean()
|
||||||
|
sns.lineplot(x=elapsed_time, y=smoothed_ram, ax=axes[1], color="#05d9e8", linewidth=2)
|
||||||
|
axes[1].axhline(y=baseline_ram, color="#ff2a6d", linestyle="--", alpha=0.5, label="Baseline")
|
||||||
|
axes[1].set_xlabel("Time (seconds)")
|
||||||
|
axes[1].set_ylabel("RAM Usage (GB)")
|
||||||
|
axes[1].set_title("RAM Usage Over Time")
|
||||||
|
axes[1].set_ylim(0, max(df["ram_used_gb"]) * 1.1)
|
||||||
|
axes[1].legend()
|
||||||
|
|
||||||
|
# Plot GPU Memory if available
|
||||||
|
if has_gpu:
|
||||||
|
smoothed_gpu = df["gpu_memory_gb"].rolling(window=window, center=True).mean()
|
||||||
|
sns.lineplot(x=elapsed_time, y=smoothed_gpu, ax=axes[2], color="#ff2a6d", linewidth=2)
|
||||||
|
axes[2].axhline(y=baseline_gpu, color="#05d9e8", linestyle="--", alpha=0.5, label="Baseline")
|
||||||
|
axes[2].set_xlabel("Time (seconds)")
|
||||||
|
axes[2].set_ylabel("GPU Memory (GB)")
|
||||||
|
axes[2].set_title("GPU Memory Usage Over Time")
|
||||||
|
axes[2].set_ylim(0, max(df["gpu_memory_gb"]) * 1.1)
|
||||||
|
axes[2].legend()
|
||||||
|
|
||||||
|
for ax in axes:
|
||||||
|
ax.grid(True, linestyle="--", alpha=0.3)
|
||||||
|
ax.set_facecolor("#1a1a2e")
|
||||||
|
for spine in ax.spines.values():
|
||||||
|
spine.set_color("#ffffff")
|
||||||
|
spine.set_alpha(0.3)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig("examples/benchmarks/system_usage_rtf.png", dpi=300, bbox_inches="tight")
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
os.makedirs("examples/benchmarks/output", exist_ok=True)
|
||||||
|
|
||||||
|
with open("examples/benchmarks/the_time_machine_hg_wells.txt", "r", encoding="utf-8") as f:
|
||||||
|
text = f.read()
|
||||||
|
|
||||||
|
total_tokens = len(enc.encode(text))
|
||||||
|
print(f"Total tokens in file: {total_tokens}")
|
||||||
|
|
||||||
|
# Generate token sizes with dense sampling at start
|
||||||
|
dense_range = list(range(100, 1001, 100))
|
||||||
|
token_sizes = sorted(list(set(dense_range)))
|
||||||
|
print(f"Testing sizes: {token_sizes}")
|
||||||
|
|
||||||
|
results = []
|
||||||
|
system_metrics = []
|
||||||
|
test_start_time = time.time()
|
||||||
|
|
||||||
|
for num_tokens in token_sizes:
|
||||||
|
chunk = get_text_for_tokens(text, num_tokens)
|
||||||
|
actual_tokens = len(enc.encode(chunk))
|
||||||
|
|
||||||
|
print(f"\nProcessing chunk with {actual_tokens} tokens:")
|
||||||
|
print(f"Text preview: {chunk[:100]}...")
|
||||||
|
|
||||||
|
system_metrics.append(get_system_metrics())
|
||||||
|
|
||||||
|
processing_time, audio_length = make_tts_request(chunk)
|
||||||
|
if processing_time is None or audio_length is None:
|
||||||
|
print("Breaking loop due to error")
|
||||||
|
break
|
||||||
|
|
||||||
|
system_metrics.append(get_system_metrics())
|
||||||
|
|
||||||
|
# Calculate RTF using the correct formula
|
||||||
|
rtf = real_time_factor(processing_time, audio_length)
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"tokens": actual_tokens,
|
||||||
|
"processing_time": processing_time,
|
||||||
|
"output_length": audio_length,
|
||||||
|
"rtf": rtf,
|
||||||
|
"elapsed_time": round(time.time() - test_start_time, 2),
|
||||||
|
})
|
||||||
|
|
||||||
|
with open("examples/benchmarks/benchmark_results_rtf.json", "w") as f:
|
||||||
|
json.dump({"results": results, "system_metrics": system_metrics}, f, indent=2)
|
||||||
|
|
||||||
|
df = pd.DataFrame(results)
|
||||||
|
if df.empty:
|
||||||
|
print("No data to plot")
|
||||||
|
return
|
||||||
|
|
||||||
|
df["tokens_per_second"] = df["tokens"] / df["processing_time"]
|
||||||
|
|
||||||
|
with open("examples/benchmarks/benchmark_stats_rtf.txt", "w") as f:
|
||||||
|
f.write("=== Benchmark Statistics (with correct RTF) ===\n\n")
|
||||||
|
|
||||||
|
f.write("Overall Stats:\n")
|
||||||
|
f.write(f"Total tokens processed: {df['tokens'].sum()}\n")
|
||||||
|
f.write(f"Total audio generated: {df['output_length'].sum():.2f}s\n")
|
||||||
|
f.write(f"Total test duration: {df['elapsed_time'].max():.2f}s\n")
|
||||||
|
f.write(f"Average processing rate: {df['tokens_per_second'].mean():.2f} tokens/second\n")
|
||||||
|
f.write(f"Average RTF: {df['rtf'].mean():.2f}x\n\n")
|
||||||
|
|
||||||
|
f.write("Per-chunk Stats:\n")
|
||||||
|
f.write(f"Average chunk size: {df['tokens'].mean():.2f} tokens\n")
|
||||||
|
f.write(f"Min chunk size: {df['tokens'].min():.2f} tokens\n")
|
||||||
|
f.write(f"Max chunk size: {df['tokens'].max():.2f} tokens\n")
|
||||||
|
f.write(f"Average processing time: {df['processing_time'].mean():.2f}s\n")
|
||||||
|
f.write(f"Average output length: {df['output_length'].mean():.2f}s\n\n")
|
||||||
|
|
||||||
|
f.write("Performance Ranges:\n")
|
||||||
|
f.write(f"Processing rate range: {df['tokens_per_second'].min():.2f} - {df['tokens_per_second'].max():.2f} tokens/second\n")
|
||||||
|
f.write(f"RTF range: {df['rtf'].min():.2f}x - {df['rtf'].max():.2f}x\n")
|
||||||
|
|
||||||
|
plt.style.use("dark_background")
|
||||||
|
|
||||||
|
# Plot Processing Time vs Token Count
|
||||||
|
fig, ax = plt.subplots(figsize=(12, 8))
|
||||||
|
sns.scatterplot(data=df, x="tokens", y="processing_time", s=100, alpha=0.6, color="#ff2a6d")
|
||||||
|
sns.regplot(data=df, x="tokens", y="processing_time", scatter=False, color="#05d9e8", line_kws={"linewidth": 2})
|
||||||
|
corr = df["tokens"].corr(df["processing_time"])
|
||||||
|
plt.text(0.05, 0.95, f"Correlation: {corr:.2f}", transform=ax.transAxes, fontsize=10, color="#ffffff",
|
||||||
|
bbox=dict(facecolor="#1a1a2e", edgecolor="#ffffff", alpha=0.7))
|
||||||
|
setup_plot(fig, ax, "Processing Time vs Input Size")
|
||||||
|
ax.set_xlabel("Number of Input Tokens")
|
||||||
|
ax.set_ylabel("Processing Time (seconds)")
|
||||||
|
plt.savefig("examples/benchmarks/processing_time_rtf.png", dpi=300, bbox_inches="tight")
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
# Plot RTF vs Token Count
|
||||||
|
fig, ax = plt.subplots(figsize=(12, 8))
|
||||||
|
sns.scatterplot(data=df, x="tokens", y="rtf", s=100, alpha=0.6, color="#ff2a6d")
|
||||||
|
sns.regplot(data=df, x="tokens", y="rtf", scatter=False, color="#05d9e8", line_kws={"linewidth": 2})
|
||||||
|
corr = df["tokens"].corr(df["rtf"])
|
||||||
|
plt.text(0.05, 0.95, f"Correlation: {corr:.2f}", transform=ax.transAxes, fontsize=10, color="#ffffff",
|
||||||
|
bbox=dict(facecolor="#1a1a2e", edgecolor="#ffffff", alpha=0.7))
|
||||||
|
setup_plot(fig, ax, "Real-Time Factor vs Input Size")
|
||||||
|
ax.set_xlabel("Number of Input Tokens")
|
||||||
|
ax.set_ylabel("Real-Time Factor (processing time / audio length)")
|
||||||
|
plt.savefig("examples/benchmarks/realtime_factor_rtf.png", dpi=300, bbox_inches="tight")
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
plot_system_metrics(system_metrics)
|
||||||
|
|
||||||
|
print("\nResults saved to:")
|
||||||
|
print("- examples/benchmarks/benchmark_results_rtf.json")
|
||||||
|
print("- examples/benchmarks/benchmark_stats_rtf.txt")
|
||||||
|
print("- examples/benchmarks/processing_time_rtf.png")
|
||||||
|
print("- examples/benchmarks/realtime_factor_rtf.png")
|
||||||
|
print("- examples/benchmarks/system_usage_rtf.png")
|
||||||
|
print("\nAudio files saved in examples/benchmarks/output/")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
BIN
examples/benchmarks/processing_time_cpu.png
Normal file
BIN
examples/benchmarks/processing_time_cpu.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 254 KiB |
BIN
examples/benchmarks/realtime_factor_cpu.png
Normal file
BIN
examples/benchmarks/realtime_factor_cpu.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 208 KiB |
Loading…
Add table
Reference in a new issue