Kokoro-FastAPI/api/src/services/tts.py

136 lines
4.7 KiB
Python
Raw Normal View History

import os
import threading
import time
import io
from typing import Optional, Tuple
import torch
import scipy.io.wavfile as wavfile
from models import build_model
from kokoro import generate
from ..database.queue import QueueDB
class TTSModel:
_instance = None
_lock = threading.Lock()
_voicepacks = {}
@classmethod
def get_instance(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Initializing model on {device}")
model = build_model("kokoro-v0_19.pth", device)
cls._instance = (model, device)
return cls._instance
@classmethod
def get_voicepack(cls, voice_name: str) -> torch.Tensor:
model, device = cls.get_instance()
if voice_name not in cls._voicepacks:
try:
voicepack = torch.load(
f"voices/{voice_name}.pt", map_location=device, weights_only=True
)
cls._voicepacks[voice_name] = voicepack
except Exception as e:
print(f"Error loading voice {voice_name}: {str(e)}")
if voice_name != "af":
return cls.get_voicepack("af")
raise
return cls._voicepacks[voice_name]
class TTSService:
def __init__(self, output_dir: str = None):
if output_dir is None:
output_dir = os.path.join(os.path.dirname(__file__), "..", "output")
self.output_dir = output_dir
self.db = QueueDB()
os.makedirs(output_dir, exist_ok=True)
self._start_worker()
def _start_worker(self):
"""Start the background worker thread"""
self.worker = threading.Thread(target=self._process_queue, daemon=True)
self.worker.start()
def _generate_audio(self, text: str, voice: str) -> Tuple[torch.Tensor, float]:
"""Generate audio and measure processing time"""
start_time = time.time()
# Get model instance and voicepack
model, device = TTSModel.get_instance()
voicepack = TTSModel.get_voicepack(voice)
# Generate audio
audio, _ = generate(model, text, voicepack, lang=voice[0])
processing_time = time.time() - start_time
return audio, processing_time
def _save_audio(self, audio: torch.Tensor, filepath: str):
"""Save audio to file"""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
wavfile.write(filepath, 24000, audio)
def _audio_to_bytes(self, audio: torch.Tensor) -> bytes:
"""Convert audio tensor to WAV bytes"""
buffer = io.BytesIO()
wavfile.write(buffer, 24000, audio)
return buffer.getvalue()
def _process_queue(self):
"""Background worker that processes the queue"""
while True:
next_request = self.db.get_next_pending()
if next_request:
request_id, text, voice = next_request
try:
# Generate audio and measure time
audio, processing_time = self._generate_audio(text, voice)
# Save to file
output_file = os.path.join(
self.output_dir, f"speech_{request_id}.wav"
)
self._save_audio(audio, output_file)
# Update status with processing time
self.db.update_status(
request_id,
"completed",
output_file=output_file,
processing_time=processing_time,
)
except Exception as e:
print(f"Error processing request {request_id}: {str(e)}")
self.db.update_status(request_id, "failed")
time.sleep(1) # Prevent busy waiting
def list_voices(self) -> list[str]:
"""List all available voices"""
voices = []
try:
for file in os.listdir("voices"):
if file.endswith(".pt"):
voice_name = file[:-3] # Remove .pt extension
voices.append(voice_name)
except Exception as e:
print(f"Error listing voices: {str(e)}")
return voices
def create_tts_request(self, text: str, voice: str = "af") -> int:
"""Create a new TTS request and return the request ID"""
return self.db.add_request(text, voice)
def get_request_status(
self, request_id: int
) -> Optional[Tuple[str, Optional[str], Optional[float]]]:
"""Get the status, output file path, and processing time for a request"""
return self.db.get_status(request_id)