mirror of
https://github.com/3b1b/manim.git
synced 2025-08-05 16:49:03 +00:00
Tweak type hints for bezier.py
This commit is contained in:
parent
810f2c67ab
commit
24fd6d890e
1 changed files with 36 additions and 35 deletions
|
@ -12,22 +12,23 @@ from manimlib.utils.space_ops import midpoint
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Callable, Iterable, Sequence, TypeVar
|
||||
|
||||
import numpy.typing as npt
|
||||
from typing import Callable, Sequence, TypeVar
|
||||
from manimlib.typing import VectN, FloatArray
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
Scalable = float | VectN
|
||||
|
||||
|
||||
CLOSED_THRESHOLD = 0.001
|
||||
|
||||
|
||||
def bezier(
|
||||
points: Iterable[float | np.ndarray]
|
||||
) -> Callable[[float], float | np.ndarray]:
|
||||
points: Sequence[Scalable]
|
||||
) -> Callable[[float], Scalable]:
|
||||
n = len(points) - 1
|
||||
|
||||
def result(t):
|
||||
def result(t: Scalable) -> Scalable:
|
||||
return sum(
|
||||
((1 - t)**(n - k)) * (t**k) * choose(n, k) * point
|
||||
for k, point in enumerate(points)
|
||||
|
@ -37,10 +38,10 @@ def bezier(
|
|||
|
||||
|
||||
def partial_bezier_points(
|
||||
points: Sequence[np.ndarray],
|
||||
points: Sequence[Scalable],
|
||||
a: float,
|
||||
b: float
|
||||
) -> list[float]:
|
||||
) -> list[Scalable]:
|
||||
"""
|
||||
Given an list of points which define
|
||||
a bezier curve, and two numbers 0<=a<b<=1,
|
||||
|
@ -67,10 +68,10 @@ def partial_bezier_points(
|
|||
# Shortened version of partial_bezier_points just for quadratics,
|
||||
# since this is called a fair amount
|
||||
def partial_quadratic_bezier_points(
|
||||
points: Sequence[np.ndarray],
|
||||
points: Sequence[Scalable],
|
||||
a: float,
|
||||
b: float
|
||||
) -> list[float]:
|
||||
) -> list[Scalable]:
|
||||
if a == 1:
|
||||
return 3 * [points[-1]]
|
||||
|
||||
|
@ -88,7 +89,7 @@ def partial_quadratic_bezier_points(
|
|||
# Linear interpolation variants
|
||||
|
||||
|
||||
def interpolate(start: T, end: T, alpha: np.ndarray | float) -> T:
|
||||
def interpolate(start: Scalable, end: Scalable, alpha: Scalable) -> Scalable:
|
||||
try:
|
||||
return (1 - alpha) * start + alpha * end
|
||||
except TypeError:
|
||||
|
@ -100,10 +101,10 @@ def interpolate(start: T, end: T, alpha: np.ndarray | float) -> T:
|
|||
|
||||
|
||||
def outer_interpolate(
|
||||
start: np.ndarray | float,
|
||||
end: np.ndarray | float,
|
||||
alpha: np.ndarray | float,
|
||||
) -> T:
|
||||
start: Scalable,
|
||||
end: Scalable,
|
||||
alpha: Scalable,
|
||||
) -> np.ndarray:
|
||||
result = np.outer(1 - alpha, start) + np.outer(alpha, end)
|
||||
return result.reshape((*np.shape(alpha), *np.shape(start)))
|
||||
|
||||
|
@ -120,8 +121,8 @@ def set_array_by_interpolation(
|
|||
|
||||
|
||||
def integer_interpolate(
|
||||
start: T,
|
||||
end: T,
|
||||
start: int,
|
||||
end: int,
|
||||
alpha: float
|
||||
) -> tuple[int, float]:
|
||||
"""
|
||||
|
@ -144,21 +145,21 @@ def integer_interpolate(
|
|||
return (value, residue)
|
||||
|
||||
|
||||
def mid(start: T, end: T) -> T:
|
||||
def mid(start: Scalable, end: Scalable) -> Scalable:
|
||||
return (start + end) / 2.0
|
||||
|
||||
|
||||
def inverse_interpolate(start: T, end: T, value: T) -> float:
|
||||
def inverse_interpolate(start: Scalable, end: Scalable, value: Scalable) -> Scalable:
|
||||
return np.true_divide(value - start, end - start)
|
||||
|
||||
|
||||
def match_interpolate(
|
||||
new_start: T,
|
||||
new_end: T,
|
||||
old_start: T,
|
||||
old_end: T,
|
||||
old_value: T
|
||||
) -> T:
|
||||
new_start: Scalable,
|
||||
new_end: Scalable,
|
||||
old_start: Scalable,
|
||||
old_end: Scalable,
|
||||
old_value: Scalable
|
||||
) -> Scalable:
|
||||
return interpolate(
|
||||
new_start, new_end,
|
||||
inverse_interpolate(old_start, old_end, old_value)
|
||||
|
@ -166,8 +167,8 @@ def match_interpolate(
|
|||
|
||||
|
||||
def get_smooth_quadratic_bezier_handle_points(
|
||||
points: Sequence[np.ndarray]
|
||||
) -> np.ndarray | list[np.ndarray]:
|
||||
points: FloatArray
|
||||
) -> FloatArray:
|
||||
"""
|
||||
Figuring out which bezier curves most smoothly connect a sequence of points.
|
||||
|
||||
|
@ -200,8 +201,8 @@ def get_smooth_quadratic_bezier_handle_points(
|
|||
|
||||
|
||||
def get_smooth_cubic_bezier_handle_points(
|
||||
points: npt.ArrayLike
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
points: Sequence[VectN]
|
||||
) -> tuple[FloatArray, FloatArray]:
|
||||
points = np.array(points)
|
||||
num_handles = len(points) - 1
|
||||
dim = points.shape[1]
|
||||
|
@ -279,17 +280,17 @@ def diag_to_matrix(
|
|||
return matrix
|
||||
|
||||
|
||||
def is_closed(points: Sequence[np.ndarray]) -> bool:
|
||||
def is_closed(points: FloatArray) -> bool:
|
||||
return np.allclose(points[0], points[-1])
|
||||
|
||||
|
||||
# Given 4 control points for a cubic bezier curve (or arrays of such)
|
||||
# return control points for 2 quadratics (or 2n quadratics) approximating them.
|
||||
def get_quadratic_approximation_of_cubic(
|
||||
a0: npt.ArrayLike,
|
||||
h0: npt.ArrayLike,
|
||||
h1: npt.ArrayLike,
|
||||
a1: npt.ArrayLike
|
||||
a0: FloatArray,
|
||||
h0: FloatArray,
|
||||
h1: FloatArray,
|
||||
a1: FloatArray
|
||||
) -> np.ndarray:
|
||||
a0 = np.array(a0, ndmin=2)
|
||||
h0 = np.array(h0, ndmin=2)
|
||||
|
@ -359,7 +360,7 @@ def get_quadratic_approximation_of_cubic(
|
|||
|
||||
|
||||
def get_smooth_quadratic_bezier_path_through(
|
||||
points: list[np.ndarray]
|
||||
points: Sequence[VectN]
|
||||
) -> np.ndarray:
|
||||
# TODO
|
||||
h0, h1 = get_smooth_cubic_bezier_handle_points(points)
|
||||
|
|
Loading…
Add table
Reference in a new issue