diff --git a/example_scenes.py b/example_scenes.py index f9268aae..28b75e4e 100644 --- a/example_scenes.py +++ b/example_scenes.py @@ -184,7 +184,7 @@ class TexTransformExample(Scene): # line up. If it's not specified, the animation # will try its best, but may not quite give the # intended effect - matched_keys=["A^2", "B^2", "C^2", "="], + matched_keys=["A^2", "B^2", "C^2"], # When you want a substring from the source # to go to a non-equal substring from the target, # use the key map. @@ -193,12 +193,14 @@ class TexTransformExample(Scene): ), ) self.wait() - self.play(TransformMatchingStrings(lines[1].copy(), lines[2])) + self.play(TransformMatchingStrings( + lines[1].copy(), lines[2], + matched_keys=["A^2"] + )) self.wait() self.play( TransformMatchingStrings( lines[2].copy(), lines[3], - matched_keys=["="], key_map={"2": R"\sqrt"}, path_arc=-30 * DEGREES, ), @@ -217,6 +219,30 @@ class TexTransformExample(Scene): self.wait() self.play(LaggedStartMap(FadeOut, lines, shift=2 * RIGHT)) + # Indexing by substrings like this may not work when + # the order in which Latex draws symbols does not match + # the order in which they show up in the string. + # For example, here the infinity is drawn before the sigma + # so we don't get the desired behavior. + equation = Tex(R"\sum_{n = 1}^\infty \frac{1}{n^2} = \frac{\pi^2}{6}") + self.play(FadeIn(equation)) + self.play(equation[R"\infty"].animate.set_color(RED)) # Doesn't hit the infinity + self.wait() + self.play(FadeOut(equation)) + + # However you can always fix this by explicitly passing in + # a string you might want to isolate later. Also, using + # \over instead of \frac helps to avoid the issue for fractions + equation = Tex( + R"\sum_{n = 1}^\infty {1 \over n^2} = {\pi^2 \over 6}", + # Explicitly mark "\infty" as a substring you might want to access + isolate=[R"\infty"] + ) + self.play(FadeIn(equation)) + self.play(equation[R"\infty"].animate.set_color(RED)) # Got it! + self.wait() + self.play(FadeOut(equation)) + # TransformMatchingShapes will try to line up all pieces of a # source mobject with those of a target, regardless of the # what Mobject type they are. diff --git a/manimlib/animation/fading.py b/manimlib/animation/fading.py index 1ed2996f..c8c29065 100644 --- a/manimlib/animation/fading.py +++ b/manimlib/animation/fading.py @@ -6,6 +6,8 @@ from manimlib.animation.animation import Animation from manimlib.animation.transform import Transform from manimlib.constants import ORIGIN from manimlib.mobject.mobject import Group +from manimlib.mobject.types.vectorized_mobject import VMobject +from manimlib.mobject.types.vectorized_mobject import VGroup from manimlib.utils.bezier import interpolate from manimlib.utils.rate_functions import there_and_back @@ -14,8 +16,8 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Callable from manimlib.mobject.mobject import Mobject - from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.scene.scene import Scene + from manimlib.typing import Vect3 @@ -48,7 +50,7 @@ class FadeOut(Fade): def __init__( self, mobject: Mobject, - shift: np.ndarray = ORIGIN, + shift: Vect3 = ORIGIN, remover: bool = True, final_alpha_value: float = 0.0, # Put it back in original state when done, **kwargs @@ -69,7 +71,7 @@ class FadeOut(Fade): class FadeInFromPoint(FadeIn): - def __init__(self, mobject: Mobject, point: np.ndarray[int, np.dtype[np.float64]], **kwargs): + def __init__(self, mobject: Mobject, point: Vect3, **kwargs): super().__init__( mobject, shift=mobject.get_center() - point, @@ -79,7 +81,7 @@ class FadeInFromPoint(FadeIn): class FadeOutToPoint(FadeOut): - def __init__(self, mobject: Mobject, point: np.ndarray[int, np.dtype[np.float64]], **kwargs): + def __init__(self, mobject: Mobject, point: Vect3, **kwargs): super().__init__( mobject, shift=point - mobject.get_center(), @@ -101,8 +103,12 @@ class FadeTransform(Transform): self.stretch = stretch self.dim_to_match = dim_to_match + group_type = Group + if isinstance(mobject, VMobject) and isinstance(target_mobject, VMobject): + group_type = VGroup + mobject.save_state() - super().__init__(Group(mobject, target_mobject.copy()), **kwargs) + super().__init__(group_type(mobject, target_mobject.copy()), **kwargs) def begin(self) -> None: self.ending_mobject = self.mobject.copy() diff --git a/manimlib/animation/transform_matching_parts.py b/manimlib/animation/transform_matching_parts.py index 6c608cae..c37c16c3 100644 --- a/manimlib/animation/transform_matching_parts.py +++ b/manimlib/animation/transform_matching_parts.py @@ -18,218 +18,93 @@ from manimlib.mobject.types.vectorized_mobject import VMobject from typing import TYPE_CHECKING if TYPE_CHECKING: + from typing import Iterable from manimlib.scene.scene import Scene class TransformMatchingParts(AnimationGroup): - mobject_type: type = Mobject - group_type: type = Group - def __init__( self, - mobject: Mobject, - target_mobject: Mobject, - transform_mismatches: bool = False, - fade_transform_mismatches: bool = False, - key_map: dict | None = None, - **kwargs - ): - assert(isinstance(mobject, self.mobject_type)) - assert(isinstance(target_mobject, self.mobject_type)) - source_map = self.get_shape_map(mobject) - target_map = self.get_shape_map(target_mobject) - key_map = key_map or dict() - - # Create two mobjects whose submobjects all match each other - # according to whatever keys are used for source_map and - # target_map - transform_source = self.group_type() - transform_target = self.group_type() - kwargs["final_alpha_value"] = 0 - for key in set(source_map).intersection(target_map): - transform_source.add(source_map[key]) - transform_target.add(target_map[key]) - anims = [Transform(transform_source, transform_target, **kwargs)] - # User can manually specify when one part should transform - # into another despite not matching by using key_map - key_mapped_source = self.group_type() - key_mapped_target = self.group_type() - for key1, key2 in key_map.items(): - if key1 in source_map and key2 in target_map: - key_mapped_source.add(source_map[key1]) - key_mapped_target.add(target_map[key2]) - source_map.pop(key1, None) - target_map.pop(key2, None) - if len(key_mapped_source) > 0: - anims.append(FadeTransformPieces( - key_mapped_source, - key_mapped_target, - )) - - fade_source = self.group_type() - fade_target = self.group_type() - for key in set(source_map).difference(target_map): - fade_source.add(source_map[key]) - for key in set(target_map).difference(source_map): - fade_target.add(target_map[key]) - - if transform_mismatches: - anims.append(Transform(fade_source.copy(), fade_target, **kwargs)) - if fade_transform_mismatches: - anims.append(FadeTransformPieces(fade_source, fade_target, **kwargs)) - else: - anims.append(FadeOutToPoint( - fade_source, target_mobject.get_center(), **kwargs - )) - anims.append(FadeInFromPoint( - fade_target.copy(), mobject.get_center(), **kwargs - )) - - super().__init__(*anims) - - self.to_remove = mobject - self.to_add = target_mobject - - def get_shape_map(self, mobject: Mobject) -> dict[int | str, VGroup]: - shape_map: dict[int | str, VGroup] = {} - for sm in self.get_mobject_parts(mobject): - key = self.get_mobject_key(sm) - if key not in shape_map: - shape_map[key] = VGroup() - shape_map[key].add(sm) - return shape_map - - def clean_up_from_scene(self, scene: Scene) -> None: - for anim in self.animations: - anim.update(0) - scene.remove(self.mobject) - scene.remove(self.to_remove) - scene.add(self.to_add) - - @staticmethod - def get_mobject_parts(mobject: Mobject) -> Mobject: - # To be implemented in subclass - return mobject - - @staticmethod - def get_mobject_key(mobject: Mobject) -> int: - # To be implemented in subclass - return hash(mobject) - - -class TransformMatchingShapes(TransformMatchingParts): - mobject_type: type = VMobject - group_type: type = VGroup - - @staticmethod - def get_mobject_parts(mobject: VMobject) -> list[VMobject]: - return mobject.family_members_with_points() - - @staticmethod - def get_mobject_key(mobject: VMobject) -> int: - mobject.save_state() - mobject.center() - mobject.set_height(1) - result = hash(np.round(mobject.get_points(), 3).tobytes()) - mobject.restore() - return result - - -class TransformMatchingStrings(AnimationGroup): - def __init__( - self, - source: StringMobject, - target: StringMobject, - matched_keys: list[str] | None = None, - key_map: dict[str, str] | None = None, + source: Mobject, + target: Mobject, + matched_pairs: Iterable[tuple[Mobject, Mobject]] = [], match_animation: type = Transform, mismatch_animation: type = Transform, - run_time=2, - lag_ratio=0, + run_time: float = 2, + lag_ratio: float = 0, + group_type: type = Group, **kwargs, ): self.source = source self.target = target - matched_keys = matched_keys or list() - key_map = key_map or dict() + self.match_animation = match_animation + self.mismatch_animation = mismatch_animation self.anim_config = dict(**kwargs) # We will progressively build up a list of transforms # from characters in source to those in target. These # two lists keep track of which characters are accounted # for so far - self.source_chars = source.family_members_with_points() - self.target_chars = target.family_members_with_points() + self.source_pieces = source.family_members_with_points() + self.target_pieces = target.family_members_with_points() self.anims = [] - # Start by pairing all matched keys specifically passed in - for key in matched_keys: - self.add_transform( - source.select_parts(key), - target.select_parts(key), - match_animation - ) - # Then pair those based on the key map - for key, value in key_map.items(): - self.add_transform( - source.select_parts(key), - target.select_parts(value), - mismatch_animation - ) - # Now pair by substrings which were isolated in StringMobject - # initializations - specified_substrings = [ - *source.get_specified_substrings(), - *target.get_specified_substrings() - ] - for key in specified_substrings: - self.add_transform( - source.select_parts(key), - target.select_parts(key), - match_animation - ) + for pair in matched_pairs: + self.add_transform(*pair) + # Match any pairs with the same shape - pairs = self.find_pairs_with_matching_shapes(self.source_chars, self.target_chars) - for source_char, target_char in pairs: - self.add_transform(source_char, target_char, match_animation) + for pair in self.find_pairs_with_matching_shapes(self.source_pieces, self.target_pieces): + self.add_transform(*pair) + # Finally, account for mismatches - for source_char in self.source_chars: + for source_char in self.source_pieces: self.anims.append(FadeOutToPoint( source_char, target.get_center(), **self.anim_config )) - for target_char in self.target_chars: + for target_char in self.target_pieces: self.anims.append(FadeInFromPoint( target_char, source.get_center(), **self.anim_config )) + super().__init__( *self.anims, run_time=run_time, lag_ratio=lag_ratio, - group_type=VGroup, + group_type=group_type, ) def add_transform( self, - source: VMobject, - target: VMobject, - transform_type: type = Transform, + source: Mobject, + target: Mobject, ): - new_source_chars = source.family_members_with_points() - new_target_chars = target.family_members_with_points() - source_is_new = all(char in self.source_chars for char in new_source_chars) - target_is_new = all(char in self.target_chars for char in new_target_chars) - if source_is_new and target_is_new: - self.anims.append(transform_type( - source, target, **self.anim_config - )) - for char in new_source_chars: - self.source_chars.remove(char) - for char in new_target_chars: - self.target_chars.remove(char) + new_source_pieces = source.family_members_with_points() + new_target_pieces = target.family_members_with_points() + if len(new_source_pieces) == 0 or len(new_target_pieces) == 0: + # Don't animate null sorces or null targets + return + source_is_new = all(char in self.source_pieces for char in new_source_pieces) + target_is_new = all(char in self.target_pieces for char in new_target_pieces) + if not source_is_new or not target_is_new: + return - def find_pairs_with_matching_shapes(self, chars1, chars2) -> list[tuple[VMobject, VMobject]]: + transform_type = self.mismatch_animation + if source.has_same_shape_as(target): + transform_type = self.match_animation + + self.anims.append(transform_type(source, target, **self.anim_config)) + for char in new_source_pieces: + self.source_pieces.remove(char) + for char in new_target_pieces: + self.target_pieces.remove(char) + + def find_pairs_with_matching_shapes( + self, + chars1: list[Mobject], + chars2: list[Mobject] + ) -> list[tuple[Mobject, Mobject]]: result = [] for char1, char2 in it.product(chars1, chars2): if char1.has_same_shape_as(char2): @@ -242,6 +117,41 @@ class TransformMatchingStrings(AnimationGroup): scene.add(self.target) +class TransformMatchingShapes(TransformMatchingParts): + """Alias for TransformMatchingParts""" + pass + + +class TransformMatchingStrings(TransformMatchingParts): + def __init__( + self, + source: StringMobject, + target: StringMobject, + matched_keys: Iterable[str] = [], + key_map: dict[str, str] = dict(), + **kwargs, + ): + matched_pairs = [ + *[(source[key], target[key]) for key in matched_keys], + *[(source[key1], target[key2]) for key1, key2 in key_map.items()], + *[ + (source[substr], target[substr]) + for substr in [ + *source.get_specified_substrings(), + *target.get_specified_substrings(), + *source.get_symbol_substrings(), + *target.get_symbol_substrings(), + ] + ] + ] + super().__init__( + source, target, + matched_pairs=matched_pairs, + group_type=VGroup, + **kwargs, + ) + + class TransformMatchingTex(TransformMatchingStrings): """Alias for TransformMatchingStrings""" pass diff --git a/manimlib/mobject/svg/string_mobject.py b/manimlib/mobject/svg/string_mobject.py index 3c81ccc1..21ff45fb 100644 --- a/manimlib/mobject/svg/string_mobject.py +++ b/manimlib/mobject/svg/string_mobject.py @@ -561,7 +561,10 @@ class StringMobject(SVGMobject, ABC): return self.select_parts(selector)[index] def substr_to_path_count(self, substr: str) -> int: - return len(re.sub(R"\s", "", substr)) + return len(re.sub(r"\s", "", substr)) + + def get_symbol_substrings(self): + return list(re.sub(r"\s", "", self.string)) def select_unisolated_substring(self, pattern: str | re.Pattern) -> VGroup: if isinstance(pattern, str): diff --git a/manimlib/mobject/svg/tex_mobject.py b/manimlib/mobject/svg/tex_mobject.py index 27e51df3..144ca66e 100644 --- a/manimlib/mobject/svg/tex_mobject.py +++ b/manimlib/mobject/svg/tex_mobject.py @@ -42,7 +42,7 @@ class Tex(StringMobject): isolate = [isolate] isolate = [*isolate, *tex_strings] - tex_string = " ".join(tex_strings) + tex_string = (" ".join(tex_strings)).strip() # Prevent from passing an empty string. if not tex_string.strip(): @@ -222,6 +222,15 @@ class Tex(StringMobject): log.warning(f"Estimated size of {tex} does not match true size") return num_tex_symbols(substr) + def get_symbol_substrings(self): + pattern = "|".join(( + # Tex commands + r"\\[a-zA-Z]+", + # And most single characters, with these exceptions + r"[^\^\{\}\s\_\$\\\&]", + )) + return re.findall(pattern, self.string) + def make_number_changable( self, value: float | int | str,