mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
More working on streaming timestamps
This commit is contained in:
parent
51b6b01589
commit
91d370d97f
2 changed files with 35 additions and 31 deletions
|
@ -5,9 +5,10 @@ import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import AsyncGenerator, Dict, List, Union
|
from typing import AsyncGenerator, Dict, List, Union, Tuple
|
||||||
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
|
from inference.base import AudioChunk
|
||||||
import torch
|
import torch
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response
|
||||||
from fastapi.responses import FileResponse, StreamingResponse
|
from fastapi.responses import FileResponse, StreamingResponse
|
||||||
|
@ -127,13 +128,13 @@ async def process_voices(
|
||||||
|
|
||||||
async def stream_audio_chunks(
|
async def stream_audio_chunks(
|
||||||
tts_service: TTSService, request: OpenAISpeechRequest, client_request: Request
|
tts_service: TTSService, request: OpenAISpeechRequest, client_request: Request
|
||||||
) -> AsyncGenerator[bytes, None]:
|
) -> AsyncGenerator[Tuple[bytes,AudioChunk], None]:
|
||||||
"""Stream audio chunks as they're generated with client disconnect handling"""
|
"""Stream audio chunks as they're generated with client disconnect handling"""
|
||||||
voice_name = await process_voices(request.voice, tts_service)
|
voice_name = await process_voices(request.voice, tts_service)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Starting audio generation with lang_code: {request.lang_code}")
|
logger.info(f"Starting audio generation with lang_code: {request.lang_code}")
|
||||||
async for chunk in tts_service.generate_audio_stream(
|
async for chunk, chunk_data in tts_service.generate_audio_stream(
|
||||||
text=request.input,
|
text=request.input,
|
||||||
voice=voice_name,
|
voice=voice_name,
|
||||||
speed=request.speed,
|
speed=request.speed,
|
||||||
|
@ -148,7 +149,7 @@ async def stream_audio_chunks(
|
||||||
if is_disconnected:
|
if is_disconnected:
|
||||||
logger.info("Client disconnected, stopping audio generation")
|
logger.info("Client disconnected, stopping audio generation")
|
||||||
break
|
break
|
||||||
yield chunk
|
yield chunk, chunk_data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in audio streaming: {str(e)}")
|
logger.error(f"Error in audio streaming: {str(e)}")
|
||||||
# Let the exception propagate to trigger cleanup
|
# Let the exception propagate to trigger cleanup
|
||||||
|
@ -213,13 +214,16 @@ async def create_speech(
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create async generator for streaming
|
# Create async generator for streaming
|
||||||
async def dual_output():
|
async def dual_output(return_json:bool=False):
|
||||||
try:
|
try:
|
||||||
# Write chunks to temp file and stream
|
# Write chunks to temp file and stream
|
||||||
async for chunk in generator:
|
async for chunk, chunk_data in generator:
|
||||||
if chunk: # Skip empty chunks
|
if chunk: # Skip empty chunks
|
||||||
await temp_writer.write(chunk)
|
await temp_writer.write(chunk)
|
||||||
yield chunk
|
if return_json:
|
||||||
|
yield chunk, chunk_data
|
||||||
|
else:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
# Finalize the temp file
|
# Finalize the temp file
|
||||||
await temp_writer.finalize()
|
await temp_writer.finalize()
|
||||||
|
|
|
@ -53,7 +53,8 @@ class TTSService:
|
||||||
is_last: bool = False,
|
is_last: bool = False,
|
||||||
normalizer: Optional[AudioNormalizer] = None,
|
normalizer: Optional[AudioNormalizer] = None,
|
||||||
lang_code: Optional[str] = None,
|
lang_code: Optional[str] = None,
|
||||||
) -> AsyncGenerator[Union[np.ndarray, bytes], None]:
|
return_timestamps: Optional[bool] = False,
|
||||||
|
) -> AsyncGenerator[Tuple[Union[np.ndarray, bytes],AudioChunk], Tuple[None,None]]:
|
||||||
"""Process tokens into audio."""
|
"""Process tokens into audio."""
|
||||||
async with self._chunk_semaphore:
|
async with self._chunk_semaphore:
|
||||||
try:
|
try:
|
||||||
|
@ -63,7 +64,7 @@ class TTSService:
|
||||||
if not output_format:
|
if not output_format:
|
||||||
yield np.array([], dtype=np.float32)
|
yield np.array([], dtype=np.float32)
|
||||||
return
|
return
|
||||||
result, _ = await AudioService.convert_audio(
|
result, chunk_data = await AudioService.convert_audio(
|
||||||
AudioChunk(np.array([0], dtype=np.float32)), # Dummy data for type checking
|
AudioChunk(np.array([0], dtype=np.float32)), # Dummy data for type checking
|
||||||
24000,
|
24000,
|
||||||
output_format,
|
output_format,
|
||||||
|
@ -73,7 +74,7 @@ class TTSService:
|
||||||
normalizer=normalizer,
|
normalizer=normalizer,
|
||||||
is_last_chunk=True,
|
is_last_chunk=True,
|
||||||
)
|
)
|
||||||
yield result
|
yield result, chunk_data
|
||||||
return
|
return
|
||||||
|
|
||||||
# Skip empty chunks
|
# Skip empty chunks
|
||||||
|
@ -91,7 +92,7 @@ class TTSService:
|
||||||
(voice_name, voice_path),
|
(voice_name, voice_path),
|
||||||
speed=speed,
|
speed=speed,
|
||||||
lang_code=lang_code,
|
lang_code=lang_code,
|
||||||
return_timestamps=True,
|
return_timestamps=return_timestamps,
|
||||||
):
|
):
|
||||||
# For streaming, convert to bytes
|
# For streaming, convert to bytes
|
||||||
if output_format:
|
if output_format:
|
||||||
|
@ -106,8 +107,7 @@ class TTSService:
|
||||||
is_last_chunk=is_last,
|
is_last_chunk=is_last,
|
||||||
normalizer=normalizer,
|
normalizer=normalizer,
|
||||||
)
|
)
|
||||||
print(chunk_data.word_timestamps)
|
yield converted, chunk_data
|
||||||
yield converted
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to convert audio: {str(e)}")
|
logger.error(f"Failed to convert audio: {str(e)}")
|
||||||
else:
|
else:
|
||||||
|
@ -116,31 +116,30 @@ class TTSService:
|
||||||
speed,
|
speed,
|
||||||
is_last,
|
is_last,
|
||||||
normalizer)
|
normalizer)
|
||||||
print(chunk_data.word_timestamps)
|
yield chunk_data.audio, chunk_data
|
||||||
yield chunk_data.audio
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
# For legacy backends, load voice tensor
|
# For legacy backends, load voice tensor
|
||||||
voice_tensor = await self._voice_manager.load_voice(
|
voice_tensor = await self._voice_manager.load_voice(
|
||||||
voice_name, device=backend.device
|
voice_name, device=backend.device
|
||||||
)
|
)
|
||||||
chunk_audio = await self.model_manager.generate(
|
chunk_data = await self.model_manager.generate(
|
||||||
tokens, voice_tensor, speed=speed
|
tokens, voice_tensor, speed=speed, return_timestamps=return_timestamps
|
||||||
)
|
)
|
||||||
|
|
||||||
if chunk_audio is None:
|
if chunk_data.audio is None:
|
||||||
logger.error("Model generated None for audio chunk")
|
logger.error("Model generated None for audio chunk")
|
||||||
return
|
return
|
||||||
|
|
||||||
if len(chunk_audio) == 0:
|
if len(chunk_data.audio) == 0:
|
||||||
logger.error("Model generated empty audio chunk")
|
logger.error("Model generated empty audio chunk")
|
||||||
return
|
return
|
||||||
|
|
||||||
# For streaming, convert to bytes
|
# For streaming, convert to bytes
|
||||||
if output_format:
|
if output_format:
|
||||||
try:
|
try:
|
||||||
converted = await AudioService.convert_audio(
|
converted, chunk_data = await AudioService.convert_audio(
|
||||||
chunk_audio,
|
chunk_data,
|
||||||
24000,
|
24000,
|
||||||
output_format,
|
output_format,
|
||||||
speed,
|
speed,
|
||||||
|
@ -149,16 +148,16 @@ class TTSService:
|
||||||
normalizer=normalizer,
|
normalizer=normalizer,
|
||||||
is_last_chunk=is_last,
|
is_last_chunk=is_last,
|
||||||
)
|
)
|
||||||
yield converted
|
yield converted, chunk_data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to convert audio: {str(e)}")
|
logger.error(f"Failed to convert audio: {str(e)}")
|
||||||
else:
|
else:
|
||||||
trimmed = await AudioService.trim_audio(chunk_audio,
|
trimmed = await AudioService.trim_audio(chunk_data,
|
||||||
chunk_text,
|
chunk_text,
|
||||||
speed,
|
speed,
|
||||||
is_last,
|
is_last,
|
||||||
normalizer)
|
normalizer)
|
||||||
yield trimmed
|
yield trimmed.audio, trimmed
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to process tokens: {str(e)}")
|
logger.error(f"Failed to process tokens: {str(e)}")
|
||||||
|
|
||||||
|
@ -242,8 +241,9 @@ class TTSService:
|
||||||
speed: float = 1.0,
|
speed: float = 1.0,
|
||||||
output_format: str = "wav",
|
output_format: str = "wav",
|
||||||
lang_code: Optional[str] = None,
|
lang_code: Optional[str] = None,
|
||||||
normalization_options: Optional[NormalizationOptions] = NormalizationOptions()
|
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
|
||||||
) -> AsyncGenerator[bytes, None]:
|
return_timestamps: Optional[bool] = False,
|
||||||
|
) -> AsyncGenerator[Tuple[bytes,AudioChunk], None]:
|
||||||
"""Generate and stream audio chunks."""
|
"""Generate and stream audio chunks."""
|
||||||
stream_normalizer = AudioNormalizer()
|
stream_normalizer = AudioNormalizer()
|
||||||
chunk_index = 0
|
chunk_index = 0
|
||||||
|
@ -266,7 +266,7 @@ class TTSService:
|
||||||
async for chunk_text, tokens in smart_split(text,normalization_options=normalization_options):
|
async for chunk_text, tokens in smart_split(text,normalization_options=normalization_options):
|
||||||
try:
|
try:
|
||||||
# Process audio for chunk
|
# Process audio for chunk
|
||||||
async for result in self._process_chunk(
|
async for result, chunk_data in self._process_chunk(
|
||||||
chunk_text, # Pass text for Kokoro V1
|
chunk_text, # Pass text for Kokoro V1
|
||||||
tokens, # Pass tokens for legacy backends
|
tokens, # Pass tokens for legacy backends
|
||||||
voice_name, # Pass voice name
|
voice_name, # Pass voice name
|
||||||
|
@ -279,7 +279,7 @@ class TTSService:
|
||||||
lang_code=pipeline_lang_code, # Pass lang_code
|
lang_code=pipeline_lang_code, # Pass lang_code
|
||||||
):
|
):
|
||||||
if result is not None:
|
if result is not None:
|
||||||
yield result
|
yield result,chunk_data
|
||||||
chunk_index += 1
|
chunk_index += 1
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
@ -296,7 +296,7 @@ class TTSService:
|
||||||
if chunk_index > 0:
|
if chunk_index > 0:
|
||||||
try:
|
try:
|
||||||
# Empty tokens list to finalize audio
|
# Empty tokens list to finalize audio
|
||||||
async for result in self._process_chunk(
|
async for result,chunk_data in self._process_chunk(
|
||||||
"", # Empty text
|
"", # Empty text
|
||||||
[], # Empty tokens
|
[], # Empty tokens
|
||||||
voice_name,
|
voice_name,
|
||||||
|
@ -309,7 +309,7 @@ class TTSService:
|
||||||
lang_code=pipeline_lang_code, # Pass lang_code
|
lang_code=pipeline_lang_code, # Pass lang_code
|
||||||
):
|
):
|
||||||
if result is not None:
|
if result is not None:
|
||||||
yield result
|
yield result, chunk_data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to finalize audio stream: {str(e)}")
|
logger.error(f"Failed to finalize audio stream: {str(e)}")
|
||||||
|
|
||||||
|
@ -325,7 +325,7 @@ class TTSService:
|
||||||
speed: float = 1.0,
|
speed: float = 1.0,
|
||||||
return_timestamps: bool = False,
|
return_timestamps: bool = False,
|
||||||
lang_code: Optional[str] = None,
|
lang_code: Optional[str] = None,
|
||||||
) -> Union[Tuple[np.ndarray, float], Tuple[np.ndarray, float, List[dict]]]:
|
) -> Tuple[Tuple[np.ndarray,AudioChunk]]:
|
||||||
"""Generate complete audio for text using streaming internally."""
|
"""Generate complete audio for text using streaming internally."""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
chunks = []
|
chunks = []
|
||||||
|
|
Loading…
Add table
Reference in a new issue