Update some type hints in bezier

This commit is contained in:
Grant Sanderson 2023-01-11 14:19:17 -08:00
parent 2808710d60
commit 40b9e22b6e

View file

@ -12,8 +12,8 @@ from manimlib.utils.space_ops import midpoint
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Callable, Sequence, TypeVar
from manimlib.typing import VectN, FloatArray
from typing import Callable, Sequence, TypeVar, Tuple
from manimlib.typing import VectN, FloatArray, VectNArray
Scalable = TypeVar("Scalable", float, FloatArray)
@ -22,7 +22,7 @@ CLOSED_THRESHOLD = 0.001
def bezier(
points: Sequence[Scalable]
points: Sequence[Scalable] | VectNArray
) -> Callable[[float], Scalable]:
if len(points) == 0:
raise Exception("bezier cannot be calld on an empty list")
@ -69,10 +69,10 @@ def partial_bezier_points(
# Shortened version of partial_bezier_points just for quadratics,
# since this is called a fair amount
def partial_quadratic_bezier_points(
points: Sequence[Scalable],
points: Sequence[VectN] | VectNArray,
a: float,
b: float
) -> list[Scalable]:
) -> list[VectN]:
if a == 1:
return 3 * [points[-1]]
@ -202,7 +202,7 @@ def get_smooth_quadratic_bezier_handle_points(
def get_smooth_cubic_bezier_handle_points(
points: Sequence[VectN]
points: Sequence[VectN] | VectNArray
) -> tuple[FloatArray, FloatArray]:
points = np.array(points)
num_handles = len(points) - 1
@ -292,7 +292,7 @@ def get_quadratic_approximation_of_cubic(
h0: FloatArray,
h1: FloatArray,
a1: FloatArray
) -> np.ndarray:
) -> FloatArray:
a0 = np.array(a0, ndmin=2)
h0 = np.array(h0, ndmin=2)
h1 = np.array(h1, ndmin=2)