mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
WIP
This commit is contained in:
parent
da1e280805
commit
45cdb607e6
4 changed files with 110 additions and 26 deletions
|
@ -1,12 +1,21 @@
|
||||||
"""Base interface for Kokoro inference."""
|
"""Base interface for Kokoro inference."""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import AsyncGenerator, Optional, Tuple, Union
|
from typing import AsyncGenerator, Optional, Tuple, Union, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
class AudioChunk:
|
||||||
|
"""Class for audio chunks returned by model backends"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
audio: np.ndarray,
|
||||||
|
word_timestamps: Optional[List]=None
|
||||||
|
):
|
||||||
|
self.audio=audio
|
||||||
|
self.word_timestamps=word_timestamps
|
||||||
|
|
||||||
class ModelBackend(ABC):
|
class ModelBackend(ABC):
|
||||||
"""Abstract base class for model inference backend."""
|
"""Abstract base class for model inference backend."""
|
||||||
|
|
||||||
|
@ -28,7 +37,7 @@ class ModelBackend(ABC):
|
||||||
text: str,
|
text: str,
|
||||||
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
|
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
|
||||||
speed: float = 1.0,
|
speed: float = 1.0,
|
||||||
) -> AsyncGenerator[np.ndarray, None]:
|
) -> AsyncGenerator[AudioChunk, None]:
|
||||||
"""Generate audio from text.
|
"""Generate audio from text.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -12,7 +12,7 @@ from ..core import paths
|
||||||
from ..core.config import settings
|
from ..core.config import settings
|
||||||
from ..core.model_config import model_config
|
from ..core.model_config import model_config
|
||||||
from .base import BaseModelBackend
|
from .base import BaseModelBackend
|
||||||
|
from .base import AudioChunk
|
||||||
|
|
||||||
class KokoroV1(BaseModelBackend):
|
class KokoroV1(BaseModelBackend):
|
||||||
"""Kokoro backend with controlled resource management."""
|
"""Kokoro backend with controlled resource management."""
|
||||||
|
@ -181,7 +181,8 @@ class KokoroV1(BaseModelBackend):
|
||||||
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
|
voice: Union[str, Tuple[str, Union[torch.Tensor, str]]],
|
||||||
speed: float = 1.0,
|
speed: float = 1.0,
|
||||||
lang_code: Optional[str] = None,
|
lang_code: Optional[str] = None,
|
||||||
) -> AsyncGenerator[np.ndarray, None]:
|
return_timestamps: Optional[bool] = False,
|
||||||
|
) -> AsyncGenerator[AudioChunk, None]:
|
||||||
"""Generate audio using model.
|
"""Generate audio using model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -249,7 +250,64 @@ class KokoroV1(BaseModelBackend):
|
||||||
):
|
):
|
||||||
if result.audio is not None:
|
if result.audio is not None:
|
||||||
logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
|
logger.debug(f"Got audio chunk with shape: {result.audio.shape}")
|
||||||
yield result.audio.numpy()
|
word_timestamps=None
|
||||||
|
if return_timestamps and hasattr(result, "tokens") and result.tokens:
|
||||||
|
word_timestamps=[]
|
||||||
|
current_offset=0.0
|
||||||
|
logger.debug(
|
||||||
|
f"Processing chunk timestamps with {len(result.tokens)} tokens"
|
||||||
|
)
|
||||||
|
if result.pred_dur is not None:
|
||||||
|
try:
|
||||||
|
# Join timestamps for this chunk's tokens
|
||||||
|
KPipeline.join_timestamps(
|
||||||
|
result.tokens, result.pred_dur
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add timestamps with offset
|
||||||
|
for token in result.tokens:
|
||||||
|
if not all(
|
||||||
|
hasattr(token, attr)
|
||||||
|
for attr in [
|
||||||
|
"text",
|
||||||
|
"start_ts",
|
||||||
|
"end_ts",
|
||||||
|
]
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
if not token.text or not token.text.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
start_time = float(token.start_ts) + current_offset
|
||||||
|
end_time = float(token.end_ts) + current_offset
|
||||||
|
word_timestamps.append(
|
||||||
|
{
|
||||||
|
"word": str(token.text).strip(),
|
||||||
|
"start_time": start_time,
|
||||||
|
"end_time": end_time,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Added timestamp for word '{token.text}': {start_time:.3f}s - {end_time:.3f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update offset for next chunk based on pred_dur
|
||||||
|
chunk_duration = (
|
||||||
|
float(result.pred_dur.sum()) / 80
|
||||||
|
) # Convert frames to seconds
|
||||||
|
current_offset = max(
|
||||||
|
current_offset + chunk_duration, end_time
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Updated time offset to {current_offset:.3f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to process timestamps for chunk: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
yield AudioChunk(result.audio.numpy(),word_timestamps=word_timestamps)
|
||||||
else:
|
else:
|
||||||
logger.warning("No audio in chunk")
|
logger.warning("No audio in chunk")
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
"""Audio conversion service"""
|
"""Audio conversion service"""
|
||||||
|
|
||||||
import struct
|
import struct
|
||||||
|
import time
|
||||||
|
from typing import Tuple
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -13,7 +15,7 @@ from torch import norm
|
||||||
|
|
||||||
from ..core.config import settings
|
from ..core.config import settings
|
||||||
from .streaming_audio_writer import StreamingAudioWriter
|
from .streaming_audio_writer import StreamingAudioWriter
|
||||||
|
from ..inference.base import AudioChunk
|
||||||
|
|
||||||
class AudioNormalizer:
|
class AudioNormalizer:
|
||||||
"""Handles audio normalization state for a single stream"""
|
"""Handles audio normalization state for a single stream"""
|
||||||
|
@ -113,7 +115,7 @@ class AudioService:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def convert_audio(
|
async def convert_audio(
|
||||||
audio_data: np.ndarray,
|
audio_chunk: AudioChunk,
|
||||||
sample_rate: int,
|
sample_rate: int,
|
||||||
output_format: str,
|
output_format: str,
|
||||||
speed: float = 1,
|
speed: float = 1,
|
||||||
|
@ -121,7 +123,7 @@ class AudioService:
|
||||||
is_first_chunk: bool = True,
|
is_first_chunk: bool = True,
|
||||||
is_last_chunk: bool = False,
|
is_last_chunk: bool = False,
|
||||||
normalizer: AudioNormalizer = None,
|
normalizer: AudioNormalizer = None,
|
||||||
) -> bytes:
|
) -> Tuple[bytes,AudioChunk]:
|
||||||
"""Convert audio data to specified format with streaming support
|
"""Convert audio data to specified format with streaming support
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -137,6 +139,7 @@ class AudioService:
|
||||||
Returns:
|
Returns:
|
||||||
Bytes of the converted audio chunk
|
Bytes of the converted audio chunk
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Validate format
|
# Validate format
|
||||||
if output_format not in AudioService.SUPPORTED_FORMATS:
|
if output_format not in AudioService.SUPPORTED_FORMATS:
|
||||||
|
@ -145,10 +148,12 @@ class AudioService:
|
||||||
# Always normalize audio to ensure proper amplitude scaling
|
# Always normalize audio to ensure proper amplitude scaling
|
||||||
if normalizer is None:
|
if normalizer is None:
|
||||||
normalizer = AudioNormalizer()
|
normalizer = AudioNormalizer()
|
||||||
|
|
||||||
normalized_audio = await normalizer.normalize(audio_data)
|
|
||||||
normalized_audio = AudioService.trim_audio(normalized_audio,chunk_text,speed,is_last_chunk,normalizer)
|
|
||||||
|
|
||||||
|
print("1")
|
||||||
|
audio_chunk.audio = await normalizer.normalize(audio_chunk.audio)
|
||||||
|
print("2")
|
||||||
|
audio_chunk = AudioService.trim_audio(audio_chunk,chunk_text,speed,is_last_chunk,normalizer)
|
||||||
|
print("3")
|
||||||
# Get or create format-specific writer
|
# Get or create format-specific writer
|
||||||
writer_key = f"{output_format}_{sample_rate}"
|
writer_key = f"{output_format}_{sample_rate}"
|
||||||
if is_first_chunk or writer_key not in AudioService._writers:
|
if is_first_chunk or writer_key not in AudioService._writers:
|
||||||
|
@ -156,14 +161,16 @@ class AudioService:
|
||||||
output_format, sample_rate
|
output_format, sample_rate
|
||||||
)
|
)
|
||||||
writer = AudioService._writers[writer_key]
|
writer = AudioService._writers[writer_key]
|
||||||
|
print("4")
|
||||||
# Write audio data first
|
# Write audio data first
|
||||||
if len(normalized_audio) > 0:
|
if len(audio_chunk.audio) > 0:
|
||||||
chunk_data = writer.write_chunk(normalized_audio)
|
chunk_data = writer.write_chunk(audio_chunk.audio)
|
||||||
|
print("5")
|
||||||
# Then finalize if this is the last chunk
|
# Then finalize if this is the last chunk
|
||||||
if is_last_chunk:
|
if is_last_chunk:
|
||||||
|
print("6")
|
||||||
final_data = writer.write_chunk(finalize=True)
|
final_data = writer.write_chunk(finalize=True)
|
||||||
|
print("7")
|
||||||
del AudioService._writers[writer_key]
|
del AudioService._writers[writer_key]
|
||||||
return final_data if final_data else b""
|
return final_data if final_data else b""
|
||||||
|
|
||||||
|
@ -175,7 +182,7 @@ class AudioService:
|
||||||
f"Failed to convert audio stream to {output_format}: {str(e)}"
|
f"Failed to convert audio stream to {output_format}: {str(e)}"
|
||||||
)
|
)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def trim_audio(audio_data: np.ndarray, chunk_text: str = "", speed: float = 1, is_last_chunk: bool = False, normalizer: AudioNormalizer = None) -> np.ndarray:
|
def trim_audio(audio_chunk: AudioChunk, chunk_text: str = "", speed: float = 1, is_last_chunk: bool = False, normalizer: AudioNormalizer = None) -> AudioChunk:
|
||||||
"""Trim silence from start and end
|
"""Trim silence from start and end
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -192,9 +199,15 @@ class AudioService:
|
||||||
normalizer = AudioNormalizer()
|
normalizer = AudioNormalizer()
|
||||||
|
|
||||||
# Trim start and end if enough samples
|
# Trim start and end if enough samples
|
||||||
if len(audio_data) > (2 * normalizer.samples_to_trim):
|
if len(audio_chunk.audio) > (2 * normalizer.samples_to_trim):
|
||||||
audio_data = audio_data[normalizer.samples_to_trim : -normalizer.samples_to_trim]
|
audio_chunk.audio = audio_chunk.audio[normalizer.samples_to_trim : -normalizer.samples_to_trim]
|
||||||
|
|
||||||
# Find non silent portion and trim
|
# Find non silent portion and trim
|
||||||
start_index,end_index=normalizer.find_first_last_non_silent(audio_data,chunk_text,speed,is_last_chunk=is_last_chunk)
|
start_index,end_index=normalizer.find_first_last_non_silent(audio_chunk.audio,chunk_text,speed,is_last_chunk=is_last_chunk)
|
||||||
return audio_data[start_index:end_index]
|
|
||||||
|
audio_chunk.audio=audio_chunk.audio[start_index:end_index]
|
||||||
|
for timestamp in audio_chunk.word_timestamps:
|
||||||
|
timestamp["start_time"]-=start_index * 24000
|
||||||
|
timestamp["end_time"]-=start_index * 24000
|
||||||
|
return audio_chunk
|
||||||
|
|
|
@ -86,17 +86,18 @@ class TTSService:
|
||||||
# Generate audio using pre-warmed model
|
# Generate audio using pre-warmed model
|
||||||
if isinstance(backend, KokoroV1):
|
if isinstance(backend, KokoroV1):
|
||||||
# For Kokoro V1, pass text and voice info with lang_code
|
# For Kokoro V1, pass text and voice info with lang_code
|
||||||
async for chunk_audio in self.model_manager.generate(
|
async for chunk_data in self.model_manager.generate(
|
||||||
chunk_text,
|
chunk_text,
|
||||||
(voice_name, voice_path),
|
(voice_name, voice_path),
|
||||||
speed=speed,
|
speed=speed,
|
||||||
lang_code=lang_code,
|
lang_code=lang_code,
|
||||||
|
return_timestamps=True,
|
||||||
):
|
):
|
||||||
# For streaming, convert to bytes
|
# For streaming, convert to bytes
|
||||||
if output_format:
|
if output_format:
|
||||||
try:
|
try:
|
||||||
converted = await AudioService.convert_audio(
|
converted, chunk_data = await AudioService.convert_audio(
|
||||||
chunk_audio,
|
chunk_data,
|
||||||
24000,
|
24000,
|
||||||
output_format,
|
output_format,
|
||||||
speed,
|
speed,
|
||||||
|
@ -105,17 +106,20 @@ class TTSService:
|
||||||
is_last_chunk=is_last,
|
is_last_chunk=is_last,
|
||||||
normalizer=normalizer,
|
normalizer=normalizer,
|
||||||
)
|
)
|
||||||
|
print(chunk_data.word_timestamps)
|
||||||
yield converted
|
yield converted
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to convert audio: {str(e)}")
|
logger.error(f"Failed to convert audio: {str(e)}")
|
||||||
else:
|
else:
|
||||||
trimmed = await AudioService.trim_audio(chunk_audio,
|
chunk_data = await AudioService.trim_audio(chunk_data,
|
||||||
chunk_text,
|
chunk_text,
|
||||||
speed,
|
speed,
|
||||||
is_last,
|
is_last,
|
||||||
normalizer)
|
normalizer)
|
||||||
yield trimmed
|
print(chunk_data.word_timestamps)
|
||||||
|
yield chunk_data.audio
|
||||||
else:
|
else:
|
||||||
|
print("old backend")
|
||||||
# For legacy backends, load voice tensor
|
# For legacy backends, load voice tensor
|
||||||
voice_tensor = await self._voice_manager.load_voice(
|
voice_tensor = await self._voice_manager.load_voice(
|
||||||
voice_name, device=backend.device
|
voice_name, device=backend.device
|
||||||
|
|
Loading…
Add table
Reference in a new issue