Add interruptible streams

This commit is contained in:
remsky 2025-01-13 23:25:06 -07:00
parent 36f85638ac
commit cf72e4ed2b

View file

@ -1,6 +1,6 @@
from typing import AsyncGenerator, List, Union from typing import AsyncGenerator, List, Union
from fastapi import APIRouter, Depends, Header, HTTPException, Response from fastapi import APIRouter, Depends, Header, HTTPException, Response, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from loguru import logger from loguru import logger
@ -49,22 +49,35 @@ async def process_voices(
async def stream_audio_chunks( async def stream_audio_chunks(
tts_service: TTSService, request: OpenAISpeechRequest tts_service: TTSService,
request: OpenAISpeechRequest,
client_request: Request
) -> AsyncGenerator[bytes, None]: ) -> AsyncGenerator[bytes, None]:
"""Stream audio chunks as they're generated""" """Stream audio chunks as they're generated with client disconnect handling"""
voice_to_use = await process_voices(request.voice, tts_service) voice_to_use = await process_voices(request.voice, tts_service)
try:
async for chunk in tts_service.generate_audio_stream( async for chunk in tts_service.generate_audio_stream(
text=request.input, text=request.input,
voice=voice_to_use, voice=voice_to_use,
speed=request.speed, speed=request.speed,
output_format=request.response_format, output_format=request.response_format,
): ):
# Check if client is still connected
if await client_request.is_disconnected():
logger.info("Client disconnected, stopping audio generation")
break
yield chunk yield chunk
except Exception as e:
logger.error(f"Error in audio streaming: {str(e)}")
# Let the exception propagate to trigger cleanup
raise
@router.post("/audio/speech") @router.post("/audio/speech")
async def create_speech( async def create_speech(
request: OpenAISpeechRequest, request: OpenAISpeechRequest,
client_request: Request,
tts_service: TTSService = Depends(get_tts_service), tts_service: TTSService = Depends(get_tts_service),
x_raw_response: str = Header(None, alias="x-raw-response"), x_raw_response: str = Header(None, alias="x-raw-response"),
): ):
@ -87,7 +100,7 @@ async def create_speech(
if request.stream: if request.stream:
# Stream audio chunks as they're generated # Stream audio chunks as they're generated
return StreamingResponse( return StreamingResponse(
stream_audio_chunks(tts_service, request), stream_audio_chunks(tts_service, request, client_request),
media_type=content_type, media_type=content_type,
headers={ headers={
"Content-Disposition": f"attachment; filename=speech.{request.response_format}", "Content-Disposition": f"attachment; filename=speech.{request.response_format}",