Merge pull request #1788 from 3b1b/interpolate-fix

Interpolate fix
This commit is contained in:
Grant Sanderson 2022-04-11 10:51:05 -07:00 committed by GitHub
commit 859680d5ab
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 9 deletions

View file

@ -7,6 +7,7 @@ from manimlib.mobject.geometry import Line
from manimlib.mobject.numbers import DecimalNumber
from manimlib.mobject.types.vectorized_mobject import VGroup
from manimlib.utils.bezier import interpolate
from manimlib.utils.bezier import outer_interpolate
from manimlib.utils.config_ops import digest_config
from manimlib.utils.config_ops import merge_dicts_recursively
from manimlib.utils.simple_functions import fdiv
@ -106,7 +107,7 @@ class NumberLine(Line):
def number_to_point(self, number: float | np.ndarray) -> np.ndarray:
alpha = (number - self.x_min) / (self.x_max - self.x_min)
return interpolate(self.get_start(), self.get_end(), alpha)
return outer_interpolate(self.get_start(), self.get_end(), alpha)
def point_to_number(self, point: np.ndarray) -> float:
points = self.get_points()

View file

@ -80,15 +80,10 @@ def partial_quadratic_bezier_points(
# Linear interpolation variants
def interpolate(start: T, end: T, alpha: float) -> T:
def interpolate(start: T, end: T, alpha: np.ndarray | float) -> T:
try:
if isinstance(alpha, float):
return (1 - alpha) * start + alpha * end
# Otherwise, assume alpha is a list or array, and return
# an appropriated shaped array of all corresponding
# interpolations
result = np.outer(1 - alpha, start) + np.outer(alpha, end)
return result.reshape((*np.shape(alpha), *np.shape(start)))
return (1 - alpha) * start + alpha * end
except TypeError:
log.debug(f"`start` parameter with type `{type(start)}` and dtype `{start.dtype}`")
log.debug(f"`end` parameter with type `{type(end)}` and dtype `{end.dtype}`")
@ -97,6 +92,15 @@ def interpolate(start: T, end: T, alpha: float) -> T:
sys.exit(2)
def outer_interpolate(
start: np.ndarray | float,
end: np.ndarray | float,
alpha: np.ndarray | float,
) -> T:
result = np.outer(1 - alpha, start) + np.outer(alpha, end)
return result.reshape((*np.shape(alpha), *np.shape(start)))
def set_array_by_interpolation(
arr: np.ndarray,
arr1: np.ndarray,