diff --git a/api/src/main.py b/api/src/main.py index 78c588a..362602c 100644 --- a/api/src/main.py +++ b/api/src/main.py @@ -1,26 +1,38 @@ """ FastAPI OpenAI Compatible API """ + import uvicorn -import logging -import sys from contextlib import asynccontextmanager from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware +from loguru import logger from .core.config import settings from .routers.openai_compatible import router as openai_router -# from services.tts import TTSModel +from .services.tts import TTSModel, TTSService -logger = logging.getLogger(__name__) @asynccontextmanager async def lifespan(app: FastAPI): - """Lifespan context manager for database initialization""" + """Lifespan context manager for model initialization""" + logger.info("Loading TTS model and voice packs...") + # Initialize the main model + model, device = TTSModel.get_instance() + logger.info(f"Model loaded on {device}") + # Initialize all voice packs + tts_service = TTSService() + voices = tts_service.list_voices() + for voice in voices: + logger.info(f"Loading voice pack: {voice}") + TTSModel.get_voicepack(voice) + + logger.info("All models and voice packs loaded successfully") yield + # Initialize FastAPI app app = FastAPI( title=settings.api_title, @@ -42,16 +54,19 @@ app.add_middleware( # Include OpenAI compatible router app.include_router(openai_router, prefix="/v1") + # Health check endpoint @app.get("/health") async def health_check(): """Health check endpoint""" return {"status": "healthy"} + @app.get("/v1/test") async def test_endpoint(): """Test endpoint to verify routing""" return {"status": "ok"} + if __name__ == "__main__": uvicorn.run("api.src.main:app", host=settings.host, port=settings.port, reload=True) diff --git a/api/src/routers/__init__.py b/api/src/routers/__init__.py index 4287ca8..792d600 100644 --- a/api/src/routers/__init__.py +++ b/api/src/routers/__init__.py @@ -1 +1 @@ -# \ No newline at end of file +# diff --git a/api/src/routers/openai_compatible.py b/api/src/routers/openai_compatible.py index b9f8da1..df20d66 100644 --- a/api/src/routers/openai_compatible.py +++ b/api/src/routers/openai_compatible.py @@ -1,5 +1,4 @@ from fastapi import APIRouter, HTTPException, Response, Depends -from sqlalchemy.orm import Session import logging from ..structures.schemas import OpenAISpeechRequest from ..services.tts import TTSService @@ -12,14 +11,17 @@ router = APIRouter( responses={404: {"description": "Not found"}}, ) + def get_tts_service() -> TTSService: """Dependency to get TTSService instance with database session""" - return TTSService(start_worker=False) # Don't start worker thread for OpenAI endpoint + return TTSService( + start_worker=False + ) # Don't start worker thread for OpenAI endpoint + @router.post("/audio/speech") async def create_speech( - request: OpenAISpeechRequest, - tts_service: TTSService = Depends(get_tts_service) + request: OpenAISpeechRequest, tts_service: TTSService = Depends(get_tts_service) ): """OpenAI-compatible endpoint for text-to-speech""" try: @@ -28,28 +30,27 @@ async def create_speech( text=request.input, voice=request.voice, speed=request.speed, - stitch_long_output=True + stitch_long_output=True, ) - + # Convert to requested format content = AudioService.convert_audio(audio, 24000, request.response_format) - + return Response( content=content, media_type=f"audio/{request.response_format}", headers={ "Content-Disposition": f"attachment; filename=speech.{request.response_format}" - } + }, ) - + except Exception as e: logger.error(f"Error generating speech: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) + @router.get("/audio/voices") -async def list_voices( - tts_service: TTSService = Depends(get_tts_service) -): +async def list_voices(tts_service: TTSService = Depends(get_tts_service)): """List all available voices for text-to-speech""" try: voices = tts_service.list_voices() diff --git a/api/src/services/audio.py b/api/src/services/audio.py index e4e0d6f..d3408ff 100644 --- a/api/src/services/audio.py +++ b/api/src/services/audio.py @@ -1,4 +1,5 @@ """Audio conversion service""" + from io import BytesIO import numpy as np import scipy.io.wavfile as wavfile @@ -7,60 +8,69 @@ import logging logger = logging.getLogger(__name__) + class AudioService: """Service for audio format conversions""" - + @staticmethod - def convert_audio(audio_data: np.ndarray, sample_rate: int, output_format: str) -> bytes: + def convert_audio( + audio_data: np.ndarray, sample_rate: int, output_format: str + ) -> bytes: """Convert audio data to specified format - + Args: audio_data: Numpy array of audio samples sample_rate: Sample rate of the audio output_format: Target format (wav, mp3, etc.) - + Returns: Bytes of the converted audio """ buffer = BytesIO() - + try: - if output_format == 'wav': + if output_format == "wav": logger.info("Writing to WAV format...") wavfile.write(buffer, sample_rate, audio_data) return buffer.getvalue() - - elif output_format == 'mp3': + + elif output_format == "mp3": # For MP3, we need to convert to WAV first logger.info("Converting to MP3 format...") wav_buffer = BytesIO() wavfile.write(wav_buffer, sample_rate, audio_data) wav_buffer.seek(0) - + # Convert WAV to MP3 using soundfile buffer = BytesIO() - sf.write(buffer, audio_data, sample_rate, format='mp3') + sf.write(buffer, audio_data, sample_rate, format="mp3") return buffer.getvalue() - - elif output_format == 'opus': + + elif output_format == "opus": logger.info("Converting to Opus format...") - sf.write(buffer, audio_data, sample_rate, format='ogg', subtype='opus') + sf.write(buffer, audio_data, sample_rate, format="ogg", subtype="opus") return buffer.getvalue() - - elif output_format == 'flac': + + elif output_format == "flac": logger.info("Converting to FLAC format...") - sf.write(buffer, audio_data, sample_rate, format='flac') + sf.write(buffer, audio_data, sample_rate, format="flac") return buffer.getvalue() - - elif output_format == 'aac': - raise ValueError("AAC format is not currently supported. Please use wav, mp3, opus, or flac.") - - elif output_format == 'pcm': - raise ValueError("PCM format is not currently supported. Please use wav, mp3, opus, or flac.") - + + elif output_format == "aac": + raise ValueError( + "AAC format is not currently supported. Please use wav, mp3, opus, or flac." + ) + + elif output_format == "pcm": + raise ValueError( + "PCM format is not currently supported. Please use wav, mp3, opus, or flac." + ) + else: - raise ValueError(f"Format {output_format} not supported. Supported formats are: wav, mp3, opus, flac.") - + raise ValueError( + f"Format {output_format} not supported. Supported formats are: wav, mp3, opus, flac." + ) + except Exception as e: logger.error(f"Error converting audio to {output_format}: {str(e)}") raise ValueError(f"Failed to convert audio to {output_format}: {str(e)}") diff --git a/api/src/services/tts.py b/api/src/services/tts.py index bf1de81..35d9bcb 100644 --- a/api/src/services/tts.py +++ b/api/src/services/tts.py @@ -2,8 +2,7 @@ import os import threading import time import io -from typing import Optional, List, Tuple -from sqlalchemy.orm import Session +from typing import List, Tuple import numpy as np import torch import scipy.io.wavfile as wavfile @@ -17,6 +16,7 @@ import tiktoken logger = logging.getLogger(__name__) enc = tiktoken.get_encoding("cl100k_base") + class TTSModel: _instance = None _lock = threading.Lock() @@ -40,7 +40,9 @@ class TTSModel: model, device = cls.get_instance() if voice_name not in cls._voicepacks: try: - voice_path = os.path.join(settings.model_dir, settings.voices_dir, f"{voice_name}.pt") + voice_path = os.path.join( + settings.model_dir, settings.voices_dir, f"{voice_name}.pt" + ) voicepack = torch.load( voice_path, map_location=device, weights_only=True ) @@ -61,9 +63,11 @@ class TTSService: def _split_text(self, text: str) -> List[str]: """Split text into sentences""" - return [s.strip() for s in re.split(r'(?<=[.!?])\s+', text) if s.strip()] + return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()] - def _generate_audio(self, text: str, voice: str, speed: float, stitch_long_output: bool = True) -> Tuple[torch.Tensor, float]: + def _generate_audio( + self, text: str, voice: str, speed: float, stitch_long_output: bool = True + ) -> Tuple[torch.Tensor, float]: """Generate audio and measure processing time""" start_time = time.time() @@ -87,22 +91,34 @@ class TTSService: # Validate phonemization first ps = phonemize(chunk, voice[0]) tokens = tokenize(ps) - logger.info(f"Processing chunk {i+1}/{len(chunks)}: {len(tokens)} tokens") - + logger.info( + f"Processing chunk {i+1}/{len(chunks)}: {len(tokens)} tokens" + ) + # Only proceed if phonemization succeeded - chunk_audio, _ = generate(model, chunk, voicepack, lang=voice[0], speed=speed) + chunk_audio, _ = generate( + model, chunk, voicepack, lang=voice[0], speed=speed + ) if chunk_audio is not None: audio_chunks.append(chunk_audio) else: - logger.error(f"No audio generated for chunk {i+1}/{len(chunks)}") + logger.error( + f"No audio generated for chunk {i+1}/{len(chunks)}" + ) except Exception as e: - logger.error(f"Failed to generate audio for chunk {i+1}/{len(chunks)}: '{chunk}'. Error: {str(e)}") + logger.error( + f"Failed to generate audio for chunk {i+1}/{len(chunks)}: '{chunk}'. Error: {str(e)}" + ) continue if not audio_chunks: raise ValueError("No audio chunks were generated successfully") - audio = np.concatenate(audio_chunks) if len(audio_chunks) > 1 else audio_chunks[0] + audio = ( + np.concatenate(audio_chunks) + if len(audio_chunks) > 1 + else audio_chunks[0] + ) else: audio, _ = generate(model, text, voicepack, lang=voice[0], speed=speed) diff --git a/api/src/structures/schemas.py b/api/src/structures/schemas.py index d24e4d5..bb00fc7 100644 --- a/api/src/structures/schemas.py +++ b/api/src/structures/schemas.py @@ -15,17 +15,24 @@ class TTSStatus(str, Enum): class OpenAISpeechRequest(BaseModel): model: Literal["tts-1", "tts-1-hd"] = "tts-1" input: str = Field(..., description="The text to generate audio for") - voice: Literal["am_adam", "am_michael", "bm_lewis", "af", "bm_george", "bf_isabella", "bf_emma", "af_sarah", "af_bella"] = Field( - default="af", - description="The voice to use for generation" - ) + voice: Literal[ + "am_adam", + "am_michael", + "bm_lewis", + "af", + "bm_george", + "bf_isabella", + "bf_emma", + "af_sarah", + "af_bella", + ] = Field(default="af", description="The voice to use for generation") response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field( default="mp3", - description="The format to return audio in. Supported formats: mp3, opus, flac, wav. AAC and PCM are not currently supported." + description="The format to return audio in. Supported formats: mp3, opus, flac, wav. AAC and PCM are not currently supported.", ) speed: float = Field( default=1.0, ge=0.25, le=4.0, - description="The speed of the generated audio. Select a value from 0.25 to 4.0." + description="The speed of the generated audio. Select a value from 0.25 to 4.0.", ) diff --git a/api/tests/conftest.py b/api/tests/conftest.py index c367c92..ecb8229 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -3,17 +3,15 @@ from unittest.mock import Mock, patch import sys # Mock torch and other ML modules before they're imported -sys.modules['torch'] = Mock() -sys.modules['transformers'] = Mock() -sys.modules['phonemizer'] = Mock() -sys.modules['models'] = Mock() -sys.modules['models.build_model'] = Mock() -sys.modules['kokoro'] = Mock() -sys.modules['kokoro.generate'] = Mock() -sys.modules['kokoro.phonemize'] = Mock() -sys.modules['kokoro.tokenize'] = Mock() - -from api.src.main import app +sys.modules["torch"] = Mock() +sys.modules["transformers"] = Mock() +sys.modules["phonemizer"] = Mock() +sys.modules["models"] = Mock() +sys.modules["models.build_model"] = Mock() +sys.modules["kokoro"] = Mock() +sys.modules["kokoro.generate"] = Mock() +sys.modules["kokoro.phonemize"] = Mock() +sys.modules["kokoro.tokenize"] = Mock() @pytest.fixture(autouse=True) diff --git a/api/tests/test_endpoints.py b/api/tests/test_endpoints.py index a779679..c2223f0 100644 --- a/api/tests/test_endpoints.py +++ b/api/tests/test_endpoints.py @@ -1,27 +1,44 @@ from fastapi.testclient import TestClient import pytest -from unittest.mock import Mock, patch +from unittest.mock import Mock from ..src.main import app -from ..src.services.tts import TTSService -from ..src.routers.openai_compatible import TTSService as OpenAITTSService # Create test client client = TestClient(app) + # Mock services @pytest.fixture def mock_tts_service(monkeypatch): mock_service = Mock() mock_service._generate_audio.return_value = (bytes([0, 1, 2, 3]), 1.0) - mock_service.list_voices.return_value = ["af", "bm_lewis", "bf_isabella", "bf_emma", "af_sarah", "af_bella", "am_adam", "am_michael", "bm_george"] - monkeypatch.setattr("api.src.routers.openai_compatible.TTSService", lambda *args, **kwargs: mock_service) + mock_service.list_voices.return_value = [ + "af", + "bm_lewis", + "bf_isabella", + "bf_emma", + "af_sarah", + "af_bella", + "am_adam", + "am_michael", + "bm_george", + ] + monkeypatch.setattr( + "api.src.routers.openai_compatible.TTSService", + lambda *args, **kwargs: mock_service, + ) return mock_service + @pytest.fixture def mock_audio_service(monkeypatch): def mock_convert(*args): return b"converted mock audio data" - monkeypatch.setattr("api.src.routers.openai_compatible.AudioService.convert_audio", mock_convert) + + monkeypatch.setattr( + "api.src.routers.openai_compatible.AudioService.convert_audio", mock_convert + ) + def test_health_check(): """Test the health check endpoint""" @@ -29,6 +46,7 @@ def test_health_check(): assert response.status_code == 200 assert response.json() == {"status": "healthy"} + def test_openai_speech_endpoint(mock_tts_service, mock_audio_service): """Test the OpenAI-compatible speech endpoint""" test_request = { @@ -36,20 +54,18 @@ def test_openai_speech_endpoint(mock_tts_service, mock_audio_service): "input": "Hello world", "voice": "bm_lewis", "response_format": "wav", - "speed": 1.0 + "speed": 1.0, } response = client.post("/v1/audio/speech", json=test_request) assert response.status_code == 200 assert response.headers["content-type"] == "audio/wav" assert response.headers["content-disposition"] == "attachment; filename=speech.wav" mock_tts_service._generate_audio.assert_called_once_with( - text="Hello world", - voice="bm_lewis", - speed=1.0, - stitch_long_output=True + text="Hello world", voice="bm_lewis", speed=1.0, stitch_long_output=True ) assert response.content == b"converted mock audio data" + def test_openai_speech_invalid_voice(mock_tts_service): """Test the OpenAI-compatible speech endpoint with invalid voice""" test_request = { @@ -57,11 +73,12 @@ def test_openai_speech_invalid_voice(mock_tts_service): "input": "Hello world", "voice": "invalid_voice", "response_format": "wav", - "speed": 1.0 + "speed": 1.0, } response = client.post("/v1/audio/speech", json=test_request) assert response.status_code == 422 # Validation error + def test_openai_speech_invalid_speed(mock_tts_service): """Test the OpenAI-compatible speech endpoint with invalid speed""" test_request = { @@ -69,11 +86,12 @@ def test_openai_speech_invalid_speed(mock_tts_service): "input": "Hello world", "voice": "af", "response_format": "wav", - "speed": -1.0 # Invalid speed + "speed": -1.0, # Invalid speed } response = client.post("/v1/audio/speech", json=test_request) assert response.status_code == 422 # Validation error + def test_openai_speech_generation_error(mock_tts_service): """Test error handling in speech generation""" mock_tts_service._generate_audio.side_effect = Exception("Generation failed") @@ -82,7 +100,7 @@ def test_openai_speech_generation_error(mock_tts_service): "input": "Hello world", "voice": "af", "response_format": "wav", - "speed": 1.0 + "speed": 1.0, } response = client.post("/v1/audio/speech", json=test_request) assert response.status_code == 500 diff --git a/examples/benchmarks/benchmark_tts.py b/examples/benchmarks/benchmark_tts.py index f95ce29..61b51f1 100644 --- a/examples/benchmarks/benchmark_tts.py +++ b/examples/benchmarks/benchmark_tts.py @@ -3,43 +3,43 @@ import time import json import scipy.io.wavfile as wavfile import requests -import numpy as np import pandas as pd import seaborn as sns import matplotlib.pyplot as plt -from scipy.signal import savgol_filter -import tiktoken +import tiktoken import psutil import subprocess from datetime import datetime enc = tiktoken.get_encoding("cl100k_base") + def setup_plot(fig, ax, title): """Configure plot styling""" # Improve grid - ax.grid(True, linestyle='--', alpha=0.3, color='#ffffff') - + ax.grid(True, linestyle="--", alpha=0.3, color="#ffffff") + # Set title and labels with better fonts - ax.set_title(title, pad=20, fontsize=16, fontweight='bold', color='#ffffff') - ax.set_xlabel(ax.get_xlabel(), fontsize=14, fontweight='medium', color='#ffffff') - ax.set_ylabel(ax.get_ylabel(), fontsize=14, fontweight='medium', color='#ffffff') - + ax.set_title(title, pad=20, fontsize=16, fontweight="bold", color="#ffffff") + ax.set_xlabel(ax.get_xlabel(), fontsize=14, fontweight="medium", color="#ffffff") + ax.set_ylabel(ax.get_ylabel(), fontsize=14, fontweight="medium", color="#ffffff") + # Improve tick labels - ax.tick_params(labelsize=12, colors='#ffffff') - + ax.tick_params(labelsize=12, colors="#ffffff") + # Style spines for spine in ax.spines.values(): - spine.set_color('#ffffff') + spine.set_color("#ffffff") spine.set_alpha(0.3) spine.set_linewidth(0.5) - + # Set background colors - ax.set_facecolor('#1a1a2e') - fig.patch.set_facecolor('#1a1a2e') - + ax.set_facecolor("#1a1a2e") + fig.patch.set_facecolor("#1a1a2e") + return fig, ax + def get_text_for_tokens(text: str, num_tokens: int) -> str: """Get a slice of text that contains exactly num_tokens tokens""" tokens = enc.encode(text) @@ -47,14 +47,15 @@ def get_text_for_tokens(text: str, num_tokens: int) -> str: return text return enc.decode(tokens[:num_tokens]) + def get_audio_length(audio_data: bytes) -> float: """Get audio length in seconds from bytes data""" # Save to a temporary file - temp_path = 'examples/benchmarks/output/temp.wav' + temp_path = "examples/benchmarks/output/temp.wav" os.makedirs(os.path.dirname(temp_path), exist_ok=True) - with open(temp_path, 'wb') as f: + with open(temp_path, "wb") as f: f.write(audio_data) - + # Read the audio file try: rate, data = wavfile.read(temp_path) @@ -64,60 +65,65 @@ def get_audio_length(audio_data: bytes) -> float: if os.path.exists(temp_path): os.remove(temp_path) + def get_gpu_memory(): """Get GPU memory usage using nvidia-smi""" try: - result = subprocess.check_output(['nvidia-smi', '--query-gpu=memory.used', '--format=csv,nounits,noheader']) - return float(result.decode('utf-8').strip()) + result = subprocess.check_output( + ["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader"] + ) + return float(result.decode("utf-8").strip()) except (subprocess.CalledProcessError, FileNotFoundError): return None + def get_system_metrics(): """Get current system metrics""" metrics = { - 'timestamp': datetime.now().isoformat(), - 'cpu_percent': psutil.cpu_percent(), - 'ram_percent': psutil.virtual_memory().percent, - 'ram_used_gb': psutil.virtual_memory().used / (1024**3), + "timestamp": datetime.now().isoformat(), + "cpu_percent": psutil.cpu_percent(), + "ram_percent": psutil.virtual_memory().percent, + "ram_used_gb": psutil.virtual_memory().used / (1024**3), } - + gpu_mem = get_gpu_memory() if gpu_mem is not None: - metrics['gpu_memory_used'] = gpu_mem - + metrics["gpu_memory_used"] = gpu_mem + return metrics + def make_tts_request(text: str, timeout: int = 120) -> tuple[float, float]: """Make TTS request using OpenAI-compatible endpoint and return processing time and output length""" try: start_time = time.time() - + # Make request to OpenAI-compatible endpoint response = requests.post( - 'http://localhost:8880/v1/audio/speech', + "http://localhost:8880/v1/audio/speech", json={ - 'model': 'tts-1', - 'input': text, - 'voice': 'af', - 'response_format': 'wav' + "model": "tts-1", + "input": text, + "voice": "af", + "response_format": "wav", }, - timeout=timeout + timeout=timeout, ) response.raise_for_status() - + processing_time = time.time() - start_time audio_length = get_audio_length(response.content) - + # Save the audio file token_count = len(enc.encode(text)) - output_file = f'examples/benchmarks/output/chunk_{token_count}_tokens.wav' + output_file = f"examples/benchmarks/output/chunk_{token_count}_tokens.wav" os.makedirs(os.path.dirname(output_file), exist_ok=True) - with open(output_file, 'wb') as f: + with open(output_file, "wb") as f: f.write(response.content) print(f"Saved audio to {output_file}") - + return processing_time, audio_length - + except requests.exceptions.RequestException as e: print(f"Error making request for text: {text[:50]}... Error: {str(e)}") return None, None @@ -125,91 +131,113 @@ def make_tts_request(text: str, timeout: int = 120) -> tuple[float, float]: print(f"Error processing text: {text[:50]}... Error: {str(e)}") return None, None + def plot_system_metrics(metrics_data): """Create plots for system metrics over time""" df = pd.DataFrame(metrics_data) - df['timestamp'] = pd.to_datetime(df['timestamp']) - elapsed_time = (df['timestamp'] - df['timestamp'].iloc[0]).dt.total_seconds() - + df["timestamp"] = pd.to_datetime(df["timestamp"]) + elapsed_time = (df["timestamp"] - df["timestamp"].iloc[0]).dt.total_seconds() + # Get baseline values (first measurement) - baseline_cpu = df['cpu_percent'].iloc[0] - baseline_ram = df['ram_used_gb'].iloc[0] - baseline_gpu = df['gpu_memory_used'].iloc[0] / 1024 if 'gpu_memory_used' in df.columns else None # Convert MB to GB - + baseline_cpu = df["cpu_percent"].iloc[0] + baseline_ram = df["ram_used_gb"].iloc[0] + baseline_gpu = ( + df["gpu_memory_used"].iloc[0] / 1024 + if "gpu_memory_used" in df.columns + else None + ) # Convert MB to GB + # Convert GPU memory to GB - if 'gpu_memory_used' in df.columns: - df['gpu_memory_gb'] = df['gpu_memory_used'] / 1024 - + if "gpu_memory_used" in df.columns: + df["gpu_memory_gb"] = df["gpu_memory_used"] / 1024 + # Set plotting style - plt.style.use('dark_background') - + plt.style.use("dark_background") + # Create figure with 3 subplots (or 2 if no GPU) - has_gpu = 'gpu_memory_used' in df.columns + has_gpu = "gpu_memory_used" in df.columns num_plots = 3 if has_gpu else 2 - fig, axes = plt.subplots(num_plots, 1, figsize=(15, 5*num_plots)) - fig.patch.set_facecolor('#1a1a2e') - + fig, axes = plt.subplots(num_plots, 1, figsize=(15, 5 * num_plots)) + fig.patch.set_facecolor("#1a1a2e") + # Apply rolling average for smoothing window = min(5, len(df) // 2) # Smaller window for smoother lines - + # Plot 1: CPU Usage - smoothed_cpu = df['cpu_percent'].rolling(window=window, center=True).mean() - sns.lineplot(x=elapsed_time, y=smoothed_cpu, ax=axes[0], color='#ff2a6d', linewidth=2) - axes[0].axhline(y=baseline_cpu, color='#05d9e8', linestyle='--', alpha=0.5, label='Baseline') - axes[0].set_xlabel('Time (seconds)', fontsize=14) - axes[0].set_ylabel('CPU Usage (%)', fontsize=14) + smoothed_cpu = df["cpu_percent"].rolling(window=window, center=True).mean() + sns.lineplot( + x=elapsed_time, y=smoothed_cpu, ax=axes[0], color="#ff2a6d", linewidth=2 + ) + axes[0].axhline( + y=baseline_cpu, color="#05d9e8", linestyle="--", alpha=0.5, label="Baseline" + ) + axes[0].set_xlabel("Time (seconds)", fontsize=14) + axes[0].set_ylabel("CPU Usage (%)", fontsize=14) axes[0].tick_params(labelsize=12) - axes[0].set_title('CPU Usage Over Time', pad=20, fontsize=16, fontweight='bold') - axes[0].set_ylim(0, max(df['cpu_percent']) * 1.1) # Add 10% padding + axes[0].set_title("CPU Usage Over Time", pad=20, fontsize=16, fontweight="bold") + axes[0].set_ylim(0, max(df["cpu_percent"]) * 1.1) # Add 10% padding axes[0].legend() - + # Plot 2: RAM Usage - smoothed_ram = df['ram_used_gb'].rolling(window=window, center=True).mean() - sns.lineplot(x=elapsed_time, y=smoothed_ram, ax=axes[1], color='#05d9e8', linewidth=2) - axes[1].axhline(y=baseline_ram, color='#ff2a6d', linestyle='--', alpha=0.5, label='Baseline') - axes[1].set_xlabel('Time (seconds)', fontsize=14) - axes[1].set_ylabel('RAM Usage (GB)', fontsize=14) + smoothed_ram = df["ram_used_gb"].rolling(window=window, center=True).mean() + sns.lineplot( + x=elapsed_time, y=smoothed_ram, ax=axes[1], color="#05d9e8", linewidth=2 + ) + axes[1].axhline( + y=baseline_ram, color="#ff2a6d", linestyle="--", alpha=0.5, label="Baseline" + ) + axes[1].set_xlabel("Time (seconds)", fontsize=14) + axes[1].set_ylabel("RAM Usage (GB)", fontsize=14) axes[1].tick_params(labelsize=12) - axes[1].set_title('RAM Usage Over Time', pad=20, fontsize=16, fontweight='bold') - axes[1].set_ylim(0, max(df['ram_used_gb']) * 1.1) # Add 10% padding + axes[1].set_title("RAM Usage Over Time", pad=20, fontsize=16, fontweight="bold") + axes[1].set_ylim(0, max(df["ram_used_gb"]) * 1.1) # Add 10% padding axes[1].legend() - + # Plot 3: GPU Memory (if available) if has_gpu: - smoothed_gpu = df['gpu_memory_gb'].rolling(window=window, center=True).mean() - sns.lineplot(x=elapsed_time, y=smoothed_gpu, ax=axes[2], color='#ff2a6d', linewidth=2) - axes[2].axhline(y=baseline_gpu, color='#05d9e8', linestyle='--', alpha=0.5, label='Baseline') - axes[2].set_xlabel('Time (seconds)', fontsize=14) - axes[2].set_ylabel('GPU Memory (GB)', fontsize=14) + smoothed_gpu = df["gpu_memory_gb"].rolling(window=window, center=True).mean() + sns.lineplot( + x=elapsed_time, y=smoothed_gpu, ax=axes[2], color="#ff2a6d", linewidth=2 + ) + axes[2].axhline( + y=baseline_gpu, color="#05d9e8", linestyle="--", alpha=0.5, label="Baseline" + ) + axes[2].set_xlabel("Time (seconds)", fontsize=14) + axes[2].set_ylabel("GPU Memory (GB)", fontsize=14) axes[2].tick_params(labelsize=12) - axes[2].set_title('GPU Memory Usage Over Time', pad=20, fontsize=16, fontweight='bold') - axes[2].set_ylim(0, max(df['gpu_memory_gb']) * 1.1) # Add 10% padding + axes[2].set_title( + "GPU Memory Usage Over Time", pad=20, fontsize=16, fontweight="bold" + ) + axes[2].set_ylim(0, max(df["gpu_memory_gb"]) * 1.1) # Add 10% padding axes[2].legend() - + # Style all subplots for ax in axes: - ax.grid(True, linestyle='--', alpha=0.3) - ax.set_facecolor('#1a1a2e') + ax.grid(True, linestyle="--", alpha=0.3) + ax.set_facecolor("#1a1a2e") for spine in ax.spines.values(): - spine.set_color('#ffffff') + spine.set_color("#ffffff") spine.set_alpha(0.3) - + plt.tight_layout() - plt.savefig('examples/benchmarks/system_usage.png', dpi=300, bbox_inches='tight') + plt.savefig("examples/benchmarks/system_usage.png", dpi=300, bbox_inches="tight") plt.close() + def main(): # Create output directory - os.makedirs('examples/benchmarks/output', exist_ok=True) - + os.makedirs("examples/benchmarks/output", exist_ok=True) + # Read input text - with open('examples/benchmarks/the_time_machine_hg_wells.txt', 'r', encoding='utf-8') as f: + with open( + "examples/benchmarks/the_time_machine_hg_wells.txt", "r", encoding="utf-8" + ) as f: text = f.read() - + # Get total tokens in file total_tokens = len(enc.encode(text)) print(f"Total tokens in file: {total_tokens}") - + # Generate token sizes with dense sampling at start and increasing intervals dense_range = list(range(100, 600, 100)) # 100, 200, 300, 400, 500 medium_range = [750, 1000, 1500, 2000, 3000] @@ -218,120 +246,160 @@ def main(): while current <= total_tokens: large_range.append(current) current *= 2 - + token_sizes = dense_range + medium_range + large_range - + # Process chunks results = [] system_metrics = [] test_start_time = time.time() - + for num_tokens in token_sizes: # Get text slice with exact token count chunk = get_text_for_tokens(text, num_tokens) actual_tokens = len(enc.encode(chunk)) - + print(f"\nProcessing chunk with {actual_tokens} tokens:") print(f"Text preview: {chunk[:100]}...") - + # Collect system metrics before processing system_metrics.append(get_system_metrics()) - + processing_time, audio_length = make_tts_request(chunk) if processing_time is None or audio_length is None: print("Breaking loop due to error") break - + # Collect system metrics after processing system_metrics.append(get_system_metrics()) - - results.append({ - 'tokens': actual_tokens, - 'processing_time': processing_time, - 'output_length': audio_length, - 'realtime_factor': audio_length / processing_time, - 'elapsed_time': time.time() - test_start_time - }) - + + results.append( + { + "tokens": actual_tokens, + "processing_time": processing_time, + "output_length": audio_length, + "realtime_factor": audio_length / processing_time, + "elapsed_time": time.time() - test_start_time, + } + ) + # Save intermediate results - with open('examples/benchmarks/benchmark_results.json', 'w') as f: - json.dump({ - 'results': results, - 'system_metrics': system_metrics - }, f, indent=2) - + with open("examples/benchmarks/benchmark_results.json", "w") as f: + json.dump( + {"results": results, "system_metrics": system_metrics}, f, indent=2 + ) + # Create DataFrame and calculate stats df = pd.DataFrame(results) if df.empty: print("No data to plot") return - + # Calculate useful metrics - df['tokens_per_second'] = df['tokens'] / df['processing_time'] - + df["tokens_per_second"] = df["tokens"] / df["processing_time"] + # Write detailed stats - with open('examples/benchmarks/benchmark_stats.txt', 'w') as f: + with open("examples/benchmarks/benchmark_stats.txt", "w") as f: f.write("=== Benchmark Statistics ===\n\n") - + f.write("Overall Stats:\n") f.write(f"Total tokens processed: {df['tokens'].sum()}\n") f.write(f"Total audio generated: {df['output_length'].sum():.2f}s\n") f.write(f"Total test duration: {df['elapsed_time'].max():.2f}s\n") - f.write(f"Average processing rate: {df['tokens_per_second'].mean():.2f} tokens/second\n") + f.write( + f"Average processing rate: {df['tokens_per_second'].mean():.2f} tokens/second\n" + ) f.write(f"Average realtime factor: {df['realtime_factor'].mean():.2f}x\n\n") - + f.write("Per-chunk Stats:\n") f.write(f"Average chunk size: {df['tokens'].mean():.2f} tokens\n") f.write(f"Min chunk size: {df['tokens'].min():.2f} tokens\n") f.write(f"Max chunk size: {df['tokens'].max():.2f} tokens\n") f.write(f"Average processing time: {df['processing_time'].mean():.2f}s\n") f.write(f"Average output length: {df['output_length'].mean():.2f}s\n\n") - + f.write("Performance Ranges:\n") - f.write(f"Processing rate range: {df['tokens_per_second'].min():.2f} - {df['tokens_per_second'].max():.2f} tokens/second\n") - f.write(f"Realtime factor range: {df['realtime_factor'].min():.2f}x - {df['realtime_factor'].max():.2f}x\n") + f.write( + f"Processing rate range: {df['tokens_per_second'].min():.2f} - {df['tokens_per_second'].max():.2f} tokens/second\n" + ) + f.write( + f"Realtime factor range: {df['realtime_factor'].min():.2f}x - {df['realtime_factor'].max():.2f}x\n" + ) # Set plotting style - plt.style.use('dark_background') - + plt.style.use("dark_background") + # Plot 1: Processing Time vs Token Count fig, ax = plt.subplots(figsize=(12, 8)) - sns.scatterplot(data=df, x='tokens', y='processing_time', s=100, alpha=0.6, color='#ff2a6d') - sns.regplot(data=df, x='tokens', y='processing_time', scatter=False, color='#05d9e8', line_kws={'linewidth': 2}) - corr = df['tokens'].corr(df['processing_time']) - plt.text(0.05, 0.95, f'Correlation: {corr:.2f}', transform=ax.transAxes, fontsize=10, color='#ffffff', - bbox=dict(facecolor='#1a1a2e', edgecolor='#ffffff', alpha=0.7)) - setup_plot(fig, ax, 'Processing Time vs Input Size') - ax.set_xlabel('Number of Input Tokens') - ax.set_ylabel('Processing Time (seconds)') - plt.savefig('examples/benchmarks/processing_time.png', dpi=300, bbox_inches='tight') + sns.scatterplot( + data=df, x="tokens", y="processing_time", s=100, alpha=0.6, color="#ff2a6d" + ) + sns.regplot( + data=df, + x="tokens", + y="processing_time", + scatter=False, + color="#05d9e8", + line_kws={"linewidth": 2}, + ) + corr = df["tokens"].corr(df["processing_time"]) + plt.text( + 0.05, + 0.95, + f"Correlation: {corr:.2f}", + transform=ax.transAxes, + fontsize=10, + color="#ffffff", + bbox=dict(facecolor="#1a1a2e", edgecolor="#ffffff", alpha=0.7), + ) + setup_plot(fig, ax, "Processing Time vs Input Size") + ax.set_xlabel("Number of Input Tokens") + ax.set_ylabel("Processing Time (seconds)") + plt.savefig("examples/benchmarks/processing_time.png", dpi=300, bbox_inches="tight") plt.close() - + # Plot 2: Realtime Factor vs Token Count fig, ax = plt.subplots(figsize=(12, 8)) - sns.scatterplot(data=df, x='tokens', y='realtime_factor', s=100, alpha=0.6, color='#ff2a6d') - sns.regplot(data=df, x='tokens', y='realtime_factor', scatter=False, color='#05d9e8', line_kws={'linewidth': 2}) - corr = df['tokens'].corr(df['realtime_factor']) - plt.text(0.05, 0.95, f'Correlation: {corr:.2f}', transform=ax.transAxes, fontsize=10, color='#ffffff', - bbox=dict(facecolor='#1a1a2e', edgecolor='#ffffff', alpha=0.7)) - setup_plot(fig, ax, 'Realtime Factor vs Input Size') - ax.set_xlabel('Number of Input Tokens') - ax.set_ylabel('Realtime Factor (output length / processing time)') - plt.savefig('examples/benchmarks/realtime_factor.png', dpi=300, bbox_inches='tight') + sns.scatterplot( + data=df, x="tokens", y="realtime_factor", s=100, alpha=0.6, color="#ff2a6d" + ) + sns.regplot( + data=df, + x="tokens", + y="realtime_factor", + scatter=False, + color="#05d9e8", + line_kws={"linewidth": 2}, + ) + corr = df["tokens"].corr(df["realtime_factor"]) + plt.text( + 0.05, + 0.95, + f"Correlation: {corr:.2f}", + transform=ax.transAxes, + fontsize=10, + color="#ffffff", + bbox=dict(facecolor="#1a1a2e", edgecolor="#ffffff", alpha=0.7), + ) + setup_plot(fig, ax, "Realtime Factor vs Input Size") + ax.set_xlabel("Number of Input Tokens") + ax.set_ylabel("Realtime Factor (output length / processing time)") + plt.savefig("examples/benchmarks/realtime_factor.png", dpi=300, bbox_inches="tight") plt.close() - + # Plot system metrics plot_system_metrics(system_metrics) - + print("\nResults saved to:") print("- examples/benchmarks/benchmark_results.json") print("- examples/benchmarks/benchmark_stats.txt") print("- examples/benchmarks/processing_time.png") print("- examples/benchmarks/realtime_factor.png") print("- examples/benchmarks/system_usage.png") - if any('gpu_memory_used' in m for m in system_metrics): + if any("gpu_memory_used" in m for m in system_metrics): print("- examples/benchmarks/gpu_usage.png") print("\nAudio files saved in examples/benchmarks/output/") -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/examples/test_all_voices.py b/examples/test_all_voices.py new file mode 100644 index 0000000..3f1c88a --- /dev/null +++ b/examples/test_all_voices.py @@ -0,0 +1,59 @@ +from pathlib import Path +import openai +import requests + +SAMPLE_TEXT = """ +That is the germ of my great discovery. But you are wrong to say that we cannot move about in Time. For instance, if I am recalling an incident very vividly I go back to the instant of its occurrence: I become absent-minded, as you say. I jump back for a moment. +""" + +# Configure OpenAI client to use our local endpoint +client = openai.OpenAI( + timeout=60, + api_key="notneeded", # API key not required for our endpoint + base_url="http://localhost:8880/v1", # Point to our local server with v1 prefix +) + +# Create output directory if it doesn't exist +output_dir = Path(__file__).parent / "output" +output_dir.mkdir(exist_ok=True) + + +def test_voice(voice: str): + speech_file = output_dir / f"speech_{voice}.wav" + print(f"\nTesting voice: {voice}") + print(f"Making request to {client.base_url}/audio/speech...") + + try: + response = client.audio.speech.create( + model="tts-1", voice=voice, input=SAMPLE_TEXT, response_format="wav" + ) + + print("Got response, saving to file...") + with open(speech_file, "wb") as f: + f.write(response.content) + print(f"Success! Saved to: {speech_file}") + + except Exception as e: + print(f"Error with voice {voice}: {str(e)}") + + +# First, get list of available voices using requests +print("Getting list of available voices...") +try: + # Convert base_url to string and ensure no double slashes + base_url = str(client.base_url).rstrip("/") + response = requests.get(f"{base_url}/audio/voices") + if response.status_code != 200: + raise Exception(f"Failed to get voices: {response.text}") + data = response.json() + if "voices" not in data: + raise Exception(f"Unexpected response format: {data}") + voices = data["voices"] + print(f"Found {len(voices)} voices: {', '.join(voices)}") + + # Test each voice + for voice in voices: + test_voice(voice) + +except Exception as e: + print(f"Error getting voices: {str(e)}") diff --git a/examples/test_openai_tts.py b/examples/test_openai_tts.py index f3635aa..fd9d7d6 100644 --- a/examples/test_openai_tts.py +++ b/examples/test_openai_tts.py @@ -5,56 +5,58 @@ import openai client = openai.OpenAI( timeout=30, api_key="notneeded", # API key not required for our endpoint - base_url="http://localhost:8880/v1" # Point to our local server with v1 prefix + base_url="http://localhost:8880/v1", # Point to our local server with v1 prefix ) # Create output directory if it doesn't exist output_dir = Path(__file__).parent / "output" output_dir.mkdir(exist_ok=True) -def test_format(format: str, text: str = "The quick brown fox jumped over the lazy dog."): + +def test_format( + format: str, text: str = "The quick brown fox jumped over the lazy dog." +): speech_file = output_dir / f"speech_{format}.{format}" print(f"\nTesting {format} format...") print(f"Making request to {client.base_url}/audio/speech...") - + try: response = client.audio.speech.create( - model="tts-1", - voice="af", - input=text, - response_format=format + model="tts-1", voice="af", input=text, response_format=format ) - - print(f"Got response, saving to file...") - with open(speech_file, 'wb') as f: + + print("Got response, saving to file...") + with open(speech_file, "wb") as f: f.write(response.content) print(f"Success! Saved to: {speech_file}") - + except Exception as e: print(f"Error: {str(e)}") + def test_speed(speed: float): speech_file = output_dir / f"speech_speed_{speed}.wav" print(f"\nTesting speed {speed}x...") print(f"Making request to {client.base_url}/audio/speech...") - + try: response = client.audio.speech.create( model="tts-1", voice="af", input="The quick brown fox jumped over the lazy dog.", response_format="wav", - speed=speed + speed=speed, ) - - print(f"Got response, saving to file...") - with open(speech_file, 'wb') as f: + + print("Got response, saving to file...") + with open(speech_file, "wb") as f: f.write(response.content) print(f"Success! Saved to: {speech_file}") - + except Exception as e: print(f"Error: {str(e)}") + # Test different formats for format in ["wav", "mp3", "opus", "aac", "flac", "pcm"]: test_format(format) @@ -64,6 +66,9 @@ for speed in [0.25, 1.0, 2.0, 4.0]: # 5.0 should fail as it's out of range test_speed(speed) # Test long text -test_format("wav", """ +test_format( + "wav", + """ That is the germ of my great discovery. But you are wrong to say that we cannot move about in Time. For instance, if I am recalling an incident very vividly I go back to the instant of its occurrence: I become absent-minded, as you say. I jump back for a moment. -""") +""", +) diff --git a/requirements.txt b/requirements.txt index d5c3312..284620c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,6 +24,7 @@ tqdm==4.67.1 requests==2.32.3 munch==4.0.0 tiktoken===0.8.0 +loguru==0.7.3 # Testing pytest==8.0.0