fix: ui stability, memory safeties

This commit is contained in:
remsky 2025-01-12 21:33:23 -07:00
parent 234445f5ae
commit f4dc292440
9 changed files with 144 additions and 66 deletions

View file

@ -7,6 +7,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
espeak-ng \
git \
libsndfile1 \
curl \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*

View file

@ -41,8 +41,6 @@ class TTSCPUModel(TTSBaseModel):
if not onnx_path:
return None
logger.info(f"Loading ONNX model from {onnx_path}")
# Configure ONNX session for optimal performance
session_options = SessionOptions()

View file

@ -38,48 +38,64 @@ from .text_processing import tokenize, phonemize
# return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
@torch.no_grad()
def forward(model, tokens, ref_s, speed):
"""Forward pass through the model with light optimizations that preserve output quality"""
"""Forward pass through the model with moderate memory management"""
device = ref_s.device
try:
# Initial tensor setup with proper device placement
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
text_mask = length_to_mask(input_lengths).to(device)
# Keep original token handling but optimize device placement
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
text_mask = length_to_mask(input_lengths).to(device)
# Split and clone reference signals with explicit device placement
s_content = ref_s[:, 128:].clone().to(device)
s_ref = ref_s[:, :128].clone().to(device)
# BERT and encoder pass
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
# BERT and encoder pass
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
# Split reference signal once for efficiency
s_content = ref_s[:, 128:]
s_ref = ref_s[:, :128]
# Predictor forward pass
d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask)
x, _ = model.predictor.lstm(d)
# Predictor forward pass
d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask)
x, _ = model.predictor.lstm(d)
# Duration prediction
duration = model.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long()
# Only cleanup large intermediates
del duration, x
# Duration prediction - keeping original logic
duration = model.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1) / speed
pred_dur = torch.round(duration).clamp(min=1).long()
# Alignment matrix construction
pred_aln_trg = torch.zeros(input_lengths.item(), pred_dur.sum().item(), device=device)
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
c_frame += pred_dur[0, i].item()
pred_aln_trg = pred_aln_trg.unsqueeze(0)
# Alignment matrix construction - keeping original approach for quality
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item(), device=device)
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
c_frame += pred_dur[0, i].item()
# Matrix multiplications with selective cleanup
en = d.transpose(-1, -2) @ pred_aln_trg
del d # Free large intermediate tensor
F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
del en # Free large intermediate tensor
# Matrix multiplications - reuse unsqueezed tensor
pred_aln_trg = pred_aln_trg.unsqueeze(0) # Do unsqueeze once
en = d.transpose(-1, -2) @ pred_aln_trg
F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
# Final text encoding and decoding
t_en = model.text_encoder(tokens, input_lengths, text_mask)
asr = t_en @ pred_aln_trg
del t_en # Free large intermediate tensor
# Text encoding and final decoding
t_en = model.text_encoder(tokens, input_lengths, text_mask)
asr = t_en @ pred_aln_trg
return model.decoder(asr, F0_pred, N_pred, s_ref).squeeze().cpu().numpy()
# Final decoding and transfer to CPU
output = model.decoder(asr, F0_pred, N_pred, s_ref)
result = output.squeeze().cpu().numpy()
return result
finally:
# Let PyTorch handle most cleanup automatically
# Only explicitly free the largest tensors
del pred_aln_trg, asr
# def length_to_mask(lengths):
@ -179,7 +195,7 @@ class TTSGPUModel(TTSBaseModel):
def generate_from_tokens(
cls, tokens: list[int], voicepack: torch.Tensor, speed: float
) -> np.ndarray:
"""Generate audio from tokens
"""Generate audio from tokens with moderate memory management
Args:
tokens: Token IDs
@ -192,10 +208,55 @@ class TTSGPUModel(TTSBaseModel):
if cls._instance is None:
raise RuntimeError("GPU model not initialized")
# Get reference style
ref_s = voicepack[len(tokens)]
# Generate audio
audio = forward(cls._instance, tokens, ref_s, speed)
return audio
try:
device = cls._device
# Check memory pressure
if torch.cuda.is_available():
memory_allocated = torch.cuda.memory_allocated(device) / 1e9 # Convert to GB
if memory_allocated > 2.0: # 2GB limit
logger.info(
f"Memory usage above 2GB threshold:{memory_allocated:.2f}GB "
f"Clearing cache"
)
torch.cuda.empty_cache()
import gc
gc.collect()
# Get reference style with proper device placement
ref_s = voicepack[len(tokens)].clone().to(device)
# Generate audio
audio = forward(cls._instance, tokens, ref_s, speed)
return audio
except RuntimeError as e:
if "out of memory" in str(e):
# On OOM, do a full cleanup and retry
if torch.cuda.is_available():
logger.warning("Out of memory detected, performing full cleanup")
torch.cuda.synchronize()
torch.cuda.empty_cache()
import gc
gc.collect()
# Log memory stats after cleanup
memory_allocated = torch.cuda.memory_allocated(device)
memory_reserved = torch.cuda.memory_reserved(device)
logger.info(
f"Memory after OOM cleanup: "
f"Allocated: {memory_allocated / 1e9:.2f}GB, "
f"Reserved: {memory_reserved / 1e9:.2f}GB"
)
# Retry generation
ref_s = voicepack[len(tokens)].clone().to(device)
audio = forward(cls._instance, tokens, ref_s, speed)
return audio
raise
finally:
# Only synchronize at the top level, no empty_cache
if torch.cuda.is_available():
torch.cuda.synchronize()

View file

@ -6,6 +6,8 @@ services:
working_dir: /app/Kokoro-82M
command: >
sh -c "
mkdir -p /app/Kokoro-82M;
cd /app/Kokoro-82M;
rm -f .git/index.lock;
if [ -z \"$(ls -A .)\" ]; then
git clone https://huggingface.co/hexgrad/Kokoro-82M .
@ -26,11 +28,11 @@ services:
start_period: 1s
kokoro-tts:
image: ghcr.io/remsky/kokoro-fastapi-cpu:latest
# image: ghcr.io/remsky/kokoro-fastapi-cpu:latest
# Uncomment below (and comment out above) to build from source instead of using the released image
# build:
# context: .
# dockerfile: Dockerfile.cpu
build:
context: .
dockerfile: Dockerfile.cpu
volumes:
- ./api/src:/app/api/src
- ./Kokoro-82M:/app/Kokoro-82M
@ -46,6 +48,12 @@ services:
- ONNX_MEMORY_PATTERN=true
- ONNX_ARENA_EXTEND_STRATEGY=kNextPowerOfTwo
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8880/health"]
interval: 10s
timeout: 5s
retries: 30
start_period: 30s
depends_on:
model-fetcher:
condition: service_healthy
@ -64,3 +72,6 @@ services:
environment:
- GRADIO_WATCH=True # Enable hot reloading
- PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
depends_on:
kokoro-tts:
condition: service_healthy

View file

@ -8,6 +8,8 @@ services:
working_dir: /app/Kokoro-82M
command: >
sh -c "
mkdir -p /app/Kokoro-82M;
cd /app/Kokoro-82M;
if [ \"$$SKIP_MODEL_FETCH\" = \"true\" ]; then
echo 'Skipping model fetch...' && touch .cloned;
else
@ -32,10 +34,10 @@ services:
start_period: 1s
kokoro-tts:
image: ghcr.io/remsky/kokoro-fastapi-gpu:latest
# image: ghcr.io/remsky/kokoro-fastapi-gpu:latest
# Uncomment below (and comment out above) to build from source instead of using the released image
# build:
# context: .
build:
context: .
volumes:
- ./api/src:/app/api/src
- ./Kokoro-82M:/app/Kokoro-82M
@ -51,7 +53,7 @@ services:
count: 1
capabilities: [gpu]
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8880/v1/audio/voices"]
test: ["CMD", "curl", "-f", "http://localhost:8880/health"]
interval: 10s
timeout: 5s
retries: 30
@ -65,7 +67,7 @@ services:
image: ghcr.io/remsky/kokoro-fastapi-ui:latest
# Uncomment below (and comment out above) to build from source instead of using the released image
# build:
# context: ./ui
# context: ./ui
ports:
- "7860:7860"
volumes:

View file

@ -21,8 +21,9 @@ def create_model_column(voice_ids: Optional[list] = None) -> Tuple[gr.Column, di
voice_input = gr.Dropdown(
choices=voice_ids,
label="Voice",
value=voice_ids[0] if voice_ids else None,
value=None, # Start with no value to avoid errors
interactive=True,
allow_custom_value=True, # Allow temporary values during updates
)
format_input = gr.Dropdown(
choices=config.AUDIO_FORMATS, label="Audio Format", value="mp3"

View file

@ -16,7 +16,7 @@ def create_output_column() -> Tuple[gr.Column, dict]:
label="Previous Outputs",
choices=files.list_output_files(),
value=None,
allow_custom_value=False,
allow_custom_value=True,
)
play_btn = gr.Button("▶️ Play Selected", size="sm")
@ -40,3 +40,4 @@ def create_output_column() -> Tuple[gr.Column, dict]:
}
return col, components
return col, components

View file

@ -12,9 +12,9 @@ def list_input_files() -> List[str]:
def list_output_files() -> List[str]:
"""List all output audio files."""
# Just return filenames since paths will be different inside/outside container
return [
os.path.join(OUTPUTS_DIR, f)
for f in os.listdir(OUTPUTS_DIR)
f for f in os.listdir(OUTPUTS_DIR)
if any(f.endswith(ext) for ext in AUDIO_FORMATS)
]

View file

@ -1,6 +1,5 @@
import os
import shutil
import gradio as gr
from . import api, files
@ -97,11 +96,12 @@ def setup_event_handlers(components: dict):
gr.Warning("Failed to generate speech. Please try again.")
return [None, gr.update(choices=files.list_output_files())]
# Update list and select the newly generated file
output_files = files.list_output_files()
last_file = output_files[-1] if output_files else None
return [
result,
gr.update(
choices=files.list_output_files(), value=os.path.basename(result)
),
gr.update(choices=output_files, value=last_file),
]
def generate_from_file(selected_file, voice, format, speed):
@ -121,16 +121,19 @@ def setup_event_handlers(components: dict):
gr.Warning("Failed to generate speech. Please try again.")
return [None, gr.update(choices=files.list_output_files())]
# Update list and select the newly generated file
output_files = files.list_output_files()
last_file = output_files[-1] if output_files else None
return [
result,
gr.update(
choices=files.list_output_files(), value=os.path.basename(result)
),
gr.update(choices=output_files, value=last_file),
]
def play_selected(file_path):
if file_path and os.path.exists(file_path):
return gr.update(value=file_path, visible=True)
def play_selected(filename):
if filename:
file_path = os.path.join(files.OUTPUTS_DIR, filename)
if os.path.exists(file_path):
return gr.update(value=file_path, visible=True)
return gr.update(visible=False)
def clear_files(voice, format, speed):