Merge remote-tracking branch 'origin/master' into release
Some checks failed
Create Release and Publish Docker Images / prepare-release (push) Has been cancelled
Create Release and Publish Docker Images / build-images (push) Has been cancelled
Create Release and Publish Docker Images / create-release (push) Has been cancelled

This commit is contained in:
Fireblade2534 2025-06-18 22:04:28 +00:00
commit dd8aa26813
22 changed files with 833 additions and 252 deletions

View file

@ -13,7 +13,7 @@
[![Tested at Model Commit](https://img.shields.io/badge/last--tested--model--commit-1.0::9901c2b-blue)](https://huggingface.co/hexgrad/Kokoro-82M/commit/9901c2b79161b6e898b7ea857ae5298f47b8b0d6) [![Tested at Model Commit](https://img.shields.io/badge/last--tested--model--commit-1.0::9901c2b-blue)](https://huggingface.co/hexgrad/Kokoro-82M/commit/9901c2b79161b6e898b7ea857ae5298f47b8b0d6)
Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model
- Multi-language support (English, Japanese, Korean, Chinese, _Vietnamese soon_) - Multi-language support (English, Japanese, Chinese, _Vietnamese soon_)
- OpenAI-compatible Speech endpoint, NVIDIA GPU accelerated or CPU inference with PyTorch - OpenAI-compatible Speech endpoint, NVIDIA GPU accelerated or CPU inference with PyTorch
- ONNX support coming soon, see v0.1.5 and earlier for legacy ONNX support in the interim - ONNX support coming soon, see v0.1.5 and earlier for legacy ONNX support in the interim
- Debug endpoints for monitoring system stats, integrated web UI on localhost:8880/web - Debug endpoints for monitoring system stats, integrated web UI on localhost:8880/web
@ -34,10 +34,12 @@ Pre built images are available to run, with arm/multi-arch support, and baked in
Refer to the core/config.py file for a full list of variables which can be managed via the environment Refer to the core/config.py file for a full list of variables which can be managed via the environment
```bash ```bash
# the `latest` tag can be used, though it may have some unexpected bonus features which impact stability. Named versions should be pinned for your regular usage. Feedback/testing is always welcome # the `latest` tag can be used, though it may have some unexpected bonus features which impact stability.
Named versions should be pinned for your regular usage.
Feedback/testing is always welcome
docker run -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-cpu:v0.3.0 # CPU, or: docker run -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-cpu:latest # CPU, or:
docker run --gpus all -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-gpu:v0.3.0 #NVIDIA GPU docker run --gpus all -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-gpu:latest #NVIDIA GPU
``` ```

View file

@ -1 +1 @@
0.3.0 0.2.4

View file

@ -31,6 +31,7 @@ class Settings(BaseSettings):
# Audio Settings # Audio Settings
sample_rate: int = 24000 sample_rate: int = 24000
default_volume_multiplier: float = 1.0
# Text Processing Settings # Text Processing Settings
target_min_tokens: int = 175 # Target minimum tokens per chunk target_min_tokens: int = 175 # Target minimum tokens per chunk
target_max_tokens: int = 250 # Target maximum tokens per chunk target_max_tokens: int = 250 # Target maximum tokens per chunk

View file

@ -300,7 +300,7 @@ async def get_web_file_path(filename: str) -> str:
) )
# Construct web directory path relative to project root # Construct web directory path relative to project root
web_dir = os.path.join("/app", settings.web_player_path) web_dir = os.path.join(root_dir, settings.web_player_path)
# Search in web directory # Search in web directory
search_paths = [web_dir] search_paths = [web_dir]

View file

@ -141,6 +141,8 @@ Model files not found! You need to download the Kokoro V1 model:
try: try:
async for chunk in self._backend.generate(*args, **kwargs): async for chunk in self._backend.generate(*args, **kwargs):
if settings.default_volume_multiplier != 1.0:
chunk.audio *= settings.default_volume_multiplier
yield chunk yield chunk
except Exception as e: except Exception as e:
raise RuntimeError(f"Generation failed: {e}") raise RuntimeError(f"Generation failed: {e}")

View file

@ -104,7 +104,7 @@ async def generate_from_phonemes(
if chunk_audio is not None: if chunk_audio is not None:
# Normalize audio before writing # Normalize audio before writing
normalized_audio = await normalizer.normalize(chunk_audio) normalized_audio = normalizer.normalize(chunk_audio)
# Write chunk and yield bytes # Write chunk and yield bytes
chunk_bytes = writer.write_chunk(normalized_audio) chunk_bytes = writer.write_chunk(normalized_audio)
if chunk_bytes: if chunk_bytes:
@ -114,6 +114,7 @@ async def generate_from_phonemes(
final_bytes = writer.write_chunk(finalize=True) final_bytes = writer.write_chunk(finalize=True)
if final_bytes: if final_bytes:
yield final_bytes yield final_bytes
writer.close()
else: else:
raise ValueError("Failed to generate audio data") raise ValueError("Failed to generate audio data")
@ -223,10 +224,13 @@ async def create_captioned_speech(
).decode("utf-8") ).decode("utf-8")
# Add any chunks that may be in the acumulator into the return word_timestamps # Add any chunks that may be in the acumulator into the return word_timestamps
chunk_data.word_timestamps = ( if chunk_data.word_timestamps is not None:
timestamp_acumulator + chunk_data.word_timestamps chunk_data.word_timestamps = (
) timestamp_acumulator + chunk_data.word_timestamps
timestamp_acumulator = [] )
timestamp_acumulator = []
else:
chunk_data.word_timestamps = []
yield CaptionedSpeechResponse( yield CaptionedSpeechResponse(
audio=base64_chunk, audio=base64_chunk,
@ -271,7 +275,7 @@ async def create_captioned_speech(
) )
# Add any chunks that may be in the acumulator into the return word_timestamps # Add any chunks that may be in the acumulator into the return word_timestamps
if chunk_data.word_timestamps != None: if chunk_data.word_timestamps is not None:
chunk_data.word_timestamps = ( chunk_data.word_timestamps = (
timestamp_acumulator + chunk_data.word_timestamps timestamp_acumulator + chunk_data.word_timestamps
) )
@ -315,6 +319,7 @@ async def create_captioned_speech(
writer=writer, writer=writer,
speed=request.speed, speed=request.speed,
return_timestamps=request.return_timestamps, return_timestamps=request.return_timestamps,
volume_multiplier=request.volume_multiplier,
normalization_options=request.normalization_options, normalization_options=request.normalization_options,
lang_code=request.lang_code, lang_code=request.lang_code,
) )

View file

@ -152,6 +152,7 @@ async def stream_audio_chunks(
speed=request.speed, speed=request.speed,
output_format=request.response_format, output_format=request.response_format,
lang_code=request.lang_code, lang_code=request.lang_code,
volume_multiplier=request.volume_multiplier,
normalization_options=request.normalization_options, normalization_options=request.normalization_options,
return_timestamps=unique_properties["return_timestamps"], return_timestamps=unique_properties["return_timestamps"],
): ):
@ -300,6 +301,7 @@ async def create_speech(
voice=voice_name, voice=voice_name,
writer=writer, writer=writer,
speed=request.speed, speed=request.speed,
volume_multiplier=request.volume_multiplier,
normalization_options=request.normalization_options, normalization_options=request.normalization_options,
lang_code=request.lang_code, lang_code=request.lang_code,
) )

View file

@ -80,12 +80,12 @@ class AudioNormalizer:
non_silent_index_start, non_silent_index_end = None, None non_silent_index_start, non_silent_index_end = None, None
for X in range(0, len(audio_data)): for X in range(0, len(audio_data)):
if audio_data[X] > amplitude_threshold: if abs(audio_data[X]) > amplitude_threshold:
non_silent_index_start = X non_silent_index_start = X
break break
for X in range(len(audio_data) - 1, -1, -1): for X in range(len(audio_data) - 1, -1, -1):
if audio_data[X] > amplitude_threshold: if abs(audio_data[X]) > amplitude_threshold:
non_silent_index_end = X non_silent_index_end = X
break break

View file

@ -32,19 +32,29 @@ class StreamingAudioWriter:
if self.format in ["wav", "flac", "mp3", "pcm", "aac", "opus"]: if self.format in ["wav", "flac", "mp3", "pcm", "aac", "opus"]:
if self.format != "pcm": if self.format != "pcm":
self.output_buffer = BytesIO() self.output_buffer = BytesIO()
container_options = {}
# Try disabling Xing VBR header for MP3 to fix iOS timeline reading issues
if self.format == 'mp3':
# Disable Xing VBR header
container_options = {'write_xing': '0'}
logger.debug("Disabling Xing VBR header for MP3 encoding.")
self.container = av.open( self.container = av.open(
self.output_buffer, self.output_buffer,
mode="w", mode="w",
format=self.format if self.format != "aac" else "adts", format=self.format if self.format != "aac" else "adts",
options=container_options # Pass options here
) )
self.stream = self.container.add_stream( self.stream = self.container.add_stream(
codec_map[self.format], codec_map[self.format],
sample_rate=self.sample_rate, rate=self.sample_rate,
layout="mono" if self.channels == 1 else "stereo", layout="mono" if self.channels == 1 else "stereo",
) )
self.stream.bit_rate = 128000 # Set bit_rate only for codecs where it's applicable and useful
if self.format in ['mp3', 'aac', 'opus']:
self.stream.bit_rate = 128000
else: else:
raise ValueError(f"Unsupported format: {format}") raise ValueError(f"Unsupported format: {self.format}") # Use self.format here
def close(self): def close(self):
if hasattr(self, "container"): if hasattr(self, "container"):
@ -65,12 +75,18 @@ class StreamingAudioWriter:
if finalize: if finalize:
if self.format != "pcm": if self.format != "pcm":
# Flush stream encoder
packets = self.stream.encode(None) packets = self.stream.encode(None)
for packet in packets: for packet in packets:
self.container.mux(packet) self.container.mux(packet)
# Closing the container handles writing the trailer and finalizing the file.
# No explicit flush method is available or needed here.
logger.debug("Muxed final packets.")
# Get the final bytes from the buffer *before* closing it
data = self.output_buffer.getvalue() data = self.output_buffer.getvalue()
self.close() self.close() # Close container and buffer
return data return data
if audio_data is None or len(audio_data) == 0: if audio_data is None or len(audio_data) == 0:

View file

@ -4,12 +4,14 @@ Handles various text formats including URLs, emails, numbers, money, and special
Converts them into a format suitable for text-to-speech processing. Converts them into a format suitable for text-to-speech processing.
""" """
import math
import re import re
from functools import lru_cache from functools import lru_cache
from typing import List, Optional, Union
import inflect import inflect
from numpy import number from numpy import number
from text_to_num import text2num # from text_to_num import text2num
from torch import mul from torch import mul
from ...structures.schemas import NormalizationOptions from ...structures.schemas import NormalizationOptions
@ -132,6 +134,24 @@ VALID_UNITS = {
"px": "pixel", # CSS units "px": "pixel", # CSS units
} }
SYMBOL_REPLACEMENTS = {
'~': ' ',
'@': ' at ',
'#': ' number ',
'$': ' dollar ',
'%': ' percent ',
'^': ' ',
'&': ' and ',
'*': ' ',
'_': ' ',
'|': ' ',
'\\': ' ',
'/': ' slash ',
'=': ' equals ',
'+': ' plus ',
}
MONEY_UNITS = {"$": ("dollar", "cent"), "£": ("pound", "pence"), "": ("euro", "cent")}
# Pre-compiled regex patterns for performance # Pre-compiled regex patterns for performance
EMAIL_PATTERN = re.compile( EMAIL_PATTERN = re.compile(
@ -152,37 +172,24 @@ UNIT_PATTERN = re.compile(
) )
TIME_PATTERN = re.compile( TIME_PATTERN = re.compile(
r"([0-9]{2} ?: ?[0-9]{2}( ?: ?[0-9]{2})?)( ?(pm|am)\b)?", re.IGNORECASE r"([0-9]{1,2} ?: ?[0-9]{2}( ?: ?[0-9]{2})?)( ?(pm|am)\b)?", re.IGNORECASE
)
MONEY_PATTERN = re.compile(
r"(-?)(["
+ "".join(MONEY_UNITS.keys())
+ r"])(\d+(?:\.\d+)?)((?: hundred| thousand| (?:[bm]|tr|quadr)illion|k|m|b|t)*)\b",
re.IGNORECASE,
)
NUMBER_PATTERN = re.compile(
r"(-?)(\d+(?:\.\d+)?)((?: hundred| thousand| (?:[bm]|tr|quadr)illion|k|m|b)*)\b",
re.IGNORECASE,
) )
INFLECT_ENGINE = inflect.engine() INFLECT_ENGINE = inflect.engine()
def split_num(num: re.Match[str]) -> str:
"""Handle number splitting for various formats"""
num = num.group()
if "." in num:
return num
elif ":" in num:
h, m = [int(n) for n in num.split(":")]
if m == 0:
return f"{h} o'clock"
elif m < 10:
return f"{h} oh {m}"
return f"{h} {m}"
year = int(num[:4])
if year < 1100 or year % 1000 < 10:
return num
left, right = num[:2], int(num[2:4])
s = "s" if num.endswith("s") else ""
if 100 <= year % 1000 <= 999:
if right == 0:
return f"{left} hundred{s}"
elif right < 10:
return f"{left} oh {right}{s}"
return f"{left} {right}{s}"
def handle_units(u: re.Match[str]) -> str: def handle_units(u: re.Match[str]) -> str:
"""Converts units to their full form""" """Converts units to their full form"""
unit_string = u.group(6).strip() unit_string = u.group(6).strip()
@ -208,14 +215,61 @@ def conditional_int(number: float, threshold: float = 0.00001):
return number return number
def translate_multiplier(multiplier: str) -> str:
"""Translate multiplier abrevations to words"""
multiplier_translation = {
"k": "thousand",
"m": "million",
"b": "billion",
"t": "trillion",
}
if multiplier.lower() in multiplier_translation:
return multiplier_translation[multiplier.lower()]
return multiplier.strip()
def split_four_digit(number: float):
part1 = str(conditional_int(number))[:2]
part2 = str(conditional_int(number))[2:]
return f"{INFLECT_ENGINE.number_to_words(part1)} {INFLECT_ENGINE.number_to_words(part2)}"
def handle_numbers(n: re.Match[str]) -> str:
number = n.group(2)
try:
number = float(number)
except:
return n.group()
if n.group(1) == "-":
number *= -1
multiplier = translate_multiplier(n.group(3))
number = conditional_int(number)
if multiplier != "":
multiplier = f" {multiplier}"
else:
if (
number % 1 == 0
and len(str(number)) == 4
and number > 1500
and number % 1000 > 9
):
return split_four_digit(number)
return f"{INFLECT_ENGINE.number_to_words(number)}{multiplier}"
def handle_money(m: re.Match[str]) -> str: def handle_money(m: re.Match[str]) -> str:
"""Convert money expressions to spoken form""" """Convert money expressions to spoken form"""
bill = "dollar" if m.group(2) == "$" else "pound" bill, coin = MONEY_UNITS[m.group(2)]
coin = "cent" if m.group(2) == "$" else "pence"
number = m.group(3) number = m.group(3)
multiplier = m.group(4)
try: try:
number = float(number) number = float(number)
except: except:
@ -224,12 +278,17 @@ def handle_money(m: re.Match[str]) -> str:
if m.group(1) == "-": if m.group(1) == "-":
number *= -1 number *= -1
multiplier = translate_multiplier(m.group(4))
if multiplier != "":
multiplier = f" {multiplier}"
if number % 1 == 0 or multiplier != "": if number % 1 == 0 or multiplier != "":
text_number = f"{INFLECT_ENGINE.number_to_words(conditional_int(number))}{multiplier} {INFLECT_ENGINE.plural(bill, count=number)}" text_number = f"{INFLECT_ENGINE.number_to_words(conditional_int(number))}{multiplier} {INFLECT_ENGINE.plural(bill, count=number)}"
else: else:
sub_number = int(str(number).split(".")[-1].ljust(2, "0")) sub_number = int(str(number).split(".")[-1].ljust(2, "0"))
text_number = f"{INFLECT_ENGINE.number_to_words(int(round(number)))} {INFLECT_ENGINE.plural(bill, count=number)} and {INFLECT_ENGINE.number_to_words(sub_number)} {INFLECT_ENGINE.plural(coin, count=sub_number)}" text_number = f"{INFLECT_ENGINE.number_to_words(int(math.floor(number)))} {INFLECT_ENGINE.plural(bill, count=number)} and {INFLECT_ENGINE.number_to_words(sub_number)} {INFLECT_ENGINE.plural(coin, count=sub_number)}"
return text_number return text_number
@ -320,19 +379,36 @@ def handle_phone_number(p: re.Match[str]) -> str:
def handle_time(t: re.Match[str]) -> str: def handle_time(t: re.Match[str]) -> str:
t = t.groups() t = t.groups()
numbers = " ".join( time_parts = t[0].split(":")
[INFLECT_ENGINE.number_to_words(X.strip()) for X in t[0].split(":")]
) numbers = []
numbers.append(INFLECT_ENGINE.number_to_words(time_parts[0].strip()))
minute_number = INFLECT_ENGINE.number_to_words(time_parts[1].strip())
if int(time_parts[1]) < 10:
if int(time_parts[1]) != 0:
numbers.append(f"oh {minute_number}")
else:
numbers.append(minute_number)
half = "" half = ""
if t[2] is not None: if len(time_parts) > 2:
half = t[2].strip() seconds_number = INFLECT_ENGINE.number_to_words(time_parts[2].strip())
second_word = INFLECT_ENGINE.plural("second", int(time_parts[2].strip()))
numbers.append(f"and {seconds_number} {second_word}")
else:
if t[2] is not None:
half = " " + t[2].strip()
else:
if int(time_parts[1]) == 0:
numbers.append("o'clock")
return numbers + half return " ".join(numbers) + half
def normalize_text(text: str, normalization_options: NormalizationOptions) -> str: def normalize_text(text: str, normalization_options: NormalizationOptions) -> str:
"""Normalize text for TTS processing""" """Normalize text for TTS processing"""
# Handle email addresses first if enabled # Handle email addresses first if enabled
if normalization_options.email_normalization: if normalization_options.email_normalization:
text = EMAIL_PATTERN.sub(handle_email, text) text = EMAIL_PATTERN.sub(handle_email, text)
@ -357,7 +433,7 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
text, text,
) )
# Replace quotes and brackets # Replace quotes and brackets (additional cleanup)
text = text.replace(chr(8216), "'").replace(chr(8217), "'") text = text.replace(chr(8216), "'").replace(chr(8217), "'")
text = text.replace("«", chr(8220)).replace("»", chr(8221)) text = text.replace("«", chr(8220)).replace("»", chr(8221))
text = text.replace(chr(8220), '"').replace(chr(8221), '"') text = text.replace(chr(8220), '"').replace(chr(8221), '"')
@ -366,7 +442,7 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
for a, b in zip("、。!,:;?–", ",.!,:;?-"): for a, b in zip("、。!,:;?–", ",.!,:;?-"):
text = text.replace(a, b + " ") text = text.replace(a, b + " ")
# Handle simple time in the format of HH:MM:SS # Handle simple time in the format of HH:MM:SS (am/pm)
text = TIME_PATTERN.sub( text = TIME_PATTERN.sub(
handle_time, handle_time,
text, text,
@ -377,6 +453,11 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
text = re.sub(r" +", " ", text) text = re.sub(r" +", " ", text)
text = re.sub(r"(?<=\n) +(?=\n)", "", text) text = re.sub(r"(?<=\n) +(?=\n)", "", text)
# Handle special characters that might cause audio artifacts first
# Replace newlines with spaces (or pauses if needed)
text = text.replace('\n', ' ')
text = text.replace('\r', ' ')
# Handle titles and abbreviations # Handle titles and abbreviations
text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text) text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text)
text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text) text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text)
@ -387,21 +468,23 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
# Handle common words # Handle common words
text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text) text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text)
# Handle numbers and money # Handle numbers and money BEFORE replacing special characters
text = re.sub(r"(?<=\d),(?=\d)", "", text) text = re.sub(r"(?<=\d),(?=\d)", "", text)
text = re.sub( text = MONEY_PATTERN.sub(
r"(?i)(-?)([$£])(\d+(?:\.\d+)?)((?: hundred| thousand| (?:[bm]|tr|quadr)illion)*)\b",
handle_money, handle_money,
text, text,
) )
text = re.sub( text = NUMBER_PATTERN.sub(handle_numbers, text)
r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)", split_num, text
)
text = re.sub(r"\d*\.\d+", handle_decimal, text) text = re.sub(r"\d*\.\d+", handle_decimal, text)
# Handle other problematic symbols AFTER money/number processing
if normalization_options.replace_remaining_symbols:
for symbol, replacement in SYMBOL_REPLACEMENTS.items():
text = text.replace(symbol, replacement)
# Handle various formatting # Handle various formatting
text = re.sub(r"(?<=\d)-(?=\d)", " to ", text) text = re.sub(r"(?<=\d)-(?=\d)", " to ", text)
text = re.sub(r"(?<=\d)S", " S", text) text = re.sub(r"(?<=\d)S", " S", text)
@ -412,4 +495,6 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
) )
text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text) text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
text = re.sub(r"\s{2,}", " ", text)
return text.strip() return text.strip()

View file

@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
import phonemizer import phonemizer
from .normalizer import normalize_text from .normalizer import normalize_text
from ...structures.schemas import NormalizationOptions
phonemizers = {} phonemizers = {}
@ -75,7 +76,7 @@ def create_phonemizer(language: str = "a") -> PhonemizerBackend:
Phonemizer backend instance Phonemizer backend instance
""" """
# Map language codes to espeak language codes # Map language codes to espeak language codes
lang_map = {"a": "en-us", "b": "en-gb"} lang_map = {"a": "en-us", "b": "en-gb", "z": "z"}
if language not in lang_map: if language not in lang_map:
raise ValueError(f"Unsupported language code: {language}") raise ValueError(f"Unsupported language code: {language}")
@ -83,20 +84,24 @@ def create_phonemizer(language: str = "a") -> PhonemizerBackend:
return EspeakBackend(lang_map[language]) return EspeakBackend(lang_map[language])
def phonemize(text: str, language: str = "a", normalize: bool = True) -> str: def phonemize(text: str, language: str = "a") -> str:
"""Convert text to phonemes """Convert text to phonemes
Args: Args:
text: Text to convert to phonemes text: Text to convert to phonemes
language: Language code ('a' for US English, 'b' for British English) language: Language code ('a' for US English, 'b' for British English)
normalize: Whether to normalize text before phonemization
Returns: Returns:
Phonemized text Phonemized text
""" """
global phonemizers global phonemizers
if normalize:
text = normalize_text(text) # Strip input text first to remove problematic leading/trailing spaces
text = text.strip()
if language not in phonemizers: if language not in phonemizers:
phonemizers[language] = create_phonemizer(language) phonemizers[language] = create_phonemizer(language)
return phonemizers[language].phonemize(text)
result = phonemizers[language].phonemize(text)
# Final strip to ensure no leading/trailing spaces in phonemes
return result.strip()

View file

@ -2,7 +2,7 @@
import re import re
import time import time
from typing import AsyncGenerator, Dict, List, Tuple from typing import AsyncGenerator, Dict, List, Tuple, Optional
from loguru import logger from loguru import logger
@ -13,7 +13,11 @@ from .phonemizer import phonemize
from .vocabulary import tokenize from .vocabulary import tokenize
# Pre-compiled regex patterns for performance # Pre-compiled regex patterns for performance
CUSTOM_PHONEMES = re.compile(r"(\[([^\]]|\n)*?\])(\(\/([^\/)]|\n)*?\/\))") # Updated regex to be more strict and avoid matching isolated brackets
# Only matches complete patterns like [word](/ipa/) and prevents catastrophic backtracking
CUSTOM_PHONEMES = re.compile(r"(\[[^\[\]]*?\])(\(\/[^\/\(\)]*?\/\))")
# Pattern to find pause tags like [pause:0.5s]
PAUSE_TAG_PATTERN = re.compile(r"\[pause:(\d+(?:\.\d+)?)s\]", re.IGNORECASE)
def process_text_chunk( def process_text_chunk(
@ -30,6 +34,12 @@ def process_text_chunk(
List of token IDs List of token IDs
""" """
start_time = time.time() start_time = time.time()
# Strip input text to remove any leading/trailing spaces that could cause artifacts
text = text.strip()
if not text:
return []
if skip_phonemize: if skip_phonemize:
# Input is already phonemes, just tokenize # Input is already phonemes, just tokenize
@ -42,7 +52,9 @@ def process_text_chunk(
t1 = time.time() t1 = time.time()
t0 = time.time() t0 = time.time()
phonemes = phonemize(text, language, normalize=False) # Already normalized phonemes = phonemize(text, language)
# Strip phonemes result to ensure no extra spaces
phonemes = phonemes.strip()
t1 = time.time() t1 = time.time()
t0 = time.time() t0 = time.time()
@ -88,10 +100,16 @@ def process_text(text: str, language: str = "a") -> List[int]:
def get_sentence_info( def get_sentence_info(
text: str, custom_phenomes_list: Dict[str, str] text: str, custom_phenomes_list: Dict[str, str], lang_code: str = "a"
) -> List[Tuple[str, List[int], int]]: ) -> List[Tuple[str, List[int], int]]:
"""Process all sentences and return info.""" """Process all sentences and return info"""
sentences = re.split(r"([.!?;:])(?=\s|$)", text) # Detect Chinese text
is_chinese = lang_code.startswith("z") or re.search(r"[\u4e00-\u9fff]", text)
if is_chinese:
# Split using Chinese punctuation
sentences = re.split(r"([,。!?;])+", text)
else:
sentences = re.split(r"([.!?;:])(?=\s|$)", text)
phoneme_length, min_value = len(custom_phenomes_list), 0 phoneme_length, min_value = len(custom_phenomes_list), 0
results = [] results = []
@ -104,16 +122,16 @@ def get_sentence_info(
current_id, custom_phenomes_list.pop(current_id) current_id, custom_phenomes_list.pop(current_id)
) )
min_value += 1 min_value += 1
punct = sentences[i + 1] if i + 1 < len(sentences) else "" punct = sentences[i + 1] if i + 1 < len(sentences) else ""
if not sentence: if not sentence:
continue continue
full = sentence + punct full = sentence + punct
# Strip the full sentence to remove any leading/trailing spaces before processing
full = full.strip()
if not full: # Skip if empty after stripping
continue
tokens = process_text_chunk(full) tokens = process_text_chunk(full)
results.append((full, tokens, len(tokens))) results.append((full, tokens, len(tokens)))
return results return results
@ -128,149 +146,189 @@ async def smart_split(
max_tokens: int = settings.absolute_max_tokens, max_tokens: int = settings.absolute_max_tokens,
lang_code: str = "a", lang_code: str = "a",
normalization_options: NormalizationOptions = NormalizationOptions(), normalization_options: NormalizationOptions = NormalizationOptions(),
) -> AsyncGenerator[Tuple[str, List[int]], None]: ) -> AsyncGenerator[Tuple[str, List[int], Optional[float]], None]:
"""Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens.""" """Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens.
Yields:
Tuple of (text_chunk, tokens, pause_duration_s).
If pause_duration_s is not None, it's a pause chunk with empty text/tokens.
Otherwise, it's a text chunk containing the original text.
"""
start_time = time.time() start_time = time.time()
chunk_count = 0 chunk_count = 0
logger.info(f"Starting smart split for {len(text)} chars") logger.info(f"Starting smart split for {len(text)} chars")
custom_phoneme_list = {} # --- Step 1: Split by Pause Tags FIRST ---
# This operates on the raw input text
parts = PAUSE_TAG_PATTERN.split(text)
logger.debug(f"Split raw text into {len(parts)} parts by pause tags.")
# Normalize text part_idx = 0
if settings.advanced_text_normalization and normalization_options.normalize: while part_idx < len(parts):
print(lang_code) text_part_raw = parts[part_idx] # This part is raw text
if lang_code in ["a", "b", "en-us", "en-gb"]: part_idx += 1
text = CUSTOM_PHONEMES.sub(
lambda s: handle_custom_phonemes(s, custom_phoneme_list), text
)
text = normalize_text(text, normalization_options)
else:
logger.info(
"Skipping text normalization as it is only supported for english"
)
# Process all sentences # --- Process Text Part ---
sentences = get_sentence_info(text, custom_phoneme_list) if text_part_raw and text_part_raw.strip(): # Only process if the part is not empty string
# Strip leading and trailing spaces to prevent pause tag splitting artifacts
text_part_raw = text_part_raw.strip()
current_chunk = [] # Apply the original smart_split logic to this text part
current_tokens = [] custom_phoneme_list = {}
current_count = 0
for sentence, tokens, count in sentences: # Normalize text (original logic)
# Handle sentences that exceed max tokens processed_text = text_part_raw
if count > max_tokens: if settings.advanced_text_normalization and normalization_options.normalize:
# Yield current chunk if any if lang_code in ["a", "b", "en-us", "en-gb"]:
if current_chunk: processed_text = CUSTOM_PHONEMES.sub(
chunk_text = " ".join(current_chunk) lambda s: handle_custom_phonemes(s, custom_phoneme_list), processed_text
chunk_count += 1 )
logger.debug( processed_text = normalize_text(processed_text, normalization_options)
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
)
yield chunk_text, current_tokens
current_chunk = []
current_tokens = []
current_count = 0
# Split long sentence on commas
clauses = re.split(r"([,])", sentence)
clause_chunk = []
clause_tokens = []
clause_count = 0
for j in range(0, len(clauses), 2):
clause = clauses[j].strip()
comma = clauses[j + 1] if j + 1 < len(clauses) else ""
if not clause:
continue
full_clause = clause + comma
tokens = process_text_chunk(full_clause)
count = len(tokens)
# If adding clause keeps us under max and not optimal yet
if (
clause_count + count <= max_tokens
and clause_count + count <= settings.target_max_tokens
):
clause_chunk.append(full_clause)
clause_tokens.extend(tokens)
clause_count += count
else: else:
# Yield clause chunk if we have one logger.info(
if clause_chunk: "Skipping text normalization as it is only supported for english"
chunk_text = " ".join(clause_chunk) )
# Process all sentences (original logic)
sentences = get_sentence_info(processed_text, custom_phoneme_list, lang_code=lang_code)
current_chunk = []
current_tokens = []
current_count = 0
for sentence, tokens, count in sentences:
# Handle sentences that exceed max tokens (original logic)
if count > max_tokens:
# Yield current chunk if any
if current_chunk:
chunk_text = " ".join(current_chunk).strip()
chunk_count += 1 chunk_count += 1
logger.debug( logger.debug(
f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)" f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)"
) )
yield chunk_text, clause_tokens yield chunk_text, current_tokens, None
clause_chunk = [full_clause] current_chunk = []
clause_tokens = tokens current_tokens = []
clause_count = count current_count = 0
# Don't forget last clause chunk # Split long sentence on commas (original logic)
if clause_chunk: clauses = re.split(r"([,])", sentence)
chunk_text = " ".join(clause_chunk) clause_chunk = []
chunk_count += 1 clause_tokens = []
logger.debug( clause_count = 0
f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)"
)
yield chunk_text, clause_tokens
# Regular sentence handling for j in range(0, len(clauses), 2):
elif ( clause = clauses[j].strip()
current_count >= settings.target_min_tokens comma = clauses[j + 1] if j + 1 < len(clauses) else ""
and current_count + count > settings.target_max_tokens
): if not clause:
# If we have a good sized chunk and adding next sentence exceeds target, continue
# yield current chunk and start new one
chunk_text = " ".join(current_chunk) full_clause = clause + comma
chunk_count += 1
logger.info( tokens = process_text_chunk(full_clause)
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" count = len(tokens)
)
yield chunk_text, current_tokens # If adding clause keeps us under max and not optimal yet
current_chunk = [sentence] if (
current_tokens = tokens clause_count + count <= max_tokens
current_count = count and clause_count + count <= settings.target_max_tokens
elif current_count + count <= settings.target_max_tokens: ):
# Keep building chunk while under target max clause_chunk.append(full_clause)
current_chunk.append(sentence) clause_tokens.extend(tokens)
current_tokens.extend(tokens) clause_count += count
current_count += count else:
elif ( # Yield clause chunk if we have one
current_count + count <= max_tokens if clause_chunk:
and current_count < settings.target_min_tokens chunk_text = " ".join(clause_chunk).strip()
): chunk_count += 1
# Only exceed target max if we haven't reached minimum size yet logger.debug(
current_chunk.append(sentence) f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({clause_count} tokens)"
current_tokens.extend(tokens) )
current_count += count yield chunk_text, clause_tokens, None
else: clause_chunk = [full_clause]
# Yield current chunk and start new one clause_tokens = tokens
clause_count = count
# Don't forget last clause chunk
if clause_chunk:
chunk_text = " ".join(clause_chunk).strip()
chunk_count += 1
logger.debug(
f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({clause_count} tokens)"
)
yield chunk_text, clause_tokens, None
# Regular sentence handling (original logic)
elif (
current_count >= settings.target_min_tokens
and current_count + count > settings.target_max_tokens
):
# If we have a good sized chunk and adding next sentence exceeds target,
# yield current chunk and start new one
chunk_text = " ".join(current_chunk).strip()
chunk_count += 1
logger.info(
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)"
)
yield chunk_text, current_tokens, None
current_chunk = [sentence]
current_tokens = tokens
current_count = count
elif current_count + count <= settings.target_max_tokens:
# Keep building chunk while under target max
current_chunk.append(sentence)
current_tokens.extend(tokens)
current_count += count
elif (
current_count + count <= max_tokens
and current_count < settings.target_min_tokens
):
# Only exceed target max if we haven't reached minimum size yet
current_chunk.append(sentence)
current_tokens.extend(tokens)
current_count += count
else:
# Yield current chunk and start new one
if current_chunk:
chunk_text = " ".join(current_chunk).strip()
chunk_count += 1
logger.info(
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)"
)
yield chunk_text, current_tokens, None
current_chunk = [sentence]
current_tokens = tokens
current_count = count
# Don't forget the last chunk for this text part
if current_chunk: if current_chunk:
chunk_text = " ".join(current_chunk) chunk_text = " ".join(current_chunk).strip()
chunk_count += 1 chunk_count += 1
logger.info( logger.info(
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" f"Yielding final chunk {chunk_count} for part: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)"
) )
yield chunk_text, current_tokens yield chunk_text, current_tokens, None
current_chunk = [sentence]
current_tokens = tokens
current_count = count
# Don't forget the last chunk # --- Handle Pause Part ---
if current_chunk: # Check if the next part is a pause duration string
chunk_text = " ".join(current_chunk) if part_idx < len(parts):
chunk_count += 1 duration_str = parts[part_idx]
logger.info( # Check if it looks like a valid number string captured by the regex group
f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" if re.fullmatch(r"\d+(?:\.\d+)?", duration_str):
) part_idx += 1 # Consume the duration string as it's been processed
yield chunk_text, current_tokens try:
duration = float(duration_str)
if duration > 0:
chunk_count += 1
logger.info(f"Yielding pause chunk {chunk_count}: {duration}s")
yield "", [], duration # Yield pause chunk
except (ValueError, TypeError):
# This case should be rare if re.fullmatch passed, but handle anyway
logger.warning(f"Could not parse valid-looking pause duration: {duration_str}")
# --- End of parts loop ---
total_time = time.time() - start_time total_time = time.time() - start_time
logger.info( logger.info(
f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks" f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks (including pauses)"
) )

View file

@ -23,6 +23,8 @@ def tokenize(phonemes: str) -> list[int]:
Returns: Returns:
List of token IDs List of token IDs
""" """
# Strip phonemes to remove leading/trailing spaces that could cause artifacts
phonemes = phonemes.strip()
return [i for i in map(VOCAB.get, phonemes) if i is not None] return [i for i in map(VOCAB.get, phonemes) if i is not None]

View file

@ -55,6 +55,7 @@ class TTSService:
output_format: Optional[str] = None, output_format: Optional[str] = None,
is_first: bool = False, is_first: bool = False,
is_last: bool = False, is_last: bool = False,
volume_multiplier: Optional[float] = 1.0,
normalizer: Optional[AudioNormalizer] = None, normalizer: Optional[AudioNormalizer] = None,
lang_code: Optional[str] = None, lang_code: Optional[str] = None,
return_timestamps: Optional[bool] = False, return_timestamps: Optional[bool] = False,
@ -100,6 +101,7 @@ class TTSService:
lang_code=lang_code, lang_code=lang_code,
return_timestamps=return_timestamps, return_timestamps=return_timestamps,
): ):
chunk_data.audio*=volume_multiplier
# For streaming, convert to bytes # For streaming, convert to bytes
if output_format: if output_format:
try: try:
@ -132,7 +134,7 @@ class TTSService:
speed=speed, speed=speed,
return_timestamps=return_timestamps, return_timestamps=return_timestamps,
) )
if chunk_data.audio is None: if chunk_data.audio is None:
logger.error("Model generated None for audio chunk") logger.error("Model generated None for audio chunk")
return return
@ -141,6 +143,8 @@ class TTSService:
logger.error("Model generated empty audio chunk") logger.error("Model generated empty audio chunk")
return return
chunk_data.audio*=volume_multiplier
# For streaming, convert to bytes # For streaming, convert to bytes
if output_format: if output_format:
try: try:
@ -259,6 +263,7 @@ class TTSService:
speed: float = 1.0, speed: float = 1.0,
output_format: str = "wav", output_format: str = "wav",
lang_code: Optional[str] = None, lang_code: Optional[str] = None,
volume_multiplier: Optional[float] = 1.0,
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(), normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
return_timestamps: Optional[bool] = False, return_timestamps: Optional[bool] = False,
) -> AsyncGenerator[AudioChunk, None]: ) -> AsyncGenerator[AudioChunk, None]:
@ -280,48 +285,90 @@ class TTSService:
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream" f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream"
) )
# Process text in chunks with smart splitting # Process text in chunks with smart splitting, handling pause tags
async for chunk_text, tokens in smart_split( async for chunk_text, tokens, pause_duration_s in smart_split(
text, text,
lang_code=pipeline_lang_code, lang_code=pipeline_lang_code,
normalization_options=normalization_options, normalization_options=normalization_options,
): ):
try: if pause_duration_s is not None and pause_duration_s > 0:
# Process audio for chunk # --- Handle Pause Chunk ---
async for chunk_data in self._process_chunk( try:
chunk_text, # Pass text for Kokoro V1 logger.debug(f"Generating {pause_duration_s}s silence chunk")
tokens, # Pass tokens for legacy backends silence_samples = int(pause_duration_s * 24000) # 24kHz sample rate
voice_name, # Pass voice name # Create proper silence as int16 zeros to avoid normalization artifacts
voice_path, # Pass voice path silence_audio = np.zeros(silence_samples, dtype=np.int16)
speed, pause_chunk = AudioChunk(audio=silence_audio, word_timestamps=[]) # Empty timestamps for silence
writer,
output_format,
is_first=(chunk_index == 0),
is_last=False, # We'll update the last chunk later
normalizer=stream_normalizer,
lang_code=pipeline_lang_code, # Pass lang_code
return_timestamps=return_timestamps,
):
if chunk_data.word_timestamps is not None:
for timestamp in chunk_data.word_timestamps:
timestamp.start_time += current_offset
timestamp.end_time += current_offset
current_offset += len(chunk_data.audio) / 24000 # Format and yield the silence chunk
if output_format:
formatted_pause_chunk = await AudioService.convert_audio(
pause_chunk, output_format, writer, speed=speed, chunk_text="",
is_last_chunk=False, trim_audio=False, normalizer=stream_normalizer,
if chunk_data.output is not None:
yield chunk_data
else:
logger.warning(
f"No audio generated for chunk: '{chunk_text[:100]}...'"
) )
chunk_index += 1 if formatted_pause_chunk.output:
except Exception as e: yield formatted_pause_chunk
logger.error( else: # Raw audio mode
f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}" # For raw audio mode, silence is already in the correct format (int16)
) # Skip normalization to avoid any potential artifacts
continue if len(pause_chunk.audio) > 0:
yield pause_chunk
# Update offset based on silence duration
current_offset += pause_duration_s
chunk_index += 1 # Count pause as a yielded chunk
except Exception as e:
logger.error(f"Failed to process pause chunk: {str(e)}")
continue
elif tokens or chunk_text.strip(): # Process if there are tokens OR non-whitespace text
# --- Handle Text Chunk ---
try:
# Process audio for chunk
async for chunk_data in self._process_chunk(
chunk_text, # Pass text for Kokoro V1
tokens, # Pass tokens for legacy backends
voice_name, # Pass voice name
voice_path, # Pass voice path
speed,
writer,
output_format,
is_first=(chunk_index == 0),
volume_multiplier=volume_multiplier,
is_last=False, # We'll update the last chunk later
normalizer=stream_normalizer,
lang_code=pipeline_lang_code, # Pass lang_code
return_timestamps=return_timestamps,
):
if chunk_data.word_timestamps is not None:
for timestamp in chunk_data.word_timestamps:
timestamp.start_time += current_offset
timestamp.end_time += current_offset
# Update offset based on the actual duration of the generated audio chunk
chunk_duration = 0
if chunk_data.audio is not None and len(chunk_data.audio) > 0:
chunk_duration = len(chunk_data.audio) / 24000
current_offset += chunk_duration
# Yield the processed chunk (either formatted or raw)
if chunk_data.output is not None:
yield chunk_data
elif chunk_data.audio is not None and len(chunk_data.audio) > 0:
yield chunk_data
else:
logger.warning(
f"No audio generated for chunk: '{chunk_text[:100]}...'"
)
chunk_index += 1 # Increment chunk index after processing text
except Exception as e:
logger.error(
f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}"
)
continue
# Only finalize if we successfully processed at least one chunk # Only finalize if we successfully processed at least one chunk
if chunk_index > 0: if chunk_index > 0:
@ -337,6 +384,7 @@ class TTSService:
output_format, output_format,
is_first=False, is_first=False,
is_last=True, # Signal this is the last chunk is_last=True, # Signal this is the last chunk
volume_multiplier=volume_multiplier,
normalizer=stream_normalizer, normalizer=stream_normalizer,
lang_code=pipeline_lang_code, # Pass lang_code lang_code=pipeline_lang_code, # Pass lang_code
): ):
@ -356,6 +404,7 @@ class TTSService:
writer: StreamingAudioWriter, writer: StreamingAudioWriter,
speed: float = 1.0, speed: float = 1.0,
return_timestamps: bool = False, return_timestamps: bool = False,
volume_multiplier: Optional[float] = 1.0,
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(), normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
lang_code: Optional[str] = None, lang_code: Optional[str] = None,
) -> AudioChunk: ) -> AudioChunk:
@ -368,6 +417,7 @@ class TTSService:
voice, voice,
writer, writer,
speed=speed, speed=speed,
volume_multiplier=volume_multiplier,
normalization_options=normalization_options, normalization_options=normalization_options,
return_timestamps=return_timestamps, return_timestamps=return_timestamps,
lang_code=lang_code, lang_code=lang_code,

View file

@ -1,3 +1,4 @@
from email.policy import default
from enum import Enum from enum import Enum
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
@ -66,6 +67,10 @@ class NormalizationOptions(BaseModel):
default=True, default=True,
description="Changes phone numbers so they can be properly pronouced by kokoro", description="Changes phone numbers so they can be properly pronouced by kokoro",
) )
replace_remaining_symbols: bool = Field(
default=True,
description="Replaces the remaining symbols after normalization with their words"
)
class OpenAISpeechRequest(BaseModel): class OpenAISpeechRequest(BaseModel):
@ -108,6 +113,10 @@ class OpenAISpeechRequest(BaseModel):
default=None, default=None,
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.", description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
) )
volume_multiplier: Optional[float] = Field(
default = 1.0,
description="A volume multiplier to multiply the output audio by."
)
normalization_options: Optional[NormalizationOptions] = Field( normalization_options: Optional[NormalizationOptions] = Field(
default=NormalizationOptions(), default=NormalizationOptions(),
description="Options for the normalization system", description="Options for the normalization system",
@ -152,6 +161,10 @@ class CaptionedSpeechRequest(BaseModel):
default=None, default=None,
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.", description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
) )
volume_multiplier: Optional[float] = Field(
default = 1.0,
description="A volume multiplier to multiply the output audio by."
)
normalization_options: Optional[NormalizationOptions] = Field( normalization_options: Optional[NormalizationOptions] = Field(
default=NormalizationOptions(), default=NormalizationOptions(),
description="Options for the normalization system", description="Options for the normalization system",

View file

@ -57,19 +57,19 @@ def test_url_localhost():
normalize_text( normalize_text(
"Running on localhost:7860", normalization_options=NormalizationOptions() "Running on localhost:7860", normalization_options=NormalizationOptions()
) )
== "Running on localhost colon 78 60" == "Running on localhost colon seventy-eight sixty"
) )
assert ( assert (
normalize_text( normalize_text(
"Server at localhost:8080/api", normalization_options=NormalizationOptions() "Server at localhost:8080/api", normalization_options=NormalizationOptions()
) )
== "Server at localhost colon 80 80 slash api" == "Server at localhost colon eighty eighty slash api"
) )
assert ( assert (
normalize_text( normalize_text(
"Test localhost:3000/test?v=1", normalization_options=NormalizationOptions() "Test localhost:3000/test?v=1", normalization_options=NormalizationOptions()
) )
== "Test localhost colon 3000 slash test question-mark v equals 1" == "Test localhost colon three thousand slash test question-mark v equals one"
) )
@ -79,17 +79,17 @@ def test_url_ip_addresses():
normalize_text( normalize_text(
"Access 0.0.0.0:9090/test", normalization_options=NormalizationOptions() "Access 0.0.0.0:9090/test", normalization_options=NormalizationOptions()
) )
== "Access 0 dot 0 dot 0 dot 0 colon 90 90 slash test" == "Access zero dot zero dot zero dot zero colon ninety ninety slash test"
) )
assert ( assert (
normalize_text( normalize_text(
"API at 192.168.1.1:8000", normalization_options=NormalizationOptions() "API at 192.168.1.1:8000", normalization_options=NormalizationOptions()
) )
== "API at 192 dot 168 dot 1 dot 1 colon 8000" == "API at one hundred and ninety-two dot one hundred and sixty-eight dot one dot one colon eight thousand"
) )
assert ( assert (
normalize_text("Server 127.0.0.1", normalization_options=NormalizationOptions()) normalize_text("Server 127.0.0.1", normalization_options=NormalizationOptions())
== "Server 127 dot 0 dot 0 dot 1" == "Server one hundred and twenty-seven dot zero dot zero dot one"
) )
@ -146,6 +146,15 @@ def test_money():
) )
== "He lost five point three thousand dollars." == "He lost five point three thousand dollars."
) )
assert (
normalize_text(
"He went gambling and lost about $25.05k.",
normalization_options=NormalizationOptions(),
)
== "He went gambling and lost about twenty-five point zero five thousand dollars."
)
assert ( assert (
normalize_text( normalize_text(
"To put it weirdly -$6.9 million", "To put it weirdly -$6.9 million",
@ -153,11 +162,147 @@ def test_money():
) )
== "To put it weirdly minus six point nine million dollars" == "To put it weirdly minus six point nine million dollars"
) )
assert ( assert (
normalize_text("It costs $50.3.", normalization_options=NormalizationOptions()) normalize_text("It costs $50.3.", normalization_options=NormalizationOptions())
== "It costs fifty dollars and thirty cents." == "It costs fifty dollars and thirty cents."
) )
assert (
normalize_text(
"The plant cost $200,000.8.", normalization_options=NormalizationOptions()
)
== "The plant cost two hundred thousand dollars and eighty cents."
)
assert (
normalize_text(
"Your shopping spree cost $674.03!", normalization_options=NormalizationOptions()
)
== "Your shopping spree cost six hundred and seventy-four dollars and three cents!"
)
assert (
normalize_text(
"€30.2 is in euros", normalization_options=NormalizationOptions()
)
== "thirty euros and twenty cents is in euros"
)
def test_time():
"""Test time normalization"""
assert (
normalize_text(
"Your flight leaves at 10:35 pm",
normalization_options=NormalizationOptions(),
)
== "Your flight leaves at ten thirty-five pm"
)
assert (
normalize_text(
"He departed for london around 5:03 am.",
normalization_options=NormalizationOptions(),
)
== "He departed for london around five oh three am."
)
assert (
normalize_text(
"Only the 13:42 and 15:12 slots are available.",
normalization_options=NormalizationOptions(),
)
== "Only the thirteen forty-two and fifteen twelve slots are available."
)
assert (
normalize_text(
"It is currently 1:00 pm", normalization_options=NormalizationOptions()
)
== "It is currently one pm"
)
assert (
normalize_text(
"It is currently 3:00", normalization_options=NormalizationOptions()
)
== "It is currently three o'clock"
)
assert (
normalize_text(
"12:00 am is midnight", normalization_options=NormalizationOptions()
)
== "twelve am is midnight"
)
def test_number():
"""Test number normalization"""
assert (
normalize_text(
"I bought 1035 cans of soda", normalization_options=NormalizationOptions()
)
== "I bought one thousand and thirty-five cans of soda"
)
assert (
normalize_text(
"The bus has a maximum capacity of 62 people",
normalization_options=NormalizationOptions(),
)
== "The bus has a maximum capacity of sixty-two people"
)
assert (
normalize_text(
"There are 1300 products left in stock",
normalization_options=NormalizationOptions(),
)
== "There are one thousand, three hundred products left in stock"
)
assert (
normalize_text(
"The population is 7,890,000 people.",
normalization_options=NormalizationOptions(),
)
== "The population is seven million, eight hundred and ninety thousand people."
)
assert (
normalize_text(
"He looked around but only found 1.6k of the 10k bricks",
normalization_options=NormalizationOptions(),
)
== "He looked around but only found one point six thousand of the ten thousand bricks"
)
assert (
normalize_text(
"The book has 342 pages.", normalization_options=NormalizationOptions()
)
== "The book has three hundred and forty-two pages."
)
assert (
normalize_text(
"He made -50 sales today.", normalization_options=NormalizationOptions()
)
== "He made minus fifty sales today."
)
assert (
normalize_text(
"56.789 to the power of 1.35 million",
normalization_options=NormalizationOptions(),
)
== "fifty-six point seven eight nine to the power of one point three five million"
)
def test_non_url_text(): def test_non_url_text():
"""Test that non-URL text is unaffected""" """Test that non-URL text is unaffected"""
@ -177,3 +322,12 @@ def test_non_url_text():
normalize_text("It costs $50.", normalization_options=NormalizationOptions()) normalize_text("It costs $50.", normalization_options=NormalizationOptions())
== "It costs fifty dollars." == "It costs fifty dollars."
) )
def test_remaining_symbol():
"""Test that remaining symbols are replaced"""
assert (
normalize_text(
"I love buying products @ good store here & @ other store", normalization_options=NormalizationOptions()
)
== "I love buying products at good store here and at other store"
)

View file

@ -67,7 +67,7 @@ async def test_smart_split_short_text():
"""Test smart splitting with text under max tokens.""" """Test smart splitting with text under max tokens."""
text = "This is a short test sentence." text = "This is a short test sentence."
chunks = [] chunks = []
async for chunk_text, chunk_tokens in smart_split(text): async for chunk_text, chunk_tokens, _ in smart_split(text):
chunks.append((chunk_text, chunk_tokens)) chunks.append((chunk_text, chunk_tokens))
assert len(chunks) == 1 assert len(chunks) == 1
@ -82,7 +82,7 @@ async def test_smart_split_long_text():
text = ". ".join(["This is test sentence number " + str(i) for i in range(20)]) text = ". ".join(["This is test sentence number " + str(i) for i in range(20)])
chunks = [] chunks = []
async for chunk_text, chunk_tokens in smart_split(text): async for chunk_text, chunk_tokens, _ in smart_split(text):
chunks.append((chunk_text, chunk_tokens)) chunks.append((chunk_text, chunk_tokens))
assert len(chunks) > 1 assert len(chunks) > 1
@ -98,8 +98,127 @@ async def test_smart_split_with_punctuation():
text = "First sentence! Second sentence? Third sentence; Fourth sentence: Fifth sentence." text = "First sentence! Second sentence? Third sentence; Fourth sentence: Fifth sentence."
chunks = [] chunks = []
async for chunk_text, chunk_tokens in smart_split(text): async for chunk_text, chunk_tokens, _ in smart_split(text):
chunks.append(chunk_text) chunks.append(chunk_text)
# Verify punctuation is preserved # Verify punctuation is preserved
assert all(any(p in chunk for p in "!?;:.") for chunk in chunks) assert all(any(p in chunk for p in "!?;:.") for chunk in chunks)
def test_process_text_chunk_chinese_phonemes():
"""Test processing with Chinese pinyin phonemes."""
pinyin = "nǐ hǎo lì" # Example pinyin sequence with tones
tokens = process_text_chunk(pinyin, skip_phonemize=True, language="z")
assert isinstance(tokens, list)
assert len(tokens) > 0
def test_get_sentence_info_chinese():
"""Test Chinese sentence splitting and info extraction."""
text = "这是一个句子。这是第二个句子!第三个问题?"
results = get_sentence_info(text, {}, lang_code="z")
assert len(results) == 3
for sentence, tokens, count in results:
assert isinstance(sentence, str)
assert isinstance(tokens, list)
assert isinstance(count, int)
assert count == len(tokens)
assert count > 0
@pytest.mark.asyncio
async def test_smart_split_chinese_short():
"""Test Chinese smart splitting with short text."""
text = "这是一句话。"
chunks = []
async for chunk_text, chunk_tokens, _ in smart_split(text, lang_code="z"):
chunks.append((chunk_text, chunk_tokens))
assert len(chunks) == 1
assert isinstance(chunks[0][0], str)
assert isinstance(chunks[0][1], list)
@pytest.mark.asyncio
async def test_smart_split_chinese_long():
"""Test Chinese smart splitting with longer text."""
text = "".join([f"测试句子 {i}" for i in range(20)])
chunks = []
async for chunk_text, chunk_tokens, _ in smart_split(text, lang_code="z"):
chunks.append((chunk_text, chunk_tokens))
assert len(chunks) > 1
for chunk_text, chunk_tokens in chunks:
assert isinstance(chunk_text, str)
assert isinstance(chunk_tokens, list)
assert len(chunk_tokens) > 0
@pytest.mark.asyncio
async def test_smart_split_chinese_punctuation():
"""Test Chinese smart splitting with punctuation preservation."""
text = "第一句!第二问?第三句;第四句:第五句。"
chunks = []
async for chunk_text, _, _ in smart_split(text, lang_code="z"):
chunks.append(chunk_text)
# Verify Chinese punctuation is preserved
assert all(any(p in chunk for p in "!?;:。") for chunk in chunks)
@pytest.mark.asyncio
async def test_smart_split_with_pause():
"""Test smart splitting with pause tags."""
text = "Hello world [pause:2.5s] How are you?"
chunks = []
async for chunk_text, chunk_tokens, pause_duration in smart_split(text):
chunks.append((chunk_text, chunk_tokens, pause_duration))
# Should have 3 chunks: text, pause, text
assert len(chunks) == 3
# First chunk: text
assert chunks[0][2] is None # No pause
assert "Hello world" in chunks[0][0]
assert len(chunks[0][1]) > 0
# Second chunk: pause
assert chunks[1][2] == 2.5 # 2.5 second pause
assert chunks[1][0] == "" # Empty text
assert len(chunks[1][1]) == 0 # No tokens
# Third chunk: text
assert chunks[2][2] is None # No pause
assert "How are you?" in chunks[2][0]
assert len(chunks[2][1]) > 0
@pytest.mark.asyncio
async def test_smart_split_with_two_pause():
"""Test smart splitting with two pause tags."""
text = "[pause:0.5s][pause:1.67s]0.5"
chunks = []
async for chunk_text, chunk_tokens, pause_duration in smart_split(text):
chunks.append((chunk_text, chunk_tokens, pause_duration))
# Should have 3 chunks: pause, pause, text
assert len(chunks) == 3
# First chunk: pause
assert chunks[0][2] == 0.5 # 0.5 second pause
assert chunks[0][0] == "" # Empty text
assert len(chunks[0][1]) == 0
# Second chunk: pause
assert chunks[1][2] == 1.67 # 1.67 second pause
assert chunks[1][0] == "" # Empty text
assert len(chunks[1][1]) == 0 # No tokens
# Third chunk: text
assert chunks[2][2] is None # No pause
assert "zero point five" in chunks[2][0]
assert len(chunks[2][1]) > 0

View file

@ -3,6 +3,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
import os
from api.src.services.tts_service import TTSService from api.src.services.tts_service import TTSService
@ -102,7 +103,10 @@ async def test_get_voice_path_combined():
service = await TTSService.create("test_output") service = await TTSService.create("test_output")
name, path = await service._get_voices_path("voice1+voice2") name, path = await service._get_voices_path("voice1+voice2")
assert name == "voice1+voice2" assert name == "voice1+voice2"
assert path.endswith("voice1+voice2.pt") # Verify the path points to a temporary file with expected format
assert path.startswith("/tmp/")
assert "voice1+voice2" in path
assert path.endswith(".pt")
mock_save.assert_called_once() mock_save.assert_called_once()

23
dev/Test Phon.py Normal file
View file

@ -0,0 +1,23 @@
import base64
import json
import pydub
import requests
def generate_audio_from_phonemes(phonemes: str, voice: str = "af_bella"):
"""Generate audio from phonemes"""
response = requests.post(
"http://localhost:8880/dev/generate_from_phonemes",
json={"phonemes": phonemes, "voice": voice},
headers={"Accept": "audio/wav"}
)
if response.status_code != 200:
print(f"Error: {response.text}")
return None
return response.content
with open(f"outputnostreammoney.wav", "wb") as f:
f.write(generate_audio_from_phonemes(r"mɪsəki ɪz ɐn ɪkspˌɛɹəmˈɛntᵊl ʤˈitəpˈi ˈɛnʤən dəzˈInd tə pˈWəɹ fjˈuʧəɹ vˈɜɹʒənz ʌv kəkˈɔɹO mˈɑdᵊlz."))

38
dev/Test copy 2.py Normal file
View file

@ -0,0 +1,38 @@
import base64
import json
import pydub
import requests
text = """Running on localhost:7860"""
Type = "wav"
response = requests.post(
"http://localhost:8880/dev/captioned_speech",
json={
"model": "kokoro",
"input": text,
"voice": "af_heart+af_sky",
"speed": 1.0,
"response_format": Type,
"stream": True,
},
stream=True,
)
f = open(f"outputstream.{Type}", "wb")
for chunk in response.iter_lines(decode_unicode=True):
if chunk:
temp_json = json.loads(chunk)
if temp_json["timestamps"] != []:
chunk_json = temp_json
# Decode base 64 stream to bytes
chunk_audio = base64.b64decode(temp_json["audio"].encode("utf-8"))
# Process streaming chunks
f.write(chunk_audio)
# Print word level timestamps
print(chunk_json["timestamps"])

View file

@ -3,9 +3,7 @@ import json
import requests import requests
text = """the administration has offered up a platter of repression for more than a year and is still slated to lose $400 million. text = """奶酪芝士很浓郁!臭豆腐芝士有争议?陈年奶酪价格昂贵。"""
Columbia is the largest private landowner in New York City and boasts an endowment of $14.8 billion;"""
Type = "wav" Type = "wav"
@ -15,7 +13,7 @@ response = requests.post(
json={ json={
"model": "kokoro", "model": "kokoro",
"input": text, "input": text,
"voice": "af_heart+af_sky", "voice": "zf_xiaobei",
"speed": 1.0, "speed": 1.0,
"response_format": Type, "response_format": Type,
"stream": False, "stream": False,

View file

@ -30,6 +30,10 @@ WORKDIR /app
# Copy dependency files # Copy dependency files
COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml
# Install Rust (required to build sudachipy and pyopenjtalk-plus)
RUN curl https://sh.rustup.rs -sSf | sh -s -- -y
ENV PATH="/home/appuser/.cargo/bin:$PATH"
# Install dependencies # Install dependencies
RUN --mount=type=cache,target=/root/.cache/uv \ RUN --mount=type=cache,target=/root/.cache/uv \
uv venv --python 3.10 && \ uv venv --python 3.10 && \