chore: add type hints to manimlib.animation

This commit is contained in:
TonyCrane 2022-02-15 18:39:45 +08:00
parent d19e0cb9ab
commit 41c4023986
No known key found for this signature in database
GPG key ID: 2313A5058A9C637C
13 changed files with 443 additions and 182 deletions

View file

@ -1,4 +1,7 @@
from __future__ import annotations
from copy import deepcopy
from typing import Callable
from manimlib.mobject.mobject import _AnimationBuilder
from manimlib.mobject.mobject import Mobject
@ -6,6 +9,10 @@ from manimlib.utils.config_ops import digest_config
from manimlib.utils.rate_functions import smooth
from manimlib.utils.simple_functions import clip
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.scene.scene import Scene
DEFAULT_ANIMATION_RUN_TIME = 1.0
DEFAULT_ANIMATION_LAG_RATIO = 0
@ -29,17 +36,17 @@ class Animation(object):
"suspend_mobject_updating": True,
}
def __init__(self, mobject, **kwargs):
def __init__(self, mobject: Mobject, **kwargs):
assert(isinstance(mobject, Mobject))
digest_config(self, kwargs)
self.mobject = mobject
def __str__(self):
def __str__(self) -> str:
if self.name:
return self.name
return self.__class__.__name__ + str(self.mobject)
def begin(self):
def begin(self) -> None:
# This is called right as an animation is being
# played. As much initialization as possible,
# especially any mobject copying, should live in
@ -56,32 +63,32 @@ class Animation(object):
self.families = list(self.get_all_families_zipped())
self.interpolate(0)
def finish(self):
def finish(self) -> None:
self.interpolate(self.final_alpha_value)
if self.suspend_mobject_updating:
self.mobject.resume_updating()
def clean_up_from_scene(self, scene):
def clean_up_from_scene(self, scene: Scene) -> None:
if self.is_remover():
scene.remove(self.mobject)
def create_starting_mobject(self):
def create_starting_mobject(self) -> Mobject:
# Keep track of where the mobject starts
return self.mobject.copy()
def get_all_mobjects(self):
def get_all_mobjects(self) -> tuple[Mobject, Mobject]:
"""
Ordering must match the ording of arguments to interpolate_submobject
"""
return self.mobject, self.starting_mobject
def get_all_families_zipped(self):
def get_all_families_zipped(self) -> zip[tuple[Mobject]]:
return zip(*[
mob.get_family()
for mob in self.get_all_mobjects()
])
def update_mobjects(self, dt):
def update_mobjects(self, dt: float) -> None:
"""
Updates things like starting_mobject, and (for
Transforms) target_mobject. Note, since typically
@ -92,7 +99,7 @@ class Animation(object):
for mob in self.get_all_mobjects_to_update():
mob.update(dt)
def get_all_mobjects_to_update(self):
def get_all_mobjects_to_update(self) -> list[Mobject]:
# The surrounding scene typically handles
# updating of self.mobject. Besides, in
# most cases its updating is suspended anyway
@ -109,27 +116,37 @@ class Animation(object):
return self
# Methods for interpolation, the mean of an Animation
def interpolate(self, alpha):
def interpolate(self, alpha: float) -> None:
alpha = clip(alpha, 0, 1)
self.interpolate_mobject(self.rate_func(alpha))
def update(self, alpha):
def update(self, alpha: float) -> None:
"""
This method shouldn't exist, but it's here to
keep many old scenes from breaking
"""
self.interpolate(alpha)
def interpolate_mobject(self, alpha):
def interpolate_mobject(self, alpha: float) -> None:
for i, mobs in enumerate(self.families):
sub_alpha = self.get_sub_alpha(alpha, i, len(self.families))
self.interpolate_submobject(*mobs, sub_alpha)
def interpolate_submobject(self, submobject, starting_sumobject, alpha):
def interpolate_submobject(
self,
submobject: Mobject,
starting_submobject: Mobject,
alpha: float
):
# Typically ipmlemented by subclass
pass
def get_sub_alpha(self, alpha, index, num_submobjects):
def get_sub_alpha(
self,
alpha: float,
index: int,
num_submobjects: int
) -> float:
# TODO, make this more understanable, and/or combine
# its functionality with AnimationGroup's method
# build_animations_with_timings
@ -140,29 +157,29 @@ class Animation(object):
return clip((value - lower), 0, 1)
# Getters and setters
def set_run_time(self, run_time):
def set_run_time(self, run_time: float):
self.run_time = run_time
return self
def get_run_time(self):
def get_run_time(self) -> float:
return self.run_time
def set_rate_func(self, rate_func):
def set_rate_func(self, rate_func: Callable[[float], float]):
self.rate_func = rate_func
return self
def get_rate_func(self):
def get_rate_func(self) -> Callable[[float], float]:
return self.rate_func
def set_name(self, name):
def set_name(self, name: str):
self.name = name
return self
def is_remover(self):
def is_remover(self) -> bool:
return self.remover
def prepare_animation(anim):
def prepare_animation(anim: Animation | _AnimationBuilder):
if isinstance(anim, _AnimationBuilder):
return anim.build()

View file

@ -1,4 +1,7 @@
from __future__ import annotations
import numpy as np
from typing import Callable
from manimlib.animation.animation import Animation, prepare_animation
from manimlib.mobject.mobject import Group
@ -9,6 +12,11 @@ from manimlib.utils.iterables import remove_list_redundancies
from manimlib.utils.rate_functions import linear
from manimlib.utils.simple_functions import clip
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.scene.scene import Scene
from manimlib.mobject.mobject import Mobject
DEFAULT_LAGGED_START_LAG_RATIO = 0.05
@ -27,7 +35,7 @@ class AnimationGroup(Animation):
"group": None,
}
def __init__(self, *animations, **kwargs):
def __init__(self, *animations: Animation, **kwargs):
digest_config(self, kwargs)
self.animations = [prepare_animation(anim) for anim in animations]
if self.group is None:
@ -37,27 +45,27 @@ class AnimationGroup(Animation):
self.init_run_time()
Animation.__init__(self, self.group, **kwargs)
def get_all_mobjects(self):
def get_all_mobjects(self) -> Group:
return self.group
def begin(self):
def begin(self) -> None:
for anim in self.animations:
anim.begin()
# self.init_run_time()
def finish(self):
def finish(self) -> None:
for anim in self.animations:
anim.finish()
def clean_up_from_scene(self, scene):
def clean_up_from_scene(self, scene: Scene) -> None:
for anim in self.animations:
anim.clean_up_from_scene(scene)
def update_mobjects(self, dt):
def update_mobjects(self, dt: float) -> None:
for anim in self.animations:
anim.update_mobjects(dt)
def init_run_time(self):
def init_run_time(self) -> None:
self.build_animations_with_timings()
if self.anims_with_timings:
self.max_end_time = np.max([
@ -68,7 +76,7 @@ class AnimationGroup(Animation):
if self.run_time is None:
self.run_time = self.max_end_time
def build_animations_with_timings(self):
def build_animations_with_timings(self) -> None:
"""
Creates a list of triplets of the form
(anim, start_time, end_time)
@ -87,7 +95,7 @@ class AnimationGroup(Animation):
start_time, end_time, self.lag_ratio
)
def interpolate(self, alpha):
def interpolate(self, alpha: float) -> None:
# Note, if the run_time of AnimationGroup has been
# set to something other than its default, these
# times might not correspond to actual times,
@ -111,19 +119,19 @@ class Succession(AnimationGroup):
"lag_ratio": 1,
}
def begin(self):
def begin(self) -> None:
assert(len(self.animations) > 0)
self.init_run_time()
self.active_animation = self.animations[0]
self.active_animation.begin()
def finish(self):
def finish(self) -> None:
self.active_animation.finish()
def update_mobjects(self, dt):
def update_mobjects(self, dt: float) -> None:
self.active_animation.update_mobjects(dt)
def interpolate(self, alpha):
def interpolate(self, alpha: float) -> None:
index, subalpha = integer_interpolate(
0, len(self.animations), alpha
)
@ -146,7 +154,13 @@ class LaggedStartMap(LaggedStart):
"run_time": 2,
}
def __init__(self, AnimationClass, mobject, arg_creator=None, **kwargs):
def __init__(
self,
AnimationClass: type,
mobject: Mobject,
arg_creator: Callable[[Mobject], tuple] | None = None,
**kwargs
):
args_list = []
for submob in mobject:
if arg_creator:

View file

@ -1,3 +1,10 @@
from __future__ import annotations
import itertools as it
from abc import abstractmethod
import numpy as np
from manimlib.animation.animation import Animation
from manimlib.animation.composition import Succession
from manimlib.mobject.types.vectorized_mobject import VMobject
@ -7,8 +14,9 @@ from manimlib.utils.rate_functions import linear
from manimlib.utils.rate_functions import double_smooth
from manimlib.utils.rate_functions import smooth
import numpy as np
import itertools as it
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.mobject.mobject import Group
class ShowPartial(Animation):
@ -19,21 +27,27 @@ class ShowPartial(Animation):
"should_match_start": False,
}
def begin(self):
def begin(self) -> None:
super().begin()
if not self.should_match_start:
self.mobject.lock_matching_data(self.mobject, self.starting_mobject)
def finish(self):
def finish(self) -> None:
super().finish()
self.mobject.unlock_data()
def interpolate_submobject(self, submob, start_submob, alpha):
def interpolate_submobject(
self,
submob: VMobject,
start_submob: VMobject,
alpha: float
) -> None:
submob.pointwise_become_partial(
start_submob, *self.get_bounds(alpha)
)
def get_bounds(self, alpha):
@abstractmethod
def get_bounds(self, alpha: float) -> tuple[float, float]:
raise Exception("Not Implemented")
@ -42,7 +56,7 @@ class ShowCreation(ShowPartial):
"lag_ratio": 1,
}
def get_bounds(self, alpha):
def get_bounds(self, alpha: float) -> tuple[float, float]:
return (0, alpha)
@ -64,7 +78,7 @@ class DrawBorderThenFill(Animation):
"fill_animation_config": {},
}
def __init__(self, vmobject, **kwargs):
def __init__(self, vmobject: VMobject, **kwargs):
assert(isinstance(vmobject, VMobject))
self.sm_to_index = dict([
(hash(sm), 0)
@ -72,7 +86,7 @@ class DrawBorderThenFill(Animation):
])
super().__init__(vmobject, **kwargs)
def begin(self):
def begin(self) -> None:
# Trigger triangulation calculation
for submob in self.mobject.get_family():
submob.get_triangulation()
@ -82,11 +96,11 @@ class DrawBorderThenFill(Animation):
self.mobject.match_style(self.outline)
self.mobject.lock_matching_data(self.mobject, self.outline)
def finish(self):
def finish(self) -> None:
super().finish()
self.mobject.unlock_data()
def get_outline(self):
def get_outline(self) -> VMobject:
outline = self.mobject.copy()
outline.set_fill(opacity=0)
for sm in outline.get_family():
@ -96,17 +110,23 @@ class DrawBorderThenFill(Animation):
)
return outline
def get_stroke_color(self, vmobject):
def get_stroke_color(self, vmobject: VMobject) -> str:
if self.stroke_color:
return self.stroke_color
elif vmobject.get_stroke_width() > 0:
return vmobject.get_stroke_color()
return vmobject.get_color()
def get_all_mobjects(self):
def get_all_mobjects(self) -> list[VMobject]:
return [*super().get_all_mobjects(), self.outline]
def interpolate_submobject(self, submob, start, outline, alpha):
def interpolate_submobject(
self,
submob: VMobject,
start: VMobject,
outline: VMobject,
alpha: float
) -> None:
index, subalpha = integer_interpolate(0, 2, alpha)
if index == 1 and self.sm_to_index[hash(submob)] == 0:
@ -133,13 +153,13 @@ class Write(DrawBorderThenFill):
"rate_func": linear,
}
def __init__(self, mobject, **kwargs):
def __init__(self, vmobject: VMobject, **kwargs):
digest_config(self, kwargs)
self.set_default_config_from_length(mobject)
super().__init__(mobject, **kwargs)
self.set_default_config_from_length(vmobject)
super().__init__(vmobject, **kwargs)
def set_default_config_from_length(self, mobject):
length = len(mobject.family_members_with_points())
def set_default_config_from_length(self, vmobject: VMobject) -> None:
length = len(vmobject.family_members_with_points())
if self.run_time is None:
if length < 15:
self.run_time = 1
@ -155,16 +175,16 @@ class ShowIncreasingSubsets(Animation):
"int_func": np.round,
}
def __init__(self, group, **kwargs):
def __init__(self, group: Group, **kwargs):
self.all_submobs = list(group.submobjects)
super().__init__(group, **kwargs)
def interpolate_mobject(self, alpha):
def interpolate_mobject(self, alpha: float) -> None:
n_submobs = len(self.all_submobs)
index = int(self.int_func(alpha * n_submobs))
self.update_submobject_list(index)
def update_submobject_list(self, index):
def update_submobject_list(self, index: int) -> None:
self.mobject.set_submobjects(self.all_submobs[:index])
@ -173,7 +193,7 @@ class ShowSubmobjectsOneByOne(ShowIncreasingSubsets):
"int_func": np.ceil,
}
def update_submobject_list(self, index):
def update_submobject_list(self, index: int) -> None:
# N = len(self.all_submobs)
if index == 0:
self.mobject.set_submobjects([])

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import numpy as np
from manimlib.animation.animation import Animation
@ -7,6 +9,12 @@ from manimlib.constants import ORIGIN
from manimlib.utils.bezier import interpolate
from manimlib.utils.rate_functions import there_and_back
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.scene.scene import Scene
from manimlib.mobject.mobject import Mobject
from manimlib.mobject.types.vectorized_mobject import VMobject
DEFAULT_FADE_LAG_RATIO = 0
@ -16,7 +24,13 @@ class Fade(Transform):
"lag_ratio": DEFAULT_FADE_LAG_RATIO,
}
def __init__(self, mobject, shift=ORIGIN, scale=1, **kwargs):
def __init__(
self,
mobject: Mobject,
shift: np.ndarray = ORIGIN,
scale: float = 1,
**kwargs
):
self.shift_vect = shift
self.scale_factor = scale
super().__init__(mobject, **kwargs)
@ -27,10 +41,10 @@ class FadeIn(Fade):
"lag_ratio": DEFAULT_FADE_LAG_RATIO,
}
def create_target(self):
def create_target(self) -> Mobject:
return self.mobject
def create_starting_mobject(self):
def create_starting_mobject(self) -> Mobject:
start = super().create_starting_mobject()
start.set_opacity(0)
start.scale(1.0 / self.scale_factor)
@ -45,7 +59,7 @@ class FadeOut(Fade):
"final_alpha_value": 0,
}
def create_target(self):
def create_target(self) -> Mobject:
result = self.mobject.copy()
result.set_opacity(0)
result.shift(self.shift_vect)
@ -54,7 +68,7 @@ class FadeOut(Fade):
class FadeInFromPoint(FadeIn):
def __init__(self, mobject, point, **kwargs):
def __init__(self, mobject: Mobject, point: np.ndarray, **kwargs):
super().__init__(
mobject,
shift=mobject.get_center() - point,
@ -64,7 +78,7 @@ class FadeInFromPoint(FadeIn):
class FadeOutToPoint(FadeOut):
def __init__(self, mobject, point, **kwargs):
def __init__(self, mobject: Mobject, point: np.ndarray, **kwargs):
super().__init__(
mobject,
shift=point - mobject.get_center(),
@ -79,7 +93,7 @@ class FadeTransform(Transform):
"dim_to_match": 1,
}
def __init__(self, mobject, target_mobject, **kwargs):
def __init__(self, mobject: Mobject, target_mobject: Mobject, **kwargs):
self.to_add_on_completion = target_mobject
mobject.save_state()
super().__init__(
@ -87,7 +101,7 @@ class FadeTransform(Transform):
**kwargs
)
def begin(self):
def begin(self) -> None:
self.ending_mobject = self.mobject.copy()
Animation.begin(self)
# Both 'start' and 'end' consists of the source and target mobjects.
@ -97,21 +111,21 @@ class FadeTransform(Transform):
for m0, m1 in ((start[1], start[0]), (end[0], end[1])):
self.ghost_to(m0, m1)
def ghost_to(self, source, target):
def ghost_to(self, source: Mobject, target: Mobject) -> None:
source.replace(target, stretch=self.stretch, dim_to_match=self.dim_to_match)
source.set_opacity(0)
def get_all_mobjects(self):
def get_all_mobjects(self) -> list[Mobject]:
return [
self.mobject,
self.starting_mobject,
self.ending_mobject,
]
def get_all_families_zipped(self):
def get_all_families_zipped(self) -> zip[tuple[Mobject]]:
return Animation.get_all_families_zipped(self)
def clean_up_from_scene(self, scene):
def clean_up_from_scene(self, scene: Scene) -> None:
Animation.clean_up_from_scene(self, scene)
scene.remove(self.mobject)
self.mobject[0].restore()
@ -119,11 +133,11 @@ class FadeTransform(Transform):
class FadeTransformPieces(FadeTransform):
def begin(self):
def begin(self) -> None:
self.mobject[0].align_family(self.mobject[1])
super().begin()
def ghost_to(self, source, target):
def ghost_to(self, source: Mobject, target: Mobject) -> None:
for sm0, sm1 in zip(source.get_family(), target.get_family()):
super().ghost_to(sm0, sm1)
@ -136,7 +150,12 @@ class VFadeIn(Animation):
"suspend_mobject_updating": False,
}
def interpolate_submobject(self, submob, start, alpha):
def interpolate_submobject(
self,
submob: VMobject,
start: VMobject,
alpha: float
) -> None:
submob.set_stroke(
opacity=interpolate(0, start.get_stroke_opacity(), alpha)
)
@ -152,7 +171,12 @@ class VFadeOut(VFadeIn):
"final_alpha_value": 0,
}
def interpolate_submobject(self, submob, start, alpha):
def interpolate_submobject(
self,
submob: VMobject,
start: VMobject,
alpha: float
) -> None:
super().interpolate_submobject(submob, start, 1 - alpha)

View file

@ -1,6 +1,13 @@
from manimlib.animation.transform import Transform
# from manimlib.utils.paths import counterclockwise_path
from __future__ import annotations
from manimlib.constants import PI
from manimlib.animation.transform import Transform
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import numpy as np
from manimlib.mobject.mobject import Mobject
from manimlib.mobject.geometry import Arrow
class GrowFromPoint(Transform):
@ -8,14 +15,14 @@ class GrowFromPoint(Transform):
"point_color": None,
}
def __init__(self, mobject, point, **kwargs):
def __init__(self, mobject: Mobject, point: np.ndarray, **kwargs):
self.point = point
super().__init__(mobject, **kwargs)
def create_target(self):
def create_target(self) -> Mobject:
return self.mobject
def create_starting_mobject(self):
def create_starting_mobject(self) -> Mobject:
start = super().create_starting_mobject()
start.scale(0)
start.move_to(self.point)
@ -25,19 +32,19 @@ class GrowFromPoint(Transform):
class GrowFromCenter(GrowFromPoint):
def __init__(self, mobject, **kwargs):
def __init__(self, mobject: Mobject, **kwargs):
point = mobject.get_center()
super().__init__(mobject, point, **kwargs)
class GrowFromEdge(GrowFromPoint):
def __init__(self, mobject, edge, **kwargs):
def __init__(self, mobject: Mobject, edge: np.ndarray, **kwargs):
point = mobject.get_bounding_box_point(edge)
super().__init__(mobject, point, **kwargs)
class GrowArrow(GrowFromPoint):
def __init__(self, arrow, **kwargs):
def __init__(self, arrow: Arrow, **kwargs):
point = arrow.get_start()
super().__init__(arrow, point, **kwargs)

View file

@ -1,5 +1,9 @@
import numpy as np
from __future__ import annotations
import math
from typing import Union, Sequence
import numpy as np
from manimlib.constants import *
from manimlib.animation.animation import Animation
@ -10,7 +14,7 @@ from manimlib.animation.creation import ShowCreation
from manimlib.animation.creation import ShowPartial
from manimlib.animation.fading import FadeOut
from manimlib.animation.fading import FadeIn
from manimlib.animation.transform import Transform
from manimlib.animation.transform import ManimColor, Transform
from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.mobject.geometry import Circle
from manimlib.mobject.geometry import Dot
@ -25,6 +29,12 @@ from manimlib.utils.rate_functions import wiggle
from manimlib.utils.rate_functions import smooth
from manimlib.utils.rate_functions import squish_rate_func
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import colour
from manimlib.mobject.mobject import Mobject
ManimColor = Union[str, colour.Color, Sequence[float]]
class FocusOn(Transform):
CONFIG = {
@ -34,13 +44,13 @@ class FocusOn(Transform):
"remover": True,
}
def __init__(self, focus_point, **kwargs):
def __init__(self, focus_point: np.ndarray, **kwargs):
self.focus_point = focus_point
# Initialize with blank mobject, while create_target
# and create_starting_mobject handle the meat
super().__init__(VMobject(), **kwargs)
def create_target(self):
def create_target(self) -> Dot:
little_dot = Dot(radius=0)
little_dot.set_fill(self.color, opacity=self.opacity)
little_dot.add_updater(
@ -48,7 +58,7 @@ class FocusOn(Transform):
)
return little_dot
def create_starting_mobject(self):
def create_starting_mobject(self) -> Dot:
return Dot(
radius=FRAME_X_RADIUS + FRAME_Y_RADIUS,
stroke_width=0,
@ -64,7 +74,7 @@ class Indicate(Transform):
"color": YELLOW,
}
def create_target(self):
def create_target(self) -> Mobject:
target = self.mobject.copy()
target.scale(self.scale_factor)
target.set_color(self.color)
@ -80,7 +90,12 @@ class Flash(AnimationGroup):
"run_time": 1,
}
def __init__(self, point, color=YELLOW, **kwargs):
def __init__(
self,
point: np.ndarray,
color: ManimColor = YELLOW,
**kwargs
):
self.point = point
self.color = color
digest_config(self, kwargs)
@ -92,7 +107,7 @@ class Flash(AnimationGroup):
**kwargs,
)
def create_lines(self):
def create_lines(self) -> VGroup:
lines = VGroup()
for angle in np.arange(0, TAU, TAU / self.num_lines):
line = Line(ORIGIN, self.line_length * RIGHT)
@ -106,7 +121,7 @@ class Flash(AnimationGroup):
lines.add_updater(lambda l: l.move_to(self.point))
return lines
def create_line_anims(self):
def create_line_anims(self) -> list[Animation]:
return [
ShowCreationThenDestruction(line)
for line in self.lines
@ -122,17 +137,17 @@ class CircleIndicate(Indicate):
},
}
def __init__(self, mobject, **kwargs):
def __init__(self, mobject: Mobject, **kwargs):
digest_config(self, kwargs)
circle = self.get_circle(mobject)
super().__init__(circle, **kwargs)
def get_circle(self, mobject):
def get_circle(self, mobject: Mobject) -> Circle:
circle = Circle(**self.circle_config)
circle.add_updater(lambda c: c.surround(mobject))
return circle
def interpolate_mobject(self, alpha):
def interpolate_mobject(self, alpha: float) -> None:
super().interpolate_mobject(alpha)
self.mobject.set_stroke(opacity=alpha)
@ -143,7 +158,7 @@ class ShowPassingFlash(ShowPartial):
"remover": True,
}
def get_bounds(self, alpha):
def get_bounds(self, alpha: float) -> tuple[float, float]:
tw = self.time_width
upper = interpolate(0, 1 + tw, alpha)
lower = upper - tw
@ -151,7 +166,7 @@ class ShowPassingFlash(ShowPartial):
lower = max(lower, 0)
return (lower, upper)
def finish(self):
def finish(self) -> None:
super().finish()
for submob, start in self.get_all_families_zipped():
submob.pointwise_become_partial(start, 0, 1)
@ -164,7 +179,7 @@ class VShowPassingFlash(Animation):
"remover": True,
}
def begin(self):
def begin(self) -> None:
self.mobject.align_stroke_width_data_to_points()
# Compute an array of stroke widths for each submobject
# which tapers out at either end
@ -184,7 +199,12 @@ class VShowPassingFlash(Animation):
self.submob_to_anchor_widths[hash(sm)] = anchor_widths * taper_array
super().begin()
def interpolate_submobject(self, submobject, starting_sumobject, alpha):
def interpolate_submobject(
self,
submobject: VMobject,
starting_sumobject: None,
alpha: float
) -> None:
anchor_widths = self.submob_to_anchor_widths[hash(submobject)]
# Create a gaussian such that 3 sigmas out on either side
# will equals time_width
@ -206,7 +226,7 @@ class VShowPassingFlash(Animation):
new_widths[1::3] = (new_widths[0::3] + new_widths[2::3]) / 2
submobject.set_stroke(width=new_widths)
def finish(self):
def finish(self) -> None:
super().finish()
for submob, start in self.get_all_families_zipped():
submob.match_style(start)
@ -221,7 +241,7 @@ class FlashAround(VShowPassingFlash):
"n_inserted_curves": 20,
}
def __init__(self, mobject, **kwargs):
def __init__(self, mobject: Mobject, **kwargs):
digest_config(self, kwargs)
path = self.get_path(mobject)
if mobject.is_fixed_in_frame:
@ -231,12 +251,12 @@ class FlashAround(VShowPassingFlash):
path.set_stroke(self.color, self.stroke_width)
super().__init__(path, **kwargs)
def get_path(self, mobject):
def get_path(self, mobject: Mobject) -> SurroundingRectangle:
return SurroundingRectangle(mobject, buff=self.buff)
class FlashUnder(FlashAround):
def get_path(self, mobject):
def get_path(self, mobject: Mobject) -> Underline:
return Underline(mobject, buff=self.buff)
@ -252,7 +272,7 @@ class ShowCreationThenFadeOut(Succession):
"remover": True,
}
def __init__(self, mobject, **kwargs):
def __init__(self, mobject: Mobject, **kwargs):
super().__init__(
ShowCreation(mobject),
FadeOut(mobject),
@ -269,7 +289,7 @@ class AnimationOnSurroundingRectangle(AnimationGroup):
"rect_animation": Animation
}
def __init__(self, mobject, **kwargs):
def __init__(self, mobject: Mobject, **kwargs):
digest_config(self, kwargs)
if "surrounding_rectangle_config" in kwargs:
kwargs.pop("surrounding_rectangle_config")
@ -282,7 +302,7 @@ class AnimationOnSurroundingRectangle(AnimationGroup):
self.rect_animation(rect, **kwargs),
)
def get_rect(self):
def get_rect(self) -> SurroundingRectangle:
return SurroundingRectangle(
self.mobject_to_surround,
**self.surrounding_rectangle_config
@ -314,7 +334,7 @@ class ApplyWave(Homotopy):
"run_time": 1,
}
def __init__(self, mobject, **kwargs):
def __init__(self, mobject: Mobject, **kwargs):
digest_config(self, kwargs, locals())
left_x = mobject.get_left()[0]
right_x = mobject.get_right()[0]
@ -339,15 +359,20 @@ class WiggleOutThenIn(Animation):
"rotate_about_point": None,
}
def get_scale_about_point(self):
def get_scale_about_point(self) -> np.ndarray:
if self.scale_about_point is None:
return self.mobject.get_center()
def get_rotate_about_point(self):
def get_rotate_about_point(self) -> np.ndarray:
if self.rotate_about_point is None:
return self.mobject.get_center()
def interpolate_submobject(self, submobject, starting_sumobject, alpha):
def interpolate_submobject(
self,
submobject: Mobject,
starting_sumobject: Mobject,
alpha: float
) -> None:
submobject.match_points(starting_sumobject)
submobject.scale(
interpolate(1, self.scale_value, there_and_back(alpha)),
@ -364,7 +389,7 @@ class TurnInsideOut(Transform):
"path_arc": TAU / 4,
}
def create_target(self):
def create_target(self) -> Mobject:
return self.mobject.copy().reverse_points()
@ -373,7 +398,7 @@ class FlashyFadeIn(AnimationGroup):
"fade_lag": 0,
}
def __init__(self, vmobject, stroke_width=2, **kwargs):
def __init__(self, vmobject: VMobject, stroke_width: float = 2, **kwargs):
digest_config(self, kwargs)
outline = vmobject.copy()
outline.set_fill(opacity=0)

View file

@ -1,6 +1,15 @@
from __future__ import annotations
from typing import Callable, Sequence
from manimlib.animation.animation import Animation
from manimlib.utils.rate_functions import linear
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import numpy as np
from manimlib.mobject.mobject import Mobject
class Homotopy(Animation):
CONFIG = {
@ -8,7 +17,12 @@ class Homotopy(Animation):
"apply_function_kwargs": {},
}
def __init__(self, homotopy, mobject, **kwargs):
def __init__(
self,
homotopy: Callable[[float, float, float, float], Sequence[float]],
mobject: Mobject,
**kwargs
):
"""
Homotopy is a function from
(x, y, z, t) to (x', y', z')
@ -16,10 +30,18 @@ class Homotopy(Animation):
self.homotopy = homotopy
super().__init__(mobject, **kwargs)
def function_at_time_t(self, t):
def function_at_time_t(
self,
t: float
) -> Callable[[np.ndarray], Sequence[float]]:
return lambda p: self.homotopy(*p, t)
def interpolate_submobject(self, submob, start, alpha):
def interpolate_submobject(
self,
submob: Mobject,
start: Mobject,
alpha: float
) -> None:
submob.match_points(start)
submob.apply_function(
self.function_at_time_t(alpha),
@ -34,7 +56,12 @@ class SmoothedVectorizedHomotopy(Homotopy):
class ComplexHomotopy(Homotopy):
def __init__(self, complex_homotopy, mobject, **kwargs):
def __init__(
self,
complex_homotopy: Callable[[complex, float], Sequence[float]],
mobject: Mobject,
**kwargs
):
"""
Given a function form (z, t) -> w, where z and w
are complex numbers and t is time, this animates
@ -53,11 +80,16 @@ class PhaseFlow(Animation):
"suspend_mobject_updating": False,
}
def __init__(self, function, mobject, **kwargs):
def __init__(
self,
function: Callable[[np.ndarray], np.ndarray],
mobject: Mobject,
**kwargs
):
self.function = function
super().__init__(mobject, **kwargs)
def interpolate_mobject(self, alpha):
def interpolate_mobject(self, alpha: float) -> None:
if hasattr(self, "last_alpha"):
dt = self.virtual_time * (alpha - self.last_alpha)
self.mobject.apply_function(
@ -71,10 +103,10 @@ class MoveAlongPath(Animation):
"suspend_mobject_updating": False,
}
def __init__(self, mobject, path, **kwargs):
def __init__(self, mobject: Mobject, path: Mobject, **kwargs):
self.path = path
super().__init__(mobject, **kwargs)
def interpolate_mobject(self, alpha):
def interpolate_mobject(self, alpha: float) -> None:
point = self.path.point_from_proportion(alpha)
self.mobject.move_to(point)

View file

@ -1,3 +1,7 @@
from __future__ import annotations
from typing import Callable
from manimlib.animation.animation import Animation
from manimlib.mobject.numbers import DecimalNumber
from manimlib.utils.bezier import interpolate
@ -8,19 +12,29 @@ class ChangingDecimal(Animation):
"suspend_mobject_updating": False,
}
def __init__(self, decimal_mob, number_update_func, **kwargs):
def __init__(
self,
decimal_mob: DecimalNumber,
number_update_func: Callable[[float], float],
**kwargs
):
assert(isinstance(decimal_mob, DecimalNumber))
self.number_update_func = number_update_func
super().__init__(decimal_mob, **kwargs)
def interpolate_mobject(self, alpha):
def interpolate_mobject(self, alpha: float) -> None:
self.mobject.set_value(
self.number_update_func(alpha)
)
class ChangeDecimalToValue(ChangingDecimal):
def __init__(self, decimal_mob, target_number, **kwargs):
def __init__(
self,
decimal_mob: DecimalNumber,
target_number: float | complex,
**kwargs
):
start_number = decimal_mob.number
super().__init__(
decimal_mob,
@ -30,7 +44,12 @@ class ChangeDecimalToValue(ChangingDecimal):
class CountInFrom(ChangingDecimal):
def __init__(self, decimal_mob, source_number=0, **kwargs):
def __init__(
self,
decimal_mob: DecimalNumber,
source_number: float | complex = 0,
**kwargs
):
start_number = decimal_mob.number
super().__init__(
decimal_mob,

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from manimlib.animation.animation import Animation
from manimlib.constants import OUT
from manimlib.constants import PI
@ -6,6 +8,11 @@ from manimlib.constants import ORIGIN
from manimlib.utils.rate_functions import linear
from manimlib.utils.rate_functions import smooth
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import numpy as np
from manimlib.mobject.mobject import Mobject
class Rotating(Animation):
CONFIG = {
@ -18,12 +25,18 @@ class Rotating(Animation):
"suspend_mobject_updating": False,
}
def __init__(self, mobject, angle=TAU, axis=OUT, **kwargs):
def __init__(
self,
mobject: Mobject,
angle: float = TAU,
axis: np.ndarray = OUT,
**kwargs
):
self.angle = angle
self.axis = axis
super().__init__(mobject, **kwargs)
def interpolate_mobject(self, alpha):
def interpolate_mobject(self, alpha: float) -> None:
for sm1, sm2 in self.get_all_families_zipped():
sm1.set_points(sm2.get_points())
self.mobject.rotate(
@ -41,5 +54,11 @@ class Rotate(Rotating):
"about_edge": ORIGIN,
}
def __init__(self, mobject, angle=PI, axis=OUT, **kwargs):
def __init__(
self,
mobject: Mobject,
angle: float = PI,
axis: np.ndarray = OUT,
**kwargs
):
super().__init__(mobject, angle, axis, **kwargs)

View file

@ -1,3 +1,7 @@
from __future__ import annotations
import numpy as np
from manimlib.animation.composition import LaggedStart
from manimlib.animation.transform import Restore
from manimlib.constants import WHITE
@ -19,7 +23,7 @@ class Broadcast(LaggedStart):
"run_time": 3,
}
def __init__(self, focal_point, **kwargs):
def __init__(self, focal_point: np.ndarray, **kwargs):
digest_config(self, kwargs)
circles = VGroup()
for x in range(self.n_circles):

View file

@ -1,6 +1,10 @@
from __future__ import annotations
import inspect
from typing import Callable, Union, Sequence
import numpy as np
import numpy.typing as npt
from manimlib.animation.animation import Animation
from manimlib.constants import DEFAULT_POINTWISE_FUNCTION_RUN_TIME
@ -14,6 +18,12 @@ from manimlib.utils.paths import straight_path
from manimlib.utils.rate_functions import smooth
from manimlib.utils.rate_functions import squish_rate_func
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import colour
from manimlib.scene.scene import Scene
ManimColor = Union[str, colour.Color, Sequence[float]]
class Transform(Animation):
CONFIG = {
@ -23,12 +33,17 @@ class Transform(Animation):
"replace_mobject_with_target_in_scene": False,
}
def __init__(self, mobject, target_mobject=None, **kwargs):
def __init__(
self,
mobject: Mobject,
target_mobject: Mobject | None = None,
**kwargs
):
super().__init__(mobject, **kwargs)
self.target_mobject = target_mobject
self.init_path_func()
def init_path_func(self):
def init_path_func(self) -> None:
if self.path_func is not None:
return
elif self.path_arc == 0:
@ -39,7 +54,7 @@ class Transform(Animation):
self.path_arc_axis,
)
def begin(self):
def begin(self) -> None:
self.target_mobject = self.create_target()
self.check_target_mobject_validity()
# Use a copy of target_mobject for the align_data_and_family
@ -54,28 +69,28 @@ class Transform(Animation):
self.target_copy,
)
def finish(self):
def finish(self) -> None:
super().finish()
self.mobject.unlock_data()
def create_target(self):
def create_target(self) -> Mobject:
# Has no meaningful effect here, but may be useful
# in subclasses
return self.target_mobject
def check_target_mobject_validity(self):
def check_target_mobject_validity(self) -> None:
if self.target_mobject is None:
raise Exception(
f"{self.__class__.__name__}.create_target not properly implemented"
)
def clean_up_from_scene(self, scene):
def clean_up_from_scene(self, scene: Scene) -> None:
super().clean_up_from_scene(scene)
if self.replace_mobject_with_target_in_scene:
scene.remove(self.mobject)
scene.add(self.target_mobject)
def update_config(self, **kwargs):
def update_config(self, **kwargs) -> None:
Animation.update_config(self, **kwargs)
if "path_arc" in kwargs:
self.path_func = path_along_arc(
@ -83,7 +98,7 @@ class Transform(Animation):
kwargs.get("path_arc_axis", OUT)
)
def get_all_mobjects(self):
def get_all_mobjects(self) -> list[Mobject]:
return [
self.mobject,
self.starting_mobject,
@ -91,7 +106,7 @@ class Transform(Animation):
self.target_copy,
]
def get_all_families_zipped(self):
def get_all_families_zipped(self) -> zip[tuple[Mobject]]:
return zip(*[
mob.get_family()
for mob in [
@ -101,7 +116,13 @@ class Transform(Animation):
]
])
def interpolate_submobject(self, submob, start, target_copy, alpha):
def interpolate_submobject(
self,
submob: Mobject,
start: Mobject,
target_copy: Mobject,
alpha: float
):
submob.interpolate(start, target_copy, alpha, self.path_func)
return self
@ -117,10 +138,10 @@ class TransformFromCopy(Transform):
Performs a reversed Transform
"""
def __init__(self, mobject, target_mobject, **kwargs):
def __init__(self, mobject: Mobject, target_mobject: Mobject, **kwargs):
super().__init__(target_mobject, mobject, **kwargs)
def interpolate(self, alpha):
def interpolate(self, alpha: float) -> None:
super().interpolate(1 - alpha)
@ -137,11 +158,11 @@ class CounterclockwiseTransform(Transform):
class MoveToTarget(Transform):
def __init__(self, mobject, **kwargs):
def __init__(self, mobject: Mobject, **kwargs):
self.check_validity_of_input(mobject)
super().__init__(mobject, mobject.target, **kwargs)
def check_validity_of_input(self, mobject):
def check_validity_of_input(self, mobject: Mobject) -> None:
if not hasattr(mobject, "target"):
raise Exception(
"MoveToTarget called on mobject"
@ -150,13 +171,13 @@ class MoveToTarget(Transform):
class _MethodAnimation(MoveToTarget):
def __init__(self, mobject, methods):
def __init__(self, mobject: Mobject, methods: Callable):
self.methods = methods
super().__init__(mobject)
class ApplyMethod(Transform):
def __init__(self, method, *args, **kwargs):
def __init__(self, method: Callable, *args, **kwargs):
"""
method is a method of Mobject, *args are arguments for
that method. Key word arguments should be passed in
@ -170,7 +191,7 @@ class ApplyMethod(Transform):
self.method_args = args
super().__init__(method.__self__, **kwargs)
def check_validity_of_input(self, method):
def check_validity_of_input(self, method: Callable) -> None:
if not inspect.ismethod(method):
raise Exception(
"Whoops, looks like you accidentally invoked "
@ -178,7 +199,7 @@ class ApplyMethod(Transform):
)
assert(isinstance(method.__self__, Mobject))
def create_target(self):
def create_target(self) -> Mobject:
method = self.method
# Make sure it's a list so that args.pop() works
args = list(self.method_args)
@ -197,16 +218,26 @@ class ApplyPointwiseFunction(ApplyMethod):
"run_time": DEFAULT_POINTWISE_FUNCTION_RUN_TIME
}
def __init__(self, function, mobject, **kwargs):
def __init__(
self,
function: Callable[[np.ndarray], np.ndarray],
mobject: Mobject,
**kwargs
):
super().__init__(mobject.apply_function, function, **kwargs)
class ApplyPointwiseFunctionToCenter(ApplyPointwiseFunction):
def __init__(self, function, mobject, **kwargs):
def __init__(
self,
function: Callable[[np.ndarray], np.ndarray],
mobject: Mobject,
**kwargs
):
self.function = function
super().__init__(mobject.move_to, **kwargs)
def begin(self):
def begin(self) -> None:
self.method_args = [
self.function(self.mobject.get_center())
]
@ -214,31 +245,46 @@ class ApplyPointwiseFunctionToCenter(ApplyPointwiseFunction):
class FadeToColor(ApplyMethod):
def __init__(self, mobject, color, **kwargs):
def __init__(
self,
mobject: Mobject,
color: ManimColor,
**kwargs
):
super().__init__(mobject.set_color, color, **kwargs)
class ScaleInPlace(ApplyMethod):
def __init__(self, mobject, scale_factor, **kwargs):
def __init__(
self,
mobject: Mobject,
scale_factor: npt.ArrayLike,
**kwargs
):
super().__init__(mobject.scale, scale_factor, **kwargs)
class ShrinkToCenter(ScaleInPlace):
def __init__(self, mobject, **kwargs):
def __init__(self, mobject: Mobject, **kwargs):
super().__init__(mobject, 0, **kwargs)
class Restore(ApplyMethod):
def __init__(self, mobject, **kwargs):
def __init__(self, mobject: Mobject, **kwargs):
super().__init__(mobject.restore, **kwargs)
class ApplyFunction(Transform):
def __init__(self, function, mobject, **kwargs):
def __init__(
self,
function: Callable[[Mobject], Mobject],
mobject: Mobject,
**kwargs
):
self.function = function
super().__init__(mobject, **kwargs)
def create_target(self):
def create_target(self) -> Mobject:
target = self.function(self.mobject.copy())
if not isinstance(target, Mobject):
raise Exception("Functions passed to ApplyFunction must return object of type Mobject")
@ -246,7 +292,12 @@ class ApplyFunction(Transform):
class ApplyMatrix(ApplyPointwiseFunction):
def __init__(self, matrix, mobject, **kwargs):
def __init__(
self,
matrix: npt.ArrayLike,
mobject: Mobject,
**kwargs
):
matrix = self.initialize_matrix(matrix)
def func(p):
@ -254,7 +305,7 @@ class ApplyMatrix(ApplyPointwiseFunction):
super().__init__(func, mobject, **kwargs)
def initialize_matrix(self, matrix):
def initialize_matrix(self, matrix: npt.ArrayLike) -> np.ndarray:
matrix = np.array(matrix)
if matrix.shape == (2, 2):
new_matrix = np.identity(3)
@ -266,12 +317,17 @@ class ApplyMatrix(ApplyPointwiseFunction):
class ApplyComplexFunction(ApplyMethod):
def __init__(self, function, mobject, **kwargs):
def __init__(
self,
function: Callable[[complex], complex],
mobject: Mobject,
**kwargs
):
self.function = function
method = mobject.apply_complex_function
super().__init__(method, function, **kwargs)
def init_path_func(self):
def init_path_func(self) -> None:
func1 = self.function(complex(1))
self.path_arc = np.log(func1).imag
super().init_path_func()
@ -284,11 +340,11 @@ class CyclicReplace(Transform):
"path_arc": 90 * DEGREES,
}
def __init__(self, *mobjects, **kwargs):
def __init__(self, *mobjects: Mobject, **kwargs):
self.group = Group(*mobjects)
super().__init__(self.group, **kwargs)
def create_target(self):
def create_target(self) -> Mobject:
target = self.group.copy()
cycled_targets = [target[-1], *target[:-1]]
for m1, m2 in zip(cycled_targets, self.group):
@ -306,7 +362,7 @@ class TransformAnimations(Transform):
"rate_func": squish_rate_func(smooth)
}
def __init__(self, start_anim, end_anim, **kwargs):
def __init__(self, start_anim: Animation, end_anim: Animation, **kwargs):
digest_config(self, kwargs, locals())
if "run_time" in kwargs:
self.run_time = kwargs.pop("run_time")
@ -327,7 +383,7 @@ class TransformAnimations(Transform):
start_anim.mobject = self.starting_mobject
end_anim.mobject = self.target_mobject
def interpolate(self, alpha):
def interpolate(self, alpha: float) -> None:
self.start_anim.interpolate(alpha)
self.end_anim.interpolate(alpha)
Transform.interpolate(self, alpha)

View file

@ -1,13 +1,15 @@
import numpy as np
from __future__ import annotations
import itertools as it
import numpy as np
from manimlib.animation.composition import AnimationGroup
from manimlib.animation.fading import FadeTransformPieces
from manimlib.animation.fading import FadeInFromPoint
from manimlib.animation.fading import FadeOutToPoint
from manimlib.animation.transform import ReplacementTransform
from manimlib.animation.transform import Transform
from manimlib.mobject.mobject import Mobject
from manimlib.mobject.mobject import Group
from manimlib.mobject.svg.mtex_mobject import MTex
@ -16,6 +18,11 @@ from manimlib.mobject.types.vectorized_mobject import VMobject
from manimlib.utils.config_ops import digest_config
from manimlib.utils.iterables import remove_list_redundancies
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.scene.scene import Scene
from manimlib.mobject.svg.tex_mobject import Tex, SingleStringTex
class TransformMatchingParts(AnimationGroup):
CONFIG = {
@ -26,7 +33,7 @@ class TransformMatchingParts(AnimationGroup):
"key_map": dict(),
}
def __init__(self, mobject, target_mobject, **kwargs):
def __init__(self, mobject: Mobject, target_mobject: Mobject, **kwargs):
digest_config(self, kwargs)
assert(isinstance(mobject, self.mobject_type))
assert(isinstance(target_mobject, self.mobject_type))
@ -83,8 +90,8 @@ class TransformMatchingParts(AnimationGroup):
self.to_remove = mobject
self.to_add = target_mobject
def get_shape_map(self, mobject):
shape_map = {}
def get_shape_map(self, mobject: Mobject) -> dict[int, VGroup]:
shape_map: dict[int, VGroup] = {}
for sm in self.get_mobject_parts(mobject):
key = self.get_mobject_key(sm)
if key not in shape_map:
@ -92,7 +99,7 @@ class TransformMatchingParts(AnimationGroup):
shape_map[key].add(sm)
return shape_map
def clean_up_from_scene(self, scene):
def clean_up_from_scene(self, scene: Scene) -> None:
for anim in self.animations:
anim.update(0)
scene.remove(self.mobject)
@ -100,12 +107,12 @@ class TransformMatchingParts(AnimationGroup):
scene.add(self.to_add)
@staticmethod
def get_mobject_parts(mobject):
def get_mobject_parts(mobject: Mobject) -> Mobject:
# To be implemented in subclass
return mobject
@staticmethod
def get_mobject_key(mobject):
def get_mobject_key(mobject: Mobject) -> int:
# To be implemented in subclass
return hash(mobject)
@ -117,11 +124,11 @@ class TransformMatchingShapes(TransformMatchingParts):
}
@staticmethod
def get_mobject_parts(mobject):
def get_mobject_parts(mobject: VMobject) -> list[VMobject]:
return mobject.family_members_with_points()
@staticmethod
def get_mobject_key(mobject):
def get_mobject_key(mobject: VMobject) -> int:
mobject.save_state()
mobject.center()
mobject.set_height(1)
@ -137,11 +144,11 @@ class TransformMatchingTex(TransformMatchingParts):
}
@staticmethod
def get_mobject_parts(mobject):
def get_mobject_parts(mobject: Tex) -> list[SingleStringTex]:
return mobject.submobjects
@staticmethod
def get_mobject_key(mobject):
def get_mobject_key(mobject: Tex) -> str:
return mobject.get_tex()
@ -150,7 +157,7 @@ class TransformMatchingMTex(AnimationGroup):
"key_map": dict(),
}
def __init__(self, source_mobject, target_mobject, **kwargs):
def __init__(self, source_mobject: MTex, target_mobject: MTex, **kwargs):
digest_config(self, kwargs)
assert isinstance(source_mobject, MTex)
assert isinstance(target_mobject, MTex)

View file

@ -1,7 +1,14 @@
from __future__ import annotations
import operator as op
from typing import Callable
from manimlib.animation.animation import Animation
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from manimlib.mobject.mobject import Mobject
class UpdateFromFunc(Animation):
"""
@ -13,21 +20,31 @@ class UpdateFromFunc(Animation):
"suspend_mobject_updating": False,
}
def __init__(self, mobject, update_function, **kwargs):
def __init__(
self,
mobject: Mobject,
update_function: Callable[[Mobject]],
**kwargs
):
self.update_function = update_function
super().__init__(mobject, **kwargs)
def interpolate_mobject(self, alpha):
def interpolate_mobject(self, alpha: float) -> None:
self.update_function(self.mobject)
class UpdateFromAlphaFunc(UpdateFromFunc):
def interpolate_mobject(self, alpha):
def interpolate_mobject(self, alpha: float) -> None:
self.update_function(self.mobject, alpha)
class MaintainPositionRelativeTo(Animation):
def __init__(self, mobject, tracked_mobject, **kwargs):
def __init__(
self,
mobject: Mobject,
tracked_mobject: Mobject,
**kwargs
):
self.tracked_mobject = tracked_mobject
self.diff = op.sub(
mobject.get_center(),
@ -35,7 +52,7 @@ class MaintainPositionRelativeTo(Animation):
)
super().__init__(mobject, **kwargs)
def interpolate_mobject(self, alpha):
def interpolate_mobject(self, alpha: float) -> None:
target = self.tracked_mobject.get_center()
location = self.mobject.get_center()
self.mobject.shift(target - location + self.diff)