Support get_tex() for submobjects of MTex

This commit is contained in:
Michael W 2021-12-07 00:34:07 +08:00 committed by GitHub
parent d7dcc9d76f
commit 88d863c1d7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1,5 +1,6 @@
import itertools as it
import re
from types import MethodType
from manimlib.mobject.svg.svg_mobject import SVGMobject
from manimlib.mobject.types.vectorized_mobject import VMobject
@ -33,20 +34,11 @@ class _LabelledTex(SVGMobject):
color_str = "#" + "".join([c * 2 for c in color_str[1:]])
return int(color_str[1:], 16) - 1
@staticmethod
def href_str_to_id(href_str):
match_obj = re.match(r"^#g(\d+)-(\d+)$", href_str)
if not match_obj:
return -1
return match_obj.group(2)
def get_mobjects_from(self, element):
result = super().get_mobjects_from(element)
for mob in result:
if not hasattr(mob, "glyph_label"):
mob.glyph_label = -1
if not hasattr(mob, "glyph_id"):
mob.glyph_id = -1
try:
color_str = element.getAttribute("fill")
if color_str:
@ -55,21 +47,13 @@ class _LabelledTex(SVGMobject):
mob.glyph_label = glyph_label
except:
pass
try:
href_str = element.getAttribute("xlink:href")
if href_str:
glyph_id = _LabelledTex.href_str_to_id(href_str)
for mob in result:
mob.glyph_id = glyph_id
except:
pass
return result
class _TexSpan(object):
def __init__(self, script_type, label):
# script_type: 0 for normal, 1 for subscript, 2 for superscript.
# Only those spans with `label != -1` will be colored.
# Only those spans with `script_type == 0` will be colored.
self.script_type = script_type
self.label = label
self.containing_labels = []
@ -113,6 +97,7 @@ class MTex(VMobject):
])
self.build_submobjects()
self.sort_scripts_in_tex_order()
self.assign_submob_tex_strings()
self.init_colors()
self.set_color_by_tex_to_color_map(self.tex_to_color_map)
@ -135,13 +120,11 @@ class MTex(VMobject):
def get_neighbouring_pairs(iterable):
return list(adjacent_pairs(iterable))[:-1]
def add_tex_span(self, span_tuple, script_type=0):
def add_tex_span(self, span_tuple, script_type=0, label=-1):
if script_type == 0:
# Should be additionally labelled.
label = self.current_label
self.current_label += 1
else:
label = -1
tex_span = _TexSpan(script_type, label)
self.tex_spans_dict[span_tuple] = tex_span
@ -203,12 +186,13 @@ class MTex(VMobject):
label = self.tex_spans_dict[content_span].label
self.add_tex_span(
(token_begin, content_span[1]),
script_type=script_type
script_type=script_type,
label=label
)
def break_up_by_additional_strings(self):
additional_strings_to_break_up = remove_list_redundancies([
*self.isolate, *self.tex_to_color_map.keys()
*self.isolate, *self.tex_to_color_map.keys(), self.tex_string
])
if "" in additional_strings_to_break_up:
additional_strings_to_break_up.remove("")
@ -218,28 +202,29 @@ class MTex(VMobject):
tex_string = self.tex_string
all_span_tuples = list(self.tex_spans_dict.keys())
for string in additional_strings_to_break_up:
# Only matches non-overlapping strings.
# Only matches non-crossing strings.
for match_obj in re.finditer(re.escape(string), tex_string):
all_span_tuples.append(match_obj.span())
# Deconstruct spans with subscripts & superscripts.
script_spans_dict = dict([
span_tuple[::-1]
for span_tuple, tex_span in self.tex_spans_dict.items()
if tex_span.script_type != 0
])
for span_begin, span_end in all_span_tuples:
while span_end in script_spans_dict:
span_end = script_spans_dict[span_end]
if span_begin >= span_end:
continue
if span_begin in script_spans_dict.values():
# Deconstruct spans with subscripts & superscripts.
while span_end in script_spans_dict:
span_end = script_spans_dict[span_end]
if span_begin >= span_end:
continue
span_tuple = (span_begin, span_end)
if span_tuple not in self.tex_spans_dict:
self.add_tex_span(span_tuple)
def analyse_containing_labels(self):
for span_0, tex_span_0 in self.tex_spans_dict.items():
if tex_span_0.label == -1:
if tex_span_0.script_type != 0:
continue
for span_1, tex_span_1 in self.tex_spans_dict.items():
if span_1[0] <= span_0[0] and span_0[1] <= span_1[1]:
@ -269,11 +254,11 @@ class MTex(VMobject):
indices_with_labels = sorted([
(span_tuple[i], i, span_tuple[1 - i], tex_span.label)
for span_tuple, tex_span in self.tex_spans_dict.items()
if tex_span.label != -1
if tex_span.script_type == 0
for i in range(2)
], key=lambda t: (t[0], -t[1], -t[2]))
# Add one more item to ensure all the substrings are joined.
indices_with_labels.append((len(tex_string), 0, 0, -1))
indices_with_labels.append((len(tex_string), 0, 0, 0))
result = tex_string[: indices_with_labels[0][0]]
index_with_label_pairs = MTex.get_neighbouring_pairs(indices_with_labels)
@ -301,7 +286,6 @@ class MTex(VMobject):
if glyphs:
submobject = VGroup(*glyphs)
submobject.submob_label = glyphs[0].glyph_label
submobject.submob_id_tuple = tuple([glyph.glyph_id for glyph in glyphs])
new_submobjects.append(submobject)
new_glyphs = []
@ -350,6 +334,42 @@ class MTex(VMobject):
])
return self
def assign_submob_tex_strings(self):
tex_string = self.tex_string
label_dict = {
tex_span.label: (span_tuple, tex_span.containing_labels)
for span_tuple, tex_span in self.tex_spans_dict.items()
if tex_span.script_type == 0
}
# Use tex strings with "_", "^" included.
label_dict.update({
tex_span.label: (span_tuple, tex_span.containing_labels)
for span_tuple, tex_span in self.tex_spans_dict.items()
if tex_span.script_type != 0
})
curr_labels = [submob.submob_label for submob in self.submobjects]
prev_labels = [curr_labels[-1], *curr_labels[:-1]]
next_labels = [*curr_labels[1:], curr_labels[0]]
tex_string_spans = []
for curr_submob_label, prev_submob_label, next_submob_label in zip(
curr_labels, prev_labels, next_labels
):
curr_span_tuple, containing_labels = label_dict[curr_submob_label]
prev_span_tuple, _ = label_dict[prev_submob_label]
next_span_tuple, _ = label_dict[next_submob_label]
tex_string_spans.append([
prev_span_tuple[1] if prev_submob_label in containing_labels else curr_span_tuple[0],
next_span_tuple[0] if next_submob_label in containing_labels else curr_span_tuple[1]
])
tex_string_spans[0][0] = label_dict[curr_labels[0]][0][0]
tex_string_spans[-1][1] = label_dict[curr_labels[-1]][0][1]
for submob, tex_string_span in zip(self.submobjects, tex_string_spans):
submob.tex_string = tex_string[slice(*tex_string_span)]
# Support `get_tex()` method here.
submob.get_tex = MethodType(lambda inst: inst.tex_string, submob)
return self
def get_part_by_span_tuples(self, span_tuples):
labels = remove_list_redundancies(list(it.chain(*[
self.tex_spans_dict[span_tuple].containing_labels