mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
209 lines
8 KiB
Python
209 lines
8 KiB
Python
import os
|
|
import threading
|
|
import time
|
|
import io
|
|
from typing import Optional, List, Tuple
|
|
from sqlalchemy.orm import Session, sessionmaker
|
|
from ..models.schemas import TTSStatus
|
|
from ..database.models import TTSQueue
|
|
import numpy as np
|
|
import torch
|
|
import scipy.io.wavfile as wavfile
|
|
from models import build_model
|
|
from kokoro import generate, phonemize, tokenize
|
|
from ..database.queue import QueueDB
|
|
|
|
|
|
class TTSModel:
|
|
_instance = None
|
|
_lock = threading.Lock()
|
|
_voicepacks = {}
|
|
|
|
@classmethod
|
|
def get_instance(cls):
|
|
if cls._instance is None:
|
|
with cls._lock:
|
|
if cls._instance is None:
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
print(f"Initializing model on {device}")
|
|
model = build_model("kokoro-v0_19.pth", device)
|
|
cls._instance = (model, device)
|
|
return cls._instance
|
|
|
|
@classmethod
|
|
def get_voicepack(cls, voice_name: str) -> torch.Tensor:
|
|
model, device = cls.get_instance()
|
|
if voice_name not in cls._voicepacks:
|
|
try:
|
|
voicepack = torch.load(
|
|
f"voices/{voice_name}.pt", map_location=device, weights_only=True
|
|
)
|
|
cls._voicepacks[voice_name] = voicepack
|
|
except Exception as e:
|
|
print(f"Error loading voice {voice_name}: {str(e)}")
|
|
if voice_name != "af":
|
|
return cls.get_voicepack("af")
|
|
raise
|
|
return cls._voicepacks[voice_name]
|
|
|
|
|
|
class TTSService:
|
|
def __init__(self, db: Session, output_dir: str = None):
|
|
if output_dir is None:
|
|
output_dir = os.path.join(os.path.dirname(__file__), "..", "output")
|
|
self.output_dir = output_dir
|
|
self.db = QueueDB(db)
|
|
self.engine = db.get_bind() # Get engine from session
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
self._start_worker()
|
|
|
|
def _start_worker(self):
|
|
"""Start the background worker thread"""
|
|
self.worker = threading.Thread(target=self._process_queue, daemon=True)
|
|
self.worker.start()
|
|
|
|
def _find_boundary(self, text: str, max_tokens: int, voice: str, margin: int = 50) -> int:
|
|
"""Find the closest sentence/clause boundary within token limit"""
|
|
# Try different boundary markers in order of preference
|
|
for marker in ['. ', '; ', ', ']:
|
|
# Look for the last occurrence of marker before max_tokens
|
|
test_text = text[:max_tokens + margin] # Look a bit beyond the limit
|
|
last_idx = test_text.rfind(marker)
|
|
|
|
if last_idx != -1:
|
|
# Verify this boundary is within our token limit
|
|
candidate = text[:last_idx + len(marker)].strip()
|
|
ps = phonemize(candidate, voice[0])
|
|
tokens = tokenize(ps)
|
|
|
|
if len(tokens) <= max_tokens:
|
|
return last_idx + len(marker)
|
|
|
|
# If no good boundary found, find last whitespace within limit
|
|
test_text = text[:max_tokens]
|
|
last_space = test_text.rfind(' ')
|
|
return last_space if last_space != -1 else max_tokens
|
|
|
|
def _split_text(self, text: str, voice: str) -> List[str]:
|
|
"""Split text into chunks that respect token limits and try to maintain sentence structure"""
|
|
MAX_TOKENS = 450 # Leave wider margin from 510 limit to account for tokenizer differences
|
|
chunks = []
|
|
remaining = text
|
|
|
|
while remaining:
|
|
# If remaining text is within limit, add it as final chunk
|
|
ps = phonemize(remaining, voice[0])
|
|
tokens = tokenize(ps)
|
|
if len(tokens) <= MAX_TOKENS:
|
|
chunks.append(remaining.strip())
|
|
break
|
|
|
|
# Find best boundary position
|
|
split_pos = self._find_boundary(remaining, MAX_TOKENS, voice)
|
|
|
|
# Add chunk and continue with remaining text
|
|
chunks.append(remaining[:split_pos].strip())
|
|
remaining = remaining[split_pos:].strip()
|
|
|
|
return chunks
|
|
|
|
def _generate_audio(self, text: str, voice: str, speed: float, stitch_long_output: bool = True) -> Tuple[torch.Tensor, float]:
|
|
"""Generate audio and measure processing time"""
|
|
start_time = time.time()
|
|
|
|
# Get model instance and voicepack
|
|
model, device = TTSModel.get_instance()
|
|
voicepack = TTSModel.get_voicepack(voice)
|
|
|
|
# Generate audio with or without stitching
|
|
if stitch_long_output:
|
|
# Split text if needed and generate audio for each chunk
|
|
chunks = self._split_text(text, voice)
|
|
audio_chunks = []
|
|
|
|
for chunk in chunks:
|
|
chunk_audio, _ = generate(model, chunk, voicepack, lang=voice[0], speed=speed)
|
|
audio_chunks.append(chunk_audio)
|
|
|
|
# Concatenate audio chunks
|
|
if len(audio_chunks) > 1:
|
|
audio = np.concatenate(audio_chunks)
|
|
else:
|
|
audio = audio_chunks[0]
|
|
else:
|
|
# Generate single chunk without splitting
|
|
audio, _ = generate(model, text, voicepack, lang=voice[0])
|
|
|
|
processing_time = time.time() - start_time
|
|
return audio, processing_time
|
|
|
|
def _save_audio(self, audio: torch.Tensor, filepath: str):
|
|
"""Save audio to file"""
|
|
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
|
wavfile.write(filepath, 24000, audio)
|
|
|
|
def _audio_to_bytes(self, audio: torch.Tensor) -> bytes:
|
|
"""Convert audio tensor to WAV bytes"""
|
|
buffer = io.BytesIO()
|
|
wavfile.write(buffer, 24000, audio)
|
|
return buffer.getvalue()
|
|
|
|
def _process_queue(self):
|
|
"""Background worker that processes the queue"""
|
|
# Create a new session for the background worker
|
|
Session = sessionmaker(bind=self.engine)
|
|
|
|
while True:
|
|
# Create a new session for each iteration
|
|
with Session() as session:
|
|
db = QueueDB(session)
|
|
request = db.get_next_pending()
|
|
if request:
|
|
try:
|
|
# Generate audio and measure time
|
|
audio, processing_time = self._generate_audio(
|
|
request.text,
|
|
request.voice,
|
|
request.speed,
|
|
request.stitch_long_output
|
|
)
|
|
|
|
# Save to file
|
|
output_file = os.path.abspath(os.path.join(
|
|
self.output_dir, f"speech_{request.id}.wav"
|
|
))
|
|
self._save_audio(audio, output_file)
|
|
|
|
# Update status with processing time
|
|
db.update_status(
|
|
request.id,
|
|
TTSStatus.COMPLETED,
|
|
output_file=output_file,
|
|
processing_time=processing_time,
|
|
)
|
|
|
|
except Exception as e:
|
|
print(f"Error processing request {request.id}: {str(e)}")
|
|
db.update_status(request.id, TTSStatus.FAILED)
|
|
|
|
time.sleep(1) # Prevent busy waiting
|
|
|
|
def list_voices(self) -> list[str]:
|
|
"""List all available voices"""
|
|
voices = []
|
|
try:
|
|
for file in os.listdir("voices"):
|
|
if file.endswith(".pt"):
|
|
voice_name = file[:-3] # Remove .pt extension
|
|
voices.append(voice_name)
|
|
except Exception as e:
|
|
print(f"Error listing voices: {str(e)}")
|
|
return voices
|
|
|
|
def create_tts_request(self, text: str, voice: str = "af", speed: float = 1.0, stitch_long_output: bool = True) -> int:
|
|
"""Create a new TTS request and return the request ID"""
|
|
return self.db.add_request(text, voice, speed, stitch_long_output)
|
|
|
|
def get_request_status(self, request_id: int) -> Optional[TTSQueue]:
|
|
"""Get the full request details"""
|
|
return self.db.get_status(request_id)
|