Kokoro-FastAPI/api/src/inference/kokoro_v1.py

371 lines
14 KiB
Python
Raw Normal View History

"""Clean Kokoro implementation with controlled resource management."""
import os
2025-02-09 18:32:17 -07:00
from typing import AsyncGenerator, Dict, Optional, Tuple, Union
import numpy as np
import torch
2025-02-09 18:32:17 -07:00
from kokoro import KModel, KPipeline
from loguru import logger
from ..core import paths
from ..core.config import settings
2025-02-09 18:32:17 -07:00
from ..core.model_config import model_config
2025-02-14 13:37:42 -05:00
from ..structures.schemas import WordTimestamp
2025-04-04 16:50:46 -06:00
from .base import AudioChunk, BaseModelBackend
class KokoroV1(BaseModelBackend):
"""Kokoro backend with controlled resource management."""
def __init__(self):
"""Initialize backend with environment-based configuration."""
super().__init__()
# Strictly respect settings.use_gpu
self._device = settings.get_device()
self._model: Optional[KModel] = None
self._pipelines: Dict[str, KPipeline] = {} # Store pipelines by lang_code
async def load_model(self, path: str) -> None:
"""Load pre-baked model.
2025-02-09 18:32:17 -07:00
Args:
path: Path to model file
2025-02-09 18:32:17 -07:00
Raises:
RuntimeError: If model loading fails
"""
try:
# Get verified model path
model_path = await paths.get_model_path(path)
2025-02-09 18:32:17 -07:00
config_path = os.path.join(os.path.dirname(model_path), "config.json")
if not os.path.exists(config_path):
raise RuntimeError(f"Config file not found: {config_path}")
2025-02-09 18:32:17 -07:00
logger.info(f"Loading Kokoro model on {self._device}")
logger.info(f"Config path: {config_path}")
logger.info(f"Model path: {model_path}")
2025-02-09 18:32:17 -07:00
# Load model and let KModel handle device mapping
2025-02-09 18:32:17 -07:00
self._model = KModel(config=config_path, model=model_path).eval()
# For MPS, manually move ISTFT layers to CPU while keeping rest on MPS
if self._device == "mps":
2025-04-04 16:58:07 -06:00
logger.info(
"Moving model to MPS device with CPU fallback for unsupported operations"
)
self._model = self._model.to(torch.device("mps"))
elif self._device == "cuda":
self._model = self._model.cuda()
else:
self._model = self._model.cpu()
2025-02-09 18:32:17 -07:00
except FileNotFoundError as e:
raise e
except Exception as e:
raise RuntimeError(f"Failed to load Kokoro model: {e}")
def _get_pipeline(self, lang_code: str) -> KPipeline:
"""Get or create pipeline for language code.
2025-02-09 18:32:17 -07:00
Args:
lang_code: Language code to use
2025-02-09 18:32:17 -07:00
Returns:
KPipeline instance for the language
"""
if not self._model:
raise RuntimeError("Model not loaded")
2025-02-09 18:32:17 -07:00
if lang_code not in self._pipelines:
logger.info(f"Creating new pipeline for language code: {lang_code}")
self._pipelines[lang_code] = KPipeline(
2025-02-09 18:32:17 -07:00
lang_code=lang_code, model=self._model, device=self._device
)
return self._pipelines[lang_code]
async def generate_from_tokens(
self,
tokens: str,
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
speed: float = 1.0,
2025-02-09 18:32:17 -07:00
lang_code: Optional[str] = None,
) -> AsyncGenerator[np.ndarray, None]:
"""Generate audio from phoneme tokens.
Args:
tokens: Input phoneme tokens to synthesize
voice: Either a voice path string or a tuple of (voice_name, voice_tensor/path)
speed: Speed multiplier
lang_code: Optional language code override
Yields:
Generated audio chunks
Raises:
RuntimeError: If generation fails
"""
if not self.is_loaded:
raise RuntimeError("Model not loaded")
try:
# Memory management for GPU
if self._device == "cuda":
if self._check_memory():
self._clear_memory()
# Handle voice input
voice_path: str
voice_name: str
if isinstance(voice, tuple):
voice_name, voice_data = voice
if isinstance(voice_data, str):
voice_path = voice_data
else:
# Save tensor to temporary file
import tempfile
2025-02-09 18:32:17 -07:00
temp_dir = tempfile.gettempdir()
voice_path = os.path.join(temp_dir, f"{voice_name}.pt")
# Save tensor with CPU mapping for portability
torch.save(voice_data.cpu(), voice_path)
else:
voice_path = voice
voice_name = os.path.splitext(os.path.basename(voice_path))[0]
# Load voice tensor with proper device mapping
2025-02-09 18:32:17 -07:00
voice_tensor = await paths.load_voice_tensor(
voice_path, device=self._device
)
# Save back to a temporary file with proper device mapping
import tempfile
2025-02-09 18:32:17 -07:00
temp_dir = tempfile.gettempdir()
2025-02-09 18:32:17 -07:00
temp_path = os.path.join(
temp_dir, f"temp_voice_{os.path.basename(voice_path)}"
)
await paths.save_voice_tensor(voice_tensor, temp_path)
voice_path = temp_path
2025-02-11 14:05:14 +00:00
# Use provided lang_code, settings voice code override, or first letter of voice name
2025-04-04 16:58:07 -06:00
if lang_code: # api is given priority
2025-02-11 14:05:14 +00:00
pipeline_lang_code = lang_code
2025-04-04 16:58:07 -06:00
elif settings.default_voice_code: # settings is next priority
2025-02-11 14:05:14 +00:00
pipeline_lang_code = settings.default_voice_code
2025-04-04 16:58:07 -06:00
else: # voice name is default/fallback
2025-02-11 14:05:14 +00:00
pipeline_lang_code = voice_name[0].lower()
pipeline = self._get_pipeline(pipeline_lang_code)
2025-02-09 18:32:17 -07:00
logger.debug(
f"Generating audio from tokens with lang_code '{pipeline_lang_code}': '{tokens[:100]}{'...' if len(tokens) > 100 else ''}'"
2025-02-09 18:32:17 -07:00
)
for result in pipeline.generate_from_tokens(
2025-02-09 18:32:17 -07:00
tokens=tokens, voice=voice_path, speed=speed, model=self._model
):
if result.audio is not None:
logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
yield result.audio.numpy()
else:
logger.warning("No audio in chunk")
except Exception as e:
logger.error(f"Generation failed: {e}")
if (
self._device == "cuda"
and model_config.pytorch_gpu.retry_on_oom
and "out of memory" in str(e).lower()
):
self._clear_memory()
2025-02-09 18:32:17 -07:00
async for chunk in self.generate_from_tokens(
tokens, voice, speed, lang_code
):
yield chunk
raise
async def generate(
self,
text: str,
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
speed: float = 1.0,
2025-02-09 18:32:17 -07:00
lang_code: Optional[str] = None,
2025-02-11 22:32:10 -05:00
return_timestamps: Optional[bool] = False,
) -> AsyncGenerator[AudioChunk, None]:
"""Generate audio using model.
Args:
text: Input text to synthesize
voice: Either a voice path string or a tuple of (voice_name, voice_tensor/path)
speed: Speed multiplier
lang_code: Optional language code override
Yields:
Generated audio chunks
Raises:
RuntimeError: If generation fails
"""
if not self.is_loaded:
raise RuntimeError("Model not loaded")
try:
# Memory management for GPU
if self._device == "cuda":
if self._check_memory():
self._clear_memory()
# Handle voice input
voice_path: str
voice_name: str
if isinstance(voice, tuple):
voice_name, voice_data = voice
if isinstance(voice_data, str):
voice_path = voice_data
else:
# Save tensor to temporary file
import tempfile
2025-02-09 18:32:17 -07:00
temp_dir = tempfile.gettempdir()
voice_path = os.path.join(temp_dir, f"{voice_name}.pt")
# Save tensor with CPU mapping for portability
torch.save(voice_data.cpu(), voice_path)
else:
voice_path = voice
voice_name = os.path.splitext(os.path.basename(voice_path))[0]
# Load voice tensor with proper device mapping
2025-02-09 18:32:17 -07:00
voice_tensor = await paths.load_voice_tensor(
voice_path, device=self._device
)
# Save back to a temporary file with proper device mapping
import tempfile
2025-02-09 18:32:17 -07:00
temp_dir = tempfile.gettempdir()
2025-02-09 18:32:17 -07:00
temp_path = os.path.join(
temp_dir, f"temp_voice_{os.path.basename(voice_path)}"
)
await paths.save_voice_tensor(voice_tensor, temp_path)
voice_path = temp_path
2025-02-11 14:05:14 +00:00
# Use provided lang_code, settings voice code override, or first letter of voice name
2025-04-04 16:58:07 -06:00
pipeline_lang_code = (
lang_code
if lang_code
else (
settings.default_voice_code
if settings.default_voice_code
else voice_name[0].lower()
)
)
pipeline = self._get_pipeline(pipeline_lang_code)
2025-02-09 18:32:17 -07:00
logger.debug(
f"Generating audio for text with lang_code '{pipeline_lang_code}': '{text[:100]}{'...' if len(text) > 100 else ''}'"
2025-02-09 18:32:17 -07:00
)
for result in pipeline(
2025-02-09 18:32:17 -07:00
text, voice=voice_path, speed=speed, model=self._model
):
if result.audio is not None:
logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
2025-04-04 16:58:07 -06:00
word_timestamps = None
if (
return_timestamps
and hasattr(result, "tokens")
and result.tokens
):
word_timestamps = []
current_offset = 0.0
2025-02-11 22:32:10 -05:00
logger.debug(
2025-04-04 16:58:07 -06:00
f"Processing chunk timestamps with {len(result.tokens)} tokens"
)
2025-02-11 22:32:10 -05:00
if result.pred_dur is not None:
try:
# Add timestamps with offset
for token in result.tokens:
if not all(
hasattr(token, attr)
for attr in [
"text",
"start_ts",
"end_ts",
]
):
continue
if not token.text or not token.text.strip():
continue
2025-02-11 22:32:10 -05:00
start_time = float(token.start_ts) + current_offset
end_time = float(token.end_ts) + current_offset
word_timestamps.append(
2025-02-14 13:37:42 -05:00
WordTimestamp(
word=str(token.text).strip(),
start_time=start_time,
2025-04-04 16:58:07 -06:00
end_time=end_time,
2025-02-14 13:37:42 -05:00
)
2025-02-11 22:32:10 -05:00
)
logger.debug(
f"Added timestamp for word '{token.text}': {start_time:.3f}s - {end_time:.3f}s"
)
except Exception as e:
logger.error(
f"Failed to process timestamps for chunk: {e}"
)
2025-04-04 16:58:07 -06:00
yield AudioChunk(
result.audio.numpy(), word_timestamps=word_timestamps
)
else:
logger.warning("No audio in chunk")
except Exception as e:
logger.error(f"Generation failed: {e}")
if (
self._device == "cuda"
and model_config.pytorch_gpu.retry_on_oom
and "out of memory" in str(e).lower()
):
self._clear_memory()
async for chunk in self.generate(text, voice, speed, lang_code):
yield chunk
raise
def _check_memory(self) -> bool:
"""Check if memory usage is above threshold."""
if self._device == "cuda":
memory_gb = torch.cuda.memory_allocated() / 1e9
return memory_gb > model_config.pytorch_gpu.memory_threshold
# MPS doesn't provide memory management APIs
return False
def _clear_memory(self) -> None:
"""Clear device memory."""
if self._device == "cuda":
torch.cuda.empty_cache()
torch.cuda.synchronize()
elif self._device == "mps":
# Empty cache if available (future-proofing)
2025-04-04 16:58:07 -06:00
if hasattr(torch.mps, "empty_cache"):
torch.mps.empty_cache()
def unload(self) -> None:
"""Unload model and free resources."""
if self._model is not None:
del self._model
self._model = None
for pipeline in self._pipelines.values():
del pipeline
self._pipelines.clear()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
@property
def is_loaded(self) -> bool:
"""Check if model is loaded."""
return self._model is not None
@property
def device(self) -> str:
"""Get device model is running on."""
return self._device