Merge pull request #7 from remsky/feat/onnx-inference

Feat/onnx inference
Added some optimization options for ONNX
Refactoring phonemizer/tokenizer services
Cleaned up benchmarking and check scripts
Added auto-wav validationd
This commit is contained in:
remsky 2025-01-04 02:48:21 -07:00 committed by GitHub
commit fe114c3367
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
67 changed files with 6407 additions and 1548 deletions

View file

@ -6,6 +6,7 @@ omit =
Kokoro-82M/*
MagicMock/*
test_*.py
examples/*
[report]
exclude_lines =

10
.gitignore vendored
View file

@ -1,5 +1,6 @@
output/
output/*
output_audio/*
ui/data/*
*.db
@ -16,3 +17,10 @@ env/
.coverage
examples/assorted_checks/benchmarks/output_audio/*
examples/assorted_checks/test_combinations/output/*
examples/assorted_checks/test_openai/output/*
examples/assorted_checks/test_voices/output/*
examples/assorted_checks/test_formats/output/*
ui/RepoScreenshot.png

View file

@ -2,6 +2,20 @@
Notable changes to this project will be documented in this file.
## 2025-01-04
### Added
- ONNX Support:
- Added single batch ONNX support for CPU inference
- Roughly 0.4 RTF (2.4x real-time speed)
### Modified
- Code Refactoring:
- Work on modularizing phonemizer and tokenizer into separate services
- Incorporated these services into a dev endpoint
- Testing and Benchmarking:
- Cleaned up benchmarking scripts
- Cleaned up test scripts
- Added auto-WAV validation scripts
## 2025-01-02
- Audio Format Support:

View file

@ -10,8 +10,9 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# Install PyTorch CPU version
RUN pip3 install --no-cache-dir torch==2.5.1 --extra-index-url https://download.pytorch.org/whl/cpu
# Install PyTorch CPU version and ONNX runtime
RUN pip3 install --no-cache-dir torch==2.5.1 --extra-index-url https://download.pytorch.org/whl/cpu && \
pip3 install --no-cache-dir onnxruntime==1.20.1
# Install all other dependencies from requirements.txt
COPY requirements.txt .

View file

@ -3,8 +3,8 @@
</p>
# Kokoro TTS API
[![Tests](https://img.shields.io/badge/tests-89%20passed-darkgreen)]()
[![Coverage](https://img.shields.io/badge/coverage-80%25-darkgreen)]()
[![Tests](https://img.shields.io/badge/tests-95%20passed-darkgreen)]()
[![Coverage](https://img.shields.io/badge/coverage-72%25-darkgreen)]()
[![Tested at Model Commit](https://img.shields.io/badge/last--tested--model--commit-a67f113-blue)](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667) [![Try on Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Try%20on-Spaces-blue)](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero)
Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model
@ -187,15 +187,13 @@ Key Performance Metrics:
<summary>GPU Vs. CPU</summary>
```bash
# GPU: Requires NVIDIA GPU with CUDA 12.1 support
# GPU: Requires NVIDIA GPU with CUDA 12.1 support (~35x realtime speed)
docker compose up --build
# CPU: ~10x slower than GPU inference
# CPU: ONNX optimized inference (~2.4x realtime speed)
docker compose -f docker-compose.cpu.yml up --build
```
*Note: CPU Inference is currently a very basic implementation, and not heavily tested*
</details>
<details>

View file

@ -14,9 +14,18 @@ class Settings(BaseSettings):
output_dir_size_limit_mb: float = 500.0 # Maximum size of output directory in MB
default_voice: str = "af"
model_dir: str = "/app/Kokoro-82M" # Base directory for model files
model_path: str = "kokoro-v0_19.pth"
pytorch_model_path: str = "kokoro-v0_19.pth"
onnx_model_path: str = "kokoro-v0_19.onnx"
voices_dir: str = "voices"
sample_rate: int = 24000
# ONNX Optimization Settings
onnx_num_threads: int = 4 # Number of threads for intra-op parallelism
onnx_inter_op_threads: int = 4 # Number of threads for inter-op parallelism
onnx_execution_mode: str = "parallel" # parallel or sequential
onnx_optimization_level: str = "all" # all, basic, or disabled
onnx_memory_pattern: bool = True # Enable memory pattern optimization
onnx_arena_extend_strategy: str = "kNextPowerOfTwo" # Memory allocation strategy
class Config:
env_file = ".env"

185
api/src/core/kokoro.py Normal file
View file

@ -0,0 +1,185 @@
import re
import torch
import phonemizer
def split_num(num):
num = num.group()
if "." in num:
return num
elif ":" in num:
h, m = [int(n) for n in num.split(":")]
if m == 0:
return f"{h} o'clock"
elif m < 10:
return f"{h} oh {m}"
return f"{h} {m}"
year = int(num[:4])
if year < 1100 or year % 1000 < 10:
return num
left, right = num[:2], int(num[2:4])
s = "s" if num.endswith("s") else ""
if 100 <= year % 1000 <= 999:
if right == 0:
return f"{left} hundred{s}"
elif right < 10:
return f"{left} oh {right}{s}"
return f"{left} {right}{s}"
def flip_money(m):
m = m.group()
bill = "dollar" if m[0] == "$" else "pound"
if m[-1].isalpha():
return f"{m[1:]} {bill}s"
elif "." not in m:
s = "" if m[1:] == "1" else "s"
return f"{m[1:]} {bill}{s}"
b, c = m[1:].split(".")
s = "" if b == "1" else "s"
c = int(c.ljust(2, "0"))
coins = (
f"cent{'' if c == 1 else 's'}"
if m[0] == "$"
else ("penny" if c == 1 else "pence")
)
return f"{b} {bill}{s} and {c} {coins}"
def point_num(num):
a, b = num.group().split(".")
return " point ".join([a, " ".join(b)])
def normalize_text(text):
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
text = text.replace("«", chr(8220)).replace("»", chr(8221))
text = text.replace(chr(8220), '"').replace(chr(8221), '"')
text = text.replace("(", "«").replace(")", "»")
for a, b in zip("、。!,:;?", ",.!,:;?"):
text = text.replace(a, b + " ")
text = re.sub(r"[^\S \n]", " ", text)
text = re.sub(r" +", " ", text)
text = re.sub(r"(?<=\n) +(?=\n)", "", text)
text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text)
text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text)
text = re.sub(r"\b(?:Ms\.|MS\.(?= [A-Z]))", "Miss", text)
text = re.sub(r"\b(?:Mrs\.|MRS\.(?= [A-Z]))", "Mrs", text)
text = re.sub(r"\betc\.(?! [A-Z])", "etc", text)
text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text)
text = re.sub(
r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)", split_num, text
)
text = re.sub(r"(?<=\d),(?=\d)", "", text)
text = re.sub(
r"(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b",
flip_money,
text,
)
text = re.sub(r"\d*\.\d+", point_num, text)
text = re.sub(r"(?<=\d)-(?=\d)", " to ", text)
text = re.sub(r"(?<=\d)S", " S", text)
text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
text = re.sub(r"(?<=X')S\b", "s", text)
text = re.sub(
r"(?:[A-Za-z]\.){2,} [a-z]", lambda m: m.group().replace(".", "-"), text
)
text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
return text.strip()
def get_vocab():
_pad = "$"
_punctuation = ';:,.!?¡¿—…"«»“” '
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'"
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
dicts = {}
for i in range(len((symbols))):
dicts[symbols[i]] = i
return dicts
VOCAB = get_vocab()
def tokenize(ps):
return [i for i in map(VOCAB.get, ps) if i is not None]
phonemizers = dict(
a=phonemizer.backend.EspeakBackend(
language="en-us", preserve_punctuation=True, with_stress=True
),
b=phonemizer.backend.EspeakBackend(
language="en-gb", preserve_punctuation=True, with_stress=True
),
)
def phonemize(text, lang, norm=True):
if norm:
text = normalize_text(text)
ps = phonemizers[lang].phonemize([text])
ps = ps[0] if ps else ""
# https://en.wiktionary.org/wiki/kokoro#English
ps = ps.replace("kəkˈoːɹoʊ", "kˈoʊkəɹoʊ").replace("kəkˈɔːɹəʊ", "kˈəʊkəɹəʊ")
ps = ps.replace("ʲ", "j").replace("r", "ɹ").replace("x", "k").replace("ɬ", "l")
ps = re.sub(r"(?<=[a-zɹː])(?=hˈʌndɹɪd)", " ", ps)
ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»“” ]|$)', "z", ps)
if lang == "a":
ps = re.sub(r"(?<=nˈaɪn)ti(?!ː)", "di", ps)
ps = "".join(filter(lambda p: p in VOCAB, ps))
return ps.strip()
def length_to_mask(lengths):
mask = (
torch.arange(lengths.max())
.unsqueeze(0)
.expand(lengths.shape[0], -1)
.type_as(lengths)
)
mask = torch.gt(mask + 1, lengths.unsqueeze(1))
return mask
@torch.no_grad()
def forward(model, tokens, ref_s, speed):
device = ref_s.device
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
text_mask = length_to_mask(input_lengths).to(device)
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
s = ref_s[:, 128:]
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
x, _ = model.predictor.lstm(d)
duration = model.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long()
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
c_frame += pred_dur[0, i].item()
en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
t_en = model.text_encoder(tokens, input_lengths, text_mask)
asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
def generate(model, text, voicepack, lang="a", speed=1):
ps = phonemize(text, lang)
tokens = tokenize(ps)
if not tokens:
return None
elif len(tokens) > 510:
tokens = tokens[:510]
print("Truncated to 510 tokens")
ref_s = voicepack[len(tokens)]
out = forward(model, tokens, ref_s, speed)
ps = "".join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
return out, ps

View file

@ -10,8 +10,10 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from .core.config import settings
from .services.tts import TTSModel, TTSService
from .services.tts_model import TTSModel
from .services.tts_service import TTSService
from .routers.openai_compatible import router as openai_router
from .routers.text_processing import router as text_router
@asynccontextmanager
@ -20,8 +22,8 @@ async def lifespan(app: FastAPI):
logger.info("Loading TTS model and voice packs...")
# Initialize the main model with warm-up
model, voicepack_count = TTSModel.initialize()
logger.info(f"Model loaded and warmed up on {TTSModel._device}")
voicepack_count = TTSModel.setup()
logger.info(f"Model loaded and warmed up on {TTSModel.get_device()}")
logger.info(f"{voicepack_count} voice packs loaded successfully")
yield
@ -44,8 +46,9 @@ app.add_middleware(
allow_headers=["*"],
)
# Include OpenAI compatible router
# Include routers
app.include_router(openai_router, prefix="/v1")
app.include_router(text_router)
# Health check endpoint

View file

@ -3,7 +3,7 @@ from typing import List
from loguru import logger
from fastapi import Depends, Response, APIRouter, HTTPException
from ..services.tts import TTSService
from ..services.tts_service import TTSService
from ..services.audio import AudioService
from ..structures.schemas import OpenAISpeechRequest
@ -15,9 +15,7 @@ router = APIRouter(
def get_tts_service() -> TTSService:
"""Dependency to get TTSService instance with database session"""
return TTSService(
start_worker=False
) # Don't start worker thread for OpenAI endpoint
return TTSService() # Initialize TTSService with default settings
@router.post("/audio/speech")

View file

@ -0,0 +1,30 @@
from fastapi import APIRouter
from ..structures.text_schemas import PhonemeRequest, PhonemeResponse
from ..services.text_processing import phonemize, tokenize
router = APIRouter(
prefix="/text",
tags=["text processing"]
)
@router.post("/phonemize", response_model=PhonemeResponse)
async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse:
"""Convert text to phonemes and tokens: Rough attempt
Args:
request: Request containing text and language
Returns:
Phonemes and token IDs
"""
# Get phonemes
phonemes = phonemize(request.text, request.language)
# Get tokens
tokens = tokenize(phonemes)
tokens = [0] + tokens + [0] # Add start/end tokens
return PhonemeResponse(
phonemes=phonemes,
tokens=tokens
)

View file

@ -1,3 +1,3 @@
from .tts import TTSModel, TTSService
from .tts_service import TTSService
__all__ = ["TTSService", "TTSModel"]
__all__ = ["TTSService"]

View file

@ -0,0 +1,13 @@
from .normalizer import normalize_text
from .phonemizer import phonemize, PhonemizerBackend, EspeakBackend
from .vocabulary import tokenize, decode_tokens, VOCAB
__all__ = [
'normalize_text',
'phonemize',
'tokenize',
'decode_tokens',
'VOCAB',
'PhonemizerBackend',
'EspeakBackend'
]

View file

@ -0,0 +1,111 @@
import re
def split_num(num: re.Match) -> str:
"""Handle number splitting for various formats"""
num = num.group()
if "." in num:
return num
elif ":" in num:
h, m = [int(n) for n in num.split(":")]
if m == 0:
return f"{h} o'clock"
elif m < 10:
return f"{h} oh {m}"
return f"{h} {m}"
year = int(num[:4])
if year < 1100 or year % 1000 < 10:
return num
left, right = num[:2], int(num[2:4])
s = "s" if num.endswith("s") else ""
if 100 <= year % 1000 <= 999:
if right == 0:
return f"{left} hundred{s}"
elif right < 10:
return f"{left} oh {right}{s}"
return f"{left} {right}{s}"
def handle_money(m: re.Match) -> str:
"""Convert money expressions to spoken form"""
m = m.group()
bill = "dollar" if m[0] == "$" else "pound"
if m[-1].isalpha():
return f"{m[1:]} {bill}s"
elif "." not in m:
s = "" if m[1:] == "1" else "s"
return f"{m[1:]} {bill}{s}"
b, c = m[1:].split(".")
s = "" if b == "1" else "s"
c = int(c.ljust(2, "0"))
coins = (
f"cent{'' if c == 1 else 's'}"
if m[0] == "$"
else ("penny" if c == 1 else "pence")
)
return f"{b} {bill}{s} and {c} {coins}"
def handle_decimal(num: re.Match) -> str:
"""Convert decimal numbers to spoken form"""
a, b = num.group().split(".")
return " point ".join([a, " ".join(b)])
def normalize_text(text: str) -> str:
"""Normalize text for TTS processing
Args:
text: Input text to normalize
Returns:
Normalized text
"""
# Replace quotes and brackets
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
text = text.replace("«", chr(8220)).replace("»", chr(8221))
text = text.replace(chr(8220), '"').replace(chr(8221), '"')
text = text.replace("(", "«").replace(")", "»")
# Handle CJK punctuation
for a, b in zip("、。!,:;?", ",.!,:;?"):
text = text.replace(a, b + " ")
# Clean up whitespace
text = re.sub(r"[^\S \n]", " ", text)
text = re.sub(r" +", " ", text)
text = re.sub(r"(?<=\n) +(?=\n)", "", text)
# Handle titles and abbreviations
text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text)
text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text)
text = re.sub(r"\b(?:Ms\.|MS\.(?= [A-Z]))", "Miss", text)
text = re.sub(r"\b(?:Mrs\.|MRS\.(?= [A-Z]))", "Mrs", text)
text = re.sub(r"\betc\.(?! [A-Z])", "etc", text)
# Handle common words
text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text)
# Handle numbers and money
text = re.sub(
r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)",
split_num,
text
)
text = re.sub(r"(?<=\d),(?=\d)", "", text)
text = re.sub(
r"(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b",
handle_money,
text,
)
text = re.sub(r"\d*\.\d+", handle_decimal, text)
# Handle various formatting
text = re.sub(r"(?<=\d)-(?=\d)", " to ", text)
text = re.sub(r"(?<=\d)S", " S", text)
text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
text = re.sub(r"(?<=X')S\b", "s", text)
text = re.sub(
r"(?:[A-Za-z]\.){2,} [a-z]",
lambda m: m.group().replace(".", "-"),
text
)
text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
return text.strip()

View file

@ -0,0 +1,97 @@
import re
from abc import ABC, abstractmethod
import phonemizer
from .normalizer import normalize_text
class PhonemizerBackend(ABC):
"""Abstract base class for phonemization backends"""
@abstractmethod
def phonemize(self, text: str) -> str:
"""Convert text to phonemes
Args:
text: Text to convert to phonemes
Returns:
Phonemized text
"""
pass
class EspeakBackend(PhonemizerBackend):
"""Espeak-based phonemizer implementation"""
def __init__(self, language: str):
"""Initialize espeak backend
Args:
language: Language code ('en-us' or 'en-gb')
"""
self.backend = phonemizer.backend.EspeakBackend(
language=language,
preserve_punctuation=True,
with_stress=True
)
self.language = language
def phonemize(self, text: str) -> str:
"""Convert text to phonemes using espeak
Args:
text: Text to convert to phonemes
Returns:
Phonemized text
"""
# Phonemize text
ps = self.backend.phonemize([text])
ps = ps[0] if ps else ""
# Handle special cases
ps = ps.replace("kəkˈoːɹoʊ", "kˈoʊkəɹoʊ").replace("kəkˈɔːɹəʊ", "kˈəʊkəɹəʊ")
ps = ps.replace("ʲ", "j").replace("r", "ɹ").replace("x", "k").replace("ɬ", "l")
ps = re.sub(r"(?<=[a-zɹː])(?=hˈʌndɹɪd)", " ", ps)
ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»"" ]|$)', "z", ps)
# Language-specific rules
if self.language == "en-us":
ps = re.sub(r"(?<=nˈaɪn)ti(?!ː)", "di", ps)
return ps.strip()
def create_phonemizer(language: str = "a") -> PhonemizerBackend:
"""Factory function to create phonemizer backend
Args:
language: Language code ('a' for US English, 'b' for British English)
Returns:
Phonemizer backend instance
"""
# Map language codes to espeak language codes
lang_map = {
"a": "en-us",
"b": "en-gb"
}
if language not in lang_map:
raise ValueError(f"Unsupported language code: {language}")
return EspeakBackend(lang_map[language])
def phonemize(text: str, language: str = "a", normalize: bool = True) -> str:
"""Convert text to phonemes
Args:
text: Text to convert to phonemes
language: Language code ('a' for US English, 'b' for British English)
normalize: Whether to normalize text before phonemization
Returns:
Phonemized text
"""
if normalize:
text = normalize_text(text)
phonemizer = create_phonemizer(language)
return phonemizer.phonemize(text)

View file

@ -0,0 +1,37 @@
def get_vocab():
"""Get the vocabulary dictionary mapping characters to token IDs"""
_pad = "$"
_punctuation = ';:,.!?¡¿—…"«»"" '
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'"
# Create vocabulary dictionary
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
return {symbol: i for i, symbol in enumerate(symbols)}
# Initialize vocabulary
VOCAB = get_vocab()
def tokenize(phonemes: str) -> list[int]:
"""Convert phonemes string to token IDs
Args:
phonemes: String of phonemes to tokenize
Returns:
List of token IDs
"""
return [i for i in map(VOCAB.get, phonemes) if i is not None]
def decode_tokens(tokens: list[int]) -> str:
"""Convert token IDs back to phonemes string
Args:
tokens: List of token IDs
Returns:
String of phonemes
"""
# Create reverse mapping
id_to_symbol = {i: s for s, i in VOCAB.items()}
return "".join(id_to_symbol[t] for t in tokens)

View file

@ -1,286 +0,0 @@
import io
import os
import re
import time
import threading
from typing import List, Tuple, Optional
import numpy as np
import torch
import tiktoken
import scipy.io.wavfile as wavfile
from kokoro import generate, tokenize, phonemize, normalize_text
from loguru import logger
from models import build_model
from ..core.config import settings
enc = tiktoken.get_encoding("cl100k_base")
class TTSModel:
_instance = None
_device = None
_lock = threading.Lock()
# Directory for all voices (copied base voices, and any created combined voices)
VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices")
@classmethod
def initialize(cls):
"""Initialize and warm up the model"""
with cls._lock:
if cls._instance is None:
# Initialize model
cls._device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Initializing model on {cls._device}")
model_path = os.path.join(settings.model_dir, settings.model_path)
model = build_model(model_path, cls._device)
cls._instance = model
# Ensure voices directory exists
os.makedirs(cls.VOICES_DIR, exist_ok=True)
# Copy base voices to local directory
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
if os.path.exists(base_voices_dir):
for file in os.listdir(base_voices_dir):
if file.endswith(".pt"):
voice_name = file[:-3]
voice_path = os.path.join(cls.VOICES_DIR, file)
if not os.path.exists(voice_path):
try:
logger.info(
f"Copying base voice {voice_name} to voices directory"
)
base_path = os.path.join(base_voices_dir, file)
voicepack = torch.load(
base_path,
map_location=cls._device,
weights_only=True,
)
torch.save(voicepack, voice_path)
except Exception as e:
logger.error(
f"Error copying voice {voice_name}: {str(e)}"
)
# Warm up with default voice
try:
dummy_text = "Hello"
voice_path = os.path.join(cls.VOICES_DIR, "af.pt")
dummy_voicepack = torch.load(
voice_path, map_location=cls._device, weights_only=True
)
generate(model, dummy_text, dummy_voicepack, lang="a", speed=1.0)
logger.info("Model warm-up complete")
except Exception as e:
logger.warning(f"Model warm-up failed: {e}")
# Count voices in directory for validation
voice_count = len(
[f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]
)
return cls._instance, voice_count
@classmethod
def get_instance(cls):
"""Get the initialized instance or raise an error"""
if cls._instance is None:
raise RuntimeError("Model not initialized. Call initialize() first.")
return cls._instance, cls._device
class TTSService:
def __init__(self, output_dir: str = None, start_worker: bool = False):
self.output_dir = output_dir
self._ensure_voices()
if start_worker:
self.start_worker()
def _ensure_voices(self):
"""Copy base voices to local voices directory during initialization"""
os.makedirs(TTSModel.VOICES_DIR, exist_ok=True)
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
if os.path.exists(base_voices_dir):
for file in os.listdir(base_voices_dir):
if file.endswith(".pt"):
voice_name = file[:-3]
voice_path = os.path.join(TTSModel.VOICES_DIR, file)
if not os.path.exists(voice_path):
try:
logger.info(
f"Copying base voice {voice_name} to voices directory"
)
base_path = os.path.join(base_voices_dir, file)
voicepack = torch.load(
base_path,
map_location=TTSModel._device,
weights_only=True,
)
torch.save(voicepack, voice_path)
except Exception as e:
logger.error(f"Error copying voice {voice_name}: {str(e)}")
def _split_text(self, text: str) -> List[str]:
"""Split text into sentences"""
return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()]
def _get_voice_path(self, voice_name: str) -> Optional[str]:
"""Get the path to a voice file.
Args:
voice_name: Name of the voice to find
Returns:
Path to the voice file if found, None otherwise
"""
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice_name}.pt")
return voice_path if os.path.exists(voice_path) else None
def _generate_audio(
self, text: str, voice: str, speed: float, stitch_long_output: bool = True
) -> Tuple[torch.Tensor, float]:
"""Generate audio and measure processing time"""
start_time = time.time()
try:
# Normalize text once at the start
text = normalize_text(text)
if not text:
raise ValueError("Text is empty after preprocessing")
# Check voice exists
voice_path = self._get_voice_path(voice)
if not voice_path:
raise ValueError(f"Voice not found: {voice}")
# Load model and voice
model = TTSModel._instance
voicepack = torch.load(
voice_path, map_location=TTSModel._device, weights_only=True
)
# Generate audio with or without stitching
if stitch_long_output:
chunks = self._split_text(text)
audio_chunks = []
# Process all chunks with same model/voicepack instance
for i, chunk in enumerate(chunks):
try:
# Validate phonemization first
# ps = phonemize(chunk, voice[0])
# tokens = tokenize(ps)
# logger.debug(
# f"Processing chunk {i + 1}/{len(chunks)}: {len(tokens)} tokens"
# )
# Only proceed if phonemization succeeded
chunk_audio, _ = generate(
model, chunk, voicepack, lang=voice[0], speed=speed
)
if chunk_audio is not None:
audio_chunks.append(chunk_audio)
else:
logger.error(
f"No audio generated for chunk {i + 1}/{len(chunks)}"
)
except Exception as e:
logger.error(
f"Failed to generate audio for chunk {i + 1}/{len(chunks)}: '{chunk}'. Error: {str(e)}"
)
continue
if not audio_chunks:
raise ValueError("No audio chunks were generated successfully")
audio = (
np.concatenate(audio_chunks)
if len(audio_chunks) > 1
else audio_chunks[0]
)
else:
audio, _ = generate(model, text, voicepack, lang=voice[0], speed=speed)
processing_time = time.time() - start_time
return audio, processing_time
except Exception as e:
print(f"Error in audio generation: {str(e)}")
raise
def _save_audio(self, audio: torch.Tensor, filepath: str):
"""Save audio to file"""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
wavfile.write(filepath, 24000, audio)
def _audio_to_bytes(self, audio: torch.Tensor) -> bytes:
"""Convert audio tensor to WAV bytes"""
buffer = io.BytesIO()
wavfile.write(buffer, 24000, audio)
return buffer.getvalue()
def combine_voices(self, voices: List[str]) -> str:
"""Combine multiple voices into a new voice.
Args:
voices: List of voice names to combine
Returns:
Name of the combined voice
Raises:
ValueError: If less than 2 voices provided or voice loading fails
RuntimeError: If voice combination or saving fails
"""
if len(voices) < 2:
raise ValueError("At least 2 voices are required for combination")
# Load voices
t_voices: List[torch.Tensor] = []
v_name: List[str] = []
for voice in voices:
try:
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt")
voicepack = torch.load(
voice_path, map_location=TTSModel._device, weights_only=True
)
t_voices.append(voicepack)
v_name.append(voice)
except Exception as e:
raise ValueError(f"Failed to load voice {voice}: {str(e)}")
# Combine voices
try:
f: str = "_".join(v_name)
v = torch.mean(torch.stack(t_voices), dim=0)
combined_path = os.path.join(TTSModel.VOICES_DIR, f"{f}.pt")
# Save combined voice
try:
torch.save(v, combined_path)
except Exception as e:
raise RuntimeError(
f"Failed to save combined voice to {combined_path}: {str(e)}"
)
return f
except Exception as e:
if not isinstance(e, (ValueError, RuntimeError)):
raise RuntimeError(f"Error combining voices: {str(e)}")
raise
def list_voices(self) -> List[str]:
"""List all available voices"""
voices = []
try:
for file in os.listdir(TTSModel.VOICES_DIR):
if file.endswith(".pt"):
voices.append(file[:-3]) # Remove .pt extension
except Exception as e:
logger.error(f"Error listing voices: {str(e)}")
return sorted(voices)

View file

@ -0,0 +1,136 @@
import os
import threading
from abc import ABC, abstractmethod
from typing import List, Tuple
import torch
import numpy as np
from loguru import logger
from ..core.config import settings
class TTSBaseModel(ABC):
_instance = None
_lock = threading.Lock()
_device = None
VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices")
@classmethod
def setup(cls):
"""Initialize model and setup voices"""
with cls._lock:
# Set device
cuda_available = torch.cuda.is_available()
logger.info(f"CUDA available: {cuda_available}")
if cuda_available:
try:
# Test CUDA device
test_tensor = torch.zeros(1).cuda()
logger.info("CUDA test successful")
model_path = os.path.join(settings.model_dir, settings.pytorch_model_path)
cls._device = "cuda"
except Exception as e:
logger.error(f"CUDA test failed: {e}")
cls._device = "cpu"
else:
cls._device = "cpu"
model_path = os.path.join(settings.model_dir, settings.onnx_model_path)
logger.info(f"Initializing model on {cls._device}")
# Initialize model
if not cls.initialize(settings.model_dir, model_path=model_path):
raise RuntimeError(f"Failed to initialize {cls._device.upper()} model")
# Setup voices directory
os.makedirs(cls.VOICES_DIR, exist_ok=True)
# Copy base voices to local directory
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
if os.path.exists(base_voices_dir):
for file in os.listdir(base_voices_dir):
if file.endswith(".pt"):
voice_name = file[:-3]
voice_path = os.path.join(cls.VOICES_DIR, file)
if not os.path.exists(voice_path):
try:
logger.info(f"Copying base voice {voice_name} to voices directory")
base_path = os.path.join(base_voices_dir, file)
voicepack = torch.load(base_path, map_location=cls._device, weights_only=True)
torch.save(voicepack, voice_path)
except Exception as e:
logger.error(f"Error copying voice {voice_name}: {str(e)}")
# Warm up with default voice
try:
dummy_text = "Hello"
voice_path = os.path.join(cls.VOICES_DIR, "af.pt")
dummy_voicepack = torch.load(voice_path, map_location=cls._device, weights_only=True)
# Process text and generate audio
phonemes, tokens = cls.process_text(dummy_text, "a")
cls.generate_from_tokens(tokens, dummy_voicepack, 1.0)
logger.info("Model warm-up complete")
except Exception as e:
logger.warning(f"Model warm-up failed: {e}")
# Count voices in directory
voice_count = len([f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")])
return voice_count
@classmethod
@abstractmethod
def initialize(cls, model_dir: str, model_path: str = None):
"""Initialize the model"""
pass
@classmethod
@abstractmethod
def process_text(cls, text: str, language: str) -> Tuple[str, List[int]]:
"""Process text into phonemes and tokens
Args:
text: Input text
language: Language code
Returns:
tuple[str, list[int]]: Phonemes and token IDs
"""
pass
@classmethod
@abstractmethod
def generate_from_text(cls, text: str, voicepack: torch.Tensor, language: str, speed: float) -> Tuple[np.ndarray, str]:
"""Generate audio from text
Args:
text: Input text
voicepack: Voice tensor
language: Language code
speed: Speed factor
Returns:
tuple[np.ndarray, str]: Generated audio samples and phonemes
"""
pass
@classmethod
@abstractmethod
def generate_from_tokens(cls, tokens: List[int], voicepack: torch.Tensor, speed: float) -> np.ndarray:
"""Generate audio from tokens
Args:
tokens: Token IDs
voicepack: Voice tensor
speed: Speed factor
Returns:
np.ndarray: Generated audio samples
"""
pass
@classmethod
def get_device(cls):
"""Get the current device"""
if cls._device is None:
raise RuntimeError("Model not initialized. Call setup() first.")
return cls._device

144
api/src/services/tts_cpu.py Normal file
View file

@ -0,0 +1,144 @@
import os
import numpy as np
import torch
from onnxruntime import InferenceSession, SessionOptions, GraphOptimizationLevel, ExecutionMode
from loguru import logger
from .tts_base import TTSBaseModel
from .text_processing import phonemize, tokenize
from ..core.config import settings
class TTSCPUModel(TTSBaseModel):
_instance = None
_onnx_session = None
@classmethod
def initialize(cls, model_dir: str, model_path: str = None):
"""Initialize ONNX model for CPU inference"""
if cls._onnx_session is None:
# Try loading ONNX model
onnx_path = os.path.join(model_dir, settings.onnx_model_path)
if os.path.exists(onnx_path):
logger.info(f"Loading ONNX model from {onnx_path}")
else:
logger.error(f"ONNX model not found at {onnx_path}")
return None
if not onnx_path:
return None
logger.info(f"Loading ONNX model from {onnx_path}")
# Configure ONNX session for optimal performance
session_options = SessionOptions()
# Set optimization level
if settings.onnx_optimization_level == "all":
session_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
elif settings.onnx_optimization_level == "basic":
session_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
else:
session_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
# Configure threading
session_options.intra_op_num_threads = settings.onnx_num_threads
session_options.inter_op_num_threads = settings.onnx_inter_op_threads
# Set execution mode
session_options.execution_mode = (
ExecutionMode.ORT_PARALLEL
if settings.onnx_execution_mode == "parallel"
else ExecutionMode.ORT_SEQUENTIAL
)
# Enable/disable memory pattern optimization
session_options.enable_mem_pattern = settings.onnx_memory_pattern
# Configure CPU provider options
provider_options = {
'CPUExecutionProvider': {
'arena_extend_strategy': settings.onnx_arena_extend_strategy,
'cpu_memory_arena_cfg': 'cpu:0'
}
}
cls._onnx_session = InferenceSession(
onnx_path,
sess_options=session_options,
providers=['CPUExecutionProvider'],
provider_options=[provider_options]
)
return cls._onnx_session
return cls._onnx_session
@classmethod
def process_text(cls, text: str, language: str) -> tuple[str, list[int]]:
"""Process text into phonemes and tokens
Args:
text: Input text
language: Language code
Returns:
tuple[str, list[int]]: Phonemes and token IDs
"""
phonemes = phonemize(text, language)
tokens = tokenize(phonemes)
tokens = [0] + tokens + [0] # Add start/end tokens
return phonemes, tokens
@classmethod
def generate_from_text(cls, text: str, voicepack: torch.Tensor, language: str, speed: float) -> tuple[np.ndarray, str]:
"""Generate audio from text
Args:
text: Input text
voicepack: Voice tensor
language: Language code
speed: Speed factor
Returns:
tuple[np.ndarray, str]: Generated audio samples and phonemes
"""
if cls._onnx_session is None:
raise RuntimeError("ONNX model not initialized")
# Process text
phonemes, tokens = cls.process_text(text, language)
# Generate audio
audio = cls.generate_from_tokens(tokens, voicepack, speed)
return audio, phonemes
@classmethod
def generate_from_tokens(cls, tokens: list[int], voicepack: torch.Tensor, speed: float) -> np.ndarray:
"""Generate audio from tokens
Args:
tokens: Token IDs
voicepack: Voice tensor
speed: Speed factor
Returns:
np.ndarray: Generated audio samples
"""
if cls._onnx_session is None:
raise RuntimeError("ONNX model not initialized")
# Pre-allocate and prepare inputs
tokens_input = np.array([tokens], dtype=np.int64)
style_input = voicepack[len(tokens)-2].numpy() # Already has correct dimensions
speed_input = np.full(1, speed, dtype=np.float32) # More efficient than ones * speed
# Run inference with optimized inputs
result = cls._onnx_session.run(
None,
{
'tokens': tokens_input,
'style': style_input,
'speed': speed_input
}
)
return result[0]

127
api/src/services/tts_gpu.py Normal file
View file

@ -0,0 +1,127 @@
import os
import numpy as np
import torch
from loguru import logger
from models import build_model
from .text_processing import phonemize, tokenize
from .tts_base import TTSBaseModel
from ..core.config import settings
@torch.no_grad()
def forward(model, tokens, ref_s, speed):
"""Forward pass through the model"""
device = ref_s.device
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
text_mask = length_to_mask(input_lengths).to(device)
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
s = ref_s[:, 128:]
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
x, _ = model.predictor.lstm(d)
duration = model.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long()
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
c_frame += pred_dur[0, i].item()
en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
t_en = model.text_encoder(tokens, input_lengths, text_mask)
asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
def length_to_mask(lengths):
"""Create attention mask from lengths"""
mask = (
torch.arange(lengths.max())
.unsqueeze(0)
.expand(lengths.shape[0], -1)
.type_as(lengths)
)
mask = torch.gt(mask + 1, lengths.unsqueeze(1))
return mask
class TTSGPUModel(TTSBaseModel):
_instance = None
_device = "cuda"
@classmethod
def initialize(cls, model_dir: str, model_path: str):
"""Initialize PyTorch model for GPU inference"""
if cls._instance is None and torch.cuda.is_available():
try:
logger.info("Initializing GPU model")
model_path = os.path.join(model_dir, settings.pytorch_model_path)
model = build_model(model_path, cls._device)
cls._instance = model
return cls._instance
except Exception as e:
logger.error(f"Failed to initialize GPU model: {e}")
return None
return cls._instance
@classmethod
def process_text(cls, text: str, language: str) -> tuple[str, list[int]]:
"""Process text into phonemes and tokens
Args:
text: Input text
language: Language code
Returns:
tuple[str, list[int]]: Phonemes and token IDs
"""
phonemes = phonemize(text, language)
tokens = tokenize(phonemes)
return phonemes, tokens
@classmethod
def generate_from_text(cls, text: str, voicepack: torch.Tensor, language: str, speed: float) -> tuple[np.ndarray, str]:
"""Generate audio from text
Args:
text: Input text
voicepack: Voice tensor
language: Language code
speed: Speed factor
Returns:
tuple[np.ndarray, str]: Generated audio samples and phonemes
"""
if cls._instance is None:
raise RuntimeError("GPU model not initialized")
# Process text
phonemes, tokens = cls.process_text(text, language)
# Generate audio
audio = cls.generate_from_tokens(tokens, voicepack, speed)
return audio, phonemes
@classmethod
def generate_from_tokens(cls, tokens: list[int], voicepack: torch.Tensor, speed: float) -> np.ndarray:
"""Generate audio from tokens
Args:
tokens: Token IDs
voicepack: Voice tensor
speed: Speed factor
Returns:
np.ndarray: Generated audio samples
"""
if cls._instance is None:
raise RuntimeError("GPU model not initialized")
# Get reference style
ref_s = voicepack[len(tokens)]
# Generate audio
audio = forward(cls._instance, tokens, ref_s, speed)
return audio

View file

@ -0,0 +1,8 @@
import torch
if torch.cuda.is_available():
from .tts_gpu import TTSGPUModel as TTSModel
else:
from .tts_cpu import TTSCPUModel as TTSModel
__all__ = ["TTSModel"]

View file

@ -0,0 +1,161 @@
import io
import os
import re
import time
from typing import List, Tuple, Optional
import numpy as np
import torch
import scipy.io.wavfile as wavfile
from .text_processing import normalize_text
from loguru import logger
from ..core.config import settings
from .tts_model import TTSModel
class TTSService:
def __init__(self, output_dir: str = None):
self.output_dir = output_dir
def _split_text(self, text: str) -> List[str]:
"""Split text into sentences"""
if not isinstance(text, str):
text = str(text) if text is not None else ""
return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()]
def _get_voice_path(self, voice_name: str) -> Optional[str]:
"""Get the path to a voice file"""
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice_name}.pt")
return voice_path if os.path.exists(voice_path) else None
def _generate_audio(
self, text: str, voice: str, speed: float, stitch_long_output: bool = True
) -> Tuple[torch.Tensor, float]:
"""Generate audio and measure processing time"""
start_time = time.time()
try:
# Normalize text once at the start
if not text:
raise ValueError("Text is empty after preprocessing")
normalized = normalize_text(text)
if not normalized:
raise ValueError("Text is empty after preprocessing")
text = str(normalized)
# Check voice exists
voice_path = self._get_voice_path(voice)
if not voice_path:
raise ValueError(f"Voice not found: {voice}")
# Load voice
voicepack = torch.load(
voice_path, map_location=TTSModel.get_device(), weights_only=True
)
# Generate audio with or without stitching
if stitch_long_output:
chunks = self._split_text(text)
audio_chunks = []
# Process all chunks
for i, chunk in enumerate(chunks):
try:
# Process text and generate audio
phonemes, tokens = TTSModel.process_text(chunk, voice[0])
chunk_audio = TTSModel.generate_from_tokens(tokens, voicepack, speed)
if chunk_audio is not None:
audio_chunks.append(chunk_audio)
else:
logger.error(f"No audio generated for chunk {i + 1}/{len(chunks)}")
except Exception as e:
logger.error(
f"Failed to generate audio for chunk {i + 1}/{len(chunks)}: '{chunk}'. Error: {str(e)}"
)
continue
if not audio_chunks:
raise ValueError("No audio chunks were generated successfully")
audio = (
np.concatenate(audio_chunks)
if len(audio_chunks) > 1
else audio_chunks[0]
)
else:
# Process single chunk
phonemes, tokens = TTSModel.process_text(text, voice[0])
audio = TTSModel.generate_from_tokens(tokens, voicepack, speed)
processing_time = time.time() - start_time
return audio, processing_time
except Exception as e:
logger.error(f"Error in audio generation: {str(e)}")
raise
def _save_audio(self, audio: torch.Tensor, filepath: str):
"""Save audio to file"""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
wavfile.write(filepath, 24000, audio)
def _audio_to_bytes(self, audio: torch.Tensor) -> bytes:
"""Convert audio tensor to WAV bytes"""
buffer = io.BytesIO()
wavfile.write(buffer, 24000, audio)
return buffer.getvalue()
def combine_voices(self, voices: List[str]) -> str:
"""Combine multiple voices into a new voice"""
if len(voices) < 2:
raise ValueError("At least 2 voices are required for combination")
# Load voices
t_voices: List[torch.Tensor] = []
v_name: List[str] = []
for voice in voices:
try:
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt")
voicepack = torch.load(
voice_path, map_location=TTSModel.get_device(), weights_only=True
)
t_voices.append(voicepack)
v_name.append(voice)
except Exception as e:
raise ValueError(f"Failed to load voice {voice}: {str(e)}")
# Combine voices
try:
f: str = "_".join(v_name)
v = torch.mean(torch.stack(t_voices), dim=0)
combined_path = os.path.join(TTSModel.VOICES_DIR, f"{f}.pt")
# Save combined voice
try:
torch.save(v, combined_path)
except Exception as e:
raise RuntimeError(
f"Failed to save combined voice to {combined_path}: {str(e)}"
)
return f
except Exception as e:
if not isinstance(e, (ValueError, RuntimeError)):
raise RuntimeError(f"Error combining voices: {str(e)}")
raise
def list_voices(self) -> List[str]:
"""List all available voices"""
voices = []
try:
for file in os.listdir(TTSModel.VOICES_DIR):
if file.endswith(".pt"):
voices.append(file[:-3]) # Remove .pt extension
except Exception as e:
logger.error(f"Error listing voices: {str(e)}")
return sorted(voices)

View file

@ -0,0 +1,9 @@
from pydantic import BaseModel
class PhonemeRequest(BaseModel):
text: str
language: str = "a" # Default to American English
class PhonemeResponse(BaseModel):
phonemes: str
tokens: list[int]

View file

@ -21,8 +21,73 @@ def cleanup():
cleanup_mock_dirs()
# Mock torch and other ML modules before they're imported
sys.modules["torch"] = Mock()
# Create mock torch module
mock_torch = Mock()
mock_torch.cuda = Mock()
mock_torch.cuda.is_available = Mock(return_value=False)
# Create a mock tensor class that supports basic operations
class MockTensor:
def __init__(self, data):
self.data = data
if isinstance(data, (list, tuple)):
self.shape = [len(data)]
elif isinstance(data, MockTensor):
self.shape = data.shape
else:
self.shape = getattr(data, 'shape', [1])
def __getitem__(self, idx):
if isinstance(self.data, (list, tuple)):
if isinstance(idx, slice):
return MockTensor(self.data[idx])
return self.data[idx]
return self
def max(self):
if isinstance(self.data, (list, tuple)):
max_val = max(self.data)
return MockTensor(max_val)
return 5 # Default for testing
def item(self):
if isinstance(self.data, (list, tuple)):
return max(self.data)
if isinstance(self.data, (int, float)):
return self.data
return 5 # Default for testing
def cuda(self):
"""Support cuda conversion"""
return self
def any(self):
if isinstance(self.data, (list, tuple)):
return any(self.data)
return False
def all(self):
if isinstance(self.data, (list, tuple)):
return all(self.data)
return True
def unsqueeze(self, dim):
return self
def expand(self, *args):
return self
def type_as(self, other):
return self
# Add tensor operations to mock torch
mock_torch.tensor = lambda x: MockTensor(x)
mock_torch.zeros = lambda *args: MockTensor([0] * (args[0] if isinstance(args[0], int) else args[0][0]))
mock_torch.arange = lambda x: MockTensor(list(range(x)))
mock_torch.gt = lambda x, y: MockTensor([False] * x.shape[0])
# Mock modules before they're imported
sys.modules["torch"] = mock_torch
sys.modules["transformers"] = Mock()
sys.modules["phonemizer"] = Mock()
sys.modules["models"] = Mock()
@ -31,14 +96,22 @@ sys.modules["kokoro"] = Mock()
sys.modules["kokoro.generate"] = Mock()
sys.modules["kokoro.phonemize"] = Mock()
sys.modules["kokoro.tokenize"] = Mock()
sys.modules["onnxruntime"] = Mock()
@pytest.fixture(autouse=True)
def mock_tts_model():
"""Mock TTSModel to avoid loading real models during tests"""
with patch("api.src.services.tts.TTSModel") as mock:
"""Mock TTSModel and TTS model initialization"""
with patch("api.src.services.tts_model.TTSModel") as mock_tts_model, \
patch("api.src.services.tts_base.TTSBaseModel") as mock_base_model:
# Mock TTSModel
model_instance = Mock()
model_instance.get_instance.return_value = model_instance
model_instance.get_voicepack.return_value = None
mock.get_instance.return_value = model_instance
mock_tts_model.get_instance.return_value = model_instance
# Mock TTS model initialization
mock_base_model.setup.return_value = 1 # Return dummy voice count
yield model_instance

View file

@ -26,13 +26,11 @@ def test_health_check(test_client):
@patch("api.src.main.logger")
async def test_lifespan_successful_warmup(mock_logger, mock_tts_model):
"""Test successful model warmup in lifespan"""
# Mock the model initialization with model info and voicepack count
mock_model = MagicMock()
# Mock file system for voice counting
mock_tts_model.VOICES_DIR = "/mock/voices"
with patch("os.listdir", return_value=["voice1.pt", "voice2.pt", "voice3.pt"]):
mock_tts_model.initialize.return_value = (mock_model, 3) # 3 voice files
mock_tts_model._device = "cuda" # Set device class variable
mock_tts_model.setup.return_value = 3 # 3 voice files
mock_tts_model.get_device.return_value = "cuda"
# Create an async generator from the lifespan context manager
async_gen = lifespan(MagicMock())
@ -44,8 +42,8 @@ async def test_lifespan_successful_warmup(mock_logger, mock_tts_model):
mock_logger.info.assert_any_call("Model loaded and warmed up on cuda")
mock_logger.info.assert_any_call("3 voice packs loaded successfully")
# Verify model initialization was called
mock_tts_model.initialize.assert_called_once()
# Verify model setup was called
mock_tts_model.setup.assert_called_once()
# Clean up
await async_gen.__aexit__(None, None, None)
@ -56,14 +54,14 @@ async def test_lifespan_successful_warmup(mock_logger, mock_tts_model):
@patch("api.src.main.logger")
async def test_lifespan_failed_warmup(mock_logger, mock_tts_model):
"""Test failed model warmup in lifespan"""
# Mock the model initialization to fail
mock_tts_model.initialize.side_effect = Exception("Failed to initialize model")
# Mock the model setup to fail
mock_tts_model.setup.side_effect = RuntimeError("Failed to initialize model")
# Create an async generator from the lifespan context manager
async_gen = lifespan(MagicMock())
# Verify the exception is raised
with pytest.raises(Exception, match="Failed to initialize model"):
with pytest.raises(RuntimeError, match="Failed to initialize model"):
await async_gen.__aenter__()
# Verify the expected logging sequence
@ -77,20 +75,18 @@ async def test_lifespan_failed_warmup(mock_logger, mock_tts_model):
@patch("api.src.main.TTSModel")
async def test_lifespan_cuda_warmup(mock_tts_model):
"""Test model warmup specifically on CUDA"""
# Mock the model initialization with CUDA and voicepacks
mock_model = MagicMock()
# Mock file system for voice counting
mock_tts_model.VOICES_DIR = "/mock/voices"
with patch("os.listdir", return_value=["voice1.pt", "voice2.pt"]):
mock_tts_model.initialize.return_value = (mock_model, 2) # 2 voice files
mock_tts_model._device = "cuda" # Set device class variable
mock_tts_model.setup.return_value = 2 # 2 voice files
mock_tts_model.get_device.return_value = "cuda"
# Create an async generator from the lifespan context manager
async_gen = lifespan(MagicMock())
await async_gen.__aenter__()
# Verify model was initialized
mock_tts_model.initialize.assert_called_once()
# Verify model setup was called
mock_tts_model.setup.assert_called_once()
# Clean up
await async_gen.__aexit__(None, None, None)
@ -100,22 +96,20 @@ async def test_lifespan_cuda_warmup(mock_tts_model):
@patch("api.src.main.TTSModel")
async def test_lifespan_cpu_fallback(mock_tts_model):
"""Test model warmup falling back to CPU"""
# Mock the model initialization with CPU and voicepacks
mock_model = MagicMock()
# Mock file system for voice counting
mock_tts_model.VOICES_DIR = "/mock/voices"
with patch(
"os.listdir", return_value=["voice1.pt", "voice2.pt", "voice3.pt", "voice4.pt"]
):
mock_tts_model.initialize.return_value = (mock_model, 4) # 4 voice files
mock_tts_model._device = "cpu" # Set device class variable
mock_tts_model.setup.return_value = 4 # 4 voice files
mock_tts_model.get_device.return_value = "cpu"
# Create an async generator from the lifespan context manager
async_gen = lifespan(MagicMock())
await async_gen.__aenter__()
# Verify model was initialized
mock_tts_model.initialize.assert_called_once()
# Verify model setup was called
mock_tts_model.setup.assert_called_once()
# Clean up
await async_gen.__aexit__(None, None, None)

View file

@ -0,0 +1,144 @@
"""Tests for TTS model implementations"""
import os
import torch
import pytest
import numpy as np
from unittest.mock import patch, MagicMock
from api.src.services.tts_base import TTSBaseModel
from api.src.services.tts_cpu import TTSCPUModel
from api.src.services.tts_gpu import TTSGPUModel, length_to_mask
# Base Model Tests
def test_get_device_error():
"""Test get_device() raises error when not initialized"""
TTSBaseModel._device = None
with pytest.raises(RuntimeError, match="Model not initialized"):
TTSBaseModel.get_device()
@patch('torch.cuda.is_available')
@patch('os.path.exists')
@patch('os.path.join')
@patch('os.listdir')
@patch('torch.load')
@patch('torch.save')
def test_setup_cuda_available(mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available):
"""Test setup with CUDA available"""
TTSBaseModel._device = None
mock_cuda_available.return_value = True
mock_exists.return_value = True
mock_load.return_value = torch.zeros(1)
mock_listdir.return_value = ["voice1.pt", "voice2.pt"]
mock_join.return_value = "/mocked/path"
# Mock the abstract methods
TTSBaseModel.initialize = MagicMock(return_value=True)
TTSBaseModel.process_text = MagicMock(return_value=("dummy", [1,2,3]))
TTSBaseModel.generate_from_tokens = MagicMock(return_value=np.zeros(1000))
voice_count = TTSBaseModel.setup()
assert TTSBaseModel._device == "cuda"
assert voice_count == 2
@patch('torch.cuda.is_available')
@patch('os.path.exists')
@patch('os.path.join')
@patch('os.listdir')
@patch('torch.load')
@patch('torch.save')
def test_setup_cuda_unavailable(mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available):
"""Test setup with CUDA unavailable"""
TTSBaseModel._device = None
mock_cuda_available.return_value = False
mock_exists.return_value = True
mock_load.return_value = torch.zeros(1)
mock_listdir.return_value = ["voice1.pt", "voice2.pt"]
mock_join.return_value = "/mocked/path"
# Mock the abstract methods
TTSBaseModel.initialize = MagicMock(return_value=True)
TTSBaseModel.process_text = MagicMock(return_value=("dummy", [1,2,3]))
TTSBaseModel.generate_from_tokens = MagicMock(return_value=np.zeros(1000))
voice_count = TTSBaseModel.setup()
assert TTSBaseModel._device == "cpu"
assert voice_count == 2
# CPU Model Tests
def test_cpu_initialize_missing_model():
"""Test CPU initialize with missing model"""
with patch('os.path.exists', return_value=False):
result = TTSCPUModel.initialize("dummy_dir")
assert result is None
def test_cpu_generate_uninitialized():
"""Test CPU generate methods with uninitialized model"""
TTSCPUModel._onnx_session = None
with pytest.raises(RuntimeError, match="ONNX model not initialized"):
TTSCPUModel.generate_from_text("test", torch.zeros(1), "en", 1.0)
with pytest.raises(RuntimeError, match="ONNX model not initialized"):
TTSCPUModel.generate_from_tokens([1,2,3], torch.zeros(1), 1.0)
def test_cpu_process_text():
"""Test CPU process_text functionality"""
with patch('api.src.services.tts_cpu.phonemize') as mock_phonemize, \
patch('api.src.services.tts_cpu.tokenize') as mock_tokenize:
mock_phonemize.return_value = "test phonemes"
mock_tokenize.return_value = [1, 2, 3]
phonemes, tokens = TTSCPUModel.process_text("test", "en")
assert phonemes == "test phonemes"
assert tokens == [0, 1, 2, 3, 0] # Should add start/end tokens
# GPU Model Tests
@patch('torch.cuda.is_available')
def test_gpu_initialize_cuda_unavailable(mock_cuda_available):
"""Test GPU initialize with CUDA unavailable"""
mock_cuda_available.return_value = False
TTSGPUModel._instance = None
result = TTSGPUModel.initialize("dummy_dir", "dummy_path")
assert result is None
@patch('api.src.services.tts_gpu.length_to_mask')
def test_gpu_length_to_mask(mock_length_to_mask):
"""Test length_to_mask function"""
# Setup mock return value
expected_mask = torch.tensor([
[False, False, False, True, True],
[False, False, False, False, False]
])
mock_length_to_mask.return_value = expected_mask
# Call function with test input
lengths = torch.tensor([3, 5])
mask = mock_length_to_mask(lengths)
# Verify mock was called with correct input
mock_length_to_mask.assert_called_once()
assert torch.equal(mask, expected_mask)
def test_gpu_generate_uninitialized():
"""Test GPU generate methods with uninitialized model"""
TTSGPUModel._instance = None
with pytest.raises(RuntimeError, match="GPU model not initialized"):
TTSGPUModel.generate_from_text("test", torch.zeros(1), "en", 1.0)
with pytest.raises(RuntimeError, match="GPU model not initialized"):
TTSGPUModel.generate_from_tokens([1,2,3], torch.zeros(1), 1.0)
def test_gpu_process_text():
"""Test GPU process_text functionality"""
with patch('api.src.services.tts_gpu.phonemize') as mock_phonemize, \
patch('api.src.services.tts_gpu.tokenize') as mock_tokenize:
mock_phonemize.return_value = "test phonemes"
mock_tokenize.return_value = [1, 2, 3]
phonemes, tokens = TTSGPUModel.process_text("test", "en")
assert phonemes == "test phonemes"
assert tokens == [1, 2, 3] # GPU implementation doesn't add start/end tokens

View file

@ -6,14 +6,19 @@ from unittest.mock import MagicMock, call, patch
import numpy as np
import torch
import pytest
from onnxruntime import InferenceSession
from api.src.services.tts import TTSModel, TTSService
from api.src.core.config import settings
from api.src.services.tts_model import TTSModel
from api.src.services.tts_service import TTSService
from api.src.services.tts_cpu import TTSCPUModel
from api.src.services.tts_gpu import TTSGPUModel
@pytest.fixture
def tts_service():
"""Create a TTSService instance for testing"""
return TTSService(start_worker=False)
return TTSService()
@pytest.fixture
@ -68,80 +73,143 @@ def test_list_voices(mock_join, mock_listdir, tts_service):
assert "not_a_voice" not in voices
@patch("api.src.services.tts.TTSModel.get_instance")
@patch("api.src.services.tts.TTSModel.get_voicepack")
@patch("api.src.services.tts.normalize_text")
@patch("api.src.services.tts.phonemize")
@patch("api.src.services.tts.tokenize")
@patch("api.src.services.tts.generate")
def test_generate_audio_empty_text(
mock_generate,
mock_tokenize,
mock_phonemize,
mock_normalize,
mock_voicepack,
mock_instance,
tts_service,
):
"""Test generating audio with empty text"""
mock_normalize.return_value = ""
@patch("os.listdir")
def test_list_voices_error(mock_listdir, tts_service):
"""Test error handling in list_voices"""
mock_listdir.side_effect = Exception("Failed to list directory")
voices = tts_service.list_voices()
assert voices == []
def mock_model_setup(cuda_available=False):
"""Helper function to mock model setup"""
# Reset model state
TTSModel._instance = None
TTSModel._device = None
TTSModel._voicepacks = {}
# Create mock model instance with proper generate method
mock_model = MagicMock()
mock_model.generate.return_value = np.zeros(24000, dtype=np.float32)
TTSModel._instance = mock_model
# Set device based on CUDA availability
TTSModel._device = "cuda" if cuda_available else "cpu"
return 3 # Return voice count (including af.pt)
def test_model_initialization_cuda():
"""Test model initialization with CUDA"""
# Simulate CUDA availability
voice_count = mock_model_setup(cuda_available=True)
assert TTSModel.get_device() == "cuda"
assert voice_count == 3 # voice1.pt, voice2.pt, af.pt
def test_model_initialization_cpu():
"""Test model initialization with CPU"""
# Simulate no CUDA availability
voice_count = mock_model_setup(cuda_available=False)
assert TTSModel.get_device() == "cpu"
assert voice_count == 3 # voice1.pt, voice2.pt, af.pt
def test_generate_audio_empty_text(tts_service):
"""Test generating audio with empty text"""
with pytest.raises(ValueError, match="Text is empty after preprocessing"):
tts_service._generate_audio("", "af", 1.0)
@patch("api.src.services.tts.TTSModel.get_instance")
@patch("api.src.services.tts_model.TTSModel.get_instance")
@patch("api.src.services.tts_model.TTSModel.get_device")
@patch("os.path.exists")
@patch("api.src.services.tts.normalize_text")
@patch("api.src.services.tts.phonemize")
@patch("api.src.services.tts.tokenize")
@patch("api.src.services.tts.generate")
@patch("kokoro.normalize_text")
@patch("kokoro.phonemize")
@patch("kokoro.tokenize")
@patch("kokoro.generate")
@patch("torch.load")
def test_generate_audio_no_chunks(
def test_generate_audio_phonemize_error(
mock_torch_load,
mock_generate,
mock_tokenize,
mock_phonemize,
mock_normalize,
mock_exists,
mock_get_device,
mock_instance,
tts_service,
):
"""Test generating audio with no successful chunks"""
"""Test handling phonemization error"""
mock_normalize.return_value = "Test text"
mock_phonemize.return_value = "Test text"
mock_tokenize.return_value = ["test", "text"]
mock_generate.return_value = (None, None)
mock_instance.return_value = (MagicMock(), "cpu")
mock_phonemize.side_effect = Exception("Phonemization failed")
mock_instance.return_value = (mock_generate, "cpu") # Use the same mock for consistency
mock_get_device.return_value = "cpu"
mock_exists.return_value = True
mock_torch_load.return_value = MagicMock()
mock_torch_load.return_value = torch.zeros((10, 24000))
mock_generate.return_value = (None, None)
with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
tts_service._generate_audio("Test text", "af", 1.0)
@patch("torch.load")
@patch("torch.save")
@patch("torch.stack")
@patch("torch.mean")
@patch("api.src.services.tts_model.TTSModel.get_instance")
@patch("api.src.services.tts_model.TTSModel.get_device")
@patch("os.path.exists")
def test_combine_voices(
mock_exists, mock_mean, mock_stack, mock_save, mock_load, tts_service
@patch("kokoro.normalize_text")
@patch("kokoro.phonemize")
@patch("kokoro.tokenize")
@patch("kokoro.generate")
@patch("torch.load")
def test_generate_audio_error(
mock_torch_load,
mock_generate,
mock_tokenize,
mock_phonemize,
mock_normalize,
mock_exists,
mock_get_device,
mock_instance,
tts_service,
):
"""Test combining multiple voices"""
# Setup mocks
"""Test handling generation error"""
mock_normalize.return_value = "Test text"
mock_phonemize.return_value = "Test text"
mock_tokenize.return_value = [1, 2] # Return integers instead of strings
mock_generate.side_effect = Exception("Generation failed")
mock_instance.return_value = (mock_generate, "cpu") # Use the same mock for consistency
mock_get_device.return_value = "cpu"
mock_exists.return_value = True
mock_load.return_value = torch.tensor([1.0, 2.0])
mock_stack.return_value = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
mock_mean.return_value = torch.tensor([2.0, 3.0])
mock_torch_load.return_value = torch.zeros((10, 24000))
# Test combining two voices
result = tts_service.combine_voices(["voice1", "voice2"])
with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
tts_service._generate_audio("Test text", "af", 1.0)
assert result == "voice1_voice2"
mock_stack.assert_called_once()
mock_mean.assert_called_once()
mock_save.assert_called_once()
def test_save_audio(tts_service, sample_audio, tmp_path):
"""Test saving audio to file"""
output_path = os.path.join(tmp_path, "test_output.wav")
tts_service._save_audio(sample_audio, output_path)
assert os.path.exists(output_path)
assert os.path.getsize(output_path) > 0
def test_combine_voices(tts_service):
"""Test combining multiple voices"""
# Setup mocks for torch operations
with patch('torch.load', return_value=torch.tensor([1.0, 2.0])), \
patch('torch.stack', return_value=torch.tensor([[1.0, 2.0], [3.0, 4.0]])), \
patch('torch.mean', return_value=torch.tensor([2.0, 3.0])), \
patch('torch.save'), \
patch('os.path.exists', return_value=True):
# Test combining two voices
result = tts_service.combine_voices(["voice1", "voice2"])
assert result == "voice1_voice2"
def test_combine_voices_invalid_input(tts_service):
@ -155,221 +223,17 @@ def test_combine_voices_invalid_input(tts_service):
tts_service.combine_voices(["voice1"])
@patch("os.makedirs")
@patch("os.path.exists")
@patch("os.listdir")
@patch("torch.load")
@patch("torch.save")
@patch("os.path.join")
def test_ensure_voices(
mock_join,
mock_save,
mock_load,
mock_listdir,
mock_exists,
mock_makedirs,
tts_service,
):
"""Test voice directory initialization"""
# Setup mocks
mock_exists.side_effect = [
True,
False,
False,
] # base_dir exists, voice files don't exist
mock_listdir.return_value = ["voice1.pt", "voice2.pt"]
mock_load.return_value = MagicMock()
mock_join.return_value = "/fake/path"
# Test voice directory initialization
tts_service._ensure_voices()
# Verify directory was created
mock_makedirs.assert_called_once()
# Verify voices were loaded and saved
assert mock_load.call_count == len(mock_listdir.return_value)
assert mock_save.call_count == len(mock_listdir.return_value)
@patch("api.src.services.tts.TTSModel.get_instance")
@patch("os.path.exists")
@patch("api.src.services.tts.normalize_text")
@patch("api.src.services.tts.phonemize")
@patch("api.src.services.tts.tokenize")
@patch("api.src.services.tts.generate")
@patch("torch.load")
def test_generate_audio_success(
mock_torch_load,
mock_generate,
mock_tokenize,
mock_phonemize,
mock_normalize,
mock_exists,
mock_instance,
tts_service,
sample_audio,
):
"""Test successful audio generation"""
mock_normalize.return_value = "Test text"
mock_phonemize.return_value = "Test text"
mock_tokenize.return_value = ["test", "text"]
mock_generate.return_value = (sample_audio, None)
mock_instance.return_value = (MagicMock(), "cpu")
mock_exists.return_value = True
mock_torch_load.return_value = MagicMock()
audio, processing_time = tts_service._generate_audio("Test text", "af", 1.0)
assert isinstance(audio, np.ndarray)
assert isinstance(processing_time, float)
assert len(audio) > 0
@patch("api.src.services.tts.torch.cuda.is_available")
@patch("api.src.services.tts.build_model")
def test_model_initialization_cuda(mock_build_model, mock_cuda_available):
"""Test model initialization with CUDA"""
mock_cuda_available.return_value = True
mock_model = MagicMock()
mock_build_model.return_value = mock_model
TTSModel._instance = None # Reset singleton
model, voice_count = TTSModel.initialize()
assert TTSModel._device == "cuda" # Check the class variable instead
assert model == mock_model
mock_build_model.assert_called_once()
@patch("api.src.services.tts.torch.cuda.is_available")
@patch("api.src.services.tts.build_model")
def test_model_initialization_cpu(mock_build_model, mock_cuda_available):
"""Test model initialization with CPU"""
mock_cuda_available.return_value = False
mock_model = MagicMock()
mock_build_model.return_value = mock_model
TTSModel._instance = None # Reset singleton
model, voice_count = TTSModel.initialize()
assert TTSModel._device == "cpu" # Check the class variable instead
assert model == mock_model
mock_build_model.assert_called_once()
@patch("api.src.services.tts.TTSService._get_voice_path")
@patch("api.src.services.tts.TTSModel.get_instance")
@patch("api.src.services.tts_service.TTSService._get_voice_path")
@patch("api.src.services.tts_model.TTSModel.get_instance")
def test_voicepack_loading_error(mock_get_instance, mock_get_voice_path):
"""Test voicepack loading error handling"""
mock_get_voice_path.return_value = None
mock_get_instance.return_value = (MagicMock(), "cpu")
mock_instance = MagicMock()
mock_instance.generate.return_value = np.zeros(24000, dtype=np.float32)
mock_get_instance.return_value = (mock_instance, "cpu")
TTSModel._voicepacks = {} # Reset voicepacks
service = TTSService(start_worker=False)
service = TTSService()
with pytest.raises(ValueError, match="Voice not found: nonexistent_voice"):
service._generate_audio("test", "nonexistent_voice", 1.0)
@patch("api.src.services.tts.TTSModel")
def test_save_audio(mock_tts_model, tts_service, sample_audio, tmp_path):
"""Test saving audio to file"""
output_dir = os.path.join(tmp_path, "test_output")
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, "audio.wav")
tts_service._save_audio(sample_audio, output_path)
assert os.path.exists(output_path)
assert os.path.getsize(output_path) > 0
@patch("api.src.services.tts.TTSModel.get_instance")
@patch("os.path.exists")
@patch("api.src.services.tts.normalize_text")
@patch("api.src.services.tts.generate")
@patch("torch.load")
def test_generate_audio_without_stitching(
mock_torch_load,
mock_generate,
mock_normalize,
mock_exists,
mock_instance,
tts_service,
sample_audio,
):
"""Test generating audio without text stitching"""
mock_normalize.return_value = "Test text"
mock_generate.return_value = (sample_audio, None)
mock_instance.return_value = (MagicMock(), "cpu")
mock_exists.return_value = True
mock_torch_load.return_value = MagicMock()
audio, processing_time = tts_service._generate_audio(
"Test text", "af", 1.0, stitch_long_output=False
)
assert isinstance(audio, np.ndarray)
assert len(audio) > 0
mock_generate.assert_called_once()
@patch("os.listdir")
def test_list_voices_error(mock_listdir, tts_service):
"""Test error handling in list_voices"""
mock_listdir.side_effect = Exception("Failed to list directory")
voices = tts_service.list_voices()
assert voices == []
@patch("api.src.services.tts.TTSModel.get_instance")
@patch("os.path.exists")
@patch("api.src.services.tts.normalize_text")
@patch("api.src.services.tts.phonemize")
@patch("api.src.services.tts.tokenize")
@patch("api.src.services.tts.generate")
@patch("torch.load")
def test_generate_audio_phonemize_error(
mock_torch_load,
mock_generate,
mock_tokenize,
mock_phonemize,
mock_normalize,
mock_exists,
mock_instance,
tts_service,
):
"""Test handling phonemization error"""
mock_normalize.return_value = "Test text"
mock_phonemize.side_effect = Exception("Phonemization failed")
mock_instance.return_value = (MagicMock(), "cpu")
mock_exists.return_value = True
mock_torch_load.return_value = MagicMock()
mock_generate.return_value = (None, None)
with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
tts_service._generate_audio("Test text", "af", 1.0)
@patch("api.src.services.tts.TTSModel.get_instance")
@patch("os.path.exists")
@patch("api.src.services.tts.normalize_text")
@patch("api.src.services.tts.generate")
@patch("torch.load")
def test_generate_audio_error(
mock_torch_load,
mock_generate,
mock_normalize,
mock_exists,
mock_instance,
tts_service,
):
"""Test handling generation error"""
mock_normalize.return_value = "Test text"
mock_generate.side_effect = Exception("Generation failed")
mock_instance.return_value = (MagicMock(), "cpu")
mock_exists.return_value = True
mock_torch_load.return_value = MagicMock()
with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
tts_service._generate_audio("Test text", "af", 1.0)

View file

@ -36,6 +36,13 @@ services:
- "8880:8880"
environment:
- PYTHONPATH=/app:/app/Kokoro-82M
# ONNX Optimization Settings for vectorized operations
- ONNX_NUM_THREADS=8 # Maximize core usage for vectorized ops
- ONNX_INTER_OP_THREADS=4 # Higher inter-op for parallel matrix operations
- ONNX_EXECUTION_MODE=parallel
- ONNX_OPTIMIZATION_LEVEL=all
- ONNX_MEMORY_PATTERN=true
- ONNX_ARENA_EXTEND_STRATEGY=kNextPowerOfTwo
depends_on:
model-fetcher:
condition: service_healthy

0
examples/__init__.py Normal file
View file

View file

View file

@ -0,0 +1,242 @@
#!/usr/bin/env python3
import os
import json
import time
import threading
import queue
import pandas as pd
import sys
from datetime import datetime
from lib.shared_plotting import plot_system_metrics, plot_correlation
from lib.shared_utils import (
get_system_metrics, save_json_results, write_benchmark_stats,
real_time_factor
)
from lib.shared_benchmark_utils import (
get_text_for_tokens, make_tts_request, generate_token_sizes, enc
)
class SystemMonitor:
def __init__(self, interval=1.0):
self.interval = interval
self.metrics_queue = queue.Queue()
self.stop_event = threading.Event()
self.metrics_timeline = []
self.start_time = None
def _monitor_loop(self):
"""Background thread function to collect system metrics."""
while not self.stop_event.is_set():
metrics = get_system_metrics()
metrics["relative_time"] = time.time() - self.start_time
self.metrics_queue.put(metrics)
time.sleep(self.interval)
def start(self):
"""Start the monitoring thread."""
self.start_time = time.time()
self.monitor_thread = threading.Thread(target=self._monitor_loop)
self.monitor_thread.daemon = True
self.monitor_thread.start()
def stop(self):
"""Stop the monitoring thread and collect final metrics."""
self.stop_event.set()
if hasattr(self, 'monitor_thread'):
self.monitor_thread.join(timeout=2)
# Collect all metrics from queue
while True:
try:
metrics = self.metrics_queue.get_nowait()
self.metrics_timeline.append(metrics)
except queue.Empty:
break
return self.metrics_timeline
def main():
# Initialize system monitor
monitor = SystemMonitor(interval=1.0) # 1 second interval
# Set prefix for output files (e.g. "gpu", "cpu", "onnx", etc.)
prefix = "gpu"
# Generate token sizes
if 'gpu' in prefix:
token_sizes = generate_token_sizes(
max_tokens=5000, dense_step=150,
dense_max=1000, sparse_step=1000)
elif 'cpu' in prefix:
token_sizes = generate_token_sizes(
max_tokens=1000, dense_step=300,
dense_max=1000, sparse_step=0)
else:
token_sizes = generate_token_sizes(max_tokens=3000)
# Set up paths relative to this file
script_dir = os.path.dirname(os.path.abspath(__file__))
output_dir = os.path.join(script_dir, "output_audio")
output_data_dir = os.path.join(script_dir, "output_data")
output_plots_dir = os.path.join(script_dir, "output_plots")
# Create output directories
os.makedirs(output_dir, exist_ok=True)
os.makedirs(output_data_dir, exist_ok=True)
os.makedirs(output_plots_dir, exist_ok=True)
# Function to prefix filenames
def prefix_path(path: str, filename: str) -> str:
if prefix:
filename = f"{prefix}_{filename}"
return os.path.join(path, filename)
with open(os.path.join(script_dir, "the_time_machine_hg_wells.txt"), "r", encoding="utf-8") as f:
text = f.read()
total_tokens = len(enc.encode(text))
print(f"Total tokens in file: {total_tokens}")
print(f"Testing sizes: {token_sizes}")
results = []
test_start_time = time.time()
# Start system monitoring
monitor.start()
for num_tokens in token_sizes:
chunk = get_text_for_tokens(text, num_tokens)
actual_tokens = len(enc.encode(chunk))
print(f"\nProcessing chunk with {actual_tokens} tokens:")
print(f"Text preview: {chunk[:100]}...")
processing_time, audio_length = make_tts_request(
chunk,
output_dir=output_dir,
prefix=prefix
)
if processing_time is None or audio_length is None:
print("Breaking loop due to error")
break
# Calculate RTF using the correct formula
rtf = real_time_factor(processing_time, audio_length)
print(f"Real-Time Factor: {rtf:.5f}")
results.append({
"tokens": actual_tokens,
"processing_time": processing_time,
"output_length": audio_length,
"rtf": rtf,
"elapsed_time": round(time.time() - test_start_time, 2),
})
df = pd.DataFrame(results)
if df.empty:
print("No data to plot")
return
df["tokens_per_second"] = df["tokens"] / df["processing_time"]
# Write benchmark stats
stats = [
{
"title": "Benchmark Statistics (with correct RTF)",
"stats": {
"Total tokens processed": df['tokens'].sum(),
"Total audio generated (s)": df['output_length'].sum(),
"Total test duration (s)": df['elapsed_time'].max(),
"Average processing rate (tokens/s)": df['tokens_per_second'].mean(),
"Average RTF": df['rtf'].mean(),
"Average Real Time Speed": 1/df['rtf'].mean()
}
},
{
"title": "Per-chunk Stats",
"stats": {
"Average chunk size (tokens)": df['tokens'].mean(),
"Min chunk size (tokens)": df['tokens'].min(),
"Max chunk size (tokens)": df['tokens'].max(),
"Average processing time (s)": df['processing_time'].mean(),
"Average output length (s)": df['output_length'].mean()
}
},
{
"title": "Performance Ranges",
"stats": {
"Processing rate range (tokens/s)": f"{df['tokens_per_second'].min():.2f} - {df['tokens_per_second'].max():.2f}",
"RTF range": f"{df['rtf'].min():.2f}x - {df['rtf'].max():.2f}x",
"Real Time Speed range": f"{1/df['rtf'].max():.2f}x - {1/df['rtf'].min():.2f}x"
}
}
]
write_benchmark_stats(stats, prefix_path(output_data_dir, "benchmark_stats_rtf.txt"))
# Plot Processing Time vs Token Count
plot_correlation(
df, "tokens", "processing_time",
"Processing Time vs Input Size",
"Number of Input Tokens",
"Processing Time (seconds)",
prefix_path(output_plots_dir, "processing_time_rtf.png")
)
# Plot RTF vs Token Count
plot_correlation(
df, "tokens", "rtf",
"Real-Time Factor vs Input Size",
"Number of Input Tokens",
"Real-Time Factor (processing time / audio length)",
prefix_path(output_plots_dir, "realtime_factor_rtf.png")
)
# Stop monitoring and get final metrics
final_metrics = monitor.stop()
# Convert metrics timeline to DataFrame for stats
metrics_df = pd.DataFrame(final_metrics)
# Add system usage stats
if not metrics_df.empty:
stats.append({
"title": "System Usage Statistics",
"stats": {
"Peak CPU Usage (%)": metrics_df['cpu_percent'].max(),
"Avg CPU Usage (%)": metrics_df['cpu_percent'].mean(),
"Peak RAM Usage (%)": metrics_df['ram_percent'].max(),
"Avg RAM Usage (%)": metrics_df['ram_percent'].mean(),
"Peak RAM Used (GB)": metrics_df['ram_used_gb'].max(),
"Avg RAM Used (GB)": metrics_df['ram_used_gb'].mean(),
}
})
if 'gpu_memory_used' in metrics_df:
stats[-1]["stats"].update({
"Peak GPU Memory (MB)": metrics_df['gpu_memory_used'].max(),
"Avg GPU Memory (MB)": metrics_df['gpu_memory_used'].mean(),
})
# Plot system metrics
plot_system_metrics(final_metrics, prefix_path(output_plots_dir, "system_usage_rtf.png"))
# Save final results
save_json_results(
{
"results": results,
"system_metrics": final_metrics,
"test_duration": time.time() - test_start_time
},
prefix_path(output_data_dir, "benchmark_results_rtf.json")
)
print("\nResults saved to:")
print(f"- {prefix_path(output_data_dir, 'benchmark_results_rtf.json')}")
print(f"- {prefix_path(output_data_dir, 'benchmark_stats_rtf.txt')}")
print(f"- {prefix_path(output_plots_dir, 'processing_time_rtf.png')}")
print(f"- {prefix_path(output_plots_dir, 'realtime_factor_rtf.png')}")
print(f"- {prefix_path(output_plots_dir, 'system_usage_rtf.png')}")
print(f"\nAudio files saved in {output_dir} with prefix: {prefix or '(none)'}")
if __name__ == "__main__":
main()

View file

@ -0,0 +1,165 @@
import os
import json
import time
import pandas as pd
from examples.assorted_checks.lib.shared_plotting import plot_system_metrics, plot_correlation
from examples.assorted_checks.lib.shared_utils import (
get_system_metrics, save_json_results, write_benchmark_stats
)
from examples.assorted_checks.lib.shared_benchmark_utils import (
get_text_for_tokens, make_tts_request, generate_token_sizes, enc
)
def main():
# Get optional prefix from first command line argument
import sys
prefix = sys.argv[1] if len(sys.argv) > 1 else ""
# Set up paths relative to this file
script_dir = os.path.dirname(os.path.abspath(__file__))
output_dir = os.path.join(script_dir, "output_audio")
output_data_dir = os.path.join(script_dir, "output_data")
output_plots_dir = os.path.join(script_dir, "output_plots")
# Create output directories
os.makedirs(output_dir, exist_ok=True)
os.makedirs(output_data_dir, exist_ok=True)
os.makedirs(output_plots_dir, exist_ok=True)
# Function to prefix filenames
def prefix_path(path: str, filename: str) -> str:
if prefix:
filename = f"{prefix}_{filename}"
return os.path.join(path, filename)
# Read input text
with open(
os.path.join(script_dir, "the_time_machine_hg_wells.txt"), "r", encoding="utf-8"
) as f:
text = f.read()
# Get total tokens in file
total_tokens = len(enc.encode(text))
print(f"Total tokens in file: {total_tokens}")
token_sizes = generate_token_sizes(total_tokens)
print(f"Testing sizes: {token_sizes}")
# Process chunks
results = []
system_metrics = []
test_start_time = time.time()
for num_tokens in token_sizes:
# Get text slice with exact token count
chunk = get_text_for_tokens(text, num_tokens)
actual_tokens = len(enc.encode(chunk))
print(f"\nProcessing chunk with {actual_tokens} tokens:")
print(f"Text preview: {chunk[:100]}...")
# Collect system metrics before processing
system_metrics.append(get_system_metrics())
processing_time, audio_length = make_tts_request(chunk)
if processing_time is None or audio_length is None:
print("Breaking loop due to error")
break
# Collect system metrics after processing
system_metrics.append(get_system_metrics())
results.append(
{
"tokens": actual_tokens,
"processing_time": processing_time,
"output_length": audio_length,
"realtime_factor": audio_length / processing_time,
"elapsed_time": time.time() - test_start_time,
}
)
# Save intermediate results
save_json_results(
{"results": results, "system_metrics": system_metrics},
prefix_path(output_data_dir, "benchmark_results.json")
)
# Create DataFrame and calculate stats
df = pd.DataFrame(results)
if df.empty:
print("No data to plot")
return
# Calculate useful metrics
df["tokens_per_second"] = df["tokens"] / df["processing_time"]
# Write benchmark stats
stats = [
{
"title": "Benchmark Statistics",
"stats": {
"Total tokens processed": df['tokens'].sum(),
"Total audio generated (s)": df['output_length'].sum(),
"Total test duration (s)": df['elapsed_time'].max(),
"Average processing rate (tokens/s)": df['tokens_per_second'].mean(),
"Average realtime factor": df['realtime_factor'].mean()
}
},
{
"title": "Per-chunk Stats",
"stats": {
"Average chunk size (tokens)": df['tokens'].mean(),
"Min chunk size (tokens)": df['tokens'].min(),
"Max chunk size (tokens)": df['tokens'].max(),
"Average processing time (s)": df['processing_time'].mean(),
"Average output length (s)": df['output_length'].mean()
}
},
{
"title": "Performance Ranges",
"stats": {
"Processing rate range (tokens/s)": f"{df['tokens_per_second'].min():.2f} - {df['tokens_per_second'].max():.2f}",
"Realtime factor range": f"{df['realtime_factor'].min():.2f}x - {df['realtime_factor'].max():.2f}x"
}
}
]
write_benchmark_stats(stats, prefix_path(output_data_dir, "benchmark_stats.txt"))
# Plot Processing Time vs Token Count
plot_correlation(
df, "tokens", "processing_time",
"Processing Time vs Input Size",
"Number of Input Tokens",
"Processing Time (seconds)",
prefix_path(output_plots_dir, "processing_time.png")
)
# Plot Realtime Factor vs Token Count
plot_correlation(
df, "tokens", "realtime_factor",
"Realtime Factor vs Input Size",
"Number of Input Tokens",
"Realtime Factor (output length / processing time)",
prefix_path(output_plots_dir, "realtime_factor.png")
)
# Plot system metrics
plot_system_metrics(system_metrics, prefix_path(output_plots_dir, "system_usage.png"))
print("\nResults saved to:")
print(f"- {prefix_path(output_data_dir, 'benchmark_results.json')}")
print(f"- {prefix_path(output_data_dir, 'benchmark_stats.txt')}")
print(f"- {prefix_path(output_plots_dir, 'processing_time.png')}")
print(f"- {prefix_path(output_plots_dir, 'realtime_factor.png')}")
print(f"- {prefix_path(output_plots_dir, 'system_usage.png')}")
if any("gpu_memory_used" in m for m in system_metrics):
print(f"- {prefix_path(output_plots_dir, 'gpu_usage.png')}")
print(f"\nAudio files saved in {output_dir} with prefix: {prefix or '(none)'}")
if __name__ == "__main__":
main()

View file

@ -0,0 +1,111 @@
"""Shared utilities specific to TTS benchmarking."""
import time
from typing import List, Optional, Tuple
import requests
import tiktoken
from .shared_utils import get_audio_length, save_audio_file
# Global tokenizer instance
enc = tiktoken.get_encoding("cl100k_base")
def get_text_for_tokens(text: str, num_tokens: int) -> str:
"""Get a slice of text that contains exactly num_tokens tokens.
Args:
text: Input text to slice
num_tokens: Desired number of tokens
Returns:
str: Text slice containing exactly num_tokens tokens
"""
tokens = enc.encode(text)
if num_tokens > len(tokens):
return text
return enc.decode(tokens[:num_tokens])
def make_tts_request(
text: str,
output_dir: str = None,
timeout: int = 1800,
prefix: str = ""
) -> Tuple[Optional[float], Optional[float]]:
"""Make TTS request using OpenAI-compatible endpoint.
Args:
text: Input text to convert to speech
output_dir: Directory to save audio files. If None, audio won't be saved.
timeout: Request timeout in seconds
prefix: Optional prefix for output filenames
Returns:
tuple: (processing_time, audio_length) in seconds, or (None, None) on error
"""
try:
start_time = time.time()
response = requests.post(
"http://localhost:8880/v1/audio/speech",
json={
"model": "kokoro",
"input": text,
"voice": "af",
"response_format": "wav",
},
timeout=timeout,
)
response.raise_for_status()
processing_time = round(time.time() - start_time, 2)
# Calculate audio length from response content
audio_length = get_audio_length(response.content)
# Save the audio file if output_dir is provided
if output_dir:
token_count = len(enc.encode(text))
output_file = save_audio_file(
response.content,
f"chunk_{token_count}_tokens",
output_dir
)
print(f"Saved audio to {output_file}")
return processing_time, audio_length
except requests.exceptions.RequestException as e:
print(f"Error making request for text: {text[:50]}... Error: {str(e)}")
return None, None
except Exception as e:
print(f"Error processing text: {text[:50]}... Error: {str(e)}")
return None, None
def generate_token_sizes(
max_tokens: int,
dense_step: int = 100,
dense_max: int = 1000,
sparse_step: int = 1000
) -> List[int]:
"""Generate token size ranges with dense sampling at start.
Args:
max_tokens: Maximum number of tokens to generate sizes up to
dense_step: Step size for dense sampling range
dense_max: Maximum value for dense sampling
sparse_step: Step size for sparse sampling range
Returns:
list: Sorted list of token sizes
"""
# Dense sampling at start
dense_range = list(range(dense_step, dense_max + 1, dense_step))
if max_tokens <= dense_max or sparse_step < dense_max:
return sorted(dense_range)
# Sparse sampling for larger sizes
sparse_range = list(range(dense_max + sparse_step, max_tokens + 1, sparse_step))
# Combine and deduplicate
return sorted(list(set(dense_range + sparse_range)))

View file

@ -0,0 +1,176 @@
"""Shared plotting utilities for benchmarks and tests."""
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
# Common style configurations
STYLE_CONFIG = {
"background_color": "#1a1a2e",
"primary_color": "#ff2a6d",
"secondary_color": "#05d9e8",
"grid_color": "#ffffff",
"text_color": "#ffffff",
"font_sizes": {
"title": 16,
"label": 14,
"tick": 12,
"text": 10
}
}
def setup_plot(fig, ax, title, xlabel=None, ylabel=None):
"""Configure plot styling with consistent theme.
Args:
fig: matplotlib figure object
ax: matplotlib axis object
title: str, plot title
xlabel: str, optional x-axis label
ylabel: str, optional y-axis label
Returns:
tuple: (fig, ax) with applied styling
"""
# Grid styling
ax.grid(True, linestyle="--", alpha=0.3, color=STYLE_CONFIG["grid_color"])
# Title and labels
ax.set_title(title, pad=20,
fontsize=STYLE_CONFIG["font_sizes"]["title"],
fontweight="bold",
color=STYLE_CONFIG["text_color"])
if xlabel:
ax.set_xlabel(xlabel,
fontsize=STYLE_CONFIG["font_sizes"]["label"],
fontweight="medium",
color=STYLE_CONFIG["text_color"])
if ylabel:
ax.set_ylabel(ylabel,
fontsize=STYLE_CONFIG["font_sizes"]["label"],
fontweight="medium",
color=STYLE_CONFIG["text_color"])
# Tick styling
ax.tick_params(labelsize=STYLE_CONFIG["font_sizes"]["tick"],
colors=STYLE_CONFIG["text_color"])
# Spine styling
for spine in ax.spines.values():
spine.set_color(STYLE_CONFIG["text_color"])
spine.set_alpha(0.3)
spine.set_linewidth(0.5)
# Background colors
ax.set_facecolor(STYLE_CONFIG["background_color"])
fig.patch.set_facecolor(STYLE_CONFIG["background_color"])
return fig, ax
def plot_system_metrics(metrics_data, output_path):
"""Create plots for system metrics over time.
Args:
metrics_data: list of dicts containing system metrics
output_path: str, path to save the output plot
"""
df = pd.DataFrame(metrics_data)
df["timestamp"] = pd.to_datetime(df["timestamp"])
elapsed_time = (df["timestamp"] - df["timestamp"].iloc[0]).dt.total_seconds()
# Get baseline values
baseline_cpu = df["cpu_percent"].iloc[0]
baseline_ram = df["ram_used_gb"].iloc[0]
baseline_gpu = df["gpu_memory_used"].iloc[0] / 1024 if "gpu_memory_used" in df.columns else None
# Convert GPU memory to GB if present
if "gpu_memory_used" in df.columns:
df["gpu_memory_gb"] = df["gpu_memory_used"] / 1024
plt.style.use("dark_background")
# Create subplots based on available metrics
has_gpu = "gpu_memory_used" in df.columns
num_plots = 3 if has_gpu else 2
fig, axes = plt.subplots(num_plots, 1, figsize=(15, 5 * num_plots))
fig.patch.set_facecolor(STYLE_CONFIG["background_color"])
# Smoothing window
window = min(5, len(df) // 2)
# Plot CPU Usage
smoothed_cpu = df["cpu_percent"].rolling(window=window, center=True).mean()
sns.lineplot(x=elapsed_time, y=smoothed_cpu, ax=axes[0],
color=STYLE_CONFIG["primary_color"], linewidth=2)
axes[0].axhline(y=baseline_cpu, color=STYLE_CONFIG["secondary_color"],
linestyle="--", alpha=0.5, label="Baseline")
setup_plot(fig, axes[0], "CPU Usage Over Time",
xlabel="Time (seconds)", ylabel="CPU Usage (%)")
axes[0].set_ylim(0, max(df["cpu_percent"]) * 1.1)
axes[0].legend()
# Plot RAM Usage
smoothed_ram = df["ram_used_gb"].rolling(window=window, center=True).mean()
sns.lineplot(x=elapsed_time, y=smoothed_ram, ax=axes[1],
color=STYLE_CONFIG["secondary_color"], linewidth=2)
axes[1].axhline(y=baseline_ram, color=STYLE_CONFIG["primary_color"],
linestyle="--", alpha=0.5, label="Baseline")
setup_plot(fig, axes[1], "RAM Usage Over Time",
xlabel="Time (seconds)", ylabel="RAM Usage (GB)")
axes[1].set_ylim(0, max(df["ram_used_gb"]) * 1.1)
axes[1].legend()
# Plot GPU Memory if available
if has_gpu:
smoothed_gpu = df["gpu_memory_gb"].rolling(window=window, center=True).mean()
sns.lineplot(x=elapsed_time, y=smoothed_gpu, ax=axes[2],
color=STYLE_CONFIG["primary_color"], linewidth=2)
axes[2].axhline(y=baseline_gpu, color=STYLE_CONFIG["secondary_color"],
linestyle="--", alpha=0.5, label="Baseline")
setup_plot(fig, axes[2], "GPU Memory Usage Over Time",
xlabel="Time (seconds)", ylabel="GPU Memory (GB)")
axes[2].set_ylim(0, max(df["gpu_memory_gb"]) * 1.1)
axes[2].legend()
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches="tight")
plt.close()
def plot_correlation(df, x, y, title, xlabel, ylabel, output_path):
"""Create correlation plot with regression line and correlation coefficient.
Args:
df: pandas DataFrame containing the data
x: str, column name for x-axis
y: str, column name for y-axis
title: str, plot title
xlabel: str, x-axis label
ylabel: str, y-axis label
output_path: str, path to save the output plot
"""
plt.style.use("dark_background")
fig, ax = plt.subplots(figsize=(12, 8))
# Scatter plot
sns.scatterplot(data=df, x=x, y=y, s=100, alpha=0.6,
color=STYLE_CONFIG["primary_color"])
# Regression line
sns.regplot(data=df, x=x, y=y, scatter=False,
color=STYLE_CONFIG["secondary_color"],
line_kws={"linewidth": 2})
# Add correlation coefficient
corr = df[x].corr(df[y])
plt.text(0.05, 0.95, f"Correlation: {corr:.2f}",
transform=ax.transAxes,
fontsize=STYLE_CONFIG["font_sizes"]["text"],
color=STYLE_CONFIG["text_color"],
bbox=dict(facecolor=STYLE_CONFIG["background_color"],
edgecolor=STYLE_CONFIG["text_color"],
alpha=0.7))
setup_plot(fig, ax, title, xlabel=xlabel, ylabel=ylabel)
plt.savefig(output_path, dpi=300, bbox_inches="tight")
plt.close()

View file

@ -0,0 +1,174 @@
"""Shared utilities for benchmarks and tests."""
import os
import json
import subprocess
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
import psutil
import scipy.io.wavfile as wavfile
# Check for torch availability once at module level
TORCH_AVAILABLE = False
try:
import torch
TORCH_AVAILABLE = torch.cuda.is_available()
except ImportError:
pass
def get_audio_length(audio_data: bytes, temp_dir: str = None) -> float:
"""Get audio length in seconds from bytes data.
Args:
audio_data: Raw audio bytes
temp_dir: Directory for temporary file. If None, uses system temp directory.
Returns:
float: Audio length in seconds
"""
if temp_dir is None:
import tempfile
temp_dir = tempfile.gettempdir()
temp_path = os.path.join(temp_dir, "temp.wav")
os.makedirs(temp_dir, exist_ok=True)
with open(temp_path, "wb") as f:
f.write(audio_data)
try:
rate, data = wavfile.read(temp_path)
return len(data) / rate
finally:
if os.path.exists(temp_path):
os.remove(temp_path)
def get_gpu_memory(average: bool = True) -> Optional[Union[float, List[float]]]:
"""Get GPU memory usage using PyTorch if available, falling back to nvidia-smi.
Args:
average: If True and multiple GPUs present, returns average memory usage.
If False, returns list of memory usage per GPU.
Returns:
float or List[float] or None: GPU memory usage in MB. Returns None if no GPU available.
If average=False and multiple GPUs present, returns list of values.
"""
if TORCH_AVAILABLE:
n_gpus = torch.cuda.device_count()
memory_used = []
for i in range(n_gpus):
memory_used.append(torch.cuda.memory_allocated(i) / 1024**2) # Convert to MB
if average and len(memory_used) > 0:
return sum(memory_used) / len(memory_used)
return memory_used if len(memory_used) > 1 else memory_used[0]
# Fall back to nvidia-smi
try:
result = subprocess.check_output(
["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader"]
)
memory_values = [float(x.strip()) for x in result.decode("utf-8").split("\n") if x.strip()]
if average and len(memory_values) > 0:
return sum(memory_values) / len(memory_values)
return memory_values if len(memory_values) > 1 else memory_values[0]
except (subprocess.CalledProcessError, FileNotFoundError):
return None
def get_system_metrics() -> Dict[str, Union[str, float]]:
"""Get current system metrics including CPU, RAM, and GPU if available.
Returns:
dict: System metrics including timestamp, CPU%, RAM%, RAM GB, and GPU MB if available
"""
# Get per-CPU percentages and calculate average
cpu_percentages = psutil.cpu_percent(percpu=True)
avg_cpu = sum(cpu_percentages) / len(cpu_percentages)
metrics = {
"timestamp": datetime.now().isoformat(),
"cpu_percent": round(avg_cpu, 2),
"ram_percent": psutil.virtual_memory().percent,
"ram_used_gb": psutil.virtual_memory().used / (1024**3),
}
gpu_mem = get_gpu_memory(average=True) # Use average for system metrics
if gpu_mem is not None:
metrics["gpu_memory_used"] = round(gpu_mem, 2)
return metrics
def save_audio_file(audio_data: bytes, identifier: str, output_dir: str) -> str:
"""Save audio data to a file with proper naming and directory creation.
Args:
audio_data: Raw audio bytes
identifier: String to identify this audio file (e.g. token count, test name)
output_dir: Directory to save the file
Returns:
str: Path to the saved audio file
"""
os.makedirs(output_dir, exist_ok=True)
output_file = os.path.join(output_dir, f"{identifier}.wav")
with open(output_file, "wb") as f:
f.write(audio_data)
return output_file
def write_benchmark_stats(stats: List[Dict[str, Any]], output_file: str) -> None:
"""Write benchmark statistics to a file in a clean, organized format.
Args:
stats: List of dictionaries containing stat name/value pairs
output_file: Path to output file
"""
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, "w") as f:
for section in stats:
# Write section header
f.write(f"=== {section['title']} ===\n\n")
# Write stats
for label, value in section['stats'].items():
if isinstance(value, float):
f.write(f"{label}: {value:.2f}\n")
else:
f.write(f"{label}: {value}\n")
f.write("\n")
def save_json_results(results: Dict[str, Any], output_file: str) -> None:
"""Save benchmark results to a JSON file with proper formatting.
Args:
results: Dictionary of results to save
output_file: Path to output file
"""
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, "w") as f:
json.dump(results, f, indent=2)
def real_time_factor(processing_time: float, audio_length: float, decimals: int = 2) -> float:
"""Calculate Real-Time Factor (RTF) as processing-time / length-of-audio.
Args:
processing_time: Time taken to process/generate audio
audio_length: Length of the generated audio
decimals: Number of decimal places to round to
Returns:
float: RTF value
"""
rtf = processing_time / audio_length
return round(rtf, decimals)

View file

@ -0,0 +1,111 @@
{
"results": [
{
"tokens": 100,
"processing_time": 18.833295583724976,
"output_length": 31.15,
"realtime_factor": 1.6539856161403135,
"elapsed_time": 19.024322748184204
},
{
"tokens": 200,
"processing_time": 38.95506024360657,
"output_length": 62.6,
"realtime_factor": 1.6069799304257042,
"elapsed_time": 58.21527123451233
},
{
"tokens": 300,
"processing_time": 49.74252939224243,
"output_length": 96.325,
"realtime_factor": 1.9364716908630366,
"elapsed_time": 108.19673728942871
},
{
"tokens": 400,
"processing_time": 61.349056243896484,
"output_length": 128.575,
"realtime_factor": 2.095794261102292,
"elapsed_time": 169.733656167984
},
{
"tokens": 500,
"processing_time": 82.86568236351013,
"output_length": 158.575,
"realtime_factor": 1.9136389815071193,
"elapsed_time": 252.7968451976776
}
],
"system_metrics": [
{
"timestamp": "2025-01-03T00:13:49.865330",
"cpu_percent": 8.0,
"ram_percent": 39.4,
"ram_used_gb": 25.03811264038086,
"gpu_memory_used": 1204.0
},
{
"timestamp": "2025-01-03T00:14:08.781551",
"cpu_percent": 26.8,
"ram_percent": 42.6,
"ram_used_gb": 27.090862274169922,
"gpu_memory_used": 1225.0
},
{
"timestamp": "2025-01-03T00:14:08.916973",
"cpu_percent": 16.1,
"ram_percent": 42.6,
"ram_used_gb": 27.089553833007812,
"gpu_memory_used": 1225.0
},
{
"timestamp": "2025-01-03T00:14:47.979053",
"cpu_percent": 31.5,
"ram_percent": 43.6,
"ram_used_gb": 27.714427947998047,
"gpu_memory_used": 1225.0
},
{
"timestamp": "2025-01-03T00:14:48.098976",
"cpu_percent": 20.0,
"ram_percent": 43.6,
"ram_used_gb": 27.704315185546875,
"gpu_memory_used": 1211.0
},
{
"timestamp": "2025-01-03T00:15:37.944729",
"cpu_percent": 29.7,
"ram_percent": 38.6,
"ram_used_gb": 24.53925323486328,
"gpu_memory_used": 1217.0
},
{
"timestamp": "2025-01-03T00:15:38.071915",
"cpu_percent": 8.6,
"ram_percent": 38.5,
"ram_used_gb": 24.51690673828125,
"gpu_memory_used": 1208.0
},
{
"timestamp": "2025-01-03T00:16:39.525449",
"cpu_percent": 23.4,
"ram_percent": 38.8,
"ram_used_gb": 24.71230697631836,
"gpu_memory_used": 1221.0
},
{
"timestamp": "2025-01-03T00:16:39.612442",
"cpu_percent": 5.5,
"ram_percent": 38.9,
"ram_used_gb": 24.72066879272461,
"gpu_memory_used": 1221.0
},
{
"timestamp": "2025-01-03T00:18:02.569076",
"cpu_percent": 27.4,
"ram_percent": 39.1,
"ram_used_gb": 24.868202209472656,
"gpu_memory_used": 1264.0
}
]
}

View file

@ -0,0 +1,216 @@
{
"results": [
{
"tokens": 100,
"processing_time": 14.349808931350708,
"output_length": 31.15,
"rtf": 0.46,
"elapsed_time": 14.716031074523926
},
{
"tokens": 200,
"processing_time": 28.341803312301636,
"output_length": 62.6,
"rtf": 0.45,
"elapsed_time": 43.44207406044006
},
{
"tokens": 300,
"processing_time": 43.352553606033325,
"output_length": 96.325,
"rtf": 0.45,
"elapsed_time": 87.26906609535217
},
{
"tokens": 400,
"processing_time": 71.02449822425842,
"output_length": 128.575,
"rtf": 0.55,
"elapsed_time": 158.7198133468628
},
{
"tokens": 500,
"processing_time": 70.92521691322327,
"output_length": 158.575,
"rtf": 0.45,
"elapsed_time": 230.01379895210266
},
{
"tokens": 600,
"processing_time": 83.6328592300415,
"output_length": 189.25,
"rtf": 0.44,
"elapsed_time": 314.02610969543457
},
{
"tokens": 700,
"processing_time": 103.0810194015503,
"output_length": 222.075,
"rtf": 0.46,
"elapsed_time": 417.5678551197052
},
{
"tokens": 800,
"processing_time": 127.02162909507751,
"output_length": 253.85,
"rtf": 0.5,
"elapsed_time": 545.0128681659698
},
{
"tokens": 900,
"processing_time": 130.49781227111816,
"output_length": 283.775,
"rtf": 0.46,
"elapsed_time": 675.8943417072296
},
{
"tokens": 1000,
"processing_time": 154.76425909996033,
"output_length": 315.475,
"rtf": 0.49,
"elapsed_time": 831.0677945613861
}
],
"system_metrics": [
{
"timestamp": "2025-01-03T00:23:52.896889",
"cpu_percent": 4.5,
"ram_percent": 39.1,
"ram_used_gb": 24.86032485961914,
"gpu_memory_used": 1281.0
},
{
"timestamp": "2025-01-03T00:24:07.429461",
"cpu_percent": 4.5,
"ram_percent": 39.1,
"ram_used_gb": 24.847564697265625,
"gpu_memory_used": 1285.0
},
{
"timestamp": "2025-01-03T00:24:07.620587",
"cpu_percent": 2.7,
"ram_percent": 39.1,
"ram_used_gb": 24.846607208251953,
"gpu_memory_used": 1275.0
},
{
"timestamp": "2025-01-03T00:24:36.140754",
"cpu_percent": 5.4,
"ram_percent": 39.1,
"ram_used_gb": 24.857810974121094,
"gpu_memory_used": 1267.0
},
{
"timestamp": "2025-01-03T00:24:36.340675",
"cpu_percent": 6.2,
"ram_percent": 39.1,
"ram_used_gb": 24.85773468017578,
"gpu_memory_used": 1267.0
},
{
"timestamp": "2025-01-03T00:25:19.905634",
"cpu_percent": 29.1,
"ram_percent": 39.2,
"ram_used_gb": 24.920318603515625,
"gpu_memory_used": 1256.0
},
{
"timestamp": "2025-01-03T00:25:20.182219",
"cpu_percent": 20.0,
"ram_percent": 39.2,
"ram_used_gb": 24.930198669433594,
"gpu_memory_used": 1256.0
},
{
"timestamp": "2025-01-03T00:26:31.414760",
"cpu_percent": 5.3,
"ram_percent": 39.5,
"ram_used_gb": 25.127891540527344,
"gpu_memory_used": 1259.0
},
{
"timestamp": "2025-01-03T00:26:31.617256",
"cpu_percent": 3.6,
"ram_percent": 39.5,
"ram_used_gb": 25.126346588134766,
"gpu_memory_used": 1252.0
},
{
"timestamp": "2025-01-03T00:27:42.736097",
"cpu_percent": 10.5,
"ram_percent": 39.5,
"ram_used_gb": 25.100231170654297,
"gpu_memory_used": 1249.0
},
{
"timestamp": "2025-01-03T00:27:42.912870",
"cpu_percent": 5.3,
"ram_percent": 39.5,
"ram_used_gb": 25.098285675048828,
"gpu_memory_used": 1249.0
},
{
"timestamp": "2025-01-03T00:29:06.725264",
"cpu_percent": 8.9,
"ram_percent": 39.5,
"ram_used_gb": 25.123123168945312,
"gpu_memory_used": 1239.0
},
{
"timestamp": "2025-01-03T00:29:06.928826",
"cpu_percent": 5.5,
"ram_percent": 39.5,
"ram_used_gb": 25.128646850585938,
"gpu_memory_used": 1239.0
},
{
"timestamp": "2025-01-03T00:30:50.206349",
"cpu_percent": 49.6,
"ram_percent": 39.6,
"ram_used_gb": 25.162948608398438,
"gpu_memory_used": 1245.0
},
{
"timestamp": "2025-01-03T00:30:50.491837",
"cpu_percent": 14.8,
"ram_percent": 39.5,
"ram_used_gb": 25.13379669189453,
"gpu_memory_used": 1245.0
},
{
"timestamp": "2025-01-03T00:32:57.721467",
"cpu_percent": 6.2,
"ram_percent": 39.6,
"ram_used_gb": 25.187721252441406,
"gpu_memory_used": 1384.0
},
{
"timestamp": "2025-01-03T00:32:57.913350",
"cpu_percent": 3.6,
"ram_percent": 39.6,
"ram_used_gb": 25.199390411376953,
"gpu_memory_used": 1384.0
},
{
"timestamp": "2025-01-03T00:35:08.608730",
"cpu_percent": 6.3,
"ram_percent": 39.8,
"ram_used_gb": 25.311710357666016,
"gpu_memory_used": 1330.0
},
{
"timestamp": "2025-01-03T00:35:08.791851",
"cpu_percent": 5.3,
"ram_percent": 39.8,
"ram_used_gb": 25.326683044433594,
"gpu_memory_used": 1333.0
},
{
"timestamp": "2025-01-03T00:37:43.782406",
"cpu_percent": 6.8,
"ram_percent": 40.6,
"ram_used_gb": 25.803058624267578,
"gpu_memory_used": 1409.0
}
]
}

View file

@ -0,0 +1,300 @@
{
"results": [
{
"tokens": 100,
"processing_time": 0.96,
"output_length": 31.1,
"rtf": 0.03,
"elapsed_time": 1.11
},
{
"tokens": 250,
"processing_time": 2.23,
"output_length": 77.17,
"rtf": 0.03,
"elapsed_time": 3.49
},
{
"tokens": 400,
"processing_time": 4.05,
"output_length": 128.05,
"rtf": 0.03,
"elapsed_time": 7.77
},
{
"tokens": 550,
"processing_time": 4.06,
"output_length": 171.45,
"rtf": 0.02,
"elapsed_time": 12.0
},
{
"tokens": 700,
"processing_time": 6.01,
"output_length": 221.6,
"rtf": 0.03,
"elapsed_time": 18.16
},
{
"tokens": 850,
"processing_time": 6.9,
"output_length": 269.1,
"rtf": 0.03,
"elapsed_time": 25.21
},
{
"tokens": 1000,
"processing_time": 7.65,
"output_length": 315.05,
"rtf": 0.02,
"elapsed_time": 33.03
},
{
"tokens": 6000,
"processing_time": 48.7,
"output_length": 1837.1,
"rtf": 0.03,
"elapsed_time": 82.21
},
{
"tokens": 11000,
"processing_time": 92.44,
"output_length": 3388.57,
"rtf": 0.03,
"elapsed_time": 175.46
},
{
"tokens": 16000,
"processing_time": 163.61,
"output_length": 4977.32,
"rtf": 0.03,
"elapsed_time": 340.46
},
{
"tokens": 21000,
"processing_time": 209.72,
"output_length": 6533.3,
"rtf": 0.03,
"elapsed_time": 551.92
},
{
"tokens": 26000,
"processing_time": 329.35,
"output_length": 8068.15,
"rtf": 0.04,
"elapsed_time": 883.37
},
{
"tokens": 31000,
"processing_time": 473.52,
"output_length": 9611.48,
"rtf": 0.05,
"elapsed_time": 1359.28
},
{
"tokens": 36000,
"processing_time": 650.98,
"output_length": 11157.15,
"rtf": 0.06,
"elapsed_time": 2012.9
}
],
"system_metrics": [
{
"timestamp": "2025-01-03T14:41:01.331735",
"cpu_percent": 7.5,
"ram_percent": 50.2,
"ram_used_gb": 31.960269927978516,
"gpu_memory_used": 3191.0
},
{
"timestamp": "2025-01-03T14:41:02.357116",
"cpu_percent": 17.01,
"ram_percent": 50.2,
"ram_used_gb": 31.96163558959961,
"gpu_memory_used": 3426.0
},
{
"timestamp": "2025-01-03T14:41:02.445009",
"cpu_percent": 9.5,
"ram_percent": 50.3,
"ram_used_gb": 31.966781616210938,
"gpu_memory_used": 3426.0
},
{
"timestamp": "2025-01-03T14:41:04.742152",
"cpu_percent": 18.27,
"ram_percent": 50.4,
"ram_used_gb": 32.08788299560547,
"gpu_memory_used": 3642.0
},
{
"timestamp": "2025-01-03T14:41:04.847795",
"cpu_percent": 16.27,
"ram_percent": 50.5,
"ram_used_gb": 32.094364166259766,
"gpu_memory_used": 3640.0
},
{
"timestamp": "2025-01-03T14:41:09.019590",
"cpu_percent": 15.97,
"ram_percent": 50.7,
"ram_used_gb": 32.23244094848633,
"gpu_memory_used": 3640.0
},
{
"timestamp": "2025-01-03T14:41:09.110324",
"cpu_percent": 3.54,
"ram_percent": 50.7,
"ram_used_gb": 32.234458923339844,
"gpu_memory_used": 3640.0
},
{
"timestamp": "2025-01-03T14:41:13.252607",
"cpu_percent": 13.4,
"ram_percent": 50.6,
"ram_used_gb": 32.194271087646484,
"gpu_memory_used": 3935.0
},
{
"timestamp": "2025-01-03T14:41:13.327557",
"cpu_percent": 4.69,
"ram_percent": 50.6,
"ram_used_gb": 32.191776275634766,
"gpu_memory_used": 3935.0
},
{
"timestamp": "2025-01-03T14:41:19.413633",
"cpu_percent": 12.92,
"ram_percent": 50.9,
"ram_used_gb": 32.3467903137207,
"gpu_memory_used": 4250.0
},
{
"timestamp": "2025-01-03T14:41:19.492758",
"cpu_percent": 7.5,
"ram_percent": 50.8,
"ram_used_gb": 32.34375,
"gpu_memory_used": 4250.0
},
{
"timestamp": "2025-01-03T14:41:26.467284",
"cpu_percent": 13.09,
"ram_percent": 51.2,
"ram_used_gb": 32.56281280517578,
"gpu_memory_used": 4249.0
},
{
"timestamp": "2025-01-03T14:41:26.553559",
"cpu_percent": 8.39,
"ram_percent": 51.2,
"ram_used_gb": 32.56183624267578,
"gpu_memory_used": 4249.0
},
{
"timestamp": "2025-01-03T14:41:34.284362",
"cpu_percent": 12.61,
"ram_percent": 51.7,
"ram_used_gb": 32.874778747558594,
"gpu_memory_used": 4250.0
},
{
"timestamp": "2025-01-03T14:41:34.362353",
"cpu_percent": 1.25,
"ram_percent": 51.7,
"ram_used_gb": 32.87461471557617,
"gpu_memory_used": 4250.0
},
{
"timestamp": "2025-01-03T14:42:23.471312",
"cpu_percent": 11.64,
"ram_percent": 54.9,
"ram_used_gb": 34.90264129638672,
"gpu_memory_used": 4647.0
},
{
"timestamp": "2025-01-03T14:42:23.547203",
"cpu_percent": 5.31,
"ram_percent": 54.9,
"ram_used_gb": 34.91563415527344,
"gpu_memory_used": 4647.0
},
{
"timestamp": "2025-01-03T14:43:56.724933",
"cpu_percent": 12.97,
"ram_percent": 59.5,
"ram_used_gb": 37.84241485595703,
"gpu_memory_used": 4655.0
},
{
"timestamp": "2025-01-03T14:43:56.815453",
"cpu_percent": 11.75,
"ram_percent": 59.5,
"ram_used_gb": 37.832679748535156,
"gpu_memory_used": 4655.0
},
{
"timestamp": "2025-01-03T14:46:41.705155",
"cpu_percent": 12.94,
"ram_percent": 66.3,
"ram_used_gb": 42.1534538269043,
"gpu_memory_used": 4729.0
},
{
"timestamp": "2025-01-03T14:46:41.835177",
"cpu_percent": 7.73,
"ram_percent": 66.2,
"ram_used_gb": 42.13554000854492,
"gpu_memory_used": 4729.0
},
{
"timestamp": "2025-01-03T14:50:13.166236",
"cpu_percent": 11.62,
"ram_percent": 73.4,
"ram_used_gb": 46.71288299560547,
"gpu_memory_used": 4676.0
},
{
"timestamp": "2025-01-03T14:50:13.261611",
"cpu_percent": 8.16,
"ram_percent": 73.4,
"ram_used_gb": 46.71356201171875,
"gpu_memory_used": 4676.0
},
{
"timestamp": "2025-01-03T14:55:44.623607",
"cpu_percent": 12.92,
"ram_percent": 82.8,
"ram_used_gb": 52.65533447265625,
"gpu_memory_used": 4636.0
},
{
"timestamp": "2025-01-03T14:55:44.735410",
"cpu_percent": 15.29,
"ram_percent": 82.7,
"ram_used_gb": 52.63290786743164,
"gpu_memory_used": 4636.0
},
{
"timestamp": "2025-01-03T15:03:40.534449",
"cpu_percent": 13.88,
"ram_percent": 85.0,
"ram_used_gb": 54.050071716308594,
"gpu_memory_used": 4771.0
},
{
"timestamp": "2025-01-03T15:03:40.638708",
"cpu_percent": 12.21,
"ram_percent": 85.0,
"ram_used_gb": 54.053733825683594,
"gpu_memory_used": 4771.0
},
{
"timestamp": "2025-01-03T15:14:34.159142",
"cpu_percent": 14.51,
"ram_percent": 78.1,
"ram_used_gb": 49.70396423339844,
"gpu_memory_used": 4739.0
}
]
}

View file

@ -0,0 +1,19 @@
=== Benchmark Statistics (with correct RTF) ===
Overall Stats:
Total tokens processed: 5500
Total audio generated: 1741.65s
Total test duration: 831.07s
Average processing rate: 6.72 tokens/second
Average RTF: 0.47x
Per-chunk Stats:
Average chunk size: 550.00 tokens
Min chunk size: 100.00 tokens
Max chunk size: 1000.00 tokens
Average processing time: 82.70s
Average output length: 174.17s
Performance Ranges:
Processing rate range: 5.63 - 7.17 tokens/second
RTF range: 0.44x - 0.55x

View file

@ -0,0 +1,9 @@
=== Benchmark Statistics (with correct RTF) ===
Overall Stats:
Total tokens processed: 150850
Total audio generated: 46786.59s
Total test duration: 2012.90s
Average processing rate: 104.34 tokens/second
Average RTF: 0.03x

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,23 @@
=== Benchmark Statistics (with correct RTF) ===
Total tokens processed: 1800
Total audio generated (s): 568.53
Total test duration (s): 244.10
Average processing rate (tokens/s): 7.34
Average RTF: 0.43
Average Real Time Speed: 2.33
=== Per-chunk Stats ===
Average chunk size (tokens): 600.00
Min chunk size (tokens): 300
Max chunk size (tokens): 900
Average processing time (s): 81.30
Average output length (s): 189.51
=== Performance Ranges ===
Processing rate range (tokens/s): 7.21 - 7.47
RTF range: 0.43x - 0.43x
Real Time Speed range: 2.33x - 2.33x

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,23 @@
=== Benchmark Statistics (with correct RTF) ===
Total tokens processed: 17150
Total audio generated (s): 5296.38
Total test duration (s): 155.23
Average processing rate (tokens/s): 102.86
Average RTF: 0.03
Average Real Time Speed: 31.25
=== Per-chunk Stats ===
Average chunk size (tokens): 1715.00
Min chunk size (tokens): 150
Max chunk size (tokens): 5000
Average processing time (s): 15.39
Average output length (s): 529.64
=== Performance Ranges ===
Processing rate range (tokens/s): 80.65 - 125.10
RTF range: 0.03x - 0.04x
Real Time Speed range: 25.00x - 33.33x

Binary file not shown.

After

Width:  |  Height:  |  Size: 231 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 181 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 454 KiB

View file

Before

Width:  |  Height:  |  Size: 764 KiB

After

Width:  |  Height:  |  Size: 764 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 238 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 250 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 459 KiB

View file

Before

Width:  |  Height:  |  Size: 198 KiB

After

Width:  |  Height:  |  Size: 198 KiB

View file

@ -332,8 +332,8 @@ def main():
)
parser.add_argument("--url", default="http://localhost:8880", help="API base URL")
parser.add_argument(
"--output-dir",
default="examples/output",
"--output-dir",
default="examples/assorted_checks/test_combinations/output",
help="Output directory for audio files",
)
args = parser.parse_args()

View file

@ -0,0 +1,231 @@
import numpy as np
import soundfile as sf
import argparse
from pathlib import Path
def validate_tts(wav_path: str) -> dict:
"""
Quick validation checks for TTS-generated audio files to detect common artifacts.
Checks for:
- Unnatural silence gaps
- Audio glitches and artifacts
- Repeated speech segments (stuck/looping)
- Abrupt changes in speech
- Audio quality issues
Args:
wav_path: Path to audio file (wav, mp3, etc)
Returns:
Dictionary with validation results
"""
try:
# Load audio
audio, sr = sf.read(wav_path)
if len(audio.shape) > 1:
audio = audio.mean(axis=1) # Convert to mono
# Basic audio stats
duration = len(audio) / sr
rms = np.sqrt(np.mean(audio**2))
peak = np.max(np.abs(audio))
dc_offset = np.mean(audio)
# Calculate clipping stats if we're near peak
clip_count = np.sum(np.abs(audio) >= 0.99)
clip_percent = (clip_count / len(audio)) * 100
if clip_percent > 0:
clip_stats = f" ({clip_percent:.2e} ratio near peak)"
else:
clip_stats = " (no samples near peak)"
# Convert to dB for analysis
eps = np.finfo(float).eps
db = 20 * np.log10(np.abs(audio) + eps)
issues = []
# Check if audio is too short (likely failed generation)
if duration < 0.1: # Less than 100ms
issues.append("WARNING: Audio is suspiciously short - possible failed generation")
# 1. Check for basic audio quality
if peak >= 1.0:
# Calculate percentage of samples that are clipping
clip_count = np.sum(np.abs(audio) >= 0.99)
clip_percent = (clip_count / len(audio)) * 100
if clip_percent > 1.0: # Only warn if more than 1% of samples clip
issues.append(f"WARNING: Significant clipping detected ({clip_percent:.2e}% of samples)")
elif clip_percent > 0.01: # Add info if more than 0.01% but less than 1%
issues.append(f"INFO: Minor peak limiting detected ({clip_percent:.2e}% of samples) - likely intentional normalization")
if rms < 0.01:
issues.append("WARNING: Audio is very quiet - possible failed generation")
if abs(dc_offset) > 0.1: # DC offset is particularly bad for speech
issues.append(f"WARNING: High DC offset ({dc_offset:.3f}) - may cause audio artifacts")
# 2. Check for long silence gaps (potential TTS failures)
silence_threshold = -45 # dB
min_silence = 2.0 # Only detect silences longer than 2 seconds
window_size = int(min_silence * sr)
silence_count = 0
last_silence = -1
# Skip the first 0.2s for silence detection (avoid false positives at start)
start_idx = int(0.2 * sr)
for i in range(start_idx, len(db) - window_size, window_size):
window = db[i:i+window_size]
if np.mean(window) < silence_threshold:
# Verify the entire window is mostly silence
silent_ratio = np.mean(window < silence_threshold)
if silent_ratio > 0.9: # 90% of the window should be below threshold
if last_silence == -1 or (i/sr - last_silence) > 2.0: # Only count silences more than 2s apart
silence_count += 1
last_silence = i/sr
issues.append(f"WARNING: Long silence detected at {i/sr:.2f}s (duration: {min_silence:.1f}s)")
if silence_count > 2: # Only warn if there are multiple long silences
issues.append(f"WARNING: Multiple long silences found ({silence_count} total) - possible generation issue")
# 3. Check for extreme audio artifacts (changes too rapid for natural speech)
# Use a longer window to avoid flagging normal phoneme transitions
window_size = int(0.02 * sr) # 20ms window
db_smooth = np.convolve(db, np.ones(window_size)/window_size, 'same')
db_diff = np.abs(np.diff(db_smooth))
# Much higher threshold to only catch truly unnatural changes
artifact_threshold = 40 # dB
min_duration = int(0.01 * sr) # Minimum 10ms duration
# Find regions where the smoothed dB change is extreme
artifact_points = np.where(db_diff > artifact_threshold)[0]
if len(artifact_points) > 0:
# Group artifacts that are very close together
grouped_artifacts = []
current_group = [artifact_points[0]]
for i in range(1, len(artifact_points)):
if (artifact_points[i] - current_group[-1]) < min_duration:
current_group.append(artifact_points[i])
else:
if len(current_group) * (1/sr) >= 0.01: # Only keep groups lasting >= 10ms
grouped_artifacts.append(current_group)
current_group = [artifact_points[i]]
if len(current_group) * (1/sr) >= 0.01:
grouped_artifacts.append(current_group)
# Report only the most severe artifacts
for group in grouped_artifacts[:2]: # Report up to 2 worst artifacts
center_idx = group[len(group)//2]
db_change = db_diff[center_idx]
if db_change > 45: # Only report very extreme changes
issues.append(
f"WARNING: Possible audio artifact at {center_idx/sr:.2f}s "
f"({db_change:.1f}dB change over {len(group)/sr*1000:.0f}ms)"
)
# 4. Check for repeated speech segments (stuck/looping)
# Check both short and long sentence durations at audiobook speed (150-160 wpm)
for chunk_duration in [5.0, 10.0]: # 5s (~12 words) and 10s (~25 words) at ~audiobook speed
chunk_size = int(chunk_duration * sr)
overlap = int(0.2 * chunk_size) # 20% overlap between chunks
for i in range(0, len(audio) - 2*chunk_size, overlap):
chunk1 = audio[i:i+chunk_size]
chunk2 = audio[i+chunk_size:i+2*chunk_size]
# Ignore chunks that are mostly silence
if np.mean(np.abs(chunk1)) < 0.01 or np.mean(np.abs(chunk2)) < 0.01:
continue
try:
correlation = np.corrcoef(chunk1, chunk2)[0,1]
if not np.isnan(correlation) and correlation > 0.92: # Lower threshold for sentence-length chunks
issues.append(
f"WARNING: Possible repeated speech at {i/sr:.1f}s "
f"(~{int(chunk_duration*160/60):d} words, correlation: {correlation:.3f})"
)
break # Found repetition at this duration, try next duration
except:
continue
# 5. Check for extreme amplitude discontinuities (common in failed TTS)
amplitude_envelope = np.abs(audio)
window_size = sr // 10 # 100ms window for smoother envelope
smooth_env = np.convolve(amplitude_envelope, np.ones(window_size)/float(window_size), 'same')
env_diff = np.abs(np.diff(smooth_env))
# Only detect very extreme amplitude changes
jump_threshold = 0.5 # Much higher threshold
jumps = np.where(env_diff > jump_threshold)[0]
if len(jumps) > 0:
# Group jumps that are close together
grouped_jumps = []
current_group = [jumps[0]]
for i in range(1, len(jumps)):
if (jumps[i] - current_group[-1]) < 0.05 * sr: # Group within 50ms
current_group.append(jumps[i])
else:
if len(current_group) >= 3: # Only keep significant discontinuities
grouped_jumps.append(current_group)
current_group = [jumps[i]]
if len(current_group) >= 3:
grouped_jumps.append(current_group)
# Report only the most severe discontinuities
for group in grouped_jumps[:2]: # Report up to 2 worst cases
center_idx = group[len(group)//2]
jump_size = env_diff[center_idx]
if jump_size > 0.6: # Only report very extreme changes
issues.append(
f"WARNING: Possible audio discontinuity at {center_idx/sr:.2f}s "
f"({jump_size:.2f} amplitude ratio change)"
)
return {
"file": wav_path,
"duration": f"{duration:.2f}s",
"sample_rate": sr,
"peak_amplitude": f"{peak:.3f}{clip_stats}",
"rms_level": f"{rms:.3f}",
"dc_offset": f"{dc_offset:.3f}",
"issues": issues,
"valid": len(issues) == 0
}
except Exception as e:
return {
"file": wav_path,
"error": str(e),
"valid": False
}
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="TTS Output Validator")
parser.add_argument("wav_file", help="Path to audio file to validate")
args = parser.parse_args()
result = validate_tts(args.wav_file)
print(f"\nValidating: {result['file']}")
if "error" in result:
print(f"Error: {result['error']}")
else:
print(f"Duration: {result['duration']}")
print(f"Sample Rate: {result['sample_rate']} Hz")
print(f"Peak Amplitude: {result['peak_amplitude']}")
print(f"RMS Level: {result['rms_level']}")
print(f"DC Offset: {result['dc_offset']}")
if result["issues"]:
print("\nIssues Found:")
for issue in result["issues"]:
print(f"- {issue}")
else:
print("\nNo issues found")

View file

@ -0,0 +1,72 @@
import argparse
from pathlib import Path
from validate_wav import validate_tts
def print_validation_result(result: dict, rel_path: Path):
"""Print full validation details for a single file."""
print(f"\nValidating: {rel_path}")
if "error" in result:
print(f"Error: {result['error']}")
else:
print(f"Duration: {result['duration']}")
print(f"Sample Rate: {result['sample_rate']} Hz")
print(f"Peak Amplitude: {result['peak_amplitude']}")
print(f"RMS Level: {result['rms_level']}")
print(f"DC Offset: {result['dc_offset']}")
if result["issues"]:
print("\nIssues Found:")
for issue in result["issues"]:
print(f"- {issue}")
else:
print("\nNo issues found")
def validate_directory(directory: str):
"""Validate all wav files in a directory with detailed output and summary."""
dir_path = Path(directory)
# Find all wav files (including nested directories)
wav_files = list(dir_path.rglob("*.wav"))
wav_files.extend(dir_path.rglob("*.mp3")) # Also check mp3s
wav_files = sorted(wav_files)
if not wav_files:
print(f"No .wav or .mp3 files found in {directory}")
return
print(f"Found {len(wav_files)} files in {directory}")
print("=" * 80)
# Store results for summary
results = []
# Detailed validation output
for wav_file in wav_files:
result = validate_tts(str(wav_file))
rel_path = wav_file.relative_to(dir_path)
print_validation_result(result, rel_path)
results.append((rel_path, result))
print("=" * 80)
# Summary with detailed issues
print("\nSUMMARY:")
for rel_path, result in results:
if "error" in result:
print(f"{rel_path}: ERROR - {result['error']}")
elif result["issues"]:
# Show first issue in summary, indicate if there are more
issues = result["issues"]
first_issue = issues[0].replace("WARNING: ", "")
if len(issues) > 1:
print(f"{rel_path}: FAIL - {first_issue} (+{len(issues)-1} more issues)")
else:
print(f"{rel_path}: FAIL - {first_issue}")
else:
print(f"{rel_path}: PASS")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Batch validate TTS wav files")
parser.add_argument("directory", help="Directory containing wav files to validate")
args = parser.parse_args()
validate_directory(args.directory)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 754 KiB

View file

@ -1,531 +0,0 @@
{
"results": [
{
"tokens": 100,
"processing_time": 8.54442310333252,
"output_length": 31.15,
"realtime_factor": 3.6456527987068887,
"elapsed_time": 8.720048666000366
},
{
"tokens": 200,
"processing_time": 1.3838517665863037,
"output_length": 62.6,
"realtime_factor": 45.236058883981606,
"elapsed_time": 10.258155345916748
},
{
"tokens": 300,
"processing_time": 2.2024788856506348,
"output_length": 96.325,
"realtime_factor": 43.73481200095347,
"elapsed_time": 12.594647407531738
},
{
"tokens": 400,
"processing_time": 3.175424098968506,
"output_length": 128.55,
"realtime_factor": 40.48278150995886,
"elapsed_time": 16.005898475646973
},
{
"tokens": 500,
"processing_time": 3.205301523208618,
"output_length": 158.55,
"realtime_factor": 49.46492517224587,
"elapsed_time": 19.377076625823975
},
{
"tokens": 600,
"processing_time": 3.9976348876953125,
"output_length": 189.225,
"realtime_factor": 47.33423769700254,
"elapsed_time": 23.568575859069824
},
{
"tokens": 700,
"processing_time": 4.98036003112793,
"output_length": 222.05,
"realtime_factor": 44.58513011351734,
"elapsed_time": 28.767319917678833
},
{
"tokens": 800,
"processing_time": 5.156893491744995,
"output_length": 253.825,
"realtime_factor": 49.22052402406907,
"elapsed_time": 34.1369092464447
},
{
"tokens": 900,
"processing_time": 5.8110880851745605,
"output_length": 283.75,
"realtime_factor": 48.82906537312906,
"elapsed_time": 40.16419458389282
},
{
"tokens": 1000,
"processing_time": 6.686216354370117,
"output_length": 315.45,
"realtime_factor": 47.17914935460046,
"elapsed_time": 47.11375427246094
},
{
"tokens": 2000,
"processing_time": 13.290695905685425,
"output_length": 624.925,
"realtime_factor": 47.01973504131358,
"elapsed_time": 60.842002630233765
},
{
"tokens": 3000,
"processing_time": 20.058005571365356,
"output_length": 932.05,
"realtime_factor": 46.46773063671828,
"elapsed_time": 81.50969815254211
},
{
"tokens": 4000,
"processing_time": 26.38338828086853,
"output_length": 1222.975,
"realtime_factor": 46.353978002394015,
"elapsed_time": 108.76348638534546
},
{
"tokens": 5000,
"processing_time": 32.472310066223145,
"output_length": 1525.15,
"realtime_factor": 46.967708699801484,
"elapsed_time": 142.2994668483734
},
{
"tokens": 6000,
"processing_time": 42.67592263221741,
"output_length": 1837.525,
"realtime_factor": 43.0576514030137,
"elapsed_time": 186.26759266853333
},
{
"tokens": 7000,
"processing_time": 51.601537466049194,
"output_length": 2146.875,
"realtime_factor": 41.60486499869347,
"elapsed_time": 239.59922289848328
},
{
"tokens": 8000,
"processing_time": 51.86434292793274,
"output_length": 2458.425,
"realtime_factor": 47.401063258741466,
"elapsed_time": 293.4462616443634
},
{
"tokens": 9000,
"processing_time": 60.4497971534729,
"output_length": 2772.1,
"realtime_factor": 45.857887545297416,
"elapsed_time": 356.02399826049805
},
{
"tokens": 10000,
"processing_time": 71.75962543487549,
"output_length": 3085.625,
"realtime_factor": 42.99945800024164,
"elapsed_time": 430.50863671302795
},
{
"tokens": 11000,
"processing_time": 96.66409230232239,
"output_length": 3389.3,
"realtime_factor": 35.062657904030935,
"elapsed_time": 529.3296246528625
},
{
"tokens": 12000,
"processing_time": 85.70126295089722,
"output_length": 3703.175,
"realtime_factor": 43.21027336693678,
"elapsed_time": 618.0248212814331
},
{
"tokens": 13000,
"processing_time": 97.2874686717987,
"output_length": 4030.825,
"realtime_factor": 41.43210893479068,
"elapsed_time": 717.9070522785187
},
{
"tokens": 14000,
"processing_time": 105.1045708656311,
"output_length": 4356.775,
"realtime_factor": 41.451812838566596,
"elapsed_time": 826.1140224933624
},
{
"tokens": 15000,
"processing_time": 111.0716404914856,
"output_length": 4663.325,
"realtime_factor": 41.984839508672565,
"elapsed_time": 940.0645899772644
},
{
"tokens": 16000,
"processing_time": 116.61742973327637,
"output_length": 4978.65,
"realtime_factor": 42.692160266154104,
"elapsed_time": 1061.1957621574402
}
],
"system_metrics": [
{
"timestamp": "2024-12-31T03:12:36.009478",
"cpu_percent": 8.1,
"ram_percent": 66.8,
"ram_used_gb": 42.47850799560547,
"gpu_memory_used": 2124.0
},
{
"timestamp": "2024-12-31T03:12:44.639678",
"cpu_percent": 7.7,
"ram_percent": 69.1,
"ram_used_gb": 43.984352111816406,
"gpu_memory_used": 3486.0
},
{
"timestamp": "2024-12-31T03:12:44.731107",
"cpu_percent": 8.3,
"ram_percent": 69.1,
"ram_used_gb": 43.97468948364258,
"gpu_memory_used": 3484.0
},
{
"timestamp": "2024-12-31T03:12:46.189723",
"cpu_percent": 14.2,
"ram_percent": 69.1,
"ram_used_gb": 43.98275375366211,
"gpu_memory_used": 3697.0
},
{
"timestamp": "2024-12-31T03:12:46.265437",
"cpu_percent": 4.7,
"ram_percent": 69.1,
"ram_used_gb": 43.982975006103516,
"gpu_memory_used": 3697.0
},
{
"timestamp": "2024-12-31T03:12:48.536216",
"cpu_percent": 12.5,
"ram_percent": 69.0,
"ram_used_gb": 43.86142349243164,
"gpu_memory_used": 3697.0
},
{
"timestamp": "2024-12-31T03:12:48.603827",
"cpu_percent": 6.2,
"ram_percent": 69.0,
"ram_used_gb": 43.8692626953125,
"gpu_memory_used": 3694.0
},
{
"timestamp": "2024-12-31T03:12:51.905764",
"cpu_percent": 14.2,
"ram_percent": 69.1,
"ram_used_gb": 43.93961715698242,
"gpu_memory_used": 3690.0
},
{
"timestamp": "2024-12-31T03:12:52.028178",
"cpu_percent": 26.0,
"ram_percent": 69.1,
"ram_used_gb": 43.944759368896484,
"gpu_memory_used": 3690.0
},
{
"timestamp": "2024-12-31T03:12:55.320709",
"cpu_percent": 13.2,
"ram_percent": 69.1,
"ram_used_gb": 43.943058013916016,
"gpu_memory_used": 3685.0
},
{
"timestamp": "2024-12-31T03:12:55.386582",
"cpu_percent": 3.2,
"ram_percent": 69.1,
"ram_used_gb": 43.9305419921875,
"gpu_memory_used": 3685.0
},
{
"timestamp": "2024-12-31T03:12:59.492304",
"cpu_percent": 15.6,
"ram_percent": 69.1,
"ram_used_gb": 43.964195251464844,
"gpu_memory_used": 4053.0
},
{
"timestamp": "2024-12-31T03:12:59.586143",
"cpu_percent": 2.1,
"ram_percent": 69.1,
"ram_used_gb": 43.9642448425293,
"gpu_memory_used": 4053.0
},
{
"timestamp": "2024-12-31T03:13:04.705286",
"cpu_percent": 12.0,
"ram_percent": 69.2,
"ram_used_gb": 43.992374420166016,
"gpu_memory_used": 4059.0
},
{
"timestamp": "2024-12-31T03:13:04.779475",
"cpu_percent": 4.7,
"ram_percent": 69.2,
"ram_used_gb": 43.9922981262207,
"gpu_memory_used": 4059.0
},
{
"timestamp": "2024-12-31T03:13:10.063292",
"cpu_percent": 12.4,
"ram_percent": 69.2,
"ram_used_gb": 44.004146575927734,
"gpu_memory_used": 4041.0
},
{
"timestamp": "2024-12-31T03:13:10.155395",
"cpu_percent": 6.8,
"ram_percent": 69.2,
"ram_used_gb": 44.004215240478516,
"gpu_memory_used": 4041.0
},
{
"timestamp": "2024-12-31T03:13:16.097887",
"cpu_percent": 13.1,
"ram_percent": 69.2,
"ram_used_gb": 44.0260009765625,
"gpu_memory_used": 4042.0
},
{
"timestamp": "2024-12-31T03:13:16.171478",
"cpu_percent": 4.5,
"ram_percent": 69.2,
"ram_used_gb": 44.02027130126953,
"gpu_memory_used": 4042.0
},
{
"timestamp": "2024-12-31T03:13:23.044945",
"cpu_percent": 12.6,
"ram_percent": 69.2,
"ram_used_gb": 44.03746795654297,
"gpu_memory_used": 4044.0
},
{
"timestamp": "2024-12-31T03:13:23.127442",
"cpu_percent": 8.3,
"ram_percent": 69.2,
"ram_used_gb": 44.0373420715332,
"gpu_memory_used": 4044.0
},
{
"timestamp": "2024-12-31T03:13:36.780309",
"cpu_percent": 12.5,
"ram_percent": 69.2,
"ram_used_gb": 44.00790786743164,
"gpu_memory_used": 4034.0
},
{
"timestamp": "2024-12-31T03:13:36.853474",
"cpu_percent": 6.2,
"ram_percent": 69.2,
"ram_used_gb": 44.00779724121094,
"gpu_memory_used": 4034.0
},
{
"timestamp": "2024-12-31T03:13:57.449274",
"cpu_percent": 12.4,
"ram_percent": 69.2,
"ram_used_gb": 44.0432243347168,
"gpu_memory_used": 4034.0
},
{
"timestamp": "2024-12-31T03:13:57.524592",
"cpu_percent": 6.2,
"ram_percent": 69.2,
"ram_used_gb": 44.03204345703125,
"gpu_memory_used": 4034.0
},
{
"timestamp": "2024-12-31T03:14:24.698822",
"cpu_percent": 13.4,
"ram_percent": 69.5,
"ram_used_gb": 44.18327331542969,
"gpu_memory_used": 4480.0
},
{
"timestamp": "2024-12-31T03:14:24.783683",
"cpu_percent": 4.2,
"ram_percent": 69.5,
"ram_used_gb": 44.182212829589844,
"gpu_memory_used": 4480.0
},
{
"timestamp": "2024-12-31T03:14:58.242642",
"cpu_percent": 12.8,
"ram_percent": 69.5,
"ram_used_gb": 44.20225524902344,
"gpu_memory_used": 4476.0
},
{
"timestamp": "2024-12-31T03:14:58.310907",
"cpu_percent": 2.9,
"ram_percent": 69.5,
"ram_used_gb": 44.19659423828125,
"gpu_memory_used": 4476.0
},
{
"timestamp": "2024-12-31T03:15:42.196813",
"cpu_percent": 14.3,
"ram_percent": 69.9,
"ram_used_gb": 44.43781661987305,
"gpu_memory_used": 4494.0
},
{
"timestamp": "2024-12-31T03:15:42.288427",
"cpu_percent": 13.7,
"ram_percent": 69.9,
"ram_used_gb": 44.439701080322266,
"gpu_memory_used": 4494.0
},
{
"timestamp": "2024-12-31T03:16:35.483849",
"cpu_percent": 14.7,
"ram_percent": 65.0,
"ram_used_gb": 41.35385513305664,
"gpu_memory_used": 4506.0
},
{
"timestamp": "2024-12-31T03:16:35.626628",
"cpu_percent": 32.9,
"ram_percent": 65.0,
"ram_used_gb": 41.34442138671875,
"gpu_memory_used": 4506.0
},
{
"timestamp": "2024-12-31T03:17:29.378353",
"cpu_percent": 13.4,
"ram_percent": 64.3,
"ram_used_gb": 40.8721809387207,
"gpu_memory_used": 4485.0
},
{
"timestamp": "2024-12-31T03:17:29.457464",
"cpu_percent": 5.1,
"ram_percent": 64.3,
"ram_used_gb": 40.875389099121094,
"gpu_memory_used": 4485.0
},
{
"timestamp": "2024-12-31T03:18:31.955862",
"cpu_percent": 14.3,
"ram_percent": 65.0,
"ram_used_gb": 41.360206604003906,
"gpu_memory_used": 4484.0
},
{
"timestamp": "2024-12-31T03:18:32.038999",
"cpu_percent": 12.5,
"ram_percent": 65.0,
"ram_used_gb": 41.37223434448242,
"gpu_memory_used": 4484.0
},
{
"timestamp": "2024-12-31T03:19:46.454105",
"cpu_percent": 13.9,
"ram_percent": 65.3,
"ram_used_gb": 41.562198638916016,
"gpu_memory_used": 4487.0
},
{
"timestamp": "2024-12-31T03:19:46.524303",
"cpu_percent": 6.8,
"ram_percent": 65.3,
"ram_used_gb": 41.56681442260742,
"gpu_memory_used": 4487.0
},
{
"timestamp": "2024-12-31T03:21:25.251452",
"cpu_percent": 23.7,
"ram_percent": 62.0,
"ram_used_gb": 39.456459045410156,
"gpu_memory_used": 4488.0
},
{
"timestamp": "2024-12-31T03:21:25.348643",
"cpu_percent": 2.9,
"ram_percent": 62.0,
"ram_used_gb": 39.454288482666016,
"gpu_memory_used": 4487.0
},
{
"timestamp": "2024-12-31T03:22:53.939896",
"cpu_percent": 12.9,
"ram_percent": 62.1,
"ram_used_gb": 39.50320053100586,
"gpu_memory_used": 4488.0
},
{
"timestamp": "2024-12-31T03:22:54.041607",
"cpu_percent": 8.3,
"ram_percent": 62.1,
"ram_used_gb": 39.49895095825195,
"gpu_memory_used": 4488.0
},
{
"timestamp": "2024-12-31T03:24:33.835432",
"cpu_percent": 12.9,
"ram_percent": 62.3,
"ram_used_gb": 39.647212982177734,
"gpu_memory_used": 4503.0
},
{
"timestamp": "2024-12-31T03:24:33.923914",
"cpu_percent": 7.6,
"ram_percent": 62.3,
"ram_used_gb": 39.64302062988281,
"gpu_memory_used": 4503.0
},
{
"timestamp": "2024-12-31T03:26:22.021598",
"cpu_percent": 12.9,
"ram_percent": 58.4,
"ram_used_gb": 37.162540435791016,
"gpu_memory_used": 4491.0
},
{
"timestamp": "2024-12-31T03:26:22.142138",
"cpu_percent": 12.0,
"ram_percent": 58.4,
"ram_used_gb": 37.162010192871094,
"gpu_memory_used": 4487.0
},
{
"timestamp": "2024-12-31T03:28:15.970365",
"cpu_percent": 15.0,
"ram_percent": 58.2,
"ram_used_gb": 37.04011535644531,
"gpu_memory_used": 4481.0
},
{
"timestamp": "2024-12-31T03:28:16.096459",
"cpu_percent": 12.4,
"ram_percent": 58.2,
"ram_used_gb": 37.035972595214844,
"gpu_memory_used": 4473.0
},
{
"timestamp": "2024-12-31T03:30:17.092257",
"cpu_percent": 12.4,
"ram_percent": 58.4,
"ram_used_gb": 37.14639663696289,
"gpu_memory_used": 4459.0
}
]
}

View file

@ -1,19 +0,0 @@
=== Benchmark Statistics ===
Overall Stats:
Total tokens processed: 140500
Total audio generated: 43469.18s
Total test duration: 1061.20s
Average processing rate: 137.67 tokens/second
Average realtime factor: 42.93x
Per-chunk Stats:
Average chunk size: 5620.00 tokens
Min chunk size: 100.00 tokens
Max chunk size: 16000.00 tokens
Average processing time: 41.13s
Average output length: 1738.77s
Performance Ranges:
Processing rate range: 11.70 - 155.99 tokens/second
Realtime factor range: 3.65x - 49.46x

View file

@ -1,406 +0,0 @@
import os
import json
import time
import subprocess
from datetime import datetime
import pandas as pd
import psutil
import seaborn as sns
import requests
import tiktoken
import scipy.io.wavfile as wavfile
import matplotlib.pyplot as plt
enc = tiktoken.get_encoding("cl100k_base")
def setup_plot(fig, ax, title):
"""Configure plot styling"""
# Improve grid
ax.grid(True, linestyle="--", alpha=0.3, color="#ffffff")
# Set title and labels with better fonts
ax.set_title(title, pad=20, fontsize=16, fontweight="bold", color="#ffffff")
ax.set_xlabel(ax.get_xlabel(), fontsize=14, fontweight="medium", color="#ffffff")
ax.set_ylabel(ax.get_ylabel(), fontsize=14, fontweight="medium", color="#ffffff")
# Improve tick labels
ax.tick_params(labelsize=12, colors="#ffffff")
# Style spines
for spine in ax.spines.values():
spine.set_color("#ffffff")
spine.set_alpha(0.3)
spine.set_linewidth(0.5)
# Set background colors
ax.set_facecolor("#1a1a2e")
fig.patch.set_facecolor("#1a1a2e")
return fig, ax
def get_text_for_tokens(text: str, num_tokens: int) -> str:
"""Get a slice of text that contains exactly num_tokens tokens"""
tokens = enc.encode(text)
if num_tokens > len(tokens):
return text
return enc.decode(tokens[:num_tokens])
def get_audio_length(audio_data: bytes) -> float:
"""Get audio length in seconds from bytes data"""
# Save to a temporary file
temp_path = "examples/benchmarks/output/temp.wav"
os.makedirs(os.path.dirname(temp_path), exist_ok=True)
with open(temp_path, "wb") as f:
f.write(audio_data)
# Read the audio file
try:
rate, data = wavfile.read(temp_path)
return len(data) / rate
finally:
# Clean up temp file
if os.path.exists(temp_path):
os.remove(temp_path)
def get_gpu_memory():
"""Get GPU memory usage using nvidia-smi"""
try:
result = subprocess.check_output(
["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader"]
)
return float(result.decode("utf-8").strip())
except (subprocess.CalledProcessError, FileNotFoundError):
return None
def get_system_metrics():
"""Get current system metrics"""
metrics = {
"timestamp": datetime.now().isoformat(),
"cpu_percent": psutil.cpu_percent(),
"ram_percent": psutil.virtual_memory().percent,
"ram_used_gb": psutil.virtual_memory().used / (1024**3),
}
gpu_mem = get_gpu_memory()
if gpu_mem is not None:
metrics["gpu_memory_used"] = gpu_mem
return metrics
def make_tts_request(text: str, timeout: int = 120) -> tuple[float, float]:
"""Make TTS request using OpenAI-compatible endpoint and return processing time and output length"""
try:
start_time = time.time()
# Make request to OpenAI-compatible endpoint
response = requests.post(
"http://localhost:8880/v1/audio/speech",
json={
"model": "kokoro",
"input": text,
"voice": "af",
"response_format": "wav",
},
timeout=timeout,
)
response.raise_for_status()
processing_time = time.time() - start_time
audio_length = get_audio_length(response.content)
# Save the audio file
token_count = len(enc.encode(text))
output_file = f"examples/benchmarks/output/chunk_{token_count}_tokens.wav"
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, "wb") as f:
f.write(response.content)
print(f"Saved audio to {output_file}")
return processing_time, audio_length
except requests.exceptions.RequestException as e:
print(f"Error making request for text: {text[:50]}... Error: {str(e)}")
return None, None
except Exception as e:
print(f"Error processing text: {text[:50]}... Error: {str(e)}")
return None, None
def plot_system_metrics(metrics_data):
"""Create plots for system metrics over time"""
df = pd.DataFrame(metrics_data)
df["timestamp"] = pd.to_datetime(df["timestamp"])
elapsed_time = (df["timestamp"] - df["timestamp"].iloc[0]).dt.total_seconds()
# Get baseline values (first measurement)
baseline_cpu = df["cpu_percent"].iloc[0]
baseline_ram = df["ram_used_gb"].iloc[0]
baseline_gpu = (
df["gpu_memory_used"].iloc[0] / 1024
if "gpu_memory_used" in df.columns
else None
) # Convert MB to GB
# Convert GPU memory to GB
if "gpu_memory_used" in df.columns:
df["gpu_memory_gb"] = df["gpu_memory_used"] / 1024
# Set plotting style
plt.style.use("dark_background")
# Create figure with 3 subplots (or 2 if no GPU)
has_gpu = "gpu_memory_used" in df.columns
num_plots = 3 if has_gpu else 2
fig, axes = plt.subplots(num_plots, 1, figsize=(15, 5 * num_plots))
fig.patch.set_facecolor("#1a1a2e")
# Apply rolling average for smoothing
window = min(5, len(df) // 2) # Smaller window for smoother lines
# Plot 1: CPU Usage
smoothed_cpu = df["cpu_percent"].rolling(window=window, center=True).mean()
sns.lineplot(
x=elapsed_time, y=smoothed_cpu, ax=axes[0], color="#ff2a6d", linewidth=2
)
axes[0].axhline(
y=baseline_cpu, color="#05d9e8", linestyle="--", alpha=0.5, label="Baseline"
)
axes[0].set_xlabel("Time (seconds)", fontsize=14)
axes[0].set_ylabel("CPU Usage (%)", fontsize=14)
axes[0].tick_params(labelsize=12)
axes[0].set_title("CPU Usage Over Time", pad=20, fontsize=16, fontweight="bold")
axes[0].set_ylim(0, max(df["cpu_percent"]) * 1.1) # Add 10% padding
axes[0].legend()
# Plot 2: RAM Usage
smoothed_ram = df["ram_used_gb"].rolling(window=window, center=True).mean()
sns.lineplot(
x=elapsed_time, y=smoothed_ram, ax=axes[1], color="#05d9e8", linewidth=2
)
axes[1].axhline(
y=baseline_ram, color="#ff2a6d", linestyle="--", alpha=0.5, label="Baseline"
)
axes[1].set_xlabel("Time (seconds)", fontsize=14)
axes[1].set_ylabel("RAM Usage (GB)", fontsize=14)
axes[1].tick_params(labelsize=12)
axes[1].set_title("RAM Usage Over Time", pad=20, fontsize=16, fontweight="bold")
axes[1].set_ylim(0, max(df["ram_used_gb"]) * 1.1) # Add 10% padding
axes[1].legend()
# Plot 3: GPU Memory (if available)
if has_gpu:
smoothed_gpu = df["gpu_memory_gb"].rolling(window=window, center=True).mean()
sns.lineplot(
x=elapsed_time, y=smoothed_gpu, ax=axes[2], color="#ff2a6d", linewidth=2
)
axes[2].axhline(
y=baseline_gpu, color="#05d9e8", linestyle="--", alpha=0.5, label="Baseline"
)
axes[2].set_xlabel("Time (seconds)", fontsize=14)
axes[2].set_ylabel("GPU Memory (GB)", fontsize=14)
axes[2].tick_params(labelsize=12)
axes[2].set_title(
"GPU Memory Usage Over Time", pad=20, fontsize=16, fontweight="bold"
)
axes[2].set_ylim(0, max(df["gpu_memory_gb"]) * 1.1) # Add 10% padding
axes[2].legend()
# Style all subplots
for ax in axes:
ax.grid(True, linestyle="--", alpha=0.3)
ax.set_facecolor("#1a1a2e")
for spine in ax.spines.values():
spine.set_color("#ffffff")
spine.set_alpha(0.3)
plt.tight_layout()
plt.savefig("examples/benchmarks/system_usage.png", dpi=300, bbox_inches="tight")
plt.close()
def main():
# Create output directory
os.makedirs("examples/benchmarks/output", exist_ok=True)
# Read input text
with open(
"examples/benchmarks/the_time_machine_hg_wells.txt", "r", encoding="utf-8"
) as f:
text = f.read()
# Get total tokens in file
total_tokens = len(enc.encode(text))
print(f"Total tokens in file: {total_tokens}")
# Generate token sizes with dense sampling at start and increasing intervals
dense_range = list(range(100, 1001, 100))
current = max(dense_range)
large_range = []
while current <= total_tokens:
large_range.append(current)
current += 1000
token_sizes = sorted(list(set(dense_range + large_range)))
print(f"Testing sizes: {token_sizes}")
# Process chunks
results = []
system_metrics = []
test_start_time = time.time()
for num_tokens in token_sizes:
# Get text slice with exact token count
chunk = get_text_for_tokens(text, num_tokens)
actual_tokens = len(enc.encode(chunk))
print(f"\nProcessing chunk with {actual_tokens} tokens:")
print(f"Text preview: {chunk[:100]}...")
# Collect system metrics before processing
system_metrics.append(get_system_metrics())
processing_time, audio_length = make_tts_request(chunk)
if processing_time is None or audio_length is None:
print("Breaking loop due to error")
break
# Collect system metrics after processing
system_metrics.append(get_system_metrics())
results.append(
{
"tokens": actual_tokens,
"processing_time": processing_time,
"output_length": audio_length,
"realtime_factor": audio_length / processing_time,
"elapsed_time": time.time() - test_start_time,
}
)
# Save intermediate results
with open("examples/benchmarks/benchmark_results.json", "w") as f:
json.dump(
{"results": results, "system_metrics": system_metrics}, f, indent=2
)
# Create DataFrame and calculate stats
df = pd.DataFrame(results)
if df.empty:
print("No data to plot")
return
# Calculate useful metrics
df["tokens_per_second"] = df["tokens"] / df["processing_time"]
# Write detailed stats
with open("examples/benchmarks/benchmark_stats.txt", "w") as f:
f.write("=== Benchmark Statistics ===\n\n")
f.write("Overall Stats:\n")
f.write(f"Total tokens processed: {df['tokens'].sum()}\n")
f.write(f"Total audio generated: {df['output_length'].sum():.2f}s\n")
f.write(f"Total test duration: {df['elapsed_time'].max():.2f}s\n")
f.write(
f"Average processing rate: {df['tokens_per_second'].mean():.2f} tokens/second\n"
)
f.write(f"Average realtime factor: {df['realtime_factor'].mean():.2f}x\n\n")
f.write("Per-chunk Stats:\n")
f.write(f"Average chunk size: {df['tokens'].mean():.2f} tokens\n")
f.write(f"Min chunk size: {df['tokens'].min():.2f} tokens\n")
f.write(f"Max chunk size: {df['tokens'].max():.2f} tokens\n")
f.write(f"Average processing time: {df['processing_time'].mean():.2f}s\n")
f.write(f"Average output length: {df['output_length'].mean():.2f}s\n\n")
f.write("Performance Ranges:\n")
f.write(
f"Processing rate range: {df['tokens_per_second'].min():.2f} - {df['tokens_per_second'].max():.2f} tokens/second\n"
)
f.write(
f"Realtime factor range: {df['realtime_factor'].min():.2f}x - {df['realtime_factor'].max():.2f}x\n"
)
# Set plotting style
plt.style.use("dark_background")
# Plot 1: Processing Time vs Token Count
fig, ax = plt.subplots(figsize=(12, 8))
sns.scatterplot(
data=df, x="tokens", y="processing_time", s=100, alpha=0.6, color="#ff2a6d"
)
sns.regplot(
data=df,
x="tokens",
y="processing_time",
scatter=False,
color="#05d9e8",
line_kws={"linewidth": 2},
)
corr = df["tokens"].corr(df["processing_time"])
plt.text(
0.05,
0.95,
f"Correlation: {corr:.2f}",
transform=ax.transAxes,
fontsize=10,
color="#ffffff",
bbox=dict(facecolor="#1a1a2e", edgecolor="#ffffff", alpha=0.7),
)
setup_plot(fig, ax, "Processing Time vs Input Size")
ax.set_xlabel("Number of Input Tokens")
ax.set_ylabel("Processing Time (seconds)")
plt.savefig("examples/benchmarks/processing_time.png", dpi=300, bbox_inches="tight")
plt.close()
# Plot 2: Realtime Factor vs Token Count
fig, ax = plt.subplots(figsize=(12, 8))
sns.scatterplot(
data=df, x="tokens", y="realtime_factor", s=100, alpha=0.6, color="#ff2a6d"
)
sns.regplot(
data=df,
x="tokens",
y="realtime_factor",
scatter=False,
color="#05d9e8",
line_kws={"linewidth": 2},
)
corr = df["tokens"].corr(df["realtime_factor"])
plt.text(
0.05,
0.95,
f"Correlation: {corr:.2f}",
transform=ax.transAxes,
fontsize=10,
color="#ffffff",
bbox=dict(facecolor="#1a1a2e", edgecolor="#ffffff", alpha=0.7),
)
setup_plot(fig, ax, "Realtime Factor vs Input Size")
ax.set_xlabel("Number of Input Tokens")
ax.set_ylabel("Realtime Factor (output length / processing time)")
plt.savefig("examples/benchmarks/realtime_factor.png", dpi=300, bbox_inches="tight")
plt.close()
# Plot system metrics
plot_system_metrics(system_metrics)
print("\nResults saved to:")
print("- examples/benchmarks/benchmark_results.json")
print("- examples/benchmarks/benchmark_stats.txt")
print("- examples/benchmarks/processing_time.png")
print("- examples/benchmarks/realtime_factor.png")
print("- examples/benchmarks/system_usage.png")
if any("gpu_memory_used" in m for m in system_metrics):
print("- examples/benchmarks/gpu_usage.png")
print("\nAudio files saved in examples/benchmarks/output/")
if __name__ == "__main__":
main()

Binary file not shown.

Before

Width:  |  Height:  |  Size: 283 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 223 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 406 KiB