diff --git a/manimlib/utils/bezier.py b/manimlib/utils/bezier.py index 293b2228..9f18e0a3 100644 --- a/manimlib/utils/bezier.py +++ b/manimlib/utils/bezier.py @@ -12,22 +12,23 @@ from manimlib.utils.space_ops import midpoint from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Callable, Iterable, Sequence, TypeVar - - import numpy.typing as npt + from typing import Callable, Sequence, TypeVar + from manimlib.typing import VectN, FloatArray T = TypeVar("T") + Scalable = float | VectN + CLOSED_THRESHOLD = 0.001 def bezier( - points: Iterable[float | np.ndarray] -) -> Callable[[float], float | np.ndarray]: + points: Sequence[Scalable] +) -> Callable[[float], Scalable]: n = len(points) - 1 - def result(t): + def result(t: Scalable) -> Scalable: return sum( ((1 - t)**(n - k)) * (t**k) * choose(n, k) * point for k, point in enumerate(points) @@ -37,10 +38,10 @@ def bezier( def partial_bezier_points( - points: Sequence[np.ndarray], + points: Sequence[Scalable], a: float, b: float -) -> list[float]: +) -> list[Scalable]: """ Given an list of points which define a bezier curve, and two numbers 0<=a list[float]: +) -> list[Scalable]: if a == 1: return 3 * [points[-1]] @@ -88,7 +89,7 @@ def partial_quadratic_bezier_points( # Linear interpolation variants -def interpolate(start: T, end: T, alpha: np.ndarray | float) -> T: +def interpolate(start: Scalable, end: Scalable, alpha: Scalable) -> Scalable: try: return (1 - alpha) * start + alpha * end except TypeError: @@ -100,10 +101,10 @@ def interpolate(start: T, end: T, alpha: np.ndarray | float) -> T: def outer_interpolate( - start: np.ndarray | float, - end: np.ndarray | float, - alpha: np.ndarray | float, -) -> T: + start: Scalable, + end: Scalable, + alpha: Scalable, +) -> np.ndarray: result = np.outer(1 - alpha, start) + np.outer(alpha, end) return result.reshape((*np.shape(alpha), *np.shape(start))) @@ -120,8 +121,8 @@ def set_array_by_interpolation( def integer_interpolate( - start: T, - end: T, + start: int, + end: int, alpha: float ) -> tuple[int, float]: """ @@ -144,21 +145,21 @@ def integer_interpolate( return (value, residue) -def mid(start: T, end: T) -> T: +def mid(start: Scalable, end: Scalable) -> Scalable: return (start + end) / 2.0 -def inverse_interpolate(start: T, end: T, value: T) -> float: +def inverse_interpolate(start: Scalable, end: Scalable, value: Scalable) -> Scalable: return np.true_divide(value - start, end - start) def match_interpolate( - new_start: T, - new_end: T, - old_start: T, - old_end: T, - old_value: T -) -> T: + new_start: Scalable, + new_end: Scalable, + old_start: Scalable, + old_end: Scalable, + old_value: Scalable +) -> Scalable: return interpolate( new_start, new_end, inverse_interpolate(old_start, old_end, old_value) @@ -166,8 +167,8 @@ def match_interpolate( def get_smooth_quadratic_bezier_handle_points( - points: Sequence[np.ndarray] -) -> np.ndarray | list[np.ndarray]: + points: FloatArray +) -> FloatArray: """ Figuring out which bezier curves most smoothly connect a sequence of points. @@ -200,8 +201,8 @@ def get_smooth_quadratic_bezier_handle_points( def get_smooth_cubic_bezier_handle_points( - points: npt.ArrayLike -) -> tuple[np.ndarray, np.ndarray]: + points: Sequence[VectN] +) -> tuple[FloatArray, FloatArray]: points = np.array(points) num_handles = len(points) - 1 dim = points.shape[1] @@ -279,17 +280,17 @@ def diag_to_matrix( return matrix -def is_closed(points: Sequence[np.ndarray]) -> bool: +def is_closed(points: FloatArray) -> bool: return np.allclose(points[0], points[-1]) # Given 4 control points for a cubic bezier curve (or arrays of such) # return control points for 2 quadratics (or 2n quadratics) approximating them. def get_quadratic_approximation_of_cubic( - a0: npt.ArrayLike, - h0: npt.ArrayLike, - h1: npt.ArrayLike, - a1: npt.ArrayLike + a0: FloatArray, + h0: FloatArray, + h1: FloatArray, + a1: FloatArray ) -> np.ndarray: a0 = np.array(a0, ndmin=2) h0 = np.array(h0, ndmin=2) @@ -359,7 +360,7 @@ def get_quadratic_approximation_of_cubic( def get_smooth_quadratic_bezier_path_through( - points: list[np.ndarray] + points: Sequence[VectN] ) -> np.ndarray: # TODO h0, h1 = get_smooth_cubic_bezier_handle_points(points)