streaming word level time stamps

This commit is contained in:
Fireblade 2025-02-14 13:37:42 -05:00
parent 4027768920
commit 0b5ec320c7
7 changed files with 101 additions and 22 deletions

View file

@ -13,7 +13,7 @@ from ..core.config import settings
from ..core.model_config import model_config
from .base import BaseModelBackend
from .base import AudioChunk
from ..structures.schemas import WordTimestamp
class KokoroV1(BaseModelBackend):
"""Kokoro backend with controlled resource management."""
@ -281,11 +281,11 @@ class KokoroV1(BaseModelBackend):
start_time = float(token.start_ts) + current_offset
end_time = float(token.end_ts) + current_offset
word_timestamps.append(
{
"word": str(token.text).strip(),
"start_time": start_time,
"end_time": end_time,
}
WordTimestamp(
word=str(token.text).strip(),
start_time=start_time,
end_time=end_time
)
)
logger.debug(
f"Added timestamp for word '{token.text}': {start_time:.3f}s - {end_time:.3f}s"

View file

@ -15,6 +15,7 @@ from ..services.text_processing import smart_split
from ..services.tts_service import TTSService
from ..services.temp_manager import TempFileWriter
from ..structures import CaptionedSpeechRequest, CaptionedSpeechResponse, WordTimestamp
from ..structures.custom_responses import JSONStreamingResponse
from ..structures.text_schemas import (
GenerateFromPhonemesRequest,
PhonemeRequest,
@ -23,6 +24,7 @@ from ..structures.text_schemas import (
from .openai_compatible import process_voices, stream_audio_chunks
import json
import os
import base64
from pathlib import Path
@ -240,12 +242,10 @@ async def create_captioned_speech(
async def dual_output():
try:
# Write chunks to temp file and stream
async for chunk in generator:
async for chunk,chunk_data in generator:
if chunk: # Skip empty chunks
await temp_writer.write(chunk)
#if return_json:
# yield chunk, chunk_data
#else:
yield chunk
# Finalize the temp file
@ -260,14 +260,29 @@ async def create_captioned_speech(
await temp_writer.__aexit__(None, None, None)
# Stream with temp file writing
return StreamingResponse(
dual_output(), media_type=content_type, headers=headers
return JSONStreamingResponse(
dual_output(), media_type="application/json", headers=headers
)
async def single_output():
try:
# Stream chunks
async for chunk,chunk_data in generator:
if chunk: # Skip empty chunks
# Encode the chunk bytes into base 64
base64_chunk= base64.b64encode(chunk).decode("utf-8")
yield CaptionedSpeechResponse(audio=base64_chunk,audio_format=content_type,words=chunk_data.word_timestamps)
except Exception as e:
logger.error(f"Error in single output streaming: {e}")
raise
# NEED TO DO REPLACE THE RETURN WITH A JSON OBJECT CONTAINING BOTH THE FILE AND THE WORD TIMESTAMPS
# Standard streaming without download link
return StreamingResponse(
generator,
media_type=content_type,
return JSONStreamingResponse(
single_output(),
media_type="application/json",
headers={
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
"X-Accel-Buffering": "no",
@ -283,6 +298,7 @@ async def create_captioned_speech(
speed=request.speed,
lang_code=request.lang_code,
)
content, audio_data = await AudioService.convert_audio(
audio_data,
24000,

View file

@ -11,6 +11,7 @@ import numpy as np
import aiofiles
from structures.schemas import CaptionedSpeechRequest
import torch
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
from fastapi.responses import FileResponse, StreamingResponse
@ -130,11 +131,17 @@ async def process_voices(
async def stream_audio_chunks(
tts_service: TTSService, request: OpenAISpeechRequest, client_request: Request
tts_service: TTSService, request: Union[OpenAISpeechRequest,CaptionedSpeechRequest], client_request: Request
) -> AsyncGenerator[Tuple[Union[np.ndarray,bytes],AudioChunk], None]:
"""Stream audio chunks as they're generated with client disconnect handling"""
voice_name = await process_voices(request.voice, tts_service)
unique_properties={
"return_timestamps":False
}
if hasattr(request, "return_timestamps"):
unique_properties["return_timestamps"]=request.return_timestamps
try:
logger.info(f"Starting audio generation with lang_code: {request.lang_code}")
async for chunk, chunk_data in tts_service.generate_audio_stream(
@ -144,7 +151,7 @@ async def stream_audio_chunks(
output_format=request.response_format,
lang_code=request.lang_code or request.voice[0],
normalization_options=request.normalization_options,
return_timestamps=False,
return_timestamps=unique_properties["return_timestamps"],
):
# Check if client is still connected

View file

@ -192,18 +192,22 @@ class AudioService:
normalizer = AudioNormalizer()
audio_chunk.audio=normalizer.normalize(audio_chunk.audio)
trimed_samples=0
# Trim start and end if enough samples
if len(audio_chunk.audio) > (2 * normalizer.samples_to_trim):
audio_chunk.audio = audio_chunk.audio[normalizer.samples_to_trim : -normalizer.samples_to_trim]
trimed_samples+=normalizer.samples_to_trim
# Find non silent portion and trim
start_index,end_index=normalizer.find_first_last_non_silent(audio_chunk.audio,chunk_text,speed,is_last_chunk=is_last_chunk)
audio_chunk.audio=audio_chunk.audio[start_index:end_index]
trimed_samples+=start_index
if audio_chunk.word_timestamps is not None:
for timestamp in audio_chunk.word_timestamps:
timestamp["start_time"]-=start_index / 24000
timestamp["end_time"]-=start_index / 24000
timestamp.start_time-=trimed_samples / 24000
timestamp.end_time-=trimed_samples / 24000
return audio_chunk

View file

@ -282,8 +282,8 @@ class TTSService:
):
if chunk_data.word_timestamps is not None:
for timestamp in chunk_data.word_timestamps:
timestamp["start_time"]+=current_offset
timestamp["end_time"]+=current_offset
timestamp.start_time+=current_offset
timestamp.end_time+=current_offset
current_offset+=len(chunk_data.audio) / 24000

View file

@ -0,0 +1,51 @@
from collections.abc import AsyncIterable, Iterable
import json
import typing
from pydantic import BaseModel
from starlette.background import BackgroundTask
from starlette.concurrency import iterate_in_threadpool
from starlette.responses import JSONResponse, StreamingResponse
class JSONStreamingResponse(StreamingResponse, JSONResponse):
"""StreamingResponse that also render with JSON."""
def __init__(
self,
content: Iterable | AsyncIterable,
status_code: int = 200,
headers: dict[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
) -> None:
if isinstance(content, AsyncIterable):
self._content_iterable: AsyncIterable = content
else:
self._content_iterable = iterate_in_threadpool(content)
async def body_iterator() -> AsyncIterable[bytes]:
async for content_ in self._content_iterable:
if isinstance(content_, BaseModel):
content_ = content_.model_dump()
yield self.render(content_)
self.body_iterator = body_iterator()
self.status_code = status_code
if media_type is not None:
self.media_type = media_type
self.background = background
self.init_headers(headers)
def render(self, content: typing.Any) -> bytes:
return (json.dumps(
content,
ensure_ascii=False,
allow_nan=False,
indent=None,
separators=(",", ":"),
) + "\n").encode("utf-8")

View file

@ -33,7 +33,8 @@ class WordTimestamp(BaseModel):
class CaptionedSpeechResponse(BaseModel):
"""Response schema for captioned speech endpoint"""
audio: bytes = Field(..., description="The generated audio data")
audio: str = Field(..., description="The generated audio data encoded in base 64")
audio_format: str = Field(..., description="The format of the output audio")
words: List[WordTimestamp] = Field(..., description="Word-level timestamps")
class NormalizationOptions(BaseModel):