mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00

Some checks are pending
CI / test (3.10) (push) Waiting to run
-Add default voice code setting and update language code resolution logic
302 lines
11 KiB
Python
302 lines
11 KiB
Python
"""Clean Kokoro implementation with controlled resource management."""
|
|
|
|
import os
|
|
from typing import AsyncGenerator, Dict, Optional, Tuple, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
from kokoro import KModel, KPipeline
|
|
from loguru import logger
|
|
|
|
from ..core import paths
|
|
from ..core.config import settings
|
|
from ..core.model_config import model_config
|
|
from .base import BaseModelBackend
|
|
|
|
|
|
class KokoroV1(BaseModelBackend):
|
|
"""Kokoro backend with controlled resource management."""
|
|
|
|
def __init__(self):
|
|
"""Initialize backend with environment-based configuration."""
|
|
super().__init__()
|
|
# Strictly respect settings.use_gpu
|
|
self._device = "cuda" if settings.use_gpu else "cpu"
|
|
self._model: Optional[KModel] = None
|
|
self._pipelines: Dict[str, KPipeline] = {} # Store pipelines by lang_code
|
|
|
|
async def load_model(self, path: str) -> None:
|
|
"""Load pre-baked model.
|
|
|
|
Args:
|
|
path: Path to model file
|
|
|
|
Raises:
|
|
RuntimeError: If model loading fails
|
|
"""
|
|
try:
|
|
# Get verified model path
|
|
model_path = await paths.get_model_path(path)
|
|
config_path = os.path.join(os.path.dirname(model_path), "config.json")
|
|
|
|
if not os.path.exists(config_path):
|
|
raise RuntimeError(f"Config file not found: {config_path}")
|
|
|
|
logger.info(f"Loading Kokoro model on {self._device}")
|
|
logger.info(f"Config path: {config_path}")
|
|
logger.info(f"Model path: {model_path}")
|
|
|
|
# Load model and let KModel handle device mapping
|
|
self._model = KModel(config=config_path, model=model_path).eval()
|
|
# Move to CUDA if needed
|
|
if self._device == "cuda":
|
|
self._model = self._model.cuda()
|
|
|
|
except FileNotFoundError as e:
|
|
raise e
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to load Kokoro model: {e}")
|
|
|
|
def _get_pipeline(self, lang_code: str) -> KPipeline:
|
|
"""Get or create pipeline for language code.
|
|
|
|
Args:
|
|
lang_code: Language code to use
|
|
|
|
Returns:
|
|
KPipeline instance for the language
|
|
"""
|
|
if not self._model:
|
|
raise RuntimeError("Model not loaded")
|
|
|
|
if lang_code not in self._pipelines:
|
|
logger.info(f"Creating new pipeline for language code: {lang_code}")
|
|
self._pipelines[lang_code] = KPipeline(
|
|
lang_code=lang_code, model=self._model, device=self._device
|
|
)
|
|
return self._pipelines[lang_code]
|
|
|
|
async def generate_from_tokens(
|
|
self,
|
|
tokens: str,
|
|
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
|
|
speed: float = 1.0,
|
|
lang_code: Optional[str] = None,
|
|
) -> AsyncGenerator[np.ndarray, None]:
|
|
"""Generate audio from phoneme tokens.
|
|
|
|
Args:
|
|
tokens: Input phoneme tokens to synthesize
|
|
voice: Either a voice path string or a tuple of (voice_name, voice_tensor/path)
|
|
speed: Speed multiplier
|
|
lang_code: Optional language code override
|
|
|
|
Yields:
|
|
Generated audio chunks
|
|
|
|
Raises:
|
|
RuntimeError: If generation fails
|
|
"""
|
|
if not self.is_loaded:
|
|
raise RuntimeError("Model not loaded")
|
|
|
|
try:
|
|
# Memory management for GPU
|
|
if self._device == "cuda":
|
|
if self._check_memory():
|
|
self._clear_memory()
|
|
|
|
# Handle voice input
|
|
voice_path: str
|
|
voice_name: str
|
|
if isinstance(voice, tuple):
|
|
voice_name, voice_data = voice
|
|
if isinstance(voice_data, str):
|
|
voice_path = voice_data
|
|
else:
|
|
# Save tensor to temporary file
|
|
import tempfile
|
|
|
|
temp_dir = tempfile.gettempdir()
|
|
voice_path = os.path.join(temp_dir, f"{voice_name}.pt")
|
|
# Save tensor with CPU mapping for portability
|
|
torch.save(voice_data.cpu(), voice_path)
|
|
else:
|
|
voice_path = voice
|
|
voice_name = os.path.splitext(os.path.basename(voice_path))[0]
|
|
|
|
# Load voice tensor with proper device mapping
|
|
voice_tensor = await paths.load_voice_tensor(
|
|
voice_path, device=self._device
|
|
)
|
|
# Save back to a temporary file with proper device mapping
|
|
import tempfile
|
|
|
|
temp_dir = tempfile.gettempdir()
|
|
temp_path = os.path.join(
|
|
temp_dir, f"temp_voice_{os.path.basename(voice_path)}"
|
|
)
|
|
await paths.save_voice_tensor(voice_tensor, temp_path)
|
|
voice_path = temp_path
|
|
|
|
# Use provided lang_code, settings voice code override, or first letter of voice name
|
|
if lang_code: # api is given priority
|
|
pipeline_lang_code = lang_code
|
|
elif settings.default_voice_code: # settings is next priority
|
|
pipeline_lang_code = settings.default_voice_code
|
|
else: # voice name is default/fallback
|
|
pipeline_lang_code = voice_name[0].lower()
|
|
|
|
pipeline = self._get_pipeline(pipeline_lang_code)
|
|
|
|
logger.debug(
|
|
f"Generating audio from tokens with lang_code '{pipeline_lang_code}': '{tokens[:100]}...'"
|
|
)
|
|
for result in pipeline.generate_from_tokens(
|
|
tokens=tokens, voice=voice_path, speed=speed, model=self._model
|
|
):
|
|
if result.audio is not None:
|
|
logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
|
|
yield result.audio.numpy()
|
|
else:
|
|
logger.warning("No audio in chunk")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Generation failed: {e}")
|
|
if (
|
|
self._device == "cuda"
|
|
and model_config.pytorch_gpu.retry_on_oom
|
|
and "out of memory" in str(e).lower()
|
|
):
|
|
self._clear_memory()
|
|
async for chunk in self.generate_from_tokens(
|
|
tokens, voice, speed, lang_code
|
|
):
|
|
yield chunk
|
|
raise
|
|
|
|
async def generate(
|
|
self,
|
|
text: str,
|
|
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
|
|
speed: float = 1.0,
|
|
lang_code: Optional[str] = None,
|
|
) -> AsyncGenerator[np.ndarray, None]:
|
|
"""Generate audio using model.
|
|
|
|
Args:
|
|
text: Input text to synthesize
|
|
voice: Either a voice path string or a tuple of (voice_name, voice_tensor/path)
|
|
speed: Speed multiplier
|
|
lang_code: Optional language code override
|
|
|
|
Yields:
|
|
Generated audio chunks
|
|
|
|
Raises:
|
|
RuntimeError: If generation fails
|
|
"""
|
|
if not self.is_loaded:
|
|
raise RuntimeError("Model not loaded")
|
|
|
|
try:
|
|
# Memory management for GPU
|
|
if self._device == "cuda":
|
|
if self._check_memory():
|
|
self._clear_memory()
|
|
|
|
# Handle voice input
|
|
voice_path: str
|
|
voice_name: str
|
|
if isinstance(voice, tuple):
|
|
voice_name, voice_data = voice
|
|
if isinstance(voice_data, str):
|
|
voice_path = voice_data
|
|
else:
|
|
# Save tensor to temporary file
|
|
import tempfile
|
|
|
|
temp_dir = tempfile.gettempdir()
|
|
voice_path = os.path.join(temp_dir, f"{voice_name}.pt")
|
|
# Save tensor with CPU mapping for portability
|
|
torch.save(voice_data.cpu(), voice_path)
|
|
else:
|
|
voice_path = voice
|
|
voice_name = os.path.splitext(os.path.basename(voice_path))[0]
|
|
|
|
# Load voice tensor with proper device mapping
|
|
voice_tensor = await paths.load_voice_tensor(
|
|
voice_path, device=self._device
|
|
)
|
|
# Save back to a temporary file with proper device mapping
|
|
import tempfile
|
|
|
|
temp_dir = tempfile.gettempdir()
|
|
temp_path = os.path.join(
|
|
temp_dir, f"temp_voice_{os.path.basename(voice_path)}"
|
|
)
|
|
await paths.save_voice_tensor(voice_tensor, temp_path)
|
|
voice_path = temp_path
|
|
|
|
# Use provided lang_code, settings voice code override, or first letter of voice name
|
|
pipeline_lang_code = lang_code if lang_code else (settings.default_voice_code if settings.default_voice_code else voice_name[0].lower())
|
|
pipeline = self._get_pipeline(pipeline_lang_code)
|
|
|
|
logger.debug(
|
|
f"Generating audio for text with lang_code '{pipeline_lang_code}': '{text[:100]}...'"
|
|
)
|
|
for result in pipeline(
|
|
text, voice=voice_path, speed=speed, model=self._model
|
|
):
|
|
if result.audio is not None:
|
|
logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
|
|
yield result.audio.numpy()
|
|
else:
|
|
logger.warning("No audio in chunk")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Generation failed: {e}")
|
|
if (
|
|
self._device == "cuda"
|
|
and model_config.pytorch_gpu.retry_on_oom
|
|
and "out of memory" in str(e).lower()
|
|
):
|
|
self._clear_memory()
|
|
async for chunk in self.generate(text, voice, speed, lang_code):
|
|
yield chunk
|
|
raise
|
|
|
|
def _check_memory(self) -> bool:
|
|
"""Check if memory usage is above threshold."""
|
|
if self._device == "cuda":
|
|
memory_gb = torch.cuda.memory_allocated() / 1e9
|
|
return memory_gb > model_config.pytorch_gpu.memory_threshold
|
|
return False
|
|
|
|
def _clear_memory(self) -> None:
|
|
"""Clear device memory."""
|
|
if self._device == "cuda":
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.synchronize()
|
|
|
|
def unload(self) -> None:
|
|
"""Unload model and free resources."""
|
|
if self._model is not None:
|
|
del self._model
|
|
self._model = None
|
|
for pipeline in self._pipelines.values():
|
|
del pipeline
|
|
self._pipelines.clear()
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.synchronize()
|
|
|
|
@property
|
|
def is_loaded(self) -> bool:
|
|
"""Check if model is loaded."""
|
|
return self._model is not None
|
|
|
|
@property
|
|
def device(self) -> str:
|
|
"""Get device model is running on."""
|
|
return self._device
|