diff --git a/api/src/inference/voice_manager.py b/api/src/inference/voice_manager.py index 56557e6..01f2372 100644 --- a/api/src/inference/voice_manager.py +++ b/api/src/inference/voice_manager.py @@ -182,6 +182,106 @@ class VoiceManager: except Exception: return False + async def create_weighted_voice( + self, + formula: str, + normalize: bool = False, + device: str = "cpu", + ) -> str: + """ + Parse the voice formula string (e.g. '0.3 * voiceA + 0.5 * voiceB') + and return a combined torch.Tensor representing the weighted sum + of the given voices. + + Args: + formula: Weighted voice formula string. + voice_manager: A class that has a `load_voice(voice_name, device)` -> Tensor + device: 'cpu' or 'cuda' for the final tensor. + normalize: If True, divide the final result by the sum of the weights + so the total "magnitude" remains consistent. + + Returns: + A torch.Tensor containing the combined voice embedding. + """ + pairs = self.parse_voice_formula(formula) # [(weight, voiceName), (weight, voiceName), ...] + + # Validate the pairs + for weight, voice_name in pairs: + if weight <= 0: + raise ValueError(f"Invalid weight {weight} for voice {voice_name}.") + + if not pairs: + raise ValueError("No valid weighted voices found in formula.") + + # Keep track of total weight if we plan to normalize. + total_weight = 0.0 + weighted_sum = None + combined_name = "" + + for weight, voice_name in pairs: + # 1) Load each base voice from your manager/service + base_voice = await self.load_voice(voice_name, device=device) + + + # 3) Combine the base voices using the weights + if combined_name == "": + combined_name = voice_name + else: + combined_name += f"+{voice_name}" + + + + + # 2) Multiply by weight and accumulate + if weighted_sum is None: + # Clone so we don't modify the base voice in memory + weighted_sum = base_voice.clone() * weight + else: + weighted_sum += (base_voice * weight) + + total_weight += weight + + if weighted_sum is None: + raise ValueError("No voices were combined. Check the formula syntax.") + + # Optional normalization + if normalize and total_weight != 0.0: + weighted_sum /= total_weight + + if settings.allow_local_voice_saving: + + # Save to disk + api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + voices_dir = os.path.join(api_dir, settings.voices_dir) + os.makedirs(voices_dir, exist_ok=True) + combined_path = os.path.join(voices_dir, f"{formula}.pt") + try: + torch.save(weighted_sum, combined_path) + except Exception as e: + logger.warning(f"Failed to save combined voice: {e}") + # Continue without saving - will be combined on-the-fly when needed + + return combined_name + + + + + def parse_voice_formula(self,formula: str) -> List[tuple[float, str]]: + """ + Parse the voice formula string (e.g. '0.3 * voiceA + 0.5 * voiceB') + and return a list of (weight, voiceName) pairs. + Args: + formula: Weighted voice formula string. + Returns: + List of (weight, voiceName) pairs. + """ + pairs = [] + parts = formula.split('+') + for part in parts: + weight, voice_name = part.strip().split('*') + pairs.append((float(weight), voice_name.strip())) + return pairs + @property def cache_info(self) -> Dict[str, int]: """Get cache statistics. diff --git a/api/src/routers/openai_compatible.py b/api/src/routers/openai_compatible.py index 5908a56..4196f4c 100644 --- a/api/src/routers/openai_compatible.py +++ b/api/src/routers/openai_compatible.py @@ -2,7 +2,7 @@ import json import os from typing import AsyncGenerator, Dict, List, Union -from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response +from fastapi import APIRouter, Header, HTTPException, Request, Response from fastapi.responses import StreamingResponse from loguru import logger @@ -112,7 +112,17 @@ async def stream_audio_chunks( client_request: Request ) -> AsyncGenerator[bytes, None]: """Stream audio chunks as they're generated with client disconnect handling""" - voice_to_use = await process_voices(request.voice, tts_service) + # Check if 'request.voice' is a weighted formula (contains '*') + if '*' in request.voice: + # Weighted formula path + voice_to_use = await tts_service._voice_manager.create_weighted_voice( + formula=request.voice, + normalize=True + ) + + else: + # Normal single or multi-voice path + voice_to_use = await process_voices(request.voice, tts_service) try: async for chunk in tts_service.generate_audio_stream( @@ -159,10 +169,7 @@ async def create_speech( # Get global service instance tts_service = await get_tts_service() - # Process voice combination and validate - voice_to_use = await process_voices(request.voice, tts_service) - - # Set content type based on format + # Set content type based on format content_type = { "mp3": "audio/mpeg", "opus": "audio/opus", @@ -209,13 +216,27 @@ async def create_speech( }, ) else: - # Generate complete audio using public interface - audio, _ = await tts_service.generate_audio( - text=request.input, - voice=voice_to_use, - speed=request.speed, - stitch_long_output=True - ) + # Check if 'request.voice' is a weighted formula (contains '*') + if '*' in request.voice: + # Weighted formula path + print("Weighted formula path") + voice_to_use = await tts_service._voice_manager.create_weighted_voice( + formula=request.voice, + normalize=True + ) + print(voice_to_use) + else: + # Normal single or multi-voice path + print("Normal single or multi-voice path") + # Otherwise, handle normal single or multi-voice logic + voice_to_use = await process_voices(request.voice, tts_service) + # Generate complete audio using public interface + audio, _ = await tts_service.generate_audio( + text=request.input, + voice=voice_to_use, + speed=request.speed, + stitch_long_output=True + ) # Convert to requested format content = await AudioService.convert_audio( diff --git a/docker/gpu/docker-compose.yml b/docker/gpu/docker-compose.yml index f27e15b..c711925 100644 --- a/docker/gpu/docker-compose.yml +++ b/docker/gpu/docker-compose.yml @@ -1,7 +1,6 @@ -name: kokoro-tts +name: InstaVoice services: - kokoro-tts: - # image: ghcr.io/remsky/kokoro-fastapi-gpu:v0.1.0 + server1: build: context: ../.. dockerfile: docker/gpu/Dockerfile @@ -19,23 +18,59 @@ services: reservations: devices: - driver: nvidia - count: 1 + count: all capabilities: [gpu] - # # Gradio UI service - # gradio-ui: - # image: ghcr.io/remsky/kokoro-fastapi-ui:v0.1.0 - # # Uncomment below to build from source instead of using the released image - # # build: - # # context: ../../ui - # ports: - # - "7860:7860" - # volumes: - # - ../../ui/data:/app/ui/data - # - ../../ui/app.py:/app/app.py # Mount app.py for hot reload - # environment: - # - GRADIO_WATCH=1 # Enable hot reloading - # - PYTHONUNBUFFERED=1 # Ensure Python output is not buffered - # - DISABLE_LOCAL_SAVING=false # Set to 'true' to disable local saving and hide file view - # - API_HOST=kokoro-tts # Set TTS service URL - # - API_PORT=8880 # Set TTS service PORT + server2: + build: + context: ../.. + dockerfile: docker/gpu/Dockerfile + volumes: + - ../../api:/app/api + ports: + - "8880:8880" + environment: + - PYTHONPATH=/app:/app/api + - USE_GPU=true + - USE_ONNX=false + - PYTHONUNBUFFERED=1 + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: [gpu] + + server3: + build: + context: ../.. + dockerfile: docker/gpu/Dockerfile + volumes: + - ../../api:/app/api + ports: + - "8880:8880" + environment: + - PYTHONPATH=/app:/app/api + - USE_GPU=true + - USE_ONNX=false + - PYTHONUNBUFFERED=1 + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: [gpu] + + +nginx: + image: nginx:alpine + ports: + - "80:80" # Expose port 80 on the host machine + volumes: + - ./nginx.conf:/etc/nginx/nginx.conf # Load custom NGINX configuration + depends_on: + - server3 + - server1 + - server2 diff --git a/docker/gpu/nginx.conf b/docker/gpu/nginx.conf new file mode 100644 index 0000000..a11ab84 --- /dev/null +++ b/docker/gpu/nginx.conf @@ -0,0 +1,78 @@ +user nginx; +worker_processes auto; # Automatically adjust worker processes based on available CPUs + +events { + worker_connections 1024; # Maximum simultaneous connections per worker + use epoll; # Use efficient event handling for Linux +} + +http { + # Basic security headers + add_header X-Frame-Options SAMEORIGIN always; # Prevent clickjacking + add_header X-Content-Type-Options nosniff always; # Prevent MIME-type sniffing + add_header X-XSS-Protection "1; mode=block" always; # Enable XSS protection in browsers + add_header Strict-Transport-Security "max-age=31536000; includeSubDomains" always; # Enforce HTTPS + add_header Content-Security-Policy "default-src 'self';" always; # Restrict resource loading to same origin + + # Timeouts + sendfile on; # Enable sendfile for efficient file serving + tcp_nopush on; # Reduce packet overhead + tcp_nodelay on; # Minimize latency + keepalive_timeout 65; # Keep connections alive for 65 seconds + client_max_body_size 10m; # Limit request body size to 10MB + client_body_timeout 12; # Timeout for client body read + client_header_timeout 12; # Timeout for client header read + + # Compression + gzip on; # Enable gzip compression + gzip_disable "msie6"; # Disable gzip for old browsers + gzip_vary on; # Add "Vary: Accept-Encoding" header + gzip_proxied any; # Enable gzip for proxied requests + gzip_comp_level 6; # Compression level + gzip_types text/plain text/css application/json application/javascript text/xml application/xml application/xml+rss text/javascript; + + # Load balancing upstream + upstream backend { + least_conn; # Use least connections load balancing strategy + server server1:8880 max_fails=3 fail_timeout=5s; # Add health check for backend servers + # Uncomment additional servers for scaling: + server server2:8880 max_fails=3 fail_timeout=5s; + server server3:8880 max_fails=3 fail_timeout=5s; + } + + server { + listen 80; + + # Redirect HTTP to HTTPS (optional) + # Uncomment the lines below if SSL is configured: + # listen 443 ssl; + # ssl_certificate /path/to/certificate.crt; + # ssl_certificate_key /path/to/private.key; + + location / { + proxy_pass http://backend; # Proxy traffic to the backend servers + proxy_http_version 1.1; # Use HTTP/1.1 for persistent connections + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "upgrade"; + proxy_set_header Host $host; + proxy_set_header X-Forwarded-For $remote_addr; # Forward client IP + proxy_cache_bypass $http_upgrade; + proxy_read_timeout 60s; # Adjust read timeout for backend + proxy_connect_timeout 60s; # Adjust connection timeout for backend + proxy_send_timeout 60s; # Adjust send timeout for backend + } + + # Custom error pages + error_page 502 503 504 /50x.html; + location = /50x.html { + root /usr/share/nginx/html; + } + + # Deny access to hidden files (e.g., .git) + location ~ /\. { + deny all; + access_log off; + log_not_found off; + } + } +} diff --git a/examples/openai_streaming_audio.py b/examples/openai_streaming_audio.py index 353ee3d..4345e4d 100644 --- a/examples/openai_streaming_audio.py +++ b/examples/openai_streaming_audio.py @@ -31,7 +31,7 @@ def stream_to_speakers() -> None: with openai.audio.speech.with_streaming_response.create( model="kokoro", - voice="af_bella", + voice="0.100 * af + 0.300 * am_adam + 0.400 * am_michael + 0.100 * bf_emma + 0.100 * bm_lewis ", response_format="pcm", # similar to WAV, but without a header chunk at the start. input="""I see skies of blue and clouds of white The bright blessed days, the dark sacred nights diff --git a/examples/speech.mp3 b/examples/speech.mp3 index c21bec0..62bc37a 100644 Binary files a/examples/speech.mp3 and b/examples/speech.mp3 differ diff --git a/setup.sh b/setup.sh new file mode 100755 index 0000000..c8bda83 --- /dev/null +++ b/setup.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +# Ensure models directory exists +mkdir -p api/src/models + +# Function to download a file +download_file() { + local url="$1" + local filename=$(basename "$url") + echo "Downloading $filename..." + curl -L "$url" -o "api/src/models/$filename" +} + +# Default PTH model if no arguments provided +DEFAULT_MODELS=( + "https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth" +) + +# Use provided models or default +if [ $# -gt 0 ]; then + MODELS=("$@") +else + MODELS=("${DEFAULT_MODELS[@]}") +fi + +# Download all models +for model in "${MODELS[@]}"; do + download_file "$model" +done + +echo "PyTorch model download complete!" \ No newline at end of file