Refactor LabelledString and relevant classes

This commit is contained in:
YishiMichael 2022-04-17 13:57:03 +08:00
parent e9298c5faf
commit 0e0244128c
No known key found for this signature in database
GPG key ID: EC615C0C5A86BC80
5 changed files with 412 additions and 486 deletions

View file

@ -6,6 +6,7 @@ import numpy as np
from manimlib.animation.animation import Animation
from manimlib.mobject.svg.labelled_string import LabelledString
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.bezier import integer_interpolate
from manimlib.utils.config_ops import digest_config
@ -212,7 +213,9 @@ class AddTextWordByWord(ShowIncreasingSubsets):
def __init__(self, string_mobject, **kwargs):
assert isinstance(string_mobject, LabelledString)
grouped_mobject = string_mobject.submob_groups
grouped_mobject = VGroup(*[
part for _, part in string_mobject.get_group_part_items()
])
digest_config(self, kwargs)
if self.run_time is None:
self.run_time = self.time_per_word * len(grouped_mobject)

View file

@ -168,40 +168,36 @@ class TransformMatchingStrings(AnimationGroup):
assert isinstance(source, LabelledString)
assert isinstance(target, LabelledString)
anims = []
source_indices = list(range(len(source.labelled_submobject_items)))
target_indices = list(range(len(target.labelled_submobject_items)))
def get_indices_lists(mobject, parts):
labelled_submobjects = [
submob
for _, submob in mobject.labelled_submobject_items
]
return [
source_submobs = [
submob for _, submob in source.labelled_submobject_items
]
target_submobs = [
submob for _, submob in target.labelled_submobject_items
]
source_indices = list(range(len(source_submobs)))
target_indices = list(range(len(target_submobs)))
def get_filtered_indices_lists(parts, submobs, rest_indices):
return list(filter(
lambda indices_list: all([
index in rest_indices
for index in indices_list
]),
[
labelled_submobjects.index(submob)
for submob in part
[submobs.index(submob) for submob in part]
for part in parts
]
for part in parts
]
))
def add_anims_from(anim_class, func, source_args, target_args=None):
if target_args is None:
target_args = source_args.copy()
for source_arg, target_arg in zip(source_args, target_args):
source_parts = func(source, source_arg)
target_parts = func(target, target_arg)
source_indices_lists = list(filter(
lambda indices_list: all([
index in source_indices
for index in indices_list
]), get_indices_lists(source, source_parts)
))
target_indices_lists = list(filter(
lambda indices_list: all([
index in target_indices
for index in indices_list
]), get_indices_lists(target, target_parts)
))
def add_anims(anim_class, parts_pairs):
for source_parts, target_parts in parts_pairs:
source_indices_lists = get_filtered_indices_lists(
source_parts, source_submobs, source_indices
)
target_indices_lists = get_filtered_indices_lists(
target_parts, target_submobs, target_indices
)
if not source_indices_lists or not target_indices_lists:
continue
anims.append(anim_class(source_parts, target_parts, **kwargs))
@ -210,29 +206,45 @@ class TransformMatchingStrings(AnimationGroup):
for index in it.chain(*target_indices_lists):
target_indices.remove(index)
def get_common_substrs(substrs_from_source, substrs_from_target):
return sorted([
substr for substr in substrs_from_source
if substr and substr in substrs_from_target
], key=len, reverse=True)
def get_substr_to_parts_map(part_items):
result = {}
for substr, part in part_items:
if substr not in result:
result[substr] = []
result[substr].append(part)
return result
add_anims_from(
ReplacementTransform, LabelledString.select_parts,
self.key_map.keys(), self.key_map.values()
def add_anims_from(anim_class, func):
source_substr_to_parts_map = get_substr_to_parts_map(func(source))
target_substr_to_parts_map = get_substr_to_parts_map(func(target))
add_anims(
anim_class,
[
(
VGroup(*source_substr_to_parts_map[substr]),
VGroup(*target_substr_to_parts_map[substr])
)
for substr in sorted([
s for s in source_substr_to_parts_map.keys()
if s and s in target_substr_to_parts_map.keys()
], key=len, reverse=True)
]
)
add_anims(
ReplacementTransform,
[
(source.select_parts(k), target.select_parts(v))
for k, v in self.key_map.items()
]
)
add_anims_from(
FadeTransformPieces, LabelledString.select_parts,
get_common_substrs(
source.specified_substrs,
target.specified_substrs
)
FadeTransformPieces,
LabelledString.get_specified_part_items
)
add_anims_from(
FadeTransformPieces, LabelledString.select_parts_by_group_substr,
get_common_substrs(
source.group_substrs,
target.group_substrs
)
FadeTransformPieces,
LabelledString.get_group_part_items
)
rest_source = VGroup(*[source[index] for index in source_indices])

View file

@ -85,6 +85,7 @@ class LabelledString(SVGMobject, ABC):
submob_color_ints = [0] * len(self.submobjects)
if len(self.submobjects) != len(submob_color_ints):
print(len(self.submobjects), len(submob_color_ints))
raise ValueError(
"Cannot align submobjects of the labelled svg "
"to the original svg"
@ -106,31 +107,25 @@ class LabelledString(SVGMobject, ABC):
def pre_parse(self) -> None:
self.string_len = len(self.string)
self.full_span = (0, self.string_len)
self.space_spans = self.find_spans(r"\s+")
self.base_color_int = self.color_to_int(self.base_color)
def parse(self) -> None:
self.command_repl_items = self.get_command_repl_items()
self.command_spans = self.get_command_spans()
self.extra_entity_spans = self.get_extra_entity_spans()
self.skippable_indices = self.get_skippable_indices()
self.entity_spans = self.get_entity_spans()
self.extra_ignored_spans = self.get_extra_ignored_spans()
self.skipped_spans = self.get_skipped_spans()
self.internal_specified_spans = self.get_internal_specified_spans()
self.external_specified_spans = self.get_external_specified_spans()
self.bracket_spans = self.get_bracket_spans()
self.extra_isolated_items = self.get_extra_isolated_items()
self.specified_items = self.get_specified_items()
self.specified_spans = self.get_specified_spans()
self.label_span_list = self.get_label_span_list()
self.check_overlapping()
self.label_span_list = self.get_label_span_list()
if len(self.label_span_list) >= 16777216:
raise ValueError("Cannot handle that many substrings")
def post_parse(self) -> None:
self.labelled_submobject_items = [
(submob.label, submob)
for submob in self.submobjects
]
self.specified_substrs = self.get_specified_substrs()
self.group_items = self.get_group_items()
self.group_substrs = self.get_group_substrs()
self.submob_groups = self.get_submob_groups()
def copy(self):
return self.deepcopy()
@ -140,16 +135,21 @@ class LabelledString(SVGMobject, ABC):
def get_substr(self, span: Span) -> str:
return self.string[slice(*span)]
def find_spans(self, pattern: str | re.Pattern) -> list[Span]:
def match(self, pattern: str | re.Pattern, **kwargs) -> re.Pattern | None:
if isinstance(pattern, str):
pattern = re.compile(pattern)
return re.compile(pattern).match(self.string, **kwargs)
def find_spans(self, pattern: str | re.Pattern, **kwargs) -> list[Span]:
if isinstance(pattern, str):
pattern = re.compile(pattern)
return [
match_obj.span()
for match_obj in pattern.finditer(self.string)
for match_obj in pattern.finditer(self.string, **kwargs)
]
def match_at(self, pattern: str, pos: int) -> re.Pattern | None:
return re.compile(pattern).match(self.string, pos=pos)
def find_indices(self, pattern: str | re.Pattern, **kwargs) -> list[int]:
return [index for index, _ in self.find_spans(pattern, **kwargs)]
@staticmethod
def is_single_selector(selector: Selector) -> bool:
@ -230,41 +230,24 @@ class LabelledString(SVGMobject, ABC):
spans = LabelledString.get_neighbouring_pairs(indices)
return list(zip(unique_vals, spans))
@staticmethod
def find_region_index(seq: list[int], val: int) -> int:
# Returns an integer in `range(-1, len(seq))` satisfying
# `seq[result] <= val < seq[result + 1]`.
# `seq` should be sorted in ascending order.
if not seq or val < seq[0]:
return -1
result = len(seq) - 1
while val < seq[result]:
result -= 1
return result
@staticmethod
def take_nearest_value(seq: list[int], val: int, index_shift: int) -> int:
sorted_seq = sorted(seq)
index = LabelledString.find_region_index(sorted_seq, val)
return sorted_seq[index + index_shift]
@staticmethod
def generate_span_repl_dict(
inserted_string_pairs: list[tuple[Span, tuple[str, str]]],
other_repl_items: list[tuple[Span, str]]
repl_items: list[tuple[Span, str]]
) -> dict[Span, str]:
result = dict(other_repl_items)
result = dict(repl_items)
if not inserted_string_pairs:
return result
indices, _, _, inserted_strings = zip(*sorted([
indices, _, _, _, inserted_strings = zip(*sorted([
(
span[flag],
item[0][flag],
-flag,
-span[1 - flag],
str_pair[flag]
-item[0][1 - flag],
(1, -1)[flag] * item_index,
item[1][flag]
)
for span, str_pair in inserted_string_pairs
for item_index, item in enumerate(inserted_string_pairs)
for flag in range(2)
]))
result.update({
@ -295,22 +278,6 @@ class LabelledString(SVGMobject, ABC):
repl_strs.append("")
return "".join(it.chain(*zip(pieces, repl_strs)))
@staticmethod
def rslide(index: int, skipped: list[Span]) -> int:
transfer_dict = dict(sorted(skipped))
while index in transfer_dict.keys():
index = transfer_dict[index]
return index
@staticmethod
def lslide(index: int, skipped: list[Span]) -> int:
transfer_dict = dict(sorted([
skipped_span[::-1] for skipped_span in skipped
], reverse=True))
while index in transfer_dict.keys():
index = transfer_dict[index]
return index
@staticmethod
def color_to_int(color: ManimColor) -> int:
hex_code = rgb_to_hex(color_to_rgb(color))
@ -323,80 +290,63 @@ class LabelledString(SVGMobject, ABC):
# Parsing
@abstractmethod
def get_command_repl_items(self) -> list[tuple[Span, str]]:
def get_skippable_indices(self) -> list[int]:
return []
def get_command_spans(self) -> list[Span]:
return [cmd_span for cmd_span, _ in self.command_repl_items]
@staticmethod
def shrink_span(span: Span, skippable_indices: list[int]) -> Span:
span_begin, span_end = span
while span_begin in skippable_indices:
span_begin += 1
while span_end - 1 in skippable_indices:
span_end -= 1
return (span_begin, span_end)
@abstractmethod
def get_extra_entity_spans(self) -> list[Span]:
return []
def get_entity_spans(self) -> list[Span]:
return list(it.chain(
self.command_spans,
self.extra_entity_spans
return []
@abstractmethod
def get_bracket_spans(self) -> list[Span]:
return []
@abstractmethod
def get_extra_isolated_items(self) -> list[tuple[Span, dict[str, str]]]:
return []
def get_specified_items(self) -> list[tuple[Span, dict[str, str]]]:
span_items = list(it.chain(
self.extra_isolated_items,
[
(span, {})
for span in self.find_spans_by_selector(self.isolate)
]
))
def is_splittable_index(self, index: int) -> bool:
return not any([
entity_span[0] < index < entity_span[1]
for entity_span in self.entity_spans
])
@abstractmethod
def get_extra_ignored_spans(self) -> list[int]:
return []
def get_skipped_spans(self) -> list[Span]:
return list(it.chain(
self.find_spans(r"\s"),
self.command_spans,
self.extra_ignored_spans
))
def shrink_span(self, span: Span) -> Span:
return (
self.rslide(span[0], self.skipped_spans),
self.lslide(span[1], self.skipped_spans)
)
@abstractmethod
def get_internal_specified_spans(self) -> list[Span]:
return []
@abstractmethod
def get_external_specified_spans(self) -> list[Span]:
return []
result = []
for span, attr_dict in span_items:
shrinked_span = self.shrink_span(span, self.skippable_indices)
if shrinked_span[0] >= shrinked_span[1]:
continue
if any([
entity_span[0] < index < entity_span[1]
for index in shrinked_span
for entity_span in self.entity_spans
]):
continue
result.append((shrinked_span, attr_dict))
return result
def get_specified_spans(self) -> list[Span]:
spans = list(it.chain(
self.internal_specified_spans,
self.external_specified_spans,
self.find_spans_by_selector(self.isolate)
))
filtered_spans = list(filter(
lambda span: all([
self.is_splittable_index(index)
for index in span
]),
spans
))
shrinked_spans = list(filter(
lambda span: span[0] < span[1],
[self.shrink_span(span) for span in filtered_spans]
))
return remove_list_redundancies(shrinked_spans)
@abstractmethod
def get_label_span_list(self) -> list[Span]:
return []
return remove_list_redundancies([
span for span, _ in self.specified_items
])
def check_overlapping(self) -> None:
if len(self.label_span_list) >= 16777216:
raise ValueError("Cannot label that many substrings")
for span_0, span_1 in it.product(self.label_span_list, repeat=2):
spans = remove_list_redundancies(list(it.chain(
self.specified_spans,
self.bracket_spans
)))
for span_0, span_1 in it.product(spans, repeat=2):
if not span_0[0] < span_1[0] < span_0[1] < span_1[1]:
continue
raise ValueError(
@ -404,23 +354,21 @@ class LabelledString(SVGMobject, ABC):
f"'{self.get_substr(span_0)}' and '{self.get_substr(span_1)}'"
)
@abstractmethod
def get_label_span_list(self) -> list[Span]:
return []
@abstractmethod
def get_content(self, is_labelled: bool) -> str:
return ""
# Post-parsing
@abstractmethod
def get_cleaned_substr(self, span: Span) -> str:
span_repl_dict = dict.fromkeys(self.command_spans, "")
return self.get_replaced_substr(span, span_repl_dict)
return ""
def get_specified_substrs(self) -> list[str]:
return remove_list_redundancies([
self.get_cleaned_substr(span)
for span in self.specified_spans
])
def get_group_items(self) -> list[tuple[str, VGroup]]:
def get_group_part_items(self) -> list[tuple[str, VGroup]]:
if not self.labelled_submobject_items:
return []
@ -445,41 +393,33 @@ class LabelledString(SVGMobject, ABC):
ordered_spans
)
]
shrinked_spans = [
self.shrink_span(span)
group_substrs = [
self.get_cleaned_substr(span) if span[0] < span[1] else ""
for span in self.get_complement_spans(
interval_spans, (ordered_spans[0][0], ordered_spans[-1][1])
)
]
group_substrs = [
self.get_cleaned_substr(span) if span[0] < span[1] else ""
for span in shrinked_spans
]
submob_groups = VGroup(*[
VGroup(*labelled_submobjects[slice(*submob_span)])
for submob_span in labelled_submob_spans
])
return list(zip(group_substrs, submob_groups))
def get_group_substrs(self) -> list[str]:
return [group_substr for group_substr, _ in self.group_items]
def get_submob_groups(self) -> list[VGroup]:
return [submob_group for _, submob_group in self.group_items]
def select_parts_by_group_substr(self, substr: str) -> VGroup:
return VGroup(*[
group
for group_substr, group in self.group_items
if group_substr == substr
])
def get_specified_part_items(self) -> list[tuple[str, VGroup]]:
return [
(
self.get_substr(span),
self.select_part_by_span(span, substring=False)
)
for span in self.specified_spans
]
# Selector
def find_span_components(
self, custom_span: Span, substring: bool = True
) -> list[Span]:
shrinked_span = self.shrink_span(custom_span)
shrinked_span = self.shrink_span(custom_span, self.skippable_indices)
if shrinked_span[0] >= shrinked_span[1]:
return []
@ -488,12 +428,12 @@ class LabelledString(SVGMobject, ABC):
self.full_span,
*self.label_span_list
)))
span_begin = self.take_nearest_value(
indices, shrinked_span[0], 0
)
span_end = self.take_nearest_value(
indices, shrinked_span[1] - 1, 1
)
span_begin = max(filter(
lambda index: index <= shrinked_span[0], indices
))
span_end = min(filter(
lambda index: index >= shrinked_span[1], indices
))
else:
span_begin, span_end = shrinked_span

View file

@ -33,6 +33,15 @@ if TYPE_CHECKING:
SCALE_FACTOR_PER_FONT_POINT = 0.001
TEX_COLOR_COMMANDS_DICT = {
"\\color": (1, False),
"\\textcolor": (1, False),
"\\pagecolor": (1, True),
"\\colorbox": (1, True),
"\\fcolorbox": (2, True),
}
class MTex(LabelledString):
CONFIG = {
"font_size": 48,
@ -78,10 +87,12 @@ class MTex(LabelledString):
def pre_parse(self) -> None:
super().pre_parse()
self.backslash_indices = self.get_backslash_indices()
self.brace_index_pairs = self.get_brace_index_pairs()
self.script_char_spans = self.get_script_char_spans()
self.command_spans = self.get_command_spans()
self.brace_spans = self.get_brace_spans()
self.script_char_indices = self.get_script_char_indices()
self.script_content_spans = self.get_script_content_spans()
self.script_spans = self.get_script_spans()
self.command_repl_items = self.get_command_repl_items()
# Toolkits
@ -95,61 +106,61 @@ class MTex(LabelledString):
def get_backslash_indices(self) -> list[int]:
# The latter of `\\` doesn't count.
return list(it.chain(*[
range(span[0], span[1], 2)
for span in self.find_spans(r"\\+")
]))
return self.find_indices(r"\\.")
def get_unescaped_char_spans(self, chars: str):
return sorted(filter(
lambda span: span[0] - 1 not in self.backslash_indices,
list(it.chain(*[
self.find_spans(re.escape(char))
for char in chars
]))
def get_command_spans(self) -> list[Span]:
return [
self.match(r"\\(?:[a-zA-Z]+|.)", pos=index).span()
for index in self.backslash_indices
]
def get_unescaped_char_indices(self, char: str) -> list[int]:
return list(filter(
lambda index: index - 1 not in self.backslash_indices,
self.find_indices(re.escape(char))
))
def get_brace_index_pairs(self) -> list[Span]:
left_brace_indices = []
right_brace_indices = []
left_brace_indices_stack = []
for span in self.get_unescaped_char_spans("{}"):
index = span[0]
if self.get_substr(span) == "{":
left_brace_indices_stack.append(index)
def get_brace_spans(self) -> list[Span]:
span_begins = []
span_ends = []
span_begins_stack = []
char_items = sorted([
(index, char)
for char in "{}"
for index in self.get_unescaped_char_indices(char)
])
for index, char in char_items:
if char == "{":
span_begins_stack.append(index)
else:
if not left_brace_indices_stack:
if not span_begins_stack:
raise ValueError("Missing '{' inserted")
left_brace_index = left_brace_indices_stack.pop()
left_brace_indices.append(left_brace_index)
right_brace_indices.append(index)
if left_brace_indices_stack:
span_begins.append(span_begins_stack.pop())
span_ends.append(index + 1)
if span_begins_stack:
raise ValueError("Missing '}' inserted")
return list(zip(left_brace_indices, right_brace_indices))
return list(zip(span_begins, span_ends))
def get_script_char_spans(self) -> list[int]:
return self.get_unescaped_char_spans("_^")
def get_script_char_indices(self) -> list[int]:
return list(it.chain(*[
self.get_unescaped_char_indices(char)
for char in "_^"
]))
def get_script_content_spans(self) -> list[Span]:
result = []
brace_indices_dict = dict(self.brace_index_pairs)
script_pattern = r"[a-zA-Z0-9]|\\[a-zA-Z]+"
for char_span in self.script_char_spans:
span_begin = self.rslide(char_span[1], self.space_spans)
if span_begin in brace_indices_dict.keys():
span_end = brace_indices_dict[span_begin] + 1
script_entity_dict = dict(it.chain(
self.brace_spans,
self.command_spans
))
for index in self.script_char_indices:
span_begin = self.match(r"\s*", pos=index + 1).end()
if span_begin in script_entity_dict.keys():
span_end = script_entity_dict[span_begin]
else:
match_obj = self.match_at(script_pattern, span_begin)
match_obj = self.match(r".", pos=span_begin)
if match_obj is None:
script_name = {
"_": "subscript",
"^": "superscript"
}[self.get_string(char_span)]
raise ValueError(
f"Unclear {script_name} detected while parsing "
f"(position {char_span[0]}). "
"Please use braces to clarify"
)
continue
span_end = match_obj.end()
result.append((span_begin, span_end))
return result
@ -157,46 +168,29 @@ class MTex(LabelledString):
def get_script_spans(self) -> list[Span]:
return [
(
self.lslide(char_span[0], self.space_spans),
self.match(r"[\s\S]*?(\s*)$", endpos=index).start(1),
script_content_span[1]
)
for char_span, script_content_span in zip(
self.script_char_spans, self.script_content_spans
for index, script_content_span in zip(
self.script_char_indices, self.script_content_spans
)
]
# Parsing
def get_command_repl_items(self) -> list[tuple[Span, str]]:
color_related_command_dict = {
"color": (1, False),
"textcolor": (1, False),
"pagecolor": (1, True),
"colorbox": (1, True),
"fcolorbox": (2, True),
}
result = []
backslash_indices = self.backslash_indices
right_brace_indices = [
right_index
for left_index, right_index in self.brace_index_pairs
]
pattern = "".join([
r"\\",
"(",
"|".join(color_related_command_dict.keys()),
")",
r"(?![a-zA-Z])"
])
for match_obj in re.finditer(pattern, self.string):
span_begin, cmd_end = match_obj.span()
if span_begin not in backslash_indices:
brace_spans_dict = dict(self.brace_spans)
brace_begins = list(brace_spans_dict.keys())
for cmd_span in self.command_spans:
cmd_name = self.get_substr(cmd_span)
if cmd_name not in TEX_COLOR_COMMANDS_DICT.keys():
continue
cmd_name = match_obj.group(1)
n_braces, substitute_cmd = color_related_command_dict[cmd_name]
span_end = self.take_nearest_value(
right_brace_indices, cmd_end, n_braces
) + 1
n_braces, substitute_cmd = TEX_COLOR_COMMANDS_DICT[cmd_name]
span_begin, span_end = cmd_span
for _ in n_braces:
span_end = brace_spans_dict[min(filter(
lambda index: index >= span_end,
brace_begins
))]
if substitute_cmd:
repl_str = "\\" + cmd_name + n_braces * "{black}"
else:
@ -204,51 +198,60 @@ class MTex(LabelledString):
result.append(((span_begin, span_end), repl_str))
return result
def get_extra_entity_spans(self) -> list[Span]:
return [
self.match_at(r"\\([a-zA-Z]+|.?)", index).span()
for index in self.backslash_indices
]
# Parsing
def get_extra_ignored_spans(self) -> list[int]:
return self.script_char_spans.copy()
def get_skippable_indices(self) -> list[int]:
return list(it.chain(
self.find_indices(r"\s"),
self.script_char_indices
))
def get_internal_specified_spans(self) -> list[Span]:
# Match paired double braces (`{{...}}`).
def get_entity_spans(self) -> list[Span]:
return self.command_spans.copy()
def get_bracket_spans(self) -> list[Span]:
return self.brace_spans.copy()
def get_extra_isolated_items(self) -> list[tuple[Span, dict[str, str]]]:
result = []
reversed_brace_indices_dict = dict([
pair[::-1] for pair in self.brace_index_pairs
])
# Match paired double braces (`{{...}}`).
reversed_brace_spans_dict = dict(sorted([
pair[::-1] for pair in self.brace_spans
]))
skip = False
for prev_right_index, right_index in self.get_neighbouring_pairs(
list(reversed_brace_indices_dict.keys())
for prev_brace_end, brace_end in self.get_neighbouring_pairs(
list(reversed_brace_spans_dict.keys())
):
if skip:
skip = False
continue
if right_index != prev_right_index + 1:
if brace_end != prev_brace_end + 1:
continue
left_index = reversed_brace_indices_dict[right_index]
prev_left_index = reversed_brace_indices_dict[prev_right_index]
if left_index != prev_left_index - 1:
brace_begin = reversed_brace_spans_dict[brace_end]
prev_brace_begin = reversed_brace_spans_dict[prev_brace_end]
if brace_begin != prev_brace_begin - 1:
continue
result.append((left_index, right_index + 1))
result.append((brace_begin, brace_end))
skip = True
return result
def get_external_specified_spans(self) -> list[Span]:
return list(it.chain(*[
result.extend(it.chain(*[
self.find_spans_by_selector(selector)
for selector in self.tex_to_color_map.keys()
]))
return [(span, {}) for span in result]
def get_label_span_list(self) -> list[Span]:
result = self.script_content_spans.copy()
reversed_script_spans_dict = dict([
script_span[::-1] for script_span in self.script_spans
])
for span_begin, span_end in self.specified_spans:
shrinked_end = self.lslide(span_end, self.script_spans)
if span_begin >= shrinked_end:
while span_end in reversed_script_spans_dict.keys():
span_end = reversed_script_spans_dict[span_end]
if span_begin >= span_end:
continue
shrinked_span = (span_begin, shrinked_end)
shrinked_span = (span_begin, span_end)
if shrinked_span in result:
continue
result.append(shrinked_span)
@ -256,12 +259,15 @@ class MTex(LabelledString):
def get_content(self, is_labelled: bool) -> str:
if is_labelled:
extended_label_span_list = [
span
if span in self.script_content_spans
else (span[0], self.rslide(span[1], self.script_spans))
for span in self.label_span_list
]
extended_label_span_list = []
script_spans_dict = dict(self.script_spans)
for span in self.label_span_list:
if span not in self.script_content_spans:
span_begin, span_end = span
while span_end in script_spans_dict.keys():
span_end = script_spans_dict[span_end]
span = (span_begin, span_end)
extended_label_span_list.append(span)
inserted_string_pairs = [
(span, (
"{{" + self.get_color_command_str(label + 1),
@ -270,8 +276,7 @@ class MTex(LabelledString):
for label, span in enumerate(extended_label_span_list)
]
span_repl_dict = self.generate_span_repl_dict(
inserted_string_pairs,
self.command_repl_items
inserted_string_pairs, self.command_repl_items
)
else:
span_repl_dict = {}
@ -296,15 +301,26 @@ class MTex(LabelledString):
# Post-parsing
def get_cleaned_substr(self, span: Span) -> str:
substr = super().get_cleaned_substr(span)
if not self.brace_index_pairs:
return substr
if not self.brace_spans:
brace_begins, brace_ends = [], []
else:
brace_begins, brace_ends = zip(*self.brace_spans)
left_brace_indices = list(brace_begins)
right_brace_indices = [index - 1 for index in brace_ends]
skippable_indices = list(it.chain(
self.skippable_indices,
left_brace_indices,
right_brace_indices
))
shrinked_span = self.shrink_span(span, skippable_indices)
if shrinked_span[0] >= shrinked_span[1]:
return ""
# Balance braces.
left_brace_indices, right_brace_indices = zip(*self.brace_index_pairs)
unclosed_left_braces = 0
unclosed_right_braces = 0
for index in range(*span):
for index in range(*shrinked_span):
if index in left_brace_indices:
unclosed_left_braces += 1
elif index in right_brace_indices:
@ -314,7 +330,7 @@ class MTex(LabelledString):
unclosed_left_braces -= 1
return "".join([
unclosed_right_braces * "{",
substr,
self.get_substr(shrinked_span),
unclosed_left_braces * "}"
])

View file

@ -20,7 +20,6 @@ from manimlib.utils.config_ops import digest_config
from manimlib.utils.customization import get_customization
from manimlib.utils.directories import get_downloads_dir
from manimlib.utils.directories import get_text_dir
from manimlib.utils.iterables import remove_list_redundancies
from manimlib.utils.tex_file_writing import tex_hash
from typing import TYPE_CHECKING
@ -244,11 +243,9 @@ class MarkupText(LabelledString):
def pre_parse(self) -> None:
super().pre_parse()
self.tag_items_from_markup = self.get_tag_items_from_markup()
self.global_dict_from_config = self.get_global_dict_from_config()
self.local_dicts_from_markup = self.get_local_dicts_from_markup()
self.local_dicts_from_config = self.get_local_dicts_from_config()
self.predefined_attr_dicts = self.get_predefined_attr_dicts()
self.tag_pairs_from_markup = self.get_tag_pairs_from_markup()
self.tag_spans = self.get_tag_spans()
self.items_from_markup = self.get_items_from_markup()
# Toolkits
@ -259,42 +256,9 @@ class MarkupText(LabelledString):
for key, val in attr_dict.items()
])
@staticmethod
def merge_attr_dicts(
attr_dict_items: list[tuple[Span, dict[str, str]]]
) -> list[tuple[Span, dict[str, str]]]:
index_seq = [0]
attr_dict_list = [{}]
for span, attr_dict in attr_dict_items:
if span[0] >= span[1]:
continue
region_indices = [
MarkupText.find_region_index(index_seq, index)
for index in span
]
for flag in (1, 0):
if index_seq[region_indices[flag]] == span[flag]:
continue
region_index = region_indices[flag]
index_seq.insert(region_index + 1, span[flag])
attr_dict_list.insert(
region_index + 1, attr_dict_list[region_index].copy()
)
region_indices[flag] += 1
if flag == 0:
region_indices[1] += 1
for key, val in attr_dict.items():
if not key:
continue
for mid_dict in attr_dict_list[slice(*region_indices)]:
mid_dict[key] = val
return list(zip(
MarkupText.get_neighbouring_pairs(index_seq), attr_dict_list[:-1]
))
# Pre-parsing
def get_tag_items_from_markup(
def get_tag_pairs_from_markup(
self
) -> list[tuple[Span, Span, dict[str, str]]]:
if not self.is_markup:
@ -342,52 +306,64 @@ class MarkupText(LabelledString):
)
return result
def get_global_dict_from_config(self) -> dict[str, str]:
result = {
"line_height": str((
(self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1
) * 0.6),
"font_family": self.font,
"font_size": str(self.font_size * 1024),
"font_style": self.slant,
"font_weight": self.weight
}
result.update(self.global_config)
return result
def get_local_dicts_from_markup(
self
) -> list[Span, dict[str, str]]:
return sorted([
((begin_tag_span[0], end_tag_span[1]), attr_dict)
for begin_tag_span, end_tag_span, attr_dict
in self.tag_items_from_markup
])
def get_local_dicts_from_config(
self
) -> list[Span, dict[str, str]]:
def get_tag_spans(self) -> list[Span]:
return [
(span, {key: val})
for t2x_dict, key in (
(self.t2c, "foreground"),
(self.t2f, "font_family"),
(self.t2s, "font_style"),
(self.t2w, "font_weight")
)
for selector, val in t2x_dict.items()
for span in self.find_spans_by_selector(selector)
] + [
(span, local_config)
for selector, local_config in self.local_configs.items()
for span in self.find_spans_by_selector(selector)
tag_span
for begin_tag, end_tag, _ in self.tag_pairs_from_markup
for tag_span in (begin_tag, end_tag)
]
def get_predefined_attr_dicts(self) -> list[Span, dict[str, str]]:
attr_dict_items = [
(self.full_span, self.global_dict_from_config),
*self.local_dicts_from_markup,
*self.local_dicts_from_config
def get_items_from_markup(self) -> list[Span]:
return [
((begin_tag_span[0], end_tag_span[1]), attr_dict)
for begin_tag_span, end_tag_span, attr_dict
in self.tag_pairs_from_markup
]
# Parsing
def get_skippable_indices(self) -> list[int]:
return self.find_indices(r"\s")
def get_entity_spans(self) -> list[Span]:
result = self.tag_spans.copy()
if self.is_markup:
result.extend(self.find_spans(r"&[\s\S]*?;"))
return result
def get_bracket_spans(self) -> list[Span]:
return [span for span, _ in self.items_from_markup]
def get_extra_isolated_items(self) -> list[tuple[Span, dict[str, str]]]:
result = [
(self.full_span, {
"line_height": str((
(self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1
) * 0.6),
"font_family": self.font,
"font_size": str(self.font_size * 1024),
"font_style": self.slant,
"font_weight": self.weight,
"foreground": self.int_to_hex(self.base_color_int)
}),
(self.full_span, self.global_config),
*self.items_from_markup,
*[
(span, {key: val})
for t2x_dict, key in (
(self.t2c, "foreground"),
(self.t2f, "font_family"),
(self.t2s, "font_style"),
(self.t2w, "font_weight")
)
for selector, val in t2x_dict.items()
for span in self.find_spans_by_selector(selector)
],
*[
(span, local_config)
for selector, local_config in self.local_configs.items()
for span in self.find_spans_by_selector(selector)
]
]
key_conversion_dict = {
key: key_alias_list[0]
@ -399,19 +375,63 @@ class MarkupText(LabelledString):
key_conversion_dict[key.lower()]: val
for key, val in attr_dict.items()
})
for span, attr_dict in result
]
def get_label_span_list(self) -> list[Span]:
interval_spans = sorted(it.chain(
self.tag_spans,
[
(index, index)
for span in self.specified_spans
for index in span
]
))
text_spans = self.get_complement_spans(interval_spans, self.full_span)
if self.is_markup:
pattern = r"[0-9a-zA-Z]+|(?:&[\s\S]*?;|[^0-9a-zA-Z\s])+"
else:
pattern = r"[0-9a-zA-Z]+|[^0-9a-zA-Z\s]+"
return list(it.chain(*[
self.find_spans(pattern, pos=span_begin, endpos=span_end)
for span_begin, span_end in text_spans
]))
def get_content(self, is_labelled: bool) -> str:
if is_labelled:
attr_dict_items = list(it.chain(
[
(span, {
key: BLACK if key in MARKUP_COLOR_KEYS else val
for key, val in attr_dict.items()
})
for span, attr_dict in self.specified_items
],
[
(span, {"foreground": self.int_to_hex(label + 1)})
for label, span in enumerate(self.label_span_list)
]
))
else:
attr_dict_items = list(it.chain(
self.specified_items,
[
(span, {})
for span in self.label_span_list
]
))
inserted_string_pairs = [
(span, (
f"<span {self.get_attr_dict_str(attr_dict)}>",
"</span>"
))
for span, attr_dict in attr_dict_items
]
# Parsing
def get_command_repl_items(self) -> list[tuple[Span, str]]:
result = [
(tag_span, "")
for begin_tag, end_tag, _ in self.tag_items_from_markup
for tag_span in (begin_tag, end_tag)
repl_items = [
(tag_span, "") for tag_span in self.tag_spans
]
if not self.is_markup:
result += [
repl_items.extend([
(span, escaped)
for char, escaped in (
("&", "&amp;"),
@ -419,83 +439,18 @@ class MarkupText(LabelledString):
("<", "&lt;")
)
for span in self.find_spans(re.escape(char))
]
return result
def get_extra_entity_spans(self) -> list[Span]:
if not self.is_markup:
return []
return self.find_spans(r"&.*?;")
def get_extra_ignored_spans(self) -> list[int]:
return []
def get_internal_specified_spans(self) -> list[Span]:
return []
def get_external_specified_spans(self) -> list[Span]:
return [span for span, _ in self.local_dicts_from_config]
def get_label_span_list(self) -> list[Span]:
breakup_indices = remove_list_redundancies(list(it.chain(*it.chain(
self.find_spans(r"\b"),
self.space_spans,
self.specified_spans
))))
breakup_indices = sorted(filter(
self.is_splittable_index, breakup_indices
))
return list(filter(
lambda span: self.get_substr(span).strip(),
self.get_neighbouring_pairs(breakup_indices)
))
def get_content(self, is_labelled: bool) -> str:
filtered_attr_dicts = list(filter(
lambda item: all([
self.is_splittable_index(index)
for index in item[0]
]),
self.predefined_attr_dicts
))
if is_labelled:
attr_dict_items = [
(self.full_span, {"foreground": BLACK}),
*[
(span, {
key: BLACK if key in MARKUP_COLOR_KEYS else val
for key, val in attr_dict.items()
})
for span, attr_dict in filtered_attr_dicts
],
*[
(span, {"foreground": self.int_to_hex(label + 1)})
for label, span in enumerate(self.label_span_list)
]
]
else:
attr_dict_items = [
(self.full_span, {
"foreground": self.int_to_hex(self.base_color_int)
}),
*filtered_attr_dicts,
*[
(span, {})
for span in self.label_span_list
]
]
inserted_string_pairs = [
(span, (
f"<span {self.get_attr_dict_str(attr_dict)}>",
"</span>"
))
for span, attr_dict in self.merge_attr_dicts(attr_dict_items)
]
])
span_repl_dict = self.generate_span_repl_dict(
inserted_string_pairs, self.command_repl_items
inserted_string_pairs, repl_items
)
return self.get_replaced_substr(self.full_span, span_repl_dict)
# Post-parsing
def get_cleaned_substr(self, span: Span) -> str:
repl_dict = dict.fromkeys(self.tag_spans, "")
return self.get_replaced_substr(span, repl_dict).strip()
# Method alias
def get_parts_by_text(self, selector: Selector, **kwargs) -> VGroup: