From 0b5ec320c769367cb7d0edcdf47f597a232858f2 Mon Sep 17 00:00:00 2001 From: Fireblade Date: Fri, 14 Feb 2025 13:37:42 -0500 Subject: [PATCH] streaming word level time stamps --- api/src/inference/kokoro_v1.py | 12 +++--- api/src/routers/development.py | 34 ++++++++++++----- api/src/routers/openai_compatible.py | 11 +++++- api/src/services/audio.py | 8 +++- api/src/services/tts_service.py | 4 +- api/src/structures/custom_responses.py | 51 ++++++++++++++++++++++++++ api/src/structures/schemas.py | 3 +- 7 files changed, 101 insertions(+), 22 deletions(-) create mode 100644 api/src/structures/custom_responses.py diff --git a/api/src/inference/kokoro_v1.py b/api/src/inference/kokoro_v1.py index 3361ade..419ade7 100644 --- a/api/src/inference/kokoro_v1.py +++ b/api/src/inference/kokoro_v1.py @@ -13,7 +13,7 @@ from ..core.config import settings from ..core.model_config import model_config from .base import BaseModelBackend from .base import AudioChunk - +from ..structures.schemas import WordTimestamp class KokoroV1(BaseModelBackend): """Kokoro backend with controlled resource management.""" @@ -281,11 +281,11 @@ class KokoroV1(BaseModelBackend): start_time = float(token.start_ts) + current_offset end_time = float(token.end_ts) + current_offset word_timestamps.append( - { - "word": str(token.text).strip(), - "start_time": start_time, - "end_time": end_time, - } + WordTimestamp( + word=str(token.text).strip(), + start_time=start_time, + end_time=end_time + ) ) logger.debug( f"Added timestamp for word '{token.text}': {start_time:.3f}s - {end_time:.3f}s" diff --git a/api/src/routers/development.py b/api/src/routers/development.py index e02c9b5..56226b6 100644 --- a/api/src/routers/development.py +++ b/api/src/routers/development.py @@ -15,6 +15,7 @@ from ..services.text_processing import smart_split from ..services.tts_service import TTSService from ..services.temp_manager import TempFileWriter from ..structures import CaptionedSpeechRequest, CaptionedSpeechResponse, WordTimestamp +from ..structures.custom_responses import JSONStreamingResponse from ..structures.text_schemas import ( GenerateFromPhonemesRequest, PhonemeRequest, @@ -23,6 +24,7 @@ from ..structures.text_schemas import ( from .openai_compatible import process_voices, stream_audio_chunks import json import os +import base64 from pathlib import Path @@ -240,12 +242,10 @@ async def create_captioned_speech( async def dual_output(): try: # Write chunks to temp file and stream - async for chunk in generator: + async for chunk,chunk_data in generator: if chunk: # Skip empty chunks await temp_writer.write(chunk) - #if return_json: - # yield chunk, chunk_data - #else: + yield chunk # Finalize the temp file @@ -260,14 +260,29 @@ async def create_captioned_speech( await temp_writer.__aexit__(None, None, None) # Stream with temp file writing - return StreamingResponse( - dual_output(), media_type=content_type, headers=headers + return JSONStreamingResponse( + dual_output(), media_type="application/json", headers=headers ) + async def single_output(): + try: + # Stream chunks + async for chunk,chunk_data in generator: + if chunk: # Skip empty chunks + # Encode the chunk bytes into base 64 + base64_chunk= base64.b64encode(chunk).decode("utf-8") + + yield CaptionedSpeechResponse(audio=base64_chunk,audio_format=content_type,words=chunk_data.word_timestamps) + except Exception as e: + logger.error(f"Error in single output streaming: {e}") + raise + + # NEED TO DO REPLACE THE RETURN WITH A JSON OBJECT CONTAINING BOTH THE FILE AND THE WORD TIMESTAMPS + # Standard streaming without download link - return StreamingResponse( - generator, - media_type=content_type, + return JSONStreamingResponse( + single_output(), + media_type="application/json", headers={ "Content-Disposition": f"attachment; filename=speech.{request.response_format}", "X-Accel-Buffering": "no", @@ -283,6 +298,7 @@ async def create_captioned_speech( speed=request.speed, lang_code=request.lang_code, ) + content, audio_data = await AudioService.convert_audio( audio_data, 24000, diff --git a/api/src/routers/openai_compatible.py b/api/src/routers/openai_compatible.py index 771830b..98eed83 100644 --- a/api/src/routers/openai_compatible.py +++ b/api/src/routers/openai_compatible.py @@ -11,6 +11,7 @@ import numpy as np import aiofiles +from structures.schemas import CaptionedSpeechRequest import torch from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response from fastapi.responses import FileResponse, StreamingResponse @@ -130,11 +131,17 @@ async def process_voices( async def stream_audio_chunks( - tts_service: TTSService, request: OpenAISpeechRequest, client_request: Request + tts_service: TTSService, request: Union[OpenAISpeechRequest,CaptionedSpeechRequest], client_request: Request ) -> AsyncGenerator[Tuple[Union[np.ndarray,bytes],AudioChunk], None]: """Stream audio chunks as they're generated with client disconnect handling""" voice_name = await process_voices(request.voice, tts_service) + unique_properties={ + "return_timestamps":False + } + if hasattr(request, "return_timestamps"): + unique_properties["return_timestamps"]=request.return_timestamps + try: logger.info(f"Starting audio generation with lang_code: {request.lang_code}") async for chunk, chunk_data in tts_service.generate_audio_stream( @@ -144,7 +151,7 @@ async def stream_audio_chunks( output_format=request.response_format, lang_code=request.lang_code or request.voice[0], normalization_options=request.normalization_options, - return_timestamps=False, + return_timestamps=unique_properties["return_timestamps"], ): # Check if client is still connected diff --git a/api/src/services/audio.py b/api/src/services/audio.py index c713c09..7fdb49f 100644 --- a/api/src/services/audio.py +++ b/api/src/services/audio.py @@ -192,18 +192,22 @@ class AudioService: normalizer = AudioNormalizer() audio_chunk.audio=normalizer.normalize(audio_chunk.audio) + + trimed_samples=0 # Trim start and end if enough samples if len(audio_chunk.audio) > (2 * normalizer.samples_to_trim): audio_chunk.audio = audio_chunk.audio[normalizer.samples_to_trim : -normalizer.samples_to_trim] + trimed_samples+=normalizer.samples_to_trim # Find non silent portion and trim start_index,end_index=normalizer.find_first_last_non_silent(audio_chunk.audio,chunk_text,speed,is_last_chunk=is_last_chunk) audio_chunk.audio=audio_chunk.audio[start_index:end_index] + trimed_samples+=start_index if audio_chunk.word_timestamps is not None: for timestamp in audio_chunk.word_timestamps: - timestamp["start_time"]-=start_index / 24000 - timestamp["end_time"]-=start_index / 24000 + timestamp.start_time-=trimed_samples / 24000 + timestamp.end_time-=trimed_samples / 24000 return audio_chunk \ No newline at end of file diff --git a/api/src/services/tts_service.py b/api/src/services/tts_service.py index a8e05ad..c3b6d26 100644 --- a/api/src/services/tts_service.py +++ b/api/src/services/tts_service.py @@ -282,8 +282,8 @@ class TTSService: ): if chunk_data.word_timestamps is not None: for timestamp in chunk_data.word_timestamps: - timestamp["start_time"]+=current_offset - timestamp["end_time"]+=current_offset + timestamp.start_time+=current_offset + timestamp.end_time+=current_offset current_offset+=len(chunk_data.audio) / 24000 diff --git a/api/src/structures/custom_responses.py b/api/src/structures/custom_responses.py new file mode 100644 index 0000000..3d3a987 --- /dev/null +++ b/api/src/structures/custom_responses.py @@ -0,0 +1,51 @@ +from collections.abc import AsyncIterable, Iterable + +import json +import typing +from pydantic import BaseModel +from starlette.background import BackgroundTask +from starlette.concurrency import iterate_in_threadpool +from starlette.responses import JSONResponse, StreamingResponse + + +class JSONStreamingResponse(StreamingResponse, JSONResponse): + """StreamingResponse that also render with JSON.""" + + def __init__( + self, + content: Iterable | AsyncIterable, + status_code: int = 200, + headers: dict[str, str] | None = None, + media_type: str | None = None, + background: BackgroundTask | None = None, + ) -> None: + if isinstance(content, AsyncIterable): + self._content_iterable: AsyncIterable = content + else: + self._content_iterable = iterate_in_threadpool(content) + + + + async def body_iterator() -> AsyncIterable[bytes]: + async for content_ in self._content_iterable: + if isinstance(content_, BaseModel): + content_ = content_.model_dump() + yield self.render(content_) + + + + self.body_iterator = body_iterator() + self.status_code = status_code + if media_type is not None: + self.media_type = media_type + self.background = background + self.init_headers(headers) + + def render(self, content: typing.Any) -> bytes: + return (json.dumps( + content, + ensure_ascii=False, + allow_nan=False, + indent=None, + separators=(",", ":"), + ) + "\n").encode("utf-8") \ No newline at end of file diff --git a/api/src/structures/schemas.py b/api/src/structures/schemas.py index b838b38..fe9d816 100644 --- a/api/src/structures/schemas.py +++ b/api/src/structures/schemas.py @@ -33,7 +33,8 @@ class WordTimestamp(BaseModel): class CaptionedSpeechResponse(BaseModel): """Response schema for captioned speech endpoint""" - audio: bytes = Field(..., description="The generated audio data") + audio: str = Field(..., description="The generated audio data encoded in base 64") + audio_format: str = Field(..., description="The format of the output audio") words: List[WordTimestamp] = Field(..., description="Word-level timestamps") class NormalizationOptions(BaseModel):