Refactor LabelledString

This commit is contained in:
YishiMichael 2022-03-30 21:53:00 +08:00
parent 7e8b3a4c6b
commit c5ec47b0e9
No known key found for this signature in database
GPG key ID: EC615C0C5A86BC80
3 changed files with 515 additions and 626 deletions

View file

@ -156,7 +156,7 @@ class TransformMatchingTex(TransformMatchingParts):
class TransformMatchingStrings(AnimationGroup): class TransformMatchingStrings(AnimationGroup):
CONFIG = { CONFIG = {
"key_map": dict(), "key_map": dict(),
"transform_mismatches_class": None, "transform_mismatches": False,
} }
def __init__(self, def __init__(self,
@ -168,42 +168,53 @@ class TransformMatchingStrings(AnimationGroup):
assert isinstance(source_mobject, LabelledString) assert isinstance(source_mobject, LabelledString)
assert isinstance(target_mobject, LabelledString) assert isinstance(target_mobject, LabelledString)
anims = [] anims = []
rest_source_submobs = source_mobject.submobjects.copy() rest_source_indices = list(range(len(source_mobject.submobjects)))
rest_target_submobs = target_mobject.submobjects.copy() rest_target_indices = list(range(len(target_mobject.submobjects)))
def add_anims_from(anim_class, func, source_attrs, target_attrs=None): def add_anims_from(anim_class, func, source_args, target_args=None):
if target_attrs is None: if target_args is None:
target_attrs = source_attrs.copy() target_args = source_args.copy()
for source_attr, target_attr in zip(source_attrs, target_attrs): for source_arg, target_arg in zip(source_args, target_args):
source_parts = func(source_mobject, source_attr) source_parts = func(source_mobject, source_arg)
target_parts = func(target_mobject, target_attr) target_parts = func(target_mobject, target_arg)
filtered_source_parts = [ source_indices_lists = source_mobject.indices_lists_of_parts(
submob_part for submob_part in source_parts source_parts
if all([ )
submob in rest_source_submobs target_indices_lists = target_mobject.indices_lists_of_parts(
for submob in submob_part target_parts
]) )
] filtered_source_indices_lists = list(filter(
filtered_target_parts = [ lambda indices_list: all([
submob_part for submob_part in target_parts index in rest_source_indices
if all([ for index in indices_list
submob in rest_target_submobs ]), source_indices_lists
for submob in submob_part
])
]
if not (filtered_source_parts and filtered_target_parts):
return
anims.append(anim_class(
VGroup(*filtered_source_parts),
VGroup(*filtered_target_parts),
**kwargs
)) ))
for submob in it.chain(*filtered_source_parts): filtered_target_indices_lists = list(filter(
rest_source_submobs.remove(submob) lambda indices_list: all([
for submob in it.chain(*filtered_target_parts): index in rest_target_indices
rest_target_submobs.remove(submob) for index in indices_list
]), target_indices_lists
))
if not all([
filtered_source_indices_lists,
filtered_target_indices_lists
]):
return
anims.append(anim_class(source_parts, target_parts, **kwargs))
for index in it.chain(*filtered_source_indices_lists):
rest_source_indices.remove(index)
for index in it.chain(*filtered_target_indices_lists):
rest_target_indices.remove(index)
def get_submobs_from_keys(mobject, keys): def get_common_substrs(func):
result = sorted(list(
set(func(source_mobject)).intersection(func(target_mobject))
), key=len, reverse=True)
if "" in result:
result.remove("")
return result
def get_parts_from_keys(mobject, keys):
if not isinstance(keys, tuple): if not isinstance(keys, tuple):
keys = (keys,) keys = (keys,)
indices = [] indices = []
@ -220,55 +231,50 @@ class TransformMatchingStrings(AnimationGroup):
else: else:
raise TypeError(key) raise TypeError(key)
return VGroup(VGroup(*[ return VGroup(VGroup(*[
mobject[i] for i in remove_list_redundancies(indices) mobject[index] for index in remove_list_redundancies(indices)
])) ]))
add_anims_from( add_anims_from(
ReplacementTransform, get_submobs_from_keys, ReplacementTransform, get_parts_from_keys,
self.key_map.keys(), self.key_map.values() self.key_map.keys(), self.key_map.values()
) )
common_specified_substrings = sorted(list(
set(source_mobject.get_specified_substrings()).intersection(
target_mobject.get_specified_substrings()
)
), key=len, reverse=True)
if "" in common_specified_substrings:
common_specified_substrings.remove("")
add_anims_from( add_anims_from(
FadeTransformPieces, FadeTransformPieces,
LabelledString.get_parts_by_string, LabelledString.get_parts_by_string,
common_specified_substrings get_common_substrs(
lambda mobject: mobject.specified_substrings
)
) )
common_submob_strings = {
source_submob.get_string() for source_submob in source_mobject
}.intersection({
target_submob.get_string() for target_submob in target_mobject
})
add_anims_from( add_anims_from(
FadeTransformPieces, FadeTransformPieces,
lambda mobject, attr: VGroup(*[ LabelledString.get_parts_by_group_substr,
VGroup(mob) for mob in mobject get_common_substrs(
if mob.get_string() == attr lambda mobject: mobject.group_substrs
]), )
common_submob_strings
) )
if self.transform_mismatches_class is not None: fade_source = VGroup(*[
anims.append(self.transform_mismatches_class( source_mobject[index]
for index in rest_source_indices
])
fade_target = VGroup(*[
target_mobject[index]
for index in rest_target_indices
])
if self.transform_mismatches:
anims.append(ReplacementTransform(
fade_source, fade_source,
fade_target, fade_target,
**kwargs **kwargs
)) ))
else: else:
anims.append(FadeOutToPoint( anims.append(FadeOutToPoint(
VGroup(*rest_source_submobs), fade_source,
target_mobject.get_center(), target_mobject.get_center(),
**kwargs **kwargs
)) ))
anims.append(FadeInFromPoint( anims.append(FadeInFromPoint(
VGroup(*rest_target_submobs), fade_target,
source_mobject.get_center(), source_mobject.get_center(),
**kwargs **kwargs
)) ))

File diff suppressed because it is too large Load diff

View file

@ -234,6 +234,14 @@ class MarkupText(LabelledString):
pango_width=pango_width pango_width=pango_width
) )
def parse(self) -> None:
self.global_items_from_config = self.get_global_items_from_config()
self.tag_items_from_markup = self.get_tag_items_from_markup()
self.local_items_from_markup = self.get_local_items_from_markup()
self.local_items_from_config = self.get_local_items_from_config()
self.predefined_items = self.get_predefined_items()
super().parse()
# Toolkits # Toolkits
@staticmethod @staticmethod
@ -251,6 +259,19 @@ class MarkupText(LabelledString):
def get_end_tag_str() -> str: def get_end_tag_str() -> str:
return "</span>" return "</span>"
@staticmethod
def rgb_int_to_hex(rgb_int: int) -> str:
return "#{:06x}".format(rgb_int).upper()
@staticmethod
def get_begin_color_command_str(rgb_int: int):
color_hex = MarkupText.rgb_int_to_hex(rgb_int)
return MarkupText.get_begin_tag_str({"foreground": color_hex})
@staticmethod
def get_end_color_command_str() -> str:
return MarkupText.get_end_tag_str()
@staticmethod @staticmethod
def convert_attr_key(key: str) -> str: def convert_attr_key(key: str) -> str:
return SPAN_ATTR_KEY_CONVERSION[key.lower()] return SPAN_ATTR_KEY_CONVERSION[key.lower()]
@ -269,7 +290,7 @@ class MarkupText(LabelledString):
if span[0] >= span[1]: if span[0] >= span[1]:
continue continue
region_indices = [ region_indices = [
MarkupText.find_region_index(index, index_seq) MarkupText.find_region_index(index_seq, index)
for index in span for index in span
] ]
for flag in (1, 0): for flag in (1, 0):
@ -289,19 +310,43 @@ class MarkupText(LabelledString):
MarkupText.get_neighbouring_pairs(index_seq), attr_dict_list[:-1] MarkupText.get_neighbouring_pairs(index_seq), attr_dict_list[:-1]
)) ))
@staticmethod def find_spans_by_word_or_span(
def get_begin_color_command_str(r: int, g: int, b: int) -> str: self, word_or_span: str | Span
color_hex = "#{:02x}{:02x}{:02x}".format(r, g, b).upper() ) -> list[Span]:
return MarkupText.get_begin_tag_str({"foreground": color_hex}) if isinstance(word_or_span, tuple):
return [word_or_span]
return self.find_spans(re.escape(word_or_span))
@staticmethod # Pre-parsing
def get_end_color_command_str() -> str:
return MarkupText.get_end_tag_str()
# Parser def get_global_items_from_config(self) -> list[str, str]:
global_attr_dict = {
"line_height": (
(self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1
) * 0.6,
"font_family": self.font or get_customization()["style"]["font"],
"font_size": self.font_size * 1024,
"font_style": self.slant,
"font_weight": self.weight
}
global_attr_dict = {
k: v
for k, v in global_attr_dict.items()
if v is not None
}
result = list(it.chain(
global_attr_dict.items(),
self.global_config.items()
))
return [
(
self.convert_attr_key(key),
self.convert_attr_val(val)
)
for key, val in result
]
@property def get_tag_items_from_markup(
def tag_items_from_markup(
self self
) -> list[tuple[Span, Span, dict[str, str]]]: ) -> list[tuple[Span, Span, dict[str, str]]]:
if not self.is_markup: if not self.is_markup:
@ -349,36 +394,7 @@ class MarkupText(LabelledString):
) )
return result return result
@property def get_local_items_from_markup(self) -> list[tuple[Span, str, str]]:
def global_attr_items_from_config(self) -> list[str, str]:
global_attr_dict = {
"line_height": (
(self.lsh or DEFAULT_LINE_SPACING_SCALE) + 1
) * 0.6,
"font_family": self.font or get_customization()["style"]["font"],
"font_size": self.font_size * 1024,
"font_style": self.slant,
"font_weight": self.weight
}
global_attr_dict = {
k: v
for k, v in global_attr_dict.items()
if v is not None
}
result = list(it.chain(
global_attr_dict.items(),
self.global_config.items()
))
return [
(
self.convert_attr_key(key),
self.convert_attr_val(val)
)
for key, val in result
]
@property
def local_attr_items_from_markup(self) -> list[tuple[Span, str, str]]:
return sorted([ return sorted([
( (
(begin_tag_span[0], end_tag_span[1]), (begin_tag_span[0], end_tag_span[1]),
@ -390,8 +406,7 @@ class MarkupText(LabelledString):
for key, val in attr_dict.items() for key, val in attr_dict.items()
]) ])
@property def get_local_items_from_config(self) -> list[tuple[Span, str, str]]:
def local_attr_items_from_config(self) -> list[tuple[Span, str, str]]:
result = [ result = [
(text_span, key, val) (text_span, key, val)
for t2x_dict, key in ( for t2x_dict, key in (
@ -417,88 +432,19 @@ class MarkupText(LabelledString):
for text_span, key, val in result for text_span, key, val in result
] ]
def find_spans_by_word_or_span( def get_predefined_items(self) -> list[Span, str, str]:
self, word_or_span: str | Span
) -> list[Span]:
if isinstance(word_or_span, tuple):
return [word_or_span]
return self.find_spans(re.escape(word_or_span))
#@property
#def skipped_spans(self) -> list[Span]:
# return [
# match_obj.span()
# for match_obj in re.finditer(r"\s+", self.string)
# ]
#@property
#def additional_substrings(self) -> list[str]:
# return self.get_substrs_to_isolate(self.isolate)
@property
def internal_specified_spans(self) -> list[Span]:
return [
markup_span
for markup_span, _, _ in self.local_attr_items_from_markup
]
@property
def label_span_list(self) -> list[Span]:
entity_spans = [span for span, _ in self.command_repl_items]
if self.is_markup:
entity_spans += self.find_spans(r"&.*?;")
breakup_indices = sorted(filter(
lambda index: not any([
span[0] < index < span[1]
for span in entity_spans
]),
remove_list_redundancies(list(it.chain(*(
self.specified_spans + self.find_spans(r"\s+", r"\b")
))))
))
return list(filter(
lambda span: self.string[slice(*span)].strip(),
self.get_neighbouring_pairs(breakup_indices)
))
@property
def predefined_items(self) -> list[Span, str, str]:
return list(it.chain( return list(it.chain(
[ [
(self.full_span, key, val) (self.full_span, key, val)
for key, val in self.global_attr_items_from_config for key, val in self.global_items_from_config
], ],
self.local_attr_items_from_markup, self.local_items_from_markup,
self.local_attr_items_from_config self.local_items_from_config
)) ))
def get_inserted_string_pairs( # Parsing
self, use_plain_file: bool
) -> list[tuple[Span, tuple[str, str]]]:
attr_items = self.predefined_items
if not use_plain_file:
attr_items = [
(span, key, WHITE if key in COLOR_RELATED_KEYS else val)
for span, key, val in attr_items
] + [
(span, "foreground", "#{:06x}".format(label))
for label, span in enumerate(self.label_span_list)
]
return [
(span, (
self.get_begin_tag_str(attr_dict),
self.get_end_tag_str()
))
for span, attr_dict in self.merge_attr_items(attr_items)
]
#@property def get_command_repl_items(self) -> list[tuple[Span, str]]:
#def inserted_string_pairs(self) -> list[tuple[Span, tuple[str, str]]]:
# return self.get_inserted_string_pairs(use_label=True)
@property
def command_repl_items(self) -> list[tuple[Span, str]]:
result = [ result = [
(tag_span, "") (tag_span, "")
for begin_tag, end_tag, _ in self.tag_items_from_markup for begin_tag, end_tag, _ in self.tag_items_from_markup
@ -516,63 +462,80 @@ class MarkupText(LabelledString):
] ]
return result return result
def remove_commands_in_plain_file(self) -> bool: def get_internal_specified_spans(self) -> list[Span]:
return False return [
markup_span
for markup_span, _, _ in self.local_items_from_markup
]
#@abstractmethod def get_label_span_list(self) -> list[Span]:
#def get_command_repl_items( breakup_indices = remove_list_redundancies(list(it.chain(*it.chain(
# self, use_plain_file: bool self.space_spans,
#) -> list[tuple[Span, str]]: self.find_spans(r"\b"),
# return [] self.specified_spans
))))
entity_spans = self.command_spans.copy()
if self.is_markup:
entity_spans += self.find_spans(r"&.*?;")
breakup_indices = sorted(filter(
lambda index: not any([
span[0] < index < span[1]
for span in entity_spans
]),
breakup_indices
))
return list(filter(
lambda span: self.string[slice(*span)].strip(),
self.get_neighbouring_pairs(breakup_indices)
))
@property def get_inserted_string_pairs(
def has_predefined_colors(self) -> bool: self, use_plain_file: bool
) -> list[tuple[Span, tuple[str, str]]]:
attr_items = self.predefined_items
if not use_plain_file:
attr_items = [
(span, key, WHITE if key in COLOR_RELATED_KEYS else val)
for span, key, val in attr_items
] + [
(span, "foreground", self.rgb_int_to_hex(label))
for label, span in enumerate(self.label_span_list)
]
return [
(span, (
self.get_begin_tag_str(attr_dict),
self.get_end_tag_str()
))
for span, attr_dict in self.merge_attr_items(attr_items)
]
def get_other_repl_items(
self, use_plain_file: bool
) -> list[tuple[Span, str]]:
return self.command_repl_items.copy()
def get_has_predefined_colors(self) -> bool:
return any([ return any([
key in COLOR_RELATED_KEYS key in COLOR_RELATED_KEYS
for _, key, _ in self.predefined_items for _, key, _ in self.predefined_items
]) ])
#@property
#def plain_string(self) -> str:
# return "".join([
# self.get_begin_tag_str({"foreground": self.base_color}),
# self.replace_str_by_spans(
# self.string, self.get_span_replacement_dict(
# self.get_inserted_string_pairs(use_label=False),
# self.command_repl_items
# )
# ),
# self.get_end_tag_str()
# ])
#@property
#def specified_substrings(self) -> list[str]: # TODO: clean up and merge
# return remove_list_redundancies([
# self.get_cleaned_substr(markup_span)
# for markup_span, _, _ in self.local_attr_items_from_markup
# ] + self.additional_substrings)
# Method alias # Method alias
def get_parts_by_text(self, substr: str) -> VGroup: def get_parts_by_text(self, text: str) -> VGroup:
return self.get_parts_by_string(substr) return self.get_parts_by_string(text)
def get_part_by_text(self, substr: str, index: int = 0) -> VMobject: def get_part_by_text(self, text: str) -> VMobject:
return self.get_part_by_string(substr, index) return self.get_part_by_string(text)
def set_color_by_text(self, substr: str, color: ManimColor): def set_color_by_text(self, text: str, color: ManimColor):
return self.set_color_by_string(substr, color) return self.set_color_by_string(text, color)
def set_color_by_text_to_color_map( def set_color_by_text_to_color_map(
self, text_to_color_map: dict[str, ManimColor] self, text_to_color_map: dict[str, ManimColor]
): ):
return self.set_color_by_string_to_color_map(text_to_color_map) return self.set_color_by_string_to_color_map(text_to_color_map)
def indices_of_part_by_text(
self, substr: str, index: int = 0
) -> list[int]:
return self.indices_of_part_by_string(substr, index)
def get_text(self) -> str: def get_text(self) -> str:
return self.get_string() return self.get_string()