Refactor LabelledString and relevant classes

This commit is contained in:
YishiMichael 2022-05-06 17:56:27 +08:00
parent b509f62010
commit 7cf0e0ba10
No known key found for this signature in database
GPG key ID: EC615C0C5A86BC80
3 changed files with 62 additions and 79 deletions

View file

@ -211,24 +211,6 @@ class LabelledString(SVGMobject, ABC):
val_ranges = LabelledString.get_neighbouring_pairs(indices) val_ranges = LabelledString.get_neighbouring_pairs(indices)
return list(zip(unique_vals, val_ranges)) return list(zip(unique_vals, val_ranges))
@staticmethod
def sort_obj_pairs_by_spans(
obj_pairs: list[tuple[Span, tuple[T, T]]]
) -> list[tuple[int, T]]:
return sorted([
(index, obj)
for (index, _), obj in [
*sorted([
(span[::-1], end_obj)
for span, (_, end_obj) in reversed(obj_pairs)
], key=lambda t: (t[0][0], -t[0][1])),
*sorted([
(span, begin_obj)
for span, (begin_obj, _) in obj_pairs
], key=lambda t: (t[0][0], -t[0][1]))
]
], key=lambda t: t[0])
@staticmethod @staticmethod
def span_contains(span_0: Span, span_1: Span) -> bool: def span_contains(span_0: Span, span_1: Span) -> bool:
return span_0[0] <= span_1[0] and span_0[1] >= span_1[1] return span_0[0] <= span_1[0] and span_0[1] >= span_1[1]
@ -246,13 +228,11 @@ class LabelledString(SVGMobject, ABC):
(*span_ends, universal_span[1]) (*span_ends, universal_span[1])
)) ))
def replace_string(self, span: Span, repl_items: list[Span, str]): def replace_substr(self, span: Span, repl_items: list[Span, str]):
if not repl_items: if not repl_items:
return self.get_substr(span) return self.get_substr(span)
repl_spans, repl_strs = zip(*sorted( repl_spans, repl_strs = zip(*sorted(repl_items, key=lambda t: t[0]))
repl_items, key=lambda t: t[0]
))
pieces = [ pieces = [
self.get_substr(piece_span) self.get_substr(piece_span)
for piece_span in self.get_complement_spans(span, repl_spans) for piece_span in self.get_complement_spans(span, repl_spans)
@ -278,7 +258,6 @@ class LabelledString(SVGMobject, ABC):
cmd_spans = self.get_cmd_spans() cmd_spans = self.get_cmd_spans()
cmd_substrs = [self.get_substr(span) for span in cmd_spans] cmd_substrs = [self.get_substr(span) for span in cmd_spans]
flags = [self.get_substr_flag(substr) for substr in cmd_substrs] flags = [self.get_substr_flag(substr) for substr in cmd_substrs]
specified_items = self.get_specified_items( specified_items = self.get_specified_items(
self.get_cmd_span_pairs(cmd_spans, flags) self.get_cmd_span_pairs(cmd_spans, flags)
) )
@ -303,13 +282,6 @@ class LabelledString(SVGMobject, ABC):
] ]
self.check_overlapping() self.check_overlapping()
#self.original_content = self.get_content(
# cmd_repl_items_for_content, split_items, is_labelled=False
#)
#self.labelled_content = self.get_content(
# cmd_repl_items_for_content, split_items, is_labelled=True
#)
@abstractmethod @abstractmethod
def get_cmd_spans(self) -> list[Span]: def get_cmd_spans(self) -> list[Span]:
return [] return []
@ -318,6 +290,14 @@ class LabelledString(SVGMobject, ABC):
def get_substr_flag(self, substr: str) -> int: def get_substr_flag(self, substr: str) -> int:
return 0 return 0
@abstractmethod
def get_repl_substr_for_content(self, substr: str) -> str:
return ""
@abstractmethod
def get_repl_substr_for_matching(self, substr: str) -> str:
return ""
@staticmethod @staticmethod
def get_cmd_span_pairs( def get_cmd_span_pairs(
cmd_spans: list[Span], flags: list[int] cmd_spans: list[Span], flags: list[int]
@ -329,11 +309,11 @@ class LabelledString(SVGMobject, ABC):
begin_cmd_spans_stack.append(cmd_span) begin_cmd_spans_stack.append(cmd_span)
elif flag == -1: elif flag == -1:
if not begin_cmd_spans_stack: if not begin_cmd_spans_stack:
raise ValueError("Missing '{' inserted") raise ValueError("Missing open command")
begin_cmd_span = begin_cmd_spans_stack.pop() begin_cmd_span = begin_cmd_spans_stack.pop()
result.append((begin_cmd_span, cmd_span)) result.append((begin_cmd_span, cmd_span))
if begin_cmd_spans_stack: if begin_cmd_spans_stack:
raise ValueError("Missing '}' inserted") raise ValueError("Missing close command")
return result return result
@abstractmethod @abstractmethod
@ -394,14 +374,6 @@ class LabelledString(SVGMobject, ABC):
f"'{self.get_substr(span_0)}' and '{self.get_substr(span_1)}'" f"'{self.get_substr(span_0)}' and '{self.get_substr(span_1)}'"
) )
@abstractmethod
def get_repl_substr_for_content(self, substr: str) -> str:
return ""
@abstractmethod
def get_repl_substr_for_matching(self, substr: str) -> str:
return ""
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def get_cmd_str_pair( def get_cmd_str_pair(
@ -423,16 +395,27 @@ class LabelledString(SVGMobject, ABC):
)) ))
for label, (span, attr_dict) in enumerate(self.split_items) for label, (span, attr_dict) in enumerate(self.split_items)
] ]
inserted_str_items = sorted([
(index, s)
for (index, _), s in [
*sorted([
(span[::-1], end_str)
for span, (_, end_str) in reversed(inserted_str_pairs)
], key=lambda t: (t[0][0], -t[0][1])),
*sorted([
(span, begin_str)
for span, (begin_str, _) in inserted_str_pairs
], key=lambda t: (t[0][0], -t[0][1]))
]
], key=lambda t: t[0])
repl_items = self.cmd_repl_items_for_content + [ repl_items = self.cmd_repl_items_for_content + [
((index, index), inserted_str) ((index, index), inserted_str)
for index, inserted_str in self.sort_obj_pairs_by_spans( for index, inserted_str in inserted_str_items
inserted_str_pairs
)
] ]
prefix, suffix = self.get_content_prefix_and_suffix(is_labelled) prefix, suffix = self.get_content_prefix_and_suffix(is_labelled)
return "".join([ return "".join([
prefix, prefix,
self.replace_string(self.full_span, repl_items), self.replace_substr(self.full_span, repl_items),
suffix suffix
]) ])
@ -483,7 +466,7 @@ class LabelledString(SVGMobject, ABC):
) )
] ]
group_substrs = [ group_substrs = [
re.sub(r"\s+", "", self.replace_string( re.sub(r"\s+", "", self.replace_substr(
span, [ span, [
(cmd_span, repl_str) (cmd_span, repl_str)
for cmd_span, repl_str in self.cmd_repl_items_for_matching for cmd_span, repl_str in self.cmd_repl_items_for_matching

View file

@ -81,6 +81,12 @@ class MTex(LabelledString):
def get_substr_flag(self, substr: str) -> int: def get_substr_flag(self, substr: str) -> int:
return {"{": 1, "}": -1}.get(substr, 0) return {"{": 1, "}": -1}.get(substr, 0)
def get_repl_substr_for_content(self, substr: str) -> str:
return substr
def get_repl_substr_for_matching(self, substr: str) -> str:
return substr if substr.startswith("\\") else ""
def get_specified_items( def get_specified_items(
self, cmd_span_pairs: list[tuple[Span, Span]] self, cmd_span_pairs: list[tuple[Span, Span]]
) -> list[tuple[Span, dict[str, str]]]: ) -> list[tuple[Span, dict[str, str]]]:
@ -108,12 +114,6 @@ class MTex(LabelledString):
] ]
return [(span, {}) for span in specified_spans] return [(span, {}) for span in specified_spans]
def get_repl_substr_for_content(self, substr: str) -> str:
return substr
def get_repl_substr_for_matching(self, substr: str) -> str:
return substr if substr.startswith("\\") else ""
@staticmethod @staticmethod
def get_color_cmd_str(rgb_hex: str) -> str: def get_color_cmd_str(rgb_hex: str) -> str:
rgb = MTex.hex_to_int(rgb_hex) rgb = MTex.hex_to_int(rgb_hex)

View file

@ -247,6 +247,34 @@ class MarkupText(LabelledString):
return -1 return -1
return 0 return 0
def get_repl_substr_for_content(self, substr: str) -> str:
if substr.startswith("<") and substr.endswith(">"):
return ""
return {
"<": "&lt;",
">": "&gt;",
"&": "&amp;",
"\"": "&quot;",
"'": "&apos;"
}.get(substr, substr)
def get_repl_substr_for_matching(self, substr: str) -> str:
if substr.startswith("<") and substr.endswith(">"):
return ""
if substr.startswith("&#") and substr.endswith(";"):
if substr.startswith("&#x"):
char_reference = int(substr[3:-1], 16)
else:
char_reference = int(substr[2:-1], 10)
return chr(char_reference)
return {
"&lt;": "<",
"&gt;": ">",
"&amp;": "&",
"&quot;": "\"",
"&apos;": "'"
}.get(substr, substr)
def get_specified_items( def get_specified_items(
self, cmd_span_pairs: list[tuple[Span, Span]] self, cmd_span_pairs: list[tuple[Span, Span]]
) -> list[tuple[Span, dict[str, str]]]: ) -> list[tuple[Span, dict[str, str]]]:
@ -290,34 +318,6 @@ class MarkupText(LabelledString):
] ]
] ]
def get_repl_substr_for_content(self, substr: str) -> str:
if substr.startswith("<") and substr.endswith(">"):
return ""
return {
"<": "&lt;",
">": "&gt;",
"&": "&amp;",
"\"": "&quot;",
"'": "&apos;"
}.get(substr, substr)
def get_repl_substr_for_matching(self, substr: str) -> str:
if substr.startswith("<") and substr.endswith(">"):
return ""
if substr.startswith("&#") and substr.endswith(";"):
if substr.startswith("&#x"):
char_reference = int(substr[3:-1], 16)
else:
char_reference = int(substr[2:-1], 10)
return chr(char_reference)
return {
"&lt;": "<",
"&gt;": ">",
"&amp;": "&",
"&quot;": "\"",
"&apos;": "'"
}.get(substr, substr)
@staticmethod @staticmethod
def get_cmd_str_pair( def get_cmd_str_pair(
attr_dict: dict[str, str], label_hex: str | None attr_dict: dict[str, str], label_hex: str | None