diff --git a/README.md b/README.md index d9b5a61..5e9da1d 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ [![Tested at Model Commit](https://img.shields.io/badge/last--tested--model--commit-1.0::9901c2b-blue)](https://huggingface.co/hexgrad/Kokoro-82M/commit/9901c2b79161b6e898b7ea857ae5298f47b8b0d6) Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model -- Multi-language support (English, Japanese, Korean, Chinese, _Vietnamese soon_) +- Multi-language support (English, Japanese, Chinese, _Vietnamese soon_) - OpenAI-compatible Speech endpoint, NVIDIA GPU accelerated or CPU inference with PyTorch - ONNX support coming soon, see v0.1.5 and earlier for legacy ONNX support in the interim - Debug endpoints for monitoring system stats, integrated web UI on localhost:8880/web @@ -34,10 +34,12 @@ Pre built images are available to run, with arm/multi-arch support, and baked in Refer to the core/config.py file for a full list of variables which can be managed via the environment ```bash -# the `latest` tag can be used, though it may have some unexpected bonus features which impact stability. Named versions should be pinned for your regular usage. Feedback/testing is always welcome +# the `latest` tag can be used, though it may have some unexpected bonus features which impact stability. + Named versions should be pinned for your regular usage. + Feedback/testing is always welcome -docker run -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-cpu:v0.3.0 # CPU, or: -docker run --gpus all -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-gpu:v0.3.0 #NVIDIA GPU +docker run -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-cpu:latest # CPU, or: +docker run --gpus all -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-gpu:latest #NVIDIA GPU ``` diff --git a/VERSION b/VERSION index 0d91a54..72f9fa8 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.3.0 +0.2.4 \ No newline at end of file diff --git a/api/src/core/config.py b/api/src/core/config.py index 1d4657c..87edce0 100644 --- a/api/src/core/config.py +++ b/api/src/core/config.py @@ -31,6 +31,7 @@ class Settings(BaseSettings): # Audio Settings sample_rate: int = 24000 + default_volume_multiplier: float = 1.0 # Text Processing Settings target_min_tokens: int = 175 # Target minimum tokens per chunk target_max_tokens: int = 250 # Target maximum tokens per chunk diff --git a/api/src/core/paths.py b/api/src/core/paths.py index 0e60528..771b70c 100644 --- a/api/src/core/paths.py +++ b/api/src/core/paths.py @@ -300,7 +300,7 @@ async def get_web_file_path(filename: str) -> str: ) # Construct web directory path relative to project root - web_dir = os.path.join("/app", settings.web_player_path) + web_dir = os.path.join(root_dir, settings.web_player_path) # Search in web directory search_paths = [web_dir] diff --git a/api/src/inference/model_manager.py b/api/src/inference/model_manager.py index 9cef95f..eb817ec 100644 --- a/api/src/inference/model_manager.py +++ b/api/src/inference/model_manager.py @@ -141,6 +141,8 @@ Model files not found! You need to download the Kokoro V1 model: try: async for chunk in self._backend.generate(*args, **kwargs): + if settings.default_volume_multiplier != 1.0: + chunk.audio *= settings.default_volume_multiplier yield chunk except Exception as e: raise RuntimeError(f"Generation failed: {e}") diff --git a/api/src/routers/development.py b/api/src/routers/development.py index e749119..8c8ed7e 100644 --- a/api/src/routers/development.py +++ b/api/src/routers/development.py @@ -104,7 +104,7 @@ async def generate_from_phonemes( if chunk_audio is not None: # Normalize audio before writing - normalized_audio = await normalizer.normalize(chunk_audio) + normalized_audio = normalizer.normalize(chunk_audio) # Write chunk and yield bytes chunk_bytes = writer.write_chunk(normalized_audio) if chunk_bytes: @@ -114,6 +114,7 @@ async def generate_from_phonemes( final_bytes = writer.write_chunk(finalize=True) if final_bytes: yield final_bytes + writer.close() else: raise ValueError("Failed to generate audio data") @@ -223,10 +224,13 @@ async def create_captioned_speech( ).decode("utf-8") # Add any chunks that may be in the acumulator into the return word_timestamps - chunk_data.word_timestamps = ( - timestamp_acumulator + chunk_data.word_timestamps - ) - timestamp_acumulator = [] + if chunk_data.word_timestamps is not None: + chunk_data.word_timestamps = ( + timestamp_acumulator + chunk_data.word_timestamps + ) + timestamp_acumulator = [] + else: + chunk_data.word_timestamps = [] yield CaptionedSpeechResponse( audio=base64_chunk, @@ -271,7 +275,7 @@ async def create_captioned_speech( ) # Add any chunks that may be in the acumulator into the return word_timestamps - if chunk_data.word_timestamps != None: + if chunk_data.word_timestamps is not None: chunk_data.word_timestamps = ( timestamp_acumulator + chunk_data.word_timestamps ) @@ -315,6 +319,7 @@ async def create_captioned_speech( writer=writer, speed=request.speed, return_timestamps=request.return_timestamps, + volume_multiplier=request.volume_multiplier, normalization_options=request.normalization_options, lang_code=request.lang_code, ) diff --git a/api/src/routers/openai_compatible.py b/api/src/routers/openai_compatible.py index 4819bc5..c325221 100644 --- a/api/src/routers/openai_compatible.py +++ b/api/src/routers/openai_compatible.py @@ -152,6 +152,7 @@ async def stream_audio_chunks( speed=request.speed, output_format=request.response_format, lang_code=request.lang_code, + volume_multiplier=request.volume_multiplier, normalization_options=request.normalization_options, return_timestamps=unique_properties["return_timestamps"], ): @@ -300,6 +301,7 @@ async def create_speech( voice=voice_name, writer=writer, speed=request.speed, + volume_multiplier=request.volume_multiplier, normalization_options=request.normalization_options, lang_code=request.lang_code, ) diff --git a/api/src/services/audio.py b/api/src/services/audio.py index 5d1d3ff..6ae6d79 100644 --- a/api/src/services/audio.py +++ b/api/src/services/audio.py @@ -80,12 +80,12 @@ class AudioNormalizer: non_silent_index_start, non_silent_index_end = None, None for X in range(0, len(audio_data)): - if audio_data[X] > amplitude_threshold: + if abs(audio_data[X]) > amplitude_threshold: non_silent_index_start = X break for X in range(len(audio_data) - 1, -1, -1): - if audio_data[X] > amplitude_threshold: + if abs(audio_data[X]) > amplitude_threshold: non_silent_index_end = X break diff --git a/api/src/services/streaming_audio_writer.py b/api/src/services/streaming_audio_writer.py index e6ec2d6..de9c84e 100644 --- a/api/src/services/streaming_audio_writer.py +++ b/api/src/services/streaming_audio_writer.py @@ -32,19 +32,29 @@ class StreamingAudioWriter: if self.format in ["wav", "flac", "mp3", "pcm", "aac", "opus"]: if self.format != "pcm": self.output_buffer = BytesIO() + container_options = {} + # Try disabling Xing VBR header for MP3 to fix iOS timeline reading issues + if self.format == 'mp3': + # Disable Xing VBR header + container_options = {'write_xing': '0'} + logger.debug("Disabling Xing VBR header for MP3 encoding.") + self.container = av.open( self.output_buffer, mode="w", format=self.format if self.format != "aac" else "adts", + options=container_options # Pass options here ) self.stream = self.container.add_stream( codec_map[self.format], - sample_rate=self.sample_rate, + rate=self.sample_rate, layout="mono" if self.channels == 1 else "stereo", ) - self.stream.bit_rate = 128000 + # Set bit_rate only for codecs where it's applicable and useful + if self.format in ['mp3', 'aac', 'opus']: + self.stream.bit_rate = 128000 else: - raise ValueError(f"Unsupported format: {format}") + raise ValueError(f"Unsupported format: {self.format}") # Use self.format here def close(self): if hasattr(self, "container"): @@ -65,12 +75,18 @@ class StreamingAudioWriter: if finalize: if self.format != "pcm": + # Flush stream encoder packets = self.stream.encode(None) for packet in packets: self.container.mux(packet) + # Closing the container handles writing the trailer and finalizing the file. + # No explicit flush method is available or needed here. + logger.debug("Muxed final packets.") + + # Get the final bytes from the buffer *before* closing it data = self.output_buffer.getvalue() - self.close() + self.close() # Close container and buffer return data if audio_data is None or len(audio_data) == 0: diff --git a/api/src/services/text_processing/normalizer.py b/api/src/services/text_processing/normalizer.py index e9f73c0..7908318 100644 --- a/api/src/services/text_processing/normalizer.py +++ b/api/src/services/text_processing/normalizer.py @@ -4,12 +4,14 @@ Handles various text formats including URLs, emails, numbers, money, and special Converts them into a format suitable for text-to-speech processing. """ +import math import re from functools import lru_cache +from typing import List, Optional, Union import inflect from numpy import number -from text_to_num import text2num +# from text_to_num import text2num from torch import mul from ...structures.schemas import NormalizationOptions @@ -132,6 +134,24 @@ VALID_UNITS = { "px": "pixel", # CSS units } +SYMBOL_REPLACEMENTS = { + '~': ' ', + '@': ' at ', + '#': ' number ', + '$': ' dollar ', + '%': ' percent ', + '^': ' ', + '&': ' and ', + '*': ' ', + '_': ' ', + '|': ' ', + '\\': ' ', + '/': ' slash ', + '=': ' equals ', + '+': ' plus ', +} + +MONEY_UNITS = {"$": ("dollar", "cent"), "£": ("pound", "pence"), "€": ("euro", "cent")} # Pre-compiled regex patterns for performance EMAIL_PATTERN = re.compile( @@ -152,37 +172,24 @@ UNIT_PATTERN = re.compile( ) TIME_PATTERN = re.compile( - r"([0-9]{2} ?: ?[0-9]{2}( ?: ?[0-9]{2})?)( ?(pm|am)\b)?", re.IGNORECASE + r"([0-9]{1,2} ?: ?[0-9]{2}( ?: ?[0-9]{2})?)( ?(pm|am)\b)?", re.IGNORECASE +) + +MONEY_PATTERN = re.compile( + r"(-?)([" + + "".join(MONEY_UNITS.keys()) + + r"])(\d+(?:\.\d+)?)((?: hundred| thousand| (?:[bm]|tr|quadr)illion|k|m|b|t)*)\b", + re.IGNORECASE, +) + +NUMBER_PATTERN = re.compile( + r"(-?)(\d+(?:\.\d+)?)((?: hundred| thousand| (?:[bm]|tr|quadr)illion|k|m|b)*)\b", + re.IGNORECASE, ) INFLECT_ENGINE = inflect.engine() -def split_num(num: re.Match[str]) -> str: - """Handle number splitting for various formats""" - num = num.group() - if "." in num: - return num - elif ":" in num: - h, m = [int(n) for n in num.split(":")] - if m == 0: - return f"{h} o'clock" - elif m < 10: - return f"{h} oh {m}" - return f"{h} {m}" - year = int(num[:4]) - if year < 1100 or year % 1000 < 10: - return num - left, right = num[:2], int(num[2:4]) - s = "s" if num.endswith("s") else "" - if 100 <= year % 1000 <= 999: - if right == 0: - return f"{left} hundred{s}" - elif right < 10: - return f"{left} oh {right}{s}" - return f"{left} {right}{s}" - - def handle_units(u: re.Match[str]) -> str: """Converts units to their full form""" unit_string = u.group(6).strip() @@ -208,14 +215,61 @@ def conditional_int(number: float, threshold: float = 0.00001): return number +def translate_multiplier(multiplier: str) -> str: + """Translate multiplier abrevations to words""" + + multiplier_translation = { + "k": "thousand", + "m": "million", + "b": "billion", + "t": "trillion", + } + if multiplier.lower() in multiplier_translation: + return multiplier_translation[multiplier.lower()] + return multiplier.strip() + + +def split_four_digit(number: float): + part1 = str(conditional_int(number))[:2] + part2 = str(conditional_int(number))[2:] + return f"{INFLECT_ENGINE.number_to_words(part1)} {INFLECT_ENGINE.number_to_words(part2)}" + + +def handle_numbers(n: re.Match[str]) -> str: + number = n.group(2) + + try: + number = float(number) + except: + return n.group() + + if n.group(1) == "-": + number *= -1 + + multiplier = translate_multiplier(n.group(3)) + + number = conditional_int(number) + if multiplier != "": + multiplier = f" {multiplier}" + else: + if ( + number % 1 == 0 + and len(str(number)) == 4 + and number > 1500 + and number % 1000 > 9 + ): + return split_four_digit(number) + + return f"{INFLECT_ENGINE.number_to_words(number)}{multiplier}" + + def handle_money(m: re.Match[str]) -> str: """Convert money expressions to spoken form""" - bill = "dollar" if m.group(2) == "$" else "pound" - coin = "cent" if m.group(2) == "$" else "pence" + bill, coin = MONEY_UNITS[m.group(2)] + number = m.group(3) - multiplier = m.group(4) try: number = float(number) except: @@ -224,12 +278,17 @@ def handle_money(m: re.Match[str]) -> str: if m.group(1) == "-": number *= -1 + multiplier = translate_multiplier(m.group(4)) + + if multiplier != "": + multiplier = f" {multiplier}" + if number % 1 == 0 or multiplier != "": text_number = f"{INFLECT_ENGINE.number_to_words(conditional_int(number))}{multiplier} {INFLECT_ENGINE.plural(bill, count=number)}" else: sub_number = int(str(number).split(".")[-1].ljust(2, "0")) - text_number = f"{INFLECT_ENGINE.number_to_words(int(round(number)))} {INFLECT_ENGINE.plural(bill, count=number)} and {INFLECT_ENGINE.number_to_words(sub_number)} {INFLECT_ENGINE.plural(coin, count=sub_number)}" + text_number = f"{INFLECT_ENGINE.number_to_words(int(math.floor(number)))} {INFLECT_ENGINE.plural(bill, count=number)} and {INFLECT_ENGINE.number_to_words(sub_number)} {INFLECT_ENGINE.plural(coin, count=sub_number)}" return text_number @@ -320,19 +379,36 @@ def handle_phone_number(p: re.Match[str]) -> str: def handle_time(t: re.Match[str]) -> str: t = t.groups() - numbers = " ".join( - [INFLECT_ENGINE.number_to_words(X.strip()) for X in t[0].split(":")] - ) + time_parts = t[0].split(":") + + numbers = [] + numbers.append(INFLECT_ENGINE.number_to_words(time_parts[0].strip())) + + minute_number = INFLECT_ENGINE.number_to_words(time_parts[1].strip()) + if int(time_parts[1]) < 10: + if int(time_parts[1]) != 0: + numbers.append(f"oh {minute_number}") + else: + numbers.append(minute_number) half = "" - if t[2] is not None: - half = t[2].strip() + if len(time_parts) > 2: + seconds_number = INFLECT_ENGINE.number_to_words(time_parts[2].strip()) + second_word = INFLECT_ENGINE.plural("second", int(time_parts[2].strip())) + numbers.append(f"and {seconds_number} {second_word}") + else: + if t[2] is not None: + half = " " + t[2].strip() + else: + if int(time_parts[1]) == 0: + numbers.append("o'clock") - return numbers + half + return " ".join(numbers) + half def normalize_text(text: str, normalization_options: NormalizationOptions) -> str: """Normalize text for TTS processing""" + # Handle email addresses first if enabled if normalization_options.email_normalization: text = EMAIL_PATTERN.sub(handle_email, text) @@ -357,7 +433,7 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st text, ) - # Replace quotes and brackets + # Replace quotes and brackets (additional cleanup) text = text.replace(chr(8216), "'").replace(chr(8217), "'") text = text.replace("«", chr(8220)).replace("»", chr(8221)) text = text.replace(chr(8220), '"').replace(chr(8221), '"') @@ -366,7 +442,7 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st for a, b in zip("、。!,:;?–", ",.!,:;?-"): text = text.replace(a, b + " ") - # Handle simple time in the format of HH:MM:SS + # Handle simple time in the format of HH:MM:SS (am/pm) text = TIME_PATTERN.sub( handle_time, text, @@ -377,6 +453,11 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st text = re.sub(r" +", " ", text) text = re.sub(r"(?<=\n) +(?=\n)", "", text) + # Handle special characters that might cause audio artifacts first + # Replace newlines with spaces (or pauses if needed) + text = text.replace('\n', ' ') + text = text.replace('\r', ' ') + # Handle titles and abbreviations text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text) text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text) @@ -387,21 +468,23 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st # Handle common words text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text) - # Handle numbers and money + # Handle numbers and money BEFORE replacing special characters text = re.sub(r"(?<=\d),(?=\d)", "", text) - text = re.sub( - r"(?i)(-?)([$£])(\d+(?:\.\d+)?)((?: hundred| thousand| (?:[bm]|tr|quadr)illion)*)\b", + text = MONEY_PATTERN.sub( handle_money, text, ) - text = re.sub( - r"\d*\.\d+|\b\d{4}s?\b|(? st ) text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text) + text = re.sub(r"\s{2,}", " ", text) + return text.strip() diff --git a/api/src/services/text_processing/phonemizer.py b/api/src/services/text_processing/phonemizer.py index 5a50d64..dabf328 100644 --- a/api/src/services/text_processing/phonemizer.py +++ b/api/src/services/text_processing/phonemizer.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod import phonemizer from .normalizer import normalize_text +from ...structures.schemas import NormalizationOptions phonemizers = {} @@ -75,7 +76,7 @@ def create_phonemizer(language: str = "a") -> PhonemizerBackend: Phonemizer backend instance """ # Map language codes to espeak language codes - lang_map = {"a": "en-us", "b": "en-gb"} + lang_map = {"a": "en-us", "b": "en-gb", "z": "z"} if language not in lang_map: raise ValueError(f"Unsupported language code: {language}") @@ -83,20 +84,24 @@ def create_phonemizer(language: str = "a") -> PhonemizerBackend: return EspeakBackend(lang_map[language]) -def phonemize(text: str, language: str = "a", normalize: bool = True) -> str: +def phonemize(text: str, language: str = "a") -> str: """Convert text to phonemes Args: text: Text to convert to phonemes language: Language code ('a' for US English, 'b' for British English) - normalize: Whether to normalize text before phonemization Returns: Phonemized text """ global phonemizers - if normalize: - text = normalize_text(text) + + # Strip input text first to remove problematic leading/trailing spaces + text = text.strip() + if language not in phonemizers: phonemizers[language] = create_phonemizer(language) - return phonemizers[language].phonemize(text) + + result = phonemizers[language].phonemize(text) + # Final strip to ensure no leading/trailing spaces in phonemes + return result.strip() diff --git a/api/src/services/text_processing/text_processor.py b/api/src/services/text_processing/text_processor.py index 584affe..39b0d6c 100644 --- a/api/src/services/text_processing/text_processor.py +++ b/api/src/services/text_processing/text_processor.py @@ -2,7 +2,7 @@ import re import time -from typing import AsyncGenerator, Dict, List, Tuple +from typing import AsyncGenerator, Dict, List, Tuple, Optional from loguru import logger @@ -13,7 +13,11 @@ from .phonemizer import phonemize from .vocabulary import tokenize # Pre-compiled regex patterns for performance -CUSTOM_PHONEMES = re.compile(r"(\[([^\]]|\n)*?\])(\(\/([^\/)]|\n)*?\/\))") +# Updated regex to be more strict and avoid matching isolated brackets +# Only matches complete patterns like [word](/ipa/) and prevents catastrophic backtracking +CUSTOM_PHONEMES = re.compile(r"(\[[^\[\]]*?\])(\(\/[^\/\(\)]*?\/\))") +# Pattern to find pause tags like [pause:0.5s] +PAUSE_TAG_PATTERN = re.compile(r"\[pause:(\d+(?:\.\d+)?)s\]", re.IGNORECASE) def process_text_chunk( @@ -30,6 +34,12 @@ def process_text_chunk( List of token IDs """ start_time = time.time() + + # Strip input text to remove any leading/trailing spaces that could cause artifacts + text = text.strip() + + if not text: + return [] if skip_phonemize: # Input is already phonemes, just tokenize @@ -42,7 +52,9 @@ def process_text_chunk( t1 = time.time() t0 = time.time() - phonemes = phonemize(text, language, normalize=False) # Already normalized + phonemes = phonemize(text, language) + # Strip phonemes result to ensure no extra spaces + phonemes = phonemes.strip() t1 = time.time() t0 = time.time() @@ -88,10 +100,16 @@ def process_text(text: str, language: str = "a") -> List[int]: def get_sentence_info( - text: str, custom_phenomes_list: Dict[str, str] + text: str, custom_phenomes_list: Dict[str, str], lang_code: str = "a" ) -> List[Tuple[str, List[int], int]]: - """Process all sentences and return info.""" - sentences = re.split(r"([.!?;:])(?=\s|$)", text) + """Process all sentences and return info""" + # Detect Chinese text + is_chinese = lang_code.startswith("z") or re.search(r"[\u4e00-\u9fff]", text) + if is_chinese: + # Split using Chinese punctuation + sentences = re.split(r"([,。!?;])+", text) + else: + sentences = re.split(r"([.!?;:])(?=\s|$)", text) phoneme_length, min_value = len(custom_phenomes_list), 0 results = [] @@ -104,16 +122,16 @@ def get_sentence_info( current_id, custom_phenomes_list.pop(current_id) ) min_value += 1 - punct = sentences[i + 1] if i + 1 < len(sentences) else "" - if not sentence: continue - full = sentence + punct + # Strip the full sentence to remove any leading/trailing spaces before processing + full = full.strip() + if not full: # Skip if empty after stripping + continue tokens = process_text_chunk(full) results.append((full, tokens, len(tokens))) - return results @@ -128,149 +146,189 @@ async def smart_split( max_tokens: int = settings.absolute_max_tokens, lang_code: str = "a", normalization_options: NormalizationOptions = NormalizationOptions(), -) -> AsyncGenerator[Tuple[str, List[int]], None]: - """Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens.""" +) -> AsyncGenerator[Tuple[str, List[int], Optional[float]], None]: + """Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens. + + Yields: + Tuple of (text_chunk, tokens, pause_duration_s). + If pause_duration_s is not None, it's a pause chunk with empty text/tokens. + Otherwise, it's a text chunk containing the original text. + """ start_time = time.time() chunk_count = 0 logger.info(f"Starting smart split for {len(text)} chars") - custom_phoneme_list = {} + # --- Step 1: Split by Pause Tags FIRST --- + # This operates on the raw input text + parts = PAUSE_TAG_PATTERN.split(text) + logger.debug(f"Split raw text into {len(parts)} parts by pause tags.") - # Normalize text - if settings.advanced_text_normalization and normalization_options.normalize: - print(lang_code) - if lang_code in ["a", "b", "en-us", "en-gb"]: - text = CUSTOM_PHONEMES.sub( - lambda s: handle_custom_phonemes(s, custom_phoneme_list), text - ) - text = normalize_text(text, normalization_options) - else: - logger.info( - "Skipping text normalization as it is only supported for english" - ) + part_idx = 0 + while part_idx < len(parts): + text_part_raw = parts[part_idx] # This part is raw text + part_idx += 1 - # Process all sentences - sentences = get_sentence_info(text, custom_phoneme_list) + # --- Process Text Part --- + if text_part_raw and text_part_raw.strip(): # Only process if the part is not empty string + # Strip leading and trailing spaces to prevent pause tag splitting artifacts + text_part_raw = text_part_raw.strip() - current_chunk = [] - current_tokens = [] - current_count = 0 + # Apply the original smart_split logic to this text part + custom_phoneme_list = {} - for sentence, tokens, count in sentences: - # Handle sentences that exceed max tokens - if count > max_tokens: - # Yield current chunk if any - if current_chunk: - chunk_text = " ".join(current_chunk) - chunk_count += 1 - logger.debug( - f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" - ) - yield chunk_text, current_tokens - current_chunk = [] - current_tokens = [] - current_count = 0 - - # Split long sentence on commas - clauses = re.split(r"([,])", sentence) - clause_chunk = [] - clause_tokens = [] - clause_count = 0 - - for j in range(0, len(clauses), 2): - clause = clauses[j].strip() - comma = clauses[j + 1] if j + 1 < len(clauses) else "" - - if not clause: - continue - - full_clause = clause + comma - - tokens = process_text_chunk(full_clause) - count = len(tokens) - - # If adding clause keeps us under max and not optimal yet - if ( - clause_count + count <= max_tokens - and clause_count + count <= settings.target_max_tokens - ): - clause_chunk.append(full_clause) - clause_tokens.extend(tokens) - clause_count += count + # Normalize text (original logic) + processed_text = text_part_raw + if settings.advanced_text_normalization and normalization_options.normalize: + if lang_code in ["a", "b", "en-us", "en-gb"]: + processed_text = CUSTOM_PHONEMES.sub( + lambda s: handle_custom_phonemes(s, custom_phoneme_list), processed_text + ) + processed_text = normalize_text(processed_text, normalization_options) else: - # Yield clause chunk if we have one - if clause_chunk: - chunk_text = " ".join(clause_chunk) + logger.info( + "Skipping text normalization as it is only supported for english" + ) + + # Process all sentences (original logic) + sentences = get_sentence_info(processed_text, custom_phoneme_list, lang_code=lang_code) + + current_chunk = [] + current_tokens = [] + current_count = 0 + + for sentence, tokens, count in sentences: + # Handle sentences that exceed max tokens (original logic) + if count > max_tokens: + # Yield current chunk if any + if current_chunk: + chunk_text = " ".join(current_chunk).strip() chunk_count += 1 logger.debug( - f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)" + f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)" ) - yield chunk_text, clause_tokens - clause_chunk = [full_clause] - clause_tokens = tokens - clause_count = count + yield chunk_text, current_tokens, None + current_chunk = [] + current_tokens = [] + current_count = 0 - # Don't forget last clause chunk - if clause_chunk: - chunk_text = " ".join(clause_chunk) - chunk_count += 1 - logger.debug( - f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)" - ) - yield chunk_text, clause_tokens + # Split long sentence on commas (original logic) + clauses = re.split(r"([,])", sentence) + clause_chunk = [] + clause_tokens = [] + clause_count = 0 - # Regular sentence handling - elif ( - current_count >= settings.target_min_tokens - and current_count + count > settings.target_max_tokens - ): - # If we have a good sized chunk and adding next sentence exceeds target, - # yield current chunk and start new one - chunk_text = " ".join(current_chunk) - chunk_count += 1 - logger.info( - f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" - ) - yield chunk_text, current_tokens - current_chunk = [sentence] - current_tokens = tokens - current_count = count - elif current_count + count <= settings.target_max_tokens: - # Keep building chunk while under target max - current_chunk.append(sentence) - current_tokens.extend(tokens) - current_count += count - elif ( - current_count + count <= max_tokens - and current_count < settings.target_min_tokens - ): - # Only exceed target max if we haven't reached minimum size yet - current_chunk.append(sentence) - current_tokens.extend(tokens) - current_count += count - else: - # Yield current chunk and start new one + for j in range(0, len(clauses), 2): + clause = clauses[j].strip() + comma = clauses[j + 1] if j + 1 < len(clauses) else "" + + if not clause: + continue + + full_clause = clause + comma + + tokens = process_text_chunk(full_clause) + count = len(tokens) + + # If adding clause keeps us under max and not optimal yet + if ( + clause_count + count <= max_tokens + and clause_count + count <= settings.target_max_tokens + ): + clause_chunk.append(full_clause) + clause_tokens.extend(tokens) + clause_count += count + else: + # Yield clause chunk if we have one + if clause_chunk: + chunk_text = " ".join(clause_chunk).strip() + chunk_count += 1 + logger.debug( + f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({clause_count} tokens)" + ) + yield chunk_text, clause_tokens, None + clause_chunk = [full_clause] + clause_tokens = tokens + clause_count = count + + # Don't forget last clause chunk + if clause_chunk: + chunk_text = " ".join(clause_chunk).strip() + chunk_count += 1 + logger.debug( + f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({clause_count} tokens)" + ) + yield chunk_text, clause_tokens, None + + # Regular sentence handling (original logic) + elif ( + current_count >= settings.target_min_tokens + and current_count + count > settings.target_max_tokens + ): + # If we have a good sized chunk and adding next sentence exceeds target, + # yield current chunk and start new one + chunk_text = " ".join(current_chunk).strip() + chunk_count += 1 + logger.info( + f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)" + ) + yield chunk_text, current_tokens, None + current_chunk = [sentence] + current_tokens = tokens + current_count = count + elif current_count + count <= settings.target_max_tokens: + # Keep building chunk while under target max + current_chunk.append(sentence) + current_tokens.extend(tokens) + current_count += count + elif ( + current_count + count <= max_tokens + and current_count < settings.target_min_tokens + ): + # Only exceed target max if we haven't reached minimum size yet + current_chunk.append(sentence) + current_tokens.extend(tokens) + current_count += count + else: + # Yield current chunk and start new one + if current_chunk: + chunk_text = " ".join(current_chunk).strip() + chunk_count += 1 + logger.info( + f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)" + ) + yield chunk_text, current_tokens, None + current_chunk = [sentence] + current_tokens = tokens + current_count = count + + # Don't forget the last chunk for this text part if current_chunk: - chunk_text = " ".join(current_chunk) + chunk_text = " ".join(current_chunk).strip() chunk_count += 1 logger.info( - f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" + f"Yielding final chunk {chunk_count} for part: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)" ) - yield chunk_text, current_tokens - current_chunk = [sentence] - current_tokens = tokens - current_count = count + yield chunk_text, current_tokens, None - # Don't forget the last chunk - if current_chunk: - chunk_text = " ".join(current_chunk) - chunk_count += 1 - logger.info( - f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" - ) - yield chunk_text, current_tokens + # --- Handle Pause Part --- + # Check if the next part is a pause duration string + if part_idx < len(parts): + duration_str = parts[part_idx] + # Check if it looks like a valid number string captured by the regex group + if re.fullmatch(r"\d+(?:\.\d+)?", duration_str): + part_idx += 1 # Consume the duration string as it's been processed + try: + duration = float(duration_str) + if duration > 0: + chunk_count += 1 + logger.info(f"Yielding pause chunk {chunk_count}: {duration}s") + yield "", [], duration # Yield pause chunk + except (ValueError, TypeError): + # This case should be rare if re.fullmatch passed, but handle anyway + logger.warning(f"Could not parse valid-looking pause duration: {duration_str}") + # --- End of parts loop --- total_time = time.time() - start_time logger.info( - f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks" + f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks (including pauses)" ) diff --git a/api/src/services/text_processing/vocabulary.py b/api/src/services/text_processing/vocabulary.py index 7a12892..d6d7863 100644 --- a/api/src/services/text_processing/vocabulary.py +++ b/api/src/services/text_processing/vocabulary.py @@ -23,6 +23,8 @@ def tokenize(phonemes: str) -> list[int]: Returns: List of token IDs """ + # Strip phonemes to remove leading/trailing spaces that could cause artifacts + phonemes = phonemes.strip() return [i for i in map(VOCAB.get, phonemes) if i is not None] diff --git a/api/src/services/tts_service.py b/api/src/services/tts_service.py index 0a69b85..46c2fb4 100644 --- a/api/src/services/tts_service.py +++ b/api/src/services/tts_service.py @@ -55,6 +55,7 @@ class TTSService: output_format: Optional[str] = None, is_first: bool = False, is_last: bool = False, + volume_multiplier: Optional[float] = 1.0, normalizer: Optional[AudioNormalizer] = None, lang_code: Optional[str] = None, return_timestamps: Optional[bool] = False, @@ -100,6 +101,7 @@ class TTSService: lang_code=lang_code, return_timestamps=return_timestamps, ): + chunk_data.audio*=volume_multiplier # For streaming, convert to bytes if output_format: try: @@ -132,7 +134,7 @@ class TTSService: speed=speed, return_timestamps=return_timestamps, ) - + if chunk_data.audio is None: logger.error("Model generated None for audio chunk") return @@ -141,6 +143,8 @@ class TTSService: logger.error("Model generated empty audio chunk") return + chunk_data.audio*=volume_multiplier + # For streaming, convert to bytes if output_format: try: @@ -259,6 +263,7 @@ class TTSService: speed: float = 1.0, output_format: str = "wav", lang_code: Optional[str] = None, + volume_multiplier: Optional[float] = 1.0, normalization_options: Optional[NormalizationOptions] = NormalizationOptions(), return_timestamps: Optional[bool] = False, ) -> AsyncGenerator[AudioChunk, None]: @@ -280,48 +285,90 @@ class TTSService: f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream" ) - # Process text in chunks with smart splitting - async for chunk_text, tokens in smart_split( + # Process text in chunks with smart splitting, handling pause tags + async for chunk_text, tokens, pause_duration_s in smart_split( text, lang_code=pipeline_lang_code, normalization_options=normalization_options, ): - try: - # Process audio for chunk - async for chunk_data in self._process_chunk( - chunk_text, # Pass text for Kokoro V1 - tokens, # Pass tokens for legacy backends - voice_name, # Pass voice name - voice_path, # Pass voice path - speed, - writer, - output_format, - is_first=(chunk_index == 0), - is_last=False, # We'll update the last chunk later - normalizer=stream_normalizer, - lang_code=pipeline_lang_code, # Pass lang_code - return_timestamps=return_timestamps, - ): - if chunk_data.word_timestamps is not None: - for timestamp in chunk_data.word_timestamps: - timestamp.start_time += current_offset - timestamp.end_time += current_offset + if pause_duration_s is not None and pause_duration_s > 0: + # --- Handle Pause Chunk --- + try: + logger.debug(f"Generating {pause_duration_s}s silence chunk") + silence_samples = int(pause_duration_s * 24000) # 24kHz sample rate + # Create proper silence as int16 zeros to avoid normalization artifacts + silence_audio = np.zeros(silence_samples, dtype=np.int16) + pause_chunk = AudioChunk(audio=silence_audio, word_timestamps=[]) # Empty timestamps for silence - current_offset += len(chunk_data.audio) / 24000 + # Format and yield the silence chunk + if output_format: + formatted_pause_chunk = await AudioService.convert_audio( + pause_chunk, output_format, writer, speed=speed, chunk_text="", + is_last_chunk=False, trim_audio=False, normalizer=stream_normalizer, - if chunk_data.output is not None: - yield chunk_data - - else: - logger.warning( - f"No audio generated for chunk: '{chunk_text[:100]}...'" ) - chunk_index += 1 - except Exception as e: - logger.error( - f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}" - ) - continue + if formatted_pause_chunk.output: + yield formatted_pause_chunk + else: # Raw audio mode + # For raw audio mode, silence is already in the correct format (int16) + # Skip normalization to avoid any potential artifacts + if len(pause_chunk.audio) > 0: + yield pause_chunk + + # Update offset based on silence duration + current_offset += pause_duration_s + chunk_index += 1 # Count pause as a yielded chunk + + except Exception as e: + logger.error(f"Failed to process pause chunk: {str(e)}") + continue + + elif tokens or chunk_text.strip(): # Process if there are tokens OR non-whitespace text + # --- Handle Text Chunk --- + try: + # Process audio for chunk + async for chunk_data in self._process_chunk( + chunk_text, # Pass text for Kokoro V1 + tokens, # Pass tokens for legacy backends + voice_name, # Pass voice name + voice_path, # Pass voice path + speed, + writer, + output_format, + is_first=(chunk_index == 0), + volume_multiplier=volume_multiplier, + is_last=False, # We'll update the last chunk later + normalizer=stream_normalizer, + lang_code=pipeline_lang_code, # Pass lang_code + return_timestamps=return_timestamps, + ): + if chunk_data.word_timestamps is not None: + for timestamp in chunk_data.word_timestamps: + timestamp.start_time += current_offset + timestamp.end_time += current_offset + + # Update offset based on the actual duration of the generated audio chunk + chunk_duration = 0 + if chunk_data.audio is not None and len(chunk_data.audio) > 0: + chunk_duration = len(chunk_data.audio) / 24000 + current_offset += chunk_duration + + # Yield the processed chunk (either formatted or raw) + if chunk_data.output is not None: + yield chunk_data + elif chunk_data.audio is not None and len(chunk_data.audio) > 0: + yield chunk_data + else: + logger.warning( + f"No audio generated for chunk: '{chunk_text[:100]}...'" + ) + + chunk_index += 1 # Increment chunk index after processing text + except Exception as e: + logger.error( + f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}" + ) + continue # Only finalize if we successfully processed at least one chunk if chunk_index > 0: @@ -337,6 +384,7 @@ class TTSService: output_format, is_first=False, is_last=True, # Signal this is the last chunk + volume_multiplier=volume_multiplier, normalizer=stream_normalizer, lang_code=pipeline_lang_code, # Pass lang_code ): @@ -356,6 +404,7 @@ class TTSService: writer: StreamingAudioWriter, speed: float = 1.0, return_timestamps: bool = False, + volume_multiplier: Optional[float] = 1.0, normalization_options: Optional[NormalizationOptions] = NormalizationOptions(), lang_code: Optional[str] = None, ) -> AudioChunk: @@ -368,6 +417,7 @@ class TTSService: voice, writer, speed=speed, + volume_multiplier=volume_multiplier, normalization_options=normalization_options, return_timestamps=return_timestamps, lang_code=lang_code, diff --git a/api/src/structures/schemas.py b/api/src/structures/schemas.py index 260224c..0aeab7b 100644 --- a/api/src/structures/schemas.py +++ b/api/src/structures/schemas.py @@ -1,3 +1,4 @@ +from email.policy import default from enum import Enum from typing import List, Literal, Optional, Union @@ -66,6 +67,10 @@ class NormalizationOptions(BaseModel): default=True, description="Changes phone numbers so they can be properly pronouced by kokoro", ) + replace_remaining_symbols: bool = Field( + default=True, + description="Replaces the remaining symbols after normalization with their words" + ) class OpenAISpeechRequest(BaseModel): @@ -108,6 +113,10 @@ class OpenAISpeechRequest(BaseModel): default=None, description="Optional language code to use for text processing. If not provided, will use first letter of voice name.", ) + volume_multiplier: Optional[float] = Field( + default = 1.0, + description="A volume multiplier to multiply the output audio by." + ) normalization_options: Optional[NormalizationOptions] = Field( default=NormalizationOptions(), description="Options for the normalization system", @@ -152,6 +161,10 @@ class CaptionedSpeechRequest(BaseModel): default=None, description="Optional language code to use for text processing. If not provided, will use first letter of voice name.", ) + volume_multiplier: Optional[float] = Field( + default = 1.0, + description="A volume multiplier to multiply the output audio by." + ) normalization_options: Optional[NormalizationOptions] = Field( default=NormalizationOptions(), description="Options for the normalization system", diff --git a/api/tests/test_normalizer.py b/api/tests/test_normalizer.py index 0b48e94..6b5a8bf 100644 --- a/api/tests/test_normalizer.py +++ b/api/tests/test_normalizer.py @@ -57,19 +57,19 @@ def test_url_localhost(): normalize_text( "Running on localhost:7860", normalization_options=NormalizationOptions() ) - == "Running on localhost colon 78 60" + == "Running on localhost colon seventy-eight sixty" ) assert ( normalize_text( "Server at localhost:8080/api", normalization_options=NormalizationOptions() ) - == "Server at localhost colon 80 80 slash api" + == "Server at localhost colon eighty eighty slash api" ) assert ( normalize_text( "Test localhost:3000/test?v=1", normalization_options=NormalizationOptions() ) - == "Test localhost colon 3000 slash test question-mark v equals 1" + == "Test localhost colon three thousand slash test question-mark v equals one" ) @@ -79,17 +79,17 @@ def test_url_ip_addresses(): normalize_text( "Access 0.0.0.0:9090/test", normalization_options=NormalizationOptions() ) - == "Access 0 dot 0 dot 0 dot 0 colon 90 90 slash test" + == "Access zero dot zero dot zero dot zero colon ninety ninety slash test" ) assert ( normalize_text( "API at 192.168.1.1:8000", normalization_options=NormalizationOptions() ) - == "API at 192 dot 168 dot 1 dot 1 colon 8000" + == "API at one hundred and ninety-two dot one hundred and sixty-eight dot one dot one colon eight thousand" ) assert ( normalize_text("Server 127.0.0.1", normalization_options=NormalizationOptions()) - == "Server 127 dot 0 dot 0 dot 1" + == "Server one hundred and twenty-seven dot zero dot zero dot one" ) @@ -146,6 +146,15 @@ def test_money(): ) == "He lost five point three thousand dollars." ) + + assert ( + normalize_text( + "He went gambling and lost about $25.05k.", + normalization_options=NormalizationOptions(), + ) + == "He went gambling and lost about twenty-five point zero five thousand dollars." + ) + assert ( normalize_text( "To put it weirdly -$6.9 million", @@ -153,11 +162,147 @@ def test_money(): ) == "To put it weirdly minus six point nine million dollars" ) + assert ( normalize_text("It costs $50.3.", normalization_options=NormalizationOptions()) == "It costs fifty dollars and thirty cents." ) + assert ( + normalize_text( + "The plant cost $200,000.8.", normalization_options=NormalizationOptions() + ) + == "The plant cost two hundred thousand dollars and eighty cents." + ) + + assert ( + normalize_text( + "Your shopping spree cost $674.03!", normalization_options=NormalizationOptions() + ) + == "Your shopping spree cost six hundred and seventy-four dollars and three cents!" + ) + + assert ( + normalize_text( + "€30.2 is in euros", normalization_options=NormalizationOptions() + ) + == "thirty euros and twenty cents is in euros" + ) + + +def test_time(): + """Test time normalization""" + + assert ( + normalize_text( + "Your flight leaves at 10:35 pm", + normalization_options=NormalizationOptions(), + ) + == "Your flight leaves at ten thirty-five pm" + ) + + assert ( + normalize_text( + "He departed for london around 5:03 am.", + normalization_options=NormalizationOptions(), + ) + == "He departed for london around five oh three am." + ) + + assert ( + normalize_text( + "Only the 13:42 and 15:12 slots are available.", + normalization_options=NormalizationOptions(), + ) + == "Only the thirteen forty-two and fifteen twelve slots are available." + ) + + assert ( + normalize_text( + "It is currently 1:00 pm", normalization_options=NormalizationOptions() + ) + == "It is currently one pm" + ) + + assert ( + normalize_text( + "It is currently 3:00", normalization_options=NormalizationOptions() + ) + == "It is currently three o'clock" + ) + + assert ( + normalize_text( + "12:00 am is midnight", normalization_options=NormalizationOptions() + ) + == "twelve am is midnight" + ) + + +def test_number(): + """Test number normalization""" + + assert ( + normalize_text( + "I bought 1035 cans of soda", normalization_options=NormalizationOptions() + ) + == "I bought one thousand and thirty-five cans of soda" + ) + + assert ( + normalize_text( + "The bus has a maximum capacity of 62 people", + normalization_options=NormalizationOptions(), + ) + == "The bus has a maximum capacity of sixty-two people" + ) + + assert ( + normalize_text( + "There are 1300 products left in stock", + normalization_options=NormalizationOptions(), + ) + == "There are one thousand, three hundred products left in stock" + ) + + assert ( + normalize_text( + "The population is 7,890,000 people.", + normalization_options=NormalizationOptions(), + ) + == "The population is seven million, eight hundred and ninety thousand people." + ) + + assert ( + normalize_text( + "He looked around but only found 1.6k of the 10k bricks", + normalization_options=NormalizationOptions(), + ) + == "He looked around but only found one point six thousand of the ten thousand bricks" + ) + + assert ( + normalize_text( + "The book has 342 pages.", normalization_options=NormalizationOptions() + ) + == "The book has three hundred and forty-two pages." + ) + + assert ( + normalize_text( + "He made -50 sales today.", normalization_options=NormalizationOptions() + ) + == "He made minus fifty sales today." + ) + + assert ( + normalize_text( + "56.789 to the power of 1.35 million", + normalization_options=NormalizationOptions(), + ) + == "fifty-six point seven eight nine to the power of one point three five million" + ) + def test_non_url_text(): """Test that non-URL text is unaffected""" @@ -177,3 +322,12 @@ def test_non_url_text(): normalize_text("It costs $50.", normalization_options=NormalizationOptions()) == "It costs fifty dollars." ) + +def test_remaining_symbol(): + """Test that remaining symbols are replaced""" + assert ( + normalize_text( + "I love buying products @ good store here & @ other store", normalization_options=NormalizationOptions() + ) + == "I love buying products at good store here and at other store" + ) diff --git a/api/tests/test_text_processor.py b/api/tests/test_text_processor.py index bfcbcfe..7495d1d 100644 --- a/api/tests/test_text_processor.py +++ b/api/tests/test_text_processor.py @@ -67,7 +67,7 @@ async def test_smart_split_short_text(): """Test smart splitting with text under max tokens.""" text = "This is a short test sentence." chunks = [] - async for chunk_text, chunk_tokens in smart_split(text): + async for chunk_text, chunk_tokens, _ in smart_split(text): chunks.append((chunk_text, chunk_tokens)) assert len(chunks) == 1 @@ -82,7 +82,7 @@ async def test_smart_split_long_text(): text = ". ".join(["This is test sentence number " + str(i) for i in range(20)]) chunks = [] - async for chunk_text, chunk_tokens in smart_split(text): + async for chunk_text, chunk_tokens, _ in smart_split(text): chunks.append((chunk_text, chunk_tokens)) assert len(chunks) > 1 @@ -98,8 +98,127 @@ async def test_smart_split_with_punctuation(): text = "First sentence! Second sentence? Third sentence; Fourth sentence: Fifth sentence." chunks = [] - async for chunk_text, chunk_tokens in smart_split(text): + async for chunk_text, chunk_tokens, _ in smart_split(text): chunks.append(chunk_text) # Verify punctuation is preserved assert all(any(p in chunk for p in "!?;:.") for chunk in chunks) + + +def test_process_text_chunk_chinese_phonemes(): + """Test processing with Chinese pinyin phonemes.""" + pinyin = "nǐ hǎo lì" # Example pinyin sequence with tones + tokens = process_text_chunk(pinyin, skip_phonemize=True, language="z") + assert isinstance(tokens, list) + assert len(tokens) > 0 + + +def test_get_sentence_info_chinese(): + """Test Chinese sentence splitting and info extraction.""" + text = "这是一个句子。这是第二个句子!第三个问题?" + results = get_sentence_info(text, {}, lang_code="z") + + assert len(results) == 3 + for sentence, tokens, count in results: + assert isinstance(sentence, str) + assert isinstance(tokens, list) + assert isinstance(count, int) + assert count == len(tokens) + assert count > 0 + + +@pytest.mark.asyncio +async def test_smart_split_chinese_short(): + """Test Chinese smart splitting with short text.""" + text = "这是一句话。" + chunks = [] + async for chunk_text, chunk_tokens, _ in smart_split(text, lang_code="z"): + chunks.append((chunk_text, chunk_tokens)) + + assert len(chunks) == 1 + assert isinstance(chunks[0][0], str) + assert isinstance(chunks[0][1], list) + + +@pytest.mark.asyncio +async def test_smart_split_chinese_long(): + """Test Chinese smart splitting with longer text.""" + text = "。".join([f"测试句子 {i}" for i in range(20)]) + + chunks = [] + async for chunk_text, chunk_tokens, _ in smart_split(text, lang_code="z"): + chunks.append((chunk_text, chunk_tokens)) + + assert len(chunks) > 1 + for chunk_text, chunk_tokens in chunks: + assert isinstance(chunk_text, str) + assert isinstance(chunk_tokens, list) + assert len(chunk_tokens) > 0 + + +@pytest.mark.asyncio +async def test_smart_split_chinese_punctuation(): + """Test Chinese smart splitting with punctuation preservation.""" + text = "第一句!第二问?第三句;第四句:第五句。" + + chunks = [] + async for chunk_text, _, _ in smart_split(text, lang_code="z"): + chunks.append(chunk_text) + + # Verify Chinese punctuation is preserved + assert all(any(p in chunk for p in "!?;:。") for chunk in chunks) + + +@pytest.mark.asyncio +async def test_smart_split_with_pause(): + """Test smart splitting with pause tags.""" + text = "Hello world [pause:2.5s] How are you?" + + chunks = [] + async for chunk_text, chunk_tokens, pause_duration in smart_split(text): + chunks.append((chunk_text, chunk_tokens, pause_duration)) + + # Should have 3 chunks: text, pause, text + assert len(chunks) == 3 + + # First chunk: text + assert chunks[0][2] is None # No pause + assert "Hello world" in chunks[0][0] + assert len(chunks[0][1]) > 0 + + # Second chunk: pause + assert chunks[1][2] == 2.5 # 2.5 second pause + assert chunks[1][0] == "" # Empty text + assert len(chunks[1][1]) == 0 # No tokens + + # Third chunk: text + assert chunks[2][2] is None # No pause + assert "How are you?" in chunks[2][0] + assert len(chunks[2][1]) > 0 + +@pytest.mark.asyncio +async def test_smart_split_with_two_pause(): + """Test smart splitting with two pause tags.""" + text = "[pause:0.5s][pause:1.67s]0.5" + + chunks = [] + async for chunk_text, chunk_tokens, pause_duration in smart_split(text): + chunks.append((chunk_text, chunk_tokens, pause_duration)) + + # Should have 3 chunks: pause, pause, text + assert len(chunks) == 3 + + # First chunk: pause + assert chunks[0][2] == 0.5 # 0.5 second pause + assert chunks[0][0] == "" # Empty text + assert len(chunks[0][1]) == 0 + + # Second chunk: pause + assert chunks[1][2] == 1.67 # 1.67 second pause + assert chunks[1][0] == "" # Empty text + assert len(chunks[1][1]) == 0 # No tokens + + # Third chunk: text + assert chunks[2][2] is None # No pause + assert "zero point five" in chunks[2][0] + assert len(chunks[2][1]) > 0 \ No newline at end of file diff --git a/api/tests/test_tts_service.py b/api/tests/test_tts_service.py index ae8447a..968c31f 100644 --- a/api/tests/test_tts_service.py +++ b/api/tests/test_tts_service.py @@ -3,6 +3,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import numpy as np import pytest import torch +import os from api.src.services.tts_service import TTSService @@ -102,7 +103,10 @@ async def test_get_voice_path_combined(): service = await TTSService.create("test_output") name, path = await service._get_voices_path("voice1+voice2") assert name == "voice1+voice2" - assert path.endswith("voice1+voice2.pt") + # Verify the path points to a temporary file with expected format + assert path.startswith("/tmp/") + assert "voice1+voice2" in path + assert path.endswith(".pt") mock_save.assert_called_once() diff --git a/dev/Test Phon.py b/dev/Test Phon.py new file mode 100644 index 0000000..d3ba783 --- /dev/null +++ b/dev/Test Phon.py @@ -0,0 +1,23 @@ +import base64 +import json + +import pydub +import requests + +def generate_audio_from_phonemes(phonemes: str, voice: str = "af_bella"): + """Generate audio from phonemes""" + response = requests.post( + "http://localhost:8880/dev/generate_from_phonemes", + json={"phonemes": phonemes, "voice": voice}, + headers={"Accept": "audio/wav"} + ) + if response.status_code != 200: + print(f"Error: {response.text}") + return None + return response.content + + + + +with open(f"outputnostreammoney.wav", "wb") as f: + f.write(generate_audio_from_phonemes(r"mɪsəki ɪz ɐn ɪkspˌɛɹəmˈɛntᵊl ʤˈitəpˈi ˈɛnʤən dəzˈInd tə pˈWəɹ fjˈuʧəɹ vˈɜɹʒənz ʌv kəkˈɔɹO mˈɑdᵊlz.")) \ No newline at end of file diff --git a/dev/Test copy 2.py b/dev/Test copy 2.py new file mode 100644 index 0000000..52634ec --- /dev/null +++ b/dev/Test copy 2.py @@ -0,0 +1,38 @@ +import base64 +import json + +import pydub +import requests + +text = """Running on localhost:7860""" + + +Type = "wav" +response = requests.post( + "http://localhost:8880/dev/captioned_speech", + json={ + "model": "kokoro", + "input": text, + "voice": "af_heart+af_sky", + "speed": 1.0, + "response_format": Type, + "stream": True, + }, + stream=True, +) + +f = open(f"outputstream.{Type}", "wb") +for chunk in response.iter_lines(decode_unicode=True): + if chunk: + temp_json = json.loads(chunk) + if temp_json["timestamps"] != []: + chunk_json = temp_json + + # Decode base 64 stream to bytes + chunk_audio = base64.b64decode(temp_json["audio"].encode("utf-8")) + + # Process streaming chunks + f.write(chunk_audio) + + # Print word level timestamps + print(chunk_json["timestamps"]) diff --git a/dev/Test money.py b/dev/Test money.py index 47e2a9c..57d1fa6 100644 --- a/dev/Test money.py +++ b/dev/Test money.py @@ -3,9 +3,7 @@ import json import requests -text = """the administration has offered up a platter of repression for more than a year and is still slated to lose $400 million. - -Columbia is the largest private landowner in New York City and boasts an endowment of $14.8 billion;""" +text = """奶酪芝士很浓郁!臭豆腐芝士有争议?陈年奶酪价格昂贵。""" Type = "wav" @@ -15,7 +13,7 @@ response = requests.post( json={ "model": "kokoro", "input": text, - "voice": "af_heart+af_sky", + "voice": "zf_xiaobei", "speed": 1.0, "response_format": Type, "stream": False, diff --git a/docker/cpu/Dockerfile b/docker/cpu/Dockerfile index d770a6c..f528307 100644 --- a/docker/cpu/Dockerfile +++ b/docker/cpu/Dockerfile @@ -30,6 +30,10 @@ WORKDIR /app # Copy dependency files COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml +# Install Rust (required to build sudachipy and pyopenjtalk-plus) +RUN curl https://sh.rustup.rs -sSf | sh -s -- -y +ENV PATH="/home/appuser/.cargo/bin:$PATH" + # Install dependencies RUN --mount=type=cache,target=/root/.cache/uv \ uv venv --python 3.10 && \