Add a .gitattributes

This commit is contained in:
Fireblade 2025-02-18 17:44:03 -05:00
parent b00c9ec28d
commit 7f15ba8fed
5 changed files with 699 additions and 694 deletions

5
.gitattributes vendored Normal file
View file

@ -0,0 +1,5 @@
* text=auto
*.py text eol=lf
*.sh text eol=lf
*.yml text eol=lf

View file

@ -1,413 +1,413 @@
"""Async file and path operations.""" """Async file and path operations."""
import io import io
import json import json
import os import os
from pathlib import Path from pathlib import Path
from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Set from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Set
import aiofiles import aiofiles
import aiofiles.os import aiofiles.os
import torch import torch
from loguru import logger from loguru import logger
from .config import settings from .config import settings
async def _find_file( async def _find_file(
filename: str, filename: str,
search_paths: List[str], search_paths: List[str],
filter_fn: Optional[Callable[[str], bool]] = None, filter_fn: Optional[Callable[[str], bool]] = None,
) -> str: ) -> str:
"""Find file in search paths. """Find file in search paths.
Args: Args:
filename: Name of file to find filename: Name of file to find
search_paths: List of paths to search in search_paths: List of paths to search in
filter_fn: Optional function to filter files filter_fn: Optional function to filter files
Returns: Returns:
Absolute path to file Absolute path to file
Raises: Raises:
RuntimeError: If file not found RuntimeError: If file not found
""" """
if os.path.isabs(filename) and await aiofiles.os.path.exists(filename): if os.path.isabs(filename) and await aiofiles.os.path.exists(filename):
return filename return filename
for path in search_paths: for path in search_paths:
full_path = os.path.join(path, filename) full_path = os.path.join(path, filename)
if await aiofiles.os.path.exists(full_path): if await aiofiles.os.path.exists(full_path):
if filter_fn is None or filter_fn(full_path): if filter_fn is None or filter_fn(full_path):
return full_path return full_path
raise FileNotFoundError(f"File not found: {filename} in paths: {search_paths}") raise FileNotFoundError(f"File not found: {filename} in paths: {search_paths}")
async def _scan_directories( async def _scan_directories(
search_paths: List[str], filter_fn: Optional[Callable[[str], bool]] = None search_paths: List[str], filter_fn: Optional[Callable[[str], bool]] = None
) -> Set[str]: ) -> Set[str]:
"""Scan directories for files. """Scan directories for files.
Args: Args:
search_paths: List of paths to scan search_paths: List of paths to scan
filter_fn: Optional function to filter files filter_fn: Optional function to filter files
Returns: Returns:
Set of matching filenames Set of matching filenames
""" """
results = set() results = set()
for path in search_paths: for path in search_paths:
if not await aiofiles.os.path.exists(path): if not await aiofiles.os.path.exists(path):
continue continue
try: try:
# Get directory entries first # Get directory entries first
entries = await aiofiles.os.scandir(path) entries = await aiofiles.os.scandir(path)
# Then process entries after await completes # Then process entries after await completes
for entry in entries: for entry in entries:
if filter_fn is None or filter_fn(entry.name): if filter_fn is None or filter_fn(entry.name):
results.add(entry.name) results.add(entry.name)
except Exception as e: except Exception as e:
logger.warning(f"Error scanning {path}: {e}") logger.warning(f"Error scanning {path}: {e}")
return results return results
async def get_model_path(model_name: str) -> str: async def get_model_path(model_name: str) -> str:
"""Get path to model file. """Get path to model file.
Args: Args:
model_name: Name of model file model_name: Name of model file
Returns: Returns:
Absolute path to model file Absolute path to model file
Raises: Raises:
RuntimeError: If model not found RuntimeError: If model not found
""" """
# Get api directory path (two levels up from core) # Get api directory path (two levels up from core)
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
# Construct model directory path relative to api directory # Construct model directory path relative to api directory
model_dir = os.path.join(api_dir, settings.model_dir) model_dir = os.path.join(api_dir, settings.model_dir)
# Ensure model directory exists # Ensure model directory exists
os.makedirs(model_dir, exist_ok=True) os.makedirs(model_dir, exist_ok=True)
# Search in model directory # Search in model directory
search_paths = [model_dir] search_paths = [model_dir]
logger.debug(f"Searching for model in path: {model_dir}") logger.debug(f"Searching for model in path: {model_dir}")
return await _find_file(model_name, search_paths) return await _find_file(model_name, search_paths)
async def get_voice_path(voice_name: str) -> str: async def get_voice_path(voice_name: str) -> str:
"""Get path to voice file. """Get path to voice file.
Args: Args:
voice_name: Name of voice file (without .pt extension) voice_name: Name of voice file (without .pt extension)
Returns: Returns:
Absolute path to voice file Absolute path to voice file
Raises: Raises:
RuntimeError: If voice not found RuntimeError: If voice not found
""" """
# Get api directory path # Get api directory path
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
# Construct voice directory path relative to api directory # Construct voice directory path relative to api directory
voice_dir = os.path.join(api_dir, settings.voices_dir) voice_dir = os.path.join(api_dir, settings.voices_dir)
# Ensure voice directory exists # Ensure voice directory exists
os.makedirs(voice_dir, exist_ok=True) os.makedirs(voice_dir, exist_ok=True)
voice_file = f"{voice_name}.pt" voice_file = f"{voice_name}.pt"
# Search in voice directory/o # Search in voice directory/o
search_paths = [voice_dir] search_paths = [voice_dir]
logger.debug(f"Searching for voice in path: {voice_dir}") logger.debug(f"Searching for voice in path: {voice_dir}")
return await _find_file(voice_file, search_paths) return await _find_file(voice_file, search_paths)
async def list_voices() -> List[str]: async def list_voices() -> List[str]:
"""List available voice files. """List available voice files.
Returns: Returns:
List of voice names (without .pt extension) List of voice names (without .pt extension)
""" """
# Get api directory path # Get api directory path
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
# Construct voice directory path relative to api directory # Construct voice directory path relative to api directory
voice_dir = os.path.join(api_dir, settings.voices_dir) voice_dir = os.path.join(api_dir, settings.voices_dir)
# Ensure voice directory exists # Ensure voice directory exists
os.makedirs(voice_dir, exist_ok=True) os.makedirs(voice_dir, exist_ok=True)
# Search in voice directory # Search in voice directory
search_paths = [voice_dir] search_paths = [voice_dir]
logger.debug(f"Scanning for voices in path: {voice_dir}") logger.debug(f"Scanning for voices in path: {voice_dir}")
def filter_voice_files(name: str) -> bool: def filter_voice_files(name: str) -> bool:
return name.endswith(".pt") return name.endswith(".pt")
voices = await _scan_directories(search_paths, filter_voice_files) voices = await _scan_directories(search_paths, filter_voice_files)
return sorted([name[:-3] for name in voices]) # Remove .pt extension return sorted([name[:-3] for name in voices]) # Remove .pt extension
async def load_voice_tensor( async def load_voice_tensor(
voice_path: str, device: str = "cpu", weights_only=False voice_path: str, device: str = "cpu", weights_only=False
) -> torch.Tensor: ) -> torch.Tensor:
"""Load voice tensor from file. """Load voice tensor from file.
Args: Args:
voice_path: Path to voice file voice_path: Path to voice file
device: Device to load tensor to device: Device to load tensor to
Returns: Returns:
Voice tensor Voice tensor
Raises: Raises:
RuntimeError: If file cannot be read RuntimeError: If file cannot be read
""" """
try: try:
async with aiofiles.open(voice_path, "rb") as f: async with aiofiles.open(voice_path, "rb") as f:
data = await f.read() data = await f.read()
return torch.load( return torch.load(
io.BytesIO(data), map_location=device, weights_only=weights_only io.BytesIO(data), map_location=device, weights_only=weights_only
) )
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to load voice tensor from {voice_path}: {e}") raise RuntimeError(f"Failed to load voice tensor from {voice_path}: {e}")
async def save_voice_tensor(tensor: torch.Tensor, voice_path: str) -> None: async def save_voice_tensor(tensor: torch.Tensor, voice_path: str) -> None:
"""Save voice tensor to file. """Save voice tensor to file.
Args: Args:
tensor: Voice tensor to save tensor: Voice tensor to save
voice_path: Path to save voice file voice_path: Path to save voice file
Raises: Raises:
RuntimeError: If file cannot be written RuntimeError: If file cannot be written
""" """
try: try:
buffer = io.BytesIO() buffer = io.BytesIO()
torch.save(tensor, buffer) torch.save(tensor, buffer)
async with aiofiles.open(voice_path, "wb") as f: async with aiofiles.open(voice_path, "wb") as f:
await f.write(buffer.getvalue()) await f.write(buffer.getvalue())
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to save voice tensor to {voice_path}: {e}") raise RuntimeError(f"Failed to save voice tensor to {voice_path}: {e}")
async def load_json(path: str) -> dict: async def load_json(path: str) -> dict:
"""Load JSON file asynchronously. """Load JSON file asynchronously.
Args: Args:
path: Path to JSON file path: Path to JSON file
Returns: Returns:
Parsed JSON data Parsed JSON data
Raises: Raises:
RuntimeError: If file cannot be read or parsed RuntimeError: If file cannot be read or parsed
""" """
try: try:
async with aiofiles.open(path, "r", encoding="utf-8") as f: async with aiofiles.open(path, "r", encoding="utf-8") as f:
content = await f.read() content = await f.read()
return json.loads(content) return json.loads(content)
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to load JSON file {path}: {e}") raise RuntimeError(f"Failed to load JSON file {path}: {e}")
async def load_model_weights(path: str, device: str = "cpu") -> dict: async def load_model_weights(path: str, device: str = "cpu") -> dict:
"""Load model weights asynchronously. """Load model weights asynchronously.
Args: Args:
path: Path to model file (.pth or .onnx) path: Path to model file (.pth or .onnx)
device: Device to load model to device: Device to load model to
Returns: Returns:
Model weights Model weights
Raises: Raises:
RuntimeError: If file cannot be read RuntimeError: If file cannot be read
""" """
try: try:
async with aiofiles.open(path, "rb") as f: async with aiofiles.open(path, "rb") as f:
data = await f.read() data = await f.read()
return torch.load(io.BytesIO(data), map_location=device, weights_only=True) return torch.load(io.BytesIO(data), map_location=device, weights_only=True)
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to load model weights from {path}: {e}") raise RuntimeError(f"Failed to load model weights from {path}: {e}")
async def read_file(path: str) -> str: async def read_file(path: str) -> str:
"""Read text file asynchronously. """Read text file asynchronously.
Args: Args:
path: Path to file path: Path to file
Returns: Returns:
File contents as string File contents as string
Raises: Raises:
RuntimeError: If file cannot be read RuntimeError: If file cannot be read
""" """
try: try:
async with aiofiles.open(path, "r", encoding="utf-8") as f: async with aiofiles.open(path, "r", encoding="utf-8") as f:
return await f.read() return await f.read()
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to read file {path}: {e}") raise RuntimeError(f"Failed to read file {path}: {e}")
async def read_bytes(path: str) -> bytes: async def read_bytes(path: str) -> bytes:
"""Read file as bytes asynchronously. """Read file as bytes asynchronously.
Args: Args:
path: Path to file path: Path to file
Returns: Returns:
File contents as bytes File contents as bytes
Raises: Raises:
RuntimeError: If file cannot be read RuntimeError: If file cannot be read
""" """
try: try:
async with aiofiles.open(path, "rb") as f: async with aiofiles.open(path, "rb") as f:
return await f.read() return await f.read()
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to read file {path}: {e}") raise RuntimeError(f"Failed to read file {path}: {e}")
async def get_web_file_path(filename: str) -> str: async def get_web_file_path(filename: str) -> str:
"""Get path to web static file. """Get path to web static file.
Args: Args:
filename: Name of file in web directory filename: Name of file in web directory
Returns: Returns:
Absolute path to file Absolute path to file
Raises: Raises:
RuntimeError: If file not found RuntimeError: If file not found
""" """
# Get project root directory (four levels up from core to get to project root) # Get project root directory (four levels up from core to get to project root)
root_dir = os.path.dirname( root_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))) os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
) )
# Construct web directory path relative to project root # Construct web directory path relative to project root
web_dir = os.path.join("/app", settings.web_player_path) web_dir = os.path.join("/app", settings.web_player_path)
# Search in web directory # Search in web directory
search_paths = [web_dir] search_paths = [web_dir]
logger.debug(f"Searching for web file in path: {web_dir}") logger.debug(f"Searching for web file in path: {web_dir}")
return await _find_file(filename, search_paths) return await _find_file(filename, search_paths)
async def get_content_type(path: str) -> str: async def get_content_type(path: str) -> str:
"""Get content type for file. """Get content type for file.
Args: Args:
path: Path to file path: Path to file
Returns: Returns:
Content type string Content type string
""" """
ext = os.path.splitext(path)[1].lower() ext = os.path.splitext(path)[1].lower()
return { return {
".html": "text/html", ".html": "text/html",
".js": "application/javascript", ".js": "application/javascript",
".css": "text/css", ".css": "text/css",
".png": "image/png", ".png": "image/png",
".jpg": "image/jpeg", ".jpg": "image/jpeg",
".jpeg": "image/jpeg", ".jpeg": "image/jpeg",
".gif": "image/gif", ".gif": "image/gif",
".svg": "image/svg+xml", ".svg": "image/svg+xml",
".ico": "image/x-icon", ".ico": "image/x-icon",
}.get(ext, "application/octet-stream") }.get(ext, "application/octet-stream")
async def verify_model_path(model_path: str) -> bool: async def verify_model_path(model_path: str) -> bool:
"""Verify model file exists at path.""" """Verify model file exists at path."""
return await aiofiles.os.path.exists(model_path) return await aiofiles.os.path.exists(model_path)
async def cleanup_temp_files() -> None: async def cleanup_temp_files() -> None:
"""Clean up old temp files on startup""" """Clean up old temp files on startup"""
try: try:
if not await aiofiles.os.path.exists(settings.temp_file_dir): if not await aiofiles.os.path.exists(settings.temp_file_dir):
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True) await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
return return
entries = await aiofiles.os.scandir(settings.temp_file_dir) entries = await aiofiles.os.scandir(settings.temp_file_dir)
for entry in entries: for entry in entries:
if entry.is_file(): if entry.is_file():
stat = await aiofiles.os.stat(entry.path) stat = await aiofiles.os.stat(entry.path)
max_age = stat.st_mtime + (settings.max_temp_dir_age_hours * 3600) max_age = stat.st_mtime + (settings.max_temp_dir_age_hours * 3600)
if max_age < stat.st_mtime: if max_age < stat.st_mtime:
try: try:
await aiofiles.os.remove(entry.path) await aiofiles.os.remove(entry.path)
logger.info(f"Cleaned up old temp file: {entry.name}") logger.info(f"Cleaned up old temp file: {entry.name}")
except Exception as e: except Exception as e:
logger.warning( logger.warning(
f"Failed to delete old temp file {entry.name}: {e}" f"Failed to delete old temp file {entry.name}: {e}"
) )
except Exception as e: except Exception as e:
logger.warning(f"Error cleaning temp files: {e}") logger.warning(f"Error cleaning temp files: {e}")
async def get_temp_file_path(filename: str) -> str: async def get_temp_file_path(filename: str) -> str:
"""Get path to temporary audio file. """Get path to temporary audio file.
Args: Args:
filename: Name of temp file filename: Name of temp file
Returns: Returns:
Absolute path to temp file Absolute path to temp file
Raises: Raises:
RuntimeError: If temp directory does not exist RuntimeError: If temp directory does not exist
""" """
temp_path = os.path.join(settings.temp_file_dir, filename) temp_path = os.path.join(settings.temp_file_dir, filename)
# Ensure temp directory exists # Ensure temp directory exists
if not await aiofiles.os.path.exists(settings.temp_file_dir): if not await aiofiles.os.path.exists(settings.temp_file_dir):
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True) await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
return temp_path return temp_path
async def list_temp_files() -> List[str]: async def list_temp_files() -> List[str]:
"""List temporary audio files. """List temporary audio files.
Returns: Returns:
List of temp file names List of temp file names
""" """
if not await aiofiles.os.path.exists(settings.temp_file_dir): if not await aiofiles.os.path.exists(settings.temp_file_dir):
return [] return []
entries = await aiofiles.os.scandir(settings.temp_file_dir) entries = await aiofiles.os.scandir(settings.temp_file_dir)
return [entry.name for entry in entries if entry.is_file()] return [entry.name for entry in entries if entry.is_file()]
async def get_temp_dir_size() -> int: async def get_temp_dir_size() -> int:
"""Get total size of temp directory in bytes. """Get total size of temp directory in bytes.
Returns: Returns:
Size in bytes Size in bytes
""" """
if not await aiofiles.os.path.exists(settings.temp_file_dir): if not await aiofiles.os.path.exists(settings.temp_file_dir):
return 0 return 0
total = 0 total = 0
entries = await aiofiles.os.scandir(settings.temp_file_dir) entries = await aiofiles.os.scandir(settings.temp_file_dir)
for entry in entries: for entry in entries:
if entry.is_file(): if entry.is_file():
stat = await aiofiles.os.stat(entry.path) stat = await aiofiles.os.stat(entry.path)
total += stat.st_size total += stat.st_size
return total return total

View file

@ -1,12 +1,12 @@
"""Model inference package.""" """Model inference package."""
from .base import BaseModelBackend from .base import BaseModelBackend
from .kokoro_v1 import KokoroV1 from .kokoro_v1 import KokoroV1
from .model_manager import ModelManager, get_manager from .model_manager import ModelManager, get_manager
__all__ = [ __all__ = [
"BaseModelBackend", "BaseModelBackend",
"ModelManager", "ModelManager",
"get_manager", "get_manager",
"KokoroV1", "KokoroV1",
] ]

View file

@ -1,98 +1,98 @@
"""Base interface for Kokoro inference.""" """Base interface for Kokoro inference."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import AsyncGenerator, Optional, Tuple, Union from typing import AsyncGenerator, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
class ModelBackend(ABC): class ModelBackend(ABC):
"""Abstract base class for model inference backend.""" """Abstract base class for model inference backend."""
@abstractmethod @abstractmethod
async def load_model(self, path: str) -> None: async def load_model(self, path: str) -> None:
"""Load model from path. """Load model from path.
Args: Args:
path: Path to model file path: Path to model file
Raises: Raises:
RuntimeError: If model loading fails RuntimeError: If model loading fails
""" """
pass pass
@abstractmethod @abstractmethod
async def generate( async def generate(
self, self,
text: str, text: str,
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]], voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
speed: float = 1.0, speed: float = 1.0,
) -> AsyncGenerator[np.ndarray, None]: ) -> AsyncGenerator[np.ndarray, None]:
"""Generate audio from text. """Generate audio from text.
Args: Args:
text: Input text to synthesize text: Input text to synthesize
voice: Either a voice path or tuple of (name, tensor/path) voice: Either a voice path or tuple of (name, tensor/path)
speed: Speed multiplier speed: Speed multiplier
Yields: Yields:
Generated audio chunks Generated audio chunks
Raises: Raises:
RuntimeError: If generation fails RuntimeError: If generation fails
""" """
pass pass
@abstractmethod @abstractmethod
def unload(self) -> None: def unload(self) -> None:
"""Unload model and free resources.""" """Unload model and free resources."""
pass pass
@property @property
@abstractmethod @abstractmethod
def is_loaded(self) -> bool: def is_loaded(self) -> bool:
"""Check if model is loaded. """Check if model is loaded.
Returns: Returns:
True if model is loaded, False otherwise True if model is loaded, False otherwise
""" """
pass pass
@property @property
@abstractmethod @abstractmethod
def device(self) -> str: def device(self) -> str:
"""Get device model is running on. """Get device model is running on.
Returns: Returns:
Device string ('cpu' or 'cuda') Device string ('cpu' or 'cuda')
""" """
pass pass
class BaseModelBackend(ModelBackend): class BaseModelBackend(ModelBackend):
"""Base implementation of model backend.""" """Base implementation of model backend."""
def __init__(self): def __init__(self):
"""Initialize base backend.""" """Initialize base backend."""
self._model: Optional[torch.nn.Module] = None self._model: Optional[torch.nn.Module] = None
self._device: str = "cpu" self._device: str = "cpu"
@property @property
def is_loaded(self) -> bool: def is_loaded(self) -> bool:
"""Check if model is loaded.""" """Check if model is loaded."""
return self._model is not None return self._model is not None
@property @property
def device(self) -> str: def device(self) -> str:
"""Get device model is running on.""" """Get device model is running on."""
return self._device return self._device
def unload(self) -> None: def unload(self) -> None:
"""Unload model and free resources.""" """Unload model and free resources."""
if self._model is not None: if self._model is not None:
del self._model del self._model
self._model = None self._model = None
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.synchronize() torch.cuda.synchronize()

View file

@ -1,171 +1,171 @@
"""Kokoro V1 model management.""" """Kokoro V1 model management."""
from typing import Optional from typing import Optional
from loguru import logger from loguru import logger
from ..core import paths from ..core import paths
from ..core.config import settings from ..core.config import settings
from ..core.model_config import ModelConfig, model_config from ..core.model_config import ModelConfig, model_config
from .base import BaseModelBackend from .base import BaseModelBackend
from .kokoro_v1 import KokoroV1 from .kokoro_v1 import KokoroV1
class ModelManager: class ModelManager:
"""Manages Kokoro V1 model loading and inference.""" """Manages Kokoro V1 model loading and inference."""
# Singleton instance # Singleton instance
_instance = None _instance = None
def __init__(self, config: Optional[ModelConfig] = None): def __init__(self, config: Optional[ModelConfig] = None):
"""Initialize manager. """Initialize manager.
Args: Args:
config: Optional model configuration override config: Optional model configuration override
""" """
self._config = config or model_config self._config = config or model_config
self._backend: Optional[KokoroV1] = None # Explicitly type as KokoroV1 self._backend: Optional[KokoroV1] = None # Explicitly type as KokoroV1
self._device: Optional[str] = None self._device: Optional[str] = None
def _determine_device(self) -> str: def _determine_device(self) -> str:
"""Determine device based on settings.""" """Determine device based on settings."""
return "cuda" if settings.use_gpu else "cpu" return "cuda" if settings.use_gpu else "cpu"
async def initialize(self) -> None: async def initialize(self) -> None:
"""Initialize Kokoro V1 backend.""" """Initialize Kokoro V1 backend."""
try: try:
self._device = self._determine_device() self._device = self._determine_device()
logger.info(f"Initializing Kokoro V1 on {self._device}") logger.info(f"Initializing Kokoro V1 on {self._device}")
self._backend = KokoroV1() self._backend = KokoroV1()
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to initialize Kokoro V1: {e}") raise RuntimeError(f"Failed to initialize Kokoro V1: {e}")
async def initialize_with_warmup(self, voice_manager) -> tuple[str, str, int]: async def initialize_with_warmup(self, voice_manager) -> tuple[str, str, int]:
"""Initialize and warm up model. """Initialize and warm up model.
Args: Args:
voice_manager: Voice manager instance for warmup voice_manager: Voice manager instance for warmup
Returns: Returns:
Tuple of (device, backend type, voice count) Tuple of (device, backend type, voice count)
Raises: Raises:
RuntimeError: If initialization fails RuntimeError: If initialization fails
""" """
import time import time
start = time.perf_counter() start = time.perf_counter()
try: try:
# Initialize backend # Initialize backend
await self.initialize() await self.initialize()
# Load model # Load model
model_path = self._config.pytorch_kokoro_v1_file model_path = self._config.pytorch_kokoro_v1_file
await self.load_model(model_path) await self.load_model(model_path)
# Use paths module to get voice path # Use paths module to get voice path
try: try:
voices = await paths.list_voices() voices = await paths.list_voices()
voice_path = await paths.get_voice_path(settings.default_voice) voice_path = await paths.get_voice_path(settings.default_voice)
# Warm up with short text # Warm up with short text
warmup_text = "Warmup text for initialization." warmup_text = "Warmup text for initialization."
# Use default voice name for warmup # Use default voice name for warmup
voice_name = settings.default_voice voice_name = settings.default_voice
logger.debug(f"Using default voice '{voice_name}' for warmup") logger.debug(f"Using default voice '{voice_name}' for warmup")
async for _ in self.generate(warmup_text, (voice_name, voice_path)): async for _ in self.generate(warmup_text, (voice_name, voice_path)):
pass pass
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to get default voice: {e}") raise RuntimeError(f"Failed to get default voice: {e}")
ms = int((time.perf_counter() - start) * 1000) ms = int((time.perf_counter() - start) * 1000)
logger.info(f"Warmup completed in {ms}ms") logger.info(f"Warmup completed in {ms}ms")
return self._device, "kokoro_v1", len(voices) return self._device, "kokoro_v1", len(voices)
except FileNotFoundError as e: except FileNotFoundError as e:
logger.error(""" logger.error("""
Model files not found! You need to download the Kokoro V1 model: Model files not found! You need to download the Kokoro V1 model:
1. Download model using the script: 1. Download model using the script:
python docker/scripts/download_model.py --output api/src/models/v1_0 python docker/scripts/download_model.py --output api/src/models/v1_0
2. Or set environment variable in docker-compose: 2. Or set environment variable in docker-compose:
DOWNLOAD_MODEL=true DOWNLOAD_MODEL=true
""") """)
exit(0) exit(0)
except Exception as e: except Exception as e:
raise RuntimeError(f"Warmup failed: {e}") raise RuntimeError(f"Warmup failed: {e}")
def get_backend(self) -> BaseModelBackend: def get_backend(self) -> BaseModelBackend:
"""Get initialized backend. """Get initialized backend.
Returns: Returns:
Initialized backend instance Initialized backend instance
Raises: Raises:
RuntimeError: If backend not initialized RuntimeError: If backend not initialized
""" """
if not self._backend: if not self._backend:
raise RuntimeError("Backend not initialized") raise RuntimeError("Backend not initialized")
return self._backend return self._backend
async def load_model(self, path: str) -> None: async def load_model(self, path: str) -> None:
"""Load model using initialized backend. """Load model using initialized backend.
Args: Args:
path: Path to model file path: Path to model file
Raises: Raises:
RuntimeError: If loading fails RuntimeError: If loading fails
""" """
if not self._backend: if not self._backend:
raise RuntimeError("Backend not initialized") raise RuntimeError("Backend not initialized")
try: try:
await self._backend.load_model(path) await self._backend.load_model(path)
except FileNotFoundError as e: except FileNotFoundError as e:
raise e raise e
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to load model: {e}") raise RuntimeError(f"Failed to load model: {e}")
async def generate(self, *args, **kwargs): async def generate(self, *args, **kwargs):
"""Generate audio using initialized backend. """Generate audio using initialized backend.
Raises: Raises:
RuntimeError: If generation fails RuntimeError: If generation fails
""" """
if not self._backend: if not self._backend:
raise RuntimeError("Backend not initialized") raise RuntimeError("Backend not initialized")
try: try:
async for chunk in self._backend.generate(*args, **kwargs): async for chunk in self._backend.generate(*args, **kwargs):
yield chunk yield chunk
except Exception as e: except Exception as e:
raise RuntimeError(f"Generation failed: {e}") raise RuntimeError(f"Generation failed: {e}")
def unload_all(self) -> None: def unload_all(self) -> None:
"""Unload model and free resources.""" """Unload model and free resources."""
if self._backend: if self._backend:
self._backend.unload() self._backend.unload()
self._backend = None self._backend = None
@property @property
def current_backend(self) -> str: def current_backend(self) -> str:
"""Get current backend type.""" """Get current backend type."""
return "kokoro_v1" return "kokoro_v1"
async def get_manager(config: Optional[ModelConfig] = None) -> ModelManager: async def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
"""Get model manager instance. """Get model manager instance.
Args: Args:
config: Optional configuration override config: Optional configuration override
Returns: Returns:
ModelManager instance ModelManager instance
""" """
if ModelManager._instance is None: if ModelManager._instance is None:
ModelManager._instance = ModelManager(config) ModelManager._instance = ModelManager(config)
return ModelManager._instance return ModelManager._instance