diff --git a/api/src/inference/base.py b/api/src/inference/base.py index 7f3cd3f..93fcf69 100644 --- a/api/src/inference/base.py +++ b/api/src/inference/base.py @@ -1,12 +1,21 @@ """Base interface for Kokoro inference.""" from abc import ABC, abstractmethod -from typing import AsyncGenerator, Optional, Tuple, Union +from typing import AsyncGenerator, Optional, Tuple, Union, List import numpy as np import torch - +class AudioChunk: + """Class for audio chunks returned by model backends""" + + def __init__(self, + audio: np.ndarray, + word_timestamps: Optional[List]=None + ): + self.audio=audio + self.word_timestamps=word_timestamps + class ModelBackend(ABC): """Abstract base class for model inference backend.""" @@ -28,7 +37,7 @@ class ModelBackend(ABC): text: str, voice: Union[str, Tuple[str, Union[torch.Tensor, str]]], speed: float = 1.0, - ) -> AsyncGenerator[np.ndarray, None]: + ) -> AsyncGenerator[AudioChunk, None]: """Generate audio from text. Args: diff --git a/api/src/inference/kokoro_v1.py b/api/src/inference/kokoro_v1.py index 99e76fa..91fb44d 100644 --- a/api/src/inference/kokoro_v1.py +++ b/api/src/inference/kokoro_v1.py @@ -12,7 +12,7 @@ from ..core import paths from ..core.config import settings from ..core.model_config import model_config from .base import BaseModelBackend - +from .base import AudioChunk class KokoroV1(BaseModelBackend): """Kokoro backend with controlled resource management.""" @@ -181,7 +181,8 @@ class KokoroV1(BaseModelBackend): voice: Union[str, Tuple[str, Union[torch.Tensor, str]]], speed: float = 1.0, lang_code: Optional[str] = None, - ) -> AsyncGenerator[np.ndarray, None]: + return_timestamps: Optional[bool] = False, + ) -> AsyncGenerator[AudioChunk, None]: """Generate audio using model. Args: @@ -249,7 +250,64 @@ class KokoroV1(BaseModelBackend): ): if result.audio is not None: logger.debug(f"Got audio chunk with shape: {result.audio.shape}") - yield result.audio.numpy() + word_timestamps=None + if return_timestamps and hasattr(result, "tokens") and result.tokens: + word_timestamps=[] + current_offset=0.0 + logger.debug( + f"Processing chunk timestamps with {len(result.tokens)} tokens" + ) + if result.pred_dur is not None: + try: + # Join timestamps for this chunk's tokens + KPipeline.join_timestamps( + result.tokens, result.pred_dur + ) + + # Add timestamps with offset + for token in result.tokens: + if not all( + hasattr(token, attr) + for attr in [ + "text", + "start_ts", + "end_ts", + ] + ): + continue + if not token.text or not token.text.strip(): + continue + + start_time = float(token.start_ts) + current_offset + end_time = float(token.end_ts) + current_offset + word_timestamps.append( + { + "word": str(token.text).strip(), + "start_time": start_time, + "end_time": end_time, + } + ) + logger.debug( + f"Added timestamp for word '{token.text}': {start_time:.3f}s - {end_time:.3f}s" + ) + + # Update offset for next chunk based on pred_dur + chunk_duration = ( + float(result.pred_dur.sum()) / 80 + ) # Convert frames to seconds + current_offset = max( + current_offset + chunk_duration, end_time + ) + logger.debug( + f"Updated time offset to {current_offset:.3f}s" + ) + + except Exception as e: + logger.error( + f"Failed to process timestamps for chunk: {e}" + ) + + yield AudioChunk(result.audio.numpy(),word_timestamps=word_timestamps) else: logger.warning("No audio in chunk") diff --git a/api/src/services/audio.py b/api/src/services/audio.py index 64062b8..53c0ed4 100644 --- a/api/src/services/audio.py +++ b/api/src/services/audio.py @@ -1,6 +1,8 @@ """Audio conversion service""" import struct +import time +from typing import Tuple from io import BytesIO import numpy as np @@ -13,7 +15,7 @@ from torch import norm from ..core.config import settings from .streaming_audio_writer import StreamingAudioWriter - +from ..inference.base import AudioChunk class AudioNormalizer: """Handles audio normalization state for a single stream""" @@ -113,7 +115,7 @@ class AudioService: @staticmethod async def convert_audio( - audio_data: np.ndarray, + audio_chunk: AudioChunk, sample_rate: int, output_format: str, speed: float = 1, @@ -121,7 +123,7 @@ class AudioService: is_first_chunk: bool = True, is_last_chunk: bool = False, normalizer: AudioNormalizer = None, - ) -> bytes: + ) -> Tuple[bytes,AudioChunk]: """Convert audio data to specified format with streaming support Args: @@ -137,6 +139,7 @@ class AudioService: Returns: Bytes of the converted audio chunk """ + try: # Validate format if output_format not in AudioService.SUPPORTED_FORMATS: @@ -145,10 +148,12 @@ class AudioService: # Always normalize audio to ensure proper amplitude scaling if normalizer is None: normalizer = AudioNormalizer() - - normalized_audio = await normalizer.normalize(audio_data) - normalized_audio = AudioService.trim_audio(normalized_audio,chunk_text,speed,is_last_chunk,normalizer) + print("1") + audio_chunk.audio = await normalizer.normalize(audio_chunk.audio) + print("2") + audio_chunk = AudioService.trim_audio(audio_chunk,chunk_text,speed,is_last_chunk,normalizer) + print("3") # Get or create format-specific writer writer_key = f"{output_format}_{sample_rate}" if is_first_chunk or writer_key not in AudioService._writers: @@ -156,14 +161,16 @@ class AudioService: output_format, sample_rate ) writer = AudioService._writers[writer_key] - + print("4") # Write audio data first - if len(normalized_audio) > 0: - chunk_data = writer.write_chunk(normalized_audio) - + if len(audio_chunk.audio) > 0: + chunk_data = writer.write_chunk(audio_chunk.audio) + print("5") # Then finalize if this is the last chunk if is_last_chunk: + print("6") final_data = writer.write_chunk(finalize=True) + print("7") del AudioService._writers[writer_key] return final_data if final_data else b"" @@ -175,7 +182,7 @@ class AudioService: f"Failed to convert audio stream to {output_format}: {str(e)}" ) @staticmethod - def trim_audio(audio_data: np.ndarray, chunk_text: str = "", speed: float = 1, is_last_chunk: bool = False, normalizer: AudioNormalizer = None) -> np.ndarray: + def trim_audio(audio_chunk: AudioChunk, chunk_text: str = "", speed: float = 1, is_last_chunk: bool = False, normalizer: AudioNormalizer = None) -> AudioChunk: """Trim silence from start and end Args: @@ -192,9 +199,15 @@ class AudioService: normalizer = AudioNormalizer() # Trim start and end if enough samples - if len(audio_data) > (2 * normalizer.samples_to_trim): - audio_data = audio_data[normalizer.samples_to_trim : -normalizer.samples_to_trim] + if len(audio_chunk.audio) > (2 * normalizer.samples_to_trim): + audio_chunk.audio = audio_chunk.audio[normalizer.samples_to_trim : -normalizer.samples_to_trim] # Find non silent portion and trim - start_index,end_index=normalizer.find_first_last_non_silent(audio_data,chunk_text,speed,is_last_chunk=is_last_chunk) - return audio_data[start_index:end_index] \ No newline at end of file + start_index,end_index=normalizer.find_first_last_non_silent(audio_chunk.audio,chunk_text,speed,is_last_chunk=is_last_chunk) + + audio_chunk.audio=audio_chunk.audio[start_index:end_index] + for timestamp in audio_chunk.word_timestamps: + timestamp["start_time"]-=start_index * 24000 + timestamp["end_time"]-=start_index * 24000 + return audio_chunk + \ No newline at end of file diff --git a/api/src/services/tts_service.py b/api/src/services/tts_service.py index ba4dcc4..67e2641 100644 --- a/api/src/services/tts_service.py +++ b/api/src/services/tts_service.py @@ -86,17 +86,18 @@ class TTSService: # Generate audio using pre-warmed model if isinstance(backend, KokoroV1): # For Kokoro V1, pass text and voice info with lang_code - async for chunk_audio in self.model_manager.generate( + async for chunk_data in self.model_manager.generate( chunk_text, (voice_name, voice_path), speed=speed, lang_code=lang_code, + return_timestamps=True, ): # For streaming, convert to bytes if output_format: try: - converted = await AudioService.convert_audio( - chunk_audio, + converted, chunk_data = await AudioService.convert_audio( + chunk_data, 24000, output_format, speed, @@ -105,17 +106,20 @@ class TTSService: is_last_chunk=is_last, normalizer=normalizer, ) + print(chunk_data.word_timestamps) yield converted except Exception as e: logger.error(f"Failed to convert audio: {str(e)}") else: - trimmed = await AudioService.trim_audio(chunk_audio, + chunk_data = await AudioService.trim_audio(chunk_data, chunk_text, speed, is_last, normalizer) - yield trimmed + print(chunk_data.word_timestamps) + yield chunk_data.audio else: + print("old backend") # For legacy backends, load voice tensor voice_tensor = await self._voice_manager.load_voice( voice_name, device=backend.device