Refactor StringMobject and relevant classes

This commit is contained in:
YishiMichael 2022-08-07 00:50:29 +08:00
parent 093af347aa
commit f434eb93e2
No known key found for this signature in database
GPG key ID: EC615C0C5A86BC80
3 changed files with 278 additions and 274 deletions

View file

@ -1,5 +1,7 @@
from __future__ import annotations
import itertools as it
from manimlib.mobject.svg.string_mobject import StringMobject
from manimlib.utils.tex_file_writing import display_during_execution
from manimlib.utils.tex_file_writing import tex_content_to_svg_file
@ -40,16 +42,6 @@ class MTex(StringMobject):
"additional_preamble": "",
}
#CMD_PATTERN = r"\\(?:[a-zA-Z]+|.)|[_^{}]"
#FLAG_DICT = {
# r"{": 1,
# r"}": -1
#}
#CONTENT_REPL = {}
#MATCH_REPL = {
# r"[_^{}]": ""
#}
def __init__(self, tex_string: str, **kwargs):
# Prevent from passing an empty string.
if not tex_string.strip():
@ -86,17 +78,20 @@ class MTex(StringMobject):
# Parsing
@staticmethod
def get_cmd_pattern() -> str | None:
return r"(\\(?:[a-zA-Z]+|.))|([_^])|([{}])"
def get_command_pattern() -> str:
return r"""
(?P<command>\\(?:[a-zA-Z]+|.))
|(?P<script>[_^])
|(?P<open>{)
|(?P<close>})
"""
@staticmethod
def get_matched_flag(match_obj: re.Match) -> int:
substr = match_obj.group()
if match_obj.group(3):
if substr == "{":
return 1
if substr == "}":
return -1
def get_command_flag(match_obj: re.Match) -> int:
if match_obj.group("open"):
return 1
if match_obj.group("close"):
return -1
return 0
@staticmethod
@ -105,7 +100,7 @@ class MTex(StringMobject):
@staticmethod
def replace_for_matching(match_obj: re.Match) -> str:
if not match_obj.group(1):
if not match_obj.group("command"):
return ""
return match_obj.group()
@ -114,26 +109,15 @@ class MTex(StringMobject):
cmd_match_pairs: list[tuple[re.Match, re.Match]]
) -> list[tuple[Span, dict[str, str]]]:
cmd_content_spans = [
(begin_match.end(), end_match.start())
for begin_match, end_match in cmd_match_pairs
(start_match.end(), end_match.start())
for start_match, end_match in cmd_match_pairs
]
#print(MTex.get_neighbouring_pairs(cmd_content_spans))
return [
(span, {})
for span, next_span
in MTex.get_neighbouring_pairs(cmd_content_spans)
if span[0] == next_span[0] + 1 and span[1] == next_span[1] - 1
]
#return [
# (cmd_content_spans[range_begin], {})
# for _, (range_begin, range_end) in self.group_neighbours([
# (span_begin + index, span_end - index)
# for index, (span_begin, span_end) in enumerate(
# cmd_content_spans
# )
# ])
# if range_end - range_begin >= 2
#]
def get_external_specified_items(
self
@ -152,12 +136,14 @@ class MTex(StringMobject):
return f"\\color[RGB]{{{r}, {g}, {b}}}"
@staticmethod
def get_cmd_str_pair(
attr_dict: dict[str, str], label_hex: str | None
) -> tuple[str, str]:
def get_command_string(
attr_dict: dict[str, str], is_end: bool, label_hex: str | None
) -> str:
if label_hex is None:
return "", ""
return "{{" + MTex.get_color_cmd_str(label_hex), "}}"
return ""
if is_end:
return "}}"
return "{{" + MTex.get_color_cmd_str(label_hex)
def get_content_prefix_and_suffix(
self, is_labelled: bool

View file

@ -18,7 +18,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from colour import Color
from typing import Callable, Iterable, TypeVar, Union
from typing import Iterable, TypeVar, Union
ManimColor = Union[str, Color]
Span = tuple[int, int]
@ -47,7 +47,7 @@ class StringMobject(SVGMobject, ABC):
if they want to do anything with their corresponding submobjects.
`isolate` parameter can be either a string, a `re.Pattern` object,
or a 2-tuple containing integers or None, or a collection of the above.
Note, substrings specified cannot *partially* overlap with each other.
Note, substrings specified cannot *partly* overlap with each other.
Each instance of `StringMobject` generates 2 svg files.
The additional one is generated with some color commands inserted,
@ -186,7 +186,7 @@ class StringMobject(SVGMobject, ABC):
if spans is None:
raise TypeError(f"Invalid selector: '{sel}'")
result.extend(spans)
return result
return list(filter(lambda span: span[0] < span[1], result))
def get_substr(self, span: Span) -> str:
return self.string[slice(*span)]
@ -196,6 +196,10 @@ class StringMobject(SVGMobject, ABC):
val_list = list(vals)
return list(zip(val_list[:-1], val_list[1:]))
@staticmethod
def join_strs(strs, inserted_strs):
return "".join(it.chain(*zip(strs, (*inserted_strs, ""))))
@staticmethod
def span_contains(span_0: Span, span_1: Span) -> bool:
return span_0[0] <= span_1[0] and span_0[1] >= span_1[1]
@ -207,9 +211,9 @@ class StringMobject(SVGMobject, ABC):
if not interval_spans:
return [universal_span]
span_ends, span_begins = zip(*interval_spans)
span_ends, span_starts = zip(*interval_spans)
return list(zip(
(universal_span[0], *span_begins),
(universal_span[0], *span_starts),
(*span_ends, universal_span[1])
))
@ -228,45 +232,115 @@ class StringMobject(SVGMobject, ABC):
# Parsing
def parse(self) -> None:
cmd_matches = list(re.finditer(
self.get_cmd_pattern(), self.string, flags=re.S
command_matches = list(re.finditer(
self.get_command_pattern(), self.string, re.X | re.S
))
cmd_spans = [match_obj.span() for match_obj in cmd_matches]
flags = [
self.get_matched_flag(match_obj)
for match_obj in cmd_matches
command_flags = [
self.get_command_flag(command_match)
for command_match in command_matches
]
specified_items = [
*self.get_internal_specified_items(
self.get_cmd_match_pairs(cmd_matches, flags)
),
command_match_pairs = self.get_command_match_pairs(
command_matches, command_flags
)
all_specified_items = [
*self.get_internal_specified_items(command_match_pairs),
*self.get_external_specified_items(),
*[
(span, {})
for span in self.find_spans_by_selector(self.isolate)
]
]
split_items = [
(span, attr_dict)
for specified_span, attr_dict in specified_items
for span in self.split_span_by_levels(
specified_span, cmd_spans, flags
)
command_spans = [match_obj.span() for match_obj in command_matches]
region_spans = self.get_complement_spans(
(0, self.full_len), command_spans
)
def get_region_index(index):
for region_index, (start, end) in enumerate(region_spans):
if start <= index <= end:
return region_index
return -1
labelled_spans = []
attr_dicts = []
for span, attr_dict in all_specified_items:
region_range = tuple(get_region_index(index) for index in span)
if -1 in region_range:
continue
levels = list(it.accumulate(command_flags[slice(*region_range)]))
if levels and any([
*(level < 0 for level in levels), levels[-1] > 0
]):
log.warning(
"Cannot handle substring '%s', ignored",
self.get_substr(span)
)
continue
overlapped_spans = [
s for s in labelled_spans if any([
s[0] < span[0] < s[1] < span[1],
span[0] < s[0] < span[1] < s[1]
])
]
if overlapped_spans:
log.warning(
"Substring '%s' partly overlaps with '%s', ignored",
self.get_substr(span),
self.get_substr(overlapped_spans[0])
)
continue
labelled_spans.append(span)
attr_dicts.append(attr_dict)
insertion_items = [
label_flag_pair
for _, label_flag_pair in sorted(it.chain(*(
sorted([
(span[::flag], (label, flag))
for label, span in list(enumerate(labelled_spans))[::flag]
], key=lambda t: (t[0][0], -t[0][1]))
for flag in (-1, 1)
)), key=lambda t: t[0][0])
]
insertion_interval_items = [
tuple(zip(*pair))
for pair in self.get_neighbouring_pairs([
(index, get_region_index(index))
for index in [0, *(
labelled_spans[label][flag < 0]
for label, flag in insertion_items
), self.full_len]
])
]
self.specified_spans = [span for span, _ in specified_items]
self.split_items = split_items
self.labelled_spans = [span for span, _ in split_items]
self.check_overlapping()
def get_replaced_pieces(replace_func):
return [
self.join_strs([
self.get_substr(s)
for s in self.get_complement_spans(
span, command_spans[slice(*region_range)]
)
], [
replace_func(command_match)
for command_match in command_matches[slice(*region_range)]
])
for span, region_range in insertion_interval_items
]
self.labelled_spans = labelled_spans
self.attr_dicts = attr_dicts
self.insertion_items = insertion_items
self.content_pieces = get_replaced_pieces(self.replace_for_content)
self.matching_pieces = get_replaced_pieces(self.replace_for_matching)
@staticmethod
@abstractmethod
def get_cmd_pattern() -> str:
return ""
def get_command_pattern() -> str:
return r"(?!)"
@staticmethod
@abstractmethod
def get_matched_flag(match_obj: re.Match) -> int:
def get_command_flag(match_obj: re.Match) -> int:
return 0
@staticmethod
@ -279,95 +353,23 @@ class StringMobject(SVGMobject, ABC):
def replace_for_matching(match_obj: re.Match) -> str:
return ""
@staticmethod
def get_cmd_match_pairs(
cmd_matches: list[re.Match], flags: list[int]
) -> list[tuple[re.Match, re.Match]]:
result = []
begin_cmd_matches_stack = []
for cmd_match, flag in zip(cmd_matches, flags):
if flag == 1:
begin_cmd_matches_stack.append(cmd_match)
elif flag == -1:
if not begin_cmd_matches_stack:
raise ValueError("Missing open command")
begin_cmd_match = begin_cmd_matches_stack.pop()
result.append((begin_cmd_match, cmd_match))
if begin_cmd_matches_stack:
raise ValueError("Missing close command")
return result
@staticmethod
@abstractmethod
def get_internal_specified_items(
cmd_match_pairs: list[tuple[re.Match, re.Match]]
command_match_pairs: list[tuple[re.Match, re.Match]]
) -> list[tuple[Span, dict[str, str]]]:
return []
@abstractmethod
def get_external_specified_items(
self
) -> list[tuple[Span, dict[str, str]]]:
def get_external_specified_items(self) -> list[tuple[Span, dict[str, str]]]:
return []
def split_span_by_levels(
self, arbitrary_span: Span, cmd_spans: list[Span], flags: list[int]
) -> list[Span]:
cmd_range = (
sum([
arbitrary_span[0] > interval_begin
for interval_begin, _ in cmd_spans
]),
sum([
arbitrary_span[1] >= interval_end
for _, interval_end in cmd_spans
])
)
complement_spans = self.get_complement_spans(
(0, self.full_len), cmd_spans
)
adjusted_span = (
max(arbitrary_span[0], complement_spans[cmd_range[0]][0]),
min(arbitrary_span[1], complement_spans[cmd_range[1]][1])
)
if adjusted_span[0] > adjusted_span[1]:
return []
upward_cmd_spans = []
downward_cmd_spans = []
for cmd_span, flag in list(zip(cmd_spans, flags))[slice(*cmd_range)]:
if flag == 1:
upward_cmd_spans.append(cmd_span)
elif flag == -1:
if upward_cmd_spans:
upward_cmd_spans.pop()
else:
downward_cmd_spans.append(cmd_span)
return list(filter(
lambda span: self.get_substr(span).strip(),
self.get_complement_spans(
adjusted_span, downward_cmd_spans + upward_cmd_spans
)
))
def check_overlapping(self) -> None:
labelled_spans = self.labelled_spans
if len(labelled_spans) >= 16777216:
raise ValueError("Cannot handle that many substrings")
for span_0, span_1 in it.product(labelled_spans, repeat=2):
if not span_0[0] < span_1[0] < span_0[1] < span_1[1]:
continue
raise ValueError(
"Partially overlapping substrings detected: "
f"'{self.get_substr(span_0)}' and '{self.get_substr(span_1)}'"
)
@staticmethod
@abstractmethod
def get_cmd_str_pair(
attr_dict: dict[str, str], label_hex: str | None
) -> tuple[str, str]:
return "", ""
def get_command_string(
attr_dict: dict[str, str], is_end: bool, label_hex: str | None
) -> str:
return ""
@abstractmethod
def get_content_prefix_and_suffix(
@ -375,54 +377,88 @@ class StringMobject(SVGMobject, ABC):
) -> tuple[str, str]:
return "", ""
def replace_substr(
self, span: Span, replace_func: Callable[[re.Match], str]
) -> str:
return re.sub(
self.get_cmd_pattern(), replace_func, self.get_substr(span),
flags=re.S
)
@staticmethod
def get_command_match_pairs(
command_matches: list[re.Match], command_flags: list[int]
) -> list[tuple[re.Match, re.Match]]:
result = []
start_cmd_matches_stack = []
for cmd_match, command_flag in zip(command_matches, command_flags):
if command_flag == 1:
start_cmd_matches_stack.append(cmd_match)
elif command_flag == -1:
if not start_cmd_matches_stack:
raise ValueError("Missing open command")
start_cmd_match = start_cmd_matches_stack.pop()
result.append(
(start_cmd_match, cmd_match)
)
if start_cmd_matches_stack:
raise ValueError("Missing close command")
return result
def get_content(self, is_labelled: bool) -> str:
inserted_str_pairs = [
(span, self.get_cmd_str_pair(
attr_dict,
insertion_strings = [
self.get_command_string(
self.attr_dicts[label],
is_end=flag < 0,
label_hex=self.int_to_hex(label + 1) if is_labelled else None
))
for label, (span, attr_dict) in enumerate(self.split_items)
)
for label, flag in self.insertion_items
]
if inserted_str_pairs:
indices, inserted_strs = zip(*sorted([
(index, s)
for (index, _), s in [
*sorted([
(span[::-1], end_str)
for span, (_, end_str) in reversed(inserted_str_pairs)
], key=lambda t: (t[0][0], -t[0][1])),
*sorted([
(span, begin_str)
for span, (begin_str, _) in inserted_str_pairs
], key=lambda t: (t[0][0], -t[0][1]))
]
], key=lambda t: t[0]))
else:
indices = ()
inserted_strs = ()
replaced_pieces = [
self.replace_substr(span, self.replace_for_content)
for span in zip((0, *indices), (*indices, self.full_len))
]
#repl_items = self.cmd_repl_items_for_content + [
# ((index, index), inserted_str)
# for index, inserted_str in inserted_str_items
#]
prefix, suffix = self.get_content_prefix_and_suffix(is_labelled)
prefix, suffix = self.get_content_prefix_and_suffix(
is_labelled=is_labelled
)
return "".join([
prefix,
*it.chain(*zip(replaced_pieces, (*inserted_strs, ""))),
self.join_strs(self.content_pieces, insertion_strings),
suffix
])
def get_group_substrs(self, group_labels: list[int]) -> list[str]:
if not group_labels:
return []
insertion_items = self.insertion_items
def get_index(label, flag):
if label == -1:
return 0 if flag == 1 else len(insertion_items) + 1
return insertion_items.index((label, flag)) + 1
def get_labelled_span(label):
if label == -1:
return (0, self.full_len)
return self.labelled_spans[label]
def label_contains(label_0, label_1):
return self.span_contains(
get_labelled_span(label_0), get_labelled_span(label_1)
)
piece_ranges = self.get_complement_spans(
(get_index(group_labels[0], 1), get_index(group_labels[-1], -1)),
[
(
get_index(next_label, 1)
if label_contains(prev_label, next_label)
else get_index(prev_label, -1),
get_index(prev_label, -1)
if label_contains(next_label, prev_label)
else get_index(next_label, 1)
)
for prev_label, next_label in self.get_neighbouring_pairs(
group_labels
)
]
)
return [
re.sub(r"\s+", "", "".join(
self.matching_pieces[slice(*piece_range)]
))
for piece_range in piece_ranges
]
# Selector
def get_submob_indices_list_by_span(
@ -442,49 +478,24 @@ class StringMobject(SVGMobject, ABC):
self.get_substr(span),
self.get_submob_indices_list_by_span(span)
)
for span in self.specified_spans
for span in self.labelled_spans
]
def get_group_part_items(self) -> list[tuple[str, list[int]]]:
if not self.labels:
return []
group_labels, range_lens = zip(*(
(val, len(list(grouper)))
range_lens, group_labels = zip(*(
(len(list(grouper)), val)
for val, grouper in it.groupby(self.labels)
))
labelled_submob_ranges = self.get_neighbouring_pairs(
[0, *it.accumulate(range_lens)]
)
ordered_spans = [
self.labelled_spans[label] if label != -1 else (0, self.full_len)
for label in group_labels
]
interval_spans = [
(
next_span[0]
if self.span_contains(prev_span, next_span)
else prev_span[1],
prev_span[1]
if self.span_contains(next_span, prev_span)
else next_span[0]
)
for prev_span, next_span in self.get_neighbouring_pairs(
ordered_spans
)
]
group_substrs = [
re.sub(r"\s+", "", self.replace_substr(
span, self.replace_for_matching
))
for span in self.get_complement_spans(
(ordered_spans[0][0], ordered_spans[-1][1]), interval_spans
)
]
submob_indices_lists = [
list(range(*submob_range))
for submob_range in labelled_submob_ranges
for submob_range in self.get_neighbouring_pairs(
[0, *it.accumulate(range_lens)]
)
]
group_substrs = self.get_group_substrs(list(group_labels))
return list(zip(group_substrs, submob_indices_lists))
def get_submob_indices_lists_by_selector(

View file

@ -80,8 +80,7 @@ class MarkupText(StringMobject):
"t2w": {},
"global_config": {},
"local_configs": {},
# For backward compatibility
"isolate": (re.compile(r"[a-zA-Z]+"), re.compile(r"\S+")),
"isolate": re.compile(r"\w+", re.U),
}
# See https://docs.gtk.org/Pango/pango_markup.html
@ -103,24 +102,6 @@ class MarkupText(StringMobject):
"\"": "&quot;",
"'": "&apos;"
}
#ENTITIES, ENTITY_CHARS = zip(
# ("&lt;", "<"),
# ("&gt;", ">"),
# ("&amp;", "&"),
# ("&quot;", "\""),
# ("&apos;", "'")
#)
#CMD_PATTERN = r"""[<>&"']"""
#FLAG_DICT = {}
#CONTENT_REPL = {
# r"<": "&lt;",
# r">": "&gt;",
# r"&": "&amp;",
# r"\"": "&quot;",
# r"'": "&apos;"
#}
#MATCH_REPL = {}
def __init__(self, text: str, **kwargs):
self.full2short(kwargs)
@ -258,62 +239,76 @@ class MarkupText(StringMobject):
# Parsing
@staticmethod
def get_cmd_pattern() -> str | None:
# Unsupported passthroughs:
# "<?...?>", "<!--...-->", "<![CDATA[...]]>", "<!DOCTYPE...>"
# See https://gitlab.gnome.org/GNOME/glib/-/blob/main/glib/gmarkup.c
return r"""(<(/)?\w+(?:\s*\w+\s*\=\s*(["']).*?\3)*(/)?>)|(&(#(x)?)?(.*?);)|([>"'])"""
def get_command_pattern() -> str:
return r"""
(?P<tag>
<
(?P<close_slash>/)?
(?P<tag_name>\w+)\s*
(?P<attr_list>(?:\w+\s*\=\s*(?P<quot>["']).*?(?P=quot)\s*)*) # TODO: test wsp
(?P<elision_slash>/)?
>
)
|(?P<passthrough>
<\?.*?\?>|<!--.*?-->|<!\[CDATA\[.*?\]\]>|<!DOCTYPE.*?>
)
|(?P<entity>&(?P<unicode>\#(?P<hex>x)?)?(?P<content>.*?);)
|(?P<char>[>"'])
"""
@staticmethod
def get_matched_flag(match_obj: re.Match) -> int:
if match_obj.group(1):
if match_obj.group(2):
def get_command_flag(match_obj: re.Match) -> int:
if match_obj.group("tag"):
if match_obj.group("close_slash"):
return -1
if not match_obj.group(4):
if not match_obj.group("elision_slash"):
return 1
return 0
@staticmethod
def replace_for_content(match_obj: re.Match) -> str:
substr = match_obj.group()
if match_obj.group(9):
return MarkupText.escape_markup_char(substr)
return substr
if match_obj.group("tag"):
return ""
if match_obj.group("char"):
return MarkupText.escape_markup_char(match_obj.group("char"))
return match_obj.group()
@staticmethod
def replace_for_matching(match_obj: re.Match) -> str:
substr = match_obj.group()
if match_obj.group(1):
if match_obj.group("tag"):
return ""
if match_obj.group(5):
if match_obj.group(6):
if match_obj.group("entity"):
if match_obj.group("unicode"):
base = 10
if match_obj.group(7):
if match_obj.group("hex"):
base = 16
return chr(int(match_obj.group(8), base))
return MarkupText.unescape_markup_char(substr)
return substr
return chr(int(match_obj.group("content"), base))
return MarkupText.unescape_markup_char(match_obj.group("entity"))
return match_obj.group()
@staticmethod
def get_internal_specified_items(
cmd_match_pairs: list[tuple[re.Match, re.Match]]
command_match_pairs: list[tuple[re.Match, re.Match]]
) -> list[tuple[Span, dict[str, str]]]:
attr_pattern = r"""(\w+)\s*\=\s*(["'])(.*?)\2"""
pattern = r"""
(?P<attr_name>\w+)
\s*\=\s*
(?P<quot>["'])(?P<attr_val>.*?)(?P=quot)
"""
result = []
for begin_match, end_match in cmd_match_pairs:
begin_tag = begin_match.group()
tag_name = re.search(r"\w+", begin_tag).group()
for start_match, end_match in command_match_pairs:
tag_name = start_match.group("tag_name")
if tag_name == "span":
attr_dict = {
attr_match_obj.group(1): attr_match_obj.group(3)
for attr_match_obj in re.finditer(
attr_pattern, begin_tag, re.S
match_obj.group("attr_name"): match_obj.group("attr_val")
for match_obj in re.finditer(
pattern, start_match.group("attr_list"), re.S | re.X
)
}
else:
attr_dict = MarkupText.MARKUP_TAGS.get(tag_name, {})
result.append(
((begin_match.end(), end_match.start()), attr_dict)
((start_match.end(), end_match.start()), attr_dict)
)
return result
@ -340,9 +335,12 @@ class MarkupText(StringMobject):
]
@staticmethod
def get_cmd_str_pair(
attr_dict: dict[str, str], label_hex: str | None
) -> tuple[str, str]:
def get_command_string(
attr_dict: dict[str, str], is_end: bool, label_hex: str | None
) -> str:
if is_end:
return "</span>"
if label_hex is not None:
converted_attr_dict = {"foreground": label_hex}
for key, val in attr_dict.items():
@ -359,7 +357,7 @@ class MarkupText(StringMobject):
f"{key}='{val}'"
for key, val in converted_attr_dict.items()
])
return f"<span {attrs_str}>", "</span>"
return f"<span {attrs_str}>"
def get_content_prefix_and_suffix(
self, is_labelled: bool
@ -387,9 +385,13 @@ class MarkupText(StringMobject):
((line_spacing_scale) + 1) * 0.6
)
return self.get_cmd_str_pair(
global_attr_dict,
label_hex=self.int_to_hex(0) if is_labelled else None
return tuple(
self.get_command_string(
global_attr_dict,
is_end=is_end,
label_hex=self.int_to_hex(0) if is_labelled else None
)
for is_end in (False, True)
)
# Method alias
@ -413,12 +415,17 @@ class MarkupText(StringMobject):
class Text(MarkupText):
CONFIG = {
# For backward compatibility
"isolate": (re.compile(r"\w+", re.U), re.compile(r"\S+", re.U)),
}
@staticmethod
def get_cmd_pattern() -> str | None:
def get_command_pattern() -> str | None:
return r"""[<>&"']"""
@staticmethod
def get_matched_flag(match_obj: re.Match) -> int:
def get_command_flag(match_obj: re.Match) -> int:
return 0
@staticmethod