mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Adds the ability to subtract voices
This commit is contained in:
parent
f2c5bc1b71
commit
aa403f2070
5 changed files with 125 additions and 113 deletions
|
@ -29,7 +29,8 @@ class Settings(BaseSettings):
|
|||
target_min_tokens: int = 175 # Target minimum tokens per chunk
|
||||
target_max_tokens: int = 250 # Target maximum tokens per chunk
|
||||
absolute_max_tokens: int = 450 # Absolute maximum tokens per chunk
|
||||
advanced_text_normalization: bool = True # Preproesses the text before misiki which leads
|
||||
advanced_text_normalization: bool = True # Preproesses the text before misiki
|
||||
voice_weight_normalization: bool = True # Normalize the voice weights so they add up to 1
|
||||
|
||||
gap_trim_ms: int = 1 # Base amount to trim from streaming chunk ends in milliseconds
|
||||
dynamic_gap_trim_padding_ms: int = 410 # Padding to add to dynamic gap trim
|
||||
|
|
|
@ -1,20 +1,24 @@
|
|||
import base64
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import List, Union, AsyncGenerator, Tuple
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse
|
||||
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
|
||||
from kokoro import KPipeline
|
||||
from loguru import logger
|
||||
|
||||
from ..inference.base import AudioChunk
|
||||
from ..core.config import settings
|
||||
from ..inference.base import AudioChunk
|
||||
from ..services.audio import AudioNormalizer, AudioService
|
||||
from ..services.streaming_audio_writer import StreamingAudioWriter
|
||||
from ..services.temp_manager import TempFileWriter
|
||||
from ..services.text_processing import smart_split
|
||||
from ..services.tts_service import TTSService
|
||||
from ..services.temp_manager import TempFileWriter
|
||||
from ..structures import CaptionedSpeechRequest, CaptionedSpeechResponse, WordTimestamp
|
||||
from ..structures.custom_responses import JSONStreamingResponse
|
||||
from ..structures.text_schemas import (
|
||||
|
@ -22,12 +26,7 @@ from ..structures.text_schemas import (
|
|||
PhonemeRequest,
|
||||
PhonemeResponse,
|
||||
)
|
||||
from .openai_compatible import process_voices, stream_audio_chunks
|
||||
import json
|
||||
import os
|
||||
import base64
|
||||
from pathlib import Path
|
||||
|
||||
from .openai_compatible import process_and_validate_voices, stream_audio_chunks
|
||||
|
||||
router = APIRouter(tags=["text processing"])
|
||||
|
||||
|
@ -169,7 +168,7 @@ async def create_captioned_speech(
|
|||
try:
|
||||
# model_name = get_model_name(request.model)
|
||||
tts_service = await get_tts_service()
|
||||
voice_name = await process_voices(request.voice, tts_service)
|
||||
voice_name = await process_and_validate_voices(request.voice, tts_service)
|
||||
|
||||
# Set content type based on format
|
||||
content_type = {
|
||||
|
|
|
@ -5,20 +5,19 @@ import json
|
|||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from typing import AsyncGenerator, Dict, List, Union, Tuple
|
||||
from typing import AsyncGenerator, Dict, List, Tuple, Union
|
||||
from urllib import response
|
||||
import numpy as np
|
||||
|
||||
import aiofiles
|
||||
|
||||
from structures.schemas import CaptionedSpeechRequest
|
||||
import numpy as np
|
||||
import torch
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from loguru import logger
|
||||
from structures.schemas import CaptionedSpeechRequest
|
||||
|
||||
from ..inference.base import AudioChunk
|
||||
from ..core.config import settings
|
||||
from ..inference.base import AudioChunk
|
||||
from ..services.audio import AudioService
|
||||
from ..services.tts_service import TTSService
|
||||
from ..structures import OpenAISpeechRequest
|
||||
|
@ -80,7 +79,7 @@ def get_model_name(model: str) -> str:
|
|||
return base_name + ".pth"
|
||||
|
||||
|
||||
async def process_voices(
|
||||
async def process_and_validate_voices(
|
||||
voice_input: Union[str, List[str]], tts_service: TTSService
|
||||
) -> str:
|
||||
"""Process voice input, handling both string and list formats
|
||||
|
@ -88,53 +87,57 @@ async def process_voices(
|
|||
Returns:
|
||||
Voice name to use (with weights if specified)
|
||||
"""
|
||||
voices = []
|
||||
# Convert input to list of voices
|
||||
if isinstance(voice_input, str):
|
||||
# Check if it's an OpenAI voice name
|
||||
mapped_voice = _openai_mappings["voices"].get(voice_input)
|
||||
if mapped_voice:
|
||||
voice_input = mapped_voice
|
||||
# Split on + but preserve any parentheses
|
||||
voices = []
|
||||
for part in voice_input.split("+"):
|
||||
part = part.strip()
|
||||
if not part:
|
||||
continue
|
||||
# Extract voice name without weight
|
||||
voice_name = part.split("(")[0].strip()
|
||||
# Check if it's a valid voice
|
||||
available_voices = await tts_service.list_voices()
|
||||
if voice_name not in available_voices:
|
||||
raise ValueError(
|
||||
f"Voice '{voice_name}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
||||
)
|
||||
voices.append(part)
|
||||
voice_input=voice_input.replace(" ","").strip()
|
||||
|
||||
if voice_input[-1] in "+-" or voice_input[0] in "+-":
|
||||
raise ValueError(
|
||||
f"Voice combination contains empty combine items"
|
||||
)
|
||||
|
||||
if re.search(r"[+-]{2,}", voice_input) is not None:
|
||||
raise ValueError(
|
||||
f"Voice combination contains empty combine items"
|
||||
)
|
||||
voices = re.split(r"([-+])", voice_input)
|
||||
else:
|
||||
# For list input, map each voice if it's an OpenAI voice name
|
||||
voices = []
|
||||
for v in voice_input:
|
||||
mapped = _openai_mappings["voices"].get(v, v)
|
||||
voice_name = mapped.split("(")[0].strip()
|
||||
# Check if it's a valid voice
|
||||
available_voices = await tts_service.list_voices()
|
||||
if voice_name not in available_voices:
|
||||
raise ValueError(
|
||||
f"Voice '{voice_name}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
||||
voices = [[item,"+"] for item in voice_input][:-1]
|
||||
|
||||
available_voices = await tts_service.list_voices()
|
||||
|
||||
for voice_index in range(0,len(voices), 2):
|
||||
|
||||
mapped_voice = voices[voice_index].split("(")
|
||||
mapped_voice = list(map(str.strip, mapped_voice))
|
||||
|
||||
if len(mapped_voice) > 2:
|
||||
raise ValueError(
|
||||
f"Voice '{voices[voice_index]}' contains too many weight items"
|
||||
)
|
||||
voices.append(mapped)
|
||||
|
||||
if mapped_voice.count(")") > 1:
|
||||
raise ValueError(
|
||||
f"Voice '{voices[voice_index]}' contains too many weight items"
|
||||
)
|
||||
|
||||
mapped_voice[0] = _openai_mappings["voices"].get(mapped_voice[0], mapped_voice[0])
|
||||
|
||||
if not voices:
|
||||
raise ValueError("No voices provided")
|
||||
|
||||
# For multiple voices, combine them with +
|
||||
return "+".join(voices)
|
||||
if mapped_voice[0] not in available_voices:
|
||||
raise ValueError(
|
||||
f"Voice '{mapped_voice[0]}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
||||
)
|
||||
|
||||
voices[voice_index] = "(".join(mapped_voice)
|
||||
|
||||
return "".join(voices)
|
||||
|
||||
async def stream_audio_chunks(
|
||||
tts_service: TTSService, request: Union[OpenAISpeechRequest,CaptionedSpeechRequest], client_request: Request
|
||||
) -> AsyncGenerator[AudioChunk, None]:
|
||||
"""Stream audio chunks as they're generated with client disconnect handling"""
|
||||
voice_name = await process_voices(request.voice, tts_service)
|
||||
voice_name = await process_and_validate_voices(request.voice, tts_service)
|
||||
|
||||
unique_properties={
|
||||
"return_timestamps":False
|
||||
|
@ -191,7 +194,7 @@ async def create_speech(
|
|||
try:
|
||||
# model_name = get_model_name(request.model)
|
||||
tts_service = await get_tts_service()
|
||||
voice_name = await process_voices(request.voice, tts_service)
|
||||
voice_name = await process_and_validate_voices(request.voice, tts_service)
|
||||
|
||||
# Set content type based on format
|
||||
content_type = {
|
||||
|
|
|
@ -2,24 +2,26 @@
|
|||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import time
|
||||
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
||||
|
||||
from ..inference.base import AudioChunk
|
||||
import numpy as np
|
||||
import torch
|
||||
from kokoro import KPipeline
|
||||
from loguru import logger
|
||||
|
||||
from ..core.config import settings
|
||||
from ..inference.base import AudioChunk
|
||||
from ..inference.kokoro_v1 import KokoroV1
|
||||
from ..inference.model_manager import get_manager as get_model_manager
|
||||
from ..inference.voice_manager import get_manager as get_voice_manager
|
||||
from ..structures.schemas import NormalizationOptions
|
||||
from .audio import AudioNormalizer, AudioService
|
||||
from .text_processing import tokenize
|
||||
from .text_processing.text_processor import process_text_chunk, smart_split
|
||||
from ..structures.schemas import NormalizationOptions
|
||||
|
||||
|
||||
class TTSService:
|
||||
"""Text-to-speech service."""
|
||||
|
@ -163,7 +165,15 @@ class TTSService:
|
|||
except Exception as e:
|
||||
logger.error(f"Failed to process tokens: {str(e)}")
|
||||
|
||||
async def _get_voice_path(self, voice: str) -> Tuple[str, str]:
|
||||
async def _load_voice_from_path(self, path: str, weight: float):
|
||||
# Check if the path is None and raise a ValueError if it is not
|
||||
if not path:
|
||||
raise ValueError(f"Voice not found at path: {path}")
|
||||
|
||||
logger.debug(f"Loading voice tensor from path: {path}")
|
||||
return torch.load(path, map_location="cpu") * weight
|
||||
|
||||
async def _get_voices_path(self, voice: str) -> Tuple[str, str]:
|
||||
"""Get voice path, handling combined voices.
|
||||
|
||||
Args:
|
||||
|
@ -176,64 +186,62 @@ class TTSService:
|
|||
RuntimeError: If voice not found
|
||||
"""
|
||||
try:
|
||||
# Check if it's a combined voice
|
||||
if "+" in voice:
|
||||
# Split on + but preserve any parentheses
|
||||
voice_parts = []
|
||||
weights = []
|
||||
for part in voice.split("+"):
|
||||
part = part.strip()
|
||||
if not part:
|
||||
continue
|
||||
# Extract voice name and weight if present
|
||||
if "(" in part and ")" in part:
|
||||
voice_name = part.split("(")[0].strip()
|
||||
weight = float(part.split("(")[1].split(")")[0])
|
||||
else:
|
||||
voice_name = part
|
||||
weight = 1.0
|
||||
voice_parts.append(voice_name)
|
||||
weights.append(weight)
|
||||
# Split the voice on + and - and ensure that they get added to the list eg: hi+bob = ["hi","+","bob"]
|
||||
split_voice = re.split(r"([-+])", voice)
|
||||
|
||||
if len(voice_parts) < 2:
|
||||
raise RuntimeError(f"Invalid combined voice name: {voice}")
|
||||
# If it is only once voice there is no point in loading it up, doing nothing with it, then saving it
|
||||
if len(split_voice) == 1:
|
||||
|
||||
# Normalize weights to sum to 1
|
||||
total_weight = sum(weights)
|
||||
weights = [w / total_weight for w in weights]
|
||||
# Since its a single voice the only time that the weight would matter is if voice_weight_normalization is off
|
||||
if ("(" not in voice and ")" not in voice) or settings.voice_weight_normalization == True:
|
||||
|
||||
# Load and combine voices
|
||||
voice_tensors = []
|
||||
for v, w in zip(voice_parts, weights):
|
||||
path = await self._voice_manager.get_voice_path(v)
|
||||
path = await self._voice_manager.get_voice_path(voice)
|
||||
if not path:
|
||||
raise RuntimeError(f"Voice not found: {v}")
|
||||
logger.debug(f"Loading voice tensor from: {path}")
|
||||
voice_tensor = torch.load(path, map_location="cpu")
|
||||
voice_tensors.append(voice_tensor * w)
|
||||
raise RuntimeError(f"Voice not found: {voice}")
|
||||
logger.debug(f"Using single voice path: {path}")
|
||||
return voice, path
|
||||
|
||||
# Sum the weighted voice tensors
|
||||
logger.debug(
|
||||
f"Combining {len(voice_tensors)} voice tensors with weights {weights}"
|
||||
)
|
||||
combined = torch.sum(torch.stack(voice_tensors), dim=0)
|
||||
total_weight = 0
|
||||
|
||||
for voice_index in range(0,len(split_voice),2):
|
||||
voice_object = split_voice[voice_index]
|
||||
|
||||
# Save combined tensor
|
||||
temp_dir = tempfile.gettempdir()
|
||||
combined_path = os.path.join(temp_dir, f"{voice}.pt")
|
||||
logger.debug(f"Saving combined voice to: {combined_path}")
|
||||
torch.save(combined, combined_path)
|
||||
if "(" in voice_object and ")" in voice_object:
|
||||
voice_name = voice_object.split("(")[0].strip()
|
||||
voice_weight = float(voice_object.split("(")[1].split(")")[0])
|
||||
else:
|
||||
voice_name = voice_object
|
||||
voice_weight = 1
|
||||
|
||||
return voice, combined_path
|
||||
else:
|
||||
# Single voice
|
||||
if "(" in voice and ")" in voice:
|
||||
voice = voice.split("(")[0].strip()
|
||||
path = await self._voice_manager.get_voice_path(voice)
|
||||
if not path:
|
||||
raise RuntimeError(f"Voice not found: {voice}")
|
||||
logger.debug(f"Using single voice path: {path}")
|
||||
return voice, path
|
||||
total_weight += voice_weight
|
||||
split_voice[voice_index] = (voice_name, voice_weight)
|
||||
|
||||
# If voice_weight_normalization is false prevent normalizing the weights by setting the total_weight to 1 so it divides each weight by 1
|
||||
if settings.voice_weight_normalization == False:
|
||||
total_weight = 1
|
||||
|
||||
# Load the first voice as the starting point for voices to be combined onto
|
||||
path = await self._voice_manager.get_voice_path(split_voice[0][0])
|
||||
combined_tensor = await self._load_voice_from_path(path, split_voice[0][1] / total_weight)
|
||||
|
||||
# Loop through each + or - in split_voice so they can be applied to combined voice
|
||||
for operation_index in range(1,len(split_voice) - 1, 2):
|
||||
# Get the voice path of the voice 1 index ahead of the operator
|
||||
path = await self._voice_manager.get_voice_path(split_voice[operation_index+1][0])
|
||||
voice_tensor = await self._load_voice_from_path(path, split_voice[operation_index + 1][1] / total_weight)
|
||||
|
||||
# Either add or subtract the voice from the current combined voice
|
||||
if split_voice[operation_index] == "+":
|
||||
combined_tensor += voice_tensor
|
||||
else:
|
||||
combined_tensor -= voice_tensor
|
||||
|
||||
# Save the new combined voice so it can be loaded latter
|
||||
temp_dir = tempfile.gettempdir()
|
||||
combined_path = os.path.join(temp_dir, f"{voice}.pt")
|
||||
logger.debug(f"Saving combined voice to: {combined_path}")
|
||||
torch.save(combined_tensor, combined_path)
|
||||
return voice, combined_path
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get voice path: {e}")
|
||||
raise
|
||||
|
@ -257,7 +265,7 @@ class TTSService:
|
|||
backend = self.model_manager.get_backend()
|
||||
|
||||
# Get voice path, handling combined voices
|
||||
voice_name, voice_path = await self._get_voice_path(voice)
|
||||
voice_name, voice_path = await self._get_voices_path(voice)
|
||||
logger.debug(f"Using voice path: {voice_path}")
|
||||
|
||||
# Use provided lang_code or determine from voice name
|
||||
|
@ -362,6 +370,7 @@ class TTSService:
|
|||
Returns:
|
||||
Combined voice tensor
|
||||
"""
|
||||
|
||||
return await self._voice_manager.combine_voices(voices)
|
||||
|
||||
async def list_voices(self) -> List[str]:
|
||||
|
@ -390,7 +399,7 @@ class TTSService:
|
|||
try:
|
||||
# Get backend and voice path
|
||||
backend = self.model_manager.get_backend()
|
||||
voice_name, voice_path = await self._get_voice_path(voice)
|
||||
voice_name, voice_path = await self._get_voices_path(voice)
|
||||
|
||||
if isinstance(backend, KokoroV1):
|
||||
# For Kokoro V1, use generate_from_tokens with raw phonemes
|
||||
|
|
|
@ -74,7 +74,7 @@ async def test_get_voice_path_single():
|
|||
mock_get_voice.return_value = voice_manager
|
||||
|
||||
service = await TTSService.create("test_output")
|
||||
name, path = await service._get_voice_path("voice1")
|
||||
name, path = await service._get_voices_path("voice1")
|
||||
assert name == "voice1"
|
||||
assert path == "/path/to/voice1.pt"
|
||||
voice_manager.get_voice_path.assert_called_once_with("voice1")
|
||||
|
@ -100,7 +100,7 @@ async def test_get_voice_path_combined():
|
|||
mock_load.return_value = torch.ones(10)
|
||||
|
||||
service = await TTSService.create("test_output")
|
||||
name, path = await service._get_voice_path("voice1+voice2")
|
||||
name, path = await service._get_voices_path("voice1+voice2")
|
||||
assert name == "voice1+voice2"
|
||||
assert path.endswith("voice1+voice2.pt")
|
||||
mock_save.assert_called_once()
|
||||
|
|
Loading…
Add table
Reference in a new issue