diff --git a/manimlib/__init__.py b/manimlib/__init__.py index 2043738c..67703a8f 100644 --- a/manimlib/__init__.py +++ b/manimlib/__init__.py @@ -38,8 +38,8 @@ from manimlib.mobject.probability import * from manimlib.mobject.shape_matchers import * from manimlib.mobject.svg.brace import * from manimlib.mobject.svg.drawings import * -from manimlib.mobject.svg.labelled_string import * from manimlib.mobject.svg.mtex_mobject import * +from manimlib.mobject.svg.string_mobject import * from manimlib.mobject.svg.svg_mobject import * from manimlib.mobject.svg.tex_mobject import * from manimlib.mobject.svg.text_mobject import * diff --git a/manimlib/animation/creation.py b/manimlib/animation/creation.py index 27460899..86c5ca05 100644 --- a/manimlib/animation/creation.py +++ b/manimlib/animation/creation.py @@ -1,12 +1,12 @@ from __future__ import annotations -import itertools as it -from abc import abstractmethod +from abc import ABC, abstractmethod import numpy as np from manimlib.animation.animation import Animation -from manimlib.mobject.svg.labelled_string import LabelledString +from manimlib.mobject.svg.string_mobject import StringMobject +from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.bezier import integer_interpolate from manimlib.utils.config_ops import digest_config @@ -17,10 +17,10 @@ from manimlib.utils.rate_functions import smooth from typing import TYPE_CHECKING if TYPE_CHECKING: - from manimlib.mobject.mobject import Group + from manimlib.mobject.mobject import Mobject -class ShowPartial(Animation): +class ShowPartial(Animation, ABC): """ Abstract class for ShowCreation and ShowPassingFlash """ @@ -176,7 +176,7 @@ class ShowIncreasingSubsets(Animation): "int_func": np.round, } - def __init__(self, group: Group, **kwargs): + def __init__(self, group: Mobject, **kwargs): self.all_submobs = list(group.submobjects) super().__init__(group, **kwargs) @@ -212,8 +212,8 @@ class AddTextWordByWord(ShowIncreasingSubsets): } def __init__(self, string_mobject, **kwargs): - assert isinstance(string_mobject, LabelledString) - grouped_mobject = string_mobject.submob_groups + assert isinstance(string_mobject, StringMobject) + grouped_mobject = string_mobject.build_groups() digest_config(self, kwargs) if self.run_time is None: self.run_time = self.time_per_word * len(grouped_mobject) diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index dab88005..e82bafaf 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -5,24 +5,24 @@ import itertools as it import numpy as np from manimlib.animation.composition import AnimationGroup -from manimlib.animation.fading import FadeTransformPieces from manimlib.animation.fading import FadeInFromPoint from manimlib.animation.fading import FadeOutToPoint +from manimlib.animation.fading import FadeTransformPieces from manimlib.animation.transform import ReplacementTransform from manimlib.animation.transform import Transform from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Group -from manimlib.mobject.svg.labelled_string import LabelledString +from manimlib.mobject.svg.string_mobject import StringMobject from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.config_ops import digest_config -from manimlib.utils.iterables import remove_list_redundancies from typing import TYPE_CHECKING if TYPE_CHECKING: + from manimlib.mobject.svg.tex_mobject import SingleStringTex + from manimlib.mobject.svg.tex_mobject import Tex from manimlib.scene.scene import Scene - from manimlib.mobject.svg.tex_mobject import Tex, SingleStringTex class TransformMatchingParts(AnimationGroup): @@ -155,92 +155,89 @@ class TransformMatchingTex(TransformMatchingParts): class TransformMatchingStrings(AnimationGroup): CONFIG = { - "key_map": dict(), + "key_map": {}, "transform_mismatches": False, } def __init__(self, - source: LabelledString, - target: LabelledString, + source: StringMobject, + target: StringMobject, **kwargs ): digest_config(self, kwargs) - assert isinstance(source, LabelledString) - assert isinstance(target, LabelledString) + assert isinstance(source, StringMobject) + assert isinstance(target, StringMobject) anims = [] - source_indices = list(range(len(source.labelled_submobjects))) - target_indices = list(range(len(target.labelled_submobjects))) + source_indices = list(range(len(source.labels))) + target_indices = list(range(len(target.labels))) - def get_indices_lists(mobject, parts): - return [ - [ - mobject.labelled_submobjects.index(submob) - for submob in part - ] - for part in parts - ] - - def add_anims_from(anim_class, func, source_args, target_args=None): - if target_args is None: - target_args = source_args.copy() - for source_arg, target_arg in zip(source_args, target_args): - source_parts = func(source, source_arg) - target_parts = func(target, target_arg) - source_indices_lists = list(filter( - lambda indices_list: all([ - index in source_indices - for index in indices_list - ]), get_indices_lists(source, source_parts) - )) - target_indices_lists = list(filter( - lambda indices_list: all([ - index in target_indices - for index in indices_list - ]), get_indices_lists(target, target_parts) - )) - if not source_indices_lists or not target_indices_lists: + def get_filtered_indices_lists(indices_lists, rest_indices): + result = [] + for indices_list in indices_lists: + if not indices_list: continue - anims.append(anim_class(source_parts, target_parts, **kwargs)) - for index in it.chain(*source_indices_lists): - source_indices.remove(index) - for index in it.chain(*target_indices_lists): - target_indices.remove(index) - - def get_common_substrs(substrs_from_source, substrs_from_target): - return sorted([ - substr for substr in substrs_from_source - if substr and substr in substrs_from_target - ], key=len, reverse=True) - - def get_parts_from_keys(mobject, keys): - if isinstance(keys, str): - keys = [keys] - result = VGroup() - for key in keys: - if not isinstance(key, str): - raise TypeError(key) - result.add(*mobject.get_parts_by_string(key)) + if not all(index in rest_indices for index in indices_list): + continue + result.append(indices_list) + for index in indices_list: + rest_indices.remove(index) return result - add_anims_from( - ReplacementTransform, get_parts_from_keys, - self.key_map.keys(), self.key_map.values() + def add_anims(anim_class, indices_lists_pairs): + for source_indices_lists, target_indices_lists in indices_lists_pairs: + source_indices_lists = get_filtered_indices_lists( + source_indices_lists, source_indices + ) + target_indices_lists = get_filtered_indices_lists( + target_indices_lists, target_indices + ) + if not source_indices_lists or not target_indices_lists: + continue + anims.append(anim_class( + source.build_parts_from_indices_lists(source_indices_lists), + target.build_parts_from_indices_lists(target_indices_lists), + **kwargs + )) + + def get_substr_to_indices_lists_map(part_items): + result = {} + for substr, indices_list in part_items: + if substr not in result: + result[substr] = [] + result[substr].append(indices_list) + return result + + def add_anims_from(anim_class, func): + source_substr_map = get_substr_to_indices_lists_map(func(source)) + target_substr_map = get_substr_to_indices_lists_map(func(target)) + common_substrings = sorted([ + s for s in source_substr_map if s and s in target_substr_map + ], key=len, reverse=True) + add_anims( + anim_class, + [ + (source_substr_map[substr], target_substr_map[substr]) + for substr in common_substrings + ] + ) + + add_anims( + ReplacementTransform, + [ + ( + source.get_submob_indices_lists_by_selector(k), + target.get_submob_indices_lists_by_selector(v) + ) + for k, v in self.key_map.items() + ] ) add_anims_from( FadeTransformPieces, - LabelledString.get_parts_by_string, - get_common_substrs( - source.specified_substrs, - target.specified_substrs - ) + StringMobject.get_specified_part_items ) add_anims_from( FadeTransformPieces, - LabelledString.get_parts_by_group_substr, - get_common_substrs( - source.group_substrs, - target.group_substrs - ) + StringMobject.get_group_part_items ) rest_source = VGroup(*[source[index] for index in source_indices]) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py deleted file mode 100644 index f1354f0c..00000000 --- a/manimlib/mobject/svg/labelled_string.py +++ /dev/null @@ -1,543 +0,0 @@ -from __future__ import annotations - -import re -import colour -import itertools as it -from typing import Iterable, Union, Sequence -from abc import ABC, abstractmethod - -from manimlib.constants import BLACK, WHITE -from manimlib.mobject.svg.svg_mobject import SVGMobject -from manimlib.mobject.types.vectorized_mobject import VGroup -from manimlib.utils.color import color_to_int_rgb -from manimlib.utils.color import color_to_rgb -from manimlib.utils.color import rgb_to_hex -from manimlib.utils.config_ops import digest_config -from manimlib.utils.iterables import remove_list_redundancies - - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from manimlib.mobject.types.vectorized_mobject import VMobject - ManimColor = Union[str, colour.Color, Sequence[float]] - Span = tuple[int, int] - - -class _StringSVG(SVGMobject): - CONFIG = { - "height": None, - "stroke_width": 0, - "stroke_color": WHITE, - "path_string_config": { - "should_subdivide_sharp_curves": True, - "should_remove_null_curves": True, - }, - } - - -class LabelledString(_StringSVG, ABC): - """ - An abstract base class for `MTex` and `MarkupText` - """ - CONFIG = { - "base_color": WHITE, - "use_plain_file": False, - "isolate": [], - } - - def __init__(self, string: str, **kwargs): - self.string = string - digest_config(self, kwargs) - - # Convert `base_color` to hex code. - self.base_color = rgb_to_hex(color_to_rgb( - self.base_color \ - or self.svg_default.get("color", None) \ - or self.svg_default.get("fill_color", None) \ - or WHITE - )) - self.svg_default["fill_color"] = BLACK - - self.pre_parse() - self.parse() - super().__init__() - self.post_parse() - - def get_file_path(self) -> str: - return self.get_file_path_(use_plain_file=False) - - def get_file_path_(self, use_plain_file: bool) -> str: - content = self.get_content(use_plain_file) - return self.get_file_path_by_content(content) - - @abstractmethod - def get_file_path_by_content(self, content: str) -> str: - return "" - - def generate_mobject(self) -> None: - super().generate_mobject() - - submob_labels = [ - self.color_to_label(submob.get_fill_color()) - for submob in self.submobjects - ] - if self.use_plain_file or self.has_predefined_local_colors: - file_path = self.get_file_path_(use_plain_file=True) - plain_svg = _StringSVG( - file_path, - svg_default=self.svg_default, - path_string_config=self.path_string_config - ) - self.set_submobjects(plain_svg.submobjects) - else: - self.set_fill(self.base_color) - for submob, label in zip(self.submobjects, submob_labels): - submob.label = label - - def pre_parse(self) -> None: - self.string_len = len(self.string) - self.full_span = (0, self.string_len) - - def parse(self) -> None: - self.command_repl_items = self.get_command_repl_items() - self.command_spans = self.get_command_spans() - self.extra_entity_spans = self.get_extra_entity_spans() - self.entity_spans = self.get_entity_spans() - self.extra_ignored_spans = self.get_extra_ignored_spans() - self.skipped_spans = self.get_skipped_spans() - self.internal_specified_spans = self.get_internal_specified_spans() - self.external_specified_spans = self.get_external_specified_spans() - self.specified_spans = self.get_specified_spans() - self.label_span_list = self.get_label_span_list() - self.check_overlapping() - - def post_parse(self) -> None: - self.labelled_submobject_items = [ - (submob.label, submob) - for submob in self.submobjects - ] - self.labelled_submobjects = self.get_labelled_submobjects() - self.specified_substrs = self.get_specified_substrs() - self.group_items = self.get_group_items() - self.group_substrs = self.get_group_substrs() - self.submob_groups = self.get_submob_groups() - - # Toolkits - - def get_substr(self, span: Span) -> str: - return self.string[slice(*span)] - - def finditer( - self, pattern: str, flags: int = 0, **kwargs - ) -> Iterable[re.Match]: - return re.compile(pattern, flags).finditer(self.string, **kwargs) - - def search( - self, pattern: str, flags: int = 0, **kwargs - ) -> re.Match | None: - return re.compile(pattern, flags).search(self.string, **kwargs) - - def match( - self, pattern: str, flags: int = 0, **kwargs - ) -> re.Match | None: - return re.compile(pattern, flags).match(self.string, **kwargs) - - def find_spans(self, pattern: str, **kwargs) -> list[Span]: - return [ - match_obj.span() - for match_obj in self.finditer(pattern, **kwargs) - ] - - def find_substr(self, substr: str, **kwargs) -> list[Span]: - if not substr: - return [] - return self.find_spans(re.escape(substr), **kwargs) - - def find_substrs(self, substrs: list[str], **kwargs) -> list[Span]: - return list(it.chain(*[ - self.find_substr(substr, **kwargs) - for substr in remove_list_redundancies(substrs) - ])) - - @staticmethod - def get_neighbouring_pairs(iterable: list) -> list[tuple]: - return list(zip(iterable[:-1], iterable[1:])) - - @staticmethod - def span_contains(span_0: Span, span_1: Span) -> bool: - return span_0[0] <= span_1[0] and span_0[1] >= span_1[1] - - @staticmethod - def get_complement_spans( - interval_spans: list[Span], universal_span: Span - ) -> list[Span]: - if not interval_spans: - return [universal_span] - - span_ends, span_begins = zip(*interval_spans) - return list(zip( - (universal_span[0], *span_begins), - (*span_ends, universal_span[1]) - )) - - @staticmethod - def compress_neighbours(vals: list[int]) -> list[tuple[int, Span]]: - if not vals: - return [] - - unique_vals = [vals[0]] - indices = [0] - for index, val in enumerate(vals): - if val == unique_vals[-1]: - continue - unique_vals.append(val) - indices.append(index) - indices.append(len(vals)) - spans = LabelledString.get_neighbouring_pairs(indices) - return list(zip(unique_vals, spans)) - - @staticmethod - def find_region_index(seq: list[int], val: int) -> int: - # Returns an integer in `range(-1, len(seq))` satisfying - # `seq[result] <= val < seq[result + 1]`. - # `seq` should be sorted in ascending order. - if not seq or val < seq[0]: - return -1 - result = len(seq) - 1 - while val < seq[result]: - result -= 1 - return result - - @staticmethod - def take_nearest_value(seq: list[int], val: int, index_shift: int) -> int: - sorted_seq = sorted(seq) - index = LabelledString.find_region_index(sorted_seq, val) - return sorted_seq[index + index_shift] - - @staticmethod - def generate_span_repl_dict( - inserted_string_pairs: list[tuple[Span, tuple[str, str]]], - other_repl_items: list[tuple[Span, str]] - ) -> dict[Span, str]: - result = dict(other_repl_items) - if not inserted_string_pairs: - return result - - indices, _, _, inserted_strings = zip(*sorted([ - ( - span[flag], - -flag, - -span[1 - flag], - str_pair[flag] - ) - for span, str_pair in inserted_string_pairs - for flag in range(2) - ])) - result.update({ - (index, index): "".join(inserted_strings[slice(*item_span)]) - for index, item_span - in LabelledString.compress_neighbours(indices) - }) - return result - - def get_replaced_substr( - self, span: Span, span_repl_dict: dict[Span, str] - ) -> str: - repl_spans = sorted(filter( - lambda repl_span: self.span_contains(span, repl_span), - span_repl_dict.keys() - )) - if not all( - span_0[1] <= span_1[0] - for span_0, span_1 in self.get_neighbouring_pairs(repl_spans) - ): - raise ValueError("Overlapping replacement") - - pieces = [ - self.get_substr(piece_span) - for piece_span in self.get_complement_spans(repl_spans, span) - ] - repl_strs = [span_repl_dict[repl_span] for repl_span in repl_spans] - repl_strs.append("") - return "".join(it.chain(*zip(pieces, repl_strs))) - - @staticmethod - def rslide(index: int, skipped: list[Span]) -> int: - transfer_dict = dict(sorted(skipped)) - while index in transfer_dict.keys(): - index = transfer_dict[index] - return index - - @staticmethod - def lslide(index: int, skipped: list[Span]) -> int: - transfer_dict = dict(sorted([ - skipped_span[::-1] for skipped_span in skipped - ], reverse=True)) - while index in transfer_dict.keys(): - index = transfer_dict[index] - return index - - @staticmethod - def rgb_to_int(rgb_tuple: tuple[int, int, int]) -> int: - r, g, b = rgb_tuple - rg = r * 256 + g - return rg * 256 + b - - @staticmethod - def int_to_rgb(rgb_int: int) -> tuple[int, int, int]: - rg, b = divmod(rgb_int, 256) - r, g = divmod(rg, 256) - return r, g, b - - @staticmethod - def int_to_hex(rgb_int: int) -> str: - return "#{:06x}".format(rgb_int).upper() - - @staticmethod - def hex_to_int(rgb_hex: str) -> int: - return int(rgb_hex[1:], 16) - - @staticmethod - def color_to_label(color: ManimColor) -> int: - rgb_tuple = color_to_int_rgb(color) - rgb = LabelledString.rgb_to_int(rgb_tuple) - return rgb - 1 - - # Parsing - - @abstractmethod - def get_command_repl_items(self) -> list[tuple[Span, str]]: - return [] - - def get_command_spans(self) -> list[Span]: - return [cmd_span for cmd_span, _ in self.command_repl_items] - - @abstractmethod - def get_extra_entity_spans(self) -> list[Span]: - return [] - - def get_entity_spans(self) -> list[Span]: - return list(it.chain( - self.command_spans, - self.extra_entity_spans - )) - - @abstractmethod - def get_extra_ignored_spans(self) -> list[int]: - return [] - - def get_skipped_spans(self) -> list[Span]: - return list(it.chain( - self.find_spans(r"\s"), - self.command_spans, - self.extra_ignored_spans - )) - - def shrink_span(self, span: Span) -> Span: - return ( - self.rslide(span[0], self.skipped_spans), - self.lslide(span[1], self.skipped_spans) - ) - - @abstractmethod - def get_internal_specified_spans(self) -> list[Span]: - return [] - - @abstractmethod - def get_external_specified_spans(self) -> list[Span]: - return [] - - def get_specified_spans(self) -> list[Span]: - spans = list(it.chain( - self.internal_specified_spans, - self.external_specified_spans, - self.find_substrs(self.isolate) - )) - shrinked_spans = list(filter( - lambda span: span[0] < span[1] and not any([ - entity_span[0] < index < entity_span[1] - for index in span - for entity_span in self.entity_spans - ]), - [self.shrink_span(span) for span in spans] - )) - return remove_list_redundancies(shrinked_spans) - - @abstractmethod - def get_label_span_list(self) -> list[Span]: - return [] - - def check_overlapping(self) -> None: - for span_0, span_1 in it.product(self.label_span_list, repeat=2): - if not span_0[0] < span_1[0] < span_0[1] < span_1[1]: - continue - raise ValueError( - "Partially overlapping substrings detected: " - f"'{self.get_substr(span_0)}' and '{self.get_substr(span_1)}'" - ) - - @abstractmethod - def get_content(self, use_plain_file: bool) -> str: - return "" - - @abstractmethod - def has_predefined_local_colors(self) -> bool: - return False - - # Post-parsing - - def get_labelled_submobjects(self) -> list[VMobject]: - return [submob for _, submob in self.labelled_submobject_items] - - def get_cleaned_substr(self, span: Span) -> str: - span_repl_dict = dict.fromkeys(self.command_spans, "") - return self.get_replaced_substr(span, span_repl_dict) - - def get_specified_substrs(self) -> list[str]: - return remove_list_redundancies([ - self.get_cleaned_substr(span) - for span in self.specified_spans - ]) - - def get_group_items(self) -> list[tuple[str, VGroup]]: - if not self.labelled_submobject_items: - return [] - - labels, labelled_submobjects = zip(*self.labelled_submobject_items) - group_labels, labelled_submob_spans = zip( - *self.compress_neighbours(labels) - ) - ordered_spans = [ - self.label_span_list[label] if label != -1 else self.full_span - for label in group_labels - ] - interval_spans = [ - ( - next_span[0] - if self.span_contains(prev_span, next_span) - else prev_span[1], - prev_span[1] - if self.span_contains(next_span, prev_span) - else next_span[0] - ) - for prev_span, next_span in self.get_neighbouring_pairs( - ordered_spans - ) - ] - shrinked_spans = [ - self.shrink_span(span) - for span in self.get_complement_spans( - interval_spans, (ordered_spans[0][0], ordered_spans[-1][1]) - ) - ] - group_substrs = [ - self.get_cleaned_substr(span) if span[0] < span[1] else "" - for span in shrinked_spans - ] - submob_groups = VGroup(*[ - VGroup(*labelled_submobjects[slice(*submob_span)]) - for submob_span in labelled_submob_spans - ]) - return list(zip(group_substrs, submob_groups)) - - def get_group_substrs(self) -> list[str]: - return [group_substr for group_substr, _ in self.group_items] - - def get_submob_groups(self) -> list[VGroup]: - return [submob_group for _, submob_group in self.group_items] - - def get_parts_by_group_substr(self, substr: str) -> VGroup: - return VGroup(*[ - group - for group_substr, group in self.group_items - if group_substr == substr - ]) - - # Selector - - def find_span_components( - self, custom_span: Span, substring: bool = True - ) -> list[Span]: - shrinked_span = self.shrink_span(custom_span) - if shrinked_span[0] >= shrinked_span[1]: - return [] - - if substring: - indices = remove_list_redundancies(list(it.chain( - self.full_span, - *self.label_span_list - ))) - span_begin = self.take_nearest_value( - indices, shrinked_span[0], 0 - ) - span_end = self.take_nearest_value( - indices, shrinked_span[1] - 1, 1 - ) - else: - span_begin, span_end = shrinked_span - - span_choices = sorted(filter( - lambda span: self.span_contains((span_begin, span_end), span), - self.label_span_list - )) - # Choose spans that reach the farthest. - span_choices_dict = dict(span_choices) - - result = [] - while span_begin < span_end: - if span_begin not in span_choices_dict.keys(): - span_begin += 1 - continue - next_begin = span_choices_dict[span_begin] - result.append((span_begin, next_begin)) - span_begin = next_begin - return result - - def get_part_by_custom_span(self, custom_span: Span, **kwargs) -> VGroup: - labels = [ - label for label, span in enumerate(self.label_span_list) - if any([ - self.span_contains(span_component, span) - for span_component in self.find_span_components( - custom_span, **kwargs - ) - ]) - ] - return VGroup(*[ - submob for label, submob in self.labelled_submobject_items - if label in labels - ]) - - def get_parts_by_string( - self, substr: str, - case_sensitive: bool = True, regex: bool = False, **kwargs - ) -> VGroup: - flags = 0 - if not case_sensitive: - flags |= re.I - pattern = substr if regex else re.escape(substr) - return VGroup(*[ - self.get_part_by_custom_span(span, **kwargs) - for span in self.find_spans(pattern, flags=flags) - if span[0] < span[1] - ]) - - def get_part_by_string( - self, substr: str, index: int = 0, **kwargs - ) -> VMobject: - return self.get_parts_by_string(substr, **kwargs)[index] - - def set_color_by_string(self, substr: str, color: ManimColor, **kwargs): - self.get_parts_by_string(substr, **kwargs).set_color(color) - return self - - def set_color_by_string_to_color_map( - self, string_to_color_map: dict[str, ManimColor], **kwargs - ): - for substr, color in string_to_color_map.items(): - self.set_color_by_string(substr, color, **kwargs) - return self - - def get_string(self) -> str: - return self.string diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index fb7922e1..149f313f 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -1,28 +1,37 @@ from __future__ import annotations -import itertools as it -import colour -from typing import Union, Sequence - -from manimlib.mobject.svg.labelled_string import LabelledString -from manimlib.utils.tex_file_writing import tex_to_svg_file -from manimlib.utils.tex_file_writing import get_tex_config +from manimlib.mobject.svg.string_mobject import StringMobject from manimlib.utils.tex_file_writing import display_during_execution - +from manimlib.utils.tex_file_writing import get_tex_config +from manimlib.utils.tex_file_writing import tex_to_svg_file from typing import TYPE_CHECKING if TYPE_CHECKING: - from manimlib.mobject.types.vectorized_mobject import VMobject + from colour import Color + import re + from typing import Iterable, Union + from manimlib.mobject.types.vectorized_mobject import VGroup - ManimColor = Union[str, colour.Color, Sequence[float]] + + ManimColor = Union[str, Color] Span = tuple[int, int] + Selector = Union[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]], + Iterable[Union[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]] + ]] + ] SCALE_FACTOR_PER_FONT_POINT = 0.001 -class MTex(LabelledString): +class MTex(StringMobject): CONFIG = { "font_size": 48, "alignment": "\\centering", @@ -32,7 +41,7 @@ class MTex(LabelledString): def __init__(self, tex_string: str, **kwargs): # Prevent from passing an empty string. - if not tex_string: + if not tex_string.strip(): tex_string = "\\\\" self.tex_string = tex_string super().__init__(tex_string, **kwargs) @@ -47,7 +56,6 @@ class MTex(LabelledString): self.svg_default, self.path_string_config, self.base_color, - self.use_plain_file, self.isolate, self.tex_string, self.alignment, @@ -61,270 +69,103 @@ class MTex(LabelledString): tex_config["text_to_replace"], content ) - with display_during_execution(f"Writing \"{self.tex_string}\""): + with display_during_execution(f"Writing \"{self.string}\""): file_path = tex_to_svg_file(full_tex) return file_path - def pre_parse(self) -> None: - super().pre_parse() - self.backslash_indices = self.get_backslash_indices() - self.brace_index_pairs = self.get_brace_index_pairs() - self.script_char_spans = self.get_script_char_spans() - self.script_content_spans = self.get_script_content_spans() - self.script_spans = self.get_script_spans() - - # Toolkits - - @staticmethod - def get_color_command_str(rgb_int: int) -> str: - rgb_tuple = MTex.int_to_rgb(rgb_int) - return "".join([ - "\\color[RGB]", - "{", - ",".join(map(str, rgb_tuple)), - "}" - ]) - - # Pre-parsing - - def get_backslash_indices(self) -> list[int]: - # The latter of `\\` doesn't count. - return list(it.chain(*[ - range(span[0], span[1], 2) - for span in self.find_spans(r"\\+") - ])) - - def get_unescaped_char_spans(self, chars: str): - return sorted(filter( - lambda span: span[0] - 1 not in self.backslash_indices, - self.find_substrs(list(chars)) - )) - - def get_brace_index_pairs(self) -> list[Span]: - left_brace_indices = [] - right_brace_indices = [] - left_brace_indices_stack = [] - for span in self.get_unescaped_char_spans("{}"): - index = span[0] - if self.get_substr(span) == "{": - left_brace_indices_stack.append(index) - else: - if not left_brace_indices_stack: - raise ValueError("Missing '{' inserted") - left_brace_index = left_brace_indices_stack.pop() - left_brace_indices.append(left_brace_index) - right_brace_indices.append(index) - if left_brace_indices_stack: - raise ValueError("Missing '}' inserted") - return list(zip(left_brace_indices, right_brace_indices)) - - def get_script_char_spans(self) -> list[int]: - return self.get_unescaped_char_spans("_^") - - def get_script_content_spans(self) -> list[Span]: - result = [] - brace_indices_dict = dict(self.brace_index_pairs) - script_pattern = r"[a-zA-Z0-9]|\\[a-zA-Z]+" - for script_char_span in self.script_char_spans: - span_begin = self.match(r"\s*", pos=script_char_span[1]).end() - if span_begin in brace_indices_dict.keys(): - span_end = brace_indices_dict[span_begin] + 1 - else: - match_obj = self.match(script_pattern, pos=span_begin) - if not match_obj: - script_name = { - "_": "subscript", - "^": "superscript" - }[script_char] - raise ValueError( - f"Unclear {script_name} detected while parsing. " - "Please use braces to clarify" - ) - span_end = match_obj.end() - result.append((span_begin, span_end)) - return result - - def get_script_spans(self) -> list[Span]: - return [ - ( - self.search(r"\s*$", endpos=script_char_span[0]).start(), - script_content_span[1] - ) - for script_char_span, script_content_span in zip( - self.script_char_spans, self.script_content_spans - ) - ] - # Parsing - def get_command_repl_items(self) -> list[tuple[Span, str]]: - color_related_command_dict = { - "color": (1, False), - "textcolor": (1, False), - "pagecolor": (1, True), - "colorbox": (1, True), - "fcolorbox": (2, True), - } - result = [] - backslash_indices = self.backslash_indices - right_brace_indices = [ - right_index - for left_index, right_index in self.brace_index_pairs + def get_cmd_spans(self) -> list[Span]: + return self.find_spans(r"\\(?:[a-zA-Z]+|\s|\S)|[_^{}]") + + def get_substr_flag(self, substr: str) -> int: + return {"{": 1, "}": -1}.get(substr, 0) + + def get_repl_substr_for_content(self, substr: str) -> str: + return substr + + def get_repl_substr_for_matching(self, substr: str) -> str: + return substr if substr.startswith("\\") else "" + + def get_specified_items( + self, cmd_span_pairs: list[tuple[Span, Span]] + ) -> list[tuple[Span, dict[str, str]]]: + cmd_content_spans = [ + (span_begin, span_end) + for (_, span_begin), (span_end, _) in cmd_span_pairs ] - pattern = "".join([ - r"\\", - "(", - "|".join(color_related_command_dict.keys()), - ")", - r"(?![a-zA-Z])" - ]) - for match_obj in self.finditer(pattern): - span_begin, cmd_end = match_obj.span() - if span_begin not in backslash_indices: - continue - cmd_name = match_obj.group(1) - n_braces, substitute_cmd = color_related_command_dict[cmd_name] - span_end = self.take_nearest_value( - right_brace_indices, cmd_end, n_braces - ) + 1 - if substitute_cmd: - repl_str = "\\" + cmd_name + n_braces * "{black}" - else: - repl_str = "" - result.append(((span_begin, span_end), repl_str)) - return result - - def get_extra_entity_spans(self) -> list[Span]: - return [ - self.match(r"\\([a-zA-Z]+|.)", pos=index).span() - for index in self.backslash_indices - ] - - def get_extra_ignored_spans(self) -> list[int]: - return self.script_char_spans.copy() - - def get_internal_specified_spans(self) -> list[Span]: - # Match paired double braces (`{{...}}`). - result = [] - reversed_brace_indices_dict = dict([ - pair[::-1] for pair in self.brace_index_pairs - ]) - skip = False - for prev_right_index, right_index in self.get_neighbouring_pairs( - list(reversed_brace_indices_dict.keys()) - ): - if skip: - skip = False - continue - if right_index != prev_right_index + 1: - continue - left_index = reversed_brace_indices_dict[right_index] - prev_left_index = reversed_brace_indices_dict[prev_right_index] - if left_index != prev_left_index - 1: - continue - result.append((left_index, right_index + 1)) - skip = True - return result - - def get_external_specified_spans(self) -> list[Span]: - return self.find_substrs(list(self.tex_to_color_map.keys())) - - def get_label_span_list(self) -> list[Span]: - result = self.script_content_spans.copy() - for span_begin, span_end in self.specified_spans: - shrinked_end = self.lslide(span_end, self.script_spans) - if span_begin >= shrinked_end: - continue - shrinked_span = (span_begin, shrinked_end) - if shrinked_span in result: - continue - result.append(shrinked_span) - return result - - def get_content(self, use_plain_file: bool) -> str: - if use_plain_file: - span_repl_dict = {} - else: - extended_label_span_list = [ + specified_spans = [ + *[ + cmd_content_spans[range_begin] + for _, (range_begin, range_end) in self.compress_neighbours([ + (span_begin + index, span_end - index) + for index, (span_begin, span_end) in enumerate( + cmd_content_spans + ) + ]) + if range_end - range_begin >= 2 + ], + *[ span - if span in self.script_content_spans - else (span[0], self.rslide(span[1], self.script_spans)) - for span in self.label_span_list - ] - inserted_string_pairs = [ - (span, ( - "{{" + self.get_color_command_str(label + 1), - "}}" - )) - for label, span in enumerate(extended_label_span_list) - ] - span_repl_dict = self.generate_span_repl_dict( - inserted_string_pairs, - self.command_repl_items - ) - result = self.get_replaced_substr(self.full_span, span_repl_dict) + for selector in self.tex_to_color_map + for span in self.find_spans_by_selector(selector) + ], + *self.find_spans_by_selector(self.isolate) + ] + return [(span, {}) for span in specified_spans] - if self.tex_environment: - result = "\n".join([ - f"\\begin{{{self.tex_environment}}}", - result, - f"\\end{{{self.tex_environment}}}" - ]) + @staticmethod + def get_color_cmd_str(rgb_hex: str) -> str: + rgb = MTex.hex_to_int(rgb_hex) + rg, b = divmod(rgb, 256) + r, g = divmod(rg, 256) + return f"\\color[RGB]{{{r}, {g}, {b}}}" + + @staticmethod + def get_cmd_str_pair( + attr_dict: dict[str, str], label_hex: str | None + ) -> tuple[str, str]: + if label_hex is None: + return "", "" + return "{{" + MTex.get_color_cmd_str(label_hex), "}}" + + def get_content_prefix_and_suffix( + self, is_labelled: bool + ) -> tuple[str, str]: + prefix_lines = [] + suffix_lines = [] + if not is_labelled: + prefix_lines.append(self.get_color_cmd_str(self.base_color_hex)) if self.alignment: - result = "\n".join([self.alignment, result]) - if use_plain_file: - result = "\n".join([ - self.get_color_command_str(self.hex_to_int(self.base_color)), - result - ]) - return result - - @property - def has_predefined_local_colors(self) -> bool: - return bool(self.command_repl_items) - - # Post-parsing - - def get_cleaned_substr(self, span: Span) -> str: - substr = super().get_cleaned_substr(span) - if not self.brace_index_pairs: - return substr - - # Balance braces. - left_brace_indices, right_brace_indices = zip(*self.brace_index_pairs) - unclosed_left_braces = 0 - unclosed_right_braces = 0 - for index in range(*span): - if index in left_brace_indices: - unclosed_left_braces += 1 - elif index in right_brace_indices: - if unclosed_left_braces == 0: - unclosed_right_braces += 1 - else: - unclosed_left_braces -= 1 - return "".join([ - unclosed_right_braces * "{", - substr, - unclosed_left_braces * "}" - ]) + prefix_lines.append(self.alignment) + if self.tex_environment: + if isinstance(self.tex_environment, str): + env_prefix = f"\\begin{{{self.tex_environment}}}" + env_suffix = f"\\end{{{self.tex_environment}}}" + else: + env_prefix, env_suffix = self.tex_environment + prefix_lines.append(env_prefix) + suffix_lines.append(env_suffix) + return ( + "".join([line + "\n" for line in prefix_lines]), + "".join(["\n" + line for line in suffix_lines]) + ) # Method alias - def get_parts_by_tex(self, tex: str, **kwargs) -> VGroup: - return self.get_parts_by_string(tex, **kwargs) + def get_parts_by_tex(self, selector: Selector) -> VGroup: + return self.select_parts(selector) - def get_part_by_tex(self, tex: str, **kwargs) -> VMobject: - return self.get_part_by_string(tex, **kwargs) + def get_part_by_tex(self, selector: Selector) -> VGroup: + return self.select_part(selector) - def set_color_by_tex(self, tex: str, color: ManimColor, **kwargs): - return self.set_color_by_string(tex, color, **kwargs) + def set_color_by_tex(self, selector: Selector, color: ManimColor): + return self.set_parts_color(selector, color) def set_color_by_tex_to_color_map( - self, tex_to_color_map: dict[str, ManimColor], **kwargs + self, color_map: dict[Selector, ManimColor] ): - return self.set_color_by_string_to_color_map( - tex_to_color_map, **kwargs - ) + return self.set_parts_color_by_dict(color_map) def get_tex(self) -> str: return self.get_string() diff --git a/manimlib/mobject/svg/string_mobject.py b/manimlib/mobject/svg/string_mobject.py new file mode 100644 index 00000000..5004960e --- /dev/null +++ b/manimlib/mobject/svg/string_mobject.py @@ -0,0 +1,532 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +import itertools as it +import re +from scipy.optimize import linear_sum_assignment +from scipy.spatial.distance import cdist + +from manimlib.constants import WHITE +from manimlib.logger import log +from manimlib.mobject.svg.svg_mobject import SVGMobject +from manimlib.mobject.types.vectorized_mobject import VGroup +from manimlib.utils.color import color_to_rgb +from manimlib.utils.color import rgb_to_hex +from manimlib.utils.config_ops import digest_config + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from colour import Color + from typing import Iterable, Sequence, TypeVar, Union + + ManimColor = Union[str, Color] + Span = tuple[int, int] + Selector = Union[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]], + Iterable[Union[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]] + ]] + ] + T = TypeVar("T") + + +class StringMobject(SVGMobject, ABC): + """ + An abstract base class for `MTex` and `MarkupText` + + This class aims to optimize the logic of "slicing submobjects + via substrings". This could be much clearer and more user-friendly + than slicing through numerical indices explicitly. + + Users are expected to specify substrings in `isolate` parameter + if they want to do anything with their corresponding submobjects. + `isolate` parameter can be either a string, a `re.Pattern` object, + or a 2-tuple containing integers or None, or a collection of the above. + Note, substrings specified cannot *partially* overlap with each other. + + Each instance of `StringMobject` generates 2 svg files. + The additional one is generated with some color commands inserted, + so that each submobject of the original `SVGMobject` will be labelled + by the color of its paired submobject from the additional `SVGMobject`. + """ + CONFIG = { + "height": None, + "stroke_width": 0, + "stroke_color": WHITE, + "path_string_config": { + "should_subdivide_sharp_curves": True, + "should_remove_null_curves": True, + }, + "base_color": WHITE, + "isolate": (), + } + + def __init__(self, string: str, **kwargs): + self.string = string + digest_config(self, kwargs) + if self.base_color is None: + self.base_color = WHITE + self.base_color_hex = self.color_to_hex(self.base_color) + + self.full_span = (0, len(self.string)) + self.parse() + super().__init__(**kwargs) + self.labels = [submob.label for submob in self.submobjects] + + def get_file_path(self) -> str: + original_content = self.get_content(is_labelled=False) + return self.get_file_path_by_content(original_content) + + @abstractmethod + def get_file_path_by_content(self, content: str) -> str: + return "" + + def generate_mobject(self) -> None: + super().generate_mobject() + + labels_count = len(self.labelled_spans) + if not labels_count: + for submob in self.submobjects: + submob.label = -1 + return + + labelled_content = self.get_content(is_labelled=True) + file_path = self.get_file_path_by_content(labelled_content) + labelled_svg = SVGMobject(file_path) + if len(self.submobjects) != len(labelled_svg.submobjects): + log.warning( + "Cannot align submobjects of the labelled svg " + "to the original svg. Skip the labelling process." + ) + for submob in self.submobjects: + submob.label = -1 + return + + self.rearrange_submobjects_by_positions(labelled_svg) + unrecognizable_colors = [] + for submob, labelled_svg_submob in zip( + self.submobjects, labelled_svg.submobjects + ): + color_int = self.hex_to_int(self.color_to_hex( + labelled_svg_submob.get_fill_color() + )) + if color_int > labels_count: + unrecognizable_colors.append(color_int) + color_int = 0 + submob.label = color_int - 1 + if unrecognizable_colors: + log.warning( + "Unrecognizable color labels detected (%s, etc). " + "The result could be unexpected.", + self.int_to_hex(unrecognizable_colors[0]) + ) + + def rearrange_submobjects_by_positions( + self, labelled_svg: SVGMobject + ) -> None: + # Rearrange submobjects of `labelled_svg` so that + # each submobject is labelled by the nearest one of `labelled_svg`. + # The correctness cannot be ensured, since the svg may + # change significantly after inserting color commands. + if not labelled_svg.submobjects: + return + + bb_0 = self.get_bounding_box() + bb_1 = labelled_svg.get_bounding_box() + scale_factor = abs((bb_0[2] - bb_0[0]) / (bb_1[2] - bb_1[0])) + labelled_svg.move_to(self).scale(scale_factor) + + distance_matrix = cdist( + [submob.get_center() for submob in self.submobjects], + [submob.get_center() for submob in labelled_svg.submobjects] + ) + _, indices = linear_sum_assignment(distance_matrix) + labelled_svg.set_submobjects([ + labelled_svg.submobjects[index] + for index in indices + ]) + + # Toolkits + + def get_substr(self, span: Span) -> str: + return self.string[slice(*span)] + + def find_spans(self, pattern: str | re.Pattern) -> list[Span]: + return [ + match_obj.span() + for match_obj in re.finditer(pattern, self.string) + ] + + def find_spans_by_selector(self, selector: Selector) -> list[Span]: + def find_spans_by_single_selector(sel): + if isinstance(sel, str): + return self.find_spans(re.escape(sel)) + if isinstance(sel, re.Pattern): + return self.find_spans(sel) + if isinstance(sel, tuple) and len(sel) == 2 and all( + isinstance(index, int) or index is None + for index in sel + ): + l = self.full_span[1] + span = tuple( + min(index, l) if index >= 0 else max(index + l, 0) + if index is not None else default_index + for index, default_index in zip(sel, self.full_span) + ) + return [span] + return None + + result = find_spans_by_single_selector(selector) + if result is None: + result = [] + for sel in selector: + spans = find_spans_by_single_selector(sel) + if spans is None: + raise TypeError(f"Invalid selector: '{sel}'") + result.extend(spans) + return result + + @staticmethod + def get_neighbouring_pairs(vals: Sequence[T]) -> list[tuple[T, T]]: + return list(zip(vals[:-1], vals[1:])) + + @staticmethod + def compress_neighbours(vals: Sequence[T]) -> list[tuple[T, Span]]: + if not vals: + return [] + + unique_vals = [vals[0]] + indices = [0] + for index, val in enumerate(vals): + if val == unique_vals[-1]: + continue + unique_vals.append(val) + indices.append(index) + indices.append(len(vals)) + val_ranges = StringMobject.get_neighbouring_pairs(indices) + return list(zip(unique_vals, val_ranges)) + + @staticmethod + def span_contains(span_0: Span, span_1: Span) -> bool: + return span_0[0] <= span_1[0] and span_0[1] >= span_1[1] + + @staticmethod + def get_complement_spans( + universal_span: Span, interval_spans: list[Span] + ) -> list[Span]: + if not interval_spans: + return [universal_span] + + span_ends, span_begins = zip(*interval_spans) + return list(zip( + (universal_span[0], *span_begins), + (*span_ends, universal_span[1]) + )) + + def replace_substr(self, span: Span, repl_items: list[Span, str]): + if not repl_items: + return self.get_substr(span) + + repl_spans, repl_strs = zip(*sorted(repl_items, key=lambda t: t[0])) + pieces = [ + self.get_substr(piece_span) + for piece_span in self.get_complement_spans(span, repl_spans) + ] + repl_strs = [*repl_strs, ""] + return "".join(it.chain(*zip(pieces, repl_strs))) + + @staticmethod + def color_to_hex(color: ManimColor) -> str: + return rgb_to_hex(color_to_rgb(color)) + + @staticmethod + def hex_to_int(rgb_hex: str) -> int: + return int(rgb_hex[1:], 16) + + @staticmethod + def int_to_hex(rgb_int: int) -> str: + return f"#{rgb_int:06x}".upper() + + # Parsing + + def parse(self) -> None: + cmd_spans = self.get_cmd_spans() + cmd_substrs = [self.get_substr(span) for span in cmd_spans] + flags = [self.get_substr_flag(substr) for substr in cmd_substrs] + specified_items = self.get_specified_items( + self.get_cmd_span_pairs(cmd_spans, flags) + ) + split_items = [ + (span, attr_dict) + for specified_span, attr_dict in specified_items + for span in self.split_span_by_levels( + specified_span, cmd_spans, flags + ) + ] + + self.specified_spans = [span for span, _ in specified_items] + self.split_items = split_items + self.labelled_spans = [span for span, _ in split_items] + self.cmd_repl_items_for_content = [ + (span, self.get_repl_substr_for_content(substr)) + for span, substr in zip(cmd_spans, cmd_substrs) + ] + self.cmd_repl_items_for_matching = [ + (span, self.get_repl_substr_for_matching(substr)) + for span, substr in zip(cmd_spans, cmd_substrs) + ] + self.check_overlapping() + + @abstractmethod + def get_cmd_spans(self) -> list[Span]: + return [] + + @abstractmethod + def get_substr_flag(self, substr: str) -> int: + return 0 + + @abstractmethod + def get_repl_substr_for_content(self, substr: str) -> str: + return "" + + @abstractmethod + def get_repl_substr_for_matching(self, substr: str) -> str: + return "" + + @staticmethod + def get_cmd_span_pairs( + cmd_spans: list[Span], flags: list[int] + ) -> list[tuple[Span, Span]]: + result = [] + begin_cmd_spans_stack = [] + for cmd_span, flag in zip(cmd_spans, flags): + if flag == 1: + begin_cmd_spans_stack.append(cmd_span) + elif flag == -1: + if not begin_cmd_spans_stack: + raise ValueError("Missing open command") + begin_cmd_span = begin_cmd_spans_stack.pop() + result.append((begin_cmd_span, cmd_span)) + if begin_cmd_spans_stack: + raise ValueError("Missing close command") + return result + + @abstractmethod + def get_specified_items( + self, cmd_span_pairs: list[tuple[Span, Span]] + ) -> list[tuple[Span, dict[str, str]]]: + return [] + + def split_span_by_levels( + self, arbitrary_span: Span, cmd_spans: list[Span], flags: list[int] + ) -> list[Span]: + cmd_range = ( + sum([ + arbitrary_span[0] > interval_begin + for interval_begin, _ in cmd_spans + ]), + sum([ + arbitrary_span[1] >= interval_end + for _, interval_end in cmd_spans + ]) + ) + complement_spans = self.get_complement_spans( + self.full_span, cmd_spans + ) + adjusted_span = ( + max(arbitrary_span[0], complement_spans[cmd_range[0]][0]), + min(arbitrary_span[1], complement_spans[cmd_range[1]][1]) + ) + if adjusted_span[0] > adjusted_span[1]: + return [] + + upward_cmd_spans = [] + downward_cmd_spans = [] + for cmd_span, flag in list(zip(cmd_spans, flags))[slice(*cmd_range)]: + if flag == 1: + upward_cmd_spans.append(cmd_span) + elif flag == -1: + if upward_cmd_spans: + upward_cmd_spans.pop() + else: + downward_cmd_spans.append(cmd_span) + return list(filter( + lambda span: self.get_substr(span).strip(), + self.get_complement_spans( + adjusted_span, downward_cmd_spans + upward_cmd_spans + ) + )) + + def check_overlapping(self) -> None: + labelled_spans = self.labelled_spans + if len(labelled_spans) >= 16777216: + raise ValueError("Cannot handle that many substrings") + for span_0, span_1 in it.product(labelled_spans, repeat=2): + if not span_0[0] < span_1[0] < span_0[1] < span_1[1]: + continue + raise ValueError( + "Partially overlapping substrings detected: " + f"'{self.get_substr(span_0)}' and '{self.get_substr(span_1)}'" + ) + + @staticmethod + @abstractmethod + def get_cmd_str_pair( + attr_dict: dict[str, str], label_hex: str | None + ) -> tuple[str, str]: + return "", "" + + @abstractmethod + def get_content_prefix_and_suffix( + self, is_labelled: bool + ) -> tuple[str, str]: + return "", "" + + def get_content(self, is_labelled: bool) -> str: + inserted_str_pairs = [ + (span, self.get_cmd_str_pair( + attr_dict, + label_hex=self.int_to_hex(label + 1) if is_labelled else None + )) + for label, (span, attr_dict) in enumerate(self.split_items) + ] + inserted_str_items = sorted([ + (index, s) + for (index, _), s in [ + *sorted([ + (span[::-1], end_str) + for span, (_, end_str) in reversed(inserted_str_pairs) + ], key=lambda t: (t[0][0], -t[0][1])), + *sorted([ + (span, begin_str) + for span, (begin_str, _) in inserted_str_pairs + ], key=lambda t: (t[0][0], -t[0][1])) + ] + ], key=lambda t: t[0]) + repl_items = self.cmd_repl_items_for_content + [ + ((index, index), inserted_str) + for index, inserted_str in inserted_str_items + ] + prefix, suffix = self.get_content_prefix_and_suffix(is_labelled) + return "".join([ + prefix, + self.replace_substr(self.full_span, repl_items), + suffix + ]) + + # Selector + + def get_submob_indices_list_by_span( + self, arbitrary_span: Span + ) -> list[int]: + return [ + submob_index + for submob_index, label in enumerate(self.labels) + if label != -1 and self.span_contains( + arbitrary_span, self.labelled_spans[label] + ) + ] + + def get_specified_part_items(self) -> list[tuple[str, list[int]]]: + return [ + ( + self.get_substr(span), + self.get_submob_indices_list_by_span(span) + ) + for span in self.specified_spans + ] + + def get_group_part_items(self) -> list[tuple[str, list[int]]]: + if not self.labels: + return [] + + group_labels, labelled_submob_ranges = zip( + *self.compress_neighbours(self.labels) + ) + ordered_spans = [ + self.labelled_spans[label] if label != -1 else self.full_span + for label in group_labels + ] + interval_spans = [ + ( + next_span[0] + if self.span_contains(prev_span, next_span) + else prev_span[1], + prev_span[1] + if self.span_contains(next_span, prev_span) + else next_span[0] + ) + for prev_span, next_span in self.get_neighbouring_pairs( + ordered_spans + ) + ] + group_substrs = [ + re.sub(r"\s+", "", self.replace_substr( + span, [ + (cmd_span, repl_str) + for cmd_span, repl_str in self.cmd_repl_items_for_matching + if self.span_contains(span, cmd_span) + ] + )) + for span in self.get_complement_spans( + (ordered_spans[0][0], ordered_spans[-1][1]), interval_spans + ) + ] + submob_indices_lists = [ + list(range(*submob_range)) + for submob_range in labelled_submob_ranges + ] + return list(zip(group_substrs, submob_indices_lists)) + + def get_submob_indices_lists_by_selector( + self, selector: Selector + ) -> list[list[int]]: + return list(filter( + lambda indices_list: indices_list, + [ + self.get_submob_indices_list_by_span(span) + for span in self.find_spans_by_selector(selector) + ] + )) + + def build_parts_from_indices_lists( + self, indices_lists: list[list[int]] + ) -> VGroup: + return VGroup(*[ + VGroup(*[ + self.submobjects[submob_index] + for submob_index in indices_list + ]) + for indices_list in indices_lists + ]) + + def build_groups(self) -> VGroup: + return self.build_parts_from_indices_lists([ + indices_list + for _, indices_list in self.get_group_part_items() + ]) + + def select_parts(self, selector: Selector) -> VGroup: + return self.build_parts_from_indices_lists( + self.get_submob_indices_lists_by_selector(selector) + ) + + def select_part(self, selector: Selector, index: int = 0) -> VGroup: + return self.select_parts(selector)[index] + + def set_parts_color(self, selector: Selector, color: ManimColor): + self.select_parts(selector).set_color(color) + return self + + def set_parts_color_by_dict(self, color_map: dict[Selector, ManimColor]): + for selector, color in color_map.items(): + self.set_parts_color(selector, color) + return self + + def get_string(self) -> str: + return self.string diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index c3c3be19..93623c31 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -1,103 +1,52 @@ from __future__ import annotations -import os -import re -import itertools as it -from pathlib import Path from contextlib import contextmanager -import typing -from typing import Iterable, Sequence, Union +import os +from pathlib import Path +import re +import manimpango import pygments import pygments.formatters import pygments.lexers -from manimpango import MarkupUtils - +from manimlib.constants import DEFAULT_PIXEL_WIDTH, FRAME_WIDTH +from manimlib.constants import NORMAL from manimlib.logger import log -from manimlib.constants import * -from manimlib.mobject.svg.labelled_string import LabelledString -from manimlib.utils.customization import get_customization -from manimlib.utils.tex_file_writing import tex_hash +from manimlib.mobject.svg.string_mobject import StringMobject from manimlib.utils.config_ops import digest_config +from manimlib.utils.customization import get_customization from manimlib.utils.directories import get_downloads_dir from manimlib.utils.directories import get_text_dir -from manimlib.utils.iterables import remove_list_redundancies - +from manimlib.utils.tex_file_writing import tex_hash from typing import TYPE_CHECKING if TYPE_CHECKING: - from manimlib.mobject.types.vectorized_mobject import VMobject + from colour import Color + from typing import Iterable, Union + from manimlib.mobject.types.vectorized_mobject import VGroup - ManimColor = Union[str, colour.Color, Sequence[float]] + + ManimColor = Union[str, Color] Span = tuple[int, int] + Selector = Union[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]], + Iterable[Union[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]] + ]] + ] TEXT_MOB_SCALE_FACTOR = 0.0076 DEFAULT_LINE_SPACING_SCALE = 0.6 - - -# See https://docs.gtk.org/Pango/pango_markup.html -# A tag containing two aliases will cause warning, -# so only use the first key of each group of aliases. -SPAN_ATTR_KEY_ALIAS_LIST = ( - ("font", "font_desc"), - ("font_family", "face"), - ("font_size", "size"), - ("font_style", "style"), - ("font_weight", "weight"), - ("font_variant", "variant"), - ("font_stretch", "stretch"), - ("font_features",), - ("foreground", "fgcolor", "color"), - ("background", "bgcolor"), - ("alpha", "fgalpha"), - ("background_alpha", "bgalpha"), - ("underline",), - ("underline_color",), - ("overline",), - ("overline_color",), - ("rise",), - ("baseline_shift",), - ("font_scale",), - ("strikethrough",), - ("strikethrough_color",), - ("fallback",), - ("lang",), - ("letter_spacing",), - ("gravity",), - ("gravity_hint",), - ("show",), - ("insert_hyphens",), - ("allow_breaks",), - ("line_height",), - ("text_transform",), - ("segment",), -) -COLOR_RELATED_KEYS = ( - "foreground", - "background", - "underline_color", - "overline_color", - "strikethrough_color" -) -SPAN_ATTR_KEY_CONVERSION = { - key: key_alias_list[0] - for key_alias_list in SPAN_ATTR_KEY_ALIAS_LIST - for key in key_alias_list -} -TAG_TO_ATTR_DICT = { - "b": {"font_weight": "bold"}, - "big": {"font_size": "larger"}, - "i": {"font_style": "italic"}, - "s": {"strikethrough": "true"}, - "sub": {"baseline_shift": "subscript", "font_scale": "subscript"}, - "sup": {"baseline_shift": "superscript", "font_scale": "superscript"}, - "small": {"font_size": "smaller"}, - "tt": {"font_family": "monospace"}, - "u": {"underline": "single"}, -} +# Ensure the canvas is large enough to hold all glyphs. +DEFAULT_CANVAS_WIDTH = 16384 +DEFAULT_CANVAS_HEIGHT = 16384 # Temporary handler @@ -112,7 +61,7 @@ class _Alignment: self.value = _Alignment.VAL_DICT[s.upper()] -class MarkupText(LabelledString): +class MarkupText(StringMobject): CONFIG = { "is_markup": True, "font_size": 48, @@ -120,7 +69,7 @@ class MarkupText(LabelledString): "justify": False, "indent": 0, "alignment": "LEFT", - "line_width_factor": None, + "line_width": None, "font": "", "slant": NORMAL, "weight": NORMAL, @@ -132,6 +81,31 @@ class MarkupText(LabelledString): "t2w": {}, "global_config": {}, "local_configs": {}, + # For backward compatibility + "isolate": (re.compile(r"[a-zA-Z]+"), re.compile(r"\S+")), + } + + # See https://docs.gtk.org/Pango/pango_markup.html + MARKUP_COLOR_KEYS = { + "foreground": False, + "fgcolor": False, + "color": False, + "background": True, + "bgcolor": True, + "underline_color": True, + "overline_color": True, + "strikethrough_color": True, + } + MARKUP_TAGS = { + "b": {"font_weight": "bold"}, + "big": {"font_size": "larger"}, + "i": {"font_style": "italic"}, + "s": {"strikethrough": "true"}, + "sub": {"baseline_shift": "subscript", "font_scale": "subscript"}, + "sup": {"baseline_shift": "superscript", "font_scale": "superscript"}, + "small": {"font_size": "smaller"}, + "tt": {"font_family": "monospace"}, + "u": {"underline": "single"}, } def __init__(self, text: str, **kwargs): @@ -141,9 +115,7 @@ class MarkupText(LabelledString): if not self.font: self.font = get_customization()["style"]["font"] if self.is_markup: - validate_error = MarkupUtils.validate(text) - if validate_error: - raise ValueError(validate_error) + self.validate_markup_string(text) self.text = text super().__init__(text, **kwargs) @@ -165,7 +137,6 @@ class MarkupText(LabelledString): self.svg_default, self.path_string_config, self.base_color, - self.use_plain_file, self.isolate, self.text, self.is_markup, @@ -174,7 +145,7 @@ class MarkupText(LabelledString): self.justify, self.indent, self.alignment, - self.line_width_factor, + self.line_width, self.font, self.slant, self.weight, @@ -201,23 +172,32 @@ class MarkupText(LabelledString): kwargs[short_name] = kwargs.pop(long_name) def get_file_path_by_content(self, content: str) -> str: + hash_content = str(( + content, + self.justify, + self.indent, + self.alignment, + self.line_width + )) svg_file = os.path.join( - get_text_dir(), tex_hash(content) + ".svg" + get_text_dir(), tex_hash(hash_content) + ".svg" ) if not os.path.exists(svg_file): self.markup_to_svg(content, svg_file) return svg_file def markup_to_svg(self, markup_str: str, file_name: str) -> str: + self.validate_markup_string(markup_str) + # `manimpango` is under construction, # so the following code is intended to suit its interface alignment = _Alignment(self.alignment) - if self.line_width_factor is None: + if self.line_width is None: pango_width = -1 else: - pango_width = self.line_width_factor * DEFAULT_PIXEL_WIDTH + pango_width = self.line_width / FRAME_WIDTH * DEFAULT_PIXEL_WIDTH - return MarkupUtils.text2svg( + return manimpango.MarkupUtils.text2svg( text=markup_str, font="", # Already handled slant="NORMAL", # Already handled @@ -228,8 +208,8 @@ class MarkupText(LabelledString): file_name=file_name, START_X=0, START_Y=0, - width=DEFAULT_PIXEL_WIDTH, - height=DEFAULT_PIXEL_HEIGHT, + width=DEFAULT_CANVAS_WIDTH, + height=DEFAULT_CANVAS_HEIGHT, justify=self.justify, indent=self.indent, line_spacing=None, # Already handled @@ -237,294 +217,173 @@ class MarkupText(LabelledString): pango_width=pango_width ) - def pre_parse(self) -> None: - super().pre_parse() - self.tag_items_from_markup = self.get_tag_items_from_markup() - self.global_dict_from_config = self.get_global_dict_from_config() - self.local_dicts_from_markup = self.get_local_dicts_from_markup() - self.local_dicts_from_config = self.get_local_dicts_from_config() - self.predefined_attr_dicts = self.get_predefined_attr_dicts() - - # Toolkits - @staticmethod - def get_attr_dict_str(attr_dict: dict[str, str]) -> str: - return " ".join([ - f"{key}='{val}'" - for key, val in attr_dict.items() - ]) - - @staticmethod - def merge_attr_dicts( - attr_dict_items: list[Span, str, typing.Any] - ) -> list[tuple[Span, dict[str, str]]]: - index_seq = [0] - attr_dict_list = [{}] - for span, attr_dict in attr_dict_items: - if span[0] >= span[1]: - continue - region_indices = [ - MarkupText.find_region_index(index_seq, index) - for index in span - ] - for flag in (1, 0): - if index_seq[region_indices[flag]] == span[flag]: - continue - region_index = region_indices[flag] - index_seq.insert(region_index + 1, span[flag]) - attr_dict_list.insert( - region_index + 1, attr_dict_list[region_index].copy() - ) - region_indices[flag] += 1 - if flag == 0: - region_indices[1] += 1 - for key, val in attr_dict.items(): - if not key: - continue - for mid_dict in attr_dict_list[slice(*region_indices)]: - mid_dict[key] = val - return list(zip( - MarkupText.get_neighbouring_pairs(index_seq), attr_dict_list[:-1] - )) - - def find_substr_or_span( - self, substr_or_span: str | tuple[int | None, int | None] - ) -> list[Span]: - if isinstance(substr_or_span, str): - return self.find_substr(substr_or_span) - - span = tuple([ - ( - min(index, self.string_len) - if index >= 0 - else max(index + self.string_len, 0) - ) - if index is not None else default_index - for index, default_index in zip(substr_or_span, self.full_span) - ]) - if span[0] >= span[1]: - return [] - return [span] - - # Pre-parsing - - def get_tag_items_from_markup( - self - ) -> list[tuple[Span, Span, dict[str, str]]]: - if not self.is_markup: - return [] - - tag_pattern = r"""<(/?)(\w+)\s*((?:\w+\s*\=\s*(['"]).*?\4\s*)*)>""" - attr_pattern = r"""(\w+)\s*\=\s*(['"])(.*?)\2""" - begin_match_obj_stack = [] - match_obj_pairs = [] - for match_obj in self.finditer(tag_pattern): - if not match_obj.group(1): - begin_match_obj_stack.append(match_obj) - else: - match_obj_pairs.append( - (begin_match_obj_stack.pop(), match_obj) - ) - if begin_match_obj_stack: - raise ValueError("Unclosed tag(s) detected") - - result = [] - for begin_match_obj, end_match_obj in match_obj_pairs: - tag_name = begin_match_obj.group(2) - if tag_name != end_match_obj.group(2): - raise ValueError("Unmatched tag names") - if end_match_obj.group(3): - raise ValueError("Attributes shan't exist in ending tags") - if tag_name == "span": - attr_dict = { - match.group(1): match.group(3) - for match in re.finditer( - attr_pattern, begin_match_obj.group(3) - ) - } - elif tag_name in TAG_TO_ATTR_DICT.keys(): - if begin_match_obj.group(3): - raise ValueError( - f"Attributes shan't exist in tag '{tag_name}'" - ) - attr_dict = TAG_TO_ATTR_DICT[tag_name].copy() - else: - raise ValueError(f"Unknown tag: '{tag_name}'") - - result.append( - (begin_match_obj.span(), end_match_obj.span(), attr_dict) - ) - return result - - def get_global_dict_from_config(self) -> dict[str, typing.Any]: - result = { - "line_height": ( - (self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1 - ) * 0.6, - "font_family": self.font, - "font_size": self.font_size * 1024, - "font_style": self.slant, - "font_weight": self.weight - } - result.update(self.global_config) - return result - - def get_local_dicts_from_markup( - self - ) -> list[Span, dict[str, str]]: - return sorted([ - ((begin_tag_span[0], end_tag_span[1]), attr_dict) - for begin_tag_span, end_tag_span, attr_dict - in self.tag_items_from_markup - ]) - - def get_local_dicts_from_config( - self - ) -> list[Span, dict[str, typing.Any]]: - return [ - (span, {key: val}) - for t2x_dict, key in ( - (self.t2c, "foreground"), - (self.t2f, "font_family"), - (self.t2s, "font_style"), - (self.t2w, "font_weight") - ) - for substr_or_span, val in t2x_dict.items() - for span in self.find_substr_or_span(substr_or_span) - ] + [ - (span, local_config) - for substr_or_span, local_config in self.local_configs.items() - for span in self.find_substr_or_span(substr_or_span) - ] - - def get_predefined_attr_dicts(self) -> list[Span, dict[str, str]]: - attr_dict_items = [ - (self.full_span, self.global_dict_from_config), - *self.local_dicts_from_markup, - *self.local_dicts_from_config - ] - return [ - (span, { - SPAN_ATTR_KEY_CONVERSION[key.lower()]: str(val) - for key, val in attr_dict.items() - }) - for span, attr_dict in attr_dict_items - ] + def validate_markup_string(markup_str: str) -> None: + validate_error = manimpango.MarkupUtils.validate(markup_str) + if not validate_error: + return + raise ValueError( + f"Invalid markup string \"{markup_str}\"\n" + f"{validate_error}" + ) # Parsing - def get_command_repl_items(self) -> list[tuple[Span, str]]: - result = [ - (tag_span, "") - for begin_tag, end_tag, _ in self.tag_items_from_markup - for tag_span in (begin_tag, end_tag) - ] + def get_cmd_spans(self) -> list[Span]: if not self.is_markup: - result += [ - (span, escaped) - for char, escaped in ( - ("&", "&"), - (">", ">"), - ("<", "<") - ) - for span in self.find_substr(char) - ] - return result + return self.find_spans(r"""[<>&"']""") - def get_extra_entity_spans(self) -> list[Span]: - if not self.is_markup: - return [] - return self.find_spans(r"&.*?;") - - def get_extra_ignored_spans(self) -> list[int]: - return [] - - def get_internal_specified_spans(self) -> list[Span]: - return [span for span, _ in self.local_dicts_from_markup] - - def get_external_specified_spans(self) -> list[Span]: - return [span for span, _ in self.local_dicts_from_config] - - def get_label_span_list(self) -> list[Span]: - breakup_indices = remove_list_redundancies(list(it.chain(*it.chain( - self.find_spans(r"\s+"), - self.find_spans(r"\b"), - self.specified_spans - )))) - breakup_indices = sorted(filter( - lambda index: not any([ - span[0] < index < span[1] - for span in self.entity_spans - ]), - breakup_indices - )) - return list(filter( - lambda span: self.get_substr(span).strip(), - self.get_neighbouring_pairs(breakup_indices) - )) - - def get_content(self, use_plain_file: bool) -> str: - if use_plain_file: - attr_dict_items = [ - (self.full_span, {"foreground": self.base_color}), - *self.predefined_attr_dicts, - *[ - (span, {}) - for span in self.label_span_list - ] - ] - else: - attr_dict_items = [ - (self.full_span, {"foreground": BLACK}), - *[ - (span, { - key: BLACK if key in COLOR_RELATED_KEYS else val - for key, val in attr_dict.items() - }) - for span, attr_dict in self.predefined_attr_dicts - ], - *[ - (span, {"foreground": self.int_to_hex(label + 1)}) - for label, span in enumerate(self.label_span_list) - ] - ] - inserted_string_pairs = [ - (span, ( - f"", - "" - )) - for span, attr_dict in self.merge_attr_dicts(attr_dict_items) - ] - span_repl_dict = self.generate_span_repl_dict( - inserted_string_pairs, self.command_repl_items + # Unsupported passthroughs: + # "", "", "", "" + # See https://gitlab.gnome.org/GNOME/glib/-/blob/main/glib/gmarkup.c + return self.find_spans( + r"""&[\s\S]*?;|[>"']|""" ) - return self.get_replaced_substr(self.full_span, span_repl_dict) - @property - def has_predefined_local_colors(self) -> bool: - return any([ - key in COLOR_RELATED_KEYS - for _, attr_dict in self.predefined_attr_dicts - for key in attr_dict.keys() + def get_substr_flag(self, substr: str) -> int: + if re.fullmatch(r"<\w[\s\S]*[^/]>", substr): + return 1 + if substr.startswith(" str: + if substr.startswith("<") and substr.endswith(">"): + return "" + return { + "<": "<", + ">": ">", + "&": "&", + "\"": """, + "'": "'" + }.get(substr, substr) + + def get_repl_substr_for_matching(self, substr: str) -> str: + if substr.startswith("<") and substr.endswith(">"): + return "" + if substr.startswith("&#") and substr.endswith(";"): + if substr.startswith("&#x"): + char_reference = int(substr[3:-1], 16) + else: + char_reference = int(substr[2:-1], 10) + return chr(char_reference) + return { + "<": "<", + ">": ">", + "&": "&", + """: "\"", + "'": "'" + }.get(substr, substr) + + def get_specified_items( + self, cmd_span_pairs: list[tuple[Span, Span]] + ) -> list[tuple[Span, dict[str, str]]]: + attr_pattern = r"""(\w+)\s*\=\s*(["'])([\s\S]*?)\2""" + internal_items = [] + for begin_cmd_span, end_cmd_span in cmd_span_pairs: + begin_tag = self.get_substr(begin_cmd_span) + tag_name = re.match(r"<(\w+)", begin_tag).group(1) + if tag_name == "span": + attr_dict = { + attr_match_obj.group(1): attr_match_obj.group(3) + for attr_match_obj in re.finditer(attr_pattern, begin_tag) + } + else: + attr_dict = MarkupText.MARKUP_TAGS.get(tag_name, {}) + internal_items.append( + ((begin_cmd_span[1], end_cmd_span[0]), attr_dict) + ) + + return [ + *internal_items, + *[ + (span, {key: val}) + for t2x_dict, key in ( + (self.t2c, "foreground"), + (self.t2f, "font_family"), + (self.t2s, "font_style"), + (self.t2w, "font_weight") + ) + for selector, val in t2x_dict.items() + for span in self.find_spans_by_selector(selector) + ], + *[ + (span, local_config) + for selector, local_config in self.local_configs.items() + for span in self.find_spans_by_selector(selector) + ], + *[ + (span, {}) + for span in self.find_spans_by_selector(self.isolate) + ] + ] + + @staticmethod + def get_cmd_str_pair( + attr_dict: dict[str, str], label_hex: str | None + ) -> tuple[str, str]: + if label_hex is not None: + converted_attr_dict = {"foreground": label_hex} + for key, val in attr_dict.items(): + substitute_key = MarkupText.MARKUP_COLOR_KEYS.get(key, None) + if substitute_key is None: + converted_attr_dict[key] = val + elif substitute_key: + converted_attr_dict[key] = "black" + else: + converted_attr_dict = attr_dict.copy() + attrs_str = " ".join([ + f"{key}='{val}'" + for key, val in converted_attr_dict.items() ]) + return f"", "" + + def get_content_prefix_and_suffix( + self, is_labelled: bool + ) -> tuple[str, str]: + global_attr_dict = { + "foreground": self.base_color_hex, + "font_family": self.font, + "font_style": self.slant, + "font_weight": self.weight, + "font_size": str(self.font_size * 1024), + } + global_attr_dict.update(self.global_config) + # `line_height` attribute is supported since Pango 1.50. + pango_version = manimpango.pango_version() + if tuple(map(int, pango_version.split("."))) < (1, 50): + if self.lsh is not None: + log.warning( + "Pango version %s found (< 1.50), " + "unable to set `line_height` attribute", + pango_version + ) + else: + line_spacing_scale = self.lsh or DEFAULT_LINE_SPACING_SCALE + global_attr_dict["line_height"] = str( + ((line_spacing_scale) + 1) * 0.6 + ) + + return self.get_cmd_str_pair( + global_attr_dict, + label_hex=self.int_to_hex(0) if is_labelled else None + ) # Method alias - def get_parts_by_text(self, text: str, **kwargs) -> VGroup: - return self.get_parts_by_string(text, **kwargs) + def get_parts_by_text(self, selector: Selector) -> VGroup: + return self.select_parts(selector) - def get_part_by_text(self, text: str, **kwargs) -> VMobject: - return self.get_part_by_string(text, **kwargs) + def get_part_by_text(self, selector: Selector) -> VGroup: + return self.select_part(selector) - def set_color_by_text(self, text: str, color: ManimColor, **kwargs): - return self.set_color_by_string(text, color, **kwargs) + def set_color_by_text(self, selector: Selector, color: ManimColor): + return self.set_parts_color(selector, color) def set_color_by_text_to_color_map( - self, text_to_color_map: dict[str, ManimColor], **kwargs + self, color_map: dict[Selector, ManimColor] ): - return self.set_color_by_string_to_color_map( - text_to_color_map, **kwargs - ) + return self.set_parts_color_by_dict(color_map) def get_text(self) -> str: return self.get_string() diff --git a/manimlib/scene/interactive_scene.py b/manimlib/scene/interactive_scene.py index 40ec18c5..3b7397b8 100644 --- a/manimlib/scene/interactive_scene.py +++ b/manimlib/scene/interactive_scene.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import itertools as it import numpy as np import pyperclip diff --git a/manimlib/utils/iterables.py b/manimlib/utils/iterables.py index bdaa76c2..fa54e68d 100644 --- a/manimlib/utils/iterables.py +++ b/manimlib/utils/iterables.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: S = TypeVar("S") -def remove_list_redundancies(l: Iterable[T]) -> list[T]: +def remove_list_redundancies(l: Sequence[T]) -> list[T]: """ Used instead of list(set(l)) to maintain order Keeps the last occurrence of each element @@ -40,14 +40,14 @@ def list_difference_update(l1: Iterable[T], l2: Iterable[T]) -> list[T]: return [e for e in l1 if e not in l2] -def adjacent_n_tuples(objects: Iterable[T], n: int) -> zip[tuple[T, T]]: +def adjacent_n_tuples(objects: Sequence[T], n: int) -> zip[tuple[T, T]]: return zip(*[ [*objects[k:], *objects[:k]] for k in range(n) ]) -def adjacent_pairs(objects: Iterable[T]) -> zip[tuple[T, T]]: +def adjacent_pairs(objects: Sequence[T]) -> zip[tuple[T, T]]: return adjacent_n_tuples(objects, 2)