Add better types + Small refactors on space_ops

This commit is contained in:
Grant Sanderson 2022-12-16 20:35:45 -08:00
parent dec11a4b17
commit cef6506920

View file

@ -18,29 +18,31 @@ from manimlib.utils.simple_functions import clip
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable, Iterable, Sequence from typing import Callable, Sequence, List, Tuple
from manimlib.typing import ManimColor, Vect2, Vect3, Vect4, VectN, Matrix3x3
import numpy.typing as npt
def cross(v1: np.ndarray, v2: np.ndarray) -> list[np.ndarray]: def cross(v1: Vect3 | List[float], v2: Vect3 | List[float]) -> Vect3:
return [ return np.array([
v1[1] * v2[2] - v1[2] * v2[1], v1[1] * v2[2] - v1[2] * v2[1],
v1[2] * v2[0] - v1[0] * v2[2], v1[2] * v2[0] - v1[0] * v2[2],
v1[0] * v2[1] - v1[1] * v2[0] v1[0] * v2[1] - v1[1] * v2[0]
] ])
def get_norm(vect: Iterable) -> float: def get_norm(vect: VectN | List[float]) -> float:
return sum((x**2 for x in vect))**0.5 return sum((x**2 for x in vect))**0.5
def normalize(vect: np.ndarray, fall_back: np.ndarray | None = None) -> np.ndarray: def normalize(
vect: VectN | List[float],
fall_back: VectN | List[float] | None = None
) -> VectN:
norm = get_norm(vect) norm = get_norm(vect)
if norm > 0: if norm > 0:
return np.array(vect) / norm return np.array(vect) / norm
elif fall_back is not None: elif fall_back is not None:
return fall_back return np.array(fall_back)
else: else:
return np.zeros(len(vect)) return np.zeros(len(vect))
@ -48,15 +50,18 @@ def normalize(vect: np.ndarray, fall_back: np.ndarray | None = None) -> np.ndarr
# Operations related to rotation # Operations related to rotation
def quaternion_mult(*quats: Sequence[float]) -> list[float]: def quaternion_mult(*quats: Vect4) -> Vect4:
# Real part is last entry, which is bizzare, but fits scipy Rotation convention """
Inputs are treated as quaternions, where the real part is the
last entry, so as to follow the scipy Rotation conventions.
"""
if len(quats) == 0: if len(quats) == 0:
return [0, 0, 0, 1] return np.array([0, 0, 0, 1])
result = quats[0] result = np.array(quats[0])
for next_quat in quats[1:]: for next_quat in quats[1:]:
x1, y1, z1, w1 = result x1, y1, z1, w1 = result
x2, y2, z2, w2 = next_quat x2, y2, z2, w2 = next_quat
result = [ result[:] = [
w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2, w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2,
w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2, w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2,
@ -67,67 +72,68 @@ def quaternion_mult(*quats: Sequence[float]) -> list[float]:
def quaternion_from_angle_axis( def quaternion_from_angle_axis(
angle: float, angle: float,
axis: np.ndarray, axis: Vect3,
) -> list[float]: ) -> Vect4:
return Rotation.from_rotvec(angle * normalize(axis)).as_quat() return Rotation.from_rotvec(angle * normalize(axis)).as_quat()
def angle_axis_from_quaternion(quat: Sequence[float]) -> tuple[float, np.ndarray]: def angle_axis_from_quaternion(quat: Vect4) -> Tuple[float, Vect3]:
rot_vec = Rotation.from_quat(quat).as_rotvec() rot_vec = Rotation.from_quat(quat).as_rotvec()
norm = get_norm(rot_vec) norm = get_norm(rot_vec)
return norm, rot_vec / norm return norm, rot_vec / norm
def quaternion_conjugate(quaternion: Iterable) -> list: def quaternion_conjugate(quaternion: Vect4) -> Vect4:
result = list(quaternion) result = np.array(quaternion)
for i in range(3): result[:3] *= -1
result[i] *= -1
return result return result
def rotate_vector( def rotate_vector(
vector: Iterable, vector: Vect3,
angle: float, angle: float,
axis: np.ndarray = OUT axis: Vect3 = OUT
) -> np.ndarray | list[float]: ) -> Vect3:
rot = Rotation.from_rotvec(angle * normalize(axis)) rot = Rotation.from_rotvec(angle * normalize(axis))
return np.dot(vector, rot.as_matrix().T) return np.dot(vector, rot.as_matrix().T)
def rotate_vector_2d(vector: Iterable, angle: float): def rotate_vector_2d(vector: Vect2, angle: float) -> Vect2:
# Use complex numbers...because why not # Use complex numbers...because why not
z = complex(*vector) * np.exp(complex(0, angle)) z = complex(*vector) * np.exp(complex(0, angle))
return np.array([z.real, z.imag]) return np.array([z.real, z.imag])
def rotation_matrix_transpose_from_quaternion(quat: Iterable) -> np.ndarray: def rotation_matrix_transpose_from_quaternion(quat: Vect4) -> Matrix3x3:
return Rotation.from_quat(quat).as_matrix() return Rotation.from_quat(quat).as_matrix()
def rotation_matrix_from_quaternion(quat: Iterable) -> np.ndarray: def rotation_matrix_from_quaternion(quat: Vect4) -> Matrix3x3:
return np.transpose(rotation_matrix_transpose_from_quaternion(quat)) return np.transpose(rotation_matrix_transpose_from_quaternion(quat))
def rotation_matrix(angle: float, axis: np.ndarray) -> np.ndarray: def rotation_matrix(angle: float, axis: Vect3) -> Matrix3x3:
""" """
Rotation in R^3 about a specified axis of rotation. Rotation in R^3 about a specified axis of rotation.
""" """
return Rotation.from_rotvec(angle * normalize(axis)).as_matrix() return Rotation.from_rotvec(angle * normalize(axis)).as_matrix()
def rotation_matrix_transpose(angle: float, axis: np.ndarray) -> np.ndarray: def rotation_matrix_transpose(angle: float, axis: Vect3) -> Matrix3x3:
return rotation_matrix(angle, axis).T return rotation_matrix(angle, axis).T
def rotation_about_z(angle: float) -> list[list[float]]: def rotation_about_z(angle: float) -> Matrix3x3:
return [ cos_a = math.cos(angle)
[math.cos(angle), -math.sin(angle), 0], sin_a = math.sin(angle)
[math.sin(angle), math.cos(angle), 0], return np.array([
[cos_a, -sin_a, 0],
[sin_a, cos_a, 0],
[0, 0, 1] [0, 0, 1]
] ])
def rotation_between_vectors(v1, v2) -> np.ndarray: def rotation_between_vectors(v1: Vect3, v2: Vect3) -> Matrix3x3:
if np.all(np.isclose(v1, v2)): if np.all(np.isclose(v1, v2)):
return np.identity(3) return np.identity(3)
return rotation_matrix( return rotation_matrix(
@ -136,18 +142,18 @@ def rotation_between_vectors(v1, v2) -> np.ndarray:
) )
def z_to_vector(vector: np.ndarray) -> np.ndarray: def z_to_vector(vector: Vect3) -> Matrix3x3:
return rotation_between_vectors(OUT, vector) return rotation_between_vectors(OUT, vector)
def angle_of_vector(vector: Sequence[float]) -> float: def angle_of_vector(vector: Vect2 | Vect3) -> float:
""" """
Returns polar coordinate theta when vector is project on xy plane Returns polar coordinate theta when vector is project on xy plane
""" """
return np.angle(complex(*vector[:2])) return np.angle(complex(*vector[:2]))
def angle_between_vectors(v1: np.ndarray, v2: np.ndarray) -> float: def angle_between_vectors(v1: VectN, v2: VectN) -> float:
""" """
Returns the angle between two 3D vectors. Returns the angle between two 3D vectors.
This angle will always be btw 0 and pi This angle will always be btw 0 and pi
@ -160,7 +166,7 @@ def angle_between_vectors(v1: np.ndarray, v2: np.ndarray) -> float:
return math.acos(clip(cos_angle, -1, 1)) return math.acos(clip(cos_angle, -1, 1))
def project_along_vector(point: np.ndarray, vector: np.ndarray) -> np.ndarray: def project_along_vector(point: Vect3, vector: Vect3) -> Vect3:
matrix = np.identity(3) - np.outer(vector, vector) matrix = np.identity(3) - np.outer(vector, vector)
return np.dot(point, matrix.T) return np.dot(point, matrix.T)
@ -177,10 +183,10 @@ def normalize_along_axis(
def get_unit_normal( def get_unit_normal(
v1: np.ndarray, v1: Vect3,
v2: np.ndarray, v2: Vect3,
tol: float = 1e-6 tol: float = 1e-6
) -> np.ndarray: ) -> Vect3:
v1 = normalize(v1) v1 = normalize(v1)
v2 = normalize(v2) v2 = normalize(v2)
cp = cross(v1, v2) cp = cross(v1, v2)
@ -204,7 +210,7 @@ def thick_diagonal(dim: int, thickness: int = 2) -> np.ndarray:
return (np.abs(row_indices - col_indices) < thickness).astype('uint8') return (np.abs(row_indices - col_indices) < thickness).astype('uint8')
def compass_directions(n: int = 4, start_vect: np.ndarray = RIGHT) -> np.ndarray: def compass_directions(n: int = 4, start_vect: Vect3 = RIGHT) -> Vect3:
angle = TAU / n angle = TAU / n
return np.array([ return np.array([
rotate_vector(start_vect, k * angle) rotate_vector(start_vect, k * angle)
@ -212,36 +218,32 @@ def compass_directions(n: int = 4, start_vect: np.ndarray = RIGHT) -> np.ndarray
]) ])
def complex_to_R3(complex_num: complex) -> np.ndarray: def complex_to_R3(complex_num: complex) -> Vect3:
return np.array((complex_num.real, complex_num.imag, 0)) return np.array((complex_num.real, complex_num.imag, 0))
def R3_to_complex(point: Sequence[float]) -> complex: def R3_to_complex(point: Vect3) -> complex:
return complex(*point[:2]) return complex(*point[:2])
def complex_func_to_R3_func( def complex_func_to_R3_func(complex_func: Callable[[complex], complex]) -> Callable[[Vect3], Vect3]:
complex_func: Callable[[complex], complex] def result(p: Vect3):
) -> Callable[[np.ndarray], np.ndarray]: return complex_to_R3(complex_func(R3_to_complex(p)))
return lambda p: complex_to_R3(complex_func(R3_to_complex(p))) return result
def center_of_mass(points: Iterable[npt.ArrayLike]) -> np.ndarray: def center_of_mass(points: Sequence[Vect3]) -> Vect3:
points = [np.array(point).astype("float") for point in points] return np.array(points).sum(0) / len(points)
return sum(points) / len(points)
def midpoint( def midpoint(point1: VectN, point2: VectN) -> VectN:
point1: Sequence[float],
point2: Sequence[float]
) -> np.ndarray:
return center_of_mass([point1, point2]) return center_of_mass([point1, point2])
def line_intersection( def line_intersection(
line1: Sequence[Sequence[float]], line1: Tuple[Vect3, Vect3],
line2: Sequence[Sequence[float]] line2: Tuple[Vect3, Vect3]
) -> np.ndarray: ) -> Vect3:
""" """
return intersection point of two lines, return intersection point of two lines,
each defined with a pair of vectors determining each defined with a pair of vectors determining
@ -263,12 +265,12 @@ def line_intersection(
def find_intersection( def find_intersection(
p0: npt.ArrayLike, p0: Vect3,
v0: npt.ArrayLike, v0: Vect3,
p1: npt.ArrayLike, p1: Vect3,
v1: npt.ArrayLike, v1: Vect3,
threshold: float = 1e-5 threshold: float = 1e-5
) -> np.ndarray: ) -> Vect3:
""" """
Return the intersection of a line passing through p0 in direction v0 Return the intersection of a line passing through p0 in direction v0
with one passing through p1 in direction v1. (Or array of intersections with one passing through p1 in direction v1. (Or array of intersections
@ -300,11 +302,7 @@ def find_intersection(
return result return result
def get_closest_point_on_line( def get_closest_point_on_line(a: VectN, b: VectN, p: VectN) -> VectN:
a: np.ndarray,
b: np.ndarray,
p: np.ndarray
) -> np.ndarray:
""" """
It returns point x such that It returns point x such that
x is on line ab and xp is perpendicular to ab. x is on line ab and xp is perpendicular to ab.
@ -319,7 +317,7 @@ def get_closest_point_on_line(
return ((t * a) + ((1 - t) * b)) return ((t * a) + ((1 - t) * b))
def get_winding_number(points: Iterable[float]) -> float: def get_winding_number(points: Sequence[Vect2 | Vect3]) -> float:
total_angle = 0 total_angle = 0
for p1, p2 in adjacent_pairs(points): for p1, p2 in adjacent_pairs(points):
d_angle = angle_of_vector(p2) - angle_of_vector(p1) d_angle = angle_of_vector(p2) - angle_of_vector(p1)
@ -330,7 +328,7 @@ def get_winding_number(points: Iterable[float]) -> float:
## ##
def cross2d(a: np.ndarray, b: np.ndarray) -> np.ndarray: def cross2d(a: Vect2, b: Vect2) -> Vect2:
if len(a.shape) == 2: if len(a.shape) == 2:
return a[:, 0] * b[:, 1] - a[:, 1] * b[:, 0] return a[:, 0] * b[:, 1] - a[:, 1] * b[:, 0]
else: else:
@ -338,9 +336,9 @@ def cross2d(a: np.ndarray, b: np.ndarray) -> np.ndarray:
def tri_area( def tri_area(
a: Sequence[float], a: Vect2,
b: Sequence[float], b: Vect2,
c: Sequence[float] c: Vect2
) -> float: ) -> float:
return 0.5 * abs( return 0.5 * abs(
a[0] * (b[1] - c[1]) + a[0] * (b[1] - c[1]) +
@ -350,10 +348,10 @@ def tri_area(
def is_inside_triangle( def is_inside_triangle(
p: np.ndarray, p: Vect2,
a: np.ndarray, a: Vect2,
b: np.ndarray, b: Vect2,
c: np.ndarray c: Vect2
) -> bool: ) -> bool:
""" """
Test if point p is inside triangle abc Test if point p is inside triangle abc
@ -363,15 +361,15 @@ def is_inside_triangle(
cross2d(p - b, c - p), cross2d(p - b, c - p),
cross2d(p - c, a - p), cross2d(p - c, a - p),
]) ])
return np.all(crosses > 0) or np.all(crosses < 0) return bool(np.all(crosses > 0) or np.all(crosses < 0))
def norm_squared(v: Sequence[float]) -> float: def norm_squared(v: VectN | List[float]) -> float:
return v[0] * v[0] + v[1] * v[1] + v[2] * v[2] return sum(x * x for x in v)
# TODO, fails for polygons drawn over themselves # TODO, fails for polygons drawn over themselves
def earclip_triangulation(verts: np.ndarray, ring_ends: list[int]) -> list: def earclip_triangulation(verts: Vect2 | Vect3, ring_ends: list[int]) -> list[int]:
""" """
Returns a list of indices giving a triangulation Returns a list of indices giving a triangulation
of a polygon, potentially with holes of a polygon, potentially with holes