Write new TransformMatchingStrings

This commit is contained in:
Grant Sanderson 2022-12-28 13:39:46 -08:00
parent c7ba775845
commit 926f3515bf

View file

@ -4,16 +4,19 @@ import itertools as it
import numpy as np import numpy as np
from manimlib.animation.animation import Animation
from manimlib.animation.composition import AnimationGroup from manimlib.animation.composition import AnimationGroup
from manimlib.animation.fading import FadeInFromPoint from manimlib.animation.fading import FadeInFromPoint
from manimlib.animation.fading import FadeOutToPoint from manimlib.animation.fading import FadeOutToPoint
from manimlib.animation.fading import FadeTransformPieces from manimlib.animation.fading import FadeTransformPieces
from manimlib.animation.fading import FadeTransform
from manimlib.animation.transform import ReplacementTransform from manimlib.animation.transform import ReplacementTransform
from manimlib.animation.transform import Transform from manimlib.animation.transform import Transform
from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Mobject
from manimlib.mobject.mobject import Group from manimlib.mobject.mobject import Group
from manimlib.mobject.svg.string_mobject import StringMobject from manimlib.mobject.svg.string_mobject import StringMobject
from manimlib.mobject.svg.old_tex_mobject import OldTex from manimlib.mobject.svg.old_tex_mobject import OldTex
from manimlib.mobject.svg.tex_mobject import Tex
from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.mobject.types.vectorized_mobject import VMobject
@ -93,8 +96,8 @@ class TransformMatchingParts(AnimationGroup):
self.to_remove = mobject self.to_remove = mobject
self.to_add = target_mobject self.to_add = target_mobject
def get_shape_map(self, mobject: Mobject) -> dict[int, VGroup]: def get_shape_map(self, mobject: Mobject) -> dict[int | str, VGroup]:
shape_map: dict[int, VGroup] = {} shape_map: dict[int | str, VGroup] = {}
for sm in self.get_mobject_parts(mobject): for sm in self.get_mobject_parts(mobject):
key = self.get_mobject_key(sm) key = self.get_mobject_key(sm)
if key not in shape_map: if key not in shape_map:
@ -138,128 +141,115 @@ class TransformMatchingShapes(TransformMatchingParts):
return result return result
class TransformMatchingTex(TransformMatchingParts):
mobject_type: type = OldTex
group_type: type = VGroup
@staticmethod
def get_mobject_parts(mobject: OldTex) -> list[SingleStringTex]:
return mobject.submobjects
@staticmethod
def get_mobject_key(mobject: OldTex) -> str:
return mobject.get_tex()
class TransformMatchingStrings(AnimationGroup): class TransformMatchingStrings(AnimationGroup):
def __init__(self, def __init__(
self,
source: StringMobject, source: StringMobject,
target: StringMobject, target: StringMobject,
key_map: dict = {}, matched_keys: list[str] | None = None,
transform_mismatches: bool = False, key_map: dict[str, str] | None = None,
**kwargs match_animation: type = Transform,
mismatch_animation: type = Transform,
run_time=2,
**kwargs,
): ):
assert isinstance(source, StringMobject) self.source = source
assert isinstance(target, StringMobject) self.target = target
matched_keys = matched_keys or list()
key_map = key_map or dict()
self.anim_config = dict(run_time=run_time, **kwargs)
def get_matched_indices_lists(*part_items_list): # We will progressively build up a list of transforms
part_items_list_len = len(part_items_list) # from characters in source to those in target. These
indexed_part_items = sorted(it.chain(*[ # two lists keep track of which characters are accounted
[ # for so far
(substr, items_index, indices_list) self.source_chars = source.family_members_with_points()
for substr, indices_list in part_items self.target_chars = target.family_members_with_points()
] self.anims = []
for items_index, part_items in enumerate(part_items_list)
]))
grouped_part_items = [
(substr, [
[indices_lists for _, _, indices_lists in grouper_2]
for _, grouper_2 in it.groupby(
grouper_1, key=lambda t: t[1]
)
])
for substr, grouper_1 in it.groupby(
indexed_part_items, key=lambda t: t[0]
)
]
return [
tuple(indices_lists_list)
for _, indices_lists_list in sorted(filter(
lambda t: t[0] and len(t[1]) == part_items_list_len,
grouped_part_items
), key=lambda t: len(t[0]), reverse=True)
]
def get_filtered_indices_lists(indices_lists, used_indices): # Start by pairing all matched keys specifically passed in
result = [] for key in matched_keys:
used = [] self.add_transform(
for indices_list in indices_lists: source.select_parts(key),
if not all( target.select_parts(key),
index not in used_indices and index not in used match_animation
for index in indices_list )
): # Then pair those based on the key map
continue for key, value in key_map.items():
result.append(indices_list) self.add_transform(
used.extend(indices_list) source.select_parts(key),
return result, used target.select_parts(value),
mismatch_animation
anim_class_items = [ )
(ReplacementTransform, [ # Now pair by substrings which were isolated in StringMobject
( # initializations
source.get_submob_indices_lists_by_selector(k), specified_substrings = [
target.get_submob_indices_lists_by_selector(v) *source.get_specified_substrings(),
) *target.get_specified_substrings()
for k, v in key_map.items()
]),
(FadeTransformPieces, get_matched_indices_lists(
source.get_specified_part_items(),
target.get_specified_part_items()
)),
(FadeTransformPieces, get_matched_indices_lists(
source.get_group_part_items(),
target.get_group_part_items()
))
] ]
for key in specified_substrings:
anims = [] self.add_transform(
source_used_indices = [] source.select_parts(key),
target_used_indices = [] target.select_parts(key),
for anim_class, pairs in anim_class_items: match_animation
for source_indices_lists, target_indices_lists in pairs:
source_filtered, source_used = get_filtered_indices_lists(
source_indices_lists, source_used_indices
)
target_filtered, target_used = get_filtered_indices_lists(
target_indices_lists, target_used_indices
)
if not source_filtered or not target_filtered:
continue
anims.append(anim_class(
source.build_parts_from_indices_lists(source_filtered),
target.build_parts_from_indices_lists(target_filtered),
**kwargs
))
source_used_indices.extend(source_used)
target_used_indices.extend(target_used)
rest_source = VGroup(*(
submob for index, submob in enumerate(source.submobjects)
if index not in source_used_indices
))
rest_target = VGroup(*(
submob for index, submob in enumerate(target.submobjects)
if index not in target_used_indices
))
if transform_mismatches:
anims.append(
ReplacementTransform(rest_source, rest_target, **kwargs)
)
else:
anims.append(
FadeOutToPoint(rest_source, target.get_center(), **kwargs)
)
anims.append(
FadeInFromPoint(rest_target, source.get_center(), **kwargs)
) )
# Match any pairs with the same shape
pairs = self.find_pairs_with_matching_shapes(self.source_chars, self.target_chars)
for source_char, target_char in pairs:
self.add_transform(source_char, target_char, match_animation)
# Finally, account for mismatches
for source_char in self.source_chars:
self.anims.append(FadeOutToPoint(
source_char, target.get_center(),
**self.anim_config
))
for target_char in self.target_chars:
self.anims.append(FadeInFromPoint(
target_char, source.get_center(),
**self.anim_config
))
super().__init__(*self.anims)
super().__init__(*anims) def add_transform(
self,
source: VMobject,
target: VMobject,
transform_type: type = Transform,
):
new_source_chars = source.family_members_with_points()
new_target_chars = target.family_members_with_points()
source_is_new = all(char in self.source_chars for char in new_source_chars)
target_is_new = all(char in self.target_chars for char in new_target_chars)
if source_is_new and target_is_new:
self.anims.append(transform_type(
source, target, **self.anim_config
))
for char in new_source_chars:
self.source_chars.remove(char)
for char in new_target_chars:
self.target_chars.remove(char)
def find_pairs_with_matching_shapes(self, chars1, chars2) -> list[tuple[VMobject, VMobject]]:
for char in (*chars1, *chars2):
char.save_state()
char.set_height(1)
char.center()
result = []
for char1, char2 in it.product(chars1, chars2):
p1 = char1.get_points()
p2 = char2.get_points()
if len(p1) == len(p2) and np.isclose(p1, p2 , atol=1e-1).all():
result.append((char1, char2))
for char in (*chars1, *chars2):
char.restore()
return result
def clean_up_from_scene(self, scene: Scene) -> None:
super().clean_up_from_scene(scene)
scene.remove(self.mobject)
scene.add(self.target)
class TransformMatchingTex(TransformMatchingStrings):
"""Alias for TransformMatchingStrings"""
pass