From 40b9e22b6ee4fb92639c8ebc553c9cc57874c082 Mon Sep 17 00:00:00 2001 From: Grant Sanderson Date: Wed, 11 Jan 2023 14:19:17 -0800 Subject: [PATCH] Update some type hints in bezier --- manimlib/utils/bezier.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/manimlib/utils/bezier.py b/manimlib/utils/bezier.py index 82290a1b..1ec3257d 100644 --- a/manimlib/utils/bezier.py +++ b/manimlib/utils/bezier.py @@ -12,8 +12,8 @@ from manimlib.utils.space_ops import midpoint from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Callable, Sequence, TypeVar - from manimlib.typing import VectN, FloatArray + from typing import Callable, Sequence, TypeVar, Tuple + from manimlib.typing import VectN, FloatArray, VectNArray Scalable = TypeVar("Scalable", float, FloatArray) @@ -22,7 +22,7 @@ CLOSED_THRESHOLD = 0.001 def bezier( - points: Sequence[Scalable] + points: Sequence[Scalable] | VectNArray ) -> Callable[[float], Scalable]: if len(points) == 0: raise Exception("bezier cannot be calld on an empty list") @@ -69,10 +69,10 @@ def partial_bezier_points( # Shortened version of partial_bezier_points just for quadratics, # since this is called a fair amount def partial_quadratic_bezier_points( - points: Sequence[Scalable], + points: Sequence[VectN] | VectNArray, a: float, b: float -) -> list[Scalable]: +) -> list[VectN]: if a == 1: return 3 * [points[-1]] @@ -202,7 +202,7 @@ def get_smooth_quadratic_bezier_handle_points( def get_smooth_cubic_bezier_handle_points( - points: Sequence[VectN] + points: Sequence[VectN] | VectNArray ) -> tuple[FloatArray, FloatArray]: points = np.array(points) num_handles = len(points) - 1 @@ -292,7 +292,7 @@ def get_quadratic_approximation_of_cubic( h0: FloatArray, h1: FloatArray, a1: FloatArray -) -> np.ndarray: +) -> FloatArray: a0 = np.array(a0, ndmin=2) h0 = np.array(h0, ndmin=2) h1 = np.array(h1, ndmin=2)