Merge pull request #1780 from YishiMichael/master

Add support for `substring` and `case_sensitive` parameters
This commit is contained in:
鹤翔万里 2022-04-07 09:58:39 +08:00 committed by GitHub
commit e9bf13882e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 81 additions and 49 deletions

View file

@ -113,26 +113,31 @@ class LabelledString(_StringSVG):
def get_substr(self, span: Span) -> str:
return self.string[slice(*span)]
def finditer(self, pattern, **kwargs):
return re.compile(pattern).finditer(self.string, **kwargs)
def finditer(
self, pattern: str, flags: int = 0, **kwargs
) -> Iterable[re.Match]:
return re.compile(pattern, flags).finditer(self.string, **kwargs)
def search(self, pattern, **kwargs):
return re.compile(pattern).search(self.string, **kwargs)
def search(self, pattern: str, flags: int = 0, **kwargs) -> re.Match:
return re.compile(pattern, flags).search(self.string, **kwargs)
def match(self, pattern, **kwargs):
return re.compile(pattern).match(self.string, **kwargs)
def match(self, pattern: str, flags: int = 0, **kwargs) -> re.Match:
return re.compile(pattern, flags).match(self.string, **kwargs)
def find_spans(self, pattern: str) -> list[Span]:
return [match_obj.span() for match_obj in self.finditer(pattern)]
def find_spans(self, pattern: str, **kwargs) -> list[Span]:
return [
match_obj.span()
for match_obj in self.finditer(pattern, **kwargs)
]
def find_substr(self, substr: str) -> list[Span]:
def find_substr(self, substr: str, **kwargs) -> list[Span]:
if not substr:
return []
return self.find_spans(re.escape(substr))
return self.find_spans(re.escape(substr), **kwargs)
def find_substrs(self, substrs: list[str]) -> list[Span]:
def find_substrs(self, substrs: list[str], **kwargs) -> list[Span]:
return list(it.chain(*[
self.find_substr(substr)
self.find_substr(substr, **kwargs)
for substr in remove_list_redundancies(substrs)
]))
@ -434,17 +439,27 @@ class LabelledString(_StringSVG):
# Selector
def find_span_components(self, custom_span: Span) -> list[Span]:
def find_span_components(
self, custom_span: Span, substring: bool = True
) -> list[Span]:
shrinked_span = self.shrink_span(custom_span)
if shrinked_span[0] >= shrinked_span[1]:
return []
indices = remove_list_redundancies(list(it.chain(
self.full_span,
*self.label_span_list
)))
span_begin = self.take_nearest_value(indices, shrinked_span[0], 0)
span_end = self.take_nearest_value(indices, shrinked_span[1] - 1, 1)
if substring:
indices = remove_list_redundancies(list(it.chain(
self.full_span,
*self.label_span_list
)))
span_begin = self.take_nearest_value(
indices, shrinked_span[0], 0
)
span_end = self.take_nearest_value(
indices, shrinked_span[1] - 1, 1
)
else:
span_begin, span_end = shrinked_span
span_choices = sorted(filter(
lambda span: self.span_contains((span_begin, span_end), span),
self.label_span_list
@ -462,12 +477,14 @@ class LabelledString(_StringSVG):
span_begin = next_begin
return result
def get_parts_by_custom_span(self, custom_span: Span) -> VGroup:
def get_parts_by_custom_span(self, custom_span: Span, **kwargs) -> VGroup:
labels = [
label for label, span in enumerate(self.label_span_list)
if any([
self.span_contains(span_component, span)
for span_component in self.find_span_components(custom_span)
for span_component in self.find_span_components(
custom_span, **kwargs
)
])
]
return VGroup(*filter(
@ -475,10 +492,15 @@ class LabelledString(_StringSVG):
self.submobjects
))
def get_parts_by_string(self, substr: str) -> VGroup:
def get_parts_by_string(
self, substr: str, case_sensitive: bool = True, **kwargs
) -> VGroup:
flags = 0
if not case_sensitive:
flags |= re.I
return VGroup(*[
self.get_parts_by_custom_span(span)
for span in self.find_substr(substr)
self.get_parts_by_custom_span(span, **kwargs)
for span in self.find_substr(substr, flags=flags)
])
def get_parts_by_group_substr(self, substr: str) -> VGroup:
@ -490,18 +512,20 @@ class LabelledString(_StringSVG):
if group_substr == substr
])
def get_part_by_string(self, substr: str, index : int = 0) -> VMobject:
return self.get_parts_by_string(substr)[index]
def get_part_by_string(
self, substr: str, index: int = 0, **kwargs
) -> VMobject:
return self.get_parts_by_string(substr, **kwargs)[index]
def set_color_by_string(self, substr: str, color: ManimColor):
self.get_parts_by_string(substr).set_color(color)
def set_color_by_string(self, substr: str, color: ManimColor, **kwargs):
self.get_parts_by_string(substr, **kwargs).set_color(color)
return self
def set_color_by_string_to_color_map(
self, string_to_color_map: dict[str, ManimColor]
self, string_to_color_map: dict[str, ManimColor], **kwargs
):
for substr, color in string_to_color_map.items():
self.set_color_by_string(substr, color)
self.set_color_by_string(substr, color, **kwargs)
return self
def indices_of_part(self, part: Iterable[VMobject]) -> list[int]:

View file

@ -314,19 +314,21 @@ class MTex(LabelledString):
# Method alias
def get_parts_by_tex(self, tex: str) -> VGroup:
return self.get_parts_by_string(tex)
def get_parts_by_tex(self, tex: str, **kwargs) -> VGroup:
return self.get_parts_by_string(tex, **kwargs)
def get_part_by_tex(self, tex: str) -> VMobject:
return self.get_part_by_string(tex)
def get_part_by_tex(self, tex: str, **kwargs) -> VMobject:
return self.get_part_by_string(tex, **kwargs)
def set_color_by_tex(self, tex: str, color: ManimColor):
return self.set_color_by_string(tex, color)
def set_color_by_tex(self, tex: str, color: ManimColor, **kwargs):
return self.set_color_by_string(tex, color, **kwargs)
def set_color_by_tex_to_color_map(
self, tex_to_color_map: dict[str, ManimColor]
self, tex_to_color_map: dict[str, ManimColor], **kwargs
):
return self.set_color_by_string_to_color_map(tex_to_color_map)
return self.set_color_by_string_to_color_map(
tex_to_color_map, **kwargs
)
def get_tex(self) -> str:
return self.get_string()

View file

@ -315,9 +315,13 @@ class MarkupText(LabelledString):
return self.find_substr(substr_or_span)
span = tuple([
(index if index >= 0 else index + self.string_len)
if index is not None else substitute
for index, substitute in zip(substr_or_span, self.full_span)
(
min(index, self.string_len)
if index >= 0
else max(index + self.string_len, 0)
)
if index is not None else default_index
for index, default_index in zip(substr_or_span, self.full_span)
])
if span[0] >= span[1]:
return []
@ -517,19 +521,21 @@ class MarkupText(LabelledString):
# Method alias
def get_parts_by_text(self, text: str) -> VGroup:
return self.get_parts_by_string(text)
def get_parts_by_text(self, text: str, **kwargs) -> VGroup:
return self.get_parts_by_string(text, **kwargs)
def get_part_by_text(self, text: str) -> VMobject:
return self.get_part_by_string(text)
def get_part_by_text(self, text: str, **kwargs) -> VMobject:
return self.get_part_by_string(text, **kwargs)
def set_color_by_text(self, text: str, color: ManimColor):
return self.set_color_by_string(text, color)
def set_color_by_text(self, text: str, color: ManimColor, **kwargs):
return self.set_color_by_string(text, color, **kwargs)
def set_color_by_text_to_color_map(
self, text_to_color_map: dict[str, ManimColor]
self, text_to_color_map: dict[str, ManimColor], **kwargs
):
return self.set_color_by_string_to_color_map(text_to_color_map)
return self.set_color_by_string_to_color_map(
text_to_color_map, **kwargs
)
def get_text(self) -> str:
return self.get_string()