mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
314 lines
12 KiB
Python
314 lines
12 KiB
Python
import os
|
|
import json
|
|
import time
|
|
import subprocess
|
|
from datetime import datetime
|
|
|
|
import pandas as pd
|
|
import psutil
|
|
import seaborn as sns
|
|
import requests
|
|
import tiktoken
|
|
import scipy.io.wavfile as wavfile
|
|
import matplotlib.pyplot as plt
|
|
|
|
enc = tiktoken.get_encoding("cl100k_base")
|
|
|
|
|
|
def setup_plot(fig, ax, title):
|
|
"""Configure plot styling"""
|
|
ax.grid(True, linestyle="--", alpha=0.3, color="#ffffff")
|
|
ax.set_title(title, pad=20, fontsize=16, fontweight="bold", color="#ffffff")
|
|
ax.set_xlabel(ax.get_xlabel(), fontsize=14, fontweight="medium", color="#ffffff")
|
|
ax.set_ylabel(ax.get_ylabel(), fontsize=14, fontweight="medium", color="#ffffff")
|
|
ax.tick_params(labelsize=12, colors="#ffffff")
|
|
|
|
for spine in ax.spines.values():
|
|
spine.set_color("#ffffff")
|
|
spine.set_alpha(0.3)
|
|
spine.set_linewidth(0.5)
|
|
|
|
ax.set_facecolor("#1a1a2e")
|
|
fig.patch.set_facecolor("#1a1a2e")
|
|
return fig, ax
|
|
|
|
|
|
def get_text_for_tokens(text: str, num_tokens: int) -> str:
|
|
"""Get a slice of text that contains exactly num_tokens tokens"""
|
|
tokens = enc.encode(text)
|
|
if num_tokens > len(tokens):
|
|
return text
|
|
return enc.decode(tokens[:num_tokens])
|
|
|
|
|
|
def get_audio_length(audio_data: bytes) -> float:
|
|
"""Get audio length in seconds from bytes data"""
|
|
temp_path = "examples/benchmarks/output/temp.wav"
|
|
os.makedirs(os.path.dirname(temp_path), exist_ok=True)
|
|
with open(temp_path, "wb") as f:
|
|
f.write(audio_data)
|
|
|
|
try:
|
|
rate, data = wavfile.read(temp_path)
|
|
return len(data) / rate
|
|
finally:
|
|
if os.path.exists(temp_path):
|
|
os.remove(temp_path)
|
|
|
|
|
|
def get_gpu_memory():
|
|
"""Get GPU memory usage using nvidia-smi"""
|
|
try:
|
|
result = subprocess.check_output(
|
|
["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader"]
|
|
)
|
|
return float(result.decode("utf-8").strip())
|
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
|
return None
|
|
|
|
|
|
def get_system_metrics():
|
|
"""Get current system metrics"""
|
|
# Get per-CPU percentages and calculate average
|
|
cpu_percentages = psutil.cpu_percent(percpu=True)
|
|
avg_cpu = sum(cpu_percentages) / len(cpu_percentages)
|
|
|
|
metrics = {
|
|
"timestamp": datetime.now().isoformat(),
|
|
"cpu_percent": round(avg_cpu, 2),
|
|
"ram_percent": psutil.virtual_memory().percent,
|
|
"ram_used_gb": psutil.virtual_memory().used / (1024**3),
|
|
}
|
|
|
|
gpu_mem = get_gpu_memory()
|
|
if gpu_mem is not None:
|
|
metrics["gpu_memory_used"] = gpu_mem
|
|
|
|
return metrics
|
|
|
|
|
|
def real_time_factor(processing_time: float, audio_length: float, decimals: int = 2) -> float:
|
|
"""Calculate Real-Time Factor (RTF) as processing-time / length-of-audio"""
|
|
rtf = processing_time / audio_length
|
|
return round(rtf, decimals)
|
|
|
|
|
|
def make_tts_request(text: str, timeout: int = 1800) -> tuple[float, float]:
|
|
"""Make TTS request using OpenAI-compatible endpoint and return processing time and output length"""
|
|
try:
|
|
start_time = time.time()
|
|
response = requests.post(
|
|
"http://localhost:8880/v1/audio/speech",
|
|
json={
|
|
"model": "kokoro",
|
|
"input": text,
|
|
"voice": "af",
|
|
"response_format": "wav",
|
|
},
|
|
timeout=timeout,
|
|
)
|
|
response.raise_for_status()
|
|
|
|
processing_time = round(time.time() - start_time, 2)
|
|
audio_length = round(get_audio_length(response.content), 2)
|
|
|
|
# Save the audio file
|
|
token_count = len(enc.encode(text))
|
|
output_file = f"examples/benchmarks/output/chunk_{token_count}_tokens.wav"
|
|
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
|
with open(output_file, "wb") as f:
|
|
f.write(response.content)
|
|
print(f"Saved audio to {output_file}")
|
|
|
|
return processing_time, audio_length
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
print(f"Error making request for text: {text[:50]}... Error: {str(e)}")
|
|
return None, None
|
|
except Exception as e:
|
|
print(f"Error processing text: {text[:50]}... Error: {str(e)}")
|
|
return None, None
|
|
|
|
|
|
def plot_system_metrics(metrics_data):
|
|
"""Create plots for system metrics over time"""
|
|
df = pd.DataFrame(metrics_data)
|
|
df["timestamp"] = pd.to_datetime(df["timestamp"])
|
|
elapsed_time = (df["timestamp"] - df["timestamp"].iloc[0]).dt.total_seconds()
|
|
|
|
baseline_cpu = df["cpu_percent"].iloc[0]
|
|
baseline_ram = df["ram_used_gb"].iloc[0]
|
|
baseline_gpu = df["gpu_memory_used"].iloc[0] / 1024 if "gpu_memory_used" in df.columns else None
|
|
|
|
if "gpu_memory_used" in df.columns:
|
|
df["gpu_memory_gb"] = df["gpu_memory_used"] / 1024
|
|
|
|
plt.style.use("dark_background")
|
|
|
|
has_gpu = "gpu_memory_used" in df.columns
|
|
num_plots = 3 if has_gpu else 2
|
|
fig, axes = plt.subplots(num_plots, 1, figsize=(15, 5 * num_plots))
|
|
fig.patch.set_facecolor("#1a1a2e")
|
|
|
|
window = min(5, len(df) // 2)
|
|
|
|
# Plot CPU Usage
|
|
smoothed_cpu = df["cpu_percent"].rolling(window=window, center=True).mean()
|
|
sns.lineplot(x=elapsed_time, y=smoothed_cpu, ax=axes[0], color="#ff2a6d", linewidth=2)
|
|
axes[0].axhline(y=baseline_cpu, color="#05d9e8", linestyle="--", alpha=0.5, label="Baseline")
|
|
axes[0].set_xlabel("Time (seconds)")
|
|
axes[0].set_ylabel("CPU Usage (%)")
|
|
axes[0].set_title("CPU Usage Over Time")
|
|
axes[0].set_ylim(0, max(df["cpu_percent"]) * 1.1)
|
|
axes[0].legend()
|
|
|
|
# Plot RAM Usage
|
|
smoothed_ram = df["ram_used_gb"].rolling(window=window, center=True).mean()
|
|
sns.lineplot(x=elapsed_time, y=smoothed_ram, ax=axes[1], color="#05d9e8", linewidth=2)
|
|
axes[1].axhline(y=baseline_ram, color="#ff2a6d", linestyle="--", alpha=0.5, label="Baseline")
|
|
axes[1].set_xlabel("Time (seconds)")
|
|
axes[1].set_ylabel("RAM Usage (GB)")
|
|
axes[1].set_title("RAM Usage Over Time")
|
|
axes[1].set_ylim(0, max(df["ram_used_gb"]) * 1.1)
|
|
axes[1].legend()
|
|
|
|
# Plot GPU Memory if available
|
|
if has_gpu:
|
|
smoothed_gpu = df["gpu_memory_gb"].rolling(window=window, center=True).mean()
|
|
sns.lineplot(x=elapsed_time, y=smoothed_gpu, ax=axes[2], color="#ff2a6d", linewidth=2)
|
|
axes[2].axhline(y=baseline_gpu, color="#05d9e8", linestyle="--", alpha=0.5, label="Baseline")
|
|
axes[2].set_xlabel("Time (seconds)")
|
|
axes[2].set_ylabel("GPU Memory (GB)")
|
|
axes[2].set_title("GPU Memory Usage Over Time")
|
|
axes[2].set_ylim(0, max(df["gpu_memory_gb"]) * 1.1)
|
|
axes[2].legend()
|
|
|
|
for ax in axes:
|
|
ax.grid(True, linestyle="--", alpha=0.3)
|
|
ax.set_facecolor("#1a1a2e")
|
|
for spine in ax.spines.values():
|
|
spine.set_color("#ffffff")
|
|
spine.set_alpha(0.3)
|
|
|
|
plt.tight_layout()
|
|
plt.savefig("examples/benchmarks/system_usage_rtf.png", dpi=300, bbox_inches="tight")
|
|
plt.close()
|
|
|
|
|
|
def main():
|
|
os.makedirs("examples/benchmarks/output", exist_ok=True)
|
|
|
|
with open("examples/benchmarks/the_time_machine_hg_wells.txt", "r", encoding="utf-8") as f:
|
|
text = f.read()
|
|
|
|
total_tokens = len(enc.encode(text))
|
|
print(f"Total tokens in file: {total_tokens}")
|
|
|
|
# Generate token sizes with dense sampling at start
|
|
dense_range = list(range(100, 1001, 100))
|
|
token_sizes = sorted(list(set(dense_range)))
|
|
print(f"Testing sizes: {token_sizes}")
|
|
|
|
results = []
|
|
system_metrics = []
|
|
test_start_time = time.time()
|
|
|
|
for num_tokens in token_sizes:
|
|
chunk = get_text_for_tokens(text, num_tokens)
|
|
actual_tokens = len(enc.encode(chunk))
|
|
|
|
print(f"\nProcessing chunk with {actual_tokens} tokens:")
|
|
print(f"Text preview: {chunk[:100]}...")
|
|
|
|
system_metrics.append(get_system_metrics())
|
|
|
|
processing_time, audio_length = make_tts_request(chunk)
|
|
if processing_time is None or audio_length is None:
|
|
print("Breaking loop due to error")
|
|
break
|
|
|
|
system_metrics.append(get_system_metrics())
|
|
|
|
# Calculate RTF using the correct formula
|
|
rtf = real_time_factor(processing_time, audio_length)
|
|
|
|
results.append({
|
|
"tokens": actual_tokens,
|
|
"processing_time": processing_time,
|
|
"output_length": audio_length,
|
|
"rtf": rtf,
|
|
"elapsed_time": round(time.time() - test_start_time, 2),
|
|
})
|
|
|
|
with open("examples/benchmarks/benchmark_results_rtf.json", "w") as f:
|
|
json.dump({"results": results, "system_metrics": system_metrics}, f, indent=2)
|
|
|
|
df = pd.DataFrame(results)
|
|
if df.empty:
|
|
print("No data to plot")
|
|
return
|
|
|
|
df["tokens_per_second"] = df["tokens"] / df["processing_time"]
|
|
|
|
with open("examples/benchmarks/benchmark_stats_rtf.txt", "w") as f:
|
|
f.write("=== Benchmark Statistics (with correct RTF) ===\n\n")
|
|
|
|
f.write("Overall Stats:\n")
|
|
f.write(f"Total tokens processed: {df['tokens'].sum()}\n")
|
|
f.write(f"Total audio generated: {df['output_length'].sum():.2f}s\n")
|
|
f.write(f"Total test duration: {df['elapsed_time'].max():.2f}s\n")
|
|
f.write(f"Average processing rate: {df['tokens_per_second'].mean():.2f} tokens/second\n")
|
|
f.write(f"Average RTF: {df['rtf'].mean():.2f}x\n\n")
|
|
|
|
f.write("Per-chunk Stats:\n")
|
|
f.write(f"Average chunk size: {df['tokens'].mean():.2f} tokens\n")
|
|
f.write(f"Min chunk size: {df['tokens'].min():.2f} tokens\n")
|
|
f.write(f"Max chunk size: {df['tokens'].max():.2f} tokens\n")
|
|
f.write(f"Average processing time: {df['processing_time'].mean():.2f}s\n")
|
|
f.write(f"Average output length: {df['output_length'].mean():.2f}s\n\n")
|
|
|
|
f.write("Performance Ranges:\n")
|
|
f.write(f"Processing rate range: {df['tokens_per_second'].min():.2f} - {df['tokens_per_second'].max():.2f} tokens/second\n")
|
|
f.write(f"RTF range: {df['rtf'].min():.2f}x - {df['rtf'].max():.2f}x\n")
|
|
|
|
plt.style.use("dark_background")
|
|
|
|
# Plot Processing Time vs Token Count
|
|
fig, ax = plt.subplots(figsize=(12, 8))
|
|
sns.scatterplot(data=df, x="tokens", y="processing_time", s=100, alpha=0.6, color="#ff2a6d")
|
|
sns.regplot(data=df, x="tokens", y="processing_time", scatter=False, color="#05d9e8", line_kws={"linewidth": 2})
|
|
corr = df["tokens"].corr(df["processing_time"])
|
|
plt.text(0.05, 0.95, f"Correlation: {corr:.2f}", transform=ax.transAxes, fontsize=10, color="#ffffff",
|
|
bbox=dict(facecolor="#1a1a2e", edgecolor="#ffffff", alpha=0.7))
|
|
setup_plot(fig, ax, "Processing Time vs Input Size")
|
|
ax.set_xlabel("Number of Input Tokens")
|
|
ax.set_ylabel("Processing Time (seconds)")
|
|
plt.savefig("examples/benchmarks/processing_time_rtf.png", dpi=300, bbox_inches="tight")
|
|
plt.close()
|
|
|
|
# Plot RTF vs Token Count
|
|
fig, ax = plt.subplots(figsize=(12, 8))
|
|
sns.scatterplot(data=df, x="tokens", y="rtf", s=100, alpha=0.6, color="#ff2a6d")
|
|
sns.regplot(data=df, x="tokens", y="rtf", scatter=False, color="#05d9e8", line_kws={"linewidth": 2})
|
|
corr = df["tokens"].corr(df["rtf"])
|
|
plt.text(0.05, 0.95, f"Correlation: {corr:.2f}", transform=ax.transAxes, fontsize=10, color="#ffffff",
|
|
bbox=dict(facecolor="#1a1a2e", edgecolor="#ffffff", alpha=0.7))
|
|
setup_plot(fig, ax, "Real-Time Factor vs Input Size")
|
|
ax.set_xlabel("Number of Input Tokens")
|
|
ax.set_ylabel("Real-Time Factor (processing time / audio length)")
|
|
plt.savefig("examples/benchmarks/realtime_factor_rtf.png", dpi=300, bbox_inches="tight")
|
|
plt.close()
|
|
|
|
plot_system_metrics(system_metrics)
|
|
|
|
print("\nResults saved to:")
|
|
print("- examples/benchmarks/benchmark_results_rtf.json")
|
|
print("- examples/benchmarks/benchmark_stats_rtf.txt")
|
|
print("- examples/benchmarks/processing_time_rtf.png")
|
|
print("- examples/benchmarks/realtime_factor_rtf.png")
|
|
print("- examples/benchmarks/system_usage_rtf.png")
|
|
print("\nAudio files saved in examples/benchmarks/output/")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|