diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 77cc9400..2b399e48 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -6,6 +6,7 @@ from types import MethodType from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.mobject.types.vectorized_mobject import VGroup +from manimlib.utils.color import color_to_int_rgb from manimlib.utils.iterables import adjacent_pairs from manimlib.utils.iterables import remove_list_redundancies from manimlib.utils.tex_file_writing import tex_to_svg_file @@ -28,7 +29,7 @@ def _get_neighbouring_pairs(iterable): return list(adjacent_pairs(iterable))[:-1] -class _PlainTex(SVGMobject): +class _TexSVG(SVGMobject): CONFIG = { "height": None, "path_string_config": { @@ -37,13 +38,6 @@ class _PlainTex(SVGMobject): }, } - -class _LabelledTex(_PlainTex): - def __init__(self, file_name=None, **kwargs): - super().__init__(file_name, **kwargs) - for glyph in self: - glyph.glyph_label = _LabelledTex.color_str_to_label(glyph.fill_color) - @staticmethod def color_str_to_label(color_str): if len(color_str) == 4: @@ -53,6 +47,11 @@ class _LabelledTex(_PlainTex): return 0 return int(color_str[1:], 16) + def parse_labels(self): + for glyph in self: + glyph.glyph_label = _TexSVG.color_str_to_label(glyph.fill_color) + return self + class _TexSpan(object): def __init__(self, script_type, label): @@ -77,10 +76,15 @@ class _TexParser(object): self.brace_index_pairs = self.get_brace_index_pairs() self.existing_color_command_spans = self.get_existing_color_command_spans() self.has_existing_color_commands = any(self.existing_color_command_spans.values()) + self.specified_substring_spans = [] self.add_tex_span((0, len(tex_string))) self.break_up_by_double_braces() self.break_up_by_scripts() self.break_up_by_additional_substrings(additional_substrings) + self.specified_substrings = remove_list_redundancies([ + tex_string[slice(*span_tuple)] + for span_tuple in self.specified_substring_spans + ]) self.check_if_overlap() self.analyse_containing_labels() @@ -164,6 +168,7 @@ class _TexParser(object): span_tuple[1] == prev_span_tuple[1] + 1 ]): self.add_tex_span(span_tuple) + self.specified_substring_spans.append(span_tuple) skip_pair = True def break_up_by_scripts(self): @@ -210,7 +215,9 @@ class _TexParser(object): span_end = script_spans_dict[span_end] if span_begin >= span_end: continue - self.add_tex_span((span_begin, span_end)) + span_tuple = (span_begin, span_end) + self.add_tex_span(span_tuple) + self.specified_substring_spans.append(span_tuple) def check_if_overlap(self): span_tuples = sorted( @@ -242,7 +249,7 @@ class _TexParser(object): if _contains(span_1, span_0): tex_span_1.containing_labels.append(tex_span_0.label) - def get_labelled_expression(self): + def get_labelled_tex_string(self): tex_string = self.tex_string if self.current_label == 0 and not self.has_existing_color_commands: return tex_string @@ -294,7 +301,8 @@ class _TexParser(object): return result.replace(color_command_placeholder, "\\color[RGB]") def raise_tex_parsing_error(self, message): - raise ValueError(f"Failed to parse tex ({message}): \"{self.tex_string}\"") + log.error(f"Failed to parse tex ({message}): \"{self.tex_string}\"") + sys.exit(2) @staticmethod def get_color_related_commands_dict(): @@ -327,7 +335,7 @@ class MTex(VMobject): tex_string = "\\quad" self.tex_string = tex_string - self.generate_tex() + self.generate_mobject() self.set_color_by_tex_to_color_map(self.tex_to_color_map) self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size) @@ -343,61 +351,60 @@ class MTex(VMobject): def get_parser(self): return _TexParser(self.tex_string, self.get_additional_substrings_to_break_up()) - def generate_tex(self): + def generate_mobject(self): tex_string = self.tex_string tex_parser = self.get_parser() self.tex_spans_dict = tex_parser.tex_spans_dict - + self.specified_substrings = tex_parser.specified_substrings fill_color = self.get_fill_color() + # Cannot simultaneously be false, so at least one file is generated. require_labelled_tex_file = tex_parser.current_label != 0 require_plain_tex_file = any([ self.use_plain_tex_file, tex_parser.has_existing_color_commands, - fill_color != "#ffffff", tex_parser.current_label == 0 ]) if require_labelled_tex_file: - labelled_full_tex = self.get_tex_file_body(tex_parser.get_labelled_expression()) + labelled_full_tex = self.get_tex_file_body(tex_parser.get_labelled_tex_string()) labelled_hash_val = hash(labelled_full_tex) if labelled_hash_val in TEX_HASH_TO_MOB_MAP: - if not require_plain_tex_file: - self.add(*TEX_HASH_TO_MOB_MAP[labelled_hash_val].copy()) - return self + self.add(*TEX_HASH_TO_MOB_MAP[labelled_hash_val].copy()) else: with display_during_execution(f"Writing \"{tex_string}\""): filename = tex_to_svg_file(labelled_full_tex) - labelled_svg_glyphs = _LabelledTex(filename) + labelled_svg_glyphs = _TexSVG(filename).parse_labels() self.add(*labelled_svg_glyphs) self.build_submobjects() - TEX_HASH_TO_MOB_MAP[labelled_hash_val] = self.copy() + TEX_HASH_TO_MOB_MAP[labelled_hash_val] = self.copy() if not require_plain_tex_file: + self.set_fill(color=fill_color) return self # require_plain_tex_file == True self.set_submobjects([]) - full_tex = self.get_tex_file_body(tex_string) - hash_val = hash((full_tex, fill_color)) + full_tex = self.get_tex_file_body(tex_string, fill_color=fill_color) + hash_val = hash(full_tex) if hash_val in TEX_HASH_TO_MOB_MAP: self.add(*TEX_HASH_TO_MOB_MAP[hash_val].copy()) - return self - with display_during_execution(f"Writing \"{tex_string}\""): - filename = tex_to_svg_file(full_tex) - svg_glyphs = _PlainTex(filename, fill_color=fill_color) - if require_labelled_tex_file: - labelled_svg_mob = TEX_HASH_TO_MOB_MAP[labelled_hash_val] - for glyph, labelled_glyph in zip(svg_glyphs, it.chain(*labelled_svg_mob)): - glyph.glyph_label = labelled_glyph.glyph_label - else: - for glyph in svg_glyphs: - glyph.glyph_label = 0 - self.add(*svg_glyphs) - self.build_submobjects() - TEX_HASH_TO_MOB_MAP[hash_val] = self.copy() + else: + with display_during_execution(f"Writing \"{tex_string}\""): + filename = tex_to_svg_file(full_tex) + svg_glyphs = _TexSVG(filename) + if require_labelled_tex_file: + labelled_svg_mob = TEX_HASH_TO_MOB_MAP[labelled_hash_val] + for glyph, labelled_glyph in zip(svg_glyphs, it.chain(*labelled_svg_mob)): + glyph.glyph_label = labelled_glyph.glyph_label + else: + for glyph in svg_glyphs: + glyph.glyph_label = 0 + self.add(*svg_glyphs) + self.build_submobjects() + TEX_HASH_TO_MOB_MAP[hash_val] = self.copy() return self - def get_tex_file_body(self, new_tex): + def get_tex_file_body(self, new_tex, fill_color=None): if self.tex_environment: new_tex = "\n".join([ f"\\begin{{{self.tex_environment}}}", @@ -406,6 +413,17 @@ class MTex(VMobject): ]) if self.alignment: new_tex = "\n".join([self.alignment, new_tex]) + if fill_color: + int_rgb = color_to_int_rgb(fill_color) + color_command = "".join([ + "\\color[RGB]", + "{", + ",".join(map(str, int_rgb)), + "}" + ]) + new_tex = "\n".join( + [color_command, new_tex] + ) tex_config = get_tex_config() return tex_config["tex_body"].replace( @@ -564,7 +582,8 @@ class MTex(VMobject): span_tuples = self.find_span_components_of_custom_span(custom_span_tuple) if span_tuples is None: tex = self.tex_string[slice(*custom_span_tuple)] - raise ValueError(f"Failed to get span of tex: \"{tex}\"") + log.error(f"Failed to get span of tex: \"{tex}\"") + sys.exit(2) return self.get_part_by_span_tuples(span_tuples) def get_parts_by_tex(self, tex): @@ -595,7 +614,8 @@ class MTex(VMobject): if submob in part ] if not indices: - raise ValueError("Failed to find part in tex") + log.error("Failed to find part in tex") + sys.exit(2) return indices def indices_of_part_by_tex(self, tex, index=0): @@ -633,11 +653,8 @@ class MTex(VMobject): for span_tuple in self.tex_spans_dict.keys() ]) - def list_tex_strings_of_submobjects(self): - # Work with `index_labels()`. - log.debug(f"Submobjects of \"{self.get_tex()}\":") - for i, submob in enumerate(self.submobjects): - log.debug(f"{i}: \"{submob.get_tex()}\"") + def get_specified_substrings(self): + return self.specified_substrings class MTexText(MTex):