diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 9d15fbad..13acbdad 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -40,6 +40,7 @@ from manimlib.utils.bezier import integer_interpolate from manimlib.utils.bezier import interpolate from manimlib.utils.paths import straight_path from manimlib.utils.simple_functions import get_parameters +from manimlib.utils.simple_functions import get_num_args from manimlib.utils.shaders import get_colormap_code from manimlib.utils.space_ops import angle_of_vector from manimlib.utils.space_ops import get_norm @@ -51,7 +52,7 @@ SubmobjectType = TypeVar('SubmobjectType', bound='Mobject') if TYPE_CHECKING: - from typing import Callable, Iterator, Union, Tuple, Optional + from typing import Callable, Iterator, Union, Tuple, Optional, Any import numpy.typing as npt from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, UniformDict, Self from moderngl.context import Context @@ -831,9 +832,9 @@ class Mobject(object): # Updating def init_updaters(self): - self.time_based_updaters: list[TimeBasedUpdater] = [] - self.non_time_updaters: list[NonTimeUpdater] = [] - self.has_updaters: bool = False + self.time_based_updaters: list[TimeBasedUpdater] = list() + self.non_time_updaters: list[NonTimeUpdater] = list() + self._has_updaters_in_family: Optional[bool] = False self.updating_suspended: bool = False def update(self, dt: float = 0, recurse: bool = True) -> Self: @@ -848,36 +849,23 @@ class Mobject(object): updater(self) return self - def get_time_based_updaters(self) -> list[TimeBasedUpdater]: - return self.time_based_updaters - - def has_time_based_updater(self) -> bool: - return len(self.time_based_updaters) > 0 - def get_updaters(self) -> list[Updater]: - return self.time_based_updaters + self.non_time_updaters + return [*self.time_based_updaters, *self.non_time_updaters] - def get_family_updaters(self) -> list[Updater]: - return list(it.chain(*[sm.get_updaters() for sm in self.get_family()])) - - def add_updater( - self, - update_func: Updater, - call_updater: bool = True - ) -> Self: - if "dt" in get_parameters(update_func): + def add_updater(self, update_func: Updater, call: bool = True) -> Self: + if get_num_args(update_func) > 1: self.time_based_updaters.append(update_func) else: self.non_time_updaters.append(update_func) - if call_updater: + if call: self.update(dt=0) self.refresh_has_updater_status() return self def insert_updater(self, update_func: Updater, index=0): - if "dt" in get_parameters(update_func): + if get_num_args(update_func) > 1: self.time_based_updaters.insert(index, update_func) else: self.non_time_updaters.insert(index, update_func) diff --git a/manimlib/scene/scene.py b/manimlib/scene/scene.py index c15682fb..a80d1866 100644 --- a/manimlib/scene/scene.py +++ b/manimlib/scene/scene.py @@ -345,17 +345,9 @@ class Scene(object): mobject.update(dt) def should_update_mobjects(self) -> bool: - return self.always_update_mobjects or any([ - len(mob.get_family_updaters()) > 0 - for mob in self.mobjects - ]) - - def has_time_based_updaters(self) -> bool: - return any([ - sm.has_time_based_updater() - for mob in self.mobjects() - for sm in mob.get_family() - ]) + return self.always_update_mobjects or any( + mob.has_updaters() for mob in self.mobjects + ) # Related to time diff --git a/manimlib/utils/simple_functions.py b/manimlib/utils/simple_functions.py index 33997857..7c7fa211 100644 --- a/manimlib/utils/simple_functions.py +++ b/manimlib/utils/simple_functions.py @@ -9,7 +9,7 @@ import numpy as np from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Callable, TypeVar + from typing import Callable, TypeVar, Iterable from manimlib.typing import FloatArray Scalable = TypeVar("Scalable", float, FloatArray) @@ -30,11 +30,11 @@ def gen_choose(n: int, r: int) -> int: def get_num_args(function: Callable) -> int: - return len(get_parameters(function)) + return len(list(get_parameters(function))) -def get_parameters(function: Callable) -> list: - return list(inspect.signature(function).parameters.keys()) +def get_parameters(function: Callable) -> Iterable[str]: + return inspect.signature(function).parameters.keys() # Just to have a less heavyweight name for this extremely common operation #