mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Refactor ONNX GPU backend and phoneme generation: improve token handling, add chunk processing for audio generation, and initial introduce stitch options for audio chunks.
This commit is contained in:
parent
d50214d3be
commit
66f46e82f9
5 changed files with 106 additions and 27 deletions
|
@ -87,7 +87,9 @@ class ONNXGPUBackend(BaseModelBackend):
|
||||||
try:
|
try:
|
||||||
# Prepare inputs
|
# Prepare inputs
|
||||||
tokens_input = np.array([[0, *tokens, 0]], dtype=np.int64) # Add start/end tokens
|
tokens_input = np.array([[0, *tokens, 0]], dtype=np.int64) # Add start/end tokens
|
||||||
style_input = voice[len(tokens) + 2].cpu().numpy() # Move to CPU for ONNX
|
# Use modulo to ensure index stays within voice tensor bounds
|
||||||
|
style_idx = (len(tokens) + 2) % voice.size(0) # Add 2 for start/end tokens
|
||||||
|
style_input = voice[style_idx].cpu().numpy() # Move to CPU for ONNX
|
||||||
speed_input = np.full(1, speed, dtype=np.float32)
|
speed_input = np.full(1, speed, dtype=np.float32)
|
||||||
|
|
||||||
# Run inference
|
# Run inference
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
@ -42,10 +43,8 @@ async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse:
|
||||||
if not phonemes:
|
if not phonemes:
|
||||||
raise ValueError("Failed to generate phonemes")
|
raise ValueError("Failed to generate phonemes")
|
||||||
|
|
||||||
# Get tokens
|
# Get tokens (without adding start/end tokens to match process_text behavior)
|
||||||
tokens = tokenize(phonemes)
|
tokens = tokenize(phonemes)
|
||||||
tokens = [0] + tokens + [0] # Add start/end tokens
|
|
||||||
|
|
||||||
return PhonemeResponse(phonemes=phonemes, tokens=tokens)
|
return PhonemeResponse(phonemes=phonemes, tokens=tokens)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.error(f"Error in phoneme generation: {str(e)}")
|
logger.error(f"Error in phoneme generation: {str(e)}")
|
||||||
|
@ -93,23 +92,54 @@ async def generate_from_phonemes(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert phonemes to tokens
|
# Handle both single string and list of chunks
|
||||||
tokens = tokenize(request.phonemes)
|
phoneme_chunks = [request.phonemes] if isinstance(request.phonemes, str) else request.phonemes
|
||||||
tokens = [0] + tokens + [0] # Add start/end tokens
|
audio_chunks = []
|
||||||
|
|
||||||
# Generate audio directly from tokens
|
# Load voice tensor first since we'll need it for all chunks
|
||||||
audio = await tts_service.model_manager.generate(
|
voice_tensor = await tts_service._voice_manager.load_voice(
|
||||||
tokens,
|
|
||||||
request.voice,
|
request.voice,
|
||||||
speed=request.speed
|
device=tts_service.model_manager.get_backend().device
|
||||||
)
|
)
|
||||||
|
|
||||||
if audio is None:
|
try:
|
||||||
raise ValueError("Failed to generate audio")
|
# Process each chunk
|
||||||
|
for chunk in phoneme_chunks:
|
||||||
|
# Convert chunk to tokens
|
||||||
|
tokens = tokenize(chunk)
|
||||||
|
tokens = [0] + tokens + [0] # Add start/end tokens
|
||||||
|
|
||||||
|
# Validate chunk length
|
||||||
|
if len(tokens) > 510: # 510 to leave room for start/end tokens
|
||||||
|
raise ValueError(
|
||||||
|
f"Chunk too long ({len(tokens)} tokens). Each chunk must be under 510 tokens."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate audio for chunk
|
||||||
|
chunk_audio = await tts_service.model_manager.generate(
|
||||||
|
tokens,
|
||||||
|
voice_tensor,
|
||||||
|
speed=request.speed
|
||||||
|
)
|
||||||
|
if chunk_audio is not None:
|
||||||
|
audio_chunks.append(chunk_audio)
|
||||||
|
|
||||||
|
# Combine chunks if needed
|
||||||
|
if len(audio_chunks) > 1:
|
||||||
|
audio = np.concatenate(audio_chunks)
|
||||||
|
elif len(audio_chunks) == 1:
|
||||||
|
audio = audio_chunks[0]
|
||||||
|
else:
|
||||||
|
raise ValueError("No audio chunks were generated")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up voice tensor
|
||||||
|
del voice_tensor
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# Convert to WAV bytes
|
# Convert to WAV bytes
|
||||||
wav_bytes = AudioService.convert_audio(
|
wav_bytes = AudioService.convert_audio(
|
||||||
audio, 24000, "wav", is_first_chunk=True, is_last_chunk=True, stream=False
|
audio, 24000, "wav", is_first_chunk=True, is_last_chunk=True, stream=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
return Response(
|
return Response(
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from pydantic import validator
|
||||||
|
from typing import List, Union, Optional
|
||||||
|
|
||||||
class PhonemeRequest(BaseModel):
|
class PhonemeRequest(BaseModel):
|
||||||
text: str
|
text: str
|
||||||
|
@ -11,9 +12,34 @@ class PhonemeResponse(BaseModel):
|
||||||
tokens: list[int]
|
tokens: list[int]
|
||||||
|
|
||||||
|
|
||||||
|
class StitchOptions(BaseModel):
|
||||||
|
"""Options for stitching audio chunks together"""
|
||||||
|
gap_method: str = Field(
|
||||||
|
default="static_trim",
|
||||||
|
description="Method to handle gaps between chunks. Currently only 'static_trim' supported."
|
||||||
|
)
|
||||||
|
trim_ms: int = Field(
|
||||||
|
default=0,
|
||||||
|
ge=0,
|
||||||
|
description="Milliseconds to trim from chunk boundaries when using static_trim"
|
||||||
|
)
|
||||||
|
|
||||||
|
@validator('gap_method')
|
||||||
|
def validate_gap_method(cls, v):
|
||||||
|
if v != 'static_trim':
|
||||||
|
raise ValueError("Currently only 'static_trim' gap method is supported")
|
||||||
|
return v
|
||||||
|
|
||||||
class GenerateFromPhonemesRequest(BaseModel):
|
class GenerateFromPhonemesRequest(BaseModel):
|
||||||
phonemes: str
|
phonemes: Union[str, List[str]] = Field(
|
||||||
|
...,
|
||||||
|
description="Single phoneme string or list of phoneme chunks to stitch together"
|
||||||
|
)
|
||||||
voice: str = Field(..., description="Voice ID to use for generation")
|
voice: str = Field(..., description="Voice ID to use for generation")
|
||||||
speed: float = Field(
|
speed: float = Field(
|
||||||
default=1.0, ge=0.1, le=5.0, description="Speed factor for generation"
|
default=1.0, ge=0.1, le=5.0, description="Speed factor for generation"
|
||||||
)
|
)
|
||||||
|
options: Optional[StitchOptions] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Optional settings for audio generation and stitching"
|
||||||
|
)
|
||||||
|
|
|
@ -46,17 +46,29 @@ def generate_audio_from_phonemes(
|
||||||
WAV audio bytes if successful, None if failed
|
WAV audio bytes if successful, None if failed
|
||||||
"""
|
"""
|
||||||
# Create the request payload
|
# Create the request payload
|
||||||
payload = {"phonemes": phonemes, "voice": voice, "speed": speed}
|
payload = {
|
||||||
|
"phonemes": phonemes,
|
||||||
|
"voice": voice,
|
||||||
|
"speed": speed,
|
||||||
|
"stitch_long_content": True # Default to false to get the error message
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
# Make POST request to generate audio
|
# Make POST request to generate audio
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
"http://localhost:8880/text/generate_from_phonemes", json=payload
|
"http://localhost:8880/text/generate_from_phonemes", json=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
# Raise exception for error status codes
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
return response.content
|
return response.content
|
||||||
|
except requests.HTTPError as e:
|
||||||
|
# Get the error details from the response
|
||||||
|
try:
|
||||||
|
error_details = response.json()
|
||||||
|
error_msg = error_details.get('detail', {}).get('message', str(e))
|
||||||
|
print(f"Server Error: {error_msg}")
|
||||||
|
except:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -66,7 +78,13 @@ def main():
|
||||||
"How are you today? I am doing reasonably well, thank you for asking",
|
"How are you today? I am doing reasonably well, thank you for asking",
|
||||||
"""This is a test of the phoneme generation system. Do not be alarmed.
|
"""This is a test of the phoneme generation system. Do not be alarmed.
|
||||||
This is only a test. If this were a real phoneme emergency, '
|
This is only a test. If this were a real phoneme emergency, '
|
||||||
you would be instructed to a phoneme shelter in your area.""",
|
you would be instructed to a phoneme shelter in your area. Repeat.
|
||||||
|
This is a test of the phoneme generation system. Do not be alarmed.
|
||||||
|
This is only a test. If this were a real phoneme emergency, '
|
||||||
|
you would be instructed to a phoneme shelter in your area. Repeat.
|
||||||
|
This is a test of the phoneme generation system. Do not be alarmed.
|
||||||
|
This is only a test. If this were a real phoneme emergency, '
|
||||||
|
you would be instructed to a phoneme shelter in your area""",
|
||||||
]
|
]
|
||||||
|
|
||||||
print("Generating phonemes and audio for example texts...\n")
|
print("Generating phonemes and audio for example texts...\n")
|
||||||
|
@ -85,6 +103,9 @@ def main():
|
||||||
|
|
||||||
# Generate audio from phonemes
|
# Generate audio from phonemes
|
||||||
print("Generating audio...")
|
print("Generating audio...")
|
||||||
|
if len(phonemes) > 500: # split into arrays of 500 phonemes
|
||||||
|
phonemes = [phonemes[i:i+500] for i in range(0, len(phonemes), 500)]
|
||||||
|
|
||||||
audio_bytes = generate_audio_from_phonemes(phonemes)
|
audio_bytes = generate_audio_from_phonemes(phonemes)
|
||||||
|
|
||||||
if audio_bytes:
|
if audio_bytes:
|
||||||
|
|
Binary file not shown.
Loading…
Add table
Reference in a new issue