diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index b7854dd3..fa30c2d3 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -3,9 +3,11 @@ import re import sys from types import MethodType +from manimlib.constants import BLACK from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.mobject.types.vectorized_mobject import VGroup +from manimlib.utils.color import color_to_int_rgb from manimlib.utils.iterables import adjacent_pairs from manimlib.utils.iterables import remove_list_redundancies from manimlib.utils.tex_file_writing import tex_to_svg_file @@ -28,8 +30,9 @@ def _get_neighbouring_pairs(iterable): return list(adjacent_pairs(iterable))[:-1] -class _PlainTex(SVGMobject): +class _TexSVG(SVGMobject): CONFIG = { + "color": BLACK, "height": None, "path_string_config": { "should_subdivide_sharp_curves": True, @@ -37,35 +40,16 @@ class _PlainTex(SVGMobject): }, } - -class _LabelledTex(_PlainTex): - def __init__(self, file_name=None, **kwargs): - super().__init__(file_name, **kwargs) - for glyph in self: - glyph.glyph_label = _LabelledTex.color_str_to_label(glyph.fill_color) - @staticmethod - def color_str_to_label(color_str): - if len(color_str) == 4: - # "#RGB" => "#RRGGBB" - color_str = "#" + "".join([c * 2 for c in color_str[1:]]) + def color_to_label(fill_color): + r, g, b = color_to_int_rgb(fill_color) + rg = r * 256 + g + return rg * 256 + b - return int(color_str[1:], 16) - 1 - - def get_mobjects_from(self, element, style): - result = super().get_mobjects_from(element, style) - for mob in result: - if not hasattr(mob, "glyph_label"): - mob.glyph_label = -1 - try: - color_str = element.getAttribute("fill") - if color_str: - glyph_label = _LabelledTex.color_str_to_label(color_str) - for mob in result: - mob.glyph_label = glyph_label - except: - pass - return result + def parse_labels(self): + for glyph in self: + glyph.glyph_label = _TexSVG.color_to_label(glyph.fill_color) + return self class _TexSpan(object): @@ -87,21 +71,25 @@ class _TexParser(object): def __init__(self, tex_string, additional_substrings): self.tex_string = tex_string self.tex_spans_dict = {} - self.specified_substrings = [] - self.current_label = 0 + self.current_label = -1 self.brace_index_pairs = self.get_brace_index_pairs() + self.existing_color_command_spans = self.get_existing_color_command_spans() + self.has_existing_color_commands = any(self.existing_color_command_spans.values()) + self.specified_substring_spans = [] self.add_tex_span((0, len(tex_string))) self.break_up_by_double_braces() self.break_up_by_scripts() self.break_up_by_additional_substrings(additional_substrings) + self.specified_substrings = remove_list_redundancies([ + tex_string[slice(*span_tuple)] + for span_tuple in self.specified_substring_spans + ]) self.check_if_overlap() self.analyse_containing_labels() - self.specified_substrings = remove_list_redundancies(self.specified_substrings) @staticmethod def label_to_color_tuple(rgb): - # Get a unique color different from black, - # or the svg file will not include the color information. + # Get a unique color different from black. rg, b = divmod(rgb, 256) r, g = divmod(rg, 256) return r, g, b @@ -112,16 +100,12 @@ class _TexParser(object): if script_type == 0: # Should be additionally labelled. - label = self.current_label self.current_label += 1 + label = self.current_label tex_span = _TexSpan(script_type, label) self.tex_spans_dict[span_tuple] = tex_span - def add_specified_substring(self, span_tuple): - substring = self.tex_string[slice(*span_tuple)] - self.specified_substrings.append(substring) - def get_brace_index_pairs(self): result = [] left_brace_indices = [] @@ -140,6 +124,34 @@ class _TexParser(object): self.raise_tex_parsing_error("unmatched braces") return result + def get_existing_color_command_spans(self): + tex_string = self.tex_string + color_related_commands_dict = _TexParser.get_color_related_commands_dict() + commands = color_related_commands_dict.keys() + result = { + command_name: [] + for command_name in commands + } + brace_index_pairs = self.brace_index_pairs + pattern = "|".join([ + re.escape(command_name) + for command_name in commands + ]) + for match_obj in re.finditer(pattern, tex_string): + span_tuple = match_obj.span() + command_begin_index = span_tuple[0] + command_name = match_obj.group() + n_braces = color_related_commands_dict[command_name] + for _ in range(n_braces): + span_tuple = min(filter( + lambda t: t[0] >= span_tuple[1], + brace_index_pairs + )) + result[command_name].append( + (command_begin_index, span_tuple[1]) + ) + return result + def break_up_by_double_braces(self): # Match paired double braces (`{{...}}`). skip_pair = False @@ -154,7 +166,7 @@ class _TexParser(object): span_tuple[1] == prev_span_tuple[1] + 1 ]): self.add_tex_span(span_tuple) - self.add_specified_substring(span_tuple) + self.specified_substring_spans.append(span_tuple) skip_pair = True def break_up_by_scripts(self): @@ -203,7 +215,7 @@ class _TexParser(object): continue span_tuple = (span_begin, span_end) self.add_tex_span(span_tuple) - self.add_specified_substring(span_tuple) + self.specified_substring_spans.append(span_tuple) def check_if_overlap(self): span_tuples = sorted( @@ -225,7 +237,7 @@ class _TexParser(object): f"\"{tex_string[slice(*span_tuple)]}\"" for span_tuple in span_tuple_pair )) - sys.exit(2) + raise ValueError def analyse_containing_labels(self): for span_0, tex_span_0 in self.tex_spans_dict.items(): @@ -235,9 +247,9 @@ class _TexParser(object): if _contains(span_1, span_0): tex_span_1.containing_labels.append(tex_span_0.label) - def get_labelled_expression(self): + def get_labelled_tex_string(self): tex_string = self.tex_string - if not self.tex_spans_dict: + if self.current_label == 0 and not self.has_existing_color_commands: return tex_string # Remove the span of extire tex string. @@ -248,6 +260,9 @@ class _TexParser(object): for i in range(2) ], key=lambda t: (t[0], -t[1], -t[2]))[1:] + # Prevent from "\\color[RGB]" being replaced. + # Hopefully tex string doesn't contain such a substring... + color_command_placeholder = "{{\\iffalse \\fi}}" result = tex_string[: indices_with_labels[0][0]] for index_with_label, next_index_with_label in _get_neighbouring_pairs( indices_with_labels @@ -259,7 +274,7 @@ class _TexParser(object): color_tuple = _TexParser.label_to_color_tuple(label) result += "".join([ "{{", - "\\color[RGB]", + color_command_placeholder, "{", ",".join(map(str, color_tuple)), "}" @@ -267,11 +282,36 @@ class _TexParser(object): else: result += "}}" result += tex_string[index : next_index] - return result + + color_related_commands_dict = _TexParser.get_color_related_commands_dict() + for command_name, command_spans in self.existing_color_command_spans.items(): + if not command_spans: + continue + n_braces = color_related_commands_dict[command_name] + command_to_replace = command_name + n_braces * "{black}" + commands = { + tex_string[slice(*span_tuple)] + for span_tuple in command_spans + } + for command in commands: + result = result.replace(command, command_to_replace) + + return result.replace(color_command_placeholder, "\\color[RGB]") def raise_tex_parsing_error(self, message): raise ValueError(f"Failed to parse tex ({message}): \"{self.tex_string}\"") + @staticmethod + def get_color_related_commands_dict(): + # Only list a few commands that are commonly used. + return { + "\\color": 1, + "\\textcolor": 1, + "\\pagecolor": 1, + "\\colorbox": 1, + "\\fcolorbox": 2, + } + class MTex(VMobject): CONFIG = { @@ -282,7 +322,7 @@ class MTex(VMobject): "tex_environment": "align*", "isolate": [], "tex_to_color_map": {}, - "generate_plain_tex_file": False, + "use_plain_tex_file": False, } def __init__(self, tex_string, **kwargs): @@ -295,7 +335,6 @@ class MTex(VMobject): self.generate_mobject() - self.init_colors() self.set_color_by_tex_to_color_map(self.tex_to_color_map) self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size) @@ -315,41 +354,55 @@ class MTex(VMobject): tex_parser = self.get_parser() self.tex_spans_dict = tex_parser.tex_spans_dict self.specified_substrings = tex_parser.specified_substrings + fill_color = self.get_fill_color() - plain_full_tex = self.get_tex_file_body(tex_string) - plain_hash_val = hash(plain_full_tex) - if plain_hash_val in TEX_HASH_TO_MOB_MAP: - self.add(*TEX_HASH_TO_MOB_MAP[plain_hash_val].copy()) - return self + # Cannot simultaneously be false, so at least one file is generated. + require_labelled_tex_file = tex_parser.current_label != 0 + require_plain_tex_file = any([ + self.use_plain_tex_file, + tex_parser.has_existing_color_commands, + tex_parser.current_label == 0 + ]) - labelled_expression = tex_parser.get_labelled_expression() - full_tex = self.get_tex_file_body(labelled_expression) + if require_labelled_tex_file: + labelled_full_tex = self.get_tex_file_body(tex_parser.get_labelled_tex_string()) + labelled_hash_val = hash(labelled_full_tex) + if labelled_hash_val in TEX_HASH_TO_MOB_MAP: + self.add(*TEX_HASH_TO_MOB_MAP[labelled_hash_val].copy()) + else: + with display_during_execution(f"Writing \"{tex_string}\""): + filename = tex_to_svg_file(labelled_full_tex) + labelled_svg_glyphs = _TexSVG(filename).parse_labels() + self.add(*labelled_svg_glyphs) + self.build_submobjects() + TEX_HASH_TO_MOB_MAP[labelled_hash_val] = self.copy() + if not require_plain_tex_file: + self.set_fill(color=fill_color) + return self + + # require_plain_tex_file == True + self.set_submobjects([]) + full_tex = self.get_tex_file_body(tex_string, fill_color=fill_color) hash_val = hash(full_tex) - if hash_val in TEX_HASH_TO_MOB_MAP and not self.generate_plain_tex_file: + if hash_val in TEX_HASH_TO_MOB_MAP: self.add(*TEX_HASH_TO_MOB_MAP[hash_val].copy()) - return self - - with display_during_execution(f"Writing \"{tex_string}\""): - filename = tex_to_svg_file(full_tex) - svg_mob = _LabelledTex(filename) - self.add(*svg_mob.copy()) - self.build_submobjects() - TEX_HASH_TO_MOB_MAP[hash_val] = self - if not self.generate_plain_tex_file: - return self - - with display_during_execution(f"Writing \"{tex_string}\""): - filename = tex_to_svg_file(plain_full_tex) - plain_svg_mob = _PlainTex(filename) - svg_mob = TEX_HASH_TO_MOB_MAP[hash_val] - for plain_submob, submob in zip(plain_svg_mob, svg_mob): - plain_submob.glyph_label = submob.glyph_label - self.add(*plain_svg_mob.copy()) - self.build_submobjects() - TEX_HASH_TO_MOB_MAP[plain_hash_val] = self + else: + with display_during_execution(f"Writing \"{tex_string}\""): + filename = tex_to_svg_file(full_tex) + svg_glyphs = _TexSVG(filename) + if require_labelled_tex_file: + labelled_svg_mob = TEX_HASH_TO_MOB_MAP[labelled_hash_val] + for glyph, labelled_glyph in zip(svg_glyphs, it.chain(*labelled_svg_mob)): + glyph.glyph_label = labelled_glyph.glyph_label + else: + for glyph in svg_glyphs: + glyph.glyph_label = 0 + self.add(*svg_glyphs) + self.build_submobjects() + TEX_HASH_TO_MOB_MAP[hash_val] = self.copy() return self - def get_tex_file_body(self, new_tex): + def get_tex_file_body(self, new_tex, fill_color=None): if self.tex_environment: new_tex = "\n".join([ f"\\begin{{{self.tex_environment}}}", @@ -358,6 +411,17 @@ class MTex(VMobject): ]) if self.alignment: new_tex = "\n".join([self.alignment, new_tex]) + if fill_color: + int_rgb = color_to_int_rgb(fill_color) + color_command = "".join([ + "\\color[RGB]", + "{", + ",".join(map(str, int_rgb)), + "}" + ]) + new_tex = "\n".join( + [color_command, new_tex] + ) tex_config = get_tex_config() return tex_config["tex_body"].replace( @@ -368,6 +432,9 @@ class MTex(VMobject): def build_submobjects(self): if not self.submobjects: return + self.init_colors() + for glyph in self.submobjects: + glyph.set_fill(glyph.fill_color) self.group_submobjects() self.sort_scripts_in_tex_order() self.assign_submob_tex_strings() @@ -383,13 +450,13 @@ class MTex(VMobject): new_glyphs = [] current_glyph_label = 0 - for submob in self.submobjects: - if submob.glyph_label == current_glyph_label: - new_glyphs.append(submob) + for glyph in self.submobjects: + if glyph.glyph_label == current_glyph_label: + new_glyphs.append(glyph) else: append_new_submobject(new_glyphs) - new_glyphs = [submob] - current_glyph_label = submob.glyph_label + new_glyphs = [glyph] + current_glyph_label = glyph.glyph_label append_new_submobject(new_glyphs) self.set_submobjects(new_submobjects) @@ -425,7 +492,7 @@ class MTex(VMobject): ) switch_range_pairs.append((submob_range_0, submob_range_1)) - switch_range_pairs.sort(key=lambda pair: (pair[0].stop, -pair[0].start)) + switch_range_pairs.sort(key=lambda t: (t[0].stop, -t[0].start)) indices = list(range(len(self.submobjects))) for submob_range_0, submob_range_1 in switch_range_pairs: indices = [ @@ -486,22 +553,28 @@ class MTex(VMobject): self.submobjects )) - def find_span_components_of_custom_span(self, custom_span_tuple, partial_result=[]): + def find_span_components_of_custom_span(self, custom_span_tuple): + tex_string = self.tex_string + span_choices = sorted(filter( + lambda t: _contains(custom_span_tuple, t), + self.tex_spans_dict.keys() + )) + # Filter out spans that reach the farthest. + span_choices_dict = dict(span_choices) + span_begin, span_end = custom_span_tuple - if span_begin == span_end: - return partial_result - next_begin_choices = sorted([ - span_tuple[1] - for span_tuple in self.tex_spans_dict.keys() - if span_tuple[0] == span_begin and span_tuple[1] <= span_end - ], reverse=True) - for next_begin in next_begin_choices: - result = self.find_span_components_of_custom_span( - (next_begin, span_end), [*partial_result, (span_begin, next_begin)] - ) - if result is not None: - return result - return None + result = [] + while span_begin != span_end: + if span_begin not in span_choices_dict: + if tex_string[span_begin].strip(): + return None + # Whitespaces may occur between spans. + span_begin += 1 + continue + next_begin = span_choices_dict[span_begin] + result.append((span_begin, next_begin)) + span_begin = next_begin + return result def get_part_by_custom_span_tuple(self, custom_span_tuple): span_tuples = self.find_span_components_of_custom_span(custom_span_tuple) @@ -545,12 +618,6 @@ class MTex(VMobject): part = self.get_part_by_tex(tex, index=index) return self.indices_of_part(part) - def indices_of_all_parts_by_tex(self, tex, index=0): - all_parts = self.get_parts_by_tex(tex) - return list(it.chain(*[ - self.indices_of_part(part) for part in all_parts - ])) - def range_of_part(self, part): indices = self.indices_of_part(part) return range(indices[0], indices[-1] + 1) @@ -576,11 +643,8 @@ class MTex(VMobject): for span_tuple in self.tex_spans_dict.keys() ]) - def list_tex_strings_of_submobjects(self): - # Work with `index_labels()`. - log.debug(f"Submobjects of \"{self.get_tex()}\":") - for i, submob in enumerate(self.submobjects): - log.debug(f"{i}: \"{submob.get_tex()}\"") + def get_specified_substrings(self): + return self.specified_substrings class MTexText(MTex):