3b1b-manim/manimlib/animation/transform_matching_parts.py

256 lines
9.2 KiB
Python
Raw Normal View History

from __future__ import annotations
import itertools as it
2021-01-30 17:52:02 -08:00
import numpy as np
2022-12-28 13:39:46 -08:00
from manimlib.animation.animation import Animation
from manimlib.animation.composition import AnimationGroup
from manimlib.animation.fading import FadeInFromPoint
from manimlib.animation.fading import FadeOutToPoint
from manimlib.animation.fading import FadeTransformPieces
2022-12-28 13:39:46 -08:00
from manimlib.animation.fading import FadeTransform
from manimlib.animation.transform import ReplacementTransform
from manimlib.animation.transform import Transform
from manimlib.mobject.mobject import Mobject
from manimlib.mobject.mobject import Group
2022-05-06 22:09:58 +08:00
from manimlib.mobject.svg.string_mobject import StringMobject
from manimlib.mobject.svg.old_tex_mobject import OldTex
2022-12-28 13:39:46 -08:00
from manimlib.mobject.svg.tex_mobject import Tex
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.mobject.types.vectorized_mobject import VMobject
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.mobject.svg.old_tex_mobject import SingleStringTex
from manimlib.scene.scene import Scene
class TransformMatchingParts(AnimationGroup):
mobject_type: type = Mobject
group_type: type = Group
def __init__(
self,
mobject: Mobject,
target_mobject: Mobject,
transform_mismatches: bool = False,
fade_transform_mismatches: bool = False,
key_map: dict | None = None,
**kwargs
):
assert(isinstance(mobject, self.mobject_type))
assert(isinstance(target_mobject, self.mobject_type))
source_map = self.get_shape_map(mobject)
target_map = self.get_shape_map(target_mobject)
key_map = key_map or dict()
# Create two mobjects whose submobjects all match each other
# according to whatever keys are used for source_map and
# target_map
transform_source = self.group_type()
transform_target = self.group_type()
kwargs["final_alpha_value"] = 0
for key in set(source_map).intersection(target_map):
transform_source.add(source_map[key])
transform_target.add(target_map[key])
anims = [Transform(transform_source, transform_target, **kwargs)]
# User can manually specify when one part should transform
# into another despite not matching by using key_map
key_mapped_source = self.group_type()
key_mapped_target = self.group_type()
for key1, key2 in key_map.items():
if key1 in source_map and key2 in target_map:
key_mapped_source.add(source_map[key1])
key_mapped_target.add(target_map[key2])
source_map.pop(key1, None)
target_map.pop(key2, None)
if len(key_mapped_source) > 0:
anims.append(FadeTransformPieces(
key_mapped_source,
key_mapped_target,
))
fade_source = self.group_type()
fade_target = self.group_type()
for key in set(source_map).difference(target_map):
fade_source.add(source_map[key])
for key in set(target_map).difference(source_map):
fade_target.add(target_map[key])
if transform_mismatches:
2021-01-30 17:52:02 -08:00
anims.append(Transform(fade_source.copy(), fade_target, **kwargs))
if fade_transform_mismatches:
anims.append(FadeTransformPieces(fade_source, fade_target, **kwargs))
else:
anims.append(FadeOutToPoint(
2021-12-13 16:03:12 -08:00
fade_source, target_mobject.get_center(), **kwargs
))
anims.append(FadeInFromPoint(
2021-12-13 16:03:12 -08:00
fade_target.copy(), mobject.get_center(), **kwargs
))
super().__init__(*anims)
self.to_remove = mobject
self.to_add = target_mobject
2022-12-28 13:39:46 -08:00
def get_shape_map(self, mobject: Mobject) -> dict[int | str, VGroup]:
shape_map: dict[int | str, VGroup] = {}
for sm in self.get_mobject_parts(mobject):
key = self.get_mobject_key(sm)
if key not in shape_map:
shape_map[key] = VGroup()
shape_map[key].add(sm)
return shape_map
def clean_up_from_scene(self, scene: Scene) -> None:
for anim in self.animations:
anim.update(0)
scene.remove(self.mobject)
scene.remove(self.to_remove)
scene.add(self.to_add)
@staticmethod
def get_mobject_parts(mobject: Mobject) -> Mobject:
# To be implemented in subclass
return mobject
@staticmethod
def get_mobject_key(mobject: Mobject) -> int:
# To be implemented in subclass
return hash(mobject)
class TransformMatchingShapes(TransformMatchingParts):
mobject_type: type = VMobject
group_type: type = VGroup
@staticmethod
def get_mobject_parts(mobject: VMobject) -> list[VMobject]:
return mobject.family_members_with_points()
@staticmethod
def get_mobject_key(mobject: VMobject) -> int:
2021-01-30 17:52:02 -08:00
mobject.save_state()
mobject.center()
mobject.set_height(1)
result = hash(np.round(mobject.get_points(), 3).tobytes())
mobject.restore()
return result
2022-03-29 23:38:06 +08:00
class TransformMatchingStrings(AnimationGroup):
2022-12-28 13:39:46 -08:00
def __init__(
self,
2022-05-06 22:09:58 +08:00
source: StringMobject,
target: StringMobject,
2022-12-28 13:39:46 -08:00
matched_keys: list[str] | None = None,
key_map: dict[str, str] | None = None,
match_animation: type = Transform,
mismatch_animation: type = Transform,
run_time=2,
**kwargs,
2022-03-28 18:54:43 +08:00
):
2022-12-28 13:39:46 -08:00
self.source = source
self.target = target
matched_keys = matched_keys or list()
key_map = key_map or dict()
self.anim_config = dict(run_time=run_time, **kwargs)
# We will progressively build up a list of transforms
# from characters in source to those in target. These
# two lists keep track of which characters are accounted
# for so far
self.source_chars = source.family_members_with_points()
self.target_chars = target.family_members_with_points()
self.anims = []
# Start by pairing all matched keys specifically passed in
for key in matched_keys:
self.add_transform(
source.select_parts(key),
target.select_parts(key),
match_animation
)
# Then pair those based on the key map
for key, value in key_map.items():
self.add_transform(
source.select_parts(key),
target.select_parts(value),
mismatch_animation
)
# Now pair by substrings which were isolated in StringMobject
# initializations
specified_substrings = [
*source.get_specified_substrings(),
*target.get_specified_substrings()
]
for key in specified_substrings:
self.add_transform(
source.select_parts(key),
target.select_parts(key),
match_animation
)
# Match any pairs with the same shape
pairs = self.find_pairs_with_matching_shapes(self.source_chars, self.target_chars)
for source_char, target_char in pairs:
self.add_transform(source_char, target_char, match_animation)
# Finally, account for mismatches
for source_char in self.source_chars:
self.anims.append(FadeOutToPoint(
source_char, target.get_center(),
**self.anim_config
))
for target_char in self.target_chars:
self.anims.append(FadeInFromPoint(
target_char, source.get_center(),
**self.anim_config
))
super().__init__(*self.anims)
2022-03-30 21:53:00 +08:00
2022-12-28 13:39:46 -08:00
def add_transform(
self,
source: VMobject,
target: VMobject,
transform_type: type = Transform,
):
new_source_chars = source.family_members_with_points()
new_target_chars = target.family_members_with_points()
source_is_new = all(char in self.source_chars for char in new_source_chars)
target_is_new = all(char in self.target_chars for char in new_target_chars)
if source_is_new and target_is_new:
self.anims.append(transform_type(
source, target, **self.anim_config
))
2022-12-28 13:39:46 -08:00
for char in new_source_chars:
self.source_chars.remove(char)
for char in new_target_chars:
self.target_chars.remove(char)
def find_pairs_with_matching_shapes(self, chars1, chars2) -> list[tuple[VMobject, VMobject]]:
for char in (*chars1, *chars2):
char.save_state()
char.set_height(1)
char.center()
result = []
for char1, char2 in it.product(chars1, chars2):
p1 = char1.get_points()
p2 = char2.get_points()
if len(p1) == len(p2) and np.isclose(p1, p2 , atol=1e-1).all():
result.append((char1, char2))
for char in (*chars1, *chars2):
char.restore()
return result
2022-12-28 13:39:46 -08:00
def clean_up_from_scene(self, scene: Scene) -> None:
super().clean_up_from_scene(scene)
scene.remove(self.mobject)
scene.add(self.target)
2022-12-28 13:39:46 -08:00
class TransformMatchingTex(TransformMatchingStrings):
"""Alias for TransformMatchingStrings"""
pass