diff --git a/manimlib/mobject/changing.py b/manimlib/mobject/changing.py index ab4313d8..68b44104 100644 --- a/manimlib/mobject/changing.py +++ b/manimlib/mobject/changing.py @@ -11,7 +11,7 @@ from manimlib.utils.rate_functions import smooth from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Callable, List, Iterable + from typing import Callable, List, Iterable, Self from manimlib.typing import ManimColor, Vect3 @@ -49,7 +49,7 @@ class AnimatedBoundary(VGroup): lambda m, dt: self.update_boundary_copies(dt) ) - def update_boundary_copies(self, dt: float) -> None: + def update_boundary_copies(self, dt: float) -> Self: # Not actual time, but something which passes at # an altered rate to make the implementation below # cleaner @@ -79,6 +79,7 @@ class AnimatedBoundary(VGroup): ) self.total_time += dt + return self def full_family_become_partial( self, @@ -86,7 +87,7 @@ class AnimatedBoundary(VGroup): mob2: VMobject, a: float, b: float - ): + ) -> Self: family1 = mob1.family_members_with_points() family2 = mob2.family_members_with_points() for sm1, sm2 in zip(family1, family2): @@ -118,7 +119,7 @@ class TracedPath(VMobject): self.traced_points: list[np.ndarray] = [] self.add_updater(lambda m, dt: m.update_path(dt)) - def update_path(self, dt: float): + def update_path(self, dt: float) -> Self: if dt == 0: return self point = self.traced_point_func().copy() diff --git a/manimlib/mobject/coordinate_systems.py b/manimlib/mobject/coordinate_systems.py index 516036b4..fd7bb8af 100644 --- a/manimlib/mobject/coordinate_systems.py +++ b/manimlib/mobject/coordinate_systems.py @@ -21,6 +21,7 @@ from manimlib.mobject.svg.tex_mobject import Tex from manimlib.mobject.types.dot_cloud import DotCloud from manimlib.mobject.types.surface import ParametricSurface from manimlib.mobject.types.vectorized_mobject import VGroup +from manimlib.mobject.types.vectorized_mobject import VMobject from manimlib.utils.dict_ops import merge_dicts_recursively from manimlib.utils.simple_functions import binary_search from manimlib.utils.space_ops import angle_of_vector @@ -31,7 +32,7 @@ from manimlib.utils.space_ops import normalize from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Callable, Iterable, Sequence, Type, TypeVar + from typing import Callable, Iterable, Sequence, Type, TypeVar, Optional, Self from manimlib.mobject.mobject import Mobject from manimlib.typing import ManimColor, Vect3, Vect3Array, VectN, RangeSpecifier @@ -235,7 +236,13 @@ class CoordinateSystem(ABC): """ return self.input_to_graph_point(x, graph) - def bind_graph_to_func(self, graph, func, jagged=False, get_discontinuities=None): + def bind_graph_to_func( + self, + graph: VMobject, + func: Callable[[Vect3], Vect3], + jagged: bool = False, + get_discontinuities: Optional[Callable[[], Vect3]] = None + ) -> VMobject: """ Use for graphing functions which might change over time, or change with conditions @@ -659,7 +666,7 @@ class NumberPlane(Axes): kwargs["buff"] = 0 return Arrow(self.c2p(0, 0), self.c2p(*coords), **kwargs) - def prepare_for_nonlinear_transform(self, num_inserted_curves: int = 50): + def prepare_for_nonlinear_transform(self, num_inserted_curves: int = 50) -> Self: for mob in self.family_members_with_points(): num_curves = mob.get_num_curves() if num_inserted_curves > num_curves: @@ -698,7 +705,7 @@ class ComplexPlane(NumberPlane): skip_first: bool = True, font_size: int = 36, **kwargs - ): + ) -> Self: if numbers is None: numbers = self.get_default_coordinate_values(skip_first) diff --git a/manimlib/mobject/geometry.py b/manimlib/mobject/geometry.py index d418271b..18ee53c8 100644 --- a/manimlib/mobject/geometry.py +++ b/manimlib/mobject/geometry.py @@ -30,7 +30,7 @@ from manimlib.utils.space_ops import rotation_matrix_transpose from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Iterable + from typing import Iterable, Self, Optional from manimlib.typing import ManimColor, Vect3, Vect3Array @@ -67,7 +67,7 @@ class TipableVMobject(VMobject): ) # Adding, Creating, Modifying tips - def add_tip(self, at_start: bool = False, **kwargs): + def add_tip(self, at_start: bool = False, **kwargs) -> Self: """ Adds a tip to the TipableVMobject instance, recognising that the endpoints might need to be switched if it's @@ -112,7 +112,7 @@ class TipableVMobject(VMobject): tip.shift(anchor - tip.get_tip_point()) return tip - def reset_endpoints_based_on_tip(self, tip: ArrowTip, at_start: bool): + def reset_endpoints_based_on_tip(self, tip: ArrowTip, at_start: bool) -> Self: if self.get_length() == 0: # Zero length, put_start_and_end_on wouldn't # work @@ -127,7 +127,7 @@ class TipableVMobject(VMobject): self.put_start_and_end_on(start, end) return self - def asign_tip_attr(self, tip: ArrowTip, at_start: bool): + def asign_tip_attr(self, tip: ArrowTip, at_start: bool) -> Self: if at_start: self.start_tip = tip else: @@ -258,7 +258,7 @@ class Arc(TipableVMobject): angle = angle_of_vector(self.get_end() - self.get_arc_center()) return angle % TAU - def move_arc_center_to(self, point: Vect3): + def move_arc_center_to(self, point: Vect3) -> Self: self.shift(point - self.get_arc_center()) return self @@ -318,7 +318,7 @@ class Circle(Arc): dim_to_match: int = 0, stretch: bool = False, buff: float = MED_SMALL_BUFF - ): + ) -> Self: self.replace(mobject, dim_to_match, stretch) self.stretch((self.get_width() + 2 * buff) / self.get_width(), 0) self.stretch((self.get_height() + 2 * buff) / self.get_height(), 1) @@ -475,7 +475,7 @@ class Line(TipableVMobject): end: Vect3, buff: float = 0, path_arc: float = 0 - ): + ) -> Self: vect = end - start dist = get_norm(vect) if np.isclose(dist, 0): @@ -504,9 +504,10 @@ class Line(TipableVMobject): self.set_points_as_corners([start, end]) return self - def set_path_arc(self, new_value: float) -> None: + def set_path_arc(self, new_value: float) -> Self: self.path_arc = new_value self.init_points() + return self def set_start_and_end_attrs(self, start: Vect3 | Mobject, end: Vect3 | Mobject): # If either start or end are Mobjects, this @@ -541,7 +542,7 @@ class Line(TipableVMobject): result[:len(point)] = point return result - def put_start_and_end_on(self, start: Vect3, end: Vect3): + def put_start_and_end_on(self, start: Vect3, end: Vect3) -> Self: curr_start, curr_end = self.get_start_and_end() if np.isclose(curr_start, curr_end).all(): # Handle null lines more gracefully @@ -569,7 +570,7 @@ class Line(TipableVMobject): def get_slope(self) -> float: return np.tan(self.get_angle()) - def set_angle(self, angle: float, about_point: Vect3 | None = None): + def set_angle(self, angle: float, about_point: Optional[Vect3] = None) -> Self: if about_point is None: about_point = self.get_start() self.rotate( @@ -695,13 +696,13 @@ class Arrow(Line): end: Vect3, buff: float = 0, path_arc: float = 0 - ): + ) -> Self: super().set_points_by_ends(start, end, buff, path_arc) self.insert_tip_anchor() self.create_tip_with_stroke_width() return self - def insert_tip_anchor(self): + def insert_tip_anchor(self) -> Self: prev_end = self.get_end() arc_len = self.get_arc_length() tip_len = self.get_stroke_width() * self.width_to_tip_len * self.tip_width_ratio @@ -716,7 +717,7 @@ class Arrow(Line): return self @Mobject.affects_data - def create_tip_with_stroke_width(self): + def create_tip_with_stroke_width(self) -> Self: if self.get_num_points() < 3: return self tip_width = self.tip_width_ratio * min( @@ -727,7 +728,7 @@ class Arrow(Line): self.data['stroke_width'][-3:, 0] = tip_width * np.linspace(1, 0, 3) return self - def reset_tip(self): + def reset_tip(self) -> Self: self.set_points_by_ends( self.get_start(), self.get_end(), path_arc=self.path_arc @@ -739,13 +740,13 @@ class Arrow(Line): color: ManimColor | Iterable[ManimColor] | None = None, width: float | Iterable[float] | None = None, *args, **kwargs - ): + ) -> Self: super().set_stroke(color=color, width=width, *args, **kwargs) if self.has_points(): self.reset_tip() return self - def _handle_scale_side_effects(self, scale_factor: float): + def _handle_scale_side_effects(self, scale_factor: float) -> Self: if scale_factor != 1.0: self.reset_tip() return self @@ -787,7 +788,7 @@ class FillArrow(Line): end: Vect3, buff: float = 0, path_arc: float = 0 - ) -> None: + ) -> Self: # Find the right tip length and thickness vect = end - start length = max(get_norm(vect), 1e-8) @@ -848,8 +849,9 @@ class FillArrow(Line): axis=rotate_vector(self.get_unit_vector(), -PI / 2), ) self.shift(start - self.get_start()) + return self - def reset_points_around_ends(self): + def reset_points_around_ends(self) -> Self: self.set_points_by_ends( self.get_start().copy(), self.get_end().copy(), @@ -864,21 +866,21 @@ class FillArrow(Line): def get_end(self) -> Vect3: return self.get_points()[self.tip_index] - def put_start_and_end_on(self, start: Vect3, end: Vect3): + def put_start_and_end_on(self, start: Vect3, end: Vect3) -> Self: self.set_points_by_ends(start, end, buff=0, path_arc=self.path_arc) return self - def scale(self, *args, **kwargs): + def scale(self, *args, **kwargs) -> Self: super().scale(*args, **kwargs) self.reset_points_around_ends() return self - def set_thickness(self, thickness: float): + def set_thickness(self, thickness: float) -> Self: self.thickness = thickness self.reset_points_around_ends() return self - def set_path_arc(self, path_arc: float): + def set_path_arc(self, path_arc: float) -> Self: self.path_arc = path_arc self.reset_points_around_ends() return self @@ -921,7 +923,7 @@ class Polygon(VMobject): def get_vertices(self) -> Vect3Array: return self.get_start_anchors() - def round_corners(self, radius: float | None = None): + def round_corners(self, radius: Optional[float] = None) -> Self: if radius is None: verts = self.get_vertices() min_edge_length = min( diff --git a/manimlib/mobject/matrix.py b/manimlib/mobject/matrix.py index 67abd686..cbf5458a 100644 --- a/manimlib/mobject/matrix.py +++ b/manimlib/mobject/matrix.py @@ -18,7 +18,7 @@ from manimlib.mobject.types.vectorized_mobject import VMobject from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Sequence + from typing import Sequence, Self import numpy.typing as npt from manimlib.mobject.mobject import Mobject from manimlib.typing import ManimColor, Vect3 @@ -129,7 +129,7 @@ class Matrix(VMobject): v_buff: float, h_buff: float, aligned_corner: Vect3, - ): + ) -> Self: for i, row in enumerate(matrix): for j, elem in enumerate(row): mob = matrix[i][j] @@ -139,7 +139,7 @@ class Matrix(VMobject): ) return self - def add_brackets(self, v_buff: float, h_buff: float): + def add_brackets(self, v_buff: float, h_buff: float) -> Self: height = len(self.mob_matrix) brackets = Tex("".join(( R"\left[\begin{array}{c}", @@ -168,13 +168,13 @@ class Matrix(VMobject): for row in self.mob_matrix ]) - def set_column_colors(self, *colors: ManimColor): + def set_column_colors(self, *colors: ManimColor) -> Self: columns = self.get_columns() for color, column in zip(colors, columns): column.set_color(color) return self - def add_background_to_entries(self): + def add_background_to_entries(self) -> Self: for mob in self.get_entries(): mob.add_background_rectangle() return self diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index 19f758fd..f6b27744 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -48,7 +48,7 @@ from manimlib.utils.space_ops import rotation_matrix_transpose from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Callable, Iterable, Union, Tuple, Optional + from typing import Callable, Iterable, Union, Tuple, Optional, Self import numpy.typing as npt from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, UniformDict from moderngl.context import Context @@ -143,7 +143,7 @@ class Mobject(object): # Typically implemented in subclass, unlpess purposefully left blank pass - def set_uniforms(self, uniforms: dict): + def set_uniforms(self, uniforms: dict) -> Self: for key, value in uniforms.items(): if isinstance(value, np.ndarray): value = value.copy() @@ -151,15 +151,16 @@ class Mobject(object): return self @property - def animate(self): + def animate(self) -> _AnimationBuilder: # Borrowed from https://github.com/ManimCommunity/manim/ return _AnimationBuilder(self) - def note_changed_data(self, recurse_up: bool = True): + def note_changed_data(self, recurse_up: bool = True) -> Self: self._data_has_changed = True if recurse_up: for mob in self.parents: mob.note_changed_data() + return self def affects_data(func: Callable): @wraps(func) @@ -179,7 +180,7 @@ class Mobject(object): # Only these methods should directly affect points @affects_data - def set_data(self, data: np.ndarray): + def set_data(self, data: np.ndarray) -> Self: assert(data.dtype == self.data.dtype) self.data = data.copy() return self @@ -189,7 +190,7 @@ class Mobject(object): self, new_length: int, resize_func: Callable[[np.ndarray, int], np.ndarray] = resize_array - ): + ) -> Self: if new_length == 0: if len(self.data) > 0: self._data_defaults[:1] = self.data[:1] @@ -201,13 +202,13 @@ class Mobject(object): return self @affects_data - def set_points(self, points: Vect3Array): + def set_points(self, points: Vect3Array) -> Self: self.resize_points(len(points), resize_func=resize_preserving_order) self.data["point"][:] = points return self @affects_data - def append_points(self, new_points: Vect3Array): + def append_points(self, new_points: Vect3Array) -> Self: n = self.get_num_points() self.resize_points(n + len(new_points)) # Have most data default to the last value @@ -218,7 +219,7 @@ class Mobject(object): return self @affects_family_data - def reverse_points(self): + def reverse_points(self) -> Self: for mob in self.get_family(): mob.data = mob.data[::-1] return self @@ -230,7 +231,7 @@ class Mobject(object): about_point: Vect3 | None = None, about_edge: Vect3 = ORIGIN, works_on_bounding_box: bool = False - ): + ) -> Self: if about_point is None and about_edge is not None: about_point = self.get_bounding_box_point(about_edge) @@ -257,15 +258,16 @@ class Mobject(object): # Others related to points - def match_points(self, mobject: Mobject): + def match_points(self, mobject: Mobject) -> Self: self.set_points(mobject.get_points()) return self def get_points(self) -> Vect3Array: return self.data["point"] - def clear_points(self) -> None: + def clear_points(self) -> Self: self.resize_points(0) + return self def get_num_points(self) -> int: return len(self.get_points()) @@ -307,7 +309,7 @@ class Mobject(object): self, recurse_down: bool = False, recurse_up: bool = True - ): + ) -> Self: for mob in self.get_family(recurse_down): mob.needs_new_bounding_box = True if recurse_up: @@ -342,23 +344,23 @@ class Mobject(object): # Family matters - def __getitem__(self, value: int | slice) -> Mobject: + def __getitem__(self, value: int | slice) -> Self: if isinstance(value, slice): GroupClass = self.get_group_class() return GroupClass(*self.split().__getitem__(value)) return self.split().__getitem__(value) - def __iter__(self): + def __iter__(self) -> Iterable[Self]: return iter(self.split()) - def __len__(self): + def __len__(self) -> int: return len(self.split()) - def split(self) -> list[Mobject]: + def split(self) -> list[Self]: return self.submobjects @affects_data - def assemble_family(self): + def assemble_family(self) -> Self: sub_families = (sm.get_family() for sm in self.submobjects) self.family = [self, *it.chain(*sub_families)] self.refresh_has_updater_status() @@ -367,13 +369,13 @@ class Mobject(object): parent.assemble_family() return self - def get_family(self, recurse: bool = True) -> list[Mobject]: + def get_family(self, recurse: bool = True) -> list[Self]: if recurse: return self.family else: return [self] - def family_members_with_points(self): + def family_members_with_points(self) -> list[Self]: return [m for m in self.get_family() if m.has_points()] def get_ancestors(self, extended: bool = False) -> list[Mobject]: @@ -397,7 +399,7 @@ class Mobject(object): # Remove list redundancies while preserving order return list(dict.fromkeys(ancestors)) - def add(self, *mobjects: Mobject): + def add(self, *mobjects: Mobject) -> Self: if self in mobjects: raise Exception("Mobject cannot contain self") for mobject in mobjects: @@ -413,7 +415,7 @@ class Mobject(object): *to_remove: Mobject, reassemble: bool = True, recurse: bool = True - ): + ) -> Self: for parent in self.get_family(recurse): for child in to_remove: if child in parent.submobjects: @@ -424,14 +426,15 @@ class Mobject(object): parent.assemble_family() return self - def clear(self): + def clear(self) -> Self: self.remove(*self.submobjects, recurse=False) + return self - def add_to_back(self, *mobjects: Mobject): + def add_to_back(self, *mobjects: Mobject) -> Self: self.set_submobjects(list_update(mobjects, self.submobjects)) return self - def replace_submobject(self, index: int, new_submob: Mobject): + def replace_submobject(self, index: int, new_submob: Mobject) -> Self: old_submob = self.submobjects[index] if self in old_submob.parents: old_submob.parents.remove(self) @@ -440,17 +443,17 @@ class Mobject(object): self.assemble_family() return self - def insert_submobject(self, index: int, new_submob: Mobject): + def insert_submobject(self, index: int, new_submob: Mobject) -> Self: self.submobjects.insert(index, new_submob) self.assemble_family() return self - def set_submobjects(self, submobject_list: list[Mobject]): + def set_submobjects(self, submobject_list: list[Mobject]) -> Self: self.remove(*self.submobjects, reassemble=False) self.add(*submobject_list) return self - def digest_mobject_attrs(self): + def digest_mobject_attrs(self) -> Self: """ Ensures all attributes which are mobjects are included in the submobjects list. @@ -466,7 +469,7 @@ class Mobject(object): direction: Vect3 = RIGHT, center: bool = True, **kwargs - ): + ) -> Self: for m1, m2 in zip(self.submobjects, self.submobjects[1:]): m2.next_to(m1, direction, **kwargs) if center: @@ -485,7 +488,7 @@ class Mobject(object): v_buff_ratio: float = 0.5, aligned_edge: Vect3 = ORIGIN, fill_rows_first: bool = True - ): + ) -> Self: submobs = self.submobjects if n_rows is None and n_cols is None: n_rows = int(np.sqrt(len(submobs))) @@ -519,7 +522,7 @@ class Mobject(object): self.center() return self - def arrange_to_fit_dim(self, length: float, dim: int, about_edge=ORIGIN): + def arrange_to_fit_dim(self, length: float, dim: int, about_edge=ORIGIN) -> Self: ref_point = self.get_bounding_box_point(about_edge) n_submobs = len(self.submobjects) if n_submobs <= 1: @@ -535,20 +538,20 @@ class Mobject(object): self.move_to(ref_point, about_edge) return self - def arrange_to_fit_width(self, width: float, about_edge=ORIGIN): + def arrange_to_fit_width(self, width: float, about_edge=ORIGIN) -> Self: return self.arrange_to_fit_dim(width, 0, about_edge) - def arrange_to_fit_height(self, height: float, about_edge=ORIGIN): + def arrange_to_fit_height(self, height: float, about_edge=ORIGIN) -> Self: return self.arrange_to_fit_dim(height, 1, about_edge) - def arrange_to_fit_depth(self, depth: float, about_edge=ORIGIN): + def arrange_to_fit_depth(self, depth: float, about_edge=ORIGIN) -> Self: return self.arrange_to_fit_dim(depth, 2, about_edge) def sort( self, point_to_num_func: Callable[[np.ndarray], float] = lambda p: p[0], submob_func: Callable[[Mobject]] | None = None - ): + ) -> Self: if submob_func is not None: self.submobjects.sort(key=submob_func) else: @@ -556,7 +559,7 @@ class Mobject(object): self.assemble_family() return self - def shuffle(self, recurse: bool = False): + def shuffle(self, recurse: bool = False) -> Self: if recurse: for submob in self.submobjects: submob.shuffle(recurse=True) @@ -564,7 +567,7 @@ class Mobject(object): self.assemble_family() return self - def reverse_submobjects(self): + def reverse_submobjects(self) -> Self: self.submobjects.reverse() self.assemble_family() return self @@ -588,19 +591,20 @@ class Mobject(object): return wrapper @stash_mobject_pointers - def serialize(self): + def serialize(self) -> bytes: return pickle.dumps(self) - def deserialize(self, data: bytes): + def deserialize(self, data: bytes) -> Self: self.become(pickle.loads(data)) return self - def deepcopy(self): + def deepcopy(self) -> Self: result = copy.deepcopy(self) result._shaders_initialized = False result._data_has_changed = True + return result - def copy(self, deep: bool = False): + def copy(self, deep: bool = False) -> Self: if deep: return self.deepcopy() @@ -644,30 +648,30 @@ class Mobject(object): setattr(result, attr, value.copy()) return result - def generate_target(self, use_deepcopy: bool = False): + def generate_target(self, use_deepcopy: bool = False) -> Self: self.target = self.copy(deep=use_deepcopy) self.target.saved_state = self.saved_state return self.target - def save_state(self, use_deepcopy: bool = False): + def save_state(self, use_deepcopy: bool = False) -> Self: self.saved_state = self.copy(deep=use_deepcopy) self.saved_state.target = self.target return self - def restore(self): + def restore(self) -> Self: if not hasattr(self, "saved_state") or self.saved_state is None: raise Exception("Trying to restore without having saved") self.become(self.saved_state) return self - def save_to_file(self, file_path: str, supress_overwrite_warning: bool = False): + def save_to_file(self, file_path: str) -> Self: with open(file_path, "wb") as fp: fp.write(self.serialize()) log.info(f"Saved mobject to {file_path}") return self @staticmethod - def load(file_path): + def load(file_path) -> Mobject: if not os.path.exists(file_path): log.error(f"No file found at {file_path}") sys.exit(2) @@ -675,7 +679,7 @@ class Mobject(object): mobject = pickle.load(fp) return mobject - def become(self, mobject: Mobject, match_updaters=False): + def become(self, mobject: Mobject, match_updaters=False) -> Self: """ Edit all data and submobjects to be idential to another mobject @@ -735,18 +739,20 @@ class Mobject(object): # Creating new Mobjects from this one - def replicate(self, n: int) -> Group: + def replicate(self, n: int) -> Self: group_class = self.get_group_class() return group_class(*(self.copy() for _ in range(n))) - def get_grid(self, - n_rows: int, - n_cols: int, - height: float | None = None, - width: float | None = None, - group_by_rows: bool = False, - group_by_cols: bool = False, - **kwargs) -> Group: + def get_grid( + self, + n_rows: int, + n_cols: int, + height: float | None = None, + width: float | None = None, + group_by_rows: bool = False, + group_by_cols: bool = False, + **kwargs + ) -> Self: """ Returns a new mobject containing multiple copies of this one arranged in a grid @@ -777,7 +783,7 @@ class Mobject(object): self.has_updaters: bool = False self.updating_suspended: bool = False - def update(self, dt: float = 0, recurse: bool = True): + def update(self, dt: float = 0, recurse: bool = True) -> Self: if not self.has_updaters or self.updating_suspended: return self for updater in self.time_based_updaters: @@ -806,7 +812,7 @@ class Mobject(object): update_function: Updater, index: int | None = None, call_updater: bool = True - ): + ) -> Self: if "dt" in get_parameters(update_function): updater_list = self.time_based_updaters else: @@ -824,14 +830,14 @@ class Mobject(object): self.update(dt=0) return self - def remove_updater(self, update_function: Updater): + def remove_updater(self, update_function: Updater) -> Self: for updater_list in [self.time_based_updaters, self.non_time_updaters]: while update_function in updater_list: updater_list.remove(update_function) self.refresh_has_updater_status() return self - def clear_updaters(self, recurse: bool = True): + def clear_updaters(self, recurse: bool = True) -> Self: self.time_based_updaters = [] self.non_time_updaters = [] if recurse: @@ -840,20 +846,20 @@ class Mobject(object): self.refresh_has_updater_status() return self - def match_updaters(self, mobject: Mobject): + def match_updaters(self, mobject: Mobject) -> Self: self.clear_updaters() for updater in mobject.get_updaters(): self.add_updater(updater) return self - def suspend_updating(self, recurse: bool = True): + def suspend_updating(self, recurse: bool = True) -> Self: self.updating_suspended = True if recurse: for submob in self.submobjects: submob.suspend_updating(recurse) return self - def resume_updating(self, recurse: bool = True, call_updater: bool = True): + def resume_updating(self, recurse: bool = True, call_updater: bool = True) -> Self: self.updating_suspended = False if recurse: for submob in self.submobjects: @@ -864,7 +870,7 @@ class Mobject(object): self.update(dt=0, recurse=recurse) return self - def refresh_has_updater_status(self): + def refresh_has_updater_status(self) -> Self: self.has_updaters = any(mob.get_updaters() for mob in self.get_family()) return self @@ -873,14 +879,14 @@ class Mobject(object): def is_changing(self) -> bool: return self._is_animating or self.has_updaters - def set_animating_status(self, is_animating: bool, recurse: bool = True): + def set_animating_status(self, is_animating: bool, recurse: bool = True) -> Self: for mob in (*self.get_family(recurse), *self.get_ancestors()): mob._is_animating = is_animating return self # Transforming operations - def shift(self, vector: Vect3): + def shift(self, vector: Vect3) -> Self: self.apply_points_function( lambda points: points + vector, about_edge=None, @@ -894,7 +900,7 @@ class Mobject(object): min_scale_factor: float = 1e-8, about_point: Vect3 | None = None, about_edge: Vect3 = ORIGIN - ): + ) -> Self: """ Default behavior is to scale about the center of the mobject. The argument about_edge can be a vector, indicating which side of @@ -923,14 +929,14 @@ class Mobject(object): # any other changes when the size gets altered pass - def stretch(self, factor: float, dim: int, **kwargs): + def stretch(self, factor: float, dim: int, **kwargs) -> Self: def func(points): points[:, dim] *= factor return points self.apply_points_function(func, works_on_bounding_box=True, **kwargs) return self - def rotate_about_origin(self, angle: float, axis: Vect3 = OUT): + def rotate_about_origin(self, angle: float, axis: Vect3 = OUT) -> Self: return self.rotate(angle, axis, about_point=ORIGIN) def rotate( @@ -939,7 +945,7 @@ class Mobject(object): axis: Vect3 = OUT, about_point: Vect3 | None = None, **kwargs - ): + ) -> Self: rot_matrix_T = rotation_matrix_transpose(angle, axis) self.apply_points_function( lambda points: np.dot(points, rot_matrix_T), @@ -948,10 +954,10 @@ class Mobject(object): ) return self - def flip(self, axis: Vect3 = UP, **kwargs): + def flip(self, axis: Vect3 = UP, **kwargs) -> Self: return self.rotate(TAU / 2, axis, **kwargs) - def apply_function(self, function: Callable[[np.ndarray], np.ndarray], **kwargs): + def apply_function(self, function: Callable[[np.ndarray], np.ndarray], **kwargs) -> Self: # Default to applying matrix about the origin, not mobjects center if len(kwargs) == 0: kwargs["about_point"] = ORIGIN @@ -961,19 +967,19 @@ class Mobject(object): ) return self - def apply_function_to_position(self, function: Callable[[np.ndarray], np.ndarray]): + def apply_function_to_position(self, function: Callable[[np.ndarray], np.ndarray]) -> Self: self.move_to(function(self.get_center())) return self def apply_function_to_submobject_positions( self, function: Callable[[np.ndarray], np.ndarray] - ): + ) -> Self: for submob in self.submobjects: submob.apply_function_to_position(function) return self - def apply_matrix(self, matrix: npt.ArrayLike, **kwargs): + def apply_matrix(self, matrix: npt.ArrayLike, **kwargs) -> Self: # Default to applying matrix about the origin, not mobjects center if ("about_point" not in kwargs) and ("about_edge" not in kwargs): kwargs["about_point"] = ORIGIN @@ -986,7 +992,7 @@ class Mobject(object): ) return self - def apply_complex_function(self, function: Callable[[complex], complex], **kwargs): + def apply_complex_function(self, function: Callable[[complex], complex], **kwargs) -> Self: def R3_func(point): x, y, z = point xy_complex = function(complex(x, y)) @@ -1002,7 +1008,7 @@ class Mobject(object): direction: Vect3 = RIGHT, axis: Vect3 = DOWN, wag_factor: float = 1.0 - ): + ) -> Self: for mob in self.family_members_with_points(): alphas = np.dot(mob.get_points(), np.transpose(axis)) alphas -= min(alphas) @@ -1016,7 +1022,7 @@ class Mobject(object): # Positioning methods - def center(self): + def center(self) -> Self: self.shift(-self.get_center()) return self @@ -1024,7 +1030,7 @@ class Mobject(object): self, direction: Vect3, buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER - ): + ) -> Self: """ Direction just needs to be a vector pointing towards side or corner in the 2d plane. @@ -1040,14 +1046,14 @@ class Mobject(object): self, corner: Vect3 = LEFT + DOWN, buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER - ): + ) -> Self: return self.align_on_border(corner, buff) def to_edge( self, edge: Vect3 = LEFT, buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER - ): + ) -> Self: return self.align_on_border(edge, buff) def next_to( @@ -1059,7 +1065,7 @@ class Mobject(object): submobject_to_align: Mobject | None = None, index_of_submobject_to_align: int | slice | None = None, coor_mask: Vect3 = np.array([1, 1, 1]), - ): + ) -> Self: if isinstance(mobject_or_point, Mobject): mob = mobject_or_point if index_of_submobject_to_align is not None: @@ -1081,7 +1087,7 @@ class Mobject(object): self.shift((target_point - point_to_align + buff * direction) * coor_mask) return self - def shift_onto_screen(self, **kwargs): + def shift_onto_screen(self, **kwargs) -> Self: space_lengths = [FRAME_X_RADIUS, FRAME_Y_RADIUS] for vect in UP, DOWN, LEFT, RIGHT: dim = np.argmax(np.abs(vect)) @@ -1092,7 +1098,7 @@ class Mobject(object): self.to_edge(vect, **kwargs) return self - def is_off_screen(self): + def is_off_screen(self) -> bool: if self.get_left()[0] > FRAME_X_RADIUS: return True if self.get_right()[0] < -FRAME_X_RADIUS: @@ -1103,14 +1109,14 @@ class Mobject(object): return True return False - def stretch_about_point(self, factor: float, dim: int, point: Vect3): + def stretch_about_point(self, factor: float, dim: int, point: Vect3) -> Self: return self.stretch(factor, dim, about_point=point) - def stretch_in_place(self, factor: float, dim: int): + def stretch_in_place(self, factor: float, dim: int) -> Self: # Now redundant with stretch return self.stretch(factor, dim) - def rescale_to_fit(self, length: float, dim: int, stretch: bool = False, **kwargs): + def rescale_to_fit(self, length: float, dim: int, stretch: bool = False, **kwargs) -> Self: old_length = self.length_over_dim(dim) if old_length == 0: return self @@ -1120,50 +1126,50 @@ class Mobject(object): self.scale(length / old_length, **kwargs) return self - def stretch_to_fit_width(self, width: float, **kwargs): + def stretch_to_fit_width(self, width: float, **kwargs) -> Self: return self.rescale_to_fit(width, 0, stretch=True, **kwargs) - def stretch_to_fit_height(self, height: float, **kwargs): + def stretch_to_fit_height(self, height: float, **kwargs) -> Self: return self.rescale_to_fit(height, 1, stretch=True, **kwargs) - def stretch_to_fit_depth(self, depth: float, **kwargs): + def stretch_to_fit_depth(self, depth: float, **kwargs) -> Self: return self.rescale_to_fit(depth, 2, stretch=True, **kwargs) - def set_width(self, width: float, stretch: bool = False, **kwargs): + def set_width(self, width: float, stretch: bool = False, **kwargs) -> Self: return self.rescale_to_fit(width, 0, stretch=stretch, **kwargs) - def set_height(self, height: float, stretch: bool = False, **kwargs): + def set_height(self, height: float, stretch: bool = False, **kwargs) -> Self: return self.rescale_to_fit(height, 1, stretch=stretch, **kwargs) - def set_depth(self, depth: float, stretch: bool = False, **kwargs): + def set_depth(self, depth: float, stretch: bool = False, **kwargs) -> Self: return self.rescale_to_fit(depth, 2, stretch=stretch, **kwargs) - def set_max_width(self, max_width: float, **kwargs): + def set_max_width(self, max_width: float, **kwargs) -> Self: if self.get_width() > max_width: self.set_width(max_width, **kwargs) return self - def set_max_height(self, max_height: float, **kwargs): + def set_max_height(self, max_height: float, **kwargs) -> Self: if self.get_height() > max_height: self.set_height(max_height, **kwargs) return self - def set_max_depth(self, max_depth: float, **kwargs): + def set_max_depth(self, max_depth: float, **kwargs) -> Self: if self.get_depth() > max_depth: self.set_depth(max_depth, **kwargs) return self - def set_min_width(self, min_width: float, **kwargs): + def set_min_width(self, min_width: float, **kwargs) -> Self: if self.get_width() < min_width: self.set_width(min_width, **kwargs) return self - def set_min_height(self, min_height: float, **kwargs): + def set_min_height(self, min_height: float, **kwargs) -> Self: if self.get_height() < min_height: self.set_height(min_height, **kwargs) return self - def set_min_depth(self, min_depth: float, **kwargs): + def set_min_depth(self, min_depth: float, **kwargs) -> Self: if self.get_depth() < min_depth: self.set_depth(min_depth, **kwargs) return self @@ -1174,7 +1180,7 @@ class Mobject(object): height: Optional[float] = None, depth: Optional[float] = None, **kwargs - ): + ) -> Self: if width is not None: self.set_width(width, stretch=True, **kwargs) if height is not None: @@ -1183,23 +1189,23 @@ class Mobject(object): self.set_depth(depth, stretch=True, **kwargs) return self - def set_coord(self, value: float, dim: int, direction: Vect3 = ORIGIN): + def set_coord(self, value: float, dim: int, direction: Vect3 = ORIGIN) -> Self: curr = self.get_coord(dim, direction) shift_vect = np.zeros(self.dim) shift_vect[dim] = value - curr self.shift(shift_vect) return self - def set_x(self, x: float, direction: Vect3 = ORIGIN): + def set_x(self, x: float, direction: Vect3 = ORIGIN) -> Self: return self.set_coord(x, 0, direction) - def set_y(self, y: float, direction: Vect3 = ORIGIN): + def set_y(self, y: float, direction: Vect3 = ORIGIN) -> Self: return self.set_coord(y, 1, direction) - def set_z(self, z: float, direction: Vect3 = ORIGIN): + def set_z(self, z: float, direction: Vect3 = ORIGIN) -> Self: return self.set_coord(z, 2, direction) - def space_out_submobjects(self, factor: float = 1.5, **kwargs): + def space_out_submobjects(self, factor: float = 1.5, **kwargs) -> Self: self.scale(factor, **kwargs) for submob in self.submobjects: submob.scale(1. / factor) @@ -1210,7 +1216,7 @@ class Mobject(object): point_or_mobject: Mobject | Vect3, aligned_edge: Vect3 = ORIGIN, coor_mask: Vect3 = np.array([1, 1, 1]) - ): + ) -> Self: if isinstance(point_or_mobject, Mobject): target = point_or_mobject.get_bounding_box_point(aligned_edge) else: @@ -1219,7 +1225,7 @@ class Mobject(object): self.shift((target - point_to_align) * coor_mask) return self - def replace(self, mobject: Mobject, dim_to_match: int = 0, stretch: bool = False): + def replace(self, mobject: Mobject, dim_to_match: int = 0, stretch: bool = False) -> Self: if not mobject.get_num_points() and not mobject.submobjects: self.scale(0) return self @@ -1241,13 +1247,13 @@ class Mobject(object): dim_to_match: int = 0, stretch: bool = False, buff: float = MED_SMALL_BUFF - ): + ) -> Self: self.replace(mobject, dim_to_match, stretch) length = mobject.length_over_dim(dim_to_match) self.scale((length + buff) / length) return self - def put_start_and_end_on(self, start: Vect3, end: Vect3): + def put_start_and_end_on(self, start: Vect3, end: Vect3) -> Self: curr_start, curr_end = self.get_start_and_end() curr_vect = curr_end - curr_start if np.all(curr_vect == 0): @@ -1275,7 +1281,7 @@ class Mobject(object): rgba_array: npt.ArrayLike, name: str = "rgba", recurse: bool = False - ): + ) -> Self: for mob in self.get_family(recurse): data = mob.data if mob.get_num_points() > 0 else mob._data_defaults data[name][:] = rgba_array @@ -1285,7 +1291,7 @@ class Mobject(object): self, func: Callable[[Vect3], Vect4], recurse: bool = True - ): + ) -> Self: """ Func should take in a point in R3 and output an rgba value """ @@ -1299,7 +1305,7 @@ class Mobject(object): func: Callable[[Vect3], Vect3], opacity: float = 1, recurse: bool = True - ): + ) -> Self: """ Func should take in a point in R3 and output an rgb value """ @@ -1315,7 +1321,7 @@ class Mobject(object): opacity: float | Iterable[float] | None = None, name: str = "rgba", recurse: bool = True - ): + ) -> Self: for mob in self.get_family(recurse): data = mob.data if mob.has_points() > 0 else mob._data_defaults if color is not None: @@ -1334,7 +1340,7 @@ class Mobject(object): color: ManimColor | Iterable[ManimColor] | None, opacity: float | Iterable[float] | None = None, recurse: bool = True - ): + ) -> Self: self.set_rgba_array_by_color(color, opacity, recurse=False) # Recurse to submobjects differently from how set_rgba_array_by_color # in case they implement set_color differently @@ -1347,7 +1353,7 @@ class Mobject(object): self, opacity: float | Iterable[float] | None, recurse: bool = True - ): + ) -> Self: self.set_rgba_array_by_color(color=None, opacity=opacity, recurse=False) if recurse: for submob in self.submobjects: @@ -1360,14 +1366,14 @@ class Mobject(object): def get_opacity(self) -> float: return self.data["rgba"][0, 3] - def set_color_by_gradient(self, *colors: ManimColor): + def set_color_by_gradient(self, *colors: ManimColor) -> Self: if self.has_points(): self.set_color(colors) else: self.set_submobject_colors_by_gradient(*colors) return self - def set_submobject_colors_by_gradient(self, *colors: ManimColor): + def set_submobject_colors_by_gradient(self, *colors: ManimColor) -> Self: if len(colors) == 0: raise Exception("Need at least one color") elif len(colors) == 1: @@ -1381,7 +1387,7 @@ class Mobject(object): mob.set_color(color) return self - def fade(self, darkness: float = 0.5, recurse: bool = True): + def fade(self, darkness: float = 0.5, recurse: bool = True) -> Self: self.set_opacity(1.0 - darkness, recurse=recurse) def get_shading(self) -> np.ndarray: @@ -1393,7 +1399,7 @@ class Mobject(object): gloss: float | None = None, shadow: float | None = None, recurse: bool = True - ): + ) -> Self: """ Larger reflectiveness makes things brighter when facing the light Larger shadow makes faces opposite the light darker @@ -1414,15 +1420,15 @@ class Mobject(object): def get_shadow(self) -> float: return self.get_shading()[2] - def set_reflectiveness(self, reflectiveness: float, recurse: bool = True): + def set_reflectiveness(self, reflectiveness: float, recurse: bool = True) -> Self: self.set_shading(reflectiveness=reflectiveness, recurse=recurse) return self - def set_gloss(self, gloss: float, recurse: bool = True): + def set_gloss(self, gloss: float, recurse: bool = True) -> Self: self.set_shading(gloss=gloss, recurse=recurse) return self - def set_shadow(self, shadow: float, recurse: bool = True): + def set_shadow(self, shadow: float, recurse: bool = True) -> Self: self.set_shading(shadow=shadow, recurse=recurse) return self @@ -1433,7 +1439,7 @@ class Mobject(object): color: ManimColor | None = None, opacity: float = 0.75, **kwargs - ): + ) -> Self: # TODO, this does not behave well when the mobject has points, # since it gets displayed on top from manimlib.mobject.shape_matchers import BackgroundRectangle @@ -1445,12 +1451,12 @@ class Mobject(object): self.add_to_back(self.background_rectangle) return self - def add_background_rectangle_to_submobjects(self, **kwargs): + def add_background_rectangle_to_submobjects(self, **kwargs) -> Self: for submobject in self.submobjects: submobject.add_background_rectangle(**kwargs) return self - def add_background_rectangle_to_family_members_with_points(self, **kwargs): + def add_background_rectangle_to_family_members_with_points(self, **kwargs) -> Self: for mob in self.family_members_with_points(): mob.add_background_rectangle(**kwargs) return self @@ -1580,29 +1586,29 @@ class Mobject(object): for a1, a2 in zip(alphas[:-1], alphas[1:]) ]) - def get_z_index_reference_point(self): + def get_z_index_reference_point(self) -> Vect3: # TODO, better place to define default z_index_group? z_index_group = getattr(self, "z_index_group", self) return z_index_group.get_center() # Match other mobject properties - def match_color(self, mobject: Mobject): + def match_color(self, mobject: Mobject) -> Self: return self.set_color(mobject.get_color()) - def match_dim_size(self, mobject: Mobject, dim: int, **kwargs): + def match_dim_size(self, mobject: Mobject, dim: int, **kwargs) -> Self: return self.rescale_to_fit( mobject.length_over_dim(dim), dim, **kwargs ) - def match_width(self, mobject: Mobject, **kwargs): + def match_width(self, mobject: Mobject, **kwargs) -> Self: return self.match_dim_size(mobject, 0, **kwargs) - def match_height(self, mobject: Mobject, **kwargs): + def match_height(self, mobject: Mobject, **kwargs) -> Self: return self.match_dim_size(mobject, 1, **kwargs) - def match_depth(self, mobject: Mobject, **kwargs): + def match_depth(self, mobject: Mobject, **kwargs) -> Self: return self.match_dim_size(mobject, 2, **kwargs) def match_coord( @@ -1610,7 +1616,7 @@ class Mobject(object): mobject_or_point: Mobject | Vect3, dim: int, direction: Vect3 = ORIGIN - ): + ) -> Self: if isinstance(mobject_or_point, Mobject): coord = mobject_or_point.get_coord(dim, direction) else: @@ -1621,28 +1627,28 @@ class Mobject(object): self, mobject_or_point: Mobject | Vect3, direction: Vect3 = ORIGIN - ): + ) -> Self: return self.match_coord(mobject_or_point, 0, direction) def match_y( self, mobject_or_point: Mobject | Vect3, direction: Vect3 = ORIGIN - ): + ) -> Self: return self.match_coord(mobject_or_point, 1, direction) def match_z( self, mobject_or_point: Mobject | Vect3, direction: Vect3 = ORIGIN - ): + ) -> Self: return self.match_coord(mobject_or_point, 2, direction) def align_to( self, mobject_or_point: Mobject | Vect3, direction: Vect3 = ORIGIN - ): + ) -> Self: """ Examples: mob1.align_to(mob2, UP) moves mob1 vertically so that its @@ -1667,21 +1673,23 @@ class Mobject(object): # Alignment - def align_data_and_family(self, mobject: Mobject) -> None: + def align_data_and_family(self, mobject: Mobject) -> Self: self.align_family(mobject) self.align_data(mobject) + return self - def align_data(self, mobject: Mobject) -> None: + def align_data(self, mobject: Mobject) -> Self: for mob1, mob2 in zip(self.get_family(), mobject.get_family()): mob1.align_points(mob2) + return self - def align_points(self, mobject: Mobject): + def align_points(self, mobject: Mobject) -> Self: max_len = max(self.get_num_points(), mobject.get_num_points()) for mob in (self, mobject): mob.resize_points(max_len, resize_func=resize_preserving_order) return self - def align_family(self, mobject: Mobject): + def align_family(self, mobject: Mobject) -> Self: mob1 = self mob2 = mobject n1 = len(mob1) @@ -1694,14 +1702,14 @@ class Mobject(object): sm1.align_family(sm2) return self - def push_self_into_submobjects(self): + def push_self_into_submobjects(self) -> Self: copy = self.copy() copy.set_submobjects([]) self.resize_points(0) self.add(copy) return self - def add_n_more_submobjects(self, n: int): + def add_n_more_submobjects(self, n: int) -> Self: if n == 0: return self @@ -1729,7 +1737,7 @@ class Mobject(object): self.set_submobjects(new_submobs) return self - def invisible_copy(self): + def invisible_copy(self) -> Self: return self.copy().set_opacity(0) # Interpolate @@ -1740,7 +1748,7 @@ class Mobject(object): mobject2: Mobject, alpha: float, path_func: Callable[[np.ndarray, np.ndarray, float], np.ndarray] = straight_path - ): + ) -> Self: keys = [k for k in self.data.dtype.names if k not in self.locked_data_keys] if keys: self.note_changed_data() @@ -1766,18 +1774,19 @@ class Mobject(object): ) return self - def pointwise_become_partial(self, mobject, a, b): + def pointwise_become_partial(self, mobject, a, b) -> Self: """ Set points in such a way as to become only part of mobject. Inputs 0 <= a < b <= 1 determine what portion of mobject to become. """ - pass # To implement in subclass + # To be implemented in subclass + return self # Locking data - def lock_data(self, keys: Iterable[str]): + def lock_data(self, keys: Iterable[str]) -> Self: """ To speed up some animations, particularly transformations, it can be handy to acknowledge which pieces of data @@ -1788,8 +1797,9 @@ class Mobject(object): if self.has_updaters: return self.locked_data_keys = set(keys) + return self - def lock_matching_data(self, mobject1: Mobject, mobject2: Mobject): + def lock_matching_data(self, mobject1: Mobject, mobject2: Mobject) -> Self: tuples = zip( self.get_family(), mobject1.get_family(), @@ -1813,10 +1823,11 @@ class Mobject(object): return self - def unlock_data(self): + def unlock_data(self) -> Self: for mob in self.get_family(): mob.locked_data_keys = set() mob.const_data_keys = set() + return self # Operations touching shader uniforms @@ -1829,18 +1840,18 @@ class Mobject(object): return wrapper @affects_shader_info_id - def set_uniform(self, recurse: bool = True, **new_uniforms): + def set_uniform(self, recurse: bool = True, **new_uniforms) -> Self: for mob in self.get_family(recurse): mob.uniforms.update(new_uniforms) return self @affects_shader_info_id - def fix_in_frame(self, recurse: bool = True): + def fix_in_frame(self, recurse: bool = True) -> Self: self.set_uniform(recurse, is_fixed_in_frame=1.0) return self @affects_shader_info_id - def unfix_from_frame(self, recurse: bool = True): + def unfix_from_frame(self, recurse: bool = True) -> Self: self.set_uniform(recurse, is_fixed_in_frame=0.0) return self @@ -1848,20 +1859,20 @@ class Mobject(object): return bool(self.uniforms["is_fixed_in_frame"]) @affects_shader_info_id - def apply_depth_test(self, recurse: bool = True): + def apply_depth_test(self, recurse: bool = True) -> Self: for mob in self.get_family(recurse): mob.depth_test = True return self @affects_shader_info_id - def deactivate_depth_test(self, recurse: bool = True): + def deactivate_depth_test(self, recurse: bool = True) -> Self: for mob in self.get_family(recurse): mob.depth_test = False return self # Shader code manipulation - def replace_shader_code(self, old: str, new: str): + def replace_shader_code(self, old: str, new: str) -> Self: # TODO, will this work with VMobject structure, given # that it does not simpler return shader_wrappers of # family? @@ -1869,7 +1880,7 @@ class Mobject(object): wrapper.replace_code(old, new) return self - def set_color_by_code(self, glsl_code: str): + def set_color_by_code(self, glsl_code: str) -> Self: """ Takes a snippet of code and inserts it into a context which has the following variables: @@ -1888,7 +1899,7 @@ class Mobject(object): min_value: float = -5.0, max_value: float = 5.0, colormap: str = "viridis" - ): + ) -> Self: """ Pass in a glsl expression in terms of x, y and z which returns a float. @@ -2091,7 +2102,7 @@ class Group(Mobject): if any(m.is_fixed_in_frame() for m in mobjects): self.fix_in_frame() - def __add__(self, other: Mobject | Group): + def __add__(self, other: Mobject | Group) -> Self: assert(isinstance(other, Mobject)) return self.add(other) @@ -2121,8 +2132,9 @@ class Point(Mobject): def get_bounding_box_point(self, *args, **kwargs) -> Vect3: return self.get_location() - def set_location(self, new_loc: npt.ArrayLike): + def set_location(self, new_loc: npt.ArrayLike) -> Self: self.set_points(np.array(new_loc, ndmin=2, dtype=float)) + return self class _AnimationBuilder: diff --git a/manimlib/mobject/numbers.py b/manimlib/mobject/numbers.py index 68e9bdf8..3ede97f1 100644 --- a/manimlib/mobject/numbers.py +++ b/manimlib/mobject/numbers.py @@ -11,7 +11,7 @@ from manimlib.mobject.types.vectorized_mobject import VMobject from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import TypeVar + from typing import TypeVar, Self from manimlib.typing import ManimColor, Vect3 T = TypeVar("T", bound=VMobject) @@ -163,7 +163,7 @@ class DecimalNumber(VMobject): def get_tex(self): return self.num_string - def set_value(self, number: float | complex): + def set_value(self, number: float | complex) -> Self: move_to_point = self.get_edge_center(self.edge_to_fix) style = self.family_members_with_points()[0].get_style() self.set_submobjects_from_number(number) @@ -171,14 +171,16 @@ class DecimalNumber(VMobject): self.set_style(**style) return self - def _handle_scale_side_effects(self, scale_factor: float) -> None: + def _handle_scale_side_effects(self, scale_factor: float) -> Self: self.uniforms["font_size"] = scale_factor * self.uniforms["font_size"] + return self def get_value(self) -> float | complex: return self.number - def increment_value(self, delta_t: float | complex = 1) -> None: + def increment_value(self, delta_t: float | complex = 1) -> Self: self.set_value(self.get_value() + delta_t) + return self class Integer(DecimalNumber): diff --git a/manimlib/mobject/shape_matchers.py b/manimlib/mobject/shape_matchers.py index 6b20344d..07accaea 100644 --- a/manimlib/mobject/shape_matchers.py +++ b/manimlib/mobject/shape_matchers.py @@ -14,7 +14,7 @@ from manimlib.utils.customization import get_customization from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Sequence + from typing import Sequence, Self from manimlib.mobject.mobject import Mobject from manimlib.typing import ManimColor @@ -60,20 +60,20 @@ class BackgroundRectangle(SurroundingRectangle): ) self.original_fill_opacity = fill_opacity - def pointwise_become_partial(self, mobject: Mobject, a: float, b: float): + def pointwise_become_partial(self, mobject: Mobject, a: float, b: float) -> Self: self.set_fill(opacity=b * self.original_fill_opacity) return self - def set_style_data( + def set_style( self, stroke_color: ManimColor | None = None, stroke_width: float | None = None, fill_color: ManimColor | None = None, fill_opacity: float | None = None, family: bool = True - ): + ) -> Self: # Unchangeable style, except for fill_opacity - VMobject.set_style_data( + VMobject.set_style( self, stroke_color=BLACK, stroke_width=0, diff --git a/manimlib/mobject/three_dimensions.py b/manimlib/mobject/three_dimensions.py index 3fc9c719..05a12ede 100644 --- a/manimlib/mobject/three_dimensions.py +++ b/manimlib/mobject/three_dimensions.py @@ -166,7 +166,6 @@ class Cylinder(Surface): self.scale(self.radius) self.set_depth(self.height, stretch=True) self.apply_matrix(z_to_vector(self.axis)) - return self def uv_func(self, u: float, v: float) -> np.ndarray: return np.array([np.cos(u), np.sin(u), v]) @@ -186,6 +185,7 @@ class Line3D(Cylinder): height=get_norm(axis), radius=width / 2, axis=axis, + resolution=resolution, **kwargs ) self.shift((start + end) / 2) @@ -376,16 +376,6 @@ class Dodecahedron(VGroup3D): super().__init__(*pentagons, **style) - # # Rotate those two pentagons by all the axis permuations to fill - # # out the dodecahedron - # Id = np.identity(3) - # for i in range(3): - # perm = [j % 3 for j in range(i, i + 3)] - # for b in [1, -1]: - # matrix = b * np.array([Id[0][perm], Id[1][perm], Id[2][perm]]) - # self.add(pentagon1.copy().apply_matrix(matrix, about_point=ORIGIN)) - # self.add(pentagon2.copy().apply_matrix(matrix, about_point=ORIGIN)) - class Prismify(VGroup3D): def __init__(self, vmobject, depth=1.0, direction=IN, **kwargs): diff --git a/manimlib/mobject/types/dot_cloud.py b/manimlib/mobject/types/dot_cloud.py index 83aa9b07..ad65647c 100644 --- a/manimlib/mobject/types/dot_cloud.py +++ b/manimlib/mobject/types/dot_cloud.py @@ -13,7 +13,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: import numpy.typing as npt - from typing import Sequence, Tuple + from typing import Sequence, Tuple, Self from manimlib.typing import ManimColor, Vect3, Vect3Array @@ -70,7 +70,7 @@ class DotCloud(PMobject): v_buff_ratio: float = 1.0, d_buff_ratio: float = 1.0, height: float = DEFAULT_GRID_HEIGHT, - ): + ) -> Self: n_points = n_rows * n_cols * n_layers points = np.repeat(range(n_points), 3, axis=0).reshape((n_points, 3)) points[:, 0] = points[:, 0] % n_cols @@ -96,7 +96,7 @@ class DotCloud(PMobject): return self @Mobject.affects_data - def set_radii(self, radii: npt.ArrayLike): + def set_radii(self, radii: npt.ArrayLike) -> Self: n_points = self.get_num_points() radii = np.array(radii).reshape((len(radii), 1)) self.data["radius"][:] = resize_with_interpolation(radii, n_points) @@ -107,7 +107,7 @@ class DotCloud(PMobject): return self.data["radius"] @Mobject.affects_data - def set_radius(self, radius: float): + def set_radius(self, radius: float) -> Self: data = self.data if self.get_num_points() > 0 else self._data_defaults data["radius"][:] = radius self.refresh_bounding_box() @@ -116,13 +116,14 @@ class DotCloud(PMobject): def get_radius(self) -> float: return self.get_radii().max() - def set_glow_factor(self, glow_factor: float) -> None: + def set_glow_factor(self, glow_factor: float) -> Self: self.uniforms["glow_factor"] = glow_factor + return self def get_glow_factor(self) -> float: return self.uniforms["glow_factor"] - def compute_bounding_box(self) -> np.ndarray: + def compute_bounding_box(self) -> Vect3Array: bb = super().compute_bounding_box() radius = self.get_radius() bb[0] += np.full((3,), -radius) @@ -134,7 +135,7 @@ class DotCloud(PMobject): scale_factor: float | npt.ArrayLike, scale_radii: bool = True, **kwargs - ): + ) -> Self: super().scale(scale_factor, **kwargs) if scale_radii: self.set_radii(scale_factor * self.get_radii()) @@ -145,7 +146,7 @@ class DotCloud(PMobject): reflectiveness: float = 0.5, gloss: float = 0.1, shadow: float = 0.2 - ): + ) -> Self: self.set_shading(reflectiveness, gloss, shadow) self.apply_depth_test() return self diff --git a/manimlib/mobject/types/point_cloud_mobject.py b/manimlib/mobject/types/point_cloud_mobject.py index f6180b3f..ab958473 100644 --- a/manimlib/mobject/types/point_cloud_mobject.py +++ b/manimlib/mobject/types/point_cloud_mobject.py @@ -10,7 +10,7 @@ from manimlib.utils.iterables import resize_with_interpolation from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Callable + from typing import Callable, Self from manimlib.typing import ManimColor, Vect3, Vect3Array, Vect4Array @@ -28,7 +28,7 @@ class PMobject(Mobject): rgbas: Vect4Array | None = None, color: ManimColor | None = None, opacity: float | None = None - ): + ) -> Self: """ points must be a Nx3 numpy array, as must rgbas if it is not None """ @@ -46,13 +46,13 @@ class PMobject(Mobject): self.data["rgba"][-len(rgbas):] = rgbas return self - def add_point(self, point: Vect3, rgba=None, color=None, opacity=None): + def add_point(self, point: Vect3, rgba=None, color=None, opacity=None) -> Self: rgbas = None if rgba is None else [rgba] self.add_points([point], rgbas, color, opacity) return self @Mobject.affects_data - def set_color_by_gradient(self, *colors: ManimColor): + def set_color_by_gradient(self, *colors: ManimColor) -> Self: self.data["rgba"][:] = np.array(list(map( color_to_rgba, color_gradient(colors, self.get_num_points()) @@ -60,20 +60,20 @@ class PMobject(Mobject): return self @Mobject.affects_data - def match_colors(self, pmobject: PMobject): + def match_colors(self, pmobject: PMobject) -> Self: self.data["rgba"][:] = resize_with_interpolation( pmobject.data["rgba"], self.get_num_points() ) return self @Mobject.affects_data - def filter_out(self, condition: Callable[[np.ndarray], bool]): + def filter_out(self, condition: Callable[[np.ndarray], bool]) -> Self: for mob in self.family_members_with_points(): mob.data = mob.data[~np.apply_along_axis(condition, 1, mob.get_points())] return self @Mobject.affects_data - def sort_points(self, function: Callable[[Vect3], None] = lambda p: p[0]): + def sort_points(self, function: Callable[[Vect3], None] = lambda p: p[0]) -> Self: """ function is any map from R^3 to R """ @@ -85,7 +85,7 @@ class PMobject(Mobject): return self @Mobject.affects_data - def ingest_submobjects(self): + def ingest_submobjects(self) -> Self: self.data = np.vstack([ sm.data for sm in self.get_family() ]) @@ -96,7 +96,7 @@ class PMobject(Mobject): return self.get_points()[int(index)] @Mobject.affects_data - def pointwise_become_partial(self, pmobject: PMobject, a: float, b: float): + def pointwise_become_partial(self, pmobject: PMobject, a: float, b: float) -> Self: lower_index = int(a * pmobject.get_num_points()) upper_index = int(b * pmobject.get_num_points()) self.data = pmobject.data[lower_index:upper_index].copy() diff --git a/manimlib/mobject/types/surface.py b/manimlib/mobject/types/surface.py index 10d4ab77..736409d2 100644 --- a/manimlib/mobject/types/surface.py +++ b/manimlib/mobject/types/surface.py @@ -17,7 +17,7 @@ from manimlib.utils.space_ops import cross from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Callable, Iterable, Sequence, Tuple + from typing import Callable, Iterable, Sequence, Tuple, Self from manimlib.camera.camera import Camera from manimlib.typing import ManimColor, Vect3, Vect3Array @@ -100,18 +100,19 @@ class Surface(Mobject): (dv_points - points) / self.epsilon, ), 1) - def apply_points_function(self, *args, **kwargs): + def apply_points_function(self, *args, **kwargs) -> Self: super().apply_points_function(*args, **kwargs) self.get_unit_normals() + return self - def compute_triangle_indices(self): + def compute_triangle_indices(self) -> np.ndarray: # TODO, if there is an event which changes # the resolution of the surface, make sure # this is called. nu, nv = self.resolution if nu == 0 or nv == 0: self.triangle_indices = np.zeros(0, dtype=int) - return + return self.triangle_indices index_grid = np.arange(nu * nv).reshape((nu, nv)) indices = np.zeros(6 * (nu - 1) * (nv - 1), dtype=int) indices[0::6] = index_grid[:-1, :-1].flatten() # Top left @@ -121,6 +122,7 @@ class Surface(Mobject): indices[4::6] = index_grid[+1:, :-1].flatten() # Bottom left indices[5::6] = index_grid[+1:, +1:].flatten() # Bottom right self.triangle_indices = indices + return self.triangle_indices def get_triangle_indices(self) -> np.ndarray: return self.triangle_indices @@ -154,7 +156,7 @@ class Surface(Mobject): a: float, b: float, axis: int | None = None - ): + ) -> Self: assert(isinstance(smobject, Surface)) if axis is None: axis = self.prefered_creation_axis @@ -211,7 +213,7 @@ class Surface(Mobject): return points.reshape((nu * nv, *resolution[2:])) @Mobject.affects_data - def sort_faces_back_to_front(self, vect: Vect3 = OUT): + def sort_faces_back_to_front(self, vect: Vect3 = OUT) -> Self: tri_is = self.triangle_indices points = self.get_points() @@ -221,24 +223,25 @@ class Surface(Mobject): tri_is[k::3] = tri_is[k::3][indices] return self - def always_sort_to_camera(self, camera: Camera): + def always_sort_to_camera(self, camera: Camera) -> Self: def updater(surface: Surface): vect = camera.get_location() - surface.get_center() surface.sort_faces_back_to_front(vect) self.add_updater(updater) + return self def set_clip_plane( self, vect: Vect3 | None = None, threshold: float | None = None - ): + ) -> Self: if vect is not None: self.uniforms["clip_plane"][:3] = vect if threshold is not None: self.uniforms["clip_plane"][3] = threshold return self - def deactivate_clip_plane(self): + def deactivate_clip_plane(self) -> Self: self.uniforms["clip_plane"][:] = 0 return self @@ -335,7 +338,7 @@ class TexturedSurface(Surface): self.uniforms["num_textures"] = self.num_textures @Mobject.affects_data - def set_opacity(self, opacity: float | Iterable[float]): + def set_opacity(self, opacity: float | Iterable[float]) -> Self: op_arr = np.array(listify(opacity)) self.data["opacity"][:, 0] = resize_with_interpolation(op_arr, len(self.data)) return self @@ -345,7 +348,7 @@ class TexturedSurface(Surface): color: ManimColor | Iterable[ManimColor] | None, opacity: float | Iterable[float] | None = None, recurse: bool = True - ): + ) -> Self: if opacity is not None: self.set_opacity(opacity) return self @@ -356,7 +359,7 @@ class TexturedSurface(Surface): a: float, b: float, axis: int = 1 - ): + ) -> Self: super().pointwise_become_partial(tsmobject, a, b, axis) im_coords = self.data["im_coords"] im_coords[:] = tsmobject.data["im_coords"] diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index 27051f79..a4fd7452 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -45,7 +45,7 @@ from manimlib.shader_wrapper import FillShaderWrapper from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Callable, Iterable, Tuple + from typing import Callable, Iterable, Tuple, Any, Self from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, Vect4Array from moderngl.context import Context @@ -128,29 +128,10 @@ class VMobject(Mobject): self.uniforms["joint_type"] = JOINT_TYPE_MAP[self.joint_type] self.uniforms["flat_stroke"] = float(self.flat_stroke) - # These are here just to make type checkers happy - def get_family(self, recurse: bool = True) -> list[VMobject]: - return super().get_family(recurse) - - def family_members_with_points(self) -> list[VMobject]: - return super().family_members_with_points() - - def replicate(self, n: int) -> VGroup: - return super().replicate(n) - - def get_grid(self, *args, **kwargs) -> VGroup: - return super().get_grid(*args, **kwargs) - - def __getitem__(self, value: int | slice) -> VMobject: - return super().__getitem__(value) - - def __iter__(self) -> Iterable[VMobject]: - return super().__iter__() - - def add(self, *vmobjects: VMobject): + def add(self, *vmobjects: VMobject) -> Self: if not all((isinstance(m, VMobject) for m in vmobjects)): raise Exception("All submobjects must be of type VMobject") - super().add(*vmobjects) + return super().add(*vmobjects) # Colors def init_colors(self): @@ -175,7 +156,7 @@ class VMobject(Mobject): rgba_array: Vect4Array, name: str | None = None, recurse: bool = False - ): + ) -> Self: if name is None: names = ["fill_rgba", "stroke_rgba"] else: @@ -191,7 +172,7 @@ class VMobject(Mobject): opacity: float | Iterable[float] | None = None, border_width: float | None = None, recurse: bool = True - ): + ) -> Self: self.set_rgba_array_by_color(color, opacity, 'fill_rgba', recurse) if border_width is not None: for mob in self.get_family(recurse): @@ -205,7 +186,7 @@ class VMobject(Mobject): opacity: float | Iterable[float] | None = None, background: bool | None = None, recurse: bool = True - ): + ) -> Self: self.set_rgba_array_by_color(color, opacity, 'stroke_rgba', recurse) if width is not None: @@ -228,7 +209,7 @@ class VMobject(Mobject): color: ManimColor | Iterable[ManimColor] = BLACK, width: float | Iterable[float] = 3, background: bool = True - ): + ) -> Self: self.set_stroke(color, width, background=background) return self @@ -245,7 +226,7 @@ class VMobject(Mobject): stroke_background: bool = True, shading: Tuple[float, float, float] | None = None, recurse: bool = True - ): + ) -> Self: for mob in self.get_family(recurse): if fill_rgba is not None: mob.data['fill_rgba'][:] = resize_with_interpolation(fill_rgba, len(mob.data['fill_rgba'])) @@ -276,7 +257,7 @@ class VMobject(Mobject): mob.set_shading(*shading, recurse=False) return self - def get_style(self): + def get_style(self) -> dict[str, Any]: data = self.data if self.get_num_points() > 0 else self._data_defaults return { "fill_rgba": data['fill_rgba'].copy(), @@ -286,7 +267,7 @@ class VMobject(Mobject): "shading": self.get_shading(), } - def match_style(self, vmobject: VMobject, recurse: bool = True): + def match_style(self, vmobject: VMobject, recurse: bool = True) -> Self: self.set_style(**vmobject.get_style(), recurse=False) if recurse: # Does its best to match up submobject lists, and @@ -305,7 +286,7 @@ class VMobject(Mobject): color: ManimColor | Iterable[ManimColor] | None, opacity: float | Iterable[float] | None = None, recurse: bool = True - ): + ) -> Self: self.set_fill(color, opacity=opacity, recurse=recurse) self.set_stroke(color, opacity=opacity, recurse=recurse) return self @@ -314,16 +295,16 @@ class VMobject(Mobject): self, opacity: float | Iterable[float] | None, recurse: bool = True - ): + ) -> Self: self.set_fill(opacity=opacity, recurse=recurse) self.set_stroke(opacity=opacity, recurse=recurse) return self - def set_anti_alias_width(self, anti_alias_width: float, recurse: bool = True): + def set_anti_alias_width(self, anti_alias_width: float, recurse: bool = True) -> Self: self.set_uniform(recurse, anti_alias_width=anti_alias_width) return self - def fade(self, darkness: float = 0.5, recurse: bool = True): + def fade(self, darkness: float = 0.5, recurse: bool = True) -> Self: mobs = self.get_family() if recurse else [self] for mob in mobs: factor = 1.0 - darkness @@ -407,7 +388,7 @@ class VMobject(Mobject): return self.get_fill_opacity() return self.get_stroke_opacity() - def set_flat_stroke(self, flat_stroke: bool = True, recurse: bool = True): + def set_flat_stroke(self, flat_stroke: bool = True, recurse: bool = True) -> Self: for mob in self.get_family(recurse): mob.uniforms["flat_stroke"] = float(flat_stroke) return self @@ -415,7 +396,7 @@ class VMobject(Mobject): def get_flat_stroke(self) -> bool: return self.uniforms["flat_stroke"] == 1.0 - def set_joint_type(self, joint_type: str, recurse: bool = True): + def set_joint_type(self, joint_type: str, recurse: bool = True) -> Self: for mob in self.get_family(recurse): mob.uniforms["joint_type"] = JOINT_TYPE_MAP[joint_type] return self @@ -428,7 +409,7 @@ class VMobject(Mobject): anti_alias_width: float = 0, fill_border_width: float = 0, recurse: bool=True - ): + ) -> Self: super().apply_depth_test(recurse) self.set_anti_alias_width(anti_alias_width) self.set_fill(border_width=fill_border_width) @@ -439,14 +420,14 @@ class VMobject(Mobject): anti_alias_width: float = 1.0, fill_border_width: float = 0.5, recurse: bool=True - ): + ) -> Self: super().apply_depth_test(recurse) self.set_anti_alias_width(anti_alias_width) self.set_fill(border_width=fill_border_width) return self @Mobject.affects_family_data - def use_winding_fill(self, value: bool = True, recurse: bool = True): + def use_winding_fill(self, value: bool = True, recurse: bool = True) -> Self: for submob in self.get_family(recurse): submob._use_winding_fill = value if not value and submob.has_points(): @@ -458,7 +439,7 @@ class VMobject(Mobject): self, anchors: Vect3Array, handles: Vect3Array, - ): + ) -> Self: assert(len(anchors) == len(handles) + 1) points = resize_array(self.get_points(), 2 * len(anchors) - 1) points[0::2] = anchors @@ -466,7 +447,7 @@ class VMobject(Mobject): self.set_points(points) return self - def start_new_path(self, point: Vect3): + def start_new_path(self, point: Vect3) -> Self: # Path ends are signaled by a handle point sitting directly # on top of the previous anchor if self.has_points(): @@ -481,7 +462,7 @@ class VMobject(Mobject): handle1: Vect3, handle2: Vect3, anchor2: Vect3 - ): + ) -> Self: self.start_new_path(anchor1) self.add_cubic_bezier_curve_to(handle1, handle2, anchor2) return self @@ -491,7 +472,7 @@ class VMobject(Mobject): handle1: Vect3, handle2: Vect3, anchor: Vect3, - ): + ) -> Self: """ Add cubic bezier curve to the path. """ @@ -513,7 +494,7 @@ class VMobject(Mobject): self.append_points(quad_approx[1:]) return self - def add_quadratic_bezier_curve_to(self, handle: Vect3, anchor: Vect3): + def add_quadratic_bezier_curve_to(self, handle: Vect3, anchor: Vect3) -> Self: self.throw_error_if_no_points() last_point = self.get_last_point() if self.consider_points_equal(handle, last_point): @@ -522,14 +503,14 @@ class VMobject(Mobject): self.append_points([handle, anchor]) return self - def add_line_to(self, point: Vect3): + def add_line_to(self, point: Vect3) -> Self: self.throw_error_if_no_points() last_point = self.get_last_point() alphas = np.linspace(0, 1, 5 if self.long_lines else 3) self.append_points(outer_interpolate(last_point, point, alphas[1:])) return self - def add_smooth_curve_to(self, point: Vect3): + def add_smooth_curve_to(self, point: Vect3) -> Self: if self.has_new_path_started(): self.add_line_to(point) else: @@ -538,7 +519,7 @@ class VMobject(Mobject): self.add_quadratic_bezier_curve_to(new_handle, point) return self - def add_smooth_cubic_curve_to(self, handle: Vect3, point: Vect3): + def add_smooth_cubic_curve_to(self, handle: Vect3, point: Vect3) -> Self: self.throw_error_if_no_points() if self.get_num_points() == 1: new_handle = handle @@ -559,7 +540,7 @@ class VMobject(Mobject): points = self.get_points() return 2 * points[-1] - points[-2] - def close_path(self, smooth: bool = False): + def close_path(self, smooth: bool = False) -> Self: if self.is_closed(): return self last_path_start = self.get_subpaths()[-1][0] @@ -577,7 +558,7 @@ class VMobject(Mobject): self, tuple_to_subdivisions: Callable, recurse: bool = True - ): + ) -> Self: for vmob in self.get_family(recurse): if not vmob.has_points(): continue @@ -599,7 +580,7 @@ class VMobject(Mobject): self, angle_threshold: float = 30 * DEGREES, recurse: bool = True - ): + ) -> Self: def tuple_to_subdivisions(b0, b1, b2): angle = angle_between_vectors(b1 - b0, b2 - b1) return int(angle / angle_threshold) @@ -607,7 +588,7 @@ class VMobject(Mobject): self.subdivide_curves_by_condition(tuple_to_subdivisions, recurse) return self - def subdivide_intersections(self, recurse: bool = True, n_subdivisions: int = 1): + def subdivide_intersections(self, recurse: bool = True, n_subdivisions: int = 1) -> Self: path = self.get_anchors() def tuple_to_subdivisions(b0, b1, b2): if line_intersects_path(b0, b1, path): @@ -617,12 +598,12 @@ class VMobject(Mobject): self.subdivide_curves_by_condition(tuple_to_subdivisions, recurse) return self - def add_points_as_corners(self, points: Iterable[Vect3]): + def add_points_as_corners(self, points: Iterable[Vect3]) -> Self: for point in points: self.add_line_to(point) - return points + return self - def set_points_as_corners(self, points: Iterable[Vect3]): + def set_points_as_corners(self, points: Iterable[Vect3]) -> Self: anchors = np.array(points) handles = 0.5 * (anchors[:-1] + anchors[1:]) self.set_anchors_and_handles(anchors, handles) @@ -632,7 +613,7 @@ class VMobject(Mobject): self, points: Iterable[Vect3], approx: bool = True - ): + ) -> Self: self.set_points_as_corners(points) self.make_smooth(approx=approx) return self @@ -641,7 +622,7 @@ class VMobject(Mobject): dots = self.get_joint_products()[::2, 3] return bool((dots > 1 - 1e-3).all()) - def change_anchor_mode(self, mode: str): + def change_anchor_mode(self, mode: str) -> Self: assert(mode in ("jagged", "approx_smooth", "true_smooth")) subpaths = self.get_subpaths() self.clear_points() @@ -664,7 +645,7 @@ class VMobject(Mobject): self.add_subpath(new_subpath) return self - def make_smooth(self, approx=False, recurse=True): + def make_smooth(self, approx=False, recurse=True) -> Self: """ Edits the path so as to pass smoothly through all the current anchor points. @@ -679,15 +660,16 @@ class VMobject(Mobject): submob.change_anchor_mode(mode) return self - def make_approximately_smooth(self, recurse=True): + def make_approximately_smooth(self, recurse=True) -> Self: self.make_smooth(approx=True, recurse=recurse) + return self - def make_jagged(self, recurse=True): + def make_jagged(self, recurse=True) -> Self: for submob in self.get_family(recurse): submob.change_anchor_mode("jagged") return self - def add_subpath(self, points: Vect3Array): + def add_subpath(self, points: Vect3Array) -> Self: assert(len(points) % 2 == 1 or len(points) == 0) if not self.has_points(): self.set_points(points) @@ -697,7 +679,7 @@ class VMobject(Mobject): self.append_points(points[1:]) return self - def append_vectorized_mobject(self, vmobject: VMobject): + def append_vectorized_mobject(self, vmobject: VMobject) -> Self: self.add_subpath(vmobject.get_points()) n = vmobject.get_num_points() self.data[-n:] = vmobject.data @@ -715,7 +697,7 @@ class VMobject(Mobject): def get_bezier_tuples(self) -> Iterable[Vect3Array]: return self.get_bezier_tuples_from_points(self.get_points()) - def get_subpath_end_indices_from_points(self, points: Vect3Array): + def get_subpath_end_indices_from_points(self, points: Vect3Array) -> np.ndarray: atol = self.tolerance_for_point_equality a0, h, a1 = points[0:-1:2], points[1::2], points[2::2] # An anchor point is considered the end of a path @@ -731,7 +713,7 @@ class VMobject(Mobject): is_end[:-1] = is_end[:-1] & ~is_end[1:] return np.array([2 * n for n, end in enumerate(is_end) if end]) - def get_subpath_end_indices(self): + def get_subpath_end_indices(self) -> np.ndarray: return self.get_subpath_end_indices_from_points(self.get_points()) def get_subpaths_from_points(self, points: Vect3Array) -> list[Vect3Array]: @@ -860,7 +842,7 @@ class VMobject(Mobject): self.data["unit_normal"][:] = normal return normal - def refresh_unit_normal(self): + def refresh_unit_normal(self) -> Self: self.get_unit_normal() return self @@ -870,20 +852,20 @@ class VMobject(Mobject): axis: Vect3 = OUT, about_point: Vect3 | None = None, **kwargs - ): + ) -> Self: super().rotate(angle, axis, about_point, **kwargs) for mob in self.get_family(): mob.refresh_unit_normal() return self - def ensure_positive_orientation(self, recurse=True): + def ensure_positive_orientation(self, recurse=True) -> Self: for mob in self.get_family(recurse): if mob.get_unit_normal()[2] < 0: mob.reverse_points() return self # Alignment - def align_points(self, vmobject: VMobject): + def align_points(self, vmobject: VMobject) -> Self: winding = self._use_winding_fill and vmobject._use_winding_fill self.use_winding_fill(winding) vmobject.use_winding_fill(winding) @@ -940,7 +922,7 @@ class VMobject(Mobject): mob.get_joint_products() return self - def invisible_copy(self): + def invisible_copy(self) -> Self: result = self.copy() if not result.has_fill() or result.get_num_points() == 0: return result @@ -948,14 +930,14 @@ class VMobject(Mobject): result.set_opacity(0) return result - def insert_n_curves(self, n: int, recurse: bool = True): + def insert_n_curves(self, n: int, recurse: bool = True) -> Self: for mob in self.get_family(recurse): if mob.get_num_curves() > 0: new_points = mob.insert_n_curves_to_point_list(n, mob.get_points()) mob.set_points(new_points) return self - def insert_n_curves_to_point_list(self, n: int, points: Vect3Array): + def insert_n_curves_to_point_list(self, n: int, points: Vect3Array) -> Vect3Array: if len(points) == 1: return np.repeat(points, 2 * n + 1, 0) @@ -988,7 +970,7 @@ class VMobject(Mobject): mobject2: VMobject, alpha: float, *args, **kwargs - ): + ) -> Self: super().interpolate(mobject1, mobject2, alpha, *args, **kwargs) if self.has_fill() and not self._use_winding_fill: tri1 = mobject1.get_triangulation() @@ -997,7 +979,7 @@ class VMobject(Mobject): self.refresh_triangulation() return self - def pointwise_become_partial(self, vmobject: VMobject, a: float, b: float): + def pointwise_become_partial(self, vmobject: VMobject, a: float, b: float) -> Self: assert(isinstance(vmobject, VMobject)) vm_points = vmobject.get_points() self.data["joint_product"] = vmobject.data["joint_product"] @@ -1040,12 +1022,12 @@ class VMobject(Mobject): self.set_points(new_points, refresh_joints=False) return self - def get_subcurve(self, a: float, b: float) -> VMobject: + def get_subcurve(self, a: float, b: float) -> Self: vmob = self.copy() vmob.pointwise_become_partial(self, a, b) return vmob - def get_outer_vert_indices(self): + def get_outer_vert_indices(self) -> np.ndarray: """ Returns the pattern (0, 1, 2, 2, 3, 4, 4, 5, 6, ...) """ @@ -1056,12 +1038,12 @@ class VMobject(Mobject): # Data for shaders that may need refreshing - def refresh_triangulation(self): + def refresh_triangulation(self) -> Self: for mob in self.get_family(): mob.needs_new_triangulation = True return self - def get_triangulation(self): + def get_triangulation(self) -> np.ndarray: # Figure out how to triangulate the interior to know # how to send the points as to the vertex shader. # First triangles come directly from the points @@ -1118,12 +1100,12 @@ class VMobject(Mobject): self.needs_new_triangulation = False return tri_indices - def refresh_joint_products(self): + def refresh_joint_products(self) -> Self: for mob in self.get_family(): mob.needs_new_joint_products = True return self - def get_joint_products(self, refresh: bool = False): + def get_joint_products(self, refresh: bool = False) -> np.ndarray: """ The 'joint product' is a 4-vector holding the cross and dot product between tangent vectors at a joint @@ -1174,10 +1156,11 @@ class VMobject(Mobject): self.data["joint_product"][:, 3] = (vect_to_vert * vect_from_vert).sum(1) return self.data["joint_product"] - def lock_matching_data(self, vmobject1: VMobject, vmobject2: VMobject): + def lock_matching_data(self, vmobject1: VMobject, vmobject2: VMobject) -> Self: for mob in [self, vmobject1, vmobject2]: mob.get_joint_products() super().lock_matching_data(vmobject1, vmobject2) + return self def triggers_refreshed_triangulation(func: Callable): @wraps(func) @@ -1189,7 +1172,7 @@ class VMobject(Mobject): return self return wrapper - def set_points(self, points: Vect3Array, refresh_joints: bool = True): + def set_points(self, points: Vect3Array, refresh_joints: bool = True) -> Self: assert(len(points) == 0 or len(points) % 2 == 1) super().set_points(points) self.refresh_triangulation() @@ -1199,13 +1182,13 @@ class VMobject(Mobject): return self @triggers_refreshed_triangulation - def append_points(self, points: Vect3Array): + def append_points(self, points: Vect3Array) -> Self: assert(len(points) % 2 == 0) super().append_points(points) return self @triggers_refreshed_triangulation - def reverse_points(self, recurse: bool = True): + def reverse_points(self, recurse: bool = True) -> Self: # This will reset which anchors are # considered path ends for mob in self.get_family(recurse): @@ -1218,7 +1201,7 @@ class VMobject(Mobject): return self @triggers_refreshed_triangulation - def set_data(self, data: np.ndarray): + def set_data(self, data: np.ndarray) -> Self: super().set_data(data) return self @@ -1229,15 +1212,16 @@ class VMobject(Mobject): function: Callable[[Vect3], Vect3], make_smooth: bool = False, **kwargs - ): + ) -> Self: super().apply_function(function, **kwargs) if self.make_smooth_after_applying_functions or make_smooth: self.make_smooth(approx=True) return self - def apply_points_function(self, *args, **kwargs): + def apply_points_function(self, *args, **kwargs) -> Self: super().apply_points_function(*args, **kwargs) self.refresh_joint_products() + return self # For shaders def init_shader_data(self, ctx: Context): @@ -1272,7 +1256,7 @@ class VMobject(Mobject): self.stroke_shader_wrapper, ] - def refresh_shader_wrapper_id(self): + def refresh_shader_wrapper_id(self) -> Self: if not self._shaders_initialized: return self for wrapper in self.shader_wrappers: @@ -1339,7 +1323,7 @@ class VGroup(VMobject): super().__init__(**kwargs) self.add(*vmobjects) - def __add__(self, other: VMobject | VGroup): + def __add__(self, other: VMobject) -> Self: assert(isinstance(other, VMobject)) return self.add(other) diff --git a/manimlib/mobject/value_tracker.py b/manimlib/mobject/value_tracker.py index 35b39557..78b74a7f 100644 --- a/manimlib/mobject/value_tracker.py +++ b/manimlib/mobject/value_tracker.py @@ -1,7 +1,7 @@ from __future__ import annotations import numpy as np - +from typing import Self from manimlib.mobject.mobject import Mobject from manimlib.utils.iterables import listify @@ -36,7 +36,7 @@ class ValueTracker(Mobject): return result[0] return result - def set_value(self, value: float | complex | np.ndarray): + def set_value(self, value: float | complex | np.ndarray) -> Self: self.uniforms["value"][:] = value return self diff --git a/manimlib/mobject/vector_field.py b/manimlib/mobject/vector_field.py index 86c9a181..02463de2 100644 --- a/manimlib/mobject/vector_field.py +++ b/manimlib/mobject/vector_field.py @@ -22,7 +22,6 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Callable, Iterable, Sequence, TypeVar, Tuple - import numpy.typing as npt from manimlib.typing import ManimColor, Vect3, VectN, Vect3Array from manimlib.mobject.coordinate_systems import CoordinateSystem