diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 53dc15da..1a71783d 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -6,7 +6,7 @@ import itertools as it from types import MethodType from typing import Iterable, Union, Sequence -from manimlib.constants import BLACK, WHITE +from manimlib.constants import WHITE from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.utils.color import color_to_int_rgb @@ -32,9 +32,6 @@ SCALE_FACTOR_PER_FONT_POINT = 0.001 class _TexSVG(SVGMobject): CONFIG = { "height": None, - "svg_default": { - "fill_color": WHITE, - }, "stroke_width": 0, "stroke_color": WHITE, "path_string_config": { @@ -46,6 +43,7 @@ class _TexSVG(SVGMobject): class MTex(_TexSVG): CONFIG = { + "base_color": WHITE, "font_size": 48, "alignment": "\\centering", "tex_environment": "align*", @@ -54,13 +52,14 @@ class MTex(_TexSVG): "use_plain_tex": False, } - def __init__(self, tex_string: str, **kwargs): + def __init__(self, string: str, **kwargs): digest_config(self, kwargs) - tex_string = tex_string.strip() + string = string.strip() # Prevent from passing an empty string. - if not tex_string: - tex_string = "\\quad" - self.tex_string = tex_string + if not string: + string = "\\quad" + self.tex_string = string + self.string = string super().__init__(**kwargs) self.set_color_by_tex_to_color_map(self.tex_to_color_map) @@ -72,7 +71,8 @@ class MTex(_TexSVG): self.__class__.__name__, self.svg_default, self.path_string_config, - self.tex_string, + self.string, + self.base_color, self.alignment, self.tex_environment, self.isolate, @@ -82,45 +82,43 @@ class MTex(_TexSVG): def get_file_path(self) -> str: self.init_parser() - self.base_color = self.svg_default["color"] \ - or self.svg_default["fill_color"] or WHITE self.use_plain_file = any([ self.use_plain_tex, self.color_cmd_repl_items, - self.base_color not in (BLACK, WHITE) + self.base_color != WHITE ]) return self.get_file_path_(use_plain_file=self.use_plain_file) def get_file_path_(self, use_plain_file: bool) -> str: if use_plain_file: - tex_string = "".join([ + content = "".join([ "{{", - self.get_color_command(int(self.base_color[1:], 16)), - self.tex_string, + self.get_color_command(self.color_to_int(self.base_color)), + self.string, "}}" ]) else: - tex_string = self.labelled_tex_string + content = self.get_labelled_string() - full_tex = self.get_tex_file_body(tex_string) - with display_during_execution(f"Writing \"{self.tex_string}\""): + full_tex = self.get_tex_file_body(content) + with display_during_execution(f"Writing \"{self.string}\""): file_path = self.tex_to_svg_file_path(full_tex) return file_path - def get_tex_file_body(self, tex_string: str) -> str: + def get_tex_file_body(self, content: str) -> str: if self.tex_environment: - tex_string = "\n".join([ + content = "\n".join([ f"\\begin{{{self.tex_environment}}}", - tex_string, + content, f"\\end{{{self.tex_environment}}}" ]) if self.alignment: - tex_string = "\n".join([self.alignment, tex_string]) + content = "\n".join([self.alignment, content]) tex_config = get_tex_config() return tex_config["tex_body"].replace( tex_config["text_to_replace"], - tex_string + content ) @staticmethod @@ -136,9 +134,7 @@ class MTex(_TexSVG): if self.use_plain_file: file_path = self.get_file_path_(use_plain_file=False) - labelled_svg_glyphs = _TexSVG( - file_path, svg_default={"fill_color": BLACK} - ) + labelled_svg_glyphs = _TexSVG(file_path) predefined_colors = [ labelled_glyph.get_fill_color() for labelled_glyph in self.submobjects @@ -166,7 +162,7 @@ class MTex(_TexSVG): for submob, label, submob_tex in zip( submobjects, submob_labels, submob_tex_strings ): - submob.submob_label = label + submob.label = label submob.tex_string = submob_tex # Support `get_tex()` method here. submob.get_tex = MethodType(lambda inst: inst.tex_string, submob) @@ -175,13 +171,17 @@ class MTex(_TexSVG): ## Static methods @staticmethod - def color_to_label(color: ManimColor) -> int: + def color_to_int(color: ManimColor) -> int: r, g, b = color_to_int_rgb(color) rg = r * 256 + g - rgb = rg * 256 + b - if rgb == 16777215: # white - return 0 - return rgb + return rg * 256 + b + + @staticmethod + def color_to_label(color: ManimColor) -> int: + result = MTex.color_to_int(color) + if result == 16777215: # white + return -1 + return result @staticmethod def get_color_command(label: int) -> str: @@ -227,6 +227,7 @@ class MTex(_TexSVG): def init_parser(self) -> None: self.additional_substrings = self.get_additional_substrings() + self.full_span = self.get_full_span() self.backslash_indices = self.get_backslash_indices() self.left_brace_indices, self.right_brace_indices = \ self.get_left_and_right_indices() @@ -236,15 +237,15 @@ class MTex(_TexSVG): self.script_content_spans = self.get_script_content_spans() self.double_braces_spans = self.get_double_braces_spans() self.stripped_substrings = self.get_stripped_substrings() - self.specified_spans = self.get_specified_spans() self.specified_substrings = self.get_specified_substrings() + self.specified_spans, self.extended_specified_spans = \ + self.get_specified_spans() self.tex_span_list = self.get_tex_span_list() self.extended_tex_span_list = self.get_extended_tex_span_list() self.isolated_substrings = self.get_isolated_substrings() self.containing_labels_dict = self.get_containing_labels_dict() self.color_cmd_repl_items = self.get_color_cmd_repl_items() self.span_repl_dict = self.get_span_repl_dict() - self.labelled_tex_string = self.get_labelled_tex_string() def get_additional_substrings(self) -> list[str]: return list(it.chain( @@ -252,28 +253,31 @@ class MTex(_TexSVG): self.isolate )) + def get_full_span(self) -> tuple[int, int]: + return (0, len(self.string)) + def get_backslash_indices(self) -> list[int]: # Newlines (`\\`) don't count. return [ match_obj.end() - 1 - for match_obj in re.finditer(r"\\+", self.tex_string) + for match_obj in re.finditer(r"\\+", self.string) if len(match_obj.group()) % 2 == 1 ] def get_left_and_right_indices(self) -> list[tuple[int, int]]: - tex_string = self.tex_string + string = self.string indices = list(filter( lambda index: index - 1 not in self.backslash_indices, [ match_obj.start() - for match_obj in re.finditer(r"[{}]", tex_string) + for match_obj in re.finditer(r"[{}]", string) ] )) left_brace_indices = [] right_brace_indices = [] left_brace_indices_stack = [] for index in indices: - if tex_string[index] == "{": + if string[index] == "{": left_brace_indices_stack.append(index) else: if not left_brace_indices_stack: @@ -288,7 +292,7 @@ class MTex(_TexSVG): def get_script_char_spans(self) -> list[tuple[int, int]]: return [ match_obj.span() - for match_obj in re.finditer(r"(\s*)[_^]\s*", self.tex_string) + for match_obj in re.finditer(r"(\s*)[_^]\s*", self.string) if match_obj.group(1) or match_obj.start() - 1 not in self.backslash_indices ] @@ -296,14 +300,14 @@ class MTex(_TexSVG): def get_skipped_indices(self) -> list[int]: return sorted(remove_list_redundancies([ match_obj.start() - for match_obj in re.finditer(r"\s", self.tex_string) + for match_obj in re.finditer(r"\s", self.string) ] + list(it.chain(*[ range(*script_char_span) for script_char_span in self.script_char_spans ])))) def get_script_spans(self) -> list[tuple[int, int]]: - tex_string = self.tex_string + string = self.string result = [] brace_indices_dict = dict(zip( self.left_brace_indices, self.right_brace_indices @@ -313,7 +317,7 @@ class MTex(_TexSVG): span_end = brace_indices_dict[span_begin] + 1 else: pattern = re.compile(r"[a-zA-Z0-9]|\\[a-zA-Z]+") - match_obj = pattern.match(tex_string, pos=span_begin) + match_obj = pattern.match(string, pos=span_begin) if not match_obj: script_name = { "_": "subscript", @@ -361,64 +365,68 @@ class MTex(_TexSVG): def get_stripped_substrings(self) -> list[str]: result = remove_list_redundancies([ - string.strip() - for string in self.additional_substrings + substr.strip() + for substr in self.additional_substrings ]) if "" in result: result.remove("") return result - def get_specified_spans(self) -> list[tuple[int, int]]: - result = self.double_braces_spans.copy() - tex_string = self.tex_string + def get_specified_substrings(self) -> list[str]: + return remove_list_redundancies([ + self.string[slice(*double_braces_span)] + for double_braces_span in self.double_braces_spans + ] + list(filter( + lambda s: s in self.string, + self.stripped_substrings + ))) + + def get_specified_spans( + self + ) -> tuple[list[tuple[int, int]], list[tuple[int, int]]]: + tex_spans = sorted(remove_list_redundancies([ + self.full_span, + *self.double_braces_spans, + *[ + match_obj.span() + for substr in self.stripped_substrings + for match_obj in re.finditer(re.escape(substr), self.string) + ] + ]), key=lambda t: (t[0], -t[1])) + result = [] + extended_result = [] + script_spans_dict = dict(self.script_spans) reversed_script_spans_dict = dict([ script_span[::-1] for script_span in self.script_spans ]) - for string in self.stripped_substrings: - for match_obj in re.finditer(re.escape(string), tex_string): - span_begin, span_end = match_obj.span() - while span_end in reversed_script_spans_dict.keys(): - span_end = reversed_script_spans_dict[span_end] - if span_begin >= span_end: - continue - result.append((span_begin, span_end)) - return list(filter( - lambda tex_span: tex_span not in self.script_content_spans, - remove_list_redundancies(result) - )) - - def get_specified_substrings(self) -> list[str]: - return remove_list_redundancies([ - self.tex_string[slice(*double_braces_span)] - for double_braces_span in self.double_braces_spans - ] + list(filter( - lambda s: s in self.tex_string, - self.additional_substrings - ))) + for tex_span in tex_spans: + if tex_span in self.script_content_spans: + continue + span_begin, span_end = tex_span + extended_span_end = span_end + while span_end in reversed_script_spans_dict.keys(): + span_end = reversed_script_spans_dict[span_end] + while extended_span_end in script_spans_dict.keys(): + extended_span_end = script_spans_dict[extended_span_end] + specified_span = (span_begin, span_end) + extended_specified_span = (span_begin, extended_span_end) + if span_begin >= span_end: + continue + if extended_specified_span in result: + continue + result.append(specified_span) + extended_result.append(extended_specified_span) + return result, extended_result def get_tex_span_list(self) -> list[tuple[int, int]]: - return [ - (0, len(self.tex_string)), - *self.script_content_spans, - *self.specified_spans - ] + return self.specified_spans + self.script_content_spans def get_extended_tex_span_list(self) -> list[tuple[int, int]]: - extended_specified_spans = [] - script_spans_dict = dict(self.script_spans) - for span_begin, span_end in self.specified_spans: - while span_end in script_spans_dict.keys(): - span_end = script_spans_dict[span_end] - extended_specified_spans.append((span_begin, span_end)) - return [ - (0, len(self.tex_string)), - *self.script_content_spans, - *extended_specified_spans - ] + return self.extended_specified_spans + self.script_content_spans def get_isolated_substrings(self) -> list[str]: return remove_list_redundancies([ - self.tex_string[slice(*tex_span)] + self.string[slice(*tex_span)] for tex_span in self.tex_span_list ]) @@ -434,37 +442,38 @@ class MTex(_TexSVG): result[span_0].append(span_index) elif span_0[0] < span_1[0] < span_0[1] < span_1[1]: string_0, string_1 = [ - self.tex_string[slice(*tex_span)] + self.string[slice(*tex_span)] for tex_span in [span_0, span_1] ] raise ValueError( "Partially overlapping substrings detected: " f"'{string_0}' and '{string_1}'" ) + result[self.full_span] = list(range(len(tex_span_list))) return result def get_color_cmd_repl_items(self) -> list[tuple[tuple[int, int], str]]: - color_related_commands_dict = { - "color": 1, - "textcolor": 1, - "pagecolor": 1, - "colorbox": 1, - "fcolorbox": 2, - } + color_related_command_items = [ + ("color", 1, ""), + ("textcolor", 1, ""), + ("pagecolor", 1, "\\pagecolor{white}"), + ("colorbox", 1, "\\colorbox{white}"), + ("fcolorbox", 2, "\\fcolorbox{white}{white}"), + ] result = [] - tex_string = self.tex_string + string = self.string backslash_indices = self.backslash_indices left_indices = self.left_brace_indices brace_indices_dict = dict(zip( self.left_brace_indices, self.right_brace_indices )) - for cmd_name, n_braces in color_related_commands_dict.items(): + for cmd_name, n_braces, repl_str in color_related_command_items: pattern = cmd_name + r"(?![a-zA-Z])" - for match_obj in re.finditer(pattern, tex_string): + for match_obj in re.finditer(pattern, string): span_begin, span_end = match_obj.span() - if span_begin - 1 not in backslash_indices: + span_begin -= 1 + if span_begin not in backslash_indices: continue - repl_str = cmd_name + n_braces * "{black}" for _ in range(n_braces): left_index = min(filter( lambda index: index >= span_end, left_indices @@ -481,7 +490,7 @@ class MTex(_TexSVG): -tex_span[1 - flag], ("{{" + self.get_color_command(label), "}}")[flag] ) - for label, tex_span in enumerate(self.extended_tex_span_list) + for label, tex_span in enumerate(self.tex_span_list) for flag in range(2) ])) result = { @@ -493,9 +502,9 @@ class MTex(_TexSVG): result.update(self.color_cmd_repl_items) return result - def get_labelled_tex_string(self) -> str: + def get_labelled_string(self) -> str: if not self.span_repl_dict: - return self.tex_string + return self.string spans = sorted(self.span_repl_dict.keys()) if not all( @@ -506,10 +515,10 @@ class MTex(_TexSVG): span_ends, span_begins = zip(*spans) string_pieces = [ - self.tex_string[slice(*span)] + self.string[slice(*span)] for span in zip( (0, *span_begins), - (*span_ends, len(self.tex_string)) + (*span_ends, len(self.string)) ) ] repl_strs = [ @@ -521,7 +530,7 @@ class MTex(_TexSVG): def get_submob_tex_strings(self, submob_labels: list[int]) -> list[str]: ordered_tex_spans = [ - self.tex_span_list[label] + self.tex_span_list[label] if label != -1 else self.full_span for label in submob_labels ] ordered_containing_labels = [ @@ -546,7 +555,7 @@ class MTex(_TexSVG): ] string_span_ends.append(ordered_span_ends[-1]) - tex_string = self.tex_string + string = self.string left_indices = self.left_brace_indices right_indices = self.right_brace_indices skipped_indices = sorted(it.chain( @@ -575,7 +584,7 @@ class MTex(_TexSVG): unclosed_left_brace -= 1 result.append("".join([ unclosed_right_brace * "{", - tex_string[span_begin:span_end], + string[span_begin:span_end], unclosed_left_brace * "}" ])) return result @@ -615,7 +624,7 @@ class MTex(_TexSVG): custom_span ) if tex_spans is None: - tex = self.tex_string[slice(*custom_span)] + tex = self.string[slice(*custom_span)] raise ValueError(f"Failed to match mobjects from tex: \"{tex}\"") labels = set(it.chain(*[ @@ -623,7 +632,7 @@ class MTex(_TexSVG): for tex_span in tex_spans ])) return VGroup(*filter( - lambda submob: submob.submob_label in labels, + lambda submob: submob.label in labels, self.submobjects )) @@ -631,7 +640,7 @@ class MTex(_TexSVG): return VGroup(*[ self.get_part_by_custom_span(match_obj.span()) for match_obj in re.finditer( - re.escape(tex.strip()), self.tex_string + re.escape(tex.strip()), self.string ) ]) @@ -665,7 +674,7 @@ class MTex(_TexSVG): return self.indices_of_part(part) def get_tex(self) -> str: - return self.tex_string + return self.string def get_submob_tex(self) -> list[str]: return [