Kokoro-FastAPI/api/src/services/tts.py

155 lines
5.5 KiB
Python

import io
import os
import re
import time
import threading
from typing import List, Tuple
import numpy as np
import torch
import tiktoken
import scipy.io.wavfile as wavfile
from kokoro import generate, tokenize, phonemize, normalize_text
from loguru import logger
from models import build_model
from ..core.config import settings
enc = tiktoken.get_encoding("cl100k_base")
class TTSModel:
_instance = None
_lock = threading.Lock()
_voicepacks = {}
@classmethod
def get_instance(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Initializing model on {device}")
model_path = os.path.join(settings.model_dir, settings.model_path)
model = build_model(model_path, device)
# Note: RNN memory optimization is handled internally by the model
cls._instance = (model, device)
return cls._instance
@classmethod
def get_voicepack(cls, voice_name: str) -> torch.Tensor:
model, device = cls.get_instance()
if voice_name not in cls._voicepacks:
try:
voice_path = os.path.join(
settings.model_dir, settings.voices_dir, f"{voice_name}.pt"
)
voicepack = torch.load(
voice_path, map_location=device, weights_only=True
)
cls._voicepacks[voice_name] = voicepack
except Exception as e:
print(f"Error loading voice {voice_name}: {str(e)}")
if voice_name != "af":
return cls.get_voicepack("af")
raise
return cls._voicepacks[voice_name]
class TTSService:
def __init__(self, output_dir: str = None, start_worker: bool = False):
self.output_dir = output_dir
if start_worker:
self.start_worker()
def _split_text(self, text: str) -> List[str]:
"""Split text into sentences"""
return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()]
def _generate_audio(
self, text: str, voice: str, speed: float, stitch_long_output: bool = True
) -> Tuple[torch.Tensor, float]:
"""Generate audio and measure processing time"""
start_time = time.time()
try:
# Normalize text once at the start
text = normalize_text(text)
if not text:
raise ValueError("Text is empty after preprocessing")
# Get model instance and voicepack
model, device = TTSModel.get_instance()
voicepack = TTSModel.get_voicepack(voice)
# Generate audio with or without stitching
if stitch_long_output:
chunks = self._split_text(text)
audio_chunks = []
for i, chunk in enumerate(chunks):
try:
# Validate phonemization first
ps = phonemize(chunk, voice[0])
tokens = tokenize(ps)
logger.debug(
f"Processing chunk {i+1}/{len(chunks)}: {len(tokens)} tokens"
)
# Only proceed if phonemization succeeded
chunk_audio, _ = generate(
model, chunk, voicepack, lang=voice[0], speed=speed
)
if chunk_audio is not None:
audio_chunks.append(chunk_audio)
else:
logger.error(
f"No audio generated for chunk {i+1}/{len(chunks)}"
)
except Exception as e:
logger.error(
f"Failed to generate audio for chunk {i+1}/{len(chunks)}: '{chunk}'. Error: {str(e)}"
)
continue
if not audio_chunks:
raise ValueError("No audio chunks were generated successfully")
audio = (
np.concatenate(audio_chunks)
if len(audio_chunks) > 1
else audio_chunks[0]
)
else:
audio, _ = generate(model, text, voicepack, lang=voice[0], speed=speed)
processing_time = time.time() - start_time
return audio, processing_time
except Exception as e:
print(f"Error in audio generation: {str(e)}")
raise
def _save_audio(self, audio: torch.Tensor, filepath: str):
"""Save audio to file"""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
wavfile.write(filepath, 24000, audio)
def _audio_to_bytes(self, audio: torch.Tensor) -> bytes:
"""Convert audio tensor to WAV bytes"""
buffer = io.BytesIO()
wavfile.write(buffer, 24000, audio)
return buffer.getvalue()
def list_voices(self) -> list[str]:
"""List all available voices"""
voices = []
try:
voices_path = os.path.join(settings.model_dir, settings.voices_dir)
for file in os.listdir(voices_path):
if file.endswith(".pt"):
voice_name = file[:-3] # Remove .pt extension
voices.append(voice_name)
except Exception as e:
print(f"Error listing voices: {str(e)}")
return voices