Tweak type hints for bezier.py

This commit is contained in:
Grant Sanderson 2022-12-17 18:35:26 -08:00
parent 810f2c67ab
commit 24fd6d890e

View file

@ -12,22 +12,23 @@ from manimlib.utils.space_ops import midpoint
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable, Iterable, Sequence, TypeVar from typing import Callable, Sequence, TypeVar
from manimlib.typing import VectN, FloatArray
import numpy.typing as npt
T = TypeVar("T") T = TypeVar("T")
Scalable = float | VectN
CLOSED_THRESHOLD = 0.001 CLOSED_THRESHOLD = 0.001
def bezier( def bezier(
points: Iterable[float | np.ndarray] points: Sequence[Scalable]
) -> Callable[[float], float | np.ndarray]: ) -> Callable[[float], Scalable]:
n = len(points) - 1 n = len(points) - 1
def result(t): def result(t: Scalable) -> Scalable:
return sum( return sum(
((1 - t)**(n - k)) * (t**k) * choose(n, k) * point ((1 - t)**(n - k)) * (t**k) * choose(n, k) * point
for k, point in enumerate(points) for k, point in enumerate(points)
@ -37,10 +38,10 @@ def bezier(
def partial_bezier_points( def partial_bezier_points(
points: Sequence[np.ndarray], points: Sequence[Scalable],
a: float, a: float,
b: float b: float
) -> list[float]: ) -> list[Scalable]:
""" """
Given an list of points which define Given an list of points which define
a bezier curve, and two numbers 0<=a<b<=1, a bezier curve, and two numbers 0<=a<b<=1,
@ -67,10 +68,10 @@ def partial_bezier_points(
# Shortened version of partial_bezier_points just for quadratics, # Shortened version of partial_bezier_points just for quadratics,
# since this is called a fair amount # since this is called a fair amount
def partial_quadratic_bezier_points( def partial_quadratic_bezier_points(
points: Sequence[np.ndarray], points: Sequence[Scalable],
a: float, a: float,
b: float b: float
) -> list[float]: ) -> list[Scalable]:
if a == 1: if a == 1:
return 3 * [points[-1]] return 3 * [points[-1]]
@ -88,7 +89,7 @@ def partial_quadratic_bezier_points(
# Linear interpolation variants # Linear interpolation variants
def interpolate(start: T, end: T, alpha: np.ndarray | float) -> T: def interpolate(start: Scalable, end: Scalable, alpha: Scalable) -> Scalable:
try: try:
return (1 - alpha) * start + alpha * end return (1 - alpha) * start + alpha * end
except TypeError: except TypeError:
@ -100,10 +101,10 @@ def interpolate(start: T, end: T, alpha: np.ndarray | float) -> T:
def outer_interpolate( def outer_interpolate(
start: np.ndarray | float, start: Scalable,
end: np.ndarray | float, end: Scalable,
alpha: np.ndarray | float, alpha: Scalable,
) -> T: ) -> np.ndarray:
result = np.outer(1 - alpha, start) + np.outer(alpha, end) result = np.outer(1 - alpha, start) + np.outer(alpha, end)
return result.reshape((*np.shape(alpha), *np.shape(start))) return result.reshape((*np.shape(alpha), *np.shape(start)))
@ -120,8 +121,8 @@ def set_array_by_interpolation(
def integer_interpolate( def integer_interpolate(
start: T, start: int,
end: T, end: int,
alpha: float alpha: float
) -> tuple[int, float]: ) -> tuple[int, float]:
""" """
@ -144,21 +145,21 @@ def integer_interpolate(
return (value, residue) return (value, residue)
def mid(start: T, end: T) -> T: def mid(start: Scalable, end: Scalable) -> Scalable:
return (start + end) / 2.0 return (start + end) / 2.0
def inverse_interpolate(start: T, end: T, value: T) -> float: def inverse_interpolate(start: Scalable, end: Scalable, value: Scalable) -> Scalable:
return np.true_divide(value - start, end - start) return np.true_divide(value - start, end - start)
def match_interpolate( def match_interpolate(
new_start: T, new_start: Scalable,
new_end: T, new_end: Scalable,
old_start: T, old_start: Scalable,
old_end: T, old_end: Scalable,
old_value: T old_value: Scalable
) -> T: ) -> Scalable:
return interpolate( return interpolate(
new_start, new_end, new_start, new_end,
inverse_interpolate(old_start, old_end, old_value) inverse_interpolate(old_start, old_end, old_value)
@ -166,8 +167,8 @@ def match_interpolate(
def get_smooth_quadratic_bezier_handle_points( def get_smooth_quadratic_bezier_handle_points(
points: Sequence[np.ndarray] points: FloatArray
) -> np.ndarray | list[np.ndarray]: ) -> FloatArray:
""" """
Figuring out which bezier curves most smoothly connect a sequence of points. Figuring out which bezier curves most smoothly connect a sequence of points.
@ -200,8 +201,8 @@ def get_smooth_quadratic_bezier_handle_points(
def get_smooth_cubic_bezier_handle_points( def get_smooth_cubic_bezier_handle_points(
points: npt.ArrayLike points: Sequence[VectN]
) -> tuple[np.ndarray, np.ndarray]: ) -> tuple[FloatArray, FloatArray]:
points = np.array(points) points = np.array(points)
num_handles = len(points) - 1 num_handles = len(points) - 1
dim = points.shape[1] dim = points.shape[1]
@ -279,17 +280,17 @@ def diag_to_matrix(
return matrix return matrix
def is_closed(points: Sequence[np.ndarray]) -> bool: def is_closed(points: FloatArray) -> bool:
return np.allclose(points[0], points[-1]) return np.allclose(points[0], points[-1])
# Given 4 control points for a cubic bezier curve (or arrays of such) # Given 4 control points for a cubic bezier curve (or arrays of such)
# return control points for 2 quadratics (or 2n quadratics) approximating them. # return control points for 2 quadratics (or 2n quadratics) approximating them.
def get_quadratic_approximation_of_cubic( def get_quadratic_approximation_of_cubic(
a0: npt.ArrayLike, a0: FloatArray,
h0: npt.ArrayLike, h0: FloatArray,
h1: npt.ArrayLike, h1: FloatArray,
a1: npt.ArrayLike a1: FloatArray
) -> np.ndarray: ) -> np.ndarray:
a0 = np.array(a0, ndmin=2) a0 = np.array(a0, ndmin=2)
h0 = np.array(h0, ndmin=2) h0 = np.array(h0, ndmin=2)
@ -359,7 +360,7 @@ def get_quadratic_approximation_of_cubic(
def get_smooth_quadratic_bezier_path_through( def get_smooth_quadratic_bezier_path_through(
points: list[np.ndarray] points: Sequence[VectN]
) -> np.ndarray: ) -> np.ndarray:
# TODO # TODO
h0, h1 = get_smooth_cubic_bezier_handle_points(points) h0, h1 = get_smooth_cubic_bezier_handle_points(points)