mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
More working on streaming timestamps
This commit is contained in:
parent
51b6b01589
commit
91d370d97f
2 changed files with 35 additions and 31 deletions
|
@ -5,9 +5,10 @@ import json
|
|||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from typing import AsyncGenerator, Dict, List, Union
|
||||
from typing import AsyncGenerator, Dict, List, Union, Tuple
|
||||
|
||||
import aiofiles
|
||||
from inference.base import AudioChunk
|
||||
import torch
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
|
@ -127,13 +128,13 @@ async def process_voices(
|
|||
|
||||
async def stream_audio_chunks(
|
||||
tts_service: TTSService, request: OpenAISpeechRequest, client_request: Request
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
) -> AsyncGenerator[Tuple[bytes,AudioChunk], None]:
|
||||
"""Stream audio chunks as they're generated with client disconnect handling"""
|
||||
voice_name = await process_voices(request.voice, tts_service)
|
||||
|
||||
try:
|
||||
logger.info(f"Starting audio generation with lang_code: {request.lang_code}")
|
||||
async for chunk in tts_service.generate_audio_stream(
|
||||
async for chunk, chunk_data in tts_service.generate_audio_stream(
|
||||
text=request.input,
|
||||
voice=voice_name,
|
||||
speed=request.speed,
|
||||
|
@ -148,7 +149,7 @@ async def stream_audio_chunks(
|
|||
if is_disconnected:
|
||||
logger.info("Client disconnected, stopping audio generation")
|
||||
break
|
||||
yield chunk
|
||||
yield chunk, chunk_data
|
||||
except Exception as e:
|
||||
logger.error(f"Error in audio streaming: {str(e)}")
|
||||
# Let the exception propagate to trigger cleanup
|
||||
|
@ -213,12 +214,15 @@ async def create_speech(
|
|||
}
|
||||
|
||||
# Create async generator for streaming
|
||||
async def dual_output():
|
||||
async def dual_output(return_json:bool=False):
|
||||
try:
|
||||
# Write chunks to temp file and stream
|
||||
async for chunk in generator:
|
||||
async for chunk, chunk_data in generator:
|
||||
if chunk: # Skip empty chunks
|
||||
await temp_writer.write(chunk)
|
||||
if return_json:
|
||||
yield chunk, chunk_data
|
||||
else:
|
||||
yield chunk
|
||||
|
||||
# Finalize the temp file
|
||||
|
|
|
@ -53,7 +53,8 @@ class TTSService:
|
|||
is_last: bool = False,
|
||||
normalizer: Optional[AudioNormalizer] = None,
|
||||
lang_code: Optional[str] = None,
|
||||
) -> AsyncGenerator[Union[np.ndarray, bytes], None]:
|
||||
return_timestamps: Optional[bool] = False,
|
||||
) -> AsyncGenerator[Tuple[Union[np.ndarray, bytes],AudioChunk], Tuple[None,None]]:
|
||||
"""Process tokens into audio."""
|
||||
async with self._chunk_semaphore:
|
||||
try:
|
||||
|
@ -63,7 +64,7 @@ class TTSService:
|
|||
if not output_format:
|
||||
yield np.array([], dtype=np.float32)
|
||||
return
|
||||
result, _ = await AudioService.convert_audio(
|
||||
result, chunk_data = await AudioService.convert_audio(
|
||||
AudioChunk(np.array([0], dtype=np.float32)), # Dummy data for type checking
|
||||
24000,
|
||||
output_format,
|
||||
|
@ -73,7 +74,7 @@ class TTSService:
|
|||
normalizer=normalizer,
|
||||
is_last_chunk=True,
|
||||
)
|
||||
yield result
|
||||
yield result, chunk_data
|
||||
return
|
||||
|
||||
# Skip empty chunks
|
||||
|
@ -91,7 +92,7 @@ class TTSService:
|
|||
(voice_name, voice_path),
|
||||
speed=speed,
|
||||
lang_code=lang_code,
|
||||
return_timestamps=True,
|
||||
return_timestamps=return_timestamps,
|
||||
):
|
||||
# For streaming, convert to bytes
|
||||
if output_format:
|
||||
|
@ -106,8 +107,7 @@ class TTSService:
|
|||
is_last_chunk=is_last,
|
||||
normalizer=normalizer,
|
||||
)
|
||||
print(chunk_data.word_timestamps)
|
||||
yield converted
|
||||
yield converted, chunk_data
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert audio: {str(e)}")
|
||||
else:
|
||||
|
@ -116,31 +116,30 @@ class TTSService:
|
|||
speed,
|
||||
is_last,
|
||||
normalizer)
|
||||
print(chunk_data.word_timestamps)
|
||||
yield chunk_data.audio
|
||||
yield chunk_data.audio, chunk_data
|
||||
else:
|
||||
|
||||
# For legacy backends, load voice tensor
|
||||
voice_tensor = await self._voice_manager.load_voice(
|
||||
voice_name, device=backend.device
|
||||
)
|
||||
chunk_audio = await self.model_manager.generate(
|
||||
tokens, voice_tensor, speed=speed
|
||||
chunk_data = await self.model_manager.generate(
|
||||
tokens, voice_tensor, speed=speed, return_timestamps=return_timestamps
|
||||
)
|
||||
|
||||
if chunk_audio is None:
|
||||
if chunk_data.audio is None:
|
||||
logger.error("Model generated None for audio chunk")
|
||||
return
|
||||
|
||||
if len(chunk_audio) == 0:
|
||||
if len(chunk_data.audio) == 0:
|
||||
logger.error("Model generated empty audio chunk")
|
||||
return
|
||||
|
||||
# For streaming, convert to bytes
|
||||
if output_format:
|
||||
try:
|
||||
converted = await AudioService.convert_audio(
|
||||
chunk_audio,
|
||||
converted, chunk_data = await AudioService.convert_audio(
|
||||
chunk_data,
|
||||
24000,
|
||||
output_format,
|
||||
speed,
|
||||
|
@ -149,16 +148,16 @@ class TTSService:
|
|||
normalizer=normalizer,
|
||||
is_last_chunk=is_last,
|
||||
)
|
||||
yield converted
|
||||
yield converted, chunk_data
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert audio: {str(e)}")
|
||||
else:
|
||||
trimmed = await AudioService.trim_audio(chunk_audio,
|
||||
trimmed = await AudioService.trim_audio(chunk_data,
|
||||
chunk_text,
|
||||
speed,
|
||||
is_last,
|
||||
normalizer)
|
||||
yield trimmed
|
||||
yield trimmed.audio, trimmed
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process tokens: {str(e)}")
|
||||
|
||||
|
@ -242,8 +241,9 @@ class TTSService:
|
|||
speed: float = 1.0,
|
||||
output_format: str = "wav",
|
||||
lang_code: Optional[str] = None,
|
||||
normalization_options: Optional[NormalizationOptions] = NormalizationOptions()
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
|
||||
return_timestamps: Optional[bool] = False,
|
||||
) -> AsyncGenerator[Tuple[bytes,AudioChunk], None]:
|
||||
"""Generate and stream audio chunks."""
|
||||
stream_normalizer = AudioNormalizer()
|
||||
chunk_index = 0
|
||||
|
@ -266,7 +266,7 @@ class TTSService:
|
|||
async for chunk_text, tokens in smart_split(text,normalization_options=normalization_options):
|
||||
try:
|
||||
# Process audio for chunk
|
||||
async for result in self._process_chunk(
|
||||
async for result, chunk_data in self._process_chunk(
|
||||
chunk_text, # Pass text for Kokoro V1
|
||||
tokens, # Pass tokens for legacy backends
|
||||
voice_name, # Pass voice name
|
||||
|
@ -279,7 +279,7 @@ class TTSService:
|
|||
lang_code=pipeline_lang_code, # Pass lang_code
|
||||
):
|
||||
if result is not None:
|
||||
yield result
|
||||
yield result,chunk_data
|
||||
chunk_index += 1
|
||||
else:
|
||||
logger.warning(
|
||||
|
@ -296,7 +296,7 @@ class TTSService:
|
|||
if chunk_index > 0:
|
||||
try:
|
||||
# Empty tokens list to finalize audio
|
||||
async for result in self._process_chunk(
|
||||
async for result,chunk_data in self._process_chunk(
|
||||
"", # Empty text
|
||||
[], # Empty tokens
|
||||
voice_name,
|
||||
|
@ -309,7 +309,7 @@ class TTSService:
|
|||
lang_code=pipeline_lang_code, # Pass lang_code
|
||||
):
|
||||
if result is not None:
|
||||
yield result
|
||||
yield result, chunk_data
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to finalize audio stream: {str(e)}")
|
||||
|
||||
|
@ -325,7 +325,7 @@ class TTSService:
|
|||
speed: float = 1.0,
|
||||
return_timestamps: bool = False,
|
||||
lang_code: Optional[str] = None,
|
||||
) -> Union[Tuple[np.ndarray, float], Tuple[np.ndarray, float, List[dict]]]:
|
||||
) -> Tuple[Tuple[np.ndarray,AudioChunk]]:
|
||||
"""Generate complete audio for text using streaming internally."""
|
||||
start_time = time.time()
|
||||
chunks = []
|
||||
|
|
Loading…
Add table
Reference in a new issue