mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Adds the ability to subtract voices
This commit is contained in:
parent
f2c5bc1b71
commit
aa403f2070
5 changed files with 125 additions and 113 deletions
|
@ -29,7 +29,8 @@ class Settings(BaseSettings):
|
||||||
target_min_tokens: int = 175 # Target minimum tokens per chunk
|
target_min_tokens: int = 175 # Target minimum tokens per chunk
|
||||||
target_max_tokens: int = 250 # Target maximum tokens per chunk
|
target_max_tokens: int = 250 # Target maximum tokens per chunk
|
||||||
absolute_max_tokens: int = 450 # Absolute maximum tokens per chunk
|
absolute_max_tokens: int = 450 # Absolute maximum tokens per chunk
|
||||||
advanced_text_normalization: bool = True # Preproesses the text before misiki which leads
|
advanced_text_normalization: bool = True # Preproesses the text before misiki
|
||||||
|
voice_weight_normalization: bool = True # Normalize the voice weights so they add up to 1
|
||||||
|
|
||||||
gap_trim_ms: int = 1 # Base amount to trim from streaming chunk ends in milliseconds
|
gap_trim_ms: int = 1 # Base amount to trim from streaming chunk ends in milliseconds
|
||||||
dynamic_gap_trim_padding_ms: int = 410 # Padding to add to dynamic gap trim
|
dynamic_gap_trim_padding_ms: int = 410 # Padding to add to dynamic gap trim
|
||||||
|
|
|
@ -1,20 +1,24 @@
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import List, Union, AsyncGenerator, Tuple
|
from pathlib import Path
|
||||||
|
from typing import AsyncGenerator, List, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse
|
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
|
||||||
from kokoro import KPipeline
|
from kokoro import KPipeline
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from ..inference.base import AudioChunk
|
|
||||||
from ..core.config import settings
|
from ..core.config import settings
|
||||||
|
from ..inference.base import AudioChunk
|
||||||
from ..services.audio import AudioNormalizer, AudioService
|
from ..services.audio import AudioNormalizer, AudioService
|
||||||
from ..services.streaming_audio_writer import StreamingAudioWriter
|
from ..services.streaming_audio_writer import StreamingAudioWriter
|
||||||
|
from ..services.temp_manager import TempFileWriter
|
||||||
from ..services.text_processing import smart_split
|
from ..services.text_processing import smart_split
|
||||||
from ..services.tts_service import TTSService
|
from ..services.tts_service import TTSService
|
||||||
from ..services.temp_manager import TempFileWriter
|
|
||||||
from ..structures import CaptionedSpeechRequest, CaptionedSpeechResponse, WordTimestamp
|
from ..structures import CaptionedSpeechRequest, CaptionedSpeechResponse, WordTimestamp
|
||||||
from ..structures.custom_responses import JSONStreamingResponse
|
from ..structures.custom_responses import JSONStreamingResponse
|
||||||
from ..structures.text_schemas import (
|
from ..structures.text_schemas import (
|
||||||
|
@ -22,12 +26,7 @@ from ..structures.text_schemas import (
|
||||||
PhonemeRequest,
|
PhonemeRequest,
|
||||||
PhonemeResponse,
|
PhonemeResponse,
|
||||||
)
|
)
|
||||||
from .openai_compatible import process_voices, stream_audio_chunks
|
from .openai_compatible import process_and_validate_voices, stream_audio_chunks
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import base64
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(tags=["text processing"])
|
router = APIRouter(tags=["text processing"])
|
||||||
|
|
||||||
|
@ -169,7 +168,7 @@ async def create_captioned_speech(
|
||||||
try:
|
try:
|
||||||
# model_name = get_model_name(request.model)
|
# model_name = get_model_name(request.model)
|
||||||
tts_service = await get_tts_service()
|
tts_service = await get_tts_service()
|
||||||
voice_name = await process_voices(request.voice, tts_service)
|
voice_name = await process_and_validate_voices(request.voice, tts_service)
|
||||||
|
|
||||||
# Set content type based on format
|
# Set content type based on format
|
||||||
content_type = {
|
content_type = {
|
||||||
|
|
|
@ -5,20 +5,19 @@ import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import AsyncGenerator, Dict, List, Union, Tuple
|
from typing import AsyncGenerator, Dict, List, Tuple, Union
|
||||||
from urllib import response
|
from urllib import response
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
|
import numpy as np
|
||||||
from structures.schemas import CaptionedSpeechRequest
|
|
||||||
import torch
|
import torch
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||||
from fastapi.responses import FileResponse, StreamingResponse
|
from fastapi.responses import FileResponse, StreamingResponse
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from structures.schemas import CaptionedSpeechRequest
|
||||||
|
|
||||||
from ..inference.base import AudioChunk
|
|
||||||
from ..core.config import settings
|
from ..core.config import settings
|
||||||
|
from ..inference.base import AudioChunk
|
||||||
from ..services.audio import AudioService
|
from ..services.audio import AudioService
|
||||||
from ..services.tts_service import TTSService
|
from ..services.tts_service import TTSService
|
||||||
from ..structures import OpenAISpeechRequest
|
from ..structures import OpenAISpeechRequest
|
||||||
|
@ -80,7 +79,7 @@ def get_model_name(model: str) -> str:
|
||||||
return base_name + ".pth"
|
return base_name + ".pth"
|
||||||
|
|
||||||
|
|
||||||
async def process_voices(
|
async def process_and_validate_voices(
|
||||||
voice_input: Union[str, List[str]], tts_service: TTSService
|
voice_input: Union[str, List[str]], tts_service: TTSService
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Process voice input, handling both string and list formats
|
"""Process voice input, handling both string and list formats
|
||||||
|
@ -88,53 +87,57 @@ async def process_voices(
|
||||||
Returns:
|
Returns:
|
||||||
Voice name to use (with weights if specified)
|
Voice name to use (with weights if specified)
|
||||||
"""
|
"""
|
||||||
|
voices = []
|
||||||
# Convert input to list of voices
|
# Convert input to list of voices
|
||||||
if isinstance(voice_input, str):
|
if isinstance(voice_input, str):
|
||||||
# Check if it's an OpenAI voice name
|
voice_input=voice_input.replace(" ","").strip()
|
||||||
mapped_voice = _openai_mappings["voices"].get(voice_input)
|
|
||||||
if mapped_voice:
|
if voice_input[-1] in "+-" or voice_input[0] in "+-":
|
||||||
voice_input = mapped_voice
|
raise ValueError(
|
||||||
# Split on + but preserve any parentheses
|
f"Voice combination contains empty combine items"
|
||||||
voices = []
|
)
|
||||||
for part in voice_input.split("+"):
|
|
||||||
part = part.strip()
|
if re.search(r"[+-]{2,}", voice_input) is not None:
|
||||||
if not part:
|
raise ValueError(
|
||||||
continue
|
f"Voice combination contains empty combine items"
|
||||||
# Extract voice name without weight
|
)
|
||||||
voice_name = part.split("(")[0].strip()
|
voices = re.split(r"([-+])", voice_input)
|
||||||
# Check if it's a valid voice
|
|
||||||
available_voices = await tts_service.list_voices()
|
|
||||||
if voice_name not in available_voices:
|
|
||||||
raise ValueError(
|
|
||||||
f"Voice '{voice_name}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
|
||||||
)
|
|
||||||
voices.append(part)
|
|
||||||
else:
|
else:
|
||||||
# For list input, map each voice if it's an OpenAI voice name
|
voices = [[item,"+"] for item in voice_input][:-1]
|
||||||
voices = []
|
|
||||||
for v in voice_input:
|
available_voices = await tts_service.list_voices()
|
||||||
mapped = _openai_mappings["voices"].get(v, v)
|
|
||||||
voice_name = mapped.split("(")[0].strip()
|
for voice_index in range(0,len(voices), 2):
|
||||||
# Check if it's a valid voice
|
|
||||||
available_voices = await tts_service.list_voices()
|
mapped_voice = voices[voice_index].split("(")
|
||||||
if voice_name not in available_voices:
|
mapped_voice = list(map(str.strip, mapped_voice))
|
||||||
raise ValueError(
|
|
||||||
f"Voice '{voice_name}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
if len(mapped_voice) > 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"Voice '{voices[voice_index]}' contains too many weight items"
|
||||||
)
|
)
|
||||||
voices.append(mapped)
|
|
||||||
|
if mapped_voice.count(")") > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Voice '{voices[voice_index]}' contains too many weight items"
|
||||||
|
)
|
||||||
|
|
||||||
|
mapped_voice[0] = _openai_mappings["voices"].get(mapped_voice[0], mapped_voice[0])
|
||||||
|
|
||||||
if not voices:
|
if mapped_voice[0] not in available_voices:
|
||||||
raise ValueError("No voices provided")
|
raise ValueError(
|
||||||
|
f"Voice '{mapped_voice[0]}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
||||||
# For multiple voices, combine them with +
|
)
|
||||||
return "+".join(voices)
|
|
||||||
|
|
||||||
|
voices[voice_index] = "(".join(mapped_voice)
|
||||||
|
|
||||||
|
return "".join(voices)
|
||||||
|
|
||||||
async def stream_audio_chunks(
|
async def stream_audio_chunks(
|
||||||
tts_service: TTSService, request: Union[OpenAISpeechRequest,CaptionedSpeechRequest], client_request: Request
|
tts_service: TTSService, request: Union[OpenAISpeechRequest,CaptionedSpeechRequest], client_request: Request
|
||||||
) -> AsyncGenerator[AudioChunk, None]:
|
) -> AsyncGenerator[AudioChunk, None]:
|
||||||
"""Stream audio chunks as they're generated with client disconnect handling"""
|
"""Stream audio chunks as they're generated with client disconnect handling"""
|
||||||
voice_name = await process_voices(request.voice, tts_service)
|
voice_name = await process_and_validate_voices(request.voice, tts_service)
|
||||||
|
|
||||||
unique_properties={
|
unique_properties={
|
||||||
"return_timestamps":False
|
"return_timestamps":False
|
||||||
|
@ -191,7 +194,7 @@ async def create_speech(
|
||||||
try:
|
try:
|
||||||
# model_name = get_model_name(request.model)
|
# model_name = get_model_name(request.model)
|
||||||
tts_service = await get_tts_service()
|
tts_service = await get_tts_service()
|
||||||
voice_name = await process_voices(request.voice, tts_service)
|
voice_name = await process_and_validate_voices(request.voice, tts_service)
|
||||||
|
|
||||||
# Set content type based on format
|
# Set content type based on format
|
||||||
content_type = {
|
content_type = {
|
||||||
|
|
|
@ -2,24 +2,26 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from ..inference.base import AudioChunk
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from kokoro import KPipeline
|
from kokoro import KPipeline
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from ..core.config import settings
|
from ..core.config import settings
|
||||||
|
from ..inference.base import AudioChunk
|
||||||
from ..inference.kokoro_v1 import KokoroV1
|
from ..inference.kokoro_v1 import KokoroV1
|
||||||
from ..inference.model_manager import get_manager as get_model_manager
|
from ..inference.model_manager import get_manager as get_model_manager
|
||||||
from ..inference.voice_manager import get_manager as get_voice_manager
|
from ..inference.voice_manager import get_manager as get_voice_manager
|
||||||
|
from ..structures.schemas import NormalizationOptions
|
||||||
from .audio import AudioNormalizer, AudioService
|
from .audio import AudioNormalizer, AudioService
|
||||||
from .text_processing import tokenize
|
from .text_processing import tokenize
|
||||||
from .text_processing.text_processor import process_text_chunk, smart_split
|
from .text_processing.text_processor import process_text_chunk, smart_split
|
||||||
from ..structures.schemas import NormalizationOptions
|
|
||||||
|
|
||||||
class TTSService:
|
class TTSService:
|
||||||
"""Text-to-speech service."""
|
"""Text-to-speech service."""
|
||||||
|
@ -163,7 +165,15 @@ class TTSService:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to process tokens: {str(e)}")
|
logger.error(f"Failed to process tokens: {str(e)}")
|
||||||
|
|
||||||
async def _get_voice_path(self, voice: str) -> Tuple[str, str]:
|
async def _load_voice_from_path(self, path: str, weight: float):
|
||||||
|
# Check if the path is None and raise a ValueError if it is not
|
||||||
|
if not path:
|
||||||
|
raise ValueError(f"Voice not found at path: {path}")
|
||||||
|
|
||||||
|
logger.debug(f"Loading voice tensor from path: {path}")
|
||||||
|
return torch.load(path, map_location="cpu") * weight
|
||||||
|
|
||||||
|
async def _get_voices_path(self, voice: str) -> Tuple[str, str]:
|
||||||
"""Get voice path, handling combined voices.
|
"""Get voice path, handling combined voices.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -176,64 +186,62 @@ class TTSService:
|
||||||
RuntimeError: If voice not found
|
RuntimeError: If voice not found
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Check if it's a combined voice
|
# Split the voice on + and - and ensure that they get added to the list eg: hi+bob = ["hi","+","bob"]
|
||||||
if "+" in voice:
|
split_voice = re.split(r"([-+])", voice)
|
||||||
# Split on + but preserve any parentheses
|
|
||||||
voice_parts = []
|
|
||||||
weights = []
|
|
||||||
for part in voice.split("+"):
|
|
||||||
part = part.strip()
|
|
||||||
if not part:
|
|
||||||
continue
|
|
||||||
# Extract voice name and weight if present
|
|
||||||
if "(" in part and ")" in part:
|
|
||||||
voice_name = part.split("(")[0].strip()
|
|
||||||
weight = float(part.split("(")[1].split(")")[0])
|
|
||||||
else:
|
|
||||||
voice_name = part
|
|
||||||
weight = 1.0
|
|
||||||
voice_parts.append(voice_name)
|
|
||||||
weights.append(weight)
|
|
||||||
|
|
||||||
if len(voice_parts) < 2:
|
# If it is only once voice there is no point in loading it up, doing nothing with it, then saving it
|
||||||
raise RuntimeError(f"Invalid combined voice name: {voice}")
|
if len(split_voice) == 1:
|
||||||
|
|
||||||
# Normalize weights to sum to 1
|
# Since its a single voice the only time that the weight would matter is if voice_weight_normalization is off
|
||||||
total_weight = sum(weights)
|
if ("(" not in voice and ")" not in voice) or settings.voice_weight_normalization == True:
|
||||||
weights = [w / total_weight for w in weights]
|
|
||||||
|
|
||||||
# Load and combine voices
|
path = await self._voice_manager.get_voice_path(voice)
|
||||||
voice_tensors = []
|
|
||||||
for v, w in zip(voice_parts, weights):
|
|
||||||
path = await self._voice_manager.get_voice_path(v)
|
|
||||||
if not path:
|
if not path:
|
||||||
raise RuntimeError(f"Voice not found: {v}")
|
raise RuntimeError(f"Voice not found: {voice}")
|
||||||
logger.debug(f"Loading voice tensor from: {path}")
|
logger.debug(f"Using single voice path: {path}")
|
||||||
voice_tensor = torch.load(path, map_location="cpu")
|
return voice, path
|
||||||
voice_tensors.append(voice_tensor * w)
|
|
||||||
|
|
||||||
# Sum the weighted voice tensors
|
total_weight = 0
|
||||||
logger.debug(
|
|
||||||
f"Combining {len(voice_tensors)} voice tensors with weights {weights}"
|
for voice_index in range(0,len(split_voice),2):
|
||||||
)
|
voice_object = split_voice[voice_index]
|
||||||
combined = torch.sum(torch.stack(voice_tensors), dim=0)
|
|
||||||
|
|
||||||
# Save combined tensor
|
if "(" in voice_object and ")" in voice_object:
|
||||||
temp_dir = tempfile.gettempdir()
|
voice_name = voice_object.split("(")[0].strip()
|
||||||
combined_path = os.path.join(temp_dir, f"{voice}.pt")
|
voice_weight = float(voice_object.split("(")[1].split(")")[0])
|
||||||
logger.debug(f"Saving combined voice to: {combined_path}")
|
else:
|
||||||
torch.save(combined, combined_path)
|
voice_name = voice_object
|
||||||
|
voice_weight = 1
|
||||||
|
|
||||||
return voice, combined_path
|
total_weight += voice_weight
|
||||||
else:
|
split_voice[voice_index] = (voice_name, voice_weight)
|
||||||
# Single voice
|
|
||||||
if "(" in voice and ")" in voice:
|
# If voice_weight_normalization is false prevent normalizing the weights by setting the total_weight to 1 so it divides each weight by 1
|
||||||
voice = voice.split("(")[0].strip()
|
if settings.voice_weight_normalization == False:
|
||||||
path = await self._voice_manager.get_voice_path(voice)
|
total_weight = 1
|
||||||
if not path:
|
|
||||||
raise RuntimeError(f"Voice not found: {voice}")
|
# Load the first voice as the starting point for voices to be combined onto
|
||||||
logger.debug(f"Using single voice path: {path}")
|
path = await self._voice_manager.get_voice_path(split_voice[0][0])
|
||||||
return voice, path
|
combined_tensor = await self._load_voice_from_path(path, split_voice[0][1] / total_weight)
|
||||||
|
|
||||||
|
# Loop through each + or - in split_voice so they can be applied to combined voice
|
||||||
|
for operation_index in range(1,len(split_voice) - 1, 2):
|
||||||
|
# Get the voice path of the voice 1 index ahead of the operator
|
||||||
|
path = await self._voice_manager.get_voice_path(split_voice[operation_index+1][0])
|
||||||
|
voice_tensor = await self._load_voice_from_path(path, split_voice[operation_index + 1][1] / total_weight)
|
||||||
|
|
||||||
|
# Either add or subtract the voice from the current combined voice
|
||||||
|
if split_voice[operation_index] == "+":
|
||||||
|
combined_tensor += voice_tensor
|
||||||
|
else:
|
||||||
|
combined_tensor -= voice_tensor
|
||||||
|
|
||||||
|
# Save the new combined voice so it can be loaded latter
|
||||||
|
temp_dir = tempfile.gettempdir()
|
||||||
|
combined_path = os.path.join(temp_dir, f"{voice}.pt")
|
||||||
|
logger.debug(f"Saving combined voice to: {combined_path}")
|
||||||
|
torch.save(combined_tensor, combined_path)
|
||||||
|
return voice, combined_path
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get voice path: {e}")
|
logger.error(f"Failed to get voice path: {e}")
|
||||||
raise
|
raise
|
||||||
|
@ -257,7 +265,7 @@ class TTSService:
|
||||||
backend = self.model_manager.get_backend()
|
backend = self.model_manager.get_backend()
|
||||||
|
|
||||||
# Get voice path, handling combined voices
|
# Get voice path, handling combined voices
|
||||||
voice_name, voice_path = await self._get_voice_path(voice)
|
voice_name, voice_path = await self._get_voices_path(voice)
|
||||||
logger.debug(f"Using voice path: {voice_path}")
|
logger.debug(f"Using voice path: {voice_path}")
|
||||||
|
|
||||||
# Use provided lang_code or determine from voice name
|
# Use provided lang_code or determine from voice name
|
||||||
|
@ -362,6 +370,7 @@ class TTSService:
|
||||||
Returns:
|
Returns:
|
||||||
Combined voice tensor
|
Combined voice tensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return await self._voice_manager.combine_voices(voices)
|
return await self._voice_manager.combine_voices(voices)
|
||||||
|
|
||||||
async def list_voices(self) -> List[str]:
|
async def list_voices(self) -> List[str]:
|
||||||
|
@ -390,7 +399,7 @@ class TTSService:
|
||||||
try:
|
try:
|
||||||
# Get backend and voice path
|
# Get backend and voice path
|
||||||
backend = self.model_manager.get_backend()
|
backend = self.model_manager.get_backend()
|
||||||
voice_name, voice_path = await self._get_voice_path(voice)
|
voice_name, voice_path = await self._get_voices_path(voice)
|
||||||
|
|
||||||
if isinstance(backend, KokoroV1):
|
if isinstance(backend, KokoroV1):
|
||||||
# For Kokoro V1, use generate_from_tokens with raw phonemes
|
# For Kokoro V1, use generate_from_tokens with raw phonemes
|
||||||
|
|
|
@ -74,7 +74,7 @@ async def test_get_voice_path_single():
|
||||||
mock_get_voice.return_value = voice_manager
|
mock_get_voice.return_value = voice_manager
|
||||||
|
|
||||||
service = await TTSService.create("test_output")
|
service = await TTSService.create("test_output")
|
||||||
name, path = await service._get_voice_path("voice1")
|
name, path = await service._get_voices_path("voice1")
|
||||||
assert name == "voice1"
|
assert name == "voice1"
|
||||||
assert path == "/path/to/voice1.pt"
|
assert path == "/path/to/voice1.pt"
|
||||||
voice_manager.get_voice_path.assert_called_once_with("voice1")
|
voice_manager.get_voice_path.assert_called_once_with("voice1")
|
||||||
|
@ -100,7 +100,7 @@ async def test_get_voice_path_combined():
|
||||||
mock_load.return_value = torch.ones(10)
|
mock_load.return_value = torch.ones(10)
|
||||||
|
|
||||||
service = await TTSService.create("test_output")
|
service = await TTSService.create("test_output")
|
||||||
name, path = await service._get_voice_path("voice1+voice2")
|
name, path = await service._get_voices_path("voice1+voice2")
|
||||||
assert name == "voice1+voice2"
|
assert name == "voice1+voice2"
|
||||||
assert path.endswith("voice1+voice2.pt")
|
assert path.endswith("voice1+voice2.pt")
|
||||||
mock_save.assert_called_once()
|
mock_save.assert_called_once()
|
||||||
|
|
Loading…
Add table
Reference in a new issue