mirror of
https://github.com/3b1b/manim.git
synced 2025-09-01 00:48:45 +00:00
Write new TransformMatchingStrings
This commit is contained in:
parent
c7ba775845
commit
926f3515bf
1 changed files with 108 additions and 118 deletions
|
@ -4,16 +4,19 @@ import itertools as it
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from manimlib.animation.animation import Animation
|
||||||
from manimlib.animation.composition import AnimationGroup
|
from manimlib.animation.composition import AnimationGroup
|
||||||
from manimlib.animation.fading import FadeInFromPoint
|
from manimlib.animation.fading import FadeInFromPoint
|
||||||
from manimlib.animation.fading import FadeOutToPoint
|
from manimlib.animation.fading import FadeOutToPoint
|
||||||
from manimlib.animation.fading import FadeTransformPieces
|
from manimlib.animation.fading import FadeTransformPieces
|
||||||
|
from manimlib.animation.fading import FadeTransform
|
||||||
from manimlib.animation.transform import ReplacementTransform
|
from manimlib.animation.transform import ReplacementTransform
|
||||||
from manimlib.animation.transform import Transform
|
from manimlib.animation.transform import Transform
|
||||||
from manimlib.mobject.mobject import Mobject
|
from manimlib.mobject.mobject import Mobject
|
||||||
from manimlib.mobject.mobject import Group
|
from manimlib.mobject.mobject import Group
|
||||||
from manimlib.mobject.svg.string_mobject import StringMobject
|
from manimlib.mobject.svg.string_mobject import StringMobject
|
||||||
from manimlib.mobject.svg.old_tex_mobject import OldTex
|
from manimlib.mobject.svg.old_tex_mobject import OldTex
|
||||||
|
from manimlib.mobject.svg.tex_mobject import Tex
|
||||||
from manimlib.mobject.types.vectorized_mobject import VGroup
|
from manimlib.mobject.types.vectorized_mobject import VGroup
|
||||||
from manimlib.mobject.types.vectorized_mobject import VMobject
|
from manimlib.mobject.types.vectorized_mobject import VMobject
|
||||||
|
|
||||||
|
@ -93,8 +96,8 @@ class TransformMatchingParts(AnimationGroup):
|
||||||
self.to_remove = mobject
|
self.to_remove = mobject
|
||||||
self.to_add = target_mobject
|
self.to_add = target_mobject
|
||||||
|
|
||||||
def get_shape_map(self, mobject: Mobject) -> dict[int, VGroup]:
|
def get_shape_map(self, mobject: Mobject) -> dict[int | str, VGroup]:
|
||||||
shape_map: dict[int, VGroup] = {}
|
shape_map: dict[int | str, VGroup] = {}
|
||||||
for sm in self.get_mobject_parts(mobject):
|
for sm in self.get_mobject_parts(mobject):
|
||||||
key = self.get_mobject_key(sm)
|
key = self.get_mobject_key(sm)
|
||||||
if key not in shape_map:
|
if key not in shape_map:
|
||||||
|
@ -138,128 +141,115 @@ class TransformMatchingShapes(TransformMatchingParts):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class TransformMatchingTex(TransformMatchingParts):
|
|
||||||
mobject_type: type = OldTex
|
|
||||||
group_type: type = VGroup
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_mobject_parts(mobject: OldTex) -> list[SingleStringTex]:
|
|
||||||
return mobject.submobjects
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_mobject_key(mobject: OldTex) -> str:
|
|
||||||
return mobject.get_tex()
|
|
||||||
|
|
||||||
|
|
||||||
class TransformMatchingStrings(AnimationGroup):
|
class TransformMatchingStrings(AnimationGroup):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
|
self,
|
||||||
source: StringMobject,
|
source: StringMobject,
|
||||||
target: StringMobject,
|
target: StringMobject,
|
||||||
key_map: dict = {},
|
matched_keys: list[str] | None = None,
|
||||||
transform_mismatches: bool = False,
|
key_map: dict[str, str] | None = None,
|
||||||
**kwargs
|
match_animation: type = Transform,
|
||||||
|
mismatch_animation: type = Transform,
|
||||||
|
run_time=2,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
assert isinstance(source, StringMobject)
|
self.source = source
|
||||||
assert isinstance(target, StringMobject)
|
self.target = target
|
||||||
|
matched_keys = matched_keys or list()
|
||||||
|
key_map = key_map or dict()
|
||||||
|
self.anim_config = dict(run_time=run_time, **kwargs)
|
||||||
|
|
||||||
def get_matched_indices_lists(*part_items_list):
|
# We will progressively build up a list of transforms
|
||||||
part_items_list_len = len(part_items_list)
|
# from characters in source to those in target. These
|
||||||
indexed_part_items = sorted(it.chain(*[
|
# two lists keep track of which characters are accounted
|
||||||
[
|
# for so far
|
||||||
(substr, items_index, indices_list)
|
self.source_chars = source.family_members_with_points()
|
||||||
for substr, indices_list in part_items
|
self.target_chars = target.family_members_with_points()
|
||||||
]
|
self.anims = []
|
||||||
for items_index, part_items in enumerate(part_items_list)
|
|
||||||
]))
|
|
||||||
grouped_part_items = [
|
|
||||||
(substr, [
|
|
||||||
[indices_lists for _, _, indices_lists in grouper_2]
|
|
||||||
for _, grouper_2 in it.groupby(
|
|
||||||
grouper_1, key=lambda t: t[1]
|
|
||||||
)
|
|
||||||
])
|
|
||||||
for substr, grouper_1 in it.groupby(
|
|
||||||
indexed_part_items, key=lambda t: t[0]
|
|
||||||
)
|
|
||||||
]
|
|
||||||
return [
|
|
||||||
tuple(indices_lists_list)
|
|
||||||
for _, indices_lists_list in sorted(filter(
|
|
||||||
lambda t: t[0] and len(t[1]) == part_items_list_len,
|
|
||||||
grouped_part_items
|
|
||||||
), key=lambda t: len(t[0]), reverse=True)
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_filtered_indices_lists(indices_lists, used_indices):
|
# Start by pairing all matched keys specifically passed in
|
||||||
result = []
|
for key in matched_keys:
|
||||||
used = []
|
self.add_transform(
|
||||||
for indices_list in indices_lists:
|
source.select_parts(key),
|
||||||
if not all(
|
target.select_parts(key),
|
||||||
index not in used_indices and index not in used
|
match_animation
|
||||||
for index in indices_list
|
)
|
||||||
):
|
# Then pair those based on the key map
|
||||||
continue
|
for key, value in key_map.items():
|
||||||
result.append(indices_list)
|
self.add_transform(
|
||||||
used.extend(indices_list)
|
source.select_parts(key),
|
||||||
return result, used
|
target.select_parts(value),
|
||||||
|
mismatch_animation
|
||||||
anim_class_items = [
|
)
|
||||||
(ReplacementTransform, [
|
# Now pair by substrings which were isolated in StringMobject
|
||||||
(
|
# initializations
|
||||||
source.get_submob_indices_lists_by_selector(k),
|
specified_substrings = [
|
||||||
target.get_submob_indices_lists_by_selector(v)
|
*source.get_specified_substrings(),
|
||||||
)
|
*target.get_specified_substrings()
|
||||||
for k, v in key_map.items()
|
|
||||||
]),
|
|
||||||
(FadeTransformPieces, get_matched_indices_lists(
|
|
||||||
source.get_specified_part_items(),
|
|
||||||
target.get_specified_part_items()
|
|
||||||
)),
|
|
||||||
(FadeTransformPieces, get_matched_indices_lists(
|
|
||||||
source.get_group_part_items(),
|
|
||||||
target.get_group_part_items()
|
|
||||||
))
|
|
||||||
]
|
]
|
||||||
|
for key in specified_substrings:
|
||||||
anims = []
|
self.add_transform(
|
||||||
source_used_indices = []
|
source.select_parts(key),
|
||||||
target_used_indices = []
|
target.select_parts(key),
|
||||||
for anim_class, pairs in anim_class_items:
|
match_animation
|
||||||
for source_indices_lists, target_indices_lists in pairs:
|
|
||||||
source_filtered, source_used = get_filtered_indices_lists(
|
|
||||||
source_indices_lists, source_used_indices
|
|
||||||
)
|
|
||||||
target_filtered, target_used = get_filtered_indices_lists(
|
|
||||||
target_indices_lists, target_used_indices
|
|
||||||
)
|
|
||||||
if not source_filtered or not target_filtered:
|
|
||||||
continue
|
|
||||||
anims.append(anim_class(
|
|
||||||
source.build_parts_from_indices_lists(source_filtered),
|
|
||||||
target.build_parts_from_indices_lists(target_filtered),
|
|
||||||
**kwargs
|
|
||||||
))
|
|
||||||
source_used_indices.extend(source_used)
|
|
||||||
target_used_indices.extend(target_used)
|
|
||||||
|
|
||||||
rest_source = VGroup(*(
|
|
||||||
submob for index, submob in enumerate(source.submobjects)
|
|
||||||
if index not in source_used_indices
|
|
||||||
))
|
|
||||||
rest_target = VGroup(*(
|
|
||||||
submob for index, submob in enumerate(target.submobjects)
|
|
||||||
if index not in target_used_indices
|
|
||||||
))
|
|
||||||
if transform_mismatches:
|
|
||||||
anims.append(
|
|
||||||
ReplacementTransform(rest_source, rest_target, **kwargs)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
anims.append(
|
|
||||||
FadeOutToPoint(rest_source, target.get_center(), **kwargs)
|
|
||||||
)
|
|
||||||
anims.append(
|
|
||||||
FadeInFromPoint(rest_target, source.get_center(), **kwargs)
|
|
||||||
)
|
)
|
||||||
|
# Match any pairs with the same shape
|
||||||
|
pairs = self.find_pairs_with_matching_shapes(self.source_chars, self.target_chars)
|
||||||
|
for source_char, target_char in pairs:
|
||||||
|
self.add_transform(source_char, target_char, match_animation)
|
||||||
|
# Finally, account for mismatches
|
||||||
|
for source_char in self.source_chars:
|
||||||
|
self.anims.append(FadeOutToPoint(
|
||||||
|
source_char, target.get_center(),
|
||||||
|
**self.anim_config
|
||||||
|
))
|
||||||
|
for target_char in self.target_chars:
|
||||||
|
self.anims.append(FadeInFromPoint(
|
||||||
|
target_char, source.get_center(),
|
||||||
|
**self.anim_config
|
||||||
|
))
|
||||||
|
super().__init__(*self.anims)
|
||||||
|
|
||||||
super().__init__(*anims)
|
def add_transform(
|
||||||
|
self,
|
||||||
|
source: VMobject,
|
||||||
|
target: VMobject,
|
||||||
|
transform_type: type = Transform,
|
||||||
|
):
|
||||||
|
new_source_chars = source.family_members_with_points()
|
||||||
|
new_target_chars = target.family_members_with_points()
|
||||||
|
source_is_new = all(char in self.source_chars for char in new_source_chars)
|
||||||
|
target_is_new = all(char in self.target_chars for char in new_target_chars)
|
||||||
|
if source_is_new and target_is_new:
|
||||||
|
self.anims.append(transform_type(
|
||||||
|
source, target, **self.anim_config
|
||||||
|
))
|
||||||
|
for char in new_source_chars:
|
||||||
|
self.source_chars.remove(char)
|
||||||
|
for char in new_target_chars:
|
||||||
|
self.target_chars.remove(char)
|
||||||
|
|
||||||
|
def find_pairs_with_matching_shapes(self, chars1, chars2) -> list[tuple[VMobject, VMobject]]:
|
||||||
|
for char in (*chars1, *chars2):
|
||||||
|
char.save_state()
|
||||||
|
char.set_height(1)
|
||||||
|
char.center()
|
||||||
|
result = []
|
||||||
|
for char1, char2 in it.product(chars1, chars2):
|
||||||
|
p1 = char1.get_points()
|
||||||
|
p2 = char2.get_points()
|
||||||
|
if len(p1) == len(p2) and np.isclose(p1, p2 , atol=1e-1).all():
|
||||||
|
result.append((char1, char2))
|
||||||
|
for char in (*chars1, *chars2):
|
||||||
|
char.restore()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def clean_up_from_scene(self, scene: Scene) -> None:
|
||||||
|
super().clean_up_from_scene(scene)
|
||||||
|
scene.remove(self.mobject)
|
||||||
|
scene.add(self.target)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformMatchingTex(TransformMatchingStrings):
|
||||||
|
"""Alias for TransformMatchingStrings"""
|
||||||
|
pass
|
||||||
|
|
Loading…
Add table
Reference in a new issue