From e085c2e21428c9c79688a52b1de865ed132ae71a Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Sat, 23 Apr 2022 17:17:43 +0800 Subject: [PATCH 01/11] Refactor LabelledString and relevant classes --- manimlib/animation/creation.py | 14 +- .../animation/transform_matching_parts.py | 116 ++-- manimlib/mobject/svg/labelled_string.py | 590 +++++++----------- manimlib/mobject/svg/mtex_mobject.py | 375 +++++------ manimlib/mobject/svg/text_mobject.py | 480 ++++++-------- manimlib/utils/iterables.py | 32 +- 6 files changed, 725 insertions(+), 882 deletions(-) diff --git a/manimlib/animation/creation.py b/manimlib/animation/creation.py index 27460899..6ad6a9bd 100644 --- a/manimlib/animation/creation.py +++ b/manimlib/animation/creation.py @@ -1,12 +1,12 @@ from __future__ import annotations -import itertools as it -from abc import abstractmethod +from abc import ABC, abstractmethod import numpy as np from manimlib.animation.animation import Animation from manimlib.mobject.svg.labelled_string import LabelledString +from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.bezier import integer_interpolate from manimlib.utils.config_ops import digest_config @@ -17,10 +17,10 @@ from manimlib.utils.rate_functions import smooth from typing import TYPE_CHECKING if TYPE_CHECKING: - from manimlib.mobject.mobject import Group + from manimlib.mobject.mobject import Mobject -class ShowPartial(Animation): +class ShowPartial(Animation, ABC): """ Abstract class for ShowCreation and ShowPassingFlash """ @@ -176,7 +176,7 @@ class ShowIncreasingSubsets(Animation): "int_func": np.round, } - def __init__(self, group: Group, **kwargs): + def __init__(self, group: Mobject, **kwargs): self.all_submobs = list(group.submobjects) super().__init__(group, **kwargs) @@ -213,7 +213,9 @@ class AddTextWordByWord(ShowIncreasingSubsets): def __init__(self, string_mobject, **kwargs): assert isinstance(string_mobject, LabelledString) - grouped_mobject = string_mobject.submob_groups + grouped_mobject = VGroup(*[ + part for _, part in string_mobject.get_group_part_items() + ]) digest_config(self, kwargs) if self.run_time is None: self.run_time = self.time_per_word * len(grouped_mobject) diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index dab88005..e84f1d9d 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -5,9 +5,9 @@ import itertools as it import numpy as np from manimlib.animation.composition import AnimationGroup -from manimlib.animation.fading import FadeTransformPieces from manimlib.animation.fading import FadeInFromPoint from manimlib.animation.fading import FadeOutToPoint +from manimlib.animation.fading import FadeTransformPieces from manimlib.animation.transform import ReplacementTransform from manimlib.animation.transform import Transform from manimlib.mobject.mobject import Mobject @@ -16,13 +16,13 @@ from manimlib.mobject.svg.labelled_string import LabelledString from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.config_ops import digest_config -from manimlib.utils.iterables import remove_list_redundancies from typing import TYPE_CHECKING if TYPE_CHECKING: + from manimlib.mobject.svg.tex_mobject import SingleStringTex + from manimlib.mobject.svg.tex_mobject import Tex from manimlib.scene.scene import Scene - from manimlib.mobject.svg.tex_mobject import Tex, SingleStringTex class TransformMatchingParts(AnimationGroup): @@ -168,36 +168,36 @@ class TransformMatchingStrings(AnimationGroup): assert isinstance(source, LabelledString) assert isinstance(target, LabelledString) anims = [] - source_indices = list(range(len(source.labelled_submobjects))) - target_indices = list(range(len(target.labelled_submobjects))) - def get_indices_lists(mobject, parts): - return [ + source_submobs = [ + submob for _, submob in source.labelled_submobject_items + ] + target_submobs = [ + submob for _, submob in target.labelled_submobject_items + ] + source_indices = list(range(len(source_submobs))) + target_indices = list(range(len(target_submobs))) + + def get_filtered_indices_lists(parts, submobs, rest_indices): + return list(filter( + lambda indices_list: all([ + index in rest_indices + for index in indices_list + ]), [ - mobject.labelled_submobjects.index(submob) - for submob in part + [submobs.index(submob) for submob in part] + for part in parts ] - for part in parts - ] + )) - def add_anims_from(anim_class, func, source_args, target_args=None): - if target_args is None: - target_args = source_args.copy() - for source_arg, target_arg in zip(source_args, target_args): - source_parts = func(source, source_arg) - target_parts = func(target, target_arg) - source_indices_lists = list(filter( - lambda indices_list: all([ - index in source_indices - for index in indices_list - ]), get_indices_lists(source, source_parts) - )) - target_indices_lists = list(filter( - lambda indices_list: all([ - index in target_indices - for index in indices_list - ]), get_indices_lists(target, target_parts) - )) + def add_anims(anim_class, parts_pairs): + for source_parts, target_parts in parts_pairs: + source_indices_lists = get_filtered_indices_lists( + source_parts, source_submobs, source_indices + ) + target_indices_lists = get_filtered_indices_lists( + target_parts, target_submobs, target_indices + ) if not source_indices_lists or not target_indices_lists: continue anims.append(anim_class(source_parts, target_parts, **kwargs)) @@ -206,41 +206,45 @@ class TransformMatchingStrings(AnimationGroup): for index in it.chain(*target_indices_lists): target_indices.remove(index) - def get_common_substrs(substrs_from_source, substrs_from_target): - return sorted([ - substr for substr in substrs_from_source - if substr and substr in substrs_from_target - ], key=len, reverse=True) - - def get_parts_from_keys(mobject, keys): - if isinstance(keys, str): - keys = [keys] - result = VGroup() - for key in keys: - if not isinstance(key, str): - raise TypeError(key) - result.add(*mobject.get_parts_by_string(key)) + def get_substr_to_parts_map(part_items): + result = {} + for substr, part in part_items: + if substr not in result: + result[substr] = [] + result[substr].append(part) return result - add_anims_from( - ReplacementTransform, get_parts_from_keys, - self.key_map.keys(), self.key_map.values() + def add_anims_from(anim_class, func): + source_substr_to_parts_map = get_substr_to_parts_map(func(source)) + target_substr_to_parts_map = get_substr_to_parts_map(func(target)) + add_anims( + anim_class, + [ + ( + VGroup(*source_substr_to_parts_map[substr]), + VGroup(*target_substr_to_parts_map[substr]) + ) + for substr in sorted([ + s for s in source_substr_to_parts_map.keys() + if s and s in target_substr_to_parts_map.keys() + ], key=len, reverse=True) + ] + ) + + add_anims( + ReplacementTransform, + [ + (source.select_parts(k), target.select_parts(v)) + for k, v in self.key_map.items() + ] ) add_anims_from( FadeTransformPieces, - LabelledString.get_parts_by_string, - get_common_substrs( - source.specified_substrs, - target.specified_substrs - ) + LabelledString.get_specified_part_items ) add_anims_from( FadeTransformPieces, - LabelledString.get_parts_by_group_substr, - get_common_substrs( - source.group_substrs, - target.group_substrs - ) + LabelledString.get_group_part_items ) rest_source = VGroup(*[source[index] for index in source_indices]) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index f1354f0c..1c3f0afd 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -1,30 +1,43 @@ from __future__ import annotations -import re -import colour -import itertools as it -from typing import Iterable, Union, Sequence from abc import ABC, abstractmethod +import itertools as it +import numpy as np +import re -from manimlib.constants import BLACK, WHITE +from manimlib.constants import WHITE from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.types.vectorized_mobject import VGroup -from manimlib.utils.color import color_to_int_rgb from manimlib.utils.color import color_to_rgb from manimlib.utils.color import rgb_to_hex from manimlib.utils.config_ops import digest_config from manimlib.utils.iterables import remove_list_redundancies - from typing import TYPE_CHECKING if TYPE_CHECKING: - from manimlib.mobject.types.vectorized_mobject import VMobject - ManimColor = Union[str, colour.Color, Sequence[float]] + from colour import Color + from typing import Iterable, Sequence, TypeVar, Union + + ManimColor = Union[str, Color] Span = tuple[int, int] + Selector = Union[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]], + Iterable[Union[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]] + ]] + ] + T = TypeVar("T") -class _StringSVG(SVGMobject): +class LabelledString(SVGMobject, ABC): + """ + An abstract base class for `MTex` and `MarkupText` + """ CONFIG = { "height": None, "stroke_width": 0, @@ -33,42 +46,30 @@ class _StringSVG(SVGMobject): "should_subdivide_sharp_curves": True, "should_remove_null_curves": True, }, - } - - -class LabelledString(_StringSVG, ABC): - """ - An abstract base class for `MTex` and `MarkupText` - """ - CONFIG = { "base_color": WHITE, - "use_plain_file": False, "isolate": [], } def __init__(self, string: str, **kwargs): self.string = string digest_config(self, kwargs) + if self.base_color is None: + self.base_color = WHITE + self.base_color_int = self.color_to_int(self.base_color) - # Convert `base_color` to hex code. - self.base_color = rgb_to_hex(color_to_rgb( - self.base_color \ - or self.svg_default.get("color", None) \ - or self.svg_default.get("fill_color", None) \ - or WHITE - )) - self.svg_default["fill_color"] = BLACK - - self.pre_parse() + self.full_span = (0, len(self.string)) self.parse() - super().__init__() - self.post_parse() + super().__init__(**kwargs) + self.labelled_submobject_items = [ + (submob.label, submob) + for submob in self.submobjects + ] def get_file_path(self) -> str: - return self.get_file_path_(use_plain_file=False) + return self.get_file_path_(is_labelled=False) - def get_file_path_(self, use_plain_file: bool) -> str: - content = self.get_content(use_plain_file) + def get_file_path_(self, is_labelled: bool) -> str: + content = self.get_content(is_labelled) return self.get_file_path_by_content(content) @abstractmethod @@ -78,91 +79,135 @@ class LabelledString(_StringSVG, ABC): def generate_mobject(self) -> None: super().generate_mobject() - submob_labels = [ - self.color_to_label(submob.get_fill_color()) - for submob in self.submobjects - ] - if self.use_plain_file or self.has_predefined_local_colors: - file_path = self.get_file_path_(use_plain_file=True) - plain_svg = _StringSVG( - file_path, - svg_default=self.svg_default, - path_string_config=self.path_string_config - ) - self.set_submobjects(plain_svg.submobjects) + num_labels = len(self.label_span_list) + if num_labels: + file_path = self.get_file_path_(is_labelled=True) + labelled_svg = SVGMobject(file_path) + submob_color_ints = [ + self.color_to_int(submob.get_fill_color()) + for submob in labelled_svg.submobjects + ] else: - self.set_fill(self.base_color) - for submob, label in zip(self.submobjects, submob_labels): - submob.label = label + submob_color_ints = [0] * len(self.submobjects) - def pre_parse(self) -> None: - self.string_len = len(self.string) - self.full_span = (0, self.string_len) + if len(self.submobjects) != len(submob_color_ints): + raise ValueError( + "Cannot align submobjects of the labelled svg " + "to the original svg" + ) + + unrecognized_color_ints = self.remove_redundancies(sorted(filter( + lambda color_int: color_int > num_labels, + submob_color_ints + ))) + if unrecognized_color_ints: + raise ValueError( + "Unrecognized color label(s) detected: " + f"{','.join(map(self.int_to_hex, unrecognized_color_ints))}" + ) + + for submob, color_int in zip(self.submobjects, submob_color_ints): + submob.label = color_int - 1 def parse(self) -> None: self.command_repl_items = self.get_command_repl_items() - self.command_spans = self.get_command_spans() - self.extra_entity_spans = self.get_extra_entity_spans() - self.entity_spans = self.get_entity_spans() - self.extra_ignored_spans = self.get_extra_ignored_spans() - self.skipped_spans = self.get_skipped_spans() - self.internal_specified_spans = self.get_internal_specified_spans() - self.external_specified_spans = self.get_external_specified_spans() self.specified_spans = self.get_specified_spans() - self.label_span_list = self.get_label_span_list() self.check_overlapping() - - def post_parse(self) -> None: - self.labelled_submobject_items = [ - (submob.label, submob) - for submob in self.submobjects - ] - self.labelled_submobjects = self.get_labelled_submobjects() - self.specified_substrs = self.get_specified_substrs() - self.group_items = self.get_group_items() - self.group_substrs = self.get_group_substrs() - self.submob_groups = self.get_submob_groups() + self.label_span_list = self.get_label_span_list() + if len(self.label_span_list) >= 16777216: + raise ValueError("Cannot handle that many substrings") # Toolkits def get_substr(self, span: Span) -> str: return self.string[slice(*span)] - def finditer( - self, pattern: str, flags: int = 0, **kwargs - ) -> Iterable[re.Match]: - return re.compile(pattern, flags).finditer(self.string, **kwargs) + def match(self, pattern: str | re.Pattern, **kwargs) -> re.Pattern | None: + if isinstance(pattern, str): + pattern = re.compile(pattern) + return re.compile(pattern).match(self.string, **kwargs) - def search( - self, pattern: str, flags: int = 0, **kwargs - ) -> re.Match | None: - return re.compile(pattern, flags).search(self.string, **kwargs) - - def match( - self, pattern: str, flags: int = 0, **kwargs - ) -> re.Match | None: - return re.compile(pattern, flags).match(self.string, **kwargs) - - def find_spans(self, pattern: str, **kwargs) -> list[Span]: + def find_spans(self, pattern: str | re.Pattern, **kwargs) -> list[Span]: + if isinstance(pattern, str): + pattern = re.compile(pattern) return [ match_obj.span() - for match_obj in self.finditer(pattern, **kwargs) + for match_obj in pattern.finditer(self.string, **kwargs) ] - def find_substr(self, substr: str, **kwargs) -> list[Span]: - if not substr: - return [] - return self.find_spans(re.escape(substr), **kwargs) - - def find_substrs(self, substrs: list[str], **kwargs) -> list[Span]: - return list(it.chain(*[ - self.find_substr(substr, **kwargs) - for substr in remove_list_redundancies(substrs) - ])) + def find_indices(self, pattern: str | re.Pattern, **kwargs) -> list[int]: + return [index for index, _ in self.find_spans(pattern, **kwargs)] @staticmethod - def get_neighbouring_pairs(iterable: list) -> list[tuple]: - return list(zip(iterable[:-1], iterable[1:])) + def is_single_selector(selector: Selector) -> bool: + if isinstance(selector, str): + return True + if isinstance(selector, re.Pattern): + return True + if isinstance(selector, tuple): + if len(selector) == 2 and all([ + isinstance(index, int) or index is None + for index in selector + ]): + return True + return False + + def find_spans_by_selector(self, selector: Selector) -> list[Span]: + if self.is_single_selector(selector): + selector = (selector,) + result = [] + for sel in selector: + if not self.is_single_selector(sel): + raise TypeError(f"Invalid selector: '{sel}'") + if isinstance(sel, str): + spans = self.find_spans(re.escape(sel)) + elif isinstance(sel, re.Pattern): + spans = self.find_spans(sel) + else: + string_len = self.full_span[1] + span = tuple([ + ( + min(index, string_len) + if index >= 0 + else max(index + string_len, 0) + ) + if index is not None else default_index + for index, default_index in zip(sel, self.full_span) + ]) + spans = [span] + result.extend(spans) + return sorted(filter( + lambda span: span[0] < span[1], + self.remove_redundancies(result) + )) + + @staticmethod + def chain(*iterables: Iterable[T]) -> list[T]: + return list(it.chain(*iterables)) + + @staticmethod + def remove_redundancies(vals: Sequence[T]) -> list[T]: + return remove_list_redundancies(vals) + + @staticmethod + def get_neighbouring_pairs(vals: Sequence[T]) -> list[tuple[T, T]]: + return list(zip(vals[:-1], vals[1:])) + + @staticmethod + def compress_neighbours(vals: Sequence[T]) -> list[tuple[T, Span]]: + if not vals: + return [] + + unique_vals = [vals[0]] + indices = [0] + for index, val in enumerate(vals): + if val == unique_vals[-1]: + continue + unique_vals.append(val) + indices.append(index) + indices.append(len(vals)) + spans = LabelledString.get_neighbouring_pairs(indices) + return list(zip(unique_vals, spans)) @staticmethod def span_contains(span_0: Span, span_1: Span) -> bool: @@ -182,194 +227,88 @@ class LabelledString(_StringSVG, ABC): )) @staticmethod - def compress_neighbours(vals: list[int]) -> list[tuple[int, Span]]: - if not vals: + def merge_inserted_strings_from_pairs( + inserted_string_pairs: list[tuple[Span, tuple[str, str]]] + ) -> list[tuple[int, str]]: + if not inserted_string_pairs: return [] - unique_vals = [vals[0]] - indices = [0] - for index, val in enumerate(vals): - if val == unique_vals[-1]: - continue - unique_vals.append(val) - indices.append(index) - indices.append(len(vals)) - spans = LabelledString.get_neighbouring_pairs(indices) - return list(zip(unique_vals, spans)) - - @staticmethod - def find_region_index(seq: list[int], val: int) -> int: - # Returns an integer in `range(-1, len(seq))` satisfying - # `seq[result] <= val < seq[result + 1]`. - # `seq` should be sorted in ascending order. - if not seq or val < seq[0]: - return -1 - result = len(seq) - 1 - while val < seq[result]: - result -= 1 - return result - - @staticmethod - def take_nearest_value(seq: list[int], val: int, index_shift: int) -> int: - sorted_seq = sorted(seq) - index = LabelledString.find_region_index(sorted_seq, val) - return sorted_seq[index + index_shift] - - @staticmethod - def generate_span_repl_dict( - inserted_string_pairs: list[tuple[Span, tuple[str, str]]], - other_repl_items: list[tuple[Span, str]] - ) -> dict[Span, str]: - result = dict(other_repl_items) - if not inserted_string_pairs: - return result - - indices, _, _, inserted_strings = zip(*sorted([ - ( - span[flag], - -flag, - -span[1 - flag], - str_pair[flag] + spans = [ + span for span, _ in inserted_string_pairs + ] + sorted_index_flag_pairs = sorted( + it.product(range(len(spans)), range(2)), + key=lambda t: ( + spans[t[0]][t[1]], + np.sign(spans[t[0]][1 - t[1]] - spans[t[0]][t[1]]), + -spans[t[0]][1 - t[1]], + t[1], + (1, -1)[t[1]] * t[0] ) - for span, str_pair in inserted_string_pairs - for flag in range(2) - ])) - result.update({ - (index, index): "".join(inserted_strings[slice(*item_span)]) + ) + indices, inserted_strings = zip(*[ + list(zip(*inserted_string_pairs[item_index]))[flag] + for item_index, flag in sorted_index_flag_pairs + ]) + return [ + (index, "".join(inserted_strings[slice(*item_span)])) for index, item_span in LabelledString.compress_neighbours(indices) - }) - return result + ] def get_replaced_substr( - self, span: Span, span_repl_dict: dict[Span, str] + self, span: Span, repl_items: list[tuple[Span, str]] ) -> str: - repl_spans = sorted(filter( - lambda repl_span: self.span_contains(span, repl_span), - span_repl_dict.keys() - )) - if not all( - span_0[1] <= span_1[0] - for span_0, span_1 in self.get_neighbouring_pairs(repl_spans) - ): - raise ValueError("Overlapping replacement") + if not repl_items: + return self.get_substr(span) + sorted_repl_items = sorted(repl_items, key=lambda t: t[0]) + repl_spans, repl_strs = zip(*sorted_repl_items) pieces = [ self.get_substr(piece_span) for piece_span in self.get_complement_spans(repl_spans, span) ] - repl_strs = [span_repl_dict[repl_span] for repl_span in repl_spans] - repl_strs.append("") - return "".join(it.chain(*zip(pieces, repl_strs))) + repl_strs = [*repl_strs, ""] + return "".join(self.chain(*zip(pieces, repl_strs))) + + def get_replaced_string( + self, + inserted_string_pairs: list[tuple[Span, tuple[str, str]]], + repl_items: list[tuple[Span, str]] + ) -> str: + all_repl_items = self.chain( + repl_items, + [ + ((index, index), inserted_string) + for index, inserted_string + in self.merge_inserted_strings_from_pairs( + inserted_string_pairs + ) + ] + ) + return self.get_replaced_substr(self.full_span, all_repl_items) @staticmethod - def rslide(index: int, skipped: list[Span]) -> int: - transfer_dict = dict(sorted(skipped)) - while index in transfer_dict.keys(): - index = transfer_dict[index] - return index - - @staticmethod - def lslide(index: int, skipped: list[Span]) -> int: - transfer_dict = dict(sorted([ - skipped_span[::-1] for skipped_span in skipped - ], reverse=True)) - while index in transfer_dict.keys(): - index = transfer_dict[index] - return index - - @staticmethod - def rgb_to_int(rgb_tuple: tuple[int, int, int]) -> int: - r, g, b = rgb_tuple - rg = r * 256 + g - return rg * 256 + b - - @staticmethod - def int_to_rgb(rgb_int: int) -> tuple[int, int, int]: - rg, b = divmod(rgb_int, 256) - r, g = divmod(rg, 256) - return r, g, b + def color_to_int(color: ManimColor) -> int: + hex_code = rgb_to_hex(color_to_rgb(color)) + return int(hex_code[1:], 16) @staticmethod def int_to_hex(rgb_int: int) -> str: return "#{:06x}".format(rgb_int).upper() - @staticmethod - def hex_to_int(rgb_hex: str) -> int: - return int(rgb_hex[1:], 16) - - @staticmethod - def color_to_label(color: ManimColor) -> int: - rgb_tuple = color_to_int_rgb(color) - rgb = LabelledString.rgb_to_int(rgb_tuple) - return rgb - 1 - # Parsing @abstractmethod def get_command_repl_items(self) -> list[tuple[Span, str]]: return [] - def get_command_spans(self) -> list[Span]: - return [cmd_span for cmd_span, _ in self.command_repl_items] - @abstractmethod - def get_extra_entity_spans(self) -> list[Span]: - return [] - - def get_entity_spans(self) -> list[Span]: - return list(it.chain( - self.command_spans, - self.extra_entity_spans - )) - - @abstractmethod - def get_extra_ignored_spans(self) -> list[int]: - return [] - - def get_skipped_spans(self) -> list[Span]: - return list(it.chain( - self.find_spans(r"\s"), - self.command_spans, - self.extra_ignored_spans - )) - - def shrink_span(self, span: Span) -> Span: - return ( - self.rslide(span[0], self.skipped_spans), - self.lslide(span[1], self.skipped_spans) - ) - - @abstractmethod - def get_internal_specified_spans(self) -> list[Span]: - return [] - - @abstractmethod - def get_external_specified_spans(self) -> list[Span]: - return [] - def get_specified_spans(self) -> list[Span]: - spans = list(it.chain( - self.internal_specified_spans, - self.external_specified_spans, - self.find_substrs(self.isolate) - )) - shrinked_spans = list(filter( - lambda span: span[0] < span[1] and not any([ - entity_span[0] < index < entity_span[1] - for index in span - for entity_span in self.entity_spans - ]), - [self.shrink_span(span) for span in spans] - )) - return remove_list_redundancies(shrinked_spans) - - @abstractmethod - def get_label_span_list(self) -> list[Span]: return [] def check_overlapping(self) -> None: - for span_0, span_1 in it.product(self.label_span_list, repeat=2): + for span_0, span_1 in it.product(self.specified_spans, repeat=2): if not span_0[0] < span_1[0] < span_0[1] < span_1[1]: continue raise ValueError( @@ -378,29 +317,20 @@ class LabelledString(_StringSVG, ABC): ) @abstractmethod - def get_content(self, use_plain_file: bool) -> str: - return "" + def get_label_span_list(self) -> list[Span]: + return [] @abstractmethod - def has_predefined_local_colors(self) -> bool: - return False + def get_content(self, is_labelled: bool) -> str: + return "" - # Post-parsing - - def get_labelled_submobjects(self) -> list[VMobject]: - return [submob for _, submob in self.labelled_submobject_items] + # Selector + @abstractmethod def get_cleaned_substr(self, span: Span) -> str: - span_repl_dict = dict.fromkeys(self.command_spans, "") - return self.get_replaced_substr(span, span_repl_dict) + return "" - def get_specified_substrs(self) -> list[str]: - return remove_list_redundancies([ - self.get_cleaned_substr(span) - for span in self.specified_spans - ]) - - def get_group_items(self) -> list[tuple[str, VGroup]]: + def get_group_part_items(self) -> list[tuple[str, VGroup]]: if not self.labelled_submobject_items: return [] @@ -425,118 +355,56 @@ class LabelledString(_StringSVG, ABC): ordered_spans ) ] - shrinked_spans = [ - self.shrink_span(span) + group_substrs = [ + self.get_cleaned_substr(span) if span[0] < span[1] else "" for span in self.get_complement_spans( interval_spans, (ordered_spans[0][0], ordered_spans[-1][1]) ) ] - group_substrs = [ - self.get_cleaned_substr(span) if span[0] < span[1] else "" - for span in shrinked_spans - ] submob_groups = VGroup(*[ VGroup(*labelled_submobjects[slice(*submob_span)]) for submob_span in labelled_submob_spans ]) return list(zip(group_substrs, submob_groups)) - def get_group_substrs(self) -> list[str]: - return [group_substr for group_substr, _ in self.group_items] - - def get_submob_groups(self) -> list[VGroup]: - return [submob_group for _, submob_group in self.group_items] - - def get_parts_by_group_substr(self, substr: str) -> VGroup: - return VGroup(*[ - group - for group_substr, group in self.group_items - if group_substr == substr - ]) - - # Selector - - def find_span_components( - self, custom_span: Span, substring: bool = True - ) -> list[Span]: - shrinked_span = self.shrink_span(custom_span) - if shrinked_span[0] >= shrinked_span[1]: - return [] - - if substring: - indices = remove_list_redundancies(list(it.chain( - self.full_span, - *self.label_span_list - ))) - span_begin = self.take_nearest_value( - indices, shrinked_span[0], 0 + def get_specified_part_items(self) -> list[tuple[str, VGroup]]: + return [ + ( + self.get_substr(span), + self.select_part_by_span(span) ) - span_end = self.take_nearest_value( - indices, shrinked_span[1] - 1, 1 - ) - else: - span_begin, span_end = shrinked_span + for span in self.specified_spans + ] - span_choices = sorted(filter( - lambda span: self.span_contains((span_begin, span_end), span), - self.label_span_list - )) - # Choose spans that reach the farthest. - span_choices_dict = dict(span_choices) - - result = [] - while span_begin < span_end: - if span_begin not in span_choices_dict.keys(): - span_begin += 1 - continue - next_begin = span_choices_dict[span_begin] - result.append((span_begin, next_begin)) - span_begin = next_begin - return result - - def get_part_by_custom_span(self, custom_span: Span, **kwargs) -> VGroup: + def select_part_by_span(self, custom_span: Span) -> VGroup: labels = [ label for label, span in enumerate(self.label_span_list) - if any([ - self.span_contains(span_component, span) - for span_component in self.find_span_components( - custom_span, **kwargs - ) - ]) + if self.span_contains(custom_span, span) ] return VGroup(*[ submob for label, submob in self.labelled_submobject_items if label in labels ]) - def get_parts_by_string( - self, substr: str, - case_sensitive: bool = True, regex: bool = False, **kwargs - ) -> VGroup: - flags = 0 - if not case_sensitive: - flags |= re.I - pattern = substr if regex else re.escape(substr) - return VGroup(*[ - self.get_part_by_custom_span(span, **kwargs) - for span in self.find_spans(pattern, flags=flags) - if span[0] < span[1] - ]) + def select_parts(self, selector: Selector) -> VGroup: + return VGroup(*filter( + lambda part: part.submobjects, + [ + self.select_part_by_span(span) + for span in self.find_spans_by_selector(selector) + ] + )) - def get_part_by_string( - self, substr: str, index: int = 0, **kwargs - ) -> VMobject: - return self.get_parts_by_string(substr, **kwargs)[index] + def select_part(self, selector: Selector, index: int = 0) -> VGroup: + return self.select_parts(selector)[index] - def set_color_by_string(self, substr: str, color: ManimColor, **kwargs): - self.get_parts_by_string(substr, **kwargs).set_color(color) + def set_parts_color(self, selector: Selector, color: ManimColor): + self.select_parts(selector).set_color(color) return self - def set_color_by_string_to_color_map( - self, string_to_color_map: dict[str, ManimColor], **kwargs - ): - for substr, color in string_to_color_map.items(): - self.set_color_by_string(substr, color, **kwargs) + def set_parts_color_by_dict(self, color_map: dict[Selector, ManimColor]): + for selector, color in color_map.items(): + self.set_parts_color(selector, color) return self def get_string(self) -> str: diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index fb7922e1..ed7273ee 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -1,27 +1,46 @@ from __future__ import annotations -import itertools as it -import colour -from typing import Union, Sequence +import re from manimlib.mobject.svg.labelled_string import LabelledString -from manimlib.utils.tex_file_writing import tex_to_svg_file -from manimlib.utils.tex_file_writing import get_tex_config from manimlib.utils.tex_file_writing import display_during_execution - +from manimlib.utils.tex_file_writing import get_tex_config +from manimlib.utils.tex_file_writing import tex_to_svg_file from typing import TYPE_CHECKING if TYPE_CHECKING: - from manimlib.mobject.types.vectorized_mobject import VMobject + from colour import Color + from typing import Iterable, Union + from manimlib.mobject.types.vectorized_mobject import VGroup - ManimColor = Union[str, colour.Color, Sequence[float]] + + ManimColor = Union[str, Color] Span = tuple[int, int] + Selector = Union[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]], + Iterable[Union[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]] + ]] + ] SCALE_FACTOR_PER_FONT_POINT = 0.001 +TEX_COLOR_COMMANDS_DICT = { + "\\color": (1, False), + "\\textcolor": (1, False), + "\\pagecolor": (1, True), + "\\colorbox": (1, True), + "\\fcolorbox": (2, True), +} + + class MTex(LabelledString): CONFIG = { "font_size": 48, @@ -32,7 +51,7 @@ class MTex(LabelledString): def __init__(self, tex_string: str, **kwargs): # Prevent from passing an empty string. - if not tex_string: + if not tex_string.strip(): tex_string = "\\\\" self.tex_string = tex_string super().__init__(tex_string, **kwargs) @@ -47,7 +66,6 @@ class MTex(LabelledString): self.svg_default, self.path_string_config, self.base_color, - self.use_plain_file, self.isolate, self.tex_string, self.alignment, @@ -61,85 +79,95 @@ class MTex(LabelledString): tex_config["text_to_replace"], content ) - with display_during_execution(f"Writing \"{self.tex_string}\""): + with display_during_execution(f"Writing \"{self.string}\""): file_path = tex_to_svg_file(full_tex) return file_path - def pre_parse(self) -> None: - super().pre_parse() + def parse(self) -> None: self.backslash_indices = self.get_backslash_indices() - self.brace_index_pairs = self.get_brace_index_pairs() - self.script_char_spans = self.get_script_char_spans() + self.command_spans = self.get_command_spans() + self.brace_spans = self.get_brace_spans() + self.script_char_indices = self.get_script_char_indices() self.script_content_spans = self.get_script_content_spans() self.script_spans = self.get_script_spans() + super().parse() # Toolkits @staticmethod def get_color_command_str(rgb_int: int) -> str: - rgb_tuple = MTex.int_to_rgb(rgb_int) - return "".join([ - "\\color[RGB]", - "{", - ",".join(map(str, rgb_tuple)), - "}" - ]) + rg, b = divmod(rgb_int, 256) + r, g = divmod(rg, 256) + return f"\\color[RGB]{{{r}, {g}, {b}}}" - # Pre-parsing + @staticmethod + def shrink_span(span: Span, skippable_indices: list[int]) -> Span: + span_begin, span_end = span + while span_begin in skippable_indices: + span_begin += 1 + while span_end - 1 in skippable_indices: + span_end -= 1 + return (span_begin, span_end) + + # Parsing def get_backslash_indices(self) -> list[int]: # The latter of `\\` doesn't count. - return list(it.chain(*[ - range(span[0], span[1], 2) - for span in self.find_spans(r"\\+") - ])) + return self.find_indices(r"\\.") - def get_unescaped_char_spans(self, chars: str): - return sorted(filter( - lambda span: span[0] - 1 not in self.backslash_indices, - self.find_substrs(list(chars)) + def get_command_spans(self) -> list[Span]: + return [ + self.match(r"\\(?:[a-zA-Z]+|.)", pos=index).span() + for index in self.backslash_indices + ] + + def get_unescaped_char_indices(self, char: str) -> list[int]: + return list(filter( + lambda index: index - 1 not in self.backslash_indices, + self.find_indices(re.escape(char)) )) - def get_brace_index_pairs(self) -> list[Span]: - left_brace_indices = [] - right_brace_indices = [] - left_brace_indices_stack = [] - for span in self.get_unescaped_char_spans("{}"): - index = span[0] - if self.get_substr(span) == "{": - left_brace_indices_stack.append(index) + def get_brace_spans(self) -> list[Span]: + span_begins = [] + span_ends = [] + span_begins_stack = [] + char_items = sorted([ + (index, char) + for char in "{}" + for index in self.get_unescaped_char_indices(char) + ]) + for index, char in char_items: + if char == "{": + span_begins_stack.append(index) else: - if not left_brace_indices_stack: + if not span_begins_stack: raise ValueError("Missing '{' inserted") - left_brace_index = left_brace_indices_stack.pop() - left_brace_indices.append(left_brace_index) - right_brace_indices.append(index) - if left_brace_indices_stack: + span_begins.append(span_begins_stack.pop()) + span_ends.append(index + 1) + if span_begins_stack: raise ValueError("Missing '}' inserted") - return list(zip(left_brace_indices, right_brace_indices)) + return list(zip(span_begins, span_ends)) - def get_script_char_spans(self) -> list[int]: - return self.get_unescaped_char_spans("_^") + def get_script_char_indices(self) -> list[int]: + return self.chain(*[ + self.get_unescaped_char_indices(char) + for char in "_^" + ]) def get_script_content_spans(self) -> list[Span]: result = [] - brace_indices_dict = dict(self.brace_index_pairs) - script_pattern = r"[a-zA-Z0-9]|\\[a-zA-Z]+" - for script_char_span in self.script_char_spans: - span_begin = self.match(r"\s*", pos=script_char_span[1]).end() - if span_begin in brace_indices_dict.keys(): - span_end = brace_indices_dict[span_begin] + 1 + script_entity_dict = dict(self.chain( + self.brace_spans, + self.command_spans + )) + for index in self.script_char_indices: + span_begin = self.match(r"\s*", pos=index + 1).end() + if span_begin in script_entity_dict.keys(): + span_end = script_entity_dict[span_begin] else: - match_obj = self.match(script_pattern, pos=span_begin) - if not match_obj: - script_name = { - "_": "subscript", - "^": "superscript" - }[script_char] - raise ValueError( - f"Unclear {script_name} detected while parsing. " - "Please use braces to clarify" - ) + match_obj = self.match(r".", pos=span_begin) + if match_obj is None: + continue span_end = match_obj.end() result.append((span_begin, span_end)) return result @@ -147,110 +175,102 @@ class MTex(LabelledString): def get_script_spans(self) -> list[Span]: return [ ( - self.search(r"\s*$", endpos=script_char_span[0]).start(), + self.match(r"[\s\S]*?(\s*)$", endpos=index).start(1), script_content_span[1] ) - for script_char_span, script_content_span in zip( - self.script_char_spans, self.script_content_spans + for index, script_content_span in zip( + self.script_char_indices, self.script_content_spans ) ] - # Parsing - def get_command_repl_items(self) -> list[tuple[Span, str]]: - color_related_command_dict = { - "color": (1, False), - "textcolor": (1, False), - "pagecolor": (1, True), - "colorbox": (1, True), - "fcolorbox": (2, True), - } result = [] - backslash_indices = self.backslash_indices - right_brace_indices = [ - right_index - for left_index, right_index in self.brace_index_pairs - ] - pattern = "".join([ - r"\\", - "(", - "|".join(color_related_command_dict.keys()), - ")", - r"(?![a-zA-Z])" - ]) - for match_obj in self.finditer(pattern): - span_begin, cmd_end = match_obj.span() - if span_begin not in backslash_indices: + brace_spans_dict = dict(self.brace_spans) + brace_begins = list(brace_spans_dict.keys()) + for cmd_span in self.command_spans: + cmd_name = self.get_substr(cmd_span) + if cmd_name not in TEX_COLOR_COMMANDS_DICT.keys(): continue - cmd_name = match_obj.group(1) - n_braces, substitute_cmd = color_related_command_dict[cmd_name] - span_end = self.take_nearest_value( - right_brace_indices, cmd_end, n_braces - ) + 1 + n_braces, substitute_cmd = TEX_COLOR_COMMANDS_DICT[cmd_name] + span_begin, span_end = cmd_span + for _ in range(n_braces): + span_end = brace_spans_dict[min(filter( + lambda index: index >= span_end, + brace_begins + ))] if substitute_cmd: - repl_str = "\\" + cmd_name + n_braces * "{black}" + repl_str = cmd_name + n_braces * "{black}" else: repl_str = "" result.append(((span_begin, span_end), repl_str)) return result - def get_extra_entity_spans(self) -> list[Span]: - return [ - self.match(r"\\([a-zA-Z]+|.)", pos=index).span() - for index in self.backslash_indices + def get_specified_spans(self) -> list[Span]: + # Match paired double braces (`{{...}}`). + sorted_brace_spans = sorted(self.brace_spans, key=lambda t: t[1]) + inner_brace_spans = [ + sorted_brace_spans[span_span[0]] + for _, span_span in self.compress_neighbours([ + (brace_span[0] + index, brace_span[1] - index) + for index, brace_span in enumerate(sorted_brace_spans) + ]) + if span_span[1] - span_span[0] >= 2 + ] + inner_brace_content_spans = [ + (span[0] + 1, span[1] - 1) + for span in inner_brace_spans + if span[1] - span[0] > 2 ] - def get_extra_ignored_spans(self) -> list[int]: - return self.script_char_spans.copy() - - def get_internal_specified_spans(self) -> list[Span]: - # Match paired double braces (`{{...}}`). - result = [] - reversed_brace_indices_dict = dict([ - pair[::-1] for pair in self.brace_index_pairs - ]) - skip = False - for prev_right_index, right_index in self.get_neighbouring_pairs( - list(reversed_brace_indices_dict.keys()) - ): - if skip: - skip = False - continue - if right_index != prev_right_index + 1: - continue - left_index = reversed_brace_indices_dict[right_index] - prev_left_index = reversed_brace_indices_dict[prev_right_index] - if left_index != prev_left_index - 1: - continue - result.append((left_index, right_index + 1)) - skip = True - return result - - def get_external_specified_spans(self) -> list[Span]: - return self.find_substrs(list(self.tex_to_color_map.keys())) + result = self.remove_redundancies(self.chain( + inner_brace_content_spans, + *[ + self.find_spans_by_selector(selector) + for selector in self.tex_to_color_map.keys() + ], + self.find_spans_by_selector(self.isolate) + )) + return list(filter( + lambda span: not any([ + entity_begin < index < entity_end + for index in span + for entity_begin, entity_end in self.command_spans + ]), + result + )) def get_label_span_list(self) -> list[Span]: + reversed_script_spans_dict = dict([ + script_span[::-1] for script_span in self.script_spans + ]) + skippable_indices = self.chain( + self.find_indices(r"\s"), + self.script_char_indices + ) result = self.script_content_spans.copy() - for span_begin, span_end in self.specified_spans: - shrinked_end = self.lslide(span_end, self.script_spans) - if span_begin >= shrinked_end: + for span in self.specified_spans: + span_begin, span_end = self.shrink_span(span, skippable_indices) + while span_end in reversed_script_spans_dict.keys(): + span_end = reversed_script_spans_dict[span_end] + if span_begin >= span_end: continue - shrinked_span = (span_begin, shrinked_end) + shrinked_span = (span_begin, span_end) if shrinked_span in result: continue result.append(shrinked_span) return result - def get_content(self, use_plain_file: bool) -> str: - if use_plain_file: - span_repl_dict = {} - else: - extended_label_span_list = [ - span - if span in self.script_content_spans - else (span[0], self.rslide(span[1], self.script_spans)) - for span in self.label_span_list - ] + def get_content(self, is_labelled: bool) -> str: + if is_labelled: + extended_label_span_list = [] + script_spans_dict = dict(self.script_spans) + for span in self.label_span_list: + if span not in self.script_content_spans: + span_begin, span_end = span + while span_end in script_spans_dict.keys(): + span_end = script_spans_dict[span_end] + span = (span_begin, span_end) + extended_label_span_list.append(span) inserted_string_pairs = [ (span, ( "{{" + self.get_color_command_str(label + 1), @@ -258,43 +278,52 @@ class MTex(LabelledString): )) for label, span in enumerate(extended_label_span_list) ] - span_repl_dict = self.generate_span_repl_dict( - inserted_string_pairs, - self.command_repl_items + result = self.get_replaced_string( + inserted_string_pairs, self.command_repl_items ) - result = self.get_replaced_substr(self.full_span, span_repl_dict) + else: + result = self.string if self.tex_environment: - result = "\n".join([ - f"\\begin{{{self.tex_environment}}}", - result, - f"\\end{{{self.tex_environment}}}" - ]) + if isinstance(self.tex_environment, str): + prefix = f"\\begin{{{self.tex_environment}}}" + suffix = f"\\end{{{self.tex_environment}}}" + else: + prefix, suffix = self.tex_environment + result = "\n".join([prefix, result, suffix]) if self.alignment: result = "\n".join([self.alignment, result]) - if use_plain_file: + if not is_labelled: result = "\n".join([ - self.get_color_command_str(self.hex_to_int(self.base_color)), + self.get_color_command_str(self.base_color_int), result ]) return result - @property - def has_predefined_local_colors(self) -> bool: - return bool(self.command_repl_items) - - # Post-parsing + # Selector def get_cleaned_substr(self, span: Span) -> str: - substr = super().get_cleaned_substr(span) - if not self.brace_index_pairs: - return substr + if not self.brace_spans: + brace_begins, brace_ends = [], [] + else: + brace_begins, brace_ends = zip(*self.brace_spans) + left_brace_indices = list(brace_begins) + right_brace_indices = [index - 1 for index in brace_ends] + skippable_indices = self.chain( + self.find_indices(r"\s"), + self.script_char_indices, + left_brace_indices, + right_brace_indices + ) + shrinked_span = self.shrink_span(span, skippable_indices) + + if shrinked_span[0] >= shrinked_span[1]: + return "" # Balance braces. - left_brace_indices, right_brace_indices = zip(*self.brace_index_pairs) unclosed_left_braces = 0 unclosed_right_braces = 0 - for index in range(*span): + for index in range(*shrinked_span): if index in left_brace_indices: unclosed_left_braces += 1 elif index in right_brace_indices: @@ -304,27 +333,25 @@ class MTex(LabelledString): unclosed_left_braces -= 1 return "".join([ unclosed_right_braces * "{", - substr, + self.get_substr(shrinked_span), unclosed_left_braces * "}" ]) # Method alias - def get_parts_by_tex(self, tex: str, **kwargs) -> VGroup: - return self.get_parts_by_string(tex, **kwargs) + def get_parts_by_tex(self, selector: Selector) -> VGroup: + return self.select_parts(selector) - def get_part_by_tex(self, tex: str, **kwargs) -> VMobject: - return self.get_part_by_string(tex, **kwargs) + def get_part_by_tex(self, selector: Selector) -> VGroup: + return self.select_part(selector) - def set_color_by_tex(self, tex: str, color: ManimColor, **kwargs): - return self.set_color_by_string(tex, color, **kwargs) + def set_color_by_tex(self, selector: Selector, color: ManimColor): + return self.set_parts_color(selector, color) def set_color_by_tex_to_color_map( - self, tex_to_color_map: dict[str, ManimColor], **kwargs + self, color_map: dict[Selector, ManimColor] ): - return self.set_color_by_string_to_color_map( - tex_to_color_map, **kwargs - ) + return self.set_parts_color_by_dict(color_map) def get_tex(self) -> str: return self.get_string() diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index c3c3be19..740eeee6 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -1,93 +1,63 @@ from __future__ import annotations -import os -import re -import itertools as it -from pathlib import Path from contextlib import contextmanager -import typing -from typing import Iterable, Sequence, Union +import os +from pathlib import Path +import re +import manimpango import pygments import pygments.formatters import pygments.lexers -from manimpango import MarkupUtils - +from manimlib.constants import DEFAULT_PIXEL_WIDTH, FRAME_WIDTH +from manimlib.constants import NORMAL from manimlib.logger import log -from manimlib.constants import * from manimlib.mobject.svg.labelled_string import LabelledString -from manimlib.utils.customization import get_customization -from manimlib.utils.tex_file_writing import tex_hash from manimlib.utils.config_ops import digest_config +from manimlib.utils.customization import get_customization from manimlib.utils.directories import get_downloads_dir from manimlib.utils.directories import get_text_dir -from manimlib.utils.iterables import remove_list_redundancies - +from manimlib.utils.tex_file_writing import tex_hash from typing import TYPE_CHECKING if TYPE_CHECKING: - from manimlib.mobject.types.vectorized_mobject import VMobject + from colour import Color + from typing import Iterable, Union + from manimlib.mobject.types.vectorized_mobject import VGroup - ManimColor = Union[str, colour.Color, Sequence[float]] + + ManimColor = Union[str, Color] Span = tuple[int, int] + Selector = Union[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]], + Iterable[Union[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]] + ]] + ] TEXT_MOB_SCALE_FACTOR = 0.0076 DEFAULT_LINE_SPACING_SCALE = 0.6 +# Ensure the canvas is large enough to hold all glyphs. +DEFAULT_CANVAS_WIDTH = 16384 +DEFAULT_CANVAS_HEIGHT = 16384 # See https://docs.gtk.org/Pango/pango_markup.html -# A tag containing two aliases will cause warning, -# so only use the first key of each group of aliases. -SPAN_ATTR_KEY_ALIAS_LIST = ( - ("font", "font_desc"), - ("font_family", "face"), - ("font_size", "size"), - ("font_style", "style"), - ("font_weight", "weight"), - ("font_variant", "variant"), - ("font_stretch", "stretch"), - ("font_features",), - ("foreground", "fgcolor", "color"), - ("background", "bgcolor"), - ("alpha", "fgalpha"), - ("background_alpha", "bgalpha"), - ("underline",), - ("underline_color",), - ("overline",), - ("overline_color",), - ("rise",), - ("baseline_shift",), - ("font_scale",), - ("strikethrough",), - ("strikethrough_color",), - ("fallback",), - ("lang",), - ("letter_spacing",), - ("gravity",), - ("gravity_hint",), - ("show",), - ("insert_hyphens",), - ("allow_breaks",), - ("line_height",), - ("text_transform",), - ("segment",), +MARKUP_COLOR_KEYS = ( + "foreground", "fgcolor", "color", + "background", "bgcolor", + "underline_color", + "overline_color", + "strikethrough_color" ) -COLOR_RELATED_KEYS = ( - "foreground", - "background", - "underline_color", - "overline_color", - "strikethrough_color" -) -SPAN_ATTR_KEY_CONVERSION = { - key: key_alias_list[0] - for key_alias_list in SPAN_ATTR_KEY_ALIAS_LIST - for key in key_alias_list -} -TAG_TO_ATTR_DICT = { +MARKUP_TAG_CONVERSION_DICT = { "b": {"font_weight": "bold"}, "big": {"font_size": "larger"}, "i": {"font_style": "italic"}, @@ -96,7 +66,7 @@ TAG_TO_ATTR_DICT = { "sup": {"baseline_shift": "superscript", "font_scale": "superscript"}, "small": {"font_size": "smaller"}, "tt": {"font_family": "monospace"}, - "u": {"underline": "single"}, + "u": {"underline": "single"} } @@ -120,7 +90,7 @@ class MarkupText(LabelledString): "justify": False, "indent": 0, "alignment": "LEFT", - "line_width_factor": None, + "line_width": None, "font": "", "slant": NORMAL, "weight": NORMAL, @@ -141,9 +111,7 @@ class MarkupText(LabelledString): if not self.font: self.font = get_customization()["style"]["font"] if self.is_markup: - validate_error = MarkupUtils.validate(text) - if validate_error: - raise ValueError(validate_error) + self.validate_markup_string(text) self.text = text super().__init__(text, **kwargs) @@ -165,7 +133,6 @@ class MarkupText(LabelledString): self.svg_default, self.path_string_config, self.base_color, - self.use_plain_file, self.isolate, self.text, self.is_markup, @@ -174,7 +141,7 @@ class MarkupText(LabelledString): self.justify, self.indent, self.alignment, - self.line_width_factor, + self.line_width, self.font, self.slant, self.weight, @@ -201,23 +168,32 @@ class MarkupText(LabelledString): kwargs[short_name] = kwargs.pop(long_name) def get_file_path_by_content(self, content: str) -> str: + hash_content = str(( + content, + self.justify, + self.indent, + self.alignment, + self.line_width + )) svg_file = os.path.join( - get_text_dir(), tex_hash(content) + ".svg" + get_text_dir(), tex_hash(hash_content) + ".svg" ) if not os.path.exists(svg_file): self.markup_to_svg(content, svg_file) return svg_file def markup_to_svg(self, markup_str: str, file_name: str) -> str: + self.validate_markup_string(markup_str) + # `manimpango` is under construction, # so the following code is intended to suit its interface alignment = _Alignment(self.alignment) - if self.line_width_factor is None: + if self.line_width is None: pango_width = -1 else: - pango_width = self.line_width_factor * DEFAULT_PIXEL_WIDTH + pango_width = self.line_width / FRAME_WIDTH * DEFAULT_PIXEL_WIDTH - return MarkupUtils.text2svg( + return manimpango.MarkupUtils.text2svg( text=markup_str, font="", # Already handled slant="NORMAL", # Already handled @@ -228,8 +204,8 @@ class MarkupText(LabelledString): file_name=file_name, START_X=0, START_Y=0, - width=DEFAULT_PIXEL_WIDTH, - height=DEFAULT_PIXEL_HEIGHT, + width=DEFAULT_CANVAS_WIDTH, + height=DEFAULT_CANVAS_HEIGHT, justify=self.justify, indent=self.indent, line_spacing=None, # Already handled @@ -237,13 +213,23 @@ class MarkupText(LabelledString): pango_width=pango_width ) - def pre_parse(self) -> None: - super().pre_parse() - self.tag_items_from_markup = self.get_tag_items_from_markup() - self.global_dict_from_config = self.get_global_dict_from_config() - self.local_dicts_from_markup = self.get_local_dicts_from_markup() - self.local_dicts_from_config = self.get_local_dicts_from_config() - self.predefined_attr_dicts = self.get_predefined_attr_dicts() + @staticmethod + def validate_markup_string(markup_str: str) -> None: + validate_error = manimpango.MarkupUtils.validate(markup_str) + if not validate_error: + return + raise ValueError( + f"Invalid markup string \"{markup_str}\"\n" + f"{validate_error}" + ) + + def parse(self) -> None: + self.global_attr_dict = self.get_global_attr_dict() + self.tag_pairs_from_markup = self.get_tag_pairs_from_markup() + self.tag_spans = self.get_tag_spans() + self.items_from_markup = self.get_items_from_markup() + self.specified_items = self.get_specified_items() + super().parse() # Toolkits @@ -254,87 +240,50 @@ class MarkupText(LabelledString): for key, val in attr_dict.items() ]) - @staticmethod - def merge_attr_dicts( - attr_dict_items: list[Span, str, typing.Any] - ) -> list[tuple[Span, dict[str, str]]]: - index_seq = [0] - attr_dict_list = [{}] - for span, attr_dict in attr_dict_items: - if span[0] >= span[1]: - continue - region_indices = [ - MarkupText.find_region_index(index_seq, index) - for index in span - ] - for flag in (1, 0): - if index_seq[region_indices[flag]] == span[flag]: - continue - region_index = region_indices[flag] - index_seq.insert(region_index + 1, span[flag]) - attr_dict_list.insert( - region_index + 1, attr_dict_list[region_index].copy() + # Parsing + + def get_global_attr_dict(self) -> dict[str, str]: + result = { + "foreground": self.int_to_hex(self.base_color_int), + "font_family": self.font, + "font_style": self.slant, + "font_weight": self.weight, + "font_size": str(self.font_size * 1024), + } + # `line_height` attribute is supported since Pango 1.50. + pango_version = manimpango.pango_version() + if tuple(map(int, pango_version.split("."))) < (1, 50): + if self.lsh is not None: + log.warning( + f"Pango version {pango_version} found (< 1.50), " + "unable to set `line_height` attribute" ) - region_indices[flag] += 1 - if flag == 0: - region_indices[1] += 1 - for key, val in attr_dict.items(): - if not key: - continue - for mid_dict in attr_dict_list[slice(*region_indices)]: - mid_dict[key] = val - return list(zip( - MarkupText.get_neighbouring_pairs(index_seq), attr_dict_list[:-1] - )) + else: + line_spacing_scale = self.lsh or DEFAULT_LINE_SPACING_SCALE + result["line_height"] = str(((line_spacing_scale) + 1) * 0.6) + return result - def find_substr_or_span( - self, substr_or_span: str | tuple[int | None, int | None] - ) -> list[Span]: - if isinstance(substr_or_span, str): - return self.find_substr(substr_or_span) - - span = tuple([ - ( - min(index, self.string_len) - if index >= 0 - else max(index + self.string_len, 0) - ) - if index is not None else default_index - for index, default_index in zip(substr_or_span, self.full_span) - ]) - if span[0] >= span[1]: - return [] - return [span] - - # Pre-parsing - - def get_tag_items_from_markup( + def get_tag_pairs_from_markup( self ) -> list[tuple[Span, Span, dict[str, str]]]: if not self.is_markup: return [] - tag_pattern = r"""<(/?)(\w+)\s*((?:\w+\s*\=\s*(['"]).*?\4\s*)*)>""" - attr_pattern = r"""(\w+)\s*\=\s*(['"])(.*?)\2""" + tag_pattern = r"""<(/?)(\w+)\s*((\w+\s*\=\s*(['"])[\s\S]*?\5\s*)*)>""" + attr_pattern = r"""(\w+)\s*\=\s*(['"])([\s\S]*?)\2""" begin_match_obj_stack = [] match_obj_pairs = [] - for match_obj in self.finditer(tag_pattern): + for match_obj in re.finditer(tag_pattern, self.string): if not match_obj.group(1): begin_match_obj_stack.append(match_obj) else: match_obj_pairs.append( (begin_match_obj_stack.pop(), match_obj) ) - if begin_match_obj_stack: - raise ValueError("Unclosed tag(s) detected") result = [] for begin_match_obj, end_match_obj in match_obj_pairs: tag_name = begin_match_obj.group(2) - if tag_name != end_match_obj.group(2): - raise ValueError("Unmatched tag names") - if end_match_obj.group(3): - raise ValueError("Attributes shan't exist in ending tags") if tag_name == "span": attr_dict = { match.group(1): match.group(3) @@ -342,189 +291,170 @@ class MarkupText(LabelledString): attr_pattern, begin_match_obj.group(3) ) } - elif tag_name in TAG_TO_ATTR_DICT.keys(): - if begin_match_obj.group(3): - raise ValueError( - f"Attributes shan't exist in tag '{tag_name}'" - ) - attr_dict = TAG_TO_ATTR_DICT[tag_name].copy() else: - raise ValueError(f"Unknown tag: '{tag_name}'") + attr_dict = MARKUP_TAG_CONVERSION_DICT.get(tag_name, {}) result.append( (begin_match_obj.span(), end_match_obj.span(), attr_dict) ) return result - def get_global_dict_from_config(self) -> dict[str, typing.Any]: - result = { - "line_height": ( - (self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1 - ) * 0.6, - "font_family": self.font, - "font_size": self.font_size * 1024, - "font_style": self.slant, - "font_weight": self.weight - } - result.update(self.global_config) - return result + def get_tag_spans(self) -> list[Span]: + return [ + tag_span + for begin_tag, end_tag, _ in self.tag_pairs_from_markup + for tag_span in (begin_tag, end_tag) + ] - def get_local_dicts_from_markup( - self - ) -> list[Span, dict[str, str]]: - return sorted([ - ((begin_tag_span[0], end_tag_span[1]), attr_dict) + def get_items_from_markup(self) -> list[Span]: + return [ + ((begin_tag_span[1], end_tag_span[0]), attr_dict) for begin_tag_span, end_tag_span, attr_dict - in self.tag_items_from_markup - ]) + in self.tag_pairs_from_markup + if begin_tag_span[1] < end_tag_span[0] + ] - def get_local_dicts_from_config( - self - ) -> list[Span, dict[str, typing.Any]]: + def get_specified_items(self) -> list[tuple[Span, dict[str, str]]]: + result = self.chain( + self.items_from_markup, + [ + (span, {key: val}) + for t2x_dict, key in ( + (self.t2c, "foreground"), + (self.t2f, "font_family"), + (self.t2s, "font_style"), + (self.t2w, "font_weight") + ) + for selector, val in t2x_dict.items() + for span in self.find_spans_by_selector(selector) + ], + [ + (span, local_config) + for selector, local_config in self.local_configs.items() + for span in self.find_spans_by_selector(selector) + ], + [ + (span, {}) + for span in self.find_spans_by_selector(self.isolate) + ] + ) + entity_spans = self.tag_spans.copy() + if self.is_markup: + entity_spans.extend(self.find_spans(r"&[\s\S]*?;")) return [ - (span, {key: val}) - for t2x_dict, key in ( - (self.t2c, "foreground"), - (self.t2f, "font_family"), - (self.t2s, "font_style"), - (self.t2w, "font_weight") - ) - for substr_or_span, val in t2x_dict.items() - for span in self.find_substr_or_span(substr_or_span) - ] + [ - (span, local_config) - for substr_or_span, local_config in self.local_configs.items() - for span in self.find_substr_or_span(substr_or_span) + (span, attr_dict) + for span, attr_dict in result + if not any([ + entity_begin < index < entity_end + for index in span + for entity_begin, entity_end in entity_spans + ]) ] - def get_predefined_attr_dicts(self) -> list[Span, dict[str, str]]: - attr_dict_items = [ - (self.full_span, self.global_dict_from_config), - *self.local_dicts_from_markup, - *self.local_dicts_from_config - ] - return [ - (span, { - SPAN_ATTR_KEY_CONVERSION[key.lower()]: str(val) - for key, val in attr_dict.items() - }) - for span, attr_dict in attr_dict_items - ] - - # Parsing - def get_command_repl_items(self) -> list[tuple[Span, str]]: result = [ - (tag_span, "") - for begin_tag, end_tag, _ in self.tag_items_from_markup - for tag_span in (begin_tag, end_tag) + (tag_span, "") for tag_span in self.tag_spans ] if not self.is_markup: - result += [ + result.extend([ (span, escaped) for char, escaped in ( ("&", "&"), (">", ">"), ("<", "<") ) - for span in self.find_substr(char) - ] + for span in self.find_spans(re.escape(char)) + ]) return result - def get_extra_entity_spans(self) -> list[Span]: - if not self.is_markup: - return [] - return self.find_spans(r"&.*?;") - - def get_extra_ignored_spans(self) -> list[int]: - return [] - - def get_internal_specified_spans(self) -> list[Span]: - return [span for span, _ in self.local_dicts_from_markup] - - def get_external_specified_spans(self) -> list[Span]: - return [span for span, _ in self.local_dicts_from_config] + def get_specified_spans(self) -> list[Span]: + return self.remove_redundancies([ + span for span, _ in self.specified_items + ]) def get_label_span_list(self) -> list[Span]: - breakup_indices = remove_list_redundancies(list(it.chain(*it.chain( - self.find_spans(r"\s+"), - self.find_spans(r"\b"), - self.specified_spans - )))) - breakup_indices = sorted(filter( - lambda index: not any([ - span[0] < index < span[1] - for span in self.entity_spans - ]), - breakup_indices - )) - return list(filter( - lambda span: self.get_substr(span).strip(), - self.get_neighbouring_pairs(breakup_indices) - )) - - def get_content(self, use_plain_file: bool) -> str: - if use_plain_file: - attr_dict_items = [ - (self.full_span, {"foreground": self.base_color}), - *self.predefined_attr_dicts, - *[ - (span, {}) - for span in self.label_span_list - ] + interval_spans = sorted(self.chain( + self.tag_spans, + [ + (index, index) + for span in self.specified_spans + for index in span ] + )) + text_spans = self.get_complement_spans(interval_spans, self.full_span) + if self.is_markup: + pattern = r"[0-9a-zA-Z]+|(?:&[\s\S]*?;|[^0-9a-zA-Z\s])+" else: - attr_dict_items = [ - (self.full_span, {"foreground": BLACK}), - *[ + pattern = r"[0-9a-zA-Z]+|[^0-9a-zA-Z\s]+" + return self.chain(*[ + self.find_spans(pattern, pos=span_begin, endpos=span_end) + for span_begin, span_end in text_spans + ]) + + def get_content(self, is_labelled: bool) -> str: + predefined_items = [ + (self.full_span, self.global_attr_dict), + (self.full_span, self.global_config), + *self.specified_items + ] + if is_labelled: + attr_dict_items = self.chain( + [ (span, { - key: BLACK if key in COLOR_RELATED_KEYS else val + key: + "black" if key.lower() in MARKUP_COLOR_KEYS else val for key, val in attr_dict.items() }) - for span, attr_dict in self.predefined_attr_dicts + for span, attr_dict in predefined_items ], - *[ + [ (span, {"foreground": self.int_to_hex(label + 1)}) for label, span in enumerate(self.label_span_list) ] - ] + ) + else: + attr_dict_items = self.chain( + predefined_items, + [ + (span, {}) + for span in self.label_span_list + ] + ) inserted_string_pairs = [ (span, ( f"", "" )) - for span, attr_dict in self.merge_attr_dicts(attr_dict_items) + for span, attr_dict in attr_dict_items if attr_dict ] - span_repl_dict = self.generate_span_repl_dict( + return self.get_replaced_string( inserted_string_pairs, self.command_repl_items ) - return self.get_replaced_substr(self.full_span, span_repl_dict) - @property - def has_predefined_local_colors(self) -> bool: - return any([ - key in COLOR_RELATED_KEYS - for _, attr_dict in self.predefined_attr_dicts - for key in attr_dict.keys() - ]) + # Selector + + def get_cleaned_substr(self, span: Span) -> str: + repl_items = list(filter( + lambda repl_item: self.span_contains(span, repl_item[0]), + self.command_repl_items + )) + return self.get_replaced_substr(span, repl_items).strip() # Method alias - def get_parts_by_text(self, text: str, **kwargs) -> VGroup: - return self.get_parts_by_string(text, **kwargs) + def get_parts_by_text(self, selector: Selector) -> VGroup: + return self.select_parts(selector) - def get_part_by_text(self, text: str, **kwargs) -> VMobject: - return self.get_part_by_string(text, **kwargs) + def get_part_by_text(self, selector: Selector) -> VGroup: + return self.select_part(selector) - def set_color_by_text(self, text: str, color: ManimColor, **kwargs): - return self.set_color_by_string(text, color, **kwargs) + def set_color_by_text(self, selector: Selector, color: ManimColor): + return self.set_parts_color(selector, color) def set_color_by_text_to_color_map( - self, text_to_color_map: dict[str, ManimColor], **kwargs + self, color_map: dict[Selector, ManimColor] ): - return self.set_color_by_string_to_color_map( - text_to_color_map, **kwargs - ) + return self.set_parts_color_by_dict(color_map) def get_text(self) -> str: return self.get_string() diff --git a/manimlib/utils/iterables.py b/manimlib/utils/iterables.py index 99788a42..fa54e68d 100644 --- a/manimlib/utils/iterables.py +++ b/manimlib/utils/iterables.py @@ -1,14 +1,19 @@ from __future__ import annotations -from typing import Callable, Iterable, Sequence, TypeVar +from colour import Color import numpy as np -T = TypeVar("T") -S = TypeVar("S") +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Callable, Iterable, Sequence, TypeVar + + T = TypeVar("T") + S = TypeVar("S") -def remove_list_redundancies(l: Iterable[T]) -> list[T]: +def remove_list_redundancies(l: Sequence[T]) -> list[T]: """ Used instead of list(set(l)) to maintain order Keeps the last occurrence of each element @@ -35,14 +40,14 @@ def list_difference_update(l1: Iterable[T], l2: Iterable[T]) -> list[T]: return [e for e in l1 if e not in l2] -def adjacent_n_tuples(objects: Iterable[T], n: int) -> zip[tuple[T, T]]: +def adjacent_n_tuples(objects: Sequence[T], n: int) -> zip[tuple[T, T]]: return zip(*[ [*objects[k:], *objects[:k]] for k in range(n) ]) -def adjacent_pairs(objects: Iterable[T]) -> zip[tuple[T, T]]: +def adjacent_pairs(objects: Sequence[T]) -> zip[tuple[T, T]]: return adjacent_n_tuples(objects, 2) @@ -76,7 +81,7 @@ def batch_by_property( return batch_prop_pairs -def listify(obj) -> list: +def listify(obj: object) -> list: if isinstance(obj, str): return [obj] try: @@ -130,10 +135,17 @@ def make_even( def hash_obj(obj: object) -> int: if isinstance(obj, dict): - new_obj = {k: hash_obj(v) for k, v in obj.items()} - return hash(tuple(frozenset(sorted(new_obj.items())))) + return hash(tuple(sorted([ + (hash_obj(k), hash_obj(v)) for k, v in obj.items() + ]))) - if isinstance(obj, (set, tuple, list)): + if isinstance(obj, set): + return hash(tuple(sorted(hash_obj(e) for e in obj))) + + if isinstance(obj, (tuple, list)): return hash(tuple(hash_obj(e) for e in obj)) + if isinstance(obj, Color): + return hash(obj.get_rgb()) + return hash(obj) From 30e33b1baaf7e677d59702d65a04c280603909a5 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Sun, 24 Apr 2022 08:24:27 +0800 Subject: [PATCH 02/11] Merge branch 'refactor' into master --- manimlib/animation/creation.py | 14 +- .../animation/transform_matching_parts.py | 116 ++-- manimlib/mobject/svg/labelled_string.py | 590 +++++++----------- manimlib/mobject/svg/mtex_mobject.py | 375 +++++------ manimlib/mobject/svg/text_mobject.py | 480 ++++++-------- manimlib/utils/iterables.py | 6 +- 6 files changed, 706 insertions(+), 875 deletions(-) diff --git a/manimlib/animation/creation.py b/manimlib/animation/creation.py index 27460899..6ad6a9bd 100644 --- a/manimlib/animation/creation.py +++ b/manimlib/animation/creation.py @@ -1,12 +1,12 @@ from __future__ import annotations -import itertools as it -from abc import abstractmethod +from abc import ABC, abstractmethod import numpy as np from manimlib.animation.animation import Animation from manimlib.mobject.svg.labelled_string import LabelledString +from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.bezier import integer_interpolate from manimlib.utils.config_ops import digest_config @@ -17,10 +17,10 @@ from manimlib.utils.rate_functions import smooth from typing import TYPE_CHECKING if TYPE_CHECKING: - from manimlib.mobject.mobject import Group + from manimlib.mobject.mobject import Mobject -class ShowPartial(Animation): +class ShowPartial(Animation, ABC): """ Abstract class for ShowCreation and ShowPassingFlash """ @@ -176,7 +176,7 @@ class ShowIncreasingSubsets(Animation): "int_func": np.round, } - def __init__(self, group: Group, **kwargs): + def __init__(self, group: Mobject, **kwargs): self.all_submobs = list(group.submobjects) super().__init__(group, **kwargs) @@ -213,7 +213,9 @@ class AddTextWordByWord(ShowIncreasingSubsets): def __init__(self, string_mobject, **kwargs): assert isinstance(string_mobject, LabelledString) - grouped_mobject = string_mobject.submob_groups + grouped_mobject = VGroup(*[ + part for _, part in string_mobject.get_group_part_items() + ]) digest_config(self, kwargs) if self.run_time is None: self.run_time = self.time_per_word * len(grouped_mobject) diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index dab88005..e84f1d9d 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -5,9 +5,9 @@ import itertools as it import numpy as np from manimlib.animation.composition import AnimationGroup -from manimlib.animation.fading import FadeTransformPieces from manimlib.animation.fading import FadeInFromPoint from manimlib.animation.fading import FadeOutToPoint +from manimlib.animation.fading import FadeTransformPieces from manimlib.animation.transform import ReplacementTransform from manimlib.animation.transform import Transform from manimlib.mobject.mobject import Mobject @@ -16,13 +16,13 @@ from manimlib.mobject.svg.labelled_string import LabelledString from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.config_ops import digest_config -from manimlib.utils.iterables import remove_list_redundancies from typing import TYPE_CHECKING if TYPE_CHECKING: + from manimlib.mobject.svg.tex_mobject import SingleStringTex + from manimlib.mobject.svg.tex_mobject import Tex from manimlib.scene.scene import Scene - from manimlib.mobject.svg.tex_mobject import Tex, SingleStringTex class TransformMatchingParts(AnimationGroup): @@ -168,36 +168,36 @@ class TransformMatchingStrings(AnimationGroup): assert isinstance(source, LabelledString) assert isinstance(target, LabelledString) anims = [] - source_indices = list(range(len(source.labelled_submobjects))) - target_indices = list(range(len(target.labelled_submobjects))) - def get_indices_lists(mobject, parts): - return [ + source_submobs = [ + submob for _, submob in source.labelled_submobject_items + ] + target_submobs = [ + submob for _, submob in target.labelled_submobject_items + ] + source_indices = list(range(len(source_submobs))) + target_indices = list(range(len(target_submobs))) + + def get_filtered_indices_lists(parts, submobs, rest_indices): + return list(filter( + lambda indices_list: all([ + index in rest_indices + for index in indices_list + ]), [ - mobject.labelled_submobjects.index(submob) - for submob in part + [submobs.index(submob) for submob in part] + for part in parts ] - for part in parts - ] + )) - def add_anims_from(anim_class, func, source_args, target_args=None): - if target_args is None: - target_args = source_args.copy() - for source_arg, target_arg in zip(source_args, target_args): - source_parts = func(source, source_arg) - target_parts = func(target, target_arg) - source_indices_lists = list(filter( - lambda indices_list: all([ - index in source_indices - for index in indices_list - ]), get_indices_lists(source, source_parts) - )) - target_indices_lists = list(filter( - lambda indices_list: all([ - index in target_indices - for index in indices_list - ]), get_indices_lists(target, target_parts) - )) + def add_anims(anim_class, parts_pairs): + for source_parts, target_parts in parts_pairs: + source_indices_lists = get_filtered_indices_lists( + source_parts, source_submobs, source_indices + ) + target_indices_lists = get_filtered_indices_lists( + target_parts, target_submobs, target_indices + ) if not source_indices_lists or not target_indices_lists: continue anims.append(anim_class(source_parts, target_parts, **kwargs)) @@ -206,41 +206,45 @@ class TransformMatchingStrings(AnimationGroup): for index in it.chain(*target_indices_lists): target_indices.remove(index) - def get_common_substrs(substrs_from_source, substrs_from_target): - return sorted([ - substr for substr in substrs_from_source - if substr and substr in substrs_from_target - ], key=len, reverse=True) - - def get_parts_from_keys(mobject, keys): - if isinstance(keys, str): - keys = [keys] - result = VGroup() - for key in keys: - if not isinstance(key, str): - raise TypeError(key) - result.add(*mobject.get_parts_by_string(key)) + def get_substr_to_parts_map(part_items): + result = {} + for substr, part in part_items: + if substr not in result: + result[substr] = [] + result[substr].append(part) return result - add_anims_from( - ReplacementTransform, get_parts_from_keys, - self.key_map.keys(), self.key_map.values() + def add_anims_from(anim_class, func): + source_substr_to_parts_map = get_substr_to_parts_map(func(source)) + target_substr_to_parts_map = get_substr_to_parts_map(func(target)) + add_anims( + anim_class, + [ + ( + VGroup(*source_substr_to_parts_map[substr]), + VGroup(*target_substr_to_parts_map[substr]) + ) + for substr in sorted([ + s for s in source_substr_to_parts_map.keys() + if s and s in target_substr_to_parts_map.keys() + ], key=len, reverse=True) + ] + ) + + add_anims( + ReplacementTransform, + [ + (source.select_parts(k), target.select_parts(v)) + for k, v in self.key_map.items() + ] ) add_anims_from( FadeTransformPieces, - LabelledString.get_parts_by_string, - get_common_substrs( - source.specified_substrs, - target.specified_substrs - ) + LabelledString.get_specified_part_items ) add_anims_from( FadeTransformPieces, - LabelledString.get_parts_by_group_substr, - get_common_substrs( - source.group_substrs, - target.group_substrs - ) + LabelledString.get_group_part_items ) rest_source = VGroup(*[source[index] for index in source_indices]) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index f1354f0c..1c3f0afd 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -1,30 +1,43 @@ from __future__ import annotations -import re -import colour -import itertools as it -from typing import Iterable, Union, Sequence from abc import ABC, abstractmethod +import itertools as it +import numpy as np +import re -from manimlib.constants import BLACK, WHITE +from manimlib.constants import WHITE from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.types.vectorized_mobject import VGroup -from manimlib.utils.color import color_to_int_rgb from manimlib.utils.color import color_to_rgb from manimlib.utils.color import rgb_to_hex from manimlib.utils.config_ops import digest_config from manimlib.utils.iterables import remove_list_redundancies - from typing import TYPE_CHECKING if TYPE_CHECKING: - from manimlib.mobject.types.vectorized_mobject import VMobject - ManimColor = Union[str, colour.Color, Sequence[float]] + from colour import Color + from typing import Iterable, Sequence, TypeVar, Union + + ManimColor = Union[str, Color] Span = tuple[int, int] + Selector = Union[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]], + Iterable[Union[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]] + ]] + ] + T = TypeVar("T") -class _StringSVG(SVGMobject): +class LabelledString(SVGMobject, ABC): + """ + An abstract base class for `MTex` and `MarkupText` + """ CONFIG = { "height": None, "stroke_width": 0, @@ -33,42 +46,30 @@ class _StringSVG(SVGMobject): "should_subdivide_sharp_curves": True, "should_remove_null_curves": True, }, - } - - -class LabelledString(_StringSVG, ABC): - """ - An abstract base class for `MTex` and `MarkupText` - """ - CONFIG = { "base_color": WHITE, - "use_plain_file": False, "isolate": [], } def __init__(self, string: str, **kwargs): self.string = string digest_config(self, kwargs) + if self.base_color is None: + self.base_color = WHITE + self.base_color_int = self.color_to_int(self.base_color) - # Convert `base_color` to hex code. - self.base_color = rgb_to_hex(color_to_rgb( - self.base_color \ - or self.svg_default.get("color", None) \ - or self.svg_default.get("fill_color", None) \ - or WHITE - )) - self.svg_default["fill_color"] = BLACK - - self.pre_parse() + self.full_span = (0, len(self.string)) self.parse() - super().__init__() - self.post_parse() + super().__init__(**kwargs) + self.labelled_submobject_items = [ + (submob.label, submob) + for submob in self.submobjects + ] def get_file_path(self) -> str: - return self.get_file_path_(use_plain_file=False) + return self.get_file_path_(is_labelled=False) - def get_file_path_(self, use_plain_file: bool) -> str: - content = self.get_content(use_plain_file) + def get_file_path_(self, is_labelled: bool) -> str: + content = self.get_content(is_labelled) return self.get_file_path_by_content(content) @abstractmethod @@ -78,91 +79,135 @@ class LabelledString(_StringSVG, ABC): def generate_mobject(self) -> None: super().generate_mobject() - submob_labels = [ - self.color_to_label(submob.get_fill_color()) - for submob in self.submobjects - ] - if self.use_plain_file or self.has_predefined_local_colors: - file_path = self.get_file_path_(use_plain_file=True) - plain_svg = _StringSVG( - file_path, - svg_default=self.svg_default, - path_string_config=self.path_string_config - ) - self.set_submobjects(plain_svg.submobjects) + num_labels = len(self.label_span_list) + if num_labels: + file_path = self.get_file_path_(is_labelled=True) + labelled_svg = SVGMobject(file_path) + submob_color_ints = [ + self.color_to_int(submob.get_fill_color()) + for submob in labelled_svg.submobjects + ] else: - self.set_fill(self.base_color) - for submob, label in zip(self.submobjects, submob_labels): - submob.label = label + submob_color_ints = [0] * len(self.submobjects) - def pre_parse(self) -> None: - self.string_len = len(self.string) - self.full_span = (0, self.string_len) + if len(self.submobjects) != len(submob_color_ints): + raise ValueError( + "Cannot align submobjects of the labelled svg " + "to the original svg" + ) + + unrecognized_color_ints = self.remove_redundancies(sorted(filter( + lambda color_int: color_int > num_labels, + submob_color_ints + ))) + if unrecognized_color_ints: + raise ValueError( + "Unrecognized color label(s) detected: " + f"{','.join(map(self.int_to_hex, unrecognized_color_ints))}" + ) + + for submob, color_int in zip(self.submobjects, submob_color_ints): + submob.label = color_int - 1 def parse(self) -> None: self.command_repl_items = self.get_command_repl_items() - self.command_spans = self.get_command_spans() - self.extra_entity_spans = self.get_extra_entity_spans() - self.entity_spans = self.get_entity_spans() - self.extra_ignored_spans = self.get_extra_ignored_spans() - self.skipped_spans = self.get_skipped_spans() - self.internal_specified_spans = self.get_internal_specified_spans() - self.external_specified_spans = self.get_external_specified_spans() self.specified_spans = self.get_specified_spans() - self.label_span_list = self.get_label_span_list() self.check_overlapping() - - def post_parse(self) -> None: - self.labelled_submobject_items = [ - (submob.label, submob) - for submob in self.submobjects - ] - self.labelled_submobjects = self.get_labelled_submobjects() - self.specified_substrs = self.get_specified_substrs() - self.group_items = self.get_group_items() - self.group_substrs = self.get_group_substrs() - self.submob_groups = self.get_submob_groups() + self.label_span_list = self.get_label_span_list() + if len(self.label_span_list) >= 16777216: + raise ValueError("Cannot handle that many substrings") # Toolkits def get_substr(self, span: Span) -> str: return self.string[slice(*span)] - def finditer( - self, pattern: str, flags: int = 0, **kwargs - ) -> Iterable[re.Match]: - return re.compile(pattern, flags).finditer(self.string, **kwargs) + def match(self, pattern: str | re.Pattern, **kwargs) -> re.Pattern | None: + if isinstance(pattern, str): + pattern = re.compile(pattern) + return re.compile(pattern).match(self.string, **kwargs) - def search( - self, pattern: str, flags: int = 0, **kwargs - ) -> re.Match | None: - return re.compile(pattern, flags).search(self.string, **kwargs) - - def match( - self, pattern: str, flags: int = 0, **kwargs - ) -> re.Match | None: - return re.compile(pattern, flags).match(self.string, **kwargs) - - def find_spans(self, pattern: str, **kwargs) -> list[Span]: + def find_spans(self, pattern: str | re.Pattern, **kwargs) -> list[Span]: + if isinstance(pattern, str): + pattern = re.compile(pattern) return [ match_obj.span() - for match_obj in self.finditer(pattern, **kwargs) + for match_obj in pattern.finditer(self.string, **kwargs) ] - def find_substr(self, substr: str, **kwargs) -> list[Span]: - if not substr: - return [] - return self.find_spans(re.escape(substr), **kwargs) - - def find_substrs(self, substrs: list[str], **kwargs) -> list[Span]: - return list(it.chain(*[ - self.find_substr(substr, **kwargs) - for substr in remove_list_redundancies(substrs) - ])) + def find_indices(self, pattern: str | re.Pattern, **kwargs) -> list[int]: + return [index for index, _ in self.find_spans(pattern, **kwargs)] @staticmethod - def get_neighbouring_pairs(iterable: list) -> list[tuple]: - return list(zip(iterable[:-1], iterable[1:])) + def is_single_selector(selector: Selector) -> bool: + if isinstance(selector, str): + return True + if isinstance(selector, re.Pattern): + return True + if isinstance(selector, tuple): + if len(selector) == 2 and all([ + isinstance(index, int) or index is None + for index in selector + ]): + return True + return False + + def find_spans_by_selector(self, selector: Selector) -> list[Span]: + if self.is_single_selector(selector): + selector = (selector,) + result = [] + for sel in selector: + if not self.is_single_selector(sel): + raise TypeError(f"Invalid selector: '{sel}'") + if isinstance(sel, str): + spans = self.find_spans(re.escape(sel)) + elif isinstance(sel, re.Pattern): + spans = self.find_spans(sel) + else: + string_len = self.full_span[1] + span = tuple([ + ( + min(index, string_len) + if index >= 0 + else max(index + string_len, 0) + ) + if index is not None else default_index + for index, default_index in zip(sel, self.full_span) + ]) + spans = [span] + result.extend(spans) + return sorted(filter( + lambda span: span[0] < span[1], + self.remove_redundancies(result) + )) + + @staticmethod + def chain(*iterables: Iterable[T]) -> list[T]: + return list(it.chain(*iterables)) + + @staticmethod + def remove_redundancies(vals: Sequence[T]) -> list[T]: + return remove_list_redundancies(vals) + + @staticmethod + def get_neighbouring_pairs(vals: Sequence[T]) -> list[tuple[T, T]]: + return list(zip(vals[:-1], vals[1:])) + + @staticmethod + def compress_neighbours(vals: Sequence[T]) -> list[tuple[T, Span]]: + if not vals: + return [] + + unique_vals = [vals[0]] + indices = [0] + for index, val in enumerate(vals): + if val == unique_vals[-1]: + continue + unique_vals.append(val) + indices.append(index) + indices.append(len(vals)) + spans = LabelledString.get_neighbouring_pairs(indices) + return list(zip(unique_vals, spans)) @staticmethod def span_contains(span_0: Span, span_1: Span) -> bool: @@ -182,194 +227,88 @@ class LabelledString(_StringSVG, ABC): )) @staticmethod - def compress_neighbours(vals: list[int]) -> list[tuple[int, Span]]: - if not vals: + def merge_inserted_strings_from_pairs( + inserted_string_pairs: list[tuple[Span, tuple[str, str]]] + ) -> list[tuple[int, str]]: + if not inserted_string_pairs: return [] - unique_vals = [vals[0]] - indices = [0] - for index, val in enumerate(vals): - if val == unique_vals[-1]: - continue - unique_vals.append(val) - indices.append(index) - indices.append(len(vals)) - spans = LabelledString.get_neighbouring_pairs(indices) - return list(zip(unique_vals, spans)) - - @staticmethod - def find_region_index(seq: list[int], val: int) -> int: - # Returns an integer in `range(-1, len(seq))` satisfying - # `seq[result] <= val < seq[result + 1]`. - # `seq` should be sorted in ascending order. - if not seq or val < seq[0]: - return -1 - result = len(seq) - 1 - while val < seq[result]: - result -= 1 - return result - - @staticmethod - def take_nearest_value(seq: list[int], val: int, index_shift: int) -> int: - sorted_seq = sorted(seq) - index = LabelledString.find_region_index(sorted_seq, val) - return sorted_seq[index + index_shift] - - @staticmethod - def generate_span_repl_dict( - inserted_string_pairs: list[tuple[Span, tuple[str, str]]], - other_repl_items: list[tuple[Span, str]] - ) -> dict[Span, str]: - result = dict(other_repl_items) - if not inserted_string_pairs: - return result - - indices, _, _, inserted_strings = zip(*sorted([ - ( - span[flag], - -flag, - -span[1 - flag], - str_pair[flag] + spans = [ + span for span, _ in inserted_string_pairs + ] + sorted_index_flag_pairs = sorted( + it.product(range(len(spans)), range(2)), + key=lambda t: ( + spans[t[0]][t[1]], + np.sign(spans[t[0]][1 - t[1]] - spans[t[0]][t[1]]), + -spans[t[0]][1 - t[1]], + t[1], + (1, -1)[t[1]] * t[0] ) - for span, str_pair in inserted_string_pairs - for flag in range(2) - ])) - result.update({ - (index, index): "".join(inserted_strings[slice(*item_span)]) + ) + indices, inserted_strings = zip(*[ + list(zip(*inserted_string_pairs[item_index]))[flag] + for item_index, flag in sorted_index_flag_pairs + ]) + return [ + (index, "".join(inserted_strings[slice(*item_span)])) for index, item_span in LabelledString.compress_neighbours(indices) - }) - return result + ] def get_replaced_substr( - self, span: Span, span_repl_dict: dict[Span, str] + self, span: Span, repl_items: list[tuple[Span, str]] ) -> str: - repl_spans = sorted(filter( - lambda repl_span: self.span_contains(span, repl_span), - span_repl_dict.keys() - )) - if not all( - span_0[1] <= span_1[0] - for span_0, span_1 in self.get_neighbouring_pairs(repl_spans) - ): - raise ValueError("Overlapping replacement") + if not repl_items: + return self.get_substr(span) + sorted_repl_items = sorted(repl_items, key=lambda t: t[0]) + repl_spans, repl_strs = zip(*sorted_repl_items) pieces = [ self.get_substr(piece_span) for piece_span in self.get_complement_spans(repl_spans, span) ] - repl_strs = [span_repl_dict[repl_span] for repl_span in repl_spans] - repl_strs.append("") - return "".join(it.chain(*zip(pieces, repl_strs))) + repl_strs = [*repl_strs, ""] + return "".join(self.chain(*zip(pieces, repl_strs))) + + def get_replaced_string( + self, + inserted_string_pairs: list[tuple[Span, tuple[str, str]]], + repl_items: list[tuple[Span, str]] + ) -> str: + all_repl_items = self.chain( + repl_items, + [ + ((index, index), inserted_string) + for index, inserted_string + in self.merge_inserted_strings_from_pairs( + inserted_string_pairs + ) + ] + ) + return self.get_replaced_substr(self.full_span, all_repl_items) @staticmethod - def rslide(index: int, skipped: list[Span]) -> int: - transfer_dict = dict(sorted(skipped)) - while index in transfer_dict.keys(): - index = transfer_dict[index] - return index - - @staticmethod - def lslide(index: int, skipped: list[Span]) -> int: - transfer_dict = dict(sorted([ - skipped_span[::-1] for skipped_span in skipped - ], reverse=True)) - while index in transfer_dict.keys(): - index = transfer_dict[index] - return index - - @staticmethod - def rgb_to_int(rgb_tuple: tuple[int, int, int]) -> int: - r, g, b = rgb_tuple - rg = r * 256 + g - return rg * 256 + b - - @staticmethod - def int_to_rgb(rgb_int: int) -> tuple[int, int, int]: - rg, b = divmod(rgb_int, 256) - r, g = divmod(rg, 256) - return r, g, b + def color_to_int(color: ManimColor) -> int: + hex_code = rgb_to_hex(color_to_rgb(color)) + return int(hex_code[1:], 16) @staticmethod def int_to_hex(rgb_int: int) -> str: return "#{:06x}".format(rgb_int).upper() - @staticmethod - def hex_to_int(rgb_hex: str) -> int: - return int(rgb_hex[1:], 16) - - @staticmethod - def color_to_label(color: ManimColor) -> int: - rgb_tuple = color_to_int_rgb(color) - rgb = LabelledString.rgb_to_int(rgb_tuple) - return rgb - 1 - # Parsing @abstractmethod def get_command_repl_items(self) -> list[tuple[Span, str]]: return [] - def get_command_spans(self) -> list[Span]: - return [cmd_span for cmd_span, _ in self.command_repl_items] - @abstractmethod - def get_extra_entity_spans(self) -> list[Span]: - return [] - - def get_entity_spans(self) -> list[Span]: - return list(it.chain( - self.command_spans, - self.extra_entity_spans - )) - - @abstractmethod - def get_extra_ignored_spans(self) -> list[int]: - return [] - - def get_skipped_spans(self) -> list[Span]: - return list(it.chain( - self.find_spans(r"\s"), - self.command_spans, - self.extra_ignored_spans - )) - - def shrink_span(self, span: Span) -> Span: - return ( - self.rslide(span[0], self.skipped_spans), - self.lslide(span[1], self.skipped_spans) - ) - - @abstractmethod - def get_internal_specified_spans(self) -> list[Span]: - return [] - - @abstractmethod - def get_external_specified_spans(self) -> list[Span]: - return [] - def get_specified_spans(self) -> list[Span]: - spans = list(it.chain( - self.internal_specified_spans, - self.external_specified_spans, - self.find_substrs(self.isolate) - )) - shrinked_spans = list(filter( - lambda span: span[0] < span[1] and not any([ - entity_span[0] < index < entity_span[1] - for index in span - for entity_span in self.entity_spans - ]), - [self.shrink_span(span) for span in spans] - )) - return remove_list_redundancies(shrinked_spans) - - @abstractmethod - def get_label_span_list(self) -> list[Span]: return [] def check_overlapping(self) -> None: - for span_0, span_1 in it.product(self.label_span_list, repeat=2): + for span_0, span_1 in it.product(self.specified_spans, repeat=2): if not span_0[0] < span_1[0] < span_0[1] < span_1[1]: continue raise ValueError( @@ -378,29 +317,20 @@ class LabelledString(_StringSVG, ABC): ) @abstractmethod - def get_content(self, use_plain_file: bool) -> str: - return "" + def get_label_span_list(self) -> list[Span]: + return [] @abstractmethod - def has_predefined_local_colors(self) -> bool: - return False + def get_content(self, is_labelled: bool) -> str: + return "" - # Post-parsing - - def get_labelled_submobjects(self) -> list[VMobject]: - return [submob for _, submob in self.labelled_submobject_items] + # Selector + @abstractmethod def get_cleaned_substr(self, span: Span) -> str: - span_repl_dict = dict.fromkeys(self.command_spans, "") - return self.get_replaced_substr(span, span_repl_dict) + return "" - def get_specified_substrs(self) -> list[str]: - return remove_list_redundancies([ - self.get_cleaned_substr(span) - for span in self.specified_spans - ]) - - def get_group_items(self) -> list[tuple[str, VGroup]]: + def get_group_part_items(self) -> list[tuple[str, VGroup]]: if not self.labelled_submobject_items: return [] @@ -425,118 +355,56 @@ class LabelledString(_StringSVG, ABC): ordered_spans ) ] - shrinked_spans = [ - self.shrink_span(span) + group_substrs = [ + self.get_cleaned_substr(span) if span[0] < span[1] else "" for span in self.get_complement_spans( interval_spans, (ordered_spans[0][0], ordered_spans[-1][1]) ) ] - group_substrs = [ - self.get_cleaned_substr(span) if span[0] < span[1] else "" - for span in shrinked_spans - ] submob_groups = VGroup(*[ VGroup(*labelled_submobjects[slice(*submob_span)]) for submob_span in labelled_submob_spans ]) return list(zip(group_substrs, submob_groups)) - def get_group_substrs(self) -> list[str]: - return [group_substr for group_substr, _ in self.group_items] - - def get_submob_groups(self) -> list[VGroup]: - return [submob_group for _, submob_group in self.group_items] - - def get_parts_by_group_substr(self, substr: str) -> VGroup: - return VGroup(*[ - group - for group_substr, group in self.group_items - if group_substr == substr - ]) - - # Selector - - def find_span_components( - self, custom_span: Span, substring: bool = True - ) -> list[Span]: - shrinked_span = self.shrink_span(custom_span) - if shrinked_span[0] >= shrinked_span[1]: - return [] - - if substring: - indices = remove_list_redundancies(list(it.chain( - self.full_span, - *self.label_span_list - ))) - span_begin = self.take_nearest_value( - indices, shrinked_span[0], 0 + def get_specified_part_items(self) -> list[tuple[str, VGroup]]: + return [ + ( + self.get_substr(span), + self.select_part_by_span(span) ) - span_end = self.take_nearest_value( - indices, shrinked_span[1] - 1, 1 - ) - else: - span_begin, span_end = shrinked_span + for span in self.specified_spans + ] - span_choices = sorted(filter( - lambda span: self.span_contains((span_begin, span_end), span), - self.label_span_list - )) - # Choose spans that reach the farthest. - span_choices_dict = dict(span_choices) - - result = [] - while span_begin < span_end: - if span_begin not in span_choices_dict.keys(): - span_begin += 1 - continue - next_begin = span_choices_dict[span_begin] - result.append((span_begin, next_begin)) - span_begin = next_begin - return result - - def get_part_by_custom_span(self, custom_span: Span, **kwargs) -> VGroup: + def select_part_by_span(self, custom_span: Span) -> VGroup: labels = [ label for label, span in enumerate(self.label_span_list) - if any([ - self.span_contains(span_component, span) - for span_component in self.find_span_components( - custom_span, **kwargs - ) - ]) + if self.span_contains(custom_span, span) ] return VGroup(*[ submob for label, submob in self.labelled_submobject_items if label in labels ]) - def get_parts_by_string( - self, substr: str, - case_sensitive: bool = True, regex: bool = False, **kwargs - ) -> VGroup: - flags = 0 - if not case_sensitive: - flags |= re.I - pattern = substr if regex else re.escape(substr) - return VGroup(*[ - self.get_part_by_custom_span(span, **kwargs) - for span in self.find_spans(pattern, flags=flags) - if span[0] < span[1] - ]) + def select_parts(self, selector: Selector) -> VGroup: + return VGroup(*filter( + lambda part: part.submobjects, + [ + self.select_part_by_span(span) + for span in self.find_spans_by_selector(selector) + ] + )) - def get_part_by_string( - self, substr: str, index: int = 0, **kwargs - ) -> VMobject: - return self.get_parts_by_string(substr, **kwargs)[index] + def select_part(self, selector: Selector, index: int = 0) -> VGroup: + return self.select_parts(selector)[index] - def set_color_by_string(self, substr: str, color: ManimColor, **kwargs): - self.get_parts_by_string(substr, **kwargs).set_color(color) + def set_parts_color(self, selector: Selector, color: ManimColor): + self.select_parts(selector).set_color(color) return self - def set_color_by_string_to_color_map( - self, string_to_color_map: dict[str, ManimColor], **kwargs - ): - for substr, color in string_to_color_map.items(): - self.set_color_by_string(substr, color, **kwargs) + def set_parts_color_by_dict(self, color_map: dict[Selector, ManimColor]): + for selector, color in color_map.items(): + self.set_parts_color(selector, color) return self def get_string(self) -> str: diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index fb7922e1..ed7273ee 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -1,27 +1,46 @@ from __future__ import annotations -import itertools as it -import colour -from typing import Union, Sequence +import re from manimlib.mobject.svg.labelled_string import LabelledString -from manimlib.utils.tex_file_writing import tex_to_svg_file -from manimlib.utils.tex_file_writing import get_tex_config from manimlib.utils.tex_file_writing import display_during_execution - +from manimlib.utils.tex_file_writing import get_tex_config +from manimlib.utils.tex_file_writing import tex_to_svg_file from typing import TYPE_CHECKING if TYPE_CHECKING: - from manimlib.mobject.types.vectorized_mobject import VMobject + from colour import Color + from typing import Iterable, Union + from manimlib.mobject.types.vectorized_mobject import VGroup - ManimColor = Union[str, colour.Color, Sequence[float]] + + ManimColor = Union[str, Color] Span = tuple[int, int] + Selector = Union[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]], + Iterable[Union[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]] + ]] + ] SCALE_FACTOR_PER_FONT_POINT = 0.001 +TEX_COLOR_COMMANDS_DICT = { + "\\color": (1, False), + "\\textcolor": (1, False), + "\\pagecolor": (1, True), + "\\colorbox": (1, True), + "\\fcolorbox": (2, True), +} + + class MTex(LabelledString): CONFIG = { "font_size": 48, @@ -32,7 +51,7 @@ class MTex(LabelledString): def __init__(self, tex_string: str, **kwargs): # Prevent from passing an empty string. - if not tex_string: + if not tex_string.strip(): tex_string = "\\\\" self.tex_string = tex_string super().__init__(tex_string, **kwargs) @@ -47,7 +66,6 @@ class MTex(LabelledString): self.svg_default, self.path_string_config, self.base_color, - self.use_plain_file, self.isolate, self.tex_string, self.alignment, @@ -61,85 +79,95 @@ class MTex(LabelledString): tex_config["text_to_replace"], content ) - with display_during_execution(f"Writing \"{self.tex_string}\""): + with display_during_execution(f"Writing \"{self.string}\""): file_path = tex_to_svg_file(full_tex) return file_path - def pre_parse(self) -> None: - super().pre_parse() + def parse(self) -> None: self.backslash_indices = self.get_backslash_indices() - self.brace_index_pairs = self.get_brace_index_pairs() - self.script_char_spans = self.get_script_char_spans() + self.command_spans = self.get_command_spans() + self.brace_spans = self.get_brace_spans() + self.script_char_indices = self.get_script_char_indices() self.script_content_spans = self.get_script_content_spans() self.script_spans = self.get_script_spans() + super().parse() # Toolkits @staticmethod def get_color_command_str(rgb_int: int) -> str: - rgb_tuple = MTex.int_to_rgb(rgb_int) - return "".join([ - "\\color[RGB]", - "{", - ",".join(map(str, rgb_tuple)), - "}" - ]) + rg, b = divmod(rgb_int, 256) + r, g = divmod(rg, 256) + return f"\\color[RGB]{{{r}, {g}, {b}}}" - # Pre-parsing + @staticmethod + def shrink_span(span: Span, skippable_indices: list[int]) -> Span: + span_begin, span_end = span + while span_begin in skippable_indices: + span_begin += 1 + while span_end - 1 in skippable_indices: + span_end -= 1 + return (span_begin, span_end) + + # Parsing def get_backslash_indices(self) -> list[int]: # The latter of `\\` doesn't count. - return list(it.chain(*[ - range(span[0], span[1], 2) - for span in self.find_spans(r"\\+") - ])) + return self.find_indices(r"\\.") - def get_unescaped_char_spans(self, chars: str): - return sorted(filter( - lambda span: span[0] - 1 not in self.backslash_indices, - self.find_substrs(list(chars)) + def get_command_spans(self) -> list[Span]: + return [ + self.match(r"\\(?:[a-zA-Z]+|.)", pos=index).span() + for index in self.backslash_indices + ] + + def get_unescaped_char_indices(self, char: str) -> list[int]: + return list(filter( + lambda index: index - 1 not in self.backslash_indices, + self.find_indices(re.escape(char)) )) - def get_brace_index_pairs(self) -> list[Span]: - left_brace_indices = [] - right_brace_indices = [] - left_brace_indices_stack = [] - for span in self.get_unescaped_char_spans("{}"): - index = span[0] - if self.get_substr(span) == "{": - left_brace_indices_stack.append(index) + def get_brace_spans(self) -> list[Span]: + span_begins = [] + span_ends = [] + span_begins_stack = [] + char_items = sorted([ + (index, char) + for char in "{}" + for index in self.get_unescaped_char_indices(char) + ]) + for index, char in char_items: + if char == "{": + span_begins_stack.append(index) else: - if not left_brace_indices_stack: + if not span_begins_stack: raise ValueError("Missing '{' inserted") - left_brace_index = left_brace_indices_stack.pop() - left_brace_indices.append(left_brace_index) - right_brace_indices.append(index) - if left_brace_indices_stack: + span_begins.append(span_begins_stack.pop()) + span_ends.append(index + 1) + if span_begins_stack: raise ValueError("Missing '}' inserted") - return list(zip(left_brace_indices, right_brace_indices)) + return list(zip(span_begins, span_ends)) - def get_script_char_spans(self) -> list[int]: - return self.get_unescaped_char_spans("_^") + def get_script_char_indices(self) -> list[int]: + return self.chain(*[ + self.get_unescaped_char_indices(char) + for char in "_^" + ]) def get_script_content_spans(self) -> list[Span]: result = [] - brace_indices_dict = dict(self.brace_index_pairs) - script_pattern = r"[a-zA-Z0-9]|\\[a-zA-Z]+" - for script_char_span in self.script_char_spans: - span_begin = self.match(r"\s*", pos=script_char_span[1]).end() - if span_begin in brace_indices_dict.keys(): - span_end = brace_indices_dict[span_begin] + 1 + script_entity_dict = dict(self.chain( + self.brace_spans, + self.command_spans + )) + for index in self.script_char_indices: + span_begin = self.match(r"\s*", pos=index + 1).end() + if span_begin in script_entity_dict.keys(): + span_end = script_entity_dict[span_begin] else: - match_obj = self.match(script_pattern, pos=span_begin) - if not match_obj: - script_name = { - "_": "subscript", - "^": "superscript" - }[script_char] - raise ValueError( - f"Unclear {script_name} detected while parsing. " - "Please use braces to clarify" - ) + match_obj = self.match(r".", pos=span_begin) + if match_obj is None: + continue span_end = match_obj.end() result.append((span_begin, span_end)) return result @@ -147,110 +175,102 @@ class MTex(LabelledString): def get_script_spans(self) -> list[Span]: return [ ( - self.search(r"\s*$", endpos=script_char_span[0]).start(), + self.match(r"[\s\S]*?(\s*)$", endpos=index).start(1), script_content_span[1] ) - for script_char_span, script_content_span in zip( - self.script_char_spans, self.script_content_spans + for index, script_content_span in zip( + self.script_char_indices, self.script_content_spans ) ] - # Parsing - def get_command_repl_items(self) -> list[tuple[Span, str]]: - color_related_command_dict = { - "color": (1, False), - "textcolor": (1, False), - "pagecolor": (1, True), - "colorbox": (1, True), - "fcolorbox": (2, True), - } result = [] - backslash_indices = self.backslash_indices - right_brace_indices = [ - right_index - for left_index, right_index in self.brace_index_pairs - ] - pattern = "".join([ - r"\\", - "(", - "|".join(color_related_command_dict.keys()), - ")", - r"(?![a-zA-Z])" - ]) - for match_obj in self.finditer(pattern): - span_begin, cmd_end = match_obj.span() - if span_begin not in backslash_indices: + brace_spans_dict = dict(self.brace_spans) + brace_begins = list(brace_spans_dict.keys()) + for cmd_span in self.command_spans: + cmd_name = self.get_substr(cmd_span) + if cmd_name not in TEX_COLOR_COMMANDS_DICT.keys(): continue - cmd_name = match_obj.group(1) - n_braces, substitute_cmd = color_related_command_dict[cmd_name] - span_end = self.take_nearest_value( - right_brace_indices, cmd_end, n_braces - ) + 1 + n_braces, substitute_cmd = TEX_COLOR_COMMANDS_DICT[cmd_name] + span_begin, span_end = cmd_span + for _ in range(n_braces): + span_end = brace_spans_dict[min(filter( + lambda index: index >= span_end, + brace_begins + ))] if substitute_cmd: - repl_str = "\\" + cmd_name + n_braces * "{black}" + repl_str = cmd_name + n_braces * "{black}" else: repl_str = "" result.append(((span_begin, span_end), repl_str)) return result - def get_extra_entity_spans(self) -> list[Span]: - return [ - self.match(r"\\([a-zA-Z]+|.)", pos=index).span() - for index in self.backslash_indices + def get_specified_spans(self) -> list[Span]: + # Match paired double braces (`{{...}}`). + sorted_brace_spans = sorted(self.brace_spans, key=lambda t: t[1]) + inner_brace_spans = [ + sorted_brace_spans[span_span[0]] + for _, span_span in self.compress_neighbours([ + (brace_span[0] + index, brace_span[1] - index) + for index, brace_span in enumerate(sorted_brace_spans) + ]) + if span_span[1] - span_span[0] >= 2 + ] + inner_brace_content_spans = [ + (span[0] + 1, span[1] - 1) + for span in inner_brace_spans + if span[1] - span[0] > 2 ] - def get_extra_ignored_spans(self) -> list[int]: - return self.script_char_spans.copy() - - def get_internal_specified_spans(self) -> list[Span]: - # Match paired double braces (`{{...}}`). - result = [] - reversed_brace_indices_dict = dict([ - pair[::-1] for pair in self.brace_index_pairs - ]) - skip = False - for prev_right_index, right_index in self.get_neighbouring_pairs( - list(reversed_brace_indices_dict.keys()) - ): - if skip: - skip = False - continue - if right_index != prev_right_index + 1: - continue - left_index = reversed_brace_indices_dict[right_index] - prev_left_index = reversed_brace_indices_dict[prev_right_index] - if left_index != prev_left_index - 1: - continue - result.append((left_index, right_index + 1)) - skip = True - return result - - def get_external_specified_spans(self) -> list[Span]: - return self.find_substrs(list(self.tex_to_color_map.keys())) + result = self.remove_redundancies(self.chain( + inner_brace_content_spans, + *[ + self.find_spans_by_selector(selector) + for selector in self.tex_to_color_map.keys() + ], + self.find_spans_by_selector(self.isolate) + )) + return list(filter( + lambda span: not any([ + entity_begin < index < entity_end + for index in span + for entity_begin, entity_end in self.command_spans + ]), + result + )) def get_label_span_list(self) -> list[Span]: + reversed_script_spans_dict = dict([ + script_span[::-1] for script_span in self.script_spans + ]) + skippable_indices = self.chain( + self.find_indices(r"\s"), + self.script_char_indices + ) result = self.script_content_spans.copy() - for span_begin, span_end in self.specified_spans: - shrinked_end = self.lslide(span_end, self.script_spans) - if span_begin >= shrinked_end: + for span in self.specified_spans: + span_begin, span_end = self.shrink_span(span, skippable_indices) + while span_end in reversed_script_spans_dict.keys(): + span_end = reversed_script_spans_dict[span_end] + if span_begin >= span_end: continue - shrinked_span = (span_begin, shrinked_end) + shrinked_span = (span_begin, span_end) if shrinked_span in result: continue result.append(shrinked_span) return result - def get_content(self, use_plain_file: bool) -> str: - if use_plain_file: - span_repl_dict = {} - else: - extended_label_span_list = [ - span - if span in self.script_content_spans - else (span[0], self.rslide(span[1], self.script_spans)) - for span in self.label_span_list - ] + def get_content(self, is_labelled: bool) -> str: + if is_labelled: + extended_label_span_list = [] + script_spans_dict = dict(self.script_spans) + for span in self.label_span_list: + if span not in self.script_content_spans: + span_begin, span_end = span + while span_end in script_spans_dict.keys(): + span_end = script_spans_dict[span_end] + span = (span_begin, span_end) + extended_label_span_list.append(span) inserted_string_pairs = [ (span, ( "{{" + self.get_color_command_str(label + 1), @@ -258,43 +278,52 @@ class MTex(LabelledString): )) for label, span in enumerate(extended_label_span_list) ] - span_repl_dict = self.generate_span_repl_dict( - inserted_string_pairs, - self.command_repl_items + result = self.get_replaced_string( + inserted_string_pairs, self.command_repl_items ) - result = self.get_replaced_substr(self.full_span, span_repl_dict) + else: + result = self.string if self.tex_environment: - result = "\n".join([ - f"\\begin{{{self.tex_environment}}}", - result, - f"\\end{{{self.tex_environment}}}" - ]) + if isinstance(self.tex_environment, str): + prefix = f"\\begin{{{self.tex_environment}}}" + suffix = f"\\end{{{self.tex_environment}}}" + else: + prefix, suffix = self.tex_environment + result = "\n".join([prefix, result, suffix]) if self.alignment: result = "\n".join([self.alignment, result]) - if use_plain_file: + if not is_labelled: result = "\n".join([ - self.get_color_command_str(self.hex_to_int(self.base_color)), + self.get_color_command_str(self.base_color_int), result ]) return result - @property - def has_predefined_local_colors(self) -> bool: - return bool(self.command_repl_items) - - # Post-parsing + # Selector def get_cleaned_substr(self, span: Span) -> str: - substr = super().get_cleaned_substr(span) - if not self.brace_index_pairs: - return substr + if not self.brace_spans: + brace_begins, brace_ends = [], [] + else: + brace_begins, brace_ends = zip(*self.brace_spans) + left_brace_indices = list(brace_begins) + right_brace_indices = [index - 1 for index in brace_ends] + skippable_indices = self.chain( + self.find_indices(r"\s"), + self.script_char_indices, + left_brace_indices, + right_brace_indices + ) + shrinked_span = self.shrink_span(span, skippable_indices) + + if shrinked_span[0] >= shrinked_span[1]: + return "" # Balance braces. - left_brace_indices, right_brace_indices = zip(*self.brace_index_pairs) unclosed_left_braces = 0 unclosed_right_braces = 0 - for index in range(*span): + for index in range(*shrinked_span): if index in left_brace_indices: unclosed_left_braces += 1 elif index in right_brace_indices: @@ -304,27 +333,25 @@ class MTex(LabelledString): unclosed_left_braces -= 1 return "".join([ unclosed_right_braces * "{", - substr, + self.get_substr(shrinked_span), unclosed_left_braces * "}" ]) # Method alias - def get_parts_by_tex(self, tex: str, **kwargs) -> VGroup: - return self.get_parts_by_string(tex, **kwargs) + def get_parts_by_tex(self, selector: Selector) -> VGroup: + return self.select_parts(selector) - def get_part_by_tex(self, tex: str, **kwargs) -> VMobject: - return self.get_part_by_string(tex, **kwargs) + def get_part_by_tex(self, selector: Selector) -> VGroup: + return self.select_part(selector) - def set_color_by_tex(self, tex: str, color: ManimColor, **kwargs): - return self.set_color_by_string(tex, color, **kwargs) + def set_color_by_tex(self, selector: Selector, color: ManimColor): + return self.set_parts_color(selector, color) def set_color_by_tex_to_color_map( - self, tex_to_color_map: dict[str, ManimColor], **kwargs + self, color_map: dict[Selector, ManimColor] ): - return self.set_color_by_string_to_color_map( - tex_to_color_map, **kwargs - ) + return self.set_parts_color_by_dict(color_map) def get_tex(self) -> str: return self.get_string() diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index c3c3be19..740eeee6 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -1,93 +1,63 @@ from __future__ import annotations -import os -import re -import itertools as it -from pathlib import Path from contextlib import contextmanager -import typing -from typing import Iterable, Sequence, Union +import os +from pathlib import Path +import re +import manimpango import pygments import pygments.formatters import pygments.lexers -from manimpango import MarkupUtils - +from manimlib.constants import DEFAULT_PIXEL_WIDTH, FRAME_WIDTH +from manimlib.constants import NORMAL from manimlib.logger import log -from manimlib.constants import * from manimlib.mobject.svg.labelled_string import LabelledString -from manimlib.utils.customization import get_customization -from manimlib.utils.tex_file_writing import tex_hash from manimlib.utils.config_ops import digest_config +from manimlib.utils.customization import get_customization from manimlib.utils.directories import get_downloads_dir from manimlib.utils.directories import get_text_dir -from manimlib.utils.iterables import remove_list_redundancies - +from manimlib.utils.tex_file_writing import tex_hash from typing import TYPE_CHECKING if TYPE_CHECKING: - from manimlib.mobject.types.vectorized_mobject import VMobject + from colour import Color + from typing import Iterable, Union + from manimlib.mobject.types.vectorized_mobject import VGroup - ManimColor = Union[str, colour.Color, Sequence[float]] + + ManimColor = Union[str, Color] Span = tuple[int, int] + Selector = Union[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]], + Iterable[Union[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]] + ]] + ] TEXT_MOB_SCALE_FACTOR = 0.0076 DEFAULT_LINE_SPACING_SCALE = 0.6 +# Ensure the canvas is large enough to hold all glyphs. +DEFAULT_CANVAS_WIDTH = 16384 +DEFAULT_CANVAS_HEIGHT = 16384 # See https://docs.gtk.org/Pango/pango_markup.html -# A tag containing two aliases will cause warning, -# so only use the first key of each group of aliases. -SPAN_ATTR_KEY_ALIAS_LIST = ( - ("font", "font_desc"), - ("font_family", "face"), - ("font_size", "size"), - ("font_style", "style"), - ("font_weight", "weight"), - ("font_variant", "variant"), - ("font_stretch", "stretch"), - ("font_features",), - ("foreground", "fgcolor", "color"), - ("background", "bgcolor"), - ("alpha", "fgalpha"), - ("background_alpha", "bgalpha"), - ("underline",), - ("underline_color",), - ("overline",), - ("overline_color",), - ("rise",), - ("baseline_shift",), - ("font_scale",), - ("strikethrough",), - ("strikethrough_color",), - ("fallback",), - ("lang",), - ("letter_spacing",), - ("gravity",), - ("gravity_hint",), - ("show",), - ("insert_hyphens",), - ("allow_breaks",), - ("line_height",), - ("text_transform",), - ("segment",), +MARKUP_COLOR_KEYS = ( + "foreground", "fgcolor", "color", + "background", "bgcolor", + "underline_color", + "overline_color", + "strikethrough_color" ) -COLOR_RELATED_KEYS = ( - "foreground", - "background", - "underline_color", - "overline_color", - "strikethrough_color" -) -SPAN_ATTR_KEY_CONVERSION = { - key: key_alias_list[0] - for key_alias_list in SPAN_ATTR_KEY_ALIAS_LIST - for key in key_alias_list -} -TAG_TO_ATTR_DICT = { +MARKUP_TAG_CONVERSION_DICT = { "b": {"font_weight": "bold"}, "big": {"font_size": "larger"}, "i": {"font_style": "italic"}, @@ -96,7 +66,7 @@ TAG_TO_ATTR_DICT = { "sup": {"baseline_shift": "superscript", "font_scale": "superscript"}, "small": {"font_size": "smaller"}, "tt": {"font_family": "monospace"}, - "u": {"underline": "single"}, + "u": {"underline": "single"} } @@ -120,7 +90,7 @@ class MarkupText(LabelledString): "justify": False, "indent": 0, "alignment": "LEFT", - "line_width_factor": None, + "line_width": None, "font": "", "slant": NORMAL, "weight": NORMAL, @@ -141,9 +111,7 @@ class MarkupText(LabelledString): if not self.font: self.font = get_customization()["style"]["font"] if self.is_markup: - validate_error = MarkupUtils.validate(text) - if validate_error: - raise ValueError(validate_error) + self.validate_markup_string(text) self.text = text super().__init__(text, **kwargs) @@ -165,7 +133,6 @@ class MarkupText(LabelledString): self.svg_default, self.path_string_config, self.base_color, - self.use_plain_file, self.isolate, self.text, self.is_markup, @@ -174,7 +141,7 @@ class MarkupText(LabelledString): self.justify, self.indent, self.alignment, - self.line_width_factor, + self.line_width, self.font, self.slant, self.weight, @@ -201,23 +168,32 @@ class MarkupText(LabelledString): kwargs[short_name] = kwargs.pop(long_name) def get_file_path_by_content(self, content: str) -> str: + hash_content = str(( + content, + self.justify, + self.indent, + self.alignment, + self.line_width + )) svg_file = os.path.join( - get_text_dir(), tex_hash(content) + ".svg" + get_text_dir(), tex_hash(hash_content) + ".svg" ) if not os.path.exists(svg_file): self.markup_to_svg(content, svg_file) return svg_file def markup_to_svg(self, markup_str: str, file_name: str) -> str: + self.validate_markup_string(markup_str) + # `manimpango` is under construction, # so the following code is intended to suit its interface alignment = _Alignment(self.alignment) - if self.line_width_factor is None: + if self.line_width is None: pango_width = -1 else: - pango_width = self.line_width_factor * DEFAULT_PIXEL_WIDTH + pango_width = self.line_width / FRAME_WIDTH * DEFAULT_PIXEL_WIDTH - return MarkupUtils.text2svg( + return manimpango.MarkupUtils.text2svg( text=markup_str, font="", # Already handled slant="NORMAL", # Already handled @@ -228,8 +204,8 @@ class MarkupText(LabelledString): file_name=file_name, START_X=0, START_Y=0, - width=DEFAULT_PIXEL_WIDTH, - height=DEFAULT_PIXEL_HEIGHT, + width=DEFAULT_CANVAS_WIDTH, + height=DEFAULT_CANVAS_HEIGHT, justify=self.justify, indent=self.indent, line_spacing=None, # Already handled @@ -237,13 +213,23 @@ class MarkupText(LabelledString): pango_width=pango_width ) - def pre_parse(self) -> None: - super().pre_parse() - self.tag_items_from_markup = self.get_tag_items_from_markup() - self.global_dict_from_config = self.get_global_dict_from_config() - self.local_dicts_from_markup = self.get_local_dicts_from_markup() - self.local_dicts_from_config = self.get_local_dicts_from_config() - self.predefined_attr_dicts = self.get_predefined_attr_dicts() + @staticmethod + def validate_markup_string(markup_str: str) -> None: + validate_error = manimpango.MarkupUtils.validate(markup_str) + if not validate_error: + return + raise ValueError( + f"Invalid markup string \"{markup_str}\"\n" + f"{validate_error}" + ) + + def parse(self) -> None: + self.global_attr_dict = self.get_global_attr_dict() + self.tag_pairs_from_markup = self.get_tag_pairs_from_markup() + self.tag_spans = self.get_tag_spans() + self.items_from_markup = self.get_items_from_markup() + self.specified_items = self.get_specified_items() + super().parse() # Toolkits @@ -254,87 +240,50 @@ class MarkupText(LabelledString): for key, val in attr_dict.items() ]) - @staticmethod - def merge_attr_dicts( - attr_dict_items: list[Span, str, typing.Any] - ) -> list[tuple[Span, dict[str, str]]]: - index_seq = [0] - attr_dict_list = [{}] - for span, attr_dict in attr_dict_items: - if span[0] >= span[1]: - continue - region_indices = [ - MarkupText.find_region_index(index_seq, index) - for index in span - ] - for flag in (1, 0): - if index_seq[region_indices[flag]] == span[flag]: - continue - region_index = region_indices[flag] - index_seq.insert(region_index + 1, span[flag]) - attr_dict_list.insert( - region_index + 1, attr_dict_list[region_index].copy() + # Parsing + + def get_global_attr_dict(self) -> dict[str, str]: + result = { + "foreground": self.int_to_hex(self.base_color_int), + "font_family": self.font, + "font_style": self.slant, + "font_weight": self.weight, + "font_size": str(self.font_size * 1024), + } + # `line_height` attribute is supported since Pango 1.50. + pango_version = manimpango.pango_version() + if tuple(map(int, pango_version.split("."))) < (1, 50): + if self.lsh is not None: + log.warning( + f"Pango version {pango_version} found (< 1.50), " + "unable to set `line_height` attribute" ) - region_indices[flag] += 1 - if flag == 0: - region_indices[1] += 1 - for key, val in attr_dict.items(): - if not key: - continue - for mid_dict in attr_dict_list[slice(*region_indices)]: - mid_dict[key] = val - return list(zip( - MarkupText.get_neighbouring_pairs(index_seq), attr_dict_list[:-1] - )) + else: + line_spacing_scale = self.lsh or DEFAULT_LINE_SPACING_SCALE + result["line_height"] = str(((line_spacing_scale) + 1) * 0.6) + return result - def find_substr_or_span( - self, substr_or_span: str | tuple[int | None, int | None] - ) -> list[Span]: - if isinstance(substr_or_span, str): - return self.find_substr(substr_or_span) - - span = tuple([ - ( - min(index, self.string_len) - if index >= 0 - else max(index + self.string_len, 0) - ) - if index is not None else default_index - for index, default_index in zip(substr_or_span, self.full_span) - ]) - if span[0] >= span[1]: - return [] - return [span] - - # Pre-parsing - - def get_tag_items_from_markup( + def get_tag_pairs_from_markup( self ) -> list[tuple[Span, Span, dict[str, str]]]: if not self.is_markup: return [] - tag_pattern = r"""<(/?)(\w+)\s*((?:\w+\s*\=\s*(['"]).*?\4\s*)*)>""" - attr_pattern = r"""(\w+)\s*\=\s*(['"])(.*?)\2""" + tag_pattern = r"""<(/?)(\w+)\s*((\w+\s*\=\s*(['"])[\s\S]*?\5\s*)*)>""" + attr_pattern = r"""(\w+)\s*\=\s*(['"])([\s\S]*?)\2""" begin_match_obj_stack = [] match_obj_pairs = [] - for match_obj in self.finditer(tag_pattern): + for match_obj in re.finditer(tag_pattern, self.string): if not match_obj.group(1): begin_match_obj_stack.append(match_obj) else: match_obj_pairs.append( (begin_match_obj_stack.pop(), match_obj) ) - if begin_match_obj_stack: - raise ValueError("Unclosed tag(s) detected") result = [] for begin_match_obj, end_match_obj in match_obj_pairs: tag_name = begin_match_obj.group(2) - if tag_name != end_match_obj.group(2): - raise ValueError("Unmatched tag names") - if end_match_obj.group(3): - raise ValueError("Attributes shan't exist in ending tags") if tag_name == "span": attr_dict = { match.group(1): match.group(3) @@ -342,189 +291,170 @@ class MarkupText(LabelledString): attr_pattern, begin_match_obj.group(3) ) } - elif tag_name in TAG_TO_ATTR_DICT.keys(): - if begin_match_obj.group(3): - raise ValueError( - f"Attributes shan't exist in tag '{tag_name}'" - ) - attr_dict = TAG_TO_ATTR_DICT[tag_name].copy() else: - raise ValueError(f"Unknown tag: '{tag_name}'") + attr_dict = MARKUP_TAG_CONVERSION_DICT.get(tag_name, {}) result.append( (begin_match_obj.span(), end_match_obj.span(), attr_dict) ) return result - def get_global_dict_from_config(self) -> dict[str, typing.Any]: - result = { - "line_height": ( - (self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1 - ) * 0.6, - "font_family": self.font, - "font_size": self.font_size * 1024, - "font_style": self.slant, - "font_weight": self.weight - } - result.update(self.global_config) - return result + def get_tag_spans(self) -> list[Span]: + return [ + tag_span + for begin_tag, end_tag, _ in self.tag_pairs_from_markup + for tag_span in (begin_tag, end_tag) + ] - def get_local_dicts_from_markup( - self - ) -> list[Span, dict[str, str]]: - return sorted([ - ((begin_tag_span[0], end_tag_span[1]), attr_dict) + def get_items_from_markup(self) -> list[Span]: + return [ + ((begin_tag_span[1], end_tag_span[0]), attr_dict) for begin_tag_span, end_tag_span, attr_dict - in self.tag_items_from_markup - ]) + in self.tag_pairs_from_markup + if begin_tag_span[1] < end_tag_span[0] + ] - def get_local_dicts_from_config( - self - ) -> list[Span, dict[str, typing.Any]]: + def get_specified_items(self) -> list[tuple[Span, dict[str, str]]]: + result = self.chain( + self.items_from_markup, + [ + (span, {key: val}) + for t2x_dict, key in ( + (self.t2c, "foreground"), + (self.t2f, "font_family"), + (self.t2s, "font_style"), + (self.t2w, "font_weight") + ) + for selector, val in t2x_dict.items() + for span in self.find_spans_by_selector(selector) + ], + [ + (span, local_config) + for selector, local_config in self.local_configs.items() + for span in self.find_spans_by_selector(selector) + ], + [ + (span, {}) + for span in self.find_spans_by_selector(self.isolate) + ] + ) + entity_spans = self.tag_spans.copy() + if self.is_markup: + entity_spans.extend(self.find_spans(r"&[\s\S]*?;")) return [ - (span, {key: val}) - for t2x_dict, key in ( - (self.t2c, "foreground"), - (self.t2f, "font_family"), - (self.t2s, "font_style"), - (self.t2w, "font_weight") - ) - for substr_or_span, val in t2x_dict.items() - for span in self.find_substr_or_span(substr_or_span) - ] + [ - (span, local_config) - for substr_or_span, local_config in self.local_configs.items() - for span in self.find_substr_or_span(substr_or_span) + (span, attr_dict) + for span, attr_dict in result + if not any([ + entity_begin < index < entity_end + for index in span + for entity_begin, entity_end in entity_spans + ]) ] - def get_predefined_attr_dicts(self) -> list[Span, dict[str, str]]: - attr_dict_items = [ - (self.full_span, self.global_dict_from_config), - *self.local_dicts_from_markup, - *self.local_dicts_from_config - ] - return [ - (span, { - SPAN_ATTR_KEY_CONVERSION[key.lower()]: str(val) - for key, val in attr_dict.items() - }) - for span, attr_dict in attr_dict_items - ] - - # Parsing - def get_command_repl_items(self) -> list[tuple[Span, str]]: result = [ - (tag_span, "") - for begin_tag, end_tag, _ in self.tag_items_from_markup - for tag_span in (begin_tag, end_tag) + (tag_span, "") for tag_span in self.tag_spans ] if not self.is_markup: - result += [ + result.extend([ (span, escaped) for char, escaped in ( ("&", "&"), (">", ">"), ("<", "<") ) - for span in self.find_substr(char) - ] + for span in self.find_spans(re.escape(char)) + ]) return result - def get_extra_entity_spans(self) -> list[Span]: - if not self.is_markup: - return [] - return self.find_spans(r"&.*?;") - - def get_extra_ignored_spans(self) -> list[int]: - return [] - - def get_internal_specified_spans(self) -> list[Span]: - return [span for span, _ in self.local_dicts_from_markup] - - def get_external_specified_spans(self) -> list[Span]: - return [span for span, _ in self.local_dicts_from_config] + def get_specified_spans(self) -> list[Span]: + return self.remove_redundancies([ + span for span, _ in self.specified_items + ]) def get_label_span_list(self) -> list[Span]: - breakup_indices = remove_list_redundancies(list(it.chain(*it.chain( - self.find_spans(r"\s+"), - self.find_spans(r"\b"), - self.specified_spans - )))) - breakup_indices = sorted(filter( - lambda index: not any([ - span[0] < index < span[1] - for span in self.entity_spans - ]), - breakup_indices - )) - return list(filter( - lambda span: self.get_substr(span).strip(), - self.get_neighbouring_pairs(breakup_indices) - )) - - def get_content(self, use_plain_file: bool) -> str: - if use_plain_file: - attr_dict_items = [ - (self.full_span, {"foreground": self.base_color}), - *self.predefined_attr_dicts, - *[ - (span, {}) - for span in self.label_span_list - ] + interval_spans = sorted(self.chain( + self.tag_spans, + [ + (index, index) + for span in self.specified_spans + for index in span ] + )) + text_spans = self.get_complement_spans(interval_spans, self.full_span) + if self.is_markup: + pattern = r"[0-9a-zA-Z]+|(?:&[\s\S]*?;|[^0-9a-zA-Z\s])+" else: - attr_dict_items = [ - (self.full_span, {"foreground": BLACK}), - *[ + pattern = r"[0-9a-zA-Z]+|[^0-9a-zA-Z\s]+" + return self.chain(*[ + self.find_spans(pattern, pos=span_begin, endpos=span_end) + for span_begin, span_end in text_spans + ]) + + def get_content(self, is_labelled: bool) -> str: + predefined_items = [ + (self.full_span, self.global_attr_dict), + (self.full_span, self.global_config), + *self.specified_items + ] + if is_labelled: + attr_dict_items = self.chain( + [ (span, { - key: BLACK if key in COLOR_RELATED_KEYS else val + key: + "black" if key.lower() in MARKUP_COLOR_KEYS else val for key, val in attr_dict.items() }) - for span, attr_dict in self.predefined_attr_dicts + for span, attr_dict in predefined_items ], - *[ + [ (span, {"foreground": self.int_to_hex(label + 1)}) for label, span in enumerate(self.label_span_list) ] - ] + ) + else: + attr_dict_items = self.chain( + predefined_items, + [ + (span, {}) + for span in self.label_span_list + ] + ) inserted_string_pairs = [ (span, ( f"", "" )) - for span, attr_dict in self.merge_attr_dicts(attr_dict_items) + for span, attr_dict in attr_dict_items if attr_dict ] - span_repl_dict = self.generate_span_repl_dict( + return self.get_replaced_string( inserted_string_pairs, self.command_repl_items ) - return self.get_replaced_substr(self.full_span, span_repl_dict) - @property - def has_predefined_local_colors(self) -> bool: - return any([ - key in COLOR_RELATED_KEYS - for _, attr_dict in self.predefined_attr_dicts - for key in attr_dict.keys() - ]) + # Selector + + def get_cleaned_substr(self, span: Span) -> str: + repl_items = list(filter( + lambda repl_item: self.span_contains(span, repl_item[0]), + self.command_repl_items + )) + return self.get_replaced_substr(span, repl_items).strip() # Method alias - def get_parts_by_text(self, text: str, **kwargs) -> VGroup: - return self.get_parts_by_string(text, **kwargs) + def get_parts_by_text(self, selector: Selector) -> VGroup: + return self.select_parts(selector) - def get_part_by_text(self, text: str, **kwargs) -> VMobject: - return self.get_part_by_string(text, **kwargs) + def get_part_by_text(self, selector: Selector) -> VGroup: + return self.select_part(selector) - def set_color_by_text(self, text: str, color: ManimColor, **kwargs): - return self.set_color_by_string(text, color, **kwargs) + def set_color_by_text(self, selector: Selector, color: ManimColor): + return self.set_parts_color(selector, color) def set_color_by_text_to_color_map( - self, text_to_color_map: dict[str, ManimColor], **kwargs + self, color_map: dict[Selector, ManimColor] ): - return self.set_color_by_string_to_color_map( - text_to_color_map, **kwargs - ) + return self.set_parts_color_by_dict(color_map) def get_text(self) -> str: return self.get_string() diff --git a/manimlib/utils/iterables.py b/manimlib/utils/iterables.py index bdaa76c2..fa54e68d 100644 --- a/manimlib/utils/iterables.py +++ b/manimlib/utils/iterables.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: S = TypeVar("S") -def remove_list_redundancies(l: Iterable[T]) -> list[T]: +def remove_list_redundancies(l: Sequence[T]) -> list[T]: """ Used instead of list(set(l)) to maintain order Keeps the last occurrence of each element @@ -40,14 +40,14 @@ def list_difference_update(l1: Iterable[T], l2: Iterable[T]) -> list[T]: return [e for e in l1 if e not in l2] -def adjacent_n_tuples(objects: Iterable[T], n: int) -> zip[tuple[T, T]]: +def adjacent_n_tuples(objects: Sequence[T], n: int) -> zip[tuple[T, T]]: return zip(*[ [*objects[k:], *objects[:k]] for k in range(n) ]) -def adjacent_pairs(objects: Iterable[T]) -> zip[tuple[T, T]]: +def adjacent_pairs(objects: Sequence[T]) -> zip[tuple[T, T]]: return adjacent_n_tuples(objects, 2) From 065900c6ac4394ae1dd942eee5c3e000af660d1f Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Wed, 27 Apr 2022 23:04:24 +0800 Subject: [PATCH 03/11] Some refactors --- manimlib/mobject/svg/labelled_string.py | 100 ++++++++++-------------- manimlib/mobject/svg/mtex_mobject.py | 32 ++++---- manimlib/mobject/svg/text_mobject.py | 17 ++-- manimlib/scene/interactive_scene.py | 2 + 4 files changed, 71 insertions(+), 80 deletions(-) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 1c3f0afd..cb62df9b 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -2,9 +2,10 @@ from __future__ import annotations from abc import ABC, abstractmethod import itertools as it -import numpy as np import re +import numpy as np + from manimlib.constants import WHITE from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.types.vectorized_mobject import VGroup @@ -138,32 +139,16 @@ class LabelledString(SVGMobject, ABC): def find_indices(self, pattern: str | re.Pattern, **kwargs) -> list[int]: return [index for index, _ in self.find_spans(pattern, **kwargs)] - @staticmethod - def is_single_selector(selector: Selector) -> bool: - if isinstance(selector, str): - return True - if isinstance(selector, re.Pattern): - return True - if isinstance(selector, tuple): - if len(selector) == 2 and all([ - isinstance(index, int) or index is None - for index in selector - ]): - return True - return False - def find_spans_by_selector(self, selector: Selector) -> list[Span]: - if self.is_single_selector(selector): - selector = (selector,) - result = [] - for sel in selector: - if not self.is_single_selector(sel): - raise TypeError(f"Invalid selector: '{sel}'") + def find_spans_by_single_selector(sel): if isinstance(sel, str): - spans = self.find_spans(re.escape(sel)) - elif isinstance(sel, re.Pattern): - spans = self.find_spans(sel) - else: + return self.find_spans(re.escape(sel)) + if isinstance(sel, re.Pattern): + return self.find_spans(sel) + if isinstance(sel, tuple) and len(sel) == 2 and all([ + isinstance(index, int) or index is None + for index in sel + ]): string_len = self.full_span[1] span = tuple([ ( @@ -174,8 +159,17 @@ class LabelledString(SVGMobject, ABC): if index is not None else default_index for index, default_index in zip(sel, self.full_span) ]) - spans = [span] - result.extend(spans) + return [span] + return None + + result = find_spans_by_single_selector(selector) + if result is None: + result = [] + for sel in selector: + spans = find_spans_by_single_selector(sel) + if spans is None: + raise TypeError(f"Invalid selector: '{sel}'") + result.extend(spans) return sorted(filter( lambda span: span[0] < span[1], self.remove_redundancies(result) @@ -206,8 +200,8 @@ class LabelledString(SVGMobject, ABC): unique_vals.append(val) indices.append(index) indices.append(len(vals)) - spans = LabelledString.get_neighbouring_pairs(indices) - return list(zip(unique_vals, spans)) + val_ranges = LabelledString.get_neighbouring_pairs(indices) + return list(zip(unique_vals, val_ranges)) @staticmethod def span_contains(span_0: Span, span_1: Span) -> bool: @@ -233,26 +227,23 @@ class LabelledString(SVGMobject, ABC): if not inserted_string_pairs: return [] - spans = [ - span for span, _ in inserted_string_pairs - ] - sorted_index_flag_pairs = sorted( - it.product(range(len(spans)), range(2)), - key=lambda t: ( - spans[t[0]][t[1]], - np.sign(spans[t[0]][1 - t[1]] - spans[t[0]][t[1]]), - -spans[t[0]][1 - t[1]], - t[1], - (1, -1)[t[1]] * t[0] + indices, *_, inserted_strings = zip(*sorted([ + ( + span[flag], + np.sign(span[1 - flag] - span[flag]), + -span[1 - flag], + flag, + (1, -1)[flag] * item_index, + str_pair[flag] ) - ) - indices, inserted_strings = zip(*[ - list(zip(*inserted_string_pairs[item_index]))[flag] - for item_index, flag in sorted_index_flag_pairs - ]) + for item_index, (span, str_pair) in enumerate( + inserted_string_pairs + ) + for flag in range(2) + ])) return [ - (index, "".join(inserted_strings[slice(*item_span)])) - for index, item_span + (index, "".join(inserted_strings[slice(*index_range)])) + for index, index_range in LabelledString.compress_neighbours(indices) ] @@ -262,8 +253,7 @@ class LabelledString(SVGMobject, ABC): if not repl_items: return self.get_substr(span) - sorted_repl_items = sorted(repl_items, key=lambda t: t[0]) - repl_spans, repl_strs = zip(*sorted_repl_items) + repl_spans, repl_strs = zip(*sorted(repl_items)) pieces = [ self.get_substr(piece_span) for piece_span in self.get_complement_spans(repl_spans, span) @@ -335,7 +325,7 @@ class LabelledString(SVGMobject, ABC): return [] labels, labelled_submobjects = zip(*self.labelled_submobject_items) - group_labels, labelled_submob_spans = zip( + group_labels, labelled_submob_ranges = zip( *self.compress_neighbours(labels) ) ordered_spans = [ @@ -362,8 +352,8 @@ class LabelledString(SVGMobject, ABC): ) ] submob_groups = VGroup(*[ - VGroup(*labelled_submobjects[slice(*submob_span)]) - for submob_span in labelled_submob_spans + VGroup(*labelled_submobjects[slice(*submob_range)]) + for submob_range in labelled_submob_ranges ]) return list(zip(group_substrs, submob_groups)) @@ -377,13 +367,9 @@ class LabelledString(SVGMobject, ABC): ] def select_part_by_span(self, custom_span: Span) -> VGroup: - labels = [ - label for label, span in enumerate(self.label_span_list) - if self.span_contains(custom_span, span) - ] return VGroup(*[ submob for label, submob in self.labelled_submobject_items - if label in labels + if self.span_contains(custom_span, self.label_span_list[label]) ]) def select_parts(self, selector: Selector) -> VGroup: diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index ed7273ee..3edb10b1 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -209,17 +209,19 @@ class MTex(LabelledString): # Match paired double braces (`{{...}}`). sorted_brace_spans = sorted(self.brace_spans, key=lambda t: t[1]) inner_brace_spans = [ - sorted_brace_spans[span_span[0]] - for _, span_span in self.compress_neighbours([ - (brace_span[0] + index, brace_span[1] - index) - for index, brace_span in enumerate(sorted_brace_spans) + sorted_brace_spans[range_begin] + for _, (range_begin, range_end) in self.compress_neighbours([ + (span_begin + index, span_end - index) + for index, (span_begin, span_end) in enumerate( + sorted_brace_spans + ) ]) - if span_span[1] - span_span[0] >= 2 + if range_end - range_begin >= 2 ] inner_brace_content_spans = [ - (span[0] + 1, span[1] - 1) - for span in inner_brace_spans - if span[1] - span[0] > 2 + (span_begin + 1, span_end - 1) + for span_begin, span_end in inner_brace_spans + if span_end - span_begin > 2 ] result = self.remove_redundancies(self.chain( @@ -303,12 +305,14 @@ class MTex(LabelledString): # Selector def get_cleaned_substr(self, span: Span) -> str: - if not self.brace_spans: - brace_begins, brace_ends = [], [] - else: - brace_begins, brace_ends = zip(*self.brace_spans) - left_brace_indices = list(brace_begins) - right_brace_indices = [index - 1 for index in brace_ends] + left_brace_indices = [ + span_begin + for span_begin, _ in self.brace_spans + ] + right_brace_indices = [ + span_end - 1 + for _, span_end in self.brace_spans + ] skippable_indices = self.chain( self.find_indices(r"\s"), self.script_char_indices, diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 740eeee6..0fe113ff 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -300,18 +300,17 @@ class MarkupText(LabelledString): return result def get_tag_spans(self) -> list[Span]: - return [ - tag_span - for begin_tag, end_tag, _ in self.tag_pairs_from_markup - for tag_span in (begin_tag, end_tag) - ] + return self.chain( + (begin_tag_span, end_tag_span) + for begin_tag_span, end_tag_span, _ in self.tag_pairs_from_markup + ) - def get_items_from_markup(self) -> list[Span]: + def get_items_from_markup(self) -> list[tuple[Span, dict[str, str]]]: return [ - ((begin_tag_span[1], end_tag_span[0]), attr_dict) - for begin_tag_span, end_tag_span, attr_dict + ((span_begin, span_end), attr_dict) + for (_, span_begin), (span_end, _), attr_dict in self.tag_pairs_from_markup - if begin_tag_span[1] < end_tag_span[0] + if span_begin < span_end ] def get_specified_items(self) -> list[tuple[Span, dict[str, str]]]: diff --git a/manimlib/scene/interactive_scene.py b/manimlib/scene/interactive_scene.py index 40ec18c5..3b7397b8 100644 --- a/manimlib/scene/interactive_scene.py +++ b/manimlib/scene/interactive_scene.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import itertools as it import numpy as np import pyperclip From 03cb42ba15d5fdf829bbc3e30f3ef594f004dd41 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Mon, 2 May 2022 22:40:06 +0800 Subject: [PATCH 04/11] [WIP] Refactor LabelledString and relevant classes --- .../animation/transform_matching_parts.py | 4 +- manimlib/mobject/svg/labelled_string.py | 659 +++++++++++++--- manimlib/mobject/svg/mtex_mobject.py | 716 +++++++++++++----- manimlib/mobject/svg/text_mobject.py | 602 +++++++++++---- 4 files changed, 1525 insertions(+), 456 deletions(-) diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index e84f1d9d..96fd95ce 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -225,8 +225,8 @@ class TransformMatchingStrings(AnimationGroup): VGroup(*target_substr_to_parts_map[substr]) ) for substr in sorted([ - s for s in source_substr_to_parts_map.keys() - if s and s in target_substr_to_parts_map.keys() + s for s in source_substr_to_parts_map + if s and s in target_substr_to_parts_map ], key=len, reverse=True) ] ) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index cb62df9b..23c285de 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -4,8 +4,6 @@ from abc import ABC, abstractmethod import itertools as it import re -import numpy as np - from manimlib.constants import WHITE from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.types.vectorized_mobject import VGroup @@ -56,7 +54,7 @@ class LabelledString(SVGMobject, ABC): digest_config(self, kwargs) if self.base_color is None: self.base_color = WHITE - self.base_color_int = self.color_to_int(self.base_color) + self.base_color_hex = self.color_to_hex(self.base_color) self.full_span = (0, len(self.string)) self.parse() @@ -67,11 +65,7 @@ class LabelledString(SVGMobject, ABC): ] def get_file_path(self) -> str: - return self.get_file_path_(is_labelled=False) - - def get_file_path_(self, is_labelled: bool) -> str: - content = self.get_content(is_labelled) - return self.get_file_path_by_content(content) + return self.get_file_path_by_content(self.original_content) @abstractmethod def get_file_path_by_content(self, content: str) -> str: @@ -80,53 +74,66 @@ class LabelledString(SVGMobject, ABC): def generate_mobject(self) -> None: super().generate_mobject() - num_labels = len(self.label_span_list) - if num_labels: - file_path = self.get_file_path_(is_labelled=True) - labelled_svg = SVGMobject(file_path) - submob_color_ints = [ - self.color_to_int(submob.get_fill_color()) - for submob in labelled_svg.submobjects - ] - else: - submob_color_ints = [0] * len(self.submobjects) - - if len(self.submobjects) != len(submob_color_ints): + file_path = self.get_file_path_by_content(self.labelled_content) + labelled_svg = SVGMobject(file_path) + num_submobjects = len(self.submobjects) + if num_submobjects != len(labelled_svg.submobjects): raise ValueError( "Cannot align submobjects of the labelled svg " "to the original svg" ) + submob_color_ints = [ + self.hex_to_int(self.color_to_hex(submob.get_fill_color())) + for submob in labelled_svg.submobjects + ] unrecognized_color_ints = self.remove_redundancies(sorted(filter( - lambda color_int: color_int > num_labels, + lambda color_int: color_int > len(self.label_span_list), submob_color_ints ))) if unrecognized_color_ints: raise ValueError( "Unrecognized color label(s) detected: " - f"{','.join(map(self.int_to_hex, unrecognized_color_ints))}" + f"{', '.join(map(self.int_to_hex, unrecognized_color_ints))}" ) + #if self.sort_labelled_submobs: + submob_indices = sorted( + range(num_submobjects), + key=lambda index: tuple( + self.submobjects[index].get_center() + ) + ) + labelled_submob_indices = sorted( + range(num_submobjects), + key=lambda index: tuple( + labelled_svg.submobjects[index].get_center() + ) + ) + submob_color_ints = [ + submob_color_ints[ + labelled_submob_indices[submob_indices.index(index)] + ] + for index in range(num_submobjects) + ] + for submob, color_int in zip(self.submobjects, submob_color_ints): submob.label = color_int - 1 - def parse(self) -> None: - self.command_repl_items = self.get_command_repl_items() - self.specified_spans = self.get_specified_spans() - self.check_overlapping() - self.label_span_list = self.get_label_span_list() - if len(self.label_span_list) >= 16777216: - raise ValueError("Cannot handle that many substrings") + #@property + #@abstractmethod + #def sort_labelled_submobs(self) -> bool: + # return False # Toolkits def get_substr(self, span: Span) -> str: return self.string[slice(*span)] - def match(self, pattern: str | re.Pattern, **kwargs) -> re.Pattern | None: - if isinstance(pattern, str): - pattern = re.compile(pattern) - return re.compile(pattern).match(self.string, **kwargs) + #def match(self, pattern: str | re.Pattern, **kwargs) -> re.Pattern | None: + # if isinstance(pattern, str): + # pattern = re.compile(pattern) + # return re.compile(pattern).match(self.string, **kwargs) def find_spans(self, pattern: str | re.Pattern, **kwargs) -> list[Span]: if isinstance(pattern, str): @@ -136,8 +143,8 @@ class LabelledString(SVGMobject, ABC): for match_obj in pattern.finditer(self.string, **kwargs) ] - def find_indices(self, pattern: str | re.Pattern, **kwargs) -> list[int]: - return [index for index, _ in self.find_spans(pattern, **kwargs)] + #def find_indices(self, pattern: str | re.Pattern, **kwargs) -> list[int]: + # return [index for index, _ in self.find_spans(pattern, **kwargs)] def find_spans_by_selector(self, selector: Selector) -> list[Span]: def find_spans_by_single_selector(sel): @@ -145,20 +152,16 @@ class LabelledString(SVGMobject, ABC): return self.find_spans(re.escape(sel)) if isinstance(sel, re.Pattern): return self.find_spans(sel) - if isinstance(sel, tuple) and len(sel) == 2 and all([ + if isinstance(sel, tuple) and len(sel) == 2 and all( isinstance(index, int) or index is None for index in sel - ]): - string_len = self.full_span[1] - span = tuple([ - ( - min(index, string_len) - if index >= 0 - else max(index + string_len, 0) - ) + ): + l = self.full_span[1] + span = tuple( + min(index, l) if index >= 0 else max(index + l, 0) if index is not None else default_index for index, default_index in zip(sel, self.full_span) - ]) + ) return [span] return None @@ -203,13 +206,158 @@ class LabelledString(SVGMobject, ABC): val_ranges = LabelledString.get_neighbouring_pairs(indices) return list(zip(unique_vals, val_ranges)) + @staticmethod + def sort_obj_pairs_by_spans( + obj_pairs: list[tuple[Span, tuple[T, T]]] + ) -> list[tuple[int, T]]: + return [ + (index, obj) + for (index, _), obj in sorted([ + (span, begin_obj) + for span, (begin_obj, _) in obj_pairs + ] + [ + (span[::-1], end_obj) + for span, (_, end_obj) in reversed(obj_pairs) + ], key=lambda t: (t[0][0], -t[0][1])) + ] + @staticmethod def span_contains(span_0: Span, span_1: Span) -> bool: return span_0[0] <= span_1[0] and span_0[1] >= span_1[1] + def get_piece_items( + self, + tag_span_pairs: list[tuple[Span, Span]], + entity_spans: list[Span] + ) -> tuple[list[Span], list[int]]: + tagged_items = sorted(self.chain( + [(begin_cmd_span, 1) for begin_cmd_span, _ in tag_span_pairs], + [(end_cmd_span, -1) for _, end_cmd_span in tag_span_pairs], + [(entity_span, 0) for entity_span in entity_spans], + ), key=lambda t: t[0]) + piece_spans = self.get_complement_spans(self.full_span, [ + interval_span for interval_span, _ in tagged_items + ]) + piece_levels = [0, *it.accumulate([tag for _, tag in tagged_items])] + return piece_spans, piece_levels + + def split_span(self, arbitrary_span: Span) -> list[Span]: + # ignorable_indices -- + # left_bracket_spans + # right_bracket_spans + # entity_spans + #piece_spans, piece_levels = zip(*self.piece_items) + #ignorable_indices = self.ignorable_indices + piece_spans = self.piece_spans + piece_levels = self.piece_levels + #piece_begins, piece_ends = zip(*piece_spans) + #span_begin, span_end = arbitrary_span + #while span_begin in ignorable_indices: + # span_begin += 1 + #while span_end - 1 in ignorable_indices: + # span_end -= 1 + #entity_spans = self.chain( + # left_bracket_spans, right_bracket_spans, entity_spans + #) + index_begin = sum([ + arbitrary_span[0] > piece_end + for _, piece_end in piece_spans + ]) + index_end = sum([ + arbitrary_span[1] >= piece_begin + for piece_begin, _ in piece_spans + ]) + if index_begin >= index_end: + return [] + + lowest_level = min( + piece_levels[index_begin:index_end] + ) + split_piece_indices = [] + target_level = piece_levels[index_begin] + for piece_index in range(index_begin, index_end): + if piece_levels[piece_index] != target_level: + continue + split_piece_indices.append(piece_index) + target_level -= 1 + if target_level < lowest_level: + break + len_indices = len(split_piece_indices) + target_level = piece_levels[index_end - 1] + for piece_index in range(index_begin, index_end)[::-1]: + if piece_levels[piece_index] != target_level: + continue + split_piece_indices.insert(len_indices, piece_index + 1) + target_level -= 1 + if target_level < lowest_level: + break + + span_begins = [ + piece_spans[piece_index][0] + for piece_index in split_piece_indices[:-1] + ] + span_begins[0] = max(arbitrary_span[0], span_begins[0]) + span_ends = [ + piece_spans[piece_index - 1][1] + for piece_index in split_piece_indices[1:] + ] + span_ends[-1] = min(arbitrary_span[1], span_ends[-1]) + return list(zip(span_begins, span_ends)) + #lowest_level_indices = [ + # piece_index + # for piece_index, piece_level in enumerate(piece_levels) + # if left_piece_index <= piece_index <= right_piece_index + # and piece_level == lowest_level + #] + #left_lowest_index = min(lowest_level_indices) + #right_lowest_index = max(lowest_level_indices) + #while right_lowest_index != right_piece_index: + + + #left_parallel_index = max( + # piece_index + # for piece_index, piece_level in enumerate(piece_levels) + # if left_piece_index <= piece_index <= right_piece_index + # and piece_level == piece_levels[left_piece_index] + #) + #right_parallel_index = min( + # piece_index + # for piece_index, piece_level in enumerate(piece_levels) + # if left_piece_index <= piece_index <= right_piece_index + # and piece_level == piece_levels[right_piece_index] + #) + #result.append(( + # piece_spans[left_lowest_index][0], + # piece_spans[right_lowest_index][1] + #)) + #lowest_piece_indices = [ + # piece_index + # for piece_index, piece_level in enumerate( + + # ) + #] + #adjusted_span_begin = max(span_begin, piece_spans[begin_piece_index][0]) ## + #adjusted_span_end = min(span_end, piece_spans[end_piece_index][1]) ## + #begin_level_mismatch = piece_levels[begin_piece_index] - lowest_level + #end_level_mismatch = piece_levels[end_piece_index] - lowest_level + #if begin_level_mismatch: + # span_begin = piece_spans[max([ + # index + # for index, piece_level in enumerate(piece_levels) + # if piece_level == lowest_level and index < begin_piece_index + # ])][1] + # begin_level_mismatch = 0 + #if end_level_mismatch: + # span_end = piece_spans[min([ + # index + # for index, piece_level in enumerate(piece_levels) + # if piece_level == lowest_level and index > end_piece_index + # ])][0] + # end_level_mismatch = 0 + @staticmethod def get_complement_spans( - interval_spans: list[Span], universal_span: Span + universal_span: Span, interval_spans: list[Span] ) -> list[Span]: if not interval_spans: return [universal_span] @@ -220,85 +368,138 @@ class LabelledString(SVGMobject, ABC): (*span_ends, universal_span[1]) )) - @staticmethod - def merge_inserted_strings_from_pairs( - inserted_string_pairs: list[tuple[Span, tuple[str, str]]] - ) -> list[tuple[int, str]]: - if not inserted_string_pairs: - return [] - - indices, *_, inserted_strings = zip(*sorted([ - ( - span[flag], - np.sign(span[1 - flag] - span[flag]), - -span[1 - flag], - flag, - (1, -1)[flag] * item_index, - str_pair[flag] - ) - for item_index, (span, str_pair) in enumerate( - inserted_string_pairs - ) - for flag in range(2) - ])) - return [ - (index, "".join(inserted_strings[slice(*index_range)])) - for index, index_range - in LabelledString.compress_neighbours(indices) - ] - - def get_replaced_substr( - self, span: Span, repl_items: list[tuple[Span, str]] - ) -> str: + def get_replaced_substr(self, span: Span, repl_items: list[Span, str]): # TODO: need `span` attr? if not repl_items: return self.get_substr(span) - repl_spans, repl_strs = zip(*sorted(repl_items)) + repl_spans, repl_strs = zip(*sorted( + repl_items, key=lambda t: t[0] + )) pieces = [ self.get_substr(piece_span) - for piece_span in self.get_complement_spans(repl_spans, span) + for piece_span in self.get_complement_spans(span, repl_spans) ] repl_strs = [*repl_strs, ""] return "".join(self.chain(*zip(pieces, repl_strs))) - def get_replaced_string( - self, - inserted_string_pairs: list[tuple[Span, tuple[str, str]]], - repl_items: list[tuple[Span, str]] - ) -> str: - all_repl_items = self.chain( - repl_items, - [ - ((index, index), inserted_string) - for index, inserted_string - in self.merge_inserted_strings_from_pairs( - inserted_string_pairs - ) - ] - ) - return self.get_replaced_substr(self.full_span, all_repl_items) + #def get_replaced_string( + # self, + # inserted_string_pairs: list[tuple[Span, tuple[str, str]]], + # repl_items: list[tuple[Span, str]] + #) -> str: + # all_repl_items = self.chain( + # repl_items, + # [ + # ((index, index), inserted_string) + # for index, inserted_string + # in self.sort_inserted_strings_from_pairs( + # inserted_string_pairs + # ) + # ] + # ) + # return self.get_replaced_substr(self.full_span, all_repl_items) @staticmethod - def color_to_int(color: ManimColor) -> int: - hex_code = rgb_to_hex(color_to_rgb(color)) - return int(hex_code[1:], 16) + def color_to_hex(color: ManimColor) -> str: + return rgb_to_hex(color_to_rgb(color)) + + @staticmethod + def hex_to_int(rgb_hex: str) -> int: + return int(rgb_hex[1:], 16) @staticmethod def int_to_hex(rgb_int: int) -> str: - return "#{:06x}".format(rgb_int).upper() + return f"#{rgb_int:06x}".upper() + + @staticmethod + @abstractmethod + def get_tag_str( + attr_dict: dict[str, str], escape_color_keys: bool, is_begin_tag: bool + ) -> str: + return "" + + #def get_color_tag_str(self, rgb_int: int, is_begin_tag: bool) -> str: + # return self.get_tag_str({ + # "foreground": self.int_to_hex(rgb_int) + # }, escape_color_keys=False, is_begin_tag=is_begin_tag) # Parsing - @abstractmethod - def get_command_repl_items(self) -> list[tuple[Span, str]]: - return [] + #@abstractmethod + #def get_command_spans(self) -> list[Span]: + # return [] + # #return [ + # # self.match(r"\\(?:[a-zA-Z]+|.)", pos=index).span() + # # for index in self.backslash_indices + # #] - @abstractmethod - def get_specified_spans(self) -> list[Span]: - return [] + #@abstractmethod + #@staticmethod + #def get_command_repl_dict() -> dict[str | re.Pattern, str]: + # return {} - def check_overlapping(self) -> None: - for span_0, span_1 in it.product(self.specified_spans, repeat=2): + #@abstractmethod + #def parse_setup(self) -> None: + # return + + #@abstractmethod + #def get_command_repl_items(self) -> list[tuple[Span, str]]: + # return [] + # #result = [] + # #for cmd_span in self.command_spans: + # # cmd_str = self.get_substr(cmd_span) + # # if + # # repl_str = self.command_repl_dict.get(cmd_str, cmd_str) + # # result.append((cmd_span, repl_str)) + # #return result + + #def span_cuts_at_entity(self, span: Span) -> bool: + # return any([ + # entity_begin < index < entity_end + # for index in span + # for entity_begin, entity_end in self.command_repl_items + # ]) + + #@abstractmethod + #def get_all_specified_items(self) -> list[tuple[Span, dict[str, str]]]: + # return [] + + #def get_specified_items(self) -> list[tuple[Span, dict[str, str]]]: + # return [ + # (span, attr_dict) + # for span, attr_dict in self.get_all_specified_items() + # if not any([ + # entity_begin < index < entity_end + # for index in span + # for entity_begin, entity_end in self.command_repl_items + # ]) + # ] + + #def get_specified_spans(self) -> list[Span]: + # return [span for span, _ in self.specified_items] + + def parse(self) -> None: + self.entity_spans = self.get_entity_spans() + tag_span_pairs, internal_items = self.get_internal_items() + self.piece_spans, self.piece_levels = self.get_piece_items( + tag_span_pairs, self.entity_spans + ) + #self.tag_content_spans = [ + # (content_begin, content_end) + # for (_, content_begin), (content_end, _) in tag_span_pairs + #] + self.tag_spans = self.chain(*tag_span_pairs) + specified_items = self.chain( + internal_items, + self.get_external_items(), + [ + (span, {}) + for span in self.find_spans_by_selector(self.isolate) + ] + ) + print(f"\n{specified_items=}\n") + specified_spans = [span for span, _ in specified_items] + for span_0, span_1 in it.product(specified_spans, repeat=2): if not span_0[0] < span_1[0] < span_0[1] < span_1[1]: continue raise ValueError( @@ -306,13 +507,246 @@ class LabelledString(SVGMobject, ABC): f"'{self.get_substr(span_0)}' and '{self.get_substr(span_1)}'" ) + split_items = [ + (span, attr_dict) + for specified_span, attr_dict in specified_items + for span in self.split_span(specified_span) + ] + print(f"\n{split_items=}\n") + split_spans = [span for span, _ in split_items] + label_span_list = self.get_label_span_list(split_spans) + if len(label_span_list) >= 16777216: + raise ValueError("Cannot handle that many substrings") + + #content_strings = [] + #for is_labelled in (False, True): + # + # content_strings.append(content_string) + + #inserted_str_pairs = self.chain( + # [ + # (span, ( + # self.get_tag_str(attr_dict, escape_color_keys=True, is_begin_tag=True), + # self.get_tag_str(attr_dict, escape_color_keys=True, is_begin_tag=False) + # )) + # for span, attr_dict in split_items + # ], + # [ + # (span, ( + # self.get_color_tag_str(label + 1, is_begin_tag=True), + # self.get_color_tag_str(label + 1, is_begin_tag=False) + # )) + # for span, attr_dict in split_items + # ] + #) + + + #decorated_strings = [ + # self.get_replaced_substr(self.full_span, [ + # (span, str_pair[flag]) + # for span, str_pair in command_repl_items + # ]) + # for flag in range(2) + #] + + self.specified_spans = specified_spans + self.label_span_list = label_span_list + self.original_content = self.get_full_content_string( + label_span_list, split_items, is_labelled=False + ) + self.labelled_content = self.get_full_content_string( + label_span_list, split_items, is_labelled=True + ) + print(self.original_content) + print() + print(self.labelled_content) + + + #self.command_repl_dict = self.get_command_repl_dict() + #self.command_repl_items = [] + #self.bracket_content_spans = [] + ##self.command_spans = self.get_command_spans() + ##self.specified_items = self.get_specified_items() + #self.specified_spans = [] + #self.check_overlapping() ####### + #self.label_span_list = [] + #if len(self.label_span_list) >= 16777216: + # raise ValueError("Cannot handle that many substrings") + @abstractmethod - def get_label_span_list(self) -> list[Span]: + def get_entity_spans(self) -> list[Span]: return [] @abstractmethod - def get_content(self, is_labelled: bool) -> str: - return "" + def get_internal_items( + self + ) -> tuple[list[tuple[Span, Span]], list[tuple[Span, dict[str, str]]]]: + return [], [] + + @abstractmethod + def get_external_items(self) -> list[tuple[Span, dict[str, str]]]: + return [] + + #@abstractmethod + #def get_spans_from_items(self, specified_items: list[tuple[Span, dict[str, str]]]) -> list[Span]: + # return [] + + #def split_span(self, arbitrary_span: Span) -> list[Span]: + # span_begin, span_end = arbitrary_span + # # TODO: improve algorithm + # span_begin += sum([ + # entity_end - span_begin + # for entity_begin, entity_end in self.entity_spans + # if entity_begin < span_begin < entity_end + # ]) + # span_end -= sum([ + # span_end - entity_begin + # for entity_begin, entity_end in self.entity_spans + # if entity_begin < span_end < entity_end + # ]) + # if span_begin >= span_end: + # return [] + + # adjusted_span = (span_begin, span_end) + # result = [] + # span_choices = list(filter( + # lambda span: span[0] < span[1] and self.span_contains( + # adjusted_span, span + # ), + # self.tag_content_spans + # )) + # while span_choices: + # chosen_span = min(span_choices, key=lambda t: (t[0], -t[1])) + # result.append(chosen_span) + # span_choices = list(filter( + # lambda span: chosen_span[1] <= span[0], + # span_choices + # )) + # result.extend(self.chain(*[ + # self.get_complement_spans(span, sorted([ + # (max(tag_span[0], span[0]), min(tag_span[1], span[1])) + # for tag_span in self.tag_spans + # if tag_span[0] < span[1] and span[0] < tag_span[1] + # ])) + # for span in self.get_complement_spans(adjusted_span, result) + # ])) + return list(filter(lambda span: span[0] < span[1], result)) + + #@abstractmethod + #def get_split_items(self, specified_items: list[T]) -> list[T]: + # return [] + + @abstractmethod + def get_label_span_list(self, split_spans: list[Span]) -> list[Span]: + return [] + + #@abstractmethod + #def get_predefined_inserted_str_items( + # self, split_items: list[T] + #) -> list[tuple[Span, tuple[str, str]]]: + # return [] + + #def check_overlapping(self) -> None: + + #for span_0, span_1 in it.product(self.specified_spans, self.bracket_content_spans): + # if not any( + # span_0[0] < span_1[0] <= span_0[1] <= span_1[1], + # span_1[0] <= span_0[0] <= span_1[1] < span_0[1] + # ): + # continue + # raise ValueError( + # f"Invalid substring detected: '{self.get_substr(span_0)}'" + # ) + # TODO: test bracket_content_spans + + #@abstractmethod + #def get_inserted_string_pairs( + # self, is_labelled: bool + #) -> list[tuple[Span, tuple[str, str]]]: + # return [] + + #@abstractmethod + #def get_label_span_list(self) -> list[Span]: + # return [] + + #def get_decorated_string( + # self, is_labelled: bool, replace_commands: bool + #) -> str: + # inserted_string_pairs = [ + # (indices, str_pair) + # for indices, str_pair in self.get_inserted_string_pairs( + # is_labelled=is_labelled + # ) + # if not any( + # cmd_begin < index < cmd_end + # for index in indices + # for (cmd_begin, cmd_end), _ in self.command_repl_items + # ) + # ] + # repl_items = [ + # ((index, index), inserted_string) + # for index, inserted_string + # in self.sort_inserted_strings_from_pairs( + # inserted_string_pairs + # ) + # ] + # if replace_commands: + # repl_items.extend(self.command_repl_items) + # return self.get_replaced_substr(self.full_span, repl_items) + + @abstractmethod + def get_additional_inserted_str_pairs( + self + ) -> list[tuple[Span, tuple[str, str]]]: + return [] + + @abstractmethod + def get_command_repl_items(self, is_labelled: bool) -> list[Span, str]: + return [] + + def get_full_content_string( + self, + label_span_list: list[Span], + split_items: list[tuple[Span, dict[str, str]]], + is_labelled: bool + ) -> str: + label_items = [ + (span, { + "foreground": self.int_to_hex(label + 1) + } if is_labelled else {}) + for label, span in enumerate(label_span_list) + ] + inserted_str_pairs = self.chain( + self.get_additional_inserted_str_pairs(), + [ + (span, tuple( + self.get_tag_str( + attr_dict, + escape_color_keys=is_labelled and not is_label_item, + is_begin_tag=is_begin_tag + ) + for is_begin_tag in (True, False) + )) + for is_label_item, items in enumerate(( + split_items, label_items + )) + for span, attr_dict in items + ] + ) + repl_items = self.chain( + self.get_command_repl_items(is_labelled), + [ + ((index, index), inserted_str) + for index, inserted_str + in self.sort_obj_pairs_by_spans(inserted_str_pairs) + ] + ) + return self.get_replaced_substr( + self.full_span, repl_items + ) + + #def get_content(self, is_labelled: bool) -> str: + # return self.content_strings[int(is_labelled)] # Selector @@ -348,7 +782,7 @@ class LabelledString(SVGMobject, ABC): group_substrs = [ self.get_cleaned_substr(span) if span[0] < span[1] else "" for span in self.get_complement_spans( - interval_spans, (ordered_spans[0][0], ordered_spans[-1][1]) + (ordered_spans[0][0], ordered_spans[-1][1]), interval_spans ) ] submob_groups = VGroup(*[ @@ -366,10 +800,11 @@ class LabelledString(SVGMobject, ABC): for span in self.specified_spans ] - def select_part_by_span(self, custom_span: Span) -> VGroup: + def select_part_by_span(self, arbitrary_span: Span) -> VGroup: return VGroup(*[ submob for label, submob in self.labelled_submobject_items - if self.span_contains(custom_span, self.label_span_list[label]) + if label != -1 + and self.span_contains(arbitrary_span, self.label_span_list[label]) ]) def select_parts(self, selector: Selector) -> VGroup: diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 3edb10b1..93e49a81 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -1,7 +1,5 @@ from __future__ import annotations -import re - from manimlib.mobject.svg.labelled_string import LabelledString from manimlib.utils.tex_file_writing import display_during_execution from manimlib.utils.tex_file_writing import get_tex_config @@ -11,6 +9,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from colour import Color + import re from typing import Iterable, Union from manimlib.mobject.types.vectorized_mobject import VGroup @@ -39,6 +38,7 @@ TEX_COLOR_COMMANDS_DICT = { "\\colorbox": (1, True), "\\fcolorbox": (2, True), } +TEX_COLOR_COMMAND_SUFFIX = "replaced" class MTex(LabelledString): @@ -56,7 +56,7 @@ class MTex(LabelledString): self.tex_string = tex_string super().__init__(tex_string, **kwargs) - self.set_color_by_tex_to_color_map(self.tex_to_color_map) + #self.set_color_by_tex_to_color_map(self.tex_to_color_map) self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size) @property @@ -83,208 +83,490 @@ class MTex(LabelledString): file_path = tex_to_svg_file(full_tex) return file_path - def parse(self) -> None: - self.backslash_indices = self.get_backslash_indices() - self.command_spans = self.get_command_spans() - self.brace_spans = self.get_brace_spans() - self.script_char_indices = self.get_script_char_indices() - self.script_content_spans = self.get_script_content_spans() - self.script_spans = self.get_script_spans() - super().parse() + #@property + #def sort_labelled_submobs(self) -> bool: + # return False # Toolkits @staticmethod - def get_color_command_str(rgb_int: int) -> str: - rg, b = divmod(rgb_int, 256) + def get_color_command_str(rgb_hex: str) -> str: + rgb = MTex.hex_to_int(rgb_hex) + rg, b = divmod(rgb, 256) r, g = divmod(rg, 256) return f"\\color[RGB]{{{r}, {g}, {b}}}" @staticmethod - def shrink_span(span: Span, skippable_indices: list[int]) -> Span: - span_begin, span_end = span - while span_begin in skippable_indices: - span_begin += 1 - while span_end - 1 in skippable_indices: - span_end -= 1 - return (span_begin, span_end) + def get_tag_str( + attr_dict: dict[str, str], escape_color_keys: bool, is_begin_tag: bool + ) -> str: + if escape_color_keys: + return "" + if not is_begin_tag: + return "}}" + if "foreground" not in attr_dict: + return "{{" + return "{{" + MTex.get_color_command_str(attr_dict["foreground"]) + + #@staticmethod + #def shrink_span(span: Span, skippable_indices: list[int]) -> Span: + # span_begin, span_end = span + # while span_begin in skippable_indices: + # span_begin += 1 + # while span_end - 1 in skippable_indices: + # span_end -= 1 + # return (span_begin, span_end) # Parsing - def get_backslash_indices(self) -> list[int]: - # The latter of `\\` doesn't count. - return self.find_indices(r"\\.") + #def parse(self) -> None: # TODO + #command_spans = self.find_spans(r"\\(?:[a-zA-Z]+|.)") - def get_command_spans(self) -> list[Span]: - return [ - self.match(r"\\(?:[a-zA-Z]+|.)", pos=index).span() - for index in self.backslash_indices - ] - def get_unescaped_char_indices(self, char: str) -> list[int]: - return list(filter( - lambda index: index - 1 not in self.backslash_indices, - self.find_indices(re.escape(char)) - )) + #specified_spans = self.chain( + # inner_content_spans, + # *[ + # self.find_spans_by_selector(selector) + # for selector in self.tex_to_color_map.keys() + # ], + # self.find_spans_by_selector(self.isolate) + #) + #print(specified_spans) + #label_span_list = self.remove_redundancies(self.chain(*[ + # self.split_span(span) + # for span in specified_spans + #])) + #print(label_span_list) + #for span in all_specified_spans: + # adjusted_span, _, _ = self.adjust_span(span, align_level=True) + # if adjusted_span[0] > adjusted_span[1]: + # continue + # specified_spans.append(adjusted_span) - def get_brace_spans(self) -> list[Span]: - span_begins = [] - span_ends = [] - span_begins_stack = [] - char_items = sorted([ - (index, char) - for char in "{}" - for index in self.get_unescaped_char_indices(char) - ]) - for index, char in char_items: - if char == "{": - span_begins_stack.append(index) - else: - if not span_begins_stack: - raise ValueError("Missing '{' inserted") - span_begins.append(span_begins_stack.pop()) - span_ends.append(index + 1) - if span_begins_stack: - raise ValueError("Missing '}' inserted") - return list(zip(span_begins, span_ends)) - def get_script_char_indices(self) -> list[int]: - return self.chain(*[ - self.get_unescaped_char_indices(char) - for char in "_^" - ]) - def get_script_content_spans(self) -> list[Span]: - result = [] - script_entity_dict = dict(self.chain( - self.brace_spans, - self.command_spans - )) - for index in self.script_char_indices: - span_begin = self.match(r"\s*", pos=index + 1).end() - if span_begin in script_entity_dict.keys(): - span_end = script_entity_dict[span_begin] - else: - match_obj = self.match(r".", pos=span_begin) - if match_obj is None: - continue - span_end = match_obj.end() - result.append((span_begin, span_end)) - return result + #reversed_script_spans_dict = { + # span_end: span_begin + # for span_begin, _, span_end in script_items + #} + #label_span_list = [ + # (content_begin, span_end) + # for _, content_begin, span_end in script_items + #] + #for span_begin, span_end in specified_spans: + # while span_end in reversed_script_spans_dict: + # span_end = reversed_script_spans_dict[span_end] + # if span_begin >= span_end: + # continue + # shrinked_span = (span_begin, span_end) + # if shrinked_span in label_span_list: + # continue + # label_span_list.append(shrinked_span) - def get_script_spans(self) -> list[Span]: - return [ - ( - self.match(r"[\s\S]*?(\s*)$", endpos=index).start(1), - script_content_span[1] - ) - for index, script_content_span in zip( - self.script_char_indices, self.script_content_spans - ) - ] + #inserted_str_items = [ + # (span, ( + # ("{{", "{{" + self.get_color_command_str(label + 1)), + # ("}}", "}}"), + # )) + # for label, span in enumerate(label_span_list) + #] + #command_repl_items = [ + # ((index, index), str_pair) + # for index, str_pair in self.sort_obj_pairs_by_spans(inserted_str_items) + #] + #for cmd_span in command_spans: + # cmd_str = self.get_substr(cmd_span) + # if cmd_str not in TEX_COLOR_COMMANDS_DICT: + # continue + # repl_str = f"{cmd_str}{TEX_COLOR_COMMAND_SUFFIX}" + # command_repl_items.append((cmd_span, (cmd_str, repl_str))) + #print(decorated_strings) + #return specified_spans, label_span_list, decorated_strings - def get_command_repl_items(self) -> list[tuple[Span, str]]: - result = [] - brace_spans_dict = dict(self.brace_spans) - brace_begins = list(brace_spans_dict.keys()) - for cmd_span in self.command_spans: - cmd_name = self.get_substr(cmd_span) - if cmd_name not in TEX_COLOR_COMMANDS_DICT.keys(): + + + #self.command_spans = self.find_spans(r"\\(?:[a-zA-Z]+|.)") + #self.ignorable_indices = self.get_ignorable_indices() + #self.brace_content_spans = self.get_brace_content_spans() + #self.command_repl_items = self.get_command_repl_items() + ##self.backslash_indices = self.get_backslash_indices() + #self.ignorable_indices = self.get_ignorable_indices() + ##self.script_items = self.get_script_items() + ##self.script_char_indices = self.get_script_char_indices() + ##self.script_content_spans = self.get_script_content_spans() + ##self.script_spans = self.get_script_spans() + #self.specified_spans = self.get_specified_spans() + ##super().parse() + #self.label_span_list = self.get_label_span_list() + + def get_entity_spans(self) -> list[Span]: + return self.find_spans(r"\\(?:[a-zA-Z]+|.)") + + def get_internal_items( + self + ) -> tuple[list[tuple[Span, Span]], list[tuple[Span, dict[str, str]]]]: + command_spans = self.entity_spans + brace_span_pairs = [] + brace_begin_spans_stack = [] + for span in self.find_spans(r"[{}]"): + char_index = span[0] + if (char_index - 1, char_index + 1) in command_spans: continue - n_braces, substitute_cmd = TEX_COLOR_COMMANDS_DICT[cmd_name] - span_begin, span_end = cmd_span - for _ in range(n_braces): - span_end = brace_spans_dict[min(filter( - lambda index: index >= span_end, - brace_begins - ))] - if substitute_cmd: - repl_str = cmd_name + n_braces * "{black}" + if self.get_substr(span) == "{": + brace_begin_spans_stack.append(span) else: - repl_str = "" - result.append(((span_begin, span_end), repl_str)) - return result + if not brace_begin_spans_stack: + raise ValueError("Missing '{' inserted") + brace_span = brace_begin_spans_stack.pop() + brace_span_pairs.append((brace_span, span)) + if brace_begin_spans_stack: + raise ValueError("Missing '}' inserted") - def get_specified_spans(self) -> list[Span]: - # Match paired double braces (`{{...}}`). - sorted_brace_spans = sorted(self.brace_spans, key=lambda t: t[1]) - inner_brace_spans = [ - sorted_brace_spans[range_begin] + #tag_span_pairs = brace_span_pairs.copy() + script_entity_dict = dict(self.chain( + [ + (span_begin, span_end) + for (span_begin, _), (_, span_end) in brace_span_pairs + ], + command_spans + )) + script_additional_brace_spans = [ + (char_index + 1, script_entity_dict.get( + script_begin, script_begin + 1 + )) + for char_index, script_begin in self.find_spans(r"[_^]\s*(?=.)") + if (char_index - 1, char_index + 1) not in command_spans + ] + #for char_index, script_begin in self.find_spans(r"[_^]\s*(?=.)"): + # if (char_index - 1, char_index + 1) in command_spans: + # continue + # script_end = script_entity_dict.get(script_begin, script_begin + 1) + # tag_span_pairs.append( + # ((char_index, char_index + 1), (script_end, script_end)) + # ) + # script_additional_brace_spans.append((char_index + 1, script_end)) + + tag_span_pairs = self.chain( + brace_span_pairs, + [ + ((script_begin - 1, script_begin), (script_end, script_end)) + for script_begin, script_end in script_additional_brace_spans + ] + ) + + brace_content_spans = [ + (span_begin, span_end) + for (_, span_begin), (span_end, _) in brace_span_pairs + ] + internal_items = [ + (brace_content_spans[range_begin], {}) for _, (range_begin, range_end) in self.compress_neighbours([ (span_begin + index, span_end - index) for index, (span_begin, span_end) in enumerate( - sorted_brace_spans + brace_content_spans ) ]) if range_end - range_begin >= 2 ] - inner_brace_content_spans = [ - (span_begin + 1, span_end - 1) - for span_begin, span_end in inner_brace_spans - if span_end - span_begin > 2 + self.script_additional_brace_spans = script_additional_brace_spans + return tag_span_pairs, internal_items + + def get_external_items(self) -> list[tuple[Span, dict[str, str]]]: + return [ + (span, {"foreground": self.color_to_hex(color)}) + for selector, color in self.tex_to_color_map.items() + for span in self.find_spans_by_selector(selector) ] - result = self.remove_redundancies(self.chain( - inner_brace_content_spans, - *[ - self.find_spans_by_selector(selector) - for selector in self.tex_to_color_map.keys() - ], - self.find_spans_by_selector(self.isolate) - )) - return list(filter( - lambda span: not any([ - entity_begin < index < entity_end - for index in span - for entity_begin, entity_end in self.command_spans - ]), - result - )) + #def get_spans_from_items(self, specified_items: list[Span]) -> list[Span]: + # return specified_items - def get_label_span_list(self) -> list[Span]: - reversed_script_spans_dict = dict([ - script_span[::-1] for script_span in self.script_spans - ]) - skippable_indices = self.chain( - self.find_indices(r"\s"), - self.script_char_indices - ) - result = self.script_content_spans.copy() - for span in self.specified_spans: - span_begin, span_end = self.shrink_span(span, skippable_indices) - while span_end in reversed_script_spans_dict.keys(): - span_end = reversed_script_spans_dict[span_end] - if span_begin >= span_end: + #def get_split_items(self, specified_items: list[Span]) -> list[Span]: + # return self.remove_redundancies(self.chain(*[ + # self.split_span(span) + # for span in specified_items + # ])) + + def get_label_span_list(self, split_spans: list[Span]) -> list[Span]: + return split_spans + + def get_additional_inserted_str_pairs( + self + ) -> list[tuple[Span, tuple[str, str]]]: + return [ + (span, ("{", "}")) + for span in self.script_additional_brace_spans + ] + + def get_command_repl_items(self, is_labelled: bool) -> list[Span, str]: + if not is_labelled: + return [] + result = [] + command_spans = self.entity_spans # TODO + for cmd_span in command_spans: + cmd_str = self.get_substr(cmd_span) + if cmd_str not in TEX_COLOR_COMMANDS_DICT: continue - shrinked_span = (span_begin, span_end) - if shrinked_span in result: - continue - result.append(shrinked_span) + repl_str = f"{cmd_str}{TEX_COLOR_COMMAND_SUFFIX}" + result.append((cmd_span, repl_str)) return result - def get_content(self, is_labelled: bool) -> str: - if is_labelled: - extended_label_span_list = [] - script_spans_dict = dict(self.script_spans) - for span in self.label_span_list: - if span not in self.script_content_spans: - span_begin, span_end = span - while span_end in script_spans_dict.keys(): - span_end = script_spans_dict[span_end] - span = (span_begin, span_end) - extended_label_span_list.append(span) - inserted_string_pairs = [ - (span, ( - "{{" + self.get_color_command_str(label + 1), - "}}" - )) - for label, span in enumerate(extended_label_span_list) - ] - result = self.get_replaced_string( - inserted_string_pairs, self.command_repl_items - ) - else: - result = self.string + #def get_predefined_inserted_str_items( + # self, split_items: list[Span] + #) -> list[tuple[Span, tuple[str, str]]]: + # return [] + + #def get_ignorable_indices(self) -> list[int]: + # return self.chain( + # [ + # index + # for index, _ in self.find_spans(r"\s") + # ], + # [ + # index + # for index, _ in self.find_spans(r"[_^{}]") + # if (index - 1, index + 1) not in self.command_spans + # ], + # ) + + #def get_bracket_content_spans(self) -> list[Span]: + # span_begins = [] + # span_ends = [] + # span_begins_stack = [] + # for match_obj in re.finditer(r"[{}]", self.string): + # index = match_obj.start() + # if (index - 1, index + 1) in command_spans: + # continue + # if match_obj.group() == "{": + # span_begins_stack.append(index + 1) + # else: + # if not span_begins_stack: + # raise ValueError("Missing '{' inserted") + # span_begins.append(span_begins_stack.pop()) + # span_ends.append(index) + # if span_begins_stack: + # raise ValueError("Missing '}' inserted") + # return list(zip(span_begins, span_ends)) + + #def get_command_repl_items(self) -> list[tuple[Span, str]]: + # result = [] + # for cmd_span in self.command_spans: + # cmd_str = self.get_substr(cmd_span) + # if cmd_str in TEX_COLOR_COMMANDS_DICT: + # repl_str = f"{cmd_str}{TEX_COLOR_COMMAND_SUFFIX}" + # else: + # repl_str = cmd_str + # result.append((cmd_span, repl_str)) + # return result + + #def get_specified_spans(self) -> list[Span]: + # # Match paired double braces (`{{...}}`). + # sorted_content_spans = sorted( + # self.bracket_content_spans, key=lambda t: t[1] + # ) + # inner_content_spans = [ + # sorted_content_spans[range_begin] + # for _, (range_begin, range_end) in self.compress_neighbours([ + # (span_begin + index, span_end - index) + # for index, (span_begin, span_end) in enumerate( + # sorted_content_spans + # ) + # ]) + # if range_end - range_begin >= 2 + # ] + # #inner_content_spans = [ + # # (span_begin + 1, span_end - 1) + # # for span_begin, span_end in inner_brace_spans + # # if span_end - span_begin > 2 + # #] + + # return self.remove_redundancies(self.chain( + # inner_content_spans, + # *[ + # self.find_spans_by_selector(selector) + # for selector in self.tex_to_color_map.keys() + # ], + # self.find_spans_by_selector(self.isolate) + # )) + # #return list(filter( + # # lambda span: not any([ + # # entity_begin < index < entity_end + # # for index in span + # # for entity_begin, entity_end in self.command_spans + # # ]), + # # result + # #)) + + #def get_label_span_list(self) -> tuple[list[int], list[Span]]: + # script_entity_dict = dict(self.chain( + # [ + # (span_begin - 1, span_end + 1) + # for span_begin, span_end in self.bracket_content_spans + # ], + # self.command_spans + # )) + # script_items = [] + # for match_obj in re.finditer(r"\s*([_^])\s*(?=.)", self.string): + # char_index = match_obj.start(1) + # if (char_index - 1, char_index + 1) in self.command_spans: + # continue + # span_begin, content_begin = match_obj.span() + # span_end = script_entity_dict.get(span_begin, content_begin + 1) + # script_items.append( + # (span_begin, char_index, content_begin, span_end) + # ) + + # reversed_script_spans_dict = { + # span_end: span_begin + # for span_begin, _, _, span_end in script_items + # } + # ignorable_indices = self.chain( + # [index for index, _ in self.find_spans(r"\s")], + # [char_index for _, char_index, _, _ in script_items] + # ) + # result = [ + # (content_begin, span_end) + # for _, _, content_begin, span_end in script_items + # ] + # for span in self.specified_spans: + # span_begin, span_end = self.shrink_span(span, ignorable_indices) + # while span_end in reversed_script_spans_dict: + # span_end = reversed_script_spans_dict[span_end] + # if span_begin >= span_end: + # continue + # shrinked_span = (span_begin, span_end) + # if shrinked_span in result: + # continue + # result.append(shrinked_span) + # return result + + #def get_command_spans(self) -> list[Span]: + # return self.find_spans() + + #def get_command_repl_items(self) -> list[Span]: + # return [ + # (span, self.get_substr(span)) + # for span in self.find_spans(r"\\(?:[a-zA-Z]+|.)") + # ] + + #def get_command_spans(self) -> list[Span]: + # return self.find_spans(r"\\(?:[a-zA-Z]+|.)") + #return [ + # self.match(r"\\(?:[a-zA-Z]+|.)", pos=index).span() + # for index in self.backslash_indices + #] + + #@staticmethod + #def get_command_repl_dict() -> dict[str | re.Pattern, str]: + # return { + # cmd_name: f"{cmd_name}replaced" + # for cmd_name in TEX_COLOR_COMMANDS_DICT + # } + + #def get_backslash_indices(self) -> list[int]: + # # The latter of `\\` doesn't count. + # return self.find_indices(r"\\.") + + #def get_unescaped_char_indices(self, char: str) -> list[int]: + # return list(filter( + # lambda index: index - 1 not in self.backslash_indices, + # self.find_indices(re.escape(char)) + # )) + + #def get_script_items(self) -> list[tuple[int, int, int, int]]: + # script_entity_dict = dict(self.chain( + # self.brace_spans, + # self.command_spans + # )) + # result = [] + # for match_obj in re.finditer(r"\s*([_^])\s*(?=.)", self.string): + # char_index = match_obj.start(1) + # if char_index - 1 in self.backslash_indices: + # continue + # span_begin, content_begin = match_obj.span() + # span_end = script_entity_dict.get(span_begin, content_begin + 1) + # result.append((span_begin, char_index, content_begin, span_end)) + # return result + + #def get_script_char_indices(self) -> list[int]: + # return self.chain(*[ + # self.get_unescaped_char_indices(char) + # for char in "_^" + # ]) + + #def get_script_content_spans(self) -> list[Span]: + # result = [] + # script_entity_dict = dict(self.chain( + # self.brace_spans, + # self.command_spans + # )) + # for index in self.script_char_indices: + # span_begin = self.match(r"\s*", pos=index + 1).end() + # if span_begin in script_entity_dict.keys(): + # span_end = script_entity_dict[span_begin] + # else: + # match_obj = self.match(r".", pos=span_begin) + # if match_obj is None: + # continue + # span_end = match_obj.end() + # result.append((span_begin, span_end)) + # return result + + #def get_script_spans(self) -> list[Span]: + # return [ + # ( + # self.match(r"[\s\S]*?(\s*)$", endpos=index).start(1), + # script_content_span[1] + # ) + # for index, script_content_span in zip( + # self.script_char_indices, self.script_content_spans + # ) + # ] + + #def get_command_repl_items(self) -> list[tuple[Span, str]]: + # result = [] + # brace_spans_dict = dict(self.brace_spans) + # brace_begins = list(brace_spans_dict.keys()) + # for cmd_span in self.command_spans: + # cmd_name = self.get_substr(cmd_span) + # if cmd_name not in TEX_COLOR_COMMANDS_DICT: + # continue + # n_braces, substitute_cmd = TEX_COLOR_COMMANDS_DICT[cmd_name] + # span_begin, span_end = cmd_span + # for _ in range(n_braces): + # span_end = brace_spans_dict[min(filter( + # lambda index: index >= span_end, + # brace_begins + # ))] + # if substitute_cmd: + # repl_str = cmd_name + n_braces * "{black}" + # else: + # repl_str = "" + # result.append(((span_begin, span_end), repl_str)) + # return result + + #def get_inserted_string_pairs( + # self, is_labelled: bool + #) -> list[tuple[Span, tuple[str, str]]]: + # if not is_labelled: + # return [] + # return [ + # (span, ( + # "{{" + self.get_color_command_str(label + 1), + # "}}" + # )) + # for label, span in enumerate(self.label_span_list) + # ] + + def get_full_content_string( + self, + label_span_list: list[Span], + split_items: list[tuple[Span, dict[str, str]]], + is_labelled: bool + ) -> str: + result = super().get_full_content_string( + label_span_list, split_items, is_labelled + ) if self.tex_environment: if isinstance(self.tex_environment, str): @@ -295,9 +577,28 @@ class MTex(LabelledString): result = "\n".join([prefix, result, suffix]) if self.alignment: result = "\n".join([self.alignment, result]) - if not is_labelled: + + if is_labelled: + occurred_commands = [ + # TODO + self.get_substr(span) for span in self.entity_spans + ] + newcommand_lines = [ + "".join([ + f"\\newcommand{cmd_name}{TEX_COLOR_COMMAND_SUFFIX}", + f"[{n_braces + 1}][]", + "{", + cmd_name + "{black}" * n_braces if substitute_cmd else "", + "}" + ]) + for cmd_name, (n_braces, substitute_cmd) + in TEX_COLOR_COMMANDS_DICT.items() + if cmd_name in occurred_commands + ] + result = "\n".join([*newcommand_lines, result]) + else: result = "\n".join([ - self.get_color_command_str(self.base_color_int), + self.get_color_command_str(self.base_color_hex), result ]) return result @@ -305,41 +606,44 @@ class MTex(LabelledString): # Selector def get_cleaned_substr(self, span: Span) -> str: - left_brace_indices = [ - span_begin - for span_begin, _ in self.brace_spans - ] - right_brace_indices = [ - span_end - 1 - for _, span_end in self.brace_spans - ] - skippable_indices = self.chain( - self.find_indices(r"\s"), - self.script_char_indices, - left_brace_indices, - right_brace_indices - ) - shrinked_span = self.shrink_span(span, skippable_indices) + return self.get_substr(span) # TODO: test + #left_brace_indices = [ + # span_begin - 1 + # for span_begin, _ in self.brace_content_spans + #] + #right_brace_indices = [ + # span_end + # for _, span_end in self.brace_content_spans + #] + #skippable_indices = self.chain( + # self.ignorable_indices, + # #self.script_char_indices, + # left_brace_indices, + # right_brace_indices + #) + #shrinked_span = self.shrink_span(span, skippable_indices) - if shrinked_span[0] >= shrinked_span[1]: - return "" + ##if shrinked_span[0] >= shrinked_span[1]: + ## return "" - # Balance braces. - unclosed_left_braces = 0 - unclosed_right_braces = 0 - for index in range(*shrinked_span): - if index in left_brace_indices: - unclosed_left_braces += 1 - elif index in right_brace_indices: - if unclosed_left_braces == 0: - unclosed_right_braces += 1 - else: - unclosed_left_braces -= 1 - return "".join([ - unclosed_right_braces * "{", - self.get_substr(shrinked_span), - unclosed_left_braces * "}" - ]) + ## Balance braces. + #unclosed_left_braces = 0 + #unclosed_right_braces = 0 + #for index in range(*shrinked_span): + # if index in left_brace_indices: + # unclosed_left_braces += 1 + # elif index in right_brace_indices: + # if unclosed_left_braces == 0: + # unclosed_right_braces += 1 + # else: + # unclosed_left_braces -= 1 + ##adjusted_span, unclosed_left_braces, unclosed_right_braces \ + ## = self.adjust_span(span, align_level=False) + #return "".join([ + # unclosed_right_braces * "{", + # self.get_substr(shrinked_span), + # unclosed_left_braces * "}" + #]) # Method alias diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 0fe113ff..3b07a07e 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -50,13 +50,16 @@ DEFAULT_CANVAS_HEIGHT = 16384 # See https://docs.gtk.org/Pango/pango_markup.html -MARKUP_COLOR_KEYS = ( - "foreground", "fgcolor", "color", - "background", "bgcolor", - "underline_color", - "overline_color", - "strikethrough_color" -) +MARKUP_COLOR_KEYS_DICT = { + "foreground": False, + "fgcolor": False, + "color": False, + "background": True, + "bgcolor": True, + "underline_color": True, + "overline_color": True, + "strikethrough_color": True, +} MARKUP_TAG_CONVERSION_DICT = { "b": {"font_weight": "bold"}, "big": {"font_size": "larger"}, @@ -66,8 +69,17 @@ MARKUP_TAG_CONVERSION_DICT = { "sup": {"baseline_shift": "superscript", "font_scale": "superscript"}, "small": {"font_size": "smaller"}, "tt": {"font_family": "monospace"}, - "u": {"underline": "single"} + "u": {"underline": "single"}, } +# See https://gitlab.gnome.org/GNOME/glib/-/blob/main/glib/gmarkup.c +# Line 629, 2204 +XML_ENTITIES = ( + ("<", "<"), + (">", ">"), + ("&", "&"), + ("\"", """), + ("'", "'") +) # Temporary handler @@ -223,28 +235,47 @@ class MarkupText(LabelledString): f"{validate_error}" ) - def parse(self) -> None: - self.global_attr_dict = self.get_global_attr_dict() - self.tag_pairs_from_markup = self.get_tag_pairs_from_markup() - self.tag_spans = self.get_tag_spans() - self.items_from_markup = self.get_items_from_markup() - self.specified_items = self.get_specified_items() - super().parse() + #def parse(self) -> None: + # #self.global_attr_dict = self.get_global_attr_dict() + # #self.items_from_markup = self.get_items_from_markup() + # #self.tag_spans = self.get_tag_spans() + # ##self.items_from_markup = self.get_items_from_markup() + # #self.specified_items = self.get_specified_items() + # super().parse() + + #@property + #def sort_labelled_submobs(self) -> bool: + # return True # Toolkits @staticmethod - def get_attr_dict_str(attr_dict: dict[str, str]) -> str: - return " ".join([ + def get_tag_str( + attr_dict: dict[str, str], escape_color_keys: bool, is_begin_tag: bool + ) -> str: + if not is_begin_tag: + return "" + if escape_color_keys: + converted_attr_dict = {} + for key, val in attr_dict.items(): + substitute_key = MARKUP_COLOR_KEYS_DICT.get(key.lower(), None) + if substitute_key is None: + converted_attr_dict[key] = val + elif substitute_key: + converted_attr_dict[key] = "black" + else: + converted_attr_dict[key] = "black" + else: + converted_attr_dict = attr_dict.copy() + result = " ".join([ f"{key}='{val}'" - for key, val in attr_dict.items() + for key, val in converted_attr_dict.items() ]) - - # Parsing + return f"" def get_global_attr_dict(self) -> dict[str, str]: result = { - "foreground": self.int_to_hex(self.base_color_int), + "foreground": self.base_color_hex, "font_family": self.font, "font_style": self.slant, "font_weight": self.weight, @@ -263,60 +294,227 @@ class MarkupText(LabelledString): result["line_height"] = str(((line_spacing_scale) + 1) * 0.6) return result - def get_tag_pairs_from_markup( - self - ) -> list[tuple[Span, Span, dict[str, str]]]: + # Parsing + + #def parse(self) -> None: + # self.bracket_content_spans, self.command_repl_items \ + # = self.get_items_from_markup() + # #self.bracket_content_spans = [ + # # span for span, _ in items_from_markup + # #] + # #specified_items = self.get_specified_items() + # #self.command_repl_items = self.get_command_repl_items() + # #self.specified_spans = self.remove_redundancies([ + # # span for span, _ in specified_items + # #]) + # #self.label_span_list = self.get_label_span_list() + # #self.predefined_items = [ + # # (self.full_span, self.get_global_attr_dict()), + # # (self.full_span, self.global_config), + # # *specified_items + # #] + + #def parse(self) -> None: # TODO: type + # if not self.is_markup: + # return [], [], [ + # (span, (escaped, escaped)) + # for char, escaped in XML_ENTITIES + # for span in self.find_spans(re.escape(char)) + # ] + + #self.entity_spans = self.find_spans(r"&[\s\S]*?;") + + #tag_spans = [span for span, _ in command_repl_items] + #begin_tag_spans = [ + # begin_tag_span for begin_tag_span, _, _ in markup_tag_items + #] + #end_tag_spans = [ + # end_tag_span for _, end_tag_span, _ in markup_tag_items + #] + #tag_spans = self.chain(begin_tag_spans, end_tag_spans) + #command_repl_items = [ + # (tag_span, "") for tag_span in tag_spans + #] + #self.chain( + # [ + # (begin_tag_span, ( + # f"", + # f"" + # )) + # for begin_tag_span, _, attr_dict in markup_tag_items + # ], + # [ + # (end_tag_span, ("", "")) + # for _, end_tag_span, _ in markup_tag_items + # ] + #) + #self.piece_spans, self.piece_levels = self.init_piece_items( + # begin_tag_spans, end_tag_spans, self.find_spans(r"&[\s\S]*?;") + #) + #command_repl_items.extend([ + # (span, (self.get_substr(span), self.get_substr(span))) + # for span in self.find_spans(r"&[\s\S]*?;") + #]) + # Needed in plain text + + #specified_items = self.chain( + # [ + # ((span_begin, span_end), attr_dict) + # for (_, span_begin), (span_end, _), attr_dict + # in markup_tag_items + # ], + # self.get_specified_items() + #) + #specified_spans = self.remove_redundancies([ + # span for span, _ in specified_items + #]) + #specified_items = [] + #for span, attr_dict in all_specified_items: + # for + # adjusted_span, _, _ = self.adjust_span(span, align_level=True) + # if adjusted_span[0] > adjusted_span[1]: + # continue + # specified_items.append(adjusted_span, attr_dict) + + + #predefined_items = [ + # (self.full_span, self.get_global_attr_dict()), + # (self.full_span, self.global_config), + # *split_items + #] + #inserted_str_items = self.chain( + # [ + # (span, ( + # ( + # f"", + # f"" + # ), + # ("", "") + # )) + # for span, attr_dict in predefined_items + # ], + # [ + # (span, ( + # ("", f""), + # ("", ""), + # )) + # for label, span in enumerate(label_span_list) + # ] + #) + #command_repl_items = self.chain( + # [ + # (tag_span, ("", "")) for tag_span in self.tag_spans + # ], + # [ + # ((index, index), str_pair) + # for index, str_pair in self.sort_obj_pairs_by_spans(inserted_str_items) + # ] + #) + #decorated_strings = [ + # self.get_replaced_substr(self.full_span, [ + # (span, str_pair[flag]) + # for span, str_pair in command_repl_items + # ]) + # for flag in range(2) + #] + #return specified_spans, label_span_list, decorated_strings + + + + + + #if is_labelled: + # attr_dict_items = self.chain( + # [ + # (span, { + # key: + # "black" if key.lower() in MARKUP_COLOR_KEYS else val + # for key, val in attr_dict.items() + # }) + # for span, attr_dict in self.predefined_items + # ], + # [ + # (span, {"foreground": self.int_to_hex(label + 1)}) + # for label, span in enumerate(self.label_span_list) + # ] + # ) + #else: + # attr_dict_items = self.chain( + # self.predefined_items, + # [ + # (span, {}) + # for span in self.label_span_list + # ] + # ) + #return [ + # (span, ( + # f"", + # "" + # )) + # for span, attr_dict in attr_dict_items + #] + #inserted_string_pairs = [ + # (indices, str_pair) + # for indices, str_pair in self.get_inserted_string_pairs( + # is_labelled=is_labelled + # ) + # if not any( + # cmd_begin < index < cmd_end + # for index in indices + # for (cmd_begin, cmd_end), _ in self.command_repl_items + # ) + #] + #return bracket_content_spans, label_span_list, command_repl_items + + def get_entity_spans(self) -> list[Span]: if not self.is_markup: return [] + return self.find_spans(r"&[\s\S]*?;") + + def get_internal_items( + self + ) -> tuple[list[tuple[Span, Span]], list[tuple[Span, dict[str, str]]]]: + if not self.is_markup: + return [], [] tag_pattern = r"""<(/?)(\w+)\s*((\w+\s*\=\s*(['"])[\s\S]*?\5\s*)*)>""" attr_pattern = r"""(\w+)\s*\=\s*(['"])([\s\S]*?)\2""" begin_match_obj_stack = [] - match_obj_pairs = [] + markup_tag_items = [] for match_obj in re.finditer(tag_pattern, self.string): if not match_obj.group(1): begin_match_obj_stack.append(match_obj) - else: - match_obj_pairs.append( - (begin_match_obj_stack.pop(), match_obj) - ) - - result = [] - for begin_match_obj, end_match_obj in match_obj_pairs: + continue + begin_match_obj = begin_match_obj_stack.pop() tag_name = begin_match_obj.group(2) if tag_name == "span": attr_dict = { - match.group(1): match.group(3) - for match in re.finditer( + attr_match_obj.group(1): attr_match_obj.group(3) + for attr_match_obj in re.finditer( attr_pattern, begin_match_obj.group(3) ) } else: attr_dict = MARKUP_TAG_CONVERSION_DICT.get(tag_name, {}) - - result.append( - (begin_match_obj.span(), end_match_obj.span(), attr_dict) + markup_tag_items.append( + (begin_match_obj.span(), match_obj.span(), attr_dict) ) - return result - def get_tag_spans(self) -> list[Span]: - return self.chain( - (begin_tag_span, end_tag_span) - for begin_tag_span, end_tag_span, _ in self.tag_pairs_from_markup - ) - - def get_items_from_markup(self) -> list[tuple[Span, dict[str, str]]]: - return [ - ((span_begin, span_end), attr_dict) - for (_, span_begin), (span_end, _), attr_dict - in self.tag_pairs_from_markup - if span_begin < span_end + tag_span_pairs = [ + (tag_begin_span, tag_end_span) + for tag_begin_span, tag_end_span, _ in markup_tag_items ] + internal_items = [ + ((span_begin, span_end), attr_dict) + for (_, span_begin), (span_end, _), attr_dict in markup_tag_items + ] + return tag_span_pairs, internal_items - def get_specified_items(self) -> list[tuple[Span, dict[str, str]]]: - result = self.chain( - self.items_from_markup, - [ + def get_external_items(self) -> list[tuple[Span, dict[str, str]]]: + return [ + (self.full_span, self.get_global_attr_dict()), + (self.full_span, self.global_config), + *[ (span, {key: val}) for t2x_dict, key in ( (self.t2c, "foreground"), @@ -327,60 +525,37 @@ class MarkupText(LabelledString): for selector, val in t2x_dict.items() for span in self.find_spans_by_selector(selector) ], - [ + *[ (span, local_config) for selector, local_config in self.local_configs.items() for span in self.find_spans_by_selector(selector) - ], - [ - (span, {}) - for span in self.find_spans_by_selector(self.isolate) ] - ) - entity_spans = self.tag_spans.copy() - if self.is_markup: - entity_spans.extend(self.find_spans(r"&[\s\S]*?;")) - return [ - (span, attr_dict) - for span, attr_dict in result - if not any([ - entity_begin < index < entity_end - for index in span - for entity_begin, entity_end in entity_spans - ]) ] - def get_command_repl_items(self) -> list[tuple[Span, str]]: - result = [ - (tag_span, "") for tag_span in self.tag_spans - ] - if not self.is_markup: - result.extend([ - (span, escaped) - for char, escaped in ( - ("&", "&"), - (">", ">"), - ("<", "<") - ) - for span in self.find_spans(re.escape(char)) - ]) - return result + #def get_spans_from_items( + # self, specified_items: list[tuple[Span, dict[str, str]]] + #) -> list[Span]: + # return [span for span, _ in specified_items] - def get_specified_spans(self) -> list[Span]: - return self.remove_redundancies([ - span for span, _ in self.specified_items - ]) + #def get_split_items( + # self, specified_items: list[tuple[Span, dict[str, str]]] + #) -> list[tuple[Span, dict[str, str]]]: + # return [ + # (span, attr_dict) + # for specified_span, attr_dict in specified_items + # for span in self.split_span(specified_span) + # ] - def get_label_span_list(self) -> list[Span]: + def get_label_span_list(self, split_spans: list[Span]) -> list[Span]: interval_spans = sorted(self.chain( self.tag_spans, [ (index, index) - for span in self.specified_spans + for span in split_spans for index in span ] )) - text_spans = self.get_complement_spans(interval_spans, self.full_span) + text_spans = self.get_complement_spans(self.full_span, interval_spans) if self.is_markup: pattern = r"[0-9a-zA-Z]+|(?:&[\s\S]*?;|[^0-9a-zA-Z\s])+" else: @@ -390,54 +565,209 @@ class MarkupText(LabelledString): for span_begin, span_end in text_spans ]) + def get_additional_inserted_str_pairs( + self + ) -> list[tuple[Span, tuple[str, str]]]: + return [] + + def get_command_repl_items(self, is_labelled: bool) -> list[Span, str]: + result = [ + (tag_span, "") for tag_span in self.tag_spans + ] + if not self.is_markup: + result.extend([ + (span, escaped) + for char, escaped in XML_ENTITIES + for span in self.find_spans(re.escape(char)) + ]) + return result + + #def get_predefined_inserted_str_items( + # self, split_items: list[tuple[Span, dict[str, str]]] + #) -> list[tuple[Span, tuple[str, str]]]: + # predefined_items = [ + # (self.full_span, self.get_global_attr_dict()), + # (self.full_span, self.global_config), + # *split_items + # ] + # return [ + # (span, ( + # ( + # self.get_tag_str(attr_dict, escape_color_keys=False, is_begin_tag=True), + # self.get_tag_str(attr_dict, escape_color_keys=True, is_begin_tag=True) + # ), + # ( + # self.get_tag_str(attr_dict, escape_color_keys=False, is_begin_tag=False), + # self.get_tag_str(attr_dict, escape_color_keys=True, is_begin_tag=False) + # ) + # )) + # for span, attr_dict in predefined_items + # ] + + #def get_full_content_string(self, replaced_string: str) -> str: + # return replaced_string + + #def get_tag_spans(self) -> list[Span]: + # return self.chain( + # (begin_tag_span, end_tag_span) + # for begin_tag_span, end_tag_span, _ in self.items_from_markup + # ) + + #def get_items_from_markup(self) -> list[tuple[Span, dict[str, str]]]: + # return [ + # ((span_begin, span_end), attr_dict) + # for (_, span_begin), (span_end, _), attr_dict + # in self.items_from_markup + # if span_begin < span_end + # ] + + #def get_command_repl_items(self) -> list[tuple[Span, str]]: + # result = [ + # (tag_span, "") + # for tag_span in self.tag_spans + # ] + # if self.is_markup: + # result.extend([ + # (span, self.get_substr(span)) + # for span in self.find_spans(r"&[\s\S]*?;") + # ]) + # else: + # result.extend([ + # (span, escaped) + # for char, escaped in ( + # ("&", "&"), + # (">", ">"), + # ("<", "<") + # ) + # for span in self.find_spans(re.escape(char)) + # ]) + # return result + + #def get_command_spans(self) -> list[Span]: + # result = self.tag_spans.copy() + # if self.is_markup: + # result.extend(self.find_spans(r"&[\s\S]*?;")) + # else: + # result.extend(self.find_spans(r"[&<>]")) + # return result + + #@staticmethod + #def get_command_repl_dict() -> dict[str | re.Pattern, str]: + # return { + # re.compile(r"<.*>"): "", + # "&": "&", + # "<": "<", + # ">": ">" + # } + # #result = [ + # # (tag_span, "") for tag_span in self.tag_spans + # #] + # #if self.is_markup: + # # result.extend([ + # # (span, self.get_substr(span)) + # # for span in self.find_spans(r"&[\s\S]*?;") + # # ]) + # #else: + # # result.extend([ + # # (span, escaped) + # # for char, escaped in ( + # # ("&", "&"), + # # (">", ">"), + # # ("<", "<") + # # ) + # # for span in self.find_spans(re.escape(char)) + # # ]) + # #return result + #entity_spans = self.tag_spans.copy() + #if self.is_markup: + # entity_spans.extend(self.find_spans(r"&[\s\S]*?;")) + #return [ + # (span, attr_dict) + # for span, attr_dict in result + # if not self.span_cuts_at_entity(span) + # #if not any([ + # # entity_begin < index < entity_end + # # for index in span + # # for entity_begin, entity_end in entity_spans + # #]) + #] + + #def get_specified_spans(self) -> list[Span]: + # return self.remove_redundancies([ + # span for span, _ in self.specified_items + # ]) + + #def get_label_span_list(self) -> list[Span]: + # interval_spans = sorted(self.chain( + # self.tag_spans, + # [ + # (index, index) + # for span in self.specified_spans + # for index in span + # ] + # )) + # text_spans = self.get_complement_spans(interval_spans, self.full_span) + # if self.is_markup: + # pattern = r"[0-9a-zA-Z]+|(?:&[\s\S]*?;|[^0-9a-zA-Z\s])+" + # else: + # pattern = r"[0-9a-zA-Z]+|[^0-9a-zA-Z\s]+" + # return self.chain(*[ + # self.find_spans(pattern, pos=span_begin, endpos=span_end) + # for span_begin, span_end in text_spans + # ]) + + #def get_inserted_string_pairs( + # self, is_labelled: bool + #) -> list[tuple[Span, tuple[str, str]]]: + # #predefined_items = [ + # # (self.full_span, self.global_attr_dict), + # # (self.full_span, self.global_config), + # # *self.specified_items + # #] + # if is_labelled: + # attr_dict_items = self.chain( + # [ + # (span, { + # key: + # "black" if key.lower() in MARKUP_COLOR_KEYS else val + # for key, val in attr_dict.items() + # }) + # for span, attr_dict in self.predefined_items + # ], + # [ + # (span, {"foreground": self.int_to_hex(label + 1)}) + # for label, span in enumerate(self.label_span_list) + # ] + # ) + # else: + # attr_dict_items = self.chain( + # self.predefined_items, + # [ + # (span, {}) + # for span in self.label_span_list + # ] + # ) + # return [ + # (span, ( + # f"", + # "" + # )) + # for span, attr_dict in attr_dict_items + # ] + def get_content(self, is_labelled: bool) -> str: - predefined_items = [ - (self.full_span, self.global_attr_dict), - (self.full_span, self.global_config), - *self.specified_items - ] - if is_labelled: - attr_dict_items = self.chain( - [ - (span, { - key: - "black" if key.lower() in MARKUP_COLOR_KEYS else val - for key, val in attr_dict.items() - }) - for span, attr_dict in predefined_items - ], - [ - (span, {"foreground": self.int_to_hex(label + 1)}) - for label, span in enumerate(self.label_span_list) - ] - ) - else: - attr_dict_items = self.chain( - predefined_items, - [ - (span, {}) - for span in self.label_span_list - ] - ) - inserted_string_pairs = [ - (span, ( - f"", - "" - )) - for span, attr_dict in attr_dict_items if attr_dict - ] - return self.get_replaced_string( - inserted_string_pairs, self.command_repl_items - ) + return self.decorated_strings[is_labelled] # Selector def get_cleaned_substr(self, span: Span) -> str: - repl_items = list(filter( - lambda repl_item: self.span_contains(span, repl_item[0]), - self.command_repl_items - )) - return self.get_replaced_substr(span, repl_items).strip() + return self.get_substr(span) # TODO: test + #repl_items = [ + # (cmd_span, repl_str) + # for cmd_span, (repl_str, _) in self.command_repl_items + # if self.span_contains(span, cmd_span) + #] + #return self.get_replaced_substr(span, repl_items).strip() # Method alias From ab8f78f40fc68c1d8d940dc4bf835284c3355793 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Tue, 3 May 2022 23:39:37 +0800 Subject: [PATCH 05/11] [WIP] Refactor LabelledString and relevant classes --- manimlib/animation/creation.py | 5 +- .../animation/transform_matching_parts.py | 64 ++-- manimlib/mobject/svg/labelled_string.py | 303 ++++++++++-------- manimlib/mobject/svg/mtex_mobject.py | 179 +++++------ manimlib/mobject/svg/text_mobject.py | 96 +++--- 5 files changed, 348 insertions(+), 299 deletions(-) diff --git a/manimlib/animation/creation.py b/manimlib/animation/creation.py index 6ad6a9bd..42ca4bf8 100644 --- a/manimlib/animation/creation.py +++ b/manimlib/animation/creation.py @@ -213,8 +213,9 @@ class AddTextWordByWord(ShowIncreasingSubsets): def __init__(self, string_mobject, **kwargs): assert isinstance(string_mobject, LabelledString) - grouped_mobject = VGroup(*[ - part for _, part in string_mobject.get_group_part_items() + grouped_mobject = string_mobject.build_parts_from_indices_lists([ + indices_list + for _, indices_list in string_mobject.get_group_part_items() ]) digest_config(self, kwargs) if self.run_time is None: diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index 96fd95ce..afa7e1ab 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -168,73 +168,67 @@ class TransformMatchingStrings(AnimationGroup): assert isinstance(source, LabelledString) assert isinstance(target, LabelledString) anims = [] + source_indices = list(range(len(source.labels))) + target_indices = list(range(len(target.labels))) - source_submobs = [ - submob for _, submob in source.labelled_submobject_items - ] - target_submobs = [ - submob for _, submob in target.labelled_submobject_items - ] - source_indices = list(range(len(source_submobs))) - target_indices = list(range(len(target_submobs))) - - def get_filtered_indices_lists(parts, submobs, rest_indices): + def get_filtered_indices_lists(indices_lists, rest_indices): return list(filter( - lambda indices_list: all([ + lambda indices_list: all( index in rest_indices for index in indices_list - ]), - [ - [submobs.index(submob) for submob in part] - for part in parts - ] + ), + indices_lists )) - def add_anims(anim_class, parts_pairs): - for source_parts, target_parts in parts_pairs: + def add_anims(anim_class, indices_lists_pairs): + for source_indices_lists, target_indices_lists in indices_lists_pairs: source_indices_lists = get_filtered_indices_lists( - source_parts, source_submobs, source_indices + source_indices_lists, source_indices ) target_indices_lists = get_filtered_indices_lists( - target_parts, target_submobs, target_indices + target_indices_lists, target_indices ) if not source_indices_lists or not target_indices_lists: continue - anims.append(anim_class(source_parts, target_parts, **kwargs)) + anims.append(anim_class( + source.build_parts_from_indices_lists(source_indices_lists), + target.build_parts_from_indices_lists(target_indices_lists), + **kwargs + )) for index in it.chain(*source_indices_lists): source_indices.remove(index) for index in it.chain(*target_indices_lists): target_indices.remove(index) - def get_substr_to_parts_map(part_items): + def get_substr_to_indices_lists_map(part_items): result = {} - for substr, part in part_items: + for substr, indices_list in part_items: if substr not in result: result[substr] = [] - result[substr].append(part) + result[substr].append(indices_list) return result def add_anims_from(anim_class, func): - source_substr_to_parts_map = get_substr_to_parts_map(func(source)) - target_substr_to_parts_map = get_substr_to_parts_map(func(target)) + source_substr_map = get_substr_to_indices_lists_map(func(source)) + target_substr_map = get_substr_to_indices_lists_map(func(target)) + common_substrings = sorted([ + s for s in source_substr_map if s and s in target_substr_map + ], key=len, reverse=True) add_anims( anim_class, [ - ( - VGroup(*source_substr_to_parts_map[substr]), - VGroup(*target_substr_to_parts_map[substr]) - ) - for substr in sorted([ - s for s in source_substr_to_parts_map - if s and s in target_substr_to_parts_map - ], key=len, reverse=True) + (source_substr_map[substr], target_substr_map[substr]) + for substr in common_substrings ] ) add_anims( ReplacementTransform, [ - (source.select_parts(k), target.select_parts(v)) + ( + source.get_submob_indices_lists_by_selector(k), + target.get_submob_indices_lists_by_selector(v) + ) for k, v in self.key_map.items() ] ) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 23c285de..ef26ecf1 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -5,6 +5,7 @@ import itertools as it import re from manimlib.constants import WHITE +from manimlib.logger import log from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.utils.color import color_to_rgb @@ -63,6 +64,7 @@ class LabelledString(SVGMobject, ABC): (submob.label, submob) for submob in self.submobjects ] + self.labels = [label for label, _ in self.labelled_submobject_items] def get_file_path(self) -> str: return self.get_file_path_by_content(self.original_content) @@ -78,26 +80,30 @@ class LabelledString(SVGMobject, ABC): labelled_svg = SVGMobject(file_path) num_submobjects = len(self.submobjects) if num_submobjects != len(labelled_svg.submobjects): - raise ValueError( + log.warning( "Cannot align submobjects of the labelled svg " - "to the original svg" - ) - - submob_color_ints = [ - self.hex_to_int(self.color_to_hex(submob.get_fill_color())) - for submob in labelled_svg.submobjects - ] - unrecognized_color_ints = self.remove_redundancies(sorted(filter( - lambda color_int: color_int > len(self.label_span_list), - submob_color_ints - ))) - if unrecognized_color_ints: - raise ValueError( - "Unrecognized color label(s) detected: " - f"{', '.join(map(self.int_to_hex, unrecognized_color_ints))}" + "to the original svg. Skip the labelling process." ) + submob_color_ints = [0] * num_submobjects + else: + submob_color_ints = [ + self.hex_to_int(self.color_to_hex(submob.get_fill_color())) + for submob in labelled_svg.submobjects + ] + unrecognized_colors = list(filter( + lambda color_int: color_int > len(self.labelled_spans), + submob_color_ints + )) + if unrecognized_colors: + log.warning( + "Unrecognized color label(s) detected (%s, etc). " + "Skip the labelling process.", + self.int_to_hex(unrecognized_colors[0]) + ) + submob_color_ints = [0] * num_submobjects #if self.sort_labelled_submobs: + # TODO: remove this submob_indices = sorted( range(num_submobjects), key=lambda index: tuple( @@ -135,12 +141,10 @@ class LabelledString(SVGMobject, ABC): # pattern = re.compile(pattern) # return re.compile(pattern).match(self.string, **kwargs) - def find_spans(self, pattern: str | re.Pattern, **kwargs) -> list[Span]: - if isinstance(pattern, str): - pattern = re.compile(pattern) + def find_spans(self, pattern: str) -> list[Span]: return [ match_obj.span() - for match_obj in pattern.finditer(self.string, **kwargs) + for match_obj in re.finditer(pattern, self.string) ] #def find_indices(self, pattern: str | re.Pattern, **kwargs) -> list[int]: @@ -151,7 +155,18 @@ class LabelledString(SVGMobject, ABC): if isinstance(sel, str): return self.find_spans(re.escape(sel)) if isinstance(sel, re.Pattern): - return self.find_spans(sel) + result_iterator = sel.finditer(self.string) + if not sel.groups: + return [ + match_obj.span() + for match_obj in result_iterator + ] + return [ + span + for match_obj in result_iterator + for span in match_obj.regs[1:] + if span != (-1, -1) + ] if isinstance(sel, tuple) and len(sel) == 2 and all( isinstance(index, int) or index is None for index in sel @@ -225,7 +240,7 @@ class LabelledString(SVGMobject, ABC): def span_contains(span_0: Span, span_1: Span) -> bool: return span_0[0] <= span_1[0] and span_0[1] >= span_1[1] - def get_piece_items( + def get_level_items( self, tag_span_pairs: list[tuple[Span, Span]], entity_spans: list[Span] @@ -241,7 +256,7 @@ class LabelledString(SVGMobject, ABC): piece_levels = [0, *it.accumulate([tag for _, tag in tagged_items])] return piece_spans, piece_levels - def split_span(self, arbitrary_span: Span) -> list[Span]: + def split_span_by_levels(self, arbitrary_span: Span) -> list[Span]: # ignorable_indices -- # left_bracket_spans # right_bracket_spans @@ -413,10 +428,10 @@ class LabelledString(SVGMobject, ABC): @staticmethod @abstractmethod - def get_tag_str( - attr_dict: dict[str, str], escape_color_keys: bool, is_begin_tag: bool - ) -> str: - return "" + def get_tag_string_pair( + attr_dict: dict[str, str], label_hex: str | None + ) -> tuple[str, str]: + return ("", "") #def get_color_tag_str(self, rgb_int: int, is_begin_tag: bool) -> str: # return self.get_tag_str({ @@ -481,7 +496,7 @@ class LabelledString(SVGMobject, ABC): def parse(self) -> None: self.entity_spans = self.get_entity_spans() tag_span_pairs, internal_items = self.get_internal_items() - self.piece_spans, self.piece_levels = self.get_piece_items( + self.piece_spans, self.piece_levels = self.get_level_items( tag_span_pairs, self.entity_spans ) #self.tag_content_spans = [ @@ -497,26 +512,19 @@ class LabelledString(SVGMobject, ABC): for span in self.find_spans_by_selector(self.isolate) ] ) - print(f"\n{specified_items=}\n") - specified_spans = [span for span, _ in specified_items] - for span_0, span_1 in it.product(specified_spans, repeat=2): - if not span_0[0] < span_1[0] < span_0[1] < span_1[1]: - continue - raise ValueError( - "Partially overlapping substrings detected: " - f"'{self.get_substr(span_0)}' and '{self.get_substr(span_1)}'" - ) + #print(f"\n{specified_items=}\n") + #specified_spans = split_items = [ (span, attr_dict) for specified_span, attr_dict in specified_items - for span in self.split_span(specified_span) + for span in self.split_span_by_levels(specified_span) ] - print(f"\n{split_items=}\n") - split_spans = [span for span, _ in split_items] - label_span_list = self.get_label_span_list(split_spans) - if len(label_span_list) >= 16777216: - raise ValueError("Cannot handle that many substrings") + #print(f"\n{split_items=}\n") + #labelled_spans = [span for span, _ in split_items] + #labelled_spans = self.get_labelled_spans(split_spans) + #if len(labelled_spans) >= 16777216: + # raise ValueError("Cannot handle that many substrings") #content_strings = [] #for is_labelled in (False, True): @@ -549,17 +557,66 @@ class LabelledString(SVGMobject, ABC): # for flag in range(2) #] - self.specified_spans = specified_spans - self.label_span_list = label_span_list - self.original_content = self.get_full_content_string( - label_span_list, split_items, is_labelled=False + command_repl_items = self.get_command_repl_items() + + #full_content_strings = {} + #for is_labelled in (False, True): + # inserted_str_pairs = [ + # (span, self.get_tag_string_pair( + # attr_dict, + # rgb_hex=self.int_to_hex(label + 1) if is_labelled else None + # )) + # for label, (span, attr_dict) in enumerate(split_items) + # ] + # repl_items = self.chain( + # command_repl_items, + # [ + # ((index, index), inserted_str) + # for index, inserted_str + # in self.sort_obj_pairs_by_spans(inserted_str_pairs) + # ] + # ) + # content_string = self.get_replaced_substr( + # self.full_span, repl_items + # ) + # full_content_string = self.get_full_content_string(content_string) + # #full_content_strings[is_labelled] = full_content_string + + self.specified_spans = [span for span, _ in specified_items] + self.labelled_spans = [span for span, _ in split_items] + for span_0, span_1 in it.product(self.labelled_spans, repeat=2): + if not span_0[0] < span_1[0] < span_0[1] < span_1[1]: + continue + raise ValueError( + "Partially overlapping substrings detected: " + f"'{self.get_substr(span_0)}' and '{self.get_substr(span_1)}'" + ) + + self.original_content, self.labelled_content = ( + self.get_full_content_string(self.get_replaced_substr( + self.full_span, self.chain( + command_repl_items, + [ + ((index, index), inserted_str) + for index, inserted_str in self.sort_obj_pairs_by_spans([ + (span, self.get_tag_string_pair( + attr_dict, + label_hex=self.int_to_hex(label + 1) if is_labelled else None + )) + for label, (span, attr_dict) in enumerate(split_items) + ]) + ] + ) + ), is_labelled=is_labelled) + for is_labelled in (False, True) ) - self.labelled_content = self.get_full_content_string( - label_span_list, split_items, is_labelled=True - ) - print(self.original_content) - print() - print(self.labelled_content) + + + #self.original_content = full_content_strings[False] + #self.labelled_content = full_content_strings[True] + #print(self.original_content) + #print() + #print(self.labelled_content) #self.command_repl_dict = self.get_command_repl_dict() @@ -569,8 +626,8 @@ class LabelledString(SVGMobject, ABC): ##self.specified_items = self.get_specified_items() #self.specified_spans = [] #self.check_overlapping() ####### - #self.label_span_list = [] - #if len(self.label_span_list) >= 16777216: + #self.labelled_spans = [] + #if len(self.labelled_spans) >= 16777216: # raise ValueError("Cannot handle that many substrings") @abstractmethod @@ -636,9 +693,9 @@ class LabelledString(SVGMobject, ABC): #def get_split_items(self, specified_items: list[T]) -> list[T]: # return [] - @abstractmethod - def get_label_span_list(self, split_spans: list[Span]) -> list[Span]: - return [] + #@abstractmethod + #def get_labelled_spans(self, split_spans: list[Span]) -> list[Span]: + # return [] #@abstractmethod #def get_predefined_inserted_str_items( @@ -666,7 +723,7 @@ class LabelledString(SVGMobject, ABC): # return [] #@abstractmethod - #def get_label_span_list(self) -> list[Span]: + #def get_labelled_spans(self) -> list[Span]: # return [] #def get_decorated_string( @@ -694,56 +751,19 @@ class LabelledString(SVGMobject, ABC): # repl_items.extend(self.command_repl_items) # return self.get_replaced_substr(self.full_span, repl_items) + #@abstractmethod + #def get_additional_inserted_str_pairs( + # self + #) -> list[tuple[Span, tuple[str, str]]]: + # return [] + @abstractmethod - def get_additional_inserted_str_pairs( - self - ) -> list[tuple[Span, tuple[str, str]]]: + def get_command_repl_items(self) -> list[Span, str]: return [] @abstractmethod - def get_command_repl_items(self, is_labelled: bool) -> list[Span, str]: - return [] - - def get_full_content_string( - self, - label_span_list: list[Span], - split_items: list[tuple[Span, dict[str, str]]], - is_labelled: bool - ) -> str: - label_items = [ - (span, { - "foreground": self.int_to_hex(label + 1) - } if is_labelled else {}) - for label, span in enumerate(label_span_list) - ] - inserted_str_pairs = self.chain( - self.get_additional_inserted_str_pairs(), - [ - (span, tuple( - self.get_tag_str( - attr_dict, - escape_color_keys=is_labelled and not is_label_item, - is_begin_tag=is_begin_tag - ) - for is_begin_tag in (True, False) - )) - for is_label_item, items in enumerate(( - split_items, label_items - )) - for span, attr_dict in items - ] - ) - repl_items = self.chain( - self.get_command_repl_items(is_labelled), - [ - ((index, index), inserted_str) - for index, inserted_str - in self.sort_obj_pairs_by_spans(inserted_str_pairs) - ] - ) - return self.get_replaced_substr( - self.full_span, repl_items - ) + def get_full_content_string(self, content_string: str, is_labelled: bool) -> str: + return "" #def get_content(self, is_labelled: bool) -> str: # return self.content_strings[int(is_labelled)] @@ -754,16 +774,15 @@ class LabelledString(SVGMobject, ABC): def get_cleaned_substr(self, span: Span) -> str: return "" - def get_group_part_items(self) -> list[tuple[str, VGroup]]: - if not self.labelled_submobject_items: + def get_group_part_items(self) -> list[tuple[str, list[int]]]: + if not self.labels: return [] - labels, labelled_submobjects = zip(*self.labelled_submobject_items) group_labels, labelled_submob_ranges = zip( - *self.compress_neighbours(labels) + *self.compress_neighbours(self.labels) ) ordered_spans = [ - self.label_span_list[label] if label != -1 else self.full_span + self.labelled_spans[label] if label != -1 else self.full_span for label in group_labels ] interval_spans = [ @@ -785,37 +804,67 @@ class LabelledString(SVGMobject, ABC): (ordered_spans[0][0], ordered_spans[-1][1]), interval_spans ) ] - submob_groups = VGroup(*[ - VGroup(*labelled_submobjects[slice(*submob_range)]) + submob_indices_lists = [ + list(range(*submob_range)) for submob_range in labelled_submob_ranges - ]) - return list(zip(group_substrs, submob_groups)) + ] + return list(zip(group_substrs, submob_indices_lists)) - def get_specified_part_items(self) -> list[tuple[str, VGroup]]: + def get_submob_indices_list_by_span( + self, arbitrary_span: Span + ) -> list[int]: + return [ + submob_index + for submob_index, label in enumerate(self.labels) + if label != -1 and self.span_contains( + arbitrary_span, self.labelled_spans[label] + ) + ] + + def get_specified_part_items(self) -> list[tuple[str, list[int]]]: return [ ( self.get_substr(span), - self.select_part_by_span(span) + self.get_submob_indices_list_by_span(span) ) for span in self.specified_spans ] - def select_part_by_span(self, arbitrary_span: Span) -> VGroup: - return VGroup(*[ - submob for label, submob in self.labelled_submobject_items - if label != -1 - and self.span_contains(arbitrary_span, self.label_span_list[label]) - ]) - - def select_parts(self, selector: Selector) -> VGroup: - return VGroup(*filter( - lambda part: part.submobjects, + def get_submob_indices_lists_by_selector( + self, selector: Selector + ) -> list[list[int]]: + return list(filter( + lambda indices_list: indices_list, [ - self.select_part_by_span(span) + self.get_submob_indices_list_by_span(span) for span in self.find_spans_by_selector(selector) ] )) + def build_parts_from_indices_lists( + self, submob_indices_lists: list[list[int]] + ) -> VGroup: + return VGroup(*[ + VGroup(*[ + self.labelled_submobject_items[submob_index][1] + for submob_index in indices_list + ]) + for indices_list in submob_indices_lists + ]) + + #def select_part_by_span(self, arbitrary_span: Span) -> VGroup: + # return VGroup(*[ + # self.labelled_submobject_items[submob_index] + # for submob_index in self.get_submob_indices_list_by_span( + # arbitrary_span + # ) + # ]) + + def select_parts(self, selector: Selector) -> VGroup: + return self.build_parts_from_indices_lists( + self.get_submob_indices_lists_by_selector(selector) + ) + def select_part(self, selector: Selector, index: int = 0) -> VGroup: return self.select_parts(selector)[index] diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 93e49a81..4a709271 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -31,14 +31,14 @@ if TYPE_CHECKING: SCALE_FACTOR_PER_FONT_POINT = 0.001 -TEX_COLOR_COMMANDS_DICT = { - "\\color": (1, False), - "\\textcolor": (1, False), - "\\pagecolor": (1, True), - "\\colorbox": (1, True), - "\\fcolorbox": (2, True), -} -TEX_COLOR_COMMAND_SUFFIX = "replaced" +#TEX_COLOR_COMMANDS_DICT = { +# "\\color": (1, False), +# "\\textcolor": (1, False), +# "\\pagecolor": (1, True), +# "\\colorbox": (1, True), +# "\\fcolorbox": (2, True), +#} +#TEX_COLOR_COMMAND_SUFFIX = "replaced" class MTex(LabelledString): @@ -56,7 +56,7 @@ class MTex(LabelledString): self.tex_string = tex_string super().__init__(tex_string, **kwargs) - #self.set_color_by_tex_to_color_map(self.tex_to_color_map) + self.set_color_by_tex_to_color_map(self.tex_to_color_map) self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size) @property @@ -97,16 +97,12 @@ class MTex(LabelledString): return f"\\color[RGB]{{{r}, {g}, {b}}}" @staticmethod - def get_tag_str( - attr_dict: dict[str, str], escape_color_keys: bool, is_begin_tag: bool - ) -> str: - if escape_color_keys: - return "" - if not is_begin_tag: - return "}}" - if "foreground" not in attr_dict: - return "{{" - return "{{" + MTex.get_color_command_str(attr_dict["foreground"]) + def get_tag_string_pair( + attr_dict: dict[str, str], label_hex: str | None + ) -> tuple[str, str]: + if label_hex is None: + return ("", "") + return ("{{" + MTex.get_color_command_str(label_hex), "}}") #@staticmethod #def shrink_span(span: Span, skippable_indices: list[int]) -> Span: @@ -223,20 +219,20 @@ class MTex(LabelledString): raise ValueError("Missing '}' inserted") #tag_span_pairs = brace_span_pairs.copy() - script_entity_dict = dict(self.chain( - [ - (span_begin, span_end) - for (span_begin, _), (_, span_end) in brace_span_pairs - ], - command_spans - )) - script_additional_brace_spans = [ - (char_index + 1, script_entity_dict.get( - script_begin, script_begin + 1 - )) - for char_index, script_begin in self.find_spans(r"[_^]\s*(?=.)") - if (char_index - 1, char_index + 1) not in command_spans - ] + #script_entity_dict = dict(self.chain( + # [ + # (span_begin, span_end) + # for (span_begin, _), (_, span_end) in brace_span_pairs + # ], + # command_spans + #)) + #script_additional_brace_spans = [ + # (char_index + 1, script_entity_dict.get( + # script_begin, script_begin + 1 + # )) + # for char_index, script_begin in self.find_spans(r"[_^]\s*(?=.)") + # if (char_index - 1, char_index + 1) not in command_spans + #] #for char_index, script_begin in self.find_spans(r"[_^]\s*(?=.)"): # if (char_index - 1, char_index + 1) in command_spans: # continue @@ -246,13 +242,13 @@ class MTex(LabelledString): # ) # script_additional_brace_spans.append((char_index + 1, script_end)) - tag_span_pairs = self.chain( - brace_span_pairs, - [ - ((script_begin - 1, script_begin), (script_end, script_end)) - for script_begin, script_end in script_additional_brace_spans - ] - ) + #tag_span_pairs = self.chain( + # brace_span_pairs, + # [ + # ((script_begin - 1, script_begin), (script_end, script_end)) + # for script_begin, script_end in script_additional_brace_spans + # ] + #) brace_content_spans = [ (span_begin, span_end) @@ -268,16 +264,19 @@ class MTex(LabelledString): ]) if range_end - range_begin >= 2 ] - self.script_additional_brace_spans = script_additional_brace_spans - return tag_span_pairs, internal_items + #self.script_additional_brace_spans = script_additional_brace_spans + return brace_span_pairs, internal_items def get_external_items(self) -> list[tuple[Span, dict[str, str]]]: return [ - (span, {"foreground": self.color_to_hex(color)}) - for selector, color in self.tex_to_color_map.items() + (span, {}) + for selector in self.tex_to_color_map for span in self.find_spans_by_selector(selector) ] + #def get_label_span_list(self, split_spans: list[Span]) -> list[Span]: + # return split_spans.copy() + #def get_spans_from_items(self, specified_items: list[Span]) -> list[Span]: # return specified_items @@ -287,29 +286,30 @@ class MTex(LabelledString): # for span in specified_items # ])) - def get_label_span_list(self, split_spans: list[Span]) -> list[Span]: - return split_spans + #def get_label_span_list(self, split_spans: list[Span]) -> list[Span]: + # return split_spans - def get_additional_inserted_str_pairs( - self - ) -> list[tuple[Span, tuple[str, str]]]: - return [ - (span, ("{", "}")) - for span in self.script_additional_brace_spans - ] + #def get_additional_inserted_str_pairs( + # self + #) -> list[tuple[Span, tuple[str, str]]]: + # return [ + # (span, ("{", "}")) + # for span in self.script_additional_brace_spans + # ] - def get_command_repl_items(self, is_labelled: bool) -> list[Span, str]: - if not is_labelled: - return [] - result = [] - command_spans = self.entity_spans # TODO - for cmd_span in command_spans: - cmd_str = self.get_substr(cmd_span) - if cmd_str not in TEX_COLOR_COMMANDS_DICT: - continue - repl_str = f"{cmd_str}{TEX_COLOR_COMMAND_SUFFIX}" - result.append((cmd_span, repl_str)) - return result + def get_command_repl_items(self) -> list[Span, str]: + return [] + #if not is_labelled: + # return [] + #result = [] + #command_spans = self.entity_spans # TODO + #for cmd_span in command_spans: + # cmd_str = self.get_substr(cmd_span) + # if cmd_str not in TEX_COLOR_COMMANDS_DICT: + # continue + # repl_str = f"{cmd_str}{TEX_COLOR_COMMAND_SUFFIX}" + # result.append((cmd_span, repl_str)) + #return result #def get_predefined_inserted_str_items( # self, split_items: list[Span] @@ -558,15 +558,8 @@ class MTex(LabelledString): # for label, span in enumerate(self.label_span_list) # ] - def get_full_content_string( - self, - label_span_list: list[Span], - split_items: list[tuple[Span, dict[str, str]]], - is_labelled: bool - ) -> str: - result = super().get_full_content_string( - label_span_list, split_items, is_labelled - ) + def get_full_content_string(self, content_string: str, is_labelled: bool) -> str: + result = content_string if self.tex_environment: if isinstance(self.tex_environment, str): @@ -578,25 +571,25 @@ class MTex(LabelledString): if self.alignment: result = "\n".join([self.alignment, result]) - if is_labelled: - occurred_commands = [ - # TODO - self.get_substr(span) for span in self.entity_spans - ] - newcommand_lines = [ - "".join([ - f"\\newcommand{cmd_name}{TEX_COLOR_COMMAND_SUFFIX}", - f"[{n_braces + 1}][]", - "{", - cmd_name + "{black}" * n_braces if substitute_cmd else "", - "}" - ]) - for cmd_name, (n_braces, substitute_cmd) - in TEX_COLOR_COMMANDS_DICT.items() - if cmd_name in occurred_commands - ] - result = "\n".join([*newcommand_lines, result]) - else: + #if is_labelled: + # occurred_commands = [ + # # TODO + # self.get_substr(span) for span in self.entity_spans + # ] + # newcommand_lines = [ + # "".join([ + # f"\\newcommand{cmd_name}{TEX_COLOR_COMMAND_SUFFIX}", + # f"[{n_braces + 1}][]", + # "{", + # cmd_name + "{black}" * n_braces if substitute_cmd else "", + # "}" + # ]) + # for cmd_name, (n_braces, substitute_cmd) + # in TEX_COLOR_COMMANDS_DICT.items() + # if cmd_name in occurred_commands + # ] + # result = "\n".join([*newcommand_lines, result]) + if not is_labelled: result = "\n".join([ self.get_color_command_str(self.base_color_hex), result diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 3b07a07e..9b82ae2f 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -114,6 +114,7 @@ class MarkupText(LabelledString): "t2w": {}, "global_config": {}, "local_configs": {}, + "split_words": True, } def __init__(self, text: str, **kwargs): @@ -162,7 +163,8 @@ class MarkupText(LabelledString): self.t2s, self.t2w, self.global_config, - self.local_configs + self.local_configs, + self.split_words ) def full2short(self, config: dict) -> None: @@ -250,28 +252,26 @@ class MarkupText(LabelledString): # Toolkits @staticmethod - def get_tag_str( - attr_dict: dict[str, str], escape_color_keys: bool, is_begin_tag: bool - ) -> str: - if not is_begin_tag: - return "" - if escape_color_keys: - converted_attr_dict = {} + def get_tag_string_pair( + attr_dict: dict[str, str], label_hex: str | None + ) -> tuple[str, str]: + if label_hex is not None: + converted_attr_dict = {"foreground": label_hex} for key, val in attr_dict.items(): substitute_key = MARKUP_COLOR_KEYS_DICT.get(key.lower(), None) if substitute_key is None: converted_attr_dict[key] = val elif substitute_key: converted_attr_dict[key] = "black" - else: - converted_attr_dict[key] = "black" + #else: + # converted_attr_dict[key] = "black" else: converted_attr_dict = attr_dict.copy() - result = " ".join([ + attrs_str = " ".join([ f"{key}='{val}'" for key, val in converted_attr_dict.items() ]) - return f"" + return (f"", "") def get_global_attr_dict(self) -> dict[str, str]: result = { @@ -286,8 +286,9 @@ class MarkupText(LabelledString): if tuple(map(int, pango_version.split("."))) < (1, 50): if self.lsh is not None: log.warning( - f"Pango version {pango_version} found (< 1.50), " - "unable to set `line_height` attribute" + "Pango version %s found (< 1.50), " + "unable to set `line_height` attribute", + pango_version ) else: line_spacing_scale = self.lsh or DEFAULT_LINE_SPACING_SCALE @@ -477,8 +478,8 @@ class MarkupText(LabelledString): if not self.is_markup: return [], [] - tag_pattern = r"""<(/?)(\w+)\s*((\w+\s*\=\s*(['"])[\s\S]*?\5\s*)*)>""" - attr_pattern = r"""(\w+)\s*\=\s*(['"])([\s\S]*?)\2""" + tag_pattern = r"<(/?)(\w+)\s*((\w+\s*\=\s*(['\x22])[\s\S]*?\5\s*)*)>" + attr_pattern = r"(\w+)\s*\=\s*(['\x22])([\s\S]*?)\2" begin_match_obj_stack = [] markup_tag_items = [] for match_obj in re.finditer(tag_pattern, self.string): @@ -511,7 +512,7 @@ class MarkupText(LabelledString): return tag_span_pairs, internal_items def get_external_items(self) -> list[tuple[Span, dict[str, str]]]: - return [ + result = [ (self.full_span, self.get_global_attr_dict()), (self.full_span, self.global_config), *[ @@ -531,6 +532,17 @@ class MarkupText(LabelledString): for span in self.find_spans_by_selector(selector) ] ] + if self.split_words: + # For backward compatibility + result.extend([ + (span, {}) + for span in self.find_spans(r"[a-zA-Z]+") + for pattern in (r"[a-zA-Z]+", r"\S+") + ]) + return result + + + #def get_label_span_list(self, split_spans: list[Span]) -> list[Span]: #def get_spans_from_items( # self, specified_items: list[tuple[Span, dict[str, str]]] @@ -546,31 +558,31 @@ class MarkupText(LabelledString): # for span in self.split_span(specified_span) # ] - def get_label_span_list(self, split_spans: list[Span]) -> list[Span]: - interval_spans = sorted(self.chain( - self.tag_spans, - [ - (index, index) - for span in split_spans - for index in span - ] - )) - text_spans = self.get_complement_spans(self.full_span, interval_spans) - if self.is_markup: - pattern = r"[0-9a-zA-Z]+|(?:&[\s\S]*?;|[^0-9a-zA-Z\s])+" - else: - pattern = r"[0-9a-zA-Z]+|[^0-9a-zA-Z\s]+" - return self.chain(*[ - self.find_spans(pattern, pos=span_begin, endpos=span_end) - for span_begin, span_end in text_spans - ]) + #def get_label_span_list(self, split_spans: list[Span]) -> list[Span]: + # interval_spans = sorted(self.chain( + # self.tag_spans, + # [ + # (index, index) + # for span in split_spans + # for index in span + # ] + # )) + # text_spans = self.get_complement_spans(self.full_span, interval_spans) + # if self.is_markup: + # pattern = r"[0-9a-zA-Z]+|(?:&[\s\S]*?;|[^0-9a-zA-Z\s])+" + # else: + # pattern = r"[0-9a-zA-Z]+|[^0-9a-zA-Z\s]+" + # return self.chain(*[ + # self.find_spans(pattern, pos=span_begin, endpos=span_end) + # for span_begin, span_end in text_spans + # ]) - def get_additional_inserted_str_pairs( - self - ) -> list[tuple[Span, tuple[str, str]]]: - return [] + #def get_additional_inserted_str_pairs( + # self + #) -> list[tuple[Span, tuple[str, str]]]: + # return [] - def get_command_repl_items(self, is_labelled: bool) -> list[Span, str]: + def get_command_repl_items(self) -> list[Span, str]: result = [ (tag_span, "") for tag_span in self.tag_spans ] @@ -755,8 +767,8 @@ class MarkupText(LabelledString): # for span, attr_dict in attr_dict_items # ] - def get_content(self, is_labelled: bool) -> str: - return self.decorated_strings[is_labelled] + def get_full_content_string(self, content_string: str, is_labelled: bool) -> str: + return content_string # Selector From 1cb740114165fd0e849e701276c157986f0d3b26 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Wed, 4 May 2022 21:56:13 +0800 Subject: [PATCH 06/11] [WIP] Refactor LabelledString and relevant classes --- .../animation/transform_matching_parts.py | 23 +- manimlib/mobject/svg/labelled_string.py | 354 ++++++++++++------ manimlib/mobject/svg/mtex_mobject.py | 225 ++++++++--- manimlib/mobject/svg/text_mobject.py | 188 +++++++--- 4 files changed, 571 insertions(+), 219 deletions(-) diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index afa7e1ab..464d36e7 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -155,7 +155,7 @@ class TransformMatchingTex(TransformMatchingParts): class TransformMatchingStrings(AnimationGroup): CONFIG = { - "key_map": dict(), + "key_map": {}, "transform_mismatches": False, } @@ -172,13 +172,16 @@ class TransformMatchingStrings(AnimationGroup): target_indices = list(range(len(target.labels))) def get_filtered_indices_lists(indices_lists, rest_indices): - return list(filter( - lambda indices_list: all( - index in rest_indices - for index in indices_list - ), - indices_lists - )) + result = [] + for indices_list in indices_lists: + if not indices_list: + continue + if not all(index in rest_indices for index in indices_list): + continue + result.append(indices_list) + for index in indices_list: + rest_indices.remove(index) + return result def add_anims(anim_class, indices_lists_pairs): for source_indices_lists, target_indices_lists in indices_lists_pairs: @@ -195,10 +198,6 @@ class TransformMatchingStrings(AnimationGroup): target.build_parts_from_indices_lists(target_indices_lists), **kwargs )) - for index in it.chain(*source_indices_lists): - source_indices.remove(index) - for index in it.chain(*target_indices_lists): - target_indices.remove(index) def get_substr_to_indices_lists_map(part_items): result = {} diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index ef26ecf1..da72fa67 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -47,7 +47,7 @@ class LabelledString(SVGMobject, ABC): "should_remove_null_curves": True, }, "base_color": WHITE, - "isolate": [], + "isolate": (), } def __init__(self, string: str, **kwargs): @@ -60,11 +60,11 @@ class LabelledString(SVGMobject, ABC): self.full_span = (0, len(self.string)) self.parse() super().__init__(**kwargs) - self.labelled_submobject_items = [ - (submob.label, submob) - for submob in self.submobjects - ] - self.labels = [label for label, _ in self.labelled_submobject_items] + #self.labelled_submobject_items = [ + # (submob.label, submob) + # for submob in self.submobjects + #] + self.labels = [submob.label for submob in self.submobjects] def get_file_path(self) -> str: return self.get_file_path_by_content(self.original_content) @@ -188,10 +188,11 @@ class LabelledString(SVGMobject, ABC): if spans is None: raise TypeError(f"Invalid selector: '{sel}'") result.extend(spans) - return sorted(filter( - lambda span: span[0] < span[1], - self.remove_redundancies(result) - )) + #return sorted(filter( + # lambda span: span[0] < span[1], + # self.remove_redundancies(result) + #)) + return result @staticmethod def chain(*iterables: Iterable[T]) -> list[T]: @@ -240,31 +241,33 @@ class LabelledString(SVGMobject, ABC): def span_contains(span_0: Span, span_1: Span) -> bool: return span_0[0] <= span_1[0] and span_0[1] >= span_1[1] - def get_level_items( - self, - tag_span_pairs: list[tuple[Span, Span]], - entity_spans: list[Span] - ) -> tuple[list[Span], list[int]]: - tagged_items = sorted(self.chain( - [(begin_cmd_span, 1) for begin_cmd_span, _ in tag_span_pairs], - [(end_cmd_span, -1) for _, end_cmd_span in tag_span_pairs], - [(entity_span, 0) for entity_span in entity_spans], - ), key=lambda t: t[0]) - piece_spans = self.get_complement_spans(self.full_span, [ - interval_span for interval_span, _ in tagged_items - ]) - piece_levels = [0, *it.accumulate([tag for _, tag in tagged_items])] - return piece_spans, piece_levels + #def get_level_interval_spans( + # self, + # tag_span_pairs: list[tuple[Span, Span]], + # entity_spans: list[Span] + #) -> list[tuple[Span, int]]: + # return sorted(self.chain( + # [(begin_cmd_span, 1) for begin_cmd_span, _ in tag_span_pairs], + # [(end_cmd_span, -1) for _, end_cmd_span in tag_span_pairs], + # [(entity_span, 0) for entity_span in entity_spans], + # ), key=lambda t: t[0]) + # #piece_spans = self.get_complement_spans(self.full_span, [ + # # interval_span for interval_span, _ in level_interval_spans + # #]) + # #piece_levels = [0, *it.accumulate([tag for _, tag in level_interval_spans])] + # #return piece_spans, piece_levels - def split_span_by_levels(self, arbitrary_span: Span) -> list[Span]: + def split_span_by_levels( + self, arbitrary_span: Span + ) -> tuple[list[Span], int, int]: # ignorable_indices -- # left_bracket_spans # right_bracket_spans # entity_spans #piece_spans, piece_levels = zip(*self.piece_items) #ignorable_indices = self.ignorable_indices - piece_spans = self.piece_spans - piece_levels = self.piece_levels + #piece_spans = self.piece_spans + #piece_levels = self.piece_levels #piece_begins, piece_ends = zip(*piece_spans) #span_begin, span_end = arbitrary_span #while span_begin in ignorable_indices: @@ -274,50 +277,141 @@ class LabelledString(SVGMobject, ABC): #entity_spans = self.chain( # left_bracket_spans, right_bracket_spans, entity_spans #) - index_begin = sum([ - arbitrary_span[0] > piece_end - for _, piece_end in piece_spans - ]) - index_end = sum([ - arbitrary_span[1] >= piece_begin - for piece_begin, _ in piece_spans - ]) - if index_begin >= index_end: - return [] + #if arbitrary_span[0] > arbitrary_span[1]: + # return [] - lowest_level = min( - piece_levels[index_begin:index_end] + #level_interval_span_items = self.level_interval_span_items + #if not level_interval_span_items: + # #if + # return [arbitrary_span] + + #span_begin, span_end = arbitrary_span + #print(level_interval_span_items) + #level_interval_spans, level_shifts = zip(*level_interval_span_items) # TODO: avoid empty list + interval_span_items = self.cmd_span_items + interval_spans = [span for span, _ in interval_span_items] + #level_interval_spans = self.level_interval_spans + #level_shifts = self.level_shifts + #print(level_interval_span_items, arbitrary_span) + #index_begin = sum([ + # arbitrary_span[0] > piece_end + # for _, piece_end in piece_spans + #]) + #interval_index_begin = sum([ + # span_begin >= interval_begin + # for interval_begin, _ in level_interval_spans + #]) + #index_end = sum([ + # arbitrary_span[1] >= piece_begin + # for piece_begin, _ in piece_spans + #]) + #interval_index_end = sum([ + # span_end >= interval_end + # for _, interval_end in level_interval_spans + #]) + #interval_range = ( + # sum([ + # arbitrary_span[0] >= interval_begin + # for interval_begin, _ in interval_spans + # ]), + # sum([ + # arbitrary_span[1] >= interval_end + # for _, interval_end in interval_spans + # ]) + #) + #interval_range = (interval_range[0], interval_range[1] - len(level_interval_spans)) + #print(interval_index_begin, interval_index_end) + #complement_spans = self.get_complement_spans(self.full_span, interval_spans) + #adjusted_span = ( + # #max(arbitrary_span[0], level_interval_spans[interval_range[0] - 1][1]), + # #if interval_range[0] > 0 else arbitrary_span[0], + # #min(arbitrary_span[1], level_interval_spans[interval_range[1]][0]) + # #if interval_range[1] < len(level_interval_spans) else arbitrary_span[1] + #) + #adjusted_span = ( + # max(arbitrary_span[0], complement_spans[interval_range[0]][0]), + # min(arbitrary_span[1], complement_spans[interval_range[1]][1]) + #) + #print(arbitrary_span, adjusted_span) + + interval_range = ( + sum([ + arbitrary_span[0] > interval_begin + for interval_begin, _ in interval_spans + ]), + sum([ + arbitrary_span[1] >= interval_end + for _, interval_end in interval_spans + ]) ) - split_piece_indices = [] - target_level = piece_levels[index_begin] - for piece_index in range(index_begin, index_end): - if piece_levels[piece_index] != target_level: - continue - split_piece_indices.append(piece_index) - target_level -= 1 - if target_level < lowest_level: - break - len_indices = len(split_piece_indices) - target_level = piece_levels[index_end - 1] - for piece_index in range(index_begin, index_end)[::-1]: - if piece_levels[piece_index] != target_level: - continue - split_piece_indices.insert(len_indices, piece_index + 1) - target_level -= 1 - if target_level < lowest_level: - break + complement_spans = self.get_complement_spans(self.full_span, interval_spans) + adjusted_span = ( + max(arbitrary_span[0], complement_spans[interval_range[0]][0]), + min(arbitrary_span[1], complement_spans[interval_range[1]][1]) + ) + if adjusted_span[0] > adjusted_span[1]: + #print([]) + return [], 0, 0 - span_begins = [ - piece_spans[piece_index][0] - for piece_index in split_piece_indices[:-1] + #lowest_level = min( + # piece_levels[index_begin:index_end] + #) + #split_piece_indices = [] + #target_level = piece_levels[index_begin] + #for piece_index in range(index_begin, index_end): + # if piece_levels[piece_index] != target_level: + # continue + # split_piece_indices.append(piece_index) + # target_level -= 1 + # if target_level < lowest_level: + # break + #len_indices = len(split_piece_indices) + #target_level = piece_levels[index_end - 1] + #for piece_index in range(index_begin, index_end)[::-1]: + # if piece_levels[piece_index] != target_level: + # continue + # split_piece_indices.insert(len_indices, piece_index + 1) + # target_level -= 1 + # if target_level < lowest_level: + # break + upwards_stack = [] + downwards_stack = [] + for interval_index in range(*interval_range): + _, level_shift = interval_span_items[interval_index] + if level_shift == 1: + upwards_stack.append(interval_index) + elif level_shift == -1: + if upwards_stack: + upwards_stack.pop() + else: + downwards_stack.append(interval_index) + #split_piece_indices = downwards_stack + upwards_stack + #print(split_piece_indices) + + covered_interval_spans = [ + interval_spans[piece_index] + for piece_index in self.chain(downwards_stack, upwards_stack) ] - span_begins[0] = max(arbitrary_span[0], span_begins[0]) - span_ends = [ - piece_spans[piece_index - 1][1] - for piece_index in split_piece_indices[1:] - ] - span_ends[-1] = min(arbitrary_span[1], span_ends[-1]) - return list(zip(span_begins, span_ends)) + result = self.get_complement_spans(adjusted_span, covered_interval_spans) + return result, len(downwards_stack), len(upwards_stack) + #if interval_index_begin > 0: + # span_begin = max(span_begin, level_interval_spans[interval_index_begin - 1][1]) + #if interval_index_end < len(level_interval_spans): + # span_end = min(span_end, level_interval_spans[interval_index_end][0]) + #universal_span = (span_begin, span_end) + #print(universal_span, self.get_complement_spans(universal_span, interval_spans)) + #print(self.get_complement_spans(adjusted_span, interval_spans)) + #span_begins = [ + # level_interval_spans[piece_index][0][1] + # for piece_index in split_piece_indices + #] + #span_begins[0] = max(arbitrary_span[0], span_begins[0]) + #span_ends = [ + # level_interval_spans[piece_index - 1][0][1] + # for piece_index in split_piece_indices[1:] + #] + #span_ends[-1] = min(arbitrary_span[1], span_ends[-1]) + #return list(zip(span_begins, span_ends)) #lowest_level_indices = [ # piece_index # for piece_index, piece_level in enumerate(piece_levels) @@ -383,7 +477,7 @@ class LabelledString(SVGMobject, ABC): (*span_ends, universal_span[1]) )) - def get_replaced_substr(self, span: Span, repl_items: list[Span, str]): # TODO: need `span` attr? + def replace_string(self, span: Span, repl_items: list[Span, str]): # TODO: need `span` attr? if not repl_items: return self.get_substr(span) @@ -412,7 +506,7 @@ class LabelledString(SVGMobject, ABC): # ) # ] # ) - # return self.get_replaced_substr(self.full_span, all_repl_items) + # return self.replace_string(self.full_span, all_repl_items) @staticmethod def color_to_hex(color: ManimColor) -> str: @@ -494,32 +588,69 @@ class LabelledString(SVGMobject, ABC): # return [span for span, _ in self.specified_items] def parse(self) -> None: - self.entity_spans = self.get_entity_spans() - tag_span_pairs, internal_items = self.get_internal_items() - self.piece_spans, self.piece_levels = self.get_level_items( - tag_span_pairs, self.entity_spans - ) + begin_cmd_spans, end_cmd_spans, cmd_spans = self.get_command_spans() + + cmd_span_items = sorted(self.chain( + [(begin_cmd_span, 1) for begin_cmd_span in begin_cmd_spans], + [(end_cmd_span, -1) for end_cmd_span in end_cmd_spans], + [(cmd_span, 0) for cmd_span in cmd_spans], + ), key=lambda t: t[0]) + self.cmd_span_items = cmd_span_items + + cmd_span_pairs = [] + begin_cmd_spans_stack = [] + for cmd_span, flag in cmd_span_items: + if flag == 1: + begin_cmd_spans_stack.append(cmd_span) + elif flag == -1: + if not begin_cmd_spans_stack: + raise ValueError("Missing '{' inserted") + begin_cmd_span = begin_cmd_spans_stack.pop() + cmd_span_pairs.append((begin_cmd_span, cmd_span)) + if begin_cmd_spans_stack: + raise ValueError("Missing '}' inserted") + + specified_items = self.get_specified_items(cmd_span_pairs) + + #entity_spans = self.get_entity_spans() + #self.entity_spans = entity_spans + #tag_span_pairs, internal_items = self.get_internal_items() + #self.level_interval_spans = self.get_level_interval_spans( + # tag_span_pairs, self.entity_spans + #) + #self.level_interval_spans = [ + # level_interval_span + # for level_interval_span, _ in level_interval_span_items + #] + #self.level_shifts = [ + # level_shift + # for _, level_shift in level_interval_span_items + #] # TODO #self.tag_content_spans = [ # (content_begin, content_end) # for (_, content_begin), (content_end, _) in tag_span_pairs #] - self.tag_spans = self.chain(*tag_span_pairs) - specified_items = self.chain( - internal_items, - self.get_external_items(), - [ - (span, {}) - for span in self.find_spans_by_selector(self.isolate) - ] - ) + #self.tag_spans = self.chain(*tag_span_pairs) + #specified_items = self.chain( + # self.get_specified_items(cmd_span_pairs) + # internal_items, + # self.get_external_items(), + # [ + # (span, {}) + # for span in self.find_spans_by_selector(self.isolate) + # ] + #) #print(f"\n{specified_items=}\n") #specified_spans = + split_items = [ (span, attr_dict) for specified_span, attr_dict in specified_items - for span in self.split_span_by_levels(specified_span) + for span in self.split_span_by_levels(specified_span)[0] ] + #print([self.get_substr(span) for span, _ in specified_items]) + #print([self.get_substr(span) for span, _ in split_items]) #print(f"\n{split_items=}\n") #labelled_spans = [span for span, _ in split_items] #labelled_spans = self.get_labelled_spans(split_spans) @@ -550,15 +681,13 @@ class LabelledString(SVGMobject, ABC): #decorated_strings = [ - # self.get_replaced_substr(self.full_span, [ + # self.replace_string(self.full_span, [ # (span, str_pair[flag]) # for span, str_pair in command_repl_items # ]) # for flag in range(2) #] - command_repl_items = self.get_command_repl_items() - #full_content_strings = {} #for is_labelled in (False, True): # inserted_str_pairs = [ @@ -576,12 +705,18 @@ class LabelledString(SVGMobject, ABC): # in self.sort_obj_pairs_by_spans(inserted_str_pairs) # ] # ) - # content_string = self.get_replaced_substr( + # content_string = self.replace_string( # self.full_span, repl_items # ) # full_content_string = self.get_full_content_string(content_string) # #full_content_strings[is_labelled] = full_content_string + command_repl_items = [ + (span, self.get_replaced_substr(self.get_substr(span), flag)) + for span, flag in cmd_span_items + ] + self.command_repl_items = command_repl_items + self.specified_spans = [span for span, _ in specified_items] self.labelled_spans = [span for span, _ in split_items] for span_0, span_1 in it.product(self.labelled_spans, repeat=2): @@ -593,7 +728,7 @@ class LabelledString(SVGMobject, ABC): ) self.original_content, self.labelled_content = ( - self.get_full_content_string(self.get_replaced_substr( + self.get_full_content_string(self.replace_string( self.full_span, self.chain( command_repl_items, [ @@ -610,6 +745,9 @@ class LabelledString(SVGMobject, ABC): ), is_labelled=is_labelled) for is_labelled in (False, True) ) + print(self.original_content) + print() + print(self.labelled_content) #self.original_content = full_content_strings[False] @@ -631,17 +769,23 @@ class LabelledString(SVGMobject, ABC): # raise ValueError("Cannot handle that many substrings") @abstractmethod - def get_entity_spans(self) -> list[Span]: - return [] + def get_command_spans(self) -> tuple[list[Span], list[Span], list[Span]]: + return [], [], [] + + #@abstractmethod + #def get_entity_spans(self) -> list[Span]: + # return [] + + #@abstractmethod + #def get_internal_items( + # self + #) -> tuple[list[tuple[Span, Span]], list[tuple[Span, dict[str, str]]]]: + # return [], [] @abstractmethod - def get_internal_items( - self - ) -> tuple[list[tuple[Span, Span]], list[tuple[Span, dict[str, str]]]]: - return [], [] - - @abstractmethod - def get_external_items(self) -> list[tuple[Span, dict[str, str]]]: + def get_specified_items( + self, cmd_span_pairs: list[tuple[Span, Span]] + ) -> list[tuple[Span, dict[str, str]]]: return [] #@abstractmethod @@ -687,7 +831,7 @@ class LabelledString(SVGMobject, ABC): # ])) # for span in self.get_complement_spans(adjusted_span, result) # ])) - return list(filter(lambda span: span[0] < span[1], result)) + # return list(filter(lambda span: span[0] < span[1], result)) #@abstractmethod #def get_split_items(self, specified_items: list[T]) -> list[T]: @@ -758,8 +902,8 @@ class LabelledString(SVGMobject, ABC): # return [] @abstractmethod - def get_command_repl_items(self) -> list[Span, str]: - return [] + def get_replaced_substr(self, substr: str, flag: int) -> str: + return "" @abstractmethod def get_full_content_string(self, content_string: str, is_labelled: bool) -> str: @@ -842,14 +986,14 @@ class LabelledString(SVGMobject, ABC): )) def build_parts_from_indices_lists( - self, submob_indices_lists: list[list[int]] + self, indices_lists: list[list[int]] ) -> VGroup: return VGroup(*[ VGroup(*[ - self.labelled_submobject_items[submob_index][1] + self.submobjects[submob_index] for submob_index in indices_list ]) - for indices_list in submob_indices_lists + for indices_list in indices_lists ]) #def select_part_by_span(self, arbitrary_span: Span) -> VGroup: diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 4a709271..03896e9c 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -195,28 +195,42 @@ class MTex(LabelledString): ##super().parse() #self.label_span_list = self.get_label_span_list() - def get_entity_spans(self) -> list[Span]: - return self.find_spans(r"\\(?:[a-zA-Z]+|.)") + def get_command_spans(self) -> tuple[list[Span], list[Span], list[Span]]: + cmd_spans = self.find_spans(r"\\(?:[a-zA-Z]+|\s|\S)") + begin_cmd_spans = [ + span + for span in self.find_spans("{") + if (span[0] - 1, span[1]) not in cmd_spans + ] + end_cmd_spans = [ + span + for span in self.find_spans("}") + if (span[0] - 1, span[1]) not in cmd_spans + ] + return begin_cmd_spans, end_cmd_spans, cmd_spans - def get_internal_items( - self - ) -> tuple[list[tuple[Span, Span]], list[tuple[Span, dict[str, str]]]]: - command_spans = self.entity_spans - brace_span_pairs = [] - brace_begin_spans_stack = [] - for span in self.find_spans(r"[{}]"): - char_index = span[0] - if (char_index - 1, char_index + 1) in command_spans: - continue - if self.get_substr(span) == "{": - brace_begin_spans_stack.append(span) - else: - if not brace_begin_spans_stack: - raise ValueError("Missing '{' inserted") - brace_span = brace_begin_spans_stack.pop() - brace_span_pairs.append((brace_span, span)) - if brace_begin_spans_stack: - raise ValueError("Missing '}' inserted") + #def get_entity_spans(self) -> list[Span]: + # return self.find_spans(r"\\(?:[a-zA-Z]+|.)") + + #def get_internal_items( + # self + #) -> tuple[list[tuple[Span, Span]], list[tuple[Span, dict[str, str]]]]: + # command_spans = self.entity_spans + # brace_span_pairs = [] + # brace_begin_spans_stack = [] + # for span in self.find_spans(r"[{}]"): + # char_index = span[0] + # if (char_index - 1, char_index + 1) in command_spans: + # continue + # if self.get_substr(span) == "{": + # brace_begin_spans_stack.append(span) + # else: + # if not brace_begin_spans_stack: + # raise ValueError("Missing '{' inserted") + # brace_span = brace_begin_spans_stack.pop() + # brace_span_pairs.append((brace_span, span)) + # if brace_begin_spans_stack: + # raise ValueError("Missing '}' inserted") #tag_span_pairs = brace_span_pairs.copy() #script_entity_dict = dict(self.chain( @@ -250,29 +264,56 @@ class MTex(LabelledString): # ] #) - brace_content_spans = [ - (span_begin, span_end) - for (_, span_begin), (span_end, _) in brace_span_pairs - ] - internal_items = [ - (brace_content_spans[range_begin], {}) - for _, (range_begin, range_end) in self.compress_neighbours([ - (span_begin + index, span_end - index) - for index, (span_begin, span_end) in enumerate( - brace_content_spans - ) - ]) - if range_end - range_begin >= 2 - ] - #self.script_additional_brace_spans = script_additional_brace_spans - return brace_span_pairs, internal_items + #brace_content_spans = [ + # (span_begin, span_end) + # for (_, span_begin), (span_end, _) in brace_span_pairs + #] + #internal_items = [ + # (brace_content_spans[range_begin], {}) + # for _, (range_begin, range_end) in self.compress_neighbours([ + # (span_begin + index, span_end - index) + # for index, (span_begin, span_end) in enumerate( + # brace_content_spans + # ) + # ]) + # if range_end - range_begin >= 2 + #] + ##self.script_additional_brace_spans = script_additional_brace_spans + #return brace_span_pairs, internal_items - def get_external_items(self) -> list[tuple[Span, dict[str, str]]]: - return [ - (span, {}) - for selector in self.tex_to_color_map - for span in self.find_spans_by_selector(selector) + #def get_external_items(self) -> list[tuple[Span, dict[str, str]]]: + # return [ + # (span, {}) + # for selector in self.tex_to_color_map + # for span in self.find_spans_by_selector(selector) + # ] + + def get_specified_items( + self, cmd_span_pairs: list[tuple[Span, Span]] + ) -> list[tuple[Span, dict[str, str]]]: + cmd_content_spans = [ + (span_begin, span_end) + for (_, span_begin), (span_end, _) in cmd_span_pairs ] + specified_spans = self.chain( + [ + cmd_content_spans[range_begin] + for _, (range_begin, range_end) in self.compress_neighbours([ + (span_begin + index, span_end - index) + for index, (span_begin, span_end) in enumerate( + cmd_content_spans + ) + ]) + if range_end - range_begin >= 2 + ], + [ + span + for selector in self.tex_to_color_map + for span in self.find_spans_by_selector(selector) + ], + self.find_spans_by_selector(self.isolate) + ) + return [(span, {}) for span in specified_spans] #def get_label_span_list(self, split_spans: list[Span]) -> list[Span]: # return split_spans.copy() @@ -297,8 +338,8 @@ class MTex(LabelledString): # for span in self.script_additional_brace_spans # ] - def get_command_repl_items(self) -> list[Span, str]: - return [] + #def get_command_repl_items(self) -> list[Span, str]: + # return [] #if not is_labelled: # return [] #result = [] @@ -558,6 +599,9 @@ class MTex(LabelledString): # for label, span in enumerate(self.label_span_list) # ] + def get_replaced_substr(self, substr: str, flag: int) -> str: + return substr # TODO: replace color commands + def get_full_content_string(self, content_string: str, is_labelled: bool) -> str: result = content_string @@ -599,7 +643,98 @@ class MTex(LabelledString): # Selector def get_cleaned_substr(self, span: Span) -> str: - return self.get_substr(span) # TODO: test + backslash_indices = [ + index for index, _ in self.find_spans(r"\\[\s\S]") + ] + #ignored_spans = [ + # ignored_span + # for ignored_span in self.find_spans(r"[\s_^{}]+") + # if ignored_span[0] - 1 not in backslash_indices + #] + #shrinked_span, _ = self.adjust_span(span, ignored_spans) + ignored_indices = [ + index + for index, _ in self.find_spans(r"[\s_^{}]") + if index - 1 not in backslash_indices + ] + span_begin, span_end = span + while span_begin in ignored_indices: + span_begin += 1 + while span_end - 1 in ignored_indices: + span_end -= 1 + shrinked_span = (span_begin, span_end) + #if span_begin >= span_end: + # return "" + + #shrinked_span = (span_begin, span_end) + _, unclosed_right_braces, unclosed_left_braces = self.split_span_by_levels(shrinked_span) + + whitespace_repl_items = [] + for whitespace_span in self.find_spans(r"\s+"): + if not self.span_contains(shrinked_span, whitespace_span): + continue + if whitespace_span[0] - 1 in backslash_indices: + whitespace_span = (whitespace_span[0] + 1, whitespace_span[1]) + if all( + self.get_substr((index, index + 1)).isalpha() + for index in (whitespace_span[0] - 1, whitespace_span[1]) + ): + replaced_substr = " " + else: + replaced_substr = "" + whitespace_repl_items.append((whitespace_span, replaced_substr)) + + return "".join([ + unclosed_right_braces * "{", + self.replace_string(shrinked_span, whitespace_repl_items), + unclosed_left_braces * "}" + ]) + + + #interval_spans = [ + # span + # if span[0] - 1 not in backslash_indices + # else (span[0] + 1, span[1]) + # for span in self.find_spans(r"[\s_^{}]+") + #] + #adjusted_span, _ = self.adjust_span(span, interval_spans) + #if adjusted_span[0] >= adjusted_span[1]: + # return "" + + #left_brace_indices = list(filter( + # lambda index: self.get_substr((index, index + 1)) == "{", + # ignored_indices + #)) + #right_brace_indices = list(filter( + # lambda index: self.get_substr((index, index + 1)) == "}", + # ignored_indices + #)) + #unclosed_left_braces = 0 + #unclosed_right_braces = 0 + #for index in range(*adjusted_span): + # if index in left_brace_indices: + # unclosed_left_braces += 1 + # elif index in right_brace_indices: + # if unclosed_left_braces == 0: + # unclosed_right_braces += 1 + # else: + # unclosed_left_braces -= 1 + #adjusted_span, unclosed_left_braces, unclosed_right_braces \ + # = self.adjust_span(span, align_level=False) + #print(self.get_substr(span), "".join([ + # unclosed_right_braces * "{", + # self.get_substr(shrinked_span), + # unclosed_left_braces * "}" + #])) + #result = "".join([ + # unclosed_right_braces * "{", + # self.get_substr(shrinked_span), + # unclosed_left_braces * "}" + #]) + #return re.sub(r"\s+", " ", result) + + #return (span_begin, span_end) + #return self.get_substr(span) # TODO: test #left_brace_indices = [ # span_begin - 1 # for span_begin, _ in self.brace_content_spans diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 9b82ae2f..fc4ff56f 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -114,7 +114,11 @@ class MarkupText(LabelledString): "t2w": {}, "global_config": {}, "local_configs": {}, - "split_words": True, + # When attempting to slice submobs via `get_part_by_text` thereafter, + # it's recommended to explicitly specify them in `isolate` attribute + # when initializing. + # For backward compatibility + "isolate": (re.compile(r"[a-zA-Z]+"), re.compile(r"\S+")), } def __init__(self, text: str, **kwargs): @@ -163,8 +167,7 @@ class MarkupText(LabelledString): self.t2s, self.t2w, self.global_config, - self.local_configs, - self.split_words + self.local_configs ) def full2short(self, config: dict) -> None: @@ -467,54 +470,114 @@ class MarkupText(LabelledString): #] #return bracket_content_spans, label_span_list, command_repl_items - def get_entity_spans(self) -> list[Span]: + def get_command_spans(self) -> tuple[list[Span], list[Span], list[Span]]: + begin_cmd_spans = self.find_spans( + r"<\w+\s*(?:\w+\s*\=\s*(['\x22])[\s\S]*?\1\s*)*>" + ) + end_cmd_spans = self.find_spans(r"") if not self.is_markup: - return [] - return self.find_spans(r"&[\s\S]*?;") + cmd_spans = [] + else: + cmd_spans = self.find_spans(r"&[\s\S]*?;") # TODO + return begin_cmd_spans, end_cmd_spans, cmd_spans - def get_internal_items( - self - ) -> tuple[list[tuple[Span, Span]], list[tuple[Span, dict[str, str]]]]: - if not self.is_markup: - return [], [] + #def get_entity_spans(self) -> list[Span]: + # if not self.is_markup: + # return [] + # return self.find_spans(r"&[\s\S]*?;") - tag_pattern = r"<(/?)(\w+)\s*((\w+\s*\=\s*(['\x22])[\s\S]*?\5\s*)*)>" + #def get_internal_items( + # self + #) -> tuple[list[tuple[Span, Span]], list[tuple[Span, dict[str, str]]]]: + # if not self.is_markup: + # return [], [] + + # tag_pattern = r"<(/?)(\w+)\s*((\w+\s*\=\s*(['\x22])[\s\S]*?\5\s*)*)>" + # attr_pattern = r"(\w+)\s*\=\s*(['\x22])([\s\S]*?)\2" + # begin_match_obj_stack = [] + # markup_tag_items = [] + # for match_obj in re.finditer(tag_pattern, self.string): + # if not match_obj.group(1): + # begin_match_obj_stack.append(match_obj) + # continue + # begin_match_obj = begin_match_obj_stack.pop() + # tag_name = begin_match_obj.group(2) + # if tag_name == "span": + # attr_dict = { + # attr_match_obj.group(1): attr_match_obj.group(3) + # for attr_match_obj in re.finditer( + # attr_pattern, begin_match_obj.group(3) + # ) + # } + # else: + # attr_dict = MARKUP_TAG_CONVERSION_DICT.get(tag_name, {}) + # markup_tag_items.append( + # (begin_match_obj.span(), match_obj.span(), attr_dict) + # ) + + # tag_span_pairs = [ + # (tag_begin_span, tag_end_span) + # for tag_begin_span, tag_end_span, _ in markup_tag_items + # ] + # internal_items = [ + # ((span_begin, span_end), attr_dict) + # for (_, span_begin), (span_end, _), attr_dict in markup_tag_items + # ] + # return tag_span_pairs, internal_items + + #def get_external_items(self) -> list[tuple[Span, dict[str, str]]]: + # return [ + # (self.full_span, self.get_global_attr_dict()), + # (self.full_span, self.global_config), + # *[ + # (span, {key: val}) + # for t2x_dict, key in ( + # (self.t2c, "foreground"), + # (self.t2f, "font_family"), + # (self.t2s, "font_style"), + # (self.t2w, "font_weight") + # ) + # for selector, val in t2x_dict.items() + # for span in self.find_spans_by_selector(selector) + # ], + # *[ + # (span, local_config) + # for selector, local_config in self.local_configs.items() + # for span in self.find_spans_by_selector(selector) + # ] + # ] + #if self.split_words: + # # For backward compatibility + # result.extend([ + # (span, {}) + # for pattern in (r"[a-zA-Z]+", r"\S+") + # for span in self.find_spans(pattern) + # ]) + #return result + + def get_specified_items( + self, cmd_span_pairs: list[tuple[Span, Span]] + ) -> list[tuple[Span, dict[str, str]]]: attr_pattern = r"(\w+)\s*\=\s*(['\x22])([\s\S]*?)\2" - begin_match_obj_stack = [] - markup_tag_items = [] - for match_obj in re.finditer(tag_pattern, self.string): - if not match_obj.group(1): - begin_match_obj_stack.append(match_obj) - continue - begin_match_obj = begin_match_obj_stack.pop() - tag_name = begin_match_obj.group(2) + internal_items = [] + for begin_cmd_span, end_cmd_span in cmd_span_pairs: + begin_tag = self.get_substr(begin_cmd_span) + tag_name = re.match(r"<(\w+)", begin_tag).group(1) if tag_name == "span": attr_dict = { attr_match_obj.group(1): attr_match_obj.group(3) - for attr_match_obj in re.finditer( - attr_pattern, begin_match_obj.group(3) - ) + for attr_match_obj in re.finditer(attr_pattern, begin_tag) } else: attr_dict = MARKUP_TAG_CONVERSION_DICT.get(tag_name, {}) - markup_tag_items.append( - (begin_match_obj.span(), match_obj.span(), attr_dict) + internal_items.append( + ((begin_cmd_span[1], end_cmd_span[0]), attr_dict) ) - tag_span_pairs = [ - (tag_begin_span, tag_end_span) - for tag_begin_span, tag_end_span, _ in markup_tag_items - ] - internal_items = [ - ((span_begin, span_end), attr_dict) - for (_, span_begin), (span_end, _), attr_dict in markup_tag_items - ] - return tag_span_pairs, internal_items - - def get_external_items(self) -> list[tuple[Span, dict[str, str]]]: - result = [ + return [ (self.full_span, self.get_global_attr_dict()), (self.full_span, self.global_config), + *internal_items, *[ (span, {key: val}) for t2x_dict, key in ( @@ -532,14 +595,7 @@ class MarkupText(LabelledString): for span in self.find_spans_by_selector(selector) ] ] - if self.split_words: - # For backward compatibility - result.extend([ - (span, {}) - for span in self.find_spans(r"[a-zA-Z]+") - for pattern in (r"[a-zA-Z]+", r"\S+") - ]) - return result + #def get_label_span_list(self, split_spans: list[Span]) -> list[Span]: @@ -582,17 +638,17 @@ class MarkupText(LabelledString): #) -> list[tuple[Span, tuple[str, str]]]: # return [] - def get_command_repl_items(self) -> list[Span, str]: - result = [ - (tag_span, "") for tag_span in self.tag_spans - ] - if not self.is_markup: - result.extend([ - (span, escaped) - for char, escaped in XML_ENTITIES - for span in self.find_spans(re.escape(char)) - ]) - return result + #def get_command_repl_items(self) -> list[Span, str]: + # result = [ + # (tag_span, "") for tag_span in self.tag_spans # TODO + # ] + # if not self.is_markup: + # result.extend([ + # (span, escaped) + # for char, escaped in XML_ENTITIES + # for span in self.find_spans(re.escape(char)) + # ]) + # return result #def get_predefined_inserted_str_items( # self, split_items: list[tuple[Span, dict[str, str]]] @@ -767,13 +823,31 @@ class MarkupText(LabelledString): # for span, attr_dict in attr_dict_items # ] + def get_replaced_substr(self, substr: str, flag: int) -> str: + if flag: + return "" + return dict(XML_ENTITIES).get(substr, substr) + def get_full_content_string(self, content_string: str, is_labelled: bool) -> str: return content_string # Selector def get_cleaned_substr(self, span: Span) -> str: - return self.get_substr(span) # TODO: test + filtered_repl_items = [] + entity_to_char_dict = { + entity: char + for char, entity in XML_ENTITIES + } + for cmd_span, replaced_substr in self.command_repl_items: + if not self.span_contains(span, cmd_span): + continue + if re.fullmatch(r"&[\s\S]*;", replaced_substr): + if replaced_substr in entity_to_char_dict: + replaced_substr = entity_to_char_dict[replaced_substr] + filtered_repl_items.append((cmd_span, replaced_substr)) + + return self.replace_string(span, filtered_repl_items).strip() # TODO: test #repl_items = [ # (cmd_span, repl_str) # for cmd_span, (repl_str, _) in self.command_repl_items From 511a3aab3d3f50a621663d49f0988dd1c6ae70f2 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Wed, 4 May 2022 22:18:19 +0800 Subject: [PATCH 07/11] [WIP] Remove comments --- manimlib/animation/creation.py | 5 +- manimlib/mobject/svg/labelled_string.py | 611 ++---------------------- manimlib/mobject/svg/mtex_mobject.py | 577 +--------------------- manimlib/mobject/svg/text_mobject.py | 487 +------------------ 4 files changed, 55 insertions(+), 1625 deletions(-) diff --git a/manimlib/animation/creation.py b/manimlib/animation/creation.py index 42ca4bf8..d3a3dd23 100644 --- a/manimlib/animation/creation.py +++ b/manimlib/animation/creation.py @@ -213,10 +213,7 @@ class AddTextWordByWord(ShowIncreasingSubsets): def __init__(self, string_mobject, **kwargs): assert isinstance(string_mobject, LabelledString) - grouped_mobject = string_mobject.build_parts_from_indices_lists([ - indices_list - for _, indices_list in string_mobject.get_group_part_items() - ]) + grouped_mobject = string_mobject.build_groups() digest_config(self, kwargs) if self.run_time is None: self.run_time = self.time_per_word * len(grouped_mobject) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index da72fa67..03b5da8b 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -60,10 +60,6 @@ class LabelledString(SVGMobject, ABC): self.full_span = (0, len(self.string)) self.parse() super().__init__(**kwargs) - #self.labelled_submobject_items = [ - # (submob.label, submob) - # for submob in self.submobjects - #] self.labels = [submob.label for submob in self.submobjects] def get_file_path(self) -> str: @@ -102,8 +98,8 @@ class LabelledString(SVGMobject, ABC): ) submob_color_ints = [0] * num_submobjects + #TODO: remove this #if self.sort_labelled_submobs: - # TODO: remove this submob_indices = sorted( range(num_submobjects), key=lambda index: tuple( @@ -136,20 +132,12 @@ class LabelledString(SVGMobject, ABC): def get_substr(self, span: Span) -> str: return self.string[slice(*span)] - #def match(self, pattern: str | re.Pattern, **kwargs) -> re.Pattern | None: - # if isinstance(pattern, str): - # pattern = re.compile(pattern) - # return re.compile(pattern).match(self.string, **kwargs) - def find_spans(self, pattern: str) -> list[Span]: return [ match_obj.span() for match_obj in re.finditer(pattern, self.string) ] - #def find_indices(self, pattern: str | re.Pattern, **kwargs) -> list[int]: - # return [index for index, _ in self.find_spans(pattern, **kwargs)] - def find_spans_by_selector(self, selector: Selector) -> list[Span]: def find_spans_by_single_selector(sel): if isinstance(sel, str): @@ -241,229 +229,6 @@ class LabelledString(SVGMobject, ABC): def span_contains(span_0: Span, span_1: Span) -> bool: return span_0[0] <= span_1[0] and span_0[1] >= span_1[1] - #def get_level_interval_spans( - # self, - # tag_span_pairs: list[tuple[Span, Span]], - # entity_spans: list[Span] - #) -> list[tuple[Span, int]]: - # return sorted(self.chain( - # [(begin_cmd_span, 1) for begin_cmd_span, _ in tag_span_pairs], - # [(end_cmd_span, -1) for _, end_cmd_span in tag_span_pairs], - # [(entity_span, 0) for entity_span in entity_spans], - # ), key=lambda t: t[0]) - # #piece_spans = self.get_complement_spans(self.full_span, [ - # # interval_span for interval_span, _ in level_interval_spans - # #]) - # #piece_levels = [0, *it.accumulate([tag for _, tag in level_interval_spans])] - # #return piece_spans, piece_levels - - def split_span_by_levels( - self, arbitrary_span: Span - ) -> tuple[list[Span], int, int]: - # ignorable_indices -- - # left_bracket_spans - # right_bracket_spans - # entity_spans - #piece_spans, piece_levels = zip(*self.piece_items) - #ignorable_indices = self.ignorable_indices - #piece_spans = self.piece_spans - #piece_levels = self.piece_levels - #piece_begins, piece_ends = zip(*piece_spans) - #span_begin, span_end = arbitrary_span - #while span_begin in ignorable_indices: - # span_begin += 1 - #while span_end - 1 in ignorable_indices: - # span_end -= 1 - #entity_spans = self.chain( - # left_bracket_spans, right_bracket_spans, entity_spans - #) - #if arbitrary_span[0] > arbitrary_span[1]: - # return [] - - #level_interval_span_items = self.level_interval_span_items - #if not level_interval_span_items: - # #if - # return [arbitrary_span] - - #span_begin, span_end = arbitrary_span - #print(level_interval_span_items) - #level_interval_spans, level_shifts = zip(*level_interval_span_items) # TODO: avoid empty list - interval_span_items = self.cmd_span_items - interval_spans = [span for span, _ in interval_span_items] - #level_interval_spans = self.level_interval_spans - #level_shifts = self.level_shifts - #print(level_interval_span_items, arbitrary_span) - #index_begin = sum([ - # arbitrary_span[0] > piece_end - # for _, piece_end in piece_spans - #]) - #interval_index_begin = sum([ - # span_begin >= interval_begin - # for interval_begin, _ in level_interval_spans - #]) - #index_end = sum([ - # arbitrary_span[1] >= piece_begin - # for piece_begin, _ in piece_spans - #]) - #interval_index_end = sum([ - # span_end >= interval_end - # for _, interval_end in level_interval_spans - #]) - #interval_range = ( - # sum([ - # arbitrary_span[0] >= interval_begin - # for interval_begin, _ in interval_spans - # ]), - # sum([ - # arbitrary_span[1] >= interval_end - # for _, interval_end in interval_spans - # ]) - #) - #interval_range = (interval_range[0], interval_range[1] - len(level_interval_spans)) - #print(interval_index_begin, interval_index_end) - #complement_spans = self.get_complement_spans(self.full_span, interval_spans) - #adjusted_span = ( - # #max(arbitrary_span[0], level_interval_spans[interval_range[0] - 1][1]), - # #if interval_range[0] > 0 else arbitrary_span[0], - # #min(arbitrary_span[1], level_interval_spans[interval_range[1]][0]) - # #if interval_range[1] < len(level_interval_spans) else arbitrary_span[1] - #) - #adjusted_span = ( - # max(arbitrary_span[0], complement_spans[interval_range[0]][0]), - # min(arbitrary_span[1], complement_spans[interval_range[1]][1]) - #) - #print(arbitrary_span, adjusted_span) - - interval_range = ( - sum([ - arbitrary_span[0] > interval_begin - for interval_begin, _ in interval_spans - ]), - sum([ - arbitrary_span[1] >= interval_end - for _, interval_end in interval_spans - ]) - ) - complement_spans = self.get_complement_spans(self.full_span, interval_spans) - adjusted_span = ( - max(arbitrary_span[0], complement_spans[interval_range[0]][0]), - min(arbitrary_span[1], complement_spans[interval_range[1]][1]) - ) - if adjusted_span[0] > adjusted_span[1]: - #print([]) - return [], 0, 0 - - #lowest_level = min( - # piece_levels[index_begin:index_end] - #) - #split_piece_indices = [] - #target_level = piece_levels[index_begin] - #for piece_index in range(index_begin, index_end): - # if piece_levels[piece_index] != target_level: - # continue - # split_piece_indices.append(piece_index) - # target_level -= 1 - # if target_level < lowest_level: - # break - #len_indices = len(split_piece_indices) - #target_level = piece_levels[index_end - 1] - #for piece_index in range(index_begin, index_end)[::-1]: - # if piece_levels[piece_index] != target_level: - # continue - # split_piece_indices.insert(len_indices, piece_index + 1) - # target_level -= 1 - # if target_level < lowest_level: - # break - upwards_stack = [] - downwards_stack = [] - for interval_index in range(*interval_range): - _, level_shift = interval_span_items[interval_index] - if level_shift == 1: - upwards_stack.append(interval_index) - elif level_shift == -1: - if upwards_stack: - upwards_stack.pop() - else: - downwards_stack.append(interval_index) - #split_piece_indices = downwards_stack + upwards_stack - #print(split_piece_indices) - - covered_interval_spans = [ - interval_spans[piece_index] - for piece_index in self.chain(downwards_stack, upwards_stack) - ] - result = self.get_complement_spans(adjusted_span, covered_interval_spans) - return result, len(downwards_stack), len(upwards_stack) - #if interval_index_begin > 0: - # span_begin = max(span_begin, level_interval_spans[interval_index_begin - 1][1]) - #if interval_index_end < len(level_interval_spans): - # span_end = min(span_end, level_interval_spans[interval_index_end][0]) - #universal_span = (span_begin, span_end) - #print(universal_span, self.get_complement_spans(universal_span, interval_spans)) - #print(self.get_complement_spans(adjusted_span, interval_spans)) - #span_begins = [ - # level_interval_spans[piece_index][0][1] - # for piece_index in split_piece_indices - #] - #span_begins[0] = max(arbitrary_span[0], span_begins[0]) - #span_ends = [ - # level_interval_spans[piece_index - 1][0][1] - # for piece_index in split_piece_indices[1:] - #] - #span_ends[-1] = min(arbitrary_span[1], span_ends[-1]) - #return list(zip(span_begins, span_ends)) - #lowest_level_indices = [ - # piece_index - # for piece_index, piece_level in enumerate(piece_levels) - # if left_piece_index <= piece_index <= right_piece_index - # and piece_level == lowest_level - #] - #left_lowest_index = min(lowest_level_indices) - #right_lowest_index = max(lowest_level_indices) - #while right_lowest_index != right_piece_index: - - - #left_parallel_index = max( - # piece_index - # for piece_index, piece_level in enumerate(piece_levels) - # if left_piece_index <= piece_index <= right_piece_index - # and piece_level == piece_levels[left_piece_index] - #) - #right_parallel_index = min( - # piece_index - # for piece_index, piece_level in enumerate(piece_levels) - # if left_piece_index <= piece_index <= right_piece_index - # and piece_level == piece_levels[right_piece_index] - #) - #result.append(( - # piece_spans[left_lowest_index][0], - # piece_spans[right_lowest_index][1] - #)) - #lowest_piece_indices = [ - # piece_index - # for piece_index, piece_level in enumerate( - - # ) - #] - #adjusted_span_begin = max(span_begin, piece_spans[begin_piece_index][0]) ## - #adjusted_span_end = min(span_end, piece_spans[end_piece_index][1]) ## - #begin_level_mismatch = piece_levels[begin_piece_index] - lowest_level - #end_level_mismatch = piece_levels[end_piece_index] - lowest_level - #if begin_level_mismatch: - # span_begin = piece_spans[max([ - # index - # for index, piece_level in enumerate(piece_levels) - # if piece_level == lowest_level and index < begin_piece_index - # ])][1] - # begin_level_mismatch = 0 - #if end_level_mismatch: - # span_end = piece_spans[min([ - # index - # for index, piece_level in enumerate(piece_levels) - # if piece_level == lowest_level and index > end_piece_index - # ])][0] - # end_level_mismatch = 0 - @staticmethod def get_complement_spans( universal_span: Span, interval_spans: list[Span] @@ -477,7 +242,7 @@ class LabelledString(SVGMobject, ABC): (*span_ends, universal_span[1]) )) - def replace_string(self, span: Span, repl_items: list[Span, str]): # TODO: need `span` attr? + def replace_string(self, span: Span, repl_items: list[Span, str]): if not repl_items: return self.get_substr(span) @@ -491,23 +256,6 @@ class LabelledString(SVGMobject, ABC): repl_strs = [*repl_strs, ""] return "".join(self.chain(*zip(pieces, repl_strs))) - #def get_replaced_string( - # self, - # inserted_string_pairs: list[tuple[Span, tuple[str, str]]], - # repl_items: list[tuple[Span, str]] - #) -> str: - # all_repl_items = self.chain( - # repl_items, - # [ - # ((index, index), inserted_string) - # for index, inserted_string - # in self.sort_inserted_strings_from_pairs( - # inserted_string_pairs - # ) - # ] - # ) - # return self.replace_string(self.full_span, all_repl_items) - @staticmethod def color_to_hex(color: ManimColor) -> str: return rgb_to_hex(color_to_rgb(color)) @@ -527,66 +275,8 @@ class LabelledString(SVGMobject, ABC): ) -> tuple[str, str]: return ("", "") - #def get_color_tag_str(self, rgb_int: int, is_begin_tag: bool) -> str: - # return self.get_tag_str({ - # "foreground": self.int_to_hex(rgb_int) - # }, escape_color_keys=False, is_begin_tag=is_begin_tag) - # Parsing - #@abstractmethod - #def get_command_spans(self) -> list[Span]: - # return [] - # #return [ - # # self.match(r"\\(?:[a-zA-Z]+|.)", pos=index).span() - # # for index in self.backslash_indices - # #] - - #@abstractmethod - #@staticmethod - #def get_command_repl_dict() -> dict[str | re.Pattern, str]: - # return {} - - #@abstractmethod - #def parse_setup(self) -> None: - # return - - #@abstractmethod - #def get_command_repl_items(self) -> list[tuple[Span, str]]: - # return [] - # #result = [] - # #for cmd_span in self.command_spans: - # # cmd_str = self.get_substr(cmd_span) - # # if - # # repl_str = self.command_repl_dict.get(cmd_str, cmd_str) - # # result.append((cmd_span, repl_str)) - # #return result - - #def span_cuts_at_entity(self, span: Span) -> bool: - # return any([ - # entity_begin < index < entity_end - # for index in span - # for entity_begin, entity_end in self.command_repl_items - # ]) - - #@abstractmethod - #def get_all_specified_items(self) -> list[tuple[Span, dict[str, str]]]: - # return [] - - #def get_specified_items(self) -> list[tuple[Span, dict[str, str]]]: - # return [ - # (span, attr_dict) - # for span, attr_dict in self.get_all_specified_items() - # if not any([ - # entity_begin < index < entity_end - # for index in span - # for entity_begin, entity_end in self.command_repl_items - # ]) - # ] - - #def get_specified_spans(self) -> list[Span]: - # return [span for span, _ in self.specified_items] - def parse(self) -> None: begin_cmd_spans, end_cmd_spans, cmd_spans = self.get_command_spans() @@ -611,105 +301,11 @@ class LabelledString(SVGMobject, ABC): raise ValueError("Missing '}' inserted") specified_items = self.get_specified_items(cmd_span_pairs) - - #entity_spans = self.get_entity_spans() - #self.entity_spans = entity_spans - #tag_span_pairs, internal_items = self.get_internal_items() - #self.level_interval_spans = self.get_level_interval_spans( - # tag_span_pairs, self.entity_spans - #) - #self.level_interval_spans = [ - # level_interval_span - # for level_interval_span, _ in level_interval_span_items - #] - #self.level_shifts = [ - # level_shift - # for _, level_shift in level_interval_span_items - #] # TODO - #self.tag_content_spans = [ - # (content_begin, content_end) - # for (_, content_begin), (content_end, _) in tag_span_pairs - #] - #self.tag_spans = self.chain(*tag_span_pairs) - #specified_items = self.chain( - # self.get_specified_items(cmd_span_pairs) - # internal_items, - # self.get_external_items(), - # [ - # (span, {}) - # for span in self.find_spans_by_selector(self.isolate) - # ] - #) - #print(f"\n{specified_items=}\n") - #specified_spans = - - split_items = [ (span, attr_dict) for specified_span, attr_dict in specified_items for span in self.split_span_by_levels(specified_span)[0] ] - #print([self.get_substr(span) for span, _ in specified_items]) - #print([self.get_substr(span) for span, _ in split_items]) - #print(f"\n{split_items=}\n") - #labelled_spans = [span for span, _ in split_items] - #labelled_spans = self.get_labelled_spans(split_spans) - #if len(labelled_spans) >= 16777216: - # raise ValueError("Cannot handle that many substrings") - - #content_strings = [] - #for is_labelled in (False, True): - # - # content_strings.append(content_string) - - #inserted_str_pairs = self.chain( - # [ - # (span, ( - # self.get_tag_str(attr_dict, escape_color_keys=True, is_begin_tag=True), - # self.get_tag_str(attr_dict, escape_color_keys=True, is_begin_tag=False) - # )) - # for span, attr_dict in split_items - # ], - # [ - # (span, ( - # self.get_color_tag_str(label + 1, is_begin_tag=True), - # self.get_color_tag_str(label + 1, is_begin_tag=False) - # )) - # for span, attr_dict in split_items - # ] - #) - - - #decorated_strings = [ - # self.replace_string(self.full_span, [ - # (span, str_pair[flag]) - # for span, str_pair in command_repl_items - # ]) - # for flag in range(2) - #] - - #full_content_strings = {} - #for is_labelled in (False, True): - # inserted_str_pairs = [ - # (span, self.get_tag_string_pair( - # attr_dict, - # rgb_hex=self.int_to_hex(label + 1) if is_labelled else None - # )) - # for label, (span, attr_dict) in enumerate(split_items) - # ] - # repl_items = self.chain( - # command_repl_items, - # [ - # ((index, index), inserted_str) - # for index, inserted_str - # in self.sort_obj_pairs_by_spans(inserted_str_pairs) - # ] - # ) - # content_string = self.replace_string( - # self.full_span, repl_items - # ) - # full_content_string = self.get_full_content_string(content_string) - # #full_content_strings[is_labelled] = full_content_string command_repl_items = [ (span, self.get_replaced_substr(self.get_substr(span), flag)) @@ -718,14 +314,17 @@ class LabelledString(SVGMobject, ABC): self.command_repl_items = command_repl_items self.specified_spans = [span for span, _ in specified_items] - self.labelled_spans = [span for span, _ in split_items] - for span_0, span_1 in it.product(self.labelled_spans, repeat=2): + labelled_spans = [span for span, _ in split_items] + if len(labelled_spans) >= 16777216: + raise ValueError("Cannot handle that many substrings") + for span_0, span_1 in it.product(labelled_spans, repeat=2): if not span_0[0] < span_1[0] < span_0[1] < span_1[1]: continue raise ValueError( "Partially overlapping substrings detected: " f"'{self.get_substr(span_0)}' and '{self.get_substr(span_1)}'" ) + self.labelled_spans = labelled_spans self.original_content, self.labelled_content = ( self.get_full_content_string(self.replace_string( @@ -745,162 +344,59 @@ class LabelledString(SVGMobject, ABC): ), is_labelled=is_labelled) for is_labelled in (False, True) ) - print(self.original_content) - print() - print(self.labelled_content) + def split_span_by_levels( + self, arbitrary_span: Span + ) -> tuple[list[Span], int, int]: + interval_span_items = self.cmd_span_items + interval_spans = [span for span, _ in interval_span_items] + interval_range = ( + sum([ + arbitrary_span[0] > interval_begin + for interval_begin, _ in interval_spans + ]), + sum([ + arbitrary_span[1] >= interval_end + for _, interval_end in interval_spans + ]) + ) + complement_spans = self.get_complement_spans(self.full_span, interval_spans) + adjusted_span = ( + max(arbitrary_span[0], complement_spans[interval_range[0]][0]), + min(arbitrary_span[1], complement_spans[interval_range[1]][1]) + ) + if adjusted_span[0] > adjusted_span[1]: + return [], 0, 0 - #self.original_content = full_content_strings[False] - #self.labelled_content = full_content_strings[True] - #print(self.original_content) - #print() - #print(self.labelled_content) + upwards_stack = [] + downwards_stack = [] + for interval_index in range(*interval_range): + _, level_shift = interval_span_items[interval_index] + if level_shift == 1: + upwards_stack.append(interval_index) + elif level_shift == -1: + if upwards_stack: + upwards_stack.pop() + else: + downwards_stack.append(interval_index) - - #self.command_repl_dict = self.get_command_repl_dict() - #self.command_repl_items = [] - #self.bracket_content_spans = [] - ##self.command_spans = self.get_command_spans() - ##self.specified_items = self.get_specified_items() - #self.specified_spans = [] - #self.check_overlapping() ####### - #self.labelled_spans = [] - #if len(self.labelled_spans) >= 16777216: - # raise ValueError("Cannot handle that many substrings") + covered_interval_spans = [ + interval_spans[piece_index] + for piece_index in self.chain(downwards_stack, upwards_stack) + ] + result = self.get_complement_spans(adjusted_span, covered_interval_spans) + return result, len(downwards_stack), len(upwards_stack) @abstractmethod def get_command_spans(self) -> tuple[list[Span], list[Span], list[Span]]: return [], [], [] - #@abstractmethod - #def get_entity_spans(self) -> list[Span]: - # return [] - - #@abstractmethod - #def get_internal_items( - # self - #) -> tuple[list[tuple[Span, Span]], list[tuple[Span, dict[str, str]]]]: - # return [], [] - @abstractmethod def get_specified_items( self, cmd_span_pairs: list[tuple[Span, Span]] ) -> list[tuple[Span, dict[str, str]]]: return [] - #@abstractmethod - #def get_spans_from_items(self, specified_items: list[tuple[Span, dict[str, str]]]) -> list[Span]: - # return [] - - #def split_span(self, arbitrary_span: Span) -> list[Span]: - # span_begin, span_end = arbitrary_span - # # TODO: improve algorithm - # span_begin += sum([ - # entity_end - span_begin - # for entity_begin, entity_end in self.entity_spans - # if entity_begin < span_begin < entity_end - # ]) - # span_end -= sum([ - # span_end - entity_begin - # for entity_begin, entity_end in self.entity_spans - # if entity_begin < span_end < entity_end - # ]) - # if span_begin >= span_end: - # return [] - - # adjusted_span = (span_begin, span_end) - # result = [] - # span_choices = list(filter( - # lambda span: span[0] < span[1] and self.span_contains( - # adjusted_span, span - # ), - # self.tag_content_spans - # )) - # while span_choices: - # chosen_span = min(span_choices, key=lambda t: (t[0], -t[1])) - # result.append(chosen_span) - # span_choices = list(filter( - # lambda span: chosen_span[1] <= span[0], - # span_choices - # )) - # result.extend(self.chain(*[ - # self.get_complement_spans(span, sorted([ - # (max(tag_span[0], span[0]), min(tag_span[1], span[1])) - # for tag_span in self.tag_spans - # if tag_span[0] < span[1] and span[0] < tag_span[1] - # ])) - # for span in self.get_complement_spans(adjusted_span, result) - # ])) - # return list(filter(lambda span: span[0] < span[1], result)) - - #@abstractmethod - #def get_split_items(self, specified_items: list[T]) -> list[T]: - # return [] - - #@abstractmethod - #def get_labelled_spans(self, split_spans: list[Span]) -> list[Span]: - # return [] - - #@abstractmethod - #def get_predefined_inserted_str_items( - # self, split_items: list[T] - #) -> list[tuple[Span, tuple[str, str]]]: - # return [] - - #def check_overlapping(self) -> None: - - #for span_0, span_1 in it.product(self.specified_spans, self.bracket_content_spans): - # if not any( - # span_0[0] < span_1[0] <= span_0[1] <= span_1[1], - # span_1[0] <= span_0[0] <= span_1[1] < span_0[1] - # ): - # continue - # raise ValueError( - # f"Invalid substring detected: '{self.get_substr(span_0)}'" - # ) - # TODO: test bracket_content_spans - - #@abstractmethod - #def get_inserted_string_pairs( - # self, is_labelled: bool - #) -> list[tuple[Span, tuple[str, str]]]: - # return [] - - #@abstractmethod - #def get_labelled_spans(self) -> list[Span]: - # return [] - - #def get_decorated_string( - # self, is_labelled: bool, replace_commands: bool - #) -> str: - # inserted_string_pairs = [ - # (indices, str_pair) - # for indices, str_pair in self.get_inserted_string_pairs( - # is_labelled=is_labelled - # ) - # if not any( - # cmd_begin < index < cmd_end - # for index in indices - # for (cmd_begin, cmd_end), _ in self.command_repl_items - # ) - # ] - # repl_items = [ - # ((index, index), inserted_string) - # for index, inserted_string - # in self.sort_inserted_strings_from_pairs( - # inserted_string_pairs - # ) - # ] - # if replace_commands: - # repl_items.extend(self.command_repl_items) - # return self.get_replaced_substr(self.full_span, repl_items) - - #@abstractmethod - #def get_additional_inserted_str_pairs( - # self - #) -> list[tuple[Span, tuple[str, str]]]: - # return [] - @abstractmethod def get_replaced_substr(self, substr: str, flag: int) -> str: return "" @@ -909,9 +405,6 @@ class LabelledString(SVGMobject, ABC): def get_full_content_string(self, content_string: str, is_labelled: bool) -> str: return "" - #def get_content(self, is_labelled: bool) -> str: - # return self.content_strings[int(is_labelled)] - # Selector @abstractmethod @@ -996,13 +489,11 @@ class LabelledString(SVGMobject, ABC): for indices_list in indices_lists ]) - #def select_part_by_span(self, arbitrary_span: Span) -> VGroup: - # return VGroup(*[ - # self.labelled_submobject_items[submob_index] - # for submob_index in self.get_submob_indices_list_by_span( - # arbitrary_span - # ) - # ]) + def build_groups(self) -> VGroup: + return self.build_parts_from_indices_lists([ + indices_list + for _, indices_list in self.get_group_part_items() + ]) def select_parts(self, selector: Selector) -> VGroup: return self.build_parts_from_indices_lists( diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 03896e9c..823e4f81 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -31,16 +31,6 @@ if TYPE_CHECKING: SCALE_FACTOR_PER_FONT_POINT = 0.001 -#TEX_COLOR_COMMANDS_DICT = { -# "\\color": (1, False), -# "\\textcolor": (1, False), -# "\\pagecolor": (1, True), -# "\\colorbox": (1, True), -# "\\fcolorbox": (2, True), -#} -#TEX_COLOR_COMMAND_SUFFIX = "replaced" - - class MTex(LabelledString): CONFIG = { "font_size": 48, @@ -104,97 +94,8 @@ class MTex(LabelledString): return ("", "") return ("{{" + MTex.get_color_command_str(label_hex), "}}") - #@staticmethod - #def shrink_span(span: Span, skippable_indices: list[int]) -> Span: - # span_begin, span_end = span - # while span_begin in skippable_indices: - # span_begin += 1 - # while span_end - 1 in skippable_indices: - # span_end -= 1 - # return (span_begin, span_end) - # Parsing - #def parse(self) -> None: # TODO - #command_spans = self.find_spans(r"\\(?:[a-zA-Z]+|.)") - - - #specified_spans = self.chain( - # inner_content_spans, - # *[ - # self.find_spans_by_selector(selector) - # for selector in self.tex_to_color_map.keys() - # ], - # self.find_spans_by_selector(self.isolate) - #) - #print(specified_spans) - #label_span_list = self.remove_redundancies(self.chain(*[ - # self.split_span(span) - # for span in specified_spans - #])) - #print(label_span_list) - #for span in all_specified_spans: - # adjusted_span, _, _ = self.adjust_span(span, align_level=True) - # if adjusted_span[0] > adjusted_span[1]: - # continue - # specified_spans.append(adjusted_span) - - - - #reversed_script_spans_dict = { - # span_end: span_begin - # for span_begin, _, span_end in script_items - #} - #label_span_list = [ - # (content_begin, span_end) - # for _, content_begin, span_end in script_items - #] - #for span_begin, span_end in specified_spans: - # while span_end in reversed_script_spans_dict: - # span_end = reversed_script_spans_dict[span_end] - # if span_begin >= span_end: - # continue - # shrinked_span = (span_begin, span_end) - # if shrinked_span in label_span_list: - # continue - # label_span_list.append(shrinked_span) - - #inserted_str_items = [ - # (span, ( - # ("{{", "{{" + self.get_color_command_str(label + 1)), - # ("}}", "}}"), - # )) - # for label, span in enumerate(label_span_list) - #] - #command_repl_items = [ - # ((index, index), str_pair) - # for index, str_pair in self.sort_obj_pairs_by_spans(inserted_str_items) - #] - #for cmd_span in command_spans: - # cmd_str = self.get_substr(cmd_span) - # if cmd_str not in TEX_COLOR_COMMANDS_DICT: - # continue - # repl_str = f"{cmd_str}{TEX_COLOR_COMMAND_SUFFIX}" - # command_repl_items.append((cmd_span, (cmd_str, repl_str))) - #print(decorated_strings) - #return specified_spans, label_span_list, decorated_strings - - - - #self.command_spans = self.find_spans(r"\\(?:[a-zA-Z]+|.)") - #self.ignorable_indices = self.get_ignorable_indices() - #self.brace_content_spans = self.get_brace_content_spans() - #self.command_repl_items = self.get_command_repl_items() - ##self.backslash_indices = self.get_backslash_indices() - #self.ignorable_indices = self.get_ignorable_indices() - ##self.script_items = self.get_script_items() - ##self.script_char_indices = self.get_script_char_indices() - ##self.script_content_spans = self.get_script_content_spans() - ##self.script_spans = self.get_script_spans() - #self.specified_spans = self.get_specified_spans() - ##super().parse() - #self.label_span_list = self.get_label_span_list() - def get_command_spans(self) -> tuple[list[Span], list[Span], list[Span]]: cmd_spans = self.find_spans(r"\\(?:[a-zA-Z]+|\s|\S)") begin_cmd_spans = [ @@ -209,85 +110,6 @@ class MTex(LabelledString): ] return begin_cmd_spans, end_cmd_spans, cmd_spans - #def get_entity_spans(self) -> list[Span]: - # return self.find_spans(r"\\(?:[a-zA-Z]+|.)") - - #def get_internal_items( - # self - #) -> tuple[list[tuple[Span, Span]], list[tuple[Span, dict[str, str]]]]: - # command_spans = self.entity_spans - # brace_span_pairs = [] - # brace_begin_spans_stack = [] - # for span in self.find_spans(r"[{}]"): - # char_index = span[0] - # if (char_index - 1, char_index + 1) in command_spans: - # continue - # if self.get_substr(span) == "{": - # brace_begin_spans_stack.append(span) - # else: - # if not brace_begin_spans_stack: - # raise ValueError("Missing '{' inserted") - # brace_span = brace_begin_spans_stack.pop() - # brace_span_pairs.append((brace_span, span)) - # if brace_begin_spans_stack: - # raise ValueError("Missing '}' inserted") - - #tag_span_pairs = brace_span_pairs.copy() - #script_entity_dict = dict(self.chain( - # [ - # (span_begin, span_end) - # for (span_begin, _), (_, span_end) in brace_span_pairs - # ], - # command_spans - #)) - #script_additional_brace_spans = [ - # (char_index + 1, script_entity_dict.get( - # script_begin, script_begin + 1 - # )) - # for char_index, script_begin in self.find_spans(r"[_^]\s*(?=.)") - # if (char_index - 1, char_index + 1) not in command_spans - #] - #for char_index, script_begin in self.find_spans(r"[_^]\s*(?=.)"): - # if (char_index - 1, char_index + 1) in command_spans: - # continue - # script_end = script_entity_dict.get(script_begin, script_begin + 1) - # tag_span_pairs.append( - # ((char_index, char_index + 1), (script_end, script_end)) - # ) - # script_additional_brace_spans.append((char_index + 1, script_end)) - - #tag_span_pairs = self.chain( - # brace_span_pairs, - # [ - # ((script_begin - 1, script_begin), (script_end, script_end)) - # for script_begin, script_end in script_additional_brace_spans - # ] - #) - - #brace_content_spans = [ - # (span_begin, span_end) - # for (_, span_begin), (span_end, _) in brace_span_pairs - #] - #internal_items = [ - # (brace_content_spans[range_begin], {}) - # for _, (range_begin, range_end) in self.compress_neighbours([ - # (span_begin + index, span_end - index) - # for index, (span_begin, span_end) in enumerate( - # brace_content_spans - # ) - # ]) - # if range_end - range_begin >= 2 - #] - ##self.script_additional_brace_spans = script_additional_brace_spans - #return brace_span_pairs, internal_items - - #def get_external_items(self) -> list[tuple[Span, dict[str, str]]]: - # return [ - # (span, {}) - # for selector in self.tex_to_color_map - # for span in self.find_spans_by_selector(selector) - # ] - def get_specified_items( self, cmd_span_pairs: list[tuple[Span, Span]] ) -> list[tuple[Span, dict[str, str]]]: @@ -315,292 +137,8 @@ class MTex(LabelledString): ) return [(span, {}) for span in specified_spans] - #def get_label_span_list(self, split_spans: list[Span]) -> list[Span]: - # return split_spans.copy() - - #def get_spans_from_items(self, specified_items: list[Span]) -> list[Span]: - # return specified_items - - #def get_split_items(self, specified_items: list[Span]) -> list[Span]: - # return self.remove_redundancies(self.chain(*[ - # self.split_span(span) - # for span in specified_items - # ])) - - #def get_label_span_list(self, split_spans: list[Span]) -> list[Span]: - # return split_spans - - #def get_additional_inserted_str_pairs( - # self - #) -> list[tuple[Span, tuple[str, str]]]: - # return [ - # (span, ("{", "}")) - # for span in self.script_additional_brace_spans - # ] - - #def get_command_repl_items(self) -> list[Span, str]: - # return [] - #if not is_labelled: - # return [] - #result = [] - #command_spans = self.entity_spans # TODO - #for cmd_span in command_spans: - # cmd_str = self.get_substr(cmd_span) - # if cmd_str not in TEX_COLOR_COMMANDS_DICT: - # continue - # repl_str = f"{cmd_str}{TEX_COLOR_COMMAND_SUFFIX}" - # result.append((cmd_span, repl_str)) - #return result - - #def get_predefined_inserted_str_items( - # self, split_items: list[Span] - #) -> list[tuple[Span, tuple[str, str]]]: - # return [] - - #def get_ignorable_indices(self) -> list[int]: - # return self.chain( - # [ - # index - # for index, _ in self.find_spans(r"\s") - # ], - # [ - # index - # for index, _ in self.find_spans(r"[_^{}]") - # if (index - 1, index + 1) not in self.command_spans - # ], - # ) - - #def get_bracket_content_spans(self) -> list[Span]: - # span_begins = [] - # span_ends = [] - # span_begins_stack = [] - # for match_obj in re.finditer(r"[{}]", self.string): - # index = match_obj.start() - # if (index - 1, index + 1) in command_spans: - # continue - # if match_obj.group() == "{": - # span_begins_stack.append(index + 1) - # else: - # if not span_begins_stack: - # raise ValueError("Missing '{' inserted") - # span_begins.append(span_begins_stack.pop()) - # span_ends.append(index) - # if span_begins_stack: - # raise ValueError("Missing '}' inserted") - # return list(zip(span_begins, span_ends)) - - #def get_command_repl_items(self) -> list[tuple[Span, str]]: - # result = [] - # for cmd_span in self.command_spans: - # cmd_str = self.get_substr(cmd_span) - # if cmd_str in TEX_COLOR_COMMANDS_DICT: - # repl_str = f"{cmd_str}{TEX_COLOR_COMMAND_SUFFIX}" - # else: - # repl_str = cmd_str - # result.append((cmd_span, repl_str)) - # return result - - #def get_specified_spans(self) -> list[Span]: - # # Match paired double braces (`{{...}}`). - # sorted_content_spans = sorted( - # self.bracket_content_spans, key=lambda t: t[1] - # ) - # inner_content_spans = [ - # sorted_content_spans[range_begin] - # for _, (range_begin, range_end) in self.compress_neighbours([ - # (span_begin + index, span_end - index) - # for index, (span_begin, span_end) in enumerate( - # sorted_content_spans - # ) - # ]) - # if range_end - range_begin >= 2 - # ] - # #inner_content_spans = [ - # # (span_begin + 1, span_end - 1) - # # for span_begin, span_end in inner_brace_spans - # # if span_end - span_begin > 2 - # #] - - # return self.remove_redundancies(self.chain( - # inner_content_spans, - # *[ - # self.find_spans_by_selector(selector) - # for selector in self.tex_to_color_map.keys() - # ], - # self.find_spans_by_selector(self.isolate) - # )) - # #return list(filter( - # # lambda span: not any([ - # # entity_begin < index < entity_end - # # for index in span - # # for entity_begin, entity_end in self.command_spans - # # ]), - # # result - # #)) - - #def get_label_span_list(self) -> tuple[list[int], list[Span]]: - # script_entity_dict = dict(self.chain( - # [ - # (span_begin - 1, span_end + 1) - # for span_begin, span_end in self.bracket_content_spans - # ], - # self.command_spans - # )) - # script_items = [] - # for match_obj in re.finditer(r"\s*([_^])\s*(?=.)", self.string): - # char_index = match_obj.start(1) - # if (char_index - 1, char_index + 1) in self.command_spans: - # continue - # span_begin, content_begin = match_obj.span() - # span_end = script_entity_dict.get(span_begin, content_begin + 1) - # script_items.append( - # (span_begin, char_index, content_begin, span_end) - # ) - - # reversed_script_spans_dict = { - # span_end: span_begin - # for span_begin, _, _, span_end in script_items - # } - # ignorable_indices = self.chain( - # [index for index, _ in self.find_spans(r"\s")], - # [char_index for _, char_index, _, _ in script_items] - # ) - # result = [ - # (content_begin, span_end) - # for _, _, content_begin, span_end in script_items - # ] - # for span in self.specified_spans: - # span_begin, span_end = self.shrink_span(span, ignorable_indices) - # while span_end in reversed_script_spans_dict: - # span_end = reversed_script_spans_dict[span_end] - # if span_begin >= span_end: - # continue - # shrinked_span = (span_begin, span_end) - # if shrinked_span in result: - # continue - # result.append(shrinked_span) - # return result - - #def get_command_spans(self) -> list[Span]: - # return self.find_spans() - - #def get_command_repl_items(self) -> list[Span]: - # return [ - # (span, self.get_substr(span)) - # for span in self.find_spans(r"\\(?:[a-zA-Z]+|.)") - # ] - - #def get_command_spans(self) -> list[Span]: - # return self.find_spans(r"\\(?:[a-zA-Z]+|.)") - #return [ - # self.match(r"\\(?:[a-zA-Z]+|.)", pos=index).span() - # for index in self.backslash_indices - #] - - #@staticmethod - #def get_command_repl_dict() -> dict[str | re.Pattern, str]: - # return { - # cmd_name: f"{cmd_name}replaced" - # for cmd_name in TEX_COLOR_COMMANDS_DICT - # } - - #def get_backslash_indices(self) -> list[int]: - # # The latter of `\\` doesn't count. - # return self.find_indices(r"\\.") - - #def get_unescaped_char_indices(self, char: str) -> list[int]: - # return list(filter( - # lambda index: index - 1 not in self.backslash_indices, - # self.find_indices(re.escape(char)) - # )) - - #def get_script_items(self) -> list[tuple[int, int, int, int]]: - # script_entity_dict = dict(self.chain( - # self.brace_spans, - # self.command_spans - # )) - # result = [] - # for match_obj in re.finditer(r"\s*([_^])\s*(?=.)", self.string): - # char_index = match_obj.start(1) - # if char_index - 1 in self.backslash_indices: - # continue - # span_begin, content_begin = match_obj.span() - # span_end = script_entity_dict.get(span_begin, content_begin + 1) - # result.append((span_begin, char_index, content_begin, span_end)) - # return result - - #def get_script_char_indices(self) -> list[int]: - # return self.chain(*[ - # self.get_unescaped_char_indices(char) - # for char in "_^" - # ]) - - #def get_script_content_spans(self) -> list[Span]: - # result = [] - # script_entity_dict = dict(self.chain( - # self.brace_spans, - # self.command_spans - # )) - # for index in self.script_char_indices: - # span_begin = self.match(r"\s*", pos=index + 1).end() - # if span_begin in script_entity_dict.keys(): - # span_end = script_entity_dict[span_begin] - # else: - # match_obj = self.match(r".", pos=span_begin) - # if match_obj is None: - # continue - # span_end = match_obj.end() - # result.append((span_begin, span_end)) - # return result - - #def get_script_spans(self) -> list[Span]: - # return [ - # ( - # self.match(r"[\s\S]*?(\s*)$", endpos=index).start(1), - # script_content_span[1] - # ) - # for index, script_content_span in zip( - # self.script_char_indices, self.script_content_spans - # ) - # ] - - #def get_command_repl_items(self) -> list[tuple[Span, str]]: - # result = [] - # brace_spans_dict = dict(self.brace_spans) - # brace_begins = list(brace_spans_dict.keys()) - # for cmd_span in self.command_spans: - # cmd_name = self.get_substr(cmd_span) - # if cmd_name not in TEX_COLOR_COMMANDS_DICT: - # continue - # n_braces, substitute_cmd = TEX_COLOR_COMMANDS_DICT[cmd_name] - # span_begin, span_end = cmd_span - # for _ in range(n_braces): - # span_end = brace_spans_dict[min(filter( - # lambda index: index >= span_end, - # brace_begins - # ))] - # if substitute_cmd: - # repl_str = cmd_name + n_braces * "{black}" - # else: - # repl_str = "" - # result.append(((span_begin, span_end), repl_str)) - # return result - - #def get_inserted_string_pairs( - # self, is_labelled: bool - #) -> list[tuple[Span, tuple[str, str]]]: - # if not is_labelled: - # return [] - # return [ - # (span, ( - # "{{" + self.get_color_command_str(label + 1), - # "}}" - # )) - # for label, span in enumerate(self.label_span_list) - # ] - def get_replaced_substr(self, substr: str, flag: int) -> str: - return substr # TODO: replace color commands + return substr def get_full_content_string(self, content_string: str, is_labelled: bool) -> str: result = content_string @@ -615,24 +153,6 @@ class MTex(LabelledString): if self.alignment: result = "\n".join([self.alignment, result]) - #if is_labelled: - # occurred_commands = [ - # # TODO - # self.get_substr(span) for span in self.entity_spans - # ] - # newcommand_lines = [ - # "".join([ - # f"\\newcommand{cmd_name}{TEX_COLOR_COMMAND_SUFFIX}", - # f"[{n_braces + 1}][]", - # "{", - # cmd_name + "{black}" * n_braces if substitute_cmd else "", - # "}" - # ]) - # for cmd_name, (n_braces, substitute_cmd) - # in TEX_COLOR_COMMANDS_DICT.items() - # if cmd_name in occurred_commands - # ] - # result = "\n".join([*newcommand_lines, result]) if not is_labelled: result = "\n".join([ self.get_color_command_str(self.base_color_hex), @@ -646,12 +166,6 @@ class MTex(LabelledString): backslash_indices = [ index for index, _ in self.find_spans(r"\\[\s\S]") ] - #ignored_spans = [ - # ignored_span - # for ignored_span in self.find_spans(r"[\s_^{}]+") - # if ignored_span[0] - 1 not in backslash_indices - #] - #shrinked_span, _ = self.adjust_span(span, ignored_spans) ignored_indices = [ index for index, _ in self.find_spans(r"[\s_^{}]") @@ -663,11 +177,6 @@ class MTex(LabelledString): while span_end - 1 in ignored_indices: span_end -= 1 shrinked_span = (span_begin, span_end) - #if span_begin >= span_end: - # return "" - - #shrinked_span = (span_begin, span_end) - _, unclosed_right_braces, unclosed_left_braces = self.split_span_by_levels(shrinked_span) whitespace_repl_items = [] for whitespace_span in self.find_spans(r"\s+"): @@ -684,95 +193,13 @@ class MTex(LabelledString): replaced_substr = "" whitespace_repl_items.append((whitespace_span, replaced_substr)) + _, unclosed_right_braces, unclosed_left_braces = self.split_span_by_levels(shrinked_span) return "".join([ unclosed_right_braces * "{", self.replace_string(shrinked_span, whitespace_repl_items), unclosed_left_braces * "}" ]) - - #interval_spans = [ - # span - # if span[0] - 1 not in backslash_indices - # else (span[0] + 1, span[1]) - # for span in self.find_spans(r"[\s_^{}]+") - #] - #adjusted_span, _ = self.adjust_span(span, interval_spans) - #if adjusted_span[0] >= adjusted_span[1]: - # return "" - - #left_brace_indices = list(filter( - # lambda index: self.get_substr((index, index + 1)) == "{", - # ignored_indices - #)) - #right_brace_indices = list(filter( - # lambda index: self.get_substr((index, index + 1)) == "}", - # ignored_indices - #)) - #unclosed_left_braces = 0 - #unclosed_right_braces = 0 - #for index in range(*adjusted_span): - # if index in left_brace_indices: - # unclosed_left_braces += 1 - # elif index in right_brace_indices: - # if unclosed_left_braces == 0: - # unclosed_right_braces += 1 - # else: - # unclosed_left_braces -= 1 - #adjusted_span, unclosed_left_braces, unclosed_right_braces \ - # = self.adjust_span(span, align_level=False) - #print(self.get_substr(span), "".join([ - # unclosed_right_braces * "{", - # self.get_substr(shrinked_span), - # unclosed_left_braces * "}" - #])) - #result = "".join([ - # unclosed_right_braces * "{", - # self.get_substr(shrinked_span), - # unclosed_left_braces * "}" - #]) - #return re.sub(r"\s+", " ", result) - - #return (span_begin, span_end) - #return self.get_substr(span) # TODO: test - #left_brace_indices = [ - # span_begin - 1 - # for span_begin, _ in self.brace_content_spans - #] - #right_brace_indices = [ - # span_end - # for _, span_end in self.brace_content_spans - #] - #skippable_indices = self.chain( - # self.ignorable_indices, - # #self.script_char_indices, - # left_brace_indices, - # right_brace_indices - #) - #shrinked_span = self.shrink_span(span, skippable_indices) - - ##if shrinked_span[0] >= shrinked_span[1]: - ## return "" - - ## Balance braces. - #unclosed_left_braces = 0 - #unclosed_right_braces = 0 - #for index in range(*shrinked_span): - # if index in left_brace_indices: - # unclosed_left_braces += 1 - # elif index in right_brace_indices: - # if unclosed_left_braces == 0: - # unclosed_right_braces += 1 - # else: - # unclosed_left_braces -= 1 - ##adjusted_span, unclosed_left_braces, unclosed_right_braces \ - ## = self.adjust_span(span, align_level=False) - #return "".join([ - # unclosed_right_braces * "{", - # self.get_substr(shrinked_span), - # unclosed_left_braces * "}" - #]) - # Method alias def get_parts_by_tex(self, selector: Selector) -> VGroup: diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index fc4ff56f..c46c955f 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -240,14 +240,6 @@ class MarkupText(LabelledString): f"{validate_error}" ) - #def parse(self) -> None: - # #self.global_attr_dict = self.get_global_attr_dict() - # #self.items_from_markup = self.get_items_from_markup() - # #self.tag_spans = self.get_tag_spans() - # ##self.items_from_markup = self.get_items_from_markup() - # #self.specified_items = self.get_specified_items() - # super().parse() - #@property #def sort_labelled_submobs(self) -> bool: # return True @@ -300,176 +292,6 @@ class MarkupText(LabelledString): # Parsing - #def parse(self) -> None: - # self.bracket_content_spans, self.command_repl_items \ - # = self.get_items_from_markup() - # #self.bracket_content_spans = [ - # # span for span, _ in items_from_markup - # #] - # #specified_items = self.get_specified_items() - # #self.command_repl_items = self.get_command_repl_items() - # #self.specified_spans = self.remove_redundancies([ - # # span for span, _ in specified_items - # #]) - # #self.label_span_list = self.get_label_span_list() - # #self.predefined_items = [ - # # (self.full_span, self.get_global_attr_dict()), - # # (self.full_span, self.global_config), - # # *specified_items - # #] - - #def parse(self) -> None: # TODO: type - # if not self.is_markup: - # return [], [], [ - # (span, (escaped, escaped)) - # for char, escaped in XML_ENTITIES - # for span in self.find_spans(re.escape(char)) - # ] - - #self.entity_spans = self.find_spans(r"&[\s\S]*?;") - - #tag_spans = [span for span, _ in command_repl_items] - #begin_tag_spans = [ - # begin_tag_span for begin_tag_span, _, _ in markup_tag_items - #] - #end_tag_spans = [ - # end_tag_span for _, end_tag_span, _ in markup_tag_items - #] - #tag_spans = self.chain(begin_tag_spans, end_tag_spans) - #command_repl_items = [ - # (tag_span, "") for tag_span in tag_spans - #] - #self.chain( - # [ - # (begin_tag_span, ( - # f"", - # f"" - # )) - # for begin_tag_span, _, attr_dict in markup_tag_items - # ], - # [ - # (end_tag_span, ("", "")) - # for _, end_tag_span, _ in markup_tag_items - # ] - #) - #self.piece_spans, self.piece_levels = self.init_piece_items( - # begin_tag_spans, end_tag_spans, self.find_spans(r"&[\s\S]*?;") - #) - #command_repl_items.extend([ - # (span, (self.get_substr(span), self.get_substr(span))) - # for span in self.find_spans(r"&[\s\S]*?;") - #]) - # Needed in plain text - - #specified_items = self.chain( - # [ - # ((span_begin, span_end), attr_dict) - # for (_, span_begin), (span_end, _), attr_dict - # in markup_tag_items - # ], - # self.get_specified_items() - #) - #specified_spans = self.remove_redundancies([ - # span for span, _ in specified_items - #]) - #specified_items = [] - #for span, attr_dict in all_specified_items: - # for - # adjusted_span, _, _ = self.adjust_span(span, align_level=True) - # if adjusted_span[0] > adjusted_span[1]: - # continue - # specified_items.append(adjusted_span, attr_dict) - - - #predefined_items = [ - # (self.full_span, self.get_global_attr_dict()), - # (self.full_span, self.global_config), - # *split_items - #] - #inserted_str_items = self.chain( - # [ - # (span, ( - # ( - # f"", - # f"" - # ), - # ("", "") - # )) - # for span, attr_dict in predefined_items - # ], - # [ - # (span, ( - # ("", f""), - # ("", ""), - # )) - # for label, span in enumerate(label_span_list) - # ] - #) - #command_repl_items = self.chain( - # [ - # (tag_span, ("", "")) for tag_span in self.tag_spans - # ], - # [ - # ((index, index), str_pair) - # for index, str_pair in self.sort_obj_pairs_by_spans(inserted_str_items) - # ] - #) - #decorated_strings = [ - # self.get_replaced_substr(self.full_span, [ - # (span, str_pair[flag]) - # for span, str_pair in command_repl_items - # ]) - # for flag in range(2) - #] - #return specified_spans, label_span_list, decorated_strings - - - - - - #if is_labelled: - # attr_dict_items = self.chain( - # [ - # (span, { - # key: - # "black" if key.lower() in MARKUP_COLOR_KEYS else val - # for key, val in attr_dict.items() - # }) - # for span, attr_dict in self.predefined_items - # ], - # [ - # (span, {"foreground": self.int_to_hex(label + 1)}) - # for label, span in enumerate(self.label_span_list) - # ] - # ) - #else: - # attr_dict_items = self.chain( - # self.predefined_items, - # [ - # (span, {}) - # for span in self.label_span_list - # ] - # ) - #return [ - # (span, ( - # f"", - # "" - # )) - # for span, attr_dict in attr_dict_items - #] - #inserted_string_pairs = [ - # (indices, str_pair) - # for indices, str_pair in self.get_inserted_string_pairs( - # is_labelled=is_labelled - # ) - # if not any( - # cmd_begin < index < cmd_end - # for index in indices - # for (cmd_begin, cmd_end), _ in self.command_repl_items - # ) - #] - #return bracket_content_spans, label_span_list, command_repl_items - def get_command_spans(self) -> tuple[list[Span], list[Span], list[Span]]: begin_cmd_spans = self.find_spans( r"<\w+\s*(?:\w+\s*\=\s*(['\x22])[\s\S]*?\1\s*)*>" @@ -481,80 +303,6 @@ class MarkupText(LabelledString): cmd_spans = self.find_spans(r"&[\s\S]*?;") # TODO return begin_cmd_spans, end_cmd_spans, cmd_spans - #def get_entity_spans(self) -> list[Span]: - # if not self.is_markup: - # return [] - # return self.find_spans(r"&[\s\S]*?;") - - #def get_internal_items( - # self - #) -> tuple[list[tuple[Span, Span]], list[tuple[Span, dict[str, str]]]]: - # if not self.is_markup: - # return [], [] - - # tag_pattern = r"<(/?)(\w+)\s*((\w+\s*\=\s*(['\x22])[\s\S]*?\5\s*)*)>" - # attr_pattern = r"(\w+)\s*\=\s*(['\x22])([\s\S]*?)\2" - # begin_match_obj_stack = [] - # markup_tag_items = [] - # for match_obj in re.finditer(tag_pattern, self.string): - # if not match_obj.group(1): - # begin_match_obj_stack.append(match_obj) - # continue - # begin_match_obj = begin_match_obj_stack.pop() - # tag_name = begin_match_obj.group(2) - # if tag_name == "span": - # attr_dict = { - # attr_match_obj.group(1): attr_match_obj.group(3) - # for attr_match_obj in re.finditer( - # attr_pattern, begin_match_obj.group(3) - # ) - # } - # else: - # attr_dict = MARKUP_TAG_CONVERSION_DICT.get(tag_name, {}) - # markup_tag_items.append( - # (begin_match_obj.span(), match_obj.span(), attr_dict) - # ) - - # tag_span_pairs = [ - # (tag_begin_span, tag_end_span) - # for tag_begin_span, tag_end_span, _ in markup_tag_items - # ] - # internal_items = [ - # ((span_begin, span_end), attr_dict) - # for (_, span_begin), (span_end, _), attr_dict in markup_tag_items - # ] - # return tag_span_pairs, internal_items - - #def get_external_items(self) -> list[tuple[Span, dict[str, str]]]: - # return [ - # (self.full_span, self.get_global_attr_dict()), - # (self.full_span, self.global_config), - # *[ - # (span, {key: val}) - # for t2x_dict, key in ( - # (self.t2c, "foreground"), - # (self.t2f, "font_family"), - # (self.t2s, "font_style"), - # (self.t2w, "font_weight") - # ) - # for selector, val in t2x_dict.items() - # for span in self.find_spans_by_selector(selector) - # ], - # *[ - # (span, local_config) - # for selector, local_config in self.local_configs.items() - # for span in self.find_spans_by_selector(selector) - # ] - # ] - #if self.split_words: - # # For backward compatibility - # result.extend([ - # (span, {}) - # for pattern in (r"[a-zA-Z]+", r"\S+") - # for span in self.find_spans(pattern) - # ]) - #return result - def get_specified_items( self, cmd_span_pairs: list[tuple[Span, Span]] ) -> list[tuple[Span, dict[str, str]]]: @@ -596,233 +344,6 @@ class MarkupText(LabelledString): ] ] - - - #def get_label_span_list(self, split_spans: list[Span]) -> list[Span]: - - #def get_spans_from_items( - # self, specified_items: list[tuple[Span, dict[str, str]]] - #) -> list[Span]: - # return [span for span, _ in specified_items] - - #def get_split_items( - # self, specified_items: list[tuple[Span, dict[str, str]]] - #) -> list[tuple[Span, dict[str, str]]]: - # return [ - # (span, attr_dict) - # for specified_span, attr_dict in specified_items - # for span in self.split_span(specified_span) - # ] - - #def get_label_span_list(self, split_spans: list[Span]) -> list[Span]: - # interval_spans = sorted(self.chain( - # self.tag_spans, - # [ - # (index, index) - # for span in split_spans - # for index in span - # ] - # )) - # text_spans = self.get_complement_spans(self.full_span, interval_spans) - # if self.is_markup: - # pattern = r"[0-9a-zA-Z]+|(?:&[\s\S]*?;|[^0-9a-zA-Z\s])+" - # else: - # pattern = r"[0-9a-zA-Z]+|[^0-9a-zA-Z\s]+" - # return self.chain(*[ - # self.find_spans(pattern, pos=span_begin, endpos=span_end) - # for span_begin, span_end in text_spans - # ]) - - #def get_additional_inserted_str_pairs( - # self - #) -> list[tuple[Span, tuple[str, str]]]: - # return [] - - #def get_command_repl_items(self) -> list[Span, str]: - # result = [ - # (tag_span, "") for tag_span in self.tag_spans # TODO - # ] - # if not self.is_markup: - # result.extend([ - # (span, escaped) - # for char, escaped in XML_ENTITIES - # for span in self.find_spans(re.escape(char)) - # ]) - # return result - - #def get_predefined_inserted_str_items( - # self, split_items: list[tuple[Span, dict[str, str]]] - #) -> list[tuple[Span, tuple[str, str]]]: - # predefined_items = [ - # (self.full_span, self.get_global_attr_dict()), - # (self.full_span, self.global_config), - # *split_items - # ] - # return [ - # (span, ( - # ( - # self.get_tag_str(attr_dict, escape_color_keys=False, is_begin_tag=True), - # self.get_tag_str(attr_dict, escape_color_keys=True, is_begin_tag=True) - # ), - # ( - # self.get_tag_str(attr_dict, escape_color_keys=False, is_begin_tag=False), - # self.get_tag_str(attr_dict, escape_color_keys=True, is_begin_tag=False) - # ) - # )) - # for span, attr_dict in predefined_items - # ] - - #def get_full_content_string(self, replaced_string: str) -> str: - # return replaced_string - - #def get_tag_spans(self) -> list[Span]: - # return self.chain( - # (begin_tag_span, end_tag_span) - # for begin_tag_span, end_tag_span, _ in self.items_from_markup - # ) - - #def get_items_from_markup(self) -> list[tuple[Span, dict[str, str]]]: - # return [ - # ((span_begin, span_end), attr_dict) - # for (_, span_begin), (span_end, _), attr_dict - # in self.items_from_markup - # if span_begin < span_end - # ] - - #def get_command_repl_items(self) -> list[tuple[Span, str]]: - # result = [ - # (tag_span, "") - # for tag_span in self.tag_spans - # ] - # if self.is_markup: - # result.extend([ - # (span, self.get_substr(span)) - # for span in self.find_spans(r"&[\s\S]*?;") - # ]) - # else: - # result.extend([ - # (span, escaped) - # for char, escaped in ( - # ("&", "&"), - # (">", ">"), - # ("<", "<") - # ) - # for span in self.find_spans(re.escape(char)) - # ]) - # return result - - #def get_command_spans(self) -> list[Span]: - # result = self.tag_spans.copy() - # if self.is_markup: - # result.extend(self.find_spans(r"&[\s\S]*?;")) - # else: - # result.extend(self.find_spans(r"[&<>]")) - # return result - - #@staticmethod - #def get_command_repl_dict() -> dict[str | re.Pattern, str]: - # return { - # re.compile(r"<.*>"): "", - # "&": "&", - # "<": "<", - # ">": ">" - # } - # #result = [ - # # (tag_span, "") for tag_span in self.tag_spans - # #] - # #if self.is_markup: - # # result.extend([ - # # (span, self.get_substr(span)) - # # for span in self.find_spans(r"&[\s\S]*?;") - # # ]) - # #else: - # # result.extend([ - # # (span, escaped) - # # for char, escaped in ( - # # ("&", "&"), - # # (">", ">"), - # # ("<", "<") - # # ) - # # for span in self.find_spans(re.escape(char)) - # # ]) - # #return result - #entity_spans = self.tag_spans.copy() - #if self.is_markup: - # entity_spans.extend(self.find_spans(r"&[\s\S]*?;")) - #return [ - # (span, attr_dict) - # for span, attr_dict in result - # if not self.span_cuts_at_entity(span) - # #if not any([ - # # entity_begin < index < entity_end - # # for index in span - # # for entity_begin, entity_end in entity_spans - # #]) - #] - - #def get_specified_spans(self) -> list[Span]: - # return self.remove_redundancies([ - # span for span, _ in self.specified_items - # ]) - - #def get_label_span_list(self) -> list[Span]: - # interval_spans = sorted(self.chain( - # self.tag_spans, - # [ - # (index, index) - # for span in self.specified_spans - # for index in span - # ] - # )) - # text_spans = self.get_complement_spans(interval_spans, self.full_span) - # if self.is_markup: - # pattern = r"[0-9a-zA-Z]+|(?:&[\s\S]*?;|[^0-9a-zA-Z\s])+" - # else: - # pattern = r"[0-9a-zA-Z]+|[^0-9a-zA-Z\s]+" - # return self.chain(*[ - # self.find_spans(pattern, pos=span_begin, endpos=span_end) - # for span_begin, span_end in text_spans - # ]) - - #def get_inserted_string_pairs( - # self, is_labelled: bool - #) -> list[tuple[Span, tuple[str, str]]]: - # #predefined_items = [ - # # (self.full_span, self.global_attr_dict), - # # (self.full_span, self.global_config), - # # *self.specified_items - # #] - # if is_labelled: - # attr_dict_items = self.chain( - # [ - # (span, { - # key: - # "black" if key.lower() in MARKUP_COLOR_KEYS else val - # for key, val in attr_dict.items() - # }) - # for span, attr_dict in self.predefined_items - # ], - # [ - # (span, {"foreground": self.int_to_hex(label + 1)}) - # for label, span in enumerate(self.label_span_list) - # ] - # ) - # else: - # attr_dict_items = self.chain( - # self.predefined_items, - # [ - # (span, {}) - # for span in self.label_span_list - # ] - # ) - # return [ - # (span, ( - # f"", - # "" - # )) - # for span, attr_dict in attr_dict_items - # ] - def get_replaced_substr(self, substr: str, flag: int) -> str: if flag: return "" @@ -847,13 +368,7 @@ class MarkupText(LabelledString): replaced_substr = entity_to_char_dict[replaced_substr] filtered_repl_items.append((cmd_span, replaced_substr)) - return self.replace_string(span, filtered_repl_items).strip() # TODO: test - #repl_items = [ - # (cmd_span, repl_str) - # for cmd_span, (repl_str, _) in self.command_repl_items - # if self.span_contains(span, cmd_span) - #] - #return self.get_replaced_substr(span, repl_items).strip() + return self.replace_string(span, filtered_repl_items).strip() # TODO # Method alias From 642602155db803871bd813d2aec3ad432d6ee68f Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Thu, 5 May 2022 23:03:02 +0800 Subject: [PATCH 08/11] [WIP] Refactor LabelledString and relevant classes --- manimlib/mobject/svg/labelled_string.py | 308 ++++++++++++------------ manimlib/mobject/svg/mtex_mobject.py | 149 +++++------- manimlib/mobject/svg/text_mobject.py | 200 ++++++++------- 3 files changed, 325 insertions(+), 332 deletions(-) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 03b5da8b..a1a693c2 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -11,7 +11,6 @@ from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.utils.color import color_to_rgb from manimlib.utils.color import rgb_to_hex from manimlib.utils.config_ops import digest_config -from manimlib.utils.iterables import remove_list_redundancies from typing import TYPE_CHECKING @@ -92,14 +91,14 @@ class LabelledString(SVGMobject, ABC): )) if unrecognized_colors: log.warning( - "Unrecognized color label(s) detected (%s, etc). " + "Unrecognized color labels detected (%s, etc). " "Skip the labelling process.", self.int_to_hex(unrecognized_colors[0]) ) submob_color_ints = [0] * num_submobjects - #TODO: remove this - #if self.sort_labelled_submobs: + # Rearrange colors so that the n-th submobject from the left + # is labelled by the n-th submobject of `labelled_svg` from the left. submob_indices = sorted( range(num_submobjects), key=lambda index: tuple( @@ -122,11 +121,6 @@ class LabelledString(SVGMobject, ABC): for submob, color_int in zip(self.submobjects, submob_color_ints): submob.label = color_int - 1 - #@property - #@abstractmethod - #def sort_labelled_submobs(self) -> bool: - # return False - # Toolkits def get_substr(self, span: Span) -> str: @@ -176,20 +170,8 @@ class LabelledString(SVGMobject, ABC): if spans is None: raise TypeError(f"Invalid selector: '{sel}'") result.extend(spans) - #return sorted(filter( - # lambda span: span[0] < span[1], - # self.remove_redundancies(result) - #)) return result - @staticmethod - def chain(*iterables: Iterable[T]) -> list[T]: - return list(it.chain(*iterables)) - - @staticmethod - def remove_redundancies(vals: Sequence[T]) -> list[T]: - return remove_list_redundancies(vals) - @staticmethod def get_neighbouring_pairs(vals: Sequence[T]) -> list[tuple[T, T]]: return list(zip(vals[:-1], vals[1:])) @@ -254,7 +236,7 @@ class LabelledString(SVGMobject, ABC): for piece_span in self.get_complement_spans(span, repl_spans) ] repl_strs = [*repl_strs, ""] - return "".join(self.chain(*zip(pieces, repl_strs))) + return "".join(it.chain(*zip(pieces, repl_strs))) @staticmethod def color_to_hex(color: ManimColor) -> str: @@ -268,26 +250,58 @@ class LabelledString(SVGMobject, ABC): def int_to_hex(rgb_int: int) -> str: return f"#{rgb_int:06x}".upper() - @staticmethod - @abstractmethod - def get_tag_string_pair( - attr_dict: dict[str, str], label_hex: str | None - ) -> tuple[str, str]: - return ("", "") - # Parsing def parse(self) -> None: - begin_cmd_spans, end_cmd_spans, cmd_spans = self.get_command_spans() - - cmd_span_items = sorted(self.chain( + begin_cmd_spans, end_cmd_spans, other_cmd_spans = self.get_cmd_spans() + cmd_span_items = sorted(it.chain( [(begin_cmd_span, 1) for begin_cmd_span in begin_cmd_spans], [(end_cmd_span, -1) for end_cmd_span in end_cmd_spans], - [(cmd_span, 0) for cmd_span in cmd_spans], + [(cmd_span, 0) for cmd_span in other_cmd_spans], ), key=lambda t: t[0]) - self.cmd_span_items = cmd_span_items + cmd_spans = [span for span, _ in cmd_span_items] + flags = [flag for _, flag in cmd_span_items] - cmd_span_pairs = [] + specified_items = self.get_specified_items( + self.get_cmd_span_pairs(cmd_span_items) + ) + split_items = [ + (span, attr_dict) + for specified_span, attr_dict in specified_items + for span in self.split_span_by_levels( + specified_span, cmd_spans, flags + ) + ] + + self.specified_spans = [span for span, _ in specified_items] + self.labelled_spans = [span for span, _ in split_items] + self.check_overlapping() + + cmd_repl_items_for_content = [ + (span, self.get_repl_substr_for_content(self.get_substr(span))) + for span in cmd_spans + ] + self.cmd_repl_items_for_matching = [ + (span, self.get_repl_substr_for_matching(self.get_substr(span))) + for span in cmd_spans + ] + + self.original_content = self.get_content( + cmd_repl_items_for_content, split_items, is_labelled=False + ) + self.labelled_content = self.get_content( + cmd_repl_items_for_content, split_items, is_labelled=True + ) + + @abstractmethod + def get_cmd_spans(self) -> tuple[list[Span], list[Span], list[Span]]: + return [], [], [] + + @staticmethod + def get_cmd_span_pairs( + cmd_span_items: list[tuple[Span, int]] + ) -> list[tuple[Span, Span]]: + result = [] begin_cmd_spans_stack = [] for cmd_span, flag in cmd_span_items: if flag == 1: @@ -296,25 +310,56 @@ class LabelledString(SVGMobject, ABC): if not begin_cmd_spans_stack: raise ValueError("Missing '{' inserted") begin_cmd_span = begin_cmd_spans_stack.pop() - cmd_span_pairs.append((begin_cmd_span, cmd_span)) + result.append((begin_cmd_span, cmd_span)) if begin_cmd_spans_stack: raise ValueError("Missing '}' inserted") + return result - specified_items = self.get_specified_items(cmd_span_pairs) - split_items = [ - (span, attr_dict) - for specified_span, attr_dict in specified_items - for span in self.split_span_by_levels(specified_span)[0] - ] + @abstractmethod + def get_specified_items( + self, cmd_span_pairs: list[tuple[Span, Span]] + ) -> list[tuple[Span, dict[str, str]]]: + return [] - command_repl_items = [ - (span, self.get_replaced_substr(self.get_substr(span), flag)) - for span, flag in cmd_span_items - ] - self.command_repl_items = command_repl_items + def split_span_by_levels( + self, arbitrary_span: Span, cmd_spans: list[Span], flags: list[int] + ) -> list[Span]: + cmd_range = ( + sum([ + arbitrary_span[0] > interval_begin + for interval_begin, _ in cmd_spans + ]), + sum([ + arbitrary_span[1] >= interval_end + for _, interval_end in cmd_spans + ]) + ) + complement_spans = self.get_complement_spans( + self.full_span, cmd_spans + ) + adjusted_span = ( + max(arbitrary_span[0], complement_spans[cmd_range[0]][0]), + min(arbitrary_span[1], complement_spans[cmd_range[1]][1]) + ) + if adjusted_span[0] > adjusted_span[1]: + return [] - self.specified_spans = [span for span, _ in specified_items] - labelled_spans = [span for span, _ in split_items] + upward_cmd_spans = [] + downward_cmd_spans = [] + for cmd_span, flag in list(zip(cmd_spans, flags))[slice(*cmd_range)]: + if flag == 1: + upward_cmd_spans.append(cmd_span) + elif flag == -1: + if upward_cmd_spans: + upward_cmd_spans.pop() + else: + downward_cmd_spans.append(cmd_span) + return self.get_complement_spans( + adjusted_span, downward_cmd_spans + upward_cmd_spans + ) + + def check_overlapping(self) -> None: + labelled_spans = self.labelled_spans if len(labelled_spans) >= 16777216: raise ValueError("Cannot handle that many substrings") for span_0, span_1 in it.product(labelled_spans, repeat=2): @@ -324,92 +369,73 @@ class LabelledString(SVGMobject, ABC): "Partially overlapping substrings detected: " f"'{self.get_substr(span_0)}' and '{self.get_substr(span_1)}'" ) - self.labelled_spans = labelled_spans - self.original_content, self.labelled_content = ( - self.get_full_content_string(self.replace_string( - self.full_span, self.chain( - command_repl_items, - [ - ((index, index), inserted_str) - for index, inserted_str in self.sort_obj_pairs_by_spans([ - (span, self.get_tag_string_pair( - attr_dict, - label_hex=self.int_to_hex(label + 1) if is_labelled else None - )) - for label, (span, attr_dict) in enumerate(split_items) - ]) - ] - ) - ), is_labelled=is_labelled) - for is_labelled in (False, True) - ) + @abstractmethod + def get_repl_substr_for_content(self, substr: str) -> str: + return "" - def split_span_by_levels( - self, arbitrary_span: Span - ) -> tuple[list[Span], int, int]: - interval_span_items = self.cmd_span_items - interval_spans = [span for span, _ in interval_span_items] - interval_range = ( - sum([ - arbitrary_span[0] > interval_begin - for interval_begin, _ in interval_spans - ]), - sum([ - arbitrary_span[1] >= interval_end - for _, interval_end in interval_spans - ]) - ) - complement_spans = self.get_complement_spans(self.full_span, interval_spans) - adjusted_span = ( - max(arbitrary_span[0], complement_spans[interval_range[0]][0]), - min(arbitrary_span[1], complement_spans[interval_range[1]][1]) - ) - if adjusted_span[0] > adjusted_span[1]: - return [], 0, 0 + @abstractmethod + def get_repl_substr_for_matching(self, substr: str) -> str: + return "" - upwards_stack = [] - downwards_stack = [] - for interval_index in range(*interval_range): - _, level_shift = interval_span_items[interval_index] - if level_shift == 1: - upwards_stack.append(interval_index) - elif level_shift == -1: - if upwards_stack: - upwards_stack.pop() - else: - downwards_stack.append(interval_index) + @staticmethod + @abstractmethod + def get_cmd_str_pair( + attr_dict: dict[str, str], label_hex: str | None + ) -> tuple[str, str]: + return "", "" - covered_interval_spans = [ - interval_spans[piece_index] - for piece_index in self.chain(downwards_stack, upwards_stack) + @abstractmethod + def get_content_prefix_and_suffix( + self, is_labelled: bool + ) -> tuple[str, str]: + return "", "" + + def get_content( + self, cmd_repl_items_for_content: list[Span, str], + split_items: list[tuple[Span, dict[str, str]]], is_labelled: bool + ) -> str: + inserted_str_pairs = [ + (span, self.get_cmd_str_pair( + attr_dict, + label_hex=self.int_to_hex(label + 1) if is_labelled else None + )) + for label, (span, attr_dict) in enumerate(split_items) ] - result = self.get_complement_spans(adjusted_span, covered_interval_spans) - return result, len(downwards_stack), len(upwards_stack) - - @abstractmethod - def get_command_spans(self) -> tuple[list[Span], list[Span], list[Span]]: - return [], [], [] - - @abstractmethod - def get_specified_items( - self, cmd_span_pairs: list[tuple[Span, Span]] - ) -> list[tuple[Span, dict[str, str]]]: - return [] - - @abstractmethod - def get_replaced_substr(self, substr: str, flag: int) -> str: - return "" - - @abstractmethod - def get_full_content_string(self, content_string: str, is_labelled: bool) -> str: - return "" + repl_items = cmd_repl_items_for_content + [ + ((index, index), inserted_str) + for index, inserted_str in self.sort_obj_pairs_by_spans( + inserted_str_pairs + ) + ] + prefix, suffix = self.get_content_prefix_and_suffix(is_labelled) + return "".join([ + prefix, + self.replace_string(self.full_span, repl_items), + suffix + ]) # Selector - @abstractmethod - def get_cleaned_substr(self, span: Span) -> str: - return "" + def get_submob_indices_list_by_span( + self, arbitrary_span: Span + ) -> list[int]: + return [ + submob_index + for submob_index, label in enumerate(self.labels) + if label != -1 and self.span_contains( + arbitrary_span, self.labelled_spans[label] + ) + ] + + def get_specified_part_items(self) -> list[tuple[str, list[int]]]: + return [ + ( + self.get_substr(span), + self.get_submob_indices_list_by_span(span) + ) + for span in self.specified_spans + ] def get_group_part_items(self) -> list[tuple[str, list[int]]]: if not self.labels: @@ -436,7 +462,13 @@ class LabelledString(SVGMobject, ABC): ) ] group_substrs = [ - self.get_cleaned_substr(span) if span[0] < span[1] else "" + re.sub(r"\s+", "", self.replace_string( + span, [ + (cmd_span, repl_str) + for cmd_span, repl_str in self.cmd_repl_items_for_matching + if self.span_contains(span, cmd_span) + ] + )) for span in self.get_complement_spans( (ordered_spans[0][0], ordered_spans[-1][1]), interval_spans ) @@ -447,26 +479,6 @@ class LabelledString(SVGMobject, ABC): ] return list(zip(group_substrs, submob_indices_lists)) - def get_submob_indices_list_by_span( - self, arbitrary_span: Span - ) -> list[int]: - return [ - submob_index - for submob_index, label in enumerate(self.labels) - if label != -1 and self.span_contains( - arbitrary_span, self.labelled_spans[label] - ) - ] - - def get_specified_part_items(self) -> list[tuple[str, list[int]]]: - return [ - ( - self.get_substr(span), - self.get_submob_indices_list_by_span(span) - ) - for span in self.specified_spans - ] - def get_submob_indices_lists_by_selector( self, selector: Selector ) -> list[list[int]]: diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 823e4f81..f90d354b 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -73,42 +73,21 @@ class MTex(LabelledString): file_path = tex_to_svg_file(full_tex) return file_path - #@property - #def sort_labelled_submobs(self) -> bool: - # return False - - # Toolkits - - @staticmethod - def get_color_command_str(rgb_hex: str) -> str: - rgb = MTex.hex_to_int(rgb_hex) - rg, b = divmod(rgb, 256) - r, g = divmod(rg, 256) - return f"\\color[RGB]{{{r}, {g}, {b}}}" - - @staticmethod - def get_tag_string_pair( - attr_dict: dict[str, str], label_hex: str | None - ) -> tuple[str, str]: - if label_hex is None: - return ("", "") - return ("{{" + MTex.get_color_command_str(label_hex), "}}") - # Parsing - def get_command_spans(self) -> tuple[list[Span], list[Span], list[Span]]: - cmd_spans = self.find_spans(r"\\(?:[a-zA-Z]+|\s|\S)") - begin_cmd_spans = [ - span - for span in self.find_spans("{") - if (span[0] - 1, span[1]) not in cmd_spans - ] - end_cmd_spans = [ - span - for span in self.find_spans("}") - if (span[0] - 1, span[1]) not in cmd_spans - ] - return begin_cmd_spans, end_cmd_spans, cmd_spans + def get_cmd_spans(self) -> tuple[list[Span], list[Span], list[Span]]: + backslash_spans = self.find_spans(r"\\(?:[a-zA-Z]+|\s|\S)") + def find_unescaped_spans(pattern): + return list(filter( + lambda span: (span[0] - 1, span[1]) not in backslash_spans, + self.find_spans(pattern) + )) + + return ( + find_unescaped_spans(r"{"), + find_unescaped_spans(r"}"), + backslash_spans + find_unescaped_spans(r"[_^]") + ) def get_specified_items( self, cmd_span_pairs: list[tuple[Span, Span]] @@ -117,8 +96,8 @@ class MTex(LabelledString): (span_begin, span_end) for (_, span_begin), (span_end, _) in cmd_span_pairs ] - specified_spans = self.chain( - [ + specified_spans = [ + *[ cmd_content_spans[range_begin] for _, (range_begin, range_end) in self.compress_neighbours([ (span_begin + index, span_end - index) @@ -128,77 +107,57 @@ class MTex(LabelledString): ]) if range_end - range_begin >= 2 ], - [ + *[ span for selector in self.tex_to_color_map for span in self.find_spans_by_selector(selector) ], - self.find_spans_by_selector(self.isolate) - ) + *self.find_spans_by_selector(self.isolate) + ] return [(span, {}) for span in specified_spans] - def get_replaced_substr(self, substr: str, flag: int) -> str: + def get_repl_substr_for_content(self, substr: str) -> str: return substr - def get_full_content_string(self, content_string: str, is_labelled: bool) -> str: - result = content_string + def get_repl_substr_for_matching(self, substr: str) -> str: + return substr if substr.startswith("\\") else "" + @staticmethod + def get_color_cmd_str(rgb_hex: str) -> str: + rgb = MTex.hex_to_int(rgb_hex) + rg, b = divmod(rgb, 256) + r, g = divmod(rg, 256) + return f"\\color[RGB]{{{r}, {g}, {b}}}" + + @staticmethod + def get_cmd_str_pair( + attr_dict: dict[str, str], label_hex: str | None + ) -> tuple[str, str]: + if label_hex is None: + return "", "" + return "{{" + MTex.get_color_cmd_str(label_hex), "}}" + + def get_content_prefix_and_suffix( + self, is_labelled: bool + ) -> tuple[str, str]: + prefix_lines = [] + suffix_lines = [] + if not is_labelled: + prefix_lines.append(self.get_color_cmd_str(self.base_color_hex)) + if self.alignment: + prefix_lines.append(self.alignment) if self.tex_environment: if isinstance(self.tex_environment, str): - prefix = f"\\begin{{{self.tex_environment}}}" - suffix = f"\\end{{{self.tex_environment}}}" + env_prefix = f"\\begin{{{self.tex_environment}}}" + env_suffix = f"\\end{{{self.tex_environment}}}" else: - prefix, suffix = self.tex_environment - result = "\n".join([prefix, result, suffix]) - if self.alignment: - result = "\n".join([self.alignment, result]) - - if not is_labelled: - result = "\n".join([ - self.get_color_command_str(self.base_color_hex), - result - ]) - return result - - # Selector - - def get_cleaned_substr(self, span: Span) -> str: - backslash_indices = [ - index for index, _ in self.find_spans(r"\\[\s\S]") - ] - ignored_indices = [ - index - for index, _ in self.find_spans(r"[\s_^{}]") - if index - 1 not in backslash_indices - ] - span_begin, span_end = span - while span_begin in ignored_indices: - span_begin += 1 - while span_end - 1 in ignored_indices: - span_end -= 1 - shrinked_span = (span_begin, span_end) - - whitespace_repl_items = [] - for whitespace_span in self.find_spans(r"\s+"): - if not self.span_contains(shrinked_span, whitespace_span): - continue - if whitespace_span[0] - 1 in backslash_indices: - whitespace_span = (whitespace_span[0] + 1, whitespace_span[1]) - if all( - self.get_substr((index, index + 1)).isalpha() - for index in (whitespace_span[0] - 1, whitespace_span[1]) - ): - replaced_substr = " " - else: - replaced_substr = "" - whitespace_repl_items.append((whitespace_span, replaced_substr)) - - _, unclosed_right_braces, unclosed_left_braces = self.split_span_by_levels(shrinked_span) - return "".join([ - unclosed_right_braces * "{", - self.replace_string(shrinked_span, whitespace_repl_items), - unclosed_left_braces * "}" - ]) + env_prefix, env_suffix = self.tex_environment + prefix_lines.append(env_prefix) + suffix_lines.append(env_suffix) + return ( + "".join([line + "\n" for line in prefix_lines]), + "".join(["\n" + line for line in suffix_lines]) + ) # Method alias diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index c46c955f..e61392bf 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -71,15 +71,8 @@ MARKUP_TAG_CONVERSION_DICT = { "tt": {"font_family": "monospace"}, "u": {"underline": "single"}, } -# See https://gitlab.gnome.org/GNOME/glib/-/blob/main/glib/gmarkup.c -# Line 629, 2204 -XML_ENTITIES = ( - ("<", "<"), - (">", ">"), - ("&", "&"), - ("\"", """), - ("'", "'") -) +XML_ENTITIES = ("<", ">", "&", """, "'") +XML_ENTITY_CHARS = "<>&\"'" # Temporary handler @@ -240,68 +233,50 @@ class MarkupText(LabelledString): f"{validate_error}" ) - #@property - #def sort_labelled_submobs(self) -> bool: - # return True - - # Toolkits - - @staticmethod - def get_tag_string_pair( - attr_dict: dict[str, str], label_hex: str | None - ) -> tuple[str, str]: - if label_hex is not None: - converted_attr_dict = {"foreground": label_hex} - for key, val in attr_dict.items(): - substitute_key = MARKUP_COLOR_KEYS_DICT.get(key.lower(), None) - if substitute_key is None: - converted_attr_dict[key] = val - elif substitute_key: - converted_attr_dict[key] = "black" - #else: - # converted_attr_dict[key] = "black" - else: - converted_attr_dict = attr_dict.copy() - attrs_str = " ".join([ - f"{key}='{val}'" - for key, val in converted_attr_dict.items() - ]) - return (f"", "") - - def get_global_attr_dict(self) -> dict[str, str]: - result = { - "foreground": self.base_color_hex, - "font_family": self.font, - "font_style": self.slant, - "font_weight": self.weight, - "font_size": str(self.font_size * 1024), - } - # `line_height` attribute is supported since Pango 1.50. - pango_version = manimpango.pango_version() - if tuple(map(int, pango_version.split("."))) < (1, 50): - if self.lsh is not None: - log.warning( - "Pango version %s found (< 1.50), " - "unable to set `line_height` attribute", - pango_version - ) - else: - line_spacing_scale = self.lsh or DEFAULT_LINE_SPACING_SCALE - result["line_height"] = str(((line_spacing_scale) + 1) * 0.6) - return result - # Parsing - def get_command_spans(self) -> tuple[list[Span], list[Span], list[Span]]: - begin_cmd_spans = self.find_spans( - r"<\w+\s*(?:\w+\s*\=\s*(['\x22])[\s\S]*?\1\s*)*>" - ) - end_cmd_spans = self.find_spans(r"") + def get_cmd_spans(self) -> tuple[list[Span], list[Span], list[Span]]: if not self.is_markup: - cmd_spans = [] - else: - cmd_spans = self.find_spans(r"&[\s\S]*?;") # TODO - return begin_cmd_spans, end_cmd_spans, cmd_spans + return [], [], self.find_spans(r"[<>&\x22']") + + # See https://gitlab.gnome.org/GNOME/glib/-/blob/main/glib/gmarkup.c + string = self.string + cmd_spans = [] + cmd_pattern = re.compile(r""" + &[\s\S]*?; # entity & character reference + | # tag + |<\?[\s\S]*?\?>|<\?> # instruction + ||", "", "" # See https://gitlab.gnome.org/GNOME/glib/-/blob/main/glib/gmarkup.c - string = self.string - cmd_spans = [] - cmd_pattern = re.compile(r""" - &[\s\S]*?; # entity & character reference - | # tag - |<\?[\s\S]*?\?>|<\?> # instruction - ||