mirror of
https://github.com/3b1b/manim.git
synced 2025-08-05 16:49:03 +00:00
refactor: refactor StringMobject
This commit is contained in:
parent
7ffc7b33f7
commit
19c757ec90
5 changed files with 190 additions and 138 deletions
|
@ -131,9 +131,9 @@ class MTex(StringMobject):
|
|||
|
||||
@staticmethod
|
||||
def replace_for_matching(match_obj: re.Match) -> str:
|
||||
if match_obj.group("command"):
|
||||
return match_obj.group()
|
||||
return ""
|
||||
if match_obj.group("script"):
|
||||
return ""
|
||||
return match_obj.group()
|
||||
|
||||
@staticmethod
|
||||
def get_attr_dict_from_command_pair(
|
||||
|
|
|
@ -90,21 +90,22 @@ class StringMobject(SVGMobject, ABC):
|
|||
super().generate_mobject()
|
||||
|
||||
labels_count = len(self.labelled_spans)
|
||||
if not labels_count:
|
||||
if labels_count == 1:
|
||||
for submob in self.submobjects:
|
||||
submob.label = -1
|
||||
submob.label = 0
|
||||
return
|
||||
|
||||
labelled_content = self.get_content(is_labelled=True)
|
||||
file_path = self.get_file_path_by_content(labelled_content)
|
||||
labelled_svg = SVGMobject(file_path)
|
||||
#print(len(self.submobjects), len(labelled_svg.submobjects)) # ????
|
||||
if len(self.submobjects) != len(labelled_svg.submobjects):
|
||||
log.warning(
|
||||
"Cannot align submobjects of the labelled svg "
|
||||
"to the original svg. Skip the labelling process."
|
||||
)
|
||||
for submob in self.submobjects:
|
||||
submob.label = -1
|
||||
submob.label = labels_count - 1
|
||||
return
|
||||
|
||||
self.rearrange_submobjects_by_positions(labelled_svg)
|
||||
|
@ -115,15 +116,18 @@ class StringMobject(SVGMobject, ABC):
|
|||
color_int = self.hex_to_int(self.color_to_hex(
|
||||
labelled_svg_submob.get_fill_color()
|
||||
))
|
||||
if color_int > labels_count:
|
||||
if color_int >= labels_count:
|
||||
unrecognizable_colors.append(color_int)
|
||||
color_int = 0
|
||||
color_int = labels_count
|
||||
submob.label = color_int - 1
|
||||
if unrecognizable_colors:
|
||||
log.warning(
|
||||
"Unrecognizable color labels detected (%s, etc). "
|
||||
"Unrecognizable color labels detected (%s). "
|
||||
"The result could be unexpected.",
|
||||
self.int_to_hex(unrecognizable_colors[0])
|
||||
", ".join([
|
||||
self.int_to_hex(color)
|
||||
for color in unrecognizable_colors
|
||||
])
|
||||
)
|
||||
|
||||
def rearrange_submobjects_by_positions(
|
||||
|
@ -186,7 +190,7 @@ class StringMobject(SVGMobject, ABC):
|
|||
if spans is None:
|
||||
raise TypeError(f"Invalid selector: '{sel}'")
|
||||
result.extend(spans)
|
||||
return list(filter(lambda span: span[0] < span[1], result))
|
||||
return list(filter(lambda span: span[0] <= span[1], result))
|
||||
|
||||
@staticmethod
|
||||
def span_contains(span_0: Span, span_1: Span) -> bool:
|
||||
|
@ -214,73 +218,94 @@ class StringMobject(SVGMobject, ABC):
|
|||
val_list = list(vals)
|
||||
return list(zip(val_list[:-1], val_list[1:]))
|
||||
|
||||
#def get_complement_spans(
|
||||
# universal_span: Span, interval_spans: list[Span]
|
||||
#) -> list[Span]:
|
||||
# if not interval_spans:
|
||||
# return [universal_span]
|
||||
|
||||
# span_ends, span_starts = zip(*interval_spans)
|
||||
# return list(zip(
|
||||
# (universal_span[0], *span_starts),
|
||||
# (*span_ends, universal_span[1])
|
||||
# ))
|
||||
|
||||
def join_strs(strs: list[str], inserted_strs: list[str]) -> str:
|
||||
return "".join(it.chain(*zip(strs, (*inserted_strs, ""))))
|
||||
|
||||
command_matches = self.get_command_matches(self.string)
|
||||
command_spans = [match_obj.span() for match_obj in command_matches]
|
||||
#command_spans = [match_obj.span() for match_obj in command_matches]
|
||||
configured_items = self.get_configured_items()
|
||||
#configured_spans = [span for span, _ in configured_items]
|
||||
#configured_attr_dicts = [d for _, d in configured_items]
|
||||
categorized_spans = [
|
||||
[(0, len(self.string))], # TODO
|
||||
[span for span, _ in configured_items],
|
||||
self.find_spans_by_selector(self.isolate),
|
||||
self.find_spans_by_selector(self.protect),
|
||||
command_spans # TODO
|
||||
[match_obj.span() for match_obj in command_matches] # TODO
|
||||
]
|
||||
|
||||
sorted_items = sorted([
|
||||
(category, category_index, flag, *span[::flag])
|
||||
for category, spans in enumerate(categorized_spans)
|
||||
for category_index, span in enumerate(spans)
|
||||
for flag in (1, -1)
|
||||
], key=lambda t: (t[3], t[2], -t[4], t[2] * (t[0] + 1), t[2] * t[1])) # TODO
|
||||
], key=lambda t: (
|
||||
t[3], t[2] * (2 if t[3] != t[4] else -1), -t[4],
|
||||
t[2] * t[0], t[2] * t[1]
|
||||
)) # TODO
|
||||
|
||||
labelled_spans = []
|
||||
attr_dicts = []
|
||||
inserted_items = []
|
||||
#labelled_items = []
|
||||
|
||||
label = 0
|
||||
count = 0
|
||||
region_index = 0
|
||||
protect_level = 0
|
||||
bracket_levels = [0]
|
||||
region_levels = [0]
|
||||
open_command_stack = []
|
||||
open_stack = []
|
||||
|
||||
#protect_level_stack = []
|
||||
#bracket_level_stack = []
|
||||
#inserted_position_stack = []
|
||||
#index_items_len = 0 # label * 2
|
||||
#index_items_len = 0 # count * 2
|
||||
for category, i, flag, _, _ in sorted_items:
|
||||
if category in (2, 3):
|
||||
if category >= 3:
|
||||
if flag == 1:
|
||||
protect_level += 1
|
||||
continue
|
||||
protect_level -= 1
|
||||
if category == 2:
|
||||
if category == 3:
|
||||
continue
|
||||
region_index += 1
|
||||
command_match = command_matches[i]
|
||||
command_flag = self.get_command_flag(command_match)
|
||||
region_levels.append(region_levels[-1] + command_flag)
|
||||
if command_flag == 1:
|
||||
bracket_levels.append(bracket_levels[-1] + 1)
|
||||
open_command_stack.append(
|
||||
(command_match, region_index, label)
|
||||
(command_match, region_index, count)
|
||||
)
|
||||
continue
|
||||
elif command_flag == 0:
|
||||
continue
|
||||
bracket_levels.append(bracket_levels[-1] - 1)
|
||||
command_match_, region_index_, label_ = open_command_stack.pop()
|
||||
command_match_, region_index_, count_ = open_command_stack.pop()
|
||||
attr_dict = self.get_attr_dict_from_command_pair(
|
||||
command_match_, command_match
|
||||
)
|
||||
if attr_dict is None:
|
||||
continue
|
||||
span = (command_match_.end(), command_match.start())
|
||||
region_span = (region_index_, region_index - 1)
|
||||
else:
|
||||
if flag == 1:
|
||||
open_stack.append(
|
||||
(category, i, protect_level, region_index, label)
|
||||
(category, i, protect_level, region_index, count)
|
||||
)
|
||||
continue
|
||||
category_, i_, protect_level_, region_index_, label_ \
|
||||
category_, i_, protect_level_, region_index_, count_ \
|
||||
= open_stack.pop()
|
||||
span = categorized_spans[category][i]
|
||||
if (category_, i_) != (category, i):
|
||||
|
@ -292,80 +317,99 @@ class StringMobject(SVGMobject, ABC):
|
|||
continue
|
||||
if protect_level_ or protect_level:
|
||||
continue
|
||||
levels = bracket_levels[region_index_:region_index]
|
||||
if levels and (
|
||||
any(levels[0] > l for l in levels) or levels[0] < levels[-1]
|
||||
):
|
||||
ls = region_levels[region_index_:region_index + 1]
|
||||
if ls and (any(ls[0] > l for l in ls) or ls[0] < ls[-1]):
|
||||
log.warning(
|
||||
"Cannot handle substring '%s'", get_substr(span)
|
||||
)
|
||||
continue
|
||||
attr_dict = {} if category == 1 else configured_items[i][1]
|
||||
pos = label_ * 2
|
||||
attr_dict = configured_items[i][1] if category == 1 else {}
|
||||
region_span = (region_index_, region_index)
|
||||
#labelled_items.append(
|
||||
# (span, region_span, (count_, count), attr_dict)
|
||||
#)
|
||||
pos = count_ * 2
|
||||
labelled_spans.append(span)
|
||||
attr_dicts.append(attr_dict)
|
||||
inserted_items.insert(pos, (label, 1, span[0], region_index_))
|
||||
inserted_items.append((label, -1, span[1], region_index))
|
||||
label += 1
|
||||
inserted_items.insert(pos, (count, 1, span[0], region_span[0]))
|
||||
inserted_items.append((count, -1, span[1], region_span[1]))
|
||||
count += 1
|
||||
|
||||
extended_inserted_items = [
|
||||
(-1, 1, 0, 0),
|
||||
*inserted_items,
|
||||
(-1, -1, len(self.string), len(command_matches))
|
||||
]
|
||||
#labelled_spans = []
|
||||
#attr_dicts = []
|
||||
#inserted_items = []
|
||||
|
||||
#inserted_items.insert(0, (-1, 1, 0, 0))
|
||||
#inserted_items.append((-1, -1, len(self.string), region_index))
|
||||
|
||||
inserted_label_items = [
|
||||
(label, flag)
|
||||
for label, flag, _, _ in extended_inserted_items
|
||||
for label, flag, _, _ in inserted_items
|
||||
]
|
||||
inserted_indices = [
|
||||
index
|
||||
for _, _, index, _ in extended_inserted_items
|
||||
]
|
||||
inserted_region_indices = [
|
||||
region_index
|
||||
for _, _, _, region_index in extended_inserted_items
|
||||
]
|
||||
|
||||
inserted_interval_spans = get_neighbouring_pairs(inserted_indices)
|
||||
inserted_interval_region_spans = get_neighbouring_pairs(inserted_region_indices)
|
||||
|
||||
def get_complement_spans(
|
||||
universal_span: Span, interval_spans: list[Span]
|
||||
) -> list[Span]:
|
||||
if not interval_spans:
|
||||
return [universal_span]
|
||||
|
||||
span_ends, span_starts = zip(*interval_spans)
|
||||
return list(zip(
|
||||
(universal_span[0], *span_starts),
|
||||
(*span_ends, universal_span[1])
|
||||
))
|
||||
|
||||
def join_strs(strs: list[str], inserted_strs: list[str]) -> str:
|
||||
return "".join(it.chain(*zip(strs, (*inserted_strs, ""))))
|
||||
|
||||
subpieces_groups = [
|
||||
[
|
||||
get_substr(s)
|
||||
for s in get_complement_spans(
|
||||
span, command_spans[slice(*region_range)]
|
||||
#inserted_interval_spans = []
|
||||
#command_matches_lists = []
|
||||
#subpieces_lists = []
|
||||
content_pieces = []
|
||||
matching_pieces = []
|
||||
for (_, _, prev_index, prev_region_index), (_, _, next_index, next_region_index) in get_neighbouring_pairs(inserted_items):
|
||||
region_matches = command_matches[prev_region_index:next_region_index]
|
||||
#command_matches_lists.append(region_matches)
|
||||
subpieces = [
|
||||
get_substr((start, end))
|
||||
for start, end in zip(
|
||||
[prev_index, *(m.end() for m in region_matches)],
|
||||
[*(m.start() for m in region_matches), next_index]
|
||||
)
|
||||
]
|
||||
for span, region_range in zip(inserted_interval_spans, inserted_interval_region_spans)
|
||||
]
|
||||
content_pieces.append(join_strs(subpieces, [
|
||||
self.replace_for_content(m) for m in region_matches
|
||||
]))
|
||||
matching_pieces.append(join_strs(subpieces, [
|
||||
self.replace_for_matching(m) for m in region_matches
|
||||
]))
|
||||
#subpieces_lists.append([
|
||||
# get_substr(s)
|
||||
# for s in get_complement_spans(
|
||||
# (prev_index, next_index),
|
||||
# [m.span() for m in region_matches]
|
||||
# )
|
||||
#])
|
||||
|
||||
def get_replaced_pieces(replace_func: Callable[[re.Match], str]) -> list[str]:
|
||||
return [
|
||||
join_strs(subpieces, [
|
||||
replace_func(command_match)
|
||||
for command_match in command_matches[slice(*region_range)]
|
||||
])
|
||||
for subpieces, region_range in zip(subpieces_groups, inserted_interval_region_spans)
|
||||
]
|
||||
|
||||
content_pieces = get_replaced_pieces(self.replace_for_content)
|
||||
matching_pieces = get_replaced_pieces(self.replace_for_matching)
|
||||
|
||||
#inserted_interval_spans = get_neighbouring_pairs([
|
||||
# index
|
||||
# for _, _, index, _ in inserted_items
|
||||
#])
|
||||
#command_matches_lists = [
|
||||
# command_matches[slice(*region_range)]
|
||||
# for region_range in get_neighbouring_pairs([
|
||||
# region_index
|
||||
# for _, _, _, region_index in inserted_items
|
||||
# ])
|
||||
#]
|
||||
|
||||
#subpieces_lists = [
|
||||
# [
|
||||
# get_substr(s)
|
||||
# for s in get_complement_spans(
|
||||
# span, [m.span() for m in match_list]
|
||||
# )
|
||||
# ]
|
||||
# for span, match_list in zip(inserted_interval_spans, command_matches_lists)
|
||||
#]
|
||||
|
||||
#def get_replaced_pieces(replace_func: Callable[[re.Match], str]) -> list[str]:
|
||||
# return [
|
||||
# join_strs(subpieces, [
|
||||
# replace_func(command_match)
|
||||
# for command_match in match_list
|
||||
# ])
|
||||
# for subpieces, match_list in zip(subpieces_lists, command_matches_lists)
|
||||
# ]
|
||||
|
||||
#content_pieces = get_replaced_pieces(self.replace_for_content)
|
||||
#matching_pieces = get_replaced_pieces(self.replace_for_matching)
|
||||
|
||||
def get_content(is_labelled: bool) -> str:
|
||||
inserted_strings = [
|
||||
|
@ -400,12 +444,16 @@ class StringMobject(SVGMobject, ABC):
|
|||
)
|
||||
]
|
||||
|
||||
def get_index(label, flag):
|
||||
def get_region_index(label, flag):
|
||||
#if label == -1:
|
||||
# if flag == 1:
|
||||
# return 0
|
||||
# return len(inserted_label_items) - 1
|
||||
return inserted_label_items.index((label, flag))
|
||||
|
||||
def get_labelled_span(label):
|
||||
if label == -1:
|
||||
return (0, len(self.string))
|
||||
#if label == -1:
|
||||
# return (0, len(self.string))
|
||||
return labelled_spans[label]
|
||||
|
||||
def label_contains(label_0, label_1):
|
||||
|
@ -413,53 +461,54 @@ class StringMobject(SVGMobject, ABC):
|
|||
get_labelled_span(label_0), get_labelled_span(label_1)
|
||||
)
|
||||
|
||||
#piece_starts = [
|
||||
# get_index(group_labels[0], 1),
|
||||
# *(
|
||||
# get_index(curr_label, 1)
|
||||
# if label_contains(prev_label, curr_label)
|
||||
# else get_index(prev_label, -1)
|
||||
# for prev_label, curr_label in get_neighbouring_pairs(
|
||||
# group_labels
|
||||
# )
|
||||
# )
|
||||
#]
|
||||
#piece_ends = [
|
||||
# *(
|
||||
# get_index(curr_label, -1)
|
||||
# if label_contains(next_label, curr_label)
|
||||
# else get_index(next_label, 1)
|
||||
# for curr_label, next_label in get_neighbouring_pairs(
|
||||
# group_labels
|
||||
# )
|
||||
# ),
|
||||
# get_index(group_labels[-1], -1)
|
||||
#]
|
||||
|
||||
piece_ranges = get_complement_spans(
|
||||
(get_index(group_labels[0], 1), get_index(group_labels[-1], -1)),
|
||||
[
|
||||
(
|
||||
get_index(next_label, 1)
|
||||
if label_contains(prev_label, next_label)
|
||||
else get_index(prev_label, -1),
|
||||
get_index(prev_label, -1)
|
||||
if label_contains(next_label, prev_label)
|
||||
else get_index(next_label, 1)
|
||||
)
|
||||
for prev_label, next_label in get_neighbouring_pairs(
|
||||
piece_starts = [
|
||||
get_region_index(group_labels[0], 1),
|
||||
*(
|
||||
get_region_index(curr_label, 1)
|
||||
if label_contains(prev_label, curr_label)
|
||||
else get_region_index(prev_label, -1)
|
||||
for prev_label, curr_label in get_neighbouring_pairs(
|
||||
group_labels
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
]
|
||||
piece_ends = [
|
||||
*(
|
||||
get_region_index(curr_label, -1)
|
||||
if label_contains(next_label, curr_label)
|
||||
else get_region_index(next_label, 1)
|
||||
for curr_label, next_label in get_neighbouring_pairs(
|
||||
group_labels
|
||||
)
|
||||
),
|
||||
get_region_index(group_labels[-1], -1)
|
||||
]
|
||||
|
||||
#piece_ranges = get_complement_spans(
|
||||
# (get_region_index(group_labels[0], 1), get_region_index(group_labels[-1], -1)),
|
||||
# [
|
||||
# (
|
||||
# get_region_index(next_label, 1)
|
||||
# if label_contains(prev_label, next_label)
|
||||
# else get_region_index(prev_label, -1),
|
||||
# get_region_index(prev_label, -1)
|
||||
# if label_contains(next_label, prev_label)
|
||||
# else get_region_index(next_label, 1)
|
||||
# )
|
||||
# for prev_label, next_label in get_neighbouring_pairs(
|
||||
# group_labels
|
||||
# )
|
||||
# ]
|
||||
#)
|
||||
group_substrs = [
|
||||
re.sub(r"\s+", "", "".join(
|
||||
matching_pieces[slice(*piece_ranges)]
|
||||
matching_pieces[start:end]
|
||||
))
|
||||
for piece_ranges in piece_ranges
|
||||
for start, end in zip(piece_starts, piece_ends)
|
||||
]
|
||||
return list(zip(group_substrs, submob_indices_lists))
|
||||
|
||||
#print(labelled_spans)
|
||||
self.labelled_spans = labelled_spans
|
||||
self.get_content = get_content
|
||||
self.get_group_part_items_by_labels = get_group_part_items_by_labels
|
||||
|
@ -516,9 +565,7 @@ class StringMobject(SVGMobject, ABC):
|
|||
return [
|
||||
submob_index
|
||||
for submob_index, label in enumerate(self.labels)
|
||||
if label != -1 and self.span_contains(
|
||||
arbitrary_span, self.labelled_spans[label]
|
||||
)
|
||||
if self.span_contains(arbitrary_span, self.labelled_spans[label])
|
||||
]
|
||||
|
||||
def get_specified_part_items(self) -> list[tuple[str, list[int]]]:
|
||||
|
@ -527,7 +574,7 @@ class StringMobject(SVGMobject, ABC):
|
|||
self.string[slice(*span)],
|
||||
self.get_submob_indices_list_by_span(span)
|
||||
)
|
||||
for span in self.labelled_spans
|
||||
for span in self.labelled_spans[:-1]
|
||||
]
|
||||
|
||||
def get_group_part_items(self) -> list[tuple[str, list[int]]]:
|
||||
|
|
|
@ -80,6 +80,7 @@ class MarkupText(StringMobject):
|
|||
"t2w": {},
|
||||
"global_config": {},
|
||||
"local_configs": {},
|
||||
"disable_ligatures": True,
|
||||
"isolate": re.compile(r"\w+", re.U),
|
||||
}
|
||||
|
||||
|
@ -150,7 +151,8 @@ class MarkupText(StringMobject):
|
|||
self.t2s,
|
||||
self.t2w,
|
||||
self.global_config,
|
||||
self.local_configs
|
||||
self.local_configs,
|
||||
self.disable_ligatures
|
||||
)
|
||||
|
||||
def full2short(self, config: dict) -> None:
|
||||
|
@ -359,9 +361,8 @@ class MarkupText(StringMobject):
|
|||
"font_family": self.font,
|
||||
"font_style": self.slant,
|
||||
"font_weight": self.weight,
|
||||
"font_size": str(self.font_size * 1024),
|
||||
"font_size": str(round(self.font_size * 1024)),
|
||||
}
|
||||
global_attr_dict.update(self.global_config)
|
||||
# `line_height` attribute is supported since Pango 1.50.
|
||||
pango_version = manimpango.pango_version()
|
||||
if tuple(map(int, pango_version.split("."))) < (1, 50):
|
||||
|
@ -376,7 +377,10 @@ class MarkupText(StringMobject):
|
|||
global_attr_dict["line_height"] = str(
|
||||
((line_spacing_scale) + 1) * 0.6
|
||||
)
|
||||
if self.disable_ligatures:
|
||||
global_attr_dict["font_features"] = "liga=0,dlig=0,clig=0,hlig=0"
|
||||
|
||||
global_attr_dict.update(self.global_config)
|
||||
return tuple(
|
||||
self.get_command_string(
|
||||
global_attr_dict,
|
||||
|
@ -413,8 +417,9 @@ class Text(MarkupText):
|
|||
}
|
||||
|
||||
@staticmethod
|
||||
def get_command_pattern() -> str | None:
|
||||
return r"""[<>&"']"""
|
||||
def get_command_matches(string: str) -> list[re.Match]:
|
||||
pattern = re.compile(r"""[<>&"']""")
|
||||
return list(pattern.finditer(string))
|
||||
|
||||
@staticmethod
|
||||
def get_command_flag(match_obj: re.Match) -> int:
|
||||
|
|
|
@ -253,7 +253,7 @@ gfs_didot: |-
|
|||
\let\varphi\phi
|
||||
|
||||
# GFS NeoHellenic
|
||||
gfs_neoHellenic: |-
|
||||
gfs_neohellenic: |-
|
||||
\usepackage[T1]{fontenc}
|
||||
\renewcommand{\rmdefault}{neohellenic}
|
||||
\usepackage[LGRgreek]{mathastext}
|
||||
|
|
|
@ -16,7 +16,7 @@ SAVED_TEX_CONFIG = {}
|
|||
|
||||
|
||||
def get_tex_preamble(template_name: str) -> str:
|
||||
name = re.sub(r"[^a-zA-Z]", "_", template_name).lower()
|
||||
name = template_name.replace(" ", "_").lower()
|
||||
with open(os.path.join(
|
||||
get_manim_dir(), "manimlib", "tex_templates.yml"
|
||||
), encoding="utf-8") as tex_templates_file:
|
||||
|
@ -143,7 +143,7 @@ def create_tex_svg(full_tex: str, svg_file: str, compiler: str) -> None:
|
|||
|
||||
# TODO, perhaps this should live elsewhere
|
||||
@contextmanager
|
||||
def display_during_execution(message: str) -> None:
|
||||
def display_during_execution(message: str):
|
||||
# Merge into a single line
|
||||
to_print = message.replace("\n", " ")
|
||||
max_characters = os.get_terminal_size().columns - 1
|
||||
|
|
Loading…
Add table
Reference in a new issue