mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-04-13 09:39:17 +00:00
136 lines
4.7 KiB
Python
136 lines
4.7 KiB
Python
![]() |
import os
|
||
|
import threading
|
||
|
import time
|
||
|
import io
|
||
|
from typing import Optional, Tuple
|
||
|
import torch
|
||
|
import scipy.io.wavfile as wavfile
|
||
|
from models import build_model
|
||
|
from kokoro import generate
|
||
|
from ..database.queue import QueueDB
|
||
|
|
||
|
|
||
|
class TTSModel:
|
||
|
_instance = None
|
||
|
_lock = threading.Lock()
|
||
|
_voicepacks = {}
|
||
|
|
||
|
@classmethod
|
||
|
def get_instance(cls):
|
||
|
if cls._instance is None:
|
||
|
with cls._lock:
|
||
|
if cls._instance is None:
|
||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
|
print(f"Initializing model on {device}")
|
||
|
model = build_model("kokoro-v0_19.pth", device)
|
||
|
cls._instance = (model, device)
|
||
|
return cls._instance
|
||
|
|
||
|
@classmethod
|
||
|
def get_voicepack(cls, voice_name: str) -> torch.Tensor:
|
||
|
model, device = cls.get_instance()
|
||
|
if voice_name not in cls._voicepacks:
|
||
|
try:
|
||
|
voicepack = torch.load(
|
||
|
f"voices/{voice_name}.pt", map_location=device, weights_only=True
|
||
|
)
|
||
|
cls._voicepacks[voice_name] = voicepack
|
||
|
except Exception as e:
|
||
|
print(f"Error loading voice {voice_name}: {str(e)}")
|
||
|
if voice_name != "af":
|
||
|
return cls.get_voicepack("af")
|
||
|
raise
|
||
|
return cls._voicepacks[voice_name]
|
||
|
|
||
|
|
||
|
class TTSService:
|
||
|
def __init__(self, output_dir: str = None):
|
||
|
if output_dir is None:
|
||
|
output_dir = os.path.join(os.path.dirname(__file__), "..", "output")
|
||
|
self.output_dir = output_dir
|
||
|
self.db = QueueDB()
|
||
|
os.makedirs(output_dir, exist_ok=True)
|
||
|
self._start_worker()
|
||
|
|
||
|
def _start_worker(self):
|
||
|
"""Start the background worker thread"""
|
||
|
self.worker = threading.Thread(target=self._process_queue, daemon=True)
|
||
|
self.worker.start()
|
||
|
|
||
|
def _generate_audio(self, text: str, voice: str) -> Tuple[torch.Tensor, float]:
|
||
|
"""Generate audio and measure processing time"""
|
||
|
start_time = time.time()
|
||
|
|
||
|
# Get model instance and voicepack
|
||
|
model, device = TTSModel.get_instance()
|
||
|
voicepack = TTSModel.get_voicepack(voice)
|
||
|
|
||
|
# Generate audio
|
||
|
audio, _ = generate(model, text, voicepack, lang=voice[0])
|
||
|
|
||
|
processing_time = time.time() - start_time
|
||
|
return audio, processing_time
|
||
|
|
||
|
def _save_audio(self, audio: torch.Tensor, filepath: str):
|
||
|
"""Save audio to file"""
|
||
|
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
||
|
wavfile.write(filepath, 24000, audio)
|
||
|
|
||
|
def _audio_to_bytes(self, audio: torch.Tensor) -> bytes:
|
||
|
"""Convert audio tensor to WAV bytes"""
|
||
|
buffer = io.BytesIO()
|
||
|
wavfile.write(buffer, 24000, audio)
|
||
|
return buffer.getvalue()
|
||
|
|
||
|
def _process_queue(self):
|
||
|
"""Background worker that processes the queue"""
|
||
|
while True:
|
||
|
next_request = self.db.get_next_pending()
|
||
|
if next_request:
|
||
|
request_id, text, voice = next_request
|
||
|
try:
|
||
|
# Generate audio and measure time
|
||
|
audio, processing_time = self._generate_audio(text, voice)
|
||
|
|
||
|
# Save to file
|
||
|
output_file = os.path.join(
|
||
|
self.output_dir, f"speech_{request_id}.wav"
|
||
|
)
|
||
|
self._save_audio(audio, output_file)
|
||
|
|
||
|
# Update status with processing time
|
||
|
self.db.update_status(
|
||
|
request_id,
|
||
|
"completed",
|
||
|
output_file=output_file,
|
||
|
processing_time=processing_time,
|
||
|
)
|
||
|
|
||
|
except Exception as e:
|
||
|
print(f"Error processing request {request_id}: {str(e)}")
|
||
|
self.db.update_status(request_id, "failed")
|
||
|
|
||
|
time.sleep(1) # Prevent busy waiting
|
||
|
|
||
|
def list_voices(self) -> list[str]:
|
||
|
"""List all available voices"""
|
||
|
voices = []
|
||
|
try:
|
||
|
for file in os.listdir("voices"):
|
||
|
if file.endswith(".pt"):
|
||
|
voice_name = file[:-3] # Remove .pt extension
|
||
|
voices.append(voice_name)
|
||
|
except Exception as e:
|
||
|
print(f"Error listing voices: {str(e)}")
|
||
|
return voices
|
||
|
|
||
|
def create_tts_request(self, text: str, voice: str = "af") -> int:
|
||
|
"""Create a new TTS request and return the request ID"""
|
||
|
return self.db.add_request(text, voice)
|
||
|
|
||
|
def get_request_status(
|
||
|
self, request_id: int
|
||
|
) -> Optional[Tuple[str, Optional[str], Optional[float]]]:
|
||
|
"""Get the status, output file path, and processing time for a request"""
|
||
|
return self.db.get_status(request_id)
|