import io import os import re import time from functools import lru_cache from typing import List, Optional, Tuple import aiofiles.os import numpy as np import scipy.io.wavfile as wavfile import torch from loguru import logger from ..core.config import settings from .audio import AudioNormalizer, AudioService from .text_processing import chunker, normalize_text from .tts_model import TTSModel class TTSService: def __init__(self, output_dir: str = None): self.output_dir = output_dir self.model = TTSModel.get_instance() @staticmethod @lru_cache(maxsize=3) # Cache up to 3 most recently used voices def _load_voice(voice_path: str) -> torch.Tensor: """Load and cache a voice model""" return torch.load( voice_path, map_location=TTSModel.get_device(), weights_only=True ) def _get_voice_path(self, voice_name: str) -> Optional[str]: """Get the path to a voice file""" voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice_name}.pt") return voice_path if os.path.exists(voice_path) else None def _generate_audio( self, text: str, voice: str, speed: float, stitch_long_output: bool = True ) -> Tuple[torch.Tensor, float]: """Generate complete audio and return with processing time""" audio, processing_time = self._generate_audio_internal( text, voice, speed, stitch_long_output ) return audio, processing_time def _generate_audio_internal( self, text: str, voice: str, speed: float, stitch_long_output: bool = True ) -> Tuple[torch.Tensor, float]: """Generate audio and measure processing time""" start_time = time.time() try: # Normalize text once at the start if not text: raise ValueError("Text is empty after preprocessing") normalized = normalize_text(text) if not normalized: raise ValueError("Text is empty after preprocessing") text = str(normalized) # Check voice exists voice_path = self._get_voice_path(voice) if not voice_path: raise ValueError(f"Voice not found: {voice}") # Load voice using cached loader voicepack = self._load_voice(voice_path) # For non-streaming, preprocess all chunks first if stitch_long_output: # Preprocess all chunks to phonemes/tokens chunks_data = [] for chunk in chunker.split_text(text): try: phonemes, tokens = TTSModel.process_text(chunk, voice[0]) chunks_data.append((chunk, tokens)) except Exception as e: logger.error( f"Failed to process chunk: '{chunk}'. Error: {str(e)}" ) continue if not chunks_data: raise ValueError("No chunks were processed successfully") # Generate audio for all chunks audio_chunks = [] for chunk, tokens in chunks_data: try: chunk_audio = TTSModel.generate_from_tokens( tokens, voicepack, speed ) if chunk_audio is not None: audio_chunks.append(chunk_audio) else: logger.error(f"No audio generated for chunk: '{chunk}'") except Exception as e: logger.error( f"Failed to generate audio for chunk: '{chunk}'. Error: {str(e)}" ) continue if not audio_chunks: raise ValueError("No audio chunks were generated successfully") # Concatenate all chunks audio = ( np.concatenate(audio_chunks) if len(audio_chunks) > 1 else audio_chunks[0] ) else: # Process single chunk phonemes, tokens = TTSModel.process_text(text, voice[0]) audio = TTSModel.generate_from_tokens(tokens, voicepack, speed) processing_time = time.time() - start_time return audio, processing_time except Exception as e: logger.error(f"Error in audio generation: {str(e)}") raise async def generate_audio_stream( self, text: str, voice: str, speed: float, output_format: str = "wav", silent=False, ): """Generate and yield audio chunks as they're generated for real-time streaming""" try: stream_start = time.time() # Create normalizer for consistent audio levels stream_normalizer = AudioNormalizer() # Input validation and preprocessing if not text: raise ValueError("Text is empty") preprocess_start = time.time() normalized = normalize_text(text) if not normalized: raise ValueError("Text is empty after preprocessing") text = str(normalized) logger.debug( f"Text preprocessing took: {(time.time() - preprocess_start)*1000:.1f}ms" ) # Voice validation and loading voice_start = time.time() voice_path = self._get_voice_path(voice) if not voice_path: raise ValueError(f"Voice not found: {voice}") voicepack = self._load_voice(voice_path) logger.debug( f"Voice loading took: {(time.time() - voice_start)*1000:.1f}ms" ) # Process chunks as they're generated is_first = True chunks_processed = 0 # Process chunks as they come from generator chunk_gen = chunker.split_text(text) current_chunk = next(chunk_gen, None) while current_chunk is not None: next_chunk = next(chunk_gen, None) # Peek at next chunk chunks_processed += 1 try: # Process text and generate audio phonemes, tokens = TTSModel.process_text(current_chunk, voice[0]) chunk_audio = TTSModel.generate_from_tokens( tokens, voicepack, speed ) if chunk_audio is not None: # Convert chunk with proper streaming header handling chunk_bytes = AudioService.convert_audio( chunk_audio, 24000, output_format, is_first_chunk=is_first, normalizer=stream_normalizer, is_last_chunk=(next_chunk is None), # Last if no next chunk stream=True # Ensure proper streaming format handling ) yield chunk_bytes is_first = False else: logger.error(f"No audio generated for chunk: '{current_chunk}'") except Exception as e: logger.error( f"Failed to generate audio for chunk: '{current_chunk}'. Error: {str(e)}" ) current_chunk = next_chunk # Move to next chunk except Exception as e: logger.error(f"Error in audio generation stream: {str(e)}") raise def _save_audio(self, audio: torch.Tensor, filepath: str): """Save audio to file""" os.makedirs(os.path.dirname(filepath), exist_ok=True) wavfile.write(filepath, 24000, audio) def _audio_to_bytes(self, audio: torch.Tensor) -> bytes: """Convert audio tensor to WAV bytes""" buffer = io.BytesIO() wavfile.write(buffer, 24000, audio) return buffer.getvalue() async def combine_voices(self, voices: List[str]) -> str: """Combine multiple voices into a new voice""" if len(voices) < 2: raise ValueError("At least 2 voices are required for combination") # Load voices t_voices: List[torch.Tensor] = [] v_name: List[str] = [] for voice in voices: try: voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt") voicepack = torch.load( voice_path, map_location=TTSModel.get_device(), weights_only=True ) t_voices.append(voicepack) v_name.append(voice) except Exception as e: raise ValueError(f"Failed to load voice {voice}: {str(e)}") # Combine voices try: f: str = "_".join(v_name) v = torch.mean(torch.stack(t_voices), dim=0) combined_path = os.path.join(TTSModel.VOICES_DIR, f"{f}.pt") # Save combined voice try: torch.save(v, combined_path) except Exception as e: raise RuntimeError( f"Failed to save combined voice to {combined_path}: {str(e)}" ) return f except Exception as e: if not isinstance(e, (ValueError, RuntimeError)): raise RuntimeError(f"Error combining voices: {str(e)}") raise async def list_voices(self) -> List[str]: """List all available voices""" voices = [] try: it = await aiofiles.os.scandir(TTSModel.VOICES_DIR) for entry in it: if entry.name.endswith(".pt"): voices.append(entry.name[:-3]) # Remove .pt extension except Exception as e: logger.error(f"Error listing voices: {str(e)}") return sorted(voices)