diff --git a/Kokoro-82M b/Kokoro-82M deleted file mode 160000 index c97b7bb..0000000 --- a/Kokoro-82M +++ /dev/null @@ -1 +0,0 @@ -Subproject commit c97b7bbc3e60f447383c79b2f94fee861ff156ac diff --git a/README.md b/README.md index 28742d1..75f763f 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model - OpenAI-compatible Speech endpoint, with inline voice combination, and mapped naming/models for strict systems -- NVIDIA GPU accelerated or CPU inference (ONNX, Pytorch) (~80-300mb modelfile) +- NVIDIA GPU accelerated or CPU inference (ONNX, Pytorch) - very fast generation time - 35x-100x+ real time speed via 4060Ti+ - 5x+ real time speed via M3 Pro CPU diff --git a/api/src/core/config.py b/api/src/core/config.py index c5155cc..0337088 100644 --- a/api/src/core/config.py +++ b/api/src/core/config.py @@ -13,8 +13,8 @@ class Settings(BaseSettings): output_dir: str = "output" output_dir_size_limit_mb: float = 500.0 # Maximum size of output directory in MB default_voice: str = "af" - use_gpu: bool = False # Whether to use GPU acceleration if available - use_onnx: bool = True # Whether to use ONNX runtime + use_gpu: bool = True # Whether to use GPU acceleration if available + use_onnx: bool = False # Whether to use ONNX runtime allow_local_voice_saving: bool = False # Whether to allow saving combined voices locally # Container absolute paths diff --git a/api/src/core/model_config.py b/api/src/core/model_config.py index 7cb7b58..3f1e00b 100644 --- a/api/src/core/model_config.py +++ b/api/src/core/model_config.py @@ -11,7 +11,7 @@ class ONNXCPUConfig(BaseModel): instance_timeout: int = Field(300, description="Session timeout in seconds") # Runtime settings - num_threads: int = Field(4, description="Number of threads for parallel operations") + num_threads: int = Field(8, description="Number of threads for parallel operations") inter_op_threads: int = Field(4, description="Number of threads for operator parallelism") execution_mode: str = Field("parallel", description="ONNX execution mode") optimization_level: str = Field("all", description="ONNX optimization level") @@ -55,7 +55,6 @@ class PyTorchGPUConfig(BaseModel): """PyTorch GPU backend configuration.""" device_id: int = Field(0, description="CUDA device ID") - use_fp16: bool = Field(True, description="Whether to use FP16 precision") use_triton: bool = Field(True, description="Whether to use Triton for CUDA kernels") memory_threshold: float = Field(0.8, description="Memory threshold for cleanup") retry_on_oom: bool = Field(True, description="Whether to retry on OOM errors") diff --git a/api/src/inference/onnx_cpu.py b/api/src/inference/onnx_cpu.py index b17227c..10ee946 100644 --- a/api/src/inference/onnx_cpu.py +++ b/api/src/inference/onnx_cpu.py @@ -85,17 +85,23 @@ class ONNXCPUBackend(BaseModelBackend): style_input = voice[len(tokens) + 2].numpy() # Adjust index for start/end tokens speed_input = np.full(1, speed, dtype=np.float32) - # Run inference - result = self._session.run( - None, - { - "tokens": tokens_input, - "style": style_input, - "speed": speed_input - } - ) + # Build base inputs + inputs = { + "style": style_input, + "speed": speed_input + } - return result[0] + # Try both possible token input names #TODO: + for token_name in ["tokens", "input_ids"]: + try: + inputs[token_name] = tokens_input + result = self._session.run(None, inputs) + return result[0] + except Exception: + del inputs[token_name] + continue + + raise RuntimeError("Model does not accept either 'tokens' or 'input_ids' as input name") except Exception as e: raise RuntimeError(f"Generation failed: {e}") diff --git a/docker/cpu/download_onnx.py b/docker/cpu/download_onnx.py index a04718c..a97daf9 100755 --- a/docker/cpu/download_onnx.py +++ b/docker/cpu/download_onnx.py @@ -37,7 +37,7 @@ def main(custom_models: List[str] = None): # Default ONNX model if no arguments provided default_models = [ "https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.onnx", - "https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19_fp16.onnx" + # "https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19_fp16.onnx" ] # Use provided models or default diff --git a/web/app.js b/web/app.js index 18ec16e..82ca420 100644 --- a/web/app.js +++ b/web/app.js @@ -14,7 +14,9 @@ class KokoroPlayer { waveContainer: document.getElementById('wave-container'), timeDisplay: document.getElementById('time-display'), downloadBtn: document.getElementById('download-btn'), - status: document.getElementById('status') + status: document.getElementById('status'), + speedSlider: document.getElementById('speed-slider'), + speedValue: document.getElementById('speed-value') }; this.isGenerating = false; @@ -201,6 +203,11 @@ class KokoroPlayer { this.elements.playPauseBtn.addEventListener('click', () => this.togglePlayPause()); this.elements.downloadBtn.addEventListener('click', () => this.downloadAudio()); + this.elements.speedSlider.addEventListener('input', (e) => { + const speed = parseFloat(e.target.value); + this.elements.speedValue.textContent = speed.toFixed(1); + }); + document.addEventListener('click', (e) => { if (!this.elements.voiceSearch.contains(e.target) && !this.elements.voiceDropdown.contains(e.target)) { @@ -329,7 +336,8 @@ class KokoroPlayer { input: text, voice: voice, response_format: 'mp3', - stream: true + stream: true, + speed: parseFloat(this.elements.speedSlider.value) }), signal: this.currentController.signal }); @@ -418,11 +426,13 @@ class KokoroPlayer { if (this.audioChunks.length === 0) return; const format = this.elements.formatSelect.value; + const voice = Array.from(this.selectedVoiceSet).join('+'); + const timestamp = new Date().toISOString().replace(/[:.]/g, '-'); const blob = new Blob(this.audioChunks, { type: `audio/${format}` }); const url = URL.createObjectURL(blob); const a = document.createElement('a'); a.href = url; - a.download = `generated-speech.${format}`; + a.download = `${voice}_${timestamp}.${format}`; document.body.appendChild(a); a.click(); document.body.removeChild(a); diff --git a/web/index.html b/web/index.html index 2fbc2ef..7c6a5b5 100644 --- a/web/index.html +++ b/web/index.html @@ -26,7 +26,7 @@