diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index d82c8874..de32bf89 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -4,16 +4,19 @@ import itertools as it import numpy as np +from manimlib.animation.animation import Animation from manimlib.animation.composition import AnimationGroup from manimlib.animation.fading import FadeInFromPoint from manimlib.animation.fading import FadeOutToPoint from manimlib.animation.fading import FadeTransformPieces +from manimlib.animation.fading import FadeTransform from manimlib.animation.transform import ReplacementTransform from manimlib.animation.transform import Transform from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Group from manimlib.mobject.svg.string_mobject import StringMobject from manimlib.mobject.svg.old_tex_mobject import OldTex +from manimlib.mobject.svg.tex_mobject import Tex from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VMobject @@ -93,8 +96,8 @@ class TransformMatchingParts(AnimationGroup): self.to_remove = mobject self.to_add = target_mobject - def get_shape_map(self, mobject: Mobject) -> dict[int, VGroup]: - shape_map: dict[int, VGroup] = {} + def get_shape_map(self, mobject: Mobject) -> dict[int | str, VGroup]: + shape_map: dict[int | str, VGroup] = {} for sm in self.get_mobject_parts(mobject): key = self.get_mobject_key(sm) if key not in shape_map: @@ -138,128 +141,115 @@ class TransformMatchingShapes(TransformMatchingParts): return result -class TransformMatchingTex(TransformMatchingParts): - mobject_type: type = OldTex - group_type: type = VGroup - - @staticmethod - def get_mobject_parts(mobject: OldTex) -> list[SingleStringTex]: - return mobject.submobjects - - @staticmethod - def get_mobject_key(mobject: OldTex) -> str: - return mobject.get_tex() - - class TransformMatchingStrings(AnimationGroup): - def __init__(self, + def __init__( + self, source: StringMobject, target: StringMobject, - key_map: dict = {}, - transform_mismatches: bool = False, - **kwargs + matched_keys: list[str] | None = None, + key_map: dict[str, str] | None = None, + match_animation: type = Transform, + mismatch_animation: type = Transform, + run_time=2, + **kwargs, ): - assert isinstance(source, StringMobject) - assert isinstance(target, StringMobject) + self.source = source + self.target = target + matched_keys = matched_keys or list() + key_map = key_map or dict() + self.anim_config = dict(run_time=run_time, **kwargs) - def get_matched_indices_lists(*part_items_list): - part_items_list_len = len(part_items_list) - indexed_part_items = sorted(it.chain(*[ - [ - (substr, items_index, indices_list) - for substr, indices_list in part_items - ] - for items_index, part_items in enumerate(part_items_list) - ])) - grouped_part_items = [ - (substr, [ - [indices_lists for _, _, indices_lists in grouper_2] - for _, grouper_2 in it.groupby( - grouper_1, key=lambda t: t[1] - ) - ]) - for substr, grouper_1 in it.groupby( - indexed_part_items, key=lambda t: t[0] - ) - ] - return [ - tuple(indices_lists_list) - for _, indices_lists_list in sorted(filter( - lambda t: t[0] and len(t[1]) == part_items_list_len, - grouped_part_items - ), key=lambda t: len(t[0]), reverse=True) - ] + # We will progressively build up a list of transforms + # from characters in source to those in target. These + # two lists keep track of which characters are accounted + # for so far + self.source_chars = source.family_members_with_points() + self.target_chars = target.family_members_with_points() + self.anims = [] - def get_filtered_indices_lists(indices_lists, used_indices): - result = [] - used = [] - for indices_list in indices_lists: - if not all( - index not in used_indices and index not in used - for index in indices_list - ): - continue - result.append(indices_list) - used.extend(indices_list) - return result, used - - anim_class_items = [ - (ReplacementTransform, [ - ( - source.get_submob_indices_lists_by_selector(k), - target.get_submob_indices_lists_by_selector(v) - ) - for k, v in key_map.items() - ]), - (FadeTransformPieces, get_matched_indices_lists( - source.get_specified_part_items(), - target.get_specified_part_items() - )), - (FadeTransformPieces, get_matched_indices_lists( - source.get_group_part_items(), - target.get_group_part_items() - )) + # Start by pairing all matched keys specifically passed in + for key in matched_keys: + self.add_transform( + source.select_parts(key), + target.select_parts(key), + match_animation + ) + # Then pair those based on the key map + for key, value in key_map.items(): + self.add_transform( + source.select_parts(key), + target.select_parts(value), + mismatch_animation + ) + # Now pair by substrings which were isolated in StringMobject + # initializations + specified_substrings = [ + *source.get_specified_substrings(), + *target.get_specified_substrings() ] - - anims = [] - source_used_indices = [] - target_used_indices = [] - for anim_class, pairs in anim_class_items: - for source_indices_lists, target_indices_lists in pairs: - source_filtered, source_used = get_filtered_indices_lists( - source_indices_lists, source_used_indices - ) - target_filtered, target_used = get_filtered_indices_lists( - target_indices_lists, target_used_indices - ) - if not source_filtered or not target_filtered: - continue - anims.append(anim_class( - source.build_parts_from_indices_lists(source_filtered), - target.build_parts_from_indices_lists(target_filtered), - **kwargs - )) - source_used_indices.extend(source_used) - target_used_indices.extend(target_used) - - rest_source = VGroup(*( - submob for index, submob in enumerate(source.submobjects) - if index not in source_used_indices - )) - rest_target = VGroup(*( - submob for index, submob in enumerate(target.submobjects) - if index not in target_used_indices - )) - if transform_mismatches: - anims.append( - ReplacementTransform(rest_source, rest_target, **kwargs) - ) - else: - anims.append( - FadeOutToPoint(rest_source, target.get_center(), **kwargs) - ) - anims.append( - FadeInFromPoint(rest_target, source.get_center(), **kwargs) + for key in specified_substrings: + self.add_transform( + source.select_parts(key), + target.select_parts(key), + match_animation ) + # Match any pairs with the same shape + pairs = self.find_pairs_with_matching_shapes(self.source_chars, self.target_chars) + for source_char, target_char in pairs: + self.add_transform(source_char, target_char, match_animation) + # Finally, account for mismatches + for source_char in self.source_chars: + self.anims.append(FadeOutToPoint( + source_char, target.get_center(), + **self.anim_config + )) + for target_char in self.target_chars: + self.anims.append(FadeInFromPoint( + target_char, source.get_center(), + **self.anim_config + )) + super().__init__(*self.anims) - super().__init__(*anims) + def add_transform( + self, + source: VMobject, + target: VMobject, + transform_type: type = Transform, + ): + new_source_chars = source.family_members_with_points() + new_target_chars = target.family_members_with_points() + source_is_new = all(char in self.source_chars for char in new_source_chars) + target_is_new = all(char in self.target_chars for char in new_target_chars) + if source_is_new and target_is_new: + self.anims.append(transform_type( + source, target, **self.anim_config + )) + for char in new_source_chars: + self.source_chars.remove(char) + for char in new_target_chars: + self.target_chars.remove(char) + + def find_pairs_with_matching_shapes(self, chars1, chars2) -> list[tuple[VMobject, VMobject]]: + for char in (*chars1, *chars2): + char.save_state() + char.set_height(1) + char.center() + result = [] + for char1, char2 in it.product(chars1, chars2): + p1 = char1.get_points() + p2 = char2.get_points() + if len(p1) == len(p2) and np.isclose(p1, p2 , atol=1e-1).all(): + result.append((char1, char2)) + for char in (*chars1, *chars2): + char.restore() + return result + + def clean_up_from_scene(self, scene: Scene) -> None: + super().clean_up_from_scene(scene) + scene.remove(self.mobject) + scene.add(self.target) + + +class TransformMatchingTex(TransformMatchingStrings): + """Alias for TransformMatchingStrings""" + pass