Abstract away logic of TransformMatchingStrings to a new version of TransformMatchingParts

This commit is contained in:
Grant Sanderson 2022-12-30 13:54:55 -08:00
parent 96bc95ef38
commit 4335e85659

View file

@ -18,218 +18,90 @@ from manimlib.mobject.types.vectorized_mobject import VMobject
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Iterable
from manimlib.scene.scene import Scene from manimlib.scene.scene import Scene
class TransformMatchingParts(AnimationGroup): class TransformMatchingParts(AnimationGroup):
mobject_type: type = Mobject
group_type: type = Group
def __init__( def __init__(
self, self,
mobject: Mobject, source: Mobject,
target_mobject: Mobject, target: Mobject,
transform_mismatches: bool = False, matched_pairs: Iterable[tuple[Mobject, Mobject]] = [],
fade_transform_mismatches: bool = False,
key_map: dict | None = None,
**kwargs
):
assert(isinstance(mobject, self.mobject_type))
assert(isinstance(target_mobject, self.mobject_type))
source_map = self.get_shape_map(mobject)
target_map = self.get_shape_map(target_mobject)
key_map = key_map or dict()
# Create two mobjects whose submobjects all match each other
# according to whatever keys are used for source_map and
# target_map
transform_source = self.group_type()
transform_target = self.group_type()
kwargs["final_alpha_value"] = 0
for key in set(source_map).intersection(target_map):
transform_source.add(source_map[key])
transform_target.add(target_map[key])
anims = [Transform(transform_source, transform_target, **kwargs)]
# User can manually specify when one part should transform
# into another despite not matching by using key_map
key_mapped_source = self.group_type()
key_mapped_target = self.group_type()
for key1, key2 in key_map.items():
if key1 in source_map and key2 in target_map:
key_mapped_source.add(source_map[key1])
key_mapped_target.add(target_map[key2])
source_map.pop(key1, None)
target_map.pop(key2, None)
if len(key_mapped_source) > 0:
anims.append(FadeTransformPieces(
key_mapped_source,
key_mapped_target,
))
fade_source = self.group_type()
fade_target = self.group_type()
for key in set(source_map).difference(target_map):
fade_source.add(source_map[key])
for key in set(target_map).difference(source_map):
fade_target.add(target_map[key])
if transform_mismatches:
anims.append(Transform(fade_source.copy(), fade_target, **kwargs))
if fade_transform_mismatches:
anims.append(FadeTransformPieces(fade_source, fade_target, **kwargs))
else:
anims.append(FadeOutToPoint(
fade_source, target_mobject.get_center(), **kwargs
))
anims.append(FadeInFromPoint(
fade_target.copy(), mobject.get_center(), **kwargs
))
super().__init__(*anims)
self.to_remove = mobject
self.to_add = target_mobject
def get_shape_map(self, mobject: Mobject) -> dict[int | str, VGroup]:
shape_map: dict[int | str, VGroup] = {}
for sm in self.get_mobject_parts(mobject):
key = self.get_mobject_key(sm)
if key not in shape_map:
shape_map[key] = VGroup()
shape_map[key].add(sm)
return shape_map
def clean_up_from_scene(self, scene: Scene) -> None:
for anim in self.animations:
anim.update(0)
scene.remove(self.mobject)
scene.remove(self.to_remove)
scene.add(self.to_add)
@staticmethod
def get_mobject_parts(mobject: Mobject) -> Mobject:
# To be implemented in subclass
return mobject
@staticmethod
def get_mobject_key(mobject: Mobject) -> int:
# To be implemented in subclass
return hash(mobject)
class TransformMatchingShapes(TransformMatchingParts):
mobject_type: type = VMobject
group_type: type = VGroup
@staticmethod
def get_mobject_parts(mobject: VMobject) -> list[VMobject]:
return mobject.family_members_with_points()
@staticmethod
def get_mobject_key(mobject: VMobject) -> int:
mobject.save_state()
mobject.center()
mobject.set_height(1)
result = hash(np.round(mobject.get_points(), 3).tobytes())
mobject.restore()
return result
class TransformMatchingStrings(AnimationGroup):
def __init__(
self,
source: StringMobject,
target: StringMobject,
matched_keys: list[str] | None = None,
key_map: dict[str, str] | None = None,
match_animation: type = Transform, match_animation: type = Transform,
mismatch_animation: type = Transform, mismatch_animation: type = Transform,
run_time=2, run_time: float = 2,
lag_ratio=0, lag_ratio: float = 0,
group_type: type = Group,
**kwargs, **kwargs,
): ):
self.source = source self.source = source
self.target = target self.target = target
matched_keys = matched_keys or list() self.match_animation = match_animation
key_map = key_map or dict() self.mismatch_animation = mismatch_animation
self.anim_config = dict(**kwargs) self.anim_config = dict(**kwargs)
# We will progressively build up a list of transforms # We will progressively build up a list of transforms
# from characters in source to those in target. These # from characters in source to those in target. These
# two lists keep track of which characters are accounted # two lists keep track of which characters are accounted
# for so far # for so far
self.source_chars = source.family_members_with_points() self.source_pieces = source.family_members_with_points()
self.target_chars = target.family_members_with_points() self.target_pieces = target.family_members_with_points()
self.anims = [] self.anims = []
# Start by pairing all matched keys specifically passed in for pair in matched_pairs:
for key in matched_keys: self.add_transform(*pair)
self.add_transform(
source.select_parts(key),
target.select_parts(key),
match_animation
)
# Then pair those based on the key map
for key, value in key_map.items():
self.add_transform(
source.select_parts(key),
target.select_parts(value),
mismatch_animation
)
# Now pair by substrings which were isolated in StringMobject
# initializations
specified_substrings = [
*source.get_specified_substrings(),
*target.get_specified_substrings()
]
for key in specified_substrings:
self.add_transform(
source.select_parts(key),
target.select_parts(key),
match_animation
)
# Match any pairs with the same shape # Match any pairs with the same shape
pairs = self.find_pairs_with_matching_shapes(self.source_chars, self.target_chars) for pair in self.find_pairs_with_matching_shapes(self.source_pieces, self.target_pieces):
for source_char, target_char in pairs: self.add_transform(*pair)
self.add_transform(source_char, target_char, match_animation)
# Finally, account for mismatches # Finally, account for mismatches
for source_char in self.source_chars: for source_char in self.source_pieces:
self.anims.append(FadeOutToPoint( self.anims.append(FadeOutToPoint(
source_char, target.get_center(), source_char, target.get_center(),
**self.anim_config **self.anim_config
)) ))
for target_char in self.target_chars: for target_char in self.target_pieces:
self.anims.append(FadeInFromPoint( self.anims.append(FadeInFromPoint(
target_char, source.get_center(), target_char, source.get_center(),
**self.anim_config **self.anim_config
)) ))
super().__init__( super().__init__(
*self.anims, *self.anims,
run_time=run_time, run_time=run_time,
lag_ratio=lag_ratio, lag_ratio=lag_ratio,
group_type=VGroup, group_type=group_type,
) )
def add_transform( def add_transform(
self, self,
source: VMobject, source: Mobject,
target: VMobject, target: Mobject,
transform_type: type = Transform,
): ):
new_source_chars = source.family_members_with_points() new_source_pieces = source.family_members_with_points()
new_target_chars = target.family_members_with_points() new_target_pieces = target.family_members_with_points()
source_is_new = all(char in self.source_chars for char in new_source_chars) if len(new_source_pieces) == 0 or len(new_target_pieces) == 0:
target_is_new = all(char in self.target_chars for char in new_target_chars) # Don't animate null sorces
return
source_is_new = all(char in self.source_pieces for char in new_source_pieces)
target_is_new = all(char in self.target_pieces for char in new_target_pieces)
if source_is_new and target_is_new: if source_is_new and target_is_new:
self.anims.append(transform_type( transform_type = self.mismatch_animation
source, target, **self.anim_config if source.has_same_shape_as(target):
)) transform_type = self.match_animation
for char in new_source_chars: self.anims.append(transform_type(source, target, **self.anim_config))
self.source_chars.remove(char) for char in new_source_pieces:
for char in new_target_chars: self.source_pieces.remove(char)
self.target_chars.remove(char) for char in new_target_pieces:
self.target_pieces.remove(char)
def find_pairs_with_matching_shapes(self, chars1, chars2) -> list[tuple[VMobject, VMobject]]: def find_pairs_with_matching_shapes(
self,
chars1: list[Mobject],
chars2: list[Mobject]
) -> list[tuple[Mobject, Mobject]]:
result = [] result = []
for char1, char2 in it.product(chars1, chars2): for char1, char2 in it.product(chars1, chars2):
if char1.has_same_shape_as(char2): if char1.has_same_shape_as(char2):
@ -242,6 +114,41 @@ class TransformMatchingStrings(AnimationGroup):
scene.add(self.target) scene.add(self.target)
class TransformMatchingShapes(TransformMatchingParts):
"""Alias for TransformMatchingParts"""
pass
class TransformMatchingStrings(TransformMatchingParts):
def __init__(
self,
source: StringMobject,
target: StringMobject,
matched_keys: Iterable[str] = [],
key_map: dict[str, str] = dict(),
**kwargs,
):
matched_pairs = [
*[(source[key], target[key]) for key in matched_keys],
*[(source[key1], target[key2]) for key1, key2 in key_map.items()],
*[
(source[substr], target[substr])
for substr in [
*source.get_specified_substrings(),
*target.get_specified_substrings(),
*source.get_string(), # TODO, have this return symbols, not characters
*target.get_string(),
]
]
]
super().__init__(
source, target,
matched_pairs=matched_pairs,
group_type=VGroup,
**kwargs,
)
class TransformMatchingTex(TransformMatchingStrings): class TransformMatchingTex(TransformMatchingStrings):
"""Alias for TransformMatchingStrings""" """Alias for TransformMatchingStrings"""
pass pass