diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..cc9c4d9 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,5 @@ +* text=auto + +*.py text eol=lf +*.sh text eol=lf +*.yml text eol=lf \ No newline at end of file diff --git a/api/src/core/paths.py b/api/src/core/paths.py index 6a024f3..0e60528 100644 --- a/api/src/core/paths.py +++ b/api/src/core/paths.py @@ -1,413 +1,413 @@ -"""Async file and path operations.""" - -import io -import json -import os -from pathlib import Path -from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Set - -import aiofiles -import aiofiles.os -import torch -from loguru import logger - -from .config import settings - - -async def _find_file( - filename: str, - search_paths: List[str], - filter_fn: Optional[Callable[[str], bool]] = None, -) -> str: - """Find file in search paths. - - Args: - filename: Name of file to find - search_paths: List of paths to search in - filter_fn: Optional function to filter files - - Returns: - Absolute path to file - - Raises: - RuntimeError: If file not found - """ - if os.path.isabs(filename) and await aiofiles.os.path.exists(filename): - return filename - - for path in search_paths: - full_path = os.path.join(path, filename) - if await aiofiles.os.path.exists(full_path): - if filter_fn is None or filter_fn(full_path): - return full_path - - raise FileNotFoundError(f"File not found: {filename} in paths: {search_paths}") - - -async def _scan_directories( - search_paths: List[str], filter_fn: Optional[Callable[[str], bool]] = None -) -> Set[str]: - """Scan directories for files. - - Args: - search_paths: List of paths to scan - filter_fn: Optional function to filter files - - Returns: - Set of matching filenames - """ - results = set() - - for path in search_paths: - if not await aiofiles.os.path.exists(path): - continue - - try: - # Get directory entries first - entries = await aiofiles.os.scandir(path) - # Then process entries after await completes - for entry in entries: - if filter_fn is None or filter_fn(entry.name): - results.add(entry.name) - except Exception as e: - logger.warning(f"Error scanning {path}: {e}") - - return results - - -async def get_model_path(model_name: str) -> str: - """Get path to model file. - - Args: - model_name: Name of model file - - Returns: - Absolute path to model file - - Raises: - RuntimeError: If model not found - """ - # Get api directory path (two levels up from core) - api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) - - # Construct model directory path relative to api directory - model_dir = os.path.join(api_dir, settings.model_dir) - - # Ensure model directory exists - os.makedirs(model_dir, exist_ok=True) - - # Search in model directory - search_paths = [model_dir] - logger.debug(f"Searching for model in path: {model_dir}") - - return await _find_file(model_name, search_paths) - - -async def get_voice_path(voice_name: str) -> str: - """Get path to voice file. - - Args: - voice_name: Name of voice file (without .pt extension) - - Returns: - Absolute path to voice file - - Raises: - RuntimeError: If voice not found - """ - # Get api directory path - api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) - - # Construct voice directory path relative to api directory - voice_dir = os.path.join(api_dir, settings.voices_dir) - - # Ensure voice directory exists - os.makedirs(voice_dir, exist_ok=True) - - voice_file = f"{voice_name}.pt" - - # Search in voice directory/o - search_paths = [voice_dir] - logger.debug(f"Searching for voice in path: {voice_dir}") - - return await _find_file(voice_file, search_paths) - - -async def list_voices() -> List[str]: - """List available voice files. - - Returns: - List of voice names (without .pt extension) - """ - # Get api directory path - api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) - - # Construct voice directory path relative to api directory - voice_dir = os.path.join(api_dir, settings.voices_dir) - - # Ensure voice directory exists - os.makedirs(voice_dir, exist_ok=True) - - # Search in voice directory - search_paths = [voice_dir] - logger.debug(f"Scanning for voices in path: {voice_dir}") - - def filter_voice_files(name: str) -> bool: - return name.endswith(".pt") - - voices = await _scan_directories(search_paths, filter_voice_files) - return sorted([name[:-3] for name in voices]) # Remove .pt extension - - -async def load_voice_tensor( - voice_path: str, device: str = "cpu", weights_only=False -) -> torch.Tensor: - """Load voice tensor from file. - - Args: - voice_path: Path to voice file - device: Device to load tensor to - - Returns: - Voice tensor - - Raises: - RuntimeError: If file cannot be read - """ - try: - async with aiofiles.open(voice_path, "rb") as f: - data = await f.read() - return torch.load( - io.BytesIO(data), map_location=device, weights_only=weights_only - ) - except Exception as e: - raise RuntimeError(f"Failed to load voice tensor from {voice_path}: {e}") - - -async def save_voice_tensor(tensor: torch.Tensor, voice_path: str) -> None: - """Save voice tensor to file. - - Args: - tensor: Voice tensor to save - voice_path: Path to save voice file - - Raises: - RuntimeError: If file cannot be written - """ - try: - buffer = io.BytesIO() - torch.save(tensor, buffer) - async with aiofiles.open(voice_path, "wb") as f: - await f.write(buffer.getvalue()) - except Exception as e: - raise RuntimeError(f"Failed to save voice tensor to {voice_path}: {e}") - - -async def load_json(path: str) -> dict: - """Load JSON file asynchronously. - - Args: - path: Path to JSON file - - Returns: - Parsed JSON data - - Raises: - RuntimeError: If file cannot be read or parsed - """ - try: - async with aiofiles.open(path, "r", encoding="utf-8") as f: - content = await f.read() - return json.loads(content) - except Exception as e: - raise RuntimeError(f"Failed to load JSON file {path}: {e}") - - -async def load_model_weights(path: str, device: str = "cpu") -> dict: - """Load model weights asynchronously. - - Args: - path: Path to model file (.pth or .onnx) - device: Device to load model to - - Returns: - Model weights - - Raises: - RuntimeError: If file cannot be read - """ - try: - async with aiofiles.open(path, "rb") as f: - data = await f.read() - return torch.load(io.BytesIO(data), map_location=device, weights_only=True) - except Exception as e: - raise RuntimeError(f"Failed to load model weights from {path}: {e}") - - -async def read_file(path: str) -> str: - """Read text file asynchronously. - - Args: - path: Path to file - - Returns: - File contents as string - - Raises: - RuntimeError: If file cannot be read - """ - try: - async with aiofiles.open(path, "r", encoding="utf-8") as f: - return await f.read() - except Exception as e: - raise RuntimeError(f"Failed to read file {path}: {e}") - - -async def read_bytes(path: str) -> bytes: - """Read file as bytes asynchronously. - - Args: - path: Path to file - - Returns: - File contents as bytes - - Raises: - RuntimeError: If file cannot be read - """ - try: - async with aiofiles.open(path, "rb") as f: - return await f.read() - except Exception as e: - raise RuntimeError(f"Failed to read file {path}: {e}") - - -async def get_web_file_path(filename: str) -> str: - """Get path to web static file. - - Args: - filename: Name of file in web directory - - Returns: - Absolute path to file - - Raises: - RuntimeError: If file not found - """ - # Get project root directory (four levels up from core to get to project root) - root_dir = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.dirname(__file__))) - ) - - # Construct web directory path relative to project root - web_dir = os.path.join("/app", settings.web_player_path) - - # Search in web directory - search_paths = [web_dir] - logger.debug(f"Searching for web file in path: {web_dir}") - - return await _find_file(filename, search_paths) - - -async def get_content_type(path: str) -> str: - """Get content type for file. - - Args: - path: Path to file - - Returns: - Content type string - """ - ext = os.path.splitext(path)[1].lower() - return { - ".html": "text/html", - ".js": "application/javascript", - ".css": "text/css", - ".png": "image/png", - ".jpg": "image/jpeg", - ".jpeg": "image/jpeg", - ".gif": "image/gif", - ".svg": "image/svg+xml", - ".ico": "image/x-icon", - }.get(ext, "application/octet-stream") - - -async def verify_model_path(model_path: str) -> bool: - """Verify model file exists at path.""" - return await aiofiles.os.path.exists(model_path) - - -async def cleanup_temp_files() -> None: - """Clean up old temp files on startup""" - try: - if not await aiofiles.os.path.exists(settings.temp_file_dir): - await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True) - return - - entries = await aiofiles.os.scandir(settings.temp_file_dir) - for entry in entries: - if entry.is_file(): - stat = await aiofiles.os.stat(entry.path) - max_age = stat.st_mtime + (settings.max_temp_dir_age_hours * 3600) - if max_age < stat.st_mtime: - try: - await aiofiles.os.remove(entry.path) - logger.info(f"Cleaned up old temp file: {entry.name}") - except Exception as e: - logger.warning( - f"Failed to delete old temp file {entry.name}: {e}" - ) - except Exception as e: - logger.warning(f"Error cleaning temp files: {e}") - - -async def get_temp_file_path(filename: str) -> str: - """Get path to temporary audio file. - - Args: - filename: Name of temp file - - Returns: - Absolute path to temp file - - Raises: - RuntimeError: If temp directory does not exist - """ - temp_path = os.path.join(settings.temp_file_dir, filename) - - # Ensure temp directory exists - if not await aiofiles.os.path.exists(settings.temp_file_dir): - await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True) - - return temp_path - - -async def list_temp_files() -> List[str]: - """List temporary audio files. - - Returns: - List of temp file names - """ - if not await aiofiles.os.path.exists(settings.temp_file_dir): - return [] - - entries = await aiofiles.os.scandir(settings.temp_file_dir) - return [entry.name for entry in entries if entry.is_file()] - - -async def get_temp_dir_size() -> int: - """Get total size of temp directory in bytes. - - Returns: - Size in bytes - """ - if not await aiofiles.os.path.exists(settings.temp_file_dir): - return 0 - - total = 0 - entries = await aiofiles.os.scandir(settings.temp_file_dir) - for entry in entries: - if entry.is_file(): - stat = await aiofiles.os.stat(entry.path) - total += stat.st_size - return total +"""Async file and path operations.""" + +import io +import json +import os +from pathlib import Path +from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Set + +import aiofiles +import aiofiles.os +import torch +from loguru import logger + +from .config import settings + + +async def _find_file( + filename: str, + search_paths: List[str], + filter_fn: Optional[Callable[[str], bool]] = None, +) -> str: + """Find file in search paths. + + Args: + filename: Name of file to find + search_paths: List of paths to search in + filter_fn: Optional function to filter files + + Returns: + Absolute path to file + + Raises: + RuntimeError: If file not found + """ + if os.path.isabs(filename) and await aiofiles.os.path.exists(filename): + return filename + + for path in search_paths: + full_path = os.path.join(path, filename) + if await aiofiles.os.path.exists(full_path): + if filter_fn is None or filter_fn(full_path): + return full_path + + raise FileNotFoundError(f"File not found: {filename} in paths: {search_paths}") + + +async def _scan_directories( + search_paths: List[str], filter_fn: Optional[Callable[[str], bool]] = None +) -> Set[str]: + """Scan directories for files. + + Args: + search_paths: List of paths to scan + filter_fn: Optional function to filter files + + Returns: + Set of matching filenames + """ + results = set() + + for path in search_paths: + if not await aiofiles.os.path.exists(path): + continue + + try: + # Get directory entries first + entries = await aiofiles.os.scandir(path) + # Then process entries after await completes + for entry in entries: + if filter_fn is None or filter_fn(entry.name): + results.add(entry.name) + except Exception as e: + logger.warning(f"Error scanning {path}: {e}") + + return results + + +async def get_model_path(model_name: str) -> str: + """Get path to model file. + + Args: + model_name: Name of model file + + Returns: + Absolute path to model file + + Raises: + RuntimeError: If model not found + """ + # Get api directory path (two levels up from core) + api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + + # Construct model directory path relative to api directory + model_dir = os.path.join(api_dir, settings.model_dir) + + # Ensure model directory exists + os.makedirs(model_dir, exist_ok=True) + + # Search in model directory + search_paths = [model_dir] + logger.debug(f"Searching for model in path: {model_dir}") + + return await _find_file(model_name, search_paths) + + +async def get_voice_path(voice_name: str) -> str: + """Get path to voice file. + + Args: + voice_name: Name of voice file (without .pt extension) + + Returns: + Absolute path to voice file + + Raises: + RuntimeError: If voice not found + """ + # Get api directory path + api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + + # Construct voice directory path relative to api directory + voice_dir = os.path.join(api_dir, settings.voices_dir) + + # Ensure voice directory exists + os.makedirs(voice_dir, exist_ok=True) + + voice_file = f"{voice_name}.pt" + + # Search in voice directory/o + search_paths = [voice_dir] + logger.debug(f"Searching for voice in path: {voice_dir}") + + return await _find_file(voice_file, search_paths) + + +async def list_voices() -> List[str]: + """List available voice files. + + Returns: + List of voice names (without .pt extension) + """ + # Get api directory path + api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + + # Construct voice directory path relative to api directory + voice_dir = os.path.join(api_dir, settings.voices_dir) + + # Ensure voice directory exists + os.makedirs(voice_dir, exist_ok=True) + + # Search in voice directory + search_paths = [voice_dir] + logger.debug(f"Scanning for voices in path: {voice_dir}") + + def filter_voice_files(name: str) -> bool: + return name.endswith(".pt") + + voices = await _scan_directories(search_paths, filter_voice_files) + return sorted([name[:-3] for name in voices]) # Remove .pt extension + + +async def load_voice_tensor( + voice_path: str, device: str = "cpu", weights_only=False +) -> torch.Tensor: + """Load voice tensor from file. + + Args: + voice_path: Path to voice file + device: Device to load tensor to + + Returns: + Voice tensor + + Raises: + RuntimeError: If file cannot be read + """ + try: + async with aiofiles.open(voice_path, "rb") as f: + data = await f.read() + return torch.load( + io.BytesIO(data), map_location=device, weights_only=weights_only + ) + except Exception as e: + raise RuntimeError(f"Failed to load voice tensor from {voice_path}: {e}") + + +async def save_voice_tensor(tensor: torch.Tensor, voice_path: str) -> None: + """Save voice tensor to file. + + Args: + tensor: Voice tensor to save + voice_path: Path to save voice file + + Raises: + RuntimeError: If file cannot be written + """ + try: + buffer = io.BytesIO() + torch.save(tensor, buffer) + async with aiofiles.open(voice_path, "wb") as f: + await f.write(buffer.getvalue()) + except Exception as e: + raise RuntimeError(f"Failed to save voice tensor to {voice_path}: {e}") + + +async def load_json(path: str) -> dict: + """Load JSON file asynchronously. + + Args: + path: Path to JSON file + + Returns: + Parsed JSON data + + Raises: + RuntimeError: If file cannot be read or parsed + """ + try: + async with aiofiles.open(path, "r", encoding="utf-8") as f: + content = await f.read() + return json.loads(content) + except Exception as e: + raise RuntimeError(f"Failed to load JSON file {path}: {e}") + + +async def load_model_weights(path: str, device: str = "cpu") -> dict: + """Load model weights asynchronously. + + Args: + path: Path to model file (.pth or .onnx) + device: Device to load model to + + Returns: + Model weights + + Raises: + RuntimeError: If file cannot be read + """ + try: + async with aiofiles.open(path, "rb") as f: + data = await f.read() + return torch.load(io.BytesIO(data), map_location=device, weights_only=True) + except Exception as e: + raise RuntimeError(f"Failed to load model weights from {path}: {e}") + + +async def read_file(path: str) -> str: + """Read text file asynchronously. + + Args: + path: Path to file + + Returns: + File contents as string + + Raises: + RuntimeError: If file cannot be read + """ + try: + async with aiofiles.open(path, "r", encoding="utf-8") as f: + return await f.read() + except Exception as e: + raise RuntimeError(f"Failed to read file {path}: {e}") + + +async def read_bytes(path: str) -> bytes: + """Read file as bytes asynchronously. + + Args: + path: Path to file + + Returns: + File contents as bytes + + Raises: + RuntimeError: If file cannot be read + """ + try: + async with aiofiles.open(path, "rb") as f: + return await f.read() + except Exception as e: + raise RuntimeError(f"Failed to read file {path}: {e}") + + +async def get_web_file_path(filename: str) -> str: + """Get path to web static file. + + Args: + filename: Name of file in web directory + + Returns: + Absolute path to file + + Raises: + RuntimeError: If file not found + """ + # Get project root directory (four levels up from core to get to project root) + root_dir = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + ) + + # Construct web directory path relative to project root + web_dir = os.path.join("/app", settings.web_player_path) + + # Search in web directory + search_paths = [web_dir] + logger.debug(f"Searching for web file in path: {web_dir}") + + return await _find_file(filename, search_paths) + + +async def get_content_type(path: str) -> str: + """Get content type for file. + + Args: + path: Path to file + + Returns: + Content type string + """ + ext = os.path.splitext(path)[1].lower() + return { + ".html": "text/html", + ".js": "application/javascript", + ".css": "text/css", + ".png": "image/png", + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".gif": "image/gif", + ".svg": "image/svg+xml", + ".ico": "image/x-icon", + }.get(ext, "application/octet-stream") + + +async def verify_model_path(model_path: str) -> bool: + """Verify model file exists at path.""" + return await aiofiles.os.path.exists(model_path) + + +async def cleanup_temp_files() -> None: + """Clean up old temp files on startup""" + try: + if not await aiofiles.os.path.exists(settings.temp_file_dir): + await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True) + return + + entries = await aiofiles.os.scandir(settings.temp_file_dir) + for entry in entries: + if entry.is_file(): + stat = await aiofiles.os.stat(entry.path) + max_age = stat.st_mtime + (settings.max_temp_dir_age_hours * 3600) + if max_age < stat.st_mtime: + try: + await aiofiles.os.remove(entry.path) + logger.info(f"Cleaned up old temp file: {entry.name}") + except Exception as e: + logger.warning( + f"Failed to delete old temp file {entry.name}: {e}" + ) + except Exception as e: + logger.warning(f"Error cleaning temp files: {e}") + + +async def get_temp_file_path(filename: str) -> str: + """Get path to temporary audio file. + + Args: + filename: Name of temp file + + Returns: + Absolute path to temp file + + Raises: + RuntimeError: If temp directory does not exist + """ + temp_path = os.path.join(settings.temp_file_dir, filename) + + # Ensure temp directory exists + if not await aiofiles.os.path.exists(settings.temp_file_dir): + await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True) + + return temp_path + + +async def list_temp_files() -> List[str]: + """List temporary audio files. + + Returns: + List of temp file names + """ + if not await aiofiles.os.path.exists(settings.temp_file_dir): + return [] + + entries = await aiofiles.os.scandir(settings.temp_file_dir) + return [entry.name for entry in entries if entry.is_file()] + + +async def get_temp_dir_size() -> int: + """Get total size of temp directory in bytes. + + Returns: + Size in bytes + """ + if not await aiofiles.os.path.exists(settings.temp_file_dir): + return 0 + + total = 0 + entries = await aiofiles.os.scandir(settings.temp_file_dir) + for entry in entries: + if entry.is_file(): + stat = await aiofiles.os.stat(entry.path) + total += stat.st_size + return total diff --git a/api/src/inference/__init__.py b/api/src/inference/__init__.py index cbbf95d..68c7ce3 100644 --- a/api/src/inference/__init__.py +++ b/api/src/inference/__init__.py @@ -1,12 +1,12 @@ -"""Model inference package.""" - -from .base import BaseModelBackend -from .kokoro_v1 import KokoroV1 -from .model_manager import ModelManager, get_manager - -__all__ = [ - "BaseModelBackend", - "ModelManager", - "get_manager", - "KokoroV1", -] +"""Model inference package.""" + +from .base import BaseModelBackend +from .kokoro_v1 import KokoroV1 +from .model_manager import ModelManager, get_manager + +__all__ = [ + "BaseModelBackend", + "ModelManager", + "get_manager", + "KokoroV1", +] diff --git a/api/src/inference/base.py b/api/src/inference/base.py index 7f3cd3f..5629bd7 100644 --- a/api/src/inference/base.py +++ b/api/src/inference/base.py @@ -1,98 +1,98 @@ -"""Base interface for Kokoro inference.""" - -from abc import ABC, abstractmethod -from typing import AsyncGenerator, Optional, Tuple, Union - -import numpy as np -import torch - - -class ModelBackend(ABC): - """Abstract base class for model inference backend.""" - - @abstractmethod - async def load_model(self, path: str) -> None: - """Load model from path. - - Args: - path: Path to model file - - Raises: - RuntimeError: If model loading fails - """ - pass - - @abstractmethod - async def generate( - self, - text: str, - voice: Union[str, Tuple[str, Union[torch.Tensor, str]]], - speed: float = 1.0, - ) -> AsyncGenerator[np.ndarray, None]: - """Generate audio from text. - - Args: - text: Input text to synthesize - voice: Either a voice path or tuple of (name, tensor/path) - speed: Speed multiplier - - Yields: - Generated audio chunks - - Raises: - RuntimeError: If generation fails - """ - pass - - @abstractmethod - def unload(self) -> None: - """Unload model and free resources.""" - pass - - @property - @abstractmethod - def is_loaded(self) -> bool: - """Check if model is loaded. - - Returns: - True if model is loaded, False otherwise - """ - pass - - @property - @abstractmethod - def device(self) -> str: - """Get device model is running on. - - Returns: - Device string ('cpu' or 'cuda') - """ - pass - - -class BaseModelBackend(ModelBackend): - """Base implementation of model backend.""" - - def __init__(self): - """Initialize base backend.""" - self._model: Optional[torch.nn.Module] = None - self._device: str = "cpu" - - @property - def is_loaded(self) -> bool: - """Check if model is loaded.""" - return self._model is not None - - @property - def device(self) -> str: - """Get device model is running on.""" - return self._device - - def unload(self) -> None: - """Unload model and free resources.""" - if self._model is not None: - del self._model - self._model = None - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.synchronize() +"""Base interface for Kokoro inference.""" + +from abc import ABC, abstractmethod +from typing import AsyncGenerator, Optional, Tuple, Union + +import numpy as np +import torch + + +class ModelBackend(ABC): + """Abstract base class for model inference backend.""" + + @abstractmethod + async def load_model(self, path: str) -> None: + """Load model from path. + + Args: + path: Path to model file + + Raises: + RuntimeError: If model loading fails + """ + pass + + @abstractmethod + async def generate( + self, + text: str, + voice: Union[str, Tuple[str, Union[torch.Tensor, str]]], + speed: float = 1.0, + ) -> AsyncGenerator[np.ndarray, None]: + """Generate audio from text. + + Args: + text: Input text to synthesize + voice: Either a voice path or tuple of (name, tensor/path) + speed: Speed multiplier + + Yields: + Generated audio chunks + + Raises: + RuntimeError: If generation fails + """ + pass + + @abstractmethod + def unload(self) -> None: + """Unload model and free resources.""" + pass + + @property + @abstractmethod + def is_loaded(self) -> bool: + """Check if model is loaded. + + Returns: + True if model is loaded, False otherwise + """ + pass + + @property + @abstractmethod + def device(self) -> str: + """Get device model is running on. + + Returns: + Device string ('cpu' or 'cuda') + """ + pass + + +class BaseModelBackend(ModelBackend): + """Base implementation of model backend.""" + + def __init__(self): + """Initialize base backend.""" + self._model: Optional[torch.nn.Module] = None + self._device: str = "cpu" + + @property + def is_loaded(self) -> bool: + """Check if model is loaded.""" + return self._model is not None + + @property + def device(self) -> str: + """Get device model is running on.""" + return self._device + + def unload(self) -> None: + """Unload model and free resources.""" + if self._model is not None: + del self._model + self._model = None + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() diff --git a/api/src/inference/model_manager.py b/api/src/inference/model_manager.py index 06e3aa1..9cef95f 100644 --- a/api/src/inference/model_manager.py +++ b/api/src/inference/model_manager.py @@ -1,171 +1,171 @@ -"""Kokoro V1 model management.""" - -from typing import Optional - -from loguru import logger - -from ..core import paths -from ..core.config import settings -from ..core.model_config import ModelConfig, model_config -from .base import BaseModelBackend -from .kokoro_v1 import KokoroV1 - - -class ModelManager: - """Manages Kokoro V1 model loading and inference.""" - - # Singleton instance - _instance = None - - def __init__(self, config: Optional[ModelConfig] = None): - """Initialize manager. - - Args: - config: Optional model configuration override - """ - self._config = config or model_config - self._backend: Optional[KokoroV1] = None # Explicitly type as KokoroV1 - self._device: Optional[str] = None - - def _determine_device(self) -> str: - """Determine device based on settings.""" - return "cuda" if settings.use_gpu else "cpu" - - async def initialize(self) -> None: - """Initialize Kokoro V1 backend.""" - try: - self._device = self._determine_device() - logger.info(f"Initializing Kokoro V1 on {self._device}") - self._backend = KokoroV1() - - except Exception as e: - raise RuntimeError(f"Failed to initialize Kokoro V1: {e}") - - async def initialize_with_warmup(self, voice_manager) -> tuple[str, str, int]: - """Initialize and warm up model. - - Args: - voice_manager: Voice manager instance for warmup - - Returns: - Tuple of (device, backend type, voice count) - - Raises: - RuntimeError: If initialization fails - """ - import time - - start = time.perf_counter() - - try: - # Initialize backend - await self.initialize() - - # Load model - model_path = self._config.pytorch_kokoro_v1_file - await self.load_model(model_path) - - # Use paths module to get voice path - try: - voices = await paths.list_voices() - voice_path = await paths.get_voice_path(settings.default_voice) - - # Warm up with short text - warmup_text = "Warmup text for initialization." - # Use default voice name for warmup - voice_name = settings.default_voice - logger.debug(f"Using default voice '{voice_name}' for warmup") - async for _ in self.generate(warmup_text, (voice_name, voice_path)): - pass - except Exception as e: - raise RuntimeError(f"Failed to get default voice: {e}") - - ms = int((time.perf_counter() - start) * 1000) - logger.info(f"Warmup completed in {ms}ms") - - return self._device, "kokoro_v1", len(voices) - except FileNotFoundError as e: - logger.error(""" -Model files not found! You need to download the Kokoro V1 model: - -1. Download model using the script: - python docker/scripts/download_model.py --output api/src/models/v1_0 - -2. Or set environment variable in docker-compose: - DOWNLOAD_MODEL=true -""") - exit(0) - except Exception as e: - raise RuntimeError(f"Warmup failed: {e}") - - def get_backend(self) -> BaseModelBackend: - """Get initialized backend. - - Returns: - Initialized backend instance - - Raises: - RuntimeError: If backend not initialized - """ - if not self._backend: - raise RuntimeError("Backend not initialized") - return self._backend - - async def load_model(self, path: str) -> None: - """Load model using initialized backend. - - Args: - path: Path to model file - - Raises: - RuntimeError: If loading fails - """ - if not self._backend: - raise RuntimeError("Backend not initialized") - - try: - await self._backend.load_model(path) - except FileNotFoundError as e: - raise e - except Exception as e: - raise RuntimeError(f"Failed to load model: {e}") - - async def generate(self, *args, **kwargs): - """Generate audio using initialized backend. - - Raises: - RuntimeError: If generation fails - """ - if not self._backend: - raise RuntimeError("Backend not initialized") - - try: - async for chunk in self._backend.generate(*args, **kwargs): - yield chunk - except Exception as e: - raise RuntimeError(f"Generation failed: {e}") - - def unload_all(self) -> None: - """Unload model and free resources.""" - if self._backend: - self._backend.unload() - self._backend = None - - @property - def current_backend(self) -> str: - """Get current backend type.""" - return "kokoro_v1" - - -async def get_manager(config: Optional[ModelConfig] = None) -> ModelManager: - """Get model manager instance. - - Args: - config: Optional configuration override - - Returns: - ModelManager instance - """ - if ModelManager._instance is None: - ModelManager._instance = ModelManager(config) - return ModelManager._instance +"""Kokoro V1 model management.""" + +from typing import Optional + +from loguru import logger + +from ..core import paths +from ..core.config import settings +from ..core.model_config import ModelConfig, model_config +from .base import BaseModelBackend +from .kokoro_v1 import KokoroV1 + + +class ModelManager: + """Manages Kokoro V1 model loading and inference.""" + + # Singleton instance + _instance = None + + def __init__(self, config: Optional[ModelConfig] = None): + """Initialize manager. + + Args: + config: Optional model configuration override + """ + self._config = config or model_config + self._backend: Optional[KokoroV1] = None # Explicitly type as KokoroV1 + self._device: Optional[str] = None + + def _determine_device(self) -> str: + """Determine device based on settings.""" + return "cuda" if settings.use_gpu else "cpu" + + async def initialize(self) -> None: + """Initialize Kokoro V1 backend.""" + try: + self._device = self._determine_device() + logger.info(f"Initializing Kokoro V1 on {self._device}") + self._backend = KokoroV1() + + except Exception as e: + raise RuntimeError(f"Failed to initialize Kokoro V1: {e}") + + async def initialize_with_warmup(self, voice_manager) -> tuple[str, str, int]: + """Initialize and warm up model. + + Args: + voice_manager: Voice manager instance for warmup + + Returns: + Tuple of (device, backend type, voice count) + + Raises: + RuntimeError: If initialization fails + """ + import time + + start = time.perf_counter() + + try: + # Initialize backend + await self.initialize() + + # Load model + model_path = self._config.pytorch_kokoro_v1_file + await self.load_model(model_path) + + # Use paths module to get voice path + try: + voices = await paths.list_voices() + voice_path = await paths.get_voice_path(settings.default_voice) + + # Warm up with short text + warmup_text = "Warmup text for initialization." + # Use default voice name for warmup + voice_name = settings.default_voice + logger.debug(f"Using default voice '{voice_name}' for warmup") + async for _ in self.generate(warmup_text, (voice_name, voice_path)): + pass + except Exception as e: + raise RuntimeError(f"Failed to get default voice: {e}") + + ms = int((time.perf_counter() - start) * 1000) + logger.info(f"Warmup completed in {ms}ms") + + return self._device, "kokoro_v1", len(voices) + except FileNotFoundError as e: + logger.error(""" +Model files not found! You need to download the Kokoro V1 model: + +1. Download model using the script: + python docker/scripts/download_model.py --output api/src/models/v1_0 + +2. Or set environment variable in docker-compose: + DOWNLOAD_MODEL=true +""") + exit(0) + except Exception as e: + raise RuntimeError(f"Warmup failed: {e}") + + def get_backend(self) -> BaseModelBackend: + """Get initialized backend. + + Returns: + Initialized backend instance + + Raises: + RuntimeError: If backend not initialized + """ + if not self._backend: + raise RuntimeError("Backend not initialized") + return self._backend + + async def load_model(self, path: str) -> None: + """Load model using initialized backend. + + Args: + path: Path to model file + + Raises: + RuntimeError: If loading fails + """ + if not self._backend: + raise RuntimeError("Backend not initialized") + + try: + await self._backend.load_model(path) + except FileNotFoundError as e: + raise e + except Exception as e: + raise RuntimeError(f"Failed to load model: {e}") + + async def generate(self, *args, **kwargs): + """Generate audio using initialized backend. + + Raises: + RuntimeError: If generation fails + """ + if not self._backend: + raise RuntimeError("Backend not initialized") + + try: + async for chunk in self._backend.generate(*args, **kwargs): + yield chunk + except Exception as e: + raise RuntimeError(f"Generation failed: {e}") + + def unload_all(self) -> None: + """Unload model and free resources.""" + if self._backend: + self._backend.unload() + self._backend = None + + @property + def current_backend(self) -> str: + """Get current backend type.""" + return "kokoro_v1" + + +async def get_manager(config: Optional[ModelConfig] = None) -> ModelManager: + """Get model manager instance. + + Args: + config: Optional configuration override + + Returns: + ModelManager instance + """ + if ModelManager._instance is None: + ModelManager._instance = ModelManager(config) + return ModelManager._instance