More working on streaming timestamps

This commit is contained in:
Fireblade2534 2025-02-12 17:13:56 +00:00
parent 51b6b01589
commit 91d370d97f
2 changed files with 35 additions and 31 deletions

View file

@ -5,9 +5,10 @@ import json
import os import os
import re import re
import tempfile import tempfile
from typing import AsyncGenerator, Dict, List, Union from typing import AsyncGenerator, Dict, List, Union, Tuple
import aiofiles import aiofiles
from inference.base import AudioChunk
import torch import torch
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
from fastapi.responses import FileResponse, StreamingResponse from fastapi.responses import FileResponse, StreamingResponse
@ -127,13 +128,13 @@ async def process_voices(
async def stream_audio_chunks( async def stream_audio_chunks(
tts_service: TTSService, request: OpenAISpeechRequest, client_request: Request tts_service: TTSService, request: OpenAISpeechRequest, client_request: Request
) -> AsyncGenerator[bytes, None]: ) -> AsyncGenerator[Tuple[bytes,AudioChunk], None]:
"""Stream audio chunks as they're generated with client disconnect handling""" """Stream audio chunks as they're generated with client disconnect handling"""
voice_name = await process_voices(request.voice, tts_service) voice_name = await process_voices(request.voice, tts_service)
try: try:
logger.info(f"Starting audio generation with lang_code: {request.lang_code}") logger.info(f"Starting audio generation with lang_code: {request.lang_code}")
async for chunk in tts_service.generate_audio_stream( async for chunk, chunk_data in tts_service.generate_audio_stream(
text=request.input, text=request.input,
voice=voice_name, voice=voice_name,
speed=request.speed, speed=request.speed,
@ -148,7 +149,7 @@ async def stream_audio_chunks(
if is_disconnected: if is_disconnected:
logger.info("Client disconnected, stopping audio generation") logger.info("Client disconnected, stopping audio generation")
break break
yield chunk yield chunk, chunk_data
except Exception as e: except Exception as e:
logger.error(f"Error in audio streaming: {str(e)}") logger.error(f"Error in audio streaming: {str(e)}")
# Let the exception propagate to trigger cleanup # Let the exception propagate to trigger cleanup
@ -213,13 +214,16 @@ async def create_speech(
} }
# Create async generator for streaming # Create async generator for streaming
async def dual_output(): async def dual_output(return_json:bool=False):
try: try:
# Write chunks to temp file and stream # Write chunks to temp file and stream
async for chunk in generator: async for chunk, chunk_data in generator:
if chunk: # Skip empty chunks if chunk: # Skip empty chunks
await temp_writer.write(chunk) await temp_writer.write(chunk)
yield chunk if return_json:
yield chunk, chunk_data
else:
yield chunk
# Finalize the temp file # Finalize the temp file
await temp_writer.finalize() await temp_writer.finalize()

View file

@ -53,7 +53,8 @@ class TTSService:
is_last: bool = False, is_last: bool = False,
normalizer: Optional[AudioNormalizer] = None, normalizer: Optional[AudioNormalizer] = None,
lang_code: Optional[str] = None, lang_code: Optional[str] = None,
) -> AsyncGenerator[Union[np.ndarray, bytes], None]: return_timestamps: Optional[bool] = False,
) -> AsyncGenerator[Tuple[Union[np.ndarray, bytes],AudioChunk], Tuple[None,None]]:
"""Process tokens into audio.""" """Process tokens into audio."""
async with self._chunk_semaphore: async with self._chunk_semaphore:
try: try:
@ -63,7 +64,7 @@ class TTSService:
if not output_format: if not output_format:
yield np.array([], dtype=np.float32) yield np.array([], dtype=np.float32)
return return
result, _ = await AudioService.convert_audio( result, chunk_data = await AudioService.convert_audio(
AudioChunk(np.array([0], dtype=np.float32)), # Dummy data for type checking AudioChunk(np.array([0], dtype=np.float32)), # Dummy data for type checking
24000, 24000,
output_format, output_format,
@ -73,7 +74,7 @@ class TTSService:
normalizer=normalizer, normalizer=normalizer,
is_last_chunk=True, is_last_chunk=True,
) )
yield result yield result, chunk_data
return return
# Skip empty chunks # Skip empty chunks
@ -91,7 +92,7 @@ class TTSService:
(voice_name, voice_path), (voice_name, voice_path),
speed=speed, speed=speed,
lang_code=lang_code, lang_code=lang_code,
return_timestamps=True, return_timestamps=return_timestamps,
): ):
# For streaming, convert to bytes # For streaming, convert to bytes
if output_format: if output_format:
@ -106,8 +107,7 @@ class TTSService:
is_last_chunk=is_last, is_last_chunk=is_last,
normalizer=normalizer, normalizer=normalizer,
) )
print(chunk_data.word_timestamps) yield converted, chunk_data
yield converted
except Exception as e: except Exception as e:
logger.error(f"Failed to convert audio: {str(e)}") logger.error(f"Failed to convert audio: {str(e)}")
else: else:
@ -116,31 +116,30 @@ class TTSService:
speed, speed,
is_last, is_last,
normalizer) normalizer)
print(chunk_data.word_timestamps) yield chunk_data.audio, chunk_data
yield chunk_data.audio
else: else:
# For legacy backends, load voice tensor # For legacy backends, load voice tensor
voice_tensor = await self._voice_manager.load_voice( voice_tensor = await self._voice_manager.load_voice(
voice_name, device=backend.device voice_name, device=backend.device
) )
chunk_audio = await self.model_manager.generate( chunk_data = await self.model_manager.generate(
tokens, voice_tensor, speed=speed tokens, voice_tensor, speed=speed, return_timestamps=return_timestamps
) )
if chunk_audio is None: if chunk_data.audio is None:
logger.error("Model generated None for audio chunk") logger.error("Model generated None for audio chunk")
return return
if len(chunk_audio) == 0: if len(chunk_data.audio) == 0:
logger.error("Model generated empty audio chunk") logger.error("Model generated empty audio chunk")
return return
# For streaming, convert to bytes # For streaming, convert to bytes
if output_format: if output_format:
try: try:
converted = await AudioService.convert_audio( converted, chunk_data = await AudioService.convert_audio(
chunk_audio, chunk_data,
24000, 24000,
output_format, output_format,
speed, speed,
@ -149,16 +148,16 @@ class TTSService:
normalizer=normalizer, normalizer=normalizer,
is_last_chunk=is_last, is_last_chunk=is_last,
) )
yield converted yield converted, chunk_data
except Exception as e: except Exception as e:
logger.error(f"Failed to convert audio: {str(e)}") logger.error(f"Failed to convert audio: {str(e)}")
else: else:
trimmed = await AudioService.trim_audio(chunk_audio, trimmed = await AudioService.trim_audio(chunk_data,
chunk_text, chunk_text,
speed, speed,
is_last, is_last,
normalizer) normalizer)
yield trimmed yield trimmed.audio, trimmed
except Exception as e: except Exception as e:
logger.error(f"Failed to process tokens: {str(e)}") logger.error(f"Failed to process tokens: {str(e)}")
@ -242,8 +241,9 @@ class TTSService:
speed: float = 1.0, speed: float = 1.0,
output_format: str = "wav", output_format: str = "wav",
lang_code: Optional[str] = None, lang_code: Optional[str] = None,
normalization_options: Optional[NormalizationOptions] = NormalizationOptions() normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
) -> AsyncGenerator[bytes, None]: return_timestamps: Optional[bool] = False,
) -> AsyncGenerator[Tuple[bytes,AudioChunk], None]:
"""Generate and stream audio chunks.""" """Generate and stream audio chunks."""
stream_normalizer = AudioNormalizer() stream_normalizer = AudioNormalizer()
chunk_index = 0 chunk_index = 0
@ -266,7 +266,7 @@ class TTSService:
async for chunk_text, tokens in smart_split(text,normalization_options=normalization_options): async for chunk_text, tokens in smart_split(text,normalization_options=normalization_options):
try: try:
# Process audio for chunk # Process audio for chunk
async for result in self._process_chunk( async for result, chunk_data in self._process_chunk(
chunk_text, # Pass text for Kokoro V1 chunk_text, # Pass text for Kokoro V1
tokens, # Pass tokens for legacy backends tokens, # Pass tokens for legacy backends
voice_name, # Pass voice name voice_name, # Pass voice name
@ -279,7 +279,7 @@ class TTSService:
lang_code=pipeline_lang_code, # Pass lang_code lang_code=pipeline_lang_code, # Pass lang_code
): ):
if result is not None: if result is not None:
yield result yield result,chunk_data
chunk_index += 1 chunk_index += 1
else: else:
logger.warning( logger.warning(
@ -296,7 +296,7 @@ class TTSService:
if chunk_index > 0: if chunk_index > 0:
try: try:
# Empty tokens list to finalize audio # Empty tokens list to finalize audio
async for result in self._process_chunk( async for result,chunk_data in self._process_chunk(
"", # Empty text "", # Empty text
[], # Empty tokens [], # Empty tokens
voice_name, voice_name,
@ -309,7 +309,7 @@ class TTSService:
lang_code=pipeline_lang_code, # Pass lang_code lang_code=pipeline_lang_code, # Pass lang_code
): ):
if result is not None: if result is not None:
yield result yield result, chunk_data
except Exception as e: except Exception as e:
logger.error(f"Failed to finalize audio stream: {str(e)}") logger.error(f"Failed to finalize audio stream: {str(e)}")
@ -325,7 +325,7 @@ class TTSService:
speed: float = 1.0, speed: float = 1.0,
return_timestamps: bool = False, return_timestamps: bool = False,
lang_code: Optional[str] = None, lang_code: Optional[str] = None,
) -> Union[Tuple[np.ndarray, float], Tuple[np.ndarray, float, List[dict]]]: ) -> Tuple[Tuple[np.ndarray,AudioChunk]]:
"""Generate complete audio for text using streaming internally.""" """Generate complete audio for text using streaming internally."""
start_time = time.time() start_time = time.time()
chunks = [] chunks = []