Refactor ONNX GPU backend and phoneme generation: improve token handling, add chunk processing for audio generation, and initial introduce stitch options for audio chunks.

This commit is contained in:
remsky 2025-01-22 17:43:38 -07:00
parent d50214d3be
commit 66f46e82f9
5 changed files with 106 additions and 27 deletions

View file

@ -87,7 +87,9 @@ class ONNXGPUBackend(BaseModelBackend):
try: try:
# Prepare inputs # Prepare inputs
tokens_input = np.array([[0, *tokens, 0]], dtype=np.int64) # Add start/end tokens tokens_input = np.array([[0, *tokens, 0]], dtype=np.int64) # Add start/end tokens
style_input = voice[len(tokens) + 2].cpu().numpy() # Move to CPU for ONNX # Use modulo to ensure index stays within voice tensor bounds
style_idx = (len(tokens) + 2) % voice.size(0) # Add 2 for start/end tokens
style_input = voice[style_idx].cpu().numpy() # Move to CPU for ONNX
speed_input = np.full(1, speed, dtype=np.float32) speed_input = np.full(1, speed, dtype=np.float32)
# Run inference # Run inference

View file

@ -1,6 +1,7 @@
from typing import List from typing import List
import numpy as np import numpy as np
import torch
from fastapi import APIRouter, Depends, HTTPException, Request, Response from fastapi import APIRouter, Depends, HTTPException, Request, Response
from loguru import logger from loguru import logger
@ -42,10 +43,8 @@ async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse:
if not phonemes: if not phonemes:
raise ValueError("Failed to generate phonemes") raise ValueError("Failed to generate phonemes")
# Get tokens # Get tokens (without adding start/end tokens to match process_text behavior)
tokens = tokenize(phonemes) tokens = tokenize(phonemes)
tokens = [0] + tokens + [0] # Add start/end tokens
return PhonemeResponse(phonemes=phonemes, tokens=tokens) return PhonemeResponse(phonemes=phonemes, tokens=tokens)
except ValueError as e: except ValueError as e:
logger.error(f"Error in phoneme generation: {str(e)}") logger.error(f"Error in phoneme generation: {str(e)}")
@ -93,23 +92,54 @@ async def generate_from_phonemes(
}, },
) )
# Convert phonemes to tokens # Handle both single string and list of chunks
tokens = tokenize(request.phonemes) phoneme_chunks = [request.phonemes] if isinstance(request.phonemes, str) else request.phonemes
tokens = [0] + tokens + [0] # Add start/end tokens audio_chunks = []
# Generate audio directly from tokens # Load voice tensor first since we'll need it for all chunks
audio = await tts_service.model_manager.generate( voice_tensor = await tts_service._voice_manager.load_voice(
tokens,
request.voice, request.voice,
speed=request.speed device=tts_service.model_manager.get_backend().device
) )
if audio is None: try:
raise ValueError("Failed to generate audio") # Process each chunk
for chunk in phoneme_chunks:
# Convert chunk to tokens
tokens = tokenize(chunk)
tokens = [0] + tokens + [0] # Add start/end tokens
# Validate chunk length
if len(tokens) > 510: # 510 to leave room for start/end tokens
raise ValueError(
f"Chunk too long ({len(tokens)} tokens). Each chunk must be under 510 tokens."
)
# Generate audio for chunk
chunk_audio = await tts_service.model_manager.generate(
tokens,
voice_tensor,
speed=request.speed
)
if chunk_audio is not None:
audio_chunks.append(chunk_audio)
# Combine chunks if needed
if len(audio_chunks) > 1:
audio = np.concatenate(audio_chunks)
elif len(audio_chunks) == 1:
audio = audio_chunks[0]
else:
raise ValueError("No audio chunks were generated")
finally:
# Clean up voice tensor
del voice_tensor
torch.cuda.empty_cache()
# Convert to WAV bytes # Convert to WAV bytes
wav_bytes = AudioService.convert_audio( wav_bytes = AudioService.convert_audio(
audio, 24000, "wav", is_first_chunk=True, is_last_chunk=True, stream=False audio, 24000, "wav", is_first_chunk=True, is_last_chunk=True, stream=False,
) )
return Response( return Response(

View file

@ -1,5 +1,6 @@
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from pydantic import validator
from typing import List, Union, Optional
class PhonemeRequest(BaseModel): class PhonemeRequest(BaseModel):
text: str text: str
@ -11,9 +12,34 @@ class PhonemeResponse(BaseModel):
tokens: list[int] tokens: list[int]
class StitchOptions(BaseModel):
"""Options for stitching audio chunks together"""
gap_method: str = Field(
default="static_trim",
description="Method to handle gaps between chunks. Currently only 'static_trim' supported."
)
trim_ms: int = Field(
default=0,
ge=0,
description="Milliseconds to trim from chunk boundaries when using static_trim"
)
@validator('gap_method')
def validate_gap_method(cls, v):
if v != 'static_trim':
raise ValueError("Currently only 'static_trim' gap method is supported")
return v
class GenerateFromPhonemesRequest(BaseModel): class GenerateFromPhonemesRequest(BaseModel):
phonemes: str phonemes: Union[str, List[str]] = Field(
...,
description="Single phoneme string or list of phoneme chunks to stitch together"
)
voice: str = Field(..., description="Voice ID to use for generation") voice: str = Field(..., description="Voice ID to use for generation")
speed: float = Field( speed: float = Field(
default=1.0, ge=0.1, le=5.0, description="Speed factor for generation" default=1.0, ge=0.1, le=5.0, description="Speed factor for generation"
) )
options: Optional[StitchOptions] = Field(
default=None,
description="Optional settings for audio generation and stitching"
)

View file

@ -46,17 +46,29 @@ def generate_audio_from_phonemes(
WAV audio bytes if successful, None if failed WAV audio bytes if successful, None if failed
""" """
# Create the request payload # Create the request payload
payload = {"phonemes": phonemes, "voice": voice, "speed": speed} payload = {
"phonemes": phonemes,
"voice": voice,
"speed": speed,
"stitch_long_content": True # Default to false to get the error message
}
try:
# Make POST request to generate audio # Make POST request to generate audio
response = requests.post( response = requests.post(
"http://localhost:8880/text/generate_from_phonemes", json=payload "http://localhost:8880/text/generate_from_phonemes", json=payload
) )
# Raise exception for error status codes
response.raise_for_status() response.raise_for_status()
return response.content return response.content
except requests.HTTPError as e:
# Get the error details from the response
try:
error_details = response.json()
error_msg = error_details.get('detail', {}).get('message', str(e))
print(f"Server Error: {error_msg}")
except:
print(f"Error: {e}")
return None
def main(): def main():
@ -66,7 +78,13 @@ def main():
"How are you today? I am doing reasonably well, thank you for asking", "How are you today? I am doing reasonably well, thank you for asking",
"""This is a test of the phoneme generation system. Do not be alarmed. """This is a test of the phoneme generation system. Do not be alarmed.
This is only a test. If this were a real phoneme emergency, ' This is only a test. If this were a real phoneme emergency, '
you would be instructed to a phoneme shelter in your area.""", you would be instructed to a phoneme shelter in your area. Repeat.
This is a test of the phoneme generation system. Do not be alarmed.
This is only a test. If this were a real phoneme emergency, '
you would be instructed to a phoneme shelter in your area. Repeat.
This is a test of the phoneme generation system. Do not be alarmed.
This is only a test. If this were a real phoneme emergency, '
you would be instructed to a phoneme shelter in your area""",
] ]
print("Generating phonemes and audio for example texts...\n") print("Generating phonemes and audio for example texts...\n")
@ -85,6 +103,9 @@ def main():
# Generate audio from phonemes # Generate audio from phonemes
print("Generating audio...") print("Generating audio...")
if len(phonemes) > 500: # split into arrays of 500 phonemes
phonemes = [phonemes[i:i+500] for i in range(0, len(phonemes), 500)]
audio_bytes = generate_audio_from_phonemes(phonemes) audio_bytes = generate_audio_from_phonemes(phonemes)
if audio_bytes: if audio_bytes:

Binary file not shown.