#!/usr/bin/env python3 # Compatible with both Windows and Linux """ Kokoro TTS Race Condition Test This script creates multiple concurrent requests to a Kokoro TTS service to reproduce a race condition where audio outputs don't match the requested text. Each thread generates a simple numbered sentence, which should make mismatches easy to identify through listening. To run: python kokoro_race_condition_test.py --threads 8 --iterations 5 --url http://localhost:8880 """ import argparse import base64 import concurrent.futures import json import os import sys import time import wave from pathlib import Path import requests def setup_args(): """Parse command line arguments""" parser = argparse.ArgumentParser(description="Test Kokoro TTS for race conditions") parser.add_argument("--url", default="http://localhost:8880", help="Base URL of the Kokoro TTS service") parser.add_argument("--threads", type=int, default=8, help="Number of concurrent threads to use") parser.add_argument("--iterations", type=int, default=5, help="Number of iterations per thread") parser.add_argument("--voice", default="af_heart", help="Voice to use for TTS") parser.add_argument("--output-dir", default="./tts_test_output", help="Directory to save output files") parser.add_argument("--debug", action="store_true", help="Enable debug logging") return parser.parse_args() def generate_test_sentence(thread_id, iteration): """Generate a simple test sentence with numbers to make mismatches easily identifiable""" return f"This is test sentence number {thread_id}-{iteration}. " \ f"If you hear this sentence, you should hear the numbers {thread_id}-{iteration}." def log_message(message, debug=False, is_error=False): """Log messages with timestamps""" timestamp = time.strftime("%Y-%m-%d %H:%M:%S") prefix = "[ERROR]" if is_error else "[INFO]" if is_error or debug: print(f"{prefix} {timestamp} - {message}") sys.stdout.flush() # Ensure logs are visible in Docker output def request_tts(url, test_id, text, voice, output_dir, debug=False): """Request TTS from the Kokoro API and save the WAV output""" start_time = time.time() output_file = os.path.join(output_dir, f"test_{test_id}.wav") text_file = os.path.join(output_dir, f"test_{test_id}.txt") # Log output paths for debugging log_message(f"Thread {test_id}: Text will be saved to: {text_file}", debug) log_message(f"Thread {test_id}: Audio will be saved to: {output_file}", debug) # Save the text for later comparison try: with open(text_file, "w") as f: f.write(text) log_message(f"Thread {test_id}: Successfully saved text file", debug) except Exception as e: log_message(f"Thread {test_id}: Error saving text file: {str(e)}", debug, is_error=True) # Make the TTS request try: log_message(f"Thread {test_id}: Requesting TTS for: '{text}'", debug) response = requests.post( f"{url}/v1/audio/speech", json={ "model": "kokoro", "input": text, "voice": voice, "response_format": "wav" }, headers={"Accept": "audio/wav"}, timeout=60 # Increase timeout to 60 seconds ) log_message(f"Thread {test_id}: Response status code: {response.status_code}", debug) log_message(f"Thread {test_id}: Response content type: {response.headers.get('Content-Type', 'None')}", debug) log_message(f"Thread {test_id}: Response content length: {len(response.content)} bytes", debug) if response.status_code != 200: log_message(f"Thread {test_id}: API error: {response.status_code} - {response.text}", debug, is_error=True) return False # Check if we got valid audio data if len(response.content) < 100: # Sanity check - WAV files should be larger than this log_message(f"Thread {test_id}: Received suspiciously small audio data: {len(response.content)} bytes", debug, is_error=True) log_message(f"Thread {test_id}: Content (base64): {base64.b64encode(response.content).decode('utf-8')}", debug, is_error=True) return False # Save the audio output with explicit error handling try: with open(output_file, "wb") as f: bytes_written = f.write(response.content) log_message(f"Thread {test_id}: Wrote {bytes_written} bytes to {output_file}", debug) # Verify the WAV file exists and has content if os.path.exists(output_file): file_size = os.path.getsize(output_file) log_message(f"Thread {test_id}: Verified file exists with size: {file_size} bytes", debug) # Validate WAV file by reading its headers try: with wave.open(output_file, 'rb') as wav_file: channels = wav_file.getnchannels() sample_width = wav_file.getsampwidth() framerate = wav_file.getframerate() frames = wav_file.getnframes() log_message(f"Thread {test_id}: Valid WAV file - channels: {channels}, " f"sample width: {sample_width}, framerate: {framerate}, frames: {frames}", debug) except Exception as wav_error: log_message(f"Thread {test_id}: Invalid WAV file: {str(wav_error)}", debug, is_error=True) else: log_message(f"Thread {test_id}: File was not created: {output_file}", debug, is_error=True) except Exception as save_error: log_message(f"Thread {test_id}: Error saving audio file: {str(save_error)}", debug, is_error=True) return False end_time = time.time() log_message(f"Thread {test_id}: Saved output to {output_file} (time: {end_time - start_time:.2f}s)", debug) return True except requests.exceptions.Timeout: log_message(f"Thread {test_id}: Request timed out", debug, is_error=True) return False except Exception as e: log_message(f"Thread {test_id}: Exception: {str(e)}", debug, is_error=True) return False def worker_task(thread_id, args): """Worker task for each thread""" for i in range(args.iterations): iteration = i + 1 test_id = f"{thread_id:02d}_{iteration:02d}" text = generate_test_sentence(thread_id, iteration) success = request_tts(args.url, test_id, text, args.voice, args.output_dir, args.debug) if not success: log_message(f"Thread {thread_id}: Iteration {iteration} failed", args.debug, is_error=True) # Small delay between iterations to avoid overwhelming the API time.sleep(0.1) def run_test(args): """Run the test with the specified parameters""" # Ensure output directory exists and check permissions os.makedirs(args.output_dir, exist_ok=True) # Test write access to the output directory test_file = os.path.join(args.output_dir, "write_test.txt") try: with open(test_file, "w") as f: f.write("Testing write access\n") os.remove(test_file) log_message(f"Successfully verified write access to output directory: {args.output_dir}") except Exception as e: log_message(f"Warning: Cannot write to output directory {args.output_dir}: {str(e)}", is_error=True) log_message(f"Current directory: {os.getcwd()}", is_error=True) log_message(f"Directory contents: {os.listdir('.')}", is_error=True) # Test connection to Kokoro TTS service try: response = requests.get(f"{args.url}/health", timeout=5) if response.status_code == 200: log_message(f"Successfully connected to Kokoro TTS service at {args.url}") else: log_message(f"Warning: Kokoro TTS service health check returned status {response.status_code}", is_error=True) except Exception as e: log_message(f"Warning: Cannot connect to Kokoro TTS service at {args.url}: {str(e)}", is_error=True) # Record start time start_time = time.time() log_message(f"Starting test with {args.threads} threads, {args.iterations} iterations per thread") # Create and start worker threads with concurrent.futures.ThreadPoolExecutor(max_workers=args.threads) as executor: futures = [] for thread_id in range(1, args.threads + 1): futures.append(executor.submit(worker_task, thread_id, args)) # Wait for all tasks to complete for future in concurrent.futures.as_completed(futures): try: future.result() except Exception as e: log_message(f"Thread execution failed: {str(e)}", args.debug, is_error=True) # Record end time and print summary end_time = time.time() total_time = end_time - start_time total_requests = args.threads * args.iterations log_message(f"Test completed in {total_time:.2f} seconds") log_message(f"Total requests: {total_requests}") log_message(f"Average time per request: {total_time / total_requests:.2f} seconds") log_message(f"Requests per second: {total_requests / total_time:.2f}") log_message(f"Output files saved to: {os.path.abspath(args.output_dir)}") log_message("To verify, listen to the audio files and check if they match the text files") log_message("If you hear audio describing a different test number than the filename, you've found a race condition") def analyze_audio_files(output_dir): """Provide summary of the generated audio files""" # Look for both WAV and TXT files wav_files = list(Path(output_dir).glob("*.wav")) txt_files = list(Path(output_dir).glob("*.txt")) log_message(f"Found {len(wav_files)} WAV files and {len(txt_files)} TXT files") if len(wav_files) == 0: log_message("No WAV files found! This indicates the TTS service requests may be failing.", is_error=True) log_message("Check the connection to the TTS service and the response status codes above.", is_error=True) file_stats = [] for wav_path in wav_files: try: with wave.open(str(wav_path), 'rb') as wav_file: frames = wav_file.getnframes() rate = wav_file.getframerate() duration = frames / rate # Get corresponding text text_path = wav_path.with_suffix('.txt') if text_path.exists(): with open(text_path, 'r') as text_file: text = text_file.read().strip() else: text = "N/A" file_stats.append({ 'filename': wav_path.name, 'duration': duration, 'text': text }) except Exception as e: log_message(f"Error analyzing {wav_path}: {str(e)}", False, is_error=True) # Print summary table if file_stats: log_message("\nAudio File Summary:") log_message(f"{'Filename':<20}{'Duration':<12}{'Text':<60}") log_message("-" * 92) for stat in file_stats: log_message(f"{stat['filename']:<20}{stat['duration']:<12.2f}{stat['text'][:57]+'...' if len(stat['text']) > 60 else stat['text']:<60}") # List missing WAV files where text files exist missing_wavs = set(p.stem for p in txt_files) - set(p.stem for p in wav_files) if missing_wavs: log_message(f"\nFound {len(missing_wavs)} text files without corresponding WAV files:", is_error=True) for stem in sorted(list(missing_wavs))[:10]: # Limit to 10 for readability log_message(f" - {stem}.txt (no WAV file)", is_error=True) if len(missing_wavs) > 10: log_message(f" ... and {len(missing_wavs) - 10} more", is_error=True) if __name__ == "__main__": args = setup_args() run_test(args) analyze_audio_files(args.output_dir) log_message("\nNext Steps:") log_message("1. Listen to the generated audio files") log_message("2. Verify if each audio correctly says its ID number") log_message("3. Check for any mismatches between the audio content and the text files") log_message("4. If mismatches are found, you've successfully reproduced the race condition")