mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Started work on allowing streaming word level timestamps as well as transitioning the dev code so it uses a lot more from the open ai endpoint
This commit is contained in:
parent
7772dbc2e4
commit
4027768920
4 changed files with 184 additions and 11 deletions
|
@ -7,6 +7,7 @@ from fastapi.responses import StreamingResponse, FileResponse
|
|||
from kokoro import KPipeline
|
||||
from loguru import logger
|
||||
|
||||
from ..inference.base import AudioChunk
|
||||
from ..core.config import settings
|
||||
from ..services.audio import AudioNormalizer, AudioService
|
||||
from ..services.streaming_audio_writer import StreamingAudioWriter
|
||||
|
@ -19,6 +20,7 @@ from ..structures.text_schemas import (
|
|||
PhonemeRequest,
|
||||
PhonemeResponse,
|
||||
)
|
||||
from .openai_compatible import process_voices, stream_audio_chunks
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
@ -194,6 +196,154 @@ async def create_captioned_speech(
|
|||
tts_service: TTSService = Depends(get_tts_service),
|
||||
):
|
||||
"""Generate audio with word-level timestamps using streaming approach"""
|
||||
|
||||
try:
|
||||
# model_name = get_model_name(request.model)
|
||||
tts_service = await get_tts_service()
|
||||
voice_name = await process_voices(request.voice, tts_service)
|
||||
|
||||
# Set content type based on format
|
||||
content_type = {
|
||||
"mp3": "audio/mpeg",
|
||||
"opus": "audio/opus",
|
||||
"aac": "audio/aac",
|
||||
"flac": "audio/flac",
|
||||
"wav": "audio/wav",
|
||||
"pcm": "audio/pcm",
|
||||
}.get(request.response_format, f"audio/{request.response_format}")
|
||||
|
||||
# Check if streaming is requested (default for OpenAI client)
|
||||
if request.stream:
|
||||
# Create generator but don't start it yet
|
||||
generator = stream_audio_chunks(tts_service, request, client_request)
|
||||
|
||||
# If download link requested, wrap generator with temp file writer
|
||||
if request.return_download_link:
|
||||
from ..services.temp_manager import TempFileWriter
|
||||
|
||||
temp_writer = TempFileWriter(request.response_format)
|
||||
await temp_writer.__aenter__() # Initialize temp file
|
||||
|
||||
# Get download path immediately after temp file creation
|
||||
download_path = temp_writer.download_path
|
||||
|
||||
# Create response headers with download path
|
||||
headers = {
|
||||
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Cache-Control": "no-cache",
|
||||
"Transfer-Encoding": "chunked",
|
||||
"X-Download-Path": download_path,
|
||||
}
|
||||
|
||||
# Create async generator for streaming
|
||||
async def dual_output():
|
||||
try:
|
||||
# Write chunks to temp file and stream
|
||||
async for chunk in generator:
|
||||
if chunk: # Skip empty chunks
|
||||
await temp_writer.write(chunk)
|
||||
#if return_json:
|
||||
# yield chunk, chunk_data
|
||||
#else:
|
||||
yield chunk
|
||||
|
||||
# Finalize the temp file
|
||||
await temp_writer.finalize()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in dual output streaming: {e}")
|
||||
await temp_writer.__aexit__(type(e), e, e.__traceback__)
|
||||
raise
|
||||
finally:
|
||||
# Ensure temp writer is closed
|
||||
if not temp_writer._finalized:
|
||||
await temp_writer.__aexit__(None, None, None)
|
||||
|
||||
# Stream with temp file writing
|
||||
return StreamingResponse(
|
||||
dual_output(), media_type=content_type, headers=headers
|
||||
)
|
||||
|
||||
# Standard streaming without download link
|
||||
return StreamingResponse(
|
||||
generator,
|
||||
media_type=content_type,
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
||||
"X-Accel-Buffering": "no",
|
||||
"Cache-Control": "no-cache",
|
||||
"Transfer-Encoding": "chunked",
|
||||
},
|
||||
)
|
||||
else:
|
||||
# Generate complete audio using public interface
|
||||
_, audio_data = await tts_service.generate_audio(
|
||||
text=request.input,
|
||||
voice=voice_name,
|
||||
speed=request.speed,
|
||||
lang_code=request.lang_code,
|
||||
)
|
||||
content, audio_data = await AudioService.convert_audio(
|
||||
audio_data,
|
||||
24000,
|
||||
request.response_format,
|
||||
is_first_chunk=True,
|
||||
is_last_chunk=False,
|
||||
)
|
||||
|
||||
# Convert to requested format with proper finalization
|
||||
final, _ = await AudioService.convert_audio(
|
||||
AudioChunk(np.array([], dtype=np.int16)),
|
||||
24000,
|
||||
request.response_format,
|
||||
is_first_chunk=False,
|
||||
is_last_chunk=True,
|
||||
)
|
||||
output=content+final
|
||||
return Response(
|
||||
content=output,
|
||||
media_type=content_type,
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
||||
"Cache-Control": "no-cache", # Prevent caching
|
||||
},
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
# Handle validation errors
|
||||
logger.warning(f"Invalid request: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "validation_error",
|
||||
"message": str(e),
|
||||
"type": "invalid_request_error",
|
||||
},
|
||||
)
|
||||
except RuntimeError as e:
|
||||
# Handle runtime/processing errors
|
||||
logger.error(f"Processing error: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "processing_error",
|
||||
"message": str(e),
|
||||
"type": "server_error",
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
# Handle unexpected errors
|
||||
logger.error(f"Unexpected error in speech generation: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "processing_error",
|
||||
"message": str(e),
|
||||
"type": "server_error",
|
||||
},
|
||||
)
|
||||
|
||||
"""
|
||||
try:
|
||||
# Set content type based on format
|
||||
content_type = {
|
||||
|
@ -344,3 +494,4 @@ async def create_captioned_speech(
|
|||
"type": "server_error",
|
||||
},
|
||||
)
|
||||
"""
|
|
@ -10,12 +10,13 @@ from urllib import response
|
|||
import numpy as np
|
||||
|
||||
import aiofiles
|
||||
from ..inference.base import AudioChunk
|
||||
|
||||
import torch
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from loguru import logger
|
||||
|
||||
from ..inference.base import AudioChunk
|
||||
from ..core.config import settings
|
||||
from ..services.audio import AudioService
|
||||
from ..services.tts_service import TTSService
|
||||
|
@ -130,7 +131,7 @@ async def process_voices(
|
|||
|
||||
async def stream_audio_chunks(
|
||||
tts_service: TTSService, request: OpenAISpeechRequest, client_request: Request
|
||||
) -> AsyncGenerator[list, None]:
|
||||
) -> AsyncGenerator[Tuple[Union[np.ndarray,bytes],AudioChunk], None]:
|
||||
"""Stream audio chunks as they're generated with client disconnect handling"""
|
||||
voice_name = await process_voices(request.voice, tts_service)
|
||||
|
||||
|
@ -154,7 +155,7 @@ async def stream_audio_chunks(
|
|||
logger.info("Client disconnected, stopping audio generation")
|
||||
break
|
||||
|
||||
yield chunk
|
||||
yield (chunk,chunk_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in audio streaming: {str(e)}")
|
||||
# Let the exception propagate to trigger cleanup
|
||||
|
@ -223,7 +224,7 @@ async def create_speech(
|
|||
async def dual_output():
|
||||
try:
|
||||
# Write chunks to temp file and stream
|
||||
async for chunk in generator:
|
||||
async for chunk,chunk_data in generator:
|
||||
if chunk: # Skip empty chunks
|
||||
await temp_writer.write(chunk)
|
||||
#if return_json:
|
||||
|
@ -247,9 +248,19 @@ async def create_speech(
|
|||
dual_output(), media_type=content_type, headers=headers
|
||||
)
|
||||
|
||||
async def single_output():
|
||||
try:
|
||||
# Stream chunks
|
||||
async for chunk,chunk_data in generator:
|
||||
if chunk: # Skip empty chunks
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
logger.error(f"Error in single output streaming: {e}")
|
||||
raise
|
||||
|
||||
# Standard streaming without download link
|
||||
return StreamingResponse(
|
||||
generator,
|
||||
single_output(),
|
||||
media_type=content_type,
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
||||
|
|
|
@ -332,24 +332,23 @@ class TTSService:
|
|||
text: str,
|
||||
voice: str,
|
||||
speed: float = 1.0,
|
||||
return_timestamps: bool = True,
|
||||
return_timestamps: bool = False,
|
||||
lang_code: Optional[str] = None,
|
||||
) -> Tuple[Tuple[np.ndarray,AudioChunk]]:
|
||||
"""Generate complete audio for text using streaming internally."""
|
||||
start_time = time.time()
|
||||
audio_chunks = []
|
||||
audio_data_chunks=[]
|
||||
|
||||
try:
|
||||
async for audio_stream,audio_stream_data in self.generate_audio_stream(text,voice,speed=speed,return_timestamps=return_timestamps,lang_code=lang_code,output_format=None):
|
||||
audio_chunks.append(audio_stream_data.audio)
|
||||
async for _,audio_stream_data in self.generate_audio_stream(text,voice,speed=speed,return_timestamps=return_timestamps,lang_code=lang_code,output_format=None):
|
||||
|
||||
audio_data_chunks.append(audio_stream_data)
|
||||
|
||||
|
||||
|
||||
combined_audio=np.concatenate(audio_chunks,dtype=np.int16)
|
||||
|
||||
combined_audio_data=AudioChunk.combine(audio_data_chunks)
|
||||
return combined_audio,combined_audio_data
|
||||
return combined_audio_data.audio,combined_audio_data
|
||||
"""
|
||||
# Get backend and voice path
|
||||
backend = self.model_manager.get_backend()
|
||||
|
|
|
@ -106,11 +106,23 @@ class CaptionedSpeechRequest(BaseModel):
|
|||
le=4.0,
|
||||
description="The speed of the generated audio. Select a value from 0.25 to 4.0.",
|
||||
)
|
||||
stream: bool = Field(
|
||||
default=True, # Default to streaming for OpenAI compatibility
|
||||
description="If true (default), audio will be streamed as it's generated. Each chunk will be a complete sentence.",
|
||||
)
|
||||
return_timestamps: bool = Field(
|
||||
default=True,
|
||||
description="If true (default), returns word-level timestamps in the response",
|
||||
)
|
||||
return_download_link: bool = Field(
|
||||
default=False,
|
||||
description="If true, returns a download link in X-Download-Path header after streaming completes",
|
||||
)
|
||||
lang_code: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
|
||||
)
|
||||
normalization_options: Optional[NormalizationOptions] = Field(
|
||||
default= NormalizationOptions(),
|
||||
description= "Options for the normalization system"
|
||||
)
|
||||
|
|
Loading…
Add table
Reference in a new issue