From 91d370d97f26fc9545759fd796f5d601bd337d4c Mon Sep 17 00:00:00 2001 From: Fireblade2534 Date: Wed, 12 Feb 2025 17:13:56 +0000 Subject: [PATCH] More working on streaming timestamps --- api/src/routers/openai_compatible.py | 18 +++++++---- api/src/services/tts_service.py | 48 ++++++++++++++-------------- 2 files changed, 35 insertions(+), 31 deletions(-) diff --git a/api/src/routers/openai_compatible.py b/api/src/routers/openai_compatible.py index 3be678a..5090a4b 100644 --- a/api/src/routers/openai_compatible.py +++ b/api/src/routers/openai_compatible.py @@ -5,9 +5,10 @@ import json import os import re import tempfile -from typing import AsyncGenerator, Dict, List, Union +from typing import AsyncGenerator, Dict, List, Union, Tuple import aiofiles +from inference.base import AudioChunk import torch from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response from fastapi.responses import FileResponse, StreamingResponse @@ -127,13 +128,13 @@ async def process_voices( async def stream_audio_chunks( tts_service: TTSService, request: OpenAISpeechRequest, client_request: Request -) -> AsyncGenerator[bytes, None]: +) -> AsyncGenerator[Tuple[bytes,AudioChunk], None]: """Stream audio chunks as they're generated with client disconnect handling""" voice_name = await process_voices(request.voice, tts_service) try: logger.info(f"Starting audio generation with lang_code: {request.lang_code}") - async for chunk in tts_service.generate_audio_stream( + async for chunk, chunk_data in tts_service.generate_audio_stream( text=request.input, voice=voice_name, speed=request.speed, @@ -148,7 +149,7 @@ async def stream_audio_chunks( if is_disconnected: logger.info("Client disconnected, stopping audio generation") break - yield chunk + yield chunk, chunk_data except Exception as e: logger.error(f"Error in audio streaming: {str(e)}") # Let the exception propagate to trigger cleanup @@ -213,13 +214,16 @@ async def create_speech( } # Create async generator for streaming - async def dual_output(): + async def dual_output(return_json:bool=False): try: # Write chunks to temp file and stream - async for chunk in generator: + async for chunk, chunk_data in generator: if chunk: # Skip empty chunks await temp_writer.write(chunk) - yield chunk + if return_json: + yield chunk, chunk_data + else: + yield chunk # Finalize the temp file await temp_writer.finalize() diff --git a/api/src/services/tts_service.py b/api/src/services/tts_service.py index 0fcaac7..7f2dcaa 100644 --- a/api/src/services/tts_service.py +++ b/api/src/services/tts_service.py @@ -53,7 +53,8 @@ class TTSService: is_last: bool = False, normalizer: Optional[AudioNormalizer] = None, lang_code: Optional[str] = None, - ) -> AsyncGenerator[Union[np.ndarray, bytes], None]: + return_timestamps: Optional[bool] = False, + ) -> AsyncGenerator[Tuple[Union[np.ndarray, bytes],AudioChunk], Tuple[None,None]]: """Process tokens into audio.""" async with self._chunk_semaphore: try: @@ -63,7 +64,7 @@ class TTSService: if not output_format: yield np.array([], dtype=np.float32) return - result, _ = await AudioService.convert_audio( + result, chunk_data = await AudioService.convert_audio( AudioChunk(np.array([0], dtype=np.float32)), # Dummy data for type checking 24000, output_format, @@ -73,7 +74,7 @@ class TTSService: normalizer=normalizer, is_last_chunk=True, ) - yield result + yield result, chunk_data return # Skip empty chunks @@ -91,7 +92,7 @@ class TTSService: (voice_name, voice_path), speed=speed, lang_code=lang_code, - return_timestamps=True, + return_timestamps=return_timestamps, ): # For streaming, convert to bytes if output_format: @@ -106,8 +107,7 @@ class TTSService: is_last_chunk=is_last, normalizer=normalizer, ) - print(chunk_data.word_timestamps) - yield converted + yield converted, chunk_data except Exception as e: logger.error(f"Failed to convert audio: {str(e)}") else: @@ -116,31 +116,30 @@ class TTSService: speed, is_last, normalizer) - print(chunk_data.word_timestamps) - yield chunk_data.audio + yield chunk_data.audio, chunk_data else: # For legacy backends, load voice tensor voice_tensor = await self._voice_manager.load_voice( voice_name, device=backend.device ) - chunk_audio = await self.model_manager.generate( - tokens, voice_tensor, speed=speed + chunk_data = await self.model_manager.generate( + tokens, voice_tensor, speed=speed, return_timestamps=return_timestamps ) - if chunk_audio is None: + if chunk_data.audio is None: logger.error("Model generated None for audio chunk") return - if len(chunk_audio) == 0: + if len(chunk_data.audio) == 0: logger.error("Model generated empty audio chunk") return # For streaming, convert to bytes if output_format: try: - converted = await AudioService.convert_audio( - chunk_audio, + converted, chunk_data = await AudioService.convert_audio( + chunk_data, 24000, output_format, speed, @@ -149,16 +148,16 @@ class TTSService: normalizer=normalizer, is_last_chunk=is_last, ) - yield converted + yield converted, chunk_data except Exception as e: logger.error(f"Failed to convert audio: {str(e)}") else: - trimmed = await AudioService.trim_audio(chunk_audio, + trimmed = await AudioService.trim_audio(chunk_data, chunk_text, speed, is_last, normalizer) - yield trimmed + yield trimmed.audio, trimmed except Exception as e: logger.error(f"Failed to process tokens: {str(e)}") @@ -242,8 +241,9 @@ class TTSService: speed: float = 1.0, output_format: str = "wav", lang_code: Optional[str] = None, - normalization_options: Optional[NormalizationOptions] = NormalizationOptions() - ) -> AsyncGenerator[bytes, None]: + normalization_options: Optional[NormalizationOptions] = NormalizationOptions(), + return_timestamps: Optional[bool] = False, + ) -> AsyncGenerator[Tuple[bytes,AudioChunk], None]: """Generate and stream audio chunks.""" stream_normalizer = AudioNormalizer() chunk_index = 0 @@ -266,7 +266,7 @@ class TTSService: async for chunk_text, tokens in smart_split(text,normalization_options=normalization_options): try: # Process audio for chunk - async for result in self._process_chunk( + async for result, chunk_data in self._process_chunk( chunk_text, # Pass text for Kokoro V1 tokens, # Pass tokens for legacy backends voice_name, # Pass voice name @@ -279,7 +279,7 @@ class TTSService: lang_code=pipeline_lang_code, # Pass lang_code ): if result is not None: - yield result + yield result,chunk_data chunk_index += 1 else: logger.warning( @@ -296,7 +296,7 @@ class TTSService: if chunk_index > 0: try: # Empty tokens list to finalize audio - async for result in self._process_chunk( + async for result,chunk_data in self._process_chunk( "", # Empty text [], # Empty tokens voice_name, @@ -309,7 +309,7 @@ class TTSService: lang_code=pipeline_lang_code, # Pass lang_code ): if result is not None: - yield result + yield result, chunk_data except Exception as e: logger.error(f"Failed to finalize audio stream: {str(e)}") @@ -325,7 +325,7 @@ class TTSService: speed: float = 1.0, return_timestamps: bool = False, lang_code: Optional[str] = None, - ) -> Union[Tuple[np.ndarray, float], Tuple[np.ndarray, float, List[dict]]]: + ) -> Tuple[Tuple[np.ndarray,AudioChunk]]: """Generate complete audio for text using streaming internally.""" start_time = time.time() chunks = []