import os import threading import time import io from typing import Optional, List, Tuple from sqlalchemy.orm import Session, sessionmaker from ..models.schemas import TTSStatus from ..database.models import TTSQueue import numpy as np import torch import scipy.io.wavfile as wavfile from models import build_model from kokoro import generate, phonemize, tokenize from ..database.queue import QueueDB class TTSModel: _instance = None _lock = threading.Lock() _voicepacks = {} @classmethod def get_instance(cls): if cls._instance is None: with cls._lock: if cls._instance is None: device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Initializing model on {device}") model = build_model("kokoro-v0_19.pth", device) cls._instance = (model, device) return cls._instance @classmethod def get_voicepack(cls, voice_name: str) -> torch.Tensor: model, device = cls.get_instance() if voice_name not in cls._voicepacks: try: voicepack = torch.load( f"voices/{voice_name}.pt", map_location=device, weights_only=True ) cls._voicepacks[voice_name] = voicepack except Exception as e: print(f"Error loading voice {voice_name}: {str(e)}") if voice_name != "af": return cls.get_voicepack("af") raise return cls._voicepacks[voice_name] class TTSService: def __init__(self, db: Session, output_dir: str = None): if output_dir is None: output_dir = os.path.join(os.path.dirname(__file__), "..", "output") self.output_dir = output_dir self.db = QueueDB(db) self.engine = db.get_bind() # Get engine from session os.makedirs(output_dir, exist_ok=True) self._start_worker() def _start_worker(self): """Start the background worker thread""" self.worker = threading.Thread(target=self._process_queue, daemon=True) self.worker.start() def _find_boundary(self, text: str, max_tokens: int, voice: str, margin: int = 50) -> int: """Find the closest sentence/clause boundary within token limit""" # Try different boundary markers in order of preference for marker in ['. ', '; ', ', ']: # Look for the last occurrence of marker before max_tokens test_text = text[:max_tokens + margin] # Look a bit beyond the limit last_idx = test_text.rfind(marker) if last_idx != -1: # Verify this boundary is within our token limit candidate = text[:last_idx + len(marker)].strip() ps = phonemize(candidate, voice[0]) tokens = tokenize(ps) if len(tokens) <= max_tokens: return last_idx + len(marker) # If no good boundary found, find last whitespace within limit test_text = text[:max_tokens] last_space = test_text.rfind(' ') return last_space if last_space != -1 else max_tokens def _split_text(self, text: str, voice: str) -> List[str]: """Split text into chunks that respect token limits and try to maintain sentence structure""" MAX_TOKENS = 450 # Leave wider margin from 510 limit to account for tokenizer differences chunks = [] remaining = text while remaining: # If remaining text is within limit, add it as final chunk ps = phonemize(remaining, voice[0]) tokens = tokenize(ps) if len(tokens) <= MAX_TOKENS: chunks.append(remaining.strip()) break # Find best boundary position split_pos = self._find_boundary(remaining, MAX_TOKENS, voice) # Add chunk and continue with remaining text chunks.append(remaining[:split_pos].strip()) remaining = remaining[split_pos:].strip() return chunks def _generate_audio(self, text: str, voice: str, speed: float, stitch_long_output: bool = True) -> Tuple[torch.Tensor, float]: """Generate audio and measure processing time""" start_time = time.time() # Get model instance and voicepack model, device = TTSModel.get_instance() voicepack = TTSModel.get_voicepack(voice) # Generate audio with or without stitching if stitch_long_output: # Split text if needed and generate audio for each chunk chunks = self._split_text(text, voice) audio_chunks = [] for chunk in chunks: chunk_audio, _ = generate(model, chunk, voicepack, lang=voice[0], speed=speed) audio_chunks.append(chunk_audio) # Concatenate audio chunks if len(audio_chunks) > 1: audio = np.concatenate(audio_chunks) else: audio = audio_chunks[0] else: # Generate single chunk without splitting audio, _ = generate(model, text, voicepack, lang=voice[0]) processing_time = time.time() - start_time return audio, processing_time def _save_audio(self, audio: torch.Tensor, filepath: str): """Save audio to file""" os.makedirs(os.path.dirname(filepath), exist_ok=True) wavfile.write(filepath, 24000, audio) def _audio_to_bytes(self, audio: torch.Tensor) -> bytes: """Convert audio tensor to WAV bytes""" buffer = io.BytesIO() wavfile.write(buffer, 24000, audio) return buffer.getvalue() def _process_queue(self): """Background worker that processes the queue""" # Create a new session for the background worker Session = sessionmaker(bind=self.engine) while True: # Create a new session for each iteration with Session() as session: db = QueueDB(session) request = db.get_next_pending() if request: try: # Generate audio and measure time audio, processing_time = self._generate_audio( request.text, request.voice, request.speed, request.stitch_long_output ) # Save to file output_file = os.path.abspath(os.path.join( self.output_dir, f"speech_{request.id}.wav" )) self._save_audio(audio, output_file) # Update status with processing time db.update_status( request.id, TTSStatus.COMPLETED, output_file=output_file, processing_time=processing_time, ) except Exception as e: print(f"Error processing request {request.id}: {str(e)}") db.update_status(request.id, TTSStatus.FAILED) time.sleep(1) # Prevent busy waiting def list_voices(self) -> list[str]: """List all available voices""" voices = [] try: for file in os.listdir("voices"): if file.endswith(".pt"): voice_name = file[:-3] # Remove .pt extension voices.append(voice_name) except Exception as e: print(f"Error listing voices: {str(e)}") return voices def create_tts_request(self, text: str, voice: str = "af", speed: float = 1.0, stitch_long_output: bool = True) -> int: """Create a new TTS request and return the request ID""" return self.db.add_request(text, voice, speed, stitch_long_output) def get_request_status(self, request_id: int) -> Optional[TTSQueue]: """Get the full request details""" return self.db.get_status(request_id)