Kokoro-FastAPI/api/src/services/tts_service.py

442 lines
18 KiB
Python
Raw Normal View History

"""TTS service using model and voice managers."""
2025-02-09 18:32:17 -07:00
import asyncio
import os
2025-03-11 14:28:48 -04:00
import re
import tempfile
2025-02-09 18:32:17 -07:00
import time
from typing import AsyncGenerator, List, Optional, Tuple, Union
2025-01-13 20:15:46 -07:00
import numpy as np
import torch
2025-02-09 18:32:17 -07:00
from kokoro import KPipeline
from loguru import logger
2025-01-09 18:41:44 -07:00
from ..core.config import settings
2025-03-11 14:28:48 -04:00
from ..inference.base import AudioChunk
2025-02-09 18:32:17 -07:00
from ..inference.kokoro_v1 import KokoroV1
from ..inference.model_manager import get_manager as get_model_manager
from ..inference.voice_manager import get_manager as get_voice_manager
2025-03-11 14:28:48 -04:00
from ..structures.schemas import NormalizationOptions
2025-01-13 20:15:46 -07:00
from .audio import AudioNormalizer, AudioService
from .text_processing import tokenize
2025-02-09 18:32:17 -07:00
from .text_processing.text_processor import process_text_chunk, smart_split
2025-03-11 14:28:48 -04:00
class TTSService:
"""Text-to-speech service."""
# Limit concurrent chunk processing
_chunk_semaphore = asyncio.Semaphore(4)
def __init__(self, output_dir: str = None):
2025-01-25 05:25:13 -07:00
"""Initialize service."""
self.output_dir = output_dir
self.model_manager = None
self._voice_manager = None
@classmethod
2025-02-09 18:32:17 -07:00
async def create(cls, output_dir: str = None) -> "TTSService":
2025-01-25 05:25:13 -07:00
"""Create and initialize TTSService instance."""
service = cls(output_dir)
service.model_manager = await get_model_manager()
service._voice_manager = await get_voice_manager()
return service
2025-01-25 05:25:13 -07:00
async def _process_chunk(
self,
chunk_text: str,
tokens: List[int],
voice_name: str,
voice_path: str,
2025-01-25 05:25:13 -07:00
speed: float,
output_format: Optional[str] = None,
is_first: bool = False,
is_last: bool = False,
normalizer: Optional[AudioNormalizer] = None,
lang_code: Optional[str] = None,
2025-02-12 17:13:56 +00:00
return_timestamps: Optional[bool] = False,
) -> AsyncGenerator[AudioChunk, None]:
"""Process tokens into audio."""
2025-01-25 05:25:13 -07:00
async with self._chunk_semaphore:
try:
# Handle stream finalization
if is_last:
# Skip format conversion for raw audio mode
if not output_format:
yield AudioChunk(np.array([], dtype=np.int16),output=b'')
return
chunk_data = await AudioService.convert_audio(
AudioChunk(np.array([], dtype=np.float32)), # Dummy data for type checking
24000,
output_format,
speed,
"",
is_first_chunk=False,
normalizer=normalizer,
2025-02-09 18:32:17 -07:00
is_last_chunk=True,
)
yield chunk_data
return
2025-02-09 18:32:17 -07:00
# Skip empty chunks
if not tokens and not chunk_text:
return
# Get backend
backend = self.model_manager.get_backend()
2025-01-25 05:25:13 -07:00
# Generate audio using pre-warmed model
if isinstance(backend, KokoroV1):
chunk_index=0
# For Kokoro V1, pass text and voice info with lang_code
2025-02-11 22:32:10 -05:00
async for chunk_data in self.model_manager.generate(
chunk_text,
(voice_name, voice_path),
speed=speed,
2025-02-09 18:32:17 -07:00
lang_code=lang_code,
2025-02-12 17:13:56 +00:00
return_timestamps=return_timestamps,
):
# For streaming, convert to bytes
if output_format:
try:
chunk_data = await AudioService.convert_audio(
2025-02-11 22:32:10 -05:00
chunk_data,
24000,
output_format,
speed,
chunk_text,
is_first_chunk=is_first and chunk_index == 0,
2025-02-09 18:32:17 -07:00
is_last_chunk=is_last,
normalizer=normalizer,
)
yield chunk_data
except Exception as e:
logger.error(f"Failed to convert audio: {str(e)}")
else:
2025-02-13 16:12:51 -05:00
chunk_data = AudioService.trim_audio(chunk_data,
chunk_text,
speed,
is_last,
normalizer)
yield chunk_data
chunk_index+=1
else:
2025-02-12 15:06:11 +00:00
# For legacy backends, load voice tensor
2025-02-09 18:32:17 -07:00
voice_tensor = await self._voice_manager.load_voice(
voice_name, device=backend.device
)
2025-02-12 17:13:56 +00:00
chunk_data = await self.model_manager.generate(
tokens, voice_tensor, speed=speed, return_timestamps=return_timestamps
)
2025-02-09 18:32:17 -07:00
2025-02-12 17:13:56 +00:00
if chunk_data.audio is None:
logger.error("Model generated None for audio chunk")
return
2025-02-09 18:32:17 -07:00
2025-02-12 17:13:56 +00:00
if len(chunk_data.audio) == 0:
logger.error("Model generated empty audio chunk")
return
2025-02-09 18:32:17 -07:00
# For streaming, convert to bytes
if output_format:
try:
chunk_data = await AudioService.convert_audio(
2025-02-12 17:13:56 +00:00
chunk_data,
24000,
output_format,
speed,
chunk_text,
is_first_chunk=is_first,
normalizer=normalizer,
2025-02-09 18:32:17 -07:00
is_last_chunk=is_last,
)
yield chunk_data
except Exception as e:
logger.error(f"Failed to convert audio: {str(e)}")
else:
2025-02-13 16:12:51 -05:00
trimmed = AudioService.trim_audio(chunk_data,
chunk_text,
speed,
is_last,
normalizer)
yield trimmed
2025-01-25 05:25:13 -07:00
except Exception as e:
logger.error(f"Failed to process tokens: {str(e)}")
2025-03-11 14:28:48 -04:00
async def _load_voice_from_path(self, path: str, weight: float):
# Check if the path is None and raise a ValueError if it is not
if not path:
raise ValueError(f"Voice not found at path: {path}")
logger.debug(f"Loading voice tensor from path: {path}")
return torch.load(path, map_location="cpu") * weight
async def _get_voices_path(self, voice: str) -> Tuple[str, str]:
"""Get voice path, handling combined voices.
2025-02-09 18:32:17 -07:00
Args:
voice: Voice name or combined voice names (e.g., 'af_jadzia+af_jessica')
2025-02-09 18:32:17 -07:00
Returns:
Tuple of (voice name to use, voice path to use)
2025-02-09 18:32:17 -07:00
Raises:
RuntimeError: If voice not found
"""
try:
2025-03-11 14:28:48 -04:00
# Split the voice on + and - and ensure that they get added to the list eg: hi+bob = ["hi","+","bob"]
split_voice = re.split(r"([-+])", voice)
# If it is only once voice there is no point in loading it up, doing nothing with it, then saving it
if len(split_voice) == 1:
# Since its a single voice the only time that the weight would matter is if voice_weight_normalization is off
if ("(" not in voice and ")" not in voice) or settings.voice_weight_normalization == True:
path = await self._voice_manager.get_voice_path(voice)
if not path:
2025-03-11 14:28:48 -04:00
raise RuntimeError(f"Voice not found: {voice}")
logger.debug(f"Using single voice path: {path}")
return voice, path
total_weight = 0
for voice_index in range(0,len(split_voice),2):
voice_object = split_voice[voice_index]
if "(" in voice_object and ")" in voice_object:
voice_name = voice_object.split("(")[0].strip()
voice_weight = float(voice_object.split("(")[1].split(")")[0])
else:
voice_name = voice_object
voice_weight = 1
total_weight += voice_weight
split_voice[voice_index] = (voice_name, voice_weight)
# If voice_weight_normalization is false prevent normalizing the weights by setting the total_weight to 1 so it divides each weight by 1
if settings.voice_weight_normalization == False:
total_weight = 1
# Load the first voice as the starting point for voices to be combined onto
path = await self._voice_manager.get_voice_path(split_voice[0][0])
combined_tensor = await self._load_voice_from_path(path, split_voice[0][1] / total_weight)
# Loop through each + or - in split_voice so they can be applied to combined voice
for operation_index in range(1,len(split_voice) - 1, 2):
# Get the voice path of the voice 1 index ahead of the operator
path = await self._voice_manager.get_voice_path(split_voice[operation_index+1][0])
voice_tensor = await self._load_voice_from_path(path, split_voice[operation_index + 1][1] / total_weight)
# Either add or subtract the voice from the current combined voice
if split_voice[operation_index] == "+":
combined_tensor += voice_tensor
else:
combined_tensor -= voice_tensor
# Save the new combined voice so it can be loaded latter
temp_dir = tempfile.gettempdir()
combined_path = os.path.join(temp_dir, f"{voice}.pt")
logger.debug(f"Saving combined voice to: {combined_path}")
torch.save(combined_tensor, combined_path)
return voice, combined_path
except Exception as e:
logger.error(f"Failed to get voice path: {e}")
raise
2025-01-04 17:54:54 -07:00
async def generate_audio_stream(
2025-01-09 18:41:44 -07:00
self,
text: str,
voice: str,
speed: float = 1.0,
2025-01-09 18:41:44 -07:00
output_format: str = "wav",
lang_code: Optional[str] = None,
2025-02-12 17:13:56 +00:00
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
return_timestamps: Optional[bool] = False,
) -> AsyncGenerator[AudioChunk, None]:
2025-01-25 05:25:13 -07:00
"""Generate and stream audio chunks."""
stream_normalizer = AudioNormalizer()
chunk_index = 0
current_offset=0.0
2025-01-04 17:54:54 -07:00
try:
# Get backend
backend = self.model_manager.get_backend()
# Get voice path, handling combined voices
2025-03-11 14:28:48 -04:00
voice_name, voice_path = await self._get_voices_path(voice)
logger.debug(f"Using voice path: {voice_path}")
# Use provided lang_code or determine from voice name
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
2025-02-09 18:32:17 -07:00
logger.info(
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream"
)
# Process text in chunks with smart splitting
async for chunk_text, tokens in smart_split(text,lang_code=lang_code,normalization_options=normalization_options):
try:
# Process audio for chunk
async for chunk_data in self._process_chunk(
chunk_text, # Pass text for Kokoro V1
2025-02-09 18:32:17 -07:00
tokens, # Pass tokens for legacy backends
voice_name, # Pass voice name
voice_path, # Pass voice path
2025-01-25 05:25:13 -07:00
speed,
output_format,
is_first=(chunk_index == 0),
is_last=False, # We'll update the last chunk later
normalizer=stream_normalizer,
2025-02-09 18:32:17 -07:00
lang_code=pipeline_lang_code, # Pass lang_code
return_timestamps=return_timestamps,
):
if chunk_data.word_timestamps is not None:
for timestamp in chunk_data.word_timestamps:
2025-02-14 13:37:42 -05:00
timestamp.start_time+=current_offset
timestamp.end_time+=current_offset
current_offset+=len(chunk_data.audio) / 24000
if chunk_data.output is not None:
yield chunk_data
else:
2025-02-09 18:32:17 -07:00
logger.warning(
f"No audio generated for chunk: '{chunk_text[:100]}...'"
)
chunk_index += 1
except Exception as e:
2025-02-09 18:32:17 -07:00
logger.error(
f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}"
)
continue
# Only finalize if we successfully processed at least one chunk
if chunk_index > 0:
try:
# Empty tokens list to finalize audio
async for chunk_data in self._process_chunk(
"", # Empty text
[], # Empty tokens
voice_name,
voice_path,
speed,
output_format,
is_first=False,
is_last=True, # Signal this is the last chunk
normalizer=stream_normalizer,
2025-02-09 18:32:17 -07:00
lang_code=pipeline_lang_code, # Pass lang_code
):
if chunk_data.output is not None:
yield chunk_data
except Exception as e:
logger.error(f"Failed to finalize audio stream: {str(e)}")
except Exception as e:
logger.error(f"Error in phoneme audio generation: {str(e)}")
2025-02-12 15:06:11 +00:00
raise e
2025-01-25 05:25:13 -07:00
async def generate_audio(
2025-02-09 18:32:17 -07:00
self,
text: str,
voice: str,
speed: float = 1.0,
return_timestamps: bool = False,
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
2025-02-09 18:32:17 -07:00
lang_code: Optional[str] = None,
) -> AudioChunk:
2025-01-25 05:25:13 -07:00
"""Generate complete audio for text using streaming internally."""
audio_data_chunks=[]
2025-01-25 05:25:13 -07:00
try:
async for audio_stream_data in self.generate_audio_stream(text,voice,speed=speed,normalization_options=normalization_options,return_timestamps=return_timestamps,lang_code=lang_code,output_format=None):
if len(audio_stream_data.audio) > 0:
audio_data_chunks.append(audio_stream_data)
combined_audio_data=AudioChunk.combine(audio_data_chunks)
return combined_audio_data
2025-01-25 05:25:13 -07:00
except Exception as e:
logger.error(f"Error in audio generation: {str(e)}")
raise
2025-01-25 05:25:13 -07:00
async def combine_voices(self, voices: List[str]) -> torch.Tensor:
"""Combine multiple voices.
2025-02-09 18:32:17 -07:00
Returns:
Combined voice tensor
"""
2025-03-11 14:28:48 -04:00
return await self._voice_manager.combine_voices(voices)
2025-01-09 18:41:44 -07:00
async def list_voices(self) -> List[str]:
2025-01-25 05:25:13 -07:00
"""List available voices."""
return await self._voice_manager.list_voices()
async def generate_from_phonemes(
self,
phonemes: str,
voice: str,
speed: float = 1.0,
2025-02-09 18:32:17 -07:00
lang_code: Optional[str] = None,
) -> Tuple[np.ndarray, float]:
"""Generate audio directly from phonemes.
2025-02-09 18:32:17 -07:00
Args:
phonemes: Phonemes in Kokoro format
voice: Voice name
speed: Speed multiplier
lang_code: Optional language code override
2025-02-09 18:32:17 -07:00
Returns:
Tuple of (audio array, processing time)
"""
start_time = time.time()
try:
# Get backend and voice path
backend = self.model_manager.get_backend()
2025-03-11 14:28:48 -04:00
voice_name, voice_path = await self._get_voices_path(voice)
if isinstance(backend, KokoroV1):
# For Kokoro V1, use generate_from_tokens with raw phonemes
result = None
# Use provided lang_code or determine from voice name
pipeline_lang_code = lang_code if lang_code else voice[:1].lower()
2025-02-09 18:32:17 -07:00
logger.info(
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in phoneme pipeline"
)
try:
# Use backend's pipeline management
2025-02-09 18:32:17 -07:00
for r in backend._get_pipeline(
pipeline_lang_code
).generate_from_tokens(
tokens=phonemes, # Pass raw phonemes string
voice=voice_path,
2025-02-09 18:32:17 -07:00
speed=speed,
):
if r.audio is not None:
result = r
break
except Exception as e:
logger.error(f"Failed to generate from phonemes: {e}")
raise RuntimeError(f"Phoneme generation failed: {e}")
2025-02-09 18:32:17 -07:00
if result is None or result.audio is None:
raise ValueError("No audio generated")
2025-02-09 18:32:17 -07:00
processing_time = time.time() - start_time
return result.audio.numpy(), processing_time
else:
2025-02-09 18:32:17 -07:00
raise ValueError(
"Phoneme generation only supported with Kokoro V1 backend"
)
except Exception as e:
logger.error(f"Error in phoneme audio generation: {str(e)}")
raise