diff --git a/api/src/core/config.py b/api/src/core/config.py index f5fd569..9e2e61f 100644 --- a/api/src/core/config.py +++ b/api/src/core/config.py @@ -29,7 +29,8 @@ class Settings(BaseSettings): target_min_tokens: int = 175 # Target minimum tokens per chunk target_max_tokens: int = 250 # Target maximum tokens per chunk absolute_max_tokens: int = 450 # Absolute maximum tokens per chunk - advanced_text_normalization: bool = True # Preproesses the text before misiki which leads + advanced_text_normalization: bool = True # Preproesses the text before misiki + voice_weight_normalization: bool = True # Normalize the voice weights so they add up to 1 gap_trim_ms: int = 1 # Base amount to trim from streaming chunk ends in milliseconds dynamic_gap_trim_padding_ms: int = 410 # Padding to add to dynamic gap trim diff --git a/api/src/routers/development.py b/api/src/routers/development.py index 569ae25..ec5596e 100644 --- a/api/src/routers/development.py +++ b/api/src/routers/development.py @@ -1,20 +1,24 @@ +import base64 +import json +import os import re -from typing import List, Union, AsyncGenerator, Tuple +from pathlib import Path +from typing import AsyncGenerator, List, Tuple, Union import numpy as np import torch from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response -from fastapi.responses import JSONResponse, StreamingResponse, FileResponse +from fastapi.responses import FileResponse, JSONResponse, StreamingResponse from kokoro import KPipeline from loguru import logger -from ..inference.base import AudioChunk from ..core.config import settings +from ..inference.base import AudioChunk from ..services.audio import AudioNormalizer, AudioService from ..services.streaming_audio_writer import StreamingAudioWriter +from ..services.temp_manager import TempFileWriter from ..services.text_processing import smart_split from ..services.tts_service import TTSService -from ..services.temp_manager import TempFileWriter from ..structures import CaptionedSpeechRequest, CaptionedSpeechResponse, WordTimestamp from ..structures.custom_responses import JSONStreamingResponse from ..structures.text_schemas import ( @@ -22,12 +26,7 @@ from ..structures.text_schemas import ( PhonemeRequest, PhonemeResponse, ) -from .openai_compatible import process_voices, stream_audio_chunks -import json -import os -import base64 -from pathlib import Path - +from .openai_compatible import process_and_validate_voices, stream_audio_chunks router = APIRouter(tags=["text processing"]) @@ -169,7 +168,7 @@ async def create_captioned_speech( try: # model_name = get_model_name(request.model) tts_service = await get_tts_service() - voice_name = await process_voices(request.voice, tts_service) + voice_name = await process_and_validate_voices(request.voice, tts_service) # Set content type based on format content_type = { diff --git a/api/src/routers/openai_compatible.py b/api/src/routers/openai_compatible.py index c4036d6..97844b1 100644 --- a/api/src/routers/openai_compatible.py +++ b/api/src/routers/openai_compatible.py @@ -5,20 +5,19 @@ import json import os import re import tempfile -from typing import AsyncGenerator, Dict, List, Union, Tuple +from typing import AsyncGenerator, Dict, List, Tuple, Union from urllib import response -import numpy as np import aiofiles - -from structures.schemas import CaptionedSpeechRequest +import numpy as np import torch from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response from fastapi.responses import FileResponse, StreamingResponse from loguru import logger +from structures.schemas import CaptionedSpeechRequest -from ..inference.base import AudioChunk from ..core.config import settings +from ..inference.base import AudioChunk from ..services.audio import AudioService from ..services.tts_service import TTSService from ..structures import OpenAISpeechRequest @@ -80,7 +79,7 @@ def get_model_name(model: str) -> str: return base_name + ".pth" -async def process_voices( +async def process_and_validate_voices( voice_input: Union[str, List[str]], tts_service: TTSService ) -> str: """Process voice input, handling both string and list formats @@ -88,53 +87,57 @@ async def process_voices( Returns: Voice name to use (with weights if specified) """ + voices = [] # Convert input to list of voices if isinstance(voice_input, str): - # Check if it's an OpenAI voice name - mapped_voice = _openai_mappings["voices"].get(voice_input) - if mapped_voice: - voice_input = mapped_voice - # Split on + but preserve any parentheses - voices = [] - for part in voice_input.split("+"): - part = part.strip() - if not part: - continue - # Extract voice name without weight - voice_name = part.split("(")[0].strip() - # Check if it's a valid voice - available_voices = await tts_service.list_voices() - if voice_name not in available_voices: - raise ValueError( - f"Voice '{voice_name}' not found. Available voices: {', '.join(sorted(available_voices))}" - ) - voices.append(part) + voice_input=voice_input.replace(" ","").strip() + + if voice_input[-1] in "+-" or voice_input[0] in "+-": + raise ValueError( + f"Voice combination contains empty combine items" + ) + + if re.search(r"[+-]{2,}", voice_input) is not None: + raise ValueError( + f"Voice combination contains empty combine items" + ) + voices = re.split(r"([-+])", voice_input) else: - # For list input, map each voice if it's an OpenAI voice name - voices = [] - for v in voice_input: - mapped = _openai_mappings["voices"].get(v, v) - voice_name = mapped.split("(")[0].strip() - # Check if it's a valid voice - available_voices = await tts_service.list_voices() - if voice_name not in available_voices: - raise ValueError( - f"Voice '{voice_name}' not found. Available voices: {', '.join(sorted(available_voices))}" + voices = [[item,"+"] for item in voice_input][:-1] + + available_voices = await tts_service.list_voices() + + for voice_index in range(0,len(voices), 2): + + mapped_voice = voices[voice_index].split("(") + mapped_voice = list(map(str.strip, mapped_voice)) + + if len(mapped_voice) > 2: + raise ValueError( + f"Voice '{voices[voice_index]}' contains too many weight items" ) - voices.append(mapped) + + if mapped_voice.count(")") > 1: + raise ValueError( + f"Voice '{voices[voice_index]}' contains too many weight items" + ) + + mapped_voice[0] = _openai_mappings["voices"].get(mapped_voice[0], mapped_voice[0]) - if not voices: - raise ValueError("No voices provided") - - # For multiple voices, combine them with + - return "+".join(voices) + if mapped_voice[0] not in available_voices: + raise ValueError( + f"Voice '{mapped_voice[0]}' not found. Available voices: {', '.join(sorted(available_voices))}" + ) + voices[voice_index] = "(".join(mapped_voice) + + return "".join(voices) async def stream_audio_chunks( tts_service: TTSService, request: Union[OpenAISpeechRequest,CaptionedSpeechRequest], client_request: Request ) -> AsyncGenerator[AudioChunk, None]: """Stream audio chunks as they're generated with client disconnect handling""" - voice_name = await process_voices(request.voice, tts_service) + voice_name = await process_and_validate_voices(request.voice, tts_service) unique_properties={ "return_timestamps":False @@ -191,7 +194,7 @@ async def create_speech( try: # model_name = get_model_name(request.model) tts_service = await get_tts_service() - voice_name = await process_voices(request.voice, tts_service) + voice_name = await process_and_validate_voices(request.voice, tts_service) # Set content type based on format content_type = { diff --git a/api/src/services/tts_service.py b/api/src/services/tts_service.py index a115c18..cb9f655 100644 --- a/api/src/services/tts_service.py +++ b/api/src/services/tts_service.py @@ -2,24 +2,26 @@ import asyncio import os +import re import tempfile import time from typing import AsyncGenerator, List, Optional, Tuple, Union -from ..inference.base import AudioChunk import numpy as np import torch from kokoro import KPipeline from loguru import logger from ..core.config import settings +from ..inference.base import AudioChunk from ..inference.kokoro_v1 import KokoroV1 from ..inference.model_manager import get_manager as get_model_manager from ..inference.voice_manager import get_manager as get_voice_manager +from ..structures.schemas import NormalizationOptions from .audio import AudioNormalizer, AudioService from .text_processing import tokenize from .text_processing.text_processor import process_text_chunk, smart_split -from ..structures.schemas import NormalizationOptions + class TTSService: """Text-to-speech service.""" @@ -163,7 +165,15 @@ class TTSService: except Exception as e: logger.error(f"Failed to process tokens: {str(e)}") - async def _get_voice_path(self, voice: str) -> Tuple[str, str]: + async def _load_voice_from_path(self, path: str, weight: float): + # Check if the path is None and raise a ValueError if it is not + if not path: + raise ValueError(f"Voice not found at path: {path}") + + logger.debug(f"Loading voice tensor from path: {path}") + return torch.load(path, map_location="cpu") * weight + + async def _get_voices_path(self, voice: str) -> Tuple[str, str]: """Get voice path, handling combined voices. Args: @@ -176,64 +186,62 @@ class TTSService: RuntimeError: If voice not found """ try: - # Check if it's a combined voice - if "+" in voice: - # Split on + but preserve any parentheses - voice_parts = [] - weights = [] - for part in voice.split("+"): - part = part.strip() - if not part: - continue - # Extract voice name and weight if present - if "(" in part and ")" in part: - voice_name = part.split("(")[0].strip() - weight = float(part.split("(")[1].split(")")[0]) - else: - voice_name = part - weight = 1.0 - voice_parts.append(voice_name) - weights.append(weight) + # Split the voice on + and - and ensure that they get added to the list eg: hi+bob = ["hi","+","bob"] + split_voice = re.split(r"([-+])", voice) - if len(voice_parts) < 2: - raise RuntimeError(f"Invalid combined voice name: {voice}") + # If it is only once voice there is no point in loading it up, doing nothing with it, then saving it + if len(split_voice) == 1: - # Normalize weights to sum to 1 - total_weight = sum(weights) - weights = [w / total_weight for w in weights] + # Since its a single voice the only time that the weight would matter is if voice_weight_normalization is off + if ("(" not in voice and ")" not in voice) or settings.voice_weight_normalization == True: - # Load and combine voices - voice_tensors = [] - for v, w in zip(voice_parts, weights): - path = await self._voice_manager.get_voice_path(v) + path = await self._voice_manager.get_voice_path(voice) if not path: - raise RuntimeError(f"Voice not found: {v}") - logger.debug(f"Loading voice tensor from: {path}") - voice_tensor = torch.load(path, map_location="cpu") - voice_tensors.append(voice_tensor * w) + raise RuntimeError(f"Voice not found: {voice}") + logger.debug(f"Using single voice path: {path}") + return voice, path - # Sum the weighted voice tensors - logger.debug( - f"Combining {len(voice_tensors)} voice tensors with weights {weights}" - ) - combined = torch.sum(torch.stack(voice_tensors), dim=0) + total_weight = 0 + + for voice_index in range(0,len(split_voice),2): + voice_object = split_voice[voice_index] - # Save combined tensor - temp_dir = tempfile.gettempdir() - combined_path = os.path.join(temp_dir, f"{voice}.pt") - logger.debug(f"Saving combined voice to: {combined_path}") - torch.save(combined, combined_path) + if "(" in voice_object and ")" in voice_object: + voice_name = voice_object.split("(")[0].strip() + voice_weight = float(voice_object.split("(")[1].split(")")[0]) + else: + voice_name = voice_object + voice_weight = 1 - return voice, combined_path - else: - # Single voice - if "(" in voice and ")" in voice: - voice = voice.split("(")[0].strip() - path = await self._voice_manager.get_voice_path(voice) - if not path: - raise RuntimeError(f"Voice not found: {voice}") - logger.debug(f"Using single voice path: {path}") - return voice, path + total_weight += voice_weight + split_voice[voice_index] = (voice_name, voice_weight) + + # If voice_weight_normalization is false prevent normalizing the weights by setting the total_weight to 1 so it divides each weight by 1 + if settings.voice_weight_normalization == False: + total_weight = 1 + + # Load the first voice as the starting point for voices to be combined onto + path = await self._voice_manager.get_voice_path(split_voice[0][0]) + combined_tensor = await self._load_voice_from_path(path, split_voice[0][1] / total_weight) + + # Loop through each + or - in split_voice so they can be applied to combined voice + for operation_index in range(1,len(split_voice) - 1, 2): + # Get the voice path of the voice 1 index ahead of the operator + path = await self._voice_manager.get_voice_path(split_voice[operation_index+1][0]) + voice_tensor = await self._load_voice_from_path(path, split_voice[operation_index + 1][1] / total_weight) + + # Either add or subtract the voice from the current combined voice + if split_voice[operation_index] == "+": + combined_tensor += voice_tensor + else: + combined_tensor -= voice_tensor + + # Save the new combined voice so it can be loaded latter + temp_dir = tempfile.gettempdir() + combined_path = os.path.join(temp_dir, f"{voice}.pt") + logger.debug(f"Saving combined voice to: {combined_path}") + torch.save(combined_tensor, combined_path) + return voice, combined_path except Exception as e: logger.error(f"Failed to get voice path: {e}") raise @@ -257,7 +265,7 @@ class TTSService: backend = self.model_manager.get_backend() # Get voice path, handling combined voices - voice_name, voice_path = await self._get_voice_path(voice) + voice_name, voice_path = await self._get_voices_path(voice) logger.debug(f"Using voice path: {voice_path}") # Use provided lang_code or determine from voice name @@ -362,6 +370,7 @@ class TTSService: Returns: Combined voice tensor """ + return await self._voice_manager.combine_voices(voices) async def list_voices(self) -> List[str]: @@ -390,7 +399,7 @@ class TTSService: try: # Get backend and voice path backend = self.model_manager.get_backend() - voice_name, voice_path = await self._get_voice_path(voice) + voice_name, voice_path = await self._get_voices_path(voice) if isinstance(backend, KokoroV1): # For Kokoro V1, use generate_from_tokens with raw phonemes diff --git a/api/tests/test_tts_service.py b/api/tests/test_tts_service.py index 0d45824..ae8447a 100644 --- a/api/tests/test_tts_service.py +++ b/api/tests/test_tts_service.py @@ -74,7 +74,7 @@ async def test_get_voice_path_single(): mock_get_voice.return_value = voice_manager service = await TTSService.create("test_output") - name, path = await service._get_voice_path("voice1") + name, path = await service._get_voices_path("voice1") assert name == "voice1" assert path == "/path/to/voice1.pt" voice_manager.get_voice_path.assert_called_once_with("voice1") @@ -100,7 +100,7 @@ async def test_get_voice_path_combined(): mock_load.return_value = torch.ones(10) service = await TTSService.create("test_output") - name, path = await service._get_voice_path("voice1+voice2") + name, path = await service._get_voices_path("voice1+voice2") assert name == "voice1+voice2" assert path.endswith("voice1+voice2.pt") mock_save.assert_called_once()