mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
187 lines
7.1 KiB
Python
187 lines
7.1 KiB
Python
![]() |
"""PyTorch inference backend with environment-based configuration."""
|
||
|
|
||
|
import gc
|
||
|
import os
|
||
|
from typing import AsyncGenerator, Optional, List, Union, Tuple
|
||
|
from contextlib import nullcontext
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
from loguru import logger
|
||
|
|
||
|
from ..core import paths
|
||
|
from ..core.model_config import model_config
|
||
|
from ..core.config import settings
|
||
|
from .base import BaseModelBackend
|
||
|
from kokoro import KModel, KPipeline
|
||
|
|
||
|
|
||
|
class KokoroV1(BaseModelBackend):
|
||
|
"""Kokoro package based inference backend with environment-based configuration."""
|
||
|
|
||
|
def __init__(self):
|
||
|
"""Initialize backend based on environment configuration."""
|
||
|
super().__init__()
|
||
|
|
||
|
# Configure device based on settings
|
||
|
self._device = (
|
||
|
"cuda" if settings.use_gpu and torch.cuda.is_available() else "cpu"
|
||
|
)
|
||
|
self._model: Optional[KModel] = None
|
||
|
self._pipeline: Optional[KPipeline] = None
|
||
|
|
||
|
async def load_model(self, path: str) -> None:
|
||
|
"""Load Kokoro model.
|
||
|
|
||
|
Args:
|
||
|
path: Path to model file
|
||
|
|
||
|
Raises:
|
||
|
RuntimeError: If model loading fails
|
||
|
"""
|
||
|
try:
|
||
|
# Get verified model path
|
||
|
model_path = await paths.get_model_path(path)
|
||
|
|
||
|
# Get config.json path from the same directory
|
||
|
config_path = os.path.join(os.path.dirname(model_path), 'config.json')
|
||
|
|
||
|
if not os.path.exists(config_path):
|
||
|
raise RuntimeError(f"Config file not found: {config_path}")
|
||
|
|
||
|
logger.info(f"Loading Kokoro model on {self._device}")
|
||
|
logger.info(f"Config path: {config_path}")
|
||
|
logger.info(f"Model path: {model_path}")
|
||
|
|
||
|
# Initialize model with config and weights
|
||
|
self._model = KModel(config=config_path, model=model_path).to(self._device).eval()
|
||
|
# Initialize pipeline with American English by default
|
||
|
self._pipeline = KPipeline(lang_code='a', model=self._model, device=self._device)
|
||
|
|
||
|
except Exception as e:
|
||
|
raise RuntimeError(f"Failed to load Kokoro model: {e}")
|
||
|
|
||
|
async def generate(
|
||
|
self, text: str, voice: Union[str, Tuple[str, Union[torch.Tensor, str]]], speed: float = 1.0
|
||
|
) -> AsyncGenerator[np.ndarray, None]:
|
||
|
"""Generate audio using model.
|
||
|
|
||
|
Args:
|
||
|
text: Input text to synthesize
|
||
|
voice: Either a voice path string or a tuple of (voice_name, voice_tensor_or_path)
|
||
|
speed: Speed multiplier
|
||
|
|
||
|
Yields:
|
||
|
Generated audio chunks
|
||
|
|
||
|
Raises:
|
||
|
RuntimeError: If generation fails
|
||
|
"""
|
||
|
if not self.is_loaded:
|
||
|
raise RuntimeError("Model not loaded")
|
||
|
|
||
|
try:
|
||
|
# Memory management for GPU
|
||
|
if self._device == "cuda":
|
||
|
if self._check_memory():
|
||
|
self._clear_memory()
|
||
|
|
||
|
# Handle voice input
|
||
|
if isinstance(voice, str):
|
||
|
voice_path = voice # Voice path provided directly
|
||
|
logger.debug(f"Using voice path directly: {voice_path}")
|
||
|
# Get language code from first letter of voice name
|
||
|
try:
|
||
|
name = os.path.basename(voice_path)
|
||
|
logger.debug(f"Voice basename: {name}")
|
||
|
if name.endswith('.pt'):
|
||
|
name = name[:-3]
|
||
|
lang_code = name[0]
|
||
|
logger.debug(f"Extracted language code: {lang_code}")
|
||
|
except Exception as e:
|
||
|
# Default to American English if we can't get language code
|
||
|
logger.warning(f"Failed to extract language code: {e}, defaulting to 'a'")
|
||
|
lang_code = 'a'
|
||
|
else:
|
||
|
# Unpack voice name and tensor/path
|
||
|
voice_name, voice_data = voice
|
||
|
# If voice_data is a path, use it directly
|
||
|
if isinstance(voice_data, str):
|
||
|
voice_path = voice_data
|
||
|
logger.debug(f"Using provided voice path: {voice_path}")
|
||
|
else:
|
||
|
# Save tensor to temporary file
|
||
|
import tempfile
|
||
|
temp_dir = tempfile.gettempdir()
|
||
|
voice_path = os.path.join(temp_dir, f"{voice_name}.pt")
|
||
|
logger.debug(f"Saving voice tensor to: {voice_path}")
|
||
|
torch.save(voice_data, voice_path)
|
||
|
# Get language code from voice name
|
||
|
lang_code = voice_name[0]
|
||
|
logger.debug(f"Using language code '{lang_code}' from voice name {voice_name}")
|
||
|
|
||
|
# Update pipeline's language code if needed
|
||
|
if self._pipeline.lang_code != lang_code:
|
||
|
logger.debug(f"Creating pipeline with lang_code='{lang_code}'")
|
||
|
self._pipeline = KPipeline(lang_code=lang_code, model=self._model, device=self._device)
|
||
|
|
||
|
# Generate audio using pipeline
|
||
|
logger.debug(f"Generating audio for text: '{text[:100]}...'")
|
||
|
for i, result in enumerate(self._pipeline(text, voice=voice_path, speed=speed)):
|
||
|
logger.debug(f"Processing chunk {i+1}")
|
||
|
if result.audio is not None:
|
||
|
logger.debug(f"Got audio chunk {i+1} with shape: {result.audio.shape}")
|
||
|
yield result.audio.numpy()
|
||
|
else:
|
||
|
logger.warning(f"No audio in chunk {i+1}")
|
||
|
|
||
|
except Exception as e:
|
||
|
logger.error(f"Generation failed: {e}")
|
||
|
if (
|
||
|
self._device == "cuda"
|
||
|
and model_config.pytorch_gpu.retry_on_oom
|
||
|
and "out of memory" in str(e).lower()
|
||
|
):
|
||
|
self._clear_memory()
|
||
|
async for chunk in self.generate(text, voice, speed):
|
||
|
yield chunk
|
||
|
raise
|
||
|
finally:
|
||
|
if self._device == "cuda" and model_config.pytorch_gpu.sync_cuda:
|
||
|
torch.cuda.synchronize()
|
||
|
|
||
|
def _check_memory(self) -> bool:
|
||
|
"""Check if memory usage is above threshold."""
|
||
|
if self._device == "cuda":
|
||
|
memory_gb = torch.cuda.memory_allocated() / 1e9
|
||
|
return memory_gb > model_config.pytorch_gpu.memory_threshold
|
||
|
return False
|
||
|
|
||
|
def _clear_memory(self) -> None:
|
||
|
"""Clear device memory."""
|
||
|
if self._device == "cuda":
|
||
|
torch.cuda.empty_cache()
|
||
|
gc.collect()
|
||
|
|
||
|
def unload(self) -> None:
|
||
|
"""Unload model and free resources."""
|
||
|
if self._model is not None:
|
||
|
del self._model
|
||
|
self._model = None
|
||
|
if self._pipeline is not None:
|
||
|
del self._pipeline
|
||
|
self._pipeline = None
|
||
|
if torch.cuda.is_available():
|
||
|
torch.cuda.empty_cache()
|
||
|
gc.collect()
|
||
|
|
||
|
@property
|
||
|
def is_loaded(self) -> bool:
|
||
|
"""Check if model is loaded."""
|
||
|
return self._model is not None and self._pipeline is not None
|
||
|
|
||
|
@property
|
||
|
def device(self) -> str:
|
||
|
"""Get device model is running on."""
|
||
|
return self._device
|