Refactor ONNX GPU backend and phoneme generation: improve token handling, add chunk processing for audio generation, and initial introduce stitch options for audio chunks.

This commit is contained in:
remsky 2025-01-22 17:43:38 -07:00
parent d50214d3be
commit 66f46e82f9
5 changed files with 106 additions and 27 deletions

View file

@ -87,7 +87,9 @@ class ONNXGPUBackend(BaseModelBackend):
try:
# Prepare inputs
tokens_input = np.array([[0, *tokens, 0]], dtype=np.int64) # Add start/end tokens
style_input = voice[len(tokens) + 2].cpu().numpy() # Move to CPU for ONNX
# Use modulo to ensure index stays within voice tensor bounds
style_idx = (len(tokens) + 2) % voice.size(0) # Add 2 for start/end tokens
style_input = voice[style_idx].cpu().numpy() # Move to CPU for ONNX
speed_input = np.full(1, speed, dtype=np.float32)
# Run inference

View file

@ -1,6 +1,7 @@
from typing import List
import numpy as np
import torch
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from loguru import logger
@ -42,10 +43,8 @@ async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse:
if not phonemes:
raise ValueError("Failed to generate phonemes")
# Get tokens
# Get tokens (without adding start/end tokens to match process_text behavior)
tokens = tokenize(phonemes)
tokens = [0] + tokens + [0] # Add start/end tokens
return PhonemeResponse(phonemes=phonemes, tokens=tokens)
except ValueError as e:
logger.error(f"Error in phoneme generation: {str(e)}")
@ -93,23 +92,54 @@ async def generate_from_phonemes(
},
)
# Convert phonemes to tokens
tokens = tokenize(request.phonemes)
tokens = [0] + tokens + [0] # Add start/end tokens
# Handle both single string and list of chunks
phoneme_chunks = [request.phonemes] if isinstance(request.phonemes, str) else request.phonemes
audio_chunks = []
# Generate audio directly from tokens
audio = await tts_service.model_manager.generate(
tokens,
# Load voice tensor first since we'll need it for all chunks
voice_tensor = await tts_service._voice_manager.load_voice(
request.voice,
speed=request.speed
device=tts_service.model_manager.get_backend().device
)
if audio is None:
raise ValueError("Failed to generate audio")
try:
# Process each chunk
for chunk in phoneme_chunks:
# Convert chunk to tokens
tokens = tokenize(chunk)
tokens = [0] + tokens + [0] # Add start/end tokens
# Validate chunk length
if len(tokens) > 510: # 510 to leave room for start/end tokens
raise ValueError(
f"Chunk too long ({len(tokens)} tokens). Each chunk must be under 510 tokens."
)
# Generate audio for chunk
chunk_audio = await tts_service.model_manager.generate(
tokens,
voice_tensor,
speed=request.speed
)
if chunk_audio is not None:
audio_chunks.append(chunk_audio)
# Combine chunks if needed
if len(audio_chunks) > 1:
audio = np.concatenate(audio_chunks)
elif len(audio_chunks) == 1:
audio = audio_chunks[0]
else:
raise ValueError("No audio chunks were generated")
finally:
# Clean up voice tensor
del voice_tensor
torch.cuda.empty_cache()
# Convert to WAV bytes
wav_bytes = AudioService.convert_audio(
audio, 24000, "wav", is_first_chunk=True, is_last_chunk=True, stream=False
audio, 24000, "wav", is_first_chunk=True, is_last_chunk=True, stream=False,
)
return Response(

View file

@ -1,5 +1,6 @@
from pydantic import BaseModel, Field
from pydantic import validator
from typing import List, Union, Optional
class PhonemeRequest(BaseModel):
text: str
@ -11,9 +12,34 @@ class PhonemeResponse(BaseModel):
tokens: list[int]
class StitchOptions(BaseModel):
"""Options for stitching audio chunks together"""
gap_method: str = Field(
default="static_trim",
description="Method to handle gaps between chunks. Currently only 'static_trim' supported."
)
trim_ms: int = Field(
default=0,
ge=0,
description="Milliseconds to trim from chunk boundaries when using static_trim"
)
@validator('gap_method')
def validate_gap_method(cls, v):
if v != 'static_trim':
raise ValueError("Currently only 'static_trim' gap method is supported")
return v
class GenerateFromPhonemesRequest(BaseModel):
phonemes: str
phonemes: Union[str, List[str]] = Field(
...,
description="Single phoneme string or list of phoneme chunks to stitch together"
)
voice: str = Field(..., description="Voice ID to use for generation")
speed: float = Field(
default=1.0, ge=0.1, le=5.0, description="Speed factor for generation"
)
options: Optional[StitchOptions] = Field(
default=None,
description="Optional settings for audio generation and stitching"
)

View file

@ -46,17 +46,29 @@ def generate_audio_from_phonemes(
WAV audio bytes if successful, None if failed
"""
# Create the request payload
payload = {"phonemes": phonemes, "voice": voice, "speed": speed}
payload = {
"phonemes": phonemes,
"voice": voice,
"speed": speed,
"stitch_long_content": True # Default to false to get the error message
}
# Make POST request to generate audio
response = requests.post(
"http://localhost:8880/text/generate_from_phonemes", json=payload
)
# Raise exception for error status codes
response.raise_for_status()
return response.content
try:
# Make POST request to generate audio
response = requests.post(
"http://localhost:8880/text/generate_from_phonemes", json=payload
)
response.raise_for_status()
return response.content
except requests.HTTPError as e:
# Get the error details from the response
try:
error_details = response.json()
error_msg = error_details.get('detail', {}).get('message', str(e))
print(f"Server Error: {error_msg}")
except:
print(f"Error: {e}")
return None
def main():
@ -66,7 +78,13 @@ def main():
"How are you today? I am doing reasonably well, thank you for asking",
"""This is a test of the phoneme generation system. Do not be alarmed.
This is only a test. If this were a real phoneme emergency, '
you would be instructed to a phoneme shelter in your area.""",
you would be instructed to a phoneme shelter in your area. Repeat.
This is a test of the phoneme generation system. Do not be alarmed.
This is only a test. If this were a real phoneme emergency, '
you would be instructed to a phoneme shelter in your area. Repeat.
This is a test of the phoneme generation system. Do not be alarmed.
This is only a test. If this were a real phoneme emergency, '
you would be instructed to a phoneme shelter in your area""",
]
print("Generating phonemes and audio for example texts...\n")
@ -85,6 +103,9 @@ def main():
# Generate audio from phonemes
print("Generating audio...")
if len(phonemes) > 500: # split into arrays of 500 phonemes
phonemes = [phonemes[i:i+500] for i in range(0, len(phonemes), 500)]
audio_bytes = generate_audio_from_phonemes(phonemes)
if audio_bytes:

Binary file not shown.