Clean up updater matters, prune unused functions

This commit is contained in:
Grant Sanderson 2024-03-07 16:07:39 -03:00
parent 70b839e188
commit 83cd5d6246
3 changed files with 17 additions and 37 deletions

View file

@ -40,6 +40,7 @@ from manimlib.utils.bezier import integer_interpolate
from manimlib.utils.bezier import interpolate from manimlib.utils.bezier import interpolate
from manimlib.utils.paths import straight_path from manimlib.utils.paths import straight_path
from manimlib.utils.simple_functions import get_parameters from manimlib.utils.simple_functions import get_parameters
from manimlib.utils.simple_functions import get_num_args
from manimlib.utils.shaders import get_colormap_code from manimlib.utils.shaders import get_colormap_code
from manimlib.utils.space_ops import angle_of_vector from manimlib.utils.space_ops import angle_of_vector
from manimlib.utils.space_ops import get_norm from manimlib.utils.space_ops import get_norm
@ -51,7 +52,7 @@ SubmobjectType = TypeVar('SubmobjectType', bound='Mobject')
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable, Iterator, Union, Tuple, Optional from typing import Callable, Iterator, Union, Tuple, Optional, Any
import numpy.typing as npt import numpy.typing as npt
from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, UniformDict, Self from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, UniformDict, Self
from moderngl.context import Context from moderngl.context import Context
@ -831,9 +832,9 @@ class Mobject(object):
# Updating # Updating
def init_updaters(self): def init_updaters(self):
self.time_based_updaters: list[TimeBasedUpdater] = [] self.time_based_updaters: list[TimeBasedUpdater] = list()
self.non_time_updaters: list[NonTimeUpdater] = [] self.non_time_updaters: list[NonTimeUpdater] = list()
self.has_updaters: bool = False self._has_updaters_in_family: Optional[bool] = False
self.updating_suspended: bool = False self.updating_suspended: bool = False
def update(self, dt: float = 0, recurse: bool = True) -> Self: def update(self, dt: float = 0, recurse: bool = True) -> Self:
@ -848,36 +849,23 @@ class Mobject(object):
updater(self) updater(self)
return self return self
def get_time_based_updaters(self) -> list[TimeBasedUpdater]:
return self.time_based_updaters
def has_time_based_updater(self) -> bool:
return len(self.time_based_updaters) > 0
def get_updaters(self) -> list[Updater]: def get_updaters(self) -> list[Updater]:
return self.time_based_updaters + self.non_time_updaters return [*self.time_based_updaters, *self.non_time_updaters]
def get_family_updaters(self) -> list[Updater]: def add_updater(self, update_func: Updater, call: bool = True) -> Self:
return list(it.chain(*[sm.get_updaters() for sm in self.get_family()])) if get_num_args(update_func) > 1:
def add_updater(
self,
update_func: Updater,
call_updater: bool = True
) -> Self:
if "dt" in get_parameters(update_func):
self.time_based_updaters.append(update_func) self.time_based_updaters.append(update_func)
else: else:
self.non_time_updaters.append(update_func) self.non_time_updaters.append(update_func)
if call_updater: if call:
self.update(dt=0) self.update(dt=0)
self.refresh_has_updater_status() self.refresh_has_updater_status()
return self return self
def insert_updater(self, update_func: Updater, index=0): def insert_updater(self, update_func: Updater, index=0):
if "dt" in get_parameters(update_func): if get_num_args(update_func) > 1:
self.time_based_updaters.insert(index, update_func) self.time_based_updaters.insert(index, update_func)
else: else:
self.non_time_updaters.insert(index, update_func) self.non_time_updaters.insert(index, update_func)

View file

@ -345,17 +345,9 @@ class Scene(object):
mobject.update(dt) mobject.update(dt)
def should_update_mobjects(self) -> bool: def should_update_mobjects(self) -> bool:
return self.always_update_mobjects or any([ return self.always_update_mobjects or any(
len(mob.get_family_updaters()) > 0 mob.has_updaters() for mob in self.mobjects
for mob in self.mobjects )
])
def has_time_based_updaters(self) -> bool:
return any([
sm.has_time_based_updater()
for mob in self.mobjects()
for sm in mob.get_family()
])
# Related to time # Related to time

View file

@ -9,7 +9,7 @@ import numpy as np
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable, TypeVar from typing import Callable, TypeVar, Iterable
from manimlib.typing import FloatArray from manimlib.typing import FloatArray
Scalable = TypeVar("Scalable", float, FloatArray) Scalable = TypeVar("Scalable", float, FloatArray)
@ -30,11 +30,11 @@ def gen_choose(n: int, r: int) -> int:
def get_num_args(function: Callable) -> int: def get_num_args(function: Callable) -> int:
return len(get_parameters(function)) return len(list(get_parameters(function)))
def get_parameters(function: Callable) -> list: def get_parameters(function: Callable) -> Iterable[str]:
return list(inspect.signature(function).parameters.keys()) return inspect.signature(function).parameters.keys()
# Just to have a less heavyweight name for this extremely common operation # Just to have a less heavyweight name for this extremely common operation
# #