Add interruptible streams

This commit is contained in:
remsky 2025-01-13 23:25:06 -07:00
parent 36f85638ac
commit cf72e4ed2b

View file

@ -1,6 +1,6 @@
from typing import AsyncGenerator, List, Union
from fastapi import APIRouter, Depends, Header, HTTPException, Response
from fastapi import APIRouter, Depends, Header, HTTPException, Response, Request
from fastapi.responses import StreamingResponse
from loguru import logger
@ -49,22 +49,35 @@ async def process_voices(
async def stream_audio_chunks(
tts_service: TTSService, request: OpenAISpeechRequest
tts_service: TTSService,
request: OpenAISpeechRequest,
client_request: Request
) -> AsyncGenerator[bytes, None]:
"""Stream audio chunks as they're generated"""
"""Stream audio chunks as they're generated with client disconnect handling"""
voice_to_use = await process_voices(request.voice, tts_service)
async for chunk in tts_service.generate_audio_stream(
text=request.input,
voice=voice_to_use,
speed=request.speed,
output_format=request.response_format,
):
yield chunk
try:
async for chunk in tts_service.generate_audio_stream(
text=request.input,
voice=voice_to_use,
speed=request.speed,
output_format=request.response_format,
):
# Check if client is still connected
if await client_request.is_disconnected():
logger.info("Client disconnected, stopping audio generation")
break
yield chunk
except Exception as e:
logger.error(f"Error in audio streaming: {str(e)}")
# Let the exception propagate to trigger cleanup
raise
@router.post("/audio/speech")
async def create_speech(
request: OpenAISpeechRequest,
client_request: Request,
tts_service: TTSService = Depends(get_tts_service),
x_raw_response: str = Header(None, alias="x-raw-response"),
):
@ -87,7 +100,7 @@ async def create_speech(
if request.stream:
# Stream audio chunks as they're generated
return StreamingResponse(
stream_audio_chunks(tts_service, request),
stream_audio_chunks(tts_service, request, client_request),
media_type=content_type,
headers={
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",