From 519e2f4f1e574766bfe0a6660b8214143a076d4d Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Wed, 23 Mar 2022 12:21:40 +0800 Subject: [PATCH 01/48] Adjust some typings --- docs/source/documentation/constants.rst | 2 -- manimlib/constants.py | 2 -- manimlib/mobject/svg/mtex_mobject.py | 41 ++++++++++++--------- manimlib/mobject/svg/svg_mobject.py | 24 ------------- manimlib/mobject/svg/text_mobject.py | 47 +++++++++++++------------ 5 files changed, 48 insertions(+), 68 deletions(-) diff --git a/docs/source/documentation/constants.rst b/docs/source/documentation/constants.rst index cbf96bae..7dabd50e 100644 --- a/docs/source/documentation/constants.rst +++ b/docs/source/documentation/constants.rst @@ -84,8 +84,6 @@ Text .. code-block:: python - START_X = 30 - START_Y = 20 NORMAL = "NORMAL" ITALIC = "ITALIC" OBLIQUE = "OBLIQUE" diff --git a/manimlib/constants.py b/manimlib/constants.py index 6c82bba2..590a9cda 100644 --- a/manimlib/constants.py +++ b/manimlib/constants.py @@ -64,8 +64,6 @@ JOINT_TYPE_MAP = { } # Related to Text -START_X = 30 -START_Y = 20 NORMAL = "NORMAL" ITALIC = "ITALIC" OBLIQUE = "OBLIQUE" diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index d3617d0b..bf126881 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -18,7 +18,12 @@ from manimlib.utils.tex_file_writing import get_tex_config from manimlib.utils.tex_file_writing import display_during_execution from manimlib.logger import log -ManimColor = Union[str, colour.Color, Sequence[float]] + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from manimlib.mobject.types.vectorized_mobject import VMobject + ManimColor = Union[str, colour.Color, Sequence[float]] SCALE_FACTOR_PER_FONT_POINT = 0.001 @@ -29,7 +34,7 @@ def _get_neighbouring_pairs(iterable: Iterable) -> list: class _TexParser(object): - def __init__(self, tex_string: str, additional_substrings: str): + def __init__(self, tex_string: str, additional_substrings: list[str]): self.tex_string = tex_string self.whitespace_indices = self.get_whitespace_indices() self.backslash_indices = self.get_backslash_indices() @@ -173,7 +178,7 @@ class _TexParser(object): def break_up_by_additional_substrings( self, - additional_substrings: Iterable[str] + additional_substrings: list[str] ) -> None: stripped_substrings = sorted(remove_list_redundancies([ string.strip() @@ -257,7 +262,7 @@ class _TexParser(object): "}" ]) - def get_sorted_submob_indices(self, submob_labels: Iterable[int]) -> list[int]: + def get_sorted_submob_indices(self, submob_labels: list[int]) -> list[int]: def script_span_to_submob_range(script_span): tex_span = self.script_span_to_tex_span_dict[script_span] submob_indices = [ @@ -291,7 +296,7 @@ class _TexParser(object): ] return result - def get_submob_tex_strings(self, submob_labels: Iterable[int]) -> list[str]: + def get_submob_tex_strings(self, submob_labels: list[int]) -> list[str]: ordered_tex_spans = [ self.tex_span_list[label] for label in submob_labels ] @@ -385,7 +390,7 @@ class _TexParser(object): def get_containing_labels_by_tex_spans( self, - tex_spans: Iterable[tuple[int, int]] + tex_spans: list[tuple[int, int]] ) -> list[int]: return remove_list_redundancies(list(it.chain(*[ self.containing_labels_dict[tex_span] @@ -503,8 +508,10 @@ class MTex(_TexSVG): self.color_to_label(labelled_glyph.get_fill_color()) for labelled_glyph in labelled_svg_glyphs ] - mob = self.build_mobject(self, glyph_labels) - self.set_submobjects(mob.submobjects) + rearranged_submobs = self.rearrange_submobjects( + self.submobjects, glyph_labels + ) + self.set_submobjects(rearranged_submobs) @staticmethod def color_to_label(color: ManimColor) -> int: @@ -512,13 +519,13 @@ class MTex(_TexSVG): rg = r * 256 + g return rg * 256 + b - def build_mobject( + def rearrange_submobjects( self, - svg_glyphs: _TexSVG | None, - glyph_labels: Iterable[int] - ) -> VGroup: + svg_glyphs: list[VMobject], + glyph_labels: list[int] + ) -> list[VMobject]: if not svg_glyphs: - return VGroup() + return [] # Simply pack together adjacent mobjects with the same label. submobjects = [] @@ -552,11 +559,11 @@ class MTex(_TexSVG): submob.tex_string = submob_tex # Support `get_tex()` method here. submob.get_tex = MethodType(lambda inst: inst.tex_string, submob) - return VGroup(*rearranged_submobjects) + return rearranged_submobjects def get_part_by_tex_spans( self, - tex_spans: Iterable[tuple[int, int]] + tex_spans: list[tuple[int, int]] ) -> VGroup: labels = self.parser.get_containing_labels_by_tex_spans(tex_spans) return VGroup(*filter( @@ -581,7 +588,7 @@ class MTex(_TexSVG): ) ]) - def get_part_by_tex(self, tex: str, index: int = 0) -> VGroup: + def get_part_by_tex(self, tex: str, index: int = 0) -> VMobject: all_parts = self.get_parts_by_tex(tex) return all_parts[index] @@ -597,7 +604,7 @@ class MTex(_TexSVG): self.set_color_by_tex(tex, color) return self - def indices_of_part(self, part: Iterable[VGroup]) -> list[int]: + def indices_of_part(self, part: Iterable[VMobject]) -> list[int]: indices = [ index for index, submob in enumerate(self.submobjects) if submob in part diff --git a/manimlib/mobject/svg/svg_mobject.py b/manimlib/mobject/svg/svg_mobject.py index 86e77f00..f02d7f5d 100644 --- a/manimlib/mobject/svg/svg_mobject.py +++ b/manimlib/mobject/svg/svg_mobject.py @@ -189,30 +189,6 @@ class SVGMobject(VMobject): mob.shift(vec) return mob - def get_mobject_from(self, shape: se.GraphicObject) -> VMobject | None: - shape_class_to_func_map: dict[ - type, Callable[[se.GraphicObject], VMobject] - ] = { - se.Path: self.path_to_mobject, - se.SimpleLine: self.line_to_mobject, - se.Rect: self.rect_to_mobject, - se.Circle: self.circle_to_mobject, - se.Ellipse: self.ellipse_to_mobject, - se.Polygon: self.polygon_to_mobject, - se.Polyline: self.polyline_to_mobject, - # se.Text: self.text_to_mobject, # TODO - } - for shape_class, func in shape_class_to_func_map.items(): - if isinstance(shape, shape_class): - mob = func(shape) - self.apply_style_to_mobject(mob, shape) - return mob - - shape_class_name = shape.__class__.__name__ - if shape_class_name != "SVGElement": - log.warning(f"Unsupported element type: {shape_class_name}") - return None - @staticmethod def apply_style_to_mobject( mob: VMobject, diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index c6abe3c6..a13d1d80 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -12,7 +12,6 @@ from typing import Iterable, Sequence, Union import pygments import pygments.formatters import pygments.lexers -import pygments.styles from manimpango import MarkupUtils @@ -20,6 +19,7 @@ from manimlib.logger import log from manimlib.constants import * from manimlib.mobject.geometry import Dot from manimlib.mobject.svg.svg_mobject import SVGMobject +from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.utils.customization import get_customization from manimlib.utils.tex_file_writing import tex_hash from manimlib.utils.config_ops import digest_config @@ -30,9 +30,7 @@ from manimlib.utils.directories import get_text_dir from typing import TYPE_CHECKING if TYPE_CHECKING: - import colour from manimlib.mobject.types.vectorized_mobject import VMobject - ManimColor = Union[str, colour.Color, Sequence[float]] TEXT_MOB_SCALE_FACTOR = 0.0076 DEFAULT_LINE_SPACING_SCALE = 0.6 @@ -199,14 +197,14 @@ class _TextParser(object): result.append((text_piece, attr_dict)) return result - def get_markup_str_with_attrs(self): + def get_markup_str_with_attrs(self) -> str: return "".join([ f"{text_piece}" for text_piece, attr_dict in self.get_text_pieces() ]) @staticmethod - def get_attr_dict_str(attr_dict: dict[str, str]): + def get_attr_dict_str(attr_dict: dict[str, str]) -> str: return " ".join([ f"{key}='{value}'" for key, value in attr_dict.items() @@ -215,9 +213,14 @@ class _TextParser(object): # Temporary handler class _Alignment: - VAL_LIST = ["LEFT", "CENTER", "RIGHT"] - def __init__(self, s): - self.value = _Alignment.VAL_LIST.index(s.upper()) + VAL_DICT = { + "LEFT": 0, + "CENTER": 1, + "RIGHT": 2 + } + + def __init__(self, s: str): + self.value = _Alignment.VAL_DICT[s.upper()] class Text(SVGMobject): @@ -251,7 +254,7 @@ class Text(SVGMobject): "local_configs": {}, } - def __init__(self, text, **kwargs): + def __init__(self, text: str, **kwargs): self.full2short(kwargs) digest_config(self, kwargs) validate_error = MarkupUtils.validate(text) @@ -267,7 +270,7 @@ class Text(SVGMobject): self.scale(TEXT_MOB_SCALE_FACTOR) @property - def hash_seed(self): + def hash_seed(self) -> tuple: return ( self.__class__.__name__, self.svg_default, @@ -293,7 +296,7 @@ class Text(SVGMobject): self.local_configs ) - def get_file_path(self): + def get_file_path(self) -> str: full_markup = self.get_full_markup_str() svg_file = os.path.join( get_text_dir(), tex_hash(full_markup) + ".svg" @@ -302,7 +305,7 @@ class Text(SVGMobject): self.markup_to_svg(full_markup, svg_file) return svg_file - def get_full_markup_str(self): + def get_full_markup_str(self) -> str: if self.t2g: log.warning( "Manim currently cannot parse gradient from svg. " @@ -346,7 +349,7 @@ class Text(SVGMobject): return self.parser.get_markup_str_with_attrs() - def find_indexes(self, word_or_text_span): + def find_indexes(self, word_or_text_span: str | tuple[int, int]) -> list[tuple[int, int]]: if isinstance(word_or_text_span, tuple): return [word_or_text_span] @@ -355,7 +358,7 @@ class Text(SVGMobject): for match_obj in re.finditer(re.escape(word_or_text_span), self.text) ] - def markup_to_svg(self, markup_str, file_name): + def markup_to_svg(self, markup_str: str, file_name: str) -> str: # `manimpango` is under construction, # so the following code is intended to suit its interface alignment = _Alignment(self.alignment) @@ -384,7 +387,7 @@ class Text(SVGMobject): pango_width=pango_width ) - def generate_mobject(self): + def generate_mobject(self) -> None: super().generate_mobject() # Remove empty paths @@ -395,15 +398,14 @@ class Text(SVGMobject): content_str = self.parser.remove_tags(self.text) if self.is_markup: content_str = saxutils.unescape(content_str) - for char_index, char in enumerate(content_str): - if not re.match(r"\s", char): - continue + for match_obj in re.finditer(r"\s", content_str): + char_index = match_obj.start() space = Dot(radius=0, fill_opacity=0, stroke_opacity=0) space.move_to(submobjects[max(char_index - 1, 0)].get_center()) submobjects.insert(char_index, space) self.set_submobjects(submobjects) - def full2short(self, config): + def full2short(self, config: dict) -> None: conversion_dict = { "line_spacing_height": "lsh", "text2color": "t2c", @@ -417,7 +419,7 @@ class Text(SVGMobject): if long_name in kwargs: kwargs[short_name] = kwargs.pop(long_name) - def get_parts_by_text(self, word): + def get_parts_by_text(self, word: str) -> VGroup: if self.is_markup: log.warning( "Slicing MarkupText via `get_parts_by_text`, " @@ -433,7 +435,7 @@ class Text(SVGMobject): for i, j in self.find_indexes(word) )) - def get_part_by_text(self, word): + def get_part_by_text(self, word: str) -> VMobject | None: parts = self.get_parts_by_text(word) return parts[0] if parts else None @@ -445,7 +447,6 @@ class MarkupText(Text): } - class Code(MarkupText): CONFIG = { "font": "Consolas", @@ -456,7 +457,7 @@ class Code(MarkupText): "code_style": "monokai", } - def __init__(self, code, **kwargs): + def __init__(self, code: str, **kwargs): digest_config(self, kwargs) self.code = code lexer = pygments.lexers.get_lexer_by_name(self.language) From 4a03d196a6fe98e2073980739e3e0be74986d4eb Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Wed, 23 Mar 2022 13:34:30 +0800 Subject: [PATCH 02/48] Adjust typings --- manimlib/mobject/svg/mtex_mobject.py | 10 ++++------ manimlib/mobject/svg/svg_mobject.py | 2 +- manimlib/mobject/svg/tex_mobject.py | 2 +- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index bf126881..73303eab 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -447,9 +447,7 @@ class MTex(_TexSVG): self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size) @property - def hash_seed( - self - ) -> tuple[str, dict[str], dict[str, bool], str, list[str], str, str, bool]: + def hash_seed(self) -> tuple: return ( self.__class__.__name__, self.svg_default, @@ -462,9 +460,9 @@ class MTex(_TexSVG): ) def get_file_path(self) -> str: - return self._get_file_path(self.use_plain_tex) + return self.get_file_path_(use_plain_tex=self.use_plain_tex) - def _get_file_path(self, use_plain_tex: bool) -> str: + def get_file_path_(self, use_plain_tex: bool) -> str: if use_plain_tex: tex_string = self.tex_string else: @@ -501,7 +499,7 @@ class MTex(_TexSVG): if not self.use_plain_tex: labelled_svg_glyphs = self else: - file_path = self._get_file_path(use_plain_tex=False) + file_path = self.get_file_path_(use_plain_tex=False) labelled_svg_glyphs = _TexSVG(file_path) glyph_labels = [ diff --git a/manimlib/mobject/svg/svg_mobject.py b/manimlib/mobject/svg/svg_mobject.py index f02d7f5d..1ba0923b 100644 --- a/manimlib/mobject/svg/svg_mobject.py +++ b/manimlib/mobject/svg/svg_mobject.py @@ -76,7 +76,7 @@ class SVGMobject(VMobject): SVG_HASH_TO_MOB_MAP[hash_val] = self.copy() @property - def hash_seed(self) -> tuple[str, dict[str], dict[str, bool], str]: + def hash_seed(self) -> tuple: # Returns data which can uniquely represent the result of `init_points`. # The hashed value of it is stored as a key in `SVG_HASH_TO_MOB_MAP`. return ( diff --git a/manimlib/mobject/svg/tex_mobject.py b/manimlib/mobject/svg/tex_mobject.py index a7627783..717f1c24 100644 --- a/manimlib/mobject/svg/tex_mobject.py +++ b/manimlib/mobject/svg/tex_mobject.py @@ -50,7 +50,7 @@ class SingleStringTex(SVGMobject): self.organize_submobjects_left_to_right() @property - def hash_seed(self) -> tuple[str, dict[str], dict[str, bool], str, str, bool]: + def hash_seed(self) -> tuple: return ( self.__class__.__name__, self.svg_default, From 9ac1805e7e3217da36f74ae87e1579660640f00f Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Sat, 26 Mar 2022 20:52:28 +0800 Subject: [PATCH 03/48] Refactor MTex --- manimlib/mobject/svg/mtex_mobject.py | 931 ++++++++++++++------------- manimlib/mobject/svg/svg_mobject.py | 2 + manimlib/utils/iterables.py | 29 - 3 files changed, 489 insertions(+), 473 deletions(-) diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 73303eab..53dc15da 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -6,7 +6,7 @@ import itertools as it from types import MethodType from typing import Iterable, Union, Sequence -from manimlib.constants import WHITE +from manimlib.constants import BLACK, WHITE from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.utils.color import color_to_int_rgb @@ -29,227 +29,159 @@ if TYPE_CHECKING: SCALE_FACTOR_PER_FONT_POINT = 0.001 -def _get_neighbouring_pairs(iterable: Iterable) -> list: - return list(adjacent_pairs(iterable))[:-1] +class _TexSVG(SVGMobject): + CONFIG = { + "height": None, + "svg_default": { + "fill_color": WHITE, + }, + "stroke_width": 0, + "stroke_color": WHITE, + "path_string_config": { + "should_subdivide_sharp_curves": True, + "should_remove_null_curves": True, + }, + } -class _TexParser(object): - def __init__(self, tex_string: str, additional_substrings: list[str]): +class MTex(_TexSVG): + CONFIG = { + "font_size": 48, + "alignment": "\\centering", + "tex_environment": "align*", + "isolate": [], + "tex_to_color_map": {}, + "use_plain_tex": False, + } + + def __init__(self, tex_string: str, **kwargs): + digest_config(self, kwargs) + tex_string = tex_string.strip() + # Prevent from passing an empty string. + if not tex_string: + tex_string = "\\quad" self.tex_string = tex_string - self.whitespace_indices = self.get_whitespace_indices() - self.backslash_indices = self.get_backslash_indices() - self.script_indices = self.get_script_indices() - self.brace_indices_dict = self.get_brace_indices_dict() - self.tex_span_list: list[tuple[int, int]] = [] - self.script_span_to_char_dict: dict[tuple[int, int], str] = {} - self.script_span_to_tex_span_dict: dict[ - tuple[int, int], tuple[int, int] - ] = {} - self.neighbouring_script_span_pairs: list[tuple[int, int]] = [] - self.specified_substrings: list[str] = [] - self.add_tex_span((0, len(tex_string))) - self.break_up_by_scripts() - self.break_up_by_double_braces() - self.break_up_by_additional_substrings(additional_substrings) - self.tex_span_list.sort(key=lambda t: (t[0], -t[1])) - self.specified_substrings = remove_list_redundancies( - self.specified_substrings + super().__init__(**kwargs) + + self.set_color_by_tex_to_color_map(self.tex_to_color_map) + self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size) + + @property + def hash_seed(self) -> tuple: + return ( + self.__class__.__name__, + self.svg_default, + self.path_string_config, + self.tex_string, + self.alignment, + self.tex_environment, + self.isolate, + self.tex_to_color_map, + self.use_plain_tex ) - self.containing_labels_dict = self.get_containing_labels_dict() - def add_tex_span(self, tex_span: tuple[int, int]) -> None: - if tex_span not in self.tex_span_list: - self.tex_span_list.append(tex_span) - - def get_whitespace_indices(self) -> list[int]: - return [ - match_obj.start() - for match_obj in re.finditer(r"\s", self.tex_string) - ] - - def get_backslash_indices(self) -> list[int]: - # Newlines (`\\`) don't count. - return [ - match_obj.end() - 1 - for match_obj in re.finditer(r"\\+", self.tex_string) - if len(match_obj.group()) % 2 == 1 - ] - - def filter_out_escaped_characters(self, indices) -> list[int]: - return list(filter( - lambda index: index - 1 not in self.backslash_indices, - indices - )) - - def get_script_indices(self) -> list[int]: - return self.filter_out_escaped_characters([ - match_obj.start() - for match_obj in re.finditer(r"[_^]", self.tex_string) + def get_file_path(self) -> str: + self.init_parser() + self.base_color = self.svg_default["color"] \ + or self.svg_default["fill_color"] or WHITE + self.use_plain_file = any([ + self.use_plain_tex, + self.color_cmd_repl_items, + self.base_color not in (BLACK, WHITE) ]) + return self.get_file_path_(use_plain_file=self.use_plain_file) - def get_brace_indices_dict(self) -> dict[int, int]: - tex_string = self.tex_string - indices = self.filter_out_escaped_characters([ - match_obj.start() - for match_obj in re.finditer(r"[{}]", tex_string) - ]) - result = {} - left_brace_indices_stack = [] - for index in indices: - if tex_string[index] == "{": - left_brace_indices_stack.append(index) - else: - left_brace_index = left_brace_indices_stack.pop() - result[left_brace_index] = index - return result + def get_file_path_(self, use_plain_file: bool) -> str: + if use_plain_file: + tex_string = "".join([ + "{{", + self.get_color_command(int(self.base_color[1:], 16)), + self.tex_string, + "}}" + ]) + else: + tex_string = self.labelled_tex_string - def break_up_by_scripts(self) -> None: - # Match subscripts & superscripts. - tex_string = self.tex_string - whitespace_indices = self.whitespace_indices - brace_indices_dict = self.brace_indices_dict - script_spans = [] - for script_index in self.script_indices: - script_char = tex_string[script_index] - extended_begin = script_index - while extended_begin - 1 in whitespace_indices: - extended_begin -= 1 - script_begin = script_index + 1 - while script_begin in whitespace_indices: - script_begin += 1 - if script_begin in brace_indices_dict.keys(): - script_end = brace_indices_dict[script_begin] + 1 - else: - pattern = re.compile(r"[a-zA-Z0-9]|\\[a-zA-Z]+") - match_obj = pattern.match(tex_string, pos=script_begin) - if not match_obj: - script_name = { - "_": "subscript", - "^": "superscript" - }[script_char] - log.warning( - f"Unclear {script_name} detected while parsing. " - "Please use braces to clarify" - ) - continue - script_end = match_obj.end() - tex_span = (script_begin, script_end) - script_span = (extended_begin, script_end) - script_spans.append(script_span) - self.add_tex_span(tex_span) - self.script_span_to_char_dict[script_span] = script_char - self.script_span_to_tex_span_dict[script_span] = tex_span + full_tex = self.get_tex_file_body(tex_string) + with display_during_execution(f"Writing \"{self.tex_string}\""): + file_path = self.tex_to_svg_file_path(full_tex) + return file_path - if not script_spans: + def get_tex_file_body(self, tex_string: str) -> str: + if self.tex_environment: + tex_string = "\n".join([ + f"\\begin{{{self.tex_environment}}}", + tex_string, + f"\\end{{{self.tex_environment}}}" + ]) + if self.alignment: + tex_string = "\n".join([self.alignment, tex_string]) + + tex_config = get_tex_config() + return tex_config["tex_body"].replace( + tex_config["text_to_replace"], + tex_string + ) + + @staticmethod + def tex_to_svg_file_path(tex_file_content: str) -> str: + return tex_to_svg_file(tex_file_content) + + def generate_mobject(self) -> None: + super().generate_mobject() + + glyphs = self.submobjects + if not glyphs: return - _, sorted_script_spans = zip(*sorted([ - (index, script_span) - for script_span in script_spans - for index in script_span - ])) - for span_0, span_1 in _get_neighbouring_pairs(sorted_script_spans): - if span_0[1] == span_1[0]: - self.neighbouring_script_span_pairs.append((span_0, span_1)) + if self.use_plain_file: + file_path = self.get_file_path_(use_plain_file=False) + labelled_svg_glyphs = _TexSVG( + file_path, svg_default={"fill_color": BLACK} + ) + predefined_colors = [ + labelled_glyph.get_fill_color() + for labelled_glyph in self.submobjects + ] + else: + labelled_svg_glyphs = self + predefined_colors = [self.base_color] * len(glyphs) - def break_up_by_double_braces(self) -> None: - # Match paired double braces (`{{...}}`). - tex_string = self.tex_string - reversed_indices_dict = dict( - item[::-1] for item in self.brace_indices_dict.items() - ) - skip = False - for prev_right_index, right_index in _get_neighbouring_pairs( - list(reversed_indices_dict.keys()) - ): - if skip: - skip = False - continue - if right_index != prev_right_index + 1: - continue - left_index = reversed_indices_dict[right_index] - prev_left_index = reversed_indices_dict[prev_right_index] - if left_index != prev_left_index - 1: - continue - tex_span = (left_index, right_index + 1) - self.add_tex_span(tex_span) - self.specified_substrings.append(tex_string[slice(*tex_span)]) - skip = True - - def break_up_by_additional_substrings( - self, - additional_substrings: list[str] - ) -> None: - stripped_substrings = sorted(remove_list_redundancies([ - string.strip() - for string in additional_substrings - ])) - if "" in stripped_substrings: - stripped_substrings.remove("") - - tex_string = self.tex_string - all_tex_spans = [] - for string in stripped_substrings: - match_objs = list(re.finditer(re.escape(string), tex_string)) - if not match_objs: - continue - self.specified_substrings.append(string) - for match_obj in match_objs: - all_tex_spans.append(match_obj.span()) - - former_script_spans_dict = dict([ - script_span_pair[0][::-1] - for script_span_pair in self.neighbouring_script_span_pairs - ]) - for span_begin, span_end in all_tex_spans: - # Deconstruct spans containing one out of two scripts. - if span_end in former_script_spans_dict.keys(): - span_end = former_script_spans_dict[span_end] - if span_begin >= span_end: - continue - self.add_tex_span((span_begin, span_end)) - - def get_containing_labels_dict(self) -> dict[tuple[int, int], list[int]]: - tex_span_list = self.tex_span_list - result = { - tex_span: [] - for tex_span in tex_span_list - } - overlapping_tex_span_pairs = [] - for index_0, span_0 in enumerate(tex_span_list): - for index_1, span_1 in enumerate(tex_span_list[index_0:]): - if span_0[1] <= span_1[0]: - continue - if span_0[1] < span_1[1]: - overlapping_tex_span_pairs.append((span_0, span_1)) - result[span_0].append(index_0 + index_1) - if overlapping_tex_span_pairs: - tex_string = self.tex_string - log.error("Partially overlapping substrings detected:") - for tex_span_pair in overlapping_tex_span_pairs: - log.error(", ".join( - f"\"{tex_string[slice(*tex_span)]}\"" - for tex_span in tex_span_pair - )) - raise ValueError - return result - - def get_labelled_tex_string(self) -> str: - indices, _, flags, labels = zip(*sorted([ - (*tex_span[::(1, -1)[flag]], flag, label) - for label, tex_span in enumerate(self.tex_span_list) - for flag in range(2) - ], key=lambda t: (t[0], -t[2], -t[1]))) - command_pieces = [ - ("{{" + self.get_color_command(label), "}}")[flag] - for flag, label in zip(flags, labels) - ][1:-1] - command_pieces.insert(0, "") - string_pieces = [ - self.tex_string[slice(*tex_span)] - for tex_span in _get_neighbouring_pairs(indices) + glyph_labels = [ + self.color_to_label(labelled_glyph.get_fill_color()) + for labelled_glyph in labelled_svg_glyphs ] - return "".join(it.chain(*zip(command_pieces, string_pieces))) + for glyph, glyph_color in zip(glyphs, predefined_colors): + glyph.set_fill(glyph_color) + + # Simply pack together adjacent mobjects with the same label. + submob_labels, glyphs_lists = self.group_neighbours( + glyph_labels, glyphs + ) + submobjects = [ + VGroup(*glyph_list) + for glyph_list in glyphs_lists + ] + submob_tex_strings = self.get_submob_tex_strings(submob_labels) + for submob, label, submob_tex in zip( + submobjects, submob_labels, submob_tex_strings + ): + submob.submob_label = label + submob.tex_string = submob_tex + # Support `get_tex()` method here. + submob.get_tex = MethodType(lambda inst: inst.tex_string, submob) + self.set_submobjects(submobjects) + + ## Static methods + + @staticmethod + def color_to_label(color: ManimColor) -> int: + r, g, b = color_to_int_rgb(color) + rg = r * 256 + g + rgb = rg * 256 + b + if rgb == 16777215: # white + return 0 + return rgb @staticmethod def get_color_command(label: int) -> str: @@ -262,43 +194,335 @@ class _TexParser(object): "}" ]) - def get_sorted_submob_indices(self, submob_labels: list[int]) -> list[int]: - def script_span_to_submob_range(script_span): - tex_span = self.script_span_to_tex_span_dict[script_span] - submob_indices = [ - index for index, label in enumerate(submob_labels) - if label in self.containing_labels_dict[tex_span] - ] - return range(submob_indices[0], submob_indices[-1] + 1) + @staticmethod + def get_neighbouring_pairs(iterable: Iterable) -> list: + return list(adjacent_pairs(iterable))[:-1] - filtered_script_span_pairs = filter( - lambda script_span_pair: all([ - self.script_span_to_char_dict[script_span] == character - for script_span, character in zip(script_span_pair, "_^") - ]), - self.neighbouring_script_span_pairs - ) - switch_range_pairs = sorted([ - tuple([ - script_span_to_submob_range(script_span) - for script_span in script_span_pair - ]) - for script_span_pair in filtered_script_span_pairs - ], key=lambda t: (t[0].stop, -t[0].start)) - result = list(range(len(submob_labels))) - for range_0, range_1 in switch_range_pairs: - result = [ - *result[:range_1.start], - *result[range_0.start:range_0.stop], - *result[range_1.stop:range_0.start], - *result[range_1.start:range_1.stop], - *result[range_0.stop:] + @staticmethod + def group_neighbours( + labels: Iterable[object], + objs: Iterable[object] + ) -> tuple[list[object], list[list[object]]]: + # Pack together neighbouring objects sharing the same label. + if not labels: + return [], [] + + group_labels = [] + groups = [] + new_group = [] + current_label = labels[0] + for label, obj in zip(labels, objs): + if label == current_label: + new_group.append(obj) + else: + group_labels.append(current_label) + groups.append(new_group) + new_group = [obj] + current_label = label + group_labels.append(current_label) + groups.append(new_group) + return group_labels, groups + + ## Parser + + def init_parser(self) -> None: + self.additional_substrings = self.get_additional_substrings() + self.backslash_indices = self.get_backslash_indices() + self.left_brace_indices, self.right_brace_indices = \ + self.get_left_and_right_indices() + self.script_char_spans = self.get_script_char_spans() + self.skipped_indices = self.get_skipped_indices() + self.script_spans = self.get_script_spans() + self.script_content_spans = self.get_script_content_spans() + self.double_braces_spans = self.get_double_braces_spans() + self.stripped_substrings = self.get_stripped_substrings() + self.specified_spans = self.get_specified_spans() + self.specified_substrings = self.get_specified_substrings() + self.tex_span_list = self.get_tex_span_list() + self.extended_tex_span_list = self.get_extended_tex_span_list() + self.isolated_substrings = self.get_isolated_substrings() + self.containing_labels_dict = self.get_containing_labels_dict() + self.color_cmd_repl_items = self.get_color_cmd_repl_items() + self.span_repl_dict = self.get_span_repl_dict() + self.labelled_tex_string = self.get_labelled_tex_string() + + def get_additional_substrings(self) -> list[str]: + return list(it.chain( + self.tex_to_color_map.keys(), + self.isolate + )) + + def get_backslash_indices(self) -> list[int]: + # Newlines (`\\`) don't count. + return [ + match_obj.end() - 1 + for match_obj in re.finditer(r"\\+", self.tex_string) + if len(match_obj.group()) % 2 == 1 + ] + + def get_left_and_right_indices(self) -> list[tuple[int, int]]: + tex_string = self.tex_string + indices = list(filter( + lambda index: index - 1 not in self.backslash_indices, + [ + match_obj.start() + for match_obj in re.finditer(r"[{}]", tex_string) ] + )) + left_brace_indices = [] + right_brace_indices = [] + left_brace_indices_stack = [] + for index in indices: + if tex_string[index] == "{": + left_brace_indices_stack.append(index) + else: + if not left_brace_indices_stack: + raise ValueError("Missing '{' inserted") + left_brace_index = left_brace_indices_stack.pop() + left_brace_indices.append(left_brace_index) + right_brace_indices.append(index) + if left_brace_indices_stack: + raise ValueError("Missing '}' inserted") + return left_brace_indices, right_brace_indices + + def get_script_char_spans(self) -> list[tuple[int, int]]: + return [ + match_obj.span() + for match_obj in re.finditer(r"(\s*)[_^]\s*", self.tex_string) + if match_obj.group(1) + or match_obj.start() - 1 not in self.backslash_indices + ] + + def get_skipped_indices(self) -> list[int]: + return sorted(remove_list_redundancies([ + match_obj.start() + for match_obj in re.finditer(r"\s", self.tex_string) + ] + list(it.chain(*[ + range(*script_char_span) + for script_char_span in self.script_char_spans + ])))) + + def get_script_spans(self) -> list[tuple[int, int]]: + tex_string = self.tex_string + result = [] + brace_indices_dict = dict(zip( + self.left_brace_indices, self.right_brace_indices + )) + for char_begin, span_begin in self.script_char_spans: + if span_begin in brace_indices_dict.keys(): + span_end = brace_indices_dict[span_begin] + 1 + else: + pattern = re.compile(r"[a-zA-Z0-9]|\\[a-zA-Z]+") + match_obj = pattern.match(tex_string, pos=span_begin) + if not match_obj: + script_name = { + "_": "subscript", + "^": "superscript" + }[script_char] + log.warning( + f"Unclear {script_name} detected while parsing. " + "Please use braces to clarify" + ) + continue + span_end = match_obj.end() + result.append((char_begin, span_end)) return result + def get_script_content_spans(self) -> list[tuple[int, int]]: + return [ + (script_char_span[1], script_span[1]) + for script_char_span, script_span in zip( + self.script_char_spans, self.script_spans + ) + ] + + def get_double_braces_spans(self) -> list[tuple[int, int]]: + # Match paired double braces (`{{...}}`). + result = [] + reversed_brace_indices_dict = dict(zip( + self.right_brace_indices, self.left_brace_indices + )) + skip = False + for prev_right_index, right_index in self.get_neighbouring_pairs( + sorted(reversed_brace_indices_dict.keys()) + ): + if skip: + skip = False + continue + if right_index != prev_right_index + 1: + continue + left_index = reversed_brace_indices_dict[right_index] + prev_left_index = reversed_brace_indices_dict[prev_right_index] + if left_index != prev_left_index - 1: + continue + result.append((left_index, right_index + 1)) + skip = True + return result + + def get_stripped_substrings(self) -> list[str]: + result = remove_list_redundancies([ + string.strip() + for string in self.additional_substrings + ]) + if "" in result: + result.remove("") + return result + + def get_specified_spans(self) -> list[tuple[int, int]]: + result = self.double_braces_spans.copy() + tex_string = self.tex_string + reversed_script_spans_dict = dict([ + script_span[::-1] for script_span in self.script_spans + ]) + for string in self.stripped_substrings: + for match_obj in re.finditer(re.escape(string), tex_string): + span_begin, span_end = match_obj.span() + while span_end in reversed_script_spans_dict.keys(): + span_end = reversed_script_spans_dict[span_end] + if span_begin >= span_end: + continue + result.append((span_begin, span_end)) + return list(filter( + lambda tex_span: tex_span not in self.script_content_spans, + remove_list_redundancies(result) + )) + + def get_specified_substrings(self) -> list[str]: + return remove_list_redundancies([ + self.tex_string[slice(*double_braces_span)] + for double_braces_span in self.double_braces_spans + ] + list(filter( + lambda s: s in self.tex_string, + self.additional_substrings + ))) + + def get_tex_span_list(self) -> list[tuple[int, int]]: + return [ + (0, len(self.tex_string)), + *self.script_content_spans, + *self.specified_spans + ] + + def get_extended_tex_span_list(self) -> list[tuple[int, int]]: + extended_specified_spans = [] + script_spans_dict = dict(self.script_spans) + for span_begin, span_end in self.specified_spans: + while span_end in script_spans_dict.keys(): + span_end = script_spans_dict[span_end] + extended_specified_spans.append((span_begin, span_end)) + return [ + (0, len(self.tex_string)), + *self.script_content_spans, + *extended_specified_spans + ] + + def get_isolated_substrings(self) -> list[str]: + return remove_list_redundancies([ + self.tex_string[slice(*tex_span)] + for tex_span in self.tex_span_list + ]) + + def get_containing_labels_dict(self) -> dict[tuple[int, int], list[int]]: + tex_span_list = self.tex_span_list + result = { + tex_span: [] + for tex_span in tex_span_list + } + for span_0 in tex_span_list: + for span_index, span_1 in enumerate(tex_span_list): + if span_0[0] <= span_1[0] and span_1[1] <= span_0[1]: + result[span_0].append(span_index) + elif span_0[0] < span_1[0] < span_0[1] < span_1[1]: + string_0, string_1 = [ + self.tex_string[slice(*tex_span)] + for tex_span in [span_0, span_1] + ] + raise ValueError( + "Partially overlapping substrings detected: " + f"'{string_0}' and '{string_1}'" + ) + return result + + def get_color_cmd_repl_items(self) -> list[tuple[tuple[int, int], str]]: + color_related_commands_dict = { + "color": 1, + "textcolor": 1, + "pagecolor": 1, + "colorbox": 1, + "fcolorbox": 2, + } + result = [] + tex_string = self.tex_string + backslash_indices = self.backslash_indices + left_indices = self.left_brace_indices + brace_indices_dict = dict(zip( + self.left_brace_indices, self.right_brace_indices + )) + for cmd_name, n_braces in color_related_commands_dict.items(): + pattern = cmd_name + r"(?![a-zA-Z])" + for match_obj in re.finditer(pattern, tex_string): + span_begin, span_end = match_obj.span() + if span_begin - 1 not in backslash_indices: + continue + repl_str = cmd_name + n_braces * "{black}" + for _ in range(n_braces): + left_index = min(filter( + lambda index: index >= span_end, left_indices + )) + span_end = brace_indices_dict[left_index] + 1 + result.append(((span_begin, span_end), repl_str)) + return result + + def get_span_repl_dict(self) -> dict[tuple[int, int], str]: + indices, _, _, cmd_strings = zip(*sorted([ + ( + tex_span[flag], + -flag, + -tex_span[1 - flag], + ("{{" + self.get_color_command(label), "}}")[flag] + ) + for label, tex_span in enumerate(self.extended_tex_span_list) + for flag in range(2) + ])) + result = { + (index, index): "".join(cmd_strs) + for index, cmd_strs in zip(*self.group_neighbours( + indices, cmd_strings + )) + } + result.update(self.color_cmd_repl_items) + return result + + def get_labelled_tex_string(self) -> str: + if not self.span_repl_dict: + return self.tex_string + + spans = sorted(self.span_repl_dict.keys()) + if not all( + span_0[1] <= span_1[0] + for span_0, span_1 in self.get_neighbouring_pairs(spans) + ): + raise ValueError("Failed to generate the labelled string") + + span_ends, span_begins = zip(*spans) + string_pieces = [ + self.tex_string[slice(*span)] + for span in zip( + (0, *span_begins), + (*span_ends, len(self.tex_string)) + ) + ] + repl_strs = [ + self.span_repl_dict[span] + for span in spans + ] + repl_strs.append("") + return "".join(it.chain(*zip(string_pieces, repl_strs))) + def get_submob_tex_strings(self, submob_labels: list[int]) -> list[str]: ordered_tex_spans = [ - self.tex_span_list[label] for label in submob_labels + self.tex_span_list[label] + for label in submob_labels ] ordered_containing_labels = [ self.containing_labels_dict[tex_span] @@ -323,29 +547,28 @@ class _TexParser(object): string_span_ends.append(ordered_span_ends[-1]) tex_string = self.tex_string - left_brace_indices = sorted(self.brace_indices_dict.keys()) - right_brace_indices = sorted(self.brace_indices_dict.values()) - ignored_indices = sorted(it.chain( - self.whitespace_indices, - left_brace_indices, - right_brace_indices, - self.script_indices + left_indices = self.left_brace_indices + right_indices = self.right_brace_indices + skipped_indices = sorted(it.chain( + self.skipped_indices, + left_indices, + right_indices )) result = [] for span_begin, span_end in zip(string_span_begins, string_span_ends): - while span_begin in ignored_indices: + while span_begin in skipped_indices: span_begin += 1 if span_begin >= span_end: result.append("") continue - while span_end - 1 in ignored_indices: + while span_end - 1 in skipped_indices: span_end -= 1 unclosed_left_brace = 0 unclosed_right_brace = 0 for index in range(span_begin, span_end): - if index in left_brace_indices: + if index in left_indices: unclosed_left_brace += 1 - elif index in right_brace_indices: + elif index in right_indices: if unclosed_left_brace == 0: unclosed_right_brace += 1 else: @@ -357,14 +580,13 @@ class _TexParser(object): ])) return result + ## Selector + def find_span_components_of_custom_span( self, custom_span: tuple[int, int] ) -> list[tuple[int, int]] | None: - skipped_indices = sorted(it.chain( - self.whitespace_indices, - self.script_indices - )) + skipped_indices = self.skipped_indices tex_span_choices = sorted(filter( lambda tex_span: all([ tex_span[0] >= custom_span[0], @@ -388,195 +610,22 @@ class _TexParser(object): span_begin = next_begin return result - def get_containing_labels_by_tex_spans( - self, - tex_spans: list[tuple[int, int]] - ) -> list[int]: - return remove_list_redundancies(list(it.chain(*[ - self.containing_labels_dict[tex_span] - for tex_span in tex_spans - ]))) - - def get_specified_substrings(self) -> list[str]: - return self.specified_substrings - - def get_isolated_substrings(self) -> list[str]: - return remove_list_redundancies([ - self.tex_string[slice(*tex_span)] - for tex_span in self.tex_span_list - ]) - - -class _TexSVG(SVGMobject): - CONFIG = { - "height": None, - "fill_opacity": 1.0, - "stroke_width": 0, - "path_string_config": { - "should_subdivide_sharp_curves": True, - "should_remove_null_curves": True, - }, - } - - -class MTex(_TexSVG): - CONFIG = { - "color": WHITE, - "font_size": 48, - "alignment": "\\centering", - "tex_environment": "align*", - "isolate": [], - "tex_to_color_map": {}, - "use_plain_tex": False, - } - - def __init__(self, tex_string: str, **kwargs): - digest_config(self, kwargs) - tex_string = tex_string.strip() - # Prevent from passing an empty string. - if not tex_string: - tex_string = "\\quad" - self.tex_string = tex_string - self.parser = _TexParser( - self.tex_string, - [*self.tex_to_color_map.keys(), *self.isolate] - ) - super().__init__(**kwargs) - - self.set_color_by_tex_to_color_map(self.tex_to_color_map) - self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size) - - @property - def hash_seed(self) -> tuple: - return ( - self.__class__.__name__, - self.svg_default, - self.path_string_config, - self.tex_string, - self.parser.specified_substrings, - self.alignment, - self.tex_environment, - self.use_plain_tex - ) - - def get_file_path(self) -> str: - return self.get_file_path_(use_plain_tex=self.use_plain_tex) - - def get_file_path_(self, use_plain_tex: bool) -> str: - if use_plain_tex: - tex_string = self.tex_string - else: - tex_string = self.parser.get_labelled_tex_string() - - full_tex = self.get_tex_file_body(tex_string) - with display_during_execution(f"Writing \"{self.tex_string}\""): - file_path = self.tex_to_svg_file_path(full_tex) - return file_path - - def get_tex_file_body(self, tex_string: str) -> str: - if self.tex_environment: - tex_string = "\n".join([ - f"\\begin{{{self.tex_environment}}}", - tex_string, - f"\\end{{{self.tex_environment}}}" - ]) - if self.alignment: - tex_string = "\n".join([self.alignment, tex_string]) - - tex_config = get_tex_config() - return tex_config["tex_body"].replace( - tex_config["text_to_replace"], - tex_string - ) - - @staticmethod - def tex_to_svg_file_path(tex_file_content: str) -> str: - return tex_to_svg_file(tex_file_content) - - def generate_mobject(self) -> None: - super().generate_mobject() - - if not self.use_plain_tex: - labelled_svg_glyphs = self - else: - file_path = self.get_file_path_(use_plain_tex=False) - labelled_svg_glyphs = _TexSVG(file_path) - - glyph_labels = [ - self.color_to_label(labelled_glyph.get_fill_color()) - for labelled_glyph in labelled_svg_glyphs - ] - rearranged_submobs = self.rearrange_submobjects( - self.submobjects, glyph_labels - ) - self.set_submobjects(rearranged_submobs) - - @staticmethod - def color_to_label(color: ManimColor) -> int: - r, g, b = color_to_int_rgb(color) - rg = r * 256 + g - return rg * 256 + b - - def rearrange_submobjects( - self, - svg_glyphs: list[VMobject], - glyph_labels: list[int] - ) -> list[VMobject]: - if not svg_glyphs: - return [] - - # Simply pack together adjacent mobjects with the same label. - submobjects = [] - submob_labels = [] - new_glyphs = [] - current_glyph_label = glyph_labels[0] - for glyph, label in zip(svg_glyphs, glyph_labels): - if label == current_glyph_label: - new_glyphs.append(glyph) - else: - submobject = VGroup(*new_glyphs) - submob_labels.append(current_glyph_label) - submobjects.append(submobject) - new_glyphs = [glyph] - current_glyph_label = label - submobject = VGroup(*new_glyphs) - submob_labels.append(current_glyph_label) - submobjects.append(submobject) - - indices = self.parser.get_sorted_submob_indices(submob_labels) - rearranged_submobjects = [submobjects[index] for index in indices] - rearranged_labels = [submob_labels[index] for index in indices] - - submob_tex_strings = self.parser.get_submob_tex_strings( - rearranged_labels - ) - for submob, label, submob_tex in zip( - rearranged_submobjects, rearranged_labels, submob_tex_strings - ): - submob.submob_label = label - submob.tex_string = submob_tex - # Support `get_tex()` method here. - submob.get_tex = MethodType(lambda inst: inst.tex_string, submob) - return rearranged_submobjects - - def get_part_by_tex_spans( - self, - tex_spans: list[tuple[int, int]] - ) -> VGroup: - labels = self.parser.get_containing_labels_by_tex_spans(tex_spans) - return VGroup(*filter( - lambda submob: submob.submob_label in labels, - self.submobjects - )) - def get_part_by_custom_span(self, custom_span: tuple[int, int]) -> VGroup: - tex_spans = self.parser.find_span_components_of_custom_span( + tex_spans = self.find_span_components_of_custom_span( custom_span ) if tex_spans is None: tex = self.tex_string[slice(*custom_span)] raise ValueError(f"Failed to match mobjects from tex: \"{tex}\"") - return self.get_part_by_tex_spans(tex_spans) + + labels = set(it.chain(*[ + self.containing_labels_dict[tex_span] + for tex_span in tex_spans + ])) + return VGroup(*filter( + lambda submob: submob.submob_label in labels, + self.submobjects + )) def get_parts_by_tex(self, tex: str) -> VGroup: return VGroup(*[ @@ -624,12 +673,6 @@ class MTex(_TexSVG): for submob in self.submobjects ] - def get_specified_substrings(self) -> list[str]: - return self.parser.get_specified_substrings() - - def get_isolated_substrings(self) -> list[str]: - return self.parser.get_isolated_substrings() - class MTexText(MTex): CONFIG = { diff --git a/manimlib/mobject/svg/svg_mobject.py b/manimlib/mobject/svg/svg_mobject.py index 1ba0923b..bc625c83 100644 --- a/manimlib/mobject/svg/svg_mobject.py +++ b/manimlib/mobject/svg/svg_mobject.py @@ -172,6 +172,8 @@ class SVGMobject(VMobject): else: log.warning(f"Unsupported element type: {type(shape)}") continue + if not mob.has_points(): + continue self.apply_style_to_mobject(mob, shape) if isinstance(shape, se.Transformable) and shape.apply: self.handle_transform(mob, shape.transform) diff --git a/manimlib/utils/iterables.py b/manimlib/utils/iterables.py index 6729d359..99788a42 100644 --- a/manimlib/utils/iterables.py +++ b/manimlib/utils/iterables.py @@ -1,6 +1,5 @@ from __future__ import annotations -import itertools as it from typing import Callable, Iterable, Sequence, TypeVar import numpy as np @@ -36,10 +35,6 @@ def list_difference_update(l1: Iterable[T], l2: Iterable[T]) -> list[T]: return [e for e in l1 if e not in l2] -def all_elements_are_instances(iterable: Iterable, Class: type) -> bool: - return all([isinstance(e, Class) for e in iterable]) - - def adjacent_n_tuples(objects: Iterable[T], n: int) -> zip[tuple[T, T]]: return zip(*[ [*objects[k:], *objects[:k]] @@ -133,30 +128,6 @@ def make_even( ) -def make_even_by_cycling( - iterable_1: Iterable[T], - iterable_2: Iterable[S] -) -> tuple[list[T], list[S]]: - length = max(len(iterable_1), len(iterable_2)) - cycle1 = it.cycle(iterable_1) - cycle2 = it.cycle(iterable_2) - return ( - [next(cycle1) for x in range(length)], - [next(cycle2) for x in range(length)] - ) - - -def remove_nones(sequence: Iterable) -> list: - return [x for x in sequence if x] - - -# Note this is redundant with it.chain - - -def concatenate_lists(*list_of_lists): - return [item for l in list_of_lists for item in l] - - def hash_obj(obj: object) -> int: if isinstance(obj, dict): new_obj = {k: hash_obj(v) for k, v in obj.items()} From e44a2fc8c6812bbd8d50fe1754634cd69ad919a0 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Sun, 27 Mar 2022 00:29:22 +0800 Subject: [PATCH 04/48] Refactor MTex --- manimlib/mobject/svg/mtex_mobject.py | 231 ++++++++++++++------------- 1 file changed, 120 insertions(+), 111 deletions(-) diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 53dc15da..1a71783d 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -6,7 +6,7 @@ import itertools as it from types import MethodType from typing import Iterable, Union, Sequence -from manimlib.constants import BLACK, WHITE +from manimlib.constants import WHITE from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.utils.color import color_to_int_rgb @@ -32,9 +32,6 @@ SCALE_FACTOR_PER_FONT_POINT = 0.001 class _TexSVG(SVGMobject): CONFIG = { "height": None, - "svg_default": { - "fill_color": WHITE, - }, "stroke_width": 0, "stroke_color": WHITE, "path_string_config": { @@ -46,6 +43,7 @@ class _TexSVG(SVGMobject): class MTex(_TexSVG): CONFIG = { + "base_color": WHITE, "font_size": 48, "alignment": "\\centering", "tex_environment": "align*", @@ -54,13 +52,14 @@ class MTex(_TexSVG): "use_plain_tex": False, } - def __init__(self, tex_string: str, **kwargs): + def __init__(self, string: str, **kwargs): digest_config(self, kwargs) - tex_string = tex_string.strip() + string = string.strip() # Prevent from passing an empty string. - if not tex_string: - tex_string = "\\quad" - self.tex_string = tex_string + if not string: + string = "\\quad" + self.tex_string = string + self.string = string super().__init__(**kwargs) self.set_color_by_tex_to_color_map(self.tex_to_color_map) @@ -72,7 +71,8 @@ class MTex(_TexSVG): self.__class__.__name__, self.svg_default, self.path_string_config, - self.tex_string, + self.string, + self.base_color, self.alignment, self.tex_environment, self.isolate, @@ -82,45 +82,43 @@ class MTex(_TexSVG): def get_file_path(self) -> str: self.init_parser() - self.base_color = self.svg_default["color"] \ - or self.svg_default["fill_color"] or WHITE self.use_plain_file = any([ self.use_plain_tex, self.color_cmd_repl_items, - self.base_color not in (BLACK, WHITE) + self.base_color != WHITE ]) return self.get_file_path_(use_plain_file=self.use_plain_file) def get_file_path_(self, use_plain_file: bool) -> str: if use_plain_file: - tex_string = "".join([ + content = "".join([ "{{", - self.get_color_command(int(self.base_color[1:], 16)), - self.tex_string, + self.get_color_command(self.color_to_int(self.base_color)), + self.string, "}}" ]) else: - tex_string = self.labelled_tex_string + content = self.get_labelled_string() - full_tex = self.get_tex_file_body(tex_string) - with display_during_execution(f"Writing \"{self.tex_string}\""): + full_tex = self.get_tex_file_body(content) + with display_during_execution(f"Writing \"{self.string}\""): file_path = self.tex_to_svg_file_path(full_tex) return file_path - def get_tex_file_body(self, tex_string: str) -> str: + def get_tex_file_body(self, content: str) -> str: if self.tex_environment: - tex_string = "\n".join([ + content = "\n".join([ f"\\begin{{{self.tex_environment}}}", - tex_string, + content, f"\\end{{{self.tex_environment}}}" ]) if self.alignment: - tex_string = "\n".join([self.alignment, tex_string]) + content = "\n".join([self.alignment, content]) tex_config = get_tex_config() return tex_config["tex_body"].replace( tex_config["text_to_replace"], - tex_string + content ) @staticmethod @@ -136,9 +134,7 @@ class MTex(_TexSVG): if self.use_plain_file: file_path = self.get_file_path_(use_plain_file=False) - labelled_svg_glyphs = _TexSVG( - file_path, svg_default={"fill_color": BLACK} - ) + labelled_svg_glyphs = _TexSVG(file_path) predefined_colors = [ labelled_glyph.get_fill_color() for labelled_glyph in self.submobjects @@ -166,7 +162,7 @@ class MTex(_TexSVG): for submob, label, submob_tex in zip( submobjects, submob_labels, submob_tex_strings ): - submob.submob_label = label + submob.label = label submob.tex_string = submob_tex # Support `get_tex()` method here. submob.get_tex = MethodType(lambda inst: inst.tex_string, submob) @@ -175,13 +171,17 @@ class MTex(_TexSVG): ## Static methods @staticmethod - def color_to_label(color: ManimColor) -> int: + def color_to_int(color: ManimColor) -> int: r, g, b = color_to_int_rgb(color) rg = r * 256 + g - rgb = rg * 256 + b - if rgb == 16777215: # white - return 0 - return rgb + return rg * 256 + b + + @staticmethod + def color_to_label(color: ManimColor) -> int: + result = MTex.color_to_int(color) + if result == 16777215: # white + return -1 + return result @staticmethod def get_color_command(label: int) -> str: @@ -227,6 +227,7 @@ class MTex(_TexSVG): def init_parser(self) -> None: self.additional_substrings = self.get_additional_substrings() + self.full_span = self.get_full_span() self.backslash_indices = self.get_backslash_indices() self.left_brace_indices, self.right_brace_indices = \ self.get_left_and_right_indices() @@ -236,15 +237,15 @@ class MTex(_TexSVG): self.script_content_spans = self.get_script_content_spans() self.double_braces_spans = self.get_double_braces_spans() self.stripped_substrings = self.get_stripped_substrings() - self.specified_spans = self.get_specified_spans() self.specified_substrings = self.get_specified_substrings() + self.specified_spans, self.extended_specified_spans = \ + self.get_specified_spans() self.tex_span_list = self.get_tex_span_list() self.extended_tex_span_list = self.get_extended_tex_span_list() self.isolated_substrings = self.get_isolated_substrings() self.containing_labels_dict = self.get_containing_labels_dict() self.color_cmd_repl_items = self.get_color_cmd_repl_items() self.span_repl_dict = self.get_span_repl_dict() - self.labelled_tex_string = self.get_labelled_tex_string() def get_additional_substrings(self) -> list[str]: return list(it.chain( @@ -252,28 +253,31 @@ class MTex(_TexSVG): self.isolate )) + def get_full_span(self) -> tuple[int, int]: + return (0, len(self.string)) + def get_backslash_indices(self) -> list[int]: # Newlines (`\\`) don't count. return [ match_obj.end() - 1 - for match_obj in re.finditer(r"\\+", self.tex_string) + for match_obj in re.finditer(r"\\+", self.string) if len(match_obj.group()) % 2 == 1 ] def get_left_and_right_indices(self) -> list[tuple[int, int]]: - tex_string = self.tex_string + string = self.string indices = list(filter( lambda index: index - 1 not in self.backslash_indices, [ match_obj.start() - for match_obj in re.finditer(r"[{}]", tex_string) + for match_obj in re.finditer(r"[{}]", string) ] )) left_brace_indices = [] right_brace_indices = [] left_brace_indices_stack = [] for index in indices: - if tex_string[index] == "{": + if string[index] == "{": left_brace_indices_stack.append(index) else: if not left_brace_indices_stack: @@ -288,7 +292,7 @@ class MTex(_TexSVG): def get_script_char_spans(self) -> list[tuple[int, int]]: return [ match_obj.span() - for match_obj in re.finditer(r"(\s*)[_^]\s*", self.tex_string) + for match_obj in re.finditer(r"(\s*)[_^]\s*", self.string) if match_obj.group(1) or match_obj.start() - 1 not in self.backslash_indices ] @@ -296,14 +300,14 @@ class MTex(_TexSVG): def get_skipped_indices(self) -> list[int]: return sorted(remove_list_redundancies([ match_obj.start() - for match_obj in re.finditer(r"\s", self.tex_string) + for match_obj in re.finditer(r"\s", self.string) ] + list(it.chain(*[ range(*script_char_span) for script_char_span in self.script_char_spans ])))) def get_script_spans(self) -> list[tuple[int, int]]: - tex_string = self.tex_string + string = self.string result = [] brace_indices_dict = dict(zip( self.left_brace_indices, self.right_brace_indices @@ -313,7 +317,7 @@ class MTex(_TexSVG): span_end = brace_indices_dict[span_begin] + 1 else: pattern = re.compile(r"[a-zA-Z0-9]|\\[a-zA-Z]+") - match_obj = pattern.match(tex_string, pos=span_begin) + match_obj = pattern.match(string, pos=span_begin) if not match_obj: script_name = { "_": "subscript", @@ -361,64 +365,68 @@ class MTex(_TexSVG): def get_stripped_substrings(self) -> list[str]: result = remove_list_redundancies([ - string.strip() - for string in self.additional_substrings + substr.strip() + for substr in self.additional_substrings ]) if "" in result: result.remove("") return result - def get_specified_spans(self) -> list[tuple[int, int]]: - result = self.double_braces_spans.copy() - tex_string = self.tex_string + def get_specified_substrings(self) -> list[str]: + return remove_list_redundancies([ + self.string[slice(*double_braces_span)] + for double_braces_span in self.double_braces_spans + ] + list(filter( + lambda s: s in self.string, + self.stripped_substrings + ))) + + def get_specified_spans( + self + ) -> tuple[list[tuple[int, int]], list[tuple[int, int]]]: + tex_spans = sorted(remove_list_redundancies([ + self.full_span, + *self.double_braces_spans, + *[ + match_obj.span() + for substr in self.stripped_substrings + for match_obj in re.finditer(re.escape(substr), self.string) + ] + ]), key=lambda t: (t[0], -t[1])) + result = [] + extended_result = [] + script_spans_dict = dict(self.script_spans) reversed_script_spans_dict = dict([ script_span[::-1] for script_span in self.script_spans ]) - for string in self.stripped_substrings: - for match_obj in re.finditer(re.escape(string), tex_string): - span_begin, span_end = match_obj.span() - while span_end in reversed_script_spans_dict.keys(): - span_end = reversed_script_spans_dict[span_end] - if span_begin >= span_end: - continue - result.append((span_begin, span_end)) - return list(filter( - lambda tex_span: tex_span not in self.script_content_spans, - remove_list_redundancies(result) - )) - - def get_specified_substrings(self) -> list[str]: - return remove_list_redundancies([ - self.tex_string[slice(*double_braces_span)] - for double_braces_span in self.double_braces_spans - ] + list(filter( - lambda s: s in self.tex_string, - self.additional_substrings - ))) + for tex_span in tex_spans: + if tex_span in self.script_content_spans: + continue + span_begin, span_end = tex_span + extended_span_end = span_end + while span_end in reversed_script_spans_dict.keys(): + span_end = reversed_script_spans_dict[span_end] + while extended_span_end in script_spans_dict.keys(): + extended_span_end = script_spans_dict[extended_span_end] + specified_span = (span_begin, span_end) + extended_specified_span = (span_begin, extended_span_end) + if span_begin >= span_end: + continue + if extended_specified_span in result: + continue + result.append(specified_span) + extended_result.append(extended_specified_span) + return result, extended_result def get_tex_span_list(self) -> list[tuple[int, int]]: - return [ - (0, len(self.tex_string)), - *self.script_content_spans, - *self.specified_spans - ] + return self.specified_spans + self.script_content_spans def get_extended_tex_span_list(self) -> list[tuple[int, int]]: - extended_specified_spans = [] - script_spans_dict = dict(self.script_spans) - for span_begin, span_end in self.specified_spans: - while span_end in script_spans_dict.keys(): - span_end = script_spans_dict[span_end] - extended_specified_spans.append((span_begin, span_end)) - return [ - (0, len(self.tex_string)), - *self.script_content_spans, - *extended_specified_spans - ] + return self.extended_specified_spans + self.script_content_spans def get_isolated_substrings(self) -> list[str]: return remove_list_redundancies([ - self.tex_string[slice(*tex_span)] + self.string[slice(*tex_span)] for tex_span in self.tex_span_list ]) @@ -434,37 +442,38 @@ class MTex(_TexSVG): result[span_0].append(span_index) elif span_0[0] < span_1[0] < span_0[1] < span_1[1]: string_0, string_1 = [ - self.tex_string[slice(*tex_span)] + self.string[slice(*tex_span)] for tex_span in [span_0, span_1] ] raise ValueError( "Partially overlapping substrings detected: " f"'{string_0}' and '{string_1}'" ) + result[self.full_span] = list(range(len(tex_span_list))) return result def get_color_cmd_repl_items(self) -> list[tuple[tuple[int, int], str]]: - color_related_commands_dict = { - "color": 1, - "textcolor": 1, - "pagecolor": 1, - "colorbox": 1, - "fcolorbox": 2, - } + color_related_command_items = [ + ("color", 1, ""), + ("textcolor", 1, ""), + ("pagecolor", 1, "\\pagecolor{white}"), + ("colorbox", 1, "\\colorbox{white}"), + ("fcolorbox", 2, "\\fcolorbox{white}{white}"), + ] result = [] - tex_string = self.tex_string + string = self.string backslash_indices = self.backslash_indices left_indices = self.left_brace_indices brace_indices_dict = dict(zip( self.left_brace_indices, self.right_brace_indices )) - for cmd_name, n_braces in color_related_commands_dict.items(): + for cmd_name, n_braces, repl_str in color_related_command_items: pattern = cmd_name + r"(?![a-zA-Z])" - for match_obj in re.finditer(pattern, tex_string): + for match_obj in re.finditer(pattern, string): span_begin, span_end = match_obj.span() - if span_begin - 1 not in backslash_indices: + span_begin -= 1 + if span_begin not in backslash_indices: continue - repl_str = cmd_name + n_braces * "{black}" for _ in range(n_braces): left_index = min(filter( lambda index: index >= span_end, left_indices @@ -481,7 +490,7 @@ class MTex(_TexSVG): -tex_span[1 - flag], ("{{" + self.get_color_command(label), "}}")[flag] ) - for label, tex_span in enumerate(self.extended_tex_span_list) + for label, tex_span in enumerate(self.tex_span_list) for flag in range(2) ])) result = { @@ -493,9 +502,9 @@ class MTex(_TexSVG): result.update(self.color_cmd_repl_items) return result - def get_labelled_tex_string(self) -> str: + def get_labelled_string(self) -> str: if not self.span_repl_dict: - return self.tex_string + return self.string spans = sorted(self.span_repl_dict.keys()) if not all( @@ -506,10 +515,10 @@ class MTex(_TexSVG): span_ends, span_begins = zip(*spans) string_pieces = [ - self.tex_string[slice(*span)] + self.string[slice(*span)] for span in zip( (0, *span_begins), - (*span_ends, len(self.tex_string)) + (*span_ends, len(self.string)) ) ] repl_strs = [ @@ -521,7 +530,7 @@ class MTex(_TexSVG): def get_submob_tex_strings(self, submob_labels: list[int]) -> list[str]: ordered_tex_spans = [ - self.tex_span_list[label] + self.tex_span_list[label] if label != -1 else self.full_span for label in submob_labels ] ordered_containing_labels = [ @@ -546,7 +555,7 @@ class MTex(_TexSVG): ] string_span_ends.append(ordered_span_ends[-1]) - tex_string = self.tex_string + string = self.string left_indices = self.left_brace_indices right_indices = self.right_brace_indices skipped_indices = sorted(it.chain( @@ -575,7 +584,7 @@ class MTex(_TexSVG): unclosed_left_brace -= 1 result.append("".join([ unclosed_right_brace * "{", - tex_string[span_begin:span_end], + string[span_begin:span_end], unclosed_left_brace * "}" ])) return result @@ -615,7 +624,7 @@ class MTex(_TexSVG): custom_span ) if tex_spans is None: - tex = self.tex_string[slice(*custom_span)] + tex = self.string[slice(*custom_span)] raise ValueError(f"Failed to match mobjects from tex: \"{tex}\"") labels = set(it.chain(*[ @@ -623,7 +632,7 @@ class MTex(_TexSVG): for tex_span in tex_spans ])) return VGroup(*filter( - lambda submob: submob.submob_label in labels, + lambda submob: submob.label in labels, self.submobjects )) @@ -631,7 +640,7 @@ class MTex(_TexSVG): return VGroup(*[ self.get_part_by_custom_span(match_obj.span()) for match_obj in re.finditer( - re.escape(tex.strip()), self.tex_string + re.escape(tex.strip()), self.string ) ]) @@ -665,7 +674,7 @@ class MTex(_TexSVG): return self.indices_of_part(part) def get_tex(self) -> str: - return self.tex_string + return self.string def get_submob_tex(self) -> list[str]: return [ From 3b01ec48e6aa4dec5411c5ec159e5e048e2fbeda Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Sun, 27 Mar 2022 14:44:50 +0800 Subject: [PATCH 05/48] Refactor MTex --- manimlib/mobject/svg/mtex_mobject.py | 502 +++++++++++++++------------ 1 file changed, 285 insertions(+), 217 deletions(-) diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 1a71783d..a14004e7 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -6,7 +6,7 @@ import itertools as it from types import MethodType from typing import Iterable, Union, Sequence -from manimlib.constants import WHITE +from manimlib.constants import BLACK, WHITE from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.utils.color import color_to_int_rgb @@ -32,6 +32,9 @@ SCALE_FACTOR_PER_FONT_POINT = 0.001 class _TexSVG(SVGMobject): CONFIG = { "height": None, + "svg_default": { + "fill_color": BLACK, + }, "stroke_width": 0, "stroke_color": WHITE, "path_string_config": { @@ -80,25 +83,11 @@ class MTex(_TexSVG): self.use_plain_tex ) - def get_file_path(self) -> str: - self.init_parser() - self.use_plain_file = any([ - self.use_plain_tex, - self.color_cmd_repl_items, - self.base_color != WHITE - ]) - return self.get_file_path_(use_plain_file=self.use_plain_file) - - def get_file_path_(self, use_plain_file: bool) -> str: + def get_file_path(self, use_plain_file: bool = False) -> str: if use_plain_file: - content = "".join([ - "{{", - self.get_color_command(self.color_to_int(self.base_color)), - self.string, - "}}" - ]) + content = self.plain_string else: - content = self.get_labelled_string() + content = self.labelled_string full_tex = self.get_tex_file_body(content) with display_during_execution(f"Writing \"{self.string}\""): @@ -128,27 +117,26 @@ class MTex(_TexSVG): def generate_mobject(self) -> None: super().generate_mobject() - glyphs = self.submobjects - if not glyphs: + if not self.submobjects: return - if self.use_plain_file: - file_path = self.get_file_path_(use_plain_file=False) - labelled_svg_glyphs = _TexSVG(file_path) - predefined_colors = [ - labelled_glyph.get_fill_color() - for labelled_glyph in self.submobjects - ] - else: - labelled_svg_glyphs = self - predefined_colors = [self.base_color] * len(glyphs) - glyph_labels = [ - self.color_to_label(labelled_glyph.get_fill_color()) - for labelled_glyph in labelled_svg_glyphs + self.color_to_label(glyph.get_fill_color()) + for glyph in self.submobjects ] - for glyph, glyph_color in zip(glyphs, predefined_colors): - glyph.set_fill(glyph_color) + + if any([ + self.use_plain_tex, + self.color_cmd_repl_items, + self.base_color in (BLACK, WHITE) + ]): + file_path = self.get_file_path(use_plain_file=True) + glyphs = _TexSVG(file_path).submobjects + for glyph, plain_glyph in zip(self.submobjects, glyphs): + glyph.set_fill(plain_glyph.get_fill_color()) + else: + glyphs = self.submobjects + self.set_fill(self.base_color) # Simply pack together adjacent mobjects with the same label. submob_labels, glyphs_lists = self.group_neighbours( @@ -171,20 +159,18 @@ class MTex(_TexSVG): ## Static methods @staticmethod - def color_to_int(color: ManimColor) -> int: + def color_to_label(color: ManimColor) -> int: r, g, b = color_to_int_rgb(color) rg = r * 256 + g - return rg * 256 + b - - @staticmethod - def color_to_label(color: ManimColor) -> int: - result = MTex.color_to_int(color) - if result == 16777215: # white + rgb = rg * 256 + b + if rgb == 16777215: # white return -1 - return result + return rgb @staticmethod def get_color_command(label: int) -> str: + if label == -1: + label = 16777215 # white rg, b = divmod(label, 256) r, g = divmod(rg, 256) return "".join([ @@ -223,40 +209,83 @@ class MTex(_TexSVG): groups.append(new_group) return group_labels, groups + @staticmethod + def find_region_index(val: int, seq: list[int]) -> int: + # Returns an integer in `range(len(seq) + 1)` satisfying + # `seq[result - 1] <= val < seq[result]` + if not seq: + return 0 + if val >= seq[-1]: + return len(seq) + result = 0 + while val >= seq[result]: + result += 1 + return result + + @staticmethod + def lstrip(index: int, skipped_spans: list[tuple[int, int]]) -> int: + index_seq = list(it.chain(*skipped_spans)) + region_index = MTex.find_region_index(index, index_seq) + if region_index % 2 == 1: + return index_seq[region_index] + return index + + @staticmethod + def rstrip(index: int, skipped_spans: list[tuple[int, int]]) -> int: + index_seq = list(it.chain(*skipped_spans)) + region_index = MTex.find_region_index(index - 1, index_seq) + if region_index % 2 == 1: + return index_seq[region_index - 1] + return index + + @staticmethod + def strip( + tex_span: tuple[int, int], skipped_spans: list[tuple[int, int]] + ) -> tuple[int, int] | None: + result = ( + MTex.lstrip(tex_span[0], skipped_spans), + MTex.rstrip(tex_span[1], skipped_spans) + ) + if result[0] >= result[1]: + return None + return result + + @staticmethod + def lslide(index: int, slid_spans: list[tuple[int, int]]) -> int: + slide_dict = dict(slid_spans) + while index in slide_dict.keys(): + index = slide_dict[index] + return index + + @staticmethod + def rslide(index: int, slid_spans: list[tuple[int, int]]) -> int: + slide_dict = dict([ + slide_span[::-1] for slide_span in slid_spans + ]) + while index in slide_dict.keys(): + index = slide_dict[index] + return index + + @staticmethod + def slide( + tex_span: tuple[int, int], slid_spans: list[tuple[int, int]] + ) -> tuple[int, int] | None: + result = ( + MTex.lslide(tex_span[0], slid_spans), + MTex.rslide(tex_span[1], slid_spans) + ) + if result[0] >= result[1]: + return None + return result + ## Parser - def init_parser(self) -> None: - self.additional_substrings = self.get_additional_substrings() - self.full_span = self.get_full_span() - self.backslash_indices = self.get_backslash_indices() - self.left_brace_indices, self.right_brace_indices = \ - self.get_left_and_right_indices() - self.script_char_spans = self.get_script_char_spans() - self.skipped_indices = self.get_skipped_indices() - self.script_spans = self.get_script_spans() - self.script_content_spans = self.get_script_content_spans() - self.double_braces_spans = self.get_double_braces_spans() - self.stripped_substrings = self.get_stripped_substrings() - self.specified_substrings = self.get_specified_substrings() - self.specified_spans, self.extended_specified_spans = \ - self.get_specified_spans() - self.tex_span_list = self.get_tex_span_list() - self.extended_tex_span_list = self.get_extended_tex_span_list() - self.isolated_substrings = self.get_isolated_substrings() - self.containing_labels_dict = self.get_containing_labels_dict() - self.color_cmd_repl_items = self.get_color_cmd_repl_items() - self.span_repl_dict = self.get_span_repl_dict() - - def get_additional_substrings(self) -> list[str]: - return list(it.chain( - self.tex_to_color_map.keys(), - self.isolate - )) - - def get_full_span(self) -> tuple[int, int]: + @property + def full_span(self) -> tuple[int, int]: return (0, len(self.string)) - def get_backslash_indices(self) -> list[int]: + @property + def backslash_indices(self) -> list[int]: # Newlines (`\\`) don't count. return [ match_obj.end() - 1 @@ -264,7 +293,9 @@ class MTex(_TexSVG): if len(match_obj.group()) % 2 == 1 ] - def get_left_and_right_indices(self) -> list[tuple[int, int]]: + def get_left_and_right_brace_indices( + self + ) -> tuple[list[tuple[int, int]], list[tuple[int, int]]]: string = self.string indices = list(filter( lambda index: index - 1 not in self.backslash_indices, @@ -287,60 +318,69 @@ class MTex(_TexSVG): right_brace_indices.append(index) if left_brace_indices_stack: raise ValueError("Missing '}' inserted") + # `right_brace_indices` is already sorted. return left_brace_indices, right_brace_indices - def get_script_char_spans(self) -> list[tuple[int, int]]: + @property + def left_brace_indices(self) -> list[tuple[int, int]]: + return self.get_left_and_right_brace_indices()[0] + + @property + def right_brace_indices(self) -> list[tuple[int, int]]: + return self.get_left_and_right_brace_indices()[1] + + @property + def skipped_spans(self) -> list[tuple[int, int]]: return [ match_obj.span() - for match_obj in re.finditer(r"(\s*)[_^]\s*", self.string) - if match_obj.group(1) - or match_obj.start() - 1 not in self.backslash_indices + for match_obj in re.finditer(r"\s*([_^])\s*|(\s+)", self.string) + if match_obj.group(2) is not None + or match_obj.start(1) - 1 not in self.backslash_indices ] - def get_skipped_indices(self) -> list[int]: - return sorted(remove_list_redundancies([ - match_obj.start() - for match_obj in re.finditer(r"\s", self.string) - ] + list(it.chain(*[ - range(*script_char_span) - for script_char_span in self.script_char_spans - ])))) + def lstrip_span(self, index: int) -> int: + return self.lstrip(index, self.skipped_spans) - def get_script_spans(self) -> list[tuple[int, int]]: - string = self.string + def rstrip_span(self, index: int) -> int: + return self.rstrip(index, self.skipped_spans) + + def strip_span(self, index: int) -> int: + return self.strip(index, self.skipped_spans) + + @property + def script_char_spans(self) -> list[tuple[int, int]]: + return list(filter( + lambda tex_span: self.string[slice(*tex_span)].strip(), + self.skipped_spans + )) + + @property + def script_content_spans(self) -> list[tuple[int, int]]: result = [] brace_indices_dict = dict(zip( self.left_brace_indices, self.right_brace_indices )) - for char_begin, span_begin in self.script_char_spans: + for _, span_begin in self.script_char_spans: if span_begin in brace_indices_dict.keys(): span_end = brace_indices_dict[span_begin] + 1 else: pattern = re.compile(r"[a-zA-Z0-9]|\\[a-zA-Z]+") - match_obj = pattern.match(string, pos=span_begin) + match_obj = pattern.match(self.string, pos=span_begin) if not match_obj: script_name = { "_": "subscript", "^": "superscript" }[script_char] - log.warning( + raise ValueError( f"Unclear {script_name} detected while parsing. " "Please use braces to clarify" ) - continue span_end = match_obj.end() - result.append((char_begin, span_end)) + result.append((span_begin, span_end)) return result - def get_script_content_spans(self) -> list[tuple[int, int]]: - return [ - (script_char_span[1], script_span[1]) - for script_char_span, script_span in zip( - self.script_char_spans, self.script_spans - ) - ] - - def get_double_braces_spans(self) -> list[tuple[int, int]]: + @property + def double_braces_spans(self) -> list[tuple[int, int]]: # Match paired double braces (`{{...}}`). result = [] reversed_brace_indices_dict = dict(zip( @@ -348,7 +388,7 @@ class MTex(_TexSVG): )) skip = False for prev_right_index, right_index in self.get_neighbouring_pairs( - sorted(reversed_brace_indices_dict.keys()) + self.right_brace_indices ): if skip: skip = False @@ -363,74 +403,64 @@ class MTex(_TexSVG): skip = True return result - def get_stripped_substrings(self) -> list[str]: - result = remove_list_redundancies([ - substr.strip() - for substr in self.additional_substrings - ]) + @property + def additional_substrings(self) -> list[str]: + result = remove_list_redundancies(list(it.chain( + self.tex_to_color_map.keys(), + self.isolate + ))) if "" in result: result.remove("") return result - def get_specified_substrings(self) -> list[str]: - return remove_list_redundancies([ - self.string[slice(*double_braces_span)] - for double_braces_span in self.double_braces_spans - ] + list(filter( - lambda s: s in self.string, - self.stripped_substrings - ))) - - def get_specified_spans( + def get_tex_span_lists( self ) -> tuple[list[tuple[int, int]], list[tuple[int, int]]]: - tex_spans = sorted(remove_list_redundancies([ - self.full_span, - *self.double_braces_spans, - *[ - match_obj.span() - for substr in self.stripped_substrings - for match_obj in re.finditer(re.escape(substr), self.string) - ] - ]), key=lambda t: (t[0], -t[1])) result = [] extended_result = [] - script_spans_dict = dict(self.script_spans) - reversed_script_spans_dict = dict([ - script_span[::-1] for script_span in self.script_spans + script_content_spans = self.script_content_spans + script_spans = [ + (script_char_span[0], script_content_span[1]) + for script_char_span, script_content_span in zip( + self.script_char_spans, script_content_spans + ) + ] + tex_spans = remove_list_redundancies([ + self.full_span, + *self.double_braces_spans, + *filter(lambda stripped_span: stripped_span is not None, [ + self.strip_span(match_obj.span()) + for substr in self.additional_substrings + for match_obj in re.finditer(re.escape(substr), self.string) + ]), + *script_content_spans ]) for tex_span in tex_spans: - if tex_span in self.script_content_spans: + if tex_span in script_content_spans: + result.append(tex_span) + extended_result.append(tex_span) continue span_begin, span_end = tex_span - extended_span_end = span_end - while span_end in reversed_script_spans_dict.keys(): - span_end = reversed_script_spans_dict[span_end] - while extended_span_end in script_spans_dict.keys(): - extended_span_end = script_spans_dict[extended_span_end] - specified_span = (span_begin, span_end) - extended_specified_span = (span_begin, extended_span_end) - if span_begin >= span_end: + shrinked_span = (span_begin, self.rslide(span_end, script_spans)) + extended_span = (span_begin, self.lslide(span_end, script_spans)) + if shrinked_span[0] >= shrinked_span[1]: continue - if extended_specified_span in result: + if shrinked_span in result: continue - result.append(specified_span) - extended_result.append(extended_specified_span) + result.append(shrinked_span) + extended_result.append(extended_span) return result, extended_result - def get_tex_span_list(self) -> list[tuple[int, int]]: - return self.specified_spans + self.script_content_spans + @property + def tex_span_list(self) -> list[tuple[int, int]]: + return self.get_tex_span_lists()[0] - def get_extended_tex_span_list(self) -> list[tuple[int, int]]: - return self.extended_specified_spans + self.script_content_spans + @property + def extended_tex_span_list(self) -> list[tuple[int, int]]: + return self.get_tex_span_lists()[1] - def get_isolated_substrings(self) -> list[str]: - return remove_list_redundancies([ - self.string[slice(*tex_span)] - for tex_span in self.tex_span_list - ]) - - def get_containing_labels_dict(self) -> dict[tuple[int, int], list[int]]: + @property + def containing_labels_dict(self) -> dict[tuple[int, int], list[int]]: tex_span_list = self.tex_span_list result = { tex_span: [] @@ -449,40 +479,46 @@ class MTex(_TexSVG): "Partially overlapping substrings detected: " f"'{string_0}' and '{string_1}'" ) - result[self.full_span] = list(range(len(tex_span_list))) + result[self.full_span] = list(range(-1, len(tex_span_list))) return result - def get_color_cmd_repl_items(self) -> list[tuple[tuple[int, int], str]]: - color_related_command_items = [ - ("color", 1, ""), - ("textcolor", 1, ""), - ("pagecolor", 1, "\\pagecolor{white}"), - ("colorbox", 1, "\\colorbox{white}"), - ("fcolorbox", 2, "\\fcolorbox{white}{white}"), - ] + @property + def color_cmd_repl_items(self) -> list[tuple[tuple[int, int], str]]: + color_related_command_dict = { + "color": (1, False), + "textcolor": (1, False), + "pagecolor": (1, True), + "colorbox": (1, True), + "fcolorbox": (2, True), + } result = [] - string = self.string backslash_indices = self.backslash_indices - left_indices = self.left_brace_indices - brace_indices_dict = dict(zip( - self.left_brace_indices, self.right_brace_indices - )) - for cmd_name, n_braces, repl_str in color_related_command_items: - pattern = cmd_name + r"(?![a-zA-Z])" - for match_obj in re.finditer(pattern, string): - span_begin, span_end = match_obj.span() - span_begin -= 1 - if span_begin not in backslash_indices: - continue - for _ in range(n_braces): - left_index = min(filter( - lambda index: index >= span_end, left_indices - )) - span_end = brace_indices_dict[left_index] + 1 - result.append(((span_begin, span_end), repl_str)) + right_brace_indices = self.right_brace_indices + pattern = "".join([ + r"\\", + "(", + "|".join(color_related_command_dict.keys()), + ")", + r"(?![a-zA-Z])" + ]) + for match_obj in re.finditer(pattern, self.string): + span_begin, cmd_end = match_obj.span() + if span_begin not in backslash_indices: + continue + cmd_name = match_obj.group(1) + n_braces, substitute_cmd = color_related_command_dict[cmd_name] + span_end = right_brace_indices[self.find_region_index( + cmd_end, right_brace_indices + ) + n_braces - 1] + 1 + if substitute_cmd: + repl_str = "\\" + cmd_name + n_braces * "{white}" + else: + repl_str = "" + result.append(((span_begin, span_end), repl_str)) return result - def get_span_repl_dict(self) -> dict[tuple[int, int], str]: + @property + def span_repl_dict(self) -> dict[tuple[int, int], str]: indices, _, _, cmd_strings = zip(*sorted([ ( tex_span[flag], @@ -490,7 +526,7 @@ class MTex(_TexSVG): -tex_span[1 - flag], ("{{" + self.get_color_command(label), "}}")[flag] ) - for label, tex_span in enumerate(self.tex_span_list) + for label, tex_span in enumerate(self.extended_tex_span_list) for flag in range(2) ])) result = { @@ -502,7 +538,17 @@ class MTex(_TexSVG): result.update(self.color_cmd_repl_items) return result - def get_labelled_string(self) -> str: + @property + def plain_string(self) -> str: + return "".join([ + "{{", + self.get_color_command(self.color_to_label(self.base_color)), + self.string, + "}}" + ]) + + @property + def labelled_string(self) -> str: if not self.span_repl_dict: return self.string @@ -545,7 +591,6 @@ class MTex(_TexSVG): ordered_containing_labels[1:], ordered_span_begins[1:] ) ] - string_span_begins.insert(0, ordered_span_begins[0]) string_span_ends = [ next_begin if next_label in containing_labels else curr_end for next_begin, next_label, containing_labels, curr_end in zip( @@ -553,49 +598,72 @@ class MTex(_TexSVG): ordered_containing_labels[:-1], ordered_span_ends[:-1] ) ] - string_span_ends.append(ordered_span_ends[-1]) + string_spans = list(zip( + (ordered_span_begins[0], *string_span_begins), + (*string_span_ends, ordered_span_ends[-1]) + )) string = self.string - left_indices = self.left_brace_indices - right_indices = self.right_brace_indices - skipped_indices = sorted(it.chain( - self.skipped_indices, - left_indices, - right_indices - )) + left_brace_indices = self.left_brace_indices + right_brace_indices = self.right_brace_indices + slid_spans = self.skipped_spans + [ + (index, index + 1) + for index in left_brace_indices + right_brace_indices + ] result = [] - for span_begin, span_end in zip(string_span_begins, string_span_ends): - while span_begin in skipped_indices: - span_begin += 1 - if span_begin >= span_end: - result.append("") + for str_span in string_spans: + str_span = self.strip_span(str_span) + if str_span is None: continue - while span_end - 1 in skipped_indices: - span_end -= 1 - unclosed_left_brace = 0 - unclosed_right_brace = 0 - for index in range(span_begin, span_end): - if index in left_indices: - unclosed_left_brace += 1 - elif index in right_indices: - if unclosed_left_brace == 0: - unclosed_right_brace += 1 + str_span = self.slide(str_span, slid_spans) + if str_span is None: + continue + unclosed_left_braces = 0 + unclosed_right_braces = 0 + for index in range(*str_span): + if index in left_brace_indices: + unclosed_left_braces += 1 + elif index in right_brace_indices: + if unclosed_left_braces == 0: + unclosed_right_braces += 1 else: - unclosed_left_brace -= 1 + unclosed_left_braces -= 1 result.append("".join([ - unclosed_right_brace * "{", - string[span_begin:span_end], - unclosed_left_brace * "}" + unclosed_right_braces * "{", + string[slice(*str_span)], + unclosed_left_braces * "}" ])) return result + @property + def specified_substrings(self) -> list[str]: + return remove_list_redundancies([ + self.string[slice(*double_braces_span)] + for double_braces_span in self.double_braces_spans + ] + list(filter( + lambda s: s in self.string, + self.additional_substrings + ))) + + def get_specified_substrings(self) -> list[str]: + return self.specified_substrings + + @property + def isolated_substrings(self) -> list[str]: + return remove_list_redundancies([ + self.string[slice(*tex_span)] + for tex_span in self.tex_span_list + ]) + + def get_isolated_substrings(self) -> list[str]: + return self.isolated_substrings + ## Selector def find_span_components_of_custom_span( self, custom_span: tuple[int, int] ) -> list[tuple[int, int]] | None: - skipped_indices = self.skipped_indices tex_span_choices = sorted(filter( lambda tex_span: all([ tex_span[0] >= custom_span[0], @@ -606,13 +674,13 @@ class MTex(_TexSVG): # Choose spans that reach the farthest. tex_span_choices_dict = dict(tex_span_choices) - span_begin, span_end = custom_span result = [] + span_begin, span_end = custom_span + span_begin = self.rstrip_span(span_begin) + span_end = self.rstrip_span(span_end) while span_begin != span_end: + span_begin = self.lstrip_span(span_begin) if span_begin not in tex_span_choices_dict.keys(): - if span_begin in skipped_indices: - span_begin += 1 - continue return None next_begin = tex_span_choices_dict[span_begin] result.append((span_begin, next_begin)) @@ -640,7 +708,7 @@ class MTex(_TexSVG): return VGroup(*[ self.get_part_by_custom_span(match_obj.span()) for match_obj in re.finditer( - re.escape(tex.strip()), self.string + re.escape(tex), self.string ) ]) From 473aaea399d2080100b89b2497c79a46703e32d6 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Mon, 28 Mar 2022 17:55:50 +0800 Subject: [PATCH 06/48] Construct LabelledString --- .../animation/transform_matching_parts.py | 44 +- manimlib/mobject/svg/mtex_mobject.py | 897 ++++++++++-------- manimlib/mobject/svg/text_mobject.py | 706 ++++++++------ 3 files changed, 921 insertions(+), 726 deletions(-) diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index 90ffa76f..486007dd 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -12,7 +12,7 @@ from manimlib.animation.transform import ReplacementTransform from manimlib.animation.transform import Transform from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Group -from manimlib.mobject.svg.mtex_mobject import MTex +from manimlib.mobject.svg.mtex_mobject import LabelledString from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.config_ops import digest_config @@ -153,15 +153,16 @@ class TransformMatchingTex(TransformMatchingParts): return mobject.get_tex() -class TransformMatchingMTex(AnimationGroup): +class TransformMatchingString(AnimationGroup): CONFIG = { "key_map": dict(), + "transform_mismatches_class": None, } - def __init__(self, source_mobject: MTex, target_mobject: MTex, **kwargs): + def __init__(self, source_mobject: LabelledString, target_mobject: LabelledString, **kwargs): digest_config(self, kwargs) - assert isinstance(source_mobject, MTex) - assert isinstance(target_mobject, MTex) + assert isinstance(source_mobject, LabelledString) + assert isinstance(target_mobject, LabelledString) anims = [] rest_source_submobs = source_mobject.submobjects.copy() rest_target_submobs = target_mobject.submobjects.copy() @@ -207,7 +208,7 @@ class TransformMatchingMTex(AnimationGroup): elif isinstance(key, range): indices.extend(key) elif isinstance(key, str): - all_parts = mobject.get_parts_by_tex(key) + all_parts = mobject.get_parts_by_string(key) indices.extend(it.chain(*[ mobject.indices_of_part(part) for part in all_parts ])) @@ -228,31 +229,34 @@ class TransformMatchingMTex(AnimationGroup): target_mobject.get_specified_substrings() ) ), key=len, reverse=True) - for part_tex_string in common_specified_substrings: + for part_string in common_specified_substrings: add_anim_from( - FadeTransformPieces, MTex.get_parts_by_tex, part_tex_string + FadeTransformPieces, LabelledString.get_parts_by_string, part_string ) - common_submob_tex_strings = { - source_submob.get_tex() for source_submob in source_mobject + common_submob_strings = { + source_submob.get_string() for source_submob in source_mobject }.intersection({ - target_submob.get_tex() for target_submob in target_mobject + target_submob.get_string() for target_submob in target_mobject }) - for tex_string in common_submob_tex_strings: + for substr in common_submob_strings: add_anim_from( FadeTransformPieces, lambda mobject, attr: VGroup(*[ VGroup(mob) for mob in mobject - if mob.get_tex() == attr + if mob.get_string() == attr ]), - tex_string + substr ) - anims.append(FadeOutToPoint( - VGroup(*rest_source_submobs), target_mobject.get_center(), **kwargs - )) - anims.append(FadeInFromPoint( - VGroup(*rest_target_submobs), source_mobject.get_center(), **kwargs - )) + if self.transform_mismatches_class is not None: + anims.append(self.transform_mismatches_class(fade_source, fade_target, **kwargs)) + else: + anims.append(FadeOutToPoint( + VGroup(*rest_source_submobs), target_mobject.get_center(), **kwargs + )) + anims.append(FadeInFromPoint( + VGroup(*rest_target_submobs), source_mobject.get_center(), **kwargs + )) super().__init__(*anims) diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index a14004e7..efc542b2 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -5,8 +5,9 @@ import colour import itertools as it from types import MethodType from typing import Iterable, Union, Sequence +from abc import abstractmethod -from manimlib.constants import BLACK, WHITE +from manimlib.constants import WHITE from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.utils.color import color_to_int_rgb @@ -24,17 +25,15 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from manimlib.mobject.types.vectorized_mobject import VMobject ManimColor = Union[str, colour.Color, Sequence[float]] + Span = tuple[int, int] SCALE_FACTOR_PER_FONT_POINT = 0.001 -class _TexSVG(SVGMobject): +class _StringSVG(SVGMobject): CONFIG = { "height": None, - "svg_default": { - "fill_color": BLACK, - }, "stroke_width": 0, "stroke_color": WHITE, "path_string_config": { @@ -44,75 +43,29 @@ class _TexSVG(SVGMobject): } -class MTex(_TexSVG): +class LabelledString(_StringSVG): + """ + An abstract base class for `MTex` and `MarkupText` + """ CONFIG = { "base_color": WHITE, - "font_size": 48, - "alignment": "\\centering", - "tex_environment": "align*", - "isolate": [], - "tex_to_color_map": {}, - "use_plain_tex": False, + "use_plain_file": False, } def __init__(self, string: str, **kwargs): - digest_config(self, kwargs) - string = string.strip() - # Prevent from passing an empty string. - if not string: - string = "\\quad" - self.tex_string = string self.string = string super().__init__(**kwargs) - self.set_color_by_tex_to_color_map(self.tex_to_color_map) - self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size) - - @property - def hash_seed(self) -> tuple: - return ( - self.__class__.__name__, - self.svg_default, - self.path_string_config, - self.string, - self.base_color, - self.alignment, - self.tex_environment, - self.isolate, - self.tex_to_color_map, - self.use_plain_tex - ) - def get_file_path(self, use_plain_file: bool = False) -> str: if use_plain_file: content = self.plain_string else: content = self.labelled_string + return self.get_file_path_by_content(content) - full_tex = self.get_tex_file_body(content) - with display_during_execution(f"Writing \"{self.string}\""): - file_path = self.tex_to_svg_file_path(full_tex) - return file_path - - def get_tex_file_body(self, content: str) -> str: - if self.tex_environment: - content = "\n".join([ - f"\\begin{{{self.tex_environment}}}", - content, - f"\\end{{{self.tex_environment}}}" - ]) - if self.alignment: - content = "\n".join([self.alignment, content]) - - tex_config = get_tex_config() - return tex_config["tex_body"].replace( - tex_config["text_to_replace"], - content - ) - - @staticmethod - def tex_to_svg_file_path(tex_file_content: str) -> str: - return tex_to_svg_file(tex_file_content) + @abstractmethod + def get_file_path_by_content(self, content: str) -> str: + return "" def generate_mobject(self) -> None: super().generate_mobject() @@ -125,13 +78,9 @@ class MTex(_TexSVG): for glyph in self.submobjects ] - if any([ - self.use_plain_tex, - self.color_cmd_repl_items, - self.base_color in (BLACK, WHITE) - ]): + if self.use_plain_file or self.has_predefined_colors: file_path = self.get_file_path(use_plain_file=True) - glyphs = _TexSVG(file_path).submobjects + glyphs = _StringSVG(file_path).submobjects for glyph, plain_glyph in zip(self.submobjects, glyphs): glyph.set_fill(plain_glyph.get_fill_color()) else: @@ -142,21 +91,19 @@ class MTex(_TexSVG): submob_labels, glyphs_lists = self.group_neighbours( glyph_labels, glyphs ) - submobjects = [ - VGroup(*glyph_list) - for glyph_list in glyphs_lists - ] - submob_tex_strings = self.get_submob_tex_strings(submob_labels) - for submob, label, submob_tex in zip( - submobjects, submob_labels, submob_tex_strings + submob_strings = self.get_submob_strings(submob_labels) + submobjects = [] + for glyph_list, label, submob_string in zip( + glyphs_lists, submob_labels, submob_strings ): + submob = VGroup(*glyph_list) submob.label = label - submob.tex_string = submob_tex - # Support `get_tex()` method here. - submob.get_tex = MethodType(lambda inst: inst.tex_string, submob) + submob.string = submob_string + submob.get_string = MethodType(lambda inst: inst.string, submob) + submobjects.append(submob) self.set_submobjects(submobjects) - ## Static methods + # Toolkits @staticmethod def color_to_label(color: ManimColor) -> int: @@ -167,23 +114,14 @@ class MTex(_TexSVG): return -1 return rgb - @staticmethod - def get_color_command(label: int) -> str: - if label == -1: - label = 16777215 # white - rg, b = divmod(label, 256) - r, g = divmod(rg, 256) - return "".join([ - "\\color[RGB]", - "{", - ",".join(map(str, (r, g, b))), - "}" - ]) - @staticmethod def get_neighbouring_pairs(iterable: Iterable) -> list: return list(adjacent_pairs(iterable))[:-1] + @staticmethod + def span_contains(span_0: Span, span_1: Span) -> bool: + return span_0[0] <= span_1[0] and span_0[1] >= span_1[1] + @staticmethod def group_neighbours( labels: Iterable[object], @@ -211,79 +149,412 @@ class MTex(_TexSVG): @staticmethod def find_region_index(val: int, seq: list[int]) -> int: - # Returns an integer in `range(len(seq) + 1)` satisfying - # `seq[result - 1] <= val < seq[result]` - if not seq: - return 0 - if val >= seq[-1]: - return len(seq) - result = 0 - while val >= seq[result]: - result += 1 + # Returns an integer in `range(-1, len(seq))` satisfying + # `seq[result] <= val < seq[result + 1]`. + # `seq` should be sorted in ascending order. + if not seq or val < seq[0]: + return -1 + result = len(seq) - 1 + while val < seq[result]: + result -= 1 return result @staticmethod - def lstrip(index: int, skipped_spans: list[tuple[int, int]]) -> int: - index_seq = list(it.chain(*skipped_spans)) - region_index = MTex.find_region_index(index, index_seq) - if region_index % 2 == 1: + def replace_str_by_spans( + substr: str, span_repl_dict: dict[Span, str] + ) -> str: + if not span_repl_dict: + return substr + + spans = sorted(span_repl_dict.keys()) + if not all( + span_0[1] <= span_1[0] + for span_0, span_1 in LabelledString.get_neighbouring_pairs(spans) + ): + raise ValueError("Overlapping replacement") + + span_ends, span_begins = zip(*spans) + pieces = [ + substr[slice(*span)] + for span in zip( + (0, *span_begins), + (*span_ends, len(substr)) + ) + ] + repl_strs = [*[span_repl_dict[span] for span in spans], ""] + return "".join(it.chain(*zip(pieces, repl_strs))) + + @staticmethod + def get_span_replacement_dict( + inserted_string_pairs: list[tuple[Span, tuple[str, str]]], + other_repl_items: list[tuple[Span, str]] + ) -> dict[Span, str]: + if not inserted_string_pairs: + return other_repl_items.copy() + + indices, _, _, inserted_strings = zip(*sorted([ + ( + span[flag], + -flag, + -span[1 - flag], + str_pair[flag] + ) + for span, str_pair in inserted_string_pairs + for flag in range(2) + ])) + result = { + (index, index): "".join(inserted_strs) + for index, inserted_strs in zip(*LabelledString.group_neighbours( + indices, inserted_strings + )) + } + result.update(other_repl_items) + return result + + @property + def skipped_spans(self) -> list[Span]: + return [] + + def lstrip(self, index: int) -> int: + index_seq = list(it.chain(*self.skipped_spans)) + region_index = self.find_region_index(index, index_seq) + if region_index % 2 == 0: + return index_seq[region_index + 1] + return index + + def rstrip(self, index: int) -> int: + index_seq = list(it.chain(*self.skipped_spans)) + region_index = self.find_region_index(index - 1, index_seq) + if region_index % 2 == 0: return index_seq[region_index] return index - @staticmethod - def rstrip(index: int, skipped_spans: list[tuple[int, int]]) -> int: - index_seq = list(it.chain(*skipped_spans)) - region_index = MTex.find_region_index(index - 1, index_seq) - if region_index % 2 == 1: - return index_seq[region_index - 1] - return index - - @staticmethod - def strip( - tex_span: tuple[int, int], skipped_spans: list[tuple[int, int]] - ) -> tuple[int, int] | None: + def strip(self, span: Span) -> Span | None: result = ( - MTex.lstrip(tex_span[0], skipped_spans), - MTex.rstrip(tex_span[1], skipped_spans) + self.lstrip(span[0]), + self.rstrip(span[1]) ) if result[0] >= result[1]: return None return result @staticmethod - def lslide(index: int, slid_spans: list[tuple[int, int]]) -> int: - slide_dict = dict(slid_spans) + def lslide(index: int, slid_spans: list[Span]) -> int: + slide_dict = dict(sorted(slid_spans)) while index in slide_dict.keys(): index = slide_dict[index] return index @staticmethod - def rslide(index: int, slid_spans: list[tuple[int, int]]) -> int: - slide_dict = dict([ + def rslide(index: int, slid_spans: list[Span]) -> int: + slide_dict = dict(sorted([ slide_span[::-1] for slide_span in slid_spans - ]) + ], reverse=True)) while index in slide_dict.keys(): index = slide_dict[index] return index @staticmethod - def slide( - tex_span: tuple[int, int], slid_spans: list[tuple[int, int]] - ) -> tuple[int, int] | None: + def slide(span: Span, slid_spans: list[Span]) -> Span | None: result = ( - MTex.lslide(tex_span[0], slid_spans), - MTex.rslide(tex_span[1], slid_spans) + LabelledString.lslide(span[0], slid_spans), + LabelledString.rslide(span[1], slid_spans) ) if result[0] >= result[1]: return None return result - ## Parser + # Parser @property - def full_span(self) -> tuple[int, int]: + def full_span(self) -> Span: return (0, len(self.string)) + def get_substrs_to_isolate(self, substrs: list[str]) -> list[str]: + result = list(filter( + lambda s: s in self.string, + remove_list_redundancies(substrs) + )) + if "" in result: + result.remove("") + return result + + @property + def label_span_list(self) -> list[Span]: + return [] + + @property + def inserted_string_pairs(self) -> list[tuple[Span, tuple[str, str]]]: + return [] + + @property + def command_repl_items(self) -> list[tuple[Span, str]]: + return [] + + @abstractmethod + def has_predefined_colors(self) -> bool: + return False + + @property + def plain_string(self) -> str: + return self.string + + @property + def labelled_string(self) -> str: + return self.replace_str_by_spans( + self.string, self.get_span_replacement_dict( + self.inserted_string_pairs, + self.command_repl_items + ) + ) + + @property + def ignored_indices_for_submob_strings(self) -> list[int]: + return [] + + def handle_submob_string(self, substr: str, string_span: Span) -> str: + return substr + + def get_submob_strings(self, submob_labels: list[int]) -> list[str]: + ordered_spans = [ + self.label_span_list[label] if label != -1 else self.full_span + for label in submob_labels + ] + ordered_containing_labels = [ + self.containing_labels_dict[span] + for span in ordered_spans + ] + ordered_span_begins, ordered_span_ends = zip(*ordered_spans) + string_span_begins = [ + prev_end if prev_label in containing_labels else curr_begin + for prev_end, prev_label, containing_labels, curr_begin in zip( + ordered_span_ends[:-1], submob_labels[:-1], + ordered_containing_labels[1:], ordered_span_begins[1:] + ) + ] + string_span_ends = [ + next_begin if next_label in containing_labels else curr_end + for next_begin, next_label, containing_labels, curr_end in zip( + ordered_span_begins[1:], submob_labels[1:], + ordered_containing_labels[:-1], ordered_span_ends[:-1] + ) + ] + string_spans = list(zip( + (ordered_span_begins[0], *string_span_begins), + (*string_span_ends, ordered_span_ends[-1]) + )) + + command_spans = [span for span, _ in self.command_repl_items] + slid_spans = list(it.chain( + self.skipped_spans, + command_spans, + [ + (index, index + 1) + for index in self.ignored_indices_for_submob_strings + ] + )) + result = [] + for string_span in string_spans: + string_span = self.slide(string_span, slid_spans) + if string_span is None: + result.append("") + continue + + span_repl_dict = { + tuple([index - string_span[0] for index in cmd_span]): "" + for cmd_span in command_spans + if self.span_contains(string_span, cmd_span) + } + substr = self.string[slice(*string_span)] + substr = self.replace_str_by_spans(substr, span_repl_dict) + substr = self.handle_submob_string(substr, string_span) + result.append(substr) + return result + + # Selector + + @property + def containing_labels_dict(self) -> dict[Span, list[int]]: + label_span_list = self.label_span_list + result = { + span: [] + for span in label_span_list + } + for span_0 in label_span_list: + for span_index, span_1 in enumerate(label_span_list): + if self.span_contains(span_0, span_1): + result[span_0].append(span_index) + elif span_0[0] < span_1[0] < span_0[1] < span_1[1]: + string_0, string_1 = [ + self.string[slice(*span)] + for span in [span_0, span_1] + ] + raise ValueError( + "Partially overlapping substrings detected: " + f"'{string_0}' and '{string_1}'" + ) + result[self.full_span] = list(range(-1, len(label_span_list))) + return result + + def find_span_components_of_custom_span( + self, custom_span: Span + ) -> list[Span] | None: + span_choices = sorted(filter( + lambda span: self.span_contains(custom_span, span), + self.label_span_list + )) + # Choose spans that reach the farthest. + span_choices_dict = dict(span_choices) + + result = [] + span_begin, span_end = custom_span + span_begin = self.rstrip(span_begin) + span_end = self.rstrip(span_end) + while span_begin != span_end: + span_begin = self.lstrip(span_begin) + if span_begin not in span_choices_dict.keys(): + return None + next_begin = span_choices_dict[span_begin] + result.append((span_begin, next_begin)) + span_begin = next_begin + return result + + def get_part_by_custom_span(self, custom_span: Span) -> VGroup: + spans = self.find_span_components_of_custom_span(custom_span) + if spans is None: + substr = self.string[slice(*custom_span)] + raise ValueError(f"Failed to match mobjects from \"{substr}\"") + + labels = set(it.chain(*[ + self.containing_labels_dict[span] + for span in spans + ])) + return VGroup(*filter( + lambda submob: submob.label in labels, + self.submobjects + )) + + def get_parts_by_string(self, substr: str) -> VGroup: + return VGroup(*[ + self.get_part_by_custom_span(match_obj.span()) + for match_obj in re.finditer(re.escape(substr), self.string) + ]) + + def get_part_by_string(self, substr: str, index: int = 0) -> VMobject: + all_parts = self.get_parts_by_string(substr) + return all_parts[index] + + def set_color_by_string(self, substr: str, color: ManimColor): + self.get_parts_by_string(substr).set_color(color) + return self + + def set_color_by_string_to_color_map( + self, string_to_color_map: dict[str, ManimColor] + ): + for substr, color in string_to_color_map.items(): + self.set_color_by_string(substr, color) + return self + + def indices_of_part(self, part: Iterable[VMobject]) -> list[int]: + indices = [ + index for index, submob in enumerate(self.submobjects) + if submob in part + ] + if not indices: + raise ValueError("Failed to find part") + return indices + + def indices_of_part_by_string( + self, substr: str, index: int = 0 + ) -> list[int]: + part = self.get_part_by_string(substr, index=index) + return self.indices_of_part(part) + + @property + def specified_substrings(self) -> list[str]: + return [] + + def get_specified_substrings(self) -> list[str]: + return self.specified_substrings + + @property + def isolated_substrings(self) -> list[str]: + return remove_list_redundancies([ + self.string[slice(*span)] + for span in self.label_span_list + ]) + + def get_isolated_substrings(self) -> list[str]: + return self.isolated_substrings + + def get_string(self) -> str: + return self.string + + +class MTex(LabelledString): + CONFIG = { + "font_size": 48, + "alignment": "\\centering", + "tex_environment": "align*", + "isolate": [], + "tex_to_color_map": {}, + "use_plain_file": False, + } + + def __init__(self, tex_string: str, **kwargs): + tex_string = tex_string.strip() + # Prevent from passing an empty string. + if not tex_string: + tex_string = "\\quad" + self.tex_string = tex_string + super().__init__(tex_string, **kwargs) + + self.set_color_by_tex_to_color_map(self.tex_to_color_map) + self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size) + + @property + def hash_seed(self) -> tuple: + return ( + self.__class__.__name__, + self.svg_default, + self.path_string_config, + self.tex_string, + self.base_color, + self.alignment, + self.tex_environment, + self.isolate, + self.tex_to_color_map, + self.use_plain_file + ) + + def get_file_path_by_content(self, content: str) -> str: + full_tex = self.get_tex_file_body(content) + with display_during_execution(f"Writing \"{self.string}\""): + file_path = self.tex_to_svg_file_path(full_tex) + return file_path + + def get_tex_file_body(self, content: str) -> str: + if self.tex_environment: + content = "\n".join([ + f"\\begin{{{self.tex_environment}}}", + content, + f"\\end{{{self.tex_environment}}}" + ]) + if self.alignment: + content = "\n".join([self.alignment, content]) + + tex_config = get_tex_config() + return tex_config["tex_body"].replace( + tex_config["text_to_replace"], + content + ) + + @staticmethod + def tex_to_svg_file_path(tex_file_content: str) -> str: + return tex_to_svg_file(tex_file_content) + + # Parser + @property def backslash_indices(self) -> list[int]: # Newlines (`\\`) don't count. @@ -293,9 +564,7 @@ class MTex(_TexSVG): if len(match_obj.group()) % 2 == 1 ] - def get_left_and_right_brace_indices( - self - ) -> tuple[list[tuple[int, int]], list[tuple[int, int]]]: + def get_brace_indices_lists(self) -> tuple[list[Span], list[Span]]: string = self.string indices = list(filter( lambda index: index - 1 not in self.backslash_indices, @@ -322,15 +591,15 @@ class MTex(_TexSVG): return left_brace_indices, right_brace_indices @property - def left_brace_indices(self) -> list[tuple[int, int]]: - return self.get_left_and_right_brace_indices()[0] + def left_brace_indices(self) -> list[Span]: + return self.get_brace_indices_lists()[0] @property - def right_brace_indices(self) -> list[tuple[int, int]]: - return self.get_left_and_right_brace_indices()[1] + def right_brace_indices(self) -> list[Span]: + return self.get_brace_indices_lists()[1] @property - def skipped_spans(self) -> list[tuple[int, int]]: + def skipped_spans(self) -> list[Span]: return [ match_obj.span() for match_obj in re.finditer(r"\s*([_^])\s*|(\s+)", self.string) @@ -338,24 +607,15 @@ class MTex(_TexSVG): or match_obj.start(1) - 1 not in self.backslash_indices ] - def lstrip_span(self, index: int) -> int: - return self.lstrip(index, self.skipped_spans) - - def rstrip_span(self, index: int) -> int: - return self.rstrip(index, self.skipped_spans) - - def strip_span(self, index: int) -> int: - return self.strip(index, self.skipped_spans) - @property - def script_char_spans(self) -> list[tuple[int, int]]: + def script_char_spans(self) -> list[Span]: return list(filter( - lambda tex_span: self.string[slice(*tex_span)].strip(), + lambda span: self.string[slice(*span)].strip(), self.skipped_spans )) @property - def script_content_spans(self) -> list[tuple[int, int]]: + def script_content_spans(self) -> list[Span]: result = [] brace_indices_dict = dict(zip( self.left_brace_indices, self.right_brace_indices @@ -380,7 +640,7 @@ class MTex(_TexSVG): return result @property - def double_braces_spans(self) -> list[tuple[int, int]]: + def double_braces_spans(self) -> list[Span]: # Match paired double braces (`{{...}}`). result = [] reversed_brace_indices_dict = dict(zip( @@ -405,19 +665,12 @@ class MTex(_TexSVG): @property def additional_substrings(self) -> list[str]: - result = remove_list_redundancies(list(it.chain( + return self.get_substrs_to_isolate(list(it.chain( self.tex_to_color_map.keys(), self.isolate ))) - if "" in result: - result.remove("") - return result - def get_tex_span_lists( - self - ) -> tuple[list[tuple[int, int]], list[tuple[int, int]]]: - result = [] - extended_result = [] + def get_label_span_list(self, extended: bool) -> list[Span]: script_content_spans = self.script_content_spans script_spans = [ (script_char_span[0], script_content_span[1]) @@ -425,65 +678,54 @@ class MTex(_TexSVG): self.script_char_spans, script_content_spans ) ] - tex_spans = remove_list_redundancies([ + spans = remove_list_redundancies([ self.full_span, *self.double_braces_spans, *filter(lambda stripped_span: stripped_span is not None, [ - self.strip_span(match_obj.span()) + self.strip(match_obj.span()) for substr in self.additional_substrings for match_obj in re.finditer(re.escape(substr), self.string) ]), *script_content_spans ]) - for tex_span in tex_spans: - if tex_span in script_content_spans: - result.append(tex_span) - extended_result.append(tex_span) + result = [] + for span in spans: + if span in script_content_spans: continue - span_begin, span_end = tex_span - shrinked_span = (span_begin, self.rslide(span_end, script_spans)) - extended_span = (span_begin, self.lslide(span_end, script_spans)) - if shrinked_span[0] >= shrinked_span[1]: + span_begin, span_end = span + shrinked_end = self.rslide(span_end, script_spans) + if span_begin >= shrinked_end: continue + shrinked_span = (span_begin, shrinked_end) if shrinked_span in result: continue result.append(shrinked_span) - extended_result.append(extended_span) - return result, extended_result + + if extended: + result = [ + (span_begin, self.lslide(span_end, script_spans)) + for span_begin, span_end in result + ] + return script_content_spans + result @property - def tex_span_list(self) -> list[tuple[int, int]]: - return self.get_tex_span_lists()[0] + def label_span_list(self) -> list[Span]: + return self.get_label_span_list(extended=False) @property - def extended_tex_span_list(self) -> list[tuple[int, int]]: - return self.get_tex_span_lists()[1] + def inserted_string_pairs(self) -> list[tuple[Span, tuple[str, str]]]: + return [ + (span, ( + "{{" + self.get_color_command_by_label(label), + "}}" + )) + for label, span in enumerate( + self.get_label_span_list(extended=True) + ) + ] @property - def containing_labels_dict(self) -> dict[tuple[int, int], list[int]]: - tex_span_list = self.tex_span_list - result = { - tex_span: [] - for tex_span in tex_span_list - } - for span_0 in tex_span_list: - for span_index, span_1 in enumerate(tex_span_list): - if span_0[0] <= span_1[0] and span_1[1] <= span_0[1]: - result[span_0].append(span_index) - elif span_0[0] < span_1[0] < span_0[1] < span_1[1]: - string_0, string_1 = [ - self.string[slice(*tex_span)] - for tex_span in [span_0, span_1] - ] - raise ValueError( - "Partially overlapping substrings detected: " - f"'{string_0}' and '{string_1}'" - ) - result[self.full_span] = list(range(-1, len(tex_span_list))) - return result - - @property - def color_cmd_repl_items(self) -> list[tuple[tuple[int, int], str]]: + def command_repl_items(self) -> list[tuple[Span, str]]: color_related_command_dict = { "color": (1, False), "textcolor": (1, False), @@ -509,7 +751,7 @@ class MTex(_TexSVG): n_braces, substitute_cmd = color_related_command_dict[cmd_name] span_end = right_brace_indices[self.find_region_index( cmd_end, right_brace_indices - ) + n_braces - 1] + 1 + ) + n_braces] + 1 if substitute_cmd: repl_str = "\\" + cmd_name + n_braces * "{white}" else: @@ -518,237 +760,84 @@ class MTex(_TexSVG): return result @property - def span_repl_dict(self) -> dict[tuple[int, int], str]: - indices, _, _, cmd_strings = zip(*sorted([ - ( - tex_span[flag], - -flag, - -tex_span[1 - flag], - ("{{" + self.get_color_command(label), "}}")[flag] - ) - for label, tex_span in enumerate(self.extended_tex_span_list) - for flag in range(2) - ])) - result = { - (index, index): "".join(cmd_strs) - for index, cmd_strs in zip(*self.group_neighbours( - indices, cmd_strings - )) - } - result.update(self.color_cmd_repl_items) - return result + def has_predefined_colors(self) -> bool: + return bool(self.command_repl_items) + + @staticmethod + def get_color_command_by_label(label: int) -> str: + if label == -1: + label = 16777215 # white + rg, b = divmod(label, 256) + r, g = divmod(rg, 256) + return "".join([ + "\\color[RGB]", + "{", + ",".join(map(str, (r, g, b))), + "}" + ]) @property def plain_string(self) -> str: return "".join([ "{{", - self.get_color_command(self.color_to_label(self.base_color)), + self.get_color_command_by_label( + self.color_to_label(self.base_color) + ), self.string, "}}" ]) @property - def labelled_string(self) -> str: - if not self.span_repl_dict: - return self.string + def ignored_indices_for_submob_strings(self) -> list[int]: + return self.left_brace_indices + self.right_brace_indices - spans = sorted(self.span_repl_dict.keys()) - if not all( - span_0[1] <= span_1[0] - for span_0, span_1 in self.get_neighbouring_pairs(spans) - ): - raise ValueError("Failed to generate the labelled string") - - span_ends, span_begins = zip(*spans) - string_pieces = [ - self.string[slice(*span)] - for span in zip( - (0, *span_begins), - (*span_ends, len(self.string)) - ) - ] - repl_strs = [ - self.span_repl_dict[span] - for span in spans - ] - repl_strs.append("") - return "".join(it.chain(*zip(string_pieces, repl_strs))) - - def get_submob_tex_strings(self, submob_labels: list[int]) -> list[str]: - ordered_tex_spans = [ - self.tex_span_list[label] if label != -1 else self.full_span - for label in submob_labels - ] - ordered_containing_labels = [ - self.containing_labels_dict[tex_span] - for tex_span in ordered_tex_spans - ] - ordered_span_begins, ordered_span_ends = zip(*ordered_tex_spans) - string_span_begins = [ - prev_end if prev_label in containing_labels else curr_begin - for prev_end, prev_label, containing_labels, curr_begin in zip( - ordered_span_ends[:-1], submob_labels[:-1], - ordered_containing_labels[1:], ordered_span_begins[1:] - ) - ] - string_span_ends = [ - next_begin if next_label in containing_labels else curr_end - for next_begin, next_label, containing_labels, curr_end in zip( - ordered_span_begins[1:], submob_labels[1:], - ordered_containing_labels[:-1], ordered_span_ends[:-1] - ) - ] - string_spans = list(zip( - (ordered_span_begins[0], *string_span_begins), - (*string_span_ends, ordered_span_ends[-1]) - )) - - string = self.string - left_brace_indices = self.left_brace_indices - right_brace_indices = self.right_brace_indices - slid_spans = self.skipped_spans + [ - (index, index + 1) - for index in left_brace_indices + right_brace_indices - ] - result = [] - for str_span in string_spans: - str_span = self.strip_span(str_span) - if str_span is None: - continue - str_span = self.slide(str_span, slid_spans) - if str_span is None: - continue - unclosed_left_braces = 0 - unclosed_right_braces = 0 - for index in range(*str_span): - if index in left_brace_indices: - unclosed_left_braces += 1 - elif index in right_brace_indices: - if unclosed_left_braces == 0: - unclosed_right_braces += 1 - else: - unclosed_left_braces -= 1 - result.append("".join([ - unclosed_right_braces * "{", - string[slice(*str_span)], - unclosed_left_braces * "}" - ])) - return result + def handle_submob_string(self, substr: str, string_span: Span) -> str: + unclosed_left_braces = 0 + unclosed_right_braces = 0 + for index in range(*string_span): + if index in self.left_brace_indices: + unclosed_left_braces += 1 + elif index in self.right_brace_indices: + if unclosed_left_braces == 0: + unclosed_right_braces += 1 + else: + unclosed_left_braces -= 1 + return "".join([ + unclosed_right_braces * "{", + substr, + unclosed_left_braces * "}" + ]) @property def specified_substrings(self) -> list[str]: return remove_list_redundancies([ self.string[slice(*double_braces_span)] for double_braces_span in self.double_braces_spans - ] + list(filter( - lambda s: s in self.string, - self.additional_substrings - ))) + ] + self.additional_substrings) - def get_specified_substrings(self) -> list[str]: - return self.specified_substrings + # Method alias - @property - def isolated_substrings(self) -> list[str]: - return remove_list_redundancies([ - self.string[slice(*tex_span)] - for tex_span in self.tex_span_list - ]) + def get_parts_by_tex(self, substr: str) -> VGroup: + return self.get_parts_by_string(substr) - def get_isolated_substrings(self) -> list[str]: - return self.isolated_substrings + def get_part_by_tex(self, substr: str, index: int = 0) -> VMobject: + return self.get_part_by_string(substr, index) - ## Selector - - def find_span_components_of_custom_span( - self, - custom_span: tuple[int, int] - ) -> list[tuple[int, int]] | None: - tex_span_choices = sorted(filter( - lambda tex_span: all([ - tex_span[0] >= custom_span[0], - tex_span[1] <= custom_span[1] - ]), - self.tex_span_list - )) - # Choose spans that reach the farthest. - tex_span_choices_dict = dict(tex_span_choices) - - result = [] - span_begin, span_end = custom_span - span_begin = self.rstrip_span(span_begin) - span_end = self.rstrip_span(span_end) - while span_begin != span_end: - span_begin = self.lstrip_span(span_begin) - if span_begin not in tex_span_choices_dict.keys(): - return None - next_begin = tex_span_choices_dict[span_begin] - result.append((span_begin, next_begin)) - span_begin = next_begin - return result - - def get_part_by_custom_span(self, custom_span: tuple[int, int]) -> VGroup: - tex_spans = self.find_span_components_of_custom_span( - custom_span - ) - if tex_spans is None: - tex = self.string[slice(*custom_span)] - raise ValueError(f"Failed to match mobjects from tex: \"{tex}\"") - - labels = set(it.chain(*[ - self.containing_labels_dict[tex_span] - for tex_span in tex_spans - ])) - return VGroup(*filter( - lambda submob: submob.label in labels, - self.submobjects - )) - - def get_parts_by_tex(self, tex: str) -> VGroup: - return VGroup(*[ - self.get_part_by_custom_span(match_obj.span()) - for match_obj in re.finditer( - re.escape(tex), self.string - ) - ]) - - def get_part_by_tex(self, tex: str, index: int = 0) -> VMobject: - all_parts = self.get_parts_by_tex(tex) - return all_parts[index] - - def set_color_by_tex(self, tex: str, color: ManimColor): - self.get_parts_by_tex(tex).set_color(color) - return self + def set_color_by_tex(self, substr: str, color: ManimColor): + return self.set_color_by_string(substr, color) def set_color_by_tex_to_color_map( - self, - tex_to_color_map: dict[str, ManimColor] + self, tex_to_color_map: dict[str, ManimColor] ): - for tex, color in tex_to_color_map.items(): - self.set_color_by_tex(tex, color) - return self + return self.set_color_by_string_to_color_map(tex_to_color_map) - def indices_of_part(self, part: Iterable[VMobject]) -> list[int]: - indices = [ - index for index, submob in enumerate(self.submobjects) - if submob in part - ] - if not indices: - raise ValueError("Failed to find part in tex") - return indices - - def indices_of_part_by_tex(self, tex: str, index: int = 0) -> list[int]: - part = self.get_part_by_tex(tex, index=index) - return self.indices_of_part(part) + def indices_of_part_by_tex( + self, substr: str, index: int = 0 + ) -> list[int]: + return self.indices_of_part_by_string(substr, index) def get_tex(self) -> str: - return self.string - - def get_submob_tex(self) -> list[str]: - return [ - submob.get_tex() - for submob in self.submobjects - ] + return self.get_string() class MTexText(MTex): diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index a13d1d80..24a5b111 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -2,11 +2,11 @@ from __future__ import annotations import os import re -import typing -from pathlib import Path - +import itertools as it import xml.sax.saxutils as saxutils +from pathlib import Path from contextlib import contextmanager +import typing from typing import Iterable, Sequence, Union import pygments @@ -17,198 +17,87 @@ from manimpango import MarkupUtils from manimlib.logger import log from manimlib.constants import * -from manimlib.mobject.geometry import Dot -from manimlib.mobject.svg.svg_mobject import SVGMobject +from manimlib.mobject.svg.mtex_mobject import LabelledString from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.utils.customization import get_customization from manimlib.utils.tex_file_writing import tex_hash from manimlib.utils.config_ops import digest_config from manimlib.utils.directories import get_downloads_dir from manimlib.utils.directories import get_text_dir +from manimlib.utils.iterables import remove_list_redundancies from typing import TYPE_CHECKING if TYPE_CHECKING: from manimlib.mobject.types.vectorized_mobject import VMobject + ManimColor = Union[str, colour.Color, Sequence[float]] + Span = tuple[int, int] TEXT_MOB_SCALE_FACTOR = 0.0076 DEFAULT_LINE_SPACING_SCALE = 0.6 -class _TextParser(object): - # See https://docs.gtk.org/Pango/pango_markup.html - # A tag containing two aliases will cause warning, - # so only use the first key of each group of aliases. - SPAN_ATTR_KEY_ALIAS_LIST = ( - ("font", "font_desc"), - ("font_family", "face"), - ("font_size", "size"), - ("font_style", "style"), - ("font_weight", "weight"), - ("font_variant", "variant"), - ("font_stretch", "stretch"), - ("font_features",), - ("foreground", "fgcolor", "color"), - ("background", "bgcolor"), - ("alpha", "fgalpha"), - ("background_alpha", "bgalpha"), - ("underline",), - ("underline_color",), - ("overline",), - ("overline_color",), - ("rise",), - ("baseline_shift",), - ("font_scale",), - ("strikethrough",), - ("strikethrough_color",), - ("fallback",), - ("lang",), - ("letter_spacing",), - ("gravity",), - ("gravity_hint",), - ("show",), - ("insert_hyphens",), - ("allow_breaks",), - ("line_height",), - ("text_transform",), - ("segment",), - ) - SPAN_ATTR_KEY_CONVERSION = { - key: key_alias_list[0] - for key_alias_list in SPAN_ATTR_KEY_ALIAS_LIST - for key in key_alias_list - } - - TAG_TO_ATTR_DICT = { - "b": {"font_weight": "bold"}, - "big": {"font_size": "larger"}, - "i": {"font_style": "italic"}, - "s": {"strikethrough": "true"}, - "sub": {"baseline_shift": "subscript", "font_scale": "subscript"}, - "sup": {"baseline_shift": "superscript", "font_scale": "superscript"}, - "small": {"font_size": "smaller"}, - "tt": {"font_family": "monospace"}, - "u": {"underline": "single"}, - } - - def __init__(self, text: str = "", is_markup: bool = True): - self.text = text - self.is_markup = is_markup - self.global_attrs = {} - self.local_attrs = {(0, len(self.text)): {}} - self.tag_strings = set() - if is_markup: - self.parse_markup() - - def parse_markup(self) -> None: - tag_pattern = r"""<(/?)(\w+)\s*((\w+\s*\=\s*('[^']*'|"[^"]*")\s*)*)>""" - attr_pattern = r"""(\w+)\s*\=\s*(?:(?:'([^']*)')|(?:"([^"]*)"))""" - start_match_obj_stack = [] - match_obj_pairs = [] - for match_obj in re.finditer(tag_pattern, self.text): - if not match_obj.group(1): - start_match_obj_stack.append(match_obj) - else: - match_obj_pairs.append((start_match_obj_stack.pop(), match_obj)) - self.tag_strings.add(match_obj.group()) - assert not start_match_obj_stack, "Unclosed tag(s) detected" - - for start_match_obj, end_match_obj in match_obj_pairs: - tag_name = start_match_obj.group(2) - assert tag_name == end_match_obj.group(2), "Unmatched tag names" - assert not end_match_obj.group(3), "Attributes shan't exist in ending tags" - if tag_name == "span": - attr_dict = { - match.group(1): match.group(2) or match.group(3) - for match in re.finditer(attr_pattern, start_match_obj.group(3)) - } - elif tag_name in _TextParser.TAG_TO_ATTR_DICT.keys(): - assert not start_match_obj.group(3), f"Attributes shan't exist in tag '{tag_name}'" - attr_dict = _TextParser.TAG_TO_ATTR_DICT[tag_name] - else: - raise AssertionError(f"Unknown tag: '{tag_name}'") - - text_span = (start_match_obj.end(), end_match_obj.start()) - self.update_local_attrs(text_span, attr_dict) - - @staticmethod - def convert_key_alias(key: str) -> str: - return _TextParser.SPAN_ATTR_KEY_CONVERSION[key] - - @staticmethod - def update_attr_dict(attr_dict: dict[str, str], key: str, value: typing.Any) -> None: - converted_key = _TextParser.convert_key_alias(key) - attr_dict[converted_key] = str(value) - - def update_global_attr(self, key: str, value: typing.Any) -> None: - _TextParser.update_attr_dict(self.global_attrs, key, value) - - def update_global_attrs(self, attr_dict: dict[str, typing.Any]) -> None: - for key, value in attr_dict.items(): - self.update_global_attr(key, value) - - def update_local_attr(self, span: tuple[int, int], key: str, value: typing.Any) -> None: - if span[0] >= span[1]: - log.warning(f"Span {span} doesn't match any part of the string") - return - - if span in self.local_attrs.keys(): - _TextParser.update_attr_dict(self.local_attrs[span], key, value) - return - - span_triplets = [] - for sp, attr_dict in self.local_attrs.items(): - if sp[1] <= span[0] or span[1] <= sp[0]: - continue - span_to_become = (max(sp[0], span[0]), min(sp[1], span[1])) - spans_to_add = [] - if sp[0] < span[0]: - spans_to_add.append((sp[0], span[0])) - if span[1] < sp[1]: - spans_to_add.append((span[1], sp[1])) - span_triplets.append((sp, span_to_become, spans_to_add)) - for span_to_remove, span_to_become, spans_to_add in span_triplets: - attr_dict = self.local_attrs.pop(span_to_remove) - for span_to_add in spans_to_add: - self.local_attrs[span_to_add] = attr_dict.copy() - self.local_attrs[span_to_become] = attr_dict - _TextParser.update_attr_dict(self.local_attrs[span_to_become], key, value) - - def update_local_attrs(self, text_span: tuple[int, int], attr_dict: dict[str, typing.Any]) -> None: - for key, value in attr_dict.items(): - self.update_local_attr(text_span, key, value) - - def remove_tags(self, string: str) -> str: - for tag_string in self.tag_strings: - string = string.replace(tag_string, "") - return string - - def get_text_pieces(self) -> list[tuple[str, dict[str, str]]]: - result = [] - for span in sorted(self.local_attrs.keys()): - text_piece = self.remove_tags(self.text[slice(*span)]) - if not text_piece: - continue - if not self.is_markup: - text_piece = saxutils.escape(text_piece) - attr_dict = self.global_attrs.copy() - attr_dict.update(self.local_attrs[span]) - result.append((text_piece, attr_dict)) - return result - - def get_markup_str_with_attrs(self) -> str: - return "".join([ - f"{text_piece}" - for text_piece, attr_dict in self.get_text_pieces() - ]) - - @staticmethod - def get_attr_dict_str(attr_dict: dict[str, str]) -> str: - return " ".join([ - f"{key}='{value}'" - for key, value in attr_dict.items() - ]) +# See https://docs.gtk.org/Pango/pango_markup.html +# A tag containing two aliases will cause warning, +# so only use the first key of each group of aliases. +SPAN_ATTR_KEY_ALIAS_LIST = ( + ("font", "font_desc"), + ("font_family", "face"), + ("font_size", "size"), + ("font_style", "style"), + ("font_weight", "weight"), + ("font_variant", "variant"), + ("font_stretch", "stretch"), + ("font_features",), + ("foreground", "fgcolor", "color"), + ("background", "bgcolor"), + ("alpha", "fgalpha"), + ("background_alpha", "bgalpha"), + ("underline",), + ("underline_color",), + ("overline",), + ("overline_color",), + ("rise",), + ("baseline_shift",), + ("font_scale",), + ("strikethrough",), + ("strikethrough_color",), + ("fallback",), + ("lang",), + ("letter_spacing",), + ("gravity",), + ("gravity_hint",), + ("show",), + ("insert_hyphens",), + ("allow_breaks",), + ("line_height",), + ("text_transform",), + ("segment",), +) +COLOR_RELATED_KEYS = ( + "foreground", + "background", + "underline_color", + "overline_color", + "strikethrough_color" +) +SPAN_ATTR_KEY_CONVERSION = { + key: key_alias_list[0] + for key_alias_list in SPAN_ATTR_KEY_ALIAS_LIST + for key in key_alias_list +} +TAG_TO_ATTR_DICT = { + "b": {"font_weight": "bold"}, + "big": {"font_size": "larger"}, + "i": {"font_style": "italic"}, + "s": {"strikethrough": "true"}, + "sub": {"baseline_shift": "subscript", "font_scale": "subscript"}, + "sup": {"baseline_shift": "superscript", "font_scale": "superscript"}, + "small": {"font_size": "smaller"}, + "tt": {"font_family": "monospace"}, + "u": {"underline": "single"}, +} # Temporary handler @@ -223,16 +112,9 @@ class _Alignment: self.value = _Alignment.VAL_DICT[s.upper()] -class Text(SVGMobject): +class MarkupText(LabelledString): CONFIG = { - # Mobject - "stroke_width": 0, - "svg_default": { - "color": WHITE, - }, - "height": None, - # Text - "is_markup": False, + "is_markup": True, "font_size": 48, "lsh": None, "justify": False, @@ -240,8 +122,6 @@ class Text(SVGMobject): "alignment": "LEFT", "line_width_factor": None, "font": "", - "disable_ligatures": True, - "apply_space_chars": True, "slant": NORMAL, "weight": NORMAL, "gradient": None, @@ -252,6 +132,7 @@ class Text(SVGMobject): "t2w": {}, "global_config": {}, "local_configs": {}, + "isolate": [], } def __init__(self, text: str, **kwargs): @@ -260,10 +141,15 @@ class Text(SVGMobject): validate_error = MarkupUtils.validate(text) if validate_error: raise ValueError(validate_error) - self.text = text - self.parser = _TextParser(text, is_markup=self.is_markup) - super().__init__(**kwargs) + self.text = text + super().__init__(text, **kwargs) + + if self.t2g: + log.warning( + "Manim currently cannot parse gradient from svg. " + "Please set gradient via `set_color_by_gradient`.", + ) if self.gradient: self.set_color_by_gradient(*self.gradient) if self.height is None: @@ -284,8 +170,6 @@ class Text(SVGMobject): self.alignment, self.line_width_factor, self.font, - self.disable_ligatures, - self.apply_space_chars, self.slant, self.weight, self.t2c, @@ -293,71 +177,32 @@ class Text(SVGMobject): self.t2s, self.t2w, self.global_config, - self.local_configs + self.local_configs, + self.isolate ) - def get_file_path(self) -> str: - full_markup = self.get_full_markup_str() + def full2short(self, config: dict) -> None: + conversion_dict = { + "line_spacing_height": "lsh", + "text2color": "t2c", + "text2font": "t2f", + "text2gradient": "t2g", + "text2slant": "t2s", + "text2weight": "t2w" + } + for kwargs in [config, self.CONFIG]: + for long_name, short_name in conversion_dict.items(): + if long_name in kwargs: + kwargs[short_name] = kwargs.pop(long_name) + + def get_file_path_by_content(self, content: str) -> str: svg_file = os.path.join( - get_text_dir(), tex_hash(full_markup) + ".svg" + get_text_dir(), tex_hash(content) + ".svg" ) if not os.path.exists(svg_file): - self.markup_to_svg(full_markup, svg_file) + self.markup_to_svg(content, svg_file) return svg_file - def get_full_markup_str(self) -> str: - if self.t2g: - log.warning( - "Manim currently cannot parse gradient from svg. " - "Please set gradient via `set_color_by_gradient`.", - ) - - config_style_dict = self.generate_config_style_dict() - global_attr_dict = { - "line_height": ((self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1) * 0.6, - "font_family": self.font or get_customization()["style"]["font"], - "font_size": self.font_size * 1024, - "font_style": self.slant, - "font_weight": self.weight, - # TODO, it seems this doesn't work - "font_features": "liga=0,dlig=0,clig=0,hlig=0" if self.disable_ligatures else None, - "foreground": config_style_dict.get("fill", None), - "alpha": config_style_dict.get("fill-opacity", None) - } - global_attr_dict = { - k: v - for k, v in global_attr_dict.items() - if v is not None - } - global_attr_dict.update(self.global_config) - self.parser.update_global_attrs(global_attr_dict) - - local_attr_items = [ - (word_or_text_span, {key: value}) - for t2x_dict, key in ( - (self.t2c, "foreground"), - (self.t2f, "font_family"), - (self.t2s, "font_style"), - (self.t2w, "font_weight") - ) - for word_or_text_span, value in t2x_dict.items() - ] - local_attr_items.extend(self.local_configs.items()) - for word_or_text_span, local_config in local_attr_items: - for text_span in self.find_indexes(word_or_text_span): - self.parser.update_local_attrs(text_span, local_config) - - return self.parser.get_markup_str_with_attrs() - - def find_indexes(self, word_or_text_span: str | tuple[int, int]) -> list[tuple[int, int]]: - if isinstance(word_or_text_span, tuple): - return [word_or_text_span] - - return [ - match_obj.span() - for match_obj in re.finditer(re.escape(word_or_text_span), self.text) - ] - def markup_to_svg(self, markup_str: str, file_name: str) -> str: # `manimpango` is under construction, # so the following code is intended to suit its interface @@ -374,7 +219,7 @@ class Text(SVGMobject): weight="NORMAL", # Already handled size=1, # Already handled _=0, # Empty parameter - disable_liga=False, # Already handled + disable_liga=False, # Need not to handle file_name=file_name, START_X=0, START_Y=0, @@ -387,63 +232,318 @@ class Text(SVGMobject): pango_width=pango_width ) - def generate_mobject(self) -> None: - super().generate_mobject() + # Toolkits - # Remove empty paths - submobjects = list(filter(lambda submob: submob.has_points(), self)) + @staticmethod + def get_attr_dict_str(attr_dict: dict[str, str]) -> str: + return " ".join([ + f"{key}='{value}'" + for key, value in attr_dict.items() + ]) - # Apply space characters - if self.apply_space_chars: - content_str = self.parser.remove_tags(self.text) - if self.is_markup: - content_str = saxutils.unescape(content_str) - for match_obj in re.finditer(r"\s", content_str): - char_index = match_obj.start() - space = Dot(radius=0, fill_opacity=0, stroke_opacity=0) - space.move_to(submobjects[max(char_index - 1, 0)].get_center()) - submobjects.insert(char_index, space) - self.set_submobjects(submobjects) + @staticmethod + def get_begin_tag_str(attr_dict: dict[str, str]) -> str: + return f"" - def full2short(self, config: dict) -> None: - conversion_dict = { - "line_spacing_height": "lsh", - "text2color": "t2c", - "text2font": "t2f", - "text2gradient": "t2g", - "text2slant": "t2s", - "text2weight": "t2w" - } - for kwargs in [config, self.CONFIG]: - for long_name, short_name in conversion_dict.items(): - if long_name in kwargs: - kwargs[short_name] = kwargs.pop(long_name) + @staticmethod + def get_end_tag_str() -> str: + return "" - def get_parts_by_text(self, word: str) -> VGroup: - if self.is_markup: - log.warning( - "Slicing MarkupText via `get_parts_by_text`, " - "the result could be unexpected." - ) - elif not self.apply_space_chars: - log.warning( - "Slicing Text via `get_parts_by_text` without applying spaces, " - "the result could be unexpected." - ) - return VGroup(*( - self[i:j] - for i, j in self.find_indexes(word) + @staticmethod + def convert_attr_key(key: str) -> str: + return SPAN_ATTR_KEY_CONVERSION[key.lower()] + + @staticmethod + def convert_attr_val(val: typing.Any) -> str: + return str(val).lower() + + @staticmethod + def merge_attr_items( + attr_items: list[Span, str, str] + ) -> list[tuple[Span, dict[str, str]]]: + index_seq = [0] + attr_dict_list = [{}] + for span, key, value in attr_items: + if span[0] >= span[1]: + continue + region_indices = [ + MarkupText.find_region_index(index, index_seq) + for index in span + ] + for flag in (1, 0): + if index_seq[region_indices[flag]] == span[flag]: + continue + region_index = region_indices[flag] + index_seq.insert(region_index + 1, span[flag]) + attr_dict_list.insert( + region_index + 1, attr_dict_list[region_index].copy() + ) + region_indices[flag] += 1 + if flag == 0: + region_indices[1] += 1 + for attr_dict in attr_dict_list[slice(*region_indices)]: + attr_dict[key] = value + return list(zip( + MarkupText.get_neighbouring_pairs(index_seq), attr_dict_list[:-1] )) - def get_part_by_text(self, word: str) -> VMobject | None: - parts = self.get_parts_by_text(word) - return parts[0] if parts else None + # Parser + + @property + def tag_items_from_markup( + self + ) -> list[tuple[Span, Span, dict[str, str]]]: + if not self.is_markup: + return [] + + tag_pattern = r"""<(/?)(\w+)\s*((\w+\s*\=\s*('.*?'|".*?")\s*)*)>""" + attr_pattern = r"""(\w+)\s*\=\s*(?:(?:'(.*?)')|(?:"(.*?)"))""" + begin_match_obj_stack = [] + match_obj_pairs = [] + for match_obj in re.finditer(tag_pattern, self.string): + if not match_obj.group(1): + begin_match_obj_stack.append(match_obj) + else: + match_obj_pairs.append( + (begin_match_obj_stack.pop(), match_obj) + ) + if begin_match_obj_stack: + raise ValueError("Unclosed tag(s) detected") + + result = [] + for begin_match_obj, end_match_obj in match_obj_pairs: + tag_name = begin_match_obj.group(2) + if tag_name != end_match_obj.group(2): + raise ValueError("Unmatched tag names") + if end_match_obj.group(3): + raise ValueError("Attributes shan't exist in ending tags") + if tag_name == "span": + attr_dict = dict([ + ( + MarkupText.convert_attr_key(match.group(1)), + MarkupText.convert_attr_val( + match.group(2) or match.group(3) + ) + ) + for match in re.finditer( + attr_pattern, begin_match_obj.group(3) + ) + ]) + elif tag_name in TAG_TO_ATTR_DICT.keys(): + if begin_match_obj.group(3): + raise ValueError( + f"Attributes shan't exist in tag '{tag_name}'" + ) + attr_dict = TAG_TO_ATTR_DICT[tag_name].copy() + else: + raise ValueError(f"Unknown tag: '{tag_name}'") + + result.append( + (begin_match_obj.span(), end_match_obj.span(), attr_dict) + ) + return result + + @property + def global_attr_items_from_config(self) -> list[str, str]: + global_attr_dict = { + "line_height": ( + (self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1 + ) * 0.6, + "font_family": self.font or get_customization()["style"]["font"], + "font_size": self.font_size * 1024, + "font_style": self.slant, + "font_weight": self.weight + } + global_attr_dict = { + k: v + for k, v in global_attr_dict.items() + if v is not None + } + result = list(it.chain( + global_attr_dict.items(), + self.global_config.items() + )) + return [ + ( + self.convert_attr_key(key), + self.convert_attr_val(val) + ) + for key, val in result + ] + + @property + def local_attr_items_from_config(self) -> list[tuple[Span, str, str]]: + result = [ + (text_span, key, val) + for t2x_dict, key in ( + (self.t2c, "foreground"), + (self.t2f, "font_family"), + (self.t2s, "font_style"), + (self.t2w, "font_weight") + ) + for word_or_span, val in t2x_dict.items() + for text_span in self.find_spans(word_or_span) + ] + [ + (text_span, key, val) + for word_or_span, local_config in self.local_configs.items() + for text_span in self.find_spans(word_or_span) + for key, val in local_config.items() + ] + return [ + ( + text_span, + self.convert_attr_key(key), + self.convert_attr_val(val) + ) + for text_span, key, val in result + ] + + def find_spans(self, word_or_span: str | Span) -> list[Span]: + if isinstance(word_or_span, tuple): + return [word_or_span] + + return [ + match_obj.span() + for match_obj in re.finditer(re.escape(word_or_span), self.string) + ] + + @property + def skipped_spans(self) -> list[Span]: + return [ + match_obj.span() + for match_obj in re.finditer(r"\s+", self.string) + ] + + @property + def label_span_list(self) -> list[Span]: + breakup_indices = [ + index + for pattern in [ + r"\s+", + r"\b", + *[ + re.escape(substr) + for substr in self.get_substrs_to_isolate(self.isolate) + ] + ] + for match_obj in re.finditer(pattern, self.string) + for index in match_obj.span() + ] + breakup_indices = sorted(filter( + lambda index: not any([ + span[0] < index < span[1] + for span, _ in self.command_repl_items + ]), + remove_list_redundancies([ + *self.full_span, *breakup_indices + ]) + )) + return list(filter( + lambda span: self.string[slice(*span)].strip(), + self.get_neighbouring_pairs(breakup_indices) + )) + + @property + def predefined_items(self) -> list[Span, str, str]: + return list(it.chain( + [ + (self.full_span, key, val) + for key, val in self.global_attr_items_from_config + ], + sorted([ + ((begin_tag_span[0], end_tag_span[1]), key, val) + for begin_tag_span, end_tag_span, attr_dict + in self.tag_items_from_markup + for key, val in attr_dict.items() + ]), + self.local_attr_items_from_config + )) + + def get_inserted_string_pairs( + self, use_label: bool + ) -> list[tuple[Span, tuple[str, str]]]: + attr_items = self.predefined_items + if use_label: + attr_items = [ + (span, key, WHITE if key in COLOR_RELATED_KEYS else val) + for span, key, val in attr_items + ] + [ + (span, "foreground", "#{:06x}".format(label)) + for label, span in enumerate(self.label_span_list) + ] + return [ + (span, ( + self.get_begin_tag_str(attr_dict), + self.get_end_tag_str() + )) + for span, attr_dict in self.merge_attr_items(attr_items) + ] + + @property + def inserted_string_pairs(self) -> list[tuple[Span, tuple[str, str]]]: + return self.get_inserted_string_pairs(use_label=True) + + @property + def command_repl_items(self) -> list[tuple[Span, str]]: + return [ + (tag_span, "") + for begin_tag, end_tag, _ in self.tag_items_from_markup + for tag_span in (begin_tag, end_tag) + ] + + @property + def has_predefined_colors(self) -> bool: + return any([ + key in COLOR_RELATED_KEYS + for _, key, _ in self.predefined_items + ]) + + @property + def plain_string(self) -> str: + return "".join([ + self.get_begin_tag_str({"foreground": self.base_color}), + self.replace_str_by_spans( + self.string, self.get_span_replacement_dict( + self.get_inserted_string_pairs(use_label=False), + self.command_repl_items + ) + ), + self.get_end_tag_str() + ]) + + def handle_submob_string(self, substr: str, string_span: Span) -> str: + if self.is_markup: + substr = saxutils.unescape(substr) + return substr + + # Method alias + + def get_parts_by_text(self, substr: str) -> VGroup: + return self.get_parts_by_string(substr) + + def get_part_by_text(self, substr: str, index: int = 0) -> VMobject: + return self.get_part_by_string(substr, index) + + def set_color_by_text(self, substr: str, color: ManimColor): + return self.set_color_by_string(substr, color) + + def set_color_by_text_to_color_map( + self, text_to_color_map: dict[str, ManimColor] + ): + return self.set_color_by_string_to_color_map(text_to_color_map) + + def indices_of_part_by_text( + self, substr: str, index: int = 0 + ) -> list[int]: + return self.indices_of_part_by_string(substr, index) + + def get_text(self) -> str: + return self.get_string() -class MarkupText(Text): +class Text(MarkupText): CONFIG = { - "is_markup": True, - "apply_space_chars": False, + "is_markup": False, } @@ -461,7 +561,9 @@ class Code(MarkupText): digest_config(self, kwargs) self.code = code lexer = pygments.lexers.get_lexer_by_name(self.language) - formatter = pygments.formatters.PangoMarkupFormatter(style=self.code_style) + formatter = pygments.formatters.PangoMarkupFormatter( + style=self.code_style + ) markup = pygments.highlight(code, lexer, formatter) markup = re.sub(r"", "", markup) super().__init__(markup, **kwargs) From 0e31ff12e2a9d307ae170a38f3c5834998d14451 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Mon, 28 Mar 2022 18:54:43 +0800 Subject: [PATCH 07/48] Tiny fix for TransformMatchingString --- .../animation/transform_matching_parts.py | 118 ++++++++++-------- manimlib/mobject/svg/text_mobject.py | 2 +- 2 files changed, 67 insertions(+), 53 deletions(-) diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index 486007dd..2e45db37 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -159,7 +159,11 @@ class TransformMatchingString(AnimationGroup): "transform_mismatches_class": None, } - def __init__(self, source_mobject: LabelledString, target_mobject: LabelledString, **kwargs): + def __init__(self, + source_mobject: LabelledString, + target_mobject: LabelledString, + **kwargs + ): digest_config(self, kwargs) assert isinstance(source_mobject, LabelledString) assert isinstance(target_mobject, LabelledString) @@ -167,36 +171,37 @@ class TransformMatchingString(AnimationGroup): rest_source_submobs = source_mobject.submobjects.copy() rest_target_submobs = target_mobject.submobjects.copy() - def add_anim_from(anim_class, func, source_attr, target_attr=None): - if target_attr is None: - target_attr = source_attr - source_parts = func(source_mobject, source_attr) - target_parts = func(target_mobject, target_attr) - filtered_source_parts = [ - submob_part for submob_part in source_parts - if all([ - submob in rest_source_submobs - for submob in submob_part - ]) - ] - filtered_target_parts = [ - submob_part for submob_part in target_parts - if all([ - submob in rest_target_submobs - for submob in submob_part - ]) - ] - if not (filtered_source_parts and filtered_target_parts): - return - anims.append(anim_class( - VGroup(*filtered_source_parts), - VGroup(*filtered_target_parts), - **kwargs - )) - for submob in it.chain(*filtered_source_parts): - rest_source_submobs.remove(submob) - for submob in it.chain(*filtered_target_parts): - rest_target_submobs.remove(submob) + def add_anims_from(anim_class, func, source_attrs, target_attrs=None): + if target_attrs is None: + target_attrs = source_attrs.copy() + for source_attr, target_attr in zip(source_attrs, target_attrs): + source_parts = func(source_mobject, source_attr) + target_parts = func(target_mobject, target_attr) + filtered_source_parts = [ + submob_part for submob_part in source_parts + if all([ + submob in rest_source_submobs + for submob in submob_part + ]) + ] + filtered_target_parts = [ + submob_part for submob_part in target_parts + if all([ + submob in rest_target_submobs + for submob in submob_part + ]) + ] + if not (filtered_source_parts and filtered_target_parts): + return + anims.append(anim_class( + VGroup(*filtered_source_parts), + VGroup(*filtered_target_parts), + **kwargs + )) + for submob in it.chain(*filtered_source_parts): + rest_source_submobs.remove(submob) + for submob in it.chain(*filtered_target_parts): + rest_target_submobs.remove(submob) def get_submobs_from_keys(mobject, keys): if not isinstance(keys, tuple): @@ -218,45 +223,54 @@ class TransformMatchingString(AnimationGroup): mobject[i] for i in remove_list_redundancies(indices) ])) - for source_key, target_key in self.key_map.items(): - add_anim_from( - ReplacementTransform, get_submobs_from_keys, - source_key, target_key - ) + add_anims_from( + ReplacementTransform, get_submobs_from_keys, + self.key_map.keys(), self.key_map.values() + ) common_specified_substrings = sorted(list( set(source_mobject.get_specified_substrings()).intersection( target_mobject.get_specified_substrings() ) ), key=len, reverse=True) - for part_string in common_specified_substrings: - add_anim_from( - FadeTransformPieces, LabelledString.get_parts_by_string, part_string - ) + if "" in common_specified_substrings: + common_specified_substrings.remove("") + add_anims_from( + FadeTransformPieces, + LabelledString.get_parts_by_string, + common_specified_substrings + ) common_submob_strings = { source_submob.get_string() for source_submob in source_mobject }.intersection({ target_submob.get_string() for target_submob in target_mobject }) - for substr in common_submob_strings: - add_anim_from( - FadeTransformPieces, - lambda mobject, attr: VGroup(*[ - VGroup(mob) for mob in mobject - if mob.get_string() == attr - ]), - substr - ) + add_anims_from( + FadeTransformPieces, + lambda mobject, attr: VGroup(*[ + VGroup(mob) for mob in mobject + if mob.get_string() == attr + ]), + common_submob_strings + ) if self.transform_mismatches_class is not None: - anims.append(self.transform_mismatches_class(fade_source, fade_target, **kwargs)) + anims.append(self.transform_mismatches_class( + fade_source, + fade_target, + **kwargs + )) else: anims.append(FadeOutToPoint( - VGroup(*rest_source_submobs), target_mobject.get_center(), **kwargs + VGroup(*rest_source_submobs), + target_mobject.get_center(), + **kwargs )) anims.append(FadeInFromPoint( - VGroup(*rest_target_submobs), source_mobject.get_center(), **kwargs + VGroup(*rest_target_submobs), + source_mobject.get_center(), + **kwargs )) super().__init__(*anims) diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 24a5b111..c44de5cc 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -219,7 +219,7 @@ class MarkupText(LabelledString): weight="NORMAL", # Already handled size=1, # Already handled _=0, # Empty parameter - disable_liga=False, # Need not to handle + disable_liga=False, file_name=file_name, START_X=0, START_Y=0, From 45faa9063bd91a74a607fdba16848ffea2ba455e Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Mon, 28 Mar 2022 19:02:50 +0800 Subject: [PATCH 08/48] Add items for hash_seed --- manimlib/mobject/svg/mtex_mobject.py | 7 +++---- manimlib/mobject/svg/text_mobject.py | 2 ++ 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index efc542b2..c04edb56 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -498,7 +498,6 @@ class MTex(LabelledString): "tex_environment": "align*", "isolate": [], "tex_to_color_map": {}, - "use_plain_file": False, } def __init__(self, tex_string: str, **kwargs): @@ -518,13 +517,13 @@ class MTex(LabelledString): self.__class__.__name__, self.svg_default, self.path_string_config, - self.tex_string, self.base_color, + self.use_plain_file, + self.tex_string, self.alignment, self.tex_environment, self.isolate, - self.tex_to_color_map, - self.use_plain_file + self.tex_to_color_map ) def get_file_path_by_content(self, content: str) -> str: diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index c44de5cc..4fafb138 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -161,6 +161,8 @@ class MarkupText(LabelledString): self.__class__.__name__, self.svg_default, self.path_string_config, + self.base_color, + self.use_plain_file, self.text, self.is_markup, self.font_size, From 89e139009b4ad6bed242f265e8af6526477c8376 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Mon, 28 Mar 2022 19:17:40 +0800 Subject: [PATCH 09/48] Remove an error raising --- manimlib/mobject/svg/mtex_mobject.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index c04edb56..44014e78 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -398,7 +398,7 @@ class LabelledString(_StringSVG): def find_span_components_of_custom_span( self, custom_span: Span - ) -> list[Span] | None: + ) -> list[Span]: span_choices = sorted(filter( lambda span: self.span_contains(custom_span, span), self.label_span_list @@ -413,7 +413,7 @@ class LabelledString(_StringSVG): while span_begin != span_end: span_begin = self.lstrip(span_begin) if span_begin not in span_choices_dict.keys(): - return None + return [] next_begin = span_choices_dict[span_begin] result.append((span_begin, next_begin)) span_begin = next_begin @@ -421,10 +421,6 @@ class LabelledString(_StringSVG): def get_part_by_custom_span(self, custom_span: Span) -> VGroup: spans = self.find_span_components_of_custom_span(custom_span) - if spans is None: - substr = self.string[slice(*custom_span)] - raise ValueError(f"Failed to match mobjects from \"{substr}\"") - labels = set(it.chain(*[ self.containing_labels_dict[span] for span in spans From 82c972b946b63f89c7bbe2265e377c8124223f74 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Mon, 28 Mar 2022 19:31:19 +0800 Subject: [PATCH 10/48] Remove saxutils.unescape process --- manimlib/mobject/svg/text_mobject.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 4fafb138..fda326f1 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -3,7 +3,6 @@ from __future__ import annotations import os import re import itertools as it -import xml.sax.saxutils as saxutils from pathlib import Path from contextlib import contextmanager import typing @@ -513,11 +512,6 @@ class MarkupText(LabelledString): self.get_end_tag_str() ]) - def handle_submob_string(self, substr: str, string_span: Span) -> str: - if self.is_markup: - substr = saxutils.unescape(substr) - return substr - # Method alias def get_parts_by_text(self, substr: str) -> VGroup: From 7e8b3a4c6b6b05cc18a6b7dd0d685e6b06b49c69 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Tue, 29 Mar 2022 23:38:06 +0800 Subject: [PATCH 11/48] Refactor LabelledString --- .../animation/transform_matching_parts.py | 2 +- manimlib/mobject/svg/mtex_mobject.py | 554 ++++++++++-------- manimlib/mobject/svg/text_mobject.py | 170 ++++-- 3 files changed, 424 insertions(+), 302 deletions(-) diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index 2e45db37..8e1dd101 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -153,7 +153,7 @@ class TransformMatchingTex(TransformMatchingParts): return mobject.get_tex() -class TransformMatchingString(AnimationGroup): +class TransformMatchingStrings(AnimationGroup): CONFIG = { "key_map": dict(), "transform_mismatches_class": None, diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 44014e78..02c243f4 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -50,6 +50,7 @@ class LabelledString(_StringSVG): CONFIG = { "base_color": WHITE, "use_plain_file": False, + "isolate": [], } def __init__(self, string: str, **kwargs): @@ -57,10 +58,11 @@ class LabelledString(_StringSVG): super().__init__(**kwargs) def get_file_path(self, use_plain_file: bool = False) -> str: - if use_plain_file: - content = self.plain_string - else: - content = self.labelled_string + #if use_plain_file: + # content = self.plain_string + #else: + # content = self.labelled_string + content = self.get_decorated_string(use_plain_file=use_plain_file) return self.get_file_path_by_content(content) @abstractmethod @@ -87,6 +89,7 @@ class LabelledString(_StringSVG): glyphs = self.submobjects self.set_fill(self.base_color) + # TODO # Simply pack together adjacent mobjects with the same label. submob_labels, glyphs_lists = self.group_neighbours( glyph_labels, glyphs @@ -105,14 +108,12 @@ class LabelledString(_StringSVG): # Toolkits - @staticmethod - def color_to_label(color: ManimColor) -> int: - r, g, b = color_to_int_rgb(color) - rg = r * 256 + g - rgb = rg * 256 + b - if rgb == 16777215: # white - return -1 - return rgb + def find_spans(self, *patterns: str) -> list[Span]: + return [ + match_obj.span() + for pattern in patterns + for match_obj in re.finditer(pattern, self.string) + ] @staticmethod def get_neighbouring_pairs(iterable: Iterable) -> list: @@ -211,109 +212,233 @@ class LabelledString(_StringSVG): result.update(other_repl_items) return result - @property - def skipped_spans(self) -> list[Span]: - return [] + #@property + #def skipped_spans(self) -> list[Span]: + # return [ + # match_obj.span() + # for match_obj in re.finditer(r"\s+", self.string) + # ] - def lstrip(self, index: int) -> int: - index_seq = list(it.chain(*self.skipped_spans)) - region_index = self.find_region_index(index, index_seq) - if region_index % 2 == 0: - return index_seq[region_index + 1] - return index + #def lstrip(self, index: int) -> int: + # index_seq = list(it.chain(*self.skipped_spans)) + # region_index = self.find_region_index(index, index_seq) + # if region_index % 2 == 0: + # return index_seq[region_index + 1] + # return index - def rstrip(self, index: int) -> int: - index_seq = list(it.chain(*self.skipped_spans)) - region_index = self.find_region_index(index - 1, index_seq) - if region_index % 2 == 0: - return index_seq[region_index] - return index + #def rstrip(self, index: int) -> int: + # index_seq = list(it.chain(*self.skipped_spans)) + # region_index = self.find_region_index(index - 1, index_seq) + # if region_index % 2 == 0: + # return index_seq[region_index] + # return index - def strip(self, span: Span) -> Span | None: - result = ( - self.lstrip(span[0]), - self.rstrip(span[1]) - ) - if result[0] >= result[1]: - return None - return result + #def strip(self, span: Span) -> Span | None: + # result = ( + # self.lstrip(span[0]), + # self.rstrip(span[1]) + # ) + # if result[0] >= result[1]: + # return None + # return result @staticmethod - def lslide(index: int, slid_spans: list[Span]) -> int: - slide_dict = dict(sorted(slid_spans)) - while index in slide_dict.keys(): - index = slide_dict[index] + def lstrip(index: int, skipped: list[Span]) -> int: + transfer_dict = dict(sorted(skipped)) + while index in transfer_dict.keys(): + index = transfer_dict[index] return index @staticmethod - def rslide(index: int, slid_spans: list[Span]) -> int: - slide_dict = dict(sorted([ - slide_span[::-1] for slide_span in slid_spans + def rstrip(index: int, skipped: list[Span]) -> int: + transfer_dict = dict(sorted([ + skipped_span[::-1] for skipped_span in skipped ], reverse=True)) - while index in slide_dict.keys(): - index = slide_dict[index] + while index in transfer_dict.keys(): + index = transfer_dict[index] return index @staticmethod - def slide(span: Span, slid_spans: list[Span]) -> Span | None: + def strip(span: Span, skipped: list[Span]) -> Span | None: result = ( - LabelledString.lslide(span[0], slid_spans), - LabelledString.rslide(span[1], slid_spans) + LabelledString.lstrip(span[0], skipped), + LabelledString.rstrip(span[1], skipped) ) if result[0] >= result[1]: return None return result + @abstractmethod + def get_begin_color_command_str(r: int, g: int, b: int) -> str: + return "" + + @abstractmethod + def get_end_color_command_str() -> str: + return "" + + @staticmethod + def color_to_label(color: ManimColor) -> int: + r, g, b = color_to_int_rgb(color) + rg = r * 256 + g + rgb = rg * 256 + b + if rgb == 16777215: # white + return -1 + return rgb + # Parser @property def full_span(self) -> Span: return (0, len(self.string)) - def get_substrs_to_isolate(self, substrs: list[str]) -> list[str]: - result = list(filter( - lambda s: s in self.string, - remove_list_redundancies(substrs) - )) - if "" in result: - result.remove("") - return result + @property + def space_spans(self) -> list[Span]: + return self.find_spans(r"\s+") + + @abstractmethod + def internal_specified_spans(self) -> list[Span]: + return [] @property + def external_specified_spans(self) -> list[Span]: + substrs = remove_list_redundancies(self.isolate) + if "" in substrs: + substrs.remove("") + return self.find_spans(*[ + re.escape(substr.strip()) for substr in substrs + ]) + + @property + def specified_spans(self) -> list[Span]: + return remove_list_redundancies([ + self.full_span, + *self.internal_specified_spans, + *self.external_specified_spans + ]) + + def get_specified_substrings(self) -> list[str]: + return remove_list_redundancies([ + self.string[slice(*span)] + for span in self.specified_spans + ]) + + @abstractmethod def label_span_list(self) -> list[Span]: return [] - @property - def inserted_string_pairs(self) -> list[tuple[Span, tuple[str, str]]]: + @abstractmethod + def get_inserted_string_pairs( + self, use_plain_file: bool + ) -> list[tuple[Span, tuple[str, str]]]: + return [] + + @abstractmethod + def command_repl_items(self) -> list[tuple[Span, str]]: return [] @property - def command_repl_items(self) -> list[tuple[Span, str]]: - return [] + def command_spans(self) -> list[Span]: + return [cmd_span for cmd_span, _ in self.command_repl_items] + + @abstractmethod + def remove_commands_in_plain_file(self) -> bool: + return True + + #@abstractmethod + #def get_command_repl_items( + # self, use_plain_file: bool + #) -> list[tuple[Span, str]]: + # return [] + + def get_decorated_string(self, use_plain_file: bool) -> str: + if use_plain_file and self.remove_commands_in_plain_file: + other_repl_items = [] + else: + other_repl_items = self.command_repl_items + span_repl_dict = self.get_span_replacement_dict( + self.get_inserted_string_pairs(use_plain_file), + other_repl_items + ) + result = self.replace_str_by_spans(self.string, span_repl_dict) + + if not use_plain_file: + return result + return "".join([ + self.get_begin_color_command_str( + *color_to_int_rgb(self.base_color) + ), + result, + self.get_end_color_command_str() + ]) @abstractmethod def has_predefined_colors(self) -> bool: return False - @property - def plain_string(self) -> str: - return self.string + #@property + #def plain_string(self) -> str: + # return self.string + + #@property + #def labelled_string(self) -> str: + # return self.replace_str_by_spans( + # self.string, self.get_span_replacement_dict( + # self.inserted_string_pairs, + # self.command_repl_items + # ) + # ) @property - def labelled_string(self) -> str: - return self.replace_str_by_spans( - self.string, self.get_span_replacement_dict( - self.inserted_string_pairs, - self.command_repl_items - ) - ) - - @property - def ignored_indices_for_submob_strings(self) -> list[int]: + def additionally_ignored_indices(self) -> list[int]: return [] - def handle_submob_string(self, substr: str, string_span: Span) -> str: - return substr + @property + def skipped_spans(self) -> list[Span]: + return list(it.chain( + self.space_spans, + self.command_spans, + [ + (index, index + 1) + for index in self.additionally_ignored_indices + ] + )) + + @property + def containing_labels_dict(self) -> dict[Span, list[int]]: + label_span_list = self.label_span_list + result = { + span: [] + for span in label_span_list + } + for span_0 in label_span_list: + for span_index, span_1 in enumerate(label_span_list): + if self.span_contains(span_0, span_1): + result[span_0].append(span_index) + elif span_0[0] < span_1[0] < span_0[1] < span_1[1]: + string_0, string_1 = [ + self.string[slice(*span)] + for span in [span_0, span_1] + ] + raise ValueError( + "Partially overlapping substrings detected: " + f"'{string_0}' and '{string_1}'" + ) + result[self.full_span] = list(range(-1, len(label_span_list))) + return result + + def get_cleaned_substr(self, string_span: Span) -> str: + span = self.strip(string_span, self.skipped_spans) + if span is None: + return "" + + span_repl_dict = { + tuple([index - span[0] for index in cmd_span]): "" + for cmd_span in self.command_spans + if self.span_contains(span, cmd_span) + } + return self.replace_str_by_spans( + self.string[slice(*span)], span_repl_dict + ) def get_submob_strings(self, submob_labels: list[int]) -> list[str]: ordered_spans = [ @@ -343,62 +468,17 @@ class LabelledString(_StringSVG): (ordered_span_begins[0], *string_span_begins), (*string_span_ends, ordered_span_ends[-1]) )) - - command_spans = [span for span, _ in self.command_repl_items] - slid_spans = list(it.chain( - self.skipped_spans, - command_spans, - [ - (index, index + 1) - for index in self.ignored_indices_for_submob_strings - ] - )) - result = [] - for string_span in string_spans: - string_span = self.slide(string_span, slid_spans) - if string_span is None: - result.append("") - continue - - span_repl_dict = { - tuple([index - string_span[0] for index in cmd_span]): "" - for cmd_span in command_spans - if self.span_contains(string_span, cmd_span) - } - substr = self.string[slice(*string_span)] - substr = self.replace_str_by_spans(substr, span_repl_dict) - substr = self.handle_submob_string(substr, string_span) - result.append(substr) - return result + return [ + self.get_cleaned_substr(string_span) + for string_span in string_spans + ] # Selector - @property - def containing_labels_dict(self) -> dict[Span, list[int]]: - label_span_list = self.label_span_list - result = { - span: [] - for span in label_span_list - } - for span_0 in label_span_list: - for span_index, span_1 in enumerate(label_span_list): - if self.span_contains(span_0, span_1): - result[span_0].append(span_index) - elif span_0[0] < span_1[0] < span_0[1] < span_1[1]: - string_0, string_1 = [ - self.string[slice(*span)] - for span in [span_0, span_1] - ] - raise ValueError( - "Partially overlapping substrings detected: " - f"'{string_0}' and '{string_1}'" - ) - result[self.full_span] = list(range(-1, len(label_span_list))) - return result - def find_span_components_of_custom_span( self, custom_span: Span ) -> list[Span]: + skipped_spans = self.skipped_spans span_choices = sorted(filter( lambda span: self.span_contains(custom_span, span), self.label_span_list @@ -408,10 +488,10 @@ class LabelledString(_StringSVG): result = [] span_begin, span_end = custom_span - span_begin = self.rstrip(span_begin) - span_end = self.rstrip(span_end) + span_begin = self.rstrip(span_begin, skipped_spans) + span_end = self.rstrip(span_end, skipped_spans) while span_begin != span_end: - span_begin = self.lstrip(span_begin) + span_begin = self.lstrip(span_begin, skipped_spans) if span_begin not in span_choices_dict.keys(): return [] next_begin = span_choices_dict[span_begin] @@ -432,8 +512,8 @@ class LabelledString(_StringSVG): def get_parts_by_string(self, substr: str) -> VGroup: return VGroup(*[ - self.get_part_by_custom_span(match_obj.span()) - for match_obj in re.finditer(re.escape(substr), self.string) + self.get_part_by_custom_span(span) + for span in self.find_spans(re.escape(substr.strip())) ]) def get_part_by_string(self, substr: str, index: int = 0) -> VMobject: @@ -466,23 +546,6 @@ class LabelledString(_StringSVG): part = self.get_part_by_string(substr, index=index) return self.indices_of_part(part) - @property - def specified_substrings(self) -> list[str]: - return [] - - def get_specified_substrings(self) -> list[str]: - return self.specified_substrings - - @property - def isolated_substrings(self) -> list[str]: - return remove_list_redundancies([ - self.string[slice(*span)] - for span in self.label_span_list - ]) - - def get_isolated_substrings(self) -> list[str]: - return self.isolated_substrings - def get_string(self) -> str: return self.string @@ -492,16 +555,17 @@ class MTex(LabelledString): "font_size": 48, "alignment": "\\centering", "tex_environment": "align*", - "isolate": [], "tex_to_color_map": {}, } def __init__(self, tex_string: str, **kwargs): + digest_config(self, kwargs) tex_string = tex_string.strip() # Prevent from passing an empty string. if not tex_string: tex_string = "\\quad" self.tex_string = tex_string + self.isolate.extend(self.tex_to_color_map.keys()) super().__init__(tex_string, **kwargs) self.set_color_by_tex_to_color_map(self.tex_to_color_map) @@ -515,10 +579,10 @@ class MTex(LabelledString): self.path_string_config, self.base_color, self.use_plain_file, + self.isolate, self.tex_string, self.alignment, self.tex_environment, - self.isolate, self.tex_to_color_map ) @@ -548,6 +612,28 @@ class MTex(LabelledString): def tex_to_svg_file_path(tex_file_content: str) -> str: return tex_to_svg_file(tex_file_content) + # Toolkits + + #@property + #def skipped_spans(self) -> list[Span]: + # return super().skipped_spans + self.indices_to_spans( + # self.script_char_indices + # ) + + @staticmethod + def get_begin_color_command_str(r: int, g: int, b: int) -> str: + return "".join([ + "{{", + "\\color[RGB]", + "{", + ",".join(map(str, (r, g, b))), + "}" + ]) + + @staticmethod + def get_end_color_command_str() -> str: + return "}}" + # Parser @property @@ -559,15 +645,20 @@ class MTex(LabelledString): if len(match_obj.group()) % 2 == 1 ] - def get_brace_indices_lists(self) -> tuple[list[Span], list[Span]]: - string = self.string - indices = list(filter( + @staticmethod + def get_unescaped_char_indices(*chars: str): + return list(filter( lambda index: index - 1 not in self.backslash_indices, [ match_obj.start() - for match_obj in re.finditer(r"[{}]", string) + for char in chars + for match_obj in re.finditer(re.escape(char), string) ] )) + + def get_brace_indices_lists(self) -> tuple[list[Span], list[Span]]: + string = self.string + indices = self.get_unescaped_char_indices("{", "}") left_brace_indices = [] right_brace_indices = [] left_brace_indices_stack = [] @@ -594,20 +685,8 @@ class MTex(LabelledString): return self.get_brace_indices_lists()[1] @property - def skipped_spans(self) -> list[Span]: - return [ - match_obj.span() - for match_obj in re.finditer(r"\s*([_^])\s*|(\s+)", self.string) - if match_obj.group(2) is not None - or match_obj.start(1) - 1 not in self.backslash_indices - ] - - @property - def script_char_spans(self) -> list[Span]: - return list(filter( - lambda span: self.string[slice(*span)].strip(), - self.skipped_spans - )) + def script_char_indices(self) -> list[Span]: + return self.get_unescaped_char_indices("_", "^") @property def script_content_spans(self) -> list[Span]: @@ -615,7 +694,8 @@ class MTex(LabelledString): brace_indices_dict = dict(zip( self.left_brace_indices, self.right_brace_indices )) - for _, span_begin in self.script_char_spans: + for index in self.script_char_indices: + span_begin = self.lstrip(index, self.space_spans) if span_begin in brace_indices_dict.keys(): span_end = brace_indices_dict[span_begin] + 1 else: @@ -635,7 +715,7 @@ class MTex(LabelledString): return result @property - def double_braces_spans(self) -> list[Span]: + def internal_specified_spans(self) -> list[Span]: # Match paired double braces (`{{...}}`). result = [] reversed_brace_indices_dict = dict(zip( @@ -659,28 +739,16 @@ class MTex(LabelledString): return result @property - def additional_substrings(self) -> list[str]: - return self.get_substrs_to_isolate(list(it.chain( - self.tex_to_color_map.keys(), - self.isolate - ))) - - def get_label_span_list(self, extended: bool) -> list[Span]: + def label_span_list(self) -> list[Span]: script_content_spans = self.script_content_spans script_spans = [ - (script_char_span[0], script_content_span[1]) - for script_char_span, script_content_span in zip( - self.script_char_spans, script_content_spans + (self.rstrip(index, self.space_spans), script_content_span[1]) + for index, script_content_span in zip( + self.script_char_indices, script_content_spans ) ] spans = remove_list_redundancies([ - self.full_span, - *self.double_braces_spans, - *filter(lambda stripped_span: stripped_span is not None, [ - self.strip(match_obj.span()) - for substr in self.additional_substrings - for match_obj in re.finditer(re.escape(substr), self.string) - ]), + *self.specified_spans, *script_content_spans ]) result = [] @@ -688,37 +756,53 @@ class MTex(LabelledString): if span in script_content_spans: continue span_begin, span_end = span - shrinked_end = self.rslide(span_end, script_spans) + shrinked_end = self.rstrip(span_end, script_spans) if span_begin >= shrinked_end: continue - shrinked_span = (span_begin, shrinked_end) - if shrinked_span in result: - continue - result.append(shrinked_span) + result.append((span_begin, self.lstrip(span_end, script_spans))) - if extended: - result = [ - (span_begin, self.lslide(span_end, script_spans)) - for span_begin, span_end in result - ] - return script_content_spans + result + #if extended: + # result = [ + # (span_begin, self.lstrip(span_end, script_spans)) + # for span_begin, span_end in result + # ] + return script_content_spans + remove_list_redundancies(result) - @property - def label_span_list(self) -> list[Span]: - return self.get_label_span_list(extended=False) + #@property + #def label_span_list(self) -> list[Span]: + # return self.get_label_span_list(extended=False) - @property - def inserted_string_pairs(self) -> list[tuple[Span, tuple[str, str]]]: + def get_inserted_string_pairs( + self, use_plain_file: bool + ) -> list[tuple[Span, tuple[str, str]]]: + if use_plain_file: + return [] return [ (span, ( - "{{" + self.get_color_command_by_label(label), - "}}" + self.get_begin_color_command_str( + label // 256 // 256, + label // 256 % 256, + label % 256 + ), + self.get_end_color_command_str() )) for label, span in enumerate( - self.get_label_span_list(extended=True) + self.label_span_list ) ] + #@property + #def inserted_string_pairs(self) -> list[tuple[Span, tuple[str, str]]]: + # return [ + # (span, ( + # "{{" + self.get_color_command_by_label(label), + # "}}" + # )) + # for label, span in enumerate( + # self.get_label_span_list(extended=True) + # ) + # ] + @property def command_repl_items(self) -> list[tuple[Span, str]]: color_related_command_dict = { @@ -754,39 +838,44 @@ class MTex(LabelledString): result.append(((span_begin, span_end), repl_str)) return result + @property + def remove_commands_in_plain_file(self) -> bool: + return True + @property def has_predefined_colors(self) -> bool: return bool(self.command_repl_items) - @staticmethod - def get_color_command_by_label(label: int) -> str: - if label == -1: - label = 16777215 # white - rg, b = divmod(label, 256) - r, g = divmod(rg, 256) - return "".join([ - "\\color[RGB]", - "{", - ",".join(map(str, (r, g, b))), - "}" - ]) + #@staticmethod + #def get_color_command_by_label(label: int) -> str: + # if label == -1: + # label = 16777215 # white + # rg, b = divmod(label, 256) + # r, g = divmod(rg, 256) + # return "".join([ + # "\\color[RGB]", + # "{", + # ",".join(map(str, (r, g, b))), + # "}" + # ]) + + #@property + #def plain_string(self) -> str: + # return "".join([ + # "{{", + # self.get_color_command_by_label( + # self.color_to_label(self.base_color) + # ), + # self.string, + # "}}" + # ]) @property - def plain_string(self) -> str: - return "".join([ - "{{", - self.get_color_command_by_label( - self.color_to_label(self.base_color) - ), - self.string, - "}}" - ]) - - @property - def ignored_indices_for_submob_strings(self) -> list[int]: + def additionally_ignored_indices(self) -> list[int]: return self.left_brace_indices + self.right_brace_indices - def handle_submob_string(self, substr: str, string_span: Span) -> str: + def get_cleaned_substr(self, string_span: Span) -> str: + substr = super().get_cleaned_substr(string_span) unclosed_left_braces = 0 unclosed_right_braces = 0 for index in range(*string_span): @@ -803,13 +892,6 @@ class MTex(LabelledString): unclosed_left_braces * "}" ]) - @property - def specified_substrings(self) -> list[str]: - return remove_list_redundancies([ - self.string[slice(*double_braces_span)] - for double_braces_span in self.double_braces_spans - ] + self.additional_substrings) - # Method alias def get_parts_by_tex(self, substr: str) -> VGroup: diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index fda326f1..2d80305f 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -131,15 +131,16 @@ class MarkupText(LabelledString): "t2w": {}, "global_config": {}, "local_configs": {}, - "isolate": [], } def __init__(self, text: str, **kwargs): self.full2short(kwargs) digest_config(self, kwargs) - validate_error = MarkupUtils.validate(text) - if validate_error: - raise ValueError(validate_error) + + if self.is_markup: + validate_error = MarkupUtils.validate(text) + if validate_error: + raise ValueError(validate_error) self.text = text super().__init__(text, **kwargs) @@ -162,6 +163,7 @@ class MarkupText(LabelledString): self.path_string_config, self.base_color, self.use_plain_file, + self.isolate, self.text, self.is_markup, self.font_size, @@ -178,8 +180,7 @@ class MarkupText(LabelledString): self.t2s, self.t2w, self.global_config, - self.local_configs, - self.isolate + self.local_configs ) def full2short(self, config: dict) -> None: @@ -288,6 +289,15 @@ class MarkupText(LabelledString): MarkupText.get_neighbouring_pairs(index_seq), attr_dict_list[:-1] )) + @staticmethod + def get_begin_color_command_str(r: int, g: int, b: int) -> str: + color_hex = "#{:02x}{:02x}{:02x}".format(r, g, b).upper() + return MarkupText.get_begin_tag_str({"foreground": color_hex}) + + @staticmethod + def get_end_color_command_str() -> str: + return MarkupText.get_end_tag_str() + # Parser @property @@ -319,17 +329,12 @@ class MarkupText(LabelledString): if end_match_obj.group(3): raise ValueError("Attributes shan't exist in ending tags") if tag_name == "span": - attr_dict = dict([ - ( - MarkupText.convert_attr_key(match.group(1)), - MarkupText.convert_attr_val( - match.group(2) or match.group(3) - ) - ) + attr_dict = { + match.group(1): match.group(2) or match.group(3) for match in re.finditer( attr_pattern, begin_match_obj.group(3) ) - ]) + } elif tag_name in TAG_TO_ATTR_DICT.keys(): if begin_match_obj.group(3): raise ValueError( @@ -372,6 +377,19 @@ class MarkupText(LabelledString): for key, val in result ] + @property + def local_attr_items_from_markup(self) -> list[tuple[Span, str, str]]: + return sorted([ + ( + (begin_tag_span[0], end_tag_span[1]), + self.convert_attr_key(key), + self.convert_attr_val(val) + ) + for begin_tag_span, end_tag_span, attr_dict + in self.tag_items_from_markup + for key, val in attr_dict.items() + ]) + @property def local_attr_items_from_config(self) -> list[tuple[Span, str, str]]: result = [ @@ -383,11 +401,11 @@ class MarkupText(LabelledString): (self.t2w, "font_weight") ) for word_or_span, val in t2x_dict.items() - for text_span in self.find_spans(word_or_span) + for text_span in self.find_spans_by_word_or_span(word_or_span) ] + [ (text_span, key, val) for word_or_span, local_config in self.local_configs.items() - for text_span in self.find_spans(word_or_span) + for text_span in self.find_spans_by_word_or_span(word_or_span) for key, val in local_config.items() ] return [ @@ -399,45 +417,45 @@ class MarkupText(LabelledString): for text_span, key, val in result ] - def find_spans(self, word_or_span: str | Span) -> list[Span]: + def find_spans_by_word_or_span( + self, word_or_span: str | Span + ) -> list[Span]: if isinstance(word_or_span, tuple): return [word_or_span] - return [ - match_obj.span() - for match_obj in re.finditer(re.escape(word_or_span), self.string) - ] + return self.find_spans(re.escape(word_or_span)) + + #@property + #def skipped_spans(self) -> list[Span]: + # return [ + # match_obj.span() + # for match_obj in re.finditer(r"\s+", self.string) + # ] + + #@property + #def additional_substrings(self) -> list[str]: + # return self.get_substrs_to_isolate(self.isolate) @property - def skipped_spans(self) -> list[Span]: + def internal_specified_spans(self) -> list[Span]: return [ - match_obj.span() - for match_obj in re.finditer(r"\s+", self.string) + markup_span + for markup_span, _, _ in self.local_attr_items_from_markup ] @property def label_span_list(self) -> list[Span]: - breakup_indices = [ - index - for pattern in [ - r"\s+", - r"\b", - *[ - re.escape(substr) - for substr in self.get_substrs_to_isolate(self.isolate) - ] - ] - for match_obj in re.finditer(pattern, self.string) - for index in match_obj.span() - ] + entity_spans = [span for span, _ in self.command_repl_items] + if self.is_markup: + entity_spans += self.find_spans(r"&.*?;") breakup_indices = sorted(filter( lambda index: not any([ span[0] < index < span[1] - for span, _ in self.command_repl_items + for span in entity_spans ]), - remove_list_redundancies([ - *self.full_span, *breakup_indices - ]) + remove_list_redundancies(list(it.chain(*( + self.specified_spans + self.find_spans(r"\s+", r"\b") + )))) )) return list(filter( lambda span: self.string[slice(*span)].strip(), @@ -451,20 +469,15 @@ class MarkupText(LabelledString): (self.full_span, key, val) for key, val in self.global_attr_items_from_config ], - sorted([ - ((begin_tag_span[0], end_tag_span[1]), key, val) - for begin_tag_span, end_tag_span, attr_dict - in self.tag_items_from_markup - for key, val in attr_dict.items() - ]), + self.local_attr_items_from_markup, self.local_attr_items_from_config )) def get_inserted_string_pairs( - self, use_label: bool + self, use_plain_file: bool ) -> list[tuple[Span, tuple[str, str]]]: attr_items = self.predefined_items - if use_label: + if not use_plain_file: attr_items = [ (span, key, WHITE if key in COLOR_RELATED_KEYS else val) for span, key, val in attr_items @@ -480,17 +493,37 @@ class MarkupText(LabelledString): for span, attr_dict in self.merge_attr_items(attr_items) ] - @property - def inserted_string_pairs(self) -> list[tuple[Span, tuple[str, str]]]: - return self.get_inserted_string_pairs(use_label=True) + #@property + #def inserted_string_pairs(self) -> list[tuple[Span, tuple[str, str]]]: + # return self.get_inserted_string_pairs(use_label=True) @property def command_repl_items(self) -> list[tuple[Span, str]]: - return [ + result = [ (tag_span, "") for begin_tag, end_tag, _ in self.tag_items_from_markup for tag_span in (begin_tag, end_tag) ] + if not self.is_markup: + result += [ + (span, escaped) + for char, escaped in ( + ("&", "&"), + (">", ">"), + ("<", "<") + ) + for span in self.find_spans(re.escape(char)) + ] + return result + + def remove_commands_in_plain_file(self) -> bool: + return False + + #@abstractmethod + #def get_command_repl_items( + # self, use_plain_file: bool + #) -> list[tuple[Span, str]]: + # return [] @property def has_predefined_colors(self) -> bool: @@ -499,18 +532,25 @@ class MarkupText(LabelledString): for _, key, _ in self.predefined_items ]) - @property - def plain_string(self) -> str: - return "".join([ - self.get_begin_tag_str({"foreground": self.base_color}), - self.replace_str_by_spans( - self.string, self.get_span_replacement_dict( - self.get_inserted_string_pairs(use_label=False), - self.command_repl_items - ) - ), - self.get_end_tag_str() - ]) + #@property + #def plain_string(self) -> str: + # return "".join([ + # self.get_begin_tag_str({"foreground": self.base_color}), + # self.replace_str_by_spans( + # self.string, self.get_span_replacement_dict( + # self.get_inserted_string_pairs(use_label=False), + # self.command_repl_items + # ) + # ), + # self.get_end_tag_str() + # ]) + + #@property + #def specified_substrings(self) -> list[str]: # TODO: clean up and merge + # return remove_list_redundancies([ + # self.get_cleaned_substr(markup_span) + # for markup_span, _, _ in self.local_attr_items_from_markup + # ] + self.additional_substrings) # Method alias From c5ec47b0e960746702640f86bf43d104eff6ed57 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Wed, 30 Mar 2022 21:53:00 +0800 Subject: [PATCH 12/48] Refactor LabelledString --- .../animation/transform_matching_parts.py | 126 +-- manimlib/mobject/svg/mtex_mobject.py | 736 ++++++++---------- manimlib/mobject/svg/text_mobject.py | 279 +++---- 3 files changed, 515 insertions(+), 626 deletions(-) diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index 8e1dd101..f92d962d 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -156,7 +156,7 @@ class TransformMatchingTex(TransformMatchingParts): class TransformMatchingStrings(AnimationGroup): CONFIG = { "key_map": dict(), - "transform_mismatches_class": None, + "transform_mismatches": False, } def __init__(self, @@ -168,42 +168,53 @@ class TransformMatchingStrings(AnimationGroup): assert isinstance(source_mobject, LabelledString) assert isinstance(target_mobject, LabelledString) anims = [] - rest_source_submobs = source_mobject.submobjects.copy() - rest_target_submobs = target_mobject.submobjects.copy() + rest_source_indices = list(range(len(source_mobject.submobjects))) + rest_target_indices = list(range(len(target_mobject.submobjects))) - def add_anims_from(anim_class, func, source_attrs, target_attrs=None): - if target_attrs is None: - target_attrs = source_attrs.copy() - for source_attr, target_attr in zip(source_attrs, target_attrs): - source_parts = func(source_mobject, source_attr) - target_parts = func(target_mobject, target_attr) - filtered_source_parts = [ - submob_part for submob_part in source_parts - if all([ - submob in rest_source_submobs - for submob in submob_part - ]) - ] - filtered_target_parts = [ - submob_part for submob_part in target_parts - if all([ - submob in rest_target_submobs - for submob in submob_part - ]) - ] - if not (filtered_source_parts and filtered_target_parts): - return - anims.append(anim_class( - VGroup(*filtered_source_parts), - VGroup(*filtered_target_parts), - **kwargs + def add_anims_from(anim_class, func, source_args, target_args=None): + if target_args is None: + target_args = source_args.copy() + for source_arg, target_arg in zip(source_args, target_args): + source_parts = func(source_mobject, source_arg) + target_parts = func(target_mobject, target_arg) + source_indices_lists = source_mobject.indices_lists_of_parts( + source_parts + ) + target_indices_lists = target_mobject.indices_lists_of_parts( + target_parts + ) + filtered_source_indices_lists = list(filter( + lambda indices_list: all([ + index in rest_source_indices + for index in indices_list + ]), source_indices_lists )) - for submob in it.chain(*filtered_source_parts): - rest_source_submobs.remove(submob) - for submob in it.chain(*filtered_target_parts): - rest_target_submobs.remove(submob) + filtered_target_indices_lists = list(filter( + lambda indices_list: all([ + index in rest_target_indices + for index in indices_list + ]), target_indices_lists + )) + if not all([ + filtered_source_indices_lists, + filtered_target_indices_lists + ]): + return + anims.append(anim_class(source_parts, target_parts, **kwargs)) + for index in it.chain(*filtered_source_indices_lists): + rest_source_indices.remove(index) + for index in it.chain(*filtered_target_indices_lists): + rest_target_indices.remove(index) - def get_submobs_from_keys(mobject, keys): + def get_common_substrs(func): + result = sorted(list( + set(func(source_mobject)).intersection(func(target_mobject)) + ), key=len, reverse=True) + if "" in result: + result.remove("") + return result + + def get_parts_from_keys(mobject, keys): if not isinstance(keys, tuple): keys = (keys,) indices = [] @@ -220,55 +231,50 @@ class TransformMatchingStrings(AnimationGroup): else: raise TypeError(key) return VGroup(VGroup(*[ - mobject[i] for i in remove_list_redundancies(indices) + mobject[index] for index in remove_list_redundancies(indices) ])) add_anims_from( - ReplacementTransform, get_submobs_from_keys, + ReplacementTransform, get_parts_from_keys, self.key_map.keys(), self.key_map.values() ) - - common_specified_substrings = sorted(list( - set(source_mobject.get_specified_substrings()).intersection( - target_mobject.get_specified_substrings() - ) - ), key=len, reverse=True) - if "" in common_specified_substrings: - common_specified_substrings.remove("") add_anims_from( FadeTransformPieces, LabelledString.get_parts_by_string, - common_specified_substrings + get_common_substrs( + lambda mobject: mobject.specified_substrings + ) ) - - common_submob_strings = { - source_submob.get_string() for source_submob in source_mobject - }.intersection({ - target_submob.get_string() for target_submob in target_mobject - }) add_anims_from( FadeTransformPieces, - lambda mobject, attr: VGroup(*[ - VGroup(mob) for mob in mobject - if mob.get_string() == attr - ]), - common_submob_strings + LabelledString.get_parts_by_group_substr, + get_common_substrs( + lambda mobject: mobject.group_substrs + ) ) - if self.transform_mismatches_class is not None: - anims.append(self.transform_mismatches_class( + fade_source = VGroup(*[ + source_mobject[index] + for index in rest_source_indices + ]) + fade_target = VGroup(*[ + target_mobject[index] + for index in rest_target_indices + ]) + if self.transform_mismatches: + anims.append(ReplacementTransform( fade_source, fade_target, **kwargs )) else: anims.append(FadeOutToPoint( - VGroup(*rest_source_submobs), + fade_source, target_mobject.get_center(), **kwargs )) anims.append(FadeInFromPoint( - VGroup(*rest_target_submobs), + fade_target, source_mobject.get_center(), **kwargs )) diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 02c243f4..55909dc6 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -3,7 +3,7 @@ from __future__ import annotations import re import colour import itertools as it -from types import MethodType +#from types import MethodType from typing import Iterable, Union, Sequence from abc import abstractmethod @@ -55,13 +55,12 @@ class LabelledString(_StringSVG): def __init__(self, string: str, **kwargs): self.string = string + digest_config(self, kwargs) + self.pre_parse() + self.parse() super().__init__(**kwargs) def get_file_path(self, use_plain_file: bool = False) -> str: - #if use_plain_file: - # content = self.plain_string - #else: - # content = self.labelled_string content = self.get_decorated_string(use_plain_file=use_plain_file) return self.get_file_path_by_content(content) @@ -72,46 +71,47 @@ class LabelledString(_StringSVG): def generate_mobject(self) -> None: super().generate_mobject() - if not self.submobjects: - return - - glyph_labels = [ - self.color_to_label(glyph.get_fill_color()) - for glyph in self.submobjects + submob_labels = [ + self.color_to_label(submob.get_fill_color()) + for submob in self.submobjects ] - if self.use_plain_file or self.has_predefined_colors: file_path = self.get_file_path(use_plain_file=True) - glyphs = _StringSVG(file_path).submobjects - for glyph, plain_glyph in zip(self.submobjects, glyphs): - glyph.set_fill(plain_glyph.get_fill_color()) + plain_svg = _StringSVG(file_path) + self.set_submobjects(plain_svg.submobjects) else: - glyphs = self.submobjects self.set_fill(self.base_color) - - # TODO - # Simply pack together adjacent mobjects with the same label. - submob_labels, glyphs_lists = self.group_neighbours( - glyph_labels, glyphs - ) - submob_strings = self.get_submob_strings(submob_labels) - submobjects = [] - for glyph_list, label, submob_string in zip( - glyphs_lists, submob_labels, submob_strings - ): - submob = VGroup(*glyph_list) + for submob, label in zip(self.submobjects, submob_labels): submob.label = label - submob.string = submob_string - submob.get_string = MethodType(lambda inst: inst.string, submob) - submobjects.append(submob) - self.set_submobjects(submobjects) + self.submob_labels = submob_labels + self.post_parse() + + def pre_parse(self) -> None: + self.full_span = self.get_full_span() + self.space_spans = self.get_space_spans() + + def parse(self) -> None: + self.command_repl_items = self.get_command_repl_items() + self.command_spans = self.get_command_spans() + self.ignored_indices = self.get_ignored_indices() + self.skipped_spans = self.get_skipped_spans() + self.internal_specified_spans = self.get_internal_specified_spans() + self.external_specified_spans = self.get_external_specified_spans() + self.specified_spans = self.get_specified_spans() + self.label_span_list = self.get_label_span_list() + self.has_predefined_colors = self.get_has_predefined_colors() + + def post_parse(self) -> None: + self.containing_labels_dict = self.get_containing_labels_dict() + self.specified_substrings = self.get_specified_substrings() + self.group_substr_items = self.get_group_substr_items() + self.group_substrs = self.get_group_substrs() # Toolkits - def find_spans(self, *patterns: str) -> list[Span]: + def find_spans(self, pattern: str) -> list[Span]: return [ match_obj.span() - for pattern in patterns for match_obj in re.finditer(pattern, self.string) ] @@ -124,32 +124,23 @@ class LabelledString(_StringSVG): return span_0[0] <= span_1[0] and span_0[1] >= span_1[1] @staticmethod - def group_neighbours( - labels: Iterable[object], - objs: Iterable[object] - ) -> tuple[list[object], list[list[object]]]: - # Pack together neighbouring objects sharing the same label. - if not labels: - return [], [] + def compress_neighbours(vals: list[int]) -> list[tuple[int, Span]]: + if not vals: + return [] - group_labels = [] - groups = [] - new_group = [] - current_label = labels[0] - for label, obj in zip(labels, objs): - if label == current_label: - new_group.append(obj) - else: - group_labels.append(current_label) - groups.append(new_group) - new_group = [obj] - current_label = label - group_labels.append(current_label) - groups.append(new_group) - return group_labels, groups + unique_vals = [vals[0]] + indices = [0] + for index, val in enumerate(vals): + if val == unique_vals[-1]: + continue + unique_vals.append(val) + indices.append(index) + indices.append(len(vals)) + spans = LabelledString.get_neighbouring_pairs(indices) + return list(zip(unique_vals, spans)) @staticmethod - def find_region_index(val: int, seq: list[int]) -> int: + def find_region_index(seq: list[int], val: int) -> int: # Returns an integer in `range(-1, len(seq))` satisfying # `seq[result] <= val < seq[result + 1]`. # `seq` should be sorted in ascending order. @@ -160,6 +151,14 @@ class LabelledString(_StringSVG): result -= 1 return result + @staticmethod + def take_nearest_value(seq: list[int], val: int, index_shift: int) -> int: + sorted_seq = sorted(seq) + index = LabelledString.find_region_index(sorted_seq, val) + if index == -1: + raise IndexError + return sorted_seq[index + index_shift] + @staticmethod def replace_str_by_spans( substr: str, span_repl_dict: dict[Span, str] @@ -204,53 +203,22 @@ class LabelledString(_StringSVG): for flag in range(2) ])) result = { - (index, index): "".join(inserted_strs) - for index, inserted_strs in zip(*LabelledString.group_neighbours( - indices, inserted_strings - )) + (index, index): "".join(inserted_strings[slice(*item_span)]) + for index, item_span + in LabelledString.compress_neighbours(indices) } result.update(other_repl_items) return result - #@property - #def skipped_spans(self) -> list[Span]: - # return [ - # match_obj.span() - # for match_obj in re.finditer(r"\s+", self.string) - # ] - - #def lstrip(self, index: int) -> int: - # index_seq = list(it.chain(*self.skipped_spans)) - # region_index = self.find_region_index(index, index_seq) - # if region_index % 2 == 0: - # return index_seq[region_index + 1] - # return index - - #def rstrip(self, index: int) -> int: - # index_seq = list(it.chain(*self.skipped_spans)) - # region_index = self.find_region_index(index - 1, index_seq) - # if region_index % 2 == 0: - # return index_seq[region_index] - # return index - - #def strip(self, span: Span) -> Span | None: - # result = ( - # self.lstrip(span[0]), - # self.rstrip(span[1]) - # ) - # if result[0] >= result[1]: - # return None - # return result - @staticmethod - def lstrip(index: int, skipped: list[Span]) -> int: + def rslide(index: int, skipped: list[Span]) -> int: transfer_dict = dict(sorted(skipped)) while index in transfer_dict.keys(): index = transfer_dict[index] return index @staticmethod - def rstrip(index: int, skipped: list[Span]) -> int: + def lslide(index: int, skipped: list[Span]) -> int: transfer_dict = dict(sorted([ skipped_span[::-1] for skipped_span in skipped ], reverse=True)) @@ -259,71 +227,102 @@ class LabelledString(_StringSVG): return index @staticmethod - def strip(span: Span, skipped: list[Span]) -> Span | None: - result = ( - LabelledString.lstrip(span[0], skipped), - LabelledString.rstrip(span[1], skipped) + def shrink_span(span: Span, skipped: list[Span]) -> Span: + return ( + LabelledString.rslide(span[0], skipped), + LabelledString.lslide(span[1], skipped) ) - if result[0] >= result[1]: - return None - return result + + @staticmethod + def extend_span(span: Span, skipped: list[Span]) -> Span: + return ( + LabelledString.lslide(span[0], skipped), + LabelledString.rslide(span[1], skipped) + ) + + @staticmethod + def rgb_to_int(rgb_tuple: tuple[int, int, int]) -> int: + r, g, b = rgb_tuple + rg = r * 256 + g + return rg * 256 + b + + @staticmethod + def int_to_rgb(rgb_int: int) -> tuple[int, int, int]: + rg, b = divmod(rgb_int, 256) + r, g = divmod(rg, 256) + return r, g, b + + @staticmethod + def color_to_label(color: ManimColor) -> int: + rgb_tuple = color_to_int_rgb(color) + rgb = LabelledString.rgb_to_int(rgb_tuple) + if rgb == 16777215: # white + return -1 + return rgb @abstractmethod - def get_begin_color_command_str(r: int, g: int, b: int) -> str: + def get_begin_color_command_str(int_rgb: int) -> str: return "" @abstractmethod def get_end_color_command_str() -> str: return "" - @staticmethod - def color_to_label(color: ManimColor) -> int: - r, g, b = color_to_int_rgb(color) - rg = r * 256 + g - rgb = rg * 256 + b - if rgb == 16777215: # white - return -1 - return rgb + # Pre-parsing - # Parser - - @property - def full_span(self) -> Span: + def get_full_span(self) -> Span: return (0, len(self.string)) - @property - def space_spans(self) -> list[Span]: + def get_space_spans(self) -> list[Span]: return self.find_spans(r"\s+") @abstractmethod - def internal_specified_spans(self) -> list[Span]: + def get_command_repl_items(self) -> list[tuple[Span, str]]: return [] - @property - def external_specified_spans(self) -> list[Span]: - substrs = remove_list_redundancies(self.isolate) - if "" in substrs: - substrs.remove("") - return self.find_spans(*[ - re.escape(substr.strip()) for substr in substrs - ]) + def get_command_spans(self) -> list[Span]: + return [cmd_span for cmd_span, _ in self.command_repl_items] - @property - def specified_spans(self) -> list[Span]: - return remove_list_redundancies([ + def get_ignored_indices(self) -> list[int]: + return [] + + def get_skipped_spans(self) -> list[Span]: + return list(it.chain( + self.space_spans, + self.command_spans, + [ + (index, index + 1) + for index in self.ignored_indices + ] + )) + + @abstractmethod + def get_internal_specified_spans(self) -> list[Span]: + return [] + + def get_external_specified_spans(self) -> list[Span]: + return remove_list_redundancies(list(it.chain(*[ + self.find_spans(re.escape(substr.strip())) + for substr in self.isolate + ]))) + + def get_specified_spans(self) -> list[Span]: + spans = [ self.full_span, *self.internal_specified_spans, *self.external_specified_spans - ]) - - def get_specified_substrings(self) -> list[str]: - return remove_list_redundancies([ - self.string[slice(*span)] - for span in self.specified_spans - ]) + ] + shrinked_spans = list(filter( + lambda span: span[0] < span[1], + [ + self.shrink_span(span, self.skipped_spans) + for span in spans + ] + )) + return remove_list_redundancies(shrinked_spans) @abstractmethod - def label_span_list(self) -> list[Span]: + def get_label_span_list(self) -> list[Span]: return [] @abstractmethod @@ -333,31 +332,15 @@ class LabelledString(_StringSVG): return [] @abstractmethod - def command_repl_items(self) -> list[tuple[Span, str]]: + def get_other_repl_items( + self, use_plain_file: bool + ) -> list[tuple[Span, str]]: return [] - @property - def command_spans(self) -> list[Span]: - return [cmd_span for cmd_span, _ in self.command_repl_items] - - @abstractmethod - def remove_commands_in_plain_file(self) -> bool: - return True - - #@abstractmethod - #def get_command_repl_items( - # self, use_plain_file: bool - #) -> list[tuple[Span, str]]: - # return [] - def get_decorated_string(self, use_plain_file: bool) -> str: - if use_plain_file and self.remove_commands_in_plain_file: - other_repl_items = [] - else: - other_repl_items = self.command_repl_items span_repl_dict = self.get_span_replacement_dict( self.get_inserted_string_pairs(use_plain_file), - other_repl_items + self.get_other_repl_items(use_plain_file) ) result = self.replace_str_by_spans(self.string, span_repl_dict) @@ -365,46 +348,19 @@ class LabelledString(_StringSVG): return result return "".join([ self.get_begin_color_command_str( - *color_to_int_rgb(self.base_color) + self.rgb_to_int(color_to_int_rgb(self.base_color)) ), result, self.get_end_color_command_str() ]) @abstractmethod - def has_predefined_colors(self) -> bool: + def get_has_predefined_colors(self) -> bool: return False - #@property - #def plain_string(self) -> str: - # return self.string + # Post-parsing - #@property - #def labelled_string(self) -> str: - # return self.replace_str_by_spans( - # self.string, self.get_span_replacement_dict( - # self.inserted_string_pairs, - # self.command_repl_items - # ) - # ) - - @property - def additionally_ignored_indices(self) -> list[int]: - return [] - - @property - def skipped_spans(self) -> list[Span]: - return list(it.chain( - self.space_spans, - self.command_spans, - [ - (index, index + 1) - for index in self.additionally_ignored_indices - ] - )) - - @property - def containing_labels_dict(self) -> dict[Span, list[int]]: + def get_containing_labels_dict(self) -> dict[Span, list[int]]: label_span_list = self.label_span_list result = { span: [] @@ -423,14 +379,11 @@ class LabelledString(_StringSVG): "Partially overlapping substrings detected: " f"'{string_0}' and '{string_1}'" ) - result[self.full_span] = list(range(-1, len(label_span_list))) + if self.full_span not in result: + result[self.full_span] = list(range(len(label_span_list))) return result - def get_cleaned_substr(self, string_span: Span) -> str: - span = self.strip(string_span, self.skipped_spans) - if span is None: - return "" - + def get_cleaned_substr(self, span: Span) -> str: span_repl_dict = { tuple([index - span[0] for index in cmd_span]): "" for cmd_span in self.command_spans @@ -440,66 +393,85 @@ class LabelledString(_StringSVG): self.string[slice(*span)], span_repl_dict ) - def get_submob_strings(self, submob_labels: list[int]) -> list[str]: + def get_specified_substrings(self) -> list[str]: + return remove_list_redundancies([ + self.get_cleaned_substr(span) + for span in self.specified_spans + ]) + + def get_group_substr_items(self) -> tuple[list[Span], list[str]]: + group_labels, submob_spans = zip( + *self.compress_neighbours(self.submob_labels) + ) ordered_spans = [ self.label_span_list[label] if label != -1 else self.full_span - for label in submob_labels + for label in group_labels ] ordered_containing_labels = [ self.containing_labels_dict[span] for span in ordered_spans ] ordered_span_begins, ordered_span_ends = zip(*ordered_spans) - string_span_begins = [ + span_begins = [ prev_end if prev_label in containing_labels else curr_begin for prev_end, prev_label, containing_labels, curr_begin in zip( - ordered_span_ends[:-1], submob_labels[:-1], + ordered_span_ends[:-1], group_labels[:-1], ordered_containing_labels[1:], ordered_span_begins[1:] ) ] - string_span_ends = [ + span_ends = [ next_begin if next_label in containing_labels else curr_end for next_begin, next_label, containing_labels, curr_end in zip( - ordered_span_begins[1:], submob_labels[1:], + ordered_span_begins[1:], group_labels[1:], ordered_containing_labels[:-1], ordered_span_ends[:-1] ) ] - string_spans = list(zip( - (ordered_span_begins[0], *string_span_begins), - (*string_span_ends, ordered_span_ends[-1]) + spans = list(zip( + (ordered_span_begins[0], *span_begins), + (*span_ends, ordered_span_ends[-1]) )) - return [ - self.get_cleaned_substr(string_span) - for string_span in string_spans + shrinked_spans = [ + self.shrink_span(span, self.skipped_spans) + for span in spans ] + group_substrs = [ + self.get_cleaned_substr(span) if span[0] < span[1] else "" + for span in shrinked_spans + ] + return submob_spans, group_substrs + + def get_group_substrs(self) -> list[str]: + return self.group_substr_items[1] # Selector def find_span_components_of_custom_span( self, custom_span: Span ) -> list[Span]: - skipped_spans = self.skipped_spans + indices = remove_list_redundancies(list(it.chain( + self.full_span, + *self.label_span_list + ))) + span_begin = self.take_nearest_value(indices, custom_span[0], 0) + span_end = self.take_nearest_value(indices, custom_span[1] - 1, 1) span_choices = sorted(filter( - lambda span: self.span_contains(custom_span, span), + lambda span: self.span_contains((span_begin, span_end), span), self.label_span_list )) # Choose spans that reach the farthest. span_choices_dict = dict(span_choices) result = [] - span_begin, span_end = custom_span - span_begin = self.rstrip(span_begin, skipped_spans) - span_end = self.rstrip(span_end, skipped_spans) - while span_begin != span_end: - span_begin = self.lstrip(span_begin, skipped_spans) + while span_begin < span_end: if span_begin not in span_choices_dict.keys(): - return [] + span_begin += 1 + continue next_begin = span_choices_dict[span_begin] result.append((span_begin, next_begin)) span_begin = next_begin return result - def get_part_by_custom_span(self, custom_span: Span) -> VGroup: + def get_parts_by_custom_span(self, custom_span: Span) -> VGroup: spans = self.find_span_components_of_custom_span(custom_span) labels = set(it.chain(*[ self.containing_labels_dict[span] @@ -512,13 +484,19 @@ class LabelledString(_StringSVG): def get_parts_by_string(self, substr: str) -> VGroup: return VGroup(*[ - self.get_part_by_custom_span(span) + self.get_parts_by_custom_span(span) for span in self.find_spans(re.escape(substr.strip())) ]) - def get_part_by_string(self, substr: str, index: int = 0) -> VMobject: - all_parts = self.get_parts_by_string(substr) - return all_parts[index] + def get_parts_by_group_substr(self, substr: str) -> VGroup: + return VGroup(*[ + VGroup(*self.submobjects[slice(*submob_span)]) + for submob_span, group_substr in zip(*self.group_substr_items) + if group_substr == substr + ]) + + def get_part_by_string(self, substr: str, index : int = 0) -> VMobject: + return self.get_parts_by_string(substr)[index] def set_color_by_string(self, substr: str, color: ManimColor): self.get_parts_by_string(substr).set_color(color) @@ -532,19 +510,12 @@ class LabelledString(_StringSVG): return self def indices_of_part(self, part: Iterable[VMobject]) -> list[int]: - indices = [ - index for index, submob in enumerate(self.submobjects) - if submob in part - ] - if not indices: - raise ValueError("Failed to find part") - return indices + return [self.submobjects.index(submob) for submob in part] - def indices_of_part_by_string( - self, substr: str, index: int = 0 - ) -> list[int]: - part = self.get_part_by_string(substr, index=index) - return self.indices_of_part(part) + def indices_lists_of_parts( + self, parts: Iterable[Iterable[VMobject]] + ) -> list[list[int]]: + return [self.indices_of_part(part) for part in parts] def get_string(self) -> str: return self.string @@ -612,21 +583,25 @@ class MTex(LabelledString): def tex_to_svg_file_path(tex_file_content: str) -> str: return tex_to_svg_file(tex_file_content) + def pre_parse(self) -> None: + super().pre_parse() + self.backslash_indices = self.get_backslash_indices() + self.brace_index_pairs = self.get_brace_index_pairs() + self.script_char_indices = self.get_script_char_indices() + self.script_char_spans = self.get_script_char_spans() + self.script_content_spans = self.get_script_content_spans() + self.script_spans = self.get_script_spans() + # Toolkits - #@property - #def skipped_spans(self) -> list[Span]: - # return super().skipped_spans + self.indices_to_spans( - # self.script_char_indices - # ) - @staticmethod - def get_begin_color_command_str(r: int, g: int, b: int) -> str: + def get_begin_color_command_str(rgb_int: int) -> str: + rgb_tuple = MTex.int_to_rgb(rgb_int) return "".join([ "{{", "\\color[RGB]", "{", - ",".join(map(str, (r, g, b))), + ",".join(map(str, rgb_tuple)), "}" ]) @@ -634,29 +609,27 @@ class MTex(LabelledString): def get_end_color_command_str() -> str: return "}}" - # Parser + # Pre-parsing - @property - def backslash_indices(self) -> list[int]: + def get_backslash_indices(self) -> list[int]: # Newlines (`\\`) don't count. return [ - match_obj.end() - 1 - for match_obj in re.finditer(r"\\+", self.string) - if len(match_obj.group()) % 2 == 1 + span[1] - 1 + for span in self.find_spans(r"\\+") + if (span[1] - span[0]) % 2 == 1 ] - @staticmethod - def get_unescaped_char_indices(*chars: str): - return list(filter( + def get_unescaped_char_indices(self, *chars: str): + return sorted(filter( lambda index: index - 1 not in self.backslash_indices, [ - match_obj.start() + span[0] for char in chars - for match_obj in re.finditer(re.escape(char), string) + for span in self.find_spans(re.escape(char)) ] )) - def get_brace_indices_lists(self) -> tuple[list[Span], list[Span]]: + def get_brace_index_pairs(self) -> list[Span]: string = self.string indices = self.get_unescaped_char_indices("{", "}") left_brace_indices = [] @@ -673,29 +646,21 @@ class MTex(LabelledString): right_brace_indices.append(index) if left_brace_indices_stack: raise ValueError("Missing '}' inserted") - # `right_brace_indices` is already sorted. - return left_brace_indices, right_brace_indices + return list(zip(left_brace_indices, right_brace_indices)) - @property - def left_brace_indices(self) -> list[Span]: - return self.get_brace_indices_lists()[0] - - @property - def right_brace_indices(self) -> list[Span]: - return self.get_brace_indices_lists()[1] - - @property - def script_char_indices(self) -> list[Span]: + def get_script_char_indices(self) -> list[int]: return self.get_unescaped_char_indices("_", "^") - @property - def script_content_spans(self) -> list[Span]: + def get_script_char_spans(self) -> list[Span]: + return [ + self.extend_span((index, index + 1), self.space_spans) + for index in self.script_char_indices + ] + + def get_script_content_spans(self) -> list[Span]: result = [] - brace_indices_dict = dict(zip( - self.left_brace_indices, self.right_brace_indices - )) - for index in self.script_char_indices: - span_begin = self.lstrip(index, self.space_spans) + brace_indices_dict = dict(self.brace_index_pairs) + for _, span_begin in self.script_char_spans: if span_begin in brace_indices_dict.keys(): span_end = brace_indices_dict[span_begin] + 1 else: @@ -714,16 +679,65 @@ class MTex(LabelledString): result.append((span_begin, span_end)) return result - @property - def internal_specified_spans(self) -> list[Span]: + def get_script_spans(self) -> list[Span]: + return [ + (script_char_span[0], script_content_span[1]) + for script_char_span, script_content_span in zip( + self.script_char_spans, self.script_content_spans + ) + ] + + # Parsing + + def get_command_repl_items(self) -> list[tuple[Span, str]]: + color_related_command_dict = { + "color": (1, False), + "textcolor": (1, False), + "pagecolor": (1, True), + "colorbox": (1, True), + "fcolorbox": (2, True), + } + result = [] + backslash_indices = self.backslash_indices + right_brace_indices = [ + right_index + for left_index, right_index in self.brace_index_pairs + ] + pattern = "".join([ + r"\\", + "(", + "|".join(color_related_command_dict.keys()), + ")", + r"(?![a-zA-Z])" + ]) + for match_obj in re.finditer(pattern, self.string): + span_begin, cmd_end = match_obj.span() + if span_begin not in backslash_indices: + continue + cmd_name = match_obj.group(1) + n_braces, substitute_cmd = color_related_command_dict[cmd_name] + span_end = self.take_nearest_value( + right_brace_indices, cmd_end, n_braces + ) + 1 + if substitute_cmd: + repl_str = "\\" + cmd_name + n_braces * "{white}" + else: + repl_str = "" + result.append(((span_begin, span_end), repl_str)) + return result + + def get_ignored_indices(self) -> list[int]: + return self.script_char_indices + + def get_internal_specified_spans(self) -> list[Span]: # Match paired double braces (`{{...}}`). result = [] - reversed_brace_indices_dict = dict(zip( - self.right_brace_indices, self.left_brace_indices - )) + reversed_brace_indices_dict = dict([ + pair[::-1] for pair in self.brace_index_pairs + ]) skip = False for prev_right_index, right_index in self.get_neighbouring_pairs( - self.right_brace_indices + list(reversed_brace_indices_dict.keys()) ): if skip: skip = False @@ -738,150 +752,61 @@ class MTex(LabelledString): skip = True return result - @property - def label_span_list(self) -> list[Span]: - script_content_spans = self.script_content_spans - script_spans = [ - (self.rstrip(index, self.space_spans), script_content_span[1]) - for index, script_content_span in zip( - self.script_char_indices, script_content_spans - ) - ] - spans = remove_list_redundancies([ - *self.specified_spans, - *script_content_spans - ]) - result = [] - for span in spans: - if span in script_content_spans: - continue - span_begin, span_end = span - shrinked_end = self.rstrip(span_end, script_spans) + def get_label_span_list(self) -> list[Span]: + result = self.script_content_spans.copy() + for span_begin, span_end in self.specified_spans: + shrinked_end = self.lslide(span_end, self.script_spans) if span_begin >= shrinked_end: continue - result.append((span_begin, self.lstrip(span_end, script_spans))) - - #if extended: - # result = [ - # (span_begin, self.lstrip(span_end, script_spans)) - # for span_begin, span_end in result - # ] - return script_content_spans + remove_list_redundancies(result) - - #@property - #def label_span_list(self) -> list[Span]: - # return self.get_label_span_list(extended=False) + shrinked_span = (span_begin, shrinked_end) + if shrinked_span in result: + continue + result.append(shrinked_span) + return result def get_inserted_string_pairs( self, use_plain_file: bool ) -> list[tuple[Span, tuple[str, str]]]: if use_plain_file: return [] + + extended_label_span_list = [ + span + if span in self.script_content_spans + else (span[0], self.rslide(span[1], self.script_spans)) + for span in self.label_span_list + ] return [ (span, ( - self.get_begin_color_command_str( - label // 256 // 256, - label // 256 % 256, - label % 256 - ), + self.get_begin_color_command_str(label), self.get_end_color_command_str() )) - for label, span in enumerate( - self.label_span_list - ) + for label, span in enumerate(extended_label_span_list) ] - #@property - #def inserted_string_pairs(self) -> list[tuple[Span, tuple[str, str]]]: - # return [ - # (span, ( - # "{{" + self.get_color_command_by_label(label), - # "}}" - # )) - # for label, span in enumerate( - # self.get_label_span_list(extended=True) - # ) - # ] + def get_other_repl_items( + self, use_plain_file: bool + ) -> list[tuple[Span, str]]: + if use_plain_file: + return [] + return self.command_repl_items.copy() - @property - def command_repl_items(self) -> list[tuple[Span, str]]: - color_related_command_dict = { - "color": (1, False), - "textcolor": (1, False), - "pagecolor": (1, True), - "colorbox": (1, True), - "fcolorbox": (2, True), - } - result = [] - backslash_indices = self.backslash_indices - right_brace_indices = self.right_brace_indices - pattern = "".join([ - r"\\", - "(", - "|".join(color_related_command_dict.keys()), - ")", - r"(?![a-zA-Z])" - ]) - for match_obj in re.finditer(pattern, self.string): - span_begin, cmd_end = match_obj.span() - if span_begin not in backslash_indices: - continue - cmd_name = match_obj.group(1) - n_braces, substitute_cmd = color_related_command_dict[cmd_name] - span_end = right_brace_indices[self.find_region_index( - cmd_end, right_brace_indices - ) + n_braces] + 1 - if substitute_cmd: - repl_str = "\\" + cmd_name + n_braces * "{white}" - else: - repl_str = "" - result.append(((span_begin, span_end), repl_str)) - return result - - @property - def remove_commands_in_plain_file(self) -> bool: - return True - - @property - def has_predefined_colors(self) -> bool: + def get_has_predefined_colors(self) -> bool: return bool(self.command_repl_items) - #@staticmethod - #def get_color_command_by_label(label: int) -> str: - # if label == -1: - # label = 16777215 # white - # rg, b = divmod(label, 256) - # r, g = divmod(rg, 256) - # return "".join([ - # "\\color[RGB]", - # "{", - # ",".join(map(str, (r, g, b))), - # "}" - # ]) + def get_cleaned_substr(self, span: Span) -> str: + substr = super().get_cleaned_substr(span) + if not self.brace_index_pairs: + return substr - #@property - #def plain_string(self) -> str: - # return "".join([ - # "{{", - # self.get_color_command_by_label( - # self.color_to_label(self.base_color) - # ), - # self.string, - # "}}" - # ]) - - @property - def additionally_ignored_indices(self) -> list[int]: - return self.left_brace_indices + self.right_brace_indices - - def get_cleaned_substr(self, string_span: Span) -> str: - substr = super().get_cleaned_substr(string_span) + # Balance braces. + left_brace_indices, right_brace_indices = zip(*self.brace_index_pairs) unclosed_left_braces = 0 unclosed_right_braces = 0 - for index in range(*string_span): - if index in self.left_brace_indices: + for index in range(*span): + if index in left_brace_indices: unclosed_left_braces += 1 - elif index in self.right_brace_indices: + elif index in right_brace_indices: if unclosed_left_braces == 0: unclosed_right_braces += 1 else: @@ -894,25 +819,20 @@ class MTex(LabelledString): # Method alias - def get_parts_by_tex(self, substr: str) -> VGroup: - return self.get_parts_by_string(substr) + def get_parts_by_tex(self, tex: str) -> VGroup: + return self.get_parts_by_string(tex) - def get_part_by_tex(self, substr: str, index: int = 0) -> VMobject: - return self.get_part_by_string(substr, index) + def get_part_by_tex(self, tex: str) -> VMobject: + return self.get_part_by_string(tex) - def set_color_by_tex(self, substr: str, color: ManimColor): - return self.set_color_by_string(substr, color) + def set_color_by_tex(self, tex: str, color: ManimColor): + return self.set_color_by_string(tex, color) def set_color_by_tex_to_color_map( self, tex_to_color_map: dict[str, ManimColor] ): return self.set_color_by_string_to_color_map(tex_to_color_map) - def indices_of_part_by_tex( - self, substr: str, index: int = 0 - ) -> list[int]: - return self.indices_of_part_by_string(substr, index) - def get_tex(self) -> str: return self.get_string() diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 2d80305f..98fc6dbe 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -234,6 +234,14 @@ class MarkupText(LabelledString): pango_width=pango_width ) + def parse(self) -> None: + self.global_items_from_config = self.get_global_items_from_config() + self.tag_items_from_markup = self.get_tag_items_from_markup() + self.local_items_from_markup = self.get_local_items_from_markup() + self.local_items_from_config = self.get_local_items_from_config() + self.predefined_items = self.get_predefined_items() + super().parse() + # Toolkits @staticmethod @@ -251,6 +259,19 @@ class MarkupText(LabelledString): def get_end_tag_str() -> str: return "" + @staticmethod + def rgb_int_to_hex(rgb_int: int) -> str: + return "#{:06x}".format(rgb_int).upper() + + @staticmethod + def get_begin_color_command_str(rgb_int: int): + color_hex = MarkupText.rgb_int_to_hex(rgb_int) + return MarkupText.get_begin_tag_str({"foreground": color_hex}) + + @staticmethod + def get_end_color_command_str() -> str: + return MarkupText.get_end_tag_str() + @staticmethod def convert_attr_key(key: str) -> str: return SPAN_ATTR_KEY_CONVERSION[key.lower()] @@ -269,7 +290,7 @@ class MarkupText(LabelledString): if span[0] >= span[1]: continue region_indices = [ - MarkupText.find_region_index(index, index_seq) + MarkupText.find_region_index(index_seq, index) for index in span ] for flag in (1, 0): @@ -289,19 +310,43 @@ class MarkupText(LabelledString): MarkupText.get_neighbouring_pairs(index_seq), attr_dict_list[:-1] )) - @staticmethod - def get_begin_color_command_str(r: int, g: int, b: int) -> str: - color_hex = "#{:02x}{:02x}{:02x}".format(r, g, b).upper() - return MarkupText.get_begin_tag_str({"foreground": color_hex}) + def find_spans_by_word_or_span( + self, word_or_span: str | Span + ) -> list[Span]: + if isinstance(word_or_span, tuple): + return [word_or_span] + return self.find_spans(re.escape(word_or_span)) - @staticmethod - def get_end_color_command_str() -> str: - return MarkupText.get_end_tag_str() + # Pre-parsing - # Parser + def get_global_items_from_config(self) -> list[str, str]: + global_attr_dict = { + "line_height": ( + (self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1 + ) * 0.6, + "font_family": self.font or get_customization()["style"]["font"], + "font_size": self.font_size * 1024, + "font_style": self.slant, + "font_weight": self.weight + } + global_attr_dict = { + k: v + for k, v in global_attr_dict.items() + if v is not None + } + result = list(it.chain( + global_attr_dict.items(), + self.global_config.items() + )) + return [ + ( + self.convert_attr_key(key), + self.convert_attr_val(val) + ) + for key, val in result + ] - @property - def tag_items_from_markup( + def get_tag_items_from_markup( self ) -> list[tuple[Span, Span, dict[str, str]]]: if not self.is_markup: @@ -349,36 +394,7 @@ class MarkupText(LabelledString): ) return result - @property - def global_attr_items_from_config(self) -> list[str, str]: - global_attr_dict = { - "line_height": ( - (self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1 - ) * 0.6, - "font_family": self.font or get_customization()["style"]["font"], - "font_size": self.font_size * 1024, - "font_style": self.slant, - "font_weight": self.weight - } - global_attr_dict = { - k: v - for k, v in global_attr_dict.items() - if v is not None - } - result = list(it.chain( - global_attr_dict.items(), - self.global_config.items() - )) - return [ - ( - self.convert_attr_key(key), - self.convert_attr_val(val) - ) - for key, val in result - ] - - @property - def local_attr_items_from_markup(self) -> list[tuple[Span, str, str]]: + def get_local_items_from_markup(self) -> list[tuple[Span, str, str]]: return sorted([ ( (begin_tag_span[0], end_tag_span[1]), @@ -390,8 +406,7 @@ class MarkupText(LabelledString): for key, val in attr_dict.items() ]) - @property - def local_attr_items_from_config(self) -> list[tuple[Span, str, str]]: + def get_local_items_from_config(self) -> list[tuple[Span, str, str]]: result = [ (text_span, key, val) for t2x_dict, key in ( @@ -417,88 +432,19 @@ class MarkupText(LabelledString): for text_span, key, val in result ] - def find_spans_by_word_or_span( - self, word_or_span: str | Span - ) -> list[Span]: - if isinstance(word_or_span, tuple): - return [word_or_span] - - return self.find_spans(re.escape(word_or_span)) - - #@property - #def skipped_spans(self) -> list[Span]: - # return [ - # match_obj.span() - # for match_obj in re.finditer(r"\s+", self.string) - # ] - - #@property - #def additional_substrings(self) -> list[str]: - # return self.get_substrs_to_isolate(self.isolate) - - @property - def internal_specified_spans(self) -> list[Span]: - return [ - markup_span - for markup_span, _, _ in self.local_attr_items_from_markup - ] - - @property - def label_span_list(self) -> list[Span]: - entity_spans = [span for span, _ in self.command_repl_items] - if self.is_markup: - entity_spans += self.find_spans(r"&.*?;") - breakup_indices = sorted(filter( - lambda index: not any([ - span[0] < index < span[1] - for span in entity_spans - ]), - remove_list_redundancies(list(it.chain(*( - self.specified_spans + self.find_spans(r"\s+", r"\b") - )))) - )) - return list(filter( - lambda span: self.string[slice(*span)].strip(), - self.get_neighbouring_pairs(breakup_indices) - )) - - @property - def predefined_items(self) -> list[Span, str, str]: + def get_predefined_items(self) -> list[Span, str, str]: return list(it.chain( [ (self.full_span, key, val) - for key, val in self.global_attr_items_from_config + for key, val in self.global_items_from_config ], - self.local_attr_items_from_markup, - self.local_attr_items_from_config + self.local_items_from_markup, + self.local_items_from_config )) - def get_inserted_string_pairs( - self, use_plain_file: bool - ) -> list[tuple[Span, tuple[str, str]]]: - attr_items = self.predefined_items - if not use_plain_file: - attr_items = [ - (span, key, WHITE if key in COLOR_RELATED_KEYS else val) - for span, key, val in attr_items - ] + [ - (span, "foreground", "#{:06x}".format(label)) - for label, span in enumerate(self.label_span_list) - ] - return [ - (span, ( - self.get_begin_tag_str(attr_dict), - self.get_end_tag_str() - )) - for span, attr_dict in self.merge_attr_items(attr_items) - ] + # Parsing - #@property - #def inserted_string_pairs(self) -> list[tuple[Span, tuple[str, str]]]: - # return self.get_inserted_string_pairs(use_label=True) - - @property - def command_repl_items(self) -> list[tuple[Span, str]]: + def get_command_repl_items(self) -> list[tuple[Span, str]]: result = [ (tag_span, "") for begin_tag, end_tag, _ in self.tag_items_from_markup @@ -516,63 +462,80 @@ class MarkupText(LabelledString): ] return result - def remove_commands_in_plain_file(self) -> bool: - return False + def get_internal_specified_spans(self) -> list[Span]: + return [ + markup_span + for markup_span, _, _ in self.local_items_from_markup + ] - #@abstractmethod - #def get_command_repl_items( - # self, use_plain_file: bool - #) -> list[tuple[Span, str]]: - # return [] + def get_label_span_list(self) -> list[Span]: + breakup_indices = remove_list_redundancies(list(it.chain(*it.chain( + self.space_spans, + self.find_spans(r"\b"), + self.specified_spans + )))) + entity_spans = self.command_spans.copy() + if self.is_markup: + entity_spans += self.find_spans(r"&.*?;") + breakup_indices = sorted(filter( + lambda index: not any([ + span[0] < index < span[1] + for span in entity_spans + ]), + breakup_indices + )) + return list(filter( + lambda span: self.string[slice(*span)].strip(), + self.get_neighbouring_pairs(breakup_indices) + )) - @property - def has_predefined_colors(self) -> bool: + def get_inserted_string_pairs( + self, use_plain_file: bool + ) -> list[tuple[Span, tuple[str, str]]]: + attr_items = self.predefined_items + if not use_plain_file: + attr_items = [ + (span, key, WHITE if key in COLOR_RELATED_KEYS else val) + for span, key, val in attr_items + ] + [ + (span, "foreground", self.rgb_int_to_hex(label)) + for label, span in enumerate(self.label_span_list) + ] + return [ + (span, ( + self.get_begin_tag_str(attr_dict), + self.get_end_tag_str() + )) + for span, attr_dict in self.merge_attr_items(attr_items) + ] + + def get_other_repl_items( + self, use_plain_file: bool + ) -> list[tuple[Span, str]]: + return self.command_repl_items.copy() + + def get_has_predefined_colors(self) -> bool: return any([ key in COLOR_RELATED_KEYS for _, key, _ in self.predefined_items ]) - #@property - #def plain_string(self) -> str: - # return "".join([ - # self.get_begin_tag_str({"foreground": self.base_color}), - # self.replace_str_by_spans( - # self.string, self.get_span_replacement_dict( - # self.get_inserted_string_pairs(use_label=False), - # self.command_repl_items - # ) - # ), - # self.get_end_tag_str() - # ]) - - #@property - #def specified_substrings(self) -> list[str]: # TODO: clean up and merge - # return remove_list_redundancies([ - # self.get_cleaned_substr(markup_span) - # for markup_span, _, _ in self.local_attr_items_from_markup - # ] + self.additional_substrings) - # Method alias - def get_parts_by_text(self, substr: str) -> VGroup: - return self.get_parts_by_string(substr) + def get_parts_by_text(self, text: str) -> VGroup: + return self.get_parts_by_string(text) - def get_part_by_text(self, substr: str, index: int = 0) -> VMobject: - return self.get_part_by_string(substr, index) + def get_part_by_text(self, text: str) -> VMobject: + return self.get_part_by_string(text) - def set_color_by_text(self, substr: str, color: ManimColor): - return self.set_color_by_string(substr, color) + def set_color_by_text(self, text: str, color: ManimColor): + return self.set_color_by_string(text, color) def set_color_by_text_to_color_map( self, text_to_color_map: dict[str, ManimColor] ): return self.set_color_by_string_to_color_map(text_to_color_map) - def indices_of_part_by_text( - self, substr: str, index: int = 0 - ) -> list[int]: - return self.indices_of_part_by_string(substr, index) - def get_text(self) -> str: return self.get_string() From 0add9b6e3a12ba148d338c85e44225ecd73c9007 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Wed, 30 Mar 2022 21:57:27 +0800 Subject: [PATCH 13/48] Rename file --- .../animation/transform_matching_parts.py | 2 +- manimlib/mobject/svg/labelled_string.py | 843 ++++++++++++++++++ manimlib/mobject/svg/text_mobject.py | 2 +- 3 files changed, 845 insertions(+), 2 deletions(-) create mode 100644 manimlib/mobject/svg/labelled_string.py diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index f92d962d..2b452d2d 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -12,7 +12,7 @@ from manimlib.animation.transform import ReplacementTransform from manimlib.animation.transform import Transform from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Group -from manimlib.mobject.svg.mtex_mobject import LabelledString +from manimlib.mobject.svg.labelled_string import LabelledString from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.config_ops import digest_config diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py new file mode 100644 index 00000000..55909dc6 --- /dev/null +++ b/manimlib/mobject/svg/labelled_string.py @@ -0,0 +1,843 @@ +from __future__ import annotations + +import re +import colour +import itertools as it +#from types import MethodType +from typing import Iterable, Union, Sequence +from abc import abstractmethod + +from manimlib.constants import WHITE +from manimlib.mobject.svg.svg_mobject import SVGMobject +from manimlib.mobject.types.vectorized_mobject import VGroup +from manimlib.utils.color import color_to_int_rgb +from manimlib.utils.config_ops import digest_config +from manimlib.utils.iterables import adjacent_pairs +from manimlib.utils.iterables import remove_list_redundancies +from manimlib.utils.tex_file_writing import tex_to_svg_file +from manimlib.utils.tex_file_writing import get_tex_config +from manimlib.utils.tex_file_writing import display_during_execution +from manimlib.logger import log + + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from manimlib.mobject.types.vectorized_mobject import VMobject + ManimColor = Union[str, colour.Color, Sequence[float]] + Span = tuple[int, int] + + +SCALE_FACTOR_PER_FONT_POINT = 0.001 + + +class _StringSVG(SVGMobject): + CONFIG = { + "height": None, + "stroke_width": 0, + "stroke_color": WHITE, + "path_string_config": { + "should_subdivide_sharp_curves": True, + "should_remove_null_curves": True, + }, + } + + +class LabelledString(_StringSVG): + """ + An abstract base class for `MTex` and `MarkupText` + """ + CONFIG = { + "base_color": WHITE, + "use_plain_file": False, + "isolate": [], + } + + def __init__(self, string: str, **kwargs): + self.string = string + digest_config(self, kwargs) + self.pre_parse() + self.parse() + super().__init__(**kwargs) + + def get_file_path(self, use_plain_file: bool = False) -> str: + content = self.get_decorated_string(use_plain_file=use_plain_file) + return self.get_file_path_by_content(content) + + @abstractmethod + def get_file_path_by_content(self, content: str) -> str: + return "" + + def generate_mobject(self) -> None: + super().generate_mobject() + + submob_labels = [ + self.color_to_label(submob.get_fill_color()) + for submob in self.submobjects + ] + if self.use_plain_file or self.has_predefined_colors: + file_path = self.get_file_path(use_plain_file=True) + plain_svg = _StringSVG(file_path) + self.set_submobjects(plain_svg.submobjects) + else: + self.set_fill(self.base_color) + for submob, label in zip(self.submobjects, submob_labels): + submob.label = label + self.submob_labels = submob_labels + self.post_parse() + + def pre_parse(self) -> None: + self.full_span = self.get_full_span() + self.space_spans = self.get_space_spans() + + def parse(self) -> None: + self.command_repl_items = self.get_command_repl_items() + self.command_spans = self.get_command_spans() + self.ignored_indices = self.get_ignored_indices() + self.skipped_spans = self.get_skipped_spans() + self.internal_specified_spans = self.get_internal_specified_spans() + self.external_specified_spans = self.get_external_specified_spans() + self.specified_spans = self.get_specified_spans() + self.label_span_list = self.get_label_span_list() + self.has_predefined_colors = self.get_has_predefined_colors() + + def post_parse(self) -> None: + self.containing_labels_dict = self.get_containing_labels_dict() + self.specified_substrings = self.get_specified_substrings() + self.group_substr_items = self.get_group_substr_items() + self.group_substrs = self.get_group_substrs() + + # Toolkits + + def find_spans(self, pattern: str) -> list[Span]: + return [ + match_obj.span() + for match_obj in re.finditer(pattern, self.string) + ] + + @staticmethod + def get_neighbouring_pairs(iterable: Iterable) -> list: + return list(adjacent_pairs(iterable))[:-1] + + @staticmethod + def span_contains(span_0: Span, span_1: Span) -> bool: + return span_0[0] <= span_1[0] and span_0[1] >= span_1[1] + + @staticmethod + def compress_neighbours(vals: list[int]) -> list[tuple[int, Span]]: + if not vals: + return [] + + unique_vals = [vals[0]] + indices = [0] + for index, val in enumerate(vals): + if val == unique_vals[-1]: + continue + unique_vals.append(val) + indices.append(index) + indices.append(len(vals)) + spans = LabelledString.get_neighbouring_pairs(indices) + return list(zip(unique_vals, spans)) + + @staticmethod + def find_region_index(seq: list[int], val: int) -> int: + # Returns an integer in `range(-1, len(seq))` satisfying + # `seq[result] <= val < seq[result + 1]`. + # `seq` should be sorted in ascending order. + if not seq or val < seq[0]: + return -1 + result = len(seq) - 1 + while val < seq[result]: + result -= 1 + return result + + @staticmethod + def take_nearest_value(seq: list[int], val: int, index_shift: int) -> int: + sorted_seq = sorted(seq) + index = LabelledString.find_region_index(sorted_seq, val) + if index == -1: + raise IndexError + return sorted_seq[index + index_shift] + + @staticmethod + def replace_str_by_spans( + substr: str, span_repl_dict: dict[Span, str] + ) -> str: + if not span_repl_dict: + return substr + + spans = sorted(span_repl_dict.keys()) + if not all( + span_0[1] <= span_1[0] + for span_0, span_1 in LabelledString.get_neighbouring_pairs(spans) + ): + raise ValueError("Overlapping replacement") + + span_ends, span_begins = zip(*spans) + pieces = [ + substr[slice(*span)] + for span in zip( + (0, *span_begins), + (*span_ends, len(substr)) + ) + ] + repl_strs = [*[span_repl_dict[span] for span in spans], ""] + return "".join(it.chain(*zip(pieces, repl_strs))) + + @staticmethod + def get_span_replacement_dict( + inserted_string_pairs: list[tuple[Span, tuple[str, str]]], + other_repl_items: list[tuple[Span, str]] + ) -> dict[Span, str]: + if not inserted_string_pairs: + return other_repl_items.copy() + + indices, _, _, inserted_strings = zip(*sorted([ + ( + span[flag], + -flag, + -span[1 - flag], + str_pair[flag] + ) + for span, str_pair in inserted_string_pairs + for flag in range(2) + ])) + result = { + (index, index): "".join(inserted_strings[slice(*item_span)]) + for index, item_span + in LabelledString.compress_neighbours(indices) + } + result.update(other_repl_items) + return result + + @staticmethod + def rslide(index: int, skipped: list[Span]) -> int: + transfer_dict = dict(sorted(skipped)) + while index in transfer_dict.keys(): + index = transfer_dict[index] + return index + + @staticmethod + def lslide(index: int, skipped: list[Span]) -> int: + transfer_dict = dict(sorted([ + skipped_span[::-1] for skipped_span in skipped + ], reverse=True)) + while index in transfer_dict.keys(): + index = transfer_dict[index] + return index + + @staticmethod + def shrink_span(span: Span, skipped: list[Span]) -> Span: + return ( + LabelledString.rslide(span[0], skipped), + LabelledString.lslide(span[1], skipped) + ) + + @staticmethod + def extend_span(span: Span, skipped: list[Span]) -> Span: + return ( + LabelledString.lslide(span[0], skipped), + LabelledString.rslide(span[1], skipped) + ) + + @staticmethod + def rgb_to_int(rgb_tuple: tuple[int, int, int]) -> int: + r, g, b = rgb_tuple + rg = r * 256 + g + return rg * 256 + b + + @staticmethod + def int_to_rgb(rgb_int: int) -> tuple[int, int, int]: + rg, b = divmod(rgb_int, 256) + r, g = divmod(rg, 256) + return r, g, b + + @staticmethod + def color_to_label(color: ManimColor) -> int: + rgb_tuple = color_to_int_rgb(color) + rgb = LabelledString.rgb_to_int(rgb_tuple) + if rgb == 16777215: # white + return -1 + return rgb + + @abstractmethod + def get_begin_color_command_str(int_rgb: int) -> str: + return "" + + @abstractmethod + def get_end_color_command_str() -> str: + return "" + + # Pre-parsing + + def get_full_span(self) -> Span: + return (0, len(self.string)) + + def get_space_spans(self) -> list[Span]: + return self.find_spans(r"\s+") + + @abstractmethod + def get_command_repl_items(self) -> list[tuple[Span, str]]: + return [] + + def get_command_spans(self) -> list[Span]: + return [cmd_span for cmd_span, _ in self.command_repl_items] + + def get_ignored_indices(self) -> list[int]: + return [] + + def get_skipped_spans(self) -> list[Span]: + return list(it.chain( + self.space_spans, + self.command_spans, + [ + (index, index + 1) + for index in self.ignored_indices + ] + )) + + @abstractmethod + def get_internal_specified_spans(self) -> list[Span]: + return [] + + def get_external_specified_spans(self) -> list[Span]: + return remove_list_redundancies(list(it.chain(*[ + self.find_spans(re.escape(substr.strip())) + for substr in self.isolate + ]))) + + def get_specified_spans(self) -> list[Span]: + spans = [ + self.full_span, + *self.internal_specified_spans, + *self.external_specified_spans + ] + shrinked_spans = list(filter( + lambda span: span[0] < span[1], + [ + self.shrink_span(span, self.skipped_spans) + for span in spans + ] + )) + return remove_list_redundancies(shrinked_spans) + + @abstractmethod + def get_label_span_list(self) -> list[Span]: + return [] + + @abstractmethod + def get_inserted_string_pairs( + self, use_plain_file: bool + ) -> list[tuple[Span, tuple[str, str]]]: + return [] + + @abstractmethod + def get_other_repl_items( + self, use_plain_file: bool + ) -> list[tuple[Span, str]]: + return [] + + def get_decorated_string(self, use_plain_file: bool) -> str: + span_repl_dict = self.get_span_replacement_dict( + self.get_inserted_string_pairs(use_plain_file), + self.get_other_repl_items(use_plain_file) + ) + result = self.replace_str_by_spans(self.string, span_repl_dict) + + if not use_plain_file: + return result + return "".join([ + self.get_begin_color_command_str( + self.rgb_to_int(color_to_int_rgb(self.base_color)) + ), + result, + self.get_end_color_command_str() + ]) + + @abstractmethod + def get_has_predefined_colors(self) -> bool: + return False + + # Post-parsing + + def get_containing_labels_dict(self) -> dict[Span, list[int]]: + label_span_list = self.label_span_list + result = { + span: [] + for span in label_span_list + } + for span_0 in label_span_list: + for span_index, span_1 in enumerate(label_span_list): + if self.span_contains(span_0, span_1): + result[span_0].append(span_index) + elif span_0[0] < span_1[0] < span_0[1] < span_1[1]: + string_0, string_1 = [ + self.string[slice(*span)] + for span in [span_0, span_1] + ] + raise ValueError( + "Partially overlapping substrings detected: " + f"'{string_0}' and '{string_1}'" + ) + if self.full_span not in result: + result[self.full_span] = list(range(len(label_span_list))) + return result + + def get_cleaned_substr(self, span: Span) -> str: + span_repl_dict = { + tuple([index - span[0] for index in cmd_span]): "" + for cmd_span in self.command_spans + if self.span_contains(span, cmd_span) + } + return self.replace_str_by_spans( + self.string[slice(*span)], span_repl_dict + ) + + def get_specified_substrings(self) -> list[str]: + return remove_list_redundancies([ + self.get_cleaned_substr(span) + for span in self.specified_spans + ]) + + def get_group_substr_items(self) -> tuple[list[Span], list[str]]: + group_labels, submob_spans = zip( + *self.compress_neighbours(self.submob_labels) + ) + ordered_spans = [ + self.label_span_list[label] if label != -1 else self.full_span + for label in group_labels + ] + ordered_containing_labels = [ + self.containing_labels_dict[span] + for span in ordered_spans + ] + ordered_span_begins, ordered_span_ends = zip(*ordered_spans) + span_begins = [ + prev_end if prev_label in containing_labels else curr_begin + for prev_end, prev_label, containing_labels, curr_begin in zip( + ordered_span_ends[:-1], group_labels[:-1], + ordered_containing_labels[1:], ordered_span_begins[1:] + ) + ] + span_ends = [ + next_begin if next_label in containing_labels else curr_end + for next_begin, next_label, containing_labels, curr_end in zip( + ordered_span_begins[1:], group_labels[1:], + ordered_containing_labels[:-1], ordered_span_ends[:-1] + ) + ] + spans = list(zip( + (ordered_span_begins[0], *span_begins), + (*span_ends, ordered_span_ends[-1]) + )) + shrinked_spans = [ + self.shrink_span(span, self.skipped_spans) + for span in spans + ] + group_substrs = [ + self.get_cleaned_substr(span) if span[0] < span[1] else "" + for span in shrinked_spans + ] + return submob_spans, group_substrs + + def get_group_substrs(self) -> list[str]: + return self.group_substr_items[1] + + # Selector + + def find_span_components_of_custom_span( + self, custom_span: Span + ) -> list[Span]: + indices = remove_list_redundancies(list(it.chain( + self.full_span, + *self.label_span_list + ))) + span_begin = self.take_nearest_value(indices, custom_span[0], 0) + span_end = self.take_nearest_value(indices, custom_span[1] - 1, 1) + span_choices = sorted(filter( + lambda span: self.span_contains((span_begin, span_end), span), + self.label_span_list + )) + # Choose spans that reach the farthest. + span_choices_dict = dict(span_choices) + + result = [] + while span_begin < span_end: + if span_begin not in span_choices_dict.keys(): + span_begin += 1 + continue + next_begin = span_choices_dict[span_begin] + result.append((span_begin, next_begin)) + span_begin = next_begin + return result + + def get_parts_by_custom_span(self, custom_span: Span) -> VGroup: + spans = self.find_span_components_of_custom_span(custom_span) + labels = set(it.chain(*[ + self.containing_labels_dict[span] + for span in spans + ])) + return VGroup(*filter( + lambda submob: submob.label in labels, + self.submobjects + )) + + def get_parts_by_string(self, substr: str) -> VGroup: + return VGroup(*[ + self.get_parts_by_custom_span(span) + for span in self.find_spans(re.escape(substr.strip())) + ]) + + def get_parts_by_group_substr(self, substr: str) -> VGroup: + return VGroup(*[ + VGroup(*self.submobjects[slice(*submob_span)]) + for submob_span, group_substr in zip(*self.group_substr_items) + if group_substr == substr + ]) + + def get_part_by_string(self, substr: str, index : int = 0) -> VMobject: + return self.get_parts_by_string(substr)[index] + + def set_color_by_string(self, substr: str, color: ManimColor): + self.get_parts_by_string(substr).set_color(color) + return self + + def set_color_by_string_to_color_map( + self, string_to_color_map: dict[str, ManimColor] + ): + for substr, color in string_to_color_map.items(): + self.set_color_by_string(substr, color) + return self + + def indices_of_part(self, part: Iterable[VMobject]) -> list[int]: + return [self.submobjects.index(submob) for submob in part] + + def indices_lists_of_parts( + self, parts: Iterable[Iterable[VMobject]] + ) -> list[list[int]]: + return [self.indices_of_part(part) for part in parts] + + def get_string(self) -> str: + return self.string + + +class MTex(LabelledString): + CONFIG = { + "font_size": 48, + "alignment": "\\centering", + "tex_environment": "align*", + "tex_to_color_map": {}, + } + + def __init__(self, tex_string: str, **kwargs): + digest_config(self, kwargs) + tex_string = tex_string.strip() + # Prevent from passing an empty string. + if not tex_string: + tex_string = "\\quad" + self.tex_string = tex_string + self.isolate.extend(self.tex_to_color_map.keys()) + super().__init__(tex_string, **kwargs) + + self.set_color_by_tex_to_color_map(self.tex_to_color_map) + self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size) + + @property + def hash_seed(self) -> tuple: + return ( + self.__class__.__name__, + self.svg_default, + self.path_string_config, + self.base_color, + self.use_plain_file, + self.isolate, + self.tex_string, + self.alignment, + self.tex_environment, + self.tex_to_color_map + ) + + def get_file_path_by_content(self, content: str) -> str: + full_tex = self.get_tex_file_body(content) + with display_during_execution(f"Writing \"{self.string}\""): + file_path = self.tex_to_svg_file_path(full_tex) + return file_path + + def get_tex_file_body(self, content: str) -> str: + if self.tex_environment: + content = "\n".join([ + f"\\begin{{{self.tex_environment}}}", + content, + f"\\end{{{self.tex_environment}}}" + ]) + if self.alignment: + content = "\n".join([self.alignment, content]) + + tex_config = get_tex_config() + return tex_config["tex_body"].replace( + tex_config["text_to_replace"], + content + ) + + @staticmethod + def tex_to_svg_file_path(tex_file_content: str) -> str: + return tex_to_svg_file(tex_file_content) + + def pre_parse(self) -> None: + super().pre_parse() + self.backslash_indices = self.get_backslash_indices() + self.brace_index_pairs = self.get_brace_index_pairs() + self.script_char_indices = self.get_script_char_indices() + self.script_char_spans = self.get_script_char_spans() + self.script_content_spans = self.get_script_content_spans() + self.script_spans = self.get_script_spans() + + # Toolkits + + @staticmethod + def get_begin_color_command_str(rgb_int: int) -> str: + rgb_tuple = MTex.int_to_rgb(rgb_int) + return "".join([ + "{{", + "\\color[RGB]", + "{", + ",".join(map(str, rgb_tuple)), + "}" + ]) + + @staticmethod + def get_end_color_command_str() -> str: + return "}}" + + # Pre-parsing + + def get_backslash_indices(self) -> list[int]: + # Newlines (`\\`) don't count. + return [ + span[1] - 1 + for span in self.find_spans(r"\\+") + if (span[1] - span[0]) % 2 == 1 + ] + + def get_unescaped_char_indices(self, *chars: str): + return sorted(filter( + lambda index: index - 1 not in self.backslash_indices, + [ + span[0] + for char in chars + for span in self.find_spans(re.escape(char)) + ] + )) + + def get_brace_index_pairs(self) -> list[Span]: + string = self.string + indices = self.get_unescaped_char_indices("{", "}") + left_brace_indices = [] + right_brace_indices = [] + left_brace_indices_stack = [] + for index in indices: + if string[index] == "{": + left_brace_indices_stack.append(index) + else: + if not left_brace_indices_stack: + raise ValueError("Missing '{' inserted") + left_brace_index = left_brace_indices_stack.pop() + left_brace_indices.append(left_brace_index) + right_brace_indices.append(index) + if left_brace_indices_stack: + raise ValueError("Missing '}' inserted") + return list(zip(left_brace_indices, right_brace_indices)) + + def get_script_char_indices(self) -> list[int]: + return self.get_unescaped_char_indices("_", "^") + + def get_script_char_spans(self) -> list[Span]: + return [ + self.extend_span((index, index + 1), self.space_spans) + for index in self.script_char_indices + ] + + def get_script_content_spans(self) -> list[Span]: + result = [] + brace_indices_dict = dict(self.brace_index_pairs) + for _, span_begin in self.script_char_spans: + if span_begin in brace_indices_dict.keys(): + span_end = brace_indices_dict[span_begin] + 1 + else: + pattern = re.compile(r"[a-zA-Z0-9]|\\[a-zA-Z]+") + match_obj = pattern.match(self.string, pos=span_begin) + if not match_obj: + script_name = { + "_": "subscript", + "^": "superscript" + }[script_char] + raise ValueError( + f"Unclear {script_name} detected while parsing. " + "Please use braces to clarify" + ) + span_end = match_obj.end() + result.append((span_begin, span_end)) + return result + + def get_script_spans(self) -> list[Span]: + return [ + (script_char_span[0], script_content_span[1]) + for script_char_span, script_content_span in zip( + self.script_char_spans, self.script_content_spans + ) + ] + + # Parsing + + def get_command_repl_items(self) -> list[tuple[Span, str]]: + color_related_command_dict = { + "color": (1, False), + "textcolor": (1, False), + "pagecolor": (1, True), + "colorbox": (1, True), + "fcolorbox": (2, True), + } + result = [] + backslash_indices = self.backslash_indices + right_brace_indices = [ + right_index + for left_index, right_index in self.brace_index_pairs + ] + pattern = "".join([ + r"\\", + "(", + "|".join(color_related_command_dict.keys()), + ")", + r"(?![a-zA-Z])" + ]) + for match_obj in re.finditer(pattern, self.string): + span_begin, cmd_end = match_obj.span() + if span_begin not in backslash_indices: + continue + cmd_name = match_obj.group(1) + n_braces, substitute_cmd = color_related_command_dict[cmd_name] + span_end = self.take_nearest_value( + right_brace_indices, cmd_end, n_braces + ) + 1 + if substitute_cmd: + repl_str = "\\" + cmd_name + n_braces * "{white}" + else: + repl_str = "" + result.append(((span_begin, span_end), repl_str)) + return result + + def get_ignored_indices(self) -> list[int]: + return self.script_char_indices + + def get_internal_specified_spans(self) -> list[Span]: + # Match paired double braces (`{{...}}`). + result = [] + reversed_brace_indices_dict = dict([ + pair[::-1] for pair in self.brace_index_pairs + ]) + skip = False + for prev_right_index, right_index in self.get_neighbouring_pairs( + list(reversed_brace_indices_dict.keys()) + ): + if skip: + skip = False + continue + if right_index != prev_right_index + 1: + continue + left_index = reversed_brace_indices_dict[right_index] + prev_left_index = reversed_brace_indices_dict[prev_right_index] + if left_index != prev_left_index - 1: + continue + result.append((left_index, right_index + 1)) + skip = True + return result + + def get_label_span_list(self) -> list[Span]: + result = self.script_content_spans.copy() + for span_begin, span_end in self.specified_spans: + shrinked_end = self.lslide(span_end, self.script_spans) + if span_begin >= shrinked_end: + continue + shrinked_span = (span_begin, shrinked_end) + if shrinked_span in result: + continue + result.append(shrinked_span) + return result + + def get_inserted_string_pairs( + self, use_plain_file: bool + ) -> list[tuple[Span, tuple[str, str]]]: + if use_plain_file: + return [] + + extended_label_span_list = [ + span + if span in self.script_content_spans + else (span[0], self.rslide(span[1], self.script_spans)) + for span in self.label_span_list + ] + return [ + (span, ( + self.get_begin_color_command_str(label), + self.get_end_color_command_str() + )) + for label, span in enumerate(extended_label_span_list) + ] + + def get_other_repl_items( + self, use_plain_file: bool + ) -> list[tuple[Span, str]]: + if use_plain_file: + return [] + return self.command_repl_items.copy() + + def get_has_predefined_colors(self) -> bool: + return bool(self.command_repl_items) + + def get_cleaned_substr(self, span: Span) -> str: + substr = super().get_cleaned_substr(span) + if not self.brace_index_pairs: + return substr + + # Balance braces. + left_brace_indices, right_brace_indices = zip(*self.brace_index_pairs) + unclosed_left_braces = 0 + unclosed_right_braces = 0 + for index in range(*span): + if index in left_brace_indices: + unclosed_left_braces += 1 + elif index in right_brace_indices: + if unclosed_left_braces == 0: + unclosed_right_braces += 1 + else: + unclosed_left_braces -= 1 + return "".join([ + unclosed_right_braces * "{", + substr, + unclosed_left_braces * "}" + ]) + + # Method alias + + def get_parts_by_tex(self, tex: str) -> VGroup: + return self.get_parts_by_string(tex) + + def get_part_by_tex(self, tex: str) -> VMobject: + return self.get_part_by_string(tex) + + def set_color_by_tex(self, tex: str, color: ManimColor): + return self.set_color_by_string(tex, color) + + def set_color_by_tex_to_color_map( + self, tex_to_color_map: dict[str, ManimColor] + ): + return self.set_color_by_string_to_color_map(tex_to_color_map) + + def get_tex(self) -> str: + return self.get_string() + + +class MTexText(MTex): + CONFIG = { + "tex_environment": None, + } diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 98fc6dbe..2ddd53f7 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -16,7 +16,7 @@ from manimpango import MarkupUtils from manimlib.logger import log from manimlib.constants import * -from manimlib.mobject.svg.mtex_mobject import LabelledString +from manimlib.mobject.svg.labelled_string import LabelledString from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.utils.customization import get_customization from manimlib.utils.tex_file_writing import tex_hash From a8039d803ed49934570aa8e24d9c732ebc6d16c9 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Wed, 30 Mar 2022 21:58:27 +0800 Subject: [PATCH 14/48] Rename file --- manimlib/mobject/svg/mtex_mobject.py | 843 --------------------------- 1 file changed, 843 deletions(-) delete mode 100644 manimlib/mobject/svg/mtex_mobject.py diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py deleted file mode 100644 index 55909dc6..00000000 --- a/manimlib/mobject/svg/mtex_mobject.py +++ /dev/null @@ -1,843 +0,0 @@ -from __future__ import annotations - -import re -import colour -import itertools as it -#from types import MethodType -from typing import Iterable, Union, Sequence -from abc import abstractmethod - -from manimlib.constants import WHITE -from manimlib.mobject.svg.svg_mobject import SVGMobject -from manimlib.mobject.types.vectorized_mobject import VGroup -from manimlib.utils.color import color_to_int_rgb -from manimlib.utils.config_ops import digest_config -from manimlib.utils.iterables import adjacent_pairs -from manimlib.utils.iterables import remove_list_redundancies -from manimlib.utils.tex_file_writing import tex_to_svg_file -from manimlib.utils.tex_file_writing import get_tex_config -from manimlib.utils.tex_file_writing import display_during_execution -from manimlib.logger import log - - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from manimlib.mobject.types.vectorized_mobject import VMobject - ManimColor = Union[str, colour.Color, Sequence[float]] - Span = tuple[int, int] - - -SCALE_FACTOR_PER_FONT_POINT = 0.001 - - -class _StringSVG(SVGMobject): - CONFIG = { - "height": None, - "stroke_width": 0, - "stroke_color": WHITE, - "path_string_config": { - "should_subdivide_sharp_curves": True, - "should_remove_null_curves": True, - }, - } - - -class LabelledString(_StringSVG): - """ - An abstract base class for `MTex` and `MarkupText` - """ - CONFIG = { - "base_color": WHITE, - "use_plain_file": False, - "isolate": [], - } - - def __init__(self, string: str, **kwargs): - self.string = string - digest_config(self, kwargs) - self.pre_parse() - self.parse() - super().__init__(**kwargs) - - def get_file_path(self, use_plain_file: bool = False) -> str: - content = self.get_decorated_string(use_plain_file=use_plain_file) - return self.get_file_path_by_content(content) - - @abstractmethod - def get_file_path_by_content(self, content: str) -> str: - return "" - - def generate_mobject(self) -> None: - super().generate_mobject() - - submob_labels = [ - self.color_to_label(submob.get_fill_color()) - for submob in self.submobjects - ] - if self.use_plain_file or self.has_predefined_colors: - file_path = self.get_file_path(use_plain_file=True) - plain_svg = _StringSVG(file_path) - self.set_submobjects(plain_svg.submobjects) - else: - self.set_fill(self.base_color) - for submob, label in zip(self.submobjects, submob_labels): - submob.label = label - self.submob_labels = submob_labels - self.post_parse() - - def pre_parse(self) -> None: - self.full_span = self.get_full_span() - self.space_spans = self.get_space_spans() - - def parse(self) -> None: - self.command_repl_items = self.get_command_repl_items() - self.command_spans = self.get_command_spans() - self.ignored_indices = self.get_ignored_indices() - self.skipped_spans = self.get_skipped_spans() - self.internal_specified_spans = self.get_internal_specified_spans() - self.external_specified_spans = self.get_external_specified_spans() - self.specified_spans = self.get_specified_spans() - self.label_span_list = self.get_label_span_list() - self.has_predefined_colors = self.get_has_predefined_colors() - - def post_parse(self) -> None: - self.containing_labels_dict = self.get_containing_labels_dict() - self.specified_substrings = self.get_specified_substrings() - self.group_substr_items = self.get_group_substr_items() - self.group_substrs = self.get_group_substrs() - - # Toolkits - - def find_spans(self, pattern: str) -> list[Span]: - return [ - match_obj.span() - for match_obj in re.finditer(pattern, self.string) - ] - - @staticmethod - def get_neighbouring_pairs(iterable: Iterable) -> list: - return list(adjacent_pairs(iterable))[:-1] - - @staticmethod - def span_contains(span_0: Span, span_1: Span) -> bool: - return span_0[0] <= span_1[0] and span_0[1] >= span_1[1] - - @staticmethod - def compress_neighbours(vals: list[int]) -> list[tuple[int, Span]]: - if not vals: - return [] - - unique_vals = [vals[0]] - indices = [0] - for index, val in enumerate(vals): - if val == unique_vals[-1]: - continue - unique_vals.append(val) - indices.append(index) - indices.append(len(vals)) - spans = LabelledString.get_neighbouring_pairs(indices) - return list(zip(unique_vals, spans)) - - @staticmethod - def find_region_index(seq: list[int], val: int) -> int: - # Returns an integer in `range(-1, len(seq))` satisfying - # `seq[result] <= val < seq[result + 1]`. - # `seq` should be sorted in ascending order. - if not seq or val < seq[0]: - return -1 - result = len(seq) - 1 - while val < seq[result]: - result -= 1 - return result - - @staticmethod - def take_nearest_value(seq: list[int], val: int, index_shift: int) -> int: - sorted_seq = sorted(seq) - index = LabelledString.find_region_index(sorted_seq, val) - if index == -1: - raise IndexError - return sorted_seq[index + index_shift] - - @staticmethod - def replace_str_by_spans( - substr: str, span_repl_dict: dict[Span, str] - ) -> str: - if not span_repl_dict: - return substr - - spans = sorted(span_repl_dict.keys()) - if not all( - span_0[1] <= span_1[0] - for span_0, span_1 in LabelledString.get_neighbouring_pairs(spans) - ): - raise ValueError("Overlapping replacement") - - span_ends, span_begins = zip(*spans) - pieces = [ - substr[slice(*span)] - for span in zip( - (0, *span_begins), - (*span_ends, len(substr)) - ) - ] - repl_strs = [*[span_repl_dict[span] for span in spans], ""] - return "".join(it.chain(*zip(pieces, repl_strs))) - - @staticmethod - def get_span_replacement_dict( - inserted_string_pairs: list[tuple[Span, tuple[str, str]]], - other_repl_items: list[tuple[Span, str]] - ) -> dict[Span, str]: - if not inserted_string_pairs: - return other_repl_items.copy() - - indices, _, _, inserted_strings = zip(*sorted([ - ( - span[flag], - -flag, - -span[1 - flag], - str_pair[flag] - ) - for span, str_pair in inserted_string_pairs - for flag in range(2) - ])) - result = { - (index, index): "".join(inserted_strings[slice(*item_span)]) - for index, item_span - in LabelledString.compress_neighbours(indices) - } - result.update(other_repl_items) - return result - - @staticmethod - def rslide(index: int, skipped: list[Span]) -> int: - transfer_dict = dict(sorted(skipped)) - while index in transfer_dict.keys(): - index = transfer_dict[index] - return index - - @staticmethod - def lslide(index: int, skipped: list[Span]) -> int: - transfer_dict = dict(sorted([ - skipped_span[::-1] for skipped_span in skipped - ], reverse=True)) - while index in transfer_dict.keys(): - index = transfer_dict[index] - return index - - @staticmethod - def shrink_span(span: Span, skipped: list[Span]) -> Span: - return ( - LabelledString.rslide(span[0], skipped), - LabelledString.lslide(span[1], skipped) - ) - - @staticmethod - def extend_span(span: Span, skipped: list[Span]) -> Span: - return ( - LabelledString.lslide(span[0], skipped), - LabelledString.rslide(span[1], skipped) - ) - - @staticmethod - def rgb_to_int(rgb_tuple: tuple[int, int, int]) -> int: - r, g, b = rgb_tuple - rg = r * 256 + g - return rg * 256 + b - - @staticmethod - def int_to_rgb(rgb_int: int) -> tuple[int, int, int]: - rg, b = divmod(rgb_int, 256) - r, g = divmod(rg, 256) - return r, g, b - - @staticmethod - def color_to_label(color: ManimColor) -> int: - rgb_tuple = color_to_int_rgb(color) - rgb = LabelledString.rgb_to_int(rgb_tuple) - if rgb == 16777215: # white - return -1 - return rgb - - @abstractmethod - def get_begin_color_command_str(int_rgb: int) -> str: - return "" - - @abstractmethod - def get_end_color_command_str() -> str: - return "" - - # Pre-parsing - - def get_full_span(self) -> Span: - return (0, len(self.string)) - - def get_space_spans(self) -> list[Span]: - return self.find_spans(r"\s+") - - @abstractmethod - def get_command_repl_items(self) -> list[tuple[Span, str]]: - return [] - - def get_command_spans(self) -> list[Span]: - return [cmd_span for cmd_span, _ in self.command_repl_items] - - def get_ignored_indices(self) -> list[int]: - return [] - - def get_skipped_spans(self) -> list[Span]: - return list(it.chain( - self.space_spans, - self.command_spans, - [ - (index, index + 1) - for index in self.ignored_indices - ] - )) - - @abstractmethod - def get_internal_specified_spans(self) -> list[Span]: - return [] - - def get_external_specified_spans(self) -> list[Span]: - return remove_list_redundancies(list(it.chain(*[ - self.find_spans(re.escape(substr.strip())) - for substr in self.isolate - ]))) - - def get_specified_spans(self) -> list[Span]: - spans = [ - self.full_span, - *self.internal_specified_spans, - *self.external_specified_spans - ] - shrinked_spans = list(filter( - lambda span: span[0] < span[1], - [ - self.shrink_span(span, self.skipped_spans) - for span in spans - ] - )) - return remove_list_redundancies(shrinked_spans) - - @abstractmethod - def get_label_span_list(self) -> list[Span]: - return [] - - @abstractmethod - def get_inserted_string_pairs( - self, use_plain_file: bool - ) -> list[tuple[Span, tuple[str, str]]]: - return [] - - @abstractmethod - def get_other_repl_items( - self, use_plain_file: bool - ) -> list[tuple[Span, str]]: - return [] - - def get_decorated_string(self, use_plain_file: bool) -> str: - span_repl_dict = self.get_span_replacement_dict( - self.get_inserted_string_pairs(use_plain_file), - self.get_other_repl_items(use_plain_file) - ) - result = self.replace_str_by_spans(self.string, span_repl_dict) - - if not use_plain_file: - return result - return "".join([ - self.get_begin_color_command_str( - self.rgb_to_int(color_to_int_rgb(self.base_color)) - ), - result, - self.get_end_color_command_str() - ]) - - @abstractmethod - def get_has_predefined_colors(self) -> bool: - return False - - # Post-parsing - - def get_containing_labels_dict(self) -> dict[Span, list[int]]: - label_span_list = self.label_span_list - result = { - span: [] - for span in label_span_list - } - for span_0 in label_span_list: - for span_index, span_1 in enumerate(label_span_list): - if self.span_contains(span_0, span_1): - result[span_0].append(span_index) - elif span_0[0] < span_1[0] < span_0[1] < span_1[1]: - string_0, string_1 = [ - self.string[slice(*span)] - for span in [span_0, span_1] - ] - raise ValueError( - "Partially overlapping substrings detected: " - f"'{string_0}' and '{string_1}'" - ) - if self.full_span not in result: - result[self.full_span] = list(range(len(label_span_list))) - return result - - def get_cleaned_substr(self, span: Span) -> str: - span_repl_dict = { - tuple([index - span[0] for index in cmd_span]): "" - for cmd_span in self.command_spans - if self.span_contains(span, cmd_span) - } - return self.replace_str_by_spans( - self.string[slice(*span)], span_repl_dict - ) - - def get_specified_substrings(self) -> list[str]: - return remove_list_redundancies([ - self.get_cleaned_substr(span) - for span in self.specified_spans - ]) - - def get_group_substr_items(self) -> tuple[list[Span], list[str]]: - group_labels, submob_spans = zip( - *self.compress_neighbours(self.submob_labels) - ) - ordered_spans = [ - self.label_span_list[label] if label != -1 else self.full_span - for label in group_labels - ] - ordered_containing_labels = [ - self.containing_labels_dict[span] - for span in ordered_spans - ] - ordered_span_begins, ordered_span_ends = zip(*ordered_spans) - span_begins = [ - prev_end if prev_label in containing_labels else curr_begin - for prev_end, prev_label, containing_labels, curr_begin in zip( - ordered_span_ends[:-1], group_labels[:-1], - ordered_containing_labels[1:], ordered_span_begins[1:] - ) - ] - span_ends = [ - next_begin if next_label in containing_labels else curr_end - for next_begin, next_label, containing_labels, curr_end in zip( - ordered_span_begins[1:], group_labels[1:], - ordered_containing_labels[:-1], ordered_span_ends[:-1] - ) - ] - spans = list(zip( - (ordered_span_begins[0], *span_begins), - (*span_ends, ordered_span_ends[-1]) - )) - shrinked_spans = [ - self.shrink_span(span, self.skipped_spans) - for span in spans - ] - group_substrs = [ - self.get_cleaned_substr(span) if span[0] < span[1] else "" - for span in shrinked_spans - ] - return submob_spans, group_substrs - - def get_group_substrs(self) -> list[str]: - return self.group_substr_items[1] - - # Selector - - def find_span_components_of_custom_span( - self, custom_span: Span - ) -> list[Span]: - indices = remove_list_redundancies(list(it.chain( - self.full_span, - *self.label_span_list - ))) - span_begin = self.take_nearest_value(indices, custom_span[0], 0) - span_end = self.take_nearest_value(indices, custom_span[1] - 1, 1) - span_choices = sorted(filter( - lambda span: self.span_contains((span_begin, span_end), span), - self.label_span_list - )) - # Choose spans that reach the farthest. - span_choices_dict = dict(span_choices) - - result = [] - while span_begin < span_end: - if span_begin not in span_choices_dict.keys(): - span_begin += 1 - continue - next_begin = span_choices_dict[span_begin] - result.append((span_begin, next_begin)) - span_begin = next_begin - return result - - def get_parts_by_custom_span(self, custom_span: Span) -> VGroup: - spans = self.find_span_components_of_custom_span(custom_span) - labels = set(it.chain(*[ - self.containing_labels_dict[span] - for span in spans - ])) - return VGroup(*filter( - lambda submob: submob.label in labels, - self.submobjects - )) - - def get_parts_by_string(self, substr: str) -> VGroup: - return VGroup(*[ - self.get_parts_by_custom_span(span) - for span in self.find_spans(re.escape(substr.strip())) - ]) - - def get_parts_by_group_substr(self, substr: str) -> VGroup: - return VGroup(*[ - VGroup(*self.submobjects[slice(*submob_span)]) - for submob_span, group_substr in zip(*self.group_substr_items) - if group_substr == substr - ]) - - def get_part_by_string(self, substr: str, index : int = 0) -> VMobject: - return self.get_parts_by_string(substr)[index] - - def set_color_by_string(self, substr: str, color: ManimColor): - self.get_parts_by_string(substr).set_color(color) - return self - - def set_color_by_string_to_color_map( - self, string_to_color_map: dict[str, ManimColor] - ): - for substr, color in string_to_color_map.items(): - self.set_color_by_string(substr, color) - return self - - def indices_of_part(self, part: Iterable[VMobject]) -> list[int]: - return [self.submobjects.index(submob) for submob in part] - - def indices_lists_of_parts( - self, parts: Iterable[Iterable[VMobject]] - ) -> list[list[int]]: - return [self.indices_of_part(part) for part in parts] - - def get_string(self) -> str: - return self.string - - -class MTex(LabelledString): - CONFIG = { - "font_size": 48, - "alignment": "\\centering", - "tex_environment": "align*", - "tex_to_color_map": {}, - } - - def __init__(self, tex_string: str, **kwargs): - digest_config(self, kwargs) - tex_string = tex_string.strip() - # Prevent from passing an empty string. - if not tex_string: - tex_string = "\\quad" - self.tex_string = tex_string - self.isolate.extend(self.tex_to_color_map.keys()) - super().__init__(tex_string, **kwargs) - - self.set_color_by_tex_to_color_map(self.tex_to_color_map) - self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size) - - @property - def hash_seed(self) -> tuple: - return ( - self.__class__.__name__, - self.svg_default, - self.path_string_config, - self.base_color, - self.use_plain_file, - self.isolate, - self.tex_string, - self.alignment, - self.tex_environment, - self.tex_to_color_map - ) - - def get_file_path_by_content(self, content: str) -> str: - full_tex = self.get_tex_file_body(content) - with display_during_execution(f"Writing \"{self.string}\""): - file_path = self.tex_to_svg_file_path(full_tex) - return file_path - - def get_tex_file_body(self, content: str) -> str: - if self.tex_environment: - content = "\n".join([ - f"\\begin{{{self.tex_environment}}}", - content, - f"\\end{{{self.tex_environment}}}" - ]) - if self.alignment: - content = "\n".join([self.alignment, content]) - - tex_config = get_tex_config() - return tex_config["tex_body"].replace( - tex_config["text_to_replace"], - content - ) - - @staticmethod - def tex_to_svg_file_path(tex_file_content: str) -> str: - return tex_to_svg_file(tex_file_content) - - def pre_parse(self) -> None: - super().pre_parse() - self.backslash_indices = self.get_backslash_indices() - self.brace_index_pairs = self.get_brace_index_pairs() - self.script_char_indices = self.get_script_char_indices() - self.script_char_spans = self.get_script_char_spans() - self.script_content_spans = self.get_script_content_spans() - self.script_spans = self.get_script_spans() - - # Toolkits - - @staticmethod - def get_begin_color_command_str(rgb_int: int) -> str: - rgb_tuple = MTex.int_to_rgb(rgb_int) - return "".join([ - "{{", - "\\color[RGB]", - "{", - ",".join(map(str, rgb_tuple)), - "}" - ]) - - @staticmethod - def get_end_color_command_str() -> str: - return "}}" - - # Pre-parsing - - def get_backslash_indices(self) -> list[int]: - # Newlines (`\\`) don't count. - return [ - span[1] - 1 - for span in self.find_spans(r"\\+") - if (span[1] - span[0]) % 2 == 1 - ] - - def get_unescaped_char_indices(self, *chars: str): - return sorted(filter( - lambda index: index - 1 not in self.backslash_indices, - [ - span[0] - for char in chars - for span in self.find_spans(re.escape(char)) - ] - )) - - def get_brace_index_pairs(self) -> list[Span]: - string = self.string - indices = self.get_unescaped_char_indices("{", "}") - left_brace_indices = [] - right_brace_indices = [] - left_brace_indices_stack = [] - for index in indices: - if string[index] == "{": - left_brace_indices_stack.append(index) - else: - if not left_brace_indices_stack: - raise ValueError("Missing '{' inserted") - left_brace_index = left_brace_indices_stack.pop() - left_brace_indices.append(left_brace_index) - right_brace_indices.append(index) - if left_brace_indices_stack: - raise ValueError("Missing '}' inserted") - return list(zip(left_brace_indices, right_brace_indices)) - - def get_script_char_indices(self) -> list[int]: - return self.get_unescaped_char_indices("_", "^") - - def get_script_char_spans(self) -> list[Span]: - return [ - self.extend_span((index, index + 1), self.space_spans) - for index in self.script_char_indices - ] - - def get_script_content_spans(self) -> list[Span]: - result = [] - brace_indices_dict = dict(self.brace_index_pairs) - for _, span_begin in self.script_char_spans: - if span_begin in brace_indices_dict.keys(): - span_end = brace_indices_dict[span_begin] + 1 - else: - pattern = re.compile(r"[a-zA-Z0-9]|\\[a-zA-Z]+") - match_obj = pattern.match(self.string, pos=span_begin) - if not match_obj: - script_name = { - "_": "subscript", - "^": "superscript" - }[script_char] - raise ValueError( - f"Unclear {script_name} detected while parsing. " - "Please use braces to clarify" - ) - span_end = match_obj.end() - result.append((span_begin, span_end)) - return result - - def get_script_spans(self) -> list[Span]: - return [ - (script_char_span[0], script_content_span[1]) - for script_char_span, script_content_span in zip( - self.script_char_spans, self.script_content_spans - ) - ] - - # Parsing - - def get_command_repl_items(self) -> list[tuple[Span, str]]: - color_related_command_dict = { - "color": (1, False), - "textcolor": (1, False), - "pagecolor": (1, True), - "colorbox": (1, True), - "fcolorbox": (2, True), - } - result = [] - backslash_indices = self.backslash_indices - right_brace_indices = [ - right_index - for left_index, right_index in self.brace_index_pairs - ] - pattern = "".join([ - r"\\", - "(", - "|".join(color_related_command_dict.keys()), - ")", - r"(?![a-zA-Z])" - ]) - for match_obj in re.finditer(pattern, self.string): - span_begin, cmd_end = match_obj.span() - if span_begin not in backslash_indices: - continue - cmd_name = match_obj.group(1) - n_braces, substitute_cmd = color_related_command_dict[cmd_name] - span_end = self.take_nearest_value( - right_brace_indices, cmd_end, n_braces - ) + 1 - if substitute_cmd: - repl_str = "\\" + cmd_name + n_braces * "{white}" - else: - repl_str = "" - result.append(((span_begin, span_end), repl_str)) - return result - - def get_ignored_indices(self) -> list[int]: - return self.script_char_indices - - def get_internal_specified_spans(self) -> list[Span]: - # Match paired double braces (`{{...}}`). - result = [] - reversed_brace_indices_dict = dict([ - pair[::-1] for pair in self.brace_index_pairs - ]) - skip = False - for prev_right_index, right_index in self.get_neighbouring_pairs( - list(reversed_brace_indices_dict.keys()) - ): - if skip: - skip = False - continue - if right_index != prev_right_index + 1: - continue - left_index = reversed_brace_indices_dict[right_index] - prev_left_index = reversed_brace_indices_dict[prev_right_index] - if left_index != prev_left_index - 1: - continue - result.append((left_index, right_index + 1)) - skip = True - return result - - def get_label_span_list(self) -> list[Span]: - result = self.script_content_spans.copy() - for span_begin, span_end in self.specified_spans: - shrinked_end = self.lslide(span_end, self.script_spans) - if span_begin >= shrinked_end: - continue - shrinked_span = (span_begin, shrinked_end) - if shrinked_span in result: - continue - result.append(shrinked_span) - return result - - def get_inserted_string_pairs( - self, use_plain_file: bool - ) -> list[tuple[Span, tuple[str, str]]]: - if use_plain_file: - return [] - - extended_label_span_list = [ - span - if span in self.script_content_spans - else (span[0], self.rslide(span[1], self.script_spans)) - for span in self.label_span_list - ] - return [ - (span, ( - self.get_begin_color_command_str(label), - self.get_end_color_command_str() - )) - for label, span in enumerate(extended_label_span_list) - ] - - def get_other_repl_items( - self, use_plain_file: bool - ) -> list[tuple[Span, str]]: - if use_plain_file: - return [] - return self.command_repl_items.copy() - - def get_has_predefined_colors(self) -> bool: - return bool(self.command_repl_items) - - def get_cleaned_substr(self, span: Span) -> str: - substr = super().get_cleaned_substr(span) - if not self.brace_index_pairs: - return substr - - # Balance braces. - left_brace_indices, right_brace_indices = zip(*self.brace_index_pairs) - unclosed_left_braces = 0 - unclosed_right_braces = 0 - for index in range(*span): - if index in left_brace_indices: - unclosed_left_braces += 1 - elif index in right_brace_indices: - if unclosed_left_braces == 0: - unclosed_right_braces += 1 - else: - unclosed_left_braces -= 1 - return "".join([ - unclosed_right_braces * "{", - substr, - unclosed_left_braces * "}" - ]) - - # Method alias - - def get_parts_by_tex(self, tex: str) -> VGroup: - return self.get_parts_by_string(tex) - - def get_part_by_tex(self, tex: str) -> VMobject: - return self.get_part_by_string(tex) - - def set_color_by_tex(self, tex: str, color: ManimColor): - return self.set_color_by_string(tex, color) - - def set_color_by_tex_to_color_map( - self, tex_to_color_map: dict[str, ManimColor] - ): - return self.set_color_by_string_to_color_map(tex_to_color_map) - - def get_tex(self) -> str: - return self.get_string() - - -class MTexText(MTex): - CONFIG = { - "tex_environment": None, - } From 9bbbed3a83ab205d114f760157afb41d9bb6c93a Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Wed, 30 Mar 2022 22:04:10 +0800 Subject: [PATCH 15/48] Remove comment --- manimlib/mobject/svg/labelled_string.py | 1 - 1 file changed, 1 deletion(-) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 55909dc6..144f239c 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -3,7 +3,6 @@ from __future__ import annotations import re import colour import itertools as it -#from types import MethodType from typing import Iterable, Union, Sequence from abc import abstractmethod From 637d7791900dca0db087f4e877c9b04f76c0f197 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Wed, 30 Mar 2022 22:09:26 +0800 Subject: [PATCH 16/48] Fix empty zipping bug --- manimlib/mobject/svg/labelled_string.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 144f239c..24b1f5fe 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -399,6 +399,9 @@ class LabelledString(_StringSVG): ]) def get_group_substr_items(self) -> tuple[list[Span], list[str]]: + if not self.submob_labels: + return [], [] + group_labels, submob_spans = zip( *self.compress_neighbours(self.submob_labels) ) From fc4f64957064198e0116afe25f849e4d1f304736 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Thu, 31 Mar 2022 10:36:14 +0800 Subject: [PATCH 17/48] Fix bugs brought by empty strings --- manimlib/__init__.py | 2 +- manimlib/animation/creation.py | 25 ++++++-------- manimlib/mobject/svg/labelled_string.py | 46 ++++++++++++++++++------- manimlib/mobject/svg/text_mobject.py | 4 +-- 4 files changed, 47 insertions(+), 30 deletions(-) diff --git a/manimlib/__init__.py b/manimlib/__init__.py index 13e41ec0..954a56d7 100644 --- a/manimlib/__init__.py +++ b/manimlib/__init__.py @@ -37,7 +37,7 @@ from manimlib.mobject.probability import * from manimlib.mobject.shape_matchers import * from manimlib.mobject.svg.brace import * from manimlib.mobject.svg.drawings import * -from manimlib.mobject.svg.mtex_mobject import * +from manimlib.mobject.svg.labelled_string import * from manimlib.mobject.svg.svg_mobject import * from manimlib.mobject.svg.tex_mobject import * from manimlib.mobject.svg.text_mobject import * diff --git a/manimlib/animation/creation.py b/manimlib/animation/creation.py index 00588b46..82fb1605 100644 --- a/manimlib/animation/creation.py +++ b/manimlib/animation/creation.py @@ -7,6 +7,7 @@ import numpy as np from manimlib.animation.animation import Animation from manimlib.animation.composition import Succession +from manimlib.mobject.svg.labelled_string import LabelledString from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.bezier import integer_interpolate from manimlib.utils.config_ops import digest_config @@ -202,23 +203,19 @@ class ShowSubmobjectsOneByOne(ShowIncreasingSubsets): self.mobject.set_submobjects([self.all_submobs[index - 1]]) -# TODO, this is broken... -class AddTextWordByWord(Succession): +class AddTextWordByWord(ShowIncreasingSubsets): CONFIG = { # If given a value for run_time, it will - # override the time_per_char + # override the time_per_word "run_time": None, - "time_per_char": 0.06, + "time_per_word": 0.2, + "rate_func": linear, } - def __init__(self, text_mobject, **kwargs): + def __init__(self, string_mobject, **kwargs): + assert isinstance(string_mobject, LabelledString) + grouped_mobject = string_mobject.get_submob_groups() digest_config(self, kwargs) - tpc = self.time_per_char - anims = it.chain(*[ - [ - ShowIncreasingSubsets(word, run_time=tpc * len(word)), - Animation(word, run_time=0.005 * len(word)**1.5), - ] - for word in text_mobject - ]) - super().__init__(*anims, **kwargs) + if self.run_time is None: + self.run_time = self.time_per_word * len(grouped_mobject) + super().__init__(grouped_mobject, **kwargs) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 24b1f5fe..a0d66c4d 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -103,7 +103,6 @@ class LabelledString(_StringSVG): def post_parse(self) -> None: self.containing_labels_dict = self.get_containing_labels_dict() self.specified_substrings = self.get_specified_substrings() - self.group_substr_items = self.get_group_substr_items() self.group_substrs = self.get_group_substrs() # Toolkits @@ -275,6 +274,8 @@ class LabelledString(_StringSVG): def get_space_spans(self) -> list[Span]: return self.find_spans(r"\s+") + # Parsing + @abstractmethod def get_command_repl_items(self) -> list[tuple[Span, str]]: return [] @@ -300,8 +301,13 @@ class LabelledString(_StringSVG): return [] def get_external_specified_spans(self) -> list[Span]: + if "" in self.isolate: + return self.get_neighbouring_pairs( + list(range(len(self.string) + 1)) + ) + return remove_list_redundancies(list(it.chain(*[ - self.find_spans(re.escape(substr.strip())) + self.find_spans(re.escape(substr)) for substr in self.isolate ]))) @@ -398,13 +404,15 @@ class LabelledString(_StringSVG): for span in self.specified_spans ]) - def get_group_substr_items(self) -> tuple[list[Span], list[str]]: + def get_group_span_items(self) -> tuple[list[int], list[Span]]: if not self.submob_labels: return [], [] - - group_labels, submob_spans = zip( + return tuple(zip( *self.compress_neighbours(self.submob_labels) - ) + )) + + def get_group_substrs(self) -> list[str]: + group_labels, _ = self.get_group_span_items() ordered_spans = [ self.label_span_list[label] if label != -1 else self.full_span for label in group_labels @@ -440,16 +448,23 @@ class LabelledString(_StringSVG): self.get_cleaned_substr(span) if span[0] < span[1] else "" for span in shrinked_spans ] - return submob_spans, group_substrs + return group_substrs - def get_group_substrs(self) -> list[str]: - return self.group_substr_items[1] + def get_submob_groups(self) -> VGroup: + _, submob_spans = self.get_group_span_items() + return VGroup(*[ + VGroup(*self.submobjects[slice(*submob_span)]) + for submob_span in submob_spans + ]) # Selector def find_span_components_of_custom_span( self, custom_span: Span ) -> list[Span]: + if custom_span[0] >= custom_span[1]: + return [] + indices = remove_list_redundancies(list(it.chain( self.full_span, *self.label_span_list @@ -485,15 +500,19 @@ class LabelledString(_StringSVG): )) def get_parts_by_string(self, substr: str) -> VGroup: + if not substr: + return VGroup() return VGroup(*[ self.get_parts_by_custom_span(span) - for span in self.find_spans(re.escape(substr.strip())) + for span in self.find_spans(re.escape(substr)) ]) def get_parts_by_group_substr(self, substr: str) -> VGroup: return VGroup(*[ - VGroup(*self.submobjects[slice(*submob_span)]) - for submob_span, group_substr in zip(*self.group_substr_items) + group + for group, group_substr in zip( + self.get_submob_groups(), self.group_substrs + ) if group_substr == substr ]) @@ -533,7 +552,6 @@ class MTex(LabelledString): def __init__(self, tex_string: str, **kwargs): digest_config(self, kwargs) - tex_string = tex_string.strip() # Prevent from passing an empty string. if not tex_string: tex_string = "\\quad" @@ -796,6 +814,8 @@ class MTex(LabelledString): def get_has_predefined_colors(self) -> bool: return bool(self.command_repl_items) + # Post-parsing + def get_cleaned_substr(self, span: Span) -> str: substr = super().get_cleaned_substr(span) if not self.brace_index_pairs: diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 2ddd53f7..83637b01 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -234,13 +234,13 @@ class MarkupText(LabelledString): pango_width=pango_width ) - def parse(self) -> None: + def pre_parse(self) -> None: + super().pre_parse() self.global_items_from_config = self.get_global_items_from_config() self.tag_items_from_markup = self.get_tag_items_from_markup() self.local_items_from_markup = self.get_local_items_from_markup() self.local_items_from_config = self.get_local_items_from_config() self.predefined_items = self.get_predefined_items() - super().parse() # Toolkits From 461500637e07ac0b6266a85a18b30f3d13810e1d Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Thu, 31 Mar 2022 10:57:25 +0800 Subject: [PATCH 18/48] Fix type bug --- manimlib/mobject/svg/labelled_string.py | 10 +++++----- manimlib/mobject/svg/text_mobject.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index a0d66c4d..989f2cb9 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -187,8 +187,9 @@ class LabelledString(_StringSVG): inserted_string_pairs: list[tuple[Span, tuple[str, str]]], other_repl_items: list[tuple[Span, str]] ) -> dict[Span, str]: + result = dict(other_repl_items) if not inserted_string_pairs: - return other_repl_items.copy() + return result indices, _, _, inserted_strings = zip(*sorted([ ( @@ -200,12 +201,11 @@ class LabelledString(_StringSVG): for span, str_pair in inserted_string_pairs for flag in range(2) ])) - result = { + result.update({ (index, index): "".join(inserted_strings[slice(*item_span)]) for index, item_span in LabelledString.compress_neighbours(indices) - } - result.update(other_repl_items) + }) return result @staticmethod @@ -272,7 +272,7 @@ class LabelledString(_StringSVG): return (0, len(self.string)) def get_space_spans(self) -> list[Span]: - return self.find_spans(r"\s+") + return self.find_spans(r"\s") # Parsing diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 83637b01..2b6211d8 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -470,7 +470,7 @@ class MarkupText(LabelledString): def get_label_span_list(self) -> list[Span]: breakup_indices = remove_list_redundancies(list(it.chain(*it.chain( - self.space_spans, + self.find_spans(r"\s+"), self.find_spans(r"\b"), self.specified_spans )))) From 724a500cc6d8a281fee434775c19353ab756b82a Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Thu, 31 Mar 2022 11:20:42 +0800 Subject: [PATCH 19/48] Fix shallow copying bug --- manimlib/mobject/svg/labelled_string.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 989f2cb9..120ce48b 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -92,7 +92,7 @@ class LabelledString(_StringSVG): def parse(self) -> None: self.command_repl_items = self.get_command_repl_items() self.command_spans = self.get_command_spans() - self.ignored_indices = self.get_ignored_indices() + self.ignored_spans = self.get_ignored_spans() self.skipped_spans = self.get_skipped_spans() self.internal_specified_spans = self.get_internal_specified_spans() self.external_specified_spans = self.get_external_specified_spans() @@ -283,17 +283,14 @@ class LabelledString(_StringSVG): def get_command_spans(self) -> list[Span]: return [cmd_span for cmd_span, _ in self.command_repl_items] - def get_ignored_indices(self) -> list[int]: + def get_ignored_spans(self) -> list[int]: return [] def get_skipped_spans(self) -> list[Span]: return list(it.chain( self.space_spans, self.command_spans, - [ - (index, index + 1) - for index in self.ignored_indices - ] + self.ignored_spans )) @abstractmethod @@ -746,8 +743,11 @@ class MTex(LabelledString): result.append(((span_begin, span_end), repl_str)) return result - def get_ignored_indices(self) -> list[int]: - return self.script_char_indices + def get_ignored_spans(self) -> list[int]: + return [ + (index, index + 1) + for index in self.script_char_indices + ] def get_internal_specified_spans(self) -> list[Span]: # Match paired double braces (`{{...}}`). From 106f2a3837738f3ea46166b18be318c6a3fad2e9 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Thu, 31 Mar 2022 11:36:50 +0800 Subject: [PATCH 20/48] Fix shallow copying bug --- manimlib/mobject/svg/labelled_string.py | 30 +++++++++---------------- 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 120ce48b..41fe4225 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -231,13 +231,6 @@ class LabelledString(_StringSVG): LabelledString.lslide(span[1], skipped) ) - @staticmethod - def extend_span(span: Span, skipped: list[Span]) -> Span: - return ( - LabelledString.lslide(span[0], skipped), - LabelledString.rslide(span[1], skipped) - ) - @staticmethod def rgb_to_int(rgb_tuple: tuple[int, int, int]) -> int: r, g, b = rgb_tuple @@ -604,7 +597,6 @@ class MTex(LabelledString): super().pre_parse() self.backslash_indices = self.get_backslash_indices() self.brace_index_pairs = self.get_brace_index_pairs() - self.script_char_indices = self.get_script_char_indices() self.script_char_spans = self.get_script_char_spans() self.script_content_spans = self.get_script_content_spans() self.script_spans = self.get_script_spans() @@ -665,19 +657,17 @@ class MTex(LabelledString): raise ValueError("Missing '}' inserted") return list(zip(left_brace_indices, right_brace_indices)) - def get_script_char_indices(self) -> list[int]: - return self.get_unescaped_char_indices("_", "^") - - def get_script_char_spans(self) -> list[Span]: + def get_script_char_spans(self) -> list[int]: return [ - self.extend_span((index, index + 1), self.space_spans) - for index in self.script_char_indices + (index, index + 1) + for index in self.get_unescaped_char_indices("_", "^") ] def get_script_content_spans(self) -> list[Span]: result = [] brace_indices_dict = dict(self.brace_index_pairs) - for _, span_begin in self.script_char_spans: + for script_char_span in self.script_char_spans: + span_begin = self.rslide(script_char_span[1], self.space_spans) if span_begin in brace_indices_dict.keys(): span_end = brace_indices_dict[span_begin] + 1 else: @@ -698,7 +688,10 @@ class MTex(LabelledString): def get_script_spans(self) -> list[Span]: return [ - (script_char_span[0], script_content_span[1]) + ( + self.lslide(script_char_span[0], self.space_spans), + script_content_span[1] + ) for script_char_span, script_content_span in zip( self.script_char_spans, self.script_content_spans ) @@ -744,10 +737,7 @@ class MTex(LabelledString): return result def get_ignored_spans(self) -> list[int]: - return [ - (index, index + 1) - for index in self.script_char_indices - ] + return self.script_char_spans.copy() def get_internal_specified_spans(self) -> list[Span]: # Match paired double braces (`{{...}}`). From d5ab9a91c4de06b43234ce39deb1f179148473bf Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Thu, 31 Mar 2022 16:15:58 +0800 Subject: [PATCH 21/48] Reorganize files --- manimlib/__init__.py | 1 + .../animation/transform_matching_parts.py | 20 +- manimlib/mobject/svg/labelled_string.py | 417 ++---------------- manimlib/mobject/svg/mtex_mobject.py | 338 ++++++++++++++ manimlib/mobject/svg/text_mobject.py | 40 +- 5 files changed, 409 insertions(+), 407 deletions(-) create mode 100644 manimlib/mobject/svg/mtex_mobject.py diff --git a/manimlib/__init__.py b/manimlib/__init__.py index 954a56d7..a0147cf7 100644 --- a/manimlib/__init__.py +++ b/manimlib/__init__.py @@ -38,6 +38,7 @@ from manimlib.mobject.shape_matchers import * from manimlib.mobject.svg.brace import * from manimlib.mobject.svg.drawings import * from manimlib.mobject.svg.labelled_string import * +from manimlib.mobject.svg.mtex_mobject import * from manimlib.mobject.svg.svg_mobject import * from manimlib.mobject.svg.tex_mobject import * from manimlib.mobject.svg.text_mobject import * diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index 2b452d2d..f824663d 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -199,7 +199,7 @@ class TransformMatchingStrings(AnimationGroup): filtered_source_indices_lists, filtered_target_indices_lists ]): - return + continue anims.append(anim_class(source_parts, target_parts, **kwargs)) for index in it.chain(*filtered_source_indices_lists): rest_source_indices.remove(index) @@ -207,12 +207,10 @@ class TransformMatchingStrings(AnimationGroup): rest_target_indices.remove(index) def get_common_substrs(func): - result = sorted(list( - set(func(source_mobject)).intersection(func(target_mobject)) - ), key=len, reverse=True) - if "" in result: - result.remove("") - return result + return sorted([ + substr for substr in func(source_mobject) + if substr and substr in func(target_mobject) + ], key=len, reverse=True) def get_parts_from_keys(mobject, keys): if not isinstance(keys, tuple): @@ -241,16 +239,12 @@ class TransformMatchingStrings(AnimationGroup): add_anims_from( FadeTransformPieces, LabelledString.get_parts_by_string, - get_common_substrs( - lambda mobject: mobject.specified_substrings - ) + get_common_substrs(LabelledString.get_specified_substrs) ) add_anims_from( FadeTransformPieces, LabelledString.get_parts_by_group_substr, - get_common_substrs( - lambda mobject: mobject.group_substrs - ) + get_common_substrs(LabelledString.get_group_substrs) ) fade_source = VGroup(*[ diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 41fe4225..8f89ab2b 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -13,10 +13,6 @@ from manimlib.utils.color import color_to_int_rgb from manimlib.utils.config_ops import digest_config from manimlib.utils.iterables import adjacent_pairs from manimlib.utils.iterables import remove_list_redundancies -from manimlib.utils.tex_file_writing import tex_to_svg_file -from manimlib.utils.tex_file_writing import get_tex_config -from manimlib.utils.tex_file_writing import display_during_execution -from manimlib.logger import log from typing import TYPE_CHECKING @@ -27,9 +23,6 @@ if TYPE_CHECKING: Span = tuple[int, int] -SCALE_FACTOR_PER_FONT_POINT = 0.001 - - class _StringSVG(SVGMobject): CONFIG = { "height": None, @@ -82,8 +75,6 @@ class LabelledString(_StringSVG): self.set_fill(self.base_color) for submob, label in zip(self.submobjects, submob_labels): submob.label = label - self.submob_labels = submob_labels - self.post_parse() def pre_parse(self) -> None: self.full_span = self.get_full_span() @@ -98,12 +89,8 @@ class LabelledString(_StringSVG): self.external_specified_spans = self.get_external_specified_spans() self.specified_spans = self.get_specified_spans() self.label_span_list = self.get_label_span_list() - self.has_predefined_colors = self.get_has_predefined_colors() - - def post_parse(self) -> None: self.containing_labels_dict = self.get_containing_labels_dict() - self.specified_substrings = self.get_specified_substrings() - self.group_substrs = self.get_group_substrs() + self.has_predefined_colors = self.get_has_predefined_colors() # Toolkits @@ -113,6 +100,12 @@ class LabelledString(_StringSVG): for match_obj in re.finditer(pattern, self.string) ] + def find_substr(self, *substrs: str) -> list[Span]: + return list(it.chain(*[ + self.find_spans(re.escape(substr)) if substr else [] + for substr in remove_list_redundancies(substrs) + ])) + @staticmethod def get_neighbouring_pairs(iterable: Iterable) -> list: return list(adjacent_pairs(iterable))[:-1] @@ -290,22 +283,16 @@ class LabelledString(_StringSVG): def get_internal_specified_spans(self) -> list[Span]: return [] + @abstractmethod def get_external_specified_spans(self) -> list[Span]: - if "" in self.isolate: - return self.get_neighbouring_pairs( - list(range(len(self.string) + 1)) - ) - - return remove_list_redundancies(list(it.chain(*[ - self.find_spans(re.escape(substr)) - for substr in self.isolate - ]))) + return [] def get_specified_spans(self) -> list[Span]: spans = [ self.full_span, *self.internal_specified_spans, - *self.external_specified_spans + *self.external_specified_spans, + *self.find_substr(*self.isolate) ] shrinked_spans = list(filter( lambda span: span[0] < span[1], @@ -320,6 +307,29 @@ class LabelledString(_StringSVG): def get_label_span_list(self) -> list[Span]: return [] + def get_containing_labels_dict(self) -> dict[Span, list[int]]: + label_span_list = self.label_span_list + result = { + span: [] + for span in label_span_list + } + for span_0 in label_span_list: + for span_index, span_1 in enumerate(label_span_list): + if self.span_contains(span_0, span_1): + result[span_0].append(span_index) + elif span_0[0] < span_1[0] < span_0[1] < span_1[1]: + string_0, string_1 = [ + self.string[slice(*span)] + for span in [span_0, span_1] + ] + raise ValueError( + "Partially overlapping substrings detected: " + f"'{string_0}' and '{string_1}'" + ) + if self.full_span not in result: + result[self.full_span] = list(range(len(label_span_list))) + return result + @abstractmethod def get_inserted_string_pairs( self, use_plain_file: bool @@ -355,29 +365,6 @@ class LabelledString(_StringSVG): # Post-parsing - def get_containing_labels_dict(self) -> dict[Span, list[int]]: - label_span_list = self.label_span_list - result = { - span: [] - for span in label_span_list - } - for span_0 in label_span_list: - for span_index, span_1 in enumerate(label_span_list): - if self.span_contains(span_0, span_1): - result[span_0].append(span_index) - elif span_0[0] < span_1[0] < span_0[1] < span_1[1]: - string_0, string_1 = [ - self.string[slice(*span)] - for span in [span_0, span_1] - ] - raise ValueError( - "Partially overlapping substrings detected: " - f"'{string_0}' and '{string_1}'" - ) - if self.full_span not in result: - result[self.full_span] = list(range(len(label_span_list))) - return result - def get_cleaned_substr(self, span: Span) -> str: span_repl_dict = { tuple([index - span[0] for index in cmd_span]): "" @@ -388,18 +375,17 @@ class LabelledString(_StringSVG): self.string[slice(*span)], span_repl_dict ) - def get_specified_substrings(self) -> list[str]: + def get_specified_substrs(self) -> list[str]: return remove_list_redundancies([ self.get_cleaned_substr(span) for span in self.specified_spans ]) def get_group_span_items(self) -> tuple[list[int], list[Span]]: - if not self.submob_labels: + submob_labels = [submob.label for submob in self.submobjects] + if not submob_labels: return [], [] - return tuple(zip( - *self.compress_neighbours(self.submob_labels) - )) + return tuple(zip(*self.compress_neighbours(submob_labels))) def get_group_substrs(self) -> list[str]: group_labels, _ = self.get_group_span_items() @@ -494,14 +480,14 @@ class LabelledString(_StringSVG): return VGroup() return VGroup(*[ self.get_parts_by_custom_span(span) - for span in self.find_spans(re.escape(substr)) + for span in self.find_substr(substr) ]) def get_parts_by_group_substr(self, substr: str) -> VGroup: return VGroup(*[ group for group, group_substr in zip( - self.get_submob_groups(), self.group_substrs + self.get_submob_groups(), self.get_group_substrs() ) if group_substr == substr ]) @@ -530,326 +516,3 @@ class LabelledString(_StringSVG): def get_string(self) -> str: return self.string - - -class MTex(LabelledString): - CONFIG = { - "font_size": 48, - "alignment": "\\centering", - "tex_environment": "align*", - "tex_to_color_map": {}, - } - - def __init__(self, tex_string: str, **kwargs): - digest_config(self, kwargs) - # Prevent from passing an empty string. - if not tex_string: - tex_string = "\\quad" - self.tex_string = tex_string - self.isolate.extend(self.tex_to_color_map.keys()) - super().__init__(tex_string, **kwargs) - - self.set_color_by_tex_to_color_map(self.tex_to_color_map) - self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size) - - @property - def hash_seed(self) -> tuple: - return ( - self.__class__.__name__, - self.svg_default, - self.path_string_config, - self.base_color, - self.use_plain_file, - self.isolate, - self.tex_string, - self.alignment, - self.tex_environment, - self.tex_to_color_map - ) - - def get_file_path_by_content(self, content: str) -> str: - full_tex = self.get_tex_file_body(content) - with display_during_execution(f"Writing \"{self.string}\""): - file_path = self.tex_to_svg_file_path(full_tex) - return file_path - - def get_tex_file_body(self, content: str) -> str: - if self.tex_environment: - content = "\n".join([ - f"\\begin{{{self.tex_environment}}}", - content, - f"\\end{{{self.tex_environment}}}" - ]) - if self.alignment: - content = "\n".join([self.alignment, content]) - - tex_config = get_tex_config() - return tex_config["tex_body"].replace( - tex_config["text_to_replace"], - content - ) - - @staticmethod - def tex_to_svg_file_path(tex_file_content: str) -> str: - return tex_to_svg_file(tex_file_content) - - def pre_parse(self) -> None: - super().pre_parse() - self.backslash_indices = self.get_backslash_indices() - self.brace_index_pairs = self.get_brace_index_pairs() - self.script_char_spans = self.get_script_char_spans() - self.script_content_spans = self.get_script_content_spans() - self.script_spans = self.get_script_spans() - - # Toolkits - - @staticmethod - def get_begin_color_command_str(rgb_int: int) -> str: - rgb_tuple = MTex.int_to_rgb(rgb_int) - return "".join([ - "{{", - "\\color[RGB]", - "{", - ",".join(map(str, rgb_tuple)), - "}" - ]) - - @staticmethod - def get_end_color_command_str() -> str: - return "}}" - - # Pre-parsing - - def get_backslash_indices(self) -> list[int]: - # Newlines (`\\`) don't count. - return [ - span[1] - 1 - for span in self.find_spans(r"\\+") - if (span[1] - span[0]) % 2 == 1 - ] - - def get_unescaped_char_indices(self, *chars: str): - return sorted(filter( - lambda index: index - 1 not in self.backslash_indices, - [ - span[0] - for char in chars - for span in self.find_spans(re.escape(char)) - ] - )) - - def get_brace_index_pairs(self) -> list[Span]: - string = self.string - indices = self.get_unescaped_char_indices("{", "}") - left_brace_indices = [] - right_brace_indices = [] - left_brace_indices_stack = [] - for index in indices: - if string[index] == "{": - left_brace_indices_stack.append(index) - else: - if not left_brace_indices_stack: - raise ValueError("Missing '{' inserted") - left_brace_index = left_brace_indices_stack.pop() - left_brace_indices.append(left_brace_index) - right_brace_indices.append(index) - if left_brace_indices_stack: - raise ValueError("Missing '}' inserted") - return list(zip(left_brace_indices, right_brace_indices)) - - def get_script_char_spans(self) -> list[int]: - return [ - (index, index + 1) - for index in self.get_unescaped_char_indices("_", "^") - ] - - def get_script_content_spans(self) -> list[Span]: - result = [] - brace_indices_dict = dict(self.brace_index_pairs) - for script_char_span in self.script_char_spans: - span_begin = self.rslide(script_char_span[1], self.space_spans) - if span_begin in brace_indices_dict.keys(): - span_end = brace_indices_dict[span_begin] + 1 - else: - pattern = re.compile(r"[a-zA-Z0-9]|\\[a-zA-Z]+") - match_obj = pattern.match(self.string, pos=span_begin) - if not match_obj: - script_name = { - "_": "subscript", - "^": "superscript" - }[script_char] - raise ValueError( - f"Unclear {script_name} detected while parsing. " - "Please use braces to clarify" - ) - span_end = match_obj.end() - result.append((span_begin, span_end)) - return result - - def get_script_spans(self) -> list[Span]: - return [ - ( - self.lslide(script_char_span[0], self.space_spans), - script_content_span[1] - ) - for script_char_span, script_content_span in zip( - self.script_char_spans, self.script_content_spans - ) - ] - - # Parsing - - def get_command_repl_items(self) -> list[tuple[Span, str]]: - color_related_command_dict = { - "color": (1, False), - "textcolor": (1, False), - "pagecolor": (1, True), - "colorbox": (1, True), - "fcolorbox": (2, True), - } - result = [] - backslash_indices = self.backslash_indices - right_brace_indices = [ - right_index - for left_index, right_index in self.brace_index_pairs - ] - pattern = "".join([ - r"\\", - "(", - "|".join(color_related_command_dict.keys()), - ")", - r"(?![a-zA-Z])" - ]) - for match_obj in re.finditer(pattern, self.string): - span_begin, cmd_end = match_obj.span() - if span_begin not in backslash_indices: - continue - cmd_name = match_obj.group(1) - n_braces, substitute_cmd = color_related_command_dict[cmd_name] - span_end = self.take_nearest_value( - right_brace_indices, cmd_end, n_braces - ) + 1 - if substitute_cmd: - repl_str = "\\" + cmd_name + n_braces * "{white}" - else: - repl_str = "" - result.append(((span_begin, span_end), repl_str)) - return result - - def get_ignored_spans(self) -> list[int]: - return self.script_char_spans.copy() - - def get_internal_specified_spans(self) -> list[Span]: - # Match paired double braces (`{{...}}`). - result = [] - reversed_brace_indices_dict = dict([ - pair[::-1] for pair in self.brace_index_pairs - ]) - skip = False - for prev_right_index, right_index in self.get_neighbouring_pairs( - list(reversed_brace_indices_dict.keys()) - ): - if skip: - skip = False - continue - if right_index != prev_right_index + 1: - continue - left_index = reversed_brace_indices_dict[right_index] - prev_left_index = reversed_brace_indices_dict[prev_right_index] - if left_index != prev_left_index - 1: - continue - result.append((left_index, right_index + 1)) - skip = True - return result - - def get_label_span_list(self) -> list[Span]: - result = self.script_content_spans.copy() - for span_begin, span_end in self.specified_spans: - shrinked_end = self.lslide(span_end, self.script_spans) - if span_begin >= shrinked_end: - continue - shrinked_span = (span_begin, shrinked_end) - if shrinked_span in result: - continue - result.append(shrinked_span) - return result - - def get_inserted_string_pairs( - self, use_plain_file: bool - ) -> list[tuple[Span, tuple[str, str]]]: - if use_plain_file: - return [] - - extended_label_span_list = [ - span - if span in self.script_content_spans - else (span[0], self.rslide(span[1], self.script_spans)) - for span in self.label_span_list - ] - return [ - (span, ( - self.get_begin_color_command_str(label), - self.get_end_color_command_str() - )) - for label, span in enumerate(extended_label_span_list) - ] - - def get_other_repl_items( - self, use_plain_file: bool - ) -> list[tuple[Span, str]]: - if use_plain_file: - return [] - return self.command_repl_items.copy() - - def get_has_predefined_colors(self) -> bool: - return bool(self.command_repl_items) - - # Post-parsing - - def get_cleaned_substr(self, span: Span) -> str: - substr = super().get_cleaned_substr(span) - if not self.brace_index_pairs: - return substr - - # Balance braces. - left_brace_indices, right_brace_indices = zip(*self.brace_index_pairs) - unclosed_left_braces = 0 - unclosed_right_braces = 0 - for index in range(*span): - if index in left_brace_indices: - unclosed_left_braces += 1 - elif index in right_brace_indices: - if unclosed_left_braces == 0: - unclosed_right_braces += 1 - else: - unclosed_left_braces -= 1 - return "".join([ - unclosed_right_braces * "{", - substr, - unclosed_left_braces * "}" - ]) - - # Method alias - - def get_parts_by_tex(self, tex: str) -> VGroup: - return self.get_parts_by_string(tex) - - def get_part_by_tex(self, tex: str) -> VMobject: - return self.get_part_by_string(tex) - - def set_color_by_tex(self, tex: str, color: ManimColor): - return self.set_color_by_string(tex, color) - - def set_color_by_tex_to_color_map( - self, tex_to_color_map: dict[str, ManimColor] - ): - return self.set_color_by_string_to_color_map(tex_to_color_map) - - def get_tex(self) -> str: - return self.get_string() - - -class MTexText(MTex): - CONFIG = { - "tex_environment": None, - } diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py new file mode 100644 index 00000000..a79c8f7c --- /dev/null +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -0,0 +1,338 @@ +from __future__ import annotations + +import re +import colour +from typing import Union, Sequence + +from manimlib.mobject.svg.labelled_string import LabelledString +from manimlib.utils.tex_file_writing import tex_to_svg_file +from manimlib.utils.tex_file_writing import get_tex_config +from manimlib.utils.tex_file_writing import display_during_execution + + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from manimlib.mobject.types.vectorized_mobject import VMobject + from manimlib.mobject.types.vectorized_mobject import VGroup + ManimColor = Union[str, colour.Color, Sequence[float]] + Span = tuple[int, int] + + +SCALE_FACTOR_PER_FONT_POINT = 0.001 + + +class MTex(LabelledString): + CONFIG = { + "font_size": 48, + "alignment": "\\centering", + "tex_environment": "align*", + "tex_to_color_map": {}, + } + + def __init__(self, tex_string: str, **kwargs): + # Prevent from passing an empty string. + if not tex_string: + tex_string = "\\quad" + self.tex_string = tex_string + super().__init__(tex_string, **kwargs) + + self.set_color_by_tex_to_color_map(self.tex_to_color_map) + self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size) + + @property + def hash_seed(self) -> tuple: + return ( + self.__class__.__name__, + self.svg_default, + self.path_string_config, + self.base_color, + self.use_plain_file, + self.isolate, + self.tex_string, + self.alignment, + self.tex_environment, + self.tex_to_color_map + ) + + def get_file_path_by_content(self, content: str) -> str: + full_tex = self.get_tex_file_body(content) + with display_during_execution(f"Writing \"{self.string}\""): + file_path = self.tex_to_svg_file_path(full_tex) + return file_path + + def get_tex_file_body(self, content: str) -> str: + if self.tex_environment: + content = "\n".join([ + f"\\begin{{{self.tex_environment}}}", + content, + f"\\end{{{self.tex_environment}}}" + ]) + if self.alignment: + content = "\n".join([self.alignment, content]) + + tex_config = get_tex_config() + return tex_config["tex_body"].replace( + tex_config["text_to_replace"], + content + ) + + @staticmethod + def tex_to_svg_file_path(tex_file_content: str) -> str: + return tex_to_svg_file(tex_file_content) + + def pre_parse(self) -> None: + super().pre_parse() + self.backslash_indices = self.get_backslash_indices() + self.brace_index_pairs = self.get_brace_index_pairs() + self.script_char_spans = self.get_script_char_spans() + self.script_content_spans = self.get_script_content_spans() + self.script_spans = self.get_script_spans() + + # Toolkits + + @staticmethod + def get_begin_color_command_str(rgb_int: int) -> str: + rgb_tuple = MTex.int_to_rgb(rgb_int) + return "".join([ + "{{", + "\\color[RGB]", + "{", + ",".join(map(str, rgb_tuple)), + "}" + ]) + + @staticmethod + def get_end_color_command_str() -> str: + return "}}" + + # Pre-parsing + + def get_backslash_indices(self) -> list[int]: + # Newlines (`\\`) don't count. + return [ + span[1] - 1 + for span in self.find_spans(r"\\+") + if (span[1] - span[0]) % 2 == 1 + ] + + def get_unescaped_char_spans(self, *chars: str): + return sorted(filter( + lambda span: span[0] - 1 not in self.backslash_indices, + self.find_substr(*chars) + )) + + def get_brace_index_pairs(self) -> list[Span]: + string = self.string + left_brace_indices = [] + right_brace_indices = [] + left_brace_indices_stack = [] + for index, _ in self.get_unescaped_char_spans("{", "}"): + if string[index] == "{": + left_brace_indices_stack.append(index) + else: + if not left_brace_indices_stack: + raise ValueError("Missing '{' inserted") + left_brace_index = left_brace_indices_stack.pop() + left_brace_indices.append(left_brace_index) + right_brace_indices.append(index) + if left_brace_indices_stack: + raise ValueError("Missing '}' inserted") + return list(zip(left_brace_indices, right_brace_indices)) + + def get_script_char_spans(self) -> list[int]: + return self.get_unescaped_char_spans("_", "^") + + def get_script_content_spans(self) -> list[Span]: + result = [] + brace_indices_dict = dict(self.brace_index_pairs) + for script_char_span in self.script_char_spans: + span_begin = self.rslide(script_char_span[1], self.space_spans) + if span_begin in brace_indices_dict.keys(): + span_end = brace_indices_dict[span_begin] + 1 + else: + pattern = re.compile(r"[a-zA-Z0-9]|\\[a-zA-Z]+") + match_obj = pattern.match(self.string, pos=span_begin) + if not match_obj: + script_name = { + "_": "subscript", + "^": "superscript" + }[script_char] + raise ValueError( + f"Unclear {script_name} detected while parsing. " + "Please use braces to clarify" + ) + span_end = match_obj.end() + result.append((span_begin, span_end)) + return result + + def get_script_spans(self) -> list[Span]: + return [ + ( + self.lslide(script_char_span[0], self.space_spans), + script_content_span[1] + ) + for script_char_span, script_content_span in zip( + self.script_char_spans, self.script_content_spans + ) + ] + + # Parsing + + def get_command_repl_items(self) -> list[tuple[Span, str]]: + color_related_command_dict = { + "color": (1, False), + "textcolor": (1, False), + "pagecolor": (1, True), + "colorbox": (1, True), + "fcolorbox": (2, True), + } + result = [] + backslash_indices = self.backslash_indices + right_brace_indices = [ + right_index + for left_index, right_index in self.brace_index_pairs + ] + pattern = "".join([ + r"\\", + "(", + "|".join(color_related_command_dict.keys()), + ")", + r"(?![a-zA-Z])" + ]) + for match_obj in re.finditer(pattern, self.string): + span_begin, cmd_end = match_obj.span() + if span_begin not in backslash_indices: + continue + cmd_name = match_obj.group(1) + n_braces, substitute_cmd = color_related_command_dict[cmd_name] + span_end = self.take_nearest_value( + right_brace_indices, cmd_end, n_braces + ) + 1 + if substitute_cmd: + repl_str = "\\" + cmd_name + n_braces * "{white}" + else: + repl_str = "" + result.append(((span_begin, span_end), repl_str)) + return result + + def get_ignored_spans(self) -> list[int]: + return self.script_char_spans.copy() + + def get_internal_specified_spans(self) -> list[Span]: + # Match paired double braces (`{{...}}`). + result = [] + reversed_brace_indices_dict = dict([ + pair[::-1] for pair in self.brace_index_pairs + ]) + skip = False + for prev_right_index, right_index in self.get_neighbouring_pairs( + list(reversed_brace_indices_dict.keys()) + ): + if skip: + skip = False + continue + if right_index != prev_right_index + 1: + continue + left_index = reversed_brace_indices_dict[right_index] + prev_left_index = reversed_brace_indices_dict[prev_right_index] + if left_index != prev_left_index - 1: + continue + result.append((left_index, right_index + 1)) + skip = True + return result + + def get_external_specified_spans(self) -> list[Span]: + return self.find_substr(*self.tex_to_color_map.keys()) + + def get_label_span_list(self) -> list[Span]: + result = self.script_content_spans.copy() + for span_begin, span_end in self.specified_spans: + shrinked_end = self.lslide(span_end, self.script_spans) + if span_begin >= shrinked_end: + continue + shrinked_span = (span_begin, shrinked_end) + if shrinked_span in result: + continue + result.append(shrinked_span) + return result + + def get_inserted_string_pairs( + self, use_plain_file: bool + ) -> list[tuple[Span, tuple[str, str]]]: + if use_plain_file: + return [] + + extended_label_span_list = [ + span + if span in self.script_content_spans + else (span[0], self.rslide(span[1], self.script_spans)) + for span in self.label_span_list + ] + return [ + (span, ( + self.get_begin_color_command_str(label), + self.get_end_color_command_str() + )) + for label, span in enumerate(extended_label_span_list) + ] + + def get_other_repl_items( + self, use_plain_file: bool + ) -> list[tuple[Span, str]]: + if use_plain_file: + return [] + return self.command_repl_items.copy() + + def get_has_predefined_colors(self) -> bool: + return bool(self.command_repl_items) + + # Post-parsing + + def get_cleaned_substr(self, span: Span) -> str: + substr = super().get_cleaned_substr(span) + if not self.brace_index_pairs: + return substr + + # Balance braces. + left_brace_indices, right_brace_indices = zip(*self.brace_index_pairs) + unclosed_left_braces = 0 + unclosed_right_braces = 0 + for index in range(*span): + if index in left_brace_indices: + unclosed_left_braces += 1 + elif index in right_brace_indices: + if unclosed_left_braces == 0: + unclosed_right_braces += 1 + else: + unclosed_left_braces -= 1 + return "".join([ + unclosed_right_braces * "{", + substr, + unclosed_left_braces * "}" + ]) + + # Method alias + + def get_parts_by_tex(self, tex: str) -> VGroup: + return self.get_parts_by_string(tex) + + def get_part_by_tex(self, tex: str) -> VMobject: + return self.get_part_by_string(tex) + + def set_color_by_tex(self, tex: str, color: ManimColor): + return self.set_color_by_string(tex, color) + + def set_color_by_tex_to_color_map( + self, tex_to_color_map: dict[str, ManimColor] + ): + return self.set_color_by_string_to_color_map(tex_to_color_map) + + def get_tex(self) -> str: + return self.get_string() + + +class MTexText(MTex): + CONFIG = { + "tex_environment": None, + } diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 2b6211d8..56944209 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -304,19 +304,14 @@ class MarkupText(LabelledString): region_indices[flag] += 1 if flag == 0: region_indices[1] += 1 + if not key: + continue for attr_dict in attr_dict_list[slice(*region_indices)]: attr_dict[key] = value return list(zip( MarkupText.get_neighbouring_pairs(index_seq), attr_dict_list[:-1] )) - def find_spans_by_word_or_span( - self, word_or_span: str | Span - ) -> list[Span]: - if isinstance(word_or_span, tuple): - return [word_or_span] - return self.find_spans(re.escape(word_or_span)) - # Pre-parsing def get_global_items_from_config(self) -> list[str, str]: @@ -408,28 +403,28 @@ class MarkupText(LabelledString): def get_local_items_from_config(self) -> list[tuple[Span, str, str]]: result = [ - (text_span, key, val) + (span, key, val) for t2x_dict, key in ( (self.t2c, "foreground"), (self.t2f, "font_family"), (self.t2s, "font_style"), (self.t2w, "font_weight") ) - for word_or_span, val in t2x_dict.items() - for text_span in self.find_spans_by_word_or_span(word_or_span) + for substr, val in t2x_dict.items() + for span in self.find_substr(substr) ] + [ - (text_span, key, val) - for word_or_span, local_config in self.local_configs.items() - for text_span in self.find_spans_by_word_or_span(word_or_span) + (span, key, val) + for substr, local_config in self.local_configs.items() + for span in self.find_substr(substr) for key, val in local_config.items() ] return [ ( - text_span, + span, self.convert_attr_key(key), self.convert_attr_val(val) ) - for text_span, key, val in result + for span, key, val in result ] def get_predefined_items(self) -> list[Span, str, str]: @@ -458,7 +453,7 @@ class MarkupText(LabelledString): (">", ">"), ("<", "<") ) - for span in self.find_spans(re.escape(char)) + for span in self.find_substr(char) ] return result @@ -468,6 +463,12 @@ class MarkupText(LabelledString): for markup_span, _, _ in self.local_items_from_markup ] + def get_external_specified_spans(self) -> list[Span]: + return [ + markup_span + for markup_span, _, _ in self.local_items_from_config + ] + def get_label_span_list(self) -> list[Span]: breakup_indices = remove_list_redundancies(list(it.chain(*it.chain( self.find_spans(r"\s+"), @@ -492,7 +493,7 @@ class MarkupText(LabelledString): def get_inserted_string_pairs( self, use_plain_file: bool ) -> list[tuple[Span, tuple[str, str]]]: - attr_items = self.predefined_items + attr_items = self.predefined_items.copy() if not use_plain_file: attr_items = [ (span, key, WHITE if key in COLOR_RELATED_KEYS else val) @@ -501,6 +502,11 @@ class MarkupText(LabelledString): (span, "foreground", self.rgb_int_to_hex(label)) for label, span in enumerate(self.label_span_list) ] + else: + attr_items += [ + (span, "", "") + for span in self.label_span_list + ] return [ (span, ( self.get_begin_tag_str(attr_dict), From dc816c9f8d1fd57af25627c10e654406a76417f3 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Thu, 31 Mar 2022 18:08:10 +0800 Subject: [PATCH 22/48] Improve algorithm --- manimlib/mobject/svg/text_mobject.py | 160 ++++++++++++--------------- 1 file changed, 69 insertions(+), 91 deletions(-) diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 56944209..78c1c45d 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -137,6 +137,8 @@ class MarkupText(LabelledString): self.full2short(kwargs) digest_config(self, kwargs) + if not self.font: + self.font = get_customization()["style"]["font"] if self.is_markup: validate_error = MarkupUtils.validate(text) if validate_error: @@ -236,19 +238,19 @@ class MarkupText(LabelledString): def pre_parse(self) -> None: super().pre_parse() - self.global_items_from_config = self.get_global_items_from_config() self.tag_items_from_markup = self.get_tag_items_from_markup() - self.local_items_from_markup = self.get_local_items_from_markup() - self.local_items_from_config = self.get_local_items_from_config() - self.predefined_items = self.get_predefined_items() + self.global_dict_from_config = self.get_global_dict_from_config() + self.local_dicts_from_markup = self.get_local_dicts_from_markup() + self.local_dicts_from_config = self.get_local_dicts_from_config() + self.predefined_attr_dicts = self.get_predefined_attr_dicts() # Toolkits @staticmethod def get_attr_dict_str(attr_dict: dict[str, str]) -> str: return " ".join([ - f"{key}='{value}'" - for key, value in attr_dict.items() + f"{key}='{val}'" + for key, val in attr_dict.items() ]) @staticmethod @@ -273,20 +275,12 @@ class MarkupText(LabelledString): return MarkupText.get_end_tag_str() @staticmethod - def convert_attr_key(key: str) -> str: - return SPAN_ATTR_KEY_CONVERSION[key.lower()] - - @staticmethod - def convert_attr_val(val: typing.Any) -> str: - return str(val).lower() - - @staticmethod - def merge_attr_items( - attr_items: list[Span, str, str] + def merge_attr_dicts( + attr_dict_items: list[Span, str, typing.Any] ) -> list[tuple[Span, dict[str, str]]]: index_seq = [0] attr_dict_list = [{}] - for span, key, value in attr_items: + for span, attr_dict in attr_dict_items: if span[0] >= span[1]: continue region_indices = [ @@ -304,51 +298,25 @@ class MarkupText(LabelledString): region_indices[flag] += 1 if flag == 0: region_indices[1] += 1 - if not key: - continue - for attr_dict in attr_dict_list[slice(*region_indices)]: - attr_dict[key] = value + for key, val in attr_dict.items(): + if not key: + continue + for mid_dict in attr_dict_list[slice(*region_indices)]: + mid_dict[key] = val return list(zip( MarkupText.get_neighbouring_pairs(index_seq), attr_dict_list[:-1] )) # Pre-parsing - def get_global_items_from_config(self) -> list[str, str]: - global_attr_dict = { - "line_height": ( - (self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1 - ) * 0.6, - "font_family": self.font or get_customization()["style"]["font"], - "font_size": self.font_size * 1024, - "font_style": self.slant, - "font_weight": self.weight - } - global_attr_dict = { - k: v - for k, v in global_attr_dict.items() - if v is not None - } - result = list(it.chain( - global_attr_dict.items(), - self.global_config.items() - )) - return [ - ( - self.convert_attr_key(key), - self.convert_attr_val(val) - ) - for key, val in result - ] - def get_tag_items_from_markup( self ) -> list[tuple[Span, Span, dict[str, str]]]: if not self.is_markup: return [] - tag_pattern = r"""<(/?)(\w+)\s*((\w+\s*\=\s*('.*?'|".*?")\s*)*)>""" - attr_pattern = r"""(\w+)\s*\=\s*(?:(?:'(.*?)')|(?:"(.*?)"))""" + tag_pattern = r"""<(/?)(\w+)\s*((?:\w+\s*\=\s*(['"]).*?\4\s*)*)>""" + attr_pattern = r"""(\w+)\s*\=\s*(['"])(.*?)\2""" begin_match_obj_stack = [] match_obj_pairs = [] for match_obj in re.finditer(tag_pattern, self.string): @@ -370,7 +338,7 @@ class MarkupText(LabelledString): raise ValueError("Attributes shan't exist in ending tags") if tag_name == "span": attr_dict = { - match.group(1): match.group(2) or match.group(3) + match.group(1): match.group(3) for match in re.finditer( attr_pattern, begin_match_obj.group(3) ) @@ -389,21 +357,33 @@ class MarkupText(LabelledString): ) return result - def get_local_items_from_markup(self) -> list[tuple[Span, str, str]]: + def get_global_dict_from_config(self) -> dict[str, typing.Any]: + result = { + "line_height": ( + (self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1 + ) * 0.6, + "font_family": self.font, + "font_size": self.font_size * 1024, + "font_style": self.slant, + "font_weight": self.weight + } + result.update(self.global_config) + return result + + def get_local_dicts_from_markup( + self + ) -> list[Span, dict[str, str]]: return sorted([ - ( - (begin_tag_span[0], end_tag_span[1]), - self.convert_attr_key(key), - self.convert_attr_val(val) - ) + ((begin_tag_span[0], end_tag_span[1]), attr_dict) for begin_tag_span, end_tag_span, attr_dict in self.tag_items_from_markup - for key, val in attr_dict.items() ]) - def get_local_items_from_config(self) -> list[tuple[Span, str, str]]: - result = [ - (span, key, val) + def get_local_dicts_from_config( + self + ) -> list[Span, dict[str, typing.Any]]: + return [ + (span, {key: val}) for t2x_dict, key in ( (self.t2c, "foreground"), (self.t2f, "font_family"), @@ -413,29 +393,24 @@ class MarkupText(LabelledString): for substr, val in t2x_dict.items() for span in self.find_substr(substr) ] + [ - (span, key, val) + (span, local_config) for substr, local_config in self.local_configs.items() for span in self.find_substr(substr) - for key, val in local_config.items() - ] - return [ - ( - span, - self.convert_attr_key(key), - self.convert_attr_val(val) - ) - for span, key, val in result ] - def get_predefined_items(self) -> list[Span, str, str]: - return list(it.chain( - [ - (self.full_span, key, val) - for key, val in self.global_items_from_config - ], - self.local_items_from_markup, - self.local_items_from_config - )) + def get_predefined_attr_dicts(self) -> list[Span, dict[str, str]]: + attr_dict_items = [ + (self.full_span, self.global_dict_from_config), + *self.local_dicts_from_markup, + *self.local_dicts_from_config + ] + return [ + (span, { + SPAN_ATTR_KEY_CONVERSION[key.lower()]: str(val) + for key, val in attr_dict.items() + }) + for span, attr_dict in attr_dict_items + ] # Parsing @@ -460,13 +435,13 @@ class MarkupText(LabelledString): def get_internal_specified_spans(self) -> list[Span]: return [ markup_span - for markup_span, _, _ in self.local_items_from_markup + for markup_span, _ in self.local_dicts_from_markup ] def get_external_specified_spans(self) -> list[Span]: return [ markup_span - for markup_span, _, _ in self.local_items_from_config + for markup_span, _, _ in self.local_dicts_from_config ] def get_label_span_list(self) -> list[Span]: @@ -493,18 +468,20 @@ class MarkupText(LabelledString): def get_inserted_string_pairs( self, use_plain_file: bool ) -> list[tuple[Span, tuple[str, str]]]: - attr_items = self.predefined_items.copy() if not use_plain_file: - attr_items = [ - (span, key, WHITE if key in COLOR_RELATED_KEYS else val) - for span, key, val in attr_items + attr_dict_items = [ + (span, { + key: WHITE if key in COLOR_RELATED_KEYS else val + for key, val in attr_dict.items() + }) + for span, attr_dict in self.predefined_attr_dicts ] + [ - (span, "foreground", self.rgb_int_to_hex(label)) + (span, {"foreground": self.rgb_int_to_hex(label)}) for label, span in enumerate(self.label_span_list) ] else: - attr_items += [ - (span, "", "") + attr_dict_items = self.predefined_attr_dicts + [ + (span, {}) for span in self.label_span_list ] return [ @@ -512,7 +489,7 @@ class MarkupText(LabelledString): self.get_begin_tag_str(attr_dict), self.get_end_tag_str() )) - for span, attr_dict in self.merge_attr_items(attr_items) + for span, attr_dict in self.merge_attr_dicts(attr_dict_items) ] def get_other_repl_items( @@ -523,7 +500,8 @@ class MarkupText(LabelledString): def get_has_predefined_colors(self) -> bool: return any([ key in COLOR_RELATED_KEYS - for _, key, _ in self.predefined_items + for _, attr_dict in self.predefined_attr_dicts + for key in attr_dict.keys() ]) # Method alias From 84c56b36241cf60d096f25f8aa611b0489dd2d23 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Thu, 31 Mar 2022 18:11:37 +0800 Subject: [PATCH 23/48] Fix typo --- manimlib/mobject/svg/text_mobject.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 78c1c45d..046b9151 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -433,16 +433,10 @@ class MarkupText(LabelledString): return result def get_internal_specified_spans(self) -> list[Span]: - return [ - markup_span - for markup_span, _ in self.local_dicts_from_markup - ] + return [span for span, _ in self.local_dicts_from_markup] def get_external_specified_spans(self) -> list[Span]: - return [ - markup_span - for markup_span, _, _ in self.local_dicts_from_config - ] + return [span for span, _ in self.local_dicts_from_config] def get_label_span_list(self) -> list[Span]: breakup_indices = remove_list_redundancies(list(it.chain(*it.chain( From 39673a80d7bbbea258c35ce5a1d37a0911aae4f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=B9=A4=E7=BF=94=E4=B8=87=E9=87=8C?= Date: Sat, 2 Apr 2022 22:00:02 +0800 Subject: [PATCH 24/48] fix: add missing import of annotations --- manimlib/utils/init_config.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/manimlib/utils/init_config.py b/manimlib/utils/init_config.py index cb0a1787..36ae9d4b 100644 --- a/manimlib/utils/init_config.py +++ b/manimlib/utils/init_config.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import yaml import inspect From 3c3264d7d6a66400cb396b364a19016962006a92 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Sat, 2 Apr 2022 22:42:19 +0800 Subject: [PATCH 25/48] Support passing in spans directly --- manimlib/mobject/svg/text_mobject.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 046b9151..818125cb 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -307,6 +307,19 @@ class MarkupText(LabelledString): MarkupText.get_neighbouring_pairs(index_seq), attr_dict_list[:-1] )) + def find_substr_or_span( + self, substr_or_span: str | tuple[int | None, int | None] + ) -> list[Span]: + if isinstance(substr_or_span, str): + return self.find_substr(substr) + + span_begin, span_end = substr_or_span + if span_begin is None: + span_begin = 0 + if span_end is None: + span_end = len(self.string) + return [(span_begin, span_end)] + # Pre-parsing def get_tag_items_from_markup( @@ -390,12 +403,12 @@ class MarkupText(LabelledString): (self.t2s, "font_style"), (self.t2w, "font_weight") ) - for substr, val in t2x_dict.items() - for span in self.find_substr(substr) + for substr_or_span, val in t2x_dict.items() + for span in self.find_substr_or_span(substr_or_span) ] + [ (span, local_config) - for substr, local_config in self.local_configs.items() - for span in self.find_substr(substr) + for substr_or_span, local_config in self.local_configs.items() + for span in self.find_substr_or_span(substr_or_span) ] def get_predefined_attr_dicts(self) -> list[Span, dict[str, str]]: From 974d9d5ab07feec9557bad82c526ca8c30ad0d00 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Mon, 4 Apr 2022 14:53:40 +0800 Subject: [PATCH 26/48] Avoid empty spans --- manimlib/mobject/svg/text_mobject.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 818125cb..166280ab 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -313,11 +313,18 @@ class MarkupText(LabelledString): if isinstance(substr_or_span, str): return self.find_substr(substr) + string_len = len(self.string) span_begin, span_end = substr_or_span if span_begin is None: span_begin = 0 + elif span_begin < 0: + span_begin += string_len if span_end is None: - span_end = len(self.string) + span_end = string_len + elif span_end < 0: + span_end += string_len + if span_begin >= span_end: + return [] return [(span_begin, span_end)] # Pre-parsing From 7f616987a3c7bafb8c2aabff01abdf74ef14bb96 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Tue, 5 Apr 2022 14:01:07 +0800 Subject: [PATCH 27/48] Fix typo --- manimlib/mobject/svg/text_mobject.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 166280ab..faa1c0a6 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -311,7 +311,7 @@ class MarkupText(LabelledString): self, substr_or_span: str | tuple[int | None, int | None] ) -> list[Span]: if isinstance(substr_or_span, str): - return self.find_substr(substr) + return self.find_substr(substr_or_span) string_len = len(self.string) span_begin, span_end = substr_or_span From b7647912582622cbea2d908247a997b31cbc6a9e Mon Sep 17 00:00:00 2001 From: YishiMichael <50232075+YishiMichael@users.noreply.github.com> Date: Tue, 5 Apr 2022 14:04:26 +0800 Subject: [PATCH 28/48] Fix typo (#1777) --- manimlib/mobject/svg/text_mobject.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 166280ab..faa1c0a6 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -311,7 +311,7 @@ class MarkupText(LabelledString): self, substr_or_span: str | tuple[int | None, int | None] ) -> list[Span]: if isinstance(substr_or_span, str): - return self.find_substr(substr) + return self.find_substr(substr_or_span) string_len = len(self.string) span_begin, span_end = substr_or_span From 55a91a2354ddb6e40404d7f3e0f28513cffeb680 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Tue, 5 Apr 2022 22:16:26 +0800 Subject: [PATCH 29/48] Remove unnecessary raise statement --- manimlib/mobject/svg/labelled_string.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 8f89ab2b..0813f3bc 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -146,8 +146,6 @@ class LabelledString(_StringSVG): def take_nearest_value(seq: list[int], val: int, index_shift: int) -> int: sorted_seq = sorted(seq) index = LabelledString.find_region_index(sorted_seq, val) - if index == -1: - raise IndexError return sorted_seq[index + index_shift] @staticmethod From f9d8a76767a3cfae79ce96924dc51831e6f59c55 Mon Sep 17 00:00:00 2001 From: YishiMichael <50232075+YishiMichael@users.noreply.github.com> Date: Tue, 5 Apr 2022 22:22:59 +0800 Subject: [PATCH 30/48] Remove unnecessary raise statement (#1778) * Fix typo * Remove unnecessary raise statement --- manimlib/mobject/svg/labelled_string.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 8f89ab2b..0813f3bc 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -146,8 +146,6 @@ class LabelledString(_StringSVG): def take_nearest_value(seq: list[int], val: int, index_shift: int) -> int: sorted_seq = sorted(seq) index = LabelledString.find_region_index(sorted_seq, val) - if index == -1: - raise IndexError return sorted_seq[index + index_shift] @staticmethod From e4ccbdfba92ca8c6ec57fb517545ed8af527ddaf Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Wed, 6 Apr 2022 11:14:45 +0800 Subject: [PATCH 31/48] docs: update changelog since v1.5.0 --- docs/source/development/changelog.rst | 54 ++++++++++++++++++++++++--- 1 file changed, 49 insertions(+), 5 deletions(-) diff --git a/docs/source/development/changelog.rst b/docs/source/development/changelog.rst index 6e865dc8..aa97e872 100644 --- a/docs/source/development/changelog.rst +++ b/docs/source/development/changelog.rst @@ -1,6 +1,50 @@ Changelog ========= +Unreleased +---------- + +Breaking changes +^^^^^^^^^^^^^^^^ +- **Python 3.6 is no longer supported** (`#1736 `__) + +Fixed bugs +^^^^^^^^^^ +- Fixed the width of riemann rectangles (`#1762 `__) +- Bug fixed in cases where empty array is passed to shader (`#1764 `__) +- Fixed ``AddTextWordByWord`` (`#1772 `__) + + +New features +^^^^^^^^^^^^ +- Added more functions to ``Text`` (details: `#1751 `__) +- Allowed ``interpolate`` to work on an array of alpha values (`#1764 `__) +- Allowed ``Numberline.number_to_point`` and ``CoordinateSystem.coords_to_point`` to work on an array of inputs (`#1764 `__) +- Added a basic ``Prismify`` to turn a flat ``VMobject`` into something with depth (`#1764 `__) +- Added ``GlowDots``, analogous to ``GlowDot`` (`#1764 `__) +- Added ``TransformMatchingStrings`` which is compatible with ``Text`` and ``MTex`` (`#1772 `__) + + +Refactor +^^^^^^^^ +- Added type hints (`#1736 `__) +- Specifid UTF-8 encoding for tex files (`#1748 `__) +- Refactored ``Text`` with the latest manimpango (`#1751 `__) +- Reorganized getters for ``ParametricCurve`` (`#1757 `__) +- Refactored ``CameraFrame`` to use ``scipy.spatial.transform.Rotation `` (`#1764 `__) +- Refactored rotation methods to use ``scipy.spatial.transform.Rotation`` (`#1764 `__) +- Used ``stroke_color`` to init ``Arrow`` (`#1764 `__) +- Refactored ``Mobject.set_rgba_array_by_color`` (`#1764 `__) +- Made panning more sensitive to mouse movements (`#1764 `__) +- Added loading progress for large SVGs (`#1766 `__) +- Added getter/setter of ``field_of_view`` for ``CameraFrame`` (`#1770 `__) +- Renamed ``focal_distance`` to ``focal_dist_to_height`` and added getter/setter (`#1770 `__) +- Added getter and setter for ``VMobject.joint_type`` (`#1770 `__) +- Refactored ``VCube`` (`#1770 `__) +- Refactored ``Prism`` to receive ``width height depth`` instead of ``dimensions`` (`#1770 `__) +- Refactored ``Text``, ``MarkupText`` and ``MTex`` based on ``LabelledString`` (`#1772 `__) + + v1.5.0 ------ @@ -9,7 +53,7 @@ Fixed bugs - Bug fix for the case of calling ``Write`` on a null object (`#1740 `__) -New Features +New features ^^^^^^^^^^^^ - Added ``TransformMatchingMTex`` (`#1725 `__) - Added ``ImplicitFunction`` (`#1727 `__) @@ -60,7 +104,7 @@ Fixed bugs - Fixed some bugs of SVG path string parser (`#1717 `__) - Fixed some bugs of ``MTex`` (`#1720 `__) -New Features +New features ^^^^^^^^^^^^ - Added option to add ticks on x-axis in ``BarChart`` (`#1694 `__) - Added ``lable_buff`` config parameter for ``Brace`` (`#1704 `__) @@ -99,7 +143,7 @@ Fixed bugs - Fixed bug in ``ShowSubmobjectsOneByOne`` (`bcd0990 `__) - Fixed bug in ``TransformMatchingParts`` (`7023548 `__) -New Features +New features ^^^^^^^^^^^^ - Added CLI flag ``--log-level`` to specify log level (`e10f850 `__) @@ -167,7 +211,7 @@ Fixed bugs - Fixed bug with ``CoordinateSystem.get_lines_parallel_to_axis`` (`c726eb7 `__) - Fixed ``ComplexPlane`` -i display bug (`7732d2f `__) -New Features +New features ^^^^^^^^^^^^ - Supported the elliptical arc command ``A`` for ``SVGMobject`` (`#1598 `__) @@ -230,7 +274,7 @@ Fixed bugs - Rewrote ``earclip_triangulation`` to fix triangulation - Allowed sound_file_name to be taken in without extensions -New Features +New features ^^^^^^^^^^^^ - Added :class:`~manimlib.animation.indication.VShowPassingFlash` From 93f8d3f1caa0c57412ae588fcb689e94b32613b6 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Wed, 6 Apr 2022 22:38:33 +0800 Subject: [PATCH 32/48] Some refactors on LabelledString --- manimlib/mobject/numbers.py | 7 +- manimlib/mobject/svg/labelled_string.py | 280 ++++++++++++------------ manimlib/mobject/svg/mtex_mobject.py | 30 +-- manimlib/mobject/svg/svg_mobject.py | 1 + manimlib/mobject/svg/text_mobject.py | 29 ++- 5 files changed, 175 insertions(+), 172 deletions(-) diff --git a/manimlib/mobject/numbers.py b/manimlib/mobject/numbers.py index bd003fb6..beac837c 100644 --- a/manimlib/mobject/numbers.py +++ b/manimlib/mobject/numbers.py @@ -6,12 +6,9 @@ from manimlib.constants import * from manimlib.mobject.svg.tex_mobject import SingleStringTex from manimlib.mobject.svg.text_mobject import Text from manimlib.mobject.types.vectorized_mobject import VMobject -from manimlib.utils.iterables import hash_obj T = TypeVar("T", bound=VMobject) -string_to_mob_map: dict[str, VMobject] = {} - class DecimalNumber(VMobject): CONFIG = { @@ -92,9 +89,7 @@ class DecimalNumber(VMobject): return self.data["font_size"][0] def string_to_mob(self, string: str, mob_class: Type[T] = Text, **kwargs) -> T: - if (string, hash_obj(kwargs)) not in string_to_mob_map: - string_to_mob_map[(string, hash_obj(kwargs))] = mob_class(string, font_size=1, **kwargs) - mob = string_to_mob_map[(string, hash_obj(kwargs))].copy() + mob = mob_class(string, font_size=1, **kwargs) mob.scale(self.get_font_size()) return mob diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 0813f3bc..85a6ccf9 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -11,7 +11,6 @@ from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.utils.color import color_to_int_rgb from manimlib.utils.config_ops import digest_config -from manimlib.utils.iterables import adjacent_pairs from manimlib.utils.iterables import remove_list_redundancies @@ -40,19 +39,29 @@ class LabelledString(_StringSVG): An abstract base class for `MTex` and `MarkupText` """ CONFIG = { - "base_color": WHITE, + "base_color": None, "use_plain_file": False, "isolate": [], } def __init__(self, string: str, **kwargs): self.string = string + reserved_svg_default = kwargs.pop("svg_default", {}) digest_config(self, kwargs) + self.reserved_svg_default = reserved_svg_default + self.base_color = self.base_color \ + or reserved_svg_default.get("color", None) \ + or reserved_svg_default.get("fill_color", None) \ + or WHITE + self.pre_parse() self.parse() super().__init__(**kwargs) - def get_file_path(self, use_plain_file: bool = False) -> str: + def get_file_path(self) -> str: + return self.get_file_path_(use_plain_file=False) + + def get_file_path_(self, use_plain_file: bool) -> str: content = self.get_decorated_string(use_plain_file=use_plain_file) return self.get_file_path_by_content(content) @@ -67,9 +76,17 @@ class LabelledString(_StringSVG): self.color_to_label(submob.get_fill_color()) for submob in self.submobjects ] - if self.use_plain_file or self.has_predefined_colors: - file_path = self.get_file_path(use_plain_file=True) - plain_svg = _StringSVG(file_path) + if any([ + self.use_plain_file, + self.reserved_svg_default, + self.has_predefined_colors + ]): + file_path = self.get_file_path_(use_plain_file=True) + plain_svg = _StringSVG( + file_path, + svg_default=self.reserved_svg_default, + path_string_config=self.path_string_config + ) self.set_submobjects(plain_svg.submobjects) else: self.set_fill(self.base_color) @@ -77,8 +94,8 @@ class LabelledString(_StringSVG): submob.label = label def pre_parse(self) -> None: - self.full_span = self.get_full_span() - self.space_spans = self.get_space_spans() + self.string_len = len(self.string) + self.full_span = (0, self.string_len) def parse(self) -> None: self.command_repl_items = self.get_command_repl_items() @@ -89,31 +106,67 @@ class LabelledString(_StringSVG): self.external_specified_spans = self.get_external_specified_spans() self.specified_spans = self.get_specified_spans() self.label_span_list = self.get_label_span_list() - self.containing_labels_dict = self.get_containing_labels_dict() - self.has_predefined_colors = self.get_has_predefined_colors() + self.check_overlapping() # Toolkits - def find_spans(self, pattern: str) -> list[Span]: - return [ - match_obj.span() - for match_obj in re.finditer(pattern, self.string) - ] + def get_substr(self, span: Span) -> str: + return self.string[slice(*span)] - def find_substr(self, *substrs: str) -> list[Span]: + def handle_regex_method(func): + def wrapper(self, pattern, pos=0, endpos=9223372036854775807): + return func()( + re.compile(pattern), self.string, pos=pos, endpos=endpos + ) + return wrapper + + @handle_regex_method + def finditer(): + return re.Pattern.finditer + + @handle_regex_method + def search(): + return re.Pattern.search + + @handle_regex_method + def match(): + return re.Pattern.match + + def find_spans(self, pattern: str) -> list[Span]: + return [match_obj.span() for match_obj in self.finditer(pattern)] + + def find_substr(self, substr: str) -> list[Span]: + if not substr: + return [] + return self.find_spans(re.escape(substr)) + + def find_substrs(self, substrs: list[str]) -> list[Span]: return list(it.chain(*[ - self.find_spans(re.escape(substr)) if substr else [] + self.find_substr(substr) for substr in remove_list_redundancies(substrs) ])) @staticmethod - def get_neighbouring_pairs(iterable: Iterable) -> list: - return list(adjacent_pairs(iterable))[:-1] + def get_neighbouring_pairs(iterable: list) -> list[tuple]: + return list(zip(iterable[:-1], iterable[1:])) @staticmethod def span_contains(span_0: Span, span_1: Span) -> bool: return span_0[0] <= span_1[0] and span_0[1] >= span_1[1] + @staticmethod + def get_complement_spans( + interval_spans: list[Span], universal_span: Span + ) -> list[Span]: + if not interval_spans: + return [universal_span] + + span_ends, span_begins = zip(*interval_spans) + return list(zip( + (universal_span[0], *span_begins), + (*span_ends, universal_span[1]) + )) + @staticmethod def compress_neighbours(vals: list[int]) -> list[tuple[int, Span]]: if not vals: @@ -148,31 +201,6 @@ class LabelledString(_StringSVG): index = LabelledString.find_region_index(sorted_seq, val) return sorted_seq[index + index_shift] - @staticmethod - def replace_str_by_spans( - substr: str, span_repl_dict: dict[Span, str] - ) -> str: - if not span_repl_dict: - return substr - - spans = sorted(span_repl_dict.keys()) - if not all( - span_0[1] <= span_1[0] - for span_0, span_1 in LabelledString.get_neighbouring_pairs(spans) - ): - raise ValueError("Overlapping replacement") - - span_ends, span_begins = zip(*spans) - pieces = [ - substr[slice(*span)] - for span in zip( - (0, *span_begins), - (*span_ends, len(substr)) - ) - ] - repl_strs = [*[span_repl_dict[span] for span in spans], ""] - return "".join(it.chain(*zip(pieces, repl_strs))) - @staticmethod def get_span_replacement_dict( inserted_string_pairs: list[tuple[Span, tuple[str, str]]], @@ -199,6 +227,27 @@ class LabelledString(_StringSVG): }) return result + def get_replaced_substr( + self, span: Span, span_repl_dict: dict[Span, str] + ) -> str: + repl_spans = sorted(filter( + lambda repl_span: self.span_contains(span, repl_span), + span_repl_dict.keys() + )) + if not all( + span_0[1] <= span_1[0] + for span_0, span_1 in self.get_neighbouring_pairs(repl_spans) + ): + raise ValueError("Overlapping replacement") + + pieces = [ + self.get_substr(piece_span) + for piece_span in self.get_complement_spans(repl_spans, span) + ] + repl_strs = [span_repl_dict[repl_span] for repl_span in repl_spans] + repl_strs.append("") + return "".join(it.chain(*zip(pieces, repl_strs))) + @staticmethod def rslide(index: int, skipped: list[Span]) -> int: transfer_dict = dict(sorted(skipped)) @@ -215,13 +264,6 @@ class LabelledString(_StringSVG): index = transfer_dict[index] return index - @staticmethod - def shrink_span(span: Span, skipped: list[Span]) -> Span: - return ( - LabelledString.rslide(span[0], skipped), - LabelledString.lslide(span[1], skipped) - ) - @staticmethod def rgb_to_int(rgb_tuple: tuple[int, int, int]) -> int: r, g, b = rgb_tuple @@ -250,14 +292,6 @@ class LabelledString(_StringSVG): def get_end_color_command_str() -> str: return "" - # Pre-parsing - - def get_full_span(self) -> Span: - return (0, len(self.string)) - - def get_space_spans(self) -> list[Span]: - return self.find_spans(r"\s") - # Parsing @abstractmethod @@ -272,11 +306,17 @@ class LabelledString(_StringSVG): def get_skipped_spans(self) -> list[Span]: return list(it.chain( - self.space_spans, + self.find_spans(r"\s"), self.command_spans, self.ignored_spans )) + def shrink_span(self, span: Span) -> Span: + return ( + self.rslide(span[0], self.skipped_spans), + self.lslide(span[1], self.skipped_spans) + ) + @abstractmethod def get_internal_specified_spans(self) -> list[Span]: return [] @@ -290,14 +330,11 @@ class LabelledString(_StringSVG): self.full_span, *self.internal_specified_spans, *self.external_specified_spans, - *self.find_substr(*self.isolate) + *self.find_substrs(self.isolate) ] shrinked_spans = list(filter( lambda span: span[0] < span[1], - [ - self.shrink_span(span, self.skipped_spans) - for span in spans - ] + [self.shrink_span(span) for span in spans] )) return remove_list_redundancies(shrinked_spans) @@ -305,28 +342,14 @@ class LabelledString(_StringSVG): def get_label_span_list(self) -> list[Span]: return [] - def get_containing_labels_dict(self) -> dict[Span, list[int]]: - label_span_list = self.label_span_list - result = { - span: [] - for span in label_span_list - } - for span_0 in label_span_list: - for span_index, span_1 in enumerate(label_span_list): - if self.span_contains(span_0, span_1): - result[span_0].append(span_index) - elif span_0[0] < span_1[0] < span_0[1] < span_1[1]: - string_0, string_1 = [ - self.string[slice(*span)] - for span in [span_0, span_1] - ] - raise ValueError( - "Partially overlapping substrings detected: " - f"'{string_0}' and '{string_1}'" - ) - if self.full_span not in result: - result[self.full_span] = list(range(len(label_span_list))) - return result + def check_overlapping(self) -> None: + for span_0, span_1 in it.product(self.label_span_list, repeat=2): + if not span_0[0] < span_1[0] < span_0[1] < span_1[1]: + continue + raise ValueError( + "Partially overlapping substrings detected: " + f"'{self.get_substr(span_0)}' and '{self.get_substr(span_1)}'" + ) @abstractmethod def get_inserted_string_pairs( @@ -345,7 +368,7 @@ class LabelledString(_StringSVG): self.get_inserted_string_pairs(use_plain_file), self.get_other_repl_items(use_plain_file) ) - result = self.replace_str_by_spans(self.string, span_repl_dict) + result = self.get_replaced_substr(self.full_span, span_repl_dict) if not use_plain_file: return result @@ -358,20 +381,14 @@ class LabelledString(_StringSVG): ]) @abstractmethod - def get_has_predefined_colors(self) -> bool: + def has_predefined_colors(self) -> bool: return False # Post-parsing def get_cleaned_substr(self, span: Span) -> str: - span_repl_dict = { - tuple([index - span[0] for index in cmd_span]): "" - for cmd_span in self.command_spans - if self.span_contains(span, cmd_span) - } - return self.replace_str_by_spans( - self.string[slice(*span)], span_repl_dict - ) + span_repl_dict = dict.fromkeys(self.command_spans, "") + return self.get_replaced_substr(span, span_repl_dict) def get_specified_substrs(self) -> list[str]: return remove_list_redundancies([ @@ -387,42 +404,36 @@ class LabelledString(_StringSVG): def get_group_substrs(self) -> list[str]: group_labels, _ = self.get_group_span_items() + if not group_labels: + return [] + ordered_spans = [ self.label_span_list[label] if label != -1 else self.full_span for label in group_labels ] - ordered_containing_labels = [ - self.containing_labels_dict[span] - for span in ordered_spans - ] - ordered_span_begins, ordered_span_ends = zip(*ordered_spans) - span_begins = [ - prev_end if prev_label in containing_labels else curr_begin - for prev_end, prev_label, containing_labels, curr_begin in zip( - ordered_span_ends[:-1], group_labels[:-1], - ordered_containing_labels[1:], ordered_span_begins[1:] + interval_spans = [ + ( + next_span[0] + if self.span_contains(prev_span, next_span) + else prev_span[1], + prev_span[1] + if self.span_contains(next_span, prev_span) + else next_span[0] + ) + for prev_span, next_span in self.get_neighbouring_pairs( + ordered_spans ) ] - span_ends = [ - next_begin if next_label in containing_labels else curr_end - for next_begin, next_label, containing_labels, curr_end in zip( - ordered_span_begins[1:], group_labels[1:], - ordered_containing_labels[:-1], ordered_span_ends[:-1] - ) - ] - spans = list(zip( - (ordered_span_begins[0], *span_begins), - (*span_ends, ordered_span_ends[-1]) - )) shrinked_spans = [ - self.shrink_span(span, self.skipped_spans) - for span in spans + self.shrink_span(span) + for span in self.get_complement_spans( + interval_spans, (ordered_spans[0][0], ordered_spans[-1][1]) + ) ] - group_substrs = [ + return [ self.get_cleaned_substr(span) if span[0] < span[1] else "" for span in shrinked_spans ] - return group_substrs def get_submob_groups(self) -> VGroup: _, submob_spans = self.get_group_span_items() @@ -433,18 +444,17 @@ class LabelledString(_StringSVG): # Selector - def find_span_components_of_custom_span( - self, custom_span: Span - ) -> list[Span]: - if custom_span[0] >= custom_span[1]: + def find_span_components(self, custom_span: Span) -> list[Span]: + shrinked_span = self.shrink_span(custom_span) + if shrinked_span[0] >= shrinked_span[1]: return [] indices = remove_list_redundancies(list(it.chain( self.full_span, *self.label_span_list ))) - span_begin = self.take_nearest_value(indices, custom_span[0], 0) - span_end = self.take_nearest_value(indices, custom_span[1] - 1, 1) + span_begin = self.take_nearest_value(indices, shrinked_span[0], 0) + span_end = self.take_nearest_value(indices, shrinked_span[1] - 1, 1) span_choices = sorted(filter( lambda span: self.span_contains((span_begin, span_end), span), self.label_span_list @@ -463,19 +473,19 @@ class LabelledString(_StringSVG): return result def get_parts_by_custom_span(self, custom_span: Span) -> VGroup: - spans = self.find_span_components_of_custom_span(custom_span) - labels = set(it.chain(*[ - self.containing_labels_dict[span] - for span in spans - ])) + labels = [ + label for label, span in enumerate(self.label_span_list) + if any([ + self.span_contains(span_component, span) + for span_component in self.find_span_components(custom_span) + ]) + ] return VGroup(*filter( lambda submob: submob.label in labels, self.submobjects )) def get_parts_by_string(self, substr: str) -> VGroup: - if not substr: - return VGroup() return VGroup(*[ self.get_parts_by_custom_span(span) for span in self.find_substr(substr) diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index a79c8f7c..23209fda 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -1,6 +1,5 @@ from __future__ import annotations -import re import colour from typing import Union, Sequence @@ -57,7 +56,7 @@ class MTex(LabelledString): def get_file_path_by_content(self, content: str) -> str: full_tex = self.get_tex_file_body(content) - with display_during_execution(f"Writing \"{self.string}\""): + with display_during_execution(f"Writing \"{self.tex_string}\""): file_path = self.tex_to_svg_file_path(full_tex) return file_path @@ -116,19 +115,19 @@ class MTex(LabelledString): if (span[1] - span[0]) % 2 == 1 ] - def get_unescaped_char_spans(self, *chars: str): + def get_unescaped_char_spans(self, chars: str): return sorted(filter( lambda span: span[0] - 1 not in self.backslash_indices, - self.find_substr(*chars) + self.find_substrs(list(chars)) )) def get_brace_index_pairs(self) -> list[Span]: - string = self.string left_brace_indices = [] right_brace_indices = [] left_brace_indices_stack = [] - for index, _ in self.get_unescaped_char_spans("{", "}"): - if string[index] == "{": + for span in self.get_unescaped_char_spans("{}"): + index = span[0] + if self.get_substr(span) == "{": left_brace_indices_stack.append(index) else: if not left_brace_indices_stack: @@ -141,18 +140,18 @@ class MTex(LabelledString): return list(zip(left_brace_indices, right_brace_indices)) def get_script_char_spans(self) -> list[int]: - return self.get_unescaped_char_spans("_", "^") + return self.get_unescaped_char_spans("_^") def get_script_content_spans(self) -> list[Span]: result = [] brace_indices_dict = dict(self.brace_index_pairs) + script_pattern = r"[a-zA-Z0-9]|\\[a-zA-Z]+" for script_char_span in self.script_char_spans: - span_begin = self.rslide(script_char_span[1], self.space_spans) + span_begin = self.match(r"\s*", pos=script_char_span[1]).end() if span_begin in brace_indices_dict.keys(): span_end = brace_indices_dict[span_begin] + 1 else: - pattern = re.compile(r"[a-zA-Z0-9]|\\[a-zA-Z]+") - match_obj = pattern.match(self.string, pos=span_begin) + match_obj = self.match(script_pattern, pos=span_begin) if not match_obj: script_name = { "_": "subscript", @@ -169,7 +168,7 @@ class MTex(LabelledString): def get_script_spans(self) -> list[Span]: return [ ( - self.lslide(script_char_span[0], self.space_spans), + self.search(r"\s*$", endpos=script_char_span[0]).start(), script_content_span[1] ) for script_char_span, script_content_span in zip( @@ -200,7 +199,7 @@ class MTex(LabelledString): ")", r"(?![a-zA-Z])" ]) - for match_obj in re.finditer(pattern, self.string): + for match_obj in self.finditer(pattern): span_begin, cmd_end = match_obj.span() if span_begin not in backslash_indices: continue @@ -243,7 +242,7 @@ class MTex(LabelledString): return result def get_external_specified_spans(self) -> list[Span]: - return self.find_substr(*self.tex_to_color_map.keys()) + return self.find_substrs(list(self.tex_to_color_map.keys())) def get_label_span_list(self) -> list[Span]: result = self.script_content_spans.copy() @@ -284,7 +283,8 @@ class MTex(LabelledString): return [] return self.command_repl_items.copy() - def get_has_predefined_colors(self) -> bool: + @property + def has_predefined_colors(self) -> bool: return bool(self.command_repl_items) # Post-parsing diff --git a/manimlib/mobject/svg/svg_mobject.py b/manimlib/mobject/svg/svg_mobject.py index bc625c83..b44c107f 100644 --- a/manimlib/mobject/svg/svg_mobject.py +++ b/manimlib/mobject/svg/svg_mobject.py @@ -58,6 +58,7 @@ class SVGMobject(VMobject): }, "path_string_config": {}, } + def __init__(self, file_name: str | None = None, **kwargs): super().__init__(**kwargs) self.file_name = file_name or self.file_name diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index faa1c0a6..b8d2f259 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -17,7 +17,6 @@ from manimpango import MarkupUtils from manimlib.logger import log from manimlib.constants import * from manimlib.mobject.svg.labelled_string import LabelledString -from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.utils.customization import get_customization from manimlib.utils.tex_file_writing import tex_hash from manimlib.utils.config_ops import digest_config @@ -30,9 +29,11 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from manimlib.mobject.types.vectorized_mobject import VMobject + from manimlib.mobject.types.vectorized_mobject import VGroup ManimColor = Union[str, colour.Color, Sequence[float]] Span = tuple[int, int] + TEXT_MOB_SCALE_FACTOR = 0.0076 DEFAULT_LINE_SPACING_SCALE = 0.6 @@ -313,19 +314,14 @@ class MarkupText(LabelledString): if isinstance(substr_or_span, str): return self.find_substr(substr_or_span) - string_len = len(self.string) - span_begin, span_end = substr_or_span - if span_begin is None: - span_begin = 0 - elif span_begin < 0: - span_begin += string_len - if span_end is None: - span_end = string_len - elif span_end < 0: - span_end += string_len - if span_begin >= span_end: + span = tuple([ + (index if index >= 0 else index + self.string_len) + if index is not None else substitute + for index, substitute in zip(substr_or_span, self.full_span) + ]) + if span[0] >= span[1]: return [] - return [(span_begin, span_end)] + return [span] # Pre-parsing @@ -339,7 +335,7 @@ class MarkupText(LabelledString): attr_pattern = r"""(\w+)\s*\=\s*(['"])(.*?)\2""" begin_match_obj_stack = [] match_obj_pairs = [] - for match_obj in re.finditer(tag_pattern, self.string): + for match_obj in self.finditer(tag_pattern): if not match_obj.group(1): begin_match_obj_stack.append(match_obj) else: @@ -475,7 +471,7 @@ class MarkupText(LabelledString): breakup_indices )) return list(filter( - lambda span: self.string[slice(*span)].strip(), + lambda span: self.get_substr(span).strip(), self.get_neighbouring_pairs(breakup_indices) )) @@ -511,7 +507,8 @@ class MarkupText(LabelledString): ) -> list[tuple[Span, str]]: return self.command_repl_items.copy() - def get_has_predefined_colors(self) -> bool: + @property + def has_predefined_colors(self) -> bool: return any([ key in COLOR_RELATED_KEYS for _, attr_dict in self.predefined_attr_dicts From 18963fb9fec412ad23a4e5d3f7f2c017174d9f10 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Wed, 6 Apr 2022 23:16:59 +0800 Subject: [PATCH 33/48] Some refactors on LabelledString --- manimlib/mobject/svg/labelled_string.py | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 85a6ccf9..4756bd61 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -113,24 +113,14 @@ class LabelledString(_StringSVG): def get_substr(self, span: Span) -> str: return self.string[slice(*span)] - def handle_regex_method(func): - def wrapper(self, pattern, pos=0, endpos=9223372036854775807): - return func()( - re.compile(pattern), self.string, pos=pos, endpos=endpos - ) - return wrapper + def finditer(self, pattern, **kwargs): + return re.compile(pattern).finditer(self.string, **kwargs) - @handle_regex_method - def finditer(): - return re.Pattern.finditer + def search(self, pattern, **kwargs): + return re.compile(pattern).search(self.string, **kwargs) - @handle_regex_method - def search(): - return re.Pattern.search - - @handle_regex_method - def match(): - return re.Pattern.match + def match(self, pattern, **kwargs): + return re.compile(pattern).match(self.string, **kwargs) def find_spans(self, pattern: str) -> list[Span]: return [match_obj.span() for match_obj in self.finditer(pattern)] From 557707ea75e6973207f2af81fd32c2cfb9eb8eb2 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Thu, 7 Apr 2022 00:46:41 +0800 Subject: [PATCH 34/48] Support substring and case_sensitive parameters --- manimlib/mobject/svg/labelled_string.py | 84 ++++++++++++++++--------- manimlib/mobject/svg/mtex_mobject.py | 18 +++--- manimlib/mobject/svg/text_mobject.py | 18 +++--- 3 files changed, 74 insertions(+), 46 deletions(-) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 4756bd61..a2a9f889 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -113,26 +113,31 @@ class LabelledString(_StringSVG): def get_substr(self, span: Span) -> str: return self.string[slice(*span)] - def finditer(self, pattern, **kwargs): - return re.compile(pattern).finditer(self.string, **kwargs) + def finditer( + self, pattern: str, flags: int = 0, **kwargs + ) -> Iterable[re.Match]: + return re.compile(pattern, flags).finditer(self.string, **kwargs) - def search(self, pattern, **kwargs): - return re.compile(pattern).search(self.string, **kwargs) + def search(self, pattern: str, flags: int = 0, **kwargs) -> re.Match: + return re.compile(pattern, flags).search(self.string, **kwargs) - def match(self, pattern, **kwargs): - return re.compile(pattern).match(self.string, **kwargs) + def match(self, pattern: str, flags: int = 0, **kwargs) -> re.Match: + return re.compile(pattern, flags).match(self.string, **kwargs) - def find_spans(self, pattern: str) -> list[Span]: - return [match_obj.span() for match_obj in self.finditer(pattern)] + def find_spans(self, pattern: str, **kwargs) -> list[Span]: + return [ + match_obj.span() + for match_obj in self.finditer(pattern, **kwargs) + ] - def find_substr(self, substr: str) -> list[Span]: + def find_substr(self, substr: str, **kwargs) -> list[Span]: if not substr: return [] - return self.find_spans(re.escape(substr)) + return self.find_spans(re.escape(substr), **kwargs) - def find_substrs(self, substrs: list[str]) -> list[Span]: + def find_substrs(self, substrs: list[str], **kwargs) -> list[Span]: return list(it.chain(*[ - self.find_substr(substr) + self.find_substr(substr, **kwargs) for substr in remove_list_redundancies(substrs) ])) @@ -434,17 +439,27 @@ class LabelledString(_StringSVG): # Selector - def find_span_components(self, custom_span: Span) -> list[Span]: + def find_span_components( + self, custom_span: Span, substring: bool = True + ) -> list[Span]: shrinked_span = self.shrink_span(custom_span) if shrinked_span[0] >= shrinked_span[1]: return [] - indices = remove_list_redundancies(list(it.chain( - self.full_span, - *self.label_span_list - ))) - span_begin = self.take_nearest_value(indices, shrinked_span[0], 0) - span_end = self.take_nearest_value(indices, shrinked_span[1] - 1, 1) + if substring: + indices = remove_list_redundancies(list(it.chain( + self.full_span, + *self.label_span_list + ))) + span_begin = self.take_nearest_value( + indices, shrinked_span[0], 0 + ) + span_end = self.take_nearest_value( + indices, shrinked_span[1] - 1, 1 + ) + else: + span_begin, span_end = shrinked_span + span_choices = sorted(filter( lambda span: self.span_contains((span_begin, span_end), span), self.label_span_list @@ -462,12 +477,14 @@ class LabelledString(_StringSVG): span_begin = next_begin return result - def get_parts_by_custom_span(self, custom_span: Span) -> VGroup: + def get_parts_by_custom_span(self, custom_span: Span, **kwargs) -> VGroup: labels = [ label for label, span in enumerate(self.label_span_list) if any([ self.span_contains(span_component, span) - for span_component in self.find_span_components(custom_span) + for span_component in self.find_span_components( + custom_span, **kwargs + ) ]) ] return VGroup(*filter( @@ -475,10 +492,15 @@ class LabelledString(_StringSVG): self.submobjects )) - def get_parts_by_string(self, substr: str) -> VGroup: + def get_parts_by_string( + self, substr: str, case_sensitive: bool = True, **kwargs + ) -> VGroup: + flags = 0 + if not case_sensitive: + flags |= re.I return VGroup(*[ - self.get_parts_by_custom_span(span) - for span in self.find_substr(substr) + self.get_parts_by_custom_span(span, **kwargs) + for span in self.find_substr(substr, flags=flags) ]) def get_parts_by_group_substr(self, substr: str) -> VGroup: @@ -490,18 +512,20 @@ class LabelledString(_StringSVG): if group_substr == substr ]) - def get_part_by_string(self, substr: str, index : int = 0) -> VMobject: - return self.get_parts_by_string(substr)[index] + def get_part_by_string( + self, substr: str, index: int = 0, **kwargs + ) -> VMobject: + return self.get_parts_by_string(substr, **kwargs)[index] - def set_color_by_string(self, substr: str, color: ManimColor): - self.get_parts_by_string(substr).set_color(color) + def set_color_by_string(self, substr: str, color: ManimColor, **kwargs): + self.get_parts_by_string(substr, **kwargs).set_color(color) return self def set_color_by_string_to_color_map( - self, string_to_color_map: dict[str, ManimColor] + self, string_to_color_map: dict[str, ManimColor], **kwargs ): for substr, color in string_to_color_map.items(): - self.set_color_by_string(substr, color) + self.set_color_by_string(substr, color, **kwargs) return self def indices_of_part(self, part: Iterable[VMobject]) -> list[int]: diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 23209fda..341db072 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -314,19 +314,21 @@ class MTex(LabelledString): # Method alias - def get_parts_by_tex(self, tex: str) -> VGroup: - return self.get_parts_by_string(tex) + def get_parts_by_tex(self, tex: str, **kwargs) -> VGroup: + return self.get_parts_by_string(tex, **kwargs) - def get_part_by_tex(self, tex: str) -> VMobject: - return self.get_part_by_string(tex) + def get_part_by_tex(self, tex: str, **kwargs) -> VMobject: + return self.get_part_by_string(tex, **kwargs) - def set_color_by_tex(self, tex: str, color: ManimColor): - return self.set_color_by_string(tex, color) + def set_color_by_tex(self, tex: str, color: ManimColor, **kwargs): + return self.set_color_by_string(tex, color, **kwargs) def set_color_by_tex_to_color_map( - self, tex_to_color_map: dict[str, ManimColor] + self, tex_to_color_map: dict[str, ManimColor], **kwargs ): - return self.set_color_by_string_to_color_map(tex_to_color_map) + return self.set_color_by_string_to_color_map( + tex_to_color_map, **kwargs + ) def get_tex(self) -> str: return self.get_string() diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index b8d2f259..d59c5d8b 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -517,19 +517,21 @@ class MarkupText(LabelledString): # Method alias - def get_parts_by_text(self, text: str) -> VGroup: - return self.get_parts_by_string(text) + def get_parts_by_text(self, text: str, **kwargs) -> VGroup: + return self.get_parts_by_string(text, **kwargs) - def get_part_by_text(self, text: str) -> VMobject: - return self.get_part_by_string(text) + def get_part_by_text(self, text: str, **kwargs) -> VMobject: + return self.get_part_by_string(text, **kwargs) - def set_color_by_text(self, text: str, color: ManimColor): - return self.set_color_by_string(text, color) + def set_color_by_text(self, text: str, color: ManimColor, **kwargs): + return self.set_color_by_string(text, color, **kwargs) def set_color_by_text_to_color_map( - self, text_to_color_map: dict[str, ManimColor] + self, text_to_color_map: dict[str, ManimColor], **kwargs ): - return self.set_color_by_string_to_color_map(text_to_color_map) + return self.set_color_by_string_to_color_map( + text_to_color_map, **kwargs + ) def get_text(self) -> str: return self.get_string() From 3550108ff76a0cfc3267850bdd67898375123912 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Thu, 7 Apr 2022 09:48:44 +0800 Subject: [PATCH 35/48] Handle out-of-bound spans --- manimlib/mobject/svg/text_mobject.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index d59c5d8b..8dbd05cc 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -315,9 +315,13 @@ class MarkupText(LabelledString): return self.find_substr(substr_or_span) span = tuple([ - (index if index >= 0 else index + self.string_len) - if index is not None else substitute - for index, substitute in zip(substr_or_span, self.full_span) + ( + min(index, self.string_len) + if index >= 0 + else max(index + self.string_len, 0) + ) + if index is not None else default_index + for index, default_index in zip(substr_or_span, self.full_span) ]) if span[0] >= span[1]: return [] From d31f3df5af8b064c35f722dee63424ab1b462d3d Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Thu, 7 Apr 2022 10:05:04 +0800 Subject: [PATCH 36/48] docs: update changelog for #1779 #1780 --- docs/source/development/changelog.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/source/development/changelog.rst b/docs/source/development/changelog.rst index aa97e872..2d327ba0 100644 --- a/docs/source/development/changelog.rst +++ b/docs/source/development/changelog.rst @@ -23,6 +23,7 @@ New features - Added a basic ``Prismify`` to turn a flat ``VMobject`` into something with depth (`#1764 `__) - Added ``GlowDots``, analogous to ``GlowDot`` (`#1764 `__) - Added ``TransformMatchingStrings`` which is compatible with ``Text`` and ``MTex`` (`#1772 `__) +- Added support for ``substring`` and ``case_sensitive`` parameters for ``LabelledString.get_parts_by_string`` (`#1780 `__) Refactor @@ -31,7 +32,7 @@ Refactor - Specifid UTF-8 encoding for tex files (`#1748 `__) - Refactored ``Text`` with the latest manimpango (`#1751 `__) - Reorganized getters for ``ParametricCurve`` (`#1757 `__) -- Refactored ``CameraFrame`` to use ``scipy.spatial.transform.Rotation `` (`#1764 `__) +- Refactored ``CameraFrame`` to use ``scipy.spatial.transform.Rotation`` (`#1764 `__) - Refactored rotation methods to use ``scipy.spatial.transform.Rotation`` (`#1764 `__) - Used ``stroke_color`` to init ``Arrow`` (`#1764 `__) - Refactored ``Mobject.set_rgba_array_by_color`` (`#1764 `__) @@ -43,6 +44,7 @@ Refactor - Refactored ``VCube`` (`#1770 `__) - Refactored ``Prism`` to receive ``width height depth`` instead of ``dimensions`` (`#1770 `__) - Refactored ``Text``, ``MarkupText`` and ``MTex`` based on ``LabelledString`` (`#1772 `__) +- Refactored ``LabelledString`` and relevant classes (`#1779 `__) v1.5.0 From 1f32a9e674c58ff7ee6456cc383c430803ee1976 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=B9=A4=E7=BF=94=E4=B8=87=E9=87=8C?= Date: Thu, 7 Apr 2022 10:50:18 +0800 Subject: [PATCH 37/48] Some fix (#1781) * fix: reduce warning from numpy * fix: fix ControlsExample --- example_scenes.py | 2 +- manimlib/utils/space_ops.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/example_scenes.py b/example_scenes.py index 70321ce5..5f61b874 100644 --- a/example_scenes.py +++ b/example_scenes.py @@ -650,7 +650,7 @@ class ControlsExample(Scene): def text_updater(old_text): assert(isinstance(old_text, Text)) - new_text = Text(self.textbox.get_value(), size=old_text.size) + new_text = Text(self.textbox.get_value(), font_size=old_text.font_size) # new_text.align_data_and_family(old_text) new_text.move_to(old_text) if self.checkbox.get_value(): diff --git a/manimlib/utils/space_ops.py b/manimlib/utils/space_ops.py index d570d04d..29c67ab4 100644 --- a/manimlib/utils/space_ops.py +++ b/manimlib/utils/space_ops.py @@ -152,7 +152,7 @@ def angle_between_vectors(v1: np.ndarray, v2: np.ndarray) -> float: """ n1 = get_norm(v1) n2 = get_norm(v2) - cos_angle = np.dot(v1, v2) / (n1 * n2) + cos_angle = np.dot(v1, v2) / np.float64(n1 * n2) return math.acos(clip(cos_angle, -1, 1)) From e8430b38b26c7b9f7bafbd254ebaae3902c1af99 Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Thu, 7 Apr 2022 10:57:21 +0800 Subject: [PATCH 38/48] docs: update changelog for #1781 --- docs/source/development/changelog.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/development/changelog.rst b/docs/source/development/changelog.rst index 2d327ba0..2a77ed5c 100644 --- a/docs/source/development/changelog.rst +++ b/docs/source/development/changelog.rst @@ -13,6 +13,7 @@ Fixed bugs - Fixed the width of riemann rectangles (`#1762 `__) - Bug fixed in cases where empty array is passed to shader (`#1764 `__) - Fixed ``AddTextWordByWord`` (`#1772 `__) +- Fixed ``ControlsExample`` (`#1781 `__) New features From 9d7db7aacd7116a8dbce0781f64ba44f065a7e39 Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Thu, 7 Apr 2022 11:00:43 +0800 Subject: [PATCH 39/48] release: ready to release v1.6.0 --- docs/source/development/changelog.rst | 4 ++-- setup.cfg | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/development/changelog.rst b/docs/source/development/changelog.rst index 2a77ed5c..ab56b032 100644 --- a/docs/source/development/changelog.rst +++ b/docs/source/development/changelog.rst @@ -1,8 +1,8 @@ Changelog ========= -Unreleased ----------- +v1.6.0 +------ Breaking changes ^^^^^^^^^^^^^^^^ diff --git a/setup.cfg b/setup.cfg index b52ed439..dea0e299 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = manimgl -version = 1.5.0 +version = 1.6.0 author = Grant Sanderson author_email= grant@3blue1brown.com description = Animation engine for explanatory math videos From 22776791111e3dcc22cb2bce0d37e12322dbde6f Mon Sep 17 00:00:00 2001 From: EbbDrop Date: Fri, 8 Apr 2022 22:59:06 +0200 Subject: [PATCH 40/48] Added a \overset as a special string --- manimlib/mobject/svg/tex_mobject.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/manimlib/mobject/svg/tex_mobject.py b/manimlib/mobject/svg/tex_mobject.py index 717f1c24..619f5bc9 100644 --- a/manimlib/mobject/svg/tex_mobject.py +++ b/manimlib/mobject/svg/tex_mobject.py @@ -100,6 +100,18 @@ class SingleStringTex(SVGMobject): filler = "{\\quad}" tex += filler + should_add_double_filler = reduce(op.or_, [ + tex == "\\overset", + # TODO: these can't be used since they change + # the latex draw order. + # tex == "\\frac", # you can use \\over as a alternative + # tex == "\\dfrac", + # tex == "\\binom", + ]) + if should_add_double_filler: + filler = "{\\quad}{\\quad}" + tex += filler + if tex == "\\substack": tex = "\\quad" From e23f667c3dce4ab5e571ff949082ebba076a704c Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Sun, 10 Apr 2022 08:36:13 +0800 Subject: [PATCH 41/48] Fix bug when handling multi-line tex --- manimlib/animation/creation.py | 2 +- .../animation/transform_matching_parts.py | 128 ++++++++---------- manimlib/mobject/svg/labelled_string.py | 106 ++++++++------- manimlib/mobject/svg/mtex_mobject.py | 4 +- manimlib/mobject/svg/text_mobject.py | 4 +- 5 files changed, 119 insertions(+), 125 deletions(-) diff --git a/manimlib/animation/creation.py b/manimlib/animation/creation.py index 82fb1605..6499d0af 100644 --- a/manimlib/animation/creation.py +++ b/manimlib/animation/creation.py @@ -214,7 +214,7 @@ class AddTextWordByWord(ShowIncreasingSubsets): def __init__(self, string_mobject, **kwargs): assert isinstance(string_mobject, LabelledString) - grouped_mobject = string_mobject.get_submob_groups() + grouped_mobject = string_mobject.submob_groups digest_config(self, kwargs) if self.run_time is None: self.run_time = self.time_per_word * len(grouped_mobject) diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index f824663d..dab88005 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -160,77 +160,67 @@ class TransformMatchingStrings(AnimationGroup): } def __init__(self, - source_mobject: LabelledString, - target_mobject: LabelledString, + source: LabelledString, + target: LabelledString, **kwargs ): digest_config(self, kwargs) - assert isinstance(source_mobject, LabelledString) - assert isinstance(target_mobject, LabelledString) + assert isinstance(source, LabelledString) + assert isinstance(target, LabelledString) anims = [] - rest_source_indices = list(range(len(source_mobject.submobjects))) - rest_target_indices = list(range(len(target_mobject.submobjects))) + source_indices = list(range(len(source.labelled_submobjects))) + target_indices = list(range(len(target.labelled_submobjects))) + + def get_indices_lists(mobject, parts): + return [ + [ + mobject.labelled_submobjects.index(submob) + for submob in part + ] + for part in parts + ] def add_anims_from(anim_class, func, source_args, target_args=None): if target_args is None: target_args = source_args.copy() for source_arg, target_arg in zip(source_args, target_args): - source_parts = func(source_mobject, source_arg) - target_parts = func(target_mobject, target_arg) - source_indices_lists = source_mobject.indices_lists_of_parts( - source_parts - ) - target_indices_lists = target_mobject.indices_lists_of_parts( - target_parts - ) - filtered_source_indices_lists = list(filter( + source_parts = func(source, source_arg) + target_parts = func(target, target_arg) + source_indices_lists = list(filter( lambda indices_list: all([ - index in rest_source_indices + index in source_indices for index in indices_list - ]), source_indices_lists + ]), get_indices_lists(source, source_parts) )) - filtered_target_indices_lists = list(filter( + target_indices_lists = list(filter( lambda indices_list: all([ - index in rest_target_indices + index in target_indices for index in indices_list - ]), target_indices_lists + ]), get_indices_lists(target, target_parts) )) - if not all([ - filtered_source_indices_lists, - filtered_target_indices_lists - ]): + if not source_indices_lists or not target_indices_lists: continue anims.append(anim_class(source_parts, target_parts, **kwargs)) - for index in it.chain(*filtered_source_indices_lists): - rest_source_indices.remove(index) - for index in it.chain(*filtered_target_indices_lists): - rest_target_indices.remove(index) + for index in it.chain(*source_indices_lists): + source_indices.remove(index) + for index in it.chain(*target_indices_lists): + target_indices.remove(index) - def get_common_substrs(func): + def get_common_substrs(substrs_from_source, substrs_from_target): return sorted([ - substr for substr in func(source_mobject) - if substr and substr in func(target_mobject) + substr for substr in substrs_from_source + if substr and substr in substrs_from_target ], key=len, reverse=True) def get_parts_from_keys(mobject, keys): - if not isinstance(keys, tuple): - keys = (keys,) - indices = [] + if isinstance(keys, str): + keys = [keys] + result = VGroup() for key in keys: - if isinstance(key, int): - indices.append(key) - elif isinstance(key, range): - indices.extend(key) - elif isinstance(key, str): - all_parts = mobject.get_parts_by_string(key) - indices.extend(it.chain(*[ - mobject.indices_of_part(part) for part in all_parts - ])) - else: + if not isinstance(key, str): raise TypeError(key) - return VGroup(VGroup(*[ - mobject[index] for index in remove_list_redundancies(indices) - ])) + result.add(*mobject.get_parts_by_string(key)) + return result add_anims_from( ReplacementTransform, get_parts_from_keys, @@ -239,38 +229,32 @@ class TransformMatchingStrings(AnimationGroup): add_anims_from( FadeTransformPieces, LabelledString.get_parts_by_string, - get_common_substrs(LabelledString.get_specified_substrs) + get_common_substrs( + source.specified_substrs, + target.specified_substrs + ) ) add_anims_from( FadeTransformPieces, LabelledString.get_parts_by_group_substr, - get_common_substrs(LabelledString.get_group_substrs) + get_common_substrs( + source.group_substrs, + target.group_substrs + ) ) - fade_source = VGroup(*[ - source_mobject[index] - for index in rest_source_indices - ]) - fade_target = VGroup(*[ - target_mobject[index] - for index in rest_target_indices - ]) + rest_source = VGroup(*[source[index] for index in source_indices]) + rest_target = VGroup(*[target[index] for index in target_indices]) if self.transform_mismatches: - anims.append(ReplacementTransform( - fade_source, - fade_target, - **kwargs - )) + anims.append( + ReplacementTransform(rest_source, rest_target, **kwargs) + ) else: - anims.append(FadeOutToPoint( - fade_source, - target_mobject.get_center(), - **kwargs - )) - anims.append(FadeInFromPoint( - fade_target, - source_mobject.get_center(), - **kwargs - )) + anims.append( + FadeOutToPoint(rest_source, target.get_center(), **kwargs) + ) + anims.append( + FadeInFromPoint(rest_target, source.get_center(), **kwargs) + ) super().__init__(*anims) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index a2a9f889..32d468a9 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -57,6 +57,7 @@ class LabelledString(_StringSVG): self.pre_parse() self.parse() super().__init__(**kwargs) + self.post_parse() def get_file_path(self) -> str: return self.get_file_path_(use_plain_file=False) @@ -108,6 +109,20 @@ class LabelledString(_StringSVG): self.label_span_list = self.get_label_span_list() self.check_overlapping() + def post_parse(self) -> None: + self.labelled_submobject_items = [ + (submob.label, submob) + for submob in self.submobjects + ] + self.labelled_submobjects = self.get_labelled_submobjects() + self.specified_substrs = self.get_specified_substrs() + self.group_items = self.get_group_items() + self.group_substrs = self.get_group_substrs() + self.submob_groups = self.get_submob_groups() + + def copy(self): + return self.deepcopy() + # Toolkits def get_substr(self, span: Span) -> str: @@ -118,10 +133,14 @@ class LabelledString(_StringSVG): ) -> Iterable[re.Match]: return re.compile(pattern, flags).finditer(self.string, **kwargs) - def search(self, pattern: str, flags: int = 0, **kwargs) -> re.Match: + def search( + self, pattern: str, flags: int = 0, **kwargs + ) -> re.Match | None: return re.compile(pattern, flags).search(self.string, **kwargs) - def match(self, pattern: str, flags: int = 0, **kwargs) -> re.Match: + def match( + self, pattern: str, flags: int = 0, **kwargs + ) -> re.Match | None: return re.compile(pattern, flags).match(self.string, **kwargs) def find_spans(self, pattern: str, **kwargs) -> list[Span]: @@ -275,9 +294,7 @@ class LabelledString(_StringSVG): def color_to_label(color: ManimColor) -> int: rgb_tuple = color_to_int_rgb(color) rgb = LabelledString.rgb_to_int(rgb_tuple) - if rgb == 16777215: # white - return -1 - return rgb + return rgb - 1 @abstractmethod def get_begin_color_command_str(int_rgb: int) -> str: @@ -321,12 +338,11 @@ class LabelledString(_StringSVG): return [] def get_specified_spans(self) -> list[Span]: - spans = [ - self.full_span, - *self.internal_specified_spans, - *self.external_specified_spans, - *self.find_substrs(self.isolate) - ] + spans = list(it.chain( + self.internal_specified_spans, + self.external_specified_spans, + self.find_substrs(self.isolate) + )) shrinked_spans = list(filter( lambda span: span[0] < span[1], [self.shrink_span(span) for span in spans] @@ -381,6 +397,9 @@ class LabelledString(_StringSVG): # Post-parsing + def get_labelled_submobjects(self) -> list[VMobject]: + return [submob for _, submob in self.labelled_submobject_items] + def get_cleaned_substr(self, span: Span) -> str: span_repl_dict = dict.fromkeys(self.command_spans, "") return self.get_replaced_substr(span, span_repl_dict) @@ -391,17 +410,14 @@ class LabelledString(_StringSVG): for span in self.specified_spans ]) - def get_group_span_items(self) -> tuple[list[int], list[Span]]: - submob_labels = [submob.label for submob in self.submobjects] - if not submob_labels: - return [], [] - return tuple(zip(*self.compress_neighbours(submob_labels))) - - def get_group_substrs(self) -> list[str]: - group_labels, _ = self.get_group_span_items() - if not group_labels: + def get_group_items(self) -> list[tuple[str, VGroup]]: + if not self.labelled_submobject_items: return [] + labels, labelled_submobjects = zip(*self.labelled_submobject_items) + group_labels, labelled_submob_spans = zip( + *self.compress_neighbours(labels) + ) ordered_spans = [ self.label_span_list[label] if label != -1 else self.full_span for label in group_labels @@ -425,16 +441,27 @@ class LabelledString(_StringSVG): interval_spans, (ordered_spans[0][0], ordered_spans[-1][1]) ) ] - return [ + group_substrs = [ self.get_cleaned_substr(span) if span[0] < span[1] else "" for span in shrinked_spans ] + submob_groups = VGroup(*[ + VGroup(*labelled_submobjects[slice(*submob_span)]) + for submob_span in labelled_submob_spans + ]) + return list(zip(group_substrs, submob_groups)) - def get_submob_groups(self) -> VGroup: - _, submob_spans = self.get_group_span_items() + def get_group_substrs(self) -> list[str]: + return [group_substr for group_substr, _ in self.group_items] + + def get_submob_groups(self) -> list[VGroup]: + return [submob_group for _, submob_group in self.group_items] + + def get_parts_by_group_substr(self, substr: str) -> VGroup: return VGroup(*[ - VGroup(*self.submobjects[slice(*submob_span)]) - for submob_span in submob_spans + group + for group_substr, group in self.group_items + if group_substr == substr ]) # Selector @@ -477,7 +504,7 @@ class LabelledString(_StringSVG): span_begin = next_begin return result - def get_parts_by_custom_span(self, custom_span: Span, **kwargs) -> VGroup: + def get_part_by_custom_span(self, custom_span: Span, **kwargs) -> VGroup: labels = [ label for label, span in enumerate(self.label_span_list) if any([ @@ -487,10 +514,10 @@ class LabelledString(_StringSVG): ) ]) ] - return VGroup(*filter( - lambda submob: submob.label in labels, - self.submobjects - )) + return VGroup(*[ + submob for label, submob in self.labelled_submobject_items + if label in labels + ]) def get_parts_by_string( self, substr: str, case_sensitive: bool = True, **kwargs @@ -499,19 +526,10 @@ class LabelledString(_StringSVG): if not case_sensitive: flags |= re.I return VGroup(*[ - self.get_parts_by_custom_span(span, **kwargs) + self.get_part_by_custom_span(span, **kwargs) for span in self.find_substr(substr, flags=flags) ]) - def get_parts_by_group_substr(self, substr: str) -> VGroup: - return VGroup(*[ - group - for group, group_substr in zip( - self.get_submob_groups(), self.get_group_substrs() - ) - if group_substr == substr - ]) - def get_part_by_string( self, substr: str, index: int = 0, **kwargs ) -> VMobject: @@ -528,13 +546,5 @@ class LabelledString(_StringSVG): self.set_color_by_string(substr, color, **kwargs) return self - def indices_of_part(self, part: Iterable[VMobject]) -> list[int]: - return [self.submobjects.index(submob) for submob in part] - - def indices_lists_of_parts( - self, parts: Iterable[Iterable[VMobject]] - ) -> list[list[int]]: - return [self.indices_of_part(part) for part in parts] - def get_string(self) -> str: return self.string diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 341db072..5668b183 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -209,7 +209,7 @@ class MTex(LabelledString): right_brace_indices, cmd_end, n_braces ) + 1 if substitute_cmd: - repl_str = "\\" + cmd_name + n_braces * "{white}" + repl_str = "\\" + cmd_name + n_braces * "{black}" else: repl_str = "" result.append(((span_begin, span_end), repl_str)) @@ -270,7 +270,7 @@ class MTex(LabelledString): ] return [ (span, ( - self.get_begin_color_command_str(label), + self.get_begin_color_command_str(label + 1), self.get_end_color_command_str() )) for label, span in enumerate(extended_label_span_list) diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 8dbd05cc..76ae8e38 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -485,12 +485,12 @@ class MarkupText(LabelledString): if not use_plain_file: attr_dict_items = [ (span, { - key: WHITE if key in COLOR_RELATED_KEYS else val + key: BLACK if key in COLOR_RELATED_KEYS else val for key, val in attr_dict.items() }) for span, attr_dict in self.predefined_attr_dicts ] + [ - (span, {"foreground": self.rgb_int_to_hex(label)}) + (span, {"foreground": self.rgb_int_to_hex(label + 1)}) for label, span in enumerate(self.label_span_list) ] else: From 36d62ae1a3955c165151450e41a599fb4ef1368d Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Sun, 10 Apr 2022 09:23:53 +0800 Subject: [PATCH 42/48] Add regex parameter --- manimlib/mobject/svg/labelled_string.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 32d468a9..b9d7b4fd 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -520,14 +520,17 @@ class LabelledString(_StringSVG): ]) def get_parts_by_string( - self, substr: str, case_sensitive: bool = True, **kwargs + self, substr: str, + case_sensitive: bool = True, regex: bool = False, **kwargs ) -> VGroup: flags = 0 if not case_sensitive: flags |= re.I + pattern = substr if regex else re.escape(substr) return VGroup(*[ self.get_part_by_custom_span(span, **kwargs) - for span in self.find_substr(substr, flags=flags) + for span in self.find_spans(pattern, flags=flags) + if span[0] < span[1] ]) def get_part_by_string( From 12bfe88f40c2fadbfb64b245be845df7a64a233e Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Mon, 11 Apr 2022 23:44:33 +0800 Subject: [PATCH 43/48] Some refactors --- manimlib/mobject/svg/labelled_string.py | 109 +++++++++++------------ manimlib/mobject/svg/mtex_mobject.py | 110 ++++++++++++------------ manimlib/mobject/svg/text_mobject.py | 87 ++++++++----------- 3 files changed, 142 insertions(+), 164 deletions(-) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index b9d7b4fd..58c47094 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -4,12 +4,14 @@ import re import colour import itertools as it from typing import Iterable, Union, Sequence -from abc import abstractmethod +from abc import ABC, abstractmethod -from manimlib.constants import WHITE +from manimlib.constants import BLACK, WHITE from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.utils.color import color_to_int_rgb +from manimlib.utils.color import color_to_rgb +from manimlib.utils.color import rgb_to_hex from manimlib.utils.config_ops import digest_config from manimlib.utils.iterables import remove_list_redundancies @@ -34,36 +36,39 @@ class _StringSVG(SVGMobject): } -class LabelledString(_StringSVG): +class LabelledString(_StringSVG, ABC): """ An abstract base class for `MTex` and `MarkupText` """ CONFIG = { - "base_color": None, + "base_color": WHITE, "use_plain_file": False, "isolate": [], } def __init__(self, string: str, **kwargs): self.string = string - reserved_svg_default = kwargs.pop("svg_default", {}) digest_config(self, kwargs) - self.reserved_svg_default = reserved_svg_default - self.base_color = self.base_color \ - or reserved_svg_default.get("color", None) \ - or reserved_svg_default.get("fill_color", None) \ + + # Convert `base_color` to hex code. + self.base_color = rgb_to_hex(color_to_rgb( + self.base_color \ + or self.svg_default.get("color", None) \ + or self.svg_default.get("fill_color", None) \ or WHITE + )) + self.svg_default["fill_color"] = BLACK self.pre_parse() self.parse() - super().__init__(**kwargs) + super().__init__() self.post_parse() def get_file_path(self) -> str: return self.get_file_path_(use_plain_file=False) def get_file_path_(self, use_plain_file: bool) -> str: - content = self.get_decorated_string(use_plain_file=use_plain_file) + content = self.get_content(use_plain_file) return self.get_file_path_by_content(content) @abstractmethod @@ -77,15 +82,11 @@ class LabelledString(_StringSVG): self.color_to_label(submob.get_fill_color()) for submob in self.submobjects ] - if any([ - self.use_plain_file, - self.reserved_svg_default, - self.has_predefined_colors - ]): + if self.use_plain_file or self.has_predefined_local_colors: file_path = self.get_file_path_(use_plain_file=True) plain_svg = _StringSVG( file_path, - svg_default=self.reserved_svg_default, + svg_default=self.svg_default, path_string_config=self.path_string_config ) self.set_submobjects(plain_svg.submobjects) @@ -101,7 +102,9 @@ class LabelledString(_StringSVG): def parse(self) -> None: self.command_repl_items = self.get_command_repl_items() self.command_spans = self.get_command_spans() - self.ignored_spans = self.get_ignored_spans() + self.extra_entity_spans = self.get_extra_entity_spans() + self.entity_spans = self.get_entity_spans() + self.extra_ignored_spans = self.get_extra_ignored_spans() self.skipped_spans = self.get_skipped_spans() self.internal_specified_spans = self.get_internal_specified_spans() self.external_specified_spans = self.get_external_specified_spans() @@ -216,7 +219,7 @@ class LabelledString(_StringSVG): return sorted_seq[index + index_shift] @staticmethod - def get_span_replacement_dict( + def generate_span_repl_dict( inserted_string_pairs: list[tuple[Span, tuple[str, str]]], other_repl_items: list[tuple[Span, str]] ) -> dict[Span, str]: @@ -290,20 +293,20 @@ class LabelledString(_StringSVG): r, g = divmod(rg, 256) return r, g, b + @staticmethod + def int_to_hex(rgb_int: int) -> str: + return "#{:06x}".format(rgb_int).upper() + + @staticmethod + def hex_to_int(rgb_hex: str) -> int: + return int(rgb_hex[1:], 16) + @staticmethod def color_to_label(color: ManimColor) -> int: rgb_tuple = color_to_int_rgb(color) rgb = LabelledString.rgb_to_int(rgb_tuple) return rgb - 1 - @abstractmethod - def get_begin_color_command_str(int_rgb: int) -> str: - return "" - - @abstractmethod - def get_end_color_command_str() -> str: - return "" - # Parsing @abstractmethod @@ -313,14 +316,25 @@ class LabelledString(_StringSVG): def get_command_spans(self) -> list[Span]: return [cmd_span for cmd_span, _ in self.command_repl_items] - def get_ignored_spans(self) -> list[int]: + @abstractmethod + def get_extra_entity_spans(self) -> list[Span]: + return [] + + def get_entity_spans(self) -> list[Span]: + return list(it.chain( + self.command_spans, + self.extra_entity_spans + )) + + @abstractmethod + def get_extra_ignored_spans(self) -> list[int]: return [] def get_skipped_spans(self) -> list[Span]: return list(it.chain( self.find_spans(r"\s"), self.command_spans, - self.ignored_spans + self.extra_ignored_spans )) def shrink_span(self, span: Span) -> Span: @@ -344,7 +358,11 @@ class LabelledString(_StringSVG): self.find_substrs(self.isolate) )) shrinked_spans = list(filter( - lambda span: span[0] < span[1], + lambda span: span[0] < span[1] and not any([ + entity_span[0] < index < entity_span[1] + for index in span + for entity_span in self.entity_spans + ]), [self.shrink_span(span) for span in spans] )) return remove_list_redundancies(shrinked_spans) @@ -363,36 +381,11 @@ class LabelledString(_StringSVG): ) @abstractmethod - def get_inserted_string_pairs( - self, use_plain_file: bool - ) -> list[tuple[Span, tuple[str, str]]]: - return [] + def get_content(self, use_plain_file: bool) -> str: + return "" @abstractmethod - def get_other_repl_items( - self, use_plain_file: bool - ) -> list[tuple[Span, str]]: - return [] - - def get_decorated_string(self, use_plain_file: bool) -> str: - span_repl_dict = self.get_span_replacement_dict( - self.get_inserted_string_pairs(use_plain_file), - self.get_other_repl_items(use_plain_file) - ) - result = self.get_replaced_substr(self.full_span, span_repl_dict) - - if not use_plain_file: - return result - return "".join([ - self.get_begin_color_command_str( - self.rgb_to_int(color_to_int_rgb(self.base_color)) - ), - result, - self.get_end_color_command_str() - ]) - - @abstractmethod - def has_predefined_colors(self) -> bool: + def has_predefined_local_colors(self) -> bool: return False # Post-parsing diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 5668b183..fb7922e1 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -1,5 +1,6 @@ from __future__ import annotations +import itertools as it import colour from typing import Union, Sequence @@ -32,7 +33,7 @@ class MTex(LabelledString): def __init__(self, tex_string: str, **kwargs): # Prevent from passing an empty string. if not tex_string: - tex_string = "\\quad" + tex_string = "\\\\" self.tex_string = tex_string super().__init__(tex_string, **kwargs) @@ -55,30 +56,14 @@ class MTex(LabelledString): ) def get_file_path_by_content(self, content: str) -> str: - full_tex = self.get_tex_file_body(content) - with display_during_execution(f"Writing \"{self.tex_string}\""): - file_path = self.tex_to_svg_file_path(full_tex) - return file_path - - def get_tex_file_body(self, content: str) -> str: - if self.tex_environment: - content = "\n".join([ - f"\\begin{{{self.tex_environment}}}", - content, - f"\\end{{{self.tex_environment}}}" - ]) - if self.alignment: - content = "\n".join([self.alignment, content]) - tex_config = get_tex_config() - return tex_config["tex_body"].replace( + full_tex = tex_config["tex_body"].replace( tex_config["text_to_replace"], content ) - - @staticmethod - def tex_to_svg_file_path(tex_file_content: str) -> str: - return tex_to_svg_file(tex_file_content) + with display_during_execution(f"Writing \"{self.tex_string}\""): + file_path = tex_to_svg_file(full_tex) + return file_path def pre_parse(self) -> None: super().pre_parse() @@ -91,29 +76,23 @@ class MTex(LabelledString): # Toolkits @staticmethod - def get_begin_color_command_str(rgb_int: int) -> str: + def get_color_command_str(rgb_int: int) -> str: rgb_tuple = MTex.int_to_rgb(rgb_int) return "".join([ - "{{", "\\color[RGB]", "{", ",".join(map(str, rgb_tuple)), "}" ]) - @staticmethod - def get_end_color_command_str() -> str: - return "}}" - # Pre-parsing def get_backslash_indices(self) -> list[int]: - # Newlines (`\\`) don't count. - return [ - span[1] - 1 + # The latter of `\\` doesn't count. + return list(it.chain(*[ + range(span[0], span[1], 2) for span in self.find_spans(r"\\+") - if (span[1] - span[0]) % 2 == 1 - ] + ])) def get_unescaped_char_spans(self, chars: str): return sorted(filter( @@ -215,7 +194,13 @@ class MTex(LabelledString): result.append(((span_begin, span_end), repl_str)) return result - def get_ignored_spans(self) -> list[int]: + def get_extra_entity_spans(self) -> list[Span]: + return [ + self.match(r"\\([a-zA-Z]+|.)", pos=index).span() + for index in self.backslash_indices + ] + + def get_extra_ignored_spans(self) -> list[int]: return self.script_char_spans.copy() def get_internal_specified_spans(self) -> list[Span]: @@ -256,35 +241,46 @@ class MTex(LabelledString): result.append(shrinked_span) return result - def get_inserted_string_pairs( - self, use_plain_file: bool - ) -> list[tuple[Span, tuple[str, str]]]: + def get_content(self, use_plain_file: bool) -> str: if use_plain_file: - return [] + span_repl_dict = {} + else: + extended_label_span_list = [ + span + if span in self.script_content_spans + else (span[0], self.rslide(span[1], self.script_spans)) + for span in self.label_span_list + ] + inserted_string_pairs = [ + (span, ( + "{{" + self.get_color_command_str(label + 1), + "}}" + )) + for label, span in enumerate(extended_label_span_list) + ] + span_repl_dict = self.generate_span_repl_dict( + inserted_string_pairs, + self.command_repl_items + ) + result = self.get_replaced_substr(self.full_span, span_repl_dict) - extended_label_span_list = [ - span - if span in self.script_content_spans - else (span[0], self.rslide(span[1], self.script_spans)) - for span in self.label_span_list - ] - return [ - (span, ( - self.get_begin_color_command_str(label + 1), - self.get_end_color_command_str() - )) - for label, span in enumerate(extended_label_span_list) - ] - - def get_other_repl_items( - self, use_plain_file: bool - ) -> list[tuple[Span, str]]: + if self.tex_environment: + result = "\n".join([ + f"\\begin{{{self.tex_environment}}}", + result, + f"\\end{{{self.tex_environment}}}" + ]) + if self.alignment: + result = "\n".join([self.alignment, result]) if use_plain_file: - return [] - return self.command_repl_items.copy() + result = "\n".join([ + self.get_color_command_str(self.hex_to_int(self.base_color)), + result + ]) + return result @property - def has_predefined_colors(self) -> bool: + def has_predefined_local_colors(self) -> bool: return bool(self.command_repl_items) # Post-parsing diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 76ae8e38..c3c3be19 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -254,27 +254,6 @@ class MarkupText(LabelledString): for key, val in attr_dict.items() ]) - @staticmethod - def get_begin_tag_str(attr_dict: dict[str, str]) -> str: - return f"" - - @staticmethod - def get_end_tag_str() -> str: - return "" - - @staticmethod - def rgb_int_to_hex(rgb_int: int) -> str: - return "#{:06x}".format(rgb_int).upper() - - @staticmethod - def get_begin_color_command_str(rgb_int: int): - color_hex = MarkupText.rgb_int_to_hex(rgb_int) - return MarkupText.get_begin_tag_str({"foreground": color_hex}) - - @staticmethod - def get_end_color_command_str() -> str: - return MarkupText.get_end_tag_str() - @staticmethod def merge_attr_dicts( attr_dict_items: list[Span, str, typing.Any] @@ -452,6 +431,14 @@ class MarkupText(LabelledString): ] return result + def get_extra_entity_spans(self) -> list[Span]: + if not self.is_markup: + return [] + return self.find_spans(r"&.*?;") + + def get_extra_ignored_spans(self) -> list[int]: + return [] + def get_internal_specified_spans(self) -> list[Span]: return [span for span, _ in self.local_dicts_from_markup] @@ -464,13 +451,10 @@ class MarkupText(LabelledString): self.find_spans(r"\b"), self.specified_spans )))) - entity_spans = self.command_spans.copy() - if self.is_markup: - entity_spans += self.find_spans(r"&.*?;") breakup_indices = sorted(filter( lambda index: not any([ span[0] < index < span[1] - for span in entity_spans + for span in self.entity_spans ]), breakup_indices )) @@ -479,40 +463,45 @@ class MarkupText(LabelledString): self.get_neighbouring_pairs(breakup_indices) )) - def get_inserted_string_pairs( - self, use_plain_file: bool - ) -> list[tuple[Span, tuple[str, str]]]: - if not use_plain_file: + def get_content(self, use_plain_file: bool) -> str: + if use_plain_file: attr_dict_items = [ - (span, { - key: BLACK if key in COLOR_RELATED_KEYS else val - for key, val in attr_dict.items() - }) - for span, attr_dict in self.predefined_attr_dicts - ] + [ - (span, {"foreground": self.rgb_int_to_hex(label + 1)}) - for label, span in enumerate(self.label_span_list) + (self.full_span, {"foreground": self.base_color}), + *self.predefined_attr_dicts, + *[ + (span, {}) + for span in self.label_span_list + ] ] else: - attr_dict_items = self.predefined_attr_dicts + [ - (span, {}) - for span in self.label_span_list + attr_dict_items = [ + (self.full_span, {"foreground": BLACK}), + *[ + (span, { + key: BLACK if key in COLOR_RELATED_KEYS else val + for key, val in attr_dict.items() + }) + for span, attr_dict in self.predefined_attr_dicts + ], + *[ + (span, {"foreground": self.int_to_hex(label + 1)}) + for label, span in enumerate(self.label_span_list) + ] ] - return [ + inserted_string_pairs = [ (span, ( - self.get_begin_tag_str(attr_dict), - self.get_end_tag_str() + f"", + "" )) for span, attr_dict in self.merge_attr_dicts(attr_dict_items) ] - - def get_other_repl_items( - self, use_plain_file: bool - ) -> list[tuple[Span, str]]: - return self.command_repl_items.copy() + span_repl_dict = self.generate_span_repl_dict( + inserted_string_pairs, self.command_repl_items + ) + return self.get_replaced_substr(self.full_span, span_repl_dict) @property - def has_predefined_colors(self) -> bool: + def has_predefined_local_colors(self) -> bool: return any([ key in COLOR_RELATED_KEYS for _, attr_dict in self.predefined_attr_dicts From 705f1a528b7ef299f7486181022600cfb641274f Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Mon, 11 Apr 2022 10:47:11 -0700 Subject: [PATCH 44/48] Separate functionality of ordinary linear interpolation from that using np.outer on arrays --- manimlib/utils/bezier.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/manimlib/utils/bezier.py b/manimlib/utils/bezier.py index 16ec3e10..71d3d2b9 100644 --- a/manimlib/utils/bezier.py +++ b/manimlib/utils/bezier.py @@ -80,15 +80,10 @@ def partial_quadratic_bezier_points( # Linear interpolation variants -def interpolate(start: T, end: T, alpha: float) -> T: + +def interpolate(start: T, end: T, alpha: np.ndarray | float) -> T: try: - if isinstance(alpha, float): - return (1 - alpha) * start + alpha * end - # Otherwise, assume alpha is a list or array, and return - # an appropriated shaped array of all corresponding - # interpolations - result = np.outer(1 - alpha, start) + np.outer(alpha, end) - return result.reshape((*np.shape(alpha), *np.shape(start))) + return (1 - alpha) * start + alpha * end except TypeError: log.debug(f"`start` parameter with type `{type(start)}` and dtype `{start.dtype}`") log.debug(f"`end` parameter with type `{type(end)}` and dtype `{end.dtype}`") @@ -97,6 +92,15 @@ def interpolate(start: T, end: T, alpha: float) -> T: sys.exit(2) +def outer_interpolate( + start: np.ndarray | float, + end: np.ndarray | float, + alpha: np.ndarray | float, +) -> T: + result = np.outer(1 - alpha, start) + np.outer(alpha, end) + return result.reshape((*np.shape(alpha), *np.shape(start))) + + def set_array_by_interpolation( arr: np.ndarray, arr1: np.ndarray, From dc4b9bc93c54e07b768aab47b417a95a40c7b85e Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Mon, 11 Apr 2022 10:47:26 -0700 Subject: [PATCH 45/48] Use outer_interpolate for NumberLine.number_to_point --- manimlib/mobject/number_line.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/manimlib/mobject/number_line.py b/manimlib/mobject/number_line.py index 13b6a13b..bc96b55a 100644 --- a/manimlib/mobject/number_line.py +++ b/manimlib/mobject/number_line.py @@ -7,6 +7,7 @@ from manimlib.mobject.geometry import Line from manimlib.mobject.numbers import DecimalNumber from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.utils.bezier import interpolate +from manimlib.utils.bezier import outer_interpolate from manimlib.utils.config_ops import digest_config from manimlib.utils.config_ops import merge_dicts_recursively from manimlib.utils.simple_functions import fdiv @@ -106,7 +107,7 @@ class NumberLine(Line): def number_to_point(self, number: float | np.ndarray) -> np.ndarray: alpha = (number - self.x_min) / (self.x_max - self.x_min) - return interpolate(self.get_start(), self.get_end(), alpha) + return outer_interpolate(self.get_start(), self.get_end(), alpha) def point_to_number(self, point: np.ndarray) -> float: points = self.get_points() From 55684af27df956ce3f9d95a3f1fd61a3263ae637 Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Tue, 12 Apr 2022 20:20:03 +0800 Subject: [PATCH 46/48] fix: fix ImageMobject by overriding set_color method --- manimlib/mobject/types/image_mobject.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/manimlib/mobject/types/image_mobject.py b/manimlib/mobject/types/image_mobject.py index d3f11f2b..54166d36 100644 --- a/manimlib/mobject/types/image_mobject.py +++ b/manimlib/mobject/types/image_mobject.py @@ -48,6 +48,9 @@ class ImageMobject(Mobject): mob.data["opacity"] = np.array([[o] for o in listify(opacity)]) return self + def set_color(self, color, opacity=None, recurse=None): + return self + def point_to_rgb(self, point: np.ndarray) -> np.ndarray: x0, y0 = self.get_corner(UL)[:2] x1, y1 = self.get_corner(DR)[:2] From 9d74e8bce3a6dff675628efbd205d321e7e13fa0 Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Wed, 13 Apr 2022 10:34:59 +0800 Subject: [PATCH 47/48] docs: update changelog for #1783 #1785 #1788 #1791 --- docs/source/development/changelog.rst | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/docs/source/development/changelog.rst b/docs/source/development/changelog.rst index ab56b032..a6b766e9 100644 --- a/docs/source/development/changelog.rst +++ b/docs/source/development/changelog.rst @@ -1,6 +1,20 @@ Changelog ========= +Unreleased +---------- + +Fixed bugs +^^^^^^^^^^ +- Fixed the bug of ``MTex`` with multi-line tex string (`#1785 `__) +- Fixed ``interpolate`` (`#1788 `__) +- Fixed ``ImageMobject`` (`#1791 `__) + +Refactor +^^^^^^^^ +- Added ``\overset`` as a special string in ``Tex`` (`#1783 `__) +- Added ``outer_interpolate`` to perform interpolation using ``np.outer`` on arrays (`#1788 `__) + v1.6.0 ------ From bda7f98d2e8b80b053f24d3e6011693b1e89ce76 Mon Sep 17 00:00:00 2001 From: TonyCrane Date: Wed, 13 Apr 2022 10:36:38 +0800 Subject: [PATCH 48/48] release: ready to release v1.6.1 --- docs/source/development/changelog.rst | 4 ++-- setup.cfg | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/development/changelog.rst b/docs/source/development/changelog.rst index a6b766e9..077e823c 100644 --- a/docs/source/development/changelog.rst +++ b/docs/source/development/changelog.rst @@ -1,8 +1,8 @@ Changelog ========= -Unreleased ----------- +v1.6.1 +------ Fixed bugs ^^^^^^^^^^ diff --git a/setup.cfg b/setup.cfg index dea0e299..934f051c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = manimgl -version = 1.6.0 +version = 1.6.1 author = Grant Sanderson author_email= grant@3blue1brown.com description = Animation engine for explanatory math videos