Refactor LabelledString

This commit is contained in:
YishiMichael 2022-04-16 00:24:55 +08:00
parent a1e77b0ce2
commit 4690edec3e
No known key found for this signature in database
GPG key ID: EC615C0C5A86BC80
3 changed files with 13 additions and 9 deletions

View file

@ -103,6 +103,7 @@ class LabelledString(SVGMobject, ABC):
def pre_parse(self) -> None: def pre_parse(self) -> None:
self.string_len = len(self.string) self.string_len = len(self.string)
self.full_span = (0, self.string_len) self.full_span = (0, self.string_len)
self.space_spans = self.find_spans(r"\s+")
self.base_color_int = self.color_to_int(self.base_color) self.base_color_int = self.color_to_int(self.base_color)
def parse(self) -> None: def parse(self) -> None:
@ -137,14 +138,17 @@ class LabelledString(SVGMobject, ABC):
def get_substr(self, span: Span) -> str: def get_substr(self, span: Span) -> str:
return self.string[slice(*span)] return self.string[slice(*span)]
def find_spans(self, pattern: str | re.Pattern, **kwargs) -> list[Span]: def find_spans(self, pattern: str | re.Pattern) -> list[Span]:
if isinstance(pattern, str): if isinstance(pattern, str):
pattern = re.compile(pattern) pattern = re.compile(pattern)
return [ return [
match_obj.span() match_obj.span()
for match_obj in pattern.finditer(self.string, **kwargs) for match_obj in pattern.finditer(self.string)
] ]
def match_at(self, pattern: str, pos: int) -> re.Pattern | None:
return re.compile(pattern).match(self.string, pos=pos)
def find_spans_by_selector(self, selector: Selector) -> list[Span]: def find_spans_by_selector(self, selector: Selector) -> list[Span]:
if isinstance(selector, str): if isinstance(selector, str):
result = self.find_spans(re.escape(selector)) result = self.find_spans(re.escape(selector))

View file

@ -131,12 +131,12 @@ class MTex(LabelledString):
brace_indices_dict = dict(self.brace_index_pairs) brace_indices_dict = dict(self.brace_index_pairs)
script_pattern = r"[a-zA-Z0-9]|\\[a-zA-Z]+" script_pattern = r"[a-zA-Z0-9]|\\[a-zA-Z]+"
for char_span in self.script_char_spans: for char_span in self.script_char_spans:
span_begin = self.find_spans(r"\s*", pos=char_span[1])[0][1] span_begin = self.rslide(char_span[1], self.space_spans)
if span_begin in brace_indices_dict.keys(): if span_begin in brace_indices_dict.keys():
span_end = brace_indices_dict[span_begin] + 1 span_end = brace_indices_dict[span_begin] + 1
else: else:
spans = self.find_spans(script_pattern, pos=span_begin) match_obj = self.match_at(script_pattern, span_begin)
if not spans or spans[0][0] != span_begin: if match_obj is None:
script_name = { script_name = {
"_": "subscript", "_": "subscript",
"^": "superscript" "^": "superscript"
@ -146,14 +146,14 @@ class MTex(LabelledString):
f"(position {char_span[0]}). " f"(position {char_span[0]}). "
"Please use braces to clarify" "Please use braces to clarify"
) )
span_end = spans[0][1] span_end = match_obj.end()
result.append((span_begin, span_end)) result.append((span_begin, span_end))
return result return result
def get_script_spans(self) -> list[Span]: def get_script_spans(self) -> list[Span]:
return [ return [
( (
self.find_spans(r"\s*", endpos=char_span[0])[-1][0], self.lslide(char_span[0], self.space_spans),
script_content_span[1] script_content_span[1]
) )
for char_span, script_content_span in zip( for char_span, script_content_span in zip(
@ -202,7 +202,7 @@ class MTex(LabelledString):
def get_extra_entity_spans(self) -> list[Span]: def get_extra_entity_spans(self) -> list[Span]:
return [ return [
self.find_spans(r"\\([a-zA-Z]+|.?)", pos=index)[0] self.match_at(r"\\([a-zA-Z]+|.?)", index).span()
for index in self.backslash_indices for index in self.backslash_indices
] ]

View file

@ -434,8 +434,8 @@ class MarkupText(LabelledString):
def get_label_span_list(self) -> list[Span]: def get_label_span_list(self) -> list[Span]:
breakup_indices = remove_list_redundancies(list(it.chain(*it.chain( breakup_indices = remove_list_redundancies(list(it.chain(*it.chain(
self.find_spans(r"\s+"),
self.find_spans(r"\b"), self.find_spans(r"\b"),
self.space_spans,
self.specified_spans self.specified_spans
)))) ))))
breakup_indices = sorted(filter( breakup_indices = sorted(filter(