Tweak type hints for bezier.py

This commit is contained in:
Grant Sanderson 2022-12-17 18:35:26 -08:00
parent 810f2c67ab
commit 24fd6d890e

View file

@ -12,22 +12,23 @@ from manimlib.utils.space_ops import midpoint
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Callable, Iterable, Sequence, TypeVar
import numpy.typing as npt
from typing import Callable, Sequence, TypeVar
from manimlib.typing import VectN, FloatArray
T = TypeVar("T")
Scalable = float | VectN
CLOSED_THRESHOLD = 0.001
def bezier(
points: Iterable[float | np.ndarray]
) -> Callable[[float], float | np.ndarray]:
points: Sequence[Scalable]
) -> Callable[[float], Scalable]:
n = len(points) - 1
def result(t):
def result(t: Scalable) -> Scalable:
return sum(
((1 - t)**(n - k)) * (t**k) * choose(n, k) * point
for k, point in enumerate(points)
@ -37,10 +38,10 @@ def bezier(
def partial_bezier_points(
points: Sequence[np.ndarray],
points: Sequence[Scalable],
a: float,
b: float
) -> list[float]:
) -> list[Scalable]:
"""
Given an list of points which define
a bezier curve, and two numbers 0<=a<b<=1,
@ -67,10 +68,10 @@ def partial_bezier_points(
# Shortened version of partial_bezier_points just for quadratics,
# since this is called a fair amount
def partial_quadratic_bezier_points(
points: Sequence[np.ndarray],
points: Sequence[Scalable],
a: float,
b: float
) -> list[float]:
) -> list[Scalable]:
if a == 1:
return 3 * [points[-1]]
@ -88,7 +89,7 @@ def partial_quadratic_bezier_points(
# Linear interpolation variants
def interpolate(start: T, end: T, alpha: np.ndarray | float) -> T:
def interpolate(start: Scalable, end: Scalable, alpha: Scalable) -> Scalable:
try:
return (1 - alpha) * start + alpha * end
except TypeError:
@ -100,10 +101,10 @@ def interpolate(start: T, end: T, alpha: np.ndarray | float) -> T:
def outer_interpolate(
start: np.ndarray | float,
end: np.ndarray | float,
alpha: np.ndarray | float,
) -> T:
start: Scalable,
end: Scalable,
alpha: Scalable,
) -> np.ndarray:
result = np.outer(1 - alpha, start) + np.outer(alpha, end)
return result.reshape((*np.shape(alpha), *np.shape(start)))
@ -120,8 +121,8 @@ def set_array_by_interpolation(
def integer_interpolate(
start: T,
end: T,
start: int,
end: int,
alpha: float
) -> tuple[int, float]:
"""
@ -144,21 +145,21 @@ def integer_interpolate(
return (value, residue)
def mid(start: T, end: T) -> T:
def mid(start: Scalable, end: Scalable) -> Scalable:
return (start + end) / 2.0
def inverse_interpolate(start: T, end: T, value: T) -> float:
def inverse_interpolate(start: Scalable, end: Scalable, value: Scalable) -> Scalable:
return np.true_divide(value - start, end - start)
def match_interpolate(
new_start: T,
new_end: T,
old_start: T,
old_end: T,
old_value: T
) -> T:
new_start: Scalable,
new_end: Scalable,
old_start: Scalable,
old_end: Scalable,
old_value: Scalable
) -> Scalable:
return interpolate(
new_start, new_end,
inverse_interpolate(old_start, old_end, old_value)
@ -166,8 +167,8 @@ def match_interpolate(
def get_smooth_quadratic_bezier_handle_points(
points: Sequence[np.ndarray]
) -> np.ndarray | list[np.ndarray]:
points: FloatArray
) -> FloatArray:
"""
Figuring out which bezier curves most smoothly connect a sequence of points.
@ -200,8 +201,8 @@ def get_smooth_quadratic_bezier_handle_points(
def get_smooth_cubic_bezier_handle_points(
points: npt.ArrayLike
) -> tuple[np.ndarray, np.ndarray]:
points: Sequence[VectN]
) -> tuple[FloatArray, FloatArray]:
points = np.array(points)
num_handles = len(points) - 1
dim = points.shape[1]
@ -279,17 +280,17 @@ def diag_to_matrix(
return matrix
def is_closed(points: Sequence[np.ndarray]) -> bool:
def is_closed(points: FloatArray) -> bool:
return np.allclose(points[0], points[-1])
# Given 4 control points for a cubic bezier curve (or arrays of such)
# return control points for 2 quadratics (or 2n quadratics) approximating them.
def get_quadratic_approximation_of_cubic(
a0: npt.ArrayLike,
h0: npt.ArrayLike,
h1: npt.ArrayLike,
a1: npt.ArrayLike
a0: FloatArray,
h0: FloatArray,
h1: FloatArray,
a1: FloatArray
) -> np.ndarray:
a0 = np.array(a0, ndmin=2)
h0 = np.array(h0, ndmin=2)
@ -359,7 +360,7 @@ def get_quadratic_approximation_of_cubic(
def get_smooth_quadratic_bezier_path_through(
points: list[np.ndarray]
points: Sequence[VectN]
) -> np.ndarray:
# TODO
h0, h1 = get_smooth_cubic_bezier_handle_points(points)