mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Merge pull request #235 from fireblade2534/fixes
This commit is contained in:
commit
e4744f5545
8 changed files with 134 additions and 123 deletions
|
@ -23,7 +23,7 @@ In conclusion, "Jet Black Heart" by 5 Seconds of Summer is far more than a typic
|
|||
5 Seconds of Summer, initially perceived as purveyors of upbeat, radio-friendly pop-punk, embarked on a significant artistic evolution with their album Sounds Good Feels Good. Among its tracks, "Jet Black Heart" stands out as a powerful testament to this shift, moving beyond catchy melodies and embracing a darker, more emotionally complex sound. Released in 2015, the song transcends the typical themes of youthful exuberance and romantic angst, instead plunging into the depths of personal turmoil and the corrosive effects of inner darkness on interpersonal relationships. "Jet Black Heart" is not merely a song about heartbreak; it is a raw and vulnerable exploration of internal struggle, self-destructive patterns, and the precarious flicker of hope that persists even in the face of profound emotional chaos."""
|
||||
|
||||
|
||||
Type="mp3"
|
||||
Type="wav"
|
||||
response = requests.post(
|
||||
"http://localhost:8880/dev/captioned_speech",
|
||||
json={
|
||||
|
@ -51,12 +51,12 @@ for chunk in response.iter_lines(decode_unicode=True):
|
|||
f.write(chunk_audio)
|
||||
|
||||
# Print word level timestamps
|
||||
last3=chunk_json["timestamps"][-3]
|
||||
last_chunks={"start_time":chunk_json["timestamps"][-10]["start_time"],"end_time":chunk_json["timestamps"][-3]["end_time"],"word":" ".join([X["word"] for X in chunk_json["timestamps"][-10:-3]])}
|
||||
|
||||
print(f"CUTTING TO {last3['word']}")
|
||||
print(f"CUTTING TO {last_chunks['word']}")
|
||||
|
||||
audioseg=pydub.AudioSegment.from_file(f"outputstream.{Type}",format=Type)
|
||||
audioseg=audioseg[last3["start_time"]*1000:last3["end_time"] * 1000]
|
||||
audioseg=audioseg[last_chunks["start_time"]*1000:last_chunks["end_time"] * 1000]
|
||||
audioseg.export(f"outputstreamcut.{Type}",format=Type)
|
||||
|
||||
|
||||
|
|
|
@ -29,7 +29,8 @@ class Settings(BaseSettings):
|
|||
target_min_tokens: int = 175 # Target minimum tokens per chunk
|
||||
target_max_tokens: int = 250 # Target maximum tokens per chunk
|
||||
absolute_max_tokens: int = 450 # Absolute maximum tokens per chunk
|
||||
advanced_text_normalization: bool = True # Preproesses the text before misiki which leads
|
||||
advanced_text_normalization: bool = True # Preproesses the text before misiki
|
||||
voice_weight_normalization: bool = True # Normalize the voice weights so they add up to 1
|
||||
|
||||
gap_trim_ms: int = 1 # Base amount to trim from streaming chunk ends in milliseconds
|
||||
dynamic_gap_trim_padding_ms: int = 410 # Padding to add to dynamic gap trim
|
||||
|
|
|
@ -259,10 +259,6 @@ class KokoroV1(BaseModelBackend):
|
|||
)
|
||||
if result.pred_dur is not None:
|
||||
try:
|
||||
# Join timestamps for this chunk's tokens
|
||||
KPipeline.join_timestamps(
|
||||
result.tokens, result.pred_dur
|
||||
)
|
||||
|
||||
# Add timestamps with offset
|
||||
for token in result.tokens:
|
||||
|
|
|
@ -1,20 +1,24 @@
|
|||
import base64
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import List, Union, AsyncGenerator, Tuple
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||
from fastapi.responses import JSONResponse, StreamingResponse, FileResponse
|
||||
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
|
||||
from kokoro import KPipeline
|
||||
from loguru import logger
|
||||
|
||||
from ..inference.base import AudioChunk
|
||||
from ..core.config import settings
|
||||
from ..inference.base import AudioChunk
|
||||
from ..services.audio import AudioNormalizer, AudioService
|
||||
from ..services.streaming_audio_writer import StreamingAudioWriter
|
||||
from ..services.temp_manager import TempFileWriter
|
||||
from ..services.text_processing import smart_split
|
||||
from ..services.tts_service import TTSService
|
||||
from ..services.temp_manager import TempFileWriter
|
||||
from ..structures import CaptionedSpeechRequest, CaptionedSpeechResponse, WordTimestamp
|
||||
from ..structures.custom_responses import JSONStreamingResponse
|
||||
from ..structures.text_schemas import (
|
||||
|
@ -22,12 +26,7 @@ from ..structures.text_schemas import (
|
|||
PhonemeRequest,
|
||||
PhonemeResponse,
|
||||
)
|
||||
from .openai_compatible import process_voices, stream_audio_chunks
|
||||
import json
|
||||
import os
|
||||
import base64
|
||||
from pathlib import Path
|
||||
|
||||
from .openai_compatible import process_and_validate_voices, stream_audio_chunks
|
||||
|
||||
router = APIRouter(tags=["text processing"])
|
||||
|
||||
|
@ -169,7 +168,7 @@ async def create_captioned_speech(
|
|||
try:
|
||||
# model_name = get_model_name(request.model)
|
||||
tts_service = await get_tts_service()
|
||||
voice_name = await process_voices(request.voice, tts_service)
|
||||
voice_name = await process_and_validate_voices(request.voice, tts_service)
|
||||
|
||||
# Set content type based on format
|
||||
content_type = {
|
||||
|
@ -254,7 +253,10 @@ async def create_captioned_speech(
|
|||
base64_chunk= base64.b64encode(chunk_data.output).decode("utf-8")
|
||||
|
||||
# Add any chunks that may be in the acumulator into the return word_timestamps
|
||||
chunk_data.word_timestamps=timestamp_acumulator + chunk_data.word_timestamps
|
||||
if chunk_data.word_timestamps != None:
|
||||
chunk_data.word_timestamps = timestamp_acumulator + chunk_data.word_timestamps
|
||||
else:
|
||||
chunk_data.word_timestamps = []
|
||||
timestamp_acumulator=[]
|
||||
|
||||
yield CaptionedSpeechResponse(audio=base64_chunk,audio_format=content_type,timestamps=chunk_data.word_timestamps)
|
||||
|
|
|
@ -5,20 +5,19 @@ import json
|
|||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from typing import AsyncGenerator, Dict, List, Union, Tuple
|
||||
from typing import AsyncGenerator, Dict, List, Tuple, Union
|
||||
from urllib import response
|
||||
import numpy as np
|
||||
|
||||
import aiofiles
|
||||
|
||||
from structures.schemas import CaptionedSpeechRequest
|
||||
import numpy as np
|
||||
import torch
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from loguru import logger
|
||||
from structures.schemas import CaptionedSpeechRequest
|
||||
|
||||
from ..inference.base import AudioChunk
|
||||
from ..core.config import settings
|
||||
from ..inference.base import AudioChunk
|
||||
from ..services.audio import AudioService
|
||||
from ..services.tts_service import TTSService
|
||||
from ..structures import OpenAISpeechRequest
|
||||
|
@ -80,7 +79,7 @@ def get_model_name(model: str) -> str:
|
|||
return base_name + ".pth"
|
||||
|
||||
|
||||
async def process_voices(
|
||||
async def process_and_validate_voices(
|
||||
voice_input: Union[str, List[str]], tts_service: TTSService
|
||||
) -> str:
|
||||
"""Process voice input, handling both string and list formats
|
||||
|
@ -88,53 +87,57 @@ async def process_voices(
|
|||
Returns:
|
||||
Voice name to use (with weights if specified)
|
||||
"""
|
||||
voices = []
|
||||
# Convert input to list of voices
|
||||
if isinstance(voice_input, str):
|
||||
# Check if it's an OpenAI voice name
|
||||
mapped_voice = _openai_mappings["voices"].get(voice_input)
|
||||
if mapped_voice:
|
||||
voice_input = mapped_voice
|
||||
# Split on + but preserve any parentheses
|
||||
voices = []
|
||||
for part in voice_input.split("+"):
|
||||
part = part.strip()
|
||||
if not part:
|
||||
continue
|
||||
# Extract voice name without weight
|
||||
voice_name = part.split("(")[0].strip()
|
||||
# Check if it's a valid voice
|
||||
available_voices = await tts_service.list_voices()
|
||||
if voice_name not in available_voices:
|
||||
raise ValueError(
|
||||
f"Voice '{voice_name}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
||||
)
|
||||
voices.append(part)
|
||||
voice_input=voice_input.replace(" ","").strip()
|
||||
|
||||
if voice_input[-1] in "+-" or voice_input[0] in "+-":
|
||||
raise ValueError(
|
||||
f"Voice combination contains empty combine items"
|
||||
)
|
||||
|
||||
if re.search(r"[+-]{2,}", voice_input) is not None:
|
||||
raise ValueError(
|
||||
f"Voice combination contains empty combine items"
|
||||
)
|
||||
voices = re.split(r"([-+])", voice_input)
|
||||
else:
|
||||
# For list input, map each voice if it's an OpenAI voice name
|
||||
voices = []
|
||||
for v in voice_input:
|
||||
mapped = _openai_mappings["voices"].get(v, v)
|
||||
voice_name = mapped.split("(")[0].strip()
|
||||
# Check if it's a valid voice
|
||||
available_voices = await tts_service.list_voices()
|
||||
if voice_name not in available_voices:
|
||||
raise ValueError(
|
||||
f"Voice '{voice_name}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
||||
voices = [[item,"+"] for item in voice_input][:-1]
|
||||
|
||||
available_voices = await tts_service.list_voices()
|
||||
|
||||
for voice_index in range(0,len(voices), 2):
|
||||
|
||||
mapped_voice = voices[voice_index].split("(")
|
||||
mapped_voice = list(map(str.strip, mapped_voice))
|
||||
|
||||
if len(mapped_voice) > 2:
|
||||
raise ValueError(
|
||||
f"Voice '{voices[voice_index]}' contains too many weight items"
|
||||
)
|
||||
voices.append(mapped)
|
||||
|
||||
if mapped_voice.count(")") > 1:
|
||||
raise ValueError(
|
||||
f"Voice '{voices[voice_index]}' contains too many weight items"
|
||||
)
|
||||
|
||||
mapped_voice[0] = _openai_mappings["voices"].get(mapped_voice[0], mapped_voice[0])
|
||||
|
||||
if not voices:
|
||||
raise ValueError("No voices provided")
|
||||
|
||||
# For multiple voices, combine them with +
|
||||
return "+".join(voices)
|
||||
if mapped_voice[0] not in available_voices:
|
||||
raise ValueError(
|
||||
f"Voice '{mapped_voice[0]}' not found. Available voices: {', '.join(sorted(available_voices))}"
|
||||
)
|
||||
|
||||
voices[voice_index] = "(".join(mapped_voice)
|
||||
|
||||
return "".join(voices)
|
||||
|
||||
async def stream_audio_chunks(
|
||||
tts_service: TTSService, request: Union[OpenAISpeechRequest,CaptionedSpeechRequest], client_request: Request
|
||||
) -> AsyncGenerator[AudioChunk, None]:
|
||||
"""Stream audio chunks as they're generated with client disconnect handling"""
|
||||
voice_name = await process_voices(request.voice, tts_service)
|
||||
voice_name = await process_and_validate_voices(request.voice, tts_service)
|
||||
|
||||
unique_properties={
|
||||
"return_timestamps":False
|
||||
|
@ -191,7 +194,7 @@ async def create_speech(
|
|||
try:
|
||||
# model_name = get_model_name(request.model)
|
||||
tts_service = await get_tts_service()
|
||||
voice_name = await process_voices(request.voice, tts_service)
|
||||
voice_name = await process_and_validate_voices(request.voice, tts_service)
|
||||
|
||||
# Set content type based on format
|
||||
content_type = {
|
||||
|
|
|
@ -121,7 +121,7 @@ class AudioService:
|
|||
is_last_chunk: bool = False,
|
||||
trim_audio: bool = True,
|
||||
normalizer: AudioNormalizer = None,
|
||||
) -> Tuple[AudioChunk]:
|
||||
) -> AudioChunk:
|
||||
"""Convert audio data to specified format with streaming support
|
||||
|
||||
Args:
|
||||
|
|
|
@ -2,24 +2,26 @@
|
|||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import time
|
||||
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
||||
|
||||
from ..inference.base import AudioChunk
|
||||
import numpy as np
|
||||
import torch
|
||||
from kokoro import KPipeline
|
||||
from loguru import logger
|
||||
|
||||
from ..core.config import settings
|
||||
from ..inference.base import AudioChunk
|
||||
from ..inference.kokoro_v1 import KokoroV1
|
||||
from ..inference.model_manager import get_manager as get_model_manager
|
||||
from ..inference.voice_manager import get_manager as get_voice_manager
|
||||
from ..structures.schemas import NormalizationOptions
|
||||
from .audio import AudioNormalizer, AudioService
|
||||
from .text_processing import tokenize
|
||||
from .text_processing.text_processor import process_text_chunk, smart_split
|
||||
from ..structures.schemas import NormalizationOptions
|
||||
|
||||
|
||||
class TTSService:
|
||||
"""Text-to-speech service."""
|
||||
|
@ -163,7 +165,15 @@ class TTSService:
|
|||
except Exception as e:
|
||||
logger.error(f"Failed to process tokens: {str(e)}")
|
||||
|
||||
async def _get_voice_path(self, voice: str) -> Tuple[str, str]:
|
||||
async def _load_voice_from_path(self, path: str, weight: float):
|
||||
# Check if the path is None and raise a ValueError if it is not
|
||||
if not path:
|
||||
raise ValueError(f"Voice not found at path: {path}")
|
||||
|
||||
logger.debug(f"Loading voice tensor from path: {path}")
|
||||
return torch.load(path, map_location="cpu") * weight
|
||||
|
||||
async def _get_voices_path(self, voice: str) -> Tuple[str, str]:
|
||||
"""Get voice path, handling combined voices.
|
||||
|
||||
Args:
|
||||
|
@ -176,64 +186,62 @@ class TTSService:
|
|||
RuntimeError: If voice not found
|
||||
"""
|
||||
try:
|
||||
# Check if it's a combined voice
|
||||
if "+" in voice:
|
||||
# Split on + but preserve any parentheses
|
||||
voice_parts = []
|
||||
weights = []
|
||||
for part in voice.split("+"):
|
||||
part = part.strip()
|
||||
if not part:
|
||||
continue
|
||||
# Extract voice name and weight if present
|
||||
if "(" in part and ")" in part:
|
||||
voice_name = part.split("(")[0].strip()
|
||||
weight = float(part.split("(")[1].split(")")[0])
|
||||
else:
|
||||
voice_name = part
|
||||
weight = 1.0
|
||||
voice_parts.append(voice_name)
|
||||
weights.append(weight)
|
||||
# Split the voice on + and - and ensure that they get added to the list eg: hi+bob = ["hi","+","bob"]
|
||||
split_voice = re.split(r"([-+])", voice)
|
||||
|
||||
if len(voice_parts) < 2:
|
||||
raise RuntimeError(f"Invalid combined voice name: {voice}")
|
||||
# If it is only once voice there is no point in loading it up, doing nothing with it, then saving it
|
||||
if len(split_voice) == 1:
|
||||
|
||||
# Normalize weights to sum to 1
|
||||
total_weight = sum(weights)
|
||||
weights = [w / total_weight for w in weights]
|
||||
# Since its a single voice the only time that the weight would matter is if voice_weight_normalization is off
|
||||
if ("(" not in voice and ")" not in voice) or settings.voice_weight_normalization == True:
|
||||
|
||||
# Load and combine voices
|
||||
voice_tensors = []
|
||||
for v, w in zip(voice_parts, weights):
|
||||
path = await self._voice_manager.get_voice_path(v)
|
||||
path = await self._voice_manager.get_voice_path(voice)
|
||||
if not path:
|
||||
raise RuntimeError(f"Voice not found: {v}")
|
||||
logger.debug(f"Loading voice tensor from: {path}")
|
||||
voice_tensor = torch.load(path, map_location="cpu")
|
||||
voice_tensors.append(voice_tensor * w)
|
||||
raise RuntimeError(f"Voice not found: {voice}")
|
||||
logger.debug(f"Using single voice path: {path}")
|
||||
return voice, path
|
||||
|
||||
# Sum the weighted voice tensors
|
||||
logger.debug(
|
||||
f"Combining {len(voice_tensors)} voice tensors with weights {weights}"
|
||||
)
|
||||
combined = torch.sum(torch.stack(voice_tensors), dim=0)
|
||||
total_weight = 0
|
||||
|
||||
for voice_index in range(0,len(split_voice),2):
|
||||
voice_object = split_voice[voice_index]
|
||||
|
||||
# Save combined tensor
|
||||
temp_dir = tempfile.gettempdir()
|
||||
combined_path = os.path.join(temp_dir, f"{voice}.pt")
|
||||
logger.debug(f"Saving combined voice to: {combined_path}")
|
||||
torch.save(combined, combined_path)
|
||||
if "(" in voice_object and ")" in voice_object:
|
||||
voice_name = voice_object.split("(")[0].strip()
|
||||
voice_weight = float(voice_object.split("(")[1].split(")")[0])
|
||||
else:
|
||||
voice_name = voice_object
|
||||
voice_weight = 1
|
||||
|
||||
return voice, combined_path
|
||||
else:
|
||||
# Single voice
|
||||
if "(" in voice and ")" in voice:
|
||||
voice = voice.split("(")[0].strip()
|
||||
path = await self._voice_manager.get_voice_path(voice)
|
||||
if not path:
|
||||
raise RuntimeError(f"Voice not found: {voice}")
|
||||
logger.debug(f"Using single voice path: {path}")
|
||||
return voice, path
|
||||
total_weight += voice_weight
|
||||
split_voice[voice_index] = (voice_name, voice_weight)
|
||||
|
||||
# If voice_weight_normalization is false prevent normalizing the weights by setting the total_weight to 1 so it divides each weight by 1
|
||||
if settings.voice_weight_normalization == False:
|
||||
total_weight = 1
|
||||
|
||||
# Load the first voice as the starting point for voices to be combined onto
|
||||
path = await self._voice_manager.get_voice_path(split_voice[0][0])
|
||||
combined_tensor = await self._load_voice_from_path(path, split_voice[0][1] / total_weight)
|
||||
|
||||
# Loop through each + or - in split_voice so they can be applied to combined voice
|
||||
for operation_index in range(1,len(split_voice) - 1, 2):
|
||||
# Get the voice path of the voice 1 index ahead of the operator
|
||||
path = await self._voice_manager.get_voice_path(split_voice[operation_index+1][0])
|
||||
voice_tensor = await self._load_voice_from_path(path, split_voice[operation_index + 1][1] / total_weight)
|
||||
|
||||
# Either add or subtract the voice from the current combined voice
|
||||
if split_voice[operation_index] == "+":
|
||||
combined_tensor += voice_tensor
|
||||
else:
|
||||
combined_tensor -= voice_tensor
|
||||
|
||||
# Save the new combined voice so it can be loaded latter
|
||||
temp_dir = tempfile.gettempdir()
|
||||
combined_path = os.path.join(temp_dir, f"{voice}.pt")
|
||||
logger.debug(f"Saving combined voice to: {combined_path}")
|
||||
torch.save(combined_tensor, combined_path)
|
||||
return voice, combined_path
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get voice path: {e}")
|
||||
raise
|
||||
|
@ -257,7 +265,7 @@ class TTSService:
|
|||
backend = self.model_manager.get_backend()
|
||||
|
||||
# Get voice path, handling combined voices
|
||||
voice_name, voice_path = await self._get_voice_path(voice)
|
||||
voice_name, voice_path = await self._get_voices_path(voice)
|
||||
logger.debug(f"Using voice path: {voice_path}")
|
||||
|
||||
# Use provided lang_code or determine from voice name
|
||||
|
@ -362,6 +370,7 @@ class TTSService:
|
|||
Returns:
|
||||
Combined voice tensor
|
||||
"""
|
||||
|
||||
return await self._voice_manager.combine_voices(voices)
|
||||
|
||||
async def list_voices(self) -> List[str]:
|
||||
|
@ -390,7 +399,7 @@ class TTSService:
|
|||
try:
|
||||
# Get backend and voice path
|
||||
backend = self.model_manager.get_backend()
|
||||
voice_name, voice_path = await self._get_voice_path(voice)
|
||||
voice_name, voice_path = await self._get_voices_path(voice)
|
||||
|
||||
if isinstance(backend, KokoroV1):
|
||||
# For Kokoro V1, use generate_from_tokens with raw phonemes
|
||||
|
|
|
@ -74,7 +74,7 @@ async def test_get_voice_path_single():
|
|||
mock_get_voice.return_value = voice_manager
|
||||
|
||||
service = await TTSService.create("test_output")
|
||||
name, path = await service._get_voice_path("voice1")
|
||||
name, path = await service._get_voices_path("voice1")
|
||||
assert name == "voice1"
|
||||
assert path == "/path/to/voice1.pt"
|
||||
voice_manager.get_voice_path.assert_called_once_with("voice1")
|
||||
|
@ -100,7 +100,7 @@ async def test_get_voice_path_combined():
|
|||
mock_load.return_value = torch.ones(10)
|
||||
|
||||
service = await TTSService.create("test_output")
|
||||
name, path = await service._get_voice_path("voice1+voice2")
|
||||
name, path = await service._get_voices_path("voice1+voice2")
|
||||
assert name == "voice1+voice2"
|
||||
assert path.endswith("voice1+voice2.pt")
|
||||
mock_save.assert_called_once()
|
||||
|
|
Loading…
Add table
Reference in a new issue