diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index 452333fd..f69f79c1 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -13,6 +13,8 @@ from manimlib.animation.transform import Transform from manimlib.mobject.mobject import Mobject from manimlib.mobject.mobject import Group from manimlib.mobject.svg.string_mobject import StringMobject +from manimlib.mobject.svg.tex_mobject import Tex +from manimlib.mobject.svg.text_mobject import TEXT_MOB_SCALE_FACTOR from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.config_ops import digest_config @@ -26,16 +28,18 @@ if TYPE_CHECKING: class TransformMatchingParts(AnimationGroup): - CONFIG = { - "mobject_type": Mobject, - "group_type": Group, - "transform_mismatches": False, - "fade_transform_mismatches": False, - "key_map": dict(), - } + mobject_type: type = Mobject + group_type: type = Group - def __init__(self, mobject: Mobject, target_mobject: Mobject, **kwargs): - digest_config(self, kwargs) + def __init__( + self, + mobject: Mobject, + target_mobject: Mobject, + transform_mismatches: bool = False, + fade_transform_mismatches: bool = False, + key_map: dict = dict(), + **kwargs + ): assert(isinstance(mobject, self.mobject_type)) assert(isinstance(target_mobject, self.mobject_type)) source_map = self.get_shape_map(mobject) @@ -55,7 +59,7 @@ class TransformMatchingParts(AnimationGroup): # into another despite not matching by using key_map key_mapped_source = self.group_type() key_mapped_target = self.group_type() - for key1, key2 in self.key_map.items(): + for key1, key2 in key_map.items(): if key1 in source_map and key2 in target_map: key_mapped_source.add(source_map[key1]) key_mapped_target.add(target_map[key2]) @@ -74,9 +78,9 @@ class TransformMatchingParts(AnimationGroup): for key in set(target_map).difference(source_map): fade_target.add(target_map[key]) - if self.transform_mismatches: + if transform_mismatches: anims.append(Transform(fade_source.copy(), fade_target, **kwargs)) - if self.fade_transform_mismatches: + if fade_transform_mismatches: anims.append(FadeTransformPieces(fade_source, fade_target, **kwargs)) else: anims.append(FadeOutToPoint( @@ -119,10 +123,8 @@ class TransformMatchingParts(AnimationGroup): class TransformMatchingShapes(TransformMatchingParts): - CONFIG = { - "mobject_type": VMobject, - "group_type": VGroup, - } + mobject_type: type = VMobject + group_type: type = VGroup @staticmethod def get_mobject_parts(mobject: VMobject) -> list[VMobject]: @@ -139,10 +141,8 @@ class TransformMatchingShapes(TransformMatchingParts): class TransformMatchingTex(TransformMatchingParts): - CONFIG = { - "mobject_type": VMobject, - "group_type": VGroup, - } + mobject_type: type = Tex + group_type: type = VGroup @staticmethod def get_mobject_parts(mobject: Tex) -> list[SingleStringTex]: @@ -154,14 +154,11 @@ class TransformMatchingTex(TransformMatchingParts): class TransformMatchingStrings(AnimationGroup): - CONFIG = { - "key_map": {}, - "transform_mismatches": False, - } - def __init__(self, source: StringMobject, target: StringMobject, + key_map: dict = {}, + transform_mismatches: bool = False, **kwargs ): digest_config(self, kwargs) @@ -215,7 +212,7 @@ class TransformMatchingStrings(AnimationGroup): source.get_submob_indices_lists_by_selector(k), target.get_submob_indices_lists_by_selector(v) ) - for k, v in self.key_map.items() + for k, v in key_map.items() ]), (FadeTransformPieces, get_matched_indices_lists( source.get_specified_part_items(), @@ -248,15 +245,15 @@ class TransformMatchingStrings(AnimationGroup): source_used_indices.extend(source_used) target_used_indices.extend(target_used) - rest_source = VGroup(*[ + rest_source = VGroup(*( submob for index, submob in enumerate(source.submobjects) if index not in source_used_indices - ]) - rest_target = VGroup(*[ + )) + rest_target = VGroup(*( submob for index, submob in enumerate(target.submobjects) if index not in target_used_indices - ]) - if self.transform_mismatches: + )) + if transform_mismatches: anims.append( ReplacementTransform(rest_source, rest_target, **kwargs) )