diff --git a/api/src/builds/v1_0/config.json b/api/src/builds/v1_0/config.json new file mode 100644 index 0000000..25f35b9 --- /dev/null +++ b/api/src/builds/v1_0/config.json @@ -0,0 +1,172 @@ +{ + "istftnet": { + "upsample_kernel_sizes": [ + 20, + 12 + ], + "upsample_rates": [ + 10, + 6 + ], + "gen_istft_hop_size": 5, + "gen_istft_n_fft": 20, + "resblock_dilation_sizes": [ + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ] + ], + "resblock_kernel_sizes": [ + 3, + 7, + 11 + ], + "upsample_initial_channel": 512 + }, + "dim_in": 64, + "dropout": 0.2, + "hidden_dim": 512, + "max_conv_dim": 512, + "max_dur": 50, + "multispeaker": true, + "n_layer": 3, + "n_mels": 80, + "n_token": 178, + "style_dim": 128, + "text_encoder_kernel_size": 5, + "plbert": { + "hidden_size": 768, + "num_attention_heads": 12, + "intermediate_size": 2048, + "max_position_embeddings": 512, + "num_hidden_layers": 12, + "dropout": 0.1 + }, + "vocab": { + ";": 1, + ":": 2, + ",": 3, + ".": 4, + "!": 5, + "?": 6, + "—": 9, + "…": 10, + "\"": 11, + "(": 12, + ")": 13, + "“": 14, + "”": 15, + " ": 16, + "̃": 17, + "ʣ": 18, + "ʥ": 19, + "ʦ": 20, + "ʨ": 21, + "ᵝ": 22, + "ꭧ": 23, + "A": 24, + "I": 25, + "O": 31, + "Q": 33, + "S": 35, + "T": 36, + "W": 39, + "Y": 41, + "ᵊ": 42, + "a": 43, + "b": 44, + "c": 45, + "d": 46, + "e": 47, + "f": 48, + "h": 50, + "i": 51, + "j": 52, + "k": 53, + "l": 54, + "m": 55, + "n": 56, + "o": 57, + "p": 58, + "q": 59, + "r": 60, + "s": 61, + "t": 62, + "u": 63, + "v": 64, + "w": 65, + "x": 66, + "y": 67, + "z": 68, + "ɑ": 69, + "ɐ": 70, + "ɒ": 71, + "æ": 72, + "β": 75, + "ɔ": 76, + "ɕ": 77, + "ç": 78, + "ɖ": 80, + "ð": 81, + "ʤ": 82, + "ə": 83, + "ɚ": 85, + "ɛ": 86, + "ɜ": 87, + "ɟ": 90, + "ɡ": 92, + "ɥ": 99, + "ɨ": 101, + "ɪ": 102, + "ʝ": 103, + "ɯ": 110, + "ɰ": 111, + "ŋ": 112, + "ɳ": 113, + "ɲ": 114, + "ɴ": 115, + "ø": 116, + "ɸ": 118, + "θ": 119, + "œ": 120, + "ɹ": 123, + "ɾ": 125, + "ɻ": 126, + "ʁ": 128, + "ɽ": 129, + "ʂ": 130, + "ʃ": 131, + "ʈ": 132, + "ʧ": 133, + "ʊ": 135, + "ʋ": 136, + "ʌ": 138, + "ɣ": 139, + "ɤ": 140, + "χ": 142, + "ʎ": 143, + "ʒ": 147, + "ʔ": 148, + "ˈ": 156, + "ˌ": 157, + "ː": 158, + "ʰ": 162, + "ʲ": 164, + "↓": 169, + "→": 171, + "↗": 172, + "↘": 173, + "ᵻ": 177 + } +} \ No newline at end of file diff --git a/api/src/core/model_config.py b/api/src/core/model_config.py index 7da360e..21f55b1 100644 --- a/api/src/core/model_config.py +++ b/api/src/core/model_config.py @@ -2,6 +2,9 @@ from pydantic import BaseModel, Field +class KokoroV1Config(BaseModel): + languages: list[str] = ["en"] + class ONNXCPUConfig(BaseModel): """ONNX CPU runtime configuration.""" @@ -77,6 +80,7 @@ class ModelConfig(BaseModel): voice_cache_size: int = Field(2, description="Maximum number of cached voices") # Model filenames + pytorch_kokoro_v1_file: str = Field("v1_0/kokoro-v1_0.pth", description="PyTorch Kokoro V1 model filename") pytorch_model_file: str = Field("kokoro-v0_19-half.pth", description="PyTorch model filename") onnx_model_file: str = Field("kokoro-v0_19.onnx", description="ONNX model filename") @@ -93,7 +97,7 @@ class ModelConfig(BaseModel): """Get configuration for specific backend. Args: - backend_type: Backend type ('pytorch_cpu', 'pytorch_gpu', 'onnx_cpu', 'onnx_gpu') + backend_type: Backend type ('pytorch_cpu', 'pytorch_gpu', 'onnx_cpu', 'onnx_gpu', 'kokoro_v1') Returns: Backend-specific configuration @@ -102,7 +106,7 @@ class ModelConfig(BaseModel): ValueError: If backend type is invalid """ if backend_type not in { - 'pytorch_cpu', 'pytorch_gpu', 'onnx_cpu', 'onnx_gpu' + 'pytorch_cpu', 'pytorch_gpu', 'onnx_cpu', 'onnx_gpu', 'kokoro_v1' }: raise ValueError(f"Invalid backend type: {backend_type}") diff --git a/api/src/inference/kokoro_v1.py b/api/src/inference/kokoro_v1.py new file mode 100644 index 0000000..b294f80 --- /dev/null +++ b/api/src/inference/kokoro_v1.py @@ -0,0 +1,186 @@ +"""PyTorch inference backend with environment-based configuration.""" + +import gc +import os +from typing import AsyncGenerator, Optional, List, Union, Tuple +from contextlib import nullcontext + +import numpy as np +import torch +from loguru import logger + +from ..core import paths +from ..core.model_config import model_config +from ..core.config import settings +from .base import BaseModelBackend +from kokoro import KModel, KPipeline + + +class KokoroV1(BaseModelBackend): + """Kokoro package based inference backend with environment-based configuration.""" + + def __init__(self): + """Initialize backend based on environment configuration.""" + super().__init__() + + # Configure device based on settings + self._device = ( + "cuda" if settings.use_gpu and torch.cuda.is_available() else "cpu" + ) + self._model: Optional[KModel] = None + self._pipeline: Optional[KPipeline] = None + + async def load_model(self, path: str) -> None: + """Load Kokoro model. + + Args: + path: Path to model file + + Raises: + RuntimeError: If model loading fails + """ + try: + # Get verified model path + model_path = await paths.get_model_path(path) + + # Get config.json path from the same directory + config_path = os.path.join(os.path.dirname(model_path), 'config.json') + + if not os.path.exists(config_path): + raise RuntimeError(f"Config file not found: {config_path}") + + logger.info(f"Loading Kokoro model on {self._device}") + logger.info(f"Config path: {config_path}") + logger.info(f"Model path: {model_path}") + + # Initialize model with config and weights + self._model = KModel(config=config_path, model=model_path).to(self._device).eval() + # Initialize pipeline with American English by default + self._pipeline = KPipeline(lang_code='a', model=self._model, device=self._device) + + except Exception as e: + raise RuntimeError(f"Failed to load Kokoro model: {e}") + + async def generate( + self, text: str, voice: Union[str, Tuple[str, Union[torch.Tensor, str]]], speed: float = 1.0 + ) -> AsyncGenerator[np.ndarray, None]: + """Generate audio using model. + + Args: + text: Input text to synthesize + voice: Either a voice path string or a tuple of (voice_name, voice_tensor_or_path) + speed: Speed multiplier + + Yields: + Generated audio chunks + + Raises: + RuntimeError: If generation fails + """ + if not self.is_loaded: + raise RuntimeError("Model not loaded") + + try: + # Memory management for GPU + if self._device == "cuda": + if self._check_memory(): + self._clear_memory() + + # Handle voice input + if isinstance(voice, str): + voice_path = voice # Voice path provided directly + logger.debug(f"Using voice path directly: {voice_path}") + # Get language code from first letter of voice name + try: + name = os.path.basename(voice_path) + logger.debug(f"Voice basename: {name}") + if name.endswith('.pt'): + name = name[:-3] + lang_code = name[0] + logger.debug(f"Extracted language code: {lang_code}") + except Exception as e: + # Default to American English if we can't get language code + logger.warning(f"Failed to extract language code: {e}, defaulting to 'a'") + lang_code = 'a' + else: + # Unpack voice name and tensor/path + voice_name, voice_data = voice + # If voice_data is a path, use it directly + if isinstance(voice_data, str): + voice_path = voice_data + logger.debug(f"Using provided voice path: {voice_path}") + else: + # Save tensor to temporary file + import tempfile + temp_dir = tempfile.gettempdir() + voice_path = os.path.join(temp_dir, f"{voice_name}.pt") + logger.debug(f"Saving voice tensor to: {voice_path}") + torch.save(voice_data, voice_path) + # Get language code from voice name + lang_code = voice_name[0] + logger.debug(f"Using language code '{lang_code}' from voice name {voice_name}") + + # Update pipeline's language code if needed + if self._pipeline.lang_code != lang_code: + logger.debug(f"Creating pipeline with lang_code='{lang_code}'") + self._pipeline = KPipeline(lang_code=lang_code, model=self._model, device=self._device) + + # Generate audio using pipeline + logger.debug(f"Generating audio for text: '{text[:100]}...'") + for i, result in enumerate(self._pipeline(text, voice=voice_path, speed=speed)): + logger.debug(f"Processing chunk {i+1}") + if result.audio is not None: + logger.debug(f"Got audio chunk {i+1} with shape: {result.audio.shape}") + yield result.audio.numpy() + else: + logger.warning(f"No audio in chunk {i+1}") + + except Exception as e: + logger.error(f"Generation failed: {e}") + if ( + self._device == "cuda" + and model_config.pytorch_gpu.retry_on_oom + and "out of memory" in str(e).lower() + ): + self._clear_memory() + async for chunk in self.generate(text, voice, speed): + yield chunk + raise + finally: + if self._device == "cuda" and model_config.pytorch_gpu.sync_cuda: + torch.cuda.synchronize() + + def _check_memory(self) -> bool: + """Check if memory usage is above threshold.""" + if self._device == "cuda": + memory_gb = torch.cuda.memory_allocated() / 1e9 + return memory_gb > model_config.pytorch_gpu.memory_threshold + return False + + def _clear_memory(self) -> None: + """Clear device memory.""" + if self._device == "cuda": + torch.cuda.empty_cache() + gc.collect() + + def unload(self) -> None: + """Unload model and free resources.""" + if self._model is not None: + del self._model + self._model = None + if self._pipeline is not None: + del self._pipeline + self._pipeline = None + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + @property + def is_loaded(self) -> bool: + """Check if model is loaded.""" + return self._model is not None and self._pipeline is not None + + @property + def device(self) -> str: + """Get device model is running on.""" + return self._device diff --git a/api/src/inference/model_manager.py b/api/src/inference/model_manager.py index 5fe53c9..a0e87e3 100644 --- a/api/src/inference/model_manager.py +++ b/api/src/inference/model_manager.py @@ -1,9 +1,10 @@ """Model management and caching.""" import asyncio -from typing import Dict, Optional, Tuple +from typing import Dict, Optional, Tuple, Union, AsyncGenerator import torch +import numpy as np from loguru import logger from ..core import paths @@ -13,6 +14,7 @@ from .base import BaseModelBackend from .onnx_cpu import ONNXCPUBackend from .onnx_gpu import ONNXGPUBackend from .pytorch_backend import PyTorchBackend +from .kokoro_v1 import KokoroV1 from .session_pool import CPUSessionPool, StreamingSessionPool @@ -56,7 +58,13 @@ class ModelManager: device = self._determine_device() try: - if device == "cuda": + # First check if we should use Kokoro V1 + if model_config.pytorch_kokoro_v1_file: + self._backends['kokoro_v1'] = KokoroV1() + self._current_backend = 'kokoro_v1' + logger.info(f"Initialized new Kokoro V1 backend on {device}") + # Otherwise use legacy backends + elif device == "cuda": if settings.use_onnx: self._backends['onnx_gpu'] = ONNXGPUBackend() self._current_backend = 'onnx_gpu' @@ -93,8 +101,11 @@ class ModelManager: RuntimeError: If initialization fails """ try: - # Determine backend type based on settings - if settings.use_onnx: + # First check if we should use Kokoro V1 + if model_config.pytorch_kokoro_v1_file: + backend_type = 'kokoro_v1' + # Otherwise determine legacy backend type + elif settings.use_onnx: backend_type = 'onnx_gpu' if settings.use_gpu and torch.cuda.is_available() else 'onnx_cpu' else: backend_type = 'pytorch' @@ -103,17 +114,26 @@ class ModelManager: backend = self.get_backend(backend_type) # Get and verify model path - model_file = model_config.pytorch_model_file if not settings.use_onnx else model_config.onnx_model_file + if backend_type == 'kokoro_v1': + model_file = model_config.pytorch_kokoro_v1_file + else: + model_file = model_config.pytorch_model_file if not settings.use_onnx else model_config.onnx_model_file model_path = await paths.get_model_path(model_file) if not await paths.verify_model_path(model_path): raise RuntimeError(f"Model file not found: {model_path}") # Pre-cache default voice and use for warmup - warmup_voice = await voice_manager.load_voice( + warmup_voice_tensor = await voice_manager.load_voice( settings.default_voice, device=backend.device) logger.info(f"Pre-cached voice {settings.default_voice} for warmup") + # For Kokoro V1, wrap voice in tuple with name + if isinstance(backend, KokoroV1): + warmup_voice = (settings.default_voice, warmup_voice_tensor) + else: + warmup_voice = warmup_voice_tensor + # Initialize model with warmup voice await self.load_model(model_path, warmup_voice, backend_type) @@ -126,7 +146,7 @@ class ModelManager: # Get device info for return device = "GPU" if settings.use_gpu else "CPU" - model = "ONNX" if settings.use_onnx else "PyTorch" + model = "Kokoro V1" if backend_type == 'kokoro_v1' else ("ONNX" if settings.use_onnx else "PyTorch") return device, model, voicepack_count @@ -137,7 +157,7 @@ class ModelManager: def get_backend(self, backend_type: Optional[str] = None) -> BaseModelBackend: """Get specified backend. Args: - backend_type: Backend type ('pytorch_cpu', 'pytorch_gpu', 'onnx_cpu', 'onnx_gpu'), + backend_type: Backend type ('pytorch_cpu', 'pytorch_gpu', 'onnx_cpu', 'onnx_gpu', 'kokoro_v1'), uses default if None Returns: Model backend instance @@ -166,15 +186,18 @@ class ModelManager: Returns: Backend type to use """ - # If ONNX is preferred or model is ONNX format - if settings.use_onnx or model_path.lower().endswith('.onnx'): + # Check if it's a Kokoro V1 model + if model_path.endswith(model_config.pytorch_kokoro_v1_file): + return 'kokoro_v1' + # Otherwise use legacy backend determination + elif settings.use_onnx or model_path.lower().endswith('.onnx'): return 'onnx_gpu' if settings.use_gpu and torch.cuda.is_available() else 'onnx_cpu' return 'pytorch' async def load_model( self, model_path: str, - warmup_voice: Optional[torch.Tensor] = None, + warmup_voice: Optional[Union[str, Tuple[str, torch.Tensor]]] = None, backend_type: Optional[str] = None ) -> None: """Load model on specified backend. @@ -206,7 +229,7 @@ class ModelManager: self._loaded_models[backend_type] = abs_path logger.info(f"Fetched model instance from {backend_type} pool") - # For PyTorch backends, load normally + # For PyTorch and Kokoro backends, load normally else: # Check if model is already loaded if (backend_type in self._loaded_models and @@ -229,27 +252,34 @@ class ModelManager: self._loaded_models.pop(backend_type, None) raise RuntimeError(f"Failed to load model: {e}") - async def _warmup_inference(self, backend: BaseModelBackend, voice: torch.Tensor) -> None: + async def _warmup_inference( + self, + backend: BaseModelBackend, + voice: Union[str, Tuple[str, torch.Tensor]] + ) -> None: """Run warmup inference to initialize model. Args: backend: Model backend to warm up - voice: Voice tensor already loaded on correct device + voice: Voice path or (name, tensor) tuple """ try: - # Import here to avoid circular imports - from ..services.text_processing import process_text - - # Use real text + # Use real text for warmup text = "Testing text to speech synthesis." - # Process through pipeline - tokens = process_text(text) - if not tokens: - raise ValueError("Text processing failed") - # Run inference - backend.generate(tokens, voice, speed=1.0) + if isinstance(backend, KokoroV1): + async for _ in backend.generate(text, voice, speed=1.0): + pass # Just run through the chunks + else: + # Import here to avoid circular imports + from ..services.text_processing import process_text + tokens = process_text(text) + if not tokens: + raise ValueError("Text processing failed") + # For legacy backends, extract tensor if needed + voice_tensor = voice[1] if isinstance(voice, tuple) else voice + backend.generate(tokens, voice_tensor, speed=1.0) logger.debug("Completed warmup inference") except Exception as e: @@ -258,21 +288,21 @@ class ModelManager: async def generate( self, - tokens: list[int], - voice: torch.Tensor, + input_text: str, + voice: Union[str, Tuple[str, torch.Tensor]], speed: float = 1.0, backend_type: Optional[str] = None - ) -> torch.Tensor: + ) -> AsyncGenerator[np.ndarray, None]: """Generate audio using specified backend. Args: - tokens: Input token IDs - voice: Voice tensor already loaded on correct device + input_text: Input text to synthesize + voice: Voice path or (name, tensor) tuple speed: Speed multiplier backend_type: Backend to use, uses default if None - Returns: - Generated audio tensor + Yields: + Generated audio chunks Raises: RuntimeError: If generation fails @@ -282,9 +312,20 @@ class ModelManager: raise RuntimeError("Model not loaded") try: - # Generate audio using provided voice tensor + # Generate audio using provided voice # No lock needed here since inference is thread-safe - return backend.generate(tokens, voice, speed) + if isinstance(backend, KokoroV1): + async for chunk in backend.generate(input_text, voice, speed): + yield chunk + else: + # Import here to avoid circular imports + from ..services.text_processing import process_text + tokens = process_text(input_text) + if not tokens: + raise ValueError("Text processing failed") + # For legacy backends, extract tensor if needed + voice_tensor = voice[1] if isinstance(voice, tuple) else voice + yield backend.generate(tokens, voice_tensor, speed) except Exception as e: raise RuntimeError(f"Generation failed: {e}") @@ -294,7 +335,7 @@ class ModelManager: for pool in self._session_pools.values(): pool.cleanup() - # Unload PyTorch backends + # Unload all backends for backend in self._backends.values(): backend.unload() @@ -303,14 +344,12 @@ class ModelManager: @property def available_backends(self) -> list[str]: - """Get list of available backends. - """ + """Get list of available backends.""" return list(self._backends.keys()) @property def current_backend(self) -> str: - """Get current default backend. - """ + """Get current default backend.""" return self._current_backend @@ -336,4 +375,3 @@ async def get_manager(config: Optional[ModelConfig] = None) -> ModelManager: _manager_instance = ModelManager(config) await _manager_instance.initialize() return _manager_instance - diff --git a/api/src/inference/voice_manager.py b/api/src/inference/voice_manager.py index 56557e6..6ac6545 100644 --- a/api/src/inference/voice_manager.py +++ b/api/src/inference/voice_manager.py @@ -8,6 +8,7 @@ from loguru import logger from ..core import paths from ..core.config import settings +from ..core.model_config import model_config from ..structures.model_schemas import VoiceConfig @@ -33,8 +34,28 @@ class VoiceManager: Path to voice file if exists, None otherwise """ api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) - voice_path = os.path.join(api_dir, settings.voices_dir, f"{voice_name}.pt") - return voice_path if os.path.exists(voice_path) else None + voices_dir = os.path.join(api_dir, settings.voices_dir) + + logger.debug(f"Looking for voice: {voice_name}") + logger.debug(f"Base voices directory: {voices_dir}") + + # Check v1_0 subdirectory first if using Kokoro V1 + if model_config.pytorch_kokoro_v1_file: + v1_path = os.path.join(voices_dir, 'v1_0', f"{voice_name}.pt") + logger.debug(f"Checking v1_0 path: {v1_path}") + if os.path.exists(v1_path): + logger.debug(f"Found voice in v1_0: {v1_path}") + return v1_path + + # Fall back to main voices directory + voice_path = os.path.join(voices_dir, f"{voice_name}.pt") + logger.debug(f"Checking main path: {voice_path}") + if os.path.exists(voice_path): + logger.debug(f"Found voice in main dir: {voice_path}") + return voice_path + + logger.debug(f"Voice not found: {voice_name}") + return None async def load_voice(self, voice_name: str, device: str = "cpu") -> torch.Tensor: """Load voice tensor. @@ -74,10 +95,12 @@ class VoiceManager: # Check cache cache_key = f"{voice_path}_{device}" if self._config.use_cache and cache_key in self._voice_cache: + logger.debug(f"Using cached voice: {voice_name} from {voice_path}") return self._voice_cache[cache_key] # Load voice tensor try: + logger.debug(f"Loading voice tensor from: {voice_path}") voice = await paths.load_voice_tensor(voice_path, device=device) except Exception as e: raise RuntimeError(f"Failed to load voice {voice_name}: {e}") @@ -86,7 +109,7 @@ class VoiceManager: if self._config.use_cache: self._manage_cache() self._voice_cache[cache_key] = voice - logger.debug(f"Cached voice: {voice_name} on {device}") + logger.debug(f"Cached voice: {voice_name} on {device} from {voice_path}") return voice @@ -128,6 +151,11 @@ class VoiceManager: # Save to disk api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) voices_dir = os.path.join(api_dir, settings.voices_dir) + + # Save in v1_0 directory if using Kokoro V1 + if model_config.pytorch_kokoro_v1_file: + voices_dir = os.path.join(voices_dir, 'v1_0') + os.makedirs(voices_dir, exist_ok=True) combined_path = os.path.join(voices_dir, f"{combined_name}.pt") @@ -157,9 +185,21 @@ class VoiceManager: voices_dir = os.path.join(api_dir, settings.voices_dir) os.makedirs(voices_dir, exist_ok=True) - for entry in os.listdir(voices_dir): - if entry.endswith(".pt"): - voices.add(entry[:-3]) + # Check v1_0 subdirectory if using Kokoro V1 + if model_config.pytorch_kokoro_v1_file: + v1_dir = os.path.join(voices_dir, 'v1_0') + logger.debug(f"Checking v1_0 directory: {v1_dir}") + if os.path.exists(v1_dir): + for entry in os.listdir(v1_dir): + if entry.endswith(".pt"): + voices.add(entry[:-3]) + logger.debug(f"Found v1_0 voice: {entry[:-3]}") + else: + # Check main voices directory + for entry in os.listdir(voices_dir): + if entry.endswith(".pt"): + voices.add(entry[:-3]) + logger.debug(f"Found main voice: {entry[:-3]}") except Exception as e: logger.error(f"Error listing voices: {e}") @@ -177,7 +217,7 @@ class VoiceManager: try: if not os.path.exists(voice_path): return False - voice = torch.load(voice_path, map_location="cpu") + voice = torch.load(voice_path, map_location="cpu", weights_only=False) return isinstance(voice, torch.Tensor) except Exception: return False diff --git a/api/src/main.py b/api/src/main.py index 5932020..9c74fcf 100644 --- a/api/src/main.py +++ b/api/src/main.py @@ -30,7 +30,7 @@ def setup_logger(): "{level: <8} | " "{message}", "colorize": True, - "level": "INFO", + "level": "DEBUG", }, ], } diff --git a/api/src/models/v1_0/config.json b/api/src/models/v1_0/config.json new file mode 100644 index 0000000..25f35b9 --- /dev/null +++ b/api/src/models/v1_0/config.json @@ -0,0 +1,172 @@ +{ + "istftnet": { + "upsample_kernel_sizes": [ + 20, + 12 + ], + "upsample_rates": [ + 10, + 6 + ], + "gen_istft_hop_size": 5, + "gen_istft_n_fft": 20, + "resblock_dilation_sizes": [ + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ] + ], + "resblock_kernel_sizes": [ + 3, + 7, + 11 + ], + "upsample_initial_channel": 512 + }, + "dim_in": 64, + "dropout": 0.2, + "hidden_dim": 512, + "max_conv_dim": 512, + "max_dur": 50, + "multispeaker": true, + "n_layer": 3, + "n_mels": 80, + "n_token": 178, + "style_dim": 128, + "text_encoder_kernel_size": 5, + "plbert": { + "hidden_size": 768, + "num_attention_heads": 12, + "intermediate_size": 2048, + "max_position_embeddings": 512, + "num_hidden_layers": 12, + "dropout": 0.1 + }, + "vocab": { + ";": 1, + ":": 2, + ",": 3, + ".": 4, + "!": 5, + "?": 6, + "—": 9, + "…": 10, + "\"": 11, + "(": 12, + ")": 13, + "“": 14, + "”": 15, + " ": 16, + "̃": 17, + "ʣ": 18, + "ʥ": 19, + "ʦ": 20, + "ʨ": 21, + "ᵝ": 22, + "ꭧ": 23, + "A": 24, + "I": 25, + "O": 31, + "Q": 33, + "S": 35, + "T": 36, + "W": 39, + "Y": 41, + "ᵊ": 42, + "a": 43, + "b": 44, + "c": 45, + "d": 46, + "e": 47, + "f": 48, + "h": 50, + "i": 51, + "j": 52, + "k": 53, + "l": 54, + "m": 55, + "n": 56, + "o": 57, + "p": 58, + "q": 59, + "r": 60, + "s": 61, + "t": 62, + "u": 63, + "v": 64, + "w": 65, + "x": 66, + "y": 67, + "z": 68, + "ɑ": 69, + "ɐ": 70, + "ɒ": 71, + "æ": 72, + "β": 75, + "ɔ": 76, + "ɕ": 77, + "ç": 78, + "ɖ": 80, + "ð": 81, + "ʤ": 82, + "ə": 83, + "ɚ": 85, + "ɛ": 86, + "ɜ": 87, + "ɟ": 90, + "ɡ": 92, + "ɥ": 99, + "ɨ": 101, + "ɪ": 102, + "ʝ": 103, + "ɯ": 110, + "ɰ": 111, + "ŋ": 112, + "ɳ": 113, + "ɲ": 114, + "ɴ": 115, + "ø": 116, + "ɸ": 118, + "θ": 119, + "œ": 120, + "ɹ": 123, + "ɾ": 125, + "ɻ": 126, + "ʁ": 128, + "ɽ": 129, + "ʂ": 130, + "ʃ": 131, + "ʈ": 132, + "ʧ": 133, + "ʊ": 135, + "ʋ": 136, + "ʌ": 138, + "ɣ": 139, + "ɤ": 140, + "χ": 142, + "ʎ": 143, + "ʒ": 147, + "ʔ": 148, + "ˈ": 156, + "ˌ": 157, + "ː": 158, + "ʰ": 162, + "ʲ": 164, + "↓": 169, + "→": 171, + "↗": 172, + "↘": 173, + "ᵻ": 177 + } +} \ No newline at end of file diff --git a/api/src/services/tts_service.py b/api/src/services/tts_service.py index 0b37ea9..bcaf830 100644 --- a/api/src/services/tts_service.py +++ b/api/src/services/tts_service.py @@ -1,9 +1,11 @@ """TTS service using model and voice managers.""" +import os import time +import tempfile from typing import List, Tuple, Optional, AsyncGenerator, Union -import asyncio +import asyncio import numpy as np import torch from loguru import logger @@ -14,6 +16,8 @@ from ..inference.voice_manager import get_manager as get_voice_manager from .audio import AudioNormalizer, AudioService from .text_processing.text_processor import process_text_chunk, smart_split from .text_processing import tokenize +from ..inference.kokoro_v1 import KokoroV1 + class TTSService: """Text-to-speech service.""" @@ -37,14 +41,16 @@ class TTSService: async def _process_chunk( self, + chunk_text: str, tokens: List[int], - voice_tensor: torch.Tensor, + voice_name: str, + voice_path: str, speed: float, output_format: Optional[str] = None, is_first: bool = False, is_last: bool = False, normalizer: Optional[AudioNormalizer] = None, - ) -> Optional[Union[np.ndarray, bytes]]: + ) -> AsyncGenerator[Union[np.ndarray, bytes], None]: """Process tokens into audio.""" async with self._chunk_semaphore: try: @@ -52,9 +58,10 @@ class TTSService: if is_last: # Skip format conversion for raw audio mode if not output_format: - return np.array([], dtype=np.float32) + yield np.array([], dtype=np.float32) + return - return await AudioService.convert_audio( + result = await AudioService.convert_audio( np.array([0], dtype=np.float32), # Dummy data for type checking 24000, output_format, @@ -62,45 +69,126 @@ class TTSService: normalizer=normalizer, is_last_chunk=True ) + yield result + return # Skip empty chunks - if not tokens: - return None + if not tokens and not chunk_text: + return + + # Get backend + backend = self.model_manager.get_backend() # Generate audio using pre-warmed model - chunk_audio = await self.model_manager.generate( - tokens, - voice_tensor, - speed=speed - ) - - if chunk_audio is None: - logger.error("Model generated None for audio chunk") - return None - - if len(chunk_audio) == 0: - logger.error("Model generated empty audio chunk") - return None + if isinstance(backend, KokoroV1): + # For Kokoro V1, pass text and voice info + async for chunk_audio in self.model_manager.generate( + chunk_text, + (voice_name, voice_path), + speed=speed + ): + # For streaming, convert to bytes + if output_format: + try: + converted = await AudioService.convert_audio( + chunk_audio, + 24000, + output_format, + is_first_chunk=is_first, + normalizer=normalizer, + is_last_chunk=is_last + ) + yield converted + except Exception as e: + logger.error(f"Failed to convert audio: {str(e)}") + else: + yield chunk_audio + else: + # For legacy backends, load voice tensor + voice_tensor = await self._voice_manager.load_voice(voice_name, device=backend.device) + chunk_audio = await self.model_manager.generate( + tokens, + voice_tensor, + speed=speed + ) - # For streaming, convert to bytes - if output_format: - try: - return await AudioService.convert_audio( - chunk_audio, - 24000, - output_format, - is_first_chunk=is_first, - normalizer=normalizer, - is_last_chunk=is_last - ) - except Exception as e: - logger.error(f"Failed to convert audio: {str(e)}") - return None + if chunk_audio is None: + logger.error("Model generated None for audio chunk") + return - return chunk_audio + if len(chunk_audio) == 0: + logger.error("Model generated empty audio chunk") + return + + # For streaming, convert to bytes + if output_format: + try: + converted = await AudioService.convert_audio( + chunk_audio, + 24000, + output_format, + is_first_chunk=is_first, + normalizer=normalizer, + is_last_chunk=is_last + ) + yield converted + except Exception as e: + logger.error(f"Failed to convert audio: {str(e)}") + else: + yield chunk_audio except Exception as e: logger.error(f"Failed to process tokens: {str(e)}") - return None + + async def _get_voice_path(self, voice: str) -> Tuple[str, str]: + """Get voice path, handling combined voices. + + Args: + voice: Voice name or combined voice names (e.g., 'af_jadzia+af_jessica') + + Returns: + Tuple of (voice name to use, voice path to use) + + Raises: + RuntimeError: If voice not found + """ + try: + # Check if it's a combined voice + if "+" in voice: + voices = [v.strip() for v in voice.split("+") if v.strip()] + if len(voices) < 2: + raise RuntimeError(f"Invalid combined voice name: {voice}") + + # Load and combine voices + voice_tensors = [] + for v in voices: + path = self._voice_manager.get_voice_path(v) + if not path: + raise RuntimeError(f"Voice not found: {v}") + logger.debug(f"Loading voice tensor from: {path}") + voice_tensor = torch.load(path, map_location="cpu") + voice_tensors.append(voice_tensor) + + # Average the voice tensors + logger.debug(f"Combining {len(voice_tensors)} voice tensors") + combined = torch.mean(torch.stack(voice_tensors), dim=0) + + # Save combined tensor + temp_dir = tempfile.gettempdir() + combined_path = os.path.join(temp_dir, f"{voice}.pt") + logger.debug(f"Saving combined voice to: {combined_path}") + torch.save(combined, combined_path) + + return voice, combined_path + else: + # Single voice + path = self._voice_manager.get_voice_path(voice) + if not path: + raise RuntimeError(f"Voice not found: {voice}") + logger.debug(f"Using single voice path: {path}") + return voice, path + except Exception as e: + logger.error(f"Failed to get voice path: {e}") + raise async def generate_audio_stream( self, @@ -111,33 +199,36 @@ class TTSService: ) -> AsyncGenerator[bytes, None]: """Generate and stream audio chunks.""" stream_normalizer = AudioNormalizer() - voice_tensor = None chunk_index = 0 try: - # Get backend and load voice (should be fast if cached) + # Get backend backend = self.model_manager.get_backend() - voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device) + + # Get voice path, handling combined voices + voice_name, voice_path = await self._get_voice_path(voice) + logger.debug(f"Using voice path: {voice_path}") # Process text in chunks with smart splitting async for chunk_text, tokens in smart_split(text): try: # Process audio for chunk - result = await self._process_chunk( - tokens, # Now always a flat List[int] - voice_tensor, + async for result in self._process_chunk( + chunk_text, # Pass text for Kokoro V1 + tokens, # Pass tokens for legacy backends + voice_name, # Pass voice name + voice_path, # Pass voice path speed, output_format, is_first=(chunk_index == 0), is_last=False, # We'll update the last chunk later normalizer=stream_normalizer - ) - - if result is not None: - yield result - chunk_index += 1 - else: - logger.warning(f"No audio generated for chunk: '{chunk_text[:100]}...'") + ): + if result is not None: + yield result + chunk_index += 1 + else: + logger.warning(f"No audio generated for chunk: '{chunk_text[:100]}...'") except Exception as e: logger.error(f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}") @@ -147,81 +238,25 @@ class TTSService: if chunk_index > 0: try: # Empty tokens list to finalize audio - final_result = await self._process_chunk( - [], # Empty tokens list - voice_tensor, + async for result in self._process_chunk( + "", # Empty text + [], # Empty tokens + voice_name, + voice_path, speed, output_format, is_first=False, - is_last=True, + is_last=True, # Signal this is the last chunk normalizer=stream_normalizer - ) - if final_result is not None: - logger.debug("Yielding final chunk to finalize audio") - yield final_result - else: - logger.warning("Final chunk processing returned None") + ): + if result is not None: + yield result except Exception as e: - logger.error(f"Failed to process final chunk: {str(e)}") - else: - logger.warning("No audio chunks were successfully processed") - - except Exception as e: - logger.error(f"Error in audio generation stream: {str(e)}") - raise - finally: - if voice_tensor is not None: - del voice_tensor - torch.cuda.empty_cache() - - async def generate_from_phonemes( - self, phonemes: str, voice: str, speed: float = 1.0 - ) -> Tuple[np.ndarray, float]: - """Generate audio from phonemes. - - Args: - phonemes: Phoneme string to synthesize - voice: Voice ID to use - speed: Speed multiplier - - Returns: - Tuple of (audio array, processing time) - """ - start_time = time.time() - voice_tensor = None - - try: - # Get backend and load voice - backend = self.model_manager.get_backend() - voice_tensor = await self._voice_manager.load_voice(voice, device=backend.device) - - # Convert phonemes to tokens - tokens = tokenize(phonemes) - if len(tokens) > 500: # Model context limit - raise ValueError(f"Phoneme sequence too long ({len(tokens)} tokens, max 500)") - - tokens = [0] + tokens + [0] # Add start/end tokens - - # Generate audio - audio = await self.model_manager.generate( - tokens, - voice_tensor, - speed=speed - ) - - if audio is None: - raise ValueError("Failed to generate audio") - - processing_time = time.time() - start_time - return audio, processing_time + logger.error(f"Failed to finalize audio stream: {str(e)}") except Exception as e: logger.error(f"Error in phoneme audio generation: {str(e)}") raise - finally: - if voice_tensor is not None: - del voice_tensor - torch.cuda.empty_cache() async def generate_audio( self, text: str, voice: str, speed: float = 1.0 @@ -263,4 +298,4 @@ class TTSService: async def list_voices(self) -> List[str]: """List available voices.""" - return await self._voice_manager.list_voices() \ No newline at end of file + return await self._voice_manager.list_voices() diff --git a/api/src/voices/v1_0/af_bella.pt b/api/src/voices/v1_0/af_bella.pt new file mode 100644 index 0000000..02e6874 Binary files /dev/null and b/api/src/voices/v1_0/af_bella.pt differ diff --git a/api/src/voices/v1_0/af_heart.pt b/api/src/voices/v1_0/af_heart.pt new file mode 100644 index 0000000..508d04a Binary files /dev/null and b/api/src/voices/v1_0/af_heart.pt differ diff --git a/api/src/voices/v1_0/af_jadzia.pt b/api/src/voices/v1_0/af_jadzia.pt new file mode 100644 index 0000000..a96fbec Binary files /dev/null and b/api/src/voices/v1_0/af_jadzia.pt differ diff --git a/api/src/voices/v1_0/af_nicole.pt b/api/src/voices/v1_0/af_nicole.pt new file mode 100644 index 0000000..207da10 Binary files /dev/null and b/api/src/voices/v1_0/af_nicole.pt differ diff --git a/api/src/voices/v1_0/af_sarah.pt b/api/src/voices/v1_0/af_sarah.pt new file mode 100644 index 0000000..b488607 Binary files /dev/null and b/api/src/voices/v1_0/af_sarah.pt differ diff --git a/api/src/voices/v1_0/af_sky.pt b/api/src/voices/v1_0/af_sky.pt new file mode 100644 index 0000000..a047bc6 Binary files /dev/null and b/api/src/voices/v1_0/af_sky.pt differ diff --git a/api/src/voices/v1_0/am_adam.pt b/api/src/voices/v1_0/am_adam.pt new file mode 100644 index 0000000..f36cd12 Binary files /dev/null and b/api/src/voices/v1_0/am_adam.pt differ diff --git a/api/src/voices/v1_0/am_michael.pt b/api/src/voices/v1_0/am_michael.pt new file mode 100644 index 0000000..b755169 Binary files /dev/null and b/api/src/voices/v1_0/am_michael.pt differ diff --git a/api/src/voices/v1_0/am_santa.pt b/api/src/voices/v1_0/am_santa.pt new file mode 100644 index 0000000..26b612a Binary files /dev/null and b/api/src/voices/v1_0/am_santa.pt differ diff --git a/api/src/voices/v1_0/bf_emma.pt b/api/src/voices/v1_0/bf_emma.pt new file mode 100644 index 0000000..c020e85 Binary files /dev/null and b/api/src/voices/v1_0/bf_emma.pt differ diff --git a/api/src/voices/v1_0/bf_isabella.pt b/api/src/voices/v1_0/bf_isabella.pt new file mode 100644 index 0000000..c9e2d73 Binary files /dev/null and b/api/src/voices/v1_0/bf_isabella.pt differ diff --git a/api/src/voices/v1_0/bm_george.pt b/api/src/voices/v1_0/bm_george.pt new file mode 100644 index 0000000..3077341 Binary files /dev/null and b/api/src/voices/v1_0/bm_george.pt differ diff --git a/api/src/voices/v1_0/bm_lewis.pt b/api/src/voices/v1_0/bm_lewis.pt new file mode 100644 index 0000000..529102a Binary files /dev/null and b/api/src/voices/v1_0/bm_lewis.pt differ diff --git a/api/src/voices/v1_0/ef_dora.pt b/api/src/voices/v1_0/ef_dora.pt new file mode 100644 index 0000000..3c2ec76 Binary files /dev/null and b/api/src/voices/v1_0/ef_dora.pt differ diff --git a/api/src/voices/v1_0/em_alex.pt b/api/src/voices/v1_0/em_alex.pt new file mode 100644 index 0000000..bee5535 Binary files /dev/null and b/api/src/voices/v1_0/em_alex.pt differ diff --git a/api/src/voices/v1_0/em_santa.pt b/api/src/voices/v1_0/em_santa.pt new file mode 100644 index 0000000..8d8a09c Binary files /dev/null and b/api/src/voices/v1_0/em_santa.pt differ diff --git a/api/src/voices/v1_0/ff_siwis.pt b/api/src/voices/v1_0/ff_siwis.pt new file mode 100644 index 0000000..5910d6a Binary files /dev/null and b/api/src/voices/v1_0/ff_siwis.pt differ diff --git a/api/src/voices/v1_0/hf_alpha.pt b/api/src/voices/v1_0/hf_alpha.pt new file mode 100644 index 0000000..3adbf3f Binary files /dev/null and b/api/src/voices/v1_0/hf_alpha.pt differ diff --git a/api/src/voices/v1_0/hf_beta.pt b/api/src/voices/v1_0/hf_beta.pt new file mode 100644 index 0000000..410c5cc Binary files /dev/null and b/api/src/voices/v1_0/hf_beta.pt differ diff --git a/api/src/voices/v1_0/hm_omega.pt b/api/src/voices/v1_0/hm_omega.pt new file mode 100644 index 0000000..6b1979a Binary files /dev/null and b/api/src/voices/v1_0/hm_omega.pt differ diff --git a/api/src/voices/v1_0/hm_psi.pt b/api/src/voices/v1_0/hm_psi.pt new file mode 100644 index 0000000..7172b5a Binary files /dev/null and b/api/src/voices/v1_0/hm_psi.pt differ diff --git a/api/src/voices/v1_0/if_sara.pt b/api/src/voices/v1_0/if_sara.pt new file mode 100644 index 0000000..92dc90a Binary files /dev/null and b/api/src/voices/v1_0/if_sara.pt differ diff --git a/api/src/voices/v1_0/im_nicola.pt b/api/src/voices/v1_0/im_nicola.pt new file mode 100644 index 0000000..935172f Binary files /dev/null and b/api/src/voices/v1_0/im_nicola.pt differ diff --git a/api/src/voices/v1_0/jf_alpha.pt b/api/src/voices/v1_0/jf_alpha.pt new file mode 100644 index 0000000..28077d0 Binary files /dev/null and b/api/src/voices/v1_0/jf_alpha.pt differ diff --git a/api/src/voices/v1_0/jf_gongitsune.pt b/api/src/voices/v1_0/jf_gongitsune.pt new file mode 100644 index 0000000..832500d Binary files /dev/null and b/api/src/voices/v1_0/jf_gongitsune.pt differ diff --git a/api/src/voices/v1_0/jf_nezumi.pt b/api/src/voices/v1_0/jf_nezumi.pt new file mode 100644 index 0000000..e9d02df Binary files /dev/null and b/api/src/voices/v1_0/jf_nezumi.pt differ diff --git a/api/src/voices/v1_0/jf_tebukuro.pt b/api/src/voices/v1_0/jf_tebukuro.pt new file mode 100644 index 0000000..11f71f8 Binary files /dev/null and b/api/src/voices/v1_0/jf_tebukuro.pt differ diff --git a/api/src/voices/v1_0/jm_kumo.pt b/api/src/voices/v1_0/jm_kumo.pt new file mode 100644 index 0000000..eb31d22 Binary files /dev/null and b/api/src/voices/v1_0/jm_kumo.pt differ diff --git a/api/src/voices/v1_0/pf_dora.pt b/api/src/voices/v1_0/pf_dora.pt new file mode 100644 index 0000000..67c5d40 Binary files /dev/null and b/api/src/voices/v1_0/pf_dora.pt differ diff --git a/api/src/voices/v1_0/pm_alex.pt b/api/src/voices/v1_0/pm_alex.pt new file mode 100644 index 0000000..21642f0 Binary files /dev/null and b/api/src/voices/v1_0/pm_alex.pt differ diff --git a/api/src/voices/v1_0/pm_santa.pt b/api/src/voices/v1_0/pm_santa.pt new file mode 100644 index 0000000..5906fe4 Binary files /dev/null and b/api/src/voices/v1_0/pm_santa.pt differ diff --git a/api/src/voices/v1_0/zf_xiaobei.pt b/api/src/voices/v1_0/zf_xiaobei.pt new file mode 100644 index 0000000..fba4240 Binary files /dev/null and b/api/src/voices/v1_0/zf_xiaobei.pt differ diff --git a/api/src/voices/v1_0/zf_xiaoni.pt b/api/src/voices/v1_0/zf_xiaoni.pt new file mode 100644 index 0000000..df30374 Binary files /dev/null and b/api/src/voices/v1_0/zf_xiaoni.pt differ diff --git a/api/src/voices/v1_0/zf_xiaoxiao.pt b/api/src/voices/v1_0/zf_xiaoxiao.pt new file mode 100644 index 0000000..8be88c7 Binary files /dev/null and b/api/src/voices/v1_0/zf_xiaoxiao.pt differ diff --git a/api/src/voices/v1_0/zf_xiaoyi.pt b/api/src/voices/v1_0/zf_xiaoyi.pt new file mode 100644 index 0000000..e3c02bb Binary files /dev/null and b/api/src/voices/v1_0/zf_xiaoyi.pt differ diff --git a/api/src/voices/v1_0/zm_yunjian.pt b/api/src/voices/v1_0/zm_yunjian.pt new file mode 100644 index 0000000..11bbcea Binary files /dev/null and b/api/src/voices/v1_0/zm_yunjian.pt differ diff --git a/api/src/voices/v1_0/zm_yunxi.pt b/api/src/voices/v1_0/zm_yunxi.pt new file mode 100644 index 0000000..faaa4fa Binary files /dev/null and b/api/src/voices/v1_0/zm_yunxi.pt differ diff --git a/api/src/voices/v1_0/zm_yunxia.pt b/api/src/voices/v1_0/zm_yunxia.pt new file mode 100644 index 0000000..4aa15ae Binary files /dev/null and b/api/src/voices/v1_0/zm_yunxia.pt differ diff --git a/api/src/voices/v1_0/zm_yunyang.pt b/api/src/voices/v1_0/zm_yunyang.pt new file mode 100644 index 0000000..4c44a0d Binary files /dev/null and b/api/src/voices/v1_0/zm_yunyang.pt differ diff --git a/docs/architecture/kokoro_v1_integration.md b/docs/architecture/kokoro_v1_integration.md new file mode 100644 index 0000000..015cf08 --- /dev/null +++ b/docs/architecture/kokoro_v1_integration.md @@ -0,0 +1,113 @@ +# Kokoro V1 Integration Architecture + +## Overview + +This document outlines the architectural approach for integrating the new Kokoro V1 library into our existing inference system. The goal is to bypass most of the legacy model machinery while maintaining compatibility with our existing interfaces, particularly the OpenAI-compatible streaming endpoint. + +## Current System + +The current system uses a `ModelBackend` interface with multiple implementations (ONNX CPU/GPU, PyTorch CPU/GPU). This interface requires: + +- Async model loading +- Audio generation from tokens and voice tensors +- Resource cleanup +- Device management + +## Integration Approach + +### 1. KokoroV1 Backend Implementation + +We'll create a `KokoroV1` class implementing the `ModelBackend` interface that wraps the new Kokoro library: + +```python +class KokoroV1(BaseModelBackend): + def __init__(self): + super().__init__() + self._model = None + self._pipeline = None + self._device = "cuda" if settings.use_gpu and torch.cuda.is_available() else "cpu" +``` + +### 2. Model Loading + +The load_model method will initialize both KModel and KPipeline: + +```python +async def load_model(self, path: str) -> None: + model_path = await paths.get_model_path(path) + self._model = KModel(model_path).to(self._device).eval() + self._pipeline = KPipeline(model=self._model, device=self._device) +``` + +### 3. Audio Generation + +The generate method will adapt our token/voice tensor format to work with KPipeline: + +```python +def generate(self, tokens: list[int], voice: torch.Tensor, speed: float = 1.0) -> np.ndarray: + # Convert tokens to text using pipeline's tokenizer + # Use voice tensor as voice embedding + # Return generated audio +``` + +### 4. Streaming Support + +The Kokoro V1 backend must maintain compatibility with our OpenAI-compatible streaming endpoint. Key requirements: + +1. **Chunked Generation**: The pipeline's output should be compatible with our streaming infrastructure: + ```python + async def generate_stream(self, text: str, voice_path: str) -> AsyncGenerator[bytes, None]: + results = self._pipeline(text, voice=voice_path) + for result in results: + yield result.audio.numpy() + ``` + +2. **Format Conversion**: Support for various output formats: + - MP3 + - Opus + - AAC + - FLAC + - WAV + - PCM + +3. **Voice Management**: + - Support for voice combination (mean of multiple voice embeddings) + - Dynamic voice loading and caching + - Voice listing and validation + +4. **Error Handling**: + - Proper error propagation for client disconnects + - Format conversion errors + - Resource cleanup on failures + +### 5. Configuration Integration + +We'll use the existing configuration system: + +```python +config = model_config.pytorch_kokoro_v1_file # Model file path +``` + +## Benefits + +1. **Simplified Pipeline**: Direct use of Kokoro library's built-in pipeline +2. **Better Language Support**: Access to Kokoro's wider language capabilities +3. **Automatic Chunking**: Built-in text chunking and processing +4. **Phoneme Generation**: Access to phoneme output for better analysis +5. **Streaming Compatibility**: Maintains existing streaming functionality + +## Migration Strategy + +1. Implement KokoroV1 backend with streaming support +2. Add to model manager's available backends +3. Make it the default for new requests +4. Keep legacy backends available for backward compatibility +5. Update voice management to handle both legacy and new voice formats + +## Next Steps + +1. Switch to Code mode to implement the KokoroV1 backend +2. Ensure streaming compatibility with OpenAI endpoint +3. Add tests to verify both streaming and non-streaming functionality +4. Update documentation for new capabilities +5. Add monitoring for streaming performance \ No newline at end of file diff --git a/docs/architecture/nlp_dependencies.md b/docs/architecture/nlp_dependencies.md new file mode 100644 index 0000000..db008ae --- /dev/null +++ b/docs/architecture/nlp_dependencies.md @@ -0,0 +1,66 @@ +# NLP Dependencies Management + +## Overview + +This document outlines our approach to managing NLP dependencies, particularly focusing on spaCy models that are required by our dependencies (such as misaki). The goal is to ensure reliable model availability while preventing runtime download attempts that could cause failures. + +## Challenge + +One of our dependencies, misaki, attempts to download the spaCy model `en_core_web_sm` during runtime. This can lead to failures if: +- The download fails due to network issues +- The environment lacks proper permissions +- The system is running in a restricted environment + +## Solution + +### Model Management with UV + +We use UV (Universal Versioner) as our package manager. For spaCy model management, we have two approaches: + +1. **Development Environment Setup** + ```bash + uv run --with spacy -- spacy download en_core_web_sm + ``` + This command: + - Temporarily installs spaCy if not present + - Downloads the required model + - Places it in the appropriate location + +2. **Project Environment** + - Add spaCy as a project dependency in pyproject.toml + - Run `uv run -- spacy download en_core_web_sm` in the project directory + - This installs the model in the project's virtual environment + +### Docker Environment + +For containerized deployments: +1. Add the model download step in the Dockerfile +2. Ensure the model is available before application startup +3. Configure misaki to use the pre-downloaded model + +## Benefits + +1. **Reliability**: Prevents runtime download attempts +2. **Reproducibility**: Model version is consistent across environments +3. **Performance**: No startup delay from download attempts +4. **Security**: Better control over external downloads + +## Implementation Notes + +1. Development environments should use the `uv run --with spacy` approach for flexibility +2. CI/CD pipelines should include model download in their setup phase +3. Docker builds should pre-download models during image creation +4. Application code should verify model availability at startup + +## Future Considerations + +1. Consider caching models in a shared location for multiple services +2. Implement version pinning for NLP models +3. Add health checks to verify model availability +4. Monitor model usage and performance + +## Related Documentation + +- [Kokoro V1 Integration](kokoro_v1_integration.md) +- UV Package Manager Documentation +- spaCy Model Management Guide \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index ca45eb5..d83e8af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,13 +30,17 @@ dependencies = [ "loguru==0.7.3", "transformers==4.47.1", "openai>=1.59.6", - "ebooklib>=0.18", - "html2text>=2024.2.26", + # "ebooklib>=0.18", + # "html2text>=2024.2.26", "pydub>=0.25.1", "matplotlib>=3.10.0", "semchunk>=3.0.1", "mutagen>=1.47.0", "psutil>=6.1.1", + "kokoro==0.3.5", + 'misaki[en,ja,ko,zh,vi]==0.6.7', + "spacy>=3.7.6", + "en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl" ] [project.optional-dependencies]