diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py deleted file mode 100644 index 55909dc6..00000000 --- a/manimlib/mobject/svg/mtex_mobject.py +++ /dev/null @@ -1,843 +0,0 @@ -from __future__ import annotations - -import re -import colour -import itertools as it -#from types import MethodType -from typing import Iterable, Union, Sequence -from abc import abstractmethod - -from manimlib.constants import WHITE -from manimlib.mobject.svg.svg_mobject import SVGMobject -from manimlib.mobject.types.vectorized_mobject import VGroup -from manimlib.utils.color import color_to_int_rgb -from manimlib.utils.config_ops import digest_config -from manimlib.utils.iterables import adjacent_pairs -from manimlib.utils.iterables import remove_list_redundancies -from manimlib.utils.tex_file_writing import tex_to_svg_file -from manimlib.utils.tex_file_writing import get_tex_config -from manimlib.utils.tex_file_writing import display_during_execution -from manimlib.logger import log - - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from manimlib.mobject.types.vectorized_mobject import VMobject - ManimColor = Union[str, colour.Color, Sequence[float]] - Span = tuple[int, int] - - -SCALE_FACTOR_PER_FONT_POINT = 0.001 - - -class _StringSVG(SVGMobject): - CONFIG = { - "height": None, - "stroke_width": 0, - "stroke_color": WHITE, - "path_string_config": { - "should_subdivide_sharp_curves": True, - "should_remove_null_curves": True, - }, - } - - -class LabelledString(_StringSVG): - """ - An abstract base class for `MTex` and `MarkupText` - """ - CONFIG = { - "base_color": WHITE, - "use_plain_file": False, - "isolate": [], - } - - def __init__(self, string: str, **kwargs): - self.string = string - digest_config(self, kwargs) - self.pre_parse() - self.parse() - super().__init__(**kwargs) - - def get_file_path(self, use_plain_file: bool = False) -> str: - content = self.get_decorated_string(use_plain_file=use_plain_file) - return self.get_file_path_by_content(content) - - @abstractmethod - def get_file_path_by_content(self, content: str) -> str: - return "" - - def generate_mobject(self) -> None: - super().generate_mobject() - - submob_labels = [ - self.color_to_label(submob.get_fill_color()) - for submob in self.submobjects - ] - if self.use_plain_file or self.has_predefined_colors: - file_path = self.get_file_path(use_plain_file=True) - plain_svg = _StringSVG(file_path) - self.set_submobjects(plain_svg.submobjects) - else: - self.set_fill(self.base_color) - for submob, label in zip(self.submobjects, submob_labels): - submob.label = label - self.submob_labels = submob_labels - self.post_parse() - - def pre_parse(self) -> None: - self.full_span = self.get_full_span() - self.space_spans = self.get_space_spans() - - def parse(self) -> None: - self.command_repl_items = self.get_command_repl_items() - self.command_spans = self.get_command_spans() - self.ignored_indices = self.get_ignored_indices() - self.skipped_spans = self.get_skipped_spans() - self.internal_specified_spans = self.get_internal_specified_spans() - self.external_specified_spans = self.get_external_specified_spans() - self.specified_spans = self.get_specified_spans() - self.label_span_list = self.get_label_span_list() - self.has_predefined_colors = self.get_has_predefined_colors() - - def post_parse(self) -> None: - self.containing_labels_dict = self.get_containing_labels_dict() - self.specified_substrings = self.get_specified_substrings() - self.group_substr_items = self.get_group_substr_items() - self.group_substrs = self.get_group_substrs() - - # Toolkits - - def find_spans(self, pattern: str) -> list[Span]: - return [ - match_obj.span() - for match_obj in re.finditer(pattern, self.string) - ] - - @staticmethod - def get_neighbouring_pairs(iterable: Iterable) -> list: - return list(adjacent_pairs(iterable))[:-1] - - @staticmethod - def span_contains(span_0: Span, span_1: Span) -> bool: - return span_0[0] <= span_1[0] and span_0[1] >= span_1[1] - - @staticmethod - def compress_neighbours(vals: list[int]) -> list[tuple[int, Span]]: - if not vals: - return [] - - unique_vals = [vals[0]] - indices = [0] - for index, val in enumerate(vals): - if val == unique_vals[-1]: - continue - unique_vals.append(val) - indices.append(index) - indices.append(len(vals)) - spans = LabelledString.get_neighbouring_pairs(indices) - return list(zip(unique_vals, spans)) - - @staticmethod - def find_region_index(seq: list[int], val: int) -> int: - # Returns an integer in `range(-1, len(seq))` satisfying - # `seq[result] <= val < seq[result + 1]`. - # `seq` should be sorted in ascending order. - if not seq or val < seq[0]: - return -1 - result = len(seq) - 1 - while val < seq[result]: - result -= 1 - return result - - @staticmethod - def take_nearest_value(seq: list[int], val: int, index_shift: int) -> int: - sorted_seq = sorted(seq) - index = LabelledString.find_region_index(sorted_seq, val) - if index == -1: - raise IndexError - return sorted_seq[index + index_shift] - - @staticmethod - def replace_str_by_spans( - substr: str, span_repl_dict: dict[Span, str] - ) -> str: - if not span_repl_dict: - return substr - - spans = sorted(span_repl_dict.keys()) - if not all( - span_0[1] <= span_1[0] - for span_0, span_1 in LabelledString.get_neighbouring_pairs(spans) - ): - raise ValueError("Overlapping replacement") - - span_ends, span_begins = zip(*spans) - pieces = [ - substr[slice(*span)] - for span in zip( - (0, *span_begins), - (*span_ends, len(substr)) - ) - ] - repl_strs = [*[span_repl_dict[span] for span in spans], ""] - return "".join(it.chain(*zip(pieces, repl_strs))) - - @staticmethod - def get_span_replacement_dict( - inserted_string_pairs: list[tuple[Span, tuple[str, str]]], - other_repl_items: list[tuple[Span, str]] - ) -> dict[Span, str]: - if not inserted_string_pairs: - return other_repl_items.copy() - - indices, _, _, inserted_strings = zip(*sorted([ - ( - span[flag], - -flag, - -span[1 - flag], - str_pair[flag] - ) - for span, str_pair in inserted_string_pairs - for flag in range(2) - ])) - result = { - (index, index): "".join(inserted_strings[slice(*item_span)]) - for index, item_span - in LabelledString.compress_neighbours(indices) - } - result.update(other_repl_items) - return result - - @staticmethod - def rslide(index: int, skipped: list[Span]) -> int: - transfer_dict = dict(sorted(skipped)) - while index in transfer_dict.keys(): - index = transfer_dict[index] - return index - - @staticmethod - def lslide(index: int, skipped: list[Span]) -> int: - transfer_dict = dict(sorted([ - skipped_span[::-1] for skipped_span in skipped - ], reverse=True)) - while index in transfer_dict.keys(): - index = transfer_dict[index] - return index - - @staticmethod - def shrink_span(span: Span, skipped: list[Span]) -> Span: - return ( - LabelledString.rslide(span[0], skipped), - LabelledString.lslide(span[1], skipped) - ) - - @staticmethod - def extend_span(span: Span, skipped: list[Span]) -> Span: - return ( - LabelledString.lslide(span[0], skipped), - LabelledString.rslide(span[1], skipped) - ) - - @staticmethod - def rgb_to_int(rgb_tuple: tuple[int, int, int]) -> int: - r, g, b = rgb_tuple - rg = r * 256 + g - return rg * 256 + b - - @staticmethod - def int_to_rgb(rgb_int: int) -> tuple[int, int, int]: - rg, b = divmod(rgb_int, 256) - r, g = divmod(rg, 256) - return r, g, b - - @staticmethod - def color_to_label(color: ManimColor) -> int: - rgb_tuple = color_to_int_rgb(color) - rgb = LabelledString.rgb_to_int(rgb_tuple) - if rgb == 16777215: # white - return -1 - return rgb - - @abstractmethod - def get_begin_color_command_str(int_rgb: int) -> str: - return "" - - @abstractmethod - def get_end_color_command_str() -> str: - return "" - - # Pre-parsing - - def get_full_span(self) -> Span: - return (0, len(self.string)) - - def get_space_spans(self) -> list[Span]: - return self.find_spans(r"\s+") - - @abstractmethod - def get_command_repl_items(self) -> list[tuple[Span, str]]: - return [] - - def get_command_spans(self) -> list[Span]: - return [cmd_span for cmd_span, _ in self.command_repl_items] - - def get_ignored_indices(self) -> list[int]: - return [] - - def get_skipped_spans(self) -> list[Span]: - return list(it.chain( - self.space_spans, - self.command_spans, - [ - (index, index + 1) - for index in self.ignored_indices - ] - )) - - @abstractmethod - def get_internal_specified_spans(self) -> list[Span]: - return [] - - def get_external_specified_spans(self) -> list[Span]: - return remove_list_redundancies(list(it.chain(*[ - self.find_spans(re.escape(substr.strip())) - for substr in self.isolate - ]))) - - def get_specified_spans(self) -> list[Span]: - spans = [ - self.full_span, - *self.internal_specified_spans, - *self.external_specified_spans - ] - shrinked_spans = list(filter( - lambda span: span[0] < span[1], - [ - self.shrink_span(span, self.skipped_spans) - for span in spans - ] - )) - return remove_list_redundancies(shrinked_spans) - - @abstractmethod - def get_label_span_list(self) -> list[Span]: - return [] - - @abstractmethod - def get_inserted_string_pairs( - self, use_plain_file: bool - ) -> list[tuple[Span, tuple[str, str]]]: - return [] - - @abstractmethod - def get_other_repl_items( - self, use_plain_file: bool - ) -> list[tuple[Span, str]]: - return [] - - def get_decorated_string(self, use_plain_file: bool) -> str: - span_repl_dict = self.get_span_replacement_dict( - self.get_inserted_string_pairs(use_plain_file), - self.get_other_repl_items(use_plain_file) - ) - result = self.replace_str_by_spans(self.string, span_repl_dict) - - if not use_plain_file: - return result - return "".join([ - self.get_begin_color_command_str( - self.rgb_to_int(color_to_int_rgb(self.base_color)) - ), - result, - self.get_end_color_command_str() - ]) - - @abstractmethod - def get_has_predefined_colors(self) -> bool: - return False - - # Post-parsing - - def get_containing_labels_dict(self) -> dict[Span, list[int]]: - label_span_list = self.label_span_list - result = { - span: [] - for span in label_span_list - } - for span_0 in label_span_list: - for span_index, span_1 in enumerate(label_span_list): - if self.span_contains(span_0, span_1): - result[span_0].append(span_index) - elif span_0[0] < span_1[0] < span_0[1] < span_1[1]: - string_0, string_1 = [ - self.string[slice(*span)] - for span in [span_0, span_1] - ] - raise ValueError( - "Partially overlapping substrings detected: " - f"'{string_0}' and '{string_1}'" - ) - if self.full_span not in result: - result[self.full_span] = list(range(len(label_span_list))) - return result - - def get_cleaned_substr(self, span: Span) -> str: - span_repl_dict = { - tuple([index - span[0] for index in cmd_span]): "" - for cmd_span in self.command_spans - if self.span_contains(span, cmd_span) - } - return self.replace_str_by_spans( - self.string[slice(*span)], span_repl_dict - ) - - def get_specified_substrings(self) -> list[str]: - return remove_list_redundancies([ - self.get_cleaned_substr(span) - for span in self.specified_spans - ]) - - def get_group_substr_items(self) -> tuple[list[Span], list[str]]: - group_labels, submob_spans = zip( - *self.compress_neighbours(self.submob_labels) - ) - ordered_spans = [ - self.label_span_list[label] if label != -1 else self.full_span - for label in group_labels - ] - ordered_containing_labels = [ - self.containing_labels_dict[span] - for span in ordered_spans - ] - ordered_span_begins, ordered_span_ends = zip(*ordered_spans) - span_begins = [ - prev_end if prev_label in containing_labels else curr_begin - for prev_end, prev_label, containing_labels, curr_begin in zip( - ordered_span_ends[:-1], group_labels[:-1], - ordered_containing_labels[1:], ordered_span_begins[1:] - ) - ] - span_ends = [ - next_begin if next_label in containing_labels else curr_end - for next_begin, next_label, containing_labels, curr_end in zip( - ordered_span_begins[1:], group_labels[1:], - ordered_containing_labels[:-1], ordered_span_ends[:-1] - ) - ] - spans = list(zip( - (ordered_span_begins[0], *span_begins), - (*span_ends, ordered_span_ends[-1]) - )) - shrinked_spans = [ - self.shrink_span(span, self.skipped_spans) - for span in spans - ] - group_substrs = [ - self.get_cleaned_substr(span) if span[0] < span[1] else "" - for span in shrinked_spans - ] - return submob_spans, group_substrs - - def get_group_substrs(self) -> list[str]: - return self.group_substr_items[1] - - # Selector - - def find_span_components_of_custom_span( - self, custom_span: Span - ) -> list[Span]: - indices = remove_list_redundancies(list(it.chain( - self.full_span, - *self.label_span_list - ))) - span_begin = self.take_nearest_value(indices, custom_span[0], 0) - span_end = self.take_nearest_value(indices, custom_span[1] - 1, 1) - span_choices = sorted(filter( - lambda span: self.span_contains((span_begin, span_end), span), - self.label_span_list - )) - # Choose spans that reach the farthest. - span_choices_dict = dict(span_choices) - - result = [] - while span_begin < span_end: - if span_begin not in span_choices_dict.keys(): - span_begin += 1 - continue - next_begin = span_choices_dict[span_begin] - result.append((span_begin, next_begin)) - span_begin = next_begin - return result - - def get_parts_by_custom_span(self, custom_span: Span) -> VGroup: - spans = self.find_span_components_of_custom_span(custom_span) - labels = set(it.chain(*[ - self.containing_labels_dict[span] - for span in spans - ])) - return VGroup(*filter( - lambda submob: submob.label in labels, - self.submobjects - )) - - def get_parts_by_string(self, substr: str) -> VGroup: - return VGroup(*[ - self.get_parts_by_custom_span(span) - for span in self.find_spans(re.escape(substr.strip())) - ]) - - def get_parts_by_group_substr(self, substr: str) -> VGroup: - return VGroup(*[ - VGroup(*self.submobjects[slice(*submob_span)]) - for submob_span, group_substr in zip(*self.group_substr_items) - if group_substr == substr - ]) - - def get_part_by_string(self, substr: str, index : int = 0) -> VMobject: - return self.get_parts_by_string(substr)[index] - - def set_color_by_string(self, substr: str, color: ManimColor): - self.get_parts_by_string(substr).set_color(color) - return self - - def set_color_by_string_to_color_map( - self, string_to_color_map: dict[str, ManimColor] - ): - for substr, color in string_to_color_map.items(): - self.set_color_by_string(substr, color) - return self - - def indices_of_part(self, part: Iterable[VMobject]) -> list[int]: - return [self.submobjects.index(submob) for submob in part] - - def indices_lists_of_parts( - self, parts: Iterable[Iterable[VMobject]] - ) -> list[list[int]]: - return [self.indices_of_part(part) for part in parts] - - def get_string(self) -> str: - return self.string - - -class MTex(LabelledString): - CONFIG = { - "font_size": 48, - "alignment": "\\centering", - "tex_environment": "align*", - "tex_to_color_map": {}, - } - - def __init__(self, tex_string: str, **kwargs): - digest_config(self, kwargs) - tex_string = tex_string.strip() - # Prevent from passing an empty string. - if not tex_string: - tex_string = "\\quad" - self.tex_string = tex_string - self.isolate.extend(self.tex_to_color_map.keys()) - super().__init__(tex_string, **kwargs) - - self.set_color_by_tex_to_color_map(self.tex_to_color_map) - self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size) - - @property - def hash_seed(self) -> tuple: - return ( - self.__class__.__name__, - self.svg_default, - self.path_string_config, - self.base_color, - self.use_plain_file, - self.isolate, - self.tex_string, - self.alignment, - self.tex_environment, - self.tex_to_color_map - ) - - def get_file_path_by_content(self, content: str) -> str: - full_tex = self.get_tex_file_body(content) - with display_during_execution(f"Writing \"{self.string}\""): - file_path = self.tex_to_svg_file_path(full_tex) - return file_path - - def get_tex_file_body(self, content: str) -> str: - if self.tex_environment: - content = "\n".join([ - f"\\begin{{{self.tex_environment}}}", - content, - f"\\end{{{self.tex_environment}}}" - ]) - if self.alignment: - content = "\n".join([self.alignment, content]) - - tex_config = get_tex_config() - return tex_config["tex_body"].replace( - tex_config["text_to_replace"], - content - ) - - @staticmethod - def tex_to_svg_file_path(tex_file_content: str) -> str: - return tex_to_svg_file(tex_file_content) - - def pre_parse(self) -> None: - super().pre_parse() - self.backslash_indices = self.get_backslash_indices() - self.brace_index_pairs = self.get_brace_index_pairs() - self.script_char_indices = self.get_script_char_indices() - self.script_char_spans = self.get_script_char_spans() - self.script_content_spans = self.get_script_content_spans() - self.script_spans = self.get_script_spans() - - # Toolkits - - @staticmethod - def get_begin_color_command_str(rgb_int: int) -> str: - rgb_tuple = MTex.int_to_rgb(rgb_int) - return "".join([ - "{{", - "\\color[RGB]", - "{", - ",".join(map(str, rgb_tuple)), - "}" - ]) - - @staticmethod - def get_end_color_command_str() -> str: - return "}}" - - # Pre-parsing - - def get_backslash_indices(self) -> list[int]: - # Newlines (`\\`) don't count. - return [ - span[1] - 1 - for span in self.find_spans(r"\\+") - if (span[1] - span[0]) % 2 == 1 - ] - - def get_unescaped_char_indices(self, *chars: str): - return sorted(filter( - lambda index: index - 1 not in self.backslash_indices, - [ - span[0] - for char in chars - for span in self.find_spans(re.escape(char)) - ] - )) - - def get_brace_index_pairs(self) -> list[Span]: - string = self.string - indices = self.get_unescaped_char_indices("{", "}") - left_brace_indices = [] - right_brace_indices = [] - left_brace_indices_stack = [] - for index in indices: - if string[index] == "{": - left_brace_indices_stack.append(index) - else: - if not left_brace_indices_stack: - raise ValueError("Missing '{' inserted") - left_brace_index = left_brace_indices_stack.pop() - left_brace_indices.append(left_brace_index) - right_brace_indices.append(index) - if left_brace_indices_stack: - raise ValueError("Missing '}' inserted") - return list(zip(left_brace_indices, right_brace_indices)) - - def get_script_char_indices(self) -> list[int]: - return self.get_unescaped_char_indices("_", "^") - - def get_script_char_spans(self) -> list[Span]: - return [ - self.extend_span((index, index + 1), self.space_spans) - for index in self.script_char_indices - ] - - def get_script_content_spans(self) -> list[Span]: - result = [] - brace_indices_dict = dict(self.brace_index_pairs) - for _, span_begin in self.script_char_spans: - if span_begin in brace_indices_dict.keys(): - span_end = brace_indices_dict[span_begin] + 1 - else: - pattern = re.compile(r"[a-zA-Z0-9]|\\[a-zA-Z]+") - match_obj = pattern.match(self.string, pos=span_begin) - if not match_obj: - script_name = { - "_": "subscript", - "^": "superscript" - }[script_char] - raise ValueError( - f"Unclear {script_name} detected while parsing. " - "Please use braces to clarify" - ) - span_end = match_obj.end() - result.append((span_begin, span_end)) - return result - - def get_script_spans(self) -> list[Span]: - return [ - (script_char_span[0], script_content_span[1]) - for script_char_span, script_content_span in zip( - self.script_char_spans, self.script_content_spans - ) - ] - - # Parsing - - def get_command_repl_items(self) -> list[tuple[Span, str]]: - color_related_command_dict = { - "color": (1, False), - "textcolor": (1, False), - "pagecolor": (1, True), - "colorbox": (1, True), - "fcolorbox": (2, True), - } - result = [] - backslash_indices = self.backslash_indices - right_brace_indices = [ - right_index - for left_index, right_index in self.brace_index_pairs - ] - pattern = "".join([ - r"\\", - "(", - "|".join(color_related_command_dict.keys()), - ")", - r"(?![a-zA-Z])" - ]) - for match_obj in re.finditer(pattern, self.string): - span_begin, cmd_end = match_obj.span() - if span_begin not in backslash_indices: - continue - cmd_name = match_obj.group(1) - n_braces, substitute_cmd = color_related_command_dict[cmd_name] - span_end = self.take_nearest_value( - right_brace_indices, cmd_end, n_braces - ) + 1 - if substitute_cmd: - repl_str = "\\" + cmd_name + n_braces * "{white}" - else: - repl_str = "" - result.append(((span_begin, span_end), repl_str)) - return result - - def get_ignored_indices(self) -> list[int]: - return self.script_char_indices - - def get_internal_specified_spans(self) -> list[Span]: - # Match paired double braces (`{{...}}`). - result = [] - reversed_brace_indices_dict = dict([ - pair[::-1] for pair in self.brace_index_pairs - ]) - skip = False - for prev_right_index, right_index in self.get_neighbouring_pairs( - list(reversed_brace_indices_dict.keys()) - ): - if skip: - skip = False - continue - if right_index != prev_right_index + 1: - continue - left_index = reversed_brace_indices_dict[right_index] - prev_left_index = reversed_brace_indices_dict[prev_right_index] - if left_index != prev_left_index - 1: - continue - result.append((left_index, right_index + 1)) - skip = True - return result - - def get_label_span_list(self) -> list[Span]: - result = self.script_content_spans.copy() - for span_begin, span_end in self.specified_spans: - shrinked_end = self.lslide(span_end, self.script_spans) - if span_begin >= shrinked_end: - continue - shrinked_span = (span_begin, shrinked_end) - if shrinked_span in result: - continue - result.append(shrinked_span) - return result - - def get_inserted_string_pairs( - self, use_plain_file: bool - ) -> list[tuple[Span, tuple[str, str]]]: - if use_plain_file: - return [] - - extended_label_span_list = [ - span - if span in self.script_content_spans - else (span[0], self.rslide(span[1], self.script_spans)) - for span in self.label_span_list - ] - return [ - (span, ( - self.get_begin_color_command_str(label), - self.get_end_color_command_str() - )) - for label, span in enumerate(extended_label_span_list) - ] - - def get_other_repl_items( - self, use_plain_file: bool - ) -> list[tuple[Span, str]]: - if use_plain_file: - return [] - return self.command_repl_items.copy() - - def get_has_predefined_colors(self) -> bool: - return bool(self.command_repl_items) - - def get_cleaned_substr(self, span: Span) -> str: - substr = super().get_cleaned_substr(span) - if not self.brace_index_pairs: - return substr - - # Balance braces. - left_brace_indices, right_brace_indices = zip(*self.brace_index_pairs) - unclosed_left_braces = 0 - unclosed_right_braces = 0 - for index in range(*span): - if index in left_brace_indices: - unclosed_left_braces += 1 - elif index in right_brace_indices: - if unclosed_left_braces == 0: - unclosed_right_braces += 1 - else: - unclosed_left_braces -= 1 - return "".join([ - unclosed_right_braces * "{", - substr, - unclosed_left_braces * "}" - ]) - - # Method alias - - def get_parts_by_tex(self, tex: str) -> VGroup: - return self.get_parts_by_string(tex) - - def get_part_by_tex(self, tex: str) -> VMobject: - return self.get_part_by_string(tex) - - def set_color_by_tex(self, tex: str, color: ManimColor): - return self.set_color_by_string(tex, color) - - def set_color_by_tex_to_color_map( - self, tex_to_color_map: dict[str, ManimColor] - ): - return self.set_color_by_string_to_color_map(tex_to_color_map) - - def get_tex(self) -> str: - return self.get_string() - - -class MTexText(MTex): - CONFIG = { - "tex_environment": None, - }