Refactor LabelledString and relevant classes

This commit is contained in:
YishiMichael 2022-05-06 16:43:20 +08:00
parent 642602155d
commit b509f62010
No known key found for this signature in database
GPG key ID: EC615C0C5A86BC80
3 changed files with 178 additions and 188 deletions

View file

@ -3,6 +3,8 @@ from __future__ import annotations
from abc import ABC, abstractmethod
import itertools as it
import re
from scipy.optimize import linear_sum_assignment
from scipy.spatial.distance import cdist
from manimlib.constants import WHITE
from manimlib.logger import log
@ -36,6 +38,21 @@ if TYPE_CHECKING:
class LabelledString(SVGMobject, ABC):
"""
An abstract base class for `MTex` and `MarkupText`
This class aims to optimize the logic of "slicing submobjects
via substrings". This could be much clearer and more user-friendly
than slicing through numerical indices explicitly.
Users are expected to specify substrings in `isolate` parameter
if they want to do anything with their corresponding submobjects.
`isolate` parameter can be either a string, a `re.Pattern` object,
or a 2-tuple containing integers or None, or a collection of the above.
Note, substrings specified cannot *partially* overlap with each other.
Each instance of `LabelledString` generates 2 svg files.
The additional one is generated with some color commands inserted,
so that each submobject of the original `SVGMobject` will be "labelled"
by the color of its paired submobject from the additional `SVGMobject`.
"""
CONFIG = {
"height": None,
@ -62,7 +79,8 @@ class LabelledString(SVGMobject, ABC):
self.labels = [submob.label for submob in self.submobjects]
def get_file_path(self) -> str:
return self.get_file_path_by_content(self.original_content)
original_content = self.get_content(is_labelled=False)
return self.get_file_path_by_content(original_content)
@abstractmethod
def get_file_path_by_content(self, content: str) -> str:
@ -71,62 +89,74 @@ class LabelledString(SVGMobject, ABC):
def generate_mobject(self) -> None:
super().generate_mobject()
file_path = self.get_file_path_by_content(self.labelled_content)
labels_count = len(self.labelled_spans)
if not labels_count:
for submob in self.submobjects:
submob.label = -1
return
labelled_content = self.get_content(is_labelled=True)
file_path = self.get_file_path_by_content(labelled_content)
labelled_svg = SVGMobject(file_path)
num_submobjects = len(self.submobjects)
if num_submobjects != len(labelled_svg.submobjects):
if len(self.submobjects) != len(labelled_svg.submobjects):
log.warning(
"Cannot align submobjects of the labelled svg "
"to the original svg. Skip the labelling process."
)
submob_color_ints = [0] * num_submobjects
else:
submob_color_ints = [
self.hex_to_int(self.color_to_hex(submob.get_fill_color()))
for submob in labelled_svg.submobjects
]
unrecognized_colors = list(filter(
lambda color_int: color_int > len(self.labelled_spans),
submob_color_ints
for submob in self.submobjects:
submob.label = -1
return
self.rearrange_submobjects_by_positions(labelled_svg)
unrecognizable_colors = []
for submob, labelled_svg_submob in zip(
self.submobjects, labelled_svg.submobjects
):
color_int = self.hex_to_int(self.color_to_hex(
labelled_svg_submob.get_fill_color()
))
if unrecognized_colors:
log.warning(
"Unrecognized color labels detected (%s, etc). "
"Skip the labelling process.",
self.int_to_hex(unrecognized_colors[0])
)
submob_color_ints = [0] * num_submobjects
# Rearrange colors so that the n-th submobject from the left
# is labelled by the n-th submobject of `labelled_svg` from the left.
submob_indices = sorted(
range(num_submobjects),
key=lambda index: tuple(
self.submobjects[index].get_center()
)
)
labelled_submob_indices = sorted(
range(num_submobjects),
key=lambda index: tuple(
labelled_svg.submobjects[index].get_center()
)
)
submob_color_ints = [
submob_color_ints[
labelled_submob_indices[submob_indices.index(index)]
]
for index in range(num_submobjects)
]
for submob, color_int in zip(self.submobjects, submob_color_ints):
if color_int > labels_count:
unrecognizable_colors.append(color_int)
color_int = 0
submob.label = color_int - 1
if unrecognizable_colors:
log.warning(
"Unrecognizable color labels detected (%s, etc). "
"The result could be unexpected.",
self.int_to_hex(unrecognizable_colors[0])
)
def rearrange_submobjects_by_positions(
self, labelled_svg: SVGMobject
) -> None:
# Rearrange submobjects of `labelled_svg` so that
# each submobject is labelled by the nearest one of `labelled_svg`.
# The correctness cannot be ensured, since the svg may
# change significantly after inserting color commands.
if not labelled_svg.submobjects:
return
bb_0 = self.get_bounding_box()
bb_1 = labelled_svg.get_bounding_box()
scale_factor = abs((bb_0[2] - bb_0[0]) / (bb_1[2] - bb_1[0]))
labelled_svg.move_to(self).scale(scale_factor)
distance_matrix = cdist(
[submob.get_center() for submob in self.submobjects],
[submob.get_center() for submob in labelled_svg.submobjects]
)
_, indices = linear_sum_assignment(distance_matrix)
labelled_svg.set_submobjects([
labelled_svg.submobjects[index]
for index in indices
])
# Toolkits
def get_substr(self, span: Span) -> str:
return self.string[slice(*span)]
def find_spans(self, pattern: str) -> list[Span]:
def find_spans(self, pattern: str | re.Pattern) -> list[Span]:
return [
match_obj.span()
for match_obj in re.finditer(pattern, self.string)
@ -137,18 +167,7 @@ class LabelledString(SVGMobject, ABC):
if isinstance(sel, str):
return self.find_spans(re.escape(sel))
if isinstance(sel, re.Pattern):
result_iterator = sel.finditer(self.string)
if not sel.groups:
return [
match_obj.span()
for match_obj in result_iterator
]
return [
span
for match_obj in result_iterator
for span in match_obj.regs[1:]
if span != (-1, -1)
]
return self.find_spans(sel)
if isinstance(sel, tuple) and len(sel) == 2 and all(
isinstance(index, int) or index is None
for index in sel
@ -196,16 +215,19 @@ class LabelledString(SVGMobject, ABC):
def sort_obj_pairs_by_spans(
obj_pairs: list[tuple[Span, tuple[T, T]]]
) -> list[tuple[int, T]]:
return [
return sorted([
(index, obj)
for (index, _), obj in sorted([
(span, begin_obj)
for span, (begin_obj, _) in obj_pairs
] + [
for (index, _), obj in [
*sorted([
(span[::-1], end_obj)
for span, (_, end_obj) in reversed(obj_pairs)
], key=lambda t: (t[0][0], -t[0][1])),
*sorted([
(span, begin_obj)
for span, (begin_obj, _) in obj_pairs
], key=lambda t: (t[0][0], -t[0][1]))
]
], key=lambda t: t[0])
@staticmethod
def span_contains(span_0: Span, span_1: Span) -> bool:
@ -253,17 +275,12 @@ class LabelledString(SVGMobject, ABC):
# Parsing
def parse(self) -> None:
begin_cmd_spans, end_cmd_spans, other_cmd_spans = self.get_cmd_spans()
cmd_span_items = sorted(it.chain(
[(begin_cmd_span, 1) for begin_cmd_span in begin_cmd_spans],
[(end_cmd_span, -1) for end_cmd_span in end_cmd_spans],
[(cmd_span, 0) for cmd_span in other_cmd_spans],
), key=lambda t: t[0])
cmd_spans = [span for span, _ in cmd_span_items]
flags = [flag for _, flag in cmd_span_items]
cmd_spans = self.get_cmd_spans()
cmd_substrs = [self.get_substr(span) for span in cmd_spans]
flags = [self.get_substr_flag(substr) for substr in cmd_substrs]
specified_items = self.get_specified_items(
self.get_cmd_span_pairs(cmd_span_items)
self.get_cmd_span_pairs(cmd_spans, flags)
)
split_items = [
(span, attr_dict)
@ -274,36 +291,40 @@ class LabelledString(SVGMobject, ABC):
]
self.specified_spans = [span for span, _ in specified_items]
self.split_items = split_items
self.labelled_spans = [span for span, _ in split_items]
self.check_overlapping()
cmd_repl_items_for_content = [
(span, self.get_repl_substr_for_content(self.get_substr(span)))
for span in cmd_spans
self.cmd_repl_items_for_content = [
(span, self.get_repl_substr_for_content(substr))
for span, substr in zip(cmd_spans, cmd_substrs)
]
self.cmd_repl_items_for_matching = [
(span, self.get_repl_substr_for_matching(self.get_substr(span)))
for span in cmd_spans
(span, self.get_repl_substr_for_matching(substr))
for span, substr in zip(cmd_spans, cmd_substrs)
]
self.check_overlapping()
self.original_content = self.get_content(
cmd_repl_items_for_content, split_items, is_labelled=False
)
self.labelled_content = self.get_content(
cmd_repl_items_for_content, split_items, is_labelled=True
)
#self.original_content = self.get_content(
# cmd_repl_items_for_content, split_items, is_labelled=False
#)
#self.labelled_content = self.get_content(
# cmd_repl_items_for_content, split_items, is_labelled=True
#)
@abstractmethod
def get_cmd_spans(self) -> tuple[list[Span], list[Span], list[Span]]:
return [], [], []
def get_cmd_spans(self) -> list[Span]:
return []
@abstractmethod
def get_substr_flag(self, substr: str) -> int:
return 0
@staticmethod
def get_cmd_span_pairs(
cmd_span_items: list[tuple[Span, int]]
cmd_spans: list[Span], flags: list[int]
) -> list[tuple[Span, Span]]:
result = []
begin_cmd_spans_stack = []
for cmd_span, flag in cmd_span_items:
for cmd_span, flag in zip(cmd_spans, flags):
if flag == 1:
begin_cmd_spans_stack.append(cmd_span)
elif flag == -1:
@ -354,9 +375,12 @@ class LabelledString(SVGMobject, ABC):
upward_cmd_spans.pop()
else:
downward_cmd_spans.append(cmd_span)
return self.get_complement_spans(
return list(filter(
lambda span: self.get_substr(span).strip(),
self.get_complement_spans(
adjusted_span, downward_cmd_spans + upward_cmd_spans
)
))
def check_overlapping(self) -> None:
labelled_spans = self.labelled_spans
@ -391,18 +415,15 @@ class LabelledString(SVGMobject, ABC):
) -> tuple[str, str]:
return "", ""
def get_content(
self, cmd_repl_items_for_content: list[Span, str],
split_items: list[tuple[Span, dict[str, str]]], is_labelled: bool
) -> str:
def get_content(self, is_labelled: bool) -> str:
inserted_str_pairs = [
(span, self.get_cmd_str_pair(
attr_dict,
label_hex=self.int_to_hex(label + 1) if is_labelled else None
))
for label, (span, attr_dict) in enumerate(split_items)
for label, (span, attr_dict) in enumerate(self.split_items)
]
repl_items = cmd_repl_items_for_content + [
repl_items = self.cmd_repl_items_for_content + [
((index, index), inserted_str)
for index, inserted_str in self.sort_obj_pairs_by_spans(
inserted_str_pairs

View file

@ -75,19 +75,11 @@ class MTex(LabelledString):
# Parsing
def get_cmd_spans(self) -> tuple[list[Span], list[Span], list[Span]]:
backslash_spans = self.find_spans(r"\\(?:[a-zA-Z]+|\s|\S)")
def find_unescaped_spans(pattern):
return list(filter(
lambda span: (span[0] - 1, span[1]) not in backslash_spans,
self.find_spans(pattern)
))
def get_cmd_spans(self) -> list[Span]:
return self.find_spans(r"\\(?:[a-zA-Z]+|\s|\S)|[_^{}]")
return (
find_unescaped_spans(r"{"),
find_unescaped_spans(r"}"),
backslash_spans + find_unescaped_spans(r"[_^]")
)
def get_substr_flag(self, substr: str) -> int:
return {"{": 1, "}": -1}.get(substr, 0)
def get_specified_items(
self, cmd_span_pairs: list[tuple[Span, Span]]

View file

@ -49,32 +49,6 @@ DEFAULT_CANVAS_WIDTH = 16384
DEFAULT_CANVAS_HEIGHT = 16384
# See https://docs.gtk.org/Pango/pango_markup.html
MARKUP_COLOR_KEYS_DICT = {
"foreground": False,
"fgcolor": False,
"color": False,
"background": True,
"bgcolor": True,
"underline_color": True,
"overline_color": True,
"strikethrough_color": True,
}
MARKUP_TAG_CONVERSION_DICT = {
"b": {"font_weight": "bold"},
"big": {"font_size": "larger"},
"i": {"font_style": "italic"},
"s": {"strikethrough": "true"},
"sub": {"baseline_shift": "subscript", "font_scale": "subscript"},
"sup": {"baseline_shift": "superscript", "font_scale": "superscript"},
"small": {"font_size": "smaller"},
"tt": {"font_family": "monospace"},
"u": {"underline": "single"},
}
XML_ENTITIES = ("<", ">", "&", """, "'")
XML_ENTITY_CHARS = "<>&\"'"
# Temporary handler
class _Alignment:
VAL_DICT = {
@ -107,13 +81,33 @@ class MarkupText(LabelledString):
"t2w": {},
"global_config": {},
"local_configs": {},
# When attempting to slice submobs via `get_part_by_text` thereafter,
# it's recommended to explicitly specify them in `isolate` attribute
# when initializing.
# For backward compatibility
"isolate": (re.compile(r"[a-zA-Z]+"), re.compile(r"\S+")),
}
# See https://docs.gtk.org/Pango/pango_markup.html
MARKUP_COLOR_KEYS = {
"foreground": False,
"fgcolor": False,
"color": False,
"background": True,
"bgcolor": True,
"underline_color": True,
"overline_color": True,
"strikethrough_color": True,
}
MARKUP_TAGS = {
"b": {"font_weight": "bold"},
"big": {"font_size": "larger"},
"i": {"font_style": "italic"},
"s": {"strikethrough": "true"},
"sub": {"baseline_shift": "subscript", "font_scale": "subscript"},
"sup": {"baseline_shift": "superscript", "font_scale": "superscript"},
"small": {"font_size": "smaller"},
"tt": {"font_family": "monospace"},
"u": {"underline": "single"},
}
def __init__(self, text: str, **kwargs):
self.full2short(kwargs)
digest_config(self, kwargs)
@ -235,53 +229,28 @@ class MarkupText(LabelledString):
# Parsing
def get_cmd_spans(self) -> tuple[list[Span], list[Span], list[Span]]:
def get_cmd_spans(self) -> list[Span]:
if not self.is_markup:
return [], [], self.find_spans(r"[<>&\x22']")
return self.find_spans(r"""[<>&"']""")
# Unsupported passthroughs:
# "<?...?>", "<!--...-->", "<![CDATA[...]]>", "<!DOCTYPE...>"
# See https://gitlab.gnome.org/GNOME/glib/-/blob/main/glib/gmarkup.c
string = self.string
cmd_spans = []
cmd_pattern = re.compile(r"""
&[\s\S]*?; # entity & character reference
|</?\w+(?:\s*\w+\s*\=\s*(['"])[\s\S]*?\1)*/?> # tag
|<\?[\s\S]*?\?>|<\?> # instruction
|<!--[\s\S]*?-->|<!---?> # comment
|<!\[CDATA\[[\s\S]*?\]\]> # cdata
|<!DOCTYPE # doctype (require balancing groups)
|[>"'] # characters to escape
""", re.X)
match_obj = cmd_pattern.search(string)
while match_obj:
span_begin, span_end = match_obj.span()
if match_obj.group() == "<!DOCTYPE":
balance = 1
while balance != 0:
angle_match_obj = re.compile(r"[<>]").search(
string, pos=span_end
return self.find_spans(
r"""&[\s\S]*?;|[>"']|</?\w+(?:\s*\w+\s*\=\s*(["'])[\s\S]*?\1)*/?>"""
)
balance += {"<": 1, ">": -1}[angle_match_obj.group()]
span_end = angle_match_obj.end()
cmd_spans.append((span_begin, span_end))
match_obj = cmd_pattern.search(string, pos=span_end)
begin_cmd_spans = []
end_cmd_spans = []
other_cmd_spans = []
for cmd_span in cmd_spans:
substr = self.get_substr(cmd_span)
def get_substr_flag(self, substr: str) -> int:
if re.fullmatch(r"<\w[\s\S]*[^/]>", substr):
begin_cmd_spans.append(cmd_span)
elif substr.startswith("</"):
end_cmd_spans.append(cmd_span)
else:
other_cmd_spans.append(cmd_span)
return begin_cmd_spans, end_cmd_spans, other_cmd_spans
return 1
if substr.startswith("</"):
return -1
return 0
def get_specified_items(
self, cmd_span_pairs: list[tuple[Span, Span]]
) -> list[tuple[Span, dict[str, str]]]:
attr_pattern = r"(\w+)\s*\=\s*(['\x22])([\s\S]*?)\2"
attr_pattern = r"""(\w+)\s*\=\s*(["'])([\s\S]*?)\2"""
internal_items = []
for begin_cmd_span, end_cmd_span in cmd_span_pairs:
begin_tag = self.get_substr(begin_cmd_span)
@ -292,7 +261,7 @@ class MarkupText(LabelledString):
for attr_match_obj in re.finditer(attr_pattern, begin_tag)
}
else:
attr_dict = MARKUP_TAG_CONVERSION_DICT.get(tag_name, {})
attr_dict = MarkupText.MARKUP_TAGS.get(tag_name, {})
internal_items.append(
((begin_cmd_span[1], end_cmd_span[0]), attr_dict)
)
@ -324,22 +293,30 @@ class MarkupText(LabelledString):
def get_repl_substr_for_content(self, substr: str) -> str:
if substr.startswith("<") and substr.endswith(">"):
return ""
if substr in XML_ENTITY_CHARS:
return XML_ENTITIES[XML_ENTITY_CHARS.index(substr)]
return substr
return {
"<": "&lt;",
">": "&gt;",
"&": "&amp;",
"\"": "&quot;",
"'": "&apos;"
}.get(substr, substr)
def get_repl_substr_for_matching(self, substr: str) -> str:
if substr.startswith("<") and substr.endswith(">"):
return ""
if substr in XML_ENTITIES:
return XML_ENTITY_CHARS[XML_ENTITIES.index(substr)]
if substr.startswith("&#") and substr.endswith(";"):
if substr.startswith("&#x"):
char_reference = int(substr[3:-1], 16)
else:
char_reference = int(substr[2:-1], 10)
return chr(char_reference)
return substr
return {
"&lt;": "<",
"&gt;": ">",
"&amp;": "&",
"&quot;": "\"",
"&apos;": "'"
}.get(substr, substr)
@staticmethod
def get_cmd_str_pair(
@ -348,7 +325,7 @@ class MarkupText(LabelledString):
if label_hex is not None:
converted_attr_dict = {"foreground": label_hex}
for key, val in attr_dict.items():
substitute_key = MARKUP_COLOR_KEYS_DICT.get(key.lower(), None)
substitute_key = MarkupText.MARKUP_COLOR_KEYS.get(key, None)
if substitute_key is None:
converted_attr_dict[key] = val
elif substitute_key: