mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Started work on allowing streaming word level timestamps as well as transitioning the dev code so it uses a lot more from the open ai endpoint
This commit is contained in:
parent
7772dbc2e4
commit
4027768920
4 changed files with 184 additions and 11 deletions
|
@ -7,6 +7,7 @@ from fastapi.responses import StreamingResponse, FileResponse
|
||||||
from kokoro import KPipeline
|
from kokoro import KPipeline
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from ..inference.base import AudioChunk
|
||||||
from ..core.config import settings
|
from ..core.config import settings
|
||||||
from ..services.audio import AudioNormalizer, AudioService
|
from ..services.audio import AudioNormalizer, AudioService
|
||||||
from ..services.streaming_audio_writer import StreamingAudioWriter
|
from ..services.streaming_audio_writer import StreamingAudioWriter
|
||||||
|
@ -19,6 +20,7 @@ from ..structures.text_schemas import (
|
||||||
PhonemeRequest,
|
PhonemeRequest,
|
||||||
PhonemeResponse,
|
PhonemeResponse,
|
||||||
)
|
)
|
||||||
|
from .openai_compatible import process_voices, stream_audio_chunks
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -194,6 +196,154 @@ async def create_captioned_speech(
|
||||||
tts_service: TTSService = Depends(get_tts_service),
|
tts_service: TTSService = Depends(get_tts_service),
|
||||||
):
|
):
|
||||||
"""Generate audio with word-level timestamps using streaming approach"""
|
"""Generate audio with word-level timestamps using streaming approach"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# model_name = get_model_name(request.model)
|
||||||
|
tts_service = await get_tts_service()
|
||||||
|
voice_name = await process_voices(request.voice, tts_service)
|
||||||
|
|
||||||
|
# Set content type based on format
|
||||||
|
content_type = {
|
||||||
|
"mp3": "audio/mpeg",
|
||||||
|
"opus": "audio/opus",
|
||||||
|
"aac": "audio/aac",
|
||||||
|
"flac": "audio/flac",
|
||||||
|
"wav": "audio/wav",
|
||||||
|
"pcm": "audio/pcm",
|
||||||
|
}.get(request.response_format, f"audio/{request.response_format}")
|
||||||
|
|
||||||
|
# Check if streaming is requested (default for OpenAI client)
|
||||||
|
if request.stream:
|
||||||
|
# Create generator but don't start it yet
|
||||||
|
generator = stream_audio_chunks(tts_service, request, client_request)
|
||||||
|
|
||||||
|
# If download link requested, wrap generator with temp file writer
|
||||||
|
if request.return_download_link:
|
||||||
|
from ..services.temp_manager import TempFileWriter
|
||||||
|
|
||||||
|
temp_writer = TempFileWriter(request.response_format)
|
||||||
|
await temp_writer.__aenter__() # Initialize temp file
|
||||||
|
|
||||||
|
# Get download path immediately after temp file creation
|
||||||
|
download_path = temp_writer.download_path
|
||||||
|
|
||||||
|
# Create response headers with download path
|
||||||
|
headers = {
|
||||||
|
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Transfer-Encoding": "chunked",
|
||||||
|
"X-Download-Path": download_path,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create async generator for streaming
|
||||||
|
async def dual_output():
|
||||||
|
try:
|
||||||
|
# Write chunks to temp file and stream
|
||||||
|
async for chunk in generator:
|
||||||
|
if chunk: # Skip empty chunks
|
||||||
|
await temp_writer.write(chunk)
|
||||||
|
#if return_json:
|
||||||
|
# yield chunk, chunk_data
|
||||||
|
#else:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# Finalize the temp file
|
||||||
|
await temp_writer.finalize()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in dual output streaming: {e}")
|
||||||
|
await temp_writer.__aexit__(type(e), e, e.__traceback__)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# Ensure temp writer is closed
|
||||||
|
if not temp_writer._finalized:
|
||||||
|
await temp_writer.__aexit__(None, None, None)
|
||||||
|
|
||||||
|
# Stream with temp file writing
|
||||||
|
return StreamingResponse(
|
||||||
|
dual_output(), media_type=content_type, headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
|
# Standard streaming without download link
|
||||||
|
return StreamingResponse(
|
||||||
|
generator,
|
||||||
|
media_type=content_type,
|
||||||
|
headers={
|
||||||
|
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Transfer-Encoding": "chunked",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Generate complete audio using public interface
|
||||||
|
_, audio_data = await tts_service.generate_audio(
|
||||||
|
text=request.input,
|
||||||
|
voice=voice_name,
|
||||||
|
speed=request.speed,
|
||||||
|
lang_code=request.lang_code,
|
||||||
|
)
|
||||||
|
content, audio_data = await AudioService.convert_audio(
|
||||||
|
audio_data,
|
||||||
|
24000,
|
||||||
|
request.response_format,
|
||||||
|
is_first_chunk=True,
|
||||||
|
is_last_chunk=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to requested format with proper finalization
|
||||||
|
final, _ = await AudioService.convert_audio(
|
||||||
|
AudioChunk(np.array([], dtype=np.int16)),
|
||||||
|
24000,
|
||||||
|
request.response_format,
|
||||||
|
is_first_chunk=False,
|
||||||
|
is_last_chunk=True,
|
||||||
|
)
|
||||||
|
output=content+final
|
||||||
|
return Response(
|
||||||
|
content=output,
|
||||||
|
media_type=content_type,
|
||||||
|
headers={
|
||||||
|
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
||||||
|
"Cache-Control": "no-cache", # Prevent caching
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
# Handle validation errors
|
||||||
|
logger.warning(f"Invalid request: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": "validation_error",
|
||||||
|
"message": str(e),
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
# Handle runtime/processing errors
|
||||||
|
logger.error(f"Processing error: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={
|
||||||
|
"error": "processing_error",
|
||||||
|
"message": str(e),
|
||||||
|
"type": "server_error",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Handle unexpected errors
|
||||||
|
logger.error(f"Unexpected error in speech generation: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={
|
||||||
|
"error": "processing_error",
|
||||||
|
"message": str(e),
|
||||||
|
"type": "server_error",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
# Set content type based on format
|
# Set content type based on format
|
||||||
content_type = {
|
content_type = {
|
||||||
|
@ -344,3 +494,4 @@ async def create_captioned_speech(
|
||||||
"type": "server_error",
|
"type": "server_error",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
"""
|
|
@ -10,12 +10,13 @@ from urllib import response
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
from ..inference.base import AudioChunk
|
|
||||||
import torch
|
import torch
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||||
from fastapi.responses import FileResponse, StreamingResponse
|
from fastapi.responses import FileResponse, StreamingResponse
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from ..inference.base import AudioChunk
|
||||||
from ..core.config import settings
|
from ..core.config import settings
|
||||||
from ..services.audio import AudioService
|
from ..services.audio import AudioService
|
||||||
from ..services.tts_service import TTSService
|
from ..services.tts_service import TTSService
|
||||||
|
@ -130,7 +131,7 @@ async def process_voices(
|
||||||
|
|
||||||
async def stream_audio_chunks(
|
async def stream_audio_chunks(
|
||||||
tts_service: TTSService, request: OpenAISpeechRequest, client_request: Request
|
tts_service: TTSService, request: OpenAISpeechRequest, client_request: Request
|
||||||
) -> AsyncGenerator[list, None]:
|
) -> AsyncGenerator[Tuple[Union[np.ndarray,bytes],AudioChunk], None]:
|
||||||
"""Stream audio chunks as they're generated with client disconnect handling"""
|
"""Stream audio chunks as they're generated with client disconnect handling"""
|
||||||
voice_name = await process_voices(request.voice, tts_service)
|
voice_name = await process_voices(request.voice, tts_service)
|
||||||
|
|
||||||
|
@ -154,7 +155,7 @@ async def stream_audio_chunks(
|
||||||
logger.info("Client disconnected, stopping audio generation")
|
logger.info("Client disconnected, stopping audio generation")
|
||||||
break
|
break
|
||||||
|
|
||||||
yield chunk
|
yield (chunk,chunk_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in audio streaming: {str(e)}")
|
logger.error(f"Error in audio streaming: {str(e)}")
|
||||||
# Let the exception propagate to trigger cleanup
|
# Let the exception propagate to trigger cleanup
|
||||||
|
@ -223,7 +224,7 @@ async def create_speech(
|
||||||
async def dual_output():
|
async def dual_output():
|
||||||
try:
|
try:
|
||||||
# Write chunks to temp file and stream
|
# Write chunks to temp file and stream
|
||||||
async for chunk in generator:
|
async for chunk,chunk_data in generator:
|
||||||
if chunk: # Skip empty chunks
|
if chunk: # Skip empty chunks
|
||||||
await temp_writer.write(chunk)
|
await temp_writer.write(chunk)
|
||||||
#if return_json:
|
#if return_json:
|
||||||
|
@ -247,9 +248,19 @@ async def create_speech(
|
||||||
dual_output(), media_type=content_type, headers=headers
|
dual_output(), media_type=content_type, headers=headers
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def single_output():
|
||||||
|
try:
|
||||||
|
# Stream chunks
|
||||||
|
async for chunk,chunk_data in generator:
|
||||||
|
if chunk: # Skip empty chunks
|
||||||
|
yield chunk
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in single output streaming: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
# Standard streaming without download link
|
# Standard streaming without download link
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
generator,
|
single_output(),
|
||||||
media_type=content_type,
|
media_type=content_type,
|
||||||
headers={
|
headers={
|
||||||
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
||||||
|
|
|
@ -332,24 +332,23 @@ class TTSService:
|
||||||
text: str,
|
text: str,
|
||||||
voice: str,
|
voice: str,
|
||||||
speed: float = 1.0,
|
speed: float = 1.0,
|
||||||
return_timestamps: bool = True,
|
return_timestamps: bool = False,
|
||||||
lang_code: Optional[str] = None,
|
lang_code: Optional[str] = None,
|
||||||
) -> Tuple[Tuple[np.ndarray,AudioChunk]]:
|
) -> Tuple[Tuple[np.ndarray,AudioChunk]]:
|
||||||
"""Generate complete audio for text using streaming internally."""
|
"""Generate complete audio for text using streaming internally."""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
audio_chunks = []
|
|
||||||
audio_data_chunks=[]
|
audio_data_chunks=[]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for audio_stream,audio_stream_data in self.generate_audio_stream(text,voice,speed=speed,return_timestamps=return_timestamps,lang_code=lang_code,output_format=None):
|
async for _,audio_stream_data in self.generate_audio_stream(text,voice,speed=speed,return_timestamps=return_timestamps,lang_code=lang_code,output_format=None):
|
||||||
audio_chunks.append(audio_stream_data.audio)
|
|
||||||
audio_data_chunks.append(audio_stream_data)
|
audio_data_chunks.append(audio_stream_data)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
combined_audio=np.concatenate(audio_chunks,dtype=np.int16)
|
|
||||||
combined_audio_data=AudioChunk.combine(audio_data_chunks)
|
combined_audio_data=AudioChunk.combine(audio_data_chunks)
|
||||||
return combined_audio,combined_audio_data
|
return combined_audio_data.audio,combined_audio_data
|
||||||
"""
|
"""
|
||||||
# Get backend and voice path
|
# Get backend and voice path
|
||||||
backend = self.model_manager.get_backend()
|
backend = self.model_manager.get_backend()
|
||||||
|
|
|
@ -106,11 +106,23 @@ class CaptionedSpeechRequest(BaseModel):
|
||||||
le=4.0,
|
le=4.0,
|
||||||
description="The speed of the generated audio. Select a value from 0.25 to 4.0.",
|
description="The speed of the generated audio. Select a value from 0.25 to 4.0.",
|
||||||
)
|
)
|
||||||
|
stream: bool = Field(
|
||||||
|
default=True, # Default to streaming for OpenAI compatibility
|
||||||
|
description="If true (default), audio will be streamed as it's generated. Each chunk will be a complete sentence.",
|
||||||
|
)
|
||||||
return_timestamps: bool = Field(
|
return_timestamps: bool = Field(
|
||||||
default=True,
|
default=True,
|
||||||
description="If true (default), returns word-level timestamps in the response",
|
description="If true (default), returns word-level timestamps in the response",
|
||||||
)
|
)
|
||||||
|
return_download_link: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="If true, returns a download link in X-Download-Path header after streaming completes",
|
||||||
|
)
|
||||||
lang_code: Optional[str] = Field(
|
lang_code: Optional[str] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
|
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
|
||||||
)
|
)
|
||||||
|
normalization_options: Optional[NormalizationOptions] = Field(
|
||||||
|
default= NormalizationOptions(),
|
||||||
|
description= "Options for the normalization system"
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Reference in a new issue