Refactor MTex

This commit is contained in:
YishiMichael 2022-03-27 00:29:22 +08:00
parent 9ac1805e7e
commit e44a2fc8c6
No known key found for this signature in database
GPG key ID: EC615C0C5A86BC80

View file

@ -6,7 +6,7 @@ import itertools as it
from types import MethodType from types import MethodType
from typing import Iterable, Union, Sequence from typing import Iterable, Union, Sequence
from manimlib.constants import BLACK, WHITE from manimlib.constants import WHITE
from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.svg.svg_mobject import SVGMobject
from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.utils.color import color_to_int_rgb from manimlib.utils.color import color_to_int_rgb
@ -32,9 +32,6 @@ SCALE_FACTOR_PER_FONT_POINT = 0.001
class _TexSVG(SVGMobject): class _TexSVG(SVGMobject):
CONFIG = { CONFIG = {
"height": None, "height": None,
"svg_default": {
"fill_color": WHITE,
},
"stroke_width": 0, "stroke_width": 0,
"stroke_color": WHITE, "stroke_color": WHITE,
"path_string_config": { "path_string_config": {
@ -46,6 +43,7 @@ class _TexSVG(SVGMobject):
class MTex(_TexSVG): class MTex(_TexSVG):
CONFIG = { CONFIG = {
"base_color": WHITE,
"font_size": 48, "font_size": 48,
"alignment": "\\centering", "alignment": "\\centering",
"tex_environment": "align*", "tex_environment": "align*",
@ -54,13 +52,14 @@ class MTex(_TexSVG):
"use_plain_tex": False, "use_plain_tex": False,
} }
def __init__(self, tex_string: str, **kwargs): def __init__(self, string: str, **kwargs):
digest_config(self, kwargs) digest_config(self, kwargs)
tex_string = tex_string.strip() string = string.strip()
# Prevent from passing an empty string. # Prevent from passing an empty string.
if not tex_string: if not string:
tex_string = "\\quad" string = "\\quad"
self.tex_string = tex_string self.tex_string = string
self.string = string
super().__init__(**kwargs) super().__init__(**kwargs)
self.set_color_by_tex_to_color_map(self.tex_to_color_map) self.set_color_by_tex_to_color_map(self.tex_to_color_map)
@ -72,7 +71,8 @@ class MTex(_TexSVG):
self.__class__.__name__, self.__class__.__name__,
self.svg_default, self.svg_default,
self.path_string_config, self.path_string_config,
self.tex_string, self.string,
self.base_color,
self.alignment, self.alignment,
self.tex_environment, self.tex_environment,
self.isolate, self.isolate,
@ -82,45 +82,43 @@ class MTex(_TexSVG):
def get_file_path(self) -> str: def get_file_path(self) -> str:
self.init_parser() self.init_parser()
self.base_color = self.svg_default["color"] \
or self.svg_default["fill_color"] or WHITE
self.use_plain_file = any([ self.use_plain_file = any([
self.use_plain_tex, self.use_plain_tex,
self.color_cmd_repl_items, self.color_cmd_repl_items,
self.base_color not in (BLACK, WHITE) self.base_color != WHITE
]) ])
return self.get_file_path_(use_plain_file=self.use_plain_file) return self.get_file_path_(use_plain_file=self.use_plain_file)
def get_file_path_(self, use_plain_file: bool) -> str: def get_file_path_(self, use_plain_file: bool) -> str:
if use_plain_file: if use_plain_file:
tex_string = "".join([ content = "".join([
"{{", "{{",
self.get_color_command(int(self.base_color[1:], 16)), self.get_color_command(self.color_to_int(self.base_color)),
self.tex_string, self.string,
"}}" "}}"
]) ])
else: else:
tex_string = self.labelled_tex_string content = self.get_labelled_string()
full_tex = self.get_tex_file_body(tex_string) full_tex = self.get_tex_file_body(content)
with display_during_execution(f"Writing \"{self.tex_string}\""): with display_during_execution(f"Writing \"{self.string}\""):
file_path = self.tex_to_svg_file_path(full_tex) file_path = self.tex_to_svg_file_path(full_tex)
return file_path return file_path
def get_tex_file_body(self, tex_string: str) -> str: def get_tex_file_body(self, content: str) -> str:
if self.tex_environment: if self.tex_environment:
tex_string = "\n".join([ content = "\n".join([
f"\\begin{{{self.tex_environment}}}", f"\\begin{{{self.tex_environment}}}",
tex_string, content,
f"\\end{{{self.tex_environment}}}" f"\\end{{{self.tex_environment}}}"
]) ])
if self.alignment: if self.alignment:
tex_string = "\n".join([self.alignment, tex_string]) content = "\n".join([self.alignment, content])
tex_config = get_tex_config() tex_config = get_tex_config()
return tex_config["tex_body"].replace( return tex_config["tex_body"].replace(
tex_config["text_to_replace"], tex_config["text_to_replace"],
tex_string content
) )
@staticmethod @staticmethod
@ -136,9 +134,7 @@ class MTex(_TexSVG):
if self.use_plain_file: if self.use_plain_file:
file_path = self.get_file_path_(use_plain_file=False) file_path = self.get_file_path_(use_plain_file=False)
labelled_svg_glyphs = _TexSVG( labelled_svg_glyphs = _TexSVG(file_path)
file_path, svg_default={"fill_color": BLACK}
)
predefined_colors = [ predefined_colors = [
labelled_glyph.get_fill_color() labelled_glyph.get_fill_color()
for labelled_glyph in self.submobjects for labelled_glyph in self.submobjects
@ -166,7 +162,7 @@ class MTex(_TexSVG):
for submob, label, submob_tex in zip( for submob, label, submob_tex in zip(
submobjects, submob_labels, submob_tex_strings submobjects, submob_labels, submob_tex_strings
): ):
submob.submob_label = label submob.label = label
submob.tex_string = submob_tex submob.tex_string = submob_tex
# Support `get_tex()` method here. # Support `get_tex()` method here.
submob.get_tex = MethodType(lambda inst: inst.tex_string, submob) submob.get_tex = MethodType(lambda inst: inst.tex_string, submob)
@ -175,13 +171,17 @@ class MTex(_TexSVG):
## Static methods ## Static methods
@staticmethod @staticmethod
def color_to_label(color: ManimColor) -> int: def color_to_int(color: ManimColor) -> int:
r, g, b = color_to_int_rgb(color) r, g, b = color_to_int_rgb(color)
rg = r * 256 + g rg = r * 256 + g
rgb = rg * 256 + b return rg * 256 + b
if rgb == 16777215: # white
return 0 @staticmethod
return rgb def color_to_label(color: ManimColor) -> int:
result = MTex.color_to_int(color)
if result == 16777215: # white
return -1
return result
@staticmethod @staticmethod
def get_color_command(label: int) -> str: def get_color_command(label: int) -> str:
@ -227,6 +227,7 @@ class MTex(_TexSVG):
def init_parser(self) -> None: def init_parser(self) -> None:
self.additional_substrings = self.get_additional_substrings() self.additional_substrings = self.get_additional_substrings()
self.full_span = self.get_full_span()
self.backslash_indices = self.get_backslash_indices() self.backslash_indices = self.get_backslash_indices()
self.left_brace_indices, self.right_brace_indices = \ self.left_brace_indices, self.right_brace_indices = \
self.get_left_and_right_indices() self.get_left_and_right_indices()
@ -236,15 +237,15 @@ class MTex(_TexSVG):
self.script_content_spans = self.get_script_content_spans() self.script_content_spans = self.get_script_content_spans()
self.double_braces_spans = self.get_double_braces_spans() self.double_braces_spans = self.get_double_braces_spans()
self.stripped_substrings = self.get_stripped_substrings() self.stripped_substrings = self.get_stripped_substrings()
self.specified_spans = self.get_specified_spans()
self.specified_substrings = self.get_specified_substrings() self.specified_substrings = self.get_specified_substrings()
self.specified_spans, self.extended_specified_spans = \
self.get_specified_spans()
self.tex_span_list = self.get_tex_span_list() self.tex_span_list = self.get_tex_span_list()
self.extended_tex_span_list = self.get_extended_tex_span_list() self.extended_tex_span_list = self.get_extended_tex_span_list()
self.isolated_substrings = self.get_isolated_substrings() self.isolated_substrings = self.get_isolated_substrings()
self.containing_labels_dict = self.get_containing_labels_dict() self.containing_labels_dict = self.get_containing_labels_dict()
self.color_cmd_repl_items = self.get_color_cmd_repl_items() self.color_cmd_repl_items = self.get_color_cmd_repl_items()
self.span_repl_dict = self.get_span_repl_dict() self.span_repl_dict = self.get_span_repl_dict()
self.labelled_tex_string = self.get_labelled_tex_string()
def get_additional_substrings(self) -> list[str]: def get_additional_substrings(self) -> list[str]:
return list(it.chain( return list(it.chain(
@ -252,28 +253,31 @@ class MTex(_TexSVG):
self.isolate self.isolate
)) ))
def get_full_span(self) -> tuple[int, int]:
return (0, len(self.string))
def get_backslash_indices(self) -> list[int]: def get_backslash_indices(self) -> list[int]:
# Newlines (`\\`) don't count. # Newlines (`\\`) don't count.
return [ return [
match_obj.end() - 1 match_obj.end() - 1
for match_obj in re.finditer(r"\\+", self.tex_string) for match_obj in re.finditer(r"\\+", self.string)
if len(match_obj.group()) % 2 == 1 if len(match_obj.group()) % 2 == 1
] ]
def get_left_and_right_indices(self) -> list[tuple[int, int]]: def get_left_and_right_indices(self) -> list[tuple[int, int]]:
tex_string = self.tex_string string = self.string
indices = list(filter( indices = list(filter(
lambda index: index - 1 not in self.backslash_indices, lambda index: index - 1 not in self.backslash_indices,
[ [
match_obj.start() match_obj.start()
for match_obj in re.finditer(r"[{}]", tex_string) for match_obj in re.finditer(r"[{}]", string)
] ]
)) ))
left_brace_indices = [] left_brace_indices = []
right_brace_indices = [] right_brace_indices = []
left_brace_indices_stack = [] left_brace_indices_stack = []
for index in indices: for index in indices:
if tex_string[index] == "{": if string[index] == "{":
left_brace_indices_stack.append(index) left_brace_indices_stack.append(index)
else: else:
if not left_brace_indices_stack: if not left_brace_indices_stack:
@ -288,7 +292,7 @@ class MTex(_TexSVG):
def get_script_char_spans(self) -> list[tuple[int, int]]: def get_script_char_spans(self) -> list[tuple[int, int]]:
return [ return [
match_obj.span() match_obj.span()
for match_obj in re.finditer(r"(\s*)[_^]\s*", self.tex_string) for match_obj in re.finditer(r"(\s*)[_^]\s*", self.string)
if match_obj.group(1) if match_obj.group(1)
or match_obj.start() - 1 not in self.backslash_indices or match_obj.start() - 1 not in self.backslash_indices
] ]
@ -296,14 +300,14 @@ class MTex(_TexSVG):
def get_skipped_indices(self) -> list[int]: def get_skipped_indices(self) -> list[int]:
return sorted(remove_list_redundancies([ return sorted(remove_list_redundancies([
match_obj.start() match_obj.start()
for match_obj in re.finditer(r"\s", self.tex_string) for match_obj in re.finditer(r"\s", self.string)
] + list(it.chain(*[ ] + list(it.chain(*[
range(*script_char_span) range(*script_char_span)
for script_char_span in self.script_char_spans for script_char_span in self.script_char_spans
])))) ]))))
def get_script_spans(self) -> list[tuple[int, int]]: def get_script_spans(self) -> list[tuple[int, int]]:
tex_string = self.tex_string string = self.string
result = [] result = []
brace_indices_dict = dict(zip( brace_indices_dict = dict(zip(
self.left_brace_indices, self.right_brace_indices self.left_brace_indices, self.right_brace_indices
@ -313,7 +317,7 @@ class MTex(_TexSVG):
span_end = brace_indices_dict[span_begin] + 1 span_end = brace_indices_dict[span_begin] + 1
else: else:
pattern = re.compile(r"[a-zA-Z0-9]|\\[a-zA-Z]+") pattern = re.compile(r"[a-zA-Z0-9]|\\[a-zA-Z]+")
match_obj = pattern.match(tex_string, pos=span_begin) match_obj = pattern.match(string, pos=span_begin)
if not match_obj: if not match_obj:
script_name = { script_name = {
"_": "subscript", "_": "subscript",
@ -361,64 +365,68 @@ class MTex(_TexSVG):
def get_stripped_substrings(self) -> list[str]: def get_stripped_substrings(self) -> list[str]:
result = remove_list_redundancies([ result = remove_list_redundancies([
string.strip() substr.strip()
for string in self.additional_substrings for substr in self.additional_substrings
]) ])
if "" in result: if "" in result:
result.remove("") result.remove("")
return result return result
def get_specified_spans(self) -> list[tuple[int, int]]: def get_specified_substrings(self) -> list[str]:
result = self.double_braces_spans.copy() return remove_list_redundancies([
tex_string = self.tex_string self.string[slice(*double_braces_span)]
for double_braces_span in self.double_braces_spans
] + list(filter(
lambda s: s in self.string,
self.stripped_substrings
)))
def get_specified_spans(
self
) -> tuple[list[tuple[int, int]], list[tuple[int, int]]]:
tex_spans = sorted(remove_list_redundancies([
self.full_span,
*self.double_braces_spans,
*[
match_obj.span()
for substr in self.stripped_substrings
for match_obj in re.finditer(re.escape(substr), self.string)
]
]), key=lambda t: (t[0], -t[1]))
result = []
extended_result = []
script_spans_dict = dict(self.script_spans)
reversed_script_spans_dict = dict([ reversed_script_spans_dict = dict([
script_span[::-1] for script_span in self.script_spans script_span[::-1] for script_span in self.script_spans
]) ])
for string in self.stripped_substrings: for tex_span in tex_spans:
for match_obj in re.finditer(re.escape(string), tex_string): if tex_span in self.script_content_spans:
span_begin, span_end = match_obj.span() continue
while span_end in reversed_script_spans_dict.keys(): span_begin, span_end = tex_span
span_end = reversed_script_spans_dict[span_end] extended_span_end = span_end
if span_begin >= span_end: while span_end in reversed_script_spans_dict.keys():
continue span_end = reversed_script_spans_dict[span_end]
result.append((span_begin, span_end)) while extended_span_end in script_spans_dict.keys():
return list(filter( extended_span_end = script_spans_dict[extended_span_end]
lambda tex_span: tex_span not in self.script_content_spans, specified_span = (span_begin, span_end)
remove_list_redundancies(result) extended_specified_span = (span_begin, extended_span_end)
)) if span_begin >= span_end:
continue
def get_specified_substrings(self) -> list[str]: if extended_specified_span in result:
return remove_list_redundancies([ continue
self.tex_string[slice(*double_braces_span)] result.append(specified_span)
for double_braces_span in self.double_braces_spans extended_result.append(extended_specified_span)
] + list(filter( return result, extended_result
lambda s: s in self.tex_string,
self.additional_substrings
)))
def get_tex_span_list(self) -> list[tuple[int, int]]: def get_tex_span_list(self) -> list[tuple[int, int]]:
return [ return self.specified_spans + self.script_content_spans
(0, len(self.tex_string)),
*self.script_content_spans,
*self.specified_spans
]
def get_extended_tex_span_list(self) -> list[tuple[int, int]]: def get_extended_tex_span_list(self) -> list[tuple[int, int]]:
extended_specified_spans = [] return self.extended_specified_spans + self.script_content_spans
script_spans_dict = dict(self.script_spans)
for span_begin, span_end in self.specified_spans:
while span_end in script_spans_dict.keys():
span_end = script_spans_dict[span_end]
extended_specified_spans.append((span_begin, span_end))
return [
(0, len(self.tex_string)),
*self.script_content_spans,
*extended_specified_spans
]
def get_isolated_substrings(self) -> list[str]: def get_isolated_substrings(self) -> list[str]:
return remove_list_redundancies([ return remove_list_redundancies([
self.tex_string[slice(*tex_span)] self.string[slice(*tex_span)]
for tex_span in self.tex_span_list for tex_span in self.tex_span_list
]) ])
@ -434,37 +442,38 @@ class MTex(_TexSVG):
result[span_0].append(span_index) result[span_0].append(span_index)
elif span_0[0] < span_1[0] < span_0[1] < span_1[1]: elif span_0[0] < span_1[0] < span_0[1] < span_1[1]:
string_0, string_1 = [ string_0, string_1 = [
self.tex_string[slice(*tex_span)] self.string[slice(*tex_span)]
for tex_span in [span_0, span_1] for tex_span in [span_0, span_1]
] ]
raise ValueError( raise ValueError(
"Partially overlapping substrings detected: " "Partially overlapping substrings detected: "
f"'{string_0}' and '{string_1}'" f"'{string_0}' and '{string_1}'"
) )
result[self.full_span] = list(range(len(tex_span_list)))
return result return result
def get_color_cmd_repl_items(self) -> list[tuple[tuple[int, int], str]]: def get_color_cmd_repl_items(self) -> list[tuple[tuple[int, int], str]]:
color_related_commands_dict = { color_related_command_items = [
"color": 1, ("color", 1, ""),
"textcolor": 1, ("textcolor", 1, ""),
"pagecolor": 1, ("pagecolor", 1, "\\pagecolor{white}"),
"colorbox": 1, ("colorbox", 1, "\\colorbox{white}"),
"fcolorbox": 2, ("fcolorbox", 2, "\\fcolorbox{white}{white}"),
} ]
result = [] result = []
tex_string = self.tex_string string = self.string
backslash_indices = self.backslash_indices backslash_indices = self.backslash_indices
left_indices = self.left_brace_indices left_indices = self.left_brace_indices
brace_indices_dict = dict(zip( brace_indices_dict = dict(zip(
self.left_brace_indices, self.right_brace_indices self.left_brace_indices, self.right_brace_indices
)) ))
for cmd_name, n_braces in color_related_commands_dict.items(): for cmd_name, n_braces, repl_str in color_related_command_items:
pattern = cmd_name + r"(?![a-zA-Z])" pattern = cmd_name + r"(?![a-zA-Z])"
for match_obj in re.finditer(pattern, tex_string): for match_obj in re.finditer(pattern, string):
span_begin, span_end = match_obj.span() span_begin, span_end = match_obj.span()
if span_begin - 1 not in backslash_indices: span_begin -= 1
if span_begin not in backslash_indices:
continue continue
repl_str = cmd_name + n_braces * "{black}"
for _ in range(n_braces): for _ in range(n_braces):
left_index = min(filter( left_index = min(filter(
lambda index: index >= span_end, left_indices lambda index: index >= span_end, left_indices
@ -481,7 +490,7 @@ class MTex(_TexSVG):
-tex_span[1 - flag], -tex_span[1 - flag],
("{{" + self.get_color_command(label), "}}")[flag] ("{{" + self.get_color_command(label), "}}")[flag]
) )
for label, tex_span in enumerate(self.extended_tex_span_list) for label, tex_span in enumerate(self.tex_span_list)
for flag in range(2) for flag in range(2)
])) ]))
result = { result = {
@ -493,9 +502,9 @@ class MTex(_TexSVG):
result.update(self.color_cmd_repl_items) result.update(self.color_cmd_repl_items)
return result return result
def get_labelled_tex_string(self) -> str: def get_labelled_string(self) -> str:
if not self.span_repl_dict: if not self.span_repl_dict:
return self.tex_string return self.string
spans = sorted(self.span_repl_dict.keys()) spans = sorted(self.span_repl_dict.keys())
if not all( if not all(
@ -506,10 +515,10 @@ class MTex(_TexSVG):
span_ends, span_begins = zip(*spans) span_ends, span_begins = zip(*spans)
string_pieces = [ string_pieces = [
self.tex_string[slice(*span)] self.string[slice(*span)]
for span in zip( for span in zip(
(0, *span_begins), (0, *span_begins),
(*span_ends, len(self.tex_string)) (*span_ends, len(self.string))
) )
] ]
repl_strs = [ repl_strs = [
@ -521,7 +530,7 @@ class MTex(_TexSVG):
def get_submob_tex_strings(self, submob_labels: list[int]) -> list[str]: def get_submob_tex_strings(self, submob_labels: list[int]) -> list[str]:
ordered_tex_spans = [ ordered_tex_spans = [
self.tex_span_list[label] self.tex_span_list[label] if label != -1 else self.full_span
for label in submob_labels for label in submob_labels
] ]
ordered_containing_labels = [ ordered_containing_labels = [
@ -546,7 +555,7 @@ class MTex(_TexSVG):
] ]
string_span_ends.append(ordered_span_ends[-1]) string_span_ends.append(ordered_span_ends[-1])
tex_string = self.tex_string string = self.string
left_indices = self.left_brace_indices left_indices = self.left_brace_indices
right_indices = self.right_brace_indices right_indices = self.right_brace_indices
skipped_indices = sorted(it.chain( skipped_indices = sorted(it.chain(
@ -575,7 +584,7 @@ class MTex(_TexSVG):
unclosed_left_brace -= 1 unclosed_left_brace -= 1
result.append("".join([ result.append("".join([
unclosed_right_brace * "{", unclosed_right_brace * "{",
tex_string[span_begin:span_end], string[span_begin:span_end],
unclosed_left_brace * "}" unclosed_left_brace * "}"
])) ]))
return result return result
@ -615,7 +624,7 @@ class MTex(_TexSVG):
custom_span custom_span
) )
if tex_spans is None: if tex_spans is None:
tex = self.tex_string[slice(*custom_span)] tex = self.string[slice(*custom_span)]
raise ValueError(f"Failed to match mobjects from tex: \"{tex}\"") raise ValueError(f"Failed to match mobjects from tex: \"{tex}\"")
labels = set(it.chain(*[ labels = set(it.chain(*[
@ -623,7 +632,7 @@ class MTex(_TexSVG):
for tex_span in tex_spans for tex_span in tex_spans
])) ]))
return VGroup(*filter( return VGroup(*filter(
lambda submob: submob.submob_label in labels, lambda submob: submob.label in labels,
self.submobjects self.submobjects
)) ))
@ -631,7 +640,7 @@ class MTex(_TexSVG):
return VGroup(*[ return VGroup(*[
self.get_part_by_custom_span(match_obj.span()) self.get_part_by_custom_span(match_obj.span())
for match_obj in re.finditer( for match_obj in re.finditer(
re.escape(tex.strip()), self.tex_string re.escape(tex.strip()), self.string
) )
]) ])
@ -665,7 +674,7 @@ class MTex(_TexSVG):
return self.indices_of_part(part) return self.indices_of_part(part)
def get_tex(self) -> str: def get_tex(self) -> str:
return self.tex_string return self.string
def get_submob_tex(self) -> list[str]: def get_submob_tex(self) -> list[str]:
return [ return [