mirror of
https://github.com/3b1b/manim.git
synced 2025-09-01 00:48:45 +00:00
commit
859680d5ab
2 changed files with 14 additions and 9 deletions
|
@ -7,6 +7,7 @@ from manimlib.mobject.geometry import Line
|
|||
from manimlib.mobject.numbers import DecimalNumber
|
||||
from manimlib.mobject.types.vectorized_mobject import VGroup
|
||||
from manimlib.utils.bezier import interpolate
|
||||
from manimlib.utils.bezier import outer_interpolate
|
||||
from manimlib.utils.config_ops import digest_config
|
||||
from manimlib.utils.config_ops import merge_dicts_recursively
|
||||
from manimlib.utils.simple_functions import fdiv
|
||||
|
@ -106,7 +107,7 @@ class NumberLine(Line):
|
|||
|
||||
def number_to_point(self, number: float | np.ndarray) -> np.ndarray:
|
||||
alpha = (number - self.x_min) / (self.x_max - self.x_min)
|
||||
return interpolate(self.get_start(), self.get_end(), alpha)
|
||||
return outer_interpolate(self.get_start(), self.get_end(), alpha)
|
||||
|
||||
def point_to_number(self, point: np.ndarray) -> float:
|
||||
points = self.get_points()
|
||||
|
|
|
@ -80,15 +80,10 @@ def partial_quadratic_bezier_points(
|
|||
|
||||
# Linear interpolation variants
|
||||
|
||||
def interpolate(start: T, end: T, alpha: float) -> T:
|
||||
|
||||
def interpolate(start: T, end: T, alpha: np.ndarray | float) -> T:
|
||||
try:
|
||||
if isinstance(alpha, float):
|
||||
return (1 - alpha) * start + alpha * end
|
||||
# Otherwise, assume alpha is a list or array, and return
|
||||
# an appropriated shaped array of all corresponding
|
||||
# interpolations
|
||||
result = np.outer(1 - alpha, start) + np.outer(alpha, end)
|
||||
return result.reshape((*np.shape(alpha), *np.shape(start)))
|
||||
return (1 - alpha) * start + alpha * end
|
||||
except TypeError:
|
||||
log.debug(f"`start` parameter with type `{type(start)}` and dtype `{start.dtype}`")
|
||||
log.debug(f"`end` parameter with type `{type(end)}` and dtype `{end.dtype}`")
|
||||
|
@ -97,6 +92,15 @@ def interpolate(start: T, end: T, alpha: float) -> T:
|
|||
sys.exit(2)
|
||||
|
||||
|
||||
def outer_interpolate(
|
||||
start: np.ndarray | float,
|
||||
end: np.ndarray | float,
|
||||
alpha: np.ndarray | float,
|
||||
) -> T:
|
||||
result = np.outer(1 - alpha, start) + np.outer(alpha, end)
|
||||
return result.reshape((*np.shape(alpha), *np.shape(start)))
|
||||
|
||||
|
||||
def set_array_by_interpolation(
|
||||
arr: np.ndarray,
|
||||
arr1: np.ndarray,
|
||||
|
|
Loading…
Add table
Reference in a new issue