diff --git a/manimlib/mobject/mobject.py b/manimlib/mobject/mobject.py index c49a65b7..a1426ba7 100644 --- a/manimlib/mobject/mobject.py +++ b/manimlib/mobject/mobject.py @@ -1,12 +1,16 @@ -import copy -import itertools as it -import random -import sys -import moderngl -from functools import wraps -from collections.abc import Iterable +from __future__ import annotations +import sys +import copy +import random +import itertools as it +from functools import wraps +from typing import Iterable, TypeVar, Callable, Union, Sequence + +import colour +import moderngl import numpy as np +import numpy.typing as npt from manimlib.constants import * from manimlib.utils.color import color_gradient @@ -35,6 +39,13 @@ from manimlib.event_handler.event_listner import EventListner from manimlib.event_handler.event_type import EventType +Self = TypeVar("Self", bound="Mobject") +TimeBasedUpdater = Callable[["Mobject", float], None] +NonTimeUpdater = Callable[["Mobject"], None] +Updater = Union[TimeBasedUpdater, NonTimeUpdater] +Color = Union[str, colour.Color, Sequence[float]] + + class Mobject(object): """ Mathematical Object @@ -66,11 +77,11 @@ class Mobject(object): def __init__(self, **kwargs): digest_config(self, kwargs) - self.submobjects = [] - self.parents = [] - self.family = [self] - self.locked_data_keys = set() - self.needs_new_bounding_box = True + self.submobjects: list["Mobject"] = [] + self.parents: list["Mobject"] = [] + self.family: list["Mobject"] = [self] + self.locked_data_keys: set[str] = set() + self.needs_new_bounding_box: bool = True self.init_data() self.init_uniforms() @@ -86,23 +97,23 @@ class Mobject(object): def __str__(self): return self.__class__.__name__ - def __add__(self, other: 'Mobject') -> 'Mobject': + def __add__(self, other: "Mobject") -> "Mobject": assert(isinstance(other, Mobject)) return self.get_group_class()(self, other) - def __mul__(self, other: 'int') -> 'Mobject': + def __mul__(self, other: int) -> "Mobject": assert(isinstance(other, int)) return self.replicate(other) def init_data(self): - self.data = { + self.data: dict[str, np.ndarray] = { "points": np.zeros((0, 3)), "bounding_box": np.zeros((3, 3)), "rgbas": np.zeros((1, 4)), } def init_uniforms(self): - self.uniforms = { + self.uniforms: dict[str, float] = { "is_fixed_in_frame": float(self.is_fixed_in_frame), "gloss": self.gloss, "shadow": self.shadow, @@ -116,12 +127,12 @@ class Mobject(object): # Typically implemented in subclass, unlpess purposefully left blank pass - def set_data(self, data): + def set_data(self, data: dict): for key in data: self.data[key] = data[key].copy() return self - def set_uniforms(self, uniforms): + def set_uniforms(self, uniforms: dict): for key in uniforms: self.uniforms[key] = uniforms[key] # Copy? return self @@ -133,13 +144,17 @@ class Mobject(object): # Only these methods should directly affect points - def resize_points(self, new_length, resize_func=resize_array): + def resize_points( + self, + new_length: int, + resize_func: Callable[[np.ndarray, int], np.ndarray] = resize_array + ): if new_length != len(self.data["points"]): self.data["points"] = resize_func(self.data["points"], new_length) self.refresh_bounding_box() return self - def set_points(self, points): + def set_points(self, points: npt.ArrayLike): if len(points) == len(self.data["points"]): self.data["points"][:] = points elif isinstance(points, np.ndarray): @@ -149,7 +164,7 @@ class Mobject(object): self.refresh_bounding_box() return self - def append_points(self, new_points): + def append_points(self, new_points: npt.ArrayLike): self.data["points"] = np.vstack([self.data["points"], new_points]) self.refresh_bounding_box() return self @@ -161,7 +176,13 @@ class Mobject(object): self.refresh_unit_normal() return self - def apply_points_function(self, func, about_point=None, about_edge=ORIGIN, works_on_bounding_box=False): + def apply_points_function( + self, + func: Callable[[np.ndarray], np.ndarray], + about_point: np.ndarray = None, + about_edge: np.ndarray = ORIGIN, + works_on_bounding_box: bool = False + ): if about_point is None and about_edge is not None: about_point = self.get_bounding_box_point(about_edge) @@ -187,35 +208,35 @@ class Mobject(object): # Others related to points - def match_points(self, mobject): + def match_points(self, mobject: "Mobject"): self.set_points(mobject.get_points()) return self - def get_points(self): + def get_points(self) -> np.ndarray: return self.data["points"] - def clear_points(self): + def clear_points(self) -> None: self.resize_points(0) - def get_num_points(self): + def get_num_points(self) -> int: return len(self.data["points"]) - def get_all_points(self): + def get_all_points(self) -> np.ndarray: if self.submobjects: return np.vstack([sm.get_points() for sm in self.get_family()]) else: return self.get_points() - def has_points(self): + def has_points(self) -> bool: return self.get_num_points() > 0 - def get_bounding_box(self): + def get_bounding_box(self) -> np.ndarray: if self.needs_new_bounding_box: self.data["bounding_box"] = self.compute_bounding_box() self.needs_new_bounding_box = False return self.data["bounding_box"] - def compute_bounding_box(self): + def compute_bounding_box(self) -> np.ndarray: all_points = np.vstack([ self.get_points(), *( @@ -233,7 +254,11 @@ class Mobject(object): mids = (mins + maxs) / 2 return np.array([mins, mids, maxs]) - def refresh_bounding_box(self, recurse_down=False, recurse_up=True): + def refresh_bounding_box( + self, + recurse_down: bool = False, + recurse_up: bool = True + ): for mob in self.get_family(recurse_down): mob.needs_new_bounding_box = True if recurse_up: @@ -241,7 +266,11 @@ class Mobject(object): parent.refresh_bounding_box() return self - def is_point_touching(self, point, buff=MED_SMALL_BUFF): + def is_point_touching( + self, + point: np.ndarray, + buff: float = MED_SMALL_BUFF + ) -> bool: bb = self.get_bounding_box() mins = (bb[0] - buff) maxs = (bb[2] + buff) @@ -273,7 +302,7 @@ class Mobject(object): parent.assemble_family() return self - def get_family(self, recurse=True): + def get_family(self, recurse: bool = True): if recurse: return self.family else: @@ -282,7 +311,7 @@ class Mobject(object): def family_members_with_points(self): return [m for m in self.get_family() if m.has_points()] - def add(self, *mobjects): + def add(self, *mobjects: "Mobject"): if self in mobjects: raise Exception("Mobject cannot contain self") for mobject in mobjects: @@ -293,7 +322,7 @@ class Mobject(object): self.assemble_family() return self - def remove(self, *mobjects): + def remove(self, *mobjects: "Mobject"): for mobject in mobjects: if mobject in self.submobjects: self.submobjects.remove(mobject) @@ -302,11 +331,11 @@ class Mobject(object): self.assemble_family() return self - def add_to_back(self, *mobjects): + def add_to_back(self, *mobjects: "Mobject"): self.set_submobjects(list_update(mobjects, self.submobjects)) return self - def replace_submobject(self, index, new_submob): + def replace_submobject(self, index: int, new_submob: "Mobject"): old_submob = self.submobjects[index] if self in old_submob.parents: old_submob.parents.remove(self) @@ -314,12 +343,12 @@ class Mobject(object): self.assemble_family() return self - def insert_submobject(self, index, new_submob): + def insert_submobject(self, index: int, new_submob: "Mobject"): self.submobjects.insert(index, new_submob) self.assemble_family() return self - def set_submobjects(self, submobject_list): + def set_submobjects(self, submobject_list: list["Mobject"]): self.remove(*self.submobjects) self.add(*submobject_list) return self @@ -335,22 +364,31 @@ class Mobject(object): # Submobject organization - def arrange(self, direction=RIGHT, center=True, **kwargs): + def arrange( + self, + direction: np.ndarray = RIGHT, + center: bool = True, + **kwargs + ): for m1, m2 in zip(self.submobjects, self.submobjects[1:]): m2.next_to(m1, direction, **kwargs) if center: self.center() return self - def arrange_in_grid(self, n_rows=None, n_cols=None, - buff=None, - h_buff=None, - v_buff=None, - buff_ratio=None, - h_buff_ratio=0.5, - v_buff_ratio=0.5, - aligned_edge=ORIGIN, - fill_rows_first=True): + def arrange_in_grid( + self, + n_rows: int | None = None, + n_cols: int | None = None, + buff: float | None = None, + h_buff: float | None = None, + v_buff: float | None = None, + buff_ratio: float | None = None, + h_buff_ratio: float =0.5, + v_buff_ratio: float = 0.5, + aligned_edge: np.ndarray = ORIGIN, + fill_rows_first: bool = True + ): submobs = self.submobjects if n_rows is None and n_cols is None: n_rows = int(np.sqrt(len(submobs))) @@ -384,12 +422,12 @@ class Mobject(object): self.center() return self - def replicate(self, n): + def replicate(self, n: int) -> Group: return self.get_group_class()( *(self.copy() for x in range(n)) ) - def get_grid(self, n_rows, n_cols, height=None, **kwargs): + def get_grid(self, n_rows: int, n_cols: int, height: float | None = None, **kwargs): """ Returns a new mobject containing multiple copies of this one arranged in a grid @@ -400,7 +438,11 @@ class Mobject(object): grid.set_height(height) return grid - def sort(self, point_to_num_func=lambda p: p[0], submob_func=None): + def sort( + self, + point_to_num_func: Callable[[np.ndarray], float] = lambda p: p[0], + submob_func: Callable[["Mobject"]] | None = None + ): if submob_func is not None: self.submobjects.sort(key=submob_func) else: @@ -408,7 +450,7 @@ class Mobject(object): self.assemble_family() return self - def shuffle(self, recurse=False): + def shuffle(self, recurse: bool = False): if recurse: for submob in self.submobjects: submob.shuffle(recurse=True) @@ -461,7 +503,7 @@ class Mobject(object): self.parents = parents return result - def generate_target(self, use_deepcopy=False): + def generate_target(self, use_deepcopy: bool = False): self.target = None # Prevent exponential explosion if use_deepcopy: self.target = self.deepcopy() @@ -469,7 +511,7 @@ class Mobject(object): self.target = self.copy() return self.target - def save_state(self, use_deepcopy=False): + def save_state(self, use_deepcopy: bool = False): if hasattr(self, "saved_state"): # Prevent exponential growth of data self.saved_state = None @@ -488,12 +530,12 @@ class Mobject(object): # Updating def init_updaters(self): - self.time_based_updaters = [] - self.non_time_updaters = [] - self.has_updaters = False - self.updating_suspended = False + self.time_based_updaters: list[TimeBasedUpdater] = [] + self.non_time_updaters: list[NonTimeUpdater] = [] + self.has_updaters: bool = False + self.updating_suspended: bool = False - def update(self, dt=0, recurse=True): + def update(self, dt: float = 0, recurse: bool = True): if not self.has_updaters or self.updating_suspended: return self for updater in self.time_based_updaters: @@ -505,19 +547,24 @@ class Mobject(object): submob.update(dt, recurse) return self - def get_time_based_updaters(self): + def get_time_based_updaters(self) -> list[TimeBasedUpdater]: return self.time_based_updaters - def has_time_based_updater(self): + def has_time_based_updater(self) -> bool: return len(self.time_based_updaters) > 0 - def get_updaters(self): + def get_updaters(self) -> list[Updater]: return self.time_based_updaters + self.non_time_updaters - def get_family_updaters(self): + def get_family_updaters(self) -> list[Updater]: return list(it.chain(*[sm.get_updaters() for sm in self.get_family()])) - def add_updater(self, update_function, index=None, call_updater=True): + def add_updater( + self, + update_function: Updater, + index: int | None = None, + call_updater: bool = True + ): if "dt" in get_parameters(update_function): updater_list = self.time_based_updaters else: @@ -533,14 +580,14 @@ class Mobject(object): self.update(dt=0) return self - def remove_updater(self, update_function): + def remove_updater(self, update_function: Updater): for updater_list in [self.time_based_updaters, self.non_time_updaters]: while update_function in updater_list: updater_list.remove(update_function) self.refresh_has_updater_status() return self - def clear_updaters(self, recurse=True): + def clear_updaters(self, recurse: bool = True): self.time_based_updaters = [] self.non_time_updaters = [] self.refresh_has_updater_status() @@ -549,20 +596,20 @@ class Mobject(object): submob.clear_updaters() return self - def match_updaters(self, mobject): + def match_updaters(self, mobject: "Mobject"): self.clear_updaters() for updater in mobject.get_updaters(): self.add_updater(updater) return self - def suspend_updating(self, recurse=True): + def suspend_updating(self, recurse: bool = True): self.updating_suspended = True if recurse: for submob in self.submobjects: submob.suspend_updating(recurse) return self - def resume_updating(self, recurse=True, call_updater=True): + def resume_updating(self, recurse: bool = True, call_updater: bool = True): self.updating_suspended = False if recurse: for submob in self.submobjects: @@ -579,7 +626,7 @@ class Mobject(object): # Transforming operations - def shift(self, vector): + def shift(self, vector: np.ndarray): self.apply_points_function( lambda points: points + vector, about_edge=None, @@ -587,7 +634,13 @@ class Mobject(object): ) return self - def scale(self, scale_factor, min_scale_factor=1e-8, about_point=None, about_edge=ORIGIN): + def scale( + self, + scale_factor: float | npt.ArrayLike, + min_scale_factor: float = 1e-8, + about_point: np.ndarray | None = None, + about_edge: np.ndarray = ORIGIN + ): """ Default behavior is to scale about the center of the mobject. The argument about_edge can be a vector, indicating which side of @@ -597,7 +650,7 @@ class Mobject(object): Otherwise, if about_point is given a value, scaling is done with respect to that point. """ - if isinstance(scale_factor, Iterable): + if isinstance(scale_factor, npt.ArrayLike): scale_factor = np.array(scale_factor).clip(min=min_scale_factor) else: scale_factor = max(scale_factor, min_scale_factor) @@ -616,28 +669,35 @@ class Mobject(object): # any other changes when the size gets altered pass - def stretch(self, factor, dim, **kwargs): + def stretch(self, factor: float, dim: int, **kwargs): def func(points): points[:, dim] *= factor return points self.apply_points_function(func, works_on_bounding_box=True, **kwargs) return self - def rotate_about_origin(self, angle, axis=OUT): + def rotate_about_origin(self, angle: float, axis: np.ndarray = OUT): return self.rotate(angle, axis, about_point=ORIGIN) - def rotate(self, angle, axis=OUT, **kwargs): + def rotate( + self, + angle: float, + axis: np.ndarray = OUT, + about_point: np.ndarray | None = None, + **kwargs + ): rot_matrix_T = rotation_matrix_transpose(angle, axis) self.apply_points_function( lambda points: np.dot(points, rot_matrix_T), + about_point, **kwargs ) return self - def flip(self, axis=UP, **kwargs): + def flip(self, axis: np.ndarray = UP, **kwargs): return self.rotate(TAU / 2, axis, **kwargs) - def apply_function(self, function, **kwargs): + def apply_function(self, function: Callable[[np.ndarray], np.ndarray], **kwargs): # Default to applying matrix about the origin, not mobjects center if len(kwargs) == 0: kwargs["about_point"] = ORIGIN @@ -647,16 +707,19 @@ class Mobject(object): ) return self - def apply_function_to_position(self, function): + def apply_function_to_position(self, function: Callable[[np.ndarray], np.ndarray]): self.move_to(function(self.get_center())) return self - def apply_function_to_submobject_positions(self, function): + def apply_function_to_submobject_positions( + self, + function: Callable[[np.ndarray], np.ndarray] + ): for submob in self.submobjects: submob.apply_function_to_position(function) return self - def apply_matrix(self, matrix, **kwargs): + def apply_matrix(self, matrix: npt.ArrayLike, **kwargs): # Default to applying matrix about the origin, not mobjects center if ("about_point" not in kwargs) and ("about_edge" not in kwargs): kwargs["about_point"] = ORIGIN @@ -669,7 +732,7 @@ class Mobject(object): ) return self - def apply_complex_function(self, function, **kwargs): + def apply_complex_function(self, function: Callable[[complex], complex], **kwargs): def R3_func(point): x, y, z = point xy_complex = function(complex(x, y)) @@ -678,9 +741,14 @@ class Mobject(object): xy_complex.imag, z ] - return self.apply_function(R3_func) + return self.apply_function(R3_func, **kwargs) - def wag(self, direction=RIGHT, axis=DOWN, wag_factor=1.0): + def wag( + self, + direction: np.ndarray = RIGHT, + axis: np.ndarray = DOWN, + wag_factor: float = 1.0 + ): for mob in self.family_members_with_points(): alphas = np.dot(mob.get_points(), np.transpose(axis)) alphas -= min(alphas) @@ -698,7 +766,11 @@ class Mobject(object): self.shift(-self.get_center()) return self - def align_on_border(self, direction, buff=DEFAULT_MOBJECT_TO_EDGE_BUFFER): + def align_on_border( + self, + direction: np.ndarray, + buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER + ): """ Direction just needs to be a vector pointing towards side or corner in the 2d plane. @@ -710,20 +782,30 @@ class Mobject(object): self.shift(shift_val) return self - def to_corner(self, corner=LEFT + DOWN, buff=DEFAULT_MOBJECT_TO_EDGE_BUFFER): + def to_corner( + self, + corner: np.ndarray = LEFT + DOWN, + buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER + ): return self.align_on_border(corner, buff) - def to_edge(self, edge=LEFT, buff=DEFAULT_MOBJECT_TO_EDGE_BUFFER): + def to_edge( + self, + edge: np.ndarray = LEFT, + buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER + ): return self.align_on_border(edge, buff) - def next_to(self, mobject_or_point, - direction=RIGHT, - buff=DEFAULT_MOBJECT_TO_MOBJECT_BUFFER, - aligned_edge=ORIGIN, - submobject_to_align=None, - index_of_submobject_to_align=None, - coor_mask=np.array([1, 1, 1]), - ): + def next_to( + self, + mobject_or_point: "Mobject" | np.ndarray, + direction: np.ndarray = RIGHT, + buff: float = DEFAULT_MOBJECT_TO_MOBJECT_BUFFER, + aligned_edge: np.ndarray = ORIGIN, + submobject_to_align: "Mobject" | None = None, + index_of_submobject_to_align: int | slice | None = None, + coor_mask: np.ndarray = np.array([1, 1, 1]), + ): if isinstance(mobject_or_point, Mobject): mob = mobject_or_point if index_of_submobject_to_align is not None: @@ -767,14 +849,14 @@ class Mobject(object): return True return False - def stretch_about_point(self, factor, dim, point): + def stretch_about_point(self, factor: float, dim: int, point: np.ndarray): return self.stretch(factor, dim, about_point=point) - def stretch_in_place(self, factor, dim): + def stretch_in_place(self, factor: float, dim: int): # Now redundant with stretch return self.stretch(factor, dim) - def rescale_to_fit(self, length, dim, stretch=False, **kwargs): + def rescale_to_fit(self, length: float, dim: int, stretch: bool = False, **kwargs): old_length = self.length_over_dim(dim) if old_length == 0: return self @@ -784,63 +866,67 @@ class Mobject(object): self.scale(length / old_length, **kwargs) return self - def stretch_to_fit_width(self, width, **kwargs): + def stretch_to_fit_width(self, width: float, **kwargs): return self.rescale_to_fit(width, 0, stretch=True, **kwargs) - def stretch_to_fit_height(self, height, **kwargs): + def stretch_to_fit_height(self, height: float, **kwargs): return self.rescale_to_fit(height, 1, stretch=True, **kwargs) - def stretch_to_fit_depth(self, depth, **kwargs): + def stretch_to_fit_depth(self, depth: float, **kwargs): return self.rescale_to_fit(depth, 2, stretch=True, **kwargs) - def set_width(self, width, stretch=False, **kwargs): + def set_width(self, width: float, stretch: bool = False, **kwargs): return self.rescale_to_fit(width, 0, stretch=stretch, **kwargs) - def set_height(self, height, stretch=False, **kwargs): + def set_height(self, height: float, stretch: bool = False, **kwargs): return self.rescale_to_fit(height, 1, stretch=stretch, **kwargs) - def set_depth(self, depth, stretch=False, **kwargs): + def set_depth(self, depth: float, stretch: bool = False, **kwargs): return self.rescale_to_fit(depth, 2, stretch=stretch, **kwargs) - def set_max_width(self, max_width, **kwargs): + def set_max_width(self, max_width: float, **kwargs): if self.get_width() > max_width: self.set_width(max_width, **kwargs) return self - def set_max_height(self, max_height, **kwargs): + def set_max_height(self, max_height: float, **kwargs): if self.get_height() > max_height: self.set_height(max_height, **kwargs) return self - def set_max_depth(self, max_depth, **kwargs): + def set_max_depth(self, max_depth: float, **kwargs): if self.get_depth() > max_depth: self.set_depth(max_depth, **kwargs) return self - def set_coord(self, value, dim, direction=ORIGIN): + def set_coord(self, value: float, dim: int, direction: np.ndarray = ORIGIN): curr = self.get_coord(dim, direction) shift_vect = np.zeros(self.dim) shift_vect[dim] = value - curr self.shift(shift_vect) return self - def set_x(self, x, direction=ORIGIN): + def set_x(self, x: float, direction: np.ndarray = ORIGIN): return self.set_coord(x, 0, direction) - def set_y(self, y, direction=ORIGIN): + def set_y(self, y: float, direction: np.ndarray = ORIGIN): return self.set_coord(y, 1, direction) - def set_z(self, z, direction=ORIGIN): + def set_z(self, z: float, direction: np.ndarray = ORIGIN): return self.set_coord(z, 2, direction) - def space_out_submobjects(self, factor=1.5, **kwargs): + def space_out_submobjects(self, factor: float = 1.5, **kwargs): self.scale(factor, **kwargs) for submob in self.submobjects: submob.scale(1. / factor) return self - def move_to(self, point_or_mobject, aligned_edge=ORIGIN, - coor_mask=np.array([1, 1, 1])): + def move_to( + self, + point_or_mobject: "Mobject" | np.ndarray, + aligned_edge: np.ndarray = ORIGIN, + coor_mask: np.ndarray = np.array([1, 1, 1]) + ): if isinstance(point_or_mobject, Mobject): target = point_or_mobject.get_bounding_box_point(aligned_edge) else: @@ -849,7 +935,7 @@ class Mobject(object): self.shift((target - point_to_align) * coor_mask) return self - def replace(self, mobject, dim_to_match=0, stretch=False): + def replace(self, mobject: "Mobject", dim_to_match: int = 0, stretch: bool = False): if not mobject.get_num_points() and not mobject.submobjects: self.scale(0) return self @@ -865,16 +951,19 @@ class Mobject(object): self.shift(mobject.get_center() - self.get_center()) return self - def surround(self, mobject, - dim_to_match=0, - stretch=False, - buff=MED_SMALL_BUFF): + def surround( + self, + mobject: "Mobject", + dim_to_match: int = 0, + stretch: bool = False, + buff: float = MED_SMALL_BUFF + ): self.replace(mobject, dim_to_match, stretch) length = mobject.length_over_dim(dim_to_match) self.scale((length + buff) / length) return self - def put_start_and_end_on(self, start, end): + def put_start_and_end_on(self, start: np.ndarray, end: np.ndarray): curr_start, curr_end = self.get_start_and_end() curr_vect = curr_end - curr_start if np.all(curr_vect == 0): @@ -896,12 +985,21 @@ class Mobject(object): # Color functions - def set_rgba_array(self, rgba_array, name="rgbas", recurse=False): + def set_rgba_array( + self, + rgba_array: npt.ArrayLike, + name: str = "rgbas", + recurse: bool = False + ): for mob in self.get_family(recurse): mob.data[name] = np.array(rgba_array) return self - def set_color_by_rgba_func(self, func, recurse=True): + def set_color_by_rgba_func( + self, + func: Callable[[np.ndarray], Sequence[float]], + recurse: bool = True + ): """ Func should take in a point in R3 and output an rgba value """ @@ -910,7 +1008,12 @@ class Mobject(object): mob.set_rgba_array(rgba_array) return self - def set_color_by_rgb_func(self, func, opacity=1, recurse=True): + def set_color_by_rgb_func( + self, + func: Callable[[np.ndarray], Sequence[float]], + opacity: float = 1, + recurse: bool = True + ): """ Func should take in a point in R3 and output an rgb value """ @@ -919,7 +1022,13 @@ class Mobject(object): mob.set_rgba_array(rgba_array) return self - def set_rgba_array_by_color(self, color=None, opacity=None, name="rgbas", recurse=True): + def set_rgba_array_by_color( + self, + color: Color | None = None, + opacity: float | None = None, + name: str = "rgbas", + recurse: bool = True + ): if color is not None: rgbs = np.array([color_to_rgb(c) for c in listify(color)]) if opacity is not None: @@ -947,7 +1056,7 @@ class Mobject(object): mob.data[name] = rgbas.copy() return self - def set_color(self, color, opacity=None, recurse=True): + def set_color(self, color: Color, opacity: float | None = None, recurse: bool = True): self.set_rgba_array_by_color(color, opacity, recurse=False) # Recurse to submobjects differently from how set_rgba_array_by_color # in case they implement set_color differently @@ -956,24 +1065,24 @@ class Mobject(object): submob.set_color(color, recurse=True) return self - def set_opacity(self, opacity, recurse=True): + def set_opacity(self, opacity: float, recurse: bool = True): self.set_rgba_array_by_color(color=None, opacity=opacity, recurse=False) if recurse: for submob in self.submobjects: submob.set_opacity(opacity, recurse=True) return self - def get_color(self): + def get_color(self) -> str: return rgb_to_hex(self.data["rgbas"][0, :3]) - def get_opacity(self): + def get_opacity(self) -> float: return self.data["rgbas"][0, 3] - def set_color_by_gradient(self, *colors): + def set_color_by_gradient(self, *colors: Color): self.set_submobject_colors_by_gradient(*colors) return self - def set_submobject_colors_by_gradient(self, *colors): + def set_submobject_colors_by_gradient(self, *colors: Color): if len(colors) == 0: raise Exception("Need at least one color") elif len(colors) == 1: @@ -987,36 +1096,41 @@ class Mobject(object): mob.set_color(color) return self - def fade(self, darkness=0.5, recurse=True): + def fade(self, darkness: float = 0.5, recurse: bool = True): self.set_opacity(1.0 - darkness, recurse=recurse) - def get_reflectiveness(self): + def get_reflectiveness(self) -> float: return self.uniforms["reflectiveness"] - def set_reflectiveness(self, reflectiveness, recurse=True): + def set_reflectiveness(self, reflectiveness: float, recurse: bool = True): for mob in self.get_family(recurse): mob.uniforms["reflectiveness"] = reflectiveness return self - def get_shadow(self): + def get_shadow(self) -> float: return self.uniforms["shadow"] - def set_shadow(self, shadow, recurse=True): + def set_shadow(self, shadow: float, recurse: bool = True): for mob in self.get_family(recurse): mob.uniforms["shadow"] = shadow return self - def get_gloss(self): + def get_gloss(self) -> float: return self.uniforms["gloss"] - def set_gloss(self, gloss, recurse=True): + def set_gloss(self, gloss: float, recurse: bool = True): for mob in self.get_family(recurse): mob.uniforms["gloss"] = gloss return self # Background rectangle - def add_background_rectangle(self, color=None, opacity=0.75, **kwargs): + def add_background_rectangle( + self, + color: Color | None = None, + opacity: float = 0.75, + **kwargs + ): # TODO, this does not behave well when the mobject has points, # since it gets displayed on top from manimlib.mobject.shape_matchers import BackgroundRectangle @@ -1040,7 +1154,7 @@ class Mobject(object): # Getters - def get_bounding_box_point(self, direction): + def get_bounding_box_point(self, direction: np.ndarray) -> np.ndarray: bb = self.get_bounding_box() indices = (np.sign(direction) + 1).astype(int) return np.array([ @@ -1048,19 +1162,19 @@ class Mobject(object): for i in range(3) ]) - def get_edge_center(self, direction): + def get_edge_center(self, direction: np.ndarray) -> np.ndarray: return self.get_bounding_box_point(direction) - def get_corner(self, direction): + def get_corner(self, direction: np.ndarray) -> np.ndarray: return self.get_bounding_box_point(direction) - def get_center(self): + def get_center(self) -> np.ndarray: return self.get_bounding_box()[1] - def get_center_of_mass(self): + def get_center_of_mass(self) -> np.ndarray: return self.get_all_points().mean(0) - def get_boundary_point(self, direction): + def get_boundary_point(self, direction: np.ndarray) -> np.ndarray: all_points = self.get_all_points() boundary_directions = all_points - self.get_center() norms = np.linalg.norm(boundary_directions, axis=1) @@ -1068,7 +1182,7 @@ class Mobject(object): index = np.argmax(np.dot(boundary_directions, np.array(direction).T)) return all_points[index] - def get_continuous_bounding_box_point(self, direction): + def get_continuous_bounding_box_point(self, direction: np.ndarray) -> np.ndarray: dl, center, ur = self.get_bounding_box() corner_vect = (ur - center) return center + direction / np.max(np.abs(np.true_divide( @@ -1077,66 +1191,66 @@ class Mobject(object): where=((corner_vect) != 0) ))) - def get_top(self): + def get_top(self) -> np.ndarray: return self.get_edge_center(UP) - def get_bottom(self): + def get_bottom(self) -> np.ndarray: return self.get_edge_center(DOWN) - def get_right(self): + def get_right(self) -> np.ndarray: return self.get_edge_center(RIGHT) - def get_left(self): + def get_left(self) -> np.ndarray: return self.get_edge_center(LEFT) - def get_zenith(self): + def get_zenith(self) -> np.ndarray: return self.get_edge_center(OUT) - def get_nadir(self): + def get_nadir(self) -> np.ndarray: return self.get_edge_center(IN) - def length_over_dim(self, dim): + def length_over_dim(self, dim: int) -> float: bb = self.get_bounding_box() return abs((bb[2] - bb[0])[dim]) - def get_width(self): + def get_width(self) -> float: return self.length_over_dim(0) - def get_height(self): + def get_height(self) -> float: return self.length_over_dim(1) - def get_depth(self): + def get_depth(self) -> float: return self.length_over_dim(2) - def get_coord(self, dim, direction=ORIGIN): + def get_coord(self, dim: int, direction: np.ndarray = ORIGIN) -> float: """ Meant to generalize get_x, get_y, get_z """ return self.get_bounding_box_point(direction)[dim] - def get_x(self, direction=ORIGIN): + def get_x(self, direction=ORIGIN) -> float: return self.get_coord(0, direction) - def get_y(self, direction=ORIGIN): + def get_y(self, direction=ORIGIN) -> float: return self.get_coord(1, direction) - def get_z(self, direction=ORIGIN): + def get_z(self, direction=ORIGIN) -> float: return self.get_coord(2, direction) - def get_start(self): + def get_start(self) -> np.ndarray: self.throw_error_if_no_points() return self.get_points()[0].copy() - def get_end(self): + def get_end(self) -> np.ndarray: self.throw_error_if_no_points() return self.get_points()[-1].copy() - def get_start_and_end(self): + def get_start_and_end(self) -> tuple(np.ndarray, np.ndarray): self.throw_error_if_no_points() points = self.get_points() return (points[0].copy(), points[-1].copy()) - def point_from_proportion(self, alpha): + def point_from_proportion(self, alpha: float) -> np.ndarray: points = self.get_points() i, subalpha = integer_interpolate(0, len(points) - 1, alpha) return interpolate(points[i], points[i + 1], subalpha) @@ -1145,7 +1259,7 @@ class Mobject(object): """Abbreviation fo point_from_proportion""" return self.point_from_proportion(alpha) - def get_pieces(self, n_pieces): + def get_pieces(self, n_pieces: int) -> Group: template = self.copy() template.set_submobjects([]) alphas = np.linspace(0, 1, n_pieces + 1) @@ -1163,41 +1277,45 @@ class Mobject(object): # Match other mobject properties - def match_color(self, mobject): + def match_color(self, mobject: "Mobject"): return self.set_color(mobject.get_color()) - def match_dim_size(self, mobject, dim, **kwargs): + def match_dim_size(self, mobject: "Mobject", dim: int, **kwargs): return self.rescale_to_fit( mobject.length_over_dim(dim), dim, **kwargs ) - def match_width(self, mobject, **kwargs): + def match_width(self, mobject: "Mobject", **kwargs): return self.match_dim_size(mobject, 0, **kwargs) - def match_height(self, mobject, **kwargs): + def match_height(self, mobject: "Mobject", **kwargs): return self.match_dim_size(mobject, 1, **kwargs) - def match_depth(self, mobject, **kwargs): + def match_depth(self, mobject: "Mobject", **kwargs): return self.match_dim_size(mobject, 2, **kwargs) - def match_coord(self, mobject, dim, direction=ORIGIN): + def match_coord(self, mobject: "Mobject", dim: int, direction: np.ndarray = ORIGIN): return self.set_coord( mobject.get_coord(dim, direction), dim=dim, direction=direction, ) - def match_x(self, mobject, direction=ORIGIN): + def match_x(self, mobject: "Mobject", direction: np.ndarray = ORIGIN): return self.match_coord(mobject, 0, direction) - def match_y(self, mobject, direction=ORIGIN): + def match_y(self, mobject: "Mobject", direction: np.ndarray = ORIGIN): return self.match_coord(mobject, 1, direction) - def match_z(self, mobject, direction=ORIGIN): + def match_z(self, mobject: "Mobject", direction: np.ndarray = ORIGIN): return self.match_coord(mobject, 2, direction) - def align_to(self, mobject_or_point, direction=ORIGIN): + def align_to( + self, + mobject_or_point: "Mobject" | np.ndarray, + direction: np.ndarray = ORIGIN + ): """ Examples: mob1.align_to(mob2, UP) moves mob1 vertically so that its @@ -1222,11 +1340,11 @@ class Mobject(object): # Alignment - def align_data_and_family(self, mobject): + def align_data_and_family(self, mobject: "Mobject") -> None: self.align_family(mobject) self.align_data(mobject) - def align_data(self, mobject): + def align_data(self, mobject: "Mobject") -> None: # In case any data arrays get resized when aligned to shader data self.refresh_shader_data() for mob1, mob2 in zip(self.get_family(), mobject.get_family()): @@ -1243,13 +1361,13 @@ class Mobject(object): elif len(arr1) > len(arr2): mob2.data[key] = resize_preserving_order(arr2, len(arr1)) - def align_points(self, mobject): + def align_points(self, mobject: "Mobject"): max_len = max(self.get_num_points(), mobject.get_num_points()) for mob in (self, mobject): mob.resize_points(max_len, resize_func=resize_preserving_order) return self - def align_family(self, mobject): + def align_family(self, mobject: "Mobject"): mob1 = self mob2 = mobject n1 = len(mob1) @@ -1269,7 +1387,7 @@ class Mobject(object): self.add(copy) return self - def add_n_more_submobjects(self, n): + def add_n_more_submobjects(self, n: int): if n == 0: return self @@ -1304,7 +1422,13 @@ class Mobject(object): # Interpolate - def interpolate(self, mobject1, mobject2, alpha, path_func=straight_path): + def interpolate( + self, + mobject1: "Mobject", + mobject2: "Mobject", + alpha: float, + path_func: Callable[[np.ndarray, np.ndarray, float], np.ndarray] = straight_path + ): for key in self.data: if key in self.locked_data_keys: continue @@ -1340,7 +1464,7 @@ class Mobject(object): """ pass # To implement in subclass - def become(self, mobject): + def become(self, mobject: "Mobject"): """ Edit all data and submobjects to be idential to another mobject @@ -1354,7 +1478,7 @@ class Mobject(object): # Locking data - def lock_data(self, keys): + def lock_data(self, keys: Iterable[str]): """ To speed up some animations, particularly transformations, it can be handy to acknowledge which pieces of data @@ -1368,7 +1492,7 @@ class Mobject(object): self.refresh_shader_data() self.locked_data_keys = set(keys) - def lock_matching_data(self, mobject1, mobject2): + def lock_matching_data(self, mobject1: "Mobject", mobject2: "Mobject"): for sm, sm1, sm2 in zip(self.get_family(), mobject1.get_family(), mobject2.get_family()): keys = sm.data.keys() & sm1.data.keys() & sm2.data.keys() sm.lock_data(list(filter( @@ -1416,7 +1540,7 @@ class Mobject(object): # Shader code manipulation - def replace_shader_code(self, old, new): + def replace_shader_code(self, old: str, new: str): # TODO, will this work with VMobject structure, given # that it does not simpler return shader_wrappers of # family? @@ -1424,7 +1548,7 @@ class Mobject(object): wrapper.replace_code(old, new) return self - def set_color_by_code(self, glsl_code): + def set_color_by_code(self, glsl_code: str): """ Takes a snippet of code and inserts it into a context which has the following variables: @@ -1437,9 +1561,13 @@ class Mobject(object): ) return self - def set_color_by_xyz_func(self, glsl_snippet, - min_value=-5.0, max_value=5.0, - colormap="viridis"): + def set_color_by_xyz_func( + self, + glsl_snippet: str, + min_value: float = -5.0, + max_value: float = 5.0, + colormap: str = "viridis" + ): """ Pass in a glsl expression in terms of x, y and z which returns a float. @@ -1484,7 +1612,7 @@ class Mobject(object): self.shader_wrapper.depth_test = self.depth_test return self.shader_wrapper - def get_shader_wrapper_list(self): + def get_shader_wrapper_list(self) -> list[ShaderWrapper]: shader_wrappers = it.chain( [self.get_shader_wrapper()], *[sm.get_shader_wrapper_list() for sm in self.submobjects] @@ -1501,7 +1629,7 @@ class Mobject(object): result.append(shader_wrapper) return result - def check_data_alignment(self, array, data_key): + def check_data_alignment(self, array: Iterable, data_key: str): # Makes sure that self.data[key] can be broadcast into # the given array, meaning its length has to be either 1 # or the length of the array @@ -1512,14 +1640,19 @@ class Mobject(object): ) return self - def get_resized_shader_data_array(self, length): + def get_resized_shader_data_array(self, length: int) -> np.ndarray: # If possible, try to populate an existing array, rather # than recreating it each frame if len(self.shader_data) != length: self.shader_data = resize_array(self.shader_data, length) return self.shader_data - def read_data_to_shader(self, shader_data, shader_data_key, data_key): + def read_data_to_shader( + self, + shader_data: np.ndarray, + shader_data_key: str, + data_key: str + ): if data_key in self.locked_data_keys: return self.check_data_alignment(shader_data, data_key) @@ -1551,22 +1684,22 @@ class Mobject(object): """ def init_event_listners(self): - self.event_listners = [] + self.event_listners: list[EventListner] = [] - def add_event_listner(self, event_type, event_callback): + def add_event_listner(self, event_type: EventType, event_callback: Callable): event_listner = EventListner(self, event_type, event_callback) self.event_listners.append(event_listner) EVENT_DISPATCHER.add_listner(event_listner) return self - def remove_event_listner(self, event_type, event_callback): + def remove_event_listner(self, event_type: EventType, event_callback: Callable): event_listner = EventListner(self, event_type, event_callback) while event_listner in self.event_listners: self.event_listners.remove(event_listner) EVENT_DISPATCHER.remove_listner(event_listner) return self - def clear_event_listners(self, recurse=True): + def clear_event_listners(self, recurse: bool = True): self.event_listners = [] if recurse: for submob in self.submobjects: @@ -1638,13 +1771,13 @@ class Mobject(object): class Group(Mobject): - def __init__(self, *mobjects, **kwargs): + def __init__(self, *mobjects: "Mobject", **kwargs): if not all([isinstance(m, Mobject) for m in mobjects]): raise Exception("All submobjects must be of type Mobject") Mobject.__init__(self, **kwargs) self.add(*mobjects) - def __add__(self, other: 'Mobject' or 'Group'): + def __add__(self, other: "Mobject" | "Group"): assert(isinstance(other, Mobject)) return self.add(other) @@ -1655,35 +1788,35 @@ class Point(Mobject): "artificial_height": 1e-6, } - def __init__(self, location=ORIGIN, **kwargs): + def __init__(self, location: npt.ArrayLike = ORIGIN, **kwargs): Mobject.__init__(self, **kwargs) self.set_location(location) - def get_width(self): + def get_width(self) -> float: return self.artificial_width - def get_height(self): + def get_height(self) -> float: return self.artificial_height - def get_location(self): + def get_location(self) -> np.ndarray: return self.get_points()[0].copy() - def get_bounding_box_point(self, *args, **kwargs): + def get_bounding_box_point(self, *args, **kwargs) -> np.ndarray: return self.get_location() - def set_location(self, new_loc): + def set_location(self, new_loc: npt.ArrayLike): self.set_points(np.array(new_loc, ndmin=2, dtype=float)) class _AnimationBuilder: - def __init__(self, mobject): + def __init__(self, mobject: Mobject): self.mobject = mobject self.overridden_animation = None self.mobject.generate_target() self.is_chaining = False self.methods = [] - def __getattr__(self, method_name): + def __getattr__(self, method_name: str): method = getattr(self.mobject.target, method_name) self.methods.append(method) has_overridden_animation = hasattr(method, "_override_animate")