refactor: refactor StringMobject

This commit is contained in:
YishiMichael 2022-08-20 13:01:59 +08:00
parent 7ffc7b33f7
commit 19c757ec90
No known key found for this signature in database
GPG key ID: EC615C0C5A86BC80
5 changed files with 190 additions and 138 deletions

View file

@ -131,9 +131,9 @@ class MTex(StringMobject):
@staticmethod @staticmethod
def replace_for_matching(match_obj: re.Match) -> str: def replace_for_matching(match_obj: re.Match) -> str:
if match_obj.group("command"): if match_obj.group("script"):
return match_obj.group()
return "" return ""
return match_obj.group()
@staticmethod @staticmethod
def get_attr_dict_from_command_pair( def get_attr_dict_from_command_pair(

View file

@ -90,21 +90,22 @@ class StringMobject(SVGMobject, ABC):
super().generate_mobject() super().generate_mobject()
labels_count = len(self.labelled_spans) labels_count = len(self.labelled_spans)
if not labels_count: if labels_count == 1:
for submob in self.submobjects: for submob in self.submobjects:
submob.label = -1 submob.label = 0
return return
labelled_content = self.get_content(is_labelled=True) labelled_content = self.get_content(is_labelled=True)
file_path = self.get_file_path_by_content(labelled_content) file_path = self.get_file_path_by_content(labelled_content)
labelled_svg = SVGMobject(file_path) labelled_svg = SVGMobject(file_path)
#print(len(self.submobjects), len(labelled_svg.submobjects)) # ????
if len(self.submobjects) != len(labelled_svg.submobjects): if len(self.submobjects) != len(labelled_svg.submobjects):
log.warning( log.warning(
"Cannot align submobjects of the labelled svg " "Cannot align submobjects of the labelled svg "
"to the original svg. Skip the labelling process." "to the original svg. Skip the labelling process."
) )
for submob in self.submobjects: for submob in self.submobjects:
submob.label = -1 submob.label = labels_count - 1
return return
self.rearrange_submobjects_by_positions(labelled_svg) self.rearrange_submobjects_by_positions(labelled_svg)
@ -115,15 +116,18 @@ class StringMobject(SVGMobject, ABC):
color_int = self.hex_to_int(self.color_to_hex( color_int = self.hex_to_int(self.color_to_hex(
labelled_svg_submob.get_fill_color() labelled_svg_submob.get_fill_color()
)) ))
if color_int > labels_count: if color_int >= labels_count:
unrecognizable_colors.append(color_int) unrecognizable_colors.append(color_int)
color_int = 0 color_int = labels_count
submob.label = color_int - 1 submob.label = color_int - 1
if unrecognizable_colors: if unrecognizable_colors:
log.warning( log.warning(
"Unrecognizable color labels detected (%s, etc). " "Unrecognizable color labels detected (%s). "
"The result could be unexpected.", "The result could be unexpected.",
self.int_to_hex(unrecognizable_colors[0]) ", ".join([
self.int_to_hex(color)
for color in unrecognizable_colors
])
) )
def rearrange_submobjects_by_positions( def rearrange_submobjects_by_positions(
@ -186,7 +190,7 @@ class StringMobject(SVGMobject, ABC):
if spans is None: if spans is None:
raise TypeError(f"Invalid selector: '{sel}'") raise TypeError(f"Invalid selector: '{sel}'")
result.extend(spans) result.extend(spans)
return list(filter(lambda span: span[0] < span[1], result)) return list(filter(lambda span: span[0] <= span[1], result))
@staticmethod @staticmethod
def span_contains(span_0: Span, span_1: Span) -> bool: def span_contains(span_0: Span, span_1: Span) -> bool:
@ -214,73 +218,94 @@ class StringMobject(SVGMobject, ABC):
val_list = list(vals) val_list = list(vals)
return list(zip(val_list[:-1], val_list[1:])) return list(zip(val_list[:-1], val_list[1:]))
#def get_complement_spans(
# universal_span: Span, interval_spans: list[Span]
#) -> list[Span]:
# if not interval_spans:
# return [universal_span]
# span_ends, span_starts = zip(*interval_spans)
# return list(zip(
# (universal_span[0], *span_starts),
# (*span_ends, universal_span[1])
# ))
def join_strs(strs: list[str], inserted_strs: list[str]) -> str:
return "".join(it.chain(*zip(strs, (*inserted_strs, ""))))
command_matches = self.get_command_matches(self.string) command_matches = self.get_command_matches(self.string)
command_spans = [match_obj.span() for match_obj in command_matches] #command_spans = [match_obj.span() for match_obj in command_matches]
configured_items = self.get_configured_items() configured_items = self.get_configured_items()
#configured_spans = [span for span, _ in configured_items] #configured_spans = [span for span, _ in configured_items]
#configured_attr_dicts = [d for _, d in configured_items] #configured_attr_dicts = [d for _, d in configured_items]
categorized_spans = [ categorized_spans = [
[(0, len(self.string))], # TODO
[span for span, _ in configured_items], [span for span, _ in configured_items],
self.find_spans_by_selector(self.isolate), self.find_spans_by_selector(self.isolate),
self.find_spans_by_selector(self.protect), self.find_spans_by_selector(self.protect),
command_spans # TODO [match_obj.span() for match_obj in command_matches] # TODO
] ]
sorted_items = sorted([ sorted_items = sorted([
(category, category_index, flag, *span[::flag]) (category, category_index, flag, *span[::flag])
for category, spans in enumerate(categorized_spans) for category, spans in enumerate(categorized_spans)
for category_index, span in enumerate(spans) for category_index, span in enumerate(spans)
for flag in (1, -1) for flag in (1, -1)
], key=lambda t: (t[3], t[2], -t[4], t[2] * (t[0] + 1), t[2] * t[1])) # TODO ], key=lambda t: (
t[3], t[2] * (2 if t[3] != t[4] else -1), -t[4],
t[2] * t[0], t[2] * t[1]
)) # TODO
labelled_spans = [] labelled_spans = []
attr_dicts = [] attr_dicts = []
inserted_items = [] inserted_items = []
#labelled_items = []
label = 0 count = 0
region_index = 0 region_index = 0
protect_level = 0 protect_level = 0
bracket_levels = [0] region_levels = [0]
open_command_stack = [] open_command_stack = []
open_stack = [] open_stack = []
#protect_level_stack = [] #protect_level_stack = []
#bracket_level_stack = [] #bracket_level_stack = []
#inserted_position_stack = [] #inserted_position_stack = []
#index_items_len = 0 # label * 2 #index_items_len = 0 # count * 2
for category, i, flag, _, _ in sorted_items: for category, i, flag, _, _ in sorted_items:
if category in (2, 3): if category >= 3:
if flag == 1: if flag == 1:
protect_level += 1 protect_level += 1
continue continue
protect_level -= 1 protect_level -= 1
if category == 2: if category == 3:
continue continue
region_index += 1 region_index += 1
command_match = command_matches[i] command_match = command_matches[i]
command_flag = self.get_command_flag(command_match) command_flag = self.get_command_flag(command_match)
region_levels.append(region_levels[-1] + command_flag)
if command_flag == 1: if command_flag == 1:
bracket_levels.append(bracket_levels[-1] + 1)
open_command_stack.append( open_command_stack.append(
(command_match, region_index, label) (command_match, region_index, count)
) )
continue continue
elif command_flag == 0: elif command_flag == 0:
continue continue
bracket_levels.append(bracket_levels[-1] - 1) command_match_, region_index_, count_ = open_command_stack.pop()
command_match_, region_index_, label_ = open_command_stack.pop()
attr_dict = self.get_attr_dict_from_command_pair( attr_dict = self.get_attr_dict_from_command_pair(
command_match_, command_match command_match_, command_match
) )
if attr_dict is None: if attr_dict is None:
continue continue
span = (command_match_.end(), command_match.start()) span = (command_match_.end(), command_match.start())
region_span = (region_index_, region_index - 1)
else: else:
if flag == 1: if flag == 1:
open_stack.append( open_stack.append(
(category, i, protect_level, region_index, label) (category, i, protect_level, region_index, count)
) )
continue continue
category_, i_, protect_level_, region_index_, label_ \ category_, i_, protect_level_, region_index_, count_ \
= open_stack.pop() = open_stack.pop()
span = categorized_spans[category][i] span = categorized_spans[category][i]
if (category_, i_) != (category, i): if (category_, i_) != (category, i):
@ -292,80 +317,99 @@ class StringMobject(SVGMobject, ABC):
continue continue
if protect_level_ or protect_level: if protect_level_ or protect_level:
continue continue
levels = bracket_levels[region_index_:region_index] ls = region_levels[region_index_:region_index + 1]
if levels and ( if ls and (any(ls[0] > l for l in ls) or ls[0] < ls[-1]):
any(levels[0] > l for l in levels) or levels[0] < levels[-1]
):
log.warning( log.warning(
"Cannot handle substring '%s'", get_substr(span) "Cannot handle substring '%s'", get_substr(span)
) )
continue continue
attr_dict = {} if category == 1 else configured_items[i][1] attr_dict = configured_items[i][1] if category == 1 else {}
pos = label_ * 2 region_span = (region_index_, region_index)
#labelled_items.append(
# (span, region_span, (count_, count), attr_dict)
#)
pos = count_ * 2
labelled_spans.append(span) labelled_spans.append(span)
attr_dicts.append(attr_dict) attr_dicts.append(attr_dict)
inserted_items.insert(pos, (label, 1, span[0], region_index_)) inserted_items.insert(pos, (count, 1, span[0], region_span[0]))
inserted_items.append((label, -1, span[1], region_index)) inserted_items.append((count, -1, span[1], region_span[1]))
label += 1 count += 1
extended_inserted_items = [ #labelled_spans = []
(-1, 1, 0, 0), #attr_dicts = []
*inserted_items, #inserted_items = []
(-1, -1, len(self.string), len(command_matches))
] #inserted_items.insert(0, (-1, 1, 0, 0))
#inserted_items.append((-1, -1, len(self.string), region_index))
inserted_label_items = [ inserted_label_items = [
(label, flag) (label, flag)
for label, flag, _, _ in extended_inserted_items for label, flag, _, _ in inserted_items
] ]
inserted_indices = [ #inserted_interval_spans = []
index #command_matches_lists = []
for _, _, index, _ in extended_inserted_items #subpieces_lists = []
] content_pieces = []
inserted_region_indices = [ matching_pieces = []
region_index for (_, _, prev_index, prev_region_index), (_, _, next_index, next_region_index) in get_neighbouring_pairs(inserted_items):
for _, _, _, region_index in extended_inserted_items region_matches = command_matches[prev_region_index:next_region_index]
] #command_matches_lists.append(region_matches)
subpieces = [
inserted_interval_spans = get_neighbouring_pairs(inserted_indices) get_substr((start, end))
inserted_interval_region_spans = get_neighbouring_pairs(inserted_region_indices) for start, end in zip(
[prev_index, *(m.end() for m in region_matches)],
def get_complement_spans( [*(m.start() for m in region_matches), next_index]
universal_span: Span, interval_spans: list[Span]
) -> list[Span]:
if not interval_spans:
return [universal_span]
span_ends, span_starts = zip(*interval_spans)
return list(zip(
(universal_span[0], *span_starts),
(*span_ends, universal_span[1])
))
def join_strs(strs: list[str], inserted_strs: list[str]) -> str:
return "".join(it.chain(*zip(strs, (*inserted_strs, ""))))
subpieces_groups = [
[
get_substr(s)
for s in get_complement_spans(
span, command_spans[slice(*region_range)]
) )
] ]
for span, region_range in zip(inserted_interval_spans, inserted_interval_region_spans) content_pieces.append(join_strs(subpieces, [
] self.replace_for_content(m) for m in region_matches
]))
matching_pieces.append(join_strs(subpieces, [
self.replace_for_matching(m) for m in region_matches
]))
#subpieces_lists.append([
# get_substr(s)
# for s in get_complement_spans(
# (prev_index, next_index),
# [m.span() for m in region_matches]
# )
#])
def get_replaced_pieces(replace_func: Callable[[re.Match], str]) -> list[str]:
return [
join_strs(subpieces, [
replace_func(command_match)
for command_match in command_matches[slice(*region_range)]
])
for subpieces, region_range in zip(subpieces_groups, inserted_interval_region_spans)
]
content_pieces = get_replaced_pieces(self.replace_for_content)
matching_pieces = get_replaced_pieces(self.replace_for_matching) #inserted_interval_spans = get_neighbouring_pairs([
# index
# for _, _, index, _ in inserted_items
#])
#command_matches_lists = [
# command_matches[slice(*region_range)]
# for region_range in get_neighbouring_pairs([
# region_index
# for _, _, _, region_index in inserted_items
# ])
#]
#subpieces_lists = [
# [
# get_substr(s)
# for s in get_complement_spans(
# span, [m.span() for m in match_list]
# )
# ]
# for span, match_list in zip(inserted_interval_spans, command_matches_lists)
#]
#def get_replaced_pieces(replace_func: Callable[[re.Match], str]) -> list[str]:
# return [
# join_strs(subpieces, [
# replace_func(command_match)
# for command_match in match_list
# ])
# for subpieces, match_list in zip(subpieces_lists, command_matches_lists)
# ]
#content_pieces = get_replaced_pieces(self.replace_for_content)
#matching_pieces = get_replaced_pieces(self.replace_for_matching)
def get_content(is_labelled: bool) -> str: def get_content(is_labelled: bool) -> str:
inserted_strings = [ inserted_strings = [
@ -400,12 +444,16 @@ class StringMobject(SVGMobject, ABC):
) )
] ]
def get_index(label, flag): def get_region_index(label, flag):
#if label == -1:
# if flag == 1:
# return 0
# return len(inserted_label_items) - 1
return inserted_label_items.index((label, flag)) return inserted_label_items.index((label, flag))
def get_labelled_span(label): def get_labelled_span(label):
if label == -1: #if label == -1:
return (0, len(self.string)) # return (0, len(self.string))
return labelled_spans[label] return labelled_spans[label]
def label_contains(label_0, label_1): def label_contains(label_0, label_1):
@ -413,53 +461,54 @@ class StringMobject(SVGMobject, ABC):
get_labelled_span(label_0), get_labelled_span(label_1) get_labelled_span(label_0), get_labelled_span(label_1)
) )
#piece_starts = [ piece_starts = [
# get_index(group_labels[0], 1), get_region_index(group_labels[0], 1),
# *( *(
# get_index(curr_label, 1) get_region_index(curr_label, 1)
# if label_contains(prev_label, curr_label) if label_contains(prev_label, curr_label)
# else get_index(prev_label, -1) else get_region_index(prev_label, -1)
# for prev_label, curr_label in get_neighbouring_pairs( for prev_label, curr_label in get_neighbouring_pairs(
# group_labels
# )
# )
#]
#piece_ends = [
# *(
# get_index(curr_label, -1)
# if label_contains(next_label, curr_label)
# else get_index(next_label, 1)
# for curr_label, next_label in get_neighbouring_pairs(
# group_labels
# )
# ),
# get_index(group_labels[-1], -1)
#]
piece_ranges = get_complement_spans(
(get_index(group_labels[0], 1), get_index(group_labels[-1], -1)),
[
(
get_index(next_label, 1)
if label_contains(prev_label, next_label)
else get_index(prev_label, -1),
get_index(prev_label, -1)
if label_contains(next_label, prev_label)
else get_index(next_label, 1)
)
for prev_label, next_label in get_neighbouring_pairs(
group_labels group_labels
) )
]
) )
]
piece_ends = [
*(
get_region_index(curr_label, -1)
if label_contains(next_label, curr_label)
else get_region_index(next_label, 1)
for curr_label, next_label in get_neighbouring_pairs(
group_labels
)
),
get_region_index(group_labels[-1], -1)
]
#piece_ranges = get_complement_spans(
# (get_region_index(group_labels[0], 1), get_region_index(group_labels[-1], -1)),
# [
# (
# get_region_index(next_label, 1)
# if label_contains(prev_label, next_label)
# else get_region_index(prev_label, -1),
# get_region_index(prev_label, -1)
# if label_contains(next_label, prev_label)
# else get_region_index(next_label, 1)
# )
# for prev_label, next_label in get_neighbouring_pairs(
# group_labels
# )
# ]
#)
group_substrs = [ group_substrs = [
re.sub(r"\s+", "", "".join( re.sub(r"\s+", "", "".join(
matching_pieces[slice(*piece_ranges)] matching_pieces[start:end]
)) ))
for piece_ranges in piece_ranges for start, end in zip(piece_starts, piece_ends)
] ]
return list(zip(group_substrs, submob_indices_lists)) return list(zip(group_substrs, submob_indices_lists))
#print(labelled_spans)
self.labelled_spans = labelled_spans self.labelled_spans = labelled_spans
self.get_content = get_content self.get_content = get_content
self.get_group_part_items_by_labels = get_group_part_items_by_labels self.get_group_part_items_by_labels = get_group_part_items_by_labels
@ -516,9 +565,7 @@ class StringMobject(SVGMobject, ABC):
return [ return [
submob_index submob_index
for submob_index, label in enumerate(self.labels) for submob_index, label in enumerate(self.labels)
if label != -1 and self.span_contains( if self.span_contains(arbitrary_span, self.labelled_spans[label])
arbitrary_span, self.labelled_spans[label]
)
] ]
def get_specified_part_items(self) -> list[tuple[str, list[int]]]: def get_specified_part_items(self) -> list[tuple[str, list[int]]]:
@ -527,7 +574,7 @@ class StringMobject(SVGMobject, ABC):
self.string[slice(*span)], self.string[slice(*span)],
self.get_submob_indices_list_by_span(span) self.get_submob_indices_list_by_span(span)
) )
for span in self.labelled_spans for span in self.labelled_spans[:-1]
] ]
def get_group_part_items(self) -> list[tuple[str, list[int]]]: def get_group_part_items(self) -> list[tuple[str, list[int]]]:

View file

@ -80,6 +80,7 @@ class MarkupText(StringMobject):
"t2w": {}, "t2w": {},
"global_config": {}, "global_config": {},
"local_configs": {}, "local_configs": {},
"disable_ligatures": True,
"isolate": re.compile(r"\w+", re.U), "isolate": re.compile(r"\w+", re.U),
} }
@ -150,7 +151,8 @@ class MarkupText(StringMobject):
self.t2s, self.t2s,
self.t2w, self.t2w,
self.global_config, self.global_config,
self.local_configs self.local_configs,
self.disable_ligatures
) )
def full2short(self, config: dict) -> None: def full2short(self, config: dict) -> None:
@ -359,9 +361,8 @@ class MarkupText(StringMobject):
"font_family": self.font, "font_family": self.font,
"font_style": self.slant, "font_style": self.slant,
"font_weight": self.weight, "font_weight": self.weight,
"font_size": str(self.font_size * 1024), "font_size": str(round(self.font_size * 1024)),
} }
global_attr_dict.update(self.global_config)
# `line_height` attribute is supported since Pango 1.50. # `line_height` attribute is supported since Pango 1.50.
pango_version = manimpango.pango_version() pango_version = manimpango.pango_version()
if tuple(map(int, pango_version.split("."))) < (1, 50): if tuple(map(int, pango_version.split("."))) < (1, 50):
@ -376,7 +377,10 @@ class MarkupText(StringMobject):
global_attr_dict["line_height"] = str( global_attr_dict["line_height"] = str(
((line_spacing_scale) + 1) * 0.6 ((line_spacing_scale) + 1) * 0.6
) )
if self.disable_ligatures:
global_attr_dict["font_features"] = "liga=0,dlig=0,clig=0,hlig=0"
global_attr_dict.update(self.global_config)
return tuple( return tuple(
self.get_command_string( self.get_command_string(
global_attr_dict, global_attr_dict,
@ -413,8 +417,9 @@ class Text(MarkupText):
} }
@staticmethod @staticmethod
def get_command_pattern() -> str | None: def get_command_matches(string: str) -> list[re.Match]:
return r"""[<>&"']""" pattern = re.compile(r"""[<>&"']""")
return list(pattern.finditer(string))
@staticmethod @staticmethod
def get_command_flag(match_obj: re.Match) -> int: def get_command_flag(match_obj: re.Match) -> int:

View file

@ -253,7 +253,7 @@ gfs_didot: |-
\let\varphi\phi \let\varphi\phi
# GFS NeoHellenic # GFS NeoHellenic
gfs_neoHellenic: |- gfs_neohellenic: |-
\usepackage[T1]{fontenc} \usepackage[T1]{fontenc}
\renewcommand{\rmdefault}{neohellenic} \renewcommand{\rmdefault}{neohellenic}
\usepackage[LGRgreek]{mathastext} \usepackage[LGRgreek]{mathastext}

View file

@ -16,7 +16,7 @@ SAVED_TEX_CONFIG = {}
def get_tex_preamble(template_name: str) -> str: def get_tex_preamble(template_name: str) -> str:
name = re.sub(r"[^a-zA-Z]", "_", template_name).lower() name = template_name.replace(" ", "_").lower()
with open(os.path.join( with open(os.path.join(
get_manim_dir(), "manimlib", "tex_templates.yml" get_manim_dir(), "manimlib", "tex_templates.yml"
), encoding="utf-8") as tex_templates_file: ), encoding="utf-8") as tex_templates_file:
@ -143,7 +143,7 @@ def create_tex_svg(full_tex: str, svg_file: str, compiler: str) -> None:
# TODO, perhaps this should live elsewhere # TODO, perhaps this should live elsewhere
@contextmanager @contextmanager
def display_during_execution(message: str) -> None: def display_during_execution(message: str):
# Merge into a single line # Merge into a single line
to_print = message.replace("\n", " ") to_print = message.replace("\n", " ")
max_characters = os.get_terminal_size().columns - 1 max_characters = os.get_terminal_size().columns - 1