Merge pull request #7 from remsky/feat/onnx-inference
Feat/onnx inference Added some optimization options for ONNX Refactoring phonemizer/tokenizer services Cleaned up benchmarking and check scripts Added auto-wav validationd
|
@ -6,6 +6,7 @@ omit =
|
|||
Kokoro-82M/*
|
||||
MagicMock/*
|
||||
test_*.py
|
||||
examples/*
|
||||
|
||||
[report]
|
||||
exclude_lines =
|
||||
|
|
10
.gitignore
vendored
|
@ -1,5 +1,6 @@
|
|||
|
||||
output/
|
||||
output/*
|
||||
output_audio/*
|
||||
ui/data/*
|
||||
|
||||
*.db
|
||||
|
@ -16,3 +17,10 @@ env/
|
|||
|
||||
.coverage
|
||||
|
||||
examples/assorted_checks/benchmarks/output_audio/*
|
||||
examples/assorted_checks/test_combinations/output/*
|
||||
examples/assorted_checks/test_openai/output/*
|
||||
|
||||
examples/assorted_checks/test_voices/output/*
|
||||
examples/assorted_checks/test_formats/output/*
|
||||
ui/RepoScreenshot.png
|
||||
|
|
14
CHANGELOG.md
|
@ -2,6 +2,20 @@
|
|||
|
||||
Notable changes to this project will be documented in this file.
|
||||
|
||||
## 2025-01-04
|
||||
### Added
|
||||
- ONNX Support:
|
||||
- Added single batch ONNX support for CPU inference
|
||||
- Roughly 0.4 RTF (2.4x real-time speed)
|
||||
|
||||
### Modified
|
||||
- Code Refactoring:
|
||||
- Work on modularizing phonemizer and tokenizer into separate services
|
||||
- Incorporated these services into a dev endpoint
|
||||
- Testing and Benchmarking:
|
||||
- Cleaned up benchmarking scripts
|
||||
- Cleaned up test scripts
|
||||
- Added auto-WAV validation scripts
|
||||
|
||||
## 2025-01-02
|
||||
- Audio Format Support:
|
||||
|
|
|
@ -10,8 +10,9 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install PyTorch CPU version
|
||||
RUN pip3 install --no-cache-dir torch==2.5.1 --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
# Install PyTorch CPU version and ONNX runtime
|
||||
RUN pip3 install --no-cache-dir torch==2.5.1 --extra-index-url https://download.pytorch.org/whl/cpu && \
|
||||
pip3 install --no-cache-dir onnxruntime==1.20.1
|
||||
|
||||
# Install all other dependencies from requirements.txt
|
||||
COPY requirements.txt .
|
||||
|
|
10
README.md
|
@ -3,8 +3,8 @@
|
|||
</p>
|
||||
|
||||
# Kokoro TTS API
|
||||
[]()
|
||||
[]()
|
||||
[]()
|
||||
[]()
|
||||
[](https://huggingface.co/hexgrad/Kokoro-82M/tree/c3b0d86e2a980e027ef71c28819ea02e351c2667) [](https://huggingface.co/spaces/Remsky/Kokoro-TTS-Zero)
|
||||
|
||||
Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model
|
||||
|
@ -187,15 +187,13 @@ Key Performance Metrics:
|
|||
<summary>GPU Vs. CPU</summary>
|
||||
|
||||
```bash
|
||||
# GPU: Requires NVIDIA GPU with CUDA 12.1 support
|
||||
# GPU: Requires NVIDIA GPU with CUDA 12.1 support (~35x realtime speed)
|
||||
docker compose up --build
|
||||
|
||||
# CPU: ~10x slower than GPU inference
|
||||
# CPU: ONNX optimized inference (~2.4x realtime speed)
|
||||
docker compose -f docker-compose.cpu.yml up --build
|
||||
```
|
||||
|
||||
*Note: CPU Inference is currently a very basic implementation, and not heavily tested*
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
|
|
@ -14,9 +14,18 @@ class Settings(BaseSettings):
|
|||
output_dir_size_limit_mb: float = 500.0 # Maximum size of output directory in MB
|
||||
default_voice: str = "af"
|
||||
model_dir: str = "/app/Kokoro-82M" # Base directory for model files
|
||||
model_path: str = "kokoro-v0_19.pth"
|
||||
pytorch_model_path: str = "kokoro-v0_19.pth"
|
||||
onnx_model_path: str = "kokoro-v0_19.onnx"
|
||||
voices_dir: str = "voices"
|
||||
sample_rate: int = 24000
|
||||
|
||||
# ONNX Optimization Settings
|
||||
onnx_num_threads: int = 4 # Number of threads for intra-op parallelism
|
||||
onnx_inter_op_threads: int = 4 # Number of threads for inter-op parallelism
|
||||
onnx_execution_mode: str = "parallel" # parallel or sequential
|
||||
onnx_optimization_level: str = "all" # all, basic, or disabled
|
||||
onnx_memory_pattern: bool = True # Enable memory pattern optimization
|
||||
onnx_arena_extend_strategy: str = "kNextPowerOfTwo" # Memory allocation strategy
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
|
185
api/src/core/kokoro.py
Normal file
|
@ -0,0 +1,185 @@
|
|||
import re
|
||||
|
||||
import torch
|
||||
import phonemizer
|
||||
|
||||
|
||||
def split_num(num):
|
||||
num = num.group()
|
||||
if "." in num:
|
||||
return num
|
||||
elif ":" in num:
|
||||
h, m = [int(n) for n in num.split(":")]
|
||||
if m == 0:
|
||||
return f"{h} o'clock"
|
||||
elif m < 10:
|
||||
return f"{h} oh {m}"
|
||||
return f"{h} {m}"
|
||||
year = int(num[:4])
|
||||
if year < 1100 or year % 1000 < 10:
|
||||
return num
|
||||
left, right = num[:2], int(num[2:4])
|
||||
s = "s" if num.endswith("s") else ""
|
||||
if 100 <= year % 1000 <= 999:
|
||||
if right == 0:
|
||||
return f"{left} hundred{s}"
|
||||
elif right < 10:
|
||||
return f"{left} oh {right}{s}"
|
||||
return f"{left} {right}{s}"
|
||||
|
||||
|
||||
def flip_money(m):
|
||||
m = m.group()
|
||||
bill = "dollar" if m[0] == "$" else "pound"
|
||||
if m[-1].isalpha():
|
||||
return f"{m[1:]} {bill}s"
|
||||
elif "." not in m:
|
||||
s = "" if m[1:] == "1" else "s"
|
||||
return f"{m[1:]} {bill}{s}"
|
||||
b, c = m[1:].split(".")
|
||||
s = "" if b == "1" else "s"
|
||||
c = int(c.ljust(2, "0"))
|
||||
coins = (
|
||||
f"cent{'' if c == 1 else 's'}"
|
||||
if m[0] == "$"
|
||||
else ("penny" if c == 1 else "pence")
|
||||
)
|
||||
return f"{b} {bill}{s} and {c} {coins}"
|
||||
|
||||
|
||||
def point_num(num):
|
||||
a, b = num.group().split(".")
|
||||
return " point ".join([a, " ".join(b)])
|
||||
|
||||
|
||||
def normalize_text(text):
|
||||
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
|
||||
text = text.replace("«", chr(8220)).replace("»", chr(8221))
|
||||
text = text.replace(chr(8220), '"').replace(chr(8221), '"')
|
||||
text = text.replace("(", "«").replace(")", "»")
|
||||
for a, b in zip("、。!,:;?", ",.!,:;?"):
|
||||
text = text.replace(a, b + " ")
|
||||
text = re.sub(r"[^\S \n]", " ", text)
|
||||
text = re.sub(r" +", " ", text)
|
||||
text = re.sub(r"(?<=\n) +(?=\n)", "", text)
|
||||
text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text)
|
||||
text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text)
|
||||
text = re.sub(r"\b(?:Ms\.|MS\.(?= [A-Z]))", "Miss", text)
|
||||
text = re.sub(r"\b(?:Mrs\.|MRS\.(?= [A-Z]))", "Mrs", text)
|
||||
text = re.sub(r"\betc\.(?! [A-Z])", "etc", text)
|
||||
text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text)
|
||||
text = re.sub(
|
||||
r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)", split_num, text
|
||||
)
|
||||
text = re.sub(r"(?<=\d),(?=\d)", "", text)
|
||||
text = re.sub(
|
||||
r"(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b",
|
||||
flip_money,
|
||||
text,
|
||||
)
|
||||
text = re.sub(r"\d*\.\d+", point_num, text)
|
||||
text = re.sub(r"(?<=\d)-(?=\d)", " to ", text)
|
||||
text = re.sub(r"(?<=\d)S", " S", text)
|
||||
text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
|
||||
text = re.sub(r"(?<=X')S\b", "s", text)
|
||||
text = re.sub(
|
||||
r"(?:[A-Za-z]\.){2,} [a-z]", lambda m: m.group().replace(".", "-"), text
|
||||
)
|
||||
text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def get_vocab():
|
||||
_pad = "$"
|
||||
_punctuation = ';:,.!?¡¿—…"«»“” '
|
||||
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
||||
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
|
||||
dicts = {}
|
||||
for i in range(len((symbols))):
|
||||
dicts[symbols[i]] = i
|
||||
return dicts
|
||||
|
||||
|
||||
VOCAB = get_vocab()
|
||||
|
||||
|
||||
def tokenize(ps):
|
||||
return [i for i in map(VOCAB.get, ps) if i is not None]
|
||||
|
||||
|
||||
phonemizers = dict(
|
||||
a=phonemizer.backend.EspeakBackend(
|
||||
language="en-us", preserve_punctuation=True, with_stress=True
|
||||
),
|
||||
b=phonemizer.backend.EspeakBackend(
|
||||
language="en-gb", preserve_punctuation=True, with_stress=True
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def phonemize(text, lang, norm=True):
|
||||
if norm:
|
||||
text = normalize_text(text)
|
||||
ps = phonemizers[lang].phonemize([text])
|
||||
ps = ps[0] if ps else ""
|
||||
# https://en.wiktionary.org/wiki/kokoro#English
|
||||
ps = ps.replace("kəkˈoːɹoʊ", "kˈoʊkəɹoʊ").replace("kəkˈɔːɹəʊ", "kˈəʊkəɹəʊ")
|
||||
ps = ps.replace("ʲ", "j").replace("r", "ɹ").replace("x", "k").replace("ɬ", "l")
|
||||
ps = re.sub(r"(?<=[a-zɹː])(?=hˈʌndɹɪd)", " ", ps)
|
||||
ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»“” ]|$)', "z", ps)
|
||||
if lang == "a":
|
||||
ps = re.sub(r"(?<=nˈaɪn)ti(?!ː)", "di", ps)
|
||||
ps = "".join(filter(lambda p: p in VOCAB, ps))
|
||||
return ps.strip()
|
||||
|
||||
|
||||
def length_to_mask(lengths):
|
||||
mask = (
|
||||
torch.arange(lengths.max())
|
||||
.unsqueeze(0)
|
||||
.expand(lengths.shape[0], -1)
|
||||
.type_as(lengths)
|
||||
)
|
||||
mask = torch.gt(mask + 1, lengths.unsqueeze(1))
|
||||
return mask
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(model, tokens, ref_s, speed):
|
||||
device = ref_s.device
|
||||
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
||||
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
||||
text_mask = length_to_mask(input_lengths).to(device)
|
||||
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
||||
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
||||
s = ref_s[:, 128:]
|
||||
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
|
||||
x, _ = model.predictor.lstm(d)
|
||||
duration = model.predictor.duration_proj(x)
|
||||
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
||||
pred_dur = torch.round(duration).clamp(min=1).long()
|
||||
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
|
||||
c_frame = 0
|
||||
for i in range(pred_aln_trg.size(0)):
|
||||
pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
|
||||
c_frame += pred_dur[0, i].item()
|
||||
en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
|
||||
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
||||
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
||||
asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
|
||||
return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
|
||||
|
||||
|
||||
def generate(model, text, voicepack, lang="a", speed=1):
|
||||
ps = phonemize(text, lang)
|
||||
tokens = tokenize(ps)
|
||||
if not tokens:
|
||||
return None
|
||||
elif len(tokens) > 510:
|
||||
tokens = tokens[:510]
|
||||
print("Truncated to 510 tokens")
|
||||
ref_s = voicepack[len(tokens)]
|
||||
out = forward(model, tokens, ref_s, speed)
|
||||
ps = "".join(next(k for k, v in VOCAB.items() if i == v) for i in tokens)
|
||||
return out, ps
|
|
@ -10,8 +10,10 @@ from fastapi import FastAPI
|
|||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from .core.config import settings
|
||||
from .services.tts import TTSModel, TTSService
|
||||
from .services.tts_model import TTSModel
|
||||
from .services.tts_service import TTSService
|
||||
from .routers.openai_compatible import router as openai_router
|
||||
from .routers.text_processing import router as text_router
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
@ -20,8 +22,8 @@ async def lifespan(app: FastAPI):
|
|||
logger.info("Loading TTS model and voice packs...")
|
||||
|
||||
# Initialize the main model with warm-up
|
||||
model, voicepack_count = TTSModel.initialize()
|
||||
logger.info(f"Model loaded and warmed up on {TTSModel._device}")
|
||||
voicepack_count = TTSModel.setup()
|
||||
logger.info(f"Model loaded and warmed up on {TTSModel.get_device()}")
|
||||
logger.info(f"{voicepack_count} voice packs loaded successfully")
|
||||
yield
|
||||
|
||||
|
@ -44,8 +46,9 @@ app.add_middleware(
|
|||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include OpenAI compatible router
|
||||
# Include routers
|
||||
app.include_router(openai_router, prefix="/v1")
|
||||
app.include_router(text_router)
|
||||
|
||||
|
||||
# Health check endpoint
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import List
|
|||
from loguru import logger
|
||||
from fastapi import Depends, Response, APIRouter, HTTPException
|
||||
|
||||
from ..services.tts import TTSService
|
||||
from ..services.tts_service import TTSService
|
||||
from ..services.audio import AudioService
|
||||
from ..structures.schemas import OpenAISpeechRequest
|
||||
|
||||
|
@ -15,9 +15,7 @@ router = APIRouter(
|
|||
|
||||
def get_tts_service() -> TTSService:
|
||||
"""Dependency to get TTSService instance with database session"""
|
||||
return TTSService(
|
||||
start_worker=False
|
||||
) # Don't start worker thread for OpenAI endpoint
|
||||
return TTSService() # Initialize TTSService with default settings
|
||||
|
||||
|
||||
@router.post("/audio/speech")
|
||||
|
|
30
api/src/routers/text_processing.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
from fastapi import APIRouter
|
||||
from ..structures.text_schemas import PhonemeRequest, PhonemeResponse
|
||||
from ..services.text_processing import phonemize, tokenize
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/text",
|
||||
tags=["text processing"]
|
||||
)
|
||||
|
||||
@router.post("/phonemize", response_model=PhonemeResponse)
|
||||
async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse:
|
||||
"""Convert text to phonemes and tokens: Rough attempt
|
||||
|
||||
Args:
|
||||
request: Request containing text and language
|
||||
|
||||
Returns:
|
||||
Phonemes and token IDs
|
||||
"""
|
||||
# Get phonemes
|
||||
phonemes = phonemize(request.text, request.language)
|
||||
|
||||
# Get tokens
|
||||
tokens = tokenize(phonemes)
|
||||
tokens = [0] + tokens + [0] # Add start/end tokens
|
||||
|
||||
return PhonemeResponse(
|
||||
phonemes=phonemes,
|
||||
tokens=tokens
|
||||
)
|
|
@ -1,3 +1,3 @@
|
|||
from .tts import TTSModel, TTSService
|
||||
from .tts_service import TTSService
|
||||
|
||||
__all__ = ["TTSService", "TTSModel"]
|
||||
__all__ = ["TTSService"]
|
||||
|
|
13
api/src/services/text_processing/__init__.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
from .normalizer import normalize_text
|
||||
from .phonemizer import phonemize, PhonemizerBackend, EspeakBackend
|
||||
from .vocabulary import tokenize, decode_tokens, VOCAB
|
||||
|
||||
__all__ = [
|
||||
'normalize_text',
|
||||
'phonemize',
|
||||
'tokenize',
|
||||
'decode_tokens',
|
||||
'VOCAB',
|
||||
'PhonemizerBackend',
|
||||
'EspeakBackend'
|
||||
]
|
111
api/src/services/text_processing/normalizer.py
Normal file
|
@ -0,0 +1,111 @@
|
|||
import re
|
||||
|
||||
def split_num(num: re.Match) -> str:
|
||||
"""Handle number splitting for various formats"""
|
||||
num = num.group()
|
||||
if "." in num:
|
||||
return num
|
||||
elif ":" in num:
|
||||
h, m = [int(n) for n in num.split(":")]
|
||||
if m == 0:
|
||||
return f"{h} o'clock"
|
||||
elif m < 10:
|
||||
return f"{h} oh {m}"
|
||||
return f"{h} {m}"
|
||||
year = int(num[:4])
|
||||
if year < 1100 or year % 1000 < 10:
|
||||
return num
|
||||
left, right = num[:2], int(num[2:4])
|
||||
s = "s" if num.endswith("s") else ""
|
||||
if 100 <= year % 1000 <= 999:
|
||||
if right == 0:
|
||||
return f"{left} hundred{s}"
|
||||
elif right < 10:
|
||||
return f"{left} oh {right}{s}"
|
||||
return f"{left} {right}{s}"
|
||||
|
||||
def handle_money(m: re.Match) -> str:
|
||||
"""Convert money expressions to spoken form"""
|
||||
m = m.group()
|
||||
bill = "dollar" if m[0] == "$" else "pound"
|
||||
if m[-1].isalpha():
|
||||
return f"{m[1:]} {bill}s"
|
||||
elif "." not in m:
|
||||
s = "" if m[1:] == "1" else "s"
|
||||
return f"{m[1:]} {bill}{s}"
|
||||
b, c = m[1:].split(".")
|
||||
s = "" if b == "1" else "s"
|
||||
c = int(c.ljust(2, "0"))
|
||||
coins = (
|
||||
f"cent{'' if c == 1 else 's'}"
|
||||
if m[0] == "$"
|
||||
else ("penny" if c == 1 else "pence")
|
||||
)
|
||||
return f"{b} {bill}{s} and {c} {coins}"
|
||||
|
||||
def handle_decimal(num: re.Match) -> str:
|
||||
"""Convert decimal numbers to spoken form"""
|
||||
a, b = num.group().split(".")
|
||||
return " point ".join([a, " ".join(b)])
|
||||
|
||||
def normalize_text(text: str) -> str:
|
||||
"""Normalize text for TTS processing
|
||||
|
||||
Args:
|
||||
text: Input text to normalize
|
||||
|
||||
Returns:
|
||||
Normalized text
|
||||
"""
|
||||
# Replace quotes and brackets
|
||||
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
|
||||
text = text.replace("«", chr(8220)).replace("»", chr(8221))
|
||||
text = text.replace(chr(8220), '"').replace(chr(8221), '"')
|
||||
text = text.replace("(", "«").replace(")", "»")
|
||||
|
||||
# Handle CJK punctuation
|
||||
for a, b in zip("、。!,:;?", ",.!,:;?"):
|
||||
text = text.replace(a, b + " ")
|
||||
|
||||
# Clean up whitespace
|
||||
text = re.sub(r"[^\S \n]", " ", text)
|
||||
text = re.sub(r" +", " ", text)
|
||||
text = re.sub(r"(?<=\n) +(?=\n)", "", text)
|
||||
|
||||
# Handle titles and abbreviations
|
||||
text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text)
|
||||
text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text)
|
||||
text = re.sub(r"\b(?:Ms\.|MS\.(?= [A-Z]))", "Miss", text)
|
||||
text = re.sub(r"\b(?:Mrs\.|MRS\.(?= [A-Z]))", "Mrs", text)
|
||||
text = re.sub(r"\betc\.(?! [A-Z])", "etc", text)
|
||||
|
||||
# Handle common words
|
||||
text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text)
|
||||
|
||||
# Handle numbers and money
|
||||
text = re.sub(
|
||||
r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)",
|
||||
split_num,
|
||||
text
|
||||
)
|
||||
text = re.sub(r"(?<=\d),(?=\d)", "", text)
|
||||
text = re.sub(
|
||||
r"(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b",
|
||||
handle_money,
|
||||
text,
|
||||
)
|
||||
text = re.sub(r"\d*\.\d+", handle_decimal, text)
|
||||
|
||||
# Handle various formatting
|
||||
text = re.sub(r"(?<=\d)-(?=\d)", " to ", text)
|
||||
text = re.sub(r"(?<=\d)S", " S", text)
|
||||
text = re.sub(r"(?<=[BCDFGHJ-NP-TV-Z])'?s\b", "'S", text)
|
||||
text = re.sub(r"(?<=X')S\b", "s", text)
|
||||
text = re.sub(
|
||||
r"(?:[A-Za-z]\.){2,} [a-z]",
|
||||
lambda m: m.group().replace(".", "-"),
|
||||
text
|
||||
)
|
||||
text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
|
||||
|
||||
return text.strip()
|
97
api/src/services/text_processing/phonemizer.py
Normal file
|
@ -0,0 +1,97 @@
|
|||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
import phonemizer
|
||||
from .normalizer import normalize_text
|
||||
|
||||
class PhonemizerBackend(ABC):
|
||||
"""Abstract base class for phonemization backends"""
|
||||
|
||||
@abstractmethod
|
||||
def phonemize(self, text: str) -> str:
|
||||
"""Convert text to phonemes
|
||||
|
||||
Args:
|
||||
text: Text to convert to phonemes
|
||||
|
||||
Returns:
|
||||
Phonemized text
|
||||
"""
|
||||
pass
|
||||
|
||||
class EspeakBackend(PhonemizerBackend):
|
||||
"""Espeak-based phonemizer implementation"""
|
||||
|
||||
def __init__(self, language: str):
|
||||
"""Initialize espeak backend
|
||||
|
||||
Args:
|
||||
language: Language code ('en-us' or 'en-gb')
|
||||
"""
|
||||
self.backend = phonemizer.backend.EspeakBackend(
|
||||
language=language,
|
||||
preserve_punctuation=True,
|
||||
with_stress=True
|
||||
)
|
||||
self.language = language
|
||||
|
||||
def phonemize(self, text: str) -> str:
|
||||
"""Convert text to phonemes using espeak
|
||||
|
||||
Args:
|
||||
text: Text to convert to phonemes
|
||||
|
||||
Returns:
|
||||
Phonemized text
|
||||
"""
|
||||
# Phonemize text
|
||||
ps = self.backend.phonemize([text])
|
||||
ps = ps[0] if ps else ""
|
||||
|
||||
# Handle special cases
|
||||
ps = ps.replace("kəkˈoːɹoʊ", "kˈoʊkəɹoʊ").replace("kəkˈɔːɹəʊ", "kˈəʊkəɹəʊ")
|
||||
ps = ps.replace("ʲ", "j").replace("r", "ɹ").replace("x", "k").replace("ɬ", "l")
|
||||
ps = re.sub(r"(?<=[a-zɹː])(?=hˈʌndɹɪd)", " ", ps)
|
||||
ps = re.sub(r' z(?=[;:,.!?¡¿—…"«»"" ]|$)', "z", ps)
|
||||
|
||||
# Language-specific rules
|
||||
if self.language == "en-us":
|
||||
ps = re.sub(r"(?<=nˈaɪn)ti(?!ː)", "di", ps)
|
||||
|
||||
return ps.strip()
|
||||
|
||||
def create_phonemizer(language: str = "a") -> PhonemizerBackend:
|
||||
"""Factory function to create phonemizer backend
|
||||
|
||||
Args:
|
||||
language: Language code ('a' for US English, 'b' for British English)
|
||||
|
||||
Returns:
|
||||
Phonemizer backend instance
|
||||
"""
|
||||
# Map language codes to espeak language codes
|
||||
lang_map = {
|
||||
"a": "en-us",
|
||||
"b": "en-gb"
|
||||
}
|
||||
|
||||
if language not in lang_map:
|
||||
raise ValueError(f"Unsupported language code: {language}")
|
||||
|
||||
return EspeakBackend(lang_map[language])
|
||||
|
||||
def phonemize(text: str, language: str = "a", normalize: bool = True) -> str:
|
||||
"""Convert text to phonemes
|
||||
|
||||
Args:
|
||||
text: Text to convert to phonemes
|
||||
language: Language code ('a' for US English, 'b' for British English)
|
||||
normalize: Whether to normalize text before phonemization
|
||||
|
||||
Returns:
|
||||
Phonemized text
|
||||
"""
|
||||
if normalize:
|
||||
text = normalize_text(text)
|
||||
|
||||
phonemizer = create_phonemizer(language)
|
||||
return phonemizer.phonemize(text)
|
37
api/src/services/text_processing/vocabulary.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
def get_vocab():
|
||||
"""Get the vocabulary dictionary mapping characters to token IDs"""
|
||||
_pad = "$"
|
||||
_punctuation = ';:,.!?¡¿—…"«»"" '
|
||||
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
||||
|
||||
# Create vocabulary dictionary
|
||||
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
|
||||
return {symbol: i for i, symbol in enumerate(symbols)}
|
||||
|
||||
# Initialize vocabulary
|
||||
VOCAB = get_vocab()
|
||||
|
||||
def tokenize(phonemes: str) -> list[int]:
|
||||
"""Convert phonemes string to token IDs
|
||||
|
||||
Args:
|
||||
phonemes: String of phonemes to tokenize
|
||||
|
||||
Returns:
|
||||
List of token IDs
|
||||
"""
|
||||
return [i for i in map(VOCAB.get, phonemes) if i is not None]
|
||||
|
||||
def decode_tokens(tokens: list[int]) -> str:
|
||||
"""Convert token IDs back to phonemes string
|
||||
|
||||
Args:
|
||||
tokens: List of token IDs
|
||||
|
||||
Returns:
|
||||
String of phonemes
|
||||
"""
|
||||
# Create reverse mapping
|
||||
id_to_symbol = {i: s for s, i in VOCAB.items()}
|
||||
return "".join(id_to_symbol[t] for t in tokens)
|
|
@ -1,286 +0,0 @@
|
|||
import io
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import threading
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tiktoken
|
||||
import scipy.io.wavfile as wavfile
|
||||
from kokoro import generate, tokenize, phonemize, normalize_text
|
||||
from loguru import logger
|
||||
from models import build_model
|
||||
|
||||
from ..core.config import settings
|
||||
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
|
||||
class TTSModel:
|
||||
_instance = None
|
||||
_device = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
# Directory for all voices (copied base voices, and any created combined voices)
|
||||
VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices")
|
||||
|
||||
@classmethod
|
||||
def initialize(cls):
|
||||
"""Initialize and warm up the model"""
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
# Initialize model
|
||||
cls._device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"Initializing model on {cls._device}")
|
||||
model_path = os.path.join(settings.model_dir, settings.model_path)
|
||||
model = build_model(model_path, cls._device)
|
||||
cls._instance = model
|
||||
|
||||
# Ensure voices directory exists
|
||||
os.makedirs(cls.VOICES_DIR, exist_ok=True)
|
||||
|
||||
# Copy base voices to local directory
|
||||
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
|
||||
if os.path.exists(base_voices_dir):
|
||||
for file in os.listdir(base_voices_dir):
|
||||
if file.endswith(".pt"):
|
||||
voice_name = file[:-3]
|
||||
voice_path = os.path.join(cls.VOICES_DIR, file)
|
||||
if not os.path.exists(voice_path):
|
||||
try:
|
||||
logger.info(
|
||||
f"Copying base voice {voice_name} to voices directory"
|
||||
)
|
||||
base_path = os.path.join(base_voices_dir, file)
|
||||
voicepack = torch.load(
|
||||
base_path,
|
||||
map_location=cls._device,
|
||||
weights_only=True,
|
||||
)
|
||||
torch.save(voicepack, voice_path)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error copying voice {voice_name}: {str(e)}"
|
||||
)
|
||||
|
||||
# Warm up with default voice
|
||||
try:
|
||||
dummy_text = "Hello"
|
||||
voice_path = os.path.join(cls.VOICES_DIR, "af.pt")
|
||||
dummy_voicepack = torch.load(
|
||||
voice_path, map_location=cls._device, weights_only=True
|
||||
)
|
||||
generate(model, dummy_text, dummy_voicepack, lang="a", speed=1.0)
|
||||
logger.info("Model warm-up complete")
|
||||
except Exception as e:
|
||||
logger.warning(f"Model warm-up failed: {e}")
|
||||
|
||||
# Count voices in directory for validation
|
||||
voice_count = len(
|
||||
[f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")]
|
||||
)
|
||||
return cls._instance, voice_count
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
"""Get the initialized instance or raise an error"""
|
||||
if cls._instance is None:
|
||||
raise RuntimeError("Model not initialized. Call initialize() first.")
|
||||
return cls._instance, cls._device
|
||||
|
||||
|
||||
class TTSService:
|
||||
def __init__(self, output_dir: str = None, start_worker: bool = False):
|
||||
self.output_dir = output_dir
|
||||
self._ensure_voices()
|
||||
if start_worker:
|
||||
self.start_worker()
|
||||
|
||||
def _ensure_voices(self):
|
||||
"""Copy base voices to local voices directory during initialization"""
|
||||
os.makedirs(TTSModel.VOICES_DIR, exist_ok=True)
|
||||
|
||||
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
|
||||
if os.path.exists(base_voices_dir):
|
||||
for file in os.listdir(base_voices_dir):
|
||||
if file.endswith(".pt"):
|
||||
voice_name = file[:-3]
|
||||
voice_path = os.path.join(TTSModel.VOICES_DIR, file)
|
||||
if not os.path.exists(voice_path):
|
||||
try:
|
||||
logger.info(
|
||||
f"Copying base voice {voice_name} to voices directory"
|
||||
)
|
||||
base_path = os.path.join(base_voices_dir, file)
|
||||
voicepack = torch.load(
|
||||
base_path,
|
||||
map_location=TTSModel._device,
|
||||
weights_only=True,
|
||||
)
|
||||
torch.save(voicepack, voice_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Error copying voice {voice_name}: {str(e)}")
|
||||
|
||||
def _split_text(self, text: str) -> List[str]:
|
||||
"""Split text into sentences"""
|
||||
return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()]
|
||||
|
||||
def _get_voice_path(self, voice_name: str) -> Optional[str]:
|
||||
"""Get the path to a voice file.
|
||||
|
||||
Args:
|
||||
voice_name: Name of the voice to find
|
||||
|
||||
Returns:
|
||||
Path to the voice file if found, None otherwise
|
||||
"""
|
||||
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice_name}.pt")
|
||||
return voice_path if os.path.exists(voice_path) else None
|
||||
|
||||
def _generate_audio(
|
||||
self, text: str, voice: str, speed: float, stitch_long_output: bool = True
|
||||
) -> Tuple[torch.Tensor, float]:
|
||||
"""Generate audio and measure processing time"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Normalize text once at the start
|
||||
text = normalize_text(text)
|
||||
if not text:
|
||||
raise ValueError("Text is empty after preprocessing")
|
||||
|
||||
# Check voice exists
|
||||
voice_path = self._get_voice_path(voice)
|
||||
if not voice_path:
|
||||
raise ValueError(f"Voice not found: {voice}")
|
||||
|
||||
# Load model and voice
|
||||
model = TTSModel._instance
|
||||
voicepack = torch.load(
|
||||
voice_path, map_location=TTSModel._device, weights_only=True
|
||||
)
|
||||
|
||||
# Generate audio with or without stitching
|
||||
if stitch_long_output:
|
||||
chunks = self._split_text(text)
|
||||
audio_chunks = []
|
||||
|
||||
# Process all chunks with same model/voicepack instance
|
||||
for i, chunk in enumerate(chunks):
|
||||
try:
|
||||
# Validate phonemization first
|
||||
# ps = phonemize(chunk, voice[0])
|
||||
# tokens = tokenize(ps)
|
||||
# logger.debug(
|
||||
# f"Processing chunk {i + 1}/{len(chunks)}: {len(tokens)} tokens"
|
||||
# )
|
||||
|
||||
# Only proceed if phonemization succeeded
|
||||
chunk_audio, _ = generate(
|
||||
model, chunk, voicepack, lang=voice[0], speed=speed
|
||||
)
|
||||
if chunk_audio is not None:
|
||||
audio_chunks.append(chunk_audio)
|
||||
else:
|
||||
logger.error(
|
||||
f"No audio generated for chunk {i + 1}/{len(chunks)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to generate audio for chunk {i + 1}/{len(chunks)}: '{chunk}'. Error: {str(e)}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not audio_chunks:
|
||||
raise ValueError("No audio chunks were generated successfully")
|
||||
|
||||
audio = (
|
||||
np.concatenate(audio_chunks)
|
||||
if len(audio_chunks) > 1
|
||||
else audio_chunks[0]
|
||||
)
|
||||
else:
|
||||
audio, _ = generate(model, text, voicepack, lang=voice[0], speed=speed)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
return audio, processing_time
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in audio generation: {str(e)}")
|
||||
raise
|
||||
|
||||
def _save_audio(self, audio: torch.Tensor, filepath: str):
|
||||
"""Save audio to file"""
|
||||
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
||||
wavfile.write(filepath, 24000, audio)
|
||||
|
||||
def _audio_to_bytes(self, audio: torch.Tensor) -> bytes:
|
||||
"""Convert audio tensor to WAV bytes"""
|
||||
buffer = io.BytesIO()
|
||||
wavfile.write(buffer, 24000, audio)
|
||||
return buffer.getvalue()
|
||||
|
||||
def combine_voices(self, voices: List[str]) -> str:
|
||||
"""Combine multiple voices into a new voice.
|
||||
|
||||
Args:
|
||||
voices: List of voice names to combine
|
||||
|
||||
Returns:
|
||||
Name of the combined voice
|
||||
|
||||
Raises:
|
||||
ValueError: If less than 2 voices provided or voice loading fails
|
||||
RuntimeError: If voice combination or saving fails
|
||||
"""
|
||||
if len(voices) < 2:
|
||||
raise ValueError("At least 2 voices are required for combination")
|
||||
|
||||
# Load voices
|
||||
t_voices: List[torch.Tensor] = []
|
||||
v_name: List[str] = []
|
||||
|
||||
for voice in voices:
|
||||
try:
|
||||
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt")
|
||||
voicepack = torch.load(
|
||||
voice_path, map_location=TTSModel._device, weights_only=True
|
||||
)
|
||||
t_voices.append(voicepack)
|
||||
v_name.append(voice)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load voice {voice}: {str(e)}")
|
||||
|
||||
# Combine voices
|
||||
try:
|
||||
f: str = "_".join(v_name)
|
||||
v = torch.mean(torch.stack(t_voices), dim=0)
|
||||
combined_path = os.path.join(TTSModel.VOICES_DIR, f"{f}.pt")
|
||||
|
||||
# Save combined voice
|
||||
try:
|
||||
torch.save(v, combined_path)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to save combined voice to {combined_path}: {str(e)}"
|
||||
)
|
||||
|
||||
return f
|
||||
|
||||
except Exception as e:
|
||||
if not isinstance(e, (ValueError, RuntimeError)):
|
||||
raise RuntimeError(f"Error combining voices: {str(e)}")
|
||||
raise
|
||||
|
||||
def list_voices(self) -> List[str]:
|
||||
"""List all available voices"""
|
||||
voices = []
|
||||
try:
|
||||
for file in os.listdir(TTSModel.VOICES_DIR):
|
||||
if file.endswith(".pt"):
|
||||
voices.append(file[:-3]) # Remove .pt extension
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing voices: {str(e)}")
|
||||
return sorted(voices)
|
136
api/src/services/tts_base.py
Normal file
|
@ -0,0 +1,136 @@
|
|||
import os
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple
|
||||
import torch
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from ..core.config import settings
|
||||
|
||||
class TTSBaseModel(ABC):
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
_device = None
|
||||
VOICES_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "voices")
|
||||
|
||||
@classmethod
|
||||
def setup(cls):
|
||||
"""Initialize model and setup voices"""
|
||||
with cls._lock:
|
||||
# Set device
|
||||
cuda_available = torch.cuda.is_available()
|
||||
logger.info(f"CUDA available: {cuda_available}")
|
||||
if cuda_available:
|
||||
try:
|
||||
# Test CUDA device
|
||||
test_tensor = torch.zeros(1).cuda()
|
||||
logger.info("CUDA test successful")
|
||||
model_path = os.path.join(settings.model_dir, settings.pytorch_model_path)
|
||||
cls._device = "cuda"
|
||||
except Exception as e:
|
||||
logger.error(f"CUDA test failed: {e}")
|
||||
cls._device = "cpu"
|
||||
else:
|
||||
cls._device = "cpu"
|
||||
model_path = os.path.join(settings.model_dir, settings.onnx_model_path)
|
||||
logger.info(f"Initializing model on {cls._device}")
|
||||
|
||||
# Initialize model
|
||||
if not cls.initialize(settings.model_dir, model_path=model_path):
|
||||
raise RuntimeError(f"Failed to initialize {cls._device.upper()} model")
|
||||
|
||||
# Setup voices directory
|
||||
os.makedirs(cls.VOICES_DIR, exist_ok=True)
|
||||
|
||||
# Copy base voices to local directory
|
||||
base_voices_dir = os.path.join(settings.model_dir, settings.voices_dir)
|
||||
if os.path.exists(base_voices_dir):
|
||||
for file in os.listdir(base_voices_dir):
|
||||
if file.endswith(".pt"):
|
||||
voice_name = file[:-3]
|
||||
voice_path = os.path.join(cls.VOICES_DIR, file)
|
||||
if not os.path.exists(voice_path):
|
||||
try:
|
||||
logger.info(f"Copying base voice {voice_name} to voices directory")
|
||||
base_path = os.path.join(base_voices_dir, file)
|
||||
voicepack = torch.load(base_path, map_location=cls._device, weights_only=True)
|
||||
torch.save(voicepack, voice_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Error copying voice {voice_name}: {str(e)}")
|
||||
|
||||
# Warm up with default voice
|
||||
try:
|
||||
dummy_text = "Hello"
|
||||
voice_path = os.path.join(cls.VOICES_DIR, "af.pt")
|
||||
dummy_voicepack = torch.load(voice_path, map_location=cls._device, weights_only=True)
|
||||
|
||||
# Process text and generate audio
|
||||
phonemes, tokens = cls.process_text(dummy_text, "a")
|
||||
cls.generate_from_tokens(tokens, dummy_voicepack, 1.0)
|
||||
|
||||
logger.info("Model warm-up complete")
|
||||
except Exception as e:
|
||||
logger.warning(f"Model warm-up failed: {e}")
|
||||
|
||||
# Count voices in directory
|
||||
voice_count = len([f for f in os.listdir(cls.VOICES_DIR) if f.endswith(".pt")])
|
||||
return voice_count
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def initialize(cls, model_dir: str, model_path: str = None):
|
||||
"""Initialize the model"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def process_text(cls, text: str, language: str) -> Tuple[str, List[int]]:
|
||||
"""Process text into phonemes and tokens
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
language: Language code
|
||||
|
||||
Returns:
|
||||
tuple[str, list[int]]: Phonemes and token IDs
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def generate_from_text(cls, text: str, voicepack: torch.Tensor, language: str, speed: float) -> Tuple[np.ndarray, str]:
|
||||
"""Generate audio from text
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
voicepack: Voice tensor
|
||||
language: Language code
|
||||
speed: Speed factor
|
||||
|
||||
Returns:
|
||||
tuple[np.ndarray, str]: Generated audio samples and phonemes
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def generate_from_tokens(cls, tokens: List[int], voicepack: torch.Tensor, speed: float) -> np.ndarray:
|
||||
"""Generate audio from tokens
|
||||
|
||||
Args:
|
||||
tokens: Token IDs
|
||||
voicepack: Voice tensor
|
||||
speed: Speed factor
|
||||
|
||||
Returns:
|
||||
np.ndarray: Generated audio samples
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_device(cls):
|
||||
"""Get the current device"""
|
||||
if cls._device is None:
|
||||
raise RuntimeError("Model not initialized. Call setup() first.")
|
||||
return cls._device
|
144
api/src/services/tts_cpu.py
Normal file
|
@ -0,0 +1,144 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
from onnxruntime import InferenceSession, SessionOptions, GraphOptimizationLevel, ExecutionMode
|
||||
from loguru import logger
|
||||
|
||||
from .tts_base import TTSBaseModel
|
||||
from .text_processing import phonemize, tokenize
|
||||
from ..core.config import settings
|
||||
|
||||
class TTSCPUModel(TTSBaseModel):
|
||||
_instance = None
|
||||
_onnx_session = None
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, model_dir: str, model_path: str = None):
|
||||
"""Initialize ONNX model for CPU inference"""
|
||||
if cls._onnx_session is None:
|
||||
# Try loading ONNX model
|
||||
onnx_path = os.path.join(model_dir, settings.onnx_model_path)
|
||||
if os.path.exists(onnx_path):
|
||||
logger.info(f"Loading ONNX model from {onnx_path}")
|
||||
else:
|
||||
logger.error(f"ONNX model not found at {onnx_path}")
|
||||
return None
|
||||
|
||||
if not onnx_path:
|
||||
return None
|
||||
|
||||
logger.info(f"Loading ONNX model from {onnx_path}")
|
||||
|
||||
# Configure ONNX session for optimal performance
|
||||
session_options = SessionOptions()
|
||||
|
||||
# Set optimization level
|
||||
if settings.onnx_optimization_level == "all":
|
||||
session_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
elif settings.onnx_optimization_level == "basic":
|
||||
session_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||
else:
|
||||
session_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
|
||||
|
||||
# Configure threading
|
||||
session_options.intra_op_num_threads = settings.onnx_num_threads
|
||||
session_options.inter_op_num_threads = settings.onnx_inter_op_threads
|
||||
|
||||
# Set execution mode
|
||||
session_options.execution_mode = (
|
||||
ExecutionMode.ORT_PARALLEL
|
||||
if settings.onnx_execution_mode == "parallel"
|
||||
else ExecutionMode.ORT_SEQUENTIAL
|
||||
)
|
||||
|
||||
# Enable/disable memory pattern optimization
|
||||
session_options.enable_mem_pattern = settings.onnx_memory_pattern
|
||||
|
||||
# Configure CPU provider options
|
||||
provider_options = {
|
||||
'CPUExecutionProvider': {
|
||||
'arena_extend_strategy': settings.onnx_arena_extend_strategy,
|
||||
'cpu_memory_arena_cfg': 'cpu:0'
|
||||
}
|
||||
}
|
||||
|
||||
cls._onnx_session = InferenceSession(
|
||||
onnx_path,
|
||||
sess_options=session_options,
|
||||
providers=['CPUExecutionProvider'],
|
||||
provider_options=[provider_options]
|
||||
)
|
||||
|
||||
return cls._onnx_session
|
||||
return cls._onnx_session
|
||||
|
||||
@classmethod
|
||||
def process_text(cls, text: str, language: str) -> tuple[str, list[int]]:
|
||||
"""Process text into phonemes and tokens
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
language: Language code
|
||||
|
||||
Returns:
|
||||
tuple[str, list[int]]: Phonemes and token IDs
|
||||
"""
|
||||
phonemes = phonemize(text, language)
|
||||
tokens = tokenize(phonemes)
|
||||
tokens = [0] + tokens + [0] # Add start/end tokens
|
||||
return phonemes, tokens
|
||||
|
||||
@classmethod
|
||||
def generate_from_text(cls, text: str, voicepack: torch.Tensor, language: str, speed: float) -> tuple[np.ndarray, str]:
|
||||
"""Generate audio from text
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
voicepack: Voice tensor
|
||||
language: Language code
|
||||
speed: Speed factor
|
||||
|
||||
Returns:
|
||||
tuple[np.ndarray, str]: Generated audio samples and phonemes
|
||||
"""
|
||||
if cls._onnx_session is None:
|
||||
raise RuntimeError("ONNX model not initialized")
|
||||
|
||||
# Process text
|
||||
phonemes, tokens = cls.process_text(text, language)
|
||||
|
||||
# Generate audio
|
||||
audio = cls.generate_from_tokens(tokens, voicepack, speed)
|
||||
|
||||
return audio, phonemes
|
||||
|
||||
@classmethod
|
||||
def generate_from_tokens(cls, tokens: list[int], voicepack: torch.Tensor, speed: float) -> np.ndarray:
|
||||
"""Generate audio from tokens
|
||||
|
||||
Args:
|
||||
tokens: Token IDs
|
||||
voicepack: Voice tensor
|
||||
speed: Speed factor
|
||||
|
||||
Returns:
|
||||
np.ndarray: Generated audio samples
|
||||
"""
|
||||
if cls._onnx_session is None:
|
||||
raise RuntimeError("ONNX model not initialized")
|
||||
|
||||
# Pre-allocate and prepare inputs
|
||||
tokens_input = np.array([tokens], dtype=np.int64)
|
||||
style_input = voicepack[len(tokens)-2].numpy() # Already has correct dimensions
|
||||
speed_input = np.full(1, speed, dtype=np.float32) # More efficient than ones * speed
|
||||
|
||||
# Run inference with optimized inputs
|
||||
result = cls._onnx_session.run(
|
||||
None,
|
||||
{
|
||||
'tokens': tokens_input,
|
||||
'style': style_input,
|
||||
'speed': speed_input
|
||||
}
|
||||
)
|
||||
return result[0]
|
127
api/src/services/tts_gpu.py
Normal file
|
@ -0,0 +1,127 @@
|
|||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
from models import build_model
|
||||
from .text_processing import phonemize, tokenize
|
||||
|
||||
from .tts_base import TTSBaseModel
|
||||
from ..core.config import settings
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(model, tokens, ref_s, speed):
|
||||
"""Forward pass through the model"""
|
||||
device = ref_s.device
|
||||
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
||||
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
||||
text_mask = length_to_mask(input_lengths).to(device)
|
||||
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
||||
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
||||
s = ref_s[:, 128:]
|
||||
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
|
||||
x, _ = model.predictor.lstm(d)
|
||||
duration = model.predictor.duration_proj(x)
|
||||
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
||||
pred_dur = torch.round(duration).clamp(min=1).long()
|
||||
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item())
|
||||
c_frame = 0
|
||||
for i in range(pred_aln_trg.size(0)):
|
||||
pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
|
||||
c_frame += pred_dur[0, i].item()
|
||||
en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
|
||||
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
||||
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
||||
asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
|
||||
return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
|
||||
|
||||
def length_to_mask(lengths):
|
||||
"""Create attention mask from lengths"""
|
||||
mask = (
|
||||
torch.arange(lengths.max())
|
||||
.unsqueeze(0)
|
||||
.expand(lengths.shape[0], -1)
|
||||
.type_as(lengths)
|
||||
)
|
||||
mask = torch.gt(mask + 1, lengths.unsqueeze(1))
|
||||
return mask
|
||||
|
||||
class TTSGPUModel(TTSBaseModel):
|
||||
_instance = None
|
||||
_device = "cuda"
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, model_dir: str, model_path: str):
|
||||
"""Initialize PyTorch model for GPU inference"""
|
||||
if cls._instance is None and torch.cuda.is_available():
|
||||
try:
|
||||
logger.info("Initializing GPU model")
|
||||
model_path = os.path.join(model_dir, settings.pytorch_model_path)
|
||||
model = build_model(model_path, cls._device)
|
||||
cls._instance = model
|
||||
return cls._instance
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize GPU model: {e}")
|
||||
return None
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def process_text(cls, text: str, language: str) -> tuple[str, list[int]]:
|
||||
"""Process text into phonemes and tokens
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
language: Language code
|
||||
|
||||
Returns:
|
||||
tuple[str, list[int]]: Phonemes and token IDs
|
||||
"""
|
||||
phonemes = phonemize(text, language)
|
||||
tokens = tokenize(phonemes)
|
||||
return phonemes, tokens
|
||||
|
||||
@classmethod
|
||||
def generate_from_text(cls, text: str, voicepack: torch.Tensor, language: str, speed: float) -> tuple[np.ndarray, str]:
|
||||
"""Generate audio from text
|
||||
|
||||
Args:
|
||||
text: Input text
|
||||
voicepack: Voice tensor
|
||||
language: Language code
|
||||
speed: Speed factor
|
||||
|
||||
Returns:
|
||||
tuple[np.ndarray, str]: Generated audio samples and phonemes
|
||||
"""
|
||||
if cls._instance is None:
|
||||
raise RuntimeError("GPU model not initialized")
|
||||
|
||||
# Process text
|
||||
phonemes, tokens = cls.process_text(text, language)
|
||||
|
||||
# Generate audio
|
||||
audio = cls.generate_from_tokens(tokens, voicepack, speed)
|
||||
|
||||
return audio, phonemes
|
||||
|
||||
@classmethod
|
||||
def generate_from_tokens(cls, tokens: list[int], voicepack: torch.Tensor, speed: float) -> np.ndarray:
|
||||
"""Generate audio from tokens
|
||||
|
||||
Args:
|
||||
tokens: Token IDs
|
||||
voicepack: Voice tensor
|
||||
speed: Speed factor
|
||||
|
||||
Returns:
|
||||
np.ndarray: Generated audio samples
|
||||
"""
|
||||
if cls._instance is None:
|
||||
raise RuntimeError("GPU model not initialized")
|
||||
|
||||
# Get reference style
|
||||
ref_s = voicepack[len(tokens)]
|
||||
|
||||
# Generate audio
|
||||
audio = forward(cls._instance, tokens, ref_s, speed)
|
||||
|
||||
return audio
|
8
api/src/services/tts_model.py
Normal file
|
@ -0,0 +1,8 @@
|
|||
import torch
|
||||
|
||||
if torch.cuda.is_available():
|
||||
from .tts_gpu import TTSGPUModel as TTSModel
|
||||
else:
|
||||
from .tts_cpu import TTSCPUModel as TTSModel
|
||||
|
||||
__all__ = ["TTSModel"]
|
161
api/src/services/tts_service.py
Normal file
|
@ -0,0 +1,161 @@
|
|||
import io
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import scipy.io.wavfile as wavfile
|
||||
from .text_processing import normalize_text
|
||||
from loguru import logger
|
||||
|
||||
from ..core.config import settings
|
||||
from .tts_model import TTSModel
|
||||
|
||||
|
||||
class TTSService:
|
||||
def __init__(self, output_dir: str = None):
|
||||
self.output_dir = output_dir
|
||||
|
||||
def _split_text(self, text: str) -> List[str]:
|
||||
"""Split text into sentences"""
|
||||
if not isinstance(text, str):
|
||||
text = str(text) if text is not None else ""
|
||||
return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()]
|
||||
|
||||
def _get_voice_path(self, voice_name: str) -> Optional[str]:
|
||||
"""Get the path to a voice file"""
|
||||
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice_name}.pt")
|
||||
return voice_path if os.path.exists(voice_path) else None
|
||||
|
||||
def _generate_audio(
|
||||
self, text: str, voice: str, speed: float, stitch_long_output: bool = True
|
||||
) -> Tuple[torch.Tensor, float]:
|
||||
"""Generate audio and measure processing time"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Normalize text once at the start
|
||||
if not text:
|
||||
raise ValueError("Text is empty after preprocessing")
|
||||
normalized = normalize_text(text)
|
||||
if not normalized:
|
||||
raise ValueError("Text is empty after preprocessing")
|
||||
text = str(normalized)
|
||||
|
||||
# Check voice exists
|
||||
voice_path = self._get_voice_path(voice)
|
||||
if not voice_path:
|
||||
raise ValueError(f"Voice not found: {voice}")
|
||||
|
||||
# Load voice
|
||||
voicepack = torch.load(
|
||||
voice_path, map_location=TTSModel.get_device(), weights_only=True
|
||||
)
|
||||
|
||||
# Generate audio with or without stitching
|
||||
if stitch_long_output:
|
||||
chunks = self._split_text(text)
|
||||
audio_chunks = []
|
||||
|
||||
# Process all chunks
|
||||
for i, chunk in enumerate(chunks):
|
||||
try:
|
||||
# Process text and generate audio
|
||||
phonemes, tokens = TTSModel.process_text(chunk, voice[0])
|
||||
chunk_audio = TTSModel.generate_from_tokens(tokens, voicepack, speed)
|
||||
|
||||
if chunk_audio is not None:
|
||||
audio_chunks.append(chunk_audio)
|
||||
else:
|
||||
logger.error(f"No audio generated for chunk {i + 1}/{len(chunks)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to generate audio for chunk {i + 1}/{len(chunks)}: '{chunk}'. Error: {str(e)}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not audio_chunks:
|
||||
raise ValueError("No audio chunks were generated successfully")
|
||||
|
||||
audio = (
|
||||
np.concatenate(audio_chunks)
|
||||
if len(audio_chunks) > 1
|
||||
else audio_chunks[0]
|
||||
)
|
||||
else:
|
||||
# Process single chunk
|
||||
phonemes, tokens = TTSModel.process_text(text, voice[0])
|
||||
audio = TTSModel.generate_from_tokens(tokens, voicepack, speed)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
return audio, processing_time
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in audio generation: {str(e)}")
|
||||
raise
|
||||
|
||||
def _save_audio(self, audio: torch.Tensor, filepath: str):
|
||||
"""Save audio to file"""
|
||||
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
||||
wavfile.write(filepath, 24000, audio)
|
||||
|
||||
def _audio_to_bytes(self, audio: torch.Tensor) -> bytes:
|
||||
"""Convert audio tensor to WAV bytes"""
|
||||
buffer = io.BytesIO()
|
||||
wavfile.write(buffer, 24000, audio)
|
||||
return buffer.getvalue()
|
||||
|
||||
def combine_voices(self, voices: List[str]) -> str:
|
||||
"""Combine multiple voices into a new voice"""
|
||||
if len(voices) < 2:
|
||||
raise ValueError("At least 2 voices are required for combination")
|
||||
|
||||
# Load voices
|
||||
t_voices: List[torch.Tensor] = []
|
||||
v_name: List[str] = []
|
||||
|
||||
for voice in voices:
|
||||
try:
|
||||
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt")
|
||||
voicepack = torch.load(
|
||||
voice_path, map_location=TTSModel.get_device(), weights_only=True
|
||||
)
|
||||
t_voices.append(voicepack)
|
||||
v_name.append(voice)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load voice {voice}: {str(e)}")
|
||||
|
||||
# Combine voices
|
||||
try:
|
||||
f: str = "_".join(v_name)
|
||||
v = torch.mean(torch.stack(t_voices), dim=0)
|
||||
combined_path = os.path.join(TTSModel.VOICES_DIR, f"{f}.pt")
|
||||
|
||||
# Save combined voice
|
||||
try:
|
||||
torch.save(v, combined_path)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to save combined voice to {combined_path}: {str(e)}"
|
||||
)
|
||||
|
||||
return f
|
||||
|
||||
except Exception as e:
|
||||
if not isinstance(e, (ValueError, RuntimeError)):
|
||||
raise RuntimeError(f"Error combining voices: {str(e)}")
|
||||
raise
|
||||
|
||||
def list_voices(self) -> List[str]:
|
||||
"""List all available voices"""
|
||||
voices = []
|
||||
try:
|
||||
for file in os.listdir(TTSModel.VOICES_DIR):
|
||||
if file.endswith(".pt"):
|
||||
voices.append(file[:-3]) # Remove .pt extension
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing voices: {str(e)}")
|
||||
return sorted(voices)
|
9
api/src/structures/text_schemas.py
Normal file
|
@ -0,0 +1,9 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
class PhonemeRequest(BaseModel):
|
||||
text: str
|
||||
language: str = "a" # Default to American English
|
||||
|
||||
class PhonemeResponse(BaseModel):
|
||||
phonemes: str
|
||||
tokens: list[int]
|
|
@ -21,8 +21,73 @@ def cleanup():
|
|||
cleanup_mock_dirs()
|
||||
|
||||
|
||||
# Mock torch and other ML modules before they're imported
|
||||
sys.modules["torch"] = Mock()
|
||||
# Create mock torch module
|
||||
mock_torch = Mock()
|
||||
mock_torch.cuda = Mock()
|
||||
mock_torch.cuda.is_available = Mock(return_value=False)
|
||||
|
||||
# Create a mock tensor class that supports basic operations
|
||||
class MockTensor:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
if isinstance(data, (list, tuple)):
|
||||
self.shape = [len(data)]
|
||||
elif isinstance(data, MockTensor):
|
||||
self.shape = data.shape
|
||||
else:
|
||||
self.shape = getattr(data, 'shape', [1])
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(self.data, (list, tuple)):
|
||||
if isinstance(idx, slice):
|
||||
return MockTensor(self.data[idx])
|
||||
return self.data[idx]
|
||||
return self
|
||||
|
||||
def max(self):
|
||||
if isinstance(self.data, (list, tuple)):
|
||||
max_val = max(self.data)
|
||||
return MockTensor(max_val)
|
||||
return 5 # Default for testing
|
||||
|
||||
def item(self):
|
||||
if isinstance(self.data, (list, tuple)):
|
||||
return max(self.data)
|
||||
if isinstance(self.data, (int, float)):
|
||||
return self.data
|
||||
return 5 # Default for testing
|
||||
|
||||
def cuda(self):
|
||||
"""Support cuda conversion"""
|
||||
return self
|
||||
|
||||
def any(self):
|
||||
if isinstance(self.data, (list, tuple)):
|
||||
return any(self.data)
|
||||
return False
|
||||
|
||||
def all(self):
|
||||
if isinstance(self.data, (list, tuple)):
|
||||
return all(self.data)
|
||||
return True
|
||||
|
||||
def unsqueeze(self, dim):
|
||||
return self
|
||||
|
||||
def expand(self, *args):
|
||||
return self
|
||||
|
||||
def type_as(self, other):
|
||||
return self
|
||||
|
||||
# Add tensor operations to mock torch
|
||||
mock_torch.tensor = lambda x: MockTensor(x)
|
||||
mock_torch.zeros = lambda *args: MockTensor([0] * (args[0] if isinstance(args[0], int) else args[0][0]))
|
||||
mock_torch.arange = lambda x: MockTensor(list(range(x)))
|
||||
mock_torch.gt = lambda x, y: MockTensor([False] * x.shape[0])
|
||||
|
||||
# Mock modules before they're imported
|
||||
sys.modules["torch"] = mock_torch
|
||||
sys.modules["transformers"] = Mock()
|
||||
sys.modules["phonemizer"] = Mock()
|
||||
sys.modules["models"] = Mock()
|
||||
|
@ -31,14 +96,22 @@ sys.modules["kokoro"] = Mock()
|
|||
sys.modules["kokoro.generate"] = Mock()
|
||||
sys.modules["kokoro.phonemize"] = Mock()
|
||||
sys.modules["kokoro.tokenize"] = Mock()
|
||||
sys.modules["onnxruntime"] = Mock()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_tts_model():
|
||||
"""Mock TTSModel to avoid loading real models during tests"""
|
||||
with patch("api.src.services.tts.TTSModel") as mock:
|
||||
"""Mock TTSModel and TTS model initialization"""
|
||||
with patch("api.src.services.tts_model.TTSModel") as mock_tts_model, \
|
||||
patch("api.src.services.tts_base.TTSBaseModel") as mock_base_model:
|
||||
|
||||
# Mock TTSModel
|
||||
model_instance = Mock()
|
||||
model_instance.get_instance.return_value = model_instance
|
||||
model_instance.get_voicepack.return_value = None
|
||||
mock.get_instance.return_value = model_instance
|
||||
mock_tts_model.get_instance.return_value = model_instance
|
||||
|
||||
# Mock TTS model initialization
|
||||
mock_base_model.setup.return_value = 1 # Return dummy voice count
|
||||
|
||||
yield model_instance
|
||||
|
|
|
@ -26,13 +26,11 @@ def test_health_check(test_client):
|
|||
@patch("api.src.main.logger")
|
||||
async def test_lifespan_successful_warmup(mock_logger, mock_tts_model):
|
||||
"""Test successful model warmup in lifespan"""
|
||||
# Mock the model initialization with model info and voicepack count
|
||||
mock_model = MagicMock()
|
||||
# Mock file system for voice counting
|
||||
mock_tts_model.VOICES_DIR = "/mock/voices"
|
||||
with patch("os.listdir", return_value=["voice1.pt", "voice2.pt", "voice3.pt"]):
|
||||
mock_tts_model.initialize.return_value = (mock_model, 3) # 3 voice files
|
||||
mock_tts_model._device = "cuda" # Set device class variable
|
||||
mock_tts_model.setup.return_value = 3 # 3 voice files
|
||||
mock_tts_model.get_device.return_value = "cuda"
|
||||
|
||||
# Create an async generator from the lifespan context manager
|
||||
async_gen = lifespan(MagicMock())
|
||||
|
@ -44,8 +42,8 @@ async def test_lifespan_successful_warmup(mock_logger, mock_tts_model):
|
|||
mock_logger.info.assert_any_call("Model loaded and warmed up on cuda")
|
||||
mock_logger.info.assert_any_call("3 voice packs loaded successfully")
|
||||
|
||||
# Verify model initialization was called
|
||||
mock_tts_model.initialize.assert_called_once()
|
||||
# Verify model setup was called
|
||||
mock_tts_model.setup.assert_called_once()
|
||||
|
||||
# Clean up
|
||||
await async_gen.__aexit__(None, None, None)
|
||||
|
@ -56,14 +54,14 @@ async def test_lifespan_successful_warmup(mock_logger, mock_tts_model):
|
|||
@patch("api.src.main.logger")
|
||||
async def test_lifespan_failed_warmup(mock_logger, mock_tts_model):
|
||||
"""Test failed model warmup in lifespan"""
|
||||
# Mock the model initialization to fail
|
||||
mock_tts_model.initialize.side_effect = Exception("Failed to initialize model")
|
||||
# Mock the model setup to fail
|
||||
mock_tts_model.setup.side_effect = RuntimeError("Failed to initialize model")
|
||||
|
||||
# Create an async generator from the lifespan context manager
|
||||
async_gen = lifespan(MagicMock())
|
||||
|
||||
# Verify the exception is raised
|
||||
with pytest.raises(Exception, match="Failed to initialize model"):
|
||||
with pytest.raises(RuntimeError, match="Failed to initialize model"):
|
||||
await async_gen.__aenter__()
|
||||
|
||||
# Verify the expected logging sequence
|
||||
|
@ -77,20 +75,18 @@ async def test_lifespan_failed_warmup(mock_logger, mock_tts_model):
|
|||
@patch("api.src.main.TTSModel")
|
||||
async def test_lifespan_cuda_warmup(mock_tts_model):
|
||||
"""Test model warmup specifically on CUDA"""
|
||||
# Mock the model initialization with CUDA and voicepacks
|
||||
mock_model = MagicMock()
|
||||
# Mock file system for voice counting
|
||||
mock_tts_model.VOICES_DIR = "/mock/voices"
|
||||
with patch("os.listdir", return_value=["voice1.pt", "voice2.pt"]):
|
||||
mock_tts_model.initialize.return_value = (mock_model, 2) # 2 voice files
|
||||
mock_tts_model._device = "cuda" # Set device class variable
|
||||
mock_tts_model.setup.return_value = 2 # 2 voice files
|
||||
mock_tts_model.get_device.return_value = "cuda"
|
||||
|
||||
# Create an async generator from the lifespan context manager
|
||||
async_gen = lifespan(MagicMock())
|
||||
await async_gen.__aenter__()
|
||||
|
||||
# Verify model was initialized
|
||||
mock_tts_model.initialize.assert_called_once()
|
||||
# Verify model setup was called
|
||||
mock_tts_model.setup.assert_called_once()
|
||||
|
||||
# Clean up
|
||||
await async_gen.__aexit__(None, None, None)
|
||||
|
@ -100,22 +96,20 @@ async def test_lifespan_cuda_warmup(mock_tts_model):
|
|||
@patch("api.src.main.TTSModel")
|
||||
async def test_lifespan_cpu_fallback(mock_tts_model):
|
||||
"""Test model warmup falling back to CPU"""
|
||||
# Mock the model initialization with CPU and voicepacks
|
||||
mock_model = MagicMock()
|
||||
# Mock file system for voice counting
|
||||
mock_tts_model.VOICES_DIR = "/mock/voices"
|
||||
with patch(
|
||||
"os.listdir", return_value=["voice1.pt", "voice2.pt", "voice3.pt", "voice4.pt"]
|
||||
):
|
||||
mock_tts_model.initialize.return_value = (mock_model, 4) # 4 voice files
|
||||
mock_tts_model._device = "cpu" # Set device class variable
|
||||
mock_tts_model.setup.return_value = 4 # 4 voice files
|
||||
mock_tts_model.get_device.return_value = "cpu"
|
||||
|
||||
# Create an async generator from the lifespan context manager
|
||||
async_gen = lifespan(MagicMock())
|
||||
await async_gen.__aenter__()
|
||||
|
||||
# Verify model was initialized
|
||||
mock_tts_model.initialize.assert_called_once()
|
||||
# Verify model setup was called
|
||||
mock_tts_model.setup.assert_called_once()
|
||||
|
||||
# Clean up
|
||||
await async_gen.__aexit__(None, None, None)
|
||||
|
|
144
api/tests/test_tts_implementations.py
Normal file
|
@ -0,0 +1,144 @@
|
|||
"""Tests for TTS model implementations"""
|
||||
import os
|
||||
import torch
|
||||
import pytest
|
||||
import numpy as np
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from api.src.services.tts_base import TTSBaseModel
|
||||
from api.src.services.tts_cpu import TTSCPUModel
|
||||
from api.src.services.tts_gpu import TTSGPUModel, length_to_mask
|
||||
|
||||
# Base Model Tests
|
||||
def test_get_device_error():
|
||||
"""Test get_device() raises error when not initialized"""
|
||||
TTSBaseModel._device = None
|
||||
with pytest.raises(RuntimeError, match="Model not initialized"):
|
||||
TTSBaseModel.get_device()
|
||||
|
||||
@patch('torch.cuda.is_available')
|
||||
@patch('os.path.exists')
|
||||
@patch('os.path.join')
|
||||
@patch('os.listdir')
|
||||
@patch('torch.load')
|
||||
@patch('torch.save')
|
||||
def test_setup_cuda_available(mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available):
|
||||
"""Test setup with CUDA available"""
|
||||
TTSBaseModel._device = None
|
||||
mock_cuda_available.return_value = True
|
||||
mock_exists.return_value = True
|
||||
mock_load.return_value = torch.zeros(1)
|
||||
mock_listdir.return_value = ["voice1.pt", "voice2.pt"]
|
||||
mock_join.return_value = "/mocked/path"
|
||||
|
||||
# Mock the abstract methods
|
||||
TTSBaseModel.initialize = MagicMock(return_value=True)
|
||||
TTSBaseModel.process_text = MagicMock(return_value=("dummy", [1,2,3]))
|
||||
TTSBaseModel.generate_from_tokens = MagicMock(return_value=np.zeros(1000))
|
||||
|
||||
voice_count = TTSBaseModel.setup()
|
||||
assert TTSBaseModel._device == "cuda"
|
||||
assert voice_count == 2
|
||||
|
||||
@patch('torch.cuda.is_available')
|
||||
@patch('os.path.exists')
|
||||
@patch('os.path.join')
|
||||
@patch('os.listdir')
|
||||
@patch('torch.load')
|
||||
@patch('torch.save')
|
||||
def test_setup_cuda_unavailable(mock_save, mock_load, mock_listdir, mock_join, mock_exists, mock_cuda_available):
|
||||
"""Test setup with CUDA unavailable"""
|
||||
TTSBaseModel._device = None
|
||||
mock_cuda_available.return_value = False
|
||||
mock_exists.return_value = True
|
||||
mock_load.return_value = torch.zeros(1)
|
||||
mock_listdir.return_value = ["voice1.pt", "voice2.pt"]
|
||||
mock_join.return_value = "/mocked/path"
|
||||
|
||||
# Mock the abstract methods
|
||||
TTSBaseModel.initialize = MagicMock(return_value=True)
|
||||
TTSBaseModel.process_text = MagicMock(return_value=("dummy", [1,2,3]))
|
||||
TTSBaseModel.generate_from_tokens = MagicMock(return_value=np.zeros(1000))
|
||||
|
||||
voice_count = TTSBaseModel.setup()
|
||||
assert TTSBaseModel._device == "cpu"
|
||||
assert voice_count == 2
|
||||
|
||||
# CPU Model Tests
|
||||
def test_cpu_initialize_missing_model():
|
||||
"""Test CPU initialize with missing model"""
|
||||
with patch('os.path.exists', return_value=False):
|
||||
result = TTSCPUModel.initialize("dummy_dir")
|
||||
assert result is None
|
||||
|
||||
def test_cpu_generate_uninitialized():
|
||||
"""Test CPU generate methods with uninitialized model"""
|
||||
TTSCPUModel._onnx_session = None
|
||||
|
||||
with pytest.raises(RuntimeError, match="ONNX model not initialized"):
|
||||
TTSCPUModel.generate_from_text("test", torch.zeros(1), "en", 1.0)
|
||||
|
||||
with pytest.raises(RuntimeError, match="ONNX model not initialized"):
|
||||
TTSCPUModel.generate_from_tokens([1,2,3], torch.zeros(1), 1.0)
|
||||
|
||||
def test_cpu_process_text():
|
||||
"""Test CPU process_text functionality"""
|
||||
with patch('api.src.services.tts_cpu.phonemize') as mock_phonemize, \
|
||||
patch('api.src.services.tts_cpu.tokenize') as mock_tokenize:
|
||||
|
||||
mock_phonemize.return_value = "test phonemes"
|
||||
mock_tokenize.return_value = [1, 2, 3]
|
||||
|
||||
phonemes, tokens = TTSCPUModel.process_text("test", "en")
|
||||
assert phonemes == "test phonemes"
|
||||
assert tokens == [0, 1, 2, 3, 0] # Should add start/end tokens
|
||||
|
||||
# GPU Model Tests
|
||||
@patch('torch.cuda.is_available')
|
||||
def test_gpu_initialize_cuda_unavailable(mock_cuda_available):
|
||||
"""Test GPU initialize with CUDA unavailable"""
|
||||
mock_cuda_available.return_value = False
|
||||
TTSGPUModel._instance = None
|
||||
|
||||
result = TTSGPUModel.initialize("dummy_dir", "dummy_path")
|
||||
assert result is None
|
||||
|
||||
@patch('api.src.services.tts_gpu.length_to_mask')
|
||||
def test_gpu_length_to_mask(mock_length_to_mask):
|
||||
"""Test length_to_mask function"""
|
||||
# Setup mock return value
|
||||
expected_mask = torch.tensor([
|
||||
[False, False, False, True, True],
|
||||
[False, False, False, False, False]
|
||||
])
|
||||
mock_length_to_mask.return_value = expected_mask
|
||||
|
||||
# Call function with test input
|
||||
lengths = torch.tensor([3, 5])
|
||||
mask = mock_length_to_mask(lengths)
|
||||
|
||||
# Verify mock was called with correct input
|
||||
mock_length_to_mask.assert_called_once()
|
||||
assert torch.equal(mask, expected_mask)
|
||||
|
||||
def test_gpu_generate_uninitialized():
|
||||
"""Test GPU generate methods with uninitialized model"""
|
||||
TTSGPUModel._instance = None
|
||||
|
||||
with pytest.raises(RuntimeError, match="GPU model not initialized"):
|
||||
TTSGPUModel.generate_from_text("test", torch.zeros(1), "en", 1.0)
|
||||
|
||||
with pytest.raises(RuntimeError, match="GPU model not initialized"):
|
||||
TTSGPUModel.generate_from_tokens([1,2,3], torch.zeros(1), 1.0)
|
||||
|
||||
def test_gpu_process_text():
|
||||
"""Test GPU process_text functionality"""
|
||||
with patch('api.src.services.tts_gpu.phonemize') as mock_phonemize, \
|
||||
patch('api.src.services.tts_gpu.tokenize') as mock_tokenize:
|
||||
|
||||
mock_phonemize.return_value = "test phonemes"
|
||||
mock_tokenize.return_value = [1, 2, 3]
|
||||
|
||||
phonemes, tokens = TTSGPUModel.process_text("test", "en")
|
||||
assert phonemes == "test phonemes"
|
||||
assert tokens == [1, 2, 3] # GPU implementation doesn't add start/end tokens
|
|
@ -6,14 +6,19 @@ from unittest.mock import MagicMock, call, patch
|
|||
import numpy as np
|
||||
import torch
|
||||
import pytest
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
from api.src.services.tts import TTSModel, TTSService
|
||||
from api.src.core.config import settings
|
||||
from api.src.services.tts_model import TTSModel
|
||||
from api.src.services.tts_service import TTSService
|
||||
from api.src.services.tts_cpu import TTSCPUModel
|
||||
from api.src.services.tts_gpu import TTSGPUModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tts_service():
|
||||
"""Create a TTSService instance for testing"""
|
||||
return TTSService(start_worker=False)
|
||||
return TTSService()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -68,80 +73,143 @@ def test_list_voices(mock_join, mock_listdir, tts_service):
|
|||
assert "not_a_voice" not in voices
|
||||
|
||||
|
||||
@patch("api.src.services.tts.TTSModel.get_instance")
|
||||
@patch("api.src.services.tts.TTSModel.get_voicepack")
|
||||
@patch("api.src.services.tts.normalize_text")
|
||||
@patch("api.src.services.tts.phonemize")
|
||||
@patch("api.src.services.tts.tokenize")
|
||||
@patch("api.src.services.tts.generate")
|
||||
def test_generate_audio_empty_text(
|
||||
mock_generate,
|
||||
mock_tokenize,
|
||||
mock_phonemize,
|
||||
mock_normalize,
|
||||
mock_voicepack,
|
||||
mock_instance,
|
||||
tts_service,
|
||||
):
|
||||
"""Test generating audio with empty text"""
|
||||
mock_normalize.return_value = ""
|
||||
@patch("os.listdir")
|
||||
def test_list_voices_error(mock_listdir, tts_service):
|
||||
"""Test error handling in list_voices"""
|
||||
mock_listdir.side_effect = Exception("Failed to list directory")
|
||||
|
||||
voices = tts_service.list_voices()
|
||||
assert voices == []
|
||||
|
||||
|
||||
def mock_model_setup(cuda_available=False):
|
||||
"""Helper function to mock model setup"""
|
||||
# Reset model state
|
||||
TTSModel._instance = None
|
||||
TTSModel._device = None
|
||||
TTSModel._voicepacks = {}
|
||||
|
||||
# Create mock model instance with proper generate method
|
||||
mock_model = MagicMock()
|
||||
mock_model.generate.return_value = np.zeros(24000, dtype=np.float32)
|
||||
TTSModel._instance = mock_model
|
||||
|
||||
# Set device based on CUDA availability
|
||||
TTSModel._device = "cuda" if cuda_available else "cpu"
|
||||
|
||||
return 3 # Return voice count (including af.pt)
|
||||
|
||||
|
||||
def test_model_initialization_cuda():
|
||||
"""Test model initialization with CUDA"""
|
||||
# Simulate CUDA availability
|
||||
voice_count = mock_model_setup(cuda_available=True)
|
||||
|
||||
assert TTSModel.get_device() == "cuda"
|
||||
assert voice_count == 3 # voice1.pt, voice2.pt, af.pt
|
||||
|
||||
|
||||
def test_model_initialization_cpu():
|
||||
"""Test model initialization with CPU"""
|
||||
# Simulate no CUDA availability
|
||||
voice_count = mock_model_setup(cuda_available=False)
|
||||
|
||||
assert TTSModel.get_device() == "cpu"
|
||||
assert voice_count == 3 # voice1.pt, voice2.pt, af.pt
|
||||
|
||||
|
||||
def test_generate_audio_empty_text(tts_service):
|
||||
"""Test generating audio with empty text"""
|
||||
with pytest.raises(ValueError, match="Text is empty after preprocessing"):
|
||||
tts_service._generate_audio("", "af", 1.0)
|
||||
|
||||
|
||||
@patch("api.src.services.tts.TTSModel.get_instance")
|
||||
@patch("api.src.services.tts_model.TTSModel.get_instance")
|
||||
@patch("api.src.services.tts_model.TTSModel.get_device")
|
||||
@patch("os.path.exists")
|
||||
@patch("api.src.services.tts.normalize_text")
|
||||
@patch("api.src.services.tts.phonemize")
|
||||
@patch("api.src.services.tts.tokenize")
|
||||
@patch("api.src.services.tts.generate")
|
||||
@patch("kokoro.normalize_text")
|
||||
@patch("kokoro.phonemize")
|
||||
@patch("kokoro.tokenize")
|
||||
@patch("kokoro.generate")
|
||||
@patch("torch.load")
|
||||
def test_generate_audio_no_chunks(
|
||||
def test_generate_audio_phonemize_error(
|
||||
mock_torch_load,
|
||||
mock_generate,
|
||||
mock_tokenize,
|
||||
mock_phonemize,
|
||||
mock_normalize,
|
||||
mock_exists,
|
||||
mock_get_device,
|
||||
mock_instance,
|
||||
tts_service,
|
||||
):
|
||||
"""Test generating audio with no successful chunks"""
|
||||
"""Test handling phonemization error"""
|
||||
mock_normalize.return_value = "Test text"
|
||||
mock_phonemize.return_value = "Test text"
|
||||
mock_tokenize.return_value = ["test", "text"]
|
||||
mock_generate.return_value = (None, None)
|
||||
mock_instance.return_value = (MagicMock(), "cpu")
|
||||
mock_phonemize.side_effect = Exception("Phonemization failed")
|
||||
mock_instance.return_value = (mock_generate, "cpu") # Use the same mock for consistency
|
||||
mock_get_device.return_value = "cpu"
|
||||
mock_exists.return_value = True
|
||||
mock_torch_load.return_value = MagicMock()
|
||||
mock_torch_load.return_value = torch.zeros((10, 24000))
|
||||
mock_generate.return_value = (None, None)
|
||||
|
||||
with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
|
||||
tts_service._generate_audio("Test text", "af", 1.0)
|
||||
|
||||
|
||||
@patch("torch.load")
|
||||
@patch("torch.save")
|
||||
@patch("torch.stack")
|
||||
@patch("torch.mean")
|
||||
@patch("api.src.services.tts_model.TTSModel.get_instance")
|
||||
@patch("api.src.services.tts_model.TTSModel.get_device")
|
||||
@patch("os.path.exists")
|
||||
def test_combine_voices(
|
||||
mock_exists, mock_mean, mock_stack, mock_save, mock_load, tts_service
|
||||
@patch("kokoro.normalize_text")
|
||||
@patch("kokoro.phonemize")
|
||||
@patch("kokoro.tokenize")
|
||||
@patch("kokoro.generate")
|
||||
@patch("torch.load")
|
||||
def test_generate_audio_error(
|
||||
mock_torch_load,
|
||||
mock_generate,
|
||||
mock_tokenize,
|
||||
mock_phonemize,
|
||||
mock_normalize,
|
||||
mock_exists,
|
||||
mock_get_device,
|
||||
mock_instance,
|
||||
tts_service,
|
||||
):
|
||||
"""Test combining multiple voices"""
|
||||
# Setup mocks
|
||||
"""Test handling generation error"""
|
||||
mock_normalize.return_value = "Test text"
|
||||
mock_phonemize.return_value = "Test text"
|
||||
mock_tokenize.return_value = [1, 2] # Return integers instead of strings
|
||||
mock_generate.side_effect = Exception("Generation failed")
|
||||
mock_instance.return_value = (mock_generate, "cpu") # Use the same mock for consistency
|
||||
mock_get_device.return_value = "cpu"
|
||||
mock_exists.return_value = True
|
||||
mock_load.return_value = torch.tensor([1.0, 2.0])
|
||||
mock_stack.return_value = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
|
||||
mock_mean.return_value = torch.tensor([2.0, 3.0])
|
||||
mock_torch_load.return_value = torch.zeros((10, 24000))
|
||||
|
||||
# Test combining two voices
|
||||
result = tts_service.combine_voices(["voice1", "voice2"])
|
||||
with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
|
||||
tts_service._generate_audio("Test text", "af", 1.0)
|
||||
|
||||
assert result == "voice1_voice2"
|
||||
mock_stack.assert_called_once()
|
||||
mock_mean.assert_called_once()
|
||||
mock_save.assert_called_once()
|
||||
|
||||
def test_save_audio(tts_service, sample_audio, tmp_path):
|
||||
"""Test saving audio to file"""
|
||||
output_path = os.path.join(tmp_path, "test_output.wav")
|
||||
tts_service._save_audio(sample_audio, output_path)
|
||||
assert os.path.exists(output_path)
|
||||
assert os.path.getsize(output_path) > 0
|
||||
|
||||
|
||||
def test_combine_voices(tts_service):
|
||||
"""Test combining multiple voices"""
|
||||
# Setup mocks for torch operations
|
||||
with patch('torch.load', return_value=torch.tensor([1.0, 2.0])), \
|
||||
patch('torch.stack', return_value=torch.tensor([[1.0, 2.0], [3.0, 4.0]])), \
|
||||
patch('torch.mean', return_value=torch.tensor([2.0, 3.0])), \
|
||||
patch('torch.save'), \
|
||||
patch('os.path.exists', return_value=True):
|
||||
|
||||
# Test combining two voices
|
||||
result = tts_service.combine_voices(["voice1", "voice2"])
|
||||
|
||||
assert result == "voice1_voice2"
|
||||
|
||||
|
||||
def test_combine_voices_invalid_input(tts_service):
|
||||
|
@ -155,221 +223,17 @@ def test_combine_voices_invalid_input(tts_service):
|
|||
tts_service.combine_voices(["voice1"])
|
||||
|
||||
|
||||
@patch("os.makedirs")
|
||||
@patch("os.path.exists")
|
||||
@patch("os.listdir")
|
||||
@patch("torch.load")
|
||||
@patch("torch.save")
|
||||
@patch("os.path.join")
|
||||
def test_ensure_voices(
|
||||
mock_join,
|
||||
mock_save,
|
||||
mock_load,
|
||||
mock_listdir,
|
||||
mock_exists,
|
||||
mock_makedirs,
|
||||
tts_service,
|
||||
):
|
||||
"""Test voice directory initialization"""
|
||||
# Setup mocks
|
||||
mock_exists.side_effect = [
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
] # base_dir exists, voice files don't exist
|
||||
mock_listdir.return_value = ["voice1.pt", "voice2.pt"]
|
||||
mock_load.return_value = MagicMock()
|
||||
mock_join.return_value = "/fake/path"
|
||||
|
||||
# Test voice directory initialization
|
||||
tts_service._ensure_voices()
|
||||
|
||||
# Verify directory was created
|
||||
mock_makedirs.assert_called_once()
|
||||
|
||||
# Verify voices were loaded and saved
|
||||
assert mock_load.call_count == len(mock_listdir.return_value)
|
||||
assert mock_save.call_count == len(mock_listdir.return_value)
|
||||
|
||||
|
||||
@patch("api.src.services.tts.TTSModel.get_instance")
|
||||
@patch("os.path.exists")
|
||||
@patch("api.src.services.tts.normalize_text")
|
||||
@patch("api.src.services.tts.phonemize")
|
||||
@patch("api.src.services.tts.tokenize")
|
||||
@patch("api.src.services.tts.generate")
|
||||
@patch("torch.load")
|
||||
def test_generate_audio_success(
|
||||
mock_torch_load,
|
||||
mock_generate,
|
||||
mock_tokenize,
|
||||
mock_phonemize,
|
||||
mock_normalize,
|
||||
mock_exists,
|
||||
mock_instance,
|
||||
tts_service,
|
||||
sample_audio,
|
||||
):
|
||||
"""Test successful audio generation"""
|
||||
mock_normalize.return_value = "Test text"
|
||||
mock_phonemize.return_value = "Test text"
|
||||
mock_tokenize.return_value = ["test", "text"]
|
||||
mock_generate.return_value = (sample_audio, None)
|
||||
mock_instance.return_value = (MagicMock(), "cpu")
|
||||
mock_exists.return_value = True
|
||||
mock_torch_load.return_value = MagicMock()
|
||||
|
||||
audio, processing_time = tts_service._generate_audio("Test text", "af", 1.0)
|
||||
assert isinstance(audio, np.ndarray)
|
||||
assert isinstance(processing_time, float)
|
||||
assert len(audio) > 0
|
||||
|
||||
|
||||
@patch("api.src.services.tts.torch.cuda.is_available")
|
||||
@patch("api.src.services.tts.build_model")
|
||||
def test_model_initialization_cuda(mock_build_model, mock_cuda_available):
|
||||
"""Test model initialization with CUDA"""
|
||||
mock_cuda_available.return_value = True
|
||||
mock_model = MagicMock()
|
||||
mock_build_model.return_value = mock_model
|
||||
|
||||
TTSModel._instance = None # Reset singleton
|
||||
model, voice_count = TTSModel.initialize()
|
||||
|
||||
assert TTSModel._device == "cuda" # Check the class variable instead
|
||||
assert model == mock_model
|
||||
mock_build_model.assert_called_once()
|
||||
|
||||
|
||||
@patch("api.src.services.tts.torch.cuda.is_available")
|
||||
@patch("api.src.services.tts.build_model")
|
||||
def test_model_initialization_cpu(mock_build_model, mock_cuda_available):
|
||||
"""Test model initialization with CPU"""
|
||||
mock_cuda_available.return_value = False
|
||||
mock_model = MagicMock()
|
||||
mock_build_model.return_value = mock_model
|
||||
|
||||
TTSModel._instance = None # Reset singleton
|
||||
model, voice_count = TTSModel.initialize()
|
||||
|
||||
assert TTSModel._device == "cpu" # Check the class variable instead
|
||||
assert model == mock_model
|
||||
mock_build_model.assert_called_once()
|
||||
|
||||
|
||||
@patch("api.src.services.tts.TTSService._get_voice_path")
|
||||
@patch("api.src.services.tts.TTSModel.get_instance")
|
||||
@patch("api.src.services.tts_service.TTSService._get_voice_path")
|
||||
@patch("api.src.services.tts_model.TTSModel.get_instance")
|
||||
def test_voicepack_loading_error(mock_get_instance, mock_get_voice_path):
|
||||
"""Test voicepack loading error handling"""
|
||||
mock_get_voice_path.return_value = None
|
||||
mock_get_instance.return_value = (MagicMock(), "cpu")
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.generate.return_value = np.zeros(24000, dtype=np.float32)
|
||||
mock_get_instance.return_value = (mock_instance, "cpu")
|
||||
|
||||
TTSModel._voicepacks = {} # Reset voicepacks
|
||||
|
||||
service = TTSService(start_worker=False)
|
||||
service = TTSService()
|
||||
with pytest.raises(ValueError, match="Voice not found: nonexistent_voice"):
|
||||
service._generate_audio("test", "nonexistent_voice", 1.0)
|
||||
|
||||
|
||||
@patch("api.src.services.tts.TTSModel")
|
||||
def test_save_audio(mock_tts_model, tts_service, sample_audio, tmp_path):
|
||||
"""Test saving audio to file"""
|
||||
output_dir = os.path.join(tmp_path, "test_output")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
output_path = os.path.join(output_dir, "audio.wav")
|
||||
|
||||
tts_service._save_audio(sample_audio, output_path)
|
||||
|
||||
assert os.path.exists(output_path)
|
||||
assert os.path.getsize(output_path) > 0
|
||||
|
||||
|
||||
@patch("api.src.services.tts.TTSModel.get_instance")
|
||||
@patch("os.path.exists")
|
||||
@patch("api.src.services.tts.normalize_text")
|
||||
@patch("api.src.services.tts.generate")
|
||||
@patch("torch.load")
|
||||
def test_generate_audio_without_stitching(
|
||||
mock_torch_load,
|
||||
mock_generate,
|
||||
mock_normalize,
|
||||
mock_exists,
|
||||
mock_instance,
|
||||
tts_service,
|
||||
sample_audio,
|
||||
):
|
||||
"""Test generating audio without text stitching"""
|
||||
mock_normalize.return_value = "Test text"
|
||||
mock_generate.return_value = (sample_audio, None)
|
||||
mock_instance.return_value = (MagicMock(), "cpu")
|
||||
mock_exists.return_value = True
|
||||
mock_torch_load.return_value = MagicMock()
|
||||
|
||||
audio, processing_time = tts_service._generate_audio(
|
||||
"Test text", "af", 1.0, stitch_long_output=False
|
||||
)
|
||||
assert isinstance(audio, np.ndarray)
|
||||
assert len(audio) > 0
|
||||
mock_generate.assert_called_once()
|
||||
|
||||
|
||||
@patch("os.listdir")
|
||||
def test_list_voices_error(mock_listdir, tts_service):
|
||||
"""Test error handling in list_voices"""
|
||||
mock_listdir.side_effect = Exception("Failed to list directory")
|
||||
|
||||
voices = tts_service.list_voices()
|
||||
assert voices == []
|
||||
|
||||
|
||||
@patch("api.src.services.tts.TTSModel.get_instance")
|
||||
@patch("os.path.exists")
|
||||
@patch("api.src.services.tts.normalize_text")
|
||||
@patch("api.src.services.tts.phonemize")
|
||||
@patch("api.src.services.tts.tokenize")
|
||||
@patch("api.src.services.tts.generate")
|
||||
@patch("torch.load")
|
||||
def test_generate_audio_phonemize_error(
|
||||
mock_torch_load,
|
||||
mock_generate,
|
||||
mock_tokenize,
|
||||
mock_phonemize,
|
||||
mock_normalize,
|
||||
mock_exists,
|
||||
mock_instance,
|
||||
tts_service,
|
||||
):
|
||||
"""Test handling phonemization error"""
|
||||
mock_normalize.return_value = "Test text"
|
||||
mock_phonemize.side_effect = Exception("Phonemization failed")
|
||||
mock_instance.return_value = (MagicMock(), "cpu")
|
||||
mock_exists.return_value = True
|
||||
mock_torch_load.return_value = MagicMock()
|
||||
mock_generate.return_value = (None, None)
|
||||
|
||||
with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
|
||||
tts_service._generate_audio("Test text", "af", 1.0)
|
||||
|
||||
|
||||
@patch("api.src.services.tts.TTSModel.get_instance")
|
||||
@patch("os.path.exists")
|
||||
@patch("api.src.services.tts.normalize_text")
|
||||
@patch("api.src.services.tts.generate")
|
||||
@patch("torch.load")
|
||||
def test_generate_audio_error(
|
||||
mock_torch_load,
|
||||
mock_generate,
|
||||
mock_normalize,
|
||||
mock_exists,
|
||||
mock_instance,
|
||||
tts_service,
|
||||
):
|
||||
"""Test handling generation error"""
|
||||
mock_normalize.return_value = "Test text"
|
||||
mock_generate.side_effect = Exception("Generation failed")
|
||||
mock_instance.return_value = (MagicMock(), "cpu")
|
||||
mock_exists.return_value = True
|
||||
mock_torch_load.return_value = MagicMock()
|
||||
|
||||
with pytest.raises(ValueError, match="No audio chunks were generated successfully"):
|
||||
tts_service._generate_audio("Test text", "af", 1.0)
|
||||
|
|
|
@ -36,6 +36,13 @@ services:
|
|||
- "8880:8880"
|
||||
environment:
|
||||
- PYTHONPATH=/app:/app/Kokoro-82M
|
||||
# ONNX Optimization Settings for vectorized operations
|
||||
- ONNX_NUM_THREADS=8 # Maximize core usage for vectorized ops
|
||||
- ONNX_INTER_OP_THREADS=4 # Higher inter-op for parallel matrix operations
|
||||
- ONNX_EXECUTION_MODE=parallel
|
||||
- ONNX_OPTIMIZATION_LEVEL=all
|
||||
- ONNX_MEMORY_PATTERN=true
|
||||
- ONNX_ARENA_EXTEND_STRATEGY=kNextPowerOfTwo
|
||||
depends_on:
|
||||
model-fetcher:
|
||||
condition: service_healthy
|
||||
|
|
0
examples/__init__.py
Normal file
0
examples/assorted_checks/__init__.py
Normal file
0
examples/assorted_checks/benchmarks/__init__.py
Normal file
242
examples/assorted_checks/benchmarks/benchmark_tts_rtf.py
Normal file
|
@ -0,0 +1,242 @@
|
|||
#!/usr/bin/env python3
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import threading
|
||||
import queue
|
||||
import pandas as pd
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
from lib.shared_plotting import plot_system_metrics, plot_correlation
|
||||
from lib.shared_utils import (
|
||||
get_system_metrics, save_json_results, write_benchmark_stats,
|
||||
real_time_factor
|
||||
)
|
||||
from lib.shared_benchmark_utils import (
|
||||
get_text_for_tokens, make_tts_request, generate_token_sizes, enc
|
||||
)
|
||||
|
||||
class SystemMonitor:
|
||||
def __init__(self, interval=1.0):
|
||||
self.interval = interval
|
||||
self.metrics_queue = queue.Queue()
|
||||
self.stop_event = threading.Event()
|
||||
self.metrics_timeline = []
|
||||
self.start_time = None
|
||||
|
||||
def _monitor_loop(self):
|
||||
"""Background thread function to collect system metrics."""
|
||||
while not self.stop_event.is_set():
|
||||
metrics = get_system_metrics()
|
||||
metrics["relative_time"] = time.time() - self.start_time
|
||||
self.metrics_queue.put(metrics)
|
||||
time.sleep(self.interval)
|
||||
|
||||
def start(self):
|
||||
"""Start the monitoring thread."""
|
||||
self.start_time = time.time()
|
||||
self.monitor_thread = threading.Thread(target=self._monitor_loop)
|
||||
self.monitor_thread.daemon = True
|
||||
self.monitor_thread.start()
|
||||
|
||||
def stop(self):
|
||||
"""Stop the monitoring thread and collect final metrics."""
|
||||
self.stop_event.set()
|
||||
if hasattr(self, 'monitor_thread'):
|
||||
self.monitor_thread.join(timeout=2)
|
||||
|
||||
# Collect all metrics from queue
|
||||
while True:
|
||||
try:
|
||||
metrics = self.metrics_queue.get_nowait()
|
||||
self.metrics_timeline.append(metrics)
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
return self.metrics_timeline
|
||||
|
||||
def main():
|
||||
# Initialize system monitor
|
||||
monitor = SystemMonitor(interval=1.0) # 1 second interval
|
||||
# Set prefix for output files (e.g. "gpu", "cpu", "onnx", etc.)
|
||||
prefix = "gpu"
|
||||
# Generate token sizes
|
||||
if 'gpu' in prefix:
|
||||
token_sizes = generate_token_sizes(
|
||||
max_tokens=5000, dense_step=150,
|
||||
dense_max=1000, sparse_step=1000)
|
||||
elif 'cpu' in prefix:
|
||||
token_sizes = generate_token_sizes(
|
||||
max_tokens=1000, dense_step=300,
|
||||
dense_max=1000, sparse_step=0)
|
||||
else:
|
||||
token_sizes = generate_token_sizes(max_tokens=3000)
|
||||
|
||||
# Set up paths relative to this file
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
output_dir = os.path.join(script_dir, "output_audio")
|
||||
output_data_dir = os.path.join(script_dir, "output_data")
|
||||
output_plots_dir = os.path.join(script_dir, "output_plots")
|
||||
|
||||
# Create output directories
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
os.makedirs(output_data_dir, exist_ok=True)
|
||||
os.makedirs(output_plots_dir, exist_ok=True)
|
||||
|
||||
# Function to prefix filenames
|
||||
def prefix_path(path: str, filename: str) -> str:
|
||||
if prefix:
|
||||
filename = f"{prefix}_{filename}"
|
||||
return os.path.join(path, filename)
|
||||
|
||||
with open(os.path.join(script_dir, "the_time_machine_hg_wells.txt"), "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
|
||||
total_tokens = len(enc.encode(text))
|
||||
print(f"Total tokens in file: {total_tokens}")
|
||||
|
||||
print(f"Testing sizes: {token_sizes}")
|
||||
|
||||
results = []
|
||||
test_start_time = time.time()
|
||||
|
||||
# Start system monitoring
|
||||
monitor.start()
|
||||
|
||||
for num_tokens in token_sizes:
|
||||
chunk = get_text_for_tokens(text, num_tokens)
|
||||
actual_tokens = len(enc.encode(chunk))
|
||||
|
||||
print(f"\nProcessing chunk with {actual_tokens} tokens:")
|
||||
print(f"Text preview: {chunk[:100]}...")
|
||||
|
||||
processing_time, audio_length = make_tts_request(
|
||||
chunk,
|
||||
output_dir=output_dir,
|
||||
prefix=prefix
|
||||
)
|
||||
if processing_time is None or audio_length is None:
|
||||
print("Breaking loop due to error")
|
||||
break
|
||||
|
||||
# Calculate RTF using the correct formula
|
||||
rtf = real_time_factor(processing_time, audio_length)
|
||||
print(f"Real-Time Factor: {rtf:.5f}")
|
||||
|
||||
results.append({
|
||||
"tokens": actual_tokens,
|
||||
"processing_time": processing_time,
|
||||
"output_length": audio_length,
|
||||
"rtf": rtf,
|
||||
"elapsed_time": round(time.time() - test_start_time, 2),
|
||||
})
|
||||
|
||||
df = pd.DataFrame(results)
|
||||
if df.empty:
|
||||
print("No data to plot")
|
||||
return
|
||||
|
||||
df["tokens_per_second"] = df["tokens"] / df["processing_time"]
|
||||
|
||||
# Write benchmark stats
|
||||
stats = [
|
||||
{
|
||||
"title": "Benchmark Statistics (with correct RTF)",
|
||||
"stats": {
|
||||
"Total tokens processed": df['tokens'].sum(),
|
||||
"Total audio generated (s)": df['output_length'].sum(),
|
||||
"Total test duration (s)": df['elapsed_time'].max(),
|
||||
"Average processing rate (tokens/s)": df['tokens_per_second'].mean(),
|
||||
"Average RTF": df['rtf'].mean(),
|
||||
"Average Real Time Speed": 1/df['rtf'].mean()
|
||||
}
|
||||
},
|
||||
{
|
||||
"title": "Per-chunk Stats",
|
||||
"stats": {
|
||||
"Average chunk size (tokens)": df['tokens'].mean(),
|
||||
"Min chunk size (tokens)": df['tokens'].min(),
|
||||
"Max chunk size (tokens)": df['tokens'].max(),
|
||||
"Average processing time (s)": df['processing_time'].mean(),
|
||||
"Average output length (s)": df['output_length'].mean()
|
||||
}
|
||||
},
|
||||
{
|
||||
"title": "Performance Ranges",
|
||||
"stats": {
|
||||
"Processing rate range (tokens/s)": f"{df['tokens_per_second'].min():.2f} - {df['tokens_per_second'].max():.2f}",
|
||||
"RTF range": f"{df['rtf'].min():.2f}x - {df['rtf'].max():.2f}x",
|
||||
"Real Time Speed range": f"{1/df['rtf'].max():.2f}x - {1/df['rtf'].min():.2f}x"
|
||||
}
|
||||
}
|
||||
]
|
||||
write_benchmark_stats(stats, prefix_path(output_data_dir, "benchmark_stats_rtf.txt"))
|
||||
|
||||
# Plot Processing Time vs Token Count
|
||||
plot_correlation(
|
||||
df, "tokens", "processing_time",
|
||||
"Processing Time vs Input Size",
|
||||
"Number of Input Tokens",
|
||||
"Processing Time (seconds)",
|
||||
prefix_path(output_plots_dir, "processing_time_rtf.png")
|
||||
)
|
||||
|
||||
# Plot RTF vs Token Count
|
||||
plot_correlation(
|
||||
df, "tokens", "rtf",
|
||||
"Real-Time Factor vs Input Size",
|
||||
"Number of Input Tokens",
|
||||
"Real-Time Factor (processing time / audio length)",
|
||||
prefix_path(output_plots_dir, "realtime_factor_rtf.png")
|
||||
)
|
||||
|
||||
# Stop monitoring and get final metrics
|
||||
final_metrics = monitor.stop()
|
||||
|
||||
# Convert metrics timeline to DataFrame for stats
|
||||
metrics_df = pd.DataFrame(final_metrics)
|
||||
|
||||
# Add system usage stats
|
||||
if not metrics_df.empty:
|
||||
stats.append({
|
||||
"title": "System Usage Statistics",
|
||||
"stats": {
|
||||
"Peak CPU Usage (%)": metrics_df['cpu_percent'].max(),
|
||||
"Avg CPU Usage (%)": metrics_df['cpu_percent'].mean(),
|
||||
"Peak RAM Usage (%)": metrics_df['ram_percent'].max(),
|
||||
"Avg RAM Usage (%)": metrics_df['ram_percent'].mean(),
|
||||
"Peak RAM Used (GB)": metrics_df['ram_used_gb'].max(),
|
||||
"Avg RAM Used (GB)": metrics_df['ram_used_gb'].mean(),
|
||||
}
|
||||
})
|
||||
if 'gpu_memory_used' in metrics_df:
|
||||
stats[-1]["stats"].update({
|
||||
"Peak GPU Memory (MB)": metrics_df['gpu_memory_used'].max(),
|
||||
"Avg GPU Memory (MB)": metrics_df['gpu_memory_used'].mean(),
|
||||
})
|
||||
|
||||
# Plot system metrics
|
||||
plot_system_metrics(final_metrics, prefix_path(output_plots_dir, "system_usage_rtf.png"))
|
||||
|
||||
# Save final results
|
||||
save_json_results(
|
||||
{
|
||||
"results": results,
|
||||
"system_metrics": final_metrics,
|
||||
"test_duration": time.time() - test_start_time
|
||||
},
|
||||
prefix_path(output_data_dir, "benchmark_results_rtf.json")
|
||||
)
|
||||
|
||||
print("\nResults saved to:")
|
||||
print(f"- {prefix_path(output_data_dir, 'benchmark_results_rtf.json')}")
|
||||
print(f"- {prefix_path(output_data_dir, 'benchmark_stats_rtf.txt')}")
|
||||
print(f"- {prefix_path(output_plots_dir, 'processing_time_rtf.png')}")
|
||||
print(f"- {prefix_path(output_plots_dir, 'realtime_factor_rtf.png')}")
|
||||
print(f"- {prefix_path(output_plots_dir, 'system_usage_rtf.png')}")
|
||||
print(f"\nAudio files saved in {output_dir} with prefix: {prefix or '(none)'}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
165
examples/assorted_checks/benchmarks/depr_benchmark_tts.py
Normal file
|
@ -0,0 +1,165 @@
|
|||
import os
|
||||
import json
|
||||
import time
|
||||
import pandas as pd
|
||||
from examples.assorted_checks.lib.shared_plotting import plot_system_metrics, plot_correlation
|
||||
from examples.assorted_checks.lib.shared_utils import (
|
||||
get_system_metrics, save_json_results, write_benchmark_stats
|
||||
)
|
||||
from examples.assorted_checks.lib.shared_benchmark_utils import (
|
||||
get_text_for_tokens, make_tts_request, generate_token_sizes, enc
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
# Get optional prefix from first command line argument
|
||||
import sys
|
||||
prefix = sys.argv[1] if len(sys.argv) > 1 else ""
|
||||
|
||||
# Set up paths relative to this file
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
output_dir = os.path.join(script_dir, "output_audio")
|
||||
output_data_dir = os.path.join(script_dir, "output_data")
|
||||
output_plots_dir = os.path.join(script_dir, "output_plots")
|
||||
|
||||
# Create output directories
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
os.makedirs(output_data_dir, exist_ok=True)
|
||||
os.makedirs(output_plots_dir, exist_ok=True)
|
||||
|
||||
# Function to prefix filenames
|
||||
def prefix_path(path: str, filename: str) -> str:
|
||||
if prefix:
|
||||
filename = f"{prefix}_{filename}"
|
||||
return os.path.join(path, filename)
|
||||
|
||||
# Read input text
|
||||
with open(
|
||||
os.path.join(script_dir, "the_time_machine_hg_wells.txt"), "r", encoding="utf-8"
|
||||
) as f:
|
||||
text = f.read()
|
||||
|
||||
# Get total tokens in file
|
||||
total_tokens = len(enc.encode(text))
|
||||
print(f"Total tokens in file: {total_tokens}")
|
||||
|
||||
|
||||
token_sizes = generate_token_sizes(total_tokens)
|
||||
|
||||
print(f"Testing sizes: {token_sizes}")
|
||||
|
||||
# Process chunks
|
||||
results = []
|
||||
system_metrics = []
|
||||
test_start_time = time.time()
|
||||
|
||||
for num_tokens in token_sizes:
|
||||
# Get text slice with exact token count
|
||||
chunk = get_text_for_tokens(text, num_tokens)
|
||||
actual_tokens = len(enc.encode(chunk))
|
||||
|
||||
print(f"\nProcessing chunk with {actual_tokens} tokens:")
|
||||
print(f"Text preview: {chunk[:100]}...")
|
||||
|
||||
# Collect system metrics before processing
|
||||
system_metrics.append(get_system_metrics())
|
||||
|
||||
processing_time, audio_length = make_tts_request(chunk)
|
||||
if processing_time is None or audio_length is None:
|
||||
print("Breaking loop due to error")
|
||||
break
|
||||
|
||||
# Collect system metrics after processing
|
||||
system_metrics.append(get_system_metrics())
|
||||
|
||||
results.append(
|
||||
{
|
||||
"tokens": actual_tokens,
|
||||
"processing_time": processing_time,
|
||||
"output_length": audio_length,
|
||||
"realtime_factor": audio_length / processing_time,
|
||||
"elapsed_time": time.time() - test_start_time,
|
||||
}
|
||||
)
|
||||
|
||||
# Save intermediate results
|
||||
save_json_results(
|
||||
{"results": results, "system_metrics": system_metrics},
|
||||
prefix_path(output_data_dir, "benchmark_results.json")
|
||||
)
|
||||
|
||||
# Create DataFrame and calculate stats
|
||||
df = pd.DataFrame(results)
|
||||
if df.empty:
|
||||
print("No data to plot")
|
||||
return
|
||||
|
||||
# Calculate useful metrics
|
||||
df["tokens_per_second"] = df["tokens"] / df["processing_time"]
|
||||
|
||||
# Write benchmark stats
|
||||
stats = [
|
||||
{
|
||||
"title": "Benchmark Statistics",
|
||||
"stats": {
|
||||
"Total tokens processed": df['tokens'].sum(),
|
||||
"Total audio generated (s)": df['output_length'].sum(),
|
||||
"Total test duration (s)": df['elapsed_time'].max(),
|
||||
"Average processing rate (tokens/s)": df['tokens_per_second'].mean(),
|
||||
"Average realtime factor": df['realtime_factor'].mean()
|
||||
}
|
||||
},
|
||||
{
|
||||
"title": "Per-chunk Stats",
|
||||
"stats": {
|
||||
"Average chunk size (tokens)": df['tokens'].mean(),
|
||||
"Min chunk size (tokens)": df['tokens'].min(),
|
||||
"Max chunk size (tokens)": df['tokens'].max(),
|
||||
"Average processing time (s)": df['processing_time'].mean(),
|
||||
"Average output length (s)": df['output_length'].mean()
|
||||
}
|
||||
},
|
||||
{
|
||||
"title": "Performance Ranges",
|
||||
"stats": {
|
||||
"Processing rate range (tokens/s)": f"{df['tokens_per_second'].min():.2f} - {df['tokens_per_second'].max():.2f}",
|
||||
"Realtime factor range": f"{df['realtime_factor'].min():.2f}x - {df['realtime_factor'].max():.2f}x"
|
||||
}
|
||||
}
|
||||
]
|
||||
write_benchmark_stats(stats, prefix_path(output_data_dir, "benchmark_stats.txt"))
|
||||
|
||||
# Plot Processing Time vs Token Count
|
||||
plot_correlation(
|
||||
df, "tokens", "processing_time",
|
||||
"Processing Time vs Input Size",
|
||||
"Number of Input Tokens",
|
||||
"Processing Time (seconds)",
|
||||
prefix_path(output_plots_dir, "processing_time.png")
|
||||
)
|
||||
|
||||
# Plot Realtime Factor vs Token Count
|
||||
plot_correlation(
|
||||
df, "tokens", "realtime_factor",
|
||||
"Realtime Factor vs Input Size",
|
||||
"Number of Input Tokens",
|
||||
"Realtime Factor (output length / processing time)",
|
||||
prefix_path(output_plots_dir, "realtime_factor.png")
|
||||
)
|
||||
|
||||
# Plot system metrics
|
||||
plot_system_metrics(system_metrics, prefix_path(output_plots_dir, "system_usage.png"))
|
||||
|
||||
print("\nResults saved to:")
|
||||
print(f"- {prefix_path(output_data_dir, 'benchmark_results.json')}")
|
||||
print(f"- {prefix_path(output_data_dir, 'benchmark_stats.txt')}")
|
||||
print(f"- {prefix_path(output_plots_dir, 'processing_time.png')}")
|
||||
print(f"- {prefix_path(output_plots_dir, 'realtime_factor.png')}")
|
||||
print(f"- {prefix_path(output_plots_dir, 'system_usage.png')}")
|
||||
if any("gpu_memory_used" in m for m in system_metrics):
|
||||
print(f"- {prefix_path(output_plots_dir, 'gpu_usage.png')}")
|
||||
print(f"\nAudio files saved in {output_dir} with prefix: {prefix or '(none)'}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
0
examples/assorted_checks/benchmarks/lib/__init__.py
Normal file
|
@ -0,0 +1,111 @@
|
|||
"""Shared utilities specific to TTS benchmarking."""
|
||||
import time
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
import tiktoken
|
||||
|
||||
from .shared_utils import get_audio_length, save_audio_file
|
||||
|
||||
# Global tokenizer instance
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
|
||||
def get_text_for_tokens(text: str, num_tokens: int) -> str:
|
||||
"""Get a slice of text that contains exactly num_tokens tokens.
|
||||
|
||||
Args:
|
||||
text: Input text to slice
|
||||
num_tokens: Desired number of tokens
|
||||
|
||||
Returns:
|
||||
str: Text slice containing exactly num_tokens tokens
|
||||
"""
|
||||
tokens = enc.encode(text)
|
||||
if num_tokens > len(tokens):
|
||||
return text
|
||||
return enc.decode(tokens[:num_tokens])
|
||||
|
||||
|
||||
def make_tts_request(
|
||||
text: str,
|
||||
output_dir: str = None,
|
||||
timeout: int = 1800,
|
||||
prefix: str = ""
|
||||
) -> Tuple[Optional[float], Optional[float]]:
|
||||
"""Make TTS request using OpenAI-compatible endpoint.
|
||||
|
||||
Args:
|
||||
text: Input text to convert to speech
|
||||
output_dir: Directory to save audio files. If None, audio won't be saved.
|
||||
timeout: Request timeout in seconds
|
||||
prefix: Optional prefix for output filenames
|
||||
|
||||
Returns:
|
||||
tuple: (processing_time, audio_length) in seconds, or (None, None) on error
|
||||
"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
response = requests.post(
|
||||
"http://localhost:8880/v1/audio/speech",
|
||||
json={
|
||||
"model": "kokoro",
|
||||
"input": text,
|
||||
"voice": "af",
|
||||
"response_format": "wav",
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
processing_time = round(time.time() - start_time, 2)
|
||||
# Calculate audio length from response content
|
||||
audio_length = get_audio_length(response.content)
|
||||
|
||||
# Save the audio file if output_dir is provided
|
||||
if output_dir:
|
||||
token_count = len(enc.encode(text))
|
||||
output_file = save_audio_file(
|
||||
response.content,
|
||||
f"chunk_{token_count}_tokens",
|
||||
output_dir
|
||||
)
|
||||
print(f"Saved audio to {output_file}")
|
||||
|
||||
return processing_time, audio_length
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Error making request for text: {text[:50]}... Error: {str(e)}")
|
||||
return None, None
|
||||
except Exception as e:
|
||||
print(f"Error processing text: {text[:50]}... Error: {str(e)}")
|
||||
return None, None
|
||||
|
||||
|
||||
def generate_token_sizes(
|
||||
max_tokens: int,
|
||||
dense_step: int = 100,
|
||||
dense_max: int = 1000,
|
||||
sparse_step: int = 1000
|
||||
) -> List[int]:
|
||||
"""Generate token size ranges with dense sampling at start.
|
||||
|
||||
Args:
|
||||
max_tokens: Maximum number of tokens to generate sizes up to
|
||||
dense_step: Step size for dense sampling range
|
||||
dense_max: Maximum value for dense sampling
|
||||
sparse_step: Step size for sparse sampling range
|
||||
|
||||
Returns:
|
||||
list: Sorted list of token sizes
|
||||
"""
|
||||
# Dense sampling at start
|
||||
dense_range = list(range(dense_step, dense_max + 1, dense_step))
|
||||
|
||||
if max_tokens <= dense_max or sparse_step < dense_max:
|
||||
return sorted(dense_range)
|
||||
# Sparse sampling for larger sizes
|
||||
sparse_range = list(range(dense_max + sparse_step, max_tokens + 1, sparse_step))
|
||||
|
||||
# Combine and deduplicate
|
||||
return sorted(list(set(dense_range + sparse_range)))
|
176
examples/assorted_checks/benchmarks/lib/shared_plotting.py
Normal file
|
@ -0,0 +1,176 @@
|
|||
"""Shared plotting utilities for benchmarks and tests."""
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Common style configurations
|
||||
STYLE_CONFIG = {
|
||||
"background_color": "#1a1a2e",
|
||||
"primary_color": "#ff2a6d",
|
||||
"secondary_color": "#05d9e8",
|
||||
"grid_color": "#ffffff",
|
||||
"text_color": "#ffffff",
|
||||
"font_sizes": {
|
||||
"title": 16,
|
||||
"label": 14,
|
||||
"tick": 12,
|
||||
"text": 10
|
||||
}
|
||||
}
|
||||
|
||||
def setup_plot(fig, ax, title, xlabel=None, ylabel=None):
|
||||
"""Configure plot styling with consistent theme.
|
||||
|
||||
Args:
|
||||
fig: matplotlib figure object
|
||||
ax: matplotlib axis object
|
||||
title: str, plot title
|
||||
xlabel: str, optional x-axis label
|
||||
ylabel: str, optional y-axis label
|
||||
|
||||
Returns:
|
||||
tuple: (fig, ax) with applied styling
|
||||
"""
|
||||
# Grid styling
|
||||
ax.grid(True, linestyle="--", alpha=0.3, color=STYLE_CONFIG["grid_color"])
|
||||
|
||||
# Title and labels
|
||||
ax.set_title(title, pad=20,
|
||||
fontsize=STYLE_CONFIG["font_sizes"]["title"],
|
||||
fontweight="bold",
|
||||
color=STYLE_CONFIG["text_color"])
|
||||
|
||||
if xlabel:
|
||||
ax.set_xlabel(xlabel,
|
||||
fontsize=STYLE_CONFIG["font_sizes"]["label"],
|
||||
fontweight="medium",
|
||||
color=STYLE_CONFIG["text_color"])
|
||||
if ylabel:
|
||||
ax.set_ylabel(ylabel,
|
||||
fontsize=STYLE_CONFIG["font_sizes"]["label"],
|
||||
fontweight="medium",
|
||||
color=STYLE_CONFIG["text_color"])
|
||||
|
||||
# Tick styling
|
||||
ax.tick_params(labelsize=STYLE_CONFIG["font_sizes"]["tick"],
|
||||
colors=STYLE_CONFIG["text_color"])
|
||||
|
||||
# Spine styling
|
||||
for spine in ax.spines.values():
|
||||
spine.set_color(STYLE_CONFIG["text_color"])
|
||||
spine.set_alpha(0.3)
|
||||
spine.set_linewidth(0.5)
|
||||
|
||||
# Background colors
|
||||
ax.set_facecolor(STYLE_CONFIG["background_color"])
|
||||
fig.patch.set_facecolor(STYLE_CONFIG["background_color"])
|
||||
|
||||
return fig, ax
|
||||
|
||||
def plot_system_metrics(metrics_data, output_path):
|
||||
"""Create plots for system metrics over time.
|
||||
|
||||
Args:
|
||||
metrics_data: list of dicts containing system metrics
|
||||
output_path: str, path to save the output plot
|
||||
"""
|
||||
df = pd.DataFrame(metrics_data)
|
||||
df["timestamp"] = pd.to_datetime(df["timestamp"])
|
||||
elapsed_time = (df["timestamp"] - df["timestamp"].iloc[0]).dt.total_seconds()
|
||||
|
||||
# Get baseline values
|
||||
baseline_cpu = df["cpu_percent"].iloc[0]
|
||||
baseline_ram = df["ram_used_gb"].iloc[0]
|
||||
baseline_gpu = df["gpu_memory_used"].iloc[0] / 1024 if "gpu_memory_used" in df.columns else None
|
||||
|
||||
# Convert GPU memory to GB if present
|
||||
if "gpu_memory_used" in df.columns:
|
||||
df["gpu_memory_gb"] = df["gpu_memory_used"] / 1024
|
||||
|
||||
plt.style.use("dark_background")
|
||||
|
||||
# Create subplots based on available metrics
|
||||
has_gpu = "gpu_memory_used" in df.columns
|
||||
num_plots = 3 if has_gpu else 2
|
||||
fig, axes = plt.subplots(num_plots, 1, figsize=(15, 5 * num_plots))
|
||||
fig.patch.set_facecolor(STYLE_CONFIG["background_color"])
|
||||
|
||||
# Smoothing window
|
||||
window = min(5, len(df) // 2)
|
||||
|
||||
# Plot CPU Usage
|
||||
smoothed_cpu = df["cpu_percent"].rolling(window=window, center=True).mean()
|
||||
sns.lineplot(x=elapsed_time, y=smoothed_cpu, ax=axes[0],
|
||||
color=STYLE_CONFIG["primary_color"], linewidth=2)
|
||||
axes[0].axhline(y=baseline_cpu, color=STYLE_CONFIG["secondary_color"],
|
||||
linestyle="--", alpha=0.5, label="Baseline")
|
||||
setup_plot(fig, axes[0], "CPU Usage Over Time",
|
||||
xlabel="Time (seconds)", ylabel="CPU Usage (%)")
|
||||
axes[0].set_ylim(0, max(df["cpu_percent"]) * 1.1)
|
||||
axes[0].legend()
|
||||
|
||||
# Plot RAM Usage
|
||||
smoothed_ram = df["ram_used_gb"].rolling(window=window, center=True).mean()
|
||||
sns.lineplot(x=elapsed_time, y=smoothed_ram, ax=axes[1],
|
||||
color=STYLE_CONFIG["secondary_color"], linewidth=2)
|
||||
axes[1].axhline(y=baseline_ram, color=STYLE_CONFIG["primary_color"],
|
||||
linestyle="--", alpha=0.5, label="Baseline")
|
||||
setup_plot(fig, axes[1], "RAM Usage Over Time",
|
||||
xlabel="Time (seconds)", ylabel="RAM Usage (GB)")
|
||||
axes[1].set_ylim(0, max(df["ram_used_gb"]) * 1.1)
|
||||
axes[1].legend()
|
||||
|
||||
# Plot GPU Memory if available
|
||||
if has_gpu:
|
||||
smoothed_gpu = df["gpu_memory_gb"].rolling(window=window, center=True).mean()
|
||||
sns.lineplot(x=elapsed_time, y=smoothed_gpu, ax=axes[2],
|
||||
color=STYLE_CONFIG["primary_color"], linewidth=2)
|
||||
axes[2].axhline(y=baseline_gpu, color=STYLE_CONFIG["secondary_color"],
|
||||
linestyle="--", alpha=0.5, label="Baseline")
|
||||
setup_plot(fig, axes[2], "GPU Memory Usage Over Time",
|
||||
xlabel="Time (seconds)", ylabel="GPU Memory (GB)")
|
||||
axes[2].set_ylim(0, max(df["gpu_memory_gb"]) * 1.1)
|
||||
axes[2].legend()
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(output_path, dpi=300, bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
def plot_correlation(df, x, y, title, xlabel, ylabel, output_path):
|
||||
"""Create correlation plot with regression line and correlation coefficient.
|
||||
|
||||
Args:
|
||||
df: pandas DataFrame containing the data
|
||||
x: str, column name for x-axis
|
||||
y: str, column name for y-axis
|
||||
title: str, plot title
|
||||
xlabel: str, x-axis label
|
||||
ylabel: str, y-axis label
|
||||
output_path: str, path to save the output plot
|
||||
"""
|
||||
plt.style.use("dark_background")
|
||||
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
|
||||
# Scatter plot
|
||||
sns.scatterplot(data=df, x=x, y=y, s=100, alpha=0.6,
|
||||
color=STYLE_CONFIG["primary_color"])
|
||||
|
||||
# Regression line
|
||||
sns.regplot(data=df, x=x, y=y, scatter=False,
|
||||
color=STYLE_CONFIG["secondary_color"],
|
||||
line_kws={"linewidth": 2})
|
||||
|
||||
# Add correlation coefficient
|
||||
corr = df[x].corr(df[y])
|
||||
plt.text(0.05, 0.95, f"Correlation: {corr:.2f}",
|
||||
transform=ax.transAxes,
|
||||
fontsize=STYLE_CONFIG["font_sizes"]["text"],
|
||||
color=STYLE_CONFIG["text_color"],
|
||||
bbox=dict(facecolor=STYLE_CONFIG["background_color"],
|
||||
edgecolor=STYLE_CONFIG["text_color"],
|
||||
alpha=0.7))
|
||||
|
||||
setup_plot(fig, ax, title, xlabel=xlabel, ylabel=ylabel)
|
||||
plt.savefig(output_path, dpi=300, bbox_inches="tight")
|
||||
plt.close()
|
174
examples/assorted_checks/benchmarks/lib/shared_utils.py
Normal file
|
@ -0,0 +1,174 @@
|
|||
"""Shared utilities for benchmarks and tests."""
|
||||
import os
|
||||
import json
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import psutil
|
||||
import scipy.io.wavfile as wavfile
|
||||
|
||||
# Check for torch availability once at module level
|
||||
TORCH_AVAILABLE = False
|
||||
try:
|
||||
import torch
|
||||
TORCH_AVAILABLE = torch.cuda.is_available()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def get_audio_length(audio_data: bytes, temp_dir: str = None) -> float:
|
||||
"""Get audio length in seconds from bytes data.
|
||||
|
||||
Args:
|
||||
audio_data: Raw audio bytes
|
||||
temp_dir: Directory for temporary file. If None, uses system temp directory.
|
||||
|
||||
Returns:
|
||||
float: Audio length in seconds
|
||||
"""
|
||||
if temp_dir is None:
|
||||
import tempfile
|
||||
temp_dir = tempfile.gettempdir()
|
||||
|
||||
temp_path = os.path.join(temp_dir, "temp.wav")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
with open(temp_path, "wb") as f:
|
||||
f.write(audio_data)
|
||||
|
||||
try:
|
||||
rate, data = wavfile.read(temp_path)
|
||||
return len(data) / rate
|
||||
finally:
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
|
||||
def get_gpu_memory(average: bool = True) -> Optional[Union[float, List[float]]]:
|
||||
"""Get GPU memory usage using PyTorch if available, falling back to nvidia-smi.
|
||||
|
||||
Args:
|
||||
average: If True and multiple GPUs present, returns average memory usage.
|
||||
If False, returns list of memory usage per GPU.
|
||||
|
||||
Returns:
|
||||
float or List[float] or None: GPU memory usage in MB. Returns None if no GPU available.
|
||||
If average=False and multiple GPUs present, returns list of values.
|
||||
"""
|
||||
if TORCH_AVAILABLE:
|
||||
n_gpus = torch.cuda.device_count()
|
||||
memory_used = []
|
||||
for i in range(n_gpus):
|
||||
memory_used.append(torch.cuda.memory_allocated(i) / 1024**2) # Convert to MB
|
||||
|
||||
if average and len(memory_used) > 0:
|
||||
return sum(memory_used) / len(memory_used)
|
||||
return memory_used if len(memory_used) > 1 else memory_used[0]
|
||||
|
||||
# Fall back to nvidia-smi
|
||||
try:
|
||||
result = subprocess.check_output(
|
||||
["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader"]
|
||||
)
|
||||
memory_values = [float(x.strip()) for x in result.decode("utf-8").split("\n") if x.strip()]
|
||||
|
||||
if average and len(memory_values) > 0:
|
||||
return sum(memory_values) / len(memory_values)
|
||||
return memory_values if len(memory_values) > 1 else memory_values[0]
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
return None
|
||||
|
||||
|
||||
def get_system_metrics() -> Dict[str, Union[str, float]]:
|
||||
"""Get current system metrics including CPU, RAM, and GPU if available.
|
||||
|
||||
Returns:
|
||||
dict: System metrics including timestamp, CPU%, RAM%, RAM GB, and GPU MB if available
|
||||
"""
|
||||
# Get per-CPU percentages and calculate average
|
||||
cpu_percentages = psutil.cpu_percent(percpu=True)
|
||||
avg_cpu = sum(cpu_percentages) / len(cpu_percentages)
|
||||
|
||||
metrics = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"cpu_percent": round(avg_cpu, 2),
|
||||
"ram_percent": psutil.virtual_memory().percent,
|
||||
"ram_used_gb": psutil.virtual_memory().used / (1024**3),
|
||||
}
|
||||
|
||||
gpu_mem = get_gpu_memory(average=True) # Use average for system metrics
|
||||
if gpu_mem is not None:
|
||||
metrics["gpu_memory_used"] = round(gpu_mem, 2)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def save_audio_file(audio_data: bytes, identifier: str, output_dir: str) -> str:
|
||||
"""Save audio data to a file with proper naming and directory creation.
|
||||
|
||||
Args:
|
||||
audio_data: Raw audio bytes
|
||||
identifier: String to identify this audio file (e.g. token count, test name)
|
||||
output_dir: Directory to save the file
|
||||
|
||||
Returns:
|
||||
str: Path to the saved audio file
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
output_file = os.path.join(output_dir, f"{identifier}.wav")
|
||||
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(audio_data)
|
||||
|
||||
return output_file
|
||||
|
||||
|
||||
def write_benchmark_stats(stats: List[Dict[str, Any]], output_file: str) -> None:
|
||||
"""Write benchmark statistics to a file in a clean, organized format.
|
||||
|
||||
Args:
|
||||
stats: List of dictionaries containing stat name/value pairs
|
||||
output_file: Path to output file
|
||||
"""
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
|
||||
with open(output_file, "w") as f:
|
||||
for section in stats:
|
||||
# Write section header
|
||||
f.write(f"=== {section['title']} ===\n\n")
|
||||
|
||||
# Write stats
|
||||
for label, value in section['stats'].items():
|
||||
if isinstance(value, float):
|
||||
f.write(f"{label}: {value:.2f}\n")
|
||||
else:
|
||||
f.write(f"{label}: {value}\n")
|
||||
f.write("\n")
|
||||
|
||||
|
||||
def save_json_results(results: Dict[str, Any], output_file: str) -> None:
|
||||
"""Save benchmark results to a JSON file with proper formatting.
|
||||
|
||||
Args:
|
||||
results: Dictionary of results to save
|
||||
output_file: Path to output file
|
||||
"""
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
with open(output_file, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
|
||||
def real_time_factor(processing_time: float, audio_length: float, decimals: int = 2) -> float:
|
||||
"""Calculate Real-Time Factor (RTF) as processing-time / length-of-audio.
|
||||
|
||||
Args:
|
||||
processing_time: Time taken to process/generate audio
|
||||
audio_length: Length of the generated audio
|
||||
decimals: Number of decimal places to round to
|
||||
|
||||
Returns:
|
||||
float: RTF value
|
||||
"""
|
||||
rtf = processing_time / audio_length
|
||||
return round(rtf, decimals)
|
|
@ -0,0 +1,111 @@
|
|||
{
|
||||
"results": [
|
||||
{
|
||||
"tokens": 100,
|
||||
"processing_time": 18.833295583724976,
|
||||
"output_length": 31.15,
|
||||
"realtime_factor": 1.6539856161403135,
|
||||
"elapsed_time": 19.024322748184204
|
||||
},
|
||||
{
|
||||
"tokens": 200,
|
||||
"processing_time": 38.95506024360657,
|
||||
"output_length": 62.6,
|
||||
"realtime_factor": 1.6069799304257042,
|
||||
"elapsed_time": 58.21527123451233
|
||||
},
|
||||
{
|
||||
"tokens": 300,
|
||||
"processing_time": 49.74252939224243,
|
||||
"output_length": 96.325,
|
||||
"realtime_factor": 1.9364716908630366,
|
||||
"elapsed_time": 108.19673728942871
|
||||
},
|
||||
{
|
||||
"tokens": 400,
|
||||
"processing_time": 61.349056243896484,
|
||||
"output_length": 128.575,
|
||||
"realtime_factor": 2.095794261102292,
|
||||
"elapsed_time": 169.733656167984
|
||||
},
|
||||
{
|
||||
"tokens": 500,
|
||||
"processing_time": 82.86568236351013,
|
||||
"output_length": 158.575,
|
||||
"realtime_factor": 1.9136389815071193,
|
||||
"elapsed_time": 252.7968451976776
|
||||
}
|
||||
],
|
||||
"system_metrics": [
|
||||
{
|
||||
"timestamp": "2025-01-03T00:13:49.865330",
|
||||
"cpu_percent": 8.0,
|
||||
"ram_percent": 39.4,
|
||||
"ram_used_gb": 25.03811264038086,
|
||||
"gpu_memory_used": 1204.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:14:08.781551",
|
||||
"cpu_percent": 26.8,
|
||||
"ram_percent": 42.6,
|
||||
"ram_used_gb": 27.090862274169922,
|
||||
"gpu_memory_used": 1225.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:14:08.916973",
|
||||
"cpu_percent": 16.1,
|
||||
"ram_percent": 42.6,
|
||||
"ram_used_gb": 27.089553833007812,
|
||||
"gpu_memory_used": 1225.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:14:47.979053",
|
||||
"cpu_percent": 31.5,
|
||||
"ram_percent": 43.6,
|
||||
"ram_used_gb": 27.714427947998047,
|
||||
"gpu_memory_used": 1225.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:14:48.098976",
|
||||
"cpu_percent": 20.0,
|
||||
"ram_percent": 43.6,
|
||||
"ram_used_gb": 27.704315185546875,
|
||||
"gpu_memory_used": 1211.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:15:37.944729",
|
||||
"cpu_percent": 29.7,
|
||||
"ram_percent": 38.6,
|
||||
"ram_used_gb": 24.53925323486328,
|
||||
"gpu_memory_used": 1217.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:15:38.071915",
|
||||
"cpu_percent": 8.6,
|
||||
"ram_percent": 38.5,
|
||||
"ram_used_gb": 24.51690673828125,
|
||||
"gpu_memory_used": 1208.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:16:39.525449",
|
||||
"cpu_percent": 23.4,
|
||||
"ram_percent": 38.8,
|
||||
"ram_used_gb": 24.71230697631836,
|
||||
"gpu_memory_used": 1221.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:16:39.612442",
|
||||
"cpu_percent": 5.5,
|
||||
"ram_percent": 38.9,
|
||||
"ram_used_gb": 24.72066879272461,
|
||||
"gpu_memory_used": 1221.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:18:02.569076",
|
||||
"cpu_percent": 27.4,
|
||||
"ram_percent": 39.1,
|
||||
"ram_used_gb": 24.868202209472656,
|
||||
"gpu_memory_used": 1264.0
|
||||
}
|
||||
]
|
||||
}
|
|
@ -0,0 +1,216 @@
|
|||
{
|
||||
"results": [
|
||||
{
|
||||
"tokens": 100,
|
||||
"processing_time": 14.349808931350708,
|
||||
"output_length": 31.15,
|
||||
"rtf": 0.46,
|
||||
"elapsed_time": 14.716031074523926
|
||||
},
|
||||
{
|
||||
"tokens": 200,
|
||||
"processing_time": 28.341803312301636,
|
||||
"output_length": 62.6,
|
||||
"rtf": 0.45,
|
||||
"elapsed_time": 43.44207406044006
|
||||
},
|
||||
{
|
||||
"tokens": 300,
|
||||
"processing_time": 43.352553606033325,
|
||||
"output_length": 96.325,
|
||||
"rtf": 0.45,
|
||||
"elapsed_time": 87.26906609535217
|
||||
},
|
||||
{
|
||||
"tokens": 400,
|
||||
"processing_time": 71.02449822425842,
|
||||
"output_length": 128.575,
|
||||
"rtf": 0.55,
|
||||
"elapsed_time": 158.7198133468628
|
||||
},
|
||||
{
|
||||
"tokens": 500,
|
||||
"processing_time": 70.92521691322327,
|
||||
"output_length": 158.575,
|
||||
"rtf": 0.45,
|
||||
"elapsed_time": 230.01379895210266
|
||||
},
|
||||
{
|
||||
"tokens": 600,
|
||||
"processing_time": 83.6328592300415,
|
||||
"output_length": 189.25,
|
||||
"rtf": 0.44,
|
||||
"elapsed_time": 314.02610969543457
|
||||
},
|
||||
{
|
||||
"tokens": 700,
|
||||
"processing_time": 103.0810194015503,
|
||||
"output_length": 222.075,
|
||||
"rtf": 0.46,
|
||||
"elapsed_time": 417.5678551197052
|
||||
},
|
||||
{
|
||||
"tokens": 800,
|
||||
"processing_time": 127.02162909507751,
|
||||
"output_length": 253.85,
|
||||
"rtf": 0.5,
|
||||
"elapsed_time": 545.0128681659698
|
||||
},
|
||||
{
|
||||
"tokens": 900,
|
||||
"processing_time": 130.49781227111816,
|
||||
"output_length": 283.775,
|
||||
"rtf": 0.46,
|
||||
"elapsed_time": 675.8943417072296
|
||||
},
|
||||
{
|
||||
"tokens": 1000,
|
||||
"processing_time": 154.76425909996033,
|
||||
"output_length": 315.475,
|
||||
"rtf": 0.49,
|
||||
"elapsed_time": 831.0677945613861
|
||||
}
|
||||
],
|
||||
"system_metrics": [
|
||||
{
|
||||
"timestamp": "2025-01-03T00:23:52.896889",
|
||||
"cpu_percent": 4.5,
|
||||
"ram_percent": 39.1,
|
||||
"ram_used_gb": 24.86032485961914,
|
||||
"gpu_memory_used": 1281.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:24:07.429461",
|
||||
"cpu_percent": 4.5,
|
||||
"ram_percent": 39.1,
|
||||
"ram_used_gb": 24.847564697265625,
|
||||
"gpu_memory_used": 1285.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:24:07.620587",
|
||||
"cpu_percent": 2.7,
|
||||
"ram_percent": 39.1,
|
||||
"ram_used_gb": 24.846607208251953,
|
||||
"gpu_memory_used": 1275.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:24:36.140754",
|
||||
"cpu_percent": 5.4,
|
||||
"ram_percent": 39.1,
|
||||
"ram_used_gb": 24.857810974121094,
|
||||
"gpu_memory_used": 1267.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:24:36.340675",
|
||||
"cpu_percent": 6.2,
|
||||
"ram_percent": 39.1,
|
||||
"ram_used_gb": 24.85773468017578,
|
||||
"gpu_memory_used": 1267.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:25:19.905634",
|
||||
"cpu_percent": 29.1,
|
||||
"ram_percent": 39.2,
|
||||
"ram_used_gb": 24.920318603515625,
|
||||
"gpu_memory_used": 1256.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:25:20.182219",
|
||||
"cpu_percent": 20.0,
|
||||
"ram_percent": 39.2,
|
||||
"ram_used_gb": 24.930198669433594,
|
||||
"gpu_memory_used": 1256.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:26:31.414760",
|
||||
"cpu_percent": 5.3,
|
||||
"ram_percent": 39.5,
|
||||
"ram_used_gb": 25.127891540527344,
|
||||
"gpu_memory_used": 1259.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:26:31.617256",
|
||||
"cpu_percent": 3.6,
|
||||
"ram_percent": 39.5,
|
||||
"ram_used_gb": 25.126346588134766,
|
||||
"gpu_memory_used": 1252.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:27:42.736097",
|
||||
"cpu_percent": 10.5,
|
||||
"ram_percent": 39.5,
|
||||
"ram_used_gb": 25.100231170654297,
|
||||
"gpu_memory_used": 1249.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:27:42.912870",
|
||||
"cpu_percent": 5.3,
|
||||
"ram_percent": 39.5,
|
||||
"ram_used_gb": 25.098285675048828,
|
||||
"gpu_memory_used": 1249.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:29:06.725264",
|
||||
"cpu_percent": 8.9,
|
||||
"ram_percent": 39.5,
|
||||
"ram_used_gb": 25.123123168945312,
|
||||
"gpu_memory_used": 1239.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:29:06.928826",
|
||||
"cpu_percent": 5.5,
|
||||
"ram_percent": 39.5,
|
||||
"ram_used_gb": 25.128646850585938,
|
||||
"gpu_memory_used": 1239.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:30:50.206349",
|
||||
"cpu_percent": 49.6,
|
||||
"ram_percent": 39.6,
|
||||
"ram_used_gb": 25.162948608398438,
|
||||
"gpu_memory_used": 1245.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:30:50.491837",
|
||||
"cpu_percent": 14.8,
|
||||
"ram_percent": 39.5,
|
||||
"ram_used_gb": 25.13379669189453,
|
||||
"gpu_memory_used": 1245.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:32:57.721467",
|
||||
"cpu_percent": 6.2,
|
||||
"ram_percent": 39.6,
|
||||
"ram_used_gb": 25.187721252441406,
|
||||
"gpu_memory_used": 1384.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:32:57.913350",
|
||||
"cpu_percent": 3.6,
|
||||
"ram_percent": 39.6,
|
||||
"ram_used_gb": 25.199390411376953,
|
||||
"gpu_memory_used": 1384.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:35:08.608730",
|
||||
"cpu_percent": 6.3,
|
||||
"ram_percent": 39.8,
|
||||
"ram_used_gb": 25.311710357666016,
|
||||
"gpu_memory_used": 1330.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:35:08.791851",
|
||||
"cpu_percent": 5.3,
|
||||
"ram_percent": 39.8,
|
||||
"ram_used_gb": 25.326683044433594,
|
||||
"gpu_memory_used": 1333.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T00:37:43.782406",
|
||||
"cpu_percent": 6.8,
|
||||
"ram_percent": 40.6,
|
||||
"ram_used_gb": 25.803058624267578,
|
||||
"gpu_memory_used": 1409.0
|
||||
}
|
||||
]
|
||||
}
|
|
@ -0,0 +1,300 @@
|
|||
{
|
||||
"results": [
|
||||
{
|
||||
"tokens": 100,
|
||||
"processing_time": 0.96,
|
||||
"output_length": 31.1,
|
||||
"rtf": 0.03,
|
||||
"elapsed_time": 1.11
|
||||
},
|
||||
{
|
||||
"tokens": 250,
|
||||
"processing_time": 2.23,
|
||||
"output_length": 77.17,
|
||||
"rtf": 0.03,
|
||||
"elapsed_time": 3.49
|
||||
},
|
||||
{
|
||||
"tokens": 400,
|
||||
"processing_time": 4.05,
|
||||
"output_length": 128.05,
|
||||
"rtf": 0.03,
|
||||
"elapsed_time": 7.77
|
||||
},
|
||||
{
|
||||
"tokens": 550,
|
||||
"processing_time": 4.06,
|
||||
"output_length": 171.45,
|
||||
"rtf": 0.02,
|
||||
"elapsed_time": 12.0
|
||||
},
|
||||
{
|
||||
"tokens": 700,
|
||||
"processing_time": 6.01,
|
||||
"output_length": 221.6,
|
||||
"rtf": 0.03,
|
||||
"elapsed_time": 18.16
|
||||
},
|
||||
{
|
||||
"tokens": 850,
|
||||
"processing_time": 6.9,
|
||||
"output_length": 269.1,
|
||||
"rtf": 0.03,
|
||||
"elapsed_time": 25.21
|
||||
},
|
||||
{
|
||||
"tokens": 1000,
|
||||
"processing_time": 7.65,
|
||||
"output_length": 315.05,
|
||||
"rtf": 0.02,
|
||||
"elapsed_time": 33.03
|
||||
},
|
||||
{
|
||||
"tokens": 6000,
|
||||
"processing_time": 48.7,
|
||||
"output_length": 1837.1,
|
||||
"rtf": 0.03,
|
||||
"elapsed_time": 82.21
|
||||
},
|
||||
{
|
||||
"tokens": 11000,
|
||||
"processing_time": 92.44,
|
||||
"output_length": 3388.57,
|
||||
"rtf": 0.03,
|
||||
"elapsed_time": 175.46
|
||||
},
|
||||
{
|
||||
"tokens": 16000,
|
||||
"processing_time": 163.61,
|
||||
"output_length": 4977.32,
|
||||
"rtf": 0.03,
|
||||
"elapsed_time": 340.46
|
||||
},
|
||||
{
|
||||
"tokens": 21000,
|
||||
"processing_time": 209.72,
|
||||
"output_length": 6533.3,
|
||||
"rtf": 0.03,
|
||||
"elapsed_time": 551.92
|
||||
},
|
||||
{
|
||||
"tokens": 26000,
|
||||
"processing_time": 329.35,
|
||||
"output_length": 8068.15,
|
||||
"rtf": 0.04,
|
||||
"elapsed_time": 883.37
|
||||
},
|
||||
{
|
||||
"tokens": 31000,
|
||||
"processing_time": 473.52,
|
||||
"output_length": 9611.48,
|
||||
"rtf": 0.05,
|
||||
"elapsed_time": 1359.28
|
||||
},
|
||||
{
|
||||
"tokens": 36000,
|
||||
"processing_time": 650.98,
|
||||
"output_length": 11157.15,
|
||||
"rtf": 0.06,
|
||||
"elapsed_time": 2012.9
|
||||
}
|
||||
],
|
||||
"system_metrics": [
|
||||
{
|
||||
"timestamp": "2025-01-03T14:41:01.331735",
|
||||
"cpu_percent": 7.5,
|
||||
"ram_percent": 50.2,
|
||||
"ram_used_gb": 31.960269927978516,
|
||||
"gpu_memory_used": 3191.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T14:41:02.357116",
|
||||
"cpu_percent": 17.01,
|
||||
"ram_percent": 50.2,
|
||||
"ram_used_gb": 31.96163558959961,
|
||||
"gpu_memory_used": 3426.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T14:41:02.445009",
|
||||
"cpu_percent": 9.5,
|
||||
"ram_percent": 50.3,
|
||||
"ram_used_gb": 31.966781616210938,
|
||||
"gpu_memory_used": 3426.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T14:41:04.742152",
|
||||
"cpu_percent": 18.27,
|
||||
"ram_percent": 50.4,
|
||||
"ram_used_gb": 32.08788299560547,
|
||||
"gpu_memory_used": 3642.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T14:41:04.847795",
|
||||
"cpu_percent": 16.27,
|
||||
"ram_percent": 50.5,
|
||||
"ram_used_gb": 32.094364166259766,
|
||||
"gpu_memory_used": 3640.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T14:41:09.019590",
|
||||
"cpu_percent": 15.97,
|
||||
"ram_percent": 50.7,
|
||||
"ram_used_gb": 32.23244094848633,
|
||||
"gpu_memory_used": 3640.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T14:41:09.110324",
|
||||
"cpu_percent": 3.54,
|
||||
"ram_percent": 50.7,
|
||||
"ram_used_gb": 32.234458923339844,
|
||||
"gpu_memory_used": 3640.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T14:41:13.252607",
|
||||
"cpu_percent": 13.4,
|
||||
"ram_percent": 50.6,
|
||||
"ram_used_gb": 32.194271087646484,
|
||||
"gpu_memory_used": 3935.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T14:41:13.327557",
|
||||
"cpu_percent": 4.69,
|
||||
"ram_percent": 50.6,
|
||||
"ram_used_gb": 32.191776275634766,
|
||||
"gpu_memory_used": 3935.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T14:41:19.413633",
|
||||
"cpu_percent": 12.92,
|
||||
"ram_percent": 50.9,
|
||||
"ram_used_gb": 32.3467903137207,
|
||||
"gpu_memory_used": 4250.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T14:41:19.492758",
|
||||
"cpu_percent": 7.5,
|
||||
"ram_percent": 50.8,
|
||||
"ram_used_gb": 32.34375,
|
||||
"gpu_memory_used": 4250.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T14:41:26.467284",
|
||||
"cpu_percent": 13.09,
|
||||
"ram_percent": 51.2,
|
||||
"ram_used_gb": 32.56281280517578,
|
||||
"gpu_memory_used": 4249.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T14:41:26.553559",
|
||||
"cpu_percent": 8.39,
|
||||
"ram_percent": 51.2,
|
||||
"ram_used_gb": 32.56183624267578,
|
||||
"gpu_memory_used": 4249.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T14:41:34.284362",
|
||||
"cpu_percent": 12.61,
|
||||
"ram_percent": 51.7,
|
||||
"ram_used_gb": 32.874778747558594,
|
||||
"gpu_memory_used": 4250.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T14:41:34.362353",
|
||||
"cpu_percent": 1.25,
|
||||
"ram_percent": 51.7,
|
||||
"ram_used_gb": 32.87461471557617,
|
||||
"gpu_memory_used": 4250.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T14:42:23.471312",
|
||||
"cpu_percent": 11.64,
|
||||
"ram_percent": 54.9,
|
||||
"ram_used_gb": 34.90264129638672,
|
||||
"gpu_memory_used": 4647.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T14:42:23.547203",
|
||||
"cpu_percent": 5.31,
|
||||
"ram_percent": 54.9,
|
||||
"ram_used_gb": 34.91563415527344,
|
||||
"gpu_memory_used": 4647.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T14:43:56.724933",
|
||||
"cpu_percent": 12.97,
|
||||
"ram_percent": 59.5,
|
||||
"ram_used_gb": 37.84241485595703,
|
||||
"gpu_memory_used": 4655.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T14:43:56.815453",
|
||||
"cpu_percent": 11.75,
|
||||
"ram_percent": 59.5,
|
||||
"ram_used_gb": 37.832679748535156,
|
||||
"gpu_memory_used": 4655.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T14:46:41.705155",
|
||||
"cpu_percent": 12.94,
|
||||
"ram_percent": 66.3,
|
||||
"ram_used_gb": 42.1534538269043,
|
||||
"gpu_memory_used": 4729.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T14:46:41.835177",
|
||||
"cpu_percent": 7.73,
|
||||
"ram_percent": 66.2,
|
||||
"ram_used_gb": 42.13554000854492,
|
||||
"gpu_memory_used": 4729.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T14:50:13.166236",
|
||||
"cpu_percent": 11.62,
|
||||
"ram_percent": 73.4,
|
||||
"ram_used_gb": 46.71288299560547,
|
||||
"gpu_memory_used": 4676.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T14:50:13.261611",
|
||||
"cpu_percent": 8.16,
|
||||
"ram_percent": 73.4,
|
||||
"ram_used_gb": 46.71356201171875,
|
||||
"gpu_memory_used": 4676.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T14:55:44.623607",
|
||||
"cpu_percent": 12.92,
|
||||
"ram_percent": 82.8,
|
||||
"ram_used_gb": 52.65533447265625,
|
||||
"gpu_memory_used": 4636.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T14:55:44.735410",
|
||||
"cpu_percent": 15.29,
|
||||
"ram_percent": 82.7,
|
||||
"ram_used_gb": 52.63290786743164,
|
||||
"gpu_memory_used": 4636.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T15:03:40.534449",
|
||||
"cpu_percent": 13.88,
|
||||
"ram_percent": 85.0,
|
||||
"ram_used_gb": 54.050071716308594,
|
||||
"gpu_memory_used": 4771.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T15:03:40.638708",
|
||||
"cpu_percent": 12.21,
|
||||
"ram_percent": 85.0,
|
||||
"ram_used_gb": 54.053733825683594,
|
||||
"gpu_memory_used": 4771.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2025-01-03T15:14:34.159142",
|
||||
"cpu_percent": 14.51,
|
||||
"ram_percent": 78.1,
|
||||
"ram_used_gb": 49.70396423339844,
|
||||
"gpu_memory_used": 4739.0
|
||||
}
|
||||
]
|
||||
}
|
|
@ -0,0 +1,19 @@
|
|||
=== Benchmark Statistics (with correct RTF) ===
|
||||
|
||||
Overall Stats:
|
||||
Total tokens processed: 5500
|
||||
Total audio generated: 1741.65s
|
||||
Total test duration: 831.07s
|
||||
Average processing rate: 6.72 tokens/second
|
||||
Average RTF: 0.47x
|
||||
|
||||
Per-chunk Stats:
|
||||
Average chunk size: 550.00 tokens
|
||||
Min chunk size: 100.00 tokens
|
||||
Max chunk size: 1000.00 tokens
|
||||
Average processing time: 82.70s
|
||||
Average output length: 174.17s
|
||||
|
||||
Performance Ranges:
|
||||
Processing rate range: 5.63 - 7.17 tokens/second
|
||||
RTF range: 0.44x - 0.55x
|
|
@ -0,0 +1,9 @@
|
|||
=== Benchmark Statistics (with correct RTF) ===
|
||||
|
||||
Overall Stats:
|
||||
Total tokens processed: 150850
|
||||
Total audio generated: 46786.59s
|
||||
Total test duration: 2012.90s
|
||||
Average processing rate: 104.34 tokens/second
|
||||
Average RTF: 0.03x
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
=== Benchmark Statistics (with correct RTF) ===
|
||||
|
||||
Total tokens processed: 1800
|
||||
Total audio generated (s): 568.53
|
||||
Total test duration (s): 244.10
|
||||
Average processing rate (tokens/s): 7.34
|
||||
Average RTF: 0.43
|
||||
Average Real Time Speed: 2.33
|
||||
|
||||
=== Per-chunk Stats ===
|
||||
|
||||
Average chunk size (tokens): 600.00
|
||||
Min chunk size (tokens): 300
|
||||
Max chunk size (tokens): 900
|
||||
Average processing time (s): 81.30
|
||||
Average output length (s): 189.51
|
||||
|
||||
=== Performance Ranges ===
|
||||
|
||||
Processing rate range (tokens/s): 7.21 - 7.47
|
||||
RTF range: 0.43x - 0.43x
|
||||
Real Time Speed range: 2.33x - 2.33x
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
=== Benchmark Statistics (with correct RTF) ===
|
||||
|
||||
Total tokens processed: 17150
|
||||
Total audio generated (s): 5296.38
|
||||
Total test duration (s): 155.23
|
||||
Average processing rate (tokens/s): 102.86
|
||||
Average RTF: 0.03
|
||||
Average Real Time Speed: 31.25
|
||||
|
||||
=== Per-chunk Stats ===
|
||||
|
||||
Average chunk size (tokens): 1715.00
|
||||
Min chunk size (tokens): 150
|
||||
Max chunk size (tokens): 5000
|
||||
Average processing time (s): 15.39
|
||||
Average output length (s): 529.64
|
||||
|
||||
=== Performance Ranges ===
|
||||
|
||||
Processing rate range (tokens/s): 80.65 - 125.10
|
||||
RTF range: 0.03x - 0.04x
|
||||
Real Time Speed range: 25.00x - 33.33x
|
||||
|
After Width: | Height: | Size: 231 KiB |
After Width: | Height: | Size: 181 KiB |
After Width: | Height: | Size: 454 KiB |
Before Width: | Height: | Size: 764 KiB After Width: | Height: | Size: 764 KiB |
After Width: | Height: | Size: 238 KiB |
After Width: | Height: | Size: 250 KiB |
After Width: | Height: | Size: 459 KiB |
Before Width: | Height: | Size: 198 KiB After Width: | Height: | Size: 198 KiB |
|
@ -332,8 +332,8 @@ def main():
|
|||
)
|
||||
parser.add_argument("--url", default="http://localhost:8880", help="API base URL")
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
default="examples/output",
|
||||
"--output-dir",
|
||||
default="examples/assorted_checks/test_combinations/output",
|
||||
help="Output directory for audio files",
|
||||
)
|
||||
args = parser.parse_args()
|
231
examples/assorted_checks/validate_wav.py
Normal file
|
@ -0,0 +1,231 @@
|
|||
import numpy as np
|
||||
import soundfile as sf
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
def validate_tts(wav_path: str) -> dict:
|
||||
"""
|
||||
Quick validation checks for TTS-generated audio files to detect common artifacts.
|
||||
|
||||
Checks for:
|
||||
- Unnatural silence gaps
|
||||
- Audio glitches and artifacts
|
||||
- Repeated speech segments (stuck/looping)
|
||||
- Abrupt changes in speech
|
||||
- Audio quality issues
|
||||
|
||||
Args:
|
||||
wav_path: Path to audio file (wav, mp3, etc)
|
||||
Returns:
|
||||
Dictionary with validation results
|
||||
"""
|
||||
try:
|
||||
# Load audio
|
||||
audio, sr = sf.read(wav_path)
|
||||
if len(audio.shape) > 1:
|
||||
audio = audio.mean(axis=1) # Convert to mono
|
||||
|
||||
# Basic audio stats
|
||||
duration = len(audio) / sr
|
||||
rms = np.sqrt(np.mean(audio**2))
|
||||
peak = np.max(np.abs(audio))
|
||||
dc_offset = np.mean(audio)
|
||||
|
||||
# Calculate clipping stats if we're near peak
|
||||
clip_count = np.sum(np.abs(audio) >= 0.99)
|
||||
clip_percent = (clip_count / len(audio)) * 100
|
||||
if clip_percent > 0:
|
||||
clip_stats = f" ({clip_percent:.2e} ratio near peak)"
|
||||
else:
|
||||
clip_stats = " (no samples near peak)"
|
||||
|
||||
# Convert to dB for analysis
|
||||
eps = np.finfo(float).eps
|
||||
db = 20 * np.log10(np.abs(audio) + eps)
|
||||
|
||||
issues = []
|
||||
|
||||
# Check if audio is too short (likely failed generation)
|
||||
if duration < 0.1: # Less than 100ms
|
||||
issues.append("WARNING: Audio is suspiciously short - possible failed generation")
|
||||
|
||||
# 1. Check for basic audio quality
|
||||
if peak >= 1.0:
|
||||
# Calculate percentage of samples that are clipping
|
||||
clip_count = np.sum(np.abs(audio) >= 0.99)
|
||||
clip_percent = (clip_count / len(audio)) * 100
|
||||
|
||||
if clip_percent > 1.0: # Only warn if more than 1% of samples clip
|
||||
issues.append(f"WARNING: Significant clipping detected ({clip_percent:.2e}% of samples)")
|
||||
elif clip_percent > 0.01: # Add info if more than 0.01% but less than 1%
|
||||
issues.append(f"INFO: Minor peak limiting detected ({clip_percent:.2e}% of samples) - likely intentional normalization")
|
||||
|
||||
if rms < 0.01:
|
||||
issues.append("WARNING: Audio is very quiet - possible failed generation")
|
||||
if abs(dc_offset) > 0.1: # DC offset is particularly bad for speech
|
||||
issues.append(f"WARNING: High DC offset ({dc_offset:.3f}) - may cause audio artifacts")
|
||||
|
||||
# 2. Check for long silence gaps (potential TTS failures)
|
||||
silence_threshold = -45 # dB
|
||||
min_silence = 2.0 # Only detect silences longer than 2 seconds
|
||||
window_size = int(min_silence * sr)
|
||||
silence_count = 0
|
||||
last_silence = -1
|
||||
|
||||
# Skip the first 0.2s for silence detection (avoid false positives at start)
|
||||
start_idx = int(0.2 * sr)
|
||||
for i in range(start_idx, len(db) - window_size, window_size):
|
||||
window = db[i:i+window_size]
|
||||
if np.mean(window) < silence_threshold:
|
||||
# Verify the entire window is mostly silence
|
||||
silent_ratio = np.mean(window < silence_threshold)
|
||||
if silent_ratio > 0.9: # 90% of the window should be below threshold
|
||||
if last_silence == -1 or (i/sr - last_silence) > 2.0: # Only count silences more than 2s apart
|
||||
silence_count += 1
|
||||
last_silence = i/sr
|
||||
issues.append(f"WARNING: Long silence detected at {i/sr:.2f}s (duration: {min_silence:.1f}s)")
|
||||
|
||||
if silence_count > 2: # Only warn if there are multiple long silences
|
||||
issues.append(f"WARNING: Multiple long silences found ({silence_count} total) - possible generation issue")
|
||||
|
||||
# 3. Check for extreme audio artifacts (changes too rapid for natural speech)
|
||||
# Use a longer window to avoid flagging normal phoneme transitions
|
||||
window_size = int(0.02 * sr) # 20ms window
|
||||
db_smooth = np.convolve(db, np.ones(window_size)/window_size, 'same')
|
||||
db_diff = np.abs(np.diff(db_smooth))
|
||||
|
||||
# Much higher threshold to only catch truly unnatural changes
|
||||
artifact_threshold = 40 # dB
|
||||
min_duration = int(0.01 * sr) # Minimum 10ms duration
|
||||
|
||||
# Find regions where the smoothed dB change is extreme
|
||||
artifact_points = np.where(db_diff > artifact_threshold)[0]
|
||||
|
||||
if len(artifact_points) > 0:
|
||||
# Group artifacts that are very close together
|
||||
grouped_artifacts = []
|
||||
current_group = [artifact_points[0]]
|
||||
|
||||
for i in range(1, len(artifact_points)):
|
||||
if (artifact_points[i] - current_group[-1]) < min_duration:
|
||||
current_group.append(artifact_points[i])
|
||||
else:
|
||||
if len(current_group) * (1/sr) >= 0.01: # Only keep groups lasting >= 10ms
|
||||
grouped_artifacts.append(current_group)
|
||||
current_group = [artifact_points[i]]
|
||||
|
||||
if len(current_group) * (1/sr) >= 0.01:
|
||||
grouped_artifacts.append(current_group)
|
||||
|
||||
# Report only the most severe artifacts
|
||||
for group in grouped_artifacts[:2]: # Report up to 2 worst artifacts
|
||||
center_idx = group[len(group)//2]
|
||||
db_change = db_diff[center_idx]
|
||||
if db_change > 45: # Only report very extreme changes
|
||||
issues.append(
|
||||
f"WARNING: Possible audio artifact at {center_idx/sr:.2f}s "
|
||||
f"({db_change:.1f}dB change over {len(group)/sr*1000:.0f}ms)"
|
||||
)
|
||||
|
||||
# 4. Check for repeated speech segments (stuck/looping)
|
||||
# Check both short and long sentence durations at audiobook speed (150-160 wpm)
|
||||
for chunk_duration in [5.0, 10.0]: # 5s (~12 words) and 10s (~25 words) at ~audiobook speed
|
||||
chunk_size = int(chunk_duration * sr)
|
||||
overlap = int(0.2 * chunk_size) # 20% overlap between chunks
|
||||
|
||||
for i in range(0, len(audio) - 2*chunk_size, overlap):
|
||||
chunk1 = audio[i:i+chunk_size]
|
||||
chunk2 = audio[i+chunk_size:i+2*chunk_size]
|
||||
|
||||
# Ignore chunks that are mostly silence
|
||||
if np.mean(np.abs(chunk1)) < 0.01 or np.mean(np.abs(chunk2)) < 0.01:
|
||||
continue
|
||||
|
||||
try:
|
||||
correlation = np.corrcoef(chunk1, chunk2)[0,1]
|
||||
if not np.isnan(correlation) and correlation > 0.92: # Lower threshold for sentence-length chunks
|
||||
issues.append(
|
||||
f"WARNING: Possible repeated speech at {i/sr:.1f}s "
|
||||
f"(~{int(chunk_duration*160/60):d} words, correlation: {correlation:.3f})"
|
||||
)
|
||||
break # Found repetition at this duration, try next duration
|
||||
except:
|
||||
continue
|
||||
|
||||
# 5. Check for extreme amplitude discontinuities (common in failed TTS)
|
||||
amplitude_envelope = np.abs(audio)
|
||||
window_size = sr // 10 # 100ms window for smoother envelope
|
||||
smooth_env = np.convolve(amplitude_envelope, np.ones(window_size)/float(window_size), 'same')
|
||||
env_diff = np.abs(np.diff(smooth_env))
|
||||
|
||||
# Only detect very extreme amplitude changes
|
||||
jump_threshold = 0.5 # Much higher threshold
|
||||
jumps = np.where(env_diff > jump_threshold)[0]
|
||||
|
||||
if len(jumps) > 0:
|
||||
# Group jumps that are close together
|
||||
grouped_jumps = []
|
||||
current_group = [jumps[0]]
|
||||
|
||||
for i in range(1, len(jumps)):
|
||||
if (jumps[i] - current_group[-1]) < 0.05 * sr: # Group within 50ms
|
||||
current_group.append(jumps[i])
|
||||
else:
|
||||
if len(current_group) >= 3: # Only keep significant discontinuities
|
||||
grouped_jumps.append(current_group)
|
||||
current_group = [jumps[i]]
|
||||
|
||||
if len(current_group) >= 3:
|
||||
grouped_jumps.append(current_group)
|
||||
|
||||
# Report only the most severe discontinuities
|
||||
for group in grouped_jumps[:2]: # Report up to 2 worst cases
|
||||
center_idx = group[len(group)//2]
|
||||
jump_size = env_diff[center_idx]
|
||||
if jump_size > 0.6: # Only report very extreme changes
|
||||
issues.append(
|
||||
f"WARNING: Possible audio discontinuity at {center_idx/sr:.2f}s "
|
||||
f"({jump_size:.2f} amplitude ratio change)"
|
||||
)
|
||||
|
||||
return {
|
||||
"file": wav_path,
|
||||
"duration": f"{duration:.2f}s",
|
||||
"sample_rate": sr,
|
||||
"peak_amplitude": f"{peak:.3f}{clip_stats}",
|
||||
"rms_level": f"{rms:.3f}",
|
||||
"dc_offset": f"{dc_offset:.3f}",
|
||||
"issues": issues,
|
||||
"valid": len(issues) == 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"file": wav_path,
|
||||
"error": str(e),
|
||||
"valid": False
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="TTS Output Validator")
|
||||
parser.add_argument("wav_file", help="Path to audio file to validate")
|
||||
args = parser.parse_args()
|
||||
|
||||
result = validate_tts(args.wav_file)
|
||||
|
||||
print(f"\nValidating: {result['file']}")
|
||||
if "error" in result:
|
||||
print(f"Error: {result['error']}")
|
||||
else:
|
||||
print(f"Duration: {result['duration']}")
|
||||
print(f"Sample Rate: {result['sample_rate']} Hz")
|
||||
print(f"Peak Amplitude: {result['peak_amplitude']}")
|
||||
print(f"RMS Level: {result['rms_level']}")
|
||||
print(f"DC Offset: {result['dc_offset']}")
|
||||
|
||||
if result["issues"]:
|
||||
print("\nIssues Found:")
|
||||
for issue in result["issues"]:
|
||||
print(f"- {issue}")
|
||||
else:
|
||||
print("\nNo issues found")
|
72
examples/assorted_checks/validate_wavs.py
Normal file
|
@ -0,0 +1,72 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
from validate_wav import validate_tts
|
||||
|
||||
def print_validation_result(result: dict, rel_path: Path):
|
||||
"""Print full validation details for a single file."""
|
||||
print(f"\nValidating: {rel_path}")
|
||||
if "error" in result:
|
||||
print(f"Error: {result['error']}")
|
||||
else:
|
||||
print(f"Duration: {result['duration']}")
|
||||
print(f"Sample Rate: {result['sample_rate']} Hz")
|
||||
print(f"Peak Amplitude: {result['peak_amplitude']}")
|
||||
print(f"RMS Level: {result['rms_level']}")
|
||||
print(f"DC Offset: {result['dc_offset']}")
|
||||
|
||||
if result["issues"]:
|
||||
print("\nIssues Found:")
|
||||
for issue in result["issues"]:
|
||||
print(f"- {issue}")
|
||||
else:
|
||||
print("\nNo issues found")
|
||||
|
||||
def validate_directory(directory: str):
|
||||
"""Validate all wav files in a directory with detailed output and summary."""
|
||||
dir_path = Path(directory)
|
||||
|
||||
# Find all wav files (including nested directories)
|
||||
wav_files = list(dir_path.rglob("*.wav"))
|
||||
wav_files.extend(dir_path.rglob("*.mp3")) # Also check mp3s
|
||||
wav_files = sorted(wav_files)
|
||||
|
||||
if not wav_files:
|
||||
print(f"No .wav or .mp3 files found in {directory}")
|
||||
return
|
||||
|
||||
print(f"Found {len(wav_files)} files in {directory}")
|
||||
print("=" * 80)
|
||||
|
||||
# Store results for summary
|
||||
results = []
|
||||
|
||||
# Detailed validation output
|
||||
for wav_file in wav_files:
|
||||
result = validate_tts(str(wav_file))
|
||||
rel_path = wav_file.relative_to(dir_path)
|
||||
print_validation_result(result, rel_path)
|
||||
results.append((rel_path, result))
|
||||
print("=" * 80)
|
||||
|
||||
# Summary with detailed issues
|
||||
print("\nSUMMARY:")
|
||||
for rel_path, result in results:
|
||||
if "error" in result:
|
||||
print(f"{rel_path}: ERROR - {result['error']}")
|
||||
elif result["issues"]:
|
||||
# Show first issue in summary, indicate if there are more
|
||||
issues = result["issues"]
|
||||
first_issue = issues[0].replace("WARNING: ", "")
|
||||
if len(issues) > 1:
|
||||
print(f"{rel_path}: FAIL - {first_issue} (+{len(issues)-1} more issues)")
|
||||
else:
|
||||
print(f"{rel_path}: FAIL - {first_issue}")
|
||||
else:
|
||||
print(f"{rel_path}: PASS")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Batch validate TTS wav files")
|
||||
parser.add_argument("directory", help="Directory containing wav files to validate")
|
||||
args = parser.parse_args()
|
||||
|
||||
validate_directory(args.directory)
|
Before Width: | Height: | Size: 754 KiB |
|
@ -1,531 +0,0 @@
|
|||
{
|
||||
"results": [
|
||||
{
|
||||
"tokens": 100,
|
||||
"processing_time": 8.54442310333252,
|
||||
"output_length": 31.15,
|
||||
"realtime_factor": 3.6456527987068887,
|
||||
"elapsed_time": 8.720048666000366
|
||||
},
|
||||
{
|
||||
"tokens": 200,
|
||||
"processing_time": 1.3838517665863037,
|
||||
"output_length": 62.6,
|
||||
"realtime_factor": 45.236058883981606,
|
||||
"elapsed_time": 10.258155345916748
|
||||
},
|
||||
{
|
||||
"tokens": 300,
|
||||
"processing_time": 2.2024788856506348,
|
||||
"output_length": 96.325,
|
||||
"realtime_factor": 43.73481200095347,
|
||||
"elapsed_time": 12.594647407531738
|
||||
},
|
||||
{
|
||||
"tokens": 400,
|
||||
"processing_time": 3.175424098968506,
|
||||
"output_length": 128.55,
|
||||
"realtime_factor": 40.48278150995886,
|
||||
"elapsed_time": 16.005898475646973
|
||||
},
|
||||
{
|
||||
"tokens": 500,
|
||||
"processing_time": 3.205301523208618,
|
||||
"output_length": 158.55,
|
||||
"realtime_factor": 49.46492517224587,
|
||||
"elapsed_time": 19.377076625823975
|
||||
},
|
||||
{
|
||||
"tokens": 600,
|
||||
"processing_time": 3.9976348876953125,
|
||||
"output_length": 189.225,
|
||||
"realtime_factor": 47.33423769700254,
|
||||
"elapsed_time": 23.568575859069824
|
||||
},
|
||||
{
|
||||
"tokens": 700,
|
||||
"processing_time": 4.98036003112793,
|
||||
"output_length": 222.05,
|
||||
"realtime_factor": 44.58513011351734,
|
||||
"elapsed_time": 28.767319917678833
|
||||
},
|
||||
{
|
||||
"tokens": 800,
|
||||
"processing_time": 5.156893491744995,
|
||||
"output_length": 253.825,
|
||||
"realtime_factor": 49.22052402406907,
|
||||
"elapsed_time": 34.1369092464447
|
||||
},
|
||||
{
|
||||
"tokens": 900,
|
||||
"processing_time": 5.8110880851745605,
|
||||
"output_length": 283.75,
|
||||
"realtime_factor": 48.82906537312906,
|
||||
"elapsed_time": 40.16419458389282
|
||||
},
|
||||
{
|
||||
"tokens": 1000,
|
||||
"processing_time": 6.686216354370117,
|
||||
"output_length": 315.45,
|
||||
"realtime_factor": 47.17914935460046,
|
||||
"elapsed_time": 47.11375427246094
|
||||
},
|
||||
{
|
||||
"tokens": 2000,
|
||||
"processing_time": 13.290695905685425,
|
||||
"output_length": 624.925,
|
||||
"realtime_factor": 47.01973504131358,
|
||||
"elapsed_time": 60.842002630233765
|
||||
},
|
||||
{
|
||||
"tokens": 3000,
|
||||
"processing_time": 20.058005571365356,
|
||||
"output_length": 932.05,
|
||||
"realtime_factor": 46.46773063671828,
|
||||
"elapsed_time": 81.50969815254211
|
||||
},
|
||||
{
|
||||
"tokens": 4000,
|
||||
"processing_time": 26.38338828086853,
|
||||
"output_length": 1222.975,
|
||||
"realtime_factor": 46.353978002394015,
|
||||
"elapsed_time": 108.76348638534546
|
||||
},
|
||||
{
|
||||
"tokens": 5000,
|
||||
"processing_time": 32.472310066223145,
|
||||
"output_length": 1525.15,
|
||||
"realtime_factor": 46.967708699801484,
|
||||
"elapsed_time": 142.2994668483734
|
||||
},
|
||||
{
|
||||
"tokens": 6000,
|
||||
"processing_time": 42.67592263221741,
|
||||
"output_length": 1837.525,
|
||||
"realtime_factor": 43.0576514030137,
|
||||
"elapsed_time": 186.26759266853333
|
||||
},
|
||||
{
|
||||
"tokens": 7000,
|
||||
"processing_time": 51.601537466049194,
|
||||
"output_length": 2146.875,
|
||||
"realtime_factor": 41.60486499869347,
|
||||
"elapsed_time": 239.59922289848328
|
||||
},
|
||||
{
|
||||
"tokens": 8000,
|
||||
"processing_time": 51.86434292793274,
|
||||
"output_length": 2458.425,
|
||||
"realtime_factor": 47.401063258741466,
|
||||
"elapsed_time": 293.4462616443634
|
||||
},
|
||||
{
|
||||
"tokens": 9000,
|
||||
"processing_time": 60.4497971534729,
|
||||
"output_length": 2772.1,
|
||||
"realtime_factor": 45.857887545297416,
|
||||
"elapsed_time": 356.02399826049805
|
||||
},
|
||||
{
|
||||
"tokens": 10000,
|
||||
"processing_time": 71.75962543487549,
|
||||
"output_length": 3085.625,
|
||||
"realtime_factor": 42.99945800024164,
|
||||
"elapsed_time": 430.50863671302795
|
||||
},
|
||||
{
|
||||
"tokens": 11000,
|
||||
"processing_time": 96.66409230232239,
|
||||
"output_length": 3389.3,
|
||||
"realtime_factor": 35.062657904030935,
|
||||
"elapsed_time": 529.3296246528625
|
||||
},
|
||||
{
|
||||
"tokens": 12000,
|
||||
"processing_time": 85.70126295089722,
|
||||
"output_length": 3703.175,
|
||||
"realtime_factor": 43.21027336693678,
|
||||
"elapsed_time": 618.0248212814331
|
||||
},
|
||||
{
|
||||
"tokens": 13000,
|
||||
"processing_time": 97.2874686717987,
|
||||
"output_length": 4030.825,
|
||||
"realtime_factor": 41.43210893479068,
|
||||
"elapsed_time": 717.9070522785187
|
||||
},
|
||||
{
|
||||
"tokens": 14000,
|
||||
"processing_time": 105.1045708656311,
|
||||
"output_length": 4356.775,
|
||||
"realtime_factor": 41.451812838566596,
|
||||
"elapsed_time": 826.1140224933624
|
||||
},
|
||||
{
|
||||
"tokens": 15000,
|
||||
"processing_time": 111.0716404914856,
|
||||
"output_length": 4663.325,
|
||||
"realtime_factor": 41.984839508672565,
|
||||
"elapsed_time": 940.0645899772644
|
||||
},
|
||||
{
|
||||
"tokens": 16000,
|
||||
"processing_time": 116.61742973327637,
|
||||
"output_length": 4978.65,
|
||||
"realtime_factor": 42.692160266154104,
|
||||
"elapsed_time": 1061.1957621574402
|
||||
}
|
||||
],
|
||||
"system_metrics": [
|
||||
{
|
||||
"timestamp": "2024-12-31T03:12:36.009478",
|
||||
"cpu_percent": 8.1,
|
||||
"ram_percent": 66.8,
|
||||
"ram_used_gb": 42.47850799560547,
|
||||
"gpu_memory_used": 2124.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:12:44.639678",
|
||||
"cpu_percent": 7.7,
|
||||
"ram_percent": 69.1,
|
||||
"ram_used_gb": 43.984352111816406,
|
||||
"gpu_memory_used": 3486.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:12:44.731107",
|
||||
"cpu_percent": 8.3,
|
||||
"ram_percent": 69.1,
|
||||
"ram_used_gb": 43.97468948364258,
|
||||
"gpu_memory_used": 3484.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:12:46.189723",
|
||||
"cpu_percent": 14.2,
|
||||
"ram_percent": 69.1,
|
||||
"ram_used_gb": 43.98275375366211,
|
||||
"gpu_memory_used": 3697.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:12:46.265437",
|
||||
"cpu_percent": 4.7,
|
||||
"ram_percent": 69.1,
|
||||
"ram_used_gb": 43.982975006103516,
|
||||
"gpu_memory_used": 3697.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:12:48.536216",
|
||||
"cpu_percent": 12.5,
|
||||
"ram_percent": 69.0,
|
||||
"ram_used_gb": 43.86142349243164,
|
||||
"gpu_memory_used": 3697.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:12:48.603827",
|
||||
"cpu_percent": 6.2,
|
||||
"ram_percent": 69.0,
|
||||
"ram_used_gb": 43.8692626953125,
|
||||
"gpu_memory_used": 3694.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:12:51.905764",
|
||||
"cpu_percent": 14.2,
|
||||
"ram_percent": 69.1,
|
||||
"ram_used_gb": 43.93961715698242,
|
||||
"gpu_memory_used": 3690.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:12:52.028178",
|
||||
"cpu_percent": 26.0,
|
||||
"ram_percent": 69.1,
|
||||
"ram_used_gb": 43.944759368896484,
|
||||
"gpu_memory_used": 3690.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:12:55.320709",
|
||||
"cpu_percent": 13.2,
|
||||
"ram_percent": 69.1,
|
||||
"ram_used_gb": 43.943058013916016,
|
||||
"gpu_memory_used": 3685.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:12:55.386582",
|
||||
"cpu_percent": 3.2,
|
||||
"ram_percent": 69.1,
|
||||
"ram_used_gb": 43.9305419921875,
|
||||
"gpu_memory_used": 3685.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:12:59.492304",
|
||||
"cpu_percent": 15.6,
|
||||
"ram_percent": 69.1,
|
||||
"ram_used_gb": 43.964195251464844,
|
||||
"gpu_memory_used": 4053.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:12:59.586143",
|
||||
"cpu_percent": 2.1,
|
||||
"ram_percent": 69.1,
|
||||
"ram_used_gb": 43.9642448425293,
|
||||
"gpu_memory_used": 4053.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:13:04.705286",
|
||||
"cpu_percent": 12.0,
|
||||
"ram_percent": 69.2,
|
||||
"ram_used_gb": 43.992374420166016,
|
||||
"gpu_memory_used": 4059.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:13:04.779475",
|
||||
"cpu_percent": 4.7,
|
||||
"ram_percent": 69.2,
|
||||
"ram_used_gb": 43.9922981262207,
|
||||
"gpu_memory_used": 4059.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:13:10.063292",
|
||||
"cpu_percent": 12.4,
|
||||
"ram_percent": 69.2,
|
||||
"ram_used_gb": 44.004146575927734,
|
||||
"gpu_memory_used": 4041.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:13:10.155395",
|
||||
"cpu_percent": 6.8,
|
||||
"ram_percent": 69.2,
|
||||
"ram_used_gb": 44.004215240478516,
|
||||
"gpu_memory_used": 4041.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:13:16.097887",
|
||||
"cpu_percent": 13.1,
|
||||
"ram_percent": 69.2,
|
||||
"ram_used_gb": 44.0260009765625,
|
||||
"gpu_memory_used": 4042.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:13:16.171478",
|
||||
"cpu_percent": 4.5,
|
||||
"ram_percent": 69.2,
|
||||
"ram_used_gb": 44.02027130126953,
|
||||
"gpu_memory_used": 4042.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:13:23.044945",
|
||||
"cpu_percent": 12.6,
|
||||
"ram_percent": 69.2,
|
||||
"ram_used_gb": 44.03746795654297,
|
||||
"gpu_memory_used": 4044.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:13:23.127442",
|
||||
"cpu_percent": 8.3,
|
||||
"ram_percent": 69.2,
|
||||
"ram_used_gb": 44.0373420715332,
|
||||
"gpu_memory_used": 4044.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:13:36.780309",
|
||||
"cpu_percent": 12.5,
|
||||
"ram_percent": 69.2,
|
||||
"ram_used_gb": 44.00790786743164,
|
||||
"gpu_memory_used": 4034.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:13:36.853474",
|
||||
"cpu_percent": 6.2,
|
||||
"ram_percent": 69.2,
|
||||
"ram_used_gb": 44.00779724121094,
|
||||
"gpu_memory_used": 4034.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:13:57.449274",
|
||||
"cpu_percent": 12.4,
|
||||
"ram_percent": 69.2,
|
||||
"ram_used_gb": 44.0432243347168,
|
||||
"gpu_memory_used": 4034.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:13:57.524592",
|
||||
"cpu_percent": 6.2,
|
||||
"ram_percent": 69.2,
|
||||
"ram_used_gb": 44.03204345703125,
|
||||
"gpu_memory_used": 4034.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:14:24.698822",
|
||||
"cpu_percent": 13.4,
|
||||
"ram_percent": 69.5,
|
||||
"ram_used_gb": 44.18327331542969,
|
||||
"gpu_memory_used": 4480.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:14:24.783683",
|
||||
"cpu_percent": 4.2,
|
||||
"ram_percent": 69.5,
|
||||
"ram_used_gb": 44.182212829589844,
|
||||
"gpu_memory_used": 4480.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:14:58.242642",
|
||||
"cpu_percent": 12.8,
|
||||
"ram_percent": 69.5,
|
||||
"ram_used_gb": 44.20225524902344,
|
||||
"gpu_memory_used": 4476.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:14:58.310907",
|
||||
"cpu_percent": 2.9,
|
||||
"ram_percent": 69.5,
|
||||
"ram_used_gb": 44.19659423828125,
|
||||
"gpu_memory_used": 4476.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:15:42.196813",
|
||||
"cpu_percent": 14.3,
|
||||
"ram_percent": 69.9,
|
||||
"ram_used_gb": 44.43781661987305,
|
||||
"gpu_memory_used": 4494.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:15:42.288427",
|
||||
"cpu_percent": 13.7,
|
||||
"ram_percent": 69.9,
|
||||
"ram_used_gb": 44.439701080322266,
|
||||
"gpu_memory_used": 4494.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:16:35.483849",
|
||||
"cpu_percent": 14.7,
|
||||
"ram_percent": 65.0,
|
||||
"ram_used_gb": 41.35385513305664,
|
||||
"gpu_memory_used": 4506.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:16:35.626628",
|
||||
"cpu_percent": 32.9,
|
||||
"ram_percent": 65.0,
|
||||
"ram_used_gb": 41.34442138671875,
|
||||
"gpu_memory_used": 4506.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:17:29.378353",
|
||||
"cpu_percent": 13.4,
|
||||
"ram_percent": 64.3,
|
||||
"ram_used_gb": 40.8721809387207,
|
||||
"gpu_memory_used": 4485.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:17:29.457464",
|
||||
"cpu_percent": 5.1,
|
||||
"ram_percent": 64.3,
|
||||
"ram_used_gb": 40.875389099121094,
|
||||
"gpu_memory_used": 4485.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:18:31.955862",
|
||||
"cpu_percent": 14.3,
|
||||
"ram_percent": 65.0,
|
||||
"ram_used_gb": 41.360206604003906,
|
||||
"gpu_memory_used": 4484.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:18:32.038999",
|
||||
"cpu_percent": 12.5,
|
||||
"ram_percent": 65.0,
|
||||
"ram_used_gb": 41.37223434448242,
|
||||
"gpu_memory_used": 4484.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:19:46.454105",
|
||||
"cpu_percent": 13.9,
|
||||
"ram_percent": 65.3,
|
||||
"ram_used_gb": 41.562198638916016,
|
||||
"gpu_memory_used": 4487.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:19:46.524303",
|
||||
"cpu_percent": 6.8,
|
||||
"ram_percent": 65.3,
|
||||
"ram_used_gb": 41.56681442260742,
|
||||
"gpu_memory_used": 4487.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:21:25.251452",
|
||||
"cpu_percent": 23.7,
|
||||
"ram_percent": 62.0,
|
||||
"ram_used_gb": 39.456459045410156,
|
||||
"gpu_memory_used": 4488.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:21:25.348643",
|
||||
"cpu_percent": 2.9,
|
||||
"ram_percent": 62.0,
|
||||
"ram_used_gb": 39.454288482666016,
|
||||
"gpu_memory_used": 4487.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:22:53.939896",
|
||||
"cpu_percent": 12.9,
|
||||
"ram_percent": 62.1,
|
||||
"ram_used_gb": 39.50320053100586,
|
||||
"gpu_memory_used": 4488.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:22:54.041607",
|
||||
"cpu_percent": 8.3,
|
||||
"ram_percent": 62.1,
|
||||
"ram_used_gb": 39.49895095825195,
|
||||
"gpu_memory_used": 4488.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:24:33.835432",
|
||||
"cpu_percent": 12.9,
|
||||
"ram_percent": 62.3,
|
||||
"ram_used_gb": 39.647212982177734,
|
||||
"gpu_memory_used": 4503.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:24:33.923914",
|
||||
"cpu_percent": 7.6,
|
||||
"ram_percent": 62.3,
|
||||
"ram_used_gb": 39.64302062988281,
|
||||
"gpu_memory_used": 4503.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:26:22.021598",
|
||||
"cpu_percent": 12.9,
|
||||
"ram_percent": 58.4,
|
||||
"ram_used_gb": 37.162540435791016,
|
||||
"gpu_memory_used": 4491.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:26:22.142138",
|
||||
"cpu_percent": 12.0,
|
||||
"ram_percent": 58.4,
|
||||
"ram_used_gb": 37.162010192871094,
|
||||
"gpu_memory_used": 4487.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:28:15.970365",
|
||||
"cpu_percent": 15.0,
|
||||
"ram_percent": 58.2,
|
||||
"ram_used_gb": 37.04011535644531,
|
||||
"gpu_memory_used": 4481.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:28:16.096459",
|
||||
"cpu_percent": 12.4,
|
||||
"ram_percent": 58.2,
|
||||
"ram_used_gb": 37.035972595214844,
|
||||
"gpu_memory_used": 4473.0
|
||||
},
|
||||
{
|
||||
"timestamp": "2024-12-31T03:30:17.092257",
|
||||
"cpu_percent": 12.4,
|
||||
"ram_percent": 58.4,
|
||||
"ram_used_gb": 37.14639663696289,
|
||||
"gpu_memory_used": 4459.0
|
||||
}
|
||||
]
|
||||
}
|
|
@ -1,19 +0,0 @@
|
|||
=== Benchmark Statistics ===
|
||||
|
||||
Overall Stats:
|
||||
Total tokens processed: 140500
|
||||
Total audio generated: 43469.18s
|
||||
Total test duration: 1061.20s
|
||||
Average processing rate: 137.67 tokens/second
|
||||
Average realtime factor: 42.93x
|
||||
|
||||
Per-chunk Stats:
|
||||
Average chunk size: 5620.00 tokens
|
||||
Min chunk size: 100.00 tokens
|
||||
Max chunk size: 16000.00 tokens
|
||||
Average processing time: 41.13s
|
||||
Average output length: 1738.77s
|
||||
|
||||
Performance Ranges:
|
||||
Processing rate range: 11.70 - 155.99 tokens/second
|
||||
Realtime factor range: 3.65x - 49.46x
|
|
@ -1,406 +0,0 @@
|
|||
import os
|
||||
import json
|
||||
import time
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
|
||||
import pandas as pd
|
||||
import psutil
|
||||
import seaborn as sns
|
||||
import requests
|
||||
import tiktoken
|
||||
import scipy.io.wavfile as wavfile
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
|
||||
def setup_plot(fig, ax, title):
|
||||
"""Configure plot styling"""
|
||||
# Improve grid
|
||||
ax.grid(True, linestyle="--", alpha=0.3, color="#ffffff")
|
||||
|
||||
# Set title and labels with better fonts
|
||||
ax.set_title(title, pad=20, fontsize=16, fontweight="bold", color="#ffffff")
|
||||
ax.set_xlabel(ax.get_xlabel(), fontsize=14, fontweight="medium", color="#ffffff")
|
||||
ax.set_ylabel(ax.get_ylabel(), fontsize=14, fontweight="medium", color="#ffffff")
|
||||
|
||||
# Improve tick labels
|
||||
ax.tick_params(labelsize=12, colors="#ffffff")
|
||||
|
||||
# Style spines
|
||||
for spine in ax.spines.values():
|
||||
spine.set_color("#ffffff")
|
||||
spine.set_alpha(0.3)
|
||||
spine.set_linewidth(0.5)
|
||||
|
||||
# Set background colors
|
||||
ax.set_facecolor("#1a1a2e")
|
||||
fig.patch.set_facecolor("#1a1a2e")
|
||||
|
||||
return fig, ax
|
||||
|
||||
|
||||
def get_text_for_tokens(text: str, num_tokens: int) -> str:
|
||||
"""Get a slice of text that contains exactly num_tokens tokens"""
|
||||
tokens = enc.encode(text)
|
||||
if num_tokens > len(tokens):
|
||||
return text
|
||||
return enc.decode(tokens[:num_tokens])
|
||||
|
||||
|
||||
def get_audio_length(audio_data: bytes) -> float:
|
||||
"""Get audio length in seconds from bytes data"""
|
||||
# Save to a temporary file
|
||||
temp_path = "examples/benchmarks/output/temp.wav"
|
||||
os.makedirs(os.path.dirname(temp_path), exist_ok=True)
|
||||
with open(temp_path, "wb") as f:
|
||||
f.write(audio_data)
|
||||
|
||||
# Read the audio file
|
||||
try:
|
||||
rate, data = wavfile.read(temp_path)
|
||||
return len(data) / rate
|
||||
finally:
|
||||
# Clean up temp file
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
|
||||
def get_gpu_memory():
|
||||
"""Get GPU memory usage using nvidia-smi"""
|
||||
try:
|
||||
result = subprocess.check_output(
|
||||
["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader"]
|
||||
)
|
||||
return float(result.decode("utf-8").strip())
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
return None
|
||||
|
||||
|
||||
def get_system_metrics():
|
||||
"""Get current system metrics"""
|
||||
metrics = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"cpu_percent": psutil.cpu_percent(),
|
||||
"ram_percent": psutil.virtual_memory().percent,
|
||||
"ram_used_gb": psutil.virtual_memory().used / (1024**3),
|
||||
}
|
||||
|
||||
gpu_mem = get_gpu_memory()
|
||||
if gpu_mem is not None:
|
||||
metrics["gpu_memory_used"] = gpu_mem
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def make_tts_request(text: str, timeout: int = 120) -> tuple[float, float]:
|
||||
"""Make TTS request using OpenAI-compatible endpoint and return processing time and output length"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# Make request to OpenAI-compatible endpoint
|
||||
response = requests.post(
|
||||
"http://localhost:8880/v1/audio/speech",
|
||||
json={
|
||||
"model": "kokoro",
|
||||
"input": text,
|
||||
"voice": "af",
|
||||
"response_format": "wav",
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
audio_length = get_audio_length(response.content)
|
||||
|
||||
# Save the audio file
|
||||
token_count = len(enc.encode(text))
|
||||
output_file = f"examples/benchmarks/output/chunk_{token_count}_tokens.wav"
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(response.content)
|
||||
print(f"Saved audio to {output_file}")
|
||||
|
||||
return processing_time, audio_length
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Error making request for text: {text[:50]}... Error: {str(e)}")
|
||||
return None, None
|
||||
except Exception as e:
|
||||
print(f"Error processing text: {text[:50]}... Error: {str(e)}")
|
||||
return None, None
|
||||
|
||||
|
||||
def plot_system_metrics(metrics_data):
|
||||
"""Create plots for system metrics over time"""
|
||||
df = pd.DataFrame(metrics_data)
|
||||
df["timestamp"] = pd.to_datetime(df["timestamp"])
|
||||
elapsed_time = (df["timestamp"] - df["timestamp"].iloc[0]).dt.total_seconds()
|
||||
|
||||
# Get baseline values (first measurement)
|
||||
baseline_cpu = df["cpu_percent"].iloc[0]
|
||||
baseline_ram = df["ram_used_gb"].iloc[0]
|
||||
baseline_gpu = (
|
||||
df["gpu_memory_used"].iloc[0] / 1024
|
||||
if "gpu_memory_used" in df.columns
|
||||
else None
|
||||
) # Convert MB to GB
|
||||
|
||||
# Convert GPU memory to GB
|
||||
if "gpu_memory_used" in df.columns:
|
||||
df["gpu_memory_gb"] = df["gpu_memory_used"] / 1024
|
||||
|
||||
# Set plotting style
|
||||
plt.style.use("dark_background")
|
||||
|
||||
# Create figure with 3 subplots (or 2 if no GPU)
|
||||
has_gpu = "gpu_memory_used" in df.columns
|
||||
num_plots = 3 if has_gpu else 2
|
||||
fig, axes = plt.subplots(num_plots, 1, figsize=(15, 5 * num_plots))
|
||||
fig.patch.set_facecolor("#1a1a2e")
|
||||
|
||||
# Apply rolling average for smoothing
|
||||
window = min(5, len(df) // 2) # Smaller window for smoother lines
|
||||
|
||||
# Plot 1: CPU Usage
|
||||
smoothed_cpu = df["cpu_percent"].rolling(window=window, center=True).mean()
|
||||
sns.lineplot(
|
||||
x=elapsed_time, y=smoothed_cpu, ax=axes[0], color="#ff2a6d", linewidth=2
|
||||
)
|
||||
axes[0].axhline(
|
||||
y=baseline_cpu, color="#05d9e8", linestyle="--", alpha=0.5, label="Baseline"
|
||||
)
|
||||
axes[0].set_xlabel("Time (seconds)", fontsize=14)
|
||||
axes[0].set_ylabel("CPU Usage (%)", fontsize=14)
|
||||
axes[0].tick_params(labelsize=12)
|
||||
axes[0].set_title("CPU Usage Over Time", pad=20, fontsize=16, fontweight="bold")
|
||||
axes[0].set_ylim(0, max(df["cpu_percent"]) * 1.1) # Add 10% padding
|
||||
axes[0].legend()
|
||||
|
||||
# Plot 2: RAM Usage
|
||||
smoothed_ram = df["ram_used_gb"].rolling(window=window, center=True).mean()
|
||||
sns.lineplot(
|
||||
x=elapsed_time, y=smoothed_ram, ax=axes[1], color="#05d9e8", linewidth=2
|
||||
)
|
||||
axes[1].axhline(
|
||||
y=baseline_ram, color="#ff2a6d", linestyle="--", alpha=0.5, label="Baseline"
|
||||
)
|
||||
axes[1].set_xlabel("Time (seconds)", fontsize=14)
|
||||
axes[1].set_ylabel("RAM Usage (GB)", fontsize=14)
|
||||
axes[1].tick_params(labelsize=12)
|
||||
axes[1].set_title("RAM Usage Over Time", pad=20, fontsize=16, fontweight="bold")
|
||||
axes[1].set_ylim(0, max(df["ram_used_gb"]) * 1.1) # Add 10% padding
|
||||
axes[1].legend()
|
||||
|
||||
# Plot 3: GPU Memory (if available)
|
||||
if has_gpu:
|
||||
smoothed_gpu = df["gpu_memory_gb"].rolling(window=window, center=True).mean()
|
||||
sns.lineplot(
|
||||
x=elapsed_time, y=smoothed_gpu, ax=axes[2], color="#ff2a6d", linewidth=2
|
||||
)
|
||||
axes[2].axhline(
|
||||
y=baseline_gpu, color="#05d9e8", linestyle="--", alpha=0.5, label="Baseline"
|
||||
)
|
||||
axes[2].set_xlabel("Time (seconds)", fontsize=14)
|
||||
axes[2].set_ylabel("GPU Memory (GB)", fontsize=14)
|
||||
axes[2].tick_params(labelsize=12)
|
||||
axes[2].set_title(
|
||||
"GPU Memory Usage Over Time", pad=20, fontsize=16, fontweight="bold"
|
||||
)
|
||||
axes[2].set_ylim(0, max(df["gpu_memory_gb"]) * 1.1) # Add 10% padding
|
||||
axes[2].legend()
|
||||
|
||||
# Style all subplots
|
||||
for ax in axes:
|
||||
ax.grid(True, linestyle="--", alpha=0.3)
|
||||
ax.set_facecolor("#1a1a2e")
|
||||
for spine in ax.spines.values():
|
||||
spine.set_color("#ffffff")
|
||||
spine.set_alpha(0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig("examples/benchmarks/system_usage.png", dpi=300, bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
|
||||
def main():
|
||||
# Create output directory
|
||||
os.makedirs("examples/benchmarks/output", exist_ok=True)
|
||||
|
||||
# Read input text
|
||||
with open(
|
||||
"examples/benchmarks/the_time_machine_hg_wells.txt", "r", encoding="utf-8"
|
||||
) as f:
|
||||
text = f.read()
|
||||
|
||||
# Get total tokens in file
|
||||
total_tokens = len(enc.encode(text))
|
||||
print(f"Total tokens in file: {total_tokens}")
|
||||
|
||||
# Generate token sizes with dense sampling at start and increasing intervals
|
||||
dense_range = list(range(100, 1001, 100))
|
||||
current = max(dense_range)
|
||||
large_range = []
|
||||
while current <= total_tokens:
|
||||
large_range.append(current)
|
||||
current += 1000
|
||||
|
||||
token_sizes = sorted(list(set(dense_range + large_range)))
|
||||
print(f"Testing sizes: {token_sizes}")
|
||||
|
||||
# Process chunks
|
||||
results = []
|
||||
system_metrics = []
|
||||
test_start_time = time.time()
|
||||
|
||||
for num_tokens in token_sizes:
|
||||
# Get text slice with exact token count
|
||||
chunk = get_text_for_tokens(text, num_tokens)
|
||||
actual_tokens = len(enc.encode(chunk))
|
||||
|
||||
print(f"\nProcessing chunk with {actual_tokens} tokens:")
|
||||
print(f"Text preview: {chunk[:100]}...")
|
||||
|
||||
# Collect system metrics before processing
|
||||
system_metrics.append(get_system_metrics())
|
||||
|
||||
processing_time, audio_length = make_tts_request(chunk)
|
||||
if processing_time is None or audio_length is None:
|
||||
print("Breaking loop due to error")
|
||||
break
|
||||
|
||||
# Collect system metrics after processing
|
||||
system_metrics.append(get_system_metrics())
|
||||
|
||||
results.append(
|
||||
{
|
||||
"tokens": actual_tokens,
|
||||
"processing_time": processing_time,
|
||||
"output_length": audio_length,
|
||||
"realtime_factor": audio_length / processing_time,
|
||||
"elapsed_time": time.time() - test_start_time,
|
||||
}
|
||||
)
|
||||
|
||||
# Save intermediate results
|
||||
with open("examples/benchmarks/benchmark_results.json", "w") as f:
|
||||
json.dump(
|
||||
{"results": results, "system_metrics": system_metrics}, f, indent=2
|
||||
)
|
||||
|
||||
# Create DataFrame and calculate stats
|
||||
df = pd.DataFrame(results)
|
||||
if df.empty:
|
||||
print("No data to plot")
|
||||
return
|
||||
|
||||
# Calculate useful metrics
|
||||
df["tokens_per_second"] = df["tokens"] / df["processing_time"]
|
||||
|
||||
# Write detailed stats
|
||||
with open("examples/benchmarks/benchmark_stats.txt", "w") as f:
|
||||
f.write("=== Benchmark Statistics ===\n\n")
|
||||
|
||||
f.write("Overall Stats:\n")
|
||||
f.write(f"Total tokens processed: {df['tokens'].sum()}\n")
|
||||
f.write(f"Total audio generated: {df['output_length'].sum():.2f}s\n")
|
||||
f.write(f"Total test duration: {df['elapsed_time'].max():.2f}s\n")
|
||||
f.write(
|
||||
f"Average processing rate: {df['tokens_per_second'].mean():.2f} tokens/second\n"
|
||||
)
|
||||
f.write(f"Average realtime factor: {df['realtime_factor'].mean():.2f}x\n\n")
|
||||
|
||||
f.write("Per-chunk Stats:\n")
|
||||
f.write(f"Average chunk size: {df['tokens'].mean():.2f} tokens\n")
|
||||
f.write(f"Min chunk size: {df['tokens'].min():.2f} tokens\n")
|
||||
f.write(f"Max chunk size: {df['tokens'].max():.2f} tokens\n")
|
||||
f.write(f"Average processing time: {df['processing_time'].mean():.2f}s\n")
|
||||
f.write(f"Average output length: {df['output_length'].mean():.2f}s\n\n")
|
||||
|
||||
f.write("Performance Ranges:\n")
|
||||
f.write(
|
||||
f"Processing rate range: {df['tokens_per_second'].min():.2f} - {df['tokens_per_second'].max():.2f} tokens/second\n"
|
||||
)
|
||||
f.write(
|
||||
f"Realtime factor range: {df['realtime_factor'].min():.2f}x - {df['realtime_factor'].max():.2f}x\n"
|
||||
)
|
||||
|
||||
# Set plotting style
|
||||
plt.style.use("dark_background")
|
||||
|
||||
# Plot 1: Processing Time vs Token Count
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
sns.scatterplot(
|
||||
data=df, x="tokens", y="processing_time", s=100, alpha=0.6, color="#ff2a6d"
|
||||
)
|
||||
sns.regplot(
|
||||
data=df,
|
||||
x="tokens",
|
||||
y="processing_time",
|
||||
scatter=False,
|
||||
color="#05d9e8",
|
||||
line_kws={"linewidth": 2},
|
||||
)
|
||||
corr = df["tokens"].corr(df["processing_time"])
|
||||
plt.text(
|
||||
0.05,
|
||||
0.95,
|
||||
f"Correlation: {corr:.2f}",
|
||||
transform=ax.transAxes,
|
||||
fontsize=10,
|
||||
color="#ffffff",
|
||||
bbox=dict(facecolor="#1a1a2e", edgecolor="#ffffff", alpha=0.7),
|
||||
)
|
||||
setup_plot(fig, ax, "Processing Time vs Input Size")
|
||||
ax.set_xlabel("Number of Input Tokens")
|
||||
ax.set_ylabel("Processing Time (seconds)")
|
||||
plt.savefig("examples/benchmarks/processing_time.png", dpi=300, bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
# Plot 2: Realtime Factor vs Token Count
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
sns.scatterplot(
|
||||
data=df, x="tokens", y="realtime_factor", s=100, alpha=0.6, color="#ff2a6d"
|
||||
)
|
||||
sns.regplot(
|
||||
data=df,
|
||||
x="tokens",
|
||||
y="realtime_factor",
|
||||
scatter=False,
|
||||
color="#05d9e8",
|
||||
line_kws={"linewidth": 2},
|
||||
)
|
||||
corr = df["tokens"].corr(df["realtime_factor"])
|
||||
plt.text(
|
||||
0.05,
|
||||
0.95,
|
||||
f"Correlation: {corr:.2f}",
|
||||
transform=ax.transAxes,
|
||||
fontsize=10,
|
||||
color="#ffffff",
|
||||
bbox=dict(facecolor="#1a1a2e", edgecolor="#ffffff", alpha=0.7),
|
||||
)
|
||||
setup_plot(fig, ax, "Realtime Factor vs Input Size")
|
||||
ax.set_xlabel("Number of Input Tokens")
|
||||
ax.set_ylabel("Realtime Factor (output length / processing time)")
|
||||
plt.savefig("examples/benchmarks/realtime_factor.png", dpi=300, bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
# Plot system metrics
|
||||
plot_system_metrics(system_metrics)
|
||||
|
||||
print("\nResults saved to:")
|
||||
print("- examples/benchmarks/benchmark_results.json")
|
||||
print("- examples/benchmarks/benchmark_stats.txt")
|
||||
print("- examples/benchmarks/processing_time.png")
|
||||
print("- examples/benchmarks/realtime_factor.png")
|
||||
print("- examples/benchmarks/system_usage.png")
|
||||
if any("gpu_memory_used" in m for m in system_metrics):
|
||||
print("- examples/benchmarks/gpu_usage.png")
|
||||
print("\nAudio files saved in examples/benchmarks/output/")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Before Width: | Height: | Size: 283 KiB |
Before Width: | Height: | Size: 223 KiB |
Before Width: | Height: | Size: 406 KiB |