Have mobjects track family and parents more directly

This commit is contained in:
Grant Sanderson 2020-02-21 10:56:40 -08:00
parent b825b36b60
commit ea59950b2c
13 changed files with 106 additions and 79 deletions

View file

@ -137,7 +137,7 @@ class ShowIncreasingSubsets(Animation):
self.update_submobject_list(index) self.update_submobject_list(index)
def update_submobject_list(self, index): def update_submobject_list(self, index):
self.mobject.submobjects = self.all_submobs[:index] self.mobject.set_submobjects(self.all_submobs[:index])
class ShowSubmobjectsOneByOne(ShowIncreasingSubsets): class ShowSubmobjectsOneByOne(ShowIncreasingSubsets):
@ -152,9 +152,9 @@ class ShowSubmobjectsOneByOne(ShowIncreasingSubsets):
def update_submobject_list(self, index): def update_submobject_list(self, index):
# N = len(self.all_submobs) # N = len(self.all_submobs)
if index == 0: if index == 0:
self.mobject.submobjects = [] self.mobject.set_submobjects([])
else: else:
self.mobject.submobjects = self.all_submobs[index - 1] self.mobject.set_submobjects(self.all_submobs[index - 1])
# TODO, this is broken... # TODO, this is broken...

View file

@ -74,7 +74,6 @@ class PiCreature(SVGMobject):
self.flip() self.flip()
if self.start_corner is not None: if self.start_corner is not None:
self.to_corner(self.start_corner) self.to_corner(self.start_corner)
self.unlock_triangulation()
def align_data(self, mobject): def align_data(self, mobject):
# This ensures that after a transform into a different mode, # This ensures that after a transform into a different mode,
@ -260,10 +259,16 @@ class PiCreature(SVGMobject):
for alpha_range in (self.right_arm_range, self.left_arm_range) for alpha_range in (self.right_arm_range, self.left_arm_range)
]) ])
def prepare_for_animation(self):
self.unlock_triangulation()
def cleanup_from_animation(self):
self.lock_triangulation()
def get_all_pi_creature_modes(): def get_all_pi_creature_modes():
result = [] result = []
prefix = "%s_" % PiCreature.CONFIG["file_name_prefix"] prefix = PiCreature.CONFIG["file_name_prefix"] + "_"
suffix = ".svg" suffix = ".svg"
for file in os.listdir(PI_CREATURE_DIR): for file in os.listdir(PI_CREATURE_DIR):
if file.startswith(prefix) and file.endswith(suffix): if file.startswith(prefix) and file.endswith(suffix):

View file

@ -54,6 +54,8 @@ class Mobject(Container):
def __init__(self, **kwargs): def __init__(self, **kwargs):
Container.__init__(self, **kwargs) Container.__init__(self, **kwargs)
self.submobjects = [] self.submobjects = []
self.parents = []
self.family = [self]
self.color = Color(self.color) self.color = Color(self.color)
if self.name is None: if self.name is None:
self.name = self.__class__.__name__ self.name = self.__class__.__name__
@ -80,21 +82,64 @@ class Mobject(Container):
# Typically implemented in subclass, unless purposefully left blank # Typically implemented in subclass, unless purposefully left blank
pass pass
# Family matters
def __getitem__(self, value):
self_list = self.split()
if isinstance(value, slice):
GroupClass = self.get_group_class()
return GroupClass(*self_list.__getitem__(value))
return self_list.__getitem__(value)
def __iter__(self):
return iter(self.split())
def __len__(self):
return len(self.split())
def split(self):
result = [self] if len(self.points) > 0 else []
return result + self.submobjects
def assemble_family(self):
sub_families = [sm.get_family() for sm in self.submobjects]
self.family = [self, *it.chain(*sub_families)]
for parent in self.parents:
parent.assemble_family()
return self
def get_family(self):
return self.family
def family_members_with_points(self):
return [m for m in self.get_family() if m.get_num_points() > 0]
def add(self, *mobjects): def add(self, *mobjects):
if self in mobjects: if self in mobjects:
raise Exception("Mobject cannot contain self") raise Exception("Mobject cannot contain self")
self.submobjects = list_update(self.submobjects, mobjects) for mobject in mobjects:
return self if mobject not in self.submobjects:
self.submobjects.append(mobject)
def add_to_back(self, *mobjects): if self not in mobject.parents:
self.remove(*mobjects) mobject.parents.append(self)
self.submobjects = list(mobjects) + self.submobjects self.assemble_family()
return self return self
def remove(self, *mobjects): def remove(self, *mobjects):
for mobject in mobjects: for mobject in mobjects:
if mobject in self.submobjects: if mobject in self.submobjects:
self.submobjects.remove(mobject) self.submobjects.remove(mobject)
if self in mobject.parents:
mobject.parents.remove(self)
self.assemble_family()
return self
def add_to_back(self, *mobjects):
self.set_submobjects(list_update(mobjects, self.sub_mobjects))
return self
def set_submobjects(self, submobject_list):
self.remove(*self.submobjects)
self.add(*submobject_list)
return self return self
def get_array_attrs(self): def get_array_attrs(self):
@ -107,7 +152,7 @@ class Mobject(Container):
in the submobjects list. in the submobjects list.
""" """
mobject_attrs = [x for x in list(self.__dict__.values()) if isinstance(x, Mobject)] mobject_attrs = [x for x in list(self.__dict__.values()) if isinstance(x, Mobject)]
self.submobjects = list_update(self.submobjects, mobject_attrs) self.set_submobjects(list_update(self.submobjects, mobject_attrs))
return self return self
def apply_over_attr_arrays(self, func): def apply_over_attr_arrays(self, func):
@ -138,10 +183,12 @@ class Mobject(Container):
copy_mobject = copy.copy(self) copy_mobject = copy.copy(self)
copy_mobject.points = np.array(self.points) copy_mobject.points = np.array(self.points)
copy_mobject.submobjects = [ copy_mobject.parents = []
submob.copy() for submob in self.submobjects copy_mobject.submobjects = []
] copy_mobject.add(*[sm.copy() for sm in self.submobjects])
copy_mobject.updaters = list(self.updaters) copy_mobject.updaters = list(self.updaters)
# Make sure any mobject or numpy array attributes are copied
family = self.get_family() family = self.get_family()
for attr, value in list(self.__dict__.items()): for attr, value in list(self.__dict__.items()):
if isinstance(value, Mobject) and value in family and value is not self: if isinstance(value, Mobject) and value in family and value is not self:
@ -817,7 +864,7 @@ class Mobject(Container):
def get_pieces(self, n_pieces): def get_pieces(self, n_pieces):
template = self.copy() template = self.copy()
template.submobjects = [] template.set_submobjects([])
alphas = np.linspace(0, 1, n_pieces + 1) alphas = np.linspace(0, 1, n_pieces + 1)
return Group(*[ return Group(*[
template.copy().pointwise_become_partial( template.copy().pointwise_become_partial(
@ -893,34 +940,10 @@ class Mobject(Container):
self.set_coord(point[dim], dim, direction) self.set_coord(point[dim], dim, direction)
return self return self
# Family matters
def __getitem__(self, value):
self_list = self.split()
if isinstance(value, slice):
GroupClass = self.get_group_class()
return GroupClass(*self_list.__getitem__(value))
return self_list.__getitem__(value)
def __iter__(self):
return iter(self.split())
def __len__(self):
return len(self.split())
def get_group_class(self): def get_group_class(self):
return Group return Group
def split(self): # Submobject organization
result = [self] if len(self.points) > 0 else []
return result + self.submobjects
def get_family(self):
sub_families = [sm.get_family() for sm in self.submobjects]
return [self, *it.chain(*sub_families)]
def family_members_with_points(self):
return [m for m in self.get_family() if m.get_num_points() > 0]
def arrange(self, direction=RIGHT, center=True, **kwargs): def arrange(self, direction=RIGHT, center=True, **kwargs):
for m1, m2 in zip(self.submobjects, self.submobjects[1:]): for m1, m2 in zip(self.submobjects, self.submobjects[1:]):
m2.next_to(m1, direction, **kwargs) m2.next_to(m1, direction, **kwargs)
@ -971,12 +994,10 @@ class Mobject(Container):
# Alignment # Alignment
def align_data(self, mobject): def align_data(self, mobject):
self.null_point_align(mobject) self.null_point_align(mobject) # Needed?
self.align_submobjects(mobject) self.align_submobjects(mobject)
self.align_points(mobject) for mob1, mob2 in zip(self.get_family(), mobject.get_family()):
# Recurse mob1.align_points(mob2)
for m1, m2 in zip(self.submobjects, mobject.submobjects):
m1.align_data(m2)
def align_points(self, mobject): def align_points(self, mobject):
count1 = self.get_num_points() count1 = self.get_num_points()
@ -997,6 +1018,9 @@ class Mobject(Container):
n2 = len(mob2.submobjects) n2 = len(mob2.submobjects)
mob1.add_n_more_submobjects(max(0, n2 - n1)) mob1.add_n_more_submobjects(max(0, n2 - n1))
mob2.add_n_more_submobjects(max(0, n1 - n2)) mob2.add_n_more_submobjects(max(0, n1 - n2))
# Recurse
for sm1, sm2 in zip(mob1.submobjects, mob2.submobjects):
sm1.align_submobjects(sm2)
return self return self
def null_point_align(self, mobject): def null_point_align(self, mobject):
@ -1013,7 +1037,7 @@ class Mobject(Container):
def push_self_into_submobjects(self): def push_self_into_submobjects(self):
copy = self.deepcopy() copy = self.deepcopy()
copy.submobjects = [] copy.submobjects.set_submobjects([])
self.reset_points() self.reset_points()
self.add(copy) self.add(copy)
return self return self
@ -1025,15 +1049,12 @@ class Mobject(Container):
curr = len(self.submobjects) curr = len(self.submobjects)
if curr == 0: if curr == 0:
# If empty, simply add n point mobjects # If empty, simply add n point mobjects
self.submobjects = [ self.set_submobjects([
self.copy().scale(0) self.copy().scale(0)
for k in range(n) for k in range(n)
] ])
return return
target = curr + n target = curr + n
# TODO, factor this out to utils so as to reuse
# with VMobject.insert_n_curves
repeat_indices = (np.arange(target) * curr) // target repeat_indices = (np.arange(target) * curr) // target
split_factors = [ split_factors = [
(repeat_indices == i).sum() (repeat_indices == i).sum()
@ -1044,12 +1065,9 @@ class Mobject(Container):
new_submobs.append(submob) new_submobs.append(submob)
for k in range(1, sf): for k in range(1, sf):
new_submobs.append(submob.copy().fade(1)) new_submobs.append(submob.copy().fade(1))
self.submobjects = new_submobs self.set_submobjects(new_submobs)
return self return self
def repeat_submobject(self, submob):
return submob.copy()
def interpolate(self, mobject1, mobject2, def interpolate(self, mobject1, mobject2,
alpha, path_func=straight_path): alpha, path_func=straight_path):
""" """
@ -1082,7 +1100,7 @@ class Mobject(Container):
Edit points, colors and submobjects to be idential Edit points, colors and submobjects to be idential
to another mobject to another mobject
""" """
self.align_data(mobject) self.align_submobjects(mobject)
for sm1, sm2 in zip(self.get_family(), mobject.get_family()): for sm1, sm2 in zip(self.get_family(), mobject.get_family()):
sm1.set_points(sm2.points) sm1.set_points(sm2.points)
sm1.interpolate_color(sm1, sm2, 1) sm1.interpolate_color(sm1, sm2, 1)

View file

@ -122,7 +122,7 @@ class DecimalNumber(VMobject):
new_decimal.match_style(self) new_decimal.match_style(self)
old_family = self.get_family() old_family = self.get_family()
self.submobjects = new_decimal.submobjects self.set_submobjects(new_decimal.submobjects)
for mob in old_family: for mob in old_family:
# Dumb hack...due to how scene handles families # Dumb hack...due to how scene handles families
# of animated mobjects # of animated mobjects

View file

@ -97,7 +97,7 @@ class BraceLabel(VMobject):
self.label.scale(self.label_scale) self.label.scale(self.label_scale)
self.brace.put_at_tip(self.label) self.brace.put_at_tip(self.label)
self.submobjects = [self.brace, self.label] self.set_submobjects([self.brace, self.label])
def creation_anim(self, label_anim=FadeIn, brace_anim=GrowFromCenter): def creation_anim(self, label_anim=FadeIn, brace_anim=GrowFromCenter):
return AnimationGroup(brace_anim(self.brace), label_anim(self.label)) return AnimationGroup(brace_anim(self.brace), label_anim(self.label))
@ -128,7 +128,7 @@ class BraceLabel(VMobject):
copy_mobject = copy.copy(self) copy_mobject = copy.copy(self)
copy_mobject.brace = self.brace.copy() copy_mobject.brace = self.brace.copy()
copy_mobject.label = self.label.copy() copy_mobject.label = self.label.copy()
copy_mobject.submobjects = [copy_mobject.brace, copy_mobject.label] copy_mobject.set_submobjects([copy_mobject.brace, copy_mobject.label])
return copy_mobject return copy_mobject

View file

@ -184,14 +184,14 @@ class TexMobject(SingleStringTexMobject):
# For cases like empty tex_strings, we want the corresponing # For cases like empty tex_strings, we want the corresponing
# part of the whole TexMobject to be a VectorizedPoint # part of the whole TexMobject to be a VectorizedPoint
# positioned in the right part of the TexMobject # positioned in the right part of the TexMobject
sub_tex_mob.submobjects = [VectorizedPoint()] sub_tex_mob.set_submobjects([VectorizedPoint()])
last_submob_index = min(curr_index, len(self.submobjects) - 1) last_submob_index = min(curr_index, len(self.submobjects) - 1)
sub_tex_mob.move_to(self.submobjects[last_submob_index], RIGHT) sub_tex_mob.move_to(self.submobjects[last_submob_index], RIGHT)
else: else:
sub_tex_mob.submobjects = self.submobjects[curr_index:new_index] sub_tex_mob.set_submobjects(self.submobjects[curr_index:new_index])
new_submobjects.append(sub_tex_mob) new_submobjects.append(sub_tex_mob)
curr_index = new_index curr_index = new_index
self.submobjects = new_submobjects self.set_submobjects(new_submobjects)
return self return self
def get_parts_by_tex(self, tex, substring=True, case_sensitive=True): def get_parts_by_tex(self, tex, substring=True, case_sensitive=True):

View file

@ -137,7 +137,7 @@ class PMobject(Mobject):
arrays = list(map(self.get_merged_array, attrs)) arrays = list(map(self.get_merged_array, attrs))
for attr, array in zip(attrs, arrays): for attr, array in zip(attrs, arrays):
setattr(self, attr, array) setattr(self, attr, array)
self.submobjects = [] self.set_submobjects([])
return self return self
def get_color(self): def get_color(self):

View file

@ -568,6 +568,10 @@ class VMobject(Mobject):
self.make_smooth() self.make_smooth()
return self return self
def flip(self):
super().flip()
self.refresh_triangulation()
# #
def consider_points_equals(self, p0, p1): def consider_points_equals(self, p0, p1):
return np.allclose( return np.allclose(

View file

@ -153,7 +153,7 @@ class GeneralizedPascalsTriangle(VMobject):
def fill_with_n_choose_k(self): def fill_with_n_choose_k(self):
if not hasattr(self, "coords_to_n_choose_k"): if not hasattr(self, "coords_to_n_choose_k"):
self.generate_n_choose_k_mobs() self.generate_n_choose_k_mobs()
self.submobjects = [] self.set_submobjects([])
self.add(*[ self.add(*[
self.coords_to_n_choose_k[n][k] self.coords_to_n_choose_k[n][k]
for n, k in self.coords for n, k in self.coords

View file

@ -62,11 +62,11 @@ def fractalification_iteration(vmobject, dimension=1.05, num_inserted_anchors_ra
new_anchors += [p1] + inserted_points new_anchors += [p1] + inserted_points
new_anchors.append(original_anchors[-1]) new_anchors.append(original_anchors[-1])
vmobject.set_points_as_corners(new_anchors) vmobject.set_points_as_corners(new_anchors)
vmobject.submobjects = [ vmobject.set_submobjects([
fractalification_iteration( fractalification_iteration(
submob, dimension, num_inserted_anchors_range) submob, dimension, num_inserted_anchors_range)
for submob in vmobject.submobjects for submob in vmobject.submobjects
] ])
return vmobject return vmobject
@ -87,9 +87,9 @@ class SelfSimilarFractal(VMobject):
def init_points(self): def init_points(self):
order_n_self = self.get_order_n_self(self.order) order_n_self = self.get_order_n_self(self.order)
if self.order == 0: if self.order == 0:
self.submobjects = [order_n_self] self.set_submobjects([order_n_self])
else: else:
self.submobjects = order_n_self.submobjects self.set_submobjects(order_n_self.submobjects)
return self return self
def get_order_n_self(self, order): def get_order_n_self(self, order):

View file

@ -67,9 +67,9 @@ class SwitchOff(LaggedStartMap):
if (not isinstance(light, AmbientLight) and not isinstance(light, Spotlight)): if (not isinstance(light, AmbientLight) and not isinstance(light, Spotlight)):
raise Exception( raise Exception(
"Only AmbientLights and Spotlights can be switched off") "Only AmbientLights and Spotlights can be switched off")
light.submobjects = light.submobjects[::-1] light.set_submobjects(light.submobjects[::-1])
LaggedStartMap.__init__(self, FadeOut, light, **kwargs) LaggedStartMap.__init__(self, FadeOut, light, **kwargs)
light.submobjects = light.submobjects[::-1] light.set_submobjects(light.submobjects[::-1])
class Lighthouse(SVGMobject): class Lighthouse(SVGMobject):
@ -182,7 +182,7 @@ class Spotlight(VMobject):
return self.source_point.get_location() return self.source_point.get_location()
def init_points(self): def init_points(self):
self.submobjects = [] self.set_submobjects([])
self.add(self.source_point) self.add(self.source_point)
@ -493,7 +493,7 @@ class LightSource(VMobject):
) )
new_ambient_light.apply_matrix(self.rotation_matrix()) new_ambient_light.apply_matrix(self.rotation_matrix())
new_ambient_light.move_source_to(self.get_source_point()) new_ambient_light.move_source_to(self.get_source_point())
self.ambient_light.submobjects = new_ambient_light.submobjects self.ambient_light.set_submobjects(new_ambient_light.submobjects)
def get_source_point(self): def get_source_point(self):
return self.source_point.get_location() return self.source_point.get_location()

View file

@ -393,16 +393,16 @@ class Scene(Container):
mobject.unlock_shader_data() mobject.unlock_shader_data()
def begin_animations(self, animations): def begin_animations(self, animations):
curr_mobjects = self.get_mobject_family_members()
for animation in animations: for animation in animations:
# Begin animation
animation.begin() animation.begin()
# Anything animated that's not already in the # Anything animated that's not already in the
# scene gets added to the scene # scene gets added to the scene. Note, for
# animated mobjects that are in the family of
# those on screen, this can result in a restructuring
# of the scene.mobjects list, which is usually desired.
mob = animation.mobject mob = animation.mobject
if mob not in curr_mobjects: if mob not in self.mobjects:
self.add(mob) self.add(mob)
curr_mobjects += mob.get_family()
def progress_through_animations(self, animations): def progress_through_animations(self, animations):
last_t = 0 last_t = 0

View file

@ -162,7 +162,7 @@ class SpecialThreeDScene(ThreeDScene):
for piece in new_pieces: for piece in new_pieces:
piece.shade_in_3d = True piece.shade_in_3d = True
new_pieces.match_style(axis.pieces) new_pieces.match_style(axis.pieces)
axis.pieces.submobjects = new_pieces.submobjects axis.pieces.set_submobjects(new_pieces.submobjects)
for tick in axis.tick_marks: for tick in axis.tick_marks:
tick.add(VectorizedPoint( tick.add(VectorizedPoint(
1.5 * tick.get_center(), 1.5 * tick.get_center(),