mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Merge pull request #92 from rvuyyuru2/v0.1.2-pre
Adds support for creating weighted voice combinations (reimplemented in v0.2.0)
This commit is contained in:
commit
d5709097e2
7 changed files with 300 additions and 35 deletions
|
@ -182,6 +182,106 @@ class VoiceManager:
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def create_weighted_voice(
|
||||||
|
self,
|
||||||
|
formula: str,
|
||||||
|
normalize: bool = False,
|
||||||
|
device: str = "cpu",
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Parse the voice formula string (e.g. '0.3 * voiceA + 0.5 * voiceB')
|
||||||
|
and return a combined torch.Tensor representing the weighted sum
|
||||||
|
of the given voices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
formula: Weighted voice formula string.
|
||||||
|
voice_manager: A class that has a `load_voice(voice_name, device)` -> Tensor
|
||||||
|
device: 'cpu' or 'cuda' for the final tensor.
|
||||||
|
normalize: If True, divide the final result by the sum of the weights
|
||||||
|
so the total "magnitude" remains consistent.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A torch.Tensor containing the combined voice embedding.
|
||||||
|
"""
|
||||||
|
pairs = self.parse_voice_formula(formula) # [(weight, voiceName), (weight, voiceName), ...]
|
||||||
|
|
||||||
|
# Validate the pairs
|
||||||
|
for weight, voice_name in pairs:
|
||||||
|
if weight <= 0:
|
||||||
|
raise ValueError(f"Invalid weight {weight} for voice {voice_name}.")
|
||||||
|
|
||||||
|
if not pairs:
|
||||||
|
raise ValueError("No valid weighted voices found in formula.")
|
||||||
|
|
||||||
|
# Keep track of total weight if we plan to normalize.
|
||||||
|
total_weight = 0.0
|
||||||
|
weighted_sum = None
|
||||||
|
combined_name = ""
|
||||||
|
|
||||||
|
for weight, voice_name in pairs:
|
||||||
|
# 1) Load each base voice from your manager/service
|
||||||
|
base_voice = await self.load_voice(voice_name, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
# 3) Combine the base voices using the weights
|
||||||
|
if combined_name == "":
|
||||||
|
combined_name = voice_name
|
||||||
|
else:
|
||||||
|
combined_name += f"+{voice_name}"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 2) Multiply by weight and accumulate
|
||||||
|
if weighted_sum is None:
|
||||||
|
# Clone so we don't modify the base voice in memory
|
||||||
|
weighted_sum = base_voice.clone() * weight
|
||||||
|
else:
|
||||||
|
weighted_sum += (base_voice * weight)
|
||||||
|
|
||||||
|
total_weight += weight
|
||||||
|
|
||||||
|
if weighted_sum is None:
|
||||||
|
raise ValueError("No voices were combined. Check the formula syntax.")
|
||||||
|
|
||||||
|
# Optional normalization
|
||||||
|
if normalize and total_weight != 0.0:
|
||||||
|
weighted_sum /= total_weight
|
||||||
|
|
||||||
|
if settings.allow_local_voice_saving:
|
||||||
|
|
||||||
|
# Save to disk
|
||||||
|
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
voices_dir = os.path.join(api_dir, settings.voices_dir)
|
||||||
|
os.makedirs(voices_dir, exist_ok=True)
|
||||||
|
combined_path = os.path.join(voices_dir, f"{formula}.pt")
|
||||||
|
try:
|
||||||
|
torch.save(weighted_sum, combined_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to save combined voice: {e}")
|
||||||
|
# Continue without saving - will be combined on-the-fly when needed
|
||||||
|
|
||||||
|
return combined_name
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def parse_voice_formula(self,formula: str) -> List[tuple[float, str]]:
|
||||||
|
"""
|
||||||
|
Parse the voice formula string (e.g. '0.3 * voiceA + 0.5 * voiceB')
|
||||||
|
and return a list of (weight, voiceName) pairs.
|
||||||
|
Args:
|
||||||
|
formula: Weighted voice formula string.
|
||||||
|
Returns:
|
||||||
|
List of (weight, voiceName) pairs.
|
||||||
|
"""
|
||||||
|
pairs = []
|
||||||
|
parts = formula.split('+')
|
||||||
|
for part in parts:
|
||||||
|
weight, voice_name = part.strip().split('*')
|
||||||
|
pairs.append((float(weight), voice_name.strip()))
|
||||||
|
return pairs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cache_info(self) -> Dict[str, int]:
|
def cache_info(self) -> Dict[str, int]:
|
||||||
"""Get cache statistics.
|
"""Get cache statistics.
|
||||||
|
|
|
@ -2,7 +2,7 @@ import json
|
||||||
import os
|
import os
|
||||||
from typing import AsyncGenerator, Dict, List, Union
|
from typing import AsyncGenerator, Dict, List, Union
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
from fastapi import APIRouter, Header, HTTPException, Request, Response
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
@ -112,6 +112,16 @@ async def stream_audio_chunks(
|
||||||
client_request: Request
|
client_request: Request
|
||||||
) -> AsyncGenerator[bytes, None]:
|
) -> AsyncGenerator[bytes, None]:
|
||||||
"""Stream audio chunks as they're generated with client disconnect handling"""
|
"""Stream audio chunks as they're generated with client disconnect handling"""
|
||||||
|
# Check if 'request.voice' is a weighted formula (contains '*')
|
||||||
|
if '*' in request.voice:
|
||||||
|
# Weighted formula path
|
||||||
|
voice_to_use = await tts_service._voice_manager.create_weighted_voice(
|
||||||
|
formula=request.voice,
|
||||||
|
normalize=True
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Normal single or multi-voice path
|
||||||
voice_to_use = await process_voices(request.voice, tts_service)
|
voice_to_use = await process_voices(request.voice, tts_service)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -159,9 +169,6 @@ async def create_speech(
|
||||||
# Get global service instance
|
# Get global service instance
|
||||||
tts_service = await get_tts_service()
|
tts_service = await get_tts_service()
|
||||||
|
|
||||||
# Process voice combination and validate
|
|
||||||
voice_to_use = await process_voices(request.voice, tts_service)
|
|
||||||
|
|
||||||
# Set content type based on format
|
# Set content type based on format
|
||||||
content_type = {
|
content_type = {
|
||||||
"mp3": "audio/mpeg",
|
"mp3": "audio/mpeg",
|
||||||
|
@ -209,6 +216,20 @@ async def create_speech(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# Check if 'request.voice' is a weighted formula (contains '*')
|
||||||
|
if '*' in request.voice:
|
||||||
|
# Weighted formula path
|
||||||
|
print("Weighted formula path")
|
||||||
|
voice_to_use = await tts_service._voice_manager.create_weighted_voice(
|
||||||
|
formula=request.voice,
|
||||||
|
normalize=True
|
||||||
|
)
|
||||||
|
print(voice_to_use)
|
||||||
|
else:
|
||||||
|
# Normal single or multi-voice path
|
||||||
|
print("Normal single or multi-voice path")
|
||||||
|
# Otherwise, handle normal single or multi-voice logic
|
||||||
|
voice_to_use = await process_voices(request.voice, tts_service)
|
||||||
# Generate complete audio using public interface
|
# Generate complete audio using public interface
|
||||||
audio, _ = await tts_service.generate_audio(
|
audio, _ = await tts_service.generate_audio(
|
||||||
text=request.input,
|
text=request.input,
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
name: kokoro-tts
|
name: InstaVoice
|
||||||
services:
|
services:
|
||||||
kokoro-tts:
|
server1:
|
||||||
# image: ghcr.io/remsky/kokoro-fastapi-gpu:v0.1.0
|
|
||||||
build:
|
build:
|
||||||
context: ../..
|
context: ../..
|
||||||
dockerfile: docker/gpu/Dockerfile
|
dockerfile: docker/gpu/Dockerfile
|
||||||
|
@ -19,23 +18,59 @@ services:
|
||||||
reservations:
|
reservations:
|
||||||
devices:
|
devices:
|
||||||
- driver: nvidia
|
- driver: nvidia
|
||||||
count: 1
|
count: all
|
||||||
capabilities: [gpu]
|
capabilities: [gpu]
|
||||||
|
|
||||||
# # Gradio UI service
|
server2:
|
||||||
# gradio-ui:
|
build:
|
||||||
# image: ghcr.io/remsky/kokoro-fastapi-ui:v0.1.0
|
context: ../..
|
||||||
# # Uncomment below to build from source instead of using the released image
|
dockerfile: docker/gpu/Dockerfile
|
||||||
# # build:
|
volumes:
|
||||||
# # context: ../../ui
|
- ../../api:/app/api
|
||||||
# ports:
|
ports:
|
||||||
# - "7860:7860"
|
- "8880:8880"
|
||||||
# volumes:
|
environment:
|
||||||
# - ../../ui/data:/app/ui/data
|
- PYTHONPATH=/app:/app/api
|
||||||
# - ../../ui/app.py:/app/app.py # Mount app.py for hot reload
|
- USE_GPU=true
|
||||||
# environment:
|
- USE_ONNX=false
|
||||||
# - GRADIO_WATCH=1 # Enable hot reloading
|
- PYTHONUNBUFFERED=1
|
||||||
# - PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
|
deploy:
|
||||||
# - DISABLE_LOCAL_SAVING=false # Set to 'true' to disable local saving and hide file view
|
resources:
|
||||||
# - API_HOST=kokoro-tts # Set TTS service URL
|
reservations:
|
||||||
# - API_PORT=8880 # Set TTS service PORT
|
devices:
|
||||||
|
- driver: nvidia
|
||||||
|
count: all
|
||||||
|
capabilities: [gpu]
|
||||||
|
|
||||||
|
server3:
|
||||||
|
build:
|
||||||
|
context: ../..
|
||||||
|
dockerfile: docker/gpu/Dockerfile
|
||||||
|
volumes:
|
||||||
|
- ../../api:/app/api
|
||||||
|
ports:
|
||||||
|
- "8880:8880"
|
||||||
|
environment:
|
||||||
|
- PYTHONPATH=/app:/app/api
|
||||||
|
- USE_GPU=true
|
||||||
|
- USE_ONNX=false
|
||||||
|
- PYTHONUNBUFFERED=1
|
||||||
|
deploy:
|
||||||
|
resources:
|
||||||
|
reservations:
|
||||||
|
devices:
|
||||||
|
- driver: nvidia
|
||||||
|
count: all
|
||||||
|
capabilities: [gpu]
|
||||||
|
|
||||||
|
|
||||||
|
nginx:
|
||||||
|
image: nginx:alpine
|
||||||
|
ports:
|
||||||
|
- "80:80" # Expose port 80 on the host machine
|
||||||
|
volumes:
|
||||||
|
- ./nginx.conf:/etc/nginx/nginx.conf # Load custom NGINX configuration
|
||||||
|
depends_on:
|
||||||
|
- server3
|
||||||
|
- server1
|
||||||
|
- server2
|
||||||
|
|
78
docker/gpu/nginx.conf
Normal file
78
docker/gpu/nginx.conf
Normal file
|
@ -0,0 +1,78 @@
|
||||||
|
user nginx;
|
||||||
|
worker_processes auto; # Automatically adjust worker processes based on available CPUs
|
||||||
|
|
||||||
|
events {
|
||||||
|
worker_connections 1024; # Maximum simultaneous connections per worker
|
||||||
|
use epoll; # Use efficient event handling for Linux
|
||||||
|
}
|
||||||
|
|
||||||
|
http {
|
||||||
|
# Basic security headers
|
||||||
|
add_header X-Frame-Options SAMEORIGIN always; # Prevent clickjacking
|
||||||
|
add_header X-Content-Type-Options nosniff always; # Prevent MIME-type sniffing
|
||||||
|
add_header X-XSS-Protection "1; mode=block" always; # Enable XSS protection in browsers
|
||||||
|
add_header Strict-Transport-Security "max-age=31536000; includeSubDomains" always; # Enforce HTTPS
|
||||||
|
add_header Content-Security-Policy "default-src 'self';" always; # Restrict resource loading to same origin
|
||||||
|
|
||||||
|
# Timeouts
|
||||||
|
sendfile on; # Enable sendfile for efficient file serving
|
||||||
|
tcp_nopush on; # Reduce packet overhead
|
||||||
|
tcp_nodelay on; # Minimize latency
|
||||||
|
keepalive_timeout 65; # Keep connections alive for 65 seconds
|
||||||
|
client_max_body_size 10m; # Limit request body size to 10MB
|
||||||
|
client_body_timeout 12; # Timeout for client body read
|
||||||
|
client_header_timeout 12; # Timeout for client header read
|
||||||
|
|
||||||
|
# Compression
|
||||||
|
gzip on; # Enable gzip compression
|
||||||
|
gzip_disable "msie6"; # Disable gzip for old browsers
|
||||||
|
gzip_vary on; # Add "Vary: Accept-Encoding" header
|
||||||
|
gzip_proxied any; # Enable gzip for proxied requests
|
||||||
|
gzip_comp_level 6; # Compression level
|
||||||
|
gzip_types text/plain text/css application/json application/javascript text/xml application/xml application/xml+rss text/javascript;
|
||||||
|
|
||||||
|
# Load balancing upstream
|
||||||
|
upstream backend {
|
||||||
|
least_conn; # Use least connections load balancing strategy
|
||||||
|
server server1:8880 max_fails=3 fail_timeout=5s; # Add health check for backend servers
|
||||||
|
# Uncomment additional servers for scaling:
|
||||||
|
server server2:8880 max_fails=3 fail_timeout=5s;
|
||||||
|
server server3:8880 max_fails=3 fail_timeout=5s;
|
||||||
|
}
|
||||||
|
|
||||||
|
server {
|
||||||
|
listen 80;
|
||||||
|
|
||||||
|
# Redirect HTTP to HTTPS (optional)
|
||||||
|
# Uncomment the lines below if SSL is configured:
|
||||||
|
# listen 443 ssl;
|
||||||
|
# ssl_certificate /path/to/certificate.crt;
|
||||||
|
# ssl_certificate_key /path/to/private.key;
|
||||||
|
|
||||||
|
location / {
|
||||||
|
proxy_pass http://backend; # Proxy traffic to the backend servers
|
||||||
|
proxy_http_version 1.1; # Use HTTP/1.1 for persistent connections
|
||||||
|
proxy_set_header Upgrade $http_upgrade;
|
||||||
|
proxy_set_header Connection "upgrade";
|
||||||
|
proxy_set_header Host $host;
|
||||||
|
proxy_set_header X-Forwarded-For $remote_addr; # Forward client IP
|
||||||
|
proxy_cache_bypass $http_upgrade;
|
||||||
|
proxy_read_timeout 60s; # Adjust read timeout for backend
|
||||||
|
proxy_connect_timeout 60s; # Adjust connection timeout for backend
|
||||||
|
proxy_send_timeout 60s; # Adjust send timeout for backend
|
||||||
|
}
|
||||||
|
|
||||||
|
# Custom error pages
|
||||||
|
error_page 502 503 504 /50x.html;
|
||||||
|
location = /50x.html {
|
||||||
|
root /usr/share/nginx/html;
|
||||||
|
}
|
||||||
|
|
||||||
|
# Deny access to hidden files (e.g., .git)
|
||||||
|
location ~ /\. {
|
||||||
|
deny all;
|
||||||
|
access_log off;
|
||||||
|
log_not_found off;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -31,7 +31,7 @@ def stream_to_speakers() -> None:
|
||||||
|
|
||||||
with openai.audio.speech.with_streaming_response.create(
|
with openai.audio.speech.with_streaming_response.create(
|
||||||
model="kokoro",
|
model="kokoro",
|
||||||
voice="af_bella",
|
voice="0.100 * af + 0.300 * am_adam + 0.400 * am_michael + 0.100 * bf_emma + 0.100 * bm_lewis ",
|
||||||
response_format="pcm", # similar to WAV, but without a header chunk at the start.
|
response_format="pcm", # similar to WAV, but without a header chunk at the start.
|
||||||
input="""I see skies of blue and clouds of white
|
input="""I see skies of blue and clouds of white
|
||||||
The bright blessed days, the dark sacred nights
|
The bright blessed days, the dark sacred nights
|
||||||
|
|
Binary file not shown.
31
setup.sh
Executable file
31
setup.sh
Executable file
|
@ -0,0 +1,31 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Ensure models directory exists
|
||||||
|
mkdir -p api/src/models
|
||||||
|
|
||||||
|
# Function to download a file
|
||||||
|
download_file() {
|
||||||
|
local url="$1"
|
||||||
|
local filename=$(basename "$url")
|
||||||
|
echo "Downloading $filename..."
|
||||||
|
curl -L "$url" -o "api/src/models/$filename"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Default PTH model if no arguments provided
|
||||||
|
DEFAULT_MODELS=(
|
||||||
|
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use provided models or default
|
||||||
|
if [ $# -gt 0 ]; then
|
||||||
|
MODELS=("$@")
|
||||||
|
else
|
||||||
|
MODELS=("${DEFAULT_MODELS[@]}")
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Download all models
|
||||||
|
for model in "${MODELS[@]}"; do
|
||||||
|
download_file "$model"
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "PyTorch model download complete!"
|
Loading…
Add table
Reference in a new issue