Kokoro-FastAPI/api/src/services/tts_service.py

169 lines
6.1 KiB
Python
Raw Normal View History

import io
import os
import re
import time
from typing import List, Tuple, Optional
import numpy as np
import torch
import scipy.io.wavfile as wavfile
from kokoro import tokenize, phonemize, normalize_text
from loguru import logger
from ..core.config import settings
from .tts_model import TTSModel
from .tts_cpu import TTSCPUModel
from .tts_gpu import TTSGPUModel
class TTSService:
def __init__(self, output_dir: str = None):
self.output_dir = output_dir
def _split_text(self, text: str) -> List[str]:
"""Split text into sentences"""
return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()]
def _get_voice_path(self, voice_name: str) -> Optional[str]:
"""Get the path to a voice file"""
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice_name}.pt")
return voice_path if os.path.exists(voice_path) else None
def _generate_audio(
self, text: str, voice: str, speed: float, stitch_long_output: bool = True
) -> Tuple[torch.Tensor, float]:
"""Generate audio and measure processing time"""
start_time = time.time()
try:
# Normalize text once at the start
text = normalize_text(text)
if not text:
raise ValueError("Text is empty after preprocessing")
# Check voice exists
voice_path = self._get_voice_path(voice)
if not voice_path:
raise ValueError(f"Voice not found: {voice}")
# Load voice
voicepack = torch.load(
voice_path, map_location=TTSModel.get_device(), weights_only=True
)
# Generate audio with or without stitching
if stitch_long_output:
chunks = self._split_text(text)
audio_chunks = []
# Process all chunks
for i, chunk in enumerate(chunks):
try:
# Process chunk
if TTSModel.get_device() == "cuda":
chunk_audio, _ = TTSGPUModel.generate(chunk, voicepack, voice[0], speed)
else:
ps = phonemize(chunk, voice[0])
tokens = tokenize(ps)
tokens = [0] + tokens + [0] # Add padding
chunk_audio = TTSCPUModel.generate(tokens, voicepack, speed)
if chunk_audio is not None:
audio_chunks.append(chunk_audio)
else:
logger.error(f"No audio generated for chunk {i + 1}/{len(chunks)}")
except Exception as e:
logger.error(
f"Failed to generate audio for chunk {i + 1}/{len(chunks)}: '{chunk}'. Error: {str(e)}"
)
continue
if not audio_chunks:
raise ValueError("No audio chunks were generated successfully")
audio = (
np.concatenate(audio_chunks)
if len(audio_chunks) > 1
else audio_chunks[0]
)
else:
# Process single chunk
if TTSModel.get_device() == "cuda":
audio, _ = TTSGPUModel.generate(text, voicepack, voice[0], speed)
else:
ps = phonemize(text, voice[0])
tokens = tokenize(ps)
tokens = [0] + tokens + [0] # Add padding
audio = TTSCPUModel.generate(tokens, voicepack, speed)
processing_time = time.time() - start_time
return audio, processing_time
except Exception as e:
logger.error(f"Error in audio generation: {str(e)}")
raise
def _save_audio(self, audio: torch.Tensor, filepath: str):
"""Save audio to file"""
os.makedirs(os.path.dirname(filepath), exist_ok=True)
wavfile.write(filepath, 24000, audio)
def _audio_to_bytes(self, audio: torch.Tensor) -> bytes:
"""Convert audio tensor to WAV bytes"""
buffer = io.BytesIO()
wavfile.write(buffer, 24000, audio)
return buffer.getvalue()
def combine_voices(self, voices: List[str]) -> str:
"""Combine multiple voices into a new voice"""
if len(voices) < 2:
raise ValueError("At least 2 voices are required for combination")
# Load voices
t_voices: List[torch.Tensor] = []
v_name: List[str] = []
for voice in voices:
try:
voice_path = os.path.join(TTSModel.VOICES_DIR, f"{voice}.pt")
voicepack = torch.load(
voice_path, map_location=TTSModel.get_device(), weights_only=True
)
t_voices.append(voicepack)
v_name.append(voice)
except Exception as e:
raise ValueError(f"Failed to load voice {voice}: {str(e)}")
# Combine voices
try:
f: str = "_".join(v_name)
v = torch.mean(torch.stack(t_voices), dim=0)
combined_path = os.path.join(TTSModel.VOICES_DIR, f"{f}.pt")
# Save combined voice
try:
torch.save(v, combined_path)
except Exception as e:
raise RuntimeError(
f"Failed to save combined voice to {combined_path}: {str(e)}"
)
return f
except Exception as e:
if not isinstance(e, (ValueError, RuntimeError)):
raise RuntimeError(f"Error combining voices: {str(e)}")
raise
def list_voices(self) -> List[str]:
"""List all available voices"""
voices = []
try:
for file in os.listdir(TTSModel.VOICES_DIR):
if file.endswith(".pt"):
voices.append(file[:-3]) # Remove .pt extension
except Exception as e:
logger.error(f"Error listing voices: {str(e)}")
return sorted(voices)