diff --git a/manimlib/mobject/svg/labelled_string.py b/manimlib/mobject/svg/labelled_string.py index 4756bd61..a2a9f889 100644 --- a/manimlib/mobject/svg/labelled_string.py +++ b/manimlib/mobject/svg/labelled_string.py @@ -113,26 +113,31 @@ class LabelledString(_StringSVG): def get_substr(self, span: Span) -> str: return self.string[slice(*span)] - def finditer(self, pattern, **kwargs): - return re.compile(pattern).finditer(self.string, **kwargs) + def finditer( + self, pattern: str, flags: int = 0, **kwargs + ) -> Iterable[re.Match]: + return re.compile(pattern, flags).finditer(self.string, **kwargs) - def search(self, pattern, **kwargs): - return re.compile(pattern).search(self.string, **kwargs) + def search(self, pattern: str, flags: int = 0, **kwargs) -> re.Match: + return re.compile(pattern, flags).search(self.string, **kwargs) - def match(self, pattern, **kwargs): - return re.compile(pattern).match(self.string, **kwargs) + def match(self, pattern: str, flags: int = 0, **kwargs) -> re.Match: + return re.compile(pattern, flags).match(self.string, **kwargs) - def find_spans(self, pattern: str) -> list[Span]: - return [match_obj.span() for match_obj in self.finditer(pattern)] + def find_spans(self, pattern: str, **kwargs) -> list[Span]: + return [ + match_obj.span() + for match_obj in self.finditer(pattern, **kwargs) + ] - def find_substr(self, substr: str) -> list[Span]: + def find_substr(self, substr: str, **kwargs) -> list[Span]: if not substr: return [] - return self.find_spans(re.escape(substr)) + return self.find_spans(re.escape(substr), **kwargs) - def find_substrs(self, substrs: list[str]) -> list[Span]: + def find_substrs(self, substrs: list[str], **kwargs) -> list[Span]: return list(it.chain(*[ - self.find_substr(substr) + self.find_substr(substr, **kwargs) for substr in remove_list_redundancies(substrs) ])) @@ -434,17 +439,27 @@ class LabelledString(_StringSVG): # Selector - def find_span_components(self, custom_span: Span) -> list[Span]: + def find_span_components( + self, custom_span: Span, substring: bool = True + ) -> list[Span]: shrinked_span = self.shrink_span(custom_span) if shrinked_span[0] >= shrinked_span[1]: return [] - indices = remove_list_redundancies(list(it.chain( - self.full_span, - *self.label_span_list - ))) - span_begin = self.take_nearest_value(indices, shrinked_span[0], 0) - span_end = self.take_nearest_value(indices, shrinked_span[1] - 1, 1) + if substring: + indices = remove_list_redundancies(list(it.chain( + self.full_span, + *self.label_span_list + ))) + span_begin = self.take_nearest_value( + indices, shrinked_span[0], 0 + ) + span_end = self.take_nearest_value( + indices, shrinked_span[1] - 1, 1 + ) + else: + span_begin, span_end = shrinked_span + span_choices = sorted(filter( lambda span: self.span_contains((span_begin, span_end), span), self.label_span_list @@ -462,12 +477,14 @@ class LabelledString(_StringSVG): span_begin = next_begin return result - def get_parts_by_custom_span(self, custom_span: Span) -> VGroup: + def get_parts_by_custom_span(self, custom_span: Span, **kwargs) -> VGroup: labels = [ label for label, span in enumerate(self.label_span_list) if any([ self.span_contains(span_component, span) - for span_component in self.find_span_components(custom_span) + for span_component in self.find_span_components( + custom_span, **kwargs + ) ]) ] return VGroup(*filter( @@ -475,10 +492,15 @@ class LabelledString(_StringSVG): self.submobjects )) - def get_parts_by_string(self, substr: str) -> VGroup: + def get_parts_by_string( + self, substr: str, case_sensitive: bool = True, **kwargs + ) -> VGroup: + flags = 0 + if not case_sensitive: + flags |= re.I return VGroup(*[ - self.get_parts_by_custom_span(span) - for span in self.find_substr(substr) + self.get_parts_by_custom_span(span, **kwargs) + for span in self.find_substr(substr, flags=flags) ]) def get_parts_by_group_substr(self, substr: str) -> VGroup: @@ -490,18 +512,20 @@ class LabelledString(_StringSVG): if group_substr == substr ]) - def get_part_by_string(self, substr: str, index : int = 0) -> VMobject: - return self.get_parts_by_string(substr)[index] + def get_part_by_string( + self, substr: str, index: int = 0, **kwargs + ) -> VMobject: + return self.get_parts_by_string(substr, **kwargs)[index] - def set_color_by_string(self, substr: str, color: ManimColor): - self.get_parts_by_string(substr).set_color(color) + def set_color_by_string(self, substr: str, color: ManimColor, **kwargs): + self.get_parts_by_string(substr, **kwargs).set_color(color) return self def set_color_by_string_to_color_map( - self, string_to_color_map: dict[str, ManimColor] + self, string_to_color_map: dict[str, ManimColor], **kwargs ): for substr, color in string_to_color_map.items(): - self.set_color_by_string(substr, color) + self.set_color_by_string(substr, color, **kwargs) return self def indices_of_part(self, part: Iterable[VMobject]) -> list[int]: diff --git a/manimlib/mobject/svg/mtex_mobject.py b/manimlib/mobject/svg/mtex_mobject.py index 23209fda..341db072 100644 --- a/manimlib/mobject/svg/mtex_mobject.py +++ b/manimlib/mobject/svg/mtex_mobject.py @@ -314,19 +314,21 @@ class MTex(LabelledString): # Method alias - def get_parts_by_tex(self, tex: str) -> VGroup: - return self.get_parts_by_string(tex) + def get_parts_by_tex(self, tex: str, **kwargs) -> VGroup: + return self.get_parts_by_string(tex, **kwargs) - def get_part_by_tex(self, tex: str) -> VMobject: - return self.get_part_by_string(tex) + def get_part_by_tex(self, tex: str, **kwargs) -> VMobject: + return self.get_part_by_string(tex, **kwargs) - def set_color_by_tex(self, tex: str, color: ManimColor): - return self.set_color_by_string(tex, color) + def set_color_by_tex(self, tex: str, color: ManimColor, **kwargs): + return self.set_color_by_string(tex, color, **kwargs) def set_color_by_tex_to_color_map( - self, tex_to_color_map: dict[str, ManimColor] + self, tex_to_color_map: dict[str, ManimColor], **kwargs ): - return self.set_color_by_string_to_color_map(tex_to_color_map) + return self.set_color_by_string_to_color_map( + tex_to_color_map, **kwargs + ) def get_tex(self) -> str: return self.get_string() diff --git a/manimlib/mobject/svg/text_mobject.py b/manimlib/mobject/svg/text_mobject.py index b8d2f259..8dbd05cc 100644 --- a/manimlib/mobject/svg/text_mobject.py +++ b/manimlib/mobject/svg/text_mobject.py @@ -315,9 +315,13 @@ class MarkupText(LabelledString): return self.find_substr(substr_or_span) span = tuple([ - (index if index >= 0 else index + self.string_len) - if index is not None else substitute - for index, substitute in zip(substr_or_span, self.full_span) + ( + min(index, self.string_len) + if index >= 0 + else max(index + self.string_len, 0) + ) + if index is not None else default_index + for index, default_index in zip(substr_or_span, self.full_span) ]) if span[0] >= span[1]: return [] @@ -517,19 +521,21 @@ class MarkupText(LabelledString): # Method alias - def get_parts_by_text(self, text: str) -> VGroup: - return self.get_parts_by_string(text) + def get_parts_by_text(self, text: str, **kwargs) -> VGroup: + return self.get_parts_by_string(text, **kwargs) - def get_part_by_text(self, text: str) -> VMobject: - return self.get_part_by_string(text) + def get_part_by_text(self, text: str, **kwargs) -> VMobject: + return self.get_part_by_string(text, **kwargs) - def set_color_by_text(self, text: str, color: ManimColor): - return self.set_color_by_string(text, color) + def set_color_by_text(self, text: str, color: ManimColor, **kwargs): + return self.set_color_by_string(text, color, **kwargs) def set_color_by_text_to_color_map( - self, text_to_color_map: dict[str, ManimColor] + self, text_to_color_map: dict[str, ManimColor], **kwargs ): - return self.set_color_by_string_to_color_map(text_to_color_map) + return self.set_color_by_string_to_color_map( + text_to_color_map, **kwargs + ) def get_text(self) -> str: return self.get_string()