mirror of
https://github.com/remsky/Kokoro-FastAPI.git
synced 2025-08-05 16:48:53 +00:00
Merge remote-tracking branch 'origin/master' into release
This commit is contained in:
commit
dd8aa26813
22 changed files with 833 additions and 252 deletions
10
README.md
10
README.md
|
@ -13,7 +13,7 @@
|
||||||
[](https://huggingface.co/hexgrad/Kokoro-82M/commit/9901c2b79161b6e898b7ea857ae5298f47b8b0d6)
|
[](https://huggingface.co/hexgrad/Kokoro-82M/commit/9901c2b79161b6e898b7ea857ae5298f47b8b0d6)
|
||||||
|
|
||||||
Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model
|
Dockerized FastAPI wrapper for [Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech model
|
||||||
- Multi-language support (English, Japanese, Korean, Chinese, _Vietnamese soon_)
|
- Multi-language support (English, Japanese, Chinese, _Vietnamese soon_)
|
||||||
- OpenAI-compatible Speech endpoint, NVIDIA GPU accelerated or CPU inference with PyTorch
|
- OpenAI-compatible Speech endpoint, NVIDIA GPU accelerated or CPU inference with PyTorch
|
||||||
- ONNX support coming soon, see v0.1.5 and earlier for legacy ONNX support in the interim
|
- ONNX support coming soon, see v0.1.5 and earlier for legacy ONNX support in the interim
|
||||||
- Debug endpoints for monitoring system stats, integrated web UI on localhost:8880/web
|
- Debug endpoints for monitoring system stats, integrated web UI on localhost:8880/web
|
||||||
|
@ -34,10 +34,12 @@ Pre built images are available to run, with arm/multi-arch support, and baked in
|
||||||
Refer to the core/config.py file for a full list of variables which can be managed via the environment
|
Refer to the core/config.py file for a full list of variables which can be managed via the environment
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# the `latest` tag can be used, though it may have some unexpected bonus features which impact stability. Named versions should be pinned for your regular usage. Feedback/testing is always welcome
|
# the `latest` tag can be used, though it may have some unexpected bonus features which impact stability.
|
||||||
|
Named versions should be pinned for your regular usage.
|
||||||
|
Feedback/testing is always welcome
|
||||||
|
|
||||||
docker run -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-cpu:v0.3.0 # CPU, or:
|
docker run -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-cpu:latest # CPU, or:
|
||||||
docker run --gpus all -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-gpu:v0.3.0 #NVIDIA GPU
|
docker run --gpus all -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-gpu:latest #NVIDIA GPU
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
2
VERSION
2
VERSION
|
@ -1 +1 @@
|
||||||
0.3.0
|
0.2.4
|
|
@ -31,6 +31,7 @@ class Settings(BaseSettings):
|
||||||
|
|
||||||
# Audio Settings
|
# Audio Settings
|
||||||
sample_rate: int = 24000
|
sample_rate: int = 24000
|
||||||
|
default_volume_multiplier: float = 1.0
|
||||||
# Text Processing Settings
|
# Text Processing Settings
|
||||||
target_min_tokens: int = 175 # Target minimum tokens per chunk
|
target_min_tokens: int = 175 # Target minimum tokens per chunk
|
||||||
target_max_tokens: int = 250 # Target maximum tokens per chunk
|
target_max_tokens: int = 250 # Target maximum tokens per chunk
|
||||||
|
|
|
@ -300,7 +300,7 @@ async def get_web_file_path(filename: str) -> str:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Construct web directory path relative to project root
|
# Construct web directory path relative to project root
|
||||||
web_dir = os.path.join("/app", settings.web_player_path)
|
web_dir = os.path.join(root_dir, settings.web_player_path)
|
||||||
|
|
||||||
# Search in web directory
|
# Search in web directory
|
||||||
search_paths = [web_dir]
|
search_paths = [web_dir]
|
||||||
|
|
|
@ -141,6 +141,8 @@ Model files not found! You need to download the Kokoro V1 model:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for chunk in self._backend.generate(*args, **kwargs):
|
async for chunk in self._backend.generate(*args, **kwargs):
|
||||||
|
if settings.default_volume_multiplier != 1.0:
|
||||||
|
chunk.audio *= settings.default_volume_multiplier
|
||||||
yield chunk
|
yield chunk
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Generation failed: {e}")
|
raise RuntimeError(f"Generation failed: {e}")
|
||||||
|
|
|
@ -104,7 +104,7 @@ async def generate_from_phonemes(
|
||||||
|
|
||||||
if chunk_audio is not None:
|
if chunk_audio is not None:
|
||||||
# Normalize audio before writing
|
# Normalize audio before writing
|
||||||
normalized_audio = await normalizer.normalize(chunk_audio)
|
normalized_audio = normalizer.normalize(chunk_audio)
|
||||||
# Write chunk and yield bytes
|
# Write chunk and yield bytes
|
||||||
chunk_bytes = writer.write_chunk(normalized_audio)
|
chunk_bytes = writer.write_chunk(normalized_audio)
|
||||||
if chunk_bytes:
|
if chunk_bytes:
|
||||||
|
@ -114,6 +114,7 @@ async def generate_from_phonemes(
|
||||||
final_bytes = writer.write_chunk(finalize=True)
|
final_bytes = writer.write_chunk(finalize=True)
|
||||||
if final_bytes:
|
if final_bytes:
|
||||||
yield final_bytes
|
yield final_bytes
|
||||||
|
writer.close()
|
||||||
else:
|
else:
|
||||||
raise ValueError("Failed to generate audio data")
|
raise ValueError("Failed to generate audio data")
|
||||||
|
|
||||||
|
@ -223,10 +224,13 @@ async def create_captioned_speech(
|
||||||
).decode("utf-8")
|
).decode("utf-8")
|
||||||
|
|
||||||
# Add any chunks that may be in the acumulator into the return word_timestamps
|
# Add any chunks that may be in the acumulator into the return word_timestamps
|
||||||
chunk_data.word_timestamps = (
|
if chunk_data.word_timestamps is not None:
|
||||||
timestamp_acumulator + chunk_data.word_timestamps
|
chunk_data.word_timestamps = (
|
||||||
)
|
timestamp_acumulator + chunk_data.word_timestamps
|
||||||
timestamp_acumulator = []
|
)
|
||||||
|
timestamp_acumulator = []
|
||||||
|
else:
|
||||||
|
chunk_data.word_timestamps = []
|
||||||
|
|
||||||
yield CaptionedSpeechResponse(
|
yield CaptionedSpeechResponse(
|
||||||
audio=base64_chunk,
|
audio=base64_chunk,
|
||||||
|
@ -271,7 +275,7 @@ async def create_captioned_speech(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add any chunks that may be in the acumulator into the return word_timestamps
|
# Add any chunks that may be in the acumulator into the return word_timestamps
|
||||||
if chunk_data.word_timestamps != None:
|
if chunk_data.word_timestamps is not None:
|
||||||
chunk_data.word_timestamps = (
|
chunk_data.word_timestamps = (
|
||||||
timestamp_acumulator + chunk_data.word_timestamps
|
timestamp_acumulator + chunk_data.word_timestamps
|
||||||
)
|
)
|
||||||
|
@ -315,6 +319,7 @@ async def create_captioned_speech(
|
||||||
writer=writer,
|
writer=writer,
|
||||||
speed=request.speed,
|
speed=request.speed,
|
||||||
return_timestamps=request.return_timestamps,
|
return_timestamps=request.return_timestamps,
|
||||||
|
volume_multiplier=request.volume_multiplier,
|
||||||
normalization_options=request.normalization_options,
|
normalization_options=request.normalization_options,
|
||||||
lang_code=request.lang_code,
|
lang_code=request.lang_code,
|
||||||
)
|
)
|
||||||
|
|
|
@ -152,6 +152,7 @@ async def stream_audio_chunks(
|
||||||
speed=request.speed,
|
speed=request.speed,
|
||||||
output_format=request.response_format,
|
output_format=request.response_format,
|
||||||
lang_code=request.lang_code,
|
lang_code=request.lang_code,
|
||||||
|
volume_multiplier=request.volume_multiplier,
|
||||||
normalization_options=request.normalization_options,
|
normalization_options=request.normalization_options,
|
||||||
return_timestamps=unique_properties["return_timestamps"],
|
return_timestamps=unique_properties["return_timestamps"],
|
||||||
):
|
):
|
||||||
|
@ -300,6 +301,7 @@ async def create_speech(
|
||||||
voice=voice_name,
|
voice=voice_name,
|
||||||
writer=writer,
|
writer=writer,
|
||||||
speed=request.speed,
|
speed=request.speed,
|
||||||
|
volume_multiplier=request.volume_multiplier,
|
||||||
normalization_options=request.normalization_options,
|
normalization_options=request.normalization_options,
|
||||||
lang_code=request.lang_code,
|
lang_code=request.lang_code,
|
||||||
)
|
)
|
||||||
|
|
|
@ -80,12 +80,12 @@ class AudioNormalizer:
|
||||||
non_silent_index_start, non_silent_index_end = None, None
|
non_silent_index_start, non_silent_index_end = None, None
|
||||||
|
|
||||||
for X in range(0, len(audio_data)):
|
for X in range(0, len(audio_data)):
|
||||||
if audio_data[X] > amplitude_threshold:
|
if abs(audio_data[X]) > amplitude_threshold:
|
||||||
non_silent_index_start = X
|
non_silent_index_start = X
|
||||||
break
|
break
|
||||||
|
|
||||||
for X in range(len(audio_data) - 1, -1, -1):
|
for X in range(len(audio_data) - 1, -1, -1):
|
||||||
if audio_data[X] > amplitude_threshold:
|
if abs(audio_data[X]) > amplitude_threshold:
|
||||||
non_silent_index_end = X
|
non_silent_index_end = X
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
|
@ -32,19 +32,29 @@ class StreamingAudioWriter:
|
||||||
if self.format in ["wav", "flac", "mp3", "pcm", "aac", "opus"]:
|
if self.format in ["wav", "flac", "mp3", "pcm", "aac", "opus"]:
|
||||||
if self.format != "pcm":
|
if self.format != "pcm":
|
||||||
self.output_buffer = BytesIO()
|
self.output_buffer = BytesIO()
|
||||||
|
container_options = {}
|
||||||
|
# Try disabling Xing VBR header for MP3 to fix iOS timeline reading issues
|
||||||
|
if self.format == 'mp3':
|
||||||
|
# Disable Xing VBR header
|
||||||
|
container_options = {'write_xing': '0'}
|
||||||
|
logger.debug("Disabling Xing VBR header for MP3 encoding.")
|
||||||
|
|
||||||
self.container = av.open(
|
self.container = av.open(
|
||||||
self.output_buffer,
|
self.output_buffer,
|
||||||
mode="w",
|
mode="w",
|
||||||
format=self.format if self.format != "aac" else "adts",
|
format=self.format if self.format != "aac" else "adts",
|
||||||
|
options=container_options # Pass options here
|
||||||
)
|
)
|
||||||
self.stream = self.container.add_stream(
|
self.stream = self.container.add_stream(
|
||||||
codec_map[self.format],
|
codec_map[self.format],
|
||||||
sample_rate=self.sample_rate,
|
rate=self.sample_rate,
|
||||||
layout="mono" if self.channels == 1 else "stereo",
|
layout="mono" if self.channels == 1 else "stereo",
|
||||||
)
|
)
|
||||||
self.stream.bit_rate = 128000
|
# Set bit_rate only for codecs where it's applicable and useful
|
||||||
|
if self.format in ['mp3', 'aac', 'opus']:
|
||||||
|
self.stream.bit_rate = 128000
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported format: {format}")
|
raise ValueError(f"Unsupported format: {self.format}") # Use self.format here
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if hasattr(self, "container"):
|
if hasattr(self, "container"):
|
||||||
|
@ -65,12 +75,18 @@ class StreamingAudioWriter:
|
||||||
|
|
||||||
if finalize:
|
if finalize:
|
||||||
if self.format != "pcm":
|
if self.format != "pcm":
|
||||||
|
# Flush stream encoder
|
||||||
packets = self.stream.encode(None)
|
packets = self.stream.encode(None)
|
||||||
for packet in packets:
|
for packet in packets:
|
||||||
self.container.mux(packet)
|
self.container.mux(packet)
|
||||||
|
|
||||||
|
# Closing the container handles writing the trailer and finalizing the file.
|
||||||
|
# No explicit flush method is available or needed here.
|
||||||
|
logger.debug("Muxed final packets.")
|
||||||
|
|
||||||
|
# Get the final bytes from the buffer *before* closing it
|
||||||
data = self.output_buffer.getvalue()
|
data = self.output_buffer.getvalue()
|
||||||
self.close()
|
self.close() # Close container and buffer
|
||||||
return data
|
return data
|
||||||
|
|
||||||
if audio_data is None or len(audio_data) == 0:
|
if audio_data is None or len(audio_data) == 0:
|
||||||
|
|
|
@ -4,12 +4,14 @@ Handles various text formats including URLs, emails, numbers, money, and special
|
||||||
Converts them into a format suitable for text-to-speech processing.
|
Converts them into a format suitable for text-to-speech processing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
import re
|
import re
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import inflect
|
import inflect
|
||||||
from numpy import number
|
from numpy import number
|
||||||
from text_to_num import text2num
|
# from text_to_num import text2num
|
||||||
from torch import mul
|
from torch import mul
|
||||||
|
|
||||||
from ...structures.schemas import NormalizationOptions
|
from ...structures.schemas import NormalizationOptions
|
||||||
|
@ -132,6 +134,24 @@ VALID_UNITS = {
|
||||||
"px": "pixel", # CSS units
|
"px": "pixel", # CSS units
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SYMBOL_REPLACEMENTS = {
|
||||||
|
'~': ' ',
|
||||||
|
'@': ' at ',
|
||||||
|
'#': ' number ',
|
||||||
|
'$': ' dollar ',
|
||||||
|
'%': ' percent ',
|
||||||
|
'^': ' ',
|
||||||
|
'&': ' and ',
|
||||||
|
'*': ' ',
|
||||||
|
'_': ' ',
|
||||||
|
'|': ' ',
|
||||||
|
'\\': ' ',
|
||||||
|
'/': ' slash ',
|
||||||
|
'=': ' equals ',
|
||||||
|
'+': ' plus ',
|
||||||
|
}
|
||||||
|
|
||||||
|
MONEY_UNITS = {"$": ("dollar", "cent"), "£": ("pound", "pence"), "€": ("euro", "cent")}
|
||||||
|
|
||||||
# Pre-compiled regex patterns for performance
|
# Pre-compiled regex patterns for performance
|
||||||
EMAIL_PATTERN = re.compile(
|
EMAIL_PATTERN = re.compile(
|
||||||
|
@ -152,37 +172,24 @@ UNIT_PATTERN = re.compile(
|
||||||
)
|
)
|
||||||
|
|
||||||
TIME_PATTERN = re.compile(
|
TIME_PATTERN = re.compile(
|
||||||
r"([0-9]{2} ?: ?[0-9]{2}( ?: ?[0-9]{2})?)( ?(pm|am)\b)?", re.IGNORECASE
|
r"([0-9]{1,2} ?: ?[0-9]{2}( ?: ?[0-9]{2})?)( ?(pm|am)\b)?", re.IGNORECASE
|
||||||
|
)
|
||||||
|
|
||||||
|
MONEY_PATTERN = re.compile(
|
||||||
|
r"(-?)(["
|
||||||
|
+ "".join(MONEY_UNITS.keys())
|
||||||
|
+ r"])(\d+(?:\.\d+)?)((?: hundred| thousand| (?:[bm]|tr|quadr)illion|k|m|b|t)*)\b",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
NUMBER_PATTERN = re.compile(
|
||||||
|
r"(-?)(\d+(?:\.\d+)?)((?: hundred| thousand| (?:[bm]|tr|quadr)illion|k|m|b)*)\b",
|
||||||
|
re.IGNORECASE,
|
||||||
)
|
)
|
||||||
|
|
||||||
INFLECT_ENGINE = inflect.engine()
|
INFLECT_ENGINE = inflect.engine()
|
||||||
|
|
||||||
|
|
||||||
def split_num(num: re.Match[str]) -> str:
|
|
||||||
"""Handle number splitting for various formats"""
|
|
||||||
num = num.group()
|
|
||||||
if "." in num:
|
|
||||||
return num
|
|
||||||
elif ":" in num:
|
|
||||||
h, m = [int(n) for n in num.split(":")]
|
|
||||||
if m == 0:
|
|
||||||
return f"{h} o'clock"
|
|
||||||
elif m < 10:
|
|
||||||
return f"{h} oh {m}"
|
|
||||||
return f"{h} {m}"
|
|
||||||
year = int(num[:4])
|
|
||||||
if year < 1100 or year % 1000 < 10:
|
|
||||||
return num
|
|
||||||
left, right = num[:2], int(num[2:4])
|
|
||||||
s = "s" if num.endswith("s") else ""
|
|
||||||
if 100 <= year % 1000 <= 999:
|
|
||||||
if right == 0:
|
|
||||||
return f"{left} hundred{s}"
|
|
||||||
elif right < 10:
|
|
||||||
return f"{left} oh {right}{s}"
|
|
||||||
return f"{left} {right}{s}"
|
|
||||||
|
|
||||||
|
|
||||||
def handle_units(u: re.Match[str]) -> str:
|
def handle_units(u: re.Match[str]) -> str:
|
||||||
"""Converts units to their full form"""
|
"""Converts units to their full form"""
|
||||||
unit_string = u.group(6).strip()
|
unit_string = u.group(6).strip()
|
||||||
|
@ -208,14 +215,61 @@ def conditional_int(number: float, threshold: float = 0.00001):
|
||||||
return number
|
return number
|
||||||
|
|
||||||
|
|
||||||
|
def translate_multiplier(multiplier: str) -> str:
|
||||||
|
"""Translate multiplier abrevations to words"""
|
||||||
|
|
||||||
|
multiplier_translation = {
|
||||||
|
"k": "thousand",
|
||||||
|
"m": "million",
|
||||||
|
"b": "billion",
|
||||||
|
"t": "trillion",
|
||||||
|
}
|
||||||
|
if multiplier.lower() in multiplier_translation:
|
||||||
|
return multiplier_translation[multiplier.lower()]
|
||||||
|
return multiplier.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def split_four_digit(number: float):
|
||||||
|
part1 = str(conditional_int(number))[:2]
|
||||||
|
part2 = str(conditional_int(number))[2:]
|
||||||
|
return f"{INFLECT_ENGINE.number_to_words(part1)} {INFLECT_ENGINE.number_to_words(part2)}"
|
||||||
|
|
||||||
|
|
||||||
|
def handle_numbers(n: re.Match[str]) -> str:
|
||||||
|
number = n.group(2)
|
||||||
|
|
||||||
|
try:
|
||||||
|
number = float(number)
|
||||||
|
except:
|
||||||
|
return n.group()
|
||||||
|
|
||||||
|
if n.group(1) == "-":
|
||||||
|
number *= -1
|
||||||
|
|
||||||
|
multiplier = translate_multiplier(n.group(3))
|
||||||
|
|
||||||
|
number = conditional_int(number)
|
||||||
|
if multiplier != "":
|
||||||
|
multiplier = f" {multiplier}"
|
||||||
|
else:
|
||||||
|
if (
|
||||||
|
number % 1 == 0
|
||||||
|
and len(str(number)) == 4
|
||||||
|
and number > 1500
|
||||||
|
and number % 1000 > 9
|
||||||
|
):
|
||||||
|
return split_four_digit(number)
|
||||||
|
|
||||||
|
return f"{INFLECT_ENGINE.number_to_words(number)}{multiplier}"
|
||||||
|
|
||||||
|
|
||||||
def handle_money(m: re.Match[str]) -> str:
|
def handle_money(m: re.Match[str]) -> str:
|
||||||
"""Convert money expressions to spoken form"""
|
"""Convert money expressions to spoken form"""
|
||||||
|
|
||||||
bill = "dollar" if m.group(2) == "$" else "pound"
|
bill, coin = MONEY_UNITS[m.group(2)]
|
||||||
coin = "cent" if m.group(2) == "$" else "pence"
|
|
||||||
number = m.group(3)
|
number = m.group(3)
|
||||||
|
|
||||||
multiplier = m.group(4)
|
|
||||||
try:
|
try:
|
||||||
number = float(number)
|
number = float(number)
|
||||||
except:
|
except:
|
||||||
|
@ -224,12 +278,17 @@ def handle_money(m: re.Match[str]) -> str:
|
||||||
if m.group(1) == "-":
|
if m.group(1) == "-":
|
||||||
number *= -1
|
number *= -1
|
||||||
|
|
||||||
|
multiplier = translate_multiplier(m.group(4))
|
||||||
|
|
||||||
|
if multiplier != "":
|
||||||
|
multiplier = f" {multiplier}"
|
||||||
|
|
||||||
if number % 1 == 0 or multiplier != "":
|
if number % 1 == 0 or multiplier != "":
|
||||||
text_number = f"{INFLECT_ENGINE.number_to_words(conditional_int(number))}{multiplier} {INFLECT_ENGINE.plural(bill, count=number)}"
|
text_number = f"{INFLECT_ENGINE.number_to_words(conditional_int(number))}{multiplier} {INFLECT_ENGINE.plural(bill, count=number)}"
|
||||||
else:
|
else:
|
||||||
sub_number = int(str(number).split(".")[-1].ljust(2, "0"))
|
sub_number = int(str(number).split(".")[-1].ljust(2, "0"))
|
||||||
|
|
||||||
text_number = f"{INFLECT_ENGINE.number_to_words(int(round(number)))} {INFLECT_ENGINE.plural(bill, count=number)} and {INFLECT_ENGINE.number_to_words(sub_number)} {INFLECT_ENGINE.plural(coin, count=sub_number)}"
|
text_number = f"{INFLECT_ENGINE.number_to_words(int(math.floor(number)))} {INFLECT_ENGINE.plural(bill, count=number)} and {INFLECT_ENGINE.number_to_words(sub_number)} {INFLECT_ENGINE.plural(coin, count=sub_number)}"
|
||||||
|
|
||||||
return text_number
|
return text_number
|
||||||
|
|
||||||
|
@ -320,19 +379,36 @@ def handle_phone_number(p: re.Match[str]) -> str:
|
||||||
def handle_time(t: re.Match[str]) -> str:
|
def handle_time(t: re.Match[str]) -> str:
|
||||||
t = t.groups()
|
t = t.groups()
|
||||||
|
|
||||||
numbers = " ".join(
|
time_parts = t[0].split(":")
|
||||||
[INFLECT_ENGINE.number_to_words(X.strip()) for X in t[0].split(":")]
|
|
||||||
)
|
numbers = []
|
||||||
|
numbers.append(INFLECT_ENGINE.number_to_words(time_parts[0].strip()))
|
||||||
|
|
||||||
|
minute_number = INFLECT_ENGINE.number_to_words(time_parts[1].strip())
|
||||||
|
if int(time_parts[1]) < 10:
|
||||||
|
if int(time_parts[1]) != 0:
|
||||||
|
numbers.append(f"oh {minute_number}")
|
||||||
|
else:
|
||||||
|
numbers.append(minute_number)
|
||||||
|
|
||||||
half = ""
|
half = ""
|
||||||
if t[2] is not None:
|
if len(time_parts) > 2:
|
||||||
half = t[2].strip()
|
seconds_number = INFLECT_ENGINE.number_to_words(time_parts[2].strip())
|
||||||
|
second_word = INFLECT_ENGINE.plural("second", int(time_parts[2].strip()))
|
||||||
|
numbers.append(f"and {seconds_number} {second_word}")
|
||||||
|
else:
|
||||||
|
if t[2] is not None:
|
||||||
|
half = " " + t[2].strip()
|
||||||
|
else:
|
||||||
|
if int(time_parts[1]) == 0:
|
||||||
|
numbers.append("o'clock")
|
||||||
|
|
||||||
return numbers + half
|
return " ".join(numbers) + half
|
||||||
|
|
||||||
|
|
||||||
def normalize_text(text: str, normalization_options: NormalizationOptions) -> str:
|
def normalize_text(text: str, normalization_options: NormalizationOptions) -> str:
|
||||||
"""Normalize text for TTS processing"""
|
"""Normalize text for TTS processing"""
|
||||||
|
|
||||||
# Handle email addresses first if enabled
|
# Handle email addresses first if enabled
|
||||||
if normalization_options.email_normalization:
|
if normalization_options.email_normalization:
|
||||||
text = EMAIL_PATTERN.sub(handle_email, text)
|
text = EMAIL_PATTERN.sub(handle_email, text)
|
||||||
|
@ -357,7 +433,7 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
|
||||||
text,
|
text,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Replace quotes and brackets
|
# Replace quotes and brackets (additional cleanup)
|
||||||
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
|
text = text.replace(chr(8216), "'").replace(chr(8217), "'")
|
||||||
text = text.replace("«", chr(8220)).replace("»", chr(8221))
|
text = text.replace("«", chr(8220)).replace("»", chr(8221))
|
||||||
text = text.replace(chr(8220), '"').replace(chr(8221), '"')
|
text = text.replace(chr(8220), '"').replace(chr(8221), '"')
|
||||||
|
@ -366,7 +442,7 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
|
||||||
for a, b in zip("、。!,:;?–", ",.!,:;?-"):
|
for a, b in zip("、。!,:;?–", ",.!,:;?-"):
|
||||||
text = text.replace(a, b + " ")
|
text = text.replace(a, b + " ")
|
||||||
|
|
||||||
# Handle simple time in the format of HH:MM:SS
|
# Handle simple time in the format of HH:MM:SS (am/pm)
|
||||||
text = TIME_PATTERN.sub(
|
text = TIME_PATTERN.sub(
|
||||||
handle_time,
|
handle_time,
|
||||||
text,
|
text,
|
||||||
|
@ -377,6 +453,11 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
|
||||||
text = re.sub(r" +", " ", text)
|
text = re.sub(r" +", " ", text)
|
||||||
text = re.sub(r"(?<=\n) +(?=\n)", "", text)
|
text = re.sub(r"(?<=\n) +(?=\n)", "", text)
|
||||||
|
|
||||||
|
# Handle special characters that might cause audio artifacts first
|
||||||
|
# Replace newlines with spaces (or pauses if needed)
|
||||||
|
text = text.replace('\n', ' ')
|
||||||
|
text = text.replace('\r', ' ')
|
||||||
|
|
||||||
# Handle titles and abbreviations
|
# Handle titles and abbreviations
|
||||||
text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text)
|
text = re.sub(r"\bD[Rr]\.(?= [A-Z])", "Doctor", text)
|
||||||
text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text)
|
text = re.sub(r"\b(?:Mr\.|MR\.(?= [A-Z]))", "Mister", text)
|
||||||
|
@ -387,21 +468,23 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
|
||||||
# Handle common words
|
# Handle common words
|
||||||
text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text)
|
text = re.sub(r"(?i)\b(y)eah?\b", r"\1e'a", text)
|
||||||
|
|
||||||
# Handle numbers and money
|
# Handle numbers and money BEFORE replacing special characters
|
||||||
text = re.sub(r"(?<=\d),(?=\d)", "", text)
|
text = re.sub(r"(?<=\d),(?=\d)", "", text)
|
||||||
|
|
||||||
text = re.sub(
|
text = MONEY_PATTERN.sub(
|
||||||
r"(?i)(-?)([$£])(\d+(?:\.\d+)?)((?: hundred| thousand| (?:[bm]|tr|quadr)illion)*)\b",
|
|
||||||
handle_money,
|
handle_money,
|
||||||
text,
|
text,
|
||||||
)
|
)
|
||||||
|
|
||||||
text = re.sub(
|
text = NUMBER_PATTERN.sub(handle_numbers, text)
|
||||||
r"\d*\.\d+|\b\d{4}s?\b|(?<!:)\b(?:[1-9]|1[0-2]):[0-5]\d\b(?!:)", split_num, text
|
|
||||||
)
|
|
||||||
|
|
||||||
text = re.sub(r"\d*\.\d+", handle_decimal, text)
|
text = re.sub(r"\d*\.\d+", handle_decimal, text)
|
||||||
|
|
||||||
|
# Handle other problematic symbols AFTER money/number processing
|
||||||
|
if normalization_options.replace_remaining_symbols:
|
||||||
|
for symbol, replacement in SYMBOL_REPLACEMENTS.items():
|
||||||
|
text = text.replace(symbol, replacement)
|
||||||
|
|
||||||
# Handle various formatting
|
# Handle various formatting
|
||||||
text = re.sub(r"(?<=\d)-(?=\d)", " to ", text)
|
text = re.sub(r"(?<=\d)-(?=\d)", " to ", text)
|
||||||
text = re.sub(r"(?<=\d)S", " S", text)
|
text = re.sub(r"(?<=\d)S", " S", text)
|
||||||
|
@ -412,4 +495,6 @@ def normalize_text(text: str, normalization_options: NormalizationOptions) -> st
|
||||||
)
|
)
|
||||||
text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
|
text = re.sub(r"(?i)(?<=[A-Z])\.(?=[A-Z])", "-", text)
|
||||||
|
|
||||||
|
text = re.sub(r"\s{2,}", " ", text)
|
||||||
|
|
||||||
return text.strip()
|
return text.strip()
|
||||||
|
|
|
@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
|
||||||
import phonemizer
|
import phonemizer
|
||||||
|
|
||||||
from .normalizer import normalize_text
|
from .normalizer import normalize_text
|
||||||
|
from ...structures.schemas import NormalizationOptions
|
||||||
|
|
||||||
phonemizers = {}
|
phonemizers = {}
|
||||||
|
|
||||||
|
@ -75,7 +76,7 @@ def create_phonemizer(language: str = "a") -> PhonemizerBackend:
|
||||||
Phonemizer backend instance
|
Phonemizer backend instance
|
||||||
"""
|
"""
|
||||||
# Map language codes to espeak language codes
|
# Map language codes to espeak language codes
|
||||||
lang_map = {"a": "en-us", "b": "en-gb"}
|
lang_map = {"a": "en-us", "b": "en-gb", "z": "z"}
|
||||||
|
|
||||||
if language not in lang_map:
|
if language not in lang_map:
|
||||||
raise ValueError(f"Unsupported language code: {language}")
|
raise ValueError(f"Unsupported language code: {language}")
|
||||||
|
@ -83,20 +84,24 @@ def create_phonemizer(language: str = "a") -> PhonemizerBackend:
|
||||||
return EspeakBackend(lang_map[language])
|
return EspeakBackend(lang_map[language])
|
||||||
|
|
||||||
|
|
||||||
def phonemize(text: str, language: str = "a", normalize: bool = True) -> str:
|
def phonemize(text: str, language: str = "a") -> str:
|
||||||
"""Convert text to phonemes
|
"""Convert text to phonemes
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: Text to convert to phonemes
|
text: Text to convert to phonemes
|
||||||
language: Language code ('a' for US English, 'b' for British English)
|
language: Language code ('a' for US English, 'b' for British English)
|
||||||
normalize: Whether to normalize text before phonemization
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Phonemized text
|
Phonemized text
|
||||||
"""
|
"""
|
||||||
global phonemizers
|
global phonemizers
|
||||||
if normalize:
|
|
||||||
text = normalize_text(text)
|
# Strip input text first to remove problematic leading/trailing spaces
|
||||||
|
text = text.strip()
|
||||||
|
|
||||||
if language not in phonemizers:
|
if language not in phonemizers:
|
||||||
phonemizers[language] = create_phonemizer(language)
|
phonemizers[language] = create_phonemizer(language)
|
||||||
return phonemizers[language].phonemize(text)
|
|
||||||
|
result = phonemizers[language].phonemize(text)
|
||||||
|
# Final strip to ensure no leading/trailing spaces in phonemes
|
||||||
|
return result.strip()
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from typing import AsyncGenerator, Dict, List, Tuple
|
from typing import AsyncGenerator, Dict, List, Tuple, Optional
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
@ -13,7 +13,11 @@ from .phonemizer import phonemize
|
||||||
from .vocabulary import tokenize
|
from .vocabulary import tokenize
|
||||||
|
|
||||||
# Pre-compiled regex patterns for performance
|
# Pre-compiled regex patterns for performance
|
||||||
CUSTOM_PHONEMES = re.compile(r"(\[([^\]]|\n)*?\])(\(\/([^\/)]|\n)*?\/\))")
|
# Updated regex to be more strict and avoid matching isolated brackets
|
||||||
|
# Only matches complete patterns like [word](/ipa/) and prevents catastrophic backtracking
|
||||||
|
CUSTOM_PHONEMES = re.compile(r"(\[[^\[\]]*?\])(\(\/[^\/\(\)]*?\/\))")
|
||||||
|
# Pattern to find pause tags like [pause:0.5s]
|
||||||
|
PAUSE_TAG_PATTERN = re.compile(r"\[pause:(\d+(?:\.\d+)?)s\]", re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
def process_text_chunk(
|
def process_text_chunk(
|
||||||
|
@ -30,6 +34,12 @@ def process_text_chunk(
|
||||||
List of token IDs
|
List of token IDs
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Strip input text to remove any leading/trailing spaces that could cause artifacts
|
||||||
|
text = text.strip()
|
||||||
|
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
|
||||||
if skip_phonemize:
|
if skip_phonemize:
|
||||||
# Input is already phonemes, just tokenize
|
# Input is already phonemes, just tokenize
|
||||||
|
@ -42,7 +52,9 @@ def process_text_chunk(
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
phonemes = phonemize(text, language, normalize=False) # Already normalized
|
phonemes = phonemize(text, language)
|
||||||
|
# Strip phonemes result to ensure no extra spaces
|
||||||
|
phonemes = phonemes.strip()
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
@ -88,10 +100,16 @@ def process_text(text: str, language: str = "a") -> List[int]:
|
||||||
|
|
||||||
|
|
||||||
def get_sentence_info(
|
def get_sentence_info(
|
||||||
text: str, custom_phenomes_list: Dict[str, str]
|
text: str, custom_phenomes_list: Dict[str, str], lang_code: str = "a"
|
||||||
) -> List[Tuple[str, List[int], int]]:
|
) -> List[Tuple[str, List[int], int]]:
|
||||||
"""Process all sentences and return info."""
|
"""Process all sentences and return info"""
|
||||||
sentences = re.split(r"([.!?;:])(?=\s|$)", text)
|
# Detect Chinese text
|
||||||
|
is_chinese = lang_code.startswith("z") or re.search(r"[\u4e00-\u9fff]", text)
|
||||||
|
if is_chinese:
|
||||||
|
# Split using Chinese punctuation
|
||||||
|
sentences = re.split(r"([,。!?;])+", text)
|
||||||
|
else:
|
||||||
|
sentences = re.split(r"([.!?;:])(?=\s|$)", text)
|
||||||
phoneme_length, min_value = len(custom_phenomes_list), 0
|
phoneme_length, min_value = len(custom_phenomes_list), 0
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
@ -104,16 +122,16 @@ def get_sentence_info(
|
||||||
current_id, custom_phenomes_list.pop(current_id)
|
current_id, custom_phenomes_list.pop(current_id)
|
||||||
)
|
)
|
||||||
min_value += 1
|
min_value += 1
|
||||||
|
|
||||||
punct = sentences[i + 1] if i + 1 < len(sentences) else ""
|
punct = sentences[i + 1] if i + 1 < len(sentences) else ""
|
||||||
|
|
||||||
if not sentence:
|
if not sentence:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
full = sentence + punct
|
full = sentence + punct
|
||||||
|
# Strip the full sentence to remove any leading/trailing spaces before processing
|
||||||
|
full = full.strip()
|
||||||
|
if not full: # Skip if empty after stripping
|
||||||
|
continue
|
||||||
tokens = process_text_chunk(full)
|
tokens = process_text_chunk(full)
|
||||||
results.append((full, tokens, len(tokens)))
|
results.append((full, tokens, len(tokens)))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
@ -128,149 +146,189 @@ async def smart_split(
|
||||||
max_tokens: int = settings.absolute_max_tokens,
|
max_tokens: int = settings.absolute_max_tokens,
|
||||||
lang_code: str = "a",
|
lang_code: str = "a",
|
||||||
normalization_options: NormalizationOptions = NormalizationOptions(),
|
normalization_options: NormalizationOptions = NormalizationOptions(),
|
||||||
) -> AsyncGenerator[Tuple[str, List[int]], None]:
|
) -> AsyncGenerator[Tuple[str, List[int], Optional[float]], None]:
|
||||||
"""Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens."""
|
"""Build optimal chunks targeting 300-400 tokens, never exceeding max_tokens.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Tuple of (text_chunk, tokens, pause_duration_s).
|
||||||
|
If pause_duration_s is not None, it's a pause chunk with empty text/tokens.
|
||||||
|
Otherwise, it's a text chunk containing the original text.
|
||||||
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
logger.info(f"Starting smart split for {len(text)} chars")
|
logger.info(f"Starting smart split for {len(text)} chars")
|
||||||
|
|
||||||
custom_phoneme_list = {}
|
# --- Step 1: Split by Pause Tags FIRST ---
|
||||||
|
# This operates on the raw input text
|
||||||
|
parts = PAUSE_TAG_PATTERN.split(text)
|
||||||
|
logger.debug(f"Split raw text into {len(parts)} parts by pause tags.")
|
||||||
|
|
||||||
# Normalize text
|
part_idx = 0
|
||||||
if settings.advanced_text_normalization and normalization_options.normalize:
|
while part_idx < len(parts):
|
||||||
print(lang_code)
|
text_part_raw = parts[part_idx] # This part is raw text
|
||||||
if lang_code in ["a", "b", "en-us", "en-gb"]:
|
part_idx += 1
|
||||||
text = CUSTOM_PHONEMES.sub(
|
|
||||||
lambda s: handle_custom_phonemes(s, custom_phoneme_list), text
|
|
||||||
)
|
|
||||||
text = normalize_text(text, normalization_options)
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
"Skipping text normalization as it is only supported for english"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Process all sentences
|
# --- Process Text Part ---
|
||||||
sentences = get_sentence_info(text, custom_phoneme_list)
|
if text_part_raw and text_part_raw.strip(): # Only process if the part is not empty string
|
||||||
|
# Strip leading and trailing spaces to prevent pause tag splitting artifacts
|
||||||
|
text_part_raw = text_part_raw.strip()
|
||||||
|
|
||||||
current_chunk = []
|
# Apply the original smart_split logic to this text part
|
||||||
current_tokens = []
|
custom_phoneme_list = {}
|
||||||
current_count = 0
|
|
||||||
|
|
||||||
for sentence, tokens, count in sentences:
|
# Normalize text (original logic)
|
||||||
# Handle sentences that exceed max tokens
|
processed_text = text_part_raw
|
||||||
if count > max_tokens:
|
if settings.advanced_text_normalization and normalization_options.normalize:
|
||||||
# Yield current chunk if any
|
if lang_code in ["a", "b", "en-us", "en-gb"]:
|
||||||
if current_chunk:
|
processed_text = CUSTOM_PHONEMES.sub(
|
||||||
chunk_text = " ".join(current_chunk)
|
lambda s: handle_custom_phonemes(s, custom_phoneme_list), processed_text
|
||||||
chunk_count += 1
|
)
|
||||||
logger.debug(
|
processed_text = normalize_text(processed_text, normalization_options)
|
||||||
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
|
|
||||||
)
|
|
||||||
yield chunk_text, current_tokens
|
|
||||||
current_chunk = []
|
|
||||||
current_tokens = []
|
|
||||||
current_count = 0
|
|
||||||
|
|
||||||
# Split long sentence on commas
|
|
||||||
clauses = re.split(r"([,])", sentence)
|
|
||||||
clause_chunk = []
|
|
||||||
clause_tokens = []
|
|
||||||
clause_count = 0
|
|
||||||
|
|
||||||
for j in range(0, len(clauses), 2):
|
|
||||||
clause = clauses[j].strip()
|
|
||||||
comma = clauses[j + 1] if j + 1 < len(clauses) else ""
|
|
||||||
|
|
||||||
if not clause:
|
|
||||||
continue
|
|
||||||
|
|
||||||
full_clause = clause + comma
|
|
||||||
|
|
||||||
tokens = process_text_chunk(full_clause)
|
|
||||||
count = len(tokens)
|
|
||||||
|
|
||||||
# If adding clause keeps us under max and not optimal yet
|
|
||||||
if (
|
|
||||||
clause_count + count <= max_tokens
|
|
||||||
and clause_count + count <= settings.target_max_tokens
|
|
||||||
):
|
|
||||||
clause_chunk.append(full_clause)
|
|
||||||
clause_tokens.extend(tokens)
|
|
||||||
clause_count += count
|
|
||||||
else:
|
else:
|
||||||
# Yield clause chunk if we have one
|
logger.info(
|
||||||
if clause_chunk:
|
"Skipping text normalization as it is only supported for english"
|
||||||
chunk_text = " ".join(clause_chunk)
|
)
|
||||||
|
|
||||||
|
# Process all sentences (original logic)
|
||||||
|
sentences = get_sentence_info(processed_text, custom_phoneme_list, lang_code=lang_code)
|
||||||
|
|
||||||
|
current_chunk = []
|
||||||
|
current_tokens = []
|
||||||
|
current_count = 0
|
||||||
|
|
||||||
|
for sentence, tokens, count in sentences:
|
||||||
|
# Handle sentences that exceed max tokens (original logic)
|
||||||
|
if count > max_tokens:
|
||||||
|
# Yield current chunk if any
|
||||||
|
if current_chunk:
|
||||||
|
chunk_text = " ".join(current_chunk).strip()
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)"
|
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)"
|
||||||
)
|
)
|
||||||
yield chunk_text, clause_tokens
|
yield chunk_text, current_tokens, None
|
||||||
clause_chunk = [full_clause]
|
current_chunk = []
|
||||||
clause_tokens = tokens
|
current_tokens = []
|
||||||
clause_count = count
|
current_count = 0
|
||||||
|
|
||||||
# Don't forget last clause chunk
|
# Split long sentence on commas (original logic)
|
||||||
if clause_chunk:
|
clauses = re.split(r"([,])", sentence)
|
||||||
chunk_text = " ".join(clause_chunk)
|
clause_chunk = []
|
||||||
chunk_count += 1
|
clause_tokens = []
|
||||||
logger.debug(
|
clause_count = 0
|
||||||
f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({clause_count} tokens)"
|
|
||||||
)
|
|
||||||
yield chunk_text, clause_tokens
|
|
||||||
|
|
||||||
# Regular sentence handling
|
for j in range(0, len(clauses), 2):
|
||||||
elif (
|
clause = clauses[j].strip()
|
||||||
current_count >= settings.target_min_tokens
|
comma = clauses[j + 1] if j + 1 < len(clauses) else ""
|
||||||
and current_count + count > settings.target_max_tokens
|
|
||||||
):
|
if not clause:
|
||||||
# If we have a good sized chunk and adding next sentence exceeds target,
|
continue
|
||||||
# yield current chunk and start new one
|
|
||||||
chunk_text = " ".join(current_chunk)
|
full_clause = clause + comma
|
||||||
chunk_count += 1
|
|
||||||
logger.info(
|
tokens = process_text_chunk(full_clause)
|
||||||
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
|
count = len(tokens)
|
||||||
)
|
|
||||||
yield chunk_text, current_tokens
|
# If adding clause keeps us under max and not optimal yet
|
||||||
current_chunk = [sentence]
|
if (
|
||||||
current_tokens = tokens
|
clause_count + count <= max_tokens
|
||||||
current_count = count
|
and clause_count + count <= settings.target_max_tokens
|
||||||
elif current_count + count <= settings.target_max_tokens:
|
):
|
||||||
# Keep building chunk while under target max
|
clause_chunk.append(full_clause)
|
||||||
current_chunk.append(sentence)
|
clause_tokens.extend(tokens)
|
||||||
current_tokens.extend(tokens)
|
clause_count += count
|
||||||
current_count += count
|
else:
|
||||||
elif (
|
# Yield clause chunk if we have one
|
||||||
current_count + count <= max_tokens
|
if clause_chunk:
|
||||||
and current_count < settings.target_min_tokens
|
chunk_text = " ".join(clause_chunk).strip()
|
||||||
):
|
chunk_count += 1
|
||||||
# Only exceed target max if we haven't reached minimum size yet
|
logger.debug(
|
||||||
current_chunk.append(sentence)
|
f"Yielding clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({clause_count} tokens)"
|
||||||
current_tokens.extend(tokens)
|
)
|
||||||
current_count += count
|
yield chunk_text, clause_tokens, None
|
||||||
else:
|
clause_chunk = [full_clause]
|
||||||
# Yield current chunk and start new one
|
clause_tokens = tokens
|
||||||
|
clause_count = count
|
||||||
|
|
||||||
|
# Don't forget last clause chunk
|
||||||
|
if clause_chunk:
|
||||||
|
chunk_text = " ".join(clause_chunk).strip()
|
||||||
|
chunk_count += 1
|
||||||
|
logger.debug(
|
||||||
|
f"Yielding final clause chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({clause_count} tokens)"
|
||||||
|
)
|
||||||
|
yield chunk_text, clause_tokens, None
|
||||||
|
|
||||||
|
# Regular sentence handling (original logic)
|
||||||
|
elif (
|
||||||
|
current_count >= settings.target_min_tokens
|
||||||
|
and current_count + count > settings.target_max_tokens
|
||||||
|
):
|
||||||
|
# If we have a good sized chunk and adding next sentence exceeds target,
|
||||||
|
# yield current chunk and start new one
|
||||||
|
chunk_text = " ".join(current_chunk).strip()
|
||||||
|
chunk_count += 1
|
||||||
|
logger.info(
|
||||||
|
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)"
|
||||||
|
)
|
||||||
|
yield chunk_text, current_tokens, None
|
||||||
|
current_chunk = [sentence]
|
||||||
|
current_tokens = tokens
|
||||||
|
current_count = count
|
||||||
|
elif current_count + count <= settings.target_max_tokens:
|
||||||
|
# Keep building chunk while under target max
|
||||||
|
current_chunk.append(sentence)
|
||||||
|
current_tokens.extend(tokens)
|
||||||
|
current_count += count
|
||||||
|
elif (
|
||||||
|
current_count + count <= max_tokens
|
||||||
|
and current_count < settings.target_min_tokens
|
||||||
|
):
|
||||||
|
# Only exceed target max if we haven't reached minimum size yet
|
||||||
|
current_chunk.append(sentence)
|
||||||
|
current_tokens.extend(tokens)
|
||||||
|
current_count += count
|
||||||
|
else:
|
||||||
|
# Yield current chunk and start new one
|
||||||
|
if current_chunk:
|
||||||
|
chunk_text = " ".join(current_chunk).strip()
|
||||||
|
chunk_count += 1
|
||||||
|
logger.info(
|
||||||
|
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)"
|
||||||
|
)
|
||||||
|
yield chunk_text, current_tokens, None
|
||||||
|
current_chunk = [sentence]
|
||||||
|
current_tokens = tokens
|
||||||
|
current_count = count
|
||||||
|
|
||||||
|
# Don't forget the last chunk for this text part
|
||||||
if current_chunk:
|
if current_chunk:
|
||||||
chunk_text = " ".join(current_chunk)
|
chunk_text = " ".join(current_chunk).strip()
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Yielding chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
|
f"Yielding final chunk {chunk_count} for part: '{chunk_text[:50]}{'...' if len(processed_text) > 50 else ''}' ({current_count} tokens)"
|
||||||
)
|
)
|
||||||
yield chunk_text, current_tokens
|
yield chunk_text, current_tokens, None
|
||||||
current_chunk = [sentence]
|
|
||||||
current_tokens = tokens
|
|
||||||
current_count = count
|
|
||||||
|
|
||||||
# Don't forget the last chunk
|
# --- Handle Pause Part ---
|
||||||
if current_chunk:
|
# Check if the next part is a pause duration string
|
||||||
chunk_text = " ".join(current_chunk)
|
if part_idx < len(parts):
|
||||||
chunk_count += 1
|
duration_str = parts[part_idx]
|
||||||
logger.info(
|
# Check if it looks like a valid number string captured by the regex group
|
||||||
f"Yielding final chunk {chunk_count}: '{chunk_text[:50]}{'...' if len(text) > 50 else ''}' ({current_count} tokens)"
|
if re.fullmatch(r"\d+(?:\.\d+)?", duration_str):
|
||||||
)
|
part_idx += 1 # Consume the duration string as it's been processed
|
||||||
yield chunk_text, current_tokens
|
try:
|
||||||
|
duration = float(duration_str)
|
||||||
|
if duration > 0:
|
||||||
|
chunk_count += 1
|
||||||
|
logger.info(f"Yielding pause chunk {chunk_count}: {duration}s")
|
||||||
|
yield "", [], duration # Yield pause chunk
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
# This case should be rare if re.fullmatch passed, but handle anyway
|
||||||
|
logger.warning(f"Could not parse valid-looking pause duration: {duration_str}")
|
||||||
|
|
||||||
|
# --- End of parts loop ---
|
||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks"
|
f"Split completed in {total_time * 1000:.2f}ms, produced {chunk_count} chunks (including pauses)"
|
||||||
)
|
)
|
||||||
|
|
|
@ -23,6 +23,8 @@ def tokenize(phonemes: str) -> list[int]:
|
||||||
Returns:
|
Returns:
|
||||||
List of token IDs
|
List of token IDs
|
||||||
"""
|
"""
|
||||||
|
# Strip phonemes to remove leading/trailing spaces that could cause artifacts
|
||||||
|
phonemes = phonemes.strip()
|
||||||
return [i for i in map(VOCAB.get, phonemes) if i is not None]
|
return [i for i in map(VOCAB.get, phonemes) if i is not None]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -55,6 +55,7 @@ class TTSService:
|
||||||
output_format: Optional[str] = None,
|
output_format: Optional[str] = None,
|
||||||
is_first: bool = False,
|
is_first: bool = False,
|
||||||
is_last: bool = False,
|
is_last: bool = False,
|
||||||
|
volume_multiplier: Optional[float] = 1.0,
|
||||||
normalizer: Optional[AudioNormalizer] = None,
|
normalizer: Optional[AudioNormalizer] = None,
|
||||||
lang_code: Optional[str] = None,
|
lang_code: Optional[str] = None,
|
||||||
return_timestamps: Optional[bool] = False,
|
return_timestamps: Optional[bool] = False,
|
||||||
|
@ -100,6 +101,7 @@ class TTSService:
|
||||||
lang_code=lang_code,
|
lang_code=lang_code,
|
||||||
return_timestamps=return_timestamps,
|
return_timestamps=return_timestamps,
|
||||||
):
|
):
|
||||||
|
chunk_data.audio*=volume_multiplier
|
||||||
# For streaming, convert to bytes
|
# For streaming, convert to bytes
|
||||||
if output_format:
|
if output_format:
|
||||||
try:
|
try:
|
||||||
|
@ -132,7 +134,7 @@ class TTSService:
|
||||||
speed=speed,
|
speed=speed,
|
||||||
return_timestamps=return_timestamps,
|
return_timestamps=return_timestamps,
|
||||||
)
|
)
|
||||||
|
|
||||||
if chunk_data.audio is None:
|
if chunk_data.audio is None:
|
||||||
logger.error("Model generated None for audio chunk")
|
logger.error("Model generated None for audio chunk")
|
||||||
return
|
return
|
||||||
|
@ -141,6 +143,8 @@ class TTSService:
|
||||||
logger.error("Model generated empty audio chunk")
|
logger.error("Model generated empty audio chunk")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
chunk_data.audio*=volume_multiplier
|
||||||
|
|
||||||
# For streaming, convert to bytes
|
# For streaming, convert to bytes
|
||||||
if output_format:
|
if output_format:
|
||||||
try:
|
try:
|
||||||
|
@ -259,6 +263,7 @@ class TTSService:
|
||||||
speed: float = 1.0,
|
speed: float = 1.0,
|
||||||
output_format: str = "wav",
|
output_format: str = "wav",
|
||||||
lang_code: Optional[str] = None,
|
lang_code: Optional[str] = None,
|
||||||
|
volume_multiplier: Optional[float] = 1.0,
|
||||||
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
|
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
|
||||||
return_timestamps: Optional[bool] = False,
|
return_timestamps: Optional[bool] = False,
|
||||||
) -> AsyncGenerator[AudioChunk, None]:
|
) -> AsyncGenerator[AudioChunk, None]:
|
||||||
|
@ -280,48 +285,90 @@ class TTSService:
|
||||||
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream"
|
f"Using lang_code '{pipeline_lang_code}' for voice '{voice_name}' in audio stream"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process text in chunks with smart splitting
|
# Process text in chunks with smart splitting, handling pause tags
|
||||||
async for chunk_text, tokens in smart_split(
|
async for chunk_text, tokens, pause_duration_s in smart_split(
|
||||||
text,
|
text,
|
||||||
lang_code=pipeline_lang_code,
|
lang_code=pipeline_lang_code,
|
||||||
normalization_options=normalization_options,
|
normalization_options=normalization_options,
|
||||||
):
|
):
|
||||||
try:
|
if pause_duration_s is not None and pause_duration_s > 0:
|
||||||
# Process audio for chunk
|
# --- Handle Pause Chunk ---
|
||||||
async for chunk_data in self._process_chunk(
|
try:
|
||||||
chunk_text, # Pass text for Kokoro V1
|
logger.debug(f"Generating {pause_duration_s}s silence chunk")
|
||||||
tokens, # Pass tokens for legacy backends
|
silence_samples = int(pause_duration_s * 24000) # 24kHz sample rate
|
||||||
voice_name, # Pass voice name
|
# Create proper silence as int16 zeros to avoid normalization artifacts
|
||||||
voice_path, # Pass voice path
|
silence_audio = np.zeros(silence_samples, dtype=np.int16)
|
||||||
speed,
|
pause_chunk = AudioChunk(audio=silence_audio, word_timestamps=[]) # Empty timestamps for silence
|
||||||
writer,
|
|
||||||
output_format,
|
|
||||||
is_first=(chunk_index == 0),
|
|
||||||
is_last=False, # We'll update the last chunk later
|
|
||||||
normalizer=stream_normalizer,
|
|
||||||
lang_code=pipeline_lang_code, # Pass lang_code
|
|
||||||
return_timestamps=return_timestamps,
|
|
||||||
):
|
|
||||||
if chunk_data.word_timestamps is not None:
|
|
||||||
for timestamp in chunk_data.word_timestamps:
|
|
||||||
timestamp.start_time += current_offset
|
|
||||||
timestamp.end_time += current_offset
|
|
||||||
|
|
||||||
current_offset += len(chunk_data.audio) / 24000
|
# Format and yield the silence chunk
|
||||||
|
if output_format:
|
||||||
|
formatted_pause_chunk = await AudioService.convert_audio(
|
||||||
|
pause_chunk, output_format, writer, speed=speed, chunk_text="",
|
||||||
|
is_last_chunk=False, trim_audio=False, normalizer=stream_normalizer,
|
||||||
|
|
||||||
if chunk_data.output is not None:
|
|
||||||
yield chunk_data
|
|
||||||
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"No audio generated for chunk: '{chunk_text[:100]}...'"
|
|
||||||
)
|
)
|
||||||
chunk_index += 1
|
if formatted_pause_chunk.output:
|
||||||
except Exception as e:
|
yield formatted_pause_chunk
|
||||||
logger.error(
|
else: # Raw audio mode
|
||||||
f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}"
|
# For raw audio mode, silence is already in the correct format (int16)
|
||||||
)
|
# Skip normalization to avoid any potential artifacts
|
||||||
continue
|
if len(pause_chunk.audio) > 0:
|
||||||
|
yield pause_chunk
|
||||||
|
|
||||||
|
# Update offset based on silence duration
|
||||||
|
current_offset += pause_duration_s
|
||||||
|
chunk_index += 1 # Count pause as a yielded chunk
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to process pause chunk: {str(e)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
elif tokens or chunk_text.strip(): # Process if there are tokens OR non-whitespace text
|
||||||
|
# --- Handle Text Chunk ---
|
||||||
|
try:
|
||||||
|
# Process audio for chunk
|
||||||
|
async for chunk_data in self._process_chunk(
|
||||||
|
chunk_text, # Pass text for Kokoro V1
|
||||||
|
tokens, # Pass tokens for legacy backends
|
||||||
|
voice_name, # Pass voice name
|
||||||
|
voice_path, # Pass voice path
|
||||||
|
speed,
|
||||||
|
writer,
|
||||||
|
output_format,
|
||||||
|
is_first=(chunk_index == 0),
|
||||||
|
volume_multiplier=volume_multiplier,
|
||||||
|
is_last=False, # We'll update the last chunk later
|
||||||
|
normalizer=stream_normalizer,
|
||||||
|
lang_code=pipeline_lang_code, # Pass lang_code
|
||||||
|
return_timestamps=return_timestamps,
|
||||||
|
):
|
||||||
|
if chunk_data.word_timestamps is not None:
|
||||||
|
for timestamp in chunk_data.word_timestamps:
|
||||||
|
timestamp.start_time += current_offset
|
||||||
|
timestamp.end_time += current_offset
|
||||||
|
|
||||||
|
# Update offset based on the actual duration of the generated audio chunk
|
||||||
|
chunk_duration = 0
|
||||||
|
if chunk_data.audio is not None and len(chunk_data.audio) > 0:
|
||||||
|
chunk_duration = len(chunk_data.audio) / 24000
|
||||||
|
current_offset += chunk_duration
|
||||||
|
|
||||||
|
# Yield the processed chunk (either formatted or raw)
|
||||||
|
if chunk_data.output is not None:
|
||||||
|
yield chunk_data
|
||||||
|
elif chunk_data.audio is not None and len(chunk_data.audio) > 0:
|
||||||
|
yield chunk_data
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"No audio generated for chunk: '{chunk_text[:100]}...'"
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk_index += 1 # Increment chunk index after processing text
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to process audio for chunk: '{chunk_text[:100]}...'. Error: {str(e)}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
# Only finalize if we successfully processed at least one chunk
|
# Only finalize if we successfully processed at least one chunk
|
||||||
if chunk_index > 0:
|
if chunk_index > 0:
|
||||||
|
@ -337,6 +384,7 @@ class TTSService:
|
||||||
output_format,
|
output_format,
|
||||||
is_first=False,
|
is_first=False,
|
||||||
is_last=True, # Signal this is the last chunk
|
is_last=True, # Signal this is the last chunk
|
||||||
|
volume_multiplier=volume_multiplier,
|
||||||
normalizer=stream_normalizer,
|
normalizer=stream_normalizer,
|
||||||
lang_code=pipeline_lang_code, # Pass lang_code
|
lang_code=pipeline_lang_code, # Pass lang_code
|
||||||
):
|
):
|
||||||
|
@ -356,6 +404,7 @@ class TTSService:
|
||||||
writer: StreamingAudioWriter,
|
writer: StreamingAudioWriter,
|
||||||
speed: float = 1.0,
|
speed: float = 1.0,
|
||||||
return_timestamps: bool = False,
|
return_timestamps: bool = False,
|
||||||
|
volume_multiplier: Optional[float] = 1.0,
|
||||||
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
|
normalization_options: Optional[NormalizationOptions] = NormalizationOptions(),
|
||||||
lang_code: Optional[str] = None,
|
lang_code: Optional[str] = None,
|
||||||
) -> AudioChunk:
|
) -> AudioChunk:
|
||||||
|
@ -368,6 +417,7 @@ class TTSService:
|
||||||
voice,
|
voice,
|
||||||
writer,
|
writer,
|
||||||
speed=speed,
|
speed=speed,
|
||||||
|
volume_multiplier=volume_multiplier,
|
||||||
normalization_options=normalization_options,
|
normalization_options=normalization_options,
|
||||||
return_timestamps=return_timestamps,
|
return_timestamps=return_timestamps,
|
||||||
lang_code=lang_code,
|
lang_code=lang_code,
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from email.policy import default
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Literal, Optional, Union
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
|
@ -66,6 +67,10 @@ class NormalizationOptions(BaseModel):
|
||||||
default=True,
|
default=True,
|
||||||
description="Changes phone numbers so they can be properly pronouced by kokoro",
|
description="Changes phone numbers so they can be properly pronouced by kokoro",
|
||||||
)
|
)
|
||||||
|
replace_remaining_symbols: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Replaces the remaining symbols after normalization with their words"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class OpenAISpeechRequest(BaseModel):
|
class OpenAISpeechRequest(BaseModel):
|
||||||
|
@ -108,6 +113,10 @@ class OpenAISpeechRequest(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
|
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
|
||||||
)
|
)
|
||||||
|
volume_multiplier: Optional[float] = Field(
|
||||||
|
default = 1.0,
|
||||||
|
description="A volume multiplier to multiply the output audio by."
|
||||||
|
)
|
||||||
normalization_options: Optional[NormalizationOptions] = Field(
|
normalization_options: Optional[NormalizationOptions] = Field(
|
||||||
default=NormalizationOptions(),
|
default=NormalizationOptions(),
|
||||||
description="Options for the normalization system",
|
description="Options for the normalization system",
|
||||||
|
@ -152,6 +161,10 @@ class CaptionedSpeechRequest(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
|
description="Optional language code to use for text processing. If not provided, will use first letter of voice name.",
|
||||||
)
|
)
|
||||||
|
volume_multiplier: Optional[float] = Field(
|
||||||
|
default = 1.0,
|
||||||
|
description="A volume multiplier to multiply the output audio by."
|
||||||
|
)
|
||||||
normalization_options: Optional[NormalizationOptions] = Field(
|
normalization_options: Optional[NormalizationOptions] = Field(
|
||||||
default=NormalizationOptions(),
|
default=NormalizationOptions(),
|
||||||
description="Options for the normalization system",
|
description="Options for the normalization system",
|
||||||
|
|
|
@ -57,19 +57,19 @@ def test_url_localhost():
|
||||||
normalize_text(
|
normalize_text(
|
||||||
"Running on localhost:7860", normalization_options=NormalizationOptions()
|
"Running on localhost:7860", normalization_options=NormalizationOptions()
|
||||||
)
|
)
|
||||||
== "Running on localhost colon 78 60"
|
== "Running on localhost colon seventy-eight sixty"
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
normalize_text(
|
normalize_text(
|
||||||
"Server at localhost:8080/api", normalization_options=NormalizationOptions()
|
"Server at localhost:8080/api", normalization_options=NormalizationOptions()
|
||||||
)
|
)
|
||||||
== "Server at localhost colon 80 80 slash api"
|
== "Server at localhost colon eighty eighty slash api"
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
normalize_text(
|
normalize_text(
|
||||||
"Test localhost:3000/test?v=1", normalization_options=NormalizationOptions()
|
"Test localhost:3000/test?v=1", normalization_options=NormalizationOptions()
|
||||||
)
|
)
|
||||||
== "Test localhost colon 3000 slash test question-mark v equals 1"
|
== "Test localhost colon three thousand slash test question-mark v equals one"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -79,17 +79,17 @@ def test_url_ip_addresses():
|
||||||
normalize_text(
|
normalize_text(
|
||||||
"Access 0.0.0.0:9090/test", normalization_options=NormalizationOptions()
|
"Access 0.0.0.0:9090/test", normalization_options=NormalizationOptions()
|
||||||
)
|
)
|
||||||
== "Access 0 dot 0 dot 0 dot 0 colon 90 90 slash test"
|
== "Access zero dot zero dot zero dot zero colon ninety ninety slash test"
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
normalize_text(
|
normalize_text(
|
||||||
"API at 192.168.1.1:8000", normalization_options=NormalizationOptions()
|
"API at 192.168.1.1:8000", normalization_options=NormalizationOptions()
|
||||||
)
|
)
|
||||||
== "API at 192 dot 168 dot 1 dot 1 colon 8000"
|
== "API at one hundred and ninety-two dot one hundred and sixty-eight dot one dot one colon eight thousand"
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
normalize_text("Server 127.0.0.1", normalization_options=NormalizationOptions())
|
normalize_text("Server 127.0.0.1", normalization_options=NormalizationOptions())
|
||||||
== "Server 127 dot 0 dot 0 dot 1"
|
== "Server one hundred and twenty-seven dot zero dot zero dot one"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -146,6 +146,15 @@ def test_money():
|
||||||
)
|
)
|
||||||
== "He lost five point three thousand dollars."
|
== "He lost five point three thousand dollars."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"He went gambling and lost about $25.05k.",
|
||||||
|
normalization_options=NormalizationOptions(),
|
||||||
|
)
|
||||||
|
== "He went gambling and lost about twenty-five point zero five thousand dollars."
|
||||||
|
)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
normalize_text(
|
normalize_text(
|
||||||
"To put it weirdly -$6.9 million",
|
"To put it weirdly -$6.9 million",
|
||||||
|
@ -153,11 +162,147 @@ def test_money():
|
||||||
)
|
)
|
||||||
== "To put it weirdly minus six point nine million dollars"
|
== "To put it weirdly minus six point nine million dollars"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
normalize_text("It costs $50.3.", normalization_options=NormalizationOptions())
|
normalize_text("It costs $50.3.", normalization_options=NormalizationOptions())
|
||||||
== "It costs fifty dollars and thirty cents."
|
== "It costs fifty dollars and thirty cents."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"The plant cost $200,000.8.", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
|
== "The plant cost two hundred thousand dollars and eighty cents."
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"Your shopping spree cost $674.03!", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
|
== "Your shopping spree cost six hundred and seventy-four dollars and three cents!"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"€30.2 is in euros", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
|
== "thirty euros and twenty cents is in euros"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_time():
|
||||||
|
"""Test time normalization"""
|
||||||
|
|
||||||
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"Your flight leaves at 10:35 pm",
|
||||||
|
normalization_options=NormalizationOptions(),
|
||||||
|
)
|
||||||
|
== "Your flight leaves at ten thirty-five pm"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"He departed for london around 5:03 am.",
|
||||||
|
normalization_options=NormalizationOptions(),
|
||||||
|
)
|
||||||
|
== "He departed for london around five oh three am."
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"Only the 13:42 and 15:12 slots are available.",
|
||||||
|
normalization_options=NormalizationOptions(),
|
||||||
|
)
|
||||||
|
== "Only the thirteen forty-two and fifteen twelve slots are available."
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"It is currently 1:00 pm", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
|
== "It is currently one pm"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"It is currently 3:00", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
|
== "It is currently three o'clock"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"12:00 am is midnight", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
|
== "twelve am is midnight"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_number():
|
||||||
|
"""Test number normalization"""
|
||||||
|
|
||||||
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"I bought 1035 cans of soda", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
|
== "I bought one thousand and thirty-five cans of soda"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"The bus has a maximum capacity of 62 people",
|
||||||
|
normalization_options=NormalizationOptions(),
|
||||||
|
)
|
||||||
|
== "The bus has a maximum capacity of sixty-two people"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"There are 1300 products left in stock",
|
||||||
|
normalization_options=NormalizationOptions(),
|
||||||
|
)
|
||||||
|
== "There are one thousand, three hundred products left in stock"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"The population is 7,890,000 people.",
|
||||||
|
normalization_options=NormalizationOptions(),
|
||||||
|
)
|
||||||
|
== "The population is seven million, eight hundred and ninety thousand people."
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"He looked around but only found 1.6k of the 10k bricks",
|
||||||
|
normalization_options=NormalizationOptions(),
|
||||||
|
)
|
||||||
|
== "He looked around but only found one point six thousand of the ten thousand bricks"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"The book has 342 pages.", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
|
== "The book has three hundred and forty-two pages."
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"He made -50 sales today.", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
|
== "He made minus fifty sales today."
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"56.789 to the power of 1.35 million",
|
||||||
|
normalization_options=NormalizationOptions(),
|
||||||
|
)
|
||||||
|
== "fifty-six point seven eight nine to the power of one point three five million"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_non_url_text():
|
def test_non_url_text():
|
||||||
"""Test that non-URL text is unaffected"""
|
"""Test that non-URL text is unaffected"""
|
||||||
|
@ -177,3 +322,12 @@ def test_non_url_text():
|
||||||
normalize_text("It costs $50.", normalization_options=NormalizationOptions())
|
normalize_text("It costs $50.", normalization_options=NormalizationOptions())
|
||||||
== "It costs fifty dollars."
|
== "It costs fifty dollars."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_remaining_symbol():
|
||||||
|
"""Test that remaining symbols are replaced"""
|
||||||
|
assert (
|
||||||
|
normalize_text(
|
||||||
|
"I love buying products @ good store here & @ other store", normalization_options=NormalizationOptions()
|
||||||
|
)
|
||||||
|
== "I love buying products at good store here and at other store"
|
||||||
|
)
|
||||||
|
|
|
@ -67,7 +67,7 @@ async def test_smart_split_short_text():
|
||||||
"""Test smart splitting with text under max tokens."""
|
"""Test smart splitting with text under max tokens."""
|
||||||
text = "This is a short test sentence."
|
text = "This is a short test sentence."
|
||||||
chunks = []
|
chunks = []
|
||||||
async for chunk_text, chunk_tokens in smart_split(text):
|
async for chunk_text, chunk_tokens, _ in smart_split(text):
|
||||||
chunks.append((chunk_text, chunk_tokens))
|
chunks.append((chunk_text, chunk_tokens))
|
||||||
|
|
||||||
assert len(chunks) == 1
|
assert len(chunks) == 1
|
||||||
|
@ -82,7 +82,7 @@ async def test_smart_split_long_text():
|
||||||
text = ". ".join(["This is test sentence number " + str(i) for i in range(20)])
|
text = ". ".join(["This is test sentence number " + str(i) for i in range(20)])
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
async for chunk_text, chunk_tokens in smart_split(text):
|
async for chunk_text, chunk_tokens, _ in smart_split(text):
|
||||||
chunks.append((chunk_text, chunk_tokens))
|
chunks.append((chunk_text, chunk_tokens))
|
||||||
|
|
||||||
assert len(chunks) > 1
|
assert len(chunks) > 1
|
||||||
|
@ -98,8 +98,127 @@ async def test_smart_split_with_punctuation():
|
||||||
text = "First sentence! Second sentence? Third sentence; Fourth sentence: Fifth sentence."
|
text = "First sentence! Second sentence? Third sentence; Fourth sentence: Fifth sentence."
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
async for chunk_text, chunk_tokens in smart_split(text):
|
async for chunk_text, chunk_tokens, _ in smart_split(text):
|
||||||
chunks.append(chunk_text)
|
chunks.append(chunk_text)
|
||||||
|
|
||||||
# Verify punctuation is preserved
|
# Verify punctuation is preserved
|
||||||
assert all(any(p in chunk for p in "!?;:.") for chunk in chunks)
|
assert all(any(p in chunk for p in "!?;:.") for chunk in chunks)
|
||||||
|
|
||||||
|
|
||||||
|
def test_process_text_chunk_chinese_phonemes():
|
||||||
|
"""Test processing with Chinese pinyin phonemes."""
|
||||||
|
pinyin = "nǐ hǎo lì" # Example pinyin sequence with tones
|
||||||
|
tokens = process_text_chunk(pinyin, skip_phonemize=True, language="z")
|
||||||
|
assert isinstance(tokens, list)
|
||||||
|
assert len(tokens) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_sentence_info_chinese():
|
||||||
|
"""Test Chinese sentence splitting and info extraction."""
|
||||||
|
text = "这是一个句子。这是第二个句子!第三个问题?"
|
||||||
|
results = get_sentence_info(text, {}, lang_code="z")
|
||||||
|
|
||||||
|
assert len(results) == 3
|
||||||
|
for sentence, tokens, count in results:
|
||||||
|
assert isinstance(sentence, str)
|
||||||
|
assert isinstance(tokens, list)
|
||||||
|
assert isinstance(count, int)
|
||||||
|
assert count == len(tokens)
|
||||||
|
assert count > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_smart_split_chinese_short():
|
||||||
|
"""Test Chinese smart splitting with short text."""
|
||||||
|
text = "这是一句话。"
|
||||||
|
chunks = []
|
||||||
|
async for chunk_text, chunk_tokens, _ in smart_split(text, lang_code="z"):
|
||||||
|
chunks.append((chunk_text, chunk_tokens))
|
||||||
|
|
||||||
|
assert len(chunks) == 1
|
||||||
|
assert isinstance(chunks[0][0], str)
|
||||||
|
assert isinstance(chunks[0][1], list)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_smart_split_chinese_long():
|
||||||
|
"""Test Chinese smart splitting with longer text."""
|
||||||
|
text = "。".join([f"测试句子 {i}" for i in range(20)])
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
async for chunk_text, chunk_tokens, _ in smart_split(text, lang_code="z"):
|
||||||
|
chunks.append((chunk_text, chunk_tokens))
|
||||||
|
|
||||||
|
assert len(chunks) > 1
|
||||||
|
for chunk_text, chunk_tokens in chunks:
|
||||||
|
assert isinstance(chunk_text, str)
|
||||||
|
assert isinstance(chunk_tokens, list)
|
||||||
|
assert len(chunk_tokens) > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_smart_split_chinese_punctuation():
|
||||||
|
"""Test Chinese smart splitting with punctuation preservation."""
|
||||||
|
text = "第一句!第二问?第三句;第四句:第五句。"
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
async for chunk_text, _, _ in smart_split(text, lang_code="z"):
|
||||||
|
chunks.append(chunk_text)
|
||||||
|
|
||||||
|
# Verify Chinese punctuation is preserved
|
||||||
|
assert all(any(p in chunk for p in "!?;:。") for chunk in chunks)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_smart_split_with_pause():
|
||||||
|
"""Test smart splitting with pause tags."""
|
||||||
|
text = "Hello world [pause:2.5s] How are you?"
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
async for chunk_text, chunk_tokens, pause_duration in smart_split(text):
|
||||||
|
chunks.append((chunk_text, chunk_tokens, pause_duration))
|
||||||
|
|
||||||
|
# Should have 3 chunks: text, pause, text
|
||||||
|
assert len(chunks) == 3
|
||||||
|
|
||||||
|
# First chunk: text
|
||||||
|
assert chunks[0][2] is None # No pause
|
||||||
|
assert "Hello world" in chunks[0][0]
|
||||||
|
assert len(chunks[0][1]) > 0
|
||||||
|
|
||||||
|
# Second chunk: pause
|
||||||
|
assert chunks[1][2] == 2.5 # 2.5 second pause
|
||||||
|
assert chunks[1][0] == "" # Empty text
|
||||||
|
assert len(chunks[1][1]) == 0 # No tokens
|
||||||
|
|
||||||
|
# Third chunk: text
|
||||||
|
assert chunks[2][2] is None # No pause
|
||||||
|
assert "How are you?" in chunks[2][0]
|
||||||
|
assert len(chunks[2][1]) > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_smart_split_with_two_pause():
|
||||||
|
"""Test smart splitting with two pause tags."""
|
||||||
|
text = "[pause:0.5s][pause:1.67s]0.5"
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
async for chunk_text, chunk_tokens, pause_duration in smart_split(text):
|
||||||
|
chunks.append((chunk_text, chunk_tokens, pause_duration))
|
||||||
|
|
||||||
|
# Should have 3 chunks: pause, pause, text
|
||||||
|
assert len(chunks) == 3
|
||||||
|
|
||||||
|
# First chunk: pause
|
||||||
|
assert chunks[0][2] == 0.5 # 0.5 second pause
|
||||||
|
assert chunks[0][0] == "" # Empty text
|
||||||
|
assert len(chunks[0][1]) == 0
|
||||||
|
|
||||||
|
# Second chunk: pause
|
||||||
|
assert chunks[1][2] == 1.67 # 1.67 second pause
|
||||||
|
assert chunks[1][0] == "" # Empty text
|
||||||
|
assert len(chunks[1][1]) == 0 # No tokens
|
||||||
|
|
||||||
|
# Third chunk: text
|
||||||
|
assert chunks[2][2] is None # No pause
|
||||||
|
assert "zero point five" in chunks[2][0]
|
||||||
|
assert len(chunks[2][1]) > 0
|
|
@ -3,6 +3,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
import os
|
||||||
|
|
||||||
from api.src.services.tts_service import TTSService
|
from api.src.services.tts_service import TTSService
|
||||||
|
|
||||||
|
@ -102,7 +103,10 @@ async def test_get_voice_path_combined():
|
||||||
service = await TTSService.create("test_output")
|
service = await TTSService.create("test_output")
|
||||||
name, path = await service._get_voices_path("voice1+voice2")
|
name, path = await service._get_voices_path("voice1+voice2")
|
||||||
assert name == "voice1+voice2"
|
assert name == "voice1+voice2"
|
||||||
assert path.endswith("voice1+voice2.pt")
|
# Verify the path points to a temporary file with expected format
|
||||||
|
assert path.startswith("/tmp/")
|
||||||
|
assert "voice1+voice2" in path
|
||||||
|
assert path.endswith(".pt")
|
||||||
mock_save.assert_called_once()
|
mock_save.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
|
23
dev/Test Phon.py
Normal file
23
dev/Test Phon.py
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pydub
|
||||||
|
import requests
|
||||||
|
|
||||||
|
def generate_audio_from_phonemes(phonemes: str, voice: str = "af_bella"):
|
||||||
|
"""Generate audio from phonemes"""
|
||||||
|
response = requests.post(
|
||||||
|
"http://localhost:8880/dev/generate_from_phonemes",
|
||||||
|
json={"phonemes": phonemes, "voice": voice},
|
||||||
|
headers={"Accept": "audio/wav"}
|
||||||
|
)
|
||||||
|
if response.status_code != 200:
|
||||||
|
print(f"Error: {response.text}")
|
||||||
|
return None
|
||||||
|
return response.content
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
with open(f"outputnostreammoney.wav", "wb") as f:
|
||||||
|
f.write(generate_audio_from_phonemes(r"mɪsəki ɪz ɐn ɪkspˌɛɹəmˈɛntᵊl ʤˈitəpˈi ˈɛnʤən dəzˈInd tə pˈWəɹ fjˈuʧəɹ vˈɜɹʒənz ʌv kəkˈɔɹO mˈɑdᵊlz."))
|
38
dev/Test copy 2.py
Normal file
38
dev/Test copy 2.py
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pydub
|
||||||
|
import requests
|
||||||
|
|
||||||
|
text = """Running on localhost:7860"""
|
||||||
|
|
||||||
|
|
||||||
|
Type = "wav"
|
||||||
|
response = requests.post(
|
||||||
|
"http://localhost:8880/dev/captioned_speech",
|
||||||
|
json={
|
||||||
|
"model": "kokoro",
|
||||||
|
"input": text,
|
||||||
|
"voice": "af_heart+af_sky",
|
||||||
|
"speed": 1.0,
|
||||||
|
"response_format": Type,
|
||||||
|
"stream": True,
|
||||||
|
},
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
f = open(f"outputstream.{Type}", "wb")
|
||||||
|
for chunk in response.iter_lines(decode_unicode=True):
|
||||||
|
if chunk:
|
||||||
|
temp_json = json.loads(chunk)
|
||||||
|
if temp_json["timestamps"] != []:
|
||||||
|
chunk_json = temp_json
|
||||||
|
|
||||||
|
# Decode base 64 stream to bytes
|
||||||
|
chunk_audio = base64.b64decode(temp_json["audio"].encode("utf-8"))
|
||||||
|
|
||||||
|
# Process streaming chunks
|
||||||
|
f.write(chunk_audio)
|
||||||
|
|
||||||
|
# Print word level timestamps
|
||||||
|
print(chunk_json["timestamps"])
|
|
@ -3,9 +3,7 @@ import json
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
text = """the administration has offered up a platter of repression for more than a year and is still slated to lose $400 million.
|
text = """奶酪芝士很浓郁!臭豆腐芝士有争议?陈年奶酪价格昂贵。"""
|
||||||
|
|
||||||
Columbia is the largest private landowner in New York City and boasts an endowment of $14.8 billion;"""
|
|
||||||
|
|
||||||
|
|
||||||
Type = "wav"
|
Type = "wav"
|
||||||
|
@ -15,7 +13,7 @@ response = requests.post(
|
||||||
json={
|
json={
|
||||||
"model": "kokoro",
|
"model": "kokoro",
|
||||||
"input": text,
|
"input": text,
|
||||||
"voice": "af_heart+af_sky",
|
"voice": "zf_xiaobei",
|
||||||
"speed": 1.0,
|
"speed": 1.0,
|
||||||
"response_format": Type,
|
"response_format": Type,
|
||||||
"stream": False,
|
"stream": False,
|
||||||
|
|
|
@ -30,6 +30,10 @@ WORKDIR /app
|
||||||
# Copy dependency files
|
# Copy dependency files
|
||||||
COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml
|
COPY --chown=appuser:appuser pyproject.toml ./pyproject.toml
|
||||||
|
|
||||||
|
# Install Rust (required to build sudachipy and pyopenjtalk-plus)
|
||||||
|
RUN curl https://sh.rustup.rs -sSf | sh -s -- -y
|
||||||
|
ENV PATH="/home/appuser/.cargo/bin:$PATH"
|
||||||
|
|
||||||
# Install dependencies
|
# Install dependencies
|
||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
uv venv --python 3.10 && \
|
uv venv --python 3.10 && \
|
||||||
|
|
Loading…
Add table
Reference in a new issue