diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index ce0807a2..b1621c4f 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -109,8 +109,8 @@ class Mobject(object): "reflectiveness": self.reflectiveness, } - def init_colors(self, override=True): - self.set_color(self.color, self.opacity, override) + def init_colors(self): + self.set_color(self.color, self.opacity) def init_points(self): # Typically implemented in subclass, unlpess purposefully left blank diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index b7854dd3..b1f34147 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -49,23 +49,9 @@ class _LabelledTex(_PlainTex): if len(color_str) == 4: # "#RGB" => "#RRGGBB" color_str = "#" + "".join([c * 2 for c in color_str[1:]]) - - return int(color_str[1:], 16) - 1 - - def get_mobjects_from(self, element, style): - result = super().get_mobjects_from(element, style) - for mob in result: - if not hasattr(mob, "glyph_label"): - mob.glyph_label = -1 - try: - color_str = element.getAttribute("fill") - if color_str: - glyph_label = _LabelledTex.color_str_to_label(color_str) - for mob in result: - mob.glyph_label = glyph_label - except: - pass - return result + if color_str == "#ffffff": + return 0 + return int(color_str[1:], 16) class _TexSpan(object): @@ -87,16 +73,16 @@ class _TexParser(object): def __init__(self, tex_string, additional_substrings): self.tex_string = tex_string self.tex_spans_dict = {} - self.specified_substrings = [] - self.current_label = 0 + self.current_label = -1 self.brace_index_pairs = self.get_brace_index_pairs() + self.existing_color_command_spans = self.get_existing_color_command_spans() + self.has_existing_color_commands = any(self.existing_color_command_spans.values()) self.add_tex_span((0, len(tex_string))) self.break_up_by_double_braces() self.break_up_by_scripts() self.break_up_by_additional_substrings(additional_substrings) self.check_if_overlap() self.analyse_containing_labels() - self.specified_substrings = remove_list_redundancies(self.specified_substrings) @staticmethod def label_to_color_tuple(rgb): @@ -112,16 +98,12 @@ class _TexParser(object): if script_type == 0: # Should be additionally labelled. - label = self.current_label self.current_label += 1 + label = self.current_label tex_span = _TexSpan(script_type, label) self.tex_spans_dict[span_tuple] = tex_span - def add_specified_substring(self, span_tuple): - substring = self.tex_string[slice(*span_tuple)] - self.specified_substrings.append(substring) - def get_brace_index_pairs(self): result = [] left_brace_indices = [] @@ -140,6 +122,34 @@ class _TexParser(object): self.raise_tex_parsing_error("unmatched braces") return result + def get_existing_color_command_spans(self): + tex_string = self.tex_string + color_related_commands_dict = _TexParser.get_color_related_commands_dict() + commands = color_related_commands_dict.keys() + result = { + command_name: [] + for command_name in commands + } + brace_index_pairs = self.brace_index_pairs + pattern = "|".join([ + re.escape(command_name) + for command_name in commands + ]) + for match_obj in re.finditer(pattern, tex_string): + span_tuple = match_obj.span() + command_begin_index = span_tuple[0] + command_name = match_obj.group() + n_braces = color_related_commands_dict[command_name] + for _ in range(n_braces): + span_tuple = min(filter( + lambda t: t[0] >= span_tuple[1], + brace_index_pairs + )) + result[command_name].append( + (command_begin_index, span_tuple[1]) + ) + return result + def break_up_by_double_braces(self): # Match paired double braces (`{{...}}`). skip_pair = False @@ -154,7 +164,6 @@ class _TexParser(object): span_tuple[1] == prev_span_tuple[1] + 1 ]): self.add_tex_span(span_tuple) - self.add_specified_substring(span_tuple) skip_pair = True def break_up_by_scripts(self): @@ -201,9 +210,7 @@ class _TexParser(object): span_end = script_spans_dict[span_end] if span_begin >= span_end: continue - span_tuple = (span_begin, span_end) - self.add_tex_span(span_tuple) - self.add_specified_substring(span_tuple) + self.add_tex_span((span_begin, span_end)) def check_if_overlap(self): span_tuples = sorted( @@ -237,7 +244,7 @@ class _TexParser(object): def get_labelled_expression(self): tex_string = self.tex_string - if not self.tex_spans_dict: + if self.current_label == 0 and not self.has_existing_color_commands: return tex_string # Remove the span of extire tex string. @@ -248,6 +255,9 @@ class _TexParser(object): for i in range(2) ], key=lambda t: (t[0], -t[1], -t[2]))[1:] + # Prevent from "\\color[RGB]" being replaced. + # Hopefully tex string doesn't contain such a substring... + color_command_placeholder = "{{\\iffalse \\fi}}" result = tex_string[: indices_with_labels[0][0]] for index_with_label, next_index_with_label in _get_neighbouring_pairs( indices_with_labels @@ -259,7 +269,7 @@ class _TexParser(object): color_tuple = _TexParser.label_to_color_tuple(label) result += "".join([ "{{", - "\\color[RGB]", + color_command_placeholder, "{", ",".join(map(str, color_tuple)), "}" @@ -267,11 +277,35 @@ class _TexParser(object): else: result += "}}" result += tex_string[index : next_index] - return result + + color_related_commands_dict = _TexParser.get_color_related_commands_dict() + for command_name, command_spans in self.existing_color_command_spans.items(): + if not command_spans: + continue + n_braces = color_related_commands_dict[command_name] + command_to_replace = command_name + n_braces * "{black}" + commands = { + tex_string[slice(*span_tuple)] + for span_tuple in command_spans + } + for command in commands: + result = result.replace(command, command_to_replace) + + return result.replace(color_command_placeholder, "\\color[RGB]") def raise_tex_parsing_error(self, message): raise ValueError(f"Failed to parse tex ({message}): \"{self.tex_string}\"") + @staticmethod + def get_color_related_commands_dict(): + return { + "\\color": 1, + "\\textcolor": 1, + "\\pagecolor": 1, + "\\colorbox": 1, + "\\fcolorbox": 2, + } + class MTex(VMobject): CONFIG = { @@ -282,7 +316,7 @@ class MTex(VMobject): "tex_environment": "align*", "isolate": [], "tex_to_color_map": {}, - "generate_plain_tex_file": False, + "use_plain_tex_file": False, } def __init__(self, tex_string, **kwargs): @@ -293,9 +327,8 @@ class MTex(VMobject): tex_string = "\\quad" self.tex_string = tex_string - self.generate_mobject() + self.generate_tex() - self.init_colors() self.set_color_by_tex_to_color_map(self.tex_to_color_map) self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size) @@ -310,43 +343,67 @@ class MTex(VMobject): def get_parser(self): return _TexParser(self.tex_string, self.get_additional_substrings_to_break_up()) - def generate_mobject(self): + def generate_tex(self): tex_string = self.tex_string tex_parser = self.get_parser() self.tex_spans_dict = tex_parser.tex_spans_dict - self.specified_substrings = tex_parser.specified_substrings - plain_full_tex = self.get_tex_file_body(tex_string) - plain_hash_val = hash(plain_full_tex) - if plain_hash_val in TEX_HASH_TO_MOB_MAP: - self.add(*TEX_HASH_TO_MOB_MAP[plain_hash_val].copy()) - return self + fill_color = self.get_fill_color() + stroke_width = self.get_stroke_width() - labelled_expression = tex_parser.get_labelled_expression() - full_tex = self.get_tex_file_body(labelled_expression) - hash_val = hash(full_tex) - if hash_val in TEX_HASH_TO_MOB_MAP and not self.generate_plain_tex_file: + # Cannot simultaneously be false, so at least one file is generated. + require_labelled_tex_file = tex_parser.current_label != 0 + require_plain_tex_file = any([ + self.use_plain_tex_file, + tex_parser.has_existing_color_commands, + fill_color != "#ffffff", + tex_parser.current_label == 0 + ]) + + if require_labelled_tex_file: + labelled_full_tex = self.get_tex_file_body(tex_parser.get_labelled_expression()) + labelled_hash_val = hash((labelled_full_tex, stroke_width)) + if labelled_hash_val in TEX_HASH_TO_MOB_MAP: + if not require_plain_tex_file: + self.add(*TEX_HASH_TO_MOB_MAP[labelled_hash_val].copy()) + return self + else: + with display_during_execution(f"Writing \"{tex_string}\""): + filename = tex_to_svg_file(labelled_full_tex) + labelled_svg_glyphs = _LabelledTex( + filename, + stroke_width=stroke_width + ) + self.add(*labelled_svg_glyphs) + self.build_submobjects() + TEX_HASH_TO_MOB_MAP[labelled_hash_val] = self.copy() + if not require_plain_tex_file: + return self + + # require_plain_tex_file == True + self.set_submobjects([]) + full_tex = self.get_tex_file_body(tex_string) + hash_val = hash((full_tex, fill_color, stroke_width)) + if hash_val in TEX_HASH_TO_MOB_MAP: self.add(*TEX_HASH_TO_MOB_MAP[hash_val].copy()) return self - with display_during_execution(f"Writing \"{tex_string}\""): filename = tex_to_svg_file(full_tex) - svg_mob = _LabelledTex(filename) - self.add(*svg_mob.copy()) + svg_glyphs = _PlainTex( + filename, + fill_color=fill_color, + stroke_width=stroke_width + ) + if require_labelled_tex_file: + labelled_svg_mob = TEX_HASH_TO_MOB_MAP[labelled_hash_val] + for glyph, labelled_glyph in zip(svg_glyphs, it.chain(*labelled_svg_mob)): + glyph.glyph_label = labelled_glyph.glyph_label + else: + for glyph in svg_glyphs: + glyph.glyph_label = 0 + self.add(*svg_glyphs) self.build_submobjects() - TEX_HASH_TO_MOB_MAP[hash_val] = self - if not self.generate_plain_tex_file: - return self - - with display_during_execution(f"Writing \"{tex_string}\""): - filename = tex_to_svg_file(plain_full_tex) - plain_svg_mob = _PlainTex(filename) - svg_mob = TEX_HASH_TO_MOB_MAP[hash_val] - for plain_submob, submob in zip(plain_svg_mob, svg_mob): - plain_submob.glyph_label = submob.glyph_label - self.add(*plain_svg_mob.copy()) - self.build_submobjects() - TEX_HASH_TO_MOB_MAP[plain_hash_val] = self + TEX_HASH_TO_MOB_MAP[hash_val] = self.copy() return self def get_tex_file_body(self, new_tex): @@ -368,6 +425,9 @@ class MTex(VMobject): def build_submobjects(self): if not self.submobjects: return + self.init_colors() + for glyph in self.submobjects: + glyph.set_fill(glyph.fill_color) self.group_submobjects() self.sort_scripts_in_tex_order() self.assign_submob_tex_strings() @@ -383,13 +443,13 @@ class MTex(VMobject): new_glyphs = [] current_glyph_label = 0 - for submob in self.submobjects: - if submob.glyph_label == current_glyph_label: - new_glyphs.append(submob) + for glyph in self.submobjects: + if glyph.glyph_label == current_glyph_label: + new_glyphs.append(glyph) else: append_new_submobject(new_glyphs) - new_glyphs = [submob] - current_glyph_label = submob.glyph_label + new_glyphs = [glyph] + current_glyph_label = glyph.glyph_label append_new_submobject(new_glyphs) self.set_submobjects(new_submobjects) @@ -425,7 +485,7 @@ class MTex(VMobject): ) switch_range_pairs.append((submob_range_0, submob_range_1)) - switch_range_pairs.sort(key=lambda pair: (pair[0].stop, -pair[0].start)) + switch_range_pairs.sort(key=lambda t: (t[0].stop, -t[0].start)) indices = list(range(len(self.submobjects))) for submob_range_0, submob_range_1 in switch_range_pairs: indices = [ @@ -486,22 +546,28 @@ class MTex(VMobject): self.submobjects )) - def find_span_components_of_custom_span(self, custom_span_tuple, partial_result=[]): + def find_span_components_of_custom_span(self, custom_span_tuple): + tex_string = self.tex_string + span_choices = sorted(filter( + lambda t: _contains(custom_span_tuple, t), + self.tex_spans_dict.keys() + )) + # Filter out spans that reach the farthest. + span_choices_dict = dict(span_choices) + span_begin, span_end = custom_span_tuple - if span_begin == span_end: - return partial_result - next_begin_choices = sorted([ - span_tuple[1] - for span_tuple in self.tex_spans_dict.keys() - if span_tuple[0] == span_begin and span_tuple[1] <= span_end - ], reverse=True) - for next_begin in next_begin_choices: - result = self.find_span_components_of_custom_span( - (next_begin, span_end), [*partial_result, (span_begin, next_begin)] - ) - if result is not None: - return result - return None + result = [] + while span_begin != span_end: + if span_begin not in span_choices_dict: + if tex_string[span_begin].strip(): + return None + # Whitespaces may occur between spans. + span_begin += 1 + continue + next_begin = span_choices_dict[span_begin] + result.append((span_begin, next_begin)) + span_begin = next_begin + return result def get_part_by_custom_span_tuple(self, custom_span_tuple): span_tuples = self.find_span_components_of_custom_span(custom_span_tuple)