diff --git a/api/src/inference/onnx_gpu.py b/api/src/inference/onnx_gpu.py index f11534e..266bad5 100644 --- a/api/src/inference/onnx_gpu.py +++ b/api/src/inference/onnx_gpu.py @@ -87,7 +87,9 @@ class ONNXGPUBackend(BaseModelBackend): try: # Prepare inputs tokens_input = np.array([[0, *tokens, 0]], dtype=np.int64) # Add start/end tokens - style_input = voice[len(tokens) + 2].cpu().numpy() # Move to CPU for ONNX + # Use modulo to ensure index stays within voice tensor bounds + style_idx = (len(tokens) + 2) % voice.size(0) # Add 2 for start/end tokens + style_input = voice[style_idx].cpu().numpy() # Move to CPU for ONNX speed_input = np.full(1, speed, dtype=np.float32) # Run inference diff --git a/api/src/routers/development.py b/api/src/routers/development.py index b8ecf35..df1b638 100644 --- a/api/src/routers/development.py +++ b/api/src/routers/development.py @@ -1,6 +1,7 @@ from typing import List import numpy as np +import torch from fastapi import APIRouter, Depends, HTTPException, Request, Response from loguru import logger @@ -42,10 +43,8 @@ async def phonemize_text(request: PhonemeRequest) -> PhonemeResponse: if not phonemes: raise ValueError("Failed to generate phonemes") - # Get tokens + # Get tokens (without adding start/end tokens to match process_text behavior) tokens = tokenize(phonemes) - tokens = [0] + tokens + [0] # Add start/end tokens - return PhonemeResponse(phonemes=phonemes, tokens=tokens) except ValueError as e: logger.error(f"Error in phoneme generation: {str(e)}") @@ -93,23 +92,54 @@ async def generate_from_phonemes( }, ) - # Convert phonemes to tokens - tokens = tokenize(request.phonemes) - tokens = [0] + tokens + [0] # Add start/end tokens + # Handle both single string and list of chunks + phoneme_chunks = [request.phonemes] if isinstance(request.phonemes, str) else request.phonemes + audio_chunks = [] - # Generate audio directly from tokens - audio = await tts_service.model_manager.generate( - tokens, + # Load voice tensor first since we'll need it for all chunks + voice_tensor = await tts_service._voice_manager.load_voice( request.voice, - speed=request.speed + device=tts_service.model_manager.get_backend().device ) - if audio is None: - raise ValueError("Failed to generate audio") + try: + # Process each chunk + for chunk in phoneme_chunks: + # Convert chunk to tokens + tokens = tokenize(chunk) + tokens = [0] + tokens + [0] # Add start/end tokens + + # Validate chunk length + if len(tokens) > 510: # 510 to leave room for start/end tokens + raise ValueError( + f"Chunk too long ({len(tokens)} tokens). Each chunk must be under 510 tokens." + ) + + # Generate audio for chunk + chunk_audio = await tts_service.model_manager.generate( + tokens, + voice_tensor, + speed=request.speed + ) + if chunk_audio is not None: + audio_chunks.append(chunk_audio) + + # Combine chunks if needed + if len(audio_chunks) > 1: + audio = np.concatenate(audio_chunks) + elif len(audio_chunks) == 1: + audio = audio_chunks[0] + else: + raise ValueError("No audio chunks were generated") + + finally: + # Clean up voice tensor + del voice_tensor + torch.cuda.empty_cache() # Convert to WAV bytes wav_bytes = AudioService.convert_audio( - audio, 24000, "wav", is_first_chunk=True, is_last_chunk=True, stream=False + audio, 24000, "wav", is_first_chunk=True, is_last_chunk=True, stream=False, ) return Response( diff --git a/api/src/structures/text_schemas.py b/api/src/structures/text_schemas.py index 1b6afc9..f25d37a 100644 --- a/api/src/structures/text_schemas.py +++ b/api/src/structures/text_schemas.py @@ -1,5 +1,6 @@ from pydantic import BaseModel, Field - +from pydantic import validator +from typing import List, Union, Optional class PhonemeRequest(BaseModel): text: str @@ -11,9 +12,34 @@ class PhonemeResponse(BaseModel): tokens: list[int] +class StitchOptions(BaseModel): + """Options for stitching audio chunks together""" + gap_method: str = Field( + default="static_trim", + description="Method to handle gaps between chunks. Currently only 'static_trim' supported." + ) + trim_ms: int = Field( + default=0, + ge=0, + description="Milliseconds to trim from chunk boundaries when using static_trim" + ) + + @validator('gap_method') + def validate_gap_method(cls, v): + if v != 'static_trim': + raise ValueError("Currently only 'static_trim' gap method is supported") + return v + class GenerateFromPhonemesRequest(BaseModel): - phonemes: str + phonemes: Union[str, List[str]] = Field( + ..., + description="Single phoneme string or list of phoneme chunks to stitch together" + ) voice: str = Field(..., description="Voice ID to use for generation") speed: float = Field( default=1.0, ge=0.1, le=5.0, description="Speed factor for generation" ) + options: Optional[StitchOptions] = Field( + default=None, + description="Optional settings for audio generation and stitching" + ) diff --git a/examples/phoneme_examples/generate_phonemes.py b/examples/phoneme_examples/generate_phonemes.py index 6b261a8..0c6b8c2 100644 --- a/examples/phoneme_examples/generate_phonemes.py +++ b/examples/phoneme_examples/generate_phonemes.py @@ -46,17 +46,29 @@ def generate_audio_from_phonemes( WAV audio bytes if successful, None if failed """ # Create the request payload - payload = {"phonemes": phonemes, "voice": voice, "speed": speed} + payload = { + "phonemes": phonemes, + "voice": voice, + "speed": speed, + "stitch_long_content": True # Default to false to get the error message + } - # Make POST request to generate audio - response = requests.post( - "http://localhost:8880/text/generate_from_phonemes", json=payload - ) - - # Raise exception for error status codes - response.raise_for_status() - - return response.content + try: + # Make POST request to generate audio + response = requests.post( + "http://localhost:8880/text/generate_from_phonemes", json=payload + ) + response.raise_for_status() + return response.content + except requests.HTTPError as e: + # Get the error details from the response + try: + error_details = response.json() + error_msg = error_details.get('detail', {}).get('message', str(e)) + print(f"Server Error: {error_msg}") + except: + print(f"Error: {e}") + return None def main(): @@ -66,7 +78,13 @@ def main(): "How are you today? I am doing reasonably well, thank you for asking", """This is a test of the phoneme generation system. Do not be alarmed. This is only a test. If this were a real phoneme emergency, ' - you would be instructed to a phoneme shelter in your area.""", + you would be instructed to a phoneme shelter in your area. Repeat. + This is a test of the phoneme generation system. Do not be alarmed. + This is only a test. If this were a real phoneme emergency, ' + you would be instructed to a phoneme shelter in your area. Repeat. + This is a test of the phoneme generation system. Do not be alarmed. + This is only a test. If this were a real phoneme emergency, ' + you would be instructed to a phoneme shelter in your area""", ] print("Generating phonemes and audio for example texts...\n") @@ -85,6 +103,9 @@ def main(): # Generate audio from phonemes print("Generating audio...") + if len(phonemes) > 500: # split into arrays of 500 phonemes + phonemes = [phonemes[i:i+500] for i in range(0, len(phonemes), 500)] + audio_bytes = generate_audio_from_phonemes(phonemes) if audio_bytes: diff --git a/examples/speech.mp3 b/examples/speech.mp3 index 749f647..c21bec0 100644 Binary files a/examples/speech.mp3 and b/examples/speech.mp3 differ