Merge branch 'master' into master

This commit is contained in:
remsky 2025-02-12 23:31:13 -07:00 committed by GitHub
commit 694b7435f1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 693 additions and 85 deletions

View file

@ -29,8 +29,11 @@ class Settings(BaseSettings):
target_min_tokens: int = 175 # Target minimum tokens per chunk target_min_tokens: int = 175 # Target minimum tokens per chunk
target_max_tokens: int = 250 # Target maximum tokens per chunk target_max_tokens: int = 250 # Target maximum tokens per chunk
absolute_max_tokens: int = 450 # Absolute maximum tokens per chunk absolute_max_tokens: int = 450 # Absolute maximum tokens per chunk
advanced_text_normalization: bool = True # Preproesses the text before misiki which leads
gap_trim_ms: int = 250 # Amount to trim from streaming chunk ends in milliseconds gap_trim_ms: int = 1 # Base amount to trim from streaming chunk ends in milliseconds
dynamic_gap_trim_padding_ms: int = 410 # Padding to add to dynamic gap trim
dynamic_gap_trim_padding_char_multiplier: dict[str,float] = {".":1,"!":0.9,"?":1,",":0.8}
# Web Player Settings # Web Player Settings
enable_web_player: bool = True # Whether to serve the web player UI enable_web_player: bool = True # Whether to serve the web player UI

View file

@ -150,7 +150,7 @@ class KokoroV1(BaseModelBackend):
pipeline = self._get_pipeline(pipeline_lang_code) pipeline = self._get_pipeline(pipeline_lang_code)
logger.debug( logger.debug(
f"Generating audio from tokens with lang_code '{pipeline_lang_code}': '{tokens[:100]}...'" f"Generating audio from tokens with lang_code '{pipeline_lang_code}': '{tokens[:100]}{'...' if len(tokens) > 100 else ''}'"
) )
for result in pipeline.generate_from_tokens( for result in pipeline.generate_from_tokens(
tokens=tokens, voice=voice_path, speed=speed, model=self._model tokens=tokens, voice=voice_path, speed=speed, model=self._model
@ -198,7 +198,6 @@ class KokoroV1(BaseModelBackend):
""" """
if not self.is_loaded: if not self.is_loaded:
raise RuntimeError("Model not loaded") raise RuntimeError("Model not loaded")
try: try:
# Memory management for GPU # Memory management for GPU
if self._device == "cuda": if self._device == "cuda":
@ -243,7 +242,7 @@ class KokoroV1(BaseModelBackend):
pipeline = self._get_pipeline(pipeline_lang_code) pipeline = self._get_pipeline(pipeline_lang_code)
logger.debug( logger.debug(
f"Generating audio for text with lang_code '{pipeline_lang_code}': '{text[:100]}...'" f"Generating audio for text with lang_code '{pipeline_lang_code}': '{text[:100]}{'...' if len(text) > 100 else ''}'"
) )
for result in pipeline( for result in pipeline(
text, voice=voice_path, speed=speed, model=self._model text, voice=voice_path, speed=speed, model=self._model

View file

@ -3,6 +3,7 @@
import io import io
import json import json
import os import os
import re
import tempfile import tempfile
from typing import AsyncGenerator, Dict, List, Union from typing import AsyncGenerator, Dict, List, Union
@ -137,7 +138,8 @@ async def stream_audio_chunks(
voice=voice_name, voice=voice_name,
speed=request.speed, speed=request.speed,
output_format=request.response_format, output_format=request.response_format,
lang_code=request.lang_code if request.lang_code else (settings.default_voice_code if settings.default_voice_code else voice_name[0].lower()), lang_code = request.lang_code or settings.default_voice_code or voice_name[0].lower(),
normalization_options=request.normalization_options
): ):
# Check if client is still connected # Check if client is still connected
is_disconnected = client_request.is_disconnected is_disconnected = client_request.is_disconnected

View file

@ -4,10 +4,12 @@ import struct
from io import BytesIO from io import BytesIO
import numpy as np import numpy as np
import math
import scipy.io.wavfile as wavfile import scipy.io.wavfile as wavfile
import soundfile as sf import soundfile as sf
from loguru import logger from loguru import logger
from pydub import AudioSegment from pydub import AudioSegment
from torch import norm
from ..core.config import settings from ..core.config import settings
from .streaming_audio_writer import StreamingAudioWriter from .streaming_audio_writer import StreamingAudioWriter
@ -20,23 +22,66 @@ class AudioNormalizer:
self.chunk_trim_ms = settings.gap_trim_ms self.chunk_trim_ms = settings.gap_trim_ms
self.sample_rate = 24000 # Sample rate of the audio self.sample_rate = 24000 # Sample rate of the audio
self.samples_to_trim = int(self.chunk_trim_ms * self.sample_rate / 1000) self.samples_to_trim = int(self.chunk_trim_ms * self.sample_rate / 1000)
self.samples_to_pad_start= int(50 * self.sample_rate / 1000)
def find_first_last_non_silent(self,audio_data: np.ndarray, chunk_text: str, speed: float, silence_threshold_db: int = -45, is_last_chunk: bool = False) -> tuple[int, int]:
"""Finds the indices of the first and last non-silent samples in audio data.
Args:
audio_data: Input audio data as numpy array
chunk_text: The text sent to the model to generate the resulting speech
speed: The speaking speed of the voice
silence_threshold_db: How quiet audio has to be to be conssidered silent
is_last_chunk: Whether this is the last chunk
Returns:
A tuple with the start of the non silent portion and with the end of the non silent portion
"""
pad_multiplier=1
split_character=chunk_text.strip()
if len(split_character) > 0:
split_character=split_character[-1]
if split_character in settings.dynamic_gap_trim_padding_char_multiplier:
pad_multiplier=settings.dynamic_gap_trim_padding_char_multiplier[split_character]
if not is_last_chunk:
samples_to_pad_end= max(int((settings.dynamic_gap_trim_padding_ms * self.sample_rate * pad_multiplier) / 1000) - self.samples_to_pad_start, 0)
else:
samples_to_pad_end=self.samples_to_pad_start
# Convert dBFS threshold to amplitude
amplitude_threshold = np.iinfo(audio_data.dtype).max * (10 ** (silence_threshold_db / 20))
# Find the first samples above the silence threshold at the start and end of the audio
non_silent_index_start, non_silent_index_end = None,None
for X in range(0,len(audio_data)):
#print(audio_data[X])
if audio_data[X] > amplitude_threshold:
non_silent_index_start=X
break
for X in range(len(audio_data) - 1, -1, -1):
if audio_data[X] > amplitude_threshold:
non_silent_index_end=X
break
# Handle the case where the entire audio is silent
if non_silent_index_start == None or non_silent_index_end == None:
return 0, len(audio_data)
return max(non_silent_index_start - self.samples_to_pad_start,0), min(non_silent_index_end + math.ceil(samples_to_pad_end / speed),len(audio_data))
async def normalize(self, audio_data: np.ndarray) -> np.ndarray: async def normalize(self, audio_data: np.ndarray) -> np.ndarray:
"""Convert audio data to int16 range and trim silence from start and end """Convert audio data to int16 range
Args: Args:
audio_data: Input audio data as numpy array audio_data: Input audio data as numpy array
Returns: Returns:
Normalized and trimmed audio data Normalized audio data
""" """
if len(audio_data) == 0: if len(audio_data) == 0:
raise ValueError("Empty audio data") raise ValueError("Empty audio data")
# Trim start and end if enough samples
if len(audio_data) > (2 * self.samples_to_trim):
audio_data = audio_data[self.samples_to_trim : -self.samples_to_trim]
# Scale directly to int16 range with clipping # Scale directly to int16 range with clipping
return np.clip(audio_data * 32767, -32768, 32767).astype(np.int16) return np.clip(audio_data * 32767, -32768, 32767).astype(np.int16)
@ -71,6 +116,8 @@ class AudioService:
audio_data: np.ndarray, audio_data: np.ndarray,
sample_rate: int, sample_rate: int,
output_format: str, output_format: str,
speed: float = 1,
chunk_text: str = "",
is_first_chunk: bool = True, is_first_chunk: bool = True,
is_last_chunk: bool = False, is_last_chunk: bool = False,
normalizer: AudioNormalizer = None, normalizer: AudioNormalizer = None,
@ -81,6 +128,8 @@ class AudioService:
audio_data: Numpy array of audio samples audio_data: Numpy array of audio samples
sample_rate: Sample rate of the audio sample_rate: Sample rate of the audio
output_format: Target format (wav, mp3, ogg, pcm) output_format: Target format (wav, mp3, ogg, pcm)
speed: The speaking speed of the voice
chunk_text: The text sent to the model to generate the resulting speech
is_first_chunk: Whether this is the first chunk is_first_chunk: Whether this is the first chunk
is_last_chunk: Whether this is the last chunk is_last_chunk: Whether this is the last chunk
normalizer: Optional AudioNormalizer instance for consistent normalization normalizer: Optional AudioNormalizer instance for consistent normalization
@ -96,8 +145,10 @@ class AudioService:
# Always normalize audio to ensure proper amplitude scaling # Always normalize audio to ensure proper amplitude scaling
if normalizer is None: if normalizer is None:
normalizer = AudioNormalizer() normalizer = AudioNormalizer()
normalized_audio = await normalizer.normalize(audio_data) normalized_audio = await normalizer.normalize(audio_data)
normalized_audio = AudioService.trim_audio(normalized_audio,chunk_text,speed,is_last_chunk,normalizer)
# Get or create format-specific writer # Get or create format-specific writer
writer_key = f"{output_format}_{sample_rate}" writer_key = f"{output_format}_{sample_rate}"
if is_first_chunk or writer_key not in AudioService._writers: if is_first_chunk or writer_key not in AudioService._writers:
@ -123,3 +174,27 @@ class AudioService:
raise ValueError( raise ValueError(
f"Failed to convert audio stream to {output_format}: {str(e)}" f"Failed to convert audio stream to {output_format}: {str(e)}"
) )
@staticmethod
def trim_audio(audio_data: np.ndarray, chunk_text: str = "", speed: float = 1, is_last_chunk: bool = False, normalizer: AudioNormalizer = None) -> np.ndarray:
"""Trim silence from start and end
Args:
audio_data: Input audio data as numpy array
chunk_text: The text sent to the model to generate the resulting speech
speed: The speaking speed of the voice
is_last_chunk: Whether this is the last chunk
normalizer: Optional AudioNormalizer instance for consistent normalization
Returns:
Trimmed audio data
"""
if normalizer is None:
normalizer = AudioNormalizer()
# Trim start and end if enough samples
if len(audio_data) > (2 * normalizer.samples_to_trim):
audio_data = audio_data[normalizer.samples_to_trim : -normalizer.samples_to_trim]
# Find non silent portion and trim
start_index,end_index=normalizer.find_first_last_non_silent(audio_data,chunk_text,speed,is_last_chunk=is_last_chunk)
return audio_data[start_index:end_index]

View file

@ -6,6 +6,9 @@ Converts them into a format suitable for text-to-speech processing.
import re import re
from functools import lru_cache from functools import lru_cache
import inflect
from ...structures.schemas import NormalizationOptions
# Constants # Constants
VALID_TLDS = [ VALID_TLDS = [
@ -50,6 +53,27 @@ VALID_TLDS = [
"io", "io",
] ]
VALID_UNITS = {
"m":"meter", "cm":"centimeter", "mm":"millimeter", "km":"kilometer", "in":"inch", "ft":"foot", "yd":"yard", "mi":"mile", # Length
"g":"gram", "kg":"kilogram", "mg":"miligram", # Mass
"s":"second", "ms":"milisecond", "min":"minutes", "h":"hour", # Time
"l":"liter", "ml":"mililiter", "cl":"centiliter", "dl":"deciliter", # Volume
"kph":"kilometer per hour", "mph":"mile per hour","mi/h":"mile per hour", "m/s":"meter per second", "km/h":"kilometer per hour", "mm/s":"milimeter per second","cm/s":"centimeter per second", "ft/s":"feet per second","cm/h":"centimeter per day", # Speed
"°c":"degree celsius","c":"degree celsius", "°f":"degree fahrenheit","f":"degree fahrenheit", "k":"kelvin", # Temperature
"pa":"pascal", "kpa":"kilopascal", "mpa":"megapascal", "atm":"atmosphere", # Pressure
"hz":"hertz", "khz":"kilohertz", "mhz":"megahertz", "ghz":"gigahertz", # Frequency
"v":"volt", "kv":"kilovolt", "mv":"mergavolt", # Voltage
"a":"amp", "ma":"megaamp", "ka":"kiloamp", # Current
"w":"watt", "kw":"kilowatt", "mw":"megawatt", # Power
"j":"joule", "kj":"kilojoule", "mj":"megajoule", # Energy
"Ω":"ohm", "":"kiloohm", "":"megaohm", # Resistance (Ohm)
"f":"farad", "µf":"microfarad", "nf":"nanofarad", "pf":"picofarad", # Capacitance
"b":"bit", "kb":"kilobit", "mb":"megabit", "gb":"gigabit", "tb":"terabit", "pb":"petabit", # Data size
"kbps":"kilobit per second","mbps":"megabit per second","gbps":"gigabit per second","tbps":"terabit per second",
"px":"pixel" # CSS units
}
# Pre-compiled regex patterns for performance # Pre-compiled regex patterns for performance
EMAIL_PATTERN = re.compile( EMAIL_PATTERN = re.compile(
r"\b[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-z]{2,}\b", re.IGNORECASE r"\b[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-z]{2,}\b", re.IGNORECASE
@ -61,6 +85,9 @@ URL_PATTERN = re.compile(
re.IGNORECASE, re.IGNORECASE,
) )
UNIT_PATTERN = re.compile(r"((?<!\w)([+-]?)(\d{1,3}(,\d{3})*|\d+)(\.\d+)?)\s*(" + "|".join(sorted(list(VALID_UNITS.keys()),reverse=True)) + r"""){1}(?=[^\w\d]{1}|\b)""",re.IGNORECASE)
INFLECT_ENGINE=inflect.engine()
def split_num(num: re.Match[str]) -> str: def split_num(num: re.Match[str]) -> str:
"""Handle number splitting for various formats""" """Handle number splitting for various formats"""
@ -86,6 +113,23 @@ def split_num(num: re.Match[str]) -> str:
return f"{left} oh {right}{s}" return f"{left} oh {right}{s}"
return f"{left} {right}{s}" return f"{left} {right}{s}"
def handle_units(u: re.Match[str]) -> str:
"""Converts units to their full form"""
unit_string=u.group(6).strip()
unit=unit_string
if unit_string.lower() in VALID_UNITS:
unit=VALID_UNITS[unit_string.lower()].split(" ")
# Handles the B vs b case
if unit[0].endswith("bit"):
b_case=unit_string[min(1,len(unit_string) - 1)]
if b_case == "B":
unit[0]=unit[0][:-3] + "byte"
number=u.group(1).strip()
unit[0]=INFLECT_ENGINE.no(unit[0],number)
return " ".join(unit)
def handle_money(m: re.Match[str]) -> str: def handle_money(m: re.Match[str]) -> str:
"""Convert money expressions to spoken form""" """Convert money expressions to spoken form"""
@ -171,30 +215,32 @@ def handle_url(u: re.Match[str]) -> str:
return re.sub(r"\s+", " ", url).strip() return re.sub(r"\s+", " ", url).strip()
def normalize_urls(text: str) -> str: def normalize_text(text: str,normalization_options: NormalizationOptions) -> str:
"""Pre-process URLs before other text normalization"""
# Handle email addresses first
text = EMAIL_PATTERN.sub(handle_email, text)
# Handle URLs
text = URL_PATTERN.sub(handle_url, text)
return text
def normalize_text(text: str) -> str:
"""Normalize text for TTS processing""" """Normalize text for TTS processing"""
# Pre-process URLs first # Handle email addresses first if enabled
text = normalize_urls(text) if normalization_options.email_normalization:
text = EMAIL_PATTERN.sub(handle_email, text)
# Handle URLs if enabled
if normalization_options.url_normalization:
text = URL_PATTERN.sub(handle_url, text)
# Pre-process numbers with units if enabled
if normalization_options.unit_normalization:
text=UNIT_PATTERN.sub(handle_units,text)
# Replace optional pluralization
if normalization_options.optional_pluralization_normalization:
text = re.sub(r"\(s\)","s",text)
# Replace quotes and brackets # Replace quotes and brackets
text = text.replace(chr(8216), "'").replace(chr(8217), "'") text = text.replace(chr(8216), "'").replace(chr(8217), "'")
text = text.replace("«", chr(8220)).replace("»", chr(8221)) text = text.replace("«", chr(8220)).replace("»", chr(8221))
text = text.replace(chr(8220), '"').replace(chr(8221), '"') text = text.replace(chr(8220), '"').replace(chr(8221), '"')
text = text.replace("(", "«").replace(")", "»") text = text.replace("(", "«").replace(")", "»")
# Handle CJK punctuation # Handle CJK punctuation and some non standard chars
for a, b in zip("、。!,:;?", ",.!,:;?"): for a, b in zip("、。!,:;?", ",.!,:;?-"):
text = text.replace(a, b + " ") text = text.replace(a, b + " ")
# Clean up whitespace # Clean up whitespace
@ -216,12 +262,14 @@ def normalize_text(text: str) -> str:
text = re.sub( text = re.sub(
r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)", split_num, text r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)", split_num, text
) )
text = re.sub(r"(?<=\d),(?=\d)", "", text) text = re.sub(r"(?<=\d),(?=\d)", "", text)
text = re.sub( text = re.sub(
r"(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b", r"(?i)[$£]\d+(?:\.\d+)?(?: hundred| thousand| (?:[bm]|tr)illion)*\b|[$£]\d+\.\d\d?\b",
handle_money, handle_money,
text, text,
) )
text = re.sub(r"\d*\.\d+", handle_decimal, text) text = re.sub(r"\d*\.\d+", handle_decimal, text)
# Handle various formatting # Handle various formatting
@ -232,6 +280,6 @@ def normalize_text(text: str) -> str:
text = re.sub( text = re.sub(
r"(?:[A-Za-z]\.){2,} [a-z]", lambda m: m.group().replace(".", "-"), text r"(?:[A-Za-z]\.){2,} [a-z]", lambda m: m.group().replace(".", "-"), text
) )
text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text) text = re.sub( r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
return text.strip() return text.strip()

View file

@ -10,7 +10,7 @@ from ...core.config import settings
from .normalizer import normalize_text from .normalizer import normalize_text
from .phonemizer import phonemize from .phonemizer import phonemize
from .vocabulary import tokenize from .vocabulary import tokenize
from ...structures.schemas import NormalizationOptions
def process_text_chunk( def process_text_chunk(
text: str, language: str = "a", skip_phonemize: bool = False text: str, language: str = "a", skip_phonemize: bool = False
@ -26,7 +26,7 @@ def process_text_chunk(
List of token IDs List of token IDs
""" """
start_time = time.time() start_time = time.time()
if skip_phonemize: if skip_phonemize:
# Input is already phonemes, just tokenize # Input is already phonemes, just tokenize
t0 = time.time() t0 = time.time()
@ -35,12 +35,11 @@ def process_text_chunk(
else: else:
# Normal text processing pipeline # Normal text processing pipeline
t0 = time.time() t0 = time.time()
normalized = normalize_text(text)
t1 = time.time() t1 = time.time()
t0 = time.time() t0 = time.time()
phonemes = phonemize( phonemes = phonemize(
normalized, language, normalize=False text, language, normalize=False
) # Already normalized ) # Already normalized
t1 = time.time() t1 = time.time()
@ -50,7 +49,7 @@ def process_text_chunk(
total_time = time.time() - start_time total_time = time.time() - start_time
logger.debug( logger.debug(
f"Total processing took {total_time * 1000:.2f}ms for chunk: '{text[:50]}...'" f"Total processing took {total_time * 1000:.2f}ms for chunk: '{text[:50]}{'...' if len(text) > 50 else ''}'"
) )
return tokens return tokens
@ -61,7 +60,7 @@ async def yield_chunk(
) -> Tuple[str, List[int]]: ) -> Tuple[str, List[int]]:
"""Yield a chunk with consistent logging.""" """Yield a chunk with consistent logging."""
logger.debug( logger.debug(
f"Yielding chunk {chunk_count}: '{text[:50]}...' ({len(tokens)} tokens)" f"Yielding chunk {chunk_count}: '{text[:50]}{'...' if len(text) > 50 else ''}' ({len(tokens)} tokens)"
) )
return text, tokens return text, tokens
@ -88,9 +87,8 @@ def process_text(text: str, language: str = "a") -> List[int]:
def get_sentence_info(text: str) -> List[Tuple[str, List[int], int]]: def get_sentence_info(text: str) -> List[Tuple[str, List[int], int]]:
"""Process all sentences and return info.""" """Process all sentences and return info."""
sentences = re.split(r"([.!?;:])", text) sentences = re.split(r"([.!?;:])(?=\s|$)", text)
results = [] results = []
for i in range(0, len(sentences), 2): for i in range(0, len(sentences), 2):
sentence = sentences[i].strip() sentence = sentences[i].strip()
punct = sentences[i + 1] if i + 1 < len(sentences) else "" punct = sentences[i + 1] if i + 1 < len(sentences) else ""
@ -106,13 +104,19 @@ def get_sentence_info(text: str) -> List[Tuple[str, List[int], int]]:
async def smart_split( async def smart_split(
text: str, max_tokens: int = settings.absolute_max_tokens text: str,
max_tokens: int = settings.absolute_max_tokens,
normalization_options: NormalizationOptions = NormalizationOptions()
) -> AsyncGenerator[Tuple[str, List[int]], None]: ) -> AsyncGenerator[Tuple[str, List[int]], None]:
"""Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens.""" """Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens."""
start_time = time.time() start_time = time.time()
chunk_count = 0 chunk_count = 0
logger.info(f"Starting smart split for {len(text)} chars") logger.info(f"Starting smart split for {len(text)} chars")
# Normilize text
if settings.advanced_text_normalization and normalization_options.normalize:
text=normalize_text(text,normalization_options)
# Process all sentences # Process all sentences
sentences = get_sentence_info(text) sentences = get_sentence_info(text)
@ -128,7 +132,7 @@ async def smart_split(
chunk_text = " ".join(current_chunk) chunk_text = " ".join(current_chunk)
chunk_count += 1 chunk_count += 1
logger.debug( logger.debug(
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}...' ({current_count} tokens)" f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
) )
yield chunk_text, current_tokens yield chunk_text, current_tokens
current_chunk = [] current_chunk = []
@ -149,6 +153,7 @@ async def smart_split(
continue continue
full_clause = clause + comma full_clause = clause + comma
tokens = process_text_chunk(full_clause) tokens = process_text_chunk(full_clause)
count = len(tokens) count = len(tokens)
@ -166,7 +171,7 @@ async def smart_split(
chunk_text = " ".join(clause_chunk) chunk_text = " ".join(clause_chunk)
chunk_count += 1 chunk_count += 1
logger.debug( logger.debug(
f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}...' ({clause_count} tokens)" f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)"
) )
yield chunk_text, clause_tokens yield chunk_text, clause_tokens
clause_chunk = [full_clause] clause_chunk = [full_clause]
@ -178,7 +183,7 @@ async def smart_split(
chunk_text = " ".join(clause_chunk) chunk_text = " ".join(clause_chunk)
chunk_count += 1 chunk_count += 1
logger.debug( logger.debug(
f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}...' ({clause_count} tokens)" f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)"
) )
yield chunk_text, clause_tokens yield chunk_text, clause_tokens
@ -192,7 +197,7 @@ async def smart_split(
chunk_text = " ".join(current_chunk) chunk_text = " ".join(current_chunk)
chunk_count += 1 chunk_count += 1
logger.info( logger.info(
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}...' ({current_count} tokens)" f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
) )
yield chunk_text, current_tokens yield chunk_text, current_tokens
current_chunk = [sentence] current_chunk = [sentence]
@ -217,7 +222,7 @@ async def smart_split(
chunk_text = " ".join(current_chunk) chunk_text = " ".join(current_chunk)
chunk_count += 1 chunk_count += 1
logger.info( logger.info(
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}...' ({current_count} tokens)" f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
) )
yield chunk_text, current_tokens yield chunk_text, current_tokens
current_chunk = [sentence] current_chunk = [sentence]
@ -229,7 +234,7 @@ async def smart_split(
chunk_text = " ".join(current_chunk) chunk_text = " ".join(current_chunk)
chunk_count += 1 chunk_count += 1
logger.info( logger.info(
f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}...' ({current_count} tokens)" f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
) )
yield chunk_text, current_tokens yield chunk_text, current_tokens

View file

@ -18,7 +18,7 @@ from ..inference.voice_manager import get_manager as get_voice_manager
from .audio import AudioNormalizer, AudioService from .audio import AudioNormalizer, AudioService
from .text_processing import tokenize from .text_processing import tokenize
from .text_processing.text_processor import process_text_chunk, smart_split from .text_processing.text_processor import process_text_chunk, smart_split
from ..structures.schemas import NormalizationOptions
class TTSService: class TTSService:
"""Text-to-speech service.""" """Text-to-speech service."""
@ -67,6 +67,8 @@ class TTSService:
np.array([0], dtype=np.float32), # Dummy data for type checking np.array([0], dtype=np.float32), # Dummy data for type checking
24000, 24000,
output_format, output_format,
speed,
"",
is_first_chunk=False, is_first_chunk=False,
normalizer=normalizer, normalizer=normalizer,
is_last_chunk=True, is_last_chunk=True,
@ -97,15 +99,22 @@ class TTSService:
chunk_audio, chunk_audio,
24000, 24000,
output_format, output_format,
speed,
chunk_text,
is_first_chunk=is_first, is_first_chunk=is_first,
normalizer=normalizer,
is_last_chunk=is_last, is_last_chunk=is_last,
normalizer=normalizer,
) )
yield converted yield converted
except Exception as e: except Exception as e:
logger.error(f"Failed to convert audio: {str(e)}") logger.error(f"Failed to convert audio: {str(e)}")
else: else:
yield chunk_audio trimmed = await AudioService.trim_audio(chunk_audio,
chunk_text,
speed,
is_last,
normalizer)
yield trimmed
else: else:
# For legacy backends, load voice tensor # For legacy backends, load voice tensor
voice_tensor = await self._voice_manager.load_voice( voice_tensor = await self._voice_manager.load_voice(
@ -130,6 +139,8 @@ class TTSService:
chunk_audio, chunk_audio,
24000, 24000,
output_format, output_format,
speed,
chunk_text,
is_first_chunk=is_first, is_first_chunk=is_first,
normalizer=normalizer, normalizer=normalizer,
is_last_chunk=is_last, is_last_chunk=is_last,
@ -138,7 +149,12 @@ class TTSService:
except Exception as e: except Exception as e:
logger.error(f"Failed to convert audio: {str(e)}") logger.error(f"Failed to convert audio: {str(e)}")
else: else:
yield chunk_audio trimmed = await AudioService.trim_audio(chunk_audio,
chunk_text,
speed,
is_last,
normalizer)
yield trimmed
except Exception as e: except Exception as e:
logger.error(f"Failed to process tokens: {str(e)}") logger.error(f"Failed to process tokens: {str(e)}")
@ -222,6 +238,7 @@ class TTSService:
speed: float = 1.0, speed: float = 1.0,
output_format: str = "wav", output_format: str = "wav",
lang_code: Optional[str] = None, lang_code: Optional[str] = None,
normalization_options: Optional[NormalizationOptions] = NormalizationOptions()
) -> AsyncGenerator[bytes, None]: ) -> AsyncGenerator[bytes, None]:
"""Generate and stream audio chunks.""" """Generate and stream audio chunks."""
stream_normalizer = AudioNormalizer() stream_normalizer = AudioNormalizer()
@ -242,7 +259,7 @@ class TTSService:
) )
# Process text in chunks with smart splitting # Process text in chunks with smart splitting
async for chunk_text, tokens in smart_split(text): async for chunk_text, tokens in smart_split(text,normalization_options=normalization_options):
try: try:
# Process audio for chunk # Process audio for chunk
async for result in self._process_chunk( async for result in self._process_chunk(

View file

@ -36,7 +36,14 @@ class CaptionedSpeechResponse(BaseModel):
audio: bytes = Field(..., description="The generated audio data") audio: bytes = Field(..., description="The generated audio data")
words: List[WordTimestamp] = Field(..., description="Word-level timestamps") words: List[WordTimestamp] = Field(..., description="Word-level timestamps")
class NormalizationOptions(BaseModel):
"""Options for the normalization system"""
normalize: bool = Field(default=True, description="Normalizes input text to make it easier for the model to say")
unit_normalization: bool = Field(default=False,description="Transforms units like 10KB to 10 kilobytes")
url_normalization: bool = Field(default=True, description="Changes urls so they can be properly pronouced by kokoro")
email_normalization: bool = Field(default=True, description="Changes emails so they can be properly pronouced by kokoro")
optional_pluralization_normalization: bool = Field(default=True, description="Replaces (s) with s so some words get pronounced correctly")
class OpenAISpeechRequest(BaseModel): class OpenAISpeechRequest(BaseModel):
"""Request schema for OpenAI-compatible speech endpoint""" """Request schema for OpenAI-compatible speech endpoint"""
@ -71,6 +78,10 @@ class OpenAISpeechRequest(BaseModel):
default=None, default=None,
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.", description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
) )
normalization_options: Optional[NormalizationOptions] = Field(
default= NormalizationOptions(),
description= "Options for the normalization system"
)
class CaptionedSpeechRequest(BaseModel): class CaptionedSpeechRequest(BaseModel):

View file

@ -3,29 +3,29 @@
import pytest import pytest
from api.src.services.text_processing.normalizer import normalize_text from api.src.services.text_processing.normalizer import normalize_text
from api.src.structures.schemas import NormalizationOptions
def test_url_protocols(): def test_url_protocols():
"""Test URL protocol handling""" """Test URL protocol handling"""
assert ( assert (
normalize_text("Check out https://example.com") normalize_text("Check out https://example.com",normalization_options=NormalizationOptions())
== "Check out https example dot com" == "Check out https example dot com"
) )
assert normalize_text("Visit http://site.com") == "Visit http site dot com" assert normalize_text("Visit http://site.com",normalization_options=NormalizationOptions()) == "Visit http site dot com"
assert ( assert (
normalize_text("Go to https://test.org/path") normalize_text("Go to https://test.org/path",normalization_options=NormalizationOptions())
== "Go to https test dot org slash path" == "Go to https test dot org slash path"
) )
def test_url_www(): def test_url_www():
"""Test www prefix handling""" """Test www prefix handling"""
assert normalize_text("Go to www.example.com") == "Go to www example dot com" assert normalize_text("Go to www.example.com",normalization_options=NormalizationOptions()) == "Go to www example dot com"
assert ( assert (
normalize_text("Visit www.test.org/docs") == "Visit www test dot org slash docs" normalize_text("Visit www.test.org/docs",normalization_options=NormalizationOptions()) == "Visit www test dot org slash docs"
) )
assert ( assert (
normalize_text("Check www.site.com?q=test") normalize_text("Check www.site.com?q=test",normalization_options=NormalizationOptions())
== "Check www site dot com question-mark q equals test" == "Check www site dot com question-mark q equals test"
) )
@ -33,15 +33,15 @@ def test_url_www():
def test_url_localhost(): def test_url_localhost():
"""Test localhost URL handling""" """Test localhost URL handling"""
assert ( assert (
normalize_text("Running on localhost:7860") normalize_text("Running on localhost:7860",normalization_options=NormalizationOptions())
== "Running on localhost colon 78 60" == "Running on localhost colon 78 60"
) )
assert ( assert (
normalize_text("Server at localhost:8080/api") normalize_text("Server at localhost:8080/api",normalization_options=NormalizationOptions())
== "Server at localhost colon 80 80 slash api" == "Server at localhost colon 80 80 slash api"
) )
assert ( assert (
normalize_text("Test localhost:3000/test?v=1") normalize_text("Test localhost:3000/test?v=1",normalization_options=NormalizationOptions())
== "Test localhost colon 3000 slash test question-mark v equals 1" == "Test localhost colon 3000 slash test question-mark v equals 1"
) )
@ -49,43 +49,43 @@ def test_url_localhost():
def test_url_ip_addresses(): def test_url_ip_addresses():
"""Test IP address URL handling""" """Test IP address URL handling"""
assert ( assert (
normalize_text("Access 0.0.0.0:9090/test") normalize_text("Access 0.0.0.0:9090/test",normalization_options=NormalizationOptions())
== "Access 0 dot 0 dot 0 dot 0 colon 90 90 slash test" == "Access 0 dot 0 dot 0 dot 0 colon 90 90 slash test"
) )
assert ( assert (
normalize_text("API at 192.168.1.1:8000") normalize_text("API at 192.168.1.1:8000",normalization_options=NormalizationOptions())
== "API at 192 dot 168 dot 1 dot 1 colon 8000" == "API at 192 dot 168 dot 1 dot 1 colon 8000"
) )
assert normalize_text("Server 127.0.0.1") == "Server 127 dot 0 dot 0 dot 1" assert normalize_text("Server 127.0.0.1",normalization_options=NormalizationOptions()) == "Server 127 dot 0 dot 0 dot 1"
def test_url_raw_domains(): def test_url_raw_domains():
"""Test raw domain handling""" """Test raw domain handling"""
assert ( assert (
normalize_text("Visit google.com/search") == "Visit google dot com slash search" normalize_text("Visit google.com/search",normalization_options=NormalizationOptions()) == "Visit google dot com slash search"
) )
assert ( assert (
normalize_text("Go to example.com/path?q=test") normalize_text("Go to example.com/path?q=test",normalization_options=NormalizationOptions())
== "Go to example dot com slash path question-mark q equals test" == "Go to example dot com slash path question-mark q equals test"
) )
assert normalize_text("Check docs.test.com") == "Check docs dot test dot com" assert normalize_text("Check docs.test.com",normalization_options=NormalizationOptions()) == "Check docs dot test dot com"
def test_url_email_addresses(): def test_url_email_addresses():
"""Test email address handling""" """Test email address handling"""
assert ( assert (
normalize_text("Email me at user@example.com") normalize_text("Email me at user@example.com",normalization_options=NormalizationOptions())
== "Email me at user at example dot com" == "Email me at user at example dot com"
) )
assert normalize_text("Contact admin@test.org") == "Contact admin at test dot org" assert normalize_text("Contact admin@test.org",normalization_options=NormalizationOptions()) == "Contact admin at test dot org"
assert ( assert (
normalize_text("Send to test.user@site.com") normalize_text("Send to test.user@site.com",normalization_options=NormalizationOptions())
== "Send to test dot user at site dot com" == "Send to test dot user at site dot com"
) )
def test_non_url_text(): def test_non_url_text():
"""Test that non-URL text is unaffected""" """Test that non-URL text is unaffected"""
assert normalize_text("This is not.a.url text") == "This is not-a-url text" assert normalize_text("This is not.a.url text",normalization_options=NormalizationOptions()) == "This is not-a-url text"
assert normalize_text("Hello, how are you today?") == "Hello, how are you today?" assert normalize_text("Hello, how are you today?",normalization_options=NormalizationOptions()) == "Hello, how are you today?"
assert normalize_text("It costs $50.") == "It costs 50 dollars." assert normalize_text("It costs $50.",normalization_options=NormalizationOptions()) == "It costs 50 dollars."

View file

@ -0,0 +1,23 @@
# Patterns to ignore when building packages.
# This supports shell glob matching, relative path matching, and
# negation (prefixed with !). Only one pattern per line.
.DS_Store
# Common VCS dirs
.git/
.gitignore
.bzr/
.bzrignore
.hg/
.hgignore
.svn/
# Common backup files
*.swp
*.bak
*.tmp
*.orig
*~
# Various IDEs
.project
.idea/
*.tmproj
.vscode/

View file

@ -0,0 +1,24 @@
apiVersion: v2
name: kokoro-fastapi
description: A Helm chart for kokoro-fastapi
# A chart can be either an 'application' or a 'library' chart.
#
# Application charts are a collection of templates that can be packaged into versioned archives
# to be deployed.
#
# Library charts provide useful utilities or functions for the chart developer. They're included as
# a dependency of application charts to inject those utilities and functions into the rendering
# pipeline. Library charts do not define any templates and therefore cannot be deployed.
type: application
# This is the chart version. This version number should be incremented each time you make changes
# to the chart and its templates, including the app version.
# Versions are expected to follow Semantic Versioning (https://semver.org/)
version: 0.1.0
# This is the version number of the application being deployed. This version number should be
# incremented each time you make changes to the application. Versions are not expected to
# follow Semantic Versioning. They should reflect the version the application is using.
# It is recommended to use it with quotes.
appVersion: "1.16.0"

View file

@ -0,0 +1,22 @@
1. Get the application URL by running these commands:
{{- if .Values.ingress.enabled }}
{{- range $host := .Values.ingress.hosts }}
{{- range .paths }}
http{{ if $.Values.ingress.tls }}s{{ end }}://{{ $host.host }}{{ .path }}
{{- end }}
{{- end }}
{{- else if contains "NodePort" .Values.service.type }}
export NODE_PORT=$(kubectl get --namespace {{ .Release.Namespace }} -o jsonpath="{.spec.ports[0].nodePort}" services {{ include "kokoro-fastapi.fullname" . }})
export NODE_IP=$(kubectl get nodes --namespace {{ .Release.Namespace }} -o jsonpath="{.items[0].status.addresses[0].address}")
echo http://$NODE_IP:$NODE_PORT
{{- else if contains "LoadBalancer" .Values.service.type }}
NOTE: It may take a few minutes for the LoadBalancer IP to be available.
You can watch the status of by running 'kubectl get --namespace {{ .Release.Namespace }} svc -w {{ include "kokoro-fastapi.fullname" . }}'
export SERVICE_IP=$(kubectl get svc --namespace {{ .Release.Namespace }} {{ include "kokoro-fastapi.fullname" . }} --template "{{"{{ range (index .status.loadBalancer.ingress 0) }}{{.}}{{ end }}"}}")
echo http://$SERVICE_IP:{{ .Values.service.port }}
{{- else if contains "ClusterIP" .Values.service.type }}
export POD_NAME=$(kubectl get pods --namespace {{ .Release.Namespace }} -l "app.kubernetes.io/name={{ include "kokoro-fastapi.name" . }},app.kubernetes.io/instance={{ .Release.Name }}" -o jsonpath="{.items[0].metadata.name}")
export CONTAINER_PORT=$(kubectl get pod --namespace {{ .Release.Namespace }} $POD_NAME -o jsonpath="{.spec.containers[0].ports[0].containerPort}")
echo "Visit http://127.0.0.1:8080 to use your application"
kubectl --namespace {{ .Release.Namespace }} port-forward $POD_NAME 8080:$CONTAINER_PORT
{{- end }}

View file

@ -0,0 +1,62 @@
{{/*
Expand the name of the chart.
*/}}
{{- define "kokoro-fastapi.name" -}}
{{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }}
{{- end }}
{{/*
Create a default fully qualified app name.
We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec).
If release name contains chart name it will be used as a full name.
*/}}
{{- define "kokoro-fastapi.fullname" -}}
{{- if .Values.fullnameOverride }}
{{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }}
{{- else }}
{{- $name := default .Chart.Name .Values.nameOverride }}
{{- if contains $name .Release.Name }}
{{- .Release.Name | trunc 63 | trimSuffix "-" }}
{{- else }}
{{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }}
{{- end }}
{{- end }}
{{- end }}
{{/*
Create chart name and version as used by the chart label.
*/}}
{{- define "kokoro-fastapi.chart" -}}
{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }}
{{- end }}
{{/*
Common labels
*/}}
{{- define "kokoro-fastapi.labels" -}}
helm.sh/chart: {{ include "kokoro-fastapi.chart" . }}
{{ include "kokoro-fastapi.selectorLabels" . }}
{{- if .Chart.AppVersion }}
app.kubernetes.io/version: {{ .Chart.AppVersion | quote }}
{{- end }}
app.kubernetes.io/managed-by: {{ .Release.Service }}
{{- end }}
{{/*
Selector labels
*/}}
{{- define "kokoro-fastapi.selectorLabels" -}}
app.kubernetes.io/name: {{ include "kokoro-fastapi.name" . }}
app.kubernetes.io/instance: {{ .Release.Name }}
{{- end }}
{{/*
Create the name of the service account to use
*/}}
{{- define "kokoro-fastapi.serviceAccountName" -}}
{{- if .Values.serviceAccount.create }}
{{- default (include "kokoro-fastapi.fullname" .) .Values.serviceAccount.name }}
{{- else }}
{{- default "default" .Values.serviceAccount.name }}
{{- end }}
{{- end }}

View file

@ -0,0 +1,28 @@
{{- if .Values.autoscaling.enabled }}
apiVersion: autoscaling/v2beta1
kind: HorizontalPodAutoscaler
metadata:
name: {{ include "kokoro-fastapi.fullname" . }}
labels:
{{- include "kokoro-fastapi.labels" . | nindent 4 }}
spec:
scaleTargetRef:
apiVersion: apps/v1
kind: Deployment
name: {{ include "kokoro-fastapi.fullname" . }}
minReplicas: {{ .Values.autoscaling.minReplicas }}
maxReplicas: {{ .Values.autoscaling.maxReplicas }}
metrics:
{{- if .Values.autoscaling.targetCPUUtilizationPercentage }}
- type: Resource
resource:
name: cpu
targetAverageUtilization: {{ .Values.autoscaling.targetCPUUtilizationPercentage }}
{{- end }}
{{- if .Values.autoscaling.targetMemoryUtilizationPercentage }}
- type: Resource
resource:
name: memory
targetAverageUtilization: {{ .Values.autoscaling.targetMemoryUtilizationPercentage }}
{{- end }}
{{- end }}

View file

@ -0,0 +1,82 @@
{{- if .Values.ingress.enabled -}}
{{- $fullName := include "kokoro-fastapi.fullname" . -}}
{{- $svcPort := .Values.service.port -}}
{{- $rewriteTargets := (list) -}}
{{- with .Values.ingress.host }}
{{- range .endpoints }}
{{- $serviceName := default $fullName .serviceName -}}
{{- $rewrite := .rewrite | default "none" -}}
{{- if not (has $rewrite $rewriteTargets ) -}}
{{- $rewriteTargets = append $rewriteTargets $rewrite -}}
{{- end -}}
{{- end}}
{{- end }}
{{- range $key := $rewriteTargets }}
{{- $expandedRewrite := regexReplaceAll "/(.*)$" $key "slash${1}" -}}
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
{{- if eq $key "none" }}
name: {{ $fullName }}
{{- else }}
name: {{ $fullName }}-{{ $expandedRewrite }}
{{- end }}
labels:
{{- include "kokoro-fastapi.labels" $ | nindent 4 }}
{{- if ne $key "none" }}
annotations:
nginx.ingress.kubernetes.io/rewrite-target: {{ regexReplaceAll "/$" $key "" }}/$2
{{- end }}
spec:
{{- if $.Values.ingress.tls }}
tls:
{{- range $.Values.ingress.tls }}
- hosts:
{{- range .hosts }}
- {{ . | quote }}
{{- end }}
secretName: {{ .secretName }}
{{- end }}
{{- end }}
rules:
{{- with $.Values.ingress.host }}
- host: {{ .name | quote }}
http:
paths:
{{- range .endpoints }}
{{- $serviceName := default $fullName .serviceName -}}
{{- $servicePort := default (print "http") .servicePort -}}
{{- if eq ( .rewrite | default "none" ) $key }}
{{- range .paths }}
{{- if not (contains "@" .) }}
{{- if eq $key "none" }}
- path: {{ . }}
{{- else }}
- path: {{ regexReplaceAll "(.*)/$" . "${1}" }}(/|$)(.*)
{{- end }}
pathType: Prefix
backend:
service:
name: "{{ $fullName }}-{{ $serviceName }}"
port:
number: {{ $servicePort }}
{{- else }}
{{- $path := . -}}
{{- $replicaCount := include "getServiceNameReplicaCount" (dict "global" $.Values "serviceName" $serviceName ) -}}
{{- range $count, $e := until ($replicaCount|int) }}
- path: {{ $path | replace "@" ( . | toString ) }}(/|$)(.*)
pathType: Prefix
backend:
service:
name: "{{ $fullName }}-{{ $serviceName }}-{{ . }}"
port:
number: {{ $servicePort }}
{{- end }}
{{- end }}
{{- end }}
{{- end }}
{{- end }}
{{- end }}
---
{{- end }}
{{- end }}

View file

@ -0,0 +1,71 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: {{ include "kokoro-fastapi.fullname" . }}-kokoro-tts
labels:
{{- include "kokoro-fastapi.labels" . | nindent 4 }}
spec:
{{- if not .Values.autoscaling.enabled }}
replicas: {{ .Values.kokoroTTS.replicaCount }}
{{- end }}
selector:
matchLabels:
{{- include "kokoro-fastapi.selectorLabels" . | nindent 6 }}
template:
metadata:
{{- with .Values.podAnnotations }}
annotations:
{{- toYaml . | nindent 8 }}
{{- end }}
labels:
{{- include "kokoro-fastapi.selectorLabels" . | nindent 8 }}
spec:
{{- with .Values.images.imagePullSecrets }}
imagePullSecrets:
{{- toYaml . | nindent 8 }}
{{- end }}
serviceAccountName: {{ include "kokoro-fastapi.serviceAccountName" . }}
securityContext:
{{- toYaml .Values.podSecurityContext | nindent 8 }}
initContainers: []
containers:
- name: kokoro-tts
securityContext:
{{- toYaml .Values.securityContext | nindent 12 }}
image: "{{ .Values.kokoroTTS.repository }}:{{ .Values.kokoroTTS.tag | default .Chart.AppVersion }}"
imagePullPolicy: {{ .Values.kokoroTTS.pullPolicy }}
env:
- name: PYTHONPATH
value: "/app:/app/api"
- name: USE_GPU
value: "true"
- name: PYTHONUNBUFFERED
value: "1"
ports:
- name: kokoro-tts-http
containerPort: {{ .Values.kokoroTTS.port | default 8880 }}
protocol: TCP
livenessProbe:
httpGet:
path: /health
port: kokoro-tts-http
readinessProbe:
httpGet:
path: /health
port: kokoro-tts-http
resources:
{{- toYaml .Values.kokoroTTS.resources | nindent 12 }}
volumeMounts: []
volumes: []
{{- with .Values.nodeSelector }}
nodeSelector:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.affinity }}
affinity:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.tolerations }}
tolerations:
{{- toYaml . | nindent 8 }}
{{- end }}

View file

@ -0,0 +1,15 @@
apiVersion: v1
kind: Service
metadata:
name: {{ include "kokoro-fastapi.fullname" . }}-kokoro-tts-service
labels:
{{- include "kokoro-fastapi.labels" . | nindent 4 }}
spec:
type: {{ .Values.service.type }}
ports:
- port: {{ .Values.kokoroTTS.port }}
targetPort: kokoro-tts-http
protocol: TCP
name: kokoro-tts-http
selector:
{{- include "kokoro-fastapi.selectorLabels" . | nindent 4 }}

View file

@ -0,0 +1,12 @@
{{- if .Values.serviceAccount.create -}}
apiVersion: v1
kind: ServiceAccount
metadata:
name: {{ include "kokoro-fastapi.serviceAccountName" . }}
labels:
{{- include "kokoro-fastapi.labels" . | nindent 4 }}
{{- with .Values.serviceAccount.annotations }}
annotations:
{{- toYaml . | nindent 4 }}
{{- end }}
{{- end }}

View file

@ -0,0 +1,15 @@
apiVersion: v1
kind: Pod
metadata:
name: "{{ include "kokoro-fastapi.fullname" . }}-test-connection"
labels:
{{- include "kokoro-fastapi.labels" . | nindent 4 }}
annotations:
"helm.sh/hook": test
spec:
containers:
- name: wget
image: busybox
command: ['wget']
args: ['{{ include "kokoro-fastapi.fullname" . }}:{{ .Values.service.port }}']
restartPolicy: Never

View file

@ -0,0 +1,94 @@
# Default values for kokoro-fastapi.
# This is a YAML-formatted file.
# Declare variables to be passed into your templates.
replicaCount: 1
images:
pullPolicy: "Always"
imagePullSecrets: [ ]
nameOverride: ""
fullnameOverride: ""
serviceAccount:
# Specifies whether a service account should be created
create: true
# Annotations to add to the service account
annotations: {}
# The name of the service account to use.
# If not set and create is true, a name is generated using the fullname template
name: ""
podAnnotations: {}
podSecurityContext: {}
# fsGroup: 2000
securityContext: {}
# capabilities:
# drop:
# - ALL
# readOnlyRootFilesystem: true
# runAsNonRoot: true
# runAsUser: 1000
service:
type: ClusterIP
ingress:
enabled: false
className: ""
annotations: {}
# kubernetes.io/ingress.class: nginx
# kubernetes.io/tls-acme: "true"
host:
name: kokoro.example.com
endpoints:
- paths:
- "/"
serviceName: "fastapi"
servicePort: 8880
tls: []
# - secretName: chart-example-tls
# hosts:
# - chart-example.local
kokoroTTS:
repository: "ghcr.io/remsky/kokoro-fastapi-gpu"
tag: "latest"
pullPolicy: Always
serviceName: "fastapi"
port: 8880
replicaCount: 1
resources:
limits:
nvidia.com/gpu: 1
requests:
nvidia.com/gpu: 1
# We usually recommend not to specify default resources and to leave this as a conscious
# choice for the user. This also increases chances charts run on environments with little
# resources, such as Minikube. If you do want to specify resources, uncomment the following
# lines, adjust them as necessary, and remove the curly braces after 'resources:'.
# limits:
# cpu: 100m
# memory: 128Mi
# requests:
# cpu: 100m
# memory: 128Mi
autoscaling:
enabled: false
minReplicas: 1
maxReplicas: 100
targetCPUUtilizationPercentage: 80
# targetMemoryUtilizationPercentage: 80
nodeSelector: {}
tolerations: []
affinity: {}

View file

@ -9,10 +9,10 @@ RUN apt-get update && apt-get install -y \
curl \ curl \
ffmpeg \ ffmpeg \
g++ \ g++ \
&& apt-get clean \ && apt-get clean \
&& rm -rf /var/lib/apt/lists/* \ && rm -rf /var/lib/apt/lists/* \
&& mkdir -p /usr/share/espeak-ng-data \ && mkdir -p /usr/share/espeak-ng-data \
&& ln -s /usr/lib/*/espeak-ng-data/* /usr/share/espeak-ng-data/ && ln -s /usr/lib/*/espeak-ng-data/* /usr/share/espeak-ng-data/
# Install UV using the installer script # Install UV using the installer script
RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \ RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
@ -32,7 +32,7 @@ COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml
# Install dependencies # Install dependencies
RUN --mount=type=cache,target=/root/.cache/uv \ RUN --mount=type=cache,target=/root/.cache/uv \
uv venv && \ uv venv --python 3.10 && \
uv sync --extra cpu uv sync --extra cpu
# Copy project files including models # Copy project files including models

View file

@ -1,11 +1,11 @@
FROM --platform=$BUILDPLATFORM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04 FROM --platform=$BUILDPLATFORM nvidia/cuda:12.8.0-cudnn-runtime-ubuntu24.04
# Set non-interactive frontend # Set non-interactive frontend
ENV DEBIAN_FRONTEND=noninteractive ENV DEBIAN_FRONTEND=noninteractive
# Install Python and other dependencies # Install Python and other dependencies
RUN apt-get update && apt-get install -y \ RUN apt-get update && apt-get install -y \
python3.10 \ python3.10 \
python3.10-venv \ python3-venv \
espeak-ng \ espeak-ng \
espeak-ng-data \ espeak-ng-data \
git \ git \
@ -23,7 +23,7 @@ RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
mv /root/.local/bin/uvx /usr/local/bin/ mv /root/.local/bin/uvx /usr/local/bin/
# Create non-root user and set up directories and permissions # Create non-root user and set up directories and permissions
RUN useradd -m -u 1000 appuser && \ RUN useradd -m -u 1001 appuser && \
mkdir -p /app/api/src/models/v1_0 && \ mkdir -p /app/api/src/models/v1_0 && \
chown -R appuser:appuser /app chown -R appuser:appuser /app
@ -39,7 +39,7 @@ ENV PHONEMIZER_ESPEAK_PATH=/usr/bin \
# Install dependencies with GPU extras (using cache mounts) # Install dependencies with GPU extras (using cache mounts)
RUN --mount=type=cache,target=/root/.cache/uv \ RUN --mount=type=cache,target=/root/.cache/uv \
uv venv && \ uv venv --python 3.10 && \
uv sync --extra gpu uv sync --extra gpu
# Copy project files including models # Copy project files including models

View file

@ -18,8 +18,6 @@ dependencies = [
"scipy==1.14.1", "scipy==1.14.1",
# Audio processing # Audio processing
"soundfile==0.13.0", "soundfile==0.13.0",
# Text processing
"phonemizer==3.3.0",
"regex==2024.11.6", "regex==2024.11.6",
# Utilities # Utilities
"aiofiles==23.2.1", "aiofiles==23.2.1",
@ -36,7 +34,9 @@ dependencies = [
"kokoro @ git+https://github.com/hexgrad/kokoro.git@31a2b6337b8c1b1418ef68c48142328f640da938", "kokoro @ git+https://github.com/hexgrad/kokoro.git@31a2b6337b8c1b1418ef68c48142328f640da938",
'misaki[en,ja,ko,zh] @ git+https://github.com/hexgrad/misaki.git@ebc76c21b66c5fc4866ed0ec234047177b396170', 'misaki[en,ja,ko,zh] @ git+https://github.com/hexgrad/misaki.git@ebc76c21b66c5fc4866ed0ec234047177b396170',
"spacy==3.7.2", "spacy==3.7.2",
"en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl" "en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl",
"inflect>=7.5.0",
"phonemizer-fork>=3.3.2",
] ]
[project.optional-dependencies] [project.optional-dependencies]