mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
streaming word level time stamps
This commit is contained in:
parent
4027768920
commit
0b5ec320c7
7 changed files with 101 additions and 22 deletions
|
@ -13,7 +13,7 @@ from ..core.config import settings
|
|||
from ..core.model_config import model_config
|
||||
from .base import BaseModelBackend
|
||||
from .base import AudioChunk
|
||||
|
||||
from ..structures.schemas import WordTimestamp
|
||||
class KokoroV1(BaseModelBackend):
|
||||
"""Kokoro backend with controlled resource management."""
|
||||
|
||||
|
@ -281,11 +281,11 @@ class KokoroV1(BaseModelBackend):
|
|||
start_time = float(token.start_ts) + current_offset
|
||||
end_time = float(token.end_ts) + current_offset
|
||||
word_timestamps.append(
|
||||
{
|
||||
"word": str(token.text).strip(),
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
}
|
||||
WordTimestamp(
|
||||
word=str(token.text).strip(),
|
||||
start_time=start_time,
|
||||
end_time=end_time
|
||||
)
|
||||
)
|
||||
logger.debug(
|
||||
f"Added timestamp for word '{token.text}': {start_time:.3f}s - {end_time:.3f}s"
|
||||
|
|
|
@ -15,6 +15,7 @@ from ..services.text_processing import smart_split
|
|||
from ..services.tts_service import TTSService
|
||||
from ..services.temp_manager import TempFileWriter
|
||||
from ..structures import CaptionedSpeechRequest, CaptionedSpeechResponse, WordTimestamp
|
||||
from ..structures.custom_responses import JSONStreamingResponse
|
||||
from ..structures.text_schemas import (
|
||||
GenerateFromPhonemesRequest,
|
||||
PhonemeRequest,
|
||||
|
@ -23,6 +24,7 @@ from ..structures.text_schemas import (
|
|||
from .openai_compatible import process_voices, stream_audio_chunks
|
||||
import json
|
||||
import os
|
||||
import base64
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
|
@ -240,12 +242,10 @@ async def create_captioned_speech(
|
|||
async def dual_output():
|
||||
try:
|
||||
# Write chunks to temp file and stream
|
||||
async for chunk in generator:
|
||||
async for chunk,chunk_data in generator:
|
||||
if chunk: # Skip empty chunks
|
||||
await temp_writer.write(chunk)
|
||||
#if return_json:
|
||||
# yield chunk, chunk_data
|
||||
#else:
|
||||
|
||||
yield chunk
|
||||
|
||||
# Finalize the temp file
|
||||
|
@ -260,14 +260,29 @@ async def create_captioned_speech(
|
|||
await temp_writer.__aexit__(None, None, None)
|
||||
|
||||
# Stream with temp file writing
|
||||
return StreamingResponse(
|
||||
dual_output(), media_type=content_type, headers=headers
|
||||
return JSONStreamingResponse(
|
||||
dual_output(), media_type="application/json", headers=headers
|
||||
)
|
||||
|
||||
async def single_output():
|
||||
try:
|
||||
# Stream chunks
|
||||
async for chunk,chunk_data in generator:
|
||||
if chunk: # Skip empty chunks
|
||||
# Encode the chunk bytes into base 64
|
||||
base64_chunk= base64.b64encode(chunk).decode("utf-8")
|
||||
|
||||
yield CaptionedSpeechResponse(audio=base64_chunk,audio_format=content_type,words=chunk_data.word_timestamps)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in single output streaming: {e}")
|
||||
raise
|
||||
|
||||
# NEED TO DO REPLACE THE RETURN WITH A JSON OBJECT CONTAINING BOTH THE FILE AND THE WORD TIMESTAMPS
|
||||
|
||||
# Standard streaming without download link
|
||||
return StreamingResponse(
|
||||
generator,
|
||||
media_type=content_type,
|
||||
return JSONStreamingResponse(
|
||||
single_output(),
|
||||
media_type="application/json",
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
||||
"X-Accel-Buffering": "no",
|
||||
|
@ -283,6 +298,7 @@ async def create_captioned_speech(
|
|||
speed=request.speed,
|
||||
lang_code=request.lang_code,
|
||||
)
|
||||
|
||||
content, audio_data = await AudioService.convert_audio(
|
||||
audio_data,
|
||||
24000,
|
||||
|
|
|
@ -11,6 +11,7 @@ import numpy as np
|
|||
|
||||
import aiofiles
|
||||
|
||||
from structures.schemas import CaptionedSpeechRequest
|
||||
import torch
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
|
@ -130,11 +131,17 @@ async def process_voices(
|
|||
|
||||
|
||||
async def stream_audio_chunks(
|
||||
tts_service: TTSService, request: OpenAISpeechRequest, client_request: Request
|
||||
tts_service: TTSService, request: Union[OpenAISpeechRequest,CaptionedSpeechRequest], client_request: Request
|
||||
) -> AsyncGenerator[Tuple[Union[np.ndarray,bytes],AudioChunk], None]:
|
||||
"""Stream audio chunks as they're generated with client disconnect handling"""
|
||||
voice_name = await process_voices(request.voice, tts_service)
|
||||
|
||||
unique_properties={
|
||||
"return_timestamps":False
|
||||
}
|
||||
if hasattr(request, "return_timestamps"):
|
||||
unique_properties["return_timestamps"]=request.return_timestamps
|
||||
|
||||
try:
|
||||
logger.info(f"Starting audio generation with lang_code: {request.lang_code}")
|
||||
async for chunk, chunk_data in tts_service.generate_audio_stream(
|
||||
|
@ -144,7 +151,7 @@ async def stream_audio_chunks(
|
|||
output_format=request.response_format,
|
||||
lang_code=request.lang_code or request.voice[0],
|
||||
normalization_options=request.normalization_options,
|
||||
return_timestamps=False,
|
||||
return_timestamps=unique_properties["return_timestamps"],
|
||||
):
|
||||
|
||||
# Check if client is still connected
|
||||
|
|
|
@ -192,18 +192,22 @@ class AudioService:
|
|||
normalizer = AudioNormalizer()
|
||||
|
||||
audio_chunk.audio=normalizer.normalize(audio_chunk.audio)
|
||||
|
||||
trimed_samples=0
|
||||
# Trim start and end if enough samples
|
||||
if len(audio_chunk.audio) > (2 * normalizer.samples_to_trim):
|
||||
audio_chunk.audio = audio_chunk.audio[normalizer.samples_to_trim : -normalizer.samples_to_trim]
|
||||
trimed_samples+=normalizer.samples_to_trim
|
||||
|
||||
# Find non silent portion and trim
|
||||
start_index,end_index=normalizer.find_first_last_non_silent(audio_chunk.audio,chunk_text,speed,is_last_chunk=is_last_chunk)
|
||||
|
||||
audio_chunk.audio=audio_chunk.audio[start_index:end_index]
|
||||
trimed_samples+=start_index
|
||||
|
||||
if audio_chunk.word_timestamps is not None:
|
||||
for timestamp in audio_chunk.word_timestamps:
|
||||
timestamp["start_time"]-=start_index / 24000
|
||||
timestamp["end_time"]-=start_index / 24000
|
||||
timestamp.start_time-=trimed_samples / 24000
|
||||
timestamp.end_time-=trimed_samples / 24000
|
||||
return audio_chunk
|
||||
|
|
@ -282,8 +282,8 @@ class TTSService:
|
|||
):
|
||||
if chunk_data.word_timestamps is not None:
|
||||
for timestamp in chunk_data.word_timestamps:
|
||||
timestamp["start_time"]+=current_offset
|
||||
timestamp["end_time"]+=current_offset
|
||||
timestamp.start_time+=current_offset
|
||||
timestamp.end_time+=current_offset
|
||||
|
||||
current_offset+=len(chunk_data.audio) / 24000
|
||||
|
||||
|
|
51
api/src/structures/custom_responses.py
Normal file
51
api/src/structures/custom_responses.py
Normal file
|
@ -0,0 +1,51 @@
|
|||
from collections.abc import AsyncIterable, Iterable
|
||||
|
||||
import json
|
||||
import typing
|
||||
from pydantic import BaseModel
|
||||
from starlette.background import BackgroundTask
|
||||
from starlette.concurrency import iterate_in_threadpool
|
||||
from starlette.responses import JSONResponse, StreamingResponse
|
||||
|
||||
|
||||
class JSONStreamingResponse(StreamingResponse, JSONResponse):
|
||||
"""StreamingResponse that also render with JSON."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content: Iterable | AsyncIterable,
|
||||
status_code: int = 200,
|
||||
headers: dict[str, str] | None = None,
|
||||
media_type: str | None = None,
|
||||
background: BackgroundTask | None = None,
|
||||
) -> None:
|
||||
if isinstance(content, AsyncIterable):
|
||||
self._content_iterable: AsyncIterable = content
|
||||
else:
|
||||
self._content_iterable = iterate_in_threadpool(content)
|
||||
|
||||
|
||||
|
||||
async def body_iterator() -> AsyncIterable[bytes]:
|
||||
async for content_ in self._content_iterable:
|
||||
if isinstance(content_, BaseModel):
|
||||
content_ = content_.model_dump()
|
||||
yield self.render(content_)
|
||||
|
||||
|
||||
|
||||
self.body_iterator = body_iterator()
|
||||
self.status_code = status_code
|
||||
if media_type is not None:
|
||||
self.media_type = media_type
|
||||
self.background = background
|
||||
self.init_headers(headers)
|
||||
|
||||
def render(self, content: typing.Any) -> bytes:
|
||||
return (json.dumps(
|
||||
content,
|
||||
ensure_ascii=False,
|
||||
allow_nan=False,
|
||||
indent=None,
|
||||
separators=(",", ":"),
|
||||
) + "\n").encode("utf-8")
|
|
@ -33,7 +33,8 @@ class WordTimestamp(BaseModel):
|
|||
class CaptionedSpeechResponse(BaseModel):
|
||||
"""Response schema for captioned speech endpoint"""
|
||||
|
||||
audio: bytes = Field(..., description="The generated audio data")
|
||||
audio: str = Field(..., description="The generated audio data encoded in base 64")
|
||||
audio_format: str = Field(..., description="The format of the output audio")
|
||||
words: List[WordTimestamp] = Field(..., description="Word-level timestamps")
|
||||
|
||||
class NormalizationOptions(BaseModel):
|
||||
|
|
Loading…
Add table
Reference in a new issue