From ab8f78f40fc68c1d8d940dc4bf835284c3355793 Mon Sep 17 00:00:00 2001 From: YishiMichael Date: Tue, 3 May 2022 23:39:37 +0800 Subject: [PATCH] [WIP] Refactor LabelledString and relevant classes --- manimlib/animation/creation.py | 5 +- .../animation/transform_matching_parts.py | 64 ++-- manimlib/mobject/svg/labelled_string.py | 303 ++++++++++-------- manimlib/mobject/svg/mtex_mobject.py | 179 +++++------ manimlib/mobject/svg/text_mobject.py | 96 +++--- 5 files changed, 348 insertions(+), 299 deletions(-) diff --git a/manimlib/animation/creation.py b/manimlib/animation/creation.py index 6ad6a9bd..42ca4bf8 100644 --- a/manimlib/animation/creation.py +++ b/manimlib/animation/creation.py @@ -213,8 +213,9 @@ class AddTextWordByWord(ShowIncreasingSubsets): def __init__(self, string_mobject, **kwargs): assert isinstance(string_mobject, LabelledString) - grouped_mobject = VGroup(*[ - part for _, part in string_mobject.get_group_part_items() + grouped_mobject = string_mobject.build_parts_from_indices_lists([ + indices_list + for _, indices_list in string_mobject.get_group_part_items() ]) digest_config(self, kwargs) if self.run_time is None: diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index 96fd95ce..afa7e1ab 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -168,73 +168,67 @@ class TransformMatchingStrings(AnimationGroup): assert isinstance(source, LabelledString) assert isinstance(target, LabelledString) anims = [] + source_indices = list(range(len(source.labels))) + target_indices = list(range(len(target.labels))) - source_submobs = [ - submob for _, submob in source.labelled_submobject_items - ] - target_submobs = [ - submob for _, submob in target.labelled_submobject_items - ] - source_indices = list(range(len(source_submobs))) - target_indices = list(range(len(target_submobs))) - - def get_filtered_indices_lists(parts, submobs, rest_indices): + def get_filtered_indices_lists(indices_lists, rest_indices): return list(filter( - lambda indices_list: all([ + lambda indices_list: all( index in rest_indices for index in indices_list - ]), - [ - [submobs.index(submob) for submob in part] - for part in parts - ] + ), + indices_lists )) - def add_anims(anim_class, parts_pairs): - for source_parts, target_parts in parts_pairs: + def add_anims(anim_class, indices_lists_pairs): + for source_indices_lists, target_indices_lists in indices_lists_pairs: source_indices_lists = get_filtered_indices_lists( - source_parts, source_submobs, source_indices + source_indices_lists, source_indices ) target_indices_lists = get_filtered_indices_lists( - target_parts, target_submobs, target_indices + target_indices_lists, target_indices ) if not source_indices_lists or not target_indices_lists: continue - anims.append(anim_class(source_parts, target_parts, **kwargs)) + anims.append(anim_class( + source.build_parts_from_indices_lists(source_indices_lists), + target.build_parts_from_indices_lists(target_indices_lists), + **kwargs + )) for index in it.chain(*source_indices_lists): source_indices.remove(index) for index in it.chain(*target_indices_lists): target_indices.remove(index) - def get_substr_to_parts_map(part_items): + def get_substr_to_indices_lists_map(part_items): result = {} - for substr, part in part_items: + for substr, indices_list in part_items: if substr not in result: result[substr] = [] - result[substr].append(part) + result[substr].append(indices_list) return result def add_anims_from(anim_class, func): - source_substr_to_parts_map = get_substr_to_parts_map(func(source)) - target_substr_to_parts_map = get_substr_to_parts_map(func(target)) + source_substr_map = get_substr_to_indices_lists_map(func(source)) + target_substr_map = get_substr_to_indices_lists_map(func(target)) + common_substrings = sorted([ + s for s in source_substr_map if s and s in target_substr_map + ], key=len, reverse=True) add_anims( anim_class, [ - ( - VGroup(*source_substr_to_parts_map[substr]), - VGroup(*target_substr_to_parts_map[substr]) - ) - for substr in sorted([ - s for s in source_substr_to_parts_map - if s and s in target_substr_to_parts_map - ], key=len, reverse=True) + (source_substr_map[substr], target_substr_map[substr]) + for substr in common_substrings ] ) add_anims( ReplacementTransform, [ - (source.select_parts(k), target.select_parts(v)) + ( + source.get_submob_indices_lists_by_selector(k), + target.get_submob_indices_lists_by_selector(v) + ) for k, v in self.key_map.items() ] ) diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 23c285de..ef26ecf1 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -5,6 +5,7 @@ import itertools as it import re from manimlib.constants import WHITE +from manimlib.logger import log from manimlib.mobject.svg.svg_mobject import SVGMobject from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.utils.color import color_to_rgb @@ -63,6 +64,7 @@ class LabelledString(SVGMobject, ABC): (submob.label, submob) for submob in self.submobjects ] + self.labels = [label for label, _ in self.labelled_submobject_items] def get_file_path(self) -> str: return self.get_file_path_by_content(self.original_content) @@ -78,26 +80,30 @@ class LabelledString(SVGMobject, ABC): labelled_svg = SVGMobject(file_path) num_submobjects = len(self.submobjects) if num_submobjects != len(labelled_svg.submobjects): - raise ValueError( + log.warning( "Cannot align submobjects of the labelled svg " - "to the original svg" - ) - - submob_color_ints = [ - self.hex_to_int(self.color_to_hex(submob.get_fill_color())) - for submob in labelled_svg.submobjects - ] - unrecognized_color_ints = self.remove_redundancies(sorted(filter( - lambda color_int: color_int > len(self.label_span_list), - submob_color_ints - ))) - if unrecognized_color_ints: - raise ValueError( - "Unrecognized color label(s) detected: " - f"{', '.join(map(self.int_to_hex, unrecognized_color_ints))}" + "to the original svg. Skip the labelling process." ) + submob_color_ints = [0] * num_submobjects + else: + submob_color_ints = [ + self.hex_to_int(self.color_to_hex(submob.get_fill_color())) + for submob in labelled_svg.submobjects + ] + unrecognized_colors = list(filter( + lambda color_int: color_int > len(self.labelled_spans), + submob_color_ints + )) + if unrecognized_colors: + log.warning( + "Unrecognized color label(s) detected (%s, etc). " + "Skip the labelling process.", + self.int_to_hex(unrecognized_colors[0]) + ) + submob_color_ints = [0] * num_submobjects #if self.sort_labelled_submobs: + # TODO: remove this submob_indices = sorted( range(num_submobjects), key=lambda index: tuple( @@ -135,12 +141,10 @@ class LabelledString(SVGMobject, ABC): # pattern = re.compile(pattern) # return re.compile(pattern).match(self.string, **kwargs) - def find_spans(self, pattern: str | re.Pattern, **kwargs) -> list[Span]: - if isinstance(pattern, str): - pattern = re.compile(pattern) + def find_spans(self, pattern: str) -> list[Span]: return [ match_obj.span() - for match_obj in pattern.finditer(self.string, **kwargs) + for match_obj in re.finditer(pattern, self.string) ] #def find_indices(self, pattern: str | re.Pattern, **kwargs) -> list[int]: @@ -151,7 +155,18 @@ class LabelledString(SVGMobject, ABC): if isinstance(sel, str): return self.find_spans(re.escape(sel)) if isinstance(sel, re.Pattern): - return self.find_spans(sel) + result_iterator = sel.finditer(self.string) + if not sel.groups: + return [ + match_obj.span() + for match_obj in result_iterator + ] + return [ + span + for match_obj in result_iterator + for span in match_obj.regs[1:] + if span != (-1, -1) + ] if isinstance(sel, tuple) and len(sel) == 2 and all( isinstance(index, int) or index is None for index in sel @@ -225,7 +240,7 @@ class LabelledString(SVGMobject, ABC): def span_contains(span_0: Span, span_1: Span) -> bool: return span_0[0] <= span_1[0] and span_0[1] >= span_1[1] - def get_piece_items( + def get_level_items( self, tag_span_pairs: list[tuple[Span, Span]], entity_spans: list[Span] @@ -241,7 +256,7 @@ class LabelledString(SVGMobject, ABC): piece_levels = [0, *it.accumulate([tag for _, tag in tagged_items])] return piece_spans, piece_levels - def split_span(self, arbitrary_span: Span) -> list[Span]: + def split_span_by_levels(self, arbitrary_span: Span) -> list[Span]: # ignorable_indices -- # left_bracket_spans # right_bracket_spans @@ -413,10 +428,10 @@ class LabelledString(SVGMobject, ABC): @staticmethod @abstractmethod - def get_tag_str( - attr_dict: dict[str, str], escape_color_keys: bool, is_begin_tag: bool - ) -> str: - return "" + def get_tag_string_pair( + attr_dict: dict[str, str], label_hex: str | None + ) -> tuple[str, str]: + return ("", "") #def get_color_tag_str(self, rgb_int: int, is_begin_tag: bool) -> str: # return self.get_tag_str({ @@ -481,7 +496,7 @@ class LabelledString(SVGMobject, ABC): def parse(self) -> None: self.entity_spans = self.get_entity_spans() tag_span_pairs, internal_items = self.get_internal_items() - self.piece_spans, self.piece_levels = self.get_piece_items( + self.piece_spans, self.piece_levels = self.get_level_items( tag_span_pairs, self.entity_spans ) #self.tag_content_spans = [ @@ -497,26 +512,19 @@ class LabelledString(SVGMobject, ABC): for span in self.find_spans_by_selector(self.isolate) ] ) - print(f"\n{specified_items=}\n") - specified_spans = [span for span, _ in specified_items] - for span_0, span_1 in it.product(specified_spans, repeat=2): - if not span_0[0] < span_1[0] < span_0[1] < span_1[1]: - continue - raise ValueError( - "Partially overlapping substrings detected: " - f"'{self.get_substr(span_0)}' and '{self.get_substr(span_1)}'" - ) + #print(f"\n{specified_items=}\n") + #specified_spans = split_items = [ (span, attr_dict) for specified_span, attr_dict in specified_items - for span in self.split_span(specified_span) + for span in self.split_span_by_levels(specified_span) ] - print(f"\n{split_items=}\n") - split_spans = [span for span, _ in split_items] - label_span_list = self.get_label_span_list(split_spans) - if len(label_span_list) >= 16777216: - raise ValueError("Cannot handle that many substrings") + #print(f"\n{split_items=}\n") + #labelled_spans = [span for span, _ in split_items] + #labelled_spans = self.get_labelled_spans(split_spans) + #if len(labelled_spans) >= 16777216: + # raise ValueError("Cannot handle that many substrings") #content_strings = [] #for is_labelled in (False, True): @@ -549,17 +557,66 @@ class LabelledString(SVGMobject, ABC): # for flag in range(2) #] - self.specified_spans = specified_spans - self.label_span_list = label_span_list - self.original_content = self.get_full_content_string( - label_span_list, split_items, is_labelled=False + command_repl_items = self.get_command_repl_items() + + #full_content_strings = {} + #for is_labelled in (False, True): + # inserted_str_pairs = [ + # (span, self.get_tag_string_pair( + # attr_dict, + # rgb_hex=self.int_to_hex(label + 1) if is_labelled else None + # )) + # for label, (span, attr_dict) in enumerate(split_items) + # ] + # repl_items = self.chain( + # command_repl_items, + # [ + # ((index, index), inserted_str) + # for index, inserted_str + # in self.sort_obj_pairs_by_spans(inserted_str_pairs) + # ] + # ) + # content_string = self.get_replaced_substr( + # self.full_span, repl_items + # ) + # full_content_string = self.get_full_content_string(content_string) + # #full_content_strings[is_labelled] = full_content_string + + self.specified_spans = [span for span, _ in specified_items] + self.labelled_spans = [span for span, _ in split_items] + for span_0, span_1 in it.product(self.labelled_spans, repeat=2): + if not span_0[0] < span_1[0] < span_0[1] < span_1[1]: + continue + raise ValueError( + "Partially overlapping substrings detected: " + f"'{self.get_substr(span_0)}' and '{self.get_substr(span_1)}'" + ) + + self.original_content, self.labelled_content = ( + self.get_full_content_string(self.get_replaced_substr( + self.full_span, self.chain( + command_repl_items, + [ + ((index, index), inserted_str) + for index, inserted_str in self.sort_obj_pairs_by_spans([ + (span, self.get_tag_string_pair( + attr_dict, + label_hex=self.int_to_hex(label + 1) if is_labelled else None + )) + for label, (span, attr_dict) in enumerate(split_items) + ]) + ] + ) + ), is_labelled=is_labelled) + for is_labelled in (False, True) ) - self.labelled_content = self.get_full_content_string( - label_span_list, split_items, is_labelled=True - ) - print(self.original_content) - print() - print(self.labelled_content) + + + #self.original_content = full_content_strings[False] + #self.labelled_content = full_content_strings[True] + #print(self.original_content) + #print() + #print(self.labelled_content) #self.command_repl_dict = self.get_command_repl_dict() @@ -569,8 +626,8 @@ class LabelledString(SVGMobject, ABC): ##self.specified_items = self.get_specified_items() #self.specified_spans = [] #self.check_overlapping() ####### - #self.label_span_list = [] - #if len(self.label_span_list) >= 16777216: + #self.labelled_spans = [] + #if len(self.labelled_spans) >= 16777216: # raise ValueError("Cannot handle that many substrings") @abstractmethod @@ -636,9 +693,9 @@ class LabelledString(SVGMobject, ABC): #def get_split_items(self, specified_items: list[T]) -> list[T]: # return [] - @abstractmethod - def get_label_span_list(self, split_spans: list[Span]) -> list[Span]: - return [] + #@abstractmethod + #def get_labelled_spans(self, split_spans: list[Span]) -> list[Span]: + # return [] #@abstractmethod #def get_predefined_inserted_str_items( @@ -666,7 +723,7 @@ class LabelledString(SVGMobject, ABC): # return [] #@abstractmethod - #def get_label_span_list(self) -> list[Span]: + #def get_labelled_spans(self) -> list[Span]: # return [] #def get_decorated_string( @@ -694,56 +751,19 @@ class LabelledString(SVGMobject, ABC): # repl_items.extend(self.command_repl_items) # return self.get_replaced_substr(self.full_span, repl_items) + #@abstractmethod + #def get_additional_inserted_str_pairs( + # self + #) -> list[tuple[Span, tuple[str, str]]]: + # return [] + @abstractmethod - def get_additional_inserted_str_pairs( - self - ) -> list[tuple[Span, tuple[str, str]]]: + def get_command_repl_items(self) -> list[Span, str]: return [] @abstractmethod - def get_command_repl_items(self, is_labelled: bool) -> list[Span, str]: - return [] - - def get_full_content_string( - self, - label_span_list: list[Span], - split_items: list[tuple[Span, dict[str, str]]], - is_labelled: bool - ) -> str: - label_items = [ - (span, { - "foreground": self.int_to_hex(label + 1) - } if is_labelled else {}) - for label, span in enumerate(label_span_list) - ] - inserted_str_pairs = self.chain( - self.get_additional_inserted_str_pairs(), - [ - (span, tuple( - self.get_tag_str( - attr_dict, - escape_color_keys=is_labelled and not is_label_item, - is_begin_tag=is_begin_tag - ) - for is_begin_tag in (True, False) - )) - for is_label_item, items in enumerate(( - split_items, label_items - )) - for span, attr_dict in items - ] - ) - repl_items = self.chain( - self.get_command_repl_items(is_labelled), - [ - ((index, index), inserted_str) - for index, inserted_str - in self.sort_obj_pairs_by_spans(inserted_str_pairs) - ] - ) - return self.get_replaced_substr( - self.full_span, repl_items - ) + def get_full_content_string(self, content_string: str, is_labelled: bool) -> str: + return "" #def get_content(self, is_labelled: bool) -> str: # return self.content_strings[int(is_labelled)] @@ -754,16 +774,15 @@ class LabelledString(SVGMobject, ABC): def get_cleaned_substr(self, span: Span) -> str: return "" - def get_group_part_items(self) -> list[tuple[str, VGroup]]: - if not self.labelled_submobject_items: + def get_group_part_items(self) -> list[tuple[str, list[int]]]: + if not self.labels: return [] - labels, labelled_submobjects = zip(*self.labelled_submobject_items) group_labels, labelled_submob_ranges = zip( - *self.compress_neighbours(labels) + *self.compress_neighbours(self.labels) ) ordered_spans = [ - self.label_span_list[label] if label != -1 else self.full_span + self.labelled_spans[label] if label != -1 else self.full_span for label in group_labels ] interval_spans = [ @@ -785,37 +804,67 @@ class LabelledString(SVGMobject, ABC): (ordered_spans[0][0], ordered_spans[-1][1]), interval_spans ) ] - submob_groups = VGroup(*[ - VGroup(*labelled_submobjects[slice(*submob_range)]) + submob_indices_lists = [ + list(range(*submob_range)) for submob_range in labelled_submob_ranges - ]) - return list(zip(group_substrs, submob_groups)) + ] + return list(zip(group_substrs, submob_indices_lists)) - def get_specified_part_items(self) -> list[tuple[str, VGroup]]: + def get_submob_indices_list_by_span( + self, arbitrary_span: Span + ) -> list[int]: + return [ + submob_index + for submob_index, label in enumerate(self.labels) + if label != -1 and self.span_contains( + arbitrary_span, self.labelled_spans[label] + ) + ] + + def get_specified_part_items(self) -> list[tuple[str, list[int]]]: return [ ( self.get_substr(span), - self.select_part_by_span(span) + self.get_submob_indices_list_by_span(span) ) for span in self.specified_spans ] - def select_part_by_span(self, arbitrary_span: Span) -> VGroup: - return VGroup(*[ - submob for label, submob in self.labelled_submobject_items - if label != -1 - and self.span_contains(arbitrary_span, self.label_span_list[label]) - ]) - - def select_parts(self, selector: Selector) -> VGroup: - return VGroup(*filter( - lambda part: part.submobjects, + def get_submob_indices_lists_by_selector( + self, selector: Selector + ) -> list[list[int]]: + return list(filter( + lambda indices_list: indices_list, [ - self.select_part_by_span(span) + self.get_submob_indices_list_by_span(span) for span in self.find_spans_by_selector(selector) ] )) + def build_parts_from_indices_lists( + self, submob_indices_lists: list[list[int]] + ) -> VGroup: + return VGroup(*[ + VGroup(*[ + self.labelled_submobject_items[submob_index][1] + for submob_index in indices_list + ]) + for indices_list in submob_indices_lists + ]) + + #def select_part_by_span(self, arbitrary_span: Span) -> VGroup: + # return VGroup(*[ + # self.labelled_submobject_items[submob_index] + # for submob_index in self.get_submob_indices_list_by_span( + # arbitrary_span + # ) + # ]) + + def select_parts(self, selector: Selector) -> VGroup: + return self.build_parts_from_indices_lists( + self.get_submob_indices_lists_by_selector(selector) + ) + def select_part(self, selector: Selector, index: int = 0) -> VGroup: return self.select_parts(selector)[index] diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 93e49a81..4a709271 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -31,14 +31,14 @@ if TYPE_CHECKING: SCALE_FACTOR_PER_FONT_POINT = 0.001 -TEX_COLOR_COMMANDS_DICT = { - "\\color": (1, False), - "\\textcolor": (1, False), - "\\pagecolor": (1, True), - "\\colorbox": (1, True), - "\\fcolorbox": (2, True), -} -TEX_COLOR_COMMAND_SUFFIX = "replaced" +#TEX_COLOR_COMMANDS_DICT = { +# "\\color": (1, False), +# "\\textcolor": (1, False), +# "\\pagecolor": (1, True), +# "\\colorbox": (1, True), +# "\\fcolorbox": (2, True), +#} +#TEX_COLOR_COMMAND_SUFFIX = "replaced" class MTex(LabelledString): @@ -56,7 +56,7 @@ class MTex(LabelledString): self.tex_string = tex_string super().__init__(tex_string, **kwargs) - #self.set_color_by_tex_to_color_map(self.tex_to_color_map) + self.set_color_by_tex_to_color_map(self.tex_to_color_map) self.scale(SCALE_FACTOR_PER_FONT_POINT * self.font_size) @property @@ -97,16 +97,12 @@ class MTex(LabelledString): return f"\\color[RGB]{{{r}, {g}, {b}}}" @staticmethod - def get_tag_str( - attr_dict: dict[str, str], escape_color_keys: bool, is_begin_tag: bool - ) -> str: - if escape_color_keys: - return "" - if not is_begin_tag: - return "}}" - if "foreground" not in attr_dict: - return "{{" - return "{{" + MTex.get_color_command_str(attr_dict["foreground"]) + def get_tag_string_pair( + attr_dict: dict[str, str], label_hex: str | None + ) -> tuple[str, str]: + if label_hex is None: + return ("", "") + return ("{{" + MTex.get_color_command_str(label_hex), "}}") #@staticmethod #def shrink_span(span: Span, skippable_indices: list[int]) -> Span: @@ -223,20 +219,20 @@ class MTex(LabelledString): raise ValueError("Missing '}' inserted") #tag_span_pairs = brace_span_pairs.copy() - script_entity_dict = dict(self.chain( - [ - (span_begin, span_end) - for (span_begin, _), (_, span_end) in brace_span_pairs - ], - command_spans - )) - script_additional_brace_spans = [ - (char_index + 1, script_entity_dict.get( - script_begin, script_begin + 1 - )) - for char_index, script_begin in self.find_spans(r"[_^]\s*(?=.)") - if (char_index - 1, char_index + 1) not in command_spans - ] + #script_entity_dict = dict(self.chain( + # [ + # (span_begin, span_end) + # for (span_begin, _), (_, span_end) in brace_span_pairs + # ], + # command_spans + #)) + #script_additional_brace_spans = [ + # (char_index + 1, script_entity_dict.get( + # script_begin, script_begin + 1 + # )) + # for char_index, script_begin in self.find_spans(r"[_^]\s*(?=.)") + # if (char_index - 1, char_index + 1) not in command_spans + #] #for char_index, script_begin in self.find_spans(r"[_^]\s*(?=.)"): # if (char_index - 1, char_index + 1) in command_spans: # continue @@ -246,13 +242,13 @@ class MTex(LabelledString): # ) # script_additional_brace_spans.append((char_index + 1, script_end)) - tag_span_pairs = self.chain( - brace_span_pairs, - [ - ((script_begin - 1, script_begin), (script_end, script_end)) - for script_begin, script_end in script_additional_brace_spans - ] - ) + #tag_span_pairs = self.chain( + # brace_span_pairs, + # [ + # ((script_begin - 1, script_begin), (script_end, script_end)) + # for script_begin, script_end in script_additional_brace_spans + # ] + #) brace_content_spans = [ (span_begin, span_end) @@ -268,16 +264,19 @@ class MTex(LabelledString): ]) if range_end - range_begin >= 2 ] - self.script_additional_brace_spans = script_additional_brace_spans - return tag_span_pairs, internal_items + #self.script_additional_brace_spans = script_additional_brace_spans + return brace_span_pairs, internal_items def get_external_items(self) -> list[tuple[Span, dict[str, str]]]: return [ - (span, {"foreground": self.color_to_hex(color)}) - for selector, color in self.tex_to_color_map.items() + (span, {}) + for selector in self.tex_to_color_map for span in self.find_spans_by_selector(selector) ] + #def get_label_span_list(self, split_spans: list[Span]) -> list[Span]: + # return split_spans.copy() + #def get_spans_from_items(self, specified_items: list[Span]) -> list[Span]: # return specified_items @@ -287,29 +286,30 @@ class MTex(LabelledString): # for span in specified_items # ])) - def get_label_span_list(self, split_spans: list[Span]) -> list[Span]: - return split_spans + #def get_label_span_list(self, split_spans: list[Span]) -> list[Span]: + # return split_spans - def get_additional_inserted_str_pairs( - self - ) -> list[tuple[Span, tuple[str, str]]]: - return [ - (span, ("{", "}")) - for span in self.script_additional_brace_spans - ] + #def get_additional_inserted_str_pairs( + # self + #) -> list[tuple[Span, tuple[str, str]]]: + # return [ + # (span, ("{", "}")) + # for span in self.script_additional_brace_spans + # ] - def get_command_repl_items(self, is_labelled: bool) -> list[Span, str]: - if not is_labelled: - return [] - result = [] - command_spans = self.entity_spans # TODO - for cmd_span in command_spans: - cmd_str = self.get_substr(cmd_span) - if cmd_str not in TEX_COLOR_COMMANDS_DICT: - continue - repl_str = f"{cmd_str}{TEX_COLOR_COMMAND_SUFFIX}" - result.append((cmd_span, repl_str)) - return result + def get_command_repl_items(self) -> list[Span, str]: + return [] + #if not is_labelled: + # return [] + #result = [] + #command_spans = self.entity_spans # TODO + #for cmd_span in command_spans: + # cmd_str = self.get_substr(cmd_span) + # if cmd_str not in TEX_COLOR_COMMANDS_DICT: + # continue + # repl_str = f"{cmd_str}{TEX_COLOR_COMMAND_SUFFIX}" + # result.append((cmd_span, repl_str)) + #return result #def get_predefined_inserted_str_items( # self, split_items: list[Span] @@ -558,15 +558,8 @@ class MTex(LabelledString): # for label, span in enumerate(self.label_span_list) # ] - def get_full_content_string( - self, - label_span_list: list[Span], - split_items: list[tuple[Span, dict[str, str]]], - is_labelled: bool - ) -> str: - result = super().get_full_content_string( - label_span_list, split_items, is_labelled - ) + def get_full_content_string(self, content_string: str, is_labelled: bool) -> str: + result = content_string if self.tex_environment: if isinstance(self.tex_environment, str): @@ -578,25 +571,25 @@ class MTex(LabelledString): if self.alignment: result = "\n".join([self.alignment, result]) - if is_labelled: - occurred_commands = [ - # TODO - self.get_substr(span) for span in self.entity_spans - ] - newcommand_lines = [ - "".join([ - f"\\newcommand{cmd_name}{TEX_COLOR_COMMAND_SUFFIX}", - f"[{n_braces + 1}][]", - "{", - cmd_name + "{black}" * n_braces if substitute_cmd else "", - "}" - ]) - for cmd_name, (n_braces, substitute_cmd) - in TEX_COLOR_COMMANDS_DICT.items() - if cmd_name in occurred_commands - ] - result = "\n".join([*newcommand_lines, result]) - else: + #if is_labelled: + # occurred_commands = [ + # # TODO + # self.get_substr(span) for span in self.entity_spans + # ] + # newcommand_lines = [ + # "".join([ + # f"\\newcommand{cmd_name}{TEX_COLOR_COMMAND_SUFFIX}", + # f"[{n_braces + 1}][]", + # "{", + # cmd_name + "{black}" * n_braces if substitute_cmd else "", + # "}" + # ]) + # for cmd_name, (n_braces, substitute_cmd) + # in TEX_COLOR_COMMANDS_DICT.items() + # if cmd_name in occurred_commands + # ] + # result = "\n".join([*newcommand_lines, result]) + if not is_labelled: result = "\n".join([ self.get_color_command_str(self.base_color_hex), result diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 3b07a07e..9b82ae2f 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -114,6 +114,7 @@ class MarkupText(LabelledString): "t2w": {}, "global_config": {}, "local_configs": {}, + "split_words": True, } def __init__(self, text: str, **kwargs): @@ -162,7 +163,8 @@ class MarkupText(LabelledString): self.t2s, self.t2w, self.global_config, - self.local_configs + self.local_configs, + self.split_words ) def full2short(self, config: dict) -> None: @@ -250,28 +252,26 @@ class MarkupText(LabelledString): # Toolkits @staticmethod - def get_tag_str( - attr_dict: dict[str, str], escape_color_keys: bool, is_begin_tag: bool - ) -> str: - if not is_begin_tag: - return "" - if escape_color_keys: - converted_attr_dict = {} + def get_tag_string_pair( + attr_dict: dict[str, str], label_hex: str | None + ) -> tuple[str, str]: + if label_hex is not None: + converted_attr_dict = {"foreground": label_hex} for key, val in attr_dict.items(): substitute_key = MARKUP_COLOR_KEYS_DICT.get(key.lower(), None) if substitute_key is None: converted_attr_dict[key] = val elif substitute_key: converted_attr_dict[key] = "black" - else: - converted_attr_dict[key] = "black" + #else: + # converted_attr_dict[key] = "black" else: converted_attr_dict = attr_dict.copy() - result = " ".join([ + attrs_str = " ".join([ f"{key}='{val}'" for key, val in converted_attr_dict.items() ]) - return f"" + return (f"", "") def get_global_attr_dict(self) -> dict[str, str]: result = { @@ -286,8 +286,9 @@ class MarkupText(LabelledString): if tuple(map(int, pango_version.split("."))) < (1, 50): if self.lsh is not None: log.warning( - f"Pango version {pango_version} found (< 1.50), " - "unable to set `line_height` attribute" + "Pango version %s found (< 1.50), " + "unable to set `line_height` attribute", + pango_version ) else: line_spacing_scale = self.lsh or DEFAULT_LINE_SPACING_SCALE @@ -477,8 +478,8 @@ class MarkupText(LabelledString): if not self.is_markup: return [], [] - tag_pattern = r"""<(/?)(\w+)\s*((\w+\s*\=\s*(['"])[\s\S]*?\5\s*)*)>""" - attr_pattern = r"""(\w+)\s*\=\s*(['"])([\s\S]*?)\2""" + tag_pattern = r"<(/?)(\w+)\s*((\w+\s*\=\s*(['\x22])[\s\S]*?\5\s*)*)>" + attr_pattern = r"(\w+)\s*\=\s*(['\x22])([\s\S]*?)\2" begin_match_obj_stack = [] markup_tag_items = [] for match_obj in re.finditer(tag_pattern, self.string): @@ -511,7 +512,7 @@ class MarkupText(LabelledString): return tag_span_pairs, internal_items def get_external_items(self) -> list[tuple[Span, dict[str, str]]]: - return [ + result = [ (self.full_span, self.get_global_attr_dict()), (self.full_span, self.global_config), *[ @@ -531,6 +532,17 @@ class MarkupText(LabelledString): for span in self.find_spans_by_selector(selector) ] ] + if self.split_words: + # For backward compatibility + result.extend([ + (span, {}) + for span in self.find_spans(r"[a-zA-Z]+") + for pattern in (r"[a-zA-Z]+", r"\S+") + ]) + return result + + + #def get_label_span_list(self, split_spans: list[Span]) -> list[Span]: #def get_spans_from_items( # self, specified_items: list[tuple[Span, dict[str, str]]] @@ -546,31 +558,31 @@ class MarkupText(LabelledString): # for span in self.split_span(specified_span) # ] - def get_label_span_list(self, split_spans: list[Span]) -> list[Span]: - interval_spans = sorted(self.chain( - self.tag_spans, - [ - (index, index) - for span in split_spans - for index in span - ] - )) - text_spans = self.get_complement_spans(self.full_span, interval_spans) - if self.is_markup: - pattern = r"[0-9a-zA-Z]+|(?:&[\s\S]*?;|[^0-9a-zA-Z\s])+" - else: - pattern = r"[0-9a-zA-Z]+|[^0-9a-zA-Z\s]+" - return self.chain(*[ - self.find_spans(pattern, pos=span_begin, endpos=span_end) - for span_begin, span_end in text_spans - ]) + #def get_label_span_list(self, split_spans: list[Span]) -> list[Span]: + # interval_spans = sorted(self.chain( + # self.tag_spans, + # [ + # (index, index) + # for span in split_spans + # for index in span + # ] + # )) + # text_spans = self.get_complement_spans(self.full_span, interval_spans) + # if self.is_markup: + # pattern = r"[0-9a-zA-Z]+|(?:&[\s\S]*?;|[^0-9a-zA-Z\s])+" + # else: + # pattern = r"[0-9a-zA-Z]+|[^0-9a-zA-Z\s]+" + # return self.chain(*[ + # self.find_spans(pattern, pos=span_begin, endpos=span_end) + # for span_begin, span_end in text_spans + # ]) - def get_additional_inserted_str_pairs( - self - ) -> list[tuple[Span, tuple[str, str]]]: - return [] + #def get_additional_inserted_str_pairs( + # self + #) -> list[tuple[Span, tuple[str, str]]]: + # return [] - def get_command_repl_items(self, is_labelled: bool) -> list[Span, str]: + def get_command_repl_items(self) -> list[Span, str]: result = [ (tag_span, "") for tag_span in self.tag_spans ] @@ -755,8 +767,8 @@ class MarkupText(LabelledString): # for span, attr_dict in attr_dict_items # ] - def get_content(self, is_labelled: bool) -> str: - return self.decorated_strings[is_labelled] + def get_full_content_string(self, content_string: str, is_labelled: bool) -> str: + return content_string # Selector