mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
Add interruptible streams
This commit is contained in:
parent
36f85638ac
commit
cf72e4ed2b
1 changed files with 24 additions and 11 deletions
|
@ -1,6 +1,6 @@
|
|||
from typing import AsyncGenerator, List, Union
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Response
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Response, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from loguru import logger
|
||||
|
||||
|
@ -49,22 +49,35 @@ async def process_voices(
|
|||
|
||||
|
||||
async def stream_audio_chunks(
|
||||
tts_service: TTSService, request: OpenAISpeechRequest
|
||||
tts_service: TTSService,
|
||||
request: OpenAISpeechRequest,
|
||||
client_request: Request
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""Stream audio chunks as they're generated"""
|
||||
"""Stream audio chunks as they're generated with client disconnect handling"""
|
||||
voice_to_use = await process_voices(request.voice, tts_service)
|
||||
async for chunk in tts_service.generate_audio_stream(
|
||||
text=request.input,
|
||||
voice=voice_to_use,
|
||||
speed=request.speed,
|
||||
output_format=request.response_format,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
try:
|
||||
async for chunk in tts_service.generate_audio_stream(
|
||||
text=request.input,
|
||||
voice=voice_to_use,
|
||||
speed=request.speed,
|
||||
output_format=request.response_format,
|
||||
):
|
||||
# Check if client is still connected
|
||||
if await client_request.is_disconnected():
|
||||
logger.info("Client disconnected, stopping audio generation")
|
||||
break
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
logger.error(f"Error in audio streaming: {str(e)}")
|
||||
# Let the exception propagate to trigger cleanup
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/audio/speech")
|
||||
async def create_speech(
|
||||
request: OpenAISpeechRequest,
|
||||
client_request: Request,
|
||||
tts_service: TTSService = Depends(get_tts_service),
|
||||
x_raw_response: str = Header(None, alias="x-raw-response"),
|
||||
):
|
||||
|
@ -87,7 +100,7 @@ async def create_speech(
|
|||
if request.stream:
|
||||
# Stream audio chunks as they're generated
|
||||
return StreamingResponse(
|
||||
stream_audio_chunks(tts_service, request),
|
||||
stream_audio_chunks(tts_service, request, client_request),
|
||||
media_type=content_type,
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
||||
|
|
Loading…
Add table
Reference in a new issue