mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Revert "Adds support for creating weighted voice combinations"
This commit is contained in:
parent
d5709097e2
commit
f11a6b3e2b
7 changed files with 35 additions and 300 deletions
|
@ -182,106 +182,6 @@ class VoiceManager:
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def create_weighted_voice(
|
|
||||||
self,
|
|
||||||
formula: str,
|
|
||||||
normalize: bool = False,
|
|
||||||
device: str = "cpu",
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Parse the voice formula string (e.g. '0.3 * voiceA + 0.5 * voiceB')
|
|
||||||
and return a combined torch.Tensor representing the weighted sum
|
|
||||||
of the given voices.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
formula: Weighted voice formula string.
|
|
||||||
voice_manager: A class that has a `load_voice(voice_name, device)` -> Tensor
|
|
||||||
device: 'cpu' or 'cuda' for the final tensor.
|
|
||||||
normalize: If True, divide the final result by the sum of the weights
|
|
||||||
so the total "magnitude" remains consistent.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A torch.Tensor containing the combined voice embedding.
|
|
||||||
"""
|
|
||||||
pairs = self.parse_voice_formula(formula) # [(weight, voiceName), (weight, voiceName), ...]
|
|
||||||
|
|
||||||
# Validate the pairs
|
|
||||||
for weight, voice_name in pairs:
|
|
||||||
if weight <= 0:
|
|
||||||
raise ValueError(f"Invalid weight {weight} for voice {voice_name}.")
|
|
||||||
|
|
||||||
if not pairs:
|
|
||||||
raise ValueError("No valid weighted voices found in formula.")
|
|
||||||
|
|
||||||
# Keep track of total weight if we plan to normalize.
|
|
||||||
total_weight = 0.0
|
|
||||||
weighted_sum = None
|
|
||||||
combined_name = ""
|
|
||||||
|
|
||||||
for weight, voice_name in pairs:
|
|
||||||
# 1) Load each base voice from your manager/service
|
|
||||||
base_voice = await self.load_voice(voice_name, device=device)
|
|
||||||
|
|
||||||
|
|
||||||
# 3) Combine the base voices using the weights
|
|
||||||
if combined_name == "":
|
|
||||||
combined_name = voice_name
|
|
||||||
else:
|
|
||||||
combined_name += f"+{voice_name}"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 2) Multiply by weight and accumulate
|
|
||||||
if weighted_sum is None:
|
|
||||||
# Clone so we don't modify the base voice in memory
|
|
||||||
weighted_sum = base_voice.clone() * weight
|
|
||||||
else:
|
|
||||||
weighted_sum += (base_voice * weight)
|
|
||||||
|
|
||||||
total_weight += weight
|
|
||||||
|
|
||||||
if weighted_sum is None:
|
|
||||||
raise ValueError("No voices were combined. Check the formula syntax.")
|
|
||||||
|
|
||||||
# Optional normalization
|
|
||||||
if normalize and total_weight != 0.0:
|
|
||||||
weighted_sum /= total_weight
|
|
||||||
|
|
||||||
if settings.allow_local_voice_saving:
|
|
||||||
|
|
||||||
# Save to disk
|
|
||||||
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
|
||||||
voices_dir = os.path.join(api_dir, settings.voices_dir)
|
|
||||||
os.makedirs(voices_dir, exist_ok=True)
|
|
||||||
combined_path = os.path.join(voices_dir, f"{formula}.pt")
|
|
||||||
try:
|
|
||||||
torch.save(weighted_sum, combined_path)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to save combined voice: {e}")
|
|
||||||
# Continue without saving - will be combined on-the-fly when needed
|
|
||||||
|
|
||||||
return combined_name
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def parse_voice_formula(self,formula: str) -> List[tuple[float, str]]:
|
|
||||||
"""
|
|
||||||
Parse the voice formula string (e.g. '0.3 * voiceA + 0.5 * voiceB')
|
|
||||||
and return a list of (weight, voiceName) pairs.
|
|
||||||
Args:
|
|
||||||
formula: Weighted voice formula string.
|
|
||||||
Returns:
|
|
||||||
List of (weight, voiceName) pairs.
|
|
||||||
"""
|
|
||||||
pairs = []
|
|
||||||
parts = formula.split('+')
|
|
||||||
for part in parts:
|
|
||||||
weight, voice_name = part.strip().split('*')
|
|
||||||
pairs.append((float(weight), voice_name.strip()))
|
|
||||||
return pairs
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cache_info(self) -> Dict[str, int]:
|
def cache_info(self) -> Dict[str, int]:
|
||||||
"""Get cache statistics.
|
"""Get cache statistics.
|
||||||
|
|
|
@ -2,7 +2,7 @@ import json
|
||||||
import os
|
import os
|
||||||
from typing import AsyncGenerator, Dict, List, Union
|
from typing import AsyncGenerator, Dict, List, Union
|
||||||
|
|
||||||
from fastapi import APIRouter, Header, HTTPException, Request, Response
|
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
@ -112,17 +112,7 @@ async def stream_audio_chunks(
|
||||||
client_request: Request
|
client_request: Request
|
||||||
) -> AsyncGenerator[bytes, None]:
|
) -> AsyncGenerator[bytes, None]:
|
||||||
"""Stream audio chunks as they're generated with client disconnect handling"""
|
"""Stream audio chunks as they're generated with client disconnect handling"""
|
||||||
# Check if 'request.voice' is a weighted formula (contains '*')
|
voice_to_use = await process_voices(request.voice, tts_service)
|
||||||
if '*' in request.voice:
|
|
||||||
# Weighted formula path
|
|
||||||
voice_to_use = await tts_service._voice_manager.create_weighted_voice(
|
|
||||||
formula=request.voice,
|
|
||||||
normalize=True
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Normal single or multi-voice path
|
|
||||||
voice_to_use = await process_voices(request.voice, tts_service)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for chunk in tts_service.generate_audio_stream(
|
async for chunk in tts_service.generate_audio_stream(
|
||||||
|
@ -169,7 +159,10 @@ async def create_speech(
|
||||||
# Get global service instance
|
# Get global service instance
|
||||||
tts_service = await get_tts_service()
|
tts_service = await get_tts_service()
|
||||||
|
|
||||||
# Set content type based on format
|
# Process voice combination and validate
|
||||||
|
voice_to_use = await process_voices(request.voice, tts_service)
|
||||||
|
|
||||||
|
# Set content type based on format
|
||||||
content_type = {
|
content_type = {
|
||||||
"mp3": "audio/mpeg",
|
"mp3": "audio/mpeg",
|
||||||
"opus": "audio/opus",
|
"opus": "audio/opus",
|
||||||
|
@ -216,27 +209,13 @@ async def create_speech(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Check if 'request.voice' is a weighted formula (contains '*')
|
# Generate complete audio using public interface
|
||||||
if '*' in request.voice:
|
audio, _ = await tts_service.generate_audio(
|
||||||
# Weighted formula path
|
text=request.input,
|
||||||
print("Weighted formula path")
|
voice=voice_to_use,
|
||||||
voice_to_use = await tts_service._voice_manager.create_weighted_voice(
|
speed=request.speed,
|
||||||
formula=request.voice,
|
stitch_long_output=True
|
||||||
normalize=True
|
)
|
||||||
)
|
|
||||||
print(voice_to_use)
|
|
||||||
else:
|
|
||||||
# Normal single or multi-voice path
|
|
||||||
print("Normal single or multi-voice path")
|
|
||||||
# Otherwise, handle normal single or multi-voice logic
|
|
||||||
voice_to_use = await process_voices(request.voice, tts_service)
|
|
||||||
# Generate complete audio using public interface
|
|
||||||
audio, _ = await tts_service.generate_audio(
|
|
||||||
text=request.input,
|
|
||||||
voice=voice_to_use,
|
|
||||||
speed=request.speed,
|
|
||||||
stitch_long_output=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert to requested format
|
# Convert to requested format
|
||||||
content = await AudioService.convert_audio(
|
content = await AudioService.convert_audio(
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
name: InstaVoice
|
name: kokoro-tts
|
||||||
services:
|
services:
|
||||||
server1:
|
kokoro-tts:
|
||||||
|
# image: ghcr.io/remsky/kokoro-fastapi-gpu:v0.1.0
|
||||||
build:
|
build:
|
||||||
context: ../..
|
context: ../..
|
||||||
dockerfile: docker/gpu/Dockerfile
|
dockerfile: docker/gpu/Dockerfile
|
||||||
|
@ -18,59 +19,23 @@ services:
|
||||||
reservations:
|
reservations:
|
||||||
devices:
|
devices:
|
||||||
- driver: nvidia
|
- driver: nvidia
|
||||||
count: all
|
count: 1
|
||||||
capabilities: [gpu]
|
capabilities: [gpu]
|
||||||
|
|
||||||
server2:
|
# # Gradio UI service
|
||||||
build:
|
# gradio-ui:
|
||||||
context: ../..
|
# image: ghcr.io/remsky/kokoro-fastapi-ui:v0.1.0
|
||||||
dockerfile: docker/gpu/Dockerfile
|
# # Uncomment below to build from source instead of using the released image
|
||||||
volumes:
|
# # build:
|
||||||
- ../../api:/app/api
|
# # context: ../../ui
|
||||||
ports:
|
# ports:
|
||||||
- "8880:8880"
|
# - "7860:7860"
|
||||||
environment:
|
# volumes:
|
||||||
- PYTHONPATH=/app:/app/api
|
# - ../../ui/data:/app/ui/data
|
||||||
- USE_GPU=true
|
# - ../../ui/app.py:/app/app.py # Mount app.py for hot reload
|
||||||
- USE_ONNX=false
|
# environment:
|
||||||
- PYTHONUNBUFFERED=1
|
# - GRADIO_WATCH=1 # Enable hot reloading
|
||||||
deploy:
|
# - PYTHONUNBUFFERED=1 # Ensure Python output is not buffered
|
||||||
resources:
|
# - DISABLE_LOCAL_SAVING=false # Set to 'true' to disable local saving and hide file view
|
||||||
reservations:
|
# - API_HOST=kokoro-tts # Set TTS service URL
|
||||||
devices:
|
# - API_PORT=8880 # Set TTS service PORT
|
||||||
- driver: nvidia
|
|
||||||
count: all
|
|
||||||
capabilities: [gpu]
|
|
||||||
|
|
||||||
server3:
|
|
||||||
build:
|
|
||||||
context: ../..
|
|
||||||
dockerfile: docker/gpu/Dockerfile
|
|
||||||
volumes:
|
|
||||||
- ../../api:/app/api
|
|
||||||
ports:
|
|
||||||
- "8880:8880"
|
|
||||||
environment:
|
|
||||||
- PYTHONPATH=/app:/app/api
|
|
||||||
- USE_GPU=true
|
|
||||||
- USE_ONNX=false
|
|
||||||
- PYTHONUNBUFFERED=1
|
|
||||||
deploy:
|
|
||||||
resources:
|
|
||||||
reservations:
|
|
||||||
devices:
|
|
||||||
- driver: nvidia
|
|
||||||
count: all
|
|
||||||
capabilities: [gpu]
|
|
||||||
|
|
||||||
|
|
||||||
nginx:
|
|
||||||
image: nginx:alpine
|
|
||||||
ports:
|
|
||||||
- "80:80" # Expose port 80 on the host machine
|
|
||||||
volumes:
|
|
||||||
- ./nginx.conf:/etc/nginx/nginx.conf # Load custom NGINX configuration
|
|
||||||
depends_on:
|
|
||||||
- server3
|
|
||||||
- server1
|
|
||||||
- server2
|
|
||||||
|
|
|
@ -1,78 +0,0 @@
|
||||||
user nginx;
|
|
||||||
worker_processes auto; # Automatically adjust worker processes based on available CPUs
|
|
||||||
|
|
||||||
events {
|
|
||||||
worker_connections 1024; # Maximum simultaneous connections per worker
|
|
||||||
use epoll; # Use efficient event handling for Linux
|
|
||||||
}
|
|
||||||
|
|
||||||
http {
|
|
||||||
# Basic security headers
|
|
||||||
add_header X-Frame-Options SAMEORIGIN always; # Prevent clickjacking
|
|
||||||
add_header X-Content-Type-Options nosniff always; # Prevent MIME-type sniffing
|
|
||||||
add_header X-XSS-Protection "1; mode=block" always; # Enable XSS protection in browsers
|
|
||||||
add_header Strict-Transport-Security "max-age=31536000; includeSubDomains" always; # Enforce HTTPS
|
|
||||||
add_header Content-Security-Policy "default-src 'self';" always; # Restrict resource loading to same origin
|
|
||||||
|
|
||||||
# Timeouts
|
|
||||||
sendfile on; # Enable sendfile for efficient file serving
|
|
||||||
tcp_nopush on; # Reduce packet overhead
|
|
||||||
tcp_nodelay on; # Minimize latency
|
|
||||||
keepalive_timeout 65; # Keep connections alive for 65 seconds
|
|
||||||
client_max_body_size 10m; # Limit request body size to 10MB
|
|
||||||
client_body_timeout 12; # Timeout for client body read
|
|
||||||
client_header_timeout 12; # Timeout for client header read
|
|
||||||
|
|
||||||
# Compression
|
|
||||||
gzip on; # Enable gzip compression
|
|
||||||
gzip_disable "msie6"; # Disable gzip for old browsers
|
|
||||||
gzip_vary on; # Add "Vary: Accept-Encoding" header
|
|
||||||
gzip_proxied any; # Enable gzip for proxied requests
|
|
||||||
gzip_comp_level 6; # Compression level
|
|
||||||
gzip_types text/plain text/css application/json application/javascript text/xml application/xml application/xml+rss text/javascript;
|
|
||||||
|
|
||||||
# Load balancing upstream
|
|
||||||
upstream backend {
|
|
||||||
least_conn; # Use least connections load balancing strategy
|
|
||||||
server server1:8880 max_fails=3 fail_timeout=5s; # Add health check for backend servers
|
|
||||||
# Uncomment additional servers for scaling:
|
|
||||||
server server2:8880 max_fails=3 fail_timeout=5s;
|
|
||||||
server server3:8880 max_fails=3 fail_timeout=5s;
|
|
||||||
}
|
|
||||||
|
|
||||||
server {
|
|
||||||
listen 80;
|
|
||||||
|
|
||||||
# Redirect HTTP to HTTPS (optional)
|
|
||||||
# Uncomment the lines below if SSL is configured:
|
|
||||||
# listen 443 ssl;
|
|
||||||
# ssl_certificate /path/to/certificate.crt;
|
|
||||||
# ssl_certificate_key /path/to/private.key;
|
|
||||||
|
|
||||||
location / {
|
|
||||||
proxy_pass http://backend; # Proxy traffic to the backend servers
|
|
||||||
proxy_http_version 1.1; # Use HTTP/1.1 for persistent connections
|
|
||||||
proxy_set_header Upgrade $http_upgrade;
|
|
||||||
proxy_set_header Connection "upgrade";
|
|
||||||
proxy_set_header Host $host;
|
|
||||||
proxy_set_header X-Forwarded-For $remote_addr; # Forward client IP
|
|
||||||
proxy_cache_bypass $http_upgrade;
|
|
||||||
proxy_read_timeout 60s; # Adjust read timeout for backend
|
|
||||||
proxy_connect_timeout 60s; # Adjust connection timeout for backend
|
|
||||||
proxy_send_timeout 60s; # Adjust send timeout for backend
|
|
||||||
}
|
|
||||||
|
|
||||||
# Custom error pages
|
|
||||||
error_page 502 503 504 /50x.html;
|
|
||||||
location = /50x.html {
|
|
||||||
root /usr/share/nginx/html;
|
|
||||||
}
|
|
||||||
|
|
||||||
# Deny access to hidden files (e.g., .git)
|
|
||||||
location ~ /\. {
|
|
||||||
deny all;
|
|
||||||
access_log off;
|
|
||||||
log_not_found off;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -31,7 +31,7 @@ def stream_to_speakers() -> None:
|
||||||
|
|
||||||
with openai.audio.speech.with_streaming_response.create(
|
with openai.audio.speech.with_streaming_response.create(
|
||||||
model="kokoro",
|
model="kokoro",
|
||||||
voice="0.100 * af + 0.300 * am_adam + 0.400 * am_michael + 0.100 * bf_emma + 0.100 * bm_lewis ",
|
voice="af_bella",
|
||||||
response_format="pcm", # similar to WAV, but without a header chunk at the start.
|
response_format="pcm", # similar to WAV, but without a header chunk at the start.
|
||||||
input="""I see skies of blue and clouds of white
|
input="""I see skies of blue and clouds of white
|
||||||
The bright blessed days, the dark sacred nights
|
The bright blessed days, the dark sacred nights
|
||||||
|
|
Binary file not shown.
31
setup.sh
31
setup.sh
|
@ -1,31 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# Ensure models directory exists
|
|
||||||
mkdir -p api/src/models
|
|
||||||
|
|
||||||
# Function to download a file
|
|
||||||
download_file() {
|
|
||||||
local url="$1"
|
|
||||||
local filename=$(basename "$url")
|
|
||||||
echo "Downloading $filename..."
|
|
||||||
curl -L "$url" -o "api/src/models/$filename"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Default PTH model if no arguments provided
|
|
||||||
DEFAULT_MODELS=(
|
|
||||||
"https://github.com/remsky/Kokoro-FastAPI/releases/download/v0.1.0/kokoro-v0_19.pth"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use provided models or default
|
|
||||||
if [ $# -gt 0 ]; then
|
|
||||||
MODELS=("$@")
|
|
||||||
else
|
|
||||||
MODELS=("${DEFAULT_MODELS[@]}")
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Download all models
|
|
||||||
for model in "${MODELS[@]}"; do
|
|
||||||
download_file "$model"
|
|
||||||
done
|
|
||||||
|
|
||||||
echo "PyTorch model download complete!"
|
|
Loading…
Add table
Reference in a new issue