More working on streaming timestamps

This commit is contained in:
Fireblade2534 2025-02-12 17:13:56 +00:00
parent 51b6b01589
commit 91d370d97f
2 changed files with 35 additions and 31 deletions

View file

@ -5,9 +5,10 @@ import json
import os
import re
import tempfile
from typing import AsyncGenerator, Dict, List, Union
from typing import AsyncGenerator, Dict, List, Union, Tuple
import aiofiles
from inference.base import AudioChunk
import torch
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
from fastapi.responses import FileResponse, StreamingResponse
@ -127,13 +128,13 @@ async def process_voices(
async def stream_audio_chunks(
tts_service: TTSService, request: OpenAISpeechRequest, client_request: Request
) -> AsyncGenerator[bytes, None]:
) -> AsyncGenerator[Tuple[bytes,AudioChunk], None]:
"""Stream audio chunks as they're generated with client disconnect handling"""
voice_name = await process_voices(request.voice, tts_service)
try:
logger.info(f"Starting audio generation with lang_code: {request.lang_code}")
async for chunk in tts_service.generate_audio_stream(
async for chunk, chunk_data in tts_service.generate_audio_stream(
text=request.input,
voice=voice_name,
speed=request.speed,
@ -148,7 +149,7 @@ async def stream_audio_chunks(
if is_disconnected:
logger.info("Client disconnected, stopping audio generation")
break
yield chunk
yield chunk, chunk_data
except Exception as e:
logger.error(f"Error in audio streaming: {str(e)}")
# Let the exception propagate to trigger cleanup
@ -213,13 +214,16 @@ async def create_speech(
}
# Create async generator for streaming
async def dual_output():
async def dual_output(return_json:bool=False):
try:
# Write chunks to temp file and stream
async for chunk in generator:
async for chunk, chunk_data in generator:
if chunk: # Skip empty chunks
await temp_writer.write(chunk)
yield chunk
if return_json:
yield chunk, chunk_data
else:
yield chunk
# Finalize the temp file
await temp_writer.finalize()

View file

@ -53,7 +53,8 @@ class TTSService:
is_last: bool = False,
normalizer: Optional[AudioNormalizer] = None,
lang_code: Optional[str] = None,
) -> AsyncGenerator[Union[np.ndarray, bytes], None]:
return_timestamps: Optional[bool] = False,
) -> AsyncGenerator[Tuple[Union[np.ndarray, bytes],AudioChunk], Tuple[None,None]]:
"""Process tokens into audio."""
async with self._chunk_semaphore:
try:
@ -63,7 +64,7 @@ class TTSService:
if not output_format:
yield np.array([], dtype=np.float32)
return
result, _ = await AudioService.convert_audio(
result, chunk_data = await AudioService.convert_audio(
AudioChunk(np.array([0], dtype=np.float32)), # Dummy data for type checking
24000,
output_format,
@ -73,7 +74,7 @@ class TTSService:
normalizer=normalizer,
is_last_chunk=True,
)
yield result
yield result, chunk_data
return
# Skip empty chunks
@ -91,7 +92,7 @@ class TTSService:
(voice_name, voice_path),
speed=speed,
lang_code=lang_code,
return_timestamps=True,
return_timestamps=return_timestamps,
):
# For streaming, convert to bytes
if output_format:
@ -106,8 +107,7 @@ class TTSService:
is_last_chunk=is_last,
normalizer=normalizer,
)
print(chunk_data.word_timestamps)
yield converted
yield converted, chunk_data
except Exception as e:
logger.error(f"Failed to convert audio: {str(e)}")
else:
@ -116,31 +116,30 @@ class TTSService:
speed,
is_last,
normalizer)
print(chunk_data.word_timestamps)
yield chunk_data.audio
yield chunk_data.audio, chunk_data
else:
# For legacy backends, load voice tensor
voice_tensor = await self._voice_manager.load_voice(
voice_name, device=backend.device
)
chunk_audio = await self.model_manager.generate(
tokens, voice_tensor, speed=speed
chunk_data = await self.model_manager.generate(
tokens, voice_tensor, speed=speed, return_timestamps=return_timestamps
)
if chunk_audio is None:
if chunk_data.audio is None:
logger.error("Model generated None for audio chunk")
return
if len(chunk_audio) == 0:
if len(chunk_data.audio) == 0:
logger.error("Model generated empty audio chunk")
return
# For streaming, convert to bytes
if output_format:
try:
converted = await AudioService.convert_audio(
chunk_audio,
converted, chunk_data = await AudioService.convert_audio(
chunk_data,
24000,
output_format,
speed,
@ -149,16 +148,16 @@ class TTSService:
normalizer=normalizer,
is_last_chunk=is_last,
)
yield converted
yield converted, chunk_data
except Exception as e:
logger.error(f"Failed to convert audio: {str(e)}")
else:
trimmed = await AudioService.trim_audio(chunk_audio,
trimmed = await AudioService.trim_audio(chunk_data,
chunk_text,
speed,
is_last,
normalizer)
yield trimmed
yield trimmed.audio, trimmed
except Exception as e:
logger.error(f"Failed to process tokens: {str(e)}")
@ -242,8 +241,9 @@ class TTSService:
speed: float = 1.0,
output_format: str = "wav",
lang_code: Optional[str] = None,
normalization_options: Optional[NormalizationOptions] = NormalizationOptions()
) -> AsyncGenerator[bytes, None]:
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
return_timestamps: Optional[bool] = False,
) -> AsyncGenerator[Tuple[bytes,AudioChunk], None]:
"""Generate and stream audio chunks."""
stream_normalizer = AudioNormalizer()
chunk_index = 0
@ -266,7 +266,7 @@ class TTSService:
async for chunk_text, tokens in smart_split(text,normalization_options=normalization_options):
try:
# Process audio for chunk
async for result in self._process_chunk(
async for result, chunk_data in self._process_chunk(
chunk_text, # Pass text for Kokoro V1
tokens, # Pass tokens for legacy backends
voice_name, # Pass voice name
@ -279,7 +279,7 @@ class TTSService:
lang_code=pipeline_lang_code, # Pass lang_code
):
if result is not None:
yield result
yield result,chunk_data
chunk_index += 1
else:
logger.warning(
@ -296,7 +296,7 @@ class TTSService:
if chunk_index > 0:
try:
# Empty tokens list to finalize audio
async for result in self._process_chunk(
async for result,chunk_data in self._process_chunk(
"", # Empty text
[], # Empty tokens
voice_name,
@ -309,7 +309,7 @@ class TTSService:
lang_code=pipeline_lang_code, # Pass lang_code
):
if result is not None:
yield result
yield result, chunk_data
except Exception as e:
logger.error(f"Failed to finalize audio stream: {str(e)}")
@ -325,7 +325,7 @@ class TTSService:
speed: float = 1.0,
return_timestamps: bool = False,
lang_code: Optional[str] = None,
) -> Union[Tuple[np.ndarray, float], Tuple[np.ndarray, float, List[dict]]]:
) -> Tuple[Tuple[np.ndarray,AudioChunk]]:
"""Generate complete audio for text using streaming internally."""
start_time = time.time()
chunks = []