Save bounding box information, which speeds up most mobject manipulations

This commit is contained in:
Grant Sanderson 2021-01-14 01:01:43 -10:00
parent d53dbba346
commit f90e4147b6

View file

@ -66,6 +66,7 @@ class Mobject(object):
self.parents = []
self.family = [self]
self.locked_data_keys = set()
self.needs_new_bounding_box = True
self.init_data()
self.init_uniforms()
@ -83,6 +84,7 @@ class Mobject(object):
def init_data(self):
self.data = {
"points": np.zeros((0, 3)),
"bounding_box": np.zeros((3, 3)),
"rgbas": np.zeros((1, 4)),
}
@ -115,16 +117,24 @@ class Mobject(object):
def resize_points(self, new_length, resize_func=resize_array):
if new_length != len(self.data["points"]):
self.data["points"] = resize_func(self.data["points"], new_length)
self.refresh_bounding_box()
return self
def set_points(self, points):
self.resize_points(len(points))
self.data["points"][:] = points
if len(points) == len(self.data["points"]):
self.data["points"][:] = points
elif isinstance(points, np.ndarray):
self.data["points"] = points.copy()
# Note that points have been resized?
else:
self.data["points"] = np.array(points)
# Note that points have been resized?
self.refresh_bounding_box()
return self
def append_points(self, new_points):
self.resize_points(self.get_num_points() + len(new_points))
self.data["points"][-len(new_points):] = new_points
self.data["points"] = np.vstack([self.data["points"], new_points])
self.refresh_bounding_box()
return self
def match_points(self, mobject):
@ -138,9 +148,9 @@ class Mobject(object):
def get_num_points(self):
return len(self.data["points"])
#
# Family matters
def __getitem__(self, value):
if isinstance(value, slice):
GroupClass = self.get_group_class()
@ -205,12 +215,9 @@ class Mobject(object):
def set_submobjects(self, submobject_list):
self.remove(*self.submobjects)
self.add(*submobject_list)
self.refresh_bounding_box()
return self
def get_array_attrs(self):
# May be more for other Mobject types
return ["points"]
def digest_mobject_attrs(self):
"""
Ensures all attributes which are mobjects are included
@ -220,11 +227,6 @@ class Mobject(object):
self.set_submobjects(list_update(self.submobjects, mobject_attrs))
return self
def apply_over_attr_arrays(self, func):
for attr in self.get_array_attrs():
setattr(self, attr, func(getattr(self, attr)))
return self
def copy(self):
# TODO, either justify reason for shallow copy, or
# remove this redundancy everywhere
@ -362,14 +364,12 @@ class Mobject(object):
return self
# Transforming operations
def apply_to_family(self, func):
for mob in self.family_members_with_points():
func(mob)
def shift(self, *vectors):
total_vector = reduce(op.add, vectors)
for mob in self.get_family():
mob.set_points(mob.get_points() + total_vector)
def shift(self, vector):
self.apply_points_function(
lambda points: points + vector,
about_edge=None,
works_on_bounding_box=True,
)
return self
def scale(self, scale_factor, **kwargs):
@ -384,11 +384,19 @@ class Mobject(object):
"""
self.apply_points_function(
lambda points: scale_factor * points,
works_on_bounding_box=True,
**kwargs
)
return self
def rotate_about_origin(self, angle, axis=OUT, axes=[]):
def stretch(self, factor, dim, **kwargs):
def func(points):
points[:, dim] *= factor
return points
self.apply_points_function(func, works_on_bounding_box=True, **kwargs)
return self
def rotate_about_origin(self, angle, axis=OUT):
return self.rotate(angle, axis, about_point=ORIGIN)
def rotate(self, angle, axis=OUT, **kwargs):
@ -402,13 +410,6 @@ class Mobject(object):
def flip(self, axis=UP, **kwargs):
return self.rotate(TAU / 2, axis, **kwargs)
def stretch(self, factor, dim, **kwargs):
def func(points):
points[:, dim] *= factor
return points
self.apply_points_function(func, **kwargs)
return self
def apply_function(self, function, **kwargs):
# Default to applying matrix about the origin, not mobjects center
if len(kwargs) == 0:
@ -466,25 +467,29 @@ class Mobject(object):
def reverse_points(self):
for mob in self.family_members_with_points():
mob.apply_over_attr_arrays(lambda arr: arr[::-1])
for key in mob.data:
mob.data[key] = mob.data[key][::-1]
return self
def repeat(self, count):
"""
This can make transition animations nicer
"""
for mob in self.family_members_with_points():
mob.apply_over_attr_arrays(lambda arr: np.vstack([arr] * count))
return self
def apply_points_function(self, func, about_point=None, about_edge=None):
if about_point is None:
if about_edge is None:
about_edge = ORIGIN
def apply_points_function(self, func, about_point=None, about_edge=ORIGIN, works_on_bounding_box=False):
if about_point is None and about_edge is not None:
about_point = self.get_bounding_box_point(about_edge)
for mob in self.family_members_with_points():
points = mob.get_points()
points[:] = func(points - about_point) + about_point
for mob in self.get_family():
arrs = [mob.get_points()]
if works_on_bounding_box:
arrs.append(mob.get_bounding_box())
for arr in arrs:
if about_point is None:
arr[:] = func(arr)
else:
arr[:] = func(arr - about_point) + about_point
if not works_on_bounding_box:
self.refresh_bounding_box(recurse_down=True)
else:
for parent in self.parents:
parent.refresh_bounding_box()
return self
# Positioning methods
@ -800,16 +805,7 @@ class Mobject(object):
self.become(self.saved_state)
return self
##
def get_merged_array(self, array_attr):
if self.submobjects:
return np.vstack([
getattr(sm, array_attr)
for sm in self.get_family()
])
else:
return getattr(self, array_attr)
# Getters
def get_all_points(self):
if self.submobjects:
@ -817,11 +813,6 @@ class Mobject(object):
else:
return self.get_points()
# Getters
def get_points_defining_boundary(self):
return self.get_all_points()
def get_bounding_box_point(self, direction):
result = np.zeros(self.dim)
bb = self.get_bounding_box()
@ -831,15 +822,28 @@ class Mobject(object):
return result
def get_bounding_box(self):
all_points = self.get_points_defining_boundary()
if not self.needs_new_bounding_box:
return self.data["bounding_box"]
all_points = self.get_all_points()
if len(all_points) == 0:
return np.zeros((3, self.dim))
self.data["bounding_box"] = np.zeros((3, self.dim))
else:
# Lower left and upper right corners
mins = all_points.min(0)
maxs = all_points.max(0)
mids = (mins + maxs) / 2
return np.array([mins, mids, maxs])
self.data["bounding_box"] = np.array([mins, mids, maxs])
self.needs_new_bounding_box = False
return self.data["bounding_box"]
def refresh_bounding_box(self, recurse_down=False, recurse_up=True):
for mob in self.get_family(recurse_down):
mob.needs_new_bounding_box = True
if recurse_up:
for parent in self.parents:
parent.refresh_bounding_box()
return self
# Pseudonyms for more general get_bounding_box_point method
@ -850,13 +854,13 @@ class Mobject(object):
return self.get_bounding_box_point(direction)
def get_center(self):
return self.get_bounding_box_point(np.zeros(self.dim))
return self.get_bounding_box_point(ORIGIN)
def get_center_of_mass(self):
return self.get_all_points().mean(0)
def get_boundary_point(self, direction):
all_points = self.get_points_defining_boundary()
all_points = self.get_all_points()
boundary_directions = all_points - self.get_center()
norms = np.linalg.norm(boundary_directions, axis=1)
boundary_directions /= np.repeat(norms, 3).reshape((len(norms), 3))
@ -930,7 +934,9 @@ class Mobject(object):
return self.get_start(), self.get_end()
def point_from_proportion(self, alpha):
raise Exception("Not implemented")
points = self.get_points()
i, subalpha = integer_interpolate(0, len(points) - 1, alpha)
return interpolate(points[i], points[i + 1], subalpha)
def pfp(self, alpha):
"""Abbreviation fo point_from_proportion"""
@ -1167,7 +1173,7 @@ class Mobject(object):
self.copy().set_points([center])
for k in range(n)
])
return
return self
target = curr + n
repeat_indices = (np.arange(target) * curr) // target
split_factors = [