From 03cb42ba15d5fdf829bbc3e30f3ef594f004dd41 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Mon, 2 May 2022 22:40:06 +0800 Subject: [PATCH] [WIP] Refactor LabelledString and relevant classes --- .../animation/transform_matching_parts.py | 4 +- manimlib/mobject/svg/labelled_string.py | 659 +++++++++++++--- manimlib/mobject/svg/mtex_mobject.py | 716 +++++++++++++----- manimlib/mobject/svg/text_mobject.py | 602 +++++++++++---- 4 files changed, 1525 insertions(+), 456 deletions(-) diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index e84f1d9d..96fd95ce 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -225,8 +225,8 @@ class TransformMatchingStrings(AnimationGroup): VGroup(*target_substr_to_parts_map[substr]) ) for substr in sorted([ - s for s in source_substr_to_parts_map.keys() - if s and s in target_substr_to_parts_map.keys() + s for s in source_substr_to_parts_map + if s and s in target_substr_to_parts_map ], key=len, reverse=True) ] ) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index cb62df9b..23c285de 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -4,8 +4,6 @@ from abc import ABC, abstractmethod import itertools as it import re -import numpy as np - from manimlib.constants import WHITE from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.types.vectorized_mobject import VGroup @@ -56,7 +54,7 @@ class LabelledString(SVGMobject, ABC): digest_config(self, kwargs) if self.base_color is None: self.base_color = WHITE - self.base_color_int = self.color_to_int(self.base_color) + self.base_color_hex = self.color_to_hex(self.base_color) self.full_span = (0, len(self.string)) self.parse() @@ -67,11 +65,7 @@ class LabelledString(SVGMobject, ABC): ] def get_file_path(self) -> str: - return self.get_file_path_(is_labelled=False) - - def get_file_path_(self, is_labelled: bool) -> str: - content = self.get_content(is_labelled) - return self.get_file_path_by_content(content) + return self.get_file_path_by_content(self.original_content) @abstractmethod def get_file_path_by_content(self, content: str) -> str: @@ -80,53 +74,66 @@ class LabelledString(SVGMobject, ABC): def generate_mobject(self) -> None: super().generate_mobject() - num_labels = len(self.label_span_list) - if num_labels: - file_path = self.get_file_path_(is_labelled=True) - labelled_svg = SVGMobject(file_path) - submob_color_ints = [ - self.color_to_int(submob.get_fill_color()) - for submob in labelled_svg.submobjects - ] - else: - submob_color_ints = [0] * len(self.submobjects) - - if len(self.submobjects) != len(submob_color_ints): + file_path = self.get_file_path_by_content(self.labelled_content) + labelled_svg = SVGMobject(file_path) + num_submobjects = len(self.submobjects) + if num_submobjects != len(labelled_svg.submobjects): raise ValueError( "Cannot align submobjects of the labelled svg " "to the original svg" ) + submob_color_ints = [ + self.hex_to_int(self.color_to_hex(submob.get_fill_color())) + for submob in labelled_svg.submobjects + ] unrecognized_color_ints = self.remove_redundancies(sorted(filter( - lambda color_int: color_int > num_labels, + lambda color_int: color_int > len(self.label_span_list), submob_color_ints ))) if unrecognized_color_ints: raise ValueError( "Unrecognized color label(s) detected: " - f"{','.join(map(self.int_to_hex, unrecognized_color_ints))}" + f"{', '.join(map(self.int_to_hex, unrecognized_color_ints))}" ) + #if self.sort_labelled_submobs: + submob_indices = sorted( + range(num_submobjects), + key=lambda index: tuple( + self.submobjects[index].get_center() + ) + ) + labelled_submob_indices = sorted( + range(num_submobjects), + key=lambda index: tuple( + labelled_svg.submobjects[index].get_center() + ) + ) + submob_color_ints = [ + submob_color_ints[ + labelled_submob_indices[submob_indices.index(index)] + ] + for index in range(num_submobjects) + ] + for submob, color_int in zip(self.submobjects, submob_color_ints): submob.label = color_int - 1 - def parse(self) -> None: - self.command_repl_items = self.get_command_repl_items() - self.specified_spans = self.get_specified_spans() - self.check_overlapping() - self.label_span_list = self.get_label_span_list() - if len(self.label_span_list) >= 16777216: - raise ValueError("Cannot handle that many substrings") + #@property + #@abstractmethod + #def sort_labelled_submobs(self) -> bool: + # return False # Toolkits def get_substr(self, span: Span) -> str: return self.string[slice(*span)] - def match(self, pattern: str | re.Pattern, **kwargs) -> re.Pattern | None: - if isinstance(pattern, str): - pattern = re.compile(pattern) - return re.compile(pattern).match(self.string, **kwargs) + #def match(self, pattern: str | re.Pattern, **kwargs) -> re.Pattern | None: + # if isinstance(pattern, str): + # pattern = re.compile(pattern) + # return re.compile(pattern).match(self.string, **kwargs) def find_spans(self, pattern: str | re.Pattern, **kwargs) -> list[Span]: if isinstance(pattern, str): @@ -136,8 +143,8 @@ class LabelledString(SVGMobject, ABC): for match_obj in pattern.finditer(self.string, **kwargs) ] - def find_indices(self, pattern: str | re.Pattern, **kwargs) -> list[int]: - return [index for index, _ in self.find_spans(pattern, **kwargs)] + #def find_indices(self, pattern: str | re.Pattern, **kwargs) -> list[int]: + # return [index for index, _ in self.find_spans(pattern, **kwargs)] def find_spans_by_selector(self, selector: Selector) -> list[Span]: def find_spans_by_single_selector(sel): @@ -145,20 +152,16 @@ class LabelledString(SVGMobject, ABC): return self.find_spans(re.escape(sel)) if isinstance(sel, re.Pattern): return self.find_spans(sel) - if isinstance(sel, tuple) and len(sel) == 2 and all([ + if isinstance(sel, tuple) and len(sel) == 2 and all( isinstance(index, int) or index is None for index in sel - ]): - string_len = self.full_span[1] - span = tuple([ - ( - min(index, string_len) - if index >= 0 - else max(index + string_len, 0) - ) + ): + l = self.full_span[1] + span = tuple( + min(index, l) if index >= 0 else max(index + l, 0) if index is not None else default_index for index, default_index in zip(sel, self.full_span) - ]) + ) return [span] return None @@ -203,13 +206,158 @@ class LabelledString(SVGMobject, ABC): val_ranges = LabelledString.get_neighbouring_pairs(indices) return list(zip(unique_vals, val_ranges)) + @staticmethod + def sort_obj_pairs_by_spans( + obj_pairs: list[tuple[Span, tuple[T, T]]] + ) -> list[tuple[int, T]]: + return [ + (index, obj) + for (index, _), obj in sorted([ + (span, begin_obj) + for span, (begin_obj, _) in obj_pairs + ] + [ + (span[::-1], end_obj) + for span, (_, end_obj) in reversed(obj_pairs) + ], key=lambda t: (t[0][0], -t[0][1])) + ] + @staticmethod def span_contains(span_0: Span, span_1: Span) -> bool: return span_0[0] <= span_1[0] and span_0[1] >= span_1[1] + def get_piece_items( + self, + tag_span_pairs: list[tuple[Span, Span]], + entity_spans: list[Span] + ) -> tuple[list[Span], list[int]]: + tagged_items = sorted(self.chain( + [(begin_cmd_span, 1) for begin_cmd_span, _ in tag_span_pairs], + [(end_cmd_span, -1) for _, end_cmd_span in tag_span_pairs], + [(entity_span, 0) for entity_span in entity_spans], + ), key=lambda t: t[0]) + piece_spans = self.get_complement_spans(self.full_span, [ + interval_span for interval_span, _ in tagged_items + ]) + piece_levels = [0, *it.accumulate([tag for _, tag in tagged_items])] + return piece_spans, piece_levels + + def split_span(self, arbitrary_span: Span) -> list[Span]: + # ignorable_indices -- + # left_bracket_spans + # right_bracket_spans + # entity_spans + #piece_spans, piece_levels = zip(*self.piece_items) + #ignorable_indices = self.ignorable_indices + piece_spans = self.piece_spans + piece_levels = self.piece_levels + #piece_begins, piece_ends = zip(*piece_spans) + #span_begin, span_end = arbitrary_span + #while span_begin in ignorable_indices: + # span_begin += 1 + #while span_end - 1 in ignorable_indices: + # span_end -= 1 + #entity_spans = self.chain( + # left_bracket_spans, right_bracket_spans, entity_spans + #) + index_begin = sum([ + arbitrary_span[0] > piece_end + for _, piece_end in piece_spans + ]) + index_end = sum([ + arbitrary_span[1] >= piece_begin + for piece_begin, _ in piece_spans + ]) + if index_begin >= index_end: + return [] + + lowest_level = min( + piece_levels[index_begin:index_end] + ) + split_piece_indices = [] + target_level = piece_levels[index_begin] + for piece_index in range(index_begin, index_end): + if piece_levels[piece_index] != target_level: + continue + split_piece_indices.append(piece_index) + target_level -= 1 + if target_level < lowest_level: + break + len_indices = len(split_piece_indices) + target_level = piece_levels[index_end - 1] + for piece_index in range(index_begin, index_end)[::-1]: + if piece_levels[piece_index] != target_level: + continue + split_piece_indices.insert(len_indices, piece_index + 1) + target_level -= 1 + if target_level < lowest_level: + break + + span_begins = [ + piece_spans[piece_index][0] + for piece_index in split_piece_indices[:-1] + ] + span_begins[0] = max(arbitrary_span[0], span_begins[0]) + span_ends = [ + piece_spans[piece_index - 1][1] + for piece_index in split_piece_indices[1:] + ] + span_ends[-1] = min(arbitrary_span[1], span_ends[-1]) + return list(zip(span_begins, span_ends)) + #lowest_level_indices = [ + # piece_index + # for piece_index, piece_level in enumerate(piece_levels) + # if left_piece_index <= piece_index <= right_piece_index + # and piece_level == lowest_level + #] + #left_lowest_index = min(lowest_level_indices) + #right_lowest_index = max(lowest_level_indices) + #while right_lowest_index != right_piece_index: + + + #left_parallel_index = max( + # piece_index + # for piece_index, piece_level in enumerate(piece_levels) + # if left_piece_index <= piece_index <= right_piece_index + # and piece_level == piece_levels[left_piece_index] + #) + #right_parallel_index = min( + # piece_index + # for piece_index, piece_level in enumerate(piece_levels) + # if left_piece_index <= piece_index <= right_piece_index + # and piece_level == piece_levels[right_piece_index] + #) + #result.append(( + # piece_spans[left_lowest_index][0], + # piece_spans[right_lowest_index][1] + #)) + #lowest_piece_indices = [ + # piece_index + # for piece_index, piece_level in enumerate( + + # ) + #] + #adjusted_span_begin = max(span_begin, piece_spans[begin_piece_index][0]) ## + #adjusted_span_end = min(span_end, piece_spans[end_piece_index][1]) ## + #begin_level_mismatch = piece_levels[begin_piece_index] - lowest_level + #end_level_mismatch = piece_levels[end_piece_index] - lowest_level + #if begin_level_mismatch: + # span_begin = piece_spans[max([ + # index + # for index, piece_level in enumerate(piece_levels) + # if piece_level == lowest_level and index < begin_piece_index + # ])][1] + # begin_level_mismatch = 0 + #if end_level_mismatch: + # span_end = piece_spans[min([ + # index + # for index, piece_level in enumerate(piece_levels) + # if piece_level == lowest_level and index > end_piece_index + # ])][0] + # end_level_mismatch = 0 + @staticmethod def get_complement_spans( - interval_spans: list[Span], universal_span: Span + universal_span: Span, interval_spans: list[Span] ) -> list[Span]: if not interval_spans: return [universal_span] @@ -220,85 +368,138 @@ class LabelledString(SVGMobject, ABC): (*span_ends, universal_span[1]) )) - @staticmethod - def merge_inserted_strings_from_pairs( - inserted_string_pairs: list[tuple[Span, tuple[str, str]]] - ) -> list[tuple[int, str]]: - if not inserted_string_pairs: - return [] - - indices, *_, inserted_strings = zip(*sorted([ - ( - span[flag], - np.sign(span[1 - flag] - span[flag]), - -span[1 - flag], - flag, - (1, -1)[flag] * item_index, - str_pair[flag] - ) - for item_index, (span, str_pair) in enumerate( - inserted_string_pairs - ) - for flag in range(2) - ])) - return [ - (index, "".join(inserted_strings[slice(*index_range)])) - for index, index_range - in LabelledString.compress_neighbours(indices) - ] - - def get_replaced_substr( - self, span: Span, repl_items: list[tuple[Span, str]] - ) -> str: + def get_replaced_substr(self, span: Span, repl_items: list[Span, str]): # TODO: need `span` attr? if not repl_items: return self.get_substr(span) - repl_spans, repl_strs = zip(*sorted(repl_items)) + repl_spans, repl_strs = zip(*sorted( + repl_items, key=lambda t: t[0] + )) pieces = [ self.get_substr(piece_span) - for piece_span in self.get_complement_spans(repl_spans, span) + for piece_span in self.get_complement_spans(span, repl_spans) ] repl_strs = [*repl_strs, ""] return "".join(self.chain(*zip(pieces, repl_strs))) - def get_replaced_string( - self, - inserted_string_pairs: list[tuple[Span, tuple[str, str]]], - repl_items: list[tuple[Span, str]] - ) -> str: - all_repl_items = self.chain( - repl_items, - [ - ((index, index), inserted_string) - for index, inserted_string - in self.merge_inserted_strings_from_pairs( - inserted_string_pairs - ) - ] - ) - return self.get_replaced_substr(self.full_span, all_repl_items) + #def get_replaced_string( + # self, + # inserted_string_pairs: list[tuple[Span, tuple[str, str]]], + # repl_items: list[tuple[Span, str]] + #) -> str: + # all_repl_items = self.chain( + # repl_items, + # [ + # ((index, index), inserted_string) + # for index, inserted_string + # in self.sort_inserted_strings_from_pairs( + # inserted_string_pairs + # ) + # ] + # ) + # return self.get_replaced_substr(self.full_span, all_repl_items) @staticmethod - def color_to_int(color: ManimColor) -> int: - hex_code = rgb_to_hex(color_to_rgb(color)) - return int(hex_code[1:], 16) + def color_to_hex(color: ManimColor) -> str: + return rgb_to_hex(color_to_rgb(color)) + + @staticmethod + def hex_to_int(rgb_hex: str) -> int: + return int(rgb_hex[1:], 16) @staticmethod def int_to_hex(rgb_int: int) -> str: - return "#{:06x}".format(rgb_int).upper() + return f"#{rgb_int:06x}".upper() + + @staticmethod + @abstractmethod + def get_tag_str( + attr_dict: dict[str, str], escape_color_keys: bool, is_begin_tag: bool + ) -> str: + return "" + + #def get_color_tag_str(self, rgb_int: int, is_begin_tag: bool) -> str: + # return self.get_tag_str({ + # "foreground": self.int_to_hex(rgb_int) + # }, escape_color_keys=False, is_begin_tag=is_begin_tag) # Parsing - @abstractmethod - def get_command_repl_items(self) -> list[tuple[Span, str]]: - return [] + #@abstractmethod + #def get_command_spans(self) -> list[Span]: + # return [] + # #return [ + # # self.match(r"\\(?:[a-zA-Z]+|.)", pos=index).span() + # # for index in self.backslash_indices + # #] - @abstractmethod - def get_specified_spans(self) -> list[Span]: - return [] + #@abstractmethod + #@staticmethod + #def get_command_repl_dict() -> dict[str | re.Pattern, str]: + # return {} - def check_overlapping(self) -> None: - for span_0, span_1 in it.product(self.specified_spans, repeat=2): + #@abstractmethod + #def parse_setup(self) -> None: + # return + + #@abstractmethod + #def get_command_repl_items(self) -> list[tuple[Span, str]]: + # return [] + # #result = [] + # #for cmd_span in self.command_spans: + # # cmd_str = self.get_substr(cmd_span) + # # if + # # repl_str = self.command_repl_dict.get(cmd_str, cmd_str) + # # result.append((cmd_span, repl_str)) + # #return result + + #def span_cuts_at_entity(self, span: Span) -> bool: + # return any([ + # entity_begin < index < entity_end + # for index in span + # for entity_begin, entity_end in self.command_repl_items + # ]) + + #@abstractmethod + #def get_all_specified_items(self) -> list[tuple[Span, dict[str, str]]]: + # return [] + + #def get_specified_items(self) -> list[tuple[Span, dict[str, str]]]: + # return [ + # (span, attr_dict) + # for span, attr_dict in self.get_all_specified_items() + # if not any([ + # entity_begin < index < entity_end + # for index in span + # for entity_begin, entity_end in self.command_repl_items + # ]) + # ] + + #def get_specified_spans(self) -> list[Span]: + # return [span for span, _ in self.specified_items] + + def parse(self) -> None: + self.entity_spans = self.get_entity_spans() + tag_span_pairs, internal_items = self.get_internal_items() + self.piece_spans, self.piece_levels = self.get_piece_items( + tag_span_pairs, self.entity_spans + ) + #self.tag_content_spans = [ + # (content_begin, content_end) + # for (_, content_begin), (content_end, _) in tag_span_pairs + #] + self.tag_spans = self.chain(*tag_span_pairs) + specified_items = self.chain( + internal_items, + self.get_external_items(), + [ + (span, {}) + for span in self.find_spans_by_selector(self.isolate) + ] + ) + print(f"\n{specified_items=}\n") + specified_spans = [span for span, _ in specified_items] + for span_0, span_1 in it.product(specified_spans, repeat=2): if not span_0[0] < span_1[0] < span_0[1] < span_1[1]: continue raise ValueError( @@ -306,13 +507,246 @@ class LabelledString(SVGMobject, ABC): f"'{self.get_substr(span_0)}' and '{self.get_substr(span_1)}'" ) + split_items = [ + (span, attr_dict) + for specified_span, attr_dict in specified_items + for span in self.split_span(specified_span) + ] + print(f"\n{split_items=}\n") + split_spans = [span for span, _ in split_items] + label_span_list = self.get_label_span_list(split_spans) + if len(label_span_list) >= 16777216: + raise ValueError("Cannot handle that many substrings") + + #content_strings = [] + #for is_labelled in (False, True): + # + # content_strings.append(content_string) + + #inserted_str_pairs = self.chain( + # [ + # (span, ( + # self.get_tag_str(attr_dict, escape_color_keys=True, is_begin_tag=True), + # self.get_tag_str(attr_dict, escape_color_keys=True, is_begin_tag=False) + # )) + # for span, attr_dict in split_items + # ], + # [ + # (span, ( + # self.get_color_tag_str(label + 1, is_begin_tag=True), + # self.get_color_tag_str(label + 1, is_begin_tag=False) + # )) + # for span, attr_dict in split_items + # ] + #) + + + #decorated_strings = [ + # self.get_replaced_substr(self.full_span, [ + # (span, str_pair[flag]) + # for span, str_pair in command_repl_items + # ]) + # for flag in range(2) + #] + + self.specified_spans = specified_spans + self.label_span_list = label_span_list + self.original_content = self.get_full_content_string( + label_span_list, split_items, is_labelled=False + ) + self.labelled_content = self.get_full_content_string( + label_span_list, split_items, is_labelled=True + ) + print(self.original_content) + print() + print(self.labelled_content) + + + #self.command_repl_dict = self.get_command_repl_dict() + #self.command_repl_items = [] + #self.bracket_content_spans = [] + ##self.command_spans = self.get_command_spans() + ##self.specified_items = self.get_specified_items() + #self.specified_spans = [] + #self.check_overlapping() ####### + #self.label_span_list = [] + #if len(self.label_span_list) >= 16777216: + # raise ValueError("Cannot handle that many substrings") + @abstractmethod - def get_label_span_list(self) -> list[Span]: + def get_entity_spans(self) -> list[Span]: return [] @abstractmethod - def get_content(self, is_labelled: bool) -> str: - return "" + def get_internal_items( + self + ) -> tuple[list[tuple[Span, Span]], list[tuple[Span, dict[str, str]]]]: + return [], [] + + @abstractmethod + def get_external_items(self) -> list[tuple[Span, dict[str, str]]]: + return [] + + #@abstractmethod + #def get_spans_from_items(self, specified_items: list[tuple[Span, dict[str, str]]]) -> list[Span]: + # return [] + + #def split_span(self, arbitrary_span: Span) -> list[Span]: + # span_begin, span_end = arbitrary_span + # # TODO: improve algorithm + # span_begin += sum([ + # entity_end - span_begin + # for entity_begin, entity_end in self.entity_spans + # if entity_begin < span_begin < entity_end + # ]) + # span_end -= sum([ + # span_end - entity_begin + # for entity_begin, entity_end in self.entity_spans + # if entity_begin < span_end < entity_end + # ]) + # if span_begin >= span_end: + # return [] + + # adjusted_span = (span_begin, span_end) + # result = [] + # span_choices = list(filter( + # lambda span: span[0] < span[1] and self.span_contains( + # adjusted_span, span + # ), + # self.tag_content_spans + # )) + # while span_choices: + # chosen_span = min(span_choices, key=lambda t: (t[0], -t[1])) + # result.append(chosen_span) + # span_choices = list(filter( + # lambda span: chosen_span[1] <= span[0], + # span_choices + # )) + # result.extend(self.chain(*[ + # self.get_complement_spans(span, sorted([ + # (max(tag_span[0], span[0]), min(tag_span[1], span[1])) + # for tag_span in self.tag_spans + # if tag_span[0] < span[1] and span[0] < tag_span[1] + # ])) + # for span in self.get_complement_spans(adjusted_span, result) + # ])) + return list(filter(lambda span: span[0] < span[1], result)) + + #@abstractmethod + #def get_split_items(self, specified_items: list[T]) -> list[T]: + # return [] + + @abstractmethod + def get_label_span_list(self, split_spans: list[Span]) -> list[Span]: + return [] + + #@abstractmethod + #def get_predefined_inserted_str_items( + # self, split_items: list[T] + #) -> list[tuple[Span, tuple[str, str]]]: + # return [] + + #def check_overlapping(self) -> None: + + #for span_0, span_1 in it.product(self.specified_spans, self.bracket_content_spans): + # if not any( + # span_0[0] < span_1[0] <= span_0[1] <= span_1[1], + # span_1[0] <= span_0[0] <= span_1[1] < span_0[1] + # ): + # continue + # raise ValueError( + # f"Invalid substring detected: '{self.get_substr(span_0)}'" + # ) + # TODO: test bracket_content_spans + + #@abstractmethod + #def get_inserted_string_pairs( + # self, is_labelled: bool + #) -> list[tuple[Span, tuple[str, str]]]: + # return [] + + #@abstractmethod + #def get_label_span_list(self) -> list[Span]: + # return [] + + #def get_decorated_string( + # self, is_labelled: bool, replace_commands: bool + #) -> str: + # inserted_string_pairs = [ + # (indices, str_pair) + # for indices, str_pair in self.get_inserted_string_pairs( + # is_labelled=is_labelled + # ) + # if not any( + # cmd_begin < index < cmd_end + # for index in indices + # for (cmd_begin, cmd_end), _ in self.command_repl_items + # ) + # ] + # repl_items = [ + # ((index, index), inserted_string) + # for index, inserted_string + # in self.sort_inserted_strings_from_pairs( + # inserted_string_pairs + # ) + # ] + # if replace_commands: + # repl_items.extend(self.command_repl_items) + # return self.get_replaced_substr(self.full_span, repl_items) + + @abstractmethod + def get_additional_inserted_str_pairs( + self + ) -> list[tuple[Span, tuple[str, str]]]: + return [] + + @abstractmethod + def get_command_repl_items(self, is_labelled: bool) -> list[Span, str]: + return [] + + def get_full_content_string( + self, + label_span_list: list[Span], + split_items: list[tuple[Span, dict[str, str]]], + is_labelled: bool + ) -> str: + label_items = [ + (span, { + "foreground": self.int_to_hex(label + 1) + } if is_labelled else {}) + for label, span in enumerate(label_span_list) + ] + inserted_str_pairs = self.chain( + self.get_additional_inserted_str_pairs(), + [ + (span, tuple( + self.get_tag_str( + attr_dict, + escape_color_keys=is_labelled and not is_label_item, + is_begin_tag=is_begin_tag + ) + for is_begin_tag in (True, False) + )) + for is_label_item, items in enumerate(( + split_items, label_items + )) + for span, attr_dict in items + ] + ) + repl_items = self.chain( + self.get_command_repl_items(is_labelled), + [ + ((index, index), inserted_str) + for index, inserted_str + in self.sort_obj_pairs_by_spans(inserted_str_pairs) + ] + ) + return self.get_replaced_substr( + self.full_span, repl_items + ) + + #def get_content(self, is_labelled: bool) -> str: + # return self.content_strings[int(is_labelled)] # Selector @@ -348,7 +782,7 @@ class LabelledString(SVGMobject, ABC): group_substrs = [ self.get_cleaned_substr(span) if span[0] < span[1] else "" for span in self.get_complement_spans( - interval_spans, (ordered_spans[0][0], ordered_spans[-1][1]) + (ordered_spans[0][0], ordered_spans[-1][1]), interval_spans ) ] submob_groups = VGroup(*[ @@ -366,10 +800,11 @@ class LabelledString(SVGMobject, ABC): for span in self.specified_spans ] - def select_part_by_span(self, custom_span: Span) -> VGroup: + def select_part_by_span(self, arbitrary_span: Span) -> VGroup: return VGroup(*[ submob for label, submob in self.labelled_submobject_items - if self.span_contains(custom_span, self.label_span_list[label]) + if label != -1 + and self.span_contains(arbitrary_span, self.label_span_list[label]) ]) def select_parts(self, selector: Selector) -> VGroup: diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 3edb10b1..93e49a81 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -1,7 +1,5 @@ from __future__ import annotations -import re - from manimlib.mobject.svg.labelled_string import LabelledString from manimlib.utils.tex_file_writing import display_during_execution from manimlib.utils.tex_file_writing import get_tex_config @@ -11,6 +9,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from colour import Color + import re from typing import Iterable, Union from manimlib.mobject.types.vectorized_mobject import VGroup @@ -39,6 +38,7 @@ TEX_COLOR_COMMANDS_DICT = { "\\colorbox": (1, True), "\\fcolorbox": (2, True), } +TEX_COLOR_COMMAND_SUFFIX = "replaced" class MTex(LabelledString): @@ -56,7 +56,7 @@ class MTex(LabelledString): self.tex_string = tex_string super().__init__(tex_string, **kwargs) - self.set_color_by_tex_to_color_map(self.tex_to_color_map) + #self.set_color_by_tex_to_color_map(self.tex_to_color_map) self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size) @property @@ -83,208 +83,490 @@ class MTex(LabelledString): file_path = tex_to_svg_file(full_tex) return file_path - def parse(self) -> None: - self.backslash_indices = self.get_backslash_indices() - self.command_spans = self.get_command_spans() - self.brace_spans = self.get_brace_spans() - self.script_char_indices = self.get_script_char_indices() - self.script_content_spans = self.get_script_content_spans() - self.script_spans = self.get_script_spans() - super().parse() + #@property + #def sort_labelled_submobs(self) -> bool: + # return False # Toolkits @staticmethod - def get_color_command_str(rgb_int: int) -> str: - rg, b = divmod(rgb_int, 256) + def get_color_command_str(rgb_hex: str) -> str: + rgb = MTex.hex_to_int(rgb_hex) + rg, b = divmod(rgb, 256) r, g = divmod(rg, 256) return f"\\color[RGB]{{{r}, {g}, {b}}}" @staticmethod - def shrink_span(span: Span, skippable_indices: list[int]) -> Span: - span_begin, span_end = span - while span_begin in skippable_indices: - span_begin += 1 - while span_end - 1 in skippable_indices: - span_end -= 1 - return (span_begin, span_end) + def get_tag_str( + attr_dict: dict[str, str], escape_color_keys: bool, is_begin_tag: bool + ) -> str: + if escape_color_keys: + return "" + if not is_begin_tag: + return "}}" + if "foreground" not in attr_dict: + return "{{" + return "{{" + MTex.get_color_command_str(attr_dict["foreground"]) + + #@staticmethod + #def shrink_span(span: Span, skippable_indices: list[int]) -> Span: + # span_begin, span_end = span + # while span_begin in skippable_indices: + # span_begin += 1 + # while span_end - 1 in skippable_indices: + # span_end -= 1 + # return (span_begin, span_end) # Parsing - def get_backslash_indices(self) -> list[int]: - # The latter of `\\` doesn't count. - return self.find_indices(r"\\.") + #def parse(self) -> None: # TODO + #command_spans = self.find_spans(r"\\(?:[a-zA-Z]+|.)") - def get_command_spans(self) -> list[Span]: - return [ - self.match(r"\\(?:[a-zA-Z]+|.)", pos=index).span() - for index in self.backslash_indices - ] - def get_unescaped_char_indices(self, char: str) -> list[int]: - return list(filter( - lambda index: index - 1 not in self.backslash_indices, - self.find_indices(re.escape(char)) - )) + #specified_spans = self.chain( + # inner_content_spans, + # *[ + # self.find_spans_by_selector(selector) + # for selector in self.tex_to_color_map.keys() + # ], + # self.find_spans_by_selector(self.isolate) + #) + #print(specified_spans) + #label_span_list = self.remove_redundancies(self.chain(*[ + # self.split_span(span) + # for span in specified_spans + #])) + #print(label_span_list) + #for span in all_specified_spans: + # adjusted_span, _, _ = self.adjust_span(span, align_level=True) + # if adjusted_span[0] > adjusted_span[1]: + # continue + # specified_spans.append(adjusted_span) - def get_brace_spans(self) -> list[Span]: - span_begins = [] - span_ends = [] - span_begins_stack = [] - char_items = sorted([ - (index, char) - for char in "{}" - for index in self.get_unescaped_char_indices(char) - ]) - for index, char in char_items: - if char == "{": - span_begins_stack.append(index) - else: - if not span_begins_stack: - raise ValueError("Missing '{' inserted") - span_begins.append(span_begins_stack.pop()) - span_ends.append(index + 1) - if span_begins_stack: - raise ValueError("Missing '}' inserted") - return list(zip(span_begins, span_ends)) - def get_script_char_indices(self) -> list[int]: - return self.chain(*[ - self.get_unescaped_char_indices(char) - for char in "_^" - ]) - def get_script_content_spans(self) -> list[Span]: - result = [] - script_entity_dict = dict(self.chain( - self.brace_spans, - self.command_spans - )) - for index in self.script_char_indices: - span_begin = self.match(r"\s*", pos=index + 1).end() - if span_begin in script_entity_dict.keys(): - span_end = script_entity_dict[span_begin] - else: - match_obj = self.match(r".", pos=span_begin) - if match_obj is None: - continue - span_end = match_obj.end() - result.append((span_begin, span_end)) - return result + #reversed_script_spans_dict = { + # span_end: span_begin + # for span_begin, _, span_end in script_items + #} + #label_span_list = [ + # (content_begin, span_end) + # for _, content_begin, span_end in script_items + #] + #for span_begin, span_end in specified_spans: + # while span_end in reversed_script_spans_dict: + # span_end = reversed_script_spans_dict[span_end] + # if span_begin >= span_end: + # continue + # shrinked_span = (span_begin, span_end) + # if shrinked_span in label_span_list: + # continue + # label_span_list.append(shrinked_span) - def get_script_spans(self) -> list[Span]: - return [ - ( - self.match(r"[\s\S]*?(\s*)$", endpos=index).start(1), - script_content_span[1] - ) - for index, script_content_span in zip( - self.script_char_indices, self.script_content_spans - ) - ] + #inserted_str_items = [ + # (span, ( + # ("{{", "{{" + self.get_color_command_str(label + 1)), + # ("}}", "}}"), + # )) + # for label, span in enumerate(label_span_list) + #] + #command_repl_items = [ + # ((index, index), str_pair) + # for index, str_pair in self.sort_obj_pairs_by_spans(inserted_str_items) + #] + #for cmd_span in command_spans: + # cmd_str = self.get_substr(cmd_span) + # if cmd_str not in TEX_COLOR_COMMANDS_DICT: + # continue + # repl_str = f"{cmd_str}{TEX_COLOR_COMMAND_SUFFIX}" + # command_repl_items.append((cmd_span, (cmd_str, repl_str))) + #print(decorated_strings) + #return specified_spans, label_span_list, decorated_strings - def get_command_repl_items(self) -> list[tuple[Span, str]]: - result = [] - brace_spans_dict = dict(self.brace_spans) - brace_begins = list(brace_spans_dict.keys()) - for cmd_span in self.command_spans: - cmd_name = self.get_substr(cmd_span) - if cmd_name not in TEX_COLOR_COMMANDS_DICT.keys(): + + + #self.command_spans = self.find_spans(r"\\(?:[a-zA-Z]+|.)") + #self.ignorable_indices = self.get_ignorable_indices() + #self.brace_content_spans = self.get_brace_content_spans() + #self.command_repl_items = self.get_command_repl_items() + ##self.backslash_indices = self.get_backslash_indices() + #self.ignorable_indices = self.get_ignorable_indices() + ##self.script_items = self.get_script_items() + ##self.script_char_indices = self.get_script_char_indices() + ##self.script_content_spans = self.get_script_content_spans() + ##self.script_spans = self.get_script_spans() + #self.specified_spans = self.get_specified_spans() + ##super().parse() + #self.label_span_list = self.get_label_span_list() + + def get_entity_spans(self) -> list[Span]: + return self.find_spans(r"\\(?:[a-zA-Z]+|.)") + + def get_internal_items( + self + ) -> tuple[list[tuple[Span, Span]], list[tuple[Span, dict[str, str]]]]: + command_spans = self.entity_spans + brace_span_pairs = [] + brace_begin_spans_stack = [] + for span in self.find_spans(r"[{}]"): + char_index = span[0] + if (char_index - 1, char_index + 1) in command_spans: continue - n_braces, substitute_cmd = TEX_COLOR_COMMANDS_DICT[cmd_name] - span_begin, span_end = cmd_span - for _ in range(n_braces): - span_end = brace_spans_dict[min(filter( - lambda index: index >= span_end, - brace_begins - ))] - if substitute_cmd: - repl_str = cmd_name + n_braces * "{black}" + if self.get_substr(span) == "{": + brace_begin_spans_stack.append(span) else: - repl_str = "" - result.append(((span_begin, span_end), repl_str)) - return result + if not brace_begin_spans_stack: + raise ValueError("Missing '{' inserted") + brace_span = brace_begin_spans_stack.pop() + brace_span_pairs.append((brace_span, span)) + if brace_begin_spans_stack: + raise ValueError("Missing '}' inserted") - def get_specified_spans(self) -> list[Span]: - # Match paired double braces (`{{...}}`). - sorted_brace_spans = sorted(self.brace_spans, key=lambda t: t[1]) - inner_brace_spans = [ - sorted_brace_spans[range_begin] + #tag_span_pairs = brace_span_pairs.copy() + script_entity_dict = dict(self.chain( + [ + (span_begin, span_end) + for (span_begin, _), (_, span_end) in brace_span_pairs + ], + command_spans + )) + script_additional_brace_spans = [ + (char_index + 1, script_entity_dict.get( + script_begin, script_begin + 1 + )) + for char_index, script_begin in self.find_spans(r"[_^]\s*(?=.)") + if (char_index - 1, char_index + 1) not in command_spans + ] + #for char_index, script_begin in self.find_spans(r"[_^]\s*(?=.)"): + # if (char_index - 1, char_index + 1) in command_spans: + # continue + # script_end = script_entity_dict.get(script_begin, script_begin + 1) + # tag_span_pairs.append( + # ((char_index, char_index + 1), (script_end, script_end)) + # ) + # script_additional_brace_spans.append((char_index + 1, script_end)) + + tag_span_pairs = self.chain( + brace_span_pairs, + [ + ((script_begin - 1, script_begin), (script_end, script_end)) + for script_begin, script_end in script_additional_brace_spans + ] + ) + + brace_content_spans = [ + (span_begin, span_end) + for (_, span_begin), (span_end, _) in brace_span_pairs + ] + internal_items = [ + (brace_content_spans[range_begin], {}) for _, (range_begin, range_end) in self.compress_neighbours([ (span_begin + index, span_end - index) for index, (span_begin, span_end) in enumerate( - sorted_brace_spans + brace_content_spans ) ]) if range_end - range_begin >= 2 ] - inner_brace_content_spans = [ - (span_begin + 1, span_end - 1) - for span_begin, span_end in inner_brace_spans - if span_end - span_begin > 2 + self.script_additional_brace_spans = script_additional_brace_spans + return tag_span_pairs, internal_items + + def get_external_items(self) -> list[tuple[Span, dict[str, str]]]: + return [ + (span, {"foreground": self.color_to_hex(color)}) + for selector, color in self.tex_to_color_map.items() + for span in self.find_spans_by_selector(selector) ] - result = self.remove_redundancies(self.chain( - inner_brace_content_spans, - *[ - self.find_spans_by_selector(selector) - for selector in self.tex_to_color_map.keys() - ], - self.find_spans_by_selector(self.isolate) - )) - return list(filter( - lambda span: not any([ - entity_begin < index < entity_end - for index in span - for entity_begin, entity_end in self.command_spans - ]), - result - )) + #def get_spans_from_items(self, specified_items: list[Span]) -> list[Span]: + # return specified_items - def get_label_span_list(self) -> list[Span]: - reversed_script_spans_dict = dict([ - script_span[::-1] for script_span in self.script_spans - ]) - skippable_indices = self.chain( - self.find_indices(r"\s"), - self.script_char_indices - ) - result = self.script_content_spans.copy() - for span in self.specified_spans: - span_begin, span_end = self.shrink_span(span, skippable_indices) - while span_end in reversed_script_spans_dict.keys(): - span_end = reversed_script_spans_dict[span_end] - if span_begin >= span_end: + #def get_split_items(self, specified_items: list[Span]) -> list[Span]: + # return self.remove_redundancies(self.chain(*[ + # self.split_span(span) + # for span in specified_items + # ])) + + def get_label_span_list(self, split_spans: list[Span]) -> list[Span]: + return split_spans + + def get_additional_inserted_str_pairs( + self + ) -> list[tuple[Span, tuple[str, str]]]: + return [ + (span, ("{", "}")) + for span in self.script_additional_brace_spans + ] + + def get_command_repl_items(self, is_labelled: bool) -> list[Span, str]: + if not is_labelled: + return [] + result = [] + command_spans = self.entity_spans # TODO + for cmd_span in command_spans: + cmd_str = self.get_substr(cmd_span) + if cmd_str not in TEX_COLOR_COMMANDS_DICT: continue - shrinked_span = (span_begin, span_end) - if shrinked_span in result: - continue - result.append(shrinked_span) + repl_str = f"{cmd_str}{TEX_COLOR_COMMAND_SUFFIX}" + result.append((cmd_span, repl_str)) return result - def get_content(self, is_labelled: bool) -> str: - if is_labelled: - extended_label_span_list = [] - script_spans_dict = dict(self.script_spans) - for span in self.label_span_list: - if span not in self.script_content_spans: - span_begin, span_end = span - while span_end in script_spans_dict.keys(): - span_end = script_spans_dict[span_end] - span = (span_begin, span_end) - extended_label_span_list.append(span) - inserted_string_pairs = [ - (span, ( - "{{" + self.get_color_command_str(label + 1), - "}}" - )) - for label, span in enumerate(extended_label_span_list) - ] - result = self.get_replaced_string( - inserted_string_pairs, self.command_repl_items - ) - else: - result = self.string + #def get_predefined_inserted_str_items( + # self, split_items: list[Span] + #) -> list[tuple[Span, tuple[str, str]]]: + # return [] + + #def get_ignorable_indices(self) -> list[int]: + # return self.chain( + # [ + # index + # for index, _ in self.find_spans(r"\s") + # ], + # [ + # index + # for index, _ in self.find_spans(r"[_^{}]") + # if (index - 1, index + 1) not in self.command_spans + # ], + # ) + + #def get_bracket_content_spans(self) -> list[Span]: + # span_begins = [] + # span_ends = [] + # span_begins_stack = [] + # for match_obj in re.finditer(r"[{}]", self.string): + # index = match_obj.start() + # if (index - 1, index + 1) in command_spans: + # continue + # if match_obj.group() == "{": + # span_begins_stack.append(index + 1) + # else: + # if not span_begins_stack: + # raise ValueError("Missing '{' inserted") + # span_begins.append(span_begins_stack.pop()) + # span_ends.append(index) + # if span_begins_stack: + # raise ValueError("Missing '}' inserted") + # return list(zip(span_begins, span_ends)) + + #def get_command_repl_items(self) -> list[tuple[Span, str]]: + # result = [] + # for cmd_span in self.command_spans: + # cmd_str = self.get_substr(cmd_span) + # if cmd_str in TEX_COLOR_COMMANDS_DICT: + # repl_str = f"{cmd_str}{TEX_COLOR_COMMAND_SUFFIX}" + # else: + # repl_str = cmd_str + # result.append((cmd_span, repl_str)) + # return result + + #def get_specified_spans(self) -> list[Span]: + # # Match paired double braces (`{{...}}`). + # sorted_content_spans = sorted( + # self.bracket_content_spans, key=lambda t: t[1] + # ) + # inner_content_spans = [ + # sorted_content_spans[range_begin] + # for _, (range_begin, range_end) in self.compress_neighbours([ + # (span_begin + index, span_end - index) + # for index, (span_begin, span_end) in enumerate( + # sorted_content_spans + # ) + # ]) + # if range_end - range_begin >= 2 + # ] + # #inner_content_spans = [ + # # (span_begin + 1, span_end - 1) + # # for span_begin, span_end in inner_brace_spans + # # if span_end - span_begin > 2 + # #] + + # return self.remove_redundancies(self.chain( + # inner_content_spans, + # *[ + # self.find_spans_by_selector(selector) + # for selector in self.tex_to_color_map.keys() + # ], + # self.find_spans_by_selector(self.isolate) + # )) + # #return list(filter( + # # lambda span: not any([ + # # entity_begin < index < entity_end + # # for index in span + # # for entity_begin, entity_end in self.command_spans + # # ]), + # # result + # #)) + + #def get_label_span_list(self) -> tuple[list[int], list[Span]]: + # script_entity_dict = dict(self.chain( + # [ + # (span_begin - 1, span_end + 1) + # for span_begin, span_end in self.bracket_content_spans + # ], + # self.command_spans + # )) + # script_items = [] + # for match_obj in re.finditer(r"\s*([_^])\s*(?=.)", self.string): + # char_index = match_obj.start(1) + # if (char_index - 1, char_index + 1) in self.command_spans: + # continue + # span_begin, content_begin = match_obj.span() + # span_end = script_entity_dict.get(span_begin, content_begin + 1) + # script_items.append( + # (span_begin, char_index, content_begin, span_end) + # ) + + # reversed_script_spans_dict = { + # span_end: span_begin + # for span_begin, _, _, span_end in script_items + # } + # ignorable_indices = self.chain( + # [index for index, _ in self.find_spans(r"\s")], + # [char_index for _, char_index, _, _ in script_items] + # ) + # result = [ + # (content_begin, span_end) + # for _, _, content_begin, span_end in script_items + # ] + # for span in self.specified_spans: + # span_begin, span_end = self.shrink_span(span, ignorable_indices) + # while span_end in reversed_script_spans_dict: + # span_end = reversed_script_spans_dict[span_end] + # if span_begin >= span_end: + # continue + # shrinked_span = (span_begin, span_end) + # if shrinked_span in result: + # continue + # result.append(shrinked_span) + # return result + + #def get_command_spans(self) -> list[Span]: + # return self.find_spans() + + #def get_command_repl_items(self) -> list[Span]: + # return [ + # (span, self.get_substr(span)) + # for span in self.find_spans(r"\\(?:[a-zA-Z]+|.)") + # ] + + #def get_command_spans(self) -> list[Span]: + # return self.find_spans(r"\\(?:[a-zA-Z]+|.)") + #return [ + # self.match(r"\\(?:[a-zA-Z]+|.)", pos=index).span() + # for index in self.backslash_indices + #] + + #@staticmethod + #def get_command_repl_dict() -> dict[str | re.Pattern, str]: + # return { + # cmd_name: f"{cmd_name}replaced" + # for cmd_name in TEX_COLOR_COMMANDS_DICT + # } + + #def get_backslash_indices(self) -> list[int]: + # # The latter of `\\` doesn't count. + # return self.find_indices(r"\\.") + + #def get_unescaped_char_indices(self, char: str) -> list[int]: + # return list(filter( + # lambda index: index - 1 not in self.backslash_indices, + # self.find_indices(re.escape(char)) + # )) + + #def get_script_items(self) -> list[tuple[int, int, int, int]]: + # script_entity_dict = dict(self.chain( + # self.brace_spans, + # self.command_spans + # )) + # result = [] + # for match_obj in re.finditer(r"\s*([_^])\s*(?=.)", self.string): + # char_index = match_obj.start(1) + # if char_index - 1 in self.backslash_indices: + # continue + # span_begin, content_begin = match_obj.span() + # span_end = script_entity_dict.get(span_begin, content_begin + 1) + # result.append((span_begin, char_index, content_begin, span_end)) + # return result + + #def get_script_char_indices(self) -> list[int]: + # return self.chain(*[ + # self.get_unescaped_char_indices(char) + # for char in "_^" + # ]) + + #def get_script_content_spans(self) -> list[Span]: + # result = [] + # script_entity_dict = dict(self.chain( + # self.brace_spans, + # self.command_spans + # )) + # for index in self.script_char_indices: + # span_begin = self.match(r"\s*", pos=index + 1).end() + # if span_begin in script_entity_dict.keys(): + # span_end = script_entity_dict[span_begin] + # else: + # match_obj = self.match(r".", pos=span_begin) + # if match_obj is None: + # continue + # span_end = match_obj.end() + # result.append((span_begin, span_end)) + # return result + + #def get_script_spans(self) -> list[Span]: + # return [ + # ( + # self.match(r"[\s\S]*?(\s*)$", endpos=index).start(1), + # script_content_span[1] + # ) + # for index, script_content_span in zip( + # self.script_char_indices, self.script_content_spans + # ) + # ] + + #def get_command_repl_items(self) -> list[tuple[Span, str]]: + # result = [] + # brace_spans_dict = dict(self.brace_spans) + # brace_begins = list(brace_spans_dict.keys()) + # for cmd_span in self.command_spans: + # cmd_name = self.get_substr(cmd_span) + # if cmd_name not in TEX_COLOR_COMMANDS_DICT: + # continue + # n_braces, substitute_cmd = TEX_COLOR_COMMANDS_DICT[cmd_name] + # span_begin, span_end = cmd_span + # for _ in range(n_braces): + # span_end = brace_spans_dict[min(filter( + # lambda index: index >= span_end, + # brace_begins + # ))] + # if substitute_cmd: + # repl_str = cmd_name + n_braces * "{black}" + # else: + # repl_str = "" + # result.append(((span_begin, span_end), repl_str)) + # return result + + #def get_inserted_string_pairs( + # self, is_labelled: bool + #) -> list[tuple[Span, tuple[str, str]]]: + # if not is_labelled: + # return [] + # return [ + # (span, ( + # "{{" + self.get_color_command_str(label + 1), + # "}}" + # )) + # for label, span in enumerate(self.label_span_list) + # ] + + def get_full_content_string( + self, + label_span_list: list[Span], + split_items: list[tuple[Span, dict[str, str]]], + is_labelled: bool + ) -> str: + result = super().get_full_content_string( + label_span_list, split_items, is_labelled + ) if self.tex_environment: if isinstance(self.tex_environment, str): @@ -295,9 +577,28 @@ class MTex(LabelledString): result = "\n".join([prefix, result, suffix]) if self.alignment: result = "\n".join([self.alignment, result]) - if not is_labelled: + + if is_labelled: + occurred_commands = [ + # TODO + self.get_substr(span) for span in self.entity_spans + ] + newcommand_lines = [ + "".join([ + f"\\newcommand{cmd_name}{TEX_COLOR_COMMAND_SUFFIX}", + f"[{n_braces + 1}][]", + "{", + cmd_name + "{black}" * n_braces if substitute_cmd else "", + "}" + ]) + for cmd_name, (n_braces, substitute_cmd) + in TEX_COLOR_COMMANDS_DICT.items() + if cmd_name in occurred_commands + ] + result = "\n".join([*newcommand_lines, result]) + else: result = "\n".join([ - self.get_color_command_str(self.base_color_int), + self.get_color_command_str(self.base_color_hex), result ]) return result @@ -305,41 +606,44 @@ class MTex(LabelledString): # Selector def get_cleaned_substr(self, span: Span) -> str: - left_brace_indices = [ - span_begin - for span_begin, _ in self.brace_spans - ] - right_brace_indices = [ - span_end - 1 - for _, span_end in self.brace_spans - ] - skippable_indices = self.chain( - self.find_indices(r"\s"), - self.script_char_indices, - left_brace_indices, - right_brace_indices - ) - shrinked_span = self.shrink_span(span, skippable_indices) + return self.get_substr(span) # TODO: test + #left_brace_indices = [ + # span_begin - 1 + # for span_begin, _ in self.brace_content_spans + #] + #right_brace_indices = [ + # span_end + # for _, span_end in self.brace_content_spans + #] + #skippable_indices = self.chain( + # self.ignorable_indices, + # #self.script_char_indices, + # left_brace_indices, + # right_brace_indices + #) + #shrinked_span = self.shrink_span(span, skippable_indices) - if shrinked_span[0] >= shrinked_span[1]: - return "" + ##if shrinked_span[0] >= shrinked_span[1]: + ## return "" - # Balance braces. - unclosed_left_braces = 0 - unclosed_right_braces = 0 - for index in range(*shrinked_span): - if index in left_brace_indices: - unclosed_left_braces += 1 - elif index in right_brace_indices: - if unclosed_left_braces == 0: - unclosed_right_braces += 1 - else: - unclosed_left_braces -= 1 - return "".join([ - unclosed_right_braces * "{", - self.get_substr(shrinked_span), - unclosed_left_braces * "}" - ]) + ## Balance braces. + #unclosed_left_braces = 0 + #unclosed_right_braces = 0 + #for index in range(*shrinked_span): + # if index in left_brace_indices: + # unclosed_left_braces += 1 + # elif index in right_brace_indices: + # if unclosed_left_braces == 0: + # unclosed_right_braces += 1 + # else: + # unclosed_left_braces -= 1 + ##adjusted_span, unclosed_left_braces, unclosed_right_braces \ + ## = self.adjust_span(span, align_level=False) + #return "".join([ + # unclosed_right_braces * "{", + # self.get_substr(shrinked_span), + # unclosed_left_braces * "}" + #]) # Method alias diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 0fe113ff..3b07a07e 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -50,13 +50,16 @@ DEFAULT_CANVAS_HEIGHT = 16384 # See https://docs.gtk.org/Pango/pango_markup.html -MARKUP_COLOR_KEYS = ( - "foreground", "fgcolor", "color", - "background", "bgcolor", - "underline_color", - "overline_color", - "strikethrough_color" -) +MARKUP_COLOR_KEYS_DICT = { + "foreground": False, + "fgcolor": False, + "color": False, + "background": True, + "bgcolor": True, + "underline_color": True, + "overline_color": True, + "strikethrough_color": True, +} MARKUP_TAG_CONVERSION_DICT = { "b": {"font_weight": "bold"}, "big": {"font_size": "larger"}, @@ -66,8 +69,17 @@ MARKUP_TAG_CONVERSION_DICT = { "sup": {"baseline_shift": "superscript", "font_scale": "superscript"}, "small": {"font_size": "smaller"}, "tt": {"font_family": "monospace"}, - "u": {"underline": "single"} + "u": {"underline": "single"}, } +# See https://gitlab.gnome.org/GNOME/glib/-/blob/main/glib/gmarkup.c +# Line 629, 2204 +XML_ENTITIES = ( + ("<", "<"), + (">", ">"), + ("&", "&"), + ("\"", """), + ("'", "'") +) # Temporary handler @@ -223,28 +235,47 @@ class MarkupText(LabelledString): f"{validate_error}" ) - def parse(self) -> None: - self.global_attr_dict = self.get_global_attr_dict() - self.tag_pairs_from_markup = self.get_tag_pairs_from_markup() - self.tag_spans = self.get_tag_spans() - self.items_from_markup = self.get_items_from_markup() - self.specified_items = self.get_specified_items() - super().parse() + #def parse(self) -> None: + # #self.global_attr_dict = self.get_global_attr_dict() + # #self.items_from_markup = self.get_items_from_markup() + # #self.tag_spans = self.get_tag_spans() + # ##self.items_from_markup = self.get_items_from_markup() + # #self.specified_items = self.get_specified_items() + # super().parse() + + #@property + #def sort_labelled_submobs(self) -> bool: + # return True # Toolkits @staticmethod - def get_attr_dict_str(attr_dict: dict[str, str]) -> str: - return " ".join([ + def get_tag_str( + attr_dict: dict[str, str], escape_color_keys: bool, is_begin_tag: bool + ) -> str: + if not is_begin_tag: + return "" + if escape_color_keys: + converted_attr_dict = {} + for key, val in attr_dict.items(): + substitute_key = MARKUP_COLOR_KEYS_DICT.get(key.lower(), None) + if substitute_key is None: + converted_attr_dict[key] = val + elif substitute_key: + converted_attr_dict[key] = "black" + else: + converted_attr_dict[key] = "black" + else: + converted_attr_dict = attr_dict.copy() + result = " ".join([ f"{key}='{val}'" - for key, val in attr_dict.items() + for key, val in converted_attr_dict.items() ]) - - # Parsing + return f"" def get_global_attr_dict(self) -> dict[str, str]: result = { - "foreground": self.int_to_hex(self.base_color_int), + "foreground": self.base_color_hex, "font_family": self.font, "font_style": self.slant, "font_weight": self.weight, @@ -263,60 +294,227 @@ class MarkupText(LabelledString): result["line_height"] = str(((line_spacing_scale) + 1) * 0.6) return result - def get_tag_pairs_from_markup( - self - ) -> list[tuple[Span, Span, dict[str, str]]]: + # Parsing + + #def parse(self) -> None: + # self.bracket_content_spans, self.command_repl_items \ + # = self.get_items_from_markup() + # #self.bracket_content_spans = [ + # # span for span, _ in items_from_markup + # #] + # #specified_items = self.get_specified_items() + # #self.command_repl_items = self.get_command_repl_items() + # #self.specified_spans = self.remove_redundancies([ + # # span for span, _ in specified_items + # #]) + # #self.label_span_list = self.get_label_span_list() + # #self.predefined_items = [ + # # (self.full_span, self.get_global_attr_dict()), + # # (self.full_span, self.global_config), + # # *specified_items + # #] + + #def parse(self) -> None: # TODO: type + # if not self.is_markup: + # return [], [], [ + # (span, (escaped, escaped)) + # for char, escaped in XML_ENTITIES + # for span in self.find_spans(re.escape(char)) + # ] + + #self.entity_spans = self.find_spans(r"&[\s\S]*?;") + + #tag_spans = [span for span, _ in command_repl_items] + #begin_tag_spans = [ + # begin_tag_span for begin_tag_span, _, _ in markup_tag_items + #] + #end_tag_spans = [ + # end_tag_span for _, end_tag_span, _ in markup_tag_items + #] + #tag_spans = self.chain(begin_tag_spans, end_tag_spans) + #command_repl_items = [ + # (tag_span, "") for tag_span in tag_spans + #] + #self.chain( + # [ + # (begin_tag_span, ( + # f"", + # f"" + # )) + # for begin_tag_span, _, attr_dict in markup_tag_items + # ], + # [ + # (end_tag_span, ("", "")) + # for _, end_tag_span, _ in markup_tag_items + # ] + #) + #self.piece_spans, self.piece_levels = self.init_piece_items( + # begin_tag_spans, end_tag_spans, self.find_spans(r"&[\s\S]*?;") + #) + #command_repl_items.extend([ + # (span, (self.get_substr(span), self.get_substr(span))) + # for span in self.find_spans(r"&[\s\S]*?;") + #]) + # Needed in plain text + + #specified_items = self.chain( + # [ + # ((span_begin, span_end), attr_dict) + # for (_, span_begin), (span_end, _), attr_dict + # in markup_tag_items + # ], + # self.get_specified_items() + #) + #specified_spans = self.remove_redundancies([ + # span for span, _ in specified_items + #]) + #specified_items = [] + #for span, attr_dict in all_specified_items: + # for + # adjusted_span, _, _ = self.adjust_span(span, align_level=True) + # if adjusted_span[0] > adjusted_span[1]: + # continue + # specified_items.append(adjusted_span, attr_dict) + + + #predefined_items = [ + # (self.full_span, self.get_global_attr_dict()), + # (self.full_span, self.global_config), + # *split_items + #] + #inserted_str_items = self.chain( + # [ + # (span, ( + # ( + # f"", + # f"" + # ), + # ("", "") + # )) + # for span, attr_dict in predefined_items + # ], + # [ + # (span, ( + # ("", f""), + # ("", ""), + # )) + # for label, span in enumerate(label_span_list) + # ] + #) + #command_repl_items = self.chain( + # [ + # (tag_span, ("", "")) for tag_span in self.tag_spans + # ], + # [ + # ((index, index), str_pair) + # for index, str_pair in self.sort_obj_pairs_by_spans(inserted_str_items) + # ] + #) + #decorated_strings = [ + # self.get_replaced_substr(self.full_span, [ + # (span, str_pair[flag]) + # for span, str_pair in command_repl_items + # ]) + # for flag in range(2) + #] + #return specified_spans, label_span_list, decorated_strings + + + + + + #if is_labelled: + # attr_dict_items = self.chain( + # [ + # (span, { + # key: + # "black" if key.lower() in MARKUP_COLOR_KEYS else val + # for key, val in attr_dict.items() + # }) + # for span, attr_dict in self.predefined_items + # ], + # [ + # (span, {"foreground": self.int_to_hex(label + 1)}) + # for label, span in enumerate(self.label_span_list) + # ] + # ) + #else: + # attr_dict_items = self.chain( + # self.predefined_items, + # [ + # (span, {}) + # for span in self.label_span_list + # ] + # ) + #return [ + # (span, ( + # f"", + # "" + # )) + # for span, attr_dict in attr_dict_items + #] + #inserted_string_pairs = [ + # (indices, str_pair) + # for indices, str_pair in self.get_inserted_string_pairs( + # is_labelled=is_labelled + # ) + # if not any( + # cmd_begin < index < cmd_end + # for index in indices + # for (cmd_begin, cmd_end), _ in self.command_repl_items + # ) + #] + #return bracket_content_spans, label_span_list, command_repl_items + + def get_entity_spans(self) -> list[Span]: if not self.is_markup: return [] + return self.find_spans(r"&[\s\S]*?;") + + def get_internal_items( + self + ) -> tuple[list[tuple[Span, Span]], list[tuple[Span, dict[str, str]]]]: + if not self.is_markup: + return [], [] tag_pattern = r"""<(/?)(\w+)\s*((\w+\s*\=\s*(['"])[\s\S]*?\5\s*)*)>""" attr_pattern = r"""(\w+)\s*\=\s*(['"])([\s\S]*?)\2""" begin_match_obj_stack = [] - match_obj_pairs = [] + markup_tag_items = [] for match_obj in re.finditer(tag_pattern, self.string): if not match_obj.group(1): begin_match_obj_stack.append(match_obj) - else: - match_obj_pairs.append( - (begin_match_obj_stack.pop(), match_obj) - ) - - result = [] - for begin_match_obj, end_match_obj in match_obj_pairs: + continue + begin_match_obj = begin_match_obj_stack.pop() tag_name = begin_match_obj.group(2) if tag_name == "span": attr_dict = { - match.group(1): match.group(3) - for match in re.finditer( + attr_match_obj.group(1): attr_match_obj.group(3) + for attr_match_obj in re.finditer( attr_pattern, begin_match_obj.group(3) ) } else: attr_dict = MARKUP_TAG_CONVERSION_DICT.get(tag_name, {}) - - result.append( - (begin_match_obj.span(), end_match_obj.span(), attr_dict) + markup_tag_items.append( + (begin_match_obj.span(), match_obj.span(), attr_dict) ) - return result - def get_tag_spans(self) -> list[Span]: - return self.chain( - (begin_tag_span, end_tag_span) - for begin_tag_span, end_tag_span, _ in self.tag_pairs_from_markup - ) - - def get_items_from_markup(self) -> list[tuple[Span, dict[str, str]]]: - return [ - ((span_begin, span_end), attr_dict) - for (_, span_begin), (span_end, _), attr_dict - in self.tag_pairs_from_markup - if span_begin < span_end + tag_span_pairs = [ + (tag_begin_span, tag_end_span) + for tag_begin_span, tag_end_span, _ in markup_tag_items ] + internal_items = [ + ((span_begin, span_end), attr_dict) + for (_, span_begin), (span_end, _), attr_dict in markup_tag_items + ] + return tag_span_pairs, internal_items - def get_specified_items(self) -> list[tuple[Span, dict[str, str]]]: - result = self.chain( - self.items_from_markup, - [ + def get_external_items(self) -> list[tuple[Span, dict[str, str]]]: + return [ + (self.full_span, self.get_global_attr_dict()), + (self.full_span, self.global_config), + *[ (span, {key: val}) for t2x_dict, key in ( (self.t2c, "foreground"), @@ -327,60 +525,37 @@ class MarkupText(LabelledString): for selector, val in t2x_dict.items() for span in self.find_spans_by_selector(selector) ], - [ + *[ (span, local_config) for selector, local_config in self.local_configs.items() for span in self.find_spans_by_selector(selector) - ], - [ - (span, {}) - for span in self.find_spans_by_selector(self.isolate) ] - ) - entity_spans = self.tag_spans.copy() - if self.is_markup: - entity_spans.extend(self.find_spans(r"&[\s\S]*?;")) - return [ - (span, attr_dict) - for span, attr_dict in result - if not any([ - entity_begin < index < entity_end - for index in span - for entity_begin, entity_end in entity_spans - ]) ] - def get_command_repl_items(self) -> list[tuple[Span, str]]: - result = [ - (tag_span, "") for tag_span in self.tag_spans - ] - if not self.is_markup: - result.extend([ - (span, escaped) - for char, escaped in ( - ("&", "&"), - (">", ">"), - ("<", "<") - ) - for span in self.find_spans(re.escape(char)) - ]) - return result + #def get_spans_from_items( + # self, specified_items: list[tuple[Span, dict[str, str]]] + #) -> list[Span]: + # return [span for span, _ in specified_items] - def get_specified_spans(self) -> list[Span]: - return self.remove_redundancies([ - span for span, _ in self.specified_items - ]) + #def get_split_items( + # self, specified_items: list[tuple[Span, dict[str, str]]] + #) -> list[tuple[Span, dict[str, str]]]: + # return [ + # (span, attr_dict) + # for specified_span, attr_dict in specified_items + # for span in self.split_span(specified_span) + # ] - def get_label_span_list(self) -> list[Span]: + def get_label_span_list(self, split_spans: list[Span]) -> list[Span]: interval_spans = sorted(self.chain( self.tag_spans, [ (index, index) - for span in self.specified_spans + for span in split_spans for index in span ] )) - text_spans = self.get_complement_spans(interval_spans, self.full_span) + text_spans = self.get_complement_spans(self.full_span, interval_spans) if self.is_markup: pattern = r"[0-9a-zA-Z]+|(?:&[\s\S]*?;|[^0-9a-zA-Z\s])+" else: @@ -390,54 +565,209 @@ class MarkupText(LabelledString): for span_begin, span_end in text_spans ]) + def get_additional_inserted_str_pairs( + self + ) -> list[tuple[Span, tuple[str, str]]]: + return [] + + def get_command_repl_items(self, is_labelled: bool) -> list[Span, str]: + result = [ + (tag_span, "") for tag_span in self.tag_spans + ] + if not self.is_markup: + result.extend([ + (span, escaped) + for char, escaped in XML_ENTITIES + for span in self.find_spans(re.escape(char)) + ]) + return result + + #def get_predefined_inserted_str_items( + # self, split_items: list[tuple[Span, dict[str, str]]] + #) -> list[tuple[Span, tuple[str, str]]]: + # predefined_items = [ + # (self.full_span, self.get_global_attr_dict()), + # (self.full_span, self.global_config), + # *split_items + # ] + # return [ + # (span, ( + # ( + # self.get_tag_str(attr_dict, escape_color_keys=False, is_begin_tag=True), + # self.get_tag_str(attr_dict, escape_color_keys=True, is_begin_tag=True) + # ), + # ( + # self.get_tag_str(attr_dict, escape_color_keys=False, is_begin_tag=False), + # self.get_tag_str(attr_dict, escape_color_keys=True, is_begin_tag=False) + # ) + # )) + # for span, attr_dict in predefined_items + # ] + + #def get_full_content_string(self, replaced_string: str) -> str: + # return replaced_string + + #def get_tag_spans(self) -> list[Span]: + # return self.chain( + # (begin_tag_span, end_tag_span) + # for begin_tag_span, end_tag_span, _ in self.items_from_markup + # ) + + #def get_items_from_markup(self) -> list[tuple[Span, dict[str, str]]]: + # return [ + # ((span_begin, span_end), attr_dict) + # for (_, span_begin), (span_end, _), attr_dict + # in self.items_from_markup + # if span_begin < span_end + # ] + + #def get_command_repl_items(self) -> list[tuple[Span, str]]: + # result = [ + # (tag_span, "") + # for tag_span in self.tag_spans + # ] + # if self.is_markup: + # result.extend([ + # (span, self.get_substr(span)) + # for span in self.find_spans(r"&[\s\S]*?;") + # ]) + # else: + # result.extend([ + # (span, escaped) + # for char, escaped in ( + # ("&", "&"), + # (">", ">"), + # ("<", "<") + # ) + # for span in self.find_spans(re.escape(char)) + # ]) + # return result + + #def get_command_spans(self) -> list[Span]: + # result = self.tag_spans.copy() + # if self.is_markup: + # result.extend(self.find_spans(r"&[\s\S]*?;")) + # else: + # result.extend(self.find_spans(r"[&<>]")) + # return result + + #@staticmethod + #def get_command_repl_dict() -> dict[str | re.Pattern, str]: + # return { + # re.compile(r"<.*>"): "", + # "&": "&", + # "<": "<", + # ">": ">" + # } + # #result = [ + # # (tag_span, "") for tag_span in self.tag_spans + # #] + # #if self.is_markup: + # # result.extend([ + # # (span, self.get_substr(span)) + # # for span in self.find_spans(r"&[\s\S]*?;") + # # ]) + # #else: + # # result.extend([ + # # (span, escaped) + # # for char, escaped in ( + # # ("&", "&"), + # # (">", ">"), + # # ("<", "<") + # # ) + # # for span in self.find_spans(re.escape(char)) + # # ]) + # #return result + #entity_spans = self.tag_spans.copy() + #if self.is_markup: + # entity_spans.extend(self.find_spans(r"&[\s\S]*?;")) + #return [ + # (span, attr_dict) + # for span, attr_dict in result + # if not self.span_cuts_at_entity(span) + # #if not any([ + # # entity_begin < index < entity_end + # # for index in span + # # for entity_begin, entity_end in entity_spans + # #]) + #] + + #def get_specified_spans(self) -> list[Span]: + # return self.remove_redundancies([ + # span for span, _ in self.specified_items + # ]) + + #def get_label_span_list(self) -> list[Span]: + # interval_spans = sorted(self.chain( + # self.tag_spans, + # [ + # (index, index) + # for span in self.specified_spans + # for index in span + # ] + # )) + # text_spans = self.get_complement_spans(interval_spans, self.full_span) + # if self.is_markup: + # pattern = r"[0-9a-zA-Z]+|(?:&[\s\S]*?;|[^0-9a-zA-Z\s])+" + # else: + # pattern = r"[0-9a-zA-Z]+|[^0-9a-zA-Z\s]+" + # return self.chain(*[ + # self.find_spans(pattern, pos=span_begin, endpos=span_end) + # for span_begin, span_end in text_spans + # ]) + + #def get_inserted_string_pairs( + # self, is_labelled: bool + #) -> list[tuple[Span, tuple[str, str]]]: + # #predefined_items = [ + # # (self.full_span, self.global_attr_dict), + # # (self.full_span, self.global_config), + # # *self.specified_items + # #] + # if is_labelled: + # attr_dict_items = self.chain( + # [ + # (span, { + # key: + # "black" if key.lower() in MARKUP_COLOR_KEYS else val + # for key, val in attr_dict.items() + # }) + # for span, attr_dict in self.predefined_items + # ], + # [ + # (span, {"foreground": self.int_to_hex(label + 1)}) + # for label, span in enumerate(self.label_span_list) + # ] + # ) + # else: + # attr_dict_items = self.chain( + # self.predefined_items, + # [ + # (span, {}) + # for span in self.label_span_list + # ] + # ) + # return [ + # (span, ( + # f"", + # "" + # )) + # for span, attr_dict in attr_dict_items + # ] + def get_content(self, is_labelled: bool) -> str: - predefined_items = [ - (self.full_span, self.global_attr_dict), - (self.full_span, self.global_config), - *self.specified_items - ] - if is_labelled: - attr_dict_items = self.chain( - [ - (span, { - key: - "black" if key.lower() in MARKUP_COLOR_KEYS else val - for key, val in attr_dict.items() - }) - for span, attr_dict in predefined_items - ], - [ - (span, {"foreground": self.int_to_hex(label + 1)}) - for label, span in enumerate(self.label_span_list) - ] - ) - else: - attr_dict_items = self.chain( - predefined_items, - [ - (span, {}) - for span in self.label_span_list - ] - ) - inserted_string_pairs = [ - (span, ( - f"", - "" - )) - for span, attr_dict in attr_dict_items if attr_dict - ] - return self.get_replaced_string( - inserted_string_pairs, self.command_repl_items - ) + return self.decorated_strings[is_labelled] # Selector def get_cleaned_substr(self, span: Span) -> str: - repl_items = list(filter( - lambda repl_item: self.span_contains(span, repl_item[0]), - self.command_repl_items - )) - return self.get_replaced_substr(span, repl_items).strip() + return self.get_substr(span) # TODO: test + #repl_items = [ + # (cmd_span, repl_str) + # for cmd_span, (repl_str, _) in self.command_repl_items + # if self.span_contains(span, cmd_span) + #] + #return self.get_replaced_substr(span, repl_items).strip() # Method alias