diff --git a/.coverage b/.coverage index 2133e27..3939b52 100644 Binary files a/.coverage and b/.coverage differ diff --git a/.gitignore b/.gitignore index f61cc2d..eaacb2e 100644 --- a/.gitignore +++ b/.gitignore @@ -23,4 +23,5 @@ examples/assorted_checks/test_openai/output/* examples/assorted_checks/test_voices/output/* examples/assorted_checks/test_formats/output/* +examples/assorted_checks/benchmarks/output_audio_stream/* ui/RepoScreenshot.png diff --git a/api/src/main.py b/api/src/main.py index c2a567e..50f31eb 100644 --- a/api/src/main.py +++ b/api/src/main.py @@ -23,8 +23,16 @@ async def lifespan(app: FastAPI): # Initialize the main model with warm-up voicepack_count = TTSModel.setup() + logger.info(""" + ███████╗█████╗█████████████████╗ ██╗██████╗██╗ ██╗██████╗ + ██╔════██╔══████╔════╚══██╔══██║ ██╔██╔═══████║ ██╔██╔═══██╗ + █████╗ ██████████████╗ ██║ █████╔╝██║ ███████╔╝██║ ██║ + ██╔══╝ ██╔══██╚════██║ ██║ ██╔═██╗██║ ████╔═██╗██║ ██║ + ██║ ██║ █████████║ ██║ ██║ ██╚██████╔██║ ██╚██████╔╝ + ╚═╝ ╚═╝ ╚═╚══════╝ ╚═╝ ╚═╝ ╚═╝╚═════╝╚═╝ ╚═╝╚═════╝ """) logger.info(f"Model loaded and warmed up on {TTSModel.get_device()}") logger.info(f"{voicepack_count} voice packs loaded successfully") + logger.info("#" * 80) yield diff --git a/api/src/routers/openai_compatible.py b/api/src/routers/openai_compatible.py index 6663d7b..c8fa610 100644 --- a/api/src/routers/openai_compatible.py +++ b/api/src/routers/openai_compatible.py @@ -2,10 +2,12 @@ from typing import List from loguru import logger from fastapi import Depends, Response, APIRouter, HTTPException +from fastapi.responses import StreamingResponse from ..services.tts_service import TTSService from ..services.audio import AudioService from ..structures.schemas import OpenAISpeechRequest +from typing import AsyncGenerator router = APIRouter( tags=["OpenAI Compatible TTS"], @@ -18,6 +20,16 @@ def get_tts_service() -> TTSService: return TTSService() # Initialize TTSService with default settings +async def stream_audio_chunks(tts_service: TTSService, request: OpenAISpeechRequest) -> AsyncGenerator[bytes, None]: + """Stream audio chunks as they're generated""" + async for chunk in tts_service.generate_audio_stream( + text=request.input, + voice=request.voice, + speed=request.speed, + output_format=request.response_format + ): + yield chunk + @router.post("/audio/speech") async def create_speech( request: OpenAISpeechRequest, tts_service: TTSService = Depends(get_tts_service) @@ -31,24 +43,52 @@ async def create_speech( f"Voice '{request.voice}' not found. Available voices: {', '.join(sorted(available_voices))}" ) - # Generate audio directly using TTSService's method - audio, _ = tts_service._generate_audio( - text=request.input, - voice=request.voice, - speed=request.speed, - stitch_long_output=True, - ) + # Set content type based on format + content_type = { + "mp3": "audio/mpeg", + "opus": "audio/opus", + "aac": "audio/aac", + "flac": "audio/flac", + "wav": "audio/wav", + "pcm": "audio/pcm", + }.get(request.response_format, f"audio/{request.response_format}") - # Convert to requested format - content = AudioService.convert_audio(audio, 24000, request.response_format) + if request.stream: + # Stream audio chunks as they're generated + return StreamingResponse( + stream_audio_chunks(tts_service, request), + media_type=content_type, + headers={ + "Content-Disposition": f"attachment; filename=speech.{request.response_format}", + "X-Accel-Buffering": "no", # Disable proxy buffering + "Cache-Control": "no-cache", # Prevent caching + }, + ) + else: + # Generate complete audio + audio, _ = tts_service._generate_audio( + text=request.input, + voice=request.voice, + speed=request.speed, + stitch_long_output=True, + ) - return Response( - content=content, - media_type=f"audio/{request.response_format}", - headers={ - "Content-Disposition": f"attachment; filename=speech.{request.response_format}" - }, - ) + # Convert to requested format + content = AudioService.convert_audio( + audio, + 24000, + request.response_format, + is_first_chunk=True + ) + + return Response( + content=content, + media_type=content_type, + headers={ + "Content-Disposition": f"attachment; filename=speech.{request.response_format}", + "Cache-Control": "no-cache", # Prevent caching + }, + ) except ValueError as e: logger.error(f"Invalid request: {str(e)}") diff --git a/api/src/services/audio.py b/api/src/services/audio.py index b8cc708..4883eed 100644 --- a/api/src/services/audio.py +++ b/api/src/services/audio.py @@ -7,12 +7,35 @@ import soundfile as sf from loguru import logger +class AudioNormalizer: + """Handles audio normalization state for a single stream""" + def __init__(self): + self.int16_max = np.iinfo(np.int16).max + + def normalize(self, audio_data: np.ndarray) -> np.ndarray: + """Normalize audio data to int16 range""" + # Convert to float64 for accurate scaling + audio_float = audio_data.astype(np.float64) + + # Scale to int16 range while preserving relative amplitudes + max_val = np.abs(audio_float).max() + if max_val > 0: + scaling = self.int16_max / max_val + audio_float *= scaling + + # Clip to int16 range and convert + return np.clip(audio_float, -self.int16_max, self.int16_max).astype(np.int16) + class AudioService: """Service for audio format conversions""" - + @staticmethod def convert_audio( - audio_data: np.ndarray, sample_rate: int, output_format: str + audio_data: np.ndarray, + sample_rate: int, + output_format: str, + is_first_chunk: bool = True, + normalizer: AudioNormalizer = None ) -> bytes: """Convert audio data to specified format @@ -20,6 +43,7 @@ class AudioService: audio_data: Numpy array of audio samples sample_rate: Sample rate of the audio output_format: Target format (wav, mp3, opus, flac, pcm) + is_first_chunk: Whether this is the first chunk of a stream Returns: Bytes of the converted audio @@ -27,30 +51,34 @@ class AudioService: buffer = BytesIO() try: - if output_format == "wav": + # Normalize audio if normalizer provided, otherwise just convert to int16 + if normalizer is not None: + normalized_audio = normalizer.normalize(audio_data) + else: + normalized_audio = audio_data.astype(np.int16) + + if output_format == "pcm": + logger.info("Writing PCM data...") + # Raw 16-bit PCM samples, no header + buffer.write(normalized_audio.tobytes()) + elif output_format == "wav": logger.info("Writing to WAV format...") - # Ensure audio_data is in int16 format for WAV - audio_data_wav = ( - audio_data / np.abs(audio_data).max() * np.iinfo(np.int16).max - ).astype(np.int16) # Normalize - sf.write(buffer, audio_data_wav, sample_rate, format="WAV") - elif output_format == "mp3": - logger.info("Converting to MP3 format...") - # soundfile can write MP3 if ffmpeg or libsox is installed - sf.write(buffer, audio_data, sample_rate, format="MP3") + # Always include WAV header for WAV format + sf.write(buffer, normalized_audio, sample_rate, format="WAV", subtype='PCM_16') + elif output_format in ["mp3", "aac"]: + logger.info(f"Converting to {output_format.upper()} format...") + # Use lower bitrate for streaming + sf.write(buffer, normalized_audio, sample_rate, format=output_format.upper(), + subtype='COMPRESSED') elif output_format == "opus": logger.info("Converting to Opus format...") - sf.write(buffer, audio_data, sample_rate, format="OGG", subtype="OPUS") + # Use lower bitrate and smaller frame size for streaming + sf.write(buffer, normalized_audio, sample_rate, format="OGG", subtype="OPUS") elif output_format == "flac": logger.info("Converting to FLAC format...") - sf.write(buffer, audio_data, sample_rate, format="FLAC") - elif output_format == "pcm": - logger.info("Extracting PCM data...") - # Ensure audio_data is in int16 format for PCM - audio_data_pcm = ( - audio_data / np.abs(audio_data).max() * np.iinfo(np.int16).max - ).astype(np.int16) # Normalize - buffer.write(audio_data_pcm.tobytes()) + # Use smaller block size for streaming + sf.write(buffer, normalized_audio, sample_rate, format="FLAC", + subtype='PCM_16') else: raise ValueError( f"Format {output_format} not supported. Supported formats are: wav, mp3, opus, flac, pcm." diff --git a/api/src/services/text_processing/normalizer.py b/api/src/services/text_processing/normalizer.py index db5b7db..34f8d4b 100644 --- a/api/src/services/text_processing/normalizer.py +++ b/api/src/services/text_processing/normalizer.py @@ -1,4 +1,5 @@ import re +from functools import lru_cache def split_num(num: re.Match) -> str: """Handle number splitting for various formats""" @@ -48,6 +49,7 @@ def handle_decimal(num: re.Match) -> str: a, b = num.group().split(".") return " point ".join([a, " ".join(b)]) +@lru_cache(maxsize=1000) # Cache normalized text results def normalize_text(text: str) -> str: """Normalize text for TTS processing diff --git a/api/src/services/tts_service.py b/api/src/services/tts_service.py index 6d763fe..66c053b 100644 --- a/api/src/services/tts_service.py +++ b/api/src/services/tts_service.py @@ -3,6 +3,7 @@ import os import re import time from typing import List, Tuple, Optional +from functools import lru_cache import numpy as np import torch @@ -12,6 +13,7 @@ from loguru import logger from ..core.config import settings from .tts_model import TTSModel +from .audio import AudioService, AudioNormalizer class TTSService: @@ -24,6 +26,12 @@ class TTSService: text = str(text) if text is not None else "" return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()] + @staticmethod + @lru_cache(maxsize=20) # Cache up to 8 most recently used voices + def _load_voice(voice_path: str) -> torch.Tensor: + """Load and cache a voice model""" + return torch.load(voice_path, map_location=TTSModel.get_device(), weights_only=True) + def _get_voice_path(self, voice_name: str) -> Optional[str]: """Get the path to a voice file""" voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice_name}.pt") @@ -31,6 +39,13 @@ class TTSService: def _generate_audio( self, text: str, voice: str, speed: float, stitch_long_output: bool = True + ) -> Tuple[torch.Tensor, float]: + """Generate complete audio and return with processing time""" + audio, processing_time = self._generate_audio_internal(text, voice, speed, stitch_long_output) + return audio, processing_time + + def _generate_audio_internal( + self, text: str, voice: str, speed: float, stitch_long_output: bool = True ) -> Tuple[torch.Tensor, float]: """Generate audio and measure processing time""" start_time = time.time() @@ -49,10 +64,8 @@ class TTSService: if not voice_path: raise ValueError(f"Voice not found: {voice}") - # Load voice - voicepack = torch.load( - voice_path, map_location=TTSModel.get_device(), weights_only=True - ) + # Load voice using cached loader + voicepack = self._load_voice(voice_path) # Generate audio with or without stitching if stitch_long_output: @@ -97,6 +110,78 @@ class TTSService: logger.error(f"Error in audio generation: {str(e)}") raise + async def generate_audio_stream( + self, text: str, voice: str, speed: float, output_format: str = "wav" + ): + """Generate and yield audio chunks as they're generated for real-time streaming""" + try: + # Create normalizer for consistent audio levels + stream_normalizer = AudioNormalizer() + + # Input validation and preprocessing + if not text: + raise ValueError("Text is empty") + normalized = normalize_text(text) + if not normalized: + raise ValueError("Text is empty after preprocessing") + text = str(normalized) + + # Voice validation and loading + voice_path = self._get_voice_path(voice) + if not voice_path: + raise ValueError(f"Voice not found: {voice}") + voicepack = self._load_voice(voice_path) + + # Split text into smaller chunks for faster streaming + # Use shorter chunks for real-time delivery + chunks = [] + sentences = self._split_text(text) + current_chunk = [] + current_length = 0 + target_length = 100 # Target ~100 characters per chunk for faster processing + + for sentence in sentences: + current_chunk.append(sentence) + current_length += len(sentence) + if current_length >= target_length: + chunks.append(" ".join(current_chunk)) + current_chunk = [] + current_length = 0 + + if current_chunk: + chunks.append(" ".join(current_chunk)) + + # Process and stream chunks + for i, chunk in enumerate(chunks): + try: + # Process text and generate audio + phonemes, tokens = TTSModel.process_text(chunk, voice[0]) + chunk_audio = TTSModel.generate_from_tokens(tokens, voicepack, speed) + + if chunk_audio is not None: + # Convert chunk with proper header handling + chunk_bytes = AudioService.convert_audio( + chunk_audio, + 24000, + output_format, + is_first_chunk=(i == 0), + normalizer=stream_normalizer + ) + yield chunk_bytes + else: + logger.error(f"No audio generated for chunk {i + 1}/{len(chunks)}") + + except Exception as e: + logger.error( + f"Failed to generate audio for chunk {i + 1}/{len(chunks)}: '{chunk}'. Error: {str(e)}" + ) + continue + + except Exception as e: + logger.error(f"Error in audio generation stream: {str(e)}") + raise + + def _save_audio(self, audio: torch.Tensor, filepath: str): """Save audio to file""" os.makedirs(os.path.dirname(filepath), exist_ok=True) diff --git a/api/src/structures/schemas.py b/api/src/structures/schemas.py index bc778bb..92d188e 100644 --- a/api/src/structures/schemas.py +++ b/api/src/structures/schemas.py @@ -22,7 +22,7 @@ class OpenAISpeechRequest(BaseModel): ) response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field( default="mp3", - description="The format to return audio in. Supported formats: mp3, opus, flac, wav. AAC and PCM are not currently supported.", + description="The format to return audio in. Supported formats: mp3, opus, flac, wav, pcm. PCM format returns raw 16-bit samples without headers. AAC is not currently supported.", ) speed: float = Field( default=1.0, @@ -30,3 +30,7 @@ class OpenAISpeechRequest(BaseModel): le=4.0, description="The speed of the generated audio. Select a value from 0.25 to 4.0.", ) + stream: bool = Field( + default=False, + description="If true, audio will be streamed as it's generated. Each chunk will be a complete sentence.", + ) diff --git a/docker-compose.yml b/docker-compose.yml index 2e7a86f..7308745 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -46,14 +46,14 @@ services: model-fetcher: condition: service_healthy - # Gradio UI service [Comment out everything below if you don't need it] - gradio-ui: - build: - context: ./ui - ports: - - "7860:7860" - volumes: - - ./ui/data:/app/ui/data - - ./ui/app.py:/app/app.py # Mount app.py for hot reload - environment: - - GRADIO_WATCH=True # Enable hot reloading + # # Gradio UI service [Comment out everything below if you don't need it] + # gradio-ui: + # build: + # context: ./ui + # ports: + # - "7860:7860" + # volumes: + # - ./ui/data:/app/ui/data + # - ./ui/app.py:/app/app.py # Mount app.py for hot reload + # environment: + # - GRADIO_WATCH=True # Enable hot reloading diff --git a/examples/assorted_checks/benchmarks/benchmark_first_token.py b/examples/assorted_checks/benchmarks/benchmark_first_token.py new file mode 100644 index 0000000..6709876 --- /dev/null +++ b/examples/assorted_checks/benchmarks/benchmark_first_token.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +import os +import time +import json +import numpy as np +import requests +import pandas as pd +from lib.shared_benchmark_utils import get_text_for_tokens, enc +from lib.shared_utils import save_json_results +from lib.shared_plotting import plot_correlation, plot_timeline + +def measure_first_token(text: str, output_dir: str, tokens: int, run_number: int) -> dict: + """Measure time to audio via API calls and save the audio output""" + results = { + "text_length": len(text), + "token_count": len(enc.encode(text)), + "total_time": None, + "time_to_first_chunk": None, + "error": None, + "audio_path": None, + "audio_length": None # Length of output audio in seconds + } + + try: + start_time = time.time() + + # Make request without streaming + response = requests.post( + "http://localhost:8880/v1/audio/speech", + json={ + "model": "kokoro", + "input": text, + "voice": "af", + "response_format": "wav", + "stream": False + }, + timeout=1800 + ) + response.raise_for_status() + + # Save complete audio + audio_filename = f"benchmark_tokens{tokens}_run{run_number}.wav" + audio_path = os.path.join(output_dir, audio_filename) + results["audio_path"] = audio_path + + content = response.content + with open(audio_path, 'wb') as f: + f.write(content) + + # Calculate audio length using scipy + import scipy.io.wavfile as wavfile + sample_rate, audio_data = wavfile.read(audio_path) + results["audio_length"] = len(audio_data) / sample_rate # Length in seconds + results["time_to_first_chunk"] = time.time() - start_time + + results["total_time"] = time.time() - start_time + return results + + except Exception as e: + results["error"] = str(e) + return results + +def main(): + # Set up paths + script_dir = os.path.dirname(os.path.abspath(__file__)) + output_dir = os.path.join(script_dir, "output_audio") + output_data_dir = os.path.join(script_dir, "output_data") + + # Create output directories + os.makedirs(output_dir, exist_ok=True) + os.makedirs(output_data_dir, exist_ok=True) + + # Load sample text + with open(os.path.join(script_dir, "the_time_machine_hg_wells.txt"), "r", encoding="utf-8") as f: + text = f.read() + + # Test specific token counts + token_sizes = [10, 25, 50, 100, 200, 500] + all_results = [] + + for tokens in token_sizes: + print(f"\nTesting {tokens} tokens") + test_text = get_text_for_tokens(text, tokens) + actual_tokens = len(enc.encode(test_text)) + print(f"Text preview: {test_text[:50]}...") + + # Run test 3 times for each size to get average + for i in range(5): + print(f"Run {i+1}/3...") + result = measure_first_token(test_text, output_dir, tokens, i + 1) + result["target_tokens"] = tokens + result["actual_tokens"] = actual_tokens + result["run_number"] = i + 1 + + print(f"Time to Audio: {result.get('time_to_first_chunk', 'N/A'):.3f}s") + print(f"Total time: {result.get('total_time', 'N/A'):.3f}s") + + if result["error"]: + print(f"Error: {result['error']}") + + all_results.append(result) + + # Calculate averages per token size + summary = {} + for tokens in token_sizes: + matching_results = [r for r in all_results if r["target_tokens"] == tokens and not r["error"]] + if matching_results: + avg_first_chunk = sum(r["time_to_first_chunk"] for r in matching_results) / len(matching_results) + avg_total = sum(r["total_time"] for r in matching_results) / len(matching_results) + avg_audio_length = sum(r["audio_length"] for r in matching_results) / len(matching_results) + summary[tokens] = { + "avg_time_to_first_chunk": round(avg_first_chunk, 3), + "avg_total_time": round(avg_total, 3), + "avg_audio_length": round(avg_audio_length, 3), + "num_successful_runs": len(matching_results) + } + + # Save results + # Save results + results_data = { + "individual_runs": all_results, + "summary": summary, + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S") + } + save_json_results( + results_data, + os.path.join(output_data_dir, "first_token_benchmark.json") + ) + + # Create plot directory if it doesn't exist + output_plots_dir = os.path.join(script_dir, "output_plots") + os.makedirs(output_plots_dir, exist_ok=True) + + # Create DataFrame for plotting + df = pd.DataFrame(all_results) + + # Create both plots + plot_correlation( + df, "target_tokens", "time_to_first_chunk", + "Time to Audio vs Input Size", + "Number of Input Tokens", + "Time to Audio (seconds)", + os.path.join(output_plots_dir, "first_token_latency.png") + ) + + plot_timeline( + df, + os.path.join(output_plots_dir, "first_token_timeline.png") + ) + + print("\nResults and plots saved to:") + print(f"- {os.path.join(output_data_dir, 'first_token_benchmark.json')}") + print(f"- {os.path.join(output_plots_dir, 'first_token_latency.png')}") + print(f"- {os.path.join(output_plots_dir, 'first_token_timeline.png')}") + +if __name__ == "__main__": + main() diff --git a/examples/assorted_checks/benchmarks/benchmark_first_token_stream.py b/examples/assorted_checks/benchmarks/benchmark_first_token_stream.py new file mode 100644 index 0000000..275cd91 --- /dev/null +++ b/examples/assorted_checks/benchmarks/benchmark_first_token_stream.py @@ -0,0 +1,207 @@ +#!/usr/bin/env python3 +import os +import time +import json +import numpy as np +import requests +import pandas as pd +from lib.shared_benchmark_utils import get_text_for_tokens, enc +from lib.shared_utils import save_json_results +from lib.shared_plotting import plot_correlation, plot_timeline + +def measure_first_token(text: str, output_dir: str, tokens: int, run_number: int) -> dict: + """Measure time to audio via API calls and save the audio output""" + results = { + "text_length": len(text), + "token_count": len(enc.encode(text)), + "total_time": None, + "time_to_first_chunk": None, + "error": None, + "audio_path": None, + "audio_length": None # Length of output audio in seconds + } + + try: + start_time = time.time() + + # Make request with streaming enabled + response = requests.post( + "http://localhost:8880/v1/audio/speech", + json={ + "model": "kokoro", + "input": text, + "voice": "af", + "response_format": "wav", + "stream": True + }, + stream=True, + timeout=1800 + ) + response.raise_for_status() + + # Save complete audio + audio_filename = f"benchmark_tokens{tokens}_run{run_number}_stream.wav" + audio_path = os.path.join(output_dir, audio_filename) + results["audio_path"] = audio_path + + first_chunk_time = None + chunks = [] + for chunk in response.iter_content(chunk_size=1024): + if chunk: + if first_chunk_time is None: + first_chunk_time = time.time() + results["time_to_first_chunk"] = first_chunk_time - start_time + chunks.append(chunk) + + # Extract WAV header and data separately + # First chunk has header + data, subsequent chunks are raw PCM + if not chunks: + raise ValueError("No audio chunks received") + + first_chunk = chunks[0] + remaining_chunks = chunks[1:] + + # Find end of WAV header (44 bytes for standard WAV) + header = first_chunk[:44] + first_data = first_chunk[44:] + + # Concatenate all PCM data + all_data = first_data + b''.join(remaining_chunks) + + # Update WAV header with total data size + import struct + data_size = len(all_data) + # Update data size field (bytes 4-7) + header = header[:4] + struct.pack(' dict: """ - Quick validation checks for TTS-generated audio files to detect common artifacts. - - Checks for: - - Unnatural silence gaps - - Audio glitches and artifacts - - Repeated speech segments (stuck/looping) - - Abrupt changes in speech - - Audio quality issues - - Args: - wav_path: Path to audio file (wav, mp3, etc) - Returns: - Dictionary with validation results + Validation checks for TTS-generated audio files to detect common artifacts. """ try: - # Load audio + # Load and process audio audio, sr = sf.read(wav_path) if len(audio.shape) > 1: - audio = audio.mean(axis=1) # Convert to mono - - # Basic audio stats + audio = np.mean(audio, axis=1) + duration = len(audio) / sr - rms = np.sqrt(np.mean(audio**2)) - peak = np.max(np.abs(audio)) - dc_offset = np.mean(audio) - - # Calculate clipping stats if we're near peak - clip_count = np.sum(np.abs(audio) >= 0.99) - clip_percent = (clip_count / len(audio)) * 100 - if clip_percent > 0: - clip_stats = f" ({clip_percent:.2e} ratio near peak)" - else: - clip_stats = " (no samples near peak)" - - # Convert to dB for analysis - eps = np.finfo(float).eps - db = 20 * np.log10(np.abs(audio) + eps) - issues = [] - # Check if audio is too short (likely failed generation) - if duration < 0.1: # Less than 100ms - issues.append("WARNING: Audio is suspiciously short - possible failed generation") + # Basic quality checks + abs_audio = np.abs(audio) + stats = { + 'rms': float(np.sqrt(np.mean(audio**2))), + 'peak': float(np.max(abs_audio)), + 'dc_offset': float(np.mean(audio)) + } - # 1. Check for basic audio quality - if peak >= 1.0: - # Calculate percentage of samples that are clipping - clip_count = np.sum(np.abs(audio) >= 0.99) - clip_percent = (clip_count / len(audio)) * 100 + clip_count = np.sum(abs_audio >= 0.99) + clip_percent = (clip_count / len(audio)) * 100 + + if duration < 0.1: + issues.append("WARNING: Audio is suspiciously short - possible failed generation") - if clip_percent > 1.0: # Only warn if more than 1% of samples clip + if stats['peak'] >= 1.0: + if clip_percent > 1.0: issues.append(f"WARNING: Significant clipping detected ({clip_percent:.2e}% of samples)") - elif clip_percent > 0.01: # Add info if more than 0.01% but less than 1% - issues.append(f"INFO: Minor peak limiting detected ({clip_percent:.2e}% of samples) - likely intentional normalization") + elif clip_percent > 0.01: + issues.append(f"INFO: Minor peak limiting detected ({clip_percent:.2e}% of samples)") - if rms < 0.01: + if stats['rms'] < 0.01: issues.append("WARNING: Audio is very quiet - possible failed generation") - if abs(dc_offset) > 0.1: # DC offset is particularly bad for speech - issues.append(f"WARNING: High DC offset ({dc_offset:.3f}) - may cause audio artifacts") - # 2. Check for long silence gaps (potential TTS failures) + if abs(stats['dc_offset']) > 0.1: + issues.append(f"WARNING: High DC offset ({stats['dc_offset']:.3f})") + + # Check for long silence gaps + eps = np.finfo(float).eps + db = 20 * np.log10(abs_audio + eps) silence_threshold = -45 # dB - min_silence = 2.0 # Only detect silences longer than 2 seconds + min_silence = 2.0 # seconds window_size = int(min_silence * sr) silence_count = 0 last_silence = -1 - # Skip the first 0.2s for silence detection (avoid false positives at start) - start_idx = int(0.2 * sr) + start_idx = int(0.2 * sr) # Skip first 0.2s for i in range(start_idx, len(db) - window_size, window_size): window = db[i:i+window_size] if np.mean(window) < silence_threshold: - # Verify the entire window is mostly silence silent_ratio = np.mean(window < silence_threshold) - if silent_ratio > 0.9: # 90% of the window should be below threshold - if last_silence == -1 or (i/sr - last_silence) > 2.0: # Only count silences more than 2s apart + if silent_ratio > 0.9: + if last_silence == -1 or (i/sr - last_silence) > 2.0: silence_count += 1 last_silence = i/sr issues.append(f"WARNING: Long silence detected at {i/sr:.2f}s (duration: {min_silence:.1f}s)") - if silence_count > 2: # Only warn if there are multiple long silences - issues.append(f"WARNING: Multiple long silences found ({silence_count} total) - possible generation issue") - - # 3. Check for extreme audio artifacts (changes too rapid for natural speech) - # Use a longer window to avoid flagging normal phoneme transitions - window_size = int(0.02 * sr) # 20ms window - db_smooth = np.convolve(db, np.ones(window_size)/window_size, 'same') - db_diff = np.abs(np.diff(db_smooth)) + if silence_count > 2: + issues.append(f"WARNING: Multiple long silences found ({silence_count} total)") + + # Detect audio artifacts + diff = np.diff(audio) + abs_diff = np.abs(diff) + window_size = min(int(0.005 * sr), 256) + window = np.ones(window_size)/window_size + local_avg_diff = np.convolve(abs_diff, window, mode='same') - # Much higher threshold to only catch truly unnatural changes - artifact_threshold = 40 # dB - min_duration = int(0.01 * sr) # Minimum 10ms duration + spikes = (abs_diff > (10 * local_avg_diff)) & (abs_diff > 0.1) + artifact_indices = np.nonzero(spikes)[0] - # Find regions where the smoothed dB change is extreme - artifact_points = np.where(db_diff > artifact_threshold)[0] - - if len(artifact_points) > 0: - # Group artifacts that are very close together - grouped_artifacts = [] - current_group = [artifact_points[0]] + artifacts = [] + if len(artifact_indices) > 0: + gaps = np.diff(artifact_indices) + min_gap = int(0.005 * sr) + break_points = np.nonzero(gaps > min_gap)[0] + 1 + groups = np.split(artifact_indices, break_points) - for i in range(1, len(artifact_points)): - if (artifact_points[i] - current_group[-1]) < min_duration: - current_group.append(artifact_points[i]) - else: - if len(current_group) * (1/sr) >= 0.01: # Only keep groups lasting >= 10ms - grouped_artifacts.append(current_group) - current_group = [artifact_points[i]] - - if len(current_group) * (1/sr) >= 0.01: - grouped_artifacts.append(current_group) - - # Report only the most severe artifacts - for group in grouped_artifacts[:2]: # Report up to 2 worst artifacts - center_idx = group[len(group)//2] - db_change = db_diff[center_idx] - if db_change > 45: # Only report very extreme changes - issues.append( - f"WARNING: Possible audio artifact at {center_idx/sr:.2f}s " - f"({db_change:.1f}dB change over {len(group)/sr*1000:.0f}ms)" - ) - - # 4. Check for repeated speech segments (stuck/looping) - # Check both short and long sentence durations at audiobook speed (150-160 wpm) - for chunk_duration in [5.0, 10.0]: # 5s (~12 words) and 10s (~25 words) at ~audiobook speed + for group in groups: + if len(group) >= 5: + severity = np.max(abs_diff[group]) + if severity > 0.2: + center_idx = group[len(group)//2] + artifacts.append({ + 'time': float(center_idx/sr), # Ensure float for consistent timing + 'severity': float(severity) + }) + issues.append( + f"WARNING: Audio discontinuity at {center_idx/sr:.3f}s " + f"(severity: {severity:.3f})" + ) + + # Check for repeated speech segments + for chunk_duration in [5.0, 10.0]: chunk_size = int(chunk_duration * sr) - overlap = int(0.2 * chunk_size) # 20% overlap between chunks + overlap = int(0.2 * chunk_size) for i in range(0, len(audio) - 2*chunk_size, overlap): chunk1 = audio[i:i+chunk_size] chunk2 = audio[i+chunk_size:i+2*chunk_size] - # Ignore chunks that are mostly silence if np.mean(np.abs(chunk1)) < 0.01 or np.mean(np.abs(chunk2)) < 0.01: continue try: correlation = np.corrcoef(chunk1, chunk2)[0,1] - if not np.isnan(correlation) and correlation > 0.92: # Lower threshold for sentence-length chunks + if not np.isnan(correlation) and correlation > 0.92: issues.append( f"WARNING: Possible repeated speech at {i/sr:.1f}s " f"(~{int(chunk_duration*160/60):d} words, correlation: {correlation:.3f})" ) - break # Found repetition at this duration, try next duration + break except: continue - - # 5. Check for extreme amplitude discontinuities (common in failed TTS) - amplitude_envelope = np.abs(audio) - window_size = sr // 10 # 100ms window for smoother envelope - smooth_env = np.convolve(amplitude_envelope, np.ones(window_size)/float(window_size), 'same') - env_diff = np.abs(np.diff(smooth_env)) - - # Only detect very extreme amplitude changes - jump_threshold = 0.5 # Much higher threshold - jumps = np.where(env_diff > jump_threshold)[0] - - if len(jumps) > 0: - # Group jumps that are close together - grouped_jumps = [] - current_group = [jumps[0]] - - for i in range(1, len(jumps)): - if (jumps[i] - current_group[-1]) < 0.05 * sr: # Group within 50ms - current_group.append(jumps[i]) - else: - if len(current_group) >= 3: # Only keep significant discontinuities - grouped_jumps.append(current_group) - current_group = [jumps[i]] - - if len(current_group) >= 3: - grouped_jumps.append(current_group) - - # Report only the most severe discontinuities - for group in grouped_jumps[:2]: # Report up to 2 worst cases - center_idx = group[len(group)//2] - jump_size = env_diff[center_idx] - if jump_size > 0.6: # Only report very extreme changes - issues.append( - f"WARNING: Possible audio discontinuity at {center_idx/sr:.2f}s " - f"({jump_size:.2f} amplitude ratio change)" - ) - + return { "file": wav_path, "duration": f"{duration:.2f}s", "sample_rate": sr, - "peak_amplitude": f"{peak:.3f}{clip_stats}", - "rms_level": f"{rms:.3f}", - "dc_offset": f"{dc_offset:.3f}", + "peak_amplitude": f"{stats['peak']:.3f}", + "rms_level": f"{stats['rms']:.3f}", + "dc_offset": f"{stats['dc_offset']:.3f}", + "artifact_count": len(artifacts), + "artifact_locations": [a['time'] for a in artifacts], + "artifact_severities": [a['severity'] for a in artifacts], "issues": issues, "valid": len(issues) == 0 } @@ -206,12 +141,78 @@ def validate_tts(wav_path: str) -> dict: "valid": False } -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="TTS Output Validator") - parser.add_argument("wav_file", help="Path to audio file to validate") - args = parser.parse_args() +def generate_analysis_plots(wav_path: str, output_dir: str, validation_result: Dict[str, Any]): + """ + Generate analysis plots for audio file with time-aligned visualizations. + """ + import matplotlib.pyplot as plt + from scipy.signal import spectrogram - result = validate_tts(args.wav_file) + # Load audio + audio, sr = sf.read(wav_path) + if len(audio.shape) > 1: + audio = np.mean(audio, axis=1) + + # Create figure with shared x-axis + fig = plt.figure(figsize=(15, 8)) + gs = plt.GridSpec(2, 1, height_ratios=[1.2, 0.8], hspace=0.1) + ax1 = fig.add_subplot(gs[0]) + ax2 = fig.add_subplot(gs[1], sharex=ax1) + + # Calculate spectrogram + nperseg = 2048 + noverlap = 1536 + f, t, Sxx = spectrogram(audio, sr, nperseg=nperseg, noverlap=noverlap, + window='hann', scaling='spectrum') + + # Plot spectrogram + im = ax1.pcolormesh(t, f, 10 * np.log10(Sxx + 1e-10), + shading='gouraud', cmap='viridis', + vmin=-100, vmax=-20) + ax1.set_ylabel('Frequency [Hz]', fontsize=10) + cbar = plt.colorbar(im, ax=ax1, label='dB') + ax1.set_title('Spectrogram', pad=10, fontsize=12) + + # Plot waveform with exact time alignment + times = np.arange(len(audio)) / sr + ax2.plot(times, audio, color='#2E5596', alpha=0.7, linewidth=0.5, label='Audio') + ax2.set_ylabel('Amplitude', fontsize=10) + ax2.set_xlabel('Time [sec]', fontsize=10) + ax2.grid(True, alpha=0.2) + + # Add artifact markers + if 'artifact_locations' in validation_result and validation_result['artifact_locations']: + for loc in validation_result['artifact_locations']: + ax1.axvline(x=loc, color='red', alpha=0.7, linewidth=2) + ax2.axvline(x=loc, color='red', alpha=0.7, linewidth=2, label='Detected Artifacts') + + # Add legend to both plots + if len(validation_result['artifact_locations']) > 0: + ax1.plot([], [], color='red', linewidth=2, label='Detected Artifacts') + ax1.legend(loc='upper right', fontsize=8) + # Only add unique labels to legend + handles, labels = ax2.get_legend_handles_labels() + unique_labels = dict(zip(labels, handles)) + ax2.legend(unique_labels.values(), unique_labels.keys(), + loc='upper right', fontsize=8) + + # Set common x limits + xlim = (0, len(audio)/sr) + ax1.set_xlim(xlim) + ax2.set_xlim(xlim) + og_filename = Path(wav_path).name.split(".")[0] + # Save plot + plt.savefig(Path(output_dir) / f"{og_filename}_audio_analysis.png", dpi=300, bbox_inches='tight') + plt.close() + +if __name__ == "__main__": + wav_file = r"C:\Users\jerem\Desktop\Kokoro-FastAPI\examples\output.wav" + silent=False + + result = validate_tts(wav_file) + if not silent: + wav_root_dir = Path(wav_file).parent + generate_analysis_plots(wav_file, wav_root_dir, result) print(f"\nValidating: {result['file']}") if "error" in result: @@ -222,10 +223,11 @@ if __name__ == "__main__": print(f"Peak Amplitude: {result['peak_amplitude']}") print(f"RMS Level: {result['rms_level']}") print(f"DC Offset: {result['dc_offset']}") + print(f"Detected Artifacts: {result['artifact_count']}") if result["issues"]: print("\nIssues Found:") for issue in result["issues"]: print(f"- {issue}") else: - print("\nNo issues found") + print("\nNo issues found") \ No newline at end of file diff --git a/examples/audio_analysis.png b/examples/audio_analysis.png new file mode 100644 index 0000000..7c3034f Binary files /dev/null and b/examples/audio_analysis.png differ diff --git a/examples/output.wav b/examples/output.wav new file mode 100644 index 0000000..4f23759 Binary files /dev/null and b/examples/output.wav differ diff --git a/examples/output_audio_analysis.png b/examples/output_audio_analysis.png new file mode 100644 index 0000000..8d0541d Binary files /dev/null and b/examples/output_audio_analysis.png differ diff --git a/examples/stream_tts_playback.py b/examples/stream_tts_playback.py new file mode 100644 index 0000000..5670f50 --- /dev/null +++ b/examples/stream_tts_playback.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 +import requests +import sounddevice as sd +import numpy as np +import time +import os +import wave + +def play_streaming_tts(text: str, output_file: str = None, voice: str = "af"): + """Stream TTS audio and play it back in real-time""" + + print("\nStarting TTS stream request...") + start_time = time.time() + + # Initialize variables + sample_rate = 24000 # Known sample rate for Kokoro + audio_started = False + stream = None + chunk_count = 0 + total_bytes = 0 + first_chunk_time = None + all_audio_data = bytearray() # Raw PCM audio data + + # Make streaming request to API + try: + response = requests.post( + "http://localhost:8880/v1/audio/speech", + json={ + "model": "kokoro", + "input": text, + "voice": voice, + "response_format": "pcm", + "stream": True + }, + stream=True, + timeout=1800 + ) + response.raise_for_status() + print(f"Request started successfully after {time.time() - start_time:.2f}s") + + # Process streaming response + for chunk in response.iter_content(chunk_size=1024): + if chunk: + chunk_count += 1 + total_bytes += len(chunk) + + # Handle first chunk + if not audio_started: + first_chunk_time = time.time() + print(f"\nReceived first chunk after {first_chunk_time - start_time:.2f}s") + print(f"First chunk size: {len(chunk)} bytes") + + # Accumulate raw audio data + all_audio_data.extend(chunk) + + # Convert PCM to float32 for playback + audio_data = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) + # Scale to [-1, 1] range for sounddevice + audio_data = audio_data / 32768.0 + + # Start audio stream + stream = sd.OutputStream( + samplerate=sample_rate, + channels=1, + dtype=np.float32 + ) + stream.start() + audio_started = True + print("Audio playback started") + + # Play first chunk + if len(audio_data) > 0: + stream.write(audio_data) + + # Handle subsequent chunks + else: + # Accumulate raw audio data + all_audio_data.extend(chunk) + + # Convert PCM to float32 for playback + audio_data = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) + audio_data = audio_data / 32768.0 + if len(audio_data) > 0: + stream.write(audio_data) + + # Log progress every 10 chunks + if chunk_count % 10 == 0: + elapsed = time.time() - start_time + print(f"Progress: {chunk_count} chunks, {total_bytes/1024:.1f}KB received, {elapsed:.1f}s elapsed") + + # Final stats + total_time = time.time() - start_time + print(f"\nStream complete:") + print(f"Total chunks: {chunk_count}") + print(f"Total data: {total_bytes/1024:.1f}KB") + print(f"Total time: {total_time:.2f}s") + print(f"Average speed: {(total_bytes/1024)/total_time:.1f}KB/s") + + # Save as WAV file + if output_file: + print(f"\nWriting audio to {output_file}") + with wave.open(output_file, 'wb') as wav_file: + wav_file.setnchannels(1) # Mono + wav_file.setsampwidth(2) # 2 bytes per sample (16-bit) + wav_file.setframerate(sample_rate) + wav_file.writeframes(all_audio_data) + print(f"Saved {len(all_audio_data)} bytes of audio data") + + # Clean up + if stream is not None: + stream.stop() + stream.close() + + except requests.exceptions.ConnectionError as e: + print(f"Connection error - Is the server running? Error: {str(e)}") + if stream is not None: + stream.stop() + stream.close() + except Exception as e: + print(f"Error during streaming: {str(e)}") + if stream is not None: + stream.stop() + stream.close() + +def main(): + # Load sample text from HG Wells + script_dir = os.path.dirname(os.path.abspath(__file__)) + wells_path = os.path.join(script_dir, "assorted_checks/benchmarks/the_time_machine_hg_wells.txt") + output_path = os.path.join(script_dir, "output.wav") + + with open(wells_path, "r", encoding="utf-8") as f: + full_text = f.read() + # Take first few paragraphs + text = " ".join(full_text.split("\n\n")[:2]) + + print("\nStarting TTS stream playback...") + print(f"Text length: {len(text)} characters") + print("\nFirst 100 characters:") + print(text[:100] + "...") + + play_streaming_tts(text, output_file=output_path) + +if __name__ == "__main__": + main()