This commit is contained in:
Fireblade 2025-02-11 22:32:10 -05:00
parent da1e280805
commit 45cdb607e6
4 changed files with 110 additions and 26 deletions

View file

@ -1,12 +1,21 @@
"""Base interface for Kokoro inference.""" """Base interface for Kokoro inference."""
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import AsyncGenerator, Optional, Tuple, Union from typing import AsyncGenerator, Optional, Tuple, Union, List
import numpy as np import numpy as np
import torch import torch
class AudioChunk:
"""Class for audio chunks returned by model backends"""
def __init__(self,
audio: np.ndarray,
word_timestamps: Optional[List]=None
):
self.audio=audio
self.word_timestamps=word_timestamps
class ModelBackend(ABC): class ModelBackend(ABC):
"""Abstract base class for model inference backend.""" """Abstract base class for model inference backend."""
@ -28,7 +37,7 @@ class ModelBackend(ABC):
text: str, text: str,
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]], voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
speed: float = 1.0, speed: float = 1.0,
) -> AsyncGenerator[np.ndarray, None]: ) -> AsyncGenerator[AudioChunk, None]:
"""Generate audio from text. """Generate audio from text.
Args: Args:

View file

@ -12,7 +12,7 @@ from ..core import paths
from ..core.config import settings from ..core.config import settings
from ..core.model_config import model_config from ..core.model_config import model_config
from .base import BaseModelBackend from .base import BaseModelBackend
from .base import AudioChunk
class KokoroV1(BaseModelBackend): class KokoroV1(BaseModelBackend):
"""Kokoro backend with controlled resource management.""" """Kokoro backend with controlled resource management."""
@ -181,7 +181,8 @@ class KokoroV1(BaseModelBackend):
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]], voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
speed: float = 1.0, speed: float = 1.0,
lang_code: Optional[str] = None, lang_code: Optional[str] = None,
) -> AsyncGenerator[np.ndarray, None]: return_timestamps: Optional[bool] = False,
) -> AsyncGenerator[AudioChunk, None]:
"""Generate audio using model. """Generate audio using model.
Args: Args:
@ -249,7 +250,64 @@ class KokoroV1(BaseModelBackend):
): ):
if result.audio is not None: if result.audio is not None:
logger.debug(f"Got audio chunk with shape: {result.audio.shape}") logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
yield result.audio.numpy() word_timestamps=None
if return_timestamps and hasattr(result, "tokens") and result.tokens:
word_timestamps=[]
current_offset=0.0
logger.debug(
f"Processing chunk timestamps with {len(result.tokens)} tokens"
)
if result.pred_dur is not None:
try:
# Join timestamps for this chunk's tokens
KPipeline.join_timestamps(
result.tokens, result.pred_dur
)
# Add timestamps with offset
for token in result.tokens:
if not all(
hasattr(token, attr)
for attr in [
"text",
"start_ts",
"end_ts",
]
):
continue
if not token.text or not token.text.strip():
continue
start_time = float(token.start_ts) + current_offset
end_time = float(token.end_ts) + current_offset
word_timestamps.append(
{
"word": str(token.text).strip(),
"start_time": start_time,
"end_time": end_time,
}
)
logger.debug(
f"Added timestamp for word '{token.text}': {start_time:.3f}s - {end_time:.3f}s"
)
# Update offset for next chunk based on pred_dur
chunk_duration = (
float(result.pred_dur.sum()) / 80
) # Convert frames to seconds
current_offset = max(
current_offset + chunk_duration, end_time
)
logger.debug(
f"Updated time offset to {current_offset:.3f}s"
)
except Exception as e:
logger.error(
f"Failed to process timestamps for chunk: {e}"
)
yield AudioChunk(result.audio.numpy(),word_timestamps=word_timestamps)
else: else:
logger.warning("No audio in chunk") logger.warning("No audio in chunk")

View file

@ -1,6 +1,8 @@
"""Audio conversion service""" """Audio conversion service"""
import struct import struct
import time
from typing import Tuple
from io import BytesIO from io import BytesIO
import numpy as np import numpy as np
@ -13,7 +15,7 @@ from torch import norm
from ..core.config import settings from ..core.config import settings
from .streaming_audio_writer import StreamingAudioWriter from .streaming_audio_writer import StreamingAudioWriter
from ..inference.base import AudioChunk
class AudioNormalizer: class AudioNormalizer:
"""Handles audio normalization state for a single stream""" """Handles audio normalization state for a single stream"""
@ -113,7 +115,7 @@ class AudioService:
@staticmethod @staticmethod
async def convert_audio( async def convert_audio(
audio_data: np.ndarray, audio_chunk: AudioChunk,
sample_rate: int, sample_rate: int,
output_format: str, output_format: str,
speed: float = 1, speed: float = 1,
@ -121,7 +123,7 @@ class AudioService:
is_first_chunk: bool = True, is_first_chunk: bool = True,
is_last_chunk: bool = False, is_last_chunk: bool = False,
normalizer: AudioNormalizer = None, normalizer: AudioNormalizer = None,
) -> bytes: ) -> Tuple[bytes,AudioChunk]:
"""Convert audio data to specified format with streaming support """Convert audio data to specified format with streaming support
Args: Args:
@ -137,6 +139,7 @@ class AudioService:
Returns: Returns:
Bytes of the converted audio chunk Bytes of the converted audio chunk
""" """
try: try:
# Validate format # Validate format
if output_format not in AudioService.SUPPORTED_FORMATS: if output_format not in AudioService.SUPPORTED_FORMATS:
@ -145,10 +148,12 @@ class AudioService:
# Always normalize audio to ensure proper amplitude scaling # Always normalize audio to ensure proper amplitude scaling
if normalizer is None: if normalizer is None:
normalizer = AudioNormalizer() normalizer = AudioNormalizer()
normalized_audio = await normalizer.normalize(audio_data)
normalized_audio = AudioService.trim_audio(normalized_audio,chunk_text,speed,is_last_chunk,normalizer)
print("1")
audio_chunk.audio = await normalizer.normalize(audio_chunk.audio)
print("2")
audio_chunk = AudioService.trim_audio(audio_chunk,chunk_text,speed,is_last_chunk,normalizer)
print("3")
# Get or create format-specific writer # Get or create format-specific writer
writer_key = f"{output_format}_{sample_rate}" writer_key = f"{output_format}_{sample_rate}"
if is_first_chunk or writer_key not in AudioService._writers: if is_first_chunk or writer_key not in AudioService._writers:
@ -156,14 +161,16 @@ class AudioService:
output_format, sample_rate output_format, sample_rate
) )
writer = AudioService._writers[writer_key] writer = AudioService._writers[writer_key]
print("4")
# Write audio data first # Write audio data first
if len(normalized_audio) > 0: if len(audio_chunk.audio) > 0:
chunk_data = writer.write_chunk(normalized_audio) chunk_data = writer.write_chunk(audio_chunk.audio)
print("5")
# Then finalize if this is the last chunk # Then finalize if this is the last chunk
if is_last_chunk: if is_last_chunk:
print("6")
final_data = writer.write_chunk(finalize=True) final_data = writer.write_chunk(finalize=True)
print("7")
del AudioService._writers[writer_key] del AudioService._writers[writer_key]
return final_data if final_data else b"" return final_data if final_data else b""
@ -175,7 +182,7 @@ class AudioService:
f"Failed to convert audio stream to {output_format}: {str(e)}" f"Failed to convert audio stream to {output_format}: {str(e)}"
) )
@staticmethod @staticmethod
def trim_audio(audio_data: np.ndarray, chunk_text: str = "", speed: float = 1, is_last_chunk: bool = False, normalizer: AudioNormalizer = None) -> np.ndarray: def trim_audio(audio_chunk: AudioChunk, chunk_text: str = "", speed: float = 1, is_last_chunk: bool = False, normalizer: AudioNormalizer = None) -> AudioChunk:
"""Trim silence from start and end """Trim silence from start and end
Args: Args:
@ -192,9 +199,15 @@ class AudioService:
normalizer = AudioNormalizer() normalizer = AudioNormalizer()
# Trim start and end if enough samples # Trim start and end if enough samples
if len(audio_data) > (2 * normalizer.samples_to_trim): if len(audio_chunk.audio) > (2 * normalizer.samples_to_trim):
audio_data = audio_data[normalizer.samples_to_trim : -normalizer.samples_to_trim] audio_chunk.audio = audio_chunk.audio[normalizer.samples_to_trim : -normalizer.samples_to_trim]
# Find non silent portion and trim # Find non silent portion and trim
start_index,end_index=normalizer.find_first_last_non_silent(audio_data,chunk_text,speed,is_last_chunk=is_last_chunk) start_index,end_index=normalizer.find_first_last_non_silent(audio_chunk.audio,chunk_text,speed,is_last_chunk=is_last_chunk)
return audio_data[start_index:end_index]
audio_chunk.audio=audio_chunk.audio[start_index:end_index]
for timestamp in audio_chunk.word_timestamps:
timestamp["start_time"]-=start_index * 24000
timestamp["end_time"]-=start_index * 24000
return audio_chunk

View file

@ -86,17 +86,18 @@ class TTSService:
# Generate audio using pre-warmed model # Generate audio using pre-warmed model
if isinstance(backend, KokoroV1): if isinstance(backend, KokoroV1):
# For Kokoro V1, pass text and voice info with lang_code # For Kokoro V1, pass text and voice info with lang_code
async for chunk_audio in self.model_manager.generate( async for chunk_data in self.model_manager.generate(
chunk_text, chunk_text,
(voice_name, voice_path), (voice_name, voice_path),
speed=speed, speed=speed,
lang_code=lang_code, lang_code=lang_code,
return_timestamps=True,
): ):
# For streaming, convert to bytes # For streaming, convert to bytes
if output_format: if output_format:
try: try:
converted = await AudioService.convert_audio( converted, chunk_data = await AudioService.convert_audio(
chunk_audio, chunk_data,
24000, 24000,
output_format, output_format,
speed, speed,
@ -105,17 +106,20 @@ class TTSService:
is_last_chunk=is_last, is_last_chunk=is_last,
normalizer=normalizer, normalizer=normalizer,
) )
print(chunk_data.word_timestamps)
yield converted yield converted
except Exception as e: except Exception as e:
logger.error(f"Failed to convert audio: {str(e)}") logger.error(f"Failed to convert audio: {str(e)}")
else: else:
trimmed = await AudioService.trim_audio(chunk_audio, chunk_data = await AudioService.trim_audio(chunk_data,
chunk_text, chunk_text,
speed, speed,
is_last, is_last,
normalizer) normalizer)
yield trimmed print(chunk_data.word_timestamps)
yield chunk_data.audio
else: else:
print("old backend")
# For legacy backends, load voice tensor # For legacy backends, load voice tensor
voice_tensor = await self._voice_manager.load_voice( voice_tensor = await self._voice_manager.load_voice(
voice_name, device=backend.device voice_name, device=backend.device