mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Enhance ONNX optimization settings and add validation script for TTS audio files
This commit is contained in:
parent
7df2a68fb4
commit
93aa205da9
13 changed files with 1976 additions and 2340 deletions
|
@ -18,6 +18,14 @@ class Settings(BaseSettings):
|
|||
onnx_model_path: str = "kokoro-v0_19.onnx"
|
||||
voices_dir: str = "voices"
|
||||
sample_rate: int = 24000
|
||||
|
||||
# ONNX Optimization Settings
|
||||
onnx_num_threads: int = 4 # Number of threads for intra-op parallelism
|
||||
onnx_inter_op_threads: int = 4 # Number of threads for inter-op parallelism
|
||||
onnx_execution_mode: str = "parallel" # parallel or sequential
|
||||
onnx_optimization_level: str = "all" # all, basic, or disabled
|
||||
onnx_memory_pattern: bool = True # Enable memory pattern optimization
|
||||
onnx_arena_extend_strategy: str = "kNextPowerOfTwo" # Memory allocation strategy
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
|
|
@ -31,14 +31,33 @@ class TTSCPUModel(TTSBaseModel):
|
|||
|
||||
# Configure ONNX session for optimal performance
|
||||
session_options = SessionOptions()
|
||||
session_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
session_options.intra_op_num_threads = 4 # Adjust based on CPU cores
|
||||
session_options.execution_mode = ExecutionMode.ORT_SEQUENTIAL
|
||||
|
||||
# Set optimization level
|
||||
if settings.onnx_optimization_level == "all":
|
||||
session_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
elif settings.onnx_optimization_level == "basic":
|
||||
session_options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC
|
||||
else:
|
||||
session_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
|
||||
|
||||
# Configure threading
|
||||
session_options.intra_op_num_threads = settings.onnx_num_threads
|
||||
session_options.inter_op_num_threads = settings.onnx_inter_op_threads
|
||||
|
||||
# Set execution mode
|
||||
session_options.execution_mode = (
|
||||
ExecutionMode.ORT_PARALLEL
|
||||
if settings.onnx_execution_mode == "parallel"
|
||||
else ExecutionMode.ORT_SEQUENTIAL
|
||||
)
|
||||
|
||||
# Enable/disable memory pattern optimization
|
||||
session_options.enable_mem_pattern = settings.onnx_memory_pattern
|
||||
|
||||
# Configure CPU provider options
|
||||
provider_options = {
|
||||
'CPUExecutionProvider': {
|
||||
'arena_extend_strategy': 'kNextPowerOfTwo',
|
||||
'arena_extend_strategy': settings.onnx_arena_extend_strategy,
|
||||
'cpu_memory_arena_cfg': 'cpu:0'
|
||||
}
|
||||
}
|
||||
|
|
|
@ -36,6 +36,13 @@ services:
|
|||
- "8880:8880"
|
||||
environment:
|
||||
- PYTHONPATH=/app:/app/Kokoro-82M
|
||||
# ONNX Optimization Settings for vectorized operations
|
||||
- ONNX_NUM_THREADS=8 # Maximize core usage for vectorized ops
|
||||
- ONNX_INTER_OP_THREADS=4 # Higher inter-op for parallel matrix operations
|
||||
- ONNX_EXECUTION_MODE=parallel
|
||||
- ONNX_OPTIMIZATION_LEVEL=all
|
||||
- ONNX_MEMORY_PATTERN=true
|
||||
- ONNX_ARENA_EXTEND_STRATEGY=kNextPowerOfTwo
|
||||
depends_on:
|
||||
model-fetcher:
|
||||
condition: service_healthy
|
||||
|
|
0
examples/assorted_checks/__init__.py
Normal file
0
examples/assorted_checks/__init__.py
Normal file
|
@ -60,7 +60,7 @@ def main():
|
|||
# Initialize system monitor
|
||||
monitor = SystemMonitor(interval=1.0) # 1 second interval
|
||||
# Set prefix for output files (e.g. "gpu", "cpu", "onnx", etc.)
|
||||
prefix = "gpu"
|
||||
prefix = "cpu_2_1_seq"
|
||||
# Generate token sizes
|
||||
if 'gpu' in prefix:
|
||||
token_sizes = generate_token_sizes(
|
||||
|
@ -68,8 +68,8 @@ def main():
|
|||
dense_max=1000, sparse_step=1000)
|
||||
elif 'cpu' in prefix:
|
||||
token_sizes = generate_token_sizes(
|
||||
max_tokens=1000, dense_step=150,
|
||||
dense_max=800, sparse_step=0)
|
||||
max_tokens=1000, dense_step=300,
|
||||
dense_max=1000, sparse_step=0)
|
||||
else:
|
||||
token_sizes = generate_token_sizes(max_tokens=3000)
|
||||
|
||||
|
@ -122,6 +122,7 @@ def main():
|
|||
|
||||
# Calculate RTF using the correct formula
|
||||
rtf = real_time_factor(processing_time, audio_length)
|
||||
print(f"Real-Time Factor: {rtf:.5f}")
|
||||
|
||||
results.append({
|
||||
"tokens": actual_tokens,
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,23 @@
|
|||
=== Benchmark Statistics (with correct RTF) ===
|
||||
|
||||
Total tokens processed: 1800
|
||||
Total audio generated (s): 568.53
|
||||
Total test duration (s): 244.10
|
||||
Average processing rate (tokens/s): 7.34
|
||||
Average RTF: 0.43
|
||||
Average Real Time Speed: 2.33
|
||||
|
||||
=== Per-chunk Stats ===
|
||||
|
||||
Average chunk size (tokens): 600.00
|
||||
Min chunk size (tokens): 300
|
||||
Max chunk size (tokens): 900
|
||||
Average processing time (s): 81.30
|
||||
Average output length (s): 189.51
|
||||
|
||||
=== Performance Ranges ===
|
||||
|
||||
Processing rate range (tokens/s): 7.21 - 7.47
|
||||
RTF range: 0.43x - 0.43x
|
||||
Real Time Speed range: 2.33x - 2.33x
|
||||
|
|
@ -1,23 +0,0 @@
|
|||
=== Benchmark Statistics (with correct RTF) ===
|
||||
|
||||
Total tokens processed: 2250
|
||||
Total audio generated (s): 710.80
|
||||
Total test duration (s): 332.81
|
||||
Average processing rate (tokens/s): 6.77
|
||||
Average RTF: 0.47
|
||||
Average Real Time Speed: 2.14
|
||||
|
||||
=== Per-chunk Stats ===
|
||||
|
||||
Average chunk size (tokens): 450.00
|
||||
Min chunk size (tokens): 150
|
||||
Max chunk size (tokens): 750
|
||||
Average processing time (s): 66.51
|
||||
Average output length (s): 142.16
|
||||
|
||||
=== Performance Ranges ===
|
||||
|
||||
Processing rate range (tokens/s): 6.50 - 7.00
|
||||
RTF range: 0.45x - 0.50x
|
||||
Real Time Speed range: 2.00x - 2.22x
|
||||
|
Binary file not shown.
Before Width: | Height: | Size: 234 KiB After Width: | Height: | Size: 231 KiB |
Binary file not shown.
Before Width: | Height: | Size: 212 KiB After Width: | Height: | Size: 181 KiB |
Binary file not shown.
Before Width: | Height: | Size: 449 KiB After Width: | Height: | Size: 454 KiB |
231
examples/assorted_checks/validate_wav.py
Normal file
231
examples/assorted_checks/validate_wav.py
Normal file
|
@ -0,0 +1,231 @@
|
|||
import numpy as np
|
||||
import soundfile as sf
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
def validate_tts(wav_path: str) -> dict:
|
||||
"""
|
||||
Quick validation checks for TTS-generated audio files to detect common artifacts.
|
||||
|
||||
Checks for:
|
||||
- Unnatural silence gaps
|
||||
- Audio glitches and artifacts
|
||||
- Repeated speech segments (stuck/looping)
|
||||
- Abrupt changes in speech
|
||||
- Audio quality issues
|
||||
|
||||
Args:
|
||||
wav_path: Path to audio file (wav, mp3, etc)
|
||||
Returns:
|
||||
Dictionary with validation results
|
||||
"""
|
||||
try:
|
||||
# Load audio
|
||||
audio, sr = sf.read(wav_path)
|
||||
if len(audio.shape) > 1:
|
||||
audio = audio.mean(axis=1) # Convert to mono
|
||||
|
||||
# Basic audio stats
|
||||
duration = len(audio) / sr
|
||||
rms = np.sqrt(np.mean(audio**2))
|
||||
peak = np.max(np.abs(audio))
|
||||
dc_offset = np.mean(audio)
|
||||
|
||||
# Calculate clipping stats if we're near peak
|
||||
clip_count = np.sum(np.abs(audio) >= 0.99)
|
||||
clip_percent = (clip_count / len(audio)) * 100
|
||||
if clip_percent > 0:
|
||||
clip_stats = f" ({clip_percent:.2e} ratio near peak)"
|
||||
else:
|
||||
clip_stats = " (no samples near peak)"
|
||||
|
||||
# Convert to dB for analysis
|
||||
eps = np.finfo(float).eps
|
||||
db = 20 * np.log10(np.abs(audio) + eps)
|
||||
|
||||
issues = []
|
||||
|
||||
# Check if audio is too short (likely failed generation)
|
||||
if duration < 0.1: # Less than 100ms
|
||||
issues.append("WARNING: Audio is suspiciously short - possible failed generation")
|
||||
|
||||
# 1. Check for basic audio quality
|
||||
if peak >= 1.0:
|
||||
# Calculate percentage of samples that are clipping
|
||||
clip_count = np.sum(np.abs(audio) >= 0.99)
|
||||
clip_percent = (clip_count / len(audio)) * 100
|
||||
|
||||
if clip_percent > 1.0: # Only warn if more than 1% of samples clip
|
||||
issues.append(f"WARNING: Significant clipping detected ({clip_percent:.2e}% of samples)")
|
||||
elif clip_percent > 0.01: # Add info if more than 0.01% but less than 1%
|
||||
issues.append(f"INFO: Minor peak limiting detected ({clip_percent:.2e}% of samples) - likely intentional normalization")
|
||||
|
||||
if rms < 0.01:
|
||||
issues.append("WARNING: Audio is very quiet - possible failed generation")
|
||||
if abs(dc_offset) > 0.1: # DC offset is particularly bad for speech
|
||||
issues.append(f"WARNING: High DC offset ({dc_offset:.3f}) - may cause audio artifacts")
|
||||
|
||||
# 2. Check for long silence gaps (potential TTS failures)
|
||||
silence_threshold = -45 # dB
|
||||
min_silence = 2.0 # Only detect silences longer than 2 seconds
|
||||
window_size = int(min_silence * sr)
|
||||
silence_count = 0
|
||||
last_silence = -1
|
||||
|
||||
# Skip the first 0.2s for silence detection (avoid false positives at start)
|
||||
start_idx = int(0.2 * sr)
|
||||
for i in range(start_idx, len(db) - window_size, window_size):
|
||||
window = db[i:i+window_size]
|
||||
if np.mean(window) < silence_threshold:
|
||||
# Verify the entire window is mostly silence
|
||||
silent_ratio = np.mean(window < silence_threshold)
|
||||
if silent_ratio > 0.9: # 90% of the window should be below threshold
|
||||
if last_silence == -1 or (i/sr - last_silence) > 2.0: # Only count silences more than 2s apart
|
||||
silence_count += 1
|
||||
last_silence = i/sr
|
||||
issues.append(f"WARNING: Long silence detected at {i/sr:.2f}s (duration: {min_silence:.1f}s)")
|
||||
|
||||
if silence_count > 2: # Only warn if there are multiple long silences
|
||||
issues.append(f"WARNING: Multiple long silences found ({silence_count} total) - possible generation issue")
|
||||
|
||||
# 3. Check for extreme audio artifacts (changes too rapid for natural speech)
|
||||
# Use a longer window to avoid flagging normal phoneme transitions
|
||||
window_size = int(0.02 * sr) # 20ms window
|
||||
db_smooth = np.convolve(db, np.ones(window_size)/window_size, 'same')
|
||||
db_diff = np.abs(np.diff(db_smooth))
|
||||
|
||||
# Much higher threshold to only catch truly unnatural changes
|
||||
artifact_threshold = 40 # dB
|
||||
min_duration = int(0.01 * sr) # Minimum 10ms duration
|
||||
|
||||
# Find regions where the smoothed dB change is extreme
|
||||
artifact_points = np.where(db_diff > artifact_threshold)[0]
|
||||
|
||||
if len(artifact_points) > 0:
|
||||
# Group artifacts that are very close together
|
||||
grouped_artifacts = []
|
||||
current_group = [artifact_points[0]]
|
||||
|
||||
for i in range(1, len(artifact_points)):
|
||||
if (artifact_points[i] - current_group[-1]) < min_duration:
|
||||
current_group.append(artifact_points[i])
|
||||
else:
|
||||
if len(current_group) * (1/sr) >= 0.01: # Only keep groups lasting >= 10ms
|
||||
grouped_artifacts.append(current_group)
|
||||
current_group = [artifact_points[i]]
|
||||
|
||||
if len(current_group) * (1/sr) >= 0.01:
|
||||
grouped_artifacts.append(current_group)
|
||||
|
||||
# Report only the most severe artifacts
|
||||
for group in grouped_artifacts[:2]: # Report up to 2 worst artifacts
|
||||
center_idx = group[len(group)//2]
|
||||
db_change = db_diff[center_idx]
|
||||
if db_change > 45: # Only report very extreme changes
|
||||
issues.append(
|
||||
f"WARNING: Possible audio artifact at {center_idx/sr:.2f}s "
|
||||
f"({db_change:.1f}dB change over {len(group)/sr*1000:.0f}ms)"
|
||||
)
|
||||
|
||||
# 4. Check for repeated speech segments (stuck/looping)
|
||||
# Check both short and long sentence durations at audiobook speed (150-160 wpm)
|
||||
for chunk_duration in [5.0, 10.0]: # 5s (~12 words) and 10s (~25 words) at ~audiobook speed
|
||||
chunk_size = int(chunk_duration * sr)
|
||||
overlap = int(0.2 * chunk_size) # 20% overlap between chunks
|
||||
|
||||
for i in range(0, len(audio) - 2*chunk_size, overlap):
|
||||
chunk1 = audio[i:i+chunk_size]
|
||||
chunk2 = audio[i+chunk_size:i+2*chunk_size]
|
||||
|
||||
# Ignore chunks that are mostly silence
|
||||
if np.mean(np.abs(chunk1)) < 0.01 or np.mean(np.abs(chunk2)) < 0.01:
|
||||
continue
|
||||
|
||||
try:
|
||||
correlation = np.corrcoef(chunk1, chunk2)[0,1]
|
||||
if not np.isnan(correlation) and correlation > 0.92: # Lower threshold for sentence-length chunks
|
||||
issues.append(
|
||||
f"WARNING: Possible repeated speech at {i/sr:.1f}s "
|
||||
f"(~{int(chunk_duration*160/60):d} words, correlation: {correlation:.3f})"
|
||||
)
|
||||
break # Found repetition at this duration, try next duration
|
||||
except:
|
||||
continue
|
||||
|
||||
# 5. Check for extreme amplitude discontinuities (common in failed TTS)
|
||||
amplitude_envelope = np.abs(audio)
|
||||
window_size = sr // 10 # 100ms window for smoother envelope
|
||||
smooth_env = np.convolve(amplitude_envelope, np.ones(window_size)/float(window_size), 'same')
|
||||
env_diff = np.abs(np.diff(smooth_env))
|
||||
|
||||
# Only detect very extreme amplitude changes
|
||||
jump_threshold = 0.5 # Much higher threshold
|
||||
jumps = np.where(env_diff > jump_threshold)[0]
|
||||
|
||||
if len(jumps) > 0:
|
||||
# Group jumps that are close together
|
||||
grouped_jumps = []
|
||||
current_group = [jumps[0]]
|
||||
|
||||
for i in range(1, len(jumps)):
|
||||
if (jumps[i] - current_group[-1]) < 0.05 * sr: # Group within 50ms
|
||||
current_group.append(jumps[i])
|
||||
else:
|
||||
if len(current_group) >= 3: # Only keep significant discontinuities
|
||||
grouped_jumps.append(current_group)
|
||||
current_group = [jumps[i]]
|
||||
|
||||
if len(current_group) >= 3:
|
||||
grouped_jumps.append(current_group)
|
||||
|
||||
# Report only the most severe discontinuities
|
||||
for group in grouped_jumps[:2]: # Report up to 2 worst cases
|
||||
center_idx = group[len(group)//2]
|
||||
jump_size = env_diff[center_idx]
|
||||
if jump_size > 0.6: # Only report very extreme changes
|
||||
issues.append(
|
||||
f"WARNING: Possible audio discontinuity at {center_idx/sr:.2f}s "
|
||||
f"({jump_size:.2f} amplitude ratio change)"
|
||||
)
|
||||
|
||||
return {
|
||||
"file": wav_path,
|
||||
"duration": f"{duration:.2f}s",
|
||||
"sample_rate": sr,
|
||||
"peak_amplitude": f"{peak:.3f}{clip_stats}",
|
||||
"rms_level": f"{rms:.3f}",
|
||||
"dc_offset": f"{dc_offset:.3f}",
|
||||
"issues": issues,
|
||||
"valid": len(issues) == 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"file": wav_path,
|
||||
"error": str(e),
|
||||
"valid": False
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="TTS Output Validator")
|
||||
parser.add_argument("wav_file", help="Path to audio file to validate")
|
||||
args = parser.parse_args()
|
||||
|
||||
result = validate_tts(args.wav_file)
|
||||
|
||||
print(f"\nValidating: {result['file']}")
|
||||
if "error" in result:
|
||||
print(f"Error: {result['error']}")
|
||||
else:
|
||||
print(f"Duration: {result['duration']}")
|
||||
print(f"Sample Rate: {result['sample_rate']} Hz")
|
||||
print(f"Peak Amplitude: {result['peak_amplitude']}")
|
||||
print(f"RMS Level: {result['rms_level']}")
|
||||
print(f"DC Offset: {result['dc_offset']}")
|
||||
|
||||
if result["issues"]:
|
||||
print("\nIssues Found:")
|
||||
for issue in result["issues"]:
|
||||
print(f"- {issue}")
|
||||
else:
|
||||
print("\nNo issues found")
|
72
examples/assorted_checks/validate_wavs.py
Normal file
72
examples/assorted_checks/validate_wavs.py
Normal file
|
@ -0,0 +1,72 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
from validate_wav import validate_tts
|
||||
|
||||
def print_validation_result(result: dict, rel_path: Path):
|
||||
"""Print full validation details for a single file."""
|
||||
print(f"\nValidating: {rel_path}")
|
||||
if "error" in result:
|
||||
print(f"Error: {result['error']}")
|
||||
else:
|
||||
print(f"Duration: {result['duration']}")
|
||||
print(f"Sample Rate: {result['sample_rate']} Hz")
|
||||
print(f"Peak Amplitude: {result['peak_amplitude']}")
|
||||
print(f"RMS Level: {result['rms_level']}")
|
||||
print(f"DC Offset: {result['dc_offset']}")
|
||||
|
||||
if result["issues"]:
|
||||
print("\nIssues Found:")
|
||||
for issue in result["issues"]:
|
||||
print(f"- {issue}")
|
||||
else:
|
||||
print("\nNo issues found")
|
||||
|
||||
def validate_directory(directory: str):
|
||||
"""Validate all wav files in a directory with detailed output and summary."""
|
||||
dir_path = Path(directory)
|
||||
|
||||
# Find all wav files (including nested directories)
|
||||
wav_files = list(dir_path.rglob("*.wav"))
|
||||
wav_files.extend(dir_path.rglob("*.mp3")) # Also check mp3s
|
||||
wav_files = sorted(wav_files)
|
||||
|
||||
if not wav_files:
|
||||
print(f"No .wav or .mp3 files found in {directory}")
|
||||
return
|
||||
|
||||
print(f"Found {len(wav_files)} files in {directory}")
|
||||
print("=" * 80)
|
||||
|
||||
# Store results for summary
|
||||
results = []
|
||||
|
||||
# Detailed validation output
|
||||
for wav_file in wav_files:
|
||||
result = validate_tts(str(wav_file))
|
||||
rel_path = wav_file.relative_to(dir_path)
|
||||
print_validation_result(result, rel_path)
|
||||
results.append((rel_path, result))
|
||||
print("=" * 80)
|
||||
|
||||
# Summary with detailed issues
|
||||
print("\nSUMMARY:")
|
||||
for rel_path, result in results:
|
||||
if "error" in result:
|
||||
print(f"{rel_path}: ERROR - {result['error']}")
|
||||
elif result["issues"]:
|
||||
# Show first issue in summary, indicate if there are more
|
||||
issues = result["issues"]
|
||||
first_issue = issues[0].replace("WARNING: ", "")
|
||||
if len(issues) > 1:
|
||||
print(f"{rel_path}: FAIL - {first_issue} (+{len(issues)-1} more issues)")
|
||||
else:
|
||||
print(f"{rel_path}: FAIL - {first_issue}")
|
||||
else:
|
||||
print(f"{rel_path}: PASS")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Batch validate TTS wav files")
|
||||
parser.add_argument("directory", help="Directory containing wav files to validate")
|
||||
args = parser.parse_args()
|
||||
|
||||
validate_directory(args.directory)
|
Loading…
Add table
Reference in a new issue