Kokoro-FastAPI/api/src/inference/kokoro_v1.py

328 lines
12 KiB
Python
Raw Normal View History

"""Clean Kokoro implementation with controlled resource management."""
import os
2025-02-09 18:32:17 -07:00
from typing import AsyncGenerator, Dict, Optional, Tuple, Union
import numpy as np
import torch
2025-02-09 18:32:17 -07:00
from kokoro import KModel, KPipeline
from loguru import logger
from ..core import paths
from ..core.config import settings
2025-02-09 18:32:17 -07:00
from ..core.model_config import model_config
from .base import BaseModelBackend
2025-02-09 18:32:17 -07:00
class KokoroV1(BaseModelBackend):
"""Kokoro backend with controlled resource management."""
def __init__(self):
"""Initialize backend with environment-based configuration."""
super().__init__()
# Strictly respect settings.use_gpu
self._device = "cuda" if settings.use_gpu else "cpu"
self._model: Optional[KModel] = None
self._pipelines: Dict[str, KPipeline] = {} # Store pipelines by lang_code
self._stream: Optional[torch.cuda.Stream] = None
def set_stream(self, stream: torch.cuda.Stream) -> None:
"""Set CUDA stream for this instance."""
self._stream = stream
async def load_model(self, path: str) -> None:
"""Load pre-baked model.
2025-02-09 18:32:17 -07:00
Args:
path: Path to model file
2025-02-09 18:32:17 -07:00
Raises:
RuntimeError: If model loading fails
"""
try:
# Get verified model path
model_path = await paths.get_model_path(path)
2025-02-09 18:32:17 -07:00
config_path = os.path.join(os.path.dirname(model_path), "config.json")
if not os.path.exists(config_path):
raise RuntimeError(f"Config file not found: {config_path}")
2025-02-09 18:32:17 -07:00
logger.info(f"Loading Kokoro model on {self._device}")
logger.info(f"Config path: {config_path}")
logger.info(f"Model path: {model_path}")
2025-02-09 18:32:17 -07:00
# Load model and let KModel handle device mapping
2025-02-09 18:32:17 -07:00
self._model = KModel(config=config_path, model=model_path).eval()
# Move to CUDA if needed
if self._device == "cuda":
self._model = self._model.cuda()
2025-02-09 18:32:17 -07:00
except FileNotFoundError as e:
raise e
except Exception as e:
raise RuntimeError(f"Failed to load Kokoro model: {e}")
def _get_pipeline(self, lang_code: str) -> KPipeline:
"""Get or create pipeline for language code.
2025-02-09 18:32:17 -07:00
Args:
lang_code: Language code to use
2025-02-09 18:32:17 -07:00
Returns:
KPipeline instance for the language
"""
if not self._model:
raise RuntimeError("Model not loaded")
2025-02-09 18:32:17 -07:00
if lang_code not in self._pipelines:
logger.info(f"Creating new pipeline for language code: {lang_code}")
self._pipelines[lang_code] = KPipeline(
2025-02-09 18:32:17 -07:00
lang_code=lang_code, model=self._model, device=self._device
)
return self._pipelines[lang_code]
async def generate_from_tokens(
self,
tokens: str,
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
speed: float = 1.0,
2025-02-09 18:32:17 -07:00
lang_code: Optional[str] = None,
) -> AsyncGenerator[np.ndarray, None]:
"""Generate audio from phoneme tokens.
Args:
tokens: Input phoneme tokens to synthesize
voice: Either a voice path string or a tuple of (voice_name, voice_tensor/path)
speed: Speed multiplier
lang_code: Optional language code override
Yields:
Generated audio chunks
Raises:
RuntimeError: If generation fails
"""
if not self.is_loaded:
raise RuntimeError("Model not loaded")
try:
# Memory management for GPU
if self._device == "cuda":
if self._check_memory():
self._clear_memory()
# Handle voice input
voice_path: str
voice_name: str
if isinstance(voice, tuple):
voice_name, voice_data = voice
if isinstance(voice_data, str):
voice_path = voice_data
else:
# Save tensor to temporary file
import tempfile
2025-02-09 18:32:17 -07:00
temp_dir = tempfile.gettempdir()
voice_path = os.path.join(temp_dir, f"{voice_name}.pt")
# Save tensor with CPU mapping for portability
torch.save(voice_data.cpu(), voice_path)
else:
voice_path = voice
voice_name = os.path.splitext(os.path.basename(voice_path))[0]
# Load voice tensor with proper device mapping
2025-02-09 18:32:17 -07:00
voice_tensor = await paths.load_voice_tensor(
voice_path, device=self._device
)
# Save back to a temporary file with proper device mapping
import tempfile
2025-02-09 18:32:17 -07:00
temp_dir = tempfile.gettempdir()
2025-02-09 18:32:17 -07:00
temp_path = os.path.join(
temp_dir, f"temp_voice_{os.path.basename(voice_path)}"
)
await paths.save_voice_tensor(voice_tensor, temp_path)
voice_path = temp_path
# Use provided lang_code or get from voice name
pipeline_lang_code = lang_code if lang_code else voice_name[0].lower()
pipeline = self._get_pipeline(pipeline_lang_code)
2025-02-09 18:32:17 -07:00
logger.debug(
f"Generating audio from tokens with lang_code '{pipeline_lang_code}': '{tokens[:100]}...'"
)
# Use CUDA stream if available
if self._stream and self._device == "cuda":
with torch.cuda.stream(self._stream):
for result in pipeline.generate_from_tokens(
tokens=tokens, voice=voice_path, speed=speed, model=self._model
):
if result.audio is not None:
logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
yield result.audio.numpy()
else:
logger.warning("No audio in chunk")
else:
for result in pipeline.generate_from_tokens(
tokens=tokens, voice=voice_path, speed=speed, model=self._model
):
if result.audio is not None:
logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
yield result.audio.numpy()
else:
logger.warning("No audio in chunk")
except Exception as e:
logger.error(f"Generation failed: {e}")
if (
self._device == "cuda"
and model_config.pytorch_gpu.retry_on_oom
and "out of memory" in str(e).lower()
):
self._clear_memory()
2025-02-09 18:32:17 -07:00
async for chunk in self.generate_from_tokens(
tokens, voice, speed, lang_code
):
yield chunk
raise
async def generate(
self,
text: str,
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
speed: float = 1.0,
2025-02-09 18:32:17 -07:00
lang_code: Optional[str] = None,
) -> AsyncGenerator[np.ndarray, None]:
"""Generate audio using model.
Args:
text: Input text to synthesize
voice: Either a voice path string or a tuple of (voice_name, voice_tensor/path)
speed: Speed multiplier
lang_code: Optional language code override
Yields:
Generated audio chunks
Raises:
RuntimeError: If generation fails
"""
if not self.is_loaded:
raise RuntimeError("Model not loaded")
try:
# Memory management for GPU
if self._device == "cuda":
if self._check_memory():
self._clear_memory()
# Handle voice input
voice_path: str
voice_name: str
if isinstance(voice, tuple):
voice_name, voice_data = voice
if isinstance(voice_data, str):
voice_path = voice_data
else:
# Save tensor to temporary file
import tempfile
2025-02-09 18:32:17 -07:00
temp_dir = tempfile.gettempdir()
voice_path = os.path.join(temp_dir, f"{voice_name}.pt")
# Save tensor with CPU mapping for portability
torch.save(voice_data.cpu(), voice_path)
else:
voice_path = voice
voice_name = os.path.splitext(os.path.basename(voice_path))[0]
# Load voice tensor with proper device mapping
2025-02-09 18:32:17 -07:00
voice_tensor = await paths.load_voice_tensor(
voice_path, device=self._device
)
# Save back to a temporary file with proper device mapping
import tempfile
2025-02-09 18:32:17 -07:00
temp_dir = tempfile.gettempdir()
2025-02-09 18:32:17 -07:00
temp_path = os.path.join(
temp_dir, f"temp_voice_{os.path.basename(voice_path)}"
)
await paths.save_voice_tensor(voice_tensor, temp_path)
voice_path = temp_path
# Use provided lang_code or get from voice name
pipeline_lang_code = lang_code if lang_code else voice_name[0].lower()
pipeline = self._get_pipeline(pipeline_lang_code)
2025-02-09 18:32:17 -07:00
logger.debug(
f"Generating audio for text with lang_code '{pipeline_lang_code}': '{text[:100]}...'"
)
# Use CUDA stream if available
if self._stream and self._device == "cuda":
with torch.cuda.stream(self._stream):
for result in pipeline(
text, voice=voice_path, speed=speed, model=self._model
):
if result.audio is not None:
logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
yield result.audio.numpy()
else:
logger.warning("No audio in chunk")
else:
for result in pipeline(
text, voice=voice_path, speed=speed, model=self._model
):
if result.audio is not None:
logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
yield result.audio.numpy()
else:
logger.warning("No audio in chunk")
except Exception as e:
logger.error(f"Generation failed: {e}")
if (
self._device == "cuda"
and model_config.pytorch_gpu.retry_on_oom
and "out of memory" in str(e).lower()
):
self._clear_memory()
async for chunk in self.generate(text, voice, speed, lang_code):
yield chunk
raise
def _check_memory(self) -> bool:
"""Check if memory usage is above threshold."""
if self._device == "cuda":
memory_gb = torch.cuda.memory_allocated() / 1e9
return memory_gb > model_config.pytorch_gpu.memory_threshold
return False
def _clear_memory(self) -> None:
"""Clear device memory."""
if self._device == "cuda":
torch.cuda.empty_cache()
torch.cuda.synchronize()
def unload(self) -> None:
"""Unload model and free resources."""
if self._model is not None:
del self._model
self._model = None
for pipeline in self._pipelines.values():
del pipeline
self._pipelines.clear()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
@property
def is_loaded(self) -> bool:
"""Check if model is loaded."""
return self._model is not None
@property
def device(self) -> str:
"""Get device model is running on."""
return self._device