Refactor LabelledString

This commit is contained in:
YishiMichael 2022-03-30 21:53:00 +08:00
parent 7e8b3a4c6b
commit c5ec47b0e9
No known key found for this signature in database
GPG key ID: EC615C0C5A86BC80
3 changed files with 515 additions and 626 deletions

View file

@ -156,7 +156,7 @@ class TransformMatchingTex(TransformMatchingParts):
class TransformMatchingStrings(AnimationGroup):
CONFIG = {
"key_map": dict(),
"transform_mismatches_class": None,
"transform_mismatches": False,
}
def __init__(self,
@ -168,42 +168,53 @@ class TransformMatchingStrings(AnimationGroup):
assert isinstance(source_mobject, LabelledString)
assert isinstance(target_mobject, LabelledString)
anims = []
rest_source_submobs = source_mobject.submobjects.copy()
rest_target_submobs = target_mobject.submobjects.copy()
rest_source_indices = list(range(len(source_mobject.submobjects)))
rest_target_indices = list(range(len(target_mobject.submobjects)))
def add_anims_from(anim_class, func, source_attrs, target_attrs=None):
if target_attrs is None:
target_attrs = source_attrs.copy()
for source_attr, target_attr in zip(source_attrs, target_attrs):
source_parts = func(source_mobject, source_attr)
target_parts = func(target_mobject, target_attr)
filtered_source_parts = [
submob_part for submob_part in source_parts
if all([
submob in rest_source_submobs
for submob in submob_part
])
]
filtered_target_parts = [
submob_part for submob_part in target_parts
if all([
submob in rest_target_submobs
for submob in submob_part
])
]
if not (filtered_source_parts and filtered_target_parts):
return
anims.append(anim_class(
VGroup(*filtered_source_parts),
VGroup(*filtered_target_parts),
**kwargs
def add_anims_from(anim_class, func, source_args, target_args=None):
if target_args is None:
target_args = source_args.copy()
for source_arg, target_arg in zip(source_args, target_args):
source_parts = func(source_mobject, source_arg)
target_parts = func(target_mobject, target_arg)
source_indices_lists = source_mobject.indices_lists_of_parts(
source_parts
)
target_indices_lists = target_mobject.indices_lists_of_parts(
target_parts
)
filtered_source_indices_lists = list(filter(
lambda indices_list: all([
index in rest_source_indices
for index in indices_list
]), source_indices_lists
))
for submob in it.chain(*filtered_source_parts):
rest_source_submobs.remove(submob)
for submob in it.chain(*filtered_target_parts):
rest_target_submobs.remove(submob)
filtered_target_indices_lists = list(filter(
lambda indices_list: all([
index in rest_target_indices
for index in indices_list
]), target_indices_lists
))
if not all([
filtered_source_indices_lists,
filtered_target_indices_lists
]):
return
anims.append(anim_class(source_parts, target_parts, **kwargs))
for index in it.chain(*filtered_source_indices_lists):
rest_source_indices.remove(index)
for index in it.chain(*filtered_target_indices_lists):
rest_target_indices.remove(index)
def get_submobs_from_keys(mobject, keys):
def get_common_substrs(func):
result = sorted(list(
set(func(source_mobject)).intersection(func(target_mobject))
), key=len, reverse=True)
if "" in result:
result.remove("")
return result
def get_parts_from_keys(mobject, keys):
if not isinstance(keys, tuple):
keys = (keys,)
indices = []
@ -220,55 +231,50 @@ class TransformMatchingStrings(AnimationGroup):
else:
raise TypeError(key)
return VGroup(VGroup(*[
mobject[i] for i in remove_list_redundancies(indices)
mobject[index] for index in remove_list_redundancies(indices)
]))
add_anims_from(
ReplacementTransform, get_submobs_from_keys,
ReplacementTransform, get_parts_from_keys,
self.key_map.keys(), self.key_map.values()
)
common_specified_substrings = sorted(list(
set(source_mobject.get_specified_substrings()).intersection(
target_mobject.get_specified_substrings()
)
), key=len, reverse=True)
if "" in common_specified_substrings:
common_specified_substrings.remove("")
add_anims_from(
FadeTransformPieces,
LabelledString.get_parts_by_string,
common_specified_substrings
get_common_substrs(
lambda mobject: mobject.specified_substrings
)
)
common_submob_strings = {
source_submob.get_string() for source_submob in source_mobject
}.intersection({
target_submob.get_string() for target_submob in target_mobject
})
add_anims_from(
FadeTransformPieces,
lambda mobject, attr: VGroup(*[
VGroup(mob) for mob in mobject
if mob.get_string() == attr
]),
common_submob_strings
LabelledString.get_parts_by_group_substr,
get_common_substrs(
lambda mobject: mobject.group_substrs
)
)
if self.transform_mismatches_class is not None:
anims.append(self.transform_mismatches_class(
fade_source = VGroup(*[
source_mobject[index]
for index in rest_source_indices
])
fade_target = VGroup(*[
target_mobject[index]
for index in rest_target_indices
])
if self.transform_mismatches:
anims.append(ReplacementTransform(
fade_source,
fade_target,
**kwargs
))
else:
anims.append(FadeOutToPoint(
VGroup(*rest_source_submobs),
fade_source,
target_mobject.get_center(),
**kwargs
))
anims.append(FadeInFromPoint(
VGroup(*rest_target_submobs),
fade_target,
source_mobject.get_center(),
**kwargs
))

File diff suppressed because it is too large Load diff

View file

@ -234,6 +234,14 @@ class MarkupText(LabelledString):
pango_width=pango_width
)
def parse(self) -> None:
self.global_items_from_config = self.get_global_items_from_config()
self.tag_items_from_markup = self.get_tag_items_from_markup()
self.local_items_from_markup = self.get_local_items_from_markup()
self.local_items_from_config = self.get_local_items_from_config()
self.predefined_items = self.get_predefined_items()
super().parse()
# Toolkits
@staticmethod
@ -251,6 +259,19 @@ class MarkupText(LabelledString):
def get_end_tag_str() -> str:
return "</span>"
@staticmethod
def rgb_int_to_hex(rgb_int: int) -> str:
return "#{:06x}".format(rgb_int).upper()
@staticmethod
def get_begin_color_command_str(rgb_int: int):
color_hex = MarkupText.rgb_int_to_hex(rgb_int)
return MarkupText.get_begin_tag_str({"foreground": color_hex})
@staticmethod
def get_end_color_command_str() -> str:
return MarkupText.get_end_tag_str()
@staticmethod
def convert_attr_key(key: str) -> str:
return SPAN_ATTR_KEY_CONVERSION[key.lower()]
@ -269,7 +290,7 @@ class MarkupText(LabelledString):
if span[0] >= span[1]:
continue
region_indices = [
MarkupText.find_region_index(index, index_seq)
MarkupText.find_region_index(index_seq, index)
for index in span
]
for flag in (1, 0):
@ -289,19 +310,43 @@ class MarkupText(LabelledString):
MarkupText.get_neighbouring_pairs(index_seq), attr_dict_list[:-1]
))
@staticmethod
def get_begin_color_command_str(r: int, g: int, b: int) -> str:
color_hex = "#{:02x}{:02x}{:02x}".format(r, g, b).upper()
return MarkupText.get_begin_tag_str({"foreground": color_hex})
def find_spans_by_word_or_span(
self, word_or_span: str | Span
) -> list[Span]:
if isinstance(word_or_span, tuple):
return [word_or_span]
return self.find_spans(re.escape(word_or_span))
@staticmethod
def get_end_color_command_str() -> str:
return MarkupText.get_end_tag_str()
# Pre-parsing
# Parser
def get_global_items_from_config(self) -> list[str, str]:
global_attr_dict = {
"line_height": (
(self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1
) * 0.6,
"font_family": self.font or get_customization()["style"]["font"],
"font_size": self.font_size * 1024,
"font_style": self.slant,
"font_weight": self.weight
}
global_attr_dict = {
k: v
for k, v in global_attr_dict.items()
if v is not None
}
result = list(it.chain(
global_attr_dict.items(),
self.global_config.items()
))
return [
(
self.convert_attr_key(key),
self.convert_attr_val(val)
)
for key, val in result
]
@property
def tag_items_from_markup(
def get_tag_items_from_markup(
self
) -> list[tuple[Span, Span, dict[str, str]]]:
if not self.is_markup:
@ -349,36 +394,7 @@ class MarkupText(LabelledString):
)
return result
@property
def global_attr_items_from_config(self) -> list[str, str]:
global_attr_dict = {
"line_height": (
(self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1
) * 0.6,
"font_family": self.font or get_customization()["style"]["font"],
"font_size": self.font_size * 1024,
"font_style": self.slant,
"font_weight": self.weight
}
global_attr_dict = {
k: v
for k, v in global_attr_dict.items()
if v is not None
}
result = list(it.chain(
global_attr_dict.items(),
self.global_config.items()
))
return [
(
self.convert_attr_key(key),
self.convert_attr_val(val)
)
for key, val in result
]
@property
def local_attr_items_from_markup(self) -> list[tuple[Span, str, str]]:
def get_local_items_from_markup(self) -> list[tuple[Span, str, str]]:
return sorted([
(
(begin_tag_span[0], end_tag_span[1]),
@ -390,8 +406,7 @@ class MarkupText(LabelledString):
for key, val in attr_dict.items()
])
@property
def local_attr_items_from_config(self) -> list[tuple[Span, str, str]]:
def get_local_items_from_config(self) -> list[tuple[Span, str, str]]:
result = [
(text_span, key, val)
for t2x_dict, key in (
@ -417,88 +432,19 @@ class MarkupText(LabelledString):
for text_span, key, val in result
]
def find_spans_by_word_or_span(
self, word_or_span: str | Span
) -> list[Span]:
if isinstance(word_or_span, tuple):
return [word_or_span]
return self.find_spans(re.escape(word_or_span))
#@property
#def skipped_spans(self) -> list[Span]:
# return [
# match_obj.span()
# for match_obj in re.finditer(r"\s+", self.string)
# ]
#@property
#def additional_substrings(self) -> list[str]:
# return self.get_substrs_to_isolate(self.isolate)
@property
def internal_specified_spans(self) -> list[Span]:
return [
markup_span
for markup_span, _, _ in self.local_attr_items_from_markup
]
@property
def label_span_list(self) -> list[Span]:
entity_spans = [span for span, _ in self.command_repl_items]
if self.is_markup:
entity_spans += self.find_spans(r"&.*?;")
breakup_indices = sorted(filter(
lambda index: not any([
span[0] < index < span[1]
for span in entity_spans
]),
remove_list_redundancies(list(it.chain(*(
self.specified_spans + self.find_spans(r"\s+", r"\b")
))))
))
return list(filter(
lambda span: self.string[slice(*span)].strip(),
self.get_neighbouring_pairs(breakup_indices)
))
@property
def predefined_items(self) -> list[Span, str, str]:
def get_predefined_items(self) -> list[Span, str, str]:
return list(it.chain(
[
(self.full_span, key, val)
for key, val in self.global_attr_items_from_config
for key, val in self.global_items_from_config
],
self.local_attr_items_from_markup,
self.local_attr_items_from_config
self.local_items_from_markup,
self.local_items_from_config
))
def get_inserted_string_pairs(
self, use_plain_file: bool
) -> list[tuple[Span, tuple[str, str]]]:
attr_items = self.predefined_items
if not use_plain_file:
attr_items = [
(span, key, WHITE if key in COLOR_RELATED_KEYS else val)
for span, key, val in attr_items
] + [
(span, "foreground", "#{:06x}".format(label))
for label, span in enumerate(self.label_span_list)
]
return [
(span, (
self.get_begin_tag_str(attr_dict),
self.get_end_tag_str()
))
for span, attr_dict in self.merge_attr_items(attr_items)
]
# Parsing
#@property
#def inserted_string_pairs(self) -> list[tuple[Span, tuple[str, str]]]:
# return self.get_inserted_string_pairs(use_label=True)
@property
def command_repl_items(self) -> list[tuple[Span, str]]:
def get_command_repl_items(self) -> list[tuple[Span, str]]:
result = [
(tag_span, "")
for begin_tag, end_tag, _ in self.tag_items_from_markup
@ -516,63 +462,80 @@ class MarkupText(LabelledString):
]
return result
def remove_commands_in_plain_file(self) -> bool:
return False
def get_internal_specified_spans(self) -> list[Span]:
return [
markup_span
for markup_span, _, _ in self.local_items_from_markup
]
#@abstractmethod
#def get_command_repl_items(
# self, use_plain_file: bool
#) -> list[tuple[Span, str]]:
# return []
def get_label_span_list(self) -> list[Span]:
breakup_indices = remove_list_redundancies(list(it.chain(*it.chain(
self.space_spans,
self.find_spans(r"\b"),
self.specified_spans
))))
entity_spans = self.command_spans.copy()
if self.is_markup:
entity_spans += self.find_spans(r"&.*?;")
breakup_indices = sorted(filter(
lambda index: not any([
span[0] < index < span[1]
for span in entity_spans
]),
breakup_indices
))
return list(filter(
lambda span: self.string[slice(*span)].strip(),
self.get_neighbouring_pairs(breakup_indices)
))
@property
def has_predefined_colors(self) -> bool:
def get_inserted_string_pairs(
self, use_plain_file: bool
) -> list[tuple[Span, tuple[str, str]]]:
attr_items = self.predefined_items
if not use_plain_file:
attr_items = [
(span, key, WHITE if key in COLOR_RELATED_KEYS else val)
for span, key, val in attr_items
] + [
(span, "foreground", self.rgb_int_to_hex(label))
for label, span in enumerate(self.label_span_list)
]
return [
(span, (
self.get_begin_tag_str(attr_dict),
self.get_end_tag_str()
))
for span, attr_dict in self.merge_attr_items(attr_items)
]
def get_other_repl_items(
self, use_plain_file: bool
) -> list[tuple[Span, str]]:
return self.command_repl_items.copy()
def get_has_predefined_colors(self) -> bool:
return any([
key in COLOR_RELATED_KEYS
for _, key, _ in self.predefined_items
])
#@property
#def plain_string(self) -> str:
# return "".join([
# self.get_begin_tag_str({"foreground": self.base_color}),
# self.replace_str_by_spans(
# self.string, self.get_span_replacement_dict(
# self.get_inserted_string_pairs(use_label=False),
# self.command_repl_items
# )
# ),
# self.get_end_tag_str()
# ])
#@property
#def specified_substrings(self) -> list[str]: # TODO: clean up and merge
# return remove_list_redundancies([
# self.get_cleaned_substr(markup_span)
# for markup_span, _, _ in self.local_attr_items_from_markup
# ] + self.additional_substrings)
# Method alias
def get_parts_by_text(self, substr: str) -> VGroup:
return self.get_parts_by_string(substr)
def get_parts_by_text(self, text: str) -> VGroup:
return self.get_parts_by_string(text)
def get_part_by_text(self, substr: str, index: int = 0) -> VMobject:
return self.get_part_by_string(substr, index)
def get_part_by_text(self, text: str) -> VMobject:
return self.get_part_by_string(text)
def set_color_by_text(self, substr: str, color: ManimColor):
return self.set_color_by_string(substr, color)
def set_color_by_text(self, text: str, color: ManimColor):
return self.set_color_by_string(text, color)
def set_color_by_text_to_color_map(
self, text_to_color_map: dict[str, ManimColor]
):
return self.set_color_by_string_to_color_map(text_to_color_map)
def indices_of_part_by_text(
self, substr: str, index: int = 0
) -> list[int]:
return self.indices_of_part_by_string(substr, index)
def get_text(self) -> str:
return self.get_string()