mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
fix: ui stability, memory safeties
This commit is contained in:
parent
234445f5ae
commit
f4dc292440
9 changed files with 144 additions and 66 deletions
|
@ -7,6 +7,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||
espeak-ng \
|
||||
git \
|
||||
libsndfile1 \
|
||||
curl \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
|
|
@ -41,8 +41,6 @@ class TTSCPUModel(TTSBaseModel):
|
|||
if not onnx_path:
|
||||
return None
|
||||
|
||||
logger.info(f"Loading ONNX model from {onnx_path}")
|
||||
|
||||
# Configure ONNX session for optimal performance
|
||||
session_options = SessionOptions()
|
||||
|
||||
|
|
|
@ -38,48 +38,64 @@ from .text_processing import tokenize, phonemize
|
|||
# return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
|
||||
@torch.no_grad()
|
||||
def forward(model, tokens, ref_s, speed):
|
||||
"""Forward pass through the model with light optimizations that preserve output quality"""
|
||||
"""Forward pass through the model with moderate memory management"""
|
||||
device = ref_s.device
|
||||
|
||||
try:
|
||||
# Initial tensor setup with proper device placement
|
||||
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
||||
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
||||
text_mask = length_to_mask(input_lengths).to(device)
|
||||
|
||||
# Keep original token handling but optimize device placement
|
||||
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
||||
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
||||
text_mask = length_to_mask(input_lengths).to(device)
|
||||
# Split and clone reference signals with explicit device placement
|
||||
s_content = ref_s[:, 128:].clone().to(device)
|
||||
s_ref = ref_s[:, :128].clone().to(device)
|
||||
|
||||
# BERT and encoder pass
|
||||
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
||||
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
||||
# BERT and encoder pass
|
||||
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
||||
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
||||
|
||||
# Split reference signal once for efficiency
|
||||
s_content = ref_s[:, 128:]
|
||||
s_ref = ref_s[:, :128]
|
||||
# Predictor forward pass
|
||||
d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask)
|
||||
x, _ = model.predictor.lstm(d)
|
||||
|
||||
# Predictor forward pass
|
||||
d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask)
|
||||
x, _ = model.predictor.lstm(d)
|
||||
# Duration prediction
|
||||
duration = model.predictor.duration_proj(x)
|
||||
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
||||
pred_dur = torch.round(duration).clamp(min=1).long()
|
||||
# Only cleanup large intermediates
|
||||
del duration, x
|
||||
|
||||
# Duration prediction - keeping original logic
|
||||
duration = model.predictor.duration_proj(x)
|
||||
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
||||
pred_dur = torch.round(duration).clamp(min=1).long()
|
||||
# Alignment matrix construction
|
||||
pred_aln_trg = torch.zeros(input_lengths.item(), pred_dur.sum().item(), device=device)
|
||||
c_frame = 0
|
||||
for i in range(pred_aln_trg.size(0)):
|
||||
pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
|
||||
c_frame += pred_dur[0, i].item()
|
||||
pred_aln_trg = pred_aln_trg.unsqueeze(0)
|
||||
|
||||
# Alignment matrix construction - keeping original approach for quality
|
||||
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item(), device=device)
|
||||
c_frame = 0
|
||||
for i in range(pred_aln_trg.size(0)):
|
||||
pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
|
||||
c_frame += pred_dur[0, i].item()
|
||||
# Matrix multiplications with selective cleanup
|
||||
en = d.transpose(-1, -2) @ pred_aln_trg
|
||||
del d # Free large intermediate tensor
|
||||
|
||||
F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
|
||||
del en # Free large intermediate tensor
|
||||
|
||||
# Matrix multiplications - reuse unsqueezed tensor
|
||||
pred_aln_trg = pred_aln_trg.unsqueeze(0) # Do unsqueeze once
|
||||
en = d.transpose(-1, -2) @ pred_aln_trg
|
||||
F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
|
||||
# Final text encoding and decoding
|
||||
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
||||
asr = t_en @ pred_aln_trg
|
||||
del t_en # Free large intermediate tensor
|
||||
|
||||
# Text encoding and final decoding
|
||||
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
||||
asr = t_en @ pred_aln_trg
|
||||
|
||||
return model.decoder(asr, F0_pred, N_pred, s_ref).squeeze().cpu().numpy()
|
||||
# Final decoding and transfer to CPU
|
||||
output = model.decoder(asr, F0_pred, N_pred, s_ref)
|
||||
result = output.squeeze().cpu().numpy()
|
||||
|
||||
return result
|
||||
|
||||
finally:
|
||||
# Let PyTorch handle most cleanup automatically
|
||||
# Only explicitly free the largest tensors
|
||||
del pred_aln_trg, asr
|
||||
|
||||
|
||||
# def length_to_mask(lengths):
|
||||
|
@ -179,7 +195,7 @@ class TTSGPUModel(TTSBaseModel):
|
|||
def generate_from_tokens(
|
||||
cls, tokens: list[int], voicepack: torch.Tensor, speed: float
|
||||
) -> np.ndarray:
|
||||
"""Generate audio from tokens
|
||||
"""Generate audio from tokens with moderate memory management
|
||||
|
||||
Args:
|
||||
tokens: Token IDs
|
||||
|
@ -192,10 +208,55 @@ class TTSGPUModel(TTSBaseModel):
|
|||
if cls._instance is None:
|
||||
raise RuntimeError("GPU model not initialized")
|
||||
|
||||
# Get reference style
|
||||
ref_s = voicepack[len(tokens)]
|
||||
|
||||
# Generate audio
|
||||
audio = forward(cls._instance, tokens, ref_s, speed)
|
||||
|
||||
return audio
|
||||
try:
|
||||
device = cls._device
|
||||
|
||||
# Check memory pressure
|
||||
if torch.cuda.is_available():
|
||||
memory_allocated = torch.cuda.memory_allocated(device) / 1e9 # Convert to GB
|
||||
if memory_allocated > 2.0: # 2GB limit
|
||||
logger.info(
|
||||
f"Memory usage above 2GB threshold:{memory_allocated:.2f}GB "
|
||||
f"Clearing cache"
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
import gc
|
||||
gc.collect()
|
||||
|
||||
# Get reference style with proper device placement
|
||||
ref_s = voicepack[len(tokens)].clone().to(device)
|
||||
|
||||
# Generate audio
|
||||
audio = forward(cls._instance, tokens, ref_s, speed)
|
||||
|
||||
return audio
|
||||
|
||||
except RuntimeError as e:
|
||||
if "out of memory" in str(e):
|
||||
# On OOM, do a full cleanup and retry
|
||||
if torch.cuda.is_available():
|
||||
logger.warning("Out of memory detected, performing full cleanup")
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
import gc
|
||||
gc.collect()
|
||||
|
||||
# Log memory stats after cleanup
|
||||
memory_allocated = torch.cuda.memory_allocated(device)
|
||||
memory_reserved = torch.cuda.memory_reserved(device)
|
||||
logger.info(
|
||||
f"Memory after OOM cleanup: "
|
||||
f"Allocated: {memory_allocated / 1e9:.2f}GB, "
|
||||
f"Reserved: {memory_reserved / 1e9:.2f}GB"
|
||||
)
|
||||
|
||||
# Retry generation
|
||||
ref_s = voicepack[len(tokens)].clone().to(device)
|
||||
audio = forward(cls._instance, tokens, ref_s, speed)
|
||||
return audio
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Only synchronize at the top level, no empty_cache
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
|
|
@ -6,6 +6,8 @@ services:
|
|||
working_dir: /app/Kokoro-82M
|
||||
command: >
|
||||
sh -c "
|
||||
mkdir -p /app/Kokoro-82M;
|
||||
cd /app/Kokoro-82M;
|
||||
rm -f .git/index.lock;
|
||||
if [ -z \"$(ls -A .)\" ]; then
|
||||
git clone https://huggingface.co/hexgrad/Kokoro-82M .
|
||||
|
@ -26,11 +28,11 @@ services:
|
|||
start_period: 1s
|
||||
|
||||
kokoro-tts:
|
||||
image: ghcr.io/remsky/kokoro-fastapi-cpu:latest
|
||||
# image: ghcr.io/remsky/kokoro-fastapi-cpu:latest
|
||||
# Uncomment below (and comment out above) to build from source instead of using the released image
|
||||
# build:
|
||||
# context: .
|
||||
# dockerfile: Dockerfile.cpu
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.cpu
|
||||
volumes:
|
||||
- ./api/src:/app/api/src
|
||||
- ./Kokoro-82M:/app/Kokoro-82M
|
||||
|
@ -46,6 +48,12 @@ services:
|
|||
- ONNX_MEMORY_PATTERN=true
|
||||
- ONNX_ARENA_EXTEND_STRATEGY=kNextPowerOfTwo
|
||||
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8880/health"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 30
|
||||
start_period: 30s
|
||||
depends_on:
|
||||
model-fetcher:
|
||||
condition: service_healthy
|
||||
|
@ -64,3 +72,6 @@ services:
|
|||
environment:
|
||||
- GRADIO_WATCH=True # Enable hot reloading
|
||||
- PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
|
||||
depends_on:
|
||||
kokoro-tts:
|
||||
condition: service_healthy
|
||||
|
|
|
@ -8,6 +8,8 @@ services:
|
|||
working_dir: /app/Kokoro-82M
|
||||
command: >
|
||||
sh -c "
|
||||
mkdir -p /app/Kokoro-82M;
|
||||
cd /app/Kokoro-82M;
|
||||
if [ \"$$SKIP_MODEL_FETCH\" = \"true\" ]; then
|
||||
echo 'Skipping model fetch...' && touch .cloned;
|
||||
else
|
||||
|
@ -32,10 +34,10 @@ services:
|
|||
start_period: 1s
|
||||
|
||||
kokoro-tts:
|
||||
image: ghcr.io/remsky/kokoro-fastapi-gpu:latest
|
||||
# image: ghcr.io/remsky/kokoro-fastapi-gpu:latest
|
||||
# Uncomment below (and comment out above) to build from source instead of using the released image
|
||||
# build:
|
||||
# context: .
|
||||
build:
|
||||
context: .
|
||||
volumes:
|
||||
- ./api/src:/app/api/src
|
||||
- ./Kokoro-82M:/app/Kokoro-82M
|
||||
|
@ -51,7 +53,7 @@ services:
|
|||
count: 1
|
||||
capabilities: [gpu]
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8880/v1/audio/voices"]
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8880/health"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 30
|
||||
|
@ -65,7 +67,7 @@ services:
|
|||
image: ghcr.io/remsky/kokoro-fastapi-ui:latest
|
||||
# Uncomment below (and comment out above) to build from source instead of using the released image
|
||||
# build:
|
||||
# context: ./ui
|
||||
# context: ./ui
|
||||
ports:
|
||||
- "7860:7860"
|
||||
volumes:
|
||||
|
|
|
@ -21,8 +21,9 @@ def create_model_column(voice_ids: Optional[list] = None) -> Tuple[gr.Column, di
|
|||
voice_input = gr.Dropdown(
|
||||
choices=voice_ids,
|
||||
label="Voice",
|
||||
value=voice_ids[0] if voice_ids else None,
|
||||
value=None, # Start with no value to avoid errors
|
||||
interactive=True,
|
||||
allow_custom_value=True, # Allow temporary values during updates
|
||||
)
|
||||
format_input = gr.Dropdown(
|
||||
choices=config.AUDIO_FORMATS, label="Audio Format", value="mp3"
|
||||
|
|
|
@ -16,7 +16,7 @@ def create_output_column() -> Tuple[gr.Column, dict]:
|
|||
label="Previous Outputs",
|
||||
choices=files.list_output_files(),
|
||||
value=None,
|
||||
allow_custom_value=False,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
|
||||
play_btn = gr.Button("▶️ Play Selected", size="sm")
|
||||
|
@ -40,3 +40,4 @@ def create_output_column() -> Tuple[gr.Column, dict]:
|
|||
}
|
||||
|
||||
return col, components
|
||||
return col, components
|
||||
|
|
|
@ -12,9 +12,9 @@ def list_input_files() -> List[str]:
|
|||
|
||||
def list_output_files() -> List[str]:
|
||||
"""List all output audio files."""
|
||||
# Just return filenames since paths will be different inside/outside container
|
||||
return [
|
||||
os.path.join(OUTPUTS_DIR, f)
|
||||
for f in os.listdir(OUTPUTS_DIR)
|
||||
f for f in os.listdir(OUTPUTS_DIR)
|
||||
if any(f.endswith(ext) for ext in AUDIO_FORMATS)
|
||||
]
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import os
|
||||
import shutil
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from . import api, files
|
||||
|
@ -97,11 +96,12 @@ def setup_event_handlers(components: dict):
|
|||
gr.Warning("Failed to generate speech. Please try again.")
|
||||
return [None, gr.update(choices=files.list_output_files())]
|
||||
|
||||
# Update list and select the newly generated file
|
||||
output_files = files.list_output_files()
|
||||
last_file = output_files[-1] if output_files else None
|
||||
return [
|
||||
result,
|
||||
gr.update(
|
||||
choices=files.list_output_files(), value=os.path.basename(result)
|
||||
),
|
||||
gr.update(choices=output_files, value=last_file),
|
||||
]
|
||||
|
||||
def generate_from_file(selected_file, voice, format, speed):
|
||||
|
@ -121,16 +121,19 @@ def setup_event_handlers(components: dict):
|
|||
gr.Warning("Failed to generate speech. Please try again.")
|
||||
return [None, gr.update(choices=files.list_output_files())]
|
||||
|
||||
# Update list and select the newly generated file
|
||||
output_files = files.list_output_files()
|
||||
last_file = output_files[-1] if output_files else None
|
||||
return [
|
||||
result,
|
||||
gr.update(
|
||||
choices=files.list_output_files(), value=os.path.basename(result)
|
||||
),
|
||||
gr.update(choices=output_files, value=last_file),
|
||||
]
|
||||
|
||||
def play_selected(file_path):
|
||||
if file_path and os.path.exists(file_path):
|
||||
return gr.update(value=file_path, visible=True)
|
||||
def play_selected(filename):
|
||||
if filename:
|
||||
file_path = os.path.join(files.OUTPUTS_DIR, filename)
|
||||
if os.path.exists(file_path):
|
||||
return gr.update(value=file_path, visible=True)
|
||||
return gr.update(visible=False)
|
||||
|
||||
def clear_files(voice, format, speed):
|
||||
|
|
Loading…
Add table
Reference in a new issue