mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
fix: ui stability, memory safeties
This commit is contained in:
parent
234445f5ae
commit
f4dc292440
9 changed files with 144 additions and 66 deletions
|
@ -7,6 +7,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
espeak-ng \
|
espeak-ng \
|
||||||
git \
|
git \
|
||||||
libsndfile1 \
|
libsndfile1 \
|
||||||
|
curl \
|
||||||
&& apt-get clean \
|
&& apt-get clean \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
|
|
@ -41,8 +41,6 @@ class TTSCPUModel(TTSBaseModel):
|
||||||
if not onnx_path:
|
if not onnx_path:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
logger.info(f"Loading ONNX model from {onnx_path}")
|
|
||||||
|
|
||||||
# Configure ONNX session for optimal performance
|
# Configure ONNX session for optimal performance
|
||||||
session_options = SessionOptions()
|
session_options = SessionOptions()
|
||||||
|
|
||||||
|
|
|
@ -38,48 +38,64 @@ from .text_processing import tokenize, phonemize
|
||||||
# return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
|
# return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy()
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(model, tokens, ref_s, speed):
|
def forward(model, tokens, ref_s, speed):
|
||||||
"""Forward pass through the model with light optimizations that preserve output quality"""
|
"""Forward pass through the model with moderate memory management"""
|
||||||
device = ref_s.device
|
device = ref_s.device
|
||||||
|
|
||||||
# Keep original token handling but optimize device placement
|
try:
|
||||||
|
# Initial tensor setup with proper device placement
|
||||||
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
tokens = torch.LongTensor([[0, *tokens, 0]]).to(device)
|
||||||
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
||||||
text_mask = length_to_mask(input_lengths).to(device)
|
text_mask = length_to_mask(input_lengths).to(device)
|
||||||
|
|
||||||
|
# Split and clone reference signals with explicit device placement
|
||||||
|
s_content = ref_s[:, 128:].clone().to(device)
|
||||||
|
s_ref = ref_s[:, :128].clone().to(device)
|
||||||
|
|
||||||
# BERT and encoder pass
|
# BERT and encoder pass
|
||||||
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
||||||
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
||||||
|
|
||||||
# Split reference signal once for efficiency
|
|
||||||
s_content = ref_s[:, 128:]
|
|
||||||
s_ref = ref_s[:, :128]
|
|
||||||
|
|
||||||
# Predictor forward pass
|
# Predictor forward pass
|
||||||
d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask)
|
d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask)
|
||||||
x, _ = model.predictor.lstm(d)
|
x, _ = model.predictor.lstm(d)
|
||||||
|
|
||||||
# Duration prediction - keeping original logic
|
# Duration prediction
|
||||||
duration = model.predictor.duration_proj(x)
|
duration = model.predictor.duration_proj(x)
|
||||||
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
duration = torch.sigmoid(duration).sum(axis=-1) / speed
|
||||||
pred_dur = torch.round(duration).clamp(min=1).long()
|
pred_dur = torch.round(duration).clamp(min=1).long()
|
||||||
|
# Only cleanup large intermediates
|
||||||
|
del duration, x
|
||||||
|
|
||||||
# Alignment matrix construction - keeping original approach for quality
|
# Alignment matrix construction
|
||||||
pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item(), device=device)
|
pred_aln_trg = torch.zeros(input_lengths.item(), pred_dur.sum().item(), device=device)
|
||||||
c_frame = 0
|
c_frame = 0
|
||||||
for i in range(pred_aln_trg.size(0)):
|
for i in range(pred_aln_trg.size(0)):
|
||||||
pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
|
pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1
|
||||||
c_frame += pred_dur[0, i].item()
|
c_frame += pred_dur[0, i].item()
|
||||||
|
pred_aln_trg = pred_aln_trg.unsqueeze(0)
|
||||||
|
|
||||||
# Matrix multiplications - reuse unsqueezed tensor
|
# Matrix multiplications with selective cleanup
|
||||||
pred_aln_trg = pred_aln_trg.unsqueeze(0) # Do unsqueeze once
|
|
||||||
en = d.transpose(-1, -2) @ pred_aln_trg
|
en = d.transpose(-1, -2) @ pred_aln_trg
|
||||||
F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
|
del d # Free large intermediate tensor
|
||||||
|
|
||||||
# Text encoding and final decoding
|
F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content)
|
||||||
|
del en # Free large intermediate tensor
|
||||||
|
|
||||||
|
# Final text encoding and decoding
|
||||||
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
||||||
asr = t_en @ pred_aln_trg
|
asr = t_en @ pred_aln_trg
|
||||||
|
del t_en # Free large intermediate tensor
|
||||||
|
|
||||||
return model.decoder(asr, F0_pred, N_pred, s_ref).squeeze().cpu().numpy()
|
# Final decoding and transfer to CPU
|
||||||
|
output = model.decoder(asr, F0_pred, N_pred, s_ref)
|
||||||
|
result = output.squeeze().cpu().numpy()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Let PyTorch handle most cleanup automatically
|
||||||
|
# Only explicitly free the largest tensors
|
||||||
|
del pred_aln_trg, asr
|
||||||
|
|
||||||
|
|
||||||
# def length_to_mask(lengths):
|
# def length_to_mask(lengths):
|
||||||
|
@ -179,7 +195,7 @@ class TTSGPUModel(TTSBaseModel):
|
||||||
def generate_from_tokens(
|
def generate_from_tokens(
|
||||||
cls, tokens: list[int], voicepack: torch.Tensor, speed: float
|
cls, tokens: list[int], voicepack: torch.Tensor, speed: float
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Generate audio from tokens
|
"""Generate audio from tokens with moderate memory management
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tokens: Token IDs
|
tokens: Token IDs
|
||||||
|
@ -192,10 +208,55 @@ class TTSGPUModel(TTSBaseModel):
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
raise RuntimeError("GPU model not initialized")
|
raise RuntimeError("GPU model not initialized")
|
||||||
|
|
||||||
# Get reference style
|
try:
|
||||||
ref_s = voicepack[len(tokens)]
|
device = cls._device
|
||||||
|
|
||||||
|
# Check memory pressure
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
memory_allocated = torch.cuda.memory_allocated(device) / 1e9 # Convert to GB
|
||||||
|
if memory_allocated > 2.0: # 2GB limit
|
||||||
|
logger.info(
|
||||||
|
f"Memory usage above 2GB threshold:{memory_allocated:.2f}GB "
|
||||||
|
f"Clearing cache"
|
||||||
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
import gc
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# Get reference style with proper device placement
|
||||||
|
ref_s = voicepack[len(tokens)].clone().to(device)
|
||||||
|
|
||||||
# Generate audio
|
# Generate audio
|
||||||
audio = forward(cls._instance, tokens, ref_s, speed)
|
audio = forward(cls._instance, tokens, ref_s, speed)
|
||||||
|
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
|
except RuntimeError as e:
|
||||||
|
if "out of memory" in str(e):
|
||||||
|
# On OOM, do a full cleanup and retry
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
logger.warning("Out of memory detected, performing full cleanup")
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
import gc
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# Log memory stats after cleanup
|
||||||
|
memory_allocated = torch.cuda.memory_allocated(device)
|
||||||
|
memory_reserved = torch.cuda.memory_reserved(device)
|
||||||
|
logger.info(
|
||||||
|
f"Memory after OOM cleanup: "
|
||||||
|
f"Allocated: {memory_allocated / 1e9:.2f}GB, "
|
||||||
|
f"Reserved: {memory_reserved / 1e9:.2f}GB"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Retry generation
|
||||||
|
ref_s = voicepack[len(tokens)].clone().to(device)
|
||||||
|
audio = forward(cls._instance, tokens, ref_s, speed)
|
||||||
|
return audio
|
||||||
|
raise
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Only synchronize at the top level, no empty_cache
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
|
@ -6,6 +6,8 @@ services:
|
||||||
working_dir: /app/Kokoro-82M
|
working_dir: /app/Kokoro-82M
|
||||||
command: >
|
command: >
|
||||||
sh -c "
|
sh -c "
|
||||||
|
mkdir -p /app/Kokoro-82M;
|
||||||
|
cd /app/Kokoro-82M;
|
||||||
rm -f .git/index.lock;
|
rm -f .git/index.lock;
|
||||||
if [ -z \"$(ls -A .)\" ]; then
|
if [ -z \"$(ls -A .)\" ]; then
|
||||||
git clone https://huggingface.co/hexgrad/Kokoro-82M .
|
git clone https://huggingface.co/hexgrad/Kokoro-82M .
|
||||||
|
@ -26,11 +28,11 @@ services:
|
||||||
start_period: 1s
|
start_period: 1s
|
||||||
|
|
||||||
kokoro-tts:
|
kokoro-tts:
|
||||||
image: ghcr.io/remsky/kokoro-fastapi-cpu:latest
|
# image: ghcr.io/remsky/kokoro-fastapi-cpu:latest
|
||||||
# Uncomment below (and comment out above) to build from source instead of using the released image
|
# Uncomment below (and comment out above) to build from source instead of using the released image
|
||||||
# build:
|
build:
|
||||||
# context: .
|
context: .
|
||||||
# dockerfile: Dockerfile.cpu
|
dockerfile: Dockerfile.cpu
|
||||||
volumes:
|
volumes:
|
||||||
- ./api/src:/app/api/src
|
- ./api/src:/app/api/src
|
||||||
- ./Kokoro-82M:/app/Kokoro-82M
|
- ./Kokoro-82M:/app/Kokoro-82M
|
||||||
|
@ -46,6 +48,12 @@ services:
|
||||||
- ONNX_MEMORY_PATTERN=true
|
- ONNX_MEMORY_PATTERN=true
|
||||||
- ONNX_ARENA_EXTEND_STRATEGY=kNextPowerOfTwo
|
- ONNX_ARENA_EXTEND_STRATEGY=kNextPowerOfTwo
|
||||||
|
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "curl", "-f", "http://localhost:8880/health"]
|
||||||
|
interval: 10s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 30
|
||||||
|
start_period: 30s
|
||||||
depends_on:
|
depends_on:
|
||||||
model-fetcher:
|
model-fetcher:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
|
@ -64,3 +72,6 @@ services:
|
||||||
environment:
|
environment:
|
||||||
- GRADIO_WATCH=True # Enable hot reloading
|
- GRADIO_WATCH=True # Enable hot reloading
|
||||||
- PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
|
- PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
|
||||||
|
depends_on:
|
||||||
|
kokoro-tts:
|
||||||
|
condition: service_healthy
|
||||||
|
|
|
@ -8,6 +8,8 @@ services:
|
||||||
working_dir: /app/Kokoro-82M
|
working_dir: /app/Kokoro-82M
|
||||||
command: >
|
command: >
|
||||||
sh -c "
|
sh -c "
|
||||||
|
mkdir -p /app/Kokoro-82M;
|
||||||
|
cd /app/Kokoro-82M;
|
||||||
if [ \"$$SKIP_MODEL_FETCH\" = \"true\" ]; then
|
if [ \"$$SKIP_MODEL_FETCH\" = \"true\" ]; then
|
||||||
echo 'Skipping model fetch...' && touch .cloned;
|
echo 'Skipping model fetch...' && touch .cloned;
|
||||||
else
|
else
|
||||||
|
@ -32,10 +34,10 @@ services:
|
||||||
start_period: 1s
|
start_period: 1s
|
||||||
|
|
||||||
kokoro-tts:
|
kokoro-tts:
|
||||||
image: ghcr.io/remsky/kokoro-fastapi-gpu:latest
|
# image: ghcr.io/remsky/kokoro-fastapi-gpu:latest
|
||||||
# Uncomment below (and comment out above) to build from source instead of using the released image
|
# Uncomment below (and comment out above) to build from source instead of using the released image
|
||||||
# build:
|
build:
|
||||||
# context: .
|
context: .
|
||||||
volumes:
|
volumes:
|
||||||
- ./api/src:/app/api/src
|
- ./api/src:/app/api/src
|
||||||
- ./Kokoro-82M:/app/Kokoro-82M
|
- ./Kokoro-82M:/app/Kokoro-82M
|
||||||
|
@ -51,7 +53,7 @@ services:
|
||||||
count: 1
|
count: 1
|
||||||
capabilities: [gpu]
|
capabilities: [gpu]
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "curl", "-f", "http://localhost:8880/v1/audio/voices"]
|
test: ["CMD", "curl", "-f", "http://localhost:8880/health"]
|
||||||
interval: 10s
|
interval: 10s
|
||||||
timeout: 5s
|
timeout: 5s
|
||||||
retries: 30
|
retries: 30
|
||||||
|
|
|
@ -21,8 +21,9 @@ def create_model_column(voice_ids: Optional[list] = None) -> Tuple[gr.Column, di
|
||||||
voice_input = gr.Dropdown(
|
voice_input = gr.Dropdown(
|
||||||
choices=voice_ids,
|
choices=voice_ids,
|
||||||
label="Voice",
|
label="Voice",
|
||||||
value=voice_ids[0] if voice_ids else None,
|
value=None, # Start with no value to avoid errors
|
||||||
interactive=True,
|
interactive=True,
|
||||||
|
allow_custom_value=True, # Allow temporary values during updates
|
||||||
)
|
)
|
||||||
format_input = gr.Dropdown(
|
format_input = gr.Dropdown(
|
||||||
choices=config.AUDIO_FORMATS, label="Audio Format", value="mp3"
|
choices=config.AUDIO_FORMATS, label="Audio Format", value="mp3"
|
||||||
|
|
|
@ -16,7 +16,7 @@ def create_output_column() -> Tuple[gr.Column, dict]:
|
||||||
label="Previous Outputs",
|
label="Previous Outputs",
|
||||||
choices=files.list_output_files(),
|
choices=files.list_output_files(),
|
||||||
value=None,
|
value=None,
|
||||||
allow_custom_value=False,
|
allow_custom_value=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
play_btn = gr.Button("▶️ Play Selected", size="sm")
|
play_btn = gr.Button("▶️ Play Selected", size="sm")
|
||||||
|
@ -40,3 +40,4 @@ def create_output_column() -> Tuple[gr.Column, dict]:
|
||||||
}
|
}
|
||||||
|
|
||||||
return col, components
|
return col, components
|
||||||
|
return col, components
|
||||||
|
|
|
@ -12,9 +12,9 @@ def list_input_files() -> List[str]:
|
||||||
|
|
||||||
def list_output_files() -> List[str]:
|
def list_output_files() -> List[str]:
|
||||||
"""List all output audio files."""
|
"""List all output audio files."""
|
||||||
|
# Just return filenames since paths will be different inside/outside container
|
||||||
return [
|
return [
|
||||||
os.path.join(OUTPUTS_DIR, f)
|
f for f in os.listdir(OUTPUTS_DIR)
|
||||||
for f in os.listdir(OUTPUTS_DIR)
|
|
||||||
if any(f.endswith(ext) for ext in AUDIO_FORMATS)
|
if any(f.endswith(ext) for ext in AUDIO_FORMATS)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
from . import api, files
|
from . import api, files
|
||||||
|
@ -97,11 +96,12 @@ def setup_event_handlers(components: dict):
|
||||||
gr.Warning("Failed to generate speech. Please try again.")
|
gr.Warning("Failed to generate speech. Please try again.")
|
||||||
return [None, gr.update(choices=files.list_output_files())]
|
return [None, gr.update(choices=files.list_output_files())]
|
||||||
|
|
||||||
|
# Update list and select the newly generated file
|
||||||
|
output_files = files.list_output_files()
|
||||||
|
last_file = output_files[-1] if output_files else None
|
||||||
return [
|
return [
|
||||||
result,
|
result,
|
||||||
gr.update(
|
gr.update(choices=output_files, value=last_file),
|
||||||
choices=files.list_output_files(), value=os.path.basename(result)
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def generate_from_file(selected_file, voice, format, speed):
|
def generate_from_file(selected_file, voice, format, speed):
|
||||||
|
@ -121,15 +121,18 @@ def setup_event_handlers(components: dict):
|
||||||
gr.Warning("Failed to generate speech. Please try again.")
|
gr.Warning("Failed to generate speech. Please try again.")
|
||||||
return [None, gr.update(choices=files.list_output_files())]
|
return [None, gr.update(choices=files.list_output_files())]
|
||||||
|
|
||||||
|
# Update list and select the newly generated file
|
||||||
|
output_files = files.list_output_files()
|
||||||
|
last_file = output_files[-1] if output_files else None
|
||||||
return [
|
return [
|
||||||
result,
|
result,
|
||||||
gr.update(
|
gr.update(choices=output_files, value=last_file),
|
||||||
choices=files.list_output_files(), value=os.path.basename(result)
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def play_selected(file_path):
|
def play_selected(filename):
|
||||||
if file_path and os.path.exists(file_path):
|
if filename:
|
||||||
|
file_path = os.path.join(files.OUTPUTS_DIR, filename)
|
||||||
|
if os.path.exists(file_path):
|
||||||
return gr.update(value=file_path, visible=True)
|
return gr.update(value=file_path, visible=True)
|
||||||
return gr.update(visible=False)
|
return gr.update(visible=False)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue