streaming word level time stamps

This commit is contained in:
Fireblade 2025-02-14 13:37:42 -05:00
parent 4027768920
commit 0b5ec320c7
7 changed files with 101 additions and 22 deletions

View file

@ -13,7 +13,7 @@ from ..core.config import settings
from ..core.model_config import model_config from ..core.model_config import model_config
from .base import BaseModelBackend from .base import BaseModelBackend
from .base import AudioChunk from .base import AudioChunk
from ..structures.schemas import WordTimestamp
class KokoroV1(BaseModelBackend): class KokoroV1(BaseModelBackend):
"""Kokoro backend with controlled resource management.""" """Kokoro backend with controlled resource management."""
@ -281,11 +281,11 @@ class KokoroV1(BaseModelBackend):
start_time = float(token.start_ts) + current_offset start_time = float(token.start_ts) + current_offset
end_time = float(token.end_ts) + current_offset end_time = float(token.end_ts) + current_offset
word_timestamps.append( word_timestamps.append(
{ WordTimestamp(
"word": str(token.text).strip(), word=str(token.text).strip(),
"start_time": start_time, start_time=start_time,
"end_time": end_time, end_time=end_time
} )
) )
logger.debug( logger.debug(
f"Added timestamp for word '{token.text}': {start_time:.3f}s - {end_time:.3f}s" f"Added timestamp for word '{token.text}': {start_time:.3f}s - {end_time:.3f}s"

View file

@ -15,6 +15,7 @@ from ..services.text_processing import smart_split
from ..services.tts_service import TTSService from ..services.tts_service import TTSService
from ..services.temp_manager import TempFileWriter from ..services.temp_manager import TempFileWriter
from ..structures import CaptionedSpeechRequest, CaptionedSpeechResponse, WordTimestamp from ..structures import CaptionedSpeechRequest, CaptionedSpeechResponse, WordTimestamp
from ..structures.custom_responses import JSONStreamingResponse
from ..structures.text_schemas import ( from ..structures.text_schemas import (
GenerateFromPhonemesRequest, GenerateFromPhonemesRequest,
PhonemeRequest, PhonemeRequest,
@ -23,6 +24,7 @@ from ..structures.text_schemas import (
from .openai_compatible import process_voices, stream_audio_chunks from .openai_compatible import process_voices, stream_audio_chunks
import json import json
import os import os
import base64
from pathlib import Path from pathlib import Path
@ -240,12 +242,10 @@ async def create_captioned_speech(
async def dual_output(): async def dual_output():
try: try:
# Write chunks to temp file and stream # Write chunks to temp file and stream
async for chunk in generator: async for chunk,chunk_data in generator:
if chunk: # Skip empty chunks if chunk: # Skip empty chunks
await temp_writer.write(chunk) await temp_writer.write(chunk)
#if return_json:
# yield chunk, chunk_data
#else:
yield chunk yield chunk
# Finalize the temp file # Finalize the temp file
@ -260,14 +260,29 @@ async def create_captioned_speech(
await temp_writer.__aexit__(None, None, None) await temp_writer.__aexit__(None, None, None)
# Stream with temp file writing # Stream with temp file writing
return StreamingResponse( return JSONStreamingResponse(
dual_output(), media_type=content_type, headers=headers dual_output(), media_type="application/json", headers=headers
) )
async def single_output():
try:
# Stream chunks
async for chunk,chunk_data in generator:
if chunk: # Skip empty chunks
# Encode the chunk bytes into base 64
base64_chunk= base64.b64encode(chunk).decode("utf-8")
yield CaptionedSpeechResponse(audio=base64_chunk,audio_format=content_type,words=chunk_data.word_timestamps)
except Exception as e:
logger.error(f"Error in single output streaming: {e}")
raise
# NEED TO DO REPLACE THE RETURN WITH A JSON OBJECT CONTAINING BOTH THE FILE AND THE WORD TIMESTAMPS
# Standard streaming without download link # Standard streaming without download link
return StreamingResponse( return JSONStreamingResponse(
generator, single_output(),
media_type=content_type, media_type="application/json",
headers={ headers={
"Content-Disposition": f"attachment; filename=speech.{request.response_format}", "Content-Disposition": f"attachment; filename=speech.{request.response_format}",
"X-Accel-Buffering": "no", "X-Accel-Buffering": "no",
@ -283,6 +298,7 @@ async def create_captioned_speech(
speed=request.speed, speed=request.speed,
lang_code=request.lang_code, lang_code=request.lang_code,
) )
content, audio_data = await AudioService.convert_audio( content, audio_data = await AudioService.convert_audio(
audio_data, audio_data,
24000, 24000,

View file

@ -11,6 +11,7 @@ import numpy as np
import aiofiles import aiofiles
from structures.schemas import CaptionedSpeechRequest
import torch import torch
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
from fastapi.responses import FileResponse, StreamingResponse from fastapi.responses import FileResponse, StreamingResponse
@ -130,11 +131,17 @@ async def process_voices(
async def stream_audio_chunks( async def stream_audio_chunks(
tts_service: TTSService, request: OpenAISpeechRequest, client_request: Request tts_service: TTSService, request: Union[OpenAISpeechRequest,CaptionedSpeechRequest], client_request: Request
) -> AsyncGenerator[Tuple[Union[np.ndarray,bytes],AudioChunk], None]: ) -> AsyncGenerator[Tuple[Union[np.ndarray,bytes],AudioChunk], None]:
"""Stream audio chunks as they're generated with client disconnect handling""" """Stream audio chunks as they're generated with client disconnect handling"""
voice_name = await process_voices(request.voice, tts_service) voice_name = await process_voices(request.voice, tts_service)
unique_properties={
"return_timestamps":False
}
if hasattr(request, "return_timestamps"):
unique_properties["return_timestamps"]=request.return_timestamps
try: try:
logger.info(f"Starting audio generation with lang_code: {request.lang_code}") logger.info(f"Starting audio generation with lang_code: {request.lang_code}")
async for chunk, chunk_data in tts_service.generate_audio_stream( async for chunk, chunk_data in tts_service.generate_audio_stream(
@ -144,7 +151,7 @@ async def stream_audio_chunks(
output_format=request.response_format, output_format=request.response_format,
lang_code=request.lang_code or request.voice[0], lang_code=request.lang_code or request.voice[0],
normalization_options=request.normalization_options, normalization_options=request.normalization_options,
return_timestamps=False, return_timestamps=unique_properties["return_timestamps"],
): ):
# Check if client is still connected # Check if client is still connected

View file

@ -192,18 +192,22 @@ class AudioService:
normalizer = AudioNormalizer() normalizer = AudioNormalizer()
audio_chunk.audio=normalizer.normalize(audio_chunk.audio) audio_chunk.audio=normalizer.normalize(audio_chunk.audio)
trimed_samples=0
# Trim start and end if enough samples # Trim start and end if enough samples
if len(audio_chunk.audio) > (2 * normalizer.samples_to_trim): if len(audio_chunk.audio) > (2 * normalizer.samples_to_trim):
audio_chunk.audio = audio_chunk.audio[normalizer.samples_to_trim : -normalizer.samples_to_trim] audio_chunk.audio = audio_chunk.audio[normalizer.samples_to_trim : -normalizer.samples_to_trim]
trimed_samples+=normalizer.samples_to_trim
# Find non silent portion and trim # Find non silent portion and trim
start_index,end_index=normalizer.find_first_last_non_silent(audio_chunk.audio,chunk_text,speed,is_last_chunk=is_last_chunk) start_index,end_index=normalizer.find_first_last_non_silent(audio_chunk.audio,chunk_text,speed,is_last_chunk=is_last_chunk)
audio_chunk.audio=audio_chunk.audio[start_index:end_index] audio_chunk.audio=audio_chunk.audio[start_index:end_index]
trimed_samples+=start_index
if audio_chunk.word_timestamps is not None: if audio_chunk.word_timestamps is not None:
for timestamp in audio_chunk.word_timestamps: for timestamp in audio_chunk.word_timestamps:
timestamp["start_time"]-=start_index / 24000 timestamp.start_time-=trimed_samples / 24000
timestamp["end_time"]-=start_index / 24000 timestamp.end_time-=trimed_samples / 24000
return audio_chunk return audio_chunk

View file

@ -282,8 +282,8 @@ class TTSService:
): ):
if chunk_data.word_timestamps is not None: if chunk_data.word_timestamps is not None:
for timestamp in chunk_data.word_timestamps: for timestamp in chunk_data.word_timestamps:
timestamp["start_time"]+=current_offset timestamp.start_time+=current_offset
timestamp["end_time"]+=current_offset timestamp.end_time+=current_offset
current_offset+=len(chunk_data.audio) / 24000 current_offset+=len(chunk_data.audio) / 24000

View file

@ -0,0 +1,51 @@
from collections.abc import AsyncIterable, Iterable
import json
import typing
from pydantic import BaseModel
from starlette.background import BackgroundTask
from starlette.concurrency import iterate_in_threadpool
from starlette.responses import JSONResponse, StreamingResponse
class JSONStreamingResponse(StreamingResponse, JSONResponse):
"""StreamingResponse that also render with JSON."""
def __init__(
self,
content: Iterable | AsyncIterable,
status_code: int = 200,
headers: dict[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
) -> None:
if isinstance(content, AsyncIterable):
self._content_iterable: AsyncIterable = content
else:
self._content_iterable = iterate_in_threadpool(content)
async def body_iterator() -> AsyncIterable[bytes]:
async for content_ in self._content_iterable:
if isinstance(content_, BaseModel):
content_ = content_.model_dump()
yield self.render(content_)
self.body_iterator = body_iterator()
self.status_code = status_code
if media_type is not None:
self.media_type = media_type
self.background = background
self.init_headers(headers)
def render(self, content: typing.Any) -> bytes:
return (json.dumps(
content,
ensure_ascii=False,
allow_nan=False,
indent=None,
separators=(",", ":"),
) + "\n").encode("utf-8")

View file

@ -33,7 +33,8 @@ class WordTimestamp(BaseModel):
class CaptionedSpeechResponse(BaseModel): class CaptionedSpeechResponse(BaseModel):
"""Response schema for captioned speech endpoint""" """Response schema for captioned speech endpoint"""
audio: bytes = Field(..., description="The generated audio data") audio: str = Field(..., description="The generated audio data encoded in base 64")
audio_format: str = Field(..., description="The format of the output audio")
words: List[WordTimestamp] = Field(..., description="Word-level timestamps") words: List[WordTimestamp] = Field(..., description="Word-level timestamps")
class NormalizationOptions(BaseModel): class NormalizationOptions(BaseModel):