Enhance TTS API with logging, voice pack loading, and schema updates

This commit is contained in:
remsky 2024-12-31 01:57:00 -07:00
parent 8ce8334345
commit c11a6ea6ea
12 changed files with 451 additions and 253 deletions

View file

@ -1,26 +1,38 @@
"""
FastAPI OpenAI Compatible API
"""
import uvicorn
import logging
import sys
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from .core.config import settings
from .routers.openai_compatible import router as openai_router
# from services.tts import TTSModel
from .services.tts import TTSModel, TTSService
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for database initialization"""
"""Lifespan context manager for model initialization"""
logger.info("Loading TTS model and voice packs...")
# Initialize the main model
model, device = TTSModel.get_instance()
logger.info(f"Model loaded on {device}")
# Initialize all voice packs
tts_service = TTSService()
voices = tts_service.list_voices()
for voice in voices:
logger.info(f"Loading voice pack: {voice}")
TTSModel.get_voicepack(voice)
logger.info("All models and voice packs loaded successfully")
yield
# Initialize FastAPI app
app = FastAPI(
title=settings.api_title,
@ -42,16 +54,19 @@ app.add_middleware(
# Include OpenAI compatible router
app.include_router(openai_router, prefix="/v1")
# Health check endpoint
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "healthy"}
@app.get("/v1/test")
async def test_endpoint():
"""Test endpoint to verify routing"""
return {"status": "ok"}
if __name__ == "__main__":
uvicorn.run("api.src.main:app", host=settings.host, port=settings.port, reload=True)

View file

@ -1 +1 @@
#
#

View file

@ -1,5 +1,4 @@
from fastapi import APIRouter, HTTPException, Response, Depends
from sqlalchemy.orm import Session
import logging
from ..structures.schemas import OpenAISpeechRequest
from ..services.tts import TTSService
@ -12,14 +11,17 @@ router = APIRouter(
responses={404: {"description": "Not found"}},
)
def get_tts_service() -> TTSService:
"""Dependency to get TTSService instance with database session"""
return TTSService(start_worker=False) # Don't start worker thread for OpenAI endpoint
return TTSService(
start_worker=False
) # Don't start worker thread for OpenAI endpoint
@router.post("/audio/speech")
async def create_speech(
request: OpenAISpeechRequest,
tts_service: TTSService = Depends(get_tts_service)
request: OpenAISpeechRequest, tts_service: TTSService = Depends(get_tts_service)
):
"""OpenAI-compatible endpoint for text-to-speech"""
try:
@ -28,28 +30,27 @@ async def create_speech(
text=request.input,
voice=request.voice,
speed=request.speed,
stitch_long_output=True
stitch_long_output=True,
)
# Convert to requested format
content = AudioService.convert_audio(audio, 24000, request.response_format)
return Response(
content=content,
media_type=f"audio/{request.response_format}",
headers={
"Content-Disposition": f"attachment; filename=speech.{request.response_format}"
}
},
)
except Exception as e:
logger.error(f"Error generating speech: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/audio/voices")
async def list_voices(
tts_service: TTSService = Depends(get_tts_service)
):
async def list_voices(tts_service: TTSService = Depends(get_tts_service)):
"""List all available voices for text-to-speech"""
try:
voices = tts_service.list_voices()

View file

@ -1,4 +1,5 @@
"""Audio conversion service"""
from io import BytesIO
import numpy as np
import scipy.io.wavfile as wavfile
@ -7,60 +8,69 @@ import logging
logger = logging.getLogger(__name__)
class AudioService:
"""Service for audio format conversions"""
@staticmethod
def convert_audio(audio_data: np.ndarray, sample_rate: int, output_format: str) -> bytes:
def convert_audio(
audio_data: np.ndarray, sample_rate: int, output_format: str
) -> bytes:
"""Convert audio data to specified format
Args:
audio_data: Numpy array of audio samples
sample_rate: Sample rate of the audio
output_format: Target format (wav, mp3, etc.)
Returns:
Bytes of the converted audio
"""
buffer = BytesIO()
try:
if output_format == 'wav':
if output_format == "wav":
logger.info("Writing to WAV format...")
wavfile.write(buffer, sample_rate, audio_data)
return buffer.getvalue()
elif output_format == 'mp3':
elif output_format == "mp3":
# For MP3, we need to convert to WAV first
logger.info("Converting to MP3 format...")
wav_buffer = BytesIO()
wavfile.write(wav_buffer, sample_rate, audio_data)
wav_buffer.seek(0)
# Convert WAV to MP3 using soundfile
buffer = BytesIO()
sf.write(buffer, audio_data, sample_rate, format='mp3')
sf.write(buffer, audio_data, sample_rate, format="mp3")
return buffer.getvalue()
elif output_format == 'opus':
elif output_format == "opus":
logger.info("Converting to Opus format...")
sf.write(buffer, audio_data, sample_rate, format='ogg', subtype='opus')
sf.write(buffer, audio_data, sample_rate, format="ogg", subtype="opus")
return buffer.getvalue()
elif output_format == 'flac':
elif output_format == "flac":
logger.info("Converting to FLAC format...")
sf.write(buffer, audio_data, sample_rate, format='flac')
sf.write(buffer, audio_data, sample_rate, format="flac")
return buffer.getvalue()
elif output_format == 'aac':
raise ValueError("AAC format is not currently supported. Please use wav, mp3, opus, or flac.")
elif output_format == 'pcm':
raise ValueError("PCM format is not currently supported. Please use wav, mp3, opus, or flac.")
elif output_format == "aac":
raise ValueError(
"AAC format is not currently supported. Please use wav, mp3, opus, or flac."
)
elif output_format == "pcm":
raise ValueError(
"PCM format is not currently supported. Please use wav, mp3, opus, or flac."
)
else:
raise ValueError(f"Format {output_format} not supported. Supported formats are: wav, mp3, opus, flac.")
raise ValueError(
f"Format {output_format} not supported. Supported formats are: wav, mp3, opus, flac."
)
except Exception as e:
logger.error(f"Error converting audio to {output_format}: {str(e)}")
raise ValueError(f"Failed to convert audio to {output_format}: {str(e)}")

View file

@ -2,8 +2,7 @@ import os
import threading
import time
import io
from typing import Optional, List, Tuple
from sqlalchemy.orm import Session
from typing import List, Tuple
import numpy as np
import torch
import scipy.io.wavfile as wavfile
@ -17,6 +16,7 @@ import tiktoken
logger = logging.getLogger(__name__)
enc = tiktoken.get_encoding("cl100k_base")
class TTSModel:
_instance = None
_lock = threading.Lock()
@ -40,7 +40,9 @@ class TTSModel:
model, device = cls.get_instance()
if voice_name not in cls._voicepacks:
try:
voice_path = os.path.join(settings.model_dir, settings.voices_dir, f"{voice_name}.pt")
voice_path = os.path.join(
settings.model_dir, settings.voices_dir, f"{voice_name}.pt"
)
voicepack = torch.load(
voice_path, map_location=device, weights_only=True
)
@ -61,9 +63,11 @@ class TTSService:
def _split_text(self, text: str) -> List[str]:
"""Split text into sentences"""
return [s.strip() for s in re.split(r'(?<=[.!?])\s+', text) if s.strip()]
return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()]
def _generate_audio(self, text: str, voice: str, speed: float, stitch_long_output: bool = True) -> Tuple[torch.Tensor, float]:
def _generate_audio(
self, text: str, voice: str, speed: float, stitch_long_output: bool = True
) -> Tuple[torch.Tensor, float]:
"""Generate audio and measure processing time"""
start_time = time.time()
@ -87,22 +91,34 @@ class TTSService:
# Validate phonemization first
ps = phonemize(chunk, voice[0])
tokens = tokenize(ps)
logger.info(f"Processing chunk {i+1}/{len(chunks)}: {len(tokens)} tokens")
logger.info(
f"Processing chunk {i+1}/{len(chunks)}: {len(tokens)} tokens"
)
# Only proceed if phonemization succeeded
chunk_audio, _ = generate(model, chunk, voicepack, lang=voice[0], speed=speed)
chunk_audio, _ = generate(
model, chunk, voicepack, lang=voice[0], speed=speed
)
if chunk_audio is not None:
audio_chunks.append(chunk_audio)
else:
logger.error(f"No audio generated for chunk {i+1}/{len(chunks)}")
logger.error(
f"No audio generated for chunk {i+1}/{len(chunks)}"
)
except Exception as e:
logger.error(f"Failed to generate audio for chunk {i+1}/{len(chunks)}: '{chunk}'. Error: {str(e)}")
logger.error(
f"Failed to generate audio for chunk {i+1}/{len(chunks)}: '{chunk}'. Error: {str(e)}"
)
continue
if not audio_chunks:
raise ValueError("No audio chunks were generated successfully")
audio = np.concatenate(audio_chunks) if len(audio_chunks) > 1 else audio_chunks[0]
audio = (
np.concatenate(audio_chunks)
if len(audio_chunks) > 1
else audio_chunks[0]
)
else:
audio, _ = generate(model, text, voicepack, lang=voice[0], speed=speed)

View file

@ -15,17 +15,24 @@ class TTSStatus(str, Enum):
class OpenAISpeechRequest(BaseModel):
model: Literal["tts-1", "tts-1-hd"] = "tts-1"
input: str = Field(..., description="The text to generate audio for")
voice: Literal["am_adam", "am_michael", "bm_lewis", "af", "bm_george", "bf_isabella", "bf_emma", "af_sarah", "af_bella"] = Field(
default="af",
description="The voice to use for generation"
)
voice: Literal[
"am_adam",
"am_michael",
"bm_lewis",
"af",
"bm_george",
"bf_isabella",
"bf_emma",
"af_sarah",
"af_bella",
] = Field(default="af", description="The voice to use for generation")
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field(
default="mp3",
description="The format to return audio in. Supported formats: mp3, opus, flac, wav. AAC and PCM are not currently supported."
description="The format to return audio in. Supported formats: mp3, opus, flac, wav. AAC and PCM are not currently supported.",
)
speed: float = Field(
default=1.0,
ge=0.25,
le=4.0,
description="The speed of the generated audio. Select a value from 0.25 to 4.0."
description="The speed of the generated audio. Select a value from 0.25 to 4.0.",
)

View file

@ -3,17 +3,15 @@ from unittest.mock import Mock, patch
import sys
# Mock torch and other ML modules before they're imported
sys.modules['torch'] = Mock()
sys.modules['transformers'] = Mock()
sys.modules['phonemizer'] = Mock()
sys.modules['models'] = Mock()
sys.modules['models.build_model'] = Mock()
sys.modules['kokoro'] = Mock()
sys.modules['kokoro.generate'] = Mock()
sys.modules['kokoro.phonemize'] = Mock()
sys.modules['kokoro.tokenize'] = Mock()
from api.src.main import app
sys.modules["torch"] = Mock()
sys.modules["transformers"] = Mock()
sys.modules["phonemizer"] = Mock()
sys.modules["models"] = Mock()
sys.modules["models.build_model"] = Mock()
sys.modules["kokoro"] = Mock()
sys.modules["kokoro.generate"] = Mock()
sys.modules["kokoro.phonemize"] = Mock()
sys.modules["kokoro.tokenize"] = Mock()
@pytest.fixture(autouse=True)

View file

@ -1,27 +1,44 @@
from fastapi.testclient import TestClient
import pytest
from unittest.mock import Mock, patch
from unittest.mock import Mock
from ..src.main import app
from ..src.services.tts import TTSService
from ..src.routers.openai_compatible import TTSService as OpenAITTSService
# Create test client
client = TestClient(app)
# Mock services
@pytest.fixture
def mock_tts_service(monkeypatch):
mock_service = Mock()
mock_service._generate_audio.return_value = (bytes([0, 1, 2, 3]), 1.0)
mock_service.list_voices.return_value = ["af", "bm_lewis", "bf_isabella", "bf_emma", "af_sarah", "af_bella", "am_adam", "am_michael", "bm_george"]
monkeypatch.setattr("api.src.routers.openai_compatible.TTSService", lambda *args, **kwargs: mock_service)
mock_service.list_voices.return_value = [
"af",
"bm_lewis",
"bf_isabella",
"bf_emma",
"af_sarah",
"af_bella",
"am_adam",
"am_michael",
"bm_george",
]
monkeypatch.setattr(
"api.src.routers.openai_compatible.TTSService",
lambda *args, **kwargs: mock_service,
)
return mock_service
@pytest.fixture
def mock_audio_service(monkeypatch):
def mock_convert(*args):
return b"converted mock audio data"
monkeypatch.setattr("api.src.routers.openai_compatible.AudioService.convert_audio", mock_convert)
monkeypatch.setattr(
"api.src.routers.openai_compatible.AudioService.convert_audio", mock_convert
)
def test_health_check():
"""Test the health check endpoint"""
@ -29,6 +46,7 @@ def test_health_check():
assert response.status_code == 200
assert response.json() == {"status": "healthy"}
def test_openai_speech_endpoint(mock_tts_service, mock_audio_service):
"""Test the OpenAI-compatible speech endpoint"""
test_request = {
@ -36,20 +54,18 @@ def test_openai_speech_endpoint(mock_tts_service, mock_audio_service):
"input": "Hello world",
"voice": "bm_lewis",
"response_format": "wav",
"speed": 1.0
"speed": 1.0,
}
response = client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 200
assert response.headers["content-type"] == "audio/wav"
assert response.headers["content-disposition"] == "attachment; filename=speech.wav"
mock_tts_service._generate_audio.assert_called_once_with(
text="Hello world",
voice="bm_lewis",
speed=1.0,
stitch_long_output=True
text="Hello world", voice="bm_lewis", speed=1.0, stitch_long_output=True
)
assert response.content == b"converted mock audio data"
def test_openai_speech_invalid_voice(mock_tts_service):
"""Test the OpenAI-compatible speech endpoint with invalid voice"""
test_request = {
@ -57,11 +73,12 @@ def test_openai_speech_invalid_voice(mock_tts_service):
"input": "Hello world",
"voice": "invalid_voice",
"response_format": "wav",
"speed": 1.0
"speed": 1.0,
}
response = client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 422 # Validation error
def test_openai_speech_invalid_speed(mock_tts_service):
"""Test the OpenAI-compatible speech endpoint with invalid speed"""
test_request = {
@ -69,11 +86,12 @@ def test_openai_speech_invalid_speed(mock_tts_service):
"input": "Hello world",
"voice": "af",
"response_format": "wav",
"speed": -1.0 # Invalid speed
"speed": -1.0, # Invalid speed
}
response = client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 422 # Validation error
def test_openai_speech_generation_error(mock_tts_service):
"""Test error handling in speech generation"""
mock_tts_service._generate_audio.side_effect = Exception("Generation failed")
@ -82,7 +100,7 @@ def test_openai_speech_generation_error(mock_tts_service):
"input": "Hello world",
"voice": "af",
"response_format": "wav",
"speed": 1.0
"speed": 1.0,
}
response = client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 500

View file

@ -3,43 +3,43 @@ import time
import json
import scipy.io.wavfile as wavfile
import requests
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.signal import savgol_filter
import tiktoken
import tiktoken
import psutil
import subprocess
from datetime import datetime
enc = tiktoken.get_encoding("cl100k_base")
def setup_plot(fig, ax, title):
"""Configure plot styling"""
# Improve grid
ax.grid(True, linestyle='--', alpha=0.3, color='#ffffff')
ax.grid(True, linestyle="--", alpha=0.3, color="#ffffff")
# Set title and labels with better fonts
ax.set_title(title, pad=20, fontsize=16, fontweight='bold', color='#ffffff')
ax.set_xlabel(ax.get_xlabel(), fontsize=14, fontweight='medium', color='#ffffff')
ax.set_ylabel(ax.get_ylabel(), fontsize=14, fontweight='medium', color='#ffffff')
ax.set_title(title, pad=20, fontsize=16, fontweight="bold", color="#ffffff")
ax.set_xlabel(ax.get_xlabel(), fontsize=14, fontweight="medium", color="#ffffff")
ax.set_ylabel(ax.get_ylabel(), fontsize=14, fontweight="medium", color="#ffffff")
# Improve tick labels
ax.tick_params(labelsize=12, colors='#ffffff')
ax.tick_params(labelsize=12, colors="#ffffff")
# Style spines
for spine in ax.spines.values():
spine.set_color('#ffffff')
spine.set_color("#ffffff")
spine.set_alpha(0.3)
spine.set_linewidth(0.5)
# Set background colors
ax.set_facecolor('#1a1a2e')
fig.patch.set_facecolor('#1a1a2e')
ax.set_facecolor("#1a1a2e")
fig.patch.set_facecolor("#1a1a2e")
return fig, ax
def get_text_for_tokens(text: str, num_tokens: int) -> str:
"""Get a slice of text that contains exactly num_tokens tokens"""
tokens = enc.encode(text)
@ -47,14 +47,15 @@ def get_text_for_tokens(text: str, num_tokens: int) -> str:
return text
return enc.decode(tokens[:num_tokens])
def get_audio_length(audio_data: bytes) -> float:
"""Get audio length in seconds from bytes data"""
# Save to a temporary file
temp_path = 'examples/benchmarks/output/temp.wav'
temp_path = "examples/benchmarks/output/temp.wav"
os.makedirs(os.path.dirname(temp_path), exist_ok=True)
with open(temp_path, 'wb') as f:
with open(temp_path, "wb") as f:
f.write(audio_data)
# Read the audio file
try:
rate, data = wavfile.read(temp_path)
@ -64,60 +65,65 @@ def get_audio_length(audio_data: bytes) -> float:
if os.path.exists(temp_path):
os.remove(temp_path)
def get_gpu_memory():
"""Get GPU memory usage using nvidia-smi"""
try:
result = subprocess.check_output(['nvidia-smi', '--query-gpu=memory.used', '--format=csv,nounits,noheader'])
return float(result.decode('utf-8').strip())
result = subprocess.check_output(
["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader"]
)
return float(result.decode("utf-8").strip())
except (subprocess.CalledProcessError, FileNotFoundError):
return None
def get_system_metrics():
"""Get current system metrics"""
metrics = {
'timestamp': datetime.now().isoformat(),
'cpu_percent': psutil.cpu_percent(),
'ram_percent': psutil.virtual_memory().percent,
'ram_used_gb': psutil.virtual_memory().used / (1024**3),
"timestamp": datetime.now().isoformat(),
"cpu_percent": psutil.cpu_percent(),
"ram_percent": psutil.virtual_memory().percent,
"ram_used_gb": psutil.virtual_memory().used / (1024**3),
}
gpu_mem = get_gpu_memory()
if gpu_mem is not None:
metrics['gpu_memory_used'] = gpu_mem
metrics["gpu_memory_used"] = gpu_mem
return metrics
def make_tts_request(text: str, timeout: int = 120) -> tuple[float, float]:
"""Make TTS request using OpenAI-compatible endpoint and return processing time and output length"""
try:
start_time = time.time()
# Make request to OpenAI-compatible endpoint
response = requests.post(
'http://localhost:8880/v1/audio/speech',
"http://localhost:8880/v1/audio/speech",
json={
'model': 'tts-1',
'input': text,
'voice': 'af',
'response_format': 'wav'
"model": "tts-1",
"input": text,
"voice": "af",
"response_format": "wav",
},
timeout=timeout
timeout=timeout,
)
response.raise_for_status()
processing_time = time.time() - start_time
audio_length = get_audio_length(response.content)
# Save the audio file
token_count = len(enc.encode(text))
output_file = f'examples/benchmarks/output/chunk_{token_count}_tokens.wav'
output_file = f"examples/benchmarks/output/chunk_{token_count}_tokens.wav"
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, 'wb') as f:
with open(output_file, "wb") as f:
f.write(response.content)
print(f"Saved audio to {output_file}")
return processing_time, audio_length
except requests.exceptions.RequestException as e:
print(f"Error making request for text: {text[:50]}... Error: {str(e)}")
return None, None
@ -125,91 +131,113 @@ def make_tts_request(text: str, timeout: int = 120) -> tuple[float, float]:
print(f"Error processing text: {text[:50]}... Error: {str(e)}")
return None, None
def plot_system_metrics(metrics_data):
"""Create plots for system metrics over time"""
df = pd.DataFrame(metrics_data)
df['timestamp'] = pd.to_datetime(df['timestamp'])
elapsed_time = (df['timestamp'] - df['timestamp'].iloc[0]).dt.total_seconds()
df["timestamp"] = pd.to_datetime(df["timestamp"])
elapsed_time = (df["timestamp"] - df["timestamp"].iloc[0]).dt.total_seconds()
# Get baseline values (first measurement)
baseline_cpu = df['cpu_percent'].iloc[0]
baseline_ram = df['ram_used_gb'].iloc[0]
baseline_gpu = df['gpu_memory_used'].iloc[0] / 1024 if 'gpu_memory_used' in df.columns else None # Convert MB to GB
baseline_cpu = df["cpu_percent"].iloc[0]
baseline_ram = df["ram_used_gb"].iloc[0]
baseline_gpu = (
df["gpu_memory_used"].iloc[0] / 1024
if "gpu_memory_used" in df.columns
else None
) # Convert MB to GB
# Convert GPU memory to GB
if 'gpu_memory_used' in df.columns:
df['gpu_memory_gb'] = df['gpu_memory_used'] / 1024
if "gpu_memory_used" in df.columns:
df["gpu_memory_gb"] = df["gpu_memory_used"] / 1024
# Set plotting style
plt.style.use('dark_background')
plt.style.use("dark_background")
# Create figure with 3 subplots (or 2 if no GPU)
has_gpu = 'gpu_memory_used' in df.columns
has_gpu = "gpu_memory_used" in df.columns
num_plots = 3 if has_gpu else 2
fig, axes = plt.subplots(num_plots, 1, figsize=(15, 5*num_plots))
fig.patch.set_facecolor('#1a1a2e')
fig, axes = plt.subplots(num_plots, 1, figsize=(15, 5 * num_plots))
fig.patch.set_facecolor("#1a1a2e")
# Apply rolling average for smoothing
window = min(5, len(df) // 2) # Smaller window for smoother lines
# Plot 1: CPU Usage
smoothed_cpu = df['cpu_percent'].rolling(window=window, center=True).mean()
sns.lineplot(x=elapsed_time, y=smoothed_cpu, ax=axes[0], color='#ff2a6d', linewidth=2)
axes[0].axhline(y=baseline_cpu, color='#05d9e8', linestyle='--', alpha=0.5, label='Baseline')
axes[0].set_xlabel('Time (seconds)', fontsize=14)
axes[0].set_ylabel('CPU Usage (%)', fontsize=14)
smoothed_cpu = df["cpu_percent"].rolling(window=window, center=True).mean()
sns.lineplot(
x=elapsed_time, y=smoothed_cpu, ax=axes[0], color="#ff2a6d", linewidth=2
)
axes[0].axhline(
y=baseline_cpu, color="#05d9e8", linestyle="--", alpha=0.5, label="Baseline"
)
axes[0].set_xlabel("Time (seconds)", fontsize=14)
axes[0].set_ylabel("CPU Usage (%)", fontsize=14)
axes[0].tick_params(labelsize=12)
axes[0].set_title('CPU Usage Over Time', pad=20, fontsize=16, fontweight='bold')
axes[0].set_ylim(0, max(df['cpu_percent']) * 1.1) # Add 10% padding
axes[0].set_title("CPU Usage Over Time", pad=20, fontsize=16, fontweight="bold")
axes[0].set_ylim(0, max(df["cpu_percent"]) * 1.1) # Add 10% padding
axes[0].legend()
# Plot 2: RAM Usage
smoothed_ram = df['ram_used_gb'].rolling(window=window, center=True).mean()
sns.lineplot(x=elapsed_time, y=smoothed_ram, ax=axes[1], color='#05d9e8', linewidth=2)
axes[1].axhline(y=baseline_ram, color='#ff2a6d', linestyle='--', alpha=0.5, label='Baseline')
axes[1].set_xlabel('Time (seconds)', fontsize=14)
axes[1].set_ylabel('RAM Usage (GB)', fontsize=14)
smoothed_ram = df["ram_used_gb"].rolling(window=window, center=True).mean()
sns.lineplot(
x=elapsed_time, y=smoothed_ram, ax=axes[1], color="#05d9e8", linewidth=2
)
axes[1].axhline(
y=baseline_ram, color="#ff2a6d", linestyle="--", alpha=0.5, label="Baseline"
)
axes[1].set_xlabel("Time (seconds)", fontsize=14)
axes[1].set_ylabel("RAM Usage (GB)", fontsize=14)
axes[1].tick_params(labelsize=12)
axes[1].set_title('RAM Usage Over Time', pad=20, fontsize=16, fontweight='bold')
axes[1].set_ylim(0, max(df['ram_used_gb']) * 1.1) # Add 10% padding
axes[1].set_title("RAM Usage Over Time", pad=20, fontsize=16, fontweight="bold")
axes[1].set_ylim(0, max(df["ram_used_gb"]) * 1.1) # Add 10% padding
axes[1].legend()
# Plot 3: GPU Memory (if available)
if has_gpu:
smoothed_gpu = df['gpu_memory_gb'].rolling(window=window, center=True).mean()
sns.lineplot(x=elapsed_time, y=smoothed_gpu, ax=axes[2], color='#ff2a6d', linewidth=2)
axes[2].axhline(y=baseline_gpu, color='#05d9e8', linestyle='--', alpha=0.5, label='Baseline')
axes[2].set_xlabel('Time (seconds)', fontsize=14)
axes[2].set_ylabel('GPU Memory (GB)', fontsize=14)
smoothed_gpu = df["gpu_memory_gb"].rolling(window=window, center=True).mean()
sns.lineplot(
x=elapsed_time, y=smoothed_gpu, ax=axes[2], color="#ff2a6d", linewidth=2
)
axes[2].axhline(
y=baseline_gpu, color="#05d9e8", linestyle="--", alpha=0.5, label="Baseline"
)
axes[2].set_xlabel("Time (seconds)", fontsize=14)
axes[2].set_ylabel("GPU Memory (GB)", fontsize=14)
axes[2].tick_params(labelsize=12)
axes[2].set_title('GPU Memory Usage Over Time', pad=20, fontsize=16, fontweight='bold')
axes[2].set_ylim(0, max(df['gpu_memory_gb']) * 1.1) # Add 10% padding
axes[2].set_title(
"GPU Memory Usage Over Time", pad=20, fontsize=16, fontweight="bold"
)
axes[2].set_ylim(0, max(df["gpu_memory_gb"]) * 1.1) # Add 10% padding
axes[2].legend()
# Style all subplots
for ax in axes:
ax.grid(True, linestyle='--', alpha=0.3)
ax.set_facecolor('#1a1a2e')
ax.grid(True, linestyle="--", alpha=0.3)
ax.set_facecolor("#1a1a2e")
for spine in ax.spines.values():
spine.set_color('#ffffff')
spine.set_color("#ffffff")
spine.set_alpha(0.3)
plt.tight_layout()
plt.savefig('examples/benchmarks/system_usage.png', dpi=300, bbox_inches='tight')
plt.savefig("examples/benchmarks/system_usage.png", dpi=300, bbox_inches="tight")
plt.close()
def main():
# Create output directory
os.makedirs('examples/benchmarks/output', exist_ok=True)
os.makedirs("examples/benchmarks/output", exist_ok=True)
# Read input text
with open('examples/benchmarks/the_time_machine_hg_wells.txt', 'r', encoding='utf-8') as f:
with open(
"examples/benchmarks/the_time_machine_hg_wells.txt", "r", encoding="utf-8"
) as f:
text = f.read()
# Get total tokens in file
total_tokens = len(enc.encode(text))
print(f"Total tokens in file: {total_tokens}")
# Generate token sizes with dense sampling at start and increasing intervals
dense_range = list(range(100, 600, 100)) # 100, 200, 300, 400, 500
medium_range = [750, 1000, 1500, 2000, 3000]
@ -218,120 +246,160 @@ def main():
while current <= total_tokens:
large_range.append(current)
current *= 2
token_sizes = dense_range + medium_range + large_range
# Process chunks
results = []
system_metrics = []
test_start_time = time.time()
for num_tokens in token_sizes:
# Get text slice with exact token count
chunk = get_text_for_tokens(text, num_tokens)
actual_tokens = len(enc.encode(chunk))
print(f"\nProcessing chunk with {actual_tokens} tokens:")
print(f"Text preview: {chunk[:100]}...")
# Collect system metrics before processing
system_metrics.append(get_system_metrics())
processing_time, audio_length = make_tts_request(chunk)
if processing_time is None or audio_length is None:
print("Breaking loop due to error")
break
# Collect system metrics after processing
system_metrics.append(get_system_metrics())
results.append({
'tokens': actual_tokens,
'processing_time': processing_time,
'output_length': audio_length,
'realtime_factor': audio_length / processing_time,
'elapsed_time': time.time() - test_start_time
})
results.append(
{
"tokens": actual_tokens,
"processing_time": processing_time,
"output_length": audio_length,
"realtime_factor": audio_length / processing_time,
"elapsed_time": time.time() - test_start_time,
}
)
# Save intermediate results
with open('examples/benchmarks/benchmark_results.json', 'w') as f:
json.dump({
'results': results,
'system_metrics': system_metrics
}, f, indent=2)
with open("examples/benchmarks/benchmark_results.json", "w") as f:
json.dump(
{"results": results, "system_metrics": system_metrics}, f, indent=2
)
# Create DataFrame and calculate stats
df = pd.DataFrame(results)
if df.empty:
print("No data to plot")
return
# Calculate useful metrics
df['tokens_per_second'] = df['tokens'] / df['processing_time']
df["tokens_per_second"] = df["tokens"] / df["processing_time"]
# Write detailed stats
with open('examples/benchmarks/benchmark_stats.txt', 'w') as f:
with open("examples/benchmarks/benchmark_stats.txt", "w") as f:
f.write("=== Benchmark Statistics ===\n\n")
f.write("Overall Stats:\n")
f.write(f"Total tokens processed: {df['tokens'].sum()}\n")
f.write(f"Total audio generated: {df['output_length'].sum():.2f}s\n")
f.write(f"Total test duration: {df['elapsed_time'].max():.2f}s\n")
f.write(f"Average processing rate: {df['tokens_per_second'].mean():.2f} tokens/second\n")
f.write(
f"Average processing rate: {df['tokens_per_second'].mean():.2f} tokens/second\n"
)
f.write(f"Average realtime factor: {df['realtime_factor'].mean():.2f}x\n\n")
f.write("Per-chunk Stats:\n")
f.write(f"Average chunk size: {df['tokens'].mean():.2f} tokens\n")
f.write(f"Min chunk size: {df['tokens'].min():.2f} tokens\n")
f.write(f"Max chunk size: {df['tokens'].max():.2f} tokens\n")
f.write(f"Average processing time: {df['processing_time'].mean():.2f}s\n")
f.write(f"Average output length: {df['output_length'].mean():.2f}s\n\n")
f.write("Performance Ranges:\n")
f.write(f"Processing rate range: {df['tokens_per_second'].min():.2f} - {df['tokens_per_second'].max():.2f} tokens/second\n")
f.write(f"Realtime factor range: {df['realtime_factor'].min():.2f}x - {df['realtime_factor'].max():.2f}x\n")
f.write(
f"Processing rate range: {df['tokens_per_second'].min():.2f} - {df['tokens_per_second'].max():.2f} tokens/second\n"
)
f.write(
f"Realtime factor range: {df['realtime_factor'].min():.2f}x - {df['realtime_factor'].max():.2f}x\n"
)
# Set plotting style
plt.style.use('dark_background')
plt.style.use("dark_background")
# Plot 1: Processing Time vs Token Count
fig, ax = plt.subplots(figsize=(12, 8))
sns.scatterplot(data=df, x='tokens', y='processing_time', s=100, alpha=0.6, color='#ff2a6d')
sns.regplot(data=df, x='tokens', y='processing_time', scatter=False, color='#05d9e8', line_kws={'linewidth': 2})
corr = df['tokens'].corr(df['processing_time'])
plt.text(0.05, 0.95, f'Correlation: {corr:.2f}', transform=ax.transAxes, fontsize=10, color='#ffffff',
bbox=dict(facecolor='#1a1a2e', edgecolor='#ffffff', alpha=0.7))
setup_plot(fig, ax, 'Processing Time vs Input Size')
ax.set_xlabel('Number of Input Tokens')
ax.set_ylabel('Processing Time (seconds)')
plt.savefig('examples/benchmarks/processing_time.png', dpi=300, bbox_inches='tight')
sns.scatterplot(
data=df, x="tokens", y="processing_time", s=100, alpha=0.6, color="#ff2a6d"
)
sns.regplot(
data=df,
x="tokens",
y="processing_time",
scatter=False,
color="#05d9e8",
line_kws={"linewidth": 2},
)
corr = df["tokens"].corr(df["processing_time"])
plt.text(
0.05,
0.95,
f"Correlation: {corr:.2f}",
transform=ax.transAxes,
fontsize=10,
color="#ffffff",
bbox=dict(facecolor="#1a1a2e", edgecolor="#ffffff", alpha=0.7),
)
setup_plot(fig, ax, "Processing Time vs Input Size")
ax.set_xlabel("Number of Input Tokens")
ax.set_ylabel("Processing Time (seconds)")
plt.savefig("examples/benchmarks/processing_time.png", dpi=300, bbox_inches="tight")
plt.close()
# Plot 2: Realtime Factor vs Token Count
fig, ax = plt.subplots(figsize=(12, 8))
sns.scatterplot(data=df, x='tokens', y='realtime_factor', s=100, alpha=0.6, color='#ff2a6d')
sns.regplot(data=df, x='tokens', y='realtime_factor', scatter=False, color='#05d9e8', line_kws={'linewidth': 2})
corr = df['tokens'].corr(df['realtime_factor'])
plt.text(0.05, 0.95, f'Correlation: {corr:.2f}', transform=ax.transAxes, fontsize=10, color='#ffffff',
bbox=dict(facecolor='#1a1a2e', edgecolor='#ffffff', alpha=0.7))
setup_plot(fig, ax, 'Realtime Factor vs Input Size')
ax.set_xlabel('Number of Input Tokens')
ax.set_ylabel('Realtime Factor (output length / processing time)')
plt.savefig('examples/benchmarks/realtime_factor.png', dpi=300, bbox_inches='tight')
sns.scatterplot(
data=df, x="tokens", y="realtime_factor", s=100, alpha=0.6, color="#ff2a6d"
)
sns.regplot(
data=df,
x="tokens",
y="realtime_factor",
scatter=False,
color="#05d9e8",
line_kws={"linewidth": 2},
)
corr = df["tokens"].corr(df["realtime_factor"])
plt.text(
0.05,
0.95,
f"Correlation: {corr:.2f}",
transform=ax.transAxes,
fontsize=10,
color="#ffffff",
bbox=dict(facecolor="#1a1a2e", edgecolor="#ffffff", alpha=0.7),
)
setup_plot(fig, ax, "Realtime Factor vs Input Size")
ax.set_xlabel("Number of Input Tokens")
ax.set_ylabel("Realtime Factor (output length / processing time)")
plt.savefig("examples/benchmarks/realtime_factor.png", dpi=300, bbox_inches="tight")
plt.close()
# Plot system metrics
plot_system_metrics(system_metrics)
print("\nResults saved to:")
print("- examples/benchmarks/benchmark_results.json")
print("- examples/benchmarks/benchmark_stats.txt")
print("- examples/benchmarks/processing_time.png")
print("- examples/benchmarks/realtime_factor.png")
print("- examples/benchmarks/system_usage.png")
if any('gpu_memory_used' in m for m in system_metrics):
if any("gpu_memory_used" in m for m in system_metrics):
print("- examples/benchmarks/gpu_usage.png")
print("\nAudio files saved in examples/benchmarks/output/")
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -0,0 +1,59 @@
from pathlib import Path
import openai
import requests
SAMPLE_TEXT = """
That is the germ of my great discovery. But you are wrong to say that we cannot move about in Time. For instance, if I am recalling an incident very vividly I go back to the instant of its occurrence: I become absent-minded, as you say. I jump back for a moment.
"""
# Configure OpenAI client to use our local endpoint
client = openai.OpenAI(
timeout=60,
api_key="notneeded", # API key not required for our endpoint
base_url="http://localhost:8880/v1", # Point to our local server with v1 prefix
)
# Create output directory if it doesn't exist
output_dir = Path(__file__).parent / "output"
output_dir.mkdir(exist_ok=True)
def test_voice(voice: str):
speech_file = output_dir / f"speech_{voice}.wav"
print(f"\nTesting voice: {voice}")
print(f"Making request to {client.base_url}/audio/speech...")
try:
response = client.audio.speech.create(
model="tts-1", voice=voice, input=SAMPLE_TEXT, response_format="wav"
)
print("Got response, saving to file...")
with open(speech_file, "wb") as f:
f.write(response.content)
print(f"Success! Saved to: {speech_file}")
except Exception as e:
print(f"Error with voice {voice}: {str(e)}")
# First, get list of available voices using requests
print("Getting list of available voices...")
try:
# Convert base_url to string and ensure no double slashes
base_url = str(client.base_url).rstrip("/")
response = requests.get(f"{base_url}/audio/voices")
if response.status_code != 200:
raise Exception(f"Failed to get voices: {response.text}")
data = response.json()
if "voices" not in data:
raise Exception(f"Unexpected response format: {data}")
voices = data["voices"]
print(f"Found {len(voices)} voices: {', '.join(voices)}")
# Test each voice
for voice in voices:
test_voice(voice)
except Exception as e:
print(f"Error getting voices: {str(e)}")

View file

@ -5,56 +5,58 @@ import openai
client = openai.OpenAI(
timeout=30,
api_key="notneeded", # API key not required for our endpoint
base_url="http://localhost:8880/v1" # Point to our local server with v1 prefix
base_url="http://localhost:8880/v1", # Point to our local server with v1 prefix
)
# Create output directory if it doesn't exist
output_dir = Path(__file__).parent / "output"
output_dir.mkdir(exist_ok=True)
def test_format(format: str, text: str = "The quick brown fox jumped over the lazy dog."):
def test_format(
format: str, text: str = "The quick brown fox jumped over the lazy dog."
):
speech_file = output_dir / f"speech_{format}.{format}"
print(f"\nTesting {format} format...")
print(f"Making request to {client.base_url}/audio/speech...")
try:
response = client.audio.speech.create(
model="tts-1",
voice="af",
input=text,
response_format=format
model="tts-1", voice="af", input=text, response_format=format
)
print(f"Got response, saving to file...")
with open(speech_file, 'wb') as f:
print("Got response, saving to file...")
with open(speech_file, "wb") as f:
f.write(response.content)
print(f"Success! Saved to: {speech_file}")
except Exception as e:
print(f"Error: {str(e)}")
def test_speed(speed: float):
speech_file = output_dir / f"speech_speed_{speed}.wav"
print(f"\nTesting speed {speed}x...")
print(f"Making request to {client.base_url}/audio/speech...")
try:
response = client.audio.speech.create(
model="tts-1",
voice="af",
input="The quick brown fox jumped over the lazy dog.",
response_format="wav",
speed=speed
speed=speed,
)
print(f"Got response, saving to file...")
with open(speech_file, 'wb') as f:
print("Got response, saving to file...")
with open(speech_file, "wb") as f:
f.write(response.content)
print(f"Success! Saved to: {speech_file}")
except Exception as e:
print(f"Error: {str(e)}")
# Test different formats
for format in ["wav", "mp3", "opus", "aac", "flac", "pcm"]:
test_format(format)
@ -64,6 +66,9 @@ for speed in [0.25, 1.0, 2.0, 4.0]: # 5.0 should fail as it's out of range
test_speed(speed)
# Test long text
test_format("wav", """
test_format(
"wav",
"""
That is the germ of my great discovery. But you are wrong to say that we cannot move about in Time. For instance, if I am recalling an incident very vividly I go back to the instant of its occurrence: I become absent-minded, as you say. I jump back for a moment.
""")
""",
)

View file

@ -24,6 +24,7 @@ tqdm==4.67.1
requests==2.32.3
munch==4.0.0
tiktoken===0.8.0
loguru==0.7.3
# Testing
pytest==8.0.0