mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
streaming word level time stamps
This commit is contained in:
parent
4027768920
commit
0b5ec320c7
7 changed files with 101 additions and 22 deletions
|
@ -13,7 +13,7 @@ from ..core.config import settings
|
||||||
from ..core.model_config import model_config
|
from ..core.model_config import model_config
|
||||||
from .base import BaseModelBackend
|
from .base import BaseModelBackend
|
||||||
from .base import AudioChunk
|
from .base import AudioChunk
|
||||||
|
from ..structures.schemas import WordTimestamp
|
||||||
class KokoroV1(BaseModelBackend):
|
class KokoroV1(BaseModelBackend):
|
||||||
"""Kokoro backend with controlled resource management."""
|
"""Kokoro backend with controlled resource management."""
|
||||||
|
|
||||||
|
@ -281,11 +281,11 @@ class KokoroV1(BaseModelBackend):
|
||||||
start_time = float(token.start_ts) + current_offset
|
start_time = float(token.start_ts) + current_offset
|
||||||
end_time = float(token.end_ts) + current_offset
|
end_time = float(token.end_ts) + current_offset
|
||||||
word_timestamps.append(
|
word_timestamps.append(
|
||||||
{
|
WordTimestamp(
|
||||||
"word": str(token.text).strip(),
|
word=str(token.text).strip(),
|
||||||
"start_time": start_time,
|
start_time=start_time,
|
||||||
"end_time": end_time,
|
end_time=end_time
|
||||||
}
|
)
|
||||||
)
|
)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Added timestamp for word '{token.text}': {start_time:.3f}s - {end_time:.3f}s"
|
f"Added timestamp for word '{token.text}': {start_time:.3f}s - {end_time:.3f}s"
|
||||||
|
|
|
@ -15,6 +15,7 @@ from ..services.text_processing import smart_split
|
||||||
from ..services.tts_service import TTSService
|
from ..services.tts_service import TTSService
|
||||||
from ..services.temp_manager import TempFileWriter
|
from ..services.temp_manager import TempFileWriter
|
||||||
from ..structures import CaptionedSpeechRequest, CaptionedSpeechResponse, WordTimestamp
|
from ..structures import CaptionedSpeechRequest, CaptionedSpeechResponse, WordTimestamp
|
||||||
|
from ..structures.custom_responses import JSONStreamingResponse
|
||||||
from ..structures.text_schemas import (
|
from ..structures.text_schemas import (
|
||||||
GenerateFromPhonemesRequest,
|
GenerateFromPhonemesRequest,
|
||||||
PhonemeRequest,
|
PhonemeRequest,
|
||||||
|
@ -23,6 +24,7 @@ from ..structures.text_schemas import (
|
||||||
from .openai_compatible import process_voices, stream_audio_chunks
|
from .openai_compatible import process_voices, stream_audio_chunks
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import base64
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
@ -240,12 +242,10 @@ async def create_captioned_speech(
|
||||||
async def dual_output():
|
async def dual_output():
|
||||||
try:
|
try:
|
||||||
# Write chunks to temp file and stream
|
# Write chunks to temp file and stream
|
||||||
async for chunk in generator:
|
async for chunk,chunk_data in generator:
|
||||||
if chunk: # Skip empty chunks
|
if chunk: # Skip empty chunks
|
||||||
await temp_writer.write(chunk)
|
await temp_writer.write(chunk)
|
||||||
#if return_json:
|
|
||||||
# yield chunk, chunk_data
|
|
||||||
#else:
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
# Finalize the temp file
|
# Finalize the temp file
|
||||||
|
@ -260,14 +260,29 @@ async def create_captioned_speech(
|
||||||
await temp_writer.__aexit__(None, None, None)
|
await temp_writer.__aexit__(None, None, None)
|
||||||
|
|
||||||
# Stream with temp file writing
|
# Stream with temp file writing
|
||||||
return StreamingResponse(
|
return JSONStreamingResponse(
|
||||||
dual_output(), media_type=content_type, headers=headers
|
dual_output(), media_type="application/json", headers=headers
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def single_output():
|
||||||
|
try:
|
||||||
|
# Stream chunks
|
||||||
|
async for chunk,chunk_data in generator:
|
||||||
|
if chunk: # Skip empty chunks
|
||||||
|
# Encode the chunk bytes into base 64
|
||||||
|
base64_chunk= base64.b64encode(chunk).decode("utf-8")
|
||||||
|
|
||||||
|
yield CaptionedSpeechResponse(audio=base64_chunk,audio_format=content_type,words=chunk_data.word_timestamps)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in single output streaming: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# NEED TO DO REPLACE THE RETURN WITH A JSON OBJECT CONTAINING BOTH THE FILE AND THE WORD TIMESTAMPS
|
||||||
|
|
||||||
# Standard streaming without download link
|
# Standard streaming without download link
|
||||||
return StreamingResponse(
|
return JSONStreamingResponse(
|
||||||
generator,
|
single_output(),
|
||||||
media_type=content_type,
|
media_type="application/json",
|
||||||
headers={
|
headers={
|
||||||
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
||||||
"X-Accel-Buffering": "no",
|
"X-Accel-Buffering": "no",
|
||||||
|
@ -283,6 +298,7 @@ async def create_captioned_speech(
|
||||||
speed=request.speed,
|
speed=request.speed,
|
||||||
lang_code=request.lang_code,
|
lang_code=request.lang_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
content, audio_data = await AudioService.convert_audio(
|
content, audio_data = await AudioService.convert_audio(
|
||||||
audio_data,
|
audio_data,
|
||||||
24000,
|
24000,
|
||||||
|
|
|
@ -11,6 +11,7 @@ import numpy as np
|
||||||
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
|
|
||||||
|
from structures.schemas import CaptionedSpeechRequest
|
||||||
import torch
|
import torch
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||||
from fastapi.responses import FileResponse, StreamingResponse
|
from fastapi.responses import FileResponse, StreamingResponse
|
||||||
|
@ -130,11 +131,17 @@ async def process_voices(
|
||||||
|
|
||||||
|
|
||||||
async def stream_audio_chunks(
|
async def stream_audio_chunks(
|
||||||
tts_service: TTSService, request: OpenAISpeechRequest, client_request: Request
|
tts_service: TTSService, request: Union[OpenAISpeechRequest,CaptionedSpeechRequest], client_request: Request
|
||||||
) -> AsyncGenerator[Tuple[Union[np.ndarray,bytes],AudioChunk], None]:
|
) -> AsyncGenerator[Tuple[Union[np.ndarray,bytes],AudioChunk], None]:
|
||||||
"""Stream audio chunks as they're generated with client disconnect handling"""
|
"""Stream audio chunks as they're generated with client disconnect handling"""
|
||||||
voice_name = await process_voices(request.voice, tts_service)
|
voice_name = await process_voices(request.voice, tts_service)
|
||||||
|
|
||||||
|
unique_properties={
|
||||||
|
"return_timestamps":False
|
||||||
|
}
|
||||||
|
if hasattr(request, "return_timestamps"):
|
||||||
|
unique_properties["return_timestamps"]=request.return_timestamps
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Starting audio generation with lang_code: {request.lang_code}")
|
logger.info(f"Starting audio generation with lang_code: {request.lang_code}")
|
||||||
async for chunk, chunk_data in tts_service.generate_audio_stream(
|
async for chunk, chunk_data in tts_service.generate_audio_stream(
|
||||||
|
@ -144,7 +151,7 @@ async def stream_audio_chunks(
|
||||||
output_format=request.response_format,
|
output_format=request.response_format,
|
||||||
lang_code=request.lang_code or request.voice[0],
|
lang_code=request.lang_code or request.voice[0],
|
||||||
normalization_options=request.normalization_options,
|
normalization_options=request.normalization_options,
|
||||||
return_timestamps=False,
|
return_timestamps=unique_properties["return_timestamps"],
|
||||||
):
|
):
|
||||||
|
|
||||||
# Check if client is still connected
|
# Check if client is still connected
|
||||||
|
|
|
@ -192,18 +192,22 @@ class AudioService:
|
||||||
normalizer = AudioNormalizer()
|
normalizer = AudioNormalizer()
|
||||||
|
|
||||||
audio_chunk.audio=normalizer.normalize(audio_chunk.audio)
|
audio_chunk.audio=normalizer.normalize(audio_chunk.audio)
|
||||||
|
|
||||||
|
trimed_samples=0
|
||||||
# Trim start and end if enough samples
|
# Trim start and end if enough samples
|
||||||
if len(audio_chunk.audio) > (2 * normalizer.samples_to_trim):
|
if len(audio_chunk.audio) > (2 * normalizer.samples_to_trim):
|
||||||
audio_chunk.audio = audio_chunk.audio[normalizer.samples_to_trim : -normalizer.samples_to_trim]
|
audio_chunk.audio = audio_chunk.audio[normalizer.samples_to_trim : -normalizer.samples_to_trim]
|
||||||
|
trimed_samples+=normalizer.samples_to_trim
|
||||||
|
|
||||||
# Find non silent portion and trim
|
# Find non silent portion and trim
|
||||||
start_index,end_index=normalizer.find_first_last_non_silent(audio_chunk.audio,chunk_text,speed,is_last_chunk=is_last_chunk)
|
start_index,end_index=normalizer.find_first_last_non_silent(audio_chunk.audio,chunk_text,speed,is_last_chunk=is_last_chunk)
|
||||||
|
|
||||||
audio_chunk.audio=audio_chunk.audio[start_index:end_index]
|
audio_chunk.audio=audio_chunk.audio[start_index:end_index]
|
||||||
|
trimed_samples+=start_index
|
||||||
|
|
||||||
if audio_chunk.word_timestamps is not None:
|
if audio_chunk.word_timestamps is not None:
|
||||||
for timestamp in audio_chunk.word_timestamps:
|
for timestamp in audio_chunk.word_timestamps:
|
||||||
timestamp["start_time"]-=start_index / 24000
|
timestamp.start_time-=trimed_samples / 24000
|
||||||
timestamp["end_time"]-=start_index / 24000
|
timestamp.end_time-=trimed_samples / 24000
|
||||||
return audio_chunk
|
return audio_chunk
|
||||||
|
|
|
@ -282,8 +282,8 @@ class TTSService:
|
||||||
):
|
):
|
||||||
if chunk_data.word_timestamps is not None:
|
if chunk_data.word_timestamps is not None:
|
||||||
for timestamp in chunk_data.word_timestamps:
|
for timestamp in chunk_data.word_timestamps:
|
||||||
timestamp["start_time"]+=current_offset
|
timestamp.start_time+=current_offset
|
||||||
timestamp["end_time"]+=current_offset
|
timestamp.end_time+=current_offset
|
||||||
|
|
||||||
current_offset+=len(chunk_data.audio) / 24000
|
current_offset+=len(chunk_data.audio) / 24000
|
||||||
|
|
||||||
|
|
51
api/src/structures/custom_responses.py
Normal file
51
api/src/structures/custom_responses.py
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
from collections.abc import AsyncIterable, Iterable
|
||||||
|
|
||||||
|
import json
|
||||||
|
import typing
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from starlette.background import BackgroundTask
|
||||||
|
from starlette.concurrency import iterate_in_threadpool
|
||||||
|
from starlette.responses import JSONResponse, StreamingResponse
|
||||||
|
|
||||||
|
|
||||||
|
class JSONStreamingResponse(StreamingResponse, JSONResponse):
|
||||||
|
"""StreamingResponse that also render with JSON."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
content: Iterable | AsyncIterable,
|
||||||
|
status_code: int = 200,
|
||||||
|
headers: dict[str, str] | None = None,
|
||||||
|
media_type: str | None = None,
|
||||||
|
background: BackgroundTask | None = None,
|
||||||
|
) -> None:
|
||||||
|
if isinstance(content, AsyncIterable):
|
||||||
|
self._content_iterable: AsyncIterable = content
|
||||||
|
else:
|
||||||
|
self._content_iterable = iterate_in_threadpool(content)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
async def body_iterator() -> AsyncIterable[bytes]:
|
||||||
|
async for content_ in self._content_iterable:
|
||||||
|
if isinstance(content_, BaseModel):
|
||||||
|
content_ = content_.model_dump()
|
||||||
|
yield self.render(content_)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
self.body_iterator = body_iterator()
|
||||||
|
self.status_code = status_code
|
||||||
|
if media_type is not None:
|
||||||
|
self.media_type = media_type
|
||||||
|
self.background = background
|
||||||
|
self.init_headers(headers)
|
||||||
|
|
||||||
|
def render(self, content: typing.Any) -> bytes:
|
||||||
|
return (json.dumps(
|
||||||
|
content,
|
||||||
|
ensure_ascii=False,
|
||||||
|
allow_nan=False,
|
||||||
|
indent=None,
|
||||||
|
separators=(",", ":"),
|
||||||
|
) + "\n").encode("utf-8")
|
|
@ -33,7 +33,8 @@ class WordTimestamp(BaseModel):
|
||||||
class CaptionedSpeechResponse(BaseModel):
|
class CaptionedSpeechResponse(BaseModel):
|
||||||
"""Response schema for captioned speech endpoint"""
|
"""Response schema for captioned speech endpoint"""
|
||||||
|
|
||||||
audio: bytes = Field(..., description="The generated audio data")
|
audio: str = Field(..., description="The generated audio data encoded in base 64")
|
||||||
|
audio_format: str = Field(..., description="The format of the output audio")
|
||||||
words: List[WordTimestamp] = Field(..., description="Word-level timestamps")
|
words: List[WordTimestamp] = Field(..., description="Word-level timestamps")
|
||||||
|
|
||||||
class NormalizationOptions(BaseModel):
|
class NormalizationOptions(BaseModel):
|
||||||
|
|
Loading…
Add table
Reference in a new issue