This commit is contained in:
Fireblade 2025-02-11 22:32:10 -05:00
parent da1e280805
commit 45cdb607e6
4 changed files with 110 additions and 26 deletions

View file

@ -1,11 +1,20 @@
"""Base interface for Kokoro inference."""
from abc import ABC, abstractmethod
from typing import AsyncGenerator, Optional, Tuple, Union
from typing import AsyncGenerator, Optional, Tuple, Union, List
import numpy as np
import torch
class AudioChunk:
"""Class for audio chunks returned by model backends"""
def __init__(self,
audio: np.ndarray,
word_timestamps: Optional[List]=None
):
self.audio=audio
self.word_timestamps=word_timestamps
class ModelBackend(ABC):
"""Abstract base class for model inference backend."""
@ -28,7 +37,7 @@ class ModelBackend(ABC):
text: str,
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
speed: float = 1.0,
) -> AsyncGenerator[np.ndarray, None]:
) -> AsyncGenerator[AudioChunk, None]:
"""Generate audio from text.
Args:

View file

@ -12,7 +12,7 @@ from ..core import paths
from ..core.config import settings
from ..core.model_config import model_config
from .base import BaseModelBackend
from .base import AudioChunk
class KokoroV1(BaseModelBackend):
"""Kokoro backend with controlled resource management."""
@ -181,7 +181,8 @@ class KokoroV1(BaseModelBackend):
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
speed: float = 1.0,
lang_code: Optional[str] = None,
) -> AsyncGenerator[np.ndarray, None]:
return_timestamps: Optional[bool] = False,
) -> AsyncGenerator[AudioChunk, None]:
"""Generate audio using model.
Args:
@ -249,7 +250,64 @@ class KokoroV1(BaseModelBackend):
):
if result.audio is not None:
logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
yield result.audio.numpy()
word_timestamps=None
if return_timestamps and hasattr(result, "tokens") and result.tokens:
word_timestamps=[]
current_offset=0.0
logger.debug(
f"Processing chunk timestamps with {len(result.tokens)} tokens"
)
if result.pred_dur is not None:
try:
# Join timestamps for this chunk's tokens
KPipeline.join_timestamps(
result.tokens, result.pred_dur
)
# Add timestamps with offset
for token in result.tokens:
if not all(
hasattr(token, attr)
for attr in [
"text",
"start_ts",
"end_ts",
]
):
continue
if not token.text or not token.text.strip():
continue
start_time = float(token.start_ts) + current_offset
end_time = float(token.end_ts) + current_offset
word_timestamps.append(
{
"word": str(token.text).strip(),
"start_time": start_time,
"end_time": end_time,
}
)
logger.debug(
f"Added timestamp for word '{token.text}': {start_time:.3f}s - {end_time:.3f}s"
)
# Update offset for next chunk based on pred_dur
chunk_duration = (
float(result.pred_dur.sum()) / 80
) # Convert frames to seconds
current_offset = max(
current_offset + chunk_duration, end_time
)
logger.debug(
f"Updated time offset to {current_offset:.3f}s"
)
except Exception as e:
logger.error(
f"Failed to process timestamps for chunk: {e}"
)
yield AudioChunk(result.audio.numpy(),word_timestamps=word_timestamps)
else:
logger.warning("No audio in chunk")

View file

@ -1,6 +1,8 @@
"""Audio conversion service"""
import struct
import time
from typing import Tuple
from io import BytesIO
import numpy as np
@ -13,7 +15,7 @@ from torch import norm
from ..core.config import settings
from .streaming_audio_writer import StreamingAudioWriter
from ..inference.base import AudioChunk
class AudioNormalizer:
"""Handles audio normalization state for a single stream"""
@ -113,7 +115,7 @@ class AudioService:
@staticmethod
async def convert_audio(
audio_data: np.ndarray,
audio_chunk: AudioChunk,
sample_rate: int,
output_format: str,
speed: float = 1,
@ -121,7 +123,7 @@ class AudioService:
is_first_chunk: bool = True,
is_last_chunk: bool = False,
normalizer: AudioNormalizer = None,
) -> bytes:
) -> Tuple[bytes,AudioChunk]:
"""Convert audio data to specified format with streaming support
Args:
@ -137,6 +139,7 @@ class AudioService:
Returns:
Bytes of the converted audio chunk
"""
try:
# Validate format
if output_format not in AudioService.SUPPORTED_FORMATS:
@ -146,9 +149,11 @@ class AudioService:
if normalizer is None:
normalizer = AudioNormalizer()
normalized_audio = await normalizer.normalize(audio_data)
normalized_audio = AudioService.trim_audio(normalized_audio,chunk_text,speed,is_last_chunk,normalizer)
print("1")
audio_chunk.audio = await normalizer.normalize(audio_chunk.audio)
print("2")
audio_chunk = AudioService.trim_audio(audio_chunk,chunk_text,speed,is_last_chunk,normalizer)
print("3")
# Get or create format-specific writer
writer_key = f"{output_format}_{sample_rate}"
if is_first_chunk or writer_key not in AudioService._writers:
@ -156,14 +161,16 @@ class AudioService:
output_format, sample_rate
)
writer = AudioService._writers[writer_key]
print("4")
# Write audio data first
if len(normalized_audio) > 0:
chunk_data = writer.write_chunk(normalized_audio)
if len(audio_chunk.audio) > 0:
chunk_data = writer.write_chunk(audio_chunk.audio)
print("5")
# Then finalize if this is the last chunk
if is_last_chunk:
print("6")
final_data = writer.write_chunk(finalize=True)
print("7")
del AudioService._writers[writer_key]
return final_data if final_data else b""
@ -175,7 +182,7 @@ class AudioService:
f"Failed to convert audio stream to {output_format}: {str(e)}"
)
@staticmethod
def trim_audio(audio_data: np.ndarray, chunk_text: str = "", speed: float = 1, is_last_chunk: bool = False, normalizer: AudioNormalizer = None) -> np.ndarray:
def trim_audio(audio_chunk: AudioChunk, chunk_text: str = "", speed: float = 1, is_last_chunk: bool = False, normalizer: AudioNormalizer = None) -> AudioChunk:
"""Trim silence from start and end
Args:
@ -192,9 +199,15 @@ class AudioService:
normalizer = AudioNormalizer()
# Trim start and end if enough samples
if len(audio_data) > (2 * normalizer.samples_to_trim):
audio_data = audio_data[normalizer.samples_to_trim : -normalizer.samples_to_trim]
if len(audio_chunk.audio) > (2 * normalizer.samples_to_trim):
audio_chunk.audio = audio_chunk.audio[normalizer.samples_to_trim : -normalizer.samples_to_trim]
# Find non silent portion and trim
start_index,end_index=normalizer.find_first_last_non_silent(audio_data,chunk_text,speed,is_last_chunk=is_last_chunk)
return audio_data[start_index:end_index]
start_index,end_index=normalizer.find_first_last_non_silent(audio_chunk.audio,chunk_text,speed,is_last_chunk=is_last_chunk)
audio_chunk.audio=audio_chunk.audio[start_index:end_index]
for timestamp in audio_chunk.word_timestamps:
timestamp["start_time"]-=start_index * 24000
timestamp["end_time"]-=start_index * 24000
return audio_chunk

View file

@ -86,17 +86,18 @@ class TTSService:
# Generate audio using pre-warmed model
if isinstance(backend, KokoroV1):
# For Kokoro V1, pass text and voice info with lang_code
async for chunk_audio in self.model_manager.generate(
async for chunk_data in self.model_manager.generate(
chunk_text,
(voice_name, voice_path),
speed=speed,
lang_code=lang_code,
return_timestamps=True,
):
# For streaming, convert to bytes
if output_format:
try:
converted = await AudioService.convert_audio(
chunk_audio,
converted, chunk_data = await AudioService.convert_audio(
chunk_data,
24000,
output_format,
speed,
@ -105,17 +106,20 @@ class TTSService:
is_last_chunk=is_last,
normalizer=normalizer,
)
print(chunk_data.word_timestamps)
yield converted
except Exception as e:
logger.error(f"Failed to convert audio: {str(e)}")
else:
trimmed = await AudioService.trim_audio(chunk_audio,
chunk_data = await AudioService.trim_audio(chunk_data,
chunk_text,
speed,
is_last,
normalizer)
yield trimmed
print(chunk_data.word_timestamps)
yield chunk_data.audio
else:
print("old backend")
# For legacy backends, load voice tensor
voice_tensor = await self._voice_manager.load_voice(
voice_name, device=backend.device