refactor: refactor StringMobject

This commit is contained in:
YishiMichael 2022-08-20 13:01:59 +08:00
parent 7ffc7b33f7
commit 19c757ec90
No known key found for this signature in database
GPG key ID: EC615C0C5A86BC80
5 changed files with 190 additions and 138 deletions

View file

@ -131,9 +131,9 @@ class MTex(StringMobject):
@staticmethod
def replace_for_matching(match_obj: re.Match) -> str:
if match_obj.group("command"):
return match_obj.group()
return ""
if match_obj.group("script"):
return ""
return match_obj.group()
@staticmethod
def get_attr_dict_from_command_pair(

View file

@ -90,21 +90,22 @@ class StringMobject(SVGMobject, ABC):
super().generate_mobject()
labels_count = len(self.labelled_spans)
if not labels_count:
if labels_count == 1:
for submob in self.submobjects:
submob.label = -1
submob.label = 0
return
labelled_content = self.get_content(is_labelled=True)
file_path = self.get_file_path_by_content(labelled_content)
labelled_svg = SVGMobject(file_path)
#print(len(self.submobjects), len(labelled_svg.submobjects)) # ????
if len(self.submobjects) != len(labelled_svg.submobjects):
log.warning(
"Cannot align submobjects of the labelled svg "
"to the original svg. Skip the labelling process."
)
for submob in self.submobjects:
submob.label = -1
submob.label = labels_count - 1
return
self.rearrange_submobjects_by_positions(labelled_svg)
@ -115,15 +116,18 @@ class StringMobject(SVGMobject, ABC):
color_int = self.hex_to_int(self.color_to_hex(
labelled_svg_submob.get_fill_color()
))
if color_int > labels_count:
if color_int >= labels_count:
unrecognizable_colors.append(color_int)
color_int = 0
color_int = labels_count
submob.label = color_int - 1
if unrecognizable_colors:
log.warning(
"Unrecognizable color labels detected (%s, etc). "
"Unrecognizable color labels detected (%s). "
"The result could be unexpected.",
self.int_to_hex(unrecognizable_colors[0])
", ".join([
self.int_to_hex(color)
for color in unrecognizable_colors
])
)
def rearrange_submobjects_by_positions(
@ -186,7 +190,7 @@ class StringMobject(SVGMobject, ABC):
if spans is None:
raise TypeError(f"Invalid selector: '{sel}'")
result.extend(spans)
return list(filter(lambda span: span[0] < span[1], result))
return list(filter(lambda span: span[0] <= span[1], result))
@staticmethod
def span_contains(span_0: Span, span_1: Span) -> bool:
@ -214,73 +218,94 @@ class StringMobject(SVGMobject, ABC):
val_list = list(vals)
return list(zip(val_list[:-1], val_list[1:]))
#def get_complement_spans(
# universal_span: Span, interval_spans: list[Span]
#) -> list[Span]:
# if not interval_spans:
# return [universal_span]
# span_ends, span_starts = zip(*interval_spans)
# return list(zip(
# (universal_span[0], *span_starts),
# (*span_ends, universal_span[1])
# ))
def join_strs(strs: list[str], inserted_strs: list[str]) -> str:
return "".join(it.chain(*zip(strs, (*inserted_strs, ""))))
command_matches = self.get_command_matches(self.string)
command_spans = [match_obj.span() for match_obj in command_matches]
#command_spans = [match_obj.span() for match_obj in command_matches]
configured_items = self.get_configured_items()
#configured_spans = [span for span, _ in configured_items]
#configured_attr_dicts = [d for _, d in configured_items]
categorized_spans = [
[(0, len(self.string))], # TODO
[span for span, _ in configured_items],
self.find_spans_by_selector(self.isolate),
self.find_spans_by_selector(self.protect),
command_spans # TODO
[match_obj.span() for match_obj in command_matches] # TODO
]
sorted_items = sorted([
(category, category_index, flag, *span[::flag])
for category, spans in enumerate(categorized_spans)
for category_index, span in enumerate(spans)
for flag in (1, -1)
], key=lambda t: (t[3], t[2], -t[4], t[2] * (t[0] + 1), t[2] * t[1])) # TODO
], key=lambda t: (
t[3], t[2] * (2 if t[3] != t[4] else -1), -t[4],
t[2] * t[0], t[2] * t[1]
)) # TODO
labelled_spans = []
attr_dicts = []
inserted_items = []
#labelled_items = []
label = 0
count = 0
region_index = 0
protect_level = 0
bracket_levels = [0]
region_levels = [0]
open_command_stack = []
open_stack = []
#protect_level_stack = []
#bracket_level_stack = []
#inserted_position_stack = []
#index_items_len = 0 # label * 2
#index_items_len = 0 # count * 2
for category, i, flag, _, _ in sorted_items:
if category in (2, 3):
if category >= 3:
if flag == 1:
protect_level += 1
continue
protect_level -= 1
if category == 2:
if category == 3:
continue
region_index += 1
command_match = command_matches[i]
command_flag = self.get_command_flag(command_match)
region_levels.append(region_levels[-1] + command_flag)
if command_flag == 1:
bracket_levels.append(bracket_levels[-1] + 1)
open_command_stack.append(
(command_match, region_index, label)
(command_match, region_index, count)
)
continue
elif command_flag == 0:
continue
bracket_levels.append(bracket_levels[-1] - 1)
command_match_, region_index_, label_ = open_command_stack.pop()
command_match_, region_index_, count_ = open_command_stack.pop()
attr_dict = self.get_attr_dict_from_command_pair(
command_match_, command_match
)
if attr_dict is None:
continue
span = (command_match_.end(), command_match.start())
region_span = (region_index_, region_index - 1)
else:
if flag == 1:
open_stack.append(
(category, i, protect_level, region_index, label)
(category, i, protect_level, region_index, count)
)
continue
category_, i_, protect_level_, region_index_, label_ \
category_, i_, protect_level_, region_index_, count_ \
= open_stack.pop()
span = categorized_spans[category][i]
if (category_, i_) != (category, i):
@ -292,80 +317,99 @@ class StringMobject(SVGMobject, ABC):
continue
if protect_level_ or protect_level:
continue
levels = bracket_levels[region_index_:region_index]
if levels and (
any(levels[0] > l for l in levels) or levels[0] < levels[-1]
):
ls = region_levels[region_index_:region_index + 1]
if ls and (any(ls[0] > l for l in ls) or ls[0] < ls[-1]):
log.warning(
"Cannot handle substring '%s'", get_substr(span)
)
continue
attr_dict = {} if category == 1 else configured_items[i][1]
pos = label_ * 2
attr_dict = configured_items[i][1] if category == 1 else {}
region_span = (region_index_, region_index)
#labelled_items.append(
# (span, region_span, (count_, count), attr_dict)
#)
pos = count_ * 2
labelled_spans.append(span)
attr_dicts.append(attr_dict)
inserted_items.insert(pos, (label, 1, span[0], region_index_))
inserted_items.append((label, -1, span[1], region_index))
label += 1
inserted_items.insert(pos, (count, 1, span[0], region_span[0]))
inserted_items.append((count, -1, span[1], region_span[1]))
count += 1
extended_inserted_items = [
(-1, 1, 0, 0),
*inserted_items,
(-1, -1, len(self.string), len(command_matches))
]
#labelled_spans = []
#attr_dicts = []
#inserted_items = []
#inserted_items.insert(0, (-1, 1, 0, 0))
#inserted_items.append((-1, -1, len(self.string), region_index))
inserted_label_items = [
(label, flag)
for label, flag, _, _ in extended_inserted_items
for label, flag, _, _ in inserted_items
]
inserted_indices = [
index
for _, _, index, _ in extended_inserted_items
]
inserted_region_indices = [
region_index
for _, _, _, region_index in extended_inserted_items
]
inserted_interval_spans = get_neighbouring_pairs(inserted_indices)
inserted_interval_region_spans = get_neighbouring_pairs(inserted_region_indices)
def get_complement_spans(
universal_span: Span, interval_spans: list[Span]
) -> list[Span]:
if not interval_spans:
return [universal_span]
span_ends, span_starts = zip(*interval_spans)
return list(zip(
(universal_span[0], *span_starts),
(*span_ends, universal_span[1])
))
def join_strs(strs: list[str], inserted_strs: list[str]) -> str:
return "".join(it.chain(*zip(strs, (*inserted_strs, ""))))
subpieces_groups = [
[
get_substr(s)
for s in get_complement_spans(
span, command_spans[slice(*region_range)]
#inserted_interval_spans = []
#command_matches_lists = []
#subpieces_lists = []
content_pieces = []
matching_pieces = []
for (_, _, prev_index, prev_region_index), (_, _, next_index, next_region_index) in get_neighbouring_pairs(inserted_items):
region_matches = command_matches[prev_region_index:next_region_index]
#command_matches_lists.append(region_matches)
subpieces = [
get_substr((start, end))
for start, end in zip(
[prev_index, *(m.end() for m in region_matches)],
[*(m.start() for m in region_matches), next_index]
)
]
for span, region_range in zip(inserted_interval_spans, inserted_interval_region_spans)
]
content_pieces.append(join_strs(subpieces, [
self.replace_for_content(m) for m in region_matches
]))
matching_pieces.append(join_strs(subpieces, [
self.replace_for_matching(m) for m in region_matches
]))
#subpieces_lists.append([
# get_substr(s)
# for s in get_complement_spans(
# (prev_index, next_index),
# [m.span() for m in region_matches]
# )
#])
def get_replaced_pieces(replace_func: Callable[[re.Match], str]) -> list[str]:
return [
join_strs(subpieces, [
replace_func(command_match)
for command_match in command_matches[slice(*region_range)]
])
for subpieces, region_range in zip(subpieces_groups, inserted_interval_region_spans)
]
content_pieces = get_replaced_pieces(self.replace_for_content)
matching_pieces = get_replaced_pieces(self.replace_for_matching)
#inserted_interval_spans = get_neighbouring_pairs([
# index
# for _, _, index, _ in inserted_items
#])
#command_matches_lists = [
# command_matches[slice(*region_range)]
# for region_range in get_neighbouring_pairs([
# region_index
# for _, _, _, region_index in inserted_items
# ])
#]
#subpieces_lists = [
# [
# get_substr(s)
# for s in get_complement_spans(
# span, [m.span() for m in match_list]
# )
# ]
# for span, match_list in zip(inserted_interval_spans, command_matches_lists)
#]
#def get_replaced_pieces(replace_func: Callable[[re.Match], str]) -> list[str]:
# return [
# join_strs(subpieces, [
# replace_func(command_match)
# for command_match in match_list
# ])
# for subpieces, match_list in zip(subpieces_lists, command_matches_lists)
# ]
#content_pieces = get_replaced_pieces(self.replace_for_content)
#matching_pieces = get_replaced_pieces(self.replace_for_matching)
def get_content(is_labelled: bool) -> str:
inserted_strings = [
@ -400,12 +444,16 @@ class StringMobject(SVGMobject, ABC):
)
]
def get_index(label, flag):
def get_region_index(label, flag):
#if label == -1:
# if flag == 1:
# return 0
# return len(inserted_label_items) - 1
return inserted_label_items.index((label, flag))
def get_labelled_span(label):
if label == -1:
return (0, len(self.string))
#if label == -1:
# return (0, len(self.string))
return labelled_spans[label]
def label_contains(label_0, label_1):
@ -413,53 +461,54 @@ class StringMobject(SVGMobject, ABC):
get_labelled_span(label_0), get_labelled_span(label_1)
)
#piece_starts = [
# get_index(group_labels[0], 1),
# *(
# get_index(curr_label, 1)
# if label_contains(prev_label, curr_label)
# else get_index(prev_label, -1)
# for prev_label, curr_label in get_neighbouring_pairs(
# group_labels
# )
# )
#]
#piece_ends = [
# *(
# get_index(curr_label, -1)
# if label_contains(next_label, curr_label)
# else get_index(next_label, 1)
# for curr_label, next_label in get_neighbouring_pairs(
# group_labels
# )
# ),
# get_index(group_labels[-1], -1)
#]
piece_ranges = get_complement_spans(
(get_index(group_labels[0], 1), get_index(group_labels[-1], -1)),
[
(
get_index(next_label, 1)
if label_contains(prev_label, next_label)
else get_index(prev_label, -1),
get_index(prev_label, -1)
if label_contains(next_label, prev_label)
else get_index(next_label, 1)
)
for prev_label, next_label in get_neighbouring_pairs(
piece_starts = [
get_region_index(group_labels[0], 1),
*(
get_region_index(curr_label, 1)
if label_contains(prev_label, curr_label)
else get_region_index(prev_label, -1)
for prev_label, curr_label in get_neighbouring_pairs(
group_labels
)
]
)
)
]
piece_ends = [
*(
get_region_index(curr_label, -1)
if label_contains(next_label, curr_label)
else get_region_index(next_label, 1)
for curr_label, next_label in get_neighbouring_pairs(
group_labels
)
),
get_region_index(group_labels[-1], -1)
]
#piece_ranges = get_complement_spans(
# (get_region_index(group_labels[0], 1), get_region_index(group_labels[-1], -1)),
# [
# (
# get_region_index(next_label, 1)
# if label_contains(prev_label, next_label)
# else get_region_index(prev_label, -1),
# get_region_index(prev_label, -1)
# if label_contains(next_label, prev_label)
# else get_region_index(next_label, 1)
# )
# for prev_label, next_label in get_neighbouring_pairs(
# group_labels
# )
# ]
#)
group_substrs = [
re.sub(r"\s+", "", "".join(
matching_pieces[slice(*piece_ranges)]
matching_pieces[start:end]
))
for piece_ranges in piece_ranges
for start, end in zip(piece_starts, piece_ends)
]
return list(zip(group_substrs, submob_indices_lists))
#print(labelled_spans)
self.labelled_spans = labelled_spans
self.get_content = get_content
self.get_group_part_items_by_labels = get_group_part_items_by_labels
@ -516,9 +565,7 @@ class StringMobject(SVGMobject, ABC):
return [
submob_index
for submob_index, label in enumerate(self.labels)
if label != -1 and self.span_contains(
arbitrary_span, self.labelled_spans[label]
)
if self.span_contains(arbitrary_span, self.labelled_spans[label])
]
def get_specified_part_items(self) -> list[tuple[str, list[int]]]:
@ -527,7 +574,7 @@ class StringMobject(SVGMobject, ABC):
self.string[slice(*span)],
self.get_submob_indices_list_by_span(span)
)
for span in self.labelled_spans
for span in self.labelled_spans[:-1]
]
def get_group_part_items(self) -> list[tuple[str, list[int]]]:

View file

@ -80,6 +80,7 @@ class MarkupText(StringMobject):
"t2w": {},
"global_config": {},
"local_configs": {},
"disable_ligatures": True,
"isolate": re.compile(r"\w+", re.U),
}
@ -150,7 +151,8 @@ class MarkupText(StringMobject):
self.t2s,
self.t2w,
self.global_config,
self.local_configs
self.local_configs,
self.disable_ligatures
)
def full2short(self, config: dict) -> None:
@ -359,9 +361,8 @@ class MarkupText(StringMobject):
"font_family": self.font,
"font_style": self.slant,
"font_weight": self.weight,
"font_size": str(self.font_size * 1024),
"font_size": str(round(self.font_size * 1024)),
}
global_attr_dict.update(self.global_config)
# `line_height` attribute is supported since Pango 1.50.
pango_version = manimpango.pango_version()
if tuple(map(int, pango_version.split("."))) < (1, 50):
@ -376,7 +377,10 @@ class MarkupText(StringMobject):
global_attr_dict["line_height"] = str(
((line_spacing_scale) + 1) * 0.6
)
if self.disable_ligatures:
global_attr_dict["font_features"] = "liga=0,dlig=0,clig=0,hlig=0"
global_attr_dict.update(self.global_config)
return tuple(
self.get_command_string(
global_attr_dict,
@ -413,8 +417,9 @@ class Text(MarkupText):
}
@staticmethod
def get_command_pattern() -> str | None:
return r"""[<>&"']"""
def get_command_matches(string: str) -> list[re.Match]:
pattern = re.compile(r"""[<>&"']""")
return list(pattern.finditer(string))
@staticmethod
def get_command_flag(match_obj: re.Match) -> int:

View file

@ -253,7 +253,7 @@ gfs_didot: |-
\let\varphi\phi
# GFS NeoHellenic
gfs_neoHellenic: |-
gfs_neohellenic: |-
\usepackage[T1]{fontenc}
\renewcommand{\rmdefault}{neohellenic}
\usepackage[LGRgreek]{mathastext}

View file

@ -16,7 +16,7 @@ SAVED_TEX_CONFIG = {}
def get_tex_preamble(template_name: str) -> str:
name = re.sub(r"[^a-zA-Z]", "_", template_name).lower()
name = template_name.replace(" ", "_").lower()
with open(os.path.join(
get_manim_dir(), "manimlib", "tex_templates.yml"
), encoding="utf-8") as tex_templates_file:
@ -143,7 +143,7 @@ def create_tex_svg(full_tex: str, svg_file: str, compiler: str) -> None:
# TODO, perhaps this should live elsewhere
@contextmanager
def display_during_execution(message: str) -> None:
def display_during_execution(message: str):
# Merge into a single line
to_print = message.replace("\n", " ")
max_characters = os.get_terminal_size().columns - 1