diff --git a/manimlib/utils/space_ops.py b/manimlib/utils/space_ops.py index a3d1ab90..8a6acd07 100644 --- a/manimlib/utils/space_ops.py +++ b/manimlib/utils/space_ops.py @@ -18,29 +18,31 @@ from manimlib.utils.simple_functions import clip from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Callable, Iterable, Sequence - - import numpy.typing as npt + from typing import Callable, Sequence, List, Tuple + from manimlib.typing import ManimColor, Vect2, Vect3, Vect4, VectN, Matrix3x3 -def cross(v1: np.ndarray, v2: np.ndarray) -> list[np.ndarray]: - return [ +def cross(v1: Vect3 | List[float], v2: Vect3 | List[float]) -> Vect3: + return np.array([ v1[1] * v2[2] - v1[2] * v2[1], v1[2] * v2[0] - v1[0] * v2[2], v1[0] * v2[1] - v1[1] * v2[0] - ] + ]) -def get_norm(vect: Iterable) -> float: +def get_norm(vect: VectN | List[float]) -> float: return sum((x**2 for x in vect))**0.5 -def normalize(vect: np.ndarray, fall_back: np.ndarray | None = None) -> np.ndarray: +def normalize( + vect: VectN | List[float], + fall_back: VectN | List[float] | None = None +) -> VectN: norm = get_norm(vect) if norm > 0: return np.array(vect) / norm elif fall_back is not None: - return fall_back + return np.array(fall_back) else: return np.zeros(len(vect)) @@ -48,15 +50,18 @@ def normalize(vect: np.ndarray, fall_back: np.ndarray | None = None) -> np.ndarr # Operations related to rotation -def quaternion_mult(*quats: Sequence[float]) -> list[float]: - # Real part is last entry, which is bizzare, but fits scipy Rotation convention +def quaternion_mult(*quats: Vect4) -> Vect4: + """ + Inputs are treated as quaternions, where the real part is the + last entry, so as to follow the scipy Rotation conventions. + """ if len(quats) == 0: - return [0, 0, 0, 1] - result = quats[0] + return np.array([0, 0, 0, 1]) + result = np.array(quats[0]) for next_quat in quats[1:]: x1, y1, z1, w1 = result x2, y2, z2, w2 = next_quat - result = [ + result[:] = [ w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2, w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2, @@ -67,67 +72,68 @@ def quaternion_mult(*quats: Sequence[float]) -> list[float]: def quaternion_from_angle_axis( angle: float, - axis: np.ndarray, -) -> list[float]: + axis: Vect3, +) -> Vect4: return Rotation.from_rotvec(angle * normalize(axis)).as_quat() -def angle_axis_from_quaternion(quat: Sequence[float]) -> tuple[float, np.ndarray]: +def angle_axis_from_quaternion(quat: Vect4) -> Tuple[float, Vect3]: rot_vec = Rotation.from_quat(quat).as_rotvec() norm = get_norm(rot_vec) return norm, rot_vec / norm -def quaternion_conjugate(quaternion: Iterable) -> list: - result = list(quaternion) - for i in range(3): - result[i] *= -1 +def quaternion_conjugate(quaternion: Vect4) -> Vect4: + result = np.array(quaternion) + result[:3] *= -1 return result def rotate_vector( - vector: Iterable, + vector: Vect3, angle: float, - axis: np.ndarray = OUT -) -> np.ndarray | list[float]: + axis: Vect3 = OUT +) -> Vect3: rot = Rotation.from_rotvec(angle * normalize(axis)) return np.dot(vector, rot.as_matrix().T) -def rotate_vector_2d(vector: Iterable, angle: float): +def rotate_vector_2d(vector: Vect2, angle: float) -> Vect2: # Use complex numbers...because why not z = complex(*vector) * np.exp(complex(0, angle)) return np.array([z.real, z.imag]) -def rotation_matrix_transpose_from_quaternion(quat: Iterable) -> np.ndarray: +def rotation_matrix_transpose_from_quaternion(quat: Vect4) -> Matrix3x3: return Rotation.from_quat(quat).as_matrix() -def rotation_matrix_from_quaternion(quat: Iterable) -> np.ndarray: +def rotation_matrix_from_quaternion(quat: Vect4) -> Matrix3x3: return np.transpose(rotation_matrix_transpose_from_quaternion(quat)) -def rotation_matrix(angle: float, axis: np.ndarray) -> np.ndarray: +def rotation_matrix(angle: float, axis: Vect3) -> Matrix3x3: """ Rotation in R^3 about a specified axis of rotation. """ return Rotation.from_rotvec(angle * normalize(axis)).as_matrix() -def rotation_matrix_transpose(angle: float, axis: np.ndarray) -> np.ndarray: +def rotation_matrix_transpose(angle: float, axis: Vect3) -> Matrix3x3: return rotation_matrix(angle, axis).T -def rotation_about_z(angle: float) -> list[list[float]]: - return [ - [math.cos(angle), -math.sin(angle), 0], - [math.sin(angle), math.cos(angle), 0], +def rotation_about_z(angle: float) -> Matrix3x3: + cos_a = math.cos(angle) + sin_a = math.sin(angle) + return np.array([ + [cos_a, -sin_a, 0], + [sin_a, cos_a, 0], [0, 0, 1] - ] + ]) -def rotation_between_vectors(v1, v2) -> np.ndarray: +def rotation_between_vectors(v1: Vect3, v2: Vect3) -> Matrix3x3: if np.all(np.isclose(v1, v2)): return np.identity(3) return rotation_matrix( @@ -136,18 +142,18 @@ def rotation_between_vectors(v1, v2) -> np.ndarray: ) -def z_to_vector(vector: np.ndarray) -> np.ndarray: +def z_to_vector(vector: Vect3) -> Matrix3x3: return rotation_between_vectors(OUT, vector) -def angle_of_vector(vector: Sequence[float]) -> float: +def angle_of_vector(vector: Vect2 | Vect3) -> float: """ Returns polar coordinate theta when vector is project on xy plane """ return np.angle(complex(*vector[:2])) -def angle_between_vectors(v1: np.ndarray, v2: np.ndarray) -> float: +def angle_between_vectors(v1: VectN, v2: VectN) -> float: """ Returns the angle between two 3D vectors. This angle will always be btw 0 and pi @@ -160,7 +166,7 @@ def angle_between_vectors(v1: np.ndarray, v2: np.ndarray) -> float: return math.acos(clip(cos_angle, -1, 1)) -def project_along_vector(point: np.ndarray, vector: np.ndarray) -> np.ndarray: +def project_along_vector(point: Vect3, vector: Vect3) -> Vect3: matrix = np.identity(3) - np.outer(vector, vector) return np.dot(point, matrix.T) @@ -177,10 +183,10 @@ def normalize_along_axis( def get_unit_normal( - v1: np.ndarray, - v2: np.ndarray, + v1: Vect3, + v2: Vect3, tol: float = 1e-6 -) -> np.ndarray: +) -> Vect3: v1 = normalize(v1) v2 = normalize(v2) cp = cross(v1, v2) @@ -204,7 +210,7 @@ def thick_diagonal(dim: int, thickness: int = 2) -> np.ndarray: return (np.abs(row_indices - col_indices) < thickness).astype('uint8') -def compass_directions(n: int = 4, start_vect: np.ndarray = RIGHT) -> np.ndarray: +def compass_directions(n: int = 4, start_vect: Vect3 = RIGHT) -> Vect3: angle = TAU / n return np.array([ rotate_vector(start_vect, k * angle) @@ -212,36 +218,32 @@ def compass_directions(n: int = 4, start_vect: np.ndarray = RIGHT) -> np.ndarray ]) -def complex_to_R3(complex_num: complex) -> np.ndarray: +def complex_to_R3(complex_num: complex) -> Vect3: return np.array((complex_num.real, complex_num.imag, 0)) -def R3_to_complex(point: Sequence[float]) -> complex: +def R3_to_complex(point: Vect3) -> complex: return complex(*point[:2]) -def complex_func_to_R3_func( - complex_func: Callable[[complex], complex] -) -> Callable[[np.ndarray], np.ndarray]: - return lambda p: complex_to_R3(complex_func(R3_to_complex(p))) +def complex_func_to_R3_func(complex_func: Callable[[complex], complex]) -> Callable[[Vect3], Vect3]: + def result(p: Vect3): + return complex_to_R3(complex_func(R3_to_complex(p))) + return result -def center_of_mass(points: Iterable[npt.ArrayLike]) -> np.ndarray: - points = [np.array(point).astype("float") for point in points] - return sum(points) / len(points) +def center_of_mass(points: Sequence[Vect3]) -> Vect3: + return np.array(points).sum(0) / len(points) -def midpoint( - point1: Sequence[float], - point2: Sequence[float] -) -> np.ndarray: +def midpoint(point1: VectN, point2: VectN) -> VectN: return center_of_mass([point1, point2]) def line_intersection( - line1: Sequence[Sequence[float]], - line2: Sequence[Sequence[float]] -) -> np.ndarray: + line1: Tuple[Vect3, Vect3], + line2: Tuple[Vect3, Vect3] +) -> Vect3: """ return intersection point of two lines, each defined with a pair of vectors determining @@ -263,12 +265,12 @@ def line_intersection( def find_intersection( - p0: npt.ArrayLike, - v0: npt.ArrayLike, - p1: npt.ArrayLike, - v1: npt.ArrayLike, + p0: Vect3, + v0: Vect3, + p1: Vect3, + v1: Vect3, threshold: float = 1e-5 -) -> np.ndarray: +) -> Vect3: """ Return the intersection of a line passing through p0 in direction v0 with one passing through p1 in direction v1. (Or array of intersections @@ -300,11 +302,7 @@ def find_intersection( return result -def get_closest_point_on_line( - a: np.ndarray, - b: np.ndarray, - p: np.ndarray -) -> np.ndarray: +def get_closest_point_on_line(a: VectN, b: VectN, p: VectN) -> VectN: """ It returns point x such that x is on line ab and xp is perpendicular to ab. @@ -319,7 +317,7 @@ def get_closest_point_on_line( return ((t * a) + ((1 - t) * b)) -def get_winding_number(points: Iterable[float]) -> float: +def get_winding_number(points: Sequence[Vect2 | Vect3]) -> float: total_angle = 0 for p1, p2 in adjacent_pairs(points): d_angle = angle_of_vector(p2) - angle_of_vector(p1) @@ -330,7 +328,7 @@ def get_winding_number(points: Iterable[float]) -> float: ## -def cross2d(a: np.ndarray, b: np.ndarray) -> np.ndarray: +def cross2d(a: Vect2, b: Vect2) -> Vect2: if len(a.shape) == 2: return a[:, 0] * b[:, 1] - a[:, 1] * b[:, 0] else: @@ -338,9 +336,9 @@ def cross2d(a: np.ndarray, b: np.ndarray) -> np.ndarray: def tri_area( - a: Sequence[float], - b: Sequence[float], - c: Sequence[float] + a: Vect2, + b: Vect2, + c: Vect2 ) -> float: return 0.5 * abs( a[0] * (b[1] - c[1]) + @@ -350,10 +348,10 @@ def tri_area( def is_inside_triangle( - p: np.ndarray, - a: np.ndarray, - b: np.ndarray, - c: np.ndarray + p: Vect2, + a: Vect2, + b: Vect2, + c: Vect2 ) -> bool: """ Test if point p is inside triangle abc @@ -363,15 +361,15 @@ def is_inside_triangle( cross2d(p - b, c - p), cross2d(p - c, a - p), ]) - return np.all(crosses > 0) or np.all(crosses < 0) + return bool(np.all(crosses > 0) or np.all(crosses < 0)) -def norm_squared(v: Sequence[float]) -> float: - return v[0] * v[0] + v[1] * v[1] + v[2] * v[2] +def norm_squared(v: VectN | List[float]) -> float: + return sum(x * x for x in v) # TODO, fails for polygons drawn over themselves -def earclip_triangulation(verts: np.ndarray, ring_ends: list[int]) -> list: +def earclip_triangulation(verts: Vect2 | Vect3, ring_ends: list[int]) -> list[int]: """ Returns a list of indices giving a triangulation of a polygon, potentially with holes