Reorganize Mobject methods and remove ones that are not longer needed

This commit is contained in:
Grant Sanderson 2021-01-14 14:15:58 -10:00
parent 8f6b006cc8
commit 7b67f4556b
5 changed files with 119 additions and 136 deletions

View file

@ -57,7 +57,6 @@ class Animation(object):
def finish(self):
self.interpolate(self.final_alpha_value)
self.mobject.cleanup_from_animation()
if self.suspend_mobject_updating:
self.mobject.resume_updating()

View file

@ -84,7 +84,7 @@ class TracedPath(VMobject):
def update_path(self):
new_point = self.traced_point_func()
if self.has_no_points():
if not self.has_points():
self.start_new_path(new_point)
self.add_line_to(new_point)
else:

View file

@ -450,7 +450,7 @@ class Line(TipableVMobject):
if direction is None:
return mob.get_center()
else:
return mob.get_bounding_box_point_by_direction(direction)
return mob.get_continuous_bounding_box_point(direction)
else:
point = mob_or_point
result = np.zeros(self.dim)

View file

@ -102,8 +102,6 @@ class Mobject(object):
# Typically implemented in subclass, unlpess purposefully left blank
pass
# Related to data dict
def set_data(self, data):
for key in data:
self.data[key] = data[key].copy()
@ -114,6 +112,8 @@ class Mobject(object):
self.uniforms[key] = uniforms[key] # Copy?
return self
# Only these methods should directly affect points
def resize_points(self, new_length, resize_func=resize_array):
if new_length != len(self.data["points"]):
self.data["points"] = resize_func(self.data["points"], new_length)
@ -125,10 +125,8 @@ class Mobject(object):
self.data["points"][:] = points
elif isinstance(points, np.ndarray):
self.data["points"] = points.copy()
# Note that points have been resized?
else:
self.data["points"] = np.array(points)
# Note that points have been resized?
self.refresh_bounding_box()
return self
@ -137,6 +135,35 @@ class Mobject(object):
self.refresh_bounding_box()
return self
def reverse_points(self):
for mob in self.get_family():
for key in mob.data:
mob.data[key] = mob.data[key][::-1]
return self
def apply_points_function(self, func, about_point=None, about_edge=ORIGIN, works_on_bounding_box=False):
if about_point is None and about_edge is not None:
about_point = self.get_bounding_box_point(about_edge)
for mob in self.get_family():
arrs = [mob.get_points()]
if works_on_bounding_box:
arrs.append(mob.get_bounding_box())
for arr in arrs:
if about_point is None:
arr[:] = func(arr)
else:
arr[:] = func(arr - about_point) + about_point
if not works_on_bounding_box:
self.refresh_bounding_box(recurse_down=True)
else:
for parent in self.parents:
parent.refresh_bounding_box()
return self
# Others related to points
def match_points(self, mobject):
self.set_points(mobject.get_points())
@ -149,6 +176,47 @@ class Mobject(object):
def get_num_points(self):
return len(self.data["points"])
def get_all_points(self):
if self.submobjects:
return np.vstack([sm.get_points() for sm in self.get_family()])
else:
return self.get_points()
def has_points(self):
return self.get_num_points() > 0
def get_bounding_box(self):
if not self.needs_new_bounding_box:
return self.data["bounding_box"]
# all_points = self.get_all_points()
all_points = np.vstack([
self.get_points(),
*(
mob.get_bounding_box()
for mob in self.get_family()[1:]
if mob.has_points()
)
])
if len(all_points) == 0:
self.data["bounding_box"] = np.zeros((3, self.dim))
else:
# Lower left and upper right corners
mins = all_points.min(0)
maxs = all_points.max(0)
mids = (mins + maxs) / 2
self.data["bounding_box"] = np.array([mins, mids, maxs])
self.needs_new_bounding_box = False
return self.data["bounding_box"]
def refresh_bounding_box(self, recurse_down=False, recurse_up=True):
for mob in self.get_family(recurse_down):
mob.needs_new_bounding_box = True
if recurse_up:
for parent in self.parents:
parent.refresh_bounding_box()
return self
# Family matters
def __getitem__(self, value):
@ -167,7 +235,7 @@ class Mobject(object):
return self.submobjects
def assemble_family(self):
sub_families = [sm.get_family() for sm in self.submobjects]
sub_families = (sm.get_family() for sm in self.submobjects)
self.family = [self, *it.chain(*sub_families)]
self.refresh_has_updater_status()
self.refresh_bounding_box()
@ -176,7 +244,10 @@ class Mobject(object):
return self
def get_family(self, recurse=True):
return self.family if recurse else [self]
if recurse:
return self.family
else:
return [self]
def family_members_with_points(self):
return [m for m in self.get_family() if m.has_points()]
@ -227,6 +298,8 @@ class Mobject(object):
self.set_submobjects(list_update(self.submobjects, mobject_attrs))
return self
# Copying
def copy(self):
# TODO, either justify reason for shallow copy, or
# remove this redundancy everywhere
@ -276,7 +349,24 @@ class Mobject(object):
self.target = self.copy()
return self.target
def save_state(self, use_deepcopy=False):
if hasattr(self, "saved_state"):
# Prevent exponential growth of data
self.saved_state = None
if use_deepcopy:
self.saved_state = self.deepcopy()
else:
self.saved_state = self.copy()
return self
def restore(self):
if not hasattr(self, "saved_state") or self.save_state is None:
raise Exception("Trying to restore without having saved")
self.become(self.saved_state)
return self
# Updating
def init_updaters(self):
self.time_based_updaters = []
self.non_time_updaters = []
@ -367,6 +457,7 @@ class Mobject(object):
return self
# Transforming operations
def shift(self, vector):
self.apply_points_function(
lambda points: points + vector,
@ -468,34 +559,8 @@ class Mobject(object):
))
return self
def reverse_points(self):
for mob in self.family_members_with_points():
for key in mob.data:
mob.data[key] = mob.data[key][::-1]
return self
def apply_points_function(self, func, about_point=None, about_edge=ORIGIN, works_on_bounding_box=False):
if about_point is None and about_edge is not None:
about_point = self.get_bounding_box_point(about_edge)
for mob in self.get_family():
arrs = [mob.get_points()]
if works_on_bounding_box:
arrs.append(mob.get_bounding_box())
for arr in arrs:
if about_point is None:
arr[:] = func(arr)
else:
arr[:] = func(arr - about_point) + about_point
if not works_on_bounding_box:
self.refresh_bounding_box(recurse_down=True)
else:
for parent in self.parents:
parent.refresh_bounding_box()
return self
# Positioning methods
def center(self):
self.shift(-self.get_center())
return self
@ -680,6 +745,7 @@ class Mobject(object):
return self
# Background rectangle
def add_background_rectangle(self, color=BLACK, opacity=0.75, **kwargs):
# TODO, this does not behave well when the mobject has points,
# since it gets displayed on top
@ -703,6 +769,7 @@ class Mobject(object):
return self
# Color functions
def set_rgba_array(self, color=None, opacity=None, name="rgbas", recurse=True):
if color is not None:
rgbs = np.array([color_to_rgb(c) for c in listify(color)])
@ -790,32 +857,8 @@ class Mobject(object):
mob.uniforms["shadow"] = shadow
return self
##
def save_state(self, use_deepcopy=False):
if hasattr(self, "saved_state"):
# Prevent exponential growth of data
self.saved_state = None
if use_deepcopy:
self.saved_state = self.deepcopy()
else:
self.saved_state = self.copy()
return self
def restore(self):
if not hasattr(self, "saved_state") or self.save_state is None:
raise Exception("Trying to restore without having saved")
self.become(self.saved_state)
return self
# Getters
def get_all_points(self):
if self.submobjects:
return np.vstack([sm.get_points() for sm in self.get_family()])
else:
return self.get_points()
def get_bounding_box_point(self, direction):
bb = self.get_bounding_box()
indices = (np.sign(direction) + 1).astype(int)
@ -824,40 +867,6 @@ class Mobject(object):
for i in range(3)
])
def get_bounding_box(self):
if not self.needs_new_bounding_box:
return self.data["bounding_box"]
# all_points = self.get_all_points()
all_points = np.vstack([
self.get_points(),
*(
mob.get_bounding_box()
for mob in self.get_family()[1:]
if mob.has_points()
)
])
if len(all_points) == 0:
self.data["bounding_box"] = np.zeros((3, self.dim))
else:
# Lower left and upper right corners
mins = all_points.min(0)
maxs = all_points.max(0)
mids = (mins + maxs) / 2
self.data["bounding_box"] = np.array([mins, mids, maxs])
self.needs_new_bounding_box = False
return self.data["bounding_box"]
def refresh_bounding_box(self, recurse_down=False, recurse_up=True):
for mob in self.get_family(recurse_down):
mob.needs_new_bounding_box = True
if recurse_up:
for parent in self.parents:
parent.refresh_bounding_box()
return self
# Pseudonyms for more general get_bounding_box_point method
def get_edge_center(self, direction):
return self.get_bounding_box_point(direction)
@ -878,7 +887,7 @@ class Mobject(object):
index = np.argmax(np.dot(boundary_directions, np.array(direction).T))
return all_points[index]
def get_bounding_box_point_by_direction(self, direction):
def get_continuous_bounding_box_point(self, direction):
dl, center, ur = self.get_bounding_box()
corner_vect = (ur - center)
return center + direction / np.max(np.abs(np.true_divide(
@ -969,12 +978,6 @@ class Mobject(object):
z_index_group = getattr(self, "z_index_group", self)
return z_index_group.get_center()
def has_points(self):
return self.get_num_points() > 0
def has_no_points(self):
return not self.has_points()
# Match other mobject properties
def match_color(self, mobject):
@ -1112,17 +1115,6 @@ class Mobject(object):
random.shuffle(self.submobjects)
return self
# Just here to keep from breaking old scenes.
def arrange_submobjects(self, *args, **kwargs):
return self.arrange(*args, **kwargs)
def sort_submobjects(self, *args, **kwargs):
return self.sort(*args, **kwargs)
def shuffle_submobjects(self, *args, **kwargs):
return self.shuffle(*args, **kwargs)
# Alignment
def align_data_and_family(self, mobject):
@ -1204,17 +1196,20 @@ class Mobject(object):
self.set_submobjects(new_submobs)
return self
# Interpolate
def interpolate(self, mobject1, mobject2, alpha, path_func=straight_path):
"""
Turns self into an interpolation between mobject1
and mobject2.
"""
for key in self.data:
if key in self.locked_data_keys:
continue
if len(self.data[key]) == 0:
continue
func = path_func if key == "points" else interpolate
if key in ("points", "bounding_box"):
func = path_func
else:
func = interpolate
self.data[key][:] = func(
mobject1.data[key],
mobject2.data[key],
@ -1228,23 +1223,18 @@ class Mobject(object):
)
return self
def become_partial(self, mobject, a, b):
def pointwise_become_partial(self, mobject, a, b):
"""
Set points in such a way as to become only
part of mobject.
Inputs 0 <= a < b <= 1 determine what portion
of mobject to become.
"""
pass # To implement in subclasses
# TODO, color?
def pointwise_become_partial(self, mobject, a, b):
pass # To implement in subclass
def become(self, mobject):
"""
Edit points, colors and submobjects to be idential
Edit all data and submobjects to be idential
to another mobject
"""
self.align_family(mobject)
@ -1253,9 +1243,6 @@ class Mobject(object):
sm1.set_uniforms(sm2.uniforms)
return self
def cleanup_from_animation(self):
pass
# Locking data
def lock_data(self, keys):
@ -1316,16 +1303,13 @@ class Mobject(object):
# Shader code manipulation
def replace_shader_code(self, old, new):
# TODO, will this work with VMobject structure, given
# that it does not simpler return shader_wrappers of
# family?
for wrapper in self.get_shader_wrapper_list():
wrapper.replace_code(old, new)
return self
def refresh_shader_code(self):
for wrapper in self.get_shader_wrapper_list():
wrapper.init_program_code()
wrapper.refresh_id()
return self
def set_color_by_code(self, glsl_code):
"""
Takes a snippet of code and inserts it into a
@ -1444,7 +1428,7 @@ class Mobject(object):
# Errors
def throw_error_if_no_points(self):
if self.has_no_points():
if not self.has_points():
message = "Cannot call Mobject.{} " +\
"for a Mobject with no points"
caller_name = sys._getframe(1).f_code.co_name

View file

@ -580,7 +580,7 @@ class VMobject(Mobject):
# the polygon formed by the anchor points, pointing
# in a direction perpendicular to the polygon according
# to the right hand rule.
if self.has_no_points():
if not self.has_points():
return np.zeros(3)
nppc = self.n_points_per_curve
@ -626,7 +626,7 @@ class VMobject(Mobject):
for mob in self, vmobject:
# If there are no points, add one to
# where the "center" is
if mob.has_no_points():
if not mob.has_points():
mob.start_new_path(mob.get_center())
# If there's only one point, turn it into
# a null curve