Merge branch 'remsky:master' into master

This commit is contained in:
Kishor Prins 2025-06-21 09:53:19 -07:00 committed by GitHub
commit 0241423375
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 512 additions and 195 deletions

View file

@ -13,7 +13,7 @@
[![Tested at Model Commit](https://img.shields.io/badge/last--tested--model--commit-1.0::9901c2b-blue)](https://huggingface.co/hexgrad/Kokoro-82M/commit/9901c2b79161b6e898b7ea857ae5298f47b8b0d6)
Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model
- Multi-language support (English, Japanese, Korean, Chinese, _Vietnamese soon_)
- Multi-language support (English, Japanese, Chinese, _Vietnamese soon_)
- OpenAI-compatible Speech endpoint, NVIDIA GPU accelerated or CPU inference with PyTorch
- ONNX support coming soon, see v0.1.5 and earlier for legacy ONNX support in the interim
- Debug endpoints for monitoring system stats, integrated web UI on localhost:8880/web

View file

@ -1 +1 @@
0.3.0
0.2.4

View file

@ -31,6 +31,7 @@ class Settings(BaseSettings):
# Audio Settings
sample_rate: int = 24000
default_volume_multiplier: float = 1.0
# Text Processing Settings
target_min_tokens: int = 175 # Target minimum tokens per chunk
target_max_tokens: int = 250 # Target maximum tokens per chunk

View file

@ -300,7 +300,7 @@ async def get_web_file_path(filename: str) -> str:
)
# Construct web directory path relative to project root
web_dir = os.path.join("/app", settings.web_player_path)
web_dir = os.path.join(root_dir, settings.web_player_path)
# Search in web directory
search_paths = [web_dir]

View file

@ -141,6 +141,8 @@ Model files not found! You need to download the Kokoro V1 model:
try:
async for chunk in self._backend.generate(*args, **kwargs):
if settings.default_volume_multiplier != 1.0:
chunk.audio *= settings.default_volume_multiplier
yield chunk
except Exception as e:
raise RuntimeError(f"Generation failed: {e}")

View file

@ -319,6 +319,7 @@ async def create_captioned_speech(
writer=writer,
speed=request.speed,
return_timestamps=request.return_timestamps,
volume_multiplier=request.volume_multiplier,
normalization_options=request.normalization_options,
lang_code=request.lang_code,
)

View file

@ -152,6 +152,7 @@ async def stream_audio_chunks(
speed=request.speed,
output_format=request.response_format,
lang_code=request.lang_code,
volume_multiplier=request.volume_multiplier,
normalization_options=request.normalization_options,
return_timestamps=unique_properties["return_timestamps"],
):
@ -300,6 +301,7 @@ async def create_speech(
voice=voice_name,
writer=writer,
speed=request.speed,
volume_multiplier=request.volume_multiplier,
normalization_options=request.normalization_options,
lang_code=request.lang_code,
)

View file

@ -80,12 +80,12 @@ class AudioNormalizer:
non_silent_index_start, non_silent_index_end = None, None
for X in range(0, len(audio_data)):
if audio_data[X] > amplitude_threshold:
if abs(audio_data[X]) > amplitude_threshold:
non_silent_index_start = X
break
for X in range(len(audio_data) - 1, -1, -1):
if audio_data[X] > amplitude_threshold:
if abs(audio_data[X]) > amplitude_threshold:
non_silent_index_end = X
break

View file

@ -32,19 +32,29 @@ class StreamingAudioWriter:
if self.format in ["wav", "flac", "mp3", "pcm", "aac", "opus"]:
if self.format != "pcm":
self.output_buffer = BytesIO()
container_options = {}
# Try disabling Xing VBR header for MP3 to fix iOS timeline reading issues
if self.format == 'mp3':
# Disable Xing VBR header
container_options = {'write_xing': '0'}
logger.debug("Disabling Xing VBR header for MP3 encoding.")
self.container = av.open(
self.output_buffer,
mode="w",
format=self.format if self.format != "aac" else "adts",
options=container_options # Pass options here
)
self.stream = self.container.add_stream(
codec_map[self.format],
sample_rate=self.sample_rate,
rate=self.sample_rate,
layout="mono" if self.channels == 1 else "stereo",
)
self.stream.bit_rate = 128000
# Set bit_rate only for codecs where it's applicable and useful
if self.format in ['mp3', 'aac', 'opus']:
self.stream.bit_rate = 128000
else:
raise ValueError(f"Unsupported format: {format}")
raise ValueError(f"Unsupported format: {self.format}") # Use self.format here
def close(self):
if hasattr(self, "container"):
@ -65,12 +75,18 @@ class StreamingAudioWriter:
if finalize:
if self.format != "pcm":
# Flush stream encoder
packets = self.stream.encode(None)
for packet in packets:
self.container.mux(packet)
# Closing the container handles writing the trailer and finalizing the file.
# No explicit flush method is available or needed here.
logger.debug("Muxed final packets.")
# Get the final bytes from the buffer *before* closing it
data = self.output_buffer.getvalue()
self.close()
self.close() # Close container and buffer
return data
if audio_data is None or len(audio_data) == 0:

View file

@ -11,7 +11,7 @@ from typing import List, Optional, Union
import inflect
from numpy import number
from text_to_num import text2num
# from text_to_num import text2num
from torch import mul
from ...structures.schemas import NormalizationOptions
@ -134,6 +134,23 @@ VALID_UNITS = {
"px": "pixel", # CSS units
}
SYMBOL_REPLACEMENTS = {
'~': ' ',
'@': ' at ',
'#': ' number ',
'$': ' dollar ',
'%': ' percent ',
'^': ' ',
'&': ' and ',
'*': ' ',
'_': ' ',
'|': ' ',
'\\': ' ',
'/': ' slash ',
'=': ' equals ',
'+': ' plus ',
}
MONEY_UNITS = {"$": ("dollar", "cent"), "£": ("pound", "pence"), "": ("euro", "cent")}
# Pre-compiled regex patterns for performance
@ -391,6 +408,7 @@ def handle_time(t: re.Match[str]) -> str:
def normalize_text(text: str, normalization_options: NormalizationOptions) -> str:
"""Normalize text for TTS processing"""
# Handle email addresses first if enabled
if normalization_options.email_normalization:
text = EMAIL_PATTERN.sub(handle_email, text)
@ -415,7 +433,7 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
text,
)
# Replace quotes and brackets
# Replace quotes and brackets (additional cleanup)
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
text = text.replace("«", chr(8220)).replace("»", chr(8221))
text = text.replace(chr(8220), '"').replace(chr(8221), '"')
@ -435,6 +453,11 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
text = re.sub(r" +", " ", text)
text = re.sub(r"(?<=\n) +(?=\n)", "", text)
# Handle special characters that might cause audio artifacts first
# Replace newlines with spaces (or pauses if needed)
text = text.replace('\n', ' ')
text = text.replace('\r', ' ')
# Handle titles and abbreviations
text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text)
text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text)
@ -445,7 +468,7 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
# Handle common words
text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text)
# Handle numbers and money
# Handle numbers and money BEFORE replacing special characters
text = re.sub(r"(?<=\d),(?=\d)", "", text)
text = MONEY_PATTERN.sub(
@ -457,6 +480,11 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
text = re.sub(r"\d*\.\d+", handle_decimal, text)
# Handle other problematic symbols AFTER money/number processing
if normalization_options.replace_remaining_symbols:
for symbol, replacement in SYMBOL_REPLACEMENTS.items():
text = text.replace(symbol, replacement)
# Handle various formatting
text = re.sub(r"(?<=\d)-(?=\d)", " to ", text)
text = re.sub(r"(?<=\d)S", " S", text)
@ -467,4 +495,6 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
)
text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
text = re.sub(r"\s{2,}", " ", text)
return text.strip()

View file

@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
import phonemizer
from .normalizer import normalize_text
from ...structures.schemas import NormalizationOptions
phonemizers = {}
@ -75,7 +76,7 @@ def create_phonemizer(language: str = "a") -> PhonemizerBackend:
Phonemizer backend instance
"""
# Map language codes to espeak language codes
lang_map = {"a": "en-us", "b": "en-gb"}
lang_map = {"a": "en-us", "b": "en-gb", "z": "z"}
if language not in lang_map:
raise ValueError(f"Unsupported language code: {language}")
@ -83,20 +84,24 @@ def create_phonemizer(language: str = "a") -> PhonemizerBackend:
return EspeakBackend(lang_map[language])
def phonemize(text: str, language: str = "a", normalize: bool = True) -> str:
def phonemize(text: str, language: str = "a") -> str:
"""Convert text to phonemes
Args:
text: Text to convert to phonemes
language: Language code ('a' for US English, 'b' for British English)
normalize: Whether to normalize text before phonemization
Returns:
Phonemized text
"""
global phonemizers
if normalize:
text = normalize_text(text)
# Strip input text first to remove problematic leading/trailing spaces
text = text.strip()
if language not in phonemizers:
phonemizers[language] = create_phonemizer(language)
return phonemizers[language].phonemize(text)
result = phonemizers[language].phonemize(text)
# Final strip to ensure no leading/trailing spaces in phonemes
return result.strip()

View file

@ -2,7 +2,7 @@
import re
import time
from typing import AsyncGenerator, Dict, List, Tuple
from typing import AsyncGenerator, Dict, List, Tuple, Optional
from loguru import logger
@ -13,7 +13,11 @@ from .phonemizer import phonemize
from .vocabulary import tokenize
# Pre-compiled regex patterns for performance
CUSTOM_PHONEMES = re.compile(r"(\[([^\]]|\n)*?\])(\(\/([^\/)]|\n)*?\/\))")
# Updated regex to be more strict and avoid matching isolated brackets
# Only matches complete patterns like [word](/ipa/) and prevents catastrophic backtracking
CUSTOM_PHONEMES = re.compile(r"(\[[^\[\]]*?\])(\(\/[^\/\(\)]*?\/\))")
# Pattern to find pause tags like [pause:0.5s]
PAUSE_TAG_PATTERN = re.compile(r"\[pause:(\d+(?:\.\d+)?)s\]", re.IGNORECASE)
def process_text_chunk(
@ -30,6 +34,12 @@ def process_text_chunk(
List of token IDs
"""
start_time = time.time()
# Strip input text to remove any leading/trailing spaces that could cause artifacts
text = text.strip()
if not text:
return []
if skip_phonemize:
# Input is already phonemes, just tokenize
@ -42,7 +52,9 @@ def process_text_chunk(
t1 = time.time()
t0 = time.time()
phonemes = phonemize(text, language, normalize=False) # Already normalized
phonemes = phonemize(text, language)
# Strip phonemes result to ensure no extra spaces
phonemes = phonemes.strip()
t1 = time.time()
t0 = time.time()
@ -88,10 +100,16 @@ def process_text(text: str, language: str = "a") -> List[int]:
def get_sentence_info(
text: str, custom_phenomes_list: Dict[str, str]
text: str, custom_phenomes_list: Dict[str, str], lang_code: str = "a"
) -> List[Tuple[str, List[int], int]]:
"""Process all sentences and return info."""
sentences = re.split(r"([.!?;:])(?=\s|$)", text)
"""Process all sentences and return info"""
# Detect Chinese text
is_chinese = lang_code.startswith("z") or re.search(r"[\u4e00-\u9fff]", text)
if is_chinese:
# Split using Chinese punctuation
sentences = re.split(r"([,。!?;])+", text)
else:
sentences = re.split(r"([.!?;:])(?=\s|$)", text)
phoneme_length, min_value = len(custom_phenomes_list), 0
results = []
@ -104,16 +122,16 @@ def get_sentence_info(
current_id, custom_phenomes_list.pop(current_id)
)
min_value += 1
punct = sentences[i + 1] if i + 1 < len(sentences) else ""
if not sentence:
continue
full = sentence + punct
# Strip the full sentence to remove any leading/trailing spaces before processing
full = full.strip()
if not full: # Skip if empty after stripping
continue
tokens = process_text_chunk(full)
results.append((full, tokens, len(tokens)))
return results
@ -128,149 +146,189 @@ async def smart_split(
max_tokens: int = settings.absolute_max_tokens,
lang_code: str = "a",
normalization_options: NormalizationOptions = NormalizationOptions(),
) -> AsyncGenerator[Tuple[str, List[int]], None]:
"""Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens."""
) -> AsyncGenerator[Tuple[str, List[int], Optional[float]], None]:
"""Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens.
Yields:
Tuple of (text_chunk, tokens, pause_duration_s).
If pause_duration_s is not None, it's a pause chunk with empty text/tokens.
Otherwise, it's a text chunk containing the original text.
"""
start_time = time.time()
chunk_count = 0
logger.info(f"Starting smart split for {len(text)} chars")
custom_phoneme_list = {}
# --- Step 1: Split by Pause Tags FIRST ---
# This operates on the raw input text
parts = PAUSE_TAG_PATTERN.split(text)
logger.debug(f"Split raw text into {len(parts)} parts by pause tags.")
# Normalize text
if settings.advanced_text_normalization and normalization_options.normalize:
print(lang_code)
if lang_code in ["a", "b", "en-us", "en-gb"]:
text = CUSTOM_PHONEMES.sub(
lambda s: handle_custom_phonemes(s, custom_phoneme_list), text
)
text = normalize_text(text, normalization_options)
else:
logger.info(
"Skipping text normalization as it is only supported for english"
)
part_idx = 0
while part_idx < len(parts):
text_part_raw = parts[part_idx] # This part is raw text
part_idx += 1
# Process all sentences
sentences = get_sentence_info(text, custom_phoneme_list)
# --- Process Text Part ---
if text_part_raw and text_part_raw.strip(): # Only process if the part is not empty string
# Strip leading and trailing spaces to prevent pause tag splitting artifacts
text_part_raw = text_part_raw.strip()
current_chunk = []
current_tokens = []
current_count = 0
# Apply the original smart_split logic to this text part
custom_phoneme_list = {}
for sentence, tokens, count in sentences:
# Handle sentences that exceed max tokens
if count > max_tokens:
# Yield current chunk if any
if current_chunk:
chunk_text = " ".join(current_chunk)
chunk_count += 1
logger.debug(
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
)
yield chunk_text, current_tokens
current_chunk = []
current_tokens = []
current_count = 0
# Split long sentence on commas
clauses = re.split(r"([,])", sentence)
clause_chunk = []
clause_tokens = []
clause_count = 0
for j in range(0, len(clauses), 2):
clause = clauses[j].strip()
comma = clauses[j + 1] if j + 1 < len(clauses) else ""
if not clause:
continue
full_clause = clause + comma
tokens = process_text_chunk(full_clause)
count = len(tokens)
# If adding clause keeps us under max and not optimal yet
if (
clause_count + count <= max_tokens
and clause_count + count <= settings.target_max_tokens
):
clause_chunk.append(full_clause)
clause_tokens.extend(tokens)
clause_count += count
# Normalize text (original logic)
processed_text = text_part_raw
if settings.advanced_text_normalization and normalization_options.normalize:
if lang_code in ["a", "b", "en-us", "en-gb"]:
processed_text = CUSTOM_PHONEMES.sub(
lambda s: handle_custom_phonemes(s, custom_phoneme_list), processed_text
)
processed_text = normalize_text(processed_text, normalization_options)
else:
# Yield clause chunk if we have one
if clause_chunk:
chunk_text = " ".join(clause_chunk)
logger.info(
"Skipping text normalization as it is only supported for english"
)
# Process all sentences (original logic)
sentences = get_sentence_info(processed_text, custom_phoneme_list, lang_code=lang_code)
current_chunk = []
current_tokens = []
current_count = 0
for sentence, tokens, count in sentences:
# Handle sentences that exceed max tokens (original logic)
if count > max_tokens:
# Yield current chunk if any
if current_chunk:
chunk_text = " ".join(current_chunk).strip()
chunk_count += 1
logger.debug(
f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)"
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)"
)
yield chunk_text, clause_tokens
clause_chunk = [full_clause]
clause_tokens = tokens
clause_count = count
yield chunk_text, current_tokens, None
current_chunk = []
current_tokens = []
current_count = 0
# Don't forget last clause chunk
if clause_chunk:
chunk_text = " ".join(clause_chunk)
chunk_count += 1
logger.debug(
f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)"
)
yield chunk_text, clause_tokens
# Split long sentence on commas (original logic)
clauses = re.split(r"([,])", sentence)
clause_chunk = []
clause_tokens = []
clause_count = 0
# Regular sentence handling
elif (
current_count >= settings.target_min_tokens
and current_count + count > settings.target_max_tokens
):
# If we have a good sized chunk and adding next sentence exceeds target,
# yield current chunk and start new one
chunk_text = " ".join(current_chunk)
chunk_count += 1
logger.info(
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
)
yield chunk_text, current_tokens
current_chunk = [sentence]
current_tokens = tokens
current_count = count
elif current_count + count <= settings.target_max_tokens:
# Keep building chunk while under target max
current_chunk.append(sentence)
current_tokens.extend(tokens)
current_count += count
elif (
current_count + count <= max_tokens
and current_count < settings.target_min_tokens
):
# Only exceed target max if we haven't reached minimum size yet
current_chunk.append(sentence)
current_tokens.extend(tokens)
current_count += count
else:
# Yield current chunk and start new one
for j in range(0, len(clauses), 2):
clause = clauses[j].strip()
comma = clauses[j + 1] if j + 1 < len(clauses) else ""
if not clause:
continue
full_clause = clause + comma
tokens = process_text_chunk(full_clause)
count = len(tokens)
# If adding clause keeps us under max and not optimal yet
if (
clause_count + count <= max_tokens
and clause_count + count <= settings.target_max_tokens
):
clause_chunk.append(full_clause)
clause_tokens.extend(tokens)
clause_count += count
else:
# Yield clause chunk if we have one
if clause_chunk:
chunk_text = " ".join(clause_chunk).strip()
chunk_count += 1
logger.debug(
f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({clause_count} tokens)"
)
yield chunk_text, clause_tokens, None
clause_chunk = [full_clause]
clause_tokens = tokens
clause_count = count
# Don't forget last clause chunk
if clause_chunk:
chunk_text = " ".join(clause_chunk).strip()
chunk_count += 1
logger.debug(
f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({clause_count} tokens)"
)
yield chunk_text, clause_tokens, None
# Regular sentence handling (original logic)
elif (
current_count >= settings.target_min_tokens
and current_count + count > settings.target_max_tokens
):
# If we have a good sized chunk and adding next sentence exceeds target,
# yield current chunk and start new one
chunk_text = " ".join(current_chunk).strip()
chunk_count += 1
logger.info(
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)"
)
yield chunk_text, current_tokens, None
current_chunk = [sentence]
current_tokens = tokens
current_count = count
elif current_count + count <= settings.target_max_tokens:
# Keep building chunk while under target max
current_chunk.append(sentence)
current_tokens.extend(tokens)
current_count += count
elif (
current_count + count <= max_tokens
and current_count < settings.target_min_tokens
):
# Only exceed target max if we haven't reached minimum size yet
current_chunk.append(sentence)
current_tokens.extend(tokens)
current_count += count
else:
# Yield current chunk and start new one
if current_chunk:
chunk_text = " ".join(current_chunk).strip()
chunk_count += 1
logger.info(
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)"
)
yield chunk_text, current_tokens, None
current_chunk = [sentence]
current_tokens = tokens
current_count = count
# Don't forget the last chunk for this text part
if current_chunk:
chunk_text = " ".join(current_chunk)
chunk_text = " ".join(current_chunk).strip()
chunk_count += 1
logger.info(
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
f"Yielding final chunk {chunk_count} for part: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)"
)
yield chunk_text, current_tokens
current_chunk = [sentence]
current_tokens = tokens
current_count = count
yield chunk_text, current_tokens, None
# Don't forget the last chunk
if current_chunk:
chunk_text = " ".join(current_chunk)
chunk_count += 1
logger.info(
f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
)
yield chunk_text, current_tokens
# --- Handle Pause Part ---
# Check if the next part is a pause duration string
if part_idx < len(parts):
duration_str = parts[part_idx]
# Check if it looks like a valid number string captured by the regex group
if re.fullmatch(r"\d+(?:\.\d+)?", duration_str):
part_idx += 1 # Consume the duration string as it's been processed
try:
duration = float(duration_str)
if duration > 0:
chunk_count += 1
logger.info(f"Yielding pause chunk {chunk_count}: {duration}s")
yield "", [], duration # Yield pause chunk
except (ValueError, TypeError):
# This case should be rare if re.fullmatch passed, but handle anyway
logger.warning(f"Could not parse valid-looking pause duration: {duration_str}")
# --- End of parts loop ---
total_time = time.time() - start_time
logger.info(
f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks"
f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks (including pauses)"
)

View file

@ -23,6 +23,8 @@ def tokenize(phonemes: str) -> list[int]:
Returns:
List of token IDs
"""
# Strip phonemes to remove leading/trailing spaces that could cause artifacts
phonemes = phonemes.strip()
return [i for i in map(VOCAB.get, phonemes) if i is not None]

View file

@ -55,6 +55,7 @@ class TTSService:
output_format: Optional[str] = None,
is_first: bool = False,
is_last: bool = False,
volume_multiplier: Optional[float] = 1.0,
normalizer: Optional[AudioNormalizer] = None,
lang_code: Optional[str] = None,
return_timestamps: Optional[bool] = False,
@ -100,6 +101,7 @@ class TTSService:
lang_code=lang_code,
return_timestamps=return_timestamps,
):
chunk_data.audio*=volume_multiplier
# For streaming, convert to bytes
if output_format:
try:
@ -132,7 +134,7 @@ class TTSService:
speed=speed,
return_timestamps=return_timestamps,
)
if chunk_data.audio is None:
logger.error("Model generated None for audio chunk")
return
@ -141,6 +143,8 @@ class TTSService:
logger.error("Model generated empty audio chunk")
return
chunk_data.audio*=volume_multiplier
# For streaming, convert to bytes
if output_format:
try:
@ -259,6 +263,7 @@ class TTSService:
speed: float = 1.0,
output_format: str = "wav",
lang_code: Optional[str] = None,
volume_multiplier: Optional[float] = 1.0,
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
return_timestamps: Optional[bool] = False,
) -> AsyncGenerator[AudioChunk, None]:
@ -280,48 +285,90 @@ class TTSService:
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream"
)
# Process text in chunks with smart splitting
async for chunk_text, tokens in smart_split(
# Process text in chunks with smart splitting, handling pause tags
async for chunk_text, tokens, pause_duration_s in smart_split(
text,
lang_code=pipeline_lang_code,
normalization_options=normalization_options,
):
try:
# Process audio for chunk
async for chunk_data in self._process_chunk(
chunk_text, # Pass text for Kokoro V1
tokens, # Pass tokens for legacy backends
voice_name, # Pass voice name
voice_path, # Pass voice path
speed,
writer,
output_format,
is_first=(chunk_index == 0),
is_last=False, # We'll update the last chunk later
normalizer=stream_normalizer,
lang_code=pipeline_lang_code, # Pass lang_code
return_timestamps=return_timestamps,
):
if chunk_data.word_timestamps is not None:
for timestamp in chunk_data.word_timestamps:
timestamp.start_time += current_offset
timestamp.end_time += current_offset
if pause_duration_s is not None and pause_duration_s > 0:
# --- Handle Pause Chunk ---
try:
logger.debug(f"Generating {pause_duration_s}s silence chunk")
silence_samples = int(pause_duration_s * 24000) # 24kHz sample rate
# Create proper silence as int16 zeros to avoid normalization artifacts
silence_audio = np.zeros(silence_samples, dtype=np.int16)
pause_chunk = AudioChunk(audio=silence_audio, word_timestamps=[]) # Empty timestamps for silence
current_offset += len(chunk_data.audio) / 24000
# Format and yield the silence chunk
if output_format:
formatted_pause_chunk = await AudioService.convert_audio(
pause_chunk, output_format, writer, speed=speed, chunk_text="",
is_last_chunk=False, trim_audio=False, normalizer=stream_normalizer,
if chunk_data.output is not None:
yield chunk_data
else:
logger.warning(
f"No audio generated for chunk: '{chunk_text[:100]}...'"
)
chunk_index += 1
except Exception as e:
logger.error(
f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}"
)
continue
if formatted_pause_chunk.output:
yield formatted_pause_chunk
else: # Raw audio mode
# For raw audio mode, silence is already in the correct format (int16)
# Skip normalization to avoid any potential artifacts
if len(pause_chunk.audio) > 0:
yield pause_chunk
# Update offset based on silence duration
current_offset += pause_duration_s
chunk_index += 1 # Count pause as a yielded chunk
except Exception as e:
logger.error(f"Failed to process pause chunk: {str(e)}")
continue
elif tokens or chunk_text.strip(): # Process if there are tokens OR non-whitespace text
# --- Handle Text Chunk ---
try:
# Process audio for chunk
async for chunk_data in self._process_chunk(
chunk_text, # Pass text for Kokoro V1
tokens, # Pass tokens for legacy backends
voice_name, # Pass voice name
voice_path, # Pass voice path
speed,
writer,
output_format,
is_first=(chunk_index == 0),
volume_multiplier=volume_multiplier,
is_last=False, # We'll update the last chunk later
normalizer=stream_normalizer,
lang_code=pipeline_lang_code, # Pass lang_code
return_timestamps=return_timestamps,
):
if chunk_data.word_timestamps is not None:
for timestamp in chunk_data.word_timestamps:
timestamp.start_time += current_offset
timestamp.end_time += current_offset
# Update offset based on the actual duration of the generated audio chunk
chunk_duration = 0
if chunk_data.audio is not None and len(chunk_data.audio) > 0:
chunk_duration = len(chunk_data.audio) / 24000
current_offset += chunk_duration
# Yield the processed chunk (either formatted or raw)
if chunk_data.output is not None:
yield chunk_data
elif chunk_data.audio is not None and len(chunk_data.audio) > 0:
yield chunk_data
else:
logger.warning(
f"No audio generated for chunk: '{chunk_text[:100]}...'"
)
chunk_index += 1 # Increment chunk index after processing text
except Exception as e:
logger.error(
f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}"
)
continue
# Only finalize if we successfully processed at least one chunk
if chunk_index > 0:
@ -337,6 +384,7 @@ class TTSService:
output_format,
is_first=False,
is_last=True, # Signal this is the last chunk
volume_multiplier=volume_multiplier,
normalizer=stream_normalizer,
lang_code=pipeline_lang_code, # Pass lang_code
):
@ -356,6 +404,7 @@ class TTSService:
writer: StreamingAudioWriter,
speed: float = 1.0,
return_timestamps: bool = False,
volume_multiplier: Optional[float] = 1.0,
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
lang_code: Optional[str] = None,
) -> AudioChunk:
@ -368,6 +417,7 @@ class TTSService:
voice,
writer,
speed=speed,
volume_multiplier=volume_multiplier,
normalization_options=normalization_options,
return_timestamps=return_timestamps,
lang_code=lang_code,

View file

@ -1,3 +1,4 @@
from email.policy import default
from enum import Enum
from typing import List, Literal, Optional, Union
@ -66,6 +67,10 @@ class NormalizationOptions(BaseModel):
default=True,
description="Changes phone numbers so they can be properly pronouced by kokoro",
)
replace_remaining_symbols: bool = Field(
default=True,
description="Replaces the remaining symbols after normalization with their words"
)
class OpenAISpeechRequest(BaseModel):
@ -108,6 +113,10 @@ class OpenAISpeechRequest(BaseModel):
default=None,
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
)
volume_multiplier: Optional[float] = Field(
default = 1.0,
description="A volume multiplier to multiply the output audio by."
)
normalization_options: Optional[NormalizationOptions] = Field(
default=NormalizationOptions(),
description="Options for the normalization system",
@ -152,6 +161,10 @@ class CaptionedSpeechRequest(BaseModel):
default=None,
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
)
volume_multiplier: Optional[float] = Field(
default = 1.0,
description="A volume multiplier to multiply the output audio by."
)
normalization_options: Optional[NormalizationOptions] = Field(
default=NormalizationOptions(),
description="Options for the normalization system",

View file

@ -175,6 +175,13 @@ def test_money():
== "The plant cost two hundred thousand dollars and eighty cents."
)
assert (
normalize_text(
"Your shopping spree cost $674.03!", normalization_options=NormalizationOptions()
)
== "Your shopping spree cost six hundred and seventy-four dollars and three cents!"
)
assert (
normalize_text(
"€30.2 is in euros", normalization_options=NormalizationOptions()
@ -315,3 +322,12 @@ def test_non_url_text():
normalize_text("It costs $50.", normalization_options=NormalizationOptions())
== "It costs fifty dollars."
)
def test_remaining_symbol():
"""Test that remaining symbols are replaced"""
assert (
normalize_text(
"I love buying products @ good store here & @ other store", normalization_options=NormalizationOptions()
)
== "I love buying products at good store here and at other store"
)

View file

@ -67,7 +67,7 @@ async def test_smart_split_short_text():
"""Test smart splitting with text under max tokens."""
text = "This is a short test sentence."
chunks = []
async for chunk_text, chunk_tokens in smart_split(text):
async for chunk_text, chunk_tokens, _ in smart_split(text):
chunks.append((chunk_text, chunk_tokens))
assert len(chunks) == 1
@ -82,7 +82,7 @@ async def test_smart_split_long_text():
text = ". ".join(["This is test sentence number " + str(i) for i in range(20)])
chunks = []
async for chunk_text, chunk_tokens in smart_split(text):
async for chunk_text, chunk_tokens, _ in smart_split(text):
chunks.append((chunk_text, chunk_tokens))
assert len(chunks) > 1
@ -98,8 +98,127 @@ async def test_smart_split_with_punctuation():
text = "First sentence! Second sentence? Third sentence; Fourth sentence: Fifth sentence."
chunks = []
async for chunk_text, chunk_tokens in smart_split(text):
async for chunk_text, chunk_tokens, _ in smart_split(text):
chunks.append(chunk_text)
# Verify punctuation is preserved
assert all(any(p in chunk for p in "!?;:.") for chunk in chunks)
def test_process_text_chunk_chinese_phonemes():
"""Test processing with Chinese pinyin phonemes."""
pinyin = "nǐ hǎo lì" # Example pinyin sequence with tones
tokens = process_text_chunk(pinyin, skip_phonemize=True, language="z")
assert isinstance(tokens, list)
assert len(tokens) > 0
def test_get_sentence_info_chinese():
"""Test Chinese sentence splitting and info extraction."""
text = "这是一个句子。这是第二个句子!第三个问题?"
results = get_sentence_info(text, {}, lang_code="z")
assert len(results) == 3
for sentence, tokens, count in results:
assert isinstance(sentence, str)
assert isinstance(tokens, list)
assert isinstance(count, int)
assert count == len(tokens)
assert count > 0
@pytest.mark.asyncio
async def test_smart_split_chinese_short():
"""Test Chinese smart splitting with short text."""
text = "这是一句话。"
chunks = []
async for chunk_text, chunk_tokens, _ in smart_split(text, lang_code="z"):
chunks.append((chunk_text, chunk_tokens))
assert len(chunks) == 1
assert isinstance(chunks[0][0], str)
assert isinstance(chunks[0][1], list)
@pytest.mark.asyncio
async def test_smart_split_chinese_long():
"""Test Chinese smart splitting with longer text."""
text = "".join([f"测试句子 {i}" for i in range(20)])
chunks = []
async for chunk_text, chunk_tokens, _ in smart_split(text, lang_code="z"):
chunks.append((chunk_text, chunk_tokens))
assert len(chunks) > 1
for chunk_text, chunk_tokens in chunks:
assert isinstance(chunk_text, str)
assert isinstance(chunk_tokens, list)
assert len(chunk_tokens) > 0
@pytest.mark.asyncio
async def test_smart_split_chinese_punctuation():
"""Test Chinese smart splitting with punctuation preservation."""
text = "第一句!第二问?第三句;第四句:第五句。"
chunks = []
async for chunk_text, _, _ in smart_split(text, lang_code="z"):
chunks.append(chunk_text)
# Verify Chinese punctuation is preserved
assert all(any(p in chunk for p in "!?;:。") for chunk in chunks)
@pytest.mark.asyncio
async def test_smart_split_with_pause():
"""Test smart splitting with pause tags."""
text = "Hello world [pause:2.5s] How are you?"
chunks = []
async for chunk_text, chunk_tokens, pause_duration in smart_split(text):
chunks.append((chunk_text, chunk_tokens, pause_duration))
# Should have 3 chunks: text, pause, text
assert len(chunks) == 3
# First chunk: text
assert chunks[0][2] is None # No pause
assert "Hello world" in chunks[0][0]
assert len(chunks[0][1]) > 0
# Second chunk: pause
assert chunks[1][2] == 2.5 # 2.5 second pause
assert chunks[1][0] == "" # Empty text
assert len(chunks[1][1]) == 0 # No tokens
# Third chunk: text
assert chunks[2][2] is None # No pause
assert "How are you?" in chunks[2][0]
assert len(chunks[2][1]) > 0
@pytest.mark.asyncio
async def test_smart_split_with_two_pause():
"""Test smart splitting with two pause tags."""
text = "[pause:0.5s][pause:1.67s]0.5"
chunks = []
async for chunk_text, chunk_tokens, pause_duration in smart_split(text):
chunks.append((chunk_text, chunk_tokens, pause_duration))
# Should have 3 chunks: pause, pause, text
assert len(chunks) == 3
# First chunk: pause
assert chunks[0][2] == 0.5 # 0.5 second pause
assert chunks[0][0] == "" # Empty text
assert len(chunks[0][1]) == 0
# Second chunk: pause
assert chunks[1][2] == 1.67 # 1.67 second pause
assert chunks[1][0] == "" # Empty text
assert len(chunks[1][1]) == 0 # No tokens
# Third chunk: text
assert chunks[2][2] is None # No pause
assert "zero point five" in chunks[2][0]
assert len(chunks[2][1]) > 0

View file

@ -3,6 +3,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
import torch
import os
from api.src.services.tts_service import TTSService
@ -102,7 +103,10 @@ async def test_get_voice_path_combined():
service = await TTSService.create("test_output")
name, path = await service._get_voices_path("voice1+voice2")
assert name == "voice1+voice2"
assert path.endswith("voice1+voice2.pt")
# Verify the path points to a temporary file with expected format
assert path.startswith("/tmp/")
assert "voice1+voice2" in path
assert path.endswith(".pt")
mock_save.assert_called_once()

View file

@ -3,9 +3,7 @@ import json
import requests
text = """the administration has offered up a platter of repression for more than a year and is still slated to lose $400 million.
Columbia is the largest private landowner in New York City and boasts an endowment of $14.8 billion;"""
text = """奶酪芝士很浓郁!臭豆腐芝士有争议?陈年奶酪价格昂贵。"""
Type = "wav"
@ -15,7 +13,7 @@ response = requests.post(
json={
"model": "kokoro",
"input": text,
"voice": "af_heart+af_sky",
"voice": "zf_xiaobei",
"speed": 1.0,
"response_format": Type,
"stream": False,