From b0f46276ebe0168a3c02c0fca68fd9cc8f4ecbd9 Mon Sep 17 00:00:00 2001 From: Mike Bailey Date: Sun, 11 May 2025 01:00:41 +1000 Subject: [PATCH 01/14] Update paths.py Use root_dir instead of /app This was breaking things for me when I started the app from a script that was not in the root_dir --- api/src/core/paths.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/src/core/paths.py b/api/src/core/paths.py index 0e60528..771b70c 100644 --- a/api/src/core/paths.py +++ b/api/src/core/paths.py @@ -300,7 +300,7 @@ async def get_web_file_path(filename: str) -> str: ) # Construct web directory path relative to project root - web_dir = os.path.join("/app", settings.web_player_path) + web_dir = os.path.join(root_dir, settings.web_player_path) # Search in web directory search_paths = [web_dir] From 75963c4aebea9c570b3ee015521869f4ea10a3e0 Mon Sep 17 00:00:00 2001 From: JCallicoat Date: Thu, 22 May 2025 06:49:37 -0500 Subject: [PATCH 02/14] Add a volume multiplier setting Allow configuring output volume via multiplier applied to np array of audio chunk. Defaults to 1.0 which is no-op. Fixes #110 --- api/src/core/config.py | 1 + api/src/inference/model_manager.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/api/src/core/config.py b/api/src/core/config.py index 1d4657c..3bc825c 100644 --- a/api/src/core/config.py +++ b/api/src/core/config.py @@ -31,6 +31,7 @@ class Settings(BaseSettings): # Audio Settings sample_rate: int = 24000 + volume_multiplier: float = 1.0 # Text Processing Settings target_min_tokens: int = 175 # Target minimum tokens per chunk target_max_tokens: int = 250 # Target maximum tokens per chunk diff --git a/api/src/inference/model_manager.py b/api/src/inference/model_manager.py index 9cef95f..0b4cd81 100644 --- a/api/src/inference/model_manager.py +++ b/api/src/inference/model_manager.py @@ -141,6 +141,8 @@ Model files not found! You need to download the Kokoro V1 model: try: async for chunk in self._backend.generate(*args, **kwargs): + if settings.volume_multiplier != 1.0: + chunk.audio *= settings.volume_multiplier yield chunk except Exception as e: raise RuntimeError(f"Generation failed: {e}") From 9c279f2b5eac3d85e793576c959621f13eaacaa2 Mon Sep 17 00:00:00 2001 From: jiaohuix <1152937237@qq.com> Date: Mon, 26 May 2025 15:30:03 +0800 Subject: [PATCH 03/14] feat(text): add Chinese punctuation-based sentence splitting for better TTS --- .../services/text_processing/normalizer.py | 2 +- .../text_processing/text_processor.py | 64 ++++++++++++------- 2 files changed, 41 insertions(+), 25 deletions(-) diff --git a/api/src/services/text_processing/normalizer.py b/api/src/services/text_processing/normalizer.py index 280a26e..5e5d6b6 100644 --- a/api/src/services/text_processing/normalizer.py +++ b/api/src/services/text_processing/normalizer.py @@ -11,7 +11,7 @@ from typing import List, Optional, Union import inflect from numpy import number -from text_to_num import text2num +# from text_to_num import text2num from torch import mul from ...structures.schemas import NormalizationOptions diff --git a/api/src/services/text_processing/text_processor.py b/api/src/services/text_processing/text_processor.py index 584affe..77fd525 100644 --- a/api/src/services/text_processing/text_processor.py +++ b/api/src/services/text_processing/text_processor.py @@ -88,32 +88,48 @@ def process_text(text: str, language: str = "a") -> List[int]: def get_sentence_info( - text: str, custom_phenomes_list: Dict[str, str] + text: str, custom_phenomes_list: Dict[str, str], lang_code: str = "a" ) -> List[Tuple[str, List[int], int]]: - """Process all sentences and return info.""" - sentences = re.split(r"([.!?;:])(?=\s|$)", text) + """Process all sentences and return info, 支持中文分句""" + # 判断是否为中文 + is_chinese = lang_code.startswith("zh") or re.search(r"[\u4e00-\u9fff]", text) + if is_chinese: + # 按中文标点断句 + sentences = re.split(r"([,。!?;])", text) + # 合并标点 + merged = [] + for i in range(0, len(sentences)-1, 2): + merged.append(sentences[i] + sentences[i+1]) + if len(sentences) % 2 == 1: + merged.append(sentences[-1]) + sentences = merged + else: + sentences = re.split(r"([.!?;:])(?=\s|$)", text) phoneme_length, min_value = len(custom_phenomes_list), 0 - results = [] - for i in range(0, len(sentences), 2): - sentence = sentences[i].strip() - for replaced in range(min_value, phoneme_length): - current_id = f"" - if current_id in sentence: - sentence = sentence.replace( - current_id, custom_phenomes_list.pop(current_id) - ) - min_value += 1 - - punct = sentences[i + 1] if i + 1 < len(sentences) else "" - - if not sentence: - continue - - full = sentence + punct - tokens = process_text_chunk(full) - results.append((full, tokens, len(tokens))) - + if is_chinese: + for sentence in sentences: + sentence = sentence.strip() + if not sentence: + continue + tokens = process_text_chunk(sentence) + results.append((sentence, tokens, len(tokens))) + else: + for i in range(0, len(sentences), 2): + sentence = sentences[i].strip() + for replaced in range(min_value, phoneme_length): + current_id = f"" + if current_id in sentence: + sentence = sentence.replace( + current_id, custom_phenomes_list.pop(current_id) + ) + min_value += 1 + punct = sentences[i + 1] if i + 1 < len(sentences) else "" + if not sentence: + continue + full = sentence + punct + tokens = process_text_chunk(full) + results.append((full, tokens, len(tokens))) return results @@ -150,7 +166,7 @@ async def smart_split( ) # Process all sentences - sentences = get_sentence_info(text, custom_phoneme_list) + sentences = get_sentence_info(text, custom_phoneme_list, lang_code=lang_code) current_chunk = [] current_tokens = [] From b89da1ff280d1d32bc70b8df01ff3c738e7889ba Mon Sep 17 00:00:00 2001 From: Fireblade2534 Date: Wed, 28 May 2025 14:53:00 +0000 Subject: [PATCH 04/14] Make the code cleaner and add tests --- .../services/text_processing/phonemizer.py | 2 +- .../text_processing/text_processor.py | 51 ++++++--------- api/tests/test_text_processor.py | 62 +++++++++++++++++++ dev/Test money.py | 6 +- 4 files changed, 83 insertions(+), 38 deletions(-) diff --git a/api/src/services/text_processing/phonemizer.py b/api/src/services/text_processing/phonemizer.py index 5a50d64..c010005 100644 --- a/api/src/services/text_processing/phonemizer.py +++ b/api/src/services/text_processing/phonemizer.py @@ -75,7 +75,7 @@ def create_phonemizer(language: str = "a") -> PhonemizerBackend: Phonemizer backend instance """ # Map language codes to espeak language codes - lang_map = {"a": "en-us", "b": "en-gb"} + lang_map = {"a": "en-us", "b": "en-gb", "z": "z"} if language not in lang_map: raise ValueError(f"Unsupported language code: {language}") diff --git a/api/src/services/text_processing/text_processor.py b/api/src/services/text_processing/text_processor.py index 77fd525..0dbb348 100644 --- a/api/src/services/text_processing/text_processor.py +++ b/api/src/services/text_processing/text_processor.py @@ -92,44 +92,30 @@ def get_sentence_info( ) -> List[Tuple[str, List[int], int]]: """Process all sentences and return info, 支持中文分句""" # 判断是否为中文 - is_chinese = lang_code.startswith("zh") or re.search(r"[\u4e00-\u9fff]", text) + is_chinese = lang_code.startswith("z") or re.search(r"[\u4e00-\u9fff]", text) if is_chinese: # 按中文标点断句 - sentences = re.split(r"([,。!?;])", text) - # 合并标点 - merged = [] - for i in range(0, len(sentences)-1, 2): - merged.append(sentences[i] + sentences[i+1]) - if len(sentences) % 2 == 1: - merged.append(sentences[-1]) - sentences = merged + sentences = re.split(r"([,。!?;])+", text) else: sentences = re.split(r"([.!?;:])(?=\s|$)", text) phoneme_length, min_value = len(custom_phenomes_list), 0 + results = [] - if is_chinese: - for sentence in sentences: - sentence = sentence.strip() - if not sentence: - continue - tokens = process_text_chunk(sentence) - results.append((sentence, tokens, len(tokens))) - else: - for i in range(0, len(sentences), 2): - sentence = sentences[i].strip() - for replaced in range(min_value, phoneme_length): - current_id = f"" - if current_id in sentence: - sentence = sentence.replace( - current_id, custom_phenomes_list.pop(current_id) - ) - min_value += 1 - punct = sentences[i + 1] if i + 1 < len(sentences) else "" - if not sentence: - continue - full = sentence + punct - tokens = process_text_chunk(full) - results.append((full, tokens, len(tokens))) + for i in range(0, len(sentences), 2): + sentence = sentences[i].strip() + for replaced in range(min_value, phoneme_length): + current_id = f"" + if current_id in sentence: + sentence = sentence.replace( + current_id, custom_phenomes_list.pop(current_id) + ) + min_value += 1 + punct = sentences[i + 1] if i + 1 < len(sentences) else "" + if not sentence: + continue + full = sentence + punct + tokens = process_text_chunk(full) + results.append((full, tokens, len(tokens))) return results @@ -154,7 +140,6 @@ async def smart_split( # Normalize text if settings.advanced_text_normalization and normalization_options.normalize: - print(lang_code) if lang_code in ["a", "b", "en-us", "en-gb"]: text = CUSTOM_PHONEMES.sub( lambda s: handle_custom_phonemes(s, custom_phoneme_list), text diff --git a/api/tests/test_text_processor.py b/api/tests/test_text_processor.py index bfcbcfe..6ff8282 100644 --- a/api/tests/test_text_processor.py +++ b/api/tests/test_text_processor.py @@ -103,3 +103,65 @@ async def test_smart_split_with_punctuation(): # Verify punctuation is preserved assert all(any(p in chunk for p in "!?;:.") for chunk in chunks) + +def test_process_text_chunk_chinese_phonemes(): + """Test processing with Chinese pinyin phonemes.""" + pinyin = "nǐ hǎo lì" # Example pinyin sequence with tones + tokens = process_text_chunk(pinyin, skip_phonemize=True, language="z") + assert isinstance(tokens, list) + assert len(tokens) > 0 + + +def test_get_sentence_info_chinese(): + """Test Chinese sentence splitting and info extraction.""" + text = "这是一个句子。这是第二个句子!第三个问题?" + results = get_sentence_info(text, {}, lang_code="z") + + assert len(results) == 3 + for sentence, tokens, count in results: + assert isinstance(sentence, str) + assert isinstance(tokens, list) + assert isinstance(count, int) + assert count == len(tokens) + assert count > 0 + +@pytest.mark.asyncio +async def test_smart_split_chinese_short(): + """Test Chinese smart splitting with short text.""" + text = "这是一句话。" + chunks = [] + async for chunk_text, chunk_tokens in smart_split(text, lang_code="z"): + chunks.append((chunk_text, chunk_tokens)) + + assert len(chunks) == 1 + assert isinstance(chunks[0][0], str) + assert isinstance(chunks[0][1], list) + + +@pytest.mark.asyncio +async def test_smart_split_chinese_long(): + """Test Chinese smart splitting with longer text.""" + text = "。".join([f"测试句子 {i}" for i in range(20)]) + + chunks = [] + async for chunk_text, chunk_tokens in smart_split(text, lang_code="z"): + chunks.append((chunk_text, chunk_tokens)) + + assert len(chunks) > 1 + for chunk_text, chunk_tokens in chunks: + assert isinstance(chunk_text, str) + assert isinstance(chunk_tokens, list) + assert len(chunk_tokens) > 0 + + +@pytest.mark.asyncio +async def test_smart_split_chinese_punctuation(): + """Test Chinese smart splitting with punctuation preservation.""" + text = "第一句!第二问?第三句;第四句:第五句。" + + chunks = [] + async for chunk_text, _ in smart_split(text, lang_code="z"): + chunks.append(chunk_text) + + # Verify Chinese punctuation is preserved + assert all(any(p in chunk for p in "!?;:。") for chunk in chunks) \ No newline at end of file diff --git a/dev/Test money.py b/dev/Test money.py index 47e2a9c..57d1fa6 100644 --- a/dev/Test money.py +++ b/dev/Test money.py @@ -3,9 +3,7 @@ import json import requests -text = """the administration has offered up a platter of repression for more than a year and is still slated to lose $400 million. - -Columbia is the largest private landowner in New York City and boasts an endowment of $14.8 billion;""" +text = """奶酪芝士很浓郁!臭豆腐芝士有争议?陈年奶酪价格昂贵。""" Type = "wav" @@ -15,7 +13,7 @@ response = requests.post( json={ "model": "kokoro", "input": text, - "voice": "af_heart+af_sky", + "voice": "zf_xiaobei", "speed": 1.0, "response_format": Type, "stream": False, From ab8ab7d7494b47b11d8dfc17b72acf77e63bdd9e Mon Sep 17 00:00:00 2001 From: Lukin Date: Fri, 30 May 2025 22:52:58 +0800 Subject: [PATCH 05/14] Refactor audio processing and text normalization: Update audio normalization to use absolute amplitude threshold, enhance streaming audio writer with MP3 container options, and improve text normalization by stripping spaces and handling special characters to prevent audio artifacts. --- api/src/services/audio.py | 4 ++-- api/src/services/streaming_audio_writer.py | 24 +++++++++++++++---- .../services/text_processing/normalizer.py | 24 ++++++++++++++++++- .../services/text_processing/phonemizer.py | 17 +++++++++++-- .../text_processing/text_processor.py | 24 ++++++++++++++----- .../services/text_processing/vocabulary.py | 2 ++ 6 files changed, 80 insertions(+), 15 deletions(-) diff --git a/api/src/services/audio.py b/api/src/services/audio.py index 5d1d3ff..6ae6d79 100644 --- a/api/src/services/audio.py +++ b/api/src/services/audio.py @@ -80,12 +80,12 @@ class AudioNormalizer: non_silent_index_start, non_silent_index_end = None, None for X in range(0, len(audio_data)): - if audio_data[X] > amplitude_threshold: + if abs(audio_data[X]) > amplitude_threshold: non_silent_index_start = X break for X in range(len(audio_data) - 1, -1, -1): - if audio_data[X] > amplitude_threshold: + if abs(audio_data[X]) > amplitude_threshold: non_silent_index_end = X break diff --git a/api/src/services/streaming_audio_writer.py b/api/src/services/streaming_audio_writer.py index e6ec2d6..85740aa 100644 --- a/api/src/services/streaming_audio_writer.py +++ b/api/src/services/streaming_audio_writer.py @@ -32,19 +32,29 @@ class StreamingAudioWriter: if self.format in ["wav", "flac", "mp3", "pcm", "aac", "opus"]: if self.format != "pcm": self.output_buffer = BytesIO() + container_options = {} + # Try disabling Xing VBR header for MP3 to fix iOS timeline reading issues + if self.format == 'mp3': + # Disable Xing VBR header + container_options = {'write_xing': '0'} + logger.debug("Disabling Xing VBR header for MP3 encoding.") + self.container = av.open( self.output_buffer, mode="w", format=self.format if self.format != "aac" else "adts", + options=container_options # Pass options here ) self.stream = self.container.add_stream( codec_map[self.format], - sample_rate=self.sample_rate, + rate=self.sample_rate, # Correct parameter name is 'rate' layout="mono" if self.channels == 1 else "stereo", ) - self.stream.bit_rate = 128000 + # Set bit_rate only for codecs where it's applicable and useful + if self.format in ['mp3', 'aac', 'opus']: + self.stream.bit_rate = 128000 # Example bitrate, can be configured else: - raise ValueError(f"Unsupported format: {format}") + raise ValueError(f"Unsupported format: {self.format}") # Use self.format here def close(self): if hasattr(self, "container"): @@ -65,12 +75,18 @@ class StreamingAudioWriter: if finalize: if self.format != "pcm": + # Flush stream encoder packets = self.stream.encode(None) for packet in packets: self.container.mux(packet) + # Closing the container handles writing the trailer and finalizing the file. + # No explicit flush method is available or needed here. + logger.debug("Muxed final packets.") + + # Get the final bytes from the buffer *before* closing it data = self.output_buffer.getvalue() - self.close() + self.close() # Close container and buffer return data if audio_data is None or len(audio_data) == 0: diff --git a/api/src/services/text_processing/normalizer.py b/api/src/services/text_processing/normalizer.py index 5e5d6b6..2163cbc 100644 --- a/api/src/services/text_processing/normalizer.py +++ b/api/src/services/text_processing/normalizer.py @@ -391,6 +391,7 @@ def handle_time(t: re.Match[str]) -> str: def normalize_text(text: str, normalization_options: NormalizationOptions) -> str: """Normalize text for TTS processing""" + # Handle email addresses first if enabled if normalization_options.email_normalization: text = EMAIL_PATTERN.sub(handle_email, text) @@ -415,7 +416,7 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st text, ) - # Replace quotes and brackets + # Replace quotes and brackets (additional cleanup) text = text.replace(chr(8216), "'").replace(chr(8217), "'") text = text.replace("«", chr(8220)).replace("»", chr(8221)) text = text.replace(chr(8220), '"').replace(chr(8221), '"') @@ -435,6 +436,27 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st text = re.sub(r" +", " ", text) text = re.sub(r"(?<=\n) +(?=\n)", "", text) + # Handle special characters that might cause audio artifacts first + # Replace newlines with spaces (or pauses if needed) + text = text.replace('\n', ' ') + text = text.replace('\r', ' ') + + # Handle other problematic symbols + text = text.replace('~', '') # Remove tilde + text = text.replace('@', ' at ') # At symbol + text = text.replace('#', ' number ') # Hash/pound + text = text.replace('$', ' dollar ') # Dollar sign (if not handled by money pattern) + text = text.replace('%', ' percent ') # Percent sign + text = text.replace('^', '') # Caret + text = text.replace('&', ' and ') # Ampersand + text = text.replace('*', '') # Asterisk + text = text.replace('_', ' ') # Underscore to space + text = text.replace('|', ' ') # Pipe to space + text = text.replace('\\', ' ') # Backslash to space + text = text.replace('/', ' slash ') # Forward slash to space (unless in URLs) + text = text.replace('=', ' equals ') # Equals sign + text = text.replace('+', ' plus ') # Plus sign + # Handle titles and abbreviations text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text) text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text) diff --git a/api/src/services/text_processing/phonemizer.py b/api/src/services/text_processing/phonemizer.py index c010005..ae49cd9 100644 --- a/api/src/services/text_processing/phonemizer.py +++ b/api/src/services/text_processing/phonemizer.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod import phonemizer from .normalizer import normalize_text +from ...structures.schemas import NormalizationOptions phonemizers = {} @@ -95,8 +96,20 @@ def phonemize(text: str, language: str = "a", normalize: bool = True) -> str: Phonemized text """ global phonemizers + + # Strip input text first to remove problematic leading/trailing spaces + text = text.strip() + if normalize: - text = normalize_text(text) + # Create default normalization options and normalize text + normalization_options = NormalizationOptions() + text = normalize_text(text, normalization_options) + # Strip again after normalization + text = text.strip() + if language not in phonemizers: phonemizers[language] = create_phonemizer(language) - return phonemizers[language].phonemize(text) + + result = phonemizers[language].phonemize(text) + # Final strip to ensure no leading/trailing spaces in phonemes + return result.strip() diff --git a/api/src/services/text_processing/text_processor.py b/api/src/services/text_processing/text_processor.py index 0dbb348..3d90325 100644 --- a/api/src/services/text_processing/text_processor.py +++ b/api/src/services/text_processing/text_processor.py @@ -30,6 +30,12 @@ def process_text_chunk( List of token IDs """ start_time = time.time() + + # Strip input text to remove any leading/trailing spaces that could cause artifacts + text = text.strip() + + if not text: + return [] if skip_phonemize: # Input is already phonemes, just tokenize @@ -43,6 +49,8 @@ def process_text_chunk( t0 = time.time() phonemes = phonemize(text, language, normalize=False) # Already normalized + # Strip phonemes result to ensure no extra spaces + phonemes = phonemes.strip() t1 = time.time() t0 = time.time() @@ -114,6 +122,10 @@ def get_sentence_info( if not sentence: continue full = sentence + punct + # Strip the full sentence to remove any leading/trailing spaces before processing + full = full.strip() + if not full: # Skip if empty after stripping + continue tokens = process_text_chunk(full) results.append((full, tokens, len(tokens))) return results @@ -162,7 +174,7 @@ async def smart_split( if count > max_tokens: # Yield current chunk if any if current_chunk: - chunk_text = " ".join(current_chunk) + chunk_text = " ".join(current_chunk).strip() # Strip after joining chunk_count += 1 logger.debug( f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" @@ -201,7 +213,7 @@ async def smart_split( else: # Yield clause chunk if we have one if clause_chunk: - chunk_text = " ".join(clause_chunk) + chunk_text = " ".join(clause_chunk).strip() # Strip after joining chunk_count += 1 logger.debug( f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)" @@ -213,7 +225,7 @@ async def smart_split( # Don't forget last clause chunk if clause_chunk: - chunk_text = " ".join(clause_chunk) + chunk_text = " ".join(clause_chunk).strip() # Strip after joining chunk_count += 1 logger.debug( f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)" @@ -227,7 +239,7 @@ async def smart_split( ): # If we have a good sized chunk and adding next sentence exceeds target, # yield current chunk and start new one - chunk_text = " ".join(current_chunk) + chunk_text = " ".join(current_chunk).strip() # Strip after joining chunk_count += 1 logger.info( f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" @@ -252,7 +264,7 @@ async def smart_split( else: # Yield current chunk and start new one if current_chunk: - chunk_text = " ".join(current_chunk) + chunk_text = " ".join(current_chunk).strip() # Strip after joining chunk_count += 1 logger.info( f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" @@ -264,7 +276,7 @@ async def smart_split( # Don't forget the last chunk if current_chunk: - chunk_text = " ".join(current_chunk) + chunk_text = " ".join(current_chunk).strip() # Strip after joining chunk_count += 1 logger.info( f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" diff --git a/api/src/services/text_processing/vocabulary.py b/api/src/services/text_processing/vocabulary.py index 7a12892..d6d7863 100644 --- a/api/src/services/text_processing/vocabulary.py +++ b/api/src/services/text_processing/vocabulary.py @@ -23,6 +23,8 @@ def tokenize(phonemes: str) -> list[int]: Returns: List of token IDs """ + # Strip phonemes to remove leading/trailing spaces that could cause artifacts + phonemes = phonemes.strip() return [i for i in map(VOCAB.get, phonemes) if i is not None] From 84d2a4d806ecb3bf49021a36fe27b7eaa792c8c0 Mon Sep 17 00:00:00 2001 From: Lukin Date: Fri, 30 May 2025 23:06:41 +0800 Subject: [PATCH 06/14] Enhance TTS text processing: Implement pause tag handling in smart_split, allowing for better audio chunk generation with pauses. Update related tests to validate new functionality and ensure compatibility with existing features. --- .../text_processing/text_processor.py | 293 ++++++++++-------- api/src/services/tts_service.py | 112 ++++--- api/tests/test_text_processor.py | 44 ++- 3 files changed, 282 insertions(+), 167 deletions(-) diff --git a/api/src/services/text_processing/text_processor.py b/api/src/services/text_processing/text_processor.py index 3d90325..c5a442d 100644 --- a/api/src/services/text_processing/text_processor.py +++ b/api/src/services/text_processing/text_processor.py @@ -2,7 +2,7 @@ import re import time -from typing import AsyncGenerator, Dict, List, Tuple +from typing import AsyncGenerator, Dict, List, Tuple, Optional from loguru import logger @@ -13,7 +13,11 @@ from .phonemizer import phonemize from .vocabulary import tokenize # Pre-compiled regex patterns for performance -CUSTOM_PHONEMES = re.compile(r"(\[([^\]]|\n)*?\])(\(\/([^\/)]|\n)*?\/\))") +# Updated regex to be more strict and avoid matching isolated brackets +# Only matches complete patterns like [word](/ipa/) and prevents catastrophic backtracking +CUSTOM_PHONEMES = re.compile(r"(\[[^\[\]]*?\])(\(\/[^\/\(\)]*?\/\))") +# Pattern to find pause tags like [pause:0.5s] +PAUSE_TAG_PATTERN = re.compile(r"\[pause:(\d+(?:\.\d+)?)s\]", re.IGNORECASE) def process_text_chunk( @@ -142,148 +146,189 @@ async def smart_split( max_tokens: int = settings.absolute_max_tokens, lang_code: str = "a", normalization_options: NormalizationOptions = NormalizationOptions(), -) -> AsyncGenerator[Tuple[str, List[int]], None]: - """Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens.""" +) -> AsyncGenerator[Tuple[str, List[int], Optional[float]], None]: + """Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens. + + Yields: + Tuple of (text_chunk, tokens, pause_duration_s). + If pause_duration_s is not None, it's a pause chunk with empty text/tokens. + Otherwise, it's a text chunk containing the original text. + """ start_time = time.time() chunk_count = 0 logger.info(f"Starting smart split for {len(text)} chars") - custom_phoneme_list = {} + # --- Step 1: Split by Pause Tags FIRST --- + # This operates on the raw input text + parts = PAUSE_TAG_PATTERN.split(text) + logger.debug(f"Split raw text into {len(parts)} parts by pause tags.") - # Normalize text - if settings.advanced_text_normalization and normalization_options.normalize: - if lang_code in ["a", "b", "en-us", "en-gb"]: - text = CUSTOM_PHONEMES.sub( - lambda s: handle_custom_phonemes(s, custom_phoneme_list), text - ) - text = normalize_text(text, normalization_options) - else: - logger.info( - "Skipping text normalization as it is only supported for english" - ) + part_idx = 0 + while part_idx < len(parts): + text_part_raw = parts[part_idx] # This part is raw text + part_idx += 1 - # Process all sentences - sentences = get_sentence_info(text, custom_phoneme_list, lang_code=lang_code) + # --- Process Text Part --- + if text_part_raw and text_part_raw.strip(): # Only process if the part is not empty string + # Strip leading and trailing spaces to prevent pause tag splitting artifacts + text_part_raw = text_part_raw.strip() - current_chunk = [] - current_tokens = [] - current_count = 0 + # Apply the original smart_split logic to this text part + custom_phoneme_list = {} - for sentence, tokens, count in sentences: - # Handle sentences that exceed max tokens - if count > max_tokens: - # Yield current chunk if any - if current_chunk: - chunk_text = " ".join(current_chunk).strip() # Strip after joining - chunk_count += 1 - logger.debug( - f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" - ) - yield chunk_text, current_tokens - current_chunk = [] - current_tokens = [] - current_count = 0 - - # Split long sentence on commas - clauses = re.split(r"([,])", sentence) - clause_chunk = [] - clause_tokens = [] - clause_count = 0 - - for j in range(0, len(clauses), 2): - clause = clauses[j].strip() - comma = clauses[j + 1] if j + 1 < len(clauses) else "" - - if not clause: - continue - - full_clause = clause + comma - - tokens = process_text_chunk(full_clause) - count = len(tokens) - - # If adding clause keeps us under max and not optimal yet - if ( - clause_count + count <= max_tokens - and clause_count + count <= settings.target_max_tokens - ): - clause_chunk.append(full_clause) - clause_tokens.extend(tokens) - clause_count += count + # Normalize text (original logic) + processed_text = text_part_raw + if settings.advanced_text_normalization and normalization_options.normalize: + if lang_code in ["a", "b", "en-us", "en-gb"]: + processed_text = CUSTOM_PHONEMES.sub( + lambda s: handle_custom_phonemes(s, custom_phoneme_list), processed_text + ) + processed_text = normalize_text(processed_text, normalization_options) else: - # Yield clause chunk if we have one - if clause_chunk: - chunk_text = " ".join(clause_chunk).strip() # Strip after joining + logger.info( + "Skipping text normalization as it is only supported for english" + ) + + # Process all sentences (original logic) + sentences = get_sentence_info(processed_text, custom_phoneme_list, lang_code=lang_code) + + current_chunk = [] + current_tokens = [] + current_count = 0 + + for sentence, tokens, count in sentences: + # Handle sentences that exceed max tokens (original logic) + if count > max_tokens: + # Yield current chunk if any + if current_chunk: + chunk_text = " ".join(current_chunk).strip() chunk_count += 1 logger.debug( - f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)" + f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)" ) - yield chunk_text, clause_tokens - clause_chunk = [full_clause] - clause_tokens = tokens - clause_count = count + yield chunk_text, current_tokens, None + current_chunk = [] + current_tokens = [] + current_count = 0 - # Don't forget last clause chunk - if clause_chunk: - chunk_text = " ".join(clause_chunk).strip() # Strip after joining - chunk_count += 1 - logger.debug( - f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)" - ) - yield chunk_text, clause_tokens + # Split long sentence on commas (original logic) + clauses = re.split(r"([,])", sentence) + clause_chunk = [] + clause_tokens = [] + clause_count = 0 - # Regular sentence handling - elif ( - current_count >= settings.target_min_tokens - and current_count + count > settings.target_max_tokens - ): - # If we have a good sized chunk and adding next sentence exceeds target, - # yield current chunk and start new one - chunk_text = " ".join(current_chunk).strip() # Strip after joining - chunk_count += 1 - logger.info( - f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" - ) - yield chunk_text, current_tokens - current_chunk = [sentence] - current_tokens = tokens - current_count = count - elif current_count + count <= settings.target_max_tokens: - # Keep building chunk while under target max - current_chunk.append(sentence) - current_tokens.extend(tokens) - current_count += count - elif ( - current_count + count <= max_tokens - and current_count < settings.target_min_tokens - ): - # Only exceed target max if we haven't reached minimum size yet - current_chunk.append(sentence) - current_tokens.extend(tokens) - current_count += count - else: - # Yield current chunk and start new one + for j in range(0, len(clauses), 2): + clause = clauses[j].strip() + comma = clauses[j + 1] if j + 1 < len(clauses) else "" + + if not clause: + continue + + full_clause = clause + comma + + tokens = process_text_chunk(full_clause) + count = len(tokens) + + # If adding clause keeps us under max and not optimal yet + if ( + clause_count + count <= max_tokens + and clause_count + count <= settings.target_max_tokens + ): + clause_chunk.append(full_clause) + clause_tokens.extend(tokens) + clause_count += count + else: + # Yield clause chunk if we have one + if clause_chunk: + chunk_text = " ".join(clause_chunk).strip() + chunk_count += 1 + logger.debug( + f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({clause_count} tokens)" + ) + yield chunk_text, clause_tokens, None + clause_chunk = [full_clause] + clause_tokens = tokens + clause_count = count + + # Don't forget last clause chunk + if clause_chunk: + chunk_text = " ".join(clause_chunk).strip() + chunk_count += 1 + logger.debug( + f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({clause_count} tokens)" + ) + yield chunk_text, clause_tokens, None + + # Regular sentence handling (original logic) + elif ( + current_count >= settings.target_min_tokens + and current_count + count > settings.target_max_tokens + ): + # If we have a good sized chunk and adding next sentence exceeds target, + # yield current chunk and start new one + chunk_text = " ".join(current_chunk).strip() + chunk_count += 1 + logger.info( + f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)" + ) + yield chunk_text, current_tokens, None + current_chunk = [sentence] + current_tokens = tokens + current_count = count + elif current_count + count <= settings.target_max_tokens: + # Keep building chunk while under target max + current_chunk.append(sentence) + current_tokens.extend(tokens) + current_count += count + elif ( + current_count + count <= max_tokens + and current_count < settings.target_min_tokens + ): + # Only exceed target max if we haven't reached minimum size yet + current_chunk.append(sentence) + current_tokens.extend(tokens) + current_count += count + else: + # Yield current chunk and start new one + if current_chunk: + chunk_text = " ".join(current_chunk).strip() + chunk_count += 1 + logger.info( + f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)" + ) + yield chunk_text, current_tokens, None + current_chunk = [sentence] + current_tokens = tokens + current_count = count + + # Don't forget the last chunk for this text part if current_chunk: - chunk_text = " ".join(current_chunk).strip() # Strip after joining + chunk_text = " ".join(current_chunk).strip() chunk_count += 1 logger.info( - f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" + f"Yielding final chunk {chunk_count} for part: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)" ) - yield chunk_text, current_tokens - current_chunk = [sentence] - current_tokens = tokens - current_count = count + yield chunk_text, current_tokens, None - # Don't forget the last chunk - if current_chunk: - chunk_text = " ".join(current_chunk).strip() # Strip after joining - chunk_count += 1 - logger.info( - f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)" - ) - yield chunk_text, current_tokens + # --- Handle Pause Part --- + # Check if the next part is a pause duration string + if part_idx < len(parts): + duration_str = parts[part_idx] + # Check if it looks like a valid number string captured by the regex group + if re.fullmatch(r"\d+(?:\.\d+)?", duration_str): + part_idx += 1 # Consume the duration string as it's been processed + try: + duration = float(duration_str) + if duration > 0: + chunk_count += 1 + logger.info(f"Yielding pause chunk {chunk_count}: {duration}s") + yield "", [], duration # Yield pause chunk + except (ValueError, TypeError): + # This case should be rare if re.fullmatch passed, but handle anyway + logger.warning(f"Could not parse valid-looking pause duration: {duration_str}") + # --- End of parts loop --- total_time = time.time() - start_time logger.info( - f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks" + f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks (including pauses)" ) diff --git a/api/src/services/tts_service.py b/api/src/services/tts_service.py index 0a69b85..399600e 100644 --- a/api/src/services/tts_service.py +++ b/api/src/services/tts_service.py @@ -280,48 +280,88 @@ class TTSService: f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream" ) - # Process text in chunks with smart splitting - async for chunk_text, tokens in smart_split( + # Process text in chunks with smart splitting, handling pause tags + async for chunk_text, tokens, pause_duration_s in smart_split( text, lang_code=pipeline_lang_code, normalization_options=normalization_options, ): - try: - # Process audio for chunk - async for chunk_data in self._process_chunk( - chunk_text, # Pass text for Kokoro V1 - tokens, # Pass tokens for legacy backends - voice_name, # Pass voice name - voice_path, # Pass voice path - speed, - writer, - output_format, - is_first=(chunk_index == 0), - is_last=False, # We'll update the last chunk later - normalizer=stream_normalizer, - lang_code=pipeline_lang_code, # Pass lang_code - return_timestamps=return_timestamps, - ): - if chunk_data.word_timestamps is not None: - for timestamp in chunk_data.word_timestamps: - timestamp.start_time += current_offset - timestamp.end_time += current_offset + if pause_duration_s is not None and pause_duration_s > 0: + # --- Handle Pause Chunk --- + try: + logger.debug(f"Generating {pause_duration_s}s silence chunk") + silence_samples = int(pause_duration_s * 24000) # 24kHz sample rate + # Create proper silence as int16 zeros to avoid normalization artifacts + silence_audio = np.zeros(silence_samples, dtype=np.int16) + pause_chunk = AudioChunk(audio=silence_audio, word_timestamps=[]) # Empty timestamps for silence - current_offset += len(chunk_data.audio) / 24000 - - if chunk_data.output is not None: - yield chunk_data - - else: - logger.warning( - f"No audio generated for chunk: '{chunk_text[:100]}...'" + # Format and yield the silence chunk + if output_format: + formatted_pause_chunk = await AudioService.convert_audio( + pause_chunk, output_format, writer, speed=speed, chunk_text="", + is_last_chunk=False, trim_audio=False, normalizer=stream_normalizer, ) - chunk_index += 1 - except Exception as e: - logger.error( - f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}" - ) - continue + if formatted_pause_chunk.output: + yield formatted_pause_chunk + else: # Raw audio mode + # For raw audio mode, silence is already in the correct format (int16) + # Skip normalization to avoid any potential artifacts + if len(pause_chunk.audio) > 0: + yield pause_chunk + + # Update offset based on silence duration + current_offset += pause_duration_s + chunk_index += 1 # Count pause as a yielded chunk + + except Exception as e: + logger.error(f"Failed to process pause chunk: {str(e)}") + continue + + elif tokens or chunk_text.strip(): # Process if there are tokens OR non-whitespace text + # --- Handle Text Chunk --- + try: + # Process audio for chunk + async for chunk_data in self._process_chunk( + chunk_text, # Pass text for Kokoro V1 + tokens, # Pass tokens for legacy backends + voice_name, # Pass voice name + voice_path, # Pass voice path + speed, + writer, + output_format, + is_first=(chunk_index == 0), + is_last=False, # We'll update the last chunk later + normalizer=stream_normalizer, + lang_code=pipeline_lang_code, # Pass lang_code + return_timestamps=return_timestamps, + ): + if chunk_data.word_timestamps is not None: + for timestamp in chunk_data.word_timestamps: + timestamp.start_time += current_offset + timestamp.end_time += current_offset + + # Update offset based on the actual duration of the generated audio chunk + chunk_duration = 0 + if chunk_data.audio is not None and len(chunk_data.audio) > 0: + chunk_duration = len(chunk_data.audio) / 24000 + current_offset += chunk_duration + + # Yield the processed chunk (either formatted or raw) + if chunk_data.output is not None: + yield chunk_data + elif chunk_data.audio is not None and len(chunk_data.audio) > 0: + yield chunk_data + else: + logger.warning( + f"No audio generated for chunk: '{chunk_text[:100]}...'" + ) + + chunk_index += 1 # Increment chunk index after processing text + except Exception as e: + logger.error( + f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}" + ) + continue # Only finalize if we successfully processed at least one chunk if chunk_index > 0: diff --git a/api/tests/test_text_processor.py b/api/tests/test_text_processor.py index 6ff8282..95c0259 100644 --- a/api/tests/test_text_processor.py +++ b/api/tests/test_text_processor.py @@ -67,7 +67,7 @@ async def test_smart_split_short_text(): """Test smart splitting with text under max tokens.""" text = "This is a short test sentence." chunks = [] - async for chunk_text, chunk_tokens in smart_split(text): + async for chunk_text, chunk_tokens, _ in smart_split(text): chunks.append((chunk_text, chunk_tokens)) assert len(chunks) == 1 @@ -82,7 +82,7 @@ async def test_smart_split_long_text(): text = ". ".join(["This is test sentence number " + str(i) for i in range(20)]) chunks = [] - async for chunk_text, chunk_tokens in smart_split(text): + async for chunk_text, chunk_tokens, _ in smart_split(text): chunks.append((chunk_text, chunk_tokens)) assert len(chunks) > 1 @@ -98,12 +98,13 @@ async def test_smart_split_with_punctuation(): text = "First sentence! Second sentence? Third sentence; Fourth sentence: Fifth sentence." chunks = [] - async for chunk_text, chunk_tokens in smart_split(text): + async for chunk_text, chunk_tokens, _ in smart_split(text): chunks.append(chunk_text) # Verify punctuation is preserved assert all(any(p in chunk for p in "!?;:.") for chunk in chunks) + def test_process_text_chunk_chinese_phonemes(): """Test processing with Chinese pinyin phonemes.""" pinyin = "nǐ hǎo lì" # Example pinyin sequence with tones @@ -125,12 +126,13 @@ def test_get_sentence_info_chinese(): assert count == len(tokens) assert count > 0 + @pytest.mark.asyncio async def test_smart_split_chinese_short(): """Test Chinese smart splitting with short text.""" text = "这是一句话。" chunks = [] - async for chunk_text, chunk_tokens in smart_split(text, lang_code="z"): + async for chunk_text, chunk_tokens, _ in smart_split(text, lang_code="z"): chunks.append((chunk_text, chunk_tokens)) assert len(chunks) == 1 @@ -144,7 +146,7 @@ async def test_smart_split_chinese_long(): text = "。".join([f"测试句子 {i}" for i in range(20)]) chunks = [] - async for chunk_text, chunk_tokens in smart_split(text, lang_code="z"): + async for chunk_text, chunk_tokens, _ in smart_split(text, lang_code="z"): chunks.append((chunk_text, chunk_tokens)) assert len(chunks) > 1 @@ -160,8 +162,36 @@ async def test_smart_split_chinese_punctuation(): text = "第一句!第二问?第三句;第四句:第五句。" chunks = [] - async for chunk_text, _ in smart_split(text, lang_code="z"): + async for chunk_text, _, _ in smart_split(text, lang_code="z"): chunks.append(chunk_text) # Verify Chinese punctuation is preserved - assert all(any(p in chunk for p in "!?;:。") for chunk in chunks) \ No newline at end of file + assert all(any(p in chunk for p in "!?;:。") for chunk in chunks) + + +@pytest.mark.asyncio +async def test_smart_split_with_pause(): + """Test smart splitting with pause tags.""" + text = "Hello world [pause:2.5s] How are you?" + + chunks = [] + async for chunk_text, chunk_tokens, pause_duration in smart_split(text): + chunks.append((chunk_text, chunk_tokens, pause_duration)) + + # Should have 3 chunks: text, pause, text + assert len(chunks) == 3 + + # First chunk: text + assert chunks[0][2] is None # No pause + assert "Hello world" in chunks[0][0] + assert len(chunks[0][1]) > 0 + + # Second chunk: pause + assert chunks[1][2] == 2.5 # 2.5 second pause + assert chunks[1][0] == "" # Empty text + assert len(chunks[1][1]) == 0 # No tokens + + # Third chunk: text + assert chunks[2][2] is None # No pause + assert "How are you?" in chunks[2][0] + assert len(chunks[2][1]) > 0 \ No newline at end of file From 888e3121ff018d8c7ae039178ac82d815b37133c Mon Sep 17 00:00:00 2001 From: Lukin Date: Sun, 1 Jun 2025 10:18:24 +0800 Subject: [PATCH 07/14] Refactor text normalization: Move handling of problematic symbols to occur after number and money processing to improve accuracy in text normalization. --- .../services/text_processing/normalizer.py | 46 +++++++++---------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/api/src/services/text_processing/normalizer.py b/api/src/services/text_processing/normalizer.py index 2163cbc..f439dfa 100644 --- a/api/src/services/text_processing/normalizer.py +++ b/api/src/services/text_processing/normalizer.py @@ -441,7 +441,29 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st text = text.replace('\n', ' ') text = text.replace('\r', ' ') - # Handle other problematic symbols + # Handle titles and abbreviations + text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text) + text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text) + text = re.sub(r"\b(?:Ms\.|MS\.(?= [A-Z]))", "Miss", text) + text = re.sub(r"\b(?:Mrs\.|MRS\.(?= [A-Z]))", "Mrs", text) + text = re.sub(r"\betc\.(?! [A-Z])", "etc", text) + + # Handle common words + text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text) + + # Handle numbers and money BEFORE replacing special characters + text = re.sub(r"(?<=\d),(?=\d)", "", text) + + text = MONEY_PATTERN.sub( + handle_money, + text, + ) + + text = NUMBER_PATTERN.sub(handle_numbers, text) + + text = re.sub(r"\d*\.\d+", handle_decimal, text) + + # Handle other problematic symbols AFTER money/number processing text = text.replace('~', '') # Remove tilde text = text.replace('@', ' at ') # At symbol text = text.replace('#', ' number ') # Hash/pound @@ -457,28 +479,6 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st text = text.replace('=', ' equals ') # Equals sign text = text.replace('+', ' plus ') # Plus sign - # Handle titles and abbreviations - text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text) - text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text) - text = re.sub(r"\b(?:Ms\.|MS\.(?= [A-Z]))", "Miss", text) - text = re.sub(r"\b(?:Mrs\.|MRS\.(?= [A-Z]))", "Mrs", text) - text = re.sub(r"\betc\.(?! [A-Z])", "etc", text) - - # Handle common words - text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text) - - # Handle numbers and money - text = re.sub(r"(?<=\d),(?=\d)", "", text) - - text = MONEY_PATTERN.sub( - handle_money, - text, - ) - - text = NUMBER_PATTERN.sub(handle_numbers, text) - - text = re.sub(r"\d*\.\d+", handle_decimal, text) - # Handle various formatting text = re.sub(r"(?<=\d)-(?=\d)", " to ", text) text = re.sub(r"(?<=\d)S", " S", text) From 0b2260602ad4e12ac91c381d4f33942db25e6743 Mon Sep 17 00:00:00 2001 From: Lukin Date: Sun, 1 Jun 2025 10:28:35 +0800 Subject: [PATCH 08/14] Update TTS service tests: Enhance test_get_voice_path_combined by mocking os.path.join to ensure correct path generation for combined voices. --- api/tests/test_tts_service.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/api/tests/test_tts_service.py b/api/tests/test_tts_service.py index ae8447a..23b5273 100644 --- a/api/tests/test_tts_service.py +++ b/api/tests/test_tts_service.py @@ -3,6 +3,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import numpy as np import pytest import torch +import os from api.src.services.tts_service import TTSService @@ -93,17 +94,20 @@ async def test_get_voice_path_combined(): patch("torch.load") as mock_load, patch("torch.save") as mock_save, patch("tempfile.gettempdir") as mock_temp, + patch("os.path.join") as mock_join, ): mock_get_model.return_value = model_manager mock_get_voice.return_value = voice_manager mock_temp.return_value = "/tmp" + mock_join.return_value = "/tmp/voice1+voice2.pt" mock_load.return_value = torch.ones(10) service = await TTSService.create("test_output") name, path = await service._get_voices_path("voice1+voice2") assert name == "voice1+voice2" - assert path.endswith("voice1+voice2.pt") + assert path == "/tmp/voice1+voice2.pt" mock_save.assert_called_once() + mock_join.assert_called_with("/tmp", "voice1+voice2.pt") @pytest.mark.asyncio From f7fb9c524a0a2d362fce0631867d8c92f30472db Mon Sep 17 00:00:00 2001 From: Lukin Date: Sun, 1 Jun 2025 10:35:59 +0800 Subject: [PATCH 09/14] Refactor TTS service tests: Update test_get_voice_path_combined to verify path format for combined voices, removing mock for os.path.join and enhancing assertions for path validation. --- api/tests/test_tts_service.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/api/tests/test_tts_service.py b/api/tests/test_tts_service.py index 23b5273..968c31f 100644 --- a/api/tests/test_tts_service.py +++ b/api/tests/test_tts_service.py @@ -94,20 +94,20 @@ async def test_get_voice_path_combined(): patch("torch.load") as mock_load, patch("torch.save") as mock_save, patch("tempfile.gettempdir") as mock_temp, - patch("os.path.join") as mock_join, ): mock_get_model.return_value = model_manager mock_get_voice.return_value = voice_manager mock_temp.return_value = "/tmp" - mock_join.return_value = "/tmp/voice1+voice2.pt" mock_load.return_value = torch.ones(10) service = await TTSService.create("test_output") name, path = await service._get_voices_path("voice1+voice2") assert name == "voice1+voice2" - assert path == "/tmp/voice1+voice2.pt" + # Verify the path points to a temporary file with expected format + assert path.startswith("/tmp/") + assert "voice1+voice2" in path + assert path.endswith(".pt") mock_save.assert_called_once() - mock_join.assert_called_with("/tmp", "voice1+voice2.pt") @pytest.mark.asyncio From d7d90cdc9d232c36d117deb2b3ffb37b82413fda Mon Sep 17 00:00:00 2001 From: Fireblade2534 Date: Thu, 12 Jun 2025 16:00:06 +0000 Subject: [PATCH 10/14] simplified some normalization and added more tests --- api/src/services/streaming_audio_writer.py | 4 +-- .../services/text_processing/normalizer.py | 36 +++++++++++-------- .../services/text_processing/phonemizer.py | 10 +----- .../text_processing/text_processor.py | 8 ++--- api/src/structures/schemas.py | 5 +++ api/tests/test_normalizer.py | 16 +++++++++ api/tests/test_text_processor.py | 27 ++++++++++++++ 7 files changed, 77 insertions(+), 29 deletions(-) diff --git a/api/src/services/streaming_audio_writer.py b/api/src/services/streaming_audio_writer.py index 85740aa..de9c84e 100644 --- a/api/src/services/streaming_audio_writer.py +++ b/api/src/services/streaming_audio_writer.py @@ -47,12 +47,12 @@ class StreamingAudioWriter: ) self.stream = self.container.add_stream( codec_map[self.format], - rate=self.sample_rate, # Correct parameter name is 'rate' + rate=self.sample_rate, layout="mono" if self.channels == 1 else "stereo", ) # Set bit_rate only for codecs where it's applicable and useful if self.format in ['mp3', 'aac', 'opus']: - self.stream.bit_rate = 128000 # Example bitrate, can be configured + self.stream.bit_rate = 128000 else: raise ValueError(f"Unsupported format: {self.format}") # Use self.format here diff --git a/api/src/services/text_processing/normalizer.py b/api/src/services/text_processing/normalizer.py index f439dfa..7908318 100644 --- a/api/src/services/text_processing/normalizer.py +++ b/api/src/services/text_processing/normalizer.py @@ -134,6 +134,23 @@ VALID_UNITS = { "px": "pixel", # CSS units } +SYMBOL_REPLACEMENTS = { + '~': ' ', + '@': ' at ', + '#': ' number ', + '$': ' dollar ', + '%': ' percent ', + '^': ' ', + '&': ' and ', + '*': ' ', + '_': ' ', + '|': ' ', + '\\': ' ', + '/': ' slash ', + '=': ' equals ', + '+': ' plus ', +} + MONEY_UNITS = {"$": ("dollar", "cent"), "£": ("pound", "pence"), "€": ("euro", "cent")} # Pre-compiled regex patterns for performance @@ -464,20 +481,9 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st text = re.sub(r"\d*\.\d+", handle_decimal, text) # Handle other problematic symbols AFTER money/number processing - text = text.replace('~', '') # Remove tilde - text = text.replace('@', ' at ') # At symbol - text = text.replace('#', ' number ') # Hash/pound - text = text.replace('$', ' dollar ') # Dollar sign (if not handled by money pattern) - text = text.replace('%', ' percent ') # Percent sign - text = text.replace('^', '') # Caret - text = text.replace('&', ' and ') # Ampersand - text = text.replace('*', '') # Asterisk - text = text.replace('_', ' ') # Underscore to space - text = text.replace('|', ' ') # Pipe to space - text = text.replace('\\', ' ') # Backslash to space - text = text.replace('/', ' slash ') # Forward slash to space (unless in URLs) - text = text.replace('=', ' equals ') # Equals sign - text = text.replace('+', ' plus ') # Plus sign + if normalization_options.replace_remaining_symbols: + for symbol, replacement in SYMBOL_REPLACEMENTS.items(): + text = text.replace(symbol, replacement) # Handle various formatting text = re.sub(r"(?<=\d)-(?=\d)", " to ", text) @@ -489,4 +495,6 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st ) text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text) + text = re.sub(r"\s{2,}", " ", text) + return text.strip() diff --git a/api/src/services/text_processing/phonemizer.py b/api/src/services/text_processing/phonemizer.py index ae49cd9..dabf328 100644 --- a/api/src/services/text_processing/phonemizer.py +++ b/api/src/services/text_processing/phonemizer.py @@ -84,13 +84,12 @@ def create_phonemizer(language: str = "a") -> PhonemizerBackend: return EspeakBackend(lang_map[language]) -def phonemize(text: str, language: str = "a", normalize: bool = True) -> str: +def phonemize(text: str, language: str = "a") -> str: """Convert text to phonemes Args: text: Text to convert to phonemes language: Language code ('a' for US English, 'b' for British English) - normalize: Whether to normalize text before phonemization Returns: Phonemized text @@ -100,13 +99,6 @@ def phonemize(text: str, language: str = "a", normalize: bool = True) -> str: # Strip input text first to remove problematic leading/trailing spaces text = text.strip() - if normalize: - # Create default normalization options and normalize text - normalization_options = NormalizationOptions() - text = normalize_text(text, normalization_options) - # Strip again after normalization - text = text.strip() - if language not in phonemizers: phonemizers[language] = create_phonemizer(language) diff --git a/api/src/services/text_processing/text_processor.py b/api/src/services/text_processing/text_processor.py index c5a442d..39b0d6c 100644 --- a/api/src/services/text_processing/text_processor.py +++ b/api/src/services/text_processing/text_processor.py @@ -52,7 +52,7 @@ def process_text_chunk( t1 = time.time() t0 = time.time() - phonemes = phonemize(text, language, normalize=False) # Already normalized + phonemes = phonemize(text, language) # Strip phonemes result to ensure no extra spaces phonemes = phonemes.strip() t1 = time.time() @@ -102,11 +102,11 @@ def process_text(text: str, language: str = "a") -> List[int]: def get_sentence_info( text: str, custom_phenomes_list: Dict[str, str], lang_code: str = "a" ) -> List[Tuple[str, List[int], int]]: - """Process all sentences and return info, 支持中文分句""" - # 判断是否为中文 + """Process all sentences and return info""" + # Detect Chinese text is_chinese = lang_code.startswith("z") or re.search(r"[\u4e00-\u9fff]", text) if is_chinese: - # 按中文标点断句 + # Split using Chinese punctuation sentences = re.split(r"([,。!?;])+", text) else: sentences = re.split(r"([.!?;:])(?=\s|$)", text) diff --git a/api/src/structures/schemas.py b/api/src/structures/schemas.py index 260224c..83b1c0b 100644 --- a/api/src/structures/schemas.py +++ b/api/src/structures/schemas.py @@ -1,3 +1,4 @@ +from email.policy import default from enum import Enum from typing import List, Literal, Optional, Union @@ -66,6 +67,10 @@ class NormalizationOptions(BaseModel): default=True, description="Changes phone numbers so they can be properly pronouced by kokoro", ) + replace_remaining_symbols: bool = Field( + default=True, + description="Replaces the remaining symbols after normalization with their words" + ) class OpenAISpeechRequest(BaseModel): diff --git a/api/tests/test_normalizer.py b/api/tests/test_normalizer.py index 3db0801..6b5a8bf 100644 --- a/api/tests/test_normalizer.py +++ b/api/tests/test_normalizer.py @@ -175,6 +175,13 @@ def test_money(): == "The plant cost two hundred thousand dollars and eighty cents." ) + assert ( + normalize_text( + "Your shopping spree cost $674.03!", normalization_options=NormalizationOptions() + ) + == "Your shopping spree cost six hundred and seventy-four dollars and three cents!" + ) + assert ( normalize_text( "€30.2 is in euros", normalization_options=NormalizationOptions() @@ -315,3 +322,12 @@ def test_non_url_text(): normalize_text("It costs $50.", normalization_options=NormalizationOptions()) == "It costs fifty dollars." ) + +def test_remaining_symbol(): + """Test that remaining symbols are replaced""" + assert ( + normalize_text( + "I love buying products @ good store here & @ other store", normalization_options=NormalizationOptions() + ) + == "I love buying products at good store here and at other store" + ) diff --git a/api/tests/test_text_processor.py b/api/tests/test_text_processor.py index 95c0259..7495d1d 100644 --- a/api/tests/test_text_processor.py +++ b/api/tests/test_text_processor.py @@ -194,4 +194,31 @@ async def test_smart_split_with_pause(): # Third chunk: text assert chunks[2][2] is None # No pause assert "How are you?" in chunks[2][0] + assert len(chunks[2][1]) > 0 + +@pytest.mark.asyncio +async def test_smart_split_with_two_pause(): + """Test smart splitting with two pause tags.""" + text = "[pause:0.5s][pause:1.67s]0.5" + + chunks = [] + async for chunk_text, chunk_tokens, pause_duration in smart_split(text): + chunks.append((chunk_text, chunk_tokens, pause_duration)) + + # Should have 3 chunks: pause, pause, text + assert len(chunks) == 3 + + # First chunk: pause + assert chunks[0][2] == 0.5 # 0.5 second pause + assert chunks[0][0] == "" # Empty text + assert len(chunks[0][1]) == 0 + + # Second chunk: pause + assert chunks[1][2] == 1.67 # 1.67 second pause + assert chunks[1][0] == "" # Empty text + assert len(chunks[1][1]) == 0 # No tokens + + # Third chunk: text + assert chunks[2][2] is None # No pause + assert "zero point five" in chunks[2][0] assert len(chunks[2][1]) > 0 \ No newline at end of file From cd82dd07355fca75a9cb37e75da624a0d3b3d0ae Mon Sep 17 00:00:00 2001 From: Fireblade2534 Date: Mon, 16 Jun 2025 16:39:30 +0000 Subject: [PATCH 11/14] Added a volume multiplier as a request parameter --- api/src/core/config.py | 2 +- api/src/inference/model_manager.py | 4 ++-- api/src/routers/development.py | 1 + api/src/routers/openai_compatible.py | 2 ++ api/src/services/tts_service.py | 11 ++++++++++- api/src/structures/schemas.py | 8 ++++++++ 6 files changed, 24 insertions(+), 4 deletions(-) diff --git a/api/src/core/config.py b/api/src/core/config.py index 3bc825c..87edce0 100644 --- a/api/src/core/config.py +++ b/api/src/core/config.py @@ -31,7 +31,7 @@ class Settings(BaseSettings): # Audio Settings sample_rate: int = 24000 - volume_multiplier: float = 1.0 + default_volume_multiplier: float = 1.0 # Text Processing Settings target_min_tokens: int = 175 # Target minimum tokens per chunk target_max_tokens: int = 250 # Target maximum tokens per chunk diff --git a/api/src/inference/model_manager.py b/api/src/inference/model_manager.py index 0b4cd81..eb817ec 100644 --- a/api/src/inference/model_manager.py +++ b/api/src/inference/model_manager.py @@ -141,8 +141,8 @@ Model files not found! You need to download the Kokoro V1 model: try: async for chunk in self._backend.generate(*args, **kwargs): - if settings.volume_multiplier != 1.0: - chunk.audio *= settings.volume_multiplier + if settings.default_volume_multiplier != 1.0: + chunk.audio *= settings.default_volume_multiplier yield chunk except Exception as e: raise RuntimeError(f"Generation failed: {e}") diff --git a/api/src/routers/development.py b/api/src/routers/development.py index d78aa3c..8c8ed7e 100644 --- a/api/src/routers/development.py +++ b/api/src/routers/development.py @@ -319,6 +319,7 @@ async def create_captioned_speech( writer=writer, speed=request.speed, return_timestamps=request.return_timestamps, + volume_multiplier=request.volume_multiplier, normalization_options=request.normalization_options, lang_code=request.lang_code, ) diff --git a/api/src/routers/openai_compatible.py b/api/src/routers/openai_compatible.py index 4819bc5..c325221 100644 --- a/api/src/routers/openai_compatible.py +++ b/api/src/routers/openai_compatible.py @@ -152,6 +152,7 @@ async def stream_audio_chunks( speed=request.speed, output_format=request.response_format, lang_code=request.lang_code, + volume_multiplier=request.volume_multiplier, normalization_options=request.normalization_options, return_timestamps=unique_properties["return_timestamps"], ): @@ -300,6 +301,7 @@ async def create_speech( voice=voice_name, writer=writer, speed=request.speed, + volume_multiplier=request.volume_multiplier, normalization_options=request.normalization_options, lang_code=request.lang_code, ) diff --git a/api/src/services/tts_service.py b/api/src/services/tts_service.py index 0a69b85..dca0d02 100644 --- a/api/src/services/tts_service.py +++ b/api/src/services/tts_service.py @@ -55,6 +55,7 @@ class TTSService: output_format: Optional[str] = None, is_first: bool = False, is_last: bool = False, + volume_multiplier: Optional[float] = 1.0, normalizer: Optional[AudioNormalizer] = None, lang_code: Optional[str] = None, return_timestamps: Optional[bool] = False, @@ -100,6 +101,7 @@ class TTSService: lang_code=lang_code, return_timestamps=return_timestamps, ): + chunk_data.audio*=volume_multiplier # For streaming, convert to bytes if output_format: try: @@ -132,7 +134,7 @@ class TTSService: speed=speed, return_timestamps=return_timestamps, ) - + if chunk_data.audio is None: logger.error("Model generated None for audio chunk") return @@ -141,6 +143,8 @@ class TTSService: logger.error("Model generated empty audio chunk") return + chunk_data.audio*=volume_multiplier + # For streaming, convert to bytes if output_format: try: @@ -259,6 +263,7 @@ class TTSService: speed: float = 1.0, output_format: str = "wav", lang_code: Optional[str] = None, + volume_multiplier: Optional[float] = 1.0, normalization_options: Optional[NormalizationOptions] = NormalizationOptions(), return_timestamps: Optional[bool] = False, ) -> AsyncGenerator[AudioChunk, None]: @@ -298,6 +303,7 @@ class TTSService: output_format, is_first=(chunk_index == 0), is_last=False, # We'll update the last chunk later + volume_multiplier=volume_multiplier, normalizer=stream_normalizer, lang_code=pipeline_lang_code, # Pass lang_code return_timestamps=return_timestamps, @@ -337,6 +343,7 @@ class TTSService: output_format, is_first=False, is_last=True, # Signal this is the last chunk + volume_multiplier=volume_multiplier, normalizer=stream_normalizer, lang_code=pipeline_lang_code, # Pass lang_code ): @@ -356,6 +363,7 @@ class TTSService: writer: StreamingAudioWriter, speed: float = 1.0, return_timestamps: bool = False, + volume_multiplier: Optional[float] = 1.0, normalization_options: Optional[NormalizationOptions] = NormalizationOptions(), lang_code: Optional[str] = None, ) -> AudioChunk: @@ -368,6 +376,7 @@ class TTSService: voice, writer, speed=speed, + volume_multiplier=volume_multiplier, normalization_options=normalization_options, return_timestamps=return_timestamps, lang_code=lang_code, diff --git a/api/src/structures/schemas.py b/api/src/structures/schemas.py index 260224c..f8273cc 100644 --- a/api/src/structures/schemas.py +++ b/api/src/structures/schemas.py @@ -108,6 +108,10 @@ class OpenAISpeechRequest(BaseModel): default=None, description="Optional language code to use for text processing. If not provided, will use first letter of voice name.", ) + volume_multiplier: Optional[float] = Field( + default = 1.0, + description="A volume multiplier to multiply the output audio by." + ) normalization_options: Optional[NormalizationOptions] = Field( default=NormalizationOptions(), description="Options for the normalization system", @@ -152,6 +156,10 @@ class CaptionedSpeechRequest(BaseModel): default=None, description="Optional language code to use for text processing. If not provided, will use first letter of voice name.", ) + volume_multiplier: Optional[float] = Field( + default = 1.0, + description="A volume multiplier to multiply the output audio by." + ) normalization_options: Optional[NormalizationOptions] = Field( default=NormalizationOptions(), description="Options for the normalization system", From ac491a9b1820f32b3cf5a3b6fffc30972b1e0b44 Mon Sep 17 00:00:00 2001 From: Fireblade2534 Date: Wed, 18 Jun 2025 22:02:33 +0000 Subject: [PATCH 12/14] Release --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 0d91a54..72f9fa8 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.3.0 +0.2.4 \ No newline at end of file From 2a6d2ae483eeec0db7fe29b3796111cc4df6cdb4 Mon Sep 17 00:00:00 2001 From: Fireblade2534 Date: Thu, 26 Jun 2025 00:16:06 +0000 Subject: [PATCH 13/14] Fix custom phenomes and make them more robust --- .../services/text_processing/normalizer.py | 2 +- .../text_processing/text_processor.py | 27 ++++------ api/tests/test_text_processor.py | 49 +++++++++++-------- 3 files changed, 39 insertions(+), 39 deletions(-) diff --git a/api/src/services/text_processing/normalizer.py b/api/src/services/text_processing/normalizer.py index 7908318..1b1c9f7 100644 --- a/api/src/services/text_processing/normalizer.py +++ b/api/src/services/text_processing/normalizer.py @@ -497,4 +497,4 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st text = re.sub(r"\s{2,}", " ", text) - return text.strip() + return text diff --git a/api/src/services/text_processing/text_processor.py b/api/src/services/text_processing/text_processor.py index 39b0d6c..483618f 100644 --- a/api/src/services/text_processing/text_processor.py +++ b/api/src/services/text_processing/text_processor.py @@ -15,7 +15,7 @@ from .vocabulary import tokenize # Pre-compiled regex patterns for performance # Updated regex to be more strict and avoid matching isolated brackets # Only matches complete patterns like [word](/ipa/) and prevents catastrophic backtracking -CUSTOM_PHONEMES = re.compile(r"(\[[^\[\]]*?\])(\(\/[^\/\(\)]*?\/\))") +CUSTOM_PHONEMES = re.compile(r"(\[[^\[\]]*?\]\(\/[^\/\(\)]*?\/\))") # Pattern to find pause tags like [pause:0.5s] PAUSE_TAG_PATTERN = re.compile(r"\[pause:(\d+(?:\.\d+)?)s\]", re.IGNORECASE) @@ -100,7 +100,7 @@ def process_text(text: str, language: str = "a") -> List[int]: def get_sentence_info( - text: str, custom_phenomes_list: Dict[str, str], lang_code: str = "a" + text: str, lang_code: str = "a" ) -> List[Tuple[str, List[int], int]]: """Process all sentences and return info""" # Detect Chinese text @@ -110,18 +110,10 @@ def get_sentence_info( sentences = re.split(r"([,。!?;])+", text) else: sentences = re.split(r"([.!?;:])(?=\s|$)", text) - phoneme_length, min_value = len(custom_phenomes_list), 0 results = [] for i in range(0, len(sentences), 2): sentence = sentences[i].strip() - for replaced in range(min_value, phoneme_length): - current_id = f"" - if current_id in sentence: - sentence = sentence.replace( - current_id, custom_phenomes_list.pop(current_id) - ) - min_value += 1 punct = sentences[i + 1] if i + 1 < len(sentences) else "" if not sentence: continue @@ -173,24 +165,23 @@ async def smart_split( # Strip leading and trailing spaces to prevent pause tag splitting artifacts text_part_raw = text_part_raw.strip() - # Apply the original smart_split logic to this text part - custom_phoneme_list = {} - # Normalize text (original logic) processed_text = text_part_raw if settings.advanced_text_normalization and normalization_options.normalize: if lang_code in ["a", "b", "en-us", "en-gb"]: - processed_text = CUSTOM_PHONEMES.sub( - lambda s: handle_custom_phonemes(s, custom_phoneme_list), processed_text - ) - processed_text = normalize_text(processed_text, normalization_options) + processed_text = CUSTOM_PHONEMES.split(processed_text) + for index in range(0, len(processed_text), 2): + processed_text[index] = normalize_text(processed_text[index], normalization_options) + + + processed_text = "".join(processed_text).strip() else: logger.info( "Skipping text normalization as it is only supported for english" ) # Process all sentences (original logic) - sentences = get_sentence_info(processed_text, custom_phoneme_list, lang_code=lang_code) + sentences = get_sentence_info(processed_text, lang_code=lang_code) current_chunk = [] current_tokens = [] diff --git a/api/tests/test_text_processor.py b/api/tests/test_text_processor.py index 7495d1d..3fc8a87 100644 --- a/api/tests/test_text_processor.py +++ b/api/tests/test_text_processor.py @@ -34,7 +34,7 @@ def test_process_text_chunk_phonemes(): def test_get_sentence_info(): """Test sentence splitting and info extraction.""" text = "This is sentence one. This is sentence two! What about three?" - results = get_sentence_info(text, {}) + results = get_sentence_info(text) assert len(results) == 3 for sentence, tokens, count in results: @@ -44,24 +44,6 @@ def test_get_sentence_info(): assert count == len(tokens) assert count > 0 - -def test_get_sentence_info_phenomoes(): - """Test sentence splitting and info extraction.""" - text = ( - "This is sentence one. This is two! What about three?" - ) - results = get_sentence_info(text, {"": r"sˈɛntᵊns"}) - - assert len(results) == 3 - assert "sˈɛntᵊns" in results[1][0] - for sentence, tokens, count in results: - assert isinstance(sentence, str) - assert isinstance(tokens, list) - assert isinstance(count, int) - assert count == len(tokens) - assert count > 0 - - @pytest.mark.asyncio async def test_smart_split_short_text(): """Test smart splitting with text under max tokens.""" @@ -74,6 +56,33 @@ async def test_smart_split_short_text(): assert isinstance(chunks[0][0], str) assert isinstance(chunks[0][1], list) +@pytest.mark.asyncio +async def test_smart_custom_phenomes(): + """Test smart splitting with text under max tokens.""" + text = "This is a short test sentence. [Kokoro](/kˈOkəɹO/) has a feature called custom phenomes. This is made possible by [Misaki](/misˈɑki/), the custom phenomizer that [Kokoro](/kˈOkəɹO/) version 1.0 uses" + chunks = [] + async for chunk_text, chunk_tokens, pause_duration in smart_split(text): + chunks.append((chunk_text, chunk_tokens, pause_duration)) + + # Should have 1 chunks: text + assert len(chunks) == 1 + + # First chunk: text + assert chunks[0][2] is None # No pause + assert "This is a short test sentence. [Kokoro](/kˈOkəɹO/) has a feature called custom phenomes. This is made possible by [Misaki](/misˈɑki/), the custom phenomizer that [Kokoro](/kˈOkəɹO/) version one uses" in chunks[0][0] + assert len(chunks[0][1]) > 0 + +@pytest.mark.asyncio +async def test_smart_split_only_phenomes(): + """Test input that is entirely made of phenome annotations.""" + text = "[Kokoro](/kˈOkəɹO/) [Misaki 1.2](/misˈɑki/) [Test](/tɛst/)" + chunks = [] + async for chunk_text, chunk_tokens, pause_duration in smart_split(text, max_tokens=10): + chunks.append((chunk_text, chunk_tokens, pause_duration)) + + assert len(chunks) == 1 + assert "[Kokoro](/kˈOkəɹO/) [Misaki 1.2](/misˈɑki/) [Test](/tɛst/)" in chunks[0][0] + @pytest.mark.asyncio async def test_smart_split_long_text(): @@ -116,7 +125,7 @@ def test_process_text_chunk_chinese_phonemes(): def test_get_sentence_info_chinese(): """Test Chinese sentence splitting and info extraction.""" text = "这是一个句子。这是第二个句子!第三个问题?" - results = get_sentence_info(text, {}, lang_code="z") + results = get_sentence_info(text, lang_code="z") assert len(results) == 3 for sentence, tokens, count in results: From 8a55cd5bf591708e6387655c615fe3a62f796f31 Mon Sep 17 00:00:00 2001 From: Miggi Date: Fri, 27 Jun 2025 14:30:47 +0200 Subject: [PATCH 14/14] Update torch to 2.7.1 & Cuda 12.8.1 in Docker --- docker/gpu/Dockerfile | 2 +- pyproject.toml | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docker/gpu/Dockerfile b/docker/gpu/Dockerfile index 44c1ba7..5bc7b2e 100644 --- a/docker/gpu/Dockerfile +++ b/docker/gpu/Dockerfile @@ -1,4 +1,4 @@ -FROM --platform=$BUILDPLATFORM nvidia/cuda:12.8.0-cudnn-runtime-ubuntu24.04 +FROM --platform=$BUILDPLATFORM nvidia/cuda:12.8.1-cudnn-runtime-ubuntu24.04 # Set non-interactive frontend ENV DEBIAN_FRONTEND=noninteractive diff --git a/pyproject.toml b/pyproject.toml index 5d082f7..ffbefe5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,10 +44,10 @@ dependencies = [ [project.optional-dependencies] gpu = [ - "torch==2.6.0+cu124", + "torch==2.7.1+cu128", ] cpu = [ - "torch==2.6.0", + "torch==2.7.1", ] test = [ "pytest==8.3.5", @@ -79,7 +79,7 @@ explicit = true [[tool.uv.index]] name = "pytorch-cuda" -url = "https://download.pytorch.org/whl/cu124" +url = "https://download.pytorch.org/whl/cu128" explicit = true [build-system]