3b1b-manim/manimlib/animation/transform_matching_parts.py

249 lines
8.7 KiB
Python
Raw Normal View History

from __future__ import annotations
import itertools as it
2021-01-30 17:52:02 -08:00
import numpy as np
from manimlib.animation.composition import AnimationGroup
from manimlib.animation.fading import FadeInFromPoint
from manimlib.animation.fading import FadeOutToPoint
2022-04-12 19:19:59 +08:00
from manimlib.animation.fading import FadeTransformPieces
from manimlib.animation.transform import ReplacementTransform
from manimlib.animation.transform import Transform
from manimlib.mobject.mobject import Mobject
from manimlib.mobject.mobject import Group
2022-03-30 21:57:27 +08:00
from manimlib.mobject.svg.labelled_string import LabelledString
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.config_ops import digest_config
from typing import TYPE_CHECKING
if TYPE_CHECKING:
2022-04-12 19:19:59 +08:00
from manimlib.mobject.svg.tex_mobject import SingleStringTex
from manimlib.mobject.svg.tex_mobject import Tex
from manimlib.scene.scene import Scene
class TransformMatchingParts(AnimationGroup):
CONFIG = {
"mobject_type": Mobject,
"group_type": Group,
2021-01-06 16:14:36 -08:00
"transform_mismatches": False,
"fade_transform_mismatches": False,
"key_map": dict(),
}
def __init__(self, mobject: Mobject, target_mobject: Mobject, **kwargs):
digest_config(self, kwargs)
assert(isinstance(mobject, self.mobject_type))
assert(isinstance(target_mobject, self.mobject_type))
source_map = self.get_shape_map(mobject)
target_map = self.get_shape_map(target_mobject)
# Create two mobjects whose submobjects all match each other
# according to whatever keys are used for source_map and
# target_map
transform_source = self.group_type()
transform_target = self.group_type()
kwargs["final_alpha_value"] = 0
for key in set(source_map).intersection(target_map):
transform_source.add(source_map[key])
transform_target.add(target_map[key])
anims = [Transform(transform_source, transform_target, **kwargs)]
# User can manually specify when one part should transform
# into another despite not matching by using key_map
key_mapped_source = self.group_type()
key_mapped_target = self.group_type()
for key1, key2 in self.key_map.items():
if key1 in source_map and key2 in target_map:
key_mapped_source.add(source_map[key1])
key_mapped_target.add(target_map[key2])
source_map.pop(key1, None)
target_map.pop(key2, None)
if len(key_mapped_source) > 0:
anims.append(FadeTransformPieces(
key_mapped_source,
key_mapped_target,
))
fade_source = self.group_type()
fade_target = self.group_type()
for key in set(source_map).difference(target_map):
fade_source.add(source_map[key])
for key in set(target_map).difference(source_map):
fade_target.add(target_map[key])
2021-01-06 16:14:36 -08:00
if self.transform_mismatches:
2021-01-30 17:52:02 -08:00
anims.append(Transform(fade_source.copy(), fade_target, **kwargs))
if self.fade_transform_mismatches:
anims.append(FadeTransformPieces(fade_source, fade_target, **kwargs))
else:
anims.append(FadeOutToPoint(
2021-12-13 16:03:12 -08:00
fade_source, target_mobject.get_center(), **kwargs
))
anims.append(FadeInFromPoint(
2021-12-13 16:03:12 -08:00
fade_target.copy(), mobject.get_center(), **kwargs
))
super().__init__(*anims)
self.to_remove = mobject
self.to_add = target_mobject
def get_shape_map(self, mobject: Mobject) -> dict[int, VGroup]:
shape_map: dict[int, VGroup] = {}
for sm in self.get_mobject_parts(mobject):
key = self.get_mobject_key(sm)
if key not in shape_map:
shape_map[key] = VGroup()
shape_map[key].add(sm)
return shape_map
def clean_up_from_scene(self, scene: Scene) -> None:
for anim in self.animations:
anim.update(0)
scene.remove(self.mobject)
scene.remove(self.to_remove)
scene.add(self.to_add)
@staticmethod
def get_mobject_parts(mobject: Mobject) -> Mobject:
# To be implemented in subclass
return mobject
@staticmethod
def get_mobject_key(mobject: Mobject) -> int:
# To be implemented in subclass
return hash(mobject)
class TransformMatchingShapes(TransformMatchingParts):
CONFIG = {
"mobject_type": VMobject,
"group_type": VGroup,
}
@staticmethod
def get_mobject_parts(mobject: VMobject) -> list[VMobject]:
return mobject.family_members_with_points()
@staticmethod
def get_mobject_key(mobject: VMobject) -> int:
2021-01-30 17:52:02 -08:00
mobject.save_state()
mobject.center()
mobject.set_height(1)
result = hash(np.round(mobject.get_points(), 3).tobytes())
mobject.restore()
return result
class TransformMatchingTex(TransformMatchingParts):
CONFIG = {
"mobject_type": VMobject,
"group_type": VGroup,
}
@staticmethod
def get_mobject_parts(mobject: Tex) -> list[SingleStringTex]:
return mobject.submobjects
@staticmethod
def get_mobject_key(mobject: Tex) -> str:
2021-01-07 12:14:51 -08:00
return mobject.get_tex()
2022-03-29 23:38:06 +08:00
class TransformMatchingStrings(AnimationGroup):
CONFIG = {
"key_map": dict(),
2022-03-30 21:53:00 +08:00
"transform_mismatches": False,
}
2022-03-28 18:54:43 +08:00
def __init__(self,
2022-04-10 08:36:13 +08:00
source: LabelledString,
target: LabelledString,
2022-03-28 18:54:43 +08:00
**kwargs
):
digest_config(self, kwargs)
2022-04-10 08:36:13 +08:00
assert isinstance(source, LabelledString)
assert isinstance(target, LabelledString)
anims = []
2022-04-10 08:36:13 +08:00
source_indices = list(range(len(source.labelled_submobjects)))
target_indices = list(range(len(target.labelled_submobjects)))
def get_indices_lists(mobject, parts):
return [
[
mobject.labelled_submobjects.index(submob)
for submob in part
]
for part in parts
]
2022-03-30 21:53:00 +08:00
def add_anims_from(anim_class, func, source_args, target_args=None):
if target_args is None:
target_args = source_args.copy()
for source_arg, target_arg in zip(source_args, target_args):
2022-04-10 08:36:13 +08:00
source_parts = func(source, source_arg)
target_parts = func(target, target_arg)
source_indices_lists = list(filter(
2022-03-30 21:53:00 +08:00
lambda indices_list: all([
2022-04-10 08:36:13 +08:00
index in source_indices
2022-03-30 21:53:00 +08:00
for index in indices_list
2022-04-10 08:36:13 +08:00
]), get_indices_lists(source, source_parts)
2022-03-28 18:54:43 +08:00
))
2022-04-10 08:36:13 +08:00
target_indices_lists = list(filter(
2022-03-30 21:53:00 +08:00
lambda indices_list: all([
2022-04-10 08:36:13 +08:00
index in target_indices
2022-03-30 21:53:00 +08:00
for index in indices_list
2022-04-10 08:36:13 +08:00
]), get_indices_lists(target, target_parts)
2022-03-30 21:53:00 +08:00
))
2022-04-10 08:36:13 +08:00
if not source_indices_lists or not target_indices_lists:
2022-03-31 16:15:58 +08:00
continue
2022-03-30 21:53:00 +08:00
anims.append(anim_class(source_parts, target_parts, **kwargs))
2022-04-10 08:36:13 +08:00
for index in it.chain(*source_indices_lists):
source_indices.remove(index)
for index in it.chain(*target_indices_lists):
target_indices.remove(index)
2022-03-30 21:53:00 +08:00
2022-04-10 08:36:13 +08:00
def get_common_substrs(substrs_from_source, substrs_from_target):
2022-03-31 16:15:58 +08:00
return sorted([
2022-04-10 08:36:13 +08:00
substr for substr in substrs_from_source
if substr and substr in substrs_from_target
2022-03-31 16:15:58 +08:00
], key=len, reverse=True)
2022-03-30 21:53:00 +08:00
2022-03-28 18:54:43 +08:00
add_anims_from(
2022-04-16 12:53:43 +08:00
ReplacementTransform, LabelledString.select_parts,
2022-03-28 18:54:43 +08:00
self.key_map.keys(), self.key_map.values()
)
add_anims_from(
2022-04-16 12:53:43 +08:00
FadeTransformPieces, LabelledString.select_parts,
2022-04-10 08:36:13 +08:00
get_common_substrs(
source.specified_substrs,
target.specified_substrs
)
2022-03-28 18:54:43 +08:00
)
add_anims_from(
2022-04-16 12:53:43 +08:00
FadeTransformPieces, LabelledString.select_parts_by_group_substr,
2022-04-10 08:36:13 +08:00
get_common_substrs(
source.group_substrs,
target.group_substrs
)
2022-03-28 18:54:43 +08:00
)
2022-04-10 08:36:13 +08:00
rest_source = VGroup(*[source[index] for index in source_indices])
rest_target = VGroup(*[target[index] for index in target_indices])
2022-03-30 21:53:00 +08:00
if self.transform_mismatches:
2022-04-10 08:36:13 +08:00
anims.append(
ReplacementTransform(rest_source, rest_target, **kwargs)
)
2022-03-28 17:55:50 +08:00
else:
2022-04-10 08:36:13 +08:00
anims.append(
FadeOutToPoint(rest_source, target.get_center(), **kwargs)
)
anims.append(
FadeInFromPoint(rest_target, source.get_center(), **kwargs)
)
super().__init__(*anims)