diff --git a/Dockerfile b/Dockerfile index 7d70af9..3cc5689 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,6 +7,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ espeak-ng \ git \ libsndfile1 \ + curl \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* diff --git a/api/src/services/tts_cpu.py b/api/src/services/tts_cpu.py index 3ee3395..5284750 100644 --- a/api/src/services/tts_cpu.py +++ b/api/src/services/tts_cpu.py @@ -41,8 +41,6 @@ class TTSCPUModel(TTSBaseModel): if not onnx_path: return None - logger.info(f"Loading ONNX model from {onnx_path}") - # Configure ONNX session for optimal performance session_options = SessionOptions() diff --git a/api/src/services/tts_gpu.py b/api/src/services/tts_gpu.py index 87e9ef2..1e5f4a1 100644 --- a/api/src/services/tts_gpu.py +++ b/api/src/services/tts_gpu.py @@ -38,48 +38,64 @@ from .text_processing import tokenize, phonemize # return model.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze().cpu().numpy() @torch.no_grad() def forward(model, tokens, ref_s, speed): - """Forward pass through the model with light optimizations that preserve output quality""" + """Forward pass through the model with moderate memory management""" device = ref_s.device + + try: + # Initial tensor setup with proper device placement + tokens = torch.LongTensor([[0, *tokens, 0]]).to(device) + input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device) + text_mask = length_to_mask(input_lengths).to(device) - # Keep original token handling but optimize device placement - tokens = torch.LongTensor([[0, *tokens, 0]]).to(device) - input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device) - text_mask = length_to_mask(input_lengths).to(device) + # Split and clone reference signals with explicit device placement + s_content = ref_s[:, 128:].clone().to(device) + s_ref = ref_s[:, :128].clone().to(device) - # BERT and encoder pass - bert_dur = model.bert(tokens, attention_mask=(~text_mask).int()) - d_en = model.bert_encoder(bert_dur).transpose(-1, -2) + # BERT and encoder pass + bert_dur = model.bert(tokens, attention_mask=(~text_mask).int()) + d_en = model.bert_encoder(bert_dur).transpose(-1, -2) - # Split reference signal once for efficiency - s_content = ref_s[:, 128:] - s_ref = ref_s[:, :128] + # Predictor forward pass + d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask) + x, _ = model.predictor.lstm(d) - # Predictor forward pass - d = model.predictor.text_encoder(d_en, s_content, input_lengths, text_mask) - x, _ = model.predictor.lstm(d) + # Duration prediction + duration = model.predictor.duration_proj(x) + duration = torch.sigmoid(duration).sum(axis=-1) / speed + pred_dur = torch.round(duration).clamp(min=1).long() + # Only cleanup large intermediates + del duration, x - # Duration prediction - keeping original logic - duration = model.predictor.duration_proj(x) - duration = torch.sigmoid(duration).sum(axis=-1) / speed - pred_dur = torch.round(duration).clamp(min=1).long() + # Alignment matrix construction + pred_aln_trg = torch.zeros(input_lengths.item(), pred_dur.sum().item(), device=device) + c_frame = 0 + for i in range(pred_aln_trg.size(0)): + pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1 + c_frame += pred_dur[0, i].item() + pred_aln_trg = pred_aln_trg.unsqueeze(0) - # Alignment matrix construction - keeping original approach for quality - pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum().item(), device=device) - c_frame = 0 - for i in range(pred_aln_trg.size(0)): - pred_aln_trg[i, c_frame : c_frame + pred_dur[0, i].item()] = 1 - c_frame += pred_dur[0, i].item() + # Matrix multiplications with selective cleanup + en = d.transpose(-1, -2) @ pred_aln_trg + del d # Free large intermediate tensor + + F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content) + del en # Free large intermediate tensor - # Matrix multiplications - reuse unsqueezed tensor - pred_aln_trg = pred_aln_trg.unsqueeze(0) # Do unsqueeze once - en = d.transpose(-1, -2) @ pred_aln_trg - F0_pred, N_pred = model.predictor.F0Ntrain(en, s_content) + # Final text encoding and decoding + t_en = model.text_encoder(tokens, input_lengths, text_mask) + asr = t_en @ pred_aln_trg + del t_en # Free large intermediate tensor - # Text encoding and final decoding - t_en = model.text_encoder(tokens, input_lengths, text_mask) - asr = t_en @ pred_aln_trg - - return model.decoder(asr, F0_pred, N_pred, s_ref).squeeze().cpu().numpy() + # Final decoding and transfer to CPU + output = model.decoder(asr, F0_pred, N_pred, s_ref) + result = output.squeeze().cpu().numpy() + + return result + + finally: + # Let PyTorch handle most cleanup automatically + # Only explicitly free the largest tensors + del pred_aln_trg, asr # def length_to_mask(lengths): @@ -179,7 +195,7 @@ class TTSGPUModel(TTSBaseModel): def generate_from_tokens( cls, tokens: list[int], voicepack: torch.Tensor, speed: float ) -> np.ndarray: - """Generate audio from tokens + """Generate audio from tokens with moderate memory management Args: tokens: Token IDs @@ -192,10 +208,55 @@ class TTSGPUModel(TTSBaseModel): if cls._instance is None: raise RuntimeError("GPU model not initialized") - # Get reference style - ref_s = voicepack[len(tokens)] - - # Generate audio - audio = forward(cls._instance, tokens, ref_s, speed) - - return audio + try: + device = cls._device + + # Check memory pressure + if torch.cuda.is_available(): + memory_allocated = torch.cuda.memory_allocated(device) / 1e9 # Convert to GB + if memory_allocated > 2.0: # 2GB limit + logger.info( + f"Memory usage above 2GB threshold:{memory_allocated:.2f}GB " + f"Clearing cache" + ) + torch.cuda.empty_cache() + import gc + gc.collect() + + # Get reference style with proper device placement + ref_s = voicepack[len(tokens)].clone().to(device) + + # Generate audio + audio = forward(cls._instance, tokens, ref_s, speed) + + return audio + + except RuntimeError as e: + if "out of memory" in str(e): + # On OOM, do a full cleanup and retry + if torch.cuda.is_available(): + logger.warning("Out of memory detected, performing full cleanup") + torch.cuda.synchronize() + torch.cuda.empty_cache() + import gc + gc.collect() + + # Log memory stats after cleanup + memory_allocated = torch.cuda.memory_allocated(device) + memory_reserved = torch.cuda.memory_reserved(device) + logger.info( + f"Memory after OOM cleanup: " + f"Allocated: {memory_allocated / 1e9:.2f}GB, " + f"Reserved: {memory_reserved / 1e9:.2f}GB" + ) + + # Retry generation + ref_s = voicepack[len(tokens)].clone().to(device) + audio = forward(cls._instance, tokens, ref_s, speed) + return audio + raise + + finally: + # Only synchronize at the top level, no empty_cache + if torch.cuda.is_available(): + torch.cuda.synchronize() diff --git a/docker-compose.cpu.yml b/docker-compose.cpu.yml index 2266364..de4b4fe 100644 --- a/docker-compose.cpu.yml +++ b/docker-compose.cpu.yml @@ -6,6 +6,8 @@ services: working_dir: /app/Kokoro-82M command: > sh -c " + mkdir -p /app/Kokoro-82M; + cd /app/Kokoro-82M; rm -f .git/index.lock; if [ -z \"$(ls -A .)\" ]; then git clone https://huggingface.co/hexgrad/Kokoro-82M . @@ -26,11 +28,11 @@ services: start_period: 1s kokoro-tts: - image: ghcr.io/remsky/kokoro-fastapi-cpu:latest + # image: ghcr.io/remsky/kokoro-fastapi-cpu:latest # Uncomment below (and comment out above) to build from source instead of using the released image - # build: - # context: . - # dockerfile: Dockerfile.cpu + build: + context: . + dockerfile: Dockerfile.cpu volumes: - ./api/src:/app/api/src - ./Kokoro-82M:/app/Kokoro-82M @@ -46,6 +48,12 @@ services: - ONNX_MEMORY_PATTERN=true - ONNX_ARENA_EXTEND_STRATEGY=kNextPowerOfTwo + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8880/health"] + interval: 10s + timeout: 5s + retries: 30 + start_period: 30s depends_on: model-fetcher: condition: service_healthy @@ -64,3 +72,6 @@ services: environment: - GRADIO_WATCH=True # Enable hot reloading - PYTHONUNBUFFERED=1 # Ensure Python output is not buffered + depends_on: + kokoro-tts: + condition: service_healthy diff --git a/docker-compose.yml b/docker-compose.yml index 3e4bd3d..7970695 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -8,6 +8,8 @@ services: working_dir: /app/Kokoro-82M command: > sh -c " + mkdir -p /app/Kokoro-82M; + cd /app/Kokoro-82M; if [ \"$$SKIP_MODEL_FETCH\" = \"true\" ]; then echo 'Skipping model fetch...' && touch .cloned; else @@ -32,10 +34,10 @@ services: start_period: 1s kokoro-tts: - image: ghcr.io/remsky/kokoro-fastapi-gpu:latest + # image: ghcr.io/remsky/kokoro-fastapi-gpu:latest # Uncomment below (and comment out above) to build from source instead of using the released image - # build: - # context: . + build: + context: . volumes: - ./api/src:/app/api/src - ./Kokoro-82M:/app/Kokoro-82M @@ -51,7 +53,7 @@ services: count: 1 capabilities: [gpu] healthcheck: - test: ["CMD", "curl", "-f", "http://localhost:8880/v1/audio/voices"] + test: ["CMD", "curl", "-f", "http://localhost:8880/health"] interval: 10s timeout: 5s retries: 30 @@ -65,7 +67,7 @@ services: image: ghcr.io/remsky/kokoro-fastapi-ui:latest # Uncomment below (and comment out above) to build from source instead of using the released image # build: - # context: ./ui + # context: ./ui ports: - "7860:7860" volumes: diff --git a/ui/lib/components/model.py b/ui/lib/components/model.py index 444d0f8..5ffa3ce 100644 --- a/ui/lib/components/model.py +++ b/ui/lib/components/model.py @@ -21,8 +21,9 @@ def create_model_column(voice_ids: Optional[list] = None) -> Tuple[gr.Column, di voice_input = gr.Dropdown( choices=voice_ids, label="Voice", - value=voice_ids[0] if voice_ids else None, + value=None, # Start with no value to avoid errors interactive=True, + allow_custom_value=True, # Allow temporary values during updates ) format_input = gr.Dropdown( choices=config.AUDIO_FORMATS, label="Audio Format", value="mp3" diff --git a/ui/lib/components/output.py b/ui/lib/components/output.py index e25601d..89a9ca4 100644 --- a/ui/lib/components/output.py +++ b/ui/lib/components/output.py @@ -16,7 +16,7 @@ def create_output_column() -> Tuple[gr.Column, dict]: label="Previous Outputs", choices=files.list_output_files(), value=None, - allow_custom_value=False, + allow_custom_value=True, ) play_btn = gr.Button("▶️ Play Selected", size="sm") @@ -40,3 +40,4 @@ def create_output_column() -> Tuple[gr.Column, dict]: } return col, components + return col, components diff --git a/ui/lib/files.py b/ui/lib/files.py index 867f4f4..5495ea9 100644 --- a/ui/lib/files.py +++ b/ui/lib/files.py @@ -12,9 +12,9 @@ def list_input_files() -> List[str]: def list_output_files() -> List[str]: """List all output audio files.""" + # Just return filenames since paths will be different inside/outside container return [ - os.path.join(OUTPUTS_DIR, f) - for f in os.listdir(OUTPUTS_DIR) + f for f in os.listdir(OUTPUTS_DIR) if any(f.endswith(ext) for ext in AUDIO_FORMATS) ] diff --git a/ui/lib/handlers.py b/ui/lib/handlers.py index eba6cda..30062a0 100644 --- a/ui/lib/handlers.py +++ b/ui/lib/handlers.py @@ -1,6 +1,5 @@ import os import shutil - import gradio as gr from . import api, files @@ -97,11 +96,12 @@ def setup_event_handlers(components: dict): gr.Warning("Failed to generate speech. Please try again.") return [None, gr.update(choices=files.list_output_files())] + # Update list and select the newly generated file + output_files = files.list_output_files() + last_file = output_files[-1] if output_files else None return [ result, - gr.update( - choices=files.list_output_files(), value=os.path.basename(result) - ), + gr.update(choices=output_files, value=last_file), ] def generate_from_file(selected_file, voice, format, speed): @@ -121,16 +121,19 @@ def setup_event_handlers(components: dict): gr.Warning("Failed to generate speech. Please try again.") return [None, gr.update(choices=files.list_output_files())] + # Update list and select the newly generated file + output_files = files.list_output_files() + last_file = output_files[-1] if output_files else None return [ result, - gr.update( - choices=files.list_output_files(), value=os.path.basename(result) - ), + gr.update(choices=output_files, value=last_file), ] - def play_selected(file_path): - if file_path and os.path.exists(file_path): - return gr.update(value=file_path, visible=True) + def play_selected(filename): + if filename: + file_path = os.path.join(files.OUTPUTS_DIR, filename) + if os.path.exists(file_path): + return gr.update(value=file_path, visible=True) return gr.update(visible=False) def clear_files(voice, format, speed):