From cf72e4ed2b6dd94066cde84a321a7543801b6b75 Mon Sep 17 00:00:00 2001 From: remsky Date: Mon, 13 Jan 2025 23:25:06 -0700 Subject: [PATCH] Add interruptible streams --- api/src/routers/openai_compatible.py | 35 +++++++++++++++++++--------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/api/src/routers/openai_compatible.py b/api/src/routers/openai_compatible.py index 96dc174..57d1257 100644 --- a/api/src/routers/openai_compatible.py +++ b/api/src/routers/openai_compatible.py @@ -1,6 +1,6 @@ from typing import AsyncGenerator, List, Union -from fastapi import APIRouter, Depends, Header, HTTPException, Response +from fastapi import APIRouter, Depends, Header, HTTPException, Response, Request from fastapi.responses import StreamingResponse from loguru import logger @@ -49,22 +49,35 @@ async def process_voices( async def stream_audio_chunks( - tts_service: TTSService, request: OpenAISpeechRequest + tts_service: TTSService, + request: OpenAISpeechRequest, + client_request: Request ) -> AsyncGenerator[bytes, None]: - """Stream audio chunks as they're generated""" + """Stream audio chunks as they're generated with client disconnect handling""" voice_to_use = await process_voices(request.voice, tts_service) - async for chunk in tts_service.generate_audio_stream( - text=request.input, - voice=voice_to_use, - speed=request.speed, - output_format=request.response_format, - ): - yield chunk + + try: + async for chunk in tts_service.generate_audio_stream( + text=request.input, + voice=voice_to_use, + speed=request.speed, + output_format=request.response_format, + ): + # Check if client is still connected + if await client_request.is_disconnected(): + logger.info("Client disconnected, stopping audio generation") + break + yield chunk + except Exception as e: + logger.error(f"Error in audio streaming: {str(e)}") + # Let the exception propagate to trigger cleanup + raise @router.post("/audio/speech") async def create_speech( request: OpenAISpeechRequest, + client_request: Request, tts_service: TTSService = Depends(get_tts_service), x_raw_response: str = Header(None, alias="x-raw-response"), ): @@ -87,7 +100,7 @@ async def create_speech( if request.stream: # Stream audio chunks as they're generated return StreamingResponse( - stream_audio_chunks(tts_service, request), + stream_audio_chunks(tts_service, request, client_request), media_type=content_type, headers={ "Content-Disposition": f"attachment; filename=speech.{request.response_format}",