diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index e3c65a49..325c2fb0 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -213,14 +213,12 @@ class TransformMatchingStrings(AnimationGroup): ], key=len, reverse=True) def get_parts_from_keys(mobject, keys): - if isinstance(keys, str): + if not isinstance(keys, list): keys = [keys] - result = VGroup() - for key in keys: - if not isinstance(key, str): - raise TypeError(key) - result.add(*mobject.get_parts_by_string(key)) - return result + return VGroup(*it.chain(*[ + mobject.select_parts(key) + for key in keys + ])) add_anims_from( ReplacementTransform, get_parts_from_keys, @@ -228,7 +226,7 @@ class TransformMatchingStrings(AnimationGroup): ) add_anims_from( FadeTransformPieces, - LabelledString.get_parts_by_string, + LabelledString.select_parts, get_common_substrs( source.specified_substrs, target.specified_substrs @@ -236,7 +234,7 @@ class TransformMatchingStrings(AnimationGroup): ) add_anims_from( FadeTransformPieces, - LabelledString.get_parts_by_group_substr, + LabelledString.select_parts_by_group_substr, get_common_substrs( source.group_substrs, target.group_substrs diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index d20b3c11..9c927e0e 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -22,6 +22,11 @@ if TYPE_CHECKING: ManimColor = Union[str, Color] Span = tuple[int, int] + Selector = Union[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]] + ] class LabelledString(SVGMobject, ABC): @@ -52,10 +57,10 @@ class LabelledString(SVGMobject, ABC): self.post_parse() def get_file_path(self) -> str: - return self.get_file_path_(use_plain_file=True) + return self.get_file_path_(is_labelled=False) - def get_file_path_(self, use_plain_file: bool) -> str: - content = self.get_content(use_plain_file) + def get_file_path_(self, is_labelled: bool) -> str: + content = self.get_content(is_labelled) return self.get_file_path_by_content(content) @abstractmethod @@ -67,7 +72,7 @@ class LabelledString(SVGMobject, ABC): num_labels = len(self.label_span_list) if num_labels: - file_path = self.get_file_path_(use_plain_file=False) + file_path = self.get_file_path_(is_labelled=True) labelled_svg = SVGMobject(file_path) submob_color_ints = [ self.color_to_int(submob.get_fill_color()) @@ -132,37 +137,31 @@ class LabelledString(SVGMobject, ABC): def get_substr(self, span: Span) -> str: return self.string[slice(*span)] - def finditer( - self, pattern: str, flags: int = 0, **kwargs - ) -> Iterable[re.Match]: - return re.compile(pattern, flags).finditer(self.string, **kwargs) - - def search( - self, pattern: str, flags: int = 0, **kwargs - ) -> re.Match | None: - return re.compile(pattern, flags).search(self.string, **kwargs) - - def match( - self, pattern: str, flags: int = 0, **kwargs - ) -> re.Match | None: - return re.compile(pattern, flags).match(self.string, **kwargs) - - def find_spans(self, pattern: str, **kwargs) -> list[Span]: + def find_spans(self, pattern: str | re.Pattern, **kwargs) -> list[Span]: + if isinstance(pattern, str): + pattern = re.compile(pattern) return [ match_obj.span() - for match_obj in self.finditer(pattern, **kwargs) + for match_obj in pattern.finditer(self.string, **kwargs) ] - def find_substr(self, substr: str, **kwargs) -> list[Span]: - if not substr: - return [] - return self.find_spans(re.escape(substr), **kwargs) - - def find_substrs(self, substrs: list[str], **kwargs) -> list[Span]: - return list(it.chain(*[ - self.find_substr(substr, **kwargs) - for substr in remove_list_redundancies(substrs) - ])) + def find_spans_by_selector(self, selector: Selector) -> list[Span]: + if isinstance(selector, str): + result = self.find_spans(re.escape(selector)) + elif isinstance(selector, re.Pattern): + result = self.find_spans(selector) + else: + span = tuple([ + ( + min(index, self.string_len) + if index >= 0 + else max(index + self.string_len, 0) + ) + if index is not None else default_index + for index, default_index in zip(selector, self.full_span) + ]) + result = [span] + return list(filter(lambda span: span[0] < span[1], result)) @staticmethod def get_neighbouring_pairs(iterable: list) -> list[tuple]: @@ -345,7 +344,10 @@ class LabelledString(SVGMobject, ABC): spans = list(it.chain( self.internal_specified_spans, self.external_specified_spans, - self.find_substrs(self.isolate) + *[ + self.find_spans_by_selector(selector) + for selector in self.isolate + ] )) filtered_spans = list(filter( lambda span: all([ @@ -376,7 +378,7 @@ class LabelledString(SVGMobject, ABC): ) @abstractmethod - def get_content(self, use_plain_file: bool) -> str: + def get_content(self, is_labelled: bool) -> str: return "" # Post-parsing @@ -441,7 +443,7 @@ class LabelledString(SVGMobject, ABC): def get_submob_groups(self) -> list[VGroup]: return [submob_group for _, submob_group in self.group_items] - def get_parts_by_group_substr(self, substr: str) -> VGroup: + def select_parts_by_group_substr(self, substr: str) -> VGroup: return VGroup(*[ group for group_substr, group in self.group_items @@ -488,7 +490,7 @@ class LabelledString(SVGMobject, ABC): span_begin = next_begin return result - def get_part_by_custom_span(self, custom_span: Span, **kwargs) -> VGroup: + def select_part_by_span(self, custom_span: Span, **kwargs) -> VGroup: labels = [ label for label, span in enumerate(self.label_span_list) if any([ @@ -503,34 +505,28 @@ class LabelledString(SVGMobject, ABC): if label in labels ]) - def get_parts_by_string( - self, substr: str, - case_sensitive: bool = True, regex: bool = False, **kwargs - ) -> VGroup: - flags = 0 - if not case_sensitive: - flags |= re.I - pattern = substr if regex else re.escape(substr) + def select_parts(self, selector: Selector, **kwargs) -> VGroup: return VGroup(*[ - self.get_part_by_custom_span(span, **kwargs) - for span in self.find_spans(pattern, flags=flags) - if span[0] < span[1] + self.select_part_by_span(span, **kwargs) + for span in self.find_spans_by_selector(selector) ]) - def get_part_by_string( - self, substr: str, index: int = 0, **kwargs + def select_part( + self, selector: Selector, index: int = 0, **kwargs ) -> VMobject: - return self.get_parts_by_string(substr, **kwargs)[index] + return self.select_parts(selector, **kwargs)[index] - def set_color_by_string(self, substr: str, color: ManimColor, **kwargs): - self.get_parts_by_string(substr, **kwargs).set_color(color) + def set_parts_color( + self, selector: Selector, color: ManimColor, **kwargs + ): + self.select_parts(selector, **kwargs).set_color(color) return self - def set_color_by_string_to_color_map( - self, string_to_color_map: dict[str, ManimColor], **kwargs + def set_parts_color_by_dict( + self, color_map: dict[Selector, ManimColor], **kwargs ): - for substr, color in string_to_color_map.items(): - self.set_color_by_string(substr, color, **kwargs) + for selector, color in color_map.items(): + self.set_parts_color(selector, color, **kwargs) return self def get_string(self) -> str: diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 91d7675b..dad69df5 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -1,6 +1,7 @@ from __future__ import annotations import itertools as it +import re from manimlib.mobject.svg.labelled_string import LabelledString from manimlib.utils.tex_file_writing import display_during_execution @@ -18,6 +19,11 @@ if TYPE_CHECKING: ManimColor = Union[str, Color] Span = tuple[int, int] + Selector = Union[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]] + ] SCALE_FACTOR_PER_FONT_POINT = 0.001 @@ -61,7 +67,7 @@ class MTex(LabelledString): tex_config["text_to_replace"], content ) - with display_during_execution(f"Writing \"{self.tex_string}\""): + with display_during_execution(f"Writing \"{self.string}\""): file_path = tex_to_svg_file(full_tex) return file_path @@ -93,7 +99,10 @@ class MTex(LabelledString): def get_unescaped_char_spans(self, chars: str): return sorted(filter( lambda span: span[0] - 1 not in self.backslash_indices, - self.find_substrs(list(chars)) + list(it.chain(*[ + self.find_spans(re.escape(char)) + for char in chars + ])) )) def get_brace_index_pairs(self) -> list[Span]: @@ -121,8 +130,8 @@ class MTex(LabelledString): result = [] brace_indices_dict = dict(self.brace_index_pairs) script_pattern = r"[a-zA-Z0-9]|\\[a-zA-Z]+" - for script_char_span in self.script_char_spans: - span_begin = self.match(r"\s*", pos=script_char_span[1]).end() + for char_span in self.script_char_spans: + span_begin = self.find_spans(r"\s*", pos=char_span[1])[0][1] if span_begin in brace_indices_dict.keys(): span_end = brace_indices_dict[span_begin] + 1 else: @@ -143,10 +152,10 @@ class MTex(LabelledString): def get_script_spans(self) -> list[Span]: return [ ( - self.search(r"\s*$", endpos=script_char_span[0]).start(), + self.find_spans(r"\s*", endpos=char_span[0])[-1][0], script_content_span[1] ) - for script_char_span, script_content_span in zip( + for char_span, script_content_span in zip( self.script_char_spans, self.script_content_spans ) ] @@ -174,7 +183,7 @@ class MTex(LabelledString): ")", r"(?![a-zA-Z])" ]) - for match_obj in self.finditer(pattern): + for match_obj in re.finditer(pattern, self.string): span_begin, cmd_end = match_obj.span() if span_begin not in backslash_indices: continue @@ -192,7 +201,7 @@ class MTex(LabelledString): def get_extra_entity_spans(self) -> list[Span]: return [ - self.match(r"\\([a-zA-Z]+|.)", pos=index).span() + self.find_spans(r"\\([a-zA-Z]+|.?)", pos=index)[0] for index in self.backslash_indices ] @@ -223,7 +232,10 @@ class MTex(LabelledString): return result def get_external_specified_spans(self) -> list[Span]: - return self.find_substrs(list(self.tex_to_color_map.keys())) + return list(it.chain(*[ + self.find_spans_by_selector(selector) + for selector in self.tex_to_color_map.keys() + ])) def get_label_span_list(self) -> list[Span]: result = self.script_content_spans.copy() @@ -237,10 +249,8 @@ class MTex(LabelledString): result.append(shrinked_span) return result - def get_content(self, use_plain_file: bool) -> str: - if use_plain_file: - span_repl_dict = {} - else: + def get_content(self, is_labelled: bool) -> str: + if is_labelled: extended_label_span_list = [ span if span in self.script_content_spans @@ -258,6 +268,8 @@ class MTex(LabelledString): inserted_string_pairs, self.command_repl_items ) + else: + span_repl_dict = {} result = self.get_replaced_substr(self.full_span, span_repl_dict) if self.tex_environment: @@ -269,7 +281,7 @@ class MTex(LabelledString): result = "\n".join([prefix, result, suffix]) if self.alignment: result = "\n".join([self.alignment, result]) - if use_plain_file: + if not is_labelled: result = "\n".join([ self.get_color_command_str(self.base_color_int), result @@ -303,21 +315,21 @@ class MTex(LabelledString): # Method alias - def get_parts_by_tex(self, tex: str, **kwargs) -> VGroup: - return self.get_parts_by_string(tex, **kwargs) + def get_parts_by_tex(self, selector: Selector, **kwargs) -> VGroup: + return self.select_parts(selector, **kwargs) - def get_part_by_tex(self, tex: str, **kwargs) -> VMobject: - return self.get_part_by_string(tex, **kwargs) + def get_part_by_tex(self, selector: Selector, **kwargs) -> VMobject: + return self.select_part(selector, **kwargs) - def set_color_by_tex(self, tex: str, color: ManimColor, **kwargs): - return self.set_color_by_string(tex, color, **kwargs) + def set_color_by_tex( + self, selector: Selector, color: ManimColor, **kwargs + ): + return self.set_parts_color(selector, color, **kwargs) def set_color_by_tex_to_color_map( - self, tex_to_color_map: dict[str, ManimColor], **kwargs + self, color_map: dict[Selector, ManimColor], **kwargs ): - return self.set_color_by_string_to_color_map( - tex_to_color_map, **kwargs - ) + return self.set_parts_color_by_dict(color_map, **kwargs) def get_tex(self) -> str: return self.get_string() diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index 2c076551..f79bb79a 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -34,6 +34,11 @@ if TYPE_CHECKING: ManimColor = Union[str, Color] Span = tuple[int, int] + Selector = Union[ + str, + re.Pattern, + tuple[Union[int, None], Union[int, None]] + ] TEXT_MOB_SCALE_FACTOR = 0.0076 @@ -283,25 +288,6 @@ class MarkupText(LabelledString): MarkupText.get_neighbouring_pairs(index_seq), attr_dict_list[:-1] )) - def find_substr_or_span( - self, substr_or_span: str | tuple[int | None, int | None] - ) -> list[Span]: - if isinstance(substr_or_span, str): - return self.find_substr(substr_or_span) - - span = tuple([ - ( - min(index, self.string_len) - if index >= 0 - else max(index + self.string_len, 0) - ) - if index is not None else default_index - for index, default_index in zip(substr_or_span, self.full_span) - ]) - if span[0] >= span[1]: - return [] - return [span] - # Pre-parsing def get_tag_items_from_markup( @@ -314,7 +300,7 @@ class MarkupText(LabelledString): attr_pattern = r"""(\w+)\s*\=\s*(['"])(.*?)\2""" begin_match_obj_stack = [] match_obj_pairs = [] - for match_obj in self.finditer(tag_pattern): + for match_obj in re.finditer(tag_pattern, self.string): if not match_obj.group(1): begin_match_obj_stack.append(match_obj) else: @@ -385,12 +371,12 @@ class MarkupText(LabelledString): (self.t2s, "font_style"), (self.t2w, "font_weight") ) - for substr_or_span, val in t2x_dict.items() - for span in self.find_substr_or_span(substr_or_span) + for selector, val in t2x_dict.items() + for span in self.find_spans_by_selector(selector) ] + [ (span, local_config) - for substr_or_span, local_config in self.local_configs.items() - for span in self.find_substr_or_span(substr_or_span) + for selector, local_config in self.local_configs.items() + for span in self.find_spans_by_selector(selector) ] def get_predefined_attr_dicts(self) -> list[Span, dict[str, str]]: @@ -428,7 +414,7 @@ class MarkupText(LabelledString): (">", ">"), ("<", "<") ) - for span in self.find_substr(char) + for span in self.find_spans(re.escape(char)) ] return result @@ -460,7 +446,7 @@ class MarkupText(LabelledString): self.get_neighbouring_pairs(breakup_indices) )) - def get_content(self, use_plain_file: bool) -> str: + def get_content(self, is_labelled: bool) -> str: filtered_attr_dicts = list(filter( lambda item: all([ self.is_splittable_index(index) @@ -468,18 +454,7 @@ class MarkupText(LabelledString): ]), self.predefined_attr_dicts )) - if use_plain_file: - attr_dict_items = [ - (self.full_span, { - "foreground": self.int_to_hex(self.base_color_int) - }), - *filtered_attr_dicts, - *[ - (span, {}) - for span in self.label_span_list - ] - ] - else: + if is_labelled: attr_dict_items = [ (self.full_span, {"foreground": BLACK}), *[ @@ -494,6 +469,17 @@ class MarkupText(LabelledString): for label, span in enumerate(self.label_span_list) ] ] + else: + attr_dict_items = [ + (self.full_span, { + "foreground": self.int_to_hex(self.base_color_int) + }), + *filtered_attr_dicts, + *[ + (span, {}) + for span in self.label_span_list + ] + ] inserted_string_pairs = [ (span, ( f"", @@ -508,21 +494,21 @@ class MarkupText(LabelledString): # Method alias - def get_parts_by_text(self, text: str, **kwargs) -> VGroup: - return self.get_parts_by_string(text, **kwargs) + def get_parts_by_text(self, selector: Selector, **kwargs) -> VGroup: + return self.select_parts(selector, **kwargs) - def get_part_by_text(self, text: str, **kwargs) -> VMobject: - return self.get_part_by_string(text, **kwargs) + def get_part_by_text(self, selector: Selector, **kwargs) -> VMobject: + return self.select_part(selector, **kwargs) - def set_color_by_text(self, text: str, color: ManimColor, **kwargs): - return self.set_color_by_string(text, color, **kwargs) + def set_color_by_text( + self, selector: Selector, color: ManimColor, **kwargs + ): + return self.set_parts_color(selector, color, **kwargs) def set_color_by_text_to_color_map( - self, text_to_color_map: dict[str, ManimColor], **kwargs + self, color_map: dict[Selector, ManimColor], **kwargs ): - return self.set_color_by_string_to_color_map( - text_to_color_map, **kwargs - ) + return self.set_parts_color_by_dict(color_map, **kwargs) def get_text(self) -> str: return self.get_string()