Kokoro-FastAPI/api/src/services/tts.py

177 lines
6.3 KiB
Python
Raw Normal View History

import io
import os
import re
import threading
2024-12-31 10:30:12 -05:00
import time
from typing import List, Tuple
import numpy as np
import scipy.io.wavfile as wavfile
2024-12-31 10:30:12 -05:00
import tiktoken
import torch
from loguru import logger
2024-12-31 10:30:12 -05:00
from kokoro import generate, normalize_text, phonemize, tokenize
from models import build_model
from ..core.config import settings
enc = tiktoken.get_encoding("cl100k_base")
class TTSModel:
_instance = None
_lock = threading.Lock()
_voicepacks = {}
@classmethod
def get_instance(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Initializing model on {device}")
model_path = os.path.join(settings.model_dir, settings.model_path)
model = build_model(model_path, device)
# Note: RNN memory optimization is handled internally by the model
cls._instance = (model, device)
return cls._instance
@classmethod
def get_voicepack(cls, voice_name: str) -> torch.Tensor:
model, device = cls.get_instance()
if voice_name not in cls._voicepacks:
try:
voice_path = os.path.join(
settings.model_dir, settings.voices_dir, f"{voice_name}.pt"
)
voicepack = torch.load(
voice_path, map_location=device, weights_only=True
)
cls._voicepacks[voice_name] = voicepack
except Exception as e:
print(f"Error loading voice {voice_name}: {str(e)}")
if voice_name != "af":
return cls.get_voicepack("af")
raise
return cls._voicepacks[voice_name]
class TTSService:
def __init__(self, output_dir: str = None, start_worker: bool = False):
self.output_dir = output_dir
if start_worker:
self.start_worker()
def _split_text(self, text: str) -> List[str]:
"""Split text into sentences"""
return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()]
def _generate_audio(
self, text: str, voice: str, speed: float, stitch_long_output: bool = True
) -> Tuple[torch.Tensor, float]:
"""Generate audio and measure processing time"""
start_time = time.time()
try:
# Normalize text once at the start
text = normalize_text(text)
if not text:
raise ValueError("Text is empty after preprocessing")
# Get model instance and voicepack
model, device = TTSModel.get_instance()
voicepack = TTSModel.get_voicepack(voice)
# Generate audio with or without stitching
if stitch_long_output:
chunks = self._split_text(text)
audio_chunks = []
for i, chunk in enumerate(chunks):
try:
# Validate phonemization first
ps = phonemize(chunk, voice[0])
tokens = tokenize(ps)
logger.debug(
2024-12-31 10:30:12 -05:00
f"Processing chunk {i + 1}/{len(chunks)}: {len(tokens)} tokens"
)
# Only proceed if phonemization succeeded
chunk_audio, _ = generate(
model, chunk, voicepack, lang=voice[0], speed=speed
)
if chunk_audio is not None:
audio_chunks.append(chunk_audio)
else:
logger.error(
2024-12-31 10:30:12 -05:00
f"No audio generated for chunk {i + 1}/{len(chunks)}"
)
except Exception as e:
logger.error(
2024-12-31 10:30:12 -05:00
f"Failed to generate audio for chunk {i + 1}/{len(chunks)}: '{chunk}'. Error: {str(e)}"
)
continue
2024-12-30 13:39:35 -05:00
if not audio_chunks:
raise ValueError("No audio chunks were generated successfully")
2024-12-30 13:39:35 -05:00
audio = (
np.concatenate(audio_chunks)
if len(audio_chunks) > 1
else audio_chunks[0]
)
else:
audio, _ = generate(model, text, voicepack, lang=voice[0], speed=speed)
processing_time = time.time() - start_time
return audio, processing_time
except Exception as e:
print(f"Error in audio generation: {str(e)}")
raise
def _save_audio(self, audio: torch.Tensor, filepath: str):
"""Save audio to file"""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
wavfile.write(filepath, 24000, audio)
def _audio_to_bytes(self, audio: torch.Tensor) -> bytes:
"""Convert audio tensor to WAV bytes"""
buffer = io.BytesIO()
wavfile.write(buffer, 24000, audio)
return buffer.getvalue()
2024-12-31 10:30:12 -05:00
def combine_voices(self, voices: List[str]) -> str:
if len(voices) < 2:
return "af"
t_voices: List[torch.Tensor] = []
v_name: List[str] = []
try:
for file in os.listdir("voices"):
voice_name = file[:-3] # Remove .pt extension
for n in voices:
if n == voice_name:
v_name.append(voice_name)
t_voices.append(torch.load(f"voices/{file}", weights_only=True))
except Exception as e:
print(f"Error combining voices: {str(e)}")
return "af"
f: str = "_".join(v_name)
v = torch.mean(torch.stack(t_voices), dim=0)
torch.save(v, f"voices/{f}.pt")
return f
def list_voices(self) -> List[str]:
"""List all available voices"""
voices = []
try:
voices_path = os.path.join(settings.model_dir, settings.voices_dir)
for file in os.listdir(voices_path):
if file.endswith(".pt"):
voice_name = file[:-3] # Remove .pt extension
voices.append(voice_name)
except Exception as e:
print(f"Error listing voices: {str(e)}")
return voices