diff --git a/manimlib/animation/creation.py b/manimlib/animation/creation.py index 27460899..6ad6a9bd 100644 --- a/manimlib/animation/creation.py +++ b/manimlib/animation/creation.py @@ -1,12 +1,12 @@ from __future__ import annotations -import itertools as it -from abc import abstractmethod +from abc import ABC, abstractmethod import numpy as np from manimlib.animation.animation import Animation from manimlib.mobject.svg.labelled_string import LabelledString +from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.bezier import integer_interpolate from manimlib.utils.config_ops import digest_config @@ -17,10 +17,10 @@ from manimlib.utils.rate_functions import smooth from typing import TYPE_CHECKING if TYPE_CHECKING: - from manimlib.mobject.mobject import Group + from manimlib.mobject.mobject import Mobject -class ShowPartial(Animation): +class ShowPartial(Animation, ABC): """ Abstract class for ShowCreation and ShowPassingFlash """ @@ -176,7 +176,7 @@ class ShowIncreasingSubsets(Animation): "int_func": np.round, } - def __init__(self, group: Group, **kwargs): + def __init__(self, group: Mobject, **kwargs): self.all_submobs = list(group.submobjects) super().__init__(group, **kwargs) @@ -213,7 +213,9 @@ class AddTextWordByWord(ShowIncreasingSubsets): def __init__(self, string_mobject, **kwargs): assert isinstance(string_mobject, LabelledString) - grouped_mobject = string_mobject.submob_groups + grouped_mobject = VGroup(*[ + part for _, part in string_mobject.get_group_part_items() + ]) digest_config(self, kwargs) if self.run_time is None: self.run_time = self.time_per_word * len(grouped_mobject) diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index dab88005..e84f1d9d 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -5,9 +5,9 @@ import itertools as it import numpy as np from manimlib.animation.composition import AnimationGroup -from manimlib.animation.fading import FadeTransformPieces from manimlib.animation.fading import FadeInFromPoint from manimlib.animation.fading import FadeOutToPoint +from manimlib.animation.fading import FadeTransformPieces from manimlib.animation.transform import ReplacementTransform from manimlib.animation.transform import Transform from manimlib.mobject.mobject import Mobject @@ -16,13 +16,13 @@ from manimlib.mobject.svg.labelled_string import LabelledString from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.config_ops import digest_config -from manimlib.utils.iterables import remove_list_redundancies from typing import TYPE_CHECKING if TYPE_CHECKING: + from manimlib.mobject.svg.tex_mobject import SingleStringTex + from manimlib.mobject.svg.tex_mobject import Tex from manimlib.scene.scene import Scene - from manimlib.mobject.svg.tex_mobject import Tex, SingleStringTex class TransformMatchingParts(AnimationGroup): @@ -168,36 +168,36 @@ class TransformMatchingStrings(AnimationGroup): assert isinstance(source, LabelledString) assert isinstance(target, LabelledString) anims = [] - source_indices = list(range(len(source.labelled_submobjects))) - target_indices = list(range(len(target.labelled_submobjects))) - def get_indices_lists(mobject, parts): - return [ + source_submobs = [ + submob for _, submob in source.labelled_submobject_items + ] + target_submobs = [ + submob for _, submob in target.labelled_submobject_items + ] + source_indices = list(range(len(source_submobs))) + target_indices = list(range(len(target_submobs))) + + def get_filtered_indices_lists(parts, submobs, rest_indices): + return list(filter( + lambda indices_list: all([ + index in rest_indices + for index in indices_list + ]), [ - mobject.labelled_submobjects.index(submob) - for submob in part + [submobs.index(submob) for submob in part] + for part in parts ] - for part in parts - ] + )) - def add_anims_from(anim_class, func, source_args, target_args=None): - if target_args is None: - target_args = source_args.copy() - for source_arg, target_arg in zip(source_args, target_args): - source_parts = func(source, source_arg) - target_parts = func(target, target_arg) - source_indices_lists = list(filter( - lambda indices_list: all([ - index in source_indices - for index in indices_list - ]), get_indices_lists(source, source_parts) - )) - target_indices_lists = list(filter( - lambda indices_list: all([ - index in target_indices - for index in indices_list - ]), get_indices_lists(target, target_parts) - )) + def add_anims(anim_class, parts_pairs): + for source_parts, target_parts in parts_pairs: + source_indices_lists = get_filtered_indices_lists( + source_parts, source_submobs, source_indices + ) + target_indices_lists = get_filtered_indices_lists( + target_parts, target_submobs, target_indices + ) if not source_indices_lists or not target_indices_lists: continue anims.append(anim_class(source_parts, target_parts, **kwargs)) @@ -206,41 +206,45 @@ class TransformMatchingStrings(AnimationGroup): for index in it.chain(*target_indices_lists): target_indices.remove(index) - def get_common_substrs(substrs_from_source, substrs_from_target): - return sorted([ - substr for substr in substrs_from_source - if substr and substr in substrs_from_target - ], key=len, reverse=True) - - def get_parts_from_keys(mobject, keys): - if isinstance(keys, str): - keys = [keys] - result = VGroup() - for key in keys: - if not isinstance(key, str): - raise TypeError(key) - result.add(*mobject.get_parts_by_string(key)) + def get_substr_to_parts_map(part_items): + result = {} + for substr, part in part_items: + if substr not in result: + result[substr] = [] + result[substr].append(part) return result - add_anims_from( - ReplacementTransform, get_parts_from_keys, - self.key_map.keys(), self.key_map.values() + def add_anims_from(anim_class, func): + source_substr_to_parts_map = get_substr_to_parts_map(func(source)) + target_substr_to_parts_map = get_substr_to_parts_map(func(target)) + add_anims( + anim_class, + [ + ( + VGroup(*source_substr_to_parts_map[substr]), + VGroup(*target_substr_to_parts_map[substr]) + ) + for substr in sorted([ + s for s in source_substr_to_parts_map.keys() + if s and s in target_substr_to_parts_map.keys() + ], key=len, reverse=True) + ] + ) + + add_anims( + ReplacementTransform, + [ + (source.select_parts(k), target.select_parts(v)) + for k, v in self.key_map.items() + ] ) add_anims_from( FadeTransformPieces, - LabelledString.get_parts_by_string, - get_common_substrs( - source.specified_substrs, - target.specified_substrs - ) + LabelledString.get_specified_part_items ) add_anims_from( FadeTransformPieces, - LabelledString.get_parts_by_group_substr, - get_common_substrs( - source.group_substrs, - target.group_substrs - ) + LabelledString.get_group_part_items ) rest_source = VGroup(*[source[index] for index in source_indices]) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index f1354f0c..1c3f0afd 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -1,30 +1,43 @@ from __future__ import annotations -import re -import colour -import itertools as it -from typing import Iterable, Union, Sequence from abc import ABC, abstractmethod +import itertools as it +import numpy as np +import re -from manimlib.constants import BLACK, WHITE +from manimlib.constants import WHITE from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.types.vectorized_mobject import VGroup -from manimlib.utils.color import color_to_int_rgb from manimlib.utils.color import color_to_rgb from manimlib.utils.color import rgb_to_hex from manimlib.utils.config_ops import digest_config from manimlib.utils.iterables import remove_list_redundancies - from typing import TYPE_CHECKING if TYPE_CHECKING: - from manimlib.mobject.types.vectorized_mobject import VMobject - ManimColor = Union[str, colour.Color, Sequence[float]] + from colour import Color + from typing import Iterable, Sequence, TypeVar, Union + + ManimColor = Union[str, Color] Span = tuple[int, int] + Selector = Union[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]], + Iterable[Union[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]] + ]] + ] + T = TypeVar("T") -class _StringSVG(SVGMobject): +class LabelledString(SVGMobject, ABC): + """ + An abstract base class for `MTex` and `MarkupText` + """ CONFIG = { "height": None, "stroke_width": 0, @@ -33,42 +46,30 @@ class _StringSVG(SVGMobject): "should_subdivide_sharp_curves": True, "should_remove_null_curves": True, }, - } - - -class LabelledString(_StringSVG, ABC): - """ - An abstract base class for `MTex` and `MarkupText` - """ - CONFIG = { "base_color": WHITE, - "use_plain_file": False, "isolate": [], } def __init__(self, string: str, **kwargs): self.string = string digest_config(self, kwargs) + if self.base_color is None: + self.base_color = WHITE + self.base_color_int = self.color_to_int(self.base_color) - # Convert `base_color` to hex code. - self.base_color = rgb_to_hex(color_to_rgb( - self.base_color \ - or self.svg_default.get("color", None) \ - or self.svg_default.get("fill_color", None) \ - or WHITE - )) - self.svg_default["fill_color"] = BLACK - - self.pre_parse() + self.full_span = (0, len(self.string)) self.parse() - super().__init__() - self.post_parse() + super().__init__(**kwargs) + self.labelled_submobject_items = [ + (submob.label, submob) + for submob in self.submobjects + ] def get_file_path(self) -> str: - return self.get_file_path_(use_plain_file=False) + return self.get_file_path_(is_labelled=False) - def get_file_path_(self, use_plain_file: bool) -> str: - content = self.get_content(use_plain_file) + def get_file_path_(self, is_labelled: bool) -> str: + content = self.get_content(is_labelled) return self.get_file_path_by_content(content) @abstractmethod @@ -78,91 +79,135 @@ class LabelledString(_StringSVG, ABC): def generate_mobject(self) -> None: super().generate_mobject() - submob_labels = [ - self.color_to_label(submob.get_fill_color()) - for submob in self.submobjects - ] - if self.use_plain_file or self.has_predefined_local_colors: - file_path = self.get_file_path_(use_plain_file=True) - plain_svg = _StringSVG( - file_path, - svg_default=self.svg_default, - path_string_config=self.path_string_config - ) - self.set_submobjects(plain_svg.submobjects) + num_labels = len(self.label_span_list) + if num_labels: + file_path = self.get_file_path_(is_labelled=True) + labelled_svg = SVGMobject(file_path) + submob_color_ints = [ + self.color_to_int(submob.get_fill_color()) + for submob in labelled_svg.submobjects + ] else: - self.set_fill(self.base_color) - for submob, label in zip(self.submobjects, submob_labels): - submob.label = label + submob_color_ints = [0] * len(self.submobjects) - def pre_parse(self) -> None: - self.string_len = len(self.string) - self.full_span = (0, self.string_len) + if len(self.submobjects) != len(submob_color_ints): + raise ValueError( + "Cannot align submobjects of the labelled svg " + "to the original svg" + ) + + unrecognized_color_ints = self.remove_redundancies(sorted(filter( + lambda color_int: color_int > num_labels, + submob_color_ints + ))) + if unrecognized_color_ints: + raise ValueError( + "Unrecognized color label(s) detected: " + f"{','.join(map(self.int_to_hex, unrecognized_color_ints))}" + ) + + for submob, color_int in zip(self.submobjects, submob_color_ints): + submob.label = color_int - 1 def parse(self) -> None: self.command_repl_items = self.get_command_repl_items() - self.command_spans = self.get_command_spans() - self.extra_entity_spans = self.get_extra_entity_spans() - self.entity_spans = self.get_entity_spans() - self.extra_ignored_spans = self.get_extra_ignored_spans() - self.skipped_spans = self.get_skipped_spans() - self.internal_specified_spans = self.get_internal_specified_spans() - self.external_specified_spans = self.get_external_specified_spans() self.specified_spans = self.get_specified_spans() - self.label_span_list = self.get_label_span_list() self.check_overlapping() - - def post_parse(self) -> None: - self.labelled_submobject_items = [ - (submob.label, submob) - for submob in self.submobjects - ] - self.labelled_submobjects = self.get_labelled_submobjects() - self.specified_substrs = self.get_specified_substrs() - self.group_items = self.get_group_items() - self.group_substrs = self.get_group_substrs() - self.submob_groups = self.get_submob_groups() + self.label_span_list = self.get_label_span_list() + if len(self.label_span_list) >= 16777216: + raise ValueError("Cannot handle that many substrings") # Toolkits def get_substr(self, span: Span) -> str: return self.string[slice(*span)] - def finditer( - self, pattern: str, flags: int = 0, **kwargs - ) -> Iterable[re.Match]: - return re.compile(pattern, flags).finditer(self.string, **kwargs) + def match(self, pattern: str | re.Pattern, **kwargs) -> re.Pattern | None: + if isinstance(pattern, str): + pattern = re.compile(pattern) + return re.compile(pattern).match(self.string, **kwargs) - def search( - self, pattern: str, flags: int = 0, **kwargs - ) -> re.Match | None: - return re.compile(pattern, flags).search(self.string, **kwargs) - - def match( - self, pattern: str, flags: int = 0, **kwargs - ) -> re.Match | None: - return re.compile(pattern, flags).match(self.string, **kwargs) - - def find_spans(self, pattern: str, **kwargs) -> list[Span]: + def find_spans(self, pattern: str | re.Pattern, **kwargs) -> list[Span]: + if isinstance(pattern, str): + pattern = re.compile(pattern) return [ match_obj.span() - for match_obj in self.finditer(pattern, **kwargs) + for match_obj in pattern.finditer(self.string, **kwargs) ] - def find_substr(self, substr: str, **kwargs) -> list[Span]: - if not substr: - return [] - return self.find_spans(re.escape(substr), **kwargs) - - def find_substrs(self, substrs: list[str], **kwargs) -> list[Span]: - return list(it.chain(*[ - self.find_substr(substr, **kwargs) - for substr in remove_list_redundancies(substrs) - ])) + def find_indices(self, pattern: str | re.Pattern, **kwargs) -> list[int]: + return [index for index, _ in self.find_spans(pattern, **kwargs)] @staticmethod - def get_neighbouring_pairs(iterable: list) -> list[tuple]: - return list(zip(iterable[:-1], iterable[1:])) + def is_single_selector(selector: Selector) -> bool: + if isinstance(selector, str): + return True + if isinstance(selector, re.Pattern): + return True + if isinstance(selector, tuple): + if len(selector) == 2 and all([ + isinstance(index, int) or index is None + for index in selector + ]): + return True + return False + + def find_spans_by_selector(self, selector: Selector) -> list[Span]: + if self.is_single_selector(selector): + selector = (selector,) + result = [] + for sel in selector: + if not self.is_single_selector(sel): + raise TypeError(f"Invalid selector: '{sel}'") + if isinstance(sel, str): + spans = self.find_spans(re.escape(sel)) + elif isinstance(sel, re.Pattern): + spans = self.find_spans(sel) + else: + string_len = self.full_span[1] + span = tuple([ + ( + min(index, string_len) + if index >= 0 + else max(index + string_len, 0) + ) + if index is not None else default_index + for index, default_index in zip(sel, self.full_span) + ]) + spans = [span] + result.extend(spans) + return sorted(filter( + lambda span: span[0] < span[1], + self.remove_redundancies(result) + )) + + @staticmethod + def chain(*iterables: Iterable[T]) -> list[T]: + return list(it.chain(*iterables)) + + @staticmethod + def remove_redundancies(vals: Sequence[T]) -> list[T]: + return remove_list_redundancies(vals) + + @staticmethod + def get_neighbouring_pairs(vals: Sequence[T]) -> list[tuple[T, T]]: + return list(zip(vals[:-1], vals[1:])) + + @staticmethod + def compress_neighbours(vals: Sequence[T]) -> list[tuple[T, Span]]: + if not vals: + return [] + + unique_vals = [vals[0]] + indices = [0] + for index, val in enumerate(vals): + if val == unique_vals[-1]: + continue + unique_vals.append(val) + indices.append(index) + indices.append(len(vals)) + spans = LabelledString.get_neighbouring_pairs(indices) + return list(zip(unique_vals, spans)) @staticmethod def span_contains(span_0: Span, span_1: Span) -> bool: @@ -182,194 +227,88 @@ class LabelledString(_StringSVG, ABC): )) @staticmethod - def compress_neighbours(vals: list[int]) -> list[tuple[int, Span]]: - if not vals: + def merge_inserted_strings_from_pairs( + inserted_string_pairs: list[tuple[Span, tuple[str, str]]] + ) -> list[tuple[int, str]]: + if not inserted_string_pairs: return [] - unique_vals = [vals[0]] - indices = [0] - for index, val in enumerate(vals): - if val == unique_vals[-1]: - continue - unique_vals.append(val) - indices.append(index) - indices.append(len(vals)) - spans = LabelledString.get_neighbouring_pairs(indices) - return list(zip(unique_vals, spans)) - - @staticmethod - def find_region_index(seq: list[int], val: int) -> int: - # Returns an integer in `range(-1, len(seq))` satisfying - # `seq[result] <= val < seq[result + 1]`. - # `seq` should be sorted in ascending order. - if not seq or val < seq[0]: - return -1 - result = len(seq) - 1 - while val < seq[result]: - result -= 1 - return result - - @staticmethod - def take_nearest_value(seq: list[int], val: int, index_shift: int) -> int: - sorted_seq = sorted(seq) - index = LabelledString.find_region_index(sorted_seq, val) - return sorted_seq[index + index_shift] - - @staticmethod - def generate_span_repl_dict( - inserted_string_pairs: list[tuple[Span, tuple[str, str]]], - other_repl_items: list[tuple[Span, str]] - ) -> dict[Span, str]: - result = dict(other_repl_items) - if not inserted_string_pairs: - return result - - indices, _, _, inserted_strings = zip(*sorted([ - ( - span[flag], - -flag, - -span[1 - flag], - str_pair[flag] + spans = [ + span for span, _ in inserted_string_pairs + ] + sorted_index_flag_pairs = sorted( + it.product(range(len(spans)), range(2)), + key=lambda t: ( + spans[t[0]][t[1]], + np.sign(spans[t[0]][1 - t[1]] - spans[t[0]][t[1]]), + -spans[t[0]][1 - t[1]], + t[1], + (1, -1)[t[1]] * t[0] ) - for span, str_pair in inserted_string_pairs - for flag in range(2) - ])) - result.update({ - (index, index): "".join(inserted_strings[slice(*item_span)]) + ) + indices, inserted_strings = zip(*[ + list(zip(*inserted_string_pairs[item_index]))[flag] + for item_index, flag in sorted_index_flag_pairs + ]) + return [ + (index, "".join(inserted_strings[slice(*item_span)])) for index, item_span in LabelledString.compress_neighbours(indices) - }) - return result + ] def get_replaced_substr( - self, span: Span, span_repl_dict: dict[Span, str] + self, span: Span, repl_items: list[tuple[Span, str]] ) -> str: - repl_spans = sorted(filter( - lambda repl_span: self.span_contains(span, repl_span), - span_repl_dict.keys() - )) - if not all( - span_0[1] <= span_1[0] - for span_0, span_1 in self.get_neighbouring_pairs(repl_spans) - ): - raise ValueError("Overlapping replacement") + if not repl_items: + return self.get_substr(span) + sorted_repl_items = sorted(repl_items, key=lambda t: t[0]) + repl_spans, repl_strs = zip(*sorted_repl_items) pieces = [ self.get_substr(piece_span) for piece_span in self.get_complement_spans(repl_spans, span) ] - repl_strs = [span_repl_dict[repl_span] for repl_span in repl_spans] - repl_strs.append("") - return "".join(it.chain(*zip(pieces, repl_strs))) + repl_strs = [*repl_strs, ""] + return "".join(self.chain(*zip(pieces, repl_strs))) + + def get_replaced_string( + self, + inserted_string_pairs: list[tuple[Span, tuple[str, str]]], + repl_items: list[tuple[Span, str]] + ) -> str: + all_repl_items = self.chain( + repl_items, + [ + ((index, index), inserted_string) + for index, inserted_string + in self.merge_inserted_strings_from_pairs( + inserted_string_pairs + ) + ] + ) + return self.get_replaced_substr(self.full_span, all_repl_items) @staticmethod - def rslide(index: int, skipped: list[Span]) -> int: - transfer_dict = dict(sorted(skipped)) - while index in transfer_dict.keys(): - index = transfer_dict[index] - return index - - @staticmethod - def lslide(index: int, skipped: list[Span]) -> int: - transfer_dict = dict(sorted([ - skipped_span[::-1] for skipped_span in skipped - ], reverse=True)) - while index in transfer_dict.keys(): - index = transfer_dict[index] - return index - - @staticmethod - def rgb_to_int(rgb_tuple: tuple[int, int, int]) -> int: - r, g, b = rgb_tuple - rg = r * 256 + g - return rg * 256 + b - - @staticmethod - def int_to_rgb(rgb_int: int) -> tuple[int, int, int]: - rg, b = divmod(rgb_int, 256) - r, g = divmod(rg, 256) - return r, g, b + def color_to_int(color: ManimColor) -> int: + hex_code = rgb_to_hex(color_to_rgb(color)) + return int(hex_code[1:], 16) @staticmethod def int_to_hex(rgb_int: int) -> str: return "#{:06x}".format(rgb_int).upper() - @staticmethod - def hex_to_int(rgb_hex: str) -> int: - return int(rgb_hex[1:], 16) - - @staticmethod - def color_to_label(color: ManimColor) -> int: - rgb_tuple = color_to_int_rgb(color) - rgb = LabelledString.rgb_to_int(rgb_tuple) - return rgb - 1 - # Parsing @abstractmethod def get_command_repl_items(self) -> list[tuple[Span, str]]: return [] - def get_command_spans(self) -> list[Span]: - return [cmd_span for cmd_span, _ in self.command_repl_items] - @abstractmethod - def get_extra_entity_spans(self) -> list[Span]: - return [] - - def get_entity_spans(self) -> list[Span]: - return list(it.chain( - self.command_spans, - self.extra_entity_spans - )) - - @abstractmethod - def get_extra_ignored_spans(self) -> list[int]: - return [] - - def get_skipped_spans(self) -> list[Span]: - return list(it.chain( - self.find_spans(r"\s"), - self.command_spans, - self.extra_ignored_spans - )) - - def shrink_span(self, span: Span) -> Span: - return ( - self.rslide(span[0], self.skipped_spans), - self.lslide(span[1], self.skipped_spans) - ) - - @abstractmethod - def get_internal_specified_spans(self) -> list[Span]: - return [] - - @abstractmethod - def get_external_specified_spans(self) -> list[Span]: - return [] - def get_specified_spans(self) -> list[Span]: - spans = list(it.chain( - self.internal_specified_spans, - self.external_specified_spans, - self.find_substrs(self.isolate) - )) - shrinked_spans = list(filter( - lambda span: span[0] < span[1] and not any([ - entity_span[0] < index < entity_span[1] - for index in span - for entity_span in self.entity_spans - ]), - [self.shrink_span(span) for span in spans] - )) - return remove_list_redundancies(shrinked_spans) - - @abstractmethod - def get_label_span_list(self) -> list[Span]: return [] def check_overlapping(self) -> None: - for span_0, span_1 in it.product(self.label_span_list, repeat=2): + for span_0, span_1 in it.product(self.specified_spans, repeat=2): if not span_0[0] < span_1[0] < span_0[1] < span_1[1]: continue raise ValueError( @@ -378,29 +317,20 @@ class LabelledString(_StringSVG, ABC): ) @abstractmethod - def get_content(self, use_plain_file: bool) -> str: - return "" + def get_label_span_list(self) -> list[Span]: + return [] @abstractmethod - def has_predefined_local_colors(self) -> bool: - return False + def get_content(self, is_labelled: bool) -> str: + return "" - # Post-parsing - - def get_labelled_submobjects(self) -> list[VMobject]: - return [submob for _, submob in self.labelled_submobject_items] + # Selector + @abstractmethod def get_cleaned_substr(self, span: Span) -> str: - span_repl_dict = dict.fromkeys(self.command_spans, "") - return self.get_replaced_substr(span, span_repl_dict) + return "" - def get_specified_substrs(self) -> list[str]: - return remove_list_redundancies([ - self.get_cleaned_substr(span) - for span in self.specified_spans - ]) - - def get_group_items(self) -> list[tuple[str, VGroup]]: + def get_group_part_items(self) -> list[tuple[str, VGroup]]: if not self.labelled_submobject_items: return [] @@ -425,118 +355,56 @@ class LabelledString(_StringSVG, ABC): ordered_spans ) ] - shrinked_spans = [ - self.shrink_span(span) + group_substrs = [ + self.get_cleaned_substr(span) if span[0] < span[1] else "" for span in self.get_complement_spans( interval_spans, (ordered_spans[0][0], ordered_spans[-1][1]) ) ] - group_substrs = [ - self.get_cleaned_substr(span) if span[0] < span[1] else "" - for span in shrinked_spans - ] submob_groups = VGroup(*[ VGroup(*labelled_submobjects[slice(*submob_span)]) for submob_span in labelled_submob_spans ]) return list(zip(group_substrs, submob_groups)) - def get_group_substrs(self) -> list[str]: - return [group_substr for group_substr, _ in self.group_items] - - def get_submob_groups(self) -> list[VGroup]: - return [submob_group for _, submob_group in self.group_items] - - def get_parts_by_group_substr(self, substr: str) -> VGroup: - return VGroup(*[ - group - for group_substr, group in self.group_items - if group_substr == substr - ]) - - # Selector - - def find_span_components( - self, custom_span: Span, substring: bool = True - ) -> list[Span]: - shrinked_span = self.shrink_span(custom_span) - if shrinked_span[0] >= shrinked_span[1]: - return [] - - if substring: - indices = remove_list_redundancies(list(it.chain( - self.full_span, - *self.label_span_list - ))) - span_begin = self.take_nearest_value( - indices, shrinked_span[0], 0 + def get_specified_part_items(self) -> list[tuple[str, VGroup]]: + return [ + ( + self.get_substr(span), + self.select_part_by_span(span) ) - span_end = self.take_nearest_value( - indices, shrinked_span[1] - 1, 1 - ) - else: - span_begin, span_end = shrinked_span + for span in self.specified_spans + ] - span_choices = sorted(filter( - lambda span: self.span_contains((span_begin, span_end), span), - self.label_span_list - )) - # Choose spans that reach the farthest. - span_choices_dict = dict(span_choices) - - result = [] - while span_begin < span_end: - if span_begin not in span_choices_dict.keys(): - span_begin += 1 - continue - next_begin = span_choices_dict[span_begin] - result.append((span_begin, next_begin)) - span_begin = next_begin - return result - - def get_part_by_custom_span(self, custom_span: Span, **kwargs) -> VGroup: + def select_part_by_span(self, custom_span: Span) -> VGroup: labels = [ label for label, span in enumerate(self.label_span_list) - if any([ - self.span_contains(span_component, span) - for span_component in self.find_span_components( - custom_span, **kwargs - ) - ]) + if self.span_contains(custom_span, span) ] return VGroup(*[ submob for label, submob in self.labelled_submobject_items if label in labels ]) - def get_parts_by_string( - self, substr: str, - case_sensitive: bool = True, regex: bool = False, **kwargs - ) -> VGroup: - flags = 0 - if not case_sensitive: - flags |= re.I - pattern = substr if regex else re.escape(substr) - return VGroup(*[ - self.get_part_by_custom_span(span, **kwargs) - for span in self.find_spans(pattern, flags=flags) - if span[0] < span[1] - ]) + def select_parts(self, selector: Selector) -> VGroup: + return VGroup(*filter( + lambda part: part.submobjects, + [ + self.select_part_by_span(span) + for span in self.find_spans_by_selector(selector) + ] + )) - def get_part_by_string( - self, substr: str, index: int = 0, **kwargs - ) -> VMobject: - return self.get_parts_by_string(substr, **kwargs)[index] + def select_part(self, selector: Selector, index: int = 0) -> VGroup: + return self.select_parts(selector)[index] - def set_color_by_string(self, substr: str, color: ManimColor, **kwargs): - self.get_parts_by_string(substr, **kwargs).set_color(color) + def set_parts_color(self, selector: Selector, color: ManimColor): + self.select_parts(selector).set_color(color) return self - def set_color_by_string_to_color_map( - self, string_to_color_map: dict[str, ManimColor], **kwargs - ): - for substr, color in string_to_color_map.items(): - self.set_color_by_string(substr, color, **kwargs) + def set_parts_color_by_dict(self, color_map: dict[Selector, ManimColor]): + for selector, color in color_map.items(): + self.set_parts_color(selector, color) return self def get_string(self) -> str: diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index fb7922e1..ed7273ee 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -1,27 +1,46 @@ from __future__ import annotations -import itertools as it -import colour -from typing import Union, Sequence +import re from manimlib.mobject.svg.labelled_string import LabelledString -from manimlib.utils.tex_file_writing import tex_to_svg_file -from manimlib.utils.tex_file_writing import get_tex_config from manimlib.utils.tex_file_writing import display_during_execution - +from manimlib.utils.tex_file_writing import get_tex_config +from manimlib.utils.tex_file_writing import tex_to_svg_file from typing import TYPE_CHECKING if TYPE_CHECKING: - from manimlib.mobject.types.vectorized_mobject import VMobject + from colour import Color + from typing import Iterable, Union + from manimlib.mobject.types.vectorized_mobject import VGroup - ManimColor = Union[str, colour.Color, Sequence[float]] + + ManimColor = Union[str, Color] Span = tuple[int, int] + Selector = Union[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]], + Iterable[Union[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]] + ]] + ] SCALE_FACTOR_PER_FONT_POINT = 0.001 +TEX_COLOR_COMMANDS_DICT = { + "\\color": (1, False), + "\\textcolor": (1, False), + "\\pagecolor": (1, True), + "\\colorbox": (1, True), + "\\fcolorbox": (2, True), +} + + class MTex(LabelledString): CONFIG = { "font_size": 48, @@ -32,7 +51,7 @@ class MTex(LabelledString): def __init__(self, tex_string: str, **kwargs): # Prevent from passing an empty string. - if not tex_string: + if not tex_string.strip(): tex_string = "\\\\" self.tex_string = tex_string super().__init__(tex_string, **kwargs) @@ -47,7 +66,6 @@ class MTex(LabelledString): self.svg_default, self.path_string_config, self.base_color, - self.use_plain_file, self.isolate, self.tex_string, self.alignment, @@ -61,85 +79,95 @@ class MTex(LabelledString): tex_config["text_to_replace"], content ) - with display_during_execution(f"Writing \"{self.tex_string}\""): + with display_during_execution(f"Writing \"{self.string}\""): file_path = tex_to_svg_file(full_tex) return file_path - def pre_parse(self) -> None: - super().pre_parse() + def parse(self) -> None: self.backslash_indices = self.get_backslash_indices() - self.brace_index_pairs = self.get_brace_index_pairs() - self.script_char_spans = self.get_script_char_spans() + self.command_spans = self.get_command_spans() + self.brace_spans = self.get_brace_spans() + self.script_char_indices = self.get_script_char_indices() self.script_content_spans = self.get_script_content_spans() self.script_spans = self.get_script_spans() + super().parse() # Toolkits @staticmethod def get_color_command_str(rgb_int: int) -> str: - rgb_tuple = MTex.int_to_rgb(rgb_int) - return "".join([ - "\\color[RGB]", - "{", - ",".join(map(str, rgb_tuple)), - "}" - ]) + rg, b = divmod(rgb_int, 256) + r, g = divmod(rg, 256) + return f"\\color[RGB]{{{r}, {g}, {b}}}" - # Pre-parsing + @staticmethod + def shrink_span(span: Span, skippable_indices: list[int]) -> Span: + span_begin, span_end = span + while span_begin in skippable_indices: + span_begin += 1 + while span_end - 1 in skippable_indices: + span_end -= 1 + return (span_begin, span_end) + + # Parsing def get_backslash_indices(self) -> list[int]: # The latter of `\\` doesn't count. - return list(it.chain(*[ - range(span[0], span[1], 2) - for span in self.find_spans(r"\\+") - ])) + return self.find_indices(r"\\.") - def get_unescaped_char_spans(self, chars: str): - return sorted(filter( - lambda span: span[0] - 1 not in self.backslash_indices, - self.find_substrs(list(chars)) + def get_command_spans(self) -> list[Span]: + return [ + self.match(r"\\(?:[a-zA-Z]+|.)", pos=index).span() + for index in self.backslash_indices + ] + + def get_unescaped_char_indices(self, char: str) -> list[int]: + return list(filter( + lambda index: index - 1 not in self.backslash_indices, + self.find_indices(re.escape(char)) )) - def get_brace_index_pairs(self) -> list[Span]: - left_brace_indices = [] - right_brace_indices = [] - left_brace_indices_stack = [] - for span in self.get_unescaped_char_spans("{}"): - index = span[0] - if self.get_substr(span) == "{": - left_brace_indices_stack.append(index) + def get_brace_spans(self) -> list[Span]: + span_begins = [] + span_ends = [] + span_begins_stack = [] + char_items = sorted([ + (index, char) + for char in "{}" + for index in self.get_unescaped_char_indices(char) + ]) + for index, char in char_items: + if char == "{": + span_begins_stack.append(index) else: - if not left_brace_indices_stack: + if not span_begins_stack: raise ValueError("Missing '{' inserted") - left_brace_index = left_brace_indices_stack.pop() - left_brace_indices.append(left_brace_index) - right_brace_indices.append(index) - if left_brace_indices_stack: + span_begins.append(span_begins_stack.pop()) + span_ends.append(index + 1) + if span_begins_stack: raise ValueError("Missing '}' inserted") - return list(zip(left_brace_indices, right_brace_indices)) + return list(zip(span_begins, span_ends)) - def get_script_char_spans(self) -> list[int]: - return self.get_unescaped_char_spans("_^") + def get_script_char_indices(self) -> list[int]: + return self.chain(*[ + self.get_unescaped_char_indices(char) + for char in "_^" + ]) def get_script_content_spans(self) -> list[Span]: result = [] - brace_indices_dict = dict(self.brace_index_pairs) - script_pattern = r"[a-zA-Z0-9]|\\[a-zA-Z]+" - for script_char_span in self.script_char_spans: - span_begin = self.match(r"\s*", pos=script_char_span[1]).end() - if span_begin in brace_indices_dict.keys(): - span_end = brace_indices_dict[span_begin] + 1 + script_entity_dict = dict(self.chain( + self.brace_spans, + self.command_spans + )) + for index in self.script_char_indices: + span_begin = self.match(r"\s*", pos=index + 1).end() + if span_begin in script_entity_dict.keys(): + span_end = script_entity_dict[span_begin] else: - match_obj = self.match(script_pattern, pos=span_begin) - if not match_obj: - script_name = { - "_": "subscript", - "^": "superscript" - }[script_char] - raise ValueError( - f"Unclear {script_name} detected while parsing. " - "Please use braces to clarify" - ) + match_obj = self.match(r".", pos=span_begin) + if match_obj is None: + continue span_end = match_obj.end() result.append((span_begin, span_end)) return result @@ -147,110 +175,102 @@ class MTex(LabelledString): def get_script_spans(self) -> list[Span]: return [ ( - self.search(r"\s*$", endpos=script_char_span[0]).start(), + self.match(r"[\s\S]*?(\s*)$", endpos=index).start(1), script_content_span[1] ) - for script_char_span, script_content_span in zip( - self.script_char_spans, self.script_content_spans + for index, script_content_span in zip( + self.script_char_indices, self.script_content_spans ) ] - # Parsing - def get_command_repl_items(self) -> list[tuple[Span, str]]: - color_related_command_dict = { - "color": (1, False), - "textcolor": (1, False), - "pagecolor": (1, True), - "colorbox": (1, True), - "fcolorbox": (2, True), - } result = [] - backslash_indices = self.backslash_indices - right_brace_indices = [ - right_index - for left_index, right_index in self.brace_index_pairs - ] - pattern = "".join([ - r"\\", - "(", - "|".join(color_related_command_dict.keys()), - ")", - r"(?![a-zA-Z])" - ]) - for match_obj in self.finditer(pattern): - span_begin, cmd_end = match_obj.span() - if span_begin not in backslash_indices: + brace_spans_dict = dict(self.brace_spans) + brace_begins = list(brace_spans_dict.keys()) + for cmd_span in self.command_spans: + cmd_name = self.get_substr(cmd_span) + if cmd_name not in TEX_COLOR_COMMANDS_DICT.keys(): continue - cmd_name = match_obj.group(1) - n_braces, substitute_cmd = color_related_command_dict[cmd_name] - span_end = self.take_nearest_value( - right_brace_indices, cmd_end, n_braces - ) + 1 + n_braces, substitute_cmd = TEX_COLOR_COMMANDS_DICT[cmd_name] + span_begin, span_end = cmd_span + for _ in range(n_braces): + span_end = brace_spans_dict[min(filter( + lambda index: index >= span_end, + brace_begins + ))] if substitute_cmd: - repl_str = "\\" + cmd_name + n_braces * "{black}" + repl_str = cmd_name + n_braces * "{black}" else: repl_str = "" result.append(((span_begin, span_end), repl_str)) return result - def get_extra_entity_spans(self) -> list[Span]: - return [ - self.match(r"\\([a-zA-Z]+|.)", pos=index).span() - for index in self.backslash_indices + def get_specified_spans(self) -> list[Span]: + # Match paired double braces (`{{...}}`). + sorted_brace_spans = sorted(self.brace_spans, key=lambda t: t[1]) + inner_brace_spans = [ + sorted_brace_spans[span_span[0]] + for _, span_span in self.compress_neighbours([ + (brace_span[0] + index, brace_span[1] - index) + for index, brace_span in enumerate(sorted_brace_spans) + ]) + if span_span[1] - span_span[0] >= 2 + ] + inner_brace_content_spans = [ + (span[0] + 1, span[1] - 1) + for span in inner_brace_spans + if span[1] - span[0] > 2 ] - def get_extra_ignored_spans(self) -> list[int]: - return self.script_char_spans.copy() - - def get_internal_specified_spans(self) -> list[Span]: - # Match paired double braces (`{{...}}`). - result = [] - reversed_brace_indices_dict = dict([ - pair[::-1] for pair in self.brace_index_pairs - ]) - skip = False - for prev_right_index, right_index in self.get_neighbouring_pairs( - list(reversed_brace_indices_dict.keys()) - ): - if skip: - skip = False - continue - if right_index != prev_right_index + 1: - continue - left_index = reversed_brace_indices_dict[right_index] - prev_left_index = reversed_brace_indices_dict[prev_right_index] - if left_index != prev_left_index - 1: - continue - result.append((left_index, right_index + 1)) - skip = True - return result - - def get_external_specified_spans(self) -> list[Span]: - return self.find_substrs(list(self.tex_to_color_map.keys())) + result = self.remove_redundancies(self.chain( + inner_brace_content_spans, + *[ + self.find_spans_by_selector(selector) + for selector in self.tex_to_color_map.keys() + ], + self.find_spans_by_selector(self.isolate) + )) + return list(filter( + lambda span: not any([ + entity_begin < index < entity_end + for index in span + for entity_begin, entity_end in self.command_spans + ]), + result + )) def get_label_span_list(self) -> list[Span]: + reversed_script_spans_dict = dict([ + script_span[::-1] for script_span in self.script_spans + ]) + skippable_indices = self.chain( + self.find_indices(r"\s"), + self.script_char_indices + ) result = self.script_content_spans.copy() - for span_begin, span_end in self.specified_spans: - shrinked_end = self.lslide(span_end, self.script_spans) - if span_begin >= shrinked_end: + for span in self.specified_spans: + span_begin, span_end = self.shrink_span(span, skippable_indices) + while span_end in reversed_script_spans_dict.keys(): + span_end = reversed_script_spans_dict[span_end] + if span_begin >= span_end: continue - shrinked_span = (span_begin, shrinked_end) + shrinked_span = (span_begin, span_end) if shrinked_span in result: continue result.append(shrinked_span) return result - def get_content(self, use_plain_file: bool) -> str: - if use_plain_file: - span_repl_dict = {} - else: - extended_label_span_list = [ - span - if span in self.script_content_spans - else (span[0], self.rslide(span[1], self.script_spans)) - for span in self.label_span_list - ] + def get_content(self, is_labelled: bool) -> str: + if is_labelled: + extended_label_span_list = [] + script_spans_dict = dict(self.script_spans) + for span in self.label_span_list: + if span not in self.script_content_spans: + span_begin, span_end = span + while span_end in script_spans_dict.keys(): + span_end = script_spans_dict[span_end] + span = (span_begin, span_end) + extended_label_span_list.append(span) inserted_string_pairs = [ (span, ( "{{" + self.get_color_command_str(label + 1), @@ -258,43 +278,52 @@ class MTex(LabelledString): )) for label, span in enumerate(extended_label_span_list) ] - span_repl_dict = self.generate_span_repl_dict( - inserted_string_pairs, - self.command_repl_items + result = self.get_replaced_string( + inserted_string_pairs, self.command_repl_items ) - result = self.get_replaced_substr(self.full_span, span_repl_dict) + else: + result = self.string if self.tex_environment: - result = "\n".join([ - f"\\begin{{{self.tex_environment}}}", - result, - f"\\end{{{self.tex_environment}}}" - ]) + if isinstance(self.tex_environment, str): + prefix = f"\\begin{{{self.tex_environment}}}" + suffix = f"\\end{{{self.tex_environment}}}" + else: + prefix, suffix = self.tex_environment + result = "\n".join([prefix, result, suffix]) if self.alignment: result = "\n".join([self.alignment, result]) - if use_plain_file: + if not is_labelled: result = "\n".join([ - self.get_color_command_str(self.hex_to_int(self.base_color)), + self.get_color_command_str(self.base_color_int), result ]) return result - @property - def has_predefined_local_colors(self) -> bool: - return bool(self.command_repl_items) - - # Post-parsing + # Selector def get_cleaned_substr(self, span: Span) -> str: - substr = super().get_cleaned_substr(span) - if not self.brace_index_pairs: - return substr + if not self.brace_spans: + brace_begins, brace_ends = [], [] + else: + brace_begins, brace_ends = zip(*self.brace_spans) + left_brace_indices = list(brace_begins) + right_brace_indices = [index - 1 for index in brace_ends] + skippable_indices = self.chain( + self.find_indices(r"\s"), + self.script_char_indices, + left_brace_indices, + right_brace_indices + ) + shrinked_span = self.shrink_span(span, skippable_indices) + + if shrinked_span[0] >= shrinked_span[1]: + return "" # Balance braces. - left_brace_indices, right_brace_indices = zip(*self.brace_index_pairs) unclosed_left_braces = 0 unclosed_right_braces = 0 - for index in range(*span): + for index in range(*shrinked_span): if index in left_brace_indices: unclosed_left_braces += 1 elif index in right_brace_indices: @@ -304,27 +333,25 @@ class MTex(LabelledString): unclosed_left_braces -= 1 return "".join([ unclosed_right_braces * "{", - substr, + self.get_substr(shrinked_span), unclosed_left_braces * "}" ]) # Method alias - def get_parts_by_tex(self, tex: str, **kwargs) -> VGroup: - return self.get_parts_by_string(tex, **kwargs) + def get_parts_by_tex(self, selector: Selector) -> VGroup: + return self.select_parts(selector) - def get_part_by_tex(self, tex: str, **kwargs) -> VMobject: - return self.get_part_by_string(tex, **kwargs) + def get_part_by_tex(self, selector: Selector) -> VGroup: + return self.select_part(selector) - def set_color_by_tex(self, tex: str, color: ManimColor, **kwargs): - return self.set_color_by_string(tex, color, **kwargs) + def set_color_by_tex(self, selector: Selector, color: ManimColor): + return self.set_parts_color(selector, color) def set_color_by_tex_to_color_map( - self, tex_to_color_map: dict[str, ManimColor], **kwargs + self, color_map: dict[Selector, ManimColor] ): - return self.set_color_by_string_to_color_map( - tex_to_color_map, **kwargs - ) + return self.set_parts_color_by_dict(color_map) def get_tex(self) -> str: return self.get_string() diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index c3c3be19..740eeee6 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -1,93 +1,63 @@ from __future__ import annotations -import os -import re -import itertools as it -from pathlib import Path from contextlib import contextmanager -import typing -from typing import Iterable, Sequence, Union +import os +from pathlib import Path +import re +import manimpango import pygments import pygments.formatters import pygments.lexers -from manimpango import MarkupUtils - +from manimlib.constants import DEFAULT_PIXEL_WIDTH, FRAME_WIDTH +from manimlib.constants import NORMAL from manimlib.logger import log -from manimlib.constants import * from manimlib.mobject.svg.labelled_string import LabelledString -from manimlib.utils.customization import get_customization -from manimlib.utils.tex_file_writing import tex_hash from manimlib.utils.config_ops import digest_config +from manimlib.utils.customization import get_customization from manimlib.utils.directories import get_downloads_dir from manimlib.utils.directories import get_text_dir -from manimlib.utils.iterables import remove_list_redundancies - +from manimlib.utils.tex_file_writing import tex_hash from typing import TYPE_CHECKING if TYPE_CHECKING: - from manimlib.mobject.types.vectorized_mobject import VMobject + from colour import Color + from typing import Iterable, Union + from manimlib.mobject.types.vectorized_mobject import VGroup - ManimColor = Union[str, colour.Color, Sequence[float]] + + ManimColor = Union[str, Color] Span = tuple[int, int] + Selector = Union[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]], + Iterable[Union[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]] + ]] + ] TEXT_MOB_SCALE_FACTOR = 0.0076 DEFAULT_LINE_SPACING_SCALE = 0.6 +# Ensure the canvas is large enough to hold all glyphs. +DEFAULT_CANVAS_WIDTH = 16384 +DEFAULT_CANVAS_HEIGHT = 16384 # See https://docs.gtk.org/Pango/pango_markup.html -# A tag containing two aliases will cause warning, -# so only use the first key of each group of aliases. -SPAN_ATTR_KEY_ALIAS_LIST = ( - ("font", "font_desc"), - ("font_family", "face"), - ("font_size", "size"), - ("font_style", "style"), - ("font_weight", "weight"), - ("font_variant", "variant"), - ("font_stretch", "stretch"), - ("font_features",), - ("foreground", "fgcolor", "color"), - ("background", "bgcolor"), - ("alpha", "fgalpha"), - ("background_alpha", "bgalpha"), - ("underline",), - ("underline_color",), - ("overline",), - ("overline_color",), - ("rise",), - ("baseline_shift",), - ("font_scale",), - ("strikethrough",), - ("strikethrough_color",), - ("fallback",), - ("lang",), - ("letter_spacing",), - ("gravity",), - ("gravity_hint",), - ("show",), - ("insert_hyphens",), - ("allow_breaks",), - ("line_height",), - ("text_transform",), - ("segment",), +MARKUP_COLOR_KEYS = ( + "foreground", "fgcolor", "color", + "background", "bgcolor", + "underline_color", + "overline_color", + "strikethrough_color" ) -COLOR_RELATED_KEYS = ( - "foreground", - "background", - "underline_color", - "overline_color", - "strikethrough_color" -) -SPAN_ATTR_KEY_CONVERSION = { - key: key_alias_list[0] - for key_alias_list in SPAN_ATTR_KEY_ALIAS_LIST - for key in key_alias_list -} -TAG_TO_ATTR_DICT = { +MARKUP_TAG_CONVERSION_DICT = { "b": {"font_weight": "bold"}, "big": {"font_size": "larger"}, "i": {"font_style": "italic"}, @@ -96,7 +66,7 @@ TAG_TO_ATTR_DICT = { "sup": {"baseline_shift": "superscript", "font_scale": "superscript"}, "small": {"font_size": "smaller"}, "tt": {"font_family": "monospace"}, - "u": {"underline": "single"}, + "u": {"underline": "single"} } @@ -120,7 +90,7 @@ class MarkupText(LabelledString): "justify": False, "indent": 0, "alignment": "LEFT", - "line_width_factor": None, + "line_width": None, "font": "", "slant": NORMAL, "weight": NORMAL, @@ -141,9 +111,7 @@ class MarkupText(LabelledString): if not self.font: self.font = get_customization()["style"]["font"] if self.is_markup: - validate_error = MarkupUtils.validate(text) - if validate_error: - raise ValueError(validate_error) + self.validate_markup_string(text) self.text = text super().__init__(text, **kwargs) @@ -165,7 +133,6 @@ class MarkupText(LabelledString): self.svg_default, self.path_string_config, self.base_color, - self.use_plain_file, self.isolate, self.text, self.is_markup, @@ -174,7 +141,7 @@ class MarkupText(LabelledString): self.justify, self.indent, self.alignment, - self.line_width_factor, + self.line_width, self.font, self.slant, self.weight, @@ -201,23 +168,32 @@ class MarkupText(LabelledString): kwargs[short_name] = kwargs.pop(long_name) def get_file_path_by_content(self, content: str) -> str: + hash_content = str(( + content, + self.justify, + self.indent, + self.alignment, + self.line_width + )) svg_file = os.path.join( - get_text_dir(), tex_hash(content) + ".svg" + get_text_dir(), tex_hash(hash_content) + ".svg" ) if not os.path.exists(svg_file): self.markup_to_svg(content, svg_file) return svg_file def markup_to_svg(self, markup_str: str, file_name: str) -> str: + self.validate_markup_string(markup_str) + # `manimpango` is under construction, # so the following code is intended to suit its interface alignment = _Alignment(self.alignment) - if self.line_width_factor is None: + if self.line_width is None: pango_width = -1 else: - pango_width = self.line_width_factor * DEFAULT_PIXEL_WIDTH + pango_width = self.line_width / FRAME_WIDTH * DEFAULT_PIXEL_WIDTH - return MarkupUtils.text2svg( + return manimpango.MarkupUtils.text2svg( text=markup_str, font="", # Already handled slant="NORMAL", # Already handled @@ -228,8 +204,8 @@ class MarkupText(LabelledString): file_name=file_name, START_X=0, START_Y=0, - width=DEFAULT_PIXEL_WIDTH, - height=DEFAULT_PIXEL_HEIGHT, + width=DEFAULT_CANVAS_WIDTH, + height=DEFAULT_CANVAS_HEIGHT, justify=self.justify, indent=self.indent, line_spacing=None, # Already handled @@ -237,13 +213,23 @@ class MarkupText(LabelledString): pango_width=pango_width ) - def pre_parse(self) -> None: - super().pre_parse() - self.tag_items_from_markup = self.get_tag_items_from_markup() - self.global_dict_from_config = self.get_global_dict_from_config() - self.local_dicts_from_markup = self.get_local_dicts_from_markup() - self.local_dicts_from_config = self.get_local_dicts_from_config() - self.predefined_attr_dicts = self.get_predefined_attr_dicts() + @staticmethod + def validate_markup_string(markup_str: str) -> None: + validate_error = manimpango.MarkupUtils.validate(markup_str) + if not validate_error: + return + raise ValueError( + f"Invalid markup string \"{markup_str}\"\n" + f"{validate_error}" + ) + + def parse(self) -> None: + self.global_attr_dict = self.get_global_attr_dict() + self.tag_pairs_from_markup = self.get_tag_pairs_from_markup() + self.tag_spans = self.get_tag_spans() + self.items_from_markup = self.get_items_from_markup() + self.specified_items = self.get_specified_items() + super().parse() # Toolkits @@ -254,87 +240,50 @@ class MarkupText(LabelledString): for key, val in attr_dict.items() ]) - @staticmethod - def merge_attr_dicts( - attr_dict_items: list[Span, str, typing.Any] - ) -> list[tuple[Span, dict[str, str]]]: - index_seq = [0] - attr_dict_list = [{}] - for span, attr_dict in attr_dict_items: - if span[0] >= span[1]: - continue - region_indices = [ - MarkupText.find_region_index(index_seq, index) - for index in span - ] - for flag in (1, 0): - if index_seq[region_indices[flag]] == span[flag]: - continue - region_index = region_indices[flag] - index_seq.insert(region_index + 1, span[flag]) - attr_dict_list.insert( - region_index + 1, attr_dict_list[region_index].copy() + # Parsing + + def get_global_attr_dict(self) -> dict[str, str]: + result = { + "foreground": self.int_to_hex(self.base_color_int), + "font_family": self.font, + "font_style": self.slant, + "font_weight": self.weight, + "font_size": str(self.font_size * 1024), + } + # `line_height` attribute is supported since Pango 1.50. + pango_version = manimpango.pango_version() + if tuple(map(int, pango_version.split("."))) < (1, 50): + if self.lsh is not None: + log.warning( + f"Pango version {pango_version} found (< 1.50), " + "unable to set `line_height` attribute" ) - region_indices[flag] += 1 - if flag == 0: - region_indices[1] += 1 - for key, val in attr_dict.items(): - if not key: - continue - for mid_dict in attr_dict_list[slice(*region_indices)]: - mid_dict[key] = val - return list(zip( - MarkupText.get_neighbouring_pairs(index_seq), attr_dict_list[:-1] - )) + else: + line_spacing_scale = self.lsh or DEFAULT_LINE_SPACING_SCALE + result["line_height"] = str(((line_spacing_scale) + 1) * 0.6) + return result - def find_substr_or_span( - self, substr_or_span: str | tuple[int | None, int | None] - ) -> list[Span]: - if isinstance(substr_or_span, str): - return self.find_substr(substr_or_span) - - span = tuple([ - ( - min(index, self.string_len) - if index >= 0 - else max(index + self.string_len, 0) - ) - if index is not None else default_index - for index, default_index in zip(substr_or_span, self.full_span) - ]) - if span[0] >= span[1]: - return [] - return [span] - - # Pre-parsing - - def get_tag_items_from_markup( + def get_tag_pairs_from_markup( self ) -> list[tuple[Span, Span, dict[str, str]]]: if not self.is_markup: return [] - tag_pattern = r"""<(/?)(\w+)\s*((?:\w+\s*\=\s*(['"]).*?\4\s*)*)>""" - attr_pattern = r"""(\w+)\s*\=\s*(['"])(.*?)\2""" + tag_pattern = r"""<(/?)(\w+)\s*((\w+\s*\=\s*(['"])[\s\S]*?\5\s*)*)>""" + attr_pattern = r"""(\w+)\s*\=\s*(['"])([\s\S]*?)\2""" begin_match_obj_stack = [] match_obj_pairs = [] - for match_obj in self.finditer(tag_pattern): + for match_obj in re.finditer(tag_pattern, self.string): if not match_obj.group(1): begin_match_obj_stack.append(match_obj) else: match_obj_pairs.append( (begin_match_obj_stack.pop(), match_obj) ) - if begin_match_obj_stack: - raise ValueError("Unclosed tag(s) detected") result = [] for begin_match_obj, end_match_obj in match_obj_pairs: tag_name = begin_match_obj.group(2) - if tag_name != end_match_obj.group(2): - raise ValueError("Unmatched tag names") - if end_match_obj.group(3): - raise ValueError("Attributes shan't exist in ending tags") if tag_name == "span": attr_dict = { match.group(1): match.group(3) @@ -342,189 +291,170 @@ class MarkupText(LabelledString): attr_pattern, begin_match_obj.group(3) ) } - elif tag_name in TAG_TO_ATTR_DICT.keys(): - if begin_match_obj.group(3): - raise ValueError( - f"Attributes shan't exist in tag '{tag_name}'" - ) - attr_dict = TAG_TO_ATTR_DICT[tag_name].copy() else: - raise ValueError(f"Unknown tag: '{tag_name}'") + attr_dict = MARKUP_TAG_CONVERSION_DICT.get(tag_name, {}) result.append( (begin_match_obj.span(), end_match_obj.span(), attr_dict) ) return result - def get_global_dict_from_config(self) -> dict[str, typing.Any]: - result = { - "line_height": ( - (self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1 - ) * 0.6, - "font_family": self.font, - "font_size": self.font_size * 1024, - "font_style": self.slant, - "font_weight": self.weight - } - result.update(self.global_config) - return result + def get_tag_spans(self) -> list[Span]: + return [ + tag_span + for begin_tag, end_tag, _ in self.tag_pairs_from_markup + for tag_span in (begin_tag, end_tag) + ] - def get_local_dicts_from_markup( - self - ) -> list[Span, dict[str, str]]: - return sorted([ - ((begin_tag_span[0], end_tag_span[1]), attr_dict) + def get_items_from_markup(self) -> list[Span]: + return [ + ((begin_tag_span[1], end_tag_span[0]), attr_dict) for begin_tag_span, end_tag_span, attr_dict - in self.tag_items_from_markup - ]) + in self.tag_pairs_from_markup + if begin_tag_span[1] < end_tag_span[0] + ] - def get_local_dicts_from_config( - self - ) -> list[Span, dict[str, typing.Any]]: + def get_specified_items(self) -> list[tuple[Span, dict[str, str]]]: + result = self.chain( + self.items_from_markup, + [ + (span, {key: val}) + for t2x_dict, key in ( + (self.t2c, "foreground"), + (self.t2f, "font_family"), + (self.t2s, "font_style"), + (self.t2w, "font_weight") + ) + for selector, val in t2x_dict.items() + for span in self.find_spans_by_selector(selector) + ], + [ + (span, local_config) + for selector, local_config in self.local_configs.items() + for span in self.find_spans_by_selector(selector) + ], + [ + (span, {}) + for span in self.find_spans_by_selector(self.isolate) + ] + ) + entity_spans = self.tag_spans.copy() + if self.is_markup: + entity_spans.extend(self.find_spans(r"&[\s\S]*?;")) return [ - (span, {key: val}) - for t2x_dict, key in ( - (self.t2c, "foreground"), - (self.t2f, "font_family"), - (self.t2s, "font_style"), - (self.t2w, "font_weight") - ) - for substr_or_span, val in t2x_dict.items() - for span in self.find_substr_or_span(substr_or_span) - ] + [ - (span, local_config) - for substr_or_span, local_config in self.local_configs.items() - for span in self.find_substr_or_span(substr_or_span) + (span, attr_dict) + for span, attr_dict in result + if not any([ + entity_begin < index < entity_end + for index in span + for entity_begin, entity_end in entity_spans + ]) ] - def get_predefined_attr_dicts(self) -> list[Span, dict[str, str]]: - attr_dict_items = [ - (self.full_span, self.global_dict_from_config), - *self.local_dicts_from_markup, - *self.local_dicts_from_config - ] - return [ - (span, { - SPAN_ATTR_KEY_CONVERSION[key.lower()]: str(val) - for key, val in attr_dict.items() - }) - for span, attr_dict in attr_dict_items - ] - - # Parsing - def get_command_repl_items(self) -> list[tuple[Span, str]]: result = [ - (tag_span, "") - for begin_tag, end_tag, _ in self.tag_items_from_markup - for tag_span in (begin_tag, end_tag) + (tag_span, "") for tag_span in self.tag_spans ] if not self.is_markup: - result += [ + result.extend([ (span, escaped) for char, escaped in ( ("&", "&"), (">", ">"), ("<", "<") ) - for span in self.find_substr(char) - ] + for span in self.find_spans(re.escape(char)) + ]) return result - def get_extra_entity_spans(self) -> list[Span]: - if not self.is_markup: - return [] - return self.find_spans(r"&.*?;") - - def get_extra_ignored_spans(self) -> list[int]: - return [] - - def get_internal_specified_spans(self) -> list[Span]: - return [span for span, _ in self.local_dicts_from_markup] - - def get_external_specified_spans(self) -> list[Span]: - return [span for span, _ in self.local_dicts_from_config] + def get_specified_spans(self) -> list[Span]: + return self.remove_redundancies([ + span for span, _ in self.specified_items + ]) def get_label_span_list(self) -> list[Span]: - breakup_indices = remove_list_redundancies(list(it.chain(*it.chain( - self.find_spans(r"\s+"), - self.find_spans(r"\b"), - self.specified_spans - )))) - breakup_indices = sorted(filter( - lambda index: not any([ - span[0] < index < span[1] - for span in self.entity_spans - ]), - breakup_indices - )) - return list(filter( - lambda span: self.get_substr(span).strip(), - self.get_neighbouring_pairs(breakup_indices) - )) - - def get_content(self, use_plain_file: bool) -> str: - if use_plain_file: - attr_dict_items = [ - (self.full_span, {"foreground": self.base_color}), - *self.predefined_attr_dicts, - *[ - (span, {}) - for span in self.label_span_list - ] + interval_spans = sorted(self.chain( + self.tag_spans, + [ + (index, index) + for span in self.specified_spans + for index in span ] + )) + text_spans = self.get_complement_spans(interval_spans, self.full_span) + if self.is_markup: + pattern = r"[0-9a-zA-Z]+|(?:&[\s\S]*?;|[^0-9a-zA-Z\s])+" else: - attr_dict_items = [ - (self.full_span, {"foreground": BLACK}), - *[ + pattern = r"[0-9a-zA-Z]+|[^0-9a-zA-Z\s]+" + return self.chain(*[ + self.find_spans(pattern, pos=span_begin, endpos=span_end) + for span_begin, span_end in text_spans + ]) + + def get_content(self, is_labelled: bool) -> str: + predefined_items = [ + (self.full_span, self.global_attr_dict), + (self.full_span, self.global_config), + *self.specified_items + ] + if is_labelled: + attr_dict_items = self.chain( + [ (span, { - key: BLACK if key in COLOR_RELATED_KEYS else val + key: + "black" if key.lower() in MARKUP_COLOR_KEYS else val for key, val in attr_dict.items() }) - for span, attr_dict in self.predefined_attr_dicts + for span, attr_dict in predefined_items ], - *[ + [ (span, {"foreground": self.int_to_hex(label + 1)}) for label, span in enumerate(self.label_span_list) ] - ] + ) + else: + attr_dict_items = self.chain( + predefined_items, + [ + (span, {}) + for span in self.label_span_list + ] + ) inserted_string_pairs = [ (span, ( f"", "" )) - for span, attr_dict in self.merge_attr_dicts(attr_dict_items) + for span, attr_dict in attr_dict_items if attr_dict ] - span_repl_dict = self.generate_span_repl_dict( + return self.get_replaced_string( inserted_string_pairs, self.command_repl_items ) - return self.get_replaced_substr(self.full_span, span_repl_dict) - @property - def has_predefined_local_colors(self) -> bool: - return any([ - key in COLOR_RELATED_KEYS - for _, attr_dict in self.predefined_attr_dicts - for key in attr_dict.keys() - ]) + # Selector + + def get_cleaned_substr(self, span: Span) -> str: + repl_items = list(filter( + lambda repl_item: self.span_contains(span, repl_item[0]), + self.command_repl_items + )) + return self.get_replaced_substr(span, repl_items).strip() # Method alias - def get_parts_by_text(self, text: str, **kwargs) -> VGroup: - return self.get_parts_by_string(text, **kwargs) + def get_parts_by_text(self, selector: Selector) -> VGroup: + return self.select_parts(selector) - def get_part_by_text(self, text: str, **kwargs) -> VMobject: - return self.get_part_by_string(text, **kwargs) + def get_part_by_text(self, selector: Selector) -> VGroup: + return self.select_part(selector) - def set_color_by_text(self, text: str, color: ManimColor, **kwargs): - return self.set_color_by_string(text, color, **kwargs) + def set_color_by_text(self, selector: Selector, color: ManimColor): + return self.set_parts_color(selector, color) def set_color_by_text_to_color_map( - self, text_to_color_map: dict[str, ManimColor], **kwargs + self, color_map: dict[Selector, ManimColor] ): - return self.set_color_by_string_to_color_map( - text_to_color_map, **kwargs - ) + return self.set_parts_color_by_dict(color_map) def get_text(self) -> str: return self.get_string() diff --git a/manimlib/utils/iterables.py b/manimlib/utils/iterables.py index bdaa76c2..fa54e68d 100644 --- a/manimlib/utils/iterables.py +++ b/manimlib/utils/iterables.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: S = TypeVar("S") -def remove_list_redundancies(l: Iterable[T]) -> list[T]: +def remove_list_redundancies(l: Sequence[T]) -> list[T]: """ Used instead of list(set(l)) to maintain order Keeps the last occurrence of each element @@ -40,14 +40,14 @@ def list_difference_update(l1: Iterable[T], l2: Iterable[T]) -> list[T]: return [e for e in l1 if e not in l2] -def adjacent_n_tuples(objects: Iterable[T], n: int) -> zip[tuple[T, T]]: +def adjacent_n_tuples(objects: Sequence[T], n: int) -> zip[tuple[T, T]]: return zip(*[ [*objects[k:], *objects[:k]] for k in range(n) ]) -def adjacent_pairs(objects: Iterable[T]) -> zip[tuple[T, T]]: +def adjacent_pairs(objects: Sequence[T]) -> zip[tuple[T, T]]: return adjacent_n_tuples(objects, 2)