mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Refactor ONNX GPU backend and phoneme generation: improve token handling, add chunk processing for audio generation, and initial introduce stitch options for audio chunks.
This commit is contained in:
parent
d50214d3be
commit
66f46e82f9
5 changed files with 106 additions and 27 deletions
|
@ -87,7 +87,9 @@ class ONNXGPUBackend(BaseModelBackend):
|
|||
try:
|
||||
# Prepare inputs
|
||||
tokens_input = np.array([[0, *tokens, 0]], dtype=np.int64) # Add start/end tokens
|
||||
style_input = voice[len(tokens) + 2].cpu().numpy() # Move to CPU for ONNX
|
||||
# Use modulo to ensure index stays within voice tensor bounds
|
||||
style_idx = (len(tokens) + 2) % voice.size(0) # Add 2 for start/end tokens
|
||||
style_input = voice[style_idx].cpu().numpy() # Move to CPU for ONNX
|
||||
speed_input = np.full(1, speed, dtype=np.float32)
|
||||
|
||||
# Run inference
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
from loguru import logger
|
||||
|
||||
|
@ -42,10 +43,8 @@ async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse:
|
|||
if not phonemes:
|
||||
raise ValueError("Failed to generate phonemes")
|
||||
|
||||
# Get tokens
|
||||
# Get tokens (without adding start/end tokens to match process_text behavior)
|
||||
tokens = tokenize(phonemes)
|
||||
tokens = [0] + tokens + [0] # Add start/end tokens
|
||||
|
||||
return PhonemeResponse(phonemes=phonemes, tokens=tokens)
|
||||
except ValueError as e:
|
||||
logger.error(f"Error in phoneme generation: {str(e)}")
|
||||
|
@ -93,23 +92,54 @@ async def generate_from_phonemes(
|
|||
},
|
||||
)
|
||||
|
||||
# Convert phonemes to tokens
|
||||
tokens = tokenize(request.phonemes)
|
||||
tokens = [0] + tokens + [0] # Add start/end tokens
|
||||
# Handle both single string and list of chunks
|
||||
phoneme_chunks = [request.phonemes] if isinstance(request.phonemes, str) else request.phonemes
|
||||
audio_chunks = []
|
||||
|
||||
# Generate audio directly from tokens
|
||||
audio = await tts_service.model_manager.generate(
|
||||
tokens,
|
||||
# Load voice tensor first since we'll need it for all chunks
|
||||
voice_tensor = await tts_service._voice_manager.load_voice(
|
||||
request.voice,
|
||||
speed=request.speed
|
||||
device=tts_service.model_manager.get_backend().device
|
||||
)
|
||||
|
||||
if audio is None:
|
||||
raise ValueError("Failed to generate audio")
|
||||
try:
|
||||
# Process each chunk
|
||||
for chunk in phoneme_chunks:
|
||||
# Convert chunk to tokens
|
||||
tokens = tokenize(chunk)
|
||||
tokens = [0] + tokens + [0] # Add start/end tokens
|
||||
|
||||
# Validate chunk length
|
||||
if len(tokens) > 510: # 510 to leave room for start/end tokens
|
||||
raise ValueError(
|
||||
f"Chunk too long ({len(tokens)} tokens). Each chunk must be under 510 tokens."
|
||||
)
|
||||
|
||||
# Generate audio for chunk
|
||||
chunk_audio = await tts_service.model_manager.generate(
|
||||
tokens,
|
||||
voice_tensor,
|
||||
speed=request.speed
|
||||
)
|
||||
if chunk_audio is not None:
|
||||
audio_chunks.append(chunk_audio)
|
||||
|
||||
# Combine chunks if needed
|
||||
if len(audio_chunks) > 1:
|
||||
audio = np.concatenate(audio_chunks)
|
||||
elif len(audio_chunks) == 1:
|
||||
audio = audio_chunks[0]
|
||||
else:
|
||||
raise ValueError("No audio chunks were generated")
|
||||
|
||||
finally:
|
||||
# Clean up voice tensor
|
||||
del voice_tensor
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Convert to WAV bytes
|
||||
wav_bytes = AudioService.convert_audio(
|
||||
audio, 24000, "wav", is_first_chunk=True, is_last_chunk=True, stream=False
|
||||
audio, 24000, "wav", is_first_chunk=True, is_last_chunk=True, stream=False,
|
||||
)
|
||||
|
||||
return Response(
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from pydantic import validator
|
||||
from typing import List, Union, Optional
|
||||
|
||||
class PhonemeRequest(BaseModel):
|
||||
text: str
|
||||
|
@ -11,9 +12,34 @@ class PhonemeResponse(BaseModel):
|
|||
tokens: list[int]
|
||||
|
||||
|
||||
class StitchOptions(BaseModel):
|
||||
"""Options for stitching audio chunks together"""
|
||||
gap_method: str = Field(
|
||||
default="static_trim",
|
||||
description="Method to handle gaps between chunks. Currently only 'static_trim' supported."
|
||||
)
|
||||
trim_ms: int = Field(
|
||||
default=0,
|
||||
ge=0,
|
||||
description="Milliseconds to trim from chunk boundaries when using static_trim"
|
||||
)
|
||||
|
||||
@validator('gap_method')
|
||||
def validate_gap_method(cls, v):
|
||||
if v != 'static_trim':
|
||||
raise ValueError("Currently only 'static_trim' gap method is supported")
|
||||
return v
|
||||
|
||||
class GenerateFromPhonemesRequest(BaseModel):
|
||||
phonemes: str
|
||||
phonemes: Union[str, List[str]] = Field(
|
||||
...,
|
||||
description="Single phoneme string or list of phoneme chunks to stitch together"
|
||||
)
|
||||
voice: str = Field(..., description="Voice ID to use for generation")
|
||||
speed: float = Field(
|
||||
default=1.0, ge=0.1, le=5.0, description="Speed factor for generation"
|
||||
)
|
||||
options: Optional[StitchOptions] = Field(
|
||||
default=None,
|
||||
description="Optional settings for audio generation and stitching"
|
||||
)
|
||||
|
|
|
@ -46,17 +46,29 @@ def generate_audio_from_phonemes(
|
|||
WAV audio bytes if successful, None if failed
|
||||
"""
|
||||
# Create the request payload
|
||||
payload = {"phonemes": phonemes, "voice": voice, "speed": speed}
|
||||
payload = {
|
||||
"phonemes": phonemes,
|
||||
"voice": voice,
|
||||
"speed": speed,
|
||||
"stitch_long_content": True # Default to false to get the error message
|
||||
}
|
||||
|
||||
# Make POST request to generate audio
|
||||
response = requests.post(
|
||||
"http://localhost:8880/text/generate_from_phonemes", json=payload
|
||||
)
|
||||
|
||||
# Raise exception for error status codes
|
||||
response.raise_for_status()
|
||||
|
||||
return response.content
|
||||
try:
|
||||
# Make POST request to generate audio
|
||||
response = requests.post(
|
||||
"http://localhost:8880/text/generate_from_phonemes", json=payload
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
except requests.HTTPError as e:
|
||||
# Get the error details from the response
|
||||
try:
|
||||
error_details = response.json()
|
||||
error_msg = error_details.get('detail', {}).get('message', str(e))
|
||||
print(f"Server Error: {error_msg}")
|
||||
except:
|
||||
print(f"Error: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -66,7 +78,13 @@ def main():
|
|||
"How are you today? I am doing reasonably well, thank you for asking",
|
||||
"""This is a test of the phoneme generation system. Do not be alarmed.
|
||||
This is only a test. If this were a real phoneme emergency, '
|
||||
you would be instructed to a phoneme shelter in your area.""",
|
||||
you would be instructed to a phoneme shelter in your area. Repeat.
|
||||
This is a test of the phoneme generation system. Do not be alarmed.
|
||||
This is only a test. If this were a real phoneme emergency, '
|
||||
you would be instructed to a phoneme shelter in your area. Repeat.
|
||||
This is a test of the phoneme generation system. Do not be alarmed.
|
||||
This is only a test. If this were a real phoneme emergency, '
|
||||
you would be instructed to a phoneme shelter in your area""",
|
||||
]
|
||||
|
||||
print("Generating phonemes and audio for example texts...\n")
|
||||
|
@ -85,6 +103,9 @@ def main():
|
|||
|
||||
# Generate audio from phonemes
|
||||
print("Generating audio...")
|
||||
if len(phonemes) > 500: # split into arrays of 500 phonemes
|
||||
phonemes = [phonemes[i:i+500] for i in range(0, len(phonemes), 500)]
|
||||
|
||||
audio_bytes = generate_audio_from_phonemes(phonemes)
|
||||
|
||||
if audio_bytes:
|
||||
|
|
Binary file not shown.
Loading…
Add table
Reference in a new issue