mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
WIP
This commit is contained in:
parent
da1e280805
commit
45cdb607e6
4 changed files with 110 additions and 26 deletions
|
@ -1,11 +1,20 @@
|
|||
"""Base interface for Kokoro inference."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import AsyncGenerator, Optional, Tuple, Union
|
||||
from typing import AsyncGenerator, Optional, Tuple, Union, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
class AudioChunk:
|
||||
"""Class for audio chunks returned by model backends"""
|
||||
|
||||
def __init__(self,
|
||||
audio: np.ndarray,
|
||||
word_timestamps: Optional[List]=None
|
||||
):
|
||||
self.audio=audio
|
||||
self.word_timestamps=word_timestamps
|
||||
|
||||
class ModelBackend(ABC):
|
||||
"""Abstract base class for model inference backend."""
|
||||
|
@ -28,7 +37,7 @@ class ModelBackend(ABC):
|
|||
text: str,
|
||||
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
|
||||
speed: float = 1.0,
|
||||
) -> AsyncGenerator[np.ndarray, None]:
|
||||
) -> AsyncGenerator[AudioChunk, None]:
|
||||
"""Generate audio from text.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -12,7 +12,7 @@ from ..core import paths
|
|||
from ..core.config import settings
|
||||
from ..core.model_config import model_config
|
||||
from .base import BaseModelBackend
|
||||
|
||||
from .base import AudioChunk
|
||||
|
||||
class KokoroV1(BaseModelBackend):
|
||||
"""Kokoro backend with controlled resource management."""
|
||||
|
@ -181,7 +181,8 @@ class KokoroV1(BaseModelBackend):
|
|||
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
|
||||
speed: float = 1.0,
|
||||
lang_code: Optional[str] = None,
|
||||
) -> AsyncGenerator[np.ndarray, None]:
|
||||
return_timestamps: Optional[bool] = False,
|
||||
) -> AsyncGenerator[AudioChunk, None]:
|
||||
"""Generate audio using model.
|
||||
|
||||
Args:
|
||||
|
@ -249,7 +250,64 @@ class KokoroV1(BaseModelBackend):
|
|||
):
|
||||
if result.audio is not None:
|
||||
logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
|
||||
yield result.audio.numpy()
|
||||
word_timestamps=None
|
||||
if return_timestamps and hasattr(result, "tokens") and result.tokens:
|
||||
word_timestamps=[]
|
||||
current_offset=0.0
|
||||
logger.debug(
|
||||
f"Processing chunk timestamps with {len(result.tokens)} tokens"
|
||||
)
|
||||
if result.pred_dur is not None:
|
||||
try:
|
||||
# Join timestamps for this chunk's tokens
|
||||
KPipeline.join_timestamps(
|
||||
result.tokens, result.pred_dur
|
||||
)
|
||||
|
||||
# Add timestamps with offset
|
||||
for token in result.tokens:
|
||||
if not all(
|
||||
hasattr(token, attr)
|
||||
for attr in [
|
||||
"text",
|
||||
"start_ts",
|
||||
"end_ts",
|
||||
]
|
||||
):
|
||||
continue
|
||||
if not token.text or not token.text.strip():
|
||||
continue
|
||||
|
||||
start_time = float(token.start_ts) + current_offset
|
||||
end_time = float(token.end_ts) + current_offset
|
||||
word_timestamps.append(
|
||||
{
|
||||
"word": str(token.text).strip(),
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
}
|
||||
)
|
||||
logger.debug(
|
||||
f"Added timestamp for word '{token.text}': {start_time:.3f}s - {end_time:.3f}s"
|
||||
)
|
||||
|
||||
# Update offset for next chunk based on pred_dur
|
||||
chunk_duration = (
|
||||
float(result.pred_dur.sum()) / 80
|
||||
) # Convert frames to seconds
|
||||
current_offset = max(
|
||||
current_offset + chunk_duration, end_time
|
||||
)
|
||||
logger.debug(
|
||||
f"Updated time offset to {current_offset:.3f}s"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to process timestamps for chunk: {e}"
|
||||
)
|
||||
|
||||
yield AudioChunk(result.audio.numpy(),word_timestamps=word_timestamps)
|
||||
else:
|
||||
logger.warning("No audio in chunk")
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
"""Audio conversion service"""
|
||||
|
||||
import struct
|
||||
import time
|
||||
from typing import Tuple
|
||||
from io import BytesIO
|
||||
|
||||
import numpy as np
|
||||
|
@ -13,7 +15,7 @@ from torch import norm
|
|||
|
||||
from ..core.config import settings
|
||||
from .streaming_audio_writer import StreamingAudioWriter
|
||||
|
||||
from ..inference.base import AudioChunk
|
||||
|
||||
class AudioNormalizer:
|
||||
"""Handles audio normalization state for a single stream"""
|
||||
|
@ -113,7 +115,7 @@ class AudioService:
|
|||
|
||||
@staticmethod
|
||||
async def convert_audio(
|
||||
audio_data: np.ndarray,
|
||||
audio_chunk: AudioChunk,
|
||||
sample_rate: int,
|
||||
output_format: str,
|
||||
speed: float = 1,
|
||||
|
@ -121,7 +123,7 @@ class AudioService:
|
|||
is_first_chunk: bool = True,
|
||||
is_last_chunk: bool = False,
|
||||
normalizer: AudioNormalizer = None,
|
||||
) -> bytes:
|
||||
) -> Tuple[bytes,AudioChunk]:
|
||||
"""Convert audio data to specified format with streaming support
|
||||
|
||||
Args:
|
||||
|
@ -137,6 +139,7 @@ class AudioService:
|
|||
Returns:
|
||||
Bytes of the converted audio chunk
|
||||
"""
|
||||
|
||||
try:
|
||||
# Validate format
|
||||
if output_format not in AudioService.SUPPORTED_FORMATS:
|
||||
|
@ -146,9 +149,11 @@ class AudioService:
|
|||
if normalizer is None:
|
||||
normalizer = AudioNormalizer()
|
||||
|
||||
normalized_audio = await normalizer.normalize(audio_data)
|
||||
normalized_audio = AudioService.trim_audio(normalized_audio,chunk_text,speed,is_last_chunk,normalizer)
|
||||
|
||||
print("1")
|
||||
audio_chunk.audio = await normalizer.normalize(audio_chunk.audio)
|
||||
print("2")
|
||||
audio_chunk = AudioService.trim_audio(audio_chunk,chunk_text,speed,is_last_chunk,normalizer)
|
||||
print("3")
|
||||
# Get or create format-specific writer
|
||||
writer_key = f"{output_format}_{sample_rate}"
|
||||
if is_first_chunk or writer_key not in AudioService._writers:
|
||||
|
@ -156,14 +161,16 @@ class AudioService:
|
|||
output_format, sample_rate
|
||||
)
|
||||
writer = AudioService._writers[writer_key]
|
||||
|
||||
print("4")
|
||||
# Write audio data first
|
||||
if len(normalized_audio) > 0:
|
||||
chunk_data = writer.write_chunk(normalized_audio)
|
||||
|
||||
if len(audio_chunk.audio) > 0:
|
||||
chunk_data = writer.write_chunk(audio_chunk.audio)
|
||||
print("5")
|
||||
# Then finalize if this is the last chunk
|
||||
if is_last_chunk:
|
||||
print("6")
|
||||
final_data = writer.write_chunk(finalize=True)
|
||||
print("7")
|
||||
del AudioService._writers[writer_key]
|
||||
return final_data if final_data else b""
|
||||
|
||||
|
@ -175,7 +182,7 @@ class AudioService:
|
|||
f"Failed to convert audio stream to {output_format}: {str(e)}"
|
||||
)
|
||||
@staticmethod
|
||||
def trim_audio(audio_data: np.ndarray, chunk_text: str = "", speed: float = 1, is_last_chunk: bool = False, normalizer: AudioNormalizer = None) -> np.ndarray:
|
||||
def trim_audio(audio_chunk: AudioChunk, chunk_text: str = "", speed: float = 1, is_last_chunk: bool = False, normalizer: AudioNormalizer = None) -> AudioChunk:
|
||||
"""Trim silence from start and end
|
||||
|
||||
Args:
|
||||
|
@ -192,9 +199,15 @@ class AudioService:
|
|||
normalizer = AudioNormalizer()
|
||||
|
||||
# Trim start and end if enough samples
|
||||
if len(audio_data) > (2 * normalizer.samples_to_trim):
|
||||
audio_data = audio_data[normalizer.samples_to_trim : -normalizer.samples_to_trim]
|
||||
if len(audio_chunk.audio) > (2 * normalizer.samples_to_trim):
|
||||
audio_chunk.audio = audio_chunk.audio[normalizer.samples_to_trim : -normalizer.samples_to_trim]
|
||||
|
||||
# Find non silent portion and trim
|
||||
start_index,end_index=normalizer.find_first_last_non_silent(audio_data,chunk_text,speed,is_last_chunk=is_last_chunk)
|
||||
return audio_data[start_index:end_index]
|
||||
start_index,end_index=normalizer.find_first_last_non_silent(audio_chunk.audio,chunk_text,speed,is_last_chunk=is_last_chunk)
|
||||
|
||||
audio_chunk.audio=audio_chunk.audio[start_index:end_index]
|
||||
for timestamp in audio_chunk.word_timestamps:
|
||||
timestamp["start_time"]-=start_index * 24000
|
||||
timestamp["end_time"]-=start_index * 24000
|
||||
return audio_chunk
|
||||
|
|
@ -86,17 +86,18 @@ class TTSService:
|
|||
# Generate audio using pre-warmed model
|
||||
if isinstance(backend, KokoroV1):
|
||||
# For Kokoro V1, pass text and voice info with lang_code
|
||||
async for chunk_audio in self.model_manager.generate(
|
||||
async for chunk_data in self.model_manager.generate(
|
||||
chunk_text,
|
||||
(voice_name, voice_path),
|
||||
speed=speed,
|
||||
lang_code=lang_code,
|
||||
return_timestamps=True,
|
||||
):
|
||||
# For streaming, convert to bytes
|
||||
if output_format:
|
||||
try:
|
||||
converted = await AudioService.convert_audio(
|
||||
chunk_audio,
|
||||
converted, chunk_data = await AudioService.convert_audio(
|
||||
chunk_data,
|
||||
24000,
|
||||
output_format,
|
||||
speed,
|
||||
|
@ -105,17 +106,20 @@ class TTSService:
|
|||
is_last_chunk=is_last,
|
||||
normalizer=normalizer,
|
||||
)
|
||||
print(chunk_data.word_timestamps)
|
||||
yield converted
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert audio: {str(e)}")
|
||||
else:
|
||||
trimmed = await AudioService.trim_audio(chunk_audio,
|
||||
chunk_data = await AudioService.trim_audio(chunk_data,
|
||||
chunk_text,
|
||||
speed,
|
||||
is_last,
|
||||
normalizer)
|
||||
yield trimmed
|
||||
print(chunk_data.word_timestamps)
|
||||
yield chunk_data.audio
|
||||
else:
|
||||
print("old backend")
|
||||
# For legacy backends, load voice tensor
|
||||
voice_tensor = await self._voice_manager.load_voice(
|
||||
voice_name, device=backend.device
|
||||
|
|
Loading…
Add table
Reference in a new issue