mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Merge pull request #145 from remsky/revert-92-v0.1.2-pre
Revert "Adds support for creating weighted voice combinations" (Implemented somewhat differently)
This commit is contained in:
commit
8f86d60319
7 changed files with 35 additions and 300 deletions
|
@ -182,106 +182,6 @@ class VoiceManager:
|
|||
except Exception:
|
||||
return False
|
||||
|
||||
async def create_weighted_voice(
|
||||
self,
|
||||
formula: str,
|
||||
normalize: bool = False,
|
||||
device: str = "cpu",
|
||||
) -> str:
|
||||
"""
|
||||
Parse the voice formula string (e.g. '0.3 * voiceA + 0.5 * voiceB')
|
||||
and return a combined torch.Tensor representing the weighted sum
|
||||
of the given voices.
|
||||
|
||||
Args:
|
||||
formula: Weighted voice formula string.
|
||||
voice_manager: A class that has a `load_voice(voice_name, device)` -> Tensor
|
||||
device: 'cpu' or 'cuda' for the final tensor.
|
||||
normalize: If True, divide the final result by the sum of the weights
|
||||
so the total "magnitude" remains consistent.
|
||||
|
||||
Returns:
|
||||
A torch.Tensor containing the combined voice embedding.
|
||||
"""
|
||||
pairs = self.parse_voice_formula(formula) # [(weight, voiceName), (weight, voiceName), ...]
|
||||
|
||||
# Validate the pairs
|
||||
for weight, voice_name in pairs:
|
||||
if weight <= 0:
|
||||
raise ValueError(f"Invalid weight {weight} for voice {voice_name}.")
|
||||
|
||||
if not pairs:
|
||||
raise ValueError("No valid weighted voices found in formula.")
|
||||
|
||||
# Keep track of total weight if we plan to normalize.
|
||||
total_weight = 0.0
|
||||
weighted_sum = None
|
||||
combined_name = ""
|
||||
|
||||
for weight, voice_name in pairs:
|
||||
# 1) Load each base voice from your manager/service
|
||||
base_voice = await self.load_voice(voice_name, device=device)
|
||||
|
||||
|
||||
# 3) Combine the base voices using the weights
|
||||
if combined_name == "":
|
||||
combined_name = voice_name
|
||||
else:
|
||||
combined_name += f"+{voice_name}"
|
||||
|
||||
|
||||
|
||||
|
||||
# 2) Multiply by weight and accumulate
|
||||
if weighted_sum is None:
|
||||
# Clone so we don't modify the base voice in memory
|
||||
weighted_sum = base_voice.clone() * weight
|
||||
else:
|
||||
weighted_sum += (base_voice * weight)
|
||||
|
||||
total_weight += weight
|
||||
|
||||
if weighted_sum is None:
|
||||
raise ValueError("No voices were combined. Check the formula syntax.")
|
||||
|
||||
# Optional normalization
|
||||
if normalize and total_weight != 0.0:
|
||||
weighted_sum /= total_weight
|
||||
|
||||
if settings.allow_local_voice_saving:
|
||||
|
||||
# Save to disk
|
||||
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
voices_dir = os.path.join(api_dir, settings.voices_dir)
|
||||
os.makedirs(voices_dir, exist_ok=True)
|
||||
combined_path = os.path.join(voices_dir, f"{formula}.pt")
|
||||
try:
|
||||
torch.save(weighted_sum, combined_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save combined voice: {e}")
|
||||
# Continue without saving - will be combined on-the-fly when needed
|
||||
|
||||
return combined_name
|
||||
|
||||
|
||||
|
||||
|
||||
def parse_voice_formula(self,formula: str) -> List[tuple[float, str]]:
|
||||
"""
|
||||
Parse the voice formula string (e.g. '0.3 * voiceA + 0.5 * voiceB')
|
||||
and return a list of (weight, voiceName) pairs.
|
||||
Args:
|
||||
formula: Weighted voice formula string.
|
||||
Returns:
|
||||
List of (weight, voiceName) pairs.
|
||||
"""
|
||||
pairs = []
|
||||
parts = formula.split('+')
|
||||
for part in parts:
|
||||
weight, voice_name = part.strip().split('*')
|
||||
pairs.append((float(weight), voice_name.strip()))
|
||||
return pairs
|
||||
|
||||
@property
|
||||
def cache_info(self) -> Dict[str, int]:
|
||||
"""Get cache statistics.
|
||||
|
|
|
@ -2,7 +2,7 @@ import json
|
|||
import os
|
||||
from typing import AsyncGenerator, Dict, List, Union
|
||||
|
||||
from fastapi import APIRouter, Header, HTTPException, Request, Response
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from loguru import logger
|
||||
|
||||
|
@ -112,17 +112,7 @@ async def stream_audio_chunks(
|
|||
client_request: Request
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""Stream audio chunks as they're generated with client disconnect handling"""
|
||||
# Check if 'request.voice' is a weighted formula (contains '*')
|
||||
if '*' in request.voice:
|
||||
# Weighted formula path
|
||||
voice_to_use = await tts_service._voice_manager.create_weighted_voice(
|
||||
formula=request.voice,
|
||||
normalize=True
|
||||
)
|
||||
|
||||
else:
|
||||
# Normal single or multi-voice path
|
||||
voice_to_use = await process_voices(request.voice, tts_service)
|
||||
voice_to_use = await process_voices(request.voice, tts_service)
|
||||
|
||||
try:
|
||||
async for chunk in tts_service.generate_audio_stream(
|
||||
|
@ -169,7 +159,10 @@ async def create_speech(
|
|||
# Get global service instance
|
||||
tts_service = await get_tts_service()
|
||||
|
||||
# Set content type based on format
|
||||
# Process voice combination and validate
|
||||
voice_to_use = await process_voices(request.voice, tts_service)
|
||||
|
||||
# Set content type based on format
|
||||
content_type = {
|
||||
"mp3": "audio/mpeg",
|
||||
"opus": "audio/opus",
|
||||
|
@ -216,27 +209,13 @@ async def create_speech(
|
|||
},
|
||||
)
|
||||
else:
|
||||
# Check if 'request.voice' is a weighted formula (contains '*')
|
||||
if '*' in request.voice:
|
||||
# Weighted formula path
|
||||
print("Weighted formula path")
|
||||
voice_to_use = await tts_service._voice_manager.create_weighted_voice(
|
||||
formula=request.voice,
|
||||
normalize=True
|
||||
)
|
||||
print(voice_to_use)
|
||||
else:
|
||||
# Normal single or multi-voice path
|
||||
print("Normal single or multi-voice path")
|
||||
# Otherwise, handle normal single or multi-voice logic
|
||||
voice_to_use = await process_voices(request.voice, tts_service)
|
||||
# Generate complete audio using public interface
|
||||
audio, _ = await tts_service.generate_audio(
|
||||
text=request.input,
|
||||
voice=voice_to_use,
|
||||
speed=request.speed,
|
||||
stitch_long_output=True
|
||||
)
|
||||
# Generate complete audio using public interface
|
||||
audio, _ = await tts_service.generate_audio(
|
||||
text=request.input,
|
||||
voice=voice_to_use,
|
||||
speed=request.speed,
|
||||
stitch_long_output=True
|
||||
)
|
||||
|
||||
# Convert to requested format
|
||||
content = await AudioService.convert_audio(
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
name: InstaVoice
|
||||
name: kokoro-tts
|
||||
services:
|
||||
server1:
|
||||
kokoro-tts:
|
||||
# image: ghcr.io/remsky/kokoro-fastapi-gpu:v0.1.0
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: docker/gpu/Dockerfile
|
||||
|
@ -18,59 +19,23 @@ services:
|
|||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: all
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
|
||||
server2:
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: docker/gpu/Dockerfile
|
||||
volumes:
|
||||
- ../../api:/app/api
|
||||
ports:
|
||||
- "8880:8880"
|
||||
environment:
|
||||
- PYTHONPATH=/app:/app/api
|
||||
- USE_GPU=true
|
||||
- USE_ONNX=false
|
||||
- PYTHONUNBUFFERED=1
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: all
|
||||
capabilities: [gpu]
|
||||
|
||||
server3:
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: docker/gpu/Dockerfile
|
||||
volumes:
|
||||
- ../../api:/app/api
|
||||
ports:
|
||||
- "8880:8880"
|
||||
environment:
|
||||
- PYTHONPATH=/app:/app/api
|
||||
- USE_GPU=true
|
||||
- USE_ONNX=false
|
||||
- PYTHONUNBUFFERED=1
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: all
|
||||
capabilities: [gpu]
|
||||
|
||||
|
||||
nginx:
|
||||
image: nginx:alpine
|
||||
ports:
|
||||
- "80:80" # Expose port 80 on the host machine
|
||||
volumes:
|
||||
- ./nginx.conf:/etc/nginx/nginx.conf # Load custom NGINX configuration
|
||||
depends_on:
|
||||
- server3
|
||||
- server1
|
||||
- server2
|
||||
# # Gradio UI service
|
||||
# gradio-ui:
|
||||
# image: ghcr.io/remsky/kokoro-fastapi-ui:v0.1.0
|
||||
# # Uncomment below to build from source instead of using the released image
|
||||
# # build:
|
||||
# # context: ../../ui
|
||||
# ports:
|
||||
# - "7860:7860"
|
||||
# volumes:
|
||||
# - ../../ui/data:/app/ui/data
|
||||
# - ../../ui/app.py:/app/app.py # Mount app.py for hot reload
|
||||
# environment:
|
||||
# - GRADIO_WATCH=1 # Enable hot reloading
|
||||
# - PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
|
||||
# - DISABLE_LOCAL_SAVING=false # Set to 'true' to disable local saving and hide file view
|
||||
# - API_HOST=kokoro-tts # Set TTS service URL
|
||||
# - API_PORT=8880 # Set TTS service PORT
|
||||
|
|
|
@ -1,78 +0,0 @@
|
|||
user nginx;
|
||||
worker_processes auto; # Automatically adjust worker processes based on available CPUs
|
||||
|
||||
events {
|
||||
worker_connections 1024; # Maximum simultaneous connections per worker
|
||||
use epoll; # Use efficient event handling for Linux
|
||||
}
|
||||
|
||||
http {
|
||||
# Basic security headers
|
||||
add_header X-Frame-Options SAMEORIGIN always; # Prevent clickjacking
|
||||
add_header X-Content-Type-Options nosniff always; # Prevent MIME-type sniffing
|
||||
add_header X-XSS-Protection "1; mode=block" always; # Enable XSS protection in browsers
|
||||
add_header Strict-Transport-Security "max-age=31536000; includeSubDomains" always; # Enforce HTTPS
|
||||
add_header Content-Security-Policy "default-src 'self';" always; # Restrict resource loading to same origin
|
||||
|
||||
# Timeouts
|
||||
sendfile on; # Enable sendfile for efficient file serving
|
||||
tcp_nopush on; # Reduce packet overhead
|
||||
tcp_nodelay on; # Minimize latency
|
||||
keepalive_timeout 65; # Keep connections alive for 65 seconds
|
||||
client_max_body_size 10m; # Limit request body size to 10MB
|
||||
client_body_timeout 12; # Timeout for client body read
|
||||
client_header_timeout 12; # Timeout for client header read
|
||||
|
||||
# Compression
|
||||
gzip on; # Enable gzip compression
|
||||
gzip_disable "msie6"; # Disable gzip for old browsers
|
||||
gzip_vary on; # Add "Vary: Accept-Encoding" header
|
||||
gzip_proxied any; # Enable gzip for proxied requests
|
||||
gzip_comp_level 6; # Compression level
|
||||
gzip_types text/plain text/css application/json application/javascript text/xml application/xml application/xml+rss text/javascript;
|
||||
|
||||
# Load balancing upstream
|
||||
upstream backend {
|
||||
least_conn; # Use least connections load balancing strategy
|
||||
server server1:8880 max_fails=3 fail_timeout=5s; # Add health check for backend servers
|
||||
# Uncomment additional servers for scaling:
|
||||
server server2:8880 max_fails=3 fail_timeout=5s;
|
||||
server server3:8880 max_fails=3 fail_timeout=5s;
|
||||
}
|
||||
|
||||
server {
|
||||
listen 80;
|
||||
|
||||
# Redirect HTTP to HTTPS (optional)
|
||||
# Uncomment the lines below if SSL is configured:
|
||||
# listen 443 ssl;
|
||||
# ssl_certificate /path/to/certificate.crt;
|
||||
# ssl_certificate_key /path/to/private.key;
|
||||
|
||||
location / {
|
||||
proxy_pass http://backend; # Proxy traffic to the backend servers
|
||||
proxy_http_version 1.1; # Use HTTP/1.1 for persistent connections
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection "upgrade";
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Forwarded-For $remote_addr; # Forward client IP
|
||||
proxy_cache_bypass $http_upgrade;
|
||||
proxy_read_timeout 60s; # Adjust read timeout for backend
|
||||
proxy_connect_timeout 60s; # Adjust connection timeout for backend
|
||||
proxy_send_timeout 60s; # Adjust send timeout for backend
|
||||
}
|
||||
|
||||
# Custom error pages
|
||||
error_page 502 503 504 /50x.html;
|
||||
location = /50x.html {
|
||||
root /usr/share/nginx/html;
|
||||
}
|
||||
|
||||
# Deny access to hidden files (e.g., .git)
|
||||
location ~ /\. {
|
||||
deny all;
|
||||
access_log off;
|
||||
log_not_found off;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -31,7 +31,7 @@ def stream_to_speakers() -> None:
|
|||
|
||||
with openai.audio.speech.with_streaming_response.create(
|
||||
model="kokoro",
|
||||
voice="0.100 * af + 0.300 * am_adam + 0.400 * am_michael + 0.100 * bf_emma + 0.100 * bm_lewis ",
|
||||
voice="af_bella",
|
||||
response_format="pcm", # similar to WAV, but without a header chunk at the start.
|
||||
input="""I see skies of blue and clouds of white
|
||||
The bright blessed days, the dark sacred nights
|
||||
|
|
Binary file not shown.
31
setup.sh
31
setup.sh
|
@ -1,31 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Ensure models directory exists
|
||||
mkdir -p api/src/models
|
||||
|
||||
# Function to download a file
|
||||
download_file() {
|
||||
local url="$1"
|
||||
local filename=$(basename "$url")
|
||||
echo "Downloading $filename..."
|
||||
curl -L "$url" -o "api/src/models/$filename"
|
||||
}
|
||||
|
||||
# Default PTH model if no arguments provided
|
||||
DEFAULT_MODELS=(
|
||||
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth"
|
||||
)
|
||||
|
||||
# Use provided models or default
|
||||
if [ $# -gt 0 ]; then
|
||||
MODELS=("$@")
|
||||
else
|
||||
MODELS=("${DEFAULT_MODELS[@]}")
|
||||
fi
|
||||
|
||||
# Download all models
|
||||
for model in "${MODELS[@]}"; do
|
||||
download_file "$model"
|
||||
done
|
||||
|
||||
echo "PyTorch model download complete!"
|
Loading…
Add table
Reference in a new issue