Add a .gitattributes

This commit is contained in:
Fireblade 2025-02-18 17:44:03 -05:00
parent b00c9ec28d
commit 7f15ba8fed
5 changed files with 699 additions and 694 deletions

5
.gitattributes vendored Normal file
View file

@ -0,0 +1,5 @@
* text=auto
*.py text eol=lf
*.sh text eol=lf
*.yml text eol=lf

View file

@ -1,413 +1,413 @@
"""Async file and path operations."""
import io
import json
import os
from pathlib import Path
from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Set
import aiofiles
import aiofiles.os
import torch
from loguru import logger
from .config import settings
async def _find_file(
filename: str,
search_paths: List[str],
filter_fn: Optional[Callable[[str], bool]] = None,
) -> str:
"""Find file in search paths.
Args:
filename: Name of file to find
search_paths: List of paths to search in
filter_fn: Optional function to filter files
Returns:
Absolute path to file
Raises:
RuntimeError: If file not found
"""
if os.path.isabs(filename) and await aiofiles.os.path.exists(filename):
return filename
for path in search_paths:
full_path = os.path.join(path, filename)
if await aiofiles.os.path.exists(full_path):
if filter_fn is None or filter_fn(full_path):
return full_path
raise FileNotFoundError(f"File not found: {filename} in paths: {search_paths}")
async def _scan_directories(
search_paths: List[str], filter_fn: Optional[Callable[[str], bool]] = None
) -> Set[str]:
"""Scan directories for files.
Args:
search_paths: List of paths to scan
filter_fn: Optional function to filter files
Returns:
Set of matching filenames
"""
results = set()
for path in search_paths:
if not await aiofiles.os.path.exists(path):
continue
try:
# Get directory entries first
entries = await aiofiles.os.scandir(path)
# Then process entries after await completes
for entry in entries:
if filter_fn is None or filter_fn(entry.name):
results.add(entry.name)
except Exception as e:
logger.warning(f"Error scanning {path}: {e}")
return results
async def get_model_path(model_name: str) -> str:
"""Get path to model file.
Args:
model_name: Name of model file
Returns:
Absolute path to model file
Raises:
RuntimeError: If model not found
"""
# Get api directory path (two levels up from core)
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
# Construct model directory path relative to api directory
model_dir = os.path.join(api_dir, settings.model_dir)
# Ensure model directory exists
os.makedirs(model_dir, exist_ok=True)
# Search in model directory
search_paths = [model_dir]
logger.debug(f"Searching for model in path: {model_dir}")
return await _find_file(model_name, search_paths)
async def get_voice_path(voice_name: str) -> str:
"""Get path to voice file.
Args:
voice_name: Name of voice file (without .pt extension)
Returns:
Absolute path to voice file
Raises:
RuntimeError: If voice not found
"""
# Get api directory path
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
# Construct voice directory path relative to api directory
voice_dir = os.path.join(api_dir, settings.voices_dir)
# Ensure voice directory exists
os.makedirs(voice_dir, exist_ok=True)
voice_file = f"{voice_name}.pt"
# Search in voice directory/o
search_paths = [voice_dir]
logger.debug(f"Searching for voice in path: {voice_dir}")
return await _find_file(voice_file, search_paths)
async def list_voices() -> List[str]:
"""List available voice files.
Returns:
List of voice names (without .pt extension)
"""
# Get api directory path
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
# Construct voice directory path relative to api directory
voice_dir = os.path.join(api_dir, settings.voices_dir)
# Ensure voice directory exists
os.makedirs(voice_dir, exist_ok=True)
# Search in voice directory
search_paths = [voice_dir]
logger.debug(f"Scanning for voices in path: {voice_dir}")
def filter_voice_files(name: str) -> bool:
return name.endswith(".pt")
voices = await _scan_directories(search_paths, filter_voice_files)
return sorted([name[:-3] for name in voices]) # Remove .pt extension
async def load_voice_tensor(
voice_path: str, device: str = "cpu", weights_only=False
) -> torch.Tensor:
"""Load voice tensor from file.
Args:
voice_path: Path to voice file
device: Device to load tensor to
Returns:
Voice tensor
Raises:
RuntimeError: If file cannot be read
"""
try:
async with aiofiles.open(voice_path, "rb") as f:
data = await f.read()
return torch.load(
io.BytesIO(data), map_location=device, weights_only=weights_only
)
except Exception as e:
raise RuntimeError(f"Failed to load voice tensor from {voice_path}: {e}")
async def save_voice_tensor(tensor: torch.Tensor, voice_path: str) -> None:
"""Save voice tensor to file.
Args:
tensor: Voice tensor to save
voice_path: Path to save voice file
Raises:
RuntimeError: If file cannot be written
"""
try:
buffer = io.BytesIO()
torch.save(tensor, buffer)
async with aiofiles.open(voice_path, "wb") as f:
await f.write(buffer.getvalue())
except Exception as e:
raise RuntimeError(f"Failed to save voice tensor to {voice_path}: {e}")
async def load_json(path: str) -> dict:
"""Load JSON file asynchronously.
Args:
path: Path to JSON file
Returns:
Parsed JSON data
Raises:
RuntimeError: If file cannot be read or parsed
"""
try:
async with aiofiles.open(path, "r", encoding="utf-8") as f:
content = await f.read()
return json.loads(content)
except Exception as e:
raise RuntimeError(f"Failed to load JSON file {path}: {e}")
async def load_model_weights(path: str, device: str = "cpu") -> dict:
"""Load model weights asynchronously.
Args:
path: Path to model file (.pth or .onnx)
device: Device to load model to
Returns:
Model weights
Raises:
RuntimeError: If file cannot be read
"""
try:
async with aiofiles.open(path, "rb") as f:
data = await f.read()
return torch.load(io.BytesIO(data), map_location=device, weights_only=True)
except Exception as e:
raise RuntimeError(f"Failed to load model weights from {path}: {e}")
async def read_file(path: str) -> str:
"""Read text file asynchronously.
Args:
path: Path to file
Returns:
File contents as string
Raises:
RuntimeError: If file cannot be read
"""
try:
async with aiofiles.open(path, "r", encoding="utf-8") as f:
return await f.read()
except Exception as e:
raise RuntimeError(f"Failed to read file {path}: {e}")
async def read_bytes(path: str) -> bytes:
"""Read file as bytes asynchronously.
Args:
path: Path to file
Returns:
File contents as bytes
Raises:
RuntimeError: If file cannot be read
"""
try:
async with aiofiles.open(path, "rb") as f:
return await f.read()
except Exception as e:
raise RuntimeError(f"Failed to read file {path}: {e}")
async def get_web_file_path(filename: str) -> str:
"""Get path to web static file.
Args:
filename: Name of file in web directory
Returns:
Absolute path to file
Raises:
RuntimeError: If file not found
"""
# Get project root directory (four levels up from core to get to project root)
root_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
)
# Construct web directory path relative to project root
web_dir = os.path.join("/app", settings.web_player_path)
# Search in web directory
search_paths = [web_dir]
logger.debug(f"Searching for web file in path: {web_dir}")
return await _find_file(filename, search_paths)
async def get_content_type(path: str) -> str:
"""Get content type for file.
Args:
path: Path to file
Returns:
Content type string
"""
ext = os.path.splitext(path)[1].lower()
return {
".html": "text/html",
".js": "application/javascript",
".css": "text/css",
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".gif": "image/gif",
".svg": "image/svg+xml",
".ico": "image/x-icon",
}.get(ext, "application/octet-stream")
async def verify_model_path(model_path: str) -> bool:
"""Verify model file exists at path."""
return await aiofiles.os.path.exists(model_path)
async def cleanup_temp_files() -> None:
"""Clean up old temp files on startup"""
try:
if not await aiofiles.os.path.exists(settings.temp_file_dir):
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
return
entries = await aiofiles.os.scandir(settings.temp_file_dir)
for entry in entries:
if entry.is_file():
stat = await aiofiles.os.stat(entry.path)
max_age = stat.st_mtime + (settings.max_temp_dir_age_hours * 3600)
if max_age < stat.st_mtime:
try:
await aiofiles.os.remove(entry.path)
logger.info(f"Cleaned up old temp file: {entry.name}")
except Exception as e:
logger.warning(
f"Failed to delete old temp file {entry.name}: {e}"
)
except Exception as e:
logger.warning(f"Error cleaning temp files: {e}")
async def get_temp_file_path(filename: str) -> str:
"""Get path to temporary audio file.
Args:
filename: Name of temp file
Returns:
Absolute path to temp file
Raises:
RuntimeError: If temp directory does not exist
"""
temp_path = os.path.join(settings.temp_file_dir, filename)
# Ensure temp directory exists
if not await aiofiles.os.path.exists(settings.temp_file_dir):
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
return temp_path
async def list_temp_files() -> List[str]:
"""List temporary audio files.
Returns:
List of temp file names
"""
if not await aiofiles.os.path.exists(settings.temp_file_dir):
return []
entries = await aiofiles.os.scandir(settings.temp_file_dir)
return [entry.name for entry in entries if entry.is_file()]
async def get_temp_dir_size() -> int:
"""Get total size of temp directory in bytes.
Returns:
Size in bytes
"""
if not await aiofiles.os.path.exists(settings.temp_file_dir):
return 0
total = 0
entries = await aiofiles.os.scandir(settings.temp_file_dir)
for entry in entries:
if entry.is_file():
stat = await aiofiles.os.stat(entry.path)
total += stat.st_size
return total
"""Async file and path operations."""
import io
import json
import os
from pathlib import Path
from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Set
import aiofiles
import aiofiles.os
import torch
from loguru import logger
from .config import settings
async def _find_file(
filename: str,
search_paths: List[str],
filter_fn: Optional[Callable[[str], bool]] = None,
) -> str:
"""Find file in search paths.
Args:
filename: Name of file to find
search_paths: List of paths to search in
filter_fn: Optional function to filter files
Returns:
Absolute path to file
Raises:
RuntimeError: If file not found
"""
if os.path.isabs(filename) and await aiofiles.os.path.exists(filename):
return filename
for path in search_paths:
full_path = os.path.join(path, filename)
if await aiofiles.os.path.exists(full_path):
if filter_fn is None or filter_fn(full_path):
return full_path
raise FileNotFoundError(f"File not found: {filename} in paths: {search_paths}")
async def _scan_directories(
search_paths: List[str], filter_fn: Optional[Callable[[str], bool]] = None
) -> Set[str]:
"""Scan directories for files.
Args:
search_paths: List of paths to scan
filter_fn: Optional function to filter files
Returns:
Set of matching filenames
"""
results = set()
for path in search_paths:
if not await aiofiles.os.path.exists(path):
continue
try:
# Get directory entries first
entries = await aiofiles.os.scandir(path)
# Then process entries after await completes
for entry in entries:
if filter_fn is None or filter_fn(entry.name):
results.add(entry.name)
except Exception as e:
logger.warning(f"Error scanning {path}: {e}")
return results
async def get_model_path(model_name: str) -> str:
"""Get path to model file.
Args:
model_name: Name of model file
Returns:
Absolute path to model file
Raises:
RuntimeError: If model not found
"""
# Get api directory path (two levels up from core)
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
# Construct model directory path relative to api directory
model_dir = os.path.join(api_dir, settings.model_dir)
# Ensure model directory exists
os.makedirs(model_dir, exist_ok=True)
# Search in model directory
search_paths = [model_dir]
logger.debug(f"Searching for model in path: {model_dir}")
return await _find_file(model_name, search_paths)
async def get_voice_path(voice_name: str) -> str:
"""Get path to voice file.
Args:
voice_name: Name of voice file (without .pt extension)
Returns:
Absolute path to voice file
Raises:
RuntimeError: If voice not found
"""
# Get api directory path
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
# Construct voice directory path relative to api directory
voice_dir = os.path.join(api_dir, settings.voices_dir)
# Ensure voice directory exists
os.makedirs(voice_dir, exist_ok=True)
voice_file = f"{voice_name}.pt"
# Search in voice directory/o
search_paths = [voice_dir]
logger.debug(f"Searching for voice in path: {voice_dir}")
return await _find_file(voice_file, search_paths)
async def list_voices() -> List[str]:
"""List available voice files.
Returns:
List of voice names (without .pt extension)
"""
# Get api directory path
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
# Construct voice directory path relative to api directory
voice_dir = os.path.join(api_dir, settings.voices_dir)
# Ensure voice directory exists
os.makedirs(voice_dir, exist_ok=True)
# Search in voice directory
search_paths = [voice_dir]
logger.debug(f"Scanning for voices in path: {voice_dir}")
def filter_voice_files(name: str) -> bool:
return name.endswith(".pt")
voices = await _scan_directories(search_paths, filter_voice_files)
return sorted([name[:-3] for name in voices]) # Remove .pt extension
async def load_voice_tensor(
voice_path: str, device: str = "cpu", weights_only=False
) -> torch.Tensor:
"""Load voice tensor from file.
Args:
voice_path: Path to voice file
device: Device to load tensor to
Returns:
Voice tensor
Raises:
RuntimeError: If file cannot be read
"""
try:
async with aiofiles.open(voice_path, "rb") as f:
data = await f.read()
return torch.load(
io.BytesIO(data), map_location=device, weights_only=weights_only
)
except Exception as e:
raise RuntimeError(f"Failed to load voice tensor from {voice_path}: {e}")
async def save_voice_tensor(tensor: torch.Tensor, voice_path: str) -> None:
"""Save voice tensor to file.
Args:
tensor: Voice tensor to save
voice_path: Path to save voice file
Raises:
RuntimeError: If file cannot be written
"""
try:
buffer = io.BytesIO()
torch.save(tensor, buffer)
async with aiofiles.open(voice_path, "wb") as f:
await f.write(buffer.getvalue())
except Exception as e:
raise RuntimeError(f"Failed to save voice tensor to {voice_path}: {e}")
async def load_json(path: str) -> dict:
"""Load JSON file asynchronously.
Args:
path: Path to JSON file
Returns:
Parsed JSON data
Raises:
RuntimeError: If file cannot be read or parsed
"""
try:
async with aiofiles.open(path, "r", encoding="utf-8") as f:
content = await f.read()
return json.loads(content)
except Exception as e:
raise RuntimeError(f"Failed to load JSON file {path}: {e}")
async def load_model_weights(path: str, device: str = "cpu") -> dict:
"""Load model weights asynchronously.
Args:
path: Path to model file (.pth or .onnx)
device: Device to load model to
Returns:
Model weights
Raises:
RuntimeError: If file cannot be read
"""
try:
async with aiofiles.open(path, "rb") as f:
data = await f.read()
return torch.load(io.BytesIO(data), map_location=device, weights_only=True)
except Exception as e:
raise RuntimeError(f"Failed to load model weights from {path}: {e}")
async def read_file(path: str) -> str:
"""Read text file asynchronously.
Args:
path: Path to file
Returns:
File contents as string
Raises:
RuntimeError: If file cannot be read
"""
try:
async with aiofiles.open(path, "r", encoding="utf-8") as f:
return await f.read()
except Exception as e:
raise RuntimeError(f"Failed to read file {path}: {e}")
async def read_bytes(path: str) -> bytes:
"""Read file as bytes asynchronously.
Args:
path: Path to file
Returns:
File contents as bytes
Raises:
RuntimeError: If file cannot be read
"""
try:
async with aiofiles.open(path, "rb") as f:
return await f.read()
except Exception as e:
raise RuntimeError(f"Failed to read file {path}: {e}")
async def get_web_file_path(filename: str) -> str:
"""Get path to web static file.
Args:
filename: Name of file in web directory
Returns:
Absolute path to file
Raises:
RuntimeError: If file not found
"""
# Get project root directory (four levels up from core to get to project root)
root_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
)
# Construct web directory path relative to project root
web_dir = os.path.join("/app", settings.web_player_path)
# Search in web directory
search_paths = [web_dir]
logger.debug(f"Searching for web file in path: {web_dir}")
return await _find_file(filename, search_paths)
async def get_content_type(path: str) -> str:
"""Get content type for file.
Args:
path: Path to file
Returns:
Content type string
"""
ext = os.path.splitext(path)[1].lower()
return {
".html": "text/html",
".js": "application/javascript",
".css": "text/css",
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".gif": "image/gif",
".svg": "image/svg+xml",
".ico": "image/x-icon",
}.get(ext, "application/octet-stream")
async def verify_model_path(model_path: str) -> bool:
"""Verify model file exists at path."""
return await aiofiles.os.path.exists(model_path)
async def cleanup_temp_files() -> None:
"""Clean up old temp files on startup"""
try:
if not await aiofiles.os.path.exists(settings.temp_file_dir):
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
return
entries = await aiofiles.os.scandir(settings.temp_file_dir)
for entry in entries:
if entry.is_file():
stat = await aiofiles.os.stat(entry.path)
max_age = stat.st_mtime + (settings.max_temp_dir_age_hours * 3600)
if max_age < stat.st_mtime:
try:
await aiofiles.os.remove(entry.path)
logger.info(f"Cleaned up old temp file: {entry.name}")
except Exception as e:
logger.warning(
f"Failed to delete old temp file {entry.name}: {e}"
)
except Exception as e:
logger.warning(f"Error cleaning temp files: {e}")
async def get_temp_file_path(filename: str) -> str:
"""Get path to temporary audio file.
Args:
filename: Name of temp file
Returns:
Absolute path to temp file
Raises:
RuntimeError: If temp directory does not exist
"""
temp_path = os.path.join(settings.temp_file_dir, filename)
# Ensure temp directory exists
if not await aiofiles.os.path.exists(settings.temp_file_dir):
await aiofiles.os.makedirs(settings.temp_file_dir, exist_ok=True)
return temp_path
async def list_temp_files() -> List[str]:
"""List temporary audio files.
Returns:
List of temp file names
"""
if not await aiofiles.os.path.exists(settings.temp_file_dir):
return []
entries = await aiofiles.os.scandir(settings.temp_file_dir)
return [entry.name for entry in entries if entry.is_file()]
async def get_temp_dir_size() -> int:
"""Get total size of temp directory in bytes.
Returns:
Size in bytes
"""
if not await aiofiles.os.path.exists(settings.temp_file_dir):
return 0
total = 0
entries = await aiofiles.os.scandir(settings.temp_file_dir)
for entry in entries:
if entry.is_file():
stat = await aiofiles.os.stat(entry.path)
total += stat.st_size
return total

View file

@ -1,12 +1,12 @@
"""Model inference package."""
from .base import BaseModelBackend
from .kokoro_v1 import KokoroV1
from .model_manager import ModelManager, get_manager
__all__ = [
"BaseModelBackend",
"ModelManager",
"get_manager",
"KokoroV1",
]
"""Model inference package."""
from .base import BaseModelBackend
from .kokoro_v1 import KokoroV1
from .model_manager import ModelManager, get_manager
__all__ = [
"BaseModelBackend",
"ModelManager",
"get_manager",
"KokoroV1",
]

View file

@ -1,98 +1,98 @@
"""Base interface for Kokoro inference."""
from abc import ABC, abstractmethod
from typing import AsyncGenerator, Optional, Tuple, Union
import numpy as np
import torch
class ModelBackend(ABC):
"""Abstract base class for model inference backend."""
@abstractmethod
async def load_model(self, path: str) -> None:
"""Load model from path.
Args:
path: Path to model file
Raises:
RuntimeError: If model loading fails
"""
pass
@abstractmethod
async def generate(
self,
text: str,
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
speed: float = 1.0,
) -> AsyncGenerator[np.ndarray, None]:
"""Generate audio from text.
Args:
text: Input text to synthesize
voice: Either a voice path or tuple of (name, tensor/path)
speed: Speed multiplier
Yields:
Generated audio chunks
Raises:
RuntimeError: If generation fails
"""
pass
@abstractmethod
def unload(self) -> None:
"""Unload model and free resources."""
pass
@property
@abstractmethod
def is_loaded(self) -> bool:
"""Check if model is loaded.
Returns:
True if model is loaded, False otherwise
"""
pass
@property
@abstractmethod
def device(self) -> str:
"""Get device model is running on.
Returns:
Device string ('cpu' or 'cuda')
"""
pass
class BaseModelBackend(ModelBackend):
"""Base implementation of model backend."""
def __init__(self):
"""Initialize base backend."""
self._model: Optional[torch.nn.Module] = None
self._device: str = "cpu"
@property
def is_loaded(self) -> bool:
"""Check if model is loaded."""
return self._model is not None
@property
def device(self) -> str:
"""Get device model is running on."""
return self._device
def unload(self) -> None:
"""Unload model and free resources."""
if self._model is not None:
del self._model
self._model = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
"""Base interface for Kokoro inference."""
from abc import ABC, abstractmethod
from typing import AsyncGenerator, Optional, Tuple, Union
import numpy as np
import torch
class ModelBackend(ABC):
"""Abstract base class for model inference backend."""
@abstractmethod
async def load_model(self, path: str) -> None:
"""Load model from path.
Args:
path: Path to model file
Raises:
RuntimeError: If model loading fails
"""
pass
@abstractmethod
async def generate(
self,
text: str,
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
speed: float = 1.0,
) -> AsyncGenerator[np.ndarray, None]:
"""Generate audio from text.
Args:
text: Input text to synthesize
voice: Either a voice path or tuple of (name, tensor/path)
speed: Speed multiplier
Yields:
Generated audio chunks
Raises:
RuntimeError: If generation fails
"""
pass
@abstractmethod
def unload(self) -> None:
"""Unload model and free resources."""
pass
@property
@abstractmethod
def is_loaded(self) -> bool:
"""Check if model is loaded.
Returns:
True if model is loaded, False otherwise
"""
pass
@property
@abstractmethod
def device(self) -> str:
"""Get device model is running on.
Returns:
Device string ('cpu' or 'cuda')
"""
pass
class BaseModelBackend(ModelBackend):
"""Base implementation of model backend."""
def __init__(self):
"""Initialize base backend."""
self._model: Optional[torch.nn.Module] = None
self._device: str = "cpu"
@property
def is_loaded(self) -> bool:
"""Check if model is loaded."""
return self._model is not None
@property
def device(self) -> str:
"""Get device model is running on."""
return self._device
def unload(self) -> None:
"""Unload model and free resources."""
if self._model is not None:
del self._model
self._model = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()

View file

@ -1,171 +1,171 @@
"""Kokoro V1 model management."""
from typing import Optional
from loguru import logger
from ..core import paths
from ..core.config import settings
from ..core.model_config import ModelConfig, model_config
from .base import BaseModelBackend
from .kokoro_v1 import KokoroV1
class ModelManager:
"""Manages Kokoro V1 model loading and inference."""
# Singleton instance
_instance = None
def __init__(self, config: Optional[ModelConfig] = None):
"""Initialize manager.
Args:
config: Optional model configuration override
"""
self._config = config or model_config
self._backend: Optional[KokoroV1] = None # Explicitly type as KokoroV1
self._device: Optional[str] = None
def _determine_device(self) -> str:
"""Determine device based on settings."""
return "cuda" if settings.use_gpu else "cpu"
async def initialize(self) -> None:
"""Initialize Kokoro V1 backend."""
try:
self._device = self._determine_device()
logger.info(f"Initializing Kokoro V1 on {self._device}")
self._backend = KokoroV1()
except Exception as e:
raise RuntimeError(f"Failed to initialize Kokoro V1: {e}")
async def initialize_with_warmup(self, voice_manager) -> tuple[str, str, int]:
"""Initialize and warm up model.
Args:
voice_manager: Voice manager instance for warmup
Returns:
Tuple of (device, backend type, voice count)
Raises:
RuntimeError: If initialization fails
"""
import time
start = time.perf_counter()
try:
# Initialize backend
await self.initialize()
# Load model
model_path = self._config.pytorch_kokoro_v1_file
await self.load_model(model_path)
# Use paths module to get voice path
try:
voices = await paths.list_voices()
voice_path = await paths.get_voice_path(settings.default_voice)
# Warm up with short text
warmup_text = "Warmup text for initialization."
# Use default voice name for warmup
voice_name = settings.default_voice
logger.debug(f"Using default voice '{voice_name}' for warmup")
async for _ in self.generate(warmup_text, (voice_name, voice_path)):
pass
except Exception as e:
raise RuntimeError(f"Failed to get default voice: {e}")
ms = int((time.perf_counter() - start) * 1000)
logger.info(f"Warmup completed in {ms}ms")
return self._device, "kokoro_v1", len(voices)
except FileNotFoundError as e:
logger.error("""
Model files not found! You need to download the Kokoro V1 model:
1. Download model using the script:
python docker/scripts/download_model.py --output api/src/models/v1_0
2. Or set environment variable in docker-compose:
DOWNLOAD_MODEL=true
""")
exit(0)
except Exception as e:
raise RuntimeError(f"Warmup failed: {e}")
def get_backend(self) -> BaseModelBackend:
"""Get initialized backend.
Returns:
Initialized backend instance
Raises:
RuntimeError: If backend not initialized
"""
if not self._backend:
raise RuntimeError("Backend not initialized")
return self._backend
async def load_model(self, path: str) -> None:
"""Load model using initialized backend.
Args:
path: Path to model file
Raises:
RuntimeError: If loading fails
"""
if not self._backend:
raise RuntimeError("Backend not initialized")
try:
await self._backend.load_model(path)
except FileNotFoundError as e:
raise e
except Exception as e:
raise RuntimeError(f"Failed to load model: {e}")
async def generate(self, *args, **kwargs):
"""Generate audio using initialized backend.
Raises:
RuntimeError: If generation fails
"""
if not self._backend:
raise RuntimeError("Backend not initialized")
try:
async for chunk in self._backend.generate(*args, **kwargs):
yield chunk
except Exception as e:
raise RuntimeError(f"Generation failed: {e}")
def unload_all(self) -> None:
"""Unload model and free resources."""
if self._backend:
self._backend.unload()
self._backend = None
@property
def current_backend(self) -> str:
"""Get current backend type."""
return "kokoro_v1"
async def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
"""Get model manager instance.
Args:
config: Optional configuration override
Returns:
ModelManager instance
"""
if ModelManager._instance is None:
ModelManager._instance = ModelManager(config)
return ModelManager._instance
"""Kokoro V1 model management."""
from typing import Optional
from loguru import logger
from ..core import paths
from ..core.config import settings
from ..core.model_config import ModelConfig, model_config
from .base import BaseModelBackend
from .kokoro_v1 import KokoroV1
class ModelManager:
"""Manages Kokoro V1 model loading and inference."""
# Singleton instance
_instance = None
def __init__(self, config: Optional[ModelConfig] = None):
"""Initialize manager.
Args:
config: Optional model configuration override
"""
self._config = config or model_config
self._backend: Optional[KokoroV1] = None # Explicitly type as KokoroV1
self._device: Optional[str] = None
def _determine_device(self) -> str:
"""Determine device based on settings."""
return "cuda" if settings.use_gpu else "cpu"
async def initialize(self) -> None:
"""Initialize Kokoro V1 backend."""
try:
self._device = self._determine_device()
logger.info(f"Initializing Kokoro V1 on {self._device}")
self._backend = KokoroV1()
except Exception as e:
raise RuntimeError(f"Failed to initialize Kokoro V1: {e}")
async def initialize_with_warmup(self, voice_manager) -> tuple[str, str, int]:
"""Initialize and warm up model.
Args:
voice_manager: Voice manager instance for warmup
Returns:
Tuple of (device, backend type, voice count)
Raises:
RuntimeError: If initialization fails
"""
import time
start = time.perf_counter()
try:
# Initialize backend
await self.initialize()
# Load model
model_path = self._config.pytorch_kokoro_v1_file
await self.load_model(model_path)
# Use paths module to get voice path
try:
voices = await paths.list_voices()
voice_path = await paths.get_voice_path(settings.default_voice)
# Warm up with short text
warmup_text = "Warmup text for initialization."
# Use default voice name for warmup
voice_name = settings.default_voice
logger.debug(f"Using default voice '{voice_name}' for warmup")
async for _ in self.generate(warmup_text, (voice_name, voice_path)):
pass
except Exception as e:
raise RuntimeError(f"Failed to get default voice: {e}")
ms = int((time.perf_counter() - start) * 1000)
logger.info(f"Warmup completed in {ms}ms")
return self._device, "kokoro_v1", len(voices)
except FileNotFoundError as e:
logger.error("""
Model files not found! You need to download the Kokoro V1 model:
1. Download model using the script:
python docker/scripts/download_model.py --output api/src/models/v1_0
2. Or set environment variable in docker-compose:
DOWNLOAD_MODEL=true
""")
exit(0)
except Exception as e:
raise RuntimeError(f"Warmup failed: {e}")
def get_backend(self) -> BaseModelBackend:
"""Get initialized backend.
Returns:
Initialized backend instance
Raises:
RuntimeError: If backend not initialized
"""
if not self._backend:
raise RuntimeError("Backend not initialized")
return self._backend
async def load_model(self, path: str) -> None:
"""Load model using initialized backend.
Args:
path: Path to model file
Raises:
RuntimeError: If loading fails
"""
if not self._backend:
raise RuntimeError("Backend not initialized")
try:
await self._backend.load_model(path)
except FileNotFoundError as e:
raise e
except Exception as e:
raise RuntimeError(f"Failed to load model: {e}")
async def generate(self, *args, **kwargs):
"""Generate audio using initialized backend.
Raises:
RuntimeError: If generation fails
"""
if not self._backend:
raise RuntimeError("Backend not initialized")
try:
async for chunk in self._backend.generate(*args, **kwargs):
yield chunk
except Exception as e:
raise RuntimeError(f"Generation failed: {e}")
def unload_all(self) -> None:
"""Unload model and free resources."""
if self._backend:
self._backend.unload()
self._backend = None
@property
def current_backend(self) -> str:
"""Get current backend type."""
return "kokoro_v1"
async def get_manager(config: Optional[ModelConfig] = None) -> ModelManager:
"""Get model manager instance.
Args:
config: Optional configuration override
Returns:
ModelManager instance
"""
if ModelManager._instance is None:
ModelManager._instance = ModelManager(config)
return ModelManager._instance