mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Adds support for creating weighted voice combinations
Implements a new method to parse weighted voice formulas and generate combined audio outputs based on specified weights. This enhancement allows for more diverse audio generation by letting users specify multiple voices with respective weights, improving flexibility in voice management. Updates voice processing logic in relevant API routes to handle weighted formulas seamlessly. Fixes #123 (if applicable, replace with the actual issue reference)
This commit is contained in:
parent
3547d95ee6
commit
44c62467ae
4 changed files with 135 additions and 14 deletions
|
@ -182,6 +182,106 @@ class VoiceManager:
|
|||
except Exception:
|
||||
return False
|
||||
|
||||
async def create_weighted_voice(
|
||||
self,
|
||||
formula: str,
|
||||
normalize: bool = False,
|
||||
device: str = "cpu",
|
||||
) -> str:
|
||||
"""
|
||||
Parse the voice formula string (e.g. '0.3 * voiceA + 0.5 * voiceB')
|
||||
and return a combined torch.Tensor representing the weighted sum
|
||||
of the given voices.
|
||||
|
||||
Args:
|
||||
formula: Weighted voice formula string.
|
||||
voice_manager: A class that has a `load_voice(voice_name, device)` -> Tensor
|
||||
device: 'cpu' or 'cuda' for the final tensor.
|
||||
normalize: If True, divide the final result by the sum of the weights
|
||||
so the total "magnitude" remains consistent.
|
||||
|
||||
Returns:
|
||||
A torch.Tensor containing the combined voice embedding.
|
||||
"""
|
||||
pairs = self.parse_voice_formula(formula) # [(weight, voiceName), (weight, voiceName), ...]
|
||||
|
||||
# Validate the pairs
|
||||
for weight, voice_name in pairs:
|
||||
if weight <= 0:
|
||||
raise ValueError(f"Invalid weight {weight} for voice {voice_name}.")
|
||||
|
||||
if not pairs:
|
||||
raise ValueError("No valid weighted voices found in formula.")
|
||||
|
||||
# Keep track of total weight if we plan to normalize.
|
||||
total_weight = 0.0
|
||||
weighted_sum = None
|
||||
combined_name = ""
|
||||
|
||||
for weight, voice_name in pairs:
|
||||
# 1) Load each base voice from your manager/service
|
||||
base_voice = await self.load_voice(voice_name, device=device)
|
||||
|
||||
|
||||
# 3) Combine the base voices using the weights
|
||||
if combined_name == "":
|
||||
combined_name = voice_name
|
||||
else:
|
||||
combined_name += f"+{voice_name}"
|
||||
|
||||
|
||||
|
||||
|
||||
# 2) Multiply by weight and accumulate
|
||||
if weighted_sum is None:
|
||||
# Clone so we don't modify the base voice in memory
|
||||
weighted_sum = base_voice.clone() * weight
|
||||
else:
|
||||
weighted_sum += (base_voice * weight)
|
||||
|
||||
total_weight += weight
|
||||
|
||||
if weighted_sum is None:
|
||||
raise ValueError("No voices were combined. Check the formula syntax.")
|
||||
|
||||
# Optional normalization
|
||||
if normalize and total_weight != 0.0:
|
||||
weighted_sum /= total_weight
|
||||
|
||||
if settings.allow_local_voice_saving:
|
||||
|
||||
# Save to disk
|
||||
api_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
voices_dir = os.path.join(api_dir, settings.voices_dir)
|
||||
os.makedirs(voices_dir, exist_ok=True)
|
||||
combined_path = os.path.join(voices_dir, f"{formula}.pt")
|
||||
try:
|
||||
torch.save(weighted_sum, combined_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save combined voice: {e}")
|
||||
# Continue without saving - will be combined on-the-fly when needed
|
||||
|
||||
return combined_name
|
||||
|
||||
|
||||
|
||||
|
||||
def parse_voice_formula(self,formula: str) -> List[tuple[float, str]]:
|
||||
"""
|
||||
Parse the voice formula string (e.g. '0.3 * voiceA + 0.5 * voiceB')
|
||||
and return a list of (weight, voiceName) pairs.
|
||||
Args:
|
||||
formula: Weighted voice formula string.
|
||||
Returns:
|
||||
List of (weight, voiceName) pairs.
|
||||
"""
|
||||
pairs = []
|
||||
parts = formula.split('+')
|
||||
for part in parts:
|
||||
weight, voice_name = part.strip().split('*')
|
||||
pairs.append((float(weight), voice_name.strip()))
|
||||
return pairs
|
||||
|
||||
@property
|
||||
def cache_info(self) -> Dict[str, int]:
|
||||
"""Get cache statistics.
|
||||
|
|
|
@ -2,7 +2,7 @@ import json
|
|||
import os
|
||||
from typing import AsyncGenerator, Dict, List, Union
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||
from fastapi import APIRouter, Header, HTTPException, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from loguru import logger
|
||||
|
||||
|
@ -112,7 +112,17 @@ async def stream_audio_chunks(
|
|||
client_request: Request
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""Stream audio chunks as they're generated with client disconnect handling"""
|
||||
voice_to_use = await process_voices(request.voice, tts_service)
|
||||
# Check if 'request.voice' is a weighted formula (contains '*')
|
||||
if '*' in request.voice:
|
||||
# Weighted formula path
|
||||
voice_to_use = await tts_service._voice_manager.create_weighted_voice(
|
||||
formula=request.voice,
|
||||
normalize=True
|
||||
)
|
||||
|
||||
else:
|
||||
# Normal single or multi-voice path
|
||||
voice_to_use = await process_voices(request.voice, tts_service)
|
||||
|
||||
try:
|
||||
async for chunk in tts_service.generate_audio_stream(
|
||||
|
@ -159,10 +169,7 @@ async def create_speech(
|
|||
# Get global service instance
|
||||
tts_service = await get_tts_service()
|
||||
|
||||
# Process voice combination and validate
|
||||
voice_to_use = await process_voices(request.voice, tts_service)
|
||||
|
||||
# Set content type based on format
|
||||
# Set content type based on format
|
||||
content_type = {
|
||||
"mp3": "audio/mpeg",
|
||||
"opus": "audio/opus",
|
||||
|
@ -209,13 +216,27 @@ async def create_speech(
|
|||
},
|
||||
)
|
||||
else:
|
||||
# Generate complete audio using public interface
|
||||
audio, _ = await tts_service.generate_audio(
|
||||
text=request.input,
|
||||
voice=voice_to_use,
|
||||
speed=request.speed,
|
||||
stitch_long_output=True
|
||||
)
|
||||
# Check if 'request.voice' is a weighted formula (contains '*')
|
||||
if '*' in request.voice:
|
||||
# Weighted formula path
|
||||
print("Weighted formula path")
|
||||
voice_to_use = await tts_service._voice_manager.create_weighted_voice(
|
||||
formula=request.voice,
|
||||
normalize=True
|
||||
)
|
||||
print(voice_to_use)
|
||||
else:
|
||||
# Normal single or multi-voice path
|
||||
print("Normal single or multi-voice path")
|
||||
# Otherwise, handle normal single or multi-voice logic
|
||||
voice_to_use = await process_voices(request.voice, tts_service)
|
||||
# Generate complete audio using public interface
|
||||
audio, _ = await tts_service.generate_audio(
|
||||
text=request.input,
|
||||
voice=voice_to_use,
|
||||
speed=request.speed,
|
||||
stitch_long_output=True
|
||||
)
|
||||
|
||||
# Convert to requested format
|
||||
content = await AudioService.convert_audio(
|
||||
|
|
|
@ -31,7 +31,7 @@ def stream_to_speakers() -> None:
|
|||
|
||||
with openai.audio.speech.with_streaming_response.create(
|
||||
model="kokoro",
|
||||
voice="af_bella",
|
||||
voice="0.100 * af + 0.300 * am_adam + 0.400 * am_michael + 0.100 * bf_emma + 0.100 * bm_lewis ",
|
||||
response_format="pcm", # similar to WAV, but without a header chunk at the start.
|
||||
input="""I see skies of blue and clouds of white
|
||||
The bright blessed days, the dark sacred nights
|
||||
|
|
Binary file not shown.
Loading…
Add table
Reference in a new issue