2025-01-03 00:53:41 -07:00
|
|
|
import io
|
|
|
|
import os
|
|
|
|
import re
|
|
|
|
import time
|
2025-01-04 17:54:54 -07:00
|
|
|
from functools import lru_cache
|
2025-01-13 20:15:46 -07:00
|
|
|
from typing import List, Optional, Tuple
|
2025-01-03 00:53:41 -07:00
|
|
|
|
2025-01-09 18:41:44 -07:00
|
|
|
import aiofiles.os
|
2025-01-13 20:15:46 -07:00
|
|
|
import numpy as np
|
2025-01-03 00:53:41 -07:00
|
|
|
import scipy.io.wavfile as wavfile
|
2025-01-13 20:15:46 -07:00
|
|
|
import torch
|
2025-01-03 00:53:41 -07:00
|
|
|
from loguru import logger
|
|
|
|
|
2025-01-09 18:41:44 -07:00
|
|
|
from ..core.config import settings
|
2025-01-13 20:15:46 -07:00
|
|
|
from .audio import AudioNormalizer, AudioService
|
2025-01-09 18:41:44 -07:00
|
|
|
from .text_processing import chunker, normalize_text
|
2025-01-13 20:15:46 -07:00
|
|
|
from .tts_model import TTSModel
|
2025-01-03 00:53:41 -07:00
|
|
|
|
|
|
|
|
|
|
|
class TTSService:
|
|
|
|
def __init__(self, output_dir: str = None):
|
|
|
|
self.output_dir = output_dir
|
2025-01-09 07:20:14 -07:00
|
|
|
self.model = TTSModel.get_instance()
|
2025-01-03 00:53:41 -07:00
|
|
|
|
2025-01-04 17:54:54 -07:00
|
|
|
@staticmethod
|
2025-01-09 07:20:14 -07:00
|
|
|
@lru_cache(maxsize=3) # Cache up to 3 most recently used voices
|
2025-01-04 17:54:54 -07:00
|
|
|
def _load_voice(voice_path: str) -> torch.Tensor:
|
|
|
|
"""Load and cache a voice model"""
|
2025-01-09 18:41:44 -07:00
|
|
|
return torch.load(
|
|
|
|
voice_path, map_location=TTSModel.get_device(), weights_only=True
|
|
|
|
)
|
2025-01-04 17:54:54 -07:00
|
|
|
|
2025-01-03 00:53:41 -07:00
|
|
|
def _get_voice_path(self, voice_name: str) -> Optional[str]:
|
|
|
|
"""Get the path to a voice file"""
|
|
|
|
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice_name}.pt")
|
|
|
|
return voice_path if os.path.exists(voice_path) else None
|
|
|
|
|
|
|
|
def _generate_audio(
|
|
|
|
self, text: str, voice: str, speed: float, stitch_long_output: bool = True
|
2025-01-04 17:54:54 -07:00
|
|
|
) -> Tuple[torch.Tensor, float]:
|
|
|
|
"""Generate complete audio and return with processing time"""
|
2025-01-09 18:41:44 -07:00
|
|
|
audio, processing_time = self._generate_audio_internal(
|
|
|
|
text, voice, speed, stitch_long_output
|
|
|
|
)
|
2025-01-04 17:54:54 -07:00
|
|
|
return audio, processing_time
|
|
|
|
|
|
|
|
def _generate_audio_internal(
|
|
|
|
self, text: str, voice: str, speed: float, stitch_long_output: bool = True
|
2025-01-03 00:53:41 -07:00
|
|
|
) -> Tuple[torch.Tensor, float]:
|
|
|
|
"""Generate audio and measure processing time"""
|
|
|
|
start_time = time.time()
|
|
|
|
|
|
|
|
try:
|
|
|
|
# Normalize text once at the start
|
|
|
|
if not text:
|
|
|
|
raise ValueError("Text is empty after preprocessing")
|
2025-01-03 03:16:42 -07:00
|
|
|
normalized = normalize_text(text)
|
|
|
|
if not normalized:
|
|
|
|
raise ValueError("Text is empty after preprocessing")
|
|
|
|
text = str(normalized)
|
2025-01-03 00:53:41 -07:00
|
|
|
|
|
|
|
# Check voice exists
|
|
|
|
voice_path = self._get_voice_path(voice)
|
|
|
|
if not voice_path:
|
|
|
|
raise ValueError(f"Voice not found: {voice}")
|
|
|
|
|
2025-01-04 17:54:54 -07:00
|
|
|
# Load voice using cached loader
|
|
|
|
voicepack = self._load_voice(voice_path)
|
2025-01-03 00:53:41 -07:00
|
|
|
|
2025-01-06 03:32:41 -07:00
|
|
|
# For non-streaming, preprocess all chunks first
|
2025-01-03 00:53:41 -07:00
|
|
|
if stitch_long_output:
|
2025-01-06 03:32:41 -07:00
|
|
|
# Preprocess all chunks to phonemes/tokens
|
|
|
|
chunks_data = []
|
|
|
|
for chunk in chunker.split_text(text):
|
2025-01-03 00:53:41 -07:00
|
|
|
try:
|
2025-01-03 17:54:17 -07:00
|
|
|
phonemes, tokens = TTSModel.process_text(chunk, voice[0])
|
2025-01-06 03:32:41 -07:00
|
|
|
chunks_data.append((chunk, tokens))
|
|
|
|
except Exception as e:
|
2025-01-09 18:41:44 -07:00
|
|
|
logger.error(
|
|
|
|
f"Failed to process chunk: '{chunk}'. Error: {str(e)}"
|
|
|
|
)
|
2025-01-06 03:32:41 -07:00
|
|
|
continue
|
|
|
|
|
|
|
|
if not chunks_data:
|
|
|
|
raise ValueError("No chunks were processed successfully")
|
|
|
|
|
|
|
|
# Generate audio for all chunks
|
|
|
|
audio_chunks = []
|
|
|
|
for chunk, tokens in chunks_data:
|
|
|
|
try:
|
2025-01-09 18:41:44 -07:00
|
|
|
chunk_audio = TTSModel.generate_from_tokens(
|
|
|
|
tokens, voicepack, speed
|
|
|
|
)
|
2025-01-03 00:53:41 -07:00
|
|
|
if chunk_audio is not None:
|
|
|
|
audio_chunks.append(chunk_audio)
|
|
|
|
else:
|
2025-01-06 03:32:41 -07:00
|
|
|
logger.error(f"No audio generated for chunk: '{chunk}'")
|
2025-01-03 00:53:41 -07:00
|
|
|
except Exception as e:
|
2025-01-09 18:41:44 -07:00
|
|
|
logger.error(
|
|
|
|
f"Failed to generate audio for chunk: '{chunk}'. Error: {str(e)}"
|
|
|
|
)
|
2025-01-03 00:53:41 -07:00
|
|
|
continue
|
|
|
|
|
|
|
|
if not audio_chunks:
|
|
|
|
raise ValueError("No audio chunks were generated successfully")
|
|
|
|
|
2025-01-04 22:23:59 -07:00
|
|
|
# Concatenate all chunks
|
2025-01-09 18:41:44 -07:00
|
|
|
audio = (
|
|
|
|
np.concatenate(audio_chunks)
|
|
|
|
if len(audio_chunks) > 1
|
|
|
|
else audio_chunks[0]
|
|
|
|
)
|
2025-01-03 00:53:41 -07:00
|
|
|
else:
|
|
|
|
# Process single chunk
|
2025-01-03 17:54:17 -07:00
|
|
|
phonemes, tokens = TTSModel.process_text(text, voice[0])
|
|
|
|
audio = TTSModel.generate_from_tokens(tokens, voicepack, speed)
|
2025-01-03 00:53:41 -07:00
|
|
|
|
|
|
|
processing_time = time.time() - start_time
|
|
|
|
return audio, processing_time
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Error in audio generation: {str(e)}")
|
|
|
|
raise
|
|
|
|
|
2025-01-04 17:54:54 -07:00
|
|
|
async def generate_audio_stream(
|
2025-01-09 18:41:44 -07:00
|
|
|
self,
|
|
|
|
text: str,
|
|
|
|
voice: str,
|
|
|
|
speed: float,
|
|
|
|
output_format: str = "wav",
|
|
|
|
silent=False,
|
2025-01-04 17:54:54 -07:00
|
|
|
):
|
|
|
|
"""Generate and yield audio chunks as they're generated for real-time streaming"""
|
|
|
|
try:
|
2025-01-06 03:32:41 -07:00
|
|
|
stream_start = time.time()
|
2025-01-04 17:54:54 -07:00
|
|
|
# Create normalizer for consistent audio levels
|
|
|
|
stream_normalizer = AudioNormalizer()
|
2025-01-09 18:41:44 -07:00
|
|
|
|
2025-01-04 17:54:54 -07:00
|
|
|
# Input validation and preprocessing
|
|
|
|
if not text:
|
|
|
|
raise ValueError("Text is empty")
|
2025-01-06 03:32:41 -07:00
|
|
|
preprocess_start = time.time()
|
2025-01-04 17:54:54 -07:00
|
|
|
normalized = normalize_text(text)
|
|
|
|
if not normalized:
|
|
|
|
raise ValueError("Text is empty after preprocessing")
|
|
|
|
text = str(normalized)
|
2025-01-09 18:41:44 -07:00
|
|
|
logger.debug(
|
|
|
|
f"Text preprocessing took: {(time.time() - preprocess_start)*1000:.1f}ms"
|
|
|
|
)
|
2025-01-04 17:54:54 -07:00
|
|
|
|
|
|
|
# Voice validation and loading
|
2025-01-06 03:32:41 -07:00
|
|
|
voice_start = time.time()
|
2025-01-04 17:54:54 -07:00
|
|
|
voice_path = self._get_voice_path(voice)
|
|
|
|
if not voice_path:
|
|
|
|
raise ValueError(f"Voice not found: {voice}")
|
|
|
|
voicepack = self._load_voice(voice_path)
|
2025-01-09 18:41:44 -07:00
|
|
|
logger.debug(
|
|
|
|
f"Voice loading took: {(time.time() - voice_start)*1000:.1f}ms"
|
|
|
|
)
|
2025-01-04 17:54:54 -07:00
|
|
|
|
2025-01-04 22:23:59 -07:00
|
|
|
# Process chunks as they're generated
|
|
|
|
is_first = True
|
2025-01-06 03:32:41 -07:00
|
|
|
chunks_processed = 0
|
2025-01-09 18:41:44 -07:00
|
|
|
|
2025-01-06 03:32:41 -07:00
|
|
|
# Process chunks as they come from generator
|
|
|
|
chunk_gen = chunker.split_text(text)
|
|
|
|
current_chunk = next(chunk_gen, None)
|
2025-01-09 18:41:44 -07:00
|
|
|
|
2025-01-06 03:32:41 -07:00
|
|
|
while current_chunk is not None:
|
|
|
|
next_chunk = next(chunk_gen, None) # Peek at next chunk
|
|
|
|
chunks_processed += 1
|
2025-01-04 17:54:54 -07:00
|
|
|
try:
|
|
|
|
# Process text and generate audio
|
2025-01-06 03:32:41 -07:00
|
|
|
phonemes, tokens = TTSModel.process_text(current_chunk, voice[0])
|
2025-01-09 18:41:44 -07:00
|
|
|
chunk_audio = TTSModel.generate_from_tokens(
|
|
|
|
tokens, voicepack, speed
|
|
|
|
)
|
|
|
|
|
2025-01-04 17:54:54 -07:00
|
|
|
if chunk_audio is not None:
|
2025-01-10 22:03:16 -07:00
|
|
|
# Convert chunk with proper streaming header handling
|
2025-01-04 17:54:54 -07:00
|
|
|
chunk_bytes = AudioService.convert_audio(
|
|
|
|
chunk_audio,
|
|
|
|
24000,
|
|
|
|
output_format,
|
2025-01-04 22:23:59 -07:00
|
|
|
is_first_chunk=is_first,
|
2025-01-06 03:32:41 -07:00
|
|
|
normalizer=stream_normalizer,
|
2025-01-09 18:41:44 -07:00
|
|
|
is_last_chunk=(next_chunk is None), # Last if no next chunk
|
2025-01-10 22:03:16 -07:00
|
|
|
stream=True # Ensure proper streaming format handling
|
2025-01-04 17:54:54 -07:00
|
|
|
)
|
2025-01-09 18:41:44 -07:00
|
|
|
|
2025-01-04 17:54:54 -07:00
|
|
|
yield chunk_bytes
|
2025-01-04 22:23:59 -07:00
|
|
|
is_first = False
|
2025-01-04 17:54:54 -07:00
|
|
|
else:
|
2025-01-06 03:32:41 -07:00
|
|
|
logger.error(f"No audio generated for chunk: '{current_chunk}'")
|
2025-01-04 17:54:54 -07:00
|
|
|
|
|
|
|
except Exception as e:
|
2025-01-09 18:41:44 -07:00
|
|
|
logger.error(
|
|
|
|
f"Failed to generate audio for chunk: '{current_chunk}'. Error: {str(e)}"
|
|
|
|
)
|
|
|
|
|
2025-01-06 03:32:41 -07:00
|
|
|
current_chunk = next_chunk # Move to next chunk
|
2025-01-09 18:41:44 -07:00
|
|
|
|
2025-01-04 17:54:54 -07:00
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Error in audio generation stream: {str(e)}")
|
|
|
|
raise
|
|
|
|
|
2025-01-03 00:53:41 -07:00
|
|
|
def _save_audio(self, audio: torch.Tensor, filepath: str):
|
|
|
|
"""Save audio to file"""
|
|
|
|
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
|
|
|
wavfile.write(filepath, 24000, audio)
|
|
|
|
|
|
|
|
def _audio_to_bytes(self, audio: torch.Tensor) -> bytes:
|
|
|
|
"""Convert audio tensor to WAV bytes"""
|
|
|
|
buffer = io.BytesIO()
|
|
|
|
wavfile.write(buffer, 24000, audio)
|
|
|
|
return buffer.getvalue()
|
|
|
|
|
2025-01-07 03:50:08 -07:00
|
|
|
async def combine_voices(self, voices: List[str]) -> str:
|
2025-01-03 00:53:41 -07:00
|
|
|
"""Combine multiple voices into a new voice"""
|
|
|
|
if len(voices) < 2:
|
|
|
|
raise ValueError("At least 2 voices are required for combination")
|
|
|
|
|
|
|
|
# Load voices
|
|
|
|
t_voices: List[torch.Tensor] = []
|
|
|
|
v_name: List[str] = []
|
|
|
|
|
|
|
|
for voice in voices:
|
|
|
|
try:
|
|
|
|
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt")
|
|
|
|
voicepack = torch.load(
|
|
|
|
voice_path, map_location=TTSModel.get_device(), weights_only=True
|
|
|
|
)
|
|
|
|
t_voices.append(voicepack)
|
|
|
|
v_name.append(voice)
|
|
|
|
except Exception as e:
|
|
|
|
raise ValueError(f"Failed to load voice {voice}: {str(e)}")
|
|
|
|
|
|
|
|
# Combine voices
|
|
|
|
try:
|
|
|
|
f: str = "_".join(v_name)
|
|
|
|
v = torch.mean(torch.stack(t_voices), dim=0)
|
|
|
|
combined_path = os.path.join(TTSModel.VOICES_DIR, f"{f}.pt")
|
|
|
|
|
|
|
|
# Save combined voice
|
|
|
|
try:
|
|
|
|
torch.save(v, combined_path)
|
|
|
|
except Exception as e:
|
|
|
|
raise RuntimeError(
|
|
|
|
f"Failed to save combined voice to {combined_path}: {str(e)}"
|
|
|
|
)
|
|
|
|
|
|
|
|
return f
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
if not isinstance(e, (ValueError, RuntimeError)):
|
|
|
|
raise RuntimeError(f"Error combining voices: {str(e)}")
|
|
|
|
raise
|
2025-01-09 18:41:44 -07:00
|
|
|
|
2025-01-07 03:50:08 -07:00
|
|
|
async def list_voices(self) -> List[str]:
|
2025-01-03 00:53:41 -07:00
|
|
|
"""List all available voices"""
|
|
|
|
voices = []
|
|
|
|
try:
|
2025-01-07 21:36:07 -07:00
|
|
|
it = await aiofiles.os.scandir(TTSModel.VOICES_DIR)
|
|
|
|
for entry in it:
|
|
|
|
if entry.name.endswith(".pt"):
|
|
|
|
voices.append(entry.name[:-3]) # Remove .pt extension
|
2025-01-03 00:53:41 -07:00
|
|
|
except Exception as e:
|
|
|
|
logger.error(f"Error listing voices: {str(e)}")
|
|
|
|
return sorted(voices)
|