Add TransformMatchingShapes and TransformMatchingTex

This commit is contained in:
Grant Sanderson 2021-01-05 18:03:06 -08:00
parent 92386f4e20
commit c0b90b398c

View file

@ -1,13 +1,26 @@
from manimlib.animation.composition import AnimationGroup from manimlib.animation.composition import AnimationGroup
from manimlib.animation.fading import FadeTransformPieces from manimlib.animation.fading import FadeTransformPieces
from manimlib.animation.fading import FadeInFromPoint
from manimlib.animation.fading import FadeOutToPoint
from manimlib.animation.transform import Transform from manimlib.animation.transform import Transform
from manimlib.mobject.mobject import Mobject
from manimlib.mobject.mobject import Group
from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.mobject.svg.tex_mobject import TexMobject
from manimlib.utils.config_ops import digest_config
class TransformMatchingParts(AnimationGroup): class TransformMatchingParts(AnimationGroup):
CONFIG = {
"mobject_type": Mobject,
"group_type": Group,
"fade_transform_mismatches": False,
}
def __init__(self, mobject, target_mobject, **kwargs): def __init__(self, mobject, target_mobject, **kwargs):
digest_config(self, kwargs)
assert(isinstance(mobject, VMobject) and isinstance(target_mobject, VMobject)) assert(isinstance(mobject, VMobject) and isinstance(target_mobject, VMobject))
source_map = self.get_shape_map(mobject) source_map = self.get_shape_map(mobject)
target_map = self.get_shape_map(target_mobject) target_map = self.get_shape_map(target_mobject)
@ -17,27 +30,36 @@ class TransformMatchingParts(AnimationGroup):
fade_source = VGroup() fade_source = VGroup()
fade_target = VGroup() fade_target = VGroup()
kwargs["final_alpha_value"] = 0
for key in set(source_map).intersection(target_map): for key in set(source_map).intersection(target_map):
transform_source.add(source_map[key]) transform_source.add(source_map[key])
transform_target.add(target_map[key]) transform_target.add(target_map[key])
anims = [Transform(transform_source, transform_target, **kwargs)]
for key in set(source_map).difference(target_map): for key in set(source_map).difference(target_map):
fade_source.add(source_map[key]) fade_source.add(source_map[key])
for key in set(target_map).difference(source_map): for key in set(target_map).difference(source_map):
fade_target.add(target_map[key]) fade_target.add(target_map[key])
kwargs["final_alpha_value"] = 0 if self.fade_transform_mismatches:
super().__init__( anims.append(FadeTransformPieces(fade_source, fade_target, **kwargs))
Transform(transform_source, transform_target, **kwargs), else:
FadeTransformPieces(fade_source, fade_target, **kwargs), anims.append(FadeOutToPoint(
) fade_source, fade_target.get_center(), **kwargs
))
anims.append(FadeInFromPoint(
fade_target.copy(), fade_source.get_center(), **kwargs
))
super().__init__(*anims)
self.to_remove = mobject self.to_remove = mobject
self.to_add = target_mobject self.to_add = target_mobject
def get_shape_map(self, mobject): def get_shape_map(self, mobject):
shape_map = {} shape_map = {}
for sm in mobject.family_members_with_points(): for sm in self.get_mobject_parts(mobject):
key = hash(sm.get_triangulation().tobytes()) key = self.get_mobject_key(sm)
if key not in shape_map: if key not in shape_map:
shape_map[key] = VGroup() shape_map[key] = VGroup()
shape_map[key].add(sm) shape_map[key].add(sm)
@ -49,3 +71,37 @@ class TransformMatchingParts(AnimationGroup):
scene.remove(self.mobject) scene.remove(self.mobject)
scene.remove(self.to_remove) scene.remove(self.to_remove)
scene.add(self.to_add) scene.add(self.to_add)
def get_mobject_parts(self, mobject):
# To be implemented in subclass
return mobject
def get_mobject_key(self, mobject):
# To be implemented in subclass
return hash(mobject)
class TransformMatchingShapes(TransformMatchingParts):
CONFIG = {
"mobject_type": VMobject,
"group_type": VGroup,
}
def get_mobject_parts(self, mobject):
return mobject.family_members_with_points()
def get_mobject_key(self, mobject):
return hash(mobject.get_triangulation().tobytes())
class TransformMatchingTex(TransformMatchingParts):
CONFIG = {
"mobject_type": TexMobject,
"group_type": VGroup,
}
def get_mobject_parts(self, mobject):
return mobject.submobjects
def get_mobject_key(self, mobject):
return mobject.get_tex_string()