mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
378 lines
14 KiB
Python
378 lines
14 KiB
Python
#!/usr/bin/env python3
|
|
# Compatible with both Windows and Linux
|
|
"""
|
|
Kokoro TTS Race Condition Test
|
|
|
|
This script creates multiple concurrent requests to a Kokoro TTS service
|
|
to reproduce a race condition where audio outputs don't match the requested text.
|
|
Each thread generates a simple numbered sentence, which should make mismatches
|
|
easy to identify through listening.
|
|
|
|
To run:
|
|
python kokoro_race_condition_test.py --threads 8 --iterations 5 --url http://localhost:8880
|
|
"""
|
|
|
|
import argparse
|
|
import base64
|
|
import concurrent.futures
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
import wave
|
|
from pathlib import Path
|
|
|
|
import requests
|
|
|
|
|
|
def setup_args():
|
|
"""Parse command line arguments"""
|
|
parser = argparse.ArgumentParser(description="Test Kokoro TTS for race conditions")
|
|
parser.add_argument(
|
|
"--url",
|
|
default="http://localhost:8880",
|
|
help="Base URL of the Kokoro TTS service",
|
|
)
|
|
parser.add_argument(
|
|
"--threads", type=int, default=8, help="Number of concurrent threads to use"
|
|
)
|
|
parser.add_argument(
|
|
"--iterations", type=int, default=5, help="Number of iterations per thread"
|
|
)
|
|
parser.add_argument("--voice", default="af_heart", help="Voice to use for TTS")
|
|
parser.add_argument(
|
|
"--output-dir",
|
|
default="./tts_test_output",
|
|
help="Directory to save output files",
|
|
)
|
|
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
|
|
return parser.parse_args()
|
|
|
|
|
|
def generate_test_sentence(thread_id, iteration):
|
|
"""Generate a simple test sentence with numbers to make mismatches easily identifiable"""
|
|
return (
|
|
f"This is test sentence number {thread_id}-{iteration}. "
|
|
f"If you hear this sentence, you should hear the numbers {thread_id}-{iteration}."
|
|
)
|
|
|
|
|
|
def log_message(message, debug=False, is_error=False):
|
|
"""Log messages with timestamps"""
|
|
timestamp = time.strftime("%Y-%m-%d %H:%M:%S")
|
|
prefix = "[ERROR]" if is_error else "[INFO]"
|
|
if is_error or debug:
|
|
print(f"{prefix} {timestamp} - {message}")
|
|
sys.stdout.flush() # Ensure logs are visible in Docker output
|
|
|
|
|
|
def request_tts(url, test_id, text, voice, output_dir, debug=False):
|
|
"""Request TTS from the Kokoro API and save the WAV output"""
|
|
start_time = time.time()
|
|
output_file = os.path.join(output_dir, f"test_{test_id}.wav")
|
|
text_file = os.path.join(output_dir, f"test_{test_id}.txt")
|
|
|
|
# Log output paths for debugging
|
|
log_message(f"Thread {test_id}: Text will be saved to: {text_file}", debug)
|
|
log_message(f"Thread {test_id}: Audio will be saved to: {output_file}", debug)
|
|
|
|
# Save the text for later comparison
|
|
try:
|
|
with open(text_file, "w") as f:
|
|
f.write(text)
|
|
log_message(f"Thread {test_id}: Successfully saved text file", debug)
|
|
except Exception as e:
|
|
log_message(
|
|
f"Thread {test_id}: Error saving text file: {str(e)}", debug, is_error=True
|
|
)
|
|
|
|
# Make the TTS request
|
|
try:
|
|
log_message(f"Thread {test_id}: Requesting TTS for: '{text}'", debug)
|
|
|
|
response = requests.post(
|
|
f"{url}/v1/audio/speech",
|
|
json={
|
|
"model": "kokoro",
|
|
"input": text,
|
|
"voice": voice,
|
|
"response_format": "wav",
|
|
},
|
|
headers={"Accept": "audio/wav"},
|
|
timeout=60, # Increase timeout to 60 seconds
|
|
)
|
|
|
|
log_message(
|
|
f"Thread {test_id}: Response status code: {response.status_code}", debug
|
|
)
|
|
log_message(
|
|
f"Thread {test_id}: Response content type: {response.headers.get('Content-Type', 'None')}",
|
|
debug,
|
|
)
|
|
log_message(
|
|
f"Thread {test_id}: Response content length: {len(response.content)} bytes",
|
|
debug,
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
log_message(
|
|
f"Thread {test_id}: API error: {response.status_code} - {response.text}",
|
|
debug,
|
|
is_error=True,
|
|
)
|
|
return False
|
|
|
|
# Check if we got valid audio data
|
|
if (
|
|
len(response.content) < 100
|
|
): # Sanity check - WAV files should be larger than this
|
|
log_message(
|
|
f"Thread {test_id}: Received suspiciously small audio data: {len(response.content)} bytes",
|
|
debug,
|
|
is_error=True,
|
|
)
|
|
log_message(
|
|
f"Thread {test_id}: Content (base64): {base64.b64encode(response.content).decode('utf-8')}",
|
|
debug,
|
|
is_error=True,
|
|
)
|
|
return False
|
|
|
|
# Save the audio output with explicit error handling
|
|
try:
|
|
with open(output_file, "wb") as f:
|
|
bytes_written = f.write(response.content)
|
|
log_message(
|
|
f"Thread {test_id}: Wrote {bytes_written} bytes to {output_file}",
|
|
debug,
|
|
)
|
|
|
|
# Verify the WAV file exists and has content
|
|
if os.path.exists(output_file):
|
|
file_size = os.path.getsize(output_file)
|
|
log_message(
|
|
f"Thread {test_id}: Verified file exists with size: {file_size} bytes",
|
|
debug,
|
|
)
|
|
|
|
# Validate WAV file by reading its headers
|
|
try:
|
|
with wave.open(output_file, "rb") as wav_file:
|
|
channels = wav_file.getnchannels()
|
|
sample_width = wav_file.getsampwidth()
|
|
framerate = wav_file.getframerate()
|
|
frames = wav_file.getnframes()
|
|
log_message(
|
|
f"Thread {test_id}: Valid WAV file - channels: {channels}, "
|
|
f"sample width: {sample_width}, framerate: {framerate}, frames: {frames}",
|
|
debug,
|
|
)
|
|
except Exception as wav_error:
|
|
log_message(
|
|
f"Thread {test_id}: Invalid WAV file: {str(wav_error)}",
|
|
debug,
|
|
is_error=True,
|
|
)
|
|
else:
|
|
log_message(
|
|
f"Thread {test_id}: File was not created: {output_file}",
|
|
debug,
|
|
is_error=True,
|
|
)
|
|
except Exception as save_error:
|
|
log_message(
|
|
f"Thread {test_id}: Error saving audio file: {str(save_error)}",
|
|
debug,
|
|
is_error=True,
|
|
)
|
|
return False
|
|
|
|
end_time = time.time()
|
|
log_message(
|
|
f"Thread {test_id}: Saved output to {output_file} (time: {end_time - start_time:.2f}s)",
|
|
debug,
|
|
)
|
|
return True
|
|
|
|
except requests.exceptions.Timeout:
|
|
log_message(f"Thread {test_id}: Request timed out", debug, is_error=True)
|
|
return False
|
|
except Exception as e:
|
|
log_message(f"Thread {test_id}: Exception: {str(e)}", debug, is_error=True)
|
|
return False
|
|
|
|
|
|
def worker_task(thread_id, args):
|
|
"""Worker task for each thread"""
|
|
for i in range(args.iterations):
|
|
iteration = i + 1
|
|
test_id = f"{thread_id:02d}_{iteration:02d}"
|
|
text = generate_test_sentence(thread_id, iteration)
|
|
success = request_tts(
|
|
args.url, test_id, text, args.voice, args.output_dir, args.debug
|
|
)
|
|
|
|
if not success:
|
|
log_message(
|
|
f"Thread {thread_id}: Iteration {iteration} failed",
|
|
args.debug,
|
|
is_error=True,
|
|
)
|
|
|
|
# Small delay between iterations to avoid overwhelming the API
|
|
time.sleep(0.1)
|
|
|
|
|
|
def run_test(args):
|
|
"""Run the test with the specified parameters"""
|
|
# Ensure output directory exists and check permissions
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
|
|
# Test write access to the output directory
|
|
test_file = os.path.join(args.output_dir, "write_test.txt")
|
|
try:
|
|
with open(test_file, "w") as f:
|
|
f.write("Testing write access\n")
|
|
os.remove(test_file)
|
|
log_message(
|
|
f"Successfully verified write access to output directory: {args.output_dir}"
|
|
)
|
|
except Exception as e:
|
|
log_message(
|
|
f"Warning: Cannot write to output directory {args.output_dir}: {str(e)}",
|
|
is_error=True,
|
|
)
|
|
log_message(f"Current directory: {os.getcwd()}", is_error=True)
|
|
log_message(f"Directory contents: {os.listdir('.')}", is_error=True)
|
|
|
|
# Test connection to Kokoro TTS service
|
|
try:
|
|
response = requests.get(f"{args.url}/health", timeout=5)
|
|
if response.status_code == 200:
|
|
log_message(f"Successfully connected to Kokoro TTS service at {args.url}")
|
|
else:
|
|
log_message(
|
|
f"Warning: Kokoro TTS service health check returned status {response.status_code}",
|
|
is_error=True,
|
|
)
|
|
except Exception as e:
|
|
log_message(
|
|
f"Warning: Cannot connect to Kokoro TTS service at {args.url}: {str(e)}",
|
|
is_error=True,
|
|
)
|
|
|
|
# Record start time
|
|
start_time = time.time()
|
|
log_message(
|
|
f"Starting test with {args.threads} threads, {args.iterations} iterations per thread"
|
|
)
|
|
|
|
# Create and start worker threads
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=args.threads) as executor:
|
|
futures = []
|
|
for thread_id in range(1, args.threads + 1):
|
|
futures.append(executor.submit(worker_task, thread_id, args))
|
|
|
|
# Wait for all tasks to complete
|
|
for future in concurrent.futures.as_completed(futures):
|
|
try:
|
|
future.result()
|
|
except Exception as e:
|
|
log_message(
|
|
f"Thread execution failed: {str(e)}", args.debug, is_error=True
|
|
)
|
|
|
|
# Record end time and print summary
|
|
end_time = time.time()
|
|
total_time = end_time - start_time
|
|
total_requests = args.threads * args.iterations
|
|
log_message(f"Test completed in {total_time:.2f} seconds")
|
|
log_message(f"Total requests: {total_requests}")
|
|
log_message(f"Average time per request: {total_time / total_requests:.2f} seconds")
|
|
log_message(f"Requests per second: {total_requests / total_time:.2f}")
|
|
log_message(f"Output files saved to: {os.path.abspath(args.output_dir)}")
|
|
log_message(
|
|
"To verify, listen to the audio files and check if they match the text files"
|
|
)
|
|
log_message(
|
|
"If you hear audio describing a different test number than the filename, you've found a race condition"
|
|
)
|
|
|
|
|
|
def analyze_audio_files(output_dir):
|
|
"""Provide summary of the generated audio files"""
|
|
# Look for both WAV and TXT files
|
|
wav_files = list(Path(output_dir).glob("*.wav"))
|
|
txt_files = list(Path(output_dir).glob("*.txt"))
|
|
|
|
log_message(f"Found {len(wav_files)} WAV files and {len(txt_files)} TXT files")
|
|
|
|
if len(wav_files) == 0:
|
|
log_message(
|
|
"No WAV files found! This indicates the TTS service requests may be failing.",
|
|
is_error=True,
|
|
)
|
|
log_message(
|
|
"Check the connection to the TTS service and the response status codes above.",
|
|
is_error=True,
|
|
)
|
|
|
|
file_stats = []
|
|
for wav_path in wav_files:
|
|
try:
|
|
with wave.open(str(wav_path), "rb") as wav_file:
|
|
frames = wav_file.getnframes()
|
|
rate = wav_file.getframerate()
|
|
duration = frames / rate
|
|
|
|
# Get corresponding text
|
|
text_path = wav_path.with_suffix(".txt")
|
|
if text_path.exists():
|
|
with open(text_path, "r") as text_file:
|
|
text = text_file.read().strip()
|
|
else:
|
|
text = "N/A"
|
|
|
|
file_stats.append(
|
|
{"filename": wav_path.name, "duration": duration, "text": text}
|
|
)
|
|
except Exception as e:
|
|
log_message(f"Error analyzing {wav_path}: {str(e)}", False, is_error=True)
|
|
|
|
# Print summary table
|
|
if file_stats:
|
|
log_message("\nAudio File Summary:")
|
|
log_message(f"{'Filename':<20}{'Duration':<12}{'Text':<60}")
|
|
log_message("-" * 92)
|
|
for stat in file_stats:
|
|
log_message(
|
|
f"{stat['filename']:<20}{stat['duration']:<12.2f}{stat['text'][:57] + '...' if len(stat['text']) > 60 else stat['text']:<60}"
|
|
)
|
|
|
|
# List missing WAV files where text files exist
|
|
missing_wavs = set(p.stem for p in txt_files) - set(p.stem for p in wav_files)
|
|
if missing_wavs:
|
|
log_message(
|
|
f"\nFound {len(missing_wavs)} text files without corresponding WAV files:",
|
|
is_error=True,
|
|
)
|
|
for stem in sorted(list(missing_wavs))[:10]: # Limit to 10 for readability
|
|
log_message(f" - {stem}.txt (no WAV file)", is_error=True)
|
|
if len(missing_wavs) > 10:
|
|
log_message(f" ... and {len(missing_wavs) - 10} more", is_error=True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = setup_args()
|
|
run_test(args)
|
|
analyze_audio_files(args.output_dir)
|
|
|
|
log_message("\nNext Steps:")
|
|
log_message("1. Listen to the generated audio files")
|
|
log_message("2. Verify if each audio correctly says its ID number")
|
|
log_message(
|
|
"3. Check for any mismatches between the audio content and the text files"
|
|
)
|
|
log_message(
|
|
"4. If mismatches are found, you've successfully reproduced the race condition"
|
|
)
|