Update some type hints in bezier

This commit is contained in:
Grant Sanderson 2023-01-11 14:19:17 -08:00
parent 2808710d60
commit 40b9e22b6e

View file

@ -12,8 +12,8 @@ from manimlib.utils.space_ops import midpoint
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable, Sequence, TypeVar from typing import Callable, Sequence, TypeVar, Tuple
from manimlib.typing import VectN, FloatArray from manimlib.typing import VectN, FloatArray, VectNArray
Scalable = TypeVar("Scalable", float, FloatArray) Scalable = TypeVar("Scalable", float, FloatArray)
@ -22,7 +22,7 @@ CLOSED_THRESHOLD = 0.001
def bezier( def bezier(
points: Sequence[Scalable] points: Sequence[Scalable] | VectNArray
) -> Callable[[float], Scalable]: ) -> Callable[[float], Scalable]:
if len(points) == 0: if len(points) == 0:
raise Exception("bezier cannot be calld on an empty list") raise Exception("bezier cannot be calld on an empty list")
@ -69,10 +69,10 @@ def partial_bezier_points(
# Shortened version of partial_bezier_points just for quadratics, # Shortened version of partial_bezier_points just for quadratics,
# since this is called a fair amount # since this is called a fair amount
def partial_quadratic_bezier_points( def partial_quadratic_bezier_points(
points: Sequence[Scalable], points: Sequence[VectN] | VectNArray,
a: float, a: float,
b: float b: float
) -> list[Scalable]: ) -> list[VectN]:
if a == 1: if a == 1:
return 3 * [points[-1]] return 3 * [points[-1]]
@ -202,7 +202,7 @@ def get_smooth_quadratic_bezier_handle_points(
def get_smooth_cubic_bezier_handle_points( def get_smooth_cubic_bezier_handle_points(
points: Sequence[VectN] points: Sequence[VectN] | VectNArray
) -> tuple[FloatArray, FloatArray]: ) -> tuple[FloatArray, FloatArray]:
points = np.array(points) points = np.array(points)
num_handles = len(points) - 1 num_handles = len(points) - 1
@ -292,7 +292,7 @@ def get_quadratic_approximation_of_cubic(
h0: FloatArray, h0: FloatArray,
h1: FloatArray, h1: FloatArray,
a1: FloatArray a1: FloatArray
) -> np.ndarray: ) -> FloatArray:
a0 = np.array(a0, ndmin=2) a0 = np.array(a0, ndmin=2)
h0 = np.array(h0, ndmin=2) h0 = np.array(h0, ndmin=2)
h1 = np.array(h1, ndmin=2) h1 = np.array(h1, ndmin=2)