Enhance TTS API with logging, voice pack loading, and schema updates

This commit is contained in:
remsky 2024-12-31 01:57:00 -07:00
parent 8ce8334345
commit c11a6ea6ea
12 changed files with 451 additions and 253 deletions

View file

@ -1,26 +1,38 @@
""" """
FastAPI OpenAI Compatible API FastAPI OpenAI Compatible API
""" """
import uvicorn import uvicorn
import logging
import sys
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from .core.config import settings from .core.config import settings
from .routers.openai_compatible import router as openai_router from .routers.openai_compatible import router as openai_router
# from services.tts import TTSModel from .services.tts import TTSModel, TTSService
logger = logging.getLogger(__name__)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
"""Lifespan context manager for database initialization""" """Lifespan context manager for model initialization"""
logger.info("Loading TTS model and voice packs...")
# Initialize the main model
model, device = TTSModel.get_instance()
logger.info(f"Model loaded on {device}")
# Initialize all voice packs
tts_service = TTSService()
voices = tts_service.list_voices()
for voice in voices:
logger.info(f"Loading voice pack: {voice}")
TTSModel.get_voicepack(voice)
logger.info("All models and voice packs loaded successfully")
yield yield
# Initialize FastAPI app # Initialize FastAPI app
app = FastAPI( app = FastAPI(
title=settings.api_title, title=settings.api_title,
@ -42,16 +54,19 @@ app.add_middleware(
# Include OpenAI compatible router # Include OpenAI compatible router
app.include_router(openai_router, prefix="/v1") app.include_router(openai_router, prefix="/v1")
# Health check endpoint # Health check endpoint
@app.get("/health") @app.get("/health")
async def health_check(): async def health_check():
"""Health check endpoint""" """Health check endpoint"""
return {"status": "healthy"} return {"status": "healthy"}
@app.get("/v1/test") @app.get("/v1/test")
async def test_endpoint(): async def test_endpoint():
"""Test endpoint to verify routing""" """Test endpoint to verify routing"""
return {"status": "ok"} return {"status": "ok"}
if __name__ == "__main__": if __name__ == "__main__":
uvicorn.run("api.src.main:app", host=settings.host, port=settings.port, reload=True) uvicorn.run("api.src.main:app", host=settings.host, port=settings.port, reload=True)

View file

@ -1 +1 @@
# #

View file

@ -1,5 +1,4 @@
from fastapi import APIRouter, HTTPException, Response, Depends from fastapi import APIRouter, HTTPException, Response, Depends
from sqlalchemy.orm import Session
import logging import logging
from ..structures.schemas import OpenAISpeechRequest from ..structures.schemas import OpenAISpeechRequest
from ..services.tts import TTSService from ..services.tts import TTSService
@ -12,14 +11,17 @@ router = APIRouter(
responses={404: {"description": "Not found"}}, responses={404: {"description": "Not found"}},
) )
def get_tts_service() -> TTSService: def get_tts_service() -> TTSService:
"""Dependency to get TTSService instance with database session""" """Dependency to get TTSService instance with database session"""
return TTSService(start_worker=False) # Don't start worker thread for OpenAI endpoint return TTSService(
start_worker=False
) # Don't start worker thread for OpenAI endpoint
@router.post("/audio/speech") @router.post("/audio/speech")
async def create_speech( async def create_speech(
request: OpenAISpeechRequest, request: OpenAISpeechRequest, tts_service: TTSService = Depends(get_tts_service)
tts_service: TTSService = Depends(get_tts_service)
): ):
"""OpenAI-compatible endpoint for text-to-speech""" """OpenAI-compatible endpoint for text-to-speech"""
try: try:
@ -28,28 +30,27 @@ async def create_speech(
text=request.input, text=request.input,
voice=request.voice, voice=request.voice,
speed=request.speed, speed=request.speed,
stitch_long_output=True stitch_long_output=True,
) )
# Convert to requested format # Convert to requested format
content = AudioService.convert_audio(audio, 24000, request.response_format) content = AudioService.convert_audio(audio, 24000, request.response_format)
return Response( return Response(
content=content, content=content,
media_type=f"audio/{request.response_format}", media_type=f"audio/{request.response_format}",
headers={ headers={
"Content-Disposition": f"attachment; filename=speech.{request.response_format}" "Content-Disposition": f"attachment; filename=speech.{request.response_format}"
} },
) )
except Exception as e: except Exception as e:
logger.error(f"Error generating speech: {str(e)}") logger.error(f"Error generating speech: {str(e)}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.get("/audio/voices") @router.get("/audio/voices")
async def list_voices( async def list_voices(tts_service: TTSService = Depends(get_tts_service)):
tts_service: TTSService = Depends(get_tts_service)
):
"""List all available voices for text-to-speech""" """List all available voices for text-to-speech"""
try: try:
voices = tts_service.list_voices() voices = tts_service.list_voices()

View file

@ -1,4 +1,5 @@
"""Audio conversion service""" """Audio conversion service"""
from io import BytesIO from io import BytesIO
import numpy as np import numpy as np
import scipy.io.wavfile as wavfile import scipy.io.wavfile as wavfile
@ -7,60 +8,69 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AudioService: class AudioService:
"""Service for audio format conversions""" """Service for audio format conversions"""
@staticmethod @staticmethod
def convert_audio(audio_data: np.ndarray, sample_rate: int, output_format: str) -> bytes: def convert_audio(
audio_data: np.ndarray, sample_rate: int, output_format: str
) -> bytes:
"""Convert audio data to specified format """Convert audio data to specified format
Args: Args:
audio_data: Numpy array of audio samples audio_data: Numpy array of audio samples
sample_rate: Sample rate of the audio sample_rate: Sample rate of the audio
output_format: Target format (wav, mp3, etc.) output_format: Target format (wav, mp3, etc.)
Returns: Returns:
Bytes of the converted audio Bytes of the converted audio
""" """
buffer = BytesIO() buffer = BytesIO()
try: try:
if output_format == 'wav': if output_format == "wav":
logger.info("Writing to WAV format...") logger.info("Writing to WAV format...")
wavfile.write(buffer, sample_rate, audio_data) wavfile.write(buffer, sample_rate, audio_data)
return buffer.getvalue() return buffer.getvalue()
elif output_format == 'mp3': elif output_format == "mp3":
# For MP3, we need to convert to WAV first # For MP3, we need to convert to WAV first
logger.info("Converting to MP3 format...") logger.info("Converting to MP3 format...")
wav_buffer = BytesIO() wav_buffer = BytesIO()
wavfile.write(wav_buffer, sample_rate, audio_data) wavfile.write(wav_buffer, sample_rate, audio_data)
wav_buffer.seek(0) wav_buffer.seek(0)
# Convert WAV to MP3 using soundfile # Convert WAV to MP3 using soundfile
buffer = BytesIO() buffer = BytesIO()
sf.write(buffer, audio_data, sample_rate, format='mp3') sf.write(buffer, audio_data, sample_rate, format="mp3")
return buffer.getvalue() return buffer.getvalue()
elif output_format == 'opus': elif output_format == "opus":
logger.info("Converting to Opus format...") logger.info("Converting to Opus format...")
sf.write(buffer, audio_data, sample_rate, format='ogg', subtype='opus') sf.write(buffer, audio_data, sample_rate, format="ogg", subtype="opus")
return buffer.getvalue() return buffer.getvalue()
elif output_format == 'flac': elif output_format == "flac":
logger.info("Converting to FLAC format...") logger.info("Converting to FLAC format...")
sf.write(buffer, audio_data, sample_rate, format='flac') sf.write(buffer, audio_data, sample_rate, format="flac")
return buffer.getvalue() return buffer.getvalue()
elif output_format == 'aac': elif output_format == "aac":
raise ValueError("AAC format is not currently supported. Please use wav, mp3, opus, or flac.") raise ValueError(
"AAC format is not currently supported. Please use wav, mp3, opus, or flac."
elif output_format == 'pcm': )
raise ValueError("PCM format is not currently supported. Please use wav, mp3, opus, or flac.")
elif output_format == "pcm":
raise ValueError(
"PCM format is not currently supported. Please use wav, mp3, opus, or flac."
)
else: else:
raise ValueError(f"Format {output_format} not supported. Supported formats are: wav, mp3, opus, flac.") raise ValueError(
f"Format {output_format} not supported. Supported formats are: wav, mp3, opus, flac."
)
except Exception as e: except Exception as e:
logger.error(f"Error converting audio to {output_format}: {str(e)}") logger.error(f"Error converting audio to {output_format}: {str(e)}")
raise ValueError(f"Failed to convert audio to {output_format}: {str(e)}") raise ValueError(f"Failed to convert audio to {output_format}: {str(e)}")

View file

@ -2,8 +2,7 @@ import os
import threading import threading
import time import time
import io import io
from typing import Optional, List, Tuple from typing import List, Tuple
from sqlalchemy.orm import Session
import numpy as np import numpy as np
import torch import torch
import scipy.io.wavfile as wavfile import scipy.io.wavfile as wavfile
@ -17,6 +16,7 @@ import tiktoken
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
enc = tiktoken.get_encoding("cl100k_base") enc = tiktoken.get_encoding("cl100k_base")
class TTSModel: class TTSModel:
_instance = None _instance = None
_lock = threading.Lock() _lock = threading.Lock()
@ -40,7 +40,9 @@ class TTSModel:
model, device = cls.get_instance() model, device = cls.get_instance()
if voice_name not in cls._voicepacks: if voice_name not in cls._voicepacks:
try: try:
voice_path = os.path.join(settings.model_dir, settings.voices_dir, f"{voice_name}.pt") voice_path = os.path.join(
settings.model_dir, settings.voices_dir, f"{voice_name}.pt"
)
voicepack = torch.load( voicepack = torch.load(
voice_path, map_location=device, weights_only=True voice_path, map_location=device, weights_only=True
) )
@ -61,9 +63,11 @@ class TTSService:
def _split_text(self, text: str) -> List[str]: def _split_text(self, text: str) -> List[str]:
"""Split text into sentences""" """Split text into sentences"""
return [s.strip() for s in re.split(r'(?<=[.!?])\s+', text) if s.strip()] return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()]
def _generate_audio(self, text: str, voice: str, speed: float, stitch_long_output: bool = True) -> Tuple[torch.Tensor, float]: def _generate_audio(
self, text: str, voice: str, speed: float, stitch_long_output: bool = True
) -> Tuple[torch.Tensor, float]:
"""Generate audio and measure processing time""" """Generate audio and measure processing time"""
start_time = time.time() start_time = time.time()
@ -87,22 +91,34 @@ class TTSService:
# Validate phonemization first # Validate phonemization first
ps = phonemize(chunk, voice[0]) ps = phonemize(chunk, voice[0])
tokens = tokenize(ps) tokens = tokenize(ps)
logger.info(f"Processing chunk {i+1}/{len(chunks)}: {len(tokens)} tokens") logger.info(
f"Processing chunk {i+1}/{len(chunks)}: {len(tokens)} tokens"
)
# Only proceed if phonemization succeeded # Only proceed if phonemization succeeded
chunk_audio, _ = generate(model, chunk, voicepack, lang=voice[0], speed=speed) chunk_audio, _ = generate(
model, chunk, voicepack, lang=voice[0], speed=speed
)
if chunk_audio is not None: if chunk_audio is not None:
audio_chunks.append(chunk_audio) audio_chunks.append(chunk_audio)
else: else:
logger.error(f"No audio generated for chunk {i+1}/{len(chunks)}") logger.error(
f"No audio generated for chunk {i+1}/{len(chunks)}"
)
except Exception as e: except Exception as e:
logger.error(f"Failed to generate audio for chunk {i+1}/{len(chunks)}: '{chunk}'. Error: {str(e)}") logger.error(
f"Failed to generate audio for chunk {i+1}/{len(chunks)}: '{chunk}'. Error: {str(e)}"
)
continue continue
if not audio_chunks: if not audio_chunks:
raise ValueError("No audio chunks were generated successfully") raise ValueError("No audio chunks were generated successfully")
audio = np.concatenate(audio_chunks) if len(audio_chunks) > 1 else audio_chunks[0] audio = (
np.concatenate(audio_chunks)
if len(audio_chunks) > 1
else audio_chunks[0]
)
else: else:
audio, _ = generate(model, text, voicepack, lang=voice[0], speed=speed) audio, _ = generate(model, text, voicepack, lang=voice[0], speed=speed)

View file

@ -15,17 +15,24 @@ class TTSStatus(str, Enum):
class OpenAISpeechRequest(BaseModel): class OpenAISpeechRequest(BaseModel):
model: Literal["tts-1", "tts-1-hd"] = "tts-1" model: Literal["tts-1", "tts-1-hd"] = "tts-1"
input: str = Field(..., description="The text to generate audio for") input: str = Field(..., description="The text to generate audio for")
voice: Literal["am_adam", "am_michael", "bm_lewis", "af", "bm_george", "bf_isabella", "bf_emma", "af_sarah", "af_bella"] = Field( voice: Literal[
default="af", "am_adam",
description="The voice to use for generation" "am_michael",
) "bm_lewis",
"af",
"bm_george",
"bf_isabella",
"bf_emma",
"af_sarah",
"af_bella",
] = Field(default="af", description="The voice to use for generation")
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field( response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field(
default="mp3", default="mp3",
description="The format to return audio in. Supported formats: mp3, opus, flac, wav. AAC and PCM are not currently supported." description="The format to return audio in. Supported formats: mp3, opus, flac, wav. AAC and PCM are not currently supported.",
) )
speed: float = Field( speed: float = Field(
default=1.0, default=1.0,
ge=0.25, ge=0.25,
le=4.0, le=4.0,
description="The speed of the generated audio. Select a value from 0.25 to 4.0." description="The speed of the generated audio. Select a value from 0.25 to 4.0.",
) )

View file

@ -3,17 +3,15 @@ from unittest.mock import Mock, patch
import sys import sys
# Mock torch and other ML modules before they're imported # Mock torch and other ML modules before they're imported
sys.modules['torch'] = Mock() sys.modules["torch"] = Mock()
sys.modules['transformers'] = Mock() sys.modules["transformers"] = Mock()
sys.modules['phonemizer'] = Mock() sys.modules["phonemizer"] = Mock()
sys.modules['models'] = Mock() sys.modules["models"] = Mock()
sys.modules['models.build_model'] = Mock() sys.modules["models.build_model"] = Mock()
sys.modules['kokoro'] = Mock() sys.modules["kokoro"] = Mock()
sys.modules['kokoro.generate'] = Mock() sys.modules["kokoro.generate"] = Mock()
sys.modules['kokoro.phonemize'] = Mock() sys.modules["kokoro.phonemize"] = Mock()
sys.modules['kokoro.tokenize'] = Mock() sys.modules["kokoro.tokenize"] = Mock()
from api.src.main import app
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)

View file

@ -1,27 +1,44 @@
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
import pytest import pytest
from unittest.mock import Mock, patch from unittest.mock import Mock
from ..src.main import app from ..src.main import app
from ..src.services.tts import TTSService
from ..src.routers.openai_compatible import TTSService as OpenAITTSService
# Create test client # Create test client
client = TestClient(app) client = TestClient(app)
# Mock services # Mock services
@pytest.fixture @pytest.fixture
def mock_tts_service(monkeypatch): def mock_tts_service(monkeypatch):
mock_service = Mock() mock_service = Mock()
mock_service._generate_audio.return_value = (bytes([0, 1, 2, 3]), 1.0) mock_service._generate_audio.return_value = (bytes([0, 1, 2, 3]), 1.0)
mock_service.list_voices.return_value = ["af", "bm_lewis", "bf_isabella", "bf_emma", "af_sarah", "af_bella", "am_adam", "am_michael", "bm_george"] mock_service.list_voices.return_value = [
monkeypatch.setattr("api.src.routers.openai_compatible.TTSService", lambda *args, **kwargs: mock_service) "af",
"bm_lewis",
"bf_isabella",
"bf_emma",
"af_sarah",
"af_bella",
"am_adam",
"am_michael",
"bm_george",
]
monkeypatch.setattr(
"api.src.routers.openai_compatible.TTSService",
lambda *args, **kwargs: mock_service,
)
return mock_service return mock_service
@pytest.fixture @pytest.fixture
def mock_audio_service(monkeypatch): def mock_audio_service(monkeypatch):
def mock_convert(*args): def mock_convert(*args):
return b"converted mock audio data" return b"converted mock audio data"
monkeypatch.setattr("api.src.routers.openai_compatible.AudioService.convert_audio", mock_convert)
monkeypatch.setattr(
"api.src.routers.openai_compatible.AudioService.convert_audio", mock_convert
)
def test_health_check(): def test_health_check():
"""Test the health check endpoint""" """Test the health check endpoint"""
@ -29,6 +46,7 @@ def test_health_check():
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == {"status": "healthy"} assert response.json() == {"status": "healthy"}
def test_openai_speech_endpoint(mock_tts_service, mock_audio_service): def test_openai_speech_endpoint(mock_tts_service, mock_audio_service):
"""Test the OpenAI-compatible speech endpoint""" """Test the OpenAI-compatible speech endpoint"""
test_request = { test_request = {
@ -36,20 +54,18 @@ def test_openai_speech_endpoint(mock_tts_service, mock_audio_service):
"input": "Hello world", "input": "Hello world",
"voice": "bm_lewis", "voice": "bm_lewis",
"response_format": "wav", "response_format": "wav",
"speed": 1.0 "speed": 1.0,
} }
response = client.post("/v1/audio/speech", json=test_request) response = client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 200 assert response.status_code == 200
assert response.headers["content-type"] == "audio/wav" assert response.headers["content-type"] == "audio/wav"
assert response.headers["content-disposition"] == "attachment; filename=speech.wav" assert response.headers["content-disposition"] == "attachment; filename=speech.wav"
mock_tts_service._generate_audio.assert_called_once_with( mock_tts_service._generate_audio.assert_called_once_with(
text="Hello world", text="Hello world", voice="bm_lewis", speed=1.0, stitch_long_output=True
voice="bm_lewis",
speed=1.0,
stitch_long_output=True
) )
assert response.content == b"converted mock audio data" assert response.content == b"converted mock audio data"
def test_openai_speech_invalid_voice(mock_tts_service): def test_openai_speech_invalid_voice(mock_tts_service):
"""Test the OpenAI-compatible speech endpoint with invalid voice""" """Test the OpenAI-compatible speech endpoint with invalid voice"""
test_request = { test_request = {
@ -57,11 +73,12 @@ def test_openai_speech_invalid_voice(mock_tts_service):
"input": "Hello world", "input": "Hello world",
"voice": "invalid_voice", "voice": "invalid_voice",
"response_format": "wav", "response_format": "wav",
"speed": 1.0 "speed": 1.0,
} }
response = client.post("/v1/audio/speech", json=test_request) response = client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 422 # Validation error assert response.status_code == 422 # Validation error
def test_openai_speech_invalid_speed(mock_tts_service): def test_openai_speech_invalid_speed(mock_tts_service):
"""Test the OpenAI-compatible speech endpoint with invalid speed""" """Test the OpenAI-compatible speech endpoint with invalid speed"""
test_request = { test_request = {
@ -69,11 +86,12 @@ def test_openai_speech_invalid_speed(mock_tts_service):
"input": "Hello world", "input": "Hello world",
"voice": "af", "voice": "af",
"response_format": "wav", "response_format": "wav",
"speed": -1.0 # Invalid speed "speed": -1.0, # Invalid speed
} }
response = client.post("/v1/audio/speech", json=test_request) response = client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 422 # Validation error assert response.status_code == 422 # Validation error
def test_openai_speech_generation_error(mock_tts_service): def test_openai_speech_generation_error(mock_tts_service):
"""Test error handling in speech generation""" """Test error handling in speech generation"""
mock_tts_service._generate_audio.side_effect = Exception("Generation failed") mock_tts_service._generate_audio.side_effect = Exception("Generation failed")
@ -82,7 +100,7 @@ def test_openai_speech_generation_error(mock_tts_service):
"input": "Hello world", "input": "Hello world",
"voice": "af", "voice": "af",
"response_format": "wav", "response_format": "wav",
"speed": 1.0 "speed": 1.0,
} }
response = client.post("/v1/audio/speech", json=test_request) response = client.post("/v1/audio/speech", json=test_request)
assert response.status_code == 500 assert response.status_code == 500

View file

@ -3,43 +3,43 @@ import time
import json import json
import scipy.io.wavfile as wavfile import scipy.io.wavfile as wavfile
import requests import requests
import numpy as np
import pandas as pd import pandas as pd
import seaborn as sns import seaborn as sns
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from scipy.signal import savgol_filter import tiktoken
import tiktoken
import psutil import psutil
import subprocess import subprocess
from datetime import datetime from datetime import datetime
enc = tiktoken.get_encoding("cl100k_base") enc = tiktoken.get_encoding("cl100k_base")
def setup_plot(fig, ax, title): def setup_plot(fig, ax, title):
"""Configure plot styling""" """Configure plot styling"""
# Improve grid # Improve grid
ax.grid(True, linestyle='--', alpha=0.3, color='#ffffff') ax.grid(True, linestyle="--", alpha=0.3, color="#ffffff")
# Set title and labels with better fonts # Set title and labels with better fonts
ax.set_title(title, pad=20, fontsize=16, fontweight='bold', color='#ffffff') ax.set_title(title, pad=20, fontsize=16, fontweight="bold", color="#ffffff")
ax.set_xlabel(ax.get_xlabel(), fontsize=14, fontweight='medium', color='#ffffff') ax.set_xlabel(ax.get_xlabel(), fontsize=14, fontweight="medium", color="#ffffff")
ax.set_ylabel(ax.get_ylabel(), fontsize=14, fontweight='medium', color='#ffffff') ax.set_ylabel(ax.get_ylabel(), fontsize=14, fontweight="medium", color="#ffffff")
# Improve tick labels # Improve tick labels
ax.tick_params(labelsize=12, colors='#ffffff') ax.tick_params(labelsize=12, colors="#ffffff")
# Style spines # Style spines
for spine in ax.spines.values(): for spine in ax.spines.values():
spine.set_color('#ffffff') spine.set_color("#ffffff")
spine.set_alpha(0.3) spine.set_alpha(0.3)
spine.set_linewidth(0.5) spine.set_linewidth(0.5)
# Set background colors # Set background colors
ax.set_facecolor('#1a1a2e') ax.set_facecolor("#1a1a2e")
fig.patch.set_facecolor('#1a1a2e') fig.patch.set_facecolor("#1a1a2e")
return fig, ax return fig, ax
def get_text_for_tokens(text: str, num_tokens: int) -> str: def get_text_for_tokens(text: str, num_tokens: int) -> str:
"""Get a slice of text that contains exactly num_tokens tokens""" """Get a slice of text that contains exactly num_tokens tokens"""
tokens = enc.encode(text) tokens = enc.encode(text)
@ -47,14 +47,15 @@ def get_text_for_tokens(text: str, num_tokens: int) -> str:
return text return text
return enc.decode(tokens[:num_tokens]) return enc.decode(tokens[:num_tokens])
def get_audio_length(audio_data: bytes) -> float: def get_audio_length(audio_data: bytes) -> float:
"""Get audio length in seconds from bytes data""" """Get audio length in seconds from bytes data"""
# Save to a temporary file # Save to a temporary file
temp_path = 'examples/benchmarks/output/temp.wav' temp_path = "examples/benchmarks/output/temp.wav"
os.makedirs(os.path.dirname(temp_path), exist_ok=True) os.makedirs(os.path.dirname(temp_path), exist_ok=True)
with open(temp_path, 'wb') as f: with open(temp_path, "wb") as f:
f.write(audio_data) f.write(audio_data)
# Read the audio file # Read the audio file
try: try:
rate, data = wavfile.read(temp_path) rate, data = wavfile.read(temp_path)
@ -64,60 +65,65 @@ def get_audio_length(audio_data: bytes) -> float:
if os.path.exists(temp_path): if os.path.exists(temp_path):
os.remove(temp_path) os.remove(temp_path)
def get_gpu_memory(): def get_gpu_memory():
"""Get GPU memory usage using nvidia-smi""" """Get GPU memory usage using nvidia-smi"""
try: try:
result = subprocess.check_output(['nvidia-smi', '--query-gpu=memory.used', '--format=csv,nounits,noheader']) result = subprocess.check_output(
return float(result.decode('utf-8').strip()) ["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader"]
)
return float(result.decode("utf-8").strip())
except (subprocess.CalledProcessError, FileNotFoundError): except (subprocess.CalledProcessError, FileNotFoundError):
return None return None
def get_system_metrics(): def get_system_metrics():
"""Get current system metrics""" """Get current system metrics"""
metrics = { metrics = {
'timestamp': datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
'cpu_percent': psutil.cpu_percent(), "cpu_percent": psutil.cpu_percent(),
'ram_percent': psutil.virtual_memory().percent, "ram_percent": psutil.virtual_memory().percent,
'ram_used_gb': psutil.virtual_memory().used / (1024**3), "ram_used_gb": psutil.virtual_memory().used / (1024**3),
} }
gpu_mem = get_gpu_memory() gpu_mem = get_gpu_memory()
if gpu_mem is not None: if gpu_mem is not None:
metrics['gpu_memory_used'] = gpu_mem metrics["gpu_memory_used"] = gpu_mem
return metrics return metrics
def make_tts_request(text: str, timeout: int = 120) -> tuple[float, float]: def make_tts_request(text: str, timeout: int = 120) -> tuple[float, float]:
"""Make TTS request using OpenAI-compatible endpoint and return processing time and output length""" """Make TTS request using OpenAI-compatible endpoint and return processing time and output length"""
try: try:
start_time = time.time() start_time = time.time()
# Make request to OpenAI-compatible endpoint # Make request to OpenAI-compatible endpoint
response = requests.post( response = requests.post(
'http://localhost:8880/v1/audio/speech', "http://localhost:8880/v1/audio/speech",
json={ json={
'model': 'tts-1', "model": "tts-1",
'input': text, "input": text,
'voice': 'af', "voice": "af",
'response_format': 'wav' "response_format": "wav",
}, },
timeout=timeout timeout=timeout,
) )
response.raise_for_status() response.raise_for_status()
processing_time = time.time() - start_time processing_time = time.time() - start_time
audio_length = get_audio_length(response.content) audio_length = get_audio_length(response.content)
# Save the audio file # Save the audio file
token_count = len(enc.encode(text)) token_count = len(enc.encode(text))
output_file = f'examples/benchmarks/output/chunk_{token_count}_tokens.wav' output_file = f"examples/benchmarks/output/chunk_{token_count}_tokens.wav"
os.makedirs(os.path.dirname(output_file), exist_ok=True) os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, 'wb') as f: with open(output_file, "wb") as f:
f.write(response.content) f.write(response.content)
print(f"Saved audio to {output_file}") print(f"Saved audio to {output_file}")
return processing_time, audio_length return processing_time, audio_length
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
print(f"Error making request for text: {text[:50]}... Error: {str(e)}") print(f"Error making request for text: {text[:50]}... Error: {str(e)}")
return None, None return None, None
@ -125,91 +131,113 @@ def make_tts_request(text: str, timeout: int = 120) -> tuple[float, float]:
print(f"Error processing text: {text[:50]}... Error: {str(e)}") print(f"Error processing text: {text[:50]}... Error: {str(e)}")
return None, None return None, None
def plot_system_metrics(metrics_data): def plot_system_metrics(metrics_data):
"""Create plots for system metrics over time""" """Create plots for system metrics over time"""
df = pd.DataFrame(metrics_data) df = pd.DataFrame(metrics_data)
df['timestamp'] = pd.to_datetime(df['timestamp']) df["timestamp"] = pd.to_datetime(df["timestamp"])
elapsed_time = (df['timestamp'] - df['timestamp'].iloc[0]).dt.total_seconds() elapsed_time = (df["timestamp"] - df["timestamp"].iloc[0]).dt.total_seconds()
# Get baseline values (first measurement) # Get baseline values (first measurement)
baseline_cpu = df['cpu_percent'].iloc[0] baseline_cpu = df["cpu_percent"].iloc[0]
baseline_ram = df['ram_used_gb'].iloc[0] baseline_ram = df["ram_used_gb"].iloc[0]
baseline_gpu = df['gpu_memory_used'].iloc[0] / 1024 if 'gpu_memory_used' in df.columns else None # Convert MB to GB baseline_gpu = (
df["gpu_memory_used"].iloc[0] / 1024
if "gpu_memory_used" in df.columns
else None
) # Convert MB to GB
# Convert GPU memory to GB # Convert GPU memory to GB
if 'gpu_memory_used' in df.columns: if "gpu_memory_used" in df.columns:
df['gpu_memory_gb'] = df['gpu_memory_used'] / 1024 df["gpu_memory_gb"] = df["gpu_memory_used"] / 1024
# Set plotting style # Set plotting style
plt.style.use('dark_background') plt.style.use("dark_background")
# Create figure with 3 subplots (or 2 if no GPU) # Create figure with 3 subplots (or 2 if no GPU)
has_gpu = 'gpu_memory_used' in df.columns has_gpu = "gpu_memory_used" in df.columns
num_plots = 3 if has_gpu else 2 num_plots = 3 if has_gpu else 2
fig, axes = plt.subplots(num_plots, 1, figsize=(15, 5*num_plots)) fig, axes = plt.subplots(num_plots, 1, figsize=(15, 5 * num_plots))
fig.patch.set_facecolor('#1a1a2e') fig.patch.set_facecolor("#1a1a2e")
# Apply rolling average for smoothing # Apply rolling average for smoothing
window = min(5, len(df) // 2) # Smaller window for smoother lines window = min(5, len(df) // 2) # Smaller window for smoother lines
# Plot 1: CPU Usage # Plot 1: CPU Usage
smoothed_cpu = df['cpu_percent'].rolling(window=window, center=True).mean() smoothed_cpu = df["cpu_percent"].rolling(window=window, center=True).mean()
sns.lineplot(x=elapsed_time, y=smoothed_cpu, ax=axes[0], color='#ff2a6d', linewidth=2) sns.lineplot(
axes[0].axhline(y=baseline_cpu, color='#05d9e8', linestyle='--', alpha=0.5, label='Baseline') x=elapsed_time, y=smoothed_cpu, ax=axes[0], color="#ff2a6d", linewidth=2
axes[0].set_xlabel('Time (seconds)', fontsize=14) )
axes[0].set_ylabel('CPU Usage (%)', fontsize=14) axes[0].axhline(
y=baseline_cpu, color="#05d9e8", linestyle="--", alpha=0.5, label="Baseline"
)
axes[0].set_xlabel("Time (seconds)", fontsize=14)
axes[0].set_ylabel("CPU Usage (%)", fontsize=14)
axes[0].tick_params(labelsize=12) axes[0].tick_params(labelsize=12)
axes[0].set_title('CPU Usage Over Time', pad=20, fontsize=16, fontweight='bold') axes[0].set_title("CPU Usage Over Time", pad=20, fontsize=16, fontweight="bold")
axes[0].set_ylim(0, max(df['cpu_percent']) * 1.1) # Add 10% padding axes[0].set_ylim(0, max(df["cpu_percent"]) * 1.1) # Add 10% padding
axes[0].legend() axes[0].legend()
# Plot 2: RAM Usage # Plot 2: RAM Usage
smoothed_ram = df['ram_used_gb'].rolling(window=window, center=True).mean() smoothed_ram = df["ram_used_gb"].rolling(window=window, center=True).mean()
sns.lineplot(x=elapsed_time, y=smoothed_ram, ax=axes[1], color='#05d9e8', linewidth=2) sns.lineplot(
axes[1].axhline(y=baseline_ram, color='#ff2a6d', linestyle='--', alpha=0.5, label='Baseline') x=elapsed_time, y=smoothed_ram, ax=axes[1], color="#05d9e8", linewidth=2
axes[1].set_xlabel('Time (seconds)', fontsize=14) )
axes[1].set_ylabel('RAM Usage (GB)', fontsize=14) axes[1].axhline(
y=baseline_ram, color="#ff2a6d", linestyle="--", alpha=0.5, label="Baseline"
)
axes[1].set_xlabel("Time (seconds)", fontsize=14)
axes[1].set_ylabel("RAM Usage (GB)", fontsize=14)
axes[1].tick_params(labelsize=12) axes[1].tick_params(labelsize=12)
axes[1].set_title('RAM Usage Over Time', pad=20, fontsize=16, fontweight='bold') axes[1].set_title("RAM Usage Over Time", pad=20, fontsize=16, fontweight="bold")
axes[1].set_ylim(0, max(df['ram_used_gb']) * 1.1) # Add 10% padding axes[1].set_ylim(0, max(df["ram_used_gb"]) * 1.1) # Add 10% padding
axes[1].legend() axes[1].legend()
# Plot 3: GPU Memory (if available) # Plot 3: GPU Memory (if available)
if has_gpu: if has_gpu:
smoothed_gpu = df['gpu_memory_gb'].rolling(window=window, center=True).mean() smoothed_gpu = df["gpu_memory_gb"].rolling(window=window, center=True).mean()
sns.lineplot(x=elapsed_time, y=smoothed_gpu, ax=axes[2], color='#ff2a6d', linewidth=2) sns.lineplot(
axes[2].axhline(y=baseline_gpu, color='#05d9e8', linestyle='--', alpha=0.5, label='Baseline') x=elapsed_time, y=smoothed_gpu, ax=axes[2], color="#ff2a6d", linewidth=2
axes[2].set_xlabel('Time (seconds)', fontsize=14) )
axes[2].set_ylabel('GPU Memory (GB)', fontsize=14) axes[2].axhline(
y=baseline_gpu, color="#05d9e8", linestyle="--", alpha=0.5, label="Baseline"
)
axes[2].set_xlabel("Time (seconds)", fontsize=14)
axes[2].set_ylabel("GPU Memory (GB)", fontsize=14)
axes[2].tick_params(labelsize=12) axes[2].tick_params(labelsize=12)
axes[2].set_title('GPU Memory Usage Over Time', pad=20, fontsize=16, fontweight='bold') axes[2].set_title(
axes[2].set_ylim(0, max(df['gpu_memory_gb']) * 1.1) # Add 10% padding "GPU Memory Usage Over Time", pad=20, fontsize=16, fontweight="bold"
)
axes[2].set_ylim(0, max(df["gpu_memory_gb"]) * 1.1) # Add 10% padding
axes[2].legend() axes[2].legend()
# Style all subplots # Style all subplots
for ax in axes: for ax in axes:
ax.grid(True, linestyle='--', alpha=0.3) ax.grid(True, linestyle="--", alpha=0.3)
ax.set_facecolor('#1a1a2e') ax.set_facecolor("#1a1a2e")
for spine in ax.spines.values(): for spine in ax.spines.values():
spine.set_color('#ffffff') spine.set_color("#ffffff")
spine.set_alpha(0.3) spine.set_alpha(0.3)
plt.tight_layout() plt.tight_layout()
plt.savefig('examples/benchmarks/system_usage.png', dpi=300, bbox_inches='tight') plt.savefig("examples/benchmarks/system_usage.png", dpi=300, bbox_inches="tight")
plt.close() plt.close()
def main(): def main():
# Create output directory # Create output directory
os.makedirs('examples/benchmarks/output', exist_ok=True) os.makedirs("examples/benchmarks/output", exist_ok=True)
# Read input text # Read input text
with open('examples/benchmarks/the_time_machine_hg_wells.txt', 'r', encoding='utf-8') as f: with open(
"examples/benchmarks/the_time_machine_hg_wells.txt", "r", encoding="utf-8"
) as f:
text = f.read() text = f.read()
# Get total tokens in file # Get total tokens in file
total_tokens = len(enc.encode(text)) total_tokens = len(enc.encode(text))
print(f"Total tokens in file: {total_tokens}") print(f"Total tokens in file: {total_tokens}")
# Generate token sizes with dense sampling at start and increasing intervals # Generate token sizes with dense sampling at start and increasing intervals
dense_range = list(range(100, 600, 100)) # 100, 200, 300, 400, 500 dense_range = list(range(100, 600, 100)) # 100, 200, 300, 400, 500
medium_range = [750, 1000, 1500, 2000, 3000] medium_range = [750, 1000, 1500, 2000, 3000]
@ -218,120 +246,160 @@ def main():
while current <= total_tokens: while current <= total_tokens:
large_range.append(current) large_range.append(current)
current *= 2 current *= 2
token_sizes = dense_range + medium_range + large_range token_sizes = dense_range + medium_range + large_range
# Process chunks # Process chunks
results = [] results = []
system_metrics = [] system_metrics = []
test_start_time = time.time() test_start_time = time.time()
for num_tokens in token_sizes: for num_tokens in token_sizes:
# Get text slice with exact token count # Get text slice with exact token count
chunk = get_text_for_tokens(text, num_tokens) chunk = get_text_for_tokens(text, num_tokens)
actual_tokens = len(enc.encode(chunk)) actual_tokens = len(enc.encode(chunk))
print(f"\nProcessing chunk with {actual_tokens} tokens:") print(f"\nProcessing chunk with {actual_tokens} tokens:")
print(f"Text preview: {chunk[:100]}...") print(f"Text preview: {chunk[:100]}...")
# Collect system metrics before processing # Collect system metrics before processing
system_metrics.append(get_system_metrics()) system_metrics.append(get_system_metrics())
processing_time, audio_length = make_tts_request(chunk) processing_time, audio_length = make_tts_request(chunk)
if processing_time is None or audio_length is None: if processing_time is None or audio_length is None:
print("Breaking loop due to error") print("Breaking loop due to error")
break break
# Collect system metrics after processing # Collect system metrics after processing
system_metrics.append(get_system_metrics()) system_metrics.append(get_system_metrics())
results.append({ results.append(
'tokens': actual_tokens, {
'processing_time': processing_time, "tokens": actual_tokens,
'output_length': audio_length, "processing_time": processing_time,
'realtime_factor': audio_length / processing_time, "output_length": audio_length,
'elapsed_time': time.time() - test_start_time "realtime_factor": audio_length / processing_time,
}) "elapsed_time": time.time() - test_start_time,
}
)
# Save intermediate results # Save intermediate results
with open('examples/benchmarks/benchmark_results.json', 'w') as f: with open("examples/benchmarks/benchmark_results.json", "w") as f:
json.dump({ json.dump(
'results': results, {"results": results, "system_metrics": system_metrics}, f, indent=2
'system_metrics': system_metrics )
}, f, indent=2)
# Create DataFrame and calculate stats # Create DataFrame and calculate stats
df = pd.DataFrame(results) df = pd.DataFrame(results)
if df.empty: if df.empty:
print("No data to plot") print("No data to plot")
return return
# Calculate useful metrics # Calculate useful metrics
df['tokens_per_second'] = df['tokens'] / df['processing_time'] df["tokens_per_second"] = df["tokens"] / df["processing_time"]
# Write detailed stats # Write detailed stats
with open('examples/benchmarks/benchmark_stats.txt', 'w') as f: with open("examples/benchmarks/benchmark_stats.txt", "w") as f:
f.write("=== Benchmark Statistics ===\n\n") f.write("=== Benchmark Statistics ===\n\n")
f.write("Overall Stats:\n") f.write("Overall Stats:\n")
f.write(f"Total tokens processed: {df['tokens'].sum()}\n") f.write(f"Total tokens processed: {df['tokens'].sum()}\n")
f.write(f"Total audio generated: {df['output_length'].sum():.2f}s\n") f.write(f"Total audio generated: {df['output_length'].sum():.2f}s\n")
f.write(f"Total test duration: {df['elapsed_time'].max():.2f}s\n") f.write(f"Total test duration: {df['elapsed_time'].max():.2f}s\n")
f.write(f"Average processing rate: {df['tokens_per_second'].mean():.2f} tokens/second\n") f.write(
f"Average processing rate: {df['tokens_per_second'].mean():.2f} tokens/second\n"
)
f.write(f"Average realtime factor: {df['realtime_factor'].mean():.2f}x\n\n") f.write(f"Average realtime factor: {df['realtime_factor'].mean():.2f}x\n\n")
f.write("Per-chunk Stats:\n") f.write("Per-chunk Stats:\n")
f.write(f"Average chunk size: {df['tokens'].mean():.2f} tokens\n") f.write(f"Average chunk size: {df['tokens'].mean():.2f} tokens\n")
f.write(f"Min chunk size: {df['tokens'].min():.2f} tokens\n") f.write(f"Min chunk size: {df['tokens'].min():.2f} tokens\n")
f.write(f"Max chunk size: {df['tokens'].max():.2f} tokens\n") f.write(f"Max chunk size: {df['tokens'].max():.2f} tokens\n")
f.write(f"Average processing time: {df['processing_time'].mean():.2f}s\n") f.write(f"Average processing time: {df['processing_time'].mean():.2f}s\n")
f.write(f"Average output length: {df['output_length'].mean():.2f}s\n\n") f.write(f"Average output length: {df['output_length'].mean():.2f}s\n\n")
f.write("Performance Ranges:\n") f.write("Performance Ranges:\n")
f.write(f"Processing rate range: {df['tokens_per_second'].min():.2f} - {df['tokens_per_second'].max():.2f} tokens/second\n") f.write(
f.write(f"Realtime factor range: {df['realtime_factor'].min():.2f}x - {df['realtime_factor'].max():.2f}x\n") f"Processing rate range: {df['tokens_per_second'].min():.2f} - {df['tokens_per_second'].max():.2f} tokens/second\n"
)
f.write(
f"Realtime factor range: {df['realtime_factor'].min():.2f}x - {df['realtime_factor'].max():.2f}x\n"
)
# Set plotting style # Set plotting style
plt.style.use('dark_background') plt.style.use("dark_background")
# Plot 1: Processing Time vs Token Count # Plot 1: Processing Time vs Token Count
fig, ax = plt.subplots(figsize=(12, 8)) fig, ax = plt.subplots(figsize=(12, 8))
sns.scatterplot(data=df, x='tokens', y='processing_time', s=100, alpha=0.6, color='#ff2a6d') sns.scatterplot(
sns.regplot(data=df, x='tokens', y='processing_time', scatter=False, color='#05d9e8', line_kws={'linewidth': 2}) data=df, x="tokens", y="processing_time", s=100, alpha=0.6, color="#ff2a6d"
corr = df['tokens'].corr(df['processing_time']) )
plt.text(0.05, 0.95, f'Correlation: {corr:.2f}', transform=ax.transAxes, fontsize=10, color='#ffffff', sns.regplot(
bbox=dict(facecolor='#1a1a2e', edgecolor='#ffffff', alpha=0.7)) data=df,
setup_plot(fig, ax, 'Processing Time vs Input Size') x="tokens",
ax.set_xlabel('Number of Input Tokens') y="processing_time",
ax.set_ylabel('Processing Time (seconds)') scatter=False,
plt.savefig('examples/benchmarks/processing_time.png', dpi=300, bbox_inches='tight') color="#05d9e8",
line_kws={"linewidth": 2},
)
corr = df["tokens"].corr(df["processing_time"])
plt.text(
0.05,
0.95,
f"Correlation: {corr:.2f}",
transform=ax.transAxes,
fontsize=10,
color="#ffffff",
bbox=dict(facecolor="#1a1a2e", edgecolor="#ffffff", alpha=0.7),
)
setup_plot(fig, ax, "Processing Time vs Input Size")
ax.set_xlabel("Number of Input Tokens")
ax.set_ylabel("Processing Time (seconds)")
plt.savefig("examples/benchmarks/processing_time.png", dpi=300, bbox_inches="tight")
plt.close() plt.close()
# Plot 2: Realtime Factor vs Token Count # Plot 2: Realtime Factor vs Token Count
fig, ax = plt.subplots(figsize=(12, 8)) fig, ax = plt.subplots(figsize=(12, 8))
sns.scatterplot(data=df, x='tokens', y='realtime_factor', s=100, alpha=0.6, color='#ff2a6d') sns.scatterplot(
sns.regplot(data=df, x='tokens', y='realtime_factor', scatter=False, color='#05d9e8', line_kws={'linewidth': 2}) data=df, x="tokens", y="realtime_factor", s=100, alpha=0.6, color="#ff2a6d"
corr = df['tokens'].corr(df['realtime_factor']) )
plt.text(0.05, 0.95, f'Correlation: {corr:.2f}', transform=ax.transAxes, fontsize=10, color='#ffffff', sns.regplot(
bbox=dict(facecolor='#1a1a2e', edgecolor='#ffffff', alpha=0.7)) data=df,
setup_plot(fig, ax, 'Realtime Factor vs Input Size') x="tokens",
ax.set_xlabel('Number of Input Tokens') y="realtime_factor",
ax.set_ylabel('Realtime Factor (output length / processing time)') scatter=False,
plt.savefig('examples/benchmarks/realtime_factor.png', dpi=300, bbox_inches='tight') color="#05d9e8",
line_kws={"linewidth": 2},
)
corr = df["tokens"].corr(df["realtime_factor"])
plt.text(
0.05,
0.95,
f"Correlation: {corr:.2f}",
transform=ax.transAxes,
fontsize=10,
color="#ffffff",
bbox=dict(facecolor="#1a1a2e", edgecolor="#ffffff", alpha=0.7),
)
setup_plot(fig, ax, "Realtime Factor vs Input Size")
ax.set_xlabel("Number of Input Tokens")
ax.set_ylabel("Realtime Factor (output length / processing time)")
plt.savefig("examples/benchmarks/realtime_factor.png", dpi=300, bbox_inches="tight")
plt.close() plt.close()
# Plot system metrics # Plot system metrics
plot_system_metrics(system_metrics) plot_system_metrics(system_metrics)
print("\nResults saved to:") print("\nResults saved to:")
print("- examples/benchmarks/benchmark_results.json") print("- examples/benchmarks/benchmark_results.json")
print("- examples/benchmarks/benchmark_stats.txt") print("- examples/benchmarks/benchmark_stats.txt")
print("- examples/benchmarks/processing_time.png") print("- examples/benchmarks/processing_time.png")
print("- examples/benchmarks/realtime_factor.png") print("- examples/benchmarks/realtime_factor.png")
print("- examples/benchmarks/system_usage.png") print("- examples/benchmarks/system_usage.png")
if any('gpu_memory_used' in m for m in system_metrics): if any("gpu_memory_used" in m for m in system_metrics):
print("- examples/benchmarks/gpu_usage.png") print("- examples/benchmarks/gpu_usage.png")
print("\nAudio files saved in examples/benchmarks/output/") print("\nAudio files saved in examples/benchmarks/output/")
if __name__ == '__main__':
if __name__ == "__main__":
main() main()

View file

@ -0,0 +1,59 @@
from pathlib import Path
import openai
import requests
SAMPLE_TEXT = """
That is the germ of my great discovery. But you are wrong to say that we cannot move about in Time. For instance, if I am recalling an incident very vividly I go back to the instant of its occurrence: I become absent-minded, as you say. I jump back for a moment.
"""
# Configure OpenAI client to use our local endpoint
client = openai.OpenAI(
timeout=60,
api_key="notneeded", # API key not required for our endpoint
base_url="http://localhost:8880/v1", # Point to our local server with v1 prefix
)
# Create output directory if it doesn't exist
output_dir = Path(__file__).parent / "output"
output_dir.mkdir(exist_ok=True)
def test_voice(voice: str):
speech_file = output_dir / f"speech_{voice}.wav"
print(f"\nTesting voice: {voice}")
print(f"Making request to {client.base_url}/audio/speech...")
try:
response = client.audio.speech.create(
model="tts-1", voice=voice, input=SAMPLE_TEXT, response_format="wav"
)
print("Got response, saving to file...")
with open(speech_file, "wb") as f:
f.write(response.content)
print(f"Success! Saved to: {speech_file}")
except Exception as e:
print(f"Error with voice {voice}: {str(e)}")
# First, get list of available voices using requests
print("Getting list of available voices...")
try:
# Convert base_url to string and ensure no double slashes
base_url = str(client.base_url).rstrip("/")
response = requests.get(f"{base_url}/audio/voices")
if response.status_code != 200:
raise Exception(f"Failed to get voices: {response.text}")
data = response.json()
if "voices" not in data:
raise Exception(f"Unexpected response format: {data}")
voices = data["voices"]
print(f"Found {len(voices)} voices: {', '.join(voices)}")
# Test each voice
for voice in voices:
test_voice(voice)
except Exception as e:
print(f"Error getting voices: {str(e)}")

View file

@ -5,56 +5,58 @@ import openai
client = openai.OpenAI( client = openai.OpenAI(
timeout=30, timeout=30,
api_key="notneeded", # API key not required for our endpoint api_key="notneeded", # API key not required for our endpoint
base_url="http://localhost:8880/v1" # Point to our local server with v1 prefix base_url="http://localhost:8880/v1", # Point to our local server with v1 prefix
) )
# Create output directory if it doesn't exist # Create output directory if it doesn't exist
output_dir = Path(__file__).parent / "output" output_dir = Path(__file__).parent / "output"
output_dir.mkdir(exist_ok=True) output_dir.mkdir(exist_ok=True)
def test_format(format: str, text: str = "The quick brown fox jumped over the lazy dog."):
def test_format(
format: str, text: str = "The quick brown fox jumped over the lazy dog."
):
speech_file = output_dir / f"speech_{format}.{format}" speech_file = output_dir / f"speech_{format}.{format}"
print(f"\nTesting {format} format...") print(f"\nTesting {format} format...")
print(f"Making request to {client.base_url}/audio/speech...") print(f"Making request to {client.base_url}/audio/speech...")
try: try:
response = client.audio.speech.create( response = client.audio.speech.create(
model="tts-1", model="tts-1", voice="af", input=text, response_format=format
voice="af",
input=text,
response_format=format
) )
print(f"Got response, saving to file...") print("Got response, saving to file...")
with open(speech_file, 'wb') as f: with open(speech_file, "wb") as f:
f.write(response.content) f.write(response.content)
print(f"Success! Saved to: {speech_file}") print(f"Success! Saved to: {speech_file}")
except Exception as e: except Exception as e:
print(f"Error: {str(e)}") print(f"Error: {str(e)}")
def test_speed(speed: float): def test_speed(speed: float):
speech_file = output_dir / f"speech_speed_{speed}.wav" speech_file = output_dir / f"speech_speed_{speed}.wav"
print(f"\nTesting speed {speed}x...") print(f"\nTesting speed {speed}x...")
print(f"Making request to {client.base_url}/audio/speech...") print(f"Making request to {client.base_url}/audio/speech...")
try: try:
response = client.audio.speech.create( response = client.audio.speech.create(
model="tts-1", model="tts-1",
voice="af", voice="af",
input="The quick brown fox jumped over the lazy dog.", input="The quick brown fox jumped over the lazy dog.",
response_format="wav", response_format="wav",
speed=speed speed=speed,
) )
print(f"Got response, saving to file...") print("Got response, saving to file...")
with open(speech_file, 'wb') as f: with open(speech_file, "wb") as f:
f.write(response.content) f.write(response.content)
print(f"Success! Saved to: {speech_file}") print(f"Success! Saved to: {speech_file}")
except Exception as e: except Exception as e:
print(f"Error: {str(e)}") print(f"Error: {str(e)}")
# Test different formats # Test different formats
for format in ["wav", "mp3", "opus", "aac", "flac", "pcm"]: for format in ["wav", "mp3", "opus", "aac", "flac", "pcm"]:
test_format(format) test_format(format)
@ -64,6 +66,9 @@ for speed in [0.25, 1.0, 2.0, 4.0]: # 5.0 should fail as it's out of range
test_speed(speed) test_speed(speed)
# Test long text # Test long text
test_format("wav", """ test_format(
"wav",
"""
That is the germ of my great discovery. But you are wrong to say that we cannot move about in Time. For instance, if I am recalling an incident very vividly I go back to the instant of its occurrence: I become absent-minded, as you say. I jump back for a moment. That is the germ of my great discovery. But you are wrong to say that we cannot move about in Time. For instance, if I am recalling an incident very vividly I go back to the instant of its occurrence: I become absent-minded, as you say. I jump back for a moment.
""") """,
)

View file

@ -24,6 +24,7 @@ tqdm==4.67.1
requests==2.32.3 requests==2.32.3
munch==4.0.0 munch==4.0.0
tiktoken===0.8.0 tiktoken===0.8.0
loguru==0.7.3
# Testing # Testing
pytest==8.0.0 pytest==8.0.0