mirror of
https://github.com/3b1b/manim.git
synced 2025-09-01 00:48:45 +00:00
Write new TransformMatchingStrings
This commit is contained in:
parent
c7ba775845
commit
926f3515bf
1 changed files with 108 additions and 118 deletions
|
@ -4,16 +4,19 @@ import itertools as it
|
|||
|
||||
import numpy as np
|
||||
|
||||
from manimlib.animation.animation import Animation
|
||||
from manimlib.animation.composition import AnimationGroup
|
||||
from manimlib.animation.fading import FadeInFromPoint
|
||||
from manimlib.animation.fading import FadeOutToPoint
|
||||
from manimlib.animation.fading import FadeTransformPieces
|
||||
from manimlib.animation.fading import FadeTransform
|
||||
from manimlib.animation.transform import ReplacementTransform
|
||||
from manimlib.animation.transform import Transform
|
||||
from manimlib.mobject.mobject import Mobject
|
||||
from manimlib.mobject.mobject import Group
|
||||
from manimlib.mobject.svg.string_mobject import StringMobject
|
||||
from manimlib.mobject.svg.old_tex_mobject import OldTex
|
||||
from manimlib.mobject.svg.tex_mobject import Tex
|
||||
from manimlib.mobject.types.vectorized_mobject import VGroup
|
||||
from manimlib.mobject.types.vectorized_mobject import VMobject
|
||||
|
||||
|
@ -93,8 +96,8 @@ class TransformMatchingParts(AnimationGroup):
|
|||
self.to_remove = mobject
|
||||
self.to_add = target_mobject
|
||||
|
||||
def get_shape_map(self, mobject: Mobject) -> dict[int, VGroup]:
|
||||
shape_map: dict[int, VGroup] = {}
|
||||
def get_shape_map(self, mobject: Mobject) -> dict[int | str, VGroup]:
|
||||
shape_map: dict[int | str, VGroup] = {}
|
||||
for sm in self.get_mobject_parts(mobject):
|
||||
key = self.get_mobject_key(sm)
|
||||
if key not in shape_map:
|
||||
|
@ -138,128 +141,115 @@ class TransformMatchingShapes(TransformMatchingParts):
|
|||
return result
|
||||
|
||||
|
||||
class TransformMatchingTex(TransformMatchingParts):
|
||||
mobject_type: type = OldTex
|
||||
group_type: type = VGroup
|
||||
|
||||
@staticmethod
|
||||
def get_mobject_parts(mobject: OldTex) -> list[SingleStringTex]:
|
||||
return mobject.submobjects
|
||||
|
||||
@staticmethod
|
||||
def get_mobject_key(mobject: OldTex) -> str:
|
||||
return mobject.get_tex()
|
||||
|
||||
|
||||
class TransformMatchingStrings(AnimationGroup):
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
source: StringMobject,
|
||||
target: StringMobject,
|
||||
key_map: dict = {},
|
||||
transform_mismatches: bool = False,
|
||||
**kwargs
|
||||
matched_keys: list[str] | None = None,
|
||||
key_map: dict[str, str] | None = None,
|
||||
match_animation: type = Transform,
|
||||
mismatch_animation: type = Transform,
|
||||
run_time=2,
|
||||
**kwargs,
|
||||
):
|
||||
assert isinstance(source, StringMobject)
|
||||
assert isinstance(target, StringMobject)
|
||||
self.source = source
|
||||
self.target = target
|
||||
matched_keys = matched_keys or list()
|
||||
key_map = key_map or dict()
|
||||
self.anim_config = dict(run_time=run_time, **kwargs)
|
||||
|
||||
def get_matched_indices_lists(*part_items_list):
|
||||
part_items_list_len = len(part_items_list)
|
||||
indexed_part_items = sorted(it.chain(*[
|
||||
[
|
||||
(substr, items_index, indices_list)
|
||||
for substr, indices_list in part_items
|
||||
]
|
||||
for items_index, part_items in enumerate(part_items_list)
|
||||
]))
|
||||
grouped_part_items = [
|
||||
(substr, [
|
||||
[indices_lists for _, _, indices_lists in grouper_2]
|
||||
for _, grouper_2 in it.groupby(
|
||||
grouper_1, key=lambda t: t[1]
|
||||
)
|
||||
])
|
||||
for substr, grouper_1 in it.groupby(
|
||||
indexed_part_items, key=lambda t: t[0]
|
||||
)
|
||||
]
|
||||
return [
|
||||
tuple(indices_lists_list)
|
||||
for _, indices_lists_list in sorted(filter(
|
||||
lambda t: t[0] and len(t[1]) == part_items_list_len,
|
||||
grouped_part_items
|
||||
), key=lambda t: len(t[0]), reverse=True)
|
||||
]
|
||||
# We will progressively build up a list of transforms
|
||||
# from characters in source to those in target. These
|
||||
# two lists keep track of which characters are accounted
|
||||
# for so far
|
||||
self.source_chars = source.family_members_with_points()
|
||||
self.target_chars = target.family_members_with_points()
|
||||
self.anims = []
|
||||
|
||||
def get_filtered_indices_lists(indices_lists, used_indices):
|
||||
result = []
|
||||
used = []
|
||||
for indices_list in indices_lists:
|
||||
if not all(
|
||||
index not in used_indices and index not in used
|
||||
for index in indices_list
|
||||
):
|
||||
continue
|
||||
result.append(indices_list)
|
||||
used.extend(indices_list)
|
||||
return result, used
|
||||
|
||||
anim_class_items = [
|
||||
(ReplacementTransform, [
|
||||
(
|
||||
source.get_submob_indices_lists_by_selector(k),
|
||||
target.get_submob_indices_lists_by_selector(v)
|
||||
)
|
||||
for k, v in key_map.items()
|
||||
]),
|
||||
(FadeTransformPieces, get_matched_indices_lists(
|
||||
source.get_specified_part_items(),
|
||||
target.get_specified_part_items()
|
||||
)),
|
||||
(FadeTransformPieces, get_matched_indices_lists(
|
||||
source.get_group_part_items(),
|
||||
target.get_group_part_items()
|
||||
))
|
||||
# Start by pairing all matched keys specifically passed in
|
||||
for key in matched_keys:
|
||||
self.add_transform(
|
||||
source.select_parts(key),
|
||||
target.select_parts(key),
|
||||
match_animation
|
||||
)
|
||||
# Then pair those based on the key map
|
||||
for key, value in key_map.items():
|
||||
self.add_transform(
|
||||
source.select_parts(key),
|
||||
target.select_parts(value),
|
||||
mismatch_animation
|
||||
)
|
||||
# Now pair by substrings which were isolated in StringMobject
|
||||
# initializations
|
||||
specified_substrings = [
|
||||
*source.get_specified_substrings(),
|
||||
*target.get_specified_substrings()
|
||||
]
|
||||
|
||||
anims = []
|
||||
source_used_indices = []
|
||||
target_used_indices = []
|
||||
for anim_class, pairs in anim_class_items:
|
||||
for source_indices_lists, target_indices_lists in pairs:
|
||||
source_filtered, source_used = get_filtered_indices_lists(
|
||||
source_indices_lists, source_used_indices
|
||||
)
|
||||
target_filtered, target_used = get_filtered_indices_lists(
|
||||
target_indices_lists, target_used_indices
|
||||
)
|
||||
if not source_filtered or not target_filtered:
|
||||
continue
|
||||
anims.append(anim_class(
|
||||
source.build_parts_from_indices_lists(source_filtered),
|
||||
target.build_parts_from_indices_lists(target_filtered),
|
||||
**kwargs
|
||||
))
|
||||
source_used_indices.extend(source_used)
|
||||
target_used_indices.extend(target_used)
|
||||
|
||||
rest_source = VGroup(*(
|
||||
submob for index, submob in enumerate(source.submobjects)
|
||||
if index not in source_used_indices
|
||||
))
|
||||
rest_target = VGroup(*(
|
||||
submob for index, submob in enumerate(target.submobjects)
|
||||
if index not in target_used_indices
|
||||
))
|
||||
if transform_mismatches:
|
||||
anims.append(
|
||||
ReplacementTransform(rest_source, rest_target, **kwargs)
|
||||
)
|
||||
else:
|
||||
anims.append(
|
||||
FadeOutToPoint(rest_source, target.get_center(), **kwargs)
|
||||
)
|
||||
anims.append(
|
||||
FadeInFromPoint(rest_target, source.get_center(), **kwargs)
|
||||
for key in specified_substrings:
|
||||
self.add_transform(
|
||||
source.select_parts(key),
|
||||
target.select_parts(key),
|
||||
match_animation
|
||||
)
|
||||
# Match any pairs with the same shape
|
||||
pairs = self.find_pairs_with_matching_shapes(self.source_chars, self.target_chars)
|
||||
for source_char, target_char in pairs:
|
||||
self.add_transform(source_char, target_char, match_animation)
|
||||
# Finally, account for mismatches
|
||||
for source_char in self.source_chars:
|
||||
self.anims.append(FadeOutToPoint(
|
||||
source_char, target.get_center(),
|
||||
**self.anim_config
|
||||
))
|
||||
for target_char in self.target_chars:
|
||||
self.anims.append(FadeInFromPoint(
|
||||
target_char, source.get_center(),
|
||||
**self.anim_config
|
||||
))
|
||||
super().__init__(*self.anims)
|
||||
|
||||
super().__init__(*anims)
|
||||
def add_transform(
|
||||
self,
|
||||
source: VMobject,
|
||||
target: VMobject,
|
||||
transform_type: type = Transform,
|
||||
):
|
||||
new_source_chars = source.family_members_with_points()
|
||||
new_target_chars = target.family_members_with_points()
|
||||
source_is_new = all(char in self.source_chars for char in new_source_chars)
|
||||
target_is_new = all(char in self.target_chars for char in new_target_chars)
|
||||
if source_is_new and target_is_new:
|
||||
self.anims.append(transform_type(
|
||||
source, target, **self.anim_config
|
||||
))
|
||||
for char in new_source_chars:
|
||||
self.source_chars.remove(char)
|
||||
for char in new_target_chars:
|
||||
self.target_chars.remove(char)
|
||||
|
||||
def find_pairs_with_matching_shapes(self, chars1, chars2) -> list[tuple[VMobject, VMobject]]:
|
||||
for char in (*chars1, *chars2):
|
||||
char.save_state()
|
||||
char.set_height(1)
|
||||
char.center()
|
||||
result = []
|
||||
for char1, char2 in it.product(chars1, chars2):
|
||||
p1 = char1.get_points()
|
||||
p2 = char2.get_points()
|
||||
if len(p1) == len(p2) and np.isclose(p1, p2 , atol=1e-1).all():
|
||||
result.append((char1, char2))
|
||||
for char in (*chars1, *chars2):
|
||||
char.restore()
|
||||
return result
|
||||
|
||||
def clean_up_from_scene(self, scene: Scene) -> None:
|
||||
super().clean_up_from_scene(scene)
|
||||
scene.remove(self.mobject)
|
||||
scene.add(self.target)
|
||||
|
||||
|
||||
class TransformMatchingTex(TransformMatchingStrings):
|
||||
"""Alias for TransformMatchingStrings"""
|
||||
pass
|
||||
|
|
Loading…
Add table
Reference in a new issue