Refactor StringMobject and relevant classes

This commit is contained in:
YishiMichael 2022-08-07 11:29:31 +08:00
parent f434eb93e2
commit 28e4240475
No known key found for this signature in database
GPG key ID: EC615C0C5A86BC80
3 changed files with 153 additions and 124 deletions

View file

@ -106,16 +106,16 @@ class MTex(StringMobject):
@staticmethod
def get_internal_specified_items(
cmd_match_pairs: list[tuple[re.Match, re.Match]]
command_match_pairs: list[tuple[re.Match, re.Match]]
) -> list[tuple[Span, dict[str, str]]]:
cmd_content_spans = [
command_content_spans = [
(start_match.end(), end_match.start())
for start_match, end_match in cmd_match_pairs
for start_match, end_match in command_match_pairs
]
return [
(span, {})
for span, next_span
in MTex.get_neighbouring_pairs(cmd_content_spans)
in MTex.get_neighbouring_pairs(command_content_spans)
if span[0] == next_span[0] + 1 and span[1] == next_span[1] - 1
]
@ -129,7 +129,7 @@ class MTex(StringMobject):
]
@staticmethod
def get_color_cmd_str(rgb_hex: str) -> str:
def get_color_command(rgb_hex: str) -> str:
rgb = MTex.hex_to_int(rgb_hex)
rg, b = divmod(rgb, 256)
r, g = divmod(rg, 256)
@ -143,7 +143,7 @@ class MTex(StringMobject):
return ""
if is_end:
return "}}"
return "{{" + MTex.get_color_cmd_str(label_hex)
return "{{" + MTex.get_color_command(label_hex)
def get_content_prefix_and_suffix(
self, is_labelled: bool
@ -151,7 +151,9 @@ class MTex(StringMobject):
prefix_lines = []
suffix_lines = []
if not is_labelled:
prefix_lines.append(self.get_color_cmd_str(self.base_color_hex))
prefix_lines.append(self.get_color_command(
self.color_to_hex(self.base_color)
))
if self.alignment:
prefix_lines.append(self.alignment)
if self.tex_environment:

View file

@ -14,7 +14,7 @@ from manimlib.utils.color import color_to_rgb
from manimlib.utils.color import rgb_to_hex
from manimlib.utils.config_ops import digest_config
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable
if TYPE_CHECKING:
from colour import Color
@ -71,9 +71,8 @@ class StringMobject(SVGMobject, ABC):
digest_config(self, kwargs)
if self.base_color is None:
self.base_color = WHITE
self.base_color_hex = self.color_to_hex(self.base_color)
#self.base_color_hex = self.color_to_hex(self.base_color)
self.full_len = len(self.string)
self.parse()
super().__init__(**kwargs)
self.labels = [submob.label for submob in self.submobjects]
@ -169,7 +168,7 @@ class StringMobject(SVGMobject, ABC):
isinstance(index, int) or index is None
for index in sel
):
l = self.full_len
l = len(self.string)
span = tuple(
default_index if index is None else
min(index, l) if index >= 0 else max(index + l, 0)
@ -188,35 +187,15 @@ class StringMobject(SVGMobject, ABC):
result.extend(spans)
return list(filter(lambda span: span[0] < span[1], result))
def get_substr(self, span: Span) -> str:
return self.string[slice(*span)]
@staticmethod
def get_neighbouring_pairs(vals: Iterable[T]) -> list[tuple[T, T]]:
val_list = list(vals)
return list(zip(val_list[:-1], val_list[1:]))
@staticmethod
def join_strs(strs, inserted_strs):
return "".join(it.chain(*zip(strs, (*inserted_strs, ""))))
@staticmethod
def span_contains(span_0: Span, span_1: Span) -> bool:
return span_0[0] <= span_1[0] and span_0[1] >= span_1[1]
@staticmethod
def get_complement_spans(
universal_span: Span, interval_spans: list[Span]
) -> list[Span]:
if not interval_spans:
return [universal_span]
span_ends, span_starts = zip(*interval_spans)
return list(zip(
(universal_span[0], *span_starts),
(*span_ends, universal_span[1])
))
@staticmethod
def color_to_hex(color: ManimColor) -> str:
return rgb_to_hex(color_to_rgb(color))
@ -232,6 +211,23 @@ class StringMobject(SVGMobject, ABC):
# Parsing
def parse(self) -> None:
def get_substr(span: Span) -> str:
return self.string[slice(*span)]
def get_complement_spans(
universal_span: Span, interval_spans: list[Span]
) -> list[Span]:
if not interval_spans:
return [universal_span]
span_ends, span_starts = zip(*interval_spans)
return list(zip(
(universal_span[0], *span_starts),
(*span_ends, universal_span[1])
))
full_len = len(self.string)
command_matches = list(re.finditer(
self.get_command_pattern(), self.string, re.X | re.S
))
@ -251,11 +247,11 @@ class StringMobject(SVGMobject, ABC):
]
]
command_spans = [match_obj.span() for match_obj in command_matches]
region_spans = self.get_complement_spans(
(0, self.full_len), command_spans
region_spans = get_complement_spans(
(0, full_len), command_spans
)
def get_region_index(index):
def get_region_index(index: int) -> int:
for region_index, (start, end) in enumerate(region_spans):
if start <= index <= end:
return region_index
@ -273,7 +269,7 @@ class StringMobject(SVGMobject, ABC):
]):
log.warning(
"Cannot handle substring '%s', ignored",
self.get_substr(span)
get_substr(span)
)
continue
overlapped_spans = [
@ -285,14 +281,14 @@ class StringMobject(SVGMobject, ABC):
if overlapped_spans:
log.warning(
"Substring '%s' partly overlaps with '%s', ignored",
self.get_substr(span),
self.get_substr(overlapped_spans[0])
get_substr(span),
get_substr(overlapped_spans[0])
)
continue
labelled_spans.append(span)
attr_dicts.append(attr_dict)
insertion_items = [
inserted_items = [
label_flag_pair
for _, label_flag_pair in sorted(it.chain(*(
sorted([
@ -302,36 +298,127 @@ class StringMobject(SVGMobject, ABC):
for flag in (-1, 1)
)), key=lambda t: t[0][0])
]
insertion_interval_items = [
#inserted_indices = [0, *(
# labelled_spans[label][flag < 0]
# for label, flag in inserted_items
#), full_len]
inserted_interval_items = [
tuple(zip(*pair))
for pair in self.get_neighbouring_pairs([
(index, get_region_index(index))
for index in [0, *(
labelled_spans[label][flag < 0]
for label, flag in insertion_items
), self.full_len]
for label, flag in inserted_items
), full_len]
])
]
def get_replaced_pieces(replace_func):
def join_strs(strs: list[str], inserted_strs: list[str]) -> str:
return "".join(it.chain(*zip(strs, (*inserted_strs, ""))))
def get_replaced_pieces(replace_func: Callable[[re.Match], str]) -> list[str]:
return [
self.join_strs([
self.get_substr(s)
for s in self.get_complement_spans(
join_strs([
get_substr(s)
for s in get_complement_spans(
span, command_spans[slice(*region_range)]
)
], [
replace_func(command_match)
for command_match in command_matches[slice(*region_range)]
])
for span, region_range in insertion_interval_items
for span, region_range in inserted_interval_items
]
content_pieces = get_replaced_pieces(self.replace_for_content)
matching_pieces = get_replaced_pieces(self.replace_for_matching)
def get_content(is_labelled: bool) -> str:
inserted_strings = [
self.get_command_string(
attr_dicts[label],
is_end=flag < 0,
label_hex=self.int_to_hex(label + 1) if is_labelled else None
)
for label, flag in inserted_items
]
prefix, suffix = self.get_content_prefix_and_suffix(
is_labelled=is_labelled
)
return "".join([
prefix,
join_strs(content_pieces, inserted_strings),
suffix
])
def get_group_substrs(group_labels: list[int]) -> list[str]:
if not group_labels:
return []
def get_index(label, flag):
if label == -1:
return 0 if flag == 1 else len(inserted_items) + 1
return inserted_items.index((label, flag)) + 1
def get_labelled_span(label):
if label == -1:
return (0, full_len)
return labelled_spans[label]
def label_contains(label_0, label_1):
return self.span_contains(
get_labelled_span(label_0), get_labelled_span(label_1)
)
#piece_starts = [
# get_index(group_labels[0], 1),
# *(
# get_index(curr_label, 1)
# if label_contains(prev_label, curr_label)
# else get_index(prev_label, -1)
# for prev_label, curr_label in self.get_neighbouring_pairs(
# group_labels
# )
# )
#]
#piece_ends = [
# *(
# get_index(curr_label, -1)
# if label_contains(next_label, curr_label)
# else get_index(next_label, 1)
# for curr_label, next_label in self.get_neighbouring_pairs(
# group_labels
# )
# ),
# get_index(group_labels[-1], -1)
#]
piece_ranges = get_complement_spans(
(get_index(group_labels[0], 1), get_index(group_labels[-1], -1)),
[
(
get_index(next_label, 1)
if label_contains(prev_label, next_label)
else get_index(prev_label, -1),
get_index(prev_label, -1)
if label_contains(next_label, prev_label)
else get_index(next_label, 1)
)
for prev_label, next_label in self.get_neighbouring_pairs(
group_labels
)
]
)
return [
re.sub(r"\s+", "", "".join(
matching_pieces[slice(*piece_ranges)]
))
for piece_ranges in piece_ranges
]
self.labelled_spans = labelled_spans
self.attr_dicts = attr_dicts
self.insertion_items = insertion_items
self.content_pieces = get_replaced_pieces(self.replace_for_content)
self.matching_pieces = get_replaced_pieces(self.replace_for_matching)
self.get_content = get_content
self.get_group_substrs = get_group_substrs
@staticmethod
@abstractmethod
@ -361,7 +448,9 @@ class StringMobject(SVGMobject, ABC):
return []
@abstractmethod
def get_external_specified_items(self) -> list[tuple[Span, dict[str, str]]]:
def get_external_specified_items(
self
) -> list[tuple[Span, dict[str, str]]]:
return []
@staticmethod
@ -382,83 +471,21 @@ class StringMobject(SVGMobject, ABC):
command_matches: list[re.Match], command_flags: list[int]
) -> list[tuple[re.Match, re.Match]]:
result = []
start_cmd_matches_stack = []
for cmd_match, command_flag in zip(command_matches, command_flags):
if command_flag == 1:
start_cmd_matches_stack.append(cmd_match)
elif command_flag == -1:
if not start_cmd_matches_stack:
open_stack = []
for command_match, flag in zip(command_matches, command_flags):
if flag == 1:
open_stack.append(command_match)
elif flag == -1:
if not open_stack:
raise ValueError("Missing open command")
start_cmd_match = start_cmd_matches_stack.pop()
open_command_match = open_stack.pop()
result.append(
(start_cmd_match, cmd_match)
(open_command_match, command_match)
)
if start_cmd_matches_stack:
if open_stack:
raise ValueError("Missing close command")
return result
def get_content(self, is_labelled: bool) -> str:
insertion_strings = [
self.get_command_string(
self.attr_dicts[label],
is_end=flag < 0,
label_hex=self.int_to_hex(label + 1) if is_labelled else None
)
for label, flag in self.insertion_items
]
prefix, suffix = self.get_content_prefix_and_suffix(
is_labelled=is_labelled
)
return "".join([
prefix,
self.join_strs(self.content_pieces, insertion_strings),
suffix
])
def get_group_substrs(self, group_labels: list[int]) -> list[str]:
if not group_labels:
return []
insertion_items = self.insertion_items
def get_index(label, flag):
if label == -1:
return 0 if flag == 1 else len(insertion_items) + 1
return insertion_items.index((label, flag)) + 1
def get_labelled_span(label):
if label == -1:
return (0, self.full_len)
return self.labelled_spans[label]
def label_contains(label_0, label_1):
return self.span_contains(
get_labelled_span(label_0), get_labelled_span(label_1)
)
piece_ranges = self.get_complement_spans(
(get_index(group_labels[0], 1), get_index(group_labels[-1], -1)),
[
(
get_index(next_label, 1)
if label_contains(prev_label, next_label)
else get_index(prev_label, -1),
get_index(prev_label, -1)
if label_contains(next_label, prev_label)
else get_index(next_label, 1)
)
for prev_label, next_label in self.get_neighbouring_pairs(
group_labels
)
]
)
return [
re.sub(r"\s+", "", "".join(
self.matching_pieces[slice(*piece_range)]
))
for piece_range in piece_ranges
]
# Selector
def get_submob_indices_list_by_span(
@ -475,7 +502,7 @@ class StringMobject(SVGMobject, ABC):
def get_specified_part_items(self) -> list[tuple[str, list[int]]]:
return [
(
self.get_substr(span),
self.string[slice(*span)],
self.get_submob_indices_list_by_span(span)
)
for span in self.labelled_spans

View file

@ -363,7 +363,7 @@ class MarkupText(StringMobject):
self, is_labelled: bool
) -> tuple[str, str]:
global_attr_dict = {
"foreground": self.base_color_hex,
"foreground": self.color_to_hex(self.base_color),
"font_family": self.font,
"font_style": self.slant,
"font_weight": self.weight,