diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index b2da38f3..f6b27744 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -344,7 +344,7 @@ class Mobject(object): # Family matters - def __getitem__(self, value: int | slice) -> Mobject: + def __getitem__(self, value: int | slice) -> Self: if isinstance(value, slice): GroupClass = self.get_group_class() return GroupClass(*self.split().__getitem__(value)) @@ -739,7 +739,7 @@ class Mobject(object): # Creating new Mobjects from this one - def replicate(self, n: int) -> Group: + def replicate(self, n: int) -> Self: group_class = self.get_group_class() return group_class(*(self.copy() for _ in range(n))) @@ -752,7 +752,7 @@ class Mobject(object): group_by_rows: bool = False, group_by_cols: bool = False, **kwargs - ) -> Group: + ) -> Self: """ Returns a new mobject containing multiple copies of this one arranged in a grid diff --git a/manimlib/mobject/types/vectorized_mobject.py b/manimlib/mobject/types/vectorized_mobject.py index 3b19b9dd..e6cf4f74 100644 --- a/manimlib/mobject/types/vectorized_mobject.py +++ b/manimlib/mobject/types/vectorized_mobject.py @@ -45,7 +45,7 @@ from manimlib.shader_wrapper import FillShaderWrapper from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Callable, Iterable, Tuple + from typing import Callable, Iterable, Tuple, Any, Self from manimlib.typing import ManimColor, Vect3, Vect4, Vect3Array, Vect4Array from moderngl.context import Context @@ -128,29 +128,10 @@ class VMobject(Mobject): self.uniforms["joint_type"] = JOINT_TYPE_MAP[self.joint_type] self.uniforms["flat_stroke"] = float(self.flat_stroke) - # These are here just to make type checkers happy - def get_family(self, recurse: bool = True) -> list[VMobject]: - return super().get_family(recurse) - - def family_members_with_points(self) -> list[VMobject]: - return super().family_members_with_points() - - def replicate(self, n: int) -> VGroup: - return super().replicate(n) - - def get_grid(self, *args, **kwargs) -> VGroup: - return super().get_grid(*args, **kwargs) - - def __getitem__(self, value: int | slice) -> VMobject: - return super().__getitem__(value) - - def __iter__(self) -> Iterable[VMobject]: - return super().__iter__() - - def add(self, *vmobjects: VMobject): + def add(self, *vmobjects: VMobject) -> Self: if not all((isinstance(m, VMobject) for m in vmobjects)): raise Exception("All submobjects must be of type VMobject") - super().add(*vmobjects) + return super().add(*vmobjects) # Colors def init_colors(self): @@ -175,7 +156,7 @@ class VMobject(Mobject): rgba_array: Vect4Array, name: str | None = None, recurse: bool = False - ): + ) -> Self: if name is None: names = ["fill_rgba", "stroke_rgba"] else: @@ -191,7 +172,7 @@ class VMobject(Mobject): opacity: float | Iterable[float] | None = None, border_width: float | None = None, recurse: bool = True - ): + ) -> Self: self.set_rgba_array_by_color(color, opacity, 'fill_rgba', recurse) if border_width is not None: for mob in self.get_family(recurse): @@ -205,7 +186,7 @@ class VMobject(Mobject): opacity: float | Iterable[float] | None = None, background: bool | None = None, recurse: bool = True - ): + ) -> Self: self.set_rgba_array_by_color(color, opacity, 'stroke_rgba', recurse) if width is not None: @@ -228,7 +209,7 @@ class VMobject(Mobject): color: ManimColor | Iterable[ManimColor] = BLACK, width: float | Iterable[float] = 3, background: bool = True - ): + ) -> Self: self.set_stroke(color, width, background=background) return self @@ -245,7 +226,7 @@ class VMobject(Mobject): stroke_background: bool = True, shading: Tuple[float, float, float] | None = None, recurse: bool = True - ): + ) -> Self: for mob in self.get_family(recurse): if fill_rgba is not None: mob.data['fill_rgba'][:] = resize_with_interpolation(fill_rgba, len(mob.data['fill_rgba'])) @@ -276,7 +257,7 @@ class VMobject(Mobject): mob.set_shading(*shading, recurse=False) return self - def get_style(self): + def get_style(self) -> dict[str, Any]: data = self.data if self.get_num_points() > 0 else self._data_defaults return { "fill_rgba": data['fill_rgba'].copy(), @@ -286,7 +267,7 @@ class VMobject(Mobject): "shading": self.get_shading(), } - def match_style(self, vmobject: VMobject, recurse: bool = True): + def match_style(self, vmobject: VMobject, recurse: bool = True) -> Self: self.set_style(**vmobject.get_style(), recurse=False) if recurse: # Does its best to match up submobject lists, and @@ -305,7 +286,7 @@ class VMobject(Mobject): color: ManimColor | Iterable[ManimColor] | None, opacity: float | Iterable[float] | None = None, recurse: bool = True - ): + ) -> Self: self.set_fill(color, opacity=opacity, recurse=recurse) self.set_stroke(color, opacity=opacity, recurse=recurse) return self @@ -314,16 +295,16 @@ class VMobject(Mobject): self, opacity: float | Iterable[float] | None, recurse: bool = True - ): + ) -> Self: self.set_fill(opacity=opacity, recurse=recurse) self.set_stroke(opacity=opacity, recurse=recurse) return self - def set_anti_alias_width(self, anti_alias_width: float, recurse: bool = True): + def set_anti_alias_width(self, anti_alias_width: float, recurse: bool = True) -> Self: self.set_uniform(recurse, anti_alias_width=anti_alias_width) return self - def fade(self, darkness: float = 0.5, recurse: bool = True): + def fade(self, darkness: float = 0.5, recurse: bool = True) -> Self: mobs = self.get_family() if recurse else [self] for mob in mobs: factor = 1.0 - darkness @@ -407,7 +388,7 @@ class VMobject(Mobject): return self.get_fill_opacity() return self.get_stroke_opacity() - def set_flat_stroke(self, flat_stroke: bool = True, recurse: bool = True): + def set_flat_stroke(self, flat_stroke: bool = True, recurse: bool = True) -> Self: for mob in self.get_family(recurse): mob.uniforms["flat_stroke"] = float(flat_stroke) return self @@ -415,7 +396,7 @@ class VMobject(Mobject): def get_flat_stroke(self) -> bool: return self.uniforms["flat_stroke"] == 1.0 - def set_joint_type(self, joint_type: str, recurse: bool = True): + def set_joint_type(self, joint_type: str, recurse: bool = True) -> Self: for mob in self.get_family(recurse): mob.uniforms["joint_type"] = JOINT_TYPE_MAP[joint_type] return self @@ -428,7 +409,7 @@ class VMobject(Mobject): anti_alias_width: float = 0, fill_border_width: float = 0, recurse: bool=True - ): + ) -> Self: super().apply_depth_test(recurse) self.set_anti_alias_width(anti_alias_width) self.set_fill(border_width=fill_border_width) @@ -439,14 +420,14 @@ class VMobject(Mobject): anti_alias_width: float = 1.0, fill_border_width: float = 0.5, recurse: bool=True - ): + ) -> Self: super().apply_depth_test(recurse) self.set_anti_alias_width(anti_alias_width) self.set_fill(border_width=fill_border_width) return self @Mobject.affects_family_data - def use_winding_fill(self, value: bool = True, recurse: bool = True): + def use_winding_fill(self, value: bool = True, recurse: bool = True) -> Self: for submob in self.get_family(recurse): submob._use_winding_fill = value if not value and submob.has_points(): @@ -458,7 +439,7 @@ class VMobject(Mobject): self, anchors: Vect3Array, handles: Vect3Array, - ): + ) -> Self: assert(len(anchors) == len(handles) + 1) points = resize_array(self.get_points(), 2 * len(anchors) - 1) points[0::2] = anchors @@ -466,7 +447,7 @@ class VMobject(Mobject): self.set_points(points) return self - def start_new_path(self, point: Vect3): + def start_new_path(self, point: Vect3) -> Self: # Path ends are signaled by a handle point sitting directly # on top of the previous anchor if self.has_points(): @@ -481,7 +462,7 @@ class VMobject(Mobject): handle1: Vect3, handle2: Vect3, anchor2: Vect3 - ): + ) -> Self: self.start_new_path(anchor1) self.add_cubic_bezier_curve_to(handle1, handle2, anchor2) return self @@ -491,7 +472,7 @@ class VMobject(Mobject): handle1: Vect3, handle2: Vect3, anchor: Vect3, - ): + ) -> Self: """ Add cubic bezier curve to the path. """ @@ -513,7 +494,7 @@ class VMobject(Mobject): self.append_points(quad_approx[1:]) return self - def add_quadratic_bezier_curve_to(self, handle: Vect3, anchor: Vect3): + def add_quadratic_bezier_curve_to(self, handle: Vect3, anchor: Vect3) -> Self: self.throw_error_if_no_points() last_point = self.get_last_point() if self.consider_points_equal(handle, last_point): @@ -522,14 +503,14 @@ class VMobject(Mobject): self.append_points([handle, anchor]) return self - def add_line_to(self, point: Vect3): + def add_line_to(self, point: Vect3) -> Self: self.throw_error_if_no_points() last_point = self.get_last_point() alphas = np.linspace(0, 1, 5 if self.long_lines else 3) self.append_points(outer_interpolate(last_point, point, alphas[1:])) return self - def add_smooth_curve_to(self, point: Vect3): + def add_smooth_curve_to(self, point: Vect3) -> Self: if self.has_new_path_started(): self.add_line_to(point) else: @@ -538,7 +519,7 @@ class VMobject(Mobject): self.add_quadratic_bezier_curve_to(new_handle, point) return self - def add_smooth_cubic_curve_to(self, handle: Vect3, point: Vect3): + def add_smooth_cubic_curve_to(self, handle: Vect3, point: Vect3) -> Self: self.throw_error_if_no_points() if self.get_num_points() == 1: new_handle = handle @@ -559,7 +540,7 @@ class VMobject(Mobject): points = self.get_points() return 2 * points[-1] - points[-2] - def close_path(self, smooth: bool = False): + def close_path(self, smooth: bool = False) -> Self: if self.is_closed(): return self last_path_start = self.get_subpaths()[-1][0] @@ -577,7 +558,7 @@ class VMobject(Mobject): self, tuple_to_subdivisions: Callable, recurse: bool = True - ): + ) -> Self: for vmob in self.get_family(recurse): if not vmob.has_points(): continue @@ -599,7 +580,7 @@ class VMobject(Mobject): self, angle_threshold: float = 30 * DEGREES, recurse: bool = True - ): + ) -> Self: def tuple_to_subdivisions(b0, b1, b2): angle = angle_between_vectors(b1 - b0, b2 - b1) return int(angle / angle_threshold) @@ -607,7 +588,7 @@ class VMobject(Mobject): self.subdivide_curves_by_condition(tuple_to_subdivisions, recurse) return self - def subdivide_intersections(self, recurse: bool = True, n_subdivisions: int = 1): + def subdivide_intersections(self, recurse: bool = True, n_subdivisions: int = 1) -> Self: path = self.get_anchors() def tuple_to_subdivisions(b0, b1, b2): if line_intersects_path(b0, b1, path): @@ -617,12 +598,12 @@ class VMobject(Mobject): self.subdivide_curves_by_condition(tuple_to_subdivisions, recurse) return self - def add_points_as_corners(self, points: Iterable[Vect3]): + def add_points_as_corners(self, points: Iterable[Vect3]) -> Self: for point in points: self.add_line_to(point) - return points + return self - def set_points_as_corners(self, points: Iterable[Vect3]): + def set_points_as_corners(self, points: Iterable[Vect3]) -> Self: anchors = np.array(points) handles = 0.5 * (anchors[:-1] + anchors[1:]) self.set_anchors_and_handles(anchors, handles) @@ -632,7 +613,7 @@ class VMobject(Mobject): self, points: Iterable[Vect3], approx: bool = True - ): + ) -> Self: self.set_points_as_corners(points) self.make_smooth(approx=approx) return self @@ -641,7 +622,7 @@ class VMobject(Mobject): dots = self.get_joint_products()[::2, 3] return bool((dots > 1 - 1e-3).all()) - def change_anchor_mode(self, mode: str): + def change_anchor_mode(self, mode: str) -> Self: assert(mode in ("jagged", "approx_smooth", "true_smooth")) subpaths = self.get_subpaths() self.clear_points() @@ -664,7 +645,7 @@ class VMobject(Mobject): self.add_subpath(new_subpath) return self - def make_smooth(self, approx=False, recurse=True): + def make_smooth(self, approx=False, recurse=True) -> Self: """ Edits the path so as to pass smoothly through all the current anchor points. @@ -679,15 +660,16 @@ class VMobject(Mobject): submob.change_anchor_mode(mode) return self - def make_approximately_smooth(self, recurse=True): + def make_approximately_smooth(self, recurse=True) -> Self: self.make_smooth(approx=True, recurse=recurse) + return self - def make_jagged(self, recurse=True): + def make_jagged(self, recurse=True) -> Self: for submob in self.get_family(recurse): submob.change_anchor_mode("jagged") return self - def add_subpath(self, points: Vect3Array): + def add_subpath(self, points: Vect3Array) -> Self: assert(len(points) % 2 == 1 or len(points) == 0) if not self.has_points(): self.set_points(points) @@ -697,7 +679,7 @@ class VMobject(Mobject): self.append_points(points[1:]) return self - def append_vectorized_mobject(self, vmobject: VMobject): + def append_vectorized_mobject(self, vmobject: VMobject) -> Self: self.add_subpath(vmobject.get_points()) n = vmobject.get_num_points() self.data[-n:] = vmobject.data @@ -715,7 +697,7 @@ class VMobject(Mobject): def get_bezier_tuples(self) -> Iterable[Vect3Array]: return self.get_bezier_tuples_from_points(self.get_points()) - def get_subpath_end_indices_from_points(self, points: Vect3Array): + def get_subpath_end_indices_from_points(self, points: Vect3Array) -> np.ndarray: atol = self.tolerance_for_point_equality a0, h, a1 = points[0:-1:2], points[1::2], points[2::2] # An anchor point is considered the end of a path @@ -731,7 +713,7 @@ class VMobject(Mobject): is_end[:-1] = is_end[:-1] & ~is_end[1:] return np.array([2 * n for n, end in enumerate(is_end) if end]) - def get_subpath_end_indices(self): + def get_subpath_end_indices(self) -> np.ndarray: return self.get_subpath_end_indices_from_points(self.get_points()) def get_subpaths_from_points(self, points: Vect3Array) -> list[Vect3Array]: @@ -860,7 +842,7 @@ class VMobject(Mobject): self.data["unit_normal"][:] = normal return normal - def refresh_unit_normal(self): + def refresh_unit_normal(self) -> Self: self.get_unit_normal() return self @@ -870,20 +852,20 @@ class VMobject(Mobject): axis: Vect3 = OUT, about_point: Vect3 | None = None, **kwargs - ): + ) -> Self: super().rotate(angle, axis, about_point, **kwargs) for mob in self.get_family(): mob.refresh_unit_normal() return self - def ensure_positive_orientation(self, recurse=True): + def ensure_positive_orientation(self, recurse=True) -> Self: for mob in self.get_family(recurse): if mob.get_unit_normal()[2] < 0: mob.reverse_points() return self # Alignment - def align_points(self, vmobject: VMobject): + def align_points(self, vmobject: VMobject) -> Self: winding = self._use_winding_fill and vmobject._use_winding_fill self.use_winding_fill(winding) vmobject.use_winding_fill(winding) @@ -940,7 +922,7 @@ class VMobject(Mobject): mob.get_joint_products() return self - def invisible_copy(self): + def invisible_copy(self) -> Self: result = self.copy() if not result.has_fill() or result.get_num_points() == 0: return result @@ -948,14 +930,14 @@ class VMobject(Mobject): result.set_opacity(0) return result - def insert_n_curves(self, n: int, recurse: bool = True): + def insert_n_curves(self, n: int, recurse: bool = True) -> Self: for mob in self.get_family(recurse): if mob.get_num_curves() > 0: new_points = mob.insert_n_curves_to_point_list(n, mob.get_points()) mob.set_points(new_points) return self - def insert_n_curves_to_point_list(self, n: int, points: Vect3Array): + def insert_n_curves_to_point_list(self, n: int, points: Vect3Array) -> Vect3Array: if len(points) == 1: return np.repeat(points, 2 * n + 1, 0) @@ -988,7 +970,7 @@ class VMobject(Mobject): mobject2: VMobject, alpha: float, *args, **kwargs - ): + ) -> Self: super().interpolate(mobject1, mobject2, alpha, *args, **kwargs) if self.has_fill() and not self._use_winding_fill: tri1 = mobject1.get_triangulation() @@ -997,7 +979,7 @@ class VMobject(Mobject): self.refresh_triangulation() return self - def pointwise_become_partial(self, vmobject: VMobject, a: float, b: float): + def pointwise_become_partial(self, vmobject: VMobject, a: float, b: float) -> Self: assert(isinstance(vmobject, VMobject)) vm_points = vmobject.get_points() self.data["joint_product"] = vmobject.data["joint_product"] @@ -1040,12 +1022,12 @@ class VMobject(Mobject): self.set_points(new_points, refresh_joints=False) return self - def get_subcurve(self, a: float, b: float) -> VMobject: + def get_subcurve(self, a: float, b: float) -> Self: vmob = self.copy() vmob.pointwise_become_partial(self, a, b) return vmob - def get_outer_vert_indices(self): + def get_outer_vert_indices(self) -> np.ndarray: """ Returns the pattern (0, 1, 2, 2, 3, 4, 4, 5, 6, ...) """ @@ -1056,12 +1038,12 @@ class VMobject(Mobject): # Data for shaders that may need refreshing - def refresh_triangulation(self): + def refresh_triangulation(self) -> Self: for mob in self.get_family(): mob.needs_new_triangulation = True return self - def get_triangulation(self): + def get_triangulation(self) -> np.ndarray: # Figure out how to triangulate the interior to know # how to send the points as to the vertex shader. # First triangles come directly from the points @@ -1118,12 +1100,12 @@ class VMobject(Mobject): self.needs_new_triangulation = False return tri_indices - def refresh_joint_products(self): + def refresh_joint_products(self) -> Self: for mob in self.get_family(): mob.needs_new_joint_products = True return self - def get_joint_products(self, refresh: bool = False): + def get_joint_products(self, refresh: bool = False) -> np.ndarray: """ The 'joint product' is a 4-vector holding the cross and dot product between tangent vectors at a joint @@ -1174,10 +1156,11 @@ class VMobject(Mobject): self.data["joint_product"][:, 3] = (vect_to_vert * vect_from_vert).sum(1) return self.data["joint_product"] - def lock_matching_data(self, vmobject1: VMobject, vmobject2: VMobject): + def lock_matching_data(self, vmobject1: VMobject, vmobject2: VMobject) -> Self: for mob in [self, vmobject1, vmobject2]: mob.get_joint_products() super().lock_matching_data(vmobject1, vmobject2) + return self def triggers_refreshed_triangulation(func: Callable): @wraps(func) @@ -1189,7 +1172,7 @@ class VMobject(Mobject): return self return wrapper - def set_points(self, points: Vect3Array, refresh_joints: bool = True): + def set_points(self, points: Vect3Array, refresh_joints: bool = True) -> Self: assert(len(points) == 0 or len(points) % 2 == 1) super().set_points(points) self.refresh_triangulation() @@ -1199,13 +1182,13 @@ class VMobject(Mobject): return self @triggers_refreshed_triangulation - def append_points(self, points: Vect3Array): + def append_points(self, points: Vect3Array) -> Self: assert(len(points) % 2 == 0) super().append_points(points) return self @triggers_refreshed_triangulation - def reverse_points(self, recurse: bool = True): + def reverse_points(self, recurse: bool = True) -> Self: # This will reset which anchors are # considered path ends for mob in self.get_family(recurse): @@ -1218,7 +1201,7 @@ class VMobject(Mobject): return self @triggers_refreshed_triangulation - def set_data(self, data: np.ndarray): + def set_data(self, data: np.ndarray) -> Self: super().set_data(data) return self @@ -1229,15 +1212,16 @@ class VMobject(Mobject): function: Callable[[Vect3], Vect3], make_smooth: bool = False, **kwargs - ): + ) -> Self: super().apply_function(function, **kwargs) if self.make_smooth_after_applying_functions or make_smooth: self.make_smooth(approx=True) return self - def apply_points_function(self, *args, **kwargs): + def apply_points_function(self, *args, **kwargs) -> Self: super().apply_points_function(*args, **kwargs) self.refresh_joint_products() + return self # For shaders def init_shader_data(self, ctx: Context): @@ -1272,7 +1256,7 @@ class VMobject(Mobject): self.stroke_shader_wrapper, ] - def refresh_shader_wrapper_id(self): + def refresh_shader_wrapper_id(self) -> Self: if not self._shaders_initialized: return self for wrapper in self.shader_wrappers: @@ -1343,7 +1327,7 @@ class VGroup(VMobject): super().__init__(**kwargs) self.add(*vmobjects) - def __add__(self, other: VMobject | VGroup): + def __add__(self, other: VMobject) -> Self: assert(isinstance(other, VMobject)) return self.add(other)