mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Enhance ONNX optimization settings and add validation script for TTS audio files
This commit is contained in:
parent
7df2a68fb4
commit
93aa205da9
13 changed files with 1976 additions and 2340 deletions
|
@ -19,6 +19,14 @@ class Settings(BaseSettings):
|
||||||
voices_dir: str = "voices"
|
voices_dir: str = "voices"
|
||||||
sample_rate: int = 24000
|
sample_rate: int = 24000
|
||||||
|
|
||||||
|
# ONNX Optimization Settings
|
||||||
|
onnx_num_threads: int = 4 # Number of threads for intra-op parallelism
|
||||||
|
onnx_inter_op_threads: int = 4 # Number of threads for inter-op parallelism
|
||||||
|
onnx_execution_mode: str = "parallel" # parallel or sequential
|
||||||
|
onnx_optimization_level: str = "all" # all, basic, or disabled
|
||||||
|
onnx_memory_pattern: bool = True # Enable memory pattern optimization
|
||||||
|
onnx_arena_extend_strategy: str = "kNextPowerOfTwo" # Memory allocation strategy
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_file = ".env"
|
env_file = ".env"
|
||||||
|
|
||||||
|
|
|
@ -31,14 +31,33 @@ class TTSCPUModel(TTSBaseModel):
|
||||||
|
|
||||||
# Configure ONNX session for optimal performance
|
# Configure ONNX session for optimal performance
|
||||||
session_options = SessionOptions()
|
session_options = SessionOptions()
|
||||||
|
|
||||||
|
# Set optimization level
|
||||||
|
if settings.onnx_optimization_level == "all":
|
||||||
session_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
session_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||||
session_options.intra_op_num_threads = 4 # Adjust based on CPU cores
|
elif settings.onnx_optimization_level == "basic":
|
||||||
session_options.execution_mode = ExecutionMode.ORT_SEQUENTIAL
|
session_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||||
|
else:
|
||||||
|
session_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
|
||||||
|
|
||||||
|
# Configure threading
|
||||||
|
session_options.intra_op_num_threads = settings.onnx_num_threads
|
||||||
|
session_options.inter_op_num_threads = settings.onnx_inter_op_threads
|
||||||
|
|
||||||
|
# Set execution mode
|
||||||
|
session_options.execution_mode = (
|
||||||
|
ExecutionMode.ORT_PARALLEL
|
||||||
|
if settings.onnx_execution_mode == "parallel"
|
||||||
|
else ExecutionMode.ORT_SEQUENTIAL
|
||||||
|
)
|
||||||
|
|
||||||
|
# Enable/disable memory pattern optimization
|
||||||
|
session_options.enable_mem_pattern = settings.onnx_memory_pattern
|
||||||
|
|
||||||
# Configure CPU provider options
|
# Configure CPU provider options
|
||||||
provider_options = {
|
provider_options = {
|
||||||
'CPUExecutionProvider': {
|
'CPUExecutionProvider': {
|
||||||
'arena_extend_strategy': 'kNextPowerOfTwo',
|
'arena_extend_strategy': settings.onnx_arena_extend_strategy,
|
||||||
'cpu_memory_arena_cfg': 'cpu:0'
|
'cpu_memory_arena_cfg': 'cpu:0'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,6 +36,13 @@ services:
|
||||||
- "8880:8880"
|
- "8880:8880"
|
||||||
environment:
|
environment:
|
||||||
- PYTHONPATH=/app:/app/Kokoro-82M
|
- PYTHONPATH=/app:/app/Kokoro-82M
|
||||||
|
# ONNX Optimization Settings for vectorized operations
|
||||||
|
- ONNX_NUM_THREADS=8 # Maximize core usage for vectorized ops
|
||||||
|
- ONNX_INTER_OP_THREADS=4 # Higher inter-op for parallel matrix operations
|
||||||
|
- ONNX_EXECUTION_MODE=parallel
|
||||||
|
- ONNX_OPTIMIZATION_LEVEL=all
|
||||||
|
- ONNX_MEMORY_PATTERN=true
|
||||||
|
- ONNX_ARENA_EXTEND_STRATEGY=kNextPowerOfTwo
|
||||||
depends_on:
|
depends_on:
|
||||||
model-fetcher:
|
model-fetcher:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
|
|
0
examples/assorted_checks/__init__.py
Normal file
0
examples/assorted_checks/__init__.py
Normal file
|
@ -60,7 +60,7 @@ def main():
|
||||||
# Initialize system monitor
|
# Initialize system monitor
|
||||||
monitor = SystemMonitor(interval=1.0) # 1 second interval
|
monitor = SystemMonitor(interval=1.0) # 1 second interval
|
||||||
# Set prefix for output files (e.g. "gpu", "cpu", "onnx", etc.)
|
# Set prefix for output files (e.g. "gpu", "cpu", "onnx", etc.)
|
||||||
prefix = "gpu"
|
prefix = "cpu_2_1_seq"
|
||||||
# Generate token sizes
|
# Generate token sizes
|
||||||
if 'gpu' in prefix:
|
if 'gpu' in prefix:
|
||||||
token_sizes = generate_token_sizes(
|
token_sizes = generate_token_sizes(
|
||||||
|
@ -68,8 +68,8 @@ def main():
|
||||||
dense_max=1000, sparse_step=1000)
|
dense_max=1000, sparse_step=1000)
|
||||||
elif 'cpu' in prefix:
|
elif 'cpu' in prefix:
|
||||||
token_sizes = generate_token_sizes(
|
token_sizes = generate_token_sizes(
|
||||||
max_tokens=1000, dense_step=150,
|
max_tokens=1000, dense_step=300,
|
||||||
dense_max=800, sparse_step=0)
|
dense_max=1000, sparse_step=0)
|
||||||
else:
|
else:
|
||||||
token_sizes = generate_token_sizes(max_tokens=3000)
|
token_sizes = generate_token_sizes(max_tokens=3000)
|
||||||
|
|
||||||
|
@ -122,6 +122,7 @@ def main():
|
||||||
|
|
||||||
# Calculate RTF using the correct formula
|
# Calculate RTF using the correct formula
|
||||||
rtf = real_time_factor(processing_time, audio_length)
|
rtf = real_time_factor(processing_time, audio_length)
|
||||||
|
print(f"Real-Time Factor: {rtf:.5f}")
|
||||||
|
|
||||||
results.append({
|
results.append({
|
||||||
"tokens": actual_tokens,
|
"tokens": actual_tokens,
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,23 @@
|
||||||
|
=== Benchmark Statistics (with correct RTF) ===
|
||||||
|
|
||||||
|
Total tokens processed: 1800
|
||||||
|
Total audio generated (s): 568.53
|
||||||
|
Total test duration (s): 244.10
|
||||||
|
Average processing rate (tokens/s): 7.34
|
||||||
|
Average RTF: 0.43
|
||||||
|
Average Real Time Speed: 2.33
|
||||||
|
|
||||||
|
=== Per-chunk Stats ===
|
||||||
|
|
||||||
|
Average chunk size (tokens): 600.00
|
||||||
|
Min chunk size (tokens): 300
|
||||||
|
Max chunk size (tokens): 900
|
||||||
|
Average processing time (s): 81.30
|
||||||
|
Average output length (s): 189.51
|
||||||
|
|
||||||
|
=== Performance Ranges ===
|
||||||
|
|
||||||
|
Processing rate range (tokens/s): 7.21 - 7.47
|
||||||
|
RTF range: 0.43x - 0.43x
|
||||||
|
Real Time Speed range: 2.33x - 2.33x
|
||||||
|
|
|
@ -1,23 +0,0 @@
|
||||||
=== Benchmark Statistics (with correct RTF) ===
|
|
||||||
|
|
||||||
Total tokens processed: 2250
|
|
||||||
Total audio generated (s): 710.80
|
|
||||||
Total test duration (s): 332.81
|
|
||||||
Average processing rate (tokens/s): 6.77
|
|
||||||
Average RTF: 0.47
|
|
||||||
Average Real Time Speed: 2.14
|
|
||||||
|
|
||||||
=== Per-chunk Stats ===
|
|
||||||
|
|
||||||
Average chunk size (tokens): 450.00
|
|
||||||
Min chunk size (tokens): 150
|
|
||||||
Max chunk size (tokens): 750
|
|
||||||
Average processing time (s): 66.51
|
|
||||||
Average output length (s): 142.16
|
|
||||||
|
|
||||||
=== Performance Ranges ===
|
|
||||||
|
|
||||||
Processing rate range (tokens/s): 6.50 - 7.00
|
|
||||||
RTF range: 0.45x - 0.50x
|
|
||||||
Real Time Speed range: 2.00x - 2.22x
|
|
||||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 234 KiB After Width: | Height: | Size: 231 KiB |
Binary file not shown.
Before Width: | Height: | Size: 212 KiB After Width: | Height: | Size: 181 KiB |
Binary file not shown.
Before Width: | Height: | Size: 449 KiB After Width: | Height: | Size: 454 KiB |
231
examples/assorted_checks/validate_wav.py
Normal file
231
examples/assorted_checks/validate_wav.py
Normal file
|
@ -0,0 +1,231 @@
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def validate_tts(wav_path: str) -> dict:
|
||||||
|
"""
|
||||||
|
Quick validation checks for TTS-generated audio files to detect common artifacts.
|
||||||
|
|
||||||
|
Checks for:
|
||||||
|
- Unnatural silence gaps
|
||||||
|
- Audio glitches and artifacts
|
||||||
|
- Repeated speech segments (stuck/looping)
|
||||||
|
- Abrupt changes in speech
|
||||||
|
- Audio quality issues
|
||||||
|
|
||||||
|
Args:
|
||||||
|
wav_path: Path to audio file (wav, mp3, etc)
|
||||||
|
Returns:
|
||||||
|
Dictionary with validation results
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Load audio
|
||||||
|
audio, sr = sf.read(wav_path)
|
||||||
|
if len(audio.shape) > 1:
|
||||||
|
audio = audio.mean(axis=1) # Convert to mono
|
||||||
|
|
||||||
|
# Basic audio stats
|
||||||
|
duration = len(audio) / sr
|
||||||
|
rms = np.sqrt(np.mean(audio**2))
|
||||||
|
peak = np.max(np.abs(audio))
|
||||||
|
dc_offset = np.mean(audio)
|
||||||
|
|
||||||
|
# Calculate clipping stats if we're near peak
|
||||||
|
clip_count = np.sum(np.abs(audio) >= 0.99)
|
||||||
|
clip_percent = (clip_count / len(audio)) * 100
|
||||||
|
if clip_percent > 0:
|
||||||
|
clip_stats = f" ({clip_percent:.2e} ratio near peak)"
|
||||||
|
else:
|
||||||
|
clip_stats = " (no samples near peak)"
|
||||||
|
|
||||||
|
# Convert to dB for analysis
|
||||||
|
eps = np.finfo(float).eps
|
||||||
|
db = 20 * np.log10(np.abs(audio) + eps)
|
||||||
|
|
||||||
|
issues = []
|
||||||
|
|
||||||
|
# Check if audio is too short (likely failed generation)
|
||||||
|
if duration < 0.1: # Less than 100ms
|
||||||
|
issues.append("WARNING: Audio is suspiciously short - possible failed generation")
|
||||||
|
|
||||||
|
# 1. Check for basic audio quality
|
||||||
|
if peak >= 1.0:
|
||||||
|
# Calculate percentage of samples that are clipping
|
||||||
|
clip_count = np.sum(np.abs(audio) >= 0.99)
|
||||||
|
clip_percent = (clip_count / len(audio)) * 100
|
||||||
|
|
||||||
|
if clip_percent > 1.0: # Only warn if more than 1% of samples clip
|
||||||
|
issues.append(f"WARNING: Significant clipping detected ({clip_percent:.2e}% of samples)")
|
||||||
|
elif clip_percent > 0.01: # Add info if more than 0.01% but less than 1%
|
||||||
|
issues.append(f"INFO: Minor peak limiting detected ({clip_percent:.2e}% of samples) - likely intentional normalization")
|
||||||
|
|
||||||
|
if rms < 0.01:
|
||||||
|
issues.append("WARNING: Audio is very quiet - possible failed generation")
|
||||||
|
if abs(dc_offset) > 0.1: # DC offset is particularly bad for speech
|
||||||
|
issues.append(f"WARNING: High DC offset ({dc_offset:.3f}) - may cause audio artifacts")
|
||||||
|
|
||||||
|
# 2. Check for long silence gaps (potential TTS failures)
|
||||||
|
silence_threshold = -45 # dB
|
||||||
|
min_silence = 2.0 # Only detect silences longer than 2 seconds
|
||||||
|
window_size = int(min_silence * sr)
|
||||||
|
silence_count = 0
|
||||||
|
last_silence = -1
|
||||||
|
|
||||||
|
# Skip the first 0.2s for silence detection (avoid false positives at start)
|
||||||
|
start_idx = int(0.2 * sr)
|
||||||
|
for i in range(start_idx, len(db) - window_size, window_size):
|
||||||
|
window = db[i:i+window_size]
|
||||||
|
if np.mean(window) < silence_threshold:
|
||||||
|
# Verify the entire window is mostly silence
|
||||||
|
silent_ratio = np.mean(window < silence_threshold)
|
||||||
|
if silent_ratio > 0.9: # 90% of the window should be below threshold
|
||||||
|
if last_silence == -1 or (i/sr - last_silence) > 2.0: # Only count silences more than 2s apart
|
||||||
|
silence_count += 1
|
||||||
|
last_silence = i/sr
|
||||||
|
issues.append(f"WARNING: Long silence detected at {i/sr:.2f}s (duration: {min_silence:.1f}s)")
|
||||||
|
|
||||||
|
if silence_count > 2: # Only warn if there are multiple long silences
|
||||||
|
issues.append(f"WARNING: Multiple long silences found ({silence_count} total) - possible generation issue")
|
||||||
|
|
||||||
|
# 3. Check for extreme audio artifacts (changes too rapid for natural speech)
|
||||||
|
# Use a longer window to avoid flagging normal phoneme transitions
|
||||||
|
window_size = int(0.02 * sr) # 20ms window
|
||||||
|
db_smooth = np.convolve(db, np.ones(window_size)/window_size, 'same')
|
||||||
|
db_diff = np.abs(np.diff(db_smooth))
|
||||||
|
|
||||||
|
# Much higher threshold to only catch truly unnatural changes
|
||||||
|
artifact_threshold = 40 # dB
|
||||||
|
min_duration = int(0.01 * sr) # Minimum 10ms duration
|
||||||
|
|
||||||
|
# Find regions where the smoothed dB change is extreme
|
||||||
|
artifact_points = np.where(db_diff > artifact_threshold)[0]
|
||||||
|
|
||||||
|
if len(artifact_points) > 0:
|
||||||
|
# Group artifacts that are very close together
|
||||||
|
grouped_artifacts = []
|
||||||
|
current_group = [artifact_points[0]]
|
||||||
|
|
||||||
|
for i in range(1, len(artifact_points)):
|
||||||
|
if (artifact_points[i] - current_group[-1]) < min_duration:
|
||||||
|
current_group.append(artifact_points[i])
|
||||||
|
else:
|
||||||
|
if len(current_group) * (1/sr) >= 0.01: # Only keep groups lasting >= 10ms
|
||||||
|
grouped_artifacts.append(current_group)
|
||||||
|
current_group = [artifact_points[i]]
|
||||||
|
|
||||||
|
if len(current_group) * (1/sr) >= 0.01:
|
||||||
|
grouped_artifacts.append(current_group)
|
||||||
|
|
||||||
|
# Report only the most severe artifacts
|
||||||
|
for group in grouped_artifacts[:2]: # Report up to 2 worst artifacts
|
||||||
|
center_idx = group[len(group)//2]
|
||||||
|
db_change = db_diff[center_idx]
|
||||||
|
if db_change > 45: # Only report very extreme changes
|
||||||
|
issues.append(
|
||||||
|
f"WARNING: Possible audio artifact at {center_idx/sr:.2f}s "
|
||||||
|
f"({db_change:.1f}dB change over {len(group)/sr*1000:.0f}ms)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Check for repeated speech segments (stuck/looping)
|
||||||
|
# Check both short and long sentence durations at audiobook speed (150-160 wpm)
|
||||||
|
for chunk_duration in [5.0, 10.0]: # 5s (~12 words) and 10s (~25 words) at ~audiobook speed
|
||||||
|
chunk_size = int(chunk_duration * sr)
|
||||||
|
overlap = int(0.2 * chunk_size) # 20% overlap between chunks
|
||||||
|
|
||||||
|
for i in range(0, len(audio) - 2*chunk_size, overlap):
|
||||||
|
chunk1 = audio[i:i+chunk_size]
|
||||||
|
chunk2 = audio[i+chunk_size:i+2*chunk_size]
|
||||||
|
|
||||||
|
# Ignore chunks that are mostly silence
|
||||||
|
if np.mean(np.abs(chunk1)) < 0.01 or np.mean(np.abs(chunk2)) < 0.01:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
correlation = np.corrcoef(chunk1, chunk2)[0,1]
|
||||||
|
if not np.isnan(correlation) and correlation > 0.92: # Lower threshold for sentence-length chunks
|
||||||
|
issues.append(
|
||||||
|
f"WARNING: Possible repeated speech at {i/sr:.1f}s "
|
||||||
|
f"(~{int(chunk_duration*160/60):d} words, correlation: {correlation:.3f})"
|
||||||
|
)
|
||||||
|
break # Found repetition at this duration, try next duration
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 5. Check for extreme amplitude discontinuities (common in failed TTS)
|
||||||
|
amplitude_envelope = np.abs(audio)
|
||||||
|
window_size = sr // 10 # 100ms window for smoother envelope
|
||||||
|
smooth_env = np.convolve(amplitude_envelope, np.ones(window_size)/float(window_size), 'same')
|
||||||
|
env_diff = np.abs(np.diff(smooth_env))
|
||||||
|
|
||||||
|
# Only detect very extreme amplitude changes
|
||||||
|
jump_threshold = 0.5 # Much higher threshold
|
||||||
|
jumps = np.where(env_diff > jump_threshold)[0]
|
||||||
|
|
||||||
|
if len(jumps) > 0:
|
||||||
|
# Group jumps that are close together
|
||||||
|
grouped_jumps = []
|
||||||
|
current_group = [jumps[0]]
|
||||||
|
|
||||||
|
for i in range(1, len(jumps)):
|
||||||
|
if (jumps[i] - current_group[-1]) < 0.05 * sr: # Group within 50ms
|
||||||
|
current_group.append(jumps[i])
|
||||||
|
else:
|
||||||
|
if len(current_group) >= 3: # Only keep significant discontinuities
|
||||||
|
grouped_jumps.append(current_group)
|
||||||
|
current_group = [jumps[i]]
|
||||||
|
|
||||||
|
if len(current_group) >= 3:
|
||||||
|
grouped_jumps.append(current_group)
|
||||||
|
|
||||||
|
# Report only the most severe discontinuities
|
||||||
|
for group in grouped_jumps[:2]: # Report up to 2 worst cases
|
||||||
|
center_idx = group[len(group)//2]
|
||||||
|
jump_size = env_diff[center_idx]
|
||||||
|
if jump_size > 0.6: # Only report very extreme changes
|
||||||
|
issues.append(
|
||||||
|
f"WARNING: Possible audio discontinuity at {center_idx/sr:.2f}s "
|
||||||
|
f"({jump_size:.2f} amplitude ratio change)"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"file": wav_path,
|
||||||
|
"duration": f"{duration:.2f}s",
|
||||||
|
"sample_rate": sr,
|
||||||
|
"peak_amplitude": f"{peak:.3f}{clip_stats}",
|
||||||
|
"rms_level": f"{rms:.3f}",
|
||||||
|
"dc_offset": f"{dc_offset:.3f}",
|
||||||
|
"issues": issues,
|
||||||
|
"valid": len(issues) == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"file": wav_path,
|
||||||
|
"error": str(e),
|
||||||
|
"valid": False
|
||||||
|
}
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="TTS Output Validator")
|
||||||
|
parser.add_argument("wav_file", help="Path to audio file to validate")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
result = validate_tts(args.wav_file)
|
||||||
|
|
||||||
|
print(f"\nValidating: {result['file']}")
|
||||||
|
if "error" in result:
|
||||||
|
print(f"Error: {result['error']}")
|
||||||
|
else:
|
||||||
|
print(f"Duration: {result['duration']}")
|
||||||
|
print(f"Sample Rate: {result['sample_rate']} Hz")
|
||||||
|
print(f"Peak Amplitude: {result['peak_amplitude']}")
|
||||||
|
print(f"RMS Level: {result['rms_level']}")
|
||||||
|
print(f"DC Offset: {result['dc_offset']}")
|
||||||
|
|
||||||
|
if result["issues"]:
|
||||||
|
print("\nIssues Found:")
|
||||||
|
for issue in result["issues"]:
|
||||||
|
print(f"- {issue}")
|
||||||
|
else:
|
||||||
|
print("\nNo issues found")
|
72
examples/assorted_checks/validate_wavs.py
Normal file
72
examples/assorted_checks/validate_wavs.py
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from validate_wav import validate_tts
|
||||||
|
|
||||||
|
def print_validation_result(result: dict, rel_path: Path):
|
||||||
|
"""Print full validation details for a single file."""
|
||||||
|
print(f"\nValidating: {rel_path}")
|
||||||
|
if "error" in result:
|
||||||
|
print(f"Error: {result['error']}")
|
||||||
|
else:
|
||||||
|
print(f"Duration: {result['duration']}")
|
||||||
|
print(f"Sample Rate: {result['sample_rate']} Hz")
|
||||||
|
print(f"Peak Amplitude: {result['peak_amplitude']}")
|
||||||
|
print(f"RMS Level: {result['rms_level']}")
|
||||||
|
print(f"DC Offset: {result['dc_offset']}")
|
||||||
|
|
||||||
|
if result["issues"]:
|
||||||
|
print("\nIssues Found:")
|
||||||
|
for issue in result["issues"]:
|
||||||
|
print(f"- {issue}")
|
||||||
|
else:
|
||||||
|
print("\nNo issues found")
|
||||||
|
|
||||||
|
def validate_directory(directory: str):
|
||||||
|
"""Validate all wav files in a directory with detailed output and summary."""
|
||||||
|
dir_path = Path(directory)
|
||||||
|
|
||||||
|
# Find all wav files (including nested directories)
|
||||||
|
wav_files = list(dir_path.rglob("*.wav"))
|
||||||
|
wav_files.extend(dir_path.rglob("*.mp3")) # Also check mp3s
|
||||||
|
wav_files = sorted(wav_files)
|
||||||
|
|
||||||
|
if not wav_files:
|
||||||
|
print(f"No .wav or .mp3 files found in {directory}")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"Found {len(wav_files)} files in {directory}")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Store results for summary
|
||||||
|
results = []
|
||||||
|
|
||||||
|
# Detailed validation output
|
||||||
|
for wav_file in wav_files:
|
||||||
|
result = validate_tts(str(wav_file))
|
||||||
|
rel_path = wav_file.relative_to(dir_path)
|
||||||
|
print_validation_result(result, rel_path)
|
||||||
|
results.append((rel_path, result))
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Summary with detailed issues
|
||||||
|
print("\nSUMMARY:")
|
||||||
|
for rel_path, result in results:
|
||||||
|
if "error" in result:
|
||||||
|
print(f"{rel_path}: ERROR - {result['error']}")
|
||||||
|
elif result["issues"]:
|
||||||
|
# Show first issue in summary, indicate if there are more
|
||||||
|
issues = result["issues"]
|
||||||
|
first_issue = issues[0].replace("WARNING: ", "")
|
||||||
|
if len(issues) > 1:
|
||||||
|
print(f"{rel_path}: FAIL - {first_issue} (+{len(issues)-1} more issues)")
|
||||||
|
else:
|
||||||
|
print(f"{rel_path}: FAIL - {first_issue}")
|
||||||
|
else:
|
||||||
|
print(f"{rel_path}: PASS")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Batch validate TTS wav files")
|
||||||
|
parser.add_argument("directory", help="Directory containing wav files to validate")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
validate_directory(args.directory)
|
Loading…
Add table
Reference in a new issue