Merge pull request #1780 from YishiMichael/master

Add support for `substring` and `case_sensitive` parameters
This commit is contained in:
鹤翔万里 2022-04-07 09:58:39 +08:00 committed by GitHub
commit e9bf13882e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 81 additions and 49 deletions

View file

@ -113,26 +113,31 @@ class LabelledString(_StringSVG):
def get_substr(self, span: Span) -> str: def get_substr(self, span: Span) -> str:
return self.string[slice(*span)] return self.string[slice(*span)]
def finditer(self, pattern, **kwargs): def finditer(
return re.compile(pattern).finditer(self.string, **kwargs) self, pattern: str, flags: int = 0, **kwargs
) -> Iterable[re.Match]:
return re.compile(pattern, flags).finditer(self.string, **kwargs)
def search(self, pattern, **kwargs): def search(self, pattern: str, flags: int = 0, **kwargs) -> re.Match:
return re.compile(pattern).search(self.string, **kwargs) return re.compile(pattern, flags).search(self.string, **kwargs)
def match(self, pattern, **kwargs): def match(self, pattern: str, flags: int = 0, **kwargs) -> re.Match:
return re.compile(pattern).match(self.string, **kwargs) return re.compile(pattern, flags).match(self.string, **kwargs)
def find_spans(self, pattern: str) -> list[Span]: def find_spans(self, pattern: str, **kwargs) -> list[Span]:
return [match_obj.span() for match_obj in self.finditer(pattern)] return [
match_obj.span()
for match_obj in self.finditer(pattern, **kwargs)
]
def find_substr(self, substr: str) -> list[Span]: def find_substr(self, substr: str, **kwargs) -> list[Span]:
if not substr: if not substr:
return [] return []
return self.find_spans(re.escape(substr)) return self.find_spans(re.escape(substr), **kwargs)
def find_substrs(self, substrs: list[str]) -> list[Span]: def find_substrs(self, substrs: list[str], **kwargs) -> list[Span]:
return list(it.chain(*[ return list(it.chain(*[
self.find_substr(substr) self.find_substr(substr, **kwargs)
for substr in remove_list_redundancies(substrs) for substr in remove_list_redundancies(substrs)
])) ]))
@ -434,17 +439,27 @@ class LabelledString(_StringSVG):
# Selector # Selector
def find_span_components(self, custom_span: Span) -> list[Span]: def find_span_components(
self, custom_span: Span, substring: bool = True
) -> list[Span]:
shrinked_span = self.shrink_span(custom_span) shrinked_span = self.shrink_span(custom_span)
if shrinked_span[0] >= shrinked_span[1]: if shrinked_span[0] >= shrinked_span[1]:
return [] return []
indices = remove_list_redundancies(list(it.chain( if substring:
self.full_span, indices = remove_list_redundancies(list(it.chain(
*self.label_span_list self.full_span,
))) *self.label_span_list
span_begin = self.take_nearest_value(indices, shrinked_span[0], 0) )))
span_end = self.take_nearest_value(indices, shrinked_span[1] - 1, 1) span_begin = self.take_nearest_value(
indices, shrinked_span[0], 0
)
span_end = self.take_nearest_value(
indices, shrinked_span[1] - 1, 1
)
else:
span_begin, span_end = shrinked_span
span_choices = sorted(filter( span_choices = sorted(filter(
lambda span: self.span_contains((span_begin, span_end), span), lambda span: self.span_contains((span_begin, span_end), span),
self.label_span_list self.label_span_list
@ -462,12 +477,14 @@ class LabelledString(_StringSVG):
span_begin = next_begin span_begin = next_begin
return result return result
def get_parts_by_custom_span(self, custom_span: Span) -> VGroup: def get_parts_by_custom_span(self, custom_span: Span, **kwargs) -> VGroup:
labels = [ labels = [
label for label, span in enumerate(self.label_span_list) label for label, span in enumerate(self.label_span_list)
if any([ if any([
self.span_contains(span_component, span) self.span_contains(span_component, span)
for span_component in self.find_span_components(custom_span) for span_component in self.find_span_components(
custom_span, **kwargs
)
]) ])
] ]
return VGroup(*filter( return VGroup(*filter(
@ -475,10 +492,15 @@ class LabelledString(_StringSVG):
self.submobjects self.submobjects
)) ))
def get_parts_by_string(self, substr: str) -> VGroup: def get_parts_by_string(
self, substr: str, case_sensitive: bool = True, **kwargs
) -> VGroup:
flags = 0
if not case_sensitive:
flags |= re.I
return VGroup(*[ return VGroup(*[
self.get_parts_by_custom_span(span) self.get_parts_by_custom_span(span, **kwargs)
for span in self.find_substr(substr) for span in self.find_substr(substr, flags=flags)
]) ])
def get_parts_by_group_substr(self, substr: str) -> VGroup: def get_parts_by_group_substr(self, substr: str) -> VGroup:
@ -490,18 +512,20 @@ class LabelledString(_StringSVG):
if group_substr == substr if group_substr == substr
]) ])
def get_part_by_string(self, substr: str, index : int = 0) -> VMobject: def get_part_by_string(
return self.get_parts_by_string(substr)[index] self, substr: str, index: int = 0, **kwargs
) -> VMobject:
return self.get_parts_by_string(substr, **kwargs)[index]
def set_color_by_string(self, substr: str, color: ManimColor): def set_color_by_string(self, substr: str, color: ManimColor, **kwargs):
self.get_parts_by_string(substr).set_color(color) self.get_parts_by_string(substr, **kwargs).set_color(color)
return self return self
def set_color_by_string_to_color_map( def set_color_by_string_to_color_map(
self, string_to_color_map: dict[str, ManimColor] self, string_to_color_map: dict[str, ManimColor], **kwargs
): ):
for substr, color in string_to_color_map.items(): for substr, color in string_to_color_map.items():
self.set_color_by_string(substr, color) self.set_color_by_string(substr, color, **kwargs)
return self return self
def indices_of_part(self, part: Iterable[VMobject]) -> list[int]: def indices_of_part(self, part: Iterable[VMobject]) -> list[int]:

View file

@ -314,19 +314,21 @@ class MTex(LabelledString):
# Method alias # Method alias
def get_parts_by_tex(self, tex: str) -> VGroup: def get_parts_by_tex(self, tex: str, **kwargs) -> VGroup:
return self.get_parts_by_string(tex) return self.get_parts_by_string(tex, **kwargs)
def get_part_by_tex(self, tex: str) -> VMobject: def get_part_by_tex(self, tex: str, **kwargs) -> VMobject:
return self.get_part_by_string(tex) return self.get_part_by_string(tex, **kwargs)
def set_color_by_tex(self, tex: str, color: ManimColor): def set_color_by_tex(self, tex: str, color: ManimColor, **kwargs):
return self.set_color_by_string(tex, color) return self.set_color_by_string(tex, color, **kwargs)
def set_color_by_tex_to_color_map( def set_color_by_tex_to_color_map(
self, tex_to_color_map: dict[str, ManimColor] self, tex_to_color_map: dict[str, ManimColor], **kwargs
): ):
return self.set_color_by_string_to_color_map(tex_to_color_map) return self.set_color_by_string_to_color_map(
tex_to_color_map, **kwargs
)
def get_tex(self) -> str: def get_tex(self) -> str:
return self.get_string() return self.get_string()

View file

@ -315,9 +315,13 @@ class MarkupText(LabelledString):
return self.find_substr(substr_or_span) return self.find_substr(substr_or_span)
span = tuple([ span = tuple([
(index if index >= 0 else index + self.string_len) (
if index is not None else substitute min(index, self.string_len)
for index, substitute in zip(substr_or_span, self.full_span) if index >= 0
else max(index + self.string_len, 0)
)
if index is not None else default_index
for index, default_index in zip(substr_or_span, self.full_span)
]) ])
if span[0] >= span[1]: if span[0] >= span[1]:
return [] return []
@ -517,19 +521,21 @@ class MarkupText(LabelledString):
# Method alias # Method alias
def get_parts_by_text(self, text: str) -> VGroup: def get_parts_by_text(self, text: str, **kwargs) -> VGroup:
return self.get_parts_by_string(text) return self.get_parts_by_string(text, **kwargs)
def get_part_by_text(self, text: str) -> VMobject: def get_part_by_text(self, text: str, **kwargs) -> VMobject:
return self.get_part_by_string(text) return self.get_part_by_string(text, **kwargs)
def set_color_by_text(self, text: str, color: ManimColor): def set_color_by_text(self, text: str, color: ManimColor, **kwargs):
return self.set_color_by_string(text, color) return self.set_color_by_string(text, color, **kwargs)
def set_color_by_text_to_color_map( def set_color_by_text_to_color_map(
self, text_to_color_map: dict[str, ManimColor] self, text_to_color_map: dict[str, ManimColor], **kwargs
): ):
return self.set_color_by_string_to_color_map(text_to_color_map) return self.set_color_by_string_to_color_map(
text_to_color_map, **kwargs
)
def get_text(self) -> str: def get_text(self) -> str:
return self.get_string() return self.get_string()