diff --git a/manimlib/utils/space_ops.py b/manimlib/utils/space_ops.py index 2958aa4d..b77587a2 100644 --- a/manimlib/utils/space_ops.py +++ b/manimlib/utils/space_ops.py @@ -22,12 +22,20 @@ if TYPE_CHECKING: from manimlib.typing import Vect2, Vect3, Vect4, VectN, Matrix3x3, Vect3Array, Vect2Array -def cross(v1: Vect3 | List[float], v2: Vect3 | List[float]) -> Vect3: - return np.array([ - v1[1] * v2[2] - v1[2] * v2[1], - v1[2] * v2[0] - v1[0] * v2[2], - v1[0] * v2[1] - v1[1] * v2[0] +def cross(v1: Vect3 | List[float], v2: Vect3 | List[float]) -> Vect3 | Vect3Array: + is2d = isinstance(v1, np.ndarray) and len(v1.shape) == 2 + if is2d: + x1, y1, z1 = v1[:, 0], v1[:, 1], v1[:, 2] + x2, y2, z2 = v2[:, 0], v2[:, 1], v2[:, 2] + else: + x1, y1, z1 = v1 + x2, y2, z2 = v2 + result = np.array([ + y1 * z2 - z1 * y2, + z1 * x2 - x1 * z2, + x1 * y2 - y1 * x2, ]) + return result.T if is2d else result def get_norm(vect: VectN | List[float]) -> float: @@ -292,8 +300,8 @@ def find_intersection( m, n = np.shape(p0) assert(n in [2, 3]) - numer = np.cross(v1, p1 - p0) - denom = np.cross(v1, v0) + numer = cross(v1, p1 - p0) + denom = cross(v1, v0) if n == 3: d = len(np.shape(numer)) new_numer = np.multiply(numer, numer).sum(d - 1)