"""Clean Kokoro implementation with controlled resource management.""" import os from typing import AsyncGenerator, Dict, Optional, Tuple, Union import numpy as np import torch from kokoro import KModel, KPipeline from loguru import logger from ..core import paths from ..core.config import settings from ..core.model_config import model_config from .base import BaseModelBackend class KokoroV1(BaseModelBackend): """Kokoro backend with controlled resource management.""" def __init__(self): """Initialize backend with environment-based configuration.""" super().__init__() # Strictly respect settings.use_gpu self._device = "cuda" if settings.use_gpu else "cpu" self._model: Optional[KModel] = None self._pipelines: Dict[str, KPipeline] = {} # Store pipelines by lang_code async def load_model(self, path: str) -> None: """Load pre-baked model. Args: path: Path to model file Raises: RuntimeError: If model loading fails """ try: # Get verified model path model_path = await paths.get_model_path(path) config_path = os.path.join(os.path.dirname(model_path), "config.json") if not os.path.exists(config_path): raise RuntimeError(f"Config file not found: {config_path}") logger.info(f"Loading Kokoro model on {self._device}") logger.info(f"Config path: {config_path}") logger.info(f"Model path: {model_path}") # Load model and let KModel handle device mapping self._model = KModel(config=config_path, model=model_path).eval() # Move to CUDA if needed if self._device == "cuda": self._model = self._model.cuda() except FileNotFoundError as e: raise e except Exception as e: raise RuntimeError(f"Failed to load Kokoro model: {e}") def _get_pipeline(self, lang_code: str) -> KPipeline: """Get or create pipeline for language code. Args: lang_code: Language code to use Returns: KPipeline instance for the language """ if not self._model: raise RuntimeError("Model not loaded") if lang_code not in self._pipelines: logger.info(f"Creating new pipeline for language code: {lang_code}") self._pipelines[lang_code] = KPipeline( lang_code=lang_code, model=self._model, device=self._device ) return self._pipelines[lang_code] async def generate_from_tokens( self, tokens: str, voice: Union[str, Tuple[str, Union[torch.Tensor, str]]], speed: float = 1.0, lang_code: Optional[str] = None, ) -> AsyncGenerator[np.ndarray, None]: """Generate audio from phoneme tokens. Args: tokens: Input phoneme tokens to synthesize voice: Either a voice path string or a tuple of (voice_name, voice_tensor/path) speed: Speed multiplier lang_code: Optional language code override Yields: Generated audio chunks Raises: RuntimeError: If generation fails """ if not self.is_loaded: raise RuntimeError("Model not loaded") try: # Memory management for GPU if self._device == "cuda": if self._check_memory(): self._clear_memory() # Handle voice input voice_path: str voice_name: str if isinstance(voice, tuple): voice_name, voice_data = voice if isinstance(voice_data, str): voice_path = voice_data else: # Save tensor to temporary file import tempfile temp_dir = tempfile.gettempdir() voice_path = os.path.join(temp_dir, f"{voice_name}.pt") # Save tensor with CPU mapping for portability torch.save(voice_data.cpu(), voice_path) else: voice_path = voice voice_name = os.path.splitext(os.path.basename(voice_path))[0] # Load voice tensor with proper device mapping voice_tensor = await paths.load_voice_tensor( voice_path, device=self._device ) # Save back to a temporary file with proper device mapping import tempfile temp_dir = tempfile.gettempdir() temp_path = os.path.join( temp_dir, f"temp_voice_{os.path.basename(voice_path)}" ) await paths.save_voice_tensor(voice_tensor, temp_path) voice_path = temp_path # Use provided lang_code or get from voice name pipeline_lang_code = lang_code if lang_code else voice_name[0].lower() pipeline = self._get_pipeline(pipeline_lang_code) logger.debug( f"Generating audio from tokens with lang_code '{pipeline_lang_code}': '{tokens[:100]}{'...' if len(tokens) > 100 else ''}'" ) for result in pipeline.generate_from_tokens( tokens=tokens, voice=voice_path, speed=speed, model=self._model ): if result.audio is not None: logger.debug(f"Got audio chunk with shape: {result.audio.shape}") yield result.audio.numpy() else: logger.warning("No audio in chunk") except Exception as e: logger.error(f"Generation failed: {e}") if ( self._device == "cuda" and model_config.pytorch_gpu.retry_on_oom and "out of memory" in str(e).lower() ): self._clear_memory() async for chunk in self.generate_from_tokens( tokens, voice, speed, lang_code ): yield chunk raise async def generate( self, text: str, voice: Union[str, Tuple[str, Union[torch.Tensor, str]]], speed: float = 1.0, lang_code: Optional[str] = None, ) -> AsyncGenerator[np.ndarray, None]: """Generate audio using model. Args: text: Input text to synthesize voice: Either a voice path string or a tuple of (voice_name, voice_tensor/path) speed: Speed multiplier lang_code: Optional language code override Yields: Generated audio chunks Raises: RuntimeError: If generation fails """ if not self.is_loaded: raise RuntimeError("Model not loaded") try: # Memory management for GPU if self._device == "cuda": if self._check_memory(): self._clear_memory() # Handle voice input voice_path: str voice_name: str if isinstance(voice, tuple): voice_name, voice_data = voice if isinstance(voice_data, str): voice_path = voice_data else: # Save tensor to temporary file import tempfile temp_dir = tempfile.gettempdir() voice_path = os.path.join(temp_dir, f"{voice_name}.pt") # Save tensor with CPU mapping for portability torch.save(voice_data.cpu(), voice_path) else: voice_path = voice voice_name = os.path.splitext(os.path.basename(voice_path))[0] # Load voice tensor with proper device mapping voice_tensor = await paths.load_voice_tensor( voice_path, device=self._device ) # Save back to a temporary file with proper device mapping import tempfile temp_dir = tempfile.gettempdir() temp_path = os.path.join( temp_dir, f"temp_voice_{os.path.basename(voice_path)}" ) await paths.save_voice_tensor(voice_tensor, temp_path) voice_path = temp_path # Use provided lang_code or get from voice name pipeline_lang_code = lang_code if lang_code else voice_name[0].lower() pipeline = self._get_pipeline(pipeline_lang_code) logger.debug( f"Generating audio for text with lang_code '{pipeline_lang_code}': '{text[:100]}{'...' if len(text) > 100 else ''}'" ) for result in pipeline( text, voice=voice_path, speed=speed, model=self._model ): if result.audio is not None: logger.debug(f"Got audio chunk with shape: {result.audio.shape}") yield result.audio.numpy() else: logger.warning("No audio in chunk") except Exception as e: logger.error(f"Generation failed: {e}") if ( self._device == "cuda" and model_config.pytorch_gpu.retry_on_oom and "out of memory" in str(e).lower() ): self._clear_memory() async for chunk in self.generate(text, voice, speed, lang_code): yield chunk raise def _check_memory(self) -> bool: """Check if memory usage is above threshold.""" if self._device == "cuda": memory_gb = torch.cuda.memory_allocated() / 1e9 return memory_gb > model_config.pytorch_gpu.memory_threshold return False def _clear_memory(self) -> None: """Clear device memory.""" if self._device == "cuda": torch.cuda.empty_cache() torch.cuda.synchronize() def unload(self) -> None: """Unload model and free resources.""" if self._model is not None: del self._model self._model = None for pipeline in self._pipelines.values(): del pipeline self._pipelines.clear() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() @property def is_loaded(self) -> bool: """Check if model is loaded.""" return self._model is not None @property def device(self) -> str: """Get device model is running on.""" return self._device