Kokoro-FastAPI/api/src/inference/kokoro_v1.py

187 lines
7.1 KiB
Python
Raw Normal View History

"""PyTorch inference backend with environment-based configuration."""
import gc
import os
from typing import AsyncGenerator, Optional, List, Union, Tuple
from contextlib import nullcontext
import numpy as np
import torch
from loguru import logger
from ..core import paths
from ..core.model_config import model_config
from ..core.config import settings
from .base import BaseModelBackend
from kokoro import KModel, KPipeline
class KokoroV1(BaseModelBackend):
"""Kokoro package based inference backend with environment-based configuration."""
def __init__(self):
"""Initialize backend based on environment configuration."""
super().__init__()
# Configure device based on settings
self._device = (
"cuda" if settings.use_gpu and torch.cuda.is_available() else "cpu"
)
self._model: Optional[KModel] = None
self._pipeline: Optional[KPipeline] = None
async def load_model(self, path: str) -> None:
"""Load Kokoro model.
Args:
path: Path to model file
Raises:
RuntimeError: If model loading fails
"""
try:
# Get verified model path
model_path = await paths.get_model_path(path)
# Get config.json path from the same directory
config_path = os.path.join(os.path.dirname(model_path), 'config.json')
if not os.path.exists(config_path):
raise RuntimeError(f"Config file not found: {config_path}")
logger.info(f"Loading Kokoro model on {self._device}")
logger.info(f"Config path: {config_path}")
logger.info(f"Model path: {model_path}")
# Initialize model with config and weights
self._model = KModel(config=config_path, model=model_path).to(self._device).eval()
# Initialize pipeline with American English by default
self._pipeline = KPipeline(lang_code='a', model=self._model, device=self._device)
except Exception as e:
raise RuntimeError(f"Failed to load Kokoro model: {e}")
async def generate(
self, text: str, voice: Union[str, Tuple[str, Union[torch.Tensor, str]]], speed: float = 1.0
) -> AsyncGenerator[np.ndarray, None]:
"""Generate audio using model.
Args:
text: Input text to synthesize
voice: Either a voice path string or a tuple of (voice_name, voice_tensor_or_path)
speed: Speed multiplier
Yields:
Generated audio chunks
Raises:
RuntimeError: If generation fails
"""
if not self.is_loaded:
raise RuntimeError("Model not loaded")
try:
# Memory management for GPU
if self._device == "cuda":
if self._check_memory():
self._clear_memory()
# Handle voice input
if isinstance(voice, str):
voice_path = voice # Voice path provided directly
logger.debug(f"Using voice path directly: {voice_path}")
# Get language code from first letter of voice name
try:
name = os.path.basename(voice_path)
logger.debug(f"Voice basename: {name}")
if name.endswith('.pt'):
name = name[:-3]
lang_code = name[0]
logger.debug(f"Extracted language code: {lang_code}")
except Exception as e:
# Default to American English if we can't get language code
logger.warning(f"Failed to extract language code: {e}, defaulting to 'a'")
lang_code = 'a'
else:
# Unpack voice name and tensor/path
voice_name, voice_data = voice
# If voice_data is a path, use it directly
if isinstance(voice_data, str):
voice_path = voice_data
logger.debug(f"Using provided voice path: {voice_path}")
else:
# Save tensor to temporary file
import tempfile
temp_dir = tempfile.gettempdir()
voice_path = os.path.join(temp_dir, f"{voice_name}.pt")
logger.debug(f"Saving voice tensor to: {voice_path}")
torch.save(voice_data, voice_path)
# Get language code from voice name
lang_code = voice_name[0]
logger.debug(f"Using language code '{lang_code}' from voice name {voice_name}")
# Update pipeline's language code if needed
if self._pipeline.lang_code != lang_code:
logger.debug(f"Creating pipeline with lang_code='{lang_code}'")
self._pipeline = KPipeline(lang_code=lang_code, model=self._model, device=self._device)
# Generate audio using pipeline
logger.debug(f"Generating audio for text: '{text[:100]}...'")
for i, result in enumerate(self._pipeline(text, voice=voice_path, speed=speed)):
logger.debug(f"Processing chunk {i+1}")
if result.audio is not None:
logger.debug(f"Got audio chunk {i+1} with shape: {result.audio.shape}")
yield result.audio.numpy()
else:
logger.warning(f"No audio in chunk {i+1}")
except Exception as e:
logger.error(f"Generation failed: {e}")
if (
self._device == "cuda"
and model_config.pytorch_gpu.retry_on_oom
and "out of memory" in str(e).lower()
):
self._clear_memory()
async for chunk in self.generate(text, voice, speed):
yield chunk
raise
finally:
if self._device == "cuda" and model_config.pytorch_gpu.sync_cuda:
torch.cuda.synchronize()
def _check_memory(self) -> bool:
"""Check if memory usage is above threshold."""
if self._device == "cuda":
memory_gb = torch.cuda.memory_allocated() / 1e9
return memory_gb > model_config.pytorch_gpu.memory_threshold
return False
def _clear_memory(self) -> None:
"""Clear device memory."""
if self._device == "cuda":
torch.cuda.empty_cache()
gc.collect()
def unload(self) -> None:
"""Unload model and free resources."""
if self._model is not None:
del self._model
self._model = None
if self._pipeline is not None:
del self._pipeline
self._pipeline = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
@property
def is_loaded(self) -> bool:
"""Check if model is loaded."""
return self._model is not None and self._pipeline is not None
@property
def device(self) -> str:
"""Get device model is running on."""
return self._device