mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Enhance TTS API with logging, voice pack loading, and schema updates
This commit is contained in:
parent
8ce8334345
commit
c11a6ea6ea
12 changed files with 451 additions and 253 deletions
|
@ -1,26 +1,38 @@
|
|||
"""
|
||||
FastAPI OpenAI Compatible API
|
||||
"""
|
||||
|
||||
import uvicorn
|
||||
import logging
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from loguru import logger
|
||||
|
||||
from .core.config import settings
|
||||
from .routers.openai_compatible import router as openai_router
|
||||
# from services.tts import TTSModel
|
||||
from .services.tts import TTSModel, TTSService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Lifespan context manager for database initialization"""
|
||||
"""Lifespan context manager for model initialization"""
|
||||
logger.info("Loading TTS model and voice packs...")
|
||||
|
||||
# Initialize the main model
|
||||
model, device = TTSModel.get_instance()
|
||||
logger.info(f"Model loaded on {device}")
|
||||
|
||||
# Initialize all voice packs
|
||||
tts_service = TTSService()
|
||||
voices = tts_service.list_voices()
|
||||
for voice in voices:
|
||||
logger.info(f"Loading voice pack: {voice}")
|
||||
TTSModel.get_voicepack(voice)
|
||||
|
||||
logger.info("All models and voice packs loaded successfully")
|
||||
yield
|
||||
|
||||
|
||||
# Initialize FastAPI app
|
||||
app = FastAPI(
|
||||
title=settings.api_title,
|
||||
|
@ -42,16 +54,19 @@ app.add_middleware(
|
|||
# Include OpenAI compatible router
|
||||
app.include_router(openai_router, prefix="/v1")
|
||||
|
||||
|
||||
# Health check endpoint
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
@app.get("/v1/test")
|
||||
async def test_endpoint():
|
||||
"""Test endpoint to verify routing"""
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run("api.src.main:app", host=settings.host, port=settings.port, reload=True)
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
from fastapi import APIRouter, HTTPException, Response, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
import logging
|
||||
from ..structures.schemas import OpenAISpeechRequest
|
||||
from ..services.tts import TTSService
|
||||
|
@ -12,14 +11,17 @@ router = APIRouter(
|
|||
responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
|
||||
def get_tts_service() -> TTSService:
|
||||
"""Dependency to get TTSService instance with database session"""
|
||||
return TTSService(start_worker=False) # Don't start worker thread for OpenAI endpoint
|
||||
return TTSService(
|
||||
start_worker=False
|
||||
) # Don't start worker thread for OpenAI endpoint
|
||||
|
||||
|
||||
@router.post("/audio/speech")
|
||||
async def create_speech(
|
||||
request: OpenAISpeechRequest,
|
||||
tts_service: TTSService = Depends(get_tts_service)
|
||||
request: OpenAISpeechRequest, tts_service: TTSService = Depends(get_tts_service)
|
||||
):
|
||||
"""OpenAI-compatible endpoint for text-to-speech"""
|
||||
try:
|
||||
|
@ -28,7 +30,7 @@ async def create_speech(
|
|||
text=request.input,
|
||||
voice=request.voice,
|
||||
speed=request.speed,
|
||||
stitch_long_output=True
|
||||
stitch_long_output=True,
|
||||
)
|
||||
|
||||
# Convert to requested format
|
||||
|
@ -39,17 +41,16 @@ async def create_speech(
|
|||
media_type=f"audio/{request.response_format}",
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=speech.{request.response_format}"
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating speech: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/audio/voices")
|
||||
async def list_voices(
|
||||
tts_service: TTSService = Depends(get_tts_service)
|
||||
):
|
||||
async def list_voices(tts_service: TTSService = Depends(get_tts_service)):
|
||||
"""List all available voices for text-to-speech"""
|
||||
try:
|
||||
voices = tts_service.list_voices()
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""Audio conversion service"""
|
||||
|
||||
from io import BytesIO
|
||||
import numpy as np
|
||||
import scipy.io.wavfile as wavfile
|
||||
|
@ -7,11 +8,14 @@ import logging
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AudioService:
|
||||
"""Service for audio format conversions"""
|
||||
|
||||
@staticmethod
|
||||
def convert_audio(audio_data: np.ndarray, sample_rate: int, output_format: str) -> bytes:
|
||||
def convert_audio(
|
||||
audio_data: np.ndarray, sample_rate: int, output_format: str
|
||||
) -> bytes:
|
||||
"""Convert audio data to specified format
|
||||
|
||||
Args:
|
||||
|
@ -25,12 +29,12 @@ class AudioService:
|
|||
buffer = BytesIO()
|
||||
|
||||
try:
|
||||
if output_format == 'wav':
|
||||
if output_format == "wav":
|
||||
logger.info("Writing to WAV format...")
|
||||
wavfile.write(buffer, sample_rate, audio_data)
|
||||
return buffer.getvalue()
|
||||
|
||||
elif output_format == 'mp3':
|
||||
elif output_format == "mp3":
|
||||
# For MP3, we need to convert to WAV first
|
||||
logger.info("Converting to MP3 format...")
|
||||
wav_buffer = BytesIO()
|
||||
|
@ -39,27 +43,33 @@ class AudioService:
|
|||
|
||||
# Convert WAV to MP3 using soundfile
|
||||
buffer = BytesIO()
|
||||
sf.write(buffer, audio_data, sample_rate, format='mp3')
|
||||
sf.write(buffer, audio_data, sample_rate, format="mp3")
|
||||
return buffer.getvalue()
|
||||
|
||||
elif output_format == 'opus':
|
||||
elif output_format == "opus":
|
||||
logger.info("Converting to Opus format...")
|
||||
sf.write(buffer, audio_data, sample_rate, format='ogg', subtype='opus')
|
||||
sf.write(buffer, audio_data, sample_rate, format="ogg", subtype="opus")
|
||||
return buffer.getvalue()
|
||||
|
||||
elif output_format == 'flac':
|
||||
elif output_format == "flac":
|
||||
logger.info("Converting to FLAC format...")
|
||||
sf.write(buffer, audio_data, sample_rate, format='flac')
|
||||
sf.write(buffer, audio_data, sample_rate, format="flac")
|
||||
return buffer.getvalue()
|
||||
|
||||
elif output_format == 'aac':
|
||||
raise ValueError("AAC format is not currently supported. Please use wav, mp3, opus, or flac.")
|
||||
elif output_format == "aac":
|
||||
raise ValueError(
|
||||
"AAC format is not currently supported. Please use wav, mp3, opus, or flac."
|
||||
)
|
||||
|
||||
elif output_format == 'pcm':
|
||||
raise ValueError("PCM format is not currently supported. Please use wav, mp3, opus, or flac.")
|
||||
elif output_format == "pcm":
|
||||
raise ValueError(
|
||||
"PCM format is not currently supported. Please use wav, mp3, opus, or flac."
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Format {output_format} not supported. Supported formats are: wav, mp3, opus, flac.")
|
||||
raise ValueError(
|
||||
f"Format {output_format} not supported. Supported formats are: wav, mp3, opus, flac."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting audio to {output_format}: {str(e)}")
|
||||
|
|
|
@ -2,8 +2,7 @@ import os
|
|||
import threading
|
||||
import time
|
||||
import io
|
||||
from typing import Optional, List, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Tuple
|
||||
import numpy as np
|
||||
import torch
|
||||
import scipy.io.wavfile as wavfile
|
||||
|
@ -17,6 +16,7 @@ import tiktoken
|
|||
logger = logging.getLogger(__name__)
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
|
||||
class TTSModel:
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
@ -40,7 +40,9 @@ class TTSModel:
|
|||
model, device = cls.get_instance()
|
||||
if voice_name not in cls._voicepacks:
|
||||
try:
|
||||
voice_path = os.path.join(settings.model_dir, settings.voices_dir, f"{voice_name}.pt")
|
||||
voice_path = os.path.join(
|
||||
settings.model_dir, settings.voices_dir, f"{voice_name}.pt"
|
||||
)
|
||||
voicepack = torch.load(
|
||||
voice_path, map_location=device, weights_only=True
|
||||
)
|
||||
|
@ -61,9 +63,11 @@ class TTSService:
|
|||
|
||||
def _split_text(self, text: str) -> List[str]:
|
||||
"""Split text into sentences"""
|
||||
return [s.strip() for s in re.split(r'(?<=[.!?])\s+', text) if s.strip()]
|
||||
return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()]
|
||||
|
||||
def _generate_audio(self, text: str, voice: str, speed: float, stitch_long_output: bool = True) -> Tuple[torch.Tensor, float]:
|
||||
def _generate_audio(
|
||||
self, text: str, voice: str, speed: float, stitch_long_output: bool = True
|
||||
) -> Tuple[torch.Tensor, float]:
|
||||
"""Generate audio and measure processing time"""
|
||||
start_time = time.time()
|
||||
|
||||
|
@ -87,22 +91,34 @@ class TTSService:
|
|||
# Validate phonemization first
|
||||
ps = phonemize(chunk, voice[0])
|
||||
tokens = tokenize(ps)
|
||||
logger.info(f"Processing chunk {i+1}/{len(chunks)}: {len(tokens)} tokens")
|
||||
logger.info(
|
||||
f"Processing chunk {i+1}/{len(chunks)}: {len(tokens)} tokens"
|
||||
)
|
||||
|
||||
# Only proceed if phonemization succeeded
|
||||
chunk_audio, _ = generate(model, chunk, voicepack, lang=voice[0], speed=speed)
|
||||
chunk_audio, _ = generate(
|
||||
model, chunk, voicepack, lang=voice[0], speed=speed
|
||||
)
|
||||
if chunk_audio is not None:
|
||||
audio_chunks.append(chunk_audio)
|
||||
else:
|
||||
logger.error(f"No audio generated for chunk {i+1}/{len(chunks)}")
|
||||
logger.error(
|
||||
f"No audio generated for chunk {i+1}/{len(chunks)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate audio for chunk {i+1}/{len(chunks)}: '{chunk}'. Error: {str(e)}")
|
||||
logger.error(
|
||||
f"Failed to generate audio for chunk {i+1}/{len(chunks)}: '{chunk}'. Error: {str(e)}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not audio_chunks:
|
||||
raise ValueError("No audio chunks were generated successfully")
|
||||
|
||||
audio = np.concatenate(audio_chunks) if len(audio_chunks) > 1 else audio_chunks[0]
|
||||
audio = (
|
||||
np.concatenate(audio_chunks)
|
||||
if len(audio_chunks) > 1
|
||||
else audio_chunks[0]
|
||||
)
|
||||
else:
|
||||
audio, _ = generate(model, text, voicepack, lang=voice[0], speed=speed)
|
||||
|
||||
|
|
|
@ -15,17 +15,24 @@ class TTSStatus(str, Enum):
|
|||
class OpenAISpeechRequest(BaseModel):
|
||||
model: Literal["tts-1", "tts-1-hd"] = "tts-1"
|
||||
input: str = Field(..., description="The text to generate audio for")
|
||||
voice: Literal["am_adam", "am_michael", "bm_lewis", "af", "bm_george", "bf_isabella", "bf_emma", "af_sarah", "af_bella"] = Field(
|
||||
default="af",
|
||||
description="The voice to use for generation"
|
||||
)
|
||||
voice: Literal[
|
||||
"am_adam",
|
||||
"am_michael",
|
||||
"bm_lewis",
|
||||
"af",
|
||||
"bm_george",
|
||||
"bf_isabella",
|
||||
"bf_emma",
|
||||
"af_sarah",
|
||||
"af_bella",
|
||||
] = Field(default="af", description="The voice to use for generation")
|
||||
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field(
|
||||
default="mp3",
|
||||
description="The format to return audio in. Supported formats: mp3, opus, flac, wav. AAC and PCM are not currently supported."
|
||||
description="The format to return audio in. Supported formats: mp3, opus, flac, wav. AAC and PCM are not currently supported.",
|
||||
)
|
||||
speed: float = Field(
|
||||
default=1.0,
|
||||
ge=0.25,
|
||||
le=4.0,
|
||||
description="The speed of the generated audio. Select a value from 0.25 to 4.0."
|
||||
description="The speed of the generated audio. Select a value from 0.25 to 4.0.",
|
||||
)
|
||||
|
|
|
@ -3,17 +3,15 @@ from unittest.mock import Mock, patch
|
|||
import sys
|
||||
|
||||
# Mock torch and other ML modules before they're imported
|
||||
sys.modules['torch'] = Mock()
|
||||
sys.modules['transformers'] = Mock()
|
||||
sys.modules['phonemizer'] = Mock()
|
||||
sys.modules['models'] = Mock()
|
||||
sys.modules['models.build_model'] = Mock()
|
||||
sys.modules['kokoro'] = Mock()
|
||||
sys.modules['kokoro.generate'] = Mock()
|
||||
sys.modules['kokoro.phonemize'] = Mock()
|
||||
sys.modules['kokoro.tokenize'] = Mock()
|
||||
|
||||
from api.src.main import app
|
||||
sys.modules["torch"] = Mock()
|
||||
sys.modules["transformers"] = Mock()
|
||||
sys.modules["phonemizer"] = Mock()
|
||||
sys.modules["models"] = Mock()
|
||||
sys.modules["models.build_model"] = Mock()
|
||||
sys.modules["kokoro"] = Mock()
|
||||
sys.modules["kokoro.generate"] = Mock()
|
||||
sys.modules["kokoro.phonemize"] = Mock()
|
||||
sys.modules["kokoro.tokenize"] = Mock()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
|
|
@ -1,27 +1,44 @@
|
|||
from fastapi.testclient import TestClient
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import Mock
|
||||
from ..src.main import app
|
||||
from ..src.services.tts import TTSService
|
||||
from ..src.routers.openai_compatible import TTSService as OpenAITTSService
|
||||
|
||||
# Create test client
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
# Mock services
|
||||
@pytest.fixture
|
||||
def mock_tts_service(monkeypatch):
|
||||
mock_service = Mock()
|
||||
mock_service._generate_audio.return_value = (bytes([0, 1, 2, 3]), 1.0)
|
||||
mock_service.list_voices.return_value = ["af", "bm_lewis", "bf_isabella", "bf_emma", "af_sarah", "af_bella", "am_adam", "am_michael", "bm_george"]
|
||||
monkeypatch.setattr("api.src.routers.openai_compatible.TTSService", lambda *args, **kwargs: mock_service)
|
||||
mock_service.list_voices.return_value = [
|
||||
"af",
|
||||
"bm_lewis",
|
||||
"bf_isabella",
|
||||
"bf_emma",
|
||||
"af_sarah",
|
||||
"af_bella",
|
||||
"am_adam",
|
||||
"am_michael",
|
||||
"bm_george",
|
||||
]
|
||||
monkeypatch.setattr(
|
||||
"api.src.routers.openai_compatible.TTSService",
|
||||
lambda *args, **kwargs: mock_service,
|
||||
)
|
||||
return mock_service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_audio_service(monkeypatch):
|
||||
def mock_convert(*args):
|
||||
return b"converted mock audio data"
|
||||
monkeypatch.setattr("api.src.routers.openai_compatible.AudioService.convert_audio", mock_convert)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"api.src.routers.openai_compatible.AudioService.convert_audio", mock_convert
|
||||
)
|
||||
|
||||
|
||||
def test_health_check():
|
||||
"""Test the health check endpoint"""
|
||||
|
@ -29,6 +46,7 @@ def test_health_check():
|
|||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "healthy"}
|
||||
|
||||
|
||||
def test_openai_speech_endpoint(mock_tts_service, mock_audio_service):
|
||||
"""Test the OpenAI-compatible speech endpoint"""
|
||||
test_request = {
|
||||
|
@ -36,20 +54,18 @@ def test_openai_speech_endpoint(mock_tts_service, mock_audio_service):
|
|||
"input": "Hello world",
|
||||
"voice": "bm_lewis",
|
||||
"response_format": "wav",
|
||||
"speed": 1.0
|
||||
"speed": 1.0,
|
||||
}
|
||||
response = client.post("/v1/audio/speech", json=test_request)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "audio/wav"
|
||||
assert response.headers["content-disposition"] == "attachment; filename=speech.wav"
|
||||
mock_tts_service._generate_audio.assert_called_once_with(
|
||||
text="Hello world",
|
||||
voice="bm_lewis",
|
||||
speed=1.0,
|
||||
stitch_long_output=True
|
||||
text="Hello world", voice="bm_lewis", speed=1.0, stitch_long_output=True
|
||||
)
|
||||
assert response.content == b"converted mock audio data"
|
||||
|
||||
|
||||
def test_openai_speech_invalid_voice(mock_tts_service):
|
||||
"""Test the OpenAI-compatible speech endpoint with invalid voice"""
|
||||
test_request = {
|
||||
|
@ -57,11 +73,12 @@ def test_openai_speech_invalid_voice(mock_tts_service):
|
|||
"input": "Hello world",
|
||||
"voice": "invalid_voice",
|
||||
"response_format": "wav",
|
||||
"speed": 1.0
|
||||
"speed": 1.0,
|
||||
}
|
||||
response = client.post("/v1/audio/speech", json=test_request)
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
|
||||
def test_openai_speech_invalid_speed(mock_tts_service):
|
||||
"""Test the OpenAI-compatible speech endpoint with invalid speed"""
|
||||
test_request = {
|
||||
|
@ -69,11 +86,12 @@ def test_openai_speech_invalid_speed(mock_tts_service):
|
|||
"input": "Hello world",
|
||||
"voice": "af",
|
||||
"response_format": "wav",
|
||||
"speed": -1.0 # Invalid speed
|
||||
"speed": -1.0, # Invalid speed
|
||||
}
|
||||
response = client.post("/v1/audio/speech", json=test_request)
|
||||
assert response.status_code == 422 # Validation error
|
||||
|
||||
|
||||
def test_openai_speech_generation_error(mock_tts_service):
|
||||
"""Test error handling in speech generation"""
|
||||
mock_tts_service._generate_audio.side_effect = Exception("Generation failed")
|
||||
|
@ -82,7 +100,7 @@ def test_openai_speech_generation_error(mock_tts_service):
|
|||
"input": "Hello world",
|
||||
"voice": "af",
|
||||
"response_format": "wav",
|
||||
"speed": 1.0
|
||||
"speed": 1.0,
|
||||
}
|
||||
response = client.post("/v1/audio/speech", json=test_request)
|
||||
assert response.status_code == 500
|
||||
|
|
|
@ -3,11 +3,9 @@ import time
|
|||
import json
|
||||
import scipy.io.wavfile as wavfile
|
||||
import requests
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
from scipy.signal import savgol_filter
|
||||
import tiktoken
|
||||
import psutil
|
||||
import subprocess
|
||||
|
@ -15,31 +13,33 @@ from datetime import datetime
|
|||
|
||||
enc = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
|
||||
def setup_plot(fig, ax, title):
|
||||
"""Configure plot styling"""
|
||||
# Improve grid
|
||||
ax.grid(True, linestyle='--', alpha=0.3, color='#ffffff')
|
||||
ax.grid(True, linestyle="--", alpha=0.3, color="#ffffff")
|
||||
|
||||
# Set title and labels with better fonts
|
||||
ax.set_title(title, pad=20, fontsize=16, fontweight='bold', color='#ffffff')
|
||||
ax.set_xlabel(ax.get_xlabel(), fontsize=14, fontweight='medium', color='#ffffff')
|
||||
ax.set_ylabel(ax.get_ylabel(), fontsize=14, fontweight='medium', color='#ffffff')
|
||||
ax.set_title(title, pad=20, fontsize=16, fontweight="bold", color="#ffffff")
|
||||
ax.set_xlabel(ax.get_xlabel(), fontsize=14, fontweight="medium", color="#ffffff")
|
||||
ax.set_ylabel(ax.get_ylabel(), fontsize=14, fontweight="medium", color="#ffffff")
|
||||
|
||||
# Improve tick labels
|
||||
ax.tick_params(labelsize=12, colors='#ffffff')
|
||||
ax.tick_params(labelsize=12, colors="#ffffff")
|
||||
|
||||
# Style spines
|
||||
for spine in ax.spines.values():
|
||||
spine.set_color('#ffffff')
|
||||
spine.set_color("#ffffff")
|
||||
spine.set_alpha(0.3)
|
||||
spine.set_linewidth(0.5)
|
||||
|
||||
# Set background colors
|
||||
ax.set_facecolor('#1a1a2e')
|
||||
fig.patch.set_facecolor('#1a1a2e')
|
||||
ax.set_facecolor("#1a1a2e")
|
||||
fig.patch.set_facecolor("#1a1a2e")
|
||||
|
||||
return fig, ax
|
||||
|
||||
|
||||
def get_text_for_tokens(text: str, num_tokens: int) -> str:
|
||||
"""Get a slice of text that contains exactly num_tokens tokens"""
|
||||
tokens = enc.encode(text)
|
||||
|
@ -47,12 +47,13 @@ def get_text_for_tokens(text: str, num_tokens: int) -> str:
|
|||
return text
|
||||
return enc.decode(tokens[:num_tokens])
|
||||
|
||||
|
||||
def get_audio_length(audio_data: bytes) -> float:
|
||||
"""Get audio length in seconds from bytes data"""
|
||||
# Save to a temporary file
|
||||
temp_path = 'examples/benchmarks/output/temp.wav'
|
||||
temp_path = "examples/benchmarks/output/temp.wav"
|
||||
os.makedirs(os.path.dirname(temp_path), exist_ok=True)
|
||||
with open(temp_path, 'wb') as f:
|
||||
with open(temp_path, "wb") as f:
|
||||
f.write(audio_data)
|
||||
|
||||
# Read the audio file
|
||||
|
@ -64,29 +65,34 @@ def get_audio_length(audio_data: bytes) -> float:
|
|||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
|
||||
def get_gpu_memory():
|
||||
"""Get GPU memory usage using nvidia-smi"""
|
||||
try:
|
||||
result = subprocess.check_output(['nvidia-smi', '--query-gpu=memory.used', '--format=csv,nounits,noheader'])
|
||||
return float(result.decode('utf-8').strip())
|
||||
result = subprocess.check_output(
|
||||
["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader"]
|
||||
)
|
||||
return float(result.decode("utf-8").strip())
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
return None
|
||||
|
||||
|
||||
def get_system_metrics():
|
||||
"""Get current system metrics"""
|
||||
metrics = {
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'cpu_percent': psutil.cpu_percent(),
|
||||
'ram_percent': psutil.virtual_memory().percent,
|
||||
'ram_used_gb': psutil.virtual_memory().used / (1024**3),
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"cpu_percent": psutil.cpu_percent(),
|
||||
"ram_percent": psutil.virtual_memory().percent,
|
||||
"ram_used_gb": psutil.virtual_memory().used / (1024**3),
|
||||
}
|
||||
|
||||
gpu_mem = get_gpu_memory()
|
||||
if gpu_mem is not None:
|
||||
metrics['gpu_memory_used'] = gpu_mem
|
||||
metrics["gpu_memory_used"] = gpu_mem
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def make_tts_request(text: str, timeout: int = 120) -> tuple[float, float]:
|
||||
"""Make TTS request using OpenAI-compatible endpoint and return processing time and output length"""
|
||||
try:
|
||||
|
@ -94,14 +100,14 @@ def make_tts_request(text: str, timeout: int = 120) -> tuple[float, float]:
|
|||
|
||||
# Make request to OpenAI-compatible endpoint
|
||||
response = requests.post(
|
||||
'http://localhost:8880/v1/audio/speech',
|
||||
"http://localhost:8880/v1/audio/speech",
|
||||
json={
|
||||
'model': 'tts-1',
|
||||
'input': text,
|
||||
'voice': 'af',
|
||||
'response_format': 'wav'
|
||||
"model": "tts-1",
|
||||
"input": text,
|
||||
"voice": "af",
|
||||
"response_format": "wav",
|
||||
},
|
||||
timeout=timeout
|
||||
timeout=timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
|
@ -110,9 +116,9 @@ def make_tts_request(text: str, timeout: int = 120) -> tuple[float, float]:
|
|||
|
||||
# Save the audio file
|
||||
token_count = len(enc.encode(text))
|
||||
output_file = f'examples/benchmarks/output/chunk_{token_count}_tokens.wav'
|
||||
output_file = f"examples/benchmarks/output/chunk_{token_count}_tokens.wav"
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
with open(output_file, 'wb') as f:
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(response.content)
|
||||
print(f"Saved audio to {output_file}")
|
||||
|
||||
|
@ -125,85 +131,107 @@ def make_tts_request(text: str, timeout: int = 120) -> tuple[float, float]:
|
|||
print(f"Error processing text: {text[:50]}... Error: {str(e)}")
|
||||
return None, None
|
||||
|
||||
|
||||
def plot_system_metrics(metrics_data):
|
||||
"""Create plots for system metrics over time"""
|
||||
df = pd.DataFrame(metrics_data)
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'])
|
||||
elapsed_time = (df['timestamp'] - df['timestamp'].iloc[0]).dt.total_seconds()
|
||||
df["timestamp"] = pd.to_datetime(df["timestamp"])
|
||||
elapsed_time = (df["timestamp"] - df["timestamp"].iloc[0]).dt.total_seconds()
|
||||
|
||||
# Get baseline values (first measurement)
|
||||
baseline_cpu = df['cpu_percent'].iloc[0]
|
||||
baseline_ram = df['ram_used_gb'].iloc[0]
|
||||
baseline_gpu = df['gpu_memory_used'].iloc[0] / 1024 if 'gpu_memory_used' in df.columns else None # Convert MB to GB
|
||||
baseline_cpu = df["cpu_percent"].iloc[0]
|
||||
baseline_ram = df["ram_used_gb"].iloc[0]
|
||||
baseline_gpu = (
|
||||
df["gpu_memory_used"].iloc[0] / 1024
|
||||
if "gpu_memory_used" in df.columns
|
||||
else None
|
||||
) # Convert MB to GB
|
||||
|
||||
# Convert GPU memory to GB
|
||||
if 'gpu_memory_used' in df.columns:
|
||||
df['gpu_memory_gb'] = df['gpu_memory_used'] / 1024
|
||||
if "gpu_memory_used" in df.columns:
|
||||
df["gpu_memory_gb"] = df["gpu_memory_used"] / 1024
|
||||
|
||||
# Set plotting style
|
||||
plt.style.use('dark_background')
|
||||
plt.style.use("dark_background")
|
||||
|
||||
# Create figure with 3 subplots (or 2 if no GPU)
|
||||
has_gpu = 'gpu_memory_used' in df.columns
|
||||
has_gpu = "gpu_memory_used" in df.columns
|
||||
num_plots = 3 if has_gpu else 2
|
||||
fig, axes = plt.subplots(num_plots, 1, figsize=(15, 5*num_plots))
|
||||
fig.patch.set_facecolor('#1a1a2e')
|
||||
fig, axes = plt.subplots(num_plots, 1, figsize=(15, 5 * num_plots))
|
||||
fig.patch.set_facecolor("#1a1a2e")
|
||||
|
||||
# Apply rolling average for smoothing
|
||||
window = min(5, len(df) // 2) # Smaller window for smoother lines
|
||||
|
||||
# Plot 1: CPU Usage
|
||||
smoothed_cpu = df['cpu_percent'].rolling(window=window, center=True).mean()
|
||||
sns.lineplot(x=elapsed_time, y=smoothed_cpu, ax=axes[0], color='#ff2a6d', linewidth=2)
|
||||
axes[0].axhline(y=baseline_cpu, color='#05d9e8', linestyle='--', alpha=0.5, label='Baseline')
|
||||
axes[0].set_xlabel('Time (seconds)', fontsize=14)
|
||||
axes[0].set_ylabel('CPU Usage (%)', fontsize=14)
|
||||
smoothed_cpu = df["cpu_percent"].rolling(window=window, center=True).mean()
|
||||
sns.lineplot(
|
||||
x=elapsed_time, y=smoothed_cpu, ax=axes[0], color="#ff2a6d", linewidth=2
|
||||
)
|
||||
axes[0].axhline(
|
||||
y=baseline_cpu, color="#05d9e8", linestyle="--", alpha=0.5, label="Baseline"
|
||||
)
|
||||
axes[0].set_xlabel("Time (seconds)", fontsize=14)
|
||||
axes[0].set_ylabel("CPU Usage (%)", fontsize=14)
|
||||
axes[0].tick_params(labelsize=12)
|
||||
axes[0].set_title('CPU Usage Over Time', pad=20, fontsize=16, fontweight='bold')
|
||||
axes[0].set_ylim(0, max(df['cpu_percent']) * 1.1) # Add 10% padding
|
||||
axes[0].set_title("CPU Usage Over Time", pad=20, fontsize=16, fontweight="bold")
|
||||
axes[0].set_ylim(0, max(df["cpu_percent"]) * 1.1) # Add 10% padding
|
||||
axes[0].legend()
|
||||
|
||||
# Plot 2: RAM Usage
|
||||
smoothed_ram = df['ram_used_gb'].rolling(window=window, center=True).mean()
|
||||
sns.lineplot(x=elapsed_time, y=smoothed_ram, ax=axes[1], color='#05d9e8', linewidth=2)
|
||||
axes[1].axhline(y=baseline_ram, color='#ff2a6d', linestyle='--', alpha=0.5, label='Baseline')
|
||||
axes[1].set_xlabel('Time (seconds)', fontsize=14)
|
||||
axes[1].set_ylabel('RAM Usage (GB)', fontsize=14)
|
||||
smoothed_ram = df["ram_used_gb"].rolling(window=window, center=True).mean()
|
||||
sns.lineplot(
|
||||
x=elapsed_time, y=smoothed_ram, ax=axes[1], color="#05d9e8", linewidth=2
|
||||
)
|
||||
axes[1].axhline(
|
||||
y=baseline_ram, color="#ff2a6d", linestyle="--", alpha=0.5, label="Baseline"
|
||||
)
|
||||
axes[1].set_xlabel("Time (seconds)", fontsize=14)
|
||||
axes[1].set_ylabel("RAM Usage (GB)", fontsize=14)
|
||||
axes[1].tick_params(labelsize=12)
|
||||
axes[1].set_title('RAM Usage Over Time', pad=20, fontsize=16, fontweight='bold')
|
||||
axes[1].set_ylim(0, max(df['ram_used_gb']) * 1.1) # Add 10% padding
|
||||
axes[1].set_title("RAM Usage Over Time", pad=20, fontsize=16, fontweight="bold")
|
||||
axes[1].set_ylim(0, max(df["ram_used_gb"]) * 1.1) # Add 10% padding
|
||||
axes[1].legend()
|
||||
|
||||
# Plot 3: GPU Memory (if available)
|
||||
if has_gpu:
|
||||
smoothed_gpu = df['gpu_memory_gb'].rolling(window=window, center=True).mean()
|
||||
sns.lineplot(x=elapsed_time, y=smoothed_gpu, ax=axes[2], color='#ff2a6d', linewidth=2)
|
||||
axes[2].axhline(y=baseline_gpu, color='#05d9e8', linestyle='--', alpha=0.5, label='Baseline')
|
||||
axes[2].set_xlabel('Time (seconds)', fontsize=14)
|
||||
axes[2].set_ylabel('GPU Memory (GB)', fontsize=14)
|
||||
smoothed_gpu = df["gpu_memory_gb"].rolling(window=window, center=True).mean()
|
||||
sns.lineplot(
|
||||
x=elapsed_time, y=smoothed_gpu, ax=axes[2], color="#ff2a6d", linewidth=2
|
||||
)
|
||||
axes[2].axhline(
|
||||
y=baseline_gpu, color="#05d9e8", linestyle="--", alpha=0.5, label="Baseline"
|
||||
)
|
||||
axes[2].set_xlabel("Time (seconds)", fontsize=14)
|
||||
axes[2].set_ylabel("GPU Memory (GB)", fontsize=14)
|
||||
axes[2].tick_params(labelsize=12)
|
||||
axes[2].set_title('GPU Memory Usage Over Time', pad=20, fontsize=16, fontweight='bold')
|
||||
axes[2].set_ylim(0, max(df['gpu_memory_gb']) * 1.1) # Add 10% padding
|
||||
axes[2].set_title(
|
||||
"GPU Memory Usage Over Time", pad=20, fontsize=16, fontweight="bold"
|
||||
)
|
||||
axes[2].set_ylim(0, max(df["gpu_memory_gb"]) * 1.1) # Add 10% padding
|
||||
axes[2].legend()
|
||||
|
||||
# Style all subplots
|
||||
for ax in axes:
|
||||
ax.grid(True, linestyle='--', alpha=0.3)
|
||||
ax.set_facecolor('#1a1a2e')
|
||||
ax.grid(True, linestyle="--", alpha=0.3)
|
||||
ax.set_facecolor("#1a1a2e")
|
||||
for spine in ax.spines.values():
|
||||
spine.set_color('#ffffff')
|
||||
spine.set_color("#ffffff")
|
||||
spine.set_alpha(0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('examples/benchmarks/system_usage.png', dpi=300, bbox_inches='tight')
|
||||
plt.savefig("examples/benchmarks/system_usage.png", dpi=300, bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
|
||||
def main():
|
||||
# Create output directory
|
||||
os.makedirs('examples/benchmarks/output', exist_ok=True)
|
||||
os.makedirs("examples/benchmarks/output", exist_ok=True)
|
||||
|
||||
# Read input text
|
||||
with open('examples/benchmarks/the_time_machine_hg_wells.txt', 'r', encoding='utf-8') as f:
|
||||
with open(
|
||||
"examples/benchmarks/the_time_machine_hg_wells.txt", "r", encoding="utf-8"
|
||||
) as f:
|
||||
text = f.read()
|
||||
|
||||
# Get total tokens in file
|
||||
|
@ -245,20 +273,21 @@ def main():
|
|||
# Collect system metrics after processing
|
||||
system_metrics.append(get_system_metrics())
|
||||
|
||||
results.append({
|
||||
'tokens': actual_tokens,
|
||||
'processing_time': processing_time,
|
||||
'output_length': audio_length,
|
||||
'realtime_factor': audio_length / processing_time,
|
||||
'elapsed_time': time.time() - test_start_time
|
||||
})
|
||||
results.append(
|
||||
{
|
||||
"tokens": actual_tokens,
|
||||
"processing_time": processing_time,
|
||||
"output_length": audio_length,
|
||||
"realtime_factor": audio_length / processing_time,
|
||||
"elapsed_time": time.time() - test_start_time,
|
||||
}
|
||||
)
|
||||
|
||||
# Save intermediate results
|
||||
with open('examples/benchmarks/benchmark_results.json', 'w') as f:
|
||||
json.dump({
|
||||
'results': results,
|
||||
'system_metrics': system_metrics
|
||||
}, f, indent=2)
|
||||
with open("examples/benchmarks/benchmark_results.json", "w") as f:
|
||||
json.dump(
|
||||
{"results": results, "system_metrics": system_metrics}, f, indent=2
|
||||
)
|
||||
|
||||
# Create DataFrame and calculate stats
|
||||
df = pd.DataFrame(results)
|
||||
|
@ -267,17 +296,19 @@ def main():
|
|||
return
|
||||
|
||||
# Calculate useful metrics
|
||||
df['tokens_per_second'] = df['tokens'] / df['processing_time']
|
||||
df["tokens_per_second"] = df["tokens"] / df["processing_time"]
|
||||
|
||||
# Write detailed stats
|
||||
with open('examples/benchmarks/benchmark_stats.txt', 'w') as f:
|
||||
with open("examples/benchmarks/benchmark_stats.txt", "w") as f:
|
||||
f.write("=== Benchmark Statistics ===\n\n")
|
||||
|
||||
f.write("Overall Stats:\n")
|
||||
f.write(f"Total tokens processed: {df['tokens'].sum()}\n")
|
||||
f.write(f"Total audio generated: {df['output_length'].sum():.2f}s\n")
|
||||
f.write(f"Total test duration: {df['elapsed_time'].max():.2f}s\n")
|
||||
f.write(f"Average processing rate: {df['tokens_per_second'].mean():.2f} tokens/second\n")
|
||||
f.write(
|
||||
f"Average processing rate: {df['tokens_per_second'].mean():.2f} tokens/second\n"
|
||||
)
|
||||
f.write(f"Average realtime factor: {df['realtime_factor'].mean():.2f}x\n\n")
|
||||
|
||||
f.write("Per-chunk Stats:\n")
|
||||
|
@ -288,36 +319,72 @@ def main():
|
|||
f.write(f"Average output length: {df['output_length'].mean():.2f}s\n\n")
|
||||
|
||||
f.write("Performance Ranges:\n")
|
||||
f.write(f"Processing rate range: {df['tokens_per_second'].min():.2f} - {df['tokens_per_second'].max():.2f} tokens/second\n")
|
||||
f.write(f"Realtime factor range: {df['realtime_factor'].min():.2f}x - {df['realtime_factor'].max():.2f}x\n")
|
||||
f.write(
|
||||
f"Processing rate range: {df['tokens_per_second'].min():.2f} - {df['tokens_per_second'].max():.2f} tokens/second\n"
|
||||
)
|
||||
f.write(
|
||||
f"Realtime factor range: {df['realtime_factor'].min():.2f}x - {df['realtime_factor'].max():.2f}x\n"
|
||||
)
|
||||
|
||||
# Set plotting style
|
||||
plt.style.use('dark_background')
|
||||
plt.style.use("dark_background")
|
||||
|
||||
# Plot 1: Processing Time vs Token Count
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
sns.scatterplot(data=df, x='tokens', y='processing_time', s=100, alpha=0.6, color='#ff2a6d')
|
||||
sns.regplot(data=df, x='tokens', y='processing_time', scatter=False, color='#05d9e8', line_kws={'linewidth': 2})
|
||||
corr = df['tokens'].corr(df['processing_time'])
|
||||
plt.text(0.05, 0.95, f'Correlation: {corr:.2f}', transform=ax.transAxes, fontsize=10, color='#ffffff',
|
||||
bbox=dict(facecolor='#1a1a2e', edgecolor='#ffffff', alpha=0.7))
|
||||
setup_plot(fig, ax, 'Processing Time vs Input Size')
|
||||
ax.set_xlabel('Number of Input Tokens')
|
||||
ax.set_ylabel('Processing Time (seconds)')
|
||||
plt.savefig('examples/benchmarks/processing_time.png', dpi=300, bbox_inches='tight')
|
||||
sns.scatterplot(
|
||||
data=df, x="tokens", y="processing_time", s=100, alpha=0.6, color="#ff2a6d"
|
||||
)
|
||||
sns.regplot(
|
||||
data=df,
|
||||
x="tokens",
|
||||
y="processing_time",
|
||||
scatter=False,
|
||||
color="#05d9e8",
|
||||
line_kws={"linewidth": 2},
|
||||
)
|
||||
corr = df["tokens"].corr(df["processing_time"])
|
||||
plt.text(
|
||||
0.05,
|
||||
0.95,
|
||||
f"Correlation: {corr:.2f}",
|
||||
transform=ax.transAxes,
|
||||
fontsize=10,
|
||||
color="#ffffff",
|
||||
bbox=dict(facecolor="#1a1a2e", edgecolor="#ffffff", alpha=0.7),
|
||||
)
|
||||
setup_plot(fig, ax, "Processing Time vs Input Size")
|
||||
ax.set_xlabel("Number of Input Tokens")
|
||||
ax.set_ylabel("Processing Time (seconds)")
|
||||
plt.savefig("examples/benchmarks/processing_time.png", dpi=300, bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
# Plot 2: Realtime Factor vs Token Count
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
sns.scatterplot(data=df, x='tokens', y='realtime_factor', s=100, alpha=0.6, color='#ff2a6d')
|
||||
sns.regplot(data=df, x='tokens', y='realtime_factor', scatter=False, color='#05d9e8', line_kws={'linewidth': 2})
|
||||
corr = df['tokens'].corr(df['realtime_factor'])
|
||||
plt.text(0.05, 0.95, f'Correlation: {corr:.2f}', transform=ax.transAxes, fontsize=10, color='#ffffff',
|
||||
bbox=dict(facecolor='#1a1a2e', edgecolor='#ffffff', alpha=0.7))
|
||||
setup_plot(fig, ax, 'Realtime Factor vs Input Size')
|
||||
ax.set_xlabel('Number of Input Tokens')
|
||||
ax.set_ylabel('Realtime Factor (output length / processing time)')
|
||||
plt.savefig('examples/benchmarks/realtime_factor.png', dpi=300, bbox_inches='tight')
|
||||
sns.scatterplot(
|
||||
data=df, x="tokens", y="realtime_factor", s=100, alpha=0.6, color="#ff2a6d"
|
||||
)
|
||||
sns.regplot(
|
||||
data=df,
|
||||
x="tokens",
|
||||
y="realtime_factor",
|
||||
scatter=False,
|
||||
color="#05d9e8",
|
||||
line_kws={"linewidth": 2},
|
||||
)
|
||||
corr = df["tokens"].corr(df["realtime_factor"])
|
||||
plt.text(
|
||||
0.05,
|
||||
0.95,
|
||||
f"Correlation: {corr:.2f}",
|
||||
transform=ax.transAxes,
|
||||
fontsize=10,
|
||||
color="#ffffff",
|
||||
bbox=dict(facecolor="#1a1a2e", edgecolor="#ffffff", alpha=0.7),
|
||||
)
|
||||
setup_plot(fig, ax, "Realtime Factor vs Input Size")
|
||||
ax.set_xlabel("Number of Input Tokens")
|
||||
ax.set_ylabel("Realtime Factor (output length / processing time)")
|
||||
plt.savefig("examples/benchmarks/realtime_factor.png", dpi=300, bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
# Plot system metrics
|
||||
|
@ -329,9 +396,10 @@ def main():
|
|||
print("- examples/benchmarks/processing_time.png")
|
||||
print("- examples/benchmarks/realtime_factor.png")
|
||||
print("- examples/benchmarks/system_usage.png")
|
||||
if any('gpu_memory_used' in m for m in system_metrics):
|
||||
if any("gpu_memory_used" in m for m in system_metrics):
|
||||
print("- examples/benchmarks/gpu_usage.png")
|
||||
print("\nAudio files saved in examples/benchmarks/output/")
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
59
examples/test_all_voices.py
Normal file
59
examples/test_all_voices.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
from pathlib import Path
|
||||
import openai
|
||||
import requests
|
||||
|
||||
SAMPLE_TEXT = """
|
||||
That is the germ of my great discovery. But you are wrong to say that we cannot move about in Time. For instance, if I am recalling an incident very vividly I go back to the instant of its occurrence: I become absent-minded, as you say. I jump back for a moment.
|
||||
"""
|
||||
|
||||
# Configure OpenAI client to use our local endpoint
|
||||
client = openai.OpenAI(
|
||||
timeout=60,
|
||||
api_key="notneeded", # API key not required for our endpoint
|
||||
base_url="http://localhost:8880/v1", # Point to our local server with v1 prefix
|
||||
)
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
output_dir = Path(__file__).parent / "output"
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
def test_voice(voice: str):
|
||||
speech_file = output_dir / f"speech_{voice}.wav"
|
||||
print(f"\nTesting voice: {voice}")
|
||||
print(f"Making request to {client.base_url}/audio/speech...")
|
||||
|
||||
try:
|
||||
response = client.audio.speech.create(
|
||||
model="tts-1", voice=voice, input=SAMPLE_TEXT, response_format="wav"
|
||||
)
|
||||
|
||||
print("Got response, saving to file...")
|
||||
with open(speech_file, "wb") as f:
|
||||
f.write(response.content)
|
||||
print(f"Success! Saved to: {speech_file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error with voice {voice}: {str(e)}")
|
||||
|
||||
|
||||
# First, get list of available voices using requests
|
||||
print("Getting list of available voices...")
|
||||
try:
|
||||
# Convert base_url to string and ensure no double slashes
|
||||
base_url = str(client.base_url).rstrip("/")
|
||||
response = requests.get(f"{base_url}/audio/voices")
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to get voices: {response.text}")
|
||||
data = response.json()
|
||||
if "voices" not in data:
|
||||
raise Exception(f"Unexpected response format: {data}")
|
||||
voices = data["voices"]
|
||||
print(f"Found {len(voices)} voices: {', '.join(voices)}")
|
||||
|
||||
# Test each voice
|
||||
for voice in voices:
|
||||
test_voice(voice)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error getting voices: {str(e)}")
|
|
@ -5,34 +5,35 @@ import openai
|
|||
client = openai.OpenAI(
|
||||
timeout=30,
|
||||
api_key="notneeded", # API key not required for our endpoint
|
||||
base_url="http://localhost:8880/v1" # Point to our local server with v1 prefix
|
||||
base_url="http://localhost:8880/v1", # Point to our local server with v1 prefix
|
||||
)
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
output_dir = Path(__file__).parent / "output"
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
def test_format(format: str, text: str = "The quick brown fox jumped over the lazy dog."):
|
||||
|
||||
def test_format(
|
||||
format: str, text: str = "The quick brown fox jumped over the lazy dog."
|
||||
):
|
||||
speech_file = output_dir / f"speech_{format}.{format}"
|
||||
print(f"\nTesting {format} format...")
|
||||
print(f"Making request to {client.base_url}/audio/speech...")
|
||||
|
||||
try:
|
||||
response = client.audio.speech.create(
|
||||
model="tts-1",
|
||||
voice="af",
|
||||
input=text,
|
||||
response_format=format
|
||||
model="tts-1", voice="af", input=text, response_format=format
|
||||
)
|
||||
|
||||
print(f"Got response, saving to file...")
|
||||
with open(speech_file, 'wb') as f:
|
||||
print("Got response, saving to file...")
|
||||
with open(speech_file, "wb") as f:
|
||||
f.write(response.content)
|
||||
print(f"Success! Saved to: {speech_file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {str(e)}")
|
||||
|
||||
|
||||
def test_speed(speed: float):
|
||||
speech_file = output_dir / f"speech_speed_{speed}.wav"
|
||||
print(f"\nTesting speed {speed}x...")
|
||||
|
@ -44,17 +45,18 @@ def test_speed(speed: float):
|
|||
voice="af",
|
||||
input="The quick brown fox jumped over the lazy dog.",
|
||||
response_format="wav",
|
||||
speed=speed
|
||||
speed=speed,
|
||||
)
|
||||
|
||||
print(f"Got response, saving to file...")
|
||||
with open(speech_file, 'wb') as f:
|
||||
print("Got response, saving to file...")
|
||||
with open(speech_file, "wb") as f:
|
||||
f.write(response.content)
|
||||
print(f"Success! Saved to: {speech_file}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {str(e)}")
|
||||
|
||||
|
||||
# Test different formats
|
||||
for format in ["wav", "mp3", "opus", "aac", "flac", "pcm"]:
|
||||
test_format(format)
|
||||
|
@ -64,6 +66,9 @@ for speed in [0.25, 1.0, 2.0, 4.0]: # 5.0 should fail as it's out of range
|
|||
test_speed(speed)
|
||||
|
||||
# Test long text
|
||||
test_format("wav", """
|
||||
test_format(
|
||||
"wav",
|
||||
"""
|
||||
That is the germ of my great discovery. But you are wrong to say that we cannot move about in Time. For instance, if I am recalling an incident very vividly I go back to the instant of its occurrence: I become absent-minded, as you say. I jump back for a moment.
|
||||
""")
|
||||
""",
|
||||
)
|
||||
|
|
|
@ -24,6 +24,7 @@ tqdm==4.67.1
|
|||
requests==2.32.3
|
||||
munch==4.0.0
|
||||
tiktoken===0.8.0
|
||||
loguru==0.7.3
|
||||
|
||||
# Testing
|
||||
pytest==8.0.0
|
||||
|
|
Loading…
Add table
Reference in a new issue