From 711438f625c4be9aab6b89a616963daacb215e54 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Sat, 3 Feb 2024 18:00:47 -0600 Subject: [PATCH] Update the types in decorator methods using @wraps This is method to address issues flagged by pyright --- manimlib/mobject/mobject.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 2c9a781d..2ac92bae 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -48,11 +48,12 @@ from manimlib.utils.space_ops import rotation_matrix_transpose from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Callable, Iterable, Iterator, Union, Tuple, Optional + from typing import Callable, Iterable, Iterator, Union, Tuple, Optional, TypeVar, Any import numpy.typing as npt from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, UniformDict, Self from moderngl.context import Context + T = TypeVar('T') TimeBasedUpdater = Callable[["Mobject", float], "Mobject" | None] NonTimeUpdater = Callable[["Mobject"], "Mobject" | None] Updater = Union[TimeBasedUpdater, NonTimeUpdater] @@ -164,20 +165,23 @@ class Mobject(object): mob.note_changed_data() return self - def affects_data(func: Callable): + @staticmethod + def affects_data(func: Callable[..., T]) -> Callable[..., T]: @wraps(func) def wrapper(self, *args, **kwargs): - func(self, *args, **kwargs) + result = func(self, *args, **kwargs) self.note_changed_data() + return result return wrapper - def affects_family_data(func: Callable): + @staticmethod + def affects_family_data(func: Callable[..., T]) -> Callable[..., T]: @wraps(func) def wrapper(self, *args, **kwargs): - func(self, *args, **kwargs) + result = func(self, *args, **kwargs) for mob in self.family_members_with_points(): mob.note_changed_data() - return self + return result return wrapper # Only these methods should directly affect points @@ -578,7 +582,8 @@ class Mobject(object): # Copying and serialization - def stash_mobject_pointers(func: Callable): + @staticmethod + def stash_mobject_pointers(func: Callable[..., T]) -> Callable[..., T]: @wraps(func) def wrapper(self, *args, **kwargs): uncopied_attrs = ["parents", "target", "saved_state"] @@ -1863,7 +1868,8 @@ class Mobject(object): # Operations touching shader uniforms - def affects_shader_info_id(func: Callable): + @staticmethod + def affects_shader_info_id(func: Callable[..., T]) -> Callable[..., T]: @wraps(func) def wrapper(self, *args, **kwargs): result = func(self, *args, **kwargs)