Merge pull request #235 from fireblade2534/fixes

This commit is contained in:
remsky 2025-03-12 02:22:04 -06:00 committed by GitHub
commit e4744f5545
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 134 additions and 123 deletions

View file

@ -23,7 +23,7 @@ In conclusion, "Jet Black Heart" by 5 Seconds of Summer is far more than a typic
5 Seconds of Summer, initially perceived as purveyors of upbeat, radio-friendly pop-punk, embarked on a significant artistic evolution with their album Sounds Good Feels Good. Among its tracks, "Jet Black Heart" stands out as a powerful testament to this shift, moving beyond catchy melodies and embracing a darker, more emotionally complex sound. Released in 2015, the song transcends the typical themes of youthful exuberance and romantic angst, instead plunging into the depths of personal turmoil and the corrosive effects of inner darkness on interpersonal relationships. "Jet Black Heart" is not merely a song about heartbreak; it is a raw and vulnerable exploration of internal struggle, self-destructive patterns, and the precarious flicker of hope that persists even in the face of profound emotional chaos."""
Type="mp3"
Type="wav"
response = requests.post(
"http://localhost:8880/dev/captioned_speech",
json={
@ -51,12 +51,12 @@ for chunk in response.iter_lines(decode_unicode=True):
f.write(chunk_audio)
# Print word level timestamps
last3=chunk_json["timestamps"][-3]
last_chunks={"start_time":chunk_json["timestamps"][-10]["start_time"],"end_time":chunk_json["timestamps"][-3]["end_time"],"word":" ".join([X["word"] for X in chunk_json["timestamps"][-10:-3]])}
print(f"CUTTING TO {last3['word']}")
print(f"CUTTING TO {last_chunks['word']}")
audioseg=pydub.AudioSegment.from_file(f"outputstream.{Type}",format=Type)
audioseg=audioseg[last3["start_time"]*1000:last3["end_time"] * 1000]
audioseg=audioseg[last_chunks["start_time"]*1000:last_chunks["end_time"] * 1000]
audioseg.export(f"outputstreamcut.{Type}",format=Type)

View file

@ -29,7 +29,8 @@ class Settings(BaseSettings):
target_min_tokens: int = 175 # Target minimum tokens per chunk
target_max_tokens: int = 250 # Target maximum tokens per chunk
absolute_max_tokens: int = 450 # Absolute maximum tokens per chunk
advanced_text_normalization: bool = True # Preproesses the text before misiki which leads
advanced_text_normalization: bool = True # Preproesses the text before misiki
voice_weight_normalization: bool = True # Normalize the voice weights so they add up to 1
gap_trim_ms: int = 1 # Base amount to trim from streaming chunk ends in milliseconds
dynamic_gap_trim_padding_ms: int = 410 # Padding to add to dynamic gap trim

View file

@ -259,10 +259,6 @@ class KokoroV1(BaseModelBackend):
)
if result.pred_dur is not None:
try:
# Join timestamps for this chunk's tokens
KPipeline.join_timestamps(
result.tokens, result.pred_dur
)
# Add timestamps with offset
for token in result.tokens:

View file

@ -1,20 +1,24 @@
import base64
import json
import os
import re
from typing import List, Union, AsyncGenerator, Tuple
from pathlib import Path
from typing import AsyncGenerator, List, Tuple, Union
import numpy as np
import torch
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
from kokoro import KPipeline
from loguru import logger
from ..inference.base import AudioChunk
from ..core.config import settings
from ..inference.base import AudioChunk
from ..services.audio import AudioNormalizer, AudioService
from ..services.streaming_audio_writer import StreamingAudioWriter
from ..services.temp_manager import TempFileWriter
from ..services.text_processing import smart_split
from ..services.tts_service import TTSService
from ..services.temp_manager import TempFileWriter
from ..structures import CaptionedSpeechRequest, CaptionedSpeechResponse, WordTimestamp
from ..structures.custom_responses import JSONStreamingResponse
from ..structures.text_schemas import (
@ -22,12 +26,7 @@ from ..structures.text_schemas import (
PhonemeRequest,
PhonemeResponse,
)
from .openai_compatible import process_voices, stream_audio_chunks
import json
import os
import base64
from pathlib import Path
from .openai_compatible import process_and_validate_voices, stream_audio_chunks
router = APIRouter(tags=["text processing"])
@ -169,7 +168,7 @@ async def create_captioned_speech(
try:
# model_name = get_model_name(request.model)
tts_service = await get_tts_service()
voice_name = await process_voices(request.voice, tts_service)
voice_name = await process_and_validate_voices(request.voice, tts_service)
# Set content type based on format
content_type = {
@ -254,7 +253,10 @@ async def create_captioned_speech(
base64_chunk= base64.b64encode(chunk_data.output).decode("utf-8")
# Add any chunks that may be in the acumulator into the return word_timestamps
chunk_data.word_timestamps=timestamp_acumulator + chunk_data.word_timestamps
if chunk_data.word_timestamps != None:
chunk_data.word_timestamps = timestamp_acumulator + chunk_data.word_timestamps
else:
chunk_data.word_timestamps = []
timestamp_acumulator=[]
yield CaptionedSpeechResponse(audio=base64_chunk,audio_format=content_type,timestamps=chunk_data.word_timestamps)

View file

@ -5,20 +5,19 @@ import json
import os
import re
import tempfile
from typing import AsyncGenerator, Dict, List, Union, Tuple
from typing import AsyncGenerator, Dict, List, Tuple, Union
from urllib import response
import numpy as np
import aiofiles
from structures.schemas import CaptionedSpeechRequest
import numpy as np
import torch
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
from fastapi.responses import FileResponse, StreamingResponse
from loguru import logger
from structures.schemas import CaptionedSpeechRequest
from ..inference.base import AudioChunk
from ..core.config import settings
from ..inference.base import AudioChunk
from ..services.audio import AudioService
from ..services.tts_service import TTSService
from ..structures import OpenAISpeechRequest
@ -80,7 +79,7 @@ def get_model_name(model: str) -> str:
return base_name + ".pth"
async def process_voices(
async def process_and_validate_voices(
voice_input: Union[str, List[str]], tts_service: TTSService
) -> str:
"""Process voice input, handling both string and list formats
@ -88,53 +87,57 @@ async def process_voices(
Returns:
Voice name to use (with weights if specified)
"""
voices = []
# Convert input to list of voices
if isinstance(voice_input, str):
# Check if it's an OpenAI voice name
mapped_voice = _openai_mappings["voices"].get(voice_input)
if mapped_voice:
voice_input = mapped_voice
# Split on + but preserve any parentheses
voices = []
for part in voice_input.split("+"):
part = part.strip()
if not part:
continue
# Extract voice name without weight
voice_name = part.split("(")[0].strip()
# Check if it's a valid voice
available_voices = await tts_service.list_voices()
if voice_name not in available_voices:
raise ValueError(
f"Voice '{voice_name}' not found. Available voices: {', '.join(sorted(available_voices))}"
)
voices.append(part)
voice_input=voice_input.replace(" ","").strip()
if voice_input[-1] in "+-" or voice_input[0] in "+-":
raise ValueError(
f"Voice combination contains empty combine items"
)
if re.search(r"[+-]{2,}", voice_input) is not None:
raise ValueError(
f"Voice combination contains empty combine items"
)
voices = re.split(r"([-+])", voice_input)
else:
# For list input, map each voice if it's an OpenAI voice name
voices = []
for v in voice_input:
mapped = _openai_mappings["voices"].get(v, v)
voice_name = mapped.split("(")[0].strip()
# Check if it's a valid voice
available_voices = await tts_service.list_voices()
if voice_name not in available_voices:
raise ValueError(
f"Voice '{voice_name}' not found. Available voices: {', '.join(sorted(available_voices))}"
voices = [[item,"+"] for item in voice_input][:-1]
available_voices = await tts_service.list_voices()
for voice_index in range(0,len(voices), 2):
mapped_voice = voices[voice_index].split("(")
mapped_voice = list(map(str.strip, mapped_voice))
if len(mapped_voice) > 2:
raise ValueError(
f"Voice '{voices[voice_index]}' contains too many weight items"
)
voices.append(mapped)
if mapped_voice.count(")") > 1:
raise ValueError(
f"Voice '{voices[voice_index]}' contains too many weight items"
)
mapped_voice[0] = _openai_mappings["voices"].get(mapped_voice[0], mapped_voice[0])
if not voices:
raise ValueError("No voices provided")
# For multiple voices, combine them with +
return "+".join(voices)
if mapped_voice[0] not in available_voices:
raise ValueError(
f"Voice '{mapped_voice[0]}' not found. Available voices: {', '.join(sorted(available_voices))}"
)
voices[voice_index] = "(".join(mapped_voice)
return "".join(voices)
async def stream_audio_chunks(
tts_service: TTSService, request: Union[OpenAISpeechRequest,CaptionedSpeechRequest], client_request: Request
) -> AsyncGenerator[AudioChunk, None]:
"""Stream audio chunks as they're generated with client disconnect handling"""
voice_name = await process_voices(request.voice, tts_service)
voice_name = await process_and_validate_voices(request.voice, tts_service)
unique_properties={
"return_timestamps":False
@ -191,7 +194,7 @@ async def create_speech(
try:
# model_name = get_model_name(request.model)
tts_service = await get_tts_service()
voice_name = await process_voices(request.voice, tts_service)
voice_name = await process_and_validate_voices(request.voice, tts_service)
# Set content type based on format
content_type = {

View file

@ -121,7 +121,7 @@ class AudioService:
is_last_chunk: bool = False,
trim_audio: bool = True,
normalizer: AudioNormalizer = None,
) -> Tuple[AudioChunk]:
) -> AudioChunk:
"""Convert audio data to specified format with streaming support
Args:

View file

@ -2,24 +2,26 @@
import asyncio
import os
import re
import tempfile
import time
from typing import AsyncGenerator, List, Optional, Tuple, Union
from ..inference.base import AudioChunk
import numpy as np
import torch
from kokoro import KPipeline
from loguru import logger
from ..core.config import settings
from ..inference.base import AudioChunk
from ..inference.kokoro_v1 import KokoroV1
from ..inference.model_manager import get_manager as get_model_manager
from ..inference.voice_manager import get_manager as get_voice_manager
from ..structures.schemas import NormalizationOptions
from .audio import AudioNormalizer, AudioService
from .text_processing import tokenize
from .text_processing.text_processor import process_text_chunk, smart_split
from ..structures.schemas import NormalizationOptions
class TTSService:
"""Text-to-speech service."""
@ -163,7 +165,15 @@ class TTSService:
except Exception as e:
logger.error(f"Failed to process tokens: {str(e)}")
async def _get_voice_path(self, voice: str) -> Tuple[str, str]:
async def _load_voice_from_path(self, path: str, weight: float):
# Check if the path is None and raise a ValueError if it is not
if not path:
raise ValueError(f"Voice not found at path: {path}")
logger.debug(f"Loading voice tensor from path: {path}")
return torch.load(path, map_location="cpu") * weight
async def _get_voices_path(self, voice: str) -> Tuple[str, str]:
"""Get voice path, handling combined voices.
Args:
@ -176,64 +186,62 @@ class TTSService:
RuntimeError: If voice not found
"""
try:
# Check if it's a combined voice
if "+" in voice:
# Split on + but preserve any parentheses
voice_parts = []
weights = []
for part in voice.split("+"):
part = part.strip()
if not part:
continue
# Extract voice name and weight if present
if "(" in part and ")" in part:
voice_name = part.split("(")[0].strip()
weight = float(part.split("(")[1].split(")")[0])
else:
voice_name = part
weight = 1.0
voice_parts.append(voice_name)
weights.append(weight)
# Split the voice on + and - and ensure that they get added to the list eg: hi+bob = ["hi","+","bob"]
split_voice = re.split(r"([-+])", voice)
if len(voice_parts) < 2:
raise RuntimeError(f"Invalid combined voice name: {voice}")
# If it is only once voice there is no point in loading it up, doing nothing with it, then saving it
if len(split_voice) == 1:
# Normalize weights to sum to 1
total_weight = sum(weights)
weights = [w / total_weight for w in weights]
# Since its a single voice the only time that the weight would matter is if voice_weight_normalization is off
if ("(" not in voice and ")" not in voice) or settings.voice_weight_normalization == True:
# Load and combine voices
voice_tensors = []
for v, w in zip(voice_parts, weights):
path = await self._voice_manager.get_voice_path(v)
path = await self._voice_manager.get_voice_path(voice)
if not path:
raise RuntimeError(f"Voice not found: {v}")
logger.debug(f"Loading voice tensor from: {path}")
voice_tensor = torch.load(path, map_location="cpu")
voice_tensors.append(voice_tensor * w)
raise RuntimeError(f"Voice not found: {voice}")
logger.debug(f"Using single voice path: {path}")
return voice, path
# Sum the weighted voice tensors
logger.debug(
f"Combining {len(voice_tensors)} voice tensors with weights {weights}"
)
combined = torch.sum(torch.stack(voice_tensors), dim=0)
total_weight = 0
for voice_index in range(0,len(split_voice),2):
voice_object = split_voice[voice_index]
# Save combined tensor
temp_dir = tempfile.gettempdir()
combined_path = os.path.join(temp_dir, f"{voice}.pt")
logger.debug(f"Saving combined voice to: {combined_path}")
torch.save(combined, combined_path)
if "(" in voice_object and ")" in voice_object:
voice_name = voice_object.split("(")[0].strip()
voice_weight = float(voice_object.split("(")[1].split(")")[0])
else:
voice_name = voice_object
voice_weight = 1
return voice, combined_path
else:
# Single voice
if "(" in voice and ")" in voice:
voice = voice.split("(")[0].strip()
path = await self._voice_manager.get_voice_path(voice)
if not path:
raise RuntimeError(f"Voice not found: {voice}")
logger.debug(f"Using single voice path: {path}")
return voice, path
total_weight += voice_weight
split_voice[voice_index] = (voice_name, voice_weight)
# If voice_weight_normalization is false prevent normalizing the weights by setting the total_weight to 1 so it divides each weight by 1
if settings.voice_weight_normalization == False:
total_weight = 1
# Load the first voice as the starting point for voices to be combined onto
path = await self._voice_manager.get_voice_path(split_voice[0][0])
combined_tensor = await self._load_voice_from_path(path, split_voice[0][1] / total_weight)
# Loop through each + or - in split_voice so they can be applied to combined voice
for operation_index in range(1,len(split_voice) - 1, 2):
# Get the voice path of the voice 1 index ahead of the operator
path = await self._voice_manager.get_voice_path(split_voice[operation_index+1][0])
voice_tensor = await self._load_voice_from_path(path, split_voice[operation_index + 1][1] / total_weight)
# Either add or subtract the voice from the current combined voice
if split_voice[operation_index] == "+":
combined_tensor += voice_tensor
else:
combined_tensor -= voice_tensor
# Save the new combined voice so it can be loaded latter
temp_dir = tempfile.gettempdir()
combined_path = os.path.join(temp_dir, f"{voice}.pt")
logger.debug(f"Saving combined voice to: {combined_path}")
torch.save(combined_tensor, combined_path)
return voice, combined_path
except Exception as e:
logger.error(f"Failed to get voice path: {e}")
raise
@ -257,7 +265,7 @@ class TTSService:
backend = self.model_manager.get_backend()
# Get voice path, handling combined voices
voice_name, voice_path = await self._get_voice_path(voice)
voice_name, voice_path = await self._get_voices_path(voice)
logger.debug(f"Using voice path: {voice_path}")
# Use provided lang_code or determine from voice name
@ -362,6 +370,7 @@ class TTSService:
Returns:
Combined voice tensor
"""
return await self._voice_manager.combine_voices(voices)
async def list_voices(self) -> List[str]:
@ -390,7 +399,7 @@ class TTSService:
try:
# Get backend and voice path
backend = self.model_manager.get_backend()
voice_name, voice_path = await self._get_voice_path(voice)
voice_name, voice_path = await self._get_voices_path(voice)
if isinstance(backend, KokoroV1):
# For Kokoro V1, use generate_from_tokens with raw phonemes

View file

@ -74,7 +74,7 @@ async def test_get_voice_path_single():
mock_get_voice.return_value = voice_manager
service = await TTSService.create("test_output")
name, path = await service._get_voice_path("voice1")
name, path = await service._get_voices_path("voice1")
assert name == "voice1"
assert path == "/path/to/voice1.pt"
voice_manager.get_voice_path.assert_called_once_with("voice1")
@ -100,7 +100,7 @@ async def test_get_voice_path_combined():
mock_load.return_value = torch.ones(10)
service = await TTSService.create("test_output")
name, path = await service._get_voice_path("voice1+voice2")
name, path = await service._get_voices_path("voice1+voice2")
assert name == "voice1+voice2"
assert path.endswith("voice1+voice2.pt")
mock_save.assert_called_once()