From d6b20a7306af96e32e1a9ed837976fcaa938cc3c Mon Sep 17 00:00:00 2001 From: YishiMichael <50232075+YishiMichael@users.noreply.github.com> Date: Tue, 8 Feb 2022 00:21:53 +0800 Subject: [PATCH] Refactor MTex and implement TransformMatchingMTex (#1725) * Some small refactors * Refactor MTex * Implement TransformMatchingMTex * Some refactors * Some refactors * Some small refactors * Strip strings before matching * Implement get_submob_tex * Use RGB color mode * Some small refactors --- .../animation/transform_matching_parts.py | 109 +++ manimlib/mobject/svg/mtex_mobject.py | 924 ++++++++---------- 2 files changed, 543 insertions(+), 490 deletions(-) diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index 84ee3a4a..ce396404 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -1,16 +1,20 @@ import numpy as np +import itertools as it from manimlib.animation.composition import AnimationGroup from manimlib.animation.fading import FadeTransformPieces from manimlib.animation.fading import FadeInFromPoint from manimlib.animation.fading import FadeOutToPoint +from manimlib.animation.transform import ReplacementTransform from manimlib.animation.transform import Transform from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Group +from manimlib.mobject.svg.mtex_mobject import MTex from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.config_ops import digest_config +from manimlib.utils.iterables import remove_list_redundancies class TransformMatchingParts(AnimationGroup): @@ -139,3 +143,108 @@ class TransformMatchingTex(TransformMatchingParts): @staticmethod def get_mobject_key(mobject): return mobject.get_tex() + + +class TransformMatchingMTex(AnimationGroup): + CONFIG = { + "key_map": dict(), + } + + def __init__(self, source_mobject, target_mobject, **kwargs): + digest_config(self, kwargs) + assert isinstance(source_mobject, MTex) + assert isinstance(target_mobject, MTex) + anims = [] + rest_source_submobs = source_mobject.submobjects.copy() + rest_target_submobs = target_mobject.submobjects.copy() + + def add_anim_from(anim_class, func, source_attr, target_attr=None): + if target_attr is None: + target_attr = source_attr + source_parts = func(source_mobject, source_attr) + target_parts = func(target_mobject, target_attr) + filtered_source_parts = [ + submob_part for submob_part in source_parts + if all([ + submob in rest_source_submobs + for submob in submob_part + ]) + ] + filtered_target_parts = [ + submob_part for submob_part in target_parts + if all([ + submob in rest_target_submobs + for submob in submob_part + ]) + ] + if not (filtered_source_parts and filtered_target_parts): + return + anims.append(anim_class( + VGroup(*filtered_source_parts), + VGroup(*filtered_target_parts), + **kwargs + )) + for submob in it.chain(*filtered_source_parts): + rest_source_submobs.remove(submob) + for submob in it.chain(*filtered_target_parts): + rest_target_submobs.remove(submob) + + def get_submobs_from_keys(mobject, keys): + if not isinstance(keys, tuple): + keys = (keys,) + indices = [] + for key in keys: + if isinstance(key, int): + indices.append(key) + elif isinstance(key, range): + indices.extend(key) + elif isinstance(key, str): + all_parts = mobject.get_parts_by_tex(key) + indices.extend(it.chain(*[ + mobject.indices_of_part(part) for part in all_parts + ])) + else: + raise TypeError(key) + return VGroup(VGroup(*[ + mobject[i] for i in remove_list_redundancies(indices) + ])) + + for source_key, target_key in self.key_map.items(): + add_anim_from( + ReplacementTransform, get_submobs_from_keys, + source_key, target_key + ) + + common_specified_substrings = sorted(list( + set(source_mobject.get_specified_substrings()).intersection( + target_mobject.get_specified_substrings() + ) + ), key=len, reverse=True) + for part_tex_string in common_specified_substrings: + add_anim_from( + FadeTransformPieces, MTex.get_parts_by_tex, part_tex_string + ) + + common_submob_tex_strings = { + source_submob.get_tex() for source_submob in source_mobject + }.intersection({ + target_submob.get_tex() for target_submob in target_mobject + }) + for tex_string in common_submob_tex_strings: + add_anim_from( + FadeTransformPieces, + lambda mobject, attr: VGroup(*[ + VGroup(mob) for mob in mobject + if mob.get_tex() == attr + ]), + tex_string + ) + + anims.append(FadeOutToPoint( + VGroup(*rest_source_submobs), target_mobject.get_center(), **kwargs + )) + anims.append(FadeInFromPoint( + VGroup(*rest_target_submobs), source_mobject.get_center(), **kwargs + )) + + super().__init__(*anims) diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index da09297b..3e9daca5 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -1,6 +1,5 @@ import itertools as it import re -import sys from types import MethodType from manimlib.constants import BLACK @@ -22,10 +21,6 @@ SCALE_FACTOR_PER_FONT_POINT = 0.001 TEX_HASH_TO_MOB_MAP = {} -def _contains(span_0, span_1): - return span_0[0] <= span_1[0] and span_1[1] <= span_0[1] - - def _get_neighbouring_pairs(iterable): return list(adjacent_pairs(iterable))[:-1] @@ -40,277 +35,368 @@ class _TexSVG(SVGMobject): }, } - @staticmethod - def color_to_label(fill_color): - r, g, b = color_to_int_rgb(fill_color) - rg = r * 256 + g - return rg * 256 + b - - def parse_labels(self): - for glyph in self: - glyph.glyph_label = _TexSVG.color_to_label(glyph.fill_color) - return self - - -class _TexSpan(object): - def __init__(self, script_type, label): - # `script_type`: 0 for normal, 1 for subscript, 2 for superscript. - # Only those spans with `script_type == 0` will be colored. - self.script_type = script_type - self.label = label - self.containing_labels = [] - - def __repr__(self): - return "_TexSpan(" + ", ".join([ - attrib_name + "=" + str(getattr(self, attrib_name)) - for attrib_name in ["script_type", "label", "containing_labels"] - ]) + ")" - class _TexParser(object): def __init__(self, tex_string, additional_substrings): self.tex_string = tex_string - self.tex_spans_dict = {} - self.current_label = -1 - self.brace_index_pairs = self.get_brace_index_pairs() - self.existing_color_command_spans = self.get_existing_color_command_spans() - self.has_existing_color_commands = any(self.existing_color_command_spans.values()) - self.specified_substring_spans = [] + self.whitespace_indices = self.get_whitespace_indices() + self.backslash_indices = self.get_backslash_indices() + self.script_indices = self.get_script_indices() + self.brace_indices_dict = self.get_brace_indices_dict() + self.tex_span_list = [] + self.script_span_to_char_dict = {} + self.script_span_to_tex_span_dict = {} + self.neighbouring_script_span_pairs = [] + self.specified_substrings = [] self.add_tex_span((0, len(tex_string))) - self.break_up_by_double_braces() self.break_up_by_scripts() + self.break_up_by_double_braces() self.break_up_by_additional_substrings(additional_substrings) - self.specified_substrings = remove_list_redundancies([ - tex_string[slice(*span_tuple)] - for span_tuple in self.specified_substring_spans + self.tex_span_list.sort(key=lambda t: (t[0], -t[1])) + self.specified_substrings = remove_list_redundancies( + self.specified_substrings + ) + self.containing_labels_dict = self.get_containing_labels_dict() + + def add_tex_span(self, tex_span): + if tex_span not in self.tex_span_list: + self.tex_span_list.append(tex_span) + + def get_whitespace_indices(self): + return [ + match_obj.start() + for match_obj in re.finditer(r"\s", self.tex_string) + ] + + def get_backslash_indices(self): + # Newlines (`\\`) don't count. + return [ + match_obj.end() - 1 + for match_obj in re.finditer(r"\\+", self.tex_string) + if len(match_obj.group()) % 2 == 1 + ] + + def filter_out_escaped_characters(self, indices): + return list(filter( + lambda index: index - 1 not in self.backslash_indices, + indices + )) + + def get_script_indices(self): + return self.filter_out_escaped_characters([ + match_obj.start() + for match_obj in re.finditer(r"[_^]", self.tex_string) ]) - self.check_if_overlap() - self.analyse_containing_labels() - @staticmethod - def label_to_color_tuple(rgb): - # Get a unique color different from black. - rg, b = divmod(rgb, 256) - r, g = divmod(rg, 256) - return r, g, b - - def add_tex_span(self, span_tuple, script_type=0, label=-1): - if span_tuple in self.tex_spans_dict: - return - - if script_type == 0: - # Should be additionally labelled. - self.current_label += 1 - label = self.current_label - - tex_span = _TexSpan(script_type, label) - self.tex_spans_dict[span_tuple] = tex_span - - def get_brace_index_pairs(self): - result = [] - left_brace_indices = [] - for match_obj in re.finditer(r"(\\*)(\{|\})", self.tex_string): - # Braces following even numbers of backslashes are counted. - if len(match_obj.group(1)) % 2 == 1: - continue - if match_obj.group(2) == "{": - left_brace_index = match_obj.span(2)[0] - left_brace_indices.append(left_brace_index) - else: - left_brace_index = left_brace_indices.pop() - right_brace_index = match_obj.span(2)[1] - result.append((left_brace_index, right_brace_index)) - if left_brace_indices: - self.raise_tex_parsing_error("unmatched braces") - return result - - def get_existing_color_command_spans(self): + def get_brace_indices_dict(self): tex_string = self.tex_string - color_related_commands_dict = _TexParser.get_color_related_commands_dict() - commands = color_related_commands_dict.keys() - result = { - command_name: [] - for command_name in commands - } - brace_index_pairs = self.brace_index_pairs - pattern = "|".join([ - re.escape(command_name) - for command_name in commands + indices = self.filter_out_escaped_characters([ + match_obj.start() + for match_obj in re.finditer(r"[{}]", tex_string) ]) - for match_obj in re.finditer(pattern, tex_string): - span_tuple = match_obj.span() - command_begin_index = span_tuple[0] - command_name = match_obj.group() - n_braces = color_related_commands_dict[command_name] - for _ in range(n_braces): - span_tuple = min(filter( - lambda t: t[0] >= span_tuple[1], - brace_index_pairs - )) - result[command_name].append( - (command_begin_index, span_tuple[1]) - ) + result = {} + left_brace_indices_stack = [] + for index in indices: + if tex_string[index] == "{": + left_brace_indices_stack.append(index) + else: + left_brace_index = left_brace_indices_stack.pop() + result[left_brace_index] = index return result - def break_up_by_double_braces(self): - # Match paired double braces (`{{...}}`). - skip_pair = False - for prev_span_tuple, span_tuple in _get_neighbouring_pairs( - self.brace_index_pairs - ): - if skip_pair: - skip_pair = False - continue - if all([ - span_tuple[0] == prev_span_tuple[0] - 1, - span_tuple[1] == prev_span_tuple[1] + 1 - ]): - self.add_tex_span(span_tuple) - self.specified_substring_spans.append(span_tuple) - skip_pair = True - def break_up_by_scripts(self): # Match subscripts & superscripts. tex_string = self.tex_string - brace_indices_dict = dict(self.brace_index_pairs) - for match_obj in re.finditer(r"((?= span_end: continue - span_tuple = (span_begin, span_end) - self.add_tex_span(span_tuple) - self.specified_substring_spans.append(span_tuple) + self.add_tex_span((span_begin, span_end)) - def check_if_overlap(self): - span_tuples = sorted( - self.tex_spans_dict.keys(), - key=lambda t: (t[0], -t[1]) - ) - overlapping_span_pairs = [] - for i, span_0 in enumerate(span_tuples): - for span_1 in span_tuples[i + 1 :]: + def get_containing_labels_dict(self): + tex_span_list = self.tex_span_list + result = { + tex_span: [] + for tex_span in tex_span_list + } + overlapping_tex_span_pairs = [] + for index_0, span_0 in enumerate(tex_span_list): + for index_1, span_1 in enumerate(tex_span_list[index_0:]): if span_0[1] <= span_1[0]: continue if span_0[1] < span_1[1]: - overlapping_span_pairs.append((span_0, span_1)) - if overlapping_span_pairs: + overlapping_tex_span_pairs.append((span_0, span_1)) + result[span_0].append(index_0 + index_1) + if overlapping_tex_span_pairs: tex_string = self.tex_string - log.error("Overlapping substring pairs occur in MTex:") - for span_tuple_pair in overlapping_span_pairs: + log.error("Partially overlapping substrings detected:") + for tex_span_pair in overlapping_tex_span_pairs: log.error(", ".join( - f"\"{tex_string[slice(*span_tuple)]}\"" - for span_tuple in span_tuple_pair + f"\"{tex_string[slice(*tex_span)]}\"" + for tex_span in tex_span_pair )) raise ValueError - - def analyse_containing_labels(self): - for span_0, tex_span_0 in self.tex_spans_dict.items(): - if tex_span_0.script_type != 0: - continue - for span_1, tex_span_1 in self.tex_spans_dict.items(): - if _contains(span_1, span_0): - tex_span_1.containing_labels.append(tex_span_0.label) + return result def get_labelled_tex_string(self): - tex_string = self.tex_string - if self.current_label == 0 and not self.has_existing_color_commands: - return tex_string - - # Remove the span of extire tex string. - indices_with_labels = sorted([ - (span_tuple[i], i, span_tuple[1 - i], tex_span.label) - for span_tuple, tex_span in self.tex_spans_dict.items() - if tex_span.script_type == 0 - for i in range(2) - ], key=lambda t: (t[0], -t[1], -t[2]))[1:] - - # Prevent from "\\color[RGB]" being replaced. - # Hopefully tex string doesn't contain such a substring... - color_command_placeholder = "{{\\iffalse \\fi}}" - result = tex_string[: indices_with_labels[0][0]] - for index_with_label, next_index_with_label in _get_neighbouring_pairs( - indices_with_labels - ): - index, flag, _, label = index_with_label - next_index, *_ = next_index_with_label - # Adding one more pair of braces will help maintain the glyghs of tex file... - if flag == 0: - color_tuple = _TexParser.label_to_color_tuple(label) - result += "".join([ - "{{", - color_command_placeholder, - "{", - ",".join(map(str, color_tuple)), - "}" - ]) - else: - result += "}}" - result += tex_string[index : next_index] - - color_related_commands_dict = _TexParser.get_color_related_commands_dict() - for command_name, command_spans in self.existing_color_command_spans.items(): - if not command_spans: - continue - n_braces = color_related_commands_dict[command_name] - command_to_replace = command_name + n_braces * "{black}" - commands = { - tex_string[slice(*span_tuple)] - for span_tuple in command_spans - } - for command in commands: - result = result.replace(command, command_to_replace) - - return result.replace(color_command_placeholder, "\\color[RGB]") - - def raise_tex_parsing_error(self, message): - raise ValueError(f"Failed to parse tex ({message}): \"{self.tex_string}\"") + indices, _, flags, labels = zip(*sorted([ + (*tex_span[::(1, -1)[flag]], flag, label) + for label, tex_span in enumerate(self.tex_span_list) + for flag in range(2) + ], key=lambda t: (t[0], -t[2], -t[1]))) + command_pieces = [ + ("{{" + self.get_color_command(label), "}}")[flag] + for flag, label in zip(flags, labels) + ][1:-1] + command_pieces.insert(0, "") + string_pieces = [ + self.tex_string[slice(*tex_span)] + for tex_span in _get_neighbouring_pairs(indices) + ] + return "".join(it.chain(*zip(command_pieces, string_pieces))) @staticmethod - def get_color_related_commands_dict(): - # Only list a few commands that are commonly used. - return { - "\\color": 1, - "\\textcolor": 1, - "\\pagecolor": 1, - "\\colorbox": 1, - "\\fcolorbox": 2, - } + def get_color_command(label): + rg, b = divmod(label, 256) + r, g = divmod(rg, 256) + return "".join([ + "\\color[RGB]", + "{", + ",".join(map(str, (r, g, b))), + "}" + ]) + + def get_sorted_submob_indices(self, submob_labels): + def script_span_to_submob_range(script_span): + tex_span = self.script_span_to_tex_span_dict[script_span] + submob_indices = [ + index for index, label in enumerate(submob_labels) + if label in self.containing_labels_dict[tex_span] + ] + return range(submob_indices[0], submob_indices[-1] + 1) + + filtered_script_span_pairs = filter( + lambda script_span_pair: all([ + self.script_span_to_char_dict[script_span] == character + for script_span, character in zip(script_span_pair, "_^") + ]), + self.neighbouring_script_span_pairs + ) + switch_range_pairs = sorted([ + tuple([ + script_span_to_submob_range(script_span) + for script_span in script_span_pair + ]) + for script_span_pair in filtered_script_span_pairs + ], key=lambda t: (t[0].stop, -t[0].start)) + result = list(range(len(submob_labels))) + for range_0, range_1 in switch_range_pairs: + result = [ + *result[:range_1.start], + *result[range_0.start:range_0.stop], + *result[range_1.stop:range_0.start], + *result[range_1.start:range_1.stop], + *result[range_0.stop:] + ] + return result + + def get_submob_tex_strings(self, submob_labels): + ordered_tex_spans = [ + self.tex_span_list[label] for label in submob_labels + ] + ordered_containing_labels = [ + self.containing_labels_dict[tex_span] + for tex_span in ordered_tex_spans + ] + ordered_span_begins, ordered_span_ends = zip(*ordered_tex_spans) + string_span_begins = [ + prev_end if prev_label in containing_labels else curr_begin + for prev_end, prev_label, containing_labels, curr_begin in zip( + ordered_span_ends[:-1], submob_labels[:-1], + ordered_containing_labels[1:], ordered_span_begins[1:] + ) + ] + string_span_begins.insert(0, ordered_span_begins[0]) + string_span_ends = [ + next_begin if next_label in containing_labels else curr_end + for next_begin, next_label, containing_labels, curr_end in zip( + ordered_span_begins[1:], submob_labels[1:], + ordered_containing_labels[:-1], ordered_span_ends[:-1] + ) + ] + string_span_ends.append(ordered_span_ends[-1]) + + tex_string = self.tex_string + left_brace_indices = sorted(self.brace_indices_dict.keys()) + right_brace_indices = sorted(self.brace_indices_dict.values()) + ignored_indices = sorted(it.chain( + self.whitespace_indices, + left_brace_indices, + right_brace_indices, + self.script_indices + )) + result = [] + for span_begin, span_end in zip(string_span_begins, string_span_ends): + while span_begin in ignored_indices: + span_begin += 1 + if span_begin >= span_end: + result.append("") + continue + while span_end - 1 in ignored_indices: + span_end -= 1 + unclosed_left_brace = 0 + unclosed_right_brace = 0 + for index in range(span_begin, span_end): + if index in left_brace_indices: + unclosed_left_brace += 1 + elif index in right_brace_indices: + if unclosed_left_brace == 0: + unclosed_right_brace += 1 + else: + unclosed_left_brace -= 1 + result.append("".join([ + unclosed_right_brace * "{", + tex_string[span_begin:span_end], + unclosed_left_brace * "}" + ])) + return result + + def find_span_components_of_custom_span(self, custom_span): + skipped_indices = sorted(it.chain( + self.whitespace_indices, + self.script_indices + )) + tex_span_choices = sorted(filter( + lambda tex_span: all([ + tex_span[0] >= custom_span[0], + tex_span[1] <= custom_span[1] + ]), + self.tex_span_list + )) + # Choose spans that reach the farthest. + tex_span_choices_dict = dict(tex_span_choices) + + span_begin, span_end = custom_span + result = [] + while span_begin != span_end: + if span_begin not in tex_span_choices_dict.keys(): + if span_begin in skipped_indices: + span_begin += 1 + continue + return None + next_begin = tex_span_choices_dict[span_begin] + result.append((span_begin, next_begin)) + span_begin = next_begin + return result + + def get_containing_labels_by_tex_spans(self, tex_spans): + return remove_list_redundancies(list(it.chain(*[ + self.containing_labels_dict[tex_span] + for tex_span in tex_spans + ]))) + + def get_specified_substrings(self): + return self.specified_substrings + + def get_isolated_substrings(self): + return remove_list_redundancies([ + self.tex_string[slice(*tex_span)] + for tex_span in self.tex_span_list + ]) class MTex(VMobject): @@ -322,7 +408,7 @@ class MTex(VMobject): "tex_environment": "align*", "isolate": [], "tex_to_color_map": {}, - "use_plain_tex_file": False, + "use_plain_tex": False, } def __init__(self, tex_string, **kwargs): @@ -333,264 +419,138 @@ class MTex(VMobject): tex_string = "\\quad" self.tex_string = tex_string - self.generate_mobject() - + self.__parser = _TexParser( + self.tex_string, + [*self.tex_to_color_map.keys(), *self.isolate] + ) + mob = self.generate_mobject() + self.add(*mob.copy()) + self.init_colors() self.set_color_by_tex_to_color_map(self.tex_to_color_map) self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size) - def get_additional_substrings_to_break_up(self): - result = remove_list_redundancies([ - *self.tex_to_color_map.keys(), *self.isolate - ]) - if "" in result: - result.remove("") - return result - - def get_parser(self): - return _TexParser(self.tex_string, self.get_additional_substrings_to_break_up()) + @staticmethod + def color_to_label(color): + r, g, b = color_to_int_rgb(color) + rg = r * 256 + g + return rg * 256 + b def generate_mobject(self): - tex_string = self.tex_string - tex_parser = self.get_parser() - self.tex_spans_dict = tex_parser.tex_spans_dict - self.specified_substrings = tex_parser.specified_substrings - fill_color = self.get_fill_color() + labelled_tex_string = self.__parser.get_labelled_tex_string() + labelled_tex_content = self.get_tex_file_content(labelled_tex_string) + hash_val = hash((labelled_tex_content, self.use_plain_tex)) - # Cannot simultaneously be false, so at least one file is generated. - require_labelled_tex_file = tex_parser.current_label != 0 - require_plain_tex_file = any([ - self.use_plain_tex_file, - tex_parser.has_existing_color_commands, - tex_parser.current_label == 0 - ]) - - if require_labelled_tex_file: - labelled_full_tex = self.get_tex_file_body(tex_parser.get_labelled_tex_string()) - labelled_hash_val = hash(labelled_full_tex) - if labelled_hash_val in TEX_HASH_TO_MOB_MAP: - self.add(*TEX_HASH_TO_MOB_MAP[labelled_hash_val].copy()) - else: - with display_during_execution(f"Writing \"{tex_string}\""): - labelled_svg_glyphs = MTex.get_svg_glyphs(labelled_full_tex) - labelled_svg_glyphs.parse_labels() - self.add(*labelled_svg_glyphs) - self.build_submobjects() - TEX_HASH_TO_MOB_MAP[labelled_hash_val] = self.copy() - if not require_plain_tex_file: - self.set_fill(color=fill_color) - return self - - # require_plain_tex_file == True - self.set_submobjects([]) - full_tex = self.get_tex_file_body(tex_string, fill_color=fill_color) - hash_val = hash(full_tex) if hash_val in TEX_HASH_TO_MOB_MAP: - self.add(*TEX_HASH_TO_MOB_MAP[hash_val].copy()) - else: - with display_during_execution(f"Writing \"{tex_string}\""): - svg_glyphs = MTex.get_svg_glyphs(full_tex) - if require_labelled_tex_file: - labelled_svg_mob = TEX_HASH_TO_MOB_MAP[labelled_hash_val] - for glyph, labelled_glyph in zip(svg_glyphs, it.chain(*labelled_svg_mob)): - glyph.glyph_label = labelled_glyph.glyph_label - else: - for glyph in svg_glyphs: - glyph.glyph_label = 0 - self.add(*svg_glyphs) - self.build_submobjects() - TEX_HASH_TO_MOB_MAP[hash_val] = self.copy() - return self + return TEX_HASH_TO_MOB_MAP[hash_val] - def get_tex_file_body(self, new_tex, fill_color=None): + if not self.use_plain_tex: + with display_during_execution(f"Writing \"{self.tex_string}\""): + labelled_svg_glyphs = self.tex_content_to_glyphs( + labelled_tex_content + ) + glyph_labels = [ + self.color_to_label(labelled_glyph.fill_color) + for labelled_glyph in labelled_svg_glyphs + ] + mob = self.build_mobject(labelled_svg_glyphs, glyph_labels) + TEX_HASH_TO_MOB_MAP[hash_val] = mob + return mob + + with display_during_execution(f"Writing \"{self.tex_string}\""): + labelled_svg_glyphs = self.tex_content_to_glyphs( + labelled_tex_content + ) + tex_content = self.get_tex_file_content(self.tex_string) + svg_glyphs = self.tex_content_to_glyphs(tex_content) + glyph_labels = [ + self.color_to_label(labelled_glyph.fill_color) + for labelled_glyph in labelled_svg_glyphs + ] + mob = self.build_mobject(svg_glyphs, glyph_labels) + TEX_HASH_TO_MOB_MAP[hash_val] = mob + return mob + + def get_tex_file_content(self, tex_string): if self.tex_environment: - new_tex = "\n".join([ + tex_string = "\n".join([ f"\\begin{{{self.tex_environment}}}", - new_tex, + tex_string, f"\\end{{{self.tex_environment}}}" ]) if self.alignment: - new_tex = "\n".join([self.alignment, new_tex]) - if fill_color: - int_rgb = color_to_int_rgb(fill_color) - color_command = "".join([ - "\\color[RGB]", - "{", - ",".join(map(str, int_rgb)), - "}" - ]) - new_tex = "\n".join( - [color_command, new_tex] - ) - - tex_config = get_tex_config() - return tex_config["tex_body"].replace( - tex_config["text_to_replace"], - new_tex - ) + tex_string = "\n".join([self.alignment, tex_string]) + return tex_string @staticmethod - def get_svg_glyphs(full_tex): + def tex_content_to_glyphs(tex_content): + tex_config = get_tex_config() + full_tex = tex_config["tex_body"].replace( + tex_config["text_to_replace"], + tex_content + ) filename = tex_to_svg_file(full_tex) return _TexSVG(filename) - def build_submobjects(self): - if not self.submobjects: - return - self.init_colors() - for glyph in self.submobjects: - glyph.set_fill(glyph.fill_color) - self.group_submobjects() - self.sort_scripts_in_tex_order() - self.assign_submob_tex_strings() + def build_mobject(self, svg_glyphs, glyph_labels): + if not svg_glyphs: + return VGroup() - def group_submobjects(self): # Simply pack together adjacent mobjects with the same label. - new_submobjects = [] - def append_new_submobject(glyphs): - if glyphs: - submobject = VGroup(*glyphs) - submobject.submob_label = glyphs[0].glyph_label - new_submobjects.append(submobject) - + submobjects = [] + submob_labels = [] new_glyphs = [] - current_glyph_label = 0 - for glyph in self.submobjects: - if glyph.glyph_label == current_glyph_label: + current_glyph_label = glyph_labels[0] + for glyph, label in zip(svg_glyphs, glyph_labels): + if label == current_glyph_label: new_glyphs.append(glyph) else: - append_new_submobject(new_glyphs) + submobject = VGroup(*new_glyphs) + submob_labels.append(current_glyph_label) + submobjects.append(submobject) new_glyphs = [glyph] - current_glyph_label = glyph.glyph_label - append_new_submobject(new_glyphs) - self.set_submobjects(new_submobjects) + current_glyph_label = label + submobject = VGroup(*new_glyphs) + submob_labels.append(current_glyph_label) + submobjects.append(submobject) - def sort_scripts_in_tex_order(self): - # LaTeX always puts superscripts before subscripts. - # This function sorts the submobjects of scripts in the order of tex given. - tex_spans_dict = self.tex_spans_dict - index_and_span_list = sorted([ - (index, span_tuple) - for span_tuple, tex_span in tex_spans_dict.items() - if tex_span.script_type != 0 - for index in span_tuple - ]) + indices = self.__parser.get_sorted_submob_indices(submob_labels) + rearranged_submobjects = [submobjects[index] for index in indices] + rearranged_labels = [submob_labels[index] for index in indices] - switch_range_pairs = [] - for index_and_span_0, index_and_span_1 in _get_neighbouring_pairs( - index_and_span_list + submob_tex_strings = self.__parser.get_submob_tex_strings( + rearranged_labels + ) + for submob, label, submob_tex in zip( + rearranged_submobjects, rearranged_labels, submob_tex_strings ): - index_0, span_tuple_0 = index_and_span_0 - index_1, span_tuple_1 = index_and_span_1 - if index_0 != index_1: - continue - if not all([ - tex_spans_dict[span_tuple_0].script_type == 1, - tex_spans_dict[span_tuple_1].script_type == 2 - ]): - continue - submob_range_0 = self.range_of_part( - self.get_part_by_span_tuples([span_tuple_0]) - ) - submob_range_1 = self.range_of_part( - self.get_part_by_span_tuples([span_tuple_1]) - ) - switch_range_pairs.append((submob_range_0, submob_range_1)) - - switch_range_pairs.sort(key=lambda t: (t[0].stop, -t[0].start)) - indices = list(range(len(self.submobjects))) - for submob_range_0, submob_range_1 in switch_range_pairs: - indices = [ - *indices[: submob_range_1.start], - *indices[submob_range_0.start : submob_range_0.stop], - *indices[submob_range_1.stop : submob_range_0.start], - *indices[submob_range_1.start : submob_range_1.stop], - *indices[submob_range_0.stop :] - ] - - submobs = self.submobjects - self.set_submobjects([submobs[i] for i in indices]) - - def assign_submob_tex_strings(self): - # Not sure whether this is the best practice... - # This temporarily supports `TransformMatchingTex`. - tex_string = self.tex_string - tex_spans_dict = self.tex_spans_dict - # Use tex strings including "_", "^". - label_dict = {} - for span_tuple, tex_span in tex_spans_dict.items(): - if tex_span.script_type != 0: - label_dict[tex_span.label] = span_tuple - else: - if tex_span.label not in label_dict: - label_dict[tex_span.label] = span_tuple - - curr_labels = [submob.submob_label for submob in self.submobjects] - prev_labels = [curr_labels[-1], *curr_labels[:-1]] - next_labels = [*curr_labels[1:], curr_labels[0]] - tex_string_spans = [] - for curr_label, prev_label, next_label in zip( - curr_labels, prev_labels, next_labels - ): - curr_span_tuple = label_dict[curr_label] - prev_span_tuple = label_dict[prev_label] - next_span_tuple = label_dict[next_label] - containing_labels = tex_spans_dict[curr_span_tuple].containing_labels - tex_string_spans.append([ - prev_span_tuple[1] if prev_label in containing_labels else curr_span_tuple[0], - next_span_tuple[0] if next_label in containing_labels else curr_span_tuple[1] - ]) - tex_string_spans[0][0] = label_dict[curr_labels[0]][0] - tex_string_spans[-1][1] = label_dict[curr_labels[-1]][1] - for submob, tex_string_span in zip(self.submobjects, tex_string_spans): - submob.tex_string = tex_string[slice(*tex_string_span)] + submob.submob_label = label + submob.tex_string = submob_tex # Support `get_tex()` method here. submob.get_tex = MethodType(lambda inst: inst.tex_string, submob) + return VGroup(*rearranged_submobjects) - def get_part_by_span_tuples(self, span_tuples): - tex_spans_dict = self.tex_spans_dict - labels = set(it.chain(*[ - tex_spans_dict[span_tuple].containing_labels - for span_tuple in span_tuples - ])) + def get_part_by_tex_spans(self, tex_spans): + labels = self.__parser.get_containing_labels_by_tex_spans(tex_spans) return VGroup(*filter( lambda submob: submob.submob_label in labels, self.submobjects )) - def find_span_components_of_custom_span(self, custom_span_tuple): - tex_string = self.tex_string - span_choices = sorted(filter( - lambda t: _contains(custom_span_tuple, t), - self.tex_spans_dict.keys() - )) - # Filter out spans that reach the farthest. - span_choices_dict = dict(span_choices) - - span_begin, span_end = custom_span_tuple - result = [] - while span_begin != span_end: - if span_begin not in span_choices_dict: - if tex_string[span_begin].strip(): - return None - # Whitespaces may occur between spans. - span_begin += 1 - continue - next_begin = span_choices_dict[span_begin] - result.append((span_begin, next_begin)) - span_begin = next_begin - return result - - def get_part_by_custom_span_tuple(self, custom_span_tuple): - span_tuples = self.find_span_components_of_custom_span(custom_span_tuple) - if span_tuples is None: - tex = self.tex_string[slice(*custom_span_tuple)] - raise ValueError(f"Failed to get span of tex: \"{tex}\"") - return self.get_part_by_span_tuples(span_tuples) + def get_part_by_custom_span(self, custom_span): + tex_spans = self.__parser.find_span_components_of_custom_span( + custom_span + ) + if tex_spans is None: + tex = self.tex_string[slice(*custom_span)] + raise ValueError(f"Failed to match mobjects from tex: \"{tex}\"") + return self.get_part_by_tex_spans(tex_spans) def get_parts_by_tex(self, tex): return VGroup(*[ - self.get_part_by_custom_span_tuple(match_obj.span()) - for match_obj in re.finditer(re.escape(tex), self.tex_string) + self.get_part_by_custom_span(match_obj.span()) + for match_obj in re.finditer( + re.escape(tex.strip()), self.tex_string + ) ]) def get_part_by_tex(self, tex, index=0): @@ -602,16 +562,13 @@ class MTex(VMobject): return self def set_color_by_tex_to_color_map(self, tex_to_color_map): - for tex, color in list(tex_to_color_map.items()): - try: - self.set_color_by_tex(tex, color) - except: - pass + for tex, color in tex_to_color_map.items(): + self.set_color_by_tex(tex, color) return self def indices_of_part(self, part): indices = [ - i for i, submob in enumerate(self.submobjects) + index for index, submob in enumerate(self.submobjects) if submob in part ] if not indices: @@ -622,33 +579,20 @@ class MTex(VMobject): part = self.get_part_by_tex(tex, index=index) return self.indices_of_part(part) - def range_of_part(self, part): - indices = self.indices_of_part(part) - return range(indices[0], indices[-1] + 1) - - def range_of_part_by_tex(self, tex, index=0): - part = self.get_part_by_tex(tex, index=index) - return self.range_of_part(part) - - def index_of_part(self, part): - return self.indices_of_part(part)[0] - - def index_of_part_by_tex(self, tex, index=0): - part = self.get_part_by_tex(tex, index=index) - return self.index_of_part(part) - def get_tex(self): return self.tex_string - def get_all_isolated_substrings(self): - tex_string = self.tex_string - return remove_list_redundancies([ - tex_string[slice(*span_tuple)] - for span_tuple in self.tex_spans_dict.keys() - ]) + def get_submob_tex(self): + return [ + submob.get_tex() + for submob in self.submobjects + ] def get_specified_substrings(self): - return self.specified_substrings + return self.__parser.get_specified_substrings() + + def get_isolated_substrings(self): + return self.__parser.get_isolated_substrings() class MTexText(MTex):