Merge branch 'remsky:master' into master

This commit is contained in:
Kishor Prins 2025-06-21 09:53:19 -07:00 committed by GitHub
commit 0241423375
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 512 additions and 195 deletions

View file

@ -13,7 +13,7 @@
[![Tested at Model Commit](https://img.shields.io/badge/last--tested--model--commit-1.0::9901c2b-blue)](https://huggingface.co/hexgrad/Kokoro-82M/commit/9901c2b79161b6e898b7ea857ae5298f47b8b0d6) [![Tested at Model Commit](https://img.shields.io/badge/last--tested--model--commit-1.0::9901c2b-blue)](https://huggingface.co/hexgrad/Kokoro-82M/commit/9901c2b79161b6e898b7ea857ae5298f47b8b0d6)
Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model
- Multi-language support (English, Japanese, Korean, Chinese, _Vietnamese soon_) - Multi-language support (English, Japanese, Chinese, _Vietnamese soon_)
- OpenAI-compatible Speech endpoint, NVIDIA GPU accelerated or CPU inference with PyTorch - OpenAI-compatible Speech endpoint, NVIDIA GPU accelerated or CPU inference with PyTorch
- ONNX support coming soon, see v0.1.5 and earlier for legacy ONNX support in the interim - ONNX support coming soon, see v0.1.5 and earlier for legacy ONNX support in the interim
- Debug endpoints for monitoring system stats, integrated web UI on localhost:8880/web - Debug endpoints for monitoring system stats, integrated web UI on localhost:8880/web

View file

@ -1 +1 @@
0.3.0 0.2.4

View file

@ -31,6 +31,7 @@ class Settings(BaseSettings):
# Audio Settings # Audio Settings
sample_rate: int = 24000 sample_rate: int = 24000
default_volume_multiplier: float = 1.0
# Text Processing Settings # Text Processing Settings
target_min_tokens: int = 175 # Target minimum tokens per chunk target_min_tokens: int = 175 # Target minimum tokens per chunk
target_max_tokens: int = 250 # Target maximum tokens per chunk target_max_tokens: int = 250 # Target maximum tokens per chunk

View file

@ -300,7 +300,7 @@ async def get_web_file_path(filename: str) -> str:
) )
# Construct web directory path relative to project root # Construct web directory path relative to project root
web_dir = os.path.join("/app", settings.web_player_path) web_dir = os.path.join(root_dir, settings.web_player_path)
# Search in web directory # Search in web directory
search_paths = [web_dir] search_paths = [web_dir]

View file

@ -141,6 +141,8 @@ Model files not found! You need to download the Kokoro V1 model:
try: try:
async for chunk in self._backend.generate(*args, **kwargs): async for chunk in self._backend.generate(*args, **kwargs):
if settings.default_volume_multiplier != 1.0:
chunk.audio *= settings.default_volume_multiplier
yield chunk yield chunk
except Exception as e: except Exception as e:
raise RuntimeError(f"Generation failed: {e}") raise RuntimeError(f"Generation failed: {e}")

View file

@ -319,6 +319,7 @@ async def create_captioned_speech(
writer=writer, writer=writer,
speed=request.speed, speed=request.speed,
return_timestamps=request.return_timestamps, return_timestamps=request.return_timestamps,
volume_multiplier=request.volume_multiplier,
normalization_options=request.normalization_options, normalization_options=request.normalization_options,
lang_code=request.lang_code, lang_code=request.lang_code,
) )

View file

@ -152,6 +152,7 @@ async def stream_audio_chunks(
speed=request.speed, speed=request.speed,
output_format=request.response_format, output_format=request.response_format,
lang_code=request.lang_code, lang_code=request.lang_code,
volume_multiplier=request.volume_multiplier,
normalization_options=request.normalization_options, normalization_options=request.normalization_options,
return_timestamps=unique_properties["return_timestamps"], return_timestamps=unique_properties["return_timestamps"],
): ):
@ -300,6 +301,7 @@ async def create_speech(
voice=voice_name, voice=voice_name,
writer=writer, writer=writer,
speed=request.speed, speed=request.speed,
volume_multiplier=request.volume_multiplier,
normalization_options=request.normalization_options, normalization_options=request.normalization_options,
lang_code=request.lang_code, lang_code=request.lang_code,
) )

View file

@ -80,12 +80,12 @@ class AudioNormalizer:
non_silent_index_start, non_silent_index_end = None, None non_silent_index_start, non_silent_index_end = None, None
for X in range(0, len(audio_data)): for X in range(0, len(audio_data)):
if audio_data[X] > amplitude_threshold: if abs(audio_data[X]) > amplitude_threshold:
non_silent_index_start = X non_silent_index_start = X
break break
for X in range(len(audio_data) - 1, -1, -1): for X in range(len(audio_data) - 1, -1, -1):
if audio_data[X] > amplitude_threshold: if abs(audio_data[X]) > amplitude_threshold:
non_silent_index_end = X non_silent_index_end = X
break break

View file

@ -32,19 +32,29 @@ class StreamingAudioWriter:
if self.format in ["wav", "flac", "mp3", "pcm", "aac", "opus"]: if self.format in ["wav", "flac", "mp3", "pcm", "aac", "opus"]:
if self.format != "pcm": if self.format != "pcm":
self.output_buffer = BytesIO() self.output_buffer = BytesIO()
container_options = {}
# Try disabling Xing VBR header for MP3 to fix iOS timeline reading issues
if self.format == 'mp3':
# Disable Xing VBR header
container_options = {'write_xing': '0'}
logger.debug("Disabling Xing VBR header for MP3 encoding.")
self.container = av.open( self.container = av.open(
self.output_buffer, self.output_buffer,
mode="w", mode="w",
format=self.format if self.format != "aac" else "adts", format=self.format if self.format != "aac" else "adts",
options=container_options # Pass options here
) )
self.stream = self.container.add_stream( self.stream = self.container.add_stream(
codec_map[self.format], codec_map[self.format],
sample_rate=self.sample_rate, rate=self.sample_rate,
layout="mono" if self.channels == 1 else "stereo", layout="mono" if self.channels == 1 else "stereo",
) )
# Set bit_rate only for codecs where it's applicable and useful
if self.format in ['mp3', 'aac', 'opus']:
self.stream.bit_rate = 128000 self.stream.bit_rate = 128000
else: else:
raise ValueError(f"Unsupported format: {format}") raise ValueError(f"Unsupported format: {self.format}") # Use self.format here
def close(self): def close(self):
if hasattr(self, "container"): if hasattr(self, "container"):
@ -65,12 +75,18 @@ class StreamingAudioWriter:
if finalize: if finalize:
if self.format != "pcm": if self.format != "pcm":
# Flush stream encoder
packets = self.stream.encode(None) packets = self.stream.encode(None)
for packet in packets: for packet in packets:
self.container.mux(packet) self.container.mux(packet)
# Closing the container handles writing the trailer and finalizing the file.
# No explicit flush method is available or needed here.
logger.debug("Muxed final packets.")
# Get the final bytes from the buffer *before* closing it
data = self.output_buffer.getvalue() data = self.output_buffer.getvalue()
self.close() self.close() # Close container and buffer
return data return data
if audio_data is None or len(audio_data) == 0: if audio_data is None or len(audio_data) == 0:

View file

@ -11,7 +11,7 @@ from typing import List, Optional, Union
import inflect import inflect
from numpy import number from numpy import number
from text_to_num import text2num # from text_to_num import text2num
from torch import mul from torch import mul
from ...structures.schemas import NormalizationOptions from ...structures.schemas import NormalizationOptions
@ -134,6 +134,23 @@ VALID_UNITS = {
"px": "pixel", # CSS units "px": "pixel", # CSS units
} }
SYMBOL_REPLACEMENTS = {
'~': ' ',
'@': ' at ',
'#': ' number ',
'$': ' dollar ',
'%': ' percent ',
'^': ' ',
'&': ' and ',
'*': ' ',
'_': ' ',
'|': ' ',
'\\': ' ',
'/': ' slash ',
'=': ' equals ',
'+': ' plus ',
}
MONEY_UNITS = {"$": ("dollar", "cent"), "£": ("pound", "pence"), "": ("euro", "cent")} MONEY_UNITS = {"$": ("dollar", "cent"), "£": ("pound", "pence"), "": ("euro", "cent")}
# Pre-compiled regex patterns for performance # Pre-compiled regex patterns for performance
@ -391,6 +408,7 @@ def handle_time(t: re.Match[str]) -> str:
def normalize_text(text: str, normalization_options: NormalizationOptions) -> str: def normalize_text(text: str, normalization_options: NormalizationOptions) -> str:
"""Normalize text for TTS processing""" """Normalize text for TTS processing"""
# Handle email addresses first if enabled # Handle email addresses first if enabled
if normalization_options.email_normalization: if normalization_options.email_normalization:
text = EMAIL_PATTERN.sub(handle_email, text) text = EMAIL_PATTERN.sub(handle_email, text)
@ -415,7 +433,7 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
text, text,
) )
# Replace quotes and brackets # Replace quotes and brackets (additional cleanup)
text = text.replace(chr(8216), "'").replace(chr(8217), "'") text = text.replace(chr(8216), "'").replace(chr(8217), "'")
text = text.replace("«", chr(8220)).replace("»", chr(8221)) text = text.replace("«", chr(8220)).replace("»", chr(8221))
text = text.replace(chr(8220), '"').replace(chr(8221), '"') text = text.replace(chr(8220), '"').replace(chr(8221), '"')
@ -435,6 +453,11 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
text = re.sub(r" +", " ", text) text = re.sub(r" +", " ", text)
text = re.sub(r"(?<=\n) +(?=\n)", "", text) text = re.sub(r"(?<=\n) +(?=\n)", "", text)
# Handle special characters that might cause audio artifacts first
# Replace newlines with spaces (or pauses if needed)
text = text.replace('\n', ' ')
text = text.replace('\r', ' ')
# Handle titles and abbreviations # Handle titles and abbreviations
text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text) text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text)
text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text) text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text)
@ -445,7 +468,7 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
# Handle common words # Handle common words
text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text) text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text)
# Handle numbers and money # Handle numbers and money BEFORE replacing special characters
text = re.sub(r"(?<=\d),(?=\d)", "", text) text = re.sub(r"(?<=\d),(?=\d)", "", text)
text = MONEY_PATTERN.sub( text = MONEY_PATTERN.sub(
@ -457,6 +480,11 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
text = re.sub(r"\d*\.\d+", handle_decimal, text) text = re.sub(r"\d*\.\d+", handle_decimal, text)
# Handle other problematic symbols AFTER money/number processing
if normalization_options.replace_remaining_symbols:
for symbol, replacement in SYMBOL_REPLACEMENTS.items():
text = text.replace(symbol, replacement)
# Handle various formatting # Handle various formatting
text = re.sub(r"(?<=\d)-(?=\d)", " to ", text) text = re.sub(r"(?<=\d)-(?=\d)", " to ", text)
text = re.sub(r"(?<=\d)S", " S", text) text = re.sub(r"(?<=\d)S", " S", text)
@ -467,4 +495,6 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
) )
text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text) text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
text = re.sub(r"\s{2,}", " ", text)
return text.strip() return text.strip()

View file

@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
import phonemizer import phonemizer
from .normalizer import normalize_text from .normalizer import normalize_text
from ...structures.schemas import NormalizationOptions
phonemizers = {} phonemizers = {}
@ -75,7 +76,7 @@ def create_phonemizer(language: str = "a") -> PhonemizerBackend:
Phonemizer backend instance Phonemizer backend instance
""" """
# Map language codes to espeak language codes # Map language codes to espeak language codes
lang_map = {"a": "en-us", "b": "en-gb"} lang_map = {"a": "en-us", "b": "en-gb", "z": "z"}
if language not in lang_map: if language not in lang_map:
raise ValueError(f"Unsupported language code: {language}") raise ValueError(f"Unsupported language code: {language}")
@ -83,20 +84,24 @@ def create_phonemizer(language: str = "a") -> PhonemizerBackend:
return EspeakBackend(lang_map[language]) return EspeakBackend(lang_map[language])
def phonemize(text: str, language: str = "a", normalize: bool = True) -> str: def phonemize(text: str, language: str = "a") -> str:
"""Convert text to phonemes """Convert text to phonemes
Args: Args:
text: Text to convert to phonemes text: Text to convert to phonemes
language: Language code ('a' for US English, 'b' for British English) language: Language code ('a' for US English, 'b' for British English)
normalize: Whether to normalize text before phonemization
Returns: Returns:
Phonemized text Phonemized text
""" """
global phonemizers global phonemizers
if normalize:
text = normalize_text(text) # Strip input text first to remove problematic leading/trailing spaces
text = text.strip()
if language not in phonemizers: if language not in phonemizers:
phonemizers[language] = create_phonemizer(language) phonemizers[language] = create_phonemizer(language)
return phonemizers[language].phonemize(text)
result = phonemizers[language].phonemize(text)
# Final strip to ensure no leading/trailing spaces in phonemes
return result.strip()

View file

@ -2,7 +2,7 @@
import re import re
import time import time
from typing import AsyncGenerator, Dict, List, Tuple from typing import AsyncGenerator, Dict, List, Tuple, Optional
from loguru import logger from loguru import logger
@ -13,7 +13,11 @@ from .phonemizer import phonemize
from .vocabulary import tokenize from .vocabulary import tokenize
# Pre-compiled regex patterns for performance # Pre-compiled regex patterns for performance
CUSTOM_PHONEMES = re.compile(r"(\[([^\]]|\n)*?\])(\(\/([^\/)]|\n)*?\/\))") # Updated regex to be more strict and avoid matching isolated brackets
# Only matches complete patterns like [word](/ipa/) and prevents catastrophic backtracking
CUSTOM_PHONEMES = re.compile(r"(\[[^\[\]]*?\])(\(\/[^\/\(\)]*?\/\))")
# Pattern to find pause tags like [pause:0.5s]
PAUSE_TAG_PATTERN = re.compile(r"\[pause:(\d+(?:\.\d+)?)s\]", re.IGNORECASE)
def process_text_chunk( def process_text_chunk(
@ -31,6 +35,12 @@ def process_text_chunk(
""" """
start_time = time.time() start_time = time.time()
# Strip input text to remove any leading/trailing spaces that could cause artifacts
text = text.strip()
if not text:
return []
if skip_phonemize: if skip_phonemize:
# Input is already phonemes, just tokenize # Input is already phonemes, just tokenize
t0 = time.time() t0 = time.time()
@ -42,7 +52,9 @@ def process_text_chunk(
t1 = time.time() t1 = time.time()
t0 = time.time() t0 = time.time()
phonemes = phonemize(text, language, normalize=False) # Already normalized phonemes = phonemize(text, language)
# Strip phonemes result to ensure no extra spaces
phonemes = phonemes.strip()
t1 = time.time() t1 = time.time()
t0 = time.time() t0 = time.time()
@ -88,9 +100,15 @@ def process_text(text: str, language: str = "a") -> List[int]:
def get_sentence_info( def get_sentence_info(
text: str, custom_phenomes_list: Dict[str, str] text: str, custom_phenomes_list: Dict[str, str], lang_code: str = "a"
) -> List[Tuple[str, List[int], int]]: ) -> List[Tuple[str, List[int], int]]:
"""Process all sentences and return info.""" """Process all sentences and return info"""
# Detect Chinese text
is_chinese = lang_code.startswith("z") or re.search(r"[\u4e00-\u9fff]", text)
if is_chinese:
# Split using Chinese punctuation
sentences = re.split(r"([,。!?;])+", text)
else:
sentences = re.split(r"([.!?;:])(?=\s|$)", text) sentences = re.split(r"([.!?;:])(?=\s|$)", text)
phoneme_length, min_value = len(custom_phenomes_list), 0 phoneme_length, min_value = len(custom_phenomes_list), 0
@ -104,16 +122,16 @@ def get_sentence_info(
current_id, custom_phenomes_list.pop(current_id) current_id, custom_phenomes_list.pop(current_id)
) )
min_value += 1 min_value += 1
punct = sentences[i + 1] if i + 1 < len(sentences) else "" punct = sentences[i + 1] if i + 1 < len(sentences) else ""
if not sentence: if not sentence:
continue continue
full = sentence + punct full = sentence + punct
# Strip the full sentence to remove any leading/trailing spaces before processing
full = full.strip()
if not full: # Skip if empty after stripping
continue
tokens = process_text_chunk(full) tokens = process_text_chunk(full)
results.append((full, tokens, len(tokens))) results.append((full, tokens, len(tokens)))
return results return results
@ -128,50 +146,72 @@ async def smart_split(
max_tokens: int = settings.absolute_max_tokens, max_tokens: int = settings.absolute_max_tokens,
lang_code: str = "a", lang_code: str = "a",
normalization_options: NormalizationOptions = NormalizationOptions(), normalization_options: NormalizationOptions = NormalizationOptions(),
) -> AsyncGenerator[Tuple[str, List[int]], None]: ) -> AsyncGenerator[Tuple[str, List[int], Optional[float]], None]:
"""Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens.""" """Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens.
Yields:
Tuple of (text_chunk, tokens, pause_duration_s).
If pause_duration_s is not None, it's a pause chunk with empty text/tokens.
Otherwise, it's a text chunk containing the original text.
"""
start_time = time.time() start_time = time.time()
chunk_count = 0 chunk_count = 0
logger.info(f"Starting smart split for {len(text)} chars") logger.info(f"Starting smart split for {len(text)} chars")
# --- Step 1: Split by Pause Tags FIRST ---
# This operates on the raw input text
parts = PAUSE_TAG_PATTERN.split(text)
logger.debug(f"Split raw text into {len(parts)} parts by pause tags.")
part_idx = 0
while part_idx < len(parts):
text_part_raw = parts[part_idx] # This part is raw text
part_idx += 1
# --- Process Text Part ---
if text_part_raw and text_part_raw.strip(): # Only process if the part is not empty string
# Strip leading and trailing spaces to prevent pause tag splitting artifacts
text_part_raw = text_part_raw.strip()
# Apply the original smart_split logic to this text part
custom_phoneme_list = {} custom_phoneme_list = {}
# Normalize text # Normalize text (original logic)
processed_text = text_part_raw
if settings.advanced_text_normalization and normalization_options.normalize: if settings.advanced_text_normalization and normalization_options.normalize:
print(lang_code)
if lang_code in ["a", "b", "en-us", "en-gb"]: if lang_code in ["a", "b", "en-us", "en-gb"]:
text = CUSTOM_PHONEMES.sub( processed_text = CUSTOM_PHONEMES.sub(
lambda s: handle_custom_phonemes(s, custom_phoneme_list), text lambda s: handle_custom_phonemes(s, custom_phoneme_list), processed_text
) )
text = normalize_text(text, normalization_options) processed_text = normalize_text(processed_text, normalization_options)
else: else:
logger.info( logger.info(
"Skipping text normalization as it is only supported for english" "Skipping text normalization as it is only supported for english"
) )
# Process all sentences # Process all sentences (original logic)
sentences = get_sentence_info(text, custom_phoneme_list) sentences = get_sentence_info(processed_text, custom_phoneme_list, lang_code=lang_code)
current_chunk = [] current_chunk = []
current_tokens = [] current_tokens = []
current_count = 0 current_count = 0
for sentence, tokens, count in sentences: for sentence, tokens, count in sentences:
# Handle sentences that exceed max tokens # Handle sentences that exceed max tokens (original logic)
if count > max_tokens: if count > max_tokens:
# Yield current chunk if any # Yield current chunk if any
if current_chunk: if current_chunk:
chunk_text = " ".join(current_chunk) chunk_text = " ".join(current_chunk).strip()
chunk_count += 1 chunk_count += 1
logger.debug( logger.debug(
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)"
) )
yield chunk_text, current_tokens yield chunk_text, current_tokens, None
current_chunk = [] current_chunk = []
current_tokens = [] current_tokens = []
current_count = 0 current_count = 0
# Split long sentence on commas # Split long sentence on commas (original logic)
clauses = re.split(r"([,])", sentence) clauses = re.split(r"([,])", sentence)
clause_chunk = [] clause_chunk = []
clause_tokens = [] clause_tokens = []
@ -200,38 +240,38 @@ async def smart_split(
else: else:
# Yield clause chunk if we have one # Yield clause chunk if we have one
if clause_chunk: if clause_chunk:
chunk_text = " ".join(clause_chunk) chunk_text = " ".join(clause_chunk).strip()
chunk_count += 1 chunk_count += 1
logger.debug( logger.debug(
f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)" f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({clause_count} tokens)"
) )
yield chunk_text, clause_tokens yield chunk_text, clause_tokens, None
clause_chunk = [full_clause] clause_chunk = [full_clause]
clause_tokens = tokens clause_tokens = tokens
clause_count = count clause_count = count
# Don't forget last clause chunk # Don't forget last clause chunk
if clause_chunk: if clause_chunk:
chunk_text = " ".join(clause_chunk) chunk_text = " ".join(clause_chunk).strip()
chunk_count += 1 chunk_count += 1
logger.debug( logger.debug(
f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)" f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({clause_count} tokens)"
) )
yield chunk_text, clause_tokens yield chunk_text, clause_tokens, None
# Regular sentence handling # Regular sentence handling (original logic)
elif ( elif (
current_count >= settings.target_min_tokens current_count >= settings.target_min_tokens
and current_count + count > settings.target_max_tokens and current_count + count > settings.target_max_tokens
): ):
# If we have a good sized chunk and adding next sentence exceeds target, # If we have a good sized chunk and adding next sentence exceeds target,
# yield current chunk and start new one # yield current chunk and start new one
chunk_text = " ".join(current_chunk) chunk_text = " ".join(current_chunk).strip()
chunk_count += 1 chunk_count += 1
logger.info( logger.info(
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)"
) )
yield chunk_text, current_tokens yield chunk_text, current_tokens, None
current_chunk = [sentence] current_chunk = [sentence]
current_tokens = tokens current_tokens = tokens
current_count = count current_count = count
@ -251,26 +291,44 @@ async def smart_split(
else: else:
# Yield current chunk and start new one # Yield current chunk and start new one
if current_chunk: if current_chunk:
chunk_text = " ".join(current_chunk) chunk_text = " ".join(current_chunk).strip()
chunk_count += 1 chunk_count += 1
logger.info( logger.info(
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)"
) )
yield chunk_text, current_tokens yield chunk_text, current_tokens, None
current_chunk = [sentence] current_chunk = [sentence]
current_tokens = tokens current_tokens = tokens
current_count = count current_count = count
# Don't forget the last chunk # Don't forget the last chunk for this text part
if current_chunk: if current_chunk:
chunk_text = " ".join(current_chunk) chunk_text = " ".join(current_chunk).strip()
chunk_count += 1 chunk_count += 1
logger.info( logger.info(
f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" f"Yielding final chunk {chunk_count} for part: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)"
) )
yield chunk_text, current_tokens yield chunk_text, current_tokens, None
# --- Handle Pause Part ---
# Check if the next part is a pause duration string
if part_idx < len(parts):
duration_str = parts[part_idx]
# Check if it looks like a valid number string captured by the regex group
if re.fullmatch(r"\d+(?:\.\d+)?", duration_str):
part_idx += 1 # Consume the duration string as it's been processed
try:
duration = float(duration_str)
if duration > 0:
chunk_count += 1
logger.info(f"Yielding pause chunk {chunk_count}: {duration}s")
yield "", [], duration # Yield pause chunk
except (ValueError, TypeError):
# This case should be rare if re.fullmatch passed, but handle anyway
logger.warning(f"Could not parse valid-looking pause duration: {duration_str}")
# --- End of parts loop ---
total_time = time.time() - start_time total_time = time.time() - start_time
logger.info( logger.info(
f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks" f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks (including pauses)"
) )

View file

@ -23,6 +23,8 @@ def tokenize(phonemes: str) -> list[int]:
Returns: Returns:
List of token IDs List of token IDs
""" """
# Strip phonemes to remove leading/trailing spaces that could cause artifacts
phonemes = phonemes.strip()
return [i for i in map(VOCAB.get, phonemes) if i is not None] return [i for i in map(VOCAB.get, phonemes) if i is not None]

View file

@ -55,6 +55,7 @@ class TTSService:
output_format: Optional[str] = None, output_format: Optional[str] = None,
is_first: bool = False, is_first: bool = False,
is_last: bool = False, is_last: bool = False,
volume_multiplier: Optional[float] = 1.0,
normalizer: Optional[AudioNormalizer] = None, normalizer: Optional[AudioNormalizer] = None,
lang_code: Optional[str] = None, lang_code: Optional[str] = None,
return_timestamps: Optional[bool] = False, return_timestamps: Optional[bool] = False,
@ -100,6 +101,7 @@ class TTSService:
lang_code=lang_code, lang_code=lang_code,
return_timestamps=return_timestamps, return_timestamps=return_timestamps,
): ):
chunk_data.audio*=volume_multiplier
# For streaming, convert to bytes # For streaming, convert to bytes
if output_format: if output_format:
try: try:
@ -141,6 +143,8 @@ class TTSService:
logger.error("Model generated empty audio chunk") logger.error("Model generated empty audio chunk")
return return
chunk_data.audio*=volume_multiplier
# For streaming, convert to bytes # For streaming, convert to bytes
if output_format: if output_format:
try: try:
@ -259,6 +263,7 @@ class TTSService:
speed: float = 1.0, speed: float = 1.0,
output_format: str = "wav", output_format: str = "wav",
lang_code: Optional[str] = None, lang_code: Optional[str] = None,
volume_multiplier: Optional[float] = 1.0,
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(), normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
return_timestamps: Optional[bool] = False, return_timestamps: Optional[bool] = False,
) -> AsyncGenerator[AudioChunk, None]: ) -> AsyncGenerator[AudioChunk, None]:
@ -280,12 +285,46 @@ class TTSService:
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream" f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream"
) )
# Process text in chunks with smart splitting # Process text in chunks with smart splitting, handling pause tags
async for chunk_text, tokens in smart_split( async for chunk_text, tokens, pause_duration_s in smart_split(
text, text,
lang_code=pipeline_lang_code, lang_code=pipeline_lang_code,
normalization_options=normalization_options, normalization_options=normalization_options,
): ):
if pause_duration_s is not None and pause_duration_s > 0:
# --- Handle Pause Chunk ---
try:
logger.debug(f"Generating {pause_duration_s}s silence chunk")
silence_samples = int(pause_duration_s * 24000) # 24kHz sample rate
# Create proper silence as int16 zeros to avoid normalization artifacts
silence_audio = np.zeros(silence_samples, dtype=np.int16)
pause_chunk = AudioChunk(audio=silence_audio, word_timestamps=[]) # Empty timestamps for silence
# Format and yield the silence chunk
if output_format:
formatted_pause_chunk = await AudioService.convert_audio(
pause_chunk, output_format, writer, speed=speed, chunk_text="",
is_last_chunk=False, trim_audio=False, normalizer=stream_normalizer,
)
if formatted_pause_chunk.output:
yield formatted_pause_chunk
else: # Raw audio mode
# For raw audio mode, silence is already in the correct format (int16)
# Skip normalization to avoid any potential artifacts
if len(pause_chunk.audio) > 0:
yield pause_chunk
# Update offset based on silence duration
current_offset += pause_duration_s
chunk_index += 1 # Count pause as a yielded chunk
except Exception as e:
logger.error(f"Failed to process pause chunk: {str(e)}")
continue
elif tokens or chunk_text.strip(): # Process if there are tokens OR non-whitespace text
# --- Handle Text Chunk ---
try: try:
# Process audio for chunk # Process audio for chunk
async for chunk_data in self._process_chunk( async for chunk_data in self._process_chunk(
@ -297,6 +336,7 @@ class TTSService:
writer, writer,
output_format, output_format,
is_first=(chunk_index == 0), is_first=(chunk_index == 0),
volume_multiplier=volume_multiplier,
is_last=False, # We'll update the last chunk later is_last=False, # We'll update the last chunk later
normalizer=stream_normalizer, normalizer=stream_normalizer,
lang_code=pipeline_lang_code, # Pass lang_code lang_code=pipeline_lang_code, # Pass lang_code
@ -307,16 +347,23 @@ class TTSService:
timestamp.start_time += current_offset timestamp.start_time += current_offset
timestamp.end_time += current_offset timestamp.end_time += current_offset
current_offset += len(chunk_data.audio) / 24000 # Update offset based on the actual duration of the generated audio chunk
chunk_duration = 0
if chunk_data.audio is not None and len(chunk_data.audio) > 0:
chunk_duration = len(chunk_data.audio) / 24000
current_offset += chunk_duration
# Yield the processed chunk (either formatted or raw)
if chunk_data.output is not None: if chunk_data.output is not None:
yield chunk_data yield chunk_data
elif chunk_data.audio is not None and len(chunk_data.audio) > 0:
yield chunk_data
else: else:
logger.warning( logger.warning(
f"No audio generated for chunk: '{chunk_text[:100]}...'" f"No audio generated for chunk: '{chunk_text[:100]}...'"
) )
chunk_index += 1
chunk_index += 1 # Increment chunk index after processing text
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}" f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}"
@ -337,6 +384,7 @@ class TTSService:
output_format, output_format,
is_first=False, is_first=False,
is_last=True, # Signal this is the last chunk is_last=True, # Signal this is the last chunk
volume_multiplier=volume_multiplier,
normalizer=stream_normalizer, normalizer=stream_normalizer,
lang_code=pipeline_lang_code, # Pass lang_code lang_code=pipeline_lang_code, # Pass lang_code
): ):
@ -356,6 +404,7 @@ class TTSService:
writer: StreamingAudioWriter, writer: StreamingAudioWriter,
speed: float = 1.0, speed: float = 1.0,
return_timestamps: bool = False, return_timestamps: bool = False,
volume_multiplier: Optional[float] = 1.0,
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(), normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
lang_code: Optional[str] = None, lang_code: Optional[str] = None,
) -> AudioChunk: ) -> AudioChunk:
@ -368,6 +417,7 @@ class TTSService:
voice, voice,
writer, writer,
speed=speed, speed=speed,
volume_multiplier=volume_multiplier,
normalization_options=normalization_options, normalization_options=normalization_options,
return_timestamps=return_timestamps, return_timestamps=return_timestamps,
lang_code=lang_code, lang_code=lang_code,

View file

@ -1,3 +1,4 @@
from email.policy import default
from enum import Enum from enum import Enum
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
@ -66,6 +67,10 @@ class NormalizationOptions(BaseModel):
default=True, default=True,
description="Changes phone numbers so they can be properly pronouced by kokoro", description="Changes phone numbers so they can be properly pronouced by kokoro",
) )
replace_remaining_symbols: bool = Field(
default=True,
description="Replaces the remaining symbols after normalization with their words"
)
class OpenAISpeechRequest(BaseModel): class OpenAISpeechRequest(BaseModel):
@ -108,6 +113,10 @@ class OpenAISpeechRequest(BaseModel):
default=None, default=None,
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.", description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
) )
volume_multiplier: Optional[float] = Field(
default = 1.0,
description="A volume multiplier to multiply the output audio by."
)
normalization_options: Optional[NormalizationOptions] = Field( normalization_options: Optional[NormalizationOptions] = Field(
default=NormalizationOptions(), default=NormalizationOptions(),
description="Options for the normalization system", description="Options for the normalization system",
@ -152,6 +161,10 @@ class CaptionedSpeechRequest(BaseModel):
default=None, default=None,
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.", description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
) )
volume_multiplier: Optional[float] = Field(
default = 1.0,
description="A volume multiplier to multiply the output audio by."
)
normalization_options: Optional[NormalizationOptions] = Field( normalization_options: Optional[NormalizationOptions] = Field(
default=NormalizationOptions(), default=NormalizationOptions(),
description="Options for the normalization system", description="Options for the normalization system",

View file

@ -175,6 +175,13 @@ def test_money():
== "The plant cost two hundred thousand dollars and eighty cents." == "The plant cost two hundred thousand dollars and eighty cents."
) )
assert (
normalize_text(
"Your shopping spree cost $674.03!", normalization_options=NormalizationOptions()
)
== "Your shopping spree cost six hundred and seventy-four dollars and three cents!"
)
assert ( assert (
normalize_text( normalize_text(
"€30.2 is in euros", normalization_options=NormalizationOptions() "€30.2 is in euros", normalization_options=NormalizationOptions()
@ -315,3 +322,12 @@ def test_non_url_text():
normalize_text("It costs $50.", normalization_options=NormalizationOptions()) normalize_text("It costs $50.", normalization_options=NormalizationOptions())
== "It costs fifty dollars." == "It costs fifty dollars."
) )
def test_remaining_symbol():
"""Test that remaining symbols are replaced"""
assert (
normalize_text(
"I love buying products @ good store here & @ other store", normalization_options=NormalizationOptions()
)
== "I love buying products at good store here and at other store"
)

View file

@ -67,7 +67,7 @@ async def test_smart_split_short_text():
"""Test smart splitting with text under max tokens.""" """Test smart splitting with text under max tokens."""
text = "This is a short test sentence." text = "This is a short test sentence."
chunks = [] chunks = []
async for chunk_text, chunk_tokens in smart_split(text): async for chunk_text, chunk_tokens, _ in smart_split(text):
chunks.append((chunk_text, chunk_tokens)) chunks.append((chunk_text, chunk_tokens))
assert len(chunks) == 1 assert len(chunks) == 1
@ -82,7 +82,7 @@ async def test_smart_split_long_text():
text = ". ".join(["This is test sentence number " + str(i) for i in range(20)]) text = ". ".join(["This is test sentence number " + str(i) for i in range(20)])
chunks = [] chunks = []
async for chunk_text, chunk_tokens in smart_split(text): async for chunk_text, chunk_tokens, _ in smart_split(text):
chunks.append((chunk_text, chunk_tokens)) chunks.append((chunk_text, chunk_tokens))
assert len(chunks) > 1 assert len(chunks) > 1
@ -98,8 +98,127 @@ async def test_smart_split_with_punctuation():
text = "First sentence! Second sentence? Third sentence; Fourth sentence: Fifth sentence." text = "First sentence! Second sentence? Third sentence; Fourth sentence: Fifth sentence."
chunks = [] chunks = []
async for chunk_text, chunk_tokens in smart_split(text): async for chunk_text, chunk_tokens, _ in smart_split(text):
chunks.append(chunk_text) chunks.append(chunk_text)
# Verify punctuation is preserved # Verify punctuation is preserved
assert all(any(p in chunk for p in "!?;:.") for chunk in chunks) assert all(any(p in chunk for p in "!?;:.") for chunk in chunks)
def test_process_text_chunk_chinese_phonemes():
"""Test processing with Chinese pinyin phonemes."""
pinyin = "nǐ hǎo lì" # Example pinyin sequence with tones
tokens = process_text_chunk(pinyin, skip_phonemize=True, language="z")
assert isinstance(tokens, list)
assert len(tokens) > 0
def test_get_sentence_info_chinese():
"""Test Chinese sentence splitting and info extraction."""
text = "这是一个句子。这是第二个句子!第三个问题?"
results = get_sentence_info(text, {}, lang_code="z")
assert len(results) == 3
for sentence, tokens, count in results:
assert isinstance(sentence, str)
assert isinstance(tokens, list)
assert isinstance(count, int)
assert count == len(tokens)
assert count > 0
@pytest.mark.asyncio
async def test_smart_split_chinese_short():
"""Test Chinese smart splitting with short text."""
text = "这是一句话。"
chunks = []
async for chunk_text, chunk_tokens, _ in smart_split(text, lang_code="z"):
chunks.append((chunk_text, chunk_tokens))
assert len(chunks) == 1
assert isinstance(chunks[0][0], str)
assert isinstance(chunks[0][1], list)
@pytest.mark.asyncio
async def test_smart_split_chinese_long():
"""Test Chinese smart splitting with longer text."""
text = "".join([f"测试句子 {i}" for i in range(20)])
chunks = []
async for chunk_text, chunk_tokens, _ in smart_split(text, lang_code="z"):
chunks.append((chunk_text, chunk_tokens))
assert len(chunks) > 1
for chunk_text, chunk_tokens in chunks:
assert isinstance(chunk_text, str)
assert isinstance(chunk_tokens, list)
assert len(chunk_tokens) > 0
@pytest.mark.asyncio
async def test_smart_split_chinese_punctuation():
"""Test Chinese smart splitting with punctuation preservation."""
text = "第一句!第二问?第三句;第四句:第五句。"
chunks = []
async for chunk_text, _, _ in smart_split(text, lang_code="z"):
chunks.append(chunk_text)
# Verify Chinese punctuation is preserved
assert all(any(p in chunk for p in "!?;:。") for chunk in chunks)
@pytest.mark.asyncio
async def test_smart_split_with_pause():
"""Test smart splitting with pause tags."""
text = "Hello world [pause:2.5s] How are you?"
chunks = []
async for chunk_text, chunk_tokens, pause_duration in smart_split(text):
chunks.append((chunk_text, chunk_tokens, pause_duration))
# Should have 3 chunks: text, pause, text
assert len(chunks) == 3
# First chunk: text
assert chunks[0][2] is None # No pause
assert "Hello world" in chunks[0][0]
assert len(chunks[0][1]) > 0
# Second chunk: pause
assert chunks[1][2] == 2.5 # 2.5 second pause
assert chunks[1][0] == "" # Empty text
assert len(chunks[1][1]) == 0 # No tokens
# Third chunk: text
assert chunks[2][2] is None # No pause
assert "How are you?" in chunks[2][0]
assert len(chunks[2][1]) > 0
@pytest.mark.asyncio
async def test_smart_split_with_two_pause():
"""Test smart splitting with two pause tags."""
text = "[pause:0.5s][pause:1.67s]0.5"
chunks = []
async for chunk_text, chunk_tokens, pause_duration in smart_split(text):
chunks.append((chunk_text, chunk_tokens, pause_duration))
# Should have 3 chunks: pause, pause, text
assert len(chunks) == 3
# First chunk: pause
assert chunks[0][2] == 0.5 # 0.5 second pause
assert chunks[0][0] == "" # Empty text
assert len(chunks[0][1]) == 0
# Second chunk: pause
assert chunks[1][2] == 1.67 # 1.67 second pause
assert chunks[1][0] == "" # Empty text
assert len(chunks[1][1]) == 0 # No tokens
# Third chunk: text
assert chunks[2][2] is None # No pause
assert "zero point five" in chunks[2][0]
assert len(chunks[2][1]) > 0

View file

@ -3,6 +3,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
import os
from api.src.services.tts_service import TTSService from api.src.services.tts_service import TTSService
@ -102,7 +103,10 @@ async def test_get_voice_path_combined():
service = await TTSService.create("test_output") service = await TTSService.create("test_output")
name, path = await service._get_voices_path("voice1+voice2") name, path = await service._get_voices_path("voice1+voice2")
assert name == "voice1+voice2" assert name == "voice1+voice2"
assert path.endswith("voice1+voice2.pt") # Verify the path points to a temporary file with expected format
assert path.startswith("/tmp/")
assert "voice1+voice2" in path
assert path.endswith(".pt")
mock_save.assert_called_once() mock_save.assert_called_once()

View file

@ -3,9 +3,7 @@ import json
import requests import requests
text = """the administration has offered up a platter of repression for more than a year and is still slated to lose $400 million. text = """奶酪芝士很浓郁!臭豆腐芝士有争议?陈年奶酪价格昂贵。"""
Columbia is the largest private landowner in New York City and boasts an endowment of $14.8 billion;"""
Type = "wav" Type = "wav"
@ -15,7 +13,7 @@ response = requests.post(
json={ json={
"model": "kokoro", "model": "kokoro",
"input": text, "input": text,
"voice": "af_heart+af_sky", "voice": "zf_xiaobei",
"speed": 1.0, "speed": 1.0,
"response_format": Type, "response_format": Type,
"stream": False, "stream": False,