Add shorthand for type np.ndarray[int, np.dtype[np.float64]]

This commit is contained in:
Grant Sanderson 2022-12-14 16:17:15 -08:00
parent ca1ba67a85
commit 53994f0650

View file

@ -56,6 +56,7 @@ if TYPE_CHECKING:
Updater = Union[TimeBasedUpdater, NonTimeUpdater]
from manimlib.constants import ManimColor
np_vector = np.ndarray[int, np.dtype[np.float64]]
class Mobject(object):
"""
@ -194,8 +195,8 @@ class Mobject(object):
def apply_points_function(
self,
func: Callable[[np.ndarray], np.ndarray],
about_point: np.ndarray = None,
about_edge: np.ndarray = ORIGIN,
about_point: np_vector = None,
about_edge: np_vector = ORIGIN,
works_on_bounding_box: bool = False
):
if about_point is None and about_edge is not None:
@ -227,7 +228,7 @@ class Mobject(object):
self.set_points(mobject.get_points())
return self
def get_points(self) -> np.ndarray:
def get_points(self) -> np_vector:
return self.data["points"]
def clear_points(self) -> None:
@ -236,7 +237,7 @@ class Mobject(object):
def get_num_points(self) -> int:
return len(self.data["points"])
def get_all_points(self) -> np.ndarray:
def get_all_points(self) -> np_vector:
if self.submobjects:
return np.vstack([sm.get_points() for sm in self.get_family()])
else:
@ -245,13 +246,13 @@ class Mobject(object):
def has_points(self) -> bool:
return self.get_num_points() > 0
def get_bounding_box(self) -> np.ndarray:
def get_bounding_box(self) -> np_vector:
if self.needs_new_bounding_box:
self.data["bounding_box"] = self.compute_bounding_box()
self.needs_new_bounding_box = False
return self.data["bounding_box"]
def compute_bounding_box(self) -> np.ndarray:
def compute_bounding_box(self) -> np_vector:
all_points = np.vstack([
self.get_points(),
*(
@ -283,7 +284,7 @@ class Mobject(object):
def are_points_touching(
self,
points: np.ndarray,
points: np_vector,
buff: float = 0
) -> bool:
bb = self.get_bounding_box()
@ -293,7 +294,7 @@ class Mobject(object):
def is_point_touching(
self,
point: np.ndarray,
point: np_vector,
buff: float = 0
) -> bool:
return self.are_points_touching(np.array(point, ndmin=2), buff)[0]
@ -418,7 +419,7 @@ class Mobject(object):
def arrange(
self,
direction: np.ndarray = RIGHT,
direction: np_vector = RIGHT,
center: bool = True,
**kwargs
):
@ -438,7 +439,7 @@ class Mobject(object):
buff_ratio: float | None = None,
h_buff_ratio: float = 0.5,
v_buff_ratio: float = 0.5,
aligned_edge: np.ndarray = ORIGIN,
aligned_edge: np_vector = ORIGIN,
fill_rows_first: bool = True
):
submobs = self.submobjects
@ -821,7 +822,7 @@ class Mobject(object):
# Transforming operations
def shift(self, vector: np.ndarray):
def shift(self, vector: np_vector):
self.apply_points_function(
lambda points: points + vector,
about_edge=None,
@ -833,8 +834,8 @@ class Mobject(object):
self,
scale_factor: float | npt.ArrayLike,
min_scale_factor: float = 1e-8,
about_point: np.ndarray | None = None,
about_edge: np.ndarray = ORIGIN
about_point: np_vector | None = None,
about_edge: np_vector = ORIGIN
):
"""
Default behavior is to scale about the center of the mobject.
@ -871,14 +872,14 @@ class Mobject(object):
self.apply_points_function(func, works_on_bounding_box=True, **kwargs)
return self
def rotate_about_origin(self, angle: float, axis: np.ndarray = OUT):
def rotate_about_origin(self, angle: float, axis: np_vector = OUT):
return self.rotate(angle, axis, about_point=ORIGIN)
def rotate(
self,
angle: float,
axis: np.ndarray = OUT,
about_point: np.ndarray | None = None,
axis: np_vector = OUT,
about_point: np_vector | None = None,
**kwargs
):
rot_matrix_T = rotation_matrix_transpose(angle, axis)
@ -889,7 +890,7 @@ class Mobject(object):
)
return self
def flip(self, axis: np.ndarray = UP, **kwargs):
def flip(self, axis: np_vector = UP, **kwargs):
return self.rotate(TAU / 2, axis, **kwargs)
def apply_function(self, function: Callable[[np.ndarray], np.ndarray], **kwargs):
@ -940,8 +941,8 @@ class Mobject(object):
def wag(
self,
direction: np.ndarray = RIGHT,
axis: np.ndarray = DOWN,
direction: np_vector = RIGHT,
axis: np_vector = DOWN,
wag_factor: float = 1.0
):
for mob in self.family_members_with_points():
@ -963,7 +964,7 @@ class Mobject(object):
def align_on_border(
self,
direction: np.ndarray,
direction: np_vector,
buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER
):
"""
@ -979,27 +980,27 @@ class Mobject(object):
def to_corner(
self,
corner: np.ndarray = LEFT + DOWN,
corner: np_vector = LEFT + DOWN,
buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER
):
return self.align_on_border(corner, buff)
def to_edge(
self,
edge: np.ndarray = LEFT,
edge: np_vector = LEFT,
buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER
):
return self.align_on_border(edge, buff)
def next_to(
self,
mobject_or_point: Mobject | np.ndarray,
direction: np.ndarray = RIGHT,
mobject_or_point: Mobject | np_vector,
direction: np_vector = RIGHT,
buff: float = DEFAULT_MOBJECT_TO_MOBJECT_BUFFER,
aligned_edge: np.ndarray = ORIGIN,
aligned_edge: np_vector = ORIGIN,
submobject_to_align: Mobject | None = None,
index_of_submobject_to_align: int | slice | None = None,
coor_mask: np.ndarray = np.array([1, 1, 1]),
coor_mask: np_vector = np.array([1, 1, 1]),
):
if isinstance(mobject_or_point, Mobject):
mob = mobject_or_point
@ -1044,7 +1045,7 @@ class Mobject(object):
return True
return False
def stretch_about_point(self, factor: float, dim: int, point: np.ndarray):
def stretch_about_point(self, factor: float, dim: int, point: np_vector):
return self.stretch(factor, dim, about_point=point)
def stretch_in_place(self, factor: float, dim: int):
@ -1109,20 +1110,20 @@ class Mobject(object):
self.set_depth(min_depth, **kwargs)
return self
def set_coord(self, value: float, dim: int, direction: np.ndarray = ORIGIN):
def set_coord(self, value: float, dim: int, direction: np_vector = ORIGIN):
curr = self.get_coord(dim, direction)
shift_vect = np.zeros(self.dim)
shift_vect[dim] = value - curr
self.shift(shift_vect)
return self
def set_x(self, x: float, direction: np.ndarray = ORIGIN):
def set_x(self, x: float, direction: np_vector = ORIGIN):
return self.set_coord(x, 0, direction)
def set_y(self, y: float, direction: np.ndarray = ORIGIN):
def set_y(self, y: float, direction: np_vector = ORIGIN):
return self.set_coord(y, 1, direction)
def set_z(self, z: float, direction: np.ndarray = ORIGIN):
def set_z(self, z: float, direction: np_vector = ORIGIN):
return self.set_coord(z, 2, direction)
def space_out_submobjects(self, factor: float = 1.5, **kwargs):
@ -1133,9 +1134,9 @@ class Mobject(object):
def move_to(
self,
point_or_mobject: Mobject | np.ndarray,
aligned_edge: np.ndarray = ORIGIN,
coor_mask: np.ndarray = np.array([1, 1, 1])
point_or_mobject: Mobject | np_vector,
aligned_edge: np_vector = ORIGIN,
coor_mask: np_vector = np.array([1, 1, 1])
):
if isinstance(point_or_mobject, Mobject):
target = point_or_mobject.get_bounding_box_point(aligned_edge)
@ -1173,7 +1174,7 @@ class Mobject(object):
self.scale((length + buff) / length)
return self
def put_start_and_end_on(self, start: np.ndarray, end: np.ndarray):
def put_start_and_end_on(self, start: np_vector, end: np_vector):
curr_start, curr_end = self.get_start_and_end()
curr_vect = curr_end - curr_start
if np.all(curr_vect == 0):
@ -1367,7 +1368,7 @@ class Mobject(object):
# Getters
def get_bounding_box_point(self, direction: np.ndarray) -> np.ndarray:
def get_bounding_box_point(self, direction: np_vector) -> np_vector:
bb = self.get_bounding_box()
indices = (np.sign(direction) + 1).astype(int)
return np.array([
@ -1375,10 +1376,10 @@ class Mobject(object):
for i in range(3)
])
def get_edge_center(self, direction: np.ndarray) -> np.ndarray:
def get_edge_center(self, direction: np_vector) -> np_vector:
return self.get_bounding_box_point(direction)
def get_corner(self, direction: np.ndarray) -> np.ndarray:
def get_corner(self, direction: np_vector) -> np_vector:
return self.get_bounding_box_point(direction)
def get_all_corners(self):
@ -1388,13 +1389,13 @@ class Mobject(object):
for indices in it.product([0, 2], repeat=3)
])
def get_center(self) -> np.ndarray:
def get_center(self) -> np_vector:
return self.get_bounding_box()[1]
def get_center_of_mass(self) -> np.ndarray:
def get_center_of_mass(self) -> np_vector:
return self.get_all_points().mean(0)
def get_boundary_point(self, direction: np.ndarray) -> np.ndarray:
def get_boundary_point(self, direction: np_vector) -> np_vector:
all_points = self.get_all_points()
boundary_directions = all_points - self.get_center()
norms = np.linalg.norm(boundary_directions, axis=1)
@ -1402,7 +1403,7 @@ class Mobject(object):
index = np.argmax(np.dot(boundary_directions, np.array(direction).T))
return all_points[index]
def get_continuous_bounding_box_point(self, direction: np.ndarray) -> np.ndarray:
def get_continuous_bounding_box_point(self, direction: np_vector) -> np_vector:
dl, center, ur = self.get_bounding_box()
corner_vect = (ur - center)
return center + direction / np.max(np.abs(np.true_divide(
@ -1411,22 +1412,22 @@ class Mobject(object):
where=((corner_vect) != 0)
)))
def get_top(self) -> np.ndarray:
def get_top(self) -> np_vector:
return self.get_edge_center(UP)
def get_bottom(self) -> np.ndarray:
def get_bottom(self) -> np_vector:
return self.get_edge_center(DOWN)
def get_right(self) -> np.ndarray:
def get_right(self) -> np_vector:
return self.get_edge_center(RIGHT)
def get_left(self) -> np.ndarray:
def get_left(self) -> np_vector:
return self.get_edge_center(LEFT)
def get_zenith(self) -> np.ndarray:
def get_zenith(self) -> np_vector:
return self.get_edge_center(OUT)
def get_nadir(self) -> np.ndarray:
def get_nadir(self) -> np_vector:
return self.get_edge_center(IN)
def length_over_dim(self, dim: int) -> float:
@ -1442,7 +1443,7 @@ class Mobject(object):
def get_depth(self) -> float:
return self.length_over_dim(2)
def get_coord(self, dim: int, direction: np.ndarray = ORIGIN) -> float:
def get_coord(self, dim: int, direction: np_vector = ORIGIN) -> float:
"""
Meant to generalize get_x, get_y, get_z
"""
@ -1457,20 +1458,20 @@ class Mobject(object):
def get_z(self, direction=ORIGIN) -> float:
return self.get_coord(2, direction)
def get_start(self) -> np.ndarray:
def get_start(self) -> np_vector:
self.throw_error_if_no_points()
return self.get_points()[0].copy()
def get_end(self) -> np.ndarray:
def get_end(self) -> np_vector:
self.throw_error_if_no_points()
return self.get_points()[-1].copy()
def get_start_and_end(self) -> tuple(np.ndarray, np.ndarray):
def get_start_and_end(self) -> tuple[np_vector, np_vector]:
self.throw_error_if_no_points()
points = self.get_points()
return (points[0].copy(), points[-1].copy())
def point_from_proportion(self, alpha: float) -> np.ndarray:
def point_from_proportion(self, alpha: float) -> np_vector:
points = self.get_points()
i, subalpha = integer_interpolate(0, len(points) - 1, alpha)
return interpolate(points[i], points[i + 1], subalpha)
@ -1517,9 +1518,9 @@ class Mobject(object):
def match_coord(
self,
mobject_or_point: Mobject | np.ndarray,
mobject_or_point: Mobject | np_vector,
dim: int,
direction: np.ndarray = ORIGIN
direction: np_vector = ORIGIN
):
if isinstance(mobject_or_point, Mobject):
coord = mobject_or_point.get_coord(dim, direction)
@ -1529,29 +1530,29 @@ class Mobject(object):
def match_x(
self,
mobject_or_point: Mobject | np.ndarray,
direction: np.ndarray = ORIGIN
mobject_or_point: Mobject | np_vector,
direction: np_vector = ORIGIN
):
return self.match_coord(mobject_or_point, 0, direction)
def match_y(
self,
mobject_or_point: Mobject | np.ndarray,
direction: np.ndarray = ORIGIN
mobject_or_point: Mobject | np_vector,
direction: np_vector = ORIGIN
):
return self.match_coord(mobject_or_point, 1, direction)
def match_z(
self,
mobject_or_point: Mobject | np.ndarray,
direction: np.ndarray = ORIGIN
mobject_or_point: Mobject | np_vector,
direction: np_vector = ORIGIN
):
return self.match_coord(mobject_or_point, 2, direction)
def align_to(
self,
mobject_or_point: Mobject | np.ndarray,
direction: np.ndarray = ORIGIN
mobject_or_point: Mobject | np_vector,
direction: np_vector = ORIGIN
):
"""
Examples:
@ -1865,7 +1866,7 @@ class Mobject(object):
)
return self
def get_resized_shader_data_array(self, length: int) -> np.ndarray:
def get_resized_shader_data_array(self, length: int) -> np_vector:
# If possible, try to populate an existing array, rather
# than recreating it each frame
if len(self.shader_data) != length:
@ -1874,7 +1875,7 @@ class Mobject(object):
def read_data_to_shader(
self,
shader_data: np.ndarray,
shader_data: np_vector,
shader_data_key: str,
data_key: str
):
@ -2031,10 +2032,10 @@ class Point(Mobject):
def get_height(self) -> float:
return self.artificial_height
def get_location(self) -> np.ndarray:
def get_location(self) -> np_vector:
return self.get_points()[0].copy()
def get_bounding_box_point(self, *args, **kwargs) -> np.ndarray:
def get_bounding_box_point(self, *args, **kwargs) -> np_vector:
return self.get_location()
def set_location(self, new_loc: npt.ArrayLike):